diff --git a/.gitattributes b/.gitattributes
index 441062e481ec29f9c71e288a6fc0a213ccbdfb21..8074aca63dfebffcbc746efe0b4b9b744a42da4e 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -34,3 +34,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
imgs/vocoder/gan/MSSBCQTD.png filter=lfs diff=lfs merge=lfs -text
+models/codec/facodec/modules/JDC/bst.t7 filter=lfs diff=lfs merge=lfs -text
+models/tts/maskgct/g2p/sources/chinese_lexicon.txt filter=lfs diff=lfs merge=lfs -text
+models/tts/maskgct/wav/prompt.wav filter=lfs diff=lfs merge=lfs -text
diff --git a/README.md b/README.md
index 3798706e36271865ba5b2e033f12182081aeedb9..05434ecf56db0608ec4c0d1009336cf540bfc056 100644
--- a/README.md
+++ b/README.md
@@ -1,14 +1,169 @@
----
-title: Maskgct
-emoji: 🚀
-colorFrom: yellow
-colorTo: red
-sdk: gradio
-sdk_version: 5.1.0
-app_file: app.py
-pinned: false
-license: mit
-short_description: MaskGCT TTS Demo
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+# Amphion: An Open-Source Audio, Music, and Speech Generation Toolkit
+
+
+
+
+**Amphion (/æmˈfaɪən/) is a toolkit for Audio, Music, and Speech Generation.** Its purpose is to support reproducible research and help junior researchers and engineers get started in the field of audio, music, and speech generation research and development. Amphion offers a unique feature: **visualizations** of classic models or architectures. We believe that these visualizations are beneficial for junior researchers and engineers who wish to gain a better understanding of the model.
+
+**The North-Star objective of Amphion is to offer a platform for studying the conversion of any inputs into audio.** Amphion is designed to support individual generation tasks, including but not limited to,
+
+- **TTS**: Text to Speech (⛳ supported)
+- **SVS**: Singing Voice Synthesis (👨💻 developing)
+- **VC**: Voice Conversion (👨💻 developing)
+- **SVC**: Singing Voice Conversion (⛳ supported)
+- **TTA**: Text to Audio (⛳ supported)
+- **TTM**: Text to Music (👨💻 developing)
+- more…
+
+In addition to the specific generation tasks, Amphion includes several **vocoders** and **evaluation metrics**. A vocoder is an important module for producing high-quality audio signals, while evaluation metrics are critical for ensuring consistent metrics in generation tasks. Moreover, Amphion is dedicated to advancing audio generation in real-world applications, such as building **large-scale datasets** for speech synthesis.
+
+## 🚀 News
+- **2024/10/19**: We release **MaskGCT**, a fully non-autoregressive TTS model that eliminates the need for explicit alignment information between text and speech supervision. MaskGCT is trained on Emilia dataset and achieves SOTA zero-shot TTS perfermance. [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2409.00750) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-model-yellow)](https://huggingface.co/amphion/maskgct) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-demo-pink)](https://huggingface.co/spaces/amphion/maskgct) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](models/tts/maskgct/README.md)
+- **2024/09/01**: [Amphion](https://arxiv.org/abs/2312.09911), [Emilia](https://arxiv.org/abs/2407.05361) and [DSFF-SVC](https://arxiv.org/abs/2310.11160) got accepted by IEEE SLT 2024! 🤗
+- **2024/08/28**: Welcome to join Amphion's [Discord channel](https://discord.gg/drhW7ajqAG) to stay connected and engage with our community!
+- **2024/08/20**: [SingVisio](https://arxiv.org/abs/2402.12660) got accepted by Computers & Graphics, [available here](https://www.sciencedirect.com/science/article/pii/S0097849324001936)! 🎉
+- **2024/08/27**: *The Emilia dataset is now publicly available!* Discover the most extensive and diverse speech generation dataset with 101k hours of in-the-wild speech data now at [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Dataset-yellow)](https://huggingface.co/datasets/amphion/Emilia-Dataset) or [![OpenDataLab](https://img.shields.io/badge/OpenDataLab-Dataset-blue)](https://opendatalab.com/Amphion/Emilia)! 👑👑👑
+- **2024/07/01**: Amphion now releases **Emilia**, the first open-source multilingual in-the-wild dataset for speech generation with over 101k hours of speech data, and the **Emilia-Pipe**, the first open-source preprocessing pipeline designed to transform in-the-wild speech data into high-quality training data with annotations for speech generation! [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2407.05361) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Dataset-yellow)](https://huggingface.co/datasets/amphion/Emilia) [![demo](https://img.shields.io/badge/WebPage-Demo-red)](https://emilia-dataset.github.io/Emilia-Demo-Page/) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](preprocessors/Emilia/README.md)
+- **2024/06/17**: Amphion has a new release for its **VALL-E** model! It uses Llama as its underlying architecture and has better model performance, faster training speed, and more readable codes compared to our first version. [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](egs/tts/VALLE_V2/README.md)
+- **2024/03/12**: Amphion now support **NaturalSpeech3 FACodec** and release pretrained checkpoints. [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2403.03100) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-model-yellow)](https://huggingface.co/amphion/naturalspeech3_facodec) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-demo-pink)](https://huggingface.co/spaces/amphion/naturalspeech3_facodec) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](models/codec/ns3_codec/README.md)
+- **2024/02/22**: The first Amphion visualization tool, **SingVisio**, release. [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2402.12660) [![openxlab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/Amphion/SingVisio) [![Video](https://img.shields.io/badge/Video-Demo-orange)](https://drive.google.com/file/d/15097SGhQh-SwUNbdWDYNyWEP--YGLba5/view) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](egs/visualization/SingVisio/README.md)
+- **2023/12/18**: Amphion v0.1 release. [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2312.09911) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Amphion-pink)](https://huggingface.co/amphion) [![youtube](https://img.shields.io/badge/YouTube-Demo-red)](https://www.youtube.com/watch?v=1aw0HhcggvQ) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](https://github.com/open-mmlab/Amphion/pull/39)
+- **2023/11/28**: Amphion alpha release. [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](https://github.com/open-mmlab/Amphion/pull/2)
+
+## ⭐ Key Features
+
+### TTS: Text to Speech
+
+- Amphion achieves state-of-the-art performance compared to existing open-source repositories on text-to-speech (TTS) systems. It supports the following models or architectures:
+ - [FastSpeech2](https://arxiv.org/abs/2006.04558): A non-autoregressive TTS architecture that utilizes feed-forward Transformer blocks.
+ - [VITS](https://arxiv.org/abs/2106.06103): An end-to-end TTS architecture that utilizes conditional variational autoencoder with adversarial learning
+ - [VALL-E](https://arxiv.org/abs/2301.02111): A zero-shot TTS architecture that uses a neural codec language model with discrete codes.
+ - [NaturalSpeech2](https://arxiv.org/abs/2304.09116): An architecture for TTS that utilizes a latent diffusion model to generate natural-sounding voices.
+ - [Jets](Jets): An end-to-end TTS model that jointly trains FastSpeech2 and HiFi-GAN with an alignment module.
+ - [MaskGCT](https://arxiv.org/abs/2409.00750): a fully non-autoregressive TTS architecture that eliminates the need for explicit alignment information between text and speech supervision.
+
+### SVC: Singing Voice Conversion
+
+- Ampion supports multiple content-based features from various pretrained models, including [WeNet](https://github.com/wenet-e2e/wenet), [Whisper](https://github.com/openai/whisper), and [ContentVec](https://github.com/auspicious3000/contentvec). Their specific roles in SVC has been investigated in our SLT 2024 paper. [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2310.11160) [![code](https://img.shields.io/badge/README-Code-red)](egs/svc/MultipleContentsSVC)
+- Amphion implements several state-of-the-art model architectures, including diffusion-, transformer-, VAE- and flow-based models. The diffusion-based architecture uses [Bidirectional dilated CNN](https://openreview.net/pdf?id=a-xFK8Ymz5J) as a backend and supports several sampling algorithms such as [DDPM](https://arxiv.org/pdf/2006.11239.pdf), [DDIM](https://arxiv.org/pdf/2010.02502.pdf), and [PNDM](https://arxiv.org/pdf/2202.09778.pdf). Additionally, it supports single-step inference based on the [Consistency Model](https://openreview.net/pdf?id=FmqFfMTNnv).
+
+### TTA: Text to Audio
+
+- Amphion supports the TTA with a latent diffusion model. It is designed like [AudioLDM](https://arxiv.org/abs/2301.12503), [Make-an-Audio](https://arxiv.org/abs/2301.12661), and [AUDIT](https://arxiv.org/abs/2304.00830). It is also the official implementation of the text-to-audio generation part of our NeurIPS 2023 paper. [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2304.00830) [![code](https://img.shields.io/badge/README-Code-red)](egs/tta/RECIPE.md)
+
+### Vocoder
+
+- Amphion supports various widely-used neural vocoders, including:
+ - GAN-based vocoders: [MelGAN](https://arxiv.org/abs/1910.06711), [HiFi-GAN](https://arxiv.org/abs/2010.05646), [NSF-HiFiGAN](https://github.com/nii-yamagishilab/project-NN-Pytorch-scripts), [BigVGAN](https://arxiv.org/abs/2206.04658), [APNet](https://arxiv.org/abs/2305.07952).
+ - Flow-based vocoders: [WaveGlow](https://arxiv.org/abs/1811.00002).
+ - Diffusion-based vocoders: [Diffwave](https://arxiv.org/abs/2009.09761).
+ - Auto-regressive based vocoders: [WaveNet](https://arxiv.org/abs/1609.03499), [WaveRNN](https://arxiv.org/abs/1802.08435v1).
+- Amphion provides the official implementation of [Multi-Scale Constant-Q Transform Discriminator](https://arxiv.org/abs/2311.14957) (our ICASSP 2024 paper). It can be used to enhance any architecture GAN-based vocoders during training, and keep the inference stage (such as memory or speed) unchanged. [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2311.14957) [![code](https://img.shields.io/badge/README-Code-red)](egs/vocoder/gan/tfr_enhanced_hifigan)
+
+### Evaluation
+
+Amphion provides a comprehensive objective evaluation of the generated audio. The evaluation metrics contain:
+
+- **F0 Modeling**: F0 Pearson Coefficients, F0 Periodicity Root Mean Square Error, F0 Root Mean Square Error, Voiced/Unvoiced F1 Score, etc.
+- **Energy Modeling**: Energy Root Mean Square Error, Energy Pearson Coefficients, etc.
+- **Intelligibility**: Character/Word Error Rate, which can be calculated based on [Whisper](https://github.com/openai/whisper) and more.
+- **Spectrogram Distortion**: Frechet Audio Distance (FAD), Mel Cepstral Distortion (MCD), Multi-Resolution STFT Distance (MSTFT), Perceptual Evaluation of Speech Quality (PESQ), Short Time Objective Intelligibility (STOI), etc.
+- **Speaker Similarity**: Cosine similarity, which can be calculated based on [RawNet3](https://github.com/Jungjee/RawNet), [Resemblyzer](https://github.com/resemble-ai/Resemblyzer), [WeSpeaker](https://github.com/wenet-e2e/wespeaker), [WavLM](https://github.com/microsoft/unilm/tree/master/wavlm) and more.
+
+### Datasets
+
+- Amphion unifies the data preprocess of the open-source datasets including [AudioCaps](https://audiocaps.github.io/), [LibriTTS](https://www.openslr.org/60/), [LJSpeech](https://keithito.com/LJ-Speech-Dataset/), [M4Singer](https://github.com/M4Singer/M4Singer), [Opencpop](https://wenet.org.cn/opencpop/), [OpenSinger](https://github.com/Multi-Singer/Multi-Singer.github.io), [SVCC](http://vc-challenge.org/), [VCTK](https://datashare.ed.ac.uk/handle/10283/3443), and more. The supported dataset list can be seen [here](egs/datasets/README.md) (updating).
+- Amphion (exclusively) supports the [**Emilia**](preprocessors/Emilia/README.md) dataset and its preprocessing pipeline **Emilia-Pipe** for in-the-wild speech data!
+
+### Visualization
+
+Amphion provides visualization tools to interactively illustrate the internal processing mechanism of classic models. This provides an invaluable resource for educational purposes and for facilitating understandable research.
+
+Currently, Amphion supports [SingVisio](egs/visualization/SingVisio/README.md), a visualization tool of the diffusion model for singing voice conversion. [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2402.12660) [![openxlab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/Amphion/SingVisio) [![Video](https://img.shields.io/badge/Video-Demo-orange)](https://drive.google.com/file/d/15097SGhQh-SwUNbdWDYNyWEP--YGLba5/view)
+
+
+## 📀 Installation
+
+Amphion can be installed through either Setup Installer or Docker Image.
+
+### Setup Installer
+
+```bash
+git clone https://github.com/open-mmlab/Amphion.git
+cd Amphion
+
+# Install Python Environment
+conda create --name amphion python=3.9.15
+conda activate amphion
+
+# Install Python Packages Dependencies
+sh env.sh
+```
+
+### Docker Image
+
+1. Install [Docker](https://docs.docker.com/get-docker/), [NVIDIA Driver](https://www.nvidia.com/download/index.aspx), [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html), and [CUDA](https://developer.nvidia.com/cuda-downloads).
+
+2. Run the following commands:
+```bash
+git clone https://github.com/open-mmlab/Amphion.git
+cd Amphion
+
+docker pull realamphion/amphion
+docker run --runtime=nvidia --gpus all -it -v .:/app realamphion/amphion
+```
+Mount dataset by argument `-v` is necessary when using Docker. Please refer to [Mount dataset in Docker container](egs/datasets/docker.md) and [Docker Docs](https://docs.docker.com/engine/reference/commandline/container_run/#volume) for more details.
+
+
+## 🐍 Usage in Python
+
+We detail the instructions of different tasks in the following recipes:
+
+- [Text to Speech (TTS)](egs/tts/README.md)
+- [Singing Voice Conversion (SVC)](egs/svc/README.md)
+- [Text to Audio (TTA)](egs/tta/README.md)
+- [Vocoder](egs/vocoder/README.md)
+- [Evaluation](egs/metrics/README.md)
+- [Visualization](egs/visualization/README.md)
+
+## 👨💻 Contributing
+We appreciate all contributions to improve Amphion. Please refer to [CONTRIBUTING.md](.github/CONTRIBUTING.md) for the contributing guideline.
+
+## 🙏 Acknowledgement
+
+
+- [ming024's FastSpeech2](https://github.com/ming024/FastSpeech2) and [jaywalnut310's VITS](https://github.com/jaywalnut310/vits) for model architecture code.
+- [lifeiteng's VALL-E](https://github.com/lifeiteng/vall-e) for training pipeline and model architecture design.
+- [SpeechTokenizer](https://github.com/ZhangXInFD/SpeechTokenizer) for semantic-distilled tokenizer design.
+- [WeNet](https://github.com/wenet-e2e/wenet), [Whisper](https://github.com/openai/whisper), [ContentVec](https://github.com/auspicious3000/contentvec), and [RawNet3](https://github.com/Jungjee/RawNet) for pretrained models and inference code.
+- [HiFi-GAN](https://github.com/jik876/hifi-gan) for GAN-based Vocoder's architecture design and training strategy.
+- [Encodec](https://github.com/facebookresearch/encodec) for well-organized GAN Discriminator's architecture and basic blocks.
+- [Latent Diffusion](https://github.com/CompVis/latent-diffusion) for model architecture design.
+- [TensorFlowTTS](https://github.com/TensorSpeech/TensorFlowTTS) for preparing the MFA tools.
+
+
+## ©️ License
+
+Amphion is under the [MIT License](LICENSE). It is free for both research and commercial use cases.
+
+## 📚 Citations
+
+```bibtex
+@inproceedings{amphion,
+ author={Zhang, Xueyao and Xue, Liumeng and Gu, Yicheng and Wang, Yuancheng and Li, Jiaqi and He, Haorui and Wang, Chaoren and Song, Ting and Chen, Xi and Fang, Zihao and Chen, Haopeng and Zhang, Junan and Tang, Tze Ying and Zou, Lexiao and Wang, Mingxuan and Han, Jun and Chen, Kai and Li, Haizhou and Wu, Zhizheng},
+ title={Amphion: An Open-Source Audio, Music and Speech Generation Toolkit},
+ booktitle={{IEEE} Spoken Language Technology Workshop, {SLT} 2024},
+ year={2024}
+}
+```
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/base/__init__.py b/models/base/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe0221047a62e0b9b3ddd112c79a700c48834fd1
--- /dev/null
+++ b/models/base/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .new_trainer import BaseTrainer
+from .new_inference import BaseInference
diff --git a/models/base/base_dataset.py b/models/base/base_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c1216a2395e2039bee6b79bf1b438eb5b967774
--- /dev/null
+++ b/models/base/base_dataset.py
@@ -0,0 +1,464 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import numpy as np
+import torch.utils.data
+from torch.nn.utils.rnn import pad_sequence
+import librosa
+
+from utils.data_utils import *
+from processors.acoustic_extractor import cal_normalized_mel
+from text import text_to_sequence
+from text.text_token_collation import phoneIDCollation
+
+
+class BaseOfflineDataset(torch.utils.data.Dataset):
+ def __init__(self, cfg, dataset, is_valid=False):
+ """
+ Args:
+ cfg: config
+ dataset: dataset name
+ is_valid: whether to use train or valid dataset
+ """
+
+ assert isinstance(dataset, str)
+
+ # self.data_root = processed_data_dir
+ self.cfg = cfg
+
+ processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
+ meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
+ self.metafile_path = os.path.join(processed_data_dir, meta_file)
+ self.metadata = self.get_metadata()
+
+ """
+ load spk2id and utt2spk from json file
+ spk2id: {spk1: 0, spk2: 1, ...}
+ utt2spk: {dataset_uid: spk1, ...}
+ """
+ if cfg.preprocess.use_spkid:
+ spk2id_path = os.path.join(processed_data_dir, cfg.preprocess.spk2id)
+ with open(spk2id_path, "r") as f:
+ self.spk2id = json.load(f)
+
+ utt2spk_path = os.path.join(processed_data_dir, cfg.preprocess.utt2spk)
+ self.utt2spk = dict()
+ with open(utt2spk_path, "r") as f:
+ for line in f.readlines():
+ utt, spk = line.strip().split("\t")
+ self.utt2spk[utt] = spk
+
+ if cfg.preprocess.use_uv:
+ self.utt2uv_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+ self.utt2uv_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.uv_dir,
+ uid + ".npy",
+ )
+
+ if cfg.preprocess.use_frame_pitch:
+ self.utt2frame_pitch_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2frame_pitch_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.pitch_dir,
+ uid + ".npy",
+ )
+
+ if cfg.preprocess.use_frame_energy:
+ self.utt2frame_energy_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2frame_energy_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.energy_dir,
+ uid + ".npy",
+ )
+
+ if cfg.preprocess.use_mel:
+ self.utt2mel_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2mel_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.mel_dir,
+ uid + ".npy",
+ )
+
+ if cfg.preprocess.use_linear:
+ self.utt2linear_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2linear_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.linear_dir,
+ uid + ".npy",
+ )
+
+ if cfg.preprocess.use_audio:
+ self.utt2audio_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2audio_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.audio_dir,
+ uid + ".npy",
+ )
+ elif cfg.preprocess.use_label:
+ self.utt2label_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2label_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.label_dir,
+ uid + ".npy",
+ )
+ elif cfg.preprocess.use_one_hot:
+ self.utt2one_hot_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2one_hot_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.one_hot_dir,
+ uid + ".npy",
+ )
+
+ if cfg.preprocess.use_text or cfg.preprocess.use_phone:
+ self.utt2seq = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ if cfg.preprocess.use_text:
+ text = utt_info["Text"]
+ sequence = text_to_sequence(text, cfg.preprocess.text_cleaners)
+ elif cfg.preprocess.use_phone:
+ # load phoneme squence from phone file
+ phone_path = os.path.join(
+ processed_data_dir, cfg.preprocess.phone_dir, uid + ".phone"
+ )
+ with open(phone_path, "r") as fin:
+ phones = fin.readlines()
+ assert len(phones) == 1
+ phones = phones[0].strip()
+ phones_seq = phones.split(" ")
+
+ phon_id_collator = phoneIDCollation(cfg, dataset=dataset)
+ sequence = phon_id_collator.get_phone_id_sequence(cfg, phones_seq)
+
+ self.utt2seq[utt] = sequence
+
+ def get_metadata(self):
+ with open(self.metafile_path, "r", encoding="utf-8") as f:
+ metadata = json.load(f)
+
+ return metadata
+
+ def get_dataset_name(self):
+ return self.metadata[0]["Dataset"]
+
+ def __getitem__(self, index):
+ utt_info = self.metadata[index]
+
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ single_feature = dict()
+
+ if self.cfg.preprocess.use_spkid:
+ single_feature["spk_id"] = np.array(
+ [self.spk2id[self.utt2spk[utt]]], dtype=np.int32
+ )
+
+ if self.cfg.preprocess.use_mel:
+ mel = np.load(self.utt2mel_path[utt])
+ assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T]
+ if self.cfg.preprocess.use_min_max_norm_mel:
+ # do mel norm
+ mel = cal_normalized_mel(mel, utt_info["Dataset"], self.cfg.preprocess)
+
+ if "target_len" not in single_feature.keys():
+ single_feature["target_len"] = mel.shape[1]
+ single_feature["mel"] = mel.T # [T, n_mels]
+
+ if self.cfg.preprocess.use_linear:
+ linear = np.load(self.utt2linear_path[utt])
+ if "target_len" not in single_feature.keys():
+ single_feature["target_len"] = linear.shape[1]
+ single_feature["linear"] = linear.T # [T, n_linear]
+
+ if self.cfg.preprocess.use_frame_pitch:
+ frame_pitch_path = self.utt2frame_pitch_path[utt]
+ frame_pitch = np.load(frame_pitch_path)
+ if "target_len" not in single_feature.keys():
+ single_feature["target_len"] = len(frame_pitch)
+ aligned_frame_pitch = align_length(
+ frame_pitch, single_feature["target_len"]
+ )
+ single_feature["frame_pitch"] = aligned_frame_pitch
+
+ if self.cfg.preprocess.use_uv:
+ frame_uv_path = self.utt2uv_path[utt]
+ frame_uv = np.load(frame_uv_path)
+ aligned_frame_uv = align_length(frame_uv, single_feature["target_len"])
+ aligned_frame_uv = [
+ 0 if frame_uv else 1 for frame_uv in aligned_frame_uv
+ ]
+ aligned_frame_uv = np.array(aligned_frame_uv)
+ single_feature["frame_uv"] = aligned_frame_uv
+
+ if self.cfg.preprocess.use_frame_energy:
+ frame_energy_path = self.utt2frame_energy_path[utt]
+ frame_energy = np.load(frame_energy_path)
+ if "target_len" not in single_feature.keys():
+ single_feature["target_len"] = len(frame_energy)
+ aligned_frame_energy = align_length(
+ frame_energy, single_feature["target_len"]
+ )
+ single_feature["frame_energy"] = aligned_frame_energy
+
+ if self.cfg.preprocess.use_audio:
+ audio = np.load(self.utt2audio_path[utt])
+ single_feature["audio"] = audio
+ single_feature["audio_len"] = audio.shape[0]
+
+ if self.cfg.preprocess.use_phone or self.cfg.preprocess.use_text:
+ single_feature["phone_seq"] = np.array(self.utt2seq[utt])
+ single_feature["phone_len"] = len(self.utt2seq[utt])
+
+ return single_feature
+
+ def __len__(self):
+ return len(self.metadata)
+
+
+class BaseOfflineCollator(object):
+ """Zero-pads model inputs and targets based on number of frames per step"""
+
+ def __init__(self, cfg):
+ self.cfg = cfg
+
+ def __call__(self, batch):
+ packed_batch_features = dict()
+
+ # mel: [b, T, n_mels]
+ # frame_pitch, frame_energy: [1, T]
+ # target_len: [b]
+ # spk_id: [b, 1]
+ # mask: [b, T, 1]
+
+ for key in batch[0].keys():
+ if key == "target_len":
+ packed_batch_features["target_len"] = torch.LongTensor(
+ [b["target_len"] for b in batch]
+ )
+ masks = [
+ torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
+ ]
+ packed_batch_features["mask"] = pad_sequence(
+ masks, batch_first=True, padding_value=0
+ )
+ elif key == "phone_len":
+ packed_batch_features["phone_len"] = torch.LongTensor(
+ [b["phone_len"] for b in batch]
+ )
+ masks = [
+ torch.ones((b["phone_len"], 1), dtype=torch.long) for b in batch
+ ]
+ packed_batch_features["phn_mask"] = pad_sequence(
+ masks, batch_first=True, padding_value=0
+ )
+ elif key == "audio_len":
+ packed_batch_features["audio_len"] = torch.LongTensor(
+ [b["audio_len"] for b in batch]
+ )
+ masks = [
+ torch.ones((b["audio_len"], 1), dtype=torch.long) for b in batch
+ ]
+ else:
+ values = [torch.from_numpy(b[key]) for b in batch]
+ packed_batch_features[key] = pad_sequence(
+ values, batch_first=True, padding_value=0
+ )
+ return packed_batch_features
+
+
+class BaseOnlineDataset(torch.utils.data.Dataset):
+ def __init__(self, cfg, dataset, is_valid=False):
+ """
+ Args:
+ cfg: config
+ dataset: dataset name
+ is_valid: whether to use train or valid dataset
+ """
+ assert isinstance(dataset, str)
+
+ self.cfg = cfg
+ self.sample_rate = cfg.preprocess.sample_rate
+ self.hop_size = self.cfg.preprocess.hop_size
+
+ processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
+ meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
+ self.metafile_path = os.path.join(processed_data_dir, meta_file)
+ self.metadata = self.get_metadata()
+
+ """
+ load spk2id and utt2spk from json file
+ spk2id: {spk1: 0, spk2: 1, ...}
+ utt2spk: {dataset_uid: spk1, ...}
+ """
+ if cfg.preprocess.use_spkid:
+ spk2id_path = os.path.join(processed_data_dir, cfg.preprocess.spk2id)
+ with open(spk2id_path, "r") as f:
+ self.spk2id = json.load(f)
+
+ utt2spk_path = os.path.join(processed_data_dir, cfg.preprocess.utt2spk)
+ self.utt2spk = dict()
+ with open(utt2spk_path, "r") as f:
+ for line in f.readlines():
+ utt, spk = line.strip().split("\t")
+ self.utt2spk[utt] = spk
+
+ def get_metadata(self):
+ with open(self.metafile_path, "r", encoding="utf-8") as f:
+ metadata = json.load(f)
+
+ return metadata
+
+ def get_dataset_name(self):
+ return self.metadata[0]["Dataset"]
+
+ def __getitem__(self, index):
+ """
+ single_feature:
+ wav: (T)
+ wav_len: int
+ target_len: int
+ mask: (n_frames, 1)
+ spk_id: (1)
+ """
+ utt_item = self.metadata[index]
+
+ wav_path = utt_item["Path"]
+ wav, _ = librosa.load(wav_path, sr=self.sample_rate)
+ # wav: (T)
+ wav = torch.as_tensor(wav, dtype=torch.float32)
+ wav_len = len(wav)
+ # mask: (n_frames, 1)
+ frame_len = wav_len // self.hop_size
+ mask = torch.ones(frame_len, 1, dtype=torch.long)
+
+ single_feature = {
+ "wav": wav,
+ "wav_len": wav_len,
+ "target_len": frame_len,
+ "mask": mask,
+ }
+
+ if self.cfg.preprocess.use_spkid:
+ utt = "{}_{}".format(utt_item["Dataset"], utt_item["Uid"])
+ single_feature["spk_id"] = torch.tensor(
+ [self.spk2id[self.utt2spk[utt]]], dtype=torch.int32
+ )
+
+ return single_feature
+
+ def __len__(self):
+ return len(self.metadata)
+
+
+class BaseOnlineCollator(object):
+ """Zero-pads model inputs and targets based on number of frames per step (For on-the-fly features extraction, whose iterative item contains only wavs)"""
+
+ def __init__(self, cfg):
+ self.cfg = cfg
+
+ def __call__(self, batch):
+ """
+ BaseOnlineDataset.__getitem__:
+ wav: (T,)
+ wav_len: int
+ target_len: int
+ mask: (n_frames, 1)
+ spk_id: (1)
+
+ Returns:
+ wav: (B, T), torch.float32
+ wav_len: (B), torch.long
+ target_len: (B), torch.long
+ mask: (B, n_frames, 1), torch.long
+ spk_id: (B, 1), torch.int32
+ """
+ packed_batch_features = dict()
+
+ for key in batch[0].keys():
+ if key in ["wav_len", "target_len"]:
+ packed_batch_features[key] = torch.LongTensor([b[key] for b in batch])
+ else:
+ packed_batch_features[key] = pad_sequence(
+ [b[key] for b in batch], batch_first=True, padding_value=0
+ )
+ return packed_batch_features
+
+
+class BaseTestDataset(torch.utils.data.Dataset):
+ def __init__(self, cfg, args):
+ raise NotImplementedError
+
+ def get_metadata(self):
+ raise NotImplementedError
+
+ def __getitem__(self, index):
+ raise NotImplementedError
+
+ def __len__(self):
+ return len(self.metadata)
+
+
+class BaseTestCollator(object):
+ """Zero-pads model inputs and targets based on number of frames per step"""
+
+ def __init__(self, cfg):
+ raise NotImplementedError
+
+ def __call__(self, batch):
+ raise NotImplementedError
diff --git a/models/base/base_inference.py b/models/base/base_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..2713f19a0d61f06bca1f01de5ccd8a3b4d2cc02f
--- /dev/null
+++ b/models/base/base_inference.py
@@ -0,0 +1,220 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import os
+import re
+import time
+from pathlib import Path
+
+import torch
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+
+from models.vocoders.vocoder_inference import synthesis
+from torch.utils.data import DataLoader
+from utils.util import set_all_random_seed
+from utils.util import load_config
+
+
+def parse_vocoder(vocoder_dir):
+ r"""Parse vocoder config"""
+ vocoder_dir = os.path.abspath(vocoder_dir)
+ ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")]
+ ckpt_list.sort(key=lambda x: int(x.stem), reverse=True)
+ ckpt_path = str(ckpt_list[0])
+ vocoder_cfg = load_config(os.path.join(vocoder_dir, "args.json"), lowercase=True)
+ vocoder_cfg.model.bigvgan = vocoder_cfg.vocoder
+ return vocoder_cfg, ckpt_path
+
+
+class BaseInference(object):
+ def __init__(self, cfg, args):
+ self.cfg = cfg
+ self.args = args
+ self.model_type = cfg.model_type
+ self.avg_rtf = list()
+ set_all_random_seed(10086)
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if torch.cuda.is_available():
+ self.device = torch.device("cuda")
+ else:
+ self.device = torch.device("cpu")
+ torch.set_num_threads(10) # inference on 1 core cpu.
+
+ # Load acoustic model
+ self.model = self.create_model().to(self.device)
+ state_dict = self.load_state_dict()
+ self.load_model(state_dict)
+ self.model.eval()
+
+ # Load vocoder model if necessary
+ if self.args.checkpoint_dir_vocoder is not None:
+ self.get_vocoder_info()
+
+ def create_model(self):
+ raise NotImplementedError
+
+ def load_state_dict(self):
+ self.checkpoint_file = self.args.checkpoint_file
+ if self.checkpoint_file is None:
+ assert self.args.checkpoint_dir is not None
+ checkpoint_path = os.path.join(self.args.checkpoint_dir, "checkpoint")
+ checkpoint_filename = open(checkpoint_path).readlines()[-1].strip()
+ self.checkpoint_file = os.path.join(
+ self.args.checkpoint_dir, checkpoint_filename
+ )
+
+ self.checkpoint_dir = os.path.split(self.checkpoint_file)[0]
+
+ print("Restore acoustic model from {}".format(self.checkpoint_file))
+ raw_state_dict = torch.load(self.checkpoint_file, map_location=self.device)
+ self.am_restore_step = re.findall(r"step-(.+?)_loss", self.checkpoint_file)[0]
+
+ return raw_state_dict
+
+ def load_model(self, model):
+ raise NotImplementedError
+
+ def get_vocoder_info(self):
+ self.checkpoint_dir_vocoder = self.args.checkpoint_dir_vocoder
+ self.vocoder_cfg = os.path.join(
+ os.path.dirname(self.checkpoint_dir_vocoder), "args.json"
+ )
+ self.cfg.vocoder = load_config(self.vocoder_cfg, lowercase=True)
+ self.vocoder_tag = self.checkpoint_dir_vocoder.split("/")[-2].split(":")[-1]
+ self.vocoder_steps = self.checkpoint_dir_vocoder.split("/")[-1].split(".")[0]
+
+ def build_test_utt_data(self):
+ raise NotImplementedError
+
+ def build_testdata_loader(self, args, target_speaker=None):
+ datasets, collate = self.build_test_dataset()
+ self.test_dataset = datasets(self.cfg, args, target_speaker)
+ self.test_collate = collate(self.cfg)
+ self.test_batch_size = min(
+ self.cfg.train.batch_size, len(self.test_dataset.metadata)
+ )
+ test_loader = DataLoader(
+ self.test_dataset,
+ collate_fn=self.test_collate,
+ num_workers=self.args.num_workers,
+ batch_size=self.test_batch_size,
+ shuffle=False,
+ )
+ return test_loader
+
+ def inference_each_batch(self, batch_data):
+ raise NotImplementedError
+
+ def inference_for_batches(self, args, target_speaker=None):
+ ###### Construct test_batch ######
+ loader = self.build_testdata_loader(args, target_speaker)
+
+ n_batch = len(loader)
+ now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
+ print(
+ "Model eval time: {}, batch_size = {}, n_batch = {}".format(
+ now, self.test_batch_size, n_batch
+ )
+ )
+ self.model.eval()
+
+ ###### Inference for each batch ######
+ pred_res = []
+ with torch.no_grad():
+ for i, batch_data in enumerate(loader if n_batch == 1 else tqdm(loader)):
+ # Put the data to device
+ for k, v in batch_data.items():
+ batch_data[k] = batch_data[k].to(self.device)
+
+ y_pred, stats = self.inference_each_batch(batch_data)
+
+ pred_res += y_pred
+
+ return pred_res
+
+ def inference(self, feature):
+ raise NotImplementedError
+
+ def synthesis_by_vocoder(self, pred):
+ audios_pred = synthesis(
+ self.vocoder_cfg,
+ self.checkpoint_dir_vocoder,
+ len(pred),
+ pred,
+ )
+ return audios_pred
+
+ def __call__(self, utt):
+ feature = self.build_test_utt_data(utt)
+ start_time = time.time()
+ with torch.no_grad():
+ outputs = self.inference(feature)[0]
+ time_used = time.time() - start_time
+ rtf = time_used / (
+ outputs.shape[1]
+ * self.cfg.preprocess.hop_size
+ / self.cfg.preprocess.sample_rate
+ )
+ print("Time used: {:.3f}, RTF: {:.4f}".format(time_used, rtf))
+ self.avg_rtf.append(rtf)
+ audios = outputs.cpu().squeeze().numpy().reshape(-1, 1)
+ return audios
+
+
+def base_parser():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--config", default="config.json", help="json files for configurations."
+ )
+ parser.add_argument("--use_ddp_inference", default=False)
+ parser.add_argument("--n_workers", default=1, type=int)
+ parser.add_argument("--local_rank", default=-1, type=int)
+ parser.add_argument(
+ "--batch_size", default=1, type=int, help="Batch size for inference"
+ )
+ parser.add_argument(
+ "--num_workers",
+ default=1,
+ type=int,
+ help="Worker number for inference dataloader",
+ )
+ parser.add_argument(
+ "--checkpoint_dir",
+ type=str,
+ default=None,
+ help="Checkpoint dir including model file and configuration",
+ )
+ parser.add_argument(
+ "--checkpoint_file", help="checkpoint file", type=str, default=None
+ )
+ parser.add_argument(
+ "--test_list", help="test utterance list for testing", type=str, default=None
+ )
+ parser.add_argument(
+ "--checkpoint_dir_vocoder",
+ help="Vocoder's checkpoint dir including model file and configuration",
+ type=str,
+ default=None,
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default=None,
+ help="Output dir for saving generated results",
+ )
+ return parser
+
+
+if __name__ == "__main__":
+ parser = base_parser()
+ args = parser.parse_args()
+ cfg = load_config(args.config)
+
+ # Build inference
+ inference = BaseInference(cfg, args)
+ inference()
diff --git a/models/base/base_sampler.py b/models/base/base_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5e882ac209bc2928d7945c3b2d6cb98a3a553fe
--- /dev/null
+++ b/models/base/base_sampler.py
@@ -0,0 +1,157 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import random
+
+from torch.utils.data import ConcatDataset, Dataset
+from torch.utils.data.sampler import (
+ BatchSampler,
+ RandomSampler,
+ Sampler,
+ SequentialSampler,
+)
+
+
+class ScheduledSampler(Sampler):
+ """A sampler that samples data from a given concat-dataset.
+
+ Args:
+ concat_dataset (ConcatDataset): a concatenated dataset consisting of all datasets
+ batch_size (int): batch size
+ holistic_shuffle (bool): whether to shuffle the whole dataset or not
+ logger (logging.Logger): logger to print warning message
+
+ Usage:
+ For cfg.train.batch_size = 3, cfg.train.holistic_shuffle = False, cfg.train.drop_last = True:
+ >>> list(ScheduledSampler(ConcatDataset([[0, 1, 2], [3, 4, 5], [6, 7, 8]])))
+ [3, 4, 5, 0, 1, 2, 6, 7, 8]
+ """
+
+ def __init__(
+ self,
+ concat_dataset,
+ batch_size,
+ holistic_shuffle,
+ logger=None,
+ loader_type="train",
+ ):
+ if not isinstance(concat_dataset, ConcatDataset):
+ raise ValueError(
+ "concat_dataset must be an instance of ConcatDataset, but got {}".format(
+ type(concat_dataset)
+ )
+ )
+ if not isinstance(batch_size, int):
+ raise ValueError(
+ "batch_size must be an integer, but got {}".format(type(batch_size))
+ )
+ if not isinstance(holistic_shuffle, bool):
+ raise ValueError(
+ "holistic_shuffle must be a boolean, but got {}".format(
+ type(holistic_shuffle)
+ )
+ )
+
+ self.concat_dataset = concat_dataset
+ self.batch_size = batch_size
+ self.holistic_shuffle = holistic_shuffle
+
+ affected_dataset_name = []
+ affected_dataset_len = []
+ for dataset in concat_dataset.datasets:
+ dataset_len = len(dataset)
+ dataset_name = dataset.get_dataset_name()
+ if dataset_len < batch_size:
+ affected_dataset_name.append(dataset_name)
+ affected_dataset_len.append(dataset_len)
+
+ self.type = loader_type
+ for dataset_name, dataset_len in zip(
+ affected_dataset_name, affected_dataset_len
+ ):
+ if not loader_type == "valid":
+ logger.warning(
+ "The {} dataset {} has a length of {}, which is smaller than the batch size {}. This may cause unexpected behavior.".format(
+ loader_type, dataset_name, dataset_len, batch_size
+ )
+ )
+
+ def __len__(self):
+ # the number of batches with drop last
+ num_of_batches = sum(
+ [
+ math.floor(len(dataset) / self.batch_size)
+ for dataset in self.concat_dataset.datasets
+ ]
+ )
+ # if samples are not enough for one batch, we don't drop last
+ if self.type == "valid" and num_of_batches < 1:
+ return len(self.concat_dataset)
+ return num_of_batches * self.batch_size
+
+ def __iter__(self):
+ iters = []
+ for dataset in self.concat_dataset.datasets:
+ iters.append(
+ SequentialSampler(dataset).__iter__()
+ if not self.holistic_shuffle
+ else RandomSampler(dataset).__iter__()
+ )
+ # e.g. [0, 200, 400]
+ init_indices = [0] + self.concat_dataset.cumulative_sizes[:-1]
+ output_batches = []
+ for dataset_idx in range(len(self.concat_dataset.datasets)):
+ cur_batch = []
+ for idx in iters[dataset_idx]:
+ cur_batch.append(idx + init_indices[dataset_idx])
+ if len(cur_batch) == self.batch_size:
+ output_batches.append(cur_batch)
+ cur_batch = []
+ # if loader_type is valid, we don't need to drop last
+ if self.type == "valid" and len(cur_batch) > 0:
+ output_batches.append(cur_batch)
+
+ # force drop last in training
+ random.shuffle(output_batches)
+ output_indices = [item for sublist in output_batches for item in sublist]
+ return iter(output_indices)
+
+
+def build_samplers(concat_dataset: Dataset, cfg, logger, loader_type):
+ sampler = ScheduledSampler(
+ concat_dataset,
+ cfg.train.batch_size,
+ cfg.train.sampler.holistic_shuffle,
+ logger,
+ loader_type,
+ )
+ batch_sampler = BatchSampler(
+ sampler,
+ cfg.train.batch_size,
+ cfg.train.sampler.drop_last if not loader_type == "valid" else False,
+ )
+ return sampler, batch_sampler
+
+
+class VariableSampler(BatchSampler):
+ def __init__(self, sampler, drop_last: bool, use_random_sampler=False):
+ self.data_list = sampler
+ if use_random_sampler:
+ self.sampler = RandomSampler(sampler)
+ else:
+ self.sampler = SequentialSampler(sampler)
+
+ super().__init__(self.sampler, 1, drop_last)
+
+ def __iter__(self):
+ for batch_ids in self.data_list:
+ yield batch_ids
+
+ def __len__(self):
+ if self.drop_last:
+ return len(self.sampler) // self.batch_size
+ else:
+ return (len(self.sampler) + self.batch_size - 1) // self.batch_size
diff --git a/models/base/base_trainer.py b/models/base/base_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8782216dc13ce5d9de05ae790faeb82cf7cfd501
--- /dev/null
+++ b/models/base/base_trainer.py
@@ -0,0 +1,348 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import collections
+import json
+import os
+import sys
+import time
+
+import torch
+import torch.distributed as dist
+from torch.nn.parallel import DistributedDataParallel
+from torch.utils.data import ConcatDataset, DataLoader
+from torch.utils.tensorboard import SummaryWriter
+
+from models.base.base_sampler import BatchSampler
+from utils.util import (
+ Logger,
+ remove_older_ckpt,
+ save_config,
+ set_all_random_seed,
+ ValueWindow,
+)
+
+
+class BaseTrainer(object):
+ def __init__(self, args, cfg):
+ self.args = args
+ self.log_dir = args.log_dir
+ self.cfg = cfg
+
+ self.checkpoint_dir = os.path.join(args.log_dir, "checkpoints")
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
+ if not cfg.train.ddp or args.local_rank == 0:
+ self.sw = SummaryWriter(os.path.join(args.log_dir, "events"))
+ self.logger = self.build_logger()
+ self.time_window = ValueWindow(50)
+
+ self.step = 0
+ self.epoch = -1
+ self.max_epochs = self.cfg.train.epochs
+ self.max_steps = self.cfg.train.max_steps
+
+ # set random seed & init distributed training
+ set_all_random_seed(self.cfg.train.random_seed)
+ if cfg.train.ddp:
+ dist.init_process_group(backend="nccl")
+
+ if cfg.model_type not in ["AutoencoderKL", "AudioLDM"]:
+ self.singers = self.build_singers_lut()
+
+ # setup data_loader
+ self.data_loader = self.build_data_loader()
+
+ # setup model & enable distributed training
+ self.model = self.build_model()
+ print(self.model)
+
+ if isinstance(self.model, dict):
+ for key, value in self.model.items():
+ value.cuda(self.args.local_rank)
+ if key == "PQMF":
+ continue
+ if cfg.train.ddp:
+ self.model[key] = DistributedDataParallel(
+ value, device_ids=[self.args.local_rank]
+ )
+ else:
+ self.model.cuda(self.args.local_rank)
+ if cfg.train.ddp:
+ self.model = DistributedDataParallel(
+ self.model, device_ids=[self.args.local_rank]
+ )
+
+ # create criterion
+ self.criterion = self.build_criterion()
+ if isinstance(self.criterion, dict):
+ for key, value in self.criterion.items():
+ self.criterion[key].cuda(args.local_rank)
+ else:
+ self.criterion.cuda(self.args.local_rank)
+
+ # optimizer
+ self.optimizer = self.build_optimizer()
+ self.scheduler = self.build_scheduler()
+
+ # save config file
+ self.config_save_path = os.path.join(self.checkpoint_dir, "args.json")
+
+ def build_logger(self):
+ log_file = os.path.join(self.checkpoint_dir, "train.log")
+ logger = Logger(log_file, level=self.args.log_level).logger
+
+ return logger
+
+ def build_dataset(self):
+ raise NotImplementedError
+
+ def build_data_loader(self):
+ Dataset, Collator = self.build_dataset()
+ # build dataset instance for each dataset and combine them by ConcatDataset
+ datasets_list = []
+ for dataset in self.cfg.dataset:
+ subdataset = Dataset(self.cfg, dataset, is_valid=False)
+ datasets_list.append(subdataset)
+ train_dataset = ConcatDataset(datasets_list)
+
+ train_collate = Collator(self.cfg)
+ # TODO: multi-GPU training
+ if self.cfg.train.ddp:
+ raise NotImplementedError("DDP is not supported yet.")
+
+ # sampler will provide indices to batch_sampler, which will perform batching and yield batch indices
+ batch_sampler = BatchSampler(
+ cfg=self.cfg, concat_dataset=train_dataset, dataset_list=datasets_list
+ )
+
+ # use batch_sampler argument instead of (sampler, shuffle, drop_last, batch_size)
+ train_loader = DataLoader(
+ train_dataset,
+ collate_fn=train_collate,
+ num_workers=self.args.num_workers,
+ batch_sampler=batch_sampler,
+ pin_memory=False,
+ )
+ if not self.cfg.train.ddp or self.args.local_rank == 0:
+ datasets_list = []
+ for dataset in self.cfg.dataset:
+ subdataset = Dataset(self.cfg, dataset, is_valid=True)
+ datasets_list.append(subdataset)
+ valid_dataset = ConcatDataset(datasets_list)
+ valid_collate = Collator(self.cfg)
+ batch_sampler = BatchSampler(
+ cfg=self.cfg, concat_dataset=valid_dataset, dataset_list=datasets_list
+ )
+ valid_loader = DataLoader(
+ valid_dataset,
+ collate_fn=valid_collate,
+ num_workers=1,
+ batch_sampler=batch_sampler,
+ )
+ else:
+ raise NotImplementedError("DDP is not supported yet.")
+ # valid_loader = None
+ data_loader = {"train": train_loader, "valid": valid_loader}
+ return data_loader
+
+ def build_singers_lut(self):
+ # combine singers
+ if not os.path.exists(os.path.join(self.log_dir, self.cfg.preprocess.spk2id)):
+ singers = collections.OrderedDict()
+ else:
+ with open(
+ os.path.join(self.log_dir, self.cfg.preprocess.spk2id), "r"
+ ) as singer_file:
+ singers = json.load(singer_file)
+ singer_count = len(singers)
+ for dataset in self.cfg.dataset:
+ singer_lut_path = os.path.join(
+ self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id
+ )
+ with open(singer_lut_path, "r") as singer_lut_path:
+ singer_lut = json.load(singer_lut_path)
+ for singer in singer_lut.keys():
+ if singer not in singers:
+ singers[singer] = singer_count
+ singer_count += 1
+ with open(
+ os.path.join(self.log_dir, self.cfg.preprocess.spk2id), "w"
+ ) as singer_file:
+ json.dump(singers, singer_file, indent=4, ensure_ascii=False)
+ print(
+ "singers have been dumped to {}".format(
+ os.path.join(self.log_dir, self.cfg.preprocess.spk2id)
+ )
+ )
+ return singers
+
+ def build_model(self):
+ raise NotImplementedError()
+
+ def build_optimizer(self):
+ raise NotImplementedError
+
+ def build_scheduler(self):
+ raise NotImplementedError()
+
+ def build_criterion(self):
+ raise NotImplementedError
+
+ def get_state_dict(self):
+ raise NotImplementedError
+
+ def save_config_file(self):
+ save_config(self.config_save_path, self.cfg)
+
+ # TODO, save without module.
+ def save_checkpoint(self, state_dict, saved_model_path):
+ torch.save(state_dict, saved_model_path)
+
+ def load_checkpoint(self):
+ checkpoint_path = os.path.join(self.checkpoint_dir, "checkpoint")
+ assert os.path.exists(checkpoint_path)
+ checkpoint_filename = open(checkpoint_path).readlines()[-1].strip()
+ model_path = os.path.join(self.checkpoint_dir, checkpoint_filename)
+ assert os.path.exists(model_path)
+ if not self.cfg.train.ddp or self.args.local_rank == 0:
+ self.logger.info(f"Re(store) from {model_path}")
+ checkpoint = torch.load(model_path, map_location="cpu")
+ return checkpoint
+
+ def load_model(self, checkpoint):
+ raise NotImplementedError
+
+ def restore(self):
+ checkpoint = self.load_checkpoint()
+ self.load_model(checkpoint)
+
+ def train_step(self, data):
+ raise NotImplementedError(
+ f"Need to implement function {sys._getframe().f_code.co_name} in "
+ f"your sub-class of {self.__class__.__name__}. "
+ )
+
+ @torch.no_grad()
+ def eval_step(self):
+ raise NotImplementedError(
+ f"Need to implement function {sys._getframe().f_code.co_name} in "
+ f"your sub-class of {self.__class__.__name__}. "
+ )
+
+ def write_summary(self, losses, stats):
+ raise NotImplementedError(
+ f"Need to implement function {sys._getframe().f_code.co_name} in "
+ f"your sub-class of {self.__class__.__name__}. "
+ )
+
+ def write_valid_summary(self, losses, stats):
+ raise NotImplementedError(
+ f"Need to implement function {sys._getframe().f_code.co_name} in "
+ f"your sub-class of {self.__class__.__name__}. "
+ )
+
+ def echo_log(self, losses, mode="Training"):
+ message = [
+ "{} - Epoch {} Step {}: [{:.3f} s/step]".format(
+ mode, self.epoch + 1, self.step, self.time_window.average
+ )
+ ]
+
+ for key in sorted(losses.keys()):
+ if isinstance(losses[key], dict):
+ for k, v in losses[key].items():
+ message.append(
+ str(k).split("/")[-1] + "=" + str(round(float(v), 5))
+ )
+ else:
+ message.append(
+ str(key).split("/")[-1] + "=" + str(round(float(losses[key]), 5))
+ )
+ self.logger.info(", ".join(message))
+
+ def eval_epoch(self):
+ self.logger.info("Validation...")
+ valid_losses = {}
+ for i, batch_data in enumerate(self.data_loader["valid"]):
+ for k, v in batch_data.items():
+ if isinstance(v, torch.Tensor):
+ batch_data[k] = v.cuda()
+ valid_loss, valid_stats, total_valid_loss = self.eval_step(batch_data, i)
+ for key in valid_loss:
+ if key not in valid_losses:
+ valid_losses[key] = 0
+ valid_losses[key] += valid_loss[key]
+
+ # Add mel and audio to the Tensorboard
+ # Average loss
+ for key in valid_losses:
+ valid_losses[key] /= i + 1
+ self.echo_log(valid_losses, "Valid")
+ return valid_losses, valid_stats
+
+ def train_epoch(self):
+ for i, batch_data in enumerate(self.data_loader["train"]):
+ start_time = time.time()
+ # Put the data to cuda device
+ for k, v in batch_data.items():
+ if isinstance(v, torch.Tensor):
+ batch_data[k] = v.cuda(self.args.local_rank)
+
+ # Training step
+ train_losses, train_stats, total_loss = self.train_step(batch_data)
+ self.time_window.append(time.time() - start_time)
+
+ if self.args.local_rank == 0 or not self.cfg.train.ddp:
+ if self.step % self.args.stdout_interval == 0:
+ self.echo_log(train_losses, "Training")
+
+ if self.step % self.cfg.train.save_summary_steps == 0:
+ self.logger.info(f"Save summary as step {self.step}")
+ self.write_summary(train_losses, train_stats)
+
+ if (
+ self.step % self.cfg.train.save_checkpoints_steps == 0
+ and self.step != 0
+ ):
+ saved_model_name = "step-{:07d}_loss-{:.4f}.pt".format(
+ self.step, total_loss
+ )
+ saved_model_path = os.path.join(
+ self.checkpoint_dir, saved_model_name
+ )
+ saved_state_dict = self.get_state_dict()
+ self.save_checkpoint(saved_state_dict, saved_model_path)
+ self.save_config_file()
+ # keep max n models
+ remove_older_ckpt(
+ saved_model_name,
+ self.checkpoint_dir,
+ max_to_keep=self.cfg.train.keep_checkpoint_max,
+ )
+
+ if self.step != 0 and self.step % self.cfg.train.valid_interval == 0:
+ if isinstance(self.model, dict):
+ for key in self.model.keys():
+ self.model[key].eval()
+ else:
+ self.model.eval()
+ # Evaluate one epoch and get average loss
+ valid_losses, valid_stats = self.eval_epoch()
+ if isinstance(self.model, dict):
+ for key in self.model.keys():
+ self.model[key].train()
+ else:
+ self.model.train()
+ # Write validation losses to summary.
+ self.write_valid_summary(valid_losses, valid_stats)
+ self.step += 1
+
+ def train(self):
+ for epoch in range(max(0, self.epoch), self.max_epochs):
+ self.train_epoch()
+ self.epoch += 1
+ if self.step > self.max_steps:
+ self.logger.info("Training finished!")
+ break
diff --git a/models/base/new_dataset.py b/models/base/new_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..2201bb4132ab86d1110092d7ab9e509296367a22
--- /dev/null
+++ b/models/base/new_dataset.py
@@ -0,0 +1,50 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import json
+import os
+from abc import abstractmethod
+from pathlib import Path
+
+import json5
+import torch
+import yaml
+
+
+# TODO: for training and validating
+class BaseDataset(torch.utils.data.Dataset):
+ r"""Base dataset for training and validating."""
+
+ def __init__(self, args, cfg, is_valid=False):
+ pass
+
+
+class BaseTestDataset(torch.utils.data.Dataset):
+ r"""Test dataset for inference."""
+
+ def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
+ assert infer_type in ["from_dataset", "from_file"]
+
+ self.args = args
+ self.cfg = cfg
+ self.infer_type = infer_type
+
+ @abstractmethod
+ def __getitem__(self, index):
+ pass
+
+ def __len__(self):
+ return len(self.metadata)
+
+ def get_metadata(self):
+ path = Path(self.args.source)
+ if path.suffix == ".json" or path.suffix == ".jsonc":
+ metadata = json5.load(open(self.args.source, "r"))
+ elif path.suffix == ".yaml" or path.suffix == ".yml":
+ metadata = yaml.full_load(open(self.args.source, "r"))
+ else:
+ raise ValueError(f"Unsupported file type: {path.suffix}")
+
+ return metadata
diff --git a/models/base/new_inference.py b/models/base/new_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..01dce86d35bd0a04349bdd091537e6a6b0340ac8
--- /dev/null
+++ b/models/base/new_inference.py
@@ -0,0 +1,253 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import random
+import re
+import time
+from abc import abstractmethod
+from pathlib import Path
+
+import accelerate
+import json5
+import numpy as np
+import torch
+from accelerate.logging import get_logger
+from torch.utils.data import DataLoader
+
+from models.vocoders.vocoder_inference import synthesis
+from utils.io import save_audio
+from utils.util import load_config
+from utils.audio_slicer import is_silence
+
+EPS = 1.0e-12
+
+
+class BaseInference(object):
+ def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
+ super().__init__()
+
+ start = time.monotonic_ns()
+ self.args = args
+ self.cfg = cfg
+
+ assert infer_type in ["from_dataset", "from_file"]
+ self.infer_type = infer_type
+
+ # init with accelerate
+ self.accelerator = accelerate.Accelerator()
+ self.accelerator.wait_for_everyone()
+
+ # Use accelerate logger for distributed inference
+ with self.accelerator.main_process_first():
+ self.logger = get_logger("inference", log_level=args.log_level)
+
+ # Log some info
+ self.logger.info("=" * 56)
+ self.logger.info("||\t\t" + "New inference process started." + "\t\t||")
+ self.logger.info("=" * 56)
+ self.logger.info("\n")
+ self.logger.debug(f"Using {args.log_level.upper()} logging level.")
+
+ self.acoustics_dir = args.acoustics_dir
+ self.logger.debug(f"Acoustic dir: {args.acoustics_dir}")
+ self.vocoder_dir = args.vocoder_dir
+ self.logger.debug(f"Vocoder dir: {args.vocoder_dir}")
+ # should be in svc inferencer
+ # self.target_singer = args.target_singer
+ # self.logger.info(f"Target singers: {args.target_singer}")
+ # self.trans_key = args.trans_key
+ # self.logger.info(f"Trans key: {args.trans_key}")
+
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # set random seed
+ with self.accelerator.main_process_first():
+ start = time.monotonic_ns()
+ self._set_random_seed(self.cfg.train.random_seed)
+ end = time.monotonic_ns()
+ self.logger.debug(
+ f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
+ )
+ self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
+
+ # setup data_loader
+ with self.accelerator.main_process_first():
+ self.logger.info("Building dataset...")
+ start = time.monotonic_ns()
+ self.test_dataloader = self._build_dataloader()
+ end = time.monotonic_ns()
+ self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
+
+ # setup model
+ with self.accelerator.main_process_first():
+ self.logger.info("Building model...")
+ start = time.monotonic_ns()
+ self.model = self._build_model()
+ end = time.monotonic_ns()
+ # self.logger.debug(self.model)
+ self.logger.info(f"Building model done in {(end - start) / 1e6:.3f}ms")
+
+ # init with accelerate
+ self.logger.info("Initializing accelerate...")
+ start = time.monotonic_ns()
+ self.accelerator = accelerate.Accelerator()
+ self.model = self.accelerator.prepare(self.model)
+ end = time.monotonic_ns()
+ self.accelerator.wait_for_everyone()
+ self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.3f}ms")
+
+ with self.accelerator.main_process_first():
+ self.logger.info("Loading checkpoint...")
+ start = time.monotonic_ns()
+ # TODO: Also, suppose only use latest one yet
+ self.__load_model(os.path.join(args.acoustics_dir, "checkpoint"))
+ end = time.monotonic_ns()
+ self.logger.info(f"Loading checkpoint done in {(end - start) / 1e6:.3f}ms")
+
+ self.model.eval()
+ self.accelerator.wait_for_everyone()
+
+ ### Abstract methods ###
+ @abstractmethod
+ def _build_test_dataset(self):
+ pass
+
+ @abstractmethod
+ def _build_model(self):
+ pass
+
+ @abstractmethod
+ @torch.inference_mode()
+ def _inference_each_batch(self, batch_data):
+ pass
+
+ ### Abstract methods end ###
+
+ @torch.inference_mode()
+ def inference(self):
+ for i, batch in enumerate(self.test_dataloader):
+ y_pred = self._inference_each_batch(batch).cpu()
+
+ # Judge whether the min-max normliazation is used
+ if self.cfg.preprocess.use_min_max_norm_mel:
+ mel_min, mel_max = self.test_dataset.target_mel_extrema
+ y_pred = (y_pred + 1.0) / 2.0 * (mel_max - mel_min + EPS) + mel_min
+
+ y_ls = y_pred.chunk(self.test_batch_size)
+ tgt_ls = batch["target_len"].cpu().chunk(self.test_batch_size)
+ j = 0
+ for it, l in zip(y_ls, tgt_ls):
+ l = l.item()
+ it = it.squeeze(0)[:l]
+ uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"]
+ torch.save(it, os.path.join(self.args.output_dir, f"{uid}.pt"))
+ j += 1
+
+ vocoder_cfg, vocoder_ckpt = self._parse_vocoder(self.args.vocoder_dir)
+
+ res = synthesis(
+ cfg=vocoder_cfg,
+ vocoder_weight_file=vocoder_ckpt,
+ n_samples=None,
+ pred=[
+ torch.load(
+ os.path.join(self.args.output_dir, "{}.pt".format(i["Uid"]))
+ ).numpy(force=True)
+ for i in self.test_dataset.metadata
+ ],
+ )
+
+ output_audio_files = []
+ for it, wav in zip(self.test_dataset.metadata, res):
+ uid = it["Uid"]
+ file = os.path.join(self.args.output_dir, f"{uid}.wav")
+ output_audio_files.append(file)
+
+ wav = wav.numpy(force=True)
+ save_audio(
+ file,
+ wav,
+ self.cfg.preprocess.sample_rate,
+ add_silence=False,
+ turn_up=not is_silence(wav, self.cfg.preprocess.sample_rate),
+ )
+ os.remove(os.path.join(self.args.output_dir, f"{uid}.pt"))
+
+ return sorted(output_audio_files)
+
+ # TODO: LEGACY CODE
+ def _build_dataloader(self):
+ datasets, collate = self._build_test_dataset()
+ self.test_dataset = datasets(self.args, self.cfg, self.infer_type)
+ self.test_collate = collate(self.cfg)
+ self.test_batch_size = min(
+ self.cfg.train.batch_size, len(self.test_dataset.metadata)
+ )
+ test_dataloader = DataLoader(
+ self.test_dataset,
+ collate_fn=self.test_collate,
+ num_workers=1,
+ batch_size=self.test_batch_size,
+ shuffle=False,
+ )
+ return test_dataloader
+
+ def __load_model(self, checkpoint_dir: str = None, checkpoint_path: str = None):
+ r"""Load model from checkpoint. If checkpoint_path is None, it will
+ load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
+ None, it will load the checkpoint specified by checkpoint_path. **Only use this
+ method after** ``accelerator.prepare()``.
+ """
+ if checkpoint_path is None:
+ ls = []
+ for i in Path(checkpoint_dir).iterdir():
+ if re.match(r"epoch-\d+_step-\d+_loss-[\d.]+", str(i.stem)):
+ ls.append(i)
+ ls.sort(
+ key=lambda x: int(x.stem.split("_")[-3].split("-")[-1]), reverse=True
+ )
+ checkpoint_path = ls[0]
+ else:
+ checkpoint_path = Path(checkpoint_path)
+ self.accelerator.load_state(str(checkpoint_path))
+ # set epoch and step
+ self.epoch = int(checkpoint_path.stem.split("_")[-3].split("-")[-1])
+ self.step = int(checkpoint_path.stem.split("_")[-2].split("-")[-1])
+ return str(checkpoint_path)
+
+ @staticmethod
+ def _set_random_seed(seed):
+ r"""Set random seed for all possible random modules."""
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.random.manual_seed(seed)
+
+ @staticmethod
+ def _parse_vocoder(vocoder_dir):
+ r"""Parse vocoder config"""
+ vocoder_dir = os.path.abspath(vocoder_dir)
+ ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")]
+ ckpt_list.sort(key=lambda x: int(x.stem), reverse=True)
+ ckpt_path = str(ckpt_list[0])
+ vocoder_cfg = load_config(
+ os.path.join(vocoder_dir, "args.json"), lowercase=True
+ )
+ return vocoder_cfg, ckpt_path
+
+ @staticmethod
+ def __count_parameters(model):
+ return sum(p.numel() for p in model.parameters())
+
+ def __dump_cfg(self, path):
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ json5.dump(
+ self.cfg,
+ open(path, "w"),
+ indent=4,
+ sort_keys=True,
+ ensure_ascii=False,
+ quote_keys=True,
+ )
diff --git a/models/base/new_trainer.py b/models/base/new_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bec327dd31daf629ecad39022a38e1017f38980
--- /dev/null
+++ b/models/base/new_trainer.py
@@ -0,0 +1,727 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import json
+import os
+import random
+import shutil
+import time
+from abc import abstractmethod
+from pathlib import Path
+
+import accelerate
+import json5
+import numpy as np
+import torch
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration
+from torch.utils.data import ConcatDataset, DataLoader
+from tqdm import tqdm
+
+from models.base.base_sampler import build_samplers
+from optimizer.optimizers import NoamLR
+
+
+class BaseTrainer(object):
+ r"""The base trainer for all tasks. Any trainer should inherit from this class."""
+
+ def __init__(self, args=None, cfg=None):
+ super().__init__()
+
+ self.args = args
+ self.cfg = cfg
+
+ cfg.exp_name = args.exp_name
+
+ # init with accelerate
+ self._init_accelerator()
+ self.accelerator.wait_for_everyone()
+
+ # Use accelerate logger for distributed training
+ with self.accelerator.main_process_first():
+ self.logger = get_logger(args.exp_name, log_level=args.log_level)
+
+ # Log some info
+ self.logger.info("=" * 56)
+ self.logger.info("||\t\t" + "New training process started." + "\t\t||")
+ self.logger.info("=" * 56)
+ self.logger.info("\n")
+ self.logger.debug(f"Using {args.log_level.upper()} logging level.")
+ self.logger.info(f"Experiment name: {args.exp_name}")
+ self.logger.info(f"Experiment directory: {self.exp_dir}")
+ self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
+ if self.accelerator.is_main_process:
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
+ self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
+
+ # init counts
+ self.batch_count: int = 0
+ self.step: int = 0
+ self.epoch: int = 0
+ self.max_epoch = (
+ self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf")
+ )
+ self.logger.info(
+ "Max epoch: {}".format(
+ self.max_epoch if self.max_epoch < float("inf") else "Unlimited"
+ )
+ )
+
+ # Check values
+ if self.accelerator.is_main_process:
+ self.__check_basic_configs()
+ # Set runtime configs
+ self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride
+ self.checkpoints_path = [
+ [] for _ in range(len(self.save_checkpoint_stride))
+ ]
+ self.keep_last = [
+ i if i > 0 else float("inf") for i in self.cfg.train.keep_last
+ ]
+ self.run_eval = self.cfg.train.run_eval
+
+ # set random seed
+ with self.accelerator.main_process_first():
+ start = time.monotonic_ns()
+ self._set_random_seed(self.cfg.train.random_seed)
+ end = time.monotonic_ns()
+ self.logger.debug(
+ f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
+ )
+ self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
+
+ # setup data_loader
+ with self.accelerator.main_process_first():
+ self.logger.info("Building dataset...")
+ start = time.monotonic_ns()
+ self.train_dataloader, self.valid_dataloader = self._build_dataloader()
+ end = time.monotonic_ns()
+ self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
+
+ # setup model
+ with self.accelerator.main_process_first():
+ self.logger.info("Building model...")
+ start = time.monotonic_ns()
+ self.model = self._build_model()
+ end = time.monotonic_ns()
+ self.logger.debug(self.model)
+ self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms")
+ self.logger.info(
+ f"Model parameters: {self.__count_parameters(self.model)/1e6:.2f}M"
+ )
+ # optimizer & scheduler
+ with self.accelerator.main_process_first():
+ self.logger.info("Building optimizer and scheduler...")
+ start = time.monotonic_ns()
+ self.optimizer = self._build_optimizer()
+ self.scheduler = self._build_scheduler()
+ end = time.monotonic_ns()
+ self.logger.info(
+ f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms"
+ )
+
+ # accelerate prepare
+ self.logger.info("Initializing accelerate...")
+ start = time.monotonic_ns()
+ self._accelerator_prepare()
+ end = time.monotonic_ns()
+ self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms")
+
+ # create criterion
+ with self.accelerator.main_process_first():
+ self.logger.info("Building criterion...")
+ start = time.monotonic_ns()
+ self.criterion = self._build_criterion()
+ end = time.monotonic_ns()
+ self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms")
+
+ # Resume or Finetune
+ with self.accelerator.main_process_first():
+ if args.resume:
+ if args.resume_from_ckpt_path == "":
+ ## Automatically resume according to the current exprimental name
+ self.logger.info(
+ "Automatically resuming from latest checkpoint in {}...".format(
+ self.checkpoint_dir
+ )
+ )
+ start = time.monotonic_ns()
+ ckpt_path = self._load_model(
+ checkpoint_dir=self.checkpoint_dir, resume_type=args.resume_type
+ )
+ end = time.monotonic_ns()
+ self.logger.info(
+ f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
+ )
+ self.checkpoints_path = json.load(
+ open(os.path.join(ckpt_path, "ckpts.json"), "r")
+ )
+ else:
+ ## Resume from the given checkpoint path
+ if not os.path.exists(args.resume_from_ckpt_path):
+ raise ValueError(
+ "[Error] The resumed checkpoint path {} don't exist.".format(
+ args.resume_from_ckpt_path
+ )
+ )
+ self.logger.info(
+ "Resuming from {}...".format(args.resume_from_ckpt_path)
+ )
+ start = time.monotonic_ns()
+ ckpt_path = self._load_model(
+ checkpoint_path=args.resume_from_ckpt_path,
+ resume_type=args.resume_type,
+ )
+ end = time.monotonic_ns()
+ self.logger.info(
+ f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
+ )
+
+ # save config file path
+ self.config_save_path = os.path.join(self.exp_dir, "args.json")
+
+ def _accelerator_prepare(self):
+ (
+ self.train_dataloader,
+ self.valid_dataloader,
+ self.model,
+ self.optimizer,
+ self.scheduler,
+ ) = self.accelerator.prepare(
+ self.train_dataloader,
+ self.valid_dataloader,
+ self.model,
+ self.optimizer,
+ self.scheduler,
+ )
+
+ ### Following are abstract methods that should be implemented in child classes ###
+ @abstractmethod
+ def _build_dataset(self):
+ r"""Build dataset for model training/validating/evaluating."""
+ pass
+
+ @staticmethod
+ @abstractmethod
+ def _build_criterion():
+ r"""Build criterion function for model loss calculation."""
+ pass
+
+ @abstractmethod
+ def _build_model(self):
+ r"""Build model for training/validating/evaluating."""
+ pass
+
+ @abstractmethod
+ def _forward_step(self, batch):
+ r"""One forward step of the neural network. This abstract method is trying to
+ unify ``_train_step`` and ``_valid_step`` and avoid redundant implementation.
+ However, for special case that using different forward step pattern for
+ training and validating, you could just override this method with ``pass`` and
+ implement ``_train_step`` and ``_valid_step`` separately.
+ """
+ pass
+
+ @abstractmethod
+ def _save_auxiliary_states(self):
+ r"""To save some auxiliary states when saving model's ckpt"""
+ pass
+
+ ### Abstract methods end ###
+
+ ### THIS IS MAIN ENTRY ###
+ def train_loop(self):
+ r"""Training loop. The public entry of training process."""
+ # Wait everyone to prepare before we move on
+ self.accelerator.wait_for_everyone()
+ # dump config file
+ if self.accelerator.is_main_process:
+ self.__dump_cfg(self.config_save_path)
+ self.model.train()
+ self.optimizer.zero_grad()
+ # Wait to ensure good to go
+ self.accelerator.wait_for_everyone()
+ while self.epoch < self.max_epoch:
+ self.logger.info("\n")
+ self.logger.info("-" * 32)
+ self.logger.info("Epoch {}: ".format(self.epoch))
+
+ ### TODO: change the return values of _train_epoch() to a loss dict, or (total_loss, loss_dict)
+ ### It's inconvenient for the model with multiple losses
+ # Do training & validating epoch
+ train_loss = self._train_epoch()
+ self.logger.info(" |- Train/Loss: {:.6f}".format(train_loss))
+ valid_loss = self._valid_epoch()
+ self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_loss))
+ self.accelerator.log(
+ {"Epoch/Train Loss": train_loss, "Epoch/Valid Loss": valid_loss},
+ step=self.epoch,
+ )
+
+ self.accelerator.wait_for_everyone()
+ # TODO: what is scheduler?
+ self.scheduler.step(valid_loss) # FIXME: use epoch track correct?
+
+ # Check if hit save_checkpoint_stride and run_eval
+ run_eval = False
+ if self.accelerator.is_main_process:
+ save_checkpoint = False
+ hit_dix = []
+ for i, num in enumerate(self.save_checkpoint_stride):
+ if self.epoch % num == 0:
+ save_checkpoint = True
+ hit_dix.append(i)
+ run_eval |= self.run_eval[i]
+
+ self.accelerator.wait_for_everyone()
+ if self.accelerator.is_main_process and save_checkpoint:
+ path = os.path.join(
+ self.checkpoint_dir,
+ "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
+ self.epoch, self.step, train_loss
+ ),
+ )
+ self.tmp_checkpoint_save_path = path
+ self.accelerator.save_state(path)
+ print(f"save checkpoint in {path}")
+ json.dump(
+ self.checkpoints_path,
+ open(os.path.join(path, "ckpts.json"), "w"),
+ ensure_ascii=False,
+ indent=4,
+ )
+ self._save_auxiliary_states()
+
+ # Remove old checkpoints
+ to_remove = []
+ for idx in hit_dix:
+ self.checkpoints_path[idx].append(path)
+ while len(self.checkpoints_path[idx]) > self.keep_last[idx]:
+ to_remove.append((idx, self.checkpoints_path[idx].pop(0)))
+
+ # Search conflicts
+ total = set()
+ for i in self.checkpoints_path:
+ total |= set(i)
+ do_remove = set()
+ for idx, path in to_remove[::-1]:
+ if path in total:
+ self.checkpoints_path[idx].insert(0, path)
+ else:
+ do_remove.add(path)
+
+ # Remove old checkpoints
+ for path in do_remove:
+ shutil.rmtree(path, ignore_errors=True)
+ self.logger.debug(f"Remove old checkpoint: {path}")
+
+ self.accelerator.wait_for_everyone()
+ if run_eval:
+ # TODO: run evaluation
+ pass
+
+ # Update info for each epoch
+ self.epoch += 1
+
+ # Finish training and save final checkpoint
+ self.accelerator.wait_for_everyone()
+ if self.accelerator.is_main_process:
+ self.accelerator.save_state(
+ os.path.join(
+ self.checkpoint_dir,
+ "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
+ self.epoch, self.step, valid_loss
+ ),
+ )
+ )
+ self._save_auxiliary_states()
+
+ self.accelerator.end_training()
+
+ ### Following are methods that can be used directly in child classes ###
+ def _train_epoch(self):
+ r"""Training epoch. Should return average loss of a batch (sample) over
+ one epoch. See ``train_loop`` for usage.
+ """
+ self.model.train()
+ epoch_sum_loss: float = 0.0
+ epoch_step: int = 0
+ for batch in tqdm(
+ self.train_dataloader,
+ desc=f"Training Epoch {self.epoch}",
+ unit="batch",
+ colour="GREEN",
+ leave=False,
+ dynamic_ncols=True,
+ smoothing=0.04,
+ disable=not self.accelerator.is_main_process,
+ ):
+ # Do training step and BP
+ with self.accelerator.accumulate(self.model):
+ loss = self._train_step(batch)
+ self.accelerator.backward(loss)
+ self.optimizer.step()
+ self.optimizer.zero_grad()
+ self.batch_count += 1
+
+ # Update info for each step
+ # TODO: step means BP counts or batch counts?
+ if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
+ epoch_sum_loss += loss
+ self.accelerator.log(
+ {
+ "Step/Train Loss": loss,
+ "Step/Learning Rate": self.optimizer.param_groups[0]["lr"],
+ },
+ step=self.step,
+ )
+ self.step += 1
+ epoch_step += 1
+
+ self.accelerator.wait_for_everyone()
+ return (
+ epoch_sum_loss
+ / len(self.train_dataloader)
+ * self.cfg.train.gradient_accumulation_step
+ )
+
+ @torch.inference_mode()
+ def _valid_epoch(self):
+ r"""Testing epoch. Should return average loss of a batch (sample) over
+ one epoch. See ``train_loop`` for usage.
+ """
+ self.model.eval()
+ epoch_sum_loss = 0.0
+ for batch in tqdm(
+ self.valid_dataloader,
+ desc=f"Validating Epoch {self.epoch}",
+ unit="batch",
+ colour="GREEN",
+ leave=False,
+ dynamic_ncols=True,
+ smoothing=0.04,
+ disable=not self.accelerator.is_main_process,
+ ):
+ batch_loss = self._valid_step(batch)
+ epoch_sum_loss += batch_loss.item()
+
+ self.accelerator.wait_for_everyone()
+ return epoch_sum_loss / len(self.valid_dataloader)
+
+ def _train_step(self, batch):
+ r"""Training forward step. Should return average loss of a sample over
+ one batch. Provoke ``_forward_step`` is recommended except for special case.
+ See ``_train_epoch`` for usage.
+ """
+ return self._forward_step(batch)
+
+ @torch.inference_mode()
+ def _valid_step(self, batch):
+ r"""Testing forward step. Should return average loss of a sample over
+ one batch. Provoke ``_forward_step`` is recommended except for special case.
+ See ``_test_epoch`` for usage.
+ """
+ return self._forward_step(batch)
+
+ def _load_model(
+ self,
+ checkpoint_dir: str = None,
+ checkpoint_path: str = None,
+ resume_type: str = "",
+ ):
+ r"""Load model from checkpoint. If checkpoint_path is None, it will
+ load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
+ None, it will load the checkpoint specified by checkpoint_path. **Only use this
+ method after** ``accelerator.prepare()``.
+ """
+ if checkpoint_path is None:
+ ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
+ ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
+ checkpoint_path = ls[0]
+ self.logger.info("Resume from {}...".format(checkpoint_path))
+
+ if resume_type in ["resume", ""]:
+ # Load all the things, including model weights, optimizer, scheduler, and random states.
+ self.accelerator.load_state(input_dir=checkpoint_path)
+
+ # set epoch and step
+ self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
+ self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
+
+ elif resume_type == "finetune":
+ # Load only the model weights
+ accelerate.load_checkpoint_and_dispatch(
+ self.accelerator.unwrap_model(self.model),
+ os.path.join(checkpoint_path, "pytorch_model.bin"),
+ )
+ self.logger.info("Load model weights for finetune...")
+
+ else:
+ raise ValueError("Resume_type must be `resume` or `finetune`.")
+
+ return checkpoint_path
+
+ def _build_dataloader(self):
+ Dataset, Collator = self._build_dataset()
+
+ # build dataset instance for each dataset and combine them by ConcatDataset
+ datasets_list = []
+ for dataset in self.cfg.dataset:
+ subdataset = Dataset(self.cfg, dataset, is_valid=False)
+ datasets_list.append(subdataset)
+ train_dataset = ConcatDataset(datasets_list)
+ train_collate = Collator(self.cfg)
+ _, batch_sampler = build_samplers(train_dataset, self.cfg, self.logger, "train")
+ self.logger.debug(f"train batch_sampler: {list(batch_sampler)}")
+ self.logger.debug(f"length: {train_dataset.cumulative_sizes}")
+ # TODO: use config instead of (sampler, shuffle, drop_last, batch_size)
+ train_loader = DataLoader(
+ train_dataset,
+ # shuffle=True,
+ collate_fn=train_collate,
+ batch_sampler=batch_sampler,
+ num_workers=self.cfg.train.dataloader.num_worker,
+ pin_memory=self.cfg.train.dataloader.pin_memory,
+ )
+
+ # Build valid dataloader
+ datasets_list = []
+ for dataset in self.cfg.dataset:
+ subdataset = Dataset(self.cfg, dataset, is_valid=True)
+ datasets_list.append(subdataset)
+ valid_dataset = ConcatDataset(datasets_list)
+ valid_collate = Collator(self.cfg)
+ _, batch_sampler = build_samplers(valid_dataset, self.cfg, self.logger, "valid")
+ self.logger.debug(f"valid batch_sampler: {list(batch_sampler)}")
+ self.logger.debug(f"length: {valid_dataset.cumulative_sizes}")
+ valid_loader = DataLoader(
+ valid_dataset,
+ collate_fn=valid_collate,
+ batch_sampler=batch_sampler,
+ num_workers=self.cfg.train.dataloader.num_worker,
+ pin_memory=self.cfg.train.dataloader.pin_memory,
+ )
+ return train_loader, valid_loader
+
+ @staticmethod
+ def _set_random_seed(seed):
+ r"""Set random seed for all possible random modules."""
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.random.manual_seed(seed)
+
+ def _check_nan(self, loss, y_pred, y_gt):
+ if torch.any(torch.isnan(loss)):
+ self.logger.error("Fatal Error: Training is down since loss has Nan!")
+ self.logger.error("loss = {:.6f}".format(loss.item()), in_order=True)
+
+ ### y_pred ###
+ if torch.any(torch.isnan(y_pred)):
+ self.logger.error(
+ f"y_pred has Nan: {torch.any(torch.isnan(y_pred))}", in_order=True
+ )
+ self.logger.error(f"y_pred: {y_pred}", in_order=True)
+ else:
+ self.logger.debug(
+ f"y_pred has Nan: {torch.any(torch.isnan(y_pred))}", in_order=True
+ )
+ self.logger.debug(f"y_pred: {y_pred}", in_order=True)
+
+ ### y_gt ###
+ if torch.any(torch.isnan(y_gt)):
+ self.logger.error(
+ f"y_gt has Nan: {torch.any(torch.isnan(y_gt))}", in_order=True
+ )
+ self.logger.error(f"y_gt: {y_gt}", in_order=True)
+ else:
+ self.logger.debug(
+ f"y_gt has nan: {torch.any(torch.isnan(y_gt))}", in_order=True
+ )
+ self.logger.debug(f"y_gt: {y_gt}", in_order=True)
+
+ self.accelerator.end_training()
+ raise RuntimeError("Loss has Nan! See log for more info.")
+
+ ### Protected methods end ###
+
+ ## Following are private methods ##
+ def _build_optimizer(self):
+ r"""Build optimizer for model."""
+ # Make case-insensitive matching
+ if self.cfg.train.optimizer.lower() == "adadelta":
+ optimizer = torch.optim.Adadelta(
+ self.model.parameters(), **self.cfg.train.adadelta
+ )
+ self.logger.info("Using Adadelta optimizer.")
+ elif self.cfg.train.optimizer.lower() == "adagrad":
+ optimizer = torch.optim.Adagrad(
+ self.model.parameters(), **self.cfg.train.adagrad
+ )
+ self.logger.info("Using Adagrad optimizer.")
+ elif self.cfg.train.optimizer.lower() == "adam":
+ optimizer = torch.optim.Adam(self.model.parameters(), **self.cfg.train.adam)
+ self.logger.info("Using Adam optimizer.")
+ elif self.cfg.train.optimizer.lower() == "adamw":
+ optimizer = torch.optim.AdamW(
+ self.model.parameters(), **self.cfg.train.adamw
+ )
+ elif self.cfg.train.optimizer.lower() == "sparseadam":
+ optimizer = torch.optim.SparseAdam(
+ self.model.parameters(), **self.cfg.train.sparseadam
+ )
+ elif self.cfg.train.optimizer.lower() == "adamax":
+ optimizer = torch.optim.Adamax(
+ self.model.parameters(), **self.cfg.train.adamax
+ )
+ elif self.cfg.train.optimizer.lower() == "asgd":
+ optimizer = torch.optim.ASGD(self.model.parameters(), **self.cfg.train.asgd)
+ elif self.cfg.train.optimizer.lower() == "lbfgs":
+ optimizer = torch.optim.LBFGS(
+ self.model.parameters(), **self.cfg.train.lbfgs
+ )
+ elif self.cfg.train.optimizer.lower() == "nadam":
+ optimizer = torch.optim.NAdam(
+ self.model.parameters(), **self.cfg.train.nadam
+ )
+ elif self.cfg.train.optimizer.lower() == "radam":
+ optimizer = torch.optim.RAdam(
+ self.model.parameters(), **self.cfg.train.radam
+ )
+ elif self.cfg.train.optimizer.lower() == "rmsprop":
+ optimizer = torch.optim.RMSprop(
+ self.model.parameters(), **self.cfg.train.rmsprop
+ )
+ elif self.cfg.train.optimizer.lower() == "rprop":
+ optimizer = torch.optim.Rprop(
+ self.model.parameters(), **self.cfg.train.rprop
+ )
+ elif self.cfg.train.optimizer.lower() == "sgd":
+ optimizer = torch.optim.SGD(self.model.parameters(), **self.cfg.train.sgd)
+ else:
+ raise NotImplementedError(
+ f"Optimizer {self.cfg.train.optimizer} not supported yet!"
+ )
+ return optimizer
+
+ def _build_scheduler(self):
+ r"""Build scheduler for optimizer."""
+ # Make case-insensitive matching
+ if self.cfg.train.scheduler.lower() == "lambdalr":
+ scheduler = torch.optim.lr_scheduler.LambdaLR(
+ self.optimizer, **self.cfg.train.lambdalr
+ )
+ elif self.cfg.train.scheduler.lower() == "multiplicativelr":
+ scheduler = torch.optim.lr_scheduler.MultiplicativeLR(
+ self.optimizer, **self.cfg.train.multiplicativelr
+ )
+ elif self.cfg.train.scheduler.lower() == "steplr":
+ scheduler = torch.optim.lr_scheduler.StepLR(
+ self.optimizer, **self.cfg.train.steplr
+ )
+ elif self.cfg.train.scheduler.lower() == "multisteplr":
+ scheduler = torch.optim.lr_scheduler.MultiStepLR(
+ self.optimizer, **self.cfg.train.multisteplr
+ )
+ elif self.cfg.train.scheduler.lower() == "constantlr":
+ scheduler = torch.optim.lr_scheduler.ConstantLR(
+ self.optimizer, **self.cfg.train.constantlr
+ )
+ elif self.cfg.train.scheduler.lower() == "linearlr":
+ scheduler = torch.optim.lr_scheduler.LinearLR(
+ self.optimizer, **self.cfg.train.linearlr
+ )
+ elif self.cfg.train.scheduler.lower() == "exponentiallr":
+ scheduler = torch.optim.lr_scheduler.ExponentialLR(
+ self.optimizer, **self.cfg.train.exponentiallr
+ )
+ elif self.cfg.train.scheduler.lower() == "polynomiallr":
+ scheduler = torch.optim.lr_scheduler.PolynomialLR(
+ self.optimizer, **self.cfg.train.polynomiallr
+ )
+ elif self.cfg.train.scheduler.lower() == "cosineannealinglr":
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
+ self.optimizer, **self.cfg.train.cosineannealinglr
+ )
+ elif self.cfg.train.scheduler.lower() == "sequentiallr":
+ scheduler = torch.optim.lr_scheduler.SequentialLR(
+ self.optimizer, **self.cfg.train.sequentiallr
+ )
+ elif self.cfg.train.scheduler.lower() == "reducelronplateau":
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
+ self.optimizer, **self.cfg.train.reducelronplateau
+ )
+ elif self.cfg.train.scheduler.lower() == "cycliclr":
+ scheduler = torch.optim.lr_scheduler.CyclicLR(
+ self.optimizer, **self.cfg.train.cycliclr
+ )
+ elif self.cfg.train.scheduler.lower() == "onecyclelr":
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
+ self.optimizer, **self.cfg.train.onecyclelr
+ )
+ elif self.cfg.train.scheduler.lower() == "cosineannearingwarmrestarts":
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
+ self.optimizer, **self.cfg.train.cosineannearingwarmrestarts
+ )
+ elif self.cfg.train.scheduler.lower() == "noamlr":
+ scheduler = NoamLR(self.optimizer, **self.cfg.train.lr_scheduler)
+ else:
+ raise NotImplementedError(
+ f"Scheduler {self.cfg.train.scheduler} not supported yet!"
+ )
+ return scheduler
+
+ def _init_accelerator(self):
+ self.exp_dir = os.path.join(
+ os.path.abspath(self.cfg.log_dir), self.args.exp_name
+ )
+ project_config = ProjectConfiguration(
+ project_dir=self.exp_dir,
+ logging_dir=os.path.join(self.exp_dir, "log"),
+ )
+ self.accelerator = accelerate.Accelerator(
+ gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step,
+ log_with=self.cfg.train.tracker,
+ project_config=project_config,
+ )
+ if self.accelerator.is_main_process:
+ os.makedirs(project_config.project_dir, exist_ok=True)
+ os.makedirs(project_config.logging_dir, exist_ok=True)
+ with self.accelerator.main_process_first():
+ self.accelerator.init_trackers(self.args.exp_name)
+
+ def __check_basic_configs(self):
+ if self.cfg.train.gradient_accumulation_step <= 0:
+ self.logger.fatal("Invalid gradient_accumulation_step value!")
+ self.logger.error(
+ f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
+ )
+ self.accelerator.end_training()
+ raise ValueError(
+ f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
+ )
+ # TODO: check other values
+
+ @staticmethod
+ def __count_parameters(model):
+ model_param = 0.0
+ if isinstance(model, dict):
+ for key, value in model.items():
+ model_param += sum(p.numel() for p in model[key].parameters())
+ else:
+ model_param = sum(p.numel() for p in model.parameters())
+ return model_param
+
+ def __dump_cfg(self, path):
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ json5.dump(
+ self.cfg,
+ open(path, "w"),
+ indent=4,
+ sort_keys=True,
+ ensure_ascii=False,
+ quote_keys=True,
+ )
+
+ ### Private methods end ###
diff --git a/models/codec/__init__.py b/models/codec/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/codec/amphion_codec/codec.py b/models/codec/amphion_codec/codec.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f8cb2b504d5d2cd37251400f16c19166ecfed2
--- /dev/null
+++ b/models/codec/amphion_codec/codec.py
@@ -0,0 +1,427 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from torch.nn.utils import weight_norm
+
+from models.codec.amphion_codec.quantize import (
+ ResidualVQ,
+ VectorQuantize,
+ FactorizedVectorQuantize,
+ LookupFreeQuantize,
+)
+
+from models.codec.amphion_codec.vocos import Vocos
+
+
+def WNConv1d(*args, **kwargs):
+ return weight_norm(nn.Conv1d(*args, **kwargs))
+
+
+def WNConvTranspose1d(*args, **kwargs):
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
+
+
+# Scripting this brings model speed up 1.4x
+@torch.jit.script
+def snake(x, alpha):
+ shape = x.shape
+ x = x.reshape(shape[0], shape[1], -1)
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
+ x = x.reshape(shape)
+ return x
+
+
+class Snake1d(nn.Module):
+ def __init__(self, channels):
+ super().__init__()
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
+
+ def forward(self, x):
+ return snake(x, self.alpha)
+
+
+def init_weights(m):
+ if isinstance(m, nn.Conv1d):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ nn.init.constant_(m.bias, 0)
+ if isinstance(m, nn.Linear):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ nn.init.constant_(m.bias, 0)
+
+
+class ResidualUnit(nn.Module):
+ def __init__(self, dim: int = 16, dilation: int = 1):
+ super().__init__()
+ pad = ((7 - 1) * dilation) // 2
+ self.block = nn.Sequential(
+ Snake1d(dim),
+ WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
+ Snake1d(dim),
+ WNConv1d(dim, dim, kernel_size=1),
+ )
+
+ def forward(self, x):
+ y = self.block(x)
+ pad = (x.shape[-1] - y.shape[-1]) // 2
+ if pad > 0:
+ x = x[..., pad:-pad]
+ return x + y
+
+
+class EncoderBlock(nn.Module):
+ def __init__(self, dim: int = 16, stride: int = 1):
+ super().__init__()
+ self.block = nn.Sequential(
+ ResidualUnit(dim // 2, dilation=1),
+ ResidualUnit(dim // 2, dilation=3),
+ ResidualUnit(dim // 2, dilation=9),
+ Snake1d(dim // 2),
+ WNConv1d(
+ dim // 2,
+ dim,
+ kernel_size=2 * stride,
+ stride=stride,
+ padding=math.ceil(stride / 2),
+ ),
+ )
+
+ def forward(self, x):
+ return self.block(x)
+
+
+class CodecEncoder(nn.Module):
+ def __init__(
+ self,
+ d_model: int = 64,
+ up_ratios: list = [4, 5, 5, 6],
+ out_channels: int = 256,
+ use_tanh: bool = False,
+ cfg=None,
+ ):
+ super().__init__()
+
+ d_model = cfg.d_model if cfg is not None else d_model
+ up_ratios = cfg.up_ratios if cfg is not None else up_ratios
+ out_channels = cfg.out_channels if cfg is not None else out_channels
+ use_tanh = cfg.use_tanh if cfg is not None else use_tanh
+
+ # Create first convolution
+ self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
+
+ # Create EncoderBlocks that double channels as they downsample by `stride`
+ for stride in up_ratios:
+ d_model *= 2
+ self.block += [EncoderBlock(d_model, stride=stride)]
+
+ # Create last convolution
+ self.block += [
+ Snake1d(d_model),
+ WNConv1d(d_model, out_channels, kernel_size=3, padding=1),
+ ]
+
+ if use_tanh:
+ self.block += [nn.Tanh()]
+
+ # Wrap black into nn.Sequential
+ self.block = nn.Sequential(*self.block)
+ self.enc_dim = d_model
+
+ self.reset_parameters()
+
+ def forward(self, x):
+ return self.block(x)
+
+ def reset_parameters(self):
+ self.apply(init_weights)
+
+
+class DecoderBlock(nn.Module):
+ def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
+ super().__init__()
+ self.block = nn.Sequential(
+ Snake1d(input_dim),
+ WNConvTranspose1d(
+ input_dim,
+ output_dim,
+ kernel_size=2 * stride,
+ stride=stride,
+ padding=stride // 2 + stride % 2,
+ output_padding=stride % 2,
+ ),
+ ResidualUnit(output_dim, dilation=1),
+ ResidualUnit(output_dim, dilation=3),
+ ResidualUnit(output_dim, dilation=9),
+ )
+
+ def forward(self, x):
+ return self.block(x)
+
+
+class CodecDecoder(nn.Module):
+ def __init__(
+ self,
+ in_channels: int = 256,
+ upsample_initial_channel: int = 1536,
+ up_ratios: list = [5, 5, 4, 2],
+ num_quantizers: int = 8,
+ codebook_size: int = 1024,
+ codebook_dim: int = 256,
+ quantizer_type: str = "vq",
+ quantizer_dropout: float = 0.5,
+ commitment: float = 0.25,
+ codebook_loss_weight: float = 1.0,
+ use_l2_normlize: bool = False,
+ codebook_type: str = "euclidean",
+ kmeans_init: bool = False,
+ kmeans_iters: int = 10,
+ decay: float = 0.8,
+ eps: float = 1e-5,
+ threshold_ema_dead_code: int = 2,
+ weight_init: bool = False,
+ use_vocos: bool = False,
+ vocos_dim: int = 384,
+ vocos_intermediate_dim: int = 1152,
+ vocos_num_layers: int = 8,
+ n_fft: int = 800,
+ hop_size: int = 200,
+ padding: str = "same",
+ cfg=None,
+ ):
+ super().__init__()
+
+ in_channels = (
+ cfg.in_channels
+ if cfg is not None and hasattr(cfg, "in_channels")
+ else in_channels
+ )
+ upsample_initial_channel = (
+ cfg.upsample_initial_channel
+ if cfg is not None and hasattr(cfg, "upsample_initial_channel")
+ else upsample_initial_channel
+ )
+ up_ratios = (
+ cfg.up_ratios
+ if cfg is not None and hasattr(cfg, "up_ratios")
+ else up_ratios
+ )
+ num_quantizers = (
+ cfg.num_quantizers
+ if cfg is not None and hasattr(cfg, "num_quantizers")
+ else num_quantizers
+ )
+ codebook_size = (
+ cfg.codebook_size
+ if cfg is not None and hasattr(cfg, "codebook_size")
+ else codebook_size
+ )
+ codebook_dim = (
+ cfg.codebook_dim
+ if cfg is not None and hasattr(cfg, "codebook_dim")
+ else codebook_dim
+ )
+ quantizer_type = (
+ cfg.quantizer_type
+ if cfg is not None and hasattr(cfg, "quantizer_type")
+ else quantizer_type
+ )
+ quantizer_dropout = (
+ cfg.quantizer_dropout
+ if cfg is not None and hasattr(cfg, "quantizer_dropout")
+ else quantizer_dropout
+ )
+ commitment = (
+ cfg.commitment
+ if cfg is not None and hasattr(cfg, "commitment")
+ else commitment
+ )
+ codebook_loss_weight = (
+ cfg.codebook_loss_weight
+ if cfg is not None and hasattr(cfg, "codebook_loss_weight")
+ else codebook_loss_weight
+ )
+ use_l2_normlize = (
+ cfg.use_l2_normlize
+ if cfg is not None and hasattr(cfg, "use_l2_normlize")
+ else use_l2_normlize
+ )
+ codebook_type = (
+ cfg.codebook_type
+ if cfg is not None and hasattr(cfg, "codebook_type")
+ else codebook_type
+ )
+ kmeans_init = (
+ cfg.kmeans_init
+ if cfg is not None and hasattr(cfg, "kmeans_init")
+ else kmeans_init
+ )
+ kmeans_iters = (
+ cfg.kmeans_iters
+ if cfg is not None and hasattr(cfg, "kmeans_iters")
+ else kmeans_iters
+ )
+ decay = cfg.decay if cfg is not None and hasattr(cfg, "decay") else decay
+ eps = cfg.eps if cfg is not None and hasattr(cfg, "eps") else eps
+ threshold_ema_dead_code = (
+ cfg.threshold_ema_dead_code
+ if cfg is not None and hasattr(cfg, "threshold_ema_dead_code")
+ else threshold_ema_dead_code
+ )
+ weight_init = (
+ cfg.weight_init
+ if cfg is not None and hasattr(cfg, "weight_init")
+ else weight_init
+ )
+ use_vocos = (
+ cfg.use_vocos
+ if cfg is not None and hasattr(cfg, "use_vocos")
+ else use_vocos
+ )
+ vocos_dim = (
+ cfg.vocos_dim
+ if cfg is not None and hasattr(cfg, "vocos_dim")
+ else vocos_dim
+ )
+ vocos_intermediate_dim = (
+ cfg.vocos_intermediate_dim
+ if cfg is not None and hasattr(cfg, "vocos_intermediate_dim")
+ else vocos_intermediate_dim
+ )
+ vocos_num_layers = (
+ cfg.vocos_num_layers
+ if cfg is not None and hasattr(cfg, "vocos_num_layers")
+ else vocos_num_layers
+ )
+ n_fft = cfg.n_fft if cfg is not None and hasattr(cfg, "n_fft") else n_fft
+ hop_size = (
+ cfg.hop_size if cfg is not None and hasattr(cfg, "hop_size") else hop_size
+ )
+ padding = (
+ cfg.padding if cfg is not None and hasattr(cfg, "padding") else padding
+ )
+
+ if quantizer_type == "vq":
+ self.quantizer = ResidualVQ(
+ input_dim=in_channels,
+ num_quantizers=num_quantizers,
+ codebook_size=codebook_size,
+ codebook_dim=codebook_dim,
+ quantizer_type=quantizer_type,
+ quantizer_dropout=quantizer_dropout,
+ commitment=commitment,
+ codebook_loss_weight=codebook_loss_weight,
+ use_l2_normlize=use_l2_normlize,
+ codebook_type=codebook_type,
+ kmeans_init=kmeans_init,
+ kmeans_iters=kmeans_iters,
+ decay=decay,
+ eps=eps,
+ threshold_ema_dead_code=threshold_ema_dead_code,
+ weight_init=weight_init,
+ )
+ elif quantizer_type == "fvq":
+ self.quantizer = ResidualVQ(
+ input_dim=in_channels,
+ num_quantizers=num_quantizers,
+ codebook_size=codebook_size,
+ codebook_dim=codebook_dim,
+ quantizer_type=quantizer_type,
+ quantizer_dropout=quantizer_dropout,
+ commitment=commitment,
+ codebook_loss_weight=codebook_loss_weight,
+ use_l2_normlize=use_l2_normlize,
+ )
+ elif quantizer_type == "lfq":
+ self.quantizer = ResidualVQ(
+ input_dim=in_channels,
+ num_quantizers=num_quantizers,
+ codebook_size=codebook_size,
+ codebook_dim=codebook_dim,
+ quantizer_type=quantizer_type,
+ )
+ else:
+ raise ValueError(f"Unknown quantizer type {quantizer_type}")
+
+ if not use_vocos:
+ # Add first conv layer
+ channels = upsample_initial_channel
+ layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)]
+
+ # Add upsampling + MRF blocks
+ for i, stride in enumerate(up_ratios):
+ input_dim = channels // 2**i
+ output_dim = channels // 2 ** (i + 1)
+ layers += [DecoderBlock(input_dim, output_dim, stride)]
+
+ # Add final conv layer
+ layers += [
+ Snake1d(output_dim),
+ WNConv1d(output_dim, 1, kernel_size=7, padding=3),
+ nn.Tanh(),
+ ]
+
+ self.model = nn.Sequential(*layers)
+
+ if use_vocos:
+ self.model = Vocos(
+ input_channels=in_channels,
+ dim=vocos_dim,
+ intermediate_dim=vocos_intermediate_dim,
+ num_layers=vocos_num_layers,
+ adanorm_num_embeddings=None,
+ n_fft=n_fft,
+ hop_size=hop_size,
+ padding=padding,
+ )
+
+ self.reset_parameters()
+
+ def forward(self, x=None, vq=False, eval_vq=False, n_quantizers=None):
+ """
+ if vq is True, x = encoder output, then return quantized output;
+ else, x = quantized output, then return decoder output
+ """
+ if vq is True:
+ if eval_vq:
+ self.quantizer.eval()
+ (
+ quantized_out,
+ all_indices,
+ all_commit_losses,
+ all_codebook_losses,
+ all_quantized,
+ ) = self.quantizer(x, n_quantizers=n_quantizers)
+ return (
+ quantized_out,
+ all_indices,
+ all_commit_losses,
+ all_codebook_losses,
+ all_quantized,
+ )
+
+ return self.model(x)
+
+ def quantize(self, x, n_quantizers=None):
+ self.quantizer.eval()
+ quantized_out, vq, _, _, _ = self.quantizer(x, n_quantizers=n_quantizers)
+ return quantized_out, vq
+
+ # TODO: check consistency of vq2emb and quantize
+ def vq2emb(self, vq, n_quantizers=None):
+ return self.quantizer.vq2emb(vq, n_quantizers=n_quantizers)
+
+ def decode(self, x):
+ return self.model(x)
+
+ def latent2dist(self, x, n_quantizers=None):
+ return self.quantizer.latent2dist(x, n_quantizers=n_quantizers)
+
+ def reset_parameters(self):
+ self.apply(init_weights)
diff --git a/models/codec/amphion_codec/quantize/__init__.py b/models/codec/amphion_codec/quantize/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..db7cdbcfd66785ba7c3ffe4b32d9a261ab52f608
--- /dev/null
+++ b/models/codec/amphion_codec/quantize/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from models.codec.amphion_codec.quantize.factorized_vector_quantize import (
+ FactorizedVectorQuantize,
+)
+from models.codec.amphion_codec.quantize.vector_quantize import VectorQuantize
+from models.codec.amphion_codec.quantize.lookup_free_quantize import LookupFreeQuantize
+from models.codec.amphion_codec.quantize.residual_vq import ResidualVQ
diff --git a/models/codec/amphion_codec/quantize/factorized_vector_quantize.py b/models/codec/amphion_codec/quantize/factorized_vector_quantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c359f8cb60e8d2617a28f8d20806b2dbfd1b588
--- /dev/null
+++ b/models/codec/amphion_codec/quantize/factorized_vector_quantize.py
@@ -0,0 +1,150 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from torch.nn.utils import weight_norm
+
+
+def WNConv1d(*args, **kwargs):
+ return weight_norm(nn.Conv1d(*args, **kwargs))
+
+
+def WNConvTranspose1d(*args, **kwargs):
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
+
+
+class FactorizedVectorQuantize(nn.Module):
+ def __init__(
+ self,
+ input_dim,
+ codebook_size,
+ codebook_dim,
+ commitment=0.005,
+ codebook_loss_weight=1.0,
+ use_l2_normlize=True,
+ ):
+ super().__init__()
+ self.input_dim = input_dim
+ self.codebook_size = codebook_size
+ self.codebook_dim = codebook_dim
+ self.commitment = commitment
+ self.codebook_loss_weight = codebook_loss_weight
+ self.use_l2_normlize = use_l2_normlize
+
+ if self.input_dim != self.codebook_dim:
+ self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
+ self.out_project = WNConv1d(
+ self.codebook_dim, self.input_dim, kernel_size=1
+ )
+
+ else:
+ self.in_project = nn.Identity()
+ self.out_project = nn.Identity()
+
+ self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim)
+
+ def forward(self, z):
+ """
+ Parameters
+ ----------
+ z: torch.Tensor[B x D x T]
+
+ Returns
+ -------
+ z_q: torch.Tensor[B x D x T]
+ Quantized continuous representation of input
+ commit_loss: Tensor[B]
+ Commitment loss to train encoder to predict vectors closer to codebook entries
+ codebook_loss: Tensor[B]
+ Codebook loss to update the codebook
+ indices: torch.Tensor[B x T]
+ Codebook indices (quantized discrete representation of input)
+ z_e: torch.Tensor[B x D x T]
+ Projected latents (continuous representation of input before quantization)
+ """
+
+ # Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
+ z_e = self.in_project(z)
+ z_q, indices = self.decode_latents(z_e)
+
+ # Compute commitment loss and codebook loss
+ if self.training:
+ commit_loss = (
+ F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
+ * self.commitment
+ )
+ codebook_loss = (
+ F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
+ * self.codebook_loss_weight
+ )
+ else:
+ commit_loss = torch.zeros(z.shape[0], device=z.device)
+ codebook_loss = torch.zeros(z.shape[0], device=z.device)
+
+ z_q = z_e + (z_q - z_e).detach()
+
+ z_q = self.out_project(z_q)
+
+ return z_q, commit_loss, codebook_loss, indices, z_e
+
+ def embed_code(self, embed_id):
+ return F.embedding(embed_id, self.codebook.weight)
+
+ def decode_code(self, embed_id):
+ return self.embed_code(embed_id).transpose(1, 2)
+
+ def decode_latents(self, latents):
+ encodings = rearrange(latents, "b d t -> (b t) d")
+ codebook = self.codebook.weight
+
+ # L2 normalize encodings and codebook
+ if self.use_l2_normlize:
+ encodings = F.normalize(encodings)
+ codebook = F.normalize(codebook)
+
+ # Compute euclidean distance between encodings and codebook,
+ # if use_l2_normlize is True, the distance is equal to cosine distance
+ dist = (
+ encodings.pow(2).sum(1, keepdim=True)
+ - 2 * encodings @ codebook.t()
+ + codebook.pow(2).sum(1, keepdim=True).t()
+ )
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
+ z_q = self.decode_code(indices)
+
+ return z_q, indices
+
+ def vq2emb(self, vq, out_proj=True):
+ emb = self.decode_code(vq)
+ if out_proj:
+ emb = self.out_project(emb)
+ return emb
+
+ def latent2dist(self, latents):
+ encodings = rearrange(latents, "b d t -> (b t) d")
+ codebook = self.codebook.weight
+
+ # L2 normalize encodings and codebook
+ if self.use_l2_normlize:
+ encodings = F.normalize(encodings)
+ codebook = F.normalize(codebook)
+
+ # Compute euclidean distance between encodings and codebook,
+ # if use_l2_normlize is True, the distance is equal to cosine distance
+ dist = (
+ encodings.pow(2).sum(1, keepdim=True)
+ - 2 * encodings @ codebook.t()
+ + codebook.pow(2).sum(1, keepdim=True).t()
+ ) # (b*t, k)
+
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
+ dist = rearrange(dist, "(b t) k -> b t k", b=latents.size(0))
+ z_q = self.decode_code(indices)
+
+ return -dist, indices, z_q
diff --git a/models/codec/amphion_codec/quantize/lookup_free_quantize.py b/models/codec/amphion_codec/quantize/lookup_free_quantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b1801573bc8f5935cce465a48bd6e5192953e32
--- /dev/null
+++ b/models/codec/amphion_codec/quantize/lookup_free_quantize.py
@@ -0,0 +1,77 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from torch.nn.utils import weight_norm
+
+
+def WNConv1d(*args, **kwargs):
+ return weight_norm(nn.Conv1d(*args, **kwargs))
+
+
+def WNConvTranspose1d(*args, **kwargs):
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
+
+
+class LookupFreeQuantize(nn.Module):
+ def __init__(
+ self,
+ input_dim,
+ codebook_size,
+ codebook_dim,
+ ):
+ super().__init__()
+ self.input_dim = input_dim
+ self.codebook_size = codebook_size
+ self.codebook_dim = codebook_dim
+
+ assert 2**codebook_dim == codebook_size
+
+ if self.input_dim != self.codebook_dim:
+ self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
+ self.out_project = WNConv1d(
+ self.codebook_dim, self.input_dim, kernel_size=1
+ )
+
+ else:
+ self.in_project = nn.Identity()
+ self.out_project = nn.Identity()
+
+ def forward(self, z):
+ z_e = self.in_project(z)
+ z_e = F.sigmoid(z_e)
+
+ z_q = z_e + (torch.round(z_e) - z_e).detach()
+
+ z_q = self.out_project(z_q)
+
+ commit_loss = torch.zeros(z.shape[0], device=z.device)
+ codebook_loss = torch.zeros(z.shape[0], device=z.device)
+
+ bits = (
+ 2
+ ** torch.arange(self.codebook_dim, device=z.device)
+ .unsqueeze(0)
+ .unsqueeze(-1)
+ .long()
+ ) # (1, d, 1)
+ indices = (torch.round(z_e.clone().detach()).long() * bits).sum(1).long()
+
+ return z_q, commit_loss, codebook_loss, indices, z_e
+
+ def vq2emb(self, vq, out_proj=True):
+ emb = torch.zeros(
+ vq.shape[0], self.codebook_dim, vq.shape[-1], device=vq.device
+ ) # (B, d, T)
+ for i in range(self.codebook_dim):
+ emb[:, i, :] = (vq % 2).float()
+ vq = vq // 2
+ if out_proj:
+ emb = self.out_project(emb)
+ return emb
diff --git a/models/codec/amphion_codec/quantize/residual_vq.py b/models/codec/amphion_codec/quantize/residual_vq.py
new file mode 100644
index 0000000000000000000000000000000000000000..41b8856e2e2b0ca902676fc86f5d807927e64108
--- /dev/null
+++ b/models/codec/amphion_codec/quantize/residual_vq.py
@@ -0,0 +1,177 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from torch.nn.utils import weight_norm
+
+from models.codec.amphion_codec.quantize.factorized_vector_quantize import (
+ FactorizedVectorQuantize,
+)
+from models.codec.amphion_codec.quantize.vector_quantize import VectorQuantize
+from models.codec.amphion_codec.quantize.lookup_free_quantize import LookupFreeQuantize
+
+
+class ResidualVQ(nn.Module):
+ """
+ Introduced in SoundStream: An end2end neural audio codec
+ https://arxiv.org/abs/2107.03312
+ """
+
+ def __init__(
+ self,
+ input_dim: int = 256,
+ num_quantizers: int = 8,
+ codebook_size: int = 1024,
+ codebook_dim: int = 256,
+ quantizer_type: str = "vq", # "vq" or "fvq" or "lfq"
+ quantizer_dropout: float = 0.5,
+ **kwargs,
+ ):
+ super().__init__()
+
+ self.input_dim = input_dim
+ self.num_quantizers = num_quantizers
+ self.codebook_size = codebook_size
+ self.codebook_dim = codebook_dim
+ self.quantizer_type = quantizer_type
+ self.quantizer_dropout = quantizer_dropout
+
+ if quantizer_type == "vq":
+ VQ = VectorQuantize
+ elif quantizer_type == "fvq":
+ VQ = FactorizedVectorQuantize
+ elif quantizer_type == "lfq":
+ VQ = LookupFreeQuantize
+ else:
+ raise ValueError(f"Unknown quantizer type {quantizer_type}")
+
+ self.quantizers = nn.ModuleList(
+ [
+ VQ(
+ input_dim=input_dim,
+ codebook_size=codebook_size,
+ codebook_dim=codebook_dim,
+ **kwargs,
+ )
+ for _ in range(num_quantizers)
+ ]
+ )
+
+ def forward(self, z, n_quantizers: int = None):
+ """
+ Parameters
+ ----------
+ z : Tensor[B x D x T]
+ n_quantizers : int, optional
+ No. of quantizers to use
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
+ when in training mode, and a random number of quantizers is used.
+ Returns
+ -------
+ "quantized_out" : Tensor[B x D x T]
+ Quantized continuous representation of input
+ "all_indices" : Tensor[N x B x T]
+ Codebook indices for each codebook
+ (quantized discrete representation of input)
+ "all_commit_losses" : Tensor[N]
+ "all_codebook_losses" : Tensor[N]
+ "all_quantized" : Tensor[N x B x D x T]
+ """
+
+ quantized_out = 0.0
+ residual = z
+
+ all_commit_losses = []
+ all_codebook_losses = []
+ all_indices = []
+ all_quantized = []
+
+ if n_quantizers is None:
+ n_quantizers = self.num_quantizers
+
+ if self.training:
+ n_quantizers = torch.ones((z.shape[0],)) * self.num_quantizers + 1
+ dropout = torch.randint(1, self.num_quantizers + 1, (z.shape[0],))
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
+ n_quantizers = n_quantizers.to(z.device)
+
+ for i, quantizer in enumerate(self.quantizers):
+ if self.training is False and i >= n_quantizers:
+ break
+
+ z_q_i, commit_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
+ residual
+ )
+
+ # Create mask to apply quantizer dropout
+ mask = (
+ torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
+ )
+ quantized_out = quantized_out + z_q_i * mask[:, None, None]
+ residual = residual - z_q_i
+
+ commit_loss_i = (commit_loss_i * mask).mean()
+ codebook_loss_i = (codebook_loss_i * mask).mean()
+
+ all_commit_losses.append(commit_loss_i)
+ all_codebook_losses.append(codebook_loss_i)
+ all_indices.append(indices_i)
+ all_quantized.append(z_q_i)
+
+ all_commit_losses, all_codebook_losses, all_indices, all_quantized = map(
+ torch.stack,
+ (all_commit_losses, all_codebook_losses, all_indices, all_quantized),
+ )
+
+ return (
+ quantized_out,
+ all_indices,
+ all_commit_losses,
+ all_codebook_losses,
+ all_quantized,
+ )
+
+ def vq2emb(self, vq, n_quantizers=None):
+ quantized_out = 0.0
+ if n_quantizers is None:
+ n_quantizers = self.num_quantizers
+ for idx, quantizer in enumerate(self.quantizers):
+ if idx >= n_quantizers:
+ break
+ quantized_out += quantizer.vq2emb(vq[idx])
+ return quantized_out
+
+ def latent2dist(self, z, n_quantizers=None):
+ quantized_out = 0.0
+ residual = z
+
+ all_dists = []
+ all_indices = []
+
+ if n_quantizers is None:
+ n_quantizers = self.num_quantizers
+
+ for i, quantizer in enumerate(self.quantizers):
+ if self.training is False and i >= n_quantizers:
+ break
+ dist_i, indices_i, z_q_i = quantizer.latent2dist(residual)
+ all_dists.append(dist_i)
+ all_indices.append(indices_i)
+
+ quantized_out = quantized_out + z_q_i
+ residual = residual - z_q_i
+
+ all_dists = torch.stack(all_dists)
+ all_indices = torch.stack(all_indices)
+
+ return all_dists, all_indices
diff --git a/models/codec/amphion_codec/quantize/vector_quantize.py b/models/codec/amphion_codec/quantize/vector_quantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..8296893ac7422982cb4a794418f73d7d57c18c98
--- /dev/null
+++ b/models/codec/amphion_codec/quantize/vector_quantize.py
@@ -0,0 +1,401 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange, repeat
+from torch.nn.utils import weight_norm
+
+
+def WNConv1d(*args, **kwargs):
+ return weight_norm(nn.Conv1d(*args, **kwargs))
+
+
+def WNConvTranspose1d(*args, **kwargs):
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
+
+
+def l2norm(t):
+ return F.normalize(t, p=2, dim=-1)
+
+
+def ema_inplace(moving_avg, new, decay):
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
+
+
+def laplace_smoothing(x, n_categories, eps=1e-5):
+ return (x + eps) / (x.sum() + n_categories * eps)
+
+
+def sample_vectors(samples, num):
+ num_samples, device = samples.shape[0], samples.device
+
+ if num_samples >= num:
+ indices = torch.randperm(num_samples, device=device)[:num]
+ else:
+ indices = torch.randint(0, num_samples, (num,), device=device)
+
+ return samples[indices]
+
+
+def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
+ dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
+
+ means = sample_vectors(samples, num_clusters)
+
+ for _ in range(num_iters):
+ if use_cosine_sim:
+ dists = samples @ means.t()
+ else:
+ diffs = rearrange(samples, "n d -> n () d") - rearrange(
+ means, "c d -> () c d"
+ )
+ dists = -(diffs**2).sum(dim=-1)
+
+ buckets = dists.max(dim=-1).indices
+ bins = torch.bincount(buckets, minlength=num_clusters)
+ zero_mask = bins == 0
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
+
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
+ new_means = new_means / bins_min_clamped[..., None]
+
+ if use_cosine_sim:
+ new_means = l2norm(new_means)
+
+ means = torch.where(zero_mask[..., None], means, new_means)
+
+ return means, bins
+
+
+class EuclideanCodebook(nn.Module):
+ def __init__(
+ self,
+ dim,
+ codebook_size,
+ kmeans_init=False,
+ kmeans_iters=10,
+ decay=0.8,
+ eps=1e-5,
+ threshold_ema_dead_code=2,
+ weight_init=False,
+ ):
+ super().__init__()
+
+ self.decay = decay
+ init_fn = torch.randn if not weight_init else torch.zeros
+ embed = init_fn(codebook_size, dim)
+
+ if weight_init:
+ nn.init.uniform_(embed, -1 / codebook_size, 1 / codebook_size)
+
+ self.codebook_size = codebook_size
+ self.kmeans_iters = kmeans_iters
+ self.eps = eps
+ self.threshold_ema_dead_code = threshold_ema_dead_code
+
+ self.register_buffer(
+ "initted", torch.Tensor([not kmeans_init])
+ ) # if kmeans_init is True, then initted is False; otherwise, initted is True
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
+ self.register_buffer("embed", embed)
+ self.register_buffer("embed_avg", embed.clone())
+
+ def init_embed_(self, data):
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
+ self.embed.data.copy_(embed)
+ self.embed_avg.data.copy_(embed)
+ self.cluster_size.data.copy_(cluster_size)
+ self.initted.data.copy_(torch.Tensor([True]))
+
+ def replace(self, samples, mask):
+ modified_codebook = torch.where(
+ mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
+ )
+ self.embed.data.copy_(modified_codebook)
+
+ def expire_codes_(self, batch_samples):
+ if self.threshold_ema_dead_code == 0:
+ return
+
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
+ if not torch.any(expired_codes):
+ return
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
+ self.replace(batch_samples, mask=expired_codes)
+
+ def forward(self, x):
+ shape, dtype = x.shape, x.dtype
+ flatten = rearrange(x, "... d -> (...) d")
+ embed = self.embed.t() # (codebook_size, dim) -> (dim, codebook_size)
+
+ if not self.initted:
+ self.init_embed_(flatten)
+
+ dist = -(
+ flatten.pow(2).sum(1, keepdim=True)
+ - 2 * flatten @ embed
+ + embed.pow(2).sum(0, keepdim=True)
+ )
+
+ embed_ind = dist.max(dim=-1).indices
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
+ embed_ind = embed_ind.view(*shape[:-1])
+ quantize = F.embedding(embed_ind, self.embed)
+
+ if self.training:
+ ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
+ embed_sum = (
+ flatten.t() @ embed_onehot
+ ) # (dim, ...) @ (..., codebook_size) -> (dim, codebook_size)
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
+ cluster_size = (
+ laplace_smoothing(self.cluster_size, self.codebook_size, self.eps)
+ * self.cluster_size.sum()
+ )
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
+ self.embed.data.copy_(embed_normalized)
+ self.expire_codes_(x)
+
+ return quantize, embed_ind
+
+ def vq2emb(self, vq):
+ quantize = F.embedding(vq, self.embed)
+ return quantize
+
+ def latent2dist(self, x):
+ shape, dtype = x.shape, x.dtype
+ flatten = rearrange(x, "... d -> (...) d")
+ embed = self.embed.t() # (codebook_size, dim) -> (dim, codebook_size)
+
+ if not self.initted:
+ self.init_embed_(flatten)
+
+ dist = -(
+ flatten.pow(2).sum(1, keepdim=True)
+ - 2 * flatten @ embed
+ + embed.pow(2).sum(0, keepdim=True)
+ )
+
+ embed_ind = dist.max(dim=-1).indices
+ embed_ind = embed_ind.view(*shape[:-1])
+ quantize = F.embedding(embed_ind, self.embed)
+
+ dist = dist.view(*shape[:-1], -1)
+
+ return dist, embed_ind, quantize
+
+
+class SimpleCodebook(nn.Module):
+ def __init__(
+ self,
+ dim,
+ codebook_size,
+ use_l2_normlize=False,
+ ):
+ super().__init__()
+
+ self.dim = dim
+ self.codebook_size = codebook_size
+ self.use_l2_normlize = use_l2_normlize
+
+ self.embed = nn.Embedding(self.codebook_size, self.dim)
+
+ def forward(self, x):
+ shape, dtype = x.shape, x.dtype
+ flatten = rearrange(x, "... d -> (...) d")
+ embed = self.embed.weight.t() # (codebook_size, dim) -> (dim, codebook_size)
+
+ if self.use_l2_normlize:
+ flatten = F.normalize(flatten)
+ embed = F.normalize(embed)
+
+ dist = -(
+ flatten.pow(2).sum(1, keepdim=True)
+ - 2 * flatten @ embed
+ + embed.pow(2).sum(0, keepdim=True)
+ )
+
+ embed_ind = dist.max(dim=-1).indices
+ embed_ind = embed_ind.view(*shape[:-1])
+ quantize = F.embedding(embed_ind, self.embed)
+
+ return quantize, embed_ind
+
+ def vq2emb(self, vq):
+ quantize = F.embedding(vq, self.embed.weight)
+ return quantize
+
+ def latent2dist(self, x):
+ shape, dtype = x.shape, x.dtype
+ flatten = rearrange(x, "... d -> (...) d")
+ embed = self.embed.weight.t() # (codebook_size, dim) -> (dim, codebook_size)
+
+ if self.use_l2_normlize:
+ flatten = F.normalize(flatten)
+ embed = F.normalize(embed)
+
+ dist = -(
+ flatten.pow(2).sum(1, keepdim=True)
+ - 2 * flatten @ embed
+ + embed.pow(2).sum(0, keepdim=True)
+ )
+
+ embed_ind = dist.max(dim=-1).indices
+ embed_ind = embed_ind.view(*shape[:-1])
+ quantize = F.embedding(embed_ind, self.embed)
+
+ dist = dist.view(*shape[:-1], -1)
+
+ return dist, embed_ind, quantize
+
+
+class VectorQuantize(nn.Module):
+ """Vector quantization and factorized vecotor quantization implementation
+ Args:
+ input_dim (int): Dimension of input.
+ codebook_size (int): Codebook size.
+ codebook_dim (int): Codebook dimension. We suggest use codebook_dim = input_dim
+ if use codebook_type == "euclidean", otherwise, if you want to use
+ factorized vector quantization, use codebook_dim as small number (e.g. 8 or 32).
+ commitment (float): Weight for commitment loss.
+ use_l2_normlize (bool): Whether to use l2 normlized codes for factorized vecotor quantization,
+ we suggest use it as True if you want to use factorized vector quantization
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
+ decay (float): Decay for exponential moving average over the codebooks.
+ epsilon (float): Epsilon value for numerical stability.
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
+ that have an exponential moving average cluster size less than the specified threshold with
+ randomly selected vector from the current batch.
+ """
+
+ def __init__(
+ self,
+ input_dim,
+ codebook_size,
+ codebook_dim,
+ commitment=0.005,
+ codebook_loss_weight=1.0,
+ use_l2_normlize=False,
+ codebook_type="euclidean", # "euclidean" or "simple"
+ kmeans_init=False,
+ kmeans_iters=10,
+ decay=0.8,
+ eps=1e-5,
+ threshold_ema_dead_code=2,
+ weight_init=False,
+ ):
+ super().__init__()
+ self.input_dim = input_dim
+ self.codebook_size = codebook_size
+ self.codebook_dim = codebook_dim
+ self.commitment = commitment
+ self.codebook_loss_weight = codebook_loss_weight
+ self.use_l2_normlize = use_l2_normlize
+ self.codebook_type = codebook_type
+ self.kmeans_init = kmeans_init
+ self.kmeans_iters = kmeans_iters
+ self.decay = decay
+ self.eps = eps
+ self.threshold_ema_dead_code = threshold_ema_dead_code
+ self.weight_init = weight_init
+
+ if self.input_dim != self.codebook_dim:
+ self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
+ self.out_project = WNConv1d(
+ self.codebook_dim, self.input_dim, kernel_size=1
+ )
+
+ else:
+ self.in_project = nn.Identity()
+ self.out_project = nn.Identity()
+
+ if self.codebook_type == "euclidean":
+ self.codebook = EuclideanCodebook(
+ self.codebook_dim,
+ codebook_size=self.codebook_size,
+ kmeans_init=self.kmeans_init,
+ kmeans_iters=self.kmeans_iters,
+ decay=self.decay,
+ eps=self.eps,
+ threshold_ema_dead_code=self.threshold_ema_dead_code,
+ weight_init=self.weight_init,
+ )
+ elif self.codebook_type == "simple":
+ self.codebook = SimpleCodebook(
+ self.codebook_dim,
+ codebook_size=self.codebook_size,
+ use_l2_normlize=self.use_l2_normlize,
+ )
+ else:
+ raise NotImplementedError(
+ f"codebook_type {self.codebook_type} is not implemented!"
+ )
+
+ def forward(self, z):
+ """
+ Parameters
+ ----------
+ z: torch.Tensor[B x D x T]
+
+ Returns
+ -------
+ z_q: torch.Tensor[B x D x T]
+ Quantized continuous representation of input
+ commit_loss: Tensor[B]
+ Commitment loss to train encoder to predict vectors closer to codebook entries
+ codebook_loss: Tensor[B]
+ Codebook loss to update the codebook
+ indices: torch.Tensor[B x T]
+ Codebook indices (quantized discrete representation of input)
+ z_e: torch.Tensor[B x D x T]
+ Projected latents (continuous representation of input before quantization)
+ """
+
+ # Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
+ z_e = self.in_project(z)
+ z_q, indices = self.decode_latents(z_e)
+
+ # Compute commitment loss and codebook loss
+ if self.training:
+ commit_loss = (
+ F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
+ * self.commitment
+ )
+ codebook_loss = (
+ F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
+ * self.codebook_loss_weight
+ )
+ else:
+ commit_loss = torch.zeros(z.shape[0], device=z.device)
+ codebook_loss = torch.zeros(z.shape[0], device=z.device)
+
+ z_q = z_e + (z_q - z_e).detach()
+
+ z_q = self.out_project(z_q)
+
+ return z_q, commit_loss, codebook_loss, indices, z_e
+
+ def decode_latents(self, latents):
+ encodings = rearrange(latents, "b d t -> b t d")
+ z_q, indices = self.codebook(encodings)
+ z_q = z_q.transpose(1, 2)
+ return z_q, indices
+
+ def vq2emb(self, vq, out_proj=True):
+ emb = self.codebook.vq2emb(vq)
+ emb = emb.transpose(1, 2)
+ if out_proj:
+ emb = self.out_project(emb)
+ return emb
+
+ def latent2dist(self, latents):
+ latents = rearrange(latents, "b d t -> b t d")
+ dist, embed_ind, quantize = self.codebook.latent2dist(latents)
+ return dist, embed_ind, quantize.transpose(1, 2)
diff --git a/models/codec/amphion_codec/vocos.py b/models/codec/amphion_codec/vocos.py
new file mode 100644
index 0000000000000000000000000000000000000000..038d8ef4fd932d22c7704fca07f89ab675637ced
--- /dev/null
+++ b/models/codec/amphion_codec/vocos.py
@@ -0,0 +1,881 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Optional, Tuple
+
+import numpy as np
+import scipy
+import torch
+from torch import nn, view_as_real, view_as_complex
+from torch import nn
+from torch.nn.utils import weight_norm, remove_weight_norm
+from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
+import librosa
+
+
+def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
+ """
+ Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
+
+ Args:
+ x (Tensor): Input tensor.
+ clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
+
+ Returns:
+ Tensor: Element-wise logarithm of the input tensor with clipping applied.
+ """
+ return torch.log(torch.clip(x, min=clip_val))
+
+
+def symlog(x: torch.Tensor) -> torch.Tensor:
+ return torch.sign(x) * torch.log1p(x.abs())
+
+
+def symexp(x: torch.Tensor) -> torch.Tensor:
+ return torch.sign(x) * (torch.exp(x.abs()) - 1)
+
+
+class STFT(nn.Module):
+ def __init__(
+ self,
+ n_fft: int,
+ hop_length: int,
+ win_length: int,
+ center=True,
+ ):
+ super().__init__()
+ self.center = center
+ self.n_fft = n_fft
+ self.hop_length = hop_length
+ self.win_length = win_length
+ window = torch.hann_window(win_length)
+ self.register_buffer("window", window)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # x: (B, T * hop_length)
+
+ if not self.center:
+ pad = self.win_length - self.hop_length
+ x = torch.nn.functional.pad(x, (pad // 2, pad // 2), mode="reflect")
+
+ stft_spec = torch.stft(
+ x,
+ self.n_fft,
+ hop_length=self.hop_length,
+ win_length=self.win_length,
+ window=self.window,
+ center=self.center,
+ return_complex=False,
+ ) # (B, n_fft // 2 + 1, T, 2)
+
+ rea = stft_spec[:, :, :, 0] # (B, n_fft // 2 + 1, T, 2)
+ imag = stft_spec[:, :, :, 1] # (B, n_fft // 2 + 1, T, 2)
+
+ log_mag = torch.log(
+ torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2))) + 1e-5
+ ) # (B, n_fft // 2 + 1, T)
+ phase = torch.atan2(imag, rea) # (B, n_fft // 2 + 1, T)
+
+ return log_mag, phase
+
+
+class ISTFT(nn.Module):
+ """
+ Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
+ windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
+ See issue: https://github.com/pytorch/pytorch/issues/62323
+ Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
+ The NOLA constraint is met as we trim padded samples anyway.
+
+ Args:
+ n_fft (int): Size of Fourier transform.
+ hop_length (int): The distance between neighboring sliding window frames.
+ win_length (int): The size of window frame and STFT filter.
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+ """
+
+ def __init__(
+ self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
+ ):
+ super().__init__()
+ if padding not in ["center", "same"]:
+ raise ValueError("Padding must be 'center' or 'same'.")
+ self.padding = padding
+ self.n_fft = n_fft
+ self.hop_length = hop_length
+ self.win_length = win_length
+ window = torch.hann_window(win_length)
+ self.register_buffer("window", window)
+
+ def forward(self, spec: torch.Tensor) -> torch.Tensor:
+ """
+ Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
+
+ Args:
+ spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
+ N is the number of frequency bins, and T is the number of time frames.
+
+ Returns:
+ Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
+ """
+ if self.padding == "center":
+ # Fallback to pytorch native implementation
+ return torch.istft(
+ spec,
+ self.n_fft,
+ self.hop_length,
+ self.win_length,
+ self.window,
+ center=True,
+ )
+ elif self.padding == "same":
+ pad = (self.win_length - self.hop_length) // 2
+ else:
+ raise ValueError("Padding must be 'center' or 'same'.")
+
+ assert spec.dim() == 3, "Expected a 3D tensor as input"
+ B, N, T = spec.shape
+
+ # Inverse FFT
+ ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
+ ifft = ifft * self.window[None, :, None]
+
+ # Overlap and Add
+ output_size = (T - 1) * self.hop_length + self.win_length
+ y = torch.nn.functional.fold(
+ ifft,
+ output_size=(1, output_size),
+ kernel_size=(1, self.win_length),
+ stride=(1, self.hop_length),
+ )[:, 0, 0, pad:-pad]
+
+ # Window envelope
+ window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
+ window_envelope = torch.nn.functional.fold(
+ window_sq,
+ output_size=(1, output_size),
+ kernel_size=(1, self.win_length),
+ stride=(1, self.hop_length),
+ ).squeeze()[pad:-pad]
+
+ # Normalize
+ assert (window_envelope > 1e-11).all()
+ y = y / window_envelope
+
+ return y
+
+
+class MDCT(nn.Module):
+ """
+ Modified Discrete Cosine Transform (MDCT) module.
+
+ Args:
+ frame_len (int): Length of the MDCT frame.
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+ """
+
+ def __init__(self, frame_len: int, padding: str = "same"):
+ super().__init__()
+ if padding not in ["center", "same"]:
+ raise ValueError("Padding must be 'center' or 'same'.")
+ self.padding = padding
+ self.frame_len = frame_len
+ N = frame_len // 2
+ n0 = (N + 1) / 2
+ window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
+ self.register_buffer("window", window)
+
+ pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
+ post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
+ # view_as_real: NCCL Backend does not support ComplexFloat data type
+ # https://github.com/pytorch/pytorch/issues/71613
+ self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
+ self.register_buffer("post_twiddle", view_as_real(post_twiddle))
+
+ def forward(self, audio: torch.Tensor) -> torch.Tensor:
+ """
+ Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
+
+ Args:
+ audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
+ and T is the length of the audio.
+
+ Returns:
+ Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
+ and N is the number of frequency bins.
+ """
+ if self.padding == "center":
+ audio = torch.nn.functional.pad(
+ audio, (self.frame_len // 2, self.frame_len // 2)
+ )
+ elif self.padding == "same":
+ # hop_length is 1/2 frame_len
+ audio = torch.nn.functional.pad(
+ audio, (self.frame_len // 4, self.frame_len // 4)
+ )
+ else:
+ raise ValueError("Padding must be 'center' or 'same'.")
+
+ x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
+ N = self.frame_len // 2
+ x = x * self.window.expand(x.shape)
+ X = torch.fft.fft(
+ x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1
+ )[..., :N]
+ res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
+ return torch.real(res) * np.sqrt(2)
+
+
+class IMDCT(nn.Module):
+ """
+ Inverse Modified Discrete Cosine Transform (IMDCT) module.
+
+ Args:
+ frame_len (int): Length of the MDCT frame.
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+ """
+
+ def __init__(self, frame_len: int, padding: str = "same"):
+ super().__init__()
+ if padding not in ["center", "same"]:
+ raise ValueError("Padding must be 'center' or 'same'.")
+ self.padding = padding
+ self.frame_len = frame_len
+ N = frame_len // 2
+ n0 = (N + 1) / 2
+ window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
+ self.register_buffer("window", window)
+
+ pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
+ post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
+ self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
+ self.register_buffer("post_twiddle", view_as_real(post_twiddle))
+
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
+ """
+ Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
+
+ Args:
+ X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
+ L is the number of frames, and N is the number of frequency bins.
+
+ Returns:
+ Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
+ """
+ B, L, N = X.shape
+ Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
+ Y[..., :N] = X
+ Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
+ y = torch.fft.ifft(
+ Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1
+ )
+ y = (
+ torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape))
+ * np.sqrt(N)
+ * np.sqrt(2)
+ )
+ result = y * self.window.expand(y.shape)
+ output_size = (1, (L + 1) * N)
+ audio = torch.nn.functional.fold(
+ result.transpose(1, 2),
+ output_size=output_size,
+ kernel_size=(1, self.frame_len),
+ stride=(1, self.frame_len // 2),
+ )[:, 0, 0, :]
+
+ if self.padding == "center":
+ pad = self.frame_len // 2
+ elif self.padding == "same":
+ pad = self.frame_len // 4
+ else:
+ raise ValueError("Padding must be 'center' or 'same'.")
+
+ audio = audio[:, pad:-pad]
+ return audio
+
+
+class FourierHead(nn.Module):
+ """Base class for inverse fourier modules."""
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
+ L is the sequence length, and H denotes the model dimension.
+
+ Returns:
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
+ """
+ raise NotImplementedError("Subclasses must implement the forward method.")
+
+
+class ISTFTHead(FourierHead):
+ """
+ ISTFT Head module for predicting STFT complex coefficients.
+
+ Args:
+ dim (int): Hidden dimension of the model.
+ n_fft (int): Size of Fourier transform.
+ hop_length (int): The distance between neighboring sliding window frames, which should align with
+ the resolution of the input features.
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+ """
+
+ def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
+ super().__init__()
+ out_dim = n_fft + 2
+ self.out = torch.nn.Linear(dim, out_dim)
+ self.istft = ISTFT(
+ n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass of the ISTFTHead module.
+
+ Args:
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
+ L is the sequence length, and H denotes the model dimension.
+
+ Returns:
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
+ """
+ x = self.out(x).transpose(1, 2)
+ mag, p = x.chunk(2, dim=1)
+ mag = torch.exp(mag)
+ mag = torch.clip(
+ mag, max=1e2
+ ) # safeguard to prevent excessively large magnitudes
+ # wrapping happens here. These two lines produce real and imaginary value
+ x = torch.cos(p)
+ y = torch.sin(p)
+ # recalculating phase here does not produce anything new
+ # only costs time
+ # phase = torch.atan2(y, x)
+ # S = mag * torch.exp(phase * 1j)
+ # better directly produce the complex value
+ S = mag * (x + 1j * y)
+ audio = self.istft(S)
+ return audio
+
+
+class IMDCTSymExpHead(FourierHead):
+ """
+ IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
+
+ Args:
+ dim (int): Hidden dimension of the model.
+ mdct_frame_len (int): Length of the MDCT frame.
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+ sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
+ based on perceptual scaling. Defaults to None.
+ clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ mdct_frame_len: int,
+ padding: str = "same",
+ sample_rate: Optional[int] = None,
+ clip_audio: bool = False,
+ ):
+ super().__init__()
+ out_dim = mdct_frame_len // 2
+ self.out = nn.Linear(dim, out_dim)
+ self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
+ self.clip_audio = clip_audio
+
+ if sample_rate is not None:
+ # optionally init the last layer following mel-scale
+ m_max = _hz_to_mel(sample_rate // 2)
+ m_pts = torch.linspace(0, m_max, out_dim)
+ f_pts = _mel_to_hz(m_pts)
+ scale = 1 - (f_pts / f_pts.max())
+
+ with torch.no_grad():
+ self.out.weight.mul_(scale.view(-1, 1))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass of the IMDCTSymExpHead module.
+
+ Args:
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
+ L is the sequence length, and H denotes the model dimension.
+
+ Returns:
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
+ """
+ x = self.out(x)
+ x = symexp(x)
+ x = torch.clip(
+ x, min=-1e2, max=1e2
+ ) # safeguard to prevent excessively large magnitudes
+ audio = self.imdct(x)
+ if self.clip_audio:
+ audio = torch.clip(x, min=-1.0, max=1.0)
+
+ return audio
+
+
+class IMDCTCosHead(FourierHead):
+ """
+ IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
+
+ Args:
+ dim (int): Hidden dimension of the model.
+ mdct_frame_len (int): Length of the MDCT frame.
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+ clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ mdct_frame_len: int,
+ padding: str = "same",
+ clip_audio: bool = False,
+ ):
+ super().__init__()
+ self.clip_audio = clip_audio
+ self.out = nn.Linear(dim, mdct_frame_len)
+ self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass of the IMDCTCosHead module.
+
+ Args:
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
+ L is the sequence length, and H denotes the model dimension.
+
+ Returns:
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
+ """
+ x = self.out(x)
+ m, p = x.chunk(2, dim=2)
+ m = torch.exp(m).clip(
+ max=1e2
+ ) # safeguard to prevent excessively large magnitudes
+ audio = self.imdct(m * torch.cos(p))
+ if self.clip_audio:
+ audio = torch.clip(x, min=-1.0, max=1.0)
+ return audio
+
+
+class ConvNeXtBlock(nn.Module):
+ """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
+
+ Args:
+ dim (int): Number of input channels.
+ intermediate_dim (int): Dimensionality of the intermediate layer.
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
+ Defaults to None.
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
+ None means non-conditional LayerNorm. Defaults to None.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ intermediate_dim: int,
+ layer_scale_init_value: float,
+ adanorm_num_embeddings: Optional[int] = None,
+ ):
+ super().__init__()
+ self.dwconv = nn.Conv1d(
+ dim, dim, kernel_size=7, padding=3, groups=dim
+ ) # depthwise conv
+ self.adanorm = adanorm_num_embeddings is not None
+ if adanorm_num_embeddings:
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
+ else:
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
+ self.pwconv1 = nn.Linear(
+ dim, intermediate_dim
+ ) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
+ self.gamma = (
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
+ if layer_scale_init_value > 0
+ else None
+ )
+
+ def forward(
+ self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ residual = x
+ x = self.dwconv(x)
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
+ if self.adanorm:
+ assert cond_embedding_id is not None
+ x = self.norm(x, cond_embedding_id)
+ else:
+ x = self.norm(x)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+ if self.gamma is not None:
+ x = self.gamma * x
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
+
+ x = residual + x
+ return x
+
+
+class AdaLayerNorm(nn.Module):
+ """
+ Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
+
+ Args:
+ num_embeddings (int): Number of embeddings.
+ embedding_dim (int): Dimension of the embeddings.
+ """
+
+ def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
+ super().__init__()
+ self.eps = eps
+ self.dim = embedding_dim
+ self.scale = nn.Embedding(
+ num_embeddings=num_embeddings, embedding_dim=embedding_dim
+ )
+ self.shift = nn.Embedding(
+ num_embeddings=num_embeddings, embedding_dim=embedding_dim
+ )
+ torch.nn.init.ones_(self.scale.weight)
+ torch.nn.init.zeros_(self.shift.weight)
+
+ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
+ scale = self.scale(cond_embedding_id)
+ shift = self.shift(cond_embedding_id)
+ x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
+ x = x * scale + shift
+ return x
+
+
+class ResBlock1(nn.Module):
+ """
+ ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
+ but without upsampling layers.
+
+ Args:
+ dim (int): Number of input channels.
+ kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
+ dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
+ Defaults to (1, 3, 5).
+ lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
+ Defaults to 0.1.
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
+ Defaults to None.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ kernel_size: int = 3,
+ dilation: Tuple[int, int, int] = (1, 3, 5),
+ lrelu_slope: float = 0.1,
+ layer_scale_init_value: Optional[float] = None,
+ ):
+ super().__init__()
+ self.lrelu_slope = lrelu_slope
+ self.convs1 = nn.ModuleList(
+ [
+ weight_norm(
+ nn.Conv1d(
+ dim,
+ dim,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=self.get_padding(kernel_size, dilation[0]),
+ )
+ ),
+ weight_norm(
+ nn.Conv1d(
+ dim,
+ dim,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=self.get_padding(kernel_size, dilation[1]),
+ )
+ ),
+ weight_norm(
+ nn.Conv1d(
+ dim,
+ dim,
+ kernel_size,
+ 1,
+ dilation=dilation[2],
+ padding=self.get_padding(kernel_size, dilation[2]),
+ )
+ ),
+ ]
+ )
+
+ self.convs2 = nn.ModuleList(
+ [
+ weight_norm(
+ nn.Conv1d(
+ dim,
+ dim,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=self.get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ nn.Conv1d(
+ dim,
+ dim,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=self.get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ nn.Conv1d(
+ dim,
+ dim,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=self.get_padding(kernel_size, 1),
+ )
+ ),
+ ]
+ )
+
+ self.gamma = nn.ParameterList(
+ [
+ (
+ nn.Parameter(
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
+ )
+ if layer_scale_init_value is not None
+ else None
+ ),
+ (
+ nn.Parameter(
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
+ )
+ if layer_scale_init_value is not None
+ else None
+ ),
+ (
+ nn.Parameter(
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
+ )
+ if layer_scale_init_value is not None
+ else None
+ ),
+ ]
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
+ xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
+ xt = c1(xt)
+ xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
+ xt = c2(xt)
+ if gamma is not None:
+ xt = gamma * xt
+ x = xt + x
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs1:
+ remove_weight_norm(l)
+ for l in self.convs2:
+ remove_weight_norm(l)
+
+ @staticmethod
+ def get_padding(kernel_size: int, dilation: int = 1) -> int:
+ return int((kernel_size * dilation - dilation) / 2)
+
+
+class Backbone(nn.Module):
+ """Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
+
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
+ """
+ Args:
+ x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
+ C denotes output features, and L is the sequence length.
+
+ Returns:
+ Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
+ and H denotes the model dimension.
+ """
+ raise NotImplementedError("Subclasses must implement the forward method.")
+
+
+class VocosBackbone(Backbone):
+ """
+ Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
+
+ Args:
+ input_channels (int): Number of input features channels.
+ dim (int): Hidden dimension of the model.
+ intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
+ num_layers (int): Number of ConvNeXtBlock layers.
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
+ None means non-conditional model. Defaults to None.
+ """
+
+ def __init__(
+ self,
+ input_channels: int,
+ dim: int,
+ intermediate_dim: int,
+ num_layers: int,
+ layer_scale_init_value: Optional[float] = None,
+ adanorm_num_embeddings: Optional[int] = None,
+ ):
+ super().__init__()
+ self.input_channels = input_channels
+ self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
+ self.adanorm = adanorm_num_embeddings is not None
+ if adanorm_num_embeddings:
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
+ else:
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
+ layer_scale_init_value = layer_scale_init_value or 1 / num_layers
+ self.convnext = nn.ModuleList(
+ [
+ ConvNeXtBlock(
+ dim=dim,
+ intermediate_dim=intermediate_dim,
+ layer_scale_init_value=layer_scale_init_value,
+ adanorm_num_embeddings=adanorm_num_embeddings,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+ self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
+ bandwidth_id = kwargs.get("bandwidth_id", None)
+ x = self.embed(x)
+ if self.adanorm:
+ assert bandwidth_id is not None
+ x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
+ else:
+ x = self.norm(x.transpose(1, 2))
+ x = x.transpose(1, 2)
+ for conv_block in self.convnext:
+ x = conv_block(x, cond_embedding_id=bandwidth_id)
+ x = self.final_layer_norm(x.transpose(1, 2))
+ return x
+
+
+class VocosResNetBackbone(Backbone):
+ """
+ Vocos backbone module built with ResBlocks.
+
+ Args:
+ input_channels (int): Number of input features channels.
+ dim (int): Hidden dimension of the model.
+ num_blocks (int): Number of ResBlock1 blocks.
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
+ """
+
+ def __init__(
+ self,
+ input_channels,
+ dim,
+ num_blocks,
+ layer_scale_init_value=None,
+ ):
+ super().__init__()
+ self.input_channels = input_channels
+ self.embed = weight_norm(
+ nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
+ )
+ layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
+ self.resnet = nn.Sequential(
+ *[
+ ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
+ for _ in range(num_blocks)
+ ]
+ )
+
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
+ x = self.embed(x)
+ x = self.resnet(x)
+ x = x.transpose(1, 2)
+ return x
+
+
+class Vocos(nn.Module):
+ def __init__(
+ self,
+ input_channels: int = 256,
+ dim: int = 384,
+ intermediate_dim: int = 1152,
+ num_layers: int = 8,
+ n_fft: int = 800,
+ hop_size: int = 200,
+ padding: str = "same",
+ adanorm_num_embeddings=None,
+ cfg=None,
+ ):
+ super().__init__()
+
+ input_channels = (
+ cfg.input_channels
+ if cfg is not None and hasattr(cfg, "input_channels")
+ else input_channels
+ )
+ dim = cfg.dim if cfg is not None and hasattr(cfg, "dim") else dim
+ intermediate_dim = (
+ cfg.intermediate_dim
+ if cfg is not None and hasattr(cfg, "intermediate_dim")
+ else intermediate_dim
+ )
+ num_layers = (
+ cfg.num_layers
+ if cfg is not None and hasattr(cfg, "num_layers")
+ else num_layers
+ )
+ adanorm_num_embeddings = (
+ cfg.adanorm_num_embeddings
+ if cfg is not None and hasattr(cfg, "adanorm_num_embeddings")
+ else adanorm_num_embeddings
+ )
+ n_fft = cfg.n_fft if cfg is not None and hasattr(cfg, "n_fft") else n_fft
+ hop_size = (
+ cfg.hop_size if cfg is not None and hasattr(cfg, "hop_size") else hop_size
+ )
+ padding = (
+ cfg.padding if cfg is not None and hasattr(cfg, "padding") else padding
+ )
+
+ self.backbone = VocosBackbone(
+ input_channels=input_channels,
+ dim=dim,
+ intermediate_dim=intermediate_dim,
+ num_layers=num_layers,
+ adanorm_num_embeddings=adanorm_num_embeddings,
+ )
+ self.head = ISTFTHead(dim, n_fft, hop_size, padding)
+
+ def forward(self, x):
+ x = self.backbone(x)
+ x = self.head(x)
+
+ return x[:, None, :]
diff --git a/models/codec/codec_dataset.py b/models/codec/codec_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..be0a30856a9ce8183a15013f71965b2f010647b4
--- /dev/null
+++ b/models/codec/codec_dataset.py
@@ -0,0 +1,264 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Iterable
+import torch
+import numpy as np
+import torch.utils.data
+from torch.nn.utils.rnn import pad_sequence
+from utils.data_utils import *
+from torch.utils.data import ConcatDataset, Dataset
+
+
+class CodecDataset(torch.utils.data.Dataset):
+ def __init__(self, cfg, dataset, is_valid=False):
+ """
+ Args:
+ cfg: config
+ dataset: dataset name
+ is_valid: whether to use train or valid dataset
+ """
+ assert isinstance(dataset, str)
+
+ processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
+
+ meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
+ self.metafile_path = os.path.join(processed_data_dir, meta_file)
+ self.metadata = self.get_metadata()
+
+ self.data_root = processed_data_dir
+ self.cfg = cfg
+
+ if cfg.preprocess.use_audio:
+ self.utt2audio_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2audio_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.audio_dir,
+ uid + ".npy",
+ )
+ elif cfg.preprocess.use_label:
+ self.utt2label_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2label_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.label_dir,
+ uid + ".npy",
+ )
+ elif cfg.preprocess.use_one_hot:
+ self.utt2one_hot_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2one_hot_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.one_hot_dir,
+ uid + ".npy",
+ )
+
+ if cfg.preprocess.use_mel:
+ self.utt2mel_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2mel_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.mel_dir,
+ uid + ".npy",
+ )
+
+ if cfg.preprocess.use_frame_pitch:
+ self.utt2frame_pitch_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2frame_pitch_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.pitch_dir,
+ uid + ".npy",
+ )
+
+ if cfg.preprocess.use_uv:
+ self.utt2uv_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+ self.utt2uv_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.uv_dir,
+ uid + ".npy",
+ )
+
+ if cfg.preprocess.use_amplitude_phase:
+ self.utt2logamp_path = {}
+ self.utt2pha_path = {}
+ self.utt2rea_path = {}
+ self.utt2imag_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+ self.utt2logamp_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.log_amplitude_dir,
+ uid + ".npy",
+ )
+ self.utt2pha_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.phase_dir,
+ uid + ".npy",
+ )
+ self.utt2rea_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.real_dir,
+ uid + ".npy",
+ )
+ self.utt2imag_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.imaginary_dir,
+ uid + ".npy",
+ )
+
+ def __getitem__(self, index):
+ utt_info = self.metadata[index]
+
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ single_feature = dict()
+
+ if self.cfg.preprocess.use_mel:
+ mel = np.load(self.utt2mel_path[utt])
+ assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T]
+
+ if "target_len" not in single_feature.keys():
+ single_feature["target_len"] = mel.shape[1]
+
+ single_feature["mel"] = mel
+
+ if self.cfg.preprocess.use_frame_pitch:
+ frame_pitch = np.load(self.utt2frame_pitch_path[utt])
+
+ if "target_len" not in single_feature.keys():
+ single_feature["target_len"] = len(frame_pitch)
+
+ aligned_frame_pitch = align_length(
+ frame_pitch, single_feature["target_len"]
+ )
+
+ single_feature["frame_pitch"] = aligned_frame_pitch
+
+ if self.cfg.preprocess.use_audio:
+ audio = np.load(self.utt2audio_path[utt])
+
+ single_feature["audio"] = audio
+
+ return single_feature
+
+ def get_metadata(self):
+ with open(self.metafile_path, "r", encoding="utf-8") as f:
+ metadata = json.load(f)
+
+ return metadata
+
+ def get_dataset_name(self):
+ return self.metadata[0]["Dataset"]
+
+ def __len__(self):
+ return len(self.metadata)
+
+
+class CodecConcatDataset(ConcatDataset):
+ def __init__(self, datasets: Iterable[Dataset], full_audio_inference=False):
+ """Concatenate a series of datasets with their random inference audio merged."""
+ super().__init__(datasets)
+
+ self.cfg = self.datasets[0].cfg
+
+ self.metadata = []
+
+ # Merge metadata
+ for dataset in self.datasets:
+ self.metadata += dataset.metadata
+
+ # Merge random inference features
+ if full_audio_inference:
+ self.eval_audios = []
+ self.eval_dataset_names = []
+ if self.cfg.preprocess.use_mel:
+ self.eval_mels = []
+ if self.cfg.preprocess.use_frame_pitch:
+ self.eval_pitchs = []
+ for dataset in self.datasets:
+ self.eval_audios.append(dataset.eval_audio)
+ self.eval_dataset_names.append(dataset.get_dataset_name())
+ if self.cfg.preprocess.use_mel:
+ self.eval_mels.append(dataset.eval_mel)
+ if self.cfg.preprocess.use_frame_pitch:
+ self.eval_pitchs.append(dataset.eval_pitch)
+
+
+class CodecCollator(object):
+ """Zero-pads model inputs and targets based on number of frames per step"""
+
+ def __init__(self, cfg):
+ self.cfg = cfg
+
+ def __call__(self, batch):
+ packed_batch_features = dict()
+
+ # mel: [b, n_mels, frame]
+ # frame_pitch: [b, frame]
+ # audios: [b, frame * hop_size]
+
+ for key in batch[0].keys():
+ if key == "target_len":
+ packed_batch_features["target_len"] = torch.LongTensor(
+ [b["target_len"] for b in batch]
+ )
+ masks = [
+ torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
+ ]
+ packed_batch_features["mask"] = pad_sequence(
+ masks, batch_first=True, padding_value=0
+ )
+ elif key == "mel":
+ values = [torch.from_numpy(b[key]).T for b in batch]
+ packed_batch_features[key] = pad_sequence(
+ values, batch_first=True, padding_value=0
+ )
+ else:
+ values = [torch.from_numpy(b[key]) for b in batch]
+ packed_batch_features[key] = pad_sequence(
+ values, batch_first=True, padding_value=0
+ )
+
+ return packed_batch_features
diff --git a/models/codec/codec_inference.py b/models/codec/codec_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..95e354c5db80cbac986543fdf7923014426c5078
--- /dev/null
+++ b/models/codec/codec_inference.py
@@ -0,0 +1,515 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import torch
+import json
+import json5
+import time
+import accelerate
+import random
+import numpy as np
+import shutil
+
+from pathlib import Path
+from tqdm import tqdm
+from glob import glob
+from accelerate.logging import get_logger
+from torch.utils.data import DataLoader
+
+from models.vocoders.vocoder_dataset import (
+ VocoderDataset,
+ VocoderCollator,
+ VocoderConcatDataset,
+)
+
+from models.vocoders.gan.generator import bigvgan, hifigan, melgan, nsfhifigan, apnet
+from models.vocoders.flow.waveglow import waveglow
+from models.vocoders.diffusion.diffwave import diffwave
+from models.vocoders.autoregressive.wavenet import wavenet
+from models.vocoders.autoregressive.wavernn import wavernn
+
+from models.vocoders.gan import gan_vocoder_inference
+from models.vocoders.diffusion import diffusion_vocoder_inference
+
+from utils.io import save_audio
+
+_vocoders = {
+ "diffwave": diffwave.DiffWave,
+ "wavernn": wavernn.WaveRNN,
+ "wavenet": wavenet.WaveNet,
+ "waveglow": waveglow.WaveGlow,
+ "nsfhifigan": nsfhifigan.NSFHiFiGAN,
+ "bigvgan": bigvgan.BigVGAN,
+ "hifigan": hifigan.HiFiGAN,
+ "melgan": melgan.MelGAN,
+ "apnet": apnet.APNet,
+}
+
+# Forward call for generalized Inferencor
+_vocoder_forward_funcs = {
+ # "world": world_inference.synthesis_audios,
+ # "wavernn": wavernn_inference.synthesis_audios,
+ # "wavenet": wavenet_inference.synthesis_audios,
+ "diffwave": diffusion_vocoder_inference.vocoder_inference,
+ "nsfhifigan": gan_vocoder_inference.vocoder_inference,
+ "bigvgan": gan_vocoder_inference.vocoder_inference,
+ "melgan": gan_vocoder_inference.vocoder_inference,
+ "hifigan": gan_vocoder_inference.vocoder_inference,
+ "apnet": gan_vocoder_inference.vocoder_inference,
+}
+
+# APIs for other tasks. e.g. SVC, TTS, TTA...
+_vocoder_infer_funcs = {
+ # "world": world_inference.synthesis_audios,
+ # "wavernn": wavernn_inference.synthesis_audios,
+ # "wavenet": wavenet_inference.synthesis_audios,
+ "diffwave": diffusion_vocoder_inference.synthesis_audios,
+ "nsfhifigan": gan_vocoder_inference.synthesis_audios,
+ "bigvgan": gan_vocoder_inference.synthesis_audios,
+ "melgan": gan_vocoder_inference.synthesis_audios,
+ "hifigan": gan_vocoder_inference.synthesis_audios,
+ "apnet": gan_vocoder_inference.synthesis_audios,
+}
+
+
+class VocoderInference(object):
+ def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
+ super().__init__()
+
+ start = time.monotonic_ns()
+ self.args = args
+ self.cfg = cfg
+ self.infer_type = infer_type
+
+ # Init accelerator
+ self.accelerator = accelerate.Accelerator()
+ self.accelerator.wait_for_everyone()
+
+ # Get logger
+ with self.accelerator.main_process_first():
+ self.logger = get_logger("inference", log_level=args.log_level)
+
+ # Log some info
+ self.logger.info("=" * 56)
+ self.logger.info("||\t\t" + "New inference process started." + "\t\t||")
+ self.logger.info("=" * 56)
+ self.logger.info("\n")
+
+ self.vocoder_dir = args.vocoder_dir
+ self.logger.debug(f"Vocoder dir: {args.vocoder_dir}")
+
+ os.makedirs(args.output_dir, exist_ok=True)
+ if os.path.exists(os.path.join(args.output_dir, "pred")):
+ shutil.rmtree(os.path.join(args.output_dir, "pred"))
+ if os.path.exists(os.path.join(args.output_dir, "gt")):
+ shutil.rmtree(os.path.join(args.output_dir, "gt"))
+ os.makedirs(os.path.join(args.output_dir, "pred"), exist_ok=True)
+ os.makedirs(os.path.join(args.output_dir, "gt"), exist_ok=True)
+
+ # Set random seed
+ with self.accelerator.main_process_first():
+ start = time.monotonic_ns()
+ self._set_random_seed(self.cfg.train.random_seed)
+ end = time.monotonic_ns()
+ self.logger.debug(
+ f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
+ )
+ self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
+
+ # Setup inference mode
+ if self.infer_type == "infer_from_dataset":
+ self.cfg.dataset = self.args.infer_datasets
+ elif self.infer_type == "infer_from_feature":
+ self._build_tmp_dataset_from_feature()
+ self.cfg.dataset = ["tmp"]
+ elif self.infer_type == "infer_from_audio":
+ self._build_tmp_dataset_from_audio()
+ self.cfg.dataset = ["tmp"]
+
+ # Setup data loader
+ with self.accelerator.main_process_first():
+ self.logger.info("Building dataset...")
+ start = time.monotonic_ns()
+ self.test_dataloader = self._build_dataloader()
+ end = time.monotonic_ns()
+ self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
+
+ # Build model
+ with self.accelerator.main_process_first():
+ self.logger.info("Building model...")
+ start = time.monotonic_ns()
+ self.model = self._build_model()
+ end = time.monotonic_ns()
+ self.logger.info(f"Building model done in {(end - start) / 1e6:.3f}ms")
+
+ # Init with accelerate
+ self.logger.info("Initializing accelerate...")
+ start = time.monotonic_ns()
+ self.accelerator = accelerate.Accelerator()
+ (self.model, self.test_dataloader) = self.accelerator.prepare(
+ self.model, self.test_dataloader
+ )
+ end = time.monotonic_ns()
+ self.accelerator.wait_for_everyone()
+ self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.3f}ms")
+
+ with self.accelerator.main_process_first():
+ self.logger.info("Loading checkpoint...")
+ start = time.monotonic_ns()
+ if os.path.isdir(args.vocoder_dir):
+ if os.path.isdir(os.path.join(args.vocoder_dir, "checkpoint")):
+ self._load_model(os.path.join(args.vocoder_dir, "checkpoint"))
+ else:
+ self._load_model(os.path.join(args.vocoder_dir))
+ else:
+ self._load_model(os.path.join(args.vocoder_dir))
+ end = time.monotonic_ns()
+ self.logger.info(f"Loading checkpoint done in {(end - start) / 1e6:.3f}ms")
+
+ self.model.eval()
+ self.accelerator.wait_for_everyone()
+
+ def _build_tmp_dataset_from_feature(self):
+ if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")):
+ shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
+
+ utts = []
+ mels = glob(os.path.join(self.args.feature_folder, "mels", "*.npy"))
+ for i, mel in enumerate(mels):
+ uid = mel.split("/")[-1].split(".")[0]
+ utt = {"Dataset": "tmp", "Uid": uid, "index": i}
+ utts.append(utt)
+
+ os.makedirs(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
+ with open(
+ os.path.join(self.cfg.preprocess.processed_dir, "tmp", "test.json"), "w"
+ ) as f:
+ json.dump(utts, f)
+
+ meta_info = {"dataset": "tmp", "test": {"size": len(utts)}}
+
+ with open(
+ os.path.join(self.cfg.preprocess.processed_dir, "tmp", "meta_info.json"),
+ "w",
+ ) as f:
+ json.dump(meta_info, f)
+
+ features = glob(os.path.join(self.args.feature_folder, "*"))
+ for feature in features:
+ feature_name = feature.split("/")[-1]
+ if os.path.isfile(feature):
+ continue
+ shutil.copytree(
+ os.path.join(self.args.feature_folder, feature_name),
+ os.path.join(self.cfg.preprocess.processed_dir, "tmp", feature_name),
+ )
+
+ def _build_tmp_dataset_from_audio(self):
+ if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")):
+ shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
+
+ utts = []
+ audios = glob(os.path.join(self.args.audio_folder, "*"))
+ for i, audio in enumerate(audios):
+ uid = audio.split("/")[-1].split(".")[0]
+ utt = {"Dataset": "tmp", "Uid": uid, "index": i, "Path": audio}
+ utts.append(utt)
+
+ os.makedirs(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
+ with open(
+ os.path.join(self.cfg.preprocess.processed_dir, "tmp", "test.json"), "w"
+ ) as f:
+ json.dump(utts, f)
+
+ meta_info = {"dataset": "tmp", "test": {"size": len(utts)}}
+
+ with open(
+ os.path.join(self.cfg.preprocess.processed_dir, "tmp", "meta_info.json"),
+ "w",
+ ) as f:
+ json.dump(meta_info, f)
+
+ from processors import acoustic_extractor
+
+ acoustic_extractor.extract_utt_acoustic_features_serial(
+ utts, os.path.join(self.cfg.preprocess.processed_dir, "tmp"), self.cfg
+ )
+
+ def _build_test_dataset(self):
+ return VocoderDataset, VocoderCollator
+
+ def _build_model(self):
+ model = _vocoders[self.cfg.model.generator](self.cfg)
+ return model
+
+ def _build_dataloader(self):
+ """Build dataloader which merges a series of datasets."""
+ Dataset, Collator = self._build_test_dataset()
+
+ datasets_list = []
+ for dataset in self.cfg.dataset:
+ subdataset = Dataset(self.cfg, dataset, is_valid=True)
+ datasets_list.append(subdataset)
+ test_dataset = VocoderConcatDataset(datasets_list, full_audio_inference=False)
+ test_collate = Collator(self.cfg)
+ test_batch_size = min(self.cfg.inference.batch_size, len(test_dataset))
+ test_dataloader = DataLoader(
+ test_dataset,
+ collate_fn=test_collate,
+ num_workers=1,
+ batch_size=test_batch_size,
+ shuffle=False,
+ )
+ self.test_batch_size = test_batch_size
+ self.test_dataset = test_dataset
+ return test_dataloader
+
+ def _load_model(self, checkpoint_dir, from_multi_gpu=False):
+ """Load model from checkpoint. If a folder is given, it will
+ load the latest checkpoint in checkpoint_dir. If a path is given
+ it will load the checkpoint specified by checkpoint_path.
+ **Only use this method after** ``accelerator.prepare()``.
+ """
+ if os.path.isdir(checkpoint_dir):
+ if "epoch" in checkpoint_dir and "step" in checkpoint_dir:
+ checkpoint_path = checkpoint_dir
+ else:
+ # Load the latest accelerator state dicts
+ ls = [
+ str(i)
+ for i in Path(checkpoint_dir).glob("*")
+ if not "audio" in str(i)
+ ]
+ ls.sort(
+ key=lambda x: int(x.split("/")[-1].split("_")[0].split("-")[-1]),
+ reverse=True,
+ )
+ checkpoint_path = ls[0]
+ accelerate.load_checkpoint_and_dispatch(
+ self.accelerator.unwrap_model(self.model),
+ os.path.join(checkpoint_path, "pytorch_model.bin"),
+ )
+ return str(checkpoint_path)
+ else:
+ # Load old .pt checkpoints
+ if self.cfg.model.generator in [
+ "bigvgan",
+ "hifigan",
+ "melgan",
+ "nsfhifigan",
+ ]:
+ ckpt = torch.load(
+ checkpoint_dir,
+ map_location=(
+ torch.device("cuda")
+ if torch.cuda.is_available()
+ else torch.device("cpu")
+ ),
+ )
+ if from_multi_gpu:
+ pretrained_generator_dict = ckpt["generator_state_dict"]
+ generator_dict = self.model.state_dict()
+
+ new_generator_dict = {
+ k.split("module.")[-1]: v
+ for k, v in pretrained_generator_dict.items()
+ if (
+ k.split("module.")[-1] in generator_dict
+ and v.shape == generator_dict[k.split("module.")[-1]].shape
+ )
+ }
+
+ generator_dict.update(new_generator_dict)
+
+ self.model.load_state_dict(generator_dict)
+ else:
+ self.model.load_state_dict(ckpt["generator_state_dict"])
+ else:
+ self.model.load_state_dict(torch.load(checkpoint_dir)["state_dict"])
+ return str(checkpoint_dir)
+
+ def inference(self):
+ """Inference via batches"""
+ for i, batch in tqdm(enumerate(self.test_dataloader)):
+ if self.cfg.preprocess.use_frame_pitch:
+ audio_pred = _vocoder_forward_funcs[self.cfg.model.generator](
+ self.cfg,
+ self.model,
+ batch["mel"].transpose(-1, -2),
+ f0s=batch["frame_pitch"].float(),
+ device=next(self.model.parameters()).device,
+ )
+ else:
+ audio_pred = _vocoder_forward_funcs[self.cfg.model.generator](
+ self.cfg,
+ self.model,
+ batch["mel"].transpose(-1, -2),
+ device=next(self.model.parameters()).device,
+ )
+ audio_ls = audio_pred.chunk(self.test_batch_size)
+ audio_gt_ls = batch["audio"].cpu().chunk(self.test_batch_size)
+ length_ls = batch["target_len"].cpu().chunk(self.test_batch_size)
+ j = 0
+ for it, it_gt, l in zip(audio_ls, audio_gt_ls, length_ls):
+ l = l.item()
+ it = it.squeeze(0).squeeze(0)[: l * self.cfg.preprocess.hop_size]
+ it_gt = it_gt.squeeze(0)[: l * self.cfg.preprocess.hop_size]
+ uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"]
+ save_audio(
+ os.path.join(self.args.output_dir, "pred", "{}.wav").format(uid),
+ it,
+ self.cfg.preprocess.sample_rate,
+ )
+ save_audio(
+ os.path.join(self.args.output_dir, "gt", "{}.wav").format(uid),
+ it_gt,
+ self.cfg.preprocess.sample_rate,
+ )
+ j += 1
+
+ if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")):
+ shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
+
+ def _set_random_seed(self, seed):
+ """Set random seed for all possible random modules."""
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.random.manual_seed(seed)
+
+ def _count_parameters(self, model):
+ return sum(p.numel() for p in model.parameters())
+
+ def _dump_cfg(self, path):
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ json5.dump(
+ self.cfg,
+ open(path, "w"),
+ indent=4,
+ sort_keys=True,
+ ensure_ascii=False,
+ quote_keys=True,
+ )
+
+
+def load_nnvocoder(
+ cfg,
+ vocoder_name,
+ weights_file,
+ from_multi_gpu=False,
+):
+ """Load the specified vocoder.
+ cfg: the vocoder config filer.
+ weights_file: a folder or a .pt path.
+ from_multi_gpu: automatically remove the "module" string in state dicts if "True".
+ """
+ print("Loading Vocoder from Weights file: {}".format(weights_file))
+
+ # Build model
+ model = _vocoders[vocoder_name](cfg)
+ if not os.path.isdir(weights_file):
+ # Load from .pt file
+ if vocoder_name in ["bigvgan", "hifigan", "melgan", "nsfhifigan"]:
+ ckpt = torch.load(
+ weights_file,
+ map_location=(
+ torch.device("cuda")
+ if torch.cuda.is_available()
+ else torch.device("cpu")
+ ),
+ )
+ if from_multi_gpu:
+ pretrained_generator_dict = ckpt["generator_state_dict"]
+ generator_dict = model.state_dict()
+
+ new_generator_dict = {
+ k.split("module.")[-1]: v
+ for k, v in pretrained_generator_dict.items()
+ if (
+ k.split("module.")[-1] in generator_dict
+ and v.shape == generator_dict[k.split("module.")[-1]].shape
+ )
+ }
+
+ generator_dict.update(new_generator_dict)
+
+ model.load_state_dict(generator_dict)
+ else:
+ model.load_state_dict(ckpt["generator_state_dict"])
+ else:
+ model.load_state_dict(torch.load(weights_file)["state_dict"])
+ else:
+ # Load from accelerator state dict
+ weights_file = os.path.join(weights_file, "checkpoint")
+ ls = [str(i) for i in Path(weights_file).glob("*") if not "audio" in str(i)]
+ ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
+ checkpoint_path = ls[0]
+ accelerator = accelerate.Accelerator()
+ model = accelerator.prepare(model)
+ accelerator.load_state(checkpoint_path)
+
+ if torch.cuda.is_available():
+ model = model.cuda()
+
+ model = model.eval()
+ return model
+
+
+def tensorize(data, device, n_samples):
+ """
+ data: a list of numpy array
+ """
+ assert type(data) == list
+ if n_samples:
+ data = data[:n_samples]
+ data = [torch.as_tensor(x, device=device) for x in data]
+ return data
+
+
+def synthesis(
+ cfg,
+ vocoder_weight_file,
+ n_samples,
+ pred,
+ f0s=None,
+ batch_size=64,
+ fast_inference=False,
+):
+ """Synthesis audios from a given vocoder and series of given features.
+ cfg: vocoder config.
+ vocoder_weight_file: a folder of accelerator state dict or a path to the .pt file.
+ pred: a list of numpy arrays. [(seq_len1, acoustic_features_dim), (seq_len2, acoustic_features_dim), ...]
+ """
+
+ vocoder_name = cfg.model.generator
+
+ print("Synthesis audios using {} vocoder...".format(vocoder_name))
+
+ ###### TODO: World Vocoder Refactor ######
+ # if vocoder_name == "world":
+ # world_inference.synthesis_audios(
+ # cfg, dataset_name, split, n_samples, pred, save_dir, tag
+ # )
+ # return
+
+ # ====== Loading neural vocoder model ======
+ vocoder = load_nnvocoder(
+ cfg, vocoder_name, weights_file=vocoder_weight_file, from_multi_gpu=True
+ )
+ device = next(vocoder.parameters()).device
+
+ # ====== Inference for predicted acoustic features ======
+ # pred: (frame_len, n_mels) -> (n_mels, frame_len)
+ mels_pred = tensorize([p.T for p in pred], device, n_samples)
+ print("For predicted mels, #sample = {}...".format(len(mels_pred)))
+ audios_pred = _vocoder_infer_funcs[vocoder_name](
+ cfg,
+ vocoder,
+ mels_pred,
+ f0s=f0s,
+ batch_size=batch_size,
+ fast_inference=fast_inference,
+ )
+ return audios_pred
diff --git a/models/codec/codec_sampler.py b/models/codec/codec_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d29f88a291dcf7386cadaeae0d990c8e76ebf98
--- /dev/null
+++ b/models/codec/codec_sampler.py
@@ -0,0 +1,126 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import random
+
+from torch.utils.data import ConcatDataset, Dataset
+from torch.utils.data.sampler import (
+ BatchSampler,
+ RandomSampler,
+ Sampler,
+ SequentialSampler,
+)
+
+
+class ScheduledSampler(Sampler):
+ """A sampler that samples data from a given concat-dataset.
+
+ Args:
+ concat_dataset (ConcatDataset): a concatenated dataset consisting of all datasets
+ batch_size (int): batch size
+ holistic_shuffle (bool): whether to shuffle the whole dataset or not
+ logger (logging.Logger): logger to print warning message
+
+ Usage:
+ For cfg.train.batch_size = 3, cfg.train.holistic_shuffle = False, cfg.train.drop_last = True:
+ >>> list(ScheduledSampler(ConcatDataset([0, 1, 2], [3, 4, 5], [6, 7, 8]])))
+ [3, 4, 5, 0, 1, 2, 6, 7, 8]
+ """
+
+ def __init__(
+ self, concat_dataset, batch_size, holistic_shuffle, logger=None, type="train"
+ ):
+ if not isinstance(concat_dataset, ConcatDataset):
+ raise ValueError(
+ "concat_dataset must be an instance of ConcatDataset, but got {}".format(
+ type(concat_dataset)
+ )
+ )
+ if not isinstance(batch_size, int):
+ raise ValueError(
+ "batch_size must be an integer, but got {}".format(type(batch_size))
+ )
+ if not isinstance(holistic_shuffle, bool):
+ raise ValueError(
+ "holistic_shuffle must be a boolean, but got {}".format(
+ type(holistic_shuffle)
+ )
+ )
+
+ self.concat_dataset = concat_dataset
+ self.batch_size = batch_size
+ self.holistic_shuffle = holistic_shuffle
+
+ affected_dataset_name = []
+ affected_dataset_len = []
+ for dataset in concat_dataset.datasets:
+ dataset_len = len(dataset)
+ dataset_name = dataset.get_dataset_name()
+ if dataset_len < batch_size:
+ affected_dataset_name.append(dataset_name)
+ affected_dataset_len.append(dataset_len)
+
+ self.type = type
+ for dataset_name, dataset_len in zip(
+ affected_dataset_name, affected_dataset_len
+ ):
+ if not type == "valid":
+ logger.warning(
+ "The {} dataset {} has a length of {}, which is smaller than the batch size {}. This may cause unexpected behavior.".format(
+ type, dataset_name, dataset_len, batch_size
+ )
+ )
+
+ def __len__(self):
+ # the number of batches with drop last
+ num_of_batches = sum(
+ [
+ math.floor(len(dataset) / self.batch_size)
+ for dataset in self.concat_dataset.datasets
+ ]
+ )
+ return num_of_batches * self.batch_size
+
+ def __iter__(self):
+ iters = []
+ for dataset in self.concat_dataset.datasets:
+ iters.append(
+ SequentialSampler(dataset).__iter__()
+ if self.holistic_shuffle
+ else RandomSampler(dataset).__iter__()
+ )
+ init_indices = [0] + self.concat_dataset.cumulative_sizes[:-1]
+ output_batches = []
+ for dataset_idx in range(len(self.concat_dataset.datasets)):
+ cur_batch = []
+ for idx in iters[dataset_idx]:
+ cur_batch.append(idx + init_indices[dataset_idx])
+ if len(cur_batch) == self.batch_size:
+ output_batches.append(cur_batch)
+ cur_batch = []
+ if self.type == "valid" and len(cur_batch) > 0:
+ output_batches.append(cur_batch)
+ cur_batch = []
+ # force drop last in training
+ random.shuffle(output_batches)
+ output_indices = [item for sublist in output_batches for item in sublist]
+ return iter(output_indices)
+
+
+def build_samplers(concat_dataset: Dataset, cfg, logger, type):
+ sampler = ScheduledSampler(
+ concat_dataset,
+ cfg.train.batch_size,
+ cfg.train.sampler.holistic_shuffle,
+ logger,
+ type,
+ )
+ batch_sampler = BatchSampler(
+ sampler,
+ cfg.train.batch_size,
+ cfg.train.sampler.drop_last if not type == "valid" else False,
+ )
+ return sampler, batch_sampler
diff --git a/models/codec/codec_trainer.py b/models/codec/codec_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a6f838814c194b9d3bccfd5c8e66ea5881a33c6
--- /dev/null
+++ b/models/codec/codec_trainer.py
@@ -0,0 +1,166 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import random
+from pathlib import Path
+import re
+
+import accelerate
+import json5
+import numpy as np
+import torch
+from accelerate.utils import ProjectConfiguration
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+
+from models.codec.codec_sampler import build_samplers
+
+
+class CodecTrainer:
+ def __init__(self):
+ super().__init__()
+
+ def _init_accelerator(self):
+ """Initialize the accelerator components."""
+ self.exp_dir = os.path.join(
+ os.path.abspath(self.cfg.log_dir), self.args.exp_name
+ )
+ project_config = ProjectConfiguration(
+ project_dir=self.exp_dir, logging_dir=os.path.join(self.exp_dir, "log")
+ )
+ self.accelerator = accelerate.Accelerator(
+ gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step,
+ log_with=self.cfg.train.tracker,
+ project_config=project_config,
+ )
+ if self.accelerator.is_main_process:
+ os.makedirs(project_config.project_dir, exist_ok=True)
+ os.makedirs(project_config.logging_dir, exist_ok=True)
+ with self.accelerator.main_process_first():
+ self.accelerator.init_trackers(self.args.exp_name)
+
+ def _build_dataset(self):
+ pass
+
+ def _build_criterion(self):
+ pass
+
+ def _build_model(self):
+ pass
+
+ def _build_dataloader(self):
+ """Build dataloader which merges a series of datasets."""
+ # Build dataset instance for each dataset and combine them by ConcatDataset
+ Dataset, Collator = self._build_dataset()
+
+ # Build train set
+ train_dataset = Dataset(self.cfg, self.cfg.dataset, is_valid=False)
+ train_collate = Collator(self.cfg)
+ sampler = torch.utils.data.distributed.DistributedSampler(
+ train_dataset,
+ num_replicas=self.accelerator.num_processes,
+ rank=self.accelerator.local_process_index,
+ shuffle=True,
+ seed=self.cfg.train.random_seed,
+ )
+ train_loader = DataLoader(
+ train_dataset,
+ batch_size=self.cfg.train.batch_size,
+ collate_fn=train_collate,
+ sampler=sampler,
+ num_workers=self.cfg.train.dataloader.num_worker,
+ pin_memory=self.cfg.train.dataloader.pin_memory,
+ )
+ return train_loader, None
+
+ def _build_optimizer(self):
+ pass
+
+ def _build_scheduler(self):
+ pass
+
+ def _load_model(self, checkpoint_dir, checkpoint_path=None, resume_type="resume"):
+ """Load model from checkpoint. If a folder is given, it will
+ load the latest checkpoint in checkpoint_dir. If a path is given
+ it will load the checkpoint specified by checkpoint_path.
+ **Only use this method after** ``accelerator.prepare()``.
+ """
+ if checkpoint_path is None:
+ ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
+ ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
+ checkpoint_path = ls[0]
+ if resume_type == "resume":
+ self.accelerator.load_state(checkpoint_path)
+ elif resume_type == "finetune":
+ accelerate.load_checkpoint_and_dispatch(
+ self.accelerator.unwrap_model(self.model),
+ os.path.join(checkpoint_path, "pytorch_model.bin"),
+ )
+ self.logger.info("Load model weights for finetune SUCCESS!")
+ else:
+ raise ValueError("Unsupported resume type: {}".format(resume_type))
+ self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
+ self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
+ return checkpoint_path
+
+ def train_loop(self):
+ pass
+
+ def _train_epoch(self):
+ pass
+
+ def _valid_epoch(self):
+ pass
+
+ def _train_step(self):
+ pass
+
+ def _valid_step(self):
+ pass
+
+ def _inference(self):
+ pass
+
+ def _set_random_seed(self, seed):
+ """Set random seed for all possible random modules."""
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.random.manual_seed(seed)
+
+ def _check_nan(self, loss):
+ if torch.any(torch.isnan(loss)):
+ self.logger.fatal("Fatal Error: NaN!")
+ self.logger.error("loss = {:.6f}".format(loss.item()), in_order=True)
+
+ def _check_basic_configs(self):
+ if self.cfg.train.gradient_accumulation_step <= 0:
+ self.logger.fatal("Invalid gradient_accumulation_step value!")
+ self.logger.error(
+ f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
+ )
+ self.accelerator.end_training()
+ raise ValueError(
+ f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
+ )
+
+ def _count_parameters(self):
+ pass
+
+ def _dump_cfg(self, path):
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ json5.dump(
+ self.cfg,
+ open(path, "w"),
+ indent=4,
+ sort_keys=True,
+ ensure_ascii=False,
+ quote_keys=True,
+ )
+
+ def _is_valid_pattern(self, directory_name):
+ directory_name = str(directory_name)
+ pattern = r"^epoch-\d{4}_step-\d{7}_loss-\d{1}\.\d{6}"
+ return re.match(pattern, directory_name) is not None
diff --git a/models/codec/facodec/__init__.py b/models/codec/facodec/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/codec/facodec/alias_free_torch/__init__.py b/models/codec/facodec/alias_free_torch/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3bccdb97a3706bcb7149f48e04178cf00a5e877
--- /dev/null
+++ b/models/codec/facodec/alias_free_torch/__init__.py
@@ -0,0 +1,5 @@
+# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
+
+from .filter import *
+from .resample import *
+from .act import *
diff --git a/models/codec/facodec/alias_free_torch/act.py b/models/codec/facodec/alias_free_torch/act.py
new file mode 100644
index 0000000000000000000000000000000000000000..779d58d5f1e889f8b639dd019a0ce951e69e4cfb
--- /dev/null
+++ b/models/codec/facodec/alias_free_torch/act.py
@@ -0,0 +1,29 @@
+# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
+
+import torch.nn as nn
+from .resample import UpSample1d, DownSample1d
+
+
+class Activation1d(nn.Module):
+ def __init__(
+ self,
+ activation,
+ up_ratio: int = 2,
+ down_ratio: int = 2,
+ up_kernel_size: int = 12,
+ down_kernel_size: int = 12,
+ ):
+ super().__init__()
+ self.up_ratio = up_ratio
+ self.down_ratio = down_ratio
+ self.act = activation
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
+
+ # x: [B,C,T]
+ def forward(self, x):
+ x = self.upsample(x)
+ x = self.act(x)
+ x = self.downsample(x)
+
+ return x
diff --git a/models/codec/facodec/alias_free_torch/filter.py b/models/codec/facodec/alias_free_torch/filter.py
new file mode 100644
index 0000000000000000000000000000000000000000..ece8e02fce0e65e13522e990a80d1bfeeffd46ba
--- /dev/null
+++ b/models/codec/facodec/alias_free_torch/filter.py
@@ -0,0 +1,96 @@
+# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+
+if "sinc" in dir(torch):
+ sinc = torch.sinc
+else:
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
+ # https://adefossez.github.io/julius/julius/core.html
+ def sinc(x: torch.Tensor):
+ """
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
+ """
+ return torch.where(
+ x == 0,
+ torch.tensor(1.0, device=x.device, dtype=x.dtype),
+ torch.sin(math.pi * x) / math.pi / x,
+ )
+
+
+# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
+# https://adefossez.github.io/julius/julius/lowpass.html
+def kaiser_sinc_filter1d(
+ cutoff, half_width, kernel_size
+): # return filter [1,1,kernel_size]
+ even = kernel_size % 2 == 0
+ half_size = kernel_size // 2
+
+ # For kaiser window
+ delta_f = 4 * half_width
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
+ if A > 50.0:
+ beta = 0.1102 * (A - 8.7)
+ elif A >= 21.0:
+ beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
+ else:
+ beta = 0.0
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
+
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
+ if even:
+ time = torch.arange(-half_size, half_size) + 0.5
+ else:
+ time = torch.arange(kernel_size) - half_size
+ if cutoff == 0:
+ filter_ = torch.zeros_like(time)
+ else:
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
+ # Normalize filter to have sum = 1, otherwise we will have a small leakage
+ # of the constant component in the input signal.
+ filter_ /= filter_.sum()
+ filter = filter_.view(1, 1, kernel_size)
+
+ return filter
+
+
+class LowPassFilter1d(nn.Module):
+ def __init__(
+ self,
+ cutoff=0.5,
+ half_width=0.6,
+ stride: int = 1,
+ padding: bool = True,
+ padding_mode: str = "replicate",
+ kernel_size: int = 12,
+ ):
+ # kernel_size should be even number for stylegan3 setup,
+ # in this implementation, odd number is also possible.
+ super().__init__()
+ if cutoff < -0.0:
+ raise ValueError("Minimum cutoff must be larger than zero.")
+ if cutoff > 0.5:
+ raise ValueError("A cutoff above 0.5 does not make sense.")
+ self.kernel_size = kernel_size
+ self.even = kernel_size % 2 == 0
+ self.pad_left = kernel_size // 2 - int(self.even)
+ self.pad_right = kernel_size // 2
+ self.stride = stride
+ self.padding = padding
+ self.padding_mode = padding_mode
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
+ self.register_buffer("filter", filter)
+
+ # input [B, C, T]
+ def forward(self, x):
+ _, C, _ = x.shape
+
+ if self.padding:
+ x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
+ out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
+
+ return out
diff --git a/models/codec/facodec/alias_free_torch/resample.py b/models/codec/facodec/alias_free_torch/resample.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee993b10339141b469b67c3e11f5d73c5f4e0bca
--- /dev/null
+++ b/models/codec/facodec/alias_free_torch/resample.py
@@ -0,0 +1,57 @@
+# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
+
+import torch.nn as nn
+from torch.nn import functional as F
+from .filter import LowPassFilter1d
+from .filter import kaiser_sinc_filter1d
+
+
+class UpSample1d(nn.Module):
+ def __init__(self, ratio=2, kernel_size=None):
+ super().__init__()
+ self.ratio = ratio
+ self.kernel_size = (
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
+ )
+ self.stride = ratio
+ self.pad = self.kernel_size // ratio - 1
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
+ self.pad_right = (
+ self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
+ )
+ filter = kaiser_sinc_filter1d(
+ cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
+ )
+ self.register_buffer("filter", filter)
+
+ # x: [B, C, T]
+ def forward(self, x):
+ _, C, _ = x.shape
+
+ x = F.pad(x, (self.pad, self.pad), mode="replicate")
+ x = self.ratio * F.conv_transpose1d(
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
+ )
+ x = x[..., self.pad_left : -self.pad_right]
+
+ return x
+
+
+class DownSample1d(nn.Module):
+ def __init__(self, ratio=2, kernel_size=None):
+ super().__init__()
+ self.ratio = ratio
+ self.kernel_size = (
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
+ )
+ self.lowpass = LowPassFilter1d(
+ cutoff=0.5 / ratio,
+ half_width=0.6 / ratio,
+ stride=ratio,
+ kernel_size=self.kernel_size,
+ )
+
+ def forward(self, x):
+ xx = self.lowpass(x)
+
+ return xx
diff --git a/models/codec/facodec/facodec_dataset.py b/models/codec/facodec/facodec_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..e86b82d1f0e6e49395dd9340961bdd517b47f8b3
--- /dev/null
+++ b/models/codec/facodec/facodec_dataset.py
@@ -0,0 +1,98 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import random
+
+import numpy as np
+
+import torchaudio
+import librosa
+from torch.nn import functional as F
+
+from torch.nn.utils.rnn import pad_sequence
+from utils.data_utils import *
+from models.codec.codec_dataset import CodecDataset
+
+
+class FAcodecDataset(torch.utils.data.Dataset):
+ def __init__(self, cfg, dataset, is_valid=False):
+ """
+ Args:
+ cfg: config
+ dataset: dataset name
+ is_valid: whether to use train or valid dataset
+ """
+ self.data_root_dir = cfg.dataset
+ self.data_list = []
+ # walk through the dataset directory recursively, save all files ends with .wav/.mp3/.opus/.flac/.m4a
+ for root, _, files in os.walk(self.data_root_dir):
+ for file in files:
+ if file.endswith((".wav", ".mp3", ".opus", ".flac", ".m4a")):
+ self.data_list.append(os.path.join(root, file))
+ self.sr = cfg.preprocess_params.sr
+ self.duration_range = cfg.preprocess_params.duration_range
+ self.to_mel = torchaudio.transforms.MelSpectrogram(
+ n_mels=cfg.preprocess_params.spect_params.n_mels,
+ n_fft=cfg.preprocess_params.spect_params.n_fft,
+ win_length=cfg.preprocess_params.spect_params.win_length,
+ hop_length=cfg.preprocess_params.spect_params.hop_length,
+ )
+ self.mean, self.std = -4, 4
+
+ def preprocess(self, wave):
+ wave_tensor = (
+ torch.from_numpy(wave).float() if isinstance(wave, np.ndarray) else wave
+ )
+ mel_tensor = self.to_mel(wave_tensor)
+ mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - self.mean) / self.std
+ return mel_tensor
+
+ def __len__(self):
+ # return len(self.data_list)
+ return len(self.data_list) # return a fixed number for testing
+
+ def __getitem__(self, index):
+ wave, _ = librosa.load(self.data_list[index], sr=self.sr)
+ wave = np.random.randn(self.sr * random.randint(*self.duration_range))
+ wave = wave / np.max(np.abs(wave))
+ mel = self.preprocess(wave).squeeze(0)
+ wave = torch.from_numpy(wave).float()
+ return wave, mel
+
+
+class FAcodecCollator(object):
+ """Zero-pads model inputs and targets based on number of frames per step"""
+
+ def __init__(self, cfg):
+ self.cfg = cfg
+
+ def __call__(self, batch):
+ # batch[0] = wave, mel, text, f0, speakerid
+ batch_size = len(batch)
+
+ # sort by mel length
+ lengths = [b[1].shape[1] for b in batch]
+ batch_indexes = np.argsort(lengths)[::-1]
+ batch = [batch[bid] for bid in batch_indexes]
+
+ nmels = batch[0][1].size(0)
+ max_mel_length = max([b[1].shape[1] for b in batch])
+ max_wave_length = max([b[0].size(0) for b in batch])
+
+ mels = torch.zeros((batch_size, nmels, max_mel_length)).float() - 10
+ waves = torch.zeros((batch_size, max_wave_length)).float()
+
+ mel_lengths = torch.zeros(batch_size).long()
+ wave_lengths = torch.zeros(batch_size).long()
+
+ for bid, (wave, mel) in enumerate(batch):
+ mel_size = mel.size(1)
+ mels[bid, :, :mel_size] = mel
+ waves[bid, : wave.size(0)] = wave
+ mel_lengths[bid] = mel_size
+ wave_lengths[bid] = wave.size(0)
+
+ return waves, mels, wave_lengths, mel_lengths
diff --git a/models/codec/facodec/facodec_inference.py b/models/codec/facodec/facodec_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..c494349e4c1140d9d11e9c5742a8faa7e1560705
--- /dev/null
+++ b/models/codec/facodec/facodec_inference.py
@@ -0,0 +1,137 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import shutil
+import warnings
+import argparse
+import torch
+import os
+import yaml
+
+warnings.simplefilter("ignore")
+
+from .modules.commons import *
+import time
+
+import torchaudio
+import librosa
+from collections import OrderedDict
+
+
+class FAcodecInference(object):
+ def __init__(self, args=None, cfg=None):
+ self.args = args
+ self.cfg = cfg
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ self.model = self._build_model()
+ self._load_checkpoint()
+
+ def _build_model(self):
+ model = build_model(self.cfg.model_params)
+ _ = [model[key].to(self.device) for key in model]
+ return model
+
+ def _load_checkpoint(self):
+ sd = torch.load(self.args.checkpoint_path, map_location="cpu")
+ sd = sd["net"] if "net" in sd else sd
+ new_params = dict()
+ for key, state_dict in sd.items():
+ new_state_dict = OrderedDict()
+ for k, v in state_dict.items():
+ if k.startswith("module."):
+ k = k[7:]
+ new_state_dict[k] = v
+ new_params[key] = new_state_dict
+ for key in new_params:
+ if key in self.model:
+ self.model[key].load_state_dict(new_params[key])
+ _ = [self.model[key].eval() for key in self.model]
+
+ @torch.no_grad()
+ def inference(self, source, output_dir):
+ source_audio = librosa.load(source, sr=self.cfg.preprocess_params.sr)[0]
+ source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(self.device)
+
+ z = self.model.encoder(source_audio[None, ...].to(self.device).float())
+ (
+ z,
+ quantized,
+ commitment_loss,
+ codebook_loss,
+ timbre,
+ codes,
+ ) = self.model.quantizer(
+ z,
+ source_audio[None, ...].to(self.device).float(),
+ n_c=self.cfg.model_params.n_c_codebooks,
+ return_codes=True,
+ )
+
+ full_pred_wave = self.model.decoder(z)
+
+ os.makedirs(output_dir, exist_ok=True)
+ source_name = source.split("/")[-1].split(".")[0]
+ torchaudio.save(
+ f"{output_dir}/reconstructed_{source_name}.wav",
+ full_pred_wave[0].cpu(),
+ self.cfg.preprocess_params.sr,
+ )
+
+ print(
+ "Reconstructed audio saved as: ",
+ f"{output_dir}/reconstructed_{source_name}.wav",
+ )
+
+ return quantized, codes
+
+ @torch.no_grad()
+ def voice_conversion(self, source, reference, output_dir):
+ source_audio = librosa.load(source, sr=self.cfg.preprocess_params.sr)[0]
+ source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(self.device)
+
+ reference_audio = librosa.load(reference, sr=self.cfg.preprocess_params.sr)[0]
+ reference_audio = (
+ torch.tensor(reference_audio).unsqueeze(0).float().to(self.device)
+ )
+
+ z = self.model.encoder(source_audio[None, ...].to(self.device).float())
+ z, quantized, commitment_loss, codebook_loss, timbre = self.model.quantizer(
+ z,
+ source_audio[None, ...].to(self.device).float(),
+ n_c=self.cfg.model_params.n_c_codebooks,
+ )
+
+ z_ref = self.model.encoder(reference_audio[None, ...].to(self.device).float())
+ (
+ z_ref,
+ quantized_ref,
+ commitment_loss_ref,
+ codebook_loss_ref,
+ timbre_ref,
+ ) = self.model.quantizer(
+ z_ref,
+ reference_audio[None, ...].to(self.device).float(),
+ n_c=self.cfg.model_params.n_c_codebooks,
+ )
+
+ z_conv = self.model.quantizer.voice_conversion(
+ quantized[0] + quantized[1],
+ reference_audio[None, ...].to(self.device).float(),
+ )
+ full_pred_wave = self.model.decoder(z_conv)
+
+ os.makedirs(output_dir, exist_ok=True)
+ source_name = source.split("/")[-1].split(".")[0]
+ reference_name = reference.split("/")[-1].split(".")[0]
+ torchaudio.save(
+ f"{output_dir}/converted_{source_name}_to_{reference_name}.wav",
+ full_pred_wave[0].cpu(),
+ self.cfg.preprocess_params.sr,
+ )
+
+ print(
+ "Voice conversion results saved as: ",
+ f"{output_dir}/converted_{source_name}_to_{reference_name}.wav",
+ )
diff --git a/models/codec/facodec/facodec_trainer.py b/models/codec/facodec/facodec_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e0b685739def8c036f319ce76b7dc7b827dba8e
--- /dev/null
+++ b/models/codec/facodec/facodec_trainer.py
@@ -0,0 +1,776 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import time
+import random
+from pathlib import Path
+import re
+import glob
+
+import accelerate
+import json
+import numpy as np
+import torch
+from accelerate.utils import ProjectConfiguration
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+
+import torch
+import torch.nn.functional as F
+import torchaudio
+
+from accelerate.logging import get_logger
+
+from models.codec.facodec.facodec_dataset import FAcodecDataset, FAcodecCollator
+from models.codec.codec_sampler import build_samplers
+from models.codec.codec_trainer import CodecTrainer
+
+from modules.dac.nn.loss import (
+ MultiScaleSTFTLoss,
+ MelSpectrogramLoss,
+ GANLoss,
+ L1Loss,
+ FocalLoss,
+)
+from audiotools import AudioSignal
+
+from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
+
+try:
+ import nemo.collections.asr as nemo_asr
+except ImportError:
+ print(
+ "Unable to import nemo_asr, titanet outputs will be set to random values, you may only run debugging mode. DO NOT USE THIS FOR TRAINING"
+ )
+ nemo_asr = None
+
+from models.codec.facodec.modules.commons import (
+ build_model,
+ load_checkpoint,
+ load_F0_models,
+ log_norm,
+)
+from models.codec.facodec.optimizer import build_optimizer
+
+
+class FAcodecTrainer(CodecTrainer):
+ def __init__(self, args, cfg):
+ super().__init__()
+
+ self.args = args
+ self.cfg = cfg
+
+ cfg.exp_name = args.exp_name
+
+ # Init accelerator
+ self._init_accelerator()
+ self.accelerator.wait_for_everyone()
+
+ # Init logger
+ with self.accelerator.main_process_first():
+ self.logger = get_logger(args.exp_name, log_level=args.log_level)
+
+ self.logger.info("=" * 56)
+ self.logger.info("||\t\t" + "New training process started." + "\t\t||")
+ self.logger.info("=" * 56)
+ self.logger.info("\n")
+ self.logger.debug(f"Using {args.log_level.upper()} logging level.")
+ self.logger.info(f"Experiment name: {args.exp_name}")
+ self.logger.info(f"Experiment directory: {self.exp_dir}")
+ self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
+ if self.accelerator.is_main_process:
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
+ self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
+
+ # Init training status
+ self.batch_count: int = 0
+ self.step: int = 0
+ self.epoch: int = 0
+
+ self.max_epoch = (
+ self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf")
+ )
+ self.logger.info(
+ "Max epoch: {}".format(
+ self.max_epoch if self.max_epoch < float("inf") else "Unlimited"
+ )
+ )
+
+ # Check potential erorrs
+ if self.accelerator.is_main_process:
+ self._check_basic_configs()
+ self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride
+ self.checkpoints_path = [
+ [] for _ in range(len(self.save_checkpoint_stride))
+ ]
+ self.run_eval = self.cfg.train.run_eval
+
+ # Set random seed
+ with self.accelerator.main_process_first():
+ start = time.monotonic_ns()
+ self._set_random_seed(self.cfg.train.random_seed)
+ end = time.monotonic_ns()
+ self.logger.debug(
+ f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
+ )
+ self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
+
+ # Build dataloader
+ with self.accelerator.main_process_first():
+ self.logger.info("Building dataset...")
+ start = time.monotonic_ns()
+ self.train_dataloader, self.valid_dataloader = self._build_dataloader()
+ end = time.monotonic_ns()
+ self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
+
+ # Build model
+ with self.accelerator.main_process_first():
+ self.logger.info("Building model...")
+ start = time.monotonic_ns()
+ self.model = self._build_model()
+ end = time.monotonic_ns()
+ for _, model in self.model.items():
+ self.logger.debug(model)
+ self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms")
+ self.logger.info(f"Model parameters: {self._count_parameters()/1e6:.2f}M")
+
+ # Build optimizers and schedulers
+ with self.accelerator.main_process_first():
+ self.logger.info("Building optimizer and scheduler...")
+ start = time.monotonic_ns()
+ self.optimizer = self._build_optimizer()
+ end = time.monotonic_ns()
+ self.logger.info(
+ f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms"
+ )
+
+ # Build helper models
+ with self.accelerator.main_process_first():
+ self.logger.info("Building helper models...")
+ start = time.monotonic_ns()
+ self._built_helper_model()
+ end = time.monotonic_ns()
+ self.logger.info(
+ f"Building helper models done in {(end - start) / 1e6:.2f}ms"
+ )
+
+ # Accelerator preparing
+ self.logger.info("Initializing accelerate...")
+ start = time.monotonic_ns()
+ for k in self.model:
+ self.model[k] = self.accelerator.prepare(self.model[k])
+ for k, v in self.optimizer.optimizers.items():
+ self.optimizer.optimizers[k] = self.accelerator.prepare(
+ self.optimizer.optimizers[k]
+ )
+ self.optimizer.schedulers[k] = self.accelerator.prepare(
+ self.optimizer.schedulers[k]
+ )
+ end = time.monotonic_ns()
+ self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms")
+
+ # Build criterions
+ with self.accelerator.main_process_first():
+ self.logger.info("Building criterion...")
+ start = time.monotonic_ns()
+ self.criterions = self._build_criterion()
+ end = time.monotonic_ns()
+ self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms")
+
+ # Resume checkpoints
+ with self.accelerator.main_process_first():
+ self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
+ if args.resume_type:
+ self.logger.info("Resuming from checkpoint...")
+ start = time.monotonic_ns()
+ ckpt_path = Path(args.checkpoint)
+ if self._is_valid_pattern(ckpt_path.parts[-1]):
+ ckpt_path = self._load_model(args.checkpoint, args.resume_type)
+ else:
+ ckpt_path = self._load_model(
+ args.checkpoint, resume_type=args.resume_type
+ )
+ end = time.monotonic_ns()
+ self.logger.info(
+ f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
+ )
+ self.checkpoints_path = json.load(
+ open(os.path.join(ckpt_path, "ckpts.json"), "r")
+ )
+
+ if self.accelerator.is_main_process:
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
+ self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
+
+ # Save config
+ self.config_save_path = os.path.join(self.exp_dir, "args.json")
+
+ def _build_dataset(self):
+ return FAcodecDataset, FAcodecCollator
+
+ def _build_criterion(self):
+ criterions = dict()
+ stft_criterion = MultiScaleSTFTLoss()
+ mel_criterion = MelSpectrogramLoss(
+ n_mels=[5, 10, 20, 40, 80, 160, 320],
+ window_lengths=[32, 64, 128, 256, 512, 1024, 2048],
+ mel_fmin=[0, 0, 0, 0, 0, 0, 0],
+ mel_fmax=[None, None, None, None, None, None, None],
+ pow=1.0,
+ mag_weight=0.0,
+ clamp_eps=1e-5,
+ )
+ content_criterion = FocalLoss(gamma=2)
+ l1_criterion = L1Loss()
+ criterions["stft"] = stft_criterion
+ criterions["mel"] = mel_criterion
+ criterions["l1"] = l1_criterion
+ criterions["content"] = content_criterion
+
+ return criterions
+
+ def _build_model(self):
+ model = build_model(self.cfg.model_params)
+ _ = [model[key].to(self.accelerator.device) for key in model]
+ return model
+
+ def _built_helper_model(self):
+ device = self.accelerator.device
+ self.pitch_extractor = load_F0_models(self.cfg.F0_path).to(device)
+
+ # load model and processor
+ self.w2v_processor = Wav2Vec2Processor.from_pretrained(
+ "facebook/wav2vec2-xlsr-53-espeak-cv-ft"
+ )
+ self.w2v_model = Wav2Vec2ForCTC.from_pretrained(
+ "facebook/wav2vec2-xlsr-53-espeak-cv-ft"
+ ).to(device)
+ self.w2v_model.eval()
+
+ if nemo_asr is None:
+ self.speaker_model = None
+ else:
+ self.speaker_model = (
+ nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(
+ "nvidia/speakerverification_en_titanet_large"
+ )
+ )
+ self.speaker_model = self.speaker_model.to(device)
+ self.speaker_model.eval()
+
+ def _build_optimizer(self):
+ scheduler_params = {
+ "warmup_steps": self.cfg.loss_params.warmup_steps,
+ "base_lr": self.cfg.loss_params.base_lr,
+ }
+ optimizer = build_optimizer(
+ {key: self.model[key] for key in self.model},
+ scheduler_params_dict={key: scheduler_params.copy() for key in self.model},
+ lr=float(scheduler_params["base_lr"]),
+ )
+
+ return optimizer
+
+ def train_loop(self):
+ """Training process"""
+ self.accelerator.wait_for_everyone()
+
+ # Dump config
+ if self.accelerator.is_main_process:
+ self._dump_cfg(self.config_save_path)
+ _ = [self.model[key].train() for key in self.model]
+ self.optimizer.zero_grad()
+
+ # Sync and start training
+ self.accelerator.wait_for_everyone()
+ while self.epoch < self.max_epoch:
+ self.logger.info("\n")
+ self.logger.info("-" * 32)
+ self.logger.info("Epoch {}: ".format(self.epoch))
+
+ # Train and Validate
+ train_total_loss, train_losses = self._train_epoch()
+ for key, loss in train_losses.items():
+ self.logger.info(" |- Train/{} Loss: {:.6f}".format(key, loss))
+ self.accelerator.log(
+ {"Epoch/Train {} Loss".format(key): loss},
+ step=self.epoch,
+ )
+ self.accelerator.log(
+ {
+ "Epoch/Train Total Loss": train_total_loss,
+ },
+ step=self.epoch,
+ )
+
+ # Update scheduler
+ self.accelerator.wait_for_everyone()
+
+ # Check save checkpoint interval
+ run_eval = False
+ if self.accelerator.is_main_process:
+ save_checkpoint = False
+ for i, num in enumerate(self.save_checkpoint_stride):
+ if self.epoch % num == 0:
+ save_checkpoint = True
+ run_eval |= self.run_eval[i]
+
+ # Save checkpoints
+ self.accelerator.wait_for_everyone()
+ if self.accelerator.is_main_process and save_checkpoint:
+ print("Saving..")
+ state = {
+ "net": {key: self.model[key].state_dict() for key in self.model},
+ "optimizer": self.optimizer.state_dict(),
+ "scheduler": self.optimizer.scheduler_state_dict(),
+ "iters": self.step,
+ "epoch": self.epoch,
+ }
+ save_path = os.path.join(
+ self.checkpoint_dir,
+ "FAcodec_epoch_%05d_step_%05d.pth" % (self.epoch, self.iters),
+ )
+ torch.save(state, save_path)
+ json.dump(
+ self.checkpoints_path,
+ open(os.path.join(self.checkpoint_dir, "ckpts.json"), "w"),
+ ensure_ascii=False,
+ indent=4,
+ )
+
+ self.accelerator.wait_for_everyone()
+
+ self.epoch += 1
+
+ # Finish training
+ self.accelerator.wait_for_everyone()
+ if self.accelerator.is_main_process:
+ path = os.path.join(
+ self.checkpoint_dir,
+ "epoch-{:04d}_step-{:07d}".format(
+ self.epoch,
+ self.step,
+ ),
+ )
+ print("Saving..")
+ state = {
+ "net": {key: self.model[key].state_dict() for key in self.model},
+ "optimizer": self.optimizer.state_dict(),
+ "scheduler": self.optimizer.scheduler_state_dict(),
+ "iters": self.step,
+ "epoch": self.epoch,
+ }
+ save_path = os.path.join(
+ self.checkpoint_dir,
+ "FAcodec_epoch_%05d_step_%05d.pth" % (self.epoch, self.iters),
+ )
+ torch.save(state, save_path)
+
+ def _train_epoch(self):
+ """Training epoch. Should return average loss of a batch (sample) over
+ one epoch. See ``train_loop`` for usage.
+ """
+ _ = [self.model[key].train() for key in self.model]
+
+ epoch_losses: dict = {}
+ epoch_total_loss: int = 0
+
+ for batch in tqdm(
+ self.train_dataloader,
+ desc=f"Training Epoch {self.epoch}",
+ unit="batch",
+ colour="GREEN",
+ leave=False,
+ dynamic_ncols=True,
+ smoothing=0.04,
+ disable=not self.accelerator.is_main_process,
+ ):
+ # Get losses
+ total_loss, losses = self._train_step(batch)
+ self.batch_count += 1
+
+ # Log info
+ if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
+ self.accelerator.log(
+ {
+ "Step/Learning Rate": (
+ self.optimizer.schedulers["encoder"].get_last_lr()[0]
+ if self.step != 0
+ else 0
+ )
+ },
+ step=self.step,
+ )
+ for key, _ in losses.items():
+ self.accelerator.log(
+ {
+ "Step/Train {} Loss".format(key): losses[key],
+ },
+ step=self.step,
+ )
+
+ if not epoch_losses:
+ epoch_losses = losses
+ else:
+ for key, value in losses.items():
+ epoch_losses[key] += value
+ epoch_total_loss += total_loss
+ self.step += 1
+
+ # Get and log total losses
+ self.accelerator.wait_for_everyone()
+ epoch_total_loss = (
+ epoch_total_loss
+ / len(self.train_dataloader)
+ * self.cfg.train.gradient_accumulation_step
+ )
+ for key in epoch_losses.keys():
+ epoch_losses[key] = (
+ epoch_losses[key]
+ / len(self.train_dataloader)
+ * self.cfg.train.gradient_accumulation_step
+ )
+ return epoch_total_loss, epoch_losses
+
+ def _train_step(self, data):
+ """Training forward step. Should return average loss of a sample over
+ one batch. Provoke ``_forward_step`` is recommended except for special case.
+ See ``_train_epoch`` for usage.
+ """
+ # Init losses
+ train_losses = {}
+ total_loss = 0
+
+ # Use input feature to get predictions
+ data = [b.to(self.accelerator.device, non_blocking=True) for b in data]
+ waves, mels, wave_lengths, mel_input_length = data
+
+ # extract semantic latent with w2v model
+ waves_16k = torchaudio.functional.resample(waves, 24000, 16000)
+ w2v_input = self.w2v_processor(
+ waves_16k, sampling_rate=16000, return_tensors="pt"
+ ).input_values.to(self.accelerator.device)
+ with torch.no_grad():
+ w2v_outputs = self.w2v_model(w2v_input.squeeze(0)).logits
+ predicted_ids = torch.argmax(w2v_outputs, dim=-1)
+ phone_ids = (
+ F.interpolate(
+ predicted_ids.unsqueeze(0).float(), mels.size(-1), mode="nearest"
+ )
+ .long()
+ .squeeze(0)
+ )
+
+ # get clips
+ mel_seg_len = min(
+ [int(mel_input_length.min().item()), self.cfg.train.max_frame_len]
+ )
+
+ gt_mel_seg = []
+ wav_seg = []
+ w2v_seg = []
+
+ for bib in range(len(mel_input_length)):
+ mel_length = int(mel_input_length[bib].item())
+
+ random_start = (
+ np.random.randint(0, mel_length - mel_seg_len)
+ if mel_length != mel_seg_len
+ else 0
+ )
+ gt_mel_seg.append(mels[bib, :, random_start : random_start + mel_seg_len])
+
+ # w2v_seg.append(w2v_latent[bib, :, random_start:random_start + mel_seg_len])
+ w2v_seg.append(phone_ids[bib, random_start : random_start + mel_seg_len])
+
+ y = waves[bib][random_start * 300 : (random_start + mel_seg_len) * 300]
+
+ wav_seg.append(y.to(self.accelerator.device))
+
+ gt_mel_seg = torch.stack(gt_mel_seg).detach()
+
+ wav_seg = torch.stack(wav_seg).float().detach().unsqueeze(1)
+ w2v_seg = torch.stack(w2v_seg).float().detach()
+
+ with torch.no_grad():
+ real_norm = log_norm(gt_mel_seg.unsqueeze(1)).squeeze(1).detach()
+ F0_real, _, _ = self.pitch_extractor(gt_mel_seg.unsqueeze(1))
+
+ # normalize f0
+ # Remove unvoiced frames (replace with -1)
+ gt_glob_f0s = []
+ f0_targets = []
+ for bib in range(len(F0_real)):
+ voiced_indices = F0_real[bib] > 5.0
+ f0_voiced = F0_real[bib][voiced_indices]
+
+ if len(f0_voiced) != 0:
+ # Convert to log scale
+ log_f0 = f0_voiced.log2()
+
+ # Calculate mean and standard deviation
+ mean_f0 = log_f0.mean()
+ std_f0 = log_f0.std()
+
+ # Normalize the F0 sequence
+ normalized_f0 = (log_f0 - mean_f0) / std_f0
+
+ # Create the normalized F0 sequence with unvoiced frames
+ normalized_sequence = torch.zeros_like(F0_real[bib])
+ normalized_sequence[voiced_indices] = normalized_f0
+ normalized_sequence[~voiced_indices] = (
+ -10
+ ) # Assign -10 to unvoiced frames
+
+ gt_glob_f0s.append(mean_f0)
+ else:
+ normalized_sequence = torch.zeros_like(F0_real[bib]) - 10.0
+ gt_glob_f0s.append(torch.tensor(0.0).to(self.accelerator.device))
+
+ # f0_targets.append(normalized_sequence[single_side_context // 200:-single_side_context // 200])
+ f0_targets.append(normalized_sequence)
+ f0_targets = torch.stack(f0_targets).to(self.accelerator.device)
+ # fill nan with -10
+ f0_targets[torch.isnan(f0_targets)] = -10.0
+ # fill inf with -10
+ f0_targets[torch.isinf(f0_targets)] = -10.0
+ # if frame_rate not equal to 80, interpolate f0 from frame rate of 80 to target frame rate
+ if self.cfg.preprocess_params.frame_rate != 80:
+ f0_targets = F.interpolate(
+ f0_targets.unsqueeze(1),
+ mel_seg_len // 80 * self.cfg.preprocess_params.frame_rate,
+ mode="nearest",
+ ).squeeze(1)
+ w2v_seg = F.interpolate(
+ w2v_seg,
+ mel_seg_len // 80 * self.cfg.preprocess_params.frame_rate,
+ mode="nearest",
+ )
+
+ wav_seg_input = wav_seg
+ wav_seg_target = wav_seg
+
+ z = self.model.encoder(wav_seg_input)
+ z, quantized, commitment_loss, codebook_loss, timbre = self.model.quantizer(
+ z, wav_seg_input, n_c=2, full_waves=waves, wave_lens=wave_lengths
+ )
+ preds, rev_preds = self.model.fa_predictors(quantized, timbre)
+
+ pred_wave = self.model.decoder(z)
+
+ len_diff = wav_seg_target.size(-1) - pred_wave.size(-1)
+ if len_diff > 0:
+ wav_seg_target = wav_seg_target[..., len_diff // 2 : -len_diff // 2]
+
+ # discriminator loss
+ d_fake = self.model.discriminator(pred_wave.detach())
+ d_real = self.model.discriminator(wav_seg_target)
+ loss_d = 0
+ for x_fake, x_real in zip(d_fake, d_real):
+ loss_d += torch.mean(x_fake[-1] ** 2)
+ loss_d += torch.mean((1 - x_real[-1]) ** 2)
+
+ self.optimizer.zero_grad()
+ self.accelerator.backward(loss_d)
+ grad_norm_d = torch.nn.utils.clip_grad_norm_(
+ self.model.discriminator.parameters(), 10.0
+ )
+ self.optimizer.step("discriminator")
+ self.optimizer.scheduler(key="discriminator")
+
+ # generator loss
+ signal = AudioSignal(wav_seg_target, sample_rate=24000)
+ recons = AudioSignal(pred_wave, sample_rate=24000)
+ stft_loss = self.criterions["stft"](recons, signal)
+ mel_loss = self.criterions["mel"](recons, signal)
+ waveform_loss = self.criterions["l1"](recons, signal)
+
+ d_fake = self.model.discriminator(pred_wave)
+ d_real = self.model.discriminator(wav_seg_target)
+
+ loss_g = 0
+ for x_fake in d_fake:
+ loss_g += torch.mean((1 - x_fake[-1]) ** 2)
+
+ loss_feature = 0
+
+ for i in range(len(d_fake)):
+ for j in range(len(d_fake[i]) - 1):
+ loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
+
+ pred_f0, pred_uv = preds["f0"], preds["uv"]
+ rev_pred_f0, rev_pred_uv = rev_preds["rev_f0"], rev_preds["rev_uv"]
+
+ common_min_size = min(pred_f0.size(-2), f0_targets.size(-1))
+ f0_targets = f0_targets[..., :common_min_size]
+ real_norm = real_norm[..., :common_min_size]
+
+ f0_loss = F.smooth_l1_loss(
+ f0_targets, pred_f0.squeeze(-1)[..., :common_min_size]
+ )
+ uv_loss = F.smooth_l1_loss(
+ real_norm, pred_uv.squeeze(-1)[..., :common_min_size]
+ )
+ rev_f0_loss = (
+ F.smooth_l1_loss(f0_targets, rev_pred_f0.squeeze(-1)[..., :common_min_size])
+ if rev_pred_f0 is not None
+ else torch.FloatTensor([0]).to(self.accelerator.device)
+ )
+ rev_uv_loss = (
+ F.smooth_l1_loss(real_norm, rev_pred_uv.squeeze(-1)[..., :common_min_size])
+ if rev_pred_uv is not None
+ else torch.FloatTensor([0]).to(self.accelerator.device)
+ )
+
+ tot_f0_loss = f0_loss + rev_f0_loss
+ tot_uv_loss = uv_loss + rev_uv_loss
+
+ pred_content = preds["content"]
+ rev_pred_content = rev_preds["rev_content"]
+
+ target_content_latents = w2v_seg[..., :common_min_size]
+
+ content_loss = self.criterions["content"](
+ pred_content.transpose(1, 2)[..., :common_min_size],
+ target_content_latents.long(),
+ )
+ rev_content_loss = (
+ self.criterions["content"](
+ rev_pred_content.transpose(1, 2)[..., :common_min_size],
+ target_content_latents.long(),
+ )
+ if rev_pred_content is not None
+ else torch.FloatTensor([0]).to(self.accelerator.device)
+ )
+
+ tot_content_loss = content_loss + rev_content_loss
+
+ if self.speaker_model is not None:
+ spk_logits = torch.cat(
+ [
+ self.speaker_model.infer_segment(w16.cpu()[..., :wl])[1]
+ for w16, wl in zip(waves_16k, wave_lengths)
+ ],
+ dim=0,
+ )
+ spk_labels = spk_logits.argmax(dim=-1)
+ else:
+ spk_labels = torch.zeros([len(waves_16k)], dtype=torch.long).to(
+ self.accelerator.device
+ )
+
+ spk_pred_logits = preds["timbre"]
+ spk_loss = F.cross_entropy(spk_pred_logits, spk_labels)
+ x_spk_pred_logits = rev_preds["x_timbre"]
+
+ x_spk_loss = (
+ F.cross_entropy(x_spk_pred_logits, spk_labels)
+ if x_spk_pred_logits is not None
+ else torch.FloatTensor([0]).to(self.accelerator.device)
+ )
+
+ tot_spk_loss = spk_loss + x_spk_loss
+
+ loss_gen_all = (
+ mel_loss * 15.0
+ + loss_feature * 1.0
+ + loss_g * 1.0
+ + commitment_loss * 0.25
+ + codebook_loss * 1.0
+ + tot_f0_loss * 1.0
+ + tot_uv_loss * 1.0
+ + tot_content_loss * 5.0
+ + tot_spk_loss * 5.0
+ )
+
+ self.optimizer.zero_grad()
+ self.accelerator.backward(loss_gen_all)
+
+ with torch.no_grad():
+ total_loss = loss_gen_all.item()
+ train_losses["stft"] = stft_loss.item()
+ train_losses["mel"] = mel_loss.item()
+ train_losses["l1"] = waveform_loss.item()
+ train_losses["f0"] = f0_loss.item()
+ train_losses["uv"] = uv_loss.item()
+ train_losses["content"] = content_loss.item()
+ train_losses["speaker"] = spk_loss.item()
+ train_losses["rev_f0"] = rev_f0_loss.item()
+ train_losses["rev_uv"] = rev_uv_loss.item()
+ train_losses["rev_content"] = rev_content_loss.item()
+ train_losses["rev_speaker"] = x_spk_loss.item()
+
+ train_losses["feature"] = loss_feature.item()
+ train_losses["generator"] = loss_g.item()
+ train_losses["commitment"] = commitment_loss.item()
+ train_losses["codebook"] = codebook_loss.item()
+
+ # discriminators
+ train_losses["discriminator"] = loss_d.item()
+
+ return total_loss, train_losses
+
+ def _inference(self, eval_wave):
+ """Inference during training for test audios."""
+ z = self.model.encoder(
+ eval_wave[None, None, ...].to(self.accelerator.device).float()
+ )
+ z, quantized, commitment_loss, codebook_loss, timbre = self.model.quantizer(
+ z, eval_wave[None, None, ...], n_c=self.cfg.model_params.n_c_codebooks
+ )
+ full_pred_wave = self.model.decoder(z)
+ return full_pred_wave[0]
+
+ def _load_model(self, checkpoint_path=None, resume_type="resume"):
+ """Load model from checkpoint. If checkpoint_path is None, it will
+ load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
+ None, it will load the checkpoint specified by checkpoint_path. **Only use this
+ method after** ``accelerator.prepare()``.
+ """
+ if resume_type == "resume":
+ if checkpoint_path is None:
+ available_checkpoints = glob.glob(
+ os.path.join(self.checkpoint_dir, "FAcodc_epoch_*_step_*.pth")
+ )
+ # find the checkpoint that has the highest step number
+ latest_checkpoint = max(
+ available_checkpoints,
+ key=lambda x: int(x.split("_")[-1].split(".")[0]),
+ )
+ earliest_checkpoint = min(
+ available_checkpoints,
+ key=lambda x: int(x.split("_")[-1].split(".")[0]),
+ )
+ # delete the earliest checkpoint
+ if (
+ earliest_checkpoint != latest_checkpoint
+ and self.accelerator.is_main_process
+ and len(available_checkpoints) > 4
+ ):
+ os.remove(earliest_checkpoint)
+ print(f"Removed {earliest_checkpoint}")
+ else:
+ latest_checkpoint = checkpoint_path
+
+ self.model, self.optimizer, self.epoch, self.step = load_checkpoint(
+ self.model,
+ self.optimizer,
+ latest_checkpoint,
+ load_only_params=False,
+ ignore_modules=[],
+ is_distributed=self.accelerator.num_processes > 1,
+ )
+
+ else:
+ raise ValueError("Invalid resume type")
+ return checkpoint_path
+
+ def _count_parameters(self):
+ total_num = sum(
+ sum(p.numel() for p in self.model[key].parameters()) for key in self.model
+ )
+ # trainable_num = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
+ return total_num
diff --git a/models/codec/facodec/modules/JDC/__init__.py b/models/codec/facodec/modules/JDC/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/models/codec/facodec/modules/JDC/__init__.py
@@ -0,0 +1 @@
+
diff --git a/models/codec/facodec/modules/JDC/bst.t7 b/models/codec/facodec/modules/JDC/bst.t7
new file mode 100644
index 0000000000000000000000000000000000000000..5aa5a7b89991a3ecce2fd13447d6cb65740d2a9b
--- /dev/null
+++ b/models/codec/facodec/modules/JDC/bst.t7
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:54dc94364b97e18ac1dfa6287714ed121248cfaac4cfd39d061c6e0a089ef169
+size 21029926
diff --git a/models/codec/facodec/modules/JDC/model.py b/models/codec/facodec/modules/JDC/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..601ec960795c76be84417bb4e466ac7fe7754cb3
--- /dev/null
+++ b/models/codec/facodec/modules/JDC/model.py
@@ -0,0 +1,219 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# This code is borrowed from https://github.com/yl4579/PitchExtractor/blob/main/model.py
+
+"""
+Implementation of model from:
+Kum et al. - "Joint Detection and Classification of Singing Voice Melody Using
+Convolutional Recurrent Neural Networks" (2019)
+Link: https://www.semanticscholar.org/paper/Joint-Detection-and-Classification-of-Singing-Voice-Kum-Nam/60a2ad4c7db43bace75805054603747fcd062c0d
+"""
+import torch
+from torch import nn
+
+
+class JDCNet(nn.Module):
+ """
+ Joint Detection and Classification Network model for singing voice melody.
+ """
+
+ def __init__(self, num_class=722, seq_len=31, leaky_relu_slope=0.01):
+ super().__init__()
+ self.num_class = num_class
+
+ # input = (b, 1, 31, 513), b = batch size
+ self.conv_block = nn.Sequential(
+ nn.Conv2d(
+ in_channels=1, out_channels=64, kernel_size=3, padding=1, bias=False
+ ), # out: (b, 64, 31, 513)
+ nn.BatchNorm2d(num_features=64),
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
+ nn.Conv2d(64, 64, 3, padding=1, bias=False), # (b, 64, 31, 513)
+ )
+
+ # res blocks
+ self.res_block1 = ResBlock(
+ in_channels=64, out_channels=128
+ ) # (b, 128, 31, 128)
+ self.res_block2 = ResBlock(
+ in_channels=128, out_channels=192
+ ) # (b, 192, 31, 32)
+ self.res_block3 = ResBlock(in_channels=192, out_channels=256) # (b, 256, 31, 8)
+
+ # pool block
+ self.pool_block = nn.Sequential(
+ nn.BatchNorm2d(num_features=256),
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
+ nn.MaxPool2d(kernel_size=(1, 4)), # (b, 256, 31, 2)
+ nn.Dropout(p=0.2),
+ )
+
+ # maxpool layers (for auxiliary network inputs)
+ # in = (b, 128, 31, 513) from conv_block, out = (b, 128, 31, 2)
+ self.maxpool1 = nn.MaxPool2d(kernel_size=(1, 40))
+ # in = (b, 128, 31, 128) from res_block1, out = (b, 128, 31, 2)
+ self.maxpool2 = nn.MaxPool2d(kernel_size=(1, 20))
+ # in = (b, 128, 31, 32) from res_block2, out = (b, 128, 31, 2)
+ self.maxpool3 = nn.MaxPool2d(kernel_size=(1, 10))
+
+ # in = (b, 640, 31, 2), out = (b, 256, 31, 2)
+ self.detector_conv = nn.Sequential(
+ nn.Conv2d(640, 256, 1, bias=False),
+ nn.BatchNorm2d(256),
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
+ nn.Dropout(p=0.2),
+ )
+
+ # input: (b, 31, 512) - resized from (b, 256, 31, 2)
+ self.bilstm_classifier = nn.LSTM(
+ input_size=512, hidden_size=256, batch_first=True, bidirectional=True
+ ) # (b, 31, 512)
+
+ # input: (b, 31, 512) - resized from (b, 256, 31, 2)
+ self.bilstm_detector = nn.LSTM(
+ input_size=512, hidden_size=256, batch_first=True, bidirectional=True
+ ) # (b, 31, 512)
+
+ # input: (b * 31, 512)
+ self.classifier = nn.Linear(
+ in_features=512, out_features=self.num_class
+ ) # (b * 31, num_class)
+
+ # input: (b * 31, 512)
+ self.detector = nn.Linear(
+ in_features=512, out_features=2
+ ) # (b * 31, 2) - binary classifier
+
+ # initialize weights
+ self.apply(self.init_weights)
+
+ def get_feature_GAN(self, x):
+ seq_len = x.shape[-2]
+ x = x.float().transpose(-1, -2)
+
+ convblock_out = self.conv_block(x)
+
+ resblock1_out = self.res_block1(convblock_out)
+ resblock2_out = self.res_block2(resblock1_out)
+ resblock3_out = self.res_block3(resblock2_out)
+ poolblock_out = self.pool_block[0](resblock3_out)
+ poolblock_out = self.pool_block[1](poolblock_out)
+
+ return poolblock_out.transpose(-1, -2)
+
+ def get_feature(self, x):
+ seq_len = x.shape[-2]
+ x = x.float().transpose(-1, -2)
+
+ convblock_out = self.conv_block(x)
+
+ resblock1_out = self.res_block1(convblock_out)
+ resblock2_out = self.res_block2(resblock1_out)
+ resblock3_out = self.res_block3(resblock2_out)
+ poolblock_out = self.pool_block[0](resblock3_out)
+ poolblock_out = self.pool_block[1](poolblock_out)
+
+ return self.pool_block[2](poolblock_out)
+
+ def forward(self, x):
+ """
+ Returns:
+ classification_prediction, detection_prediction
+ sizes: (b, 31, 722), (b, 31, 2)
+ """
+ ###############################
+ # forward pass for classifier #
+ ###############################
+ seq_len = x.shape[-1]
+ x = x.float().transpose(-1, -2)
+
+ convblock_out = self.conv_block(x)
+
+ resblock1_out = self.res_block1(convblock_out)
+ resblock2_out = self.res_block2(resblock1_out)
+ resblock3_out = self.res_block3(resblock2_out)
+
+ poolblock_out = self.pool_block[0](resblock3_out)
+ poolblock_out = self.pool_block[1](poolblock_out)
+ GAN_feature = poolblock_out.transpose(-1, -2)
+ poolblock_out = self.pool_block[2](poolblock_out)
+
+ # (b, 256, 31, 2) => (b, 31, 256, 2) => (b, 31, 512)
+ classifier_out = (
+ poolblock_out.permute(0, 2, 1, 3).contiguous().view((-1, seq_len, 512))
+ )
+ classifier_out, _ = self.bilstm_classifier(
+ classifier_out
+ ) # ignore the hidden states
+
+ classifier_out = classifier_out.contiguous().view((-1, 512)) # (b * 31, 512)
+ classifier_out = self.classifier(classifier_out)
+ classifier_out = classifier_out.view(
+ (-1, seq_len, self.num_class)
+ ) # (b, 31, num_class)
+
+ # sizes: (b, 31, 722), (b, 31, 2)
+ # classifier output consists of predicted pitch classes per frame
+ # detector output consists of: (isvoice, notvoice) estimates per frame
+ return torch.abs(classifier_out.squeeze(-1)), GAN_feature, poolblock_out
+
+ @staticmethod
+ def init_weights(m):
+ if isinstance(m, nn.Linear):
+ nn.init.kaiming_uniform_(m.weight)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Conv2d):
+ nn.init.xavier_normal_(m.weight)
+ elif isinstance(m, nn.LSTM) or isinstance(m, nn.LSTMCell):
+ for p in m.parameters():
+ if p.data is None:
+ continue
+
+ if len(p.shape) >= 2:
+ nn.init.orthogonal_(p.data)
+ else:
+ nn.init.normal_(p.data)
+
+
+class ResBlock(nn.Module):
+ def __init__(self, in_channels: int, out_channels: int, leaky_relu_slope=0.01):
+ super().__init__()
+ self.downsample = in_channels != out_channels
+
+ # BN / LReLU / MaxPool layer before the conv layer - see Figure 1b in the paper
+ self.pre_conv = nn.Sequential(
+ nn.BatchNorm2d(num_features=in_channels),
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
+ nn.MaxPool2d(kernel_size=(1, 2)), # apply downsampling on the y axis only
+ )
+
+ # conv layers
+ self.conv = nn.Sequential(
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ padding=1,
+ bias=False,
+ ),
+ nn.BatchNorm2d(out_channels),
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
+ nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
+ )
+
+ # 1 x 1 convolution layer to match the feature dimensions
+ self.conv1by1 = None
+ if self.downsample:
+ self.conv1by1 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
+
+ def forward(self, x):
+ x = self.pre_conv(x)
+ if self.downsample:
+ x = self.conv(x) + self.conv1by1(x)
+ else:
+ x = self.conv(x) + x
+ return x
diff --git a/models/codec/facodec/modules/attentions.py b/models/codec/facodec/modules/attentions.py
new file mode 100644
index 0000000000000000000000000000000000000000..c29854fd97cefc66125301003198b2da6ea1e9be
--- /dev/null
+++ b/models/codec/facodec/modules/attentions.py
@@ -0,0 +1,437 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# This code is modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/ttv_v1/attentions.py
+
+import copy
+import math
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from . import commons
+
+
+class LayerNorm(nn.Module):
+ def __init__(self, channels, eps=1e-5):
+ super().__init__()
+ self.channels = channels
+ self.eps = eps
+
+ self.gamma = nn.Parameter(torch.ones(channels))
+ self.beta = nn.Parameter(torch.zeros(channels))
+
+ def forward(self, x):
+ x = x.transpose(1, -1)
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
+ return x.transpose(1, -1)
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size=1,
+ p_dropout=0.0,
+ window_size=4,
+ **kwargs
+ ):
+ super().__init__()
+ self.hidden_channels = hidden_channels
+ self.filter_channels = filter_channels
+ self.n_heads = n_heads
+ self.n_layers = n_layers
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.window_size = window_size
+
+ self.drop = nn.Dropout(p_dropout)
+ self.attn_layers = nn.ModuleList()
+ self.norm_layers_1 = nn.ModuleList()
+ self.ffn_layers = nn.ModuleList()
+ self.norm_layers_2 = nn.ModuleList()
+ for i in range(self.n_layers):
+ self.attn_layers.append(
+ MultiHeadAttention(
+ hidden_channels,
+ hidden_channels,
+ n_heads,
+ p_dropout=p_dropout,
+ window_size=window_size,
+ )
+ )
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
+ self.ffn_layers.append(
+ FFN(
+ hidden_channels,
+ hidden_channels,
+ filter_channels,
+ kernel_size,
+ p_dropout=p_dropout,
+ )
+ )
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
+
+ def forward(self, x, x_mask):
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
+ x = x * x_mask
+ for i in range(self.n_layers):
+ y = self.attn_layers[i](x, x, attn_mask)
+ y = self.drop(y)
+ x = self.norm_layers_1[i](x + y)
+
+ y = self.ffn_layers[i](x, x_mask)
+ y = self.drop(y)
+ x = self.norm_layers_2[i](x + y)
+ x = x * x_mask
+ return x
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size=1,
+ p_dropout=0.0,
+ proximal_bias=False,
+ proximal_init=True,
+ **kwargs
+ ):
+ super().__init__()
+ self.hidden_channels = hidden_channels
+ self.filter_channels = filter_channels
+ self.n_heads = n_heads
+ self.n_layers = n_layers
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.proximal_bias = proximal_bias
+ self.proximal_init = proximal_init
+
+ self.drop = nn.Dropout(p_dropout)
+ self.self_attn_layers = nn.ModuleList()
+ self.norm_layers_0 = nn.ModuleList()
+ self.encdec_attn_layers = nn.ModuleList()
+ self.norm_layers_1 = nn.ModuleList()
+ self.ffn_layers = nn.ModuleList()
+ self.norm_layers_2 = nn.ModuleList()
+ for i in range(self.n_layers):
+ self.self_attn_layers.append(
+ MultiHeadAttention(
+ hidden_channels,
+ hidden_channels,
+ n_heads,
+ p_dropout=p_dropout,
+ proximal_bias=proximal_bias,
+ proximal_init=proximal_init,
+ )
+ )
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
+ self.encdec_attn_layers.append(
+ MultiHeadAttention(
+ hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
+ )
+ )
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
+ self.ffn_layers.append(
+ FFN(
+ hidden_channels,
+ hidden_channels,
+ filter_channels,
+ kernel_size,
+ p_dropout=p_dropout,
+ causal=True,
+ )
+ )
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
+
+ def forward(self, x, x_mask, h, h_mask):
+ """
+ x: decoder input
+ h: encoder output
+ """
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
+ device=x.device, dtype=x.dtype
+ )
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
+ x = x * x_mask
+ for i in range(self.n_layers):
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
+ y = self.drop(y)
+ x = self.norm_layers_0[i](x + y)
+
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
+ y = self.drop(y)
+ x = self.norm_layers_1[i](x + y)
+
+ y = self.ffn_layers[i](x, x_mask)
+ y = self.drop(y)
+ x = self.norm_layers_2[i](x + y)
+ x = x * x_mask
+ return x
+
+
+class MultiHeadAttention(nn.Module):
+ def __init__(
+ self,
+ channels,
+ out_channels,
+ n_heads,
+ p_dropout=0.0,
+ window_size=None,
+ heads_share=True,
+ block_length=None,
+ proximal_bias=False,
+ proximal_init=False,
+ ):
+ super().__init__()
+ assert channels % n_heads == 0
+
+ self.channels = channels
+ self.out_channels = out_channels
+ self.n_heads = n_heads
+ self.p_dropout = p_dropout
+ self.window_size = window_size
+ self.heads_share = heads_share
+ self.block_length = block_length
+ self.proximal_bias = proximal_bias
+ self.proximal_init = proximal_init
+ self.attn = None
+
+ self.k_channels = channels // n_heads
+ self.conv_q = nn.Conv1d(channels, channels, 1)
+ self.conv_k = nn.Conv1d(channels, channels, 1)
+ self.conv_v = nn.Conv1d(channels, channels, 1)
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
+ self.drop = nn.Dropout(p_dropout)
+
+ if window_size is not None:
+ n_heads_rel = 1 if heads_share else n_heads
+ rel_stddev = self.k_channels**-0.5
+ self.emb_rel_k = nn.Parameter(
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
+ * rel_stddev
+ )
+ self.emb_rel_v = nn.Parameter(
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
+ * rel_stddev
+ )
+
+ nn.init.xavier_uniform_(self.conv_q.weight)
+ nn.init.xavier_uniform_(self.conv_k.weight)
+ nn.init.xavier_uniform_(self.conv_v.weight)
+ if proximal_init:
+ with torch.no_grad():
+ self.conv_k.weight.copy_(self.conv_q.weight)
+ self.conv_k.bias.copy_(self.conv_q.bias)
+
+ def forward(self, x, c, attn_mask=None):
+ q = self.conv_q(x)
+ k = self.conv_k(c)
+ v = self.conv_v(c)
+
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
+
+ x = self.conv_o(x)
+ return x
+
+ def attention(self, query, key, value, mask=None):
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
+ b, d, t_s, t_t = (*key.size(), query.size(2))
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
+
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
+ if self.window_size is not None:
+ assert (
+ t_s == t_t
+ ), "Relative attention is only available for self-attention."
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
+ rel_logits = self._matmul_with_relative_keys(
+ query / math.sqrt(self.k_channels), key_relative_embeddings
+ )
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
+ scores = scores + scores_local
+ if self.proximal_bias:
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
+ scores = scores + self._attention_bias_proximal(t_s).to(
+ device=scores.device, dtype=scores.dtype
+ )
+ if mask is not None:
+ scores = scores.masked_fill(mask == 0, -1e4)
+ if self.block_length is not None:
+ assert (
+ t_s == t_t
+ ), "Local attention is only available for self-attention."
+ block_mask = (
+ torch.ones_like(scores)
+ .triu(-self.block_length)
+ .tril(self.block_length)
+ )
+ scores = scores.masked_fill(block_mask == 0, -1e4)
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
+ p_attn = self.drop(p_attn)
+ output = torch.matmul(p_attn, value)
+ if self.window_size is not None:
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
+ value_relative_embeddings = self._get_relative_embeddings(
+ self.emb_rel_v, t_s
+ )
+ output = output + self._matmul_with_relative_values(
+ relative_weights, value_relative_embeddings
+ )
+ output = (
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
+ return output, p_attn
+
+ def _matmul_with_relative_values(self, x, y):
+ """
+ x: [b, h, l, m]
+ y: [h or 1, m, d]
+ ret: [b, h, l, d]
+ """
+ ret = torch.matmul(x, y.unsqueeze(0))
+ return ret
+
+ def _matmul_with_relative_keys(self, x, y):
+ """
+ x: [b, h, l, d]
+ y: [h or 1, m, d]
+ ret: [b, h, l, m]
+ """
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
+ return ret
+
+ def _get_relative_embeddings(self, relative_embeddings, length):
+ max_relative_position = 2 * self.window_size + 1
+ # Pad first before slice to avoid using cond ops.
+ pad_length = max(length - (self.window_size + 1), 0)
+ slice_start_position = max((self.window_size + 1) - length, 0)
+ slice_end_position = slice_start_position + 2 * length - 1
+ if pad_length > 0:
+ padded_relative_embeddings = F.pad(
+ relative_embeddings,
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
+ )
+ else:
+ padded_relative_embeddings = relative_embeddings
+ used_relative_embeddings = padded_relative_embeddings[
+ :, slice_start_position:slice_end_position
+ ]
+ return used_relative_embeddings
+
+ def _relative_position_to_absolute_position(self, x):
+ """
+ x: [b, h, l, 2*l-1]
+ ret: [b, h, l, l]
+ """
+ batch, heads, length, _ = x.size()
+ # Concat columns of pad to shift from relative to absolute indexing.
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
+
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
+ x_flat = x.view([batch, heads, length * 2 * length])
+ x_flat = F.pad(
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
+ )
+
+ # Reshape and slice out the padded elements.
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
+ :, :, :length, length - 1 :
+ ]
+ return x_final
+
+ def _absolute_position_to_relative_position(self, x):
+ """
+ x: [b, h, l, l]
+ ret: [b, h, l, 2*l-1]
+ """
+ batch, heads, length, _ = x.size()
+ # padd along column
+ x = F.pad(
+ x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
+ )
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
+ # add 0's in the beginning that will skew the elements after reshape
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
+ return x_final
+
+ def _attention_bias_proximal(self, length):
+ """Bias for self-attention to encourage attention to close positions.
+ Args:
+ length: an integer scalar.
+ Returns:
+ a Tensor with shape [1, 1, length, length]
+ """
+ r = torch.arange(length, dtype=torch.float32)
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
+
+
+class FFN(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ filter_channels,
+ kernel_size,
+ p_dropout=0.0,
+ activation=None,
+ causal=False,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.filter_channels = filter_channels
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.activation = activation
+ self.causal = causal
+
+ if causal:
+ self.padding = self._causal_padding
+ else:
+ self.padding = self._same_padding
+
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
+ self.drop = nn.Dropout(p_dropout)
+
+ def forward(self, x, x_mask):
+ x = self.conv_1(self.padding(x * x_mask))
+ if self.activation == "gelu":
+ x = x * torch.sigmoid(1.702 * x)
+ else:
+ x = torch.relu(x)
+ x = self.drop(x)
+ x = self.conv_2(self.padding(x * x_mask))
+ return x * x_mask
+
+ def _causal_padding(self, x):
+ if self.kernel_size == 1:
+ return x
+ pad_l = self.kernel_size - 1
+ pad_r = 0
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
+ x = F.pad(x, commons.convert_pad_shape(padding))
+ return x
+
+ def _same_padding(self, x):
+ if self.kernel_size == 1:
+ return x
+ pad_l = (self.kernel_size - 1) // 2
+ pad_r = self.kernel_size // 2
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
+ x = F.pad(x, commons.convert_pad_shape(padding))
+ return x
diff --git a/models/codec/facodec/modules/commons.py b/models/codec/facodec/modules/commons.py
new file mode 100644
index 0000000000000000000000000000000000000000..89baaf4b06426595b7be1ab9ca4d94c5c99779d6
--- /dev/null
+++ b/models/codec/facodec/modules/commons.py
@@ -0,0 +1,331 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import math
+import os.path
+
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+from munch import Munch
+import json
+
+
+class AttrDict(dict):
+ def __init__(self, *args, **kwargs):
+ super(AttrDict, self).__init__(*args, **kwargs)
+ self.__dict__ = self
+
+
+def init_weights(m, mean=0.0, std=0.01):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+ return int((kernel_size * dilation - dilation) / 2)
+
+
+def convert_pad_shape(pad_shape):
+ l = pad_shape[::-1]
+ pad_shape = [item for sublist in l for item in sublist]
+ return pad_shape
+
+
+def intersperse(lst, item):
+ result = [item] * (len(lst) * 2 + 1)
+ result[1::2] = lst
+ return result
+
+
+def kl_divergence(m_p, logs_p, m_q, logs_q):
+ """KL(P||Q)"""
+ kl = (logs_q - logs_p) - 0.5
+ kl += (
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
+ )
+ return kl
+
+
+def rand_gumbel(shape):
+ """Sample from the Gumbel distribution, protect from overflows."""
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
+ return -torch.log(-torch.log(uniform_samples))
+
+
+def rand_gumbel_like(x):
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
+ return g
+
+
+def slice_segments(x, ids_str, segment_size=4):
+ ret = torch.zeros_like(x[:, :, :segment_size])
+ for i in range(x.size(0)):
+ idx_str = ids_str[i]
+ idx_end = idx_str + segment_size
+ ret[i] = x[i, :, idx_str:idx_end]
+ return ret
+
+
+def slice_segments_audio(x, ids_str, segment_size=4):
+ ret = torch.zeros_like(x[:, :segment_size])
+ for i in range(x.size(0)):
+ idx_str = ids_str[i]
+ idx_end = idx_str + segment_size
+ ret[i] = x[i, idx_str:idx_end]
+ return ret
+
+
+def rand_slice_segments(x, x_lengths=None, segment_size=4):
+ b, d, t = x.size()
+ if x_lengths is None:
+ x_lengths = t
+ ids_str_max = x_lengths - segment_size + 1
+ ids_str = ((torch.rand([b]).to(device=x.device) * ids_str_max).clip(0)).to(
+ dtype=torch.long
+ )
+ ret = slice_segments(x, ids_str, segment_size)
+ return ret, ids_str
+
+
+def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
+ position = torch.arange(length, dtype=torch.float)
+ num_timescales = channels // 2
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
+ num_timescales - 1
+ )
+ inv_timescales = min_timescale * torch.exp(
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
+ )
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
+ signal = signal.view(1, channels, length)
+ return signal
+
+
+def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
+ b, channels, length = x.size()
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
+ return x + signal.to(dtype=x.dtype, device=x.device)
+
+
+def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
+ b, channels, length = x.size()
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
+
+
+def subsequent_mask(length):
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
+ return mask
+
+
+@torch.jit.script
+def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
+ n_channels_int = n_channels[0]
+ in_act = input_a + input_b
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
+ acts = t_act * s_act
+ return acts
+
+
+def convert_pad_shape(pad_shape):
+ l = pad_shape[::-1]
+ pad_shape = [item for sublist in l for item in sublist]
+ return pad_shape
+
+
+def shift_1d(x):
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
+ return x
+
+
+def sequence_mask(length, max_length=None):
+ if max_length is None:
+ max_length = length.max()
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
+ return x.unsqueeze(0) < length.unsqueeze(1)
+
+
+def generate_path(duration, mask):
+ """
+ duration: [b, 1, t_x]
+ mask: [b, 1, t_y, t_x]
+ """
+ device = duration.device
+
+ b, _, t_y, t_x = mask.shape
+ cum_duration = torch.cumsum(duration, -1)
+
+ cum_duration_flat = cum_duration.view(b * t_x)
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
+ path = path.view(b, t_x, t_y)
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
+ path = path.unsqueeze(1).transpose(2, 3) * mask
+ return path
+
+
+def clip_grad_value_(parameters, clip_value, norm_type=2):
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
+ norm_type = float(norm_type)
+ if clip_value is not None:
+ clip_value = float(clip_value)
+
+ total_norm = 0
+ for p in parameters:
+ param_norm = p.grad.data.norm(norm_type)
+ total_norm += param_norm.item() ** norm_type
+ if clip_value is not None:
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
+ total_norm = total_norm ** (1.0 / norm_type)
+ return total_norm
+
+
+def log_norm(x, mean=-4, std=4, dim=2):
+ """
+ normalized log mel -> mel -> norm -> log(norm)
+ """
+ x = torch.log(torch.exp(x * std + mean).norm(dim=dim))
+ return x
+
+
+from huggingface_hub import hf_hub_download
+
+
+def load_F0_models(path):
+ # load F0 model
+ from .JDC.model import JDCNet
+
+ F0_model = JDCNet(num_class=1, seq_len=192)
+ if not os.path.exists(path):
+ path = hf_hub_download(repo_id="Plachta/JDCnet", filename="bst.t7")
+ params = torch.load(path, map_location="cpu")["net"]
+ F0_model.load_state_dict(params)
+ _ = F0_model.train()
+
+ return F0_model
+
+
+# Generators
+from modules.dac.model.dac import Encoder, Decoder
+from .quantize import FAquantizer, FApredictors
+
+# Discriminators
+from modules.dac.model.discriminator import Discriminator
+
+
+def build_model(args):
+ encoder = Encoder(
+ d_model=args.DAC.encoder_dim,
+ strides=args.DAC.encoder_rates,
+ d_latent=1024,
+ causal=args.causal,
+ lstm=args.lstm,
+ )
+
+ quantizer = FAquantizer(
+ in_dim=1024,
+ n_p_codebooks=1,
+ n_c_codebooks=args.n_c_codebooks,
+ n_t_codebooks=2,
+ n_r_codebooks=3,
+ codebook_size=1024,
+ codebook_dim=8,
+ quantizer_dropout=0.5,
+ causal=args.causal,
+ separate_prosody_encoder=args.separate_prosody_encoder,
+ timbre_norm=args.timbre_norm,
+ )
+
+ fa_predictors = FApredictors(
+ in_dim=1024,
+ use_gr_content_f0=args.use_gr_content_f0,
+ use_gr_prosody_phone=args.use_gr_prosody_phone,
+ use_gr_residual_f0=True,
+ use_gr_residual_phone=True,
+ use_gr_timbre_content=True,
+ use_gr_timbre_prosody=args.use_gr_timbre_prosody,
+ use_gr_x_timbre=True,
+ norm_f0=args.norm_f0,
+ timbre_norm=args.timbre_norm,
+ use_gr_content_global_f0=args.use_gr_content_global_f0,
+ )
+
+ decoder = Decoder(
+ input_channel=1024,
+ channels=args.DAC.decoder_dim,
+ rates=args.DAC.decoder_rates,
+ causal=args.causal,
+ lstm=args.lstm,
+ )
+
+ discriminator = Discriminator(
+ rates=[],
+ periods=[2, 3, 5, 7, 11],
+ fft_sizes=[2048, 1024, 512],
+ sample_rate=args.DAC.sr,
+ bands=[(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)],
+ )
+
+ nets = Munch(
+ encoder=encoder,
+ quantizer=quantizer,
+ decoder=decoder,
+ discriminator=discriminator,
+ fa_predictors=fa_predictors,
+ )
+
+ return nets
+
+
+def load_checkpoint(
+ model,
+ optimizer,
+ path,
+ load_only_params=True,
+ ignore_modules=[],
+ is_distributed=False,
+):
+ state = torch.load(path, map_location="cpu")
+ params = state["net"]
+ for key in model:
+ if key in params and key not in ignore_modules:
+ if not is_distributed:
+ # strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
+ for k in list(params[key].keys()):
+ if k.startswith("module."):
+ params[key][k[len("module.") :]] = params[key][k]
+ del params[key][k]
+ print("%s loaded" % key)
+ model[key].load_state_dict(params[key], strict=True)
+ _ = [model[key].eval() for key in model]
+
+ if not load_only_params:
+ epoch = state["epoch"] + 1
+ iters = state["iters"]
+ optimizer.load_state_dict(state["optimizer"])
+ optimizer.load_scheduler_state_dict(state["scheduler"])
+
+ else:
+ epoch = state["epoch"] + 1
+ iters = state["iters"]
+
+ return model, optimizer, epoch, iters
+
+
+def recursive_munch(d):
+ if isinstance(d, dict):
+ return Munch((k, recursive_munch(v)) for k, v in d.items())
+ elif isinstance(d, list):
+ return [recursive_munch(v) for v in d]
+ else:
+ return d
diff --git a/models/codec/facodec/modules/gradient_reversal.py b/models/codec/facodec/modules/gradient_reversal.py
new file mode 100644
index 0000000000000000000000000000000000000000..d09396ea20c653b2a443e144ab429f534ce033fd
--- /dev/null
+++ b/models/codec/facodec/modules/gradient_reversal.py
@@ -0,0 +1,35 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from torch.autograd import Function
+import torch
+from torch import nn
+
+
+class GradientReversal(Function):
+ @staticmethod
+ def forward(ctx, x, alpha):
+ ctx.save_for_backward(x, alpha)
+ return x
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ grad_input = None
+ _, alpha = ctx.saved_tensors
+ if ctx.needs_input_grad[0]:
+ grad_input = -alpha * grad_output
+ return grad_input, None
+
+
+revgrad = GradientReversal.apply
+
+
+class GradientReversal(nn.Module):
+ def __init__(self, alpha):
+ super().__init__()
+ self.alpha = torch.tensor(alpha, requires_grad=False)
+
+ def forward(self, x):
+ return revgrad(x, self.alpha)
diff --git a/models/codec/facodec/modules/layers.py b/models/codec/facodec/modules/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..186cbe7bef653fe7cbd6dbd40f38624457b8ecfa
--- /dev/null
+++ b/models/codec/facodec/modules/layers.py
@@ -0,0 +1,460 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import torch
+from torch import nn
+from typing import Optional, Any
+from torch import Tensor
+import torch.nn.functional as F
+import torchaudio
+import torchaudio.functional as audio_F
+
+import random
+
+random.seed(0)
+
+
+def _get_activation_fn(activ):
+ if activ == "relu":
+ return nn.ReLU()
+ elif activ == "lrelu":
+ return nn.LeakyReLU(0.2)
+ elif activ == "swish":
+ return lambda x: x * torch.sigmoid(x)
+ else:
+ raise RuntimeError(
+ "Unexpected activ type %s, expected [relu, lrelu, swish]" % activ
+ )
+
+
+class LinearNorm(torch.nn.Module):
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"):
+ super(LinearNorm, self).__init__()
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
+
+ torch.nn.init.xavier_uniform_(
+ self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
+ )
+
+ def forward(self, x):
+ return self.linear_layer(x)
+
+
+class ConvNorm(torch.nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=None,
+ dilation=1,
+ bias=True,
+ w_init_gain="linear",
+ param=None,
+ ):
+ super(ConvNorm, self).__init__()
+ if padding is None:
+ assert kernel_size % 2 == 1
+ padding = int(dilation * (kernel_size - 1) / 2)
+
+ self.conv = torch.nn.Conv1d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ bias=bias,
+ )
+
+ torch.nn.init.xavier_uniform_(
+ self.conv.weight,
+ gain=torch.nn.init.calculate_gain(w_init_gain, param=param),
+ )
+
+ def forward(self, signal):
+ conv_signal = self.conv(signal)
+ return conv_signal
+
+
+class CausualConv(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=1,
+ dilation=1,
+ bias=True,
+ w_init_gain="linear",
+ param=None,
+ ):
+ super(CausualConv, self).__init__()
+ if padding is None:
+ assert kernel_size % 2 == 1
+ padding = int(dilation * (kernel_size - 1) / 2) * 2
+ else:
+ self.padding = padding * 2
+ self.conv = nn.Conv1d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=self.padding,
+ dilation=dilation,
+ bias=bias,
+ )
+
+ torch.nn.init.xavier_uniform_(
+ self.conv.weight,
+ gain=torch.nn.init.calculate_gain(w_init_gain, param=param),
+ )
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = x[:, :, : -self.padding]
+ return x
+
+
+class CausualBlock(nn.Module):
+ def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ="lrelu"):
+ super(CausualBlock, self).__init__()
+ self.blocks = nn.ModuleList(
+ [
+ self._get_conv(
+ hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p
+ )
+ for i in range(n_conv)
+ ]
+ )
+
+ def forward(self, x):
+ for block in self.blocks:
+ res = x
+ x = block(x)
+ x += res
+ return x
+
+ def _get_conv(self, hidden_dim, dilation, activ="lrelu", dropout_p=0.2):
+ layers = [
+ CausualConv(
+ hidden_dim,
+ hidden_dim,
+ kernel_size=3,
+ padding=dilation,
+ dilation=dilation,
+ ),
+ _get_activation_fn(activ),
+ nn.BatchNorm1d(hidden_dim),
+ nn.Dropout(p=dropout_p),
+ CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
+ _get_activation_fn(activ),
+ nn.Dropout(p=dropout_p),
+ ]
+ return nn.Sequential(*layers)
+
+
+class ConvBlock(nn.Module):
+ def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ="relu"):
+ super().__init__()
+ self._n_groups = 8
+ self.blocks = nn.ModuleList(
+ [
+ self._get_conv(
+ hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p
+ )
+ for i in range(n_conv)
+ ]
+ )
+
+ def forward(self, x):
+ for block in self.blocks:
+ res = x
+ x = block(x)
+ x += res
+ return x
+
+ def _get_conv(self, hidden_dim, dilation, activ="relu", dropout_p=0.2):
+ layers = [
+ ConvNorm(
+ hidden_dim,
+ hidden_dim,
+ kernel_size=3,
+ padding=dilation,
+ dilation=dilation,
+ ),
+ _get_activation_fn(activ),
+ nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
+ nn.Dropout(p=dropout_p),
+ ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
+ _get_activation_fn(activ),
+ nn.Dropout(p=dropout_p),
+ ]
+ return nn.Sequential(*layers)
+
+
+class LocationLayer(nn.Module):
+ def __init__(self, attention_n_filters, attention_kernel_size, attention_dim):
+ super(LocationLayer, self).__init__()
+ padding = int((attention_kernel_size - 1) / 2)
+ self.location_conv = ConvNorm(
+ 2,
+ attention_n_filters,
+ kernel_size=attention_kernel_size,
+ padding=padding,
+ bias=False,
+ stride=1,
+ dilation=1,
+ )
+ self.location_dense = LinearNorm(
+ attention_n_filters, attention_dim, bias=False, w_init_gain="tanh"
+ )
+
+ def forward(self, attention_weights_cat):
+ processed_attention = self.location_conv(attention_weights_cat)
+ processed_attention = processed_attention.transpose(1, 2)
+ processed_attention = self.location_dense(processed_attention)
+ return processed_attention
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ attention_rnn_dim,
+ embedding_dim,
+ attention_dim,
+ attention_location_n_filters,
+ attention_location_kernel_size,
+ ):
+ super(Attention, self).__init__()
+ self.query_layer = LinearNorm(
+ attention_rnn_dim, attention_dim, bias=False, w_init_gain="tanh"
+ )
+ self.memory_layer = LinearNorm(
+ embedding_dim, attention_dim, bias=False, w_init_gain="tanh"
+ )
+ self.v = LinearNorm(attention_dim, 1, bias=False)
+ self.location_layer = LocationLayer(
+ attention_location_n_filters, attention_location_kernel_size, attention_dim
+ )
+ self.score_mask_value = -float("inf")
+
+ def get_alignment_energies(self, query, processed_memory, attention_weights_cat):
+ """
+ PARAMS
+ ------
+ query: decoder output (batch, n_mel_channels * n_frames_per_step)
+ processed_memory: processed encoder outputs (B, T_in, attention_dim)
+ attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
+ RETURNS
+ -------
+ alignment (batch, max_time)
+ """
+
+ processed_query = self.query_layer(query.unsqueeze(1))
+ processed_attention_weights = self.location_layer(attention_weights_cat)
+ energies = self.v(
+ torch.tanh(processed_query + processed_attention_weights + processed_memory)
+ )
+
+ energies = energies.squeeze(-1)
+ return energies
+
+ def forward(
+ self,
+ attention_hidden_state,
+ memory,
+ processed_memory,
+ attention_weights_cat,
+ mask,
+ ):
+ """
+ PARAMS
+ ------
+ attention_hidden_state: attention rnn last output
+ memory: encoder outputs
+ processed_memory: processed encoder outputs
+ attention_weights_cat: previous and cummulative attention weights
+ mask: binary mask for padded data
+ """
+ alignment = self.get_alignment_energies(
+ attention_hidden_state, processed_memory, attention_weights_cat
+ )
+
+ if mask is not None:
+ alignment.data.masked_fill_(mask, self.score_mask_value)
+
+ attention_weights = F.softmax(alignment, dim=1)
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
+ attention_context = attention_context.squeeze(1)
+
+ return attention_context, attention_weights
+
+
+class ForwardAttentionV2(nn.Module):
+ def __init__(
+ self,
+ attention_rnn_dim,
+ embedding_dim,
+ attention_dim,
+ attention_location_n_filters,
+ attention_location_kernel_size,
+ ):
+ super(ForwardAttentionV2, self).__init__()
+ self.query_layer = LinearNorm(
+ attention_rnn_dim, attention_dim, bias=False, w_init_gain="tanh"
+ )
+ self.memory_layer = LinearNorm(
+ embedding_dim, attention_dim, bias=False, w_init_gain="tanh"
+ )
+ self.v = LinearNorm(attention_dim, 1, bias=False)
+ self.location_layer = LocationLayer(
+ attention_location_n_filters, attention_location_kernel_size, attention_dim
+ )
+ self.score_mask_value = -float(1e20)
+
+ def get_alignment_energies(self, query, processed_memory, attention_weights_cat):
+ """
+ PARAMS
+ ------
+ query: decoder output (batch, n_mel_channels * n_frames_per_step)
+ processed_memory: processed encoder outputs (B, T_in, attention_dim)
+ attention_weights_cat: prev. and cumulative att weights (B, 2, max_time)
+ RETURNS
+ -------
+ alignment (batch, max_time)
+ """
+
+ processed_query = self.query_layer(query.unsqueeze(1))
+ processed_attention_weights = self.location_layer(attention_weights_cat)
+ energies = self.v(
+ torch.tanh(processed_query + processed_attention_weights + processed_memory)
+ )
+
+ energies = energies.squeeze(-1)
+ return energies
+
+ def forward(
+ self,
+ attention_hidden_state,
+ memory,
+ processed_memory,
+ attention_weights_cat,
+ mask,
+ log_alpha,
+ ):
+ """
+ PARAMS
+ ------
+ attention_hidden_state: attention rnn last output
+ memory: encoder outputs
+ processed_memory: processed encoder outputs
+ attention_weights_cat: previous and cummulative attention weights
+ mask: binary mask for padded data
+ """
+ log_energy = self.get_alignment_energies(
+ attention_hidden_state, processed_memory, attention_weights_cat
+ )
+
+ # log_energy =
+
+ if mask is not None:
+ log_energy.data.masked_fill_(mask, self.score_mask_value)
+
+ # attention_weights = F.softmax(alignment, dim=1)
+
+ # content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME]
+ # log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1]
+
+ # log_total_score = log_alpha + content_score
+
+ # previous_attention_weights = attention_weights_cat[:,0,:]
+
+ log_alpha_shift_padded = []
+ max_time = log_energy.size(1)
+ for sft in range(2):
+ shifted = log_alpha[:, : max_time - sft]
+ shift_padded = F.pad(shifted, (sft, 0), "constant", self.score_mask_value)
+ log_alpha_shift_padded.append(shift_padded.unsqueeze(2))
+
+ biased = torch.logsumexp(torch.cat(log_alpha_shift_padded, 2), 2)
+
+ log_alpha_new = biased + log_energy
+
+ attention_weights = F.softmax(log_alpha_new, dim=1)
+
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
+ attention_context = attention_context.squeeze(1)
+
+ return attention_context, attention_weights, log_alpha_new
+
+
+class PhaseShuffle2d(nn.Module):
+ def __init__(self, n=2):
+ super(PhaseShuffle2d, self).__init__()
+ self.n = n
+ self.random = random.Random(1)
+
+ def forward(self, x, move=None):
+ # x.size = (B, C, M, L)
+ if move is None:
+ move = self.random.randint(-self.n, self.n)
+
+ if move == 0:
+ return x
+ else:
+ left = x[:, :, :, :move]
+ right = x[:, :, :, move:]
+ shuffled = torch.cat([right, left], dim=3)
+ return shuffled
+
+
+class PhaseShuffle1d(nn.Module):
+ def __init__(self, n=2):
+ super(PhaseShuffle1d, self).__init__()
+ self.n = n
+ self.random = random.Random(1)
+
+ def forward(self, x, move=None):
+ # x.size = (B, C, M, L)
+ if move is None:
+ move = self.random.randint(-self.n, self.n)
+
+ if move == 0:
+ return x
+ else:
+ left = x[:, :, :move]
+ right = x[:, :, move:]
+ shuffled = torch.cat([right, left], dim=2)
+
+ return shuffled
+
+
+class MFCC(nn.Module):
+ def __init__(self, n_mfcc=40, n_mels=80):
+ super(MFCC, self).__init__()
+ self.n_mfcc = n_mfcc
+ self.n_mels = n_mels
+ self.norm = "ortho"
+ dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
+ self.register_buffer("dct_mat", dct_mat)
+
+ def forward(self, mel_specgram):
+ if len(mel_specgram.shape) == 2:
+ mel_specgram = mel_specgram.unsqueeze(0)
+ unsqueezed = True
+ else:
+ unsqueezed = False
+ # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
+ # -> (channel, time, n_mfcc).tranpose(...)
+ mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
+
+ # unpack batch
+ if unsqueezed:
+ mfcc = mfcc.squeeze(0)
+ return mfcc
diff --git a/models/codec/facodec/modules/quantize.py b/models/codec/facodec/modules/quantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9cb55be173ff1aa7a15cccab7c7ac51fbd042c1
--- /dev/null
+++ b/models/codec/facodec/modules/quantize.py
@@ -0,0 +1,741 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from modules.dac.nn.quantize import ResidualVectorQuantize
+from torch import nn
+from .wavenet import WN
+from .style_encoder import StyleEncoder
+from .gradient_reversal import GradientReversal
+import torch
+import torchaudio
+import torchaudio.functional as audio_F
+import numpy as np
+from ..alias_free_torch import *
+from torch.nn.utils import weight_norm
+from torch import nn, sin, pow
+from einops.layers.torch import Rearrange
+from modules.dac.model.encodec import SConv1d
+
+
+def init_weights(m):
+ if isinstance(m, nn.Conv1d):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ nn.init.constant_(m.bias, 0)
+
+
+def WNConv1d(*args, **kwargs):
+ return weight_norm(nn.Conv1d(*args, **kwargs))
+
+
+def WNConvTranspose1d(*args, **kwargs):
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
+
+
+class SnakeBeta(nn.Module):
+ """
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
+ Shape:
+ - Input: (B, C, T)
+ - Output: (B, C, T), same shape as the input
+ Parameters:
+ - alpha - trainable parameter that controls frequency
+ - beta - trainable parameter that controls magnitude
+ References:
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
+ https://arxiv.org/abs/2006.08195
+ Examples:
+ >>> a1 = snakebeta(256)
+ >>> x = torch.randn(256)
+ >>> x = a1(x)
+ """
+
+ def __init__(
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
+ ):
+ """
+ Initialization.
+ INPUT:
+ - in_features: shape of the input
+ - alpha - trainable parameter that controls frequency
+ - beta - trainable parameter that controls magnitude
+ alpha is initialized to 1 by default, higher values = higher-frequency.
+ beta is initialized to 1 by default, higher values = higher-magnitude.
+ alpha will be trained along with the rest of your model.
+ """
+ super(SnakeBeta, self).__init__()
+ self.in_features = in_features
+
+ # initialize alpha
+ self.alpha_logscale = alpha_logscale
+ if self.alpha_logscale: # log scale alphas initialized to zeros
+ self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
+ self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
+ else: # linear scale alphas initialized to ones
+ self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
+ self.beta = nn.Parameter(torch.ones(in_features) * alpha)
+
+ self.alpha.requires_grad = alpha_trainable
+ self.beta.requires_grad = alpha_trainable
+
+ self.no_div_by_zero = 0.000000001
+
+ def forward(self, x):
+ """
+ Forward pass of the function.
+ Applies the function to the input elementwise.
+ SnakeBeta := x + 1/b * sin^2 (xa)
+ """
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
+ if self.alpha_logscale:
+ alpha = torch.exp(alpha)
+ beta = torch.exp(beta)
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
+
+ return x
+
+
+class ResidualUnit(nn.Module):
+ def __init__(self, dim: int = 16, dilation: int = 1):
+ super().__init__()
+ pad = ((7 - 1) * dilation) // 2
+ self.block = nn.Sequential(
+ Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
+ WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
+ Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
+ WNConv1d(dim, dim, kernel_size=1),
+ )
+
+ def forward(self, x):
+ return x + self.block(x)
+
+
+class CNNLSTM(nn.Module):
+ def __init__(self, indim, outdim, head, global_pred=False):
+ super().__init__()
+ self.global_pred = global_pred
+ self.model = nn.Sequential(
+ ResidualUnit(indim, dilation=1),
+ ResidualUnit(indim, dilation=2),
+ ResidualUnit(indim, dilation=3),
+ Activation1d(activation=SnakeBeta(indim, alpha_logscale=True)),
+ Rearrange("b c t -> b t c"),
+ )
+ self.heads = nn.ModuleList([nn.Linear(indim, outdim) for i in range(head)])
+
+ def forward(self, x):
+ # x: [B, C, T]
+ x = self.model(x)
+ if self.global_pred:
+ x = torch.mean(x, dim=1, keepdim=False)
+ outs = [head(x) for head in self.heads]
+ return outs
+
+
+def sequence_mask(length, max_length=None):
+ if max_length is None:
+ max_length = length.max()
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
+ return x.unsqueeze(0) < length.unsqueeze(1)
+
+
+class MFCC(nn.Module):
+ def __init__(self, n_mfcc=40, n_mels=80):
+ super(MFCC, self).__init__()
+ self.n_mfcc = n_mfcc
+ self.n_mels = n_mels
+ self.norm = "ortho"
+ dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
+ self.register_buffer("dct_mat", dct_mat)
+
+ def forward(self, mel_specgram):
+ if len(mel_specgram.shape) == 2:
+ mel_specgram = mel_specgram.unsqueeze(0)
+ unsqueezed = True
+ else:
+ unsqueezed = False
+ # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
+ # -> (channel, time, n_mfcc).tranpose(...)
+ mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
+
+ # unpack batch
+ if unsqueezed:
+ mfcc = mfcc.squeeze(0)
+ return mfcc
+
+
+class FAquantizer(nn.Module):
+ def __init__(
+ self,
+ in_dim=1024,
+ n_p_codebooks=1,
+ n_c_codebooks=2,
+ n_t_codebooks=2,
+ n_r_codebooks=3,
+ codebook_size=1024,
+ codebook_dim=8,
+ quantizer_dropout=0.5,
+ causal=False,
+ separate_prosody_encoder=False,
+ timbre_norm=False,
+ ):
+ super(FAquantizer, self).__init__()
+ conv1d_type = SConv1d # if causal else nn.Conv1d
+ self.prosody_quantizer = ResidualVectorQuantize(
+ input_dim=in_dim,
+ n_codebooks=n_p_codebooks,
+ codebook_size=codebook_size,
+ codebook_dim=codebook_dim,
+ quantizer_dropout=quantizer_dropout,
+ )
+
+ self.content_quantizer = ResidualVectorQuantize(
+ input_dim=in_dim,
+ n_codebooks=n_c_codebooks,
+ codebook_size=codebook_size,
+ codebook_dim=codebook_dim,
+ quantizer_dropout=quantizer_dropout,
+ )
+
+ if not timbre_norm:
+ self.timbre_quantizer = ResidualVectorQuantize(
+ input_dim=in_dim,
+ n_codebooks=n_t_codebooks,
+ codebook_size=codebook_size,
+ codebook_dim=codebook_dim,
+ quantizer_dropout=quantizer_dropout,
+ )
+ else:
+ self.timbre_encoder = StyleEncoder(
+ in_dim=80, hidden_dim=512, out_dim=in_dim
+ )
+ self.timbre_linear = nn.Linear(1024, 1024 * 2)
+ self.timbre_linear.bias.data[:1024] = 1
+ self.timbre_linear.bias.data[1024:] = 0
+ self.timbre_norm = nn.LayerNorm(1024, elementwise_affine=False)
+
+ self.residual_quantizer = ResidualVectorQuantize(
+ input_dim=in_dim,
+ n_codebooks=n_r_codebooks,
+ codebook_size=codebook_size,
+ codebook_dim=codebook_dim,
+ quantizer_dropout=quantizer_dropout,
+ )
+
+ if separate_prosody_encoder:
+ self.melspec_linear = conv1d_type(
+ in_channels=20, out_channels=256, kernel_size=1, causal=causal
+ )
+ self.melspec_encoder = WN(
+ hidden_channels=256,
+ kernel_size=5,
+ dilation_rate=1,
+ n_layers=8,
+ gin_channels=0,
+ p_dropout=0.2,
+ causal=causal,
+ )
+ self.melspec_linear2 = conv1d_type(
+ in_channels=256, out_channels=1024, kernel_size=1, causal=causal
+ )
+ else:
+ pass
+ self.separate_prosody_encoder = separate_prosody_encoder
+
+ self.prob_random_mask_residual = 0.75
+
+ SPECT_PARAMS = {
+ "n_fft": 2048,
+ "win_length": 1200,
+ "hop_length": 300,
+ }
+ MEL_PARAMS = {
+ "n_mels": 80,
+ }
+
+ self.to_mel = torchaudio.transforms.MelSpectrogram(
+ n_mels=MEL_PARAMS["n_mels"], sample_rate=24000, **SPECT_PARAMS
+ )
+ self.mel_mean, self.mel_std = -4, 4
+ self.frame_rate = 24000 / 300
+ self.hop_length = 300
+
+ self.is_timbre_norm = timbre_norm
+ if timbre_norm:
+ self.forward = self.forward_v2
+
+ def preprocess(self, wave_tensor, n_bins=20):
+ mel_tensor = self.to_mel(wave_tensor.squeeze(1))
+ mel_tensor = (torch.log(1e-5 + mel_tensor) - self.mel_mean) / self.mel_std
+ return mel_tensor[:, :n_bins, : int(wave_tensor.size(-1) / self.hop_length)]
+
+ @torch.no_grad()
+ def decode(self, codes):
+ code_c, code_p, code_t = codes.split([1, 1, 2], dim=1)
+
+ z_c = self.content_quantizer.from_codes(code_c)[0]
+ z_p = self.prosody_quantizer.from_codes(code_p)[0]
+ z_t = self.timbre_quantizer.from_codes(code_t)[0]
+
+ z = z_c + z_p + z_t
+
+ return z, [z_c, z_p, z_t]
+
+ @torch.no_grad()
+ def encode(self, x, wave_segments, n_c=1):
+ outs = 0
+ if self.separate_prosody_encoder:
+ prosody_feature = self.preprocess(wave_segments)
+
+ f0_input = prosody_feature # (B, T, 20)
+ f0_input = self.melspec_linear(f0_input)
+ f0_input = self.melspec_encoder(
+ f0_input,
+ torch.ones(f0_input.shape[0], 1, f0_input.shape[2])
+ .to(f0_input.device)
+ .bool(),
+ )
+ f0_input = self.melspec_linear2(f0_input)
+
+ common_min_size = min(f0_input.size(2), x.size(2))
+ f0_input = f0_input[:, :, :common_min_size]
+
+ x = x[:, :, :common_min_size]
+
+ (
+ z_p,
+ codes_p,
+ latents_p,
+ commitment_loss_p,
+ codebook_loss_p,
+ ) = self.prosody_quantizer(f0_input, 1)
+ outs += z_p.detach()
+ else:
+ (
+ z_p,
+ codes_p,
+ latents_p,
+ commitment_loss_p,
+ codebook_loss_p,
+ ) = self.prosody_quantizer(x, 1)
+ outs += z_p.detach()
+
+ (
+ z_c,
+ codes_c,
+ latents_c,
+ commitment_loss_c,
+ codebook_loss_c,
+ ) = self.content_quantizer(x, n_c)
+ outs += z_c.detach()
+
+ timbre_residual_feature = x - z_p.detach() - z_c.detach()
+
+ (
+ z_t,
+ codes_t,
+ latents_t,
+ commitment_loss_t,
+ codebook_loss_t,
+ ) = self.timbre_quantizer(timbre_residual_feature, 2)
+ outs += z_t # we should not detach timbre
+
+ residual_feature = timbre_residual_feature - z_t
+
+ (
+ z_r,
+ codes_r,
+ latents_r,
+ commitment_loss_r,
+ codebook_loss_r,
+ ) = self.residual_quantizer(residual_feature, 3)
+
+ return [codes_c, codes_p, codes_t, codes_r], [z_c, z_p, z_t, z_r]
+
+ def forward(
+ self, x, wave_segments, noise_added_flags, recon_noisy_flags, n_c=2, n_t=2
+ ):
+ # timbre = self.timbre_encoder(mels, sequence_mask(mel_lens, mels.size(-1)).unsqueeze(1))
+ # timbre = self.timbre_encoder(mel_segments, torch.ones(mel_segments.size(0), 1, mel_segments.size(2)).bool().to(mel_segments.device))
+ outs = 0
+ if self.separate_prosody_encoder:
+ prosody_feature = self.preprocess(wave_segments)
+
+ f0_input = prosody_feature # (B, T, 20)
+ f0_input = self.melspec_linear(f0_input)
+ f0_input = self.melspec_encoder(
+ f0_input,
+ torch.ones(f0_input.shape[0], 1, f0_input.shape[2])
+ .to(f0_input.device)
+ .bool(),
+ )
+ f0_input = self.melspec_linear2(f0_input)
+
+ common_min_size = min(f0_input.size(2), x.size(2))
+ f0_input = f0_input[:, :, :common_min_size]
+
+ x = x[:, :, :common_min_size]
+
+ (
+ z_p,
+ codes_p,
+ latents_p,
+ commitment_loss_p,
+ codebook_loss_p,
+ ) = self.prosody_quantizer(f0_input, 1)
+ outs += z_p.detach()
+ else:
+ (
+ z_p,
+ codes_p,
+ latents_p,
+ commitment_loss_p,
+ codebook_loss_p,
+ ) = self.prosody_quantizer(x, 1)
+ outs += z_p.detach()
+
+ (
+ z_c,
+ codes_c,
+ latents_c,
+ commitment_loss_c,
+ codebook_loss_c,
+ ) = self.content_quantizer(x, n_c)
+ outs += z_c.detach()
+
+ timbre_residual_feature = x - z_p.detach() - z_c.detach()
+
+ (
+ z_t,
+ codes_t,
+ latents_t,
+ commitment_loss_t,
+ codebook_loss_t,
+ ) = self.timbre_quantizer(timbre_residual_feature, n_t)
+ outs += z_t # we should not detach timbre
+
+ residual_feature = timbre_residual_feature - z_t
+
+ (
+ z_r,
+ codes_r,
+ latents_r,
+ commitment_loss_r,
+ codebook_loss_r,
+ ) = self.residual_quantizer(residual_feature, 3)
+
+ bsz = z_r.shape[0]
+ res_mask = np.random.choice(
+ [0, 1],
+ size=bsz,
+ p=[
+ self.prob_random_mask_residual,
+ 1 - self.prob_random_mask_residual,
+ ],
+ )
+ res_mask = torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1) # (B, 1, 1)
+ res_mask = res_mask.to(device=z_r.device, dtype=z_r.dtype)
+ noise_must_on = noise_added_flags * recon_noisy_flags
+ noise_must_off = noise_added_flags * (~recon_noisy_flags)
+ res_mask[noise_must_on] = 1
+ res_mask[noise_must_off] = 0
+
+ outs += z_r * res_mask
+
+ quantized = [z_p, z_c, z_t, z_r]
+ commitment_losses = (
+ commitment_loss_p
+ + commitment_loss_c
+ + commitment_loss_t
+ + commitment_loss_r
+ )
+ codebook_losses = (
+ codebook_loss_p + codebook_loss_c + codebook_loss_t + codebook_loss_r
+ )
+
+ return outs, quantized, commitment_losses, codebook_losses
+
+ def forward_v2(
+ self,
+ x,
+ wave_segments,
+ n_c=1,
+ n_t=2,
+ full_waves=None,
+ wave_lens=None,
+ return_codes=False,
+ ):
+ # timbre = self.timbre_encoder(x, sequence_mask(mel_lens, mels.size(-1)).unsqueeze(1))
+ if full_waves is None:
+ mel = self.preprocess(wave_segments, n_bins=80)
+ timbre = self.timbre_encoder(
+ mel, torch.ones(mel.size(0), 1, mel.size(2)).bool().to(mel.device)
+ )
+ else:
+ mel = self.preprocess(full_waves, n_bins=80)
+ timbre = self.timbre_encoder(
+ mel,
+ sequence_mask(wave_lens // self.hop_length, mel.size(-1)).unsqueeze(1),
+ )
+ outs = 0
+ if self.separate_prosody_encoder:
+ prosody_feature = self.preprocess(wave_segments)
+
+ f0_input = prosody_feature # (B, T, 20)
+ f0_input = self.melspec_linear(f0_input)
+ f0_input = self.melspec_encoder(
+ f0_input,
+ torch.ones(f0_input.shape[0], 1, f0_input.shape[2])
+ .to(f0_input.device)
+ .bool(),
+ )
+ f0_input = self.melspec_linear2(f0_input)
+
+ common_min_size = min(f0_input.size(2), x.size(2))
+ f0_input = f0_input[:, :, :common_min_size]
+
+ x = x[:, :, :common_min_size]
+
+ (
+ z_p,
+ codes_p,
+ latents_p,
+ commitment_loss_p,
+ codebook_loss_p,
+ ) = self.prosody_quantizer(f0_input, 1)
+ outs += z_p.detach()
+ else:
+ (
+ z_p,
+ codes_p,
+ latents_p,
+ commitment_loss_p,
+ codebook_loss_p,
+ ) = self.prosody_quantizer(x, 1)
+ outs += z_p.detach()
+
+ (
+ z_c,
+ codes_c,
+ latents_c,
+ commitment_loss_c,
+ codebook_loss_c,
+ ) = self.content_quantizer(x, n_c)
+ outs += z_c.detach()
+
+ residual_feature = x - z_p.detach() - z_c.detach()
+
+ (
+ z_r,
+ codes_r,
+ latents_r,
+ commitment_loss_r,
+ codebook_loss_r,
+ ) = self.residual_quantizer(residual_feature, 3)
+
+ bsz = z_r.shape[0]
+ res_mask = np.random.choice(
+ [0, 1],
+ size=bsz,
+ p=[
+ self.prob_random_mask_residual,
+ 1 - self.prob_random_mask_residual,
+ ],
+ )
+ res_mask = torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1) # (B, 1, 1)
+ res_mask = res_mask.to(device=z_r.device, dtype=z_r.dtype)
+
+ if not self.training:
+ res_mask = torch.ones_like(res_mask)
+ outs += z_r * res_mask
+
+ quantized = [z_p, z_c, z_r]
+ codes = [codes_p, codes_c, codes_r]
+ commitment_losses = commitment_loss_p + commitment_loss_c + commitment_loss_r
+ codebook_losses = codebook_loss_p + codebook_loss_c + codebook_loss_r
+
+ style = self.timbre_linear(timbre).unsqueeze(2) # (B, 2d, 1)
+ gamma, beta = style.chunk(2, 1) # (B, d, 1)
+ outs = outs.transpose(1, 2)
+ outs = self.timbre_norm(outs)
+ outs = outs.transpose(1, 2)
+ outs = outs * gamma + beta
+
+ if return_codes:
+ return outs, quantized, commitment_losses, codebook_losses, timbre, codes
+ else:
+ return outs, quantized, commitment_losses, codebook_losses, timbre
+
+ def voice_conversion(self, z, ref_wave):
+ ref_mel = self.preprocess(ref_wave, n_bins=80)
+ ref_timbre = self.timbre_encoder(
+ ref_mel,
+ sequence_mask(
+ torch.LongTensor([ref_wave.size(-1)]).to(z.device) // self.hop_length,
+ ref_mel.size(-1),
+ ).unsqueeze(1),
+ )
+ style = self.timbre_linear(ref_timbre).unsqueeze(2) # (B, 2d, 1)
+ gamma, beta = style.chunk(2, 1) # (B, d, 1)
+ outs = z.transpose(1, 2)
+ outs = self.timbre_norm(outs)
+ outs = outs.transpose(1, 2)
+ outs = outs * gamma + beta
+
+ return outs
+
+
+class FApredictors(nn.Module):
+ def __init__(
+ self,
+ in_dim=1024,
+ use_gr_content_f0=False,
+ use_gr_prosody_phone=False,
+ use_gr_residual_f0=False,
+ use_gr_residual_phone=False,
+ use_gr_timbre_content=True,
+ use_gr_timbre_prosody=True,
+ use_gr_x_timbre=False,
+ norm_f0=True,
+ timbre_norm=False,
+ use_gr_content_global_f0=False,
+ ):
+ super(FApredictors, self).__init__()
+ self.f0_predictor = CNNLSTM(in_dim, 1, 2)
+ self.phone_predictor = CNNLSTM(in_dim, 1024, 1)
+ if timbre_norm:
+ self.timbre_predictor = nn.Linear(in_dim, 20000)
+ else:
+ self.timbre_predictor = CNNLSTM(in_dim, 20000, 1, global_pred=True)
+
+ self.use_gr_content_f0 = use_gr_content_f0
+ self.use_gr_prosody_phone = use_gr_prosody_phone
+ self.use_gr_residual_f0 = use_gr_residual_f0
+ self.use_gr_residual_phone = use_gr_residual_phone
+ self.use_gr_timbre_content = use_gr_timbre_content
+ self.use_gr_timbre_prosody = use_gr_timbre_prosody
+ self.use_gr_x_timbre = use_gr_x_timbre
+
+ self.rev_f0_predictor = nn.Sequential(
+ GradientReversal(alpha=1.0), CNNLSTM(in_dim, 1, 2)
+ )
+ self.rev_content_predictor = nn.Sequential(
+ GradientReversal(alpha=1.0), CNNLSTM(in_dim, 1024, 1)
+ )
+ self.rev_timbre_predictor = nn.Sequential(
+ GradientReversal(alpha=1.0), CNNLSTM(in_dim, 20000, 1, global_pred=True)
+ )
+
+ self.norm_f0 = norm_f0
+ self.timbre_norm = timbre_norm
+ if timbre_norm:
+ self.forward = self.forward_v2
+ self.global_f0_predictor = nn.Linear(in_dim, 1)
+
+ self.use_gr_content_global_f0 = use_gr_content_global_f0
+ if use_gr_content_global_f0:
+ self.rev_global_f0_predictor = nn.Sequential(
+ GradientReversal(alpha=1.0), CNNLSTM(in_dim, 1, 1, global_pred=True)
+ )
+
+ def forward(self, quantized):
+ prosody_latent = quantized[0]
+ content_latent = quantized[1]
+ timbre_latent = quantized[2]
+ residual_latent = quantized[3]
+ content_pred = self.phone_predictor(content_latent)[0]
+
+ if self.norm_f0:
+ spk_pred = self.timbre_predictor(timbre_latent)[0]
+ f0_pred, uv_pred = self.f0_predictor(prosody_latent)
+ else:
+ spk_pred = self.timbre_predictor(timbre_latent + prosody_latent)[0]
+ f0_pred, uv_pred = self.f0_predictor(prosody_latent + timbre_latent)
+
+ prosody_rev_latent = torch.zeros_like(quantized[0])
+ if self.use_gr_content_f0:
+ prosody_rev_latent += quantized[1]
+ if self.use_gr_timbre_prosody:
+ prosody_rev_latent += quantized[2]
+ if self.use_gr_residual_f0:
+ prosody_rev_latent += quantized[3]
+ rev_f0_pred, rev_uv_pred = self.rev_f0_predictor(prosody_rev_latent)
+
+ content_rev_latent = torch.zeros_like(quantized[1])
+ if self.use_gr_prosody_phone:
+ content_rev_latent += quantized[0]
+ if self.use_gr_timbre_content:
+ content_rev_latent += quantized[2]
+ if self.use_gr_residual_phone:
+ content_rev_latent += quantized[3]
+ rev_content_pred = self.rev_content_predictor(content_rev_latent)[0]
+
+ if self.norm_f0:
+ timbre_rev_latent = quantized[0] + quantized[1] + quantized[3]
+ else:
+ timbre_rev_latent = quantized[1] + quantized[3]
+ if self.use_gr_x_timbre:
+ x_spk_pred = self.rev_timbre_predictor(timbre_rev_latent)[0]
+ else:
+ x_spk_pred = None
+
+ preds = {
+ "f0": f0_pred,
+ "uv": uv_pred,
+ "content": content_pred,
+ "timbre": spk_pred,
+ }
+
+ rev_preds = {
+ "rev_f0": rev_f0_pred,
+ "rev_uv": rev_uv_pred,
+ "rev_content": rev_content_pred,
+ "x_timbre": x_spk_pred,
+ }
+ return preds, rev_preds
+
+ def forward_v2(self, quantized, timbre):
+ prosody_latent = quantized[0]
+ content_latent = quantized[1]
+ residual_latent = quantized[2]
+ content_pred = self.phone_predictor(content_latent)[0]
+
+ spk_pred = self.timbre_predictor(timbre)
+ f0_pred, uv_pred = self.f0_predictor(prosody_latent)
+
+ prosody_rev_latent = torch.zeros_like(prosody_latent)
+ if self.use_gr_content_f0:
+ prosody_rev_latent += content_latent
+ if self.use_gr_residual_f0:
+ prosody_rev_latent += residual_latent
+ rev_f0_pred, rev_uv_pred = self.rev_f0_predictor(prosody_rev_latent)
+
+ content_rev_latent = torch.zeros_like(content_latent)
+ if self.use_gr_prosody_phone:
+ content_rev_latent += prosody_latent
+ if self.use_gr_residual_phone:
+ content_rev_latent += residual_latent
+ rev_content_pred = self.rev_content_predictor(content_rev_latent)[0]
+
+ timbre_rev_latent = prosody_latent + content_latent + residual_latent
+ if self.use_gr_x_timbre:
+ x_spk_pred = self.rev_timbre_predictor(timbre_rev_latent)[0]
+ else:
+ x_spk_pred = None
+
+ preds = {
+ "f0": f0_pred,
+ "uv": uv_pred,
+ "content": content_pred,
+ "timbre": spk_pred,
+ }
+
+ rev_preds = {
+ "rev_f0": rev_f0_pred,
+ "rev_uv": rev_uv_pred,
+ "rev_content": rev_content_pred,
+ "x_timbre": x_spk_pred,
+ }
+ return preds, rev_preds
diff --git a/models/codec/facodec/modules/style_encoder.py b/models/codec/facodec/modules/style_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..e437c1adfc823af8f2324a24c4801b130eb69191
--- /dev/null
+++ b/models/codec/facodec/modules/style_encoder.py
@@ -0,0 +1,110 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# This code is modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/ttv_v1/styleencoder.py
+
+from . import attentions
+from torch import nn
+import torch
+from torch.nn import functional as F
+
+
+class Mish(nn.Module):
+ def __init__(self):
+ super(Mish, self).__init__()
+
+ def forward(self, x):
+ return x * torch.tanh(F.softplus(x))
+
+
+class Conv1dGLU(nn.Module):
+ """
+ Conv1d + GLU(Gated Linear Unit) with residual connection.
+ For GLU refer to https://arxiv.org/abs/1612.08083 paper.
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size, dropout):
+ super(Conv1dGLU, self).__init__()
+ self.out_channels = out_channels
+ self.conv1 = nn.Conv1d(
+ in_channels, 2 * out_channels, kernel_size=kernel_size, padding=2
+ )
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x):
+ residual = x
+ x = self.conv1(x)
+ x1, x2 = torch.split(x, split_size_or_sections=self.out_channels, dim=1)
+ x = x1 * torch.sigmoid(x2)
+ x = residual + self.dropout(x)
+ return x
+
+
+class StyleEncoder(torch.nn.Module):
+ def __init__(self, in_dim=513, hidden_dim=128, out_dim=256):
+
+ super().__init__()
+
+ self.in_dim = in_dim # Linear 513 wav2vec 2.0 1024
+ self.hidden_dim = hidden_dim
+ self.out_dim = out_dim
+ self.kernel_size = 5
+ self.n_head = 2
+ self.dropout = 0.1
+
+ self.spectral = nn.Sequential(
+ nn.Conv1d(self.in_dim, self.hidden_dim, 1),
+ Mish(),
+ nn.Dropout(self.dropout),
+ nn.Conv1d(self.hidden_dim, self.hidden_dim, 1),
+ Mish(),
+ nn.Dropout(self.dropout),
+ )
+
+ self.temporal = nn.Sequential(
+ Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
+ Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
+ )
+
+ self.slf_attn = attentions.MultiHeadAttention(
+ self.hidden_dim,
+ self.hidden_dim,
+ self.n_head,
+ p_dropout=self.dropout,
+ proximal_bias=False,
+ proximal_init=True,
+ )
+ self.atten_drop = nn.Dropout(self.dropout)
+ self.fc = nn.Conv1d(self.hidden_dim, self.out_dim, 1)
+
+ def forward(self, x, mask=None):
+
+ # spectral
+ x = self.spectral(x) * mask
+ # temporal
+ x = self.temporal(x) * mask
+
+ # self-attention
+ attn_mask = mask.unsqueeze(2) * mask.unsqueeze(-1)
+ y = self.slf_attn(x, x, attn_mask=attn_mask)
+ x = x + self.atten_drop(y)
+
+ # fc
+ x = self.fc(x)
+
+ # temoral average pooling
+ w = self.temporal_avg_pool(x, mask=mask)
+
+ return w
+
+ def temporal_avg_pool(self, x, mask=None):
+ if mask is None:
+ out = torch.mean(x, dim=2)
+ else:
+ len_ = mask.sum(dim=2)
+ x = x.sum(dim=2)
+
+ out = torch.div(x, len_)
+ return out
diff --git a/models/codec/facodec/modules/wavenet.py b/models/codec/facodec/modules/wavenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a2da541be26d54f0ccc098637334ad812bc2374
--- /dev/null
+++ b/models/codec/facodec/modules/wavenet.py
@@ -0,0 +1,224 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# This code is modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/ttv_v1/modules.py
+
+import math
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from modules.dac.model.encodec import SConv1d
+
+from . import commons
+
+LRELU_SLOPE = 0.1
+
+
+class LayerNorm(nn.Module):
+ def __init__(self, channels, eps=1e-5):
+ super().__init__()
+ self.channels = channels
+ self.eps = eps
+
+ self.gamma = nn.Parameter(torch.ones(channels))
+ self.beta = nn.Parameter(torch.zeros(channels))
+
+ def forward(self, x):
+ x = x.transpose(1, -1)
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
+ return x.transpose(1, -1)
+
+
+class ConvReluNorm(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ hidden_channels,
+ out_channels,
+ kernel_size,
+ n_layers,
+ p_dropout,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.hidden_channels = hidden_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.n_layers = n_layers
+ self.p_dropout = p_dropout
+ assert n_layers > 1, "Number of layers should be larger than 0."
+
+ self.conv_layers = nn.ModuleList()
+ self.norm_layers = nn.ModuleList()
+ self.conv_layers.append(
+ nn.Conv1d(
+ in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
+ )
+ )
+ self.norm_layers.append(LayerNorm(hidden_channels))
+ self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
+ for _ in range(n_layers - 1):
+ self.conv_layers.append(
+ nn.Conv1d(
+ hidden_channels,
+ hidden_channels,
+ kernel_size,
+ padding=kernel_size // 2,
+ )
+ )
+ self.norm_layers.append(LayerNorm(hidden_channels))
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
+ self.proj.weight.data.zero_()
+ self.proj.bias.data.zero_()
+
+ def forward(self, x, x_mask):
+ x_org = x
+ for i in range(self.n_layers):
+ x = self.conv_layers[i](x * x_mask)
+ x = self.norm_layers[i](x)
+ x = self.relu_drop(x)
+ x = x_org + self.proj(x)
+ return x * x_mask
+
+
+class DDSConv(nn.Module):
+ """
+ Dialted and Depth-Separable Convolution
+ """
+
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
+ super().__init__()
+ self.channels = channels
+ self.kernel_size = kernel_size
+ self.n_layers = n_layers
+ self.p_dropout = p_dropout
+
+ self.drop = nn.Dropout(p_dropout)
+ self.convs_sep = nn.ModuleList()
+ self.convs_1x1 = nn.ModuleList()
+ self.norms_1 = nn.ModuleList()
+ self.norms_2 = nn.ModuleList()
+ for i in range(n_layers):
+ dilation = kernel_size**i
+ padding = (kernel_size * dilation - dilation) // 2
+ self.convs_sep.append(
+ nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ groups=channels,
+ dilation=dilation,
+ padding=padding,
+ )
+ )
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
+ self.norms_1.append(LayerNorm(channels))
+ self.norms_2.append(LayerNorm(channels))
+
+ def forward(self, x, x_mask, g=None):
+ if g is not None:
+ x = x + g
+ for i in range(self.n_layers):
+ y = self.convs_sep[i](x * x_mask)
+ y = self.norms_1[i](y)
+ y = F.gelu(y)
+ y = self.convs_1x1[i](y)
+ y = self.norms_2[i](y)
+ y = F.gelu(y)
+ y = self.drop(y)
+ x = x + y
+ return x * x_mask
+
+
+class WN(torch.nn.Module):
+ def __init__(
+ self,
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ gin_channels=0,
+ p_dropout=0,
+ causal=False,
+ ):
+ super(WN, self).__init__()
+ conv1d_type = SConv1d
+ assert kernel_size % 2 == 1
+ self.hidden_channels = hidden_channels
+ self.kernel_size = (kernel_size,)
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.gin_channels = gin_channels
+ self.p_dropout = p_dropout
+
+ self.in_layers = torch.nn.ModuleList()
+ self.res_skip_layers = torch.nn.ModuleList()
+ self.drop = nn.Dropout(p_dropout)
+
+ if gin_channels != 0:
+ self.cond_layer = conv1d_type(
+ gin_channels, 2 * hidden_channels * n_layers, 1, norm="weight_norm"
+ )
+
+ for i in range(n_layers):
+ dilation = dilation_rate**i
+ padding = int((kernel_size * dilation - dilation) / 2)
+ in_layer = conv1d_type(
+ hidden_channels,
+ 2 * hidden_channels,
+ kernel_size,
+ dilation=dilation,
+ padding=padding,
+ norm="weight_norm",
+ causal=causal,
+ )
+ self.in_layers.append(in_layer)
+
+ # last one is not necessary
+ if i < n_layers - 1:
+ res_skip_channels = 2 * hidden_channels
+ else:
+ res_skip_channels = hidden_channels
+
+ res_skip_layer = conv1d_type(
+ hidden_channels, res_skip_channels, 1, norm="weight_norm", causal=causal
+ )
+ self.res_skip_layers.append(res_skip_layer)
+
+ def forward(self, x, x_mask, g=None, **kwargs):
+ output = torch.zeros_like(x)
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
+
+ if g is not None:
+ g = self.cond_layer(g)
+
+ for i in range(self.n_layers):
+ x_in = self.in_layers[i](x)
+ if g is not None:
+ cond_offset = i * 2 * self.hidden_channels
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
+ else:
+ g_l = torch.zeros_like(x_in)
+
+ acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
+ acts = self.drop(acts)
+
+ res_skip_acts = self.res_skip_layers[i](acts)
+ if i < self.n_layers - 1:
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
+ x = (x + res_acts) * x_mask
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
+ else:
+ output = output + res_skip_acts
+ return output * x_mask
+
+ def remove_weight_norm(self):
+ if self.gin_channels != 0:
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
+ for l in self.in_layers:
+ torch.nn.utils.remove_weight_norm(l)
+ for l in self.res_skip_layers:
+ torch.nn.utils.remove_weight_norm(l)
diff --git a/models/codec/facodec/optimizer.py b/models/codec/facodec/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d6d798d5f20a137a5140834d64407c423012673
--- /dev/null
+++ b/models/codec/facodec/optimizer.py
@@ -0,0 +1,104 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os, sys
+import os.path as osp
+import numpy as np
+import torch
+from torch import nn
+from torch.optim import Optimizer
+from functools import reduce
+from torch.optim import AdamW
+
+
+class MultiOptimizer:
+ def __init__(self, optimizers={}, schedulers={}):
+ self.optimizers = optimizers
+ self.schedulers = schedulers
+ self.keys = list(optimizers.keys())
+ self.param_groups = reduce(
+ lambda x, y: x + y, [v.param_groups for v in self.optimizers.values()]
+ )
+
+ def state_dict(self):
+ state_dicts = [(key, self.optimizers[key].state_dict()) for key in self.keys]
+ return state_dicts
+
+ def scheduler_state_dict(self):
+ state_dicts = [(key, self.schedulers[key].state_dict()) for key in self.keys]
+ return state_dicts
+
+ def load_state_dict(self, state_dict):
+ for key, val in state_dict:
+ try:
+ self.optimizers[key].load_state_dict(val)
+ except:
+ print("Unloaded %s" % key)
+
+ def load_scheduler_state_dict(self, state_dict):
+ for key, val in state_dict:
+ try:
+ self.schedulers[key].load_state_dict(val)
+ except:
+ print("Unloaded %s" % key)
+
+ def step(self, key=None, scaler=None):
+ keys = [key] if key is not None else self.keys
+ _ = [self._step(key, scaler) for key in keys]
+
+ def _step(self, key, scaler=None):
+ if scaler is not None:
+ scaler.step(self.optimizers[key])
+ scaler.update()
+ else:
+ self.optimizers[key].step()
+
+ def zero_grad(self, key=None):
+ if key is not None:
+ self.optimizers[key].zero_grad()
+ else:
+ _ = [self.optimizers[key].zero_grad() for key in self.keys]
+
+ def scheduler(self, *args, key=None):
+ if key is not None:
+ self.schedulers[key].step(*args)
+ else:
+ _ = [self.schedulers[key].step_batch(*args) for key in self.keys]
+
+
+def define_scheduler(optimizer, params):
+ scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=params["gamma"])
+
+ return scheduler
+
+
+def build_optimizer(model_dict, scheduler_params_dict, lr, type="AdamW"):
+ optim = {}
+ for key, model in model_dict.items():
+ model_parameters = model.parameters()
+ parameters_names = []
+ parameters_names.append(
+ [name_param_pair[0] for name_param_pair in model.named_parameters()]
+ )
+ if type == "AdamW":
+ optim[key] = AdamW(
+ model_parameters,
+ lr=lr,
+ betas=(0.9, 0.98),
+ eps=1e-9,
+ weight_decay=0.1,
+ )
+ else:
+ raise ValueError("Unknown optimizer type: %s" % type)
+
+ schedulers = dict(
+ [
+ (key, torch.optim.lr_scheduler.ExponentialLR(opt, gamma=0.999996))
+ for key, opt in optim.items()
+ ]
+ )
+
+ multi_optim = MultiOptimizer(optim, schedulers)
+ return multi_optim
diff --git a/models/codec/kmeans/repcodec_model.py b/models/codec/kmeans/repcodec_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5124444c7e1c6bf0f626a2b9cb0f9245b0b0de4
--- /dev/null
+++ b/models/codec/kmeans/repcodec_model.py
@@ -0,0 +1,210 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from concurrent.futures import ALL_COMPLETED
+import numpy as np
+import torch
+import torch.nn as nn
+
+from torch.nn import functional as F
+from einops import rearrange, repeat
+
+from models.codec.amphion_codec.quantize import ResidualVQ
+from models.codec.kmeans.vocos import VocosBackbone
+
+
+def init_weights(m):
+ if isinstance(m, nn.Conv1d):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ nn.init.constant_(m.bias, 0)
+ if isinstance(m, nn.Linear):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ nn.init.constant_(m.bias, 0)
+
+
+def compute_codebook_perplexity(indices, codebook_size):
+ indices = indices.flatten()
+ prob = torch.bincount(indices, minlength=codebook_size).float() / indices.size(0)
+ perp = torch.exp(-torch.sum(prob * torch.log(prob + 1e-10)))
+ return perp
+
+
+class RepCodec(nn.Module):
+ def __init__(
+ self,
+ codebook_size=8192,
+ hidden_size=1024,
+ codebook_dim=8,
+ vocos_dim=384,
+ vocos_intermediate_dim=2048,
+ vocos_num_layers=12,
+ num_quantizers=1,
+ downsample_scale=1,
+ cfg=None,
+ ):
+ super().__init__()
+ codebook_size = (
+ cfg.codebook_size
+ if cfg is not None and hasattr(cfg, "codebook_size")
+ else codebook_size
+ )
+ codebook_dim = (
+ cfg.codebook_dim
+ if cfg is not None and hasattr(cfg, "codebook_dim")
+ else codebook_dim
+ )
+ hidden_size = (
+ cfg.hidden_size
+ if cfg is not None and hasattr(cfg, "hidden_size")
+ else hidden_size
+ )
+ vocos_dim = (
+ cfg.vocos_dim
+ if cfg is not None and hasattr(cfg, "vocos_dim")
+ else vocos_dim
+ )
+ vocos_intermediate_dim = (
+ cfg.vocos_intermediate_dim
+ if cfg is not None and hasattr(cfg, "vocos_dim")
+ else vocos_intermediate_dim
+ )
+ vocos_num_layers = (
+ cfg.vocos_num_layers
+ if cfg is not None and hasattr(cfg, "vocos_dim")
+ else vocos_num_layers
+ )
+ num_quantizers = (
+ cfg.num_quantizers
+ if cfg is not None and hasattr(cfg, "num_quantizers")
+ else num_quantizers
+ )
+ downsample_scale = (
+ cfg.downsample_scale
+ if cfg is not None and hasattr(cfg, "downsample_scale")
+ else downsample_scale
+ )
+
+ self.codebook_size = codebook_size
+ self.codebook_dim = codebook_dim
+ self.hidden_size = hidden_size
+ self.vocos_dim = vocos_dim
+ self.vocos_intermediate_dim = vocos_intermediate_dim
+ self.vocos_num_layers = vocos_num_layers
+ self.num_quantizers = num_quantizers
+ self.downsample_scale = downsample_scale
+
+ if self.downsample_scale != None and self.downsample_scale > 1:
+ self.down = nn.Conv1d(
+ self.hidden_size, self.hidden_size, kernel_size=3, stride=2, padding=1
+ )
+ self.up = nn.Conv1d(
+ self.hidden_size, self.hidden_size, kernel_size=3, stride=1, padding=1
+ )
+
+ self.encoder = nn.Sequential(
+ VocosBackbone(
+ input_channels=self.hidden_size,
+ dim=self.vocos_dim,
+ intermediate_dim=self.vocos_intermediate_dim,
+ num_layers=self.vocos_num_layers,
+ adanorm_num_embeddings=None,
+ ),
+ nn.Linear(self.vocos_dim, self.hidden_size),
+ )
+ self.decoder = nn.Sequential(
+ VocosBackbone(
+ input_channels=self.hidden_size,
+ dim=self.vocos_dim,
+ intermediate_dim=self.vocos_intermediate_dim,
+ num_layers=self.vocos_num_layers,
+ adanorm_num_embeddings=None,
+ ),
+ nn.Linear(self.vocos_dim, self.hidden_size),
+ )
+
+ self.quantizer = ResidualVQ(
+ input_dim=hidden_size,
+ num_quantizers=num_quantizers,
+ codebook_size=codebook_size,
+ codebook_dim=codebook_dim,
+ quantizer_type="fvq",
+ quantizer_dropout=0.0,
+ commitment=0.15,
+ codebook_loss_weight=1.0,
+ use_l2_normlize=True,
+ )
+
+ self.reset_parameters()
+
+ def forward(self, x):
+
+ # downsample
+ if self.downsample_scale != None and self.downsample_scale > 1:
+ x = x.transpose(1, 2)
+ x = self.down(x)
+ x = F.gelu(x)
+ x = x.transpose(1, 2)
+
+ # encoder
+ x = self.encoder(x.transpose(1, 2)).transpose(1, 2)
+
+ # vq
+ (
+ quantized_out,
+ all_indices,
+ all_commit_losses,
+ all_codebook_losses,
+ _,
+ ) = self.quantizer(x)
+
+ # decoder
+ x = self.decoder(quantized_out)
+
+ # up
+ if self.downsample_scale != None and self.downsample_scale > 1:
+ x = x.transpose(1, 2)
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ x_rec = self.up(x).transpose(1, 2)
+
+ codebook_loss = (all_codebook_losses + all_commit_losses).mean()
+ all_indices = all_indices
+
+ return x_rec, codebook_loss, all_indices
+
+ def quantize(self, x):
+
+ if self.downsample_scale != None and self.downsample_scale > 1:
+ x = x.transpose(1, 2)
+ x = self.down(x)
+ x = F.gelu(x)
+ x = x.transpose(1, 2)
+
+ x = self.encoder(x.transpose(1, 2)).transpose(1, 2)
+
+ (
+ quantized_out,
+ all_indices,
+ all_commit_losses,
+ all_codebook_losses,
+ _,
+ ) = self.quantizer(x)
+
+ if all_indices.shape[0] == 1:
+ return all_indices.squeeze(0), quantized_out.transpose(1, 2)
+ return all_indices, quantized_out.transpose(1, 2)
+
+ def reset_parameters(self):
+ self.apply(init_weights)
+
+
+if __name__ == "__main__":
+ repcodec = RepCodec(vocos_dim=1024, downsample_scale=2)
+ print(repcodec)
+ print(sum(p.numel() for p in repcodec.parameters()) / 1e6)
+ x = torch.randn(5, 10, 1024)
+ x_rec, codebook_loss, all_indices = repcodec(x)
+ print(x_rec.shape, codebook_loss, all_indices.shape)
+ vq_id, emb = repcodec.quantize(x)
+ print(vq_id.shape, emb.shape)
diff --git a/models/codec/kmeans/vocos.py b/models/codec/kmeans/vocos.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d9f5a5ce9d8d4283ac313caeb983d3280afacd9
--- /dev/null
+++ b/models/codec/kmeans/vocos.py
@@ -0,0 +1,850 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Optional, Tuple
+
+import numpy as np
+import scipy
+import torch
+from torch import nn, view_as_real, view_as_complex
+from torch import nn
+from torch.nn.utils import weight_norm, remove_weight_norm
+from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
+
+
+def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
+ """
+ Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
+
+ Args:
+ x (Tensor): Input tensor.
+ clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
+
+ Returns:
+ Tensor: Element-wise logarithm of the input tensor with clipping applied.
+ """
+ return torch.log(torch.clip(x, min=clip_val))
+
+
+def symlog(x: torch.Tensor) -> torch.Tensor:
+ return torch.sign(x) * torch.log1p(x.abs())
+
+
+def symexp(x: torch.Tensor) -> torch.Tensor:
+ return torch.sign(x) * (torch.exp(x.abs()) - 1)
+
+
+class STFT(nn.Module):
+ def __init__(
+ self,
+ n_fft: int,
+ hop_length: int,
+ win_length: int,
+ center=True,
+ ):
+ super().__init__()
+ self.center = center
+ self.n_fft = n_fft
+ self.hop_length = hop_length
+ self.win_length = win_length
+ window = torch.hann_window(win_length)
+ self.register_buffer("window", window)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # x: (B, T * hop_length)
+
+ if not self.center:
+ pad = self.win_length - self.hop_length
+ x = torch.nn.functional.pad(x, (pad // 2, pad // 2), mode="reflect")
+
+ stft_spec = torch.stft(
+ x,
+ self.n_fft,
+ hop_length=self.hop_length,
+ win_length=self.win_length,
+ window=self.window,
+ center=self.center,
+ return_complex=False,
+ ) # (B, n_fft // 2 + 1, T, 2)
+
+ rea = stft_spec[:, :, :, 0] # (B, n_fft // 2 + 1, T, 2)
+ imag = stft_spec[:, :, :, 1] # (B, n_fft // 2 + 1, T, 2)
+
+ log_mag = torch.log(
+ torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2))) + 1e-5
+ ) # (B, n_fft // 2 + 1, T)
+ phase = torch.atan2(imag, rea) # (B, n_fft // 2 + 1, T)
+
+ return log_mag, phase
+
+
+class ISTFT(nn.Module):
+ """
+ Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
+ windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
+ See issue: https://github.com/pytorch/pytorch/issues/62323
+ Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
+ The NOLA constraint is met as we trim padded samples anyway.
+
+ Args:
+ n_fft (int): Size of Fourier transform.
+ hop_length (int): The distance between neighboring sliding window frames.
+ win_length (int): The size of window frame and STFT filter.
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+ """
+
+ def __init__(
+ self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
+ ):
+ super().__init__()
+ if padding not in ["center", "same"]:
+ raise ValueError("Padding must be 'center' or 'same'.")
+ self.padding = padding
+ self.n_fft = n_fft
+ self.hop_length = hop_length
+ self.win_length = win_length
+ window = torch.hann_window(win_length)
+ self.register_buffer("window", window)
+
+ def forward(self, spec: torch.Tensor) -> torch.Tensor:
+ """
+ Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
+
+ Args:
+ spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
+ N is the number of frequency bins, and T is the number of time frames.
+
+ Returns:
+ Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
+ """
+ if self.padding == "center":
+ # Fallback to pytorch native implementation
+ return torch.istft(
+ spec,
+ self.n_fft,
+ self.hop_length,
+ self.win_length,
+ self.window,
+ center=True,
+ )
+ elif self.padding == "same":
+ pad = (self.win_length - self.hop_length) // 2
+ else:
+ raise ValueError("Padding must be 'center' or 'same'.")
+
+ assert spec.dim() == 3, "Expected a 3D tensor as input"
+ B, N, T = spec.shape
+
+ # Inverse FFT
+ ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
+ ifft = ifft * self.window[None, :, None]
+
+ # Overlap and Add
+ output_size = (T - 1) * self.hop_length + self.win_length
+ y = torch.nn.functional.fold(
+ ifft,
+ output_size=(1, output_size),
+ kernel_size=(1, self.win_length),
+ stride=(1, self.hop_length),
+ )[:, 0, 0, pad:-pad]
+
+ # Window envelope
+ window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
+ window_envelope = torch.nn.functional.fold(
+ window_sq,
+ output_size=(1, output_size),
+ kernel_size=(1, self.win_length),
+ stride=(1, self.hop_length),
+ ).squeeze()[pad:-pad]
+
+ # Normalize
+ assert (window_envelope > 1e-11).all()
+ y = y / window_envelope
+
+ return y
+
+
+class MDCT(nn.Module):
+ """
+ Modified Discrete Cosine Transform (MDCT) module.
+
+ Args:
+ frame_len (int): Length of the MDCT frame.
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+ """
+
+ def __init__(self, frame_len: int, padding: str = "same"):
+ super().__init__()
+ if padding not in ["center", "same"]:
+ raise ValueError("Padding must be 'center' or 'same'.")
+ self.padding = padding
+ self.frame_len = frame_len
+ N = frame_len // 2
+ n0 = (N + 1) / 2
+ window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
+ self.register_buffer("window", window)
+
+ pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
+ post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
+ # view_as_real: NCCL Backend does not support ComplexFloat data type
+ # https://github.com/pytorch/pytorch/issues/71613
+ self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
+ self.register_buffer("post_twiddle", view_as_real(post_twiddle))
+
+ def forward(self, audio: torch.Tensor) -> torch.Tensor:
+ """
+ Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
+
+ Args:
+ audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
+ and T is the length of the audio.
+
+ Returns:
+ Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
+ and N is the number of frequency bins.
+ """
+ if self.padding == "center":
+ audio = torch.nn.functional.pad(
+ audio, (self.frame_len // 2, self.frame_len // 2)
+ )
+ elif self.padding == "same":
+ # hop_length is 1/2 frame_len
+ audio = torch.nn.functional.pad(
+ audio, (self.frame_len // 4, self.frame_len // 4)
+ )
+ else:
+ raise ValueError("Padding must be 'center' or 'same'.")
+
+ x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
+ N = self.frame_len // 2
+ x = x * self.window.expand(x.shape)
+ X = torch.fft.fft(
+ x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1
+ )[..., :N]
+ res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
+ return torch.real(res) * np.sqrt(2)
+
+
+class IMDCT(nn.Module):
+ """
+ Inverse Modified Discrete Cosine Transform (IMDCT) module.
+
+ Args:
+ frame_len (int): Length of the MDCT frame.
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+ """
+
+ def __init__(self, frame_len: int, padding: str = "same"):
+ super().__init__()
+ if padding not in ["center", "same"]:
+ raise ValueError("Padding must be 'center' or 'same'.")
+ self.padding = padding
+ self.frame_len = frame_len
+ N = frame_len // 2
+ n0 = (N + 1) / 2
+ window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
+ self.register_buffer("window", window)
+
+ pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
+ post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
+ self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
+ self.register_buffer("post_twiddle", view_as_real(post_twiddle))
+
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
+ """
+ Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
+
+ Args:
+ X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
+ L is the number of frames, and N is the number of frequency bins.
+
+ Returns:
+ Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
+ """
+ B, L, N = X.shape
+ Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
+ Y[..., :N] = X
+ Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
+ y = torch.fft.ifft(
+ Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1
+ )
+ y = (
+ torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape))
+ * np.sqrt(N)
+ * np.sqrt(2)
+ )
+ result = y * self.window.expand(y.shape)
+ output_size = (1, (L + 1) * N)
+ audio = torch.nn.functional.fold(
+ result.transpose(1, 2),
+ output_size=output_size,
+ kernel_size=(1, self.frame_len),
+ stride=(1, self.frame_len // 2),
+ )[:, 0, 0, :]
+
+ if self.padding == "center":
+ pad = self.frame_len // 2
+ elif self.padding == "same":
+ pad = self.frame_len // 4
+ else:
+ raise ValueError("Padding must be 'center' or 'same'.")
+
+ audio = audio[:, pad:-pad]
+ return audio
+
+
+class FourierHead(nn.Module):
+ """Base class for inverse fourier modules."""
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
+ L is the sequence length, and H denotes the model dimension.
+
+ Returns:
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
+ """
+ raise NotImplementedError("Subclasses must implement the forward method.")
+
+
+class ISTFTHead(FourierHead):
+ """
+ ISTFT Head module for predicting STFT complex coefficients.
+
+ Args:
+ dim (int): Hidden dimension of the model.
+ n_fft (int): Size of Fourier transform.
+ hop_length (int): The distance between neighboring sliding window frames, which should align with
+ the resolution of the input features.
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+ """
+
+ def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
+ super().__init__()
+ out_dim = n_fft + 2
+ self.out = torch.nn.Linear(dim, out_dim)
+ self.istft = ISTFT(
+ n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass of the ISTFTHead module.
+
+ Args:
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
+ L is the sequence length, and H denotes the model dimension.
+
+ Returns:
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
+ """
+ x = self.out(x).transpose(1, 2)
+ mag, p = x.chunk(2, dim=1)
+ mag = torch.exp(mag)
+ mag = torch.clip(
+ mag, max=1e2
+ ) # safeguard to prevent excessively large magnitudes
+ # wrapping happens here. These two lines produce real and imaginary value
+ x = torch.cos(p)
+ y = torch.sin(p)
+ # recalculating phase here does not produce anything new
+ # only costs time
+ # phase = torch.atan2(y, x)
+ # S = mag * torch.exp(phase * 1j)
+ # better directly produce the complex value
+ S = mag * (x + 1j * y)
+ audio = self.istft(S)
+ return audio
+
+
+class IMDCTSymExpHead(FourierHead):
+ """
+ IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
+
+ Args:
+ dim (int): Hidden dimension of the model.
+ mdct_frame_len (int): Length of the MDCT frame.
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+ sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
+ based on perceptual scaling. Defaults to None.
+ clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ mdct_frame_len: int,
+ padding: str = "same",
+ sample_rate: Optional[int] = None,
+ clip_audio: bool = False,
+ ):
+ super().__init__()
+ out_dim = mdct_frame_len // 2
+ self.out = nn.Linear(dim, out_dim)
+ self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
+ self.clip_audio = clip_audio
+
+ if sample_rate is not None:
+ # optionally init the last layer following mel-scale
+ m_max = _hz_to_mel(sample_rate // 2)
+ m_pts = torch.linspace(0, m_max, out_dim)
+ f_pts = _mel_to_hz(m_pts)
+ scale = 1 - (f_pts / f_pts.max())
+
+ with torch.no_grad():
+ self.out.weight.mul_(scale.view(-1, 1))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass of the IMDCTSymExpHead module.
+
+ Args:
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
+ L is the sequence length, and H denotes the model dimension.
+
+ Returns:
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
+ """
+ x = self.out(x)
+ x = symexp(x)
+ x = torch.clip(
+ x, min=-1e2, max=1e2
+ ) # safeguard to prevent excessively large magnitudes
+ audio = self.imdct(x)
+ if self.clip_audio:
+ audio = torch.clip(x, min=-1.0, max=1.0)
+
+ return audio
+
+
+class IMDCTCosHead(FourierHead):
+ """
+ IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
+
+ Args:
+ dim (int): Hidden dimension of the model.
+ mdct_frame_len (int): Length of the MDCT frame.
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+ clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ mdct_frame_len: int,
+ padding: str = "same",
+ clip_audio: bool = False,
+ ):
+ super().__init__()
+ self.clip_audio = clip_audio
+ self.out = nn.Linear(dim, mdct_frame_len)
+ self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass of the IMDCTCosHead module.
+
+ Args:
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
+ L is the sequence length, and H denotes the model dimension.
+
+ Returns:
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
+ """
+ x = self.out(x)
+ m, p = x.chunk(2, dim=2)
+ m = torch.exp(m).clip(
+ max=1e2
+ ) # safeguard to prevent excessively large magnitudes
+ audio = self.imdct(m * torch.cos(p))
+ if self.clip_audio:
+ audio = torch.clip(x, min=-1.0, max=1.0)
+ return audio
+
+
+class ConvNeXtBlock(nn.Module):
+ """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
+
+ Args:
+ dim (int): Number of input channels.
+ intermediate_dim (int): Dimensionality of the intermediate layer.
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
+ Defaults to None.
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
+ None means non-conditional LayerNorm. Defaults to None.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ intermediate_dim: int,
+ layer_scale_init_value: float,
+ adanorm_num_embeddings: Optional[int] = None,
+ ):
+ super().__init__()
+ self.dwconv = nn.Conv1d(
+ dim, dim, kernel_size=7, padding=3, groups=dim
+ ) # depthwise conv
+ self.adanorm = adanorm_num_embeddings is not None
+ if adanorm_num_embeddings:
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
+ else:
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
+ self.pwconv1 = nn.Linear(
+ dim, intermediate_dim
+ ) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
+ self.gamma = (
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
+ if layer_scale_init_value > 0
+ else None
+ )
+
+ def forward(
+ self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ residual = x
+ x = self.dwconv(x)
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
+ if self.adanorm:
+ assert cond_embedding_id is not None
+ x = self.norm(x, cond_embedding_id)
+ else:
+ x = self.norm(x)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+ if self.gamma is not None:
+ x = self.gamma * x
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
+
+ x = residual + x
+ return x
+
+
+class AdaLayerNorm(nn.Module):
+ """
+ Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
+
+ Args:
+ num_embeddings (int): Number of embeddings.
+ embedding_dim (int): Dimension of the embeddings.
+ """
+
+ def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
+ super().__init__()
+ self.eps = eps
+ self.dim = embedding_dim
+ self.scale = nn.Embedding(
+ num_embeddings=num_embeddings, embedding_dim=embedding_dim
+ )
+ self.shift = nn.Embedding(
+ num_embeddings=num_embeddings, embedding_dim=embedding_dim
+ )
+ torch.nn.init.ones_(self.scale.weight)
+ torch.nn.init.zeros_(self.shift.weight)
+
+ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
+ scale = self.scale(cond_embedding_id)
+ shift = self.shift(cond_embedding_id)
+ x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
+ x = x * scale + shift
+ return x
+
+
+class ResBlock1(nn.Module):
+ """
+ ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
+ but without upsampling layers.
+
+ Args:
+ dim (int): Number of input channels.
+ kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
+ dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
+ Defaults to (1, 3, 5).
+ lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
+ Defaults to 0.1.
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
+ Defaults to None.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ kernel_size: int = 3,
+ dilation: Tuple[int, int, int] = (1, 3, 5),
+ lrelu_slope: float = 0.1,
+ layer_scale_init_value: Optional[float] = None,
+ ):
+ super().__init__()
+ self.lrelu_slope = lrelu_slope
+ self.convs1 = nn.ModuleList(
+ [
+ weight_norm(
+ nn.Conv1d(
+ dim,
+ dim,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=self.get_padding(kernel_size, dilation[0]),
+ )
+ ),
+ weight_norm(
+ nn.Conv1d(
+ dim,
+ dim,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=self.get_padding(kernel_size, dilation[1]),
+ )
+ ),
+ weight_norm(
+ nn.Conv1d(
+ dim,
+ dim,
+ kernel_size,
+ 1,
+ dilation=dilation[2],
+ padding=self.get_padding(kernel_size, dilation[2]),
+ )
+ ),
+ ]
+ )
+
+ self.convs2 = nn.ModuleList(
+ [
+ weight_norm(
+ nn.Conv1d(
+ dim,
+ dim,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=self.get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ nn.Conv1d(
+ dim,
+ dim,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=self.get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ nn.Conv1d(
+ dim,
+ dim,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=self.get_padding(kernel_size, 1),
+ )
+ ),
+ ]
+ )
+
+ self.gamma = nn.ParameterList(
+ [
+ (
+ nn.Parameter(
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
+ )
+ if layer_scale_init_value is not None
+ else None
+ ),
+ (
+ nn.Parameter(
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
+ )
+ if layer_scale_init_value is not None
+ else None
+ ),
+ (
+ nn.Parameter(
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
+ )
+ if layer_scale_init_value is not None
+ else None
+ ),
+ ]
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
+ xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
+ xt = c1(xt)
+ xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
+ xt = c2(xt)
+ if gamma is not None:
+ xt = gamma * xt
+ x = xt + x
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs1:
+ remove_weight_norm(l)
+ for l in self.convs2:
+ remove_weight_norm(l)
+
+ @staticmethod
+ def get_padding(kernel_size: int, dilation: int = 1) -> int:
+ return int((kernel_size * dilation - dilation) / 2)
+
+
+class Backbone(nn.Module):
+ """Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
+
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
+ """
+ Args:
+ x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
+ C denotes output features, and L is the sequence length.
+
+ Returns:
+ Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
+ and H denotes the model dimension.
+ """
+ raise NotImplementedError("Subclasses must implement the forward method.")
+
+
+class VocosBackbone(Backbone):
+ """
+ Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
+
+ Args:
+ input_channels (int): Number of input features channels.
+ dim (int): Hidden dimension of the model.
+ intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
+ num_layers (int): Number of ConvNeXtBlock layers.
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
+ None means non-conditional model. Defaults to None.
+ """
+
+ def __init__(
+ self,
+ input_channels: int,
+ dim: int,
+ intermediate_dim: int,
+ num_layers: int,
+ layer_scale_init_value: Optional[float] = None,
+ adanorm_num_embeddings: Optional[int] = None,
+ ):
+ super().__init__()
+ self.input_channels = input_channels
+ self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
+ self.adanorm = adanorm_num_embeddings is not None
+ if adanorm_num_embeddings:
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
+ else:
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
+ layer_scale_init_value = layer_scale_init_value or 1 / num_layers
+ self.convnext = nn.ModuleList(
+ [
+ ConvNeXtBlock(
+ dim=dim,
+ intermediate_dim=intermediate_dim,
+ layer_scale_init_value=layer_scale_init_value,
+ adanorm_num_embeddings=adanorm_num_embeddings,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+ self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
+ bandwidth_id = kwargs.get("bandwidth_id", None)
+ x = self.embed(x)
+ if self.adanorm:
+ assert bandwidth_id is not None
+ x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
+ else:
+ x = self.norm(x.transpose(1, 2))
+ x = x.transpose(1, 2)
+ for conv_block in self.convnext:
+ x = conv_block(x, cond_embedding_id=bandwidth_id)
+ x = self.final_layer_norm(x.transpose(1, 2))
+ return x
+
+
+class VocosResNetBackbone(Backbone):
+ """
+ Vocos backbone module built with ResBlocks.
+
+ Args:
+ input_channels (int): Number of input features channels.
+ dim (int): Hidden dimension of the model.
+ num_blocks (int): Number of ResBlock1 blocks.
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
+ """
+
+ def __init__(
+ self,
+ input_channels,
+ dim,
+ num_blocks,
+ layer_scale_init_value=None,
+ ):
+ super().__init__()
+ self.input_channels = input_channels
+ self.embed = weight_norm(
+ nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
+ )
+ layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
+ self.resnet = nn.Sequential(
+ *[
+ ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
+ for _ in range(num_blocks)
+ ]
+ )
+
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
+ x = self.embed(x)
+ x = self.resnet(x)
+ x = x.transpose(1, 2)
+ return x
+
+
+class Vocos(nn.Module):
+ def __init__(
+ self,
+ input_channels: int = 256,
+ dim: int = 384,
+ intermediate_dim: int = 1152,
+ num_layers: int = 8,
+ adanorm_num_embeddings: int = 4,
+ n_fft: int = 800,
+ hop_size: int = 200,
+ padding: str = "same",
+ ):
+ super().__init__()
+
+ self.backbone = VocosBackbone(
+ input_channels=input_channels,
+ dim=dim,
+ intermediate_dim=intermediate_dim,
+ num_layers=num_layers,
+ adanorm_num_embeddings=adanorm_num_embeddings,
+ )
+ self.head = ISTFTHead(dim, n_fft, hop_size, padding)
+
+ def forward(self, x):
+ x = self.backbone(x)
+ x = self.head(x)
+
+ return x[:, None, :]
diff --git a/models/codec/ns3_codec/README.md b/models/codec/ns3_codec/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..1283d677f41d16f72577586ceb57bcf49241280c
--- /dev/null
+++ b/models/codec/ns3_codec/README.md
@@ -0,0 +1,216 @@
+## FACodec: Speech Codec with Attribute Factorization used for NaturalSpeech 3
+
+[![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/pdf/2403.03100.pdf)
+[![demo](https://img.shields.io/badge/FACodec-Demo-red)](https://speechresearch.github.io/naturalspeech3/)
+[![model](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Models-pink)](https://huggingface.co/amphion/naturalspeech3_facodec)
+[![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Spaces-yellow)](https://huggingface.co/spaces/amphion/naturalspeech3_facodec)
+
+## Overview
+
+FACodec is a core component of the advanced text-to-speech (TTS) model NaturalSpeech 3. FACodec converts complex speech waveform into disentangled subspaces representing speech attributes of content, prosody, timbre, and acoustic details and reconstruct high-quality speech waveform from these attributes. FACodec decomposes complex speech into subspaces representing different attributes, thus simplifying the modeling of speech representation.
+
+Research can use FACodec to develop different modes of TTS models, such as non-autoregressive based discrete diffusion (NaturalSpeech 3) or autoregressive models (like VALL-E).
+
+
+
+
+
+
+
+
+
+
+
+
+
+## Useage
+
+Download the pre-trained FACodec model from HuggingFace: [Pretrained FACodec checkpoint](https://huggingface.co/amphion/naturalspeech3_facodec)
+
+Install Amphion
+```bash
+git clone https://github.com/open-mmlab/Amphion.git
+```
+
+Few lines of code to use the pre-trained FACodec model
+```python
+from Amphion.models.codec.ns3_codec import FACodecEncoder, FACodecDecoder
+from huggingface_hub import hf_hub_download
+
+fa_encoder = FACodecEncoder(
+ ngf=32,
+ up_ratios=[2, 4, 5, 5],
+ out_channels=256,
+)
+
+fa_decoder = FACodecDecoder(
+ in_channels=256,
+ upsample_initial_channel=1024,
+ ngf=32,
+ up_ratios=[5, 5, 4, 2],
+ vq_num_q_c=2,
+ vq_num_q_p=1,
+ vq_num_q_r=3,
+ vq_dim=256,
+ codebook_dim=8,
+ codebook_size_prosody=10,
+ codebook_size_content=10,
+ codebook_size_residual=10,
+ use_gr_x_timbre=True,
+ use_gr_residual_f0=True,
+ use_gr_residual_phone=True,
+)
+
+encoder_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_encoder.bin")
+decoder_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_decoder.bin")
+
+fa_encoder.load_state_dict(torch.load(encoder_ckpt))
+fa_decoder.load_state_dict(torch.load(decoder_ckpt))
+
+fa_encoder.eval()
+fa_decoder.eval()
+
+```
+
+Inference
+```python
+test_wav_path = "test.wav"
+test_wav = librosa.load(test_wav_path, sr=16000)[0]
+test_wav = torch.from_numpy(test_wav).float()
+test_wav = test_wav.unsqueeze(0).unsqueeze(0)
+
+with torch.no_grad():
+
+ # encode
+ enc_out = fa_encoder(test_wav)
+ print(enc_out.shape)
+
+ # quantize
+ vq_post_emb, vq_id, _, quantized, spk_embs = fa_decoder(enc_out, eval_vq=False, vq=True)
+
+ # latent after quantization
+ print(vq_post_emb.shape)
+
+ # codes
+ print("vq id shape:", vq_id.shape)
+
+ # get prosody code
+ prosody_code = vq_id[:1]
+ print("prosody code shape:", prosody_code.shape)
+
+ # get content code
+ cotent_code = vq_id[1:3]
+ print("content code shape:", cotent_code.shape)
+
+ # get residual code (acoustic detail codes)
+ residual_code = vq_id[3:]
+ print("residual code shape:", residual_code.shape)
+
+ # speaker embedding
+ print("speaker embedding shape:", spk_embs.shape)
+
+ # decode (recommand)
+ recon_wav = fa_decoder.inference(vq_post_emb, spk_embs)
+ print(recon_wav.shape)
+ sf.write("recon.wav", recon_wav[0][0].cpu().numpy(), 16000)
+```
+
+FACodec can achieve zero-shot voice conversion with FACodecEncoderV2/FACodecDecoderV2 or FACodecRedecoder
+```python
+from Amphion.models.codec.ns3_codec import FACodecEncoderV2, FACodecDecoderV2
+
+# Same parameters as FACodecEncoder/FACodecDecoder
+fa_encoder_v2 = FACodecEncoderV2(...)
+fa_decoder_v2 = FACodecDecoderV2(...)
+
+encoder_v2_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_encoder_v2.bin")
+decoder_v2_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_decoder_v2.bin")
+
+fa_encoder_v2.load_state_dict(torch.load(encoder_v2_ckpt))
+fa_decoder_v2.load_state_dict(torch.load(decoder_v2_ckpt))
+
+with torch.no_grad():
+ enc_out_a = fa_encoder_v2(wav_a)
+ prosody_a = fa_encoder_v2.get_prosody_feature(wav_a)
+ enc_out_b = fa_encoder_v2(wav_b)
+ prosody_b = fa_encoder_v2.get_prosody_feature(wav_b)
+
+ vq_post_emb_a, vq_id_a, _, quantized, spk_embs_a = fa_decoder_v2(
+ enc_out_a, prosody_a, eval_vq=False, vq=True
+ )
+ vq_post_emb_b, vq_id_b, _, quantized, spk_embs_b = fa_decoder_v2(
+ enc_out_b, prosody_b, eval_vq=False, vq=True
+ )
+
+ vq_post_emb_a_to_b = fa_decoder_v2.vq2emb(vq_id_a, use_residual=False)
+ recon_wav_a_to_b = fa_decoder_v2.inference(vq_post_emb_a_to_b, spk_embs_b)
+```
+
+or
+
+```python
+from Amphion.models.codec.ns3_codec import FACodecRedecoder
+
+fa_redecoder = FACodecRedecoder()
+
+redecoder_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_redecoder.bin")
+
+fa_redecoder.load_state_dict(torch.load(redecoder_ckpt))
+
+with torch.no_grad():
+ enc_out_a = fa_encoder(wav_a)
+ enc_out_b = fa_encoder(wav_b)
+
+ vq_post_emb_a, vq_id_a, _, quantized_a, spk_embs_a = fa_decoder(enc_out_a, eval_vq=False, vq=True)
+ vq_post_emb_b, vq_id_b, _, quantized_b, spk_embs_b = fa_decoder(enc_out_b, eval_vq=False, vq=True)
+
+ # convert speaker
+ vq_post_emb_a_to_b = fa_redecoder.vq2emb(vq_id_a, spk_embs_b, use_residual=False)
+ recon_wav_a_to_b = fa_redecoder.inference(vq_post_emb_a_to_b, spk_embs_b)
+
+ sf.write("recon_a_to_b.wav", recon_wav_a_to_b[0][0].cpu().numpy(), 16000)
+```
+
+## Q&A
+
+Q1: What audio sample rate does FACodec support? What is the hop size? How many codes will be generated for each frame?
+
+A1: FACodec supports 16KHz speech audio. The hop size is 200 samples, and (16000/200) * 6 (total number of codebooks) codes will be generated for each frame.
+
+Q2: Is it possible to train an autoregressive TTS model like VALL-E using FACodec?
+
+A2: Yes. In fact, the authors of NaturalSpeech 3 have already employ explore the autoregressive generative model for discrete token generation with FACodec. They use an autoregressive language model to generate prosody codes, followed by a non-autoregressive model to generate the remaining content and acoustic details codes.
+
+Q3: Is it possible to train a latent diffusion TTS model like NaturalSpeech2 using FACodec?
+
+A3: Yes. You can use the latent getted after quanzaition as the modelling target for the latent diffusion model.
+
+Q4: Can FACodec compress and reconstruct audio from other domains? Such as sound effects, music, etc.
+
+A4: Since FACodec is designed for speech, it may not be suitable for other audio domains. However, it is possible to use the FACodec model to compress and reconstruct audio from other domains, but the quality may not be as good as the original audio.
+
+Q5: Can FACodec be used for content feature for some other tasks like voice conversion?
+
+A5: I think the answer is yes. Researchers can use the content code of FACodec as the content feature for voice conversion. We hope to see more research in this direction.
+
+## Citations
+
+If you use our FACodec model, please cite the following paper:
+
+```bibtex
+@article{ju2024naturalspeech,
+ title={NaturalSpeech 3: Zero-Shot Speech Synthesis with Factorized Codec and Diffusion Models},
+ author={Ju, Zeqian and Wang, Yuancheng and Shen, Kai and Tan, Xu and Xin, Detai and Yang, Dongchao and Liu, Yanqing and Leng, Yichong and Song, Kaitao and Tang, Siliang and others},
+ journal={arXiv preprint arXiv:2403.03100},
+ year={2024}
+}
+
+@article{zhang2023amphion,
+ title={Amphion: An Open-Source Audio, Music and Speech Generation Toolkit},
+ author={Xueyao Zhang and Liumeng Xue and Yicheng Gu and Yuancheng Wang and Haorui He and Chaoren Wang and Xi Chen and Zihao Fang and Haopeng Chen and Junan Zhang and Tze Ying Tang and Lexiao Zou and Mingxuan Wang and Jun Han and Kai Chen and Haizhou Li and Zhizheng Wu},
+ journal={arXiv},
+ year={2024},
+ volume={abs/2312.09911}
+}
+```
+
diff --git a/models/codec/ns3_codec/__init__.py b/models/codec/ns3_codec/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f0e4c194e9a02d93f7f1c8f4bed05f460b36e20
--- /dev/null
+++ b/models/codec/ns3_codec/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .facodec import *
diff --git a/models/codec/ns3_codec/alias_free_torch/__init__.py b/models/codec/ns3_codec/alias_free_torch/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3bccdb97a3706bcb7149f48e04178cf00a5e877
--- /dev/null
+++ b/models/codec/ns3_codec/alias_free_torch/__init__.py
@@ -0,0 +1,5 @@
+# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
+
+from .filter import *
+from .resample import *
+from .act import *
diff --git a/models/codec/ns3_codec/alias_free_torch/act.py b/models/codec/ns3_codec/alias_free_torch/act.py
new file mode 100644
index 0000000000000000000000000000000000000000..779d58d5f1e889f8b639dd019a0ce951e69e4cfb
--- /dev/null
+++ b/models/codec/ns3_codec/alias_free_torch/act.py
@@ -0,0 +1,29 @@
+# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
+
+import torch.nn as nn
+from .resample import UpSample1d, DownSample1d
+
+
+class Activation1d(nn.Module):
+ def __init__(
+ self,
+ activation,
+ up_ratio: int = 2,
+ down_ratio: int = 2,
+ up_kernel_size: int = 12,
+ down_kernel_size: int = 12,
+ ):
+ super().__init__()
+ self.up_ratio = up_ratio
+ self.down_ratio = down_ratio
+ self.act = activation
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
+
+ # x: [B,C,T]
+ def forward(self, x):
+ x = self.upsample(x)
+ x = self.act(x)
+ x = self.downsample(x)
+
+ return x
diff --git a/models/codec/ns3_codec/alias_free_torch/filter.py b/models/codec/ns3_codec/alias_free_torch/filter.py
new file mode 100644
index 0000000000000000000000000000000000000000..ece8e02fce0e65e13522e990a80d1bfeeffd46ba
--- /dev/null
+++ b/models/codec/ns3_codec/alias_free_torch/filter.py
@@ -0,0 +1,96 @@
+# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+
+if "sinc" in dir(torch):
+ sinc = torch.sinc
+else:
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
+ # https://adefossez.github.io/julius/julius/core.html
+ def sinc(x: torch.Tensor):
+ """
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
+ """
+ return torch.where(
+ x == 0,
+ torch.tensor(1.0, device=x.device, dtype=x.dtype),
+ torch.sin(math.pi * x) / math.pi / x,
+ )
+
+
+# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
+# https://adefossez.github.io/julius/julius/lowpass.html
+def kaiser_sinc_filter1d(
+ cutoff, half_width, kernel_size
+): # return filter [1,1,kernel_size]
+ even = kernel_size % 2 == 0
+ half_size = kernel_size // 2
+
+ # For kaiser window
+ delta_f = 4 * half_width
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
+ if A > 50.0:
+ beta = 0.1102 * (A - 8.7)
+ elif A >= 21.0:
+ beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
+ else:
+ beta = 0.0
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
+
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
+ if even:
+ time = torch.arange(-half_size, half_size) + 0.5
+ else:
+ time = torch.arange(kernel_size) - half_size
+ if cutoff == 0:
+ filter_ = torch.zeros_like(time)
+ else:
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
+ # Normalize filter to have sum = 1, otherwise we will have a small leakage
+ # of the constant component in the input signal.
+ filter_ /= filter_.sum()
+ filter = filter_.view(1, 1, kernel_size)
+
+ return filter
+
+
+class LowPassFilter1d(nn.Module):
+ def __init__(
+ self,
+ cutoff=0.5,
+ half_width=0.6,
+ stride: int = 1,
+ padding: bool = True,
+ padding_mode: str = "replicate",
+ kernel_size: int = 12,
+ ):
+ # kernel_size should be even number for stylegan3 setup,
+ # in this implementation, odd number is also possible.
+ super().__init__()
+ if cutoff < -0.0:
+ raise ValueError("Minimum cutoff must be larger than zero.")
+ if cutoff > 0.5:
+ raise ValueError("A cutoff above 0.5 does not make sense.")
+ self.kernel_size = kernel_size
+ self.even = kernel_size % 2 == 0
+ self.pad_left = kernel_size // 2 - int(self.even)
+ self.pad_right = kernel_size // 2
+ self.stride = stride
+ self.padding = padding
+ self.padding_mode = padding_mode
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
+ self.register_buffer("filter", filter)
+
+ # input [B, C, T]
+ def forward(self, x):
+ _, C, _ = x.shape
+
+ if self.padding:
+ x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
+ out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
+
+ return out
diff --git a/models/codec/ns3_codec/alias_free_torch/resample.py b/models/codec/ns3_codec/alias_free_torch/resample.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee993b10339141b469b67c3e11f5d73c5f4e0bca
--- /dev/null
+++ b/models/codec/ns3_codec/alias_free_torch/resample.py
@@ -0,0 +1,57 @@
+# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
+
+import torch.nn as nn
+from torch.nn import functional as F
+from .filter import LowPassFilter1d
+from .filter import kaiser_sinc_filter1d
+
+
+class UpSample1d(nn.Module):
+ def __init__(self, ratio=2, kernel_size=None):
+ super().__init__()
+ self.ratio = ratio
+ self.kernel_size = (
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
+ )
+ self.stride = ratio
+ self.pad = self.kernel_size // ratio - 1
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
+ self.pad_right = (
+ self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
+ )
+ filter = kaiser_sinc_filter1d(
+ cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
+ )
+ self.register_buffer("filter", filter)
+
+ # x: [B, C, T]
+ def forward(self, x):
+ _, C, _ = x.shape
+
+ x = F.pad(x, (self.pad, self.pad), mode="replicate")
+ x = self.ratio * F.conv_transpose1d(
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
+ )
+ x = x[..., self.pad_left : -self.pad_right]
+
+ return x
+
+
+class DownSample1d(nn.Module):
+ def __init__(self, ratio=2, kernel_size=None):
+ super().__init__()
+ self.ratio = ratio
+ self.kernel_size = (
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
+ )
+ self.lowpass = LowPassFilter1d(
+ cutoff=0.5 / ratio,
+ half_width=0.6 / ratio,
+ stride=ratio,
+ kernel_size=self.kernel_size,
+ )
+
+ def forward(self, x):
+ xx = self.lowpass(x)
+
+ return xx
diff --git a/models/codec/ns3_codec/facodec.py b/models/codec/ns3_codec/facodec.py
new file mode 100644
index 0000000000000000000000000000000000000000..87f661bdfa250e5be7514946934f6873a981b9c3
--- /dev/null
+++ b/models/codec/ns3_codec/facodec.py
@@ -0,0 +1,1222 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+from torch import nn, sin, pow
+from torch.nn import Parameter
+import torch.nn.functional as F
+from torch.nn.utils import weight_norm
+from .alias_free_torch import *
+from .quantize import *
+from einops import rearrange
+from einops.layers.torch import Rearrange
+from .transformer import TransformerEncoder
+from .gradient_reversal import GradientReversal
+from .melspec import MelSpectrogram
+
+
+def init_weights(m):
+ if isinstance(m, nn.Conv1d):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ nn.init.constant_(m.bias, 0)
+
+
+def WNConv1d(*args, **kwargs):
+ return weight_norm(nn.Conv1d(*args, **kwargs))
+
+
+def WNConvTranspose1d(*args, **kwargs):
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
+
+
+class CNNLSTM(nn.Module):
+ def __init__(self, indim, outdim, head, global_pred=False):
+ super().__init__()
+ self.global_pred = global_pred
+ self.model = nn.Sequential(
+ ResidualUnit(indim, dilation=1),
+ ResidualUnit(indim, dilation=2),
+ ResidualUnit(indim, dilation=3),
+ Activation1d(activation=SnakeBeta(indim, alpha_logscale=True)),
+ Rearrange("b c t -> b t c"),
+ )
+ self.heads = nn.ModuleList([nn.Linear(indim, outdim) for i in range(head)])
+
+ def forward(self, x):
+ # x: [B, C, T]
+ x = self.model(x)
+ if self.global_pred:
+ x = torch.mean(x, dim=1, keepdim=False)
+ outs = [head(x) for head in self.heads]
+ return outs
+
+
+class SnakeBeta(nn.Module):
+ """
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
+ Shape:
+ - Input: (B, C, T)
+ - Output: (B, C, T), same shape as the input
+ Parameters:
+ - alpha - trainable parameter that controls frequency
+ - beta - trainable parameter that controls magnitude
+ References:
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
+ https://arxiv.org/abs/2006.08195
+ Examples:
+ >>> a1 = snakebeta(256)
+ >>> x = torch.randn(256)
+ >>> x = a1(x)
+ """
+
+ def __init__(
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
+ ):
+ """
+ Initialization.
+ INPUT:
+ - in_features: shape of the input
+ - alpha - trainable parameter that controls frequency
+ - beta - trainable parameter that controls magnitude
+ alpha is initialized to 1 by default, higher values = higher-frequency.
+ beta is initialized to 1 by default, higher values = higher-magnitude.
+ alpha will be trained along with the rest of your model.
+ """
+ super(SnakeBeta, self).__init__()
+ self.in_features = in_features
+
+ # initialize alpha
+ self.alpha_logscale = alpha_logscale
+ if self.alpha_logscale: # log scale alphas initialized to zeros
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
+ else: # linear scale alphas initialized to ones
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
+ self.beta = Parameter(torch.ones(in_features) * alpha)
+
+ self.alpha.requires_grad = alpha_trainable
+ self.beta.requires_grad = alpha_trainable
+
+ self.no_div_by_zero = 0.000000001
+
+ def forward(self, x):
+ """
+ Forward pass of the function.
+ Applies the function to the input elementwise.
+ SnakeBeta := x + 1/b * sin^2 (xa)
+ """
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
+ if self.alpha_logscale:
+ alpha = torch.exp(alpha)
+ beta = torch.exp(beta)
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
+
+ return x
+
+
+class ResidualUnit(nn.Module):
+ def __init__(self, dim: int = 16, dilation: int = 1):
+ super().__init__()
+ pad = ((7 - 1) * dilation) // 2
+ self.block = nn.Sequential(
+ Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
+ WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
+ Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
+ WNConv1d(dim, dim, kernel_size=1),
+ )
+
+ def forward(self, x):
+ return x + self.block(x)
+
+
+class EncoderBlock(nn.Module):
+ def __init__(self, dim: int = 16, stride: int = 1):
+ super().__init__()
+ self.block = nn.Sequential(
+ ResidualUnit(dim // 2, dilation=1),
+ ResidualUnit(dim // 2, dilation=3),
+ ResidualUnit(dim // 2, dilation=9),
+ Activation1d(activation=SnakeBeta(dim // 2, alpha_logscale=True)),
+ WNConv1d(
+ dim // 2,
+ dim,
+ kernel_size=2 * stride,
+ stride=stride,
+ padding=stride // 2 + stride % 2,
+ ),
+ )
+
+ def forward(self, x):
+ return self.block(x)
+
+
+class FACodecEncoder(nn.Module):
+ def __init__(
+ self,
+ ngf=32,
+ up_ratios=(2, 4, 5, 5),
+ out_channels=1024,
+ ):
+ super().__init__()
+ self.hop_length = np.prod(up_ratios)
+ self.up_ratios = up_ratios
+
+ # Create first convolution
+ d_model = ngf
+ self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
+
+ # Create EncoderBlocks that double channels as they downsample by `stride`
+ for stride in up_ratios:
+ d_model *= 2
+ self.block += [EncoderBlock(d_model, stride=stride)]
+
+ # Create last convolution
+ self.block += [
+ Activation1d(activation=SnakeBeta(d_model, alpha_logscale=True)),
+ WNConv1d(d_model, out_channels, kernel_size=3, padding=1),
+ ]
+
+ # Wrap black into nn.Sequential
+ self.block = nn.Sequential(*self.block)
+ self.enc_dim = d_model
+
+ self.reset_parameters()
+
+ def forward(self, x):
+ out = self.block(x)
+ return out
+
+ def inference(self, x):
+ return self.block(x)
+
+ def remove_weight_norm(self):
+ """Remove weight normalization module from all of the layers."""
+
+ def _remove_weight_norm(m):
+ try:
+ torch.nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(_remove_weight_norm)
+
+ def apply_weight_norm(self):
+ """Apply weight normalization module from all of the layers."""
+
+ def _apply_weight_norm(m):
+ if isinstance(m, nn.Conv1d):
+ torch.nn.utils.weight_norm(m)
+
+ self.apply(_apply_weight_norm)
+
+ def reset_parameters(self):
+ self.apply(init_weights)
+
+
+class DecoderBlock(nn.Module):
+ def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
+ super().__init__()
+ self.block = nn.Sequential(
+ Activation1d(activation=SnakeBeta(input_dim, alpha_logscale=True)),
+ WNConvTranspose1d(
+ input_dim,
+ output_dim,
+ kernel_size=2 * stride,
+ stride=stride,
+ padding=stride // 2 + stride % 2,
+ output_padding=stride % 2,
+ ),
+ ResidualUnit(output_dim, dilation=1),
+ ResidualUnit(output_dim, dilation=3),
+ ResidualUnit(output_dim, dilation=9),
+ )
+
+ def forward(self, x):
+ return self.block(x)
+
+
+class FACodecDecoder(nn.Module):
+ def __init__(
+ self,
+ in_channels=256,
+ upsample_initial_channel=1536,
+ ngf=32,
+ up_ratios=(5, 5, 4, 2),
+ vq_num_q_c=2,
+ vq_num_q_p=1,
+ vq_num_q_r=3,
+ vq_dim=1024,
+ vq_commit_weight=0.005,
+ vq_weight_init=False,
+ vq_full_commit_loss=False,
+ codebook_dim=8,
+ codebook_size_prosody=10, # true codebook size is equal to 2^codebook_size
+ codebook_size_content=10,
+ codebook_size_residual=10,
+ quantizer_dropout=0.0,
+ dropout_type="linear",
+ use_gr_content_f0=False,
+ use_gr_prosody_phone=False,
+ use_gr_residual_f0=False,
+ use_gr_residual_phone=False,
+ use_gr_x_timbre=False,
+ use_random_mask_residual=True,
+ prob_random_mask_residual=0.75,
+ ):
+ super().__init__()
+ self.hop_length = np.prod(up_ratios)
+ self.ngf = ngf
+ self.up_ratios = up_ratios
+
+ self.use_random_mask_residual = use_random_mask_residual
+ self.prob_random_mask_residual = prob_random_mask_residual
+
+ self.vq_num_q_p = vq_num_q_p
+ self.vq_num_q_c = vq_num_q_c
+ self.vq_num_q_r = vq_num_q_r
+
+ self.codebook_size_prosody = codebook_size_prosody
+ self.codebook_size_content = codebook_size_content
+ self.codebook_size_residual = codebook_size_residual
+
+ quantizer_class = ResidualVQ
+
+ self.quantizer = nn.ModuleList()
+
+ # prosody
+ quantizer = quantizer_class(
+ num_quantizers=vq_num_q_p,
+ dim=vq_dim,
+ codebook_size=codebook_size_prosody,
+ codebook_dim=codebook_dim,
+ threshold_ema_dead_code=2,
+ commitment=vq_commit_weight,
+ weight_init=vq_weight_init,
+ full_commit_loss=vq_full_commit_loss,
+ quantizer_dropout=quantizer_dropout,
+ dropout_type=dropout_type,
+ )
+ self.quantizer.append(quantizer)
+
+ # phone
+ quantizer = quantizer_class(
+ num_quantizers=vq_num_q_c,
+ dim=vq_dim,
+ codebook_size=codebook_size_content,
+ codebook_dim=codebook_dim,
+ threshold_ema_dead_code=2,
+ commitment=vq_commit_weight,
+ weight_init=vq_weight_init,
+ full_commit_loss=vq_full_commit_loss,
+ quantizer_dropout=quantizer_dropout,
+ dropout_type=dropout_type,
+ )
+ self.quantizer.append(quantizer)
+
+ # residual
+ if self.vq_num_q_r > 0:
+ quantizer = quantizer_class(
+ num_quantizers=vq_num_q_r,
+ dim=vq_dim,
+ codebook_size=codebook_size_residual,
+ codebook_dim=codebook_dim,
+ threshold_ema_dead_code=2,
+ commitment=vq_commit_weight,
+ weight_init=vq_weight_init,
+ full_commit_loss=vq_full_commit_loss,
+ quantizer_dropout=quantizer_dropout,
+ dropout_type=dropout_type,
+ )
+ self.quantizer.append(quantizer)
+
+ # Add first conv layer
+ channels = upsample_initial_channel
+ layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)]
+
+ # Add upsampling + MRF blocks
+ for i, stride in enumerate(up_ratios):
+ input_dim = channels // 2**i
+ output_dim = channels // 2 ** (i + 1)
+ layers += [DecoderBlock(input_dim, output_dim, stride)]
+
+ # Add final conv layer
+ layers += [
+ Activation1d(activation=SnakeBeta(output_dim, alpha_logscale=True)),
+ WNConv1d(output_dim, 1, kernel_size=7, padding=3),
+ nn.Tanh(),
+ ]
+
+ self.model = nn.Sequential(*layers)
+
+ self.timbre_encoder = TransformerEncoder(
+ enc_emb_tokens=None,
+ encoder_layer=4,
+ encoder_hidden=256,
+ encoder_head=4,
+ conv_filter_size=1024,
+ conv_kernel_size=5,
+ encoder_dropout=0.1,
+ use_cln=False,
+ )
+
+ self.timbre_linear = nn.Linear(in_channels, in_channels * 2)
+ self.timbre_linear.bias.data[:in_channels] = 1
+ self.timbre_linear.bias.data[in_channels:] = 0
+ self.timbre_norm = nn.LayerNorm(in_channels, elementwise_affine=False)
+
+ self.f0_predictor = CNNLSTM(in_channels, 1, 2)
+ self.phone_predictor = CNNLSTM(in_channels, 5003, 1)
+
+ self.use_gr_content_f0 = use_gr_content_f0
+ self.use_gr_prosody_phone = use_gr_prosody_phone
+ self.use_gr_residual_f0 = use_gr_residual_f0
+ self.use_gr_residual_phone = use_gr_residual_phone
+ self.use_gr_x_timbre = use_gr_x_timbre
+
+ if self.vq_num_q_r > 0 and self.use_gr_residual_f0:
+ self.res_f0_predictor = nn.Sequential(
+ GradientReversal(alpha=1.0), CNNLSTM(in_channels, 1, 2)
+ )
+
+ if self.vq_num_q_r > 0 and self.use_gr_residual_phone > 0:
+ self.res_phone_predictor = nn.Sequential(
+ GradientReversal(alpha=1.0), CNNLSTM(in_channels, 5003, 1)
+ )
+
+ if self.use_gr_content_f0:
+ self.content_f0_predictor = nn.Sequential(
+ GradientReversal(alpha=1.0), CNNLSTM(in_channels, 1, 2)
+ )
+
+ if self.use_gr_prosody_phone:
+ self.prosody_phone_predictor = nn.Sequential(
+ GradientReversal(alpha=1.0), CNNLSTM(in_channels, 5003, 1)
+ )
+
+ if self.use_gr_x_timbre:
+ self.x_timbre_predictor = nn.Sequential(
+ GradientReversal(alpha=1),
+ CNNLSTM(in_channels, 245200, 1, global_pred=True),
+ )
+
+ self.reset_parameters()
+
+ def quantize(self, x, n_quantizers=None):
+ outs, qs, commit_loss, quantized_buf = 0, [], [], []
+
+ # prosody
+ f0_input = x # (B, d, T)
+ f0_quantizer = self.quantizer[0]
+ out, q, commit, quantized = f0_quantizer(f0_input, n_quantizers=n_quantizers)
+ outs += out
+ qs.append(q)
+ quantized_buf.append(quantized.sum(0))
+ commit_loss.append(commit)
+
+ # phone
+ phone_input = x
+ phone_quantizer = self.quantizer[1]
+ out, q, commit, quantized = phone_quantizer(
+ phone_input, n_quantizers=n_quantizers
+ )
+ outs += out
+ qs.append(q)
+ quantized_buf.append(quantized.sum(0))
+ commit_loss.append(commit)
+
+ # residual
+ if self.vq_num_q_r > 0:
+ residual_quantizer = self.quantizer[2]
+ residual_input = x - (quantized_buf[0] + quantized_buf[1]).detach()
+ out, q, commit, quantized = residual_quantizer(
+ residual_input, n_quantizers=n_quantizers
+ )
+ outs += out
+ qs.append(q)
+ quantized_buf.append(quantized.sum(0)) # [L, B, C, T] -> [B, C, T]
+ commit_loss.append(commit)
+
+ qs = torch.cat(qs, dim=0)
+ commit_loss = torch.cat(commit_loss, dim=0)
+ return outs, qs, commit_loss, quantized_buf
+
+ def forward(
+ self,
+ x,
+ vq=True,
+ get_vq=False,
+ eval_vq=True,
+ speaker_embedding=None,
+ n_quantizers=None,
+ quantized=None,
+ ):
+ if get_vq:
+ return self.quantizer.get_emb()
+ if vq is True:
+ if eval_vq:
+ self.quantizer.eval()
+ x_timbre = x
+ outs, qs, commit_loss, quantized_buf = self.quantize(
+ x, n_quantizers=n_quantizers
+ )
+
+ x_timbre = x_timbre.transpose(1, 2)
+ x_timbre = self.timbre_encoder(x_timbre, None, None)
+ x_timbre = x_timbre.transpose(1, 2)
+ spk_embs = torch.mean(x_timbre, dim=2)
+ return outs, qs, commit_loss, quantized_buf, spk_embs
+
+ out = {}
+
+ layer_0 = quantized[0]
+ f0, uv = self.f0_predictor(layer_0)
+ f0 = rearrange(f0, "... 1 -> ...")
+ uv = rearrange(uv, "... 1 -> ...")
+
+ layer_1 = quantized[1]
+ (phone,) = self.phone_predictor(layer_1)
+
+ out = {"f0": f0, "uv": uv, "phone": phone}
+
+ if self.use_gr_prosody_phone:
+ (prosody_phone,) = self.prosody_phone_predictor(layer_0)
+ out["prosody_phone"] = prosody_phone
+
+ if self.use_gr_content_f0:
+ content_f0, content_uv = self.content_f0_predictor(layer_1)
+ content_f0 = rearrange(content_f0, "... 1 -> ...")
+ content_uv = rearrange(content_uv, "... 1 -> ...")
+ out["content_f0"] = content_f0
+ out["content_uv"] = content_uv
+
+ if self.vq_num_q_r > 0:
+ layer_2 = quantized[2]
+
+ if self.use_gr_residual_f0:
+ res_f0, res_uv = self.res_f0_predictor(layer_2)
+ res_f0 = rearrange(res_f0, "... 1 -> ...")
+ res_uv = rearrange(res_uv, "... 1 -> ...")
+ out["res_f0"] = res_f0
+ out["res_uv"] = res_uv
+
+ if self.use_gr_residual_phone:
+ (res_phone,) = self.res_phone_predictor(layer_2)
+ out["res_phone"] = res_phone
+
+ style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
+ gamma, beta = style.chunk(2, 1) # (B, d, 1)
+ if self.vq_num_q_r > 0:
+ if self.use_random_mask_residual:
+ bsz = quantized[2].shape[0]
+ res_mask = np.random.choice(
+ [0, 1],
+ size=bsz,
+ p=[
+ self.prob_random_mask_residual,
+ 1 - self.prob_random_mask_residual,
+ ],
+ )
+ res_mask = (
+ torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1)
+ ) # (B, 1, 1)
+ res_mask = res_mask.to(
+ device=quantized[2].device, dtype=quantized[2].dtype
+ )
+ x = (
+ quantized[0].detach()
+ + quantized[1].detach()
+ + quantized[2] * res_mask
+ )
+ # x = quantized_perturbe[0].detach() + quantized[1].detach() + quantized[2] * res_mask
+ else:
+ x = quantized[0].detach() + quantized[1].detach() + quantized[2]
+ # x = quantized_perturbe[0].detach() + quantized[1].detach() + quantized[2]
+ else:
+ x = quantized[0].detach() + quantized[1].detach()
+ # x = quantized_perturbe[0].detach() + quantized[1].detach()
+
+ if self.use_gr_x_timbre:
+ (x_timbre,) = self.x_timbre_predictor(x)
+ out["x_timbre"] = x_timbre
+
+ x = x.transpose(1, 2)
+ x = self.timbre_norm(x)
+ x = x.transpose(1, 2)
+ x = x * gamma + beta
+
+ x = self.model(x)
+ out["audio"] = x
+
+ return out
+
+ def vq2emb(self, vq, use_residual_code=True):
+ # vq: [num_quantizer, B, T]
+ self.quantizer = self.quantizer.eval()
+ out = 0
+ out += self.quantizer[0].vq2emb(vq[0 : self.vq_num_q_p])
+ out += self.quantizer[1].vq2emb(
+ vq[self.vq_num_q_p : self.vq_num_q_p + self.vq_num_q_c]
+ )
+ if self.vq_num_q_r > 0 and use_residual_code:
+ out += self.quantizer[2].vq2emb(vq[self.vq_num_q_p + self.vq_num_q_c :])
+ return out
+
+ def inference(self, x, speaker_embedding):
+ style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
+ gamma, beta = style.chunk(2, 1) # (B, d, 1)
+ x = x.transpose(1, 2)
+ x = self.timbre_norm(x)
+ x = x.transpose(1, 2)
+ x = x * gamma + beta
+ x = self.model(x)
+ return x
+
+ def remove_weight_norm(self):
+ """Remove weight normalization module from all of the layers."""
+
+ def _remove_weight_norm(m):
+ try:
+ torch.nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(_remove_weight_norm)
+
+ def apply_weight_norm(self):
+ """Apply weight normalization module from all of the layers."""
+
+ def _apply_weight_norm(m):
+ if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
+ torch.nn.utils.weight_norm(m)
+
+ self.apply(_apply_weight_norm)
+
+ def reset_parameters(self):
+ self.apply(init_weights)
+
+
+class FACodecRedecoder(nn.Module):
+ def __init__(
+ self,
+ in_channels=256,
+ upsample_initial_channel=1280,
+ up_ratios=(5, 5, 4, 2),
+ vq_num_q_c=2,
+ vq_num_q_p=1,
+ vq_num_q_r=3,
+ vq_dim=256,
+ codebook_size_prosody=10,
+ codebook_size_content=10,
+ codebook_size_residual=10,
+ ):
+ super().__init__()
+ self.hop_length = np.prod(up_ratios)
+ self.up_ratios = up_ratios
+
+ self.vq_num_q_p = vq_num_q_p
+ self.vq_num_q_c = vq_num_q_c
+ self.vq_num_q_r = vq_num_q_r
+
+ self.vq_dim = vq_dim
+
+ self.codebook_size_prosody = codebook_size_prosody
+ self.codebook_size_content = codebook_size_content
+ self.codebook_size_residual = codebook_size_residual
+
+ self.prosody_embs = nn.ModuleList()
+ for i in range(self.vq_num_q_p):
+ emb_tokens = nn.Embedding(
+ num_embeddings=2**self.codebook_size_prosody,
+ embedding_dim=self.vq_dim,
+ )
+ emb_tokens.weight.data.normal_(mean=0.0, std=1e-5)
+ self.prosody_embs.append(emb_tokens)
+ self.content_embs = nn.ModuleList()
+ for i in range(self.vq_num_q_c):
+ emb_tokens = nn.Embedding(
+ num_embeddings=2**self.codebook_size_content,
+ embedding_dim=self.vq_dim,
+ )
+ emb_tokens.weight.data.normal_(mean=0.0, std=1e-5)
+ self.content_embs.append(emb_tokens)
+ self.residual_embs = nn.ModuleList()
+ for i in range(self.vq_num_q_r):
+ emb_tokens = nn.Embedding(
+ num_embeddings=2**self.codebook_size_residual,
+ embedding_dim=self.vq_dim,
+ )
+ emb_tokens.weight.data.normal_(mean=0.0, std=1e-5)
+ self.residual_embs.append(emb_tokens)
+
+ # Add first conv layer
+ channels = upsample_initial_channel
+ layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)]
+
+ # Add upsampling + MRF blocks
+ for i, stride in enumerate(up_ratios):
+ input_dim = channels // 2**i
+ output_dim = channels // 2 ** (i + 1)
+ layers += [DecoderBlock(input_dim, output_dim, stride)]
+
+ # Add final conv layer
+ layers += [
+ Activation1d(activation=SnakeBeta(output_dim, alpha_logscale=True)),
+ WNConv1d(output_dim, 1, kernel_size=7, padding=3),
+ nn.Tanh(),
+ ]
+
+ self.model = nn.Sequential(*layers)
+
+ self.timbre_linear = nn.Linear(in_channels, in_channels * 2)
+ self.timbre_linear.bias.data[:in_channels] = 1
+ self.timbre_linear.bias.data[in_channels:] = 0
+ self.timbre_norm = nn.LayerNorm(in_channels, elementwise_affine=False)
+
+ self.timbre_cond_prosody_enc = TransformerEncoder(
+ enc_emb_tokens=None,
+ encoder_layer=4,
+ encoder_hidden=256,
+ encoder_head=4,
+ conv_filter_size=1024,
+ conv_kernel_size=5,
+ encoder_dropout=0.1,
+ use_cln=True,
+ cfg=None,
+ )
+
+ def forward(
+ self,
+ vq,
+ speaker_embedding,
+ use_residual_code=False,
+ ):
+
+ x = 0
+
+ x_p = 0
+ for i in range(self.vq_num_q_p):
+ x_p = x_p + self.prosody_embs[i](vq[i]) # (B, T, d)
+ spk_cond = speaker_embedding.unsqueeze(1).expand(-1, x_p.shape[1], -1)
+ x_p = self.timbre_cond_prosody_enc(
+ x_p, key_padding_mask=None, condition=spk_cond
+ )
+ x = x + x_p
+
+ x_c = 0
+ for i in range(self.vq_num_q_c):
+ x_c = x_c + self.content_embs[i](vq[self.vq_num_q_p + i])
+
+ x = x + x_c
+
+ if use_residual_code:
+
+ x_r = 0
+ for i in range(self.vq_num_q_r):
+ x_r = x_r + self.residual_embs[i](
+ vq[self.vq_num_q_p + self.vq_num_q_c + i]
+ )
+ x = x + x_r
+
+ style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
+ gamma, beta = style.chunk(2, 1) # (B, d, 1)
+ x = x.transpose(1, 2)
+ x = self.timbre_norm(x)
+ x = x.transpose(1, 2)
+ x = x * gamma + beta
+ x = self.model(x)
+
+ return x
+
+ def vq2emb(self, vq, speaker_embedding, use_residual=True):
+
+ out = 0
+
+ x_t = 0
+ for i in range(self.vq_num_q_p):
+ x_t += self.prosody_embs[i](vq[i]) # (B, T, d)
+ spk_cond = speaker_embedding.unsqueeze(1).expand(-1, x_t.shape[1], -1)
+ x_t = self.timbre_cond_prosody_enc(
+ x_t, key_padding_mask=None, condition=spk_cond
+ )
+
+ # prosody
+ out += x_t
+
+ # content
+ for i in range(self.vq_num_q_c):
+ out += self.content_embs[i](vq[self.vq_num_q_p + i])
+
+ # residual
+ if use_residual:
+ for i in range(self.vq_num_q_r):
+ out += self.residual_embs[i](vq[self.vq_num_q_p + self.vq_num_q_c + i])
+
+ out = out.transpose(1, 2) # (B, T, d) -> (B, d, T)
+ return out
+
+ def inference(self, x, speaker_embedding):
+ style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
+ gamma, beta = style.chunk(2, 1) # (B, d, 1)
+ x = x.transpose(1, 2)
+ x = self.timbre_norm(x)
+ x = x.transpose(1, 2)
+ x = x * gamma + beta
+ x = self.model(x)
+ return x
+
+
+class FACodecEncoderV2(nn.Module):
+ def __init__(
+ self,
+ ngf=32,
+ up_ratios=(2, 4, 5, 5),
+ out_channels=1024,
+ ):
+ super().__init__()
+ self.hop_length = np.prod(up_ratios)
+ self.up_ratios = up_ratios
+
+ # Create first convolution
+ d_model = ngf
+ self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
+
+ # Create EncoderBlocks that double channels as they downsample by `stride`
+ for stride in up_ratios:
+ d_model *= 2
+ self.block += [EncoderBlock(d_model, stride=stride)]
+
+ # Create last convolution
+ self.block += [
+ Activation1d(activation=SnakeBeta(d_model, alpha_logscale=True)),
+ WNConv1d(d_model, out_channels, kernel_size=3, padding=1),
+ ]
+
+ # Wrap black into nn.Sequential
+ self.block = nn.Sequential(*self.block)
+ self.enc_dim = d_model
+
+ self.mel_transform = MelSpectrogram(
+ n_fft=1024,
+ num_mels=80,
+ sampling_rate=16000,
+ hop_size=200,
+ win_size=800,
+ fmin=0,
+ fmax=8000,
+ )
+
+ self.reset_parameters()
+
+ def forward(self, x):
+ out = self.block(x)
+ return out
+
+ def inference(self, x):
+ return self.block(x)
+
+ def get_prosody_feature(self, x):
+ return self.mel_transform(x.squeeze(1))[:, :20, :]
+
+ def remove_weight_norm(self):
+ """Remove weight normalization module from all of the layers."""
+
+ def _remove_weight_norm(m):
+ try:
+ torch.nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(_remove_weight_norm)
+
+ def apply_weight_norm(self):
+ """Apply weight normalization module from all of the layers."""
+
+ def _apply_weight_norm(m):
+ if isinstance(m, nn.Conv1d):
+ torch.nn.utils.weight_norm(m)
+
+ self.apply(_apply_weight_norm)
+
+ def reset_parameters(self):
+ self.apply(init_weights)
+
+
+class FACodecDecoderV2(nn.Module):
+ def __init__(
+ self,
+ in_channels=256,
+ upsample_initial_channel=1536,
+ ngf=32,
+ up_ratios=(5, 5, 4, 2),
+ vq_num_q_c=2,
+ vq_num_q_p=1,
+ vq_num_q_r=3,
+ vq_dim=1024,
+ vq_commit_weight=0.005,
+ vq_weight_init=False,
+ vq_full_commit_loss=False,
+ codebook_dim=8,
+ codebook_size_prosody=10, # true codebook size is equal to 2^codebook_size
+ codebook_size_content=10,
+ codebook_size_residual=10,
+ quantizer_dropout=0.0,
+ dropout_type="linear",
+ use_gr_content_f0=False,
+ use_gr_prosody_phone=False,
+ use_gr_residual_f0=False,
+ use_gr_residual_phone=False,
+ use_gr_x_timbre=False,
+ use_random_mask_residual=True,
+ prob_random_mask_residual=0.75,
+ ):
+ super().__init__()
+ self.hop_length = np.prod(up_ratios)
+ self.ngf = ngf
+ self.up_ratios = up_ratios
+
+ self.use_random_mask_residual = use_random_mask_residual
+ self.prob_random_mask_residual = prob_random_mask_residual
+
+ self.vq_num_q_p = vq_num_q_p
+ self.vq_num_q_c = vq_num_q_c
+ self.vq_num_q_r = vq_num_q_r
+
+ self.codebook_size_prosody = codebook_size_prosody
+ self.codebook_size_content = codebook_size_content
+ self.codebook_size_residual = codebook_size_residual
+
+ quantizer_class = ResidualVQ
+
+ self.quantizer = nn.ModuleList()
+
+ # prosody
+ quantizer = quantizer_class(
+ num_quantizers=vq_num_q_p,
+ dim=vq_dim,
+ codebook_size=codebook_size_prosody,
+ codebook_dim=codebook_dim,
+ threshold_ema_dead_code=2,
+ commitment=vq_commit_weight,
+ weight_init=vq_weight_init,
+ full_commit_loss=vq_full_commit_loss,
+ quantizer_dropout=quantizer_dropout,
+ dropout_type=dropout_type,
+ )
+ self.quantizer.append(quantizer)
+
+ # phone
+ quantizer = quantizer_class(
+ num_quantizers=vq_num_q_c,
+ dim=vq_dim,
+ codebook_size=codebook_size_content,
+ codebook_dim=codebook_dim,
+ threshold_ema_dead_code=2,
+ commitment=vq_commit_weight,
+ weight_init=vq_weight_init,
+ full_commit_loss=vq_full_commit_loss,
+ quantizer_dropout=quantizer_dropout,
+ dropout_type=dropout_type,
+ )
+ self.quantizer.append(quantizer)
+
+ # residual
+ if self.vq_num_q_r > 0:
+ quantizer = quantizer_class(
+ num_quantizers=vq_num_q_r,
+ dim=vq_dim,
+ codebook_size=codebook_size_residual,
+ codebook_dim=codebook_dim,
+ threshold_ema_dead_code=2,
+ commitment=vq_commit_weight,
+ weight_init=vq_weight_init,
+ full_commit_loss=vq_full_commit_loss,
+ quantizer_dropout=quantizer_dropout,
+ dropout_type=dropout_type,
+ )
+ self.quantizer.append(quantizer)
+
+ # Add first conv layer
+ channels = upsample_initial_channel
+ layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)]
+
+ # Add upsampling + MRF blocks
+ for i, stride in enumerate(up_ratios):
+ input_dim = channels // 2**i
+ output_dim = channels // 2 ** (i + 1)
+ layers += [DecoderBlock(input_dim, output_dim, stride)]
+
+ # Add final conv layer
+ layers += [
+ Activation1d(activation=SnakeBeta(output_dim, alpha_logscale=True)),
+ WNConv1d(output_dim, 1, kernel_size=7, padding=3),
+ nn.Tanh(),
+ ]
+
+ self.model = nn.Sequential(*layers)
+
+ self.timbre_encoder = TransformerEncoder(
+ enc_emb_tokens=None,
+ encoder_layer=4,
+ encoder_hidden=256,
+ encoder_head=4,
+ conv_filter_size=1024,
+ conv_kernel_size=5,
+ encoder_dropout=0.1,
+ use_cln=False,
+ )
+
+ self.timbre_linear = nn.Linear(in_channels, in_channels * 2)
+ self.timbre_linear.bias.data[:in_channels] = 1
+ self.timbre_linear.bias.data[in_channels:] = 0
+ self.timbre_norm = nn.LayerNorm(in_channels, elementwise_affine=False)
+
+ self.f0_predictor = CNNLSTM(in_channels, 1, 2)
+ self.phone_predictor = CNNLSTM(in_channels, 5003, 1)
+
+ self.use_gr_content_f0 = use_gr_content_f0
+ self.use_gr_prosody_phone = use_gr_prosody_phone
+ self.use_gr_residual_f0 = use_gr_residual_f0
+ self.use_gr_residual_phone = use_gr_residual_phone
+ self.use_gr_x_timbre = use_gr_x_timbre
+
+ if self.vq_num_q_r > 0 and self.use_gr_residual_f0:
+ self.res_f0_predictor = nn.Sequential(
+ GradientReversal(alpha=1.0), CNNLSTM(in_channels, 1, 2)
+ )
+
+ if self.vq_num_q_r > 0 and self.use_gr_residual_phone > 0:
+ self.res_phone_predictor = nn.Sequential(
+ GradientReversal(alpha=1.0), CNNLSTM(in_channels, 5003, 1)
+ )
+
+ if self.use_gr_content_f0:
+ self.content_f0_predictor = nn.Sequential(
+ GradientReversal(alpha=1.0), CNNLSTM(in_channels, 1, 2)
+ )
+
+ if self.use_gr_prosody_phone:
+ self.prosody_phone_predictor = nn.Sequential(
+ GradientReversal(alpha=1.0), CNNLSTM(in_channels, 5003, 1)
+ )
+
+ if self.use_gr_x_timbre:
+ self.x_timbre_predictor = nn.Sequential(
+ GradientReversal(alpha=1),
+ CNNLSTM(in_channels, 245200, 1, global_pred=True),
+ )
+
+ self.melspec_linear = nn.Linear(20, 256)
+ self.melspec_encoder = TransformerEncoder(
+ enc_emb_tokens=None,
+ encoder_layer=4,
+ encoder_hidden=256,
+ encoder_head=4,
+ conv_filter_size=1024,
+ conv_kernel_size=5,
+ encoder_dropout=0.1,
+ use_cln=False,
+ cfg=None,
+ )
+
+ self.reset_parameters()
+
+ def quantize(self, x, prosody_feature, n_quantizers=None):
+ outs, qs, commit_loss, quantized_buf = 0, [], [], []
+
+ # prosody
+ f0_input = prosody_feature.transpose(1, 2) # (B, T, 20)
+ f0_input = self.melspec_linear(f0_input)
+ f0_input = self.melspec_encoder(f0_input, None, None)
+ f0_input = f0_input.transpose(1, 2)
+ f0_quantizer = self.quantizer[0]
+ out, q, commit, quantized = f0_quantizer(f0_input, n_quantizers=n_quantizers)
+ outs += out
+ qs.append(q)
+ quantized_buf.append(quantized.sum(0))
+ commit_loss.append(commit)
+
+ # phone
+ phone_input = x
+ phone_quantizer = self.quantizer[1]
+ out, q, commit, quantized = phone_quantizer(
+ phone_input, n_quantizers=n_quantizers
+ )
+ outs += out
+ qs.append(q)
+ quantized_buf.append(quantized.sum(0))
+ commit_loss.append(commit)
+
+ # residual
+ if self.vq_num_q_r > 0:
+ residual_quantizer = self.quantizer[2]
+ residual_input = x - (quantized_buf[0] + quantized_buf[1]).detach()
+ out, q, commit, quantized = residual_quantizer(
+ residual_input, n_quantizers=n_quantizers
+ )
+ outs += out
+ qs.append(q)
+ quantized_buf.append(quantized.sum(0)) # [L, B, C, T] -> [B, C, T]
+ commit_loss.append(commit)
+
+ qs = torch.cat(qs, dim=0)
+ commit_loss = torch.cat(commit_loss, dim=0)
+ return outs, qs, commit_loss, quantized_buf
+
+ def forward(
+ self,
+ x,
+ prosody_feature,
+ vq=True,
+ get_vq=False,
+ eval_vq=True,
+ speaker_embedding=None,
+ n_quantizers=None,
+ quantized=None,
+ ):
+ if get_vq:
+ return self.quantizer.get_emb()
+ if vq is True:
+ if eval_vq:
+ self.quantizer.eval()
+ x_timbre = x
+ outs, qs, commit_loss, quantized_buf = self.quantize(
+ x, prosody_feature, n_quantizers=n_quantizers
+ )
+
+ x_timbre = x_timbre.transpose(1, 2)
+ x_timbre = self.timbre_encoder(x_timbre, None, None)
+ x_timbre = x_timbre.transpose(1, 2)
+ spk_embs = torch.mean(x_timbre, dim=2)
+ return outs, qs, commit_loss, quantized_buf, spk_embs
+
+ out = {}
+
+ layer_0 = quantized[0]
+ f0, uv = self.f0_predictor(layer_0)
+ f0 = rearrange(f0, "... 1 -> ...")
+ uv = rearrange(uv, "... 1 -> ...")
+
+ layer_1 = quantized[1]
+ (phone,) = self.phone_predictor(layer_1)
+
+ out = {"f0": f0, "uv": uv, "phone": phone}
+
+ if self.use_gr_prosody_phone:
+ (prosody_phone,) = self.prosody_phone_predictor(layer_0)
+ out["prosody_phone"] = prosody_phone
+
+ if self.use_gr_content_f0:
+ content_f0, content_uv = self.content_f0_predictor(layer_1)
+ content_f0 = rearrange(content_f0, "... 1 -> ...")
+ content_uv = rearrange(content_uv, "... 1 -> ...")
+ out["content_f0"] = content_f0
+ out["content_uv"] = content_uv
+
+ if self.vq_num_q_r > 0:
+ layer_2 = quantized[2]
+
+ if self.use_gr_residual_f0:
+ res_f0, res_uv = self.res_f0_predictor(layer_2)
+ res_f0 = rearrange(res_f0, "... 1 -> ...")
+ res_uv = rearrange(res_uv, "... 1 -> ...")
+ out["res_f0"] = res_f0
+ out["res_uv"] = res_uv
+
+ if self.use_gr_residual_phone:
+ (res_phone,) = self.res_phone_predictor(layer_2)
+ out["res_phone"] = res_phone
+
+ style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
+ gamma, beta = style.chunk(2, 1) # (B, d, 1)
+ if self.vq_num_q_r > 0:
+ if self.use_random_mask_residual:
+ bsz = quantized[2].shape[0]
+ res_mask = np.random.choice(
+ [0, 1],
+ size=bsz,
+ p=[
+ self.prob_random_mask_residual,
+ 1 - self.prob_random_mask_residual,
+ ],
+ )
+ res_mask = (
+ torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1)
+ ) # (B, 1, 1)
+ res_mask = res_mask.to(
+ device=quantized[2].device, dtype=quantized[2].dtype
+ )
+ x = (
+ quantized[0].detach()
+ + quantized[1].detach()
+ + quantized[2] * res_mask
+ )
+ # x = quantized_perturbe[0].detach() + quantized[1].detach() + quantized[2] * res_mask
+ else:
+ x = quantized[0].detach() + quantized[1].detach() + quantized[2]
+ # x = quantized_perturbe[0].detach() + quantized[1].detach() + quantized[2]
+ else:
+ x = quantized[0].detach() + quantized[1].detach()
+ # x = quantized_perturbe[0].detach() + quantized[1].detach()
+
+ if self.use_gr_x_timbre:
+ (x_timbre,) = self.x_timbre_predictor(x)
+ out["x_timbre"] = x_timbre
+
+ x = x.transpose(1, 2)
+ x = self.timbre_norm(x)
+ x = x.transpose(1, 2)
+ x = x * gamma + beta
+
+ x = self.model(x)
+ out["audio"] = x
+
+ return out
+
+ def vq2emb(self, vq, use_residual=True):
+ # vq: [num_quantizer, B, T]
+ self.quantizer = self.quantizer.eval()
+ out = 0
+ out += self.quantizer[0].vq2emb(vq[0 : self.vq_num_q_p])
+ out += self.quantizer[1].vq2emb(
+ vq[self.vq_num_q_p : self.vq_num_q_p + self.vq_num_q_c]
+ )
+ if self.vq_num_q_r > 0 and use_residual:
+ out += self.quantizer[2].vq2emb(vq[self.vq_num_q_p + self.vq_num_q_c :])
+ return out
+
+ def inference(self, x, speaker_embedding):
+ style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
+ gamma, beta = style.chunk(2, 1) # (B, d, 1)
+ x = x.transpose(1, 2)
+ x = self.timbre_norm(x)
+ x = x.transpose(1, 2)
+ x = x * gamma + beta
+ x = self.model(x)
+ return x
+
+ def remove_weight_norm(self):
+ """Remove weight normalization module from all of the layers."""
+
+ def _remove_weight_norm(m):
+ try:
+ torch.nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(_remove_weight_norm)
+
+ def apply_weight_norm(self):
+ """Apply weight normalization module from all of the layers."""
+
+ def _apply_weight_norm(m):
+ if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
+ torch.nn.utils.weight_norm(m)
+
+ self.apply(_apply_weight_norm)
+
+ def reset_parameters(self):
+ self.apply(init_weights)
diff --git a/models/codec/ns3_codec/gradient_reversal.py b/models/codec/ns3_codec/gradient_reversal.py
new file mode 100644
index 0000000000000000000000000000000000000000..d09396ea20c653b2a443e144ab429f534ce033fd
--- /dev/null
+++ b/models/codec/ns3_codec/gradient_reversal.py
@@ -0,0 +1,35 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from torch.autograd import Function
+import torch
+from torch import nn
+
+
+class GradientReversal(Function):
+ @staticmethod
+ def forward(ctx, x, alpha):
+ ctx.save_for_backward(x, alpha)
+ return x
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ grad_input = None
+ _, alpha = ctx.saved_tensors
+ if ctx.needs_input_grad[0]:
+ grad_input = -alpha * grad_output
+ return grad_input, None
+
+
+revgrad = GradientReversal.apply
+
+
+class GradientReversal(nn.Module):
+ def __init__(self, alpha):
+ super().__init__()
+ self.alpha = torch.tensor(alpha, requires_grad=False)
+
+ def forward(self, x):
+ return revgrad(x, self.alpha)
diff --git a/models/codec/ns3_codec/melspec.py b/models/codec/ns3_codec/melspec.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbf1cd24ebe533477be0833723b84b0b1d75c2d8
--- /dev/null
+++ b/models/codec/ns3_codec/melspec.py
@@ -0,0 +1,102 @@
+import torch
+import pyworld as pw
+import numpy as np
+import soundfile as sf
+import os
+from torchaudio.functional import pitch_shift
+import librosa
+from librosa.filters import mel as librosa_mel_fn
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def dynamic_range_compression(x, C=1, clip_val=1e-5):
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
+
+
+def dynamic_range_decompression(x, C=1):
+ return np.exp(x) / C
+
+
+def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
+ return torch.log(torch.clamp(x, min=clip_val) * C)
+
+
+def dynamic_range_decompression_torch(x, C=1):
+ return torch.exp(x) / C
+
+
+def spectral_normalize_torch(magnitudes):
+ output = dynamic_range_compression_torch(magnitudes)
+ return output
+
+
+def spectral_de_normalize_torch(magnitudes):
+ output = dynamic_range_decompression_torch(magnitudes)
+ return output
+
+
+class MelSpectrogram(nn.Module):
+ def __init__(
+ self,
+ n_fft,
+ num_mels,
+ sampling_rate,
+ hop_size,
+ win_size,
+ fmin,
+ fmax,
+ center=False,
+ ):
+ super(MelSpectrogram, self).__init__()
+ self.n_fft = n_fft
+ self.hop_size = hop_size
+ self.win_size = win_size
+ self.sampling_rate = sampling_rate
+ self.num_mels = num_mels
+ self.fmin = fmin
+ self.fmax = fmax
+ self.center = center
+
+ mel_basis = {}
+ hann_window = {}
+
+ mel = librosa_mel_fn(
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
+ )
+ mel_basis = torch.from_numpy(mel).float()
+ hann_window = torch.hann_window(win_size)
+
+ self.register_buffer("mel_basis", mel_basis)
+ self.register_buffer("hann_window", hann_window)
+
+ def forward(self, y):
+ y = torch.nn.functional.pad(
+ y.unsqueeze(1),
+ (
+ int((self.n_fft - self.hop_size) / 2),
+ int((self.n_fft - self.hop_size) / 2),
+ ),
+ mode="reflect",
+ )
+ y = y.squeeze(1)
+ spec = torch.stft(
+ y,
+ self.n_fft,
+ hop_length=self.hop_size,
+ win_length=self.win_size,
+ window=self.hann_window,
+ center=self.center,
+ pad_mode="reflect",
+ normalized=False,
+ onesided=True,
+ return_complex=True,
+ )
+ spec = torch.view_as_real(spec)
+
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
+
+ spec = torch.matmul(self.mel_basis, spec)
+ spec = spectral_normalize_torch(spec)
+
+ return spec
diff --git a/models/codec/ns3_codec/quantize/__init__.py b/models/codec/ns3_codec/quantize/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2cb7b4063ca2364ccc2658a8e19061fb65ddd7a7
--- /dev/null
+++ b/models/codec/ns3_codec/quantize/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .fvq import *
+from .rvq import *
diff --git a/models/codec/ns3_codec/quantize/fvq.py b/models/codec/ns3_codec/quantize/fvq.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ade35d889497a8d42034b6cf00aea48c92c5422
--- /dev/null
+++ b/models/codec/ns3_codec/quantize/fvq.py
@@ -0,0 +1,116 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from torch.nn.utils import weight_norm
+
+
+class FactorizedVectorQuantize(nn.Module):
+ def __init__(self, dim, codebook_size, codebook_dim, commitment, **kwargs):
+ super().__init__()
+ self.codebook_size = codebook_size
+ self.codebook_dim = codebook_dim
+ self.commitment = commitment
+
+ if dim != self.codebook_dim:
+ self.in_proj = weight_norm(nn.Linear(dim, self.codebook_dim))
+ self.out_proj = weight_norm(nn.Linear(self.codebook_dim, dim))
+ else:
+ self.in_proj = nn.Identity()
+ self.out_proj = nn.Identity()
+ self._codebook = nn.Embedding(codebook_size, self.codebook_dim)
+
+ @property
+ def codebook(self):
+ return self._codebook
+
+ def forward(self, z):
+ """Quantized the input tensor using a fixed codebook and returns
+ the corresponding codebook vectors
+
+ Parameters
+ ----------
+ z : Tensor[B x D x T]
+
+ Returns
+ -------
+ Tensor[B x D x T]
+ Quantized continuous representation of input
+ Tensor[1]
+ Commitment loss to train encoder to predict vectors closer to codebook
+ entries
+ Tensor[1]
+ Codebook loss to update the codebook
+ Tensor[B x T]
+ Codebook indices (quantized discrete representation of input)
+ Tensor[B x D x T]
+ Projected latents (continuous representation of input before quantization)
+ """
+ # transpose since we use linear
+
+ z = rearrange(z, "b d t -> b t d")
+
+ # Factorized codes project input into low-dimensional space
+ z_e = self.in_proj(z) # z_e : (B x T x D)
+ z_e = rearrange(z_e, "b t d -> b d t")
+ z_q, indices = self.decode_latents(z_e)
+
+ if self.training:
+ commitment_loss = (
+ F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
+ * self.commitment
+ )
+ codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
+ commit_loss = commitment_loss + codebook_loss
+ else:
+ commit_loss = torch.zeros(z.shape[0], device=z.device)
+
+ z_q = (
+ z_e + (z_q - z_e).detach()
+ ) # noop in forward pass, straight-through gradient estimator in backward pass
+
+ z_q = rearrange(z_q, "b d t -> b t d")
+ z_q = self.out_proj(z_q)
+ z_q = rearrange(z_q, "b t d -> b d t")
+
+ return z_q, indices, commit_loss
+
+ def vq2emb(self, vq, proj=True):
+ emb = self.embed_code(vq)
+ if proj:
+ emb = self.out_proj(emb)
+ return emb.transpose(1, 2)
+
+ def get_emb(self):
+ return self.codebook.weight
+
+ def embed_code(self, embed_id):
+ return F.embedding(embed_id, self.codebook.weight)
+
+ def decode_code(self, embed_id):
+ return self.embed_code(embed_id).transpose(1, 2)
+
+ def decode_latents(self, latents):
+ encodings = rearrange(latents, "b d t -> (b t) d")
+ codebook = self.codebook.weight # codebook: (N x D)
+ # L2 normalize encodings and codebook
+ encodings = F.normalize(encodings)
+ codebook = F.normalize(codebook)
+
+ # Compute euclidean distance with codebook
+ dist = (
+ encodings.pow(2).sum(1, keepdim=True)
+ - 2 * encodings @ codebook.t()
+ + codebook.pow(2).sum(1, keepdim=True).t()
+ )
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
+ z_q = self.decode_code(indices)
+ return z_q, indices
diff --git a/models/codec/ns3_codec/quantize/rvq.py b/models/codec/ns3_codec/quantize/rvq.py
new file mode 100644
index 0000000000000000000000000000000000000000..d22d88d584df625234d865a63e0fdb709fdf77a2
--- /dev/null
+++ b/models/codec/ns3_codec/quantize/rvq.py
@@ -0,0 +1,87 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import torch
+from torch import nn
+from .fvq import FactorizedVectorQuantize
+
+
+class ResidualVQ(nn.Module):
+ """Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf"""
+
+ def __init__(self, *, num_quantizers, codebook_size, **kwargs):
+ super().__init__()
+ VQ = FactorizedVectorQuantize
+ if type(codebook_size) == int:
+ codebook_size = [codebook_size] * num_quantizers
+ self.layers = nn.ModuleList(
+ [VQ(codebook_size=2**size, **kwargs) for size in codebook_size]
+ )
+ self.num_quantizers = num_quantizers
+ self.quantizer_dropout = kwargs.get("quantizer_dropout", 0.0)
+ self.dropout_type = kwargs.get("dropout_type", None)
+
+ def forward(self, x, n_quantizers=None):
+ quantized_out = 0.0
+ residual = x
+
+ all_losses = []
+ all_indices = []
+ all_quantized = []
+
+ if n_quantizers is None:
+ n_quantizers = self.num_quantizers
+ if self.training:
+ n_quantizers = torch.ones((x.shape[0],)) * self.num_quantizers + 1
+ if self.dropout_type == "linear":
+ dropout = torch.randint(1, self.num_quantizers + 1, (x.shape[0],))
+ elif self.dropout_type == "exp":
+ dropout = torch.randint(
+ 1, int(math.log2(self.num_quantizers)), (x.shape[0],)
+ )
+ dropout = torch.pow(2, dropout)
+ n_dropout = int(x.shape[0] * self.quantizer_dropout)
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
+ n_quantizers = n_quantizers.to(x.device)
+
+ for idx, layer in enumerate(self.layers):
+ if not self.training and idx >= n_quantizers:
+ break
+ quantized, indices, loss = layer(residual)
+
+ mask = (
+ torch.full((x.shape[0],), fill_value=idx, device=x.device)
+ < n_quantizers
+ )
+
+ residual = residual - quantized
+
+ quantized_out = quantized_out + quantized * mask[:, None, None]
+
+ # loss
+ loss = (loss * mask).mean()
+
+ all_indices.append(indices)
+ all_losses.append(loss)
+ all_quantized.append(quantized)
+ all_losses, all_indices, all_quantized = map(
+ torch.stack, (all_losses, all_indices, all_quantized)
+ )
+ return quantized_out, all_indices, all_losses, all_quantized
+
+ def vq2emb(self, vq):
+ # vq: [n_quantizers, B, T]
+ quantized_out = 0.0
+ for idx, layer in enumerate(self.layers):
+ quantized = layer.vq2emb(vq[idx])
+ quantized_out += quantized
+ return quantized_out
+
+ def get_emb(self):
+ embs = []
+ for idx, layer in enumerate(self.layers):
+ embs.append(layer.get_emb())
+ return embs
diff --git a/models/codec/ns3_codec/transformer.py b/models/codec/ns3_codec/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..146d0f364dd17c385babb2c903f33378038556db
--- /dev/null
+++ b/models/codec/ns3_codec/transformer.py
@@ -0,0 +1,234 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+import torch.nn as nn
+import math
+from torch.nn import functional as F
+
+
+class StyleAdaptiveLayerNorm(nn.Module):
+ def __init__(self, normalized_shape, eps=1e-5):
+ super().__init__()
+ self.in_dim = normalized_shape
+ self.norm = nn.LayerNorm(self.in_dim, eps=eps, elementwise_affine=False)
+ self.style = nn.Linear(self.in_dim, self.in_dim * 2)
+ self.style.bias.data[: self.in_dim] = 1
+ self.style.bias.data[self.in_dim :] = 0
+
+ def forward(self, x, condition):
+ # x: (B, T, d); condition: (B, T, d)
+
+ style = self.style(torch.mean(condition, dim=1, keepdim=True))
+
+ gamma, beta = style.chunk(2, -1)
+
+ out = self.norm(x)
+
+ out = gamma * out + beta
+ return out
+
+
+class PositionalEncoding(nn.Module):
+ def __init__(self, d_model, dropout, max_len=5000):
+ super().__init__()
+
+ self.dropout = dropout
+ position = torch.arange(max_len).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
+ )
+ pe = torch.zeros(max_len, 1, d_model)
+ pe[:, 0, 0::2] = torch.sin(position * div_term)
+ pe[:, 0, 1::2] = torch.cos(position * div_term)
+ self.register_buffer("pe", pe)
+
+ def forward(self, x):
+ x = x + self.pe[: x.size(0)]
+ return F.dropout(x, self.dropout, training=self.training)
+
+
+class TransformerFFNLayer(nn.Module):
+ def __init__(
+ self, encoder_hidden, conv_filter_size, conv_kernel_size, encoder_dropout
+ ):
+ super().__init__()
+
+ self.encoder_hidden = encoder_hidden
+ self.conv_filter_size = conv_filter_size
+ self.conv_kernel_size = conv_kernel_size
+ self.encoder_dropout = encoder_dropout
+
+ self.ffn_1 = nn.Conv1d(
+ self.encoder_hidden,
+ self.conv_filter_size,
+ self.conv_kernel_size,
+ padding=self.conv_kernel_size // 2,
+ )
+ self.ffn_1.weight.data.normal_(0.0, 0.02)
+ self.ffn_2 = nn.Linear(self.conv_filter_size, self.encoder_hidden)
+ self.ffn_2.weight.data.normal_(0.0, 0.02)
+
+ def forward(self, x):
+ # x: (B, T, d)
+ x = self.ffn_1(x.permute(0, 2, 1)).permute(
+ 0, 2, 1
+ ) # (B, T, d) -> (B, d, T) -> (B, T, d)
+ x = F.relu(x)
+ x = F.dropout(x, self.encoder_dropout, training=self.training)
+ x = self.ffn_2(x)
+ return x
+
+
+class TransformerEncoderLayer(nn.Module):
+ def __init__(
+ self,
+ encoder_hidden,
+ encoder_head,
+ conv_filter_size,
+ conv_kernel_size,
+ encoder_dropout,
+ use_cln,
+ ):
+ super().__init__()
+ self.encoder_hidden = encoder_hidden
+ self.encoder_head = encoder_head
+ self.conv_filter_size = conv_filter_size
+ self.conv_kernel_size = conv_kernel_size
+ self.encoder_dropout = encoder_dropout
+ self.use_cln = use_cln
+
+ if not self.use_cln:
+ self.ln_1 = nn.LayerNorm(self.encoder_hidden)
+ self.ln_2 = nn.LayerNorm(self.encoder_hidden)
+ else:
+ self.ln_1 = StyleAdaptiveLayerNorm(self.encoder_hidden)
+ self.ln_2 = StyleAdaptiveLayerNorm(self.encoder_hidden)
+
+ self.self_attn = nn.MultiheadAttention(
+ self.encoder_hidden, self.encoder_head, batch_first=True
+ )
+
+ self.ffn = TransformerFFNLayer(
+ self.encoder_hidden,
+ self.conv_filter_size,
+ self.conv_kernel_size,
+ self.encoder_dropout,
+ )
+
+ def forward(self, x, key_padding_mask, conditon=None):
+ # x: (B, T, d); key_padding_mask: (B, T), mask is 0; condition: (B, T, d)
+
+ # self attention
+ residual = x
+ if self.use_cln:
+ x = self.ln_1(x, conditon)
+ else:
+ x = self.ln_1(x)
+
+ if key_padding_mask != None:
+ key_padding_mask_input = ~(key_padding_mask.bool())
+ else:
+ key_padding_mask_input = None
+ x, _ = self.self_attn(
+ query=x, key=x, value=x, key_padding_mask=key_padding_mask_input
+ )
+ x = F.dropout(x, self.encoder_dropout, training=self.training)
+ x = residual + x
+
+ # ffn
+ residual = x
+ if self.use_cln:
+ x = self.ln_2(x, conditon)
+ else:
+ x = self.ln_2(x)
+ x = self.ffn(x)
+ x = residual + x
+
+ return x
+
+
+class TransformerEncoder(nn.Module):
+ def __init__(
+ self,
+ enc_emb_tokens=None,
+ encoder_layer=4,
+ encoder_hidden=256,
+ encoder_head=4,
+ conv_filter_size=1024,
+ conv_kernel_size=5,
+ encoder_dropout=0.1,
+ use_cln=False,
+ cfg=None,
+ ):
+ super().__init__()
+
+ self.encoder_layer = (
+ encoder_layer if encoder_layer is not None else cfg.encoder_layer
+ )
+ self.encoder_hidden = (
+ encoder_hidden if encoder_hidden is not None else cfg.encoder_hidden
+ )
+ self.encoder_head = (
+ encoder_head if encoder_head is not None else cfg.encoder_head
+ )
+ self.conv_filter_size = (
+ conv_filter_size if conv_filter_size is not None else cfg.conv_filter_size
+ )
+ self.conv_kernel_size = (
+ conv_kernel_size if conv_kernel_size is not None else cfg.conv_kernel_size
+ )
+ self.encoder_dropout = (
+ encoder_dropout if encoder_dropout is not None else cfg.encoder_dropout
+ )
+ self.use_cln = use_cln if use_cln is not None else cfg.use_cln
+
+ if enc_emb_tokens != None:
+ self.use_enc_emb = True
+ self.enc_emb_tokens = enc_emb_tokens
+ else:
+ self.use_enc_emb = False
+
+ self.position_emb = PositionalEncoding(
+ self.encoder_hidden, self.encoder_dropout
+ )
+
+ self.layers = nn.ModuleList([])
+ self.layers.extend(
+ [
+ TransformerEncoderLayer(
+ self.encoder_hidden,
+ self.encoder_head,
+ self.conv_filter_size,
+ self.conv_kernel_size,
+ self.encoder_dropout,
+ self.use_cln,
+ )
+ for i in range(self.encoder_layer)
+ ]
+ )
+
+ if self.use_cln:
+ self.last_ln = StyleAdaptiveLayerNorm(self.encoder_hidden)
+ else:
+ self.last_ln = nn.LayerNorm(self.encoder_hidden)
+
+ def forward(self, x, key_padding_mask, condition=None):
+ if len(x.shape) == 2 and self.use_enc_emb:
+ x = self.enc_emb_tokens(x)
+ x = self.position_emb(x)
+ else:
+ x = self.position_emb(x) # (B, T, d)
+
+ for layer in self.layers:
+ x = layer(x, key_padding_mask, condition)
+
+ if self.use_cln:
+ x = self.last_ln(x, condition)
+ else:
+ x = self.last_ln(x)
+
+ return x
diff --git a/models/codec/speechtokenizer/model.py b/models/codec/speechtokenizer/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..b722d38653acdb348a40b1bd8ff3b94ceb2db563
--- /dev/null
+++ b/models/codec/speechtokenizer/model.py
@@ -0,0 +1,184 @@
+# Copyright (c) 2023 Amphion.
+#
+# This code is modified from https://github.com/ZhangXInFD/SpeechTokenizer/blob/main/speechtokenizer/model.py
+# Licensed under Apache License 2.0
+
+from .modules.seanet import SEANetEncoder, SEANetDecoder
+from .modules.quantization import ResidualVectorQuantizer
+import torch.nn as nn
+from einops import rearrange
+import torch
+import numpy as np
+
+
+class SpeechTokenizer(nn.Module):
+ def __init__(self, config):
+ """
+
+ Parameters
+ ----------
+ config : json
+ Model Config.
+
+ """
+ super().__init__()
+ self.encoder = SEANetEncoder(
+ n_filters=config.get("n_filters"),
+ dimension=config.get("dimension"),
+ ratios=config.get("strides"),
+ lstm=config.get("lstm_layers"),
+ bidirectional=config.get("bidirectional"),
+ dilation_base=config.get("dilation_base"),
+ residual_kernel_size=config.get("residual_kernel_size"),
+ n_residual_layers=config.get("n_residual_layers"),
+ activation=config.get("activation"),
+ )
+ self.sample_rate = config.get("sample_rate")
+ self.n_q = config.get("n_q")
+ self.downsample_rate = np.prod(config.get("strides"))
+ if config.get("dimension") != config.get("semantic_dimension"):
+ self.transform = nn.Linear(
+ config.get("dimension"), config.get("semantic_dimension")
+ )
+ else:
+ self.transform = nn.Identity()
+ self.quantizer = ResidualVectorQuantizer(
+ dimension=config.get("dimension"),
+ n_q=config.get("n_q"),
+ bins=config.get("codebook_size"),
+ )
+ self.decoder = SEANetDecoder(
+ n_filters=config.get("n_filters"),
+ dimension=config.get("dimension"),
+ ratios=config.get("strides"),
+ lstm=config.get("lstm_layers"),
+ bidirectional=False,
+ dilation_base=config.get("dilation_base"),
+ residual_kernel_size=config.get("residual_kernel_size"),
+ n_residual_layers=config.get("n_residual_layers"),
+ activation=config.get("activation"),
+ )
+
+ @classmethod
+ def load_from_checkpoint(cls, config_path: str, ckpt_path: str):
+ """
+
+ Parameters
+ ----------
+ config_path : str
+ Path of model configuration file.
+ ckpt_path : str
+ Path of model checkpoint.
+
+ Returns
+ -------
+ model : SpeechTokenizer
+ SpeechTokenizer model.
+
+ """
+ import json
+
+ with open(config_path) as f:
+ cfg = json.load(f)
+ model = cls(cfg)
+ params = torch.load(ckpt_path, map_location="cpu")
+ model.load_state_dict(params)
+ return model
+
+ def forward(self, x: torch.tensor, n_q: int = None, layers: list = [0]):
+ """
+
+ Parameters
+ ----------
+ x : torch.tensor
+ Input wavs. Shape: (batch, channels, timesteps).
+ n_q : int, optional
+ Number of quantizers in RVQ used to encode. The default is all layers.
+ layers : list[int], optional
+ Layers of RVQ should return quantized result. The default is the first layer.
+
+ Returns
+ -------
+ o : torch.tensor
+ Output wavs. Shape: (batch, channels, timesteps).
+ commit_loss : torch.tensor
+ Commitment loss from residual vector quantizers.
+ feature : torch.tensor
+ Output of RVQ's first layer. Shape: (batch, timesteps, dimension)
+
+ """
+ n_q = n_q if n_q else self.n_q
+ e = self.encoder(x)
+ quantized, codes, commit_loss, quantized_list = self.quantizer(
+ e, n_q=n_q, layers=layers
+ )
+ feature = rearrange(quantized_list[0], "b d t -> b t d")
+ feature = self.transform(feature)
+ o = self.decoder(quantized)
+ return o, commit_loss, feature
+
+ def forward_feature(self, x: torch.tensor, layers: list = None):
+ """
+
+ Parameters
+ ----------
+ x : torch.tensor
+ Input wavs. Shape should be (batch, channels, timesteps).
+ layers : list[int], optional
+ Layers of RVQ should return quantized result. The default is all layers.
+
+ Returns
+ -------
+ quantized_list : list[torch.tensor]
+ Quantized of required layers.
+
+ """
+ e = self.encoder(x)
+ layers = layers if layers else list(range(self.n_q))
+ quantized, codes, commit_loss, quantized_list = self.quantizer(e, layers=layers)
+ return quantized_list
+
+ def encode(self, x: torch.tensor, n_q: int = None, st: int = None):
+ """
+
+ Parameters
+ ----------
+ x : torch.tensor
+ Input wavs. Shape: (batch, channels, timesteps).
+ n_q : int, optional
+ Number of quantizers in RVQ used to encode. The default is all layers.
+ st : int, optional
+ Start quantizer index in RVQ. The default is 0.
+
+ Returns
+ -------
+ codes : torch.tensor
+ Output indices for each quantizer. Shape: (n_q, batch, timesteps)
+
+ """
+ e = self.encoder(x)
+ if st is None:
+ st = 0
+ n_q = n_q if n_q else self.n_q
+ codes = self.quantizer.encode(e, n_q=n_q, st=st)
+ return codes
+
+ def decode(self, codes: torch.tensor, st: int = 0):
+ """
+
+ Parameters
+ ----------
+ codes : torch.tensor
+ Indices for each quantizer. Shape: (n_q, batch, timesteps).
+ st : int, optional
+ Start quantizer index in RVQ. The default is 0.
+
+ Returns
+ -------
+ o : torch.tensor
+ Reconstruct wavs from codes. Shape: (batch, channels, timesteps)
+
+ """
+ quantized = self.quantizer.decode(codes, st=st)
+ o = self.decoder(quantized)
+ return o
diff --git a/models/codec/speechtokenizer/modules/__init__.py b/models/codec/speechtokenizer/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0581347c1300a65bfb84e4ae581526cc6edcc1ca
--- /dev/null
+++ b/models/codec/speechtokenizer/modules/__init__.py
@@ -0,0 +1,27 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+# This source file is copied from https://github.com/facebookresearch/encodec
+
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Torch modules."""
+
+# flake8: noqa
+from .conv import (
+ pad1d,
+ unpad1d,
+ NormConv1d,
+ NormConvTranspose1d,
+ NormConv2d,
+ NormConvTranspose2d,
+ SConv1d,
+ SConvTranspose1d,
+)
+from .lstm import SLSTM
+from .seanet import SEANetEncoder, SEANetDecoder
diff --git a/models/codec/speechtokenizer/modules/conv.py b/models/codec/speechtokenizer/modules/conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..0352b8bfa322b0c166bb068fa18c3c3a46cb498e
--- /dev/null
+++ b/models/codec/speechtokenizer/modules/conv.py
@@ -0,0 +1,346 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+# This source file is copied from https://github.com/facebookresearch/encodec
+
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Convolutional layers wrappers and utilities."""
+
+import math
+import typing as tp
+import warnings
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.nn.utils import spectral_norm, weight_norm
+
+from .norm import ConvLayerNorm
+
+
+CONV_NORMALIZATIONS = frozenset(
+ [
+ "none",
+ "weight_norm",
+ "spectral_norm",
+ "time_layer_norm",
+ "layer_norm",
+ "time_group_norm",
+ ]
+)
+
+
+def apply_parametrization_norm(module: nn.Module, norm: str = "none") -> nn.Module:
+ assert norm in CONV_NORMALIZATIONS
+ if norm == "weight_norm":
+ return weight_norm(module)
+ elif norm == "spectral_norm":
+ return spectral_norm(module)
+ else:
+ # We already check was in CONV_NORMALIZATION, so any other choice
+ # doesn't need reparametrization.
+ return module
+
+
+def get_norm_module(
+ module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs
+) -> nn.Module:
+ """Return the proper normalization module. If causal is True, this will ensure the returned
+ module is causal, or return an error if the normalization doesn't support causal evaluation.
+ """
+ assert norm in CONV_NORMALIZATIONS
+ if norm == "layer_norm":
+ assert isinstance(module, nn.modules.conv._ConvNd)
+ return ConvLayerNorm(module.out_channels, **norm_kwargs)
+ elif norm == "time_group_norm":
+ if causal:
+ raise ValueError("GroupNorm doesn't support causal evaluation.")
+ assert isinstance(module, nn.modules.conv._ConvNd)
+ return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
+ else:
+ return nn.Identity()
+
+
+def get_extra_padding_for_conv1d(
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
+) -> int:
+ """See `pad_for_conv1d`."""
+ length = x.shape[-1]
+ n_frames = (length - kernel_size + padding_total) / stride + 1
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
+ return ideal_length - length
+
+
+def pad_for_conv1d(
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
+):
+ """Pad for a convolution to make sure that the last window is full.
+ Extra padding is added at the end. This is required to ensure that we can rebuild
+ an output of the same length, as otherwise, even with padding, some time steps
+ might get removed.
+ For instance, with total padding = 4, kernel size = 4, stride = 2:
+ 0 0 1 2 3 4 5 0 0 # (0s are padding)
+ 1 2 3 # (output frames of a convolution, last 0 is never used)
+ 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
+ 1 2 3 4 # once you removed padding, we are missing one time step !
+ """
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
+ return F.pad(x, (0, extra_padding))
+
+
+def pad1d(
+ x: torch.Tensor,
+ paddings: tp.Tuple[int, int],
+ mode: str = "zero",
+ value: float = 0.0,
+):
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
+ """
+ length = x.shape[-1]
+ padding_left, padding_right = paddings
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+ if mode == "reflect":
+ max_pad = max(padding_left, padding_right)
+ extra_pad = 0
+ if length <= max_pad:
+ extra_pad = max_pad - length + 1
+ x = F.pad(x, (0, extra_pad))
+ padded = F.pad(x, paddings, mode, value)
+ end = padded.shape[-1] - extra_pad
+ return padded[..., :end]
+ else:
+ return F.pad(x, paddings, mode, value)
+
+
+def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
+ padding_left, padding_right = paddings
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+ assert (padding_left + padding_right) <= x.shape[-1]
+ end = x.shape[-1] - padding_right
+ return x[..., padding_left:end]
+
+
+class NormConv1d(nn.Module):
+ """Wrapper around Conv1d and normalization applied to this conv
+ to provide a uniform interface across normalization approaches.
+ """
+
+ def __init__(
+ self,
+ *args,
+ causal: bool = False,
+ norm: str = "none",
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
+ **kwargs,
+ ):
+ super().__init__()
+ self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
+ self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
+ self.norm_type = norm
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.norm(x)
+ return x
+
+
+class NormConv2d(nn.Module):
+ """Wrapper around Conv2d and normalization applied to this conv
+ to provide a uniform interface across normalization approaches.
+ """
+
+ def __init__(
+ self,
+ *args,
+ norm: str = "none",
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
+ **kwargs,
+ ):
+ super().__init__()
+ self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
+ self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
+ self.norm_type = norm
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.norm(x)
+ return x
+
+
+class NormConvTranspose1d(nn.Module):
+ """Wrapper around ConvTranspose1d and normalization applied to this conv
+ to provide a uniform interface across normalization approaches.
+ """
+
+ def __init__(
+ self,
+ *args,
+ causal: bool = False,
+ norm: str = "none",
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
+ **kwargs,
+ ):
+ super().__init__()
+ self.convtr = apply_parametrization_norm(
+ nn.ConvTranspose1d(*args, **kwargs), norm
+ )
+ self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
+ self.norm_type = norm
+
+ def forward(self, x):
+ x = self.convtr(x)
+ x = self.norm(x)
+ return x
+
+
+class NormConvTranspose2d(nn.Module):
+ """Wrapper around ConvTranspose2d and normalization applied to this conv
+ to provide a uniform interface across normalization approaches.
+ """
+
+ def __init__(
+ self,
+ *args,
+ norm: str = "none",
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
+ **kwargs,
+ ):
+ super().__init__()
+ self.convtr = apply_parametrization_norm(
+ nn.ConvTranspose2d(*args, **kwargs), norm
+ )
+ self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
+
+ def forward(self, x):
+ x = self.convtr(x)
+ x = self.norm(x)
+ return x
+
+
+class SConv1d(nn.Module):
+ """Conv1d with some builtin handling of asymmetric or causal padding
+ and normalization.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int = 1,
+ dilation: int = 1,
+ groups: int = 1,
+ bias: bool = True,
+ causal: bool = False,
+ norm: str = "none",
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
+ pad_mode: str = "reflect",
+ ):
+ super().__init__()
+ # warn user on unusual setup between dilation and stride
+ if stride > 1 and dilation > 1:
+ warnings.warn(
+ "SConv1d has been initialized with stride > 1 and dilation > 1"
+ f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})."
+ )
+ self.conv = NormConv1d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ dilation=dilation,
+ groups=groups,
+ bias=bias,
+ causal=causal,
+ norm=norm,
+ norm_kwargs=norm_kwargs,
+ )
+ self.causal = causal
+ self.pad_mode = pad_mode
+
+ def forward(self, x):
+ B, C, T = x.shape
+ kernel_size = self.conv.conv.kernel_size[0]
+ stride = self.conv.conv.stride[0]
+ dilation = self.conv.conv.dilation[0]
+ padding_total = (kernel_size - 1) * dilation - (stride - 1)
+ extra_padding = get_extra_padding_for_conv1d(
+ x, kernel_size, stride, padding_total
+ )
+ if self.causal:
+ # Left padding for causal
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
+ else:
+ # Asymmetric padding required for odd strides
+ padding_right = padding_total // 2
+ padding_left = padding_total - padding_right
+ x = pad1d(
+ x, (padding_left, padding_right + extra_padding), mode=self.pad_mode
+ )
+ return self.conv(x)
+
+
+class SConvTranspose1d(nn.Module):
+ """ConvTranspose1d with some builtin handling of asymmetric or causal padding
+ and normalization.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int = 1,
+ causal: bool = False,
+ norm: str = "none",
+ trim_right_ratio: float = 1.0,
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
+ ):
+ super().__init__()
+ self.convtr = NormConvTranspose1d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ causal=causal,
+ norm=norm,
+ norm_kwargs=norm_kwargs,
+ )
+ self.causal = causal
+ self.trim_right_ratio = trim_right_ratio
+ assert (
+ self.causal or self.trim_right_ratio == 1.0
+ ), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
+ assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0
+
+ def forward(self, x):
+ kernel_size = self.convtr.convtr.kernel_size[0]
+ stride = self.convtr.convtr.stride[0]
+ padding_total = kernel_size - stride
+
+ y = self.convtr(x)
+
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
+ # removed at the very end, when keeping only the right length for the output,
+ # as removing it here would require also passing the length at the matching layer
+ # in the encoder.
+ if self.causal:
+ # Trim the padding on the right according to the specified ratio
+ # if trim_right_ratio = 1.0, trim everything from right
+ padding_right = math.ceil(padding_total * self.trim_right_ratio)
+ padding_left = padding_total - padding_right
+ y = unpad1d(y, (padding_left, padding_right))
+ else:
+ # Asymmetric padding required for odd strides
+ padding_right = padding_total // 2
+ padding_left = padding_total - padding_right
+ y = unpad1d(y, (padding_left, padding_right))
+ return y
diff --git a/models/codec/speechtokenizer/modules/lstm.py b/models/codec/speechtokenizer/modules/lstm.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f7e431238872e3175c5b379f69cc786bc0486a6
--- /dev/null
+++ b/models/codec/speechtokenizer/modules/lstm.py
@@ -0,0 +1,46 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+# This source file is copied from https://github.com/facebookresearch/encodec
+
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""LSTM layers module."""
+
+from torch import nn
+
+
+class SLSTM(nn.Module):
+ """
+ LSTM without worrying about the hidden state, nor the layout of the data.
+ Expects input as convolutional layout.
+ """
+
+ def __init__(
+ self,
+ dimension: int,
+ num_layers: int = 2,
+ skip: bool = True,
+ bidirectional: bool = False,
+ ):
+ super().__init__()
+ self.bidirectional = bidirectional
+ self.skip = skip
+ self.lstm = nn.LSTM(
+ dimension, dimension, num_layers, bidirectional=bidirectional
+ )
+
+ def forward(self, x):
+ x = x.permute(2, 0, 1)
+ y, _ = self.lstm(x)
+ if self.bidirectional:
+ x = x.repeat(1, 1, 2)
+ if self.skip:
+ y = y + x
+ y = y.permute(1, 2, 0)
+ return y
diff --git a/models/codec/speechtokenizer/modules/norm.py b/models/codec/speechtokenizer/modules/norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff5eaefd6b6103777d49c6fca2b071870371b7c5
--- /dev/null
+++ b/models/codec/speechtokenizer/modules/norm.py
@@ -0,0 +1,37 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+# This source file is copied from https://github.com/facebookresearch/encodec
+
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Normalization modules."""
+
+import typing as tp
+
+import einops
+import torch
+from torch import nn
+
+
+class ConvLayerNorm(nn.LayerNorm):
+ """
+ Convolution-friendly LayerNorm that moves channels to last dimensions
+ before running the normalization and moves them back to original position right after.
+ """
+
+ def __init__(
+ self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs
+ ):
+ super().__init__(normalized_shape, **kwargs)
+
+ def forward(self, x):
+ x = einops.rearrange(x, "b ... t -> b t ...")
+ x = super().forward(x)
+ x = einops.rearrange(x, "b t ... -> b ... t")
+ return
diff --git a/models/codec/speechtokenizer/modules/quantization/__init__.py b/models/codec/speechtokenizer/modules/quantization/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..79d90a1a2b310074fffdd0bd03cd3e60193a8de6
--- /dev/null
+++ b/models/codec/speechtokenizer/modules/quantization/__init__.py
@@ -0,0 +1,14 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+# This source file is copied from https://github.com/facebookresearch/encodec
+
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# flake8: noqa
+from .vq import QuantizedResult, ResidualVectorQuantizer
diff --git a/models/codec/speechtokenizer/modules/quantization/ac.py b/models/codec/speechtokenizer/modules/quantization/ac.py
new file mode 100644
index 0000000000000000000000000000000000000000..5695ea84451a110875c530558b93d9ea915500c7
--- /dev/null
+++ b/models/codec/speechtokenizer/modules/quantization/ac.py
@@ -0,0 +1,317 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+# This source file is copied from https://github.com/facebookresearch/encodec
+
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Arithmetic coder."""
+
+import io
+import math
+import random
+import typing as tp
+import torch
+
+from ..binary import BitPacker, BitUnpacker
+
+
+def build_stable_quantized_cdf(
+ pdf: torch.Tensor,
+ total_range_bits: int,
+ roundoff: float = 1e-8,
+ min_range: int = 2,
+ check: bool = True,
+) -> torch.Tensor:
+ """Turn the given PDF into a quantized CDF that splits
+ [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional
+ to the PDF.
+
+ Args:
+ pdf (torch.Tensor): probability distribution, shape should be `[N]`.
+ total_range_bits (int): see `ArithmeticCoder`, the typical range we expect
+ during the coding process is `[0, 2 ** total_range_bits - 1]`.
+ roundoff (float): will round the pdf up to that level to remove difference coming
+ from e.g. evaluating the Language Model on different architectures.
+ min_range (int): minimum range width. Should always be at least 2 for numerical
+ stability. Use this to avoid pathological behavior is a value
+ that is expected to be rare actually happens in real life.
+ check (bool): if True, checks that nothing bad happened, can be deactivated for speed.
+ """
+ pdf = pdf.detach()
+ if roundoff:
+ pdf = (pdf / roundoff).floor() * roundoff
+ # interpolate with uniform distribution to achieve desired minimum probability.
+ total_range = 2**total_range_bits
+ cardinality = len(pdf)
+ alpha = min_range * cardinality / total_range
+ assert alpha <= 1, "you must reduce min_range"
+ ranges = (((1 - alpha) * total_range) * pdf).floor().long()
+ ranges += min_range
+ quantized_cdf = torch.cumsum(ranges, dim=-1)
+ if min_range < 2:
+ raise ValueError("min_range must be at least 2.")
+ if check:
+ assert quantized_cdf[-1] <= 2**total_range_bits, quantized_cdf[-1]
+ if (
+ (quantized_cdf[1:] - quantized_cdf[:-1]) < min_range
+ ).any() or quantized_cdf[0] < min_range:
+ raise ValueError("You must increase your total_range_bits.")
+ return quantized_cdf
+
+
+class ArithmeticCoder:
+ """ArithmeticCoder,
+ Let us take a distribution `p` over `N` symbols, and assume we have a stream
+ of random variables `s_t` sampled from `p`. Let us assume that we have a budget
+ of `B` bits that we can afford to write on device. There are `2**B` possible numbers,
+ corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single
+ sequence `(s_t)` by doing the following:
+
+ 1) Initialize the current range to` [0 ** 2 B - 1]`.
+ 2) For each time step t, split the current range into contiguous chunks,
+ one for each possible outcome, with size roughly proportional to `p`.
+ For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks
+ would be `{[0, 2], [3, 3]}`.
+ 3) Select the chunk corresponding to `s_t`, and replace the current range with this.
+ 4) When done encoding all the values, just select any value remaining in the range.
+
+ You will notice that this procedure can fail: for instance if at any point in time
+ the range is smaller than `N`, then we can no longer assign a non-empty chunk to each
+ possible outcome. Intuitively, the more likely a value is, the less the range width
+ will reduce, and the longer we can go on encoding values. This makes sense: for any efficient
+ coding scheme, likely outcomes would take less bits, and more of them can be coded
+ with a fixed budget.
+
+ In practice, we do not know `B` ahead of time, but we have a way to inject new bits
+ when the current range decreases below a given limit (given by `total_range_bits`), without
+ having to redo all the computations. If we encode mostly likely values, we will seldom
+ need to inject new bits, but a single rare value can deplete our stock of entropy!
+
+ In this explanation, we assumed that the distribution `p` was constant. In fact, the present
+ code works for any sequence `(p_t)` possibly different for each timestep.
+ We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller
+ the KL between the true distribution and `p_t`, the most efficient the coding will be.
+
+ Args:
+ fo (IO[bytes]): file-like object to which the bytes will be written to.
+ total_range_bits (int): the range `M` described above is `2 ** total_range_bits.
+ Any time the current range width fall under this limit, new bits will
+ be injected to rescale the initial range.
+ """
+
+ def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
+ assert total_range_bits <= 30
+ self.total_range_bits = total_range_bits
+ self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time.
+ self.low: int = 0
+ self.high: int = 0
+ self.max_bit: int = -1
+ self._dbg: tp.List[tp.Any] = []
+ self._dbg2: tp.List[tp.Any] = []
+
+ @property
+ def delta(self) -> int:
+ """Return the current range width."""
+ return self.high - self.low + 1
+
+ def _flush_common_prefix(self):
+ # If self.low and self.high start with the sames bits,
+ # those won't change anymore as we always just increase the range
+ # by powers of 2, and we can flush them out to the bit stream.
+ assert self.high >= self.low, (self.low, self.high)
+ assert self.high < 2 ** (self.max_bit + 1)
+ while self.max_bit >= 0:
+ b1 = self.low >> self.max_bit
+ b2 = self.high >> self.max_bit
+ if b1 == b2:
+ self.low -= b1 << self.max_bit
+ self.high -= b1 << self.max_bit
+ assert self.high >= self.low, (self.high, self.low, self.max_bit)
+ assert self.low >= 0
+ self.max_bit -= 1
+ self.packer.push(b1)
+ else:
+ break
+
+ def push(self, symbol: int, quantized_cdf: torch.Tensor):
+ """Push the given symbol on the stream, flushing out bits
+ if possible.
+
+ Args:
+ symbol (int): symbol to encode with the AC.
+ quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
+ to build this from your pdf estimate.
+ """
+ while self.delta < 2**self.total_range_bits:
+ self.low *= 2
+ self.high = self.high * 2 + 1
+ self.max_bit += 1
+
+ range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
+ range_high = quantized_cdf[symbol].item() - 1
+ effective_low = int(
+ math.ceil(range_low * (self.delta / (2**self.total_range_bits)))
+ )
+ effective_high = int(
+ math.floor(range_high * (self.delta / (2**self.total_range_bits)))
+ )
+ assert self.low <= self.high
+ self.high = self.low + effective_high
+ self.low = self.low + effective_low
+ assert self.low <= self.high, (
+ effective_low,
+ effective_high,
+ range_low,
+ range_high,
+ )
+ self._dbg.append((self.low, self.high))
+ self._dbg2.append((self.low, self.high))
+ outs = self._flush_common_prefix()
+ assert self.low <= self.high
+ assert self.max_bit >= -1
+ assert self.max_bit <= 61, self.max_bit
+ return outs
+
+ def flush(self):
+ """Flush the remaining information to the stream."""
+ while self.max_bit >= 0:
+ b1 = (self.low >> self.max_bit) & 1
+ self.packer.push(b1)
+ self.max_bit -= 1
+ self.packer.flush()
+
+
+class ArithmeticDecoder:
+ """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation.
+
+ Note that this must be called with **exactly** the same parameters and sequence
+ of quantized cdf as the arithmetic encoder or the wrong values will be decoded.
+
+ If the AC encoder current range is [L, H], with `L` and `H` having the some common
+ prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream.
+ For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside
+ `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained
+ for a specific sequence of symbols and a binary-search allows us to decode those symbols.
+ At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols,
+ and we will need to read new bits from the stream and repeat the process.
+
+ """
+
+ def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
+ self.total_range_bits = total_range_bits
+ self.low: int = 0
+ self.high: int = 0
+ self.current: int = 0
+ self.max_bit: int = -1
+ self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time.
+ # Following is for debugging
+ self._dbg: tp.List[tp.Any] = []
+ self._dbg2: tp.List[tp.Any] = []
+ self._last: tp.Any = None
+
+ @property
+ def delta(self) -> int:
+ return self.high - self.low + 1
+
+ def _flush_common_prefix(self):
+ # Given the current range [L, H], if both have a common prefix,
+ # we know we can remove it from our representation to avoid handling large numbers.
+ while self.max_bit >= 0:
+ b1 = self.low >> self.max_bit
+ b2 = self.high >> self.max_bit
+ if b1 == b2:
+ self.low -= b1 << self.max_bit
+ self.high -= b1 << self.max_bit
+ self.current -= b1 << self.max_bit
+ assert self.high >= self.low
+ assert self.low >= 0
+ self.max_bit -= 1
+ else:
+ break
+
+ def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]:
+ """Pull a symbol, reading as many bits from the stream as required.
+ This returns `None` when the stream has been exhausted.
+
+ Args:
+ quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
+ to build this from your pdf estimate. This must be **exatly**
+ the same cdf as the one used at encoding time.
+ """
+ while self.delta < 2**self.total_range_bits:
+ bit = self.unpacker.pull()
+ if bit is None:
+ return None
+ self.low *= 2
+ self.high = self.high * 2 + 1
+ self.current = self.current * 2 + bit
+ self.max_bit += 1
+
+ def bin_search(low_idx: int, high_idx: int):
+ # Binary search is not just for coding interviews :)
+ if high_idx < low_idx:
+ raise RuntimeError("Binary search failed")
+ mid = (low_idx + high_idx) // 2
+ range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
+ range_high = quantized_cdf[mid].item() - 1
+ effective_low = int(
+ math.ceil(range_low * (self.delta / (2**self.total_range_bits)))
+ )
+ effective_high = int(
+ math.floor(range_high * (self.delta / (2**self.total_range_bits)))
+ )
+ low = effective_low + self.low
+ high = effective_high + self.low
+ if self.current >= low:
+ if self.current <= high:
+ return (mid, low, high, self.current)
+ else:
+ return bin_search(mid + 1, high_idx)
+ else:
+ return bin_search(low_idx, mid - 1)
+
+ self._last = (self.low, self.high, self.current, self.max_bit)
+ sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1)
+ self._dbg.append((self.low, self.high, self.current))
+ self._flush_common_prefix()
+ self._dbg2.append((self.low, self.high, self.current))
+
+ return sym
+
+
+def test():
+ torch.manual_seed(1234)
+ random.seed(1234)
+ for _ in range(4):
+ pdfs = []
+ cardinality = random.randrange(4000)
+ steps = random.randrange(100, 500)
+ fo = io.BytesIO()
+ encoder = ArithmeticCoder(fo)
+ symbols = []
+ for step in range(steps):
+ pdf = torch.softmax(torch.randn(cardinality), dim=0)
+ pdfs.append(pdf)
+ q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
+ symbol = torch.multinomial(pdf, 1).item()
+ symbols.append(symbol)
+ encoder.push(symbol, q_cdf)
+ encoder.flush()
+
+ fo.seek(0)
+ decoder = ArithmeticDecoder(fo)
+ for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)):
+ q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
+ decoded_symbol = decoder.pull(q_cdf)
+ assert decoded_symbol == symbol, idx
+ assert decoder.pull(torch.zeros(1)) is None
+
+
+if __name__ == "__main__":
+ test()
diff --git a/models/codec/speechtokenizer/modules/quantization/core_vq.py b/models/codec/speechtokenizer/modules/quantization/core_vq.py
new file mode 100644
index 0000000000000000000000000000000000000000..5799725598983bccb5c0644550f52303b15471c3
--- /dev/null
+++ b/models/codec/speechtokenizer/modules/quantization/core_vq.py
@@ -0,0 +1,388 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+# This source file is copied from https://github.com/facebookresearch/encodec
+
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+#
+# This implementation is inspired from
+# https://github.com/lucidrains/vector-quantize-pytorch
+# which is released under MIT License. Hereafter, the original license:
+# MIT License
+#
+# Copyright (c) 2020 Phil Wang
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+"""Core vector quantization implementation."""
+import typing as tp
+
+from einops import rearrange, repeat
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from .distrib import broadcast_tensors, rank
+
+
+def default(val: tp.Any, d: tp.Any) -> tp.Any:
+ return val if val is not None else d
+
+
+def ema_inplace(moving_avg, new, decay: float):
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
+
+
+def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
+ return (x + epsilon) / (x.sum() + n_categories * epsilon)
+
+
+def uniform_init(*shape: int):
+ t = torch.empty(shape)
+ nn.init.kaiming_uniform_(t)
+ return t
+
+
+def sample_vectors(samples, num: int):
+ num_samples, device = samples.shape[0], samples.device
+
+ if num_samples >= num:
+ indices = torch.randperm(num_samples, device=device)[:num]
+ else:
+ indices = torch.randint(0, num_samples, (num,), device=device)
+
+ return samples[indices]
+
+
+def kmeans(samples, num_clusters: int, num_iters: int = 10):
+ dim, dtype = samples.shape[-1], samples.dtype
+
+ means = sample_vectors(samples, num_clusters)
+
+ for _ in range(num_iters):
+ diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
+ dists = -(diffs**2).sum(dim=-1)
+
+ buckets = dists.max(dim=-1).indices
+ bins = torch.bincount(buckets, minlength=num_clusters)
+ zero_mask = bins == 0
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
+
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
+ new_means = new_means / bins_min_clamped[..., None]
+
+ means = torch.where(zero_mask[..., None], means, new_means)
+
+ return means, bins
+
+
+class EuclideanCodebook(nn.Module):
+ """Codebook with Euclidean distance.
+ Args:
+ dim (int): Dimension.
+ codebook_size (int): Codebook size.
+ kmeans_init (bool): Whether to use k-means to initialize the codebooks.
+ If set to true, run the k-means algorithm on the first training batch and use
+ the learned centroids as initialization.
+ kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
+ decay (float): Decay for exponential moving average over the codebooks.
+ epsilon (float): Epsilon value for numerical stability.
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
+ that have an exponential moving average cluster size less than the specified threshold with
+ randomly selected vector from the current batch.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ codebook_size: int,
+ kmeans_init: int = False,
+ kmeans_iters: int = 10,
+ decay: float = 0.99,
+ epsilon: float = 1e-5,
+ threshold_ema_dead_code: int = 2,
+ ):
+ super().__init__()
+ self.decay = decay
+ init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = (
+ uniform_init if not kmeans_init else torch.zeros
+ )
+ embed = init_fn(codebook_size, dim)
+
+ self.codebook_size = codebook_size
+
+ self.kmeans_iters = kmeans_iters
+ self.epsilon = epsilon
+ self.threshold_ema_dead_code = threshold_ema_dead_code
+
+ self.register_buffer("inited", torch.Tensor([not kmeans_init]))
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
+ self.register_buffer("embed", embed)
+ self.register_buffer("embed_avg", embed.clone())
+
+ @torch.jit.ignore
+ def init_embed_(self, data):
+ if self.inited:
+ return
+
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
+ self.embed.data.copy_(embed)
+ self.embed_avg.data.copy_(embed.clone())
+ self.cluster_size.data.copy_(cluster_size)
+ self.inited.data.copy_(torch.Tensor([True]))
+ # Make sure all buffers across workers are in sync after initialization
+ # broadcast_tensors(self.buffers())
+
+ def replace_(self, samples, mask):
+ modified_codebook = torch.where(
+ mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
+ )
+ self.embed.data.copy_(modified_codebook)
+
+ def expire_codes_(self, batch_samples):
+ if self.threshold_ema_dead_code == 0:
+ return
+
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
+ if not torch.any(expired_codes):
+ return
+
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
+ self.replace_(batch_samples, mask=expired_codes)
+ # broadcast_tensors(self.buffers())
+
+ def preprocess(self, x):
+ x = rearrange(x, "... d -> (...) d")
+ return x
+
+ def quantize(self, x):
+ embed = self.embed.t()
+ dist = -(
+ x.pow(2).sum(1, keepdim=True)
+ - 2 * x @ embed
+ + embed.pow(2).sum(0, keepdim=True)
+ )
+ embed_ind = dist.max(dim=-1).indices
+ return embed_ind
+
+ def postprocess_emb(self, embed_ind, shape):
+ return embed_ind.view(*shape[:-1])
+
+ def dequantize(self, embed_ind):
+ quantize = F.embedding(embed_ind, self.embed)
+ return quantize
+
+ def encode(self, x):
+ shape = x.shape
+ # pre-process
+ x = self.preprocess(x)
+ # quantize
+ embed_ind = self.quantize(x)
+ # post-process
+ embed_ind = self.postprocess_emb(embed_ind, shape)
+ return embed_ind
+
+ def decode(self, embed_ind):
+ quantize = self.dequantize(embed_ind)
+ return quantize
+
+ def forward(self, x):
+ shape, dtype = x.shape, x.dtype
+ x = self.preprocess(x)
+
+ self.init_embed_(x)
+
+ embed_ind = self.quantize(x)
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
+ embed_ind = self.postprocess_emb(embed_ind, shape)
+ quantize = self.dequantize(embed_ind)
+
+ if self.training:
+ # We do the expiry of code at that point as buffers are in sync
+ # and all the workers will take the same decision.
+ self.expire_codes_(x)
+ ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
+ embed_sum = x.t() @ embed_onehot
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
+ cluster_size = (
+ laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
+ * self.cluster_size.sum()
+ )
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
+ self.embed.data.copy_(embed_normalized)
+
+ return quantize, embed_ind
+
+
+class VectorQuantization(nn.Module):
+ """Vector quantization implementation.
+ Currently supports only euclidean distance.
+ Args:
+ dim (int): Dimension
+ codebook_size (int): Codebook size
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
+ decay (float): Decay for exponential moving average over the codebooks.
+ epsilon (float): Epsilon value for numerical stability.
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
+ that have an exponential moving average cluster size less than the specified threshold with
+ randomly selected vector from the current batch.
+ commitment_weight (float): Weight for commitment loss.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ codebook_size: int,
+ codebook_dim: tp.Optional[int] = None,
+ decay: float = 0.99,
+ epsilon: float = 1e-5,
+ kmeans_init: bool = True,
+ kmeans_iters: int = 50,
+ threshold_ema_dead_code: int = 2,
+ commitment_weight: float = 1.0,
+ ):
+ super().__init__()
+ _codebook_dim: int = default(codebook_dim, dim)
+
+ requires_projection = _codebook_dim != dim
+ self.project_in = (
+ nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
+ )
+ self.project_out = (
+ nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
+ )
+
+ self.epsilon = epsilon
+ self.commitment_weight = commitment_weight
+
+ self._codebook = EuclideanCodebook(
+ dim=_codebook_dim,
+ codebook_size=codebook_size,
+ kmeans_init=kmeans_init,
+ kmeans_iters=kmeans_iters,
+ decay=decay,
+ epsilon=epsilon,
+ threshold_ema_dead_code=threshold_ema_dead_code,
+ )
+ self.codebook_size = codebook_size
+
+ @property
+ def codebook(self):
+ return self._codebook.embed
+
+ def encode(self, x):
+ x = rearrange(x, "b d n -> b n d")
+ x = self.project_in(x)
+ embed_in = self._codebook.encode(x)
+ return embed_in
+
+ def decode(self, embed_ind):
+ quantize = self._codebook.decode(embed_ind)
+ quantize = self.project_out(quantize)
+ quantize = rearrange(quantize, "b n d -> b d n")
+ return quantize
+
+ def forward(self, x):
+ device = x.device
+ x = rearrange(x, "b d n -> b n d")
+ x = self.project_in(x)
+
+ quantize, embed_ind = self._codebook(x)
+
+ if self.training:
+ quantize = x + (quantize - x).detach()
+
+ loss = torch.tensor([0.0], device=device, requires_grad=self.training)
+
+ if self.training:
+ if self.commitment_weight > 0:
+ commit_loss = F.mse_loss(quantize.detach(), x)
+ loss = loss + commit_loss * self.commitment_weight
+
+ quantize = self.project_out(quantize)
+ quantize = rearrange(quantize, "b n d -> b d n")
+ return quantize, embed_ind, loss
+
+
+class ResidualVectorQuantization(nn.Module):
+ """Residual vector quantization implementation.
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
+ """
+
+ def __init__(self, *, num_quantizers, **kwargs):
+ super().__init__()
+ self.layers = nn.ModuleList(
+ [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
+ )
+
+ def forward(
+ self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None
+ ):
+ quantized_out = 0.0
+ residual = x
+
+ all_losses = []
+ all_indices = []
+ out_quantized = []
+
+ n_q = n_q or len(self.layers)
+
+ for i, layer in enumerate(self.layers[:n_q]):
+ quantized, indices, loss = layer(residual)
+ residual = residual - quantized
+ quantized_out = quantized_out + quantized
+
+ all_indices.append(indices)
+ all_losses.append(loss)
+ if layers and i in layers:
+ out_quantized.append(quantized)
+
+ out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
+ return quantized_out, out_indices, out_losses, out_quantized
+
+ def encode(
+ self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
+ ) -> torch.Tensor:
+ residual = x
+ all_indices = []
+ n_q = n_q or len(self.layers)
+ st = st or 0
+ for layer in self.layers[st:n_q]:
+ indices = layer.encode(residual)
+ quantized = layer.decode(indices)
+ residual = residual - quantized
+ all_indices.append(indices)
+ out_indices = torch.stack(all_indices)
+ return out_indices
+
+ def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor:
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
+ for i, indices in enumerate(q_indices):
+ layer = self.layers[st + i]
+ quantized = layer.decode(indices)
+ quantized_out = quantized_out + quantized
+ return quantized_out
diff --git a/models/codec/speechtokenizer/modules/quantization/distrib.py b/models/codec/speechtokenizer/modules/quantization/distrib.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b9a9b83e47cc3403354ce4c5e34eb0b279df5f2
--- /dev/null
+++ b/models/codec/speechtokenizer/modules/quantization/distrib.py
@@ -0,0 +1,135 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+# This source file is copied from https://github.com/facebookresearch/encodec
+
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Torch distributed utilities."""
+
+import typing as tp
+
+import torch
+
+
+def rank():
+ if torch.distributed.is_initialized():
+ return torch.distributed.get_rank()
+ else:
+ return 0
+
+
+def world_size():
+ if torch.distributed.is_initialized():
+ return torch.distributed.get_world_size()
+ else:
+ return 1
+
+
+def is_distributed():
+ return world_size() > 1
+
+
+def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
+ if is_distributed():
+ return torch.distributed.all_reduce(tensor, op)
+
+
+def _is_complex_or_float(tensor):
+ return torch.is_floating_point(tensor) or torch.is_complex(tensor)
+
+
+def _check_number_of_params(params: tp.List[torch.Tensor]):
+ # utility function to check that the number of params in all workers is the same,
+ # and thus avoid a deadlock with distributed all reduce.
+ if not is_distributed() or not params:
+ return
+ # print('params[0].device ', params[0].device)
+ tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
+ all_reduce(tensor)
+ if tensor.item() != len(params) * world_size():
+ # If not all the workers have the same number, for at least one of them,
+ # this inequality will be verified.
+ raise RuntimeError(
+ f"Mismatch in number of params: ours is {len(params)}, "
+ "at least one worker has a different one."
+ )
+
+
+def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
+ """Broadcast the tensors from the given parameters to all workers.
+ This can be used to ensure that all workers have the same model to start with.
+ """
+ if not is_distributed():
+ return
+ tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
+ _check_number_of_params(tensors)
+ handles = []
+ for tensor in tensors:
+ # src = int(rank()) # added code
+ handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
+ handles.append(handle)
+ for handle in handles:
+ handle.wait()
+
+
+def sync_buffer(buffers, average=True):
+ """
+ Sync grad for buffers. If average is False, broadcast instead of averaging.
+ """
+ if not is_distributed():
+ return
+ handles = []
+ for buffer in buffers:
+ if torch.is_floating_point(buffer.data):
+ if average:
+ handle = torch.distributed.all_reduce(
+ buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True
+ )
+ else:
+ handle = torch.distributed.broadcast(buffer.data, src=0, async_op=True)
+ handles.append((buffer, handle))
+ for buffer, handle in handles:
+ handle.wait()
+ if average:
+ buffer.data /= world_size
+
+
+def sync_grad(params):
+ """
+ Simpler alternative to DistributedDataParallel, that doesn't rely
+ on any black magic. For simple models it can also be as fast.
+ Just call this on your model parameters after the call to backward!
+ """
+ if not is_distributed():
+ return
+ handles = []
+ for p in params:
+ if p.grad is not None:
+ handle = torch.distributed.all_reduce(
+ p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True
+ )
+ handles.append((p, handle))
+ for p, handle in handles:
+ handle.wait()
+ p.grad.data /= world_size()
+
+
+def average_metrics(metrics: tp.Dict[str, float], count=1.0):
+ """Average a dictionary of metrics across all workers, using the optional
+ `count` as unormalized weight.
+ """
+ if not is_distributed():
+ return metrics
+ keys, values = zip(*metrics.items())
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
+ tensor *= count
+ all_reduce(tensor)
+ averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
+ return dict(zip(keys, averaged))
diff --git a/models/codec/speechtokenizer/modules/quantization/vq.py b/models/codec/speechtokenizer/modules/quantization/vq.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec7df0f9a0f58f2c757a710b682e10977ba58298
--- /dev/null
+++ b/models/codec/speechtokenizer/modules/quantization/vq.py
@@ -0,0 +1,125 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+# This source file is copied from https://github.com/facebookresearch/encodec
+
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Residual vector quantizer implementation."""
+
+from dataclasses import dataclass, field
+import math
+import typing as tp
+
+import torch
+from torch import nn
+
+from .core_vq import ResidualVectorQuantization
+
+
+@dataclass
+class QuantizedResult:
+ quantized: torch.Tensor
+ codes: torch.Tensor
+ bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
+ penalty: tp.Optional[torch.Tensor] = None
+ metrics: dict = field(default_factory=dict)
+
+
+class ResidualVectorQuantizer(nn.Module):
+ """Residual Vector Quantizer.
+ Args:
+ dimension (int): Dimension of the codebooks.
+ n_q (int): Number of residual vector quantizers used.
+ bins (int): Codebook size.
+ decay (float): Decay for exponential moving average over the codebooks.
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
+ that have an exponential moving average cluster size less than the specified threshold with
+ randomly selected vector from the current batch.
+ """
+
+ def __init__(
+ self,
+ dimension: int = 256,
+ n_q: int = 8,
+ bins: int = 1024,
+ decay: float = 0.99,
+ kmeans_init: bool = True,
+ kmeans_iters: int = 50,
+ threshold_ema_dead_code: int = 2,
+ ):
+ super().__init__()
+ self.n_q = n_q
+ self.dimension = dimension
+ self.bins = bins
+ self.decay = decay
+ self.kmeans_init = kmeans_init
+ self.kmeans_iters = kmeans_iters
+ self.threshold_ema_dead_code = threshold_ema_dead_code
+ self.vq = ResidualVectorQuantization(
+ dim=self.dimension,
+ codebook_size=self.bins,
+ num_quantizers=self.n_q,
+ decay=self.decay,
+ kmeans_init=self.kmeans_init,
+ kmeans_iters=self.kmeans_iters,
+ threshold_ema_dead_code=self.threshold_ema_dead_code,
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ n_q: tp.Optional[int] = None,
+ layers: tp.Optional[list] = None,
+ ) -> QuantizedResult:
+ """Residual vector quantization on the given input tensor.
+ Args:
+ x (torch.Tensor): Input tensor.
+ n_q (int): Number of quantizer used to quantize. Default: All quantizers.
+ layers (list): Layer that need to return quantized. Defalt: None.
+ Returns:
+ QuantizedResult:
+ The quantized (or approximately quantized) representation with
+ the associated numbert quantizers and layer quantized required to return.
+ """
+ n_q = n_q if n_q else self.n_q
+ if layers and max(layers) >= n_q:
+ raise ValueError(
+ f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B."
+ )
+ quantized, codes, commit_loss, quantized_list = self.vq(
+ x, n_q=n_q, layers=layers
+ )
+ return quantized, codes, torch.mean(commit_loss), quantized_list
+
+ def encode(
+ self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
+ ) -> torch.Tensor:
+ """Encode a given input tensor with the specified sample rate at the given bandwidth.
+ The RVQ encode method sets the appropriate number of quantizer to use
+ and returns indices for each quantizer.
+ Args:
+ x (torch.Tensor): Input tensor.
+ n_q (int): Number of quantizer used to quantize. Default: All quantizers.
+ st (int): Start to encode input from which layers. Default: 0.
+ """
+ n_q = n_q if n_q else self.n_q
+ st = st or 0
+ codes = self.vq.encode(x, n_q=n_q, st=st)
+ return codes
+
+ def decode(self, codes: torch.Tensor, st: int = 0) -> torch.Tensor:
+ """Decode the given codes to the quantized representation.
+ Args:
+ codes (torch.Tensor): Input indices for each quantizer.
+ st (int): Start to decode input codes from which layers. Default: 0.
+ """
+ quantized = self.vq.decode(codes, st=st)
+ return quantized
diff --git a/models/codec/speechtokenizer/modules/seanet.py b/models/codec/speechtokenizer/modules/seanet.py
new file mode 100644
index 0000000000000000000000000000000000000000..481de20c7ef05210e1bdf9092fe249060d06d686
--- /dev/null
+++ b/models/codec/speechtokenizer/modules/seanet.py
@@ -0,0 +1,414 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+# This source file is copied from https://github.com/facebookresearch/encodec
+
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Encodec SEANet-based encoder and decoder implementation."""
+
+import typing as tp
+
+import numpy as np
+import torch.nn as nn
+import torch
+
+from . import SConv1d, SConvTranspose1d, SLSTM
+
+
+@torch.jit.script
+def snake(x, alpha):
+ shape = x.shape
+ x = x.reshape(shape[0], shape[1], -1)
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
+ x = x.reshape(shape)
+ return x
+
+
+class Snake1d(nn.Module):
+ def __init__(self, channels):
+ super().__init__()
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
+
+ def forward(self, x):
+ return snake(x, self.alpha)
+
+
+class SEANetResnetBlock(nn.Module):
+ """Residual block from SEANet model.
+ Args:
+ dim (int): Dimension of the input/output
+ kernel_sizes (list): List of kernel sizes for the convolutions.
+ dilations (list): List of dilations for the convolutions.
+ activation (str): Activation function.
+ activation_params (dict): Parameters to provide to the activation function
+ norm (str): Normalization method.
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
+ causal (bool): Whether to use fully causal convolution.
+ pad_mode (str): Padding mode for the convolutions.
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3)
+ true_skip (bool): Whether to use true skip connection or a simple convolution as the skip connection.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ kernel_sizes: tp.List[int] = [3, 1],
+ dilations: tp.List[int] = [1, 1],
+ activation: str = "ELU",
+ activation_params: dict = {"alpha": 1.0},
+ norm: str = "weight_norm",
+ norm_params: tp.Dict[str, tp.Any] = {},
+ causal: bool = False,
+ pad_mode: str = "reflect",
+ compress: int = 2,
+ true_skip: bool = True,
+ ):
+ super().__init__()
+ assert len(kernel_sizes) == len(
+ dilations
+ ), "Number of kernel sizes should match number of dilations"
+ act = getattr(nn, activation) if activation != "Snake" else Snake1d
+ hidden = dim // compress
+ block = []
+ for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
+ in_chs = dim if i == 0 else hidden
+ out_chs = dim if i == len(kernel_sizes) - 1 else hidden
+ block += [
+ act(**activation_params) if activation != "Snake" else act(in_chs),
+ SConv1d(
+ in_chs,
+ out_chs,
+ kernel_size=kernel_size,
+ dilation=dilation,
+ norm=norm,
+ norm_kwargs=norm_params,
+ causal=causal,
+ pad_mode=pad_mode,
+ ),
+ ]
+ self.block = nn.Sequential(*block)
+ self.shortcut: nn.Module
+ if true_skip:
+ self.shortcut = nn.Identity()
+ else:
+ self.shortcut = SConv1d(
+ dim,
+ dim,
+ kernel_size=1,
+ norm=norm,
+ norm_kwargs=norm_params,
+ causal=causal,
+ pad_mode=pad_mode,
+ )
+
+ def forward(self, x):
+ return self.shortcut(x) + self.block(x)
+
+
+class SEANetEncoder(nn.Module):
+ """SEANet encoder.
+ Args:
+ channels (int): Audio channels.
+ dimension (int): Intermediate representation dimension.
+ n_filters (int): Base width for the model.
+ n_residual_layers (int): nb of residual layers.
+ ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
+ upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
+ that must match the decoder order
+ activation (str): Activation function.
+ activation_params (dict): Parameters to provide to the activation function
+ norm (str): Normalization method.
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
+ kernel_size (int): Kernel size for the initial convolution.
+ last_kernel_size (int): Kernel size for the initial convolution.
+ residual_kernel_size (int): Kernel size for the residual layers.
+ dilation_base (int): How much to increase the dilation with each layer.
+ causal (bool): Whether to use fully causal convolution.
+ pad_mode (str): Padding mode for the convolutions.
+ true_skip (bool): Whether to use true skip connection or a simple
+ (streamable) convolution as the skip connection in the residual network blocks.
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
+ lstm (int): Number of LSTM layers at the end of the encoder.
+ """
+
+ def __init__(
+ self,
+ channels: int = 1,
+ dimension: int = 128,
+ n_filters: int = 32,
+ n_residual_layers: int = 1,
+ ratios: tp.List[int] = [8, 5, 4, 2],
+ activation: str = "ELU",
+ activation_params: dict = {"alpha": 1.0},
+ norm: str = "weight_norm",
+ norm_params: tp.Dict[str, tp.Any] = {},
+ kernel_size: int = 7,
+ last_kernel_size: int = 7,
+ residual_kernel_size: int = 3,
+ dilation_base: int = 2,
+ causal: bool = False,
+ pad_mode: str = "reflect",
+ true_skip: bool = False,
+ compress: int = 2,
+ lstm: int = 2,
+ bidirectional: bool = False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.dimension = dimension
+ self.n_filters = n_filters
+ self.ratios = list(reversed(ratios))
+ del ratios
+ self.n_residual_layers = n_residual_layers
+ self.hop_length = np.prod(self.ratios) # 计算乘积
+
+ act = getattr(nn, activation) if activation != "Snake" else Snake1d
+ mult = 1
+ model: tp.List[nn.Module] = [
+ SConv1d(
+ channels,
+ mult * n_filters,
+ kernel_size,
+ norm=norm,
+ norm_kwargs=norm_params,
+ causal=causal,
+ pad_mode=pad_mode,
+ )
+ ]
+ # Downsample to raw audio scale
+ for i, ratio in enumerate(self.ratios):
+ # Add residual layers
+ for j in range(n_residual_layers):
+ model += [
+ SEANetResnetBlock(
+ mult * n_filters,
+ kernel_sizes=[residual_kernel_size, 1],
+ dilations=[dilation_base**j, 1],
+ norm=norm,
+ norm_params=norm_params,
+ activation=activation,
+ activation_params=activation_params,
+ causal=causal,
+ pad_mode=pad_mode,
+ compress=compress,
+ true_skip=true_skip,
+ )
+ ]
+
+ # Add downsampling layers
+ model += [
+ (
+ act(**activation_params)
+ if activation != "Snake"
+ else act(mult * n_filters)
+ ),
+ SConv1d(
+ mult * n_filters,
+ mult * n_filters * 2,
+ kernel_size=ratio * 2,
+ stride=ratio,
+ norm=norm,
+ norm_kwargs=norm_params,
+ causal=causal,
+ pad_mode=pad_mode,
+ ),
+ ]
+ mult *= 2
+
+ if lstm:
+ model += [
+ SLSTM(mult * n_filters, num_layers=lstm, bidirectional=bidirectional)
+ ]
+
+ mult = mult * 2 if bidirectional else mult
+ model += [
+ (
+ act(**activation_params)
+ if activation != "Snake"
+ else act(mult * n_filters)
+ ),
+ SConv1d(
+ mult * n_filters,
+ dimension,
+ last_kernel_size,
+ norm=norm,
+ norm_kwargs=norm_params,
+ causal=causal,
+ pad_mode=pad_mode,
+ ),
+ ]
+
+ self.model = nn.Sequential(*model)
+
+ def forward(self, x):
+ return self.model(x)
+
+
+class SEANetDecoder(nn.Module):
+ """SEANet decoder.
+ Args:
+ channels (int): Audio channels.
+ dimension (int): Intermediate representation dimension.
+ n_filters (int): Base width for the model.
+ n_residual_layers (int): nb of residual layers.
+ ratios (Sequence[int]): kernel size and stride ratios
+ activation (str): Activation function.
+ activation_params (dict): Parameters to provide to the activation function
+ final_activation (str): Final activation function after all convolutions.
+ final_activation_params (dict): Parameters to provide to the activation function
+ norm (str): Normalization method.
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
+ kernel_size (int): Kernel size for the initial convolution.
+ last_kernel_size (int): Kernel size for the initial convolution.
+ residual_kernel_size (int): Kernel size for the residual layers.
+ dilation_base (int): How much to increase the dilation with each layer.
+ causal (bool): Whether to use fully causal convolution.
+ pad_mode (str): Padding mode for the convolutions.
+ true_skip (bool): Whether to use true skip connection or a simple
+ (streamable) convolution as the skip connection in the residual network blocks.
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
+ lstm (int): Number of LSTM layers at the end of the encoder.
+ trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
+ If equal to 1.0, it means that all the trimming is done at the right.
+ """
+
+ def __init__(
+ self,
+ channels: int = 1,
+ dimension: int = 128,
+ n_filters: int = 32,
+ n_residual_layers: int = 1,
+ ratios: tp.List[int] = [8, 5, 4, 2],
+ activation: str = "ELU",
+ activation_params: dict = {"alpha": 1.0},
+ final_activation: tp.Optional[str] = None,
+ final_activation_params: tp.Optional[dict] = None,
+ norm: str = "weight_norm",
+ norm_params: tp.Dict[str, tp.Any] = {},
+ kernel_size: int = 7,
+ last_kernel_size: int = 7,
+ residual_kernel_size: int = 3,
+ dilation_base: int = 2,
+ causal: bool = False,
+ pad_mode: str = "reflect",
+ true_skip: bool = False,
+ compress: int = 2,
+ lstm: int = 2,
+ trim_right_ratio: float = 1.0,
+ bidirectional: bool = False,
+ ):
+ super().__init__()
+ self.dimension = dimension
+ self.channels = channels
+ self.n_filters = n_filters
+ self.ratios = ratios
+ del ratios
+ self.n_residual_layers = n_residual_layers
+ self.hop_length = np.prod(self.ratios)
+
+ act = getattr(nn, activation) if activation != "Snake" else Snake1d
+ mult = int(2 ** len(self.ratios))
+ model: tp.List[nn.Module] = [
+ SConv1d(
+ dimension,
+ mult * n_filters,
+ kernel_size,
+ norm=norm,
+ norm_kwargs=norm_params,
+ causal=causal,
+ pad_mode=pad_mode,
+ )
+ ]
+
+ if lstm:
+ model += [
+ SLSTM(mult * n_filters, num_layers=lstm, bidirectional=bidirectional)
+ ]
+
+ # Upsample to raw audio scale
+ for i, ratio in enumerate(self.ratios):
+ # Add upsampling layers
+ model += [
+ (
+ act(**activation_params)
+ if activation != "Snake"
+ else act(mult * n_filters)
+ ),
+ SConvTranspose1d(
+ mult * n_filters,
+ mult * n_filters // 2,
+ kernel_size=ratio * 2,
+ stride=ratio,
+ norm=norm,
+ norm_kwargs=norm_params,
+ causal=causal,
+ trim_right_ratio=trim_right_ratio,
+ ),
+ ]
+ # Add residual layers
+ for j in range(n_residual_layers):
+ model += [
+ SEANetResnetBlock(
+ mult * n_filters // 2,
+ kernel_sizes=[residual_kernel_size, 1],
+ dilations=[dilation_base**j, 1],
+ activation=activation,
+ activation_params=activation_params,
+ norm=norm,
+ norm_params=norm_params,
+ causal=causal,
+ pad_mode=pad_mode,
+ compress=compress,
+ true_skip=true_skip,
+ )
+ ]
+
+ mult //= 2
+
+ # Add final layers
+ model += [
+ act(**activation_params) if activation != "Snake" else act(n_filters),
+ SConv1d(
+ n_filters,
+ channels,
+ last_kernel_size,
+ norm=norm,
+ norm_kwargs=norm_params,
+ causal=causal,
+ pad_mode=pad_mode,
+ ),
+ ]
+ # Add optional final activation to decoder (eg. tanh)
+ if final_activation is not None:
+ final_act = getattr(nn, final_activation)
+ final_activation_params = final_activation_params or {}
+ model += [final_act(**final_activation_params)]
+ self.model = nn.Sequential(*model)
+
+ def forward(self, z):
+ y = self.model(z)
+ return y
+
+
+def test():
+ import torch
+
+ encoder = SEANetEncoder()
+ decoder = SEANetDecoder()
+ x = torch.randn(1, 1, 24000)
+ z = encoder(x)
+ print("z ", z.shape)
+ assert 1 == 2
+ assert list(z.shape) == [1, 128, 75], z.shape
+ y = decoder(z)
+ assert y.shape == x.shape, (x.shape, y.shape)
+
+
+if __name__ == "__main__":
+ test()
diff --git a/models/svc/__init__.py b/models/svc/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/svc/base/__init__.py b/models/svc/base/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38c2b1686db550b3b9892b8bc6e594cd847aafd1
--- /dev/null
+++ b/models/svc/base/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .svc_inference import SVCInference
+from .svc_trainer import SVCTrainer
diff --git a/models/svc/base/svc_dataset.py b/models/svc/base/svc_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8ef7d5adb557b84f3220fd1f45b1fc39de9a433
--- /dev/null
+++ b/models/svc/base/svc_dataset.py
@@ -0,0 +1,595 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import random
+import torch
+from torch.nn.utils.rnn import pad_sequence
+import json
+import os
+import numpy as np
+import librosa
+
+from utils.data_utils import *
+from processors.acoustic_extractor import cal_normalized_mel, load_mel_extrema
+from processors.content_extractor import (
+ ContentvecExtractor,
+ WhisperExtractor,
+ WenetExtractor,
+)
+from models.base.base_dataset import (
+ BaseOfflineDataset,
+ BaseOfflineCollator,
+ BaseOnlineDataset,
+ BaseOnlineCollator,
+)
+from models.base.new_dataset import BaseTestDataset
+
+EPS = 1.0e-12
+
+
+class SVCOfflineDataset(BaseOfflineDataset):
+ def __init__(self, cfg, dataset, is_valid=False):
+ BaseOfflineDataset.__init__(self, cfg, dataset, is_valid=is_valid)
+
+ cfg = self.cfg
+
+ if cfg.model.condition_encoder.use_whisper:
+ self.whisper_aligner = WhisperExtractor(self.cfg)
+ self.utt2whisper_path = load_content_feature_path(
+ self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.whisper_dir
+ )
+
+ if cfg.model.condition_encoder.use_contentvec:
+ self.contentvec_aligner = ContentvecExtractor(self.cfg)
+ self.utt2contentVec_path = load_content_feature_path(
+ self.metadata,
+ cfg.preprocess.processed_dir,
+ cfg.preprocess.contentvec_dir,
+ )
+
+ if cfg.model.condition_encoder.use_mert:
+ self.utt2mert_path = load_content_feature_path(
+ self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.mert_dir
+ )
+ if cfg.model.condition_encoder.use_wenet:
+ self.wenet_aligner = WenetExtractor(self.cfg)
+ self.utt2wenet_path = load_content_feature_path(
+ self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.wenet_dir
+ )
+
+ def __getitem__(self, index):
+ single_feature = BaseOfflineDataset.__getitem__(self, index)
+
+ utt_info = self.metadata[index]
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ if self.cfg.model.condition_encoder.use_whisper:
+ assert "target_len" in single_feature.keys()
+ aligned_whisper_feat = (
+ self.whisper_aligner.offline_resolution_transformation(
+ np.load(self.utt2whisper_path[utt]), single_feature["target_len"]
+ )
+ )
+ single_feature["whisper_feat"] = aligned_whisper_feat
+
+ if self.cfg.model.condition_encoder.use_contentvec:
+ assert "target_len" in single_feature.keys()
+ aligned_contentvec = (
+ self.contentvec_aligner.offline_resolution_transformation(
+ np.load(self.utt2contentVec_path[utt]), single_feature["target_len"]
+ )
+ )
+ single_feature["contentvec_feat"] = aligned_contentvec
+
+ if self.cfg.model.condition_encoder.use_mert:
+ assert "target_len" in single_feature.keys()
+ aligned_mert_feat = align_content_feature_length(
+ np.load(self.utt2mert_path[utt]),
+ single_feature["target_len"],
+ source_hop=self.cfg.preprocess.mert_hop_size,
+ )
+ single_feature["mert_feat"] = aligned_mert_feat
+
+ if self.cfg.model.condition_encoder.use_wenet:
+ assert "target_len" in single_feature.keys()
+ aligned_wenet_feat = self.wenet_aligner.offline_resolution_transformation(
+ np.load(self.utt2wenet_path[utt]), single_feature["target_len"]
+ )
+ single_feature["wenet_feat"] = aligned_wenet_feat
+
+ # print(single_feature.keys())
+ # for k, v in single_feature.items():
+ # if type(v) in [torch.Tensor, np.ndarray]:
+ # print(k, v.shape)
+ # else:
+ # print(k, v)
+ # exit()
+
+ return self.clip_if_too_long(single_feature)
+
+ def __len__(self):
+ return len(self.metadata)
+
+ def random_select(self, feature_seq_len, max_seq_len, ending_ts=2812):
+ """
+ ending_ts: to avoid invalid whisper features for over 30s audios
+ 2812 = 30 * 24000 // 256
+ """
+ ts = max(feature_seq_len - max_seq_len, 0)
+ ts = min(ts, ending_ts - max_seq_len)
+
+ start = random.randint(0, ts)
+ end = start + max_seq_len
+ return start, end
+
+ def clip_if_too_long(self, sample, max_seq_len=512):
+ """
+ sample :
+ {
+ 'spk_id': (1,),
+ 'target_len': int
+ 'mel': (seq_len, dim),
+ 'frame_pitch': (seq_len,)
+ 'frame_energy': (seq_len,)
+ 'content_vector_feat': (seq_len, dim)
+ }
+ """
+
+ if sample["target_len"] <= max_seq_len:
+ return sample
+
+ start, end = self.random_select(sample["target_len"], max_seq_len)
+ sample["target_len"] = end - start
+
+ for k in sample.keys():
+ if k == "audio":
+ # audio should be clipped in hop_size scale
+ sample[k] = sample[k][
+ start
+ * self.cfg.preprocess.hop_size : end
+ * self.cfg.preprocess.hop_size
+ ]
+ elif k == "audio_len":
+ sample[k] = (end - start) * self.cfg.preprocess.hop_size
+ elif k not in ["spk_id", "target_len"]:
+ sample[k] = sample[k][start:end]
+
+ return sample
+
+
+class SVCOnlineDataset(BaseOnlineDataset):
+ def __init__(self, cfg, dataset, is_valid=False):
+ super().__init__(cfg, dataset, is_valid=is_valid)
+
+ # Audio pretrained models' sample rates
+ self.all_sample_rates = {self.sample_rate}
+ if self.cfg.model.condition_encoder.use_whisper:
+ self.all_sample_rates.add(self.cfg.preprocess.whisper_sample_rate)
+ if self.cfg.model.condition_encoder.use_contentvec:
+ self.all_sample_rates.add(self.cfg.preprocess.contentvec_sample_rate)
+ if self.cfg.model.condition_encoder.use_wenet:
+ self.all_sample_rates.add(self.cfg.preprocess.wenet_sample_rate)
+
+ self.highest_sample_rate = max(list(self.all_sample_rates))
+
+ # The maximum duration (seconds) for one training sample
+ self.max_duration = 6.0
+ self.max_n_frames = int(self.max_duration * self.highest_sample_rate)
+
+ def random_select(self, wav, duration, wav_path):
+ """
+ wav: (T,)
+ """
+ if duration <= self.max_duration:
+ return wav
+
+ ts_frame = int((duration - self.max_duration) * self.highest_sample_rate)
+ start = random.randint(0, ts_frame)
+ end = start + self.max_n_frames
+
+ if (wav[start:end] == 0).all():
+ print("*" * 20)
+ print("Warning! The wav file {} has a lot of silience.".format(wav_path))
+
+ # There should be at least some frames that are not silience. Then we select them.
+ assert (wav != 0).any()
+ start = np.where(wav != 0)[0][0]
+ end = start + self.max_n_frames
+
+ return wav[start:end]
+
+ def __getitem__(self, index):
+ """
+ single_feature: dict,
+ wav: (T,)
+ wav_len: int
+ target_len: int
+ mask: (n_frames, 1)
+ spk_id
+
+ wav_{sr}: (T,)
+ wav_{sr}_len: int
+ """
+ single_feature = dict()
+
+ utt_item = self.metadata[index]
+ wav_path = utt_item["Path"]
+
+ ### Use the highest sampling rate to load and randomly select ###
+ highest_sr_wav, _ = librosa.load(wav_path, sr=self.highest_sample_rate)
+ highest_sr_wav = self.random_select(
+ highest_sr_wav, utt_item["Duration"], wav_path
+ )
+
+ ### Waveforms under all the sample rates ###
+ for sr in self.all_sample_rates:
+ # Resample to the required sample rate
+ if sr != self.highest_sample_rate:
+ wav_sr = librosa.resample(
+ highest_sr_wav, orig_sr=self.highest_sample_rate, target_sr=sr
+ )
+ else:
+ wav_sr = highest_sr_wav
+
+ wav_sr = torch.as_tensor(wav_sr, dtype=torch.float32)
+ single_feature["wav_{}".format(sr)] = wav_sr
+ single_feature["wav_{}_len".format(sr)] = len(wav_sr)
+
+ # For target sample rate
+ if sr == self.sample_rate:
+ wav_len = len(wav_sr)
+ frame_len = wav_len // self.hop_size
+
+ single_feature["wav"] = wav_sr
+ single_feature["wav_len"] = wav_len
+ single_feature["target_len"] = frame_len
+ single_feature["mask"] = torch.ones(frame_len, 1, dtype=torch.long)
+
+ ### Speaker ID ###
+ if self.cfg.preprocess.use_spkid:
+ utt = "{}_{}".format(utt_item["Dataset"], utt_item["Uid"])
+ single_feature["spk_id"] = torch.tensor(
+ [self.spk2id[self.utt2spk[utt]]], dtype=torch.int32
+ )
+
+ return single_feature
+
+ def __len__(self):
+ return len(self.metadata)
+
+
+class SVCOfflineCollator(BaseOfflineCollator):
+ def __init__(self, cfg):
+ super().__init__(cfg)
+
+ def __call__(self, batch):
+ parsed_batch_features = super().__call__(batch)
+ return parsed_batch_features
+
+
+class SVCOnlineCollator(BaseOnlineCollator):
+ def __init__(self, cfg):
+ super().__init__(cfg)
+
+ def __call__(self, batch):
+ """
+ SVCOnlineDataset.__getitem__:
+ wav: (T,)
+ wav_len: int
+ target_len: int
+ mask: (n_frames, 1)
+ spk_id: (1)
+
+ wav_{sr}: (T,)
+ wav_{sr}_len: int
+
+ Returns:
+ wav: (B, T), torch.float32
+ wav_len: (B), torch.long
+ target_len: (B), torch.long
+ mask: (B, n_frames, 1), torch.long
+ spk_id: (B, 1), torch.int32
+
+ wav_{sr}: (B, T)
+ wav_{sr}_len: (B), torch.long
+ """
+ packed_batch_features = dict()
+
+ for key in batch[0].keys():
+ if "_len" in key:
+ packed_batch_features[key] = torch.LongTensor([b[key] for b in batch])
+ else:
+ packed_batch_features[key] = pad_sequence(
+ [b[key] for b in batch], batch_first=True, padding_value=0
+ )
+ return packed_batch_features
+
+
+class SVCTestDataset(BaseTestDataset):
+ def __init__(self, args, cfg, infer_type):
+ BaseTestDataset.__init__(self, args, cfg, infer_type)
+ self.metadata = self.get_metadata()
+
+ target_singer = args.target_singer
+ self.cfg = cfg
+ self.trans_key = args.trans_key
+ assert type(target_singer) == str
+
+ self.target_singer = target_singer.split("_")[-1]
+ self.target_dataset = target_singer.replace(
+ "_{}".format(self.target_singer), ""
+ )
+ if cfg.preprocess.mel_min_max_norm:
+ if self.cfg.preprocess.features_extraction_mode == "online":
+ # TODO: Change the hard code
+
+ # Using an empirical mel extrema to normalize
+ self.target_mel_extrema = load_mel_extrema(cfg.preprocess, "vctk")
+ else:
+ self.target_mel_extrema = load_mel_extrema(
+ cfg.preprocess, self.target_dataset
+ )
+
+ self.target_mel_extrema = torch.as_tensor(
+ self.target_mel_extrema[0]
+ ), torch.as_tensor(self.target_mel_extrema[1])
+
+ ######### Load source acoustic features #########
+ if cfg.preprocess.use_spkid:
+ spk2id_path = os.path.join(args.acoustics_dir, cfg.preprocess.spk2id)
+ # utt2sp_path = os.path.join(self.data_root, cfg.preprocess.utt2spk)
+
+ with open(spk2id_path, "r", encoding="utf-8") as f:
+ self.spk2id = json.load(f)
+ # print("self.spk2id", self.spk2id)
+
+ if cfg.preprocess.use_uv:
+ self.utt2uv_path = {
+ f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join(
+ cfg.preprocess.processed_dir,
+ utt_info["Dataset"],
+ cfg.preprocess.uv_dir,
+ utt_info["Uid"] + ".npy",
+ )
+ for utt_info in self.metadata
+ }
+
+ if cfg.preprocess.use_frame_pitch:
+ self.utt2frame_pitch_path = {
+ f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join(
+ cfg.preprocess.processed_dir,
+ utt_info["Dataset"],
+ cfg.preprocess.pitch_dir,
+ utt_info["Uid"] + ".npy",
+ )
+ for utt_info in self.metadata
+ }
+
+ # Target F0 median
+ target_f0_statistics_path = os.path.join(
+ cfg.preprocess.processed_dir,
+ self.target_dataset,
+ cfg.preprocess.pitch_dir,
+ "statistics.json",
+ )
+ self.target_pitch_median = json.load(
+ open(target_f0_statistics_path, "r", encoding="utf-8")
+ )[f"{self.target_dataset}_{self.target_singer}"]["voiced_positions"][
+ "median"
+ ]
+
+ # Source F0 median (if infer from file)
+ if infer_type == "from_file":
+ source_audio_name = cfg.inference.source_audio_name
+ source_f0_statistics_path = os.path.join(
+ cfg.preprocess.processed_dir,
+ source_audio_name,
+ cfg.preprocess.pitch_dir,
+ "statistics.json",
+ )
+ self.source_pitch_median = json.load(
+ open(source_f0_statistics_path, "r", encoding="utf-8")
+ )[f"{source_audio_name}_{source_audio_name}"]["voiced_positions"][
+ "median"
+ ]
+ else:
+ self.source_pitch_median = None
+
+ if cfg.preprocess.use_frame_energy:
+ self.utt2frame_energy_path = {
+ f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join(
+ cfg.preprocess.processed_dir,
+ utt_info["Dataset"],
+ cfg.preprocess.energy_dir,
+ utt_info["Uid"] + ".npy",
+ )
+ for utt_info in self.metadata
+ }
+
+ if cfg.preprocess.use_mel:
+ self.utt2mel_path = {
+ f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join(
+ cfg.preprocess.processed_dir,
+ utt_info["Dataset"],
+ cfg.preprocess.mel_dir,
+ utt_info["Uid"] + ".npy",
+ )
+ for utt_info in self.metadata
+ }
+
+ ######### Load source content features' path #########
+ if cfg.model.condition_encoder.use_whisper:
+ self.whisper_aligner = WhisperExtractor(cfg)
+ self.utt2whisper_path = load_content_feature_path(
+ self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.whisper_dir
+ )
+
+ if cfg.model.condition_encoder.use_contentvec:
+ self.contentvec_aligner = ContentvecExtractor(cfg)
+ self.utt2contentVec_path = load_content_feature_path(
+ self.metadata,
+ cfg.preprocess.processed_dir,
+ cfg.preprocess.contentvec_dir,
+ )
+
+ if cfg.model.condition_encoder.use_mert:
+ self.utt2mert_path = load_content_feature_path(
+ self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.mert_dir
+ )
+ if cfg.model.condition_encoder.use_wenet:
+ self.wenet_aligner = WenetExtractor(cfg)
+ self.utt2wenet_path = load_content_feature_path(
+ self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.wenet_dir
+ )
+
+ def __getitem__(self, index):
+ single_feature = {}
+
+ utt_info = self.metadata[index]
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ source_dataset = self.metadata[index]["Dataset"]
+
+ if self.cfg.preprocess.use_spkid:
+ single_feature["spk_id"] = np.array(
+ [self.spk2id[f"{self.target_dataset}_{self.target_singer}"]],
+ dtype=np.int32,
+ )
+
+ ######### Get Acoustic Features Item #########
+ if self.cfg.preprocess.use_mel:
+ mel = np.load(self.utt2mel_path[utt])
+ assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T]
+ if self.cfg.preprocess.use_min_max_norm_mel:
+ # mel norm
+ mel = cal_normalized_mel(mel, source_dataset, self.cfg.preprocess)
+
+ if "target_len" not in single_feature.keys():
+ single_feature["target_len"] = mel.shape[1]
+ single_feature["mel"] = mel.T # [T, n_mels]
+
+ if self.cfg.preprocess.use_frame_pitch:
+ frame_pitch_path = self.utt2frame_pitch_path[utt]
+ frame_pitch = np.load(frame_pitch_path)
+
+ if self.trans_key:
+ try:
+ self.trans_key = int(self.trans_key)
+ except:
+ pass
+ if type(self.trans_key) == int:
+ frame_pitch = transpose_key(frame_pitch, self.trans_key)
+ elif self.trans_key:
+ assert self.target_singer
+
+ frame_pitch = pitch_shift_to_target(
+ frame_pitch, self.target_pitch_median, self.source_pitch_median
+ )
+
+ if "target_len" not in single_feature.keys():
+ single_feature["target_len"] = len(frame_pitch)
+ aligned_frame_pitch = align_length(
+ frame_pitch, single_feature["target_len"]
+ )
+ single_feature["frame_pitch"] = aligned_frame_pitch
+
+ if self.cfg.preprocess.use_uv:
+ frame_uv_path = self.utt2uv_path[utt]
+ frame_uv = np.load(frame_uv_path)
+ aligned_frame_uv = align_length(frame_uv, single_feature["target_len"])
+ aligned_frame_uv = [
+ 0 if frame_uv else 1 for frame_uv in aligned_frame_uv
+ ]
+ aligned_frame_uv = np.array(aligned_frame_uv)
+ single_feature["frame_uv"] = aligned_frame_uv
+
+ if self.cfg.preprocess.use_frame_energy:
+ frame_energy_path = self.utt2frame_energy_path[utt]
+ frame_energy = np.load(frame_energy_path)
+ if "target_len" not in single_feature.keys():
+ single_feature["target_len"] = len(frame_energy)
+ aligned_frame_energy = align_length(
+ frame_energy, single_feature["target_len"]
+ )
+ single_feature["frame_energy"] = aligned_frame_energy
+
+ ######### Get Content Features Item #########
+ if self.cfg.model.condition_encoder.use_whisper:
+ assert "target_len" in single_feature.keys()
+ aligned_whisper_feat = (
+ self.whisper_aligner.offline_resolution_transformation(
+ np.load(self.utt2whisper_path[utt]), single_feature["target_len"]
+ )
+ )
+ single_feature["whisper_feat"] = aligned_whisper_feat
+
+ if self.cfg.model.condition_encoder.use_contentvec:
+ assert "target_len" in single_feature.keys()
+ aligned_contentvec = (
+ self.contentvec_aligner.offline_resolution_transformation(
+ np.load(self.utt2contentVec_path[utt]), single_feature["target_len"]
+ )
+ )
+ single_feature["contentvec_feat"] = aligned_contentvec
+
+ if self.cfg.model.condition_encoder.use_mert:
+ assert "target_len" in single_feature.keys()
+ aligned_mert_feat = align_content_feature_length(
+ np.load(self.utt2mert_path[utt]),
+ single_feature["target_len"],
+ source_hop=self.cfg.preprocess.mert_hop_size,
+ )
+ single_feature["mert_feat"] = aligned_mert_feat
+
+ if self.cfg.model.condition_encoder.use_wenet:
+ assert "target_len" in single_feature.keys()
+ aligned_wenet_feat = self.wenet_aligner.offline_resolution_transformation(
+ np.load(self.utt2wenet_path[utt]), single_feature["target_len"]
+ )
+ single_feature["wenet_feat"] = aligned_wenet_feat
+
+ return single_feature
+
+ def __len__(self):
+ return len(self.metadata)
+
+
+class SVCTestCollator:
+ """Zero-pads model inputs and targets based on number of frames per step"""
+
+ def __init__(self, cfg):
+ self.cfg = cfg
+
+ def __call__(self, batch):
+ packed_batch_features = dict()
+
+ # mel: [b, T, n_mels]
+ # frame_pitch, frame_energy: [1, T]
+ # target_len: [1]
+ # spk_id: [b, 1]
+ # mask: [b, T, 1]
+
+ for key in batch[0].keys():
+ if key == "target_len":
+ packed_batch_features["target_len"] = torch.LongTensor(
+ [b["target_len"] for b in batch]
+ )
+ masks = [
+ torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
+ ]
+ packed_batch_features["mask"] = pad_sequence(
+ masks, batch_first=True, padding_value=0
+ )
+ else:
+ values = [torch.from_numpy(b[key]) for b in batch]
+ packed_batch_features[key] = pad_sequence(
+ values, batch_first=True, padding_value=0
+ )
+
+ return packed_batch_features
diff --git a/models/svc/base/svc_inference.py b/models/svc/base/svc_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..52f88d5d915e1616292c03927b4f51557351f58b
--- /dev/null
+++ b/models/svc/base/svc_inference.py
@@ -0,0 +1,15 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from models.base.new_inference import BaseInference
+from models.svc.base.svc_dataset import SVCTestCollator, SVCTestDataset
+
+
+class SVCInference(BaseInference):
+ def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
+ BaseInference.__init__(self, args, cfg, infer_type)
+
+ def _build_test_dataset(self):
+ return SVCTestDataset, SVCTestCollator
diff --git a/models/svc/base/svc_trainer.py b/models/svc/base/svc_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2d8f69c115dbb09ee4cde1968ece1c7084d7394
--- /dev/null
+++ b/models/svc/base/svc_trainer.py
@@ -0,0 +1,265 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import json
+import os
+
+import torch
+import torch.nn as nn
+import numpy as np
+
+from models.base.new_trainer import BaseTrainer
+from models.svc.base.svc_dataset import (
+ SVCOfflineCollator,
+ SVCOfflineDataset,
+ SVCOnlineCollator,
+ SVCOnlineDataset,
+)
+from processors.audio_features_extractor import AudioFeaturesExtractor
+from processors.acoustic_extractor import cal_normalized_mel, load_mel_extrema
+
+EPS = 1.0e-12
+
+
+class SVCTrainer(BaseTrainer):
+ r"""The base trainer for all SVC models. It inherits from BaseTrainer and implements
+ ``build_criterion``, ``_build_dataset`` and ``_build_singer_lut`` methods. You can inherit from this
+ class, and implement ``_build_model``, ``_forward_step``.
+ """
+
+ def __init__(self, args=None, cfg=None):
+ self.args = args
+ self.cfg = cfg
+
+ self._init_accelerator()
+
+ # Only for SVC tasks
+ with self.accelerator.main_process_first():
+ self.singers = self._build_singer_lut()
+
+ # Super init
+ BaseTrainer.__init__(self, args, cfg)
+
+ # Only for SVC tasks
+ self.task_type = "SVC"
+ self.logger.info("Task type: {}".format(self.task_type))
+
+ ### Following are methods only for SVC tasks ###
+ def _build_dataset(self):
+ self.online_features_extraction = (
+ self.cfg.preprocess.features_extraction_mode == "online"
+ )
+
+ if not self.online_features_extraction:
+ return SVCOfflineDataset, SVCOfflineCollator
+ else:
+ self.audio_features_extractor = AudioFeaturesExtractor(self.cfg)
+ return SVCOnlineDataset, SVCOnlineCollator
+
+ def _extract_svc_features(self, batch):
+ """
+ Features extraction during training
+
+ Batch:
+ wav: (B, T)
+ wav_len: (B)
+ target_len: (B)
+ mask: (B, n_frames, 1)
+ spk_id: (B, 1)
+
+ wav_{sr}: (B, T)
+ wav_{sr}_len: (B)
+
+ Added elements when output:
+ mel: (B, n_frames, n_mels)
+ frame_pitch: (B, n_frames)
+ frame_uv: (B, n_frames)
+ frame_energy: (B, n_frames)
+ frame_{content}: (B, n_frames, D)
+ """
+
+ padded_n_frames = torch.max(batch["target_len"])
+ final_n_frames = padded_n_frames
+
+ ### Mel Spectrogram ###
+ if self.cfg.preprocess.use_mel:
+ # (B, n_mels, n_frames)
+ raw_mel = self.audio_features_extractor.get_mel_spectrogram(batch["wav"])
+ if self.cfg.preprocess.use_min_max_norm_mel:
+ # TODO: Change the hard code
+
+ # Using the empirical mel extrema to denormalize
+ if not hasattr(self, "mel_extrema"):
+ # (n_mels)
+ m, M = load_mel_extrema(self.cfg.preprocess, "vctk")
+ # (1, n_mels, 1)
+ m = (
+ torch.as_tensor(m, device=raw_mel.device)
+ .unsqueeze(0)
+ .unsqueeze(-1)
+ )
+ M = (
+ torch.as_tensor(M, device=raw_mel.device)
+ .unsqueeze(0)
+ .unsqueeze(-1)
+ )
+ self.mel_extrema = m, M
+
+ m, M = self.mel_extrema
+ mel = (raw_mel - m) / (M - m + EPS) * 2 - 1
+
+ else:
+ mel = raw_mel
+
+ final_n_frames = min(final_n_frames, mel.size(-1))
+
+ # (B, n_frames, n_mels)
+ batch["mel"] = mel.transpose(1, 2)
+ else:
+ raw_mel = None
+
+ ### F0 ###
+ if self.cfg.preprocess.use_frame_pitch:
+ # (B, n_frames)
+ raw_f0, raw_uv = self.audio_features_extractor.get_f0(
+ batch["wav"],
+ wav_lens=batch["wav_len"],
+ use_interpolate=self.cfg.preprocess.use_interpolation_for_uv,
+ return_uv=True,
+ )
+ final_n_frames = min(final_n_frames, raw_f0.size(-1))
+ batch["frame_pitch"] = raw_f0
+
+ if self.cfg.preprocess.use_uv:
+ batch["frame_uv"] = raw_uv
+
+ ### Energy ###
+ if self.cfg.preprocess.use_frame_energy:
+ # (B, n_frames)
+ raw_energy = self.audio_features_extractor.get_energy(
+ batch["wav"], mel_spec=raw_mel
+ )
+ final_n_frames = min(final_n_frames, raw_energy.size(-1))
+ batch["frame_energy"] = raw_energy
+
+ ### Semantic Features ###
+ if self.cfg.model.condition_encoder.use_whisper:
+ # (B, n_frames, D)
+ whisper_feats = self.audio_features_extractor.get_whisper_features(
+ wavs=batch["wav_{}".format(self.cfg.preprocess.whisper_sample_rate)],
+ target_frame_len=padded_n_frames,
+ )
+ final_n_frames = min(final_n_frames, whisper_feats.size(1))
+ batch["whisper_feat"] = whisper_feats
+
+ if self.cfg.model.condition_encoder.use_contentvec:
+ # (B, n_frames, D)
+ contentvec_feats = self.audio_features_extractor.get_contentvec_features(
+ wavs=batch["wav_{}".format(self.cfg.preprocess.contentvec_sample_rate)],
+ target_frame_len=padded_n_frames,
+ )
+ final_n_frames = min(final_n_frames, contentvec_feats.size(1))
+ batch["contentvec_feat"] = contentvec_feats
+
+ if self.cfg.model.condition_encoder.use_wenet:
+ # (B, n_frames, D)
+ wenet_feats = self.audio_features_extractor.get_wenet_features(
+ wavs=batch["wav_{}".format(self.cfg.preprocess.wenet_sample_rate)],
+ target_frame_len=padded_n_frames,
+ wav_lens=batch[
+ "wav_{}_len".format(self.cfg.preprocess.wenet_sample_rate)
+ ],
+ )
+ final_n_frames = min(final_n_frames, wenet_feats.size(1))
+ batch["wenet_feat"] = wenet_feats
+
+ ### Align all the audio features to the same frame length ###
+ frame_level_features = [
+ "mask",
+ "mel",
+ "frame_pitch",
+ "frame_uv",
+ "frame_energy",
+ "whisper_feat",
+ "contentvec_feat",
+ "wenet_feat",
+ ]
+ for k in frame_level_features:
+ if k in batch:
+ # (B, n_frames, ...)
+ batch[k] = batch[k][:, :final_n_frames].contiguous()
+
+ return batch
+
+ @staticmethod
+ def _build_criterion():
+ criterion = nn.MSELoss(reduction="none")
+ return criterion
+
+ @staticmethod
+ def _compute_loss(criterion, y_pred, y_gt, loss_mask):
+ """
+ Args:
+ criterion: MSELoss(reduction='none')
+ y_pred, y_gt: (B, seq_len, D)
+ loss_mask: (B, seq_len, 1)
+ Returns:
+ loss: Tensor of shape []
+ """
+
+ # (B, seq_len, D)
+ loss = criterion(y_pred, y_gt)
+ # expand loss_mask to (B, seq_len, D)
+ loss_mask = loss_mask.repeat(1, 1, loss.shape[-1])
+
+ loss = torch.sum(loss * loss_mask) / torch.sum(loss_mask)
+ return loss
+
+ def _save_auxiliary_states(self):
+ """
+ To save the singer's look-up table in the checkpoint saving path
+ """
+ with open(
+ os.path.join(self.tmp_checkpoint_save_path, self.cfg.preprocess.spk2id),
+ "w",
+ encoding="utf-8",
+ ) as f:
+ json.dump(self.singers, f, indent=4, ensure_ascii=False)
+
+ def _build_singer_lut(self):
+ resumed_singer_path = None
+ if self.args.resume_from_ckpt_path and self.args.resume_from_ckpt_path != "":
+ resumed_singer_path = os.path.join(
+ self.args.resume_from_ckpt_path, self.cfg.preprocess.spk2id
+ )
+ if os.path.exists(os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)):
+ resumed_singer_path = os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)
+
+ if resumed_singer_path:
+ with open(resumed_singer_path, "r") as f:
+ singers = json.load(f)
+ else:
+ singers = dict()
+
+ for dataset in self.cfg.dataset:
+ singer_lut_path = os.path.join(
+ self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id
+ )
+ with open(singer_lut_path, "r") as singer_lut_path:
+ singer_lut = json.load(singer_lut_path)
+ for singer in singer_lut.keys():
+ if singer not in singers:
+ singers[singer] = len(singers)
+
+ with open(
+ os.path.join(self.exp_dir, self.cfg.preprocess.spk2id), "w"
+ ) as singer_file:
+ json.dump(singers, singer_file, indent=4, ensure_ascii=False)
+ print(
+ "singers have been dumped to {}".format(
+ os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)
+ )
+ )
+ return singers
diff --git a/models/svc/comosvc/__init__.py b/models/svc/comosvc/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..19f1cb162e95d8a992002beaa0c0d8bada9cddd5
--- /dev/null
+++ b/models/svc/comosvc/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/models/svc/comosvc/comosvc.py b/models/svc/comosvc/comosvc.py
new file mode 100644
index 0000000000000000000000000000000000000000..f637d1a5d78381335c35c1a750076b489dc284a2
--- /dev/null
+++ b/models/svc/comosvc/comosvc.py
@@ -0,0 +1,391 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import copy
+import numpy as np
+import math
+from tqdm.auto import tqdm
+
+from utils.ssim import SSIM
+
+from models.svc.transformer.conformer import Conformer, BaseModule
+from models.svc.diffusion.diffusion_wrapper import DiffusionWrapper
+
+
+class Consistency(nn.Module):
+ def __init__(self, cfg, distill=False):
+ super().__init__()
+ self.cfg = cfg
+ self.denoise_fn = DiffusionWrapper(self.cfg)
+ self.cfg = cfg.model.comosvc
+ self.teacher = not distill
+ self.P_mean = self.cfg.P_mean
+ self.P_std = self.cfg.P_std
+ self.sigma_data = self.cfg.sigma_data
+ self.sigma_min = self.cfg.sigma_min
+ self.sigma_max = self.cfg.sigma_max
+ self.rho = self.cfg.rho
+ self.N = self.cfg.n_timesteps
+ self.ssim_loss = SSIM()
+
+ # Time step discretization
+ step_indices = torch.arange(self.N)
+ # karras boundaries formula
+ t_steps = (
+ self.sigma_min ** (1 / self.rho)
+ + step_indices
+ / (self.N - 1)
+ * (self.sigma_max ** (1 / self.rho) - self.sigma_min ** (1 / self.rho))
+ ) ** self.rho
+ self.t_steps = torch.cat(
+ [torch.zeros_like(t_steps[:1]), self.round_sigma(t_steps)]
+ )
+
+ def init_consistency_training(self):
+ self.denoise_fn_ema = copy.deepcopy(self.denoise_fn)
+ self.denoise_fn_pretrained = copy.deepcopy(self.denoise_fn)
+
+ def EDMPrecond(self, x, sigma, cond, denoise_fn):
+ """
+ karras diffusion reverse process
+
+ Args:
+ x: noisy mel-spectrogram [B x n_mel x L]
+ sigma: noise level [B x 1 x 1]
+ cond: output of conformer encoder [B x n_mel x L]
+ denoise_fn: denoiser neural network e.g. DilatedCNN
+
+ Returns:
+ denoised mel-spectrogram [B x n_mel x L]
+ """
+ sigma = sigma.reshape(-1, 1, 1)
+
+ c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
+ c_out = (
+ (sigma - self.sigma_min)
+ * self.sigma_data
+ / (sigma**2 + self.sigma_data**2).sqrt()
+ )
+ c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
+ c_noise = sigma.log() / 4
+
+ x_in = c_in * x
+ x_in = x_in.transpose(1, 2)
+ x = x.transpose(1, 2)
+ cond = cond.transpose(1, 2)
+ c_noise = c_noise.squeeze()
+ if c_noise.dim() == 0:
+ c_noise = c_noise.unsqueeze(0)
+ F_x = denoise_fn(x_in, c_noise, cond)
+ D_x = c_skip * x + c_out * (F_x)
+ D_x = D_x.transpose(1, 2)
+ return D_x
+
+ def EDMLoss(self, x_start, cond, mask):
+ """
+ compute loss for EDM model
+
+ Args:
+ x_start: ground truth mel-spectrogram [B x n_mel x L]
+ cond: output of conformer encoder [B x n_mel x L]
+ mask: mask of padded frames [B x n_mel x L]
+ """
+ rnd_normal = torch.randn([x_start.shape[0], 1, 1], device=x_start.device)
+ sigma = (rnd_normal * self.P_std + self.P_mean).exp()
+ weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
+
+ # follow Grad-TTS, start from Gaussian noise with mean cond and std I
+ noise = (torch.randn_like(x_start) + cond) * sigma
+ D_yn = self.EDMPrecond(x_start + noise, sigma, cond, self.denoise_fn)
+ loss = weight * ((D_yn - x_start) ** 2)
+ loss = torch.sum(loss * mask) / torch.sum(mask)
+ return loss
+
+ def round_sigma(self, sigma):
+ return torch.as_tensor(sigma)
+
+ def edm_sampler(
+ self,
+ latents,
+ cond,
+ nonpadding,
+ num_steps=50,
+ sigma_min=0.002,
+ sigma_max=80,
+ rho=7,
+ S_churn=0,
+ S_min=0,
+ S_max=float("inf"),
+ S_noise=1,
+ ):
+ """
+ karras diffusion sampler
+
+ Args:
+ latents: noisy mel-spectrogram [B x n_mel x L]
+ cond: output of conformer encoder [B x n_mel x L]
+ nonpadding: mask of padded frames [B x n_mel x L]
+ num_steps: number of steps for diffusion inference
+
+ Returns:
+ denoised mel-spectrogram [B x n_mel x L]
+ """
+ # Time step discretization.
+
+ num_steps = num_steps + 1
+ step_indices = torch.arange(num_steps, device=latents.device)
+ t_steps = (
+ sigma_max ** (1 / rho)
+ + step_indices
+ / (num_steps - 1)
+ * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))
+ ) ** rho
+ t_steps = torch.cat([self.round_sigma(t_steps), torch.zeros_like(t_steps[:1])])
+
+ # Main sampling loop.
+ x_next = latents * t_steps[0]
+ # wrap in tqdm for progress bar
+ bar = tqdm(enumerate(zip(t_steps[:-1], t_steps[1:])))
+ for i, (t_cur, t_next) in bar:
+ x_cur = x_next
+ # Increase noise temporarily.
+ gamma = (
+ min(S_churn / num_steps, np.sqrt(2) - 1)
+ if S_min <= t_cur <= S_max
+ else 0
+ )
+ t_hat = self.round_sigma(t_cur + gamma * t_cur)
+ t = torch.zeros((x_cur.shape[0], 1, 1), device=x_cur.device)
+ t[:, 0, 0] = t_hat
+ t_hat = t
+ x_hat = x_cur + (
+ t_hat**2 - t_cur**2
+ ).sqrt() * S_noise * torch.randn_like(x_cur)
+ # Euler step.
+ denoised = self.EDMPrecond(x_hat, t_hat, cond, self.denoise_fn)
+ d_cur = (x_hat - denoised) / t_hat
+ x_next = x_hat + (t_next - t_hat) * d_cur
+
+ # add Heun’s 2nd order method
+ # if i < num_steps - 1:
+ # t = torch.zeros((x_cur.shape[0], 1, 1), device=x_cur.device)
+ # t[:, 0, 0] = t_next
+ # #t_next = t
+ # denoised = self.EDMPrecond(x_next, t, cond, self.denoise_fn, nonpadding)
+ # d_prime = (x_next - denoised) / t_next
+ # x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
+
+ return x_next
+
+ def CTLoss_D(self, y, cond, mask):
+ """
+ compute loss for consistency distillation
+
+ Args:
+ y: ground truth mel-spectrogram [B x n_mel x L]
+ cond: output of conformer encoder [B x n_mel x L]
+ mask: mask of padded frames [B x n_mel x L]
+ """
+ with torch.no_grad():
+ mu = 0.95
+ for p, ema_p in zip(
+ self.denoise_fn.parameters(), self.denoise_fn_ema.parameters()
+ ):
+ ema_p.mul_(mu).add_(p, alpha=1 - mu)
+
+ n = torch.randint(1, self.N, (y.shape[0],))
+ z = torch.randn_like(y) + cond
+
+ tn_1 = self.t_steps[n + 1].reshape(-1, 1, 1).to(y.device)
+ f_theta = self.EDMPrecond(y + tn_1 * z, tn_1, cond, self.denoise_fn)
+
+ with torch.no_grad():
+ tn = self.t_steps[n].reshape(-1, 1, 1).to(y.device)
+
+ # euler step
+ x_hat = y + tn_1 * z
+ denoised = self.EDMPrecond(x_hat, tn_1, cond, self.denoise_fn_pretrained)
+ d_cur = (x_hat - denoised) / tn_1
+ y_tn = x_hat + (tn - tn_1) * d_cur
+
+ # Heun’s 2nd order method
+
+ denoised2 = self.EDMPrecond(y_tn, tn, cond, self.denoise_fn_pretrained)
+ d_prime = (y_tn - denoised2) / tn
+ y_tn = x_hat + (tn - tn_1) * (0.5 * d_cur + 0.5 * d_prime)
+
+ f_theta_ema = self.EDMPrecond(y_tn, tn, cond, self.denoise_fn_ema)
+
+ loss = (f_theta - f_theta_ema.detach()) ** 2
+ loss = torch.sum(loss * mask) / torch.sum(mask)
+
+ # check nan
+ if torch.any(torch.isnan(loss)):
+ print("nan loss")
+ if torch.any(torch.isnan(f_theta)):
+ print("nan f_theta")
+ if torch.any(torch.isnan(f_theta_ema)):
+ print("nan f_theta_ema")
+
+ return loss
+
+ def get_t_steps(self, N):
+ N = N + 1
+ step_indices = torch.arange(N)
+ t_steps = (
+ self.sigma_min ** (1 / self.rho)
+ + step_indices
+ / (N - 1)
+ * (self.sigma_max ** (1 / self.rho) - self.sigma_min ** (1 / self.rho))
+ ) ** self.rho
+
+ return t_steps.flip(0)
+
+ def CT_sampler(self, latents, cond, nonpadding, t_steps=1):
+ """
+ consistency distillation sampler
+
+ Args:
+ latents: noisy mel-spectrogram [B x n_mel x L]
+ cond: output of conformer encoder [B x n_mel x L]
+ nonpadding: mask of padded frames [B x n_mel x L]
+ t_steps: number of steps for diffusion inference
+
+ Returns:
+ denoised mel-spectrogram [B x n_mel x L]
+ """
+ # one-step
+ if t_steps == 1:
+ t_steps = [80]
+ # multi-step
+ else:
+ t_steps = self.get_t_steps(t_steps)
+
+ t_steps = torch.as_tensor(t_steps).to(latents.device)
+ latents = latents * t_steps[0]
+ _t = torch.zeros((latents.shape[0], 1, 1), device=latents.device)
+ _t[:, 0, 0] = t_steps[0]
+ x = self.EDMPrecond(latents, _t, cond, self.denoise_fn_ema)
+
+ for t in t_steps[1:-1]:
+ z = torch.randn_like(x) + cond
+ x_tn = x + (t**2 - self.sigma_min**2).sqrt() * z
+ _t = torch.zeros((x.shape[0], 1, 1), device=x.device)
+ _t[:, 0, 0] = t
+ t = _t
+ x = self.EDMPrecond(x_tn, t, cond, self.denoise_fn_ema)
+ return x
+
+ def forward(self, x, nonpadding, cond, t_steps=1, infer=False):
+ """
+ calculate loss or sample mel-spectrogram
+
+ Args:
+ x:
+ training: ground truth mel-spectrogram [B x n_mel x L]
+ inference: output of encoder [B x n_mel x L]
+ """
+ if self.teacher: # teacher model -- karras diffusion
+ if not infer:
+ loss = self.EDMLoss(x, cond, nonpadding)
+ return loss
+ else:
+ shape = (cond.shape[0], self.cfg.n_mel, cond.shape[2])
+ x = torch.randn(shape, device=x.device) + cond
+ x = self.edm_sampler(x, cond, nonpadding, t_steps)
+
+ return x
+ else: # Consistency distillation
+ if not infer:
+ loss = self.CTLoss_D(x, cond, nonpadding)
+ return loss
+
+ else:
+ shape = (cond.shape[0], self.cfg.n_mel, cond.shape[2])
+ x = torch.randn(shape, device=x.device) + cond
+ x = self.CT_sampler(x, cond, nonpadding, t_steps=1)
+
+ return x
+
+
+class ComoSVC(BaseModule):
+ def __init__(self, cfg):
+ super().__init__()
+ self.cfg = cfg
+ self.cfg.model.comosvc.n_mel = self.cfg.preprocess.n_mel
+ self.distill = self.cfg.model.comosvc.distill
+ self.encoder = Conformer(self.cfg.model.comosvc)
+ self.decoder = Consistency(self.cfg, distill=self.distill)
+ self.ssim_loss = SSIM()
+
+ @torch.no_grad()
+ def forward(self, x_mask, x, n_timesteps, temperature=1.0):
+ """
+ Generates mel-spectrogram from pitch, content vector, energy. Returns:
+ 1. encoder outputs (from conformer)
+ 2. decoder outputs (from diffusion-based decoder)
+
+ Args:
+ x_mask : mask of padded frames in mel-spectrogram. [B x L x n_mel]
+ x : output of encoder framework. [B x L x d_condition]
+ n_timesteps : number of steps to use for reverse diffusion in decoder.
+ temperature : controls variance of terminal distribution.
+ """
+
+ # Get encoder_outputs `mu_x`
+ mu_x = self.encoder(x, x_mask)
+ encoder_outputs = mu_x
+
+ mu_x = mu_x.transpose(1, 2)
+ x_mask = x_mask.transpose(1, 2)
+
+ # Generate sample by performing reverse dynamics
+ decoder_outputs = self.decoder(
+ mu_x, x_mask, mu_x, t_steps=n_timesteps, infer=True
+ )
+ decoder_outputs = decoder_outputs.transpose(1, 2)
+ return encoder_outputs, decoder_outputs
+
+ def compute_loss(self, x_mask, x, mel, skip_diff=False):
+ """
+ Computes 2 losses:
+ 1. prior loss: loss between mel-spectrogram and encoder outputs. (l2 and ssim loss)
+ 2. diffusion loss: loss between gaussian noise and its reconstruction by diffusion-based decoder.
+
+ Args:
+ x_mask : mask of padded frames in mel-spectrogram. [B x L x n_mel]
+ x : output of encoder framework. [B x L x d_condition]
+ mel : ground truth mel-spectrogram. [B x L x n_mel]
+ """
+
+ mu_x = self.encoder(x, x_mask)
+ # prior loss
+ x_mask = x_mask.repeat(1, 1, mel.shape[-1])
+ prior_loss = torch.sum(
+ 0.5 * ((mel - mu_x) ** 2 + math.log(2 * math.pi)) * x_mask
+ )
+
+ prior_loss = prior_loss / (torch.sum(x_mask) * self.cfg.model.comosvc.n_mel)
+ # ssim loss
+ ssim_loss = self.ssim_loss(mu_x, mel)
+ ssim_loss = torch.sum(ssim_loss * x_mask) / torch.sum(x_mask)
+
+ x_mask = x_mask.transpose(1, 2)
+ mu_x = mu_x.transpose(1, 2)
+ mel = mel.transpose(1, 2)
+ if not self.distill and skip_diff:
+ diff_loss = prior_loss.clone()
+ diff_loss.fill_(0)
+
+ # Cut a small segment of mel-spectrogram in order to increase batch size
+ else:
+ mu_y = mu_x
+ mask_y = x_mask
+
+ diff_loss = self.decoder(mel, mask_y, mu_y, infer=False)
+
+ return ssim_loss, prior_loss, diff_loss
diff --git a/models/svc/comosvc/comosvc_inference.py b/models/svc/comosvc/comosvc_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..2783ec7e468c367c7d2f5f8988ed1f7e272d4cb7
--- /dev/null
+++ b/models/svc/comosvc/comosvc_inference.py
@@ -0,0 +1,39 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+from models.svc.base import SVCInference
+from modules.encoder.condition_encoder import ConditionEncoder
+from models.svc.comosvc.comosvc import ComoSVC
+
+
+class ComoSVCInference(SVCInference):
+ def __init__(self, args, cfg, infer_type="from_dataset"):
+ SVCInference.__init__(self, args, cfg, infer_type)
+
+ def _build_model(self):
+ # TODO: sort out the config
+ self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
+ self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
+ self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
+ self.acoustic_mapper = ComoSVC(self.cfg)
+ if self.cfg.model.comosvc.distill:
+ self.acoustic_mapper.decoder.init_consistency_training()
+ model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
+ return model
+
+ def _inference_each_batch(self, batch_data):
+ device = self.accelerator.device
+ for k, v in batch_data.items():
+ batch_data[k] = v.to(device)
+
+ cond = self.condition_encoder(batch_data)
+ mask = batch_data["mask"]
+ encoder_pred, decoder_pred = self.acoustic_mapper(
+ mask, cond, self.cfg.inference.comosvc.inference_steps
+ )
+
+ return decoder_pred
diff --git a/models/svc/comosvc/comosvc_trainer.py b/models/svc/comosvc/comosvc_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..a689faa61ad18c0bbfb2cf9dfc42209cfe113069
--- /dev/null
+++ b/models/svc/comosvc/comosvc_trainer.py
@@ -0,0 +1,319 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import os
+import json5
+from tqdm import tqdm
+import json
+import shutil
+
+from models.svc.base import SVCTrainer
+from modules.encoder.condition_encoder import ConditionEncoder
+from models.svc.comosvc.comosvc import ComoSVC
+
+
+class ComoSVCTrainer(SVCTrainer):
+ r"""The base trainer for all diffusion models. It inherits from SVCTrainer and
+ implements ``_build_model`` and ``_forward_step`` methods.
+ """
+
+ def __init__(self, args=None, cfg=None):
+ SVCTrainer.__init__(self, args, cfg)
+ self.distill = cfg.model.comosvc.distill
+ self.skip_diff = True
+
+ ### Following are methods only for comoSVC models ###
+
+ def _load_teacher_model(self, model):
+ r"""Load teacher model from checkpoint file."""
+ self.checkpoint_file = self.teacher_model_path
+ self.logger.info(
+ "Load teacher acoustic model from {}".format(self.checkpoint_file)
+ )
+ raw_dict = torch.load(self.checkpoint_file)
+ model.load_state_dict(raw_dict)
+
+ def _build_model(self):
+ r"""Build the model for training. This function is called in ``__init__`` function."""
+
+ # TODO: sort out the config
+ self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
+ self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
+ self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
+ self.acoustic_mapper = ComoSVC(self.cfg)
+ model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
+ if self.cfg.model.comosvc.distill:
+ if not self.args.resume:
+ # do not load teacher model when resume
+ self.teacher_model_path = self.cfg.model.teacher_model_path
+ self._load_teacher_model(model)
+ # build teacher & target decoder and freeze teacher
+ self.acoustic_mapper.decoder.init_consistency_training()
+ self.freeze_net(self.condition_encoder)
+ self.freeze_net(self.acoustic_mapper.encoder)
+ self.freeze_net(self.acoustic_mapper.decoder.denoise_fn_pretrained)
+ self.freeze_net(self.acoustic_mapper.decoder.denoise_fn_ema)
+ return model
+
+ def freeze_net(self, model):
+ r"""Freeze the model for training."""
+ for name, param in model.named_parameters():
+ param.requires_grad = False
+
+ def __build_optimizer(self):
+ r"""Build optimizer for training. This function is called in ``__init__`` function."""
+
+ if self.cfg.train.optimizer.lower() == "adamw":
+ optimizer = torch.optim.AdamW(
+ params=filter(lambda p: p.requires_grad, self.model.parameters()),
+ **self.cfg.train.adamw,
+ )
+
+ else:
+ raise NotImplementedError(
+ "Not support optimizer: {}".format(self.cfg.train.optimizer)
+ )
+
+ return optimizer
+
+ def _forward_step(self, batch):
+ r"""Forward step for training and inference. This function is called
+ in ``_train_step`` & ``_test_step`` function.
+ """
+ loss = {}
+ mask = batch["mask"]
+ mel_input = batch["mel"]
+ cond = self.condition_encoder(batch)
+ if self.distill:
+ cond = cond.detach()
+ self.skip_diff = True if self.step < self.cfg.train.fast_steps else False
+ ssim_loss, prior_loss, diff_loss = self.acoustic_mapper.compute_loss(
+ mask, cond, mel_input, skip_diff=self.skip_diff
+ )
+ if self.distill:
+ loss["distil_loss"] = diff_loss
+ else:
+ loss["ssim_loss_encoder"] = ssim_loss
+ loss["prior_loss_encoder"] = prior_loss
+ loss["diffusion_loss_decoder"] = diff_loss
+
+ return loss
+
+ def _train_epoch(self):
+ r"""Training epoch. Should return average loss of a batch (sample) over
+ one epoch. See ``train_loop`` for usage.
+ """
+ self.model.train()
+ epoch_sum_loss: float = 0.0
+ epoch_step: int = 0
+ for batch in tqdm(
+ self.train_dataloader,
+ desc=f"Training Epoch {self.epoch}",
+ unit="batch",
+ colour="GREEN",
+ leave=False,
+ dynamic_ncols=True,
+ smoothing=0.04,
+ disable=not self.accelerator.is_main_process,
+ ):
+ # Do training step and BP
+ with self.accelerator.accumulate(self.model):
+ loss = self._train_step(batch)
+ total_loss = 0
+ for k, v in loss.items():
+ total_loss += v
+ self.accelerator.backward(total_loss)
+ enc_grad_norm = torch.nn.utils.clip_grad_norm_(
+ self.acoustic_mapper.encoder.parameters(), max_norm=1
+ )
+ dec_grad_norm = torch.nn.utils.clip_grad_norm_(
+ self.acoustic_mapper.decoder.parameters(), max_norm=1
+ )
+ self.optimizer.step()
+ self.optimizer.zero_grad()
+ self.batch_count += 1
+
+ # Update info for each step
+ # TODO: step means BP counts or batch counts?
+ if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
+ epoch_sum_loss += total_loss
+ log_info = {}
+ for k, v in loss.items():
+ key = "Step/Train Loss/{}".format(k)
+ log_info[key] = v
+ log_info["Step/Learning Rate"] = self.optimizer.param_groups[0]["lr"]
+ self.accelerator.log(
+ log_info,
+ step=self.step,
+ )
+ self.step += 1
+ epoch_step += 1
+
+ self.accelerator.wait_for_everyone()
+ return (
+ epoch_sum_loss
+ / len(self.train_dataloader)
+ * self.cfg.train.gradient_accumulation_step,
+ loss,
+ )
+
+ def train_loop(self):
+ r"""Training loop. The public entry of training process."""
+ # Wait everyone to prepare before we move on
+ self.accelerator.wait_for_everyone()
+ # dump config file
+ if self.accelerator.is_main_process:
+ self.__dump_cfg(self.config_save_path)
+ self.model.train()
+ self.optimizer.zero_grad()
+ # Wait to ensure good to go
+ self.accelerator.wait_for_everyone()
+ while self.epoch < self.max_epoch:
+ self.logger.info("\n")
+ self.logger.info("-" * 32)
+ self.logger.info("Epoch {}: ".format(self.epoch))
+
+ ### TODO: change the return values of _train_epoch() to a loss dict, or (total_loss, loss_dict)
+ ### It's inconvenient for the model with multiple losses
+ # Do training & validating epoch
+ train_loss, loss = self._train_epoch()
+ self.logger.info(" |- Train/Loss: {:.6f}".format(train_loss))
+ for k, v in loss.items():
+ self.logger.info(" |- Train/Loss/{}: {:.6f}".format(k, v))
+ valid_loss = self._valid_epoch()
+ self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_loss))
+ self.accelerator.log(
+ {"Epoch/Train Loss": train_loss, "Epoch/Valid Loss": valid_loss},
+ step=self.epoch,
+ )
+
+ self.accelerator.wait_for_everyone()
+ # TODO: what is scheduler?
+ self.scheduler.step(valid_loss) # FIXME: use epoch track correct?
+
+ # Check if hit save_checkpoint_stride and run_eval
+ run_eval = False
+ if self.accelerator.is_main_process:
+ save_checkpoint = False
+ hit_dix = []
+ for i, num in enumerate(self.save_checkpoint_stride):
+ if self.epoch % num == 0:
+ save_checkpoint = True
+ hit_dix.append(i)
+ run_eval |= self.run_eval[i]
+
+ self.accelerator.wait_for_everyone()
+ if (
+ self.accelerator.is_main_process
+ and save_checkpoint
+ and (self.distill or not self.skip_diff)
+ ):
+ path = os.path.join(
+ self.checkpoint_dir,
+ "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
+ self.epoch, self.step, train_loss
+ ),
+ )
+ self.tmp_checkpoint_save_path = path
+ self.accelerator.save_state(path)
+ print(f"save checkpoint in {path}")
+ json.dump(
+ self.checkpoints_path,
+ open(os.path.join(path, "ckpts.json"), "w"),
+ ensure_ascii=False,
+ indent=4,
+ )
+ self._save_auxiliary_states()
+
+ # Remove old checkpoints
+ to_remove = []
+ for idx in hit_dix:
+ self.checkpoints_path[idx].append(path)
+ while len(self.checkpoints_path[idx]) > self.keep_last[idx]:
+ to_remove.append((idx, self.checkpoints_path[idx].pop(0)))
+
+ # Search conflicts
+ total = set()
+ for i in self.checkpoints_path:
+ total |= set(i)
+ do_remove = set()
+ for idx, path in to_remove[::-1]:
+ if path in total:
+ self.checkpoints_path[idx].insert(0, path)
+ else:
+ do_remove.add(path)
+
+ # Remove old checkpoints
+ for path in do_remove:
+ shutil.rmtree(path, ignore_errors=True)
+ self.logger.debug(f"Remove old checkpoint: {path}")
+
+ self.accelerator.wait_for_everyone()
+ if run_eval:
+ # TODO: run evaluation
+ pass
+
+ # Update info for each epoch
+ self.epoch += 1
+
+ # Finish training and save final checkpoint
+ self.accelerator.wait_for_everyone()
+ if self.accelerator.is_main_process:
+ self.accelerator.save_state(
+ os.path.join(
+ self.checkpoint_dir,
+ "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
+ self.epoch, self.step, valid_loss
+ ),
+ )
+ )
+ self._save_auxiliary_states()
+ self.accelerator.end_training()
+
+ @torch.inference_mode()
+ def _valid_epoch(self):
+ r"""Testing epoch. Should return average loss of a batch (sample) over
+ one epoch. See ``train_loop`` for usage.
+ """
+ self.model.eval()
+ epoch_sum_loss = 0.0
+ for batch in tqdm(
+ self.valid_dataloader,
+ desc=f"Validating Epoch {self.epoch}",
+ unit="batch",
+ colour="GREEN",
+ leave=False,
+ dynamic_ncols=True,
+ smoothing=0.04,
+ disable=not self.accelerator.is_main_process,
+ ):
+ batch_loss = self._valid_step(batch)
+ for k, v in batch_loss.items():
+ epoch_sum_loss += v
+
+ self.accelerator.wait_for_everyone()
+ return epoch_sum_loss / len(self.valid_dataloader)
+
+ @staticmethod
+ def __count_parameters(model):
+ model_param = 0.0
+ if isinstance(model, dict):
+ for key, value in model.items():
+ model_param += sum(p.numel() for p in model[key].parameters())
+ else:
+ model_param = sum(p.numel() for p in model.parameters())
+ return model_param
+
+ def __dump_cfg(self, path):
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ json5.dump(
+ self.cfg,
+ open(path, "w"),
+ indent=4,
+ sort_keys=True,
+ ensure_ascii=False,
+ quote_keys=True,
+ )
diff --git a/models/svc/diffusion/__init__.py b/models/svc/diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/svc/diffusion/diffusion_inference.py b/models/svc/diffusion/diffusion_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..a752ef8f195b59d0ac0ad402dc35ce5840626ab9
--- /dev/null
+++ b/models/svc/diffusion/diffusion_inference.py
@@ -0,0 +1,63 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler
+
+from models.svc.base import SVCInference
+from models.svc.diffusion.diffusion_inference_pipeline import DiffusionInferencePipeline
+from models.svc.diffusion.diffusion_wrapper import DiffusionWrapper
+from modules.encoder.condition_encoder import ConditionEncoder
+
+
+class DiffusionInference(SVCInference):
+ def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
+ SVCInference.__init__(self, args, cfg, infer_type)
+
+ settings = {
+ **cfg.model.diffusion.scheduler_settings,
+ **cfg.inference.diffusion.scheduler_settings,
+ }
+ settings.pop("num_inference_timesteps")
+
+ if cfg.inference.diffusion.scheduler.lower() == "ddpm":
+ self.scheduler = DDPMScheduler(**settings)
+ self.logger.info("Using DDPM scheduler.")
+ elif cfg.inference.diffusion.scheduler.lower() == "ddim":
+ self.scheduler = DDIMScheduler(**settings)
+ self.logger.info("Using DDIM scheduler.")
+ elif cfg.inference.diffusion.scheduler.lower() == "pndm":
+ self.scheduler = PNDMScheduler(**settings)
+ self.logger.info("Using PNDM scheduler.")
+ else:
+ raise NotImplementedError(
+ "Unsupported scheduler type: {}".format(
+ cfg.inference.diffusion.scheduler.lower()
+ )
+ )
+
+ self.pipeline = DiffusionInferencePipeline(
+ self.model[1],
+ self.scheduler,
+ args.diffusion_inference_steps,
+ )
+
+ def _build_model(self):
+ self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
+ self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
+ self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
+ self.acoustic_mapper = DiffusionWrapper(self.cfg)
+ model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
+ return model
+
+ def _inference_each_batch(self, batch_data):
+ device = self.accelerator.device
+ for k, v in batch_data.items():
+ batch_data[k] = v.to(device)
+
+ conditioner = self.model[0](batch_data)
+ noise = torch.randn_like(batch_data["mel"], device=device)
+ y_pred = self.pipeline(noise, conditioner)
+ return y_pred
diff --git a/models/svc/diffusion/diffusion_inference_pipeline.py b/models/svc/diffusion/diffusion_inference_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2461aada99179ac17a2aaffebdb24864af1f5ee
--- /dev/null
+++ b/models/svc/diffusion/diffusion_inference_pipeline.py
@@ -0,0 +1,47 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from diffusers import DiffusionPipeline
+
+
+class DiffusionInferencePipeline(DiffusionPipeline):
+ def __init__(self, network, scheduler, num_inference_timesteps=1000):
+ super().__init__()
+
+ self.register_modules(network=network, scheduler=scheduler)
+ self.num_inference_timesteps = num_inference_timesteps
+
+ @torch.inference_mode()
+ def __call__(
+ self,
+ initial_noise: torch.Tensor,
+ conditioner: torch.Tensor = None,
+ ):
+ r"""
+ Args:
+ initial_noise: The initial noise to be denoised.
+ conditioner:The conditioner.
+ n_inference_steps: The number of denoising steps. More denoising steps
+ usually lead to a higher quality at the expense of slower inference.
+ """
+
+ mel = initial_noise
+ batch_size = mel.size(0)
+ self.scheduler.set_timesteps(self.num_inference_timesteps)
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ timestep = torch.full((batch_size,), t, device=mel.device, dtype=torch.long)
+
+ # 1. predict noise model_output
+ model_output = self.network(mel, timestep, conditioner)
+
+ # 2. denoise, compute previous step: x_t -> x_t-1
+ mel = self.scheduler.step(model_output, t, mel).prev_sample
+
+ # 3. clamp
+ mel = mel.clamp(-1.0, 1.0)
+
+ return mel
diff --git a/models/svc/diffusion/diffusion_trainer.py b/models/svc/diffusion/diffusion_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1c59f21106ed69911ccd5db26aa24184a79dd4c
--- /dev/null
+++ b/models/svc/diffusion/diffusion_trainer.py
@@ -0,0 +1,102 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from diffusers import DDPMScheduler
+
+from models.svc.base import SVCTrainer
+from modules.encoder.condition_encoder import ConditionEncoder
+from .diffusion_wrapper import DiffusionWrapper
+
+
+class DiffusionTrainer(SVCTrainer):
+ r"""The base trainer for all diffusion models. It inherits from SVCTrainer and
+ implements ``_build_model`` and ``_forward_step`` methods.
+ """
+
+ def __init__(self, args=None, cfg=None):
+ SVCTrainer.__init__(self, args, cfg)
+
+ # Only for SVC tasks using diffusion
+ self.noise_scheduler = DDPMScheduler(
+ **self.cfg.model.diffusion.scheduler_settings,
+ )
+ self.diffusion_timesteps = (
+ self.cfg.model.diffusion.scheduler_settings.num_train_timesteps
+ )
+
+ ### Following are methods only for diffusion models ###
+ def _build_model(self):
+ r"""Build the model for training. This function is called in ``__init__`` function."""
+
+ # TODO: sort out the config
+ self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
+ self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
+ self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
+ self.acoustic_mapper = DiffusionWrapper(self.cfg)
+ model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
+
+ num_of_params_encoder = self.count_parameters(self.condition_encoder)
+ num_of_params_am = self.count_parameters(self.acoustic_mapper)
+ num_of_params = num_of_params_encoder + num_of_params_am
+ log = "Diffusion Model's Parameters: #Encoder is {:.2f}M, #Diffusion is {:.2f}M. The total is {:.2f}M".format(
+ num_of_params_encoder / 1e6, num_of_params_am / 1e6, num_of_params / 1e6
+ )
+ self.logger.info(log)
+
+ return model
+
+ def count_parameters(self, model):
+ model_param = 0.0
+ if isinstance(model, dict):
+ for key, value in model.items():
+ model_param += sum(p.numel() for p in model[key].parameters())
+ else:
+ model_param = sum(p.numel() for p in model.parameters())
+ return model_param
+
+ def _check_nan(self, batch, loss, y_pred, y_gt):
+ if torch.any(torch.isnan(loss)):
+ for k, v in batch.items():
+ self.logger.info(k)
+ self.logger.info(v)
+
+ super()._check_nan(loss, y_pred, y_gt)
+
+ def _forward_step(self, batch):
+ r"""Forward step for training and inference. This function is called
+ in ``_train_step`` & ``_test_step`` function.
+ """
+ device = self.accelerator.device
+
+ if self.online_features_extraction:
+ # On-the-fly features extraction
+ batch = self._extract_svc_features(batch)
+
+ # To debug
+ # for k, v in batch.items():
+ # print(k, v.shape, v)
+ # exit()
+
+ mel_input = batch["mel"]
+ noise = torch.randn_like(mel_input, device=device, dtype=torch.float32)
+ batch_size = mel_input.size(0)
+ timesteps = torch.randint(
+ 0,
+ self.diffusion_timesteps,
+ (batch_size,),
+ device=device,
+ dtype=torch.long,
+ )
+
+ noisy_mel = self.noise_scheduler.add_noise(mel_input, noise, timesteps)
+ conditioner = self.condition_encoder(batch)
+
+ y_pred = self.acoustic_mapper(noisy_mel, timesteps, conditioner)
+
+ loss = self._compute_loss(self.criterion, y_pred, noise, batch["mask"])
+ self._check_nan(batch, loss, y_pred, noise)
+
+ return loss
diff --git a/models/svc/diffusion/diffusion_wrapper.py b/models/svc/diffusion/diffusion_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef66c2b6b85ceb8fe7a2cf9b53c62edc6b3ef6bc
--- /dev/null
+++ b/models/svc/diffusion/diffusion_wrapper.py
@@ -0,0 +1,73 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch.nn as nn
+
+from modules.diffusion import BiDilConv
+from modules.encoder.position_encoder import PositionEncoder
+
+
+class DiffusionWrapper(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+
+ self.cfg = cfg
+ self.diff_cfg = cfg.model.diffusion
+
+ self.diff_encoder = PositionEncoder(
+ d_raw_emb=self.diff_cfg.step_encoder.dim_raw_embedding,
+ d_out=self.diff_cfg.bidilconv.base_channel,
+ d_mlp=self.diff_cfg.step_encoder.dim_hidden_layer,
+ activation_function=self.diff_cfg.step_encoder.activation,
+ n_layer=self.diff_cfg.step_encoder.num_layer,
+ max_period=self.diff_cfg.step_encoder.max_period,
+ )
+
+ # FIXME: Only support BiDilConv now for debug
+ if self.diff_cfg.model_type.lower() == "bidilconv":
+ self.neural_network = BiDilConv(
+ input_channel=self.cfg.preprocess.n_mel, **self.diff_cfg.bidilconv
+ )
+ else:
+ raise ValueError(
+ f"Unsupported diffusion model type: {self.diff_cfg.model_type}"
+ )
+
+ def forward(self, x, t, c):
+ """
+ Args:
+ x: [N, T, mel_band] of mel spectrogram
+ t: Diffusion time step with shape of [N]
+ c: [N, T, conditioner_size] of conditioner
+
+ Returns:
+ [N, T, mel_band] of mel spectrogram
+ """
+
+ assert (
+ x.size()[:-1] == c.size()[:-1]
+ ), "x mismatch with c, got \n x: {} \n c: {}".format(x.size(), c.size())
+ assert x.size(0) == t.size(
+ 0
+ ), "x mismatch with t, got \n x: {} \n t: {}".format(x.size(), t.size())
+ assert t.dim() == 1, "t must be 1D tensor, got {}".format(t.dim())
+
+ N, T, mel_band = x.size()
+
+ x = x.transpose(1, 2).contiguous() # [N, mel_band, T]
+ c = c.transpose(1, 2).contiguous() # [N, conditioner_size, T]
+ t = self.diff_encoder(t).contiguous() # [N, base_channel]
+
+ h = self.neural_network(x, t, c)
+ h = h.transpose(1, 2).contiguous() # [N, T, mel_band]
+
+ assert h.size() == (
+ N,
+ T,
+ mel_band,
+ ), "h mismatch with input x, got \n h: {} \n x: {}".format(
+ h.size(), (N, T, mel_band)
+ )
+ return h
diff --git a/models/svc/transformer/__init__.py b/models/svc/transformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/svc/transformer/conformer.py b/models/svc/transformer/conformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e48019cfc17d5f3825ce989f4852cec55fe1daa
--- /dev/null
+++ b/models/svc/transformer/conformer.py
@@ -0,0 +1,405 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import torch
+import numpy as np
+import torch.nn as nn
+from utils.util import convert_pad_shape
+
+
+class BaseModule(torch.nn.Module):
+ def __init__(self):
+ super(BaseModule, self).__init__()
+
+ @property
+ def nparams(self):
+ """
+ Returns number of trainable parameters of the module.
+ """
+ num_params = 0
+ for name, param in self.named_parameters():
+ if param.requires_grad:
+ num_params += np.prod(param.detach().cpu().numpy().shape)
+ return num_params
+
+ def relocate_input(self, x: list):
+ """
+ Relocates provided tensors to the same device set for the module.
+ """
+ device = next(self.parameters()).device
+ for i in range(len(x)):
+ if isinstance(x[i], torch.Tensor) and x[i].device != device:
+ x[i] = x[i].to(device)
+ return x
+
+
+class LayerNorm(BaseModule):
+ def __init__(self, channels, eps=1e-4):
+ super(LayerNorm, self).__init__()
+ self.channels = channels
+ self.eps = eps
+
+ self.gamma = torch.nn.Parameter(torch.ones(channels))
+ self.beta = torch.nn.Parameter(torch.zeros(channels))
+
+ def forward(self, x):
+ n_dims = len(x.shape)
+ mean = torch.mean(x, 1, keepdim=True)
+ variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
+
+ x = (x - mean) * torch.rsqrt(variance + self.eps)
+
+ shape = [1, -1] + [1] * (n_dims - 2)
+ x = x * self.gamma.view(*shape) + self.beta.view(*shape)
+ return x
+
+
+class ConvReluNorm(BaseModule):
+ def __init__(
+ self,
+ in_channels,
+ hidden_channels,
+ out_channels,
+ kernel_size,
+ n_layers,
+ p_dropout,
+ eps=1e-5,
+ ):
+ super(ConvReluNorm, self).__init__()
+ self.in_channels = in_channels
+ self.hidden_channels = hidden_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.n_layers = n_layers
+ self.p_dropout = p_dropout
+ self.eps = eps
+
+ self.conv_layers = torch.nn.ModuleList()
+ self.conv_layers.append(
+ torch.nn.Conv1d(
+ in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
+ )
+ )
+ self.relu_drop = torch.nn.Sequential(
+ torch.nn.ReLU(), torch.nn.Dropout(p_dropout)
+ )
+ for _ in range(n_layers - 1):
+ self.conv_layers.append(
+ torch.nn.Conv1d(
+ hidden_channels,
+ hidden_channels,
+ kernel_size,
+ padding=kernel_size // 2,
+ )
+ )
+ self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
+ self.proj.weight.data.zero_()
+ self.proj.bias.data.zero_()
+
+ def forward(self, x, x_mask):
+ for i in range(self.n_layers):
+ x = self.conv_layers[i](x * x_mask)
+ x = self.instance_norm(x, x_mask)
+ x = self.relu_drop(x)
+ x = self.proj(x)
+ return x * x_mask
+
+ def instance_norm(self, x, mask, return_mean_std=False):
+ mean, std = self.calc_mean_std(x, mask)
+ x = (x - mean) / std
+ if return_mean_std:
+ return x, mean, std
+ else:
+ return x
+
+ def calc_mean_std(self, x, mask=None):
+ x = x * mask
+ B, C = x.shape[:2]
+ mn = x.view(B, C, -1).mean(-1)
+ sd = (x.view(B, C, -1).var(-1) + self.eps).sqrt()
+ mn = mn.view(B, C, *((len(x.shape) - 2) * [1]))
+ sd = sd.view(B, C, *((len(x.shape) - 2) * [1]))
+ return mn, sd
+
+
+class MultiHeadAttention(BaseModule):
+ def __init__(
+ self,
+ channels,
+ out_channels,
+ n_heads,
+ window_size=None,
+ heads_share=True,
+ p_dropout=0.0,
+ proximal_bias=False,
+ proximal_init=False,
+ ):
+ super(MultiHeadAttention, self).__init__()
+ assert channels % n_heads == 0
+
+ self.channels = channels
+ self.out_channels = out_channels
+ self.n_heads = n_heads
+ self.window_size = window_size
+ self.heads_share = heads_share
+ self.proximal_bias = proximal_bias
+ self.p_dropout = p_dropout
+ self.attn = None
+
+ self.k_channels = channels // n_heads
+ self.conv_q = torch.nn.Conv1d(channels, channels, 1)
+ self.conv_k = torch.nn.Conv1d(channels, channels, 1)
+ self.conv_v = torch.nn.Conv1d(channels, channels, 1)
+ if window_size is not None:
+ n_heads_rel = 1 if heads_share else n_heads
+ rel_stddev = self.k_channels**-0.5
+ self.emb_rel_k = torch.nn.Parameter(
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
+ * rel_stddev
+ )
+ self.emb_rel_v = torch.nn.Parameter(
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
+ * rel_stddev
+ )
+ self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
+ self.drop = torch.nn.Dropout(p_dropout)
+
+ torch.nn.init.xavier_uniform_(self.conv_q.weight)
+ torch.nn.init.xavier_uniform_(self.conv_k.weight)
+ if proximal_init:
+ self.conv_k.weight.data.copy_(self.conv_q.weight.data)
+ self.conv_k.bias.data.copy_(self.conv_q.bias.data)
+ torch.nn.init.xavier_uniform_(self.conv_v.weight)
+
+ def forward(self, x, c, attn_mask=None):
+ q = self.conv_q(x)
+ k = self.conv_k(c)
+ v = self.conv_v(c)
+
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
+
+ x = self.conv_o(x)
+ return x
+
+ def attention(self, query, key, value, mask=None):
+ b, d, t_s, t_t = (*key.size(), query.size(2))
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
+
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
+ if self.window_size is not None:
+ assert (
+ t_s == t_t
+ ), "Relative attention is only available for self-attention."
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
+ rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
+ rel_logits = self._relative_position_to_absolute_position(rel_logits)
+ scores_local = rel_logits / math.sqrt(self.k_channels)
+ scores = scores + scores_local
+ if self.proximal_bias:
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
+ scores = scores + self._attention_bias_proximal(t_s).to(
+ device=scores.device, dtype=scores.dtype
+ )
+ if mask is not None:
+ scores = scores.masked_fill(mask == 0, -1e4)
+ p_attn = torch.nn.functional.softmax(scores, dim=-1)
+ p_attn = self.drop(p_attn)
+ output = torch.matmul(p_attn, value)
+ if self.window_size is not None:
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
+ value_relative_embeddings = self._get_relative_embeddings(
+ self.emb_rel_v, t_s
+ )
+ output = output + self._matmul_with_relative_values(
+ relative_weights, value_relative_embeddings
+ )
+ output = output.transpose(2, 3).contiguous().view(b, d, t_t)
+ return output, p_attn
+
+ def _matmul_with_relative_values(self, x, y):
+ ret = torch.matmul(x, y.unsqueeze(0))
+ return ret
+
+ def _matmul_with_relative_keys(self, x, y):
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
+ return ret
+
+ def _get_relative_embeddings(self, relative_embeddings, length):
+ pad_length = max(length - (self.window_size + 1), 0)
+ slice_start_position = max((self.window_size + 1) - length, 0)
+ slice_end_position = slice_start_position + 2 * length - 1
+ if pad_length > 0:
+ padded_relative_embeddings = torch.nn.functional.pad(
+ relative_embeddings,
+ convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
+ )
+ else:
+ padded_relative_embeddings = relative_embeddings
+ used_relative_embeddings = padded_relative_embeddings[
+ :, slice_start_position:slice_end_position
+ ]
+ return used_relative_embeddings
+
+ def _relative_position_to_absolute_position(self, x):
+ batch, heads, length, _ = x.size()
+ x = torch.nn.functional.pad(
+ x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])
+ )
+ x_flat = x.view([batch, heads, length * 2 * length])
+ x_flat = torch.nn.functional.pad(
+ x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
+ )
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
+ :, :, :length, length - 1 :
+ ]
+ return x_final
+
+ def _absolute_position_to_relative_position(self, x):
+ batch, heads, length, _ = x.size()
+ x = torch.nn.functional.pad(
+ x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
+ )
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
+ x_flat = torch.nn.functional.pad(
+ x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]])
+ )
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
+ return x_final
+
+ def _attention_bias_proximal(self, length):
+ r = torch.arange(length, dtype=torch.float32)
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
+
+
+class FFN(BaseModule):
+ def __init__(
+ self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0
+ ):
+ super(FFN, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.filter_channels = filter_channels
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+
+ self.conv_1 = torch.nn.Conv1d(
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
+ )
+ self.conv_2 = torch.nn.Conv1d(
+ filter_channels, out_channels, kernel_size, padding=kernel_size // 2
+ )
+ self.drop = torch.nn.Dropout(p_dropout)
+
+ def forward(self, x, x_mask):
+ x = self.conv_1(x * x_mask)
+ x = torch.relu(x)
+ x = self.drop(x)
+ x = self.conv_2(x * x_mask)
+ return x * x_mask
+
+
+class Encoder(BaseModule):
+ def __init__(
+ self,
+ hidden_channels,
+ filter_channels,
+ n_heads=2,
+ n_layers=6,
+ kernel_size=3,
+ p_dropout=0.1,
+ window_size=4,
+ **kwargs
+ ):
+ super(Encoder, self).__init__()
+ self.hidden_channels = hidden_channels
+ self.filter_channels = filter_channels
+ self.n_heads = n_heads
+ self.n_layers = n_layers
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.window_size = window_size
+
+ self.drop = torch.nn.Dropout(p_dropout)
+ self.attn_layers = torch.nn.ModuleList()
+ self.norm_layers_1 = torch.nn.ModuleList()
+ self.ffn_layers = torch.nn.ModuleList()
+ self.norm_layers_2 = torch.nn.ModuleList()
+ for _ in range(self.n_layers):
+ self.attn_layers.append(
+ MultiHeadAttention(
+ hidden_channels,
+ hidden_channels,
+ n_heads,
+ window_size=window_size,
+ p_dropout=p_dropout,
+ )
+ )
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
+ self.ffn_layers.append(
+ FFN(
+ hidden_channels,
+ hidden_channels,
+ filter_channels,
+ kernel_size,
+ p_dropout=p_dropout,
+ )
+ )
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
+
+ def forward(self, x, x_mask):
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
+ for i in range(self.n_layers):
+ x = x * x_mask
+ y = self.attn_layers[i](x, x, attn_mask)
+ y = self.drop(y)
+ x = self.norm_layers_1[i](x + y)
+ y = self.ffn_layers[i](x, x_mask)
+ y = self.drop(y)
+ x = self.norm_layers_2[i](x + y)
+ x = x * x_mask
+ return x
+
+
+class Conformer(BaseModule):
+ def __init__(self, cfg):
+ super().__init__()
+ self.cfg = cfg
+ self.n_heads = self.cfg.n_heads
+ self.n_layers = self.cfg.n_layers
+ self.hidden_channels = self.cfg.input_dim
+ self.filter_channels = self.cfg.filter_channels
+ self.output_dim = self.cfg.output_dim
+ self.dropout = self.cfg.dropout
+
+ self.conformer_encoder = Encoder(
+ self.hidden_channels,
+ self.filter_channels,
+ n_heads=self.n_heads,
+ n_layers=self.n_layers,
+ kernel_size=3,
+ p_dropout=self.dropout,
+ window_size=4,
+ )
+ self.projection = nn.Conv1d(self.hidden_channels, self.output_dim, 1)
+
+ def forward(self, x, x_mask):
+ """
+ Args:
+ x: (N, seq_len, input_dim)
+ Returns:
+ output: (N, seq_len, output_dim)
+ """
+ # (N, seq_len, d_model)
+ x = x.transpose(1, 2)
+ x_mask = x_mask.transpose(1, 2)
+ output = self.conformer_encoder(x, x_mask)
+ # (N, seq_len, output_dim)
+ output = self.projection(output)
+ output = output.transpose(1, 2)
+ return output
diff --git a/models/svc/transformer/transformer.py b/models/svc/transformer/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd3cdb6c2d0fc93534d005b9f67a3058c9185c60
--- /dev/null
+++ b/models/svc/transformer/transformer.py
@@ -0,0 +1,82 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import torch
+import torch.nn as nn
+from torch.nn import TransformerEncoder, TransformerEncoderLayer
+
+
+class Transformer(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.cfg = cfg
+
+ dropout = self.cfg.dropout
+ nhead = self.cfg.n_heads
+ nlayers = self.cfg.n_layers
+ input_dim = self.cfg.input_dim
+ output_dim = self.cfg.output_dim
+
+ d_model = input_dim
+ self.pos_encoder = PositionalEncoding(d_model, dropout)
+ encoder_layers = TransformerEncoderLayer(
+ d_model, nhead, dropout=dropout, batch_first=True
+ )
+ self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
+
+ self.output_mlp = nn.Linear(d_model, output_dim)
+
+ def forward(self, x, mask=None):
+ """
+ Args:
+ x: (N, seq_len, input_dim)
+ Returns:
+ output: (N, seq_len, output_dim)
+ """
+ # (N, seq_len, d_model)
+ src = self.pos_encoder(x)
+ # model_stats["pos_embedding"] = x
+ # (N, seq_len, d_model)
+ output = self.transformer_encoder(src)
+ # (N, seq_len, output_dim)
+ output = self.output_mlp(output)
+ return output
+
+
+class PositionalEncoding(nn.Module):
+ def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
+ super().__init__()
+ self.dropout = nn.Dropout(p=dropout)
+
+ position = torch.arange(max_len).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
+ )
+
+ # Assume that x is (seq_len, N, d)
+ # pe = torch.zeros(max_len, 1, d_model)
+ # pe[:, 0, 0::2] = torch.sin(position * div_term)
+ # pe[:, 0, 1::2] = torch.cos(position * div_term)
+
+ # Assume that x in (N, seq_len, d)
+ pe = torch.zeros(1, max_len, d_model)
+ pe[0, :, 0::2] = torch.sin(position * div_term)
+ pe[0, :, 1::2] = torch.cos(position * div_term)
+
+ self.register_buffer("pe", pe)
+
+ def forward(self, x):
+ """
+ Args:
+ x: Tensor, shape [N, seq_len, d]
+ """
+ # Old: Assume that x is (seq_len, N, d), and self.pe is (max_len, 1, d_model)
+ # x = x + self.pe[: x.size(0)]
+
+ # Now: self.pe is (1, max_len, d)
+ x = x + self.pe[:, : x.size(1), :]
+
+ return self.dropout(x)
diff --git a/models/svc/transformer/transformer_inference.py b/models/svc/transformer/transformer_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6299c532aec6cb9283ee87ee9f0142f0b5c981b
--- /dev/null
+++ b/models/svc/transformer/transformer_inference.py
@@ -0,0 +1,45 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import time
+import numpy as np
+import torch
+from tqdm import tqdm
+import torch.nn as nn
+from collections import OrderedDict
+
+from models.svc.base import SVCInference
+from modules.encoder.condition_encoder import ConditionEncoder
+from models.svc.transformer.transformer import Transformer
+from models.svc.transformer.conformer import Conformer
+
+
+class TransformerInference(SVCInference):
+ def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
+ SVCInference.__init__(self, args, cfg, infer_type)
+
+ def _build_model(self):
+ self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
+ self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
+ self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
+ if self.cfg.model.transformer.type == "transformer":
+ self.acoustic_mapper = Transformer(self.cfg.model.transformer)
+ elif self.cfg.model.transformer.type == "conformer":
+ self.acoustic_mapper = Conformer(self.cfg.model.transformer)
+ else:
+ raise NotImplementedError
+ model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
+ return model
+
+ def _inference_each_batch(self, batch_data):
+ device = self.accelerator.device
+ for k, v in batch_data.items():
+ batch_data[k] = v.to(device)
+
+ condition = self.condition_encoder(batch_data)
+ y_pred = self.acoustic_mapper(condition, batch_data["mask"])
+
+ return y_pred
diff --git a/models/svc/transformer/transformer_trainer.py b/models/svc/transformer/transformer_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..3633078475d26e708280bc354f091bb9ab01ae45
--- /dev/null
+++ b/models/svc/transformer/transformer_trainer.py
@@ -0,0 +1,52 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+from models.svc.base import SVCTrainer
+from modules.encoder.condition_encoder import ConditionEncoder
+from models.svc.transformer.transformer import Transformer
+from models.svc.transformer.conformer import Conformer
+from utils.ssim import SSIM
+
+
+class TransformerTrainer(SVCTrainer):
+ def __init__(self, args, cfg):
+ SVCTrainer.__init__(self, args, cfg)
+ self.ssim_loss = SSIM()
+
+ def _build_model(self):
+ self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
+ self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
+ self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
+ if self.cfg.model.transformer.type == "transformer":
+ self.acoustic_mapper = Transformer(self.cfg.model.transformer)
+ elif self.cfg.model.transformer.type == "conformer":
+ self.acoustic_mapper = Conformer(self.cfg.model.transformer)
+ else:
+ raise NotImplementedError
+ model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
+ return model
+
+ def _forward_step(self, batch):
+ total_loss = 0
+ device = self.accelerator.device
+ mel = batch["mel"]
+ mask = batch["mask"]
+
+ condition = self.condition_encoder(batch)
+ mel_pred = self.acoustic_mapper(condition, mask)
+
+ l1_loss = torch.sum(torch.abs(mel_pred - mel) * batch["mask"]) / torch.sum(
+ batch["mask"]
+ )
+ self._check_nan(l1_loss, mel_pred, mel)
+ total_loss += l1_loss
+ ssim_loss = self.ssim_loss(mel_pred, mel)
+ ssim_loss = torch.sum(ssim_loss * batch["mask"]) / torch.sum(batch["mask"])
+ self._check_nan(ssim_loss, mel_pred, mel)
+ total_loss += ssim_loss
+
+ return total_loss
diff --git a/models/svc/vits/__init__.py b/models/svc/vits/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/svc/vits/vits.py b/models/svc/vits/vits.py
new file mode 100644
index 0000000000000000000000000000000000000000..983a704eb3abc065b30b5766fdbd7035587eb373
--- /dev/null
+++ b/models/svc/vits/vits.py
@@ -0,0 +1,267 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# This code is modified from https://github.com/svc-develop-team/so-vits-svc/blob/4.1-Stable/models.py
+import copy
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from utils.util import *
+
+from modules.transformer.attentions import Encoder
+from models.tts.vits.vits import ResidualCouplingBlock, PosteriorEncoder
+from models.vocoders.gan.generator.bigvgan import BigVGAN
+from models.vocoders.gan.generator.hifigan import HiFiGAN
+from models.vocoders.gan.generator.nsfhifigan import NSFHiFiGAN
+from models.vocoders.gan.generator.melgan import MelGAN
+from models.vocoders.gan.generator.apnet import APNet
+from modules.encoder.condition_encoder import ConditionEncoder
+
+
+def slice_pitch_segments(x, ids_str, segment_size=4):
+ ret = torch.zeros_like(x[:, :segment_size])
+ for i in range(x.size(0)):
+ idx_str = ids_str[i]
+ idx_end = idx_str + segment_size
+ ret[i] = x[i, idx_str:idx_end]
+ return ret
+
+
+def rand_slice_segments_with_pitch(x, pitch, x_lengths=None, segment_size=4):
+ b, d, t = x.size()
+ if x_lengths is None:
+ x_lengths = t
+ ids_str_max = x_lengths - segment_size + 1
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
+ ret = slice_segments(x, ids_str, segment_size)
+ ret_pitch = slice_pitch_segments(pitch, ids_str, segment_size)
+ return ret, ret_pitch, ids_str
+
+
+class ContentEncoder(nn.Module):
+ def __init__(
+ self,
+ out_channels,
+ hidden_channels,
+ kernel_size,
+ n_layers,
+ gin_channels=0,
+ filter_channels=None,
+ n_heads=None,
+ p_dropout=None,
+ ):
+ super().__init__()
+ self.out_channels = out_channels
+ self.hidden_channels = hidden_channels
+ self.kernel_size = kernel_size
+ self.n_layers = n_layers
+ self.gin_channels = gin_channels
+
+ self.f0_emb = nn.Embedding(256, hidden_channels)
+
+ self.enc_ = Encoder(
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
+ )
+
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
+
+ # condition_encoder ver.
+ def forward(self, x, x_mask, noice_scale=1):
+ x = self.enc_(x * x_mask, x_mask)
+ stats = self.proj(x) * x_mask
+ m, logs = torch.split(stats, self.out_channels, dim=1)
+ z = (m + torch.randn_like(m) * torch.exp(logs) * noice_scale) * x_mask
+
+ return z, m, logs, x_mask
+
+
+class SynthesizerTrn(nn.Module):
+ """
+ Synthesizer for Training
+ """
+
+ def __init__(self, spec_channels, segment_size, cfg):
+ super().__init__()
+ self.spec_channels = spec_channels
+ self.segment_size = segment_size
+ self.cfg = cfg
+ self.inter_channels = cfg.model.vits.inter_channels
+ self.hidden_channels = cfg.model.vits.hidden_channels
+ self.filter_channels = cfg.model.vits.filter_channels
+ self.n_heads = cfg.model.vits.n_heads
+ self.n_layers = cfg.model.vits.n_layers
+ self.kernel_size = cfg.model.vits.kernel_size
+ self.p_dropout = cfg.model.vits.p_dropout
+ self.n_flow_layer = cfg.model.vits.n_flow_layer
+ self.gin_channels = cfg.model.vits.gin_channels
+ self.n_speakers = cfg.model.vits.n_speakers
+
+ # f0
+ self.n_bins = cfg.preprocess.pitch_bin
+ self.f0_min = cfg.preprocess.f0_min
+ self.f0_max = cfg.preprocess.f0_max
+
+ # TODO: sort out the config
+ self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
+ self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
+ self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
+
+ self.emb_g = nn.Embedding(self.n_speakers, self.gin_channels)
+
+ self.enc_p = ContentEncoder(
+ self.inter_channels,
+ self.hidden_channels,
+ filter_channels=self.filter_channels,
+ n_heads=self.n_heads,
+ n_layers=self.n_layers,
+ kernel_size=self.kernel_size,
+ p_dropout=self.p_dropout,
+ )
+
+ assert cfg.model.generator in [
+ "bigvgan",
+ "hifigan",
+ "melgan",
+ "nsfhifigan",
+ "apnet",
+ ]
+ self.dec_name = cfg.model.generator
+ temp_cfg = copy.deepcopy(cfg)
+ temp_cfg.preprocess.n_mel = self.inter_channels
+ if cfg.model.generator == "bigvgan":
+ temp_cfg.model.bigvgan = cfg.model.generator_config.bigvgan
+ self.dec = BigVGAN(temp_cfg)
+ elif cfg.model.generator == "hifigan":
+ temp_cfg.model.hifigan = cfg.model.generator_config.hifigan
+ self.dec = HiFiGAN(temp_cfg)
+ elif cfg.model.generator == "melgan":
+ temp_cfg.model.melgan = cfg.model.generator_config.melgan
+ self.dec = MelGAN(temp_cfg)
+ elif cfg.model.generator == "nsfhifigan":
+ temp_cfg.model.nsfhifigan = cfg.model.generator_config.nsfhifigan
+ self.dec = NSFHiFiGAN(temp_cfg) # TODO: nsf need f0
+ elif cfg.model.generator == "apnet":
+ temp_cfg.model.apnet = cfg.model.generator_config.apnet
+ self.dec = APNet(temp_cfg)
+
+ self.enc_q = PosteriorEncoder(
+ self.spec_channels,
+ self.inter_channels,
+ self.hidden_channels,
+ 5,
+ 1,
+ 16,
+ gin_channels=self.gin_channels,
+ )
+
+ self.flow = ResidualCouplingBlock(
+ self.inter_channels,
+ self.hidden_channels,
+ 5,
+ 1,
+ self.n_flow_layer,
+ gin_channels=self.gin_channels,
+ )
+
+ def forward(self, data):
+ """VitsSVC forward function.
+
+ Args:
+ data (dict): condition data & audio data, including:
+ B: batch size, T: target length
+ {
+ "spk_id": [B, singer_table_size]
+ "target_len": [B]
+ "mask": [B, T, 1]
+ "mel": [B, T, n_mel]
+ "linear": [B, T, n_fft // 2 + 1]
+ "frame_pitch": [B, T]
+ "frame_uv": [B, T]
+ "audio": [B, audio_len]
+ "audio_len": [B]
+ "contentvec_feat": [B, T, contentvec_dim]
+ "whisper_feat": [B, T, whisper_dim]
+ ...
+ }
+ """
+
+ # TODO: elegantly handle the dimensions
+ spec = data["linear"].transpose(1, 2)
+
+ g = data["spk_id"]
+ g = self.emb_g(g).transpose(1, 2)
+
+ c_lengths = data["target_len"]
+ spec_lengths = data["target_len"]
+ f0 = data["frame_pitch"]
+
+ # condition_encoder ver.
+ x = self.condition_encoder(data).transpose(1, 2)
+ x_mask = torch.unsqueeze(sequence_mask(c_lengths, f0.size(1)), 1).to(x.dtype)
+
+ # prior encoder
+ z_ptemp, m_p, logs_p, _ = self.enc_p(x, x_mask)
+ # posterior encoder
+ z, m_q, logs_q, spec_mask = self.enc_q(spec, spec_lengths, g=g)
+
+ # flow
+ z_p = self.flow(z, spec_mask, g=g)
+ z_slice, pitch_slice, ids_slice = rand_slice_segments_with_pitch(
+ z, f0, spec_lengths, self.segment_size
+ )
+
+ if self.dec_name == "nsfhifigan":
+ o = self.dec(z_slice, f0=f0.float())
+ elif self.dec_name == "apnet":
+ _, _, _, _, o = self.dec(z_slice)
+ else:
+ o = self.dec(z_slice)
+
+ outputs = {
+ "y_hat": o,
+ "ids_slice": ids_slice,
+ "x_mask": x_mask,
+ "z_mask": data["mask"].transpose(1, 2),
+ "z": z,
+ "z_p": z_p,
+ "m_p": m_p,
+ "logs_p": logs_p,
+ "m_q": m_q,
+ "logs_q": logs_q,
+ }
+ return outputs
+
+ @torch.no_grad()
+ def infer(self, data, noise_scale=0.35, seed=52468):
+ # c, f0, uv, g
+ f0 = data["frame_pitch"]
+ g = data["spk_id"]
+
+ if f0.device == torch.device("cuda"):
+ torch.cuda.manual_seed_all(seed)
+ else:
+ torch.manual_seed(seed)
+
+ c_lengths = (torch.ones(f0.size(0)) * f0.size(-1)).to(f0.device)
+
+ if g.dim() == 1:
+ g = g.unsqueeze(0)
+ g = self.emb_g(g).transpose(1, 2)
+
+ # condition_encoder ver.
+ x = self.condition_encoder(data).transpose(1, 2)
+ x_mask = torch.unsqueeze(sequence_mask(c_lengths, f0.size(1)), 1).to(x.dtype)
+
+ z_p, m_p, logs_p, c_mask = self.enc_p(x, x_mask, noice_scale=noise_scale)
+ z = self.flow(z_p, c_mask, g=g, reverse=True)
+
+ if self.dec_name == "nsfhifigan":
+ o = self.dec(z * c_mask, f0=f0.float())
+ elif self.dec_name == "apnet":
+ _, _, _, _, o = self.dec(z * c_mask)
+ else:
+ o = self.dec(z * c_mask)
+ return o, f0
diff --git a/models/svc/vits/vits_inference.py b/models/svc/vits/vits_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1e4c1acc52d0fe18fb1d874234d042f90e5c03c
--- /dev/null
+++ b/models/svc/vits/vits_inference.py
@@ -0,0 +1,102 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import json
+import os
+import time
+import numpy as np
+from tqdm import tqdm
+import torch
+from torch.utils.data import DataLoader
+
+from models.svc.base import SVCInference
+from models.svc.vits.vits import SynthesizerTrn
+
+from models.svc.base.svc_dataset import SVCTestDataset, SVCTestCollator
+from utils.io import save_audio
+from utils.audio_slicer import is_silence
+
+
+class VitsInference(SVCInference):
+ def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
+ SVCInference.__init__(self, args, cfg)
+
+ def _build_model(self):
+ net_g = SynthesizerTrn(
+ self.cfg.preprocess.n_fft // 2 + 1,
+ self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
+ self.cfg,
+ )
+ self.model = net_g
+ return net_g
+
+ def build_save_dir(self, dataset, speaker):
+ save_dir = os.path.join(
+ self.args.output_dir,
+ "svc_am_step-{}_{}".format(self.am_restore_step, self.args.mode),
+ )
+ if dataset is not None:
+ save_dir = os.path.join(save_dir, "data_{}".format(dataset))
+ if speaker != -1:
+ save_dir = os.path.join(
+ save_dir,
+ "spk_{}".format(speaker),
+ )
+ os.makedirs(save_dir, exist_ok=True)
+ print("Saving to ", save_dir)
+ return save_dir
+
+ def _build_dataloader(self):
+ datasets, collate = self._build_test_dataset()
+ self.test_dataset = datasets(self.args, self.cfg, self.infer_type)
+ self.test_collate = collate(self.cfg)
+ self.test_batch_size = min(
+ self.cfg.inference.batch_size, len(self.test_dataset.metadata)
+ )
+ test_dataloader = DataLoader(
+ self.test_dataset,
+ collate_fn=self.test_collate,
+ num_workers=1,
+ batch_size=self.test_batch_size,
+ shuffle=False,
+ )
+ return test_dataloader
+
+ @torch.inference_mode()
+ def inference(self):
+ res = []
+ for i, batch in enumerate(self.test_dataloader):
+ pred_audio_list = self._inference_each_batch(batch)
+ for j, wav in enumerate(pred_audio_list):
+ uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"]
+ file = os.path.join(self.args.output_dir, f"{uid}.wav")
+ print(f"Saving {file}")
+
+ wav = wav.numpy(force=True)
+ save_audio(
+ file,
+ wav,
+ self.cfg.preprocess.sample_rate,
+ add_silence=False,
+ turn_up=not is_silence(wav, self.cfg.preprocess.sample_rate),
+ )
+ res.append(file)
+ return res
+
+ def _inference_each_batch(self, batch_data, noise_scale=0.667):
+ device = self.accelerator.device
+ pred_res = []
+ self.model.eval()
+ with torch.no_grad():
+ # Put the data to device
+ # device = self.accelerator.device
+ for k, v in batch_data.items():
+ batch_data[k] = v.to(device)
+
+ audios, f0 = self.model.infer(batch_data, noise_scale=noise_scale)
+
+ pred_res.extend(audios)
+
+ return pred_res
diff --git a/models/svc/vits/vits_trainer.py b/models/svc/vits/vits_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8be6d374d5f49959231bbd45daae4a904e7fa15e
--- /dev/null
+++ b/models/svc/vits/vits_trainer.py
@@ -0,0 +1,704 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch.optim.lr_scheduler import ExponentialLR
+from tqdm import tqdm
+from pathlib import Path
+import shutil
+import accelerate
+
+# from models.svc.base import SVCTrainer
+from models.svc.base.svc_dataset import SVCOfflineCollator, SVCOfflineDataset
+from models.svc.vits.vits import *
+from models.svc.base import SVCTrainer
+
+from utils.mel import mel_spectrogram_torch
+import json
+
+from models.vocoders.gan.discriminator.mpd import (
+ MultiPeriodDiscriminator_vits as MultiPeriodDiscriminator,
+)
+
+
+class VitsSVCTrainer(SVCTrainer):
+ def __init__(self, args, cfg):
+ self.args = args
+ self.cfg = cfg
+ SVCTrainer.__init__(self, args, cfg)
+
+ def _accelerator_prepare(self):
+ (
+ self.train_dataloader,
+ self.valid_dataloader,
+ ) = self.accelerator.prepare(
+ self.train_dataloader,
+ self.valid_dataloader,
+ )
+ if isinstance(self.model, dict):
+ for key in self.model.keys():
+ self.model[key] = self.accelerator.prepare(self.model[key])
+ else:
+ self.model = self.accelerator.prepare(self.model)
+
+ if isinstance(self.optimizer, dict):
+ for key in self.optimizer.keys():
+ self.optimizer[key] = self.accelerator.prepare(self.optimizer[key])
+ else:
+ self.optimizer = self.accelerator.prepare(self.optimizer)
+
+ if isinstance(self.scheduler, dict):
+ for key in self.scheduler.keys():
+ self.scheduler[key] = self.accelerator.prepare(self.scheduler[key])
+ else:
+ self.scheduler = self.accelerator.prepare(self.scheduler)
+
+ def _load_model(
+ self,
+ checkpoint_dir: str = None,
+ checkpoint_path: str = None,
+ resume_type: str = "",
+ ):
+ r"""Load model from checkpoint. If checkpoint_path is None, it will
+ load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
+ None, it will load the checkpoint specified by checkpoint_path. **Only use this
+ method after** ``accelerator.prepare()``.
+ """
+ if checkpoint_path is None:
+ ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
+ ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
+ checkpoint_path = ls[0]
+ self.logger.info("Resume from {}...".format(checkpoint_path))
+
+ if resume_type in ["resume", ""]:
+ # Load all the things, including model weights, optimizer, scheduler, and random states.
+ self.accelerator.load_state(input_dir=checkpoint_path)
+
+ # set epoch and step
+ self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
+ self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
+
+ elif resume_type == "finetune":
+ # Load only the model weights
+ accelerate.load_checkpoint_and_dispatch(
+ self.accelerator.unwrap_model(self.model["generator"]),
+ os.path.join(checkpoint_path, "pytorch_model.bin"),
+ )
+ accelerate.load_checkpoint_and_dispatch(
+ self.accelerator.unwrap_model(self.model["discriminator"]),
+ os.path.join(checkpoint_path, "pytorch_model.bin"),
+ )
+ self.logger.info("Load model weights for finetune...")
+
+ else:
+ raise ValueError("Resume_type must be `resume` or `finetune`.")
+
+ return checkpoint_path
+
+ def _build_model(self):
+ net_g = SynthesizerTrn(
+ self.cfg.preprocess.n_fft // 2 + 1,
+ self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
+ # directly use cfg
+ self.cfg,
+ )
+ net_d = MultiPeriodDiscriminator(self.cfg.model.vits.use_spectral_norm)
+ model = {"generator": net_g, "discriminator": net_d}
+
+ return model
+
+ def _build_dataset(self):
+ return SVCOfflineDataset, SVCOfflineCollator
+
+ def _build_optimizer(self):
+ optimizer_g = torch.optim.AdamW(
+ self.model["generator"].parameters(),
+ self.cfg.train.learning_rate,
+ betas=self.cfg.train.AdamW.betas,
+ eps=self.cfg.train.AdamW.eps,
+ )
+ optimizer_d = torch.optim.AdamW(
+ self.model["discriminator"].parameters(),
+ self.cfg.train.learning_rate,
+ betas=self.cfg.train.AdamW.betas,
+ eps=self.cfg.train.AdamW.eps,
+ )
+ optimizer = {"optimizer_g": optimizer_g, "optimizer_d": optimizer_d}
+
+ return optimizer
+
+ def _build_scheduler(self):
+ scheduler_g = ExponentialLR(
+ self.optimizer["optimizer_g"],
+ gamma=self.cfg.train.lr_decay,
+ last_epoch=self.epoch - 1,
+ )
+ scheduler_d = ExponentialLR(
+ self.optimizer["optimizer_d"],
+ gamma=self.cfg.train.lr_decay,
+ last_epoch=self.epoch - 1,
+ )
+
+ scheduler = {"scheduler_g": scheduler_g, "scheduler_d": scheduler_d}
+ return scheduler
+
+ def _build_criterion(self):
+ class GeneratorLoss(nn.Module):
+ def __init__(self, cfg):
+ super(GeneratorLoss, self).__init__()
+ self.cfg = cfg
+ self.l1_loss = nn.L1Loss()
+
+ def generator_loss(self, disc_outputs):
+ loss = 0
+ gen_losses = []
+ for dg in disc_outputs:
+ dg = dg.float()
+ l = torch.mean((1 - dg) ** 2)
+ gen_losses.append(l)
+ loss += l
+
+ return loss, gen_losses
+
+ def feature_loss(self, fmap_r, fmap_g):
+ loss = 0
+ for dr, dg in zip(fmap_r, fmap_g):
+ for rl, gl in zip(dr, dg):
+ rl = rl.float().detach()
+ gl = gl.float()
+ loss += torch.mean(torch.abs(rl - gl))
+
+ return loss * 2
+
+ def kl_loss(self, z_p, logs_q, m_p, logs_p, z_mask):
+ """
+ z_p, logs_q: [b, h, t_t]
+ m_p, logs_p: [b, h, t_t]
+ """
+ z_p = z_p.float()
+ logs_q = logs_q.float()
+ m_p = m_p.float()
+ logs_p = logs_p.float()
+ z_mask = z_mask.float()
+
+ kl = logs_p - logs_q - 0.5
+ kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
+ kl = torch.sum(kl * z_mask)
+ l = kl / torch.sum(z_mask)
+ return l
+
+ def forward(
+ self,
+ outputs_g,
+ outputs_d,
+ y_mel,
+ y_hat_mel,
+ ):
+ loss_g = {}
+
+ # mel loss
+ loss_mel = self.l1_loss(y_mel, y_hat_mel) * self.cfg.train.c_mel
+ loss_g["loss_mel"] = loss_mel
+
+ # kl loss
+ loss_kl = (
+ self.kl_loss(
+ outputs_g["z_p"],
+ outputs_g["logs_q"],
+ outputs_g["m_p"],
+ outputs_g["logs_p"],
+ outputs_g["z_mask"],
+ )
+ * self.cfg.train.c_kl
+ )
+ loss_g["loss_kl"] = loss_kl
+
+ # feature loss
+ loss_fm = self.feature_loss(outputs_d["fmap_rs"], outputs_d["fmap_gs"])
+ loss_g["loss_fm"] = loss_fm
+
+ # gan loss
+ loss_gen, losses_gen = self.generator_loss(outputs_d["y_d_hat_g"])
+ loss_g["loss_gen"] = loss_gen
+ loss_g["loss_gen_all"] = loss_mel + loss_kl + loss_fm + loss_gen
+
+ return loss_g
+
+ class DiscriminatorLoss(nn.Module):
+ def __init__(self, cfg):
+ super(DiscriminatorLoss, self).__init__()
+ self.cfg = cfg
+ self.l1Loss = torch.nn.L1Loss(reduction="mean")
+
+ def __call__(self, disc_real_outputs, disc_generated_outputs):
+ loss_d = {}
+
+ loss = 0
+ r_losses = []
+ g_losses = []
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
+ dr = dr.float()
+ dg = dg.float()
+ r_loss = torch.mean((1 - dr) ** 2)
+ g_loss = torch.mean(dg**2)
+ loss += r_loss + g_loss
+ r_losses.append(r_loss.item())
+ g_losses.append(g_loss.item())
+
+ loss_d["loss_disc_all"] = loss
+
+ return loss_d
+
+ criterion = {
+ "generator": GeneratorLoss(self.cfg),
+ "discriminator": DiscriminatorLoss(self.cfg),
+ }
+ return criterion
+
+ # Keep legacy unchanged
+ def write_summary(
+ self,
+ losses,
+ stats,
+ images={},
+ audios={},
+ audio_sampling_rate=24000,
+ tag="train",
+ ):
+ for key, value in losses.items():
+ self.sw.add_scalar(tag + "/" + key, value, self.step)
+ self.sw.add_scalar(
+ "learning_rate",
+ self.optimizer["optimizer_g"].param_groups[0]["lr"],
+ self.step,
+ )
+
+ if len(images) != 0:
+ for key, value in images.items():
+ self.sw.add_image(key, value, self.global_step, batchformats="HWC")
+ if len(audios) != 0:
+ for key, value in audios.items():
+ self.sw.add_audio(key, value, self.global_step, audio_sampling_rate)
+
+ def write_valid_summary(
+ self, losses, stats, images={}, audios={}, audio_sampling_rate=24000, tag="val"
+ ):
+ for key, value in losses.items():
+ self.sw.add_scalar(tag + "/" + key, value, self.step)
+
+ if len(images) != 0:
+ for key, value in images.items():
+ self.sw.add_image(key, value, self.global_step, batchformats="HWC")
+ if len(audios) != 0:
+ for key, value in audios.items():
+ self.sw.add_audio(key, value, self.global_step, audio_sampling_rate)
+
+ def _get_state_dict(self):
+ state_dict = {
+ "generator": self.model["generator"].state_dict(),
+ "discriminator": self.model["discriminator"].state_dict(),
+ "optimizer_g": self.optimizer["optimizer_g"].state_dict(),
+ "optimizer_d": self.optimizer["optimizer_d"].state_dict(),
+ "scheduler_g": self.scheduler["scheduler_g"].state_dict(),
+ "scheduler_d": self.scheduler["scheduler_d"].state_dict(),
+ "step": self.step,
+ "epoch": self.epoch,
+ "batch_size": self.cfg.train.batch_size,
+ }
+ return state_dict
+
+ def get_state_dict(self):
+ state_dict = {
+ "generator": self.model["generator"].state_dict(),
+ "discriminator": self.model["discriminator"].state_dict(),
+ "optimizer_g": self.optimizer["optimizer_g"].state_dict(),
+ "optimizer_d": self.optimizer["optimizer_d"].state_dict(),
+ "scheduler_g": self.scheduler["scheduler_g"].state_dict(),
+ "scheduler_d": self.scheduler["scheduler_d"].state_dict(),
+ "step": self.step,
+ "epoch": self.epoch,
+ "batch_size": self.cfg.train.batch_size,
+ }
+ return state_dict
+
+ def load_model(self, checkpoint):
+ self.step = checkpoint["step"]
+ self.epoch = checkpoint["epoch"]
+ self.model["generator"].load_state_dict(checkpoint["generator"])
+ self.model["discriminator"].load_state_dict(checkpoint["discriminator"])
+ self.optimizer["optimizer_g"].load_state_dict(checkpoint["optimizer_g"])
+ self.optimizer["optimizer_d"].load_state_dict(checkpoint["optimizer_d"])
+ self.scheduler["scheduler_g"].load_state_dict(checkpoint["scheduler_g"])
+ self.scheduler["scheduler_d"].load_state_dict(checkpoint["scheduler_d"])
+
+ @torch.inference_mode()
+ def _valid_step(self, batch):
+ r"""Testing forward step. Should return average loss of a sample over
+ one batch. Provoke ``_forward_step`` is recommended except for special case.
+ See ``_test_epoch`` for usage.
+ """
+
+ valid_losses = {}
+ total_loss = 0
+ valid_stats = {}
+
+ # Discriminator
+ # Generator output
+ outputs_g = self.model["generator"](batch)
+
+ y_mel = slice_segments(
+ batch["mel"].transpose(1, 2),
+ outputs_g["ids_slice"],
+ self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
+ )
+ y_hat_mel = mel_spectrogram_torch(
+ outputs_g["y_hat"].squeeze(1), self.cfg.preprocess
+ )
+ y = slice_segments(
+ batch["audio"].unsqueeze(1),
+ outputs_g["ids_slice"] * self.cfg.preprocess.hop_size,
+ self.cfg.preprocess.segment_size,
+ )
+
+ # Discriminator output
+ outputs_d = self.model["discriminator"](y, outputs_g["y_hat"].detach())
+ ## Discriminator loss
+ loss_d = self.criterion["discriminator"](
+ outputs_d["y_d_hat_r"], outputs_d["y_d_hat_g"]
+ )
+ valid_losses.update(loss_d)
+
+ ## Generator
+ outputs_d = self.model["discriminator"](y, outputs_g["y_hat"])
+ loss_g = self.criterion["generator"](outputs_g, outputs_d, y_mel, y_hat_mel)
+ valid_losses.update(loss_g)
+
+ for item in valid_losses:
+ valid_losses[item] = valid_losses[item].item()
+
+ total_loss = loss_g["loss_gen_all"] + loss_d["loss_disc_all"]
+
+ return (
+ total_loss.item(),
+ valid_losses,
+ valid_stats,
+ )
+
+ @torch.inference_mode()
+ def _valid_epoch(self):
+ r"""Testing epoch. Should return average loss of a batch (sample) over
+ one epoch. See ``train_loop`` for usage.
+ """
+ if isinstance(self.model, dict):
+ for key in self.model.keys():
+ self.model[key].eval()
+ else:
+ self.model.eval()
+
+ epoch_sum_loss = 0.0
+ epoch_losses = dict()
+ for batch in tqdm(
+ self.valid_dataloader,
+ desc=f"Validating Epoch {self.epoch}",
+ unit="batch",
+ colour="GREEN",
+ leave=False,
+ dynamic_ncols=True,
+ smoothing=0.04,
+ disable=not self.accelerator.is_main_process,
+ ):
+ total_loss, valid_losses, valid_stats = self._valid_step(batch)
+ epoch_sum_loss += total_loss
+ if isinstance(valid_losses, dict):
+ for key, value in valid_losses.items():
+ if key not in epoch_losses.keys():
+ epoch_losses[key] = value
+ else:
+ epoch_losses[key] += value
+
+ epoch_sum_loss = epoch_sum_loss / len(self.valid_dataloader)
+ for key in epoch_losses.keys():
+ epoch_losses[key] = epoch_losses[key] / len(self.valid_dataloader)
+
+ self.accelerator.wait_for_everyone()
+
+ return epoch_sum_loss, epoch_losses
+
+ ### THIS IS MAIN ENTRY ###
+ def train_loop(self):
+ r"""Training loop. The public entry of training process."""
+ # Wait everyone to prepare before we move on
+ self.accelerator.wait_for_everyone()
+ # dump config file
+ if self.accelerator.is_main_process:
+ self.__dump_cfg(self.config_save_path)
+
+ # self.optimizer.zero_grad()
+ # Wait to ensure good to go
+
+ self.accelerator.wait_for_everyone()
+ while self.epoch < self.max_epoch:
+ self.logger.info("\n")
+ self.logger.info("-" * 32)
+ self.logger.info("Epoch {}: ".format(self.epoch))
+
+ # Do training & validating epoch
+ train_total_loss, train_losses = self._train_epoch()
+ if isinstance(train_losses, dict):
+ for key, loss in train_losses.items():
+ self.logger.info(" |- Train/{} Loss: {:.6f}".format(key, loss))
+ self.accelerator.log(
+ {"Epoch/Train {} Loss".format(key): loss},
+ step=self.epoch,
+ )
+
+ valid_total_loss, valid_losses = self._valid_epoch()
+ if isinstance(valid_losses, dict):
+ for key, loss in valid_losses.items():
+ self.logger.info(" |- Valid/{} Loss: {:.6f}".format(key, loss))
+ self.accelerator.log(
+ {"Epoch/Train {} Loss".format(key): loss},
+ step=self.epoch,
+ )
+
+ self.logger.info(" |- Train/Loss: {:.6f}".format(train_total_loss))
+ self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_total_loss))
+ self.accelerator.log(
+ {
+ "Epoch/Train Loss": train_total_loss,
+ "Epoch/Valid Loss": valid_total_loss,
+ },
+ step=self.epoch,
+ )
+
+ self.accelerator.wait_for_everyone()
+
+ # Check if hit save_checkpoint_stride and run_eval
+ run_eval = False
+ if self.accelerator.is_main_process:
+ save_checkpoint = False
+ hit_dix = []
+ for i, num in enumerate(self.save_checkpoint_stride):
+ if self.epoch % num == 0:
+ save_checkpoint = True
+ hit_dix.append(i)
+ run_eval |= self.run_eval[i]
+
+ self.accelerator.wait_for_everyone()
+ if self.accelerator.is_main_process and save_checkpoint:
+ path = os.path.join(
+ self.checkpoint_dir,
+ "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
+ self.epoch, self.step, train_total_loss
+ ),
+ )
+ self.tmp_checkpoint_save_path = path
+ self.accelerator.save_state(path)
+
+ json.dump(
+ self.checkpoints_path,
+ open(os.path.join(path, "ckpts.json"), "w"),
+ ensure_ascii=False,
+ indent=4,
+ )
+ self._save_auxiliary_states()
+
+ # Remove old checkpoints
+ to_remove = []
+ for idx in hit_dix:
+ self.checkpoints_path[idx].append(path)
+ while len(self.checkpoints_path[idx]) > self.keep_last[idx]:
+ to_remove.append((idx, self.checkpoints_path[idx].pop(0)))
+
+ # Search conflicts
+ total = set()
+ for i in self.checkpoints_path:
+ total |= set(i)
+ do_remove = set()
+ for idx, path in to_remove[::-1]:
+ if path in total:
+ self.checkpoints_path[idx].insert(0, path)
+ else:
+ do_remove.add(path)
+
+ # Remove old checkpoints
+ for path in do_remove:
+ shutil.rmtree(path, ignore_errors=True)
+ self.logger.debug(f"Remove old checkpoint: {path}")
+
+ self.accelerator.wait_for_everyone()
+ if run_eval:
+ # TODO: run evaluation
+ pass
+
+ # Update info for each epoch
+ self.epoch += 1
+
+ # Finish training and save final checkpoint
+ self.accelerator.wait_for_everyone()
+ if self.accelerator.is_main_process:
+ path = os.path.join(
+ self.checkpoint_dir,
+ "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
+ self.epoch, self.step, valid_total_loss
+ ),
+ )
+ self.tmp_checkpoint_save_path = path
+ self.accelerator.save_state(
+ os.path.join(
+ self.checkpoint_dir,
+ "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
+ self.epoch, self.step, valid_total_loss
+ ),
+ )
+ )
+
+ json.dump(
+ self.checkpoints_path,
+ open(os.path.join(path, "ckpts.json"), "w"),
+ ensure_ascii=False,
+ indent=4,
+ )
+ self._save_auxiliary_states()
+
+ self.accelerator.end_training()
+
+ def _train_step(self, batch):
+ r"""Forward step for training and inference. This function is called
+ in ``_train_step`` & ``_test_step`` function.
+ """
+
+ train_losses = {}
+ total_loss = 0
+ training_stats = {}
+
+ ## Train Discriminator
+ # Generator output
+ outputs_g = self.model["generator"](batch)
+
+ y_mel = slice_segments(
+ batch["mel"].transpose(1, 2),
+ outputs_g["ids_slice"],
+ self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
+ )
+ y_hat_mel = mel_spectrogram_torch(
+ outputs_g["y_hat"].squeeze(1), self.cfg.preprocess
+ )
+
+ y = slice_segments(
+ # [1, 168418] -> [1, 1, 168418]
+ batch["audio"].unsqueeze(1),
+ outputs_g["ids_slice"] * self.cfg.preprocess.hop_size,
+ self.cfg.preprocess.segment_size,
+ )
+
+ # Discriminator output
+ outputs_d = self.model["discriminator"](y, outputs_g["y_hat"].detach())
+ # Discriminator loss
+ loss_d = self.criterion["discriminator"](
+ outputs_d["y_d_hat_r"], outputs_d["y_d_hat_g"]
+ )
+ train_losses.update(loss_d)
+
+ # BP and Grad Updated
+ self.optimizer["optimizer_d"].zero_grad()
+ self.accelerator.backward(loss_d["loss_disc_all"])
+ self.optimizer["optimizer_d"].step()
+
+ ## Train Generator
+ outputs_d = self.model["discriminator"](y, outputs_g["y_hat"])
+ loss_g = self.criterion["generator"](outputs_g, outputs_d, y_mel, y_hat_mel)
+ train_losses.update(loss_g)
+
+ # BP and Grad Updated
+ self.optimizer["optimizer_g"].zero_grad()
+ self.accelerator.backward(loss_g["loss_gen_all"])
+ self.optimizer["optimizer_g"].step()
+
+ for item in train_losses:
+ train_losses[item] = train_losses[item].item()
+
+ total_loss = loss_g["loss_gen_all"] + loss_d["loss_disc_all"]
+
+ return (
+ total_loss.item(),
+ train_losses,
+ training_stats,
+ )
+
+ def _train_epoch(self):
+ r"""Training epoch. Should return average loss of a batch (sample) over
+ one epoch. See ``train_loop`` for usage.
+ """
+ epoch_sum_loss: float = 0.0
+ epoch_losses: dict = {}
+ epoch_step: int = 0
+ for batch in tqdm(
+ self.train_dataloader,
+ desc=f"Training Epoch {self.epoch}",
+ unit="batch",
+ colour="GREEN",
+ leave=False,
+ dynamic_ncols=True,
+ smoothing=0.04,
+ disable=not self.accelerator.is_main_process,
+ ):
+ # Do training step and BP
+ with self.accelerator.accumulate(self.model):
+ total_loss, train_losses, training_stats = self._train_step(batch)
+ self.batch_count += 1
+
+ # Update info for each step
+ if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
+ epoch_sum_loss += total_loss
+ for key, value in train_losses.items():
+ if key not in epoch_losses.keys():
+ epoch_losses[key] = value
+ else:
+ epoch_losses[key] += value
+
+ self.accelerator.log(
+ {
+ "Step/Generator Loss": train_losses["loss_gen_all"],
+ "Step/Discriminator Loss": train_losses["loss_disc_all"],
+ "Step/Generator Learning Rate": self.optimizer[
+ "optimizer_d"
+ ].param_groups[0]["lr"],
+ "Step/Discriminator Learning Rate": self.optimizer[
+ "optimizer_g"
+ ].param_groups[0]["lr"],
+ },
+ step=self.step,
+ )
+ self.step += 1
+ epoch_step += 1
+
+ self.accelerator.wait_for_everyone()
+
+ epoch_sum_loss = (
+ epoch_sum_loss
+ / len(self.train_dataloader)
+ * self.cfg.train.gradient_accumulation_step
+ )
+
+ for key in epoch_losses.keys():
+ epoch_losses[key] = (
+ epoch_losses[key]
+ / len(self.train_dataloader)
+ * self.cfg.train.gradient_accumulation_step
+ )
+
+ return epoch_sum_loss, epoch_losses
+
+ def __dump_cfg(self, path):
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ json5.dump(
+ self.cfg,
+ open(path, "w"),
+ indent=4,
+ sort_keys=True,
+ ensure_ascii=False,
+ quote_keys=True,
+ )
diff --git a/models/tta/autoencoder/__init__.py b/models/tta/autoencoder/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/tta/autoencoder/autoencoder.py b/models/tta/autoencoder/autoencoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..fec677182a81e1149ad7736e80d135ee38bbbea9
--- /dev/null
+++ b/models/tta/autoencoder/autoencoder.py
@@ -0,0 +1,403 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from modules.distributions.distributions import DiagonalGaussianDistribution
+
+
+def nonlinearity(x):
+ # swish
+ return x * torch.sigmoid(x)
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
+ )
+
+
+class Upsample2d(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Upsample1d(Upsample2d):
+ def __init__(self, in_channels, with_conv):
+ super().__init__(in_channels, with_conv)
+ if self.with_conv:
+ self.conv = torch.nn.Conv1d(
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
+ )
+
+
+class Downsample2d(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
+ )
+ self.pad = (0, 1, 0, 1)
+ else:
+ self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
+
+ def forward(self, x):
+ if self.with_conv: # bp: check self.avgpool and self.pad
+ x = torch.nn.functional.pad(x, self.pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = self.avg_pool(x)
+ return x
+
+
+class Downsample1d(Downsample2d):
+ def __init__(self, in_channels, with_conv):
+ super().__init__(in_channels, with_conv)
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ # TODO: can we replace it just with conv2d with padding 1?
+ self.conv = torch.nn.Conv1d(
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
+ )
+ self.pad = (1, 1)
+ else:
+ self.avg_pool = nn.AvgPool1d(kernel_size=2, stride=2)
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
+ )
+
+ def forward(self, x):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x + h
+
+
+class ResnetBlock1d(ResnetBlock):
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout,
+ temb_channels=512
+ ):
+ super().__init__(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ conv_shortcut=conv_shortcut,
+ dropout=dropout,
+ )
+
+ self.conv1 = torch.nn.Conv1d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ self.conv2 = torch.nn.Conv1d(
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv1d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ else:
+ self.nin_shortcut = torch.nn.Conv1d(
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
+ )
+
+
+class Encoder2d(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ z_channels,
+ double_z=True,
+ **ignore_kwargs
+ ):
+ super().__init__()
+ self.ch = ch
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
+ )
+
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ ResnetBlock(
+ in_channels=block_in, out_channels=block_out, dropout=dropout
+ )
+ )
+ block_in = block_out
+ down = nn.Module()
+ down.block = block
+ if i_level != self.num_resolutions - 1:
+ down.downsample = Downsample2d(block_in, resamp_with_conv)
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in, out_channels=block_in, dropout=dropout
+ )
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in, out_channels=block_in, dropout=dropout
+ )
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(
+ block_in,
+ 2 * z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ def forward(self, x):
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1])
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h)
+ h = self.mid.block_2(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+# TODO: Encoder1d
+class Encoder1d(Encoder2d): ...
+
+
+class Decoder2d(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ z_channels,
+ give_pre_end=False,
+ **ignorekwargs
+ ):
+ super().__init__()
+ self.ch = ch
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,) + tuple(ch_mult)
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ # self.z_shape = (1,z_channels,curr_res,curr_res)
+ # print("Working with z of shape {} = {} dimensions.".format(
+ # self.z_shape, np.prod(self.z_shape)))
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
+ )
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in, out_channels=block_in, dropout=dropout
+ )
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in, out_channels=block_in, dropout=dropout
+ )
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ block.append(
+ ResnetBlock(
+ in_channels=block_in, out_channels=block_out, dropout=dropout
+ )
+ )
+ block_in = block_out
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample2d(block_in, resamp_with_conv)
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
+ )
+
+ def forward(self, z):
+ self.last_z_shape = z.shape
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h)
+ h = self.mid.block_2(h)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+# TODO: decoder1d
+class Decoder1d(Decoder2d): ...
+
+
+class AutoencoderKL(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.cfg = cfg
+ self.encoder = Encoder2d(
+ ch=cfg.ch,
+ ch_mult=cfg.ch_mult,
+ num_res_blocks=cfg.num_res_blocks,
+ in_channels=cfg.in_channels,
+ z_channels=cfg.z_channels,
+ double_z=cfg.double_z,
+ )
+ self.decoder = Decoder2d(
+ ch=cfg.ch,
+ ch_mult=cfg.ch_mult,
+ num_res_blocks=cfg.num_res_blocks,
+ out_ch=cfg.out_ch,
+ z_channels=cfg.z_channels,
+ in_channels=None,
+ )
+ assert self.cfg.double_z
+
+ self.quant_conv = torch.nn.Conv2d(2 * cfg.z_channels, 2 * cfg.z_channels, 1)
+ self.post_quant_conv = torch.nn.Conv2d(cfg.z_channels, cfg.z_channels, 1)
+ self.embed_dim = cfg.z_channels
+
+ def encode(self, x):
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+
+ def decode(self, z):
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ return dec
+
+ def forward(self, input, sample_posterior=True):
+ posterior = self.encode(input)
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ return dec, posterior
+
+ def get_last_layer(self):
+ return self.decoder.conv_out.weight
diff --git a/models/tta/autoencoder/autoencoder_dataset.py b/models/tta/autoencoder/autoencoder_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b9b4bdfe8d9eeae62dbc313547cdb056938104a
--- /dev/null
+++ b/models/tta/autoencoder/autoencoder_dataset.py
@@ -0,0 +1,112 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import random
+import torch
+from torch.nn.utils.rnn import pad_sequence
+from utils.data_utils import *
+from models.base.base_dataset import (
+ BaseOfflineCollator,
+ BaseOfflineDataset,
+ BaseTestDataset,
+ BaseTestCollator,
+)
+import librosa
+
+
+class AutoencoderKLDataset(BaseOfflineDataset):
+ def __init__(self, cfg, dataset, is_valid=False):
+ BaseOfflineDataset.__init__(self, cfg, dataset, is_valid=is_valid)
+
+ cfg = self.cfg
+
+ # utt2melspec
+ if cfg.preprocess.use_melspec:
+ self.utt2melspec_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2melspec_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.melspec_dir,
+ uid + ".npy",
+ )
+
+ # utt2wav
+ if cfg.preprocess.use_wav:
+ self.utt2wav_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2wav_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.wav_dir,
+ uid + ".wav",
+ )
+
+ def __getitem__(self, index):
+ # melspec: (n_mels, T)
+ # wav: (T,)
+
+ single_feature = BaseOfflineDataset.__getitem__(self, index)
+
+ utt_info = self.metadata[index]
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ if self.cfg.preprocess.use_melspec:
+ single_feature["melspec"] = np.load(self.utt2melspec_path[utt])
+
+ if self.cfg.preprocess.use_wav:
+ wav, sr = librosa.load(
+ self.utt2wav_path[utt], sr=16000
+ ) # hard coding for 16KHz...
+ single_feature["wav"] = wav
+
+ return single_feature
+
+ def __len__(self):
+ return len(self.metadata)
+
+ def __len__(self):
+ return len(self.metadata)
+
+
+class AutoencoderKLCollator(BaseOfflineCollator):
+ def __init__(self, cfg):
+ BaseOfflineCollator.__init__(self, cfg)
+
+ def __call__(self, batch):
+ # mel: (B, n_mels, T)
+ # wav (option): (B, T)
+
+ packed_batch_features = dict()
+
+ for key in batch[0].keys():
+ if key == "melspec":
+ packed_batch_features["melspec"] = torch.from_numpy(
+ np.array([b["melspec"][:, :624] for b in batch])
+ )
+
+ if key == "wav":
+ values = [torch.from_numpy(b[key]) for b in batch]
+ packed_batch_features[key] = pad_sequence(
+ values, batch_first=True, padding_value=0
+ )
+
+ return packed_batch_features
+
+
+class AutoencoderKLTestDataset(BaseTestDataset): ...
+
+
+class AutoencoderKLTestCollator(BaseTestCollator): ...
diff --git a/models/tta/autoencoder/autoencoder_loss.py b/models/tta/autoencoder/autoencoder_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..5916aa36ebdfba6f2608514767f3e3761b57269f
--- /dev/null
+++ b/models/tta/autoencoder/autoencoder_loss.py
@@ -0,0 +1,305 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import functools
+import torch.nn.functional as F
+
+
+def hinge_d_loss(logits_real, logits_fake):
+ loss_real = torch.mean(F.relu(1.0 - logits_real))
+ loss_fake = torch.mean(F.relu(1.0 + logits_fake))
+ d_loss = 0.5 * (loss_real + loss_fake)
+ return d_loss
+
+
+def vanilla_d_loss(logits_real, logits_fake):
+ d_loss = 0.5 * (
+ torch.mean(F.softplus(-logits_real)) + torch.mean(F.softplus(logits_fake))
+ )
+ return d_loss
+
+
+def adopt_weight(weight, global_step, threshold=0, value=0.0):
+ if global_step < threshold:
+ weight = value
+ return weight
+
+
+class ActNorm(nn.Module):
+ def __init__(
+ self, num_features, logdet=False, affine=True, allow_reverse_init=False
+ ):
+ assert affine
+ super().__init__()
+ self.logdet = logdet
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
+ self.allow_reverse_init = allow_reverse_init
+
+ self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
+
+ def initialize(self, input):
+ with torch.no_grad():
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
+ mean = (
+ flatten.mean(1)
+ .unsqueeze(1)
+ .unsqueeze(2)
+ .unsqueeze(3)
+ .permute(1, 0, 2, 3)
+ )
+ std = (
+ flatten.std(1)
+ .unsqueeze(1)
+ .unsqueeze(2)
+ .unsqueeze(3)
+ .permute(1, 0, 2, 3)
+ )
+
+ self.loc.data.copy_(-mean)
+ self.scale.data.copy_(1 / (std + 1e-6))
+
+ def forward(self, input, reverse=False):
+ if reverse:
+ return self.reverse(input)
+ if len(input.shape) == 2:
+ input = input[:, :, None, None]
+ squeeze = True
+ else:
+ squeeze = False
+
+ _, _, height, width = input.shape
+
+ if self.training and self.initialized.item() == 0:
+ self.initialize(input)
+ self.initialized.fill_(1)
+
+ h = self.scale * (input + self.loc)
+
+ if squeeze:
+ h = h.squeeze(-1).squeeze(-1)
+
+ if self.logdet:
+ log_abs = torch.log(torch.abs(self.scale))
+ logdet = height * width * torch.sum(log_abs)
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
+ return h, logdet
+
+ return h
+
+ def reverse(self, output):
+ if self.training and self.initialized.item() == 0:
+ if not self.allow_reverse_init:
+ raise RuntimeError(
+ "Initializing ActNorm in reverse direction is "
+ "disabled by default. Use allow_reverse_init=True to enable."
+ )
+ else:
+ self.initialize(output)
+ self.initialized.fill_(1)
+
+ if len(output.shape) == 2:
+ output = output[:, :, None, None]
+ squeeze = True
+ else:
+ squeeze = False
+
+ h = output / self.scale - self.loc
+
+ if squeeze:
+ h = h.squeeze(-1).squeeze(-1)
+ return h
+
+
+def weights_init(m):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
+ elif classname.find("BatchNorm") != -1:
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
+ nn.init.constant_(m.bias.data, 0)
+
+
+class NLayerDiscriminator(nn.Module):
+ """Defines a PatchGAN discriminator as in Pix2Pix
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
+ """
+
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
+ """Construct a PatchGAN discriminator
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ ndf (int) -- the number of filters in the last conv layer
+ n_layers (int) -- the number of conv layers in the discriminator
+ norm_layer -- normalization layer
+ """
+ super(NLayerDiscriminator, self).__init__()
+ if not use_actnorm:
+ norm_layer = nn.BatchNorm2d
+ else:
+ norm_layer = ActNorm
+ if (
+ type(norm_layer) == functools.partial
+ ): # no need to use bias as BatchNorm2d has affine parameters
+ use_bias = norm_layer.func != nn.BatchNorm2d
+ else:
+ use_bias = norm_layer != nn.BatchNorm2d
+
+ kw = 4
+ padw = 1
+ sequence = [
+ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
+ nn.LeakyReLU(0.2, True),
+ ]
+ nf_mult = 1
+ nf_mult_prev = 1
+ for n in range(1, n_layers): # gradually increase the number of filters
+ nf_mult_prev = nf_mult
+ nf_mult = min(2**n, 8)
+ sequence += [
+ nn.Conv2d(
+ ndf * nf_mult_prev,
+ ndf * nf_mult,
+ kernel_size=kw,
+ stride=2,
+ padding=padw,
+ bias=use_bias,
+ ),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True),
+ ]
+
+ nf_mult_prev = nf_mult
+ nf_mult = min(2**n_layers, 8)
+ sequence += [
+ nn.Conv2d(
+ ndf * nf_mult_prev,
+ ndf * nf_mult,
+ kernel_size=kw,
+ stride=1,
+ padding=padw,
+ bias=use_bias,
+ ),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True),
+ ]
+
+ sequence += [
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
+ ] # output 1 channel prediction map
+ self.main = nn.Sequential(*sequence)
+
+ def forward(self, input):
+ """Standard forward."""
+ return self.main(input)
+
+
+class AutoencoderLossWithDiscriminator(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.cfg = cfg
+ self.kl_weight = cfg.kl_weight
+ self.logvar = nn.Parameter(torch.ones(size=()) * cfg.logvar_init)
+
+ self.discriminator = NLayerDiscriminator(
+ input_nc=cfg.disc_in_channels,
+ n_layers=cfg.disc_num_layers,
+ use_actnorm=cfg.use_actnorm,
+ ).apply(weights_init)
+
+ self.discriminator_iter_start = cfg.disc_start
+ self.discriminator_weight = cfg.disc_weight
+ self.disc_factor = cfg.disc_factor
+ self.disc_loss = hinge_d_loss
+
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer):
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
+ d_weight = torch.clamp(
+ d_weight, self.cfg.min_adapt_d_weight, self.cfg.max_adapt_d_weight
+ ).detach()
+ d_weight = d_weight * self.discriminator_weight
+ return d_weight
+
+ def forward(
+ self,
+ inputs,
+ reconstructions,
+ posteriors,
+ optimizer_idx,
+ global_step,
+ last_layer,
+ split="train",
+ weights=None,
+ ):
+ rec_loss = torch.abs(
+ inputs.contiguous() - reconstructions.contiguous()
+ ) # l1 loss
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
+ weighted_nll_loss = nll_loss
+ if weights is not None:
+ weighted_nll_loss = weights * nll_loss
+ # weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
+ weighted_nll_loss = torch.mean(weighted_nll_loss)
+ # nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
+ nll_loss = torch.mean(nll_loss)
+ kl_loss = posteriors.kl()
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
+ # ? kl_loss = torch.mean(kl_loss)
+
+ # now the GAN part
+ if optimizer_idx == 0:
+ logits_fake = self.discriminator(reconstructions.contiguous())
+ g_loss = -torch.mean(logits_fake)
+
+ if self.disc_factor > 0.0:
+ try:
+ d_weight = self.calculate_adaptive_weight(
+ nll_loss, g_loss, last_layer=last_layer
+ )
+ except RuntimeError:
+ assert not self.training
+ d_weight = torch.tensor(0.0)
+ else:
+ d_weight = torch.tensor(0.0)
+
+ disc_factor = adopt_weight(
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
+ )
+
+ total_loss = (
+ weighted_nll_loss
+ + self.kl_weight * kl_loss
+ + d_weight * disc_factor * g_loss
+ )
+
+ return {
+ "loss": total_loss,
+ "kl_loss": kl_loss,
+ "rec_loss": rec_loss.mean(),
+ "nll_loss": nll_loss,
+ "g_loss": g_loss,
+ "d_weight": d_weight,
+ "disc_factor": torch.tensor(disc_factor),
+ }
+
+ if optimizer_idx == 1:
+ logits_real = self.discriminator(inputs.contiguous().detach())
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
+
+ disc_factor = adopt_weight(
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
+ )
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
+
+ return {
+ "d_loss": d_loss,
+ "logits_real": logits_real.mean(),
+ "logits_fake": logits_fake.mean(),
+ }
diff --git a/models/tta/autoencoder/autoencoder_trainer.py b/models/tta/autoencoder/autoencoder_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..1faf02fa26c9bdc69faf2344fc2f722336d68a71
--- /dev/null
+++ b/models/tta/autoencoder/autoencoder_trainer.py
@@ -0,0 +1,187 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from models.base.base_trainer import BaseTrainer
+from models.tta.autoencoder.autoencoder_dataset import (
+ AutoencoderKLDataset,
+ AutoencoderKLCollator,
+)
+from models.tta.autoencoder.autoencoder import AutoencoderKL
+from models.tta.autoencoder.autoencoder_loss import AutoencoderLossWithDiscriminator
+from torch.optim import Adam, AdamW
+from torch.optim.lr_scheduler import ReduceLROnPlateau
+from torch.nn import MSELoss, L1Loss
+import torch.nn.functional as F
+from torch.utils.data import ConcatDataset, DataLoader
+
+
+class AutoencoderKLTrainer(BaseTrainer):
+ def __init__(self, args, cfg):
+ BaseTrainer.__init__(self, args, cfg)
+ self.cfg = cfg
+ self.save_config_file()
+
+ def build_dataset(self):
+ return AutoencoderKLDataset, AutoencoderKLCollator
+
+ def build_optimizer(self):
+ opt_ae = torch.optim.AdamW(self.model.parameters(), **self.cfg.train.adam)
+ opt_disc = torch.optim.AdamW(
+ self.criterion.discriminator.parameters(), **self.cfg.train.adam
+ )
+ optimizer = {"opt_ae": opt_ae, "opt_disc": opt_disc}
+ return optimizer
+
+ def build_data_loader(self):
+ Dataset, Collator = self.build_dataset()
+ # build dataset instance for each dataset and combine them by ConcatDataset
+ datasets_list = []
+ for dataset in self.cfg.dataset:
+ subdataset = Dataset(self.cfg, dataset, is_valid=False)
+ datasets_list.append(subdataset)
+ train_dataset = ConcatDataset(datasets_list)
+
+ train_collate = Collator(self.cfg)
+
+ # use batch_sampler argument instead of (sampler, shuffle, drop_last, batch_size)
+ train_loader = DataLoader(
+ train_dataset,
+ collate_fn=train_collate,
+ num_workers=self.args.num_workers,
+ batch_size=self.cfg.train.batch_size,
+ pin_memory=False,
+ )
+ if not self.cfg.train.ddp or self.args.local_rank == 0:
+ datasets_list = []
+ for dataset in self.cfg.dataset:
+ subdataset = Dataset(self.cfg, dataset, is_valid=True)
+ datasets_list.append(subdataset)
+ valid_dataset = ConcatDataset(datasets_list)
+ valid_collate = Collator(self.cfg)
+
+ valid_loader = DataLoader(
+ valid_dataset,
+ collate_fn=valid_collate,
+ num_workers=1,
+ batch_size=self.cfg.train.batch_size,
+ )
+ else:
+ raise NotImplementedError("DDP is not supported yet.")
+ # valid_loader = None
+ data_loader = {"train": train_loader, "valid": valid_loader}
+ return data_loader
+
+ # TODO: check it...
+ def build_scheduler(self):
+ return None
+ # return ReduceLROnPlateau(self.optimizer["opt_ae"], **self.cfg.train.lronPlateau)
+
+ def write_summary(self, losses, stats):
+ for key, value in losses.items():
+ self.sw.add_scalar(key, value, self.step)
+
+ def write_valid_summary(self, losses, stats):
+ for key, value in losses.items():
+ self.sw.add_scalar(key, value, self.step)
+
+ def build_criterion(self):
+ return AutoencoderLossWithDiscriminator(self.cfg.model.loss)
+
+ def get_state_dict(self):
+ if self.scheduler != None:
+ state_dict = {
+ "model": self.model.state_dict(),
+ "optimizer_ae": self.optimizer["opt_ae"].state_dict(),
+ "optimizer_disc": self.optimizer["opt_disc"].state_dict(),
+ "scheduler": self.scheduler.state_dict(),
+ "step": self.step,
+ "epoch": self.epoch,
+ "batch_size": self.cfg.train.batch_size,
+ }
+ else:
+ state_dict = {
+ "model": self.model.state_dict(),
+ "optimizer_ae": self.optimizer["opt_ae"].state_dict(),
+ "optimizer_disc": self.optimizer["opt_disc"].state_dict(),
+ "step": self.step,
+ "epoch": self.epoch,
+ "batch_size": self.cfg.train.batch_size,
+ }
+ return state_dict
+
+ def load_model(self, checkpoint):
+ self.step = checkpoint["step"]
+ self.epoch = checkpoint["epoch"]
+
+ self.model.load_state_dict(checkpoint["model"])
+ self.optimizer["opt_ae"].load_state_dict(checkpoint["optimizer_ae"])
+ self.optimizer["opt_disc"].load_state_dict(checkpoint["optimizer_disc"])
+ if self.scheduler != None:
+ self.scheduler.load_state_dict(checkpoint["scheduler"])
+
+ def build_model(self):
+ self.model = AutoencoderKL(self.cfg.model.autoencoderkl)
+ return self.model
+
+ # TODO: train step
+ def train_step(self, data):
+ global_step = self.step
+ optimizer_idx = global_step % 2
+
+ train_losses = {}
+ total_loss = 0
+ train_states = {}
+
+ inputs = data["melspec"].unsqueeze(1) # (B, 80, T) -> (B, 1, 80, T)
+ reconstructions, posterior = self.model(inputs)
+ # train_stats.update(stat)
+
+ train_losses = self.criterion(
+ inputs=inputs,
+ reconstructions=reconstructions,
+ posteriors=posterior,
+ optimizer_idx=optimizer_idx,
+ global_step=global_step,
+ last_layer=self.model.get_last_layer(),
+ split="train",
+ )
+
+ if optimizer_idx == 0:
+ total_loss = train_losses["loss"]
+ self.optimizer["opt_ae"].zero_grad()
+ total_loss.backward()
+ self.optimizer["opt_ae"].step()
+
+ else:
+ total_loss = train_losses["d_loss"]
+ self.optimizer["opt_disc"].zero_grad()
+ total_loss.backward()
+ self.optimizer["opt_disc"].step()
+
+ for item in train_losses:
+ train_losses[item] = train_losses[item].item()
+
+ return train_losses, train_states, total_loss.item()
+
+ # TODO: eval step
+ @torch.no_grad()
+ def eval_step(self, data, index):
+ valid_loss = {}
+ total_valid_loss = 0
+ valid_stats = {}
+
+ inputs = data["melspec"].unsqueeze(1) # (B, 80, T) -> (B, 1, 80, T)
+ reconstructions, posterior = self.model(inputs)
+
+ loss = F.l1_loss(inputs, reconstructions)
+ valid_loss["loss"] = loss
+
+ total_valid_loss += loss
+
+ for item in valid_loss:
+ valid_loss[item] = valid_loss[item].item()
+
+ return valid_loss, valid_stats, total_valid_loss.item()
diff --git a/models/tta/ldm/__init__.py b/models/tta/ldm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/tta/ldm/attention.py b/models/tta/ldm/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e241fc8be6456a1c3bc3b8d8efe648ea4e42740
--- /dev/null
+++ b/models/tta/ldm/attention.py
@@ -0,0 +1,329 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from inspect import isfunction
+import math
+import torch
+import torch.nn.functional as F
+from torch import nn, einsum
+from einops import rearrange, repeat
+
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad():
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+def exists(val):
+ return val is not None
+
+
+def uniq(arr):
+ return {el: True for el in arr}.keys()
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def max_neg_value(t):
+ return -torch.finfo(t.dtype).max
+
+
+def init_(tensor):
+ dim = tensor.shape[-1]
+ std = 1 / math.sqrt(dim)
+ tensor.uniform_(-std, std)
+ return tensor
+
+
+# feedforward
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = (
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
+ if not glu
+ else GEGLU(dim, inner_dim)
+ )
+
+ self.net = nn.Sequential(
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
+ )
+
+
+class LinearAttention(nn.Module):
+ def __init__(self, dim, heads=4, dim_head=32):
+ super().__init__()
+ self.heads = heads
+ hidden_dim = dim_head * heads
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ qkv = self.to_qkv(x)
+ q, k, v = rearrange(
+ qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
+ )
+ k = k.softmax(dim=-1)
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
+ out = rearrange(
+ out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
+ )
+ return self.to_out(out)
+
+
+class SpatialSelfAttention(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.k = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.v = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.proj_out = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = rearrange(q, "b c h w -> b (h w) c")
+ k = rearrange(k, "b c h w -> b c (h w)")
+ w_ = torch.einsum("bij,bjk->bik", q, k)
+
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = rearrange(v, "b c h w -> b c (h w)")
+ w_ = rearrange(w_, "b i j -> b j i")
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+class CrossAttention(nn.Module):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.scale = dim_head**-0.5
+ self.heads = heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
+ )
+
+ def forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
+
+ sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
+
+ if exists(mask):
+ mask = rearrange(mask, "b ... -> b (...)")
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, "b j -> (b h) () j", h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ attn = sim.softmax(dim=-1)
+
+ out = einsum("b i j, b j d -> b i d", attn, v)
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
+ return self.to_out(out)
+
+
+class BasicTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim,
+ n_heads,
+ d_head,
+ dropout=0.0,
+ context_dim=None,
+ gated_ff=True,
+ checkpoint=True,
+ ):
+ super().__init__()
+ self.attn1 = CrossAttention(
+ query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
+ ) # is a self-attention
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = CrossAttention(
+ query_dim=dim,
+ context_dim=context_dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ ) # is self-attn if context is none
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def forward(self, x, context=None):
+ return checkpoint(
+ self._forward, (x, context), self.parameters(), self.checkpoint
+ )
+
+ def _forward(self, x, context=None):
+ x = self.attn1(self.norm1(x)) + x
+ x = self.attn2(self.norm2(x), context=context) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+class SpatialTransformer(nn.Module):
+ """
+ Transformer block for image-like data.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ """
+
+ def __init__(
+ self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = Normalize(in_channels)
+
+ self.proj_in = nn.Conv2d(
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
+ )
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim
+ )
+ for d in range(depth)
+ ]
+ )
+
+ self.proj_out = zero_module(
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+ )
+
+ def forward(self, x, context=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ x = self.proj_in(x)
+ x = rearrange(x, "b c h w -> b (h w) c")
+ for block in self.transformer_blocks:
+ x = block(x, context=context)
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
+ x = self.proj_out(x)
+ return x + x_in
diff --git a/models/tta/ldm/audioldm.py b/models/tta/ldm/audioldm.py
new file mode 100644
index 0000000000000000000000000000000000000000..c799fec50c9885c3a005d0975cc7f15ee3469b55
--- /dev/null
+++ b/models/tta/ldm/audioldm.py
@@ -0,0 +1,928 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from abc import abstractmethod
+from functools import partial
+import math
+from typing import Iterable
+
+import os
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from einops import repeat
+
+from models.tta.ldm.attention import SpatialTransformer
+
+# from attention import SpatialTransformer
+
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad():
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+ """
+ Create sinusoidal timestep embeddings.
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ if not repeat_only:
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period)
+ * torch.arange(start=0, end=half, dtype=torch.float32)
+ / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat(
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
+ )
+ else:
+ embedding = repeat(timesteps, "b -> b d", d=dim)
+ return embedding
+
+
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+
+def normalization(channels):
+ """
+ Make a standard normalization layer.
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(32, channels)
+
+
+def count_flops_attn(model, _x, y):
+ """
+ A counter for the `thop` package to count the operations in an
+ attention operation.
+ Meant to be used like:
+ macs, params = thop.profile(
+ model,
+ inputs=(inputs, timestamps),
+ custom_ops={QKVAttention: QKVAttention.count_flops},
+ )
+ """
+ b, c, *spatial = y[0].shape
+ num_spatial = int(np.prod(spatial))
+ # We perform two matmuls with the same number of ops.
+ # The first computes the weight matrix, the second computes
+ # the combination of the value vectors.
+ matmul_ops = 2 * b * (num_spatial**2) * c
+ model.total_ops += torch.DoubleTensor([matmul_ops])
+
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+class QKVAttention(nn.Module):
+ """
+ A module which performs QKV attention and splits in a different order.
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.chunk(3, dim=1) # [N x (H * C) x T]
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = torch.einsum(
+ "bct,bcs->bts",
+ (q * scale).view(bs * self.n_heads, ch, length),
+ (k * scale).view(bs * self.n_heads, ch, length),
+ ) # More stable with f16 than dividing afterwards
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = torch.einsum(
+ "bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)
+ )
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class QKVAttentionLegacy(nn.Module):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = torch.einsum(
+ "bct,bcs->bts", q * scale, k * scale
+ ) # More stable with f16 than dividing afterwards
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = torch.einsum("bts,bcs->bct", weight, v)
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class AttentionPool2d(nn.Module):
+ """
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+ """
+
+ def __init__(
+ self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads_channels: int,
+ output_dim: int = None,
+ ):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(
+ torch.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
+ )
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+ self.num_heads = embed_dim // num_heads_channels
+ self.attention = QKVAttention(self.num_heads)
+
+ def forward(self, x):
+ b, c, *_spatial = x.shape
+ x = x.reshape(b, c, -1) # NC(HW)
+ x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
+ x = self.qkv_proj(x)
+ x = self.attention(x)
+ x = self.c_proj(x)
+ return x[:, :, 0]
+
+
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(self, x, emb, context=None):
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb)
+ elif isinstance(layer, SpatialTransformer):
+ x = layer(x, context)
+ else:
+ x = layer(x)
+ return x
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = conv_nd(
+ dims, self.channels, self.out_channels, 3, padding=padding
+ )
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
+ )
+ else:
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+
+class TransposedUpsample(nn.Module):
+ "Learned 2x upsampling without padding"
+
+ def __init__(self, channels, out_channels=None, ks=5):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+
+ self.up = nn.ConvTranspose2d(
+ self.channels, self.out_channels, kernel_size=ks, stride=2
+ )
+
+ def forward(self, x):
+ return self.up(x)
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(
+ dims,
+ self.channels,
+ self.out_channels,
+ 3,
+ stride=stride,
+ padding=padding,
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(
+ emb_channels,
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+ ),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, 3, padding=1
+ )
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x, emb):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ return checkpoint(
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
+ )
+
+ def _forward(self, x, emb):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other.
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = normalization(channels)
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
+ if use_new_attention_order:
+ # split qkv before split heads
+ self.attention = QKVAttention(self.num_heads)
+ else:
+ # split heads before split qkv
+ self.attention = QKVAttentionLegacy(self.num_heads)
+
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+ def forward(self, x):
+ return checkpoint(
+ self._forward, (x,), self.parameters(), True
+ ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
+ # return pt_checkpoint(self._forward, x) # pytorch
+
+ def _forward(self, x):
+ b, c, *spatial = x.shape
+ x = x.reshape(b, c, -1)
+ qkv = self.qkv(self.norm(x))
+ h = self.attention(qkv)
+ h = self.proj_out(h)
+ return (x + h).reshape(b, c, *spatial)
+
+
+class UNetModel(nn.Module):
+ """
+ The full UNet model with attention and timestep embedding.
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param num_res_blocks: number of residual blocks per downsample.
+ :param attention_resolutions: a collection of downsample rates at which
+ attention will take place. May be a set, list, or tuple.
+ For example, if this contains 4, then at 4x downsampling, attention
+ will be used.
+ :param dropout: the dropout probability.
+ :param channel_mult: channel multiplier for each level of the UNet.
+ :param conv_resample: if True, use learned convolutions for upsampling and
+ downsampling.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param num_classes: if specified (as an int), then this model will be
+ class-conditional with `num_classes` classes.
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+ :param num_heads: the number of attention heads in each attention layer.
+ :param num_heads_channels: if specified, ignore num_heads and instead use
+ a fixed channel width per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number
+ of heads for upsampling. Deprecated.
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ :param resblock_updown: use residual blocks for up/downsampling.
+ :param use_new_attention_order: use a different attention pattern for potentially
+ increased efficiency.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ ):
+ super().__init__()
+ if use_spatial_transformer:
+ assert (
+ context_dim is not None
+ ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
+
+ if context_dim is not None:
+ assert (
+ use_spatial_transformer
+ ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
+ from omegaconf.listconfig import ListConfig
+
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert (
+ num_head_channels != -1
+ ), "Either num_heads or num_head_channels has to be set"
+
+ if num_head_channels == -1:
+ assert (
+ num_heads != -1
+ ), "Either num_heads or num_head_channels has to be set"
+
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.dtype = torch.float16 if use_fp16 else torch.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ nn.Linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ nn.Linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.num_classes is not None:
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ # num_heads = 1
+ dim_head = (
+ ch // num_heads
+ if use_spatial_transformer
+ else num_head_channels
+ )
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ )
+ if not use_spatial_transformer
+ else SpatialTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth,
+ context_dim=context_dim,
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ # num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ (
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ )
+ if not use_spatial_transformer
+ else SpatialTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth,
+ context_dim=context_dim,
+ )
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(num_res_blocks + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=model_channels * mult,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ # num_heads = 1
+ dim_head = (
+ ch // num_heads
+ if use_spatial_transformer
+ else num_head_channels
+ )
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads_upsample,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ )
+ if not use_spatial_transformer
+ else SpatialTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth,
+ context_dim=context_dim,
+ )
+ )
+ if level and i == num_res_blocks:
+ out_ch = ch
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+ if self.predict_codebook_ids:
+ self.id_predictor = nn.Sequential(
+ normalization(ch),
+ conv_nd(dims, model_channels, n_embed, 1),
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
+ )
+
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param context: conditioning plugged in via crossattn
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ if self.num_classes is not None:
+ assert y.shape == (x.shape[0],)
+ emb = emb + self.label_emb(y)
+
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb, context)
+ hs.append(h)
+ h = self.middle_block(h, emb, context)
+ for module in self.output_blocks:
+ # print(h.shape, hs[-1].shape)
+ if h.shape != hs[-1].shape:
+ if h.shape[-1] > hs[-1].shape[-1]:
+ h = h[:, :, :, : hs[-1].shape[-1]]
+ if h.shape[-2] > hs[-1].shape[-2]:
+ h = h[:, :, : hs[-1].shape[-2], :]
+ h = torch.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context)
+ # print(h.shape)
+ h = h.type(x.dtype)
+ if self.predict_codebook_ids:
+ return self.id_predictor(h)
+ else:
+ return self.out(h)
+
+
+class AudioLDM(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.cfg = cfg
+ self.unet = UNetModel(
+ image_size=cfg.image_size,
+ in_channels=cfg.in_channels,
+ out_channels=cfg.out_channels,
+ model_channels=cfg.model_channels,
+ attention_resolutions=cfg.attention_resolutions,
+ num_res_blocks=cfg.num_res_blocks,
+ channel_mult=cfg.channel_mult,
+ num_heads=cfg.num_heads,
+ use_spatial_transformer=cfg.use_spatial_transformer,
+ transformer_depth=cfg.transformer_depth,
+ context_dim=cfg.context_dim,
+ use_checkpoint=cfg.use_checkpoint,
+ legacy=cfg.legacy,
+ )
+
+ def forward(self, x, timesteps=None, context=None, y=None):
+ x = self.unet(x=x, timesteps=timesteps, context=context, y=y)
+ return x
diff --git a/models/tta/ldm/audioldm_dataset.py b/models/tta/ldm/audioldm_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..96232536599c2d53495c23e5b810ce0e3c381a7e
--- /dev/null
+++ b/models/tta/ldm/audioldm_dataset.py
@@ -0,0 +1,151 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import random
+import torch
+from torch.nn.utils.rnn import pad_sequence
+from utils.data_utils import *
+
+
+from models.base.base_dataset import (
+ BaseOfflineCollator,
+ BaseOfflineDataset,
+ BaseTestDataset,
+ BaseTestCollator,
+)
+import librosa
+
+from transformers import AutoTokenizer
+
+
+class AudioLDMDataset(BaseOfflineDataset):
+ def __init__(self, cfg, dataset, is_valid=False):
+ BaseOfflineDataset.__init__(self, cfg, dataset, is_valid=is_valid)
+
+ self.cfg = cfg
+
+ # utt2melspec
+ if cfg.preprocess.use_melspec:
+ self.utt2melspec_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2melspec_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.melspec_dir,
+ uid + ".npy",
+ )
+
+ # utt2wav
+ if cfg.preprocess.use_wav:
+ self.utt2wav_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2wav_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.wav_dir,
+ uid + ".wav",
+ )
+
+ # utt2caption
+ if cfg.preprocess.use_caption:
+ self.utt2caption = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2caption[utt] = utt_info["Caption"]
+
+ def __getitem__(self, index):
+ # melspec: (n_mels, T)
+ # wav: (T,)
+
+ single_feature = BaseOfflineDataset.__getitem__(self, index)
+
+ utt_info = self.metadata[index]
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ if self.cfg.preprocess.use_melspec:
+ single_feature["melspec"] = np.load(self.utt2melspec_path[utt])
+
+ if self.cfg.preprocess.use_wav:
+ wav, sr = librosa.load(
+ self.utt2wav_path[utt], sr=16000
+ ) # hard coding for 16KHz...
+ single_feature["wav"] = wav
+
+ if self.cfg.preprocess.use_caption:
+ cond_mask = np.random.choice(
+ [1, 0],
+ p=[
+ self.cfg.preprocess.cond_mask_prob,
+ 1 - self.cfg.preprocess.cond_mask_prob,
+ ],
+ ) # (0.1, 0.9)
+ if cond_mask:
+ single_feature["caption"] = ""
+ else:
+ single_feature["caption"] = self.utt2caption[utt]
+
+ return single_feature
+
+ def __len__(self):
+ return len(self.metadata)
+
+
+class AudioLDMCollator(BaseOfflineCollator):
+ def __init__(self, cfg):
+ BaseOfflineCollator.__init__(self, cfg)
+
+ self.tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512)
+
+ def __call__(self, batch):
+ # mel: (B, n_mels, T)
+ # wav (option): (B, T)
+ # text_input_ids: (B, L)
+ # text_attention_mask: (B, L)
+
+ packed_batch_features = dict()
+
+ for key in batch[0].keys():
+ if key == "melspec":
+ packed_batch_features["melspec"] = torch.from_numpy(
+ np.array([b["melspec"][:, :624] for b in batch])
+ )
+
+ if key == "wav":
+ values = [torch.from_numpy(b[key]) for b in batch]
+ packed_batch_features[key] = pad_sequence(
+ values, batch_first=True, padding_value=0
+ )
+
+ if key == "caption":
+ captions = [b[key] for b in batch]
+ text_input = self.tokenizer(
+ captions, return_tensors="pt", truncation=True, padding="longest"
+ )
+ text_input_ids = text_input["input_ids"]
+ text_attention_mask = text_input["attention_mask"]
+
+ packed_batch_features["text_input_ids"] = text_input_ids
+ packed_batch_features["text_attention_mask"] = text_attention_mask
+
+ return packed_batch_features
+
+
+class AudioLDMTestDataset(BaseTestDataset): ...
+
+
+class AudioLDMTestCollator(BaseTestCollator): ...
diff --git a/models/tta/ldm/audioldm_inference.py b/models/tta/ldm/audioldm_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..a37b40639aeef4d4a8ade1324171e7be11009d8d
--- /dev/null
+++ b/models/tta/ldm/audioldm_inference.py
@@ -0,0 +1,193 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import time
+import numpy as np
+import torch
+from tqdm import tqdm
+import torch.nn as nn
+from collections import OrderedDict
+import json
+
+from models.tta.autoencoder.autoencoder import AutoencoderKL
+from models.tta.ldm.inference_utils.vocoder import Generator
+from models.tta.ldm.audioldm import AudioLDM
+from transformers import T5EncoderModel, AutoTokenizer
+from diffusers import PNDMScheduler
+
+import matplotlib.pyplot as plt
+from scipy.io.wavfile import write
+
+
+class AttrDict(dict):
+ def __init__(self, *args, **kwargs):
+ super(AttrDict, self).__init__(*args, **kwargs)
+ self.__dict__ = self
+
+
+class AudioLDMInference:
+ def __init__(self, args, cfg):
+ self.cfg = cfg
+ self.args = args
+
+ self.build_autoencoderkl()
+ self.build_textencoder()
+
+ self.model = self.build_model()
+ self.load_state_dict()
+
+ self.build_vocoder()
+
+ self.out_path = self.args.output_dir
+ self.out_mel_path = os.path.join(self.out_path, "mel")
+ self.out_wav_path = os.path.join(self.out_path, "wav")
+ os.makedirs(self.out_mel_path, exist_ok=True)
+ os.makedirs(self.out_wav_path, exist_ok=True)
+
+ def build_autoencoderkl(self):
+ self.autoencoderkl = AutoencoderKL(self.cfg.model.autoencoderkl)
+ self.autoencoder_path = self.cfg.model.autoencoder_path
+ checkpoint = torch.load(self.autoencoder_path, map_location="cpu")
+ self.autoencoderkl.load_state_dict(checkpoint["model"])
+ self.autoencoderkl.cuda(self.args.local_rank)
+ self.autoencoderkl.requires_grad_(requires_grad=False)
+ self.autoencoderkl.eval()
+
+ def build_textencoder(self):
+ self.tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512)
+ self.text_encoder = T5EncoderModel.from_pretrained("t5-base")
+ self.text_encoder.cuda(self.args.local_rank)
+ self.text_encoder.requires_grad_(requires_grad=False)
+ self.text_encoder.eval()
+
+ def build_vocoder(self):
+ config_file = os.path.join(self.args.vocoder_config_path)
+ with open(config_file) as f:
+ data = f.read()
+ json_config = json.loads(data)
+ h = AttrDict(json_config)
+ self.vocoder = Generator(h).to(self.args.local_rank)
+ checkpoint_dict = torch.load(
+ self.args.vocoder_path, map_location=self.args.local_rank
+ )
+ self.vocoder.load_state_dict(checkpoint_dict["generator"])
+
+ def build_model(self):
+ self.model = AudioLDM(self.cfg.model.audioldm)
+ return self.model
+
+ def load_state_dict(self):
+ self.checkpoint_path = self.args.checkpoint_path
+ checkpoint = torch.load(self.checkpoint_path, map_location="cpu")
+ self.model.load_state_dict(checkpoint["model"])
+ self.model.cuda(self.args.local_rank)
+
+ def get_text_embedding(self):
+ text = self.args.text
+
+ prompt = [text]
+
+ text_input = self.tokenizer(
+ prompt,
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ padding="do_not_pad",
+ return_tensors="pt",
+ )
+ text_embeddings = self.text_encoder(
+ text_input.input_ids.to(self.args.local_rank)
+ )[0]
+
+ max_length = text_input.input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ [""] * 1, padding="max_length", max_length=max_length, return_tensors="pt"
+ )
+ uncond_embeddings = self.text_encoder(
+ uncond_input.input_ids.to(self.args.local_rank)
+ )[0]
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ return text_embeddings
+
+ def inference(self):
+ text_embeddings = self.get_text_embedding()
+ print(text_embeddings.shape)
+
+ num_steps = self.args.num_steps
+ guidance_scale = self.args.guidance_scale
+
+ noise_scheduler = PNDMScheduler(
+ num_train_timesteps=1000,
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ skip_prk_steps=True,
+ set_alpha_to_one=False,
+ steps_offset=1,
+ prediction_type="epsilon",
+ )
+
+ noise_scheduler.set_timesteps(num_steps)
+
+ latents = torch.randn(
+ (
+ 1,
+ self.cfg.model.autoencoderkl.z_channels,
+ 80 // (2 ** (len(self.cfg.model.autoencoderkl.ch_mult) - 1)),
+ 624 // (2 ** (len(self.cfg.model.autoencoderkl.ch_mult) - 1)),
+ )
+ ).to(self.args.local_rank)
+
+ self.model.eval()
+ for t in tqdm(noise_scheduler.timesteps):
+ t = t.to(self.args.local_rank)
+
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
+ latent_model_input = torch.cat([latents] * 2)
+
+ latent_model_input = noise_scheduler.scale_model_input(
+ latent_model_input, timestep=t
+ )
+ # print(latent_model_input.shape)
+
+ # predict the noise residual
+ with torch.no_grad():
+ noise_pred = self.model(
+ latent_model_input, torch.cat([t.unsqueeze(0)] * 2), text_embeddings
+ )
+
+ # perform guidance
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (
+ noise_pred_text - noise_pred_uncond
+ )
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
+ # print(latents.shape)
+
+ latents_out = latents
+ print(latents_out.shape)
+
+ with torch.no_grad():
+ mel_out = self.autoencoderkl.decode(latents_out)
+ print(mel_out.shape)
+
+ melspec = mel_out[0, 0].cpu().detach().numpy()
+ plt.imsave(os.path.join(self.out_mel_path, self.args.text + ".png"), melspec)
+
+ self.vocoder.eval()
+ self.vocoder.remove_weight_norm()
+ with torch.no_grad():
+ melspec = np.expand_dims(melspec, 0)
+ melspec = torch.FloatTensor(melspec).to(self.args.local_rank)
+
+ y = self.vocoder(melspec)
+ audio = y.squeeze()
+ audio = audio * 32768.0
+ audio = audio.cpu().numpy().astype("int16")
+
+ write(os.path.join(self.out_wav_path, self.args.text + ".wav"), 16000, audio)
diff --git a/models/tta/ldm/audioldm_trainer.py b/models/tta/ldm/audioldm_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd4a241867b139f0ce314b9c78d053cf711a83df
--- /dev/null
+++ b/models/tta/ldm/audioldm_trainer.py
@@ -0,0 +1,251 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from models.base.base_trainer import BaseTrainer
+from diffusers import DDPMScheduler
+from models.tta.ldm.audioldm_dataset import AudioLDMDataset, AudioLDMCollator
+from models.tta.autoencoder.autoencoder import AutoencoderKL
+from models.tta.ldm.audioldm import AudioLDM, UNetModel
+import torch
+import torch.nn as nn
+from torch.nn import MSELoss, L1Loss
+import torch.nn.functional as F
+from torch.utils.data import ConcatDataset, DataLoader
+
+from transformers import T5EncoderModel
+from diffusers import DDPMScheduler
+
+
+class AudioLDMTrainer(BaseTrainer):
+ def __init__(self, args, cfg):
+ BaseTrainer.__init__(self, args, cfg)
+ self.cfg = cfg
+
+ self.build_autoencoderkl()
+ self.build_textencoder()
+ self.nosie_scheduler = self.build_noise_scheduler()
+
+ self.save_config_file()
+
+ def build_autoencoderkl(self):
+ self.autoencoderkl = AutoencoderKL(self.cfg.model.autoencoderkl)
+ self.autoencoder_path = self.cfg.model.autoencoder_path
+ checkpoint = torch.load(self.autoencoder_path, map_location="cpu")
+ self.autoencoderkl.load_state_dict(checkpoint["model"])
+ self.autoencoderkl.cuda(self.args.local_rank)
+ self.autoencoderkl.requires_grad_(requires_grad=False)
+ self.autoencoderkl.eval()
+
+ def build_textencoder(self):
+ self.text_encoder = T5EncoderModel.from_pretrained("t5-base")
+ self.text_encoder.cuda(self.args.local_rank)
+ self.text_encoder.requires_grad_(requires_grad=False)
+ self.text_encoder.eval()
+
+ def build_noise_scheduler(self):
+ nosie_scheduler = DDPMScheduler(
+ num_train_timesteps=self.cfg.model.noise_scheduler.num_train_timesteps,
+ beta_start=self.cfg.model.noise_scheduler.beta_start,
+ beta_end=self.cfg.model.noise_scheduler.beta_end,
+ beta_schedule=self.cfg.model.noise_scheduler.beta_schedule,
+ clip_sample=self.cfg.model.noise_scheduler.clip_sample,
+ # steps_offset=self.cfg.model.noise_scheduler.steps_offset,
+ # set_alpha_to_one=self.cfg.model.noise_scheduler.set_alpha_to_one,
+ # skip_prk_steps=self.cfg.model.noise_scheduler.skip_prk_steps,
+ prediction_type=self.cfg.model.noise_scheduler.prediction_type,
+ )
+ return nosie_scheduler
+
+ def build_dataset(self):
+ return AudioLDMDataset, AudioLDMCollator
+
+ def build_data_loader(self):
+ Dataset, Collator = self.build_dataset()
+ # build dataset instance for each dataset and combine them by ConcatDataset
+ datasets_list = []
+ for dataset in self.cfg.dataset:
+ subdataset = Dataset(self.cfg, dataset, is_valid=False)
+ datasets_list.append(subdataset)
+ train_dataset = ConcatDataset(datasets_list)
+
+ train_collate = Collator(self.cfg)
+
+ # use batch_sampler argument instead of (sampler, shuffle, drop_last, batch_size)
+ train_loader = DataLoader(
+ train_dataset,
+ collate_fn=train_collate,
+ num_workers=self.args.num_workers,
+ batch_size=self.cfg.train.batch_size,
+ pin_memory=False,
+ )
+ if not self.cfg.train.ddp or self.args.local_rank == 0:
+ datasets_list = []
+ for dataset in self.cfg.dataset:
+ subdataset = Dataset(self.cfg, dataset, is_valid=True)
+ datasets_list.append(subdataset)
+ valid_dataset = ConcatDataset(datasets_list)
+ valid_collate = Collator(self.cfg)
+
+ valid_loader = DataLoader(
+ valid_dataset,
+ collate_fn=valid_collate,
+ num_workers=1,
+ batch_size=self.cfg.train.batch_size,
+ )
+ else:
+ raise NotImplementedError("DDP is not supported yet.")
+ # valid_loader = None
+ data_loader = {"train": train_loader, "valid": valid_loader}
+ return data_loader
+
+ def build_optimizer(self):
+ optimizer = torch.optim.AdamW(self.model.parameters(), **self.cfg.train.adam)
+ return optimizer
+
+ # TODO: check it...
+ def build_scheduler(self):
+ return None
+ # return ReduceLROnPlateau(self.optimizer["opt_ae"], **self.cfg.train.lronPlateau)
+
+ def write_summary(self, losses, stats):
+ for key, value in losses.items():
+ self.sw.add_scalar(key, value, self.step)
+
+ def write_valid_summary(self, losses, stats):
+ for key, value in losses.items():
+ self.sw.add_scalar(key, value, self.step)
+
+ def build_criterion(self):
+ criterion = nn.MSELoss(reduction="mean")
+ return criterion
+
+ def get_state_dict(self):
+ if self.scheduler != None:
+ state_dict = {
+ "model": self.model.state_dict(),
+ "optimizer": self.optimizer.state_dict(),
+ "scheduler": self.scheduler.state_dict(),
+ "step": self.step,
+ "epoch": self.epoch,
+ "batch_size": self.cfg.train.batch_size,
+ }
+ else:
+ state_dict = {
+ "model": self.model.state_dict(),
+ "optimizer": self.optimizer.state_dict(),
+ "step": self.step,
+ "epoch": self.epoch,
+ "batch_size": self.cfg.train.batch_size,
+ }
+ return state_dict
+
+ def load_model(self, checkpoint):
+ self.step = checkpoint["step"]
+ self.epoch = checkpoint["epoch"]
+
+ self.model.load_state_dict(checkpoint["model"])
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
+ if self.scheduler != None:
+ self.scheduler.load_state_dict(checkpoint["scheduler"])
+
+ def build_model(self):
+ self.model = AudioLDM(self.cfg.model.audioldm)
+ return self.model
+
+ @torch.no_grad()
+ def mel_to_latent(self, melspec):
+ posterior = self.autoencoderkl.encode(melspec)
+ latent = posterior.sample() # (B, 4, 5, 78)
+ return latent
+
+ @torch.no_grad()
+ def get_text_embedding(self, text_input_ids, text_attention_mask):
+ text_embedding = self.text_encoder(
+ input_ids=text_input_ids, attention_mask=text_attention_mask
+ ).last_hidden_state
+ return text_embedding # (B, T, 768)
+
+ def train_step(self, data):
+ train_losses = {}
+ total_loss = 0
+ train_stats = {}
+
+ melspec = data["melspec"].unsqueeze(1) # (B, 80, T) -> (B, 1, 80, T)
+ latents = self.mel_to_latent(melspec)
+
+ text_embedding = self.get_text_embedding(
+ data["text_input_ids"], data["text_attention_mask"]
+ )
+
+ noise = torch.randn_like(latents).float()
+
+ bsz = latents.shape[0]
+ timesteps = torch.randint(
+ 0,
+ self.cfg.model.noise_scheduler.num_train_timesteps,
+ (bsz,),
+ device=latents.device,
+ )
+ timesteps = timesteps.long()
+
+ with torch.no_grad():
+ noisy_latents = self.nosie_scheduler.add_noise(latents, noise, timesteps)
+
+ model_pred = self.model(
+ noisy_latents, timesteps=timesteps, context=text_embedding
+ )
+
+ loss = self.criterion(model_pred, noise)
+
+ train_losses["loss"] = loss
+ total_loss += loss
+
+ self.optimizer.zero_grad()
+ total_loss.backward()
+ self.optimizer.step()
+
+ for item in train_losses:
+ train_losses[item] = train_losses[item].item()
+
+ return train_losses, train_stats, total_loss.item()
+
+ # TODO: eval step
+ @torch.no_grad()
+ def eval_step(self, data, index):
+ valid_loss = {}
+ total_valid_loss = 0
+ valid_stats = {}
+
+ melspec = data["melspec"].unsqueeze(1) # (B, 80, T) -> (B, 1, 80, T)
+ latents = self.mel_to_latent(melspec)
+
+ text_embedding = self.get_text_embedding(
+ data["text_input_ids"], data["text_attention_mask"]
+ )
+
+ noise = torch.randn_like(latents).float()
+
+ bsz = latents.shape[0]
+ timesteps = torch.randint(
+ 0,
+ self.cfg.model.noise_scheduler.num_train_timesteps,
+ (bsz,),
+ device=latents.device,
+ )
+ timesteps = timesteps.long()
+
+ noisy_latents = self.nosie_scheduler.add_noise(latents, noise, timesteps)
+
+ model_pred = self.model(noisy_latents, timesteps, text_embedding)
+
+ loss = self.criterion(model_pred, noise)
+ valid_loss["loss"] = loss
+
+ total_valid_loss += loss
+
+ for item in valid_loss:
+ valid_loss[item] = valid_loss[item].item()
+
+ return valid_loss, valid_stats, total_valid_loss.item()
diff --git a/models/tta/ldm/inference_utils/utils.py b/models/tta/ldm/inference_utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd61c0c262be55a9b98f12c1ce1043eeddfcc739
--- /dev/null
+++ b/models/tta/ldm/inference_utils/utils.py
@@ -0,0 +1,62 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import glob
+import os
+import matplotlib
+import torch
+from torch.nn.utils import weight_norm
+
+matplotlib.use("Agg")
+import matplotlib.pylab as plt
+
+
+def plot_spectrogram(spectrogram):
+ fig, ax = plt.subplots(figsize=(10, 2))
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
+ plt.colorbar(im, ax=ax)
+
+ fig.canvas.draw()
+ plt.close()
+
+ return fig
+
+
+def init_weights(m, mean=0.0, std=0.01):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ m.weight.data.normal_(mean, std)
+
+
+def apply_weight_norm(m):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ weight_norm(m)
+
+
+def get_padding(kernel_size, dilation=1):
+ return int((kernel_size * dilation - dilation) / 2)
+
+
+def load_checkpoint(filepath, device):
+ assert os.path.isfile(filepath)
+ print("Loading '{}'".format(filepath))
+ checkpoint_dict = torch.load(filepath, map_location=device)
+ print("Complete.")
+ return checkpoint_dict
+
+
+def save_checkpoint(filepath, obj):
+ print("Saving checkpoint to {}".format(filepath))
+ torch.save(obj, filepath)
+ print("Complete.")
+
+
+def scan_checkpoint(cp_dir, prefix):
+ pattern = os.path.join(cp_dir, prefix + "????????")
+ cp_list = glob.glob(pattern)
+ if len(cp_list) == 0:
+ return None
+ return sorted(cp_list)[-1]
diff --git a/models/tta/ldm/inference_utils/vocoder.py b/models/tta/ldm/inference_utils/vocoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..19e17c1e2b3e20154305180705ccbf8b5e49c346
--- /dev/null
+++ b/models/tta/ldm/inference_utils/vocoder.py
@@ -0,0 +1,408 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn.functional as F
+import torch.nn as nn
+from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
+from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
+from models.tta.ldm.inference_utils.utils import get_padding, init_weights
+
+LRELU_SLOPE = 0.1
+
+
+class ResBlock1(torch.nn.Module):
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
+ super(ResBlock1, self).__init__()
+ self.h = h
+ self.convs1 = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[2],
+ padding=get_padding(kernel_size, dilation[2]),
+ )
+ ),
+ ]
+ )
+ self.convs1.apply(init_weights)
+
+ self.convs2 = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ ]
+ )
+ self.convs2.apply(init_weights)
+
+ def forward(self, x):
+ for c1, c2 in zip(self.convs1, self.convs2):
+ xt = F.leaky_relu(x, LRELU_SLOPE)
+ xt = c1(xt)
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
+ xt = c2(xt)
+ x = xt + x
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs1:
+ remove_weight_norm(l)
+ for l in self.convs2:
+ remove_weight_norm(l)
+
+
+class ResBlock2(torch.nn.Module):
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
+ super(ResBlock2, self).__init__()
+ self.h = h
+ self.convs = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]),
+ )
+ ),
+ ]
+ )
+ self.convs.apply(init_weights)
+
+ def forward(self, x):
+ for c in self.convs:
+ xt = F.leaky_relu(x, LRELU_SLOPE)
+ xt = c(xt)
+ x = xt + x
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs:
+ remove_weight_norm(l)
+
+
+class Generator(torch.nn.Module):
+ def __init__(self, h):
+ super(Generator, self).__init__()
+ self.h = h
+ self.num_kernels = len(h.resblock_kernel_sizes)
+ self.num_upsamples = len(h.upsample_rates)
+ self.conv_pre = weight_norm(
+ Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)
+ )
+ resblock = ResBlock1 if h.resblock == "1" else ResBlock2
+
+ self.ups = nn.ModuleList()
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
+ self.ups.append(
+ weight_norm(
+ ConvTranspose1d(
+ h.upsample_initial_channel // (2**i),
+ h.upsample_initial_channel // (2 ** (i + 1)),
+ k,
+ u,
+ padding=(k - u) // 2,
+ )
+ )
+ )
+
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
+ for j, (k, d) in enumerate(
+ zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
+ ):
+ self.resblocks.append(resblock(h, ch, k, d))
+
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
+ self.ups.apply(init_weights)
+ self.conv_post.apply(init_weights)
+
+ def forward(self, x):
+ x = self.conv_pre(x)
+ for i in range(self.num_upsamples):
+ x = F.leaky_relu(x, LRELU_SLOPE)
+ x = self.ups[i](x)
+ xs = None
+ for j in range(self.num_kernels):
+ if xs is None:
+ xs = self.resblocks[i * self.num_kernels + j](x)
+ else:
+ xs += self.resblocks[i * self.num_kernels + j](x)
+ x = xs / self.num_kernels
+ x = F.leaky_relu(x)
+ x = self.conv_post(x)
+ x = torch.tanh(x)
+
+ return x
+
+ def remove_weight_norm(self):
+ print("Removing weight norm...")
+ for l in self.ups:
+ remove_weight_norm(l)
+ for l in self.resblocks:
+ l.remove_weight_norm()
+ remove_weight_norm(self.conv_pre)
+ remove_weight_norm(self.conv_post)
+
+
+class DiscriminatorP(torch.nn.Module):
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
+ super(DiscriminatorP, self).__init__()
+ self.period = period
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+ self.convs = nn.ModuleList(
+ [
+ norm_f(
+ Conv2d(
+ 1,
+ 32,
+ (kernel_size, 1),
+ (stride, 1),
+ padding=(get_padding(5, 1), 0),
+ )
+ ),
+ norm_f(
+ Conv2d(
+ 32,
+ 128,
+ (kernel_size, 1),
+ (stride, 1),
+ padding=(get_padding(5, 1), 0),
+ )
+ ),
+ norm_f(
+ Conv2d(
+ 128,
+ 512,
+ (kernel_size, 1),
+ (stride, 1),
+ padding=(get_padding(5, 1), 0),
+ )
+ ),
+ norm_f(
+ Conv2d(
+ 512,
+ 1024,
+ (kernel_size, 1),
+ (stride, 1),
+ padding=(get_padding(5, 1), 0),
+ )
+ ),
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
+ ]
+ )
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
+
+ def forward(self, x):
+ fmap = []
+
+ # 1d to 2d
+ b, c, t = x.shape
+ if t % self.period != 0: # pad first
+ n_pad = self.period - (t % self.period)
+ x = F.pad(x, (0, n_pad), "reflect")
+ t = t + n_pad
+ x = x.view(b, c, t // self.period, self.period)
+
+ for l in self.convs:
+ x = l(x)
+ x = F.leaky_relu(x, LRELU_SLOPE)
+ fmap.append(x)
+ x = self.conv_post(x)
+ fmap.append(x)
+ x = torch.flatten(x, 1, -1)
+
+ return x, fmap
+
+
+class MultiPeriodDiscriminator(torch.nn.Module):
+ def __init__(self):
+ super(MultiPeriodDiscriminator, self).__init__()
+ self.discriminators = nn.ModuleList(
+ [
+ DiscriminatorP(2),
+ DiscriminatorP(3),
+ DiscriminatorP(5),
+ DiscriminatorP(7),
+ DiscriminatorP(11),
+ ]
+ )
+
+ def forward(self, y, y_hat):
+ y_d_rs = []
+ y_d_gs = []
+ fmap_rs = []
+ fmap_gs = []
+ for i, d in enumerate(self.discriminators):
+ y_d_r, fmap_r = d(y)
+ y_d_g, fmap_g = d(y_hat)
+ y_d_rs.append(y_d_r)
+ fmap_rs.append(fmap_r)
+ y_d_gs.append(y_d_g)
+ fmap_gs.append(fmap_g)
+
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+class DiscriminatorS(torch.nn.Module):
+ def __init__(self, use_spectral_norm=False):
+ super(DiscriminatorS, self).__init__()
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+ self.convs = nn.ModuleList(
+ [
+ norm_f(Conv1d(1, 128, 15, 1, padding=7)),
+ norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
+ norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
+ norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
+ norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
+ norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
+ ]
+ )
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
+
+ def forward(self, x):
+ fmap = []
+ for l in self.convs:
+ x = l(x)
+ x = F.leaky_relu(x, LRELU_SLOPE)
+ fmap.append(x)
+ x = self.conv_post(x)
+ fmap.append(x)
+ x = torch.flatten(x, 1, -1)
+
+ return x, fmap
+
+
+class MultiScaleDiscriminator(torch.nn.Module):
+ def __init__(self):
+ super(MultiScaleDiscriminator, self).__init__()
+ self.discriminators = nn.ModuleList(
+ [
+ DiscriminatorS(use_spectral_norm=True),
+ DiscriminatorS(),
+ DiscriminatorS(),
+ ]
+ )
+ self.meanpools = nn.ModuleList(
+ [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]
+ )
+
+ def forward(self, y, y_hat):
+ y_d_rs = []
+ y_d_gs = []
+ fmap_rs = []
+ fmap_gs = []
+ for i, d in enumerate(self.discriminators):
+ if i != 0:
+ y = self.meanpools[i - 1](y)
+ y_hat = self.meanpools[i - 1](y_hat)
+ y_d_r, fmap_r = d(y)
+ y_d_g, fmap_g = d(y_hat)
+ y_d_rs.append(y_d_r)
+ fmap_rs.append(fmap_r)
+ y_d_gs.append(y_d_g)
+ fmap_gs.append(fmap_g)
+
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+def feature_loss(fmap_r, fmap_g):
+ loss = 0
+ for dr, dg in zip(fmap_r, fmap_g):
+ for rl, gl in zip(dr, dg):
+ loss += torch.mean(torch.abs(rl - gl))
+
+ return loss * 2
+
+
+def discriminator_loss(disc_real_outputs, disc_generated_outputs):
+ loss = 0
+ r_losses = []
+ g_losses = []
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
+ r_loss = torch.mean((1 - dr) ** 2)
+ g_loss = torch.mean(dg**2)
+ loss += r_loss + g_loss
+ r_losses.append(r_loss.item())
+ g_losses.append(g_loss.item())
+
+ return loss, r_losses, g_losses
+
+
+def generator_loss(disc_outputs):
+ loss = 0
+ gen_losses = []
+ for dg in disc_outputs:
+ l = torch.mean((1 - dg) ** 2)
+ gen_losses.append(l)
+ loss += l
+
+ return loss, gen_losses
diff --git a/models/tts/base/__init__.py b/models/tts/base/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..268a381ec731d684418beab0d60cad84b22c4533
--- /dev/null
+++ b/models/tts/base/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# from .tts_inferece import TTSInference
+from .tts_trainer import TTSTrainer
diff --git a/models/tts/base/tts_dataset.py b/models/tts/base/tts_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..dca129e29c6cdb3d3d79432ae4492b8234e56822
--- /dev/null
+++ b/models/tts/base/tts_dataset.py
@@ -0,0 +1,392 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import json
+import os
+import torchaudio
+import numpy as np
+import torch
+from utils.data_utils import *
+from torch.nn.utils.rnn import pad_sequence
+from text import text_to_sequence
+from text.text_token_collation import phoneIDCollation
+from processors.acoustic_extractor import cal_normalized_mel
+
+from models.base.base_dataset import (
+ BaseOfflineDataset,
+ BaseOfflineCollator,
+ BaseTestDataset,
+ BaseTestCollator,
+)
+
+from processors.content_extractor import (
+ ContentvecExtractor,
+ WenetExtractor,
+ WhisperExtractor,
+)
+
+
+class TTSDataset(BaseOfflineDataset):
+ def __init__(self, cfg, dataset, is_valid=False):
+ """
+ Args:
+ cfg: config
+ dataset: dataset name
+ is_valid: whether to use train or valid dataset
+ """
+
+ assert isinstance(dataset, str)
+
+ self.cfg = cfg
+
+ processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
+ meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
+ self.metafile_path = os.path.join(processed_data_dir, meta_file)
+ self.metadata = self.get_metadata()
+
+ """
+ load spk2id and utt2spk from json file
+ spk2id: {spk1: 0, spk2: 1, ...}
+ utt2spk: {dataset_uid: spk1, ...}
+ """
+ if cfg.preprocess.use_spkid:
+ dataset = self.metadata[0]["Dataset"]
+
+ spk2id_path = os.path.join(processed_data_dir, cfg.preprocess.spk2id)
+ with open(spk2id_path, "r") as f:
+ self.spk2id = json.load(f)
+
+ utt2spk_path = os.path.join(processed_data_dir, cfg.preprocess.utt2spk)
+ self.utt2spk = dict()
+ with open(utt2spk_path, "r") as f:
+ for line in f.readlines():
+ utt, spk = line.strip().split("\t")
+ self.utt2spk[utt] = spk
+
+ if cfg.preprocess.use_uv:
+ self.utt2uv_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+ self.utt2uv_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.uv_dir,
+ uid + ".npy",
+ )
+
+ if cfg.preprocess.use_frame_pitch:
+ self.utt2frame_pitch_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2frame_pitch_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.pitch_dir,
+ uid + ".npy",
+ )
+
+ if cfg.preprocess.use_frame_energy:
+ self.utt2frame_energy_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2frame_energy_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.energy_dir,
+ uid + ".npy",
+ )
+
+ if cfg.preprocess.use_mel:
+ self.utt2mel_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2mel_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.mel_dir,
+ uid + ".npy",
+ )
+
+ if cfg.preprocess.use_linear:
+ self.utt2linear_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2linear_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.linear_dir,
+ uid + ".npy",
+ )
+
+ if cfg.preprocess.use_audio:
+ self.utt2audio_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ if cfg.preprocess.extract_audio:
+ self.utt2audio_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.audio_dir,
+ uid + ".wav",
+ )
+ else:
+ self.utt2audio_path[utt] = utt_info["Path"]
+
+ # self.utt2audio_path[utt] = os.path.join(
+ # cfg.preprocess.processed_dir,
+ # dataset,
+ # cfg.preprocess.audio_dir,
+ # uid + ".numpy",
+ # )
+
+ elif cfg.preprocess.use_label:
+ self.utt2label_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2label_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.label_dir,
+ uid + ".npy",
+ )
+ elif cfg.preprocess.use_one_hot:
+ self.utt2one_hot_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2one_hot_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.one_hot_dir,
+ uid + ".npy",
+ )
+
+ if cfg.preprocess.use_text or cfg.preprocess.use_phone:
+ self.utt2seq = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ if cfg.preprocess.use_text:
+ text = utt_info["Text"]
+ sequence = text_to_sequence(text, cfg.preprocess.text_cleaners)
+ elif cfg.preprocess.use_phone:
+ # load phoneme squence from phone file
+ phone_path = os.path.join(
+ processed_data_dir, cfg.preprocess.phone_dir, uid + ".phone"
+ )
+ with open(phone_path, "r") as fin:
+ phones = fin.readlines()
+ assert len(phones) == 1
+ phones = phones[0].strip()
+ phones_seq = phones.split(" ")
+
+ phon_id_collator = phoneIDCollation(cfg, dataset=dataset)
+ sequence = phon_id_collator.get_phone_id_sequence(cfg, phones_seq)
+
+ if cfg.preprocess.add_blank:
+ sequence = intersperse(sequence, 0)
+
+ self.utt2seq[utt] = sequence
+
+ def __getitem__(self, index):
+ utt_info = self.metadata[index]
+
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ single_feature = dict()
+
+ if self.cfg.preprocess.use_spkid:
+ single_feature["spk_id"] = np.array(
+ [self.spk2id[self.utt2spk[utt]]], dtype=np.int32
+ )
+
+ if self.cfg.preprocess.use_mel:
+ mel = np.load(self.utt2mel_path[utt])
+ assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T]
+ if self.cfg.preprocess.use_min_max_norm_mel:
+ # do mel norm
+ mel = cal_normalized_mel(mel, utt_info["Dataset"], self.cfg.preprocess)
+
+ if "target_len" not in single_feature.keys():
+ single_feature["target_len"] = mel.shape[1]
+ single_feature["mel"] = mel.T # [T, n_mels]
+
+ if self.cfg.preprocess.use_linear:
+ linear = np.load(self.utt2linear_path[utt])
+ if "target_len" not in single_feature.keys():
+ single_feature["target_len"] = linear.shape[1]
+ single_feature["linear"] = linear.T # [T, n_linear]
+
+ if self.cfg.preprocess.use_frame_pitch:
+ frame_pitch_path = self.utt2frame_pitch_path[utt]
+ frame_pitch = np.load(frame_pitch_path)
+ if "target_len" not in single_feature.keys():
+ single_feature["target_len"] = len(frame_pitch)
+ aligned_frame_pitch = align_length(
+ frame_pitch, single_feature["target_len"]
+ )
+ single_feature["frame_pitch"] = aligned_frame_pitch
+
+ if self.cfg.preprocess.use_uv:
+ frame_uv_path = self.utt2uv_path[utt]
+ frame_uv = np.load(frame_uv_path)
+ aligned_frame_uv = align_length(frame_uv, single_feature["target_len"])
+ aligned_frame_uv = [
+ 0 if frame_uv else 1 for frame_uv in aligned_frame_uv
+ ]
+ aligned_frame_uv = np.array(aligned_frame_uv)
+ single_feature["frame_uv"] = aligned_frame_uv
+
+ if self.cfg.preprocess.use_frame_energy:
+ frame_energy_path = self.utt2frame_energy_path[utt]
+ frame_energy = np.load(frame_energy_path)
+ if "target_len" not in single_feature.keys():
+ single_feature["target_len"] = len(frame_energy)
+ aligned_frame_energy = align_length(
+ frame_energy, single_feature["target_len"]
+ )
+ single_feature["frame_energy"] = aligned_frame_energy
+
+ if self.cfg.preprocess.use_audio:
+ audio, sr = torchaudio.load(self.utt2audio_path[utt])
+ audio = audio.cpu().numpy().squeeze()
+ single_feature["audio"] = audio
+ single_feature["audio_len"] = audio.shape[0]
+
+ if self.cfg.preprocess.use_phone or self.cfg.preprocess.use_text:
+ single_feature["phone_seq"] = np.array(self.utt2seq[utt])
+ single_feature["phone_len"] = len(self.utt2seq[utt])
+
+ return single_feature
+
+ def __len__(self):
+ return super().__len__()
+
+ def get_metadata(self):
+ return super().get_metadata()
+
+
+class TTSCollator(BaseOfflineCollator):
+ """Zero-pads model inputs and targets based on number of frames per step"""
+
+ def __init__(self, cfg):
+ super().__init__(cfg)
+
+ def __call__(self, batch):
+ parsed_batch_features = super().__call__(batch)
+ return parsed_batch_features
+
+
+class TTSTestDataset(BaseTestDataset):
+ def __init__(self, args, cfg):
+ self.cfg = cfg
+
+ # inference from test list file
+ if args.test_list_file is not None:
+ # construst metadata
+ self.metadata = []
+
+ with open(args.test_list_file, "r") as fin:
+ for idx, line in enumerate(fin.readlines()):
+ utt_info = {}
+
+ utt_info["Dataset"] = "test"
+ utt_info["Text"] = line.strip()
+ utt_info["Uid"] = str(idx)
+ self.metadata.append(utt_info)
+
+ else:
+ assert args.testing_set
+ self.metafile_path = os.path.join(
+ cfg.preprocess.processed_dir,
+ args.dataset,
+ "{}.json".format(args.testing_set),
+ )
+ self.metadata = self.get_metadata()
+
+ def __getitem__(self, index):
+ single_feature = {}
+
+ return single_feature
+
+ def __len__(self):
+ return len(self.metadata)
+
+
+class TTSTestCollator(BaseTestCollator):
+ """Zero-pads model inputs and targets based on number of frames per step"""
+
+ def __init__(self, cfg):
+ self.cfg = cfg
+
+ def __call__(self, batch):
+ packed_batch_features = dict()
+
+ # mel: [b, T, n_mels]
+ # frame_pitch, frame_energy: [1, T]
+ # target_len: [1]
+ # spk_id: [b, 1]
+ # mask: [b, T, 1]
+
+ for key in batch[0].keys():
+ if key == "target_len":
+ packed_batch_features["target_len"] = torch.LongTensor(
+ [b["target_len"] for b in batch]
+ )
+ masks = [
+ torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
+ ]
+ packed_batch_features["mask"] = pad_sequence(
+ masks, batch_first=True, padding_value=0
+ )
+ elif key == "phone_len":
+ packed_batch_features["phone_len"] = torch.LongTensor(
+ [b["phone_len"] for b in batch]
+ )
+ masks = [
+ torch.ones((b["phone_len"], 1), dtype=torch.long) for b in batch
+ ]
+ packed_batch_features["phn_mask"] = pad_sequence(
+ masks, batch_first=True, padding_value=0
+ )
+ elif key == "audio_len":
+ packed_batch_features["audio_len"] = torch.LongTensor(
+ [b["audio_len"] for b in batch]
+ )
+ masks = [
+ torch.ones((b["audio_len"], 1), dtype=torch.long) for b in batch
+ ]
+ else:
+ values = [torch.from_numpy(b[key]) for b in batch]
+ packed_batch_features[key] = pad_sequence(
+ values, batch_first=True, padding_value=0
+ )
+ return packed_batch_features
diff --git a/models/tts/base/tts_inferece.py b/models/tts/base/tts_inferece.py
new file mode 100644
index 0000000000000000000000000000000000000000..f49ace0f1222c6cc203f5aa7ff4e320458709c01
--- /dev/null
+++ b/models/tts/base/tts_inferece.py
@@ -0,0 +1,278 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import torch
+import time
+import accelerate
+import random
+import numpy as np
+from tqdm import tqdm
+from accelerate.logging import get_logger
+from torch.utils.data import DataLoader
+from safetensors.torch import load_file
+
+
+from abc import abstractmethod
+from pathlib import Path
+from utils.io import save_audio
+from utils.util import load_config
+from models.vocoders.vocoder_inference import synthesis
+
+
+class TTSInference(object):
+ def __init__(self, args=None, cfg=None):
+ super().__init__()
+
+ start = time.monotonic_ns()
+ self.args = args
+ self.cfg = cfg
+ self.infer_type = args.mode
+
+ # get exp_dir
+ if self.args.acoustics_dir is not None:
+ self.exp_dir = self.args.acoustics_dir
+ elif self.args.checkpoint_path is not None:
+ self.exp_dir = os.path.dirname(os.path.dirname(self.args.checkpoint_path))
+
+ # Init accelerator
+ self.accelerator = accelerate.Accelerator()
+ self.accelerator.wait_for_everyone()
+ self.device = self.accelerator.device
+
+ # Get logger
+ with self.accelerator.main_process_first():
+ self.logger = get_logger("inference", log_level=args.log_level)
+
+ # Log some info
+ self.logger.info("=" * 56)
+ self.logger.info("||\t\t" + "New inference process started." + "\t\t||")
+ self.logger.info("=" * 56)
+ self.logger.info("\n")
+
+ self.acoustic_model_dir = args.acoustics_dir
+ self.logger.debug(f"Acoustic model dir: {args.acoustics_dir}")
+
+ if args.vocoder_dir is not None:
+ self.vocoder_dir = args.vocoder_dir
+ self.logger.debug(f"Vocoder dir: {args.vocoder_dir}")
+
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # Set random seed
+ with self.accelerator.main_process_first():
+ start = time.monotonic_ns()
+ self._set_random_seed(self.cfg.train.random_seed)
+ end = time.monotonic_ns()
+ self.logger.debug(
+ f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
+ )
+ self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
+
+ # Setup data loader
+ if self.infer_type == "batch":
+ with self.accelerator.main_process_first():
+ self.logger.info("Building dataset...")
+ start = time.monotonic_ns()
+ self.test_dataloader = self._build_test_dataloader()
+ end = time.monotonic_ns()
+ self.logger.info(
+ f"Building dataset done in {(end - start) / 1e6:.2f}ms"
+ )
+
+ # Build model
+ with self.accelerator.main_process_first():
+ self.logger.info("Building model...")
+ start = time.monotonic_ns()
+ self.model = self._build_model()
+ end = time.monotonic_ns()
+ self.logger.info(f"Building model done in {(end - start) / 1e6:.3f}ms")
+
+ # Init with accelerate
+ self.logger.info("Initializing accelerate...")
+ start = time.monotonic_ns()
+ self.accelerator = accelerate.Accelerator()
+ self.model = self.accelerator.prepare(self.model)
+ if self.infer_type == "batch":
+ self.test_dataloader = self.accelerator.prepare(self.test_dataloader)
+ end = time.monotonic_ns()
+ self.accelerator.wait_for_everyone()
+ self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.3f}ms")
+
+ with self.accelerator.main_process_first():
+ self.logger.info("Loading checkpoint...")
+ start = time.monotonic_ns()
+ if args.acoustics_dir is not None:
+ self._load_model(
+ checkpoint_dir=os.path.join(args.acoustics_dir, "checkpoint")
+ )
+ elif args.checkpoint_path is not None:
+ self._load_model(checkpoint_path=args.checkpoint_path)
+ else:
+ print("Either checkpoint dir or checkpoint path should be provided.")
+
+ end = time.monotonic_ns()
+ self.logger.info(f"Loading checkpoint done in {(end - start) / 1e6:.3f}ms")
+
+ self.model.eval()
+ self.accelerator.wait_for_everyone()
+
+ def _build_test_dataset(self):
+ pass
+
+ def _build_model(self):
+ pass
+
+ # TODO: LEGACY CODE
+ def _build_test_dataloader(self):
+ datasets, collate = self._build_test_dataset()
+ self.test_dataset = datasets(self.args, self.cfg)
+ self.test_collate = collate(self.cfg)
+ self.test_batch_size = min(
+ self.cfg.train.batch_size, len(self.test_dataset.metadata)
+ )
+ test_dataloader = DataLoader(
+ self.test_dataset,
+ collate_fn=self.test_collate,
+ num_workers=1,
+ batch_size=self.test_batch_size,
+ shuffle=False,
+ )
+ return test_dataloader
+
+ def _load_model(
+ self,
+ checkpoint_dir: str = None,
+ checkpoint_path: str = None,
+ old_mode: bool = False,
+ ):
+ r"""Load model from checkpoint. If checkpoint_path is None, it will
+ load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
+ None, it will load the checkpoint specified by checkpoint_path. **Only use this
+ method after** ``accelerator.prepare()``.
+ """
+
+ if checkpoint_path is None:
+ assert checkpoint_dir is not None
+ # Load the latest accelerator state dicts
+ ls = [
+ str(i) for i in Path(checkpoint_dir).glob("*") if not "audio" in str(i)
+ ]
+ ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
+ checkpoint_path = ls[0]
+
+ if (
+ Path(os.path.join(checkpoint_path, "model.safetensors")).exists()
+ and accelerate.__version__ < "0.25"
+ ):
+ self.model.load_state_dict(
+ load_file(os.path.join(checkpoint_path, "model.safetensors")),
+ strict=False,
+ )
+ else:
+ self.accelerator.load_state(str(checkpoint_path))
+ return str(checkpoint_path)
+
+ def inference(self):
+ if self.infer_type == "single":
+ out_dir = os.path.join(self.args.output_dir, "single")
+ os.makedirs(out_dir, exist_ok=True)
+
+ pred_audio = self.inference_for_single_utterance()
+ save_path = os.path.join(out_dir, "test_pred.wav")
+ save_audio(save_path, pred_audio, self.cfg.preprocess.sample_rate)
+
+ elif self.infer_type == "batch":
+ out_dir = os.path.join(self.args.output_dir, "batch")
+ os.makedirs(out_dir, exist_ok=True)
+
+ pred_audio_list = self.inference_for_batches()
+ for it, wav in zip(self.test_dataset.metadata, pred_audio_list):
+ uid = it["Uid"]
+ save_audio(
+ os.path.join(out_dir, f"{uid}.wav"),
+ wav.numpy(),
+ self.cfg.preprocess.sample_rate,
+ add_silence=True,
+ turn_up=True,
+ )
+ tmp_file = os.path.join(out_dir, f"{uid}.pt")
+ if os.path.exists(tmp_file):
+ os.remove(tmp_file)
+ print("Saved to: ", out_dir)
+
+ @torch.inference_mode()
+ def inference_for_batches(self):
+ y_pred = []
+ for i, batch in tqdm(enumerate(self.test_dataloader)):
+ y_pred, mel_lens, _ = self._inference_each_batch(batch)
+ y_ls = y_pred.chunk(self.test_batch_size)
+ tgt_ls = mel_lens.chunk(self.test_batch_size)
+ j = 0
+ for it, l in zip(y_ls, tgt_ls):
+ l = l.item()
+ it = it.squeeze(0)[:l].detach().cpu()
+
+ uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"]
+ torch.save(it, os.path.join(self.args.output_dir, f"{uid}.pt"))
+ j += 1
+
+ vocoder_cfg, vocoder_ckpt = self._parse_vocoder(self.args.vocoder_dir)
+ res = synthesis(
+ cfg=vocoder_cfg,
+ vocoder_weight_file=vocoder_ckpt,
+ n_samples=None,
+ pred=[
+ torch.load(
+ os.path.join(self.args.output_dir, "{}.pt".format(item["Uid"]))
+ ).numpy()
+ for item in self.test_dataset.metadata
+ ],
+ )
+ for it, wav in zip(self.test_dataset.metadata, res):
+ uid = it["Uid"]
+ save_audio(
+ os.path.join(self.args.output_dir, f"{uid}.wav"),
+ wav.numpy(),
+ 22050,
+ add_silence=True,
+ turn_up=True,
+ )
+
+ @abstractmethod
+ @torch.inference_mode()
+ def _inference_each_batch(self, batch_data):
+ pass
+
+ def inference_for_single_utterance(self, text):
+ pass
+
+ def synthesis_by_vocoder(self, pred):
+ audios_pred = synthesis(
+ self.vocoder_cfg,
+ self.checkpoint_dir_vocoder,
+ len(pred),
+ pred,
+ )
+
+ return audios_pred
+
+ @staticmethod
+ def _parse_vocoder(vocoder_dir):
+ r"""Parse vocoder config"""
+ vocoder_dir = os.path.abspath(vocoder_dir)
+ ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")]
+ ckpt_list.sort(key=lambda x: int(x.stem), reverse=True)
+ ckpt_path = str(ckpt_list[0])
+ vocoder_cfg = load_config(
+ os.path.join(vocoder_dir, "args.json"), lowercase=True
+ )
+ return vocoder_cfg, ckpt_path
+
+ def _set_random_seed(self, seed):
+ """Set random seed for all possible random modules."""
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.random.manual_seed(seed)
diff --git a/models/tts/base/tts_trainer.py b/models/tts/base/tts_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ebfca6c40530b81817b702574190ca8a08c0941
--- /dev/null
+++ b/models/tts/base/tts_trainer.py
@@ -0,0 +1,721 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import json
+import os
+import shutil
+import torch
+import time
+from pathlib import Path
+import torch
+from tqdm import tqdm
+import re
+import logging
+import json5
+import accelerate
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration
+from torch.utils.data import ConcatDataset, DataLoader
+from accelerate import DistributedDataParallelKwargs
+from schedulers.scheduler import Eden
+from models.base.base_sampler import build_samplers
+from models.base.new_trainer import BaseTrainer
+
+
+class TTSTrainer(BaseTrainer):
+ r"""The base trainer for all TTS models. It inherits from BaseTrainer and implements
+ ``build_criterion``, ``_build_dataset`` and ``_build_singer_lut`` methods. You can inherit from this
+ class, and implement ``_build_model``, ``_forward_step``.
+ """
+
+ def __init__(self, args=None, cfg=None):
+ self.args = args
+ self.cfg = cfg
+
+ cfg.exp_name = args.exp_name
+
+ # init with accelerate
+ self._init_accelerator()
+ self.accelerator.wait_for_everyone()
+
+ with self.accelerator.main_process_first():
+ self.logger = get_logger(args.exp_name, log_level="INFO")
+
+ # Log some info
+ self.logger.info("=" * 56)
+ self.logger.info("||\t\t" + "New training process started." + "\t\t||")
+ self.logger.info("=" * 56)
+ self.logger.info("\n")
+ self.logger.debug(f"Using {args.log_level.upper()} logging level.")
+ self.logger.info(f"Experiment name: {args.exp_name}")
+ self.logger.info(f"Experiment directory: {self.exp_dir}")
+ self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
+ if self.accelerator.is_main_process:
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
+ self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
+
+ # init counts
+ self.batch_count: int = 0
+ self.step: int = 0
+ self.epoch: int = 0
+ self.max_epoch = (
+ self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf")
+ )
+ self.logger.info(
+ "Max epoch: {}".format(
+ self.max_epoch if self.max_epoch < float("inf") else "Unlimited"
+ )
+ )
+
+ # Check values
+ if self.accelerator.is_main_process:
+ self.__check_basic_configs()
+ # Set runtime configs
+ self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride
+ self.checkpoints_path = [
+ [] for _ in range(len(self.save_checkpoint_stride))
+ ]
+ self.keep_last = [
+ i if i > 0 else float("inf") for i in self.cfg.train.keep_last
+ ]
+ self.run_eval = self.cfg.train.run_eval
+
+ # set random seed
+ with self.accelerator.main_process_first():
+ start = time.monotonic_ns()
+ self._set_random_seed(self.cfg.train.random_seed)
+ end = time.monotonic_ns()
+ self.logger.debug(
+ f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
+ )
+ self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
+
+ # setup data_loader
+ with self.accelerator.main_process_first():
+ self.logger.info("Building dataset...")
+ start = time.monotonic_ns()
+ self.train_dataloader, self.valid_dataloader = self._build_dataloader()
+ end = time.monotonic_ns()
+ self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
+
+ # save phone table to exp dir. Should be done before building model due to loading phone table in model
+ if cfg.preprocess.use_phone and cfg.preprocess.phone_extractor != "lexicon":
+ self._save_phone_symbols_file_to_exp_path()
+
+ # setup model
+ with self.accelerator.main_process_first():
+ self.logger.info("Building model...")
+ start = time.monotonic_ns()
+ self.model = self._build_model()
+ end = time.monotonic_ns()
+ self.logger.debug(self.model)
+ self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms")
+ self.logger.info(
+ f"Model parameters: {self.__count_parameters(self.model)/1e6:.2f}M"
+ )
+
+ # optimizer & scheduler
+ with self.accelerator.main_process_first():
+ self.logger.info("Building optimizer and scheduler...")
+ start = time.monotonic_ns()
+ self.optimizer = self._build_optimizer()
+ self.scheduler = self._build_scheduler()
+ end = time.monotonic_ns()
+ self.logger.info(
+ f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms"
+ )
+
+ # create criterion
+ with self.accelerator.main_process_first():
+ self.logger.info("Building criterion...")
+ start = time.monotonic_ns()
+ self.criterion = self._build_criterion()
+ end = time.monotonic_ns()
+ self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms")
+
+ # Resume or Finetune
+ with self.accelerator.main_process_first():
+ self._check_resume()
+
+ # accelerate prepare
+ self.logger.info("Initializing accelerate...")
+ start = time.monotonic_ns()
+ self._accelerator_prepare()
+ end = time.monotonic_ns()
+ self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms")
+
+ # save config file path
+ self.config_save_path = os.path.join(self.exp_dir, "args.json")
+ self.device = self.accelerator.device
+
+ if cfg.preprocess.use_spkid and cfg.train.multi_speaker_training:
+ self.speakers = self._build_speaker_lut()
+ self.utt2spk_dict = self._build_utt2spk_dict()
+
+ # Only for TTS tasks
+ self.task_type = "TTS"
+ self.logger.info("Task type: {}".format(self.task_type))
+
+ def _check_resume(self):
+ # if args.resume:
+ if self.args.resume or (
+ self.cfg.model_type == "VALLE" and self.args.train_stage == 2
+ ):
+ checkpoint_dir = self.checkpoint_dir
+ if self.cfg.model_type == "VALLE" and self.args.train_stage == 2:
+ ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
+ if (
+ self.args.checkpoint_path is None or len(ls) == 0
+ ): # Train stage 2 from scratch using the checkpoint of stage 1
+ assert (
+ self.args.ar_model_ckpt_dir is not None
+ ), "Error: ar_model_ckpt_dir should be set to train nar model."
+ self.args.resume_type = "finetune"
+ checkpoint_dir = self.args.ar_model_ckpt_dir
+ self.logger.info(
+ f"Training NAR model at stage 2 using the checkpoint of AR model at stage 1."
+ )
+
+ self.logger.info(f"Resuming from checkpoint: {checkpoint_dir}")
+ start = time.monotonic_ns()
+ self.ckpt_path = self._load_model(
+ checkpoint_dir, self.args.checkpoint_path, self.args.resume_type
+ )
+ self.logger.info(f"Checkpoint path: {self.ckpt_path}")
+ end = time.monotonic_ns()
+ self.logger.info(
+ f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
+ )
+ self.checkpoints_path = json.load(
+ open(os.path.join(self.ckpt_path, "ckpts.json"), "r")
+ )
+
+ def _init_accelerator(self):
+ self.exp_dir = os.path.join(
+ os.path.abspath(self.cfg.log_dir), self.args.exp_name
+ )
+ project_config = ProjectConfiguration(
+ project_dir=self.exp_dir,
+ logging_dir=os.path.join(self.exp_dir, "log"),
+ )
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ self.accelerator = accelerate.Accelerator(
+ gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step,
+ log_with=self.cfg.train.tracker,
+ project_config=project_config,
+ kwargs_handlers=[kwargs],
+ )
+ if self.accelerator.is_main_process:
+ os.makedirs(project_config.project_dir, exist_ok=True)
+ os.makedirs(project_config.logging_dir, exist_ok=True)
+ with self.accelerator.main_process_first():
+ self.accelerator.init_trackers(self.args.exp_name)
+
+ def _accelerator_prepare(self):
+ (
+ self.train_dataloader,
+ self.valid_dataloader,
+ ) = self.accelerator.prepare(
+ self.train_dataloader,
+ self.valid_dataloader,
+ )
+
+ if isinstance(self.model, dict):
+ for key in self.model.keys():
+ self.model[key] = self.accelerator.prepare(self.model[key])
+ else:
+ self.model = self.accelerator.prepare(self.model)
+
+ if isinstance(self.optimizer, dict):
+ for key in self.optimizer.keys():
+ self.optimizer[key] = self.accelerator.prepare(self.optimizer[key])
+ else:
+ self.optimizer = self.accelerator.prepare(self.optimizer)
+
+ if isinstance(self.scheduler, dict):
+ for key in self.scheduler.keys():
+ self.scheduler[key] = self.accelerator.prepare(self.scheduler[key])
+ else:
+ self.scheduler = self.accelerator.prepare(self.scheduler)
+
+ ### Following are methods only for TTS tasks ###
+ def _build_dataset(self):
+ pass
+
+ def _build_criterion(self):
+ pass
+
+ def _build_model(self):
+ pass
+
+ def _build_dataloader(self):
+ """Build dataloader which merges a series of datasets."""
+ # Build dataset instance for each dataset and combine them by ConcatDataset
+ Dataset, Collator = self._build_dataset()
+
+ # Build train set
+ datasets_list = []
+ for dataset in self.cfg.dataset:
+ subdataset = Dataset(self.cfg, dataset, is_valid=False)
+ datasets_list.append(subdataset)
+ train_dataset = ConcatDataset(datasets_list)
+ train_collate = Collator(self.cfg)
+ _, batch_sampler = build_samplers(train_dataset, self.cfg, self.logger, "train")
+ train_loader = DataLoader(
+ train_dataset,
+ collate_fn=train_collate,
+ batch_sampler=batch_sampler,
+ num_workers=self.cfg.train.dataloader.num_worker,
+ pin_memory=self.cfg.train.dataloader.pin_memory,
+ )
+
+ # Build test set
+ datasets_list = []
+ for dataset in self.cfg.dataset:
+ subdataset = Dataset(self.cfg, dataset, is_valid=True)
+ datasets_list.append(subdataset)
+ valid_dataset = ConcatDataset(datasets_list)
+ valid_collate = Collator(self.cfg)
+ _, batch_sampler = build_samplers(valid_dataset, self.cfg, self.logger, "valid")
+ valid_loader = DataLoader(
+ valid_dataset,
+ collate_fn=valid_collate,
+ batch_sampler=batch_sampler,
+ num_workers=self.cfg.train.dataloader.num_worker,
+ pin_memory=self.cfg.train.dataloader.pin_memory,
+ )
+ return train_loader, valid_loader
+
+ def _build_optimizer(self):
+ pass
+
+ def _build_scheduler(self):
+ pass
+
+ def _load_model(self, checkpoint_dir, checkpoint_path=None, resume_type="resume"):
+ """Load model from checkpoint. If a folder is given, it will
+ load the latest checkpoint in checkpoint_dir. If a path is given
+ it will load the checkpoint specified by checkpoint_path.
+ **Only use this method after** ``accelerator.prepare()``.
+ """
+ if checkpoint_path is None or checkpoint_path == "":
+ ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
+ ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
+ checkpoint_path = ls[0]
+ self.logger.info("Load model from {}".format(checkpoint_path))
+ print("Load model from {}".format(checkpoint_path))
+ if resume_type == "resume":
+ self.accelerator.load_state(checkpoint_path)
+ self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
+ self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
+ elif resume_type == "finetune":
+ if isinstance(self.model, dict):
+ for idx, sub_model in enumerate(self.model.keys()):
+ if idx == 0:
+ ckpt_name = "pytorch_model.bin"
+ else:
+ ckpt_name = "pytorch_model_{}.bin".format(idx)
+
+ self.model[sub_model].load_state_dict(
+ torch.load(os.path.join(checkpoint_path, ckpt_name))
+ )
+ self.model[sub_model].cuda(self.accelerator.device)
+ else:
+ self.model.load_state_dict(
+ torch.load(os.path.join(checkpoint_path, "pytorch_model.bin"))
+ )
+ self.model.cuda(self.accelerator.device)
+ self.logger.info("Load model weights for finetune SUCCESS!")
+
+ else:
+ raise ValueError("Unsupported resume type: {}".format(resume_type))
+
+ return checkpoint_path
+
+ ### THIS IS MAIN ENTRY ###
+ def train_loop(self):
+ r"""Training loop. The public entry of training process."""
+ # Wait everyone to prepare before we move on
+ self.accelerator.wait_for_everyone()
+ # dump config file
+ if self.accelerator.is_main_process:
+ self.__dump_cfg(self.config_save_path)
+
+ # self.optimizer.zero_grad()
+ # Wait to ensure good to go
+
+ self.accelerator.wait_for_everyone()
+ while self.epoch < self.max_epoch:
+ self.logger.info("\n")
+ self.logger.info("-" * 32)
+ self.logger.info("Epoch {}: ".format(self.epoch))
+
+ # Do training & validating epoch
+ train_total_loss, train_losses = self._train_epoch()
+ if isinstance(train_losses, dict):
+ for key, loss in train_losses.items():
+ self.logger.info(" |- Train/{} Loss: {:.6f}".format(key, loss))
+ self.accelerator.log(
+ {"Epoch/Train {} Loss".format(key): loss},
+ step=self.epoch,
+ )
+
+ valid_total_loss, valid_losses = self._valid_epoch()
+ if isinstance(valid_losses, dict):
+ for key, loss in valid_losses.items():
+ self.logger.info(" |- Valid/{} Loss: {:.6f}".format(key, loss))
+ self.accelerator.log(
+ {"Epoch/Valid {} Loss".format(key): loss},
+ step=self.epoch,
+ )
+
+ self.logger.info(" |- Train/Loss: {:.6f}".format(train_total_loss))
+ self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_total_loss))
+ self.accelerator.log(
+ {
+ "Epoch/Train Loss": train_total_loss,
+ "Epoch/Valid Loss": valid_total_loss,
+ },
+ step=self.epoch,
+ )
+
+ self.accelerator.wait_for_everyone()
+
+ # Check if hit save_checkpoint_stride and run_eval
+ run_eval = False
+ if self.accelerator.is_main_process:
+ save_checkpoint = False
+ hit_dix = []
+ for i, num in enumerate(self.save_checkpoint_stride):
+ if self.epoch % num == 0:
+ save_checkpoint = True
+ hit_dix.append(i)
+ run_eval |= self.run_eval[i]
+
+ self.accelerator.wait_for_everyone()
+ if self.accelerator.is_main_process and save_checkpoint:
+ path = os.path.join(
+ self.checkpoint_dir,
+ "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
+ self.epoch, self.step, train_total_loss
+ ),
+ )
+ self.accelerator.save_state(path)
+
+ json.dump(
+ self.checkpoints_path,
+ open(os.path.join(path, "ckpts.json"), "w"),
+ ensure_ascii=False,
+ indent=4,
+ )
+
+ # Remove old checkpoints
+ to_remove = []
+ for idx in hit_dix:
+ self.checkpoints_path[idx].append(path)
+ while len(self.checkpoints_path[idx]) > self.keep_last[idx]:
+ to_remove.append((idx, self.checkpoints_path[idx].pop(0)))
+
+ # Search conflicts
+ total = set()
+ for i in self.checkpoints_path:
+ total |= set(i)
+ do_remove = set()
+ for idx, path in to_remove[::-1]:
+ if path in total:
+ self.checkpoints_path[idx].insert(0, path)
+ else:
+ do_remove.add(path)
+
+ # Remove old checkpoints
+ for path in do_remove:
+ shutil.rmtree(path, ignore_errors=True)
+ self.logger.debug(f"Remove old checkpoint: {path}")
+
+ self.accelerator.wait_for_everyone()
+ if run_eval:
+ # TODO: run evaluation
+ pass
+
+ # Update info for each epoch
+ self.epoch += 1
+
+ # Finish training and save final checkpoint
+ self.accelerator.wait_for_everyone()
+ if self.accelerator.is_main_process:
+ path = os.path.join(
+ self.checkpoint_dir,
+ "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
+ self.epoch, self.step, valid_total_loss
+ ),
+ )
+ self.accelerator.save_state(
+ os.path.join(
+ self.checkpoint_dir,
+ "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
+ self.epoch, self.step, valid_total_loss
+ ),
+ )
+ )
+
+ json.dump(
+ self.checkpoints_path,
+ open(os.path.join(path, "ckpts.json"), "w"),
+ ensure_ascii=False,
+ indent=4,
+ )
+
+ self.accelerator.end_training()
+
+ ### Following are methods that can be used directly in child classes ###
+ def _train_epoch(self):
+ r"""Training epoch. Should return average loss of a batch (sample) over
+ one epoch. See ``train_loop`` for usage.
+ """
+ if isinstance(self.model, dict):
+ for key in self.model.keys():
+ self.model[key].train()
+ else:
+ self.model.train()
+
+ epoch_sum_loss: float = 0.0
+ epoch_losses: dict = {}
+ epoch_step: int = 0
+ for batch in tqdm(
+ self.train_dataloader,
+ desc=f"Training Epoch {self.epoch}",
+ unit="batch",
+ colour="GREEN",
+ leave=False,
+ dynamic_ncols=True,
+ smoothing=0.04,
+ disable=not self.accelerator.is_main_process,
+ ):
+ # Do training step and BP
+ with self.accelerator.accumulate(self.model):
+ total_loss, train_losses, _ = self._train_step(batch)
+ self.batch_count += 1
+
+ # Update info for each step
+ # TODO: step means BP counts or batch counts?
+ if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
+ if isinstance(self.scheduler, dict):
+ for key in self.scheduler.keys():
+ self.scheduler[key].step()
+ else:
+ if isinstance(self.scheduler, Eden):
+ self.scheduler.step_batch(self.step)
+ else:
+ self.scheduler.step()
+
+ epoch_sum_loss += total_loss
+
+ if isinstance(train_losses, dict):
+ for key, value in train_losses.items():
+ epoch_losses[key] += value
+
+ if isinstance(train_losses, dict):
+ for key, loss in train_losses.items():
+ self.accelerator.log(
+ {"Epoch/Train {} Loss".format(key): loss},
+ step=self.step,
+ )
+
+ self.step += 1
+ epoch_step += 1
+
+ self.accelerator.wait_for_everyone()
+
+ epoch_sum_loss = (
+ epoch_sum_loss
+ / len(self.train_dataloader)
+ * self.cfg.train.gradient_accumulation_step
+ )
+
+ for key in epoch_losses.keys():
+ epoch_losses[key] = (
+ epoch_losses[key]
+ / len(self.train_dataloader)
+ * self.cfg.train.gradient_accumulation_step
+ )
+
+ return epoch_sum_loss, epoch_losses
+
+ @torch.inference_mode()
+ def _valid_epoch(self):
+ r"""Testing epoch. Should return average loss of a batch (sample) over
+ one epoch. See ``train_loop`` for usage.
+ """
+ if isinstance(self.model, dict):
+ for key in self.model.keys():
+ self.model[key].eval()
+ else:
+ self.model.eval()
+
+ epoch_sum_loss = 0.0
+ epoch_losses = dict()
+ for batch in tqdm(
+ self.valid_dataloader,
+ desc=f"Validating Epoch {self.epoch}",
+ unit="batch",
+ colour="GREEN",
+ leave=False,
+ dynamic_ncols=True,
+ smoothing=0.04,
+ disable=not self.accelerator.is_main_process,
+ ):
+ total_loss, valid_losses, valid_stats = self._valid_step(batch)
+ epoch_sum_loss += total_loss
+ if isinstance(valid_losses, dict):
+ for key, value in valid_losses.items():
+ if key not in epoch_losses.keys():
+ epoch_losses[key] = value
+ else:
+ epoch_losses[key] += value
+
+ epoch_sum_loss = epoch_sum_loss / len(self.valid_dataloader)
+ for key in epoch_losses.keys():
+ epoch_losses[key] = epoch_losses[key] / len(self.valid_dataloader)
+
+ self.accelerator.wait_for_everyone()
+
+ return epoch_sum_loss, epoch_losses
+
+ def _train_step(self):
+ pass
+
+ def _valid_step(self, batch):
+ pass
+
+ def _inference(self):
+ pass
+
+ def _is_valid_pattern(self, directory_name):
+ directory_name = str(directory_name)
+ pattern = r"^epoch-\d{4}_step-\d{7}_loss-\d{1}\.\d{6}"
+ return re.match(pattern, directory_name) is not None
+
+ def _check_basic_configs(self):
+ if self.cfg.train.gradient_accumulation_step <= 0:
+ self.logger.fatal("Invalid gradient_accumulation_step value!")
+ self.logger.error(
+ f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
+ )
+ self.accelerator.end_training()
+ raise ValueError(
+ f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
+ )
+
+ def __dump_cfg(self, path):
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ json5.dump(
+ self.cfg,
+ open(path, "w"),
+ indent=4,
+ sort_keys=True,
+ ensure_ascii=False,
+ quote_keys=True,
+ )
+
+ def __check_basic_configs(self):
+ if self.cfg.train.gradient_accumulation_step <= 0:
+ self.logger.fatal("Invalid gradient_accumulation_step value!")
+ self.logger.error(
+ f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
+ )
+ self.accelerator.end_training()
+ raise ValueError(
+ f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
+ )
+ # TODO: check other values
+
+ @staticmethod
+ def __count_parameters(model):
+ model_param = 0.0
+ if isinstance(model, dict):
+ for key, value in model.items():
+ model_param += sum(p.numel() for p in model[key].parameters())
+ else:
+ model_param = sum(p.numel() for p in model.parameters())
+ return model_param
+
+ def _build_speaker_lut(self):
+ # combine speakers
+ if not os.path.exists(os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)):
+ speakers = {}
+ else:
+ with open(
+ os.path.join(self.exp_dir, self.cfg.preprocess.spk2id), "r"
+ ) as speaker_file:
+ speakers = json.load(speaker_file)
+ for dataset in self.cfg.dataset:
+ speaker_lut_path = os.path.join(
+ self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id
+ )
+ with open(speaker_lut_path, "r") as speaker_lut_path:
+ singer_lut = json.load(speaker_lut_path)
+ for singer in singer_lut.keys():
+ if singer not in speakers:
+ speakers[singer] = len(speakers)
+ with open(
+ os.path.join(self.exp_dir, self.cfg.preprocess.spk2id), "w"
+ ) as speaker_file:
+ json.dump(speakers, speaker_file, indent=4, ensure_ascii=False)
+ print(
+ "speakers have been dumped to {}".format(
+ os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)
+ )
+ )
+ return speakers
+
+ def _build_utt2spk_dict(self):
+ # combine speakers
+ utt2spk = {}
+ if not os.path.exists(os.path.join(self.exp_dir, self.cfg.preprocess.utt2spk)):
+ utt2spk = {}
+ else:
+ with open(
+ os.path.join(self.exp_dir, self.cfg.preprocess.utt2spk), "r"
+ ) as utt2spk_file:
+ for line in utt2spk_file.readlines():
+ utt, spk = line.strip().split("\t")
+ utt2spk[utt] = spk
+ for dataset in self.cfg.dataset:
+ utt2spk_dict_path = os.path.join(
+ self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.utt2spk
+ )
+ with open(utt2spk_dict_path, "r") as utt2spk_dict:
+ for line in utt2spk_dict.readlines():
+ utt, spk = line.strip().split("\t")
+ if utt not in utt2spk.keys():
+ utt2spk[utt] = spk
+ with open(
+ os.path.join(self.exp_dir, self.cfg.preprocess.utt2spk), "w"
+ ) as utt2spk_file:
+ for utt, spk in utt2spk.items():
+ utt2spk_file.write(utt + "\t" + spk + "\n")
+ print(
+ "utterance and speaker mapper have been dumped to {}".format(
+ os.path.join(self.exp_dir, self.cfg.preprocess.utt2spk)
+ )
+ )
+ return utt2spk
+
+ def _save_phone_symbols_file_to_exp_path(self):
+ phone_symbols_file = os.path.join(
+ self.cfg.preprocess.processed_dir,
+ self.cfg.dataset[0],
+ self.cfg.preprocess.symbols_dict,
+ )
+ phone_symbols_file_to_exp_path = os.path.join(
+ self.exp_dir, self.cfg.preprocess.symbols_dict
+ )
+ shutil.copy(phone_symbols_file, phone_symbols_file_to_exp_path)
+ os.chmod(phone_symbols_file_to_exp_path, 0o666)
+ print(
+ "phone symbols been dumped to {}".format(
+ os.path.join(self.exp_dir, self.cfg.preprocess.symbols_dict)
+ )
+ )
diff --git a/models/tts/fastspeech2/__init__.py b/models/tts/fastspeech2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/tts/fastspeech2/fs2.py b/models/tts/fastspeech2/fs2.py
new file mode 100644
index 0000000000000000000000000000000000000000..61cab7eb598a76d29388b4b3150fa7925a79f5fb
--- /dev/null
+++ b/models/tts/fastspeech2/fs2.py
@@ -0,0 +1,548 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# This code is modified from https://github.com/ming024/FastSpeech2/blob/master/model/fastspeech2.py
+import torch
+import torch.nn as nn
+import numpy as np
+import torch.nn.functional as F
+
+from modules.transformer.Models import Encoder, Decoder
+from modules.transformer.Layers import PostNet
+from collections import OrderedDict
+
+import os
+import json
+
+
+def get_mask_from_lengths(lengths, max_len=None):
+ device = lengths.device
+ batch_size = lengths.shape[0]
+ if max_len is None:
+ max_len = torch.max(lengths).item()
+
+ ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device)
+ mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
+
+ return mask
+
+
+def pad(input_ele, mel_max_length=None):
+ if mel_max_length:
+ max_len = mel_max_length
+ else:
+ max_len = max([input_ele[i].size(0) for i in range(len(input_ele))])
+
+ out_list = list()
+ for i, batch in enumerate(input_ele):
+ if len(batch.shape) == 1:
+ one_batch_padded = F.pad(
+ batch, (0, max_len - batch.size(0)), "constant", 0.0
+ )
+ elif len(batch.shape) == 2:
+ one_batch_padded = F.pad(
+ batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0
+ )
+ out_list.append(one_batch_padded)
+ out_padded = torch.stack(out_list)
+ return out_padded
+
+
+class VarianceAdaptor(nn.Module):
+ """Variance Adaptor"""
+
+ def __init__(self, cfg):
+ super(VarianceAdaptor, self).__init__()
+ self.duration_predictor = VariancePredictor(cfg)
+ self.length_regulator = LengthRegulator()
+ self.pitch_predictor = VariancePredictor(cfg)
+ self.energy_predictor = VariancePredictor(cfg)
+
+ # assign the pitch/energy feature level
+ if cfg.preprocess.use_frame_pitch:
+ self.pitch_feature_level = "frame_level"
+ self.pitch_dir = cfg.preprocess.pitch_dir
+ else:
+ self.pitch_feature_level = "phoneme_level"
+ self.pitch_dir = cfg.preprocess.phone_pitch_dir
+
+ if cfg.preprocess.use_frame_energy:
+ self.energy_feature_level = "frame_level"
+ self.energy_dir = cfg.preprocess.energy_dir
+ else:
+ self.energy_feature_level = "phoneme_level"
+ self.energy_dir = cfg.preprocess.phone_energy_dir
+
+ assert self.pitch_feature_level in ["phoneme_level", "frame_level"]
+ assert self.energy_feature_level in ["phoneme_level", "frame_level"]
+
+ pitch_quantization = cfg.model.variance_embedding.pitch_quantization
+ energy_quantization = cfg.model.variance_embedding.energy_quantization
+ n_bins = cfg.model.variance_embedding.n_bins
+ assert pitch_quantization in ["linear", "log"]
+ assert energy_quantization in ["linear", "log"]
+
+ with open(
+ os.path.join(
+ cfg.preprocess.processed_dir,
+ cfg.dataset[0],
+ self.energy_dir,
+ "statistics.json",
+ )
+ ) as f:
+ stats = json.load(f)
+ stats = stats[cfg.dataset[0] + "_" + cfg.dataset[0]]
+ mean, std = (
+ stats["voiced_positions"]["mean"],
+ stats["voiced_positions"]["std"],
+ )
+ energy_min = (stats["total_positions"]["min"] - mean) / std
+ energy_max = (stats["total_positions"]["max"] - mean) / std
+
+ with open(
+ os.path.join(
+ cfg.preprocess.processed_dir,
+ cfg.dataset[0],
+ self.pitch_dir,
+ "statistics.json",
+ )
+ ) as f:
+ stats = json.load(f)
+ stats = stats[cfg.dataset[0] + "_" + cfg.dataset[0]]
+ mean, std = (
+ stats["voiced_positions"]["mean"],
+ stats["voiced_positions"]["std"],
+ )
+ pitch_min = (stats["total_positions"]["min"] - mean) / std
+ pitch_max = (stats["total_positions"]["max"] - mean) / std
+
+ if pitch_quantization == "log":
+ self.pitch_bins = nn.Parameter(
+ torch.exp(
+ torch.linspace(np.log(pitch_min), np.log(pitch_max), n_bins - 1)
+ ),
+ requires_grad=False,
+ )
+ else:
+ self.pitch_bins = nn.Parameter(
+ torch.linspace(pitch_min, pitch_max, n_bins - 1),
+ requires_grad=False,
+ )
+ if energy_quantization == "log":
+ self.energy_bins = nn.Parameter(
+ torch.exp(
+ torch.linspace(np.log(energy_min), np.log(energy_max), n_bins - 1)
+ ),
+ requires_grad=False,
+ )
+ else:
+ self.energy_bins = nn.Parameter(
+ torch.linspace(energy_min, energy_max, n_bins - 1),
+ requires_grad=False,
+ )
+
+ self.pitch_embedding = nn.Embedding(
+ n_bins, cfg.model.transformer.encoder_hidden
+ )
+ self.energy_embedding = nn.Embedding(
+ n_bins, cfg.model.transformer.encoder_hidden
+ )
+
+ def get_pitch_embedding(self, x, target, mask, control):
+ prediction = self.pitch_predictor(x, mask)
+ if target is not None:
+ embedding = self.pitch_embedding(torch.bucketize(target, self.pitch_bins))
+ else:
+ prediction = prediction * control
+ embedding = self.pitch_embedding(
+ torch.bucketize(prediction, self.pitch_bins)
+ )
+ return prediction, embedding
+
+ def get_energy_embedding(self, x, target, mask, control):
+ prediction = self.energy_predictor(x, mask)
+ if target is not None:
+ embedding = self.energy_embedding(torch.bucketize(target, self.energy_bins))
+ else:
+ prediction = prediction * control
+ embedding = self.energy_embedding(
+ torch.bucketize(prediction, self.energy_bins)
+ )
+ return prediction, embedding
+
+ def forward(
+ self,
+ x,
+ src_mask,
+ mel_mask=None,
+ max_len=None,
+ pitch_target=None,
+ energy_target=None,
+ duration_target=None,
+ p_control=1.0,
+ e_control=1.0,
+ d_control=1.0,
+ ):
+ log_duration_prediction = self.duration_predictor(x, src_mask)
+ if self.pitch_feature_level == "phoneme_level":
+ pitch_prediction, pitch_embedding = self.get_pitch_embedding(
+ x, pitch_target, src_mask, p_control
+ )
+ x = x + pitch_embedding
+ if self.energy_feature_level == "phoneme_level":
+ energy_prediction, energy_embedding = self.get_energy_embedding(
+ x, energy_target, src_mask, e_control
+ )
+ x = x + energy_embedding
+
+ if duration_target is not None:
+ x, mel_len = self.length_regulator(x, duration_target, max_len)
+ duration_rounded = duration_target
+ else:
+ duration_rounded = torch.clamp(
+ (torch.round(torch.exp(log_duration_prediction) - 1) * d_control),
+ min=0,
+ )
+ x, mel_len = self.length_regulator(x, duration_rounded, max_len)
+ mel_mask = get_mask_from_lengths(mel_len)
+
+ if self.pitch_feature_level == "frame_level":
+ pitch_prediction, pitch_embedding = self.get_pitch_embedding(
+ x, pitch_target, mel_mask, p_control
+ )
+ x = x + pitch_embedding
+ if self.energy_feature_level == "frame_level":
+ energy_prediction, energy_embedding = self.get_energy_embedding(
+ x, energy_target, mel_mask, e_control
+ )
+ x = x + energy_embedding
+
+ return (
+ x,
+ pitch_prediction,
+ energy_prediction,
+ log_duration_prediction,
+ duration_rounded,
+ mel_len,
+ mel_mask,
+ )
+
+
+class LengthRegulator(nn.Module):
+ """Length Regulator"""
+
+ def __init__(self):
+ super(LengthRegulator, self).__init__()
+
+ def LR(self, x, duration, max_len):
+ device = x.device
+ output = list()
+ mel_len = list()
+ for batch, expand_target in zip(x, duration):
+ expanded = self.expand(batch, expand_target)
+ output.append(expanded)
+ mel_len.append(expanded.shape[0])
+
+ if max_len is not None:
+ output = pad(output, max_len)
+ else:
+ output = pad(output)
+
+ return output, torch.LongTensor(mel_len).to(device)
+
+ def expand(self, batch, predicted):
+ out = list()
+
+ for i, vec in enumerate(batch):
+ expand_size = predicted[i].item()
+ out.append(vec.expand(max(int(expand_size), 0), -1))
+ out = torch.cat(out, 0)
+
+ return out
+
+ def forward(self, x, duration, max_len):
+ output, mel_len = self.LR(x, duration, max_len)
+ return output, mel_len
+
+
+class VariancePredictor(nn.Module):
+ """Duration, Pitch and Energy Predictor"""
+
+ def __init__(self, cfg):
+ super(VariancePredictor, self).__init__()
+
+ self.input_size = cfg.model.transformer.encoder_hidden
+ self.filter_size = cfg.model.variance_predictor.filter_size
+ self.kernel = cfg.model.variance_predictor.kernel_size
+ self.conv_output_size = cfg.model.variance_predictor.filter_size
+ self.dropout = cfg.model.variance_predictor.dropout
+
+ self.conv_layer = nn.Sequential(
+ OrderedDict(
+ [
+ (
+ "conv1d_1",
+ Conv(
+ self.input_size,
+ self.filter_size,
+ kernel_size=self.kernel,
+ padding=(self.kernel - 1) // 2,
+ ),
+ ),
+ ("relu_1", nn.ReLU()),
+ ("layer_norm_1", nn.LayerNorm(self.filter_size)),
+ ("dropout_1", nn.Dropout(self.dropout)),
+ (
+ "conv1d_2",
+ Conv(
+ self.filter_size,
+ self.filter_size,
+ kernel_size=self.kernel,
+ padding=1,
+ ),
+ ),
+ ("relu_2", nn.ReLU()),
+ ("layer_norm_2", nn.LayerNorm(self.filter_size)),
+ ("dropout_2", nn.Dropout(self.dropout)),
+ ]
+ )
+ )
+
+ self.linear_layer = nn.Linear(self.conv_output_size, 1)
+
+ def forward(self, encoder_output, mask):
+ out = self.conv_layer(encoder_output)
+ out = self.linear_layer(out)
+ out = out.squeeze(-1)
+
+ if mask is not None:
+ out = out.masked_fill(mask, 0.0)
+
+ return out
+
+
+class Conv(nn.Module):
+ """
+ Convolution Module
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ dilation=1,
+ bias=True,
+ w_init="linear",
+ ):
+ """
+ :param in_channels: dimension of input
+ :param out_channels: dimension of output
+ :param kernel_size: size of kernel
+ :param stride: size of stride
+ :param padding: size of padding
+ :param dilation: dilation rate
+ :param bias: boolean. if True, bias is included.
+ :param w_init: str. weight inits with xavier initialization.
+ """
+ super(Conv, self).__init__()
+
+ self.conv = nn.Conv1d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ bias=bias,
+ )
+
+ def forward(self, x):
+ x = x.contiguous().transpose(1, 2)
+ x = self.conv(x)
+ x = x.contiguous().transpose(1, 2)
+
+ return x
+
+
+class FastSpeech2(nn.Module):
+ def __init__(self, cfg) -> None:
+ super(FastSpeech2, self).__init__()
+ self.cfg = cfg
+ self.encoder = Encoder(cfg.model)
+ self.variance_adaptor = VarianceAdaptor(cfg)
+ self.decoder = Decoder(cfg.model)
+ self.mel_linear = nn.Linear(
+ cfg.model.transformer.decoder_hidden,
+ cfg.preprocess.n_mel,
+ )
+ self.postnet = PostNet(n_mel_channels=cfg.preprocess.n_mel)
+
+ self.speaker_emb = None
+ if cfg.train.multi_speaker_training:
+ with open(
+ os.path.join(
+ cfg.preprocess.processed_dir, cfg.dataset[0], "spk2id.json"
+ ),
+ "r",
+ ) as f:
+ n_speaker = len(json.load(f))
+ self.speaker_emb = nn.Embedding(
+ n_speaker,
+ cfg.model.transformer.encoder_hidden,
+ )
+
+ def forward(self, data, p_control=1.0, e_control=1.0, d_control=1.0):
+ speakers = data["spk_id"]
+ texts = data["texts"]
+ src_lens = data["text_len"]
+ max_src_len = max(src_lens)
+ mel_lens = data["target_len"] if "target_len" in data else None
+ max_mel_len = max(mel_lens) if "target_len" in data else None
+ p_targets = data["pitch"] if "pitch" in data else None
+ e_targets = data["energy"] if "energy" in data else None
+ d_targets = data["durations"] if "durations" in data else None
+ src_masks = get_mask_from_lengths(src_lens, max_src_len)
+ mel_masks = (
+ get_mask_from_lengths(mel_lens, max_mel_len)
+ if mel_lens is not None
+ else None
+ )
+
+ output = self.encoder(texts, src_masks)
+
+ if self.speaker_emb is not None:
+ output = output + self.speaker_emb(speakers).unsqueeze(1).expand(
+ -1, max_src_len, -1
+ )
+
+ (
+ output,
+ p_predictions,
+ e_predictions,
+ log_d_predictions,
+ d_rounded,
+ mel_lens,
+ mel_masks,
+ ) = self.variance_adaptor(
+ output,
+ src_masks,
+ mel_masks,
+ max_mel_len,
+ p_targets,
+ e_targets,
+ d_targets,
+ p_control,
+ e_control,
+ d_control,
+ )
+
+ output, mel_masks = self.decoder(output, mel_masks)
+ output = self.mel_linear(output)
+
+ postnet_output = self.postnet(output) + output
+
+ return {
+ "output": output,
+ "postnet_output": postnet_output,
+ "p_predictions": p_predictions,
+ "e_predictions": e_predictions,
+ "log_d_predictions": log_d_predictions,
+ "d_rounded": d_rounded,
+ "src_masks": src_masks,
+ "mel_masks": mel_masks,
+ "src_lens": src_lens,
+ "mel_lens": mel_lens,
+ }
+
+
+class FastSpeech2Loss(nn.Module):
+ """FastSpeech2 Loss"""
+
+ def __init__(self, cfg):
+ super(FastSpeech2Loss, self).__init__()
+ if cfg.preprocess.use_frame_pitch:
+ self.pitch_feature_level = "frame_level"
+ else:
+ self.pitch_feature_level = "phoneme_level"
+
+ if cfg.preprocess.use_frame_energy:
+ self.energy_feature_level = "frame_level"
+ else:
+ self.energy_feature_level = "phoneme_level"
+
+ self.mse_loss = nn.MSELoss()
+ self.mae_loss = nn.L1Loss()
+
+ def forward(self, data, predictions):
+ mel_targets = data["mel"]
+ pitch_targets = data["pitch"].float()
+ energy_targets = data["energy"].float()
+ duration_targets = data["durations"]
+
+ mel_predictions = predictions["output"]
+ postnet_mel_predictions = predictions["postnet_output"]
+ pitch_predictions = predictions["p_predictions"]
+ energy_predictions = predictions["e_predictions"]
+ log_duration_predictions = predictions["log_d_predictions"]
+ src_masks = predictions["src_masks"]
+ mel_masks = predictions["mel_masks"]
+
+ src_masks = ~src_masks
+ mel_masks = ~mel_masks
+
+ log_duration_targets = torch.log(duration_targets.float() + 1)
+ mel_targets = mel_targets[:, : mel_masks.shape[1], :]
+ mel_masks = mel_masks[:, : mel_masks.shape[1]]
+
+ log_duration_targets.requires_grad = False
+ pitch_targets.requires_grad = False
+ energy_targets.requires_grad = False
+ mel_targets.requires_grad = False
+
+ if self.pitch_feature_level == "phoneme_level":
+ pitch_predictions = pitch_predictions.masked_select(src_masks)
+ pitch_targets = pitch_targets.masked_select(src_masks)
+ elif self.pitch_feature_level == "frame_level":
+ pitch_predictions = pitch_predictions.masked_select(mel_masks)
+ pitch_targets = pitch_targets.masked_select(mel_masks)
+
+ if self.energy_feature_level == "phoneme_level":
+ energy_predictions = energy_predictions.masked_select(src_masks)
+ energy_targets = energy_targets.masked_select(src_masks)
+ if self.energy_feature_level == "frame_level":
+ energy_predictions = energy_predictions.masked_select(mel_masks)
+ energy_targets = energy_targets.masked_select(mel_masks)
+
+ log_duration_predictions = log_duration_predictions.masked_select(src_masks)
+ log_duration_targets = log_duration_targets.masked_select(src_masks)
+
+ mel_predictions = mel_predictions.masked_select(mel_masks.unsqueeze(-1))
+ postnet_mel_predictions = postnet_mel_predictions.masked_select(
+ mel_masks.unsqueeze(-1)
+ )
+ mel_targets = mel_targets.masked_select(mel_masks.unsqueeze(-1))
+
+ mel_loss = self.mae_loss(mel_predictions, mel_targets)
+ postnet_mel_loss = self.mae_loss(postnet_mel_predictions, mel_targets)
+
+ pitch_loss = self.mse_loss(pitch_predictions, pitch_targets)
+ energy_loss = self.mse_loss(energy_predictions, energy_targets)
+ duration_loss = self.mse_loss(log_duration_predictions, log_duration_targets)
+
+ total_loss = (
+ mel_loss + postnet_mel_loss + duration_loss + pitch_loss + energy_loss
+ )
+
+ return {
+ "loss": total_loss,
+ "mel_loss": mel_loss,
+ "postnet_mel_loss": postnet_mel_loss,
+ "pitch_loss": pitch_loss,
+ "energy_loss": energy_loss,
+ "duration_loss": duration_loss,
+ }
diff --git a/models/tts/fastspeech2/fs2_dataset.py b/models/tts/fastspeech2/fs2_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..dae5b0458ae0c709d417609f7d579e3b67d1ea08
--- /dev/null
+++ b/models/tts/fastspeech2/fs2_dataset.py
@@ -0,0 +1,424 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import random
+import torch
+from torch.nn.utils.rnn import pad_sequence
+from utils.data_utils import *
+from models.base.base_dataset import (
+ BaseOfflineCollator,
+ BaseOfflineDataset,
+ BaseTestDataset,
+ BaseTestCollator,
+)
+from text import text_to_sequence
+
+
+class FS2Dataset(BaseOfflineDataset):
+ def __init__(self, cfg, dataset, is_valid=False):
+ BaseOfflineDataset.__init__(self, cfg, dataset, is_valid=is_valid)
+ self.batch_size = cfg.train.batch_size
+ cfg = cfg.preprocess
+ # utt2duration
+ self.utt2duration_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2duration_path[utt] = os.path.join(
+ cfg.processed_dir,
+ dataset,
+ cfg.duration_dir,
+ uid + ".npy",
+ )
+ self.utt2dur = self.read_duration()
+
+ if cfg.use_frame_energy:
+ self.frame_utt2energy, self.energy_statistic = load_energy(
+ self.metadata,
+ cfg.processed_dir,
+ cfg.energy_dir,
+ use_log_scale=cfg.use_log_scale_energy,
+ utt2spk=self.preprocess.utt2spk if cfg.use_spkid else None,
+ return_norm=True,
+ )
+ elif cfg.use_phone_energy:
+ self.phone_utt2energy, self.energy_statistic = load_energy(
+ self.metadata,
+ cfg.processed_dir,
+ cfg.phone_energy_dir,
+ use_log_scale=cfg.use_log_scale_energy,
+ utt2spk=self.utt2spk if cfg.use_spkid else None,
+ return_norm=True,
+ )
+
+ if cfg.use_frame_pitch:
+ self.frame_utt2pitch, self.pitch_statistic = load_energy(
+ self.metadata,
+ cfg.processed_dir,
+ cfg.pitch_dir,
+ use_log_scale=cfg.energy_extract_mode,
+ utt2spk=self.utt2spk if cfg.use_spkid else None,
+ return_norm=True,
+ )
+
+ elif cfg.use_phone_pitch:
+ self.phone_utt2pitch, self.pitch_statistic = load_energy(
+ self.metadata,
+ cfg.processed_dir,
+ cfg.phone_pitch_dir,
+ use_log_scale=cfg.use_log_scale_pitch,
+ utt2spk=self.utt2spk if cfg.use_spkid else None,
+ return_norm=True,
+ )
+
+ # utt2lab
+ self.utt2lab_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2lab_path[utt] = os.path.join(
+ cfg.processed_dir,
+ dataset,
+ cfg.lab_dir,
+ uid + ".txt",
+ )
+
+ self.speaker_map = {}
+ if os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json")):
+ with open(
+ os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json"))
+ ) as f:
+ self.speaker_map = json.load(f)
+
+ self.metadata = self.check_metadata()
+
+ def __getitem__(self, index):
+ single_feature = BaseOfflineDataset.__getitem__(self, index)
+
+ utt_info = self.metadata[index]
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ duration = self.utt2dur[utt]
+
+ # text
+ f = open(self.utt2lab_path[utt], "r")
+ phones = f.readlines()[0].strip()
+ f.close()
+ # todo: add cleaner(chenxi)
+ phones_ids = np.array(text_to_sequence(phones, ["english_cleaners"]))
+ text_len = len(phones_ids)
+
+ if self.cfg.preprocess.use_frame_pitch:
+ pitch = self.frame_utt2pitch[utt]
+ elif self.cfg.preprocess.use_phone_pitch:
+ pitch = self.phone_utt2pitch[utt]
+
+ if self.cfg.preprocess.use_frame_energy:
+ energy = self.frame_utt2energy[utt]
+ elif self.cfg.preprocess.use_phone_energy:
+ energy = self.phone_utt2energy[utt]
+
+ # speaker
+ if len(self.speaker_map) > 0:
+ speaker_id = self.speaker_map[utt_info["Singer"]]
+ else:
+ speaker_id = 0
+
+ single_feature.update(
+ {
+ "durations": duration,
+ "texts": phones_ids,
+ "spk_id": speaker_id,
+ "text_len": text_len,
+ "pitch": pitch,
+ "energy": energy,
+ "uid": uid,
+ }
+ )
+ return self.clip_if_too_long(single_feature)
+
+ def read_duration(self):
+ # read duration
+ utt2dur = {}
+ for index in range(len(self.metadata)):
+ utt_info = self.metadata[index]
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ if not os.path.exists(self.utt2mel_path[utt]) or not os.path.exists(
+ self.utt2duration_path[utt]
+ ):
+ continue
+
+ mel = np.load(self.utt2mel_path[utt]).transpose(1, 0)
+ duration = np.load(self.utt2duration_path[utt])
+ assert mel.shape[0] == sum(
+ duration
+ ), f"{utt}: mismatch length between mel {mel.shape[0]} and sum(duration) {sum(duration)}"
+ utt2dur[utt] = duration
+ return utt2dur
+
+ def __len__(self):
+ return len(self.metadata)
+
+ def random_select(self, feature_seq_len, max_seq_len, ending_ts=2812):
+ """
+ ending_ts: to avoid invalid whisper features for over 30s audios
+ 2812 = 30 * 24000 // 256
+ """
+ ts = max(feature_seq_len - max_seq_len, 0)
+ ts = min(ts, ending_ts - max_seq_len)
+
+ start = random.randint(0, ts)
+ end = start + max_seq_len
+ return start, end
+
+ def clip_if_too_long(self, sample, max_seq_len=1000):
+ """
+ sample :
+ {
+ 'spk_id': (1,),
+ 'target_len': int
+ 'mel': (seq_len, dim),
+ 'frame_pitch': (seq_len,)
+ 'frame_energy': (seq_len,)
+ 'content_vector_feat': (seq_len, dim)
+ }
+ """
+ if sample["target_len"] <= max_seq_len:
+ return sample
+
+ start, end = self.random_select(sample["target_len"], max_seq_len)
+ sample["target_len"] = end - start
+
+ for k in sample.keys():
+ if k not in ["spk_id", "target_len"]:
+ sample[k] = sample[k][start:end]
+
+ return sample
+
+ def check_metadata(self):
+ new_metadata = []
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+ if not os.path.exists(self.utt2duration_path[utt]) or not os.path.exists(
+ self.utt2mel_path[utt]
+ ):
+ continue
+ else:
+ new_metadata.append(utt_info)
+ return new_metadata
+
+
+class FS2Collator(BaseOfflineCollator):
+ """Zero-pads model inputs and targets based on number of frames per step"""
+
+ def __init__(self, cfg):
+ BaseOfflineCollator.__init__(self, cfg)
+ self.sort = cfg.train.sort_sample
+ self.batch_size = cfg.train.batch_size
+ self.drop_last = cfg.train.drop_last
+
+ def __call__(self, batch):
+ # mel: [b, T, n_mels]
+ # frame_pitch, frame_energy: [1, T]
+ # target_len: [1]
+ # spk_id: [b, 1]
+ # mask: [b, T, 1]
+ packed_batch_features = dict()
+
+ for key in batch[0].keys():
+ if key == "target_len":
+ packed_batch_features["target_len"] = torch.LongTensor(
+ [b["target_len"] for b in batch]
+ )
+ masks = [
+ torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
+ ]
+ packed_batch_features["mask"] = pad_sequence(
+ masks, batch_first=True, padding_value=0
+ )
+ elif key == "text_len":
+ packed_batch_features["text_len"] = torch.LongTensor(
+ [b["text_len"] for b in batch]
+ )
+ masks = [
+ torch.ones((b["text_len"], 1), dtype=torch.long) for b in batch
+ ]
+ packed_batch_features["text_mask"] = pad_sequence(
+ masks, batch_first=True, padding_value=0
+ )
+ elif key == "spk_id":
+ packed_batch_features["spk_id"] = torch.LongTensor(
+ [b["spk_id"] for b in batch]
+ )
+ elif key == "uid":
+ packed_batch_features[key] = [b["uid"] for b in batch]
+ else:
+ values = [torch.from_numpy(b[key]) for b in batch]
+ packed_batch_features[key] = pad_sequence(
+ values, batch_first=True, padding_value=0
+ )
+ return packed_batch_features
+
+
+class FS2TestDataset(BaseTestDataset):
+ def __init__(self, args, cfg, infer_type=None):
+ datasets = cfg.dataset
+ cfg = cfg.preprocess
+ is_bigdata = False
+
+ assert len(datasets) >= 1
+ if len(datasets) > 1:
+ datasets.sort()
+ bigdata_version = "_".join(datasets)
+ processed_data_dir = os.path.join(cfg.processed_dir, bigdata_version)
+ is_bigdata = True
+ else:
+ processed_data_dir = os.path.join(cfg.processed_dir, args.dataset)
+
+ if args.test_list_file:
+ self.metafile_path = args.test_list_file
+ self.metadata = self.get_metadata()
+ else:
+ assert args.testing_set
+ source_metafile_path = os.path.join(
+ cfg.processed_dir,
+ args.dataset,
+ "{}.json".format(args.testing_set),
+ )
+ with open(source_metafile_path, "r") as f:
+ self.metadata = json.load(f)
+
+ self.cfg = cfg
+ self.datasets = datasets
+ self.data_root = processed_data_dir
+ self.is_bigdata = is_bigdata
+ self.source_dataset = args.dataset
+
+ ######### Load source acoustic features #########
+ if cfg.use_spkid:
+ spk2id_path = os.path.join(self.data_root, cfg.spk2id)
+ utt2sp_path = os.path.join(self.data_root, cfg.utt2spk)
+ self.spk2id, self.utt2spk = get_spk_map(spk2id_path, utt2sp_path, datasets)
+
+ # utt2lab
+ self.utt2lab_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+ self.utt2lab_path[utt] = os.path.join(
+ cfg.processed_dir,
+ dataset,
+ cfg.lab_dir,
+ uid + ".txt",
+ )
+
+ self.speaker_map = {}
+ if os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json")):
+ with open(
+ os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json"))
+ ) as f:
+ self.speaker_map = json.load(f)
+
+ def __getitem__(self, index):
+ single_feature = {}
+
+ utt_info = self.metadata[index]
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ # text
+ f = open(self.utt2lab_path[utt], "r")
+ phones = f.readlines()[0].strip()
+ f.close()
+
+ phones_ids = np.array(text_to_sequence(phones, self.cfg.text_cleaners))
+ text_len = len(phones_ids)
+
+ # speaker
+ if len(self.speaker_map) > 0:
+ speaker_id = self.speaker_map[utt_info["Singer"]]
+ else:
+ speaker_id = 0
+
+ single_feature.update(
+ {
+ "texts": phones_ids,
+ "spk_id": speaker_id,
+ "text_len": text_len,
+ }
+ )
+
+ return single_feature
+
+ def __len__(self):
+ return len(self.metadata)
+
+ def get_metadata(self):
+ with open(self.metafile_path, "r", encoding="utf-8") as f:
+ metadata = json.load(f)
+
+ return metadata
+
+
+class FS2TestCollator(BaseTestCollator):
+ """Zero-pads model inputs and targets based on number of frames per step"""
+
+ def __init__(self, cfg):
+ self.cfg = cfg
+
+ def __call__(self, batch):
+ packed_batch_features = dict()
+
+ # mel: [b, T, n_mels]
+ # frame_pitch, frame_energy: [1, T]
+ # target_len: [1]
+ # spk_id: [b, 1]
+ # mask: [b, T, 1]
+
+ for key in batch[0].keys():
+ if key == "target_len":
+ packed_batch_features["target_len"] = torch.LongTensor(
+ [b["target_len"] for b in batch]
+ )
+ masks = [
+ torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
+ ]
+ packed_batch_features["mask"] = pad_sequence(
+ masks, batch_first=True, padding_value=0
+ )
+ elif key == "text_len":
+ packed_batch_features["text_len"] = torch.LongTensor(
+ [b["text_len"] for b in batch]
+ )
+ masks = [
+ torch.ones((b["text_len"], 1), dtype=torch.long) for b in batch
+ ]
+ packed_batch_features["text_mask"] = pad_sequence(
+ masks, batch_first=True, padding_value=0
+ )
+ elif key == "spk_id":
+ packed_batch_features["spk_id"] = torch.LongTensor(
+ [b["spk_id"] for b in batch]
+ )
+ else:
+ values = [torch.from_numpy(b[key]) for b in batch]
+ packed_batch_features[key] = pad_sequence(
+ values, batch_first=True, padding_value=0
+ )
+
+ return packed_batch_features
diff --git a/models/tts/fastspeech2/fs2_inference.py b/models/tts/fastspeech2/fs2_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a03afd0866e4ee5b3e257eb376a22558a9a055d
--- /dev/null
+++ b/models/tts/fastspeech2/fs2_inference.py
@@ -0,0 +1,193 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import torch
+from tqdm import tqdm
+from collections import OrderedDict
+
+from models.tts.base.tts_inferece import TTSInference
+from models.tts.fastspeech2.fs2_dataset import FS2TestDataset, FS2TestCollator
+from utils.util import load_config
+from utils.io import save_audio
+from models.tts.fastspeech2.fs2 import FastSpeech2
+from models.vocoders.vocoder_inference import synthesis
+from pathlib import Path
+from processors.phone_extractor import phoneExtractor
+from text.text_token_collation import phoneIDCollation
+import numpy as np
+import json
+
+
+class FastSpeech2Inference(TTSInference):
+ def __init__(self, args, cfg):
+ TTSInference.__init__(self, args, cfg)
+ self.args = args
+ self.cfg = cfg
+ self.infer_type = args.mode
+
+ def _build_model(self):
+ self.model = FastSpeech2(self.cfg)
+ return self.model
+
+ def load_model(self, state_dict):
+ raw_dict = state_dict["model"]
+ clean_dict = OrderedDict()
+ for k, v in raw_dict.items():
+ if k.startswith("module."):
+ clean_dict[k[7:]] = v
+ else:
+ clean_dict[k] = v
+
+ self.model.load_state_dict(clean_dict)
+
+ def _build_test_dataset(self):
+ return FS2TestDataset, FS2TestCollator
+
+ @staticmethod
+ def _parse_vocoder(vocoder_dir):
+ r"""Parse vocoder config"""
+ vocoder_dir = os.path.abspath(vocoder_dir)
+ ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")]
+ # last step (different from the base *int(x.stem)*)
+ ckpt_list.sort(
+ key=lambda x: int(x.stem.split("_")[-2].split("-")[-1]), reverse=True
+ )
+ ckpt_path = str(ckpt_list[0])
+ vocoder_cfg = load_config(
+ os.path.join(vocoder_dir, "args.json"), lowercase=True
+ )
+ return vocoder_cfg, ckpt_path
+
+ @torch.inference_mode()
+ def inference_for_batches(self):
+ y_pred = []
+ for i, batch in tqdm(enumerate(self.test_dataloader)):
+ y_pred, mel_lens, _ = self._inference_each_batch(batch)
+ y_ls = y_pred.chunk(self.test_batch_size)
+ tgt_ls = mel_lens.chunk(self.test_batch_size)
+ j = 0
+ for it, l in zip(y_ls, tgt_ls):
+ l = l.item()
+ it = it.squeeze(0)[:l].detach().cpu()
+
+ uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"]
+ torch.save(it, os.path.join(self.args.output_dir, f"{uid}.pt"))
+ j += 1
+
+ vocoder_cfg, vocoder_ckpt = self._parse_vocoder(self.args.vocoder_dir)
+ res = synthesis(
+ cfg=vocoder_cfg,
+ vocoder_weight_file=vocoder_ckpt,
+ n_samples=None,
+ pred=[
+ torch.load(
+ os.path.join(self.args.output_dir, "{}.pt".format(item["Uid"]))
+ ).numpy()
+ for item in self.test_dataset.metadata
+ ],
+ )
+ for it, wav in zip(self.test_dataset.metadata, res):
+ uid = it["Uid"]
+ save_audio(
+ os.path.join(self.args.output_dir, f"{uid}.wav"),
+ wav.numpy(),
+ self.cfg.preprocess.sample_rate,
+ add_silence=True,
+ turn_up=True,
+ )
+ os.remove(os.path.join(self.args.output_dir, f"{uid}.pt"))
+
+ @torch.inference_mode()
+ def _inference_each_batch(self, batch_data):
+ device = self.accelerator.device
+ control_values = (
+ self.args.pitch_control,
+ self.args.energy_control,
+ self.args.duration_control,
+ )
+ for k, v in batch_data.items():
+ batch_data[k] = v.to(device)
+
+ pitch_control, energy_control, duration_control = control_values
+
+ output = self.model(
+ batch_data,
+ p_control=pitch_control,
+ e_control=energy_control,
+ d_control=duration_control,
+ )
+ pred_res = output["postnet_output"]
+ mel_lens = output["mel_lens"].cpu()
+ return pred_res, mel_lens, 0
+
+ def inference_for_single_utterance(self):
+ text = self.args.text
+ control_values = (
+ self.args.pitch_control,
+ self.args.energy_control,
+ self.args.duration_control,
+ )
+ pitch_control, energy_control, duration_control = control_values
+
+ # get phone symbol file
+ phone_symbol_file = None
+ if self.cfg.preprocess.phone_extractor != "lexicon":
+ phone_symbol_file = os.path.join(
+ self.exp_dir, self.cfg.preprocess.symbols_dict
+ )
+ assert os.path.exists(phone_symbol_file)
+ # convert text to phone sequence
+ phone_extractor = phoneExtractor(self.cfg)
+
+ phone_seq = phone_extractor.extract_phone(text) # phone_seq: list
+ # convert phone sequence to phone id sequence
+ phon_id_collator = phoneIDCollation(
+ self.cfg, symbols_dict_file=phone_symbol_file
+ )
+ phone_seq = ["{"] + phone_seq + ["}"]
+ phone_id_seq = phon_id_collator.get_phone_id_sequence(self.cfg, phone_seq)
+
+ # convert phone sequence to phone id sequence
+ phone_id_seq = np.array(phone_id_seq)
+ phone_id_seq = torch.from_numpy(phone_id_seq)
+
+ # get speaker id if multi-speaker training and use speaker id
+ speaker_id = None
+ if self.cfg.preprocess.use_spkid and self.cfg.train.multi_speaker_training:
+ spk2id_file = os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)
+ with open(spk2id_file, "r") as f:
+ spk2id = json.load(f)
+ speaker_id = spk2id[self.args.speaker_name]
+ speaker_id = torch.from_numpy(np.array([speaker_id], dtype=np.int32))
+ else:
+ speaker_id = torch.Tensor(0).view(-1)
+
+ with torch.no_grad():
+ x_tst = phone_id_seq.to(self.device).unsqueeze(0)
+ x_tst_lengths = torch.LongTensor([phone_id_seq.size(0)]).to(self.device)
+ if speaker_id is not None:
+ speaker_id = speaker_id.to(self.device)
+
+ data = {}
+ data["texts"] = x_tst
+ data["text_len"] = x_tst_lengths
+ data["spk_id"] = speaker_id
+
+ output = self.model(
+ data,
+ p_control=pitch_control,
+ e_control=energy_control,
+ d_control=duration_control,
+ )
+ pred_res = output["postnet_output"]
+ vocoder_cfg, vocoder_ckpt = self._parse_vocoder(self.args.vocoder_dir)
+ audio = synthesis(
+ cfg=vocoder_cfg,
+ vocoder_weight_file=vocoder_ckpt,
+ n_samples=None,
+ pred=pred_res,
+ )
+ return audio[0]
diff --git a/models/tts/fastspeech2/fs2_trainer.py b/models/tts/fastspeech2/fs2_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0bc1201b192e30339c8f6e877467cc6b1f7a938
--- /dev/null
+++ b/models/tts/fastspeech2/fs2_trainer.py
@@ -0,0 +1,155 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from tqdm import tqdm
+from models.tts.base import TTSTrainer
+from models.tts.fastspeech2.fs2 import FastSpeech2, FastSpeech2Loss
+from models.tts.fastspeech2.fs2_dataset import FS2Dataset, FS2Collator
+from optimizer.optimizers import NoamLR
+
+
+class FastSpeech2Trainer(TTSTrainer):
+ def __init__(self, args, cfg):
+ TTSTrainer.__init__(self, args, cfg)
+ self.cfg = cfg
+
+ def _build_dataset(self):
+ return FS2Dataset, FS2Collator
+
+ def __build_scheduler(self):
+ return NoamLR(self.optimizer, **self.cfg.train.lr_scheduler)
+
+ def _write_summary(self, losses, stats):
+ for key, value in losses.items():
+ self.sw.add_scalar("train/" + key, value, self.step)
+ lr = self.optimizer.state_dict()["param_groups"][0]["lr"]
+ self.sw.add_scalar("learning_rate", lr, self.step)
+
+ def _write_valid_summary(self, losses, stats):
+ for key, value in losses.items():
+ self.sw.add_scalar("val/" + key, value, self.step)
+
+ def _build_criterion(self):
+ return FastSpeech2Loss(self.cfg)
+
+ def get_state_dict(self):
+ state_dict = {
+ "model": self.model.state_dict(),
+ "optimizer": self.optimizer.state_dict(),
+ "scheduler": self.scheduler.state_dict(),
+ "step": self.step,
+ "epoch": self.epoch,
+ "batch_size": self.cfg.train.batch_size,
+ }
+ return state_dict
+
+ def _build_optimizer(self):
+ optimizer = torch.optim.Adam(self.model.parameters(), **self.cfg.train.adam)
+ return optimizer
+
+ def _build_scheduler(self):
+ scheduler = NoamLR(self.optimizer, **self.cfg.train.lr_scheduler)
+ return scheduler
+
+ def _build_model(self):
+ self.model = FastSpeech2(self.cfg)
+ return self.model
+
+ def _train_epoch(self):
+ r"""Training epoch. Should return average loss of a batch (sample) over
+ one epoch. See ``train_loop`` for usage.
+ """
+ self.model.train()
+ epoch_sum_loss: float = 0.0
+ epoch_step: int = 0
+ epoch_losses: dict = {}
+ for batch in tqdm(
+ self.train_dataloader,
+ desc=f"Training Epoch {self.epoch}",
+ unit="batch",
+ colour="GREEN",
+ leave=False,
+ dynamic_ncols=True,
+ smoothing=0.04,
+ disable=not self.accelerator.is_main_process,
+ ):
+ # Do training step and BP
+ with self.accelerator.accumulate(self.model):
+ loss, train_losses = self._train_step(batch)
+ self.accelerator.backward(loss)
+ grad_clip_thresh = self.cfg.train.grad_clip_thresh
+ nn.utils.clip_grad_norm_(self.model.parameters(), grad_clip_thresh)
+ self.optimizer.step()
+ self.scheduler.step()
+ self.optimizer.zero_grad()
+ self.batch_count += 1
+
+ # Update info for each step
+ if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
+ epoch_sum_loss += loss
+ for key, value in train_losses.items():
+ if key not in epoch_losses.keys():
+ epoch_losses[key] = value
+ else:
+ epoch_losses[key] += value
+
+ self.accelerator.log(
+ {
+ "Step/Train Loss": loss,
+ "Step/Learning Rate": self.optimizer.param_groups[0]["lr"],
+ },
+ step=self.step,
+ )
+ self.step += 1
+ epoch_step += 1
+
+ self.accelerator.wait_for_everyone()
+
+ epoch_sum_loss = (
+ epoch_sum_loss
+ / len(self.train_dataloader)
+ * self.cfg.train.gradient_accumulation_step
+ )
+
+ for key in epoch_losses.keys():
+ epoch_losses[key] = (
+ epoch_losses[key]
+ / len(self.train_dataloader)
+ * self.cfg.train.gradient_accumulation_step
+ )
+ return epoch_sum_loss, epoch_losses
+
+ def _train_step(self, data):
+ train_losses = {}
+ total_loss = 0
+ train_stats = {}
+
+ preds = self.model(data)
+
+ train_losses = self.criterion(data, preds)
+
+ total_loss = train_losses["loss"]
+ for key, value in train_losses.items():
+ train_losses[key] = value.item()
+
+ return total_loss, train_losses
+
+ @torch.no_grad()
+ def _valid_step(self, data):
+ valid_loss = {}
+ total_valid_loss = 0
+ valid_stats = {}
+
+ preds = self.model(data)
+
+ valid_losses = self.criterion(data, preds)
+
+ total_valid_loss = valid_losses["loss"]
+ for key, value in valid_losses.items():
+ valid_losses[key] = value.item()
+
+ return total_valid_loss, valid_losses, valid_stats
diff --git a/models/tts/jets/__init__.py b/models/tts/jets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/tts/jets/alignments.py b/models/tts/jets/alignments.py
new file mode 100644
index 0000000000000000000000000000000000000000..a96a5000ca9d995a254049b1e54667ba0b67252b
--- /dev/null
+++ b/models/tts/jets/alignments.py
@@ -0,0 +1,487 @@
+# Copyright (c) 2024 Amphion.
+#
+# This code is modified from https://github.com/imdanboy/jets/blob/main/espnet2/gan_tts/jets/alignments.py
+# Licensed under Apache License 2.0
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import Tuple
+from numba import jit
+from scipy.stats import betabinom
+
+
+class AlignmentModule(nn.Module):
+ """Alignment Learning Framework proposed for parallel TTS models in:
+
+ https://arxiv.org/abs/2108.10447
+
+ """
+
+ def __init__(self, adim, odim, cache_prior=True):
+ """Initialize AlignmentModule.
+
+ Args:
+ adim (int): Dimension of attention.
+ odim (int): Dimension of feats.
+ cache_prior (bool): Whether to cache beta-binomial prior.
+
+ """
+ super().__init__()
+ self.cache_prior = cache_prior
+ self._cache = {}
+
+ self.t_conv1 = nn.Conv1d(adim, adim, kernel_size=3, padding=1)
+ self.t_conv2 = nn.Conv1d(adim, adim, kernel_size=1, padding=0)
+
+ self.f_conv1 = nn.Conv1d(odim, adim, kernel_size=3, padding=1)
+ self.f_conv2 = nn.Conv1d(adim, adim, kernel_size=3, padding=1)
+ self.f_conv3 = nn.Conv1d(adim, adim, kernel_size=1, padding=0)
+
+ def forward(self, text, feats, text_lengths, feats_lengths, x_masks=None):
+ """Calculate alignment loss.
+
+ Args:
+ text (Tensor): Batched text embedding (B, T_text, adim).
+ feats (Tensor): Batched acoustic feature (B, T_feats, odim).
+ text_lengths (Tensor): Text length tensor (B,).
+ feats_lengths (Tensor): Feature length tensor (B,).
+ x_masks (Tensor): Mask tensor (B, T_text).
+
+ Returns:
+ Tensor: Log probability of attention matrix (B, T_feats, T_text).
+
+ """
+ text = text.transpose(1, 2)
+ text = F.relu(self.t_conv1(text))
+ text = self.t_conv2(text)
+ text = text.transpose(1, 2)
+
+ feats = feats.transpose(1, 2)
+ feats = F.relu(self.f_conv1(feats))
+ feats = F.relu(self.f_conv2(feats))
+ feats = self.f_conv3(feats)
+ feats = feats.transpose(1, 2)
+
+ dist = feats.unsqueeze(2) - text.unsqueeze(1)
+ dist = torch.norm(dist, p=2, dim=3)
+ score = -dist
+
+ if x_masks is not None:
+ x_masks = x_masks.unsqueeze(-2)
+ score = score.masked_fill(x_masks, -np.inf)
+
+ log_p_attn = F.log_softmax(score, dim=-1)
+
+ # add beta-binomial prior
+ bb_prior = self._generate_prior(
+ text_lengths,
+ feats_lengths,
+ ).to(dtype=log_p_attn.dtype, device=log_p_attn.device)
+ log_p_attn = log_p_attn + bb_prior
+
+ return log_p_attn
+
+ def _generate_prior(self, text_lengths, feats_lengths, w=1) -> torch.Tensor:
+ """Generate alignment prior formulated as beta-binomial distribution
+
+ Args:
+ text_lengths (Tensor): Batch of the lengths of each input (B,).
+ feats_lengths (Tensor): Batch of the lengths of each target (B,).
+ w (float): Scaling factor; lower -> wider the width.
+
+ Returns:
+ Tensor: Batched 2d static prior matrix (B, T_feats, T_text).
+
+ """
+ B = len(text_lengths)
+ T_text = text_lengths.max()
+ T_feats = feats_lengths.max()
+
+ bb_prior = torch.full((B, T_feats, T_text), fill_value=-np.inf)
+ for bidx in range(B):
+ T = feats_lengths[bidx].item()
+ N = text_lengths[bidx].item()
+
+ key = str(T) + "," + str(N)
+ if self.cache_prior and key in self._cache:
+ prob = self._cache[key]
+ else:
+ alpha = w * np.arange(1, T + 1, dtype=float) # (T,)
+ beta = w * np.array([T - t + 1 for t in alpha])
+ k = np.arange(N)
+ batched_k = k[..., None] # (N,1)
+ prob = betabinom.logpmf(batched_k, N, alpha, beta) # (N,T)
+
+ # store cache
+ if self.cache_prior and key not in self._cache:
+ self._cache[key] = prob
+
+ prob = torch.from_numpy(prob).transpose(0, 1) # -> (T,N)
+ bb_prior[bidx, :T, :N] = prob
+
+ return bb_prior
+
+
+@jit(nopython=True)
+def _monotonic_alignment_search(log_p_attn):
+ # https://arxiv.org/abs/2005.11129
+ T_mel = log_p_attn.shape[0]
+ T_inp = log_p_attn.shape[1]
+ Q = np.full((T_inp, T_mel), fill_value=-np.inf)
+
+ log_prob = log_p_attn.transpose(1, 0) # -> (T_inp,T_mel)
+ # 1. Q <- init first row for all j
+ for j in range(T_mel):
+ Q[0, j] = log_prob[0, : j + 1].sum()
+
+ # 2.
+ for j in range(1, T_mel):
+ for i in range(1, min(j + 1, T_inp)):
+ Q[i, j] = max(Q[i - 1, j - 1], Q[i, j - 1]) + log_prob[i, j]
+
+ # 3.
+ A = np.full((T_mel,), fill_value=T_inp - 1)
+ for j in range(T_mel - 2, -1, -1): # T_mel-2, ..., 0
+ # 'i' in {A[j+1]-1, A[j+1]}
+ i_a = A[j + 1] - 1
+ i_b = A[j + 1]
+ if i_b == 0:
+ argmax_i = 0
+ elif Q[i_a, j] >= Q[i_b, j]:
+ argmax_i = i_a
+ else:
+ argmax_i = i_b
+ A[j] = argmax_i
+ return A
+
+
+def viterbi_decode(log_p_attn, text_lengths, feats_lengths):
+ """Extract duration from an attention probability matrix
+
+ Args:
+ log_p_attn (Tensor): Batched log probability of attention
+ matrix (B, T_feats, T_text).
+ text_lengths (Tensor): Text length tensor (B,).
+ feats_legnths (Tensor): Feature length tensor (B,).
+
+ Returns:
+ Tensor: Batched token duration extracted from `log_p_attn` (B, T_text).
+ Tensor: Binarization loss tensor ().
+
+ """
+ B = log_p_attn.size(0)
+ T_text = log_p_attn.size(2)
+ device = log_p_attn.device
+
+ bin_loss = 0
+ ds = torch.zeros((B, T_text), device=device)
+ for b in range(B):
+ cur_log_p_attn = log_p_attn[b, : feats_lengths[b], : text_lengths[b]]
+ viterbi = _monotonic_alignment_search(cur_log_p_attn.detach().cpu().numpy())
+ _ds = np.bincount(viterbi)
+ ds[b, : len(_ds)] = torch.from_numpy(_ds).to(device)
+
+ t_idx = torch.arange(feats_lengths[b])
+ bin_loss = bin_loss - cur_log_p_attn[t_idx, viterbi].mean()
+ bin_loss = bin_loss / B
+ return ds, bin_loss
+
+
+@jit(nopython=True)
+def _average_by_duration(ds, xs, text_lengths, feats_lengths):
+ B = ds.shape[0]
+ xs_avg = np.zeros_like(ds)
+ ds = ds.astype(np.int32)
+ for b in range(B):
+ t_text = text_lengths[b]
+ t_feats = feats_lengths[b]
+ d = ds[b, :t_text]
+ d_cumsum = d.cumsum()
+ d_cumsum = [0] + list(d_cumsum)
+ x = xs[b, :t_feats]
+ for n, (start, end) in enumerate(zip(d_cumsum[:-1], d_cumsum[1:])):
+ if len(x[start:end]) != 0:
+ xs_avg[b, n] = x[start:end].mean()
+ else:
+ xs_avg[b, n] = 0
+ return xs_avg
+
+
+def average_by_duration(ds, xs, text_lengths, feats_lengths):
+ """Average frame-level features into token-level according to durations
+
+ Args:
+ ds (Tensor): Batched token duration (B, T_text).
+ xs (Tensor): Batched feature sequences to be averaged (B, T_feats).
+ text_lengths (Tensor): Text length tensor (B,).
+ feats_lengths (Tensor): Feature length tensor (B,).
+
+ Returns:
+ Tensor: Batched feature averaged according to the token duration (B, T_text).
+
+ """
+ device = ds.device
+ args = [ds, xs, text_lengths, feats_lengths]
+ args = [arg.detach().cpu().numpy() for arg in args]
+ xs_avg = _average_by_duration(*args)
+ xs_avg = torch.from_numpy(xs_avg).to(device)
+ return xs_avg
+
+
+def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
+ """Make mask tensor containing indices of padded part.
+
+ Args:
+ lengths (LongTensor or List): Batch of lengths (B,).
+ xs (Tensor, optional): The reference tensor.
+ If set, masks will be the same shape as this tensor.
+ length_dim (int, optional): Dimension indicator of the above tensor.
+ See the example.
+
+ Returns:
+ Tensor: Mask tensor containing indices of padded part.
+ dtype=torch.uint8 in PyTorch 1.2-
+ dtype=torch.bool in PyTorch 1.2+ (including 1.2)
+
+ Examples:
+ With only lengths.
+
+ >>> lengths = [5, 3, 2]
+ >>> make_pad_mask(lengths)
+ masks = [[0, 0, 0, 0 ,0],
+ [0, 0, 0, 1, 1],
+ [0, 0, 1, 1, 1]]
+
+ With the reference tensor.
+
+ >>> xs = torch.zeros((3, 2, 4))
+ >>> make_pad_mask(lengths, xs)
+ tensor([[[0, 0, 0, 0],
+ [0, 0, 0, 0]],
+ [[0, 0, 0, 1],
+ [0, 0, 0, 1]],
+ [[0, 0, 1, 1],
+ [0, 0, 1, 1]]], dtype=torch.uint8)
+ >>> xs = torch.zeros((3, 2, 6))
+ >>> make_pad_mask(lengths, xs)
+ tensor([[[0, 0, 0, 0, 0, 1],
+ [0, 0, 0, 0, 0, 1]],
+ [[0, 0, 0, 1, 1, 1],
+ [0, 0, 0, 1, 1, 1]],
+ [[0, 0, 1, 1, 1, 1],
+ [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
+
+ With the reference tensor and dimension indicator.
+
+ >>> xs = torch.zeros((3, 6, 6))
+ >>> make_pad_mask(lengths, xs, 1)
+ tensor([[[0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [1, 1, 1, 1, 1, 1]],
+ [[0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1]],
+ [[0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
+ >>> make_pad_mask(lengths, xs, 2)
+ tensor([[[0, 0, 0, 0, 0, 1],
+ [0, 0, 0, 0, 0, 1],
+ [0, 0, 0, 0, 0, 1],
+ [0, 0, 0, 0, 0, 1],
+ [0, 0, 0, 0, 0, 1],
+ [0, 0, 0, 0, 0, 1]],
+ [[0, 0, 0, 1, 1, 1],
+ [0, 0, 0, 1, 1, 1],
+ [0, 0, 0, 1, 1, 1],
+ [0, 0, 0, 1, 1, 1],
+ [0, 0, 0, 1, 1, 1],
+ [0, 0, 0, 1, 1, 1]],
+ [[0, 0, 1, 1, 1, 1],
+ [0, 0, 1, 1, 1, 1],
+ [0, 0, 1, 1, 1, 1],
+ [0, 0, 1, 1, 1, 1],
+ [0, 0, 1, 1, 1, 1],
+ [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
+
+ """
+ if length_dim == 0:
+ raise ValueError("length_dim cannot be 0: {}".format(length_dim))
+
+ if not isinstance(lengths, list):
+ lengths = lengths.tolist()
+ bs = int(len(lengths))
+ if maxlen is None:
+ if xs is None:
+ maxlen = int(max(lengths))
+ else:
+ maxlen = xs.size(length_dim)
+ else:
+ assert xs is None
+ assert maxlen >= int(max(lengths))
+
+ seq_range = torch.arange(0, maxlen, dtype=torch.int64)
+ seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
+ seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
+ mask = seq_range_expand >= seq_length_expand
+
+ if xs is not None:
+ assert xs.size(0) == bs, (xs.size(0), bs)
+
+ if length_dim < 0:
+ length_dim = xs.dim() + length_dim
+ # ind = (:, None, ..., None, :, , None, ..., None)
+ ind = tuple(
+ slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
+ )
+ mask = mask[ind].expand_as(xs).to(xs.device)
+ return mask
+
+
+def make_non_pad_mask(lengths, xs=None, length_dim=-1):
+ """Make mask tensor containing indices of non-padded part.
+
+ Args:
+ lengths (LongTensor or List): Batch of lengths (B,).
+ xs (Tensor, optional): The reference tensor.
+ If set, masks will be the same shape as this tensor.
+ length_dim (int, optional): Dimension indicator of the above tensor.
+ See the example.
+
+ Returns:
+ ByteTensor: mask tensor containing indices of padded part.
+ dtype=torch.uint8 in PyTorch 1.2-
+ dtype=torch.bool in PyTorch 1.2+ (including 1.2)
+
+ Examples:
+ With only lengths.
+
+ >>> lengths = [5, 3, 2]
+ >>> make_non_pad_mask(lengths)
+ masks = [[1, 1, 1, 1 ,1],
+ [1, 1, 1, 0, 0],
+ [1, 1, 0, 0, 0]]
+
+ With the reference tensor.
+
+ >>> xs = torch.zeros((3, 2, 4))
+ >>> make_non_pad_mask(lengths, xs)
+ tensor([[[1, 1, 1, 1],
+ [1, 1, 1, 1]],
+ [[1, 1, 1, 0],
+ [1, 1, 1, 0]],
+ [[1, 1, 0, 0],
+ [1, 1, 0, 0]]], dtype=torch.uint8)
+ >>> xs = torch.zeros((3, 2, 6))
+ >>> make_non_pad_mask(lengths, xs)
+ tensor([[[1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 0]],
+ [[1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 0, 0, 0]],
+ [[1, 1, 0, 0, 0, 0],
+ [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
+
+ With the reference tensor and dimension indicator.
+
+ >>> xs = torch.zeros((3, 6, 6))
+ >>> make_non_pad_mask(lengths, xs, 1)
+ tensor([[[1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [0, 0, 0, 0, 0, 0]],
+ [[1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0]],
+ [[1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
+ >>> make_non_pad_mask(lengths, xs, 2)
+ tensor([[[1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 0]],
+ [[1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 0, 0, 0]],
+ [[1, 1, 0, 0, 0, 0],
+ [1, 1, 0, 0, 0, 0],
+ [1, 1, 0, 0, 0, 0],
+ [1, 1, 0, 0, 0, 0],
+ [1, 1, 0, 0, 0, 0],
+ [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
+
+ """
+ return ~make_pad_mask(lengths, xs, length_dim)
+
+
+def get_random_segments(
+ x: torch.Tensor,
+ x_lengths: torch.Tensor,
+ segment_size: int,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Get random segments.
+
+ Args:
+ x (Tensor): Input tensor (B, C, T).
+ x_lengths (Tensor): Length tensor (B,).
+ segment_size (int): Segment size.
+
+ Returns:
+ Tensor: Segmented tensor (B, C, segment_size).
+ Tensor: Start index tensor (B,).
+
+ """
+ b, c, t = x.size()
+ max_start_idx = x_lengths - segment_size
+ start_idxs = (torch.rand([b]).to(x.device) * max_start_idx).to(
+ dtype=torch.long,
+ )
+ segments = get_segments(x, start_idxs, segment_size)
+ return segments, start_idxs
+
+
+def get_segments(
+ x: torch.Tensor,
+ start_idxs: torch.Tensor,
+ segment_size: int,
+) -> torch.Tensor:
+ """Get segments.
+
+ Args:
+ x (Tensor): Input tensor (B, C, T).
+ start_idxs (Tensor): Start index tensor (B,).
+ segment_size (int): Segment size.
+
+ Returns:
+ Tensor: Segmented tensor (B, C, segment_size).
+
+ """
+ b, c, t = x.size()
+ segments = x.new_zeros(b, c, segment_size)
+ for i, start_idx in enumerate(start_idxs):
+ segments[i] = x[i, :, start_idx : start_idx + segment_size]
+ return segments
diff --git a/models/tts/jets/jets.py b/models/tts/jets/jets.py
new file mode 100644
index 0000000000000000000000000000000000000000..3940b3722748030ba3f03d5e14f0577a82bb0f4e
--- /dev/null
+++ b/models/tts/jets/jets.py
@@ -0,0 +1,621 @@
+# Copyright (c) 2024 Amphion.
+#
+# This code is modified from https://github.com/imdanboy/jets/blob/main/espnet2/gan_tts/jets/generator.py
+# Licensed under Apache License 2.0
+
+import torch
+import torch.nn as nn
+import numpy as np
+import torch.nn.functional as F
+
+from modules.transformer.Models import Encoder, Decoder
+from modules.transformer.Layers import PostNet
+from collections import OrderedDict
+from models.tts.jets.alignments import (
+ AlignmentModule,
+ viterbi_decode,
+ average_by_duration,
+ make_pad_mask,
+ make_non_pad_mask,
+ get_random_segments,
+)
+from models.tts.jets.length_regulator import GaussianUpsampling
+from models.vocoders.gan.generator.hifigan import HiFiGAN
+import os
+import json
+
+from utils.util import load_config
+
+
+def get_mask_from_lengths(lengths, max_len=None):
+ device = lengths.device
+ batch_size = lengths.shape[0]
+ if max_len is None:
+ max_len = torch.max(lengths).item()
+
+ ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device)
+ mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
+
+ return mask
+
+
+def pad(input_ele, mel_max_length=None):
+ if mel_max_length:
+ max_len = mel_max_length
+ else:
+ max_len = max([input_ele[i].size(0) for i in range(len(input_ele))])
+
+ out_list = list()
+ for i, batch in enumerate(input_ele):
+ if len(batch.shape) == 1:
+ one_batch_padded = F.pad(
+ batch, (0, max_len - batch.size(0)), "constant", 0.0
+ )
+ elif len(batch.shape) == 2:
+ one_batch_padded = F.pad(
+ batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0
+ )
+ out_list.append(one_batch_padded)
+ out_padded = torch.stack(out_list)
+ return out_padded
+
+
+class VarianceAdaptor(nn.Module):
+ """Variance Adaptor"""
+
+ def __init__(self, cfg):
+ super(VarianceAdaptor, self).__init__()
+ self.duration_predictor = VariancePredictor(cfg)
+ self.length_regulator = LengthRegulator()
+ self.pitch_predictor = VariancePredictor(cfg)
+ self.energy_predictor = VariancePredictor(cfg)
+
+ # assign the pitch/energy feature level
+ if cfg.preprocess.use_frame_pitch:
+ self.pitch_feature_level = "frame_level"
+ self.pitch_dir = cfg.preprocess.pitch_dir
+ else:
+ self.pitch_feature_level = "phoneme_level"
+ self.pitch_dir = cfg.preprocess.phone_pitch_dir
+
+ if cfg.preprocess.use_frame_energy:
+ self.energy_feature_level = "frame_level"
+ self.energy_dir = cfg.preprocess.energy_dir
+ else:
+ self.energy_feature_level = "phoneme_level"
+ self.energy_dir = cfg.preprocess.phone_energy_dir
+
+ assert self.pitch_feature_level in ["phoneme_level", "frame_level"]
+ assert self.energy_feature_level in ["phoneme_level", "frame_level"]
+
+ pitch_quantization = cfg.model.variance_embedding.pitch_quantization
+ energy_quantization = cfg.model.variance_embedding.energy_quantization
+ n_bins = cfg.model.variance_embedding.n_bins
+ assert pitch_quantization in ["linear", "log"]
+ assert energy_quantization in ["linear", "log"]
+
+ with open(
+ os.path.join(
+ cfg.preprocess.processed_dir,
+ cfg.dataset[0],
+ self.energy_dir,
+ "statistics.json",
+ )
+ ) as f:
+ stats = json.load(f)
+ stats = stats[cfg.dataset[0] + "_" + cfg.dataset[0]]
+ mean, std = (
+ stats["voiced_positions"]["mean"],
+ stats["voiced_positions"]["std"],
+ )
+ energy_min = (stats["total_positions"]["min"] - mean) / std
+ energy_max = (stats["total_positions"]["max"] - mean) / std
+
+ with open(
+ os.path.join(
+ cfg.preprocess.processed_dir,
+ cfg.dataset[0],
+ self.pitch_dir,
+ "statistics.json",
+ )
+ ) as f:
+ stats = json.load(f)
+ stats = stats[cfg.dataset[0] + "_" + cfg.dataset[0]]
+ mean, std = (
+ stats["voiced_positions"]["mean"],
+ stats["voiced_positions"]["std"],
+ )
+ pitch_min = (stats["total_positions"]["min"] - mean) / std
+ pitch_max = (stats["total_positions"]["max"] - mean) / std
+
+ if pitch_quantization == "log":
+ self.pitch_bins = nn.Parameter(
+ torch.exp(
+ torch.linspace(np.log(pitch_min), np.log(pitch_max), n_bins - 1)
+ ),
+ requires_grad=False,
+ )
+ else:
+ self.pitch_bins = nn.Parameter(
+ torch.linspace(pitch_min, pitch_max, n_bins - 1),
+ requires_grad=False,
+ )
+ if energy_quantization == "log":
+ self.energy_bins = nn.Parameter(
+ torch.exp(
+ torch.linspace(np.log(energy_min), np.log(energy_max), n_bins - 1)
+ ),
+ requires_grad=False,
+ )
+ else:
+ self.energy_bins = nn.Parameter(
+ torch.linspace(energy_min, energy_max, n_bins - 1),
+ requires_grad=False,
+ )
+
+ self.pitch_embedding = nn.Embedding(
+ n_bins, cfg.model.transformer.encoder_hidden
+ )
+ self.energy_embedding = nn.Embedding(
+ n_bins, cfg.model.transformer.encoder_hidden
+ )
+
+ def get_pitch_embedding(self, x, target, mask, control):
+ prediction = self.pitch_predictor(x, mask)
+ if target is not None:
+ embedding = self.pitch_embedding(torch.bucketize(target, self.pitch_bins))
+ else:
+ prediction = prediction * control
+ embedding = self.pitch_embedding(
+ torch.bucketize(prediction, self.pitch_bins)
+ )
+ return prediction, embedding
+
+ def get_energy_embedding(self, x, target, mask, control):
+ prediction = self.energy_predictor(x, mask)
+ if target is not None:
+ embedding = self.energy_embedding(torch.bucketize(target, self.energy_bins))
+ else:
+ prediction = prediction * control
+ embedding = self.energy_embedding(
+ torch.bucketize(prediction, self.energy_bins)
+ )
+ return prediction, embedding
+
+ def forward(
+ self,
+ x,
+ src_mask,
+ mel_mask=None,
+ max_len=None,
+ pitch_target=None,
+ energy_target=None,
+ duration_target=None,
+ p_control=1.0,
+ e_control=1.0,
+ d_control=1.0,
+ pitch_embedding=None,
+ energy_embedding=None,
+ ):
+ log_duration_prediction = self.duration_predictor(x, src_mask)
+
+ x = x + pitch_embedding
+ x = x + energy_embedding
+
+ pitch_prediction = self.pitch_predictor(x, src_mask)
+ energy_prediction = self.energy_predictor(x, src_mask)
+
+ if duration_target is not None:
+ x, mel_len = self.length_regulator(x, duration_target, max_len)
+ duration_rounded = duration_target
+ else:
+ duration_rounded = torch.clamp(
+ (torch.round(torch.exp(log_duration_prediction) - 1) * d_control),
+ min=0,
+ )
+ x, mel_len = self.length_regulator(x, duration_rounded, max_len)
+ mel_mask = get_mask_from_lengths(mel_len)
+
+ return (
+ x,
+ pitch_prediction,
+ energy_prediction,
+ log_duration_prediction,
+ duration_rounded,
+ mel_len,
+ mel_mask,
+ )
+
+ def inference(
+ self,
+ x,
+ src_mask,
+ mel_mask=None,
+ max_len=None,
+ pitch_target=None,
+ energy_target=None,
+ duration_target=None,
+ p_control=1.0,
+ e_control=1.0,
+ d_control=1.0,
+ pitch_embedding=None,
+ energy_embedding=None,
+ ):
+
+ p_outs = self.pitch_predictor(x, src_mask)
+ e_outs = self.energy_predictor(x, src_mask)
+ d_outs = self.duration_predictor(x, src_mask)
+
+ return p_outs, e_outs, d_outs
+
+
+class LengthRegulator(nn.Module):
+ """Length Regulator"""
+
+ def __init__(self):
+ super(LengthRegulator, self).__init__()
+
+ def LR(self, x, duration, max_len):
+ device = x.device
+ output = list()
+ mel_len = list()
+ for batch, expand_target in zip(x, duration):
+ expanded = self.expand(batch, expand_target)
+ output.append(expanded)
+ mel_len.append(expanded.shape[0])
+
+ if max_len is not None:
+ output = pad(output, max_len)
+ else:
+ output = pad(output)
+
+ return output, torch.LongTensor(mel_len).to(device)
+
+ def expand(self, batch, predicted):
+ out = list()
+
+ for i, vec in enumerate(batch):
+ expand_size = predicted[i].item()
+ out.append(vec.expand(max(int(expand_size), 0), -1))
+ out = torch.cat(out, 0)
+
+ return out
+
+ def forward(self, x, duration, max_len):
+ output, mel_len = self.LR(x, duration, max_len)
+ return output, mel_len
+
+
+class VariancePredictor(nn.Module):
+ """Duration, Pitch and Energy Predictor"""
+
+ def __init__(self, cfg):
+ super(VariancePredictor, self).__init__()
+
+ self.input_size = cfg.model.transformer.encoder_hidden
+ self.filter_size = cfg.model.variance_predictor.filter_size
+ self.kernel = cfg.model.variance_predictor.kernel_size
+ self.conv_output_size = cfg.model.variance_predictor.filter_size
+ self.dropout = cfg.model.variance_predictor.dropout
+
+ self.conv_layer = nn.Sequential(
+ OrderedDict(
+ [
+ (
+ "conv1d_1",
+ Conv(
+ self.input_size,
+ self.filter_size,
+ kernel_size=self.kernel,
+ padding=(self.kernel - 1) // 2,
+ ),
+ ),
+ ("relu_1", nn.ReLU()),
+ ("layer_norm_1", nn.LayerNorm(self.filter_size)),
+ ("dropout_1", nn.Dropout(self.dropout)),
+ (
+ "conv1d_2",
+ Conv(
+ self.filter_size,
+ self.filter_size,
+ kernel_size=self.kernel,
+ padding=1,
+ ),
+ ),
+ ("relu_2", nn.ReLU()),
+ ("layer_norm_2", nn.LayerNorm(self.filter_size)),
+ ("dropout_2", nn.Dropout(self.dropout)),
+ ]
+ )
+ )
+
+ self.linear_layer = nn.Linear(self.conv_output_size, 1)
+
+ def forward(self, encoder_output, mask):
+ out = self.conv_layer(encoder_output)
+ out = self.linear_layer(out)
+ out = out.squeeze(-1)
+
+ if mask is not None:
+ out = out.masked_fill(mask, 0.0)
+
+ return out
+
+
+class Conv(nn.Module):
+ """
+ Convolution Module
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ dilation=1,
+ bias=True,
+ w_init="linear",
+ ):
+ """
+ :param in_channels: dimension of input
+ :param out_channels: dimension of output
+ :param kernel_size: size of kernel
+ :param stride: size of stride
+ :param padding: size of padding
+ :param dilation: dilation rate
+ :param bias: boolean. if True, bias is included.
+ :param w_init: str. weight inits with xavier initialization.
+ """
+ super(Conv, self).__init__()
+
+ self.conv = nn.Conv1d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ bias=bias,
+ )
+
+ def forward(self, x):
+ x = x.contiguous().transpose(1, 2)
+ x = self.conv(x)
+ x = x.contiguous().transpose(1, 2)
+
+ return x
+
+
+class Jets(nn.Module):
+ def __init__(self, cfg) -> None:
+ super(Jets, self).__init__()
+ self.cfg = cfg
+ self.encoder = Encoder(cfg.model)
+ self.variance_adaptor = VarianceAdaptor(cfg)
+ self.decoder = Decoder(cfg.model)
+ self.length_regulator_infer = LengthRegulator()
+ self.mel_linear = nn.Linear(
+ cfg.model.transformer.decoder_hidden,
+ cfg.preprocess.n_mel,
+ )
+ self.postnet = PostNet(n_mel_channels=cfg.preprocess.n_mel)
+
+ self.speaker_emb = None
+ if cfg.train.multi_speaker_training:
+ with open(
+ os.path.join(
+ cfg.preprocess.processed_dir, cfg.dataset[0], "spk2id.json"
+ ),
+ "r",
+ ) as f:
+ n_speaker = len(json.load(f))
+ self.speaker_emb = nn.Embedding(
+ n_speaker,
+ cfg.model.transformer.encoder_hidden,
+ )
+
+ output_dim = cfg.preprocess.n_mel
+ attention_dim = 256
+ self.alignment_module = AlignmentModule(attention_dim, output_dim)
+
+ # NOTE(kan-bayashi): We use continuous pitch + FastPitch style avg
+ pitch_embed_kernel_size: int = 9
+ pitch_embed_dropout: float = 0.5
+ self.pitch_embed = torch.nn.Sequential(
+ torch.nn.Conv1d(
+ in_channels=1,
+ out_channels=attention_dim,
+ kernel_size=pitch_embed_kernel_size,
+ padding=(pitch_embed_kernel_size - 1) // 2,
+ ),
+ torch.nn.Dropout(pitch_embed_dropout),
+ )
+
+ # NOTE(kan-bayashi): We use continuous enegy + FastPitch style avg
+ energy_embed_kernel_size: int = 9
+ energy_embed_dropout: float = 0.5
+ self.energy_embed = torch.nn.Sequential(
+ torch.nn.Conv1d(
+ in_channels=1,
+ out_channels=attention_dim,
+ kernel_size=energy_embed_kernel_size,
+ padding=(energy_embed_kernel_size - 1) // 2,
+ ),
+ torch.nn.Dropout(energy_embed_dropout),
+ )
+
+ # define length regulator
+ self.length_regulator = GaussianUpsampling()
+
+ self.segment_size = cfg.train.segment_size
+
+ # Define HiFiGAN generator
+ hifi_cfg = load_config("egs/vocoder/gan/hifigan/exp_config.json")
+ # hifi_cfg.model.hifigan.resblock_kernel_sizes = [3, 7, 11]
+ hifi_cfg.preprocess.n_mel = attention_dim
+ self.generator = HiFiGAN(hifi_cfg)
+
+ def _source_mask(self, ilens: torch.Tensor) -> torch.Tensor:
+ """Make masks for self-attention.
+
+ Args:
+ ilens (LongTensor): Batch of lengths (B,).
+
+ Returns:
+ Tensor: Mask tensor for self-attention.
+ dtype=torch.uint8 in PyTorch 1.2-
+ dtype=torch.bool in PyTorch 1.2+ (including 1.2)
+
+ Examples:
+ >>> ilens = [5, 3]
+ >>> self._source_mask(ilens)
+ tensor([[[1, 1, 1, 1, 1],
+ [1, 1, 1, 0, 0]]], dtype=torch.uint8)
+
+ """
+ x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device)
+ return x_masks.unsqueeze(-2)
+
+ def forward(self, data, p_control=1.0, e_control=1.0, d_control=1.0):
+ speakers = data["spk_id"]
+ texts = data["texts"]
+ src_lens = data["text_len"]
+ max_src_len = max(src_lens)
+ feats = data["mel"]
+ mel_lens = data["target_len"] if "target_len" in data else None
+ feats_lengths = mel_lens
+ max_mel_len = max(mel_lens) if "target_len" in data else None
+ p_targets = data["pitch"] if "pitch" in data else None
+ e_targets = data["energy"] if "energy" in data else None
+ src_masks = get_mask_from_lengths(src_lens, max_src_len)
+ mel_masks = (
+ get_mask_from_lengths(mel_lens, max_mel_len)
+ if mel_lens is not None
+ else None
+ )
+
+ output = self.encoder(texts, src_masks)
+
+ if self.speaker_emb is not None:
+ output = output + self.speaker_emb(speakers).unsqueeze(1).expand(
+ -1, max_src_len, -1
+ )
+
+ # Forward alignment module and obtain duration, averaged pitch, energy
+ h_masks = make_pad_mask(src_lens).to(output.device)
+ log_p_attn = self.alignment_module(
+ output, feats, src_lens, feats_lengths, h_masks
+ )
+ ds, bin_loss = viterbi_decode(log_p_attn, src_lens, feats_lengths)
+ ps = average_by_duration(
+ ds, p_targets.squeeze(-1), src_lens, feats_lengths
+ ).unsqueeze(-1)
+ es = average_by_duration(
+ ds, e_targets.squeeze(-1), src_lens, feats_lengths
+ ).unsqueeze(-1)
+ p_embs = self.pitch_embed(ps.transpose(1, 2)).transpose(1, 2)
+ e_embs = self.energy_embed(es.transpose(1, 2)).transpose(1, 2)
+
+ # FastSpeech2 variance adaptor
+ (
+ output,
+ p_predictions,
+ e_predictions,
+ log_d_predictions,
+ d_rounded,
+ mel_lens,
+ mel_masks,
+ ) = self.variance_adaptor(
+ output,
+ src_masks,
+ mel_masks,
+ max_mel_len,
+ p_targets,
+ e_targets,
+ ds,
+ p_control,
+ e_control,
+ d_control,
+ ps,
+ es,
+ )
+
+ # forward decoder
+ zs, _ = self.decoder(output, mel_masks) # (B, T_feats, adim)
+
+ # get random segments
+ z_segments, z_start_idxs = get_random_segments(
+ zs.transpose(1, 2),
+ feats_lengths,
+ self.segment_size,
+ )
+
+ # forward generator
+ wav = self.generator(z_segments)
+
+ return (
+ wav,
+ bin_loss,
+ log_p_attn,
+ z_start_idxs,
+ log_d_predictions,
+ ds,
+ p_predictions,
+ ps,
+ e_predictions,
+ es,
+ src_lens,
+ feats_lengths,
+ )
+
+ def inference(self, data, p_control=1.0, e_control=1.0, d_control=1.0):
+ speakers = data["spk_id"]
+ texts = data["texts"]
+ src_lens = data["text_len"]
+ max_src_len = max(src_lens)
+ mel_lens = data["target_len"] if "target_len" in data else None
+ feats_lengths = mel_lens
+ max_mel_len = max(mel_lens) if "target_len" in data else None
+ p_targets = data["pitch"] if "pitch" in data else None
+ e_targets = data["energy"] if "energy" in data else None
+ d_targets = data["durations"] if "durations" in data else None
+ src_masks = get_mask_from_lengths(src_lens, max_src_len)
+ mel_masks = (
+ get_mask_from_lengths(mel_lens, max_mel_len)
+ if mel_lens is not None
+ else None
+ )
+
+ x_masks = self._source_mask(src_lens)
+ hs = self.encoder(texts, src_masks)
+
+ (
+ p_outs,
+ e_outs,
+ d_outs,
+ ) = self.variance_adaptor.inference(
+ hs,
+ src_masks,
+ )
+
+ p_embs = self.pitch_embed(p_outs.unsqueeze(-1).transpose(1, 2)).transpose(1, 2)
+ e_embs = self.energy_embed(e_outs.unsqueeze(-1).transpose(1, 2)).transpose(1, 2)
+ hs = hs + e_embs + p_embs
+
+ # Duration predictor inference mode: log_d_pred to d_pred
+ offset = 1.0
+ d_predictions = torch.clamp(
+ torch.round(d_outs.exp() - offset), min=0
+ ).long() # avoid negative value
+
+ # forward decoder
+ hs, mel_len = self.length_regulator_infer(hs, d_predictions, max_mel_len)
+ mel_mask = get_mask_from_lengths(mel_len)
+ zs, _ = self.decoder(hs, mel_mask) # (B, T_feats, adim)
+
+ # forward generator
+ wav = self.generator(zs.transpose(1, 2))
+
+ return wav, d_predictions
diff --git a/models/tts/jets/jets_dataset.py b/models/tts/jets/jets_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..6909468db5668e436470ff74487713e7164953ff
--- /dev/null
+++ b/models/tts/jets/jets_dataset.py
@@ -0,0 +1,451 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import random
+import torch
+import torchaudio
+from torch.nn.utils.rnn import pad_sequence
+from utils.data_utils import *
+from models.base.base_dataset import (
+ BaseOfflineCollator,
+ BaseOfflineDataset,
+ BaseTestDataset,
+ BaseTestCollator,
+)
+from text import text_to_sequence
+
+
+class JetsDataset(BaseOfflineDataset):
+ def __init__(self, cfg, dataset, is_valid=False):
+ BaseOfflineDataset.__init__(self, cfg, dataset, is_valid=is_valid)
+ self.batch_size = cfg.train.batch_size
+ cfg = cfg.preprocess
+ # utt2duration
+ self.utt2duration_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2duration_path[utt] = os.path.join(
+ cfg.processed_dir,
+ dataset,
+ cfg.duration_dir,
+ uid + ".npy",
+ )
+ self.utt2dur = self.read_duration()
+
+ if cfg.use_frame_energy:
+ self.frame_utt2energy, self.energy_statistic = load_energy(
+ self.metadata,
+ cfg.processed_dir,
+ cfg.energy_dir,
+ use_log_scale=cfg.use_log_scale_energy,
+ utt2spk=self.preprocess.utt2spk if cfg.use_spkid else None,
+ return_norm=True,
+ )
+ elif cfg.use_phone_energy:
+ self.phone_utt2energy, self.energy_statistic = load_energy(
+ self.metadata,
+ cfg.processed_dir,
+ cfg.phone_energy_dir,
+ use_log_scale=cfg.use_log_scale_energy,
+ utt2spk=self.utt2spk if cfg.use_spkid else None,
+ return_norm=True,
+ )
+
+ if cfg.use_frame_pitch:
+ self.frame_utt2pitch, self.pitch_statistic = load_energy(
+ self.metadata,
+ cfg.processed_dir,
+ cfg.pitch_dir,
+ use_log_scale=cfg.energy_extract_mode,
+ utt2spk=self.utt2spk if cfg.use_spkid else None,
+ return_norm=True,
+ )
+
+ elif cfg.use_phone_pitch:
+ self.phone_utt2pitch, self.pitch_statistic = load_energy(
+ self.metadata,
+ cfg.processed_dir,
+ cfg.phone_pitch_dir,
+ use_log_scale=cfg.use_log_scale_pitch,
+ utt2spk=self.utt2spk if cfg.use_spkid else None,
+ return_norm=True,
+ )
+
+ # utt2lab
+ self.utt2lab_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2lab_path[utt] = os.path.join(
+ cfg.processed_dir,
+ dataset,
+ cfg.lab_dir,
+ uid + ".txt",
+ )
+
+ self.speaker_map = {}
+ if os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json")):
+ with open(
+ os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json"))
+ ) as f:
+ self.speaker_map = json.load(f)
+
+ self.metadata = self.check_metadata()
+ if cfg.use_audios:
+ self.utt2audio_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ if cfg.extract_audio:
+ self.utt2audio_path[utt] = os.path.join(
+ cfg.processed_dir,
+ dataset,
+ cfg.audio_dir,
+ uid + ".wav",
+ )
+ else:
+ self.utt2audio_path[utt] = utt_info["Path"]
+
+ def __getitem__(self, index):
+ single_feature = BaseOfflineDataset.__getitem__(self, index)
+
+ utt_info = self.metadata[index]
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ duration = self.utt2dur[utt]
+
+ # text
+ f = open(self.utt2lab_path[utt], "r")
+ phones = f.readlines()[0].strip()
+ f.close()
+ # todo: add cleaner(chenxi)
+ phones_ids = np.array(text_to_sequence(phones, ["english_cleaners"]))
+ text_len = len(phones_ids)
+
+ if self.cfg.preprocess.use_frame_pitch:
+ pitch = self.frame_utt2pitch[utt]
+ elif self.cfg.preprocess.use_phone_pitch:
+ pitch = self.phone_utt2pitch[utt]
+
+ if self.cfg.preprocess.use_frame_energy:
+ energy = self.frame_utt2energy[utt]
+ elif self.cfg.preprocess.use_phone_energy:
+ energy = self.phone_utt2energy[utt]
+
+ # speaker
+ if len(self.speaker_map) > 0:
+ speaker_id = self.speaker_map[utt_info["Singer"]]
+ else:
+ speaker_id = 0
+
+ single_feature.update(
+ {
+ "durations": duration,
+ "texts": phones_ids,
+ "spk_id": speaker_id,
+ "text_len": text_len,
+ "pitch": pitch,
+ "energy": energy,
+ "uid": uid,
+ }
+ )
+
+ if self.cfg.preprocess.use_audios:
+ audio, sr = torchaudio.load(self.utt2audio_path[utt])
+ audio = audio.cpu().numpy().squeeze()
+ single_feature["audio"] = audio
+ single_feature["audio_len"] = audio.shape[0]
+ return self.clip_if_too_long(single_feature)
+
+ def read_duration(self):
+ # read duration
+ utt2dur = {}
+ for index in range(len(self.metadata)):
+ utt_info = self.metadata[index]
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ if not os.path.exists(self.utt2mel_path[utt]) or not os.path.exists(
+ self.utt2duration_path[utt]
+ ):
+ continue
+
+ mel = np.load(self.utt2mel_path[utt]).transpose(1, 0)
+ duration = np.load(self.utt2duration_path[utt])
+ assert mel.shape[0] == sum(
+ duration
+ ), f"{utt}: mismatch length between mel {mel.shape[0]} and sum(duration) {sum(duration)}"
+ utt2dur[utt] = duration
+ return utt2dur
+
+ def __len__(self):
+ return len(self.metadata)
+
+ def random_select(self, feature_seq_len, max_seq_len, ending_ts=2812):
+ """
+ ending_ts: to avoid invalid whisper features for over 30s audios
+ 2812 = 30 * 24000 // 256
+ """
+ ts = max(feature_seq_len - max_seq_len, 0)
+ ts = min(ts, ending_ts - max_seq_len)
+
+ start = random.randint(0, ts)
+ end = start + max_seq_len
+ return start, end
+
+ def clip_if_too_long(self, sample, max_seq_len=1000):
+ """
+ sample :
+ {
+ 'spk_id': (1,),
+ 'target_len': int
+ 'mel': (seq_len, dim),
+ 'frame_pitch': (seq_len,)
+ 'frame_energy': (seq_len,)
+ 'content_vector_feat': (seq_len, dim)
+ }
+ """
+ if sample["target_len"] <= max_seq_len:
+ return sample
+
+ start, end = self.random_select(sample["target_len"], max_seq_len)
+ sample["target_len"] = end - start
+
+ for k in sample.keys():
+ if k not in ["spk_id", "target_len"]:
+ sample[k] = sample[k][start:end]
+
+ return sample
+
+ def check_metadata(self):
+ new_metadata = []
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+ if not os.path.exists(self.utt2duration_path[utt]) or not os.path.exists(
+ self.utt2mel_path[utt]
+ ):
+ continue
+ else:
+ new_metadata.append(utt_info)
+ return new_metadata
+
+
+class JetsCollator(BaseOfflineCollator):
+ """Zero-pads model inputs and targets based on number of frames per step"""
+
+ def __init__(self, cfg):
+ BaseOfflineCollator.__init__(self, cfg)
+ self.sort = cfg.train.sort_sample
+ self.batch_size = cfg.train.batch_size
+ self.drop_last = cfg.train.drop_last
+
+ def __call__(self, batch):
+ # mel: [b, T, n_mels]
+ # frame_pitch, frame_energy: [1, T]
+ # target_len: [1]
+ # spk_id: [b, 1]
+ # mask: [b, T, 1]
+ packed_batch_features = dict()
+
+ for key in batch[0].keys():
+ if key == "target_len":
+ packed_batch_features["target_len"] = torch.LongTensor(
+ [b["target_len"] for b in batch]
+ )
+ masks = [
+ torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
+ ]
+ packed_batch_features["mask"] = pad_sequence(
+ masks, batch_first=True, padding_value=0
+ )
+ elif key == "text_len":
+ packed_batch_features["text_len"] = torch.LongTensor(
+ [b["text_len"] for b in batch]
+ )
+ masks = [
+ torch.ones((b["text_len"], 1), dtype=torch.long) for b in batch
+ ]
+ packed_batch_features["text_mask"] = pad_sequence(
+ masks, batch_first=True, padding_value=0
+ )
+ elif key == "spk_id":
+ packed_batch_features["spk_id"] = torch.LongTensor(
+ [b["spk_id"] for b in batch]
+ )
+ elif key == "uid":
+ packed_batch_features[key] = [b["uid"] for b in batch]
+ elif key == "audio_len":
+ packed_batch_features["audio_len"] = torch.LongTensor(
+ [b["audio_len"] for b in batch]
+ )
+ else:
+ values = [torch.from_numpy(b[key]) for b in batch]
+ packed_batch_features[key] = pad_sequence(
+ values, batch_first=True, padding_value=0
+ )
+ return packed_batch_features
+
+
+class JetsTestDataset(BaseTestDataset):
+ def __init__(self, args, cfg, infer_type=None):
+ datasets = cfg.dataset
+ cfg = cfg.preprocess
+ is_bigdata = False
+
+ assert len(datasets) >= 1
+ if len(datasets) > 1:
+ datasets.sort()
+ bigdata_version = "_".join(datasets)
+ processed_data_dir = os.path.join(cfg.processed_dir, bigdata_version)
+ is_bigdata = True
+ else:
+ processed_data_dir = os.path.join(cfg.processed_dir, args.dataset)
+
+ if args.test_list_file:
+ self.metafile_path = args.test_list_file
+ self.metadata = self.get_metadata()
+ else:
+ assert args.testing_set
+ source_metafile_path = os.path.join(
+ cfg.processed_dir,
+ args.dataset,
+ "{}.json".format(args.testing_set),
+ )
+ with open(source_metafile_path, "r") as f:
+ self.metadata = json.load(f)
+
+ self.cfg = cfg
+ self.datasets = datasets
+ self.data_root = processed_data_dir
+ self.is_bigdata = is_bigdata
+ self.source_dataset = args.dataset
+
+ ######### Load source acoustic features #########
+ if cfg.use_spkid:
+ spk2id_path = os.path.join(self.data_root, cfg.spk2id)
+ utt2sp_path = os.path.join(self.data_root, cfg.utt2spk)
+ self.spk2id, self.utt2spk = get_spk_map(spk2id_path, utt2sp_path, datasets)
+
+ # utt2lab
+ self.utt2lab_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+ self.utt2lab_path[utt] = os.path.join(
+ cfg.processed_dir,
+ dataset,
+ cfg.lab_dir,
+ uid + ".txt",
+ )
+
+ self.speaker_map = {}
+ if os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json")):
+ with open(
+ os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json"))
+ ) as f:
+ self.speaker_map = json.load(f)
+
+ def __getitem__(self, index):
+ single_feature = {}
+
+ utt_info = self.metadata[index]
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ # text
+ f = open(self.utt2lab_path[utt], "r")
+ phones = f.readlines()[0].strip()
+ f.close()
+
+ phones_ids = np.array(text_to_sequence(phones, self.cfg.text_cleaners))
+ text_len = len(phones_ids)
+
+ # speaker
+ if len(self.speaker_map) > 0:
+ speaker_id = self.speaker_map[utt_info["Singer"]]
+ else:
+ speaker_id = 0
+
+ single_feature.update(
+ {
+ "texts": phones_ids,
+ "spk_id": speaker_id,
+ "text_len": text_len,
+ }
+ )
+
+ return single_feature
+
+ def __len__(self):
+ return len(self.metadata)
+
+ def get_metadata(self):
+ with open(self.metafile_path, "r", encoding="utf-8") as f:
+ metadata = json.load(f)
+
+ return metadata
+
+
+class JetsTestCollator(BaseTestCollator):
+ """Zero-pads model inputs and targets based on number of frames per step"""
+
+ def __init__(self, cfg):
+ self.cfg = cfg
+
+ def __call__(self, batch):
+ packed_batch_features = dict()
+
+ # mel: [b, T, n_mels]
+ # frame_pitch, frame_energy: [1, T]
+ # target_len: [1]
+ # spk_id: [b, 1]
+ # mask: [b, T, 1]
+
+ for key in batch[0].keys():
+ if key == "target_len":
+ packed_batch_features["target_len"] = torch.LongTensor(
+ [b["target_len"] for b in batch]
+ )
+ masks = [
+ torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
+ ]
+ packed_batch_features["mask"] = pad_sequence(
+ masks, batch_first=True, padding_value=0
+ )
+ elif key == "text_len":
+ packed_batch_features["text_len"] = torch.LongTensor(
+ [b["text_len"] for b in batch]
+ )
+ masks = [
+ torch.ones((b["text_len"], 1), dtype=torch.long) for b in batch
+ ]
+ packed_batch_features["text_mask"] = pad_sequence(
+ masks, batch_first=True, padding_value=0
+ )
+ elif key == "spk_id":
+ packed_batch_features["spk_id"] = torch.LongTensor(
+ [b["spk_id"] for b in batch]
+ )
+ else:
+ values = [torch.from_numpy(b[key]) for b in batch]
+ packed_batch_features[key] = pad_sequence(
+ values, batch_first=True, padding_value=0
+ )
+
+ return packed_batch_features
diff --git a/models/tts/jets/jets_inference.py b/models/tts/jets/jets_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..35610c4b93b0c6c784f0b67443998d56aca76202
--- /dev/null
+++ b/models/tts/jets/jets_inference.py
@@ -0,0 +1,71 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import torch
+from tqdm import tqdm
+from collections import OrderedDict
+
+from models.tts.base.tts_inferece import TTSInference
+from models.tts.jets.jets_dataset import JetsTestDataset, JetsTestCollator
+from utils.util import load_config
+from utils.io import save_audio
+from models.tts.jets.jets import Jets
+from models.vocoders.vocoder_inference import synthesis
+from pathlib import Path
+from processors.phone_extractor import phoneExtractor
+from text.text_token_collation import phoneIDCollation
+import numpy as np
+import json
+import time
+
+
+class JetsInference(TTSInference):
+ def __init__(self, args, cfg):
+ TTSInference.__init__(self, args, cfg)
+ self.args = args
+ self.cfg = cfg
+ self.infer_type = args.mode
+
+ def _build_model(self):
+ self.model = Jets(self.cfg)
+ return self.model
+
+ def _build_test_dataset(self):
+ return JetsTestDataset, JetsTestCollator
+
+ def inference_for_batches(self):
+ ###### Construct test_batch ######
+ n_batch = len(self.test_dataloader)
+ now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
+ print(
+ "Model eval time: {}, batch_size = {}, n_batch = {}".format(
+ now, self.test_batch_size, n_batch
+ )
+ )
+ self.model.eval()
+
+ ###### Inference for each batch ######
+ pred_res = []
+ with torch.no_grad():
+ for i, batch_data in enumerate(
+ self.test_dataloader if n_batch == 1 else tqdm(self.test_dataloader)
+ ):
+ outputs = self.model.inference(batch_data)
+
+ audios, d_predictions = outputs
+ d_predictions = d_predictions.unsqueeze(-1)
+
+ for idx in range(audios.size(0)):
+ audio = audios[idx, 0, :].data.cpu().float()
+ duration = d_predictions[idx, :, :]
+ audio_length = (
+ duration.sum([0, 1]).long() * self.cfg.preprocess.hop_size
+ )
+ audio_length = audio_length.cpu().numpy()
+ audio = audio[:audio_length]
+ pred_res.append(audio)
+
+ return pred_res
diff --git a/models/tts/jets/jets_loss.py b/models/tts/jets/jets_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..3955734dd8d9e443c08bc07854b1e9fbbacccb08
--- /dev/null
+++ b/models/tts/jets/jets_loss.py
@@ -0,0 +1,537 @@
+# Copyright (c) 2024 Amphion.
+#
+# This code is modified from https://github.com/imdanboy/jets/blob/main/espnet2/gan_tts/jets/loss.py
+# Licensed under Apache License 2.0
+
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import librosa
+
+from models.vocoders.gan.discriminator.mpd import MultiScaleMultiPeriodDiscriminator
+from models.tts.jets.alignments import make_non_pad_mask, make_pad_mask
+
+
+class GeneratorAdversarialLoss(torch.nn.Module):
+ """Generator adversarial loss module."""
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, outputs) -> torch.Tensor:
+ if isinstance(outputs, (tuple, list)):
+ adv_loss = 0.0
+ for i, outputs_ in enumerate(outputs):
+ if isinstance(outputs_, (tuple, list)):
+ # NOTE(kan-bayashi): case including feature maps
+ outputs_ = outputs_[-1]
+ adv_loss += F.mse_loss(outputs_, outputs_.new_ones(outputs_.size()))
+ else:
+ adv_loss = F.mse_loss(outputs, outputs.new_ones(outputs.size()))
+
+ return adv_loss
+
+
+class FeatureMatchLoss(torch.nn.Module):
+ """Feature matching loss module."""
+
+ def __init__(
+ self,
+ average_by_layers: bool = False,
+ average_by_discriminators: bool = False,
+ include_final_outputs: bool = True,
+ ):
+ """Initialize FeatureMatchLoss module.
+
+ Args:
+ average_by_layers (bool): Whether to average the loss by the number
+ of layers.
+ average_by_discriminators (bool): Whether to average the loss by
+ the number of discriminators.
+ include_final_outputs (bool): Whether to include the final output of
+ each discriminator for loss calculation.
+
+ """
+ super().__init__()
+ self.average_by_layers = average_by_layers
+ self.average_by_discriminators = average_by_discriminators
+ self.include_final_outputs = include_final_outputs
+
+ def forward(
+ self,
+ feats_hat: Union[List[List[torch.Tensor]], List[torch.Tensor]],
+ feats: Union[List[List[torch.Tensor]], List[torch.Tensor]],
+ ) -> torch.Tensor:
+ """Calculate feature matching loss.
+
+ Args:
+ feats_hat (Union[List[List[Tensor]], List[Tensor]]): List of list of
+ discriminator outputs or list of discriminator outputs calcuated
+ from generator's outputs.
+ feats (Union[List[List[Tensor]], List[Tensor]]): List of list of
+ discriminator outputs or list of discriminator outputs calcuated
+ from groundtruth..
+
+ Returns:
+ Tensor: Feature matching loss value.
+
+ """
+ feat_match_loss = 0.0
+ for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)):
+ feat_match_loss_ = 0.0
+ if not self.include_final_outputs:
+ feats_hat_ = feats_hat_[:-1]
+ feats_ = feats_[:-1]
+ for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
+ feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach())
+ if self.average_by_layers:
+ feat_match_loss_ /= j + 1
+ feat_match_loss += feat_match_loss_
+ if self.average_by_discriminators:
+ feat_match_loss /= i + 1
+
+ return feat_match_loss
+
+
+class DurationPredictorLoss(torch.nn.Module):
+ """Loss function module for duration predictor.
+
+ The loss value is Calculated in log domain to make it Gaussian.
+
+ """
+
+ def __init__(self, offset=1.0, reduction="mean"):
+ """Initilize duration predictor loss module.
+
+ Args:
+ offset (float, optional): Offset value to avoid nan in log domain.
+ reduction (str): Reduction type in loss calculation.
+
+ """
+ super().__init__()
+ self.criterion = torch.nn.MSELoss(reduction=reduction)
+ self.offset = offset
+
+ def forward(self, outputs, targets):
+ targets = torch.log(targets.float() + self.offset)
+ loss = self.criterion(outputs, targets)
+
+ return loss
+
+
+class VarianceLoss(torch.nn.Module):
+ def __init__(self):
+ """Initialize JETS variance loss module."""
+ super().__init__()
+
+ # define criterions
+ reduction = "mean"
+ self.mse_criterion = torch.nn.MSELoss(reduction=reduction)
+ self.duration_criterion = DurationPredictorLoss(reduction=reduction)
+
+ def forward(
+ self,
+ d_outs: torch.Tensor,
+ ds: torch.Tensor,
+ p_outs: torch.Tensor,
+ ps: torch.Tensor,
+ e_outs: torch.Tensor,
+ es: torch.Tensor,
+ ilens: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Calculate forward propagation.
+
+ Args:
+ d_outs (LongTensor): Batch of outputs of duration predictor (B, T_text).
+ ds (LongTensor): Batch of durations (B, T_text).
+ p_outs (Tensor): Batch of outputs of pitch predictor (B, T_text, 1).
+ ps (Tensor): Batch of target token-averaged pitch (B, T_text, 1).
+ e_outs (Tensor): Batch of outputs of energy predictor (B, T_text, 1).
+ es (Tensor): Batch of target token-averaged energy (B, T_text, 1).
+ ilens (LongTensor): Batch of the lengths of each input (B,).
+
+ Returns:
+ Tensor: Duration predictor loss value.
+ Tensor: Pitch predictor loss value.
+ Tensor: Energy predictor loss value.
+
+ """
+ # apply mask to remove padded part
+ duration_masks = make_non_pad_mask(ilens).to(ds.device)
+ d_outs = d_outs.masked_select(duration_masks)
+ ds = ds.masked_select(duration_masks)
+ pitch_masks = make_non_pad_mask(ilens).to(ds.device)
+ pitch_masks_ = make_non_pad_mask(ilens).unsqueeze(-1).to(ds.device)
+ p_outs = p_outs.masked_select(pitch_masks)
+ e_outs = e_outs.masked_select(pitch_masks)
+ ps = ps.masked_select(pitch_masks_)
+ es = es.masked_select(pitch_masks_)
+
+ # calculate loss
+ duration_loss = self.duration_criterion(d_outs, ds)
+ pitch_loss = self.mse_criterion(p_outs, ps)
+ energy_loss = self.mse_criterion(e_outs, es)
+
+ return duration_loss, pitch_loss, energy_loss
+
+
+class ForwardSumLoss(torch.nn.Module):
+ """Forwardsum loss described at https://openreview.net/forum?id=0NQwnnwAORi"""
+
+ def __init__(self):
+ """Initialize forwardsum loss module."""
+ super().__init__()
+
+ def forward(
+ self,
+ log_p_attn: torch.Tensor,
+ ilens: torch.Tensor,
+ olens: torch.Tensor,
+ blank_prob: float = np.e**-1,
+ ) -> torch.Tensor:
+ """Calculate forward propagation.
+
+ Args:
+ log_p_attn (Tensor): Batch of log probability of attention matrix
+ (B, T_feats, T_text).
+ ilens (Tensor): Batch of the lengths of each input (B,).
+ olens (Tensor): Batch of the lengths of each target (B,).
+ blank_prob (float): Blank symbol probability.
+
+ Returns:
+ Tensor: forwardsum loss value.
+
+ """
+ B = log_p_attn.size(0)
+
+ # a row must be added to the attention matrix to account for
+ # blank token of CTC loss
+ # (B,T_feats,T_text+1)
+ log_p_attn_pd = F.pad(log_p_attn, (1, 0, 0, 0, 0, 0), value=np.log(blank_prob))
+
+ loss = 0
+ for bidx in range(B):
+ # construct target sequnece.
+ # Every text token is mapped to a unique sequnece number.
+ target_seq = torch.arange(1, ilens[bidx] + 1).unsqueeze(0)
+ cur_log_p_attn_pd = log_p_attn_pd[
+ bidx, : olens[bidx], : ilens[bidx] + 1
+ ].unsqueeze(
+ 1
+ ) # (T_feats,1,T_text+1)
+ cur_log_p_attn_pd = F.log_softmax(cur_log_p_attn_pd, dim=-1)
+ loss += F.ctc_loss(
+ log_probs=cur_log_p_attn_pd,
+ targets=target_seq,
+ input_lengths=olens[bidx : bidx + 1],
+ target_lengths=ilens[bidx : bidx + 1],
+ zero_infinity=True,
+ )
+ loss = loss / B
+ return loss
+
+
+class MelSpectrogramLoss(torch.nn.Module):
+ """Mel-spectrogram loss."""
+
+ def __init__(
+ self,
+ fs: int = 22050,
+ n_fft: int = 1024,
+ hop_length: int = 256,
+ win_length: Optional[int] = None,
+ window: str = "hann",
+ n_mels: int = 80,
+ fmin: Optional[int] = 0,
+ fmax: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = False,
+ onesided: bool = True,
+ htk: bool = False,
+ ):
+ """Initialize Mel-spectrogram loss.
+
+ Args:
+ fs (int): Sampling rate.
+ n_fft (int): FFT points.
+ hop_length (int): Hop length.
+ win_length (Optional[int]): Window length.
+ window (str): Window type.
+ n_mels (int): Number of Mel basis.
+ fmin (Optional[int]): Minimum frequency for Mel.
+ fmax (Optional[int]): Maximum frequency for Mel.
+ center (bool): Whether to use center window.
+ normalized (bool): Whether to use normalized one.
+ onesided (bool): Whether to use oneseded one.
+
+ """
+ super().__init__()
+
+ self.fs = fs
+ self.n_fft = n_fft
+ self.hop_length = hop_length
+ self.win_length = n_fft
+ self.window = window
+ self.n_mels = n_mels
+ self.fmin = 0 if fmin is None else fmin
+ self.fmax = fs / 2 if fmax is None else fmax
+ self.center = center
+ self.normalized = normalized
+ self.onesided = onesided
+ self.htk = htk
+
+ def logmel(self, feat, ilens):
+ mel_options = dict(
+ sr=self.fs,
+ n_fft=self.n_fft,
+ n_mels=self.n_mels,
+ fmin=self.fmin,
+ fmax=self.fmax,
+ htk=self.htk,
+ )
+ melmat = librosa.filters.mel(**mel_options)
+ melmat = torch.from_numpy(melmat.T).float().to(feat.device)
+ mel_feat = torch.matmul(feat, melmat)
+ mel_feat = torch.clamp(mel_feat, min=1e-10)
+ logmel_feat = mel_feat.log10()
+
+ # Zero padding
+ if ilens is not None:
+ logmel_feat = logmel_feat.masked_fill(
+ make_pad_mask(ilens, logmel_feat, 1), 0.0
+ )
+ else:
+ ilens = feat.new_full(
+ [feat.size(0)], fill_value=feat.size(1), dtype=torch.long
+ )
+ return logmel_feat
+
+ def wav_to_mel(self, input, input_lengths=None):
+ if self.window is not None:
+ window_func = getattr(torch, f"{self.window}_window")
+ window = window_func(
+ self.win_length, dtype=input.dtype, device=input.device
+ )
+
+ stft_kwargs = dict(
+ n_fft=self.n_fft,
+ win_length=self.win_length,
+ hop_length=self.hop_length,
+ center=self.center,
+ window=window,
+ normalized=self.normalized,
+ onesided=self.onesided,
+ return_complex=True,
+ )
+
+ bs = input.size(0)
+ if input.dim() == 3:
+ multi_channel = True
+ # input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample)
+ input = input.transpose(1, 2).reshape(-1, input.size(1))
+ else:
+ multi_channel = False
+
+ input_stft = torch.stft(input, **stft_kwargs)
+ input_stft = torch.view_as_real(input_stft)
+ input_stft = input_stft.transpose(1, 2)
+ if multi_channel:
+ input_stft = input_stft.view(
+ bs, -1, input_stft.size(1), input_stft.size(2), 2
+ ).transpose(1, 2)
+ if input_lengths is not None:
+ if self.center:
+ pad = self.n_fft // 2
+ input_lengths = input_lengths + 2 * pad
+
+ feats_lens = (input_lengths - self.n_fft) // self.hop_length + 1
+ input_stft.masked_fill_(make_pad_mask(feats_lens, input_stft, 1), 0.0)
+ else:
+ feats_lens = None
+ input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2
+ input_amp = torch.sqrt(torch.clamp(input_power, min=1.0e-10))
+ input_feats = self.logmel(input_amp, feats_lens)
+ return input_feats, feats_lens
+
+ def forward(
+ self,
+ y_hat: torch.Tensor,
+ y: torch.Tensor,
+ ) -> torch.Tensor:
+ mel_hat, _ = self.wav_to_mel(y_hat.squeeze(1))
+ mel, _ = self.wav_to_mel(y.squeeze(1))
+ mel_loss = F.l1_loss(mel_hat, mel)
+
+ return mel_loss
+
+
+class GeneratorLoss(nn.Module):
+ """The total loss of the generator"""
+
+ def __init__(self, cfg):
+ super().__init__()
+ self.cfg = cfg
+
+ self.mel_loss = MelSpectrogramLoss()
+ self.generator_adv_loss = GeneratorAdversarialLoss()
+ self.feat_match_loss = FeatureMatchLoss()
+ self.var_loss = VarianceLoss()
+ self.forwardsum_loss = ForwardSumLoss()
+
+ self.lambda_adv = 1.0
+ self.lambda_mel = 45.0
+ self.lambda_feat_match = 2.0
+ self.lambda_var = 1.0
+ self.lambda_align = 2.0
+
+ def forward(self, outputs_g, outputs_d, speech_):
+ loss_g = {}
+
+ # parse generator output
+ (
+ speech_hat_,
+ bin_loss,
+ log_p_attn,
+ start_idxs,
+ d_outs,
+ ds,
+ p_outs,
+ ps,
+ e_outs,
+ es,
+ text_lengths,
+ feats_lengths,
+ ) = outputs_g
+
+ # parse discriminator output
+ (p_hat, p) = outputs_d
+
+ # calculate losses
+ mel_loss = self.mel_loss(speech_hat_, speech_)
+ adv_loss = self.generator_adv_loss(p_hat)
+ feat_match_loss = self.feat_match_loss(p_hat, p)
+ dur_loss, pitch_loss, energy_loss = self.var_loss(
+ d_outs, ds, p_outs, ps, e_outs, es, text_lengths
+ )
+ forwardsum_loss = self.forwardsum_loss(log_p_attn, text_lengths, feats_lengths)
+
+ # calculate total loss
+ mel_loss = mel_loss * self.lambda_mel
+ loss_g["mel_loss"] = mel_loss
+ adv_loss = adv_loss * self.lambda_adv
+ loss_g["adv_loss"] = adv_loss
+ feat_match_loss = feat_match_loss * self.lambda_feat_match
+ loss_g["feat_match_loss"] = feat_match_loss
+ g_loss = mel_loss + adv_loss + feat_match_loss
+ loss_g["g_loss"] = g_loss
+ var_loss = (dur_loss + pitch_loss + energy_loss) * self.lambda_var
+ loss_g["var_loss"] = var_loss
+ align_loss = (forwardsum_loss + bin_loss) * self.lambda_align
+ loss_g["align_loss"] = align_loss
+
+ g_total_loss = g_loss + var_loss + align_loss
+
+ loss_g["g_total_loss"] = g_total_loss
+
+ return loss_g
+
+
+class DiscriminatorAdversarialLoss(torch.nn.Module):
+ """Discriminator adversarial loss module."""
+
+ def __init__(
+ self,
+ average_by_discriminators: bool = True,
+ loss_type: str = "mse",
+ ):
+ """Initialize DiscriminatorAversarialLoss module.
+
+ Args:
+ average_by_discriminators (bool): Whether to average the loss by
+ the number of discriminators.
+ loss_type (str): Loss type, "mse" or "hinge".
+
+ """
+ super().__init__()
+ self.average_by_discriminators = average_by_discriminators
+ assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
+ if loss_type == "mse":
+ self.fake_criterion = self._mse_fake_loss
+ self.real_criterion = self._mse_real_loss
+ else:
+ self.fake_criterion = self._hinge_fake_loss
+ self.real_criterion = self._hinge_real_loss
+
+ def forward(
+ self,
+ outputs_hat: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
+ outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Calcualate discriminator adversarial loss.
+
+ Args:
+ outputs_hat (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
+ outputs, list of discriminator outputs, or list of list of discriminator
+ outputs calculated from generator.
+ outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
+ outputs, list of discriminator outputs, or list of list of discriminator
+ outputs calculated from groundtruth.
+
+ Returns:
+ Tensor: Discriminator real loss value.
+ Tensor: Discriminator fake loss value.
+
+ """
+ if isinstance(outputs, (tuple, list)):
+ real_loss = 0.0
+ fake_loss = 0.0
+ for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)):
+ if isinstance(outputs_hat_, (tuple, list)):
+ # NOTE(kan-bayashi): case including feature maps
+ outputs_hat_ = outputs_hat_[-1]
+ outputs_ = outputs_[-1]
+ real_loss += self.real_criterion(outputs_)
+ fake_loss += self.fake_criterion(outputs_hat_)
+ if self.average_by_discriminators:
+ fake_loss /= i + 1
+ real_loss /= i + 1
+ else:
+ real_loss = self.real_criterion(outputs)
+ fake_loss = self.fake_criterion(outputs_hat)
+
+ return real_loss, fake_loss
+
+ def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor:
+ return F.mse_loss(x, x.new_ones(x.size()))
+
+ def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
+ return F.mse_loss(x, x.new_zeros(x.size()))
+
+ def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor:
+ return -torch.mean(torch.min(x - 1, x.new_zeros(x.size())))
+
+ def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
+ return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size())))
+
+
+class DiscriminatorLoss(torch.nn.Module):
+ """The total loss of the discriminator"""
+
+ def __init__(self, cfg):
+ super(DiscriminatorLoss, self).__init__()
+ self.cfg = cfg
+ self.discriminator = MultiScaleMultiPeriodDiscriminator()
+ self.discriminator_adv_loss = DiscriminatorAdversarialLoss()
+
+ def forward(self, speech_real, speech_generated):
+ loss_d = {}
+
+ real_loss, fake_loss = self.discriminator_adv_loss(
+ speech_generated, speech_real
+ )
+ loss_d["loss_disc_all"] = real_loss + fake_loss
+
+ return loss_d
diff --git a/models/tts/jets/jets_trainer.py b/models/tts/jets/jets_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..925dc9280611d6b45a8fcc43667458d5649c57a4
--- /dev/null
+++ b/models/tts/jets/jets_trainer.py
@@ -0,0 +1,361 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import os
+from utils.io import save_audio
+from tqdm import tqdm
+from models.tts.base import TTSTrainer
+from models.tts.jets.jets import Jets
+from models.tts.jets.jets_loss import GeneratorLoss, DiscriminatorLoss
+from models.tts.jets.jets_dataset import JetsDataset, JetsCollator
+from optimizer.optimizers import NoamLR
+from torch.optim.lr_scheduler import ExponentialLR
+from models.vocoders.gan.discriminator.mpd import MultiScaleMultiPeriodDiscriminator
+
+
+def get_segments(
+ x: torch.Tensor,
+ start_idxs: torch.Tensor,
+ segment_size: int,
+) -> torch.Tensor:
+ """Get segments.
+
+ Args:
+ x (Tensor): Input tensor (B, C, T).
+ start_idxs (Tensor): Start index tensor (B,).
+ segment_size (int): Segment size.
+
+ Returns:
+ Tensor: Segmented tensor (B, C, segment_size).
+
+ """
+ b, c, t = x.size()
+ segments = x.new_zeros(b, c, segment_size)
+ for i, start_idx in enumerate(start_idxs):
+ segments[i] = x[i, :, start_idx : start_idx + segment_size]
+ return segments
+
+
+class JetsTrainer(TTSTrainer):
+ def __init__(self, args, cfg):
+ TTSTrainer.__init__(self, args, cfg)
+ self.cfg = cfg
+
+ def _build_dataset(self):
+ return JetsDataset, JetsCollator
+
+ def __build_scheduler(self):
+ return NoamLR(self.optimizer, **self.cfg.train.lr_scheduler)
+
+ def _write_summary(
+ self,
+ losses,
+ stats,
+ images={},
+ audios={},
+ audio_sampling_rate=24000,
+ tag="train",
+ ):
+ for key, value in losses.items():
+ self.sw.add_scalar(tag + "/" + key, value, self.step)
+ self.sw.add_scalar(
+ "learning_rate",
+ self.optimizer["optimizer_g"].param_groups[0]["lr"],
+ self.step,
+ )
+
+ if len(images) != 0:
+ for key, value in images.items():
+ self.sw.add_image(key, value, self.global_step, batchformats="HWC")
+ if len(audios) != 0:
+ for key, value in audios.items():
+ self.sw.add_audio(key, value, self.global_step, audio_sampling_rate)
+
+ for key, value in losses.items():
+ self.sw.add_scalar("train/" + key, value, self.step)
+ lr = self.optimizer.state_dict()["param_groups"][0]["lr"]
+ self.sw.add_scalar("learning_rate", lr, self.step)
+
+ def _write_valid_summary(
+ self, losses, stats, images={}, audios={}, audio_sampling_rate=24000, tag="val"
+ ):
+ for key, value in losses.items():
+ self.sw.add_scalar(tag + "/" + key, value, self.step)
+
+ if len(images) != 0:
+ for key, value in images.items():
+ self.sw.add_image(key, value, self.global_step, batchformats="HWC")
+ if len(audios) != 0:
+ for key, value in audios.items():
+ self.sw.add_audio(key, value, self.global_step, audio_sampling_rate)
+
+ def _build_criterion(self):
+ criterion = {
+ "generator": GeneratorLoss(self.cfg),
+ "discriminator": DiscriminatorLoss(self.cfg),
+ }
+ return criterion
+
+ def get_state_dict(self):
+ state_dict = {
+ "generator": self.model["generator"].state_dict(),
+ "discriminator": self.model["discriminator"].state_dict(),
+ "optimizer_g": self.optimizer["optimizer_g"].state_dict(),
+ "optimizer_d": self.optimizer["optimizer_d"].state_dict(),
+ "scheduler_g": self.scheduler["scheduler_g"].state_dict(),
+ "scheduler_d": self.scheduler["scheduler_d"].state_dict(),
+ "step": self.step,
+ "epoch": self.epoch,
+ "batch_size": self.cfg.train.batch_size,
+ }
+ return state_dict
+
+ def _build_optimizer(self):
+ optimizer_g = torch.optim.AdamW(
+ self.model["generator"].parameters(),
+ self.cfg.train.learning_rate,
+ betas=self.cfg.train.AdamW.betas,
+ eps=self.cfg.train.AdamW.eps,
+ )
+ optimizer_d = torch.optim.AdamW(
+ self.model["discriminator"].parameters(),
+ self.cfg.train.learning_rate,
+ betas=self.cfg.train.AdamW.betas,
+ eps=self.cfg.train.AdamW.eps,
+ )
+ optimizer = {"optimizer_g": optimizer_g, "optimizer_d": optimizer_d}
+
+ return optimizer
+
+ def _build_scheduler(self):
+ scheduler_g = ExponentialLR(
+ self.optimizer["optimizer_g"],
+ gamma=self.cfg.train.lr_decay,
+ last_epoch=self.epoch - 1,
+ )
+ scheduler_d = ExponentialLR(
+ self.optimizer["optimizer_d"],
+ gamma=self.cfg.train.lr_decay,
+ last_epoch=self.epoch - 1,
+ )
+
+ scheduler = {"scheduler_g": scheduler_g, "scheduler_d": scheduler_d}
+ return scheduler
+
+ def _build_model(self):
+ net_g = Jets(self.cfg)
+ net_d = MultiScaleMultiPeriodDiscriminator()
+ self.model = {"generator": net_g, "discriminator": net_d}
+ return self.model
+
+ def _train_epoch(self):
+ r"""Training epoch. Should return average loss of a batch (sample) over
+ one epoch. See ``train_loop`` for usage.
+ """
+ self.model["generator"].train()
+ self.model["discriminator"].train()
+ epoch_sum_loss: float = 0.0
+ epoch_losses: dict = {}
+ epoch_step: int = 0
+ for batch in tqdm(
+ self.train_dataloader,
+ desc=f"Training Epoch {self.epoch}",
+ unit="batch",
+ colour="GREEN",
+ leave=False,
+ dynamic_ncols=True,
+ smoothing=0.04,
+ disable=not self.accelerator.is_main_process,
+ ):
+ with self.accelerator.accumulate(self.model):
+ if batch["target_len"].min() < self.cfg.train.segment_size:
+ continue
+ total_loss, train_losses, training_stats = self._train_step(batch)
+ self.batch_count += 1
+
+ if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
+ epoch_sum_loss += total_loss
+ for key, value in train_losses.items():
+ if key not in epoch_losses.keys():
+ epoch_losses[key] = value
+ else:
+ epoch_losses[key] += value
+
+ self.accelerator.log(
+ {
+ "Step/Train {} Loss".format(key): value,
+ },
+ step=self.step,
+ )
+
+ self.step += 1
+ epoch_step += 1
+
+ self.accelerator.wait_for_everyone()
+
+ epoch_sum_loss = (
+ epoch_sum_loss
+ / len(self.train_dataloader)
+ * self.cfg.train.gradient_accumulation_step
+ )
+
+ for key in epoch_losses.keys():
+ epoch_losses[key] = (
+ epoch_losses[key]
+ / len(self.train_dataloader)
+ * self.cfg.train.gradient_accumulation_step
+ )
+
+ return epoch_sum_loss, epoch_losses
+
+ def _train_step(self, batch):
+ train_losses = {}
+ total_loss = 0
+ training_stats = {}
+
+ # Train Discriminator
+ # Generator output
+ outputs_g = self.model["generator"](batch)
+ speech_hat_, _, _, start_idxs, *_ = outputs_g
+
+ # Discriminator output
+ speech = batch["audio"].unsqueeze(1)
+ upsample_factor = self.cfg.train.upsample_factor
+ speech_ = get_segments(
+ x=speech,
+ start_idxs=start_idxs * upsample_factor,
+ segment_size=self.cfg.train.segment_size * upsample_factor,
+ )
+ p_hat = self.model["discriminator"](speech_hat_.detach())
+ p = self.model["discriminator"](speech_)
+
+ # Discriminator loss
+ loss_d = self.criterion["discriminator"](p, p_hat)
+ train_losses.update(loss_d)
+
+ # BP and Grad Updated
+ self.optimizer["optimizer_d"].zero_grad()
+ self.accelerator.backward(loss_d["loss_disc_all"])
+ self.optimizer["optimizer_d"].step()
+
+ # Train Generator
+ p_hat = self.model["discriminator"](speech_hat_)
+ with torch.no_grad():
+ p = self.model["discriminator"](speech_)
+
+ outputs_d = (p_hat, p)
+ loss_g = self.criterion["generator"](outputs_g, outputs_d, speech_)
+ train_losses.update(loss_g)
+
+ # BP and Grad Updated
+ self.optimizer["optimizer_g"].zero_grad()
+ self.accelerator.backward(loss_g["g_total_loss"])
+ self.optimizer["optimizer_g"].step()
+
+ for item in train_losses:
+ train_losses[item] = train_losses[item].item()
+
+ total_loss = loss_g["g_total_loss"] + loss_d["loss_disc_all"]
+
+ return (
+ total_loss.item(),
+ train_losses,
+ training_stats,
+ )
+
+ @torch.inference_mode()
+ def _valid_step(self, batch):
+ valid_losses = {}
+ total_loss = 0
+ valid_stats = {}
+
+ # Discriminator
+ # Generator output
+ outputs_g = self.model["generator"](batch)
+ speech_hat_, _, _, start_idxs, *_ = outputs_g
+
+ # Discriminator output
+ speech = batch["audio"].unsqueeze(1)
+ upsample_factor = self.cfg.train.upsample_factor
+ speech_ = get_segments(
+ x=speech,
+ start_idxs=start_idxs * upsample_factor,
+ segment_size=self.cfg.train.segment_size * upsample_factor,
+ )
+ p_hat = self.model["discriminator"](speech_hat_.detach())
+ p = self.model["discriminator"](speech_)
+
+ # Discriminator loss
+ loss_d = self.criterion["discriminator"](p, p_hat)
+ valid_losses.update(loss_d)
+
+ # Generator loss
+ p_hat = self.model["discriminator"](speech_hat_)
+ with torch.no_grad():
+ p = self.model["discriminator"](speech_)
+
+ outputs_d = (p_hat, p)
+ loss_g = self.criterion["generator"](outputs_g, outputs_d, speech_)
+ valid_losses.update(loss_g)
+
+ for item in valid_losses:
+ valid_losses[item] = valid_losses[item].item()
+
+ total_loss = loss_g["g_total_loss"] + loss_d["loss_disc_all"]
+
+ return (
+ total_loss.item(),
+ valid_losses,
+ valid_stats,
+ )
+
+ @torch.inference_mode()
+ def _valid_epoch(self):
+ r"""Testing epoch. Should return average loss of a batch (sample) over
+ one epoch. See ``train_loop`` for usage.
+ """
+ if isinstance(self.model, dict):
+ for key in self.model.keys():
+ self.model[key].eval()
+ else:
+ self.model.eval()
+
+ epoch_sum_loss = 0.0
+ epoch_losses = dict()
+ for batch in tqdm(
+ self.valid_dataloader,
+ desc=f"Validating Epoch {self.epoch}",
+ unit="batch",
+ colour="GREEN",
+ leave=False,
+ dynamic_ncols=True,
+ smoothing=0.04,
+ disable=not self.accelerator.is_main_process,
+ ):
+ total_loss, valid_losses, valid_stats = self._valid_step(batch)
+ epoch_sum_loss += total_loss
+ if isinstance(valid_losses, dict):
+ for key, value in valid_losses.items():
+ if key not in epoch_losses.keys():
+ epoch_losses[key] = value
+ else:
+ epoch_losses[key] += value
+
+ self.accelerator.log(
+ {
+ "Step/Valid {} Loss".format(key): value,
+ },
+ step=self.step,
+ )
+
+ epoch_sum_loss = epoch_sum_loss / len(self.valid_dataloader)
+ for key in epoch_losses.keys():
+ epoch_losses[key] = epoch_losses[key] / len(self.valid_dataloader)
+
+ self.accelerator.wait_for_everyone()
+
+ return epoch_sum_loss, epoch_losses
diff --git a/models/tts/jets/length_regulator.py b/models/tts/jets/length_regulator.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce30f331d23afc07dc348f9a30aa118d370d9037
--- /dev/null
+++ b/models/tts/jets/length_regulator.py
@@ -0,0 +1,49 @@
+# Copyright (c) 2024 Amphion.
+#
+# This code is modified from https://github.com/imdanboy/jets/blob/main/espnet2/gan_tts/jets/length_regulator.py
+# Licensed under Apache License 2.0
+
+import torch
+
+
+class GaussianUpsampling(torch.nn.Module):
+ """
+ Gaussian upsampling with fixed temperature as in:
+ https://arxiv.org/abs/2010.04301
+ """
+
+ def __init__(self, delta=0.1):
+ super().__init__()
+ self.delta = delta
+
+ def forward(self, hs, ds, h_masks=None, d_masks=None):
+ """
+ Args:
+ hs (Tensor): Batched hidden state to be expanded (B, T_text, adim)
+ ds (Tensor): Batched token duration (B, T_text)
+ h_masks (Tensor): Mask tensor (B,T_feats)
+ d_masks (Tensor): Mask tensor (B,T_text)
+ Returns:
+ Tensor: Expanded hidden state (B, T_feat, adim)
+ """
+ B = ds.size(0)
+ device = ds.device
+
+ if h_masks is None:
+ T_feats = ds.sum().int()
+ else:
+ T_feats = h_masks.size(-1)
+ t = torch.arange(0, T_feats).unsqueeze(0).repeat(B, 1).to(device).float()
+ if h_masks is not None:
+ t = t * h_masks.float()
+
+ c = ds.cumsum(dim=-1) - ds / 2
+ energy = -1 * self.delta * (t.unsqueeze(-1) - c.unsqueeze(1)) ** 2
+ if d_masks is not None:
+ energy = energy.masked_fill(
+ ~(d_masks.unsqueeze(1).repeat(1, T_feats, 1)), -float("inf")
+ )
+
+ p_attn = torch.softmax(energy, dim=2) # (B, T_feats, T_text)
+ hs = torch.matmul(p_attn, hs)
+ return hs
diff --git a/models/tts/maskgct/.DS_Store b/models/tts/maskgct/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..e7741526bf2e85d0c7b083184a8420ef6396a7e9
Binary files /dev/null and b/models/tts/maskgct/.DS_Store differ
diff --git a/models/tts/maskgct/README.md b/models/tts/maskgct/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..696fe7ce75a3b59710c82273be697ae659650ff7
--- /dev/null
+++ b/models/tts/maskgct/README.md
@@ -0,0 +1,183 @@
+## MaskGCT: Zero-Shot Text-to-Speech with Masked Generative Codec Transformer
+
+[![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2409.00750)
+[![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-model-yellow)](https://huggingface.co/amphion/maskgct)
+[![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-demo-pink)](https://huggingface.co/spaces/amphion/maskgct)
+[![readme](https://img.shields.io/badge/README-Key%20Features-blue)](../../../models/tts/maskgct/README.md)
+
+## Overview
+
+MaskGCT (**Mask**ed **G**enerative **C**odec **T**ransformer) is *a fully non-autoregressive TTS model that eliminates the need for explicit alignment information between text and speech supervision, as well as phone-level duration prediction*. MaskGCT is a two-stage model: in the first stage, the model uses text to predict semantic tokens extracted from a speech self-supervised learning (SSL) model, and in the second stage, the model predicts acoustic tokens conditioned on these semantic tokens. MaskGCT follows the *mask-and-predict* learning paradigm. During training, MaskGCT learns to predict masked semantic or acoustic tokens based on given conditions and prompts. During inference, the model generates tokens of a specified length in a parallel manner. Experiments with 100K hours of in-the-wild speech demonstrate that MaskGCT outperforms the current state-of-the-art zero-shot TTS systems in terms of quality, similarity, and intelligibility. Audio samples are available at [demo page](https://maskgct.github.io/).
+
+
+
+
+
+
+
+## News
+
+- **2024/10/19**: We release **MaskGCT**, a fully non-autoregressive TTS model that eliminates the need for explicit alignment information between text and speech supervision. MaskGCT is trained on Emilia dataset and achieves SOTA zero-shot TTS perfermance.
+
+## Quickstart
+
+**Clone and install**
+
+```bash
+git clone https://github.com/open-mmlab/Amphion.git
+# create env
+bash ./models/tts/maskgct/env.sh
+```
+
+**Model download**
+
+We provide the following pretrained checkpoints:
+
+
+| Model Name | Description |
+|-------------------|-------------|
+| [Acoustic Codec](https://huggingface.co/amphion/MaskGCT/tree/main/acoustic_codec) | Converting speech to semantic tokens. |
+| [Semantic Codec](https://huggingface.co/amphion/MaskGCT/tree/main/semantic_codec) | Converting speech to acoustic tokens and reconstructing waveform from acoustic tokens. |
+| [MaskGCT-T2S](https://huggingface.co/amphion/MaskGCT/tree/main/t2s_model) | Predicting semantic tokens with text and prompt semantic tokens. |
+| [MaskGCT-S2A](https://huggingface.co/amphion/MaskGCT/tree/main/s2a_model) | Predicts acoustic tokens conditioned on semantic tokens. |
+
+You can download all pretrained checkpoints from [HuggingFace](https://huggingface.co/amphion/MaskGCT/tree/main) or use huggingface api.
+
+```python
+from huggingface_hub import hf_hub_download
+
+# download semantic codec ckpt
+semantic_code_ckpt = hf_hub_download("amphion/MaskGCT" filename="semantic_codec/model.safetensors")
+
+# download acoustic codec ckpt
+codec_encoder_ckpt = hf_hub_download("amphion/MaskGCT", filename="acoustic_codec/model.safetensors")
+codec_decoder_ckpt = hf_hub_download("amphion/MaskGCT", filename="acoustic_codec/model_1.safetensors")
+
+# download t2s model ckpt
+t2s_model_ckpt = hf_hub_download("amphion/MaskGCT", filename="t2s_model/model.safetensors")
+
+# download s2a model ckpt
+s2a_1layer_ckpt = hf_hub_download("amphion/MaskGCT", filename="s2a_model/s2a_model_1layer/model.safetensors")
+s2a_full_ckpt = hf_hub_download("amphion/MaskGCT", filename="s2a_model/s2a_model_full/model.safetensors")
+```
+
+**Basic Usage**
+
+You can use the following code to generate speech from text and a prompt speech (the code is also provided in [inference.py](../../../models/tts/maskgct/maskgct_inference.py)).
+
+```python
+from models.tts.maskgct.maskgct_utils import *
+from huggingface_hub import hf_hub_download
+import safetensors
+import soundfile as sf
+
+if __name__ == "__main__":
+
+ # build model
+ device = torch.device("cuda:0")
+ cfg_path = "./models/tts/maskgct/config/maskgct.json"
+ cfg = load_config(cfg_path)
+ # 1. build semantic model (w2v-bert-2.0)
+ semantic_model, semantic_mean, semantic_std = build_semantic_model(device)
+ # 2. build semantic codec
+ semantic_codec = build_semantic_codec(cfg.model.semantic_codec, device)
+ # 3. build acoustic codec
+ codec_encoder, codec_decoder = build_acoustic_codec(cfg.model.acoustic_codec, device)
+ # 4. build t2s model
+ t2s_model = build_t2s_model(cfg.model.t2s_model, device)
+ # 5. build s2a model
+ s2a_model_1layer = build_s2a_model(cfg.model.s2a_model.s2a_1layer, device)
+ s2a_model_full = build_s2a_model(cfg.model.s2a_model.s2a_full, device)
+
+ # download checkpoint
+ ...
+
+ # load semantic codec
+ safetensors.torch.load_model(semantic_codec, semantic_code_ckpt)
+ # load acoustic codec
+ safetensors.torch.load_model(codec_encoder, codec_encoder_ckpt)
+ safetensors.torch.load_model(codec_decoder, codec_decoder_ckpt)
+ # load t2s model
+ safetensors.torch.load_model(t2s_model, t2s_model_ckpt)
+ # load s2a model
+ safetensors.torch.load_model(s2a_model_1layer, s2a_1layer_ckpt)
+ safetensors.torch.load_model(s2a_model_full, s2a_full_ckpt)
+
+ # inference
+ prompt_wav_path = "./models/tts/maskgct/wav/prompt.wav"
+ save_path = "[YOUR SAVE PATH]"
+ prompt_text = " We do not break. We never give in. We never back down."
+ target_text = "In this paper, we introduce MaskGCT, a fully non-autoregressive TTS model that eliminates the need for explicit alignment information between text and speech supervision."
+ # Specify the target duration (in seconds). If target_len = None, we use a simple rule to predict the target duration.
+ target_len = 18
+
+ maskgct_inference_pipeline = MaskGCT_Inference_Pipeline(
+ semantic_model,
+ semantic_codec,
+ codec_encoder,
+ codec_decoder,
+ t2s_model,
+ s2a_model_1layer,
+ s2a_model_full,
+ semantic_mean,
+ semantic_std,
+ device,
+ )
+
+ recovered_audio = maskgct_inference_pipeline.maskgct_inference(
+ prompt_wav_path, prompt_text, target_text, "en", "en", target_len=target_len
+ )
+ sf.write(save_path, recovered_audio, 24000)
+```
+
+**Jupyter Notebook**
+
+We also provide a [jupyter notebook](../../../models/tts/maskgct/maskgct_demo.ipynb) to show more details of MaskGCT inference.
+
+
+## Evaluation Results of MaskGCT
+
+| System | SIM-O↑ | WER↓ | FSD↓ | SMOS↑ | CMOS↑ |
+| :--- | :---: | :---: | :---: | :---: | :---: |
+| | | **LibriSpeech test-clean** |
+| Ground Truth | 0.68 | 1.94 | | 4.05±0.12 | 0.00 |
+| VALL-E | 0.50 | 5.90 | - | 3.47 ±0.26 | -0.52±0.22 |
+| VoiceBox | 0.64 | 2.03 | 0.762 | 3.80±0.17 | -0.41±0.13 |
+| NaturalSpeech 3 | 0.67 | 1.94 | 0.786 | 4.26±0.10 | 0.16±0.14 |
+| VoiceCraft | 0.45 | 4.68 | 0.981 | 3.52±0.21 | -0.33 ±0.16 |
+| XTTS-v2 | 0.51 | 4.20 | 0.945 | 3.02±0.22 | -0.98 ±0.19 |
+| MaskGCT | 0.687(0.723) | 2.634(1.976) | 0.886 | 4.27±0.14 | 0.10±0.16 |
+| MaskGCT(gt length) | 0.697 | 2.012 | 0.746 | 4.33±0.11 | 0.13±0.13 |
+| | | **SeedTTS test-en** |
+| Ground Truth | 0.730 | 2.143 | | 3.92±0.15 | 0.00 |
+| CosyVoice | 0.643 | 4.079 | 0.316 | 3.52±0.17 | -0.41 ±0.18 |
+| XTTS-v2 | 0.463 | 3.248 | 0.484 | 3.15±0.22 | -0.86±0.19 |
+| VoiceCraft | 0.470 | 7.556 | 0.226 | 3.18±0.20 | -1.08 ±0.15 |
+| MaskGCT | 0.717(0.760) | 2.623(1.283) | 0.188 | 4.24 ±0.12 | 0.03 ±0.14 |
+| MaskGCT(gt length) | 0.728 | 2.466 | 0.159 | 4.13 ±0.17 | 0.12 ±0.15 |
+| | | **SeedTTS test-zh** |
+| Ground Truth | 0.750 | 1.254 | | 3.86 ±0.17 | 0.00 |
+| CosyVoice | 0.750 | 4.089 | 0.276 | 3.54 ±0.12 | -0.45 ±0.15 |
+| XTTS-v2 | 0.635 | 2.876 | 0.413 | 2.95 ±0.18 | -0.81 ±0.22 |
+| MaskGCT | 0.774(0.805) | 2.273(0.843) | 0.106 | 4.09 ±0.12 | 0.05 ±0.17 |
+| MaskGCT(gt length) | 0.777 | 2.183 | 0.101 | 4.11 ±0.12 | 0.08±0.18 |
+
+## Citations
+
+If you use MaskGCT in your research, please cite the following paper:
+
+```bibtex
+@article{wang2024maskgct,
+ title={MaskGCT: Zero-Shot Text-to-Speech with Masked Generative Codec Transformer},
+ author={Wang, Yuancheng and Zhan, Haoyue and Liu, Liwei and Zeng, Ruihong and Guo, Haotian and Zheng, Jiachen and Zhang, Qiang and Zhang, Shunsi and Wu, Zhizheng},
+ journal={arXiv preprint arXiv:2409.00750},
+ year={2024}
+}
+
+@article{zhang2023amphion,
+ title={Amphion: An open-source audio, music and speech generation toolkit},
+ author={Zhang, Xueyao and Xue, Liumeng and Wang, Yuancheng and Gu, Yicheng and Chen, Xi and Fang, Zihao and Chen, Haopeng and Zou, Lexiao and Wang, Chaoren and Han, Jun and others},
+ journal={arXiv preprint arXiv:2312.09911},
+ year={2023}
+}
+```
\ No newline at end of file
diff --git a/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt b/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt
new file mode 100644
index 0000000000000000000000000000000000000000..1a8ecb924668659f3b1a9c35b02b2f8839fd8c5a
--- /dev/null
+++ b/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c9c176c2b8850ab2e3ba828bbfa969deaf4566ce55db5f2687b8430b87526ad2
+size 9343
diff --git a/models/tts/maskgct/config/maskgct.json b/models/tts/maskgct/config/maskgct.json
new file mode 100644
index 0000000000000000000000000000000000000000..b4cf5d2caed6c5620d465e8d45cc2214871995ac
--- /dev/null
+++ b/models/tts/maskgct/config/maskgct.json
@@ -0,0 +1,81 @@
+{
+ "model": {
+ "t2s_model": {
+ "hidden_size": 1536,
+ "num_layers": 16,
+ "num_heads": 16,
+ "cfg_scale": 0.15,
+ "cond_codebook_size": 8192,
+ "cond_dim": 1024
+ },
+ "s2a_model": {
+ "s2a_1layer": {
+ "num_quantizer": 1,
+ "hidden_size": 1024,
+ "num_layers": 16,
+ "num_heads": 16,
+ "codebook_size": 1024,
+ "cfg_scale": 0.15,
+ "mask_layer_schedule": "linear",
+ "cond_codebook_size": 8192,
+ "cond_dim": 1024,
+ "predict_layer_1": true
+ },
+ "s2a_full": {
+ "num_quantizer": 12,
+ "hidden_size": 1024,
+ "num_layers": 16,
+ "num_heads": 16,
+ "codebook_size": 1024,
+ "cfg_scale": 0.15,
+ "mask_layer_schedule": "linear",
+ "cond_codebook_size": 8192,
+ "cond_dim": 1024,
+ "predict_layer_1": false
+ }
+ },
+ "semantic_codec": {
+ "codebook_size": 8192,
+ "hidden_size": 1024,
+ "codebook_dim": 8,
+ "vocos_dim": 384,
+ "vocos_intermediate_dim": 2048,
+ "vocos_num_layers": 12
+ },
+ "acoustic_codec": {
+ "encoder": {
+ "d_model": 96,
+ "up_ratios": [3, 4, 5, 8],
+ "out_channels": 256,
+ "use_tanh": false
+ },
+ "decoder": {
+ "in_channel": 256,
+ "upsample_initial_channel": 1536,
+ "up_ratios": [8, 5, 4, 3],
+ "num_quantizers": 12,
+ "codebook_size": 1024,
+ "codebook_dim": 8,
+ "quantizer_type": "fvq",
+ "quantizer_dropout": 0.5,
+ "commitment": 0.25,
+ "codebook_loss_weight": 1.0,
+ "use_l2_normlize": true,
+ "codebook_type": "euclidean",
+ "kmeans_init": false,
+ "kmeans_iters": 10,
+ "decay": 0.8,
+ "eps": 0.5,
+ "threshold_ema_dead_code": 2,
+ "weight_init": false,
+ "use_vocos": true,
+ "vocos_dim": 512,
+ "vocos_intermediate_dim": 4096,
+ "vocos_num_layers": 30,
+ "n_fft": 1920,
+ "hop_size": 480,
+ "padding": "same"
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/models/tts/maskgct/env.sh b/models/tts/maskgct/env.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ed59567365ac7ccee7b678a578683c846a646e5b
--- /dev/null
+++ b/models/tts/maskgct/env.sh
@@ -0,0 +1,25 @@
+pip install setuptools ruamel.yaml tqdm
+pip install tensorboard tensorboardX torch==2.0.1
+pip install transformers===4.41.1
+pip install -U encodec
+pip install black==24.1.1
+pip install oss2
+sudo apt-get install espeak-ng
+pip install phonemizer
+pip install g2p_en
+pip install accelerate==0.31.0
+pip install funasr zhconv zhon modelscope
+# pip install git+https://github.com/lhotse-speech/lhotse
+pip install timm
+pip install jieba cn2an
+pip install unidecode
+pip install -U cos-python-sdk-v5
+pip install pypinyin
+pip install jiwer
+pip install omegaconf
+pip install pyworld
+pip install py3langid==0.2.2 LangSegment
+pip install onnxruntime
+pip install pyopenjtalk
+pip install pykakasi
+pip install -U openai-whisper
\ No newline at end of file
diff --git a/models/tts/maskgct/g2p/g2p/__init__.py b/models/tts/maskgct/g2p/g2p/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca96b67e85e1e242768c94f992bd9d5fa3a3e9ef
--- /dev/null
+++ b/models/tts/maskgct/g2p/g2p/__init__.py
@@ -0,0 +1,87 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from models.tts.maskgct.g2p.g2p import cleaners
+from tokenizers import Tokenizer
+from models.tts.maskgct.g2p.g2p.text_tokenizers import TextTokenizer
+import LangSegment
+import json
+import re
+
+
+class PhonemeBpeTokenizer:
+
+ def __init__(self, vacab_path="./models/tts/maskgct/g2p/g2p/vocab.json"):
+ self.lang2backend = {
+ "zh": "cmn",
+ "ja": "ja",
+ "en": "en-us",
+ "fr": "fr-fr",
+ "ko": "ko",
+ "de": "de",
+ }
+ self.text_tokenizers = {}
+ self.int_text_tokenizers()
+
+ with open(vacab_path, "r") as f:
+ json_data = f.read()
+ data = json.loads(json_data)
+ self.vocab = data["vocab"]
+ LangSegment.setfilters(["en", "zh", "ja", "ko", "fr", "de"])
+
+ def int_text_tokenizers(self):
+ for key, value in self.lang2backend.items():
+ self.text_tokenizers[key] = TextTokenizer(language=value)
+
+ def tokenize(self, text, sentence, language):
+
+ # 1. convert text to phoneme
+ phonemes = []
+ if language == "auto":
+ seglist = LangSegment.getTexts(text)
+ tmp_ph = []
+ for seg in seglist:
+ tmp_ph.append(
+ self._clean_text(
+ seg["text"], sentence, seg["lang"], ["cjekfd_cleaners"]
+ )
+ )
+ phonemes = "|_|".join(tmp_ph)
+ else:
+ phonemes = self._clean_text(text, sentence, language, ["cjekfd_cleaners"])
+ # print('clean text: ', phonemes)
+
+ # 2. tokenize phonemes
+ phoneme_tokens = self.phoneme2token(phonemes)
+ # print('encode: ', phoneme_tokens)
+
+ # # 3. decode tokens [optional]
+ # decoded_text = self.tokenizer.decode(phoneme_tokens)
+ # print('decoded: ', decoded_text)
+
+ return phonemes, phoneme_tokens
+
+ def _clean_text(self, text, sentence, language, cleaner_names):
+ for name in cleaner_names:
+ cleaner = getattr(cleaners, name)
+ if not cleaner:
+ raise Exception("Unknown cleaner: %s" % name)
+ text = cleaner(text, sentence, language, self.text_tokenizers)
+ return text
+
+ def phoneme2token(self, phonemes):
+ tokens = []
+ if isinstance(phonemes, list):
+ for phone in phonemes:
+ phone = phone.split("\t")[0]
+ phonemes_split = phone.split("|")
+ tokens.append(
+ [self.vocab[p] for p in phonemes_split if p in self.vocab]
+ )
+ else:
+ phonemes = phonemes.split("\t")[0]
+ phonemes_split = phonemes.split("|")
+ tokens = [self.vocab[p] for p in phonemes_split if p in self.vocab]
+ return tokens
diff --git a/models/tts/maskgct/g2p/g2p/chinese_model_g2p.py b/models/tts/maskgct/g2p/g2p/chinese_model_g2p.py
new file mode 100644
index 0000000000000000000000000000000000000000..862ebac671818731d98c7df79dfbf079172fa38d
--- /dev/null
+++ b/models/tts/maskgct/g2p/g2p/chinese_model_g2p.py
@@ -0,0 +1,212 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import numpy as np
+import torch
+from torch.utils.data import DataLoader
+import json
+from transformers import BertTokenizer
+from torch.utils.data import Dataset
+from transformers.models.bert.modeling_bert import *
+import torch.nn.functional as F
+from onnxruntime import InferenceSession, GraphOptimizationLevel, SessionOptions
+
+
+class PolyDataset(Dataset):
+ def __init__(self, words, labels, word_pad_idx=0, label_pad_idx=-1):
+ self.dataset = self.preprocess(words, labels)
+ self.word_pad_idx = word_pad_idx
+ self.label_pad_idx = label_pad_idx
+
+ def preprocess(self, origin_sentences, origin_labels):
+ """
+ Maps tokens and tags to their indices and stores them in the dict data.
+ examples:
+ word:['[CLS]', '浙', '商', '银', '行', '企', '业', '信', '贷', '部']
+ sentence:([101, 3851, 1555, 7213, 6121, 821, 689, 928, 6587, 6956],
+ array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))
+ label:[3, 13, 13, 13, 0, 0, 0, 0, 0]
+ """
+ data = []
+ labels = []
+ sentences = []
+ # tokenize
+ for line in origin_sentences:
+ # replace each token by its index
+ # we can not use encode_plus because our sentences are aligned to labels in list type
+ words = []
+ word_lens = []
+ for token in line:
+ words.append(token)
+ word_lens.append(1)
+ token_start_idxs = 1 + np.cumsum([0] + word_lens[:-1])
+ sentences.append(((words, token_start_idxs), 0))
+ ###
+ for tag in origin_labels:
+ labels.append(tag)
+
+ for sentence, label in zip(sentences, labels):
+ data.append((sentence, label))
+ return data
+
+ def __getitem__(self, idx):
+ """sample data to get batch"""
+ word = self.dataset[idx][0]
+ label = self.dataset[idx][1]
+ return [word, label]
+
+ def __len__(self):
+ """get dataset size"""
+ return len(self.dataset)
+
+ def collate_fn(self, batch):
+
+ sentences = [x[0][0] for x in batch]
+ ori_sents = [x[0][1] for x in batch]
+ labels = [x[1] for x in batch]
+ batch_len = len(sentences)
+
+ # compute length of longest sentence in batch
+ max_len = max([len(s[0]) for s in sentences])
+ max_label_len = 0
+ batch_data = np.ones((batch_len, max_len))
+ batch_label_starts = []
+
+ # padding and aligning
+ for j in range(batch_len):
+ cur_len = len(sentences[j][0])
+ batch_data[j][:cur_len] = sentences[j][0]
+ label_start_idx = sentences[j][-1]
+ label_starts = np.zeros(max_len)
+ label_starts[[idx for idx in label_start_idx if idx < max_len]] = 1
+ batch_label_starts.append(label_starts)
+ max_label_len = max(int(sum(label_starts)), max_label_len)
+
+ # padding label
+ batch_labels = self.label_pad_idx * np.ones((batch_len, max_label_len))
+ batch_pmasks = self.label_pad_idx * np.ones((batch_len, max_label_len))
+ for j in range(batch_len):
+ cur_tags_len = len(labels[j])
+ batch_labels[j][:cur_tags_len] = labels[j]
+ batch_pmasks[j][:cur_tags_len] = [
+ 1 if item > 0 else 0 for item in labels[j]
+ ]
+
+ # convert data to torch LongTensors
+ batch_data = torch.tensor(batch_data, dtype=torch.long)
+ batch_label_starts = torch.tensor(batch_label_starts, dtype=torch.long)
+ batch_labels = torch.tensor(batch_labels, dtype=torch.long)
+ batch_pmasks = torch.tensor(batch_pmasks, dtype=torch.long)
+ return [batch_data, batch_label_starts, batch_labels, batch_pmasks, ori_sents]
+
+
+class BertPolyPredict:
+ def __init__(self, bert_model, jsonr_file, json_file):
+ self.tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=True)
+ with open(jsonr_file, "r", encoding="utf8") as fp:
+ self.pron_dict = json.load(fp)
+ with open(json_file, "r", encoding="utf8") as fp:
+ self.pron_dict_id_2_pinyin = json.load(fp)
+ self.num_polyphone = len(self.pron_dict)
+ self.device = "cpu"
+ self.polydataset = PolyDataset
+ options = SessionOptions() # initialize session options
+ options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
+ print(os.path.join(bert_model, "poly_bert_model.onnx"))
+ self.session = InferenceSession(
+ os.path.join(bert_model, "poly_bert_model.onnx"),
+ sess_options=options,
+ providers=[
+ "CUDAExecutionProvider",
+ "CPUExecutionProvider",
+ ], # CPUExecutionProvider #CUDAExecutionProvider
+ )
+ # self.session.set_providers(['CUDAExecutionProvider', "CPUExecutionProvider"], [ {'device_id': 0}])
+
+ # disable session.run() fallback mechanism, it prevents for a reset of the execution provider
+ self.session.disable_fallback()
+
+ def predict_process(self, txt_list):
+ word_test, label_test, texts_test = self.get_examples_po(txt_list)
+ data = self.polydataset(word_test, label_test)
+ predict_loader = DataLoader(
+ data, batch_size=1, shuffle=False, collate_fn=data.collate_fn
+ )
+ pred_tags = self.predict_onnx(predict_loader)
+ return pred_tags
+
+ def predict_onnx(self, dev_loader):
+ pred_tags = []
+ with torch.no_grad():
+ for idx, batch_samples in enumerate(dev_loader):
+ # [batch_data, batch_label_starts, batch_labels, batch_pmasks, ori_sents]
+ batch_data, batch_label_starts, batch_labels, batch_pmasks, _ = (
+ batch_samples
+ )
+ # shift tensors to GPU if available
+ batch_data = batch_data.to(self.device)
+ batch_label_starts = batch_label_starts.to(self.device)
+ batch_labels = batch_labels.to(self.device)
+ batch_pmasks = batch_pmasks.to(self.device)
+ batch_data = np.asarray(batch_data, dtype=np.int32)
+ batch_pmasks = np.asarray(batch_pmasks, dtype=np.int32)
+ # batch_output = self.session.run(output_names=['outputs'], input_feed={"input_ids":batch_data, "input_pmasks": batch_pmasks})[0][0]
+ batch_output = self.session.run(
+ output_names=["outputs"], input_feed={"input_ids": batch_data}
+ )[0]
+ label_masks = batch_pmasks == 1
+ batch_labels = batch_labels.to("cpu").numpy()
+ for i, indices in enumerate(np.argmax(batch_output, axis=2)):
+ for j, idx in enumerate(indices):
+ if label_masks[i][j]:
+ # pred_tag.append(idx)
+ pred_tags.append(self.pron_dict_id_2_pinyin[str(idx + 1)])
+ return pred_tags
+
+ def get_examples_po(self, text_list):
+
+ word_list = []
+ label_list = []
+ sentence_list = []
+ id = 0
+ for line in [text_list]:
+ sentence = line[0]
+ words = []
+ tokens = line[0]
+ index = line[-1]
+ front = index
+ back = len(tokens) - index - 1
+ labels = [0] * front + [1] + [0] * back
+ words = ["[CLS]"] + [item for item in sentence]
+ words = self.tokenizer.convert_tokens_to_ids(words)
+ word_list.append(words)
+ label_list.append(labels)
+ sentence_list.append(sentence)
+
+ id += 1
+ # mask_list.append(masks)
+ assert len(labels) + 1 == len(words), print(
+ (
+ poly,
+ sentence,
+ words,
+ labels,
+ sentence,
+ len(sentence),
+ len(words),
+ len(labels),
+ )
+ )
+ assert len(labels) + 1 == len(
+ words
+ ), "Number of labels does not match number of words"
+ assert len(labels) == len(
+ sentence
+ ), "Number of labels does not match number of sentences"
+ assert len(word_list) == len(
+ label_list
+ ), "Number of label sentences does not match number of word sentences"
+ return word_list, label_list, text_list
diff --git a/models/tts/maskgct/g2p/g2p/cleaners.py b/models/tts/maskgct/g2p/g2p/cleaners.py
new file mode 100644
index 0000000000000000000000000000000000000000..b25aabf7651902471c3caf00a012366ea649b90c
--- /dev/null
+++ b/models/tts/maskgct/g2p/g2p/cleaners.py
@@ -0,0 +1,31 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import re
+from models.tts.maskgct.g2p.g2p.japanese import japanese_to_ipa
+from models.tts.maskgct.g2p.g2p.mandarin import chinese_to_ipa
+from models.tts.maskgct.g2p.g2p.english import english_to_ipa
+from models.tts.maskgct.g2p.g2p.french import french_to_ipa
+from models.tts.maskgct.g2p.g2p.korean import korean_to_ipa
+from models.tts.maskgct.g2p.g2p.german import german_to_ipa
+
+
+def cjekfd_cleaners(text, sentence, language, text_tokenizers):
+
+ if language == "zh":
+ return chinese_to_ipa(text, sentence, text_tokenizers["zh"])
+ elif language == "ja":
+ return japanese_to_ipa(text, text_tokenizers["ja"])
+ elif language == "en":
+ return english_to_ipa(text, text_tokenizers["en"])
+ elif language == "fr":
+ return french_to_ipa(text, text_tokenizers["fr"])
+ elif language == "ko":
+ return korean_to_ipa(text, text_tokenizers["ko"])
+ elif language == "de":
+ return german_to_ipa(text, text_tokenizers["de"])
+ else:
+ raise Exception("Unknown language: %s" % language)
+ return None
diff --git a/models/tts/maskgct/g2p/g2p/english.py b/models/tts/maskgct/g2p/g2p/english.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8f349fd621ba3d6aa110f447238249642d80326
--- /dev/null
+++ b/models/tts/maskgct/g2p/g2p/english.py
@@ -0,0 +1,202 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import re
+from unidecode import unidecode
+import inflect
+
+"""
+ Text clean time
+"""
+_inflect = inflect.engine()
+_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
+_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
+_percent_number_re = re.compile(r"([0-9\.\,]*[0-9]+%)")
+_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
+_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
+_fraction_re = re.compile(r"([0-9]+)/([0-9]+)")
+_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
+_number_re = re.compile(r"[0-9]+")
+
+# List of (regular expression, replacement) pairs for abbreviations:
+_abbreviations = [
+ (re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1])
+ for x in [
+ ("mrs", "misess"),
+ ("mr", "mister"),
+ ("dr", "doctor"),
+ ("st", "saint"),
+ ("co", "company"),
+ ("jr", "junior"),
+ ("maj", "major"),
+ ("gen", "general"),
+ ("drs", "doctors"),
+ ("rev", "reverend"),
+ ("lt", "lieutenant"),
+ ("hon", "honorable"),
+ ("sgt", "sergeant"),
+ ("capt", "captain"),
+ ("esq", "esquire"),
+ ("ltd", "limited"),
+ ("col", "colonel"),
+ ("ft", "fort"),
+ ("etc", "et cetera"),
+ ("btw", "by the way"),
+ ]
+]
+
+_special_map = [
+ ("t|ɹ", "tɹ"),
+ ("d|ɹ", "dɹ"),
+ ("t|s", "ts"),
+ ("d|z", "dz"),
+ ("ɪ|ɹ", "ɪɹ"),
+ ("ɐ", "ɚ"),
+ ("ᵻ", "ɪ"),
+ ("əl", "l"),
+ ("x", "k"),
+ ("ɬ", "l"),
+ ("ʔ", "t"),
+ ("n̩", "n"),
+ ("oː|ɹ", "oːɹ"),
+]
+
+
+def expand_abbreviations(text):
+ for regex, replacement in _abbreviations:
+ text = re.sub(regex, replacement, text)
+ return text
+
+
+def _remove_commas(m):
+ return m.group(1).replace(",", "")
+
+
+def _expand_decimal_point(m):
+ return m.group(1).replace(".", " point ")
+
+
+def _expand_percent(m):
+ return m.group(1).replace("%", " percent ")
+
+
+def _expand_dollars(m):
+ match = m.group(1)
+ parts = match.split(".")
+ if len(parts) > 2:
+ return " " + match + " dollars " # Unexpected format
+ dollars = int(parts[0]) if parts[0] else 0
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
+ if dollars and cents:
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
+ cent_unit = "cent" if cents == 1 else "cents"
+ return " %s %s, %s %s " % (dollars, dollar_unit, cents, cent_unit)
+ elif dollars:
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
+ return " %s %s " % (dollars, dollar_unit)
+ elif cents:
+ cent_unit = "cent" if cents == 1 else "cents"
+ return " %s %s " % (cents, cent_unit)
+ else:
+ return " zero dollars "
+
+
+def fraction_to_words(numerator, denominator):
+ if numerator == 1 and denominator == 2:
+ return " one half "
+ if numerator == 1 and denominator == 4:
+ return " one quarter "
+ if denominator == 2:
+ return " " + _inflect.number_to_words(numerator) + " halves "
+ if denominator == 4:
+ return " " + _inflect.number_to_words(numerator) + " quarters "
+ return (
+ " "
+ + _inflect.number_to_words(numerator)
+ + " "
+ + _inflect.ordinal(_inflect.number_to_words(denominator))
+ + " "
+ )
+
+
+def _expand_fraction(m):
+ numerator = int(m.group(1))
+ denominator = int(m.group(2))
+ return fraction_to_words(numerator, denominator)
+
+
+def _expand_ordinal(m):
+ return " " + _inflect.number_to_words(m.group(0)) + " "
+
+
+def _expand_number(m):
+ num = int(m.group(0))
+ if num > 1000 and num < 3000:
+ if num == 2000:
+ return " two thousand "
+ elif num > 2000 and num < 2010:
+ return " two thousand " + _inflect.number_to_words(num % 100) + " "
+ elif num % 100 == 0:
+ return " " + _inflect.number_to_words(num // 100) + " hundred "
+ else:
+ return (
+ " "
+ + _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(
+ ", ", " "
+ )
+ + " "
+ )
+ else:
+ return " " + _inflect.number_to_words(num, andword="") + " "
+
+
+# Normalize numbers pronunciation
+def normalize_numbers(text):
+ text = re.sub(_comma_number_re, _remove_commas, text)
+ text = re.sub(_pounds_re, r"\1 pounds", text)
+ text = re.sub(_dollars_re, _expand_dollars, text)
+ text = re.sub(_fraction_re, _expand_fraction, text)
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
+ text = re.sub(_percent_number_re, _expand_percent, text)
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
+ text = re.sub(_number_re, _expand_number, text)
+ return text
+
+
+def _english_to_ipa(text):
+ # text = unidecode(text).lower()
+ text = expand_abbreviations(text)
+ text = normalize_numbers(text)
+ return text
+
+
+# special map
+def special_map(text):
+ for regex, replacement in _special_map:
+ regex = regex.replace("|", "\|")
+ while re.search(r"(^|[_|]){}([_|]|$)".format(regex), text):
+ text = re.sub(
+ r"(^|[_|]){}([_|]|$)".format(regex), r"\1{}\2".format(replacement), text
+ )
+ # text = re.sub(r'([,.!?])', r'|\1', text)
+ return text
+
+
+# Add some special operation
+def english_to_ipa(text, text_tokenizer):
+ if type(text) == str:
+ text = _english_to_ipa(text)
+ else:
+ text = [_english_to_ipa(t) for t in text]
+ phonemes = text_tokenizer(text)
+ if phonemes[-1] in "p⁼ʰmftnlkxʃs`ɹaoəɛɪeɑʊŋiuɥwæjː":
+ phonemes += "|_"
+ if type(text) == str:
+ return special_map(phonemes)
+ else:
+ result_ph = []
+ for phone in phonemes:
+ result_ph.append(special_map(phone))
+ return result_ph
diff --git a/models/tts/maskgct/g2p/g2p/french.py b/models/tts/maskgct/g2p/g2p/french.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd9400cdfc6598e7d642480cbfc1f990fc78cddf
--- /dev/null
+++ b/models/tts/maskgct/g2p/g2p/french.py
@@ -0,0 +1,149 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import re
+
+"""
+ Text clean time
+"""
+# List of (regular expression, replacement) pairs for abbreviations in french:
+_abbreviations = [
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
+ for x in [
+ ("M", "monsieur"),
+ ("Mlle", "mademoiselle"),
+ ("Mlles", "mesdemoiselles"),
+ ("Mme", "Madame"),
+ ("Mmes", "Mesdames"),
+ ("N.B", "nota bene"),
+ ("M", "monsieur"),
+ ("p.c.q", "parce que"),
+ ("Pr", "professeur"),
+ ("qqch", "quelque chose"),
+ ("rdv", "rendez-vous"),
+ ("max", "maximum"),
+ ("min", "minimum"),
+ ("no", "numéro"),
+ ("adr", "adresse"),
+ ("dr", "docteur"),
+ ("st", "saint"),
+ ("co", "companie"),
+ ("jr", "junior"),
+ ("sgt", "sergent"),
+ ("capt", "capitain"),
+ ("col", "colonel"),
+ ("av", "avenue"),
+ ("av. J.-C", "avant Jésus-Christ"),
+ ("apr. J.-C", "après Jésus-Christ"),
+ ("art", "article"),
+ ("boul", "boulevard"),
+ ("c.-à-d", "c’est-à-dire"),
+ ("etc", "et cetera"),
+ ("ex", "exemple"),
+ ("excl", "exclusivement"),
+ ("boul", "boulevard"),
+ ]
+] + [
+ (re.compile("\\b%s" % x[0]), x[1])
+ for x in [
+ ("Mlle", "mademoiselle"),
+ ("Mlles", "mesdemoiselles"),
+ ("Mme", "Madame"),
+ ("Mmes", "Mesdames"),
+ ]
+]
+
+rep_map = {
+ ":": ",",
+ ";": ",",
+ ",": ",",
+ "。": ".",
+ "!": "!",
+ "?": "?",
+ "\n": ".",
+ "·": ",",
+ "、": ",",
+ "...": ".",
+ "…": ".",
+ "$": ".",
+ "“": "",
+ "”": "",
+ "‘": "",
+ "’": "",
+ "(": "",
+ ")": "",
+ "(": "",
+ ")": "",
+ "《": "",
+ "》": "",
+ "【": "",
+ "】": "",
+ "[": "",
+ "]": "",
+ "—": "",
+ "~": "-",
+ "~": "-",
+ "「": "",
+ "」": "",
+ "¿": "",
+ "¡": "",
+}
+
+
+def collapse_whitespace(text):
+ # Regular expression matching whitespace:
+ _whitespace_re = re.compile(r"\s+")
+ return re.sub(_whitespace_re, " ", text).strip()
+
+
+def remove_punctuation_at_begin(text):
+ return re.sub(r"^[,.!?]+", "", text)
+
+
+def remove_aux_symbols(text):
+ text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text)
+ return text
+
+
+def replace_symbols(text):
+ text = text.replace(";", ",")
+ text = text.replace("-", " ")
+ text = text.replace(":", ",")
+ text = text.replace("&", " et ")
+ return text
+
+
+def expand_abbreviations(text):
+ for regex, replacement in _abbreviations:
+ text = re.sub(regex, replacement, text)
+ return text
+
+
+def replace_punctuation(text):
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
+ return replaced_text
+
+
+def text_normalize(text):
+ text = expand_abbreviations(text)
+ text = replace_punctuation(text)
+ text = replace_symbols(text)
+ text = remove_aux_symbols(text)
+ text = remove_punctuation_at_begin(text)
+ text = collapse_whitespace(text)
+ text = re.sub(r"([^\.,!\?\-…])$", r"\1", text)
+ return text
+
+
+def french_to_ipa(text, text_tokenizer):
+ if type(text) == str:
+ text = text_normalize(text)
+ phonemes = text_tokenizer(text)
+ return phonemes
+ else:
+ for i, t in enumerate(text):
+ text[i] = text_normalize(t)
+ return text_tokenizer(text)
diff --git a/models/tts/maskgct/g2p/g2p/german.py b/models/tts/maskgct/g2p/g2p/german.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd82eeabc44cc891acd98daa982cd2be1e991e3a
--- /dev/null
+++ b/models/tts/maskgct/g2p/g2p/german.py
@@ -0,0 +1,94 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import re
+
+"""
+ Text clean time
+"""
+rep_map = {
+ ":": ",",
+ ";": ",",
+ ",": ",",
+ "。": ".",
+ "!": "!",
+ "?": "?",
+ "\n": ".",
+ "·": ",",
+ "、": ",",
+ "...": ".",
+ "…": ".",
+ "$": ".",
+ "“": "",
+ "”": "",
+ "‘": "",
+ "’": "",
+ "(": "",
+ ")": "",
+ "(": "",
+ ")": "",
+ "《": "",
+ "》": "",
+ "【": "",
+ "】": "",
+ "[": "",
+ "]": "",
+ "—": "",
+ "~": "-",
+ "~": "-",
+ "「": "",
+ "」": "",
+ "¿": "",
+ "¡": "",
+}
+
+
+def collapse_whitespace(text):
+ # Regular expression matching whitespace:
+ _whitespace_re = re.compile(r"\s+")
+ return re.sub(_whitespace_re, " ", text).strip()
+
+
+def remove_punctuation_at_begin(text):
+ return re.sub(r"^[,.!?]+", "", text)
+
+
+def remove_aux_symbols(text):
+ text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text)
+ return text
+
+
+def replace_symbols(text):
+ text = text.replace(";", ",")
+ text = text.replace("-", " ")
+ text = text.replace(":", ",")
+ return text
+
+
+def replace_punctuation(text):
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
+ return replaced_text
+
+
+def text_normalize(text):
+ text = replace_punctuation(text)
+ text = replace_symbols(text)
+ text = remove_aux_symbols(text)
+ text = remove_punctuation_at_begin(text)
+ text = collapse_whitespace(text)
+ text = re.sub(r"([^\.,!\?\-…])$", r"\1", text)
+ return text
+
+
+def german_to_ipa(text, text_tokenizer):
+ if type(text) == str:
+ text = text_normalize(text)
+ phonemes = text_tokenizer(text)
+ return phonemes
+ else:
+ for i, t in enumerate(text):
+ text[i] = text_normalize(t)
+ return text_tokenizer(text)
diff --git a/models/tts/maskgct/g2p/g2p/japanese.py b/models/tts/maskgct/g2p/g2p/japanese.py
new file mode 100644
index 0000000000000000000000000000000000000000..bef50ce2761aea755b9d9bada637bd3f30f331d9
--- /dev/null
+++ b/models/tts/maskgct/g2p/g2p/japanese.py
@@ -0,0 +1,816 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import io, re, os, sys, time, argparse, pdb, json
+from io import StringIO
+from typing import Optional
+import numpy as np
+import traceback
+import pyopenjtalk
+from pykakasi import kakasi
+
+punctuation = [",", ".", "!", "?", ":", ";", "'", "…"]
+
+jp_xphone2ipa = [
+ " a a",
+ " i i",
+ " u ɯ",
+ " e e",
+ " o o",
+ " a: aː",
+ " i: iː",
+ " u: ɯː",
+ " e: eː",
+ " o: oː",
+ " k k",
+ " s s",
+ " t t",
+ " n n",
+ " h ç",
+ " f ɸ",
+ " m m",
+ " y j",
+ " r ɾ",
+ " w ɰᵝ",
+ " N ɴ",
+ " g g",
+ " j d ʑ",
+ " z z",
+ " d d",
+ " b b",
+ " p p",
+ " q q",
+ " v v",
+ " : :",
+ " by b j",
+ " ch t ɕ",
+ " dy d e j",
+ " ty t e j",
+ " gy g j",
+ " gw g ɯ",
+ " hy ç j",
+ " ky k j",
+ " kw k ɯ",
+ " my m j",
+ " ny n j",
+ " py p j",
+ " ry ɾ j",
+ " sh ɕ",
+ " ts t s ɯ",
+]
+
+_mora_list_minimum: list[tuple[str, Optional[str], str]] = [
+ ("ヴォ", "v", "o"),
+ ("ヴェ", "v", "e"),
+ ("ヴィ", "v", "i"),
+ ("ヴァ", "v", "a"),
+ ("ヴ", "v", "u"),
+ ("ン", None, "N"),
+ ("ワ", "w", "a"),
+ ("ロ", "r", "o"),
+ ("レ", "r", "e"),
+ ("ル", "r", "u"),
+ ("リョ", "ry", "o"),
+ ("リュ", "ry", "u"),
+ ("リャ", "ry", "a"),
+ ("リェ", "ry", "e"),
+ ("リ", "r", "i"),
+ ("ラ", "r", "a"),
+ ("ヨ", "y", "o"),
+ ("ユ", "y", "u"),
+ ("ヤ", "y", "a"),
+ ("モ", "m", "o"),
+ ("メ", "m", "e"),
+ ("ム", "m", "u"),
+ ("ミョ", "my", "o"),
+ ("ミュ", "my", "u"),
+ ("ミャ", "my", "a"),
+ ("ミェ", "my", "e"),
+ ("ミ", "m", "i"),
+ ("マ", "m", "a"),
+ ("ポ", "p", "o"),
+ ("ボ", "b", "o"),
+ ("ホ", "h", "o"),
+ ("ペ", "p", "e"),
+ ("ベ", "b", "e"),
+ ("ヘ", "h", "e"),
+ ("プ", "p", "u"),
+ ("ブ", "b", "u"),
+ ("フォ", "f", "o"),
+ ("フェ", "f", "e"),
+ ("フィ", "f", "i"),
+ ("ファ", "f", "a"),
+ ("フ", "f", "u"),
+ ("ピョ", "py", "o"),
+ ("ピュ", "py", "u"),
+ ("ピャ", "py", "a"),
+ ("ピェ", "py", "e"),
+ ("ピ", "p", "i"),
+ ("ビョ", "by", "o"),
+ ("ビュ", "by", "u"),
+ ("ビャ", "by", "a"),
+ ("ビェ", "by", "e"),
+ ("ビ", "b", "i"),
+ ("ヒョ", "hy", "o"),
+ ("ヒュ", "hy", "u"),
+ ("ヒャ", "hy", "a"),
+ ("ヒェ", "hy", "e"),
+ ("ヒ", "h", "i"),
+ ("パ", "p", "a"),
+ ("バ", "b", "a"),
+ ("ハ", "h", "a"),
+ ("ノ", "n", "o"),
+ ("ネ", "n", "e"),
+ ("ヌ", "n", "u"),
+ ("ニョ", "ny", "o"),
+ ("ニュ", "ny", "u"),
+ ("ニャ", "ny", "a"),
+ ("ニェ", "ny", "e"),
+ ("ニ", "n", "i"),
+ ("ナ", "n", "a"),
+ ("ドゥ", "d", "u"),
+ ("ド", "d", "o"),
+ ("トゥ", "t", "u"),
+ ("ト", "t", "o"),
+ ("デョ", "dy", "o"),
+ ("デュ", "dy", "u"),
+ ("デャ", "dy", "a"),
+ # ("デェ", "dy", "e"),
+ ("ディ", "d", "i"),
+ ("デ", "d", "e"),
+ ("テョ", "ty", "o"),
+ ("テュ", "ty", "u"),
+ ("テャ", "ty", "a"),
+ ("ティ", "t", "i"),
+ ("テ", "t", "e"),
+ ("ツォ", "ts", "o"),
+ ("ツェ", "ts", "e"),
+ ("ツィ", "ts", "i"),
+ ("ツァ", "ts", "a"),
+ ("ツ", "ts", "u"),
+ ("ッ", None, "q"), # 「cl」から「q」に変更
+ ("チョ", "ch", "o"),
+ ("チュ", "ch", "u"),
+ ("チャ", "ch", "a"),
+ ("チェ", "ch", "e"),
+ ("チ", "ch", "i"),
+ ("ダ", "d", "a"),
+ ("タ", "t", "a"),
+ ("ゾ", "z", "o"),
+ ("ソ", "s", "o"),
+ ("ゼ", "z", "e"),
+ ("セ", "s", "e"),
+ ("ズィ", "z", "i"),
+ ("ズ", "z", "u"),
+ ("スィ", "s", "i"),
+ ("ス", "s", "u"),
+ ("ジョ", "j", "o"),
+ ("ジュ", "j", "u"),
+ ("ジャ", "j", "a"),
+ ("ジェ", "j", "e"),
+ ("ジ", "j", "i"),
+ ("ショ", "sh", "o"),
+ ("シュ", "sh", "u"),
+ ("シャ", "sh", "a"),
+ ("シェ", "sh", "e"),
+ ("シ", "sh", "i"),
+ ("ザ", "z", "a"),
+ ("サ", "s", "a"),
+ ("ゴ", "g", "o"),
+ ("コ", "k", "o"),
+ ("ゲ", "g", "e"),
+ ("ケ", "k", "e"),
+ ("グヮ", "gw", "a"),
+ ("グ", "g", "u"),
+ ("クヮ", "kw", "a"),
+ ("ク", "k", "u"),
+ ("ギョ", "gy", "o"),
+ ("ギュ", "gy", "u"),
+ ("ギャ", "gy", "a"),
+ ("ギェ", "gy", "e"),
+ ("ギ", "g", "i"),
+ ("キョ", "ky", "o"),
+ ("キュ", "ky", "u"),
+ ("キャ", "ky", "a"),
+ ("キェ", "ky", "e"),
+ ("キ", "k", "i"),
+ ("ガ", "g", "a"),
+ ("カ", "k", "a"),
+ ("オ", None, "o"),
+ ("エ", None, "e"),
+ ("ウォ", "w", "o"),
+ ("ウェ", "w", "e"),
+ ("ウィ", "w", "i"),
+ ("ウ", None, "u"),
+ ("イェ", "y", "e"),
+ ("イ", None, "i"),
+ ("ア", None, "a"),
+]
+
+_mora_list_additional: list[tuple[str, Optional[str], str]] = [
+ ("ヴョ", "by", "o"),
+ ("ヴュ", "by", "u"),
+ ("ヴャ", "by", "a"),
+ ("ヲ", None, "o"),
+ ("ヱ", None, "e"),
+ ("ヰ", None, "i"),
+ ("ヮ", "w", "a"),
+ ("ョ", "y", "o"),
+ ("ュ", "y", "u"),
+ ("ヅ", "z", "u"),
+ ("ヂ", "j", "i"),
+ ("ヶ", "k", "e"),
+ ("ャ", "y", "a"),
+ ("ォ", None, "o"),
+ ("ェ", None, "e"),
+ ("ゥ", None, "u"),
+ ("ィ", None, "i"),
+ ("ァ", None, "a"),
+]
+
+# 例: "vo" -> "ヴォ", "a" -> "ア"
+mora_phonemes_to_mora_kata: dict[str, str] = {
+ (consonant or "") + vowel: kana for [kana, consonant, vowel] in _mora_list_minimum
+}
+
+# 例: "ヴォ" -> ("v", "o"), "ア" -> (None, "a")
+mora_kata_to_mora_phonemes: dict[str, tuple[Optional[str], str]] = {
+ kana: (consonant, vowel)
+ for [kana, consonant, vowel] in _mora_list_minimum + _mora_list_additional
+}
+
+
+# 正規化で記号を変換するための辞書
+rep_map = {
+ ":": ":",
+ ";": ";",
+ ",": ",",
+ "。": ".",
+ "!": "!",
+ "?": "?",
+ "\n": ".",
+ ".": ".",
+ "⋯": "…",
+ "···": "…",
+ "・・・": "…",
+ "·": ",",
+ "・": ",",
+ "•": ",",
+ "、": ",",
+ "$": ".",
+ # "“": "'",
+ # "”": "'",
+ # '"': "'",
+ "‘": "'",
+ "’": "'",
+ # "(": "'",
+ # ")": "'",
+ # "(": "'",
+ # ")": "'",
+ # "《": "'",
+ # "》": "'",
+ # "【": "'",
+ # "】": "'",
+ # "[": "'",
+ # "]": "'",
+ # "——": "-",
+ # "−": "-",
+ # "-": "-",
+ # "『": "'",
+ # "』": "'",
+ # "〈": "'",
+ # "〉": "'",
+ # "«": "'",
+ # "»": "'",
+ # # "~": "-", # これは長音記号「ー」として扱うよう変更
+ # # "~": "-", # これは長音記号「ー」として扱うよう変更
+ # "「": "'",
+ # "」": "'",
+}
+
+
+def _numeric_feature_by_regex(regex, s):
+ match = re.search(regex, s)
+ if match is None:
+ return -50
+ return int(match.group(1))
+
+
+def replace_punctuation(text: str) -> str:
+ """句読点等を「.」「,」「!」「?」「'」「-」に正規化し、OpenJTalkで読みが取得できるもののみ残す:
+ 漢字・平仮名・カタカナ、アルファベット、ギリシャ文字
+ """
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
+ # print("before: ", text)
+ # 句読点を辞書で置換
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
+
+ replaced_text = re.sub(
+ # ↓ ひらがな、カタカナ、漢字
+ r"[^\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF\u3400-\u4DBF\u3005"
+ # ↓ 半角アルファベット(大文字と小文字)
+ + r"\u0041-\u005A\u0061-\u007A"
+ # ↓ 全角アルファベット(大文字と小文字)
+ + r"\uFF21-\uFF3A\uFF41-\uFF5A"
+ # ↓ ギリシャ文字
+ + r"\u0370-\u03FF\u1F00-\u1FFF"
+ # ↓ "!", "?", "…", ",", ".", "'", "-", 但し`…`はすでに`...`に変換されている
+ + "".join(punctuation) + r"]+",
+ # 上述以外の文字を削除
+ "",
+ replaced_text,
+ )
+ # print("after: ", replaced_text)
+ return replaced_text
+
+
+def fix_phone_tone(phone_tone_list: list[tuple[str, int]]) -> list[tuple[str, int]]:
+ """
+ `phone_tone_list`のtone(アクセントの値)を0か1の範囲に修正する。
+ 例: [(a, 0), (i, -1), (u, -1)] → [(a, 1), (i, 0), (u, 0)]
+ """
+ tone_values = set(tone for _, tone in phone_tone_list)
+ if len(tone_values) == 1:
+ assert tone_values == {0}, tone_values
+ return phone_tone_list
+ elif len(tone_values) == 2:
+ if tone_values == {0, 1}:
+ return phone_tone_list
+ elif tone_values == {-1, 0}:
+ return [
+ (letter, 0 if tone == -1 else 1) for letter, tone in phone_tone_list
+ ]
+ else:
+ raise ValueError(f"Unexpected tone values: {tone_values}")
+ else:
+ raise ValueError(f"Unexpected tone values: {tone_values}")
+
+
+def fix_phone_tone_wplen(phone_tone_list, word_phone_length_list):
+ phones = []
+ tones = []
+ w_p_len = []
+ p_len = len(phone_tone_list)
+ idx = 0
+ w_idx = 0
+ while idx < p_len:
+ offset = 0
+ if phone_tone_list[idx] == "▁":
+ w_p_len.append(w_idx + 1)
+
+ curr_w_p_len = word_phone_length_list[w_idx]
+ for i in range(curr_w_p_len):
+ p, t = phone_tone_list[idx]
+ if p == ":" and len(phones) > 0:
+ if phones[-1][-1] != ":":
+ phones[-1] += ":"
+ offset -= 1
+ else:
+ phones.append(p)
+ tones.append(str(t))
+ idx += 1
+ if idx >= p_len:
+ break
+ w_p_len.append(curr_w_p_len + offset)
+ w_idx += 1
+ # print(w_p_len)
+ return phones, tones, w_p_len
+
+
+def g2phone_tone_wo_punct(prosodies) -> list[tuple[str, int]]:
+ """
+ テキストに対して、音素とアクセント(0か1)のペアのリストを返す。
+ ただし「!」「.」「?」等の非音素記号(punctuation)は全て消える(ポーズ記号も残さない)。
+ 非音素記号を含める処理は`align_tones()`で行われる。
+ また「っ」は「cl」でなく「q」に変換される(「ん」は「N」のまま)。
+ 例: "こんにちは、世界ー。。元気?!" →
+ [('k', 0), ('o', 0), ('N', 1), ('n', 1), ('i', 1), ('ch', 1), ('i', 1), ('w', 1), ('a', 1), ('s', 1), ('e', 1), ('k', 0), ('a', 0), ('i', 0), ('i', 0), ('g', 1), ('e', 1), ('N', 0), ('k', 0), ('i', 0)]
+ """
+ result: list[tuple[str, int]] = []
+ current_phrase: list[tuple[str, int]] = []
+ current_tone = 0
+ last_accent = ""
+ for i, letter in enumerate(prosodies):
+ # 特殊記号の処理
+
+ # 文頭記号、無視する
+ if letter == "^":
+ assert i == 0, "Unexpected ^"
+ # アクセント句の終わりに来る記号
+ elif letter in ("$", "?", "_", "#"):
+ # 保持しているフレーズを、アクセント数値を0-1に修正し結果に追加
+ result.extend(fix_phone_tone(current_phrase))
+ # 末尾に来る終了記号、無視(文中の疑問文は`_`になる)
+ if letter in ("$", "?"):
+ assert i == len(prosodies) - 1, f"Unexpected {letter}"
+ # あとは"_"(ポーズ)と"#"(アクセント句の境界)のみ
+ # これらは残さず、次のアクセント句に備える。
+
+ current_phrase = []
+ # 0を基準点にしてそこから上昇・下降する(負の場合は上の`fix_phone_tone`で直る)
+ current_tone = 0
+ last_accent = ""
+ # アクセント上昇記号
+ elif letter == "[":
+ if last_accent != letter:
+ current_tone = current_tone + 1
+ last_accent = letter
+ # アクセント下降記号
+ elif letter == "]":
+ if last_accent != letter:
+ current_tone = current_tone - 1
+ last_accent = letter
+ # それ以外は通常の音素
+ else:
+ if letter == "cl": # 「っ」の処理
+ letter = "q"
+ current_phrase.append((letter, current_tone))
+ return result
+
+
+def handle_long(sep_phonemes: list[list[str]]) -> list[list[str]]:
+ for i in range(len(sep_phonemes)):
+ if sep_phonemes[i][0] == "ー":
+ # sep_phonemes[i][0] = sep_phonemes[i - 1][-1]
+ sep_phonemes[i][0] = ":"
+ if "ー" in sep_phonemes[i]:
+ for j in range(len(sep_phonemes[i])):
+ if sep_phonemes[i][j] == "ー":
+ # sep_phonemes[i][j] = sep_phonemes[i][j - 1][-1]
+ sep_phonemes[i][j] = ":"
+ return sep_phonemes
+
+
+def handle_long_word(sep_phonemes: list[list[str]]) -> list[list[str]]:
+ res = []
+ for i in range(len(sep_phonemes)):
+ if sep_phonemes[i][0] == "ー":
+ sep_phonemes[i][0] = sep_phonemes[i - 1][-1]
+ # sep_phonemes[i][0] = ':'
+ if "ー" in sep_phonemes[i]:
+ for j in range(len(sep_phonemes[i])):
+ if sep_phonemes[i][j] == "ー":
+ sep_phonemes[i][j] = sep_phonemes[i][j - 1][-1]
+ # sep_phonemes[i][j] = ':'
+ res.append(sep_phonemes[i])
+ res.append("▁")
+ return res
+
+
+def align_tones(
+ phones_with_punct: list[str], phone_tone_list: list[tuple[str, int]]
+) -> list[tuple[str, int]]:
+ """
+ 例:
+ …私は、、そう思う。
+ phones_with_punct:
+ [".", ".", ".", "w", "a", "t", "a", "sh", "i", "w", "a", ",", ",", "s", "o", "o", "o", "m", "o", "u", "."]
+ phone_tone_list:
+ [("w", 0), ("a", 0), ("t", 1), ("a", 1), ("sh", 1), ("i", 1), ("w", 1), ("a", 1), ("s", 0), ("o", 0), ("o", 1), ("o", 1), ("m", 1), ("o", 1), ("u", 0))]
+ Return:
+ [(".", 0), (".", 0), (".", 0), ("w", 0), ("a", 0), ("t", 1), ("a", 1), ("sh", 1), ("i", 1), ("w", 1), ("a", 1), (",", 0), (",", 0), ("s", 0), ("o", 0), ("o", 1), ("o", 1), ("m", 1), ("o", 1), ("u", 0), (".", 0)]
+ """
+ result: list[tuple[str, int]] = []
+ tone_index = 0
+ for phone in phones_with_punct:
+ if tone_index >= len(phone_tone_list):
+ # 余ったpunctuationがある場合 → (punctuation, 0)を追加
+ result.append((phone, 0))
+ elif phone == phone_tone_list[tone_index][0]:
+ # phone_tone_listの現在の音素と一致する場合 → toneをそこから取得、(phone, tone)を追加
+ result.append((phone, phone_tone_list[tone_index][1]))
+ # 探すindexを1つ進める
+ tone_index += 1
+ elif phone in punctuation or phone == "▁":
+ # phoneがpunctuationの場合 → (phone, 0)を追加
+ result.append((phone, 0))
+ else:
+ print(f"phones: {phones_with_punct}")
+ print(f"phone_tone_list: {phone_tone_list}")
+ print(f"result: {result}")
+ print(f"tone_index: {tone_index}")
+ print(f"phone: {phone}")
+ raise ValueError(f"Unexpected phone: {phone}")
+ return result
+
+
+def kata2phoneme_list(text: str) -> list[str]:
+ """
+ 原則カタカナの`text`を受け取り、それをそのままいじらずに音素記号のリストに変換。
+ 注意点:
+ - punctuationが来た場合(punctuationが1文字の場合がありうる)、処理せず1文字のリストを返す
+ - 冒頭に続く「ー」はそのまま「ー」のままにする(`handle_long()`で処理される)
+ - 文中の「ー」は前の音素記号の最後の音素記号に変換される。
+ 例:
+ `ーーソーナノカーー` → ["ー", "ー", "s", "o", "o", "n", "a", "n", "o", "k", "a", "a", "a"]
+ `?` → ["?"]
+ """
+ if text in punctuation:
+ return [text]
+ # `text`がカタカナ(`ー`含む)のみからなるかどうかをチェック
+ if re.fullmatch(r"[\u30A0-\u30FF]+", text) is None:
+ raise ValueError(f"Input must be katakana only: {text}")
+ sorted_keys = sorted(mora_kata_to_mora_phonemes.keys(), key=len, reverse=True)
+ pattern = "|".join(map(re.escape, sorted_keys))
+
+ def mora2phonemes(mora: str) -> str:
+ cosonant, vowel = mora_kata_to_mora_phonemes[mora]
+ if cosonant is None:
+ return f" {vowel}"
+ return f" {cosonant} {vowel}"
+
+ spaced_phonemes = re.sub(pattern, lambda m: mora2phonemes(m.group()), text)
+
+ # 長音記号「ー」の処理
+ long_pattern = r"(\w)(ー*)"
+ long_replacement = lambda m: m.group(1) + (" " + m.group(1)) * len(m.group(2))
+ spaced_phonemes = re.sub(long_pattern, long_replacement, spaced_phonemes)
+ # spaced_phonemes += ' ▁'
+ return spaced_phonemes.strip().split(" ")
+
+
+def frontend2phoneme(labels, drop_unvoiced_vowels=False):
+ N = len(labels)
+
+ phones = []
+ for n in range(N):
+ lab_curr = labels[n]
+ # print(lab_curr)
+ # current phoneme
+ p3 = re.search(r"\-(.*?)\+", lab_curr).group(1)
+
+ # deal unvoiced vowels as normal vowels
+ if drop_unvoiced_vowels and p3 in "AEIOU":
+ p3 = p3.lower()
+
+ # deal with sil at the beginning and the end of text
+ if p3 == "sil":
+ # assert n == 0 or n == N - 1
+ # if n == 0:
+ # phones.append("^")
+ # elif n == N - 1:
+ # # check question form or not
+ # e3 = _numeric_feature_by_regex(r"!(\d+)_", lab_curr)
+ # if e3 == 0:
+ # phones.append("$")
+ # elif e3 == 1:
+ # phones.append("?")
+ continue
+ elif p3 == "pau":
+ phones.append("_")
+ continue
+ else:
+ phones.append(p3)
+
+ # accent type and position info (forward or backward)
+ a1 = _numeric_feature_by_regex(r"/A:([0-9\-]+)\+", lab_curr)
+ a2 = _numeric_feature_by_regex(r"\+(\d+)\+", lab_curr)
+ a3 = _numeric_feature_by_regex(r"\+(\d+)/", lab_curr)
+
+ # number of mora in accent phrase
+ f1 = _numeric_feature_by_regex(r"/F:(\d+)_", lab_curr)
+
+ a2_next = _numeric_feature_by_regex(r"\+(\d+)\+", labels[n + 1])
+ # accent phrase border
+ # print(p3, a1, a2, a3, f1, a2_next, lab_curr)
+ if a3 == 1 and a2_next == 1 and p3 in "aeiouAEIOUNcl":
+ phones.append("#")
+ # pitch falling
+ elif a1 == 0 and a2_next == a2 + 1 and a2 != f1:
+ phones.append("]")
+ # pitch rising
+ elif a2 == 1 and a2_next == 2:
+ phones.append("[")
+
+ # phones = ' '.join(phones)
+ return phones
+
+
+class JapanesePhoneConverter(object):
+ def __init__(self, lexicon_path=None, ipa_dict_path=None):
+ # lexicon_lines = open(lexicon_path, 'r', encoding='utf-8').readlines()
+ # self.lexicon = {}
+ # self.single_dict = {}
+ # self.double_dict = {}
+ # for curr_line in lexicon_lines:
+ # k,v = curr_line.strip().split('+',1)
+ # self.lexicon[k] = v
+ # if len(k) == 2:
+ # self.double_dict[k] = v
+ # elif len(k) == 1:
+ # self.single_dict[k] = v
+ self.ipa_dict = {}
+ for curr_line in jp_xphone2ipa:
+ k, v = curr_line.strip().split(" ", 1)
+ self.ipa_dict[k] = re.sub("\s", "", v)
+ # kakasi1 = kakasi()
+ # kakasi1.setMode("H","K")
+ # kakasi1.setMode("J","K")
+ # kakasi1.setMode("r","Hepburn")
+ self.japan_JH2K = kakasi()
+ self.table = {ord(f): ord(t) for f, t in zip("67", "_¯")}
+
+ def text2sep_kata(self, parsed) -> tuple[list[str], list[str]]:
+ """
+ `text_normalize`で正規化済みの`norm_text`を受け取り、それを単語分割し、
+ 分割された単語リストとその読み(カタカナor記号1文字)のリストのタプルを返す。
+ 単語分割結果は、`g2p()`の`word2ph`で1文字あたりに割り振る音素記号の数を決めるために使う。
+ 例:
+ `私はそう思う!って感じ?` →
+ ["私", "は", "そう", "思う", "!", "って", "感じ", "?"], ["ワタシ", "ワ", "ソー", "オモウ", "!", "ッテ", "カンジ", "?"]
+ """
+ # parsed: OpenJTalkの解析結果
+ sep_text: list[str] = []
+ sep_kata: list[str] = []
+ fix_parsed = []
+ i = 0
+ while i <= len(parsed) - 1:
+ # word: 実際の単語の文字列
+ # yomi: その読み、但し無声化サインの`’`は除去
+ # print(parsed)
+ yomi = parsed[i]["pron"]
+ tmp_parsed = parsed[i]
+ if i != len(parsed) - 1 and parsed[i + 1]["string"] in [
+ "々",
+ "ゝ",
+ "ヽ",
+ "ゞ",
+ "ヾ",
+ "゛",
+ ]:
+ word = parsed[i]["string"] + parsed[i + 1]["string"]
+ i += 1
+ else:
+ word = parsed[i]["string"]
+ word, yomi = replace_punctuation(word), yomi.replace("’", "")
+ """
+ ここで`yomi`の取りうる値は以下の通りのはず。
+ - `word`が通常単語 → 通常の読み(カタカナ)
+ (カタカナからなり、長音記号も含みうる、`アー` 等)
+ - `word`が`ー` から始まる → `ーラー` や `ーーー` など
+ - `word`が句読点や空白等 → `、`
+ - `word`が`?` → `?`(全角になる)
+ 他にも`word`が読めないキリル文字アラビア文字等が来ると`、`になるが、正規化でこの場合は起きないはず。
+ また元のコードでは`yomi`が空白の場合の処理があったが、これは起きないはず。
+ 処理すべきは`yomi`が`、`の場合のみのはず。
+ """
+ assert yomi != "", f"Empty yomi: {word}"
+ if yomi == "、":
+ # wordは正規化されているので、`.`, `,`, `!`, `'`, `-`のいずれか
+ if word not in (
+ ".",
+ ",",
+ "!",
+ "'",
+ "-",
+ "?",
+ ":",
+ ";",
+ "…",
+ "",
+ ):
+ # ここはpyopenjtalkが読めない文字等のときに起こる
+ print(
+ "Cannot read:{}, yomi:{}, new_word:{};".format(
+ word, yomi, self.japan_JH2K.convert(word)[0]["kana"]
+ )
+ )
+ # raise ValueError(word)
+ word = self.japan_JH2K.convert(word)[0]["kana"]
+ # print(word, self.japan_JH2K.convert(word)[0]['kana'], kata2phoneme_list(self.japan_JH2K.convert(word)[0]['kana']))
+ tmp_parsed["pron"] = word
+ # yomi = "-"
+ # word = ','
+ # yomiは元の記号のままに変更
+ # else:
+ # parsed[i]['pron'] = parsed[i]["string"]
+ yomi = word
+ elif yomi == "?":
+ assert word == "?", f"yomi `?` comes from: {word}"
+ yomi = "?"
+ if word == "":
+ i += 1
+ continue
+ sep_text.append(word)
+ sep_kata.append(yomi)
+ # print(word, yomi, parts)
+ fix_parsed.append(tmp_parsed)
+ i += 1
+ # print(sep_text, sep_kata)
+ return sep_text, sep_kata, fix_parsed
+
+ def getSentencePhone(self, sentence, blank_mode=True, phoneme_mode=False):
+ # print("origin:", sentence)
+ words = []
+ words_phone_len = []
+ short_char_flag = False
+ output_duration_flag = []
+ output_before_sil_flag = []
+ normed_text = []
+ sentence = sentence.strip().strip("'")
+ sentence = re.sub(r"\s+", "", sentence)
+ output_res = []
+ failed_words = []
+ last_long_pause = 4
+ last_word = None
+ frontend_text = pyopenjtalk.run_frontend(sentence)
+ # print("frontend_text: ", frontend_text)
+ try:
+ frontend_text = pyopenjtalk.estimate_accent(frontend_text)
+ except:
+ pass
+ # print("estimate_accent: ", frontend_text)
+ # sep_text: 単語単位の単語のリスト
+ # sep_kata: 単語単位の単語のカタカナ読みのリスト
+ sep_text, sep_kata, frontend_text = self.text2sep_kata(frontend_text)
+ # print("sep_text: ", sep_text)
+ # print("sep_kata: ", sep_kata)
+ # print("frontend_text: ", frontend_text)
+ # sep_phonemes: 各単語ごとの音素のリストのリスト
+ sep_phonemes = handle_long_word([kata2phoneme_list(i) for i in sep_kata])
+ # print("sep_phonemes: ", sep_phonemes)
+
+ pron_text = [x["pron"].strip().replace("’", "") for x in frontend_text]
+ # pdb.set_trace()
+ prosodys = pyopenjtalk.make_label(frontend_text)
+ prosodys = frontend2phoneme(prosodys, drop_unvoiced_vowels=True)
+ # print("prosodys: ", ' '.join(prosodys))
+ # print("pron_text: ", pron_text)
+ normed_text = [x["string"].strip() for x in frontend_text]
+ # punctuationがすべて消えた、音素とアクセントのタプルのリスト
+ phone_tone_list_wo_punct = g2phone_tone_wo_punct(prosodys)
+ # print("phone_tone_list_wo_punct: ", phone_tone_list_wo_punct)
+
+ # phone_w_punct: sep_phonemesを結合した、punctuationを元のまま保持した音素列
+ phone_w_punct: list[str] = []
+ w_p_len = []
+ for i in sep_phonemes:
+ phone_w_punct += i
+ w_p_len.append(len(i))
+ phone_w_punct = phone_w_punct[:-1]
+ # punctuation無しのアクセント情報を使って、punctuationを含めたアクセント情報を作る
+ # print("phone_w_punct: ", phone_w_punct)
+ # print("phone_tone_list_wo_punct: ", phone_tone_list_wo_punct)
+ phone_tone_list = align_tones(phone_w_punct, phone_tone_list_wo_punct)
+
+ jp_item = {}
+ jp_p = ""
+ jp_t = ""
+ # mye rye pye bye nye
+ # je she
+ # print(phone_tone_list)
+ for p, t in phone_tone_list:
+ if p in self.ipa_dict:
+ curr_p = self.ipa_dict[p]
+ jp_p += curr_p
+ jp_t += str(t + 6) * len(curr_p)
+ elif p in punctuation:
+ jp_p += p
+ jp_t += "0"
+ elif p == "▁":
+ jp_p += p
+ jp_t += " "
+ else:
+ print(p, t)
+ jp_p += "|"
+ jp_t += "0"
+ # return phones, tones, w_p_len
+ jp_p = jp_p.replace("▁", " ")
+ jp_t = jp_t.translate(self.table)
+ jp_l = ""
+ for t in jp_t:
+ if t == " ":
+ jp_l += " "
+ else:
+ jp_l += "2"
+ # print(jp_p)
+ # print(jp_t)
+ # print(jp_l)
+ # print(len(jp_p_len), sum(w_p_len), len(jp_p), sum(jp_p_len))
+ assert len(jp_p) == len(jp_t) and len(jp_p) == len(jp_l)
+
+ jp_item["jp_p"] = jp_p.replace("| |", "|").rstrip("|")
+ jp_item["jp_t"] = jp_t
+ jp_item["jp_l"] = jp_l
+ jp_item["jp_normed_text"] = " ".join(normed_text)
+ jp_item["jp_pron_text"] = " ".join(pron_text)
+ # jp_item['jp_ruoma'] = sep_phonemes
+ # print(len(normed_text), len(sep_phonemes))
+ # print(normed_text)
+ return jp_item
+
+
+jpc = JapanesePhoneConverter()
+
+
+def japanese_to_ipa(text, text_tokenizer):
+ # phonemes = text_tokenizer(text)
+ if type(text) == str:
+ return jpc.getSentencePhone(text)["jp_p"]
+ else:
+ result_ph = []
+ for t in text:
+ result_ph.append(jpc.getSentencePhone(t)["jp_p"])
+ return result_ph
diff --git a/models/tts/maskgct/g2p/g2p/korean.py b/models/tts/maskgct/g2p/g2p/korean.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7c540b47d98ccf6e0db5f938e52834abf679b59
--- /dev/null
+++ b/models/tts/maskgct/g2p/g2p/korean.py
@@ -0,0 +1,81 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import re
+
+"""
+ Text clean time
+"""
+english_dictionary = {
+ "KOREA": "코리아",
+ "IDOL": "아이돌",
+ "IT": "아이티",
+ "IQ": "아이큐",
+ "UP": "업",
+ "DOWN": "다운",
+ "PC": "피씨",
+ "CCTV": "씨씨티비",
+ "SNS": "에스엔에스",
+ "AI": "에이아이",
+ "CEO": "씨이오",
+ "A": "에이",
+ "B": "비",
+ "C": "씨",
+ "D": "디",
+ "E": "이",
+ "F": "에프",
+ "G": "지",
+ "H": "에이치",
+ "I": "아이",
+ "J": "제이",
+ "K": "케이",
+ "L": "엘",
+ "M": "엠",
+ "N": "엔",
+ "O": "오",
+ "P": "피",
+ "Q": "큐",
+ "R": "알",
+ "S": "에스",
+ "T": "티",
+ "U": "유",
+ "V": "브이",
+ "W": "더블유",
+ "X": "엑스",
+ "Y": "와이",
+ "Z": "제트",
+}
+
+
+def normalize(text):
+ text = text.strip()
+ text = re.sub(
+ "[⺀-⺙⺛-⻳⼀-⿕々〇〡-〩〸-〺〻㐀-䶵一-鿃豈-鶴侮-頻並-龎]", "", text
+ )
+ text = normalize_english(text)
+ text = text.lower()
+ return text
+
+
+def normalize_english(text):
+ def fn(m):
+ word = m.group()
+ if word in english_dictionary:
+ return english_dictionary.get(word)
+ return word
+
+ text = re.sub("([A-Za-z]+)", fn, text)
+ return text
+
+
+def korean_to_ipa(text, text_tokenizer):
+ if type(text) == str:
+ text = normalize(text)
+ phonemes = text_tokenizer(text)
+ return phonemes
+ else:
+ for i, t in enumerate(text):
+ text[i] = normalize(t)
+ return text_tokenizer(text)
diff --git a/models/tts/maskgct/g2p/g2p/mandarin.py b/models/tts/maskgct/g2p/g2p/mandarin.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b107a1e22c2c8d93c118be0841e0ddf75e65164
--- /dev/null
+++ b/models/tts/maskgct/g2p/g2p/mandarin.py
@@ -0,0 +1,595 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import re
+import jieba
+import cn2an
+from pypinyin import lazy_pinyin, BOPOMOFO
+from typing import List
+from models.tts.maskgct.g2p.g2p.chinese_model_g2p import BertPolyPredict
+from models.tts.maskgct.g2p.utils.front_utils import *
+import os
+
+# from g2pw import G2PWConverter
+
+
+# set blank level, {0:"none",1:"char", 2:"word"}
+BLANK_LEVEL = 0
+
+# conv = G2PWConverter(style='pinyin', enable_non_tradional_chinese=True)
+resource_path = r"./models/tts/maskgct/g2p"
+poly_all_class_path = os.path.join(
+ resource_path, "sources", "g2p_chinese_model", "polychar.txt"
+)
+if not os.path.exists(poly_all_class_path):
+ print(
+ "Incorrect path for polyphonic character class dictionary: {}, please check...".format(
+ poly_all_class_path
+ )
+ )
+ exit()
+poly_dict = generate_poly_lexicon(poly_all_class_path)
+
+# Set up G2PW model parameters
+g2pw_poly_model_path = os.path.join(resource_path, "sources", "g2p_chinese_model")
+if not os.path.exists(g2pw_poly_model_path):
+ print(
+ "Incorrect path for g2pw polyphonic character model: {}, please check...".format(
+ g2pw_poly_model_path
+ )
+ )
+ exit()
+
+json_file_path = os.path.join(
+ resource_path, "sources", "g2p_chinese_model", "polydict.json"
+)
+if not os.path.exists(json_file_path):
+ print(
+ "Incorrect path for g2pw id to pinyin dictionary: {}, please check...".format(
+ json_file_path
+ )
+ )
+ exit()
+
+jsonr_file_path = os.path.join(
+ resource_path, "sources", "g2p_chinese_model", "polydict_r.json"
+)
+if not os.path.exists(jsonr_file_path):
+ print(
+ "Incorrect path for g2pw pinyin to id dictionary: {}, please check...".format(
+ jsonr_file_path
+ )
+ )
+ exit()
+
+g2pw_poly_predict = BertPolyPredict(
+ g2pw_poly_model_path, jsonr_file_path, json_file_path
+)
+
+
+"""
+ Text clean time
+"""
+# List of (Latin alphabet, bopomofo) pairs:
+_latin_to_bopomofo = [
+ (re.compile("%s" % x[0], re.IGNORECASE), x[1])
+ for x in [
+ ("a", "ㄟˉ"),
+ ("b", "ㄅㄧˋ"),
+ ("c", "ㄙㄧˉ"),
+ ("d", "ㄉㄧˋ"),
+ ("e", "ㄧˋ"),
+ ("f", "ㄝˊㄈㄨˋ"),
+ ("g", "ㄐㄧˋ"),
+ ("h", "ㄝˇㄑㄩˋ"),
+ ("i", "ㄞˋ"),
+ ("j", "ㄐㄟˋ"),
+ ("k", "ㄎㄟˋ"),
+ ("l", "ㄝˊㄛˋ"),
+ ("m", "ㄝˊㄇㄨˋ"),
+ ("n", "ㄣˉ"),
+ ("o", "ㄡˉ"),
+ ("p", "ㄆㄧˉ"),
+ ("q", "ㄎㄧㄡˉ"),
+ ("r", "ㄚˋ"),
+ ("s", "ㄝˊㄙˋ"),
+ ("t", "ㄊㄧˋ"),
+ ("u", "ㄧㄡˉ"),
+ ("v", "ㄨㄧˉ"),
+ ("w", "ㄉㄚˋㄅㄨˋㄌㄧㄡˋ"),
+ ("x", "ㄝˉㄎㄨˋㄙˋ"),
+ ("y", "ㄨㄞˋ"),
+ ("z", "ㄗㄟˋ"),
+ ]
+]
+
+# List of (bopomofo, ipa) pairs:
+_bopomofo_to_ipa = [
+ (re.compile("%s" % x[0]), x[1])
+ for x in [
+ ("ㄅㄛ", "p⁼wo"),
+ ("ㄆㄛ", "pʰwo"),
+ ("ㄇㄛ", "mwo"),
+ ("ㄈㄛ", "fwo"),
+ ("ㄧㄢ", "|jɛn"),
+ ("ㄩㄢ", "|ɥæn"),
+ ("ㄧㄣ", "|in"),
+ ("ㄩㄣ", "|ɥn"),
+ ("ㄧㄥ", "|iŋ"),
+ ("ㄨㄥ", "|ʊŋ"),
+ ("ㄩㄥ", "|jʊŋ"),
+ # Add
+ ("ㄧㄚ", "|ia"),
+ ("ㄧㄝ", "|iɛ"),
+ ("ㄧㄠ", "|iɑʊ"),
+ ("ㄧㄡ", "|ioʊ"),
+ ("ㄧㄤ", "|iɑŋ"),
+ ("ㄨㄚ", "|ua"),
+ ("ㄨㄛ", "|uo"),
+ ("ㄨㄞ", "|uaɪ"),
+ ("ㄨㄟ", "|ueɪ"),
+ ("ㄨㄢ", "|uan"),
+ ("ㄨㄣ", "|uən"),
+ ("ㄨㄤ", "|uɑŋ"),
+ ("ㄩㄝ", "|ɥɛ"),
+ # End
+ ("ㄅ", "p⁼"),
+ ("ㄆ", "pʰ"),
+ ("ㄇ", "m"),
+ ("ㄈ", "f"),
+ ("ㄉ", "t⁼"),
+ ("ㄊ", "tʰ"),
+ ("ㄋ", "n"),
+ ("ㄌ", "l"),
+ ("ㄍ", "k⁼"),
+ ("ㄎ", "kʰ"),
+ ("ㄏ", "x"),
+ ("ㄐ", "tʃ⁼"),
+ ("ㄑ", "tʃʰ"),
+ ("ㄒ", "ʃ"),
+ ("ㄓ", "ts`⁼"),
+ ("ㄔ", "ts`ʰ"),
+ ("ㄕ", "s`"),
+ ("ㄖ", "ɹ`"),
+ ("ㄗ", "ts⁼"),
+ ("ㄘ", "tsʰ"),
+ ("ㄙ", "|s"),
+ ("ㄚ", "|a"),
+ ("ㄛ", "|o"),
+ ("ㄜ", "|ə"),
+ ("ㄝ", "|ɛ"),
+ ("ㄞ", "|aɪ"),
+ ("ㄟ", "|eɪ"),
+ ("ㄠ", "|ɑʊ"),
+ ("ㄡ", "|oʊ"),
+ ("ㄢ", "|an"),
+ ("ㄣ", "|ən"),
+ ("ㄤ", "|ɑŋ"),
+ ("ㄥ", "|əŋ"),
+ ("ㄦ", "əɹ"),
+ ("ㄧ", "|i"),
+ ("ㄨ", "|u"),
+ ("ㄩ", "|ɥ"),
+ ("ˉ", "→|"),
+ ("ˊ", "↑|"),
+ ("ˇ", "↓↑|"),
+ ("ˋ", "↓|"),
+ ("˙", "|"),
+ ]
+]
+must_not_er_words = {"女儿", "老儿", "男儿", "少儿", "小儿"}
+
+word_pinyin_dict = {}
+with open(
+ r"./models/tts/maskgct/g2p/sources/chinese_lexicon.txt", "r", encoding="utf-8"
+) as fread:
+ txt_list = fread.readlines()
+ for txt in txt_list:
+ word, pinyin = txt.strip().split("\t")
+ word_pinyin_dict[word] = pinyin
+ fread.close()
+
+pinyin_2_bopomofo_dict = {}
+with open(
+ r"./models/tts/maskgct/g2p/sources/pinyin_2_bpmf.txt", "r", encoding="utf-8"
+) as fread:
+ txt_list = fread.readlines()
+ for txt in txt_list:
+ pinyin, bopomofo = txt.strip().split("\t")
+ pinyin_2_bopomofo_dict[pinyin] = bopomofo
+ fread.close()
+
+tone_dict = {
+ "0": "˙",
+ "5": "˙",
+ "1": "",
+ "2": "ˊ",
+ "3": "ˇ",
+ "4": "ˋ",
+}
+
+bopomofos2pinyin_dict = {}
+with open(
+ r"./models/tts/maskgct/g2p/sources/bpmf_2_pinyin.txt", "r", encoding="utf-8"
+) as fread:
+ txt_list = fread.readlines()
+ for txt in txt_list:
+ v, k = txt.strip().split("\t")
+ bopomofos2pinyin_dict[k] = v
+ fread.close()
+
+
+def bpmf_to_pinyin(text):
+ bopomofo_list = text.split("|")
+ pinyin_list = []
+ for info in bopomofo_list:
+ pinyin = ""
+ for c in info:
+ if c in bopomofos2pinyin_dict:
+ pinyin += bopomofos2pinyin_dict[c]
+ if len(pinyin) == 0:
+ continue
+ if pinyin[-1] not in "01234":
+ pinyin += "1"
+ if pinyin[:-1] == "ve":
+ pinyin = "y" + pinyin
+ if pinyin[:-1] == "sh":
+ pinyin = pinyin[:-1] + "i" + pinyin[-1]
+ if pinyin == "sh":
+ pinyin = pinyin[:-1] + "i"
+ if pinyin[:-1] == "s":
+ pinyin = "si" + pinyin[-1]
+ if pinyin[:-1] == "c":
+ pinyin = "ci" + pinyin[-1]
+ if pinyin[:-1] == "i":
+ pinyin = "yi" + pinyin[-1]
+ if pinyin[:-1] == "iou":
+ pinyin = "you" + pinyin[-1]
+ if pinyin[:-1] == "ien":
+ pinyin = "yin" + pinyin[-1]
+ if "iou" in pinyin and pinyin[-4:-1] == "iou":
+ pinyin = pinyin[:-4] + "iu" + pinyin[-1]
+ if "uei" in pinyin:
+ if pinyin[:-1] == "uei":
+ pinyin = "wei" + pinyin[-1]
+ elif pinyin[-4:-1] == "uei":
+ pinyin = pinyin[:-4] + "ui" + pinyin[-1]
+ if "uen" in pinyin and pinyin[-4:-1] == "uen":
+ if pinyin[:-1] == "uen":
+ pinyin = "wen" + pinyin[-1]
+ elif pinyin[-4:-1] == "uei":
+ pinyin = pinyin[:-4] + "un" + pinyin[-1]
+ if "van" in pinyin and pinyin[-4:-1] == "van":
+ if pinyin[:-1] == "van":
+ pinyin = "yuan" + pinyin[-1]
+ elif pinyin[-4:-1] == "van":
+ pinyin = pinyin[:-4] + "uan" + pinyin[-1]
+ if "ueng" in pinyin and pinyin[-5:-1] == "ueng":
+ pinyin = pinyin[:-5] + "ong" + pinyin[-1]
+ if pinyin[:-1] == "veng":
+ pinyin = "yong" + pinyin[-1]
+ if "veng" in pinyin and pinyin[-5:-1] == "veng":
+ pinyin = pinyin[:-5] + "iong" + pinyin[-1]
+ if pinyin[:-1] == "ieng":
+ pinyin = "ying" + pinyin[-1]
+ if pinyin[:-1] == "u":
+ pinyin = "wu" + pinyin[-1]
+ if pinyin[:-1] == "v":
+ pinyin = "yv" + pinyin[-1]
+ if pinyin[:-1] == "ing":
+ pinyin = "ying" + pinyin[-1]
+ if pinyin[:-1] == "z":
+ pinyin = "zi" + pinyin[-1]
+ if pinyin[:-1] == "zh":
+ pinyin = "zhi" + pinyin[-1]
+ if pinyin[0] == "u":
+ pinyin = "w" + pinyin[1:]
+ if pinyin[0] == "i":
+ pinyin = "y" + pinyin[1:]
+ pinyin = pinyin.replace("ien", "in")
+
+ pinyin_list.append(pinyin)
+ return " ".join(pinyin_list)
+
+
+# Convert numbers to Chinese pronunciation
+def number_to_chinese(text):
+ # numbers = re.findall(r'\d+(?:\.?\d+)?', text)
+ # for number in numbers:
+ # text = text.replace(number, cn2an.an2cn(number), 1)
+ text = cn2an.transform(text, "an2cn")
+ return text
+
+
+def normalization(text):
+ text = text.replace(",", ",")
+ text = text.replace("。", ".")
+ text = text.replace("!", "!")
+ text = text.replace("?", "?")
+ text = text.replace(";", ";")
+ text = text.replace(":", ":")
+ text = text.replace("、", ",")
+ text = text.replace("‘", "'")
+ text = text.replace("’", "'")
+ text = text.replace("⋯", "…")
+ text = text.replace("···", "…")
+ text = text.replace("・・・", "…")
+ text = text.replace("...", "…")
+ text = re.sub(r"\s+", "", text)
+ text = re.sub(r"[^\u4e00-\u9fff\s_,\.\?!;:\'…]", "", text)
+ text = re.sub(r"\s*([,\.\?!;:\'…])\s*", r"\1", text)
+ return text
+
+
+def change_tone(bopomofo: str, tone: str) -> str:
+ if bopomofo[-1] not in "˙ˊˇˋ":
+ bopomofo = bopomofo + tone
+ else:
+ bopomofo = bopomofo[:-1] + tone
+ return bopomofo
+
+
+def er_sandhi(word: str, bopomofos: List[str]) -> List[str]:
+ if len(word) > 1 and word[-1] == "儿" and word not in must_not_er_words:
+ bopomofos[-1] = change_tone(bopomofos[-1], "˙")
+ return bopomofos
+
+
+def bu_sandhi(word: str, bopomofos: List[str]) -> List[str]:
+ valid_char = set(word)
+ if len(valid_char) == 1 and "不" in valid_char:
+ pass
+ elif word in ["不字"]:
+ pass
+ elif len(word) == 3 and word[1] == "不" and bopomofos[1][:-1] == "ㄅㄨ":
+ bopomofos[1] = bopomofos[1][:-1] + "˙"
+ else:
+ for i, char in enumerate(word):
+ if (
+ i + 1 < len(bopomofos)
+ and char == "不"
+ and i + 1 < len(word)
+ and 0 < len(bopomofos[i + 1])
+ and bopomofos[i + 1][-1] == "ˋ"
+ ):
+ bopomofos[i] = bopomofos[i][:-1] + "ˊ"
+ return bopomofos
+
+
+def yi_sandhi(word: str, bopomofos: List[str]) -> List[str]:
+ punc = ":,;。?!“”‘’':,;.?!()(){}【】[]-~`、 "
+ if word.find("一") != -1 and any(
+ [item.isnumeric() for item in word if item != "一"]
+ ):
+ for i in range(len(word)):
+ if (
+ i == 0
+ and word[0] == "一"
+ and len(word) > 1
+ and word[1]
+ not in [
+ "零",
+ "一",
+ "二",
+ "三",
+ "四",
+ "五",
+ "六",
+ "七",
+ "八",
+ "九",
+ "十",
+ ]
+ ):
+ if len(bopomofos[0]) > 0 and bopomofos[1][-1] in ["ˋ", "˙"]:
+ bopomofos[0] = change_tone(bopomofos[0], "ˊ")
+ else:
+ bopomofos[0] = change_tone(bopomofos[0], "ˋ")
+ elif word[i] == "一":
+ bopomofos[i] = change_tone(bopomofos[i], "")
+ return bopomofos
+ elif len(word) == 3 and word[1] == "一" and word[0] == word[-1]:
+ bopomofos[1] = change_tone(bopomofos[1], "˙")
+ elif word.startswith("第一"):
+ bopomofos[1] = change_tone(bopomofos[1], "")
+ elif word.startswith("一月") or word.startswith("一日") or word.startswith("一号"):
+ bopomofos[0] = change_tone(bopomofos[0], "")
+ else:
+ for i, char in enumerate(word):
+ if char == "一" and i + 1 < len(word):
+ if (
+ len(bopomofos) > i + 1
+ and len(bopomofos[i + 1]) > 0
+ and bopomofos[i + 1][-1] in {"ˋ"}
+ ):
+ bopomofos[i] = change_tone(bopomofos[i], "ˊ")
+ else:
+ if word[i + 1] not in punc:
+ bopomofos[i] = change_tone(bopomofos[i], "ˋ")
+ else:
+ pass
+ return bopomofos
+
+
+def merge_bu(seg: List) -> List:
+ new_seg = []
+ last_word = ""
+ for word in seg:
+ if word != "不":
+ if last_word == "不":
+ word = last_word + word
+ new_seg.append(word)
+ last_word = word
+ return new_seg
+
+
+def merge_er(seg: List) -> List:
+ new_seg = []
+ for i, word in enumerate(seg):
+ if i - 1 >= 0 and word == "儿":
+ new_seg[-1] = new_seg[-1] + seg[i]
+ else:
+ new_seg.append(word)
+ return new_seg
+
+
+def merge_yi(seg: List) -> List:
+ new_seg = []
+ # function 1
+ for i, word in enumerate(seg):
+ if (
+ i - 1 >= 0
+ and word == "一"
+ and i + 1 < len(seg)
+ and seg[i - 1] == seg[i + 1]
+ ):
+ if i - 1 < len(new_seg):
+ new_seg[i - 1] = new_seg[i - 1] + "一" + new_seg[i - 1]
+ else:
+ new_seg.append(word)
+ new_seg.append(seg[i + 1])
+ else:
+ if i - 2 >= 0 and seg[i - 1] == "一" and seg[i - 2] == word:
+ continue
+ else:
+ new_seg.append(word)
+ seg = new_seg
+ new_seg = []
+ isnumeric_flag = False
+ for i, word in enumerate(seg):
+ if all([item.isnumeric() for item in word]) and not isnumeric_flag:
+ isnumeric_flag = True
+ new_seg.append(word)
+ else:
+ new_seg.append(word)
+ seg = new_seg
+ new_seg = []
+ # function 2
+ for i, word in enumerate(seg):
+ if new_seg and new_seg[-1] == "一":
+ new_seg[-1] = new_seg[-1] + word
+ else:
+ new_seg.append(word)
+ return new_seg
+
+
+# Word Segmentation, and convert Chinese pronunciation to pinyin (bopomofo)
+def chinese_to_bopomofo(text_short, sentence):
+ # bopomofos = conv(text_short)
+ words = jieba.lcut(text_short, cut_all=False)
+ words = merge_yi(words)
+ words = merge_bu(words)
+ words = merge_er(words)
+ text = ""
+
+ char_index = 0
+ for word in words:
+ bopomofos = []
+ if word in word_pinyin_dict and word not in poly_dict:
+ pinyin = word_pinyin_dict[word]
+ for py in pinyin.split(" "):
+ if py[:-1] in pinyin_2_bopomofo_dict and py[-1] in tone_dict:
+ bopomofos.append(
+ pinyin_2_bopomofo_dict[py[:-1]] + tone_dict[py[-1]]
+ )
+ if BLANK_LEVEL == 1:
+ bopomofos.append("_")
+ else:
+ bopomofos_lazy = lazy_pinyin(word, BOPOMOFO)
+ bopomofos += bopomofos_lazy
+ if BLANK_LEVEL == 1:
+ bopomofos.append("_")
+ else:
+ for i in range(len(word)):
+ c = word[i]
+ if c in poly_dict:
+ poly_pinyin = g2pw_poly_predict.predict_process(
+ [text_short, char_index + i]
+ )[0]
+ py = poly_pinyin[2:-1]
+ bopomofos.append(
+ pinyin_2_bopomofo_dict[py[:-1]] + tone_dict[py[-1]]
+ )
+ if BLANK_LEVEL == 1:
+ bopomofos.append("_")
+ elif c in word_pinyin_dict:
+ py = word_pinyin_dict[c]
+ bopomofos.append(
+ pinyin_2_bopomofo_dict[py[:-1]] + tone_dict[py[-1]]
+ )
+ if BLANK_LEVEL == 1:
+ bopomofos.append("_")
+ else:
+ bopomofos.append(c)
+ if BLANK_LEVEL == 1:
+ bopomofos.append("_")
+ if BLANK_LEVEL == 2:
+ bopomofos.append("_")
+ char_index += len(word)
+
+ if (
+ len(word) == 3
+ and bopomofos[0][-1] == "ˇ"
+ and bopomofos[1][-1] == "ˇ"
+ and bopomofos[-1][-1] == "ˇ"
+ ):
+ bopomofos[0] = bopomofos[0] + "ˊ"
+ bopomofos[1] = bopomofos[1] + "ˊ"
+ if len(word) == 2 and bopomofos[0][-1] == "ˇ" and bopomofos[-1][-1] == "ˇ":
+ bopomofos[0] = bopomofos[0][:-1] + "ˊ"
+ bopomofos = bu_sandhi(word, bopomofos)
+ bopomofos = yi_sandhi(word, bopomofos)
+ bopomofos = er_sandhi(word, bopomofos)
+ if not re.search("[\u4e00-\u9fff]", word):
+ text += "|" + word
+ continue
+ for i in range(len(bopomofos)):
+ bopomofos[i] = re.sub(r"([\u3105-\u3129])$", r"\1ˉ", bopomofos[i])
+ if text != "":
+ text += "|"
+ text += "|".join(bopomofos)
+ return text
+
+
+# Convert latin pronunciation to pinyin (bopomofo)
+def latin_to_bopomofo(text):
+ for regex, replacement in _latin_to_bopomofo:
+ text = re.sub(regex, replacement, text)
+ return text
+
+
+# Convert pinyin (bopomofo) to IPA
+def bopomofo_to_ipa(text):
+ for regex, replacement in _bopomofo_to_ipa:
+ text = re.sub(regex, replacement, text)
+ return text
+
+
+def _chinese_to_ipa(text, sentence):
+ text = number_to_chinese(text.strip())
+ text = normalization(text)
+ text = chinese_to_bopomofo(text, sentence)
+ # pinyin = bpmf_to_pinyin(text)
+ text = latin_to_bopomofo(text)
+ text = bopomofo_to_ipa(text)
+ text = re.sub("([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)", r"\1ɹ\2", text)
+ text = re.sub("([s][⁼ʰ]?)([→↓↑ ]+|$)", r"\1ɹ\2", text)
+ text = re.sub(r"^\||[^\w\s_,\.\?!;:\'…\|→↓↑⁼ʰ`]", "", text)
+ text = re.sub(r"([,\.\?!;:\'…])", r"|\1|", text)
+ text = re.sub(r"\|+", "|", text)
+ text = text.rstrip("|")
+ return text
+
+
+# Convert Chinese to IPA
+def chinese_to_ipa(text, sentence, text_tokenizer):
+ # phonemes = text_tokenizer(text.strip())
+ if type(text) == str:
+ return _chinese_to_ipa(text, sentence)
+ else:
+ result_ph = []
+ for t in text:
+ result_ph.append(_chinese_to_ipa(t, sentence))
+ return result_ph
diff --git a/models/tts/maskgct/g2p/g2p/text_tokenizers.py b/models/tts/maskgct/g2p/g2p/text_tokenizers.py
new file mode 100644
index 0000000000000000000000000000000000000000..45cb3481c4611839c190309e4b348801f40227af
--- /dev/null
+++ b/models/tts/maskgct/g2p/g2p/text_tokenizers.py
@@ -0,0 +1,84 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import re
+import os
+from typing import List, Pattern, Union
+from phonemizer.utils import list2str, str2list
+from phonemizer.backend import EspeakBackend
+from phonemizer.backend.espeak.language_switch import LanguageSwitch
+from phonemizer.backend.espeak.words_mismatch import WordMismatch
+from phonemizer.punctuation import Punctuation
+from phonemizer.separator import Separator
+
+
+class TextTokenizer:
+ """Phonemize Text."""
+
+ def __init__(
+ self,
+ language="en-us",
+ backend="espeak",
+ separator=Separator(word="|_|", syllable="-", phone="|"),
+ preserve_punctuation=True,
+ with_stress: bool = False,
+ tie: Union[bool, str] = False,
+ language_switch: LanguageSwitch = "remove-flags",
+ words_mismatch: WordMismatch = "ignore",
+ ) -> None:
+ self.preserve_punctuation_marks = ",.?!;:'…"
+ self.backend = EspeakBackend(
+ language,
+ punctuation_marks=self.preserve_punctuation_marks,
+ preserve_punctuation=preserve_punctuation,
+ with_stress=with_stress,
+ tie=tie,
+ language_switch=language_switch,
+ words_mismatch=words_mismatch,
+ )
+
+ self.separator = separator
+
+ # convert chinese punctuation to english punctuation
+ def convert_chinese_punctuation(self, text: str) -> str:
+ text = text.replace(",", ",")
+ text = text.replace("。", ".")
+ text = text.replace("!", "!")
+ text = text.replace("?", "?")
+ text = text.replace(";", ";")
+ text = text.replace(":", ":")
+ text = text.replace("、", ",")
+ text = text.replace("‘", "'")
+ text = text.replace("’", "'")
+ text = text.replace("⋯", "…")
+ text = text.replace("···", "…")
+ text = text.replace("・・・", "…")
+ text = text.replace("...", "…")
+ return text
+
+ def __call__(self, text, strip=True) -> List[str]:
+
+ text_type = type(text)
+ normalized_text = []
+ for line in str2list(text):
+ line = self.convert_chinese_punctuation(line.strip())
+ line = re.sub(r"[^\w\s_,\.\?!;:\'…]", "", line)
+ line = re.sub(r"\s*([,\.\?!;:\'…])\s*", r"\1", line)
+ line = re.sub(r"\s+", " ", line)
+ normalized_text.append(line)
+ # print("Normalized test: ", normalized_text[0])
+ phonemized = self.backend.phonemize(
+ normalized_text, separator=self.separator, strip=strip, njobs=1
+ )
+ if text_type == str:
+ phonemized = re.sub(r"([,\.\?!;:\'…])", r"|\1|", list2str(phonemized))
+ phonemized = re.sub(r"\|+", "|", phonemized)
+ phonemized = phonemized.rstrip("|")
+ else:
+ for i in range(len(phonemized)):
+ phonemized[i] = re.sub(r"([,\.\?!;:\'…])", r"|\1|", phonemized[i])
+ phonemized[i] = re.sub(r"\|+", "|", phonemized[i])
+ phonemized[i] = phonemized[i].rstrip("|")
+ return phonemized
diff --git a/models/tts/maskgct/g2p/g2p/vocab.json b/models/tts/maskgct/g2p/g2p/vocab.json
new file mode 100644
index 0000000000000000000000000000000000000000..28d32aaf01881c6ff5449aaaf942d94b753a4e91
--- /dev/null
+++ b/models/tts/maskgct/g2p/g2p/vocab.json
@@ -0,0 +1,372 @@
+{
+ "vocab": {
+ ",": 0,
+ ".": 1,
+ "?": 2,
+ "!": 3,
+ "_": 4,
+ "iː": 5,
+ "ɪ": 6,
+ "ɜː": 7,
+ "ɚ": 8,
+ "oːɹ": 9,
+ "ɔː": 10,
+ "ɔːɹ": 11,
+ "ɑː": 12,
+ "uː": 13,
+ "ʊ": 14,
+ "ɑːɹ": 15,
+ "ʌ": 16,
+ "ɛ": 17,
+ "æ": 18,
+ "eɪ": 19,
+ "aɪ": 20,
+ "ɔɪ": 21,
+ "aʊ": 22,
+ "oʊ": 23,
+ "ɪɹ": 24,
+ "ɛɹ": 25,
+ "ʊɹ": 26,
+ "p": 27,
+ "b": 28,
+ "t": 29,
+ "d": 30,
+ "k": 31,
+ "ɡ": 32,
+ "f": 33,
+ "v": 34,
+ "θ": 35,
+ "ð": 36,
+ "s": 37,
+ "z": 38,
+ "ʃ": 39,
+ "ʒ": 40,
+ "h": 41,
+ "tʃ": 42,
+ "dʒ": 43,
+ "m": 44,
+ "n": 45,
+ "ŋ": 46,
+ "j": 47,
+ "w": 48,
+ "ɹ": 49,
+ "l": 50,
+ "tɹ": 51,
+ "dɹ": 52,
+ "ts": 53,
+ "dz": 54,
+ "i": 55,
+ "ɔ": 56,
+ "ə": 57,
+ "ɾ": 58,
+ "iə": 59,
+ "r": 60,
+ "u": 61,
+ "oː": 62,
+ "ɛː": 63,
+ "ɪː": 64,
+ "aɪə": 65,
+ "aɪɚ": 66,
+ "ɑ̃": 67,
+ "ç": 68,
+ "ɔ̃": 69,
+ "ææ": 70,
+ "ɐɐ": 71,
+ "ɡʲ": 72,
+ "nʲ": 73,
+ "iːː": 74,
+
+ "p⁼": 75,
+ "pʰ": 76,
+ "t⁼": 77,
+ "tʰ": 78,
+ "k⁼": 79,
+ "kʰ": 80,
+ "x": 81,
+ "tʃ⁼": 82,
+ "tʃʰ": 83,
+ "ts`⁼": 84,
+ "ts`ʰ": 85,
+ "s`": 86,
+ "ɹ`": 87,
+ "ts⁼": 88,
+ "tsʰ": 89,
+ "p⁼wo": 90,
+ "p⁼wo→": 91,
+ "p⁼wo↑": 92,
+ "p⁼wo↓↑": 93,
+ "p⁼wo↓": 94,
+ "pʰwo": 95,
+ "pʰwo→": 96,
+ "pʰwo↑": 97,
+ "pʰwo↓↑": 98,
+ "pʰwo↓": 99,
+ "mwo": 100,
+ "mwo→": 101,
+ "mwo↑": 102,
+ "mwo↓↑": 103,
+ "mwo↓": 104,
+ "fwo": 105,
+ "fwo→": 106,
+ "fwo↑": 107,
+ "fwo↓↑": 108,
+ "fwo↓": 109,
+ "jɛn": 110,
+ "jɛn→": 111,
+ "jɛn↑": 112,
+ "jɛn↓↑": 113,
+ "jɛn↓": 114,
+ "ɥæn": 115,
+ "ɥæn→": 116,
+ "ɥæn↑": 117,
+ "ɥæn↓↑": 118,
+ "ɥæn↓": 119,
+ "in": 120,
+ "in→": 121,
+ "in↑": 122,
+ "in↓↑": 123,
+ "in↓": 124,
+ "ɥn": 125,
+ "ɥn→": 126,
+ "ɥn↑": 127,
+ "ɥn↓↑": 128,
+ "ɥn↓": 129,
+ "iŋ": 130,
+ "iŋ→": 131,
+ "iŋ↑": 132,
+ "iŋ↓↑": 133,
+ "iŋ↓": 134,
+ "ʊŋ": 135,
+ "ʊŋ→": 136,
+ "ʊŋ↑": 137,
+ "ʊŋ↓↑": 138,
+ "ʊŋ↓": 139,
+ "jʊŋ": 140,
+ "jʊŋ→": 141,
+ "jʊŋ↑": 142,
+ "jʊŋ↓↑": 143,
+ "jʊŋ↓": 144,
+ "ia": 145,
+ "ia→": 146,
+ "ia↑": 147,
+ "ia↓↑": 148,
+ "ia↓": 149,
+ "iɛ": 150,
+ "iɛ→": 151,
+ "iɛ↑": 152,
+ "iɛ↓↑": 153,
+ "iɛ↓": 154,
+ "iɑʊ": 155,
+ "iɑʊ→": 156,
+ "iɑʊ↑": 157,
+ "iɑʊ↓↑": 158,
+ "iɑʊ↓": 159,
+ "ioʊ": 160,
+ "ioʊ→": 161,
+ "ioʊ↑": 162,
+ "ioʊ↓↑": 163,
+ "ioʊ↓": 164,
+ "iɑŋ": 165,
+ "iɑŋ→": 166,
+ "iɑŋ↑": 167,
+ "iɑŋ↓↑": 168,
+ "iɑŋ↓": 169,
+ "ua": 170,
+ "ua→": 171,
+ "ua↑": 172,
+ "ua↓↑": 173,
+ "ua↓": 174,
+ "uo": 175,
+ "uo→": 176,
+ "uo↑": 177,
+ "uo↓↑": 178,
+ "uo↓": 179,
+ "uaɪ": 180,
+ "uaɪ→": 181,
+ "uaɪ↑": 182,
+ "uaɪ↓↑": 183,
+ "uaɪ↓": 184,
+ "ueɪ": 185,
+ "ueɪ→": 186,
+ "ueɪ↑": 187,
+ "ueɪ↓↑": 188,
+ "ueɪ↓": 189,
+ "uan": 190,
+ "uan→": 191,
+ "uan↑": 192,
+ "uan↓↑": 193,
+ "uan↓": 194,
+ "uən": 195,
+ "uən→": 196,
+ "uən↑": 197,
+ "uən↓↑": 198,
+ "uən↓": 199,
+ "uɑŋ": 200,
+ "uɑŋ→": 201,
+ "uɑŋ↑": 202,
+ "uɑŋ↓↑": 203,
+ "uɑŋ↓": 204,
+ "ɥɛ": 205,
+ "ɥɛ→": 206,
+ "ɥɛ↑": 207,
+ "ɥɛ↓↑": 208,
+ "ɥɛ↓": 209,
+ "a": 210,
+ "a→": 211,
+ "a↑": 212,
+ "a↓↑": 213,
+ "a↓": 214,
+ "o": 215,
+ "o→": 216,
+ "o↑": 217,
+ "o↓↑": 218,
+ "o↓": 219,
+ "ə→": 220,
+ "ə↑": 221,
+ "ə↓↑": 222,
+ "ə↓": 223,
+ "ɛ→": 224,
+ "ɛ↑": 225,
+ "ɛ↓↑": 226,
+ "ɛ↓": 227,
+ "aɪ→": 228,
+ "aɪ↑": 229,
+ "aɪ↓↑": 230,
+ "aɪ↓": 231,
+ "eɪ→": 232,
+ "eɪ↑": 233,
+ "eɪ↓↑": 234,
+ "eɪ↓": 235,
+ "ɑʊ": 236,
+ "ɑʊ→": 237,
+ "ɑʊ↑": 238,
+ "ɑʊ↓↑": 239,
+ "ɑʊ↓": 240,
+ "oʊ→": 241,
+ "oʊ↑": 242,
+ "oʊ↓↑": 243,
+ "oʊ↓": 244,
+ "an": 245,
+ "an→": 246,
+ "an↑": 247,
+ "an↓↑": 248,
+ "an↓": 249,
+ "ən": 250,
+ "ən→": 251,
+ "ən↑": 252,
+ "ən↓↑": 253,
+ "ən↓": 254,
+ "ɑŋ": 255,
+ "ɑŋ→": 256,
+ "ɑŋ↑": 257,
+ "ɑŋ↓↑": 258,
+ "ɑŋ↓": 259,
+ "əŋ": 260,
+ "əŋ→": 261,
+ "əŋ↑": 262,
+ "əŋ↓↑": 263,
+ "əŋ↓": 264,
+ "əɹ": 265,
+ "əɹ→": 266,
+ "əɹ↑": 267,
+ "əɹ↓↑": 268,
+ "əɹ↓": 269,
+ "i→": 270,
+ "i↑": 271,
+ "i↓↑": 272,
+ "i↓": 273,
+ "u→": 274,
+ "u↑": 275,
+ "u↓↑": 276,
+ "u↓": 277,
+ "ɥ": 278,
+ "ɥ→": 279,
+ "ɥ↑": 280,
+ "ɥ↓↑": 281,
+ "ɥ↓": 282,
+ "ts`⁼ɹ": 283,
+ "ts`⁼ɹ→": 284,
+ "ts`⁼ɹ↑": 285,
+ "ts`⁼ɹ↓↑": 286,
+ "ts`⁼ɹ↓": 287,
+ "ts`ʰɹ": 288,
+ "ts`ʰɹ→": 289,
+ "ts`ʰɹ↑": 290,
+ "ts`ʰɹ↓↑": 291,
+ "ts`ʰɹ↓": 292,
+ "s`ɹ": 293,
+ "s`ɹ→": 294,
+ "s`ɹ↑": 295,
+ "s`ɹ↓↑": 296,
+ "s`ɹ↓": 297,
+ "ɹ`ɹ": 298,
+ "ɹ`ɹ→": 299,
+ "ɹ`ɹ↑": 300,
+ "ɹ`ɹ↓↑": 301,
+ "ɹ`ɹ↓": 302,
+ "ts⁼ɹ": 303,
+ "ts⁼ɹ→": 304,
+ "ts⁼ɹ↑": 305,
+ "ts⁼ɹ↓↑": 306,
+ "ts⁼ɹ↓": 307,
+ "tsʰɹ": 308,
+ "tsʰɹ→": 309,
+ "tsʰɹ↑": 310,
+ "tsʰɹ↓↑": 311,
+ "tsʰɹ↓": 312,
+ "sɹ": 313,
+ "sɹ→": 314,
+ "sɹ↑": 315,
+ "sɹ↓↑": 316,
+ "sɹ↓": 317,
+
+ "ɯ": 318,
+ "e": 319,
+ "aː": 320,
+ "ɯː": 321,
+ "eː": 322,
+ "ç": 323,
+ "ɸ": 324,
+ "ɰᵝ": 325,
+ "ɴ": 326,
+ "g": 327,
+ "dʑ": 328,
+ "q": 329,
+ "ː": 330,
+ "bj": 331,
+ "tɕ": 332,
+ "dej": 333,
+ "tej": 334,
+ "gj": 335,
+ "gɯ": 336,
+ "çj": 337,
+ "kj": 338,
+ "kɯ": 339,
+ "mj": 340,
+ "nj": 341,
+ "pj": 342,
+ "ɾj": 343,
+ "ɕ": 344,
+ "tsɯ": 345,
+
+ "ɐ": 346,
+ "ɑ": 347,
+ "ɒ": 348,
+ "ɜ": 349,
+ "ɫ": 350,
+ "ʑ": 351,
+ "ʲ": 352,
+
+ "y": 353,
+ "ø": 354,
+ "œ": 355,
+ "ʁ": 356,
+ "̃": 357,
+ "ɲ": 358,
+
+ ":": 359,
+ ";": 360,
+ "'": 361,
+ "…": 362
+ }
+}
diff --git a/models/tts/maskgct/g2p/g2p_generation.py b/models/tts/maskgct/g2p/g2p_generation.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca3d984145e0d8a49ee20dde7b1536d2bece862a
--- /dev/null
+++ b/models/tts/maskgct/g2p/g2p_generation.py
@@ -0,0 +1,120 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import sys
+
+from models.tts.maskgct.g2p.g2p import PhonemeBpeTokenizer
+from models.tts.maskgct.g2p.utils.g2p import phonemizer_g2p
+import tqdm
+from typing import List
+import json
+import os
+import re
+
+
+def ph_g2p(text, language):
+
+ return phonemizer_g2p(text=text, language=language)
+
+
+def g2p(text, sentence, language):
+
+ return text_tokenizer.tokenize(text=text, sentence=sentence, language=language)
+
+
+def is_chinese(char):
+ if char >= "\u4e00" and char <= "\u9fa5":
+ return True
+ else:
+ return False
+
+
+def is_alphabet(char):
+ if (char >= "\u0041" and char <= "\u005a") or (
+ char >= "\u0061" and char <= "\u007a"
+ ):
+ return True
+ else:
+ return False
+
+
+def is_other(char):
+ if not (is_chinese(char) or is_alphabet(char)):
+ return True
+ else:
+ return False
+
+
+def get_segment(text: str) -> List[str]:
+ # sentence --> [ch_part, en_part, ch_part, ...]
+ segments = []
+ types = []
+ flag = 0
+ temp_seg = ""
+ temp_lang = ""
+
+ # Determine the type of each character. type: blank, chinese, alphabet, number, unk and point.
+ for i, ch in enumerate(text):
+ if is_chinese(ch):
+ types.append("zh")
+ elif is_alphabet(ch):
+ types.append("en")
+ else:
+ types.append("other")
+
+ assert len(types) == len(text)
+
+ for i in range(len(types)):
+ # find the first char of the seg
+ if flag == 0:
+ temp_seg += text[i]
+ temp_lang = types[i]
+ flag = 1
+ else:
+ if temp_lang == "other":
+ if types[i] == temp_lang:
+ temp_seg += text[i]
+ else:
+ temp_seg += text[i]
+ temp_lang = types[i]
+ else:
+ if types[i] == temp_lang:
+ temp_seg += text[i]
+ elif types[i] == "other":
+ temp_seg += text[i]
+ else:
+ segments.append((temp_seg, temp_lang))
+ temp_seg = text[i]
+ temp_lang = types[i]
+ flag = 1
+
+ segments.append((temp_seg, temp_lang))
+ return segments
+
+
+def chn_eng_g2p(text: str):
+ # now only en and ch
+ segments = get_segment(text)
+ all_phoneme = ""
+ all_tokens = []
+
+ for index in range(len(segments)):
+ seg = segments[index]
+ phoneme, token = g2p(seg[0], text, seg[1])
+ all_phoneme += phoneme + "|"
+ all_tokens += token
+
+ if seg[1] == "en" and index == len(segments) - 1 and all_phoneme[-2] == "_":
+ all_phoneme = all_phoneme[:-2]
+ all_tokens = all_tokens[:-1]
+ return all_phoneme, all_tokens
+
+
+text_tokenizer = PhonemeBpeTokenizer()
+with open("./models/tts/maskgct/g2p/g2p/vocab.json", "r") as f:
+ json_data = f.read()
+data = json.loads(json_data)
+vocab = data["vocab"]
diff --git a/models/tts/maskgct/g2p/sources/bpmf_2_pinyin.txt b/models/tts/maskgct/g2p/sources/bpmf_2_pinyin.txt
new file mode 100644
index 0000000000000000000000000000000000000000..474529e5d347b94a80e5052de0065347ff14b95e
--- /dev/null
+++ b/models/tts/maskgct/g2p/sources/bpmf_2_pinyin.txt
@@ -0,0 +1,41 @@
+b ㄅ
+p ㄆ
+m ㄇ
+f ㄈ
+d ㄉ
+t ㄊ
+n ㄋ
+l ㄌ
+g ㄍ
+k ㄎ
+h ㄏ
+j ㄐ
+q ㄑ
+x ㄒ
+zh ㄓ
+ch ㄔ
+sh ㄕ
+r ㄖ
+z ㄗ
+c ㄘ
+s ㄙ
+i ㄧ
+u ㄨ
+v ㄩ
+a ㄚ
+o ㄛ
+e ㄜ
+e ㄝ
+ai ㄞ
+ei ㄟ
+ao ㄠ
+ou ㄡ
+an ㄢ
+en ㄣ
+ang ㄤ
+eng ㄥ
+er ㄦ
+2 ˊ
+3 ˇ
+4 ˋ
+0 ˙
diff --git a/models/tts/maskgct/g2p/sources/chinese_lexicon.txt b/models/tts/maskgct/g2p/sources/chinese_lexicon.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4d7dbf347a29d3b87c199d0e56ef7f1dbf28a6ee
--- /dev/null
+++ b/models/tts/maskgct/g2p/sources/chinese_lexicon.txt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a3a7685d1c3e68eb2fa304bfc63e90c90c3c1a1948839a5b1b507b2131b3e2fb
+size 14779443
diff --git a/models/tts/maskgct/g2p/sources/g2p_chinese_model/config.json b/models/tts/maskgct/g2p/sources/g2p_chinese_model/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..5fb70ca91db27a4ad73b58a0c500a903be9bc1a9
--- /dev/null
+++ b/models/tts/maskgct/g2p/sources/g2p_chinese_model/config.json
@@ -0,0 +1,819 @@
+{
+ "_name_or_path": "/BERT-POLY-v2/pretrained_models/mini_bert",
+ "architectures": [
+ "BertPoly"
+ ],
+ "attention_probs_dropout_prob": 0.1,
+ "classifier_dropout": null,
+ "directionality": "bidi",
+ "gradient_checkpointing": false,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 384,
+ "id2label": {
+ "0": "LABEL_0",
+ "1": "LABEL_1",
+ "2": "LABEL_2",
+ "3": "LABEL_3",
+ "4": "LABEL_4",
+ "5": "LABEL_5",
+ "6": "LABEL_6",
+ "7": "LABEL_7",
+ "8": "LABEL_8",
+ "9": "LABEL_9",
+ "10": "LABEL_10",
+ "11": "LABEL_11",
+ "12": "LABEL_12",
+ "13": "LABEL_13",
+ "14": "LABEL_14",
+ "15": "LABEL_15",
+ "16": "LABEL_16",
+ "17": "LABEL_17",
+ "18": "LABEL_18",
+ "19": "LABEL_19",
+ "20": "LABEL_20",
+ "21": "LABEL_21",
+ "22": "LABEL_22",
+ "23": "LABEL_23",
+ "24": "LABEL_24",
+ "25": "LABEL_25",
+ "26": "LABEL_26",
+ "27": "LABEL_27",
+ "28": "LABEL_28",
+ "29": "LABEL_29",
+ "30": "LABEL_30",
+ "31": "LABEL_31",
+ "32": "LABEL_32",
+ "33": "LABEL_33",
+ "34": "LABEL_34",
+ "35": "LABEL_35",
+ "36": "LABEL_36",
+ "37": "LABEL_37",
+ "38": "LABEL_38",
+ "39": "LABEL_39",
+ "40": "LABEL_40",
+ "41": "LABEL_41",
+ "42": "LABEL_42",
+ "43": "LABEL_43",
+ "44": "LABEL_44",
+ "45": "LABEL_45",
+ "46": "LABEL_46",
+ "47": "LABEL_47",
+ "48": "LABEL_48",
+ "49": "LABEL_49",
+ "50": "LABEL_50",
+ "51": "LABEL_51",
+ "52": "LABEL_52",
+ "53": "LABEL_53",
+ "54": "LABEL_54",
+ "55": "LABEL_55",
+ "56": "LABEL_56",
+ "57": "LABEL_57",
+ "58": "LABEL_58",
+ "59": "LABEL_59",
+ "60": "LABEL_60",
+ "61": "LABEL_61",
+ "62": "LABEL_62",
+ "63": "LABEL_63",
+ "64": "LABEL_64",
+ "65": "LABEL_65",
+ "66": "LABEL_66",
+ "67": "LABEL_67",
+ "68": "LABEL_68",
+ "69": "LABEL_69",
+ "70": "LABEL_70",
+ "71": "LABEL_71",
+ "72": "LABEL_72",
+ "73": "LABEL_73",
+ "74": "LABEL_74",
+ "75": "LABEL_75",
+ "76": "LABEL_76",
+ "77": "LABEL_77",
+ "78": "LABEL_78",
+ "79": "LABEL_79",
+ "80": "LABEL_80",
+ "81": "LABEL_81",
+ "82": "LABEL_82",
+ "83": "LABEL_83",
+ "84": "LABEL_84",
+ "85": "LABEL_85",
+ "86": "LABEL_86",
+ "87": "LABEL_87",
+ "88": "LABEL_88",
+ "89": "LABEL_89",
+ "90": "LABEL_90",
+ "91": "LABEL_91",
+ "92": "LABEL_92",
+ "93": "LABEL_93",
+ "94": "LABEL_94",
+ "95": "LABEL_95",
+ "96": "LABEL_96",
+ "97": "LABEL_97",
+ "98": "LABEL_98",
+ "99": "LABEL_99",
+ "100": "LABEL_100",
+ "101": "LABEL_101",
+ "102": "LABEL_102",
+ "103": "LABEL_103",
+ "104": "LABEL_104",
+ "105": "LABEL_105",
+ "106": "LABEL_106",
+ "107": "LABEL_107",
+ "108": "LABEL_108",
+ "109": "LABEL_109",
+ "110": "LABEL_110",
+ "111": "LABEL_111",
+ "112": "LABEL_112",
+ "113": "LABEL_113",
+ "114": "LABEL_114",
+ "115": "LABEL_115",
+ "116": "LABEL_116",
+ "117": "LABEL_117",
+ "118": "LABEL_118",
+ "119": "LABEL_119",
+ "120": "LABEL_120",
+ "121": "LABEL_121",
+ "122": "LABEL_122",
+ "123": "LABEL_123",
+ "124": "LABEL_124",
+ "125": "LABEL_125",
+ "126": "LABEL_126",
+ "127": "LABEL_127",
+ "128": "LABEL_128",
+ "129": "LABEL_129",
+ "130": "LABEL_130",
+ "131": "LABEL_131",
+ "132": "LABEL_132",
+ "133": "LABEL_133",
+ "134": "LABEL_134",
+ "135": "LABEL_135",
+ "136": "LABEL_136",
+ "137": "LABEL_137",
+ "138": "LABEL_138",
+ "139": "LABEL_139",
+ "140": "LABEL_140",
+ "141": "LABEL_141",
+ "142": "LABEL_142",
+ "143": "LABEL_143",
+ "144": "LABEL_144",
+ "145": "LABEL_145",
+ "146": "LABEL_146",
+ "147": "LABEL_147",
+ "148": "LABEL_148",
+ "149": "LABEL_149",
+ "150": "LABEL_150",
+ "151": "LABEL_151",
+ "152": "LABEL_152",
+ "153": "LABEL_153",
+ "154": "LABEL_154",
+ "155": "LABEL_155",
+ "156": "LABEL_156",
+ "157": "LABEL_157",
+ "158": "LABEL_158",
+ "159": "LABEL_159",
+ "160": "LABEL_160",
+ "161": "LABEL_161",
+ "162": "LABEL_162",
+ "163": "LABEL_163",
+ "164": "LABEL_164",
+ "165": "LABEL_165",
+ "166": "LABEL_166",
+ "167": "LABEL_167",
+ "168": "LABEL_168",
+ "169": "LABEL_169",
+ "170": "LABEL_170",
+ "171": "LABEL_171",
+ "172": "LABEL_172",
+ "173": "LABEL_173",
+ "174": "LABEL_174",
+ "175": "LABEL_175",
+ "176": "LABEL_176",
+ "177": "LABEL_177",
+ "178": "LABEL_178",
+ "179": "LABEL_179",
+ "180": "LABEL_180",
+ "181": "LABEL_181",
+ "182": "LABEL_182",
+ "183": "LABEL_183",
+ "184": "LABEL_184",
+ "185": "LABEL_185",
+ "186": "LABEL_186",
+ "187": "LABEL_187",
+ "188": "LABEL_188",
+ "189": "LABEL_189",
+ "190": "LABEL_190",
+ "191": "LABEL_191",
+ "192": "LABEL_192",
+ "193": "LABEL_193",
+ "194": "LABEL_194",
+ "195": "LABEL_195",
+ "196": "LABEL_196",
+ "197": "LABEL_197",
+ "198": "LABEL_198",
+ "199": "LABEL_199",
+ "200": "LABEL_200",
+ "201": "LABEL_201",
+ "202": "LABEL_202",
+ "203": "LABEL_203",
+ "204": "LABEL_204",
+ "205": "LABEL_205",
+ "206": "LABEL_206",
+ "207": "LABEL_207",
+ "208": "LABEL_208",
+ "209": "LABEL_209",
+ "210": "LABEL_210",
+ "211": "LABEL_211",
+ "212": "LABEL_212",
+ "213": "LABEL_213",
+ "214": "LABEL_214",
+ "215": "LABEL_215",
+ "216": "LABEL_216",
+ "217": "LABEL_217",
+ "218": "LABEL_218",
+ "219": "LABEL_219",
+ "220": "LABEL_220",
+ "221": "LABEL_221",
+ "222": "LABEL_222",
+ "223": "LABEL_223",
+ "224": "LABEL_224",
+ "225": "LABEL_225",
+ "226": "LABEL_226",
+ "227": "LABEL_227",
+ "228": "LABEL_228",
+ "229": "LABEL_229",
+ "230": "LABEL_230",
+ "231": "LABEL_231",
+ "232": "LABEL_232",
+ "233": "LABEL_233",
+ "234": "LABEL_234",
+ "235": "LABEL_235",
+ "236": "LABEL_236",
+ "237": "LABEL_237",
+ "238": "LABEL_238",
+ "239": "LABEL_239",
+ "240": "LABEL_240",
+ "241": "LABEL_241",
+ "242": "LABEL_242",
+ "243": "LABEL_243",
+ "244": "LABEL_244",
+ "245": "LABEL_245",
+ "246": "LABEL_246",
+ "247": "LABEL_247",
+ "248": "LABEL_248",
+ "249": "LABEL_249",
+ "250": "LABEL_250",
+ "251": "LABEL_251",
+ "252": "LABEL_252",
+ "253": "LABEL_253",
+ "254": "LABEL_254",
+ "255": "LABEL_255",
+ "256": "LABEL_256",
+ "257": "LABEL_257",
+ "258": "LABEL_258",
+ "259": "LABEL_259",
+ "260": "LABEL_260",
+ "261": "LABEL_261",
+ "262": "LABEL_262",
+ "263": "LABEL_263",
+ "264": "LABEL_264",
+ "265": "LABEL_265",
+ "266": "LABEL_266",
+ "267": "LABEL_267",
+ "268": "LABEL_268",
+ "269": "LABEL_269",
+ "270": "LABEL_270",
+ "271": "LABEL_271",
+ "272": "LABEL_272",
+ "273": "LABEL_273",
+ "274": "LABEL_274",
+ "275": "LABEL_275",
+ "276": "LABEL_276",
+ "277": "LABEL_277",
+ "278": "LABEL_278",
+ "279": "LABEL_279",
+ "280": "LABEL_280",
+ "281": "LABEL_281",
+ "282": "LABEL_282",
+ "283": "LABEL_283",
+ "284": "LABEL_284",
+ "285": "LABEL_285",
+ "286": "LABEL_286",
+ "287": "LABEL_287",
+ "288": "LABEL_288",
+ "289": "LABEL_289",
+ "290": "LABEL_290",
+ "291": "LABEL_291",
+ "292": "LABEL_292",
+ "293": "LABEL_293",
+ "294": "LABEL_294",
+ "295": "LABEL_295",
+ "296": "LABEL_296",
+ "297": "LABEL_297",
+ "298": "LABEL_298",
+ "299": "LABEL_299",
+ "300": "LABEL_300",
+ "301": "LABEL_301",
+ "302": "LABEL_302",
+ "303": "LABEL_303",
+ "304": "LABEL_304",
+ "305": "LABEL_305",
+ "306": "LABEL_306",
+ "307": "LABEL_307",
+ "308": "LABEL_308",
+ "309": "LABEL_309",
+ "310": "LABEL_310",
+ "311": "LABEL_311",
+ "312": "LABEL_312",
+ "313": "LABEL_313",
+ "314": "LABEL_314",
+ "315": "LABEL_315",
+ "316": "LABEL_316",
+ "317": "LABEL_317",
+ "318": "LABEL_318",
+ "319": "LABEL_319",
+ "320": "LABEL_320",
+ "321": "LABEL_321",
+ "322": "LABEL_322",
+ "323": "LABEL_323",
+ "324": "LABEL_324",
+ "325": "LABEL_325",
+ "326": "LABEL_326",
+ "327": "LABEL_327",
+ "328": "LABEL_328",
+ "329": "LABEL_329",
+ "330": "LABEL_330",
+ "331": "LABEL_331",
+ "332": "LABEL_332",
+ "333": "LABEL_333",
+ "334": "LABEL_334",
+ "335": "LABEL_335",
+ "336": "LABEL_336",
+ "337": "LABEL_337",
+ "338": "LABEL_338",
+ "339": "LABEL_339",
+ "340": "LABEL_340",
+ "341": "LABEL_341",
+ "342": "LABEL_342",
+ "343": "LABEL_343",
+ "344": "LABEL_344",
+ "345": "LABEL_345",
+ "346": "LABEL_346",
+ "347": "LABEL_347",
+ "348": "LABEL_348",
+ "349": "LABEL_349",
+ "350": "LABEL_350",
+ "351": "LABEL_351",
+ "352": "LABEL_352",
+ "353": "LABEL_353",
+ "354": "LABEL_354",
+ "355": "LABEL_355",
+ "356": "LABEL_356",
+ "357": "LABEL_357",
+ "358": "LABEL_358",
+ "359": "LABEL_359",
+ "360": "LABEL_360",
+ "361": "LABEL_361",
+ "362": "LABEL_362",
+ "363": "LABEL_363",
+ "364": "LABEL_364",
+ "365": "LABEL_365",
+ "366": "LABEL_366",
+ "367": "LABEL_367",
+ "368": "LABEL_368",
+ "369": "LABEL_369",
+ "370": "LABEL_370",
+ "371": "LABEL_371",
+ "372": "LABEL_372",
+ "373": "LABEL_373",
+ "374": "LABEL_374",
+ "375": "LABEL_375",
+ "376": "LABEL_376",
+ "377": "LABEL_377",
+ "378": "LABEL_378",
+ "379": "LABEL_379",
+ "380": "LABEL_380",
+ "381": "LABEL_381",
+ "382": "LABEL_382",
+ "383": "LABEL_383",
+ "384": "LABEL_384",
+ "385": "LABEL_385",
+ "386": "LABEL_386",
+ "387": "LABEL_387",
+ "388": "LABEL_388",
+ "389": "LABEL_389",
+ "390": "LABEL_390"
+ },
+ "initializer_range": 0.02,
+ "intermediate_size": 1536,
+ "label2id": {
+ "LABEL_0": 0,
+ "LABEL_1": 1,
+ "LABEL_10": 10,
+ "LABEL_100": 100,
+ "LABEL_101": 101,
+ "LABEL_102": 102,
+ "LABEL_103": 103,
+ "LABEL_104": 104,
+ "LABEL_105": 105,
+ "LABEL_106": 106,
+ "LABEL_107": 107,
+ "LABEL_108": 108,
+ "LABEL_109": 109,
+ "LABEL_11": 11,
+ "LABEL_110": 110,
+ "LABEL_111": 111,
+ "LABEL_112": 112,
+ "LABEL_113": 113,
+ "LABEL_114": 114,
+ "LABEL_115": 115,
+ "LABEL_116": 116,
+ "LABEL_117": 117,
+ "LABEL_118": 118,
+ "LABEL_119": 119,
+ "LABEL_12": 12,
+ "LABEL_120": 120,
+ "LABEL_121": 121,
+ "LABEL_122": 122,
+ "LABEL_123": 123,
+ "LABEL_124": 124,
+ "LABEL_125": 125,
+ "LABEL_126": 126,
+ "LABEL_127": 127,
+ "LABEL_128": 128,
+ "LABEL_129": 129,
+ "LABEL_13": 13,
+ "LABEL_130": 130,
+ "LABEL_131": 131,
+ "LABEL_132": 132,
+ "LABEL_133": 133,
+ "LABEL_134": 134,
+ "LABEL_135": 135,
+ "LABEL_136": 136,
+ "LABEL_137": 137,
+ "LABEL_138": 138,
+ "LABEL_139": 139,
+ "LABEL_14": 14,
+ "LABEL_140": 140,
+ "LABEL_141": 141,
+ "LABEL_142": 142,
+ "LABEL_143": 143,
+ "LABEL_144": 144,
+ "LABEL_145": 145,
+ "LABEL_146": 146,
+ "LABEL_147": 147,
+ "LABEL_148": 148,
+ "LABEL_149": 149,
+ "LABEL_15": 15,
+ "LABEL_150": 150,
+ "LABEL_151": 151,
+ "LABEL_152": 152,
+ "LABEL_153": 153,
+ "LABEL_154": 154,
+ "LABEL_155": 155,
+ "LABEL_156": 156,
+ "LABEL_157": 157,
+ "LABEL_158": 158,
+ "LABEL_159": 159,
+ "LABEL_16": 16,
+ "LABEL_160": 160,
+ "LABEL_161": 161,
+ "LABEL_162": 162,
+ "LABEL_163": 163,
+ "LABEL_164": 164,
+ "LABEL_165": 165,
+ "LABEL_166": 166,
+ "LABEL_167": 167,
+ "LABEL_168": 168,
+ "LABEL_169": 169,
+ "LABEL_17": 17,
+ "LABEL_170": 170,
+ "LABEL_171": 171,
+ "LABEL_172": 172,
+ "LABEL_173": 173,
+ "LABEL_174": 174,
+ "LABEL_175": 175,
+ "LABEL_176": 176,
+ "LABEL_177": 177,
+ "LABEL_178": 178,
+ "LABEL_179": 179,
+ "LABEL_18": 18,
+ "LABEL_180": 180,
+ "LABEL_181": 181,
+ "LABEL_182": 182,
+ "LABEL_183": 183,
+ "LABEL_184": 184,
+ "LABEL_185": 185,
+ "LABEL_186": 186,
+ "LABEL_187": 187,
+ "LABEL_188": 188,
+ "LABEL_189": 189,
+ "LABEL_19": 19,
+ "LABEL_190": 190,
+ "LABEL_191": 191,
+ "LABEL_192": 192,
+ "LABEL_193": 193,
+ "LABEL_194": 194,
+ "LABEL_195": 195,
+ "LABEL_196": 196,
+ "LABEL_197": 197,
+ "LABEL_198": 198,
+ "LABEL_199": 199,
+ "LABEL_2": 2,
+ "LABEL_20": 20,
+ "LABEL_200": 200,
+ "LABEL_201": 201,
+ "LABEL_202": 202,
+ "LABEL_203": 203,
+ "LABEL_204": 204,
+ "LABEL_205": 205,
+ "LABEL_206": 206,
+ "LABEL_207": 207,
+ "LABEL_208": 208,
+ "LABEL_209": 209,
+ "LABEL_21": 21,
+ "LABEL_210": 210,
+ "LABEL_211": 211,
+ "LABEL_212": 212,
+ "LABEL_213": 213,
+ "LABEL_214": 214,
+ "LABEL_215": 215,
+ "LABEL_216": 216,
+ "LABEL_217": 217,
+ "LABEL_218": 218,
+ "LABEL_219": 219,
+ "LABEL_22": 22,
+ "LABEL_220": 220,
+ "LABEL_221": 221,
+ "LABEL_222": 222,
+ "LABEL_223": 223,
+ "LABEL_224": 224,
+ "LABEL_225": 225,
+ "LABEL_226": 226,
+ "LABEL_227": 227,
+ "LABEL_228": 228,
+ "LABEL_229": 229,
+ "LABEL_23": 23,
+ "LABEL_230": 230,
+ "LABEL_231": 231,
+ "LABEL_232": 232,
+ "LABEL_233": 233,
+ "LABEL_234": 234,
+ "LABEL_235": 235,
+ "LABEL_236": 236,
+ "LABEL_237": 237,
+ "LABEL_238": 238,
+ "LABEL_239": 239,
+ "LABEL_24": 24,
+ "LABEL_240": 240,
+ "LABEL_241": 241,
+ "LABEL_242": 242,
+ "LABEL_243": 243,
+ "LABEL_244": 244,
+ "LABEL_245": 245,
+ "LABEL_246": 246,
+ "LABEL_247": 247,
+ "LABEL_248": 248,
+ "LABEL_249": 249,
+ "LABEL_25": 25,
+ "LABEL_250": 250,
+ "LABEL_251": 251,
+ "LABEL_252": 252,
+ "LABEL_253": 253,
+ "LABEL_254": 254,
+ "LABEL_255": 255,
+ "LABEL_256": 256,
+ "LABEL_257": 257,
+ "LABEL_258": 258,
+ "LABEL_259": 259,
+ "LABEL_26": 26,
+ "LABEL_260": 260,
+ "LABEL_261": 261,
+ "LABEL_262": 262,
+ "LABEL_263": 263,
+ "LABEL_264": 264,
+ "LABEL_265": 265,
+ "LABEL_266": 266,
+ "LABEL_267": 267,
+ "LABEL_268": 268,
+ "LABEL_269": 269,
+ "LABEL_27": 27,
+ "LABEL_270": 270,
+ "LABEL_271": 271,
+ "LABEL_272": 272,
+ "LABEL_273": 273,
+ "LABEL_274": 274,
+ "LABEL_275": 275,
+ "LABEL_276": 276,
+ "LABEL_277": 277,
+ "LABEL_278": 278,
+ "LABEL_279": 279,
+ "LABEL_28": 28,
+ "LABEL_280": 280,
+ "LABEL_281": 281,
+ "LABEL_282": 282,
+ "LABEL_283": 283,
+ "LABEL_284": 284,
+ "LABEL_285": 285,
+ "LABEL_286": 286,
+ "LABEL_287": 287,
+ "LABEL_288": 288,
+ "LABEL_289": 289,
+ "LABEL_29": 29,
+ "LABEL_290": 290,
+ "LABEL_291": 291,
+ "LABEL_292": 292,
+ "LABEL_293": 293,
+ "LABEL_294": 294,
+ "LABEL_295": 295,
+ "LABEL_296": 296,
+ "LABEL_297": 297,
+ "LABEL_298": 298,
+ "LABEL_299": 299,
+ "LABEL_3": 3,
+ "LABEL_30": 30,
+ "LABEL_300": 300,
+ "LABEL_301": 301,
+ "LABEL_302": 302,
+ "LABEL_303": 303,
+ "LABEL_304": 304,
+ "LABEL_305": 305,
+ "LABEL_306": 306,
+ "LABEL_307": 307,
+ "LABEL_308": 308,
+ "LABEL_309": 309,
+ "LABEL_31": 31,
+ "LABEL_310": 310,
+ "LABEL_311": 311,
+ "LABEL_312": 312,
+ "LABEL_313": 313,
+ "LABEL_314": 314,
+ "LABEL_315": 315,
+ "LABEL_316": 316,
+ "LABEL_317": 317,
+ "LABEL_318": 318,
+ "LABEL_319": 319,
+ "LABEL_32": 32,
+ "LABEL_320": 320,
+ "LABEL_321": 321,
+ "LABEL_322": 322,
+ "LABEL_323": 323,
+ "LABEL_324": 324,
+ "LABEL_325": 325,
+ "LABEL_326": 326,
+ "LABEL_327": 327,
+ "LABEL_328": 328,
+ "LABEL_329": 329,
+ "LABEL_33": 33,
+ "LABEL_330": 330,
+ "LABEL_331": 331,
+ "LABEL_332": 332,
+ "LABEL_333": 333,
+ "LABEL_334": 334,
+ "LABEL_335": 335,
+ "LABEL_336": 336,
+ "LABEL_337": 337,
+ "LABEL_338": 338,
+ "LABEL_339": 339,
+ "LABEL_34": 34,
+ "LABEL_340": 340,
+ "LABEL_341": 341,
+ "LABEL_342": 342,
+ "LABEL_343": 343,
+ "LABEL_344": 344,
+ "LABEL_345": 345,
+ "LABEL_346": 346,
+ "LABEL_347": 347,
+ "LABEL_348": 348,
+ "LABEL_349": 349,
+ "LABEL_35": 35,
+ "LABEL_350": 350,
+ "LABEL_351": 351,
+ "LABEL_352": 352,
+ "LABEL_353": 353,
+ "LABEL_354": 354,
+ "LABEL_355": 355,
+ "LABEL_356": 356,
+ "LABEL_357": 357,
+ "LABEL_358": 358,
+ "LABEL_359": 359,
+ "LABEL_36": 36,
+ "LABEL_360": 360,
+ "LABEL_361": 361,
+ "LABEL_362": 362,
+ "LABEL_363": 363,
+ "LABEL_364": 364,
+ "LABEL_365": 365,
+ "LABEL_366": 366,
+ "LABEL_367": 367,
+ "LABEL_368": 368,
+ "LABEL_369": 369,
+ "LABEL_37": 37,
+ "LABEL_370": 370,
+ "LABEL_371": 371,
+ "LABEL_372": 372,
+ "LABEL_373": 373,
+ "LABEL_374": 374,
+ "LABEL_375": 375,
+ "LABEL_376": 376,
+ "LABEL_377": 377,
+ "LABEL_378": 378,
+ "LABEL_379": 379,
+ "LABEL_38": 38,
+ "LABEL_380": 380,
+ "LABEL_381": 381,
+ "LABEL_382": 382,
+ "LABEL_383": 383,
+ "LABEL_384": 384,
+ "LABEL_385": 385,
+ "LABEL_386": 386,
+ "LABEL_387": 387,
+ "LABEL_388": 388,
+ "LABEL_389": 389,
+ "LABEL_39": 39,
+ "LABEL_390": 390,
+ "LABEL_4": 4,
+ "LABEL_40": 40,
+ "LABEL_41": 41,
+ "LABEL_42": 42,
+ "LABEL_43": 43,
+ "LABEL_44": 44,
+ "LABEL_45": 45,
+ "LABEL_46": 46,
+ "LABEL_47": 47,
+ "LABEL_48": 48,
+ "LABEL_49": 49,
+ "LABEL_5": 5,
+ "LABEL_50": 50,
+ "LABEL_51": 51,
+ "LABEL_52": 52,
+ "LABEL_53": 53,
+ "LABEL_54": 54,
+ "LABEL_55": 55,
+ "LABEL_56": 56,
+ "LABEL_57": 57,
+ "LABEL_58": 58,
+ "LABEL_59": 59,
+ "LABEL_6": 6,
+ "LABEL_60": 60,
+ "LABEL_61": 61,
+ "LABEL_62": 62,
+ "LABEL_63": 63,
+ "LABEL_64": 64,
+ "LABEL_65": 65,
+ "LABEL_66": 66,
+ "LABEL_67": 67,
+ "LABEL_68": 68,
+ "LABEL_69": 69,
+ "LABEL_7": 7,
+ "LABEL_70": 70,
+ "LABEL_71": 71,
+ "LABEL_72": 72,
+ "LABEL_73": 73,
+ "LABEL_74": 74,
+ "LABEL_75": 75,
+ "LABEL_76": 76,
+ "LABEL_77": 77,
+ "LABEL_78": 78,
+ "LABEL_79": 79,
+ "LABEL_8": 8,
+ "LABEL_80": 80,
+ "LABEL_81": 81,
+ "LABEL_82": 82,
+ "LABEL_83": 83,
+ "LABEL_84": 84,
+ "LABEL_85": 85,
+ "LABEL_86": 86,
+ "LABEL_87": 87,
+ "LABEL_88": 88,
+ "LABEL_89": 89,
+ "LABEL_9": 9,
+ "LABEL_90": 90,
+ "LABEL_91": 91,
+ "LABEL_92": 92,
+ "LABEL_93": 93,
+ "LABEL_94": 94,
+ "LABEL_95": 95,
+ "LABEL_96": 96,
+ "LABEL_97": 97,
+ "LABEL_98": 98,
+ "LABEL_99": 99
+ },
+ "layer_norm_eps": 1e-12,
+ "max_position_embeddings": 512,
+ "model_type": "bert",
+ "num_attention_heads": 12,
+ "num_hidden_layers": 6,
+ "num_relation_heads": 32,
+ "pad_token_id": 0,
+ "pooler_fc_size": 768,
+ "pooler_num_attention_heads": 12,
+ "pooler_num_fc_layers": 3,
+ "pooler_size_per_head": 128,
+ "pooler_type": "first_token_transform",
+ "position_embedding_type": "absolute",
+ "torch_dtype": "float32",
+ "transformers_version": "4.44.1",
+ "type_vocab_size": 2,
+ "use_cache": true,
+ "vocab_size": 21128
+}
diff --git a/models/tts/maskgct/g2p/sources/g2p_chinese_model/poly_bert_model.onnx b/models/tts/maskgct/g2p/sources/g2p_chinese_model/poly_bert_model.onnx
new file mode 100644
index 0000000000000000000000000000000000000000..6b952b9717eb71bb5a7aa2492478095f117858dd
--- /dev/null
+++ b/models/tts/maskgct/g2p/sources/g2p_chinese_model/poly_bert_model.onnx
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8765d835ffdf9811c832d4dc7b6a552757aa8615c01d1184db716a50c20aebbc
+size 76583333
diff --git a/models/tts/maskgct/g2p/sources/g2p_chinese_model/polychar.txt b/models/tts/maskgct/g2p/sources/g2p_chinese_model/polychar.txt
new file mode 100644
index 0000000000000000000000000000000000000000..819f6249a661134128c7a4bc72a1059ebe133d20
--- /dev/null
+++ b/models/tts/maskgct/g2p/sources/g2p_chinese_model/polychar.txt
@@ -0,0 +1,159 @@
+丧
+中
+为
+乌
+乐
+了
+什
+仔
+令
+任
+会
+传
+佛
+供
+便
+倒
+假
+兴
+冠
+冲
+几
+分
+切
+划
+创
+剥
+勒
+区
+华
+单
+卜
+占
+卡
+卷
+厦
+参
+发
+只
+号
+同
+吐
+和
+喝
+圈
+地
+塞
+壳
+处
+奇
+奔
+好
+宁
+宿
+将
+少
+尽
+岗
+差
+巷
+帖
+干
+应
+度
+弹
+强
+当
+待
+得
+恶
+扁
+扇
+扎
+扫
+担
+挑
+据
+撒
+教
+散
+数
+斗
+晃
+曝
+曲
+更
+曾
+朝
+朴
+杆
+查
+校
+模
+横
+没
+泡
+济
+混
+漂
+炸
+熟
+燕
+片
+率
+畜
+的
+盛
+相
+省
+看
+着
+矫
+禁
+种
+称
+空
+答
+粘
+糊
+系
+累
+纤
+结
+给
+缝
+肖
+背
+脏
+舍
+色
+落
+蒙
+薄
+藏
+血
+行
+要
+观
+觉
+角
+解
+说
+调
+踏
+车
+转
+载
+还
+遂
+都
+重
+量
+钻
+铺
+长
+间
+降
+难
+露
+鲜
diff --git a/models/tts/maskgct/g2p/sources/g2p_chinese_model/polydict.json b/models/tts/maskgct/g2p/sources/g2p_chinese_model/polydict.json
new file mode 100644
index 0000000000000000000000000000000000000000..903fd018067b185c8cb8cd8a5b6cf07822512989
--- /dev/null
+++ b/models/tts/maskgct/g2p/sources/g2p_chinese_model/polydict.json
@@ -0,0 +1,393 @@
+{
+ "1": "丧{sang1}",
+ "2": "丧{sang4}",
+ "3": "中{zhong1}",
+ "4": "中{zhong4}",
+ "5": "为{wei2}",
+ "6": "为{wei4}",
+ "7": "乌{wu1}",
+ "8": "乌{wu4}",
+ "9": "乐{lao4}",
+ "10": "乐{le4}",
+ "11": "乐{le5}",
+ "12": "乐{yao4}",
+ "13": "乐{yve4}",
+ "14": "了{le5}",
+ "15": "了{liao3}",
+ "16": "了{liao5}",
+ "17": "什{shen2}",
+ "18": "什{shi2}",
+ "19": "仔{zai3}",
+ "20": "仔{zai5}",
+ "21": "仔{zi3}",
+ "22": "仔{zi5}",
+ "23": "令{ling2}",
+ "24": "令{ling4}",
+ "25": "任{ren2}",
+ "26": "任{ren4}",
+ "27": "会{hui4}",
+ "28": "会{hui5}",
+ "29": "会{kuai4}",
+ "30": "传{chuan2}",
+ "31": "传{zhuan4}",
+ "32": "佛{fo2}",
+ "33": "佛{fu2}",
+ "34": "供{gong1}",
+ "35": "供{gong4}",
+ "36": "便{bian4}",
+ "37": "便{pian2}",
+ "38": "倒{dao3}",
+ "39": "倒{dao4}",
+ "40": "假{jia3}",
+ "41": "假{jia4}",
+ "42": "兴{xing1}",
+ "43": "兴{xing4}",
+ "44": "冠{guan1}",
+ "45": "冠{guan4}",
+ "46": "冲{chong1}",
+ "47": "冲{chong4}",
+ "48": "几{ji1}",
+ "49": "几{ji2}",
+ "50": "几{ji3}",
+ "51": "分{fen1}",
+ "52": "分{fen4}",
+ "53": "分{fen5}",
+ "54": "切{qie1}",
+ "55": "切{qie4}",
+ "56": "划{hua2}",
+ "57": "划{hua4}",
+ "58": "划{hua5}",
+ "59": "创{chuang1}",
+ "60": "创{chuang4}",
+ "61": "剥{bao1}",
+ "62": "剥{bo1}",
+ "63": "勒{le4}",
+ "64": "勒{le5}",
+ "65": "勒{lei1}",
+ "66": "区{ou1}",
+ "67": "区{qu1}",
+ "68": "华{hua2}",
+ "69": "华{hua4}",
+ "70": "单{chan2}",
+ "71": "单{dan1}",
+ "72": "单{shan4}",
+ "73": "卜{bo5}",
+ "74": "卜{bu3}",
+ "75": "占{zhan1}",
+ "76": "占{zhan4}",
+ "77": "卡{ka2}",
+ "78": "卡{ka3}",
+ "79": "卡{qia3}",
+ "80": "卷{jvan3}",
+ "81": "卷{jvan4}",
+ "82": "厦{sha4}",
+ "83": "厦{xia4}",
+ "84": "参{can1}",
+ "85": "参{cen1}",
+ "86": "参{shen1}",
+ "87": "发{fa1}",
+ "88": "发{fa4}",
+ "89": "发{fa5}",
+ "90": "只{zhi1}",
+ "91": "只{zhi3}",
+ "92": "号{hao2}",
+ "93": "号{hao4}",
+ "94": "号{hao5}",
+ "95": "同{tong2}",
+ "96": "同{tong4}",
+ "97": "同{tong5}",
+ "98": "吐{tu2}",
+ "99": "吐{tu3}",
+ "100": "吐{tu4}",
+ "101": "和{he2}",
+ "102": "和{he4}",
+ "103": "和{he5}",
+ "104": "和{huo2}",
+ "105": "和{huo4}",
+ "106": "和{huo5}",
+ "107": "喝{he1}",
+ "108": "喝{he4}",
+ "109": "圈{jvan4}",
+ "110": "圈{qvan1}",
+ "111": "圈{qvan5}",
+ "112": "地{de5}",
+ "113": "地{di4}",
+ "114": "地{di5}",
+ "115": "塞{sai1}",
+ "116": "塞{sai2}",
+ "117": "塞{sai4}",
+ "118": "塞{se4}",
+ "119": "壳{ke2}",
+ "120": "壳{qiao4}",
+ "121": "处{chu3}",
+ "122": "处{chu4}",
+ "123": "奇{ji1}",
+ "124": "奇{qi2}",
+ "125": "奔{ben1}",
+ "126": "奔{ben4}",
+ "127": "好{hao3}",
+ "128": "好{hao4}",
+ "129": "好{hao5}",
+ "130": "宁{ning2}",
+ "131": "宁{ning4}",
+ "132": "宁{ning5}",
+ "133": "宿{su4}",
+ "134": "宿{xiu3}",
+ "135": "宿{xiu4}",
+ "136": "将{jiang1}",
+ "137": "将{jiang4}",
+ "138": "少{shao3}",
+ "139": "少{shao4}",
+ "140": "尽{jin3}",
+ "141": "尽{jin4}",
+ "142": "岗{gang1}",
+ "143": "岗{gang3}",
+ "144": "差{cha1}",
+ "145": "差{cha4}",
+ "146": "差{chai1}",
+ "147": "差{ci1}",
+ "148": "巷{hang4}",
+ "149": "巷{xiang4}",
+ "150": "帖{tie1}",
+ "151": "帖{tie3}",
+ "152": "帖{tie4}",
+ "153": "干{gan1}",
+ "154": "干{gan4}",
+ "155": "应{ying1}",
+ "156": "应{ying4}",
+ "157": "应{ying5}",
+ "158": "度{du4}",
+ "159": "度{du5}",
+ "160": "度{duo2}",
+ "161": "弹{dan4}",
+ "162": "弹{tan2}",
+ "163": "弹{tan5}",
+ "164": "强{jiang4}",
+ "165": "强{qiang2}",
+ "166": "强{qiang3}",
+ "167": "当{dang1}",
+ "168": "当{dang4}",
+ "169": "当{dang5}",
+ "170": "待{dai1}",
+ "171": "待{dai4}",
+ "172": "得{de2}",
+ "173": "得{de5}",
+ "174": "得{dei3}",
+ "175": "得{dei5}",
+ "176": "恶{e3}",
+ "177": "恶{e4}",
+ "178": "恶{wu4}",
+ "179": "扁{bian3}",
+ "180": "扁{pian1}",
+ "181": "扇{shan1}",
+ "182": "扇{shan4}",
+ "183": "扎{za1}",
+ "184": "扎{zha1}",
+ "185": "扎{zha2}",
+ "186": "扫{sao3}",
+ "187": "扫{sao4}",
+ "188": "担{dan1}",
+ "189": "担{dan4}",
+ "190": "担{dan5}",
+ "191": "挑{tiao1}",
+ "192": "挑{tiao3}",
+ "193": "据{jv1}",
+ "194": "据{jv4}",
+ "195": "撒{sa1}",
+ "196": "撒{sa3}",
+ "197": "撒{sa5}",
+ "198": "教{jiao1}",
+ "199": "教{jiao4}",
+ "200": "散{san3}",
+ "201": "散{san4}",
+ "202": "散{san5}",
+ "203": "数{shu3}",
+ "204": "数{shu4}",
+ "205": "数{shu5}",
+ "206": "斗{dou3}",
+ "207": "斗{dou4}",
+ "208": "晃{huang3}",
+ "209": "曝{bao4}",
+ "210": "曲{qu1}",
+ "211": "曲{qu3}",
+ "212": "更{geng1}",
+ "213": "更{geng4}",
+ "214": "曾{ceng1}",
+ "215": "曾{ceng2}",
+ "216": "曾{zeng1}",
+ "217": "朝{chao2}",
+ "218": "朝{zhao1}",
+ "219": "朴{piao2}",
+ "220": "朴{pu2}",
+ "221": "朴{pu3}",
+ "222": "杆{gan1}",
+ "223": "杆{gan3}",
+ "224": "查{cha2}",
+ "225": "查{zha1}",
+ "226": "校{jiao4}",
+ "227": "校{xiao4}",
+ "228": "模{mo2}",
+ "229": "模{mu2}",
+ "230": "横{heng2}",
+ "231": "横{heng4}",
+ "232": "没{mei2}",
+ "233": "没{mo4}",
+ "234": "泡{pao1}",
+ "235": "泡{pao4}",
+ "236": "泡{pao5}",
+ "237": "济{ji3}",
+ "238": "济{ji4}",
+ "239": "混{hun2}",
+ "240": "混{hun3}",
+ "241": "混{hun4}",
+ "242": "混{hun5}",
+ "243": "漂{piao1}",
+ "244": "漂{piao3}",
+ "245": "漂{piao4}",
+ "246": "炸{zha2}",
+ "247": "炸{zha4}",
+ "248": "熟{shou2}",
+ "249": "熟{shu2}",
+ "250": "燕{yan1}",
+ "251": "燕{yan4}",
+ "252": "片{pian1}",
+ "253": "片{pian4}",
+ "254": "率{lv4}",
+ "255": "率{shuai4}",
+ "256": "畜{chu4}",
+ "257": "畜{xu4}",
+ "258": "的{de5}",
+ "259": "的{di1}",
+ "260": "的{di2}",
+ "261": "的{di4}",
+ "262": "的{di5}",
+ "263": "盛{cheng2}",
+ "264": "盛{sheng4}",
+ "265": "相{xiang1}",
+ "266": "相{xiang4}",
+ "267": "相{xiang5}",
+ "268": "省{sheng3}",
+ "269": "省{xing3}",
+ "270": "看{kan1}",
+ "271": "看{kan4}",
+ "272": "看{kan5}",
+ "273": "着{zhao1}",
+ "274": "着{zhao2}",
+ "275": "着{zhao5}",
+ "276": "着{zhe5}",
+ "277": "着{zhuo2}",
+ "278": "着{zhuo5}",
+ "279": "矫{jiao3}",
+ "280": "禁{jin1}",
+ "281": "禁{jin4}",
+ "282": "种{zhong3}",
+ "283": "种{zhong4}",
+ "284": "称{chen4}",
+ "285": "称{cheng1}",
+ "286": "空{kong1}",
+ "287": "空{kong4}",
+ "288": "答{da1}",
+ "289": "答{da2}",
+ "290": "粘{nian2}",
+ "291": "粘{zhan1}",
+ "292": "糊{hu2}",
+ "293": "糊{hu5}",
+ "294": "系{ji4}",
+ "295": "系{xi4}",
+ "296": "系{xi5}",
+ "297": "累{lei2}",
+ "298": "累{lei3}",
+ "299": "累{lei4}",
+ "300": "累{lei5}",
+ "301": "纤{qian4}",
+ "302": "纤{xian1}",
+ "303": "结{jie1}",
+ "304": "结{jie2}",
+ "305": "结{jie5}",
+ "306": "给{gei3}",
+ "307": "给{gei5}",
+ "308": "给{ji3}",
+ "309": "缝{feng2}",
+ "310": "缝{feng4}",
+ "311": "缝{feng5}",
+ "312": "肖{xiao1}",
+ "313": "肖{xiao4}",
+ "314": "背{bei1}",
+ "315": "背{bei4}",
+ "316": "脏{zang1}",
+ "317": "脏{zang4}",
+ "318": "舍{she3}",
+ "319": "舍{she4}",
+ "320": "色{se4}",
+ "321": "色{shai3}",
+ "322": "落{lao4}",
+ "323": "落{luo4}",
+ "324": "蒙{meng1}",
+ "325": "蒙{meng2}",
+ "326": "蒙{meng3}",
+ "327": "薄{bao2}",
+ "328": "薄{bo2}",
+ "329": "薄{bo4}",
+ "330": "藏{cang2}",
+ "331": "藏{zang4}",
+ "332": "血{xie3}",
+ "333": "血{xue4}",
+ "334": "行{hang2}",
+ "335": "行{hang5}",
+ "336": "行{heng5}",
+ "337": "行{xing2}",
+ "338": "行{xing4}",
+ "339": "要{yao1}",
+ "340": "要{yao4}",
+ "341": "观{guan1}",
+ "342": "观{guan4}",
+ "343": "觉{jiao4}",
+ "344": "觉{jiao5}",
+ "345": "觉{jve2}",
+ "346": "角{jiao3}",
+ "347": "角{jve2}",
+ "348": "解{jie3}",
+ "349": "解{jie4}",
+ "350": "解{xie4}",
+ "351": "说{shui4}",
+ "352": "说{shuo1}",
+ "353": "调{diao4}",
+ "354": "调{tiao2}",
+ "355": "踏{ta1}",
+ "356": "踏{ta4}",
+ "357": "车{che1}",
+ "358": "车{jv1}",
+ "359": "转{zhuan3}",
+ "360": "转{zhuan4}",
+ "361": "载{zai3}",
+ "362": "载{zai4}",
+ "363": "还{hai2}",
+ "364": "还{huan2}",
+ "365": "遂{sui2}",
+ "366": "遂{sui4}",
+ "367": "都{dou1}",
+ "368": "都{du1}",
+ "369": "重{chong2}",
+ "370": "重{zhong4}",
+ "371": "量{liang2}",
+ "372": "量{liang4}",
+ "373": "量{liang5}",
+ "374": "钻{zuan1}",
+ "375": "钻{zuan4}",
+ "376": "铺{pu1}",
+ "377": "铺{pu4}",
+ "378": "长{chang2}",
+ "379": "长{chang3}",
+ "380": "长{zhang3}",
+ "381": "间{jian1}",
+ "382": "间{jian4}",
+ "383": "降{jiang4}",
+ "384": "降{xiang2}",
+ "385": "难{nan2}",
+ "386": "难{nan4}",
+ "387": "难{nan5}",
+ "388": "露{lou4}",
+ "389": "露{lu4}",
+ "390": "鲜{xian1}",
+ "391": "鲜{xian3}"
+}
\ No newline at end of file
diff --git a/models/tts/maskgct/g2p/sources/g2p_chinese_model/polydict_r.json b/models/tts/maskgct/g2p/sources/g2p_chinese_model/polydict_r.json
new file mode 100644
index 0000000000000000000000000000000000000000..aabbe6257493eaee7d3f0b77f78f0cb006e89fb6
--- /dev/null
+++ b/models/tts/maskgct/g2p/sources/g2p_chinese_model/polydict_r.json
@@ -0,0 +1,393 @@
+{
+ "丧{sang1}": 1,
+ "丧{sang4}": 2,
+ "中{zhong1}": 3,
+ "中{zhong4}": 4,
+ "为{wei2}": 5,
+ "为{wei4}": 6,
+ "乌{wu1}": 7,
+ "乌{wu4}": 8,
+ "乐{lao4}": 9,
+ "乐{le4}": 10,
+ "乐{le5}": 11,
+ "乐{yao4}": 12,
+ "乐{yve4}": 13,
+ "了{le5}": 14,
+ "了{liao3}": 15,
+ "了{liao5}": 16,
+ "什{shen2}": 17,
+ "什{shi2}": 18,
+ "仔{zai3}": 19,
+ "仔{zai5}": 20,
+ "仔{zi3}": 21,
+ "仔{zi5}": 22,
+ "令{ling2}": 23,
+ "令{ling4}": 24,
+ "任{ren2}": 25,
+ "任{ren4}": 26,
+ "会{hui4}": 27,
+ "会{hui5}": 28,
+ "会{kuai4}": 29,
+ "传{chuan2}": 30,
+ "传{zhuan4}": 31,
+ "佛{fo2}": 32,
+ "佛{fu2}": 33,
+ "供{gong1}": 34,
+ "供{gong4}": 35,
+ "便{bian4}": 36,
+ "便{pian2}": 37,
+ "倒{dao3}": 38,
+ "倒{dao4}": 39,
+ "假{jia3}": 40,
+ "假{jia4}": 41,
+ "兴{xing1}": 42,
+ "兴{xing4}": 43,
+ "冠{guan1}": 44,
+ "冠{guan4}": 45,
+ "冲{chong1}": 46,
+ "冲{chong4}": 47,
+ "几{ji1}": 48,
+ "几{ji2}": 49,
+ "几{ji3}": 50,
+ "分{fen1}": 51,
+ "分{fen4}": 52,
+ "分{fen5}": 53,
+ "切{qie1}": 54,
+ "切{qie4}": 55,
+ "划{hua2}": 56,
+ "划{hua4}": 57,
+ "划{hua5}": 58,
+ "创{chuang1}": 59,
+ "创{chuang4}": 60,
+ "剥{bao1}": 61,
+ "剥{bo1}": 62,
+ "勒{le4}": 63,
+ "勒{le5}": 64,
+ "勒{lei1}": 65,
+ "区{ou1}": 66,
+ "区{qu1}": 67,
+ "华{hua2}": 68,
+ "华{hua4}": 69,
+ "单{chan2}": 70,
+ "单{dan1}": 71,
+ "单{shan4}": 72,
+ "卜{bo5}": 73,
+ "卜{bu3}": 74,
+ "占{zhan1}": 75,
+ "占{zhan4}": 76,
+ "卡{ka2}": 77,
+ "卡{ka3}": 78,
+ "卡{qia3}": 79,
+ "卷{jvan3}": 80,
+ "卷{jvan4}": 81,
+ "厦{sha4}": 82,
+ "厦{xia4}": 83,
+ "参{can1}": 84,
+ "参{cen1}": 85,
+ "参{shen1}": 86,
+ "发{fa1}": 87,
+ "发{fa4}": 88,
+ "发{fa5}": 89,
+ "只{zhi1}": 90,
+ "只{zhi3}": 91,
+ "号{hao2}": 92,
+ "号{hao4}": 93,
+ "号{hao5}": 94,
+ "同{tong2}": 95,
+ "同{tong4}": 96,
+ "同{tong5}": 97,
+ "吐{tu2}": 98,
+ "吐{tu3}": 99,
+ "吐{tu4}": 100,
+ "和{he2}": 101,
+ "和{he4}": 102,
+ "和{he5}": 103,
+ "和{huo2}": 104,
+ "和{huo4}": 105,
+ "和{huo5}": 106,
+ "喝{he1}": 107,
+ "喝{he4}": 108,
+ "圈{jvan4}": 109,
+ "圈{qvan1}": 110,
+ "圈{qvan5}": 111,
+ "地{de5}": 112,
+ "地{di4}": 113,
+ "地{di5}": 114,
+ "塞{sai1}": 115,
+ "塞{sai2}": 116,
+ "塞{sai4}": 117,
+ "塞{se4}": 118,
+ "壳{ke2}": 119,
+ "壳{qiao4}": 120,
+ "处{chu3}": 121,
+ "处{chu4}": 122,
+ "奇{ji1}": 123,
+ "奇{qi2}": 124,
+ "奔{ben1}": 125,
+ "奔{ben4}": 126,
+ "好{hao3}": 127,
+ "好{hao4}": 128,
+ "好{hao5}": 129,
+ "宁{ning2}": 130,
+ "宁{ning4}": 131,
+ "宁{ning5}": 132,
+ "宿{su4}": 133,
+ "宿{xiu3}": 134,
+ "宿{xiu4}": 135,
+ "将{jiang1}": 136,
+ "将{jiang4}": 137,
+ "少{shao3}": 138,
+ "少{shao4}": 139,
+ "尽{jin3}": 140,
+ "尽{jin4}": 141,
+ "岗{gang1}": 142,
+ "岗{gang3}": 143,
+ "差{cha1}": 144,
+ "差{cha4}": 145,
+ "差{chai1}": 146,
+ "差{ci1}": 147,
+ "巷{hang4}": 148,
+ "巷{xiang4}": 149,
+ "帖{tie1}": 150,
+ "帖{tie3}": 151,
+ "帖{tie4}": 152,
+ "干{gan1}": 153,
+ "干{gan4}": 154,
+ "应{ying1}": 155,
+ "应{ying4}": 156,
+ "应{ying5}": 157,
+ "度{du4}": 158,
+ "度{du5}": 159,
+ "度{duo2}": 160,
+ "弹{dan4}": 161,
+ "弹{tan2}": 162,
+ "弹{tan5}": 163,
+ "强{jiang4}": 164,
+ "强{qiang2}": 165,
+ "强{qiang3}": 166,
+ "当{dang1}": 167,
+ "当{dang4}": 168,
+ "当{dang5}": 169,
+ "待{dai1}": 170,
+ "待{dai4}": 171,
+ "得{de2}": 172,
+ "得{de5}": 173,
+ "得{dei3}": 174,
+ "得{dei5}": 175,
+ "恶{e3}": 176,
+ "恶{e4}": 177,
+ "恶{wu4}": 178,
+ "扁{bian3}": 179,
+ "扁{pian1}": 180,
+ "扇{shan1}": 181,
+ "扇{shan4}": 182,
+ "扎{za1}": 183,
+ "扎{zha1}": 184,
+ "扎{zha2}": 185,
+ "扫{sao3}": 186,
+ "扫{sao4}": 187,
+ "担{dan1}": 188,
+ "担{dan4}": 189,
+ "担{dan5}": 190,
+ "挑{tiao1}": 191,
+ "挑{tiao3}": 192,
+ "据{jv1}": 193,
+ "据{jv4}": 194,
+ "撒{sa1}": 195,
+ "撒{sa3}": 196,
+ "撒{sa5}": 197,
+ "教{jiao1}": 198,
+ "教{jiao4}": 199,
+ "散{san3}": 200,
+ "散{san4}": 201,
+ "散{san5}": 202,
+ "数{shu3}": 203,
+ "数{shu4}": 204,
+ "数{shu5}": 205,
+ "斗{dou3}": 206,
+ "斗{dou4}": 207,
+ "晃{huang3}": 208,
+ "曝{bao4}": 209,
+ "曲{qu1}": 210,
+ "曲{qu3}": 211,
+ "更{geng1}": 212,
+ "更{geng4}": 213,
+ "曾{ceng1}": 214,
+ "曾{ceng2}": 215,
+ "曾{zeng1}": 216,
+ "朝{chao2}": 217,
+ "朝{zhao1}": 218,
+ "朴{piao2}": 219,
+ "朴{pu2}": 220,
+ "朴{pu3}": 221,
+ "杆{gan1}": 222,
+ "杆{gan3}": 223,
+ "查{cha2}": 224,
+ "查{zha1}": 225,
+ "校{jiao4}": 226,
+ "校{xiao4}": 227,
+ "模{mo2}": 228,
+ "模{mu2}": 229,
+ "横{heng2}": 230,
+ "横{heng4}": 231,
+ "没{mei2}": 232,
+ "没{mo4}": 233,
+ "泡{pao1}": 234,
+ "泡{pao4}": 235,
+ "泡{pao5}": 236,
+ "济{ji3}": 237,
+ "济{ji4}": 238,
+ "混{hun2}": 239,
+ "混{hun3}": 240,
+ "混{hun4}": 241,
+ "混{hun5}": 242,
+ "漂{piao1}": 243,
+ "漂{piao3}": 244,
+ "漂{piao4}": 245,
+ "炸{zha2}": 246,
+ "炸{zha4}": 247,
+ "熟{shou2}": 248,
+ "熟{shu2}": 249,
+ "燕{yan1}": 250,
+ "燕{yan4}": 251,
+ "片{pian1}": 252,
+ "片{pian4}": 253,
+ "率{lv4}": 254,
+ "率{shuai4}": 255,
+ "畜{chu4}": 256,
+ "畜{xu4}": 257,
+ "的{de5}": 258,
+ "的{di1}": 259,
+ "的{di2}": 260,
+ "的{di4}": 261,
+ "的{di5}": 262,
+ "盛{cheng2}": 263,
+ "盛{sheng4}": 264,
+ "相{xiang1}": 265,
+ "相{xiang4}": 266,
+ "相{xiang5}": 267,
+ "省{sheng3}": 268,
+ "省{xing3}": 269,
+ "看{kan1}": 270,
+ "看{kan4}": 271,
+ "看{kan5}": 272,
+ "着{zhao1}": 273,
+ "着{zhao2}": 274,
+ "着{zhao5}": 275,
+ "着{zhe5}": 276,
+ "着{zhuo2}": 277,
+ "着{zhuo5}": 278,
+ "矫{jiao3}": 279,
+ "禁{jin1}": 280,
+ "禁{jin4}": 281,
+ "种{zhong3}": 282,
+ "种{zhong4}": 283,
+ "称{chen4}": 284,
+ "称{cheng1}": 285,
+ "空{kong1}": 286,
+ "空{kong4}": 287,
+ "答{da1}": 288,
+ "答{da2}": 289,
+ "粘{nian2}": 290,
+ "粘{zhan1}": 291,
+ "糊{hu2}": 292,
+ "糊{hu5}": 293,
+ "系{ji4}": 294,
+ "系{xi4}": 295,
+ "系{xi5}": 296,
+ "累{lei2}": 297,
+ "累{lei3}": 298,
+ "累{lei4}": 299,
+ "累{lei5}": 300,
+ "纤{qian4}": 301,
+ "纤{xian1}": 302,
+ "结{jie1}": 303,
+ "结{jie2}": 304,
+ "结{jie5}": 305,
+ "给{gei3}": 306,
+ "给{gei5}": 307,
+ "给{ji3}": 308,
+ "缝{feng2}": 309,
+ "缝{feng4}": 310,
+ "缝{feng5}": 311,
+ "肖{xiao1}": 312,
+ "肖{xiao4}": 313,
+ "背{bei1}": 314,
+ "背{bei4}": 315,
+ "脏{zang1}": 316,
+ "脏{zang4}": 317,
+ "舍{she3}": 318,
+ "舍{she4}": 319,
+ "色{se4}": 320,
+ "色{shai3}": 321,
+ "落{lao4}": 322,
+ "落{luo4}": 323,
+ "蒙{meng1}": 324,
+ "蒙{meng2}": 325,
+ "蒙{meng3}": 326,
+ "薄{bao2}": 327,
+ "薄{bo2}": 328,
+ "薄{bo4}": 329,
+ "藏{cang2}": 330,
+ "藏{zang4}": 331,
+ "血{xie3}": 332,
+ "血{xue4}": 333,
+ "行{hang2}": 334,
+ "行{hang5}": 335,
+ "行{heng5}": 336,
+ "行{xing2}": 337,
+ "行{xing4}": 338,
+ "要{yao1}": 339,
+ "要{yao4}": 340,
+ "观{guan1}": 341,
+ "观{guan4}": 342,
+ "觉{jiao4}": 343,
+ "觉{jiao5}": 344,
+ "觉{jve2}": 345,
+ "角{jiao3}": 346,
+ "角{jve2}": 347,
+ "解{jie3}": 348,
+ "解{jie4}": 349,
+ "解{xie4}": 350,
+ "说{shui4}": 351,
+ "说{shuo1}": 352,
+ "调{diao4}": 353,
+ "调{tiao2}": 354,
+ "踏{ta1}": 355,
+ "踏{ta4}": 356,
+ "车{che1}": 357,
+ "车{jv1}": 358,
+ "转{zhuan3}": 359,
+ "转{zhuan4}": 360,
+ "载{zai3}": 361,
+ "载{zai4}": 362,
+ "还{hai2}": 363,
+ "还{huan2}": 364,
+ "遂{sui2}": 365,
+ "遂{sui4}": 366,
+ "都{dou1}": 367,
+ "都{du1}": 368,
+ "重{chong2}": 369,
+ "重{zhong4}": 370,
+ "量{liang2}": 371,
+ "量{liang4}": 372,
+ "量{liang5}": 373,
+ "钻{zuan1}": 374,
+ "钻{zuan4}": 375,
+ "铺{pu1}": 376,
+ "铺{pu4}": 377,
+ "长{chang2}": 378,
+ "长{chang3}": 379,
+ "长{zhang3}": 380,
+ "间{jian1}": 381,
+ "间{jian4}": 382,
+ "降{jiang4}": 383,
+ "降{xiang2}": 384,
+ "难{nan2}": 385,
+ "难{nan4}": 386,
+ "难{nan5}": 387,
+ "露{lou4}": 388,
+ "露{lu4}": 389,
+ "鲜{xian1}": 390,
+ "鲜{xian3}": 391
+}
\ No newline at end of file
diff --git a/models/tts/maskgct/g2p/sources/g2p_chinese_model/vocab.txt b/models/tts/maskgct/g2p/sources/g2p_chinese_model/vocab.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ca4f9781030019ab9b253c6dcb8c7878b6dc87a5
--- /dev/null
+++ b/models/tts/maskgct/g2p/sources/g2p_chinese_model/vocab.txt
@@ -0,0 +1,21128 @@
+[PAD]
+[unused1]
+[unused2]
+[unused3]
+[unused4]
+[unused5]
+[unused6]
+[unused7]
+[unused8]
+[unused9]
+[unused10]
+[unused11]
+[unused12]
+[unused13]
+[unused14]
+[unused15]
+[unused16]
+[unused17]
+[unused18]
+[unused19]
+[unused20]
+[unused21]
+[unused22]
+[unused23]
+[unused24]
+[unused25]
+[unused26]
+[unused27]
+[unused28]
+[unused29]
+[unused30]
+[unused31]
+[unused32]
+[unused33]
+[unused34]
+[unused35]
+[unused36]
+[unused37]
+[unused38]
+[unused39]
+[unused40]
+[unused41]
+[unused42]
+[unused43]
+[unused44]
+[unused45]
+[unused46]
+[unused47]
+[unused48]
+[unused49]
+[unused50]
+[unused51]
+[unused52]
+[unused53]
+[unused54]
+[unused55]
+[unused56]
+[unused57]
+[unused58]
+[unused59]
+[unused60]
+[unused61]
+[unused62]
+[unused63]
+[unused64]
+[unused65]
+[unused66]
+[unused67]
+[unused68]
+[unused69]
+[unused70]
+[unused71]
+[unused72]
+[unused73]
+[unused74]
+[unused75]
+[unused76]
+[unused77]
+[unused78]
+[unused79]
+[unused80]
+[unused81]
+[unused82]
+[unused83]
+[unused84]
+[unused85]
+[unused86]
+[unused87]
+[unused88]
+[unused89]
+[unused90]
+[unused91]
+[unused92]
+[unused93]
+[unused94]
+[unused95]
+[unused96]
+[unused97]
+[unused98]
+[unused99]
+[UNK]
+[CLS]
+[SEP]
+[MASK]
+
+
+!
+"
+#
+$
+%
+&
+'
+(
+)
+*
++
+,
+-
+.
+/
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+:
+;
+<
+=
+>
+?
+@
+[
+\
+]
+^
+_
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+{
+|
+}
+~
+£
+¤
+¥
+§
+©
+«
+®
+°
+±
+²
+³
+µ
+·
+¹
+º
+»
+¼
+×
+ß
+æ
+÷
+ø
+đ
+ŋ
+ɔ
+ə
+ɡ
+ʰ
+ˇ
+ˈ
+ˊ
+ˋ
+ˍ
+ː
+˙
+˚
+ˢ
+α
+β
+γ
+δ
+ε
+η
+θ
+ι
+κ
+λ
+μ
+ν
+ο
+π
+ρ
+ς
+σ
+τ
+υ
+φ
+χ
+ψ
+ω
+а
+б
+в
+г
+д
+е
+ж
+з
+и
+к
+л
+м
+н
+о
+п
+р
+с
+т
+у
+ф
+х
+ц
+ч
+ш
+ы
+ь
+я
+і
+ا
+ب
+ة
+ت
+د
+ر
+س
+ع
+ل
+م
+ن
+ه
+و
+ي
+۩
+ก
+ง
+น
+ม
+ย
+ร
+อ
+า
+เ
+๑
+་
+ღ
+ᄀ
+ᄁ
+ᄂ
+ᄃ
+ᄅ
+ᄆ
+ᄇ
+ᄈ
+ᄉ
+ᄋ
+ᄌ
+ᄎ
+ᄏ
+ᄐ
+ᄑ
+ᄒ
+ᅡ
+ᅢ
+ᅣ
+ᅥ
+ᅦ
+ᅧ
+ᅨ
+ᅩ
+ᅪ
+ᅬ
+ᅭ
+ᅮ
+ᅯ
+ᅲ
+ᅳ
+ᅴ
+ᅵ
+ᆨ
+ᆫ
+ᆯ
+ᆷ
+ᆸ
+ᆺ
+ᆻ
+ᆼ
+ᗜ
+ᵃ
+ᵉ
+ᵍ
+ᵏ
+ᵐ
+ᵒ
+ᵘ
+‖
+„
+†
+•
+‥
+‧
+
+‰
+′
+″
+‹
+›
+※
+‿
+⁄
+ⁱ
+⁺
+ⁿ
+₁
+₂
+₃
+₄
+€
+℃
+№
+™
+ⅰ
+ⅱ
+ⅲ
+ⅳ
+ⅴ
+←
+↑
+→
+↓
+↔
+↗
+↘
+⇒
+∀
+−
+∕
+∙
+√
+∞
+∟
+∠
+∣
+∥
+∩
+∮
+∶
+∼
+∽
+≈
+≒
+≡
+≤
+≥
+≦
+≧
+≪
+≫
+⊙
+⋅
+⋈
+⋯
+⌒
+①
+②
+③
+④
+⑤
+⑥
+⑦
+⑧
+⑨
+⑩
+⑴
+⑵
+⑶
+⑷
+⑸
+⒈
+⒉
+⒊
+⒋
+ⓒ
+ⓔ
+ⓘ
+─
+━
+│
+┃
+┅
+┆
+┊
+┌
+└
+├
+┣
+═
+║
+╚
+╞
+╠
+╭
+╮
+╯
+╰
+╱
+╳
+▂
+▃
+▅
+▇
+█
+▉
+▋
+▌
+▍
+▎
+■
+□
+▪
+▫
+▬
+▲
+△
+▶
+►
+▼
+▽
+◆
+◇
+○
+◎
+●
+◕
+◠
+◢
+◤
+☀
+★
+☆
+☕
+☞
+☺
+☼
+♀
+♂
+♠
+♡
+♣
+♥
+♦
+♪
+♫
+♬
+✈
+✔
+✕
+✖
+✦
+✨
+✪
+✰
+✿
+❀
+❤
+➜
+➤
+⦿
+、
+。
+〃
+々
+〇
+〈
+〉
+《
+》
+「
+」
+『
+』
+【
+】
+〓
+〔
+〕
+〖
+〗
+〜
+〝
+〞
+ぁ
+あ
+ぃ
+い
+う
+ぇ
+え
+お
+か
+き
+く
+け
+こ
+さ
+し
+す
+せ
+そ
+た
+ち
+っ
+つ
+て
+と
+な
+に
+ぬ
+ね
+の
+は
+ひ
+ふ
+へ
+ほ
+ま
+み
+む
+め
+も
+ゃ
+や
+ゅ
+ゆ
+ょ
+よ
+ら
+り
+る
+れ
+ろ
+わ
+を
+ん
+゜
+ゝ
+ァ
+ア
+ィ
+イ
+ゥ
+ウ
+ェ
+エ
+ォ
+オ
+カ
+キ
+ク
+ケ
+コ
+サ
+シ
+ス
+セ
+ソ
+タ
+チ
+ッ
+ツ
+テ
+ト
+ナ
+ニ
+ヌ
+ネ
+ノ
+ハ
+ヒ
+フ
+ヘ
+ホ
+マ
+ミ
+ム
+メ
+モ
+ャ
+ヤ
+ュ
+ユ
+ョ
+ヨ
+ラ
+リ
+ル
+レ
+ロ
+ワ
+ヲ
+ン
+ヶ
+・
+ー
+ヽ
+ㄅ
+ㄆ
+ㄇ
+ㄉ
+ㄋ
+ㄌ
+ㄍ
+ㄎ
+ㄏ
+ㄒ
+ㄚ
+ㄛ
+ㄞ
+ㄟ
+ㄢ
+ㄤ
+ㄥ
+ㄧ
+ㄨ
+ㆍ
+㈦
+㊣
+㎡
+㗎
+一
+丁
+七
+万
+丈
+三
+上
+下
+不
+与
+丐
+丑
+专
+且
+丕
+世
+丘
+丙
+业
+丛
+东
+丝
+丞
+丟
+両
+丢
+两
+严
+並
+丧
+丨
+个
+丫
+中
+丰
+串
+临
+丶
+丸
+丹
+为
+主
+丼
+丽
+举
+丿
+乂
+乃
+久
+么
+义
+之
+乌
+乍
+乎
+乏
+乐
+乒
+乓
+乔
+乖
+乗
+乘
+乙
+乜
+九
+乞
+也
+习
+乡
+书
+乩
+买
+乱
+乳
+乾
+亀
+亂
+了
+予
+争
+事
+二
+于
+亏
+云
+互
+五
+井
+亘
+亙
+亚
+些
+亜
+亞
+亟
+亡
+亢
+交
+亥
+亦
+产
+亨
+亩
+享
+京
+亭
+亮
+亲
+亳
+亵
+人
+亿
+什
+仁
+仃
+仄
+仅
+仆
+仇
+今
+介
+仍
+从
+仏
+仑
+仓
+仔
+仕
+他
+仗
+付
+仙
+仝
+仞
+仟
+代
+令
+以
+仨
+仪
+们
+仮
+仰
+仲
+件
+价
+任
+份
+仿
+企
+伉
+伊
+伍
+伎
+伏
+伐
+休
+伕
+众
+优
+伙
+会
+伝
+伞
+伟
+传
+伢
+伤
+伦
+伪
+伫
+伯
+估
+伴
+伶
+伸
+伺
+似
+伽
+佃
+但
+佇
+佈
+位
+低
+住
+佐
+佑
+体
+佔
+何
+佗
+佘
+余
+佚
+佛
+作
+佝
+佞
+佟
+你
+佢
+佣
+佤
+佥
+佩
+佬
+佯
+佰
+佳
+併
+佶
+佻
+佼
+使
+侃
+侄
+來
+侈
+例
+侍
+侏
+侑
+侖
+侗
+供
+依
+侠
+価
+侣
+侥
+侦
+侧
+侨
+侬
+侮
+侯
+侵
+侶
+侷
+便
+係
+促
+俄
+俊
+俎
+俏
+俐
+俑
+俗
+俘
+俚
+保
+俞
+俟
+俠
+信
+俨
+俩
+俪
+俬
+俭
+修
+俯
+俱
+俳
+俸
+俺
+俾
+倆
+倉
+個
+倌
+倍
+倏
+們
+倒
+倔
+倖
+倘
+候
+倚
+倜
+借
+倡
+値
+倦
+倩
+倪
+倫
+倬
+倭
+倶
+债
+值
+倾
+偃
+假
+偈
+偉
+偌
+偎
+偏
+偕
+做
+停
+健
+側
+偵
+偶
+偷
+偻
+偽
+偿
+傀
+傅
+傍
+傑
+傘
+備
+傚
+傢
+傣
+傥
+储
+傩
+催
+傭
+傲
+傳
+債
+傷
+傻
+傾
+僅
+働
+像
+僑
+僕
+僖
+僚
+僥
+僧
+僭
+僮
+僱
+僵
+價
+僻
+儀
+儂
+億
+儆
+儉
+儋
+儒
+儕
+儘
+償
+儡
+優
+儲
+儷
+儼
+儿
+兀
+允
+元
+兄
+充
+兆
+兇
+先
+光
+克
+兌
+免
+児
+兑
+兒
+兔
+兖
+党
+兜
+兢
+入
+內
+全
+兩
+八
+公
+六
+兮
+兰
+共
+兲
+关
+兴
+兵
+其
+具
+典
+兹
+养
+兼
+兽
+冀
+内
+円
+冇
+冈
+冉
+冊
+册
+再
+冏
+冒
+冕
+冗
+写
+军
+农
+冠
+冢
+冤
+冥
+冨
+冪
+冬
+冯
+冰
+冲
+决
+况
+冶
+冷
+冻
+冼
+冽
+冾
+净
+凄
+准
+凇
+凈
+凉
+凋
+凌
+凍
+减
+凑
+凛
+凜
+凝
+几
+凡
+凤
+処
+凪
+凭
+凯
+凰
+凱
+凳
+凶
+凸
+凹
+出
+击
+函
+凿
+刀
+刁
+刃
+分
+切
+刈
+刊
+刍
+刎
+刑
+划
+列
+刘
+则
+刚
+创
+初
+删
+判
+別
+刨
+利
+刪
+别
+刮
+到
+制
+刷
+券
+刹
+刺
+刻
+刽
+剁
+剂
+剃
+則
+剉
+削
+剋
+剌
+前
+剎
+剐
+剑
+剔
+剖
+剛
+剜
+剝
+剣
+剤
+剥
+剧
+剩
+剪
+副
+割
+創
+剷
+剽
+剿
+劃
+劇
+劈
+劉
+劊
+劍
+劏
+劑
+力
+劝
+办
+功
+加
+务
+劣
+动
+助
+努
+劫
+劭
+励
+劲
+劳
+労
+劵
+効
+劾
+势
+勁
+勃
+勇
+勉
+勋
+勐
+勒
+動
+勖
+勘
+務
+勛
+勝
+勞
+募
+勢
+勤
+勧
+勳
+勵
+勸
+勺
+勻
+勾
+勿
+匀
+包
+匆
+匈
+匍
+匐
+匕
+化
+北
+匙
+匝
+匠
+匡
+匣
+匪
+匮
+匯
+匱
+匹
+区
+医
+匾
+匿
+區
+十
+千
+卅
+升
+午
+卉
+半
+卍
+华
+协
+卑
+卒
+卓
+協
+单
+卖
+南
+単
+博
+卜
+卞
+卟
+占
+卡
+卢
+卤
+卦
+卧
+卫
+卮
+卯
+印
+危
+即
+却
+卵
+卷
+卸
+卻
+卿
+厂
+厄
+厅
+历
+厉
+压
+厌
+厕
+厘
+厚
+厝
+原
+厢
+厥
+厦
+厨
+厩
+厭
+厮
+厲
+厳
+去
+县
+叁
+参
+參
+又
+叉
+及
+友
+双
+反
+収
+发
+叔
+取
+受
+变
+叙
+叛
+叟
+叠
+叡
+叢
+口
+古
+句
+另
+叨
+叩
+只
+叫
+召
+叭
+叮
+可
+台
+叱
+史
+右
+叵
+叶
+号
+司
+叹
+叻
+叼
+叽
+吁
+吃
+各
+吆
+合
+吉
+吊
+吋
+同
+名
+后
+吏
+吐
+向
+吒
+吓
+吕
+吖
+吗
+君
+吝
+吞
+吟
+吠
+吡
+否
+吧
+吨
+吩
+含
+听
+吭
+吮
+启
+吱
+吳
+吴
+吵
+吶
+吸
+吹
+吻
+吼
+吽
+吾
+呀
+呂
+呃
+呆
+呈
+告
+呋
+呎
+呐
+呓
+呕
+呗
+员
+呛
+呜
+呢
+呤
+呦
+周
+呱
+呲
+味
+呵
+呷
+呸
+呻
+呼
+命
+咀
+咁
+咂
+咄
+咆
+咋
+和
+咎
+咏
+咐
+咒
+咔
+咕
+咖
+咗
+咘
+咙
+咚
+咛
+咣
+咤
+咦
+咧
+咨
+咩
+咪
+咫
+咬
+咭
+咯
+咱
+咲
+咳
+咸
+咻
+咽
+咿
+哀
+品
+哂
+哄
+哆
+哇
+哈
+哉
+哋
+哌
+响
+哎
+哏
+哐
+哑
+哒
+哔
+哗
+哟
+員
+哥
+哦
+哧
+哨
+哩
+哪
+哭
+哮
+哲
+哺
+哼
+哽
+唁
+唄
+唆
+唇
+唉
+唏
+唐
+唑
+唔
+唠
+唤
+唧
+唬
+售
+唯
+唰
+唱
+唳
+唷
+唸
+唾
+啃
+啄
+商
+啉
+啊
+問
+啓
+啕
+啖
+啜
+啞
+啟
+啡
+啤
+啥
+啦
+啧
+啪
+啫
+啬
+啮
+啰
+啱
+啲
+啵
+啶
+啷
+啸
+啻
+啼
+啾
+喀
+喂
+喃
+善
+喆
+喇
+喉
+喊
+喋
+喎
+喏
+喔
+喘
+喙
+喚
+喜
+喝
+喟
+喧
+喪
+喫
+喬
+單
+喰
+喱
+喲
+喳
+喵
+営
+喷
+喹
+喺
+喻
+喽
+嗅
+嗆
+嗇
+嗎
+嗑
+嗒
+嗓
+嗔
+嗖
+嗚
+嗜
+嗝
+嗟
+嗡
+嗣
+嗤
+嗦
+嗨
+嗪
+嗬
+嗯
+嗰
+嗲
+嗳
+嗶
+嗷
+嗽
+嘀
+嘅
+嘆
+嘈
+嘉
+嘌
+嘍
+嘎
+嘔
+嘖
+嘗
+嘘
+嘚
+嘛
+嘜
+嘞
+嘟
+嘢
+嘣
+嘤
+嘧
+嘩
+嘭
+嘮
+嘯
+嘰
+嘱
+嘲
+嘴
+嘶
+嘸
+嘹
+嘻
+嘿
+噁
+噌
+噎
+噓
+噔
+噗
+噙
+噜
+噠
+噢
+噤
+器
+噩
+噪
+噬
+噱
+噴
+噶
+噸
+噹
+噻
+噼
+嚀
+嚇
+嚎
+嚏
+嚐
+嚓
+嚕
+嚟
+嚣
+嚥
+嚨
+嚮
+嚴
+嚷
+嚼
+囂
+囉
+囊
+囍
+囑
+囔
+囗
+囚
+四
+囝
+回
+囟
+因
+囡
+团
+団
+囤
+囧
+囪
+囫
+园
+困
+囱
+囲
+図
+围
+囹
+固
+国
+图
+囿
+圃
+圄
+圆
+圈
+國
+圍
+圏
+園
+圓
+圖
+團
+圜
+土
+圣
+圧
+在
+圩
+圭
+地
+圳
+场
+圻
+圾
+址
+坂
+均
+坊
+坍
+坎
+坏
+坐
+坑
+块
+坚
+坛
+坝
+坞
+坟
+坠
+坡
+坤
+坦
+坨
+坪
+坯
+坳
+坵
+坷
+垂
+垃
+垄
+型
+垒
+垚
+垛
+垠
+垢
+垣
+垦
+垩
+垫
+垭
+垮
+垵
+埂
+埃
+埋
+城
+埔
+埕
+埗
+域
+埠
+埤
+埵
+執
+埸
+培
+基
+埼
+堀
+堂
+堃
+堅
+堆
+堇
+堑
+堕
+堙
+堡
+堤
+堪
+堯
+堰
+報
+場
+堵
+堺
+堿
+塊
+塌
+塑
+塔
+塗
+塘
+塚
+塞
+塢
+塩
+填
+塬
+塭
+塵
+塾
+墀
+境
+墅
+墉
+墊
+墒
+墓
+増
+墘
+墙
+墜
+增
+墟
+墨
+墩
+墮
+墳
+墻
+墾
+壁
+壅
+壆
+壇
+壊
+壑
+壓
+壕
+壘
+壞
+壟
+壢
+壤
+壩
+士
+壬
+壮
+壯
+声
+売
+壳
+壶
+壹
+壺
+壽
+处
+备
+変
+复
+夏
+夔
+夕
+外
+夙
+多
+夜
+够
+夠
+夢
+夥
+大
+天
+太
+夫
+夭
+央
+夯
+失
+头
+夷
+夸
+夹
+夺
+夾
+奂
+奄
+奇
+奈
+奉
+奋
+奎
+奏
+奐
+契
+奔
+奕
+奖
+套
+奘
+奚
+奠
+奢
+奥
+奧
+奪
+奬
+奮
+女
+奴
+奶
+奸
+她
+好
+如
+妃
+妄
+妆
+妇
+妈
+妊
+妍
+妒
+妓
+妖
+妘
+妙
+妝
+妞
+妣
+妤
+妥
+妨
+妩
+妪
+妮
+妲
+妳
+妹
+妻
+妾
+姆
+姉
+姊
+始
+姍
+姐
+姑
+姒
+姓
+委
+姗
+姚
+姜
+姝
+姣
+姥
+姦
+姨
+姪
+姫
+姬
+姹
+姻
+姿
+威
+娃
+娄
+娅
+娆
+娇
+娉
+娑
+娓
+娘
+娛
+娜
+娟
+娠
+娣
+娥
+娩
+娱
+娲
+娴
+娶
+娼
+婀
+婁
+婆
+婉
+婊
+婕
+婚
+婢
+婦
+婧
+婪
+婭
+婴
+婵
+婶
+婷
+婺
+婿
+媒
+媚
+媛
+媞
+媧
+媲
+媳
+媽
+媾
+嫁
+嫂
+嫉
+嫌
+嫑
+嫔
+嫖
+嫘
+嫚
+嫡
+嫣
+嫦
+嫩
+嫲
+嫵
+嫻
+嬅
+嬉
+嬌
+嬗
+嬛
+嬢
+嬤
+嬪
+嬰
+嬴
+嬷
+嬸
+嬿
+孀
+孃
+子
+孑
+孔
+孕
+孖
+字
+存
+孙
+孚
+孛
+孜
+孝
+孟
+孢
+季
+孤
+学
+孩
+孪
+孫
+孬
+孰
+孱
+孳
+孵
+學
+孺
+孽
+孿
+宁
+它
+宅
+宇
+守
+安
+宋
+完
+宏
+宓
+宕
+宗
+官
+宙
+定
+宛
+宜
+宝
+实
+実
+宠
+审
+客
+宣
+室
+宥
+宦
+宪
+宫
+宮
+宰
+害
+宴
+宵
+家
+宸
+容
+宽
+宾
+宿
+寂
+寄
+寅
+密
+寇
+富
+寐
+寒
+寓
+寛
+寝
+寞
+察
+寡
+寢
+寥
+實
+寧
+寨
+審
+寫
+寬
+寮
+寰
+寵
+寶
+寸
+对
+寺
+寻
+导
+対
+寿
+封
+専
+射
+将
+將
+專
+尉
+尊
+尋
+對
+導
+小
+少
+尔
+尕
+尖
+尘
+尚
+尝
+尤
+尧
+尬
+就
+尴
+尷
+尸
+尹
+尺
+尻
+尼
+尽
+尾
+尿
+局
+屁
+层
+屄
+居
+屆
+屈
+屉
+届
+屋
+屌
+屍
+屎
+屏
+屐
+屑
+展
+屜
+属
+屠
+屡
+屢
+層
+履
+屬
+屯
+山
+屹
+屿
+岀
+岁
+岂
+岌
+岐
+岑
+岔
+岖
+岗
+岘
+岙
+岚
+岛
+岡
+岩
+岫
+岬
+岭
+岱
+岳
+岷
+岸
+峇
+峋
+峒
+峙
+峡
+峤
+峥
+峦
+峨
+峪
+峭
+峯
+峰
+峴
+島
+峻
+峽
+崁
+崂
+崆
+崇
+崎
+崑
+崔
+崖
+崗
+崙
+崛
+崧
+崩
+崭
+崴
+崽
+嵇
+嵊
+嵋
+嵌
+嵐
+嵘
+嵩
+嵬
+嵯
+嶂
+嶄
+嶇
+嶋
+嶙
+嶺
+嶼
+嶽
+巅
+巍
+巒
+巔
+巖
+川
+州
+巡
+巢
+工
+左
+巧
+巨
+巩
+巫
+差
+己
+已
+巳
+巴
+巷
+巻
+巽
+巾
+巿
+币
+市
+布
+帅
+帆
+师
+希
+帐
+帑
+帕
+帖
+帘
+帚
+帛
+帜
+帝
+帥
+带
+帧
+師
+席
+帮
+帯
+帰
+帳
+帶
+帷
+常
+帼
+帽
+幀
+幂
+幄
+幅
+幌
+幔
+幕
+幟
+幡
+幢
+幣
+幫
+干
+平
+年
+并
+幸
+幹
+幺
+幻
+幼
+幽
+幾
+广
+庁
+広
+庄
+庆
+庇
+床
+序
+庐
+库
+应
+底
+庖
+店
+庙
+庚
+府
+庞
+废
+庠
+度
+座
+庫
+庭
+庵
+庶
+康
+庸
+庹
+庾
+廁
+廂
+廃
+廈
+廉
+廊
+廓
+廖
+廚
+廝
+廟
+廠
+廢
+廣
+廬
+廳
+延
+廷
+建
+廿
+开
+弁
+异
+弃
+弄
+弈
+弊
+弋
+式
+弑
+弒
+弓
+弔
+引
+弗
+弘
+弛
+弟
+张
+弥
+弦
+弧
+弩
+弭
+弯
+弱
+張
+強
+弹
+强
+弼
+弾
+彅
+彆
+彈
+彌
+彎
+归
+当
+录
+彗
+彙
+彝
+形
+彤
+彥
+彦
+彧
+彩
+彪
+彫
+彬
+彭
+彰
+影
+彷
+役
+彻
+彼
+彿
+往
+征
+径
+待
+徇
+很
+徉
+徊
+律
+後
+徐
+徑
+徒
+従
+徕
+得
+徘
+徙
+徜
+從
+徠
+御
+徨
+復
+循
+徬
+微
+徳
+徴
+徵
+德
+徹
+徼
+徽
+心
+必
+忆
+忌
+忍
+忏
+忐
+忑
+忒
+忖
+志
+忘
+忙
+応
+忠
+忡
+忤
+忧
+忪
+快
+忱
+念
+忻
+忽
+忿
+怀
+态
+怂
+怅
+怆
+怎
+怏
+怒
+怔
+怕
+怖
+怙
+怜
+思
+怠
+怡
+急
+怦
+性
+怨
+怪
+怯
+怵
+总
+怼
+恁
+恃
+恆
+恋
+恍
+恐
+恒
+恕
+恙
+恚
+恢
+恣
+恤
+恥
+恨
+恩
+恪
+恫
+恬
+恭
+息
+恰
+恳
+恵
+恶
+恸
+恺
+恻
+恼
+恿
+悄
+悅
+悉
+悌
+悍
+悔
+悖
+悚
+悟
+悠
+患
+悦
+您
+悩
+悪
+悬
+悯
+悱
+悲
+悴
+悵
+悶
+悸
+悻
+悼
+悽
+情
+惆
+惇
+惊
+惋
+惑
+惕
+惘
+惚
+惜
+惟
+惠
+惡
+惦
+惧
+惨
+惩
+惫
+惬
+惭
+惮
+惯
+惰
+惱
+想
+惴
+惶
+惹
+惺
+愁
+愆
+愈
+愉
+愍
+意
+愕
+愚
+愛
+愜
+感
+愣
+愤
+愧
+愫
+愷
+愿
+慄
+慈
+態
+慌
+慎
+慑
+慕
+慘
+慚
+慟
+慢
+慣
+慧
+慨
+慫
+慮
+慰
+慳
+慵
+慶
+慷
+慾
+憂
+憊
+憋
+憎
+憐
+憑
+憔
+憚
+憤
+憧
+憨
+憩
+憫
+憬
+憲
+憶
+憾
+懂
+懇
+懈
+應
+懊
+懋
+懑
+懒
+懦
+懲
+懵
+懶
+懷
+懸
+懺
+懼
+懾
+懿
+戀
+戈
+戊
+戌
+戍
+戎
+戏
+成
+我
+戒
+戕
+或
+战
+戚
+戛
+戟
+戡
+戦
+截
+戬
+戮
+戰
+戲
+戳
+戴
+戶
+户
+戸
+戻
+戾
+房
+所
+扁
+扇
+扈
+扉
+手
+才
+扎
+扑
+扒
+打
+扔
+払
+托
+扛
+扣
+扦
+执
+扩
+扪
+扫
+扬
+扭
+扮
+扯
+扰
+扱
+扳
+扶
+批
+扼
+找
+承
+技
+抄
+抉
+把
+抑
+抒
+抓
+投
+抖
+抗
+折
+抚
+抛
+抜
+択
+抟
+抠
+抡
+抢
+护
+报
+抨
+披
+抬
+抱
+抵
+抹
+押
+抽
+抿
+拂
+拄
+担
+拆
+拇
+拈
+拉
+拋
+拌
+拍
+拎
+拐
+拒
+拓
+拔
+拖
+拗
+拘
+拙
+拚
+招
+拜
+拟
+拡
+拢
+拣
+拥
+拦
+拧
+拨
+择
+括
+拭
+拮
+拯
+拱
+拳
+拴
+拷
+拼
+拽
+拾
+拿
+持
+挂
+指
+挈
+按
+挎
+挑
+挖
+挙
+挚
+挛
+挝
+挞
+挟
+挠
+挡
+挣
+挤
+挥
+挨
+挪
+挫
+振
+挲
+挹
+挺
+挽
+挾
+捂
+捅
+捆
+捉
+捋
+捌
+捍
+捎
+捏
+捐
+捕
+捞
+损
+捡
+换
+捣
+捧
+捨
+捩
+据
+捱
+捲
+捶
+捷
+捺
+捻
+掀
+掂
+掃
+掇
+授
+掉
+掌
+掏
+掐
+排
+掖
+掘
+掙
+掛
+掠
+採
+探
+掣
+接
+控
+推
+掩
+措
+掬
+掰
+掲
+掳
+掴
+掷
+掸
+掺
+揀
+揃
+揄
+揆
+揉
+揍
+描
+提
+插
+揖
+揚
+換
+握
+揣
+揩
+揪
+揭
+揮
+援
+揶
+揸
+揹
+揽
+搀
+搁
+搂
+搅
+損
+搏
+搐
+搓
+搔
+搖
+搗
+搜
+搞
+搡
+搪
+搬
+搭
+搵
+搶
+携
+搽
+摀
+摁
+摄
+摆
+摇
+摈
+摊
+摒
+摔
+摘
+摞
+摟
+摧
+摩
+摯
+摳
+摸
+摹
+摺
+摻
+撂
+撃
+撅
+撇
+撈
+撐
+撑
+撒
+撓
+撕
+撚
+撞
+撤
+撥
+撩
+撫
+撬
+播
+撮
+撰
+撲
+撵
+撷
+撸
+撻
+撼
+撿
+擀
+擁
+擂
+擄
+擅
+擇
+擊
+擋
+操
+擎
+擒
+擔
+擘
+據
+擞
+擠
+擡
+擢
+擦
+擬
+擰
+擱
+擲
+擴
+擷
+擺
+擼
+擾
+攀
+攏
+攒
+攔
+攘
+攙
+攜
+攝
+攞
+攢
+攣
+攤
+攥
+攪
+攫
+攬
+支
+收
+攸
+改
+攻
+放
+政
+故
+效
+敌
+敍
+敎
+敏
+救
+敕
+敖
+敗
+敘
+教
+敛
+敝
+敞
+敢
+散
+敦
+敬
+数
+敲
+整
+敵
+敷
+數
+斂
+斃
+文
+斋
+斌
+斎
+斐
+斑
+斓
+斗
+料
+斛
+斜
+斟
+斡
+斤
+斥
+斧
+斩
+斫
+斬
+断
+斯
+新
+斷
+方
+於
+施
+旁
+旃
+旅
+旋
+旌
+旎
+族
+旖
+旗
+无
+既
+日
+旦
+旧
+旨
+早
+旬
+旭
+旮
+旱
+时
+旷
+旺
+旻
+昀
+昂
+昆
+昇
+昉
+昊
+昌
+明
+昏
+易
+昔
+昕
+昙
+星
+映
+春
+昧
+昨
+昭
+是
+昱
+昴
+昵
+昶
+昼
+显
+晁
+時
+晃
+晉
+晋
+晌
+晏
+晒
+晓
+晔
+晕
+晖
+晗
+晚
+晝
+晞
+晟
+晤
+晦
+晨
+晩
+普
+景
+晰
+晴
+晶
+晷
+智
+晾
+暂
+暄
+暇
+暈
+暉
+暌
+暐
+暑
+暖
+暗
+暝
+暢
+暧
+暨
+暫
+暮
+暱
+暴
+暸
+暹
+曄
+曆
+曇
+曉
+曖
+曙
+曜
+曝
+曠
+曦
+曬
+曰
+曲
+曳
+更
+書
+曹
+曼
+曾
+替
+最
+會
+月
+有
+朋
+服
+朐
+朔
+朕
+朗
+望
+朝
+期
+朦
+朧
+木
+未
+末
+本
+札
+朮
+术
+朱
+朴
+朵
+机
+朽
+杀
+杂
+权
+杆
+杈
+杉
+李
+杏
+材
+村
+杓
+杖
+杜
+杞
+束
+杠
+条
+来
+杨
+杭
+杯
+杰
+東
+杳
+杵
+杷
+杼
+松
+板
+极
+构
+枇
+枉
+枋
+析
+枕
+林
+枚
+果
+枝
+枢
+枣
+枪
+枫
+枭
+枯
+枰
+枱
+枳
+架
+枷
+枸
+柄
+柏
+某
+柑
+柒
+染
+柔
+柘
+柚
+柜
+柞
+柠
+柢
+查
+柩
+柬
+柯
+柱
+柳
+柴
+柵
+査
+柿
+栀
+栃
+栄
+栅
+标
+栈
+栉
+栋
+栎
+栏
+树
+栓
+栖
+栗
+校
+栩
+株
+样
+核
+根
+格
+栽
+栾
+桀
+桁
+桂
+桃
+桅
+框
+案
+桉
+桌
+桎
+桐
+桑
+桓
+桔
+桜
+桠
+桡
+桢
+档
+桥
+桦
+桧
+桨
+桩
+桶
+桿
+梁
+梅
+梆
+梏
+梓
+梗
+條
+梟
+梢
+梦
+梧
+梨
+梭
+梯
+械
+梳
+梵
+梶
+检
+棂
+棄
+棉
+棋
+棍
+棒
+棕
+棗
+棘
+棚
+棟
+棠
+棣
+棧
+森
+棱
+棲
+棵
+棹
+棺
+椁
+椅
+椋
+植
+椎
+椒
+検
+椪
+椭
+椰
+椹
+椽
+椿
+楂
+楊
+楓
+楔
+楚
+楝
+楞
+楠
+楣
+楨
+楫
+業
+楮
+極
+楷
+楸
+楹
+楼
+楽
+概
+榄
+榆
+榈
+榉
+榔
+榕
+榖
+榛
+榜
+榨
+榫
+榭
+榮
+榱
+榴
+榷
+榻
+槁
+槃
+構
+槌
+槍
+槎
+槐
+槓
+様
+槛
+槟
+槤
+槭
+槲
+槳
+槻
+槽
+槿
+樁
+樂
+樊
+樑
+樓
+標
+樞
+樟
+模
+樣
+権
+横
+樫
+樯
+樱
+樵
+樸
+樹
+樺
+樽
+樾
+橄
+橇
+橋
+橐
+橘
+橙
+機
+橡
+橢
+橫
+橱
+橹
+橼
+檀
+檄
+檎
+檐
+檔
+檗
+檜
+檢
+檬
+檯
+檳
+檸
+檻
+櫃
+櫚
+櫛
+櫥
+櫸
+櫻
+欄
+權
+欒
+欖
+欠
+次
+欢
+欣
+欧
+欲
+欸
+欺
+欽
+款
+歆
+歇
+歉
+歌
+歎
+歐
+歓
+歙
+歛
+歡
+止
+正
+此
+步
+武
+歧
+歩
+歪
+歯
+歲
+歳
+歴
+歷
+歸
+歹
+死
+歼
+殁
+殃
+殆
+殇
+殉
+殊
+残
+殒
+殓
+殖
+殘
+殞
+殡
+殤
+殭
+殯
+殲
+殴
+段
+殷
+殺
+殼
+殿
+毀
+毁
+毂
+毅
+毆
+毋
+母
+毎
+每
+毒
+毓
+比
+毕
+毗
+毘
+毙
+毛
+毡
+毫
+毯
+毽
+氈
+氏
+氐
+民
+氓
+气
+氖
+気
+氙
+氛
+氟
+氡
+氢
+氣
+氤
+氦
+氧
+氨
+氪
+氫
+氮
+氯
+氰
+氲
+水
+氷
+永
+氹
+氾
+汀
+汁
+求
+汆
+汇
+汉
+汎
+汐
+汕
+汗
+汙
+汛
+汝
+汞
+江
+池
+污
+汤
+汨
+汩
+汪
+汰
+汲
+汴
+汶
+汹
+決
+汽
+汾
+沁
+沂
+沃
+沅
+沈
+沉
+沌
+沏
+沐
+沒
+沓
+沖
+沙
+沛
+沟
+没
+沢
+沣
+沥
+沦
+沧
+沪
+沫
+沭
+沮
+沱
+河
+沸
+油
+治
+沼
+沽
+沾
+沿
+況
+泄
+泉
+泊
+泌
+泓
+法
+泗
+泛
+泞
+泠
+泡
+波
+泣
+泥
+注
+泪
+泫
+泮
+泯
+泰
+泱
+泳
+泵
+泷
+泸
+泻
+泼
+泽
+泾
+洁
+洄
+洋
+洒
+洗
+洙
+洛
+洞
+津
+洩
+洪
+洮
+洱
+洲
+洵
+洶
+洸
+洹
+活
+洼
+洽
+派
+流
+浃
+浄
+浅
+浆
+浇
+浊
+测
+济
+浏
+浑
+浒
+浓
+浔
+浙
+浚
+浜
+浣
+浦
+浩
+浪
+浬
+浮
+浯
+浴
+海
+浸
+涂
+涅
+涇
+消
+涉
+涌
+涎
+涓
+涔
+涕
+涙
+涛
+涝
+涞
+涟
+涠
+涡
+涣
+涤
+润
+涧
+涨
+涩
+涪
+涮
+涯
+液
+涵
+涸
+涼
+涿
+淀
+淄
+淅
+淆
+淇
+淋
+淌
+淑
+淒
+淖
+淘
+淙
+淚
+淞
+淡
+淤
+淦
+淨
+淩
+淪
+淫
+淬
+淮
+深
+淳
+淵
+混
+淹
+淺
+添
+淼
+清
+済
+渉
+渊
+渋
+渍
+渎
+渐
+渔
+渗
+渙
+渚
+減
+渝
+渠
+渡
+渣
+渤
+渥
+渦
+温
+測
+渭
+港
+渲
+渴
+游
+渺
+渾
+湃
+湄
+湊
+湍
+湖
+湘
+湛
+湟
+湧
+湫
+湮
+湯
+湳
+湾
+湿
+満
+溃
+溅
+溉
+溏
+源
+準
+溜
+溝
+溟
+溢
+溥
+溧
+溪
+溫
+溯
+溱
+溴
+溶
+溺
+溼
+滁
+滂
+滄
+滅
+滇
+滋
+滌
+滑
+滓
+滔
+滕
+滙
+滚
+滝
+滞
+滟
+满
+滢
+滤
+滥
+滦
+滨
+滩
+滬
+滯
+滲
+滴
+滷
+滸
+滾
+滿
+漁
+漂
+漆
+漉
+漏
+漓
+演
+漕
+漠
+漢
+漣
+漩
+漪
+漫
+漬
+漯
+漱
+漲
+漳
+漸
+漾
+漿
+潆
+潇
+潋
+潍
+潑
+潔
+潘
+潛
+潜
+潞
+潟
+潢
+潤
+潦
+潧
+潭
+潮
+潰
+潴
+潸
+潺
+潼
+澀
+澄
+澆
+澈
+澍
+澎
+澗
+澜
+澡
+澤
+澧
+澱
+澳
+澹
+激
+濁
+濂
+濃
+濑
+濒
+濕
+濘
+濛
+濟
+濠
+濡
+濤
+濫
+濬
+濮
+濯
+濱
+濺
+濾
+瀅
+瀆
+瀉
+瀋
+瀏
+瀑
+瀕
+瀘
+瀚
+瀛
+瀝
+瀞
+瀟
+瀧
+瀨
+瀬
+瀰
+瀾
+灌
+灏
+灑
+灘
+灝
+灞
+灣
+火
+灬
+灭
+灯
+灰
+灵
+灶
+灸
+灼
+災
+灾
+灿
+炀
+炁
+炅
+炉
+炊
+炎
+炒
+炔
+炕
+炖
+炙
+炜
+炫
+炬
+炭
+炮
+炯
+炳
+炷
+炸
+点
+為
+炼
+炽
+烁
+烂
+烃
+烈
+烊
+烏
+烘
+烙
+烛
+烟
+烤
+烦
+烧
+烨
+烩
+烫
+烬
+热
+烯
+烷
+烹
+烽
+焉
+焊
+焕
+焖
+焗
+焘
+焙
+焚
+焜
+無
+焦
+焯
+焰
+焱
+然
+焼
+煅
+煉
+煊
+煌
+煎
+煒
+煖
+煙
+煜
+煞
+煤
+煥
+煦
+照
+煨
+煩
+煮
+煲
+煸
+煽
+熄
+熊
+熏
+熒
+熔
+熙
+熟
+熠
+熨
+熬
+熱
+熵
+熹
+熾
+燁
+燃
+燄
+燈
+燉
+燊
+燎
+燒
+燔
+燕
+燙
+燜
+營
+燥
+燦
+燧
+燭
+燮
+燴
+燻
+燼
+燿
+爆
+爍
+爐
+爛
+爪
+爬
+爭
+爰
+爱
+爲
+爵
+父
+爷
+爸
+爹
+爺
+爻
+爽
+爾
+牆
+片
+版
+牌
+牍
+牒
+牙
+牛
+牝
+牟
+牠
+牡
+牢
+牦
+牧
+物
+牯
+牲
+牴
+牵
+特
+牺
+牽
+犀
+犁
+犄
+犊
+犍
+犒
+犢
+犧
+犬
+犯
+状
+犷
+犸
+犹
+狀
+狂
+狄
+狈
+狎
+狐
+狒
+狗
+狙
+狞
+狠
+狡
+狩
+独
+狭
+狮
+狰
+狱
+狸
+狹
+狼
+狽
+猎
+猕
+猖
+猗
+猙
+猛
+猜
+猝
+猥
+猩
+猪
+猫
+猬
+献
+猴
+猶
+猷
+猾
+猿
+獄
+獅
+獎
+獐
+獒
+獗
+獠
+獣
+獨
+獭
+獰
+獲
+獵
+獷
+獸
+獺
+獻
+獼
+獾
+玄
+率
+玉
+王
+玑
+玖
+玛
+玟
+玠
+玥
+玩
+玫
+玮
+环
+现
+玲
+玳
+玷
+玺
+玻
+珀
+珂
+珅
+珈
+珉
+珊
+珍
+珏
+珐
+珑
+珙
+珞
+珠
+珣
+珥
+珩
+珪
+班
+珮
+珲
+珺
+現
+球
+琅
+理
+琇
+琉
+琊
+琍
+琏
+琐
+琛
+琢
+琥
+琦
+琨
+琪
+琬
+琮
+琰
+琲
+琳
+琴
+琵
+琶
+琺
+琼
+瑀
+瑁
+瑄
+瑋
+瑕
+瑗
+瑙
+瑚
+瑛
+瑜
+瑞
+瑟
+瑠
+瑣
+瑤
+瑩
+瑪
+瑯
+瑰
+瑶
+瑾
+璀
+璁
+璃
+璇
+璉
+璋
+璎
+璐
+璜
+璞
+璟
+璧
+璨
+環
+璽
+璿
+瓊
+瓏
+瓒
+瓜
+瓢
+瓣
+瓤
+瓦
+瓮
+瓯
+瓴
+瓶
+瓷
+甄
+甌
+甕
+甘
+甙
+甚
+甜
+生
+產
+産
+甥
+甦
+用
+甩
+甫
+甬
+甭
+甯
+田
+由
+甲
+申
+电
+男
+甸
+町
+画
+甾
+畀
+畅
+界
+畏
+畑
+畔
+留
+畜
+畝
+畢
+略
+畦
+番
+畫
+異
+畲
+畳
+畴
+當
+畸
+畹
+畿
+疆
+疇
+疊
+疏
+疑
+疔
+疖
+疗
+疙
+疚
+疝
+疟
+疡
+疣
+疤
+疥
+疫
+疮
+疯
+疱
+疲
+疳
+疵
+疸
+疹
+疼
+疽
+疾
+痂
+病
+症
+痈
+痉
+痊
+痍
+痒
+痔
+痕
+痘
+痙
+痛
+痞
+痠
+痢
+痣
+痤
+痧
+痨
+痪
+痫
+痰
+痱
+痴
+痹
+痺
+痼
+痿
+瘀
+瘁
+瘋
+瘍
+瘓
+瘘
+瘙
+瘟
+瘠
+瘡
+瘢
+瘤
+瘦
+瘧
+瘩
+瘪
+瘫
+瘴
+瘸
+瘾
+療
+癇
+癌
+癒
+癖
+癜
+癞
+癡
+癢
+癣
+癥
+癫
+癬
+癮
+癱
+癲
+癸
+発
+登
+發
+白
+百
+皂
+的
+皆
+皇
+皈
+皋
+皎
+皑
+皓
+皖
+皙
+皚
+皮
+皰
+皱
+皴
+皺
+皿
+盂
+盃
+盅
+盆
+盈
+益
+盎
+盏
+盐
+监
+盒
+盔
+盖
+盗
+盘
+盛
+盜
+盞
+盟
+盡
+監
+盤
+盥
+盧
+盪
+目
+盯
+盱
+盲
+直
+相
+盹
+盼
+盾
+省
+眈
+眉
+看
+県
+眙
+眞
+真
+眠
+眦
+眨
+眩
+眯
+眶
+眷
+眸
+眺
+眼
+眾
+着
+睁
+睇
+睏
+睐
+睑
+睛
+睜
+睞
+睡
+睢
+督
+睥
+睦
+睨
+睪
+睫
+睬
+睹
+睽
+睾
+睿
+瞄
+瞅
+瞇
+瞋
+瞌
+瞎
+瞑
+瞒
+瞓
+瞞
+瞟
+瞠
+瞥
+瞧
+瞩
+瞪
+瞬
+瞭
+瞰
+瞳
+瞻
+瞼
+瞿
+矇
+矍
+矗
+矚
+矛
+矜
+矢
+矣
+知
+矩
+矫
+短
+矮
+矯
+石
+矶
+矽
+矾
+矿
+码
+砂
+砌
+砍
+砒
+研
+砖
+砗
+砚
+砝
+砣
+砥
+砧
+砭
+砰
+砲
+破
+砷
+砸
+砺
+砼
+砾
+础
+硅
+硐
+硒
+硕
+硝
+硫
+硬
+确
+硯
+硼
+碁
+碇
+碉
+碌
+碍
+碎
+碑
+碓
+碗
+碘
+碚
+碛
+碟
+碣
+碧
+碩
+碰
+碱
+碳
+碴
+確
+碼
+碾
+磁
+磅
+磊
+磋
+磐
+磕
+磚
+磡
+磨
+磬
+磯
+磲
+磷
+磺
+礁
+礎
+礙
+礡
+礦
+礪
+礫
+礴
+示
+礼
+社
+祀
+祁
+祂
+祇
+祈
+祉
+祎
+祐
+祕
+祖
+祗
+祚
+祛
+祜
+祝
+神
+祟
+祠
+祢
+祥
+票
+祭
+祯
+祷
+祸
+祺
+祿
+禀
+禁
+禄
+禅
+禍
+禎
+福
+禛
+禦
+禧
+禪
+禮
+禱
+禹
+禺
+离
+禽
+禾
+禿
+秀
+私
+秃
+秆
+秉
+秋
+种
+科
+秒
+秘
+租
+秣
+秤
+秦
+秧
+秩
+秭
+积
+称
+秸
+移
+秽
+稀
+稅
+程
+稍
+税
+稔
+稗
+稚
+稜
+稞
+稟
+稠
+稣
+種
+稱
+稲
+稳
+稷
+稹
+稻
+稼
+稽
+稿
+穀
+穂
+穆
+穌
+積
+穎
+穗
+穢
+穩
+穫
+穴
+究
+穷
+穹
+空
+穿
+突
+窃
+窄
+窈
+窍
+窑
+窒
+窓
+窕
+窖
+窗
+窘
+窜
+窝
+窟
+窠
+窥
+窦
+窨
+窩
+窪
+窮
+窯
+窺
+窿
+竄
+竅
+竇
+竊
+立
+竖
+站
+竜
+竞
+竟
+章
+竣
+童
+竭
+端
+競
+竹
+竺
+竽
+竿
+笃
+笆
+笈
+笋
+笏
+笑
+笔
+笙
+笛
+笞
+笠
+符
+笨
+第
+笹
+笺
+笼
+筆
+等
+筊
+筋
+筍
+筏
+筐
+筑
+筒
+答
+策
+筛
+筝
+筠
+筱
+筲
+筵
+筷
+筹
+签
+简
+箇
+箋
+箍
+箏
+箐
+箔
+箕
+算
+箝
+管
+箩
+箫
+箭
+箱
+箴
+箸
+節
+篁
+範
+篆
+篇
+築
+篑
+篓
+篙
+篝
+篠
+篡
+篤
+篩
+篪
+篮
+篱
+篷
+簇
+簌
+簍
+簡
+簦
+簧
+簪
+簫
+簷
+簸
+簽
+簾
+簿
+籁
+籃
+籌
+籍
+籐
+籟
+籠
+籤
+籬
+籮
+籲
+米
+类
+籼
+籽
+粄
+粉
+粑
+粒
+粕
+粗
+粘
+粟
+粤
+粥
+粧
+粪
+粮
+粱
+粲
+粳
+粵
+粹
+粼
+粽
+精
+粿
+糅
+糊
+糍
+糕
+糖
+糗
+糙
+糜
+糞
+糟
+糠
+糧
+糬
+糯
+糰
+糸
+系
+糾
+紀
+紂
+約
+紅
+紉
+紊
+紋
+納
+紐
+紓
+純
+紗
+紘
+紙
+級
+紛
+紜
+素
+紡
+索
+紧
+紫
+紮
+累
+細
+紳
+紹
+紺
+終
+絃
+組
+絆
+経
+結
+絕
+絞
+絡
+絢
+給
+絨
+絮
+統
+絲
+絳
+絵
+絶
+絹
+綁
+綏
+綑
+經
+継
+続
+綜
+綠
+綢
+綦
+綫
+綬
+維
+綱
+網
+綴
+綵
+綸
+綺
+綻
+綽
+綾
+綿
+緊
+緋
+総
+緑
+緒
+緘
+線
+緝
+緞
+締
+緣
+編
+緩
+緬
+緯
+練
+緹
+緻
+縁
+縄
+縈
+縛
+縝
+縣
+縫
+縮
+縱
+縴
+縷
+總
+績
+繁
+繃
+繆
+繇
+繋
+織
+繕
+繚
+繞
+繡
+繩
+繪
+繫
+繭
+繳
+繹
+繼
+繽
+纂
+續
+纍
+纏
+纓
+纔
+纖
+纜
+纠
+红
+纣
+纤
+约
+级
+纨
+纪
+纫
+纬
+纭
+纯
+纰
+纱
+纲
+纳
+纵
+纶
+纷
+纸
+纹
+纺
+纽
+纾
+线
+绀
+练
+组
+绅
+细
+织
+终
+绊
+绍
+绎
+经
+绑
+绒
+结
+绔
+绕
+绘
+给
+绚
+绛
+络
+绝
+绞
+统
+绡
+绢
+绣
+绥
+绦
+继
+绩
+绪
+绫
+续
+绮
+绯
+绰
+绳
+维
+绵
+绶
+绷
+绸
+绻
+综
+绽
+绾
+绿
+缀
+缄
+缅
+缆
+缇
+缈
+缉
+缎
+缓
+缔
+缕
+编
+缘
+缙
+缚
+缜
+缝
+缠
+缢
+缤
+缥
+缨
+缩
+缪
+缭
+缮
+缰
+缱
+缴
+缸
+缺
+缽
+罂
+罄
+罌
+罐
+网
+罔
+罕
+罗
+罚
+罡
+罢
+罩
+罪
+置
+罰
+署
+罵
+罷
+罹
+羁
+羅
+羈
+羊
+羌
+美
+羔
+羚
+羞
+羟
+羡
+羣
+群
+羥
+羧
+羨
+義
+羯
+羲
+羸
+羹
+羽
+羿
+翁
+翅
+翊
+翌
+翎
+習
+翔
+翘
+翟
+翠
+翡
+翦
+翩
+翰
+翱
+翳
+翹
+翻
+翼
+耀
+老
+考
+耄
+者
+耆
+耋
+而
+耍
+耐
+耒
+耕
+耗
+耘
+耙
+耦
+耨
+耳
+耶
+耷
+耸
+耻
+耽
+耿
+聂
+聆
+聊
+聋
+职
+聒
+联
+聖
+聘
+聚
+聞
+聪
+聯
+聰
+聲
+聳
+聴
+聶
+職
+聽
+聾
+聿
+肃
+肄
+肅
+肆
+肇
+肉
+肋
+肌
+肏
+肓
+肖
+肘
+肚
+肛
+肝
+肠
+股
+肢
+肤
+肥
+肩
+肪
+肮
+肯
+肱
+育
+肴
+肺
+肽
+肾
+肿
+胀
+胁
+胃
+胄
+胆
+背
+胍
+胎
+胖
+胚
+胛
+胜
+胝
+胞
+胡
+胤
+胥
+胧
+胫
+胭
+胯
+胰
+胱
+胳
+胴
+胶
+胸
+胺
+能
+脂
+脅
+脆
+脇
+脈
+脉
+脊
+脍
+脏
+脐
+脑
+脓
+脖
+脘
+脚
+脛
+脣
+脩
+脫
+脯
+脱
+脲
+脳
+脸
+脹
+脾
+腆
+腈
+腊
+腋
+腌
+腎
+腐
+腑
+腓
+腔
+腕
+腥
+腦
+腩
+腫
+腭
+腮
+腰
+腱
+腳
+腴
+腸
+腹
+腺
+腻
+腼
+腾
+腿
+膀
+膈
+膊
+膏
+膑
+膘
+膚
+膛
+膜
+膝
+膠
+膦
+膨
+膩
+膳
+膺
+膻
+膽
+膾
+膿
+臀
+臂
+臃
+臆
+臉
+臊
+臍
+臓
+臘
+臟
+臣
+臥
+臧
+臨
+自
+臬
+臭
+至
+致
+臺
+臻
+臼
+臾
+舀
+舂
+舅
+舆
+與
+興
+舉
+舊
+舌
+舍
+舎
+舐
+舒
+舔
+舖
+舗
+舛
+舜
+舞
+舟
+航
+舫
+般
+舰
+舱
+舵
+舶
+舷
+舸
+船
+舺
+舾
+艇
+艋
+艘
+艙
+艦
+艮
+良
+艰
+艱
+色
+艳
+艷
+艹
+艺
+艾
+节
+芃
+芈
+芊
+芋
+芍
+芎
+芒
+芙
+芜
+芝
+芡
+芥
+芦
+芩
+芪
+芫
+芬
+芭
+芮
+芯
+花
+芳
+芷
+芸
+芹
+芻
+芽
+芾
+苁
+苄
+苇
+苋
+苍
+苏
+苑
+苒
+苓
+苔
+苕
+苗
+苛
+苜
+苞
+苟
+苡
+苣
+若
+苦
+苫
+苯
+英
+苷
+苹
+苻
+茁
+茂
+范
+茄
+茅
+茉
+茎
+茏
+茗
+茜
+茧
+茨
+茫
+茬
+茭
+茯
+茱
+茲
+茴
+茵
+茶
+茸
+茹
+茼
+荀
+荃
+荆
+草
+荊
+荏
+荐
+荒
+荔
+荖
+荘
+荚
+荞
+荟
+荠
+荡
+荣
+荤
+荥
+荧
+荨
+荪
+荫
+药
+荳
+荷
+荸
+荻
+荼
+荽
+莅
+莆
+莉
+莊
+莎
+莒
+莓
+莖
+莘
+莞
+莠
+莢
+莧
+莪
+莫
+莱
+莲
+莴
+获
+莹
+莺
+莽
+莿
+菀
+菁
+菅
+菇
+菈
+菊
+菌
+菏
+菓
+菖
+菘
+菜
+菟
+菠
+菡
+菩
+華
+菱
+菲
+菸
+菽
+萁
+萃
+萄
+萊
+萋
+萌
+萍
+萎
+萘
+萝
+萤
+营
+萦
+萧
+萨
+萩
+萬
+萱
+萵
+萸
+萼
+落
+葆
+葉
+著
+葚
+葛
+葡
+董
+葦
+葩
+葫
+葬
+葭
+葯
+葱
+葳
+葵
+葷
+葺
+蒂
+蒋
+蒐
+蒔
+蒙
+蒜
+蒞
+蒟
+蒡
+蒨
+蒲
+蒸
+蒹
+蒻
+蒼
+蒿
+蓁
+蓄
+蓆
+蓉
+蓋
+蓑
+蓓
+蓖
+蓝
+蓟
+蓦
+蓬
+蓮
+蓼
+蓿
+蔑
+蔓
+蔔
+蔗
+蔘
+蔚
+蔡
+蔣
+蔥
+蔫
+蔬
+蔭
+蔵
+蔷
+蔺
+蔻
+蔼
+蔽
+蕁
+蕃
+蕈
+蕉
+蕊
+蕎
+蕙
+蕤
+蕨
+蕩
+蕪
+蕭
+蕲
+蕴
+蕻
+蕾
+薄
+薅
+薇
+薈
+薊
+薏
+薑
+薔
+薙
+薛
+薦
+薨
+薩
+薪
+薬
+薯
+薰
+薹
+藉
+藍
+藏
+藐
+藓
+藕
+藜
+藝
+藤
+藥
+藩
+藹
+藻
+藿
+蘆
+蘇
+蘊
+蘋
+蘑
+蘚
+蘭
+蘸
+蘼
+蘿
+虎
+虏
+虐
+虑
+虔
+處
+虚
+虛
+虜
+虞
+號
+虢
+虧
+虫
+虬
+虱
+虹
+虻
+虽
+虾
+蚀
+蚁
+蚂
+蚊
+蚌
+蚓
+蚕
+蚜
+蚝
+蚣
+蚤
+蚩
+蚪
+蚯
+蚱
+蚵
+蛀
+蛆
+蛇
+蛊
+蛋
+蛎
+蛐
+蛔
+蛙
+蛛
+蛟
+蛤
+蛭
+蛮
+蛰
+蛳
+蛹
+蛻
+蛾
+蜀
+蜂
+蜃
+蜆
+蜇
+蜈
+蜊
+蜍
+蜒
+蜓
+蜕
+蜗
+蜘
+蜚
+蜜
+蜡
+蜢
+蜥
+蜱
+蜴
+蜷
+蜻
+蜿
+蝇
+蝈
+蝉
+蝌
+蝎
+蝕
+蝗
+蝙
+蝟
+蝠
+蝦
+蝨
+蝴
+蝶
+蝸
+蝼
+螂
+螃
+融
+螞
+螢
+螨
+螯
+螳
+螺
+蟀
+蟄
+蟆
+蟋
+蟎
+蟑
+蟒
+蟠
+蟬
+蟲
+蟹
+蟻
+蟾
+蠅
+蠍
+蠔
+蠕
+蠛
+蠟
+蠡
+蠢
+蠣
+蠱
+蠶
+蠹
+蠻
+血
+衄
+衅
+衆
+行
+衍
+術
+衔
+街
+衙
+衛
+衝
+衞
+衡
+衢
+衣
+补
+表
+衩
+衫
+衬
+衮
+衰
+衲
+衷
+衹
+衾
+衿
+袁
+袂
+袄
+袅
+袈
+袋
+袍
+袒
+袖
+袜
+袞
+袤
+袪
+被
+袭
+袱
+裁
+裂
+装
+裆
+裊
+裏
+裔
+裕
+裘
+裙
+補
+裝
+裟
+裡
+裤
+裨
+裱
+裳
+裴
+裸
+裹
+製
+裾
+褂
+複
+褐
+褒
+褓
+褔
+褚
+褥
+褪
+褫
+褲
+褶
+褻
+襁
+襄
+襟
+襠
+襪
+襬
+襯
+襲
+西
+要
+覃
+覆
+覇
+見
+規
+覓
+視
+覚
+覦
+覧
+親
+覬
+観
+覷
+覺
+覽
+觀
+见
+观
+规
+觅
+视
+览
+觉
+觊
+觎
+觐
+觑
+角
+觞
+解
+觥
+触
+觸
+言
+訂
+計
+訊
+討
+訓
+訕
+訖
+託
+記
+訛
+訝
+訟
+訣
+訥
+訪
+設
+許
+訳
+訴
+訶
+診
+註
+証
+詆
+詐
+詔
+評
+詛
+詞
+詠
+詡
+詢
+詣
+試
+詩
+詫
+詬
+詭
+詮
+詰
+話
+該
+詳
+詹
+詼
+誅
+誇
+誉
+誌
+認
+誓
+誕
+誘
+語
+誠
+誡
+誣
+誤
+誥
+誦
+誨
+說
+説
+読
+誰
+課
+誹
+誼
+調
+諄
+談
+請
+諏
+諒
+論
+諗
+諜
+諡
+諦
+諧
+諫
+諭
+諮
+諱
+諳
+諷
+諸
+諺
+諾
+謀
+謁
+謂
+謄
+謊
+謎
+謐
+謔
+謗
+謙
+講
+謝
+謠
+謨
+謬
+謹
+謾
+譁
+證
+譎
+譏
+識
+譙
+譚
+譜
+警
+譬
+譯
+議
+譲
+譴
+護
+譽
+讀
+變
+讓
+讚
+讞
+计
+订
+认
+讥
+讧
+讨
+让
+讪
+讫
+训
+议
+讯
+记
+讲
+讳
+讴
+讶
+讷
+许
+讹
+论
+讼
+讽
+设
+访
+诀
+证
+诃
+评
+诅
+识
+诈
+诉
+诊
+诋
+词
+诏
+译
+试
+诗
+诘
+诙
+诚
+诛
+话
+诞
+诟
+诠
+诡
+询
+诣
+诤
+该
+详
+诧
+诩
+诫
+诬
+语
+误
+诰
+诱
+诲
+说
+诵
+诶
+请
+诸
+诺
+读
+诽
+课
+诿
+谀
+谁
+调
+谄
+谅
+谆
+谈
+谊
+谋
+谌
+谍
+谎
+谏
+谐
+谑
+谒
+谓
+谔
+谕
+谗
+谘
+谙
+谚
+谛
+谜
+谟
+谢
+谣
+谤
+谥
+谦
+谧
+谨
+谩
+谪
+谬
+谭
+谯
+谱
+谲
+谴
+谶
+谷
+豁
+豆
+豇
+豈
+豉
+豊
+豌
+豎
+豐
+豔
+豚
+象
+豢
+豪
+豫
+豬
+豹
+豺
+貂
+貅
+貌
+貓
+貔
+貘
+貝
+貞
+負
+財
+貢
+貧
+貨
+販
+貪
+貫
+責
+貯
+貰
+貳
+貴
+貶
+買
+貸
+費
+貼
+貽
+貿
+賀
+賁
+賂
+賃
+賄
+資
+賈
+賊
+賑
+賓
+賜
+賞
+賠
+賡
+賢
+賣
+賤
+賦
+質
+賬
+賭
+賴
+賺
+購
+賽
+贅
+贈
+贊
+贍
+贏
+贓
+贖
+贛
+贝
+贞
+负
+贡
+财
+责
+贤
+败
+账
+货
+质
+贩
+贪
+贫
+贬
+购
+贮
+贯
+贰
+贱
+贲
+贴
+贵
+贷
+贸
+费
+贺
+贻
+贼
+贾
+贿
+赁
+赂
+赃
+资
+赅
+赈
+赊
+赋
+赌
+赎
+赏
+赐
+赓
+赔
+赖
+赘
+赚
+赛
+赝
+赞
+赠
+赡
+赢
+赣
+赤
+赦
+赧
+赫
+赭
+走
+赳
+赴
+赵
+赶
+起
+趁
+超
+越
+趋
+趕
+趙
+趟
+趣
+趨
+足
+趴
+趵
+趸
+趺
+趾
+跃
+跄
+跆
+跋
+跌
+跎
+跑
+跖
+跚
+跛
+距
+跟
+跡
+跤
+跨
+跩
+跪
+路
+跳
+践
+跷
+跹
+跺
+跻
+踉
+踊
+踌
+踏
+踐
+踝
+踞
+踟
+踢
+踩
+踪
+踮
+踱
+踴
+踵
+踹
+蹂
+蹄
+蹇
+蹈
+蹉
+蹊
+蹋
+蹑
+蹒
+蹙
+蹟
+蹣
+蹤
+蹦
+蹩
+蹬
+蹭
+蹲
+蹴
+蹶
+蹺
+蹼
+蹿
+躁
+躇
+躉
+躊
+躋
+躍
+躏
+躪
+身
+躬
+躯
+躲
+躺
+軀
+車
+軋
+軌
+軍
+軒
+軟
+転
+軸
+軼
+軽
+軾
+較
+載
+輒
+輓
+輔
+輕
+輛
+輝
+輟
+輩
+輪
+輯
+輸
+輻
+輾
+輿
+轄
+轅
+轆
+轉
+轍
+轎
+轟
+车
+轧
+轨
+轩
+转
+轭
+轮
+软
+轰
+轲
+轴
+轶
+轻
+轼
+载
+轿
+较
+辄
+辅
+辆
+辇
+辈
+辉
+辊
+辍
+辐
+辑
+输
+辕
+辖
+辗
+辘
+辙
+辛
+辜
+辞
+辟
+辣
+辦
+辨
+辩
+辫
+辭
+辮
+辯
+辰
+辱
+農
+边
+辺
+辻
+込
+辽
+达
+迁
+迂
+迄
+迅
+过
+迈
+迎
+运
+近
+返
+还
+这
+进
+远
+违
+连
+迟
+迢
+迤
+迥
+迦
+迩
+迪
+迫
+迭
+述
+迴
+迷
+迸
+迹
+迺
+追
+退
+送
+适
+逃
+逅
+逆
+选
+逊
+逍
+透
+逐
+递
+途
+逕
+逗
+這
+通
+逛
+逝
+逞
+速
+造
+逢
+連
+逮
+週
+進
+逵
+逶
+逸
+逻
+逼
+逾
+遁
+遂
+遅
+遇
+遊
+運
+遍
+過
+遏
+遐
+遑
+遒
+道
+達
+違
+遗
+遙
+遛
+遜
+遞
+遠
+遢
+遣
+遥
+遨
+適
+遭
+遮
+遲
+遴
+遵
+遶
+遷
+選
+遺
+遼
+遽
+避
+邀
+邁
+邂
+邃
+還
+邇
+邈
+邊
+邋
+邏
+邑
+邓
+邕
+邛
+邝
+邢
+那
+邦
+邨
+邪
+邬
+邮
+邯
+邰
+邱
+邳
+邵
+邸
+邹
+邺
+邻
+郁
+郅
+郊
+郎
+郑
+郜
+郝
+郡
+郢
+郤
+郦
+郧
+部
+郫
+郭
+郴
+郵
+郷
+郸
+都
+鄂
+鄉
+鄒
+鄔
+鄙
+鄞
+鄢
+鄧
+鄭
+鄰
+鄱
+鄲
+鄺
+酉
+酊
+酋
+酌
+配
+酐
+酒
+酗
+酚
+酝
+酢
+酣
+酥
+酩
+酪
+酬
+酮
+酯
+酰
+酱
+酵
+酶
+酷
+酸
+酿
+醃
+醇
+醉
+醋
+醍
+醐
+醒
+醚
+醛
+醜
+醞
+醣
+醪
+醫
+醬
+醮
+醯
+醴
+醺
+釀
+釁
+采
+釉
+释
+釋
+里
+重
+野
+量
+釐
+金
+釗
+釘
+釜
+針
+釣
+釦
+釧
+釵
+鈀
+鈉
+鈍
+鈎
+鈔
+鈕
+鈞
+鈣
+鈦
+鈪
+鈴
+鈺
+鈾
+鉀
+鉄
+鉅
+鉉
+鉑
+鉗
+鉚
+鉛
+鉤
+鉴
+鉻
+銀
+銃
+銅
+銑
+銓
+銖
+銘
+銜
+銬
+銭
+銮
+銳
+銷
+銹
+鋁
+鋅
+鋒
+鋤
+鋪
+鋰
+鋸
+鋼
+錄
+錐
+錘
+錚
+錠
+錢
+錦
+錨
+錫
+錮
+錯
+録
+錳
+錶
+鍊
+鍋
+鍍
+鍛
+鍥
+鍰
+鍵
+鍺
+鍾
+鎂
+鎊
+鎌
+鎏
+鎔
+鎖
+鎗
+鎚
+鎧
+鎬
+鎮
+鎳
+鏈
+鏖
+鏗
+鏘
+鏞
+鏟
+鏡
+鏢
+鏤
+鏽
+鐘
+鐮
+鐲
+鐳
+鐵
+鐸
+鐺
+鑄
+鑊
+鑑
+鑒
+鑣
+鑫
+鑰
+鑲
+鑼
+鑽
+鑾
+鑿
+针
+钉
+钊
+钎
+钏
+钒
+钓
+钗
+钙
+钛
+钜
+钝
+钞
+钟
+钠
+钡
+钢
+钣
+钤
+钥
+钦
+钧
+钨
+钩
+钮
+钯
+钰
+钱
+钳
+钴
+钵
+钺
+钻
+钼
+钾
+钿
+铀
+铁
+铂
+铃
+铄
+铅
+铆
+铉
+铎
+铐
+铛
+铜
+铝
+铠
+铡
+铢
+铣
+铤
+铨
+铩
+铬
+铭
+铮
+铰
+铲
+铵
+银
+铸
+铺
+链
+铿
+销
+锁
+锂
+锄
+锅
+锆
+锈
+锉
+锋
+锌
+锏
+锐
+锑
+错
+锚
+锟
+锡
+锢
+锣
+锤
+锥
+锦
+锭
+键
+锯
+锰
+锲
+锵
+锹
+锺
+锻
+镀
+镁
+镂
+镇
+镉
+镌
+镍
+镐
+镑
+镕
+镖
+镗
+镛
+镜
+镣
+镭
+镯
+镰
+镳
+镶
+長
+长
+門
+閃
+閉
+開
+閎
+閏
+閑
+閒
+間
+閔
+閘
+閡
+関
+閣
+閥
+閨
+閩
+閱
+閲
+閹
+閻
+閾
+闆
+闇
+闊
+闌
+闍
+闔
+闕
+闖
+闘
+關
+闡
+闢
+门
+闪
+闫
+闭
+问
+闯
+闰
+闲
+间
+闵
+闷
+闸
+闹
+闺
+闻
+闽
+闾
+阀
+阁
+阂
+阅
+阆
+阇
+阈
+阉
+阎
+阐
+阑
+阔
+阕
+阖
+阙
+阚
+阜
+队
+阡
+阪
+阮
+阱
+防
+阳
+阴
+阵
+阶
+阻
+阿
+陀
+陂
+附
+际
+陆
+陇
+陈
+陋
+陌
+降
+限
+陕
+陛
+陝
+陞
+陟
+陡
+院
+陣
+除
+陨
+险
+陪
+陰
+陲
+陳
+陵
+陶
+陷
+陸
+険
+陽
+隅
+隆
+隈
+隊
+隋
+隍
+階
+随
+隐
+隔
+隕
+隘
+隙
+際
+障
+隠
+隣
+隧
+隨
+險
+隱
+隴
+隶
+隸
+隻
+隼
+隽
+难
+雀
+雁
+雄
+雅
+集
+雇
+雉
+雋
+雌
+雍
+雎
+雏
+雑
+雒
+雕
+雖
+雙
+雛
+雜
+雞
+離
+難
+雨
+雪
+雯
+雰
+雲
+雳
+零
+雷
+雹
+電
+雾
+需
+霁
+霄
+霆
+震
+霈
+霉
+霊
+霍
+霎
+霏
+霑
+霓
+霖
+霜
+霞
+霧
+霭
+霰
+露
+霸
+霹
+霽
+霾
+靂
+靄
+靈
+青
+靓
+靖
+静
+靚
+靛
+靜
+非
+靠
+靡
+面
+靥
+靦
+革
+靳
+靴
+靶
+靼
+鞅
+鞋
+鞍
+鞏
+鞑
+鞘
+鞠
+鞣
+鞦
+鞭
+韆
+韋
+韌
+韓
+韜
+韦
+韧
+韩
+韬
+韭
+音
+韵
+韶
+韻
+響
+頁
+頂
+頃
+項
+順
+須
+頌
+預
+頑
+頒
+頓
+頗
+領
+頜
+頡
+頤
+頫
+頭
+頰
+頷
+頸
+頹
+頻
+頼
+顆
+題
+額
+顎
+顏
+顔
+願
+顛
+類
+顧
+顫
+顯
+顱
+顴
+页
+顶
+顷
+项
+顺
+须
+顼
+顽
+顾
+顿
+颁
+颂
+预
+颅
+领
+颇
+颈
+颉
+颊
+颌
+颍
+颐
+频
+颓
+颔
+颖
+颗
+题
+颚
+颛
+颜
+额
+颞
+颠
+颡
+颢
+颤
+颦
+颧
+風
+颯
+颱
+颳
+颶
+颼
+飄
+飆
+风
+飒
+飓
+飕
+飘
+飙
+飚
+飛
+飞
+食
+飢
+飨
+飩
+飪
+飯
+飲
+飼
+飽
+飾
+餃
+餅
+餉
+養
+餌
+餐
+餒
+餓
+餘
+餚
+餛
+餞
+餡
+館
+餮
+餵
+餾
+饅
+饈
+饋
+饌
+饍
+饑
+饒
+饕
+饗
+饞
+饥
+饨
+饪
+饬
+饭
+饮
+饯
+饰
+饱
+饲
+饴
+饵
+饶
+饷
+饺
+饼
+饽
+饿
+馀
+馁
+馄
+馅
+馆
+馈
+馋
+馍
+馏
+馒
+馔
+首
+馗
+香
+馥
+馨
+馬
+馭
+馮
+馳
+馴
+駁
+駄
+駅
+駆
+駐
+駒
+駕
+駛
+駝
+駭
+駱
+駿
+騁
+騎
+騏
+験
+騙
+騨
+騰
+騷
+驀
+驅
+驊
+驍
+驒
+驕
+驗
+驚
+驛
+驟
+驢
+驥
+马
+驭
+驮
+驯
+驰
+驱
+驳
+驴
+驶
+驷
+驸
+驹
+驻
+驼
+驾
+驿
+骁
+骂
+骄
+骅
+骆
+骇
+骈
+骊
+骋
+验
+骏
+骐
+骑
+骗
+骚
+骛
+骜
+骞
+骠
+骡
+骤
+骥
+骧
+骨
+骯
+骰
+骶
+骷
+骸
+骼
+髂
+髅
+髋
+髏
+髒
+髓
+體
+髖
+高
+髦
+髪
+髮
+髯
+髻
+鬃
+鬆
+鬍
+鬓
+鬚
+鬟
+鬢
+鬣
+鬥
+鬧
+鬱
+鬼
+魁
+魂
+魄
+魅
+魇
+魍
+魏
+魔
+魘
+魚
+魯
+魷
+鮑
+鮨
+鮪
+鮭
+鮮
+鯉
+鯊
+鯖
+鯛
+鯨
+鯰
+鯽
+鰍
+鰓
+鰭
+鰲
+鰻
+鰾
+鱈
+鱉
+鱔
+鱗
+鱷
+鱸
+鱼
+鱿
+鲁
+鲈
+鲍
+鲑
+鲛
+鲜
+鲟
+鲢
+鲤
+鲨
+鲫
+鲱
+鲲
+鲶
+鲷
+鲸
+鳃
+鳄
+鳅
+鳌
+鳍
+鳕
+鳖
+鳗
+鳝
+鳞
+鳥
+鳩
+鳳
+鳴
+鳶
+鴉
+鴕
+鴛
+鴦
+鴨
+鴻
+鴿
+鵑
+鵜
+鵝
+鵡
+鵬
+鵰
+鵲
+鶘
+鶩
+鶯
+鶴
+鷗
+鷲
+鷹
+鷺
+鸚
+鸞
+鸟
+鸠
+鸡
+鸢
+鸣
+鸥
+鸦
+鸨
+鸪
+鸭
+鸯
+鸳
+鸵
+鸽
+鸾
+鸿
+鹂
+鹃
+鹄
+鹅
+鹈
+鹉
+鹊
+鹌
+鹏
+鹑
+鹕
+鹘
+鹜
+鹞
+鹤
+鹦
+鹧
+鹫
+鹭
+鹰
+鹳
+鹵
+鹹
+鹼
+鹽
+鹿
+麂
+麋
+麒
+麓
+麗
+麝
+麟
+麥
+麦
+麩
+麴
+麵
+麸
+麺
+麻
+麼
+麽
+麾
+黃
+黄
+黍
+黎
+黏
+黑
+黒
+黔
+默
+黛
+黜
+黝
+點
+黠
+黨
+黯
+黴
+鼋
+鼎
+鼐
+鼓
+鼠
+鼬
+鼹
+鼻
+鼾
+齁
+齊
+齋
+齐
+齒
+齡
+齢
+齣
+齦
+齿
+龄
+龅
+龈
+龊
+龋
+龌
+龍
+龐
+龔
+龕
+龙
+龚
+龛
+龜
+龟
+︰
+︱
+︶
+︿
+﹁
+﹂
+﹍
+﹏
+﹐
+﹑
+﹒
+﹔
+﹕
+﹖
+﹗
+﹙
+﹚
+﹝
+﹞
+﹡
+﹣
+!
+"
+#
+$
+%
+&
+'
+(
+)
+*
++
+,
+-
+.
+/
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+:
+;
+<
+=
+>
+?
+@
+[
+\
+]
+^
+_
+`
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+{
+|
+}
+~
+。
+「
+」
+、
+・
+ッ
+ー
+イ
+ク
+シ
+ス
+ト
+ノ
+フ
+ラ
+ル
+ン
+゙
+゚
+ ̄
+¥
+👍
+🔥
+😂
+😎
+...
+yam
+10
+2017
+12
+11
+2016
+20
+30
+15
+06
+lofter
+##s
+2015
+by
+16
+14
+18
+13
+24
+17
+2014
+21
+##0
+22
+19
+25
+23
+com
+100
+00
+05
+2013
+##a
+03
+09
+08
+28
+##2
+50
+01
+04
+##1
+27
+02
+2012
+##3
+26
+##e
+07
+##8
+##5
+##6
+##4
+##9
+##7
+29
+2011
+40
+##t
+2010
+##o
+##d
+##i
+2009
+##n
+app
+www
+the
+##m
+31
+##c
+##l
+##y
+##r
+##g
+2008
+60
+http
+200
+qq
+##p
+80
+##f
+google
+pixnet
+90
+cookies
+tripadvisor
+500
+##er
+##k
+35
+##h
+facebook
+2007
+2000
+70
+##b
+of
+##x
+##u
+45
+300
+iphone
+32
+1000
+2006
+48
+ip
+36
+in
+38
+3d
+##w
+##ing
+55
+ctrip
+##on
+##v
+33
+##の
+to
+34
+400
+id
+2005
+it
+37
+windows
+llc
+top
+99
+42
+39
+000
+led
+at
+##an
+41
+51
+52
+46
+49
+43
+53
+44
+##z
+android
+58
+and
+59
+2004
+56
+vr
+##か
+5000
+2003
+47
+blogthis
+twitter
+54
+##le
+150
+ok
+2018
+57
+75
+cn
+no
+ios
+##in
+##mm
+##00
+800
+on
+te
+3000
+65
+2001
+360
+95
+ig
+lv
+120
+##ng
+##を
+##us
+##に
+pc
+てす
+──
+600
+##te
+85
+2002
+88
+##ed
+html
+ncc
+wifi
+email
+64
+blog
+is
+##10
+##て
+mail
+online
+##al
+dvd
+##ic
+studio
+##は
+##℃
+##ia
+##と
+line
+vip
+72
+##q
+98
+##ce
+##en
+for
+##is
+##ra
+##es
+##j
+usb
+net
+cp
+1999
+asia
+4g
+##cm
+diy
+new
+3c
+##お
+ta
+66
+language
+vs
+apple
+tw
+86
+web
+##ne
+ipad
+62
+you
+##re
+101
+68
+##tion
+ps
+de
+bt
+pony
+atm
+##2017
+1998
+67
+##ch
+ceo
+##or
+go
+##na
+av
+pro
+cafe
+96
+pinterest
+97
+63
+pixstyleme3c
+##ta
+more
+said
+##2016
+1997
+mp3
+700
+##ll
+nba
+jun
+##20
+92
+tv
+1995
+pm
+61
+76
+nbsp
+250
+##ie
+linux
+##ma
+cd
+110
+hd
+##17
+78
+##ion
+77
+6000
+am
+##th
+##st
+94
+##se
+##et
+69
+180
+gdp
+my
+105
+81
+abc
+89
+flash
+79
+one
+93
+1990
+1996
+##ck
+gps
+##も
+##ly
+web885
+106
+2020
+91
+##ge
+4000
+1500
+xd
+boss
+isbn
+1994
+org
+##ry
+me
+love
+##11
+0fork
+73
+##12
+3g
+##ter
+##ar
+71
+82
+##la
+hotel
+130
+1970
+pk
+83
+87
+140
+ie
+##os
+##30
+##el
+74
+##50
+seo
+cpu
+##ml
+p2p
+84
+may
+##る
+sun
+tue
+internet
+cc
+posted
+youtube
+##at
+##ン
+##man
+ii
+##ル
+##15
+abs
+nt
+pdf
+yahoo
+ago
+1980
+##it
+news
+mac
+104
+##てす
+##me
+##り
+java
+1992
+spa
+##de
+##nt
+hk
+all
+plus
+la
+1993
+##mb
+##16
+##ve
+west
+##da
+160
+air
+##い
+##ps
+から
+##to
+1989
+logo
+htc
+php
+https
+fi
+momo
+##son
+sat
+##ke
+##80
+ebd
+suv
+wi
+day
+apk
+##88
+##um
+mv
+galaxy
+wiki
+or
+brake
+##ス
+1200
+する
+this
+1991
+mon
+##こ
+❤2017
+po
+##ない
+javascript
+life
+home
+june
+##ss
+system
+900
+##ー
+##0
+pp
+1988
+world
+fb
+4k
+br
+##as
+ic
+ai
+leonardo
+safari
+##60
+live
+free
+xx
+wed
+win7
+kiehl
+##co
+lg
+o2o
+##go
+us
+235
+1949
+mm
+しい
+vfm
+kanye
+##90
+##2015
+##id
+jr
+##ey
+123
+rss
+##sa
+##ro
+##am
+##no
+thu
+fri
+350
+##sh
+##ki
+103
+comments
+name
+##のて
+##pe
+##ine
+max
+1987
+8000
+uber
+##mi
+##ton
+wordpress
+office
+1986
+1985
+##ment
+107
+bd
+win10
+##ld
+##li
+gmail
+bb
+dior
+##rs
+##ri
+##rd
+##ます
+up
+cad
+##®
+dr
+して
+read
+##21
+をお
+##io
+##99
+url
+1984
+pvc
+paypal
+show
+policy
+##40
+##ty
+##18
+with
+##★
+##01
+txt
+102
+##ba
+dna
+from
+post
+mini
+ar
+taiwan
+john
+##ga
+privacy
+agoda
+##13
+##ny
+word
+##24
+##22
+##by
+##ur
+##hz
+1982
+##ang
+265
+cookie
+netscape
+108
+##ka
+##~
+##ad
+house
+share
+note
+ibm
+code
+hello
+nike
+sim
+survey
+##016
+1979
+1950
+wikia
+##32
+##017
+5g
+cbc
+##tor
+##kg
+1983
+##rt
+##14
+campaign
+store
+2500
+os
+##ct
+##ts
+##°
+170
+api
+##ns
+365
+excel
+##な
+##ao
+##ら
+##し
+~~
+##nd
+university
+163
+には
+518
+##70
+##ya
+##il
+##25
+pierre
+ipo
+0020
+897
+##23
+hotels
+##ian
+のお
+125
+years
+6606
+##ers
+##26
+high
+##day
+time
+##ay
+bug
+##line
+##く
+##す
+##be
+xp
+talk2yam
+yamservice
+10000
+coco
+##dy
+sony
+##ies
+1978
+microsoft
+david
+people
+##ha
+1960
+instagram
+intel
+その
+##ot
+iso
+1981
+##va
+115
+##mo
+##land
+xxx
+man
+co
+ltxsw
+##ation
+baby
+220
+##pa
+##ol
+1945
+7000
+tag
+450
+##ue
+msn
+##31
+oppo
+##ト
+##ca
+control
+##om
+st
+chrome
+##ure
+##ん
+be
+##き
+lol
+##19
+した
+##bo
+240
+lady
+##100
+##way
+##から
+4600
+##ko
+##do
+##un
+4s
+corporation
+168
+##ni
+herme
+##28
+cp
+978
+##up
+##06
+ui
+##ds
+ppt
+admin
+three
+します
+bbc
+re
+128
+##48
+ca
+##015
+##35
+hp
+##ee
+tpp
+##た
+##ive
+××
+root
+##cc
+##ました
+##ble
+##ity
+adobe
+park
+114
+et
+oled
+city
+##ex
+##ler
+##ap
+china
+##book
+20000
+view
+##ice
+global
+##km
+your
+hong
+##mg
+out
+##ms
+ng
+ebay
+##29
+menu
+ubuntu
+##cy
+rom
+##view
+open
+ktv
+do
+server
+##lo
+if
+english
+##ね
+##5
+##oo
+1600
+##02
+step1
+kong
+club
+135
+july
+inc
+1976
+mr
+hi
+##net
+touch
+##ls
+##ii
+michael
+lcd
+##05
+##33
+phone
+james
+step2
+1300
+ios9
+##box
+dc
+##2
+##ley
+samsung
+111
+280
+pokemon
+css
+##ent
+##les
+いいえ
+##1
+s8
+atom
+play
+bmw
+##said
+sa
+etf
+ctrl
+♥yoyo♥
+##55
+2025
+##2014
+##66
+adidas
+amazon
+1958
+##ber
+##ner
+visa
+##77
+##der
+1800
+connectivity
+##hi
+firefox
+109
+118
+hr
+so
+style
+mark
+pop
+ol
+skip
+1975
+as
+##27
+##ir
+##61
+190
+mba
+##う
+##ai
+le
+##ver
+1900
+cafe2017
+lte
+super
+113
+129
+##ron
+amd
+like
+##☆
+are
+##ster
+we
+##sk
+paul
+data
+international
+##ft
+longchamp
+ssd
+good
+##ート
+##ti
+reply
+##my
+↓↓↓
+apr
+star
+##ker
+source
+136
+js
+112
+get
+force
+photo
+##one
+126
+##2013
+##ow
+link
+bbs
+1972
+goods
+##lin
+python
+119
+##ip
+game
+##ics
+##ません
+blue
+##●
+520
+##45
+page
+itunes
+##03
+1955
+260
+1968
+gt
+gif
+618
+##ff
+##47
+group
+くたさい
+about
+bar
+ganji
+##nce
+music
+lee
+not
+1977
+1971
+1973
+##per
+an
+faq
+comment
+##って
+days
+##ock
+116
+##bs
+1974
+1969
+v1
+player
+1956
+xbox
+sql
+fm
+f1
+139
+##ah
+210
+##lv
+##mp
+##000
+melody
+1957
+##3
+550
+17life
+199
+1966
+xml
+market
+##au
+##71
+999
+##04
+what
+gl
+##95
+##age
+tips
+##68
+book
+##ting
+mysql
+can
+1959
+230
+##ung
+wonderland
+watch
+10℃
+##ction
+9000
+mar
+mobile
+1946
+1962
+article
+##db
+part
+▲top
+party
+って
+1967
+1964
+1948
+##07
+##ore
+##op
+この
+dj
+##78
+##38
+010
+main
+225
+1965
+##ong
+art
+320
+ad
+134
+020
+##73
+117
+pm2
+japan
+228
+##08
+ts
+1963
+##ica
+der
+sm
+##36
+2019
+##wa
+ct
+##7
+##や
+##64
+1937
+homemesh
+search
+##85
+##れは
+##tv
+##di
+macbook
+##9
+##くたさい
+service
+##♥
+type
+った
+750
+##ier
+##si
+##75
+##います
+##ok
+best
+##ット
+goris
+lock
+##った
+cf
+3m
+big
+##ut
+ftp
+carol
+##vi
+10
+1961
+happy
+sd
+##ac
+122
+anti
+pe
+cnn
+iii
+1920
+138
+##ラ
+1940
+esp
+jan
+tags
+##98
+##51
+august
+vol
+##86
+154
+##™
+##fs
+##れ
+##sion
+design
+ac
+##ム
+press
+jordan
+ppp
+that
+key
+check
+##6
+##tt
+##㎡
+1080p
+##lt
+power
+##42
+1952
+##bc
+vivi
+##ック
+he
+133
+121
+jpg
+##rry
+201
+175
+3500
+1947
+nb
+##ted
+##rn
+しています
+1954
+usd
+##t00
+master
+##ンク
+001
+model
+##58
+al
+##09
+1953
+##34
+ram
+goo
+ても
+##ui
+127
+1930
+red
+##ary
+rpg
+item
+##pm
+##41
+270
+##za
+project
+##2012
+hot
+td
+blogabstract
+##ger
+##62
+650
+##44
+gr2
+##します
+##m
+black
+electronic
+nfc
+year
+asus
+また
+html5
+cindy
+##hd
+m3
+132
+esc
+##od
+booking
+##53
+fed
+tvb
+##81
+##ina
+mit
+165
+##いる
+chan
+192
+distribution
+next
+になる
+peter
+bios
+steam
+cm
+1941
+にも
+pk10
+##ix
+##65
+##91
+dec
+nasa
+##ana
+icecat
+00z
+b1
+will
+##46
+li
+se
+##ji
+##み
+##ard
+oct
+##ain
+jp
+##ze
+##bi
+cio
+##56
+smart
+h5
+##39
+##port
+curve
+vpn
+##nm
+##dia
+utc
+##あり
+12345678910
+##52
+rmvb
+chanel
+a4
+miss
+##and
+##im
+media
+who
+##63
+she
+girl
+5s
+124
+vera
+##して
+class
+vivo
+king
+##フ
+##ei
+national
+ab
+1951
+5cm
+888
+145
+ipod
+ap
+1100
+5mm
+211
+ms
+2756
+##69
+mp4
+msci
+##po
+##89
+131
+mg
+index
+380
+##bit
+##out
+##zz
+##97
+##67
+158
+apec
+##8
+photoshop
+opec
+¥799
+ては
+##96
+##tes
+##ast
+2g
+○○
+##ール
+¥2899
+##ling
+##よ
+##ory
+1938
+##ical
+kitty
+content
+##43
+step3
+##cn
+win8
+155
+vc
+1400
+iphone7
+robert
+##した
+tcl
+137
+beauty
+##87
+en
+dollars
+##ys
+##oc
+step
+pay
+yy
+a1
+##2011
+##lly
+##ks
+##♪
+1939
+188
+download
+1944
+sep
+exe
+ph
+います
+school
+gb
+center
+pr
+street
+##board
+uv
+##37
+##lan
+winrar
+##que
+##ua
+##com
+1942
+1936
+480
+gpu
+##4
+ettoday
+fu
+tom
+##54
+##ren
+##via
+149
+##72
+b2b
+144
+##79
+##tch
+rose
+arm
+mb
+##49
+##ial
+##nn
+nvidia
+step4
+mvp
+00㎡
+york
+156
+##イ
+how
+cpi
+591
+2765
+gov
+kg
+joe
+##xx
+mandy
+pa
+##ser
+copyright
+fashion
+1935
+don
+##け
+ecu
+##ist
+##art
+erp
+wap
+have
+##lm
+talk
+##ek
+##ning
+##if
+ch
+##ite
+video
+1943
+cs
+san
+iot
+look
+##84
+##2010
+##ku
+october
+##ux
+trump
+##hs
+##ide
+box
+141
+first
+##ins
+april
+##ight
+##83
+185
+angel
+protected
+aa
+151
+162
+x1
+m2
+##fe
+##×
+##ho
+size
+143
+min
+ofo
+fun
+gomaji
+ex
+hdmi
+food
+dns
+march
+chris
+kevin
+##のか
+##lla
+##pp
+##ec
+ag
+ems
+6s
+720p
+##rm
+##ham
+off
+##92
+asp
+team
+fandom
+ed
+299
+▌♥
+##ell
+info
+されています
+##82
+sina
+4066
+161
+##able
+##ctor
+330
+399
+315
+dll
+rights
+ltd
+idc
+jul
+3kg
+1927
+142
+ma
+surface
+##76
+##ク
+~~~
+304
+mall
+eps
+146
+green
+##59
+map
+space
+donald
+v2
+sodu
+##light
+1931
+148
+1700
+まて
+310
+reserved
+htm
+##han
+##57
+2d
+178
+mod
+##ise
+##tions
+152
+ti
+##shi
+doc
+1933
+icp
+055
+wang
+##ram
+shopping
+aug
+##pi
+##well
+now
+wam
+b2
+からお
+##hu
+236
+1928
+##gb
+266
+f2
+##93
+153
+mix
+##ef
+##uan
+bwl
+##plus
+##res
+core
+##ess
+tea
+5℃
+hktvmall
+nhk
+##ate
+list
+##ese
+301
+feb
+4m
+inn
+ての
+nov
+159
+12345
+daniel
+##ci
+pass
+##bet
+##nk
+coffee
+202
+ssl
+airbnb
+##ute
+fbi
+woshipm
+skype
+ea
+cg
+sp
+##fc
+##www
+yes
+edge
+alt
+007
+##94
+fpga
+##ght
+##gs
+iso9001
+さい
+##ile
+##wood
+##uo
+image
+lin
+icon
+american
+##em
+1932
+set
+says
+##king
+##tive
+blogger
+##74
+なと
+256
+147
+##ox
+##zy
+##red
+##ium
+##lf
+nokia
+claire
+##リ
+##ding
+november
+lohas
+##500
+##tic
+##マ
+##cs
+##ある
+##che
+##ire
+##gy
+##ult
+db
+january
+win
+##カ
+166
+road
+ptt
+##ま
+##つ
+198
+##fa
+##mer
+anna
+pchome
+はい
+udn
+ef
+420
+##time
+##tte
+2030
+##ア
+g20
+white
+かかります
+1929
+308
+garden
+eleven
+di
+##おります
+chen
+309b
+777
+172
+young
+cosplay
+ちてない
+4500
+bat
+##123
+##tra
+##ては
+kindle
+npc
+steve
+etc
+##ern
+##|
+call
+xperia
+ces
+travel
+sk
+s7
+##ous
+1934
+##int
+みいたたけます
+183
+edu
+file
+cho
+qr
+##car
+##our
+186
+##ant
+##d
+eric
+1914
+rends
+##jo
+##する
+mastercard
+##2000
+kb
+##min
+290
+##ino
+vista
+##ris
+##ud
+jack
+2400
+##set
+169
+pos
+1912
+##her
+##ou
+taipei
+しく
+205
+beta
+##ませんか
+232
+##fi
+express
+255
+body
+##ill
+aphojoy
+user
+december
+meiki
+##ick
+tweet
+richard
+##av
+##ᆫ
+iphone6
+##dd
+ちてすか
+views
+##mark
+321
+pd
+##00
+times
+##▲
+level
+##ash
+10g
+point
+5l
+##ome
+208
+koreanmall
+##ak
+george
+q2
+206
+wma
+tcp
+##200
+スタッフ
+full
+mlb
+##lle
+##watch
+tm
+run
+179
+911
+smith
+business
+##und
+1919
+color
+##tal
+222
+171
+##less
+moon
+4399
+##rl
+update
+pcb
+shop
+499
+157
+little
+なし
+end
+##mhz
+van
+dsp
+easy
+660
+##house
+##key
+history
+##o
+oh
+##001
+##hy
+##web
+oem
+let
+was
+##2009
+##gg
+review
+##wan
+182
+##°c
+203
+uc
+title
+##val
+united
+233
+2021
+##ons
+doi
+trivago
+overdope
+sbs
+##ance
+##ち
+grand
+special
+573032185
+imf
+216
+wx17house
+##so
+##ーム
+audi
+##he
+london
+william
+##rp
+##ake
+science
+beach
+cfa
+amp
+ps4
+880
+##800
+##link
+##hp
+crm
+ferragamo
+bell
+make
+##eng
+195
+under
+zh
+photos
+2300
+##style
+##ント
+via
+176
+da
+##gi
+company
+i7
+##ray
+thomas
+370
+ufo
+i5
+##max
+plc
+ben
+back
+research
+8g
+173
+mike
+##pc
+##ッフ
+september
+189
+##ace
+vps
+february
+167
+pantos
+wp
+lisa
+1921
+★★
+jquery
+night
+long
+offer
+##berg
+##news
+1911
+##いて
+ray
+fks
+wto
+せます
+over
+164
+340
+##all
+##rus
+1924
+##888
+##works
+blogtitle
+loftpermalink
+##→
+187
+martin
+test
+ling
+km
+##め
+15000
+fda
+v3
+##ja
+##ロ
+wedding
+かある
+outlet
+family
+##ea
+をこ
+##top
+story
+##ness
+salvatore
+##lu
+204
+swift
+215
+room
+している
+oracle
+##ul
+1925
+sam
+b2c
+week
+pi
+rock
+##のは
+##a
+##けと
+##ean
+##300
+##gle
+cctv
+after
+chinese
+##back
+powered
+x2
+##tan
+1918
+##nes
+##イン
+canon
+only
+181
+##zi
+##las
+say
+##oe
+184
+##sd
+221
+##bot
+##world
+##zo
+sky
+made
+top100
+just
+1926
+pmi
+802
+234
+gap
+##vr
+177
+les
+174
+▲topoct
+ball
+vogue
+vi
+ing
+ofweek
+cos
+##list
+##ort
+▲topmay
+##なら
+##lon
+として
+last
+##tc
+##of
+##bus
+##gen
+real
+eva
+##コ
+a3
+nas
+##lie
+##ria
+##coin
+##bt
+▲topapr
+his
+212
+cat
+nata
+vive
+health
+⋯⋯
+drive
+sir
+▲topmar
+du
+cup
+##カー
+##ook
+##よう
+##sy
+alex
+msg
+tour
+しました
+3ce
+##word
+193
+ebooks
+r8
+block
+318
+##より
+2200
+nice
+pvp
+207
+months
+1905
+rewards
+##ther
+1917
+0800
+##xi
+##チ
+##sc
+micro
+850
+gg
+blogfp
+op
+1922
+daily
+m1
+264
+true
+##bb
+ml
+##tar
+##のお
+##ky
+anthony
+196
+253
+##yo
+state
+218
+##ara
+##aa
+##rc
+##tz
+##ston
+より
+gear
+##eo
+##ade
+ge
+see
+1923
+##win
+##ura
+ss
+heart
+##den
+##ita
+down
+##sm
+el
+png
+2100
+610
+rakuten
+whatsapp
+bay
+dream
+add
+##use
+680
+311
+pad
+gucci
+mpv
+##ode
+##fo
+island
+▲topjun
+##▼
+223
+jason
+214
+chicago
+##❤
+しの
+##hone
+io
+##れる
+##ことか
+sogo
+be2
+##ology
+990
+cloud
+vcd
+##con
+2~3
+##ford
+##joy
+##kb
+##こさいます
+##rade
+but
+##ach
+docker
+##ful
+rfid
+ul
+##ase
+hit
+ford
+##star
+580
+##○
+11
+a2
+sdk
+reading
+edited
+##are
+cmos
+##mc
+238
+siri
+light
+##ella
+##ため
+bloomberg
+##read
+pizza
+##ison
+jimmy
+##vm
+college
+node
+journal
+ba
+18k
+##play
+245
+##cer
+20
+magic
+##yu
+191
+jump
+288
+tt
+##ings
+asr
+##lia
+3200
+step5
+network
+##cd
+mc
+いします
+1234
+pixstyleme
+273
+##600
+2800
+money
+★★★★★
+1280
+12
+430
+bl
+みの
+act
+##tus
+tokyo
+##rial
+##life
+emba
+##ae
+saas
+tcs
+##rk
+##wang
+summer
+##sp
+ko
+##ving
+390
+premium
+##その
+netflix
+##ヒ
+uk
+mt
+##lton
+right
+frank
+two
+209
+える
+##ple
+##cal
+021
+##んな
+##sen
+##ville
+hold
+nexus
+dd
+##ius
+てお
+##mah
+##なく
+tila
+zero
+820
+ce
+##tin
+resort
+##ws
+charles
+old
+p10
+5d
+report
+##360
+##ru
+##には
+bus
+vans
+lt
+##est
+pv
+##レ
+links
+rebecca
+##ツ
+##dm
+azure
+##365
+きな
+limited
+bit
+4gb
+##mon
+1910
+moto
+##eam
+213
+1913
+var
+eos
+なとの
+226
+blogspot
+された
+699
+e3
+dos
+dm
+fc
+##ments
+##ik
+##kw
+boy
+##bin
+##ata
+960
+er
+##せ
+219
+##vin
+##tu
+##ula
+194
+##∥
+station
+##ろ
+##ature
+835
+files
+zara
+hdr
+top10
+nature
+950
+magazine
+s6
+marriott
+##シ
+avira
+case
+##っと
+tab
+##ran
+tony
+##home
+oculus
+im
+##ral
+jean
+saint
+cry
+307
+rosie
+##force
+##ini
+ice
+##bert
+のある
+##nder
+##mber
+pet
+2600
+##◆
+plurk
+▲topdec
+##sis
+00kg
+▲topnov
+720
+##ence
+tim
+##ω
+##nc
+##ても
+##name
+log
+ips
+great
+ikea
+malaysia
+unix
+##イト
+3600
+##ncy
+##nie
+12000
+akb48
+##ye
+##oid
+404
+##chi
+##いた
+oa
+xuehai
+##1000
+##orm
+##rf
+275
+さん
+##ware
+##リー
+980
+ho
+##pro
+text
+##era
+560
+bob
+227
+##ub
+##2008
+8891
+scp
+avi
+##zen
+2022
+mi
+wu
+museum
+qvod
+apache
+lake
+jcb
+▲topaug
+★★★
+ni
+##hr
+hill
+302
+ne
+weibo
+490
+ruby
+##ーシ
+##ヶ
+##row
+4d
+▲topjul
+iv
+##ish
+github
+306
+mate
+312
+##スト
+##lot
+##ane
+andrew
+のハイト
+##tina
+t1
+rf
+ed2k
+##vel
+##900
+way
+final
+りの
+ns
+5a
+705
+197
+##メ
+sweet
+bytes
+##ene
+▲topjan
+231
+##cker
+##2007
+##px
+100g
+topapp
+229
+helpapp
+rs
+low
+14k
+g4g
+care
+630
+ldquo
+あり
+##fork
+leave
+rm
+edition
+##gan
+##zon
+##qq
+▲topsep
+##google
+##ism
+gold
+224
+explorer
+##zer
+toyota
+category
+select
+visual
+##labels
+restaurant
+##md
+posts
+s1
+##ico
+もっと
+angelababy
+123456
+217
+sports
+s3
+mbc
+1915
+してくたさい
+shell
+x86
+candy
+##new
+kbs
+face
+xl
+470
+##here
+4a
+swissinfo
+v8
+▲topfeb
+dram
+##ual
+##vice
+3a
+##wer
+sport
+q1
+ios10
+public
+int
+card
+##c
+ep
+au
+rt
+##れた
+1080
+bill
+##mll
+kim
+30
+460
+wan
+##uk
+##ミ
+x3
+298
+0t
+scott
+##ming
+239
+e5
+##3d
+h7n9
+worldcat
+brown
+##あります
+##vo
+##led
+##580
+##ax
+249
+410
+##ert
+paris
+##~6
+polo
+925
+##lr
+599
+##ナ
+capital
+##hing
+bank
+cv
+1g
+##chat
+##s
+##たい
+adc
+##ule
+2m
+##e
+digital
+hotmail
+268
+##pad
+870
+bbq
+quot
+##ring
+before
+wali
+##まて
+mcu
+2k
+2b
+という
+costco
+316
+north
+333
+switch
+##city
+##p
+philips
+##mann
+management
+panasonic
+##cl
+##vd
+##ping
+##rge
+alice
+##lk
+##ましょう
+css3
+##ney
+vision
+alpha
+##ular
+##400
+##tter
+lz
+にお
+##ありません
+mode
+gre
+1916
+pci
+##tm
+237
+1~2
+##yan
+##そ
+について
+##let
+##キ
+work
+war
+coach
+ah
+mary
+##ᅵ
+huang
+##pt
+a8
+pt
+follow
+##berry
+1895
+##ew
+a5
+ghost
+##ション
+##wn
+##og
+south
+##code
+girls
+##rid
+action
+villa
+git
+r11
+table
+games
+##cket
+error
+##anonymoussaid
+##ag
+here
+##ame
+##gc
+qa
+##■
+##lis
+gmp
+##gin
+vmalife
+##cher
+yu
+wedding
+##tis
+demo
+dragon
+530
+soho
+social
+bye
+##rant
+river
+orz
+acer
+325
+##↑
+##ース
+##ats
+261
+del
+##ven
+440
+ups
+##ように
+##ター
+305
+value
+macd
+yougou
+##dn
+661
+##ano
+ll
+##urt
+##rent
+continue
+script
+##wen
+##ect
+paper
+263
+319
+shift
+##chel
+##フト
+##cat
+258
+x5
+fox
+243
+##さん
+car
+aaa
+##blog
+loading
+##yn
+##tp
+kuso
+799
+si
+sns
+イカせるテンマ
+ヒンクテンマ3
+rmb
+vdc
+forest
+central
+prime
+help
+ultra
+##rmb
+##ような
+241
+square
+688
+##しい
+のないフロクに
+##field
+##reen
+##ors
+##ju
+c1
+start
+510
+##air
+##map
+cdn
+##wo
+cba
+stephen
+m8
+100km
+##get
+opera
+##base
+##ood
+vsa
+com™
+##aw
+##ail
+251
+なのて
+count
+t2
+##ᅡ
+##een
+2700
+hop
+##gp
+vsc
+tree
+##eg
+##ose
+816
+285
+##ories
+##shop
+alphago
+v4
+1909
+simon
+##ᆼ
+fluke62max
+zip
+スホンサー
+##sta
+louis
+cr
+bas
+##~10
+bc
+##yer
+hadoop
+##ube
+##wi
+1906
+0755
+hola
+##low
+place
+centre
+5v
+d3
+##fer
+252
+##750
+##media
+281
+540
+0l
+exchange
+262
+series
+##ハー
+##san
+eb
+##bank
+##k
+q3
+##nge
+##mail
+take
+##lp
+259
+1888
+client
+east
+cache
+event
+vincent
+##ールを
+きを
+##nse
+sui
+855
+adchoice
+##и
+##stry
+##なたの
+246
+##zone
+ga
+apps
+sea
+##ab
+248
+cisco
+##タ
+##rner
+kymco
+##care
+dha
+##pu
+##yi
+minkoff
+royal
+p1
+への
+annie
+269
+collection
+kpi
+playstation
+257
+になります
+866
+bh
+##bar
+queen
+505
+radio
+1904
+andy
+armani
+##xy
+manager
+iherb
+##ery
+##share
+spring
+raid
+johnson
+1908
+##ob
+volvo
+hall
+##ball
+v6
+our
+taylor
+##hk
+bi
+242
+##cp
+kate
+bo
+water
+technology
+##rie
+サイトは
+277
+##ona
+##sl
+hpv
+303
+gtx
+hip
+rdquo
+jayz
+stone
+##lex
+##rum
+namespace
+##やり
+620
+##ale
+##atic
+des
+##erson
+##ql
+##ves
+##type
+enter
+##この
+##てきます
+d2
+##168
+##mix
+##bian
+との
+a9
+jj
+ky
+##lc
+access
+movie
+##hc
+リストに
+tower
+##ration
+##mit
+ます
+##nch
+ua
+tel
+prefix
+##o2
+1907
+##point
+1901
+ott
+~10
+##http
+##ury
+baidu
+##ink
+member
+##logy
+bigbang
+nownews
+##js
+##shot
+##tb
+##こと
+247
+eba
+##tics
+##lus
+ける
+v5
+spark
+##ama
+there
+##ions
+god
+##lls
+##down
+hiv
+##ress
+burberry
+day2
+##kv
+◆◆
+jeff
+related
+film
+edit
+joseph
+283
+##ark
+cx
+32gb
+order
+g9
+30000
+##ans
+##tty
+s5
+##bee
+かあります
+thread
+xr
+buy
+sh
+005
+land
+spotify
+mx
+##ari
+276
+##verse
+×email
+sf
+why
+##ことて
+244
+7headlines
+nego
+sunny
+dom
+exo
+401
+666
+positioning
+fit
+rgb
+##tton
+278
+kiss
+alexa
+adam
+lp
+みリストを
+##g
+mp
+##ties
+##llow
+amy
+##du
+np
+002
+institute
+271
+##rth
+##lar
+2345
+590
+##des
+sidebar
+15
+imax
+site
+##cky
+##kit
+##ime
+##009
+season
+323
+##fun
+##ンター
+##ひ
+gogoro
+a7
+pu
+lily
+fire
+twd600
+##ッセーシを
+いて
+##vis
+30ml
+##cture
+##をお
+information
+##オ
+close
+friday
+##くれる
+yi
+nick
+てすか
+##tta
+##tel
+6500
+##lock
+cbd
+economy
+254
+かお
+267
+tinker
+double
+375
+8gb
+voice
+##app
+oops
+channel
+today
+985
+##right
+raw
+xyz
+##+
+jim
+edm
+##cent
+7500
+supreme
+814
+ds
+##its
+##asia
+dropbox
+##てすか
+##tti
+books
+272
+100ml
+##tle
+##ller
+##ken
+##more
+##boy
+sex
+309
+##dom
+t3
+##ider
+##なります
+##unch
+1903
+810
+feel
+5500
+##かった
+##put
+により
+s2
+mo
+##gh
+men
+ka
+amoled
+div
+##tr
+##n1
+port
+howard
+##tags
+ken
+dnf
+##nus
+adsense
+##а
+ide
+##へ
+buff
+thunder
+##town
+##ique
+has
+##body
+auto
+pin
+##erry
+tee
+てした
+295
+number
+##the
+##013
+object
+psp
+cool
+udnbkk
+16gb
+##mic
+miui
+##tro
+most
+r2
+##alk
+##nity
+1880
+±0
+##いました
+428
+s4
+law
+version
+##oa
+n1
+sgs
+docomo
+##tf
+##ack
+henry
+fc2
+##ded
+##sco
+##014
+##rite
+286
+0mm
+linkedin
+##ada
+##now
+wii
+##ndy
+ucbug
+##◎
+sputniknews
+legalminer
+##ika
+##xp
+2gb
+##bu
+q10
+oo
+b6
+come
+##rman
+cheese
+ming
+maker
+##gm
+nikon
+##fig
+ppi
+kelly
+##ります
+jchere
+てきます
+ted
+md
+003
+fgo
+tech
+##tto
+dan
+soc
+##gl
+##len
+hair
+earth
+640
+521
+img
+##pper
+##a1
+##てきる
+##ロク
+acca
+##ition
+##ference
+suite
+##ig
+outlook
+##mond
+##cation
+398
+##pr
+279
+101vip
+358
+##999
+282
+64gb
+3800
+345
+airport
+##over
+284
+##おり
+jones
+##ith
+lab
+##su
+##いるのて
+co2
+town
+piece
+##llo
+no1
+vmware
+24h
+##qi
+focus
+reader
+##admin
+##ora
+tb
+false
+##log
+1898
+know
+lan
+838
+##ces
+f4
+##ume
+motel
+stop
+##oper
+na
+flickr
+netcomponents
+##af
+##─
+pose
+williams
+local
+##ound
+##cg
+##site
+##iko
+いお
+274
+5m
+gsm
+con
+##ath
+1902
+friends
+##hip
+cell
+317
+##rey
+780
+cream
+##cks
+012
+##dp
+facebooktwitterpinterestgoogle
+sso
+324
+shtml
+song
+swiss
+##mw
+##キンク
+lumia
+xdd
+string
+tiffany
+522
+marc
+られた
+insee
+russell
+sc
+dell
+##ations
+ok
+camera
+289
+##vs
+##flow
+##late
+classic
+287
+##nter
+stay
+g1
+mtv
+512
+##ever
+##lab
+##nger
+qe
+sata
+ryan
+d1
+50ml
+cms
+##cing
+su
+292
+3300
+editor
+296
+##nap
+security
+sunday
+association
+##ens
+##700
+##bra
+acg
+##かり
+sofascore
+とは
+mkv
+##ign
+jonathan
+gary
+build
+labels
+##oto
+tesla
+moba
+qi
+gohappy
+general
+ajax
+1024
+##かる
+サイト
+society
+##test
+##urs
+wps
+fedora
+##ich
+mozilla
+328
+##480
+##dr
+usa
+urn
+##lina
+##r
+grace
+##die
+##try
+##ader
+1250
+##なり
+elle
+570
+##chen
+##ᆯ
+price
+##ten
+uhz
+##ough
+eq
+##hen
+states
+push
+session
+balance
+wow
+506
+##cus
+##py
+when
+##ward
+##ep
+34e
+wong
+library
+prada
+##サイト
+##cle
+running
+##ree
+313
+ck
+date
+q4
+##ctive
+##ool
+##>
+mk
+##ira
+##163
+388
+die
+secret
+rq
+dota
+buffet
+は1ヶ
+e6
+##ez
+pan
+368
+ha
+##card
+##cha
+2a
+##さ
+alan
+day3
+eye
+f3
+##end
+france
+keep
+adi
+rna
+tvbs
+##ala
+solo
+nova
+##え
+##tail
+##ょう
+support
+##ries
+##なる
+##ved
+base
+copy
+iis
+fps
+##ways
+hero
+hgih
+profile
+fish
+mu
+ssh
+entertainment
+chang
+##wd
+click
+cake
+##ond
+pre
+##tom
+kic
+pixel
+##ov
+##fl
+product
+6a
+##pd
+dear
+##gate
+es
+yumi
+audio
+##²
+##sky
+echo
+bin
+where
+##ture
+329
+##ape
+find
+sap
+isis
+##なと
+nand
+##101
+##load
+##ream
+band
+a6
+525
+never
+##post
+festival
+50cm
+##we
+555
+guide
+314
+zenfone
+##ike
+335
+gd
+forum
+jessica
+strong
+alexander
+##ould
+software
+allen
+##ious
+program
+360°
+else
+lohasthree
+##gar
+することかてきます
+please
+##れます
+rc
+##ggle
+##ric
+bim
+50000
+##own
+eclipse
+355
+brian
+3ds
+##side
+061
+361
+##other
+##ける
+##tech
+##ator
+485
+engine
+##ged
+##t
+plaza
+##fit
+cia
+ngo
+westbrook
+shi
+tbs
+50mm
+##みませんか
+sci
+291
+reuters
+##ily
+contextlink
+##hn
+af
+##cil
+bridge
+very
+##cel
+1890
+cambridge
+##ize
+15g
+##aid
+##data
+790
+frm
+##head
+award
+butler
+##sun
+meta
+##mar
+america
+ps3
+puma
+pmid
+##すか
+lc
+670
+kitchen
+##lic
+オーフン5
+きなしソフトサーヒス
+そして
+day1
+future
+★★★★
+##text
+##page
+##rris
+pm1
+##ket
+fans
+##っています
+1001
+christian
+bot
+kids
+trackback
+##hai
+c3
+display
+##hl
+n2
+1896
+idea
+さんも
+##sent
+airmail
+##ug
+##men
+pwm
+けます
+028
+##lution
+369
+852
+awards
+schemas
+354
+asics
+wikipedia
+font
+##tional
+##vy
+c2
+293
+##れている
+##dget
+##ein
+っている
+contact
+pepper
+スキル
+339
+##~5
+294
+##uel
+##ument
+730
+##hang
+みてす
+q5
+##sue
+rain
+##ndi
+wei
+swatch
+##cept
+わせ
+331
+popular
+##ste
+##tag
+p2
+501
+trc
+1899
+##west
+##live
+justin
+honda
+ping
+messenger
+##rap
+v9
+543
+##とは
+unity
+appqq
+はすへて
+025
+leo
+##tone
+##テ
+##ass
+uniqlo
+##010
+502
+her
+jane
+memory
+moneydj
+##tical
+human
+12306
+していると
+##m2
+coc
+miacare
+##mn
+tmt
+##core
+vim
+kk
+##may
+fan
+target
+use
+too
+338
+435
+2050
+867
+737
+fast
+##2c
+services
+##ope
+omega
+energy
+##わ
+pinkoi
+1a
+##なから
+##rain
+jackson
+##ement
+##シャンルの
+374
+366
+そんな
+p9
+rd
+##ᆨ
+1111
+##tier
+##vic
+zone
+##│
+385
+690
+dl
+isofix
+cpa
+m4
+322
+kimi
+めて
+davis
+##lay
+lulu
+##uck
+050
+weeks
+qs
+##hop
+920
+##n
+ae
+##ear
+~5
+eia
+405
+##fly
+korea
+jpeg
+boost
+##ship
+small
+##リア
+1860
+eur
+297
+425
+valley
+##iel
+simple
+##ude
+rn
+k2
+##ena
+されます
+non
+patrick
+しているから
+##ナー
+feed
+5757
+30g
+process
+well
+qqmei
+##thing
+they
+aws
+lu
+pink
+##ters
+##kin
+または
+board
+##vertisement
+wine
+##ien
+unicode
+##dge
+r1
+359
+##tant
+いを
+##twitter
+##3c
+cool1
+される
+##れて
+##l
+isp
+##012
+standard
+45㎡2
+402
+##150
+matt
+##fu
+326
+##iner
+googlemsn
+pixnetfacebookyahoo
+##ラン
+x7
+886
+##uce
+メーカー
+sao
+##ev
+##きました
+##file
+9678
+403
+xddd
+shirt
+6l
+##rio
+##hat
+3mm
+givenchy
+ya
+bang
+##lio
+monday
+crystal
+ロクイン
+##abc
+336
+head
+890
+ubuntuforumwikilinuxpastechat
+##vc
+##~20
+##rity
+cnc
+7866
+ipv6
+null
+1897
+##ost
+yang
+imsean
+tiger
+##fet
+##ンス
+352
+##=
+dji
+327
+ji
+maria
+##come
+##んて
+foundation
+3100
+##beth
+##なった
+1m
+601
+active
+##aft
+##don
+3p
+sr
+349
+emma
+##khz
+living
+415
+353
+1889
+341
+709
+457
+sas
+x6
+##face
+pptv
+x4
+##mate
+han
+sophie
+##jing
+337
+fifa
+##mand
+other
+sale
+inwedding
+##gn
+てきちゃいます
+##mmy
+##pmlast
+bad
+nana
+nbc
+してみてくたさいね
+なとはお
+##wu
+##かあります
+##あ
+note7
+single
+##340
+せからこ
+してくたさい♪この
+しにはとんとんワークケートを
+するとあなたにもっとマッチした
+ならワークケートへ
+もみつかっちゃうかも
+ワークケートの
+##bel
+window
+##dio
+##ht
+union
+age
+382
+14
+##ivity
+##y
+コメント
+domain
+neo
+##isa
+##lter
+5k
+f5
+steven
+##cts
+powerpoint
+tft
+self
+g2
+ft
+##テル
+zol
+##act
+mwc
+381
+343
+もう
+nbapop
+408
+てある
+eds
+ace
+##room
+previous
+author
+tomtom
+il
+##ets
+hu
+financial
+☆☆☆
+っています
+bp
+5t
+chi
+1gb
+##hg
+fairmont
+cross
+008
+gay
+h2
+function
+##けて
+356
+also
+1b
+625
+##ータ
+##raph
+1894
+3~5
+##ils
+i3
+334
+avenue
+##host
+による
+##bon
+##tsu
+message
+navigation
+50g
+fintech
+h6
+##ことを
+8cm
+##ject
+##vas
+##firm
+credit
+##wf
+xxxx
+form
+##nor
+##space
+huawei
+plan
+json
+sbl
+##dc
+machine
+921
+392
+wish
+##120
+##sol
+windows7
+edward
+##ために
+development
+washington
+##nsis
+lo
+818
+##sio
+##ym
+##bor
+planet
+##~8
+##wt
+ieee
+gpa
+##めて
+camp
+ann
+gm
+##tw
+##oka
+connect
+##rss
+##work
+##atus
+wall
+chicken
+soul
+2mm
+##times
+fa
+##ather
+##cord
+009
+##eep
+hitachi
+gui
+harry
+##pan
+e1
+disney
+##press
+##ーション
+wind
+386
+frigidaire
+##tl
+liu
+hsu
+332
+basic
+von
+ev
+いた
+てきる
+スホンサーサイト
+learning
+##ull
+expedia
+archives
+change
+##wei
+santa
+cut
+ins
+6gb
+turbo
+brand
+cf1
+508
+004
+return
+747
+##rip
+h1
+##nis
+##をこ
+128gb
+##にお
+3t
+application
+しており
+emc
+rx
+##oon
+384
+quick
+412
+15058
+wilson
+wing
+chapter
+##bug
+beyond
+##cms
+##dar
+##oh
+zoom
+e2
+trip
+sb
+##nba
+rcep
+342
+aspx
+ci
+080
+gc
+gnu
+める
+##count
+advanced
+dance
+dv
+##url
+##ging
+367
+8591
+am09
+shadow
+battle
+346
+##i
+##cia
+##という
+emily
+##のてす
+##tation
+host
+ff
+techorz
+sars
+##mini
+##mporary
+##ering
+nc
+4200
+798
+##next
+cma
+##mbps
+##gas
+##ift
+##dot
+##ィ
+455
+##~17
+amana
+##りの
+426
+##ros
+ir
+00㎡1
+##eet
+##ible
+##↓
+710
+ˋ▽ˊ
+##aka
+dcs
+iq
+##v
+l1
+##lor
+maggie
+##011
+##iu
+588
+##~1
+830
+##gt
+1tb
+articles
+create
+##burg
+##iki
+database
+fantasy
+##rex
+##cam
+dlc
+dean
+##you
+hard
+path
+gaming
+victoria
+maps
+cb
+##lee
+##itor
+overchicstoretvhome
+systems
+##xt
+416
+p3
+sarah
+760
+##nan
+407
+486
+x9
+install
+second
+626
+##ann
+##ph
+##rcle
+##nic
+860
+##nar
+ec
+##とう
+768
+metro
+chocolate
+##rian
+~4
+##table
+##しています
+skin
+##sn
+395
+mountain
+##0mm
+inparadise
+6m
+7x24
+ib
+4800
+##jia
+eeworld
+creative
+g5
+g3
+357
+parker
+ecfa
+village
+からの
+18000
+sylvia
+サーヒス
+hbl
+##ques
+##onsored
+##x2
+##きます
+##v4
+##tein
+ie6
+383
+##stack
+389
+ver
+##ads
+##baby
+sound
+bbe
+##110
+##lone
+##uid
+ads
+022
+gundam
+351
+thinkpad
+006
+scrum
+match
+##ave
+mems
+##470
+##oy
+##なりました
+##talk
+glass
+lamigo
+span
+##eme
+job
+##a5
+jay
+wade
+kde
+498
+##lace
+ocean
+tvg
+##covery
+##r3
+##ners
+##rea
+junior
+think
+##aine
+cover
+##ision
+##sia
+↓↓
+##bow
+msi
+413
+458
+406
+##love
+711
+801
+soft
+z2
+##pl
+456
+1840
+mobil
+mind
+##uy
+427
+nginx
+##oi
+めた
+##rr
+6221
+##mple
+##sson
+##ーシてす
+371
+##nts
+91tv
+comhd
+crv3000
+##uard
+1868
+397
+deep
+lost
+field
+gallery
+##bia
+rate
+spf
+redis
+traction
+930
+icloud
+011
+なら
+fe
+jose
+372
+##tory
+into
+sohu
+fx
+899
+379
+kicstart2
+##hia
+すく
+##~3
+##sit
+ra
+24
+##walk
+##xure
+500g
+##pact
+pacific
+xa
+natural
+carlo
+##250
+##walker
+1850
+##can
+cto
+gigi
+516
+##サー
+pen
+##hoo
+ob
+matlab
+##b
+##yy
+13913459
+##iti
+mango
+##bbs
+sense
+c5
+oxford
+##ニア
+walker
+jennifer
+##ola
+course
+##bre
+701
+##pus
+##rder
+lucky
+075
+##ぁ
+ivy
+なお
+##nia
+sotheby
+side
+##ugh
+joy
+##orage
+##ush
+##bat
+##dt
+364
+r9
+##2d
+##gio
+511
+country
+wear
+##lax
+##~7
+##moon
+393
+seven
+study
+411
+348
+lonzo
+8k
+##ェ
+evolution
+##イフ
+##kk
+gs
+kd
+##レス
+arduino
+344
+b12
+##lux
+arpg
+##rdon
+cook
+##x5
+dark
+five
+##als
+##ida
+とても
+sign
+362
+##ちの
+something
+20mm
+##nda
+387
+##posted
+fresh
+tf
+1870
+422
+cam
+##mine
+##skip
+##form
+##ssion
+education
+394
+##tee
+dyson
+stage
+##jie
+want
+##night
+epson
+pack
+あります
+##ppy
+テリヘル
+##█
+wd
+##eh
+##rence
+left
+##lvin
+golden
+mhz
+discovery
+##trix
+##n2
+loft
+##uch
+##dra
+##sse
+speed
+~1
+1mdb
+sorry
+welcome
+##urn
+wave
+gaga
+##lmer
+teddy
+##160
+トラックハック
+せよ
+611
+##f2016
+378
+rp
+##sha
+rar
+##あなたに
+##きた
+840
+holiday
+##ュー
+373
+074
+##vg
+##nos
+##rail
+gartner
+gi
+6p
+##dium
+kit
+488
+b3
+eco
+##ろう
+20g
+sean
+##stone
+autocad
+nu
+##np
+f16
+write
+029
+m5
+##ias
+images
+atp
+##dk
+fsm
+504
+1350
+ve
+52kb
+##xxx
+##のに
+##cake
+414
+unit
+lim
+ru
+1v
+##ification
+published
+angela
+16g
+analytics
+ak
+##q
+##nel
+gmt
+##icon
+again
+##₂
+##bby
+ios11
+445
+かこさいます
+waze
+いてす
+##ハ
+9985
+##ust
+##ティー
+framework
+##007
+iptv
+delete
+52sykb
+cl
+wwdc
+027
+30cm
+##fw
+##ての
+1389
+##xon
+brandt
+##ses
+##dragon
+tc
+vetements
+anne
+monte
+modern
+official
+##へて
+##ere
+##nne
+##oud
+もちろん
+50
+etnews
+##a2
+##graphy
+421
+863
+##ちゃん
+444
+##rtex
+##てお
+l2
+##gma
+mount
+ccd
+たと
+archive
+morning
+tan
+ddos
+e7
+##ホ
+day4
+##ウ
+gis
+453
+its
+495
+factory
+bruce
+pg
+##ito
+ってくたさい
+guest
+cdma
+##lling
+536
+n3
+しかし
+3~4
+mega
+eyes
+ro
+13
+women
+dac
+church
+##jun
+singapore
+##facebook
+6991
+starbucks
+##tos
+##stin
+##shine
+zen
+##mu
+tina
+20℃
+1893
+##たけて
+503
+465
+request
+##gence
+qt
+##っ
+1886
+347
+363
+q7
+##zzi
+diary
+##tore
+409
+##ead
+468
+cst
+##osa
+canada
+agent
+va
+##jiang
+##ちは
+##ーク
+##lam
+sg
+##nix
+##sday
+##よって
+g6
+##master
+bing
+##zl
+charlie
+16
+8mm
+nb40
+##ーン
+thai
+##ルフ
+ln284ct
+##itz
+##2f
+bonnie
+##food
+##lent
+originals
+##stro
+##lts
+418
+∟∣
+##bscribe
+children
+ntd
+yesstyle
+##かも
+hmv
+##tment
+d5
+2cm
+arts
+sms
+##pn
+##я
+##いい
+topios9
+539
+lifestyle
+virtual
+##ague
+xz
+##deo
+muji
+024
+unt
+##nnis
+##ᅩ
+faq1
+1884
+396
+##ette
+fly
+64㎡
+はしめまして
+441
+curry
+##pop
+のこ
+release
+##←
+##◆◆
+##cast
+073
+ありな
+500ml
+##ews
+5c
+##stle
+ios7
+##ima
+787
+dog
+lenovo
+##r4
+roger
+013
+cbs
+vornado
+100m
+417
+##desk
+##クok
+##ald
+1867
+9595
+2900
+##van
+oil
+##x
+some
+break
+common
+##jy
+##lines
+g7
+twice
+419
+ella
+nano
+belle
+にこ
+##mes
+##self
+##note
+jb
+##ことかてきます
+benz
+##との
+##ova
+451
+save
+##wing
+##ますのて
+kai
+りは
+##hua
+##rect
+rainer
+##unge
+448
+##0m
+adsl
+##かな
+guestname
+##uma
+##kins
+##zu
+tokichoi
+##price
+county
+##med
+##mus
+rmk
+391
+address
+vm
+えて
+openload
+##group
+##hin
+##iginal
+amg
+urban
+##oz
+jobs
+emi
+##public
+beautiful
+##sch
+album
+##dden
+##bell
+jerry
+works
+hostel
+miller
+##drive
+##rmin
+##10
+376
+boot
+828
+##370
+##fx
+##cm~
+1885
+##nome
+##ctionary
+##oman
+##lish
+##cr
+##hm
+433
+##how
+432
+francis
+xi
+c919
+b5
+evernote
+##uc
+vga
+##3000
+coupe
+##urg
+##cca
+##uality
+019
+6g
+れる
+multi
+##また
+##ett
+em
+hey
+##ani
+##tax
+##rma
+inside
+than
+740
+leonnhurt
+##jin
+ict
+れた
+bird
+notes
+200mm
+くの
+##dical
+##lli
+result
+442
+iu
+ee
+438
+smap
+gopro
+##last
+yin
+pure
+998
+32g
+けた
+5kg
+##dan
+##rame
+mama
+##oot
+bean
+marketing
+##hur
+2l
+bella
+sync
+xuite
+##ground
+515
+discuz
+##getrelax
+##ince
+##bay
+##5s
+cj
+##イス
+gmat
+apt
+##pass
+jing
+##rix
+c4
+rich
+##とても
+niusnews
+##ello
+bag
+770
+##eting
+##mobile
+18
+culture
+015
+##のてすか
+377
+1020
+area
+##ience
+616
+details
+gp
+universal
+silver
+dit
+はお
+private
+ddd
+u11
+kanshu
+##ified
+fung
+##nny
+dx
+##520
+tai
+475
+023
+##fr
+##lean
+3s
+##pin
+429
+##rin
+25000
+ly
+rick
+##bility
+usb3
+banner
+##baru
+##gion
+metal
+dt
+vdf
+1871
+karl
+qualcomm
+bear
+1010
+oldid
+ian
+jo
+##tors
+population
+##ernel
+1882
+mmorpg
+##mv
+##bike
+603
+##©
+ww
+friend
+##ager
+exhibition
+##del
+##pods
+fpx
+structure
+##free
+##tings
+kl
+##rley
+##copyright
+##mma
+california
+3400
+orange
+yoga
+4l
+canmake
+honey
+##anda
+##コメント
+595
+nikkie
+##ルハイト
+dhl
+publishing
+##mall
+##gnet
+20cm
+513
+##クセス
+##┅
+e88
+970
+##dog
+fishbase
+##!
+##"
+###
+##$
+##%
+##&
+##'
+##(
+##)
+##*
+##+
+##,
+##-
+##.
+##/
+##:
+##;
+##<
+##=
+##>
+##?
+##@
+##[
+##\
+##]
+##^
+##_
+##{
+##|
+##}
+##~
+##£
+##¤
+##¥
+##§
+##«
+##±
+##³
+##µ
+##·
+##¹
+##º
+##»
+##¼
+##ß
+##æ
+##÷
+##ø
+##đ
+##ŋ
+##ɔ
+##ə
+##ɡ
+##ʰ
+##ˇ
+##ˈ
+##ˊ
+##ˋ
+##ˍ
+##ː
+##˙
+##˚
+##ˢ
+##α
+##β
+##γ
+##δ
+##ε
+##η
+##θ
+##ι
+##κ
+##λ
+##μ
+##ν
+##ο
+##π
+##ρ
+##ς
+##σ
+##τ
+##υ
+##φ
+##χ
+##ψ
+##б
+##в
+##г
+##д
+##е
+##ж
+##з
+##к
+##л
+##м
+##н
+##о
+##п
+##р
+##с
+##т
+##у
+##ф
+##х
+##ц
+##ч
+##ш
+##ы
+##ь
+##і
+##ا
+##ب
+##ة
+##ت
+##د
+##ر
+##س
+##ع
+##ل
+##م
+##ن
+##ه
+##و
+##ي
+##۩
+##ก
+##ง
+##น
+##ม
+##ย
+##ร
+##อ
+##า
+##เ
+##๑
+##་
+##ღ
+##ᄀ
+##ᄁ
+##ᄂ
+##ᄃ
+##ᄅ
+##ᄆ
+##ᄇ
+##ᄈ
+##ᄉ
+##ᄋ
+##ᄌ
+##ᄎ
+##ᄏ
+##ᄐ
+##ᄑ
+##ᄒ
+##ᅢ
+##ᅣ
+##ᅥ
+##ᅦ
+##ᅧ
+##ᅨ
+##ᅪ
+##ᅬ
+##ᅭ
+##ᅮ
+##ᅯ
+##ᅲ
+##ᅳ
+##ᅴ
+##ᆷ
+##ᆸ
+##ᆺ
+##ᆻ
+##ᗜ
+##ᵃ
+##ᵉ
+##ᵍ
+##ᵏ
+##ᵐ
+##ᵒ
+##ᵘ
+##‖
+##„
+##†
+##•
+##‥
+##‧
+##
+##‰
+##′
+##″
+##‹
+##›
+##※
+##‿
+##⁄
+##ⁱ
+##⁺
+##ⁿ
+##₁
+##₃
+##₄
+##€
+##№
+##ⅰ
+##ⅱ
+##ⅲ
+##ⅳ
+##ⅴ
+##↔
+##↗
+##↘
+##⇒
+##∀
+##−
+##∕
+##∙
+##√
+##∞
+##∟
+##∠
+##∣
+##∩
+##∮
+##∶
+##∼
+##∽
+##≈
+##≒
+##≡
+##≤
+##≥
+##≦
+##≧
+##≪
+##≫
+##⊙
+##⋅
+##⋈
+##⋯
+##⌒
+##①
+##②
+##③
+##④
+##⑤
+##⑥
+##⑦
+##⑧
+##⑨
+##⑩
+##⑴
+##⑵
+##⑶
+##⑷
+##⑸
+##⒈
+##⒉
+##⒊
+##⒋
+##ⓒ
+##ⓔ
+##ⓘ
+##━
+##┃
+##┆
+##┊
+##┌
+##└
+##├
+##┣
+##═
+##║
+##╚
+##╞
+##╠
+##╭
+##╮
+##╯
+##╰
+##╱
+##╳
+##▂
+##▃
+##▅
+##▇
+##▉
+##▋
+##▌
+##▍
+##▎
+##□
+##▪
+##▫
+##▬
+##△
+##▶
+##►
+##▽
+##◇
+##◕
+##◠
+##◢
+##◤
+##☀
+##☕
+##☞
+##☺
+##☼
+##♀
+##♂
+##♠
+##♡
+##♣
+##♦
+##♫
+##♬
+##✈
+##✔
+##✕
+##✖
+##✦
+##✨
+##✪
+##✰
+##✿
+##❀
+##➜
+##➤
+##⦿
+##、
+##。
+##〃
+##々
+##〇
+##〈
+##〉
+##《
+##》
+##「
+##」
+##『
+##』
+##【
+##】
+##〓
+##〔
+##〕
+##〖
+##〗
+##〜
+##〝
+##〞
+##ぃ
+##ぇ
+##ぬ
+##ふ
+##ほ
+##む
+##ゃ
+##ゅ
+##ゆ
+##ょ
+##゜
+##ゝ
+##ァ
+##ゥ
+##エ
+##ォ
+##ケ
+##サ
+##セ
+##ソ
+##ッ
+##ニ
+##ヌ
+##ネ
+##ノ
+##ヘ
+##モ
+##ャ
+##ヤ
+##ュ
+##ユ
+##ョ
+##ヨ
+##ワ
+##ヲ
+##・
+##ヽ
+##ㄅ
+##ㄆ
+##ㄇ
+##ㄉ
+##ㄋ
+##ㄌ
+##ㄍ
+##ㄎ
+##ㄏ
+##ㄒ
+##ㄚ
+##ㄛ
+##ㄞ
+##ㄟ
+##ㄢ
+##ㄤ
+##ㄥ
+##ㄧ
+##ㄨ
+##ㆍ
+##㈦
+##㊣
+##㗎
+##一
+##丁
+##七
+##万
+##丈
+##三
+##上
+##下
+##不
+##与
+##丐
+##丑
+##专
+##且
+##丕
+##世
+##丘
+##丙
+##业
+##丛
+##东
+##丝
+##丞
+##丟
+##両
+##丢
+##两
+##严
+##並
+##丧
+##丨
+##个
+##丫
+##中
+##丰
+##串
+##临
+##丶
+##丸
+##丹
+##为
+##主
+##丼
+##丽
+##举
+##丿
+##乂
+##乃
+##久
+##么
+##义
+##之
+##乌
+##乍
+##乎
+##乏
+##乐
+##乒
+##乓
+##乔
+##乖
+##乗
+##乘
+##乙
+##乜
+##九
+##乞
+##也
+##习
+##乡
+##书
+##乩
+##买
+##乱
+##乳
+##乾
+##亀
+##亂
+##了
+##予
+##争
+##事
+##二
+##于
+##亏
+##云
+##互
+##五
+##井
+##亘
+##亙
+##亚
+##些
+##亜
+##亞
+##亟
+##亡
+##亢
+##交
+##亥
+##亦
+##产
+##亨
+##亩
+##享
+##京
+##亭
+##亮
+##亲
+##亳
+##亵
+##人
+##亿
+##什
+##仁
+##仃
+##仄
+##仅
+##仆
+##仇
+##今
+##介
+##仍
+##从
+##仏
+##仑
+##仓
+##仔
+##仕
+##他
+##仗
+##付
+##仙
+##仝
+##仞
+##仟
+##代
+##令
+##以
+##仨
+##仪
+##们
+##仮
+##仰
+##仲
+##件
+##价
+##任
+##份
+##仿
+##企
+##伉
+##伊
+##伍
+##伎
+##伏
+##伐
+##休
+##伕
+##众
+##优
+##伙
+##会
+##伝
+##伞
+##伟
+##传
+##伢
+##伤
+##伦
+##伪
+##伫
+##伯
+##估
+##伴
+##伶
+##伸
+##伺
+##似
+##伽
+##佃
+##但
+##佇
+##佈
+##位
+##低
+##住
+##佐
+##佑
+##体
+##佔
+##何
+##佗
+##佘
+##余
+##佚
+##佛
+##作
+##佝
+##佞
+##佟
+##你
+##佢
+##佣
+##佤
+##佥
+##佩
+##佬
+##佯
+##佰
+##佳
+##併
+##佶
+##佻
+##佼
+##使
+##侃
+##侄
+##來
+##侈
+##例
+##侍
+##侏
+##侑
+##侖
+##侗
+##供
+##依
+##侠
+##価
+##侣
+##侥
+##侦
+##侧
+##侨
+##侬
+##侮
+##侯
+##侵
+##侶
+##侷
+##便
+##係
+##促
+##俄
+##俊
+##俎
+##俏
+##俐
+##俑
+##俗
+##俘
+##俚
+##保
+##俞
+##俟
+##俠
+##信
+##俨
+##俩
+##俪
+##俬
+##俭
+##修
+##俯
+##俱
+##俳
+##俸
+##俺
+##俾
+##倆
+##倉
+##個
+##倌
+##倍
+##倏
+##們
+##倒
+##倔
+##倖
+##倘
+##候
+##倚
+##倜
+##借
+##倡
+##値
+##倦
+##倩
+##倪
+##倫
+##倬
+##倭
+##倶
+##债
+##值
+##倾
+##偃
+##假
+##偈
+##偉
+##偌
+##偎
+##偏
+##偕
+##做
+##停
+##健
+##側
+##偵
+##偶
+##偷
+##偻
+##偽
+##偿
+##傀
+##傅
+##傍
+##傑
+##傘
+##備
+##傚
+##傢
+##傣
+##傥
+##储
+##傩
+##催
+##傭
+##傲
+##傳
+##債
+##傷
+##傻
+##傾
+##僅
+##働
+##像
+##僑
+##僕
+##僖
+##僚
+##僥
+##僧
+##僭
+##僮
+##僱
+##僵
+##價
+##僻
+##儀
+##儂
+##億
+##儆
+##儉
+##儋
+##儒
+##儕
+##儘
+##償
+##儡
+##優
+##儲
+##儷
+##儼
+##儿
+##兀
+##允
+##元
+##兄
+##充
+##兆
+##兇
+##先
+##光
+##克
+##兌
+##免
+##児
+##兑
+##兒
+##兔
+##兖
+##党
+##兜
+##兢
+##入
+##內
+##全
+##兩
+##八
+##公
+##六
+##兮
+##兰
+##共
+##兲
+##关
+##兴
+##兵
+##其
+##具
+##典
+##兹
+##养
+##兼
+##兽
+##冀
+##内
+##円
+##冇
+##冈
+##冉
+##冊
+##册
+##再
+##冏
+##冒
+##冕
+##冗
+##写
+##军
+##农
+##冠
+##冢
+##冤
+##冥
+##冨
+##冪
+##冬
+##冯
+##冰
+##冲
+##决
+##况
+##冶
+##冷
+##冻
+##冼
+##冽
+##冾
+##净
+##凄
+##准
+##凇
+##凈
+##凉
+##凋
+##凌
+##凍
+##减
+##凑
+##凛
+##凜
+##凝
+##几
+##凡
+##凤
+##処
+##凪
+##凭
+##凯
+##凰
+##凱
+##凳
+##凶
+##凸
+##凹
+##出
+##击
+##函
+##凿
+##刀
+##刁
+##刃
+##分
+##切
+##刈
+##刊
+##刍
+##刎
+##刑
+##划
+##列
+##刘
+##则
+##刚
+##创
+##初
+##删
+##判
+##別
+##刨
+##利
+##刪
+##别
+##刮
+##到
+##制
+##刷
+##券
+##刹
+##刺
+##刻
+##刽
+##剁
+##剂
+##剃
+##則
+##剉
+##削
+##剋
+##剌
+##前
+##剎
+##剐
+##剑
+##剔
+##剖
+##剛
+##剜
+##剝
+##剣
+##剤
+##剥
+##剧
+##剩
+##剪
+##副
+##割
+##創
+##剷
+##剽
+##剿
+##劃
+##劇
+##劈
+##劉
+##劊
+##劍
+##劏
+##劑
+##力
+##劝
+##办
+##功
+##加
+##务
+##劣
+##动
+##助
+##努
+##劫
+##劭
+##励
+##劲
+##劳
+##労
+##劵
+##効
+##劾
+##势
+##勁
+##勃
+##勇
+##勉
+##勋
+##勐
+##勒
+##動
+##勖
+##勘
+##務
+##勛
+##勝
+##勞
+##募
+##勢
+##勤
+##勧
+##勳
+##勵
+##勸
+##勺
+##勻
+##勾
+##勿
+##匀
+##包
+##匆
+##匈
+##匍
+##匐
+##匕
+##化
+##北
+##匙
+##匝
+##匠
+##匡
+##匣
+##匪
+##匮
+##匯
+##匱
+##匹
+##区
+##医
+##匾
+##匿
+##區
+##十
+##千
+##卅
+##升
+##午
+##卉
+##半
+##卍
+##华
+##协
+##卑
+##卒
+##卓
+##協
+##单
+##卖
+##南
+##単
+##博
+##卜
+##卞
+##卟
+##占
+##卡
+##卢
+##卤
+##卦
+##卧
+##卫
+##卮
+##卯
+##印
+##危
+##即
+##却
+##卵
+##卷
+##卸
+##卻
+##卿
+##厂
+##厄
+##厅
+##历
+##厉
+##压
+##厌
+##厕
+##厘
+##厚
+##厝
+##原
+##厢
+##厥
+##厦
+##厨
+##厩
+##厭
+##厮
+##厲
+##厳
+##去
+##县
+##叁
+##参
+##參
+##又
+##叉
+##及
+##友
+##双
+##反
+##収
+##发
+##叔
+##取
+##受
+##变
+##叙
+##叛
+##叟
+##叠
+##叡
+##叢
+##口
+##古
+##句
+##另
+##叨
+##叩
+##只
+##叫
+##召
+##叭
+##叮
+##可
+##台
+##叱
+##史
+##右
+##叵
+##叶
+##号
+##司
+##叹
+##叻
+##叼
+##叽
+##吁
+##吃
+##各
+##吆
+##合
+##吉
+##吊
+##吋
+##同
+##名
+##后
+##吏
+##吐
+##向
+##吒
+##吓
+##吕
+##吖
+##吗
+##君
+##吝
+##吞
+##吟
+##吠
+##吡
+##否
+##吧
+##吨
+##吩
+##含
+##听
+##吭
+##吮
+##启
+##吱
+##吳
+##吴
+##吵
+##吶
+##吸
+##吹
+##吻
+##吼
+##吽
+##吾
+##呀
+##呂
+##呃
+##呆
+##呈
+##告
+##呋
+##呎
+##呐
+##呓
+##呕
+##呗
+##员
+##呛
+##呜
+##呢
+##呤
+##呦
+##周
+##呱
+##呲
+##味
+##呵
+##呷
+##呸
+##呻
+##呼
+##命
+##咀
+##咁
+##咂
+##咄
+##咆
+##咋
+##和
+##咎
+##咏
+##咐
+##咒
+##咔
+##咕
+##咖
+##咗
+##咘
+##咙
+##咚
+##咛
+##咣
+##咤
+##咦
+##咧
+##咨
+##咩
+##咪
+##咫
+##咬
+##咭
+##咯
+##咱
+##咲
+##咳
+##咸
+##咻
+##咽
+##咿
+##哀
+##品
+##哂
+##哄
+##哆
+##哇
+##哈
+##哉
+##哋
+##哌
+##响
+##哎
+##哏
+##哐
+##哑
+##哒
+##哔
+##哗
+##哟
+##員
+##哥
+##哦
+##哧
+##哨
+##哩
+##哪
+##哭
+##哮
+##哲
+##哺
+##哼
+##哽
+##唁
+##唄
+##唆
+##唇
+##唉
+##唏
+##唐
+##唑
+##唔
+##唠
+##唤
+##唧
+##唬
+##售
+##唯
+##唰
+##唱
+##唳
+##唷
+##唸
+##唾
+##啃
+##啄
+##商
+##啉
+##啊
+##問
+##啓
+##啕
+##啖
+##啜
+##啞
+##啟
+##啡
+##啤
+##啥
+##啦
+##啧
+##啪
+##啫
+##啬
+##啮
+##啰
+##啱
+##啲
+##啵
+##啶
+##啷
+##啸
+##啻
+##啼
+##啾
+##喀
+##喂
+##喃
+##善
+##喆
+##喇
+##喉
+##喊
+##喋
+##喎
+##喏
+##喔
+##喘
+##喙
+##喚
+##喜
+##喝
+##喟
+##喧
+##喪
+##喫
+##喬
+##單
+##喰
+##喱
+##喲
+##喳
+##喵
+##営
+##喷
+##喹
+##喺
+##喻
+##喽
+##嗅
+##嗆
+##嗇
+##嗎
+##嗑
+##嗒
+##嗓
+##嗔
+##嗖
+##嗚
+##嗜
+##嗝
+##嗟
+##嗡
+##嗣
+##嗤
+##嗦
+##嗨
+##嗪
+##嗬
+##嗯
+##嗰
+##嗲
+##嗳
+##嗶
+##嗷
+##嗽
+##嘀
+##嘅
+##嘆
+##嘈
+##嘉
+##嘌
+##嘍
+##嘎
+##嘔
+##嘖
+##嘗
+##嘘
+##嘚
+##嘛
+##嘜
+##嘞
+##嘟
+##嘢
+##嘣
+##嘤
+##嘧
+##嘩
+##嘭
+##嘮
+##嘯
+##嘰
+##嘱
+##嘲
+##嘴
+##嘶
+##嘸
+##嘹
+##嘻
+##嘿
+##噁
+##噌
+##噎
+##噓
+##噔
+##噗
+##噙
+##噜
+##噠
+##噢
+##噤
+##器
+##噩
+##噪
+##噬
+##噱
+##噴
+##噶
+##噸
+##噹
+##噻
+##噼
+##嚀
+##嚇
+##嚎
+##嚏
+##嚐
+##嚓
+##嚕
+##嚟
+##嚣
+##嚥
+##嚨
+##嚮
+##嚴
+##嚷
+##嚼
+##囂
+##囉
+##囊
+##囍
+##囑
+##囔
+##囗
+##囚
+##四
+##囝
+##回
+##囟
+##因
+##囡
+##团
+##団
+##囤
+##囧
+##囪
+##囫
+##园
+##困
+##囱
+##囲
+##図
+##围
+##囹
+##固
+##国
+##图
+##囿
+##圃
+##圄
+##圆
+##圈
+##國
+##圍
+##圏
+##園
+##圓
+##圖
+##團
+##圜
+##土
+##圣
+##圧
+##在
+##圩
+##圭
+##地
+##圳
+##场
+##圻
+##圾
+##址
+##坂
+##均
+##坊
+##坍
+##坎
+##坏
+##坐
+##坑
+##块
+##坚
+##坛
+##坝
+##坞
+##坟
+##坠
+##坡
+##坤
+##坦
+##坨
+##坪
+##坯
+##坳
+##坵
+##坷
+##垂
+##垃
+##垄
+##型
+##垒
+##垚
+##垛
+##垠
+##垢
+##垣
+##垦
+##垩
+##垫
+##垭
+##垮
+##垵
+##埂
+##埃
+##埋
+##城
+##埔
+##埕
+##埗
+##域
+##埠
+##埤
+##埵
+##執
+##埸
+##培
+##基
+##埼
+##堀
+##堂
+##堃
+##堅
+##堆
+##堇
+##堑
+##堕
+##堙
+##堡
+##堤
+##堪
+##堯
+##堰
+##報
+##場
+##堵
+##堺
+##堿
+##塊
+##塌
+##塑
+##塔
+##塗
+##塘
+##塚
+##塞
+##塢
+##塩
+##填
+##塬
+##塭
+##塵
+##塾
+##墀
+##境
+##墅
+##墉
+##墊
+##墒
+##墓
+##増
+##墘
+##墙
+##墜
+##增
+##墟
+##墨
+##墩
+##墮
+##墳
+##墻
+##墾
+##壁
+##壅
+##壆
+##壇
+##壊
+##壑
+##壓
+##壕
+##壘
+##壞
+##壟
+##壢
+##壤
+##壩
+##士
+##壬
+##壮
+##壯
+##声
+##売
+##壳
+##壶
+##壹
+##壺
+##壽
+##处
+##备
+##変
+##复
+##夏
+##夔
+##夕
+##外
+##夙
+##多
+##夜
+##够
+##夠
+##夢
+##夥
+##大
+##天
+##太
+##夫
+##夭
+##央
+##夯
+##失
+##头
+##夷
+##夸
+##夹
+##夺
+##夾
+##奂
+##奄
+##奇
+##奈
+##奉
+##奋
+##奎
+##奏
+##奐
+##契
+##奔
+##奕
+##奖
+##套
+##奘
+##奚
+##奠
+##奢
+##奥
+##奧
+##奪
+##奬
+##奮
+##女
+##奴
+##奶
+##奸
+##她
+##好
+##如
+##妃
+##妄
+##妆
+##妇
+##妈
+##妊
+##妍
+##妒
+##妓
+##妖
+##妘
+##妙
+##妝
+##妞
+##妣
+##妤
+##妥
+##妨
+##妩
+##妪
+##妮
+##妲
+##妳
+##妹
+##妻
+##妾
+##姆
+##姉
+##姊
+##始
+##姍
+##姐
+##姑
+##姒
+##姓
+##委
+##姗
+##姚
+##姜
+##姝
+##姣
+##姥
+##姦
+##姨
+##姪
+##姫
+##姬
+##姹
+##姻
+##姿
+##威
+##娃
+##娄
+##娅
+##娆
+##娇
+##娉
+##娑
+##娓
+##娘
+##娛
+##娜
+##娟
+##娠
+##娣
+##娥
+##娩
+##娱
+##娲
+##娴
+##娶
+##娼
+##婀
+##婁
+##婆
+##婉
+##婊
+##婕
+##婚
+##婢
+##婦
+##婧
+##婪
+##婭
+##婴
+##婵
+##婶
+##婷
+##婺
+##婿
+##媒
+##媚
+##媛
+##媞
+##媧
+##媲
+##媳
+##媽
+##媾
+##嫁
+##嫂
+##嫉
+##嫌
+##嫑
+##嫔
+##嫖
+##嫘
+##嫚
+##嫡
+##嫣
+##嫦
+##嫩
+##嫲
+##嫵
+##嫻
+##嬅
+##嬉
+##嬌
+##嬗
+##嬛
+##嬢
+##嬤
+##嬪
+##嬰
+##嬴
+##嬷
+##嬸
+##嬿
+##孀
+##孃
+##子
+##孑
+##孔
+##孕
+##孖
+##字
+##存
+##孙
+##孚
+##孛
+##孜
+##孝
+##孟
+##孢
+##季
+##孤
+##学
+##孩
+##孪
+##孫
+##孬
+##孰
+##孱
+##孳
+##孵
+##學
+##孺
+##孽
+##孿
+##宁
+##它
+##宅
+##宇
+##守
+##安
+##宋
+##完
+##宏
+##宓
+##宕
+##宗
+##官
+##宙
+##定
+##宛
+##宜
+##宝
+##实
+##実
+##宠
+##审
+##客
+##宣
+##室
+##宥
+##宦
+##宪
+##宫
+##宮
+##宰
+##害
+##宴
+##宵
+##家
+##宸
+##容
+##宽
+##宾
+##宿
+##寂
+##寄
+##寅
+##密
+##寇
+##富
+##寐
+##寒
+##寓
+##寛
+##寝
+##寞
+##察
+##寡
+##寢
+##寥
+##實
+##寧
+##寨
+##審
+##寫
+##寬
+##寮
+##寰
+##寵
+##寶
+##寸
+##对
+##寺
+##寻
+##导
+##対
+##寿
+##封
+##専
+##射
+##将
+##將
+##專
+##尉
+##尊
+##尋
+##對
+##導
+##小
+##少
+##尔
+##尕
+##尖
+##尘
+##尚
+##尝
+##尤
+##尧
+##尬
+##就
+##尴
+##尷
+##尸
+##尹
+##尺
+##尻
+##尼
+##尽
+##尾
+##尿
+##局
+##屁
+##层
+##屄
+##居
+##屆
+##屈
+##屉
+##届
+##屋
+##屌
+##屍
+##屎
+##屏
+##屐
+##屑
+##展
+##屜
+##属
+##屠
+##屡
+##屢
+##層
+##履
+##屬
+##屯
+##山
+##屹
+##屿
+##岀
+##岁
+##岂
+##岌
+##岐
+##岑
+##岔
+##岖
+##岗
+##岘
+##岙
+##岚
+##岛
+##岡
+##岩
+##岫
+##岬
+##岭
+##岱
+##岳
+##岷
+##岸
+##峇
+##峋
+##峒
+##峙
+##峡
+##峤
+##峥
+##峦
+##峨
+##峪
+##峭
+##峯
+##峰
+##峴
+##島
+##峻
+##峽
+##崁
+##崂
+##崆
+##崇
+##崎
+##崑
+##崔
+##崖
+##崗
+##崙
+##崛
+##崧
+##崩
+##崭
+##崴
+##崽
+##嵇
+##嵊
+##嵋
+##嵌
+##嵐
+##嵘
+##嵩
+##嵬
+##嵯
+##嶂
+##嶄
+##嶇
+##嶋
+##嶙
+##嶺
+##嶼
+##嶽
+##巅
+##巍
+##巒
+##巔
+##巖
+##川
+##州
+##巡
+##巢
+##工
+##左
+##巧
+##巨
+##巩
+##巫
+##差
+##己
+##已
+##巳
+##巴
+##巷
+##巻
+##巽
+##巾
+##巿
+##币
+##市
+##布
+##帅
+##帆
+##师
+##希
+##帐
+##帑
+##帕
+##帖
+##帘
+##帚
+##帛
+##帜
+##帝
+##帥
+##带
+##帧
+##師
+##席
+##帮
+##帯
+##帰
+##帳
+##帶
+##帷
+##常
+##帼
+##帽
+##幀
+##幂
+##幄
+##幅
+##幌
+##幔
+##幕
+##幟
+##幡
+##幢
+##幣
+##幫
+##干
+##平
+##年
+##并
+##幸
+##幹
+##幺
+##幻
+##幼
+##幽
+##幾
+##广
+##庁
+##広
+##庄
+##庆
+##庇
+##床
+##序
+##庐
+##库
+##应
+##底
+##庖
+##店
+##庙
+##庚
+##府
+##庞
+##废
+##庠
+##度
+##座
+##庫
+##庭
+##庵
+##庶
+##康
+##庸
+##庹
+##庾
+##廁
+##廂
+##廃
+##廈
+##廉
+##廊
+##廓
+##廖
+##廚
+##廝
+##廟
+##廠
+##廢
+##廣
+##廬
+##廳
+##延
+##廷
+##建
+##廿
+##开
+##弁
+##异
+##弃
+##弄
+##弈
+##弊
+##弋
+##式
+##弑
+##弒
+##弓
+##弔
+##引
+##弗
+##弘
+##弛
+##弟
+##张
+##弥
+##弦
+##弧
+##弩
+##弭
+##弯
+##弱
+##張
+##強
+##弹
+##强
+##弼
+##弾
+##彅
+##彆
+##彈
+##彌
+##彎
+##归
+##当
+##录
+##彗
+##彙
+##彝
+##形
+##彤
+##彥
+##彦
+##彧
+##彩
+##彪
+##彫
+##彬
+##彭
+##彰
+##影
+##彷
+##役
+##彻
+##彼
+##彿
+##往
+##征
+##径
+##待
+##徇
+##很
+##徉
+##徊
+##律
+##後
+##徐
+##徑
+##徒
+##従
+##徕
+##得
+##徘
+##徙
+##徜
+##從
+##徠
+##御
+##徨
+##復
+##循
+##徬
+##微
+##徳
+##徴
+##徵
+##德
+##徹
+##徼
+##徽
+##心
+##必
+##忆
+##忌
+##忍
+##忏
+##忐
+##忑
+##忒
+##忖
+##志
+##忘
+##忙
+##応
+##忠
+##忡
+##忤
+##忧
+##忪
+##快
+##忱
+##念
+##忻
+##忽
+##忿
+##怀
+##态
+##怂
+##怅
+##怆
+##怎
+##怏
+##怒
+##怔
+##怕
+##怖
+##怙
+##怜
+##思
+##怠
+##怡
+##急
+##怦
+##性
+##怨
+##怪
+##怯
+##怵
+##总
+##怼
+##恁
+##恃
+##恆
+##恋
+##恍
+##恐
+##恒
+##恕
+##恙
+##恚
+##恢
+##恣
+##恤
+##恥
+##恨
+##恩
+##恪
+##恫
+##恬
+##恭
+##息
+##恰
+##恳
+##恵
+##恶
+##恸
+##恺
+##恻
+##恼
+##恿
+##悄
+##悅
+##悉
+##悌
+##悍
+##悔
+##悖
+##悚
+##悟
+##悠
+##患
+##悦
+##您
+##悩
+##悪
+##悬
+##悯
+##悱
+##悲
+##悴
+##悵
+##悶
+##悸
+##悻
+##悼
+##悽
+##情
+##惆
+##惇
+##惊
+##惋
+##惑
+##惕
+##惘
+##惚
+##惜
+##惟
+##惠
+##惡
+##惦
+##惧
+##惨
+##惩
+##惫
+##惬
+##惭
+##惮
+##惯
+##惰
+##惱
+##想
+##惴
+##惶
+##惹
+##惺
+##愁
+##愆
+##愈
+##愉
+##愍
+##意
+##愕
+##愚
+##愛
+##愜
+##感
+##愣
+##愤
+##愧
+##愫
+##愷
+##愿
+##慄
+##慈
+##態
+##慌
+##慎
+##慑
+##慕
+##慘
+##慚
+##慟
+##慢
+##慣
+##慧
+##慨
+##慫
+##慮
+##慰
+##慳
+##慵
+##慶
+##慷
+##慾
+##憂
+##憊
+##憋
+##憎
+##憐
+##憑
+##憔
+##憚
+##憤
+##憧
+##憨
+##憩
+##憫
+##憬
+##憲
+##憶
+##憾
+##懂
+##懇
+##懈
+##應
+##懊
+##懋
+##懑
+##懒
+##懦
+##懲
+##懵
+##懶
+##懷
+##懸
+##懺
+##懼
+##懾
+##懿
+##戀
+##戈
+##戊
+##戌
+##戍
+##戎
+##戏
+##成
+##我
+##戒
+##戕
+##或
+##战
+##戚
+##戛
+##戟
+##戡
+##戦
+##截
+##戬
+##戮
+##戰
+##戲
+##戳
+##戴
+##戶
+##户
+##戸
+##戻
+##戾
+##房
+##所
+##扁
+##扇
+##扈
+##扉
+##手
+##才
+##扎
+##扑
+##扒
+##打
+##扔
+##払
+##托
+##扛
+##扣
+##扦
+##执
+##扩
+##扪
+##扫
+##扬
+##扭
+##扮
+##扯
+##扰
+##扱
+##扳
+##扶
+##批
+##扼
+##找
+##承
+##技
+##抄
+##抉
+##把
+##抑
+##抒
+##抓
+##投
+##抖
+##抗
+##折
+##抚
+##抛
+##抜
+##択
+##抟
+##抠
+##抡
+##抢
+##护
+##报
+##抨
+##披
+##抬
+##抱
+##抵
+##抹
+##押
+##抽
+##抿
+##拂
+##拄
+##担
+##拆
+##拇
+##拈
+##拉
+##拋
+##拌
+##拍
+##拎
+##拐
+##拒
+##拓
+##拔
+##拖
+##拗
+##拘
+##拙
+##拚
+##招
+##拜
+##拟
+##拡
+##拢
+##拣
+##拥
+##拦
+##拧
+##拨
+##择
+##括
+##拭
+##拮
+##拯
+##拱
+##拳
+##拴
+##拷
+##拼
+##拽
+##拾
+##拿
+##持
+##挂
+##指
+##挈
+##按
+##挎
+##挑
+##挖
+##挙
+##挚
+##挛
+##挝
+##挞
+##挟
+##挠
+##挡
+##挣
+##挤
+##挥
+##挨
+##挪
+##挫
+##振
+##挲
+##挹
+##挺
+##挽
+##挾
+##捂
+##捅
+##捆
+##捉
+##捋
+##捌
+##捍
+##捎
+##捏
+##捐
+##捕
+##捞
+##损
+##捡
+##换
+##捣
+##捧
+##捨
+##捩
+##据
+##捱
+##捲
+##捶
+##捷
+##捺
+##捻
+##掀
+##掂
+##掃
+##掇
+##授
+##掉
+##掌
+##掏
+##掐
+##排
+##掖
+##掘
+##掙
+##掛
+##掠
+##採
+##探
+##掣
+##接
+##控
+##推
+##掩
+##措
+##掬
+##掰
+##掲
+##掳
+##掴
+##掷
+##掸
+##掺
+##揀
+##揃
+##揄
+##揆
+##揉
+##揍
+##描
+##提
+##插
+##揖
+##揚
+##換
+##握
+##揣
+##揩
+##揪
+##揭
+##揮
+##援
+##揶
+##揸
+##揹
+##揽
+##搀
+##搁
+##搂
+##搅
+##損
+##搏
+##搐
+##搓
+##搔
+##搖
+##搗
+##搜
+##搞
+##搡
+##搪
+##搬
+##搭
+##搵
+##搶
+##携
+##搽
+##摀
+##摁
+##摄
+##摆
+##摇
+##摈
+##摊
+##摒
+##摔
+##摘
+##摞
+##摟
+##摧
+##摩
+##摯
+##摳
+##摸
+##摹
+##摺
+##摻
+##撂
+##撃
+##撅
+##撇
+##撈
+##撐
+##撑
+##撒
+##撓
+##撕
+##撚
+##撞
+##撤
+##撥
+##撩
+##撫
+##撬
+##播
+##撮
+##撰
+##撲
+##撵
+##撷
+##撸
+##撻
+##撼
+##撿
+##擀
+##擁
+##擂
+##擄
+##擅
+##擇
+##擊
+##擋
+##操
+##擎
+##擒
+##擔
+##擘
+##據
+##擞
+##擠
+##擡
+##擢
+##擦
+##擬
+##擰
+##擱
+##擲
+##擴
+##擷
+##擺
+##擼
+##擾
+##攀
+##攏
+##攒
+##攔
+##攘
+##攙
+##攜
+##攝
+##攞
+##攢
+##攣
+##攤
+##攥
+##攪
+##攫
+##攬
+##支
+##收
+##攸
+##改
+##攻
+##放
+##政
+##故
+##效
+##敌
+##敍
+##敎
+##敏
+##救
+##敕
+##敖
+##敗
+##敘
+##教
+##敛
+##敝
+##敞
+##敢
+##散
+##敦
+##敬
+##数
+##敲
+##整
+##敵
+##敷
+##數
+##斂
+##斃
+##文
+##斋
+##斌
+##斎
+##斐
+##斑
+##斓
+##斗
+##料
+##斛
+##斜
+##斟
+##斡
+##斤
+##斥
+##斧
+##斩
+##斫
+##斬
+##断
+##斯
+##新
+##斷
+##方
+##於
+##施
+##旁
+##旃
+##旅
+##旋
+##旌
+##旎
+##族
+##旖
+##旗
+##无
+##既
+##日
+##旦
+##旧
+##旨
+##早
+##旬
+##旭
+##旮
+##旱
+##时
+##旷
+##旺
+##旻
+##昀
+##昂
+##昆
+##昇
+##昉
+##昊
+##昌
+##明
+##昏
+##易
+##昔
+##昕
+##昙
+##星
+##映
+##春
+##昧
+##昨
+##昭
+##是
+##昱
+##昴
+##昵
+##昶
+##昼
+##显
+##晁
+##時
+##晃
+##晉
+##晋
+##晌
+##晏
+##晒
+##晓
+##晔
+##晕
+##晖
+##晗
+##晚
+##晝
+##晞
+##晟
+##晤
+##晦
+##晨
+##晩
+##普
+##景
+##晰
+##晴
+##晶
+##晷
+##智
+##晾
+##暂
+##暄
+##暇
+##暈
+##暉
+##暌
+##暐
+##暑
+##暖
+##暗
+##暝
+##暢
+##暧
+##暨
+##暫
+##暮
+##暱
+##暴
+##暸
+##暹
+##曄
+##曆
+##曇
+##曉
+##曖
+##曙
+##曜
+##曝
+##曠
+##曦
+##曬
+##曰
+##曲
+##曳
+##更
+##書
+##曹
+##曼
+##曾
+##替
+##最
+##會
+##月
+##有
+##朋
+##服
+##朐
+##朔
+##朕
+##朗
+##望
+##朝
+##期
+##朦
+##朧
+##木
+##未
+##末
+##本
+##札
+##朮
+##术
+##朱
+##朴
+##朵
+##机
+##朽
+##杀
+##杂
+##权
+##杆
+##杈
+##杉
+##李
+##杏
+##材
+##村
+##杓
+##杖
+##杜
+##杞
+##束
+##杠
+##条
+##来
+##杨
+##杭
+##杯
+##杰
+##東
+##杳
+##杵
+##杷
+##杼
+##松
+##板
+##极
+##构
+##枇
+##枉
+##枋
+##析
+##枕
+##林
+##枚
+##果
+##枝
+##枢
+##枣
+##枪
+##枫
+##枭
+##枯
+##枰
+##枱
+##枳
+##架
+##枷
+##枸
+##柄
+##柏
+##某
+##柑
+##柒
+##染
+##柔
+##柘
+##柚
+##柜
+##柞
+##柠
+##柢
+##查
+##柩
+##柬
+##柯
+##柱
+##柳
+##柴
+##柵
+##査
+##柿
+##栀
+##栃
+##栄
+##栅
+##标
+##栈
+##栉
+##栋
+##栎
+##栏
+##树
+##栓
+##栖
+##栗
+##校
+##栩
+##株
+##样
+##核
+##根
+##格
+##栽
+##栾
+##桀
+##桁
+##桂
+##桃
+##桅
+##框
+##案
+##桉
+##桌
+##桎
+##桐
+##桑
+##桓
+##桔
+##桜
+##桠
+##桡
+##桢
+##档
+##桥
+##桦
+##桧
+##桨
+##桩
+##桶
+##桿
+##梁
+##梅
+##梆
+##梏
+##梓
+##梗
+##條
+##梟
+##梢
+##梦
+##梧
+##梨
+##梭
+##梯
+##械
+##梳
+##梵
+##梶
+##检
+##棂
+##棄
+##棉
+##棋
+##棍
+##棒
+##棕
+##棗
+##棘
+##棚
+##棟
+##棠
+##棣
+##棧
+##森
+##棱
+##棲
+##棵
+##棹
+##棺
+##椁
+##椅
+##椋
+##植
+##椎
+##椒
+##検
+##椪
+##椭
+##椰
+##椹
+##椽
+##椿
+##楂
+##楊
+##楓
+##楔
+##楚
+##楝
+##楞
+##楠
+##楣
+##楨
+##楫
+##業
+##楮
+##極
+##楷
+##楸
+##楹
+##楼
+##楽
+##概
+##榄
+##榆
+##榈
+##榉
+##榔
+##榕
+##榖
+##榛
+##榜
+##榨
+##榫
+##榭
+##榮
+##榱
+##榴
+##榷
+##榻
+##槁
+##槃
+##構
+##槌
+##槍
+##槎
+##槐
+##槓
+##様
+##槛
+##槟
+##槤
+##槭
+##槲
+##槳
+##槻
+##槽
+##槿
+##樁
+##樂
+##樊
+##樑
+##樓
+##標
+##樞
+##樟
+##模
+##樣
+##権
+##横
+##樫
+##樯
+##樱
+##樵
+##樸
+##樹
+##樺
+##樽
+##樾
+##橄
+##橇
+##橋
+##橐
+##橘
+##橙
+##機
+##橡
+##橢
+##橫
+##橱
+##橹
+##橼
+##檀
+##檄
+##檎
+##檐
+##檔
+##檗
+##檜
+##檢
+##檬
+##檯
+##檳
+##檸
+##檻
+##櫃
+##櫚
+##櫛
+##櫥
+##櫸
+##櫻
+##欄
+##權
+##欒
+##欖
+##欠
+##次
+##欢
+##欣
+##欧
+##欲
+##欸
+##欺
+##欽
+##款
+##歆
+##歇
+##歉
+##歌
+##歎
+##歐
+##歓
+##歙
+##歛
+##歡
+##止
+##正
+##此
+##步
+##武
+##歧
+##歩
+##歪
+##歯
+##歲
+##歳
+##歴
+##歷
+##歸
+##歹
+##死
+##歼
+##殁
+##殃
+##殆
+##殇
+##殉
+##殊
+##残
+##殒
+##殓
+##殖
+##殘
+##殞
+##殡
+##殤
+##殭
+##殯
+##殲
+##殴
+##段
+##殷
+##殺
+##殼
+##殿
+##毀
+##毁
+##毂
+##毅
+##毆
+##毋
+##母
+##毎
+##每
+##毒
+##毓
+##比
+##毕
+##毗
+##毘
+##毙
+##毛
+##毡
+##毫
+##毯
+##毽
+##氈
+##氏
+##氐
+##民
+##氓
+##气
+##氖
+##気
+##氙
+##氛
+##氟
+##氡
+##氢
+##氣
+##氤
+##氦
+##氧
+##氨
+##氪
+##氫
+##氮
+##氯
+##氰
+##氲
+##水
+##氷
+##永
+##氹
+##氾
+##汀
+##汁
+##求
+##汆
+##汇
+##汉
+##汎
+##汐
+##汕
+##汗
+##汙
+##汛
+##汝
+##汞
+##江
+##池
+##污
+##汤
+##汨
+##汩
+##汪
+##汰
+##汲
+##汴
+##汶
+##汹
+##決
+##汽
+##汾
+##沁
+##沂
+##沃
+##沅
+##沈
+##沉
+##沌
+##沏
+##沐
+##沒
+##沓
+##沖
+##沙
+##沛
+##沟
+##没
+##沢
+##沣
+##沥
+##沦
+##沧
+##沪
+##沫
+##沭
+##沮
+##沱
+##河
+##沸
+##油
+##治
+##沼
+##沽
+##沾
+##沿
+##況
+##泄
+##泉
+##泊
+##泌
+##泓
+##法
+##泗
+##泛
+##泞
+##泠
+##泡
+##波
+##泣
+##泥
+##注
+##泪
+##泫
+##泮
+##泯
+##泰
+##泱
+##泳
+##泵
+##泷
+##泸
+##泻
+##泼
+##泽
+##泾
+##洁
+##洄
+##洋
+##洒
+##洗
+##洙
+##洛
+##洞
+##津
+##洩
+##洪
+##洮
+##洱
+##洲
+##洵
+##洶
+##洸
+##洹
+##活
+##洼
+##洽
+##派
+##流
+##浃
+##浄
+##浅
+##浆
+##浇
+##浊
+##测
+##济
+##浏
+##浑
+##浒
+##浓
+##浔
+##浙
+##浚
+##浜
+##浣
+##浦
+##浩
+##浪
+##浬
+##浮
+##浯
+##浴
+##海
+##浸
+##涂
+##涅
+##涇
+##消
+##涉
+##涌
+##涎
+##涓
+##涔
+##涕
+##涙
+##涛
+##涝
+##涞
+##涟
+##涠
+##涡
+##涣
+##涤
+##润
+##涧
+##涨
+##涩
+##涪
+##涮
+##涯
+##液
+##涵
+##涸
+##涼
+##涿
+##淀
+##淄
+##淅
+##淆
+##淇
+##淋
+##淌
+##淑
+##淒
+##淖
+##淘
+##淙
+##淚
+##淞
+##淡
+##淤
+##淦
+##淨
+##淩
+##淪
+##淫
+##淬
+##淮
+##深
+##淳
+##淵
+##混
+##淹
+##淺
+##添
+##淼
+##清
+##済
+##渉
+##渊
+##渋
+##渍
+##渎
+##渐
+##渔
+##渗
+##渙
+##渚
+##減
+##渝
+##渠
+##渡
+##渣
+##渤
+##渥
+##渦
+##温
+##測
+##渭
+##港
+##渲
+##渴
+##游
+##渺
+##渾
+##湃
+##湄
+##湊
+##湍
+##湖
+##湘
+##湛
+##湟
+##湧
+##湫
+##湮
+##湯
+##湳
+##湾
+##湿
+##満
+##溃
+##溅
+##溉
+##溏
+##源
+##準
+##溜
+##溝
+##溟
+##溢
+##溥
+##溧
+##溪
+##溫
+##溯
+##溱
+##溴
+##溶
+##溺
+##溼
+##滁
+##滂
+##滄
+##滅
+##滇
+##滋
+##滌
+##滑
+##滓
+##滔
+##滕
+##滙
+##滚
+##滝
+##滞
+##滟
+##满
+##滢
+##滤
+##滥
+##滦
+##滨
+##滩
+##滬
+##滯
+##滲
+##滴
+##滷
+##滸
+##滾
+##滿
+##漁
+##漂
+##漆
+##漉
+##漏
+##漓
+##演
+##漕
+##漠
+##漢
+##漣
+##漩
+##漪
+##漫
+##漬
+##漯
+##漱
+##漲
+##漳
+##漸
+##漾
+##漿
+##潆
+##潇
+##潋
+##潍
+##潑
+##潔
+##潘
+##潛
+##潜
+##潞
+##潟
+##潢
+##潤
+##潦
+##潧
+##潭
+##潮
+##潰
+##潴
+##潸
+##潺
+##潼
+##澀
+##澄
+##澆
+##澈
+##澍
+##澎
+##澗
+##澜
+##澡
+##澤
+##澧
+##澱
+##澳
+##澹
+##激
+##濁
+##濂
+##濃
+##濑
+##濒
+##濕
+##濘
+##濛
+##濟
+##濠
+##濡
+##濤
+##濫
+##濬
+##濮
+##濯
+##濱
+##濺
+##濾
+##瀅
+##瀆
+##瀉
+##瀋
+##瀏
+##瀑
+##瀕
+##瀘
+##瀚
+##瀛
+##瀝
+##瀞
+##瀟
+##瀧
+##瀨
+##瀬
+##瀰
+##瀾
+##灌
+##灏
+##灑
+##灘
+##灝
+##灞
+##灣
+##火
+##灬
+##灭
+##灯
+##灰
+##灵
+##灶
+##灸
+##灼
+##災
+##灾
+##灿
+##炀
+##炁
+##炅
+##炉
+##炊
+##炎
+##炒
+##炔
+##炕
+##炖
+##炙
+##炜
+##炫
+##炬
+##炭
+##炮
+##炯
+##炳
+##炷
+##炸
+##点
+##為
+##炼
+##炽
+##烁
+##烂
+##烃
+##烈
+##烊
+##烏
+##烘
+##烙
+##烛
+##烟
+##烤
+##烦
+##烧
+##烨
+##烩
+##烫
+##烬
+##热
+##烯
+##烷
+##烹
+##烽
+##焉
+##焊
+##焕
+##焖
+##焗
+##焘
+##焙
+##焚
+##焜
+##無
+##焦
+##焯
+##焰
+##焱
+##然
+##焼
+##煅
+##煉
+##煊
+##煌
+##煎
+##煒
+##煖
+##煙
+##煜
+##煞
+##煤
+##煥
+##煦
+##照
+##煨
+##煩
+##煮
+##煲
+##煸
+##煽
+##熄
+##熊
+##熏
+##熒
+##熔
+##熙
+##熟
+##熠
+##熨
+##熬
+##熱
+##熵
+##熹
+##熾
+##燁
+##燃
+##燄
+##燈
+##燉
+##燊
+##燎
+##燒
+##燔
+##燕
+##燙
+##燜
+##營
+##燥
+##燦
+##燧
+##燭
+##燮
+##燴
+##燻
+##燼
+##燿
+##爆
+##爍
+##爐
+##爛
+##爪
+##爬
+##爭
+##爰
+##爱
+##爲
+##爵
+##父
+##爷
+##爸
+##爹
+##爺
+##爻
+##爽
+##爾
+##牆
+##片
+##版
+##牌
+##牍
+##牒
+##牙
+##牛
+##牝
+##牟
+##牠
+##牡
+##牢
+##牦
+##牧
+##物
+##牯
+##牲
+##牴
+##牵
+##特
+##牺
+##牽
+##犀
+##犁
+##犄
+##犊
+##犍
+##犒
+##犢
+##犧
+##犬
+##犯
+##状
+##犷
+##犸
+##犹
+##狀
+##狂
+##狄
+##狈
+##狎
+##狐
+##狒
+##狗
+##狙
+##狞
+##狠
+##狡
+##狩
+##独
+##狭
+##狮
+##狰
+##狱
+##狸
+##狹
+##狼
+##狽
+##猎
+##猕
+##猖
+##猗
+##猙
+##猛
+##猜
+##猝
+##猥
+##猩
+##猪
+##猫
+##猬
+##献
+##猴
+##猶
+##猷
+##猾
+##猿
+##獄
+##獅
+##獎
+##獐
+##獒
+##獗
+##獠
+##獣
+##獨
+##獭
+##獰
+##獲
+##獵
+##獷
+##獸
+##獺
+##獻
+##獼
+##獾
+##玄
+##率
+##玉
+##王
+##玑
+##玖
+##玛
+##玟
+##玠
+##玥
+##玩
+##玫
+##玮
+##环
+##现
+##玲
+##玳
+##玷
+##玺
+##玻
+##珀
+##珂
+##珅
+##珈
+##珉
+##珊
+##珍
+##珏
+##珐
+##珑
+##珙
+##珞
+##珠
+##珣
+##珥
+##珩
+##珪
+##班
+##珮
+##珲
+##珺
+##現
+##球
+##琅
+##理
+##琇
+##琉
+##琊
+##琍
+##琏
+##琐
+##琛
+##琢
+##琥
+##琦
+##琨
+##琪
+##琬
+##琮
+##琰
+##琲
+##琳
+##琴
+##琵
+##琶
+##琺
+##琼
+##瑀
+##瑁
+##瑄
+##瑋
+##瑕
+##瑗
+##瑙
+##瑚
+##瑛
+##瑜
+##瑞
+##瑟
+##瑠
+##瑣
+##瑤
+##瑩
+##瑪
+##瑯
+##瑰
+##瑶
+##瑾
+##璀
+##璁
+##璃
+##璇
+##璉
+##璋
+##璎
+##璐
+##璜
+##璞
+##璟
+##璧
+##璨
+##環
+##璽
+##璿
+##瓊
+##瓏
+##瓒
+##瓜
+##瓢
+##瓣
+##瓤
+##瓦
+##瓮
+##瓯
+##瓴
+##瓶
+##瓷
+##甄
+##甌
+##甕
+##甘
+##甙
+##甚
+##甜
+##生
+##產
+##産
+##甥
+##甦
+##用
+##甩
+##甫
+##甬
+##甭
+##甯
+##田
+##由
+##甲
+##申
+##电
+##男
+##甸
+##町
+##画
+##甾
+##畀
+##畅
+##界
+##畏
+##畑
+##畔
+##留
+##畜
+##畝
+##畢
+##略
+##畦
+##番
+##畫
+##異
+##畲
+##畳
+##畴
+##當
+##畸
+##畹
+##畿
+##疆
+##疇
+##疊
+##疏
+##疑
+##疔
+##疖
+##疗
+##疙
+##疚
+##疝
+##疟
+##疡
+##疣
+##疤
+##疥
+##疫
+##疮
+##疯
+##疱
+##疲
+##疳
+##疵
+##疸
+##疹
+##疼
+##疽
+##疾
+##痂
+##病
+##症
+##痈
+##痉
+##痊
+##痍
+##痒
+##痔
+##痕
+##痘
+##痙
+##痛
+##痞
+##痠
+##痢
+##痣
+##痤
+##痧
+##痨
+##痪
+##痫
+##痰
+##痱
+##痴
+##痹
+##痺
+##痼
+##痿
+##瘀
+##瘁
+##瘋
+##瘍
+##瘓
+##瘘
+##瘙
+##瘟
+##瘠
+##瘡
+##瘢
+##瘤
+##瘦
+##瘧
+##瘩
+##瘪
+##瘫
+##瘴
+##瘸
+##瘾
+##療
+##癇
+##癌
+##癒
+##癖
+##癜
+##癞
+##癡
+##癢
+##癣
+##癥
+##癫
+##癬
+##癮
+##癱
+##癲
+##癸
+##発
+##登
+##發
+##白
+##百
+##皂
+##的
+##皆
+##皇
+##皈
+##皋
+##皎
+##皑
+##皓
+##皖
+##皙
+##皚
+##皮
+##皰
+##皱
+##皴
+##皺
+##皿
+##盂
+##盃
+##盅
+##盆
+##盈
+##益
+##盎
+##盏
+##盐
+##监
+##盒
+##盔
+##盖
+##盗
+##盘
+##盛
+##盜
+##盞
+##盟
+##盡
+##監
+##盤
+##盥
+##盧
+##盪
+##目
+##盯
+##盱
+##盲
+##直
+##相
+##盹
+##盼
+##盾
+##省
+##眈
+##眉
+##看
+##県
+##眙
+##眞
+##真
+##眠
+##眦
+##眨
+##眩
+##眯
+##眶
+##眷
+##眸
+##眺
+##眼
+##眾
+##着
+##睁
+##睇
+##睏
+##睐
+##睑
+##睛
+##睜
+##睞
+##睡
+##睢
+##督
+##睥
+##睦
+##睨
+##睪
+##睫
+##睬
+##睹
+##睽
+##睾
+##睿
+##瞄
+##瞅
+##瞇
+##瞋
+##瞌
+##瞎
+##瞑
+##瞒
+##瞓
+##瞞
+##瞟
+##瞠
+##瞥
+##瞧
+##瞩
+##瞪
+##瞬
+##瞭
+##瞰
+##瞳
+##瞻
+##瞼
+##瞿
+##矇
+##矍
+##矗
+##矚
+##矛
+##矜
+##矢
+##矣
+##知
+##矩
+##矫
+##短
+##矮
+##矯
+##石
+##矶
+##矽
+##矾
+##矿
+##码
+##砂
+##砌
+##砍
+##砒
+##研
+##砖
+##砗
+##砚
+##砝
+##砣
+##砥
+##砧
+##砭
+##砰
+##砲
+##破
+##砷
+##砸
+##砺
+##砼
+##砾
+##础
+##硅
+##硐
+##硒
+##硕
+##硝
+##硫
+##硬
+##确
+##硯
+##硼
+##碁
+##碇
+##碉
+##碌
+##碍
+##碎
+##碑
+##碓
+##碗
+##碘
+##碚
+##碛
+##碟
+##碣
+##碧
+##碩
+##碰
+##碱
+##碳
+##碴
+##確
+##碼
+##碾
+##磁
+##磅
+##磊
+##磋
+##磐
+##磕
+##磚
+##磡
+##磨
+##磬
+##磯
+##磲
+##磷
+##磺
+##礁
+##礎
+##礙
+##礡
+##礦
+##礪
+##礫
+##礴
+##示
+##礼
+##社
+##祀
+##祁
+##祂
+##祇
+##祈
+##祉
+##祎
+##祐
+##祕
+##祖
+##祗
+##祚
+##祛
+##祜
+##祝
+##神
+##祟
+##祠
+##祢
+##祥
+##票
+##祭
+##祯
+##祷
+##祸
+##祺
+##祿
+##禀
+##禁
+##禄
+##禅
+##禍
+##禎
+##福
+##禛
+##禦
+##禧
+##禪
+##禮
+##禱
+##禹
+##禺
+##离
+##禽
+##禾
+##禿
+##秀
+##私
+##秃
+##秆
+##秉
+##秋
+##种
+##科
+##秒
+##秘
+##租
+##秣
+##秤
+##秦
+##秧
+##秩
+##秭
+##积
+##称
+##秸
+##移
+##秽
+##稀
+##稅
+##程
+##稍
+##税
+##稔
+##稗
+##稚
+##稜
+##稞
+##稟
+##稠
+##稣
+##種
+##稱
+##稲
+##稳
+##稷
+##稹
+##稻
+##稼
+##稽
+##稿
+##穀
+##穂
+##穆
+##穌
+##積
+##穎
+##穗
+##穢
+##穩
+##穫
+##穴
+##究
+##穷
+##穹
+##空
+##穿
+##突
+##窃
+##窄
+##窈
+##窍
+##窑
+##窒
+##窓
+##窕
+##窖
+##窗
+##窘
+##窜
+##窝
+##窟
+##窠
+##窥
+##窦
+##窨
+##窩
+##窪
+##窮
+##窯
+##窺
+##窿
+##竄
+##竅
+##竇
+##竊
+##立
+##竖
+##站
+##竜
+##竞
+##竟
+##章
+##竣
+##童
+##竭
+##端
+##競
+##竹
+##竺
+##竽
+##竿
+##笃
+##笆
+##笈
+##笋
+##笏
+##笑
+##笔
+##笙
+##笛
+##笞
+##笠
+##符
+##笨
+##第
+##笹
+##笺
+##笼
+##筆
+##等
+##筊
+##筋
+##筍
+##筏
+##筐
+##筑
+##筒
+##答
+##策
+##筛
+##筝
+##筠
+##筱
+##筲
+##筵
+##筷
+##筹
+##签
+##简
+##箇
+##箋
+##箍
+##箏
+##箐
+##箔
+##箕
+##算
+##箝
+##管
+##箩
+##箫
+##箭
+##箱
+##箴
+##箸
+##節
+##篁
+##範
+##篆
+##篇
+##築
+##篑
+##篓
+##篙
+##篝
+##篠
+##篡
+##篤
+##篩
+##篪
+##篮
+##篱
+##篷
+##簇
+##簌
+##簍
+##簡
+##簦
+##簧
+##簪
+##簫
+##簷
+##簸
+##簽
+##簾
+##簿
+##籁
+##籃
+##籌
+##籍
+##籐
+##籟
+##籠
+##籤
+##籬
+##籮
+##籲
+##米
+##类
+##籼
+##籽
+##粄
+##粉
+##粑
+##粒
+##粕
+##粗
+##粘
+##粟
+##粤
+##粥
+##粧
+##粪
+##粮
+##粱
+##粲
+##粳
+##粵
+##粹
+##粼
+##粽
+##精
+##粿
+##糅
+##糊
+##糍
+##糕
+##糖
+##糗
+##糙
+##糜
+##糞
+##糟
+##糠
+##糧
+##糬
+##糯
+##糰
+##糸
+##系
+##糾
+##紀
+##紂
+##約
+##紅
+##紉
+##紊
+##紋
+##納
+##紐
+##紓
+##純
+##紗
+##紘
+##紙
+##級
+##紛
+##紜
+##素
+##紡
+##索
+##紧
+##紫
+##紮
+##累
+##細
+##紳
+##紹
+##紺
+##終
+##絃
+##組
+##絆
+##経
+##結
+##絕
+##絞
+##絡
+##絢
+##給
+##絨
+##絮
+##統
+##絲
+##絳
+##絵
+##絶
+##絹
+##綁
+##綏
+##綑
+##經
+##継
+##続
+##綜
+##綠
+##綢
+##綦
+##綫
+##綬
+##維
+##綱
+##網
+##綴
+##綵
+##綸
+##綺
+##綻
+##綽
+##綾
+##綿
+##緊
+##緋
+##総
+##緑
+##緒
+##緘
+##線
+##緝
+##緞
+##締
+##緣
+##編
+##緩
+##緬
+##緯
+##練
+##緹
+##緻
+##縁
+##縄
+##縈
+##縛
+##縝
+##縣
+##縫
+##縮
+##縱
+##縴
+##縷
+##總
+##績
+##繁
+##繃
+##繆
+##繇
+##繋
+##織
+##繕
+##繚
+##繞
+##繡
+##繩
+##繪
+##繫
+##繭
+##繳
+##繹
+##繼
+##繽
+##纂
+##續
+##纍
+##纏
+##纓
+##纔
+##纖
+##纜
+##纠
+##红
+##纣
+##纤
+##约
+##级
+##纨
+##纪
+##纫
+##纬
+##纭
+##纯
+##纰
+##纱
+##纲
+##纳
+##纵
+##纶
+##纷
+##纸
+##纹
+##纺
+##纽
+##纾
+##线
+##绀
+##练
+##组
+##绅
+##细
+##织
+##终
+##绊
+##绍
+##绎
+##经
+##绑
+##绒
+##结
+##绔
+##绕
+##绘
+##给
+##绚
+##绛
+##络
+##绝
+##绞
+##统
+##绡
+##绢
+##绣
+##绥
+##绦
+##继
+##绩
+##绪
+##绫
+##续
+##绮
+##绯
+##绰
+##绳
+##维
+##绵
+##绶
+##绷
+##绸
+##绻
+##综
+##绽
+##绾
+##绿
+##缀
+##缄
+##缅
+##缆
+##缇
+##缈
+##缉
+##缎
+##缓
+##缔
+##缕
+##编
+##缘
+##缙
+##缚
+##缜
+##缝
+##缠
+##缢
+##缤
+##缥
+##缨
+##缩
+##缪
+##缭
+##缮
+##缰
+##缱
+##缴
+##缸
+##缺
+##缽
+##罂
+##罄
+##罌
+##罐
+##网
+##罔
+##罕
+##罗
+##罚
+##罡
+##罢
+##罩
+##罪
+##置
+##罰
+##署
+##罵
+##罷
+##罹
+##羁
+##羅
+##羈
+##羊
+##羌
+##美
+##羔
+##羚
+##羞
+##羟
+##羡
+##羣
+##群
+##羥
+##羧
+##羨
+##義
+##羯
+##羲
+##羸
+##羹
+##羽
+##羿
+##翁
+##翅
+##翊
+##翌
+##翎
+##習
+##翔
+##翘
+##翟
+##翠
+##翡
+##翦
+##翩
+##翰
+##翱
+##翳
+##翹
+##翻
+##翼
+##耀
+##老
+##考
+##耄
+##者
+##耆
+##耋
+##而
+##耍
+##耐
+##耒
+##耕
+##耗
+##耘
+##耙
+##耦
+##耨
+##耳
+##耶
+##耷
+##耸
+##耻
+##耽
+##耿
+##聂
+##聆
+##聊
+##聋
+##职
+##聒
+##联
+##聖
+##聘
+##聚
+##聞
+##聪
+##聯
+##聰
+##聲
+##聳
+##聴
+##聶
+##職
+##聽
+##聾
+##聿
+##肃
+##肄
+##肅
+##肆
+##肇
+##肉
+##肋
+##肌
+##肏
+##肓
+##肖
+##肘
+##肚
+##肛
+##肝
+##肠
+##股
+##肢
+##肤
+##肥
+##肩
+##肪
+##肮
+##肯
+##肱
+##育
+##肴
+##肺
+##肽
+##肾
+##肿
+##胀
+##胁
+##胃
+##胄
+##胆
+##背
+##胍
+##胎
+##胖
+##胚
+##胛
+##胜
+##胝
+##胞
+##胡
+##胤
+##胥
+##胧
+##胫
+##胭
+##胯
+##胰
+##胱
+##胳
+##胴
+##胶
+##胸
+##胺
+##能
+##脂
+##脅
+##脆
+##脇
+##脈
+##脉
+##脊
+##脍
+##脏
+##脐
+##脑
+##脓
+##脖
+##脘
+##脚
+##脛
+##脣
+##脩
+##脫
+##脯
+##脱
+##脲
+##脳
+##脸
+##脹
+##脾
+##腆
+##腈
+##腊
+##腋
+##腌
+##腎
+##腐
+##腑
+##腓
+##腔
+##腕
+##腥
+##腦
+##腩
+##腫
+##腭
+##腮
+##腰
+##腱
+##腳
+##腴
+##腸
+##腹
+##腺
+##腻
+##腼
+##腾
+##腿
+##膀
+##膈
+##膊
+##膏
+##膑
+##膘
+##膚
+##膛
+##膜
+##膝
+##膠
+##膦
+##膨
+##膩
+##膳
+##膺
+##膻
+##膽
+##膾
+##膿
+##臀
+##臂
+##臃
+##臆
+##臉
+##臊
+##臍
+##臓
+##臘
+##臟
+##臣
+##臥
+##臧
+##臨
+##自
+##臬
+##臭
+##至
+##致
+##臺
+##臻
+##臼
+##臾
+##舀
+##舂
+##舅
+##舆
+##與
+##興
+##舉
+##舊
+##舌
+##舍
+##舎
+##舐
+##舒
+##舔
+##舖
+##舗
+##舛
+##舜
+##舞
+##舟
+##航
+##舫
+##般
+##舰
+##舱
+##舵
+##舶
+##舷
+##舸
+##船
+##舺
+##舾
+##艇
+##艋
+##艘
+##艙
+##艦
+##艮
+##良
+##艰
+##艱
+##色
+##艳
+##艷
+##艹
+##艺
+##艾
+##节
+##芃
+##芈
+##芊
+##芋
+##芍
+##芎
+##芒
+##芙
+##芜
+##芝
+##芡
+##芥
+##芦
+##芩
+##芪
+##芫
+##芬
+##芭
+##芮
+##芯
+##花
+##芳
+##芷
+##芸
+##芹
+##芻
+##芽
+##芾
+##苁
+##苄
+##苇
+##苋
+##苍
+##苏
+##苑
+##苒
+##苓
+##苔
+##苕
+##苗
+##苛
+##苜
+##苞
+##苟
+##苡
+##苣
+##若
+##苦
+##苫
+##苯
+##英
+##苷
+##苹
+##苻
+##茁
+##茂
+##范
+##茄
+##茅
+##茉
+##茎
+##茏
+##茗
+##茜
+##茧
+##茨
+##茫
+##茬
+##茭
+##茯
+##茱
+##茲
+##茴
+##茵
+##茶
+##茸
+##茹
+##茼
+##荀
+##荃
+##荆
+##草
+##荊
+##荏
+##荐
+##荒
+##荔
+##荖
+##荘
+##荚
+##荞
+##荟
+##荠
+##荡
+##荣
+##荤
+##荥
+##荧
+##荨
+##荪
+##荫
+##药
+##荳
+##荷
+##荸
+##荻
+##荼
+##荽
+##莅
+##莆
+##莉
+##莊
+##莎
+##莒
+##莓
+##莖
+##莘
+##莞
+##莠
+##莢
+##莧
+##莪
+##莫
+##莱
+##莲
+##莴
+##获
+##莹
+##莺
+##莽
+##莿
+##菀
+##菁
+##菅
+##菇
+##菈
+##菊
+##菌
+##菏
+##菓
+##菖
+##菘
+##菜
+##菟
+##菠
+##菡
+##菩
+##華
+##菱
+##菲
+##菸
+##菽
+##萁
+##萃
+##萄
+##萊
+##萋
+##萌
+##萍
+##萎
+##萘
+##萝
+##萤
+##营
+##萦
+##萧
+##萨
+##萩
+##萬
+##萱
+##萵
+##萸
+##萼
+##落
+##葆
+##葉
+##著
+##葚
+##葛
+##葡
+##董
+##葦
+##葩
+##葫
+##葬
+##葭
+##葯
+##葱
+##葳
+##葵
+##葷
+##葺
+##蒂
+##蒋
+##蒐
+##蒔
+##蒙
+##蒜
+##蒞
+##蒟
+##蒡
+##蒨
+##蒲
+##蒸
+##蒹
+##蒻
+##蒼
+##蒿
+##蓁
+##蓄
+##蓆
+##蓉
+##蓋
+##蓑
+##蓓
+##蓖
+##蓝
+##蓟
+##蓦
+##蓬
+##蓮
+##蓼
+##蓿
+##蔑
+##蔓
+##蔔
+##蔗
+##蔘
+##蔚
+##蔡
+##蔣
+##蔥
+##蔫
+##蔬
+##蔭
+##蔵
+##蔷
+##蔺
+##蔻
+##蔼
+##蔽
+##蕁
+##蕃
+##蕈
+##蕉
+##蕊
+##蕎
+##蕙
+##蕤
+##蕨
+##蕩
+##蕪
+##蕭
+##蕲
+##蕴
+##蕻
+##蕾
+##薄
+##薅
+##薇
+##薈
+##薊
+##薏
+##薑
+##薔
+##薙
+##薛
+##薦
+##薨
+##薩
+##薪
+##薬
+##薯
+##薰
+##薹
+##藉
+##藍
+##藏
+##藐
+##藓
+##藕
+##藜
+##藝
+##藤
+##藥
+##藩
+##藹
+##藻
+##藿
+##蘆
+##蘇
+##蘊
+##蘋
+##蘑
+##蘚
+##蘭
+##蘸
+##蘼
+##蘿
+##虎
+##虏
+##虐
+##虑
+##虔
+##處
+##虚
+##虛
+##虜
+##虞
+##號
+##虢
+##虧
+##虫
+##虬
+##虱
+##虹
+##虻
+##虽
+##虾
+##蚀
+##蚁
+##蚂
+##蚊
+##蚌
+##蚓
+##蚕
+##蚜
+##蚝
+##蚣
+##蚤
+##蚩
+##蚪
+##蚯
+##蚱
+##蚵
+##蛀
+##蛆
+##蛇
+##蛊
+##蛋
+##蛎
+##蛐
+##蛔
+##蛙
+##蛛
+##蛟
+##蛤
+##蛭
+##蛮
+##蛰
+##蛳
+##蛹
+##蛻
+##蛾
+##蜀
+##蜂
+##蜃
+##蜆
+##蜇
+##蜈
+##蜊
+##蜍
+##蜒
+##蜓
+##蜕
+##蜗
+##蜘
+##蜚
+##蜜
+##蜡
+##蜢
+##蜥
+##蜱
+##蜴
+##蜷
+##蜻
+##蜿
+##蝇
+##蝈
+##蝉
+##蝌
+##蝎
+##蝕
+##蝗
+##蝙
+##蝟
+##蝠
+##蝦
+##蝨
+##蝴
+##蝶
+##蝸
+##蝼
+##螂
+##螃
+##融
+##螞
+##螢
+##螨
+##螯
+##螳
+##螺
+##蟀
+##蟄
+##蟆
+##蟋
+##蟎
+##蟑
+##蟒
+##蟠
+##蟬
+##蟲
+##蟹
+##蟻
+##蟾
+##蠅
+##蠍
+##蠔
+##蠕
+##蠛
+##蠟
+##蠡
+##蠢
+##蠣
+##蠱
+##蠶
+##蠹
+##蠻
+##血
+##衄
+##衅
+##衆
+##行
+##衍
+##術
+##衔
+##街
+##衙
+##衛
+##衝
+##衞
+##衡
+##衢
+##衣
+##补
+##表
+##衩
+##衫
+##衬
+##衮
+##衰
+##衲
+##衷
+##衹
+##衾
+##衿
+##袁
+##袂
+##袄
+##袅
+##袈
+##袋
+##袍
+##袒
+##袖
+##袜
+##袞
+##袤
+##袪
+##被
+##袭
+##袱
+##裁
+##裂
+##装
+##裆
+##裊
+##裏
+##裔
+##裕
+##裘
+##裙
+##補
+##裝
+##裟
+##裡
+##裤
+##裨
+##裱
+##裳
+##裴
+##裸
+##裹
+##製
+##裾
+##褂
+##複
+##褐
+##褒
+##褓
+##褔
+##褚
+##褥
+##褪
+##褫
+##褲
+##褶
+##褻
+##襁
+##襄
+##襟
+##襠
+##襪
+##襬
+##襯
+##襲
+##西
+##要
+##覃
+##覆
+##覇
+##見
+##規
+##覓
+##視
+##覚
+##覦
+##覧
+##親
+##覬
+##観
+##覷
+##覺
+##覽
+##觀
+##见
+##观
+##规
+##觅
+##视
+##览
+##觉
+##觊
+##觎
+##觐
+##觑
+##角
+##觞
+##解
+##觥
+##触
+##觸
+##言
+##訂
+##計
+##訊
+##討
+##訓
+##訕
+##訖
+##託
+##記
+##訛
+##訝
+##訟
+##訣
+##訥
+##訪
+##設
+##許
+##訳
+##訴
+##訶
+##診
+##註
+##証
+##詆
+##詐
+##詔
+##評
+##詛
+##詞
+##詠
+##詡
+##詢
+##詣
+##試
+##詩
+##詫
+##詬
+##詭
+##詮
+##詰
+##話
+##該
+##詳
+##詹
+##詼
+##誅
+##誇
+##誉
+##誌
+##認
+##誓
+##誕
+##誘
+##語
+##誠
+##誡
+##誣
+##誤
+##誥
+##誦
+##誨
+##說
+##説
+##読
+##誰
+##課
+##誹
+##誼
+##調
+##諄
+##談
+##請
+##諏
+##諒
+##論
+##諗
+##諜
+##諡
+##諦
+##諧
+##諫
+##諭
+##諮
+##諱
+##諳
+##諷
+##諸
+##諺
+##諾
+##謀
+##謁
+##謂
+##謄
+##謊
+##謎
+##謐
+##謔
+##謗
+##謙
+##講
+##謝
+##謠
+##謨
+##謬
+##謹
+##謾
+##譁
+##證
+##譎
+##譏
+##識
+##譙
+##譚
+##譜
+##警
+##譬
+##譯
+##議
+##譲
+##譴
+##護
+##譽
+##讀
+##變
+##讓
+##讚
+##讞
+##计
+##订
+##认
+##讥
+##讧
+##讨
+##让
+##讪
+##讫
+##训
+##议
+##讯
+##记
+##讲
+##讳
+##讴
+##讶
+##讷
+##许
+##讹
+##论
+##讼
+##讽
+##设
+##访
+##诀
+##证
+##诃
+##评
+##诅
+##识
+##诈
+##诉
+##诊
+##诋
+##词
+##诏
+##译
+##试
+##诗
+##诘
+##诙
+##诚
+##诛
+##话
+##诞
+##诟
+##诠
+##诡
+##询
+##诣
+##诤
+##该
+##详
+##诧
+##诩
+##诫
+##诬
+##语
+##误
+##诰
+##诱
+##诲
+##说
+##诵
+##诶
+##请
+##诸
+##诺
+##读
+##诽
+##课
+##诿
+##谀
+##谁
+##调
+##谄
+##谅
+##谆
+##谈
+##谊
+##谋
+##谌
+##谍
+##谎
+##谏
+##谐
+##谑
+##谒
+##谓
+##谔
+##谕
+##谗
+##谘
+##谙
+##谚
+##谛
+##谜
+##谟
+##谢
+##谣
+##谤
+##谥
+##谦
+##谧
+##谨
+##谩
+##谪
+##谬
+##谭
+##谯
+##谱
+##谲
+##谴
+##谶
+##谷
+##豁
+##豆
+##豇
+##豈
+##豉
+##豊
+##豌
+##豎
+##豐
+##豔
+##豚
+##象
+##豢
+##豪
+##豫
+##豬
+##豹
+##豺
+##貂
+##貅
+##貌
+##貓
+##貔
+##貘
+##貝
+##貞
+##負
+##財
+##貢
+##貧
+##貨
+##販
+##貪
+##貫
+##責
+##貯
+##貰
+##貳
+##貴
+##貶
+##買
+##貸
+##費
+##貼
+##貽
+##貿
+##賀
+##賁
+##賂
+##賃
+##賄
+##資
+##賈
+##賊
+##賑
+##賓
+##賜
+##賞
+##賠
+##賡
+##賢
+##賣
+##賤
+##賦
+##質
+##賬
+##賭
+##賴
+##賺
+##購
+##賽
+##贅
+##贈
+##贊
+##贍
+##贏
+##贓
+##贖
+##贛
+##贝
+##贞
+##负
+##贡
+##财
+##责
+##贤
+##败
+##账
+##货
+##质
+##贩
+##贪
+##贫
+##贬
+##购
+##贮
+##贯
+##贰
+##贱
+##贲
+##贴
+##贵
+##贷
+##贸
+##费
+##贺
+##贻
+##贼
+##贾
+##贿
+##赁
+##赂
+##赃
+##资
+##赅
+##赈
+##赊
+##赋
+##赌
+##赎
+##赏
+##赐
+##赓
+##赔
+##赖
+##赘
+##赚
+##赛
+##赝
+##赞
+##赠
+##赡
+##赢
+##赣
+##赤
+##赦
+##赧
+##赫
+##赭
+##走
+##赳
+##赴
+##赵
+##赶
+##起
+##趁
+##超
+##越
+##趋
+##趕
+##趙
+##趟
+##趣
+##趨
+##足
+##趴
+##趵
+##趸
+##趺
+##趾
+##跃
+##跄
+##跆
+##跋
+##跌
+##跎
+##跑
+##跖
+##跚
+##跛
+##距
+##跟
+##跡
+##跤
+##跨
+##跩
+##跪
+##路
+##跳
+##践
+##跷
+##跹
+##跺
+##跻
+##踉
+##踊
+##踌
+##踏
+##踐
+##踝
+##踞
+##踟
+##踢
+##踩
+##踪
+##踮
+##踱
+##踴
+##踵
+##踹
+##蹂
+##蹄
+##蹇
+##蹈
+##蹉
+##蹊
+##蹋
+##蹑
+##蹒
+##蹙
+##蹟
+##蹣
+##蹤
+##蹦
+##蹩
+##蹬
+##蹭
+##蹲
+##蹴
+##蹶
+##蹺
+##蹼
+##蹿
+##躁
+##躇
+##躉
+##躊
+##躋
+##躍
+##躏
+##躪
+##身
+##躬
+##躯
+##躲
+##躺
+##軀
+##車
+##軋
+##軌
+##軍
+##軒
+##軟
+##転
+##軸
+##軼
+##軽
+##軾
+##較
+##載
+##輒
+##輓
+##輔
+##輕
+##輛
+##輝
+##輟
+##輩
+##輪
+##輯
+##輸
+##輻
+##輾
+##輿
+##轄
+##轅
+##轆
+##轉
+##轍
+##轎
+##轟
+##车
+##轧
+##轨
+##轩
+##转
+##轭
+##轮
+##软
+##轰
+##轲
+##轴
+##轶
+##轻
+##轼
+##载
+##轿
+##较
+##辄
+##辅
+##辆
+##辇
+##辈
+##辉
+##辊
+##辍
+##辐
+##辑
+##输
+##辕
+##辖
+##辗
+##辘
+##辙
+##辛
+##辜
+##辞
+##辟
+##辣
+##辦
+##辨
+##辩
+##辫
+##辭
+##辮
+##辯
+##辰
+##辱
+##農
+##边
+##辺
+##辻
+##込
+##辽
+##达
+##迁
+##迂
+##迄
+##迅
+##过
+##迈
+##迎
+##运
+##近
+##返
+##还
+##这
+##进
+##远
+##违
+##连
+##迟
+##迢
+##迤
+##迥
+##迦
+##迩
+##迪
+##迫
+##迭
+##述
+##迴
+##迷
+##迸
+##迹
+##迺
+##追
+##退
+##送
+##适
+##逃
+##逅
+##逆
+##选
+##逊
+##逍
+##透
+##逐
+##递
+##途
+##逕
+##逗
+##這
+##通
+##逛
+##逝
+##逞
+##速
+##造
+##逢
+##連
+##逮
+##週
+##進
+##逵
+##逶
+##逸
+##逻
+##逼
+##逾
+##遁
+##遂
+##遅
+##遇
+##遊
+##運
+##遍
+##過
+##遏
+##遐
+##遑
+##遒
+##道
+##達
+##違
+##遗
+##遙
+##遛
+##遜
+##遞
+##遠
+##遢
+##遣
+##遥
+##遨
+##適
+##遭
+##遮
+##遲
+##遴
+##遵
+##遶
+##遷
+##選
+##遺
+##遼
+##遽
+##避
+##邀
+##邁
+##邂
+##邃
+##還
+##邇
+##邈
+##邊
+##邋
+##邏
+##邑
+##邓
+##邕
+##邛
+##邝
+##邢
+##那
+##邦
+##邨
+##邪
+##邬
+##邮
+##邯
+##邰
+##邱
+##邳
+##邵
+##邸
+##邹
+##邺
+##邻
+##郁
+##郅
+##郊
+##郎
+##郑
+##郜
+##郝
+##郡
+##郢
+##郤
+##郦
+##郧
+##部
+##郫
+##郭
+##郴
+##郵
+##郷
+##郸
+##都
+##鄂
+##鄉
+##鄒
+##鄔
+##鄙
+##鄞
+##鄢
+##鄧
+##鄭
+##鄰
+##鄱
+##鄲
+##鄺
+##酉
+##酊
+##酋
+##酌
+##配
+##酐
+##酒
+##酗
+##酚
+##酝
+##酢
+##酣
+##酥
+##酩
+##酪
+##酬
+##酮
+##酯
+##酰
+##酱
+##酵
+##酶
+##酷
+##酸
+##酿
+##醃
+##醇
+##醉
+##醋
+##醍
+##醐
+##醒
+##醚
+##醛
+##醜
+##醞
+##醣
+##醪
+##醫
+##醬
+##醮
+##醯
+##醴
+##醺
+##釀
+##釁
+##采
+##釉
+##释
+##釋
+##里
+##重
+##野
+##量
+##釐
+##金
+##釗
+##釘
+##釜
+##針
+##釣
+##釦
+##釧
+##釵
+##鈀
+##鈉
+##鈍
+##鈎
+##鈔
+##鈕
+##鈞
+##鈣
+##鈦
+##鈪
+##鈴
+##鈺
+##鈾
+##鉀
+##鉄
+##鉅
+##鉉
+##鉑
+##鉗
+##鉚
+##鉛
+##鉤
+##鉴
+##鉻
+##銀
+##銃
+##銅
+##銑
+##銓
+##銖
+##銘
+##銜
+##銬
+##銭
+##銮
+##銳
+##銷
+##銹
+##鋁
+##鋅
+##鋒
+##鋤
+##鋪
+##鋰
+##鋸
+##鋼
+##錄
+##錐
+##錘
+##錚
+##錠
+##錢
+##錦
+##錨
+##錫
+##錮
+##錯
+##録
+##錳
+##錶
+##鍊
+##鍋
+##鍍
+##鍛
+##鍥
+##鍰
+##鍵
+##鍺
+##鍾
+##鎂
+##鎊
+##鎌
+##鎏
+##鎔
+##鎖
+##鎗
+##鎚
+##鎧
+##鎬
+##鎮
+##鎳
+##鏈
+##鏖
+##鏗
+##鏘
+##鏞
+##鏟
+##鏡
+##鏢
+##鏤
+##鏽
+##鐘
+##鐮
+##鐲
+##鐳
+##鐵
+##鐸
+##鐺
+##鑄
+##鑊
+##鑑
+##鑒
+##鑣
+##鑫
+##鑰
+##鑲
+##鑼
+##鑽
+##鑾
+##鑿
+##针
+##钉
+##钊
+##钎
+##钏
+##钒
+##钓
+##钗
+##钙
+##钛
+##钜
+##钝
+##钞
+##钟
+##钠
+##钡
+##钢
+##钣
+##钤
+##钥
+##钦
+##钧
+##钨
+##钩
+##钮
+##钯
+##钰
+##钱
+##钳
+##钴
+##钵
+##钺
+##钻
+##钼
+##钾
+##钿
+##铀
+##铁
+##铂
+##铃
+##铄
+##铅
+##铆
+##铉
+##铎
+##铐
+##铛
+##铜
+##铝
+##铠
+##铡
+##铢
+##铣
+##铤
+##铨
+##铩
+##铬
+##铭
+##铮
+##铰
+##铲
+##铵
+##银
+##铸
+##铺
+##链
+##铿
+##销
+##锁
+##锂
+##锄
+##锅
+##锆
+##锈
+##锉
+##锋
+##锌
+##锏
+##锐
+##锑
+##错
+##锚
+##锟
+##锡
+##锢
+##锣
+##锤
+##锥
+##锦
+##锭
+##键
+##锯
+##锰
+##锲
+##锵
+##锹
+##锺
+##锻
+##镀
+##镁
+##镂
+##镇
+##镉
+##镌
+##镍
+##镐
+##镑
+##镕
+##镖
+##镗
+##镛
+##镜
+##镣
+##镭
+##镯
+##镰
+##镳
+##镶
+##長
+##长
+##門
+##閃
+##閉
+##開
+##閎
+##閏
+##閑
+##閒
+##間
+##閔
+##閘
+##閡
+##関
+##閣
+##閥
+##閨
+##閩
+##閱
+##閲
+##閹
+##閻
+##閾
+##闆
+##闇
+##闊
+##闌
+##闍
+##闔
+##闕
+##闖
+##闘
+##關
+##闡
+##闢
+##门
+##闪
+##闫
+##闭
+##问
+##闯
+##闰
+##闲
+##间
+##闵
+##闷
+##闸
+##闹
+##闺
+##闻
+##闽
+##闾
+##阀
+##阁
+##阂
+##阅
+##阆
+##阇
+##阈
+##阉
+##阎
+##阐
+##阑
+##阔
+##阕
+##阖
+##阙
+##阚
+##阜
+##队
+##阡
+##阪
+##阮
+##阱
+##防
+##阳
+##阴
+##阵
+##阶
+##阻
+##阿
+##陀
+##陂
+##附
+##际
+##陆
+##陇
+##陈
+##陋
+##陌
+##降
+##限
+##陕
+##陛
+##陝
+##陞
+##陟
+##陡
+##院
+##陣
+##除
+##陨
+##险
+##陪
+##陰
+##陲
+##陳
+##陵
+##陶
+##陷
+##陸
+##険
+##陽
+##隅
+##隆
+##隈
+##隊
+##隋
+##隍
+##階
+##随
+##隐
+##隔
+##隕
+##隘
+##隙
+##際
+##障
+##隠
+##隣
+##隧
+##隨
+##險
+##隱
+##隴
+##隶
+##隸
+##隻
+##隼
+##隽
+##难
+##雀
+##雁
+##雄
+##雅
+##集
+##雇
+##雉
+##雋
+##雌
+##雍
+##雎
+##雏
+##雑
+##雒
+##雕
+##雖
+##雙
+##雛
+##雜
+##雞
+##離
+##難
+##雨
+##雪
+##雯
+##雰
+##雲
+##雳
+##零
+##雷
+##雹
+##電
+##雾
+##需
+##霁
+##霄
+##霆
+##震
+##霈
+##霉
+##霊
+##霍
+##霎
+##霏
+##霑
+##霓
+##霖
+##霜
+##霞
+##霧
+##霭
+##霰
+##露
+##霸
+##霹
+##霽
+##霾
+##靂
+##靄
+##靈
+##青
+##靓
+##靖
+##静
+##靚
+##靛
+##靜
+##非
+##靠
+##靡
+##面
+##靥
+##靦
+##革
+##靳
+##靴
+##靶
+##靼
+##鞅
+##鞋
+##鞍
+##鞏
+##鞑
+##鞘
+##鞠
+##鞣
+##鞦
+##鞭
+##韆
+##韋
+##韌
+##韓
+##韜
+##韦
+##韧
+##韩
+##韬
+##韭
+##音
+##韵
+##韶
+##韻
+##響
+##頁
+##頂
+##頃
+##項
+##順
+##須
+##頌
+##預
+##頑
+##頒
+##頓
+##頗
+##領
+##頜
+##頡
+##頤
+##頫
+##頭
+##頰
+##頷
+##頸
+##頹
+##頻
+##頼
+##顆
+##題
+##額
+##顎
+##顏
+##顔
+##願
+##顛
+##類
+##顧
+##顫
+##顯
+##顱
+##顴
+##页
+##顶
+##顷
+##项
+##顺
+##须
+##顼
+##顽
+##顾
+##顿
+##颁
+##颂
+##预
+##颅
+##领
+##颇
+##颈
+##颉
+##颊
+##颌
+##颍
+##颐
+##频
+##颓
+##颔
+##颖
+##颗
+##题
+##颚
+##颛
+##颜
+##额
+##颞
+##颠
+##颡
+##颢
+##颤
+##颦
+##颧
+##風
+##颯
+##颱
+##颳
+##颶
+##颼
+##飄
+##飆
+##风
+##飒
+##飓
+##飕
+##飘
+##飙
+##飚
+##飛
+##飞
+##食
+##飢
+##飨
+##飩
+##飪
+##飯
+##飲
+##飼
+##飽
+##飾
+##餃
+##餅
+##餉
+##養
+##餌
+##餐
+##餒
+##餓
+##餘
+##餚
+##餛
+##餞
+##餡
+##館
+##餮
+##餵
+##餾
+##饅
+##饈
+##饋
+##饌
+##饍
+##饑
+##饒
+##饕
+##饗
+##饞
+##饥
+##饨
+##饪
+##饬
+##饭
+##饮
+##饯
+##饰
+##饱
+##饲
+##饴
+##饵
+##饶
+##饷
+##饺
+##饼
+##饽
+##饿
+##馀
+##馁
+##馄
+##馅
+##馆
+##馈
+##馋
+##馍
+##馏
+##馒
+##馔
+##首
+##馗
+##香
+##馥
+##馨
+##馬
+##馭
+##馮
+##馳
+##馴
+##駁
+##駄
+##駅
+##駆
+##駐
+##駒
+##駕
+##駛
+##駝
+##駭
+##駱
+##駿
+##騁
+##騎
+##騏
+##験
+##騙
+##騨
+##騰
+##騷
+##驀
+##驅
+##驊
+##驍
+##驒
+##驕
+##驗
+##驚
+##驛
+##驟
+##驢
+##驥
+##马
+##驭
+##驮
+##驯
+##驰
+##驱
+##驳
+##驴
+##驶
+##驷
+##驸
+##驹
+##驻
+##驼
+##驾
+##驿
+##骁
+##骂
+##骄
+##骅
+##骆
+##骇
+##骈
+##骊
+##骋
+##验
+##骏
+##骐
+##骑
+##骗
+##骚
+##骛
+##骜
+##骞
+##骠
+##骡
+##骤
+##骥
+##骧
+##骨
+##骯
+##骰
+##骶
+##骷
+##骸
+##骼
+##髂
+##髅
+##髋
+##髏
+##髒
+##髓
+##體
+##髖
+##高
+##髦
+##髪
+##髮
+##髯
+##髻
+##鬃
+##鬆
+##鬍
+##鬓
+##鬚
+##鬟
+##鬢
+##鬣
+##鬥
+##鬧
+##鬱
+##鬼
+##魁
+##魂
+##魄
+##魅
+##魇
+##魍
+##魏
+##魔
+##魘
+##魚
+##魯
+##魷
+##鮑
+##鮨
+##鮪
+##鮭
+##鮮
+##鯉
+##鯊
+##鯖
+##鯛
+##鯨
+##鯰
+##鯽
+##鰍
+##鰓
+##鰭
+##鰲
+##鰻
+##鰾
+##鱈
+##鱉
+##鱔
+##鱗
+##鱷
+##鱸
+##鱼
+##鱿
+##鲁
+##鲈
+##鲍
+##鲑
+##鲛
+##鲜
+##鲟
+##鲢
+##鲤
+##鲨
+##鲫
+##鲱
+##鲲
+##鲶
+##鲷
+##鲸
+##鳃
+##鳄
+##鳅
+##鳌
+##鳍
+##鳕
+##鳖
+##鳗
+##鳝
+##鳞
+##鳥
+##鳩
+##鳳
+##鳴
+##鳶
+##鴉
+##鴕
+##鴛
+##鴦
+##鴨
+##鴻
+##鴿
+##鵑
+##鵜
+##鵝
+##鵡
+##鵬
+##鵰
+##鵲
+##鶘
+##鶩
+##鶯
+##鶴
+##鷗
+##鷲
+##鷹
+##鷺
+##鸚
+##鸞
+##鸟
+##鸠
+##鸡
+##鸢
+##鸣
+##鸥
+##鸦
+##鸨
+##鸪
+##鸭
+##鸯
+##鸳
+##鸵
+##鸽
+##鸾
+##鸿
+##鹂
+##鹃
+##鹄
+##鹅
+##鹈
+##鹉
+##鹊
+##鹌
+##鹏
+##鹑
+##鹕
+##鹘
+##鹜
+##鹞
+##鹤
+##鹦
+##鹧
+##鹫
+##鹭
+##鹰
+##鹳
+##鹵
+##鹹
+##鹼
+##鹽
+##鹿
+##麂
+##麋
+##麒
+##麓
+##麗
+##麝
+##麟
+##麥
+##麦
+##麩
+##麴
+##麵
+##麸
+##麺
+##麻
+##麼
+##麽
+##麾
+##黃
+##黄
+##黍
+##黎
+##黏
+##黑
+##黒
+##黔
+##默
+##黛
+##黜
+##黝
+##點
+##黠
+##黨
+##黯
+##黴
+##鼋
+##鼎
+##鼐
+##鼓
+##鼠
+##鼬
+##鼹
+##鼻
+##鼾
+##齁
+##齊
+##齋
+##齐
+##齒
+##齡
+##齢
+##齣
+##齦
+##齿
+##龄
+##龅
+##龈
+##龊
+##龋
+##龌
+##龍
+##龐
+##龔
+##龕
+##龙
+##龚
+##龛
+##龜
+##龟
+##︰
+##︱
+##︶
+##︿
+##﹁
+##﹂
+##﹍
+##﹏
+##﹐
+##﹑
+##﹒
+##﹔
+##﹕
+##﹖
+##﹗
+##﹙
+##﹚
+##﹝
+##﹞
+##﹡
+##﹣
+##!
+##"
+###
+##$
+##%
+##&
+##'
+##(
+##)
+##*
+##,
+##-
+##.
+##/
+##:
+##;
+##<
+##?
+##@
+##[
+##\
+##]
+##^
+##_
+##`
+##f
+##h
+##j
+##u
+##w
+##z
+##{
+##}
+##。
+##「
+##」
+##、
+##・
+##ッ
+##ー
+##イ
+##ク
+##シ
+##ス
+##ト
+##ノ
+##フ
+##ラ
+##ル
+##ン
+##゙
+##゚
+## ̄
+##¥
+##👍
+##🔥
+##😂
+##😎
diff --git a/models/tts/maskgct/g2p/sources/pinyin_2_bpmf.txt b/models/tts/maskgct/g2p/sources/pinyin_2_bpmf.txt
new file mode 100644
index 0000000000000000000000000000000000000000..af74dc687a547ed7822dacc77b7491924a8dcf1b
--- /dev/null
+++ b/models/tts/maskgct/g2p/sources/pinyin_2_bpmf.txt
@@ -0,0 +1,429 @@
+a ㄚ
+ai ㄞ
+an ㄢ
+ang ㄤ
+ao ㄠ
+ba ㄅㄚ
+bai ㄅㄞ
+ban ㄅㄢ
+bang ㄅㄤ
+bao ㄅㄠ
+bei ㄅㄟ
+ben ㄅㄣ
+beng ㄅㄥ
+bi ㄅㄧ
+bian ㄅㄧㄢ
+biang ㄅㄧㄤ
+biao ㄅㄧㄠ
+bie ㄅㄧㄝ
+bin ㄅㄧㄣ
+bing ㄅㄧㄥ
+bo ㄅㄛ
+bu ㄅㄨ
+ca ㄘㄚ
+cai ㄘㄞ
+can ㄘㄢ
+cang ㄘㄤ
+cao ㄘㄠ
+ce ㄘㄜ
+cen ㄘㄣ
+ceng ㄘㄥ
+cha ㄔㄚ
+chai ㄔㄞ
+chan ㄔㄢ
+chang ㄔㄤ
+chao ㄔㄠ
+che ㄔㄜ
+chen ㄔㄣ
+cheng ㄔㄥ
+chi ㄔ
+chong ㄔㄨㄥ
+chou ㄔㄡ
+chu ㄔㄨ
+chua ㄔㄨㄚ
+chuai ㄔㄨㄞ
+chuan ㄔㄨㄢ
+chuang ㄔㄨㄤ
+chui ㄔㄨㄟ
+chun ㄔㄨㄣ
+chuo ㄔㄨㄛ
+ci ㄘ
+cong ㄘㄨㄥ
+cou ㄘㄡ
+cu ㄘㄨ
+cuan ㄘㄨㄢ
+cui ㄘㄨㄟ
+cun ㄘㄨㄣ
+cuo ㄘㄨㄛ
+da ㄉㄚ
+dai ㄉㄞ
+dan ㄉㄢ
+dang ㄉㄤ
+dao ㄉㄠ
+de ㄉㄜ
+dei ㄉㄟ
+den ㄉㄣ
+deng ㄉㄥ
+di ㄉㄧ
+dia ㄉㄧㄚ
+dian ㄉㄧㄢ
+diao ㄉㄧㄠ
+die ㄉㄧㄝ
+din ㄉㄧㄣ
+ding ㄉㄧㄥ
+diu ㄉㄧㄡ
+dong ㄉㄨㄥ
+dou ㄉㄡ
+du ㄉㄨ
+duan ㄉㄨㄢ
+dui ㄉㄨㄟ
+dun ㄉㄨㄣ
+duo ㄉㄨㄛ
+e ㄜ
+ei ㄟ
+en ㄣ
+eng ㄥ
+er ㄦ
+fa ㄈㄚ
+fan ㄈㄢ
+fang ㄈㄤ
+fei ㄈㄟ
+fen ㄈㄣ
+feng ㄈㄥ
+fo ㄈㄛ
+fou ㄈㄡ
+fu ㄈㄨ
+ga ㄍㄚ
+gai ㄍㄞ
+gan ㄍㄢ
+gang ㄍㄤ
+gao ㄍㄠ
+ge ㄍㄜ
+gei ㄍㄟ
+gen ㄍㄣ
+geng ㄍㄥ
+gong ㄍㄨㄥ
+gou ㄍㄡ
+gu ㄍㄨ
+gua ㄍㄨㄚ
+guai ㄍㄨㄞ
+guan ㄍㄨㄢ
+guang ㄍㄨㄤ
+gui ㄍㄨㄟ
+gun ㄍㄨㄣ
+guo ㄍㄨㄛ
+ha ㄏㄚ
+hai ㄏㄞ
+han ㄏㄢ
+hang ㄏㄤ
+hao ㄏㄠ
+he ㄏㄜ
+hei ㄏㄟ
+hen ㄏㄣ
+heng ㄏㄥ
+hm ㄏㄇ
+hong ㄏㄨㄥ
+hou ㄏㄡ
+hu ㄏㄨ
+hua ㄏㄨㄚ
+huai ㄏㄨㄞ
+huan ㄏㄨㄢ
+huang ㄏㄨㄤ
+hui ㄏㄨㄟ
+hun ㄏㄨㄣ
+huo ㄏㄨㄛ
+ji ㄐㄧ
+jia ㄐㄧㄚ
+jian ㄐㄧㄢ
+jiang ㄐㄧㄤ
+jiao ㄐㄧㄠ
+jie ㄐㄧㄝ
+jin ㄐㄧㄣ
+jing ㄐㄧㄥ
+jiong ㄐㄩㄥ
+jiu ㄐㄧㄡ
+ju ㄐㄩ
+jv ㄐㄩ
+juan ㄐㄩㄢ
+jvan ㄐㄩㄢ
+jue ㄐㄩㄝ
+jve ㄐㄩㄝ
+jun ㄐㄩㄣ
+ka ㄎㄚ
+kai ㄎㄞ
+kan ㄎㄢ
+kang ㄎㄤ
+kao ㄎㄠ
+ke ㄎㄜ
+kei ㄎㄟ
+ken ㄎㄣ
+keng ㄎㄥ
+kong ㄎㄨㄥ
+kou ㄎㄡ
+ku ㄎㄨ
+kua ㄎㄨㄚ
+kuai ㄎㄨㄞ
+kuan ㄎㄨㄢ
+kuang ㄎㄨㄤ
+kui ㄎㄨㄟ
+kun ㄎㄨㄣ
+kuo ㄎㄨㄛ
+la ㄌㄚ
+lai ㄌㄞ
+lan ㄌㄢ
+lang ㄌㄤ
+lao ㄌㄠ
+le ㄌㄜ
+lei ㄌㄟ
+leng ㄌㄥ
+li ㄌㄧ
+lia ㄌㄧㄚ
+lian ㄌㄧㄢ
+liang ㄌㄧㄤ
+liao ㄌㄧㄠ
+lie ㄌㄧㄝ
+lin ㄌㄧㄣ
+ling ㄌㄧㄥ
+liu ㄌㄧㄡ
+lo ㄌㄛ
+long ㄌㄨㄥ
+lou ㄌㄡ
+lu ㄌㄨ
+luan ㄌㄨㄢ
+lue ㄌㄩㄝ
+lun ㄌㄨㄣ
+luo ㄌㄨㄛ
+lv ㄌㄩ
+lve ㄌㄩㄝ
+m ㄇㄨ
+ma ㄇㄚ
+mai ㄇㄞ
+man ㄇㄢ
+mang ㄇㄤ
+mao ㄇㄠ
+me ㄇㄜ
+mei ㄇㄟ
+men ㄇㄣ
+meng ㄇㄥ
+mi ㄇㄧ
+mian ㄇㄧㄢ
+miao ㄇㄧㄠ
+mie ㄇㄧㄝ
+min ㄇㄧㄣ
+ming ㄇㄧㄥ
+miu ㄇㄧㄡ
+mo ㄇㄛ
+mou ㄇㄡ
+mu ㄇㄨ
+n ㄣ
+na ㄋㄚ
+nai ㄋㄞ
+nan ㄋㄢ
+nang ㄋㄤ
+nao ㄋㄠ
+ne ㄋㄜ
+nei ㄋㄟ
+nen ㄋㄣ
+neng ㄋㄥ
+ng ㄣ
+ni ㄋㄧ
+nian ㄋㄧㄢ
+niang ㄋㄧㄤ
+niao ㄋㄧㄠ
+nie ㄋㄧㄝ
+nin ㄋㄧㄣ
+ning ㄋㄧㄥ
+niu ㄋㄧㄡ
+nong ㄋㄨㄥ
+nou ㄋㄡ
+nu ㄋㄨ
+nuan ㄋㄨㄢ
+nue ㄋㄩㄝ
+nun ㄋㄨㄣ
+nuo ㄋㄨㄛ
+nv ㄋㄩ
+nve ㄋㄩㄝ
+o ㄛ
+ou ㄡ
+pa ㄆㄚ
+pai ㄆㄞ
+pan ㄆㄢ
+pang ㄆㄤ
+pao ㄆㄠ
+pei ㄆㄟ
+pen ㄆㄣ
+peng ㄆㄥ
+pi ㄆㄧ
+pian ㄆㄧㄢ
+piao ㄆㄧㄠ
+pie ㄆㄧㄝ
+pin ㄆㄧㄣ
+ping ㄆㄧㄥ
+po ㄆㄛ
+pou ㄆㄡ
+pu ㄆㄨ
+qi ㄑㄧ
+qia ㄑㄧㄚ
+qian ㄑㄧㄢ
+qiang ㄑㄧㄤ
+qiao ㄑㄧㄠ
+qie ㄑㄧㄝ
+qin ㄑㄧㄣ
+qing ㄑㄧㄥ
+qiong ㄑㄩㄥ
+qiu ㄑㄧㄡ
+qu ㄑㄩ
+quan ㄑㄩㄢ
+qvan ㄑㄩㄢ
+que ㄑㄩㄝ
+qun ㄑㄩㄣ
+ran ㄖㄢ
+rang ㄖㄤ
+rao ㄖㄠ
+re ㄖㄜ
+ren ㄖㄣ
+reng ㄖㄥ
+ri ㄖ
+rong ㄖㄨㄥ
+rou ㄖㄡ
+ru ㄖㄨ
+rua ㄖㄨㄚ
+ruan ㄖㄨㄢ
+rui ㄖㄨㄟ
+run ㄖㄨㄣ
+ruo ㄖㄨㄛ
+sa ㄙㄚ
+sai ㄙㄞ
+san ㄙㄢ
+sang ㄙㄤ
+sao ㄙㄠ
+se ㄙㄜ
+sen ㄙㄣ
+seng ㄙㄥ
+sha ㄕㄚ
+shai ㄕㄞ
+shan ㄕㄢ
+shang ㄕㄤ
+shao ㄕㄠ
+she ㄕㄜ
+shei ㄕㄟ
+shen ㄕㄣ
+sheng ㄕㄥ
+shi ㄕ
+shou ㄕㄡ
+shu ㄕㄨ
+shua ㄕㄨㄚ
+shuai ㄕㄨㄞ
+shuan ㄕㄨㄢ
+shuang ㄕㄨㄤ
+shui ㄕㄨㄟ
+shun ㄕㄨㄣ
+shuo ㄕㄨㄛ
+si ㄙ
+song ㄙㄨㄥ
+sou ㄙㄡ
+su ㄙㄨ
+suan ㄙㄨㄢ
+sui ㄙㄨㄟ
+sun ㄙㄨㄣ
+suo ㄙㄨㄛ
+ta ㄊㄚ
+tai ㄊㄞ
+tan ㄊㄢ
+tang ㄊㄤ
+tao ㄊㄠ
+te ㄊㄜ
+tei ㄊㄟ
+teng ㄊㄥ
+ti ㄊㄧ
+tian ㄊㄧㄢ
+tiao ㄊㄧㄠ
+tie ㄊㄧㄝ
+ting ㄊㄧㄥ
+tong ㄊㄨㄥ
+tou ㄊㄡ
+tsuo ㄘㄨㄛ
+tu ㄊㄨ
+tuan ㄊㄨㄢ
+tui ㄊㄨㄟ
+tun ㄊㄨㄣ
+tuo ㄊㄨㄛ
+tzan ㄗㄢ
+wa ㄨㄚ
+wai ㄨㄞ
+wan ㄨㄢ
+wang ㄨㄤ
+wei ㄨㄟ
+wen ㄨㄣ
+weng ㄨㄥ
+wo ㄨㄛ
+wong ㄨㄥ
+wu ㄨ
+xi ㄒㄧ
+xia ㄒㄧㄚ
+xian ㄒㄧㄢ
+xiang ㄒㄧㄤ
+xiao ㄒㄧㄠ
+xie ㄒㄧㄝ
+xin ㄒㄧㄣ
+xing ㄒㄧㄥ
+xiong ㄒㄩㄥ
+xiu ㄒㄧㄡ
+xu ㄒㄩ
+xuan ㄒㄩㄢ
+xue ㄒㄩㄝ
+xun ㄒㄩㄣ
+ya ㄧㄚ
+yai ㄧㄞ
+yan ㄧㄢ
+yang ㄧㄤ
+yao ㄧㄠ
+ye ㄧㄝ
+yi ㄧ
+yin ㄧㄣ
+ying ㄧㄥ
+yo ㄧㄛ
+yong ㄩㄥ
+you ㄧㄡ
+yu ㄩ
+yuan ㄩㄢ
+yue ㄩㄝ
+yve ㄩㄝ
+yun ㄩㄣ
+za ㄗㄚ
+zai ㄗㄞ
+zan ㄗㄢ
+zang ㄗㄤ
+zao ㄗㄠ
+ze ㄗㄜ
+zei ㄗㄟ
+zen ㄗㄣ
+zeng ㄗㄥ
+zha ㄓㄚ
+zhai ㄓㄞ
+zhan ㄓㄢ
+zhang ㄓㄤ
+zhao ㄓㄠ
+zhe ㄓㄜ
+zhei ㄓㄟ
+zhen ㄓㄣ
+zheng ㄓㄥ
+zhi ㄓ
+zhong ㄓㄨㄥ
+zhou ㄓㄡ
+zhu ㄓㄨ
+zhua ㄓㄨㄚ
+zhuai ㄓㄨㄞ
+zhuan ㄓㄨㄢ
+zhuang ㄓㄨㄤ
+zhui ㄓㄨㄟ
+zhun ㄓㄨㄣ
+zhuo ㄓㄨㄛ
+zi ㄗ
+zong ㄗㄨㄥ
+zou ㄗㄡ
+zu ㄗㄨ
+zuan ㄗㄨㄢ
+zui ㄗㄨㄟ
+zun ㄗㄨㄣ
+zuo ㄗㄨㄛ
diff --git a/models/tts/maskgct/g2p/utils/front_utils.py b/models/tts/maskgct/g2p/utils/front_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..de9f878b5ea87868aee62b3eed5c29e3e95776b7
--- /dev/null
+++ b/models/tts/maskgct/g2p/utils/front_utils.py
@@ -0,0 +1,20 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+
+
+def generate_poly_lexicon(file_path: str):
+ """Generate poly char lexicon for Mandarin Chinese."""
+ poly_dict = {}
+
+ with open(file_path, "r", encoding="utf-8") as readf:
+ txt_list = readf.readlines()
+ for txt in txt_list:
+ word = txt.strip("\n")
+ if word not in poly_dict:
+ poly_dict[word] = 1
+ readf.close()
+ return poly_dict
diff --git a/models/tts/maskgct/g2p/utils/g2p.py b/models/tts/maskgct/g2p/utils/g2p.py
new file mode 100644
index 0000000000000000000000000000000000000000..f71e0c8dca08657857a7ec56a290561ff6bc083b
--- /dev/null
+++ b/models/tts/maskgct/g2p/utils/g2p.py
@@ -0,0 +1,139 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from phonemizer.backend import EspeakBackend
+from phonemizer.separator import Separator
+from phonemizer.utils import list2str, str2list
+from typing import List, Union
+import os
+import json
+import sys
+
+# separator=Separator(phone=' ', word=' _ ', syllable='|'),
+separator = Separator(word=" _ ", syllable="|", phone=" ")
+
+phonemizer_zh = EspeakBackend(
+ "cmn", preserve_punctuation=False, with_stress=False, language_switch="remove-flags"
+)
+# phonemizer_zh.separator = separator
+
+phonemizer_en = EspeakBackend(
+ "en-us",
+ preserve_punctuation=False,
+ with_stress=False,
+ language_switch="remove-flags",
+)
+# phonemizer_en.separator = separator
+
+phonemizer_ja = EspeakBackend(
+ "ja", preserve_punctuation=False, with_stress=False, language_switch="remove-flags"
+)
+# phonemizer_ja.separator = separator
+
+phonemizer_ko = EspeakBackend(
+ "ko", preserve_punctuation=False, with_stress=False, language_switch="remove-flags"
+)
+# phonemizer_ko.separator = separator
+
+phonemizer_fr = EspeakBackend(
+ "fr-fr",
+ preserve_punctuation=False,
+ with_stress=False,
+ language_switch="remove-flags",
+)
+# phonemizer_fr.separator = separator
+
+phonemizer_de = EspeakBackend(
+ "de", preserve_punctuation=False, with_stress=False, language_switch="remove-flags"
+)
+# phonemizer_de.separator = separator
+
+
+lang2backend = {
+ "zh": phonemizer_zh,
+ "ja": phonemizer_ja,
+ "en": phonemizer_en,
+ "fr": phonemizer_fr,
+ "ko": phonemizer_ko,
+ "de": phonemizer_de,
+}
+
+with open("./models/tts/maskgct/g2p/utils/mls_en.json", "r") as f:
+ json_data = f.read()
+token = json.loads(json_data)
+
+
+def phonemizer_g2p(text, language):
+ langbackend = lang2backend[language]
+ phonemes = _phonemize(
+ langbackend,
+ text,
+ separator,
+ strip=True,
+ njobs=1,
+ prepend_text=False,
+ preserve_empty_lines=False,
+ )
+ token_id = []
+ if isinstance(phonemes, list):
+ for phone in phonemes:
+ phonemes_split = phone.split(" ")
+ token_id.append([token[p] for p in phonemes_split if p in token])
+ else:
+ phonemes_split = phonemes.split(" ")
+ token_id = [token[p] for p in phonemes_split if p in token]
+ return phonemes, token_id
+
+
+def _phonemize( # pylint: disable=too-many-arguments
+ backend,
+ text: Union[str, List[str]],
+ separator: Separator,
+ strip: bool,
+ njobs: int,
+ prepend_text: bool,
+ preserve_empty_lines: bool,
+):
+ """Auxiliary function to phonemize()
+
+ Does the phonemization and returns the phonemized text. Raises a
+ RuntimeError on error.
+
+ """
+ # remember the text type for output (either list or string)
+ text_type = type(text)
+
+ # force the text as a list
+ text = [line.strip(os.linesep) for line in str2list(text)]
+
+ # if preserving empty lines, note the index of each empty line
+ if preserve_empty_lines:
+ empty_lines = [n for n, line in enumerate(text) if not line.strip()]
+
+ # ignore empty lines
+ text = [line for line in text if line.strip()]
+
+ if text:
+ # phonemize the text
+ phonemized = backend.phonemize(
+ text, separator=separator, strip=strip, njobs=njobs
+ )
+ else:
+ phonemized = []
+
+ # if preserving empty lines, reinsert them into text and phonemized lists
+ if preserve_empty_lines:
+ for i in empty_lines: # noqa
+ if prepend_text:
+ text.insert(i, "")
+ phonemized.insert(i, "")
+
+ # at that point, the phonemized text is a list of str. Format it as
+ # expected by the parameters
+ if prepend_text:
+ return list(zip(text, phonemized))
+ if text_type == str:
+ return list2str(phonemized)
+ return phonemized
diff --git a/models/tts/maskgct/g2p/utils/log.py b/models/tts/maskgct/g2p/utils/log.py
new file mode 100644
index 0000000000000000000000000000000000000000..d10b887ef2e9292bd79c628e9ed7881c7a91bf52
--- /dev/null
+++ b/models/tts/maskgct/g2p/utils/log.py
@@ -0,0 +1,52 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import functools
+import logging
+
+__all__ = [
+ "logger",
+]
+
+
+class Logger(object):
+ def __init__(self, name: str = None):
+ name = "PaddleSpeech" if not name else name
+ self.logger = logging.getLogger(name)
+
+ log_config = {
+ "DEBUG": 10,
+ "INFO": 20,
+ "TRAIN": 21,
+ "EVAL": 22,
+ "WARNING": 30,
+ "ERROR": 40,
+ "CRITICAL": 50,
+ "EXCEPTION": 100,
+ }
+ for key, level in log_config.items():
+ logging.addLevelName(level, key)
+ if key == "EXCEPTION":
+ self.__dict__[key.lower()] = self.logger.exception
+ else:
+ self.__dict__[key.lower()] = functools.partial(self.__call__, level)
+
+ self.format = logging.Formatter(
+ fmt="[%(asctime)-15s] [%(levelname)8s] - %(message)s"
+ )
+
+ self.handler = logging.StreamHandler()
+ self.handler.setFormatter(self.format)
+
+ self.logger.addHandler(self.handler)
+ self.logger.setLevel(logging.INFO)
+ self.logger.propagate = False
+
+ def __call__(self, log_level: str, msg: str):
+ self.logger.log(log_level, msg)
+
+
+logger = Logger()
diff --git a/models/tts/maskgct/g2p/utils/mls_en.json b/models/tts/maskgct/g2p/utils/mls_en.json
new file mode 100644
index 0000000000000000000000000000000000000000..f3aadbf144427af10ec06ca3cab8c4a2c461925d
--- /dev/null
+++ b/models/tts/maskgct/g2p/utils/mls_en.json
@@ -0,0 +1,335 @@
+{
+ "[UNK]": 0,
+ "_": 1,
+ "b": 2,
+ "d": 3,
+ "f": 4,
+ "h": 5,
+ "i": 6,
+ "j": 7,
+ "k": 8,
+ "l": 9,
+ "m": 10,
+ "n": 11,
+ "p": 12,
+ "r": 13,
+ "s": 14,
+ "t": 15,
+ "v": 16,
+ "w": 17,
+ "x": 18,
+ "z": 19,
+ "æ": 20,
+ "ç": 21,
+ "ð": 22,
+ "ŋ": 23,
+ "ɐ": 24,
+ "ɔ": 25,
+ "ə": 26,
+ "ɚ": 27,
+ "ɛ": 28,
+ "ɡ": 29,
+ "ɪ": 30,
+ "ɬ": 31,
+ "ɹ": 32,
+ "ɾ": 33,
+ "ʃ": 34,
+ "ʊ": 35,
+ "ʌ": 36,
+ "ʒ": 37,
+ "ʔ": 38,
+ "θ": 39,
+ "ᵻ": 40,
+ "aɪ": 41,
+ "aʊ": 42,
+ "dʒ": 43,
+ "eɪ": 44,
+ "iə": 45,
+ "iː": 46,
+ "n̩": 47,
+ "oʊ": 48,
+ "oː": 49,
+ "tʃ": 50,
+ "uː": 51,
+ "ææ": 52,
+ "ɐɐ": 53,
+ "ɑː": 54,
+ "ɑ̃": 55,
+ "ɔɪ": 56,
+ "ɔː": 57,
+ "ɔ̃": 58,
+ "əl": 59,
+ "ɛɹ": 60,
+ "ɜː": 61,
+ "ɡʲ": 62,
+ "ɪɹ": 63,
+ "ʊɹ": 64,
+ "aɪə": 65,
+ "aɪɚ": 66,
+ "iːː": 67,
+ "oːɹ": 68,
+ "ɑːɹ": 69,
+ "ɔːɹ": 70,
+
+ "1": 71,
+ "a": 72,
+ "e": 73,
+ "o": 74,
+ "q": 75,
+ "u": 76,
+ "y": 77,
+ "ɑ": 78,
+ "ɒ": 79,
+ "ɕ": 80,
+ "ɣ": 81,
+ "ɫ": 82,
+ "ɯ": 83,
+ "ʐ": 84,
+ "ʲ": 85,
+ "a1": 86,
+ "a2": 87,
+ "a5": 88,
+ "ai": 89,
+ "aɜ": 90,
+ "aː": 91,
+ "ei": 92,
+ "eə": 93,
+ "i.": 94,
+ "i1": 95,
+ "i2": 96,
+ "i5": 97,
+ "io": 98,
+ "iɑ": 99,
+ "iɛ": 100,
+ "iɜ": 101,
+ "i̪": 102,
+ "kh": 103,
+ "nʲ": 104,
+ "o1": 105,
+ "o2": 106,
+ "o5": 107,
+ "ou": 108,
+ "oɜ": 109,
+ "ph": 110,
+ "s.": 111,
+ "th": 112,
+ "ts": 113,
+ "tɕ": 114,
+ "u1": 115,
+ "u2": 116,
+ "u5": 117,
+ "ua": 118,
+ "uo": 119,
+ "uə": 120,
+ "uɜ": 121,
+ "y1": 122,
+ "y2": 123,
+ "y5": 124,
+ "yu": 125,
+ "yæ": 126,
+ "yə": 127,
+ "yɛ": 128,
+ "yɜ": 129,
+ "ŋɜ": 130,
+ "ŋʲ": 131,
+ "ɑ1": 132,
+ "ɑ2": 133,
+ "ɑ5": 134,
+ "ɑu": 135,
+ "ɑɜ": 136,
+ "ɑʲ": 137,
+ "ə1": 138,
+ "ə2": 139,
+ "ə5": 140,
+ "ər": 141,
+ "əɜ": 142,
+ "əʊ": 143,
+ "ʊə": 144,
+ "ai1": 145,
+ "ai2": 146,
+ "ai5": 147,
+ "aiɜ": 148,
+ "ei1": 149,
+ "ei2": 150,
+ "ei5": 151,
+ "eiɜ": 152,
+ "i.1": 153,
+ "i.2": 154,
+ "i.5": 155,
+ "i.ɜ": 156,
+ "io5": 157,
+ "iou": 158,
+ "iɑ1": 159,
+ "iɑ2": 160,
+ "iɑ5": 161,
+ "iɑɜ": 162,
+ "iɛ1": 163,
+ "iɛ2": 164,
+ "iɛ5": 165,
+ "iɛɜ": 166,
+ "i̪1": 167,
+ "i̪2": 168,
+ "i̪5": 169,
+ "i̪ɜ": 170,
+ "onɡ": 171,
+ "ou1": 172,
+ "ou2": 173,
+ "ou5": 174,
+ "ouɜ": 175,
+ "ts.": 176,
+ "tsh": 177,
+ "tɕh": 178,
+ "u5ʲ": 179,
+ "ua1": 180,
+ "ua2": 181,
+ "ua5": 182,
+ "uai": 183,
+ "uaɜ": 184,
+ "uei": 185,
+ "uo1": 186,
+ "uo2": 187,
+ "uo5": 188,
+ "uoɜ": 189,
+ "uə1": 190,
+ "uə2": 191,
+ "uə5": 192,
+ "uəɜ": 193,
+ "yiɜ": 194,
+ "yu2": 195,
+ "yu5": 196,
+ "yæ2": 197,
+ "yæ5": 198,
+ "yæɜ": 199,
+ "yə2": 200,
+ "yə5": 201,
+ "yəɜ": 202,
+ "yɛ1": 203,
+ "yɛ2": 204,
+ "yɛ5": 205,
+ "yɛɜ": 206,
+ "ɑu1": 207,
+ "ɑu2": 208,
+ "ɑu5": 209,
+ "ɑuɜ": 210,
+ "ər1": 211,
+ "ər2": 212,
+ "ər5": 213,
+ "ərɜ": 214,
+ "əː1": 215,
+ "iou1": 216,
+ "iou2": 217,
+ "iou5": 218,
+ "iouɜ": 219,
+ "onɡ1": 220,
+ "onɡ2": 221,
+ "onɡ5": 222,
+ "onɡɜ": 223,
+ "ts.h": 224,
+ "uai2": 225,
+ "uai5": 226,
+ "uaiɜ": 227,
+ "uei1": 228,
+ "uei2": 229,
+ "uei5": 230,
+ "ueiɜ": 231,
+ "uoɜʲ": 232,
+ "yɛ5ʲ": 233,
+ "ɑu2ʲ": 234,
+
+ "2": 235,
+ "5": 236,
+ "ɜ": 237,
+ "ʂ": 238,
+ "dʑ": 239,
+ "iɪ": 240,
+ "uɪ": 241,
+ "xʲ": 242,
+ "ɑt": 243,
+ "ɛɜ": 244,
+ "ɛː": 245,
+ "ɪː": 246,
+ "phʲ": 247,
+ "ɑ5ʲ": 248,
+ "ɑuʲ": 249,
+ "ərə": 250,
+ "uozʰ": 251,
+ "ər1ʲ": 252,
+ "tɕhtɕh": 253,
+
+ "c": 254,
+ "ʋ": 255,
+ "ʍ": 256,
+ "ʑ": 257,
+ "ː": 258,
+ "aə": 259,
+ "eː": 260,
+ "hʲ": 261,
+ "iʊ": 262,
+ "kʲ": 263,
+ "lʲ": 264,
+ "oə": 265,
+ "oɪ": 266,
+ "oʲ": 267,
+ "pʲ": 268,
+ "sʲ": 269,
+ "u4": 270,
+ "uʲ": 271,
+ "yi": 272,
+ "yʲ": 273,
+ "ŋ2": 274,
+ "ŋ5": 275,
+ "ŋ̩": 276,
+ "ɑɪ": 277,
+ "ɑʊ": 278,
+ "ɕʲ": 279,
+ "ət": 280,
+ "əə": 281,
+ "əɪ": 282,
+ "əʲ": 283,
+ "ɛ1": 284,
+ "ɛ5": 285,
+ "aiə": 286,
+ "aiɪ": 287,
+ "azʰ": 288,
+ "eiə": 289,
+ "eiɪ": 290,
+ "eiʊ": 291,
+ "i.ə": 292,
+ "i.ɪ": 293,
+ "i.ʊ": 294,
+ "ioɜ": 295,
+ "izʰ": 296,
+ "iɑə": 297,
+ "iɑʊ": 298,
+ "iɑʲ": 299,
+ "iɛə": 300,
+ "iɛɪ": 301,
+ "iɛʊ": 302,
+ "i̪ə": 303,
+ "i̪ʊ": 304,
+ "khʲ": 305,
+ "ouʲ": 306,
+ "tsʲ": 307,
+ "u2ʲ": 308,
+ "uoɪ": 309,
+ "uzʰ": 310,
+ "uɜʲ": 311,
+ "yæɪ": 312,
+ "yəʊ": 313,
+ "ərt": 314,
+ "ərɪ": 315,
+ "ərʲ": 316,
+ "əːt": 317,
+ "iouə": 318,
+ "iouʊ": 319,
+ "iouʲ": 320,
+ "iɛzʰ": 321,
+ "onɡə": 322,
+ "onɡɪ": 323,
+ "onɡʊ": 324,
+ "ouzʰ": 325,
+ "uai1": 326,
+ "ueiɪ": 327,
+ "ɑuzʰ": 328,
+ "iouzʰ": 329
+}
\ No newline at end of file
diff --git a/models/tts/maskgct/llama_nar.py b/models/tts/maskgct/llama_nar.py
new file mode 100644
index 0000000000000000000000000000000000000000..14c8ae842b5952b6c7ba3aa6e442dbb0277e2e87
--- /dev/null
+++ b/models/tts/maskgct/llama_nar.py
@@ -0,0 +1,650 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
+import torch
+import torch.nn.functional as F
+import numpy as np
+import os
+import torch.nn as nn
+from typing import List, Optional, Tuple, Union
+import math
+
+from transformers.models.llama.modeling_llama import LlamaDecoderLayer
+from transformers.models.llama.modeling_llama import BaseModelOutputWithPast
+
+
+# sinusoidal positional encoding
+class SinusoidalPosEmb(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, x):
+ device = x.device
+ half_dim = self.dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
+ emb = x[:, None] * emb[None, :] * 1.0
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
+ return emb
+
+
+class LlamaAdaptiveRMSNorm(nn.Module):
+ def __init__(self, hidden_size=1024, eps=1e-6, dim_cond=1024):
+ super().__init__()
+ self.to_weight = nn.Linear(dim_cond, hidden_size)
+ nn.init.zeros_(self.to_weight.weight)
+ nn.init.ones_(self.to_weight.bias)
+ self.variance_epsilon = eps
+ self._is_hf_initialized = True # disable automatic init
+
+ def forward(self, hidden_states, cond_embedding):
+ input_dtype = hidden_states.dtype
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+
+ weight = self.to_weight(cond_embedding)
+ if len(weight.shape) == 2:
+ weight = weight.unsqueeze(1)
+
+ return (weight * hidden_states).to(input_dtype)
+
+
+class LlamaNARDecoderLayer(LlamaDecoderLayer):
+ def __init__(self, config: LlamaConfig, layer_idx: int):
+ """Override to adaptive layer norm"""
+ super().__init__(config, layer_idx) # init attention, mlp, etc.
+ self.input_layernorm = LlamaAdaptiveRMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
+ )
+ self.post_attention_layernorm = LlamaAdaptiveRMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
+ )
+
+ # add `cond` in forward function
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cond_embedding: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ ) -> Tuple[
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
+ ]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ """
+
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(
+ hidden_states, cond_embedding=cond_embedding
+ )
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(
+ hidden_states, cond_embedding=cond_embedding
+ )
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+ def __init__(self, config: LlamaConfig, layer_idx: int):
+ """Override to adaptive layer norm"""
+ super().__init__(config, layer_idx) # init attention, mlp, etc.
+ self.layer_idx = layer_idx
+ self.input_layernorm = LlamaAdaptiveRMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
+ )
+ self.post_attention_layernorm = LlamaAdaptiveRMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cond_embedding: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ ) -> Tuple[
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
+ ]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ """
+
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(
+ hidden_states, cond_embedding=cond_embedding
+ )
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(
+ hidden_states, cond_embedding=cond_embedding
+ )
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+class DiffLlama(LlamaModel):
+ def __init__(
+ self,
+ hidden_size=1024,
+ num_heads=16,
+ num_layers=16,
+ config=LlamaConfig(0, 256, 1024, 1, 1),
+ ):
+ super().__init__(config)
+
+ self.layers = nn.ModuleList(
+ [
+ LlamaNARDecoderLayer(
+ LlamaConfig(
+ hidden_size=hidden_size,
+ num_attention_heads=num_heads,
+ max_position_embeddings=4096,
+ intermediate_size=hidden_size * 4,
+ ),
+ layer_idx=i,
+ )
+ for i in range(num_layers)
+ ]
+ )
+
+ self.norm = LlamaAdaptiveRMSNorm(hidden_size, dim_cond=hidden_size)
+
+ self.diff_step_embedding = SinusoidalPosEmb(hidden_size)
+ self.diff_step_mlp = nn.Sequential(
+ nn.Linear(hidden_size, hidden_size * 4),
+ nn.SiLU(),
+ nn.Linear(hidden_size * 4, hidden_size),
+ )
+
+ # self.position_embedding = PositionalEncoding(hidden_size, dropout=0.0)
+
+ self.cond_mlp = nn.Sequential(
+ nn.Linear(hidden_size, hidden_size * 4),
+ nn.SiLU(),
+ nn.Linear(hidden_size * 4, hidden_size),
+ )
+
+ for layer in self.layers:
+ layer.input_layernorm = LlamaAdaptiveRMSNorm(
+ hidden_size, dim_cond=hidden_size
+ )
+ layer.post_attention_layernorm = LlamaAdaptiveRMSNorm(
+ hidden_size, dim_cond=hidden_size
+ )
+
+ self.post_init()
+
+ # self.reset_parameters()
+
+ def _prepare_decoder_attention_mask(
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
+ ):
+ # create noncausal mask
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ combined_attention_mask = None
+
+ def _expand_mask(
+ mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
+ ):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = (
+ mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+ )
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
+ )
+
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ expanded_attn_mask = _expand_mask(
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ ).to(inputs_embeds.device)
+ combined_attention_mask = (
+ expanded_attn_mask
+ if combined_attention_mask is None
+ else expanded_attn_mask + combined_attention_mask
+ )
+
+ return combined_attention_mask
+
+ def forward(
+ self,
+ x,
+ diffusion_step,
+ cond,
+ x_mask,
+ input_ids: torch.LongTensor = None, # [num_quant, B, T]
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+
+ # retrieve some shape info
+ batch_size, seq_length, _ = x.shape
+
+ # condtion mlp
+ cond_embedding = self.cond_mlp(cond) # (B, T, C)
+
+ # diffusion step embedding
+ diffusion_step = self.diff_step_embedding(diffusion_step).to(x.device)
+ diffusion_step = self.diff_step_mlp(diffusion_step) # (B, C)
+ x = x + cond_embedding
+
+ inputs_embeds = x
+ attention_mask = x_mask
+
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length,
+ seq_length + past_key_values_length,
+ dtype=torch.long,
+ device=device,
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ # embed positions
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (batch_size, seq_length_with_past),
+ dtype=torch.bool,
+ device=inputs_embeds.device,
+ )
+ attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask,
+ (batch_size, seq_length),
+ inputs_embeds,
+ past_key_values_length,
+ )
+
+ hidden_states = inputs_embeds
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ for idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ past_key_value = (
+ past_key_values[idx] if past_key_values is not None else None
+ )
+
+ if self.gradient_checkpointing and self.training:
+ raise NotImplementedError
+
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cond_embedding=diffusion_step,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states, cond_embedding=diffusion_step)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+
+ return hidden_states
+
+
+class DiffLlamaPrefix(LlamaModel):
+ def __init__(
+ self,
+ hidden_size=1024,
+ num_heads=16,
+ num_layers=16,
+ config=LlamaConfig(0, 256, 1024, 1, 1),
+ ):
+ super().__init__(config)
+
+ self.layers = nn.ModuleList(
+ [
+ LlamaNARDecoderLayer(
+ LlamaConfig(
+ hidden_size=hidden_size,
+ num_attention_heads=num_heads,
+ max_position_embeddings=4096,
+ intermediate_size=hidden_size * 4,
+ ),
+ layer_idx=i,
+ )
+ for i in range(num_layers)
+ ]
+ )
+
+ self.norm = LlamaAdaptiveRMSNorm(hidden_size, dim_cond=hidden_size)
+
+ self.diff_step_embedding = SinusoidalPosEmb(hidden_size)
+ self.diff_step_mlp = nn.Sequential(
+ nn.Linear(hidden_size, hidden_size * 4),
+ nn.SiLU(),
+ nn.Linear(hidden_size * 4, hidden_size),
+ )
+
+ self.cond_mlp = nn.Sequential(
+ nn.Linear(hidden_size, hidden_size * 4),
+ nn.SiLU(),
+ nn.Linear(hidden_size * 4, hidden_size),
+ )
+
+ for layer in self.layers:
+ layer.input_layernorm = LlamaAdaptiveRMSNorm(
+ hidden_size, dim_cond=hidden_size
+ )
+ layer.post_attention_layernorm = LlamaAdaptiveRMSNorm(
+ hidden_size, dim_cond=hidden_size
+ )
+
+ self.embed_tokens = None
+
+ self.post_init()
+
+ def _prepare_decoder_attention_mask(
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
+ ):
+ # create noncausal mask
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ combined_attention_mask = None
+
+ def _expand_mask(
+ mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
+ ):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = (
+ mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+ )
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
+ )
+
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ expanded_attn_mask = _expand_mask(
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ ).to(inputs_embeds.device)
+ combined_attention_mask = (
+ expanded_attn_mask
+ if combined_attention_mask is None
+ else expanded_attn_mask + combined_attention_mask
+ )
+
+ return combined_attention_mask
+
+ def forward(
+ self,
+ x,
+ diffusion_step,
+ x_mask,
+ phone_embedding: Optional[torch.LongTensor] = None,
+ phone_mask: Optional[torch.FloatTensor] = None,
+ input_ids: torch.LongTensor = None, # [num_quant, B, T]
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+
+ # retrieve some shape info
+
+ phone_embedding = self.cond_mlp(phone_embedding) # (B, T, C)
+ phone_length = phone_embedding.shape[1]
+ inputs_embeds = torch.cat([phone_embedding, x], dim=1)
+ attention_mask = torch.cat([phone_mask, x_mask], dim=1)
+
+ # diffusion step embedding
+ diffusion_step = self.diff_step_embedding(diffusion_step).to(x.device)
+ diffusion_step = self.diff_step_mlp(diffusion_step) # (B, C)
+
+ batch_size, seq_length, _ = inputs_embeds.shape
+
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length,
+ seq_length + past_key_values_length,
+ dtype=torch.long,
+ device=device,
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ # embed positions
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (batch_size, seq_length_with_past),
+ dtype=torch.bool,
+ device=inputs_embeds.device,
+ )
+ attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask,
+ (batch_size, seq_length),
+ inputs_embeds,
+ past_key_values_length,
+ )
+
+ hidden_states = inputs_embeds
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ for idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ past_key_value = (
+ past_key_values[idx] if past_key_values is not None else None
+ )
+
+ if self.gradient_checkpointing and self.training:
+ raise NotImplementedError
+
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cond_embedding=diffusion_step,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states, cond_embedding=diffusion_step)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+
+ return hidden_states[
+ :,
+ phone_length:,
+ ]
diff --git a/models/tts/maskgct/maskgct_inference.py b/models/tts/maskgct/maskgct_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..631ad2ceb85781cfcb10dca840280fb51321b64d
--- /dev/null
+++ b/models/tts/maskgct/maskgct_inference.py
@@ -0,0 +1,90 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from models.tts.maskgct.maskgct_utils import *
+from huggingface_hub import hf_hub_download
+import safetensors
+import soundfile as sf
+
+if __name__ == "__main__":
+
+ # build model
+ device = torch.device("cuda:0")
+ cfg_path = "./models/tts/maskgct/config/maskgct.json"
+ cfg = load_config(cfg_path)
+ # 1. build semantic model (w2v-bert-2.0)
+ semantic_model, semantic_mean, semantic_std = build_semantic_model(device)
+ # 2. build semantic codec
+ semantic_codec = build_semantic_codec(cfg.model.semantic_codec, device)
+ # 3. build acoustic codec
+ codec_encoder, codec_decoder = build_acoustic_codec(
+ cfg.model.acoustic_codec, device
+ )
+ # 4. build t2s model
+ t2s_model = build_t2s_model(cfg.model.t2s_model, device)
+ # 5. build s2a model
+ s2a_model_1layer = build_s2a_model(cfg.model.s2a_model.s2a_1layer, device)
+ s2a_model_full = build_s2a_model(cfg.model.s2a_model.s2a_full, device)
+
+ # download checkpoint
+ # download semantic codec ckpt
+ semantic_code_ckpt = hf_hub_download(
+ "amphion/MaskGCT", filename="semantic_codec/model.safetensors"
+ )
+ # download acoustic codec ckpt
+ codec_encoder_ckpt = hf_hub_download(
+ "amphion/MaskGCT", filename="acoustic_codec/model.safetensors"
+ )
+ codec_decoder_ckpt = hf_hub_download(
+ "amphion/MaskGCT", filename="acoustic_codec/model_1.safetensors"
+ )
+ # download t2s model ckpt
+ t2s_model_ckpt = hf_hub_download(
+ "amphion/MaskGCT", filename="t2s_model/model.safetensors"
+ )
+ # download s2a model ckpt
+ s2a_1layer_ckpt = hf_hub_download(
+ "amphion/MaskGCT", filename="s2a_model/s2a_model_1layer/model.safetensors"
+ )
+ s2a_full_ckpt = hf_hub_download(
+ "amphion/MaskGCT", filename="s2a_model/s2a_model_full/model.safetensors"
+ )
+
+ # load semantic codec
+ safetensors.torch.load_model(semantic_codec, semantic_code_ckpt)
+ # load acoustic codec
+ safetensors.torch.load_model(codec_encoder, codec_encoder_ckpt)
+ safetensors.torch.load_model(codec_decoder, codec_decoder_ckpt)
+ # load t2s model
+ safetensors.torch.load_model(t2s_model, t2s_model_ckpt)
+ # load s2a model
+ safetensors.torch.load_model(s2a_model_1layer, s2a_1layer_ckpt)
+ safetensors.torch.load_model(s2a_model_full, s2a_full_ckpt)
+
+ # inference
+ prompt_wav_path = "./models/tts/maskgct/wav/prompt.wav"
+ save_path = "[YOUR SAVE PATH]"
+ prompt_text = " We do not break. We never give in. We never back down."
+ target_text = "In this paper, we introduce MaskGCT, a fully non-autoregressive TTS model that eliminates the need for explicit alignment information between text and speech supervision."
+ # Specify the target duration (in seconds). If target_len = None, we use a simple rule to predict the target duration.
+ target_len = 18
+ maskgct_inference_pipeline = MaskGCT_Inference_Pipeline(
+ semantic_model,
+ semantic_codec,
+ codec_encoder,
+ codec_decoder,
+ t2s_model,
+ s2a_model_1layer,
+ s2a_model_full,
+ semantic_mean,
+ semantic_std,
+ device,
+ )
+
+ recovered_audio = maskgct_inference_pipeline.maskgct_inference(
+ prompt_wav_path, prompt_text, target_text, "en", "en", target_len=target_len
+ )
+
+ sf.write(save_path, recovered_audio, 24000)
diff --git a/models/tts/maskgct/maskgct_s2a.py b/models/tts/maskgct/maskgct_s2a.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad9fd09947e5638002e47e4f381157fa0dac7992
--- /dev/null
+++ b/models/tts/maskgct/maskgct_s2a.py
@@ -0,0 +1,503 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import numpy as np
+import torch.nn as nn
+import math
+from einops import rearrange
+from models.tts.maskgct.llama_nar import DiffLlama
+
+
+def top_k(logits, thres=0.9):
+ k = math.ceil((1 - thres) * logits.shape[-1])
+ val, ind = logits.topk(k, dim=-1)
+ probs = torch.full_like(logits, float("-inf"))
+ probs.scatter_(2, ind, val)
+ return probs
+
+
+def log(t, eps=1e-10):
+ return torch.log(t + eps)
+
+
+def gumbel_noise(t):
+ noise = torch.zeros_like(t).uniform_(0, 1)
+ return -log(-log(noise))
+
+
+def gumbel_sample(t, temperature=1.0, dim=-1):
+ return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
+
+
+def top_k(logits, thres=0.9):
+ k = math.ceil((1 - thres) * logits.shape[-1])
+ val, ind = logits.topk(k, dim=-1)
+ probs = torch.full_like(logits, float("-inf"))
+ probs.scatter_(2, ind, val)
+ return probs
+
+
+def log(t, eps=1e-10):
+ return torch.log(t + eps)
+
+
+def gumbel_noise(t):
+ noise = torch.zeros_like(t).uniform_(0, 1)
+ return -log(-log(noise))
+
+
+def gumbel_sample(t, temperature=1.0, dim=-1):
+ return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
+
+
+class MaskGCT_S2A(nn.Module):
+ def __init__(
+ self,
+ num_quantizer=12,
+ hidden_size=1024,
+ num_layers=16,
+ num_heads=16,
+ codebook_size=1024,
+ cfg_scale=0.15,
+ mask_layer_schedule="linear",
+ cond_codebook_size=1024,
+ cond_dim=1024,
+ predict_layer_1=True,
+ cfg=None,
+ ):
+ super().__init__()
+
+ num_quantizer = (
+ cfg.num_quantizer
+ if cfg is not None and hasattr(cfg, "num_quantizer")
+ else num_quantizer
+ )
+ hidden_size = (
+ cfg.hidden_size
+ if cfg is not None and hasattr(cfg, "hidden_size")
+ else hidden_size
+ )
+ num_layers = (
+ cfg.num_layers
+ if cfg is not None and hasattr(cfg, "num_layers")
+ else num_layers
+ )
+ num_heads = (
+ cfg.num_heads
+ if cfg is not None and hasattr(cfg, "num_heads")
+ else num_heads
+ )
+ codebook_size = (
+ cfg.codebook_size
+ if cfg is not None and hasattr(cfg, "codebook_size")
+ else codebook_size
+ )
+ cfg_scale = (
+ cfg.cfg_scale
+ if cfg is not None and hasattr(cfg, "cfg_scale")
+ else cfg_scale
+ )
+ mask_layer_schedule = (
+ cfg.mask_layer_schedule
+ if cfg is not None and hasattr(cfg, "mask_layer_schedule")
+ else mask_layer_schedule
+ )
+ cond_codebook_size = (
+ cfg.cond_codebook_size
+ if cfg is not None and hasattr(cfg, "cond_codebook_size")
+ else cond_codebook_size
+ )
+ cond_dim = (
+ cfg.cond_dim if cfg is not None and hasattr(cfg, "cond_dim") else cond_dim
+ )
+ predict_layer_1 = (
+ cfg.predict_layer_1
+ if cfg is not None and hasattr(cfg, "predict_layer_1")
+ else predict_layer_1
+ )
+
+ self.num_quantizer = num_quantizer
+ self.hidden_size = hidden_size
+ self.codebook_size = codebook_size
+ self.num_layers = num_layers
+ self.num_heads = num_heads
+ self.cfg_scale = cfg_scale
+ self.mask_layer_schedule = mask_layer_schedule
+ self.cond_codebook_size = cond_codebook_size
+ self.cond_dim = cond_dim
+ self.predict_layer_1 = predict_layer_1
+
+ self.layer_emb = nn.Embedding(self.num_quantizer, self.hidden_size)
+ self.mask_emb = nn.Embedding(1, self.hidden_size)
+
+ self.token_emb = torch.nn.ModuleList(
+ [
+ nn.Embedding(self.codebook_size, self.hidden_size)
+ for _ in range(self.num_quantizer)
+ ]
+ )
+
+ self.to_logits = torch.nn.ModuleList(
+ [
+ nn.Linear(self.hidden_size, self.codebook_size)
+ for _ in range(self.num_quantizer)
+ ]
+ )
+
+ self.cond_emb = nn.Embedding(cond_codebook_size, self.hidden_size)
+
+ self.reset_parameters()
+
+ self.diff_estimator = DiffLlama(
+ hidden_size=hidden_size,
+ num_heads=self.num_heads,
+ num_layers=num_layers,
+ )
+
+ def mask_prob(self, t):
+ return torch.sin(t * np.pi / 2).to(t.device)
+
+ def mask_layer(self, t):
+ # print(self.predict_layer_1)
+ if self.mask_layer_schedule == "uniform":
+ if self.predict_layer_1:
+ mask_layer = torch.randint(0, self.num_quantizer, (1,)).to(t.device)
+ else:
+ mask_layer = torch.randint(1, self.num_quantizer, (1,)).to(t.device)
+ elif self.mask_layer_schedule == "cosine":
+ if self.predict_layer_1:
+ weights = torch.tensor(
+ [
+ np.cos(i / self.num_quantizer * np.pi / 2)
+ for i in range(self.num_quantizer)
+ ]
+ )
+ else:
+ weights = torch.tensor(
+ [0]
+ + [
+ np.cos((i - 1) / self.num_quantizer * np.pi / 2)
+ for i in range(1, self.num_quantizer)
+ ]
+ )
+ mask_layer = torch.multinomial(weights, 1).to(t.device)
+ elif self.mask_layer_schedule == "linear":
+ if self.predict_layer_1:
+ weights = torch.tensor(
+ [self.num_quantizer - i for i in range(self.num_quantizer)]
+ )
+ else:
+ weights = torch.tensor(
+ [0]
+ + [
+ self.num_quantizer - (i - 1)
+ for i in range(1, self.num_quantizer)
+ ]
+ )
+ weights = weights / weights.sum()
+ mask_layer = torch.multinomial(weights, 1).to(t.device)
+ # print(mask_layer)
+ new_t = t
+
+ return mask_layer, new_t
+
+ def forward_diffusion(self, x0, t):
+ # x0: (B, T, num_quantizer)
+ mask_layer, new_t = self.mask_layer(t) # (1,)
+ mask_prob = self.mask_prob(new_t) # (B,)
+ mask_token = self.mask_emb(torch.zeros_like(mask_layer)) # (1, hidden_size)
+
+ xt = torch.zeros(x0.shape[0], x0.shape[1], self.hidden_size).to(x0.device)
+
+ cfg_scale = self.cfg_scale
+
+ # get prompt len
+ if torch.rand(1) > cfg_scale:
+ prompt_len = torch.randint(
+ min(x0.shape[1] // 4, 5), x0.shape[1] // 2, (x0.shape[0],)
+ ).to(
+ x0.device
+ ) # (B,)
+ else:
+ prompt_len = torch.zeros(x0.shape[0]).to(x0) # (B,)
+
+ # get is prompt
+ is_prompt = torch.zeros_like(x0[:, :, 0]) # (B, T)
+ col_indices = (
+ torch.arange(is_prompt.shape[1])
+ .repeat(is_prompt.shape[0], 1)
+ .to(prompt_len)
+ ) # (B, T)
+ is_prompt[col_indices < prompt_len.unsqueeze(1)] = 1 # (B, T) 1 if prompt
+
+ for idx, token_emb_idx in enumerate(self.token_emb):
+ if idx < mask_layer:
+ xt = xt + token_emb_idx(x0[:, :, idx]) # (B, T, hidden_size)
+
+ elif idx == mask_layer:
+ mask = torch.bernoulli(
+ torch.ones_like(x0[:, :, idx]) * mask_prob[..., None]
+ ) # mask if 1, not mask if 0
+ # prompt part don't need to be masked
+ mask[is_prompt.bool()] = 0
+ # Ensure at least one token is masked
+ mask_num = mask[:,].sum(dim=1, keepdim=False)
+ all_zero_mask = (mask_num == 0).bool()
+ row_indices_to_modify = torch.nonzero(all_zero_mask)
+ # mask the first token if all tokens are not masked (may mask pad if random indices)
+ mask[row_indices_to_modify, prompt_len[row_indices_to_modify]] = 1
+
+ mask = mask[..., None] # (B, T, 1)
+ xt = (
+ xt
+ + mask * mask_token[:, None, :]
+ + (1 - mask) * token_emb_idx(x0[:, :, idx])
+ ) # (B, T, hidden_size)
+
+ else:
+ # prompt part don't need to be masked
+ xt = (
+ xt
+ + token_emb_idx(x0[:, :, idx]) * is_prompt[..., None]
+ + mask_token * (1 - is_prompt[..., None])
+ )
+
+ return xt, new_t, mask_layer, mask, prompt_len, mask_prob
+
+ def loss_t(self, x0, x_mask, t, cond=None):
+ xt, new_t, mask_layer, mask, prompt_len, mask_prob = self.forward_diffusion(
+ x0, t
+ )
+ # xt: (B, T, hidden_size)
+ # new_t: (B,)
+ # mask_layer: (1,)
+ # mask: (B, T, 1) mask if 1, not mask if 0
+ # prompt_len: (B,)
+ # mask_prob: (B,)
+
+ mask_layer_cond = self.layer_emb(mask_layer).unsqueeze(1) # (1, 1, hidden_size)
+ cond = cond + mask_layer_cond # (B, T, hidden_size)
+
+ embeds = self.diff_estimator(xt, new_t, cond, x_mask) # (B, T, hidden_size)
+
+ logits = self.to_logits[mask_layer.item()](embeds) # (B, T, codebook_size)
+
+ # final mask used for loss calculation
+ final_mask = mask * x_mask[..., None] # (B, T, 1)
+
+ return logits, mask_layer, final_mask, x0, prompt_len, mask_prob
+
+ def compute_loss(self, x0, x_mask, cond=None):
+ # x0: (B, T, num_quantizer)
+ # x_mask: (B, T) mask is 0 for padding
+ t = torch.rand(x0.shape[0], device=x0.device, requires_grad=False)
+ t = torch.clamp(t, 1e-5, 1.0)
+ return self.loss_t(x0, x_mask, t, cond)
+
+ def reset_parameters(self):
+ def _reset_parameters(m):
+ if isinstance(m, nn.MultiheadAttention):
+ if m._qkv_same_embed_dim:
+ nn.init.normal_(m.in_proj_weight, std=0.02)
+ else:
+ nn.init.normal_(m.q_proj_weight, std=0.02)
+ nn.init.normal_(m.k_proj_weight, std=0.02)
+ nn.init.normal_(m.v_proj_weight, std=0.02)
+
+ if m.in_proj_bias is not None:
+ nn.init.constant_(m.in_proj_bias, 0.0)
+ nn.init.constant_(m.out_proj.bias, 0.0)
+ if m.bias_k is not None:
+ nn.init.xavier_normal_(m.bias_k)
+ if m.bias_v is not None:
+ nn.init.xavier_normal_(m.bias_v)
+
+ elif (
+ isinstance(m, nn.Conv1d)
+ or isinstance(m, nn.ConvTranspose1d)
+ or isinstance(m, nn.Conv2d)
+ or isinstance(m, nn.ConvTranspose2d)
+ ):
+ m.weight.data.normal_(0.0, 0.02)
+
+ elif isinstance(m, nn.Linear):
+ m.weight.data.normal_(mean=0.0, std=0.02)
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ elif isinstance(m, nn.Embedding):
+ m.weight.data.normal_(mean=0.0, std=0.02)
+ if m.padding_idx is not None:
+ m.weight.data[m.padding_idx].zero_()
+
+ self.apply(_reset_parameters)
+
+ @torch.no_grad()
+ def reverse_diffusion(
+ self,
+ cond,
+ prompt,
+ x_mask=None,
+ prompt_mask=None,
+ temp=1.5,
+ filter_thres=0.98,
+ max_layer=None,
+ gt_code=None,
+ n_timesteps=[10, 4, 4, 4, 4, 4, 4, 4],
+ cfg=1.0,
+ rescale_cfg=1.0,
+ ):
+
+ assert (
+ len(n_timesteps) == self.num_quantizer
+ ) # each layer has a number of steps
+
+ prompt_code = prompt # (B, prompt_len, num_quantizer)
+ prompt_len = prompt_code.shape[1]
+ target_len = cond.shape[1] - prompt_len
+
+ if x_mask == None:
+ x_mask = torch.ones(cond.shape[0], target_len).to(cond.device) # (B, T)
+ if prompt_mask == None:
+ prompt_mask = torch.ones(cond.shape[0], prompt_len).to(
+ cond.device
+ ) # (B, prompt_len)
+
+ cum = torch.zeros(x_mask.shape[0], x_mask.shape[1], self.hidden_size).to(
+ x_mask.device
+ ) # (B, T, hidden_size)
+
+ bsz, seq_len, _ = cum.shape
+
+ choice_temp = 1.0
+ start_temp = temp # temperature for sampling
+ start_choice_temp = choice_temp # temperature for choicing mask tokens
+
+ if max_layer is None:
+ max_layer = self.num_quantizer
+
+ xt = torch.LongTensor(bsz, seq_len, max_layer).to(x_mask.device)
+
+ if gt_code is not None:
+ gt_layer = gt_code.shape[-1]
+ xt[:, :, :gt_layer] = gt_code
+ for i in range(gt_layer):
+ cum += self.token_emb[i](xt[:, :, i])
+ else:
+ gt_layer = 0
+
+ for mask_layer in range(gt_layer, max_layer):
+ steps = n_timesteps[mask_layer]
+ to_logits = self.to_logits[mask_layer]
+ token_emb = self.token_emb[mask_layer]
+ mask_layer = torch.tensor(mask_layer).to(x_mask.device).long().unsqueeze(0)
+ mask_layer_cond = self.layer_emb(mask_layer).unsqueeze(
+ 1
+ ) # (1,) -> (1, 1, hidden_size)
+ temp_cond = cond + mask_layer_cond # (B, T, hidden_size)
+
+ mask_token = self.mask_emb(torch.zeros_like(mask_layer)) # (1, hidden_size)
+ mask = torch.full((bsz, seq_len, 1), True).to(x_mask.device) # (B, T, 1)
+ seq = torch.full((bsz, seq_len), 0).to(x_mask.device)
+
+ h = 1.0 / steps
+
+ # prompt_code: (B, prompt_len, num_quantizer)
+ cur_prompt = 0
+ for idx, emb in enumerate(self.token_emb):
+ cur_prompt = cur_prompt + emb(
+ prompt_code[:, :, idx]
+ ) # (B, prompt_len, hidden_size)
+
+ t_list = [1.0 - i * h for i in range(steps)]
+ t_list.append(0.0)
+ for i in range(steps):
+ t = t_list[i] * torch.ones(bsz).to(x_mask.device)
+ token = token_emb(seq) # (B, T, hidden_size)
+ cur = cum + mask * mask_token[:, None, :] + (~mask) * token
+ cur = cur + mask_token[:, None, :] * (max_layer - 1 - mask_layer)
+
+ xt_input = torch.cat([cur_prompt, cur], dim=1) # (B, T, hidden_size)
+ xt_mask = torch.cat(
+ [prompt_mask, x_mask], dim=1
+ ) # (B, T), mask is 0 for padding
+
+ embeds = self.diff_estimator(xt_input, t, temp_cond, xt_mask)
+ embeds = embeds[:, prompt_len:, :]
+
+ # cfg
+ if cfg > 0:
+ mask_embeds = self.diff_estimator(
+ cur, t, temp_cond[:, prompt_len:, :], x_mask
+ )
+ pos_emb_std = embeds.std() # std(g_cond)
+ embeds = embeds + cfg * (embeds - mask_embeds) # g_cfg
+ rescale_embeds = embeds * pos_emb_std / embeds.std() # g_final
+ embeds = rescale_cfg * rescale_embeds + (1 - rescale_cfg) * embeds
+
+ logits = to_logits(embeds) # (B, T, codebook_size)
+ annealing_scale = t_list[i]
+
+ choice_temp = start_choice_temp * annealing_scale
+ temp = start_temp * annealing_scale
+ logits = top_k(logits, filter_thres)
+
+ if i == steps - 1:
+ # greedy
+ if steps == 1:
+ temp = 0.2
+ sampled_ids = gumbel_sample(logits, temperature=max(temp, 1e-3))
+ else:
+ sampled_ids = logits.argmax(dim=-1)
+
+ else:
+ # sampling
+ sampled_ids = gumbel_sample(logits, temperature=max(temp, 1e-3))
+
+ seq = torch.where(mask.squeeze(-1), sampled_ids, seq)
+
+ scores = logits.softmax(dim=-1)
+ scores = scores.gather(2, rearrange(sampled_ids, "b n -> b n 1"))
+ scores = rearrange(scores, "b n 1 -> b n")
+
+ scores = choice_temp * gumbel_noise(scores) + scores
+ scores = 1 - scores
+
+ next_t = t_list[i + 1] * torch.ones(bsz).to(x_mask.device)
+
+ next_mask_num = (self.mask_prob(next_t) * seq_len).long()[0].item()
+
+ if next_mask_num == 0:
+ break
+ scores = scores.masked_fill(
+ ~mask.squeeze(-1), -torch.finfo(scores.dtype).max
+ )
+
+ mask_indices = scores.topk(next_mask_num, dim=-1).indices
+ mask = torch.zeros_like(scores, dtype=torch.bool).scatter(
+ 1, mask_indices, True
+ )
+ seq = seq.masked_fill(mask, 0)
+
+ mask = mask.unsqueeze(-1)
+
+ cum = cum + token_emb(seq)
+ xt[..., mask_layer.squeeze(0).item()] = seq
+
+ return xt
+
+ def forward(self, x0, x_mask, cond_code=None):
+ # x0: (B, T, num_quantizer)
+ # x_mask: (B, T) mask is 0 for padding
+ # cond_code: semantic token (B, T)
+ cond = self.cond_emb(cond_code)
+
+ logits, mask_layer, final_mask, x0, prompt_len, mask_prob = self.compute_loss(
+ x0,
+ x_mask,
+ cond,
+ )
+ return logits, mask_layer, final_mask, x0, prompt_len, mask_prob
diff --git a/models/tts/maskgct/maskgct_t2s.py b/models/tts/maskgct/maskgct_t2s.py
new file mode 100644
index 0000000000000000000000000000000000000000..8088531b53fba59d47267aae39fbd4576e4338e4
--- /dev/null
+++ b/models/tts/maskgct/maskgct_t2s.py
@@ -0,0 +1,364 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import numpy as np
+import torch.nn as nn
+import math
+from einops import rearrange
+from models.tts.maskgct.llama_nar import DiffLlamaPrefix
+
+
+def top_k(logits, thres=0.9):
+ k = math.ceil((1 - thres) * logits.shape[-1])
+ val, ind = logits.topk(k, dim=-1)
+ probs = torch.full_like(logits, float("-inf"))
+ probs.scatter_(2, ind, val)
+ return probs
+
+
+def log(t, eps=1e-10):
+ return torch.log(t + eps)
+
+
+def gumbel_noise(t):
+ noise = torch.zeros_like(t).uniform_(0, 1)
+ return -log(-log(noise))
+
+
+def gumbel_sample(t, temperature=1.0, dim=-1):
+ return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
+
+
+class MaskGCT_T2S(nn.Module):
+ def __init__(
+ self,
+ hidden_size=1024,
+ num_layers=16,
+ num_heads=16,
+ cfg_scale=0.2,
+ cond_codebook_size=8192,
+ cond_dim=1024,
+ cfg=None,
+ ):
+ super().__init__()
+
+ hidden_size = (
+ cfg.hidden_size
+ if cfg is not None and hasattr(cfg, "hidden_size")
+ else hidden_size
+ )
+ num_layers = (
+ cfg.num_layers
+ if cfg is not None and hasattr(cfg, "num_layers")
+ else num_layers
+ )
+ num_heads = (
+ cfg.num_heads
+ if cfg is not None and hasattr(cfg, "num_heads")
+ else num_heads
+ )
+ cfg_scale = (
+ cfg.cfg_scale
+ if cfg is not None and hasattr(cfg, "cfg_scale")
+ else cfg_scale
+ )
+ cond_codebook_size = (
+ cfg.cond_codebook_size
+ if cfg is not None and hasattr(cfg, "cond_codebook_size")
+ else cond_codebook_size
+ )
+ cond_dim = (
+ cfg.cond_dim if cfg is not None and hasattr(cfg, "cond_dim") else cond_dim
+ )
+
+ self.hidden_size = hidden_size
+ self.num_layers = num_layers
+ self.num_heads = num_heads
+ self.cfg_scale = cfg_scale
+ self.cond_codebook_size = cond_codebook_size
+ self.cond_dim = cond_dim
+
+ self.mask_emb = nn.Embedding(1, self.hidden_size)
+
+ self.to_logit = nn.Linear(self.hidden_size, self.cond_codebook_size)
+
+ self.cond_emb = nn.Embedding(cond_codebook_size, self.hidden_size)
+
+ self.phone_emb = nn.Embedding(1024, hidden_size, padding_idx=1023)
+
+ self.reset_parameters()
+
+ self.diff_estimator = DiffLlamaPrefix(
+ hidden_size=hidden_size,
+ num_heads=num_heads,
+ num_layers=num_layers,
+ )
+
+ def mask_prob(self, t):
+ return torch.sin(t * np.pi / 2).to(t.device)
+
+ def forward_diffusion(self, x0, t):
+ # x0: semantic tokens (B, T)
+ new_t = t
+ mask_prob = self.mask_prob(new_t) # (B,)
+ # if mask_prob[i] < 0.2, mask_prob[i] = 0.2
+ mask_prob = torch.where(
+ mask_prob < 0.2, torch.ones_like(mask_prob) * 0.2, mask_prob
+ )
+ mask_token = self.mask_emb(
+ torch.LongTensor([0]).to(x0.device)
+ ) # (1, hidden_size)
+
+ xt = torch.zeros(x0.shape[0], x0.shape[1], self.hidden_size).to(x0.device)
+
+ cfg_scale = self.cfg_scale
+
+ # a segment of r% sequence length is masked, where r ~ U[60, 100]
+ if torch.rand(1) > cfg_scale:
+ prompt_len = torch.randint(
+ min(x0.shape[1] // 4, 5), int(x0.shape[1] * 0.4), (x0.shape[0],)
+ ).to(
+ x0.device
+ ) # (B,)
+ else:
+ prompt_len = torch.zeros(x0.shape[0]).to(x0) # (B,)
+
+ # get is prompt
+ is_prompt = torch.zeros_like(x0[:, :]) # (B, T)
+ col_indices = (
+ torch.arange(is_prompt.shape[1])
+ .repeat(is_prompt.shape[0], 1)
+ .to(prompt_len)
+ ) # (B, T)
+ is_prompt[col_indices < prompt_len.unsqueeze(1)] = 1 # (B, T) 1 if prompt
+
+ # Add mask
+ mask = torch.bernoulli(torch.ones_like(x0[:, :]) * mask_prob[..., None])
+ mask[is_prompt.bool()] = 0
+ mask_num = mask[:,].sum(dim=1, keepdim=False)
+ all_zero_mask = (mask_num == 0).bool()
+ row_indices_to_modify = torch.nonzero(all_zero_mask)
+ mask[row_indices_to_modify, prompt_len[row_indices_to_modify]] = 1
+ mask = mask[..., None] # (B, T, 1)
+ xt = (
+ xt + mask * mask_token[:, None, :] + (1 - mask) * self.cond_emb(x0[:, :])
+ ) # (B, T, hidden_size)
+
+ return xt, new_t, mask, prompt_len, mask_prob
+
+ def loss_t(self, x0, x_mask, t, phone_embedding=None, phone_mask=None):
+ xt, new_t, mask, prompt_len, mask_prob = self.forward_diffusion(x0, t)
+ # xt: (B, T, hidden_size)
+ # new_t: (B,)
+ # mask: (B, T, 1) mask if 1, not mask if 0
+ # prompt_len: (B,)
+ # mask_prob: (B,)
+
+ embeds = self.diff_estimator(
+ xt, new_t, x_mask, phone_embedding=phone_embedding, phone_mask=phone_mask
+ ) # (B, T, hidden_size)
+ logits = self.to_logit(embeds) # (B, T, codebook_size)
+
+ # final mask used for loss calculation
+ final_mask = mask * x_mask[..., None] # (B, T, 1)
+
+ return logits, final_mask, x0, prompt_len, mask_prob
+
+ def compute_loss(self, x0, x_mask, phone_embedding=None, phone_mask=None):
+ # x0: (B, T)
+ # x_mask: (B, T) mask is 0 for padding
+ t = torch.rand(x0.shape[0], device=x0.device, requires_grad=False)
+ t = torch.clamp(t, 1e-5, 1.0)
+ return self.loss_t(x0, x_mask, t, phone_embedding, phone_mask)
+
+ def reset_parameters(self):
+ def _reset_parameters(m):
+ if isinstance(m, nn.MultiheadAttention):
+ if m._qkv_same_embed_dim:
+ nn.init.normal_(m.in_proj_weight, std=0.02)
+ else:
+ nn.init.normal_(m.q_proj_weight, std=0.02)
+ nn.init.normal_(m.k_proj_weight, std=0.02)
+ nn.init.normal_(m.v_proj_weight, std=0.02)
+
+ if m.in_proj_bias is not None:
+ nn.init.constant_(m.in_proj_bias, 0.0)
+ nn.init.constant_(m.out_proj.bias, 0.0)
+ if m.bias_k is not None:
+ nn.init.xavier_normal_(m.bias_k)
+ if m.bias_v is not None:
+ nn.init.xavier_normal_(m.bias_v)
+
+ elif (
+ isinstance(m, nn.Conv1d)
+ or isinstance(m, nn.ConvTranspose1d)
+ or isinstance(m, nn.Conv2d)
+ or isinstance(m, nn.ConvTranspose2d)
+ ):
+ m.weight.data.normal_(0.0, 0.02)
+
+ elif isinstance(m, nn.Linear):
+ m.weight.data.normal_(mean=0.0, std=0.02)
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ elif isinstance(m, nn.Embedding):
+ m.weight.data.normal_(mean=0.0, std=0.02)
+ if m.padding_idx is not None:
+ m.weight.data[m.padding_idx].zero_()
+
+ self.apply(_reset_parameters)
+
+ @torch.no_grad()
+ def reverse_diffusion(
+ self,
+ prompt,
+ target_len,
+ phone_id,
+ prompt_mask=None,
+ temp=0.9,
+ filter_thres=0.98,
+ n_timesteps=40,
+ cfg=1.0,
+ rescale_cfg=1.0,
+ ):
+ # prompt: (B, T)
+ phone_embedding = self.phone_emb(phone_id)
+
+ prompt_code = prompt # (B, prompt_len)
+ prompt_len = prompt_code.shape[1]
+
+ x_mask = torch.ones(prompt_code.shape[0], target_len).to(
+ prompt_code.device
+ ) # (B, target_len)
+ phone_mask = torch.ones_like(phone_id)
+
+ if prompt_mask == None:
+ prompt_mask = torch.ones(prompt_code.shape[0], prompt_len).to(
+ prompt_code.device
+ ) # (B, prompt_len)
+
+ cum = torch.zeros(x_mask.shape[0], x_mask.shape[1], self.hidden_size).to(
+ x_mask.device
+ ) # (B, T, hidden_size)
+
+ bsz, seq_len, _ = cum.shape
+
+ choice_temp = 1.0
+ start_temp = temp # temperature for sampling
+ start_choice_temp = choice_temp # temperature for choicing mask tokens
+
+ xt = torch.LongTensor(bsz, seq_len).to(x_mask.device)
+
+ steps = n_timesteps
+ to_logit = self.to_logit
+ cond_emb = self.cond_emb
+
+ mask_token = self.mask_emb(torch.LongTensor([0]).to(xt.device))
+ mask = torch.full((bsz, seq_len, 1), True).to(x_mask.device) # (B, T, 1)
+ seq = torch.full((bsz, seq_len), 0).to(x_mask.device)
+ h = 1.0 / steps
+
+ cur_prompt = 0
+ cur_prompt = cur_prompt + cond_emb(prompt_code)
+
+ t_list = [1.0 - i * h for i in range(steps)]
+ t_list.append(0.0)
+ for i in range(steps):
+ t = t_list[i] * torch.ones(bsz).to(x_mask.device)
+ token = cond_emb(seq) # (B, T, hidden_size)
+ cur = cum + mask * mask_token[:, None, :] + (~mask) * token
+
+ xt_input = torch.cat([cur_prompt, cur], dim=1) # (B, T, hidden_size)
+ xt_mask = torch.cat(
+ [prompt_mask, x_mask], dim=1
+ ) # (B, T), mask is 0 for padding
+
+ embeds = self.diff_estimator(
+ xt_input,
+ t,
+ xt_mask,
+ phone_embedding=phone_embedding,
+ phone_mask=phone_mask,
+ )
+ embeds = embeds[:, prompt_len:, :]
+
+ # classifier free guidance
+ # phone_embedding=phone_embedding[:,phone_embedding.shape[1]:,:] means phone_embedding is None
+ if cfg > 0:
+ mask_embeds = self.diff_estimator(
+ cur,
+ t,
+ x_mask,
+ phone_embedding=phone_embedding[:, phone_embedding.shape[1] :, :],
+ phone_mask=phone_mask[:, prompt_len:],
+ )
+ pos_emb_std = embeds.std() # std(g_cond)
+ embeds = embeds + cfg * (embeds - mask_embeds) # g_cfg
+ rescale_embeds = embeds * pos_emb_std / embeds.std() # g_final
+ embeds = rescale_cfg * rescale_embeds + (1 - rescale_cfg) * embeds
+
+ logits = to_logit(embeds) # (B, T, codebook_size)
+ annealing_scale = t_list[i]
+
+ choice_temp = start_choice_temp * annealing_scale
+ temp = start_temp * annealing_scale
+ logits = top_k(logits, filter_thres)
+
+ if i == steps - 1:
+ # greedy
+ if steps == 1:
+ temp = 0.2
+ sampled_ids = gumbel_sample(logits, temperature=max(temp, 1e-3))
+ else:
+ sampled_ids = logits.argmax(dim=-1)
+
+ else:
+ # sampling
+ sampled_ids = gumbel_sample(logits, temperature=max(temp, 1e-3))
+
+ seq = torch.where(mask.squeeze(-1), sampled_ids, seq)
+
+ scores = logits.softmax(dim=-1)
+ scores = scores.gather(2, rearrange(sampled_ids, "b n -> b n 1"))
+ scores = rearrange(scores, "b n 1 -> b n")
+
+ scores = choice_temp * gumbel_noise(scores) + scores
+ scores = 1 - scores
+
+ next_t = t_list[i + 1] * torch.ones(bsz).to(x_mask.device)
+
+ next_mask_num = (self.mask_prob(next_t) * seq_len).long()[0].item()
+
+ if next_mask_num == 0:
+ break
+ scores = scores.masked_fill(
+ ~mask.squeeze(-1), -torch.finfo(scores.dtype).max
+ )
+
+ mask_indices = scores.topk(next_mask_num, dim=-1).indices
+ mask = torch.zeros_like(scores, dtype=torch.bool).scatter(
+ 1, mask_indices, True
+ )
+ seq = seq.masked_fill(mask, 0)
+
+ mask = mask.unsqueeze(-1)
+
+ cum = cum + cond_emb(seq)
+ xt = seq
+
+ return xt
+
+ def forward(self, x0, x_mask, phone_id=None, phone_mask=None):
+ # x0: (B, T)
+ # x_mask: (B, T) mask is 0 for padding
+
+ phone_embedding = self.phone_emb(phone_id)
+
+ logits, final_mask, x0, prompt_len, mask_prob = self.compute_loss(
+ x0, x_mask, phone_embedding, phone_mask=phone_mask
+ )
+ return logits, final_mask, x0, prompt_len, mask_prob
diff --git a/models/tts/maskgct/maskgct_utils.py b/models/tts/maskgct/maskgct_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..35217c4cbe9b6fc4f8835f347ed6f5bbdf5264ce
--- /dev/null
+++ b/models/tts/maskgct/maskgct_utils.py
@@ -0,0 +1,283 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import librosa
+import os
+import pickle
+import math
+import json
+import accelerate
+import safetensors
+from utils.util import load_config
+from tqdm import tqdm
+
+from models.codec.kmeans.repcodec_model import RepCodec
+from models.tts.maskgct.maskgct_s2a import MaskGCT_S2A
+from models.tts.maskgct.maskgct_t2s import MaskGCT_T2S
+from models.codec.amphion_codec.codec import CodecEncoder, CodecDecoder
+from transformers import Wav2Vec2BertModel
+
+from models.tts.maskgct.g2p.g2p_generation import g2p, chn_eng_g2p
+
+from transformers import SeamlessM4TFeatureExtractor
+
+
+def g2p_(text, language):
+ if language in ["zh", "en"]:
+ return chn_eng_g2p(text)
+ else:
+ return g2p(text, sentence=None, language=language)
+
+
+def build_t2s_model(cfg, device):
+ t2s_model = MaskGCT_T2S(cfg=cfg)
+ t2s_model.eval()
+ t2s_model.to(device)
+ return t2s_model
+
+
+def build_s2a_model(cfg, device):
+ soundstorm_model = MaskGCT_S2A(cfg=cfg)
+ soundstorm_model.eval()
+ soundstorm_model.to(device)
+ return soundstorm_model
+
+
+def build_semantic_model(device):
+ semantic_model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0")
+ semantic_model.eval()
+ semantic_model.to(device)
+ stat_mean_var = torch.load("./models/tts/maskgct/ckpt/wav2vec2bert_stats.pt")
+ semantic_mean = stat_mean_var["mean"]
+ semantic_std = torch.sqrt(stat_mean_var["var"])
+ semantic_mean = semantic_mean.to(device)
+ semantic_std = semantic_std.to(device)
+ return semantic_model, semantic_mean, semantic_std
+
+
+def build_semantic_codec(cfg, device):
+ semantic_codec = RepCodec(cfg=cfg)
+ semantic_codec.eval()
+ semantic_codec.to(device)
+ return semantic_codec
+
+
+def build_acoustic_codec(cfg, device):
+ codec_encoder = CodecEncoder(cfg=cfg.encoder)
+ codec_decoder = CodecDecoder(cfg=cfg.decoder)
+ codec_encoder.eval()
+ codec_decoder.eval()
+ codec_encoder.to(device)
+ codec_decoder.to(device)
+ return codec_encoder, codec_decoder
+
+
+class MaskGCT_Inference_Pipeline:
+ def __init__(
+ self,
+ semantic_model,
+ semantic_codec,
+ codec_encoder,
+ codec_decoder,
+ t2s_model,
+ s2a_model_1layer,
+ s2a_model_full,
+ semantic_mean,
+ semantic_std,
+ device,
+ ):
+ self.processor = SeamlessM4TFeatureExtractor.from_pretrained(
+ "facebook/w2v-bert-2.0"
+ )
+ self.semantic_model = semantic_model
+ self.semantic_codec = semantic_codec
+ self.codec_encoder = codec_encoder
+ self.codec_decoder = codec_decoder
+ self.t2s_model = t2s_model
+ self.s2a_model_1layer = s2a_model_1layer
+ self.s2a_model_full = s2a_model_full
+ self.semantic_mean = semantic_mean
+ self.semantic_std = semantic_std
+ self.device = device
+
+ @torch.no_grad()
+ def extract_features(self, speech):
+ inputs = self.processor(speech, sampling_rate=16000, return_tensors="pt")
+ input_features = inputs["input_features"][0]
+ attention_mask = inputs["attention_mask"][0]
+ return input_features, attention_mask
+
+ @torch.no_grad()
+ def extract_semantic_code(self, input_features, attention_mask):
+ vq_emb = self.semantic_model(
+ input_features=input_features,
+ attention_mask=attention_mask,
+ output_hidden_states=True,
+ )
+ feat = vq_emb.hidden_states[17] # (B, T, C)
+ feat = (feat - self.semantic_mean.to(feat)) / self.semantic_std.to(feat)
+
+ semantic_code, rec_feat = self.semantic_codec.quantize(feat) # (B, T)
+ return semantic_code, rec_feat
+
+ @torch.no_grad()
+ def extract_acoustic_code(self, speech):
+ vq_emb = self.codec_encoder(speech.unsqueeze(1))
+ _, vq, _, _, _ = self.codec_decoder.quantizer(vq_emb)
+ acoustic_code = vq.permute(1, 2, 0)
+ return acoustic_code
+
+ @torch.no_grad()
+ def text2semantic(
+ self,
+ prompt_speech,
+ prompt_text,
+ prompt_language,
+ target_text,
+ target_language,
+ target_len=None,
+ n_timesteps=50,
+ cfg=2.5,
+ rescale_cfg=0.75,
+ ):
+ prompt_phone_id = g2p_(prompt_text, prompt_language)[1]
+ target_phone_id = g2p_(target_text, target_language)[1]
+
+ if target_len is None:
+ target_len = int(
+ (len(prompt_speech) * len(target_phone_id) / len(prompt_phone_id))
+ / 16000
+ * 50
+ )
+ else:
+ target_len = int(target_len * 50)
+
+ prompt_phone_id = torch.tensor(prompt_phone_id, dtype=torch.long).to(
+ self.device
+ )
+ target_phone_id = torch.tensor(target_phone_id, dtype=torch.long).to(
+ self.device
+ )
+
+ phone_id = torch.cat([prompt_phone_id, target_phone_id])
+
+ input_features, attention_mask = self.extract_features(prompt_speech)
+ input_features = input_features.unsqueeze(0).to(self.device)
+ attention_mask = attention_mask.unsqueeze(0).to(self.device)
+ semantic_code, _ = self.extract_semantic_code(input_features, attention_mask)
+
+ predict_semantic = self.t2s_model.reverse_diffusion(
+ semantic_code[:, :],
+ target_len,
+ phone_id.unsqueeze(0),
+ n_timesteps=n_timesteps,
+ cfg=cfg,
+ rescale_cfg=rescale_cfg,
+ )
+
+ print("predict semantic shape", predict_semantic.shape)
+
+ combine_semantic_code = torch.cat(
+ [semantic_code[:, :], predict_semantic], dim=-1
+ )
+ prompt_semantic_code = semantic_code
+
+ return combine_semantic_code, prompt_semantic_code
+
+ @torch.no_grad()
+ def semantic2acoustic(
+ self,
+ combine_semantic_code,
+ acoustic_code,
+ n_timesteps=[25, 10, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
+ cfg=2.5,
+ rescale_cfg=0.75,
+ ):
+ semantic_code = combine_semantic_code
+
+ cond = self.s2a_model_1layer.cond_emb(semantic_code)
+ prompt = acoustic_code[:, :, :]
+ predict_1layer = self.s2a_model_1layer.reverse_diffusion(
+ cond=cond,
+ prompt=prompt,
+ temp=1.5,
+ filter_thres=0.98,
+ n_timesteps=n_timesteps[:1],
+ cfg=cfg,
+ rescale_cfg=rescale_cfg,
+ )
+
+ cond = self.s2a_model_full.cond_emb(semantic_code)
+ prompt = acoustic_code[:, :, :]
+ predict_full = self.s2a_model_full.reverse_diffusion(
+ cond=cond,
+ prompt=prompt,
+ temp=1.5,
+ filter_thres=0.98,
+ n_timesteps=n_timesteps,
+ cfg=cfg,
+ rescale_cfg=rescale_cfg,
+ gt_code=predict_1layer,
+ )
+
+ vq_emb = self.codec_decoder.vq2emb(
+ predict_full.permute(2, 0, 1), n_quantizers=12
+ )
+ recovered_audio = self.codec_decoder(vq_emb)
+ prompt_vq_emb = self.codec_decoder.vq2emb(
+ prompt.permute(2, 0, 1), n_quantizers=12
+ )
+ recovered_prompt_audio = self.codec_decoder(prompt_vq_emb)
+ recovered_prompt_audio = recovered_prompt_audio[0][0].cpu().numpy()
+ recovered_audio = recovered_audio[0][0].cpu().numpy()
+ combine_audio = np.concatenate([recovered_prompt_audio, recovered_audio])
+
+ return combine_audio, recovered_audio
+
+ def maskgct_inference(
+ self,
+ prompt_speech_path,
+ prompt_text,
+ target_text,
+ language="en",
+ target_language="en",
+ target_len=None,
+ n_timesteps=25,
+ cfg=2.5,
+ rescale_cfg=0.75,
+ n_timesteps_s2a=[25, 10, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
+ cfg_s2a=2.5,
+ rescale_cfg_s2a=0.75,
+ ):
+ speech_16k = librosa.load(prompt_speech_path, sr=16000)[0]
+ speech = librosa.load(prompt_speech_path, sr=24000)[0]
+
+ combine_semantic_code, _ = self.text2semantic(
+ speech_16k,
+ prompt_text,
+ language,
+ target_text,
+ target_language,
+ target_len,
+ n_timesteps,
+ cfg,
+ rescale_cfg,
+ )
+ acoustic_code = self.extract_acoustic_code(
+ torch.tensor(speech).unsqueeze(0).to(self.device)
+ )
+ _, recovered_audio = self.semantic2acoustic(
+ combine_semantic_code,
+ acoustic_code,
+ n_timesteps=n_timesteps_s2a,
+ cfg=cfg_s2a,
+ rescale_cfg=rescale_cfg_s2a,
+ )
+
+ return recovered_audio
diff --git a/models/tts/maskgct/wav/prompt.wav b/models/tts/maskgct/wav/prompt.wav
new file mode 100644
index 0000000000000000000000000000000000000000..f3d41ea9b87c7a60f4f3aa4ec32f38671ec177eb
--- /dev/null
+++ b/models/tts/maskgct/wav/prompt.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1b083e2553dc932eeb4f877237ea3d3c5cdb408ee71da629470e2326d2165bb9
+size 1111542
diff --git a/models/tts/naturalspeech2/__init__.py b/models/tts/naturalspeech2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/tts/naturalspeech2/diffusion.py b/models/tts/naturalspeech2/diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..c74c20ca7e69170cee189bd3942808ba9b1b743c
--- /dev/null
+++ b/models/tts/naturalspeech2/diffusion.py
@@ -0,0 +1,124 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import numpy as np
+import torch.nn.functional as F
+from models.tts.naturalspeech2.wavenet import WaveNet
+
+
+class Diffusion(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+
+ self.cfg = cfg
+
+ self.diff_estimator = WaveNet(cfg.wavenet)
+ self.beta_min = cfg.beta_min
+ self.beta_max = cfg.beta_max
+ self.sigma = cfg.sigma
+ self.noise_factor = cfg.noise_factor
+
+ def forward(self, x, x_mask, cond, spk_query_emb, offset=1e-5):
+ """
+ x: (B, 128, T)
+ x_mask: (B, T), mask is 0
+ cond: (B, T, 512)
+ spk_query_emb: (B, 32, 512)
+ """
+ diffusion_step = torch.rand(
+ x.shape[0], dtype=x.dtype, device=x.device, requires_grad=False
+ )
+ diffusion_step = torch.clamp(diffusion_step, offset, 1.0 - offset)
+ xt, z = self.forward_diffusion(x0=x, diffusion_step=diffusion_step)
+
+ cum_beta = self.get_cum_beta(diffusion_step.unsqueeze(-1).unsqueeze(-1))
+ x0_pred = self.diff_estimator(xt, x_mask, cond, diffusion_step, spk_query_emb)
+ mean_pred = x0_pred * torch.exp(-0.5 * cum_beta / (self.sigma**2))
+ variance = (self.sigma**2) * (1.0 - torch.exp(-cum_beta / (self.sigma**2)))
+ noise_pred = (xt - mean_pred) / (torch.sqrt(variance) * self.noise_factor)
+ noise = z
+ diff_out = {"x0_pred": x0_pred, "noise_pred": noise_pred, "noise": noise}
+ return diff_out
+
+ @torch.no_grad()
+ def get_cum_beta(self, time_step):
+ return self.beta_min * time_step + 0.5 * (self.beta_max - self.beta_min) * (
+ time_step**2
+ )
+
+ @torch.no_grad()
+ def get_beta_t(self, time_step):
+ return self.beta_min + (self.beta_max - self.beta_min) * time_step
+
+ @torch.no_grad()
+ def forward_diffusion(self, x0, diffusion_step):
+ """
+ x0: (B, 128, T)
+ time_step: (B,)
+ """
+ time_step = diffusion_step.unsqueeze(-1).unsqueeze(-1)
+ cum_beta = self.get_cum_beta(time_step)
+ mean = x0 * torch.exp(-0.5 * cum_beta / (self.sigma**2))
+ variance = (self.sigma**2) * (1 - torch.exp(-cum_beta / (self.sigma**2)))
+ z = torch.randn(x0.shape, dtype=x0.dtype, device=x0.device, requires_grad=False)
+ xt = mean + z * torch.sqrt(variance) * self.noise_factor
+ return xt, z
+
+ @torch.no_grad()
+ def cal_dxt(self, xt, x_mask, cond, spk_query_emb, diffusion_step, h):
+ time_step = diffusion_step.unsqueeze(-1).unsqueeze(-1)
+ cum_beta = self.get_cum_beta(time_step=time_step)
+ beta_t = self.get_beta_t(time_step=time_step)
+ x0_pred = self.diff_estimator(xt, x_mask, cond, diffusion_step, spk_query_emb)
+ mean_pred = x0_pred * torch.exp(-0.5 * cum_beta / (self.sigma**2))
+ noise_pred = xt - mean_pred
+ variance = (self.sigma**2) * (1.0 - torch.exp(-cum_beta / (self.sigma**2)))
+ logp = -noise_pred / (variance + 1e-8)
+ dxt = -0.5 * h * beta_t * (logp + xt / (self.sigma**2))
+ return dxt
+
+ @torch.no_grad()
+ def reverse_diffusion(self, z, x_mask, cond, n_timesteps, spk_query_emb):
+ h = 1.0 / max(n_timesteps, 1)
+ xt = z
+ for i in range(n_timesteps):
+ t = (1.0 - (i + 0.5) * h) * torch.ones(
+ z.shape[0], dtype=z.dtype, device=z.device
+ )
+ dxt = self.cal_dxt(xt, x_mask, cond, spk_query_emb, diffusion_step=t, h=h)
+ xt_ = xt - dxt
+ if self.cfg.ode_solver == "midpoint":
+ x_mid = 0.5 * (xt_ + xt)
+ dxt = self.cal_dxt(
+ x_mid, x_mask, cond, spk_query_emb, diffusion_step=t + 0.5 * h, h=h
+ )
+ xt = xt - dxt
+ elif self.cfg.ode_solver == "euler":
+ xt = xt_
+ return xt
+
+ @torch.no_grad()
+ def reverse_diffusion_from_t(
+ self, z, x_mask, cond, n_timesteps, spk_query_emb, t_start
+ ):
+ h = t_start / max(n_timesteps, 1)
+ xt = z
+ for i in range(n_timesteps):
+ t = (t_start - (i + 0.5) * h) * torch.ones(
+ z.shape[0], dtype=z.dtype, device=z.device
+ )
+ dxt = self.cal_dxt(xt, x_mask, cond, spk_query_emb, diffusion_step=t, h=h)
+ xt_ = xt - dxt
+ if self.cfg.ode_solver == "midpoint":
+ x_mid = 0.5 * (xt_ + xt)
+ dxt = self.cal_dxt(
+ x_mid, x_mask, cond, spk_query_emb, diffusion_step=t + 0.5 * h, h=h
+ )
+ xt = xt - dxt
+ elif self.cfg.ode_solver == "euler":
+ xt = xt_
+ return xt
diff --git a/models/tts/naturalspeech2/diffusion_flow.py b/models/tts/naturalspeech2/diffusion_flow.py
new file mode 100644
index 0000000000000000000000000000000000000000..505fecc28d0602b3104a9befbbfd3df0796b8f94
--- /dev/null
+++ b/models/tts/naturalspeech2/diffusion_flow.py
@@ -0,0 +1,79 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import numpy as np
+import torch.nn.functional as F
+from models.tts.naturalspeech2.wavenet import WaveNet
+
+
+class DiffusionFlow(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+
+ self.diff_estimator = WaveNet(cfg.wavenet)
+ self.beta_min = cfg.beta_min
+ self.beta_max = cfg.beta_max
+ self.sigma = cfg.sigma
+ self.noise_factor = cfg.noise_factor
+
+ def forward(self, x, x_mask, cond, spk_query_emb, offset=1e-5):
+ """
+ x: (B, 128, T)
+ x_mask: (B, T), mask is 0
+ cond: (B, T, 512)
+ spk_query_emb: (B, 32, 512)
+ """
+ diffusion_step = torch.rand(
+ x.shape[0], dtype=x.dtype, device=x.device, requires_grad=False
+ )
+ diffusion_step = torch.clamp(diffusion_step, offset, 1.0 - offset)
+ xt, z = self.forward_diffusion(x0=x, diffusion_step=diffusion_step)
+
+ flow_pred = self.diff_estimator(
+ xt, x_mask, cond, diffusion_step, spk_query_emb
+ ) # noise - x0_pred, noise_pred - x0
+ noise = z
+ x0_pred = noise - flow_pred
+ noise_pred = x + flow_pred
+ diff_out = {
+ "x0_pred": x0_pred,
+ "noise_pred": noise_pred,
+ "noise": noise,
+ "flow_pred": flow_pred,
+ }
+ return diff_out
+
+ @torch.no_grad()
+ def forward_diffusion(self, x0, diffusion_step):
+ """
+ x0: (B, 128, T)
+ time_step: (B,)
+ """
+ time_step = diffusion_step.unsqueeze(-1).unsqueeze(-1)
+ z = torch.randn(x0.shape, dtype=x0.dtype, device=x0.device, requires_grad=False)
+ xt = (1 - time_step) * x0 + time_step * z
+ return xt, z
+
+ @torch.no_grad()
+ def cal_dxt(self, xt, x_mask, cond, spk_query_emb, diffusion_step, h):
+ flow_pred = self.diff_estimator(
+ xt, x_mask, cond, diffusion_step, spk_query_emb
+ ) # z - x0 = x1 - x0
+ dxt = h * flow_pred
+ return dxt
+
+ @torch.no_grad()
+ def reverse_diffusion(self, z, x_mask, cond, n_timesteps, spk_query_emb):
+ h = 1.0 / n_timesteps
+ xt = z
+ for i in range(n_timesteps):
+ t = (1.0 - (i + 0.5) * h) * torch.ones(
+ z.shape[0], dtype=z.dtype, device=z.device
+ )
+ dxt = self.cal_dxt(xt, x_mask, cond, spk_query_emb, diffusion_step=t, h=h)
+ xt = xt - dxt
+ return xt
diff --git a/models/tts/naturalspeech2/ns2.py b/models/tts/naturalspeech2/ns2.py
new file mode 100644
index 0000000000000000000000000000000000000000..9531644625ac031024bc614b5e66f155360f3baf
--- /dev/null
+++ b/models/tts/naturalspeech2/ns2.py
@@ -0,0 +1,259 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import numpy as np
+import torch.nn.functional as F
+from models.tts.naturalspeech2.diffusion import Diffusion
+from models.tts.naturalspeech2.diffusion_flow import DiffusionFlow
+from models.tts.naturalspeech2.wavenet import WaveNet
+from models.tts.naturalspeech2.prior_encoder import PriorEncoder
+from modules.naturalpseech2.transformers import TransformerEncoder
+from encodec import EncodecModel
+from einops import rearrange, repeat
+
+import os
+import json
+
+
+class NaturalSpeech2(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.cfg = cfg
+
+ self.latent_dim = cfg.latent_dim
+ self.query_emb_num = cfg.query_emb.query_token_num
+
+ self.prior_encoder = PriorEncoder(cfg.prior_encoder)
+ if cfg.diffusion.diffusion_type == "diffusion":
+ self.diffusion = Diffusion(cfg.diffusion)
+ elif cfg.diffusion.diffusion_type == "flow":
+ self.diffusion = DiffusionFlow(cfg.diffusion)
+
+ self.prompt_encoder = TransformerEncoder(cfg=cfg.prompt_encoder)
+ if self.latent_dim != cfg.prompt_encoder.encoder_hidden:
+ self.prompt_lin = nn.Linear(
+ self.latent_dim, cfg.prompt_encoder.encoder_hidden
+ )
+ self.prompt_lin.weight.data.normal_(0.0, 0.02)
+ else:
+ self.prompt_lin = None
+
+ self.query_emb = nn.Embedding(self.query_emb_num, cfg.query_emb.hidden_size)
+ self.query_attn = nn.MultiheadAttention(
+ cfg.query_emb.hidden_size, cfg.query_emb.head_num, batch_first=True
+ )
+
+ codec_model = EncodecModel.encodec_model_24khz()
+ codec_model.set_target_bandwidth(12.0)
+ codec_model.requires_grad_(False)
+ self.quantizer = codec_model.quantizer
+
+ @torch.no_grad()
+ def code_to_latent(self, code):
+ latent = self.quantizer.decode(code.transpose(0, 1))
+ return latent
+
+ def latent_to_code(self, latent, nq=16):
+ residual = latent
+ all_indices = []
+ all_dist = []
+ for i in range(nq):
+ layer = self.quantizer.vq.layers[i]
+ x = rearrange(residual, "b d n -> b n d")
+ x = layer.project_in(x)
+ shape = x.shape
+ x = layer._codebook.preprocess(x)
+ embed = layer._codebook.embed.t()
+ dist = -(
+ x.pow(2).sum(1, keepdim=True)
+ - 2 * x @ embed
+ + embed.pow(2).sum(0, keepdim=True)
+ )
+ indices = dist.max(dim=-1).indices
+ indices = layer._codebook.postprocess_emb(indices, shape)
+ dist = dist.reshape(*shape[:-1], dist.shape[-1])
+ quantized = layer.decode(indices)
+ residual = residual - quantized
+ all_indices.append(indices)
+ all_dist.append(dist)
+
+ out_indices = torch.stack(all_indices)
+ out_dist = torch.stack(all_dist)
+
+ return out_indices, out_dist # (nq, B, T); (nq, B, T, 1024)
+
+ @torch.no_grad()
+ def latent_to_latent(self, latent, nq=16):
+ codes, _ = self.latent_to_code(latent, nq)
+ latent = self.quantizer.vq.decode(codes)
+ return latent
+
+ def forward(
+ self,
+ code=None,
+ pitch=None,
+ duration=None,
+ phone_id=None,
+ phone_id_frame=None,
+ frame_nums=None,
+ ref_code=None,
+ ref_frame_nums=None,
+ phone_mask=None,
+ mask=None,
+ ref_mask=None,
+ ):
+ ref_latent = self.code_to_latent(ref_code)
+ latent = self.code_to_latent(code)
+
+ if self.latent_dim is not None:
+ ref_latent = self.prompt_lin(ref_latent.transpose(1, 2))
+
+ ref_latent = self.prompt_encoder(ref_latent, ref_mask, condition=None)
+ spk_emb = ref_latent.transpose(1, 2) # (B, d, T')
+
+ spk_query_emb = self.query_emb(
+ torch.arange(self.query_emb_num).to(latent.device)
+ ).repeat(
+ latent.shape[0], 1, 1
+ ) # (B, query_emb_num, d)
+ spk_query_emb, _ = self.query_attn(
+ spk_query_emb,
+ spk_emb.transpose(1, 2),
+ spk_emb.transpose(1, 2),
+ key_padding_mask=~(ref_mask.bool()),
+ ) # (B, query_emb_num, d)
+
+ prior_out = self.prior_encoder(
+ phone_id=phone_id,
+ duration=duration,
+ pitch=pitch,
+ phone_mask=phone_mask,
+ mask=mask,
+ ref_emb=spk_emb,
+ ref_mask=ref_mask,
+ is_inference=False,
+ )
+ prior_condition = prior_out["prior_out"] # (B, T, d)
+
+ diff_out = self.diffusion(latent, mask, prior_condition, spk_query_emb)
+
+ return diff_out, prior_out
+
+ @torch.no_grad()
+ def inference(
+ self, ref_code=None, phone_id=None, ref_mask=None, inference_steps=1000
+ ):
+ ref_latent = self.code_to_latent(ref_code)
+
+ if self.latent_dim is not None:
+ ref_latent = self.prompt_lin(ref_latent.transpose(1, 2))
+
+ ref_latent = self.prompt_encoder(ref_latent, ref_mask, condition=None)
+ spk_emb = ref_latent.transpose(1, 2) # (B, d, T')
+
+ spk_query_emb = self.query_emb(
+ torch.arange(self.query_emb_num).to(ref_latent.device)
+ ).repeat(
+ ref_latent.shape[0], 1, 1
+ ) # (B, query_emb_num, d)
+ spk_query_emb, _ = self.query_attn(
+ spk_query_emb,
+ spk_emb.transpose(1, 2),
+ spk_emb.transpose(1, 2),
+ key_padding_mask=~(ref_mask.bool()),
+ ) # (B, query_emb_num, d)
+
+ prior_out = self.prior_encoder(
+ phone_id=phone_id,
+ duration=None,
+ pitch=None,
+ phone_mask=None,
+ mask=None,
+ ref_emb=spk_emb,
+ ref_mask=ref_mask,
+ is_inference=True,
+ )
+ prior_condition = prior_out["prior_out"] # (B, T, d)
+
+ z = torch.randn(
+ prior_condition.shape[0], self.latent_dim, prior_condition.shape[1]
+ ).to(ref_latent.device) / (1.20)
+ x0 = self.diffusion.reverse_diffusion(
+ z, None, prior_condition, inference_steps, spk_query_emb
+ )
+
+ return x0, prior_out
+
+ @torch.no_grad()
+ def reverse_diffusion_from_t(
+ self,
+ code=None,
+ pitch=None,
+ duration=None,
+ phone_id=None,
+ ref_code=None,
+ phone_mask=None,
+ mask=None,
+ ref_mask=None,
+ n_timesteps=None,
+ t=None,
+ ):
+ # o Only for debug
+
+ ref_latent = self.code_to_latent(ref_code)
+ latent = self.code_to_latent(code)
+
+ if self.latent_dim is not None:
+ ref_latent = self.prompt_lin(ref_latent.transpose(1, 2))
+
+ ref_latent = self.prompt_encoder(ref_latent, ref_mask, condition=None)
+ spk_emb = ref_latent.transpose(1, 2) # (B, d, T')
+
+ spk_query_emb = self.query_emb(
+ torch.arange(self.query_emb_num).to(latent.device)
+ ).repeat(
+ latent.shape[0], 1, 1
+ ) # (B, query_emb_num, d)
+ spk_query_emb, _ = self.query_attn(
+ spk_query_emb,
+ spk_emb.transpose(1, 2),
+ spk_emb.transpose(1, 2),
+ key_padding_mask=~(ref_mask.bool()),
+ ) # (B, query_emb_num, d)
+
+ prior_out = self.prior_encoder(
+ phone_id=phone_id,
+ duration=duration,
+ pitch=pitch,
+ phone_mask=phone_mask,
+ mask=mask,
+ ref_emb=spk_emb,
+ ref_mask=ref_mask,
+ is_inference=False,
+ )
+ prior_condition = prior_out["prior_out"] # (B, T, d)
+
+ diffusion_step = (
+ torch.ones(
+ latent.shape[0],
+ dtype=latent.dtype,
+ device=latent.device,
+ requires_grad=False,
+ )
+ * t
+ )
+ diffusion_step = torch.clamp(diffusion_step, 1e-5, 1.0 - 1e-5)
+ xt, _ = self.diffusion.forward_diffusion(
+ x0=latent, diffusion_step=diffusion_step
+ )
+ # print(torch.abs(xt-latent).max(), torch.abs(xt-latent).mean(), torch.abs(xt-latent).std())
+
+ x0 = self.diffusion.reverse_diffusion_from_t(
+ xt, mask, prior_condition, n_timesteps, spk_query_emb, t_start=t
+ )
+
+ return x0, prior_out, xt
diff --git a/models/tts/naturalspeech2/ns2_dataset.py b/models/tts/naturalspeech2/ns2_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..df10f3fa884b4e31a21a48e9092c8741f3cb932d
--- /dev/null
+++ b/models/tts/naturalspeech2/ns2_dataset.py
@@ -0,0 +1,524 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import random
+import torch
+from torch.nn.utils.rnn import pad_sequence
+from utils.data_utils import *
+from processors.acoustic_extractor import cal_normalized_mel
+from processors.acoustic_extractor import load_normalized
+from models.base.base_dataset import (
+ BaseOfflineCollator,
+ BaseOfflineDataset,
+ BaseTestDataset,
+ BaseTestCollator,
+)
+from text import text_to_sequence
+from text.cmudict import valid_symbols
+from tqdm import tqdm
+import pickle
+
+
+class NS2Dataset(torch.utils.data.Dataset):
+ def __init__(self, cfg, dataset, is_valid=False):
+ assert isinstance(dataset, str)
+
+ processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
+
+ meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
+ # train.json
+
+ self.metafile_path = os.path.join(processed_data_dir, meta_file)
+
+ self.metadata = self.get_metadata()
+
+ self.cfg = cfg
+
+ assert cfg.preprocess.use_mel == False
+ if cfg.preprocess.use_mel:
+ self.utt2melspec_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2melspec_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.melspec_dir, # mel
+ utt_info["speaker"],
+ uid + ".npy",
+ )
+
+ assert cfg.preprocess.use_code == True
+ if cfg.preprocess.use_code:
+ self.utt2code_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2code_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.code_dir, # code
+ utt_info["speaker"],
+ uid + ".npy",
+ )
+
+ assert cfg.preprocess.use_spkid == True
+ if cfg.preprocess.use_spkid:
+ self.utt2spkid = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2spkid[utt] = utt_info["speaker"]
+
+ assert cfg.preprocess.use_pitch == True
+ if cfg.preprocess.use_pitch:
+ self.utt2pitch_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2pitch_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.pitch_dir, # pitch
+ utt_info["speaker"],
+ uid + ".npy",
+ )
+
+ assert cfg.preprocess.use_duration == True
+ if cfg.preprocess.use_duration:
+ self.utt2duration_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2duration_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.duration_dir, # duration
+ utt_info["speaker"],
+ uid + ".npy",
+ )
+
+ assert cfg.preprocess.use_phone == True
+ if cfg.preprocess.use_phone:
+ self.utt2phone = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2phone[utt] = utt_info["phones"]
+
+ assert cfg.preprocess.use_len == True
+ if cfg.preprocess.use_len:
+ self.utt2len = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2len[utt] = utt_info["num_frames"]
+
+ # for cross reference
+ if cfg.preprocess.use_cross_reference:
+ self.spkid2utt = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+ spkid = utt_info["speaker"]
+ if spkid not in self.spkid2utt:
+ self.spkid2utt[spkid] = []
+ self.spkid2utt[spkid].append(utt)
+
+ # get phone to id / id to phone map
+ self.phone2id, self.id2phone = self.get_phone_map()
+
+ self.all_num_frames = []
+ for i in range(len(self.metadata)):
+ self.all_num_frames.append(self.metadata[i]["num_frames"])
+ self.num_frame_sorted = np.array(sorted(self.all_num_frames))
+ self.num_frame_indices = np.array(
+ sorted(
+ range(len(self.all_num_frames)), key=lambda k: self.all_num_frames[k]
+ )
+ )
+
+ def __len__(self):
+ return len(self.metadata)
+
+ def get_dataset_name(self):
+ return self.metadata[0]["Dataset"]
+
+ def get_metadata(self):
+ with open(self.metafile_path, "r", encoding="utf-8") as f:
+ metadata = json.load(f)
+
+ print("metadata len: ", len(metadata))
+
+ return metadata
+
+ def get_phone_map(self):
+ symbols = valid_symbols + ["sp", "spn", "sil"] + ["", ""]
+ phone2id = {s: i for i, s in enumerate(symbols)}
+ id2phone = {i: s for s, i in phone2id.items()}
+ return phone2id, id2phone
+
+ def __getitem__(self, index):
+ utt_info = self.metadata[index]
+
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ single_feature = dict()
+
+ if self.cfg.preprocess.read_metadata:
+ metadata_uid_path = os.path.join(
+ self.cfg.preprocess.processed_dir,
+ self.cfg.preprocess.metadata_dir,
+ dataset,
+ # utt_info["speaker"],
+ uid + ".pkl",
+ )
+ with open(metadata_uid_path, "rb") as f:
+ metadata_uid = pickle.load(f)
+ # code
+ code = metadata_uid["code"]
+ # frame_nums
+ frame_nums = code.shape[1]
+ # pitch
+ pitch = metadata_uid["pitch"]
+ # duration
+ duration = metadata_uid["duration"]
+ # phone_id
+ phone_id = np.array(
+ [
+ *map(
+ self.phone2id.get,
+ self.utt2phone[utt].replace("{", "").replace("}", "").split(),
+ )
+ ]
+ )
+
+ else:
+ # code
+ code = np.load(self.utt2code_path[utt])
+ # frame_nums
+ frame_nums = code.shape[1]
+ # pitch
+ pitch = np.load(self.utt2pitch_path[utt])
+ # duration
+ duration = np.load(self.utt2duration_path[utt])
+ # phone_id
+ phone_id = np.array(
+ [
+ *map(
+ self.phone2id.get,
+ self.utt2phone[utt].replace("{", "").replace("}", "").split(),
+ )
+ ]
+ )
+
+ # align length
+ code, pitch, duration, phone_id, frame_nums = self.align_length(
+ code, pitch, duration, phone_id, frame_nums
+ )
+
+ # spkid
+ spkid = self.utt2spkid[utt]
+
+ # get target and reference
+ out = self.get_target_and_reference(code, pitch, duration, phone_id, frame_nums)
+ code, ref_code = out["code"], out["ref_code"]
+ pitch, ref_pitch = out["pitch"], out["ref_pitch"]
+ duration, ref_duration = out["duration"], out["ref_duration"]
+ phone_id, ref_phone_id = out["phone_id"], out["ref_phone_id"]
+ frame_nums, ref_frame_nums = out["frame_nums"], out["ref_frame_nums"]
+
+ # phone_id_frame
+ assert len(phone_id) == len(duration)
+ phone_id_frame = []
+ for i in range(len(phone_id)):
+ phone_id_frame.extend([phone_id[i] for _ in range(duration[i])])
+ phone_id_frame = np.array(phone_id_frame)
+
+ # ref_phone_id_frame
+ assert len(ref_phone_id) == len(ref_duration)
+ ref_phone_id_frame = []
+ for i in range(len(ref_phone_id)):
+ ref_phone_id_frame.extend([ref_phone_id[i] for _ in range(ref_duration[i])])
+ ref_phone_id_frame = np.array(ref_phone_id_frame)
+
+ single_feature.update(
+ {
+ "code": code,
+ "frame_nums": frame_nums,
+ "pitch": pitch,
+ "duration": duration,
+ "phone_id": phone_id,
+ "phone_id_frame": phone_id_frame,
+ "ref_code": ref_code,
+ "ref_frame_nums": ref_frame_nums,
+ "ref_pitch": ref_pitch,
+ "ref_duration": ref_duration,
+ "ref_phone_id": ref_phone_id,
+ "ref_phone_id_frame": ref_phone_id_frame,
+ "spkid": spkid,
+ }
+ )
+
+ return single_feature
+
+ def get_num_frames(self, index):
+ utt_info = self.metadata[index]
+ return utt_info["num_frames"]
+
+ def align_length(self, code, pitch, duration, phone_id, frame_nums):
+ # aligh lenght of code, pitch, duration, phone_id, and frame nums
+ code_len = code.shape[1]
+ pitch_len = len(pitch)
+ dur_sum = sum(duration)
+ min_len = min(code_len, dur_sum)
+ code = code[:, :min_len]
+ if pitch_len >= min_len:
+ pitch = pitch[:min_len]
+ else:
+ pitch = np.pad(pitch, (0, min_len - pitch_len), mode="edge")
+ frame_nums = min_len
+ if dur_sum > min_len:
+ assert (duration[-1] - (dur_sum - min_len)) >= 0
+ duration[-1] = duration[-1] - (dur_sum - min_len)
+ assert duration[-1] >= 0
+
+ return code, pitch, duration, phone_id, frame_nums
+
+ def get_target_and_reference(self, code, pitch, duration, phone_id, frame_nums):
+ phone_nums = len(phone_id)
+ clip_phone_nums = np.random.randint(
+ int(phone_nums * 0.1), int(phone_nums * 0.5) + 1
+ )
+ clip_phone_nums = max(clip_phone_nums, 1)
+ assert clip_phone_nums < phone_nums and clip_phone_nums >= 1
+ if self.cfg.preprocess.clip_mode == "mid":
+ start_idx = np.random.randint(0, phone_nums - clip_phone_nums)
+ elif self.cfg.preprocess.clip_mode == "start":
+ if duration[0] == 0 and clip_phone_nums == 1:
+ start_idx = 1
+ else:
+ start_idx = 0
+ else:
+ assert self.cfg.preprocess.clip_mode in ["mid", "start"]
+ end_idx = start_idx + clip_phone_nums
+ start_frames = sum(duration[:start_idx])
+ end_frames = sum(duration[:end_idx])
+
+ new_code = np.concatenate(
+ (code[:, :start_frames], code[:, end_frames:]), axis=1
+ )
+ ref_code = code[:, start_frames:end_frames]
+
+ new_pitch = np.append(pitch[:start_frames], pitch[end_frames:])
+ ref_pitch = pitch[start_frames:end_frames]
+
+ new_duration = np.append(duration[:start_idx], duration[end_idx:])
+ ref_duration = duration[start_idx:end_idx]
+
+ new_phone_id = np.append(phone_id[:start_idx], phone_id[end_idx:])
+ ref_phone_id = phone_id[start_idx:end_idx]
+
+ new_frame_nums = frame_nums - (end_frames - start_frames)
+ ref_frame_nums = end_frames - start_frames
+
+ return {
+ "code": new_code,
+ "ref_code": ref_code,
+ "pitch": new_pitch,
+ "ref_pitch": ref_pitch,
+ "duration": new_duration,
+ "ref_duration": ref_duration,
+ "phone_id": new_phone_id,
+ "ref_phone_id": ref_phone_id,
+ "frame_nums": new_frame_nums,
+ "ref_frame_nums": ref_frame_nums,
+ }
+
+
+class NS2Collator(BaseOfflineCollator):
+ def __init__(self, cfg):
+ BaseOfflineCollator.__init__(self, cfg)
+
+ def __call__(self, batch):
+ packed_batch_features = dict()
+
+ # code: (B, 16, T)
+ # frame_nums: (B,) not used
+ # pitch: (B, T)
+ # duration: (B, N)
+ # phone_id: (B, N)
+ # phone_id_frame: (B, T)
+ # ref_code: (B, 16, T')
+ # ref_frame_nums: (B,) not used
+ # ref_pitch: (B, T) not used
+ # ref_duration: (B, N') not used
+ # ref_phone_id: (B, N') not used
+ # ref_phone_frame: (B, T') not used
+ # spkid: (B,) not used
+ # phone_mask: (B, N)
+ # mask: (B, T)
+ # ref_mask: (B, T')
+
+ for key in batch[0].keys():
+ if key == "phone_id":
+ phone_ids = [torch.LongTensor(b["phone_id"]) for b in batch]
+ phone_masks = [torch.ones(len(b["phone_id"])) for b in batch]
+ packed_batch_features["phone_id"] = pad_sequence(
+ phone_ids,
+ batch_first=True,
+ padding_value=0,
+ )
+ packed_batch_features["phone_mask"] = pad_sequence(
+ phone_masks,
+ batch_first=True,
+ padding_value=0,
+ )
+ elif key == "phone_id_frame":
+ phone_id_frames = [torch.LongTensor(b["phone_id_frame"]) for b in batch]
+ masks = [torch.ones(len(b["phone_id_frame"])) for b in batch]
+ packed_batch_features["phone_id_frame"] = pad_sequence(
+ phone_id_frames,
+ batch_first=True,
+ padding_value=0,
+ )
+ packed_batch_features["mask"] = pad_sequence(
+ masks,
+ batch_first=True,
+ padding_value=0,
+ )
+ elif key == "ref_code":
+ ref_codes = [
+ torch.from_numpy(b["ref_code"]).transpose(0, 1) for b in batch
+ ]
+ ref_masks = [torch.ones(max(b["ref_code"].shape[1], 1)) for b in batch]
+ packed_batch_features["ref_code"] = pad_sequence(
+ ref_codes,
+ batch_first=True,
+ padding_value=0,
+ ).transpose(1, 2)
+ packed_batch_features["ref_mask"] = pad_sequence(
+ ref_masks,
+ batch_first=True,
+ padding_value=0,
+ )
+ elif key == "code":
+ codes = [torch.from_numpy(b["code"]).transpose(0, 1) for b in batch]
+ masks = [torch.ones(max(b["code"].shape[1], 1)) for b in batch]
+ packed_batch_features["code"] = pad_sequence(
+ codes,
+ batch_first=True,
+ padding_value=0,
+ ).transpose(1, 2)
+ packed_batch_features["mask"] = pad_sequence(
+ masks,
+ batch_first=True,
+ padding_value=0,
+ )
+ elif key == "pitch":
+ values = [torch.from_numpy(b[key]) for b in batch]
+ packed_batch_features[key] = pad_sequence(
+ values, batch_first=True, padding_value=50.0
+ )
+ elif key == "duration":
+ values = [torch.from_numpy(b[key]) for b in batch]
+ packed_batch_features[key] = pad_sequence(
+ values, batch_first=True, padding_value=0
+ )
+ elif key == "frame_nums":
+ packed_batch_features["frame_nums"] = torch.LongTensor(
+ [b["frame_nums"] for b in batch]
+ )
+ elif key == "ref_frame_nums":
+ packed_batch_features["ref_frame_nums"] = torch.LongTensor(
+ [b["ref_frame_nums"] for b in batch]
+ )
+ else:
+ pass
+
+ return packed_batch_features
+
+
+def _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
+ if len(batch) == 0:
+ return 0
+ if len(batch) == max_sentences:
+ return 1
+ if num_tokens > max_tokens:
+ return 1
+ return 0
+
+
+def batch_by_size(
+ indices,
+ num_tokens_fn,
+ max_tokens=None,
+ max_sentences=None,
+ required_batch_size_multiple=1,
+):
+ """
+ Yield mini-batches of indices bucketed by size. Batches may contain
+ sequences of different lengths.
+
+ Args:
+ indices (List[int]): ordered list of dataset indices
+ num_tokens_fn (callable): function that returns the number of tokens at
+ a given index
+ max_tokens (int, optional): max number of tokens in each batch
+ (default: None).
+ max_sentences (int, optional): max number of sentences in each
+ batch (default: None).
+ required_batch_size_multiple (int, optional): require batch size to
+ be a multiple of N (default: 1).
+ """
+ bsz_mult = required_batch_size_multiple
+
+ sample_len = 0
+ sample_lens = []
+ batch = []
+ batches = []
+ for i in range(len(indices)):
+ idx = indices[i]
+ num_tokens = num_tokens_fn(idx)
+ sample_lens.append(num_tokens)
+ sample_len = max(sample_len, num_tokens)
+
+ assert (
+ sample_len <= max_tokens
+ ), "sentence at index {} of size {} exceeds max_tokens " "limit of {}!".format(
+ idx, sample_len, max_tokens
+ )
+ num_tokens = (len(batch) + 1) * sample_len
+
+ if _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
+ mod_len = max(
+ bsz_mult * (len(batch) // bsz_mult),
+ len(batch) % bsz_mult,
+ )
+ batches.append(batch[:mod_len])
+ batch = batch[mod_len:]
+ sample_lens = sample_lens[mod_len:]
+ sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
+ batch.append(idx)
+ if len(batch) > 0:
+ batches.append(batch)
+ return batches
diff --git a/models/tts/naturalspeech2/ns2_inference.py b/models/tts/naturalspeech2/ns2_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2605d216d226a4d3e0e24a01f6b8cf2a1dcd447
--- /dev/null
+++ b/models/tts/naturalspeech2/ns2_inference.py
@@ -0,0 +1,128 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import os
+import torch
+import soundfile as sf
+import numpy as np
+
+from models.tts.naturalspeech2.ns2 import NaturalSpeech2
+from encodec import EncodecModel
+from encodec.utils import convert_audio
+from utils.util import load_config
+
+from text import text_to_sequence
+from text.cmudict import valid_symbols
+from text.g2p import preprocess_english, read_lexicon
+
+import torchaudio
+
+
+class NS2Inference:
+ def __init__(self, args, cfg):
+ self.cfg = cfg
+ self.args = args
+
+ self.model = self.build_model()
+ self.codec = self.build_codec()
+
+ self.symbols = valid_symbols + ["sp", "spn", "sil"] + ["", ""]
+ self.phone2id = {s: i for i, s in enumerate(self.symbols)}
+ self.id2phone = {i: s for s, i in self.phone2id.items()}
+
+ def build_model(self):
+ model = NaturalSpeech2(self.cfg.model)
+ model.load_state_dict(
+ torch.load(
+ os.path.join(self.args.checkpoint_path, "pytorch_model.bin"),
+ map_location="cpu",
+ )
+ )
+ model = model.to(self.args.device)
+ return model
+
+ def build_codec(self):
+ encodec_model = EncodecModel.encodec_model_24khz()
+ encodec_model = encodec_model.to(device=self.args.device)
+ encodec_model.set_target_bandwidth(12.0)
+ return encodec_model
+
+ def get_ref_code(self):
+ ref_wav_path = self.args.ref_audio
+ ref_wav, sr = torchaudio.load(ref_wav_path)
+ ref_wav = convert_audio(
+ ref_wav, sr, self.codec.sample_rate, self.codec.channels
+ )
+ ref_wav = ref_wav.unsqueeze(0).to(device=self.args.device)
+
+ with torch.no_grad():
+ encoded_frames = self.codec.encode(ref_wav)
+ ref_code = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1)
+ # print(ref_code.shape)
+
+ ref_mask = torch.ones(ref_code.shape[0], ref_code.shape[-1]).to(ref_code.device)
+ # print(ref_mask.shape)
+
+ return ref_code, ref_mask
+
+ def inference(self):
+ ref_code, ref_mask = self.get_ref_code()
+
+ lexicon = read_lexicon(self.cfg.preprocess.lexicon_path)
+ phone_seq = preprocess_english(self.args.text, lexicon)
+ print(phone_seq)
+
+ phone_id = np.array(
+ [
+ *map(
+ self.phone2id.get,
+ phone_seq.replace("{", "").replace("}", "").split(),
+ )
+ ]
+ )
+ phone_id = torch.from_numpy(phone_id).unsqueeze(0).to(device=self.args.device)
+ print(phone_id)
+
+ x0, prior_out = self.model.inference(
+ ref_code, phone_id, ref_mask, self.args.inference_step
+ )
+ print(prior_out["dur_pred"])
+ print(prior_out["dur_pred_round"])
+ print(torch.sum(prior_out["dur_pred_round"]))
+
+ latent_ref = self.codec.quantizer.vq.decode(ref_code.transpose(0, 1))
+
+ rec_wav = self.codec.decoder(x0)
+ # ref_wav = self.codec.decoder(latent_ref)
+
+ os.makedirs(self.args.output_dir, exist_ok=True)
+
+ sf.write(
+ "{}/{}.wav".format(
+ self.args.output_dir, self.args.text.replace(" ", "_", 100)
+ ),
+ rec_wav[0, 0].detach().cpu().numpy(),
+ samplerate=24000,
+ )
+
+ def add_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--ref_audio",
+ type=str,
+ default="",
+ help="Reference audio path",
+ )
+ parser.add_argument(
+ "--device",
+ type=str,
+ default="cuda",
+ )
+ parser.add_argument(
+ "--inference_step",
+ type=int,
+ default=200,
+ help="Total inference steps for the diffusion model",
+ )
diff --git a/models/tts/naturalspeech2/ns2_loss.py b/models/tts/naturalspeech2/ns2_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..1743edb1284549e28e14582734ef84b78fa0214a
--- /dev/null
+++ b/models/tts/naturalspeech2/ns2_loss.py
@@ -0,0 +1,75 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import numpy as np
+import torch.nn.functional as F
+
+
+def log_dur_loss(dur_pred_log, dur_target, mask, loss_type="l1"):
+ # dur_pred_log: (B, N)
+ # dur_target: (B, N)
+ # mask: (B, N) mask is 0
+ dur_target_log = torch.log(1 + dur_target)
+ if loss_type == "l1":
+ loss = F.l1_loss(
+ dur_pred_log, dur_target_log, reduction="none"
+ ).float() * mask.to(dur_target.dtype)
+ elif loss_type == "l2":
+ loss = F.mse_loss(
+ dur_pred_log, dur_target_log, reduction="none"
+ ).float() * mask.to(dur_target.dtype)
+ else:
+ raise NotImplementedError()
+ loss = loss.sum() / (mask.to(dur_target.dtype).sum())
+ return loss
+
+
+def log_pitch_loss(pitch_pred_log, pitch_target, mask, loss_type="l1"):
+ pitch_target_log = torch.log(pitch_target)
+ if loss_type == "l1":
+ loss = F.l1_loss(
+ pitch_pred_log, pitch_target_log, reduction="none"
+ ).float() * mask.to(pitch_target.dtype)
+ elif loss_type == "l2":
+ loss = F.mse_loss(
+ pitch_pred_log, pitch_target_log, reduction="none"
+ ).float() * mask.to(pitch_target.dtype)
+ else:
+ raise NotImplementedError()
+ loss = loss.sum() / (mask.to(pitch_target.dtype).sum() + 1e-8)
+ return loss
+
+
+def diff_loss(pred, target, mask, loss_type="l1"):
+ # pred: (B, d, T)
+ # target: (B, d, T)
+ # mask: (B, T)
+ if loss_type == "l1":
+ loss = F.l1_loss(pred, target, reduction="none").float() * (
+ mask.to(pred.dtype).unsqueeze(1)
+ )
+ elif loss_type == "l2":
+ loss = F.mse_loss(pred, target, reduction="none").float() * (
+ mask.to(pred.dtype).unsqueeze(1)
+ )
+ else:
+ raise NotImplementedError()
+ loss = (torch.mean(loss, dim=1)).sum() / (mask.to(pred.dtype).sum())
+ return loss
+
+
+def diff_ce_loss(pred_dist, gt_indices, mask):
+ # pred_dist: (nq, B, T, 1024)
+ # gt_indices: (nq, B, T)
+ pred_dist = pred_dist.permute(1, 3, 0, 2) # (B, 1024, nq, T)
+ gt_indices = gt_indices.permute(1, 0, 2).long() # (B, nq, T)
+ loss = F.cross_entropy(
+ pred_dist, gt_indices, reduction="none"
+ ).float() # (B, nq, T)
+ loss = loss * mask.to(loss.dtype).unsqueeze(1)
+ loss = (torch.mean(loss, dim=1)).sum() / (mask.to(loss.dtype).sum())
+ return loss
diff --git a/models/tts/naturalspeech2/ns2_trainer.py b/models/tts/naturalspeech2/ns2_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..63c4353eee6eaafceb8a7b2ff1542587e767d99d
--- /dev/null
+++ b/models/tts/naturalspeech2/ns2_trainer.py
@@ -0,0 +1,798 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import shutil
+import json
+import time
+import torch
+import numpy as np
+from utils.util import Logger, ValueWindow
+from torch.utils.data import ConcatDataset, DataLoader
+from models.tts.base.tts_trainer import TTSTrainer
+from models.base.base_trainer import BaseTrainer
+from models.base.base_sampler import VariableSampler
+from models.tts.naturalspeech2.ns2_dataset import NS2Dataset, NS2Collator, batch_by_size
+from models.tts.naturalspeech2.ns2_loss import (
+ log_pitch_loss,
+ log_dur_loss,
+ diff_loss,
+ diff_ce_loss,
+)
+from torch.utils.data.sampler import BatchSampler, SequentialSampler
+from models.tts.naturalspeech2.ns2 import NaturalSpeech2
+from torch.optim import Adam, AdamW
+from torch.nn import MSELoss, L1Loss
+import torch.nn.functional as F
+from diffusers import get_scheduler
+
+import accelerate
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration
+
+
+class NS2Trainer(TTSTrainer):
+ def __init__(self, args, cfg):
+ self.args = args
+ self.cfg = cfg
+
+ cfg.exp_name = args.exp_name
+
+ self._init_accelerator()
+ self.accelerator.wait_for_everyone()
+
+ # Init logger
+ with self.accelerator.main_process_first():
+ if self.accelerator.is_main_process:
+ os.makedirs(os.path.join(self.exp_dir, "checkpoint"), exist_ok=True)
+ self.log_file = os.path.join(
+ os.path.join(self.exp_dir, "checkpoint"), "train.log"
+ )
+ self.logger = Logger(self.log_file, level=self.args.log_level).logger
+
+ self.time_window = ValueWindow(50)
+
+ if self.accelerator.is_main_process:
+ # Log some info
+ self.logger.info("=" * 56)
+ self.logger.info("||\t\t" + "New training process started." + "\t\t||")
+ self.logger.info("=" * 56)
+ self.logger.info("\n")
+ self.logger.debug(f"Using {args.log_level.upper()} logging level.")
+ self.logger.info(f"Experiment name: {args.exp_name}")
+ self.logger.info(f"Experiment directory: {self.exp_dir}")
+
+ self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
+ if self.accelerator.is_main_process:
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
+
+ if self.accelerator.is_main_process:
+ self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
+
+ # init counts
+ self.batch_count: int = 0
+ self.step: int = 0
+ self.epoch: int = 0
+ self.max_epoch = (
+ self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf")
+ )
+ if self.accelerator.is_main_process:
+ self.logger.info(
+ "Max epoch: {}".format(
+ self.max_epoch if self.max_epoch < float("inf") else "Unlimited"
+ )
+ )
+
+ # Check values
+ if self.accelerator.is_main_process:
+ self._check_basic_configs()
+ # Set runtime configs
+ self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride
+ self.checkpoints_path = [
+ [] for _ in range(len(self.save_checkpoint_stride))
+ ]
+ self.keep_last = [
+ i if i > 0 else float("inf") for i in self.cfg.train.keep_last
+ ]
+ self.run_eval = self.cfg.train.run_eval
+
+ # set random seed
+ with self.accelerator.main_process_first():
+ start = time.monotonic_ns()
+ self._set_random_seed(self.cfg.train.random_seed)
+ end = time.monotonic_ns()
+ if self.accelerator.is_main_process:
+ self.logger.debug(
+ f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
+ )
+ self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
+
+ # setup data_loader
+ with self.accelerator.main_process_first():
+ if self.accelerator.is_main_process:
+ self.logger.info("Building dataset...")
+ start = time.monotonic_ns()
+ self.train_dataloader, self.valid_dataloader = self._build_dataloader()
+ end = time.monotonic_ns()
+ if self.accelerator.is_main_process:
+ self.logger.info(
+ f"Building dataset done in {(end - start) / 1e6:.2f}ms"
+ )
+
+ # setup model
+ with self.accelerator.main_process_first():
+ if self.accelerator.is_main_process:
+ self.logger.info("Building model...")
+ start = time.monotonic_ns()
+ self.model = self._build_model()
+ end = time.monotonic_ns()
+ if self.accelerator.is_main_process:
+ self.logger.debug(self.model)
+ self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms")
+ self.logger.info(
+ f"Model parameters: {self._count_parameters(self.model)/1e6:.2f}M"
+ )
+
+ # optimizer & scheduler
+ with self.accelerator.main_process_first():
+ if self.accelerator.is_main_process:
+ self.logger.info("Building optimizer and scheduler...")
+ start = time.monotonic_ns()
+ self.optimizer = self._build_optimizer()
+ self.scheduler = self._build_scheduler()
+ end = time.monotonic_ns()
+ if self.accelerator.is_main_process:
+ self.logger.info(
+ f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms"
+ )
+
+ # accelerate prepare
+ if not self.cfg.train.use_dynamic_batchsize:
+ if self.accelerator.is_main_process:
+ self.logger.info("Initializing accelerate...")
+ start = time.monotonic_ns()
+ (
+ self.train_dataloader,
+ self.valid_dataloader,
+ ) = self.accelerator.prepare(
+ self.train_dataloader,
+ self.valid_dataloader,
+ )
+
+ if isinstance(self.model, dict):
+ for key in self.model.keys():
+ self.model[key] = self.accelerator.prepare(self.model[key])
+ else:
+ self.model = self.accelerator.prepare(self.model)
+
+ if isinstance(self.optimizer, dict):
+ for key in self.optimizer.keys():
+ self.optimizer[key] = self.accelerator.prepare(self.optimizer[key])
+ else:
+ self.optimizer = self.accelerator.prepare(self.optimizer)
+
+ if isinstance(self.scheduler, dict):
+ for key in self.scheduler.keys():
+ self.scheduler[key] = self.accelerator.prepare(self.scheduler[key])
+ else:
+ self.scheduler = self.accelerator.prepare(self.scheduler)
+
+ end = time.monotonic_ns()
+ if self.accelerator.is_main_process:
+ self.logger.info(
+ f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms"
+ )
+
+ # create criterion
+ with self.accelerator.main_process_first():
+ if self.accelerator.is_main_process:
+ self.logger.info("Building criterion...")
+ start = time.monotonic_ns()
+ self.criterion = self._build_criterion()
+ end = time.monotonic_ns()
+ if self.accelerator.is_main_process:
+ self.logger.info(
+ f"Building criterion done in {(end - start) / 1e6:.2f}ms"
+ )
+
+ # TODO: Resume from ckpt need test/debug
+ with self.accelerator.main_process_first():
+ if args.resume:
+ if self.accelerator.is_main_process:
+ self.logger.info("Resuming from checkpoint...")
+ start = time.monotonic_ns()
+ ckpt_path = self._load_model(
+ self.checkpoint_dir,
+ args.checkpoint_path,
+ resume_type=args.resume_type,
+ )
+ end = time.monotonic_ns()
+ if self.accelerator.is_main_process:
+ self.logger.info(
+ f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
+ )
+ self.checkpoints_path = json.load(
+ open(os.path.join(ckpt_path, "ckpts.json"), "r")
+ )
+
+ self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
+ if self.accelerator.is_main_process:
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
+ if self.accelerator.is_main_process:
+ self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
+
+ # save config file path
+ self.config_save_path = os.path.join(self.exp_dir, "args.json")
+
+ # Only for TTS tasks
+ self.task_type = "TTS"
+ if self.accelerator.is_main_process:
+ self.logger.info("Task type: {}".format(self.task_type))
+
+ def _init_accelerator(self):
+ self.exp_dir = os.path.join(
+ os.path.abspath(self.cfg.log_dir), self.args.exp_name
+ )
+ project_config = ProjectConfiguration(
+ project_dir=self.exp_dir,
+ logging_dir=os.path.join(self.exp_dir, "log"),
+ )
+ # ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ self.accelerator = accelerate.Accelerator(
+ gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step,
+ log_with=self.cfg.train.tracker,
+ project_config=project_config,
+ # kwargs_handlers=[ddp_kwargs]
+ )
+ if self.accelerator.is_main_process:
+ os.makedirs(project_config.project_dir, exist_ok=True)
+ os.makedirs(project_config.logging_dir, exist_ok=True)
+ with self.accelerator.main_process_first():
+ self.accelerator.init_trackers(self.args.exp_name)
+
+ def _build_model(self):
+ model = NaturalSpeech2(cfg=self.cfg.model)
+ return model
+
+ def _build_dataset(self):
+ return NS2Dataset, NS2Collator
+
+ def _build_dataloader(self):
+ if self.cfg.train.use_dynamic_batchsize:
+ print("Use Dynamic Batchsize......")
+ Dataset, Collator = self._build_dataset()
+ train_dataset = Dataset(self.cfg, self.cfg.dataset[0], is_valid=False)
+ train_collate = Collator(self.cfg)
+ batch_sampler = batch_by_size(
+ train_dataset.num_frame_indices,
+ train_dataset.get_num_frames,
+ max_tokens=self.cfg.train.max_tokens * self.accelerator.num_processes,
+ max_sentences=self.cfg.train.max_sentences
+ * self.accelerator.num_processes,
+ required_batch_size_multiple=self.accelerator.num_processes,
+ )
+ np.random.seed(980205)
+ np.random.shuffle(batch_sampler)
+ print(batch_sampler[:1])
+ batches = [
+ x[
+ self.accelerator.local_process_index :: self.accelerator.num_processes
+ ]
+ for x in batch_sampler
+ if len(x) % self.accelerator.num_processes == 0
+ ]
+
+ train_loader = DataLoader(
+ train_dataset,
+ collate_fn=train_collate,
+ num_workers=self.cfg.train.dataloader.num_worker,
+ batch_sampler=VariableSampler(
+ batches, drop_last=False, use_random_sampler=True
+ ),
+ pin_memory=self.cfg.train.dataloader.pin_memory,
+ )
+ self.accelerator.wait_for_everyone()
+
+ valid_dataset = Dataset(self.cfg, self.cfg.dataset[0], is_valid=True)
+ valid_collate = Collator(self.cfg)
+ batch_sampler = batch_by_size(
+ valid_dataset.num_frame_indices,
+ valid_dataset.get_num_frames,
+ max_tokens=self.cfg.train.max_tokens * self.accelerator.num_processes,
+ max_sentences=self.cfg.train.max_sentences
+ * self.accelerator.num_processes,
+ required_batch_size_multiple=self.accelerator.num_processes,
+ )
+ batches = [
+ x[
+ self.accelerator.local_process_index :: self.accelerator.num_processes
+ ]
+ for x in batch_sampler
+ if len(x) % self.accelerator.num_processes == 0
+ ]
+ valid_loader = DataLoader(
+ valid_dataset,
+ collate_fn=valid_collate,
+ num_workers=self.cfg.train.dataloader.num_worker,
+ batch_sampler=VariableSampler(batches, drop_last=False),
+ pin_memory=self.cfg.train.dataloader.pin_memory,
+ )
+ self.accelerator.wait_for_everyone()
+
+ else:
+ print("Use Normal Batchsize......")
+ Dataset, Collator = self._build_dataset()
+ train_dataset = Dataset(self.cfg, self.cfg.dataset[0], is_valid=False)
+ train_collate = Collator(self.cfg)
+
+ train_loader = DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=train_collate,
+ batch_size=self.cfg.train.batch_size,
+ num_workers=self.cfg.train.dataloader.num_worker,
+ pin_memory=self.cfg.train.dataloader.pin_memory,
+ )
+
+ valid_dataset = Dataset(self.cfg, self.cfg.dataset[0], is_valid=True)
+ valid_collate = Collator(self.cfg)
+
+ valid_loader = DataLoader(
+ valid_dataset,
+ shuffle=True,
+ collate_fn=valid_collate,
+ batch_size=self.cfg.train.batch_size,
+ num_workers=self.cfg.train.dataloader.num_worker,
+ pin_memory=self.cfg.train.dataloader.pin_memory,
+ )
+ self.accelerator.wait_for_everyone()
+
+ return train_loader, valid_loader
+
+ def _build_optimizer(self):
+ optimizer = torch.optim.AdamW(
+ filter(lambda p: p.requires_grad, self.model.parameters()),
+ **self.cfg.train.adam,
+ )
+ return optimizer
+
+ def _build_scheduler(self):
+ lr_scheduler = get_scheduler(
+ self.cfg.train.lr_scheduler,
+ optimizer=self.optimizer,
+ num_warmup_steps=self.cfg.train.lr_warmup_steps,
+ num_training_steps=self.cfg.train.num_train_steps,
+ )
+ return lr_scheduler
+
+ def _build_criterion(self):
+ criterion = torch.nn.L1Loss(reduction="mean")
+ return criterion
+
+ def write_summary(self, losses, stats):
+ for key, value in losses.items():
+ self.sw.add_scalar(key, value, self.step)
+
+ def write_valid_summary(self, losses, stats):
+ for key, value in losses.items():
+ self.sw.add_scalar(key, value, self.step)
+
+ def get_state_dict(self):
+ state_dict = {
+ "model": self.model.state_dict(),
+ "optimizer": self.optimizer.state_dict(),
+ "scheduler": self.scheduler.state_dict(),
+ "step": self.step,
+ "epoch": self.epoch,
+ "batch_size": self.cfg.train.batch_size,
+ }
+ return state_dict
+
+ def load_model(self, checkpoint):
+ self.step = checkpoint["step"]
+ self.epoch = checkpoint["epoch"]
+
+ self.model.load_state_dict(checkpoint["model"])
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
+ self.scheduler.load_state_dict(checkpoint["scheduler"])
+
+ def _train_step(self, batch):
+ train_losses = {}
+ total_loss = 0
+ train_stats = {}
+
+ code = batch["code"] # (B, 16, T)
+ pitch = batch["pitch"] # (B, T)
+ duration = batch["duration"] # (B, N)
+ phone_id = batch["phone_id"] # (B, N)
+ ref_code = batch["ref_code"] # (B, 16, T')
+ phone_mask = batch["phone_mask"] # (B, N)
+ mask = batch["mask"] # (B, T)
+ ref_mask = batch["ref_mask"] # (B, T')
+
+ diff_out, prior_out = self.model(
+ code=code,
+ pitch=pitch,
+ duration=duration,
+ phone_id=phone_id,
+ ref_code=ref_code,
+ phone_mask=phone_mask,
+ mask=mask,
+ ref_mask=ref_mask,
+ )
+
+ # pitch loss
+ pitch_loss = log_pitch_loss(prior_out["pitch_pred_log"], pitch, mask=mask)
+ total_loss += pitch_loss
+ train_losses["pitch_loss"] = pitch_loss
+
+ # duration loss
+ dur_loss = log_dur_loss(prior_out["dur_pred_log"], duration, mask=phone_mask)
+ total_loss += dur_loss
+ train_losses["dur_loss"] = dur_loss
+
+ x0 = self.model.module.code_to_latent(code)
+ if self.cfg.model.diffusion.diffusion_type == "diffusion":
+ # diff loss x0
+ diff_loss_x0 = diff_loss(diff_out["x0_pred"], x0, mask=mask)
+ total_loss += diff_loss_x0
+ train_losses["diff_loss_x0"] = diff_loss_x0
+
+ # diff loss noise
+ diff_loss_noise = diff_loss(
+ diff_out["noise_pred"], diff_out["noise"], mask=mask
+ )
+ total_loss += diff_loss_noise * self.cfg.train.diff_noise_loss_lambda
+ train_losses["diff_loss_noise"] = diff_loss_noise
+
+ elif self.cfg.model.diffusion.diffusion_type == "flow":
+ # diff flow matching loss
+ flow_gt = diff_out["noise"] - x0
+ diff_loss_flow = diff_loss(diff_out["flow_pred"], flow_gt, mask=mask)
+ total_loss += diff_loss_flow
+ train_losses["diff_loss_flow"] = diff_loss_flow
+
+ # diff loss ce
+
+ # (nq, B, T); (nq, B, T, 1024)
+ if self.cfg.train.diff_ce_loss_lambda > 0:
+ pred_indices, pred_dist = self.model.module.latent_to_code(
+ diff_out["x0_pred"], nq=code.shape[1]
+ )
+ gt_indices, _ = self.model.module.latent_to_code(x0, nq=code.shape[1])
+ diff_loss_ce = diff_ce_loss(pred_dist, gt_indices, mask=mask)
+ total_loss += diff_loss_ce * self.cfg.train.diff_ce_loss_lambda
+ train_losses["diff_loss_ce"] = diff_loss_ce
+
+ self.optimizer.zero_grad()
+ # total_loss.backward()
+ self.accelerator.backward(total_loss)
+ if self.accelerator.sync_gradients:
+ self.accelerator.clip_grad_norm_(
+ filter(lambda p: p.requires_grad, self.model.parameters()), 0.5
+ )
+ self.optimizer.step()
+ self.scheduler.step()
+
+ for item in train_losses:
+ train_losses[item] = train_losses[item].item()
+
+ if self.cfg.train.diff_ce_loss_lambda > 0:
+ pred_indices_list = pred_indices.long().detach().cpu().numpy()
+ gt_indices_list = gt_indices.long().detach().cpu().numpy()
+ mask_list = batch["mask"].detach().cpu().numpy()
+
+ for i in range(pred_indices_list.shape[0]):
+ pred_acc = np.sum(
+ (pred_indices_list[i] == gt_indices_list[i]) * mask_list
+ ) / np.sum(mask_list)
+ train_losses["pred_acc_{}".format(str(i))] = pred_acc
+
+ train_losses["batch_size"] = code.shape[0]
+ train_losses["max_frame_nums"] = np.max(
+ batch["frame_nums"].detach().cpu().numpy()
+ )
+
+ return (total_loss.item(), train_losses, train_stats)
+
+ @torch.inference_mode()
+ def _valid_step(self, batch):
+ valid_losses = {}
+ total_loss = 0
+ valid_stats = {}
+
+ code = batch["code"] # (B, 16, T)
+ pitch = batch["pitch"] # (B, T)
+ duration = batch["duration"] # (B, N)
+ phone_id = batch["phone_id"] # (B, N)
+ ref_code = batch["ref_code"] # (B, 16, T')
+ phone_mask = batch["phone_mask"] # (B, N)
+ mask = batch["mask"] # (B, T)
+ ref_mask = batch["ref_mask"] # (B, T')
+
+ diff_out, prior_out = self.model(
+ code=code,
+ pitch=pitch,
+ duration=duration,
+ phone_id=phone_id,
+ ref_code=ref_code,
+ phone_mask=phone_mask,
+ mask=mask,
+ ref_mask=ref_mask,
+ )
+
+ # pitch loss
+ pitch_loss = log_pitch_loss(prior_out["pitch_pred_log"], pitch, mask=mask)
+ total_loss += pitch_loss
+ valid_losses["pitch_loss"] = pitch_loss
+
+ # duration loss
+ dur_loss = log_dur_loss(prior_out["dur_pred_log"], duration, mask=phone_mask)
+ total_loss += dur_loss
+ valid_losses["dur_loss"] = dur_loss
+
+ x0 = self.model.module.code_to_latent(code)
+ if self.cfg.model.diffusion.diffusion_type == "diffusion":
+ # diff loss x0
+ diff_loss_x0 = diff_loss(diff_out["x0_pred"], x0, mask=mask)
+ total_loss += diff_loss_x0
+ valid_losses["diff_loss_x0"] = diff_loss_x0
+
+ # diff loss noise
+ diff_loss_noise = diff_loss(
+ diff_out["noise_pred"], diff_out["noise"], mask=mask
+ )
+ total_loss += diff_loss_noise * self.cfg.train.diff_noise_loss_lambda
+ valid_losses["diff_loss_noise"] = diff_loss_noise
+
+ elif self.cfg.model.diffusion.diffusion_type == "flow":
+ # diff flow matching loss
+ flow_gt = diff_out["noise"] - x0
+ diff_loss_flow = diff_loss(diff_out["flow_pred"], flow_gt, mask=mask)
+ total_loss += diff_loss_flow
+ valid_losses["diff_loss_flow"] = diff_loss_flow
+
+ # diff loss ce
+
+ # (nq, B, T); (nq, B, T, 1024)
+ if self.cfg.train.diff_ce_loss_lambda > 0:
+ pred_indices, pred_dist = self.model.module.latent_to_code(
+ diff_out["x0_pred"], nq=code.shape[1]
+ )
+ gt_indices, _ = self.model.module.latent_to_code(x0, nq=code.shape[1])
+ diff_loss_ce = diff_ce_loss(pred_dist, gt_indices, mask=mask)
+ total_loss += diff_loss_ce * self.cfg.train.diff_ce_loss_lambda
+ valid_losses["diff_loss_ce"] = diff_loss_ce
+
+ for item in valid_losses:
+ valid_losses[item] = valid_losses[item].item()
+
+ if self.cfg.train.diff_ce_loss_lambda > 0:
+ pred_indices_list = pred_indices.long().detach().cpu().numpy()
+ gt_indices_list = gt_indices.long().detach().cpu().numpy()
+ mask_list = batch["mask"].detach().cpu().numpy()
+
+ for i in range(pred_indices_list.shape[0]):
+ pred_acc = np.sum(
+ (pred_indices_list[i] == gt_indices_list[i]) * mask_list
+ ) / np.sum(mask_list)
+ valid_losses["pred_acc_{}".format(str(i))] = pred_acc
+
+ return (total_loss.item(), valid_losses, valid_stats)
+
+ @torch.inference_mode()
+ def _valid_epoch(self):
+ r"""Testing epoch. Should return average loss of a batch (sample) over
+ one epoch. See ``train_loop`` for usage.
+ """
+ if isinstance(self.model, dict):
+ for key in self.model.keys():
+ self.model[key].eval()
+ else:
+ self.model.eval()
+
+ epoch_sum_loss = 0.0
+ epoch_losses = dict()
+
+ for batch in self.valid_dataloader:
+ # Put the data to cuda device
+ device = self.accelerator.device
+ for k, v in batch.items():
+ if isinstance(v, torch.Tensor):
+ batch[k] = v.to(device)
+
+ total_loss, valid_losses, valid_stats = self._valid_step(batch)
+ epoch_sum_loss = total_loss
+ for key, value in valid_losses.items():
+ epoch_losses[key] = value
+
+ self.accelerator.wait_for_everyone()
+
+ return epoch_sum_loss, epoch_losses
+
+ def _train_epoch(self):
+ r"""Training epoch. Should return average loss of a batch (sample) over
+ one epoch. See ``train_loop`` for usage.
+ """
+ if isinstance(self.model, dict):
+ for key in self.model.keys():
+ self.model[key].train()
+ else:
+ self.model.train()
+
+ epoch_sum_loss: float = 0.0
+ epoch_losses: dict = {}
+ epoch_step: int = 0
+
+ for batch in self.train_dataloader:
+ # Put the data to cuda device
+ device = self.accelerator.device
+ for k, v in batch.items():
+ if isinstance(v, torch.Tensor):
+ batch[k] = v.to(device)
+
+ # Do training step and BP
+ with self.accelerator.accumulate(self.model):
+ total_loss, train_losses, training_stats = self._train_step(batch)
+ self.batch_count += 1
+
+ # Update info for each step
+ # TODO: step means BP counts or batch counts?
+ if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
+ epoch_sum_loss = total_loss
+ for key, value in train_losses.items():
+ epoch_losses[key] = value
+
+ if isinstance(train_losses, dict):
+ for key, loss in train_losses.items():
+ self.accelerator.log(
+ {"Epoch/Train {} Loss".format(key): loss},
+ step=self.step,
+ )
+
+ if (
+ self.accelerator.is_main_process
+ and self.batch_count
+ % (1 * self.cfg.train.gradient_accumulation_step)
+ == 0
+ ):
+ self.echo_log(train_losses, mode="Training")
+
+ self.step += 1
+ epoch_step += 1
+
+ self.accelerator.wait_for_everyone()
+
+ return epoch_sum_loss, epoch_losses
+
+ def train_loop(self):
+ r"""Training loop. The public entry of training process."""
+ # Wait everyone to prepare before we move on
+ self.accelerator.wait_for_everyone()
+ # dump config file
+ if self.accelerator.is_main_process:
+ self._dump_cfg(self.config_save_path)
+
+ # self.optimizer.zero_grad()
+
+ # Wait to ensure good to go
+ self.accelerator.wait_for_everyone()
+ while self.epoch < self.max_epoch:
+ if self.accelerator.is_main_process:
+ self.logger.info("\n")
+ self.logger.info("-" * 32)
+ self.logger.info("Epoch {}: ".format(self.epoch))
+
+ # Do training & validating epoch
+ train_total_loss, train_losses = self._train_epoch()
+ if isinstance(train_losses, dict):
+ for key, loss in train_losses.items():
+ if self.accelerator.is_main_process:
+ self.logger.info(" |- Train/{} Loss: {:.6f}".format(key, loss))
+ self.accelerator.log(
+ {"Epoch/Train {} Loss".format(key): loss},
+ step=self.epoch,
+ )
+
+ valid_total_loss, valid_losses = self._valid_epoch()
+ if isinstance(valid_losses, dict):
+ for key, loss in valid_losses.items():
+ if self.accelerator.is_main_process:
+ self.logger.info(" |- Valid/{} Loss: {:.6f}".format(key, loss))
+ self.accelerator.log(
+ {"Epoch/Train {} Loss".format(key): loss},
+ step=self.epoch,
+ )
+
+ if self.accelerator.is_main_process:
+ self.logger.info(" |- Train/Loss: {:.6f}".format(train_total_loss))
+ self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_total_loss))
+ self.accelerator.log(
+ {
+ "Epoch/Train Loss": train_total_loss,
+ "Epoch/Valid Loss": valid_total_loss,
+ },
+ step=self.epoch,
+ )
+
+ self.accelerator.wait_for_everyone()
+ if isinstance(self.scheduler, dict):
+ for key in self.scheduler.keys():
+ self.scheduler[key].step()
+ else:
+ self.scheduler.step()
+
+ # Check if hit save_checkpoint_stride and run_eval
+ run_eval = False
+ if self.accelerator.is_main_process:
+ save_checkpoint = False
+ hit_dix = []
+ for i, num in enumerate(self.save_checkpoint_stride):
+ if self.epoch % num == 0:
+ save_checkpoint = True
+ hit_dix.append(i)
+ run_eval |= self.run_eval[i]
+
+ self.accelerator.wait_for_everyone()
+ if self.accelerator.is_main_process and save_checkpoint:
+ path = os.path.join(
+ self.checkpoint_dir,
+ "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
+ self.epoch, self.step, train_total_loss
+ ),
+ )
+ print("save state......")
+ self.accelerator.save_state(path)
+ print("finish saving state......")
+ json.dump(
+ self.checkpoints_path,
+ open(os.path.join(path, "ckpts.json"), "w"),
+ ensure_ascii=False,
+ indent=4,
+ )
+ # Remove old checkpoints
+ to_remove = []
+ for idx in hit_dix:
+ self.checkpoints_path[idx].append(path)
+ while len(self.checkpoints_path[idx]) > self.keep_last[idx]:
+ to_remove.append((idx, self.checkpoints_path[idx].pop(0)))
+
+ # Search conflicts
+ total = set()
+ for i in self.checkpoints_path:
+ total |= set(i)
+ do_remove = set()
+ for idx, path in to_remove[::-1]:
+ if path in total:
+ self.checkpoints_path[idx].insert(0, path)
+ else:
+ do_remove.add(path)
+
+ # Remove old checkpoints
+ for path in do_remove:
+ shutil.rmtree(path, ignore_errors=True)
+ if self.accelerator.is_main_process:
+ self.logger.debug(f"Remove old checkpoint: {path}")
+
+ self.accelerator.wait_for_everyone()
+ if run_eval:
+ # TODO: run evaluation
+ pass
+
+ # Update info for each epoch
+ self.epoch += 1
+
+ # Finish training and save final checkpoint
+ self.accelerator.wait_for_everyone()
+ if self.accelerator.is_main_process:
+ self.accelerator.save_state(
+ os.path.join(
+ self.checkpoint_dir,
+ "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
+ self.epoch, self.step, valid_total_loss
+ ),
+ )
+ )
+ self.accelerator.end_training()
diff --git a/models/tts/naturalspeech2/prior_encoder.py b/models/tts/naturalspeech2/prior_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..84cf78c391bf08696b4a2217a12732f36095ddea
--- /dev/null
+++ b/models/tts/naturalspeech2/prior_encoder.py
@@ -0,0 +1,114 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import numpy as np
+import torch.nn.functional as F
+from modules.naturalpseech2.transformers import (
+ TransformerEncoder,
+ DurationPredictor,
+ PitchPredictor,
+ LengthRegulator,
+)
+
+
+class PriorEncoder(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.cfg = cfg
+ self.enc_emb_tokens = nn.Embedding(
+ cfg.vocab_size, cfg.encoder.encoder_hidden, padding_idx=0
+ )
+ self.enc_emb_tokens.weight.data.normal_(mean=0.0, std=1e-5)
+ self.encoder = TransformerEncoder(
+ enc_emb_tokens=self.enc_emb_tokens, cfg=cfg.encoder
+ )
+
+ self.duration_predictor = DurationPredictor(cfg.duration_predictor)
+ self.pitch_predictor = PitchPredictor(cfg.pitch_predictor)
+ self.length_regulator = LengthRegulator()
+
+ self.pitch_min = cfg.pitch_min
+ self.pitch_max = cfg.pitch_max
+ self.pitch_bins_num = cfg.pitch_bins_num
+
+ pitch_bins = torch.exp(
+ torch.linspace(
+ np.log(self.pitch_min), np.log(self.pitch_max), self.pitch_bins_num - 1
+ )
+ )
+ self.register_buffer("pitch_bins", pitch_bins)
+
+ self.pitch_embedding = nn.Embedding(
+ self.pitch_bins_num, cfg.encoder.encoder_hidden
+ )
+
+ def forward(
+ self,
+ phone_id,
+ duration=None,
+ pitch=None,
+ phone_mask=None,
+ mask=None,
+ ref_emb=None,
+ ref_mask=None,
+ is_inference=False,
+ ):
+ """
+ input:
+ phone_id: (B, N)
+ duration: (B, N)
+ pitch: (B, T)
+ phone_mask: (B, N); mask is 0
+ mask: (B, T); mask is 0
+ ref_emb: (B, d, T')
+ ref_mask: (B, T'); mask is 0
+
+ output:
+ prior_embedding: (B, d, T)
+ pred_dur: (B, N)
+ pred_pitch: (B, T)
+ """
+
+ x = self.encoder(phone_id, phone_mask, ref_emb.transpose(1, 2))
+ # print(torch.min(x), torch.max(x))
+ dur_pred_out = self.duration_predictor(x, phone_mask, ref_emb, ref_mask)
+ # dur_pred_out: {dur_pred_log, dur_pred, dur_pred_round}
+
+ if is_inference or duration is None:
+ x, mel_len = self.length_regulator(
+ x,
+ dur_pred_out["dur_pred_round"],
+ max_len=torch.max(torch.sum(dur_pred_out["dur_pred_round"], dim=1)),
+ )
+ else:
+ x, mel_len = self.length_regulator(x, duration, max_len=pitch.shape[1])
+
+ pitch_pred_log = self.pitch_predictor(x, mask, ref_emb, ref_mask)
+
+ if is_inference or pitch is None:
+ pitch_tokens = torch.bucketize(pitch_pred_log.exp(), self.pitch_bins)
+ pitch_embedding = self.pitch_embedding(pitch_tokens)
+ else:
+ pitch_tokens = torch.bucketize(pitch, self.pitch_bins)
+ pitch_embedding = self.pitch_embedding(pitch_tokens)
+
+ x = x + pitch_embedding
+
+ if (not is_inference) and (mask is not None):
+ x = x * mask.to(x.dtype)[:, :, None]
+
+ prior_out = {
+ "dur_pred_round": dur_pred_out["dur_pred_round"],
+ "dur_pred_log": dur_pred_out["dur_pred_log"],
+ "dur_pred": dur_pred_out["dur_pred"],
+ "pitch_pred_log": pitch_pred_log,
+ "pitch_token": pitch_tokens,
+ "mel_len": mel_len,
+ "prior_out": x,
+ }
+
+ return prior_out
diff --git a/models/tts/naturalspeech2/wavenet.py b/models/tts/naturalspeech2/wavenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc964ea6e41df541bcd5d5cea2cc83ab53c29d1f
--- /dev/null
+++ b/models/tts/naturalspeech2/wavenet.py
@@ -0,0 +1,206 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import numpy as np
+import torch.nn.functional as F
+import math
+
+
+class FiLM(nn.Module):
+ def __init__(self, in_dim, cond_dim):
+ super().__init__()
+
+ self.gain = Linear(cond_dim, in_dim)
+ self.bias = Linear(cond_dim, in_dim)
+
+ nn.init.xavier_uniform_(self.gain.weight)
+ nn.init.constant_(self.gain.bias, 1)
+
+ nn.init.xavier_uniform_(self.bias.weight)
+ nn.init.constant_(self.bias.bias, 0)
+
+ def forward(self, x, condition):
+ gain = self.gain(condition)
+ bias = self.bias(condition)
+ if gain.dim() == 2:
+ gain = gain.unsqueeze(-1)
+ if bias.dim() == 2:
+ bias = bias.unsqueeze(-1)
+ return x * gain + bias
+
+
+class Mish(nn.Module):
+ def forward(self, x):
+ return x * torch.tanh(F.softplus(x))
+
+
+def Conv1d(*args, **kwargs):
+ layer = nn.Conv1d(*args, **kwargs)
+ nn.init.kaiming_normal_(layer.weight)
+ return layer
+
+
+def Linear(*args, **kwargs):
+ layer = nn.Linear(*args, **kwargs)
+ layer.weight.data.normal_(0.0, 0.02)
+ return layer
+
+
+class SinusoidalPosEmb(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, x):
+ device = x.device
+ half_dim = self.dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
+ emb = x[:, None] * emb[None, :]
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
+ return emb
+
+
+class ResidualBlock(nn.Module):
+ def __init__(self, hidden_dim, attn_head, dilation, drop_out, has_cattn=False):
+ super().__init__()
+
+ self.hidden_dim = hidden_dim
+ self.dilation = dilation
+ self.has_cattn = has_cattn
+ self.attn_head = attn_head
+ self.drop_out = drop_out
+
+ self.dilated_conv = Conv1d(
+ hidden_dim, 2 * hidden_dim, 3, padding=dilation, dilation=dilation
+ )
+ self.diffusion_proj = Linear(hidden_dim, hidden_dim)
+
+ self.cond_proj = Conv1d(hidden_dim, hidden_dim * 2, 1)
+ self.out_proj = Conv1d(hidden_dim, hidden_dim * 2, 1)
+
+ if self.has_cattn:
+ self.attn = nn.MultiheadAttention(
+ hidden_dim, attn_head, 0.1, batch_first=True
+ )
+ self.film = FiLM(hidden_dim * 2, hidden_dim)
+
+ self.ln = nn.LayerNorm(hidden_dim)
+
+ self.dropout = nn.Dropout(self.drop_out)
+
+ def forward(self, x, x_mask, cond, diffusion_step, spk_query_emb):
+ diffusion_step = self.diffusion_proj(diffusion_step).unsqueeze(-1) # (B, d, 1)
+ cond = self.cond_proj(cond) # (B, 2*d, T)
+
+ y = x + diffusion_step
+ if x_mask != None:
+ y = y * x_mask.to(y.dtype)[:, None, :] # (B, 2*d, T)
+
+ if self.has_cattn:
+ y_ = y.transpose(1, 2)
+ y_ = self.ln(y_)
+
+ y_, _ = self.attn(y_, spk_query_emb, spk_query_emb) # (B, T, d)
+
+ y = self.dilated_conv(y) + cond # (B, 2*d, T)
+
+ if self.has_cattn:
+ y = self.film(y.transpose(1, 2), y_) # (B, T, 2*d)
+ y = y.transpose(1, 2) # (B, 2*d, T)
+
+ gate, filter_ = torch.chunk(y, 2, dim=1)
+ y = torch.sigmoid(gate) * torch.tanh(filter_)
+
+ y = self.out_proj(y)
+
+ residual, skip = torch.chunk(y, 2, dim=1)
+
+ if x_mask != None:
+ residual = residual * x_mask.to(y.dtype)[:, None, :]
+ skip = skip * x_mask.to(y.dtype)[:, None, :]
+
+ return (x + residual) / math.sqrt(2.0), skip
+
+
+class WaveNet(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+
+ self.cfg = cfg
+ self.in_dim = cfg.input_size
+ self.hidden_dim = cfg.hidden_size
+ self.out_dim = cfg.out_size
+ self.num_layers = cfg.num_layers
+ self.cross_attn_per_layer = cfg.cross_attn_per_layer
+ self.dilation_cycle = cfg.dilation_cycle
+ self.attn_head = cfg.attn_head
+ self.drop_out = cfg.drop_out
+
+ self.in_proj = Conv1d(self.in_dim, self.hidden_dim, 1)
+ self.diffusion_embedding = SinusoidalPosEmb(self.hidden_dim)
+
+ self.mlp = nn.Sequential(
+ Linear(self.hidden_dim, self.hidden_dim * 4),
+ Mish(),
+ Linear(self.hidden_dim * 4, self.hidden_dim),
+ )
+
+ self.cond_ln = nn.LayerNorm(self.hidden_dim)
+
+ self.layers = nn.ModuleList(
+ [
+ ResidualBlock(
+ self.hidden_dim,
+ self.attn_head,
+ 2 ** (i % self.dilation_cycle),
+ self.drop_out,
+ has_cattn=(i % self.cross_attn_per_layer == 0),
+ )
+ for i in range(self.num_layers)
+ ]
+ )
+
+ self.skip_proj = Conv1d(self.hidden_dim, self.hidden_dim, 1)
+ self.out_proj = Conv1d(self.hidden_dim, self.out_dim, 1)
+
+ nn.init.zeros_(self.out_proj.weight)
+
+ def forward(self, x, x_mask, cond, diffusion_step, spk_query_emb):
+ """
+ x: (B, 128, T)
+ x_mask: (B, T), mask is 0
+ cond: (B, T, 512)
+ diffusion_step: (B,)
+ spk_query_emb: (B, 32, 512)
+ """
+ cond = self.cond_ln(cond)
+ cond_input = cond.transpose(1, 2)
+
+ x_input = self.in_proj(x)
+
+ x_input = F.relu(x_input)
+
+ diffusion_step = self.diffusion_embedding(diffusion_step).to(x.dtype)
+ diffusion_step = self.mlp(diffusion_step)
+
+ skip = []
+ for _, layer in enumerate(self.layers):
+ x_input, skip_connection = layer(
+ x_input, x_mask, cond_input, diffusion_step, spk_query_emb
+ )
+ skip.append(skip_connection)
+
+ x_input = torch.sum(torch.stack(skip), dim=0) / math.sqrt(self.num_layers)
+
+ x_out = self.skip_proj(x_input)
+
+ x_out = F.relu(x_out)
+
+ x_out = self.out_proj(x_out) # (B, 128, T)
+
+ return x_out
diff --git a/models/tts/valle/__init__.py b/models/tts/valle/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/tts/valle/valle.py b/models/tts/valle/valle.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbcf43c88c6492d221f49893394ec1a8f94e9ca8
--- /dev/null
+++ b/models/tts/valle/valle.py
@@ -0,0 +1,794 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# This code is modified from https://github.com/lifeiteng/vall-e/blob/main/valle/models/valle.py
+
+import random
+from typing import Dict, Iterator, List, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchmetrics.classification import MulticlassAccuracy
+from utils.util import make_pad_mask
+from utils.topk_sampling import topk_sampling
+from modules.general import Transpose
+from modules.encoder import TokenEmbedding
+from modules.general import PromptedFeatures
+from modules.transformer import SinePositionalEmbedding
+from modules.norms import AdaptiveLayerNorm, LayerNorm
+from modules.transformer.transformer import TransformerEncoder, TransformerEncoderLayer
+
+
+class VALLE(nn.Module):
+ def __init__(
+ self,
+ cfg,
+ decoder_cls=TransformerEncoder,
+ decoder_layer_cls=TransformerEncoderLayer,
+ ):
+ super().__init__()
+ decoder_dim = cfg.decoder_dim
+ nhead = cfg.nhead
+ nar_scale_factor = cfg.nar_scale_factor
+ num_quantizers = cfg.num_quantizers
+ num_decoder_layers = cfg.num_decoder_layers
+ nar_decoder_dim = int(decoder_dim * nar_scale_factor)
+
+ self.ar_text_embedding = TokenEmbedding(decoder_dim, cfg.text_token_num)
+ self.nar_text_embedding = TokenEmbedding(nar_decoder_dim, cfg.text_token_num)
+
+ self.ar_audio_prepend_bos = cfg.prepend_bos
+ self.ar_audio_embedding = TokenEmbedding(
+ decoder_dim, cfg.audio_token_num + 1 + int(cfg.prepend_bos)
+ )
+ self.audio_token_num = cfg.audio_token_num
+
+ # PreNet of AR
+ if cfg.add_prenet:
+ self.ar_text_prenet = nn.Sequential(
+ Transpose(),
+ nn.Conv1d(decoder_dim, decoder_dim, kernel_size=5, padding="same"),
+ nn.BatchNorm1d(decoder_dim),
+ nn.ReLU(),
+ nn.Dropout(0.5),
+ nn.Conv1d(decoder_dim, decoder_dim, kernel_size=5, padding="same"),
+ nn.BatchNorm1d(decoder_dim),
+ nn.ReLU(),
+ nn.Dropout(0.5),
+ nn.Conv1d(decoder_dim, decoder_dim, kernel_size=5, padding="same"),
+ nn.BatchNorm1d(decoder_dim),
+ nn.ReLU(),
+ nn.Dropout(0.5),
+ Transpose(),
+ nn.Linear(decoder_dim, decoder_dim),
+ )
+
+ self.ar_audio_prenet = nn.Sequential(
+ nn.Linear(decoder_dim, 256),
+ nn.ReLU(),
+ nn.Dropout(0.25),
+ nn.Linear(256, 256),
+ nn.ReLU(),
+ nn.Dropout(0.25),
+ nn.Linear(256, decoder_dim),
+ )
+ else:
+ self.ar_text_prenet = nn.Identity()
+ self.ar_audio_prenet = nn.Identity()
+
+ self.ar_text_position = SinePositionalEmbedding(
+ decoder_dim,
+ dropout=0.1,
+ scale=False,
+ alpha=True,
+ )
+ self.ar_audio_position = SinePositionalEmbedding(
+ decoder_dim,
+ dropout=0.1,
+ scale=False,
+ alpha=True,
+ )
+
+ self.ar_decoder = decoder_cls(
+ decoder_layer_cls(
+ decoder_dim,
+ nhead,
+ dim_feedforward=decoder_dim * 4, # *4?
+ dropout=0.1,
+ batch_first=True,
+ norm_first=cfg.norm_first,
+ ),
+ num_layers=num_decoder_layers,
+ norm=LayerNorm(decoder_dim) if cfg.norm_first else None,
+ )
+ self.ar_predict_layer = nn.Linear(
+ decoder_dim, cfg.audio_token_num + 1, bias=False
+ )
+
+ self.ar_accuracy_metric = MulticlassAccuracy(
+ cfg.audio_token_num + 1,
+ top_k=10,
+ average="micro",
+ multidim_average="global",
+ ignore_index=cfg.audio_token_num,
+ )
+
+ self.rng = random.Random(0)
+ self.num_heads = nhead
+ self.prefix_mode = cfg.prefix_mode
+ self.num_quantizers = num_quantizers
+
+ assert num_quantizers >= 1
+ if num_quantizers > 1:
+ self.nar_audio_embeddings = nn.ModuleList(
+ [
+ TokenEmbedding(nar_decoder_dim, cfg.audio_token_num + 1)
+ ] # Why the first layer is audio_token_num + 1?
+ + [
+ TokenEmbedding(nar_decoder_dim, cfg.audio_token_num)
+ for i in range(num_quantizers - 1)
+ ]
+ )
+
+ if cfg.add_prenet:
+ self.nar_text_prenet = nn.Sequential(
+ Transpose(),
+ nn.Conv1d(
+ nar_decoder_dim, nar_decoder_dim, kernel_size=5, padding="same"
+ ),
+ nn.BatchNorm1d(nar_decoder_dim),
+ nn.ReLU(),
+ nn.Dropout(0.5),
+ nn.Conv1d(
+ nar_decoder_dim, nar_decoder_dim, kernel_size=5, padding="same"
+ ),
+ nn.BatchNorm1d(nar_decoder_dim),
+ nn.ReLU(),
+ nn.Dropout(0.5),
+ nn.Conv1d(
+ nar_decoder_dim, nar_decoder_dim, kernel_size=5, padding="same"
+ ),
+ nn.BatchNorm1d(nar_decoder_dim),
+ nn.ReLU(),
+ nn.Dropout(0.5),
+ Transpose(),
+ nn.Linear(nar_decoder_dim, nar_decoder_dim),
+ )
+ self.nar_audio_prenet = nn.Sequential(
+ nn.Linear(nar_decoder_dim, 256),
+ nn.ReLU(),
+ nn.Dropout(0.25),
+ nn.Linear(256, 256),
+ nn.ReLU(),
+ nn.Dropout(0.25),
+ nn.Linear(256, nar_decoder_dim),
+ )
+ else:
+ self.nar_text_prenet = nn.Identity()
+ self.nar_audio_prenet = nn.Identity()
+
+ self.nar_text_position = SinePositionalEmbedding(
+ nar_decoder_dim,
+ dropout=0.0,
+ scale=False,
+ alpha=False,
+ )
+ self.nar_audio_position = SinePositionalEmbedding(
+ nar_decoder_dim,
+ dropout=0.1,
+ scale=False,
+ alpha=False,
+ )
+
+ self.nar_decoder = decoder_cls(
+ decoder_layer_cls(
+ nar_decoder_dim,
+ int(nhead * nar_scale_factor),
+ dim_feedforward=nar_decoder_dim * 4,
+ dropout=0.1,
+ batch_first=True,
+ norm_first=cfg.norm_first,
+ adaptive_layer_norm=True,
+ ),
+ num_layers=int(num_decoder_layers * nar_scale_factor),
+ norm=(
+ AdaptiveLayerNorm(
+ nar_decoder_dim, norm=nn.LayerNorm(nar_decoder_dim)
+ )
+ if cfg.norm_first
+ else None
+ ),
+ )
+ self.nar_predict_layers = nn.ModuleList(
+ [
+ nn.Linear(nar_decoder_dim, cfg.audio_token_num, bias=False)
+ for i in range(num_quantizers - 1)
+ ]
+ )
+ self.nar_stage_embeddings = nn.ModuleList(
+ [TokenEmbedding(nar_decoder_dim, 1) for i in range(num_quantizers - 1)]
+ )
+
+ if cfg.share_embedding:
+ for j in range(0, num_quantizers - 2):
+ self.nar_predict_layers[j].weight = self.nar_audio_embeddings[
+ j + 2
+ ].weight
+
+ self.nar_accuracy_metric = MulticlassAccuracy(
+ cfg.audio_token_num + 1,
+ top_k=10,
+ average="micro",
+ multidim_average="global",
+ ignore_index=cfg.audio_token_num,
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_lens: torch.Tensor,
+ y: Union[torch.Tensor, PromptedFeatures],
+ y_lens: Union[torch.Tensor, PromptedFeatures],
+ reduction: str = "sum",
+ train_stage: int = 0,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
+ """
+ Args:
+ x:
+ A 2-D tensor of shape (N, S).
+ x_lens:
+ A 1-D tensor of shape (N,). It contains the number of tokens in `x`
+ before padding.
+ y:
+ A 3-D tensor of shape (N, T, 8).
+ y_lens:
+ A 1-D tensor of shape (N,). It contains the number of tokens in `x`
+ before padding.
+ train_stage:
+ 0: AR & NAR modules, 1: AR modules, 2: NAR modules
+ Returns:
+ Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy.
+ """
+ assert x.ndim == 2, x.shape
+ assert x_lens.ndim == 1, x_lens.shape
+
+ y_prompts_codes = None
+ if isinstance(y, PromptedFeatures):
+ y_prompts_codes, y = y.data
+ prompts_len, y_lens = y_lens.data
+ assert prompts_len.min() == prompts_len.max()
+ assert self.prefix_mode == 4
+ y_prompts_codes = y_prompts_codes.type(torch.int64)
+
+ assert y.ndim == 3, y.shape
+ assert y_lens.ndim == 1, y_lens.shape
+
+ x_mask = make_pad_mask(x_lens).to(x.device)
+ y_mask = make_pad_mask(y_lens).to(y.device)
+ y_mask_int = y_mask.type(torch.int64)
+
+ text = x
+ codes = y.type(torch.int64) * (1 - y_mask_int.unsqueeze(dim=-1))
+
+ y, targets = self.pad_y_eos(
+ codes[..., 0], y_mask_int, eos_id=self.audio_token_num
+ )
+ self.y_mask_int = y_mask_int
+
+ metrics = {}
+ total_loss = 0.0
+
+ xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
+ if self.ar_audio_prepend_bos:
+ ar_xy_padding_mask = torch.concat(
+ [x_mask, F.pad(y_mask, (1, 0), value=False)], dim=1
+ )
+ else:
+ ar_xy_padding_mask = xy_padding_mask
+ self.xy_padding_mask = xy_padding_mask
+ self.ar_xy_padding_mask = ar_xy_padding_mask
+
+ # AR Decoder
+ if train_stage in [0, 1]:
+ ar_loss, ar_metrics = self._forward_ar_decoder(
+ text, x_lens.max(), y, y_lens.max(), targets, x_mask, y_mask, reduction
+ )
+ total_loss += ar_loss
+ metrics["AR_Top100Acc"] = ar_metrics
+
+ # NAR Decoder
+ if self.ar_audio_prepend_bos:
+ y = y[:, 1:]
+
+ if self.num_quantizers > 1 and train_stage in [0, 2]:
+ nar_loss, nar_metrics = self._forward_nar_decoder(
+ text,
+ x_lens,
+ y,
+ y_lens,
+ codes,
+ y_prompts_codes,
+ x_mask,
+ y_mask,
+ reduction,
+ )
+ total_loss += nar_loss
+ metrics["NAR_Top100Acc"] = nar_metrics
+
+ if train_stage == 0:
+ total_loss = total_loss / 2.0
+
+ return total_loss, metrics
+
+ def _forward_ar_decoder(
+ self, x, x_len, y, y_lens, targets, x_mask, y_mask, reduction
+ ):
+ x = self.ar_text_embedding(x)
+ x = self.ar_text_prenet(x)
+ x = self.ar_text_position(x)
+
+ y_len = y_lens.max() + int(self.ar_audio_prepend_bos)
+
+ x_attn_mask = F.pad(
+ torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
+ (0, y_len),
+ value=True,
+ )
+ y_attn_mask = F.pad(
+ torch.triu(
+ torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
+ diagonal=1,
+ ),
+ (x_len, 0),
+ value=False,
+ )
+ xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
+
+ bsz, src_len = x.shape[0], x_len + y_len
+ _xy_padding_mask = (
+ self.ar_xy_padding_mask.view(bsz, 1, 1, src_len)
+ .expand(-1, self.num_heads, -1, -1)
+ .reshape(bsz * self.num_heads, 1, src_len)
+ )
+ xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
+
+ new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
+ new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
+ xy_attn_mask = new_attn_mask
+
+ y_emb = self.ar_audio_embedding(y)
+ y_emb = self.ar_audio_prenet(y_emb)
+ y_pos = self.ar_audio_position(y_emb)
+
+ xy_pos = torch.concat([x, y_pos], dim=1)
+
+ xy_dec, _ = self.ar_decoder(
+ (xy_pos, None),
+ mask=xy_attn_mask,
+ )
+ logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1)
+ ar_loss = F.cross_entropy(logits, targets, reduction=reduction)
+
+ ar_metrics = self.ar_accuracy_metric(
+ logits.detach(), targets
+ ).item() * y_lens.sum().type(torch.float32)
+
+ return ar_loss, ar_metrics
+
+ def _forward_nar_decoder(
+ self, x, x_lens, y, y_lens, codes, y_prompts_codes, x_mask, y_mask, reduction
+ ):
+ num_nar_layers = self.num_quantizers - 1
+ nar_stage = self.rng.choices(
+ [_k for _k in range(1, self.num_quantizers)],
+ weights=[1.0 / num_nar_layers] * num_nar_layers,
+ k=1,
+ )[0]
+
+ x = self.nar_text_embedding(x)
+ x = self.nar_text_prenet(x)
+ x = self.nar_text_position(x)
+
+ y_emb, prefix_len = self._prepare_prompts(
+ y, y_lens, codes, nar_stage, y_prompts_codes
+ )
+
+ y_len = y_lens.max()
+ targets = codes[..., nar_stage] + self.audio_token_num * self.y_mask_int
+ if self.prefix_mode in [2, 4]:
+ xy_padding_mask = torch.concat(
+ [
+ x_mask,
+ F.pad(y_mask, (y_emb.shape[1] - y_len, 0), value=False),
+ ],
+ dim=1,
+ )
+ elif self.prefix_mode == 1:
+ targets = targets[:, prefix_len:]
+
+ y_pos = self.nar_audio_prenet(y_emb)
+ y_pos = self.nar_audio_position(y_pos)
+ xy_pos = torch.concat([x, y_pos], dim=1)
+ xy_dec, _ = self.nar_decoder(
+ (xy_pos, self.nar_stage_embeddings[nar_stage - 1].weight),
+ src_key_padding_mask=self.xy_padding_mask,
+ )
+ xy_dec = xy_dec[:, x_lens.max() + prefix_len :]
+ if self.prefix_mode == 4:
+ prefix_len = 0
+ logits = self.nar_predict_layers[nar_stage - 1](xy_dec).permute(0, 2, 1)
+
+ total_length = (y_lens).sum().type(torch.float32)
+ nar_loss = F.cross_entropy(
+ logits,
+ targets,
+ ignore_index=self.audio_token_num,
+ reduction=reduction,
+ ) * (total_length / (total_length - prefix_len * x.shape[0]))
+ nar_metrics = (
+ self.nar_accuracy_metric(
+ F.pad(
+ logits.detach(),
+ (0, 0, 0, 1, 0, 0),
+ value=logits.min().cpu().item(),
+ ),
+ targets,
+ ).item()
+ * total_length
+ )
+ return nar_loss, nar_metrics
+
+ def inference(
+ self,
+ x: torch.Tensor,
+ x_lens: torch.Tensor,
+ y: torch.Tensor,
+ enroll_x_lens: torch.Tensor,
+ top_k: int = -100,
+ temperature: float = 1.0,
+ ) -> torch.Tensor:
+ """
+ Args:
+ x:
+ A 2-D tensor of shape (1, S).
+ x_lens:
+ A 1-D tensor of shape (1,). It contains the number of tokens in `x`
+ before padding.
+ y:
+ A 3-D tensor of shape (1, T, 8).
+ top_k: (`optional`) int
+ The number of highest probability tokens to keep for top-k-filtering. Default to -100.
+ temperature: (`optional`) float
+ The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
+ Returns:
+ Return the predicted audio code matrix.
+ """
+ assert x.ndim == 2, x.shape
+ assert x_lens.ndim == 1, x_lens.shape
+ assert y.ndim == 3, y.shape
+ assert y.shape[0] == 1, y.shape
+
+ assert torch.all(x_lens > 0)
+
+ text = x
+ x = self.ar_text_embedding(text)
+ x = self.ar_text_prenet(x)
+ x = self.ar_text_position(x)
+
+ text_len = x_lens.max()
+ prompts = y
+ prefix_len = y.shape[1]
+
+ # AR Decoder
+ y = prompts[..., 0]
+ if self.ar_audio_prepend_bos:
+ y = F.pad(y, (1, 0), value=self.audio_token_num + 1)
+
+ x_len = x_lens.max()
+ x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
+
+ while True:
+ y_emb = self.ar_audio_embedding(y)
+ y_emb = self.ar_audio_prenet(y_emb)
+ y_pos = self.ar_audio_position(y_emb)
+ xy_pos = torch.concat([x, y_pos], dim=1)
+
+ y_len = y.shape[1]
+ x_attn_mask_pad = F.pad(
+ x_attn_mask,
+ (0, y_len),
+ value=True,
+ )
+ y_attn_mask = F.pad(
+ torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
+ (x_len, 0),
+ value=False,
+ )
+ xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
+ y.device
+ )
+
+ xy_dec, _ = self.ar_decoder(
+ (xy_pos, None),
+ mask=xy_attn_mask,
+ )
+ logits = self.ar_predict_layer(xy_dec[:, -1])
+ samples = topk_sampling(
+ logits, top_k=top_k, top_p=1.0, temperature=temperature
+ )
+
+ if (
+ torch.argmax(logits, dim=-1)[0] == self.audio_token_num
+ or samples[0, 0] == self.audio_token_num
+ or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 16
+ ):
+ if prompts.shape[1] == y.shape[1]:
+ raise SyntaxError("well trained model shouldn't reach here.")
+
+ break
+
+ y = torch.concat([y, samples], dim=1)
+
+ codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]]
+ if self.num_quantizers == 1:
+ return torch.stack(codes, dim=-1)
+
+ # Non-AR Decoders
+ y_emb = self.nar_audio_embeddings[0](y[:, int(self.ar_audio_prepend_bos) :])
+
+ if self.prefix_mode in [2, 4]:
+ enrolled_len = enroll_x_lens.max().item()
+ # SOS + Synthesis Text + EOS
+ text = torch.concat(
+ [
+ text[:, :1],
+ text[:, enrolled_len - 1 :],
+ ],
+ dim=1,
+ )
+ text_len = text_len - (enrolled_len - 2)
+ assert text.shape[0] == 1
+
+ x = self.nar_text_embedding(text)
+ x = self.nar_text_prenet(x)
+ x = self.nar_text_position(x)
+
+ if self.prefix_mode == 0:
+ for i, (predict_layer, embedding_layer) in enumerate(
+ zip(
+ self.nar_predict_layers,
+ self.nar_audio_embeddings[1:],
+ )
+ ):
+ y_pos = self.nar_audio_prenet(y_emb)
+ y_pos = self.nar_audio_position(y_pos)
+ xy_pos = torch.concat([x, y_pos], dim=1)
+
+ xy_dec, _ = self.nar_decoder(
+ (xy_pos, self.nar_stage_embeddings[i].weight)
+ )
+ logits = predict_layer(xy_dec[:, text_len + prefix_len :])
+
+ samples = torch.argmax(logits, dim=-1)
+ codes.append(samples)
+
+ if i < self.num_quantizers - 2:
+ y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1])
+ y_emb[:, prefix_len:] += embedding_layer(samples)
+ else:
+ for j in range(1, self.num_quantizers):
+ y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j])
+
+ for i, (predict_layer, embedding_layer) in enumerate(
+ zip(
+ self.nar_predict_layers,
+ self.nar_audio_embeddings[1:],
+ )
+ ):
+ y_pos = self.nar_audio_prenet(y_emb)
+ y_pos = self.nar_audio_position(y_pos)
+ xy_pos = torch.concat([x, y_pos], dim=1)
+
+ xy_dec, _ = self.nar_decoder(
+ (xy_pos, self.nar_stage_embeddings[i].weight)
+ )
+ logits = predict_layer(xy_dec[:, text_len + prefix_len :])
+
+ samples = torch.argmax(logits, dim=-1)
+ codes.append(samples)
+
+ if i < self.num_quantizers - 2:
+ y_emb[:, prefix_len:] += embedding_layer(samples)
+
+ assert len(codes) == self.num_quantizers
+ return torch.stack(codes, dim=-1)
+
+ def continual(
+ self,
+ x: torch.Tensor,
+ x_lens: torch.Tensor,
+ y: torch.Tensor,
+ ) -> torch.Tensor:
+ """
+ Args:
+ x:
+ A 2-D tensor of shape (1, S).
+ x_lens:
+ A 1-D tensor of shape (1,). It contains the number of tokens in `x`
+ before padding.
+ y:
+ A 3-D tensor of shape (1, T, 8).
+ Returns:
+ Return the predicted audio code matrix.
+ """
+ assert x.ndim == 2, x.shape
+ assert x_lens.ndim == 1, x_lens.shape
+ assert y.ndim == 3, y.shape
+ assert y.shape[0] == 1, y.shape
+
+ assert torch.all(x_lens > 0)
+ assert self.num_quantizers == 8
+
+ text = x
+ x = self.ar_text_embedding(text)
+ x = self.ar_text_prenet(x)
+ x = self.ar_text_position(x)
+
+ text_len = x_lens.max()
+
+ prefix_len = min(int(y.shape[1] * 0.5), 3 * 75)
+
+ # AR Decoder
+ prompts = y[:, :prefix_len]
+
+ codes = [y[:, prefix_len:, 0]]
+ # Non-AR Decoders
+ x = self.nar_text_embedding(text)
+ x = self.nar_text_prenet(x)
+ x = self.nar_text_position(x)
+
+ y_emb = self.nar_audio_embeddings[0](y[..., 0])
+
+ if self.prefix_mode == 0:
+ for i, (predict_layer, embedding_layer) in enumerate(
+ zip(
+ self.nar_predict_layers,
+ self.nar_audio_embeddings[1:],
+ )
+ ):
+ y_pos = self.nar_audio_position(y_emb)
+ y_pos = self.nar_audio_prenet(y_pos)
+ xy_pos = torch.concat([x, y_pos], dim=1)
+
+ xy_dec, _ = self.nar_decoder(
+ (xy_pos, self.nar_stage_embeddings[i].weight)
+ )
+ logits = predict_layer(xy_dec[:, text_len + prefix_len :])
+
+ samples = torch.argmax(logits, dim=-1)
+ codes.append(samples)
+
+ if i < 6:
+ y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1])
+ y_emb[:, prefix_len:] += embedding_layer(samples)
+ else:
+ for j in range(1, 8):
+ y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j])
+
+ for i, (predict_layer, embedding_layer) in enumerate(
+ zip(
+ self.nar_predict_layers,
+ self.nar_audio_embeddings[1:],
+ )
+ ):
+ y_pos = self.nar_audio_prenet(y_emb)
+ y_pos = self.nar_audio_position(y_pos)
+ xy_pos = torch.concat([x, y_pos], dim=1)
+
+ xy_dec, _ = self.nar_decoder(
+ (xy_pos, self.nar_stage_embeddings[i].weight)
+ )
+ logits = predict_layer(xy_dec[:, text_len + prefix_len :])
+
+ samples = torch.argmax(logits, dim=-1)
+ codes.append(samples)
+
+ if i < 6:
+ y_emb[:, prefix_len:] += embedding_layer(samples)
+
+ assert len(codes) == 8
+ return torch.stack(codes, dim=-1)
+
+ def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]:
+ assert stage > 0
+ if stage == 1:
+ for name, param in self.named_parameters():
+ if name.startswith("ar_"):
+ yield param
+
+ if stage == 2:
+ for name, param in self.named_parameters():
+ if name.startswith("nar_"):
+ yield param
+
+ def stage_named_parameters(
+ self, stage: int = 1
+ ) -> Iterator[Tuple[str, nn.Parameter]]:
+ assert stage > 0
+ if stage == 1:
+ for pair in self.named_parameters():
+ if pair[0].startswith("ar_"):
+ yield pair
+
+ if stage == 2:
+ for pair in self.named_parameters():
+ if pair[0].startswith("nar_"):
+ yield pair
+
+ def pad_y_eos(self, y, y_mask_int, eos_id):
+ targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
+ y_mask_int, (0, 1), value=1
+ )
+ if self.ar_audio_prepend_bos:
+ return (
+ F.pad(targets[:, :-1], (1, 0), value=self.audio_token_num + 1),
+ targets,
+ )
+
+ return targets[:, :-1], targets[:, 1:]
+
+ def _prepare_prompts(self, y, y_lens, codes, nar_stage, y_prompts_codes):
+ # 5.1 For the NAR acoustic prompt tokens, we select a random segment waveform of 3 seconds
+ # from the same utterance.
+ # We implement this differently.
+ if self.prefix_mode == 0:
+ # no prefix
+ prefix_len = 0
+ y_emb = self.nar_audio_embeddings[0](y)
+ for j in range(1, nar_stage):
+ # Formula (4) (5)
+ y_emb = y_emb + self.nar_audio_embeddings[j](codes[..., j])
+ elif self.prefix_mode == 1:
+ # prefix at begining
+ int_low = (0.25 * y_lens.min()).type(torch.int64).item()
+ prefix_len = torch.randint(int_low, int_low * 2, size=()).item()
+ prefix_len = min(prefix_len, 225) # 24000/320 * 3s = 225 frames
+
+ y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len])
+ y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:])
+ for j in range(1, self.num_quantizers):
+ y_prompts += self.nar_audio_embeddings[j](codes[:, :prefix_len, j])
+ if j < nar_stage:
+ y_emb += self.nar_audio_embeddings[j](codes[:, prefix_len:, j])
+ y_emb = torch.concat([y_prompts, y_emb], axis=1)
+ elif self.prefix_mode in [2, 4]:
+ if self.prefix_mode == 2:
+ # random prefix
+ prefix_len = min(225, int(0.25 * y_lens.min().item()))
+
+ y_prompts_codes = []
+ for b in range(codes.shape[0]):
+ start = self.rng.randint(0, y_lens[b].item() - prefix_len)
+ y_prompts_codes.append(
+ torch.clone(codes[b, start : start + prefix_len])
+ )
+ codes[b, start : start + prefix_len, nar_stage] = NUM_AUDIO_TOKENS
+ y_prompts_codes = torch.stack(y_prompts_codes, dim=0)
+ else:
+ prefix_len = y_prompts_codes.shape[1]
+
+ y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0])
+ y_emb = self.nar_audio_embeddings[0](y)
+ for j in range(1, self.num_quantizers):
+ y_prompts += self.nar_audio_embeddings[j](y_prompts_codes[..., j])
+ if j < nar_stage:
+ y_emb += self.nar_audio_embeddings[j](codes[..., j])
+ y_emb = torch.concat([y_prompts, y_emb], axis=1)
+ else:
+ raise ValueError
+
+ return y_emb, prefix_len
diff --git a/models/tts/valle/valle_dataset.py b/models/tts/valle/valle_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e6055089b44097a3574279ee2f50327bc09fe65
--- /dev/null
+++ b/models/tts/valle/valle_dataset.py
@@ -0,0 +1,291 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch.nn.utils.rnn import pad_sequence
+from utils.data_utils import *
+from models.tts.base.tts_dataset import (
+ TTSDataset,
+ TTSCollator,
+ TTSTestDataset,
+ TTSTestCollator,
+)
+
+from torch.utils.data.sampler import (
+ BatchSampler,
+ RandomSampler,
+ SequentialSampler,
+)
+
+from utils.tokenizer import tokenize_audio
+
+
+class VALLEDataset(TTSDataset):
+ def __init__(self, cfg, dataset, is_valid=False):
+ super().__init__(cfg, dataset, is_valid=is_valid)
+
+ """
+ Args:
+ cfg: config
+ dataset: dataset name
+ is_valid: whether to use train or valid dataset
+ """
+
+ assert isinstance(dataset, str)
+
+ assert cfg.preprocess.use_acoustic_token == True
+ if cfg.preprocess.use_acoustic_token:
+ self.utt2acousticToken_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2acousticToken_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.acoustic_token_dir, # code
+ uid + ".npy",
+ )
+
+ self.all_num_frames = []
+ for i in range(len(self.metadata)):
+ self.all_num_frames.append(self.metadata[i]["Duration"])
+ self.num_frame_sorted = np.array(sorted(self.all_num_frames))
+ self.num_frame_indices = np.array(
+ sorted(
+ range(len(self.all_num_frames)), key=lambda k: self.all_num_frames[k]
+ )
+ )
+
+ def __len__(self):
+ return super().__len__()
+
+ def get_metadata(self):
+ metadata_filter = []
+ with open(self.metafile_path, "r", encoding="utf-8") as f:
+ metadata = json.load(f)
+ for utt_info in metadata:
+ duration = utt_info["Duration"]
+ if (
+ duration >= self.cfg.preprocess.max_duration
+ or duration <= self.cfg.preprocess.min_duration
+ ):
+ continue
+ metadata_filter.append(utt_info)
+
+ return metadata_filter
+
+ def get_dur(self, idx):
+ utt_info = self.metadata[idx]
+ return utt_info["Duration"]
+
+ def __getitem__(self, index):
+ single_feature = super().__getitem__(index)
+
+ utt_info = self.metadata[index]
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ # acoustic token
+ if self.cfg.preprocess.use_acoustic_token:
+ acoustic_token = np.load(self.utt2acousticToken_path[utt])
+ if "target_len" not in single_feature.keys():
+ single_feature["target_len"] = acoustic_token.shape[0]
+ single_feature["acoustic_token"] = acoustic_token # [T, 8]
+
+ return single_feature
+
+ def get_num_frames(self, index):
+ utt_info = self.metadata[index]
+ return int(
+ utt_info["Duration"]
+ * (self.cfg.preprocess.sample_rate // self.cfg.preprocess.codec_hop_size)
+ )
+
+
+class VALLECollator(TTSCollator):
+ def __init__(self, cfg):
+ super().__init__(cfg)
+
+ def __call__(self, batch):
+ parsed_batch_features = super().__call__(batch)
+ return parsed_batch_features
+
+
+class VALLETestDataset(TTSTestDataset):
+ def __init__(self, args, cfg):
+ super().__init__(args, cfg)
+
+ # prepare data
+ assert cfg.preprocess.use_acoustic_token == True
+ if cfg.preprocess.use_acoustic_token:
+ self.utt2acousticToken = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ # extract acoustic token
+ audio_file = utt_info["Audio_pormpt_path"]
+ encoded_frames = tokenize_audio(self.audio_tokenizer, audio_file)
+ audio_prompt_token = (
+ encoded_frames[0][0].transpose(2, 1).squeeze(0).cpu().numpy()
+ )
+ self.utt2acousticToken[utt] = audio_prompt_token
+
+ def __getitem__(self, index):
+ utt_info = self.metadata[index]
+
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ single_feature = dict()
+
+ # acoustic token
+ if self.cfg.preprocess.use_acoustic_token:
+ acoustic_token = self.utt2acousticToken[utt]
+ if "target_len" not in single_feature.keys():
+ single_feature["target_len"] = acoustic_token.shape[0]
+ single_feature["acoustic_token"] = acoustic_token # [T, 8]
+
+ # phone sequence todo
+ if self.cfg.preprocess.use_phone:
+ single_feature["phone_seq"] = np.array(self.utt2seq[utt])
+ single_feature["phone_len"] = len(self.utt2seq[utt])
+ single_feature["pmt_phone_seq"] = np.array(self.utt2pmtseq[utt])
+ single_feature["pmt_phone_len"] = len(self.utt2pmtseq[utt])
+
+ return single_feature
+
+ def get_metadata(self):
+ with open(self.metafile_path, "r", encoding="utf-8") as f:
+ metadata = json.load(f)
+ return metadata
+
+ def __len__(self):
+ return len(self.metadata)
+
+
+class VALLETestCollator(TTSTestCollator):
+ def __init__(self, cfg):
+ self.cfg = cfg
+
+ def __call__(self, batch):
+ packed_batch_features = dict()
+
+ for key in batch[0].keys():
+ if key == "target_len":
+ packed_batch_features["target_len"] = torch.LongTensor(
+ [b["target_len"] for b in batch]
+ )
+ masks = [
+ torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
+ ]
+ packed_batch_features["mask"] = pad_sequence(
+ masks, batch_first=True, padding_value=0
+ )
+ elif key == "phone_len":
+ packed_batch_features["phone_len"] = torch.LongTensor(
+ [b["phone_len"] for b in batch]
+ )
+ masks = [
+ torch.ones((b["phone_len"], 1), dtype=torch.long) for b in batch
+ ]
+ packed_batch_features["phn_mask"] = pad_sequence(
+ masks, batch_first=True, padding_value=0
+ )
+ elif key == "pmt_phone_len":
+ packed_batch_features["pmt_phone_len"] = torch.LongTensor(
+ [b["pmt_phone_len"] for b in batch]
+ )
+ masks = [
+ torch.ones((b["pmt_phone_len"], 1), dtype=torch.long) for b in batch
+ ]
+ packed_batch_features["pmt_phone_len_mask"] = pad_sequence(
+ masks, batch_first=True, padding_value=0
+ )
+ elif key == "audio_len":
+ packed_batch_features["audio_len"] = torch.LongTensor(
+ [b["audio_len"] for b in batch]
+ )
+ masks = [
+ torch.ones((b["audio_len"], 1), dtype=torch.long) for b in batch
+ ]
+ else:
+ values = [torch.from_numpy(b[key]) for b in batch]
+ packed_batch_features[key] = pad_sequence(
+ values, batch_first=True, padding_value=0
+ )
+
+ return packed_batch_features
+
+
+def _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
+ if len(batch) == 0:
+ return 0
+ if len(batch) == max_sentences:
+ return 1
+ if num_tokens > max_tokens:
+ return 1
+ return 0
+
+
+def batch_by_size(
+ indices,
+ num_tokens_fn,
+ max_tokens=None,
+ max_sentences=None,
+ required_batch_size_multiple=1,
+):
+ """
+ Yield mini-batches of indices bucketed by size. Batches may contain
+ sequences of different lengths.
+
+ Args:
+ indices (List[int]): ordered list of dataset indices
+ num_tokens_fn (callable): function that returns the number of tokens at
+ a given index
+ max_tokens (int, optional): max number of tokens in each batch
+ (default: None).
+ max_sentences (int, optional): max number of sentences in each
+ batch (default: None).
+ required_batch_size_multiple (int, optional): require batch size to
+ be a multiple of N (default: 1).
+ """
+ bsz_mult = required_batch_size_multiple
+
+ sample_len = 0
+ sample_lens = []
+ batch = []
+ batches = []
+ for i in range(len(indices)):
+ idx = indices[i]
+ num_tokens = num_tokens_fn(idx)
+ sample_lens.append(num_tokens)
+ sample_len = max(sample_len, num_tokens)
+
+ assert (
+ sample_len <= max_tokens
+ ), "sentence at index {} of size {} exceeds max_tokens " "limit of {}!".format(
+ idx, sample_len, max_tokens
+ )
+ num_tokens = (len(batch) + 1) * sample_len
+
+ if _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
+ mod_len = max(
+ bsz_mult * (len(batch) // bsz_mult),
+ len(batch) % bsz_mult,
+ )
+ batches.append(batch[:mod_len])
+ batch = batch[mod_len:]
+ sample_lens = sample_lens[mod_len:]
+ sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
+ batch.append(idx)
+ if len(batch) > 0:
+ batches.append(batch)
+ return batches
diff --git a/models/tts/valle/valle_inference.py b/models/tts/valle/valle_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..2735057b87302c2132b5ef94c67d4df45ac3fcbb
--- /dev/null
+++ b/models/tts/valle/valle_inference.py
@@ -0,0 +1,237 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import numpy as np
+import torch
+import torchaudio
+import argparse
+
+
+from text.g2p_module import G2PModule
+from utils.tokenizer import AudioTokenizer, tokenize_audio
+from models.tts.valle.valle import VALLE
+from models.tts.base.tts_inferece import TTSInference
+from models.tts.valle.valle_dataset import VALLETestDataset, VALLETestCollator
+from processors.phone_extractor import phoneExtractor
+from text.text_token_collation import phoneIDCollation
+
+
+class VALLEInference(TTSInference):
+ def __init__(self, args=None, cfg=None):
+ TTSInference.__init__(self, args, cfg)
+
+ self.g2p_module = G2PModule(backend=self.cfg.preprocess.phone_extractor)
+ text_token_path = os.path.join(
+ cfg.preprocess.processed_dir, cfg.dataset[0], cfg.preprocess.symbols_dict
+ )
+ self.audio_tokenizer = AudioTokenizer()
+
+ def _build_model(self):
+ model = VALLE(self.cfg.model)
+ return model
+
+ def _build_test_dataset(self):
+ return VALLETestDataset, VALLETestCollator
+
+ def inference_one_clip(self, text, text_prompt, audio_file, save_name="pred"):
+ # get phone symbol file
+ phone_symbol_file = None
+ if self.cfg.preprocess.phone_extractor != "lexicon":
+ phone_symbol_file = os.path.join(
+ self.exp_dir, self.cfg.preprocess.symbols_dict
+ )
+ assert os.path.exists(phone_symbol_file)
+ # convert text to phone sequence
+ phone_extractor = phoneExtractor(self.cfg)
+ # convert phone sequence to phone id sequence
+ phon_id_collator = phoneIDCollation(
+ self.cfg, symbols_dict_file=phone_symbol_file
+ )
+
+ text = f"{text_prompt} {text}".strip()
+ phone_seq = phone_extractor.extract_phone(text) # phone_seq: list
+ phone_id_seq = phon_id_collator.get_phone_id_sequence(self.cfg, phone_seq)
+ phone_id_seq_len = torch.IntTensor([len(phone_id_seq)]).to(self.device)
+
+ # convert phone sequence to phone id sequence
+ phone_id_seq = np.array([phone_id_seq])
+ phone_id_seq = torch.from_numpy(phone_id_seq).to(self.device)
+
+ # extract acoustic token
+ encoded_frames = tokenize_audio(self.audio_tokenizer, audio_file)
+ audio_prompt_token = encoded_frames[0][0].transpose(2, 1).to(self.device)
+
+ # copysyn
+ if self.args.copysyn:
+ samples = self.audio_tokenizer.decode(encoded_frames)
+ audio_copysyn = samples[0].cpu().detach()
+
+ out_path = os.path.join(
+ self.args.output_dir, self.infer_type, f"{save_name}_copysyn.wav"
+ )
+ torchaudio.save(out_path, audio_copysyn, self.cfg.preprocess.sampling_rate)
+
+ if self.args.continual:
+ encoded_frames = self.model.continual(
+ phone_id_seq,
+ phone_id_seq_len,
+ audio_prompt_token,
+ )
+ else:
+ enroll_x_lens = None
+ if text_prompt:
+ # prompt_phone_seq = tokenize_text(self.g2p_module, text=f"{text_prompt}".strip())
+ # _, enroll_x_lens = self.text_tokenizer.get_token_id_seq(prompt_phone_seq)
+
+ text = f"{text_prompt}".strip()
+ prompt_phone_seq = phone_extractor.extract_phone(
+ text
+ ) # phone_seq: list
+ prompt_phone_id_seq = phon_id_collator.get_phone_id_sequence(
+ self.cfg, prompt_phone_seq
+ )
+ prompt_phone_id_seq_len = torch.IntTensor(
+ [len(prompt_phone_id_seq)]
+ ).to(self.device)
+
+ encoded_frames = self.model.inference(
+ phone_id_seq,
+ phone_id_seq_len,
+ audio_prompt_token,
+ enroll_x_lens=prompt_phone_id_seq_len,
+ top_k=self.args.top_k,
+ temperature=self.args.temperature,
+ )
+
+ samples = self.audio_tokenizer.decode([(encoded_frames.transpose(2, 1), None)])
+
+ audio = samples[0].squeeze(0).cpu().detach()
+
+ return audio
+
+ def inference_for_single_utterance(self):
+ text = self.args.text
+ text_prompt = self.args.text_prompt
+ audio_file = self.args.audio_prompt
+
+ if not self.args.continual:
+ assert text != ""
+ else:
+ text = ""
+ assert text_prompt != ""
+ assert audio_file != ""
+
+ audio = self.inference_one_clip(text, text_prompt, audio_file)
+
+ return audio
+
+ def inference_for_batches(self):
+ test_list_file = self.args.test_list_file
+ assert test_list_file is not None
+
+ pred_res = []
+ with open(test_list_file, "r") as fin:
+ for idx, line in enumerate(fin.readlines()):
+ fields = line.strip().split("|")
+ if self.args.continual:
+ assert len(fields) == 2
+ text_prompt, audio_prompt_path = fields
+ text = ""
+ else:
+ assert len(fields) == 3
+ text_prompt, audio_prompt_path, text = fields
+
+ audio = self.inference_one_clip(
+ text, text_prompt, audio_prompt_path, str(idx)
+ )
+ pred_res.append(audio)
+
+ return pred_res
+
+ """
+ TODO: batch inference
+ ###### Construct test_batch ######
+ n_batch = len(self.test_dataloader)
+ now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
+ print(
+ "Model eval time: {}, batch_size = {}, n_batch = {}".format(
+ now, self.test_batch_size, n_batch
+ )
+ )
+
+ ###### Inference for each batch ######
+ pred_res = []
+ with torch.no_grad():
+ for i, batch_data in enumerate(
+ self.test_dataloader if n_batch == 1 else tqdm(self.test_dataloader)
+ ):
+ if self.args.continual:
+ encoded_frames = self.model.continual(
+ batch_data["phone_seq"],
+ batch_data["phone_len"],
+ batch_data["acoustic_token"],
+ )
+ else:
+ encoded_frames = self.model.inference(
+ batch_data["phone_seq"],
+ batch_data["phone_len"],
+ batch_data["acoustic_token"],
+ enroll_x_lens=batch_data["pmt_phone_len"],
+ top_k=self.args.top_k,
+ temperature=self.args.temperature
+ )
+
+ samples = self.audio_tokenizer.decode(
+ [(encoded_frames.transpose(2, 1), None)]
+ )
+
+
+ for idx in range(samples.size(0)):
+ audio = samples[idx].cpu()
+ pred_res.append(audio)
+
+ return pred_res
+ """
+
+ def add_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--text_prompt",
+ type=str,
+ default="",
+ help="Text prompt that should be aligned with --audio_prompt.",
+ )
+
+ parser.add_argument(
+ "--audio_prompt",
+ type=str,
+ default="",
+ help="Audio prompt that should be aligned with --text_prompt.",
+ )
+ parser.add_argument(
+ "--top-k",
+ type=int,
+ default=-100,
+ help="Whether AR Decoder do top_k(if > 0) sampling.",
+ )
+
+ parser.add_argument(
+ "--temperature",
+ type=float,
+ default=1.0,
+ help="The temperature of AR Decoder top_k sampling.",
+ )
+
+ parser.add_argument(
+ "--continual",
+ action="store_true",
+ help="Inference for continual task.",
+ )
+
+ parser.add_argument(
+ "--copysyn",
+ action="store_true",
+ help="Copysyn: generate audio by decoder of the original audio tokenizer.",
+ )
diff --git a/models/tts/valle/valle_trainer.py b/models/tts/valle/valle_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc3b363539ba63d951d6f054b2b597a0eaec637a
--- /dev/null
+++ b/models/tts/valle/valle_trainer.py
@@ -0,0 +1,367 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+from tqdm import tqdm
+import torch
+import numpy as np
+from torch.utils.data import DataLoader
+from torch.nn.parallel import DistributedDataParallel
+from optimizer.optimizers import Eve, ScaledAdam
+from schedulers.scheduler import NoamScheduler, Eden
+from models.tts.valle.valle_dataset import (
+ VALLEDataset,
+ VALLECollator,
+ batch_by_size,
+)
+from models.base.base_sampler import VariableSampler
+from models.tts.base import TTSTrainer
+from models.tts.valle.valle import VALLE
+import diffusers
+
+
+class VALLETrainer(TTSTrainer):
+ def __init__(self, args, cfg):
+ TTSTrainer.__init__(self, args, cfg)
+
+ def _build_model(self):
+ model = VALLE(self.cfg.model)
+
+ return model
+
+ def _build_dataset(self):
+ return VALLEDataset, VALLECollator
+
+ def _build_optimizer(self):
+ if self.args.train_stage:
+ if isinstance(self.model, DistributedDataParallel):
+ model = self.model.module
+ else:
+ model = self.model
+ model_parameters = model.stage_parameters(self.args.train_stage)
+ else:
+ model_parameters = self.model.parameters()
+
+ if self.cfg.train.optimizer == "ScaledAdam":
+ parameters_names = []
+ if self.args.train_stage != 0:
+ parameters_names.append(
+ [
+ name_param_pair[0]
+ for name_param_pair in model.stage_named_parameters(
+ self.args.train_stage
+ )
+ ]
+ )
+ else:
+ parameters_names.append(
+ [name_param_pair[0] for name_param_pair in model.named_parameters()]
+ )
+
+ optimizer = ScaledAdam(
+ model_parameters,
+ lr=self.cfg.train.base_lr,
+ betas=(0.9, 0.95),
+ clipping_scale=2.0,
+ parameters_names=parameters_names,
+ show_dominant_parameters=False,
+ clipping_update_period=1000,
+ )
+ elif self.cfg.train.optimizer == "Eve":
+ optimizer = Eve(
+ model_parameters,
+ lr=self.cfg.train.base_lr,
+ betas=(0.9, 0.98),
+ target_rms=0.1,
+ )
+ elif self.cfg.train.optimizer == "AdamW":
+ optimizer = torch.optim.AdamW(
+ model_parameters,
+ lr=self.cfg.train.base_lr,
+ betas=(0.9, 0.95),
+ weight_decay=1e-2,
+ eps=1e-8,
+ )
+ elif self.cfg.train.optimizer == "Adam":
+ optimizer = torch.optim.Adam(
+ model_parameters,
+ lr=self.cfg.train.base_lr,
+ betas=(0.9, 0.95),
+ eps=1e-8,
+ )
+ else:
+ raise NotImplementedError()
+
+ return optimizer
+
+ def _build_scheduler(self):
+ if self.cfg.train.scheduler.lower() == "eden":
+ scheduler = Eden(
+ self.optimizer, 5000, 4, warmup_batches=self.cfg.train.warmup_steps
+ )
+ elif self.cfg.train.scheduler.lower() == "noam":
+ scheduler = NoamScheduler(
+ self.cfg.train.base_lr,
+ self.optimizer,
+ self.cfg.model.decoder_dim,
+ warmup_steps=self.cfg.train.warmup_steps,
+ )
+ elif self.cfg.train.scheduler.lower() == "cosine":
+ from diffusers.optimization import get_cosine_schedule_with_warmup
+
+ scheduler = get_cosine_schedule_with_warmup(
+ self.optimizer,
+ num_warmup_steps=self.cfg.train.warmup_steps
+ * self.accelerator.num_processes,
+ num_training_steps=self.cfg.train.total_training_steps
+ * self.accelerator.num_processes,
+ )
+ else:
+ raise NotImplementedError(f"{self.cfg.train.scheduler}")
+
+ return scheduler
+
+ def _train_epoch(self):
+ r"""Training epoch. Should return average loss of a batch (sample) over
+ one epoch. See ``train_loop`` for usage.
+ """
+ if isinstance(self.model, dict):
+ for key in self.model.keys():
+ self.model[key].train()
+ else:
+ self.model.train()
+
+ epoch_sum_loss: float = 0.0
+ epoch_losses: dict = {}
+ epoch_step: int = 0
+ for batch in tqdm(
+ self.train_dataloader,
+ desc=f"Training Epoch {self.epoch}",
+ unit="batch",
+ colour="GREEN",
+ leave=False,
+ dynamic_ncols=True,
+ smoothing=0.04,
+ disable=not self.accelerator.is_main_process,
+ ):
+ # Do training step and BP
+ with self.accelerator.accumulate(self.model):
+ total_loss, train_losses = self._train_step(batch)
+ self.accelerator.backward(total_loss)
+ self.optimizer.step()
+ self.optimizer.zero_grad()
+ self.batch_count += 1
+
+ if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
+ if self.cfg.train.optimizer not in ["ScaledAdam", "Eve"]:
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
+
+ for k in range(self.cfg.train.gradient_accumulation_step):
+ if isinstance(self.scheduler, Eden):
+ self.scheduler.step_batch(self.step)
+ else:
+ self.scheduler.step()
+
+ epoch_sum_loss += total_loss.detach().cpu().item()
+
+ if isinstance(train_losses, dict):
+ for key, value in train_losses.items():
+ if key not in epoch_losses.keys():
+ epoch_losses[key] = value
+ else:
+ epoch_losses[key] += value
+
+ if isinstance(train_losses, dict):
+ for key, loss in train_losses.items():
+ self.accelerator.log(
+ {"Step/Train {}".format(key): "{:.6f}".format(loss)},
+ step=self.step,
+ )
+ else:
+ self.accelerator.log(
+ {"Step/Train Loss": loss},
+ step=self.step,
+ )
+
+ self.accelerator.log(
+ {"Step/lr": self.scheduler.get_last_lr()[0]},
+ step=self.step,
+ )
+
+ # print loss every log_epoch_step steps
+ # if epoch_step % self.cfg.train.log_epoch_step == 0:
+ # for key, loss in train_losses.items():
+ # self.logger.info("Step/Train {}: {:.6f}".format(key, loss))
+ # print("Step/Train {}: {:.6f}".format(key, loss))
+
+ self.step += 1
+ epoch_step += 1
+
+ self.accelerator.wait_for_everyone()
+
+ epoch_sum_loss = (
+ epoch_sum_loss
+ / len(self.train_dataloader)
+ * self.cfg.train.gradient_accumulation_step
+ )
+
+ for key in epoch_losses.keys():
+ epoch_losses[key] = (
+ epoch_losses[key]
+ / len(self.train_dataloader)
+ * self.cfg.train.gradient_accumulation_step
+ )
+
+ return epoch_sum_loss, epoch_losses
+
+ def _train_step(self, batch, is_training=True):
+ text_tokens = batch["phone_seq"].to(self.device)
+ text_tokens_lens = batch["phone_len"].to(self.device)
+ assert text_tokens.ndim == 2
+
+ audio_features = batch["acoustic_token"].to(self.device)
+ audio_features_lens = batch["target_len"].to(self.device)
+ assert audio_features.ndim == 3
+
+ with torch.set_grad_enabled(is_training):
+ loss, losses = self.model(
+ x=text_tokens,
+ x_lens=text_tokens_lens,
+ y=audio_features,
+ y_lens=audio_features_lens,
+ train_stage=self.args.train_stage,
+ )
+
+ assert loss.requires_grad == is_training
+
+ loss_dict = {}
+ frames_sum = (audio_features_lens).sum()
+
+ avg_loss = loss / frames_sum
+
+ loss_dict["loss"] = avg_loss.detach().cpu().item()
+ for l in losses:
+ loss_dict[l] = losses[l].detach().cpu().item() / frames_sum.item()
+
+ return avg_loss, loss_dict
+
+ def _valid_step(self, batch):
+ valid_losses = {}
+ total_loss = 0
+ valid_stats = {}
+
+ total_loss, valid_losses = self._train_step(
+ batch=batch,
+ is_training=False,
+ )
+ assert total_loss.requires_grad is False
+
+ total_loss = total_loss.detach().cpu().item()
+
+ return total_loss, valid_losses, valid_stats
+
+ def _build_dataloader(self):
+ if not self.cfg.train.use_dynamic_batchsize:
+ return super()._build_dataloader()
+ if len(self.cfg.dataset) > 1:
+ raise Exception("use_dynamic_batchsize only supports single dataset now.")
+ Dataset, Collator = self._build_dataset()
+ train_dataset = Dataset(
+ self.cfg, self.cfg.dataset[0], is_valid=False
+ ) # TODO: support use_dynamic_batchsize for more than one datasets.
+ train_collate = Collator(self.cfg)
+ batch_sampler = batch_by_size(
+ train_dataset.num_frame_indices,
+ train_dataset.get_num_frames,
+ max_tokens=self.cfg.train.max_tokens * self.accelerator.num_processes,
+ max_sentences=self.cfg.train.max_sentences * self.accelerator.num_processes,
+ required_batch_size_multiple=self.accelerator.num_processes,
+ )
+ np.random.seed(1234)
+ np.random.shuffle(batch_sampler)
+ print(batch_sampler[:1])
+ batches = [
+ x[self.accelerator.local_process_index :: self.accelerator.num_processes]
+ for x in batch_sampler
+ if len(x) % self.accelerator.num_processes == 0
+ ]
+
+ train_loader = DataLoader(
+ train_dataset,
+ collate_fn=train_collate,
+ num_workers=self.cfg.train.dataloader.num_worker,
+ batch_sampler=VariableSampler(
+ batches, drop_last=False, use_random_sampler=True
+ ),
+ pin_memory=False,
+ )
+ self.accelerator.wait_for_everyone()
+
+ valid_dataset = Dataset(self.cfg, self.cfg.dataset[0], is_valid=True)
+ valid_collate = Collator(self.cfg)
+ batch_sampler = batch_by_size(
+ valid_dataset.num_frame_indices,
+ valid_dataset.get_num_frames,
+ max_tokens=self.cfg.train.max_tokens * self.accelerator.num_processes,
+ max_sentences=self.cfg.train.max_sentences * self.accelerator.num_processes,
+ required_batch_size_multiple=self.accelerator.num_processes,
+ )
+ batches = [
+ x[self.accelerator.local_process_index :: self.accelerator.num_processes]
+ for x in batch_sampler
+ if len(x) % self.accelerator.num_processes == 0
+ ]
+ valid_loader = DataLoader(
+ valid_dataset,
+ collate_fn=valid_collate,
+ num_workers=self.cfg.train.dataloader.num_worker,
+ batch_sampler=VariableSampler(batches, drop_last=False),
+ pin_memory=False,
+ )
+ self.accelerator.wait_for_everyone()
+
+ return train_loader, valid_loader
+
+ def _accelerator_prepare(self):
+ if not self.cfg.train.use_dynamic_batchsize:
+ (
+ self.train_dataloader,
+ self.valid_dataloader,
+ ) = self.accelerator.prepare(
+ self.train_dataloader,
+ self.valid_dataloader,
+ )
+
+ if isinstance(self.model, dict):
+ for key in self.model.keys():
+ self.model[key] = self.accelerator.prepare(self.model[key])
+ else:
+ self.model = self.accelerator.prepare(self.model)
+
+ if isinstance(self.optimizer, dict):
+ for key in self.optimizer.keys():
+ self.optimizer[key] = self.accelerator.prepare(self.optimizer[key])
+ else:
+ self.optimizer = self.accelerator.prepare(self.optimizer)
+
+ if isinstance(self.scheduler, dict):
+ for key in self.scheduler.keys():
+ self.scheduler[key] = self.accelerator.prepare(self.scheduler[key])
+ else:
+ self.scheduler = self.accelerator.prepare(self.scheduler)
+
+ def add_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--train_stage",
+ type=int,
+ default="1",
+ help="0: train all modules, 1: AR Decoder, 2: NAR Decoder",
+ )
+ parser.add_argument(
+ "--ar_model_ckpt_dir",
+ type=str,
+ default=None,
+ help="Checkpoint for ar model ckeckpoint in the first training stage.",
+ )
diff --git a/models/tts/valle_v2/base_trainer.py b/models/tts/valle_v2/base_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f7f97b034495a0c861739073e5e405ba487831d
--- /dev/null
+++ b/models/tts/valle_v2/base_trainer.py
@@ -0,0 +1,810 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import json
+import os
+import random
+import shutil
+import time
+from abc import abstractmethod
+from pathlib import Path
+import math
+import accelerate
+import json5
+import numpy as np
+import torch
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration
+from torch.utils.data import ConcatDataset, DataLoader
+from tqdm import tqdm
+
+from models.base.base_sampler import build_samplers
+from optimizer.optimizers import NoamLR
+
+
+class MainProcessLogger:
+ def __init__(self, is_main_process=True, name=None, **kwargs):
+ import logging
+
+ if name is None:
+ logger = logging.getLogger(__name__)
+ else:
+ logger = logging.getLogger(name)
+ self.logger = logger
+ self.is_main_process = is_main_process
+
+ def info(self, msg):
+ if self.is_main_process:
+ print(msg)
+ # self.logger.info(msg)
+
+ def debug(self, msg):
+ if self.is_main_process:
+ print(msg)
+ # self.logger.debug(msg)
+
+ def warning(self, msg):
+ if self.is_main_process:
+ print(msg)
+ # self.logger.warning(msg)
+
+
+class BaseTrainer(object):
+ r"""The base trainer for all tasks. Any trainer should inherit from this class."""
+
+ def __init__(self, args=None, cfg=None):
+ super().__init__()
+
+ self.args = args
+ self.cfg = cfg
+
+ cfg.exp_name = args.exp_name
+
+ # init with accelerate
+ self._init_accelerator()
+ self.accelerator.wait_for_everyone()
+
+ # Use accelerate logger for distributed training
+ with self.accelerator.main_process_first():
+ self.logger = MainProcessLogger(
+ self.accelerator.is_main_process,
+ name=args.exp_name,
+ log_level=args.log_level,
+ )
+
+ # Log some info
+ self.logger.info("=" * 56)
+ self.logger.info("||\t\t" + "New training process started." + "\t\t||")
+ self.logger.info("=" * 56)
+ self.logger.info("\n")
+ self.logger.debug(f"Using {args.log_level.upper()} logging level.")
+ self.logger.info(f"Experiment name: {args.exp_name}")
+ self.logger.info(f"Experiment directory: {self.exp_dir}")
+ self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
+ if self.accelerator.is_main_process:
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
+ self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
+
+ # init counts
+ self.batch_count: int = 0
+ self.step: int = 0
+ self.epoch: int = 0
+ self.max_epoch = (
+ self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf")
+ )
+ self.logger.info(
+ "Max epoch: {}".format(
+ self.max_epoch if self.max_epoch < float("inf") else "Unlimited"
+ )
+ )
+
+ # Check values
+ if self.accelerator.is_main_process:
+ self.__check_basic_configs()
+ # Set runtime configs
+ self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride
+ self.checkpoints_path = [
+ [] for _ in range(len(self.save_checkpoint_stride))
+ ]
+ self.keep_last = [
+ i if i > 0 else float("inf") for i in self.cfg.train.keep_last
+ ]
+ self.run_eval = self.cfg.train.run_eval
+
+ # set random seed
+ with self.accelerator.main_process_first():
+ start = time.monotonic_ns()
+ self._set_random_seed(args.seed)
+ end = time.monotonic_ns()
+ self.logger.debug(
+ f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
+ )
+ self.logger.debug(f"Random seed: {args.seed}")
+
+ # setup data_loader
+ with self.accelerator.main_process_first():
+ self.logger.info("Building dataset...")
+ start = time.monotonic_ns()
+ self.train_dataloader, self.valid_dataloader = self._build_dataloader()
+ end = time.monotonic_ns()
+ self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
+
+ # setup model
+ with self.accelerator.main_process_first():
+ self.logger.info("Building model...")
+ start = time.monotonic_ns()
+ self.model = self._build_model()
+ end = time.monotonic_ns()
+ self.logger.debug(self.model)
+ self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms")
+ self.logger.info(
+ f"Model parameters: {self.__count_parameters(self.model)/1e6:.2f}M"
+ )
+ # optimizer & scheduler
+ with self.accelerator.main_process_first():
+ self.logger.info("Building optimizer and scheduler...")
+ start = time.monotonic_ns()
+ self.optimizer = self._build_optimizer()
+ self.scheduler = self._build_scheduler()
+ end = time.monotonic_ns()
+ self.logger.info(
+ f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms"
+ )
+
+ # accelerate prepare
+ self.logger.info("Initializing accelerate...")
+ start = time.monotonic_ns()
+ self._accelerator_prepare()
+ end = time.monotonic_ns()
+ self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms")
+
+ # create criterion
+ with self.accelerator.main_process_first():
+ self.logger.info("Building criterion...")
+ start = time.monotonic_ns()
+ self.criterion = self._build_criterion()
+ end = time.monotonic_ns()
+ self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms")
+
+ # Resume or Finetune
+ with self.accelerator.main_process_first():
+ if args.resume:
+ if args.resume_from_ckpt_path == "":
+ ## Automatically resume according to the current exprimental name
+ self.logger.info(
+ "Automatically resuming from latest checkpoint in {}...".format(
+ self.checkpoint_dir
+ )
+ )
+ start = time.monotonic_ns()
+ ckpt_path = self._load_model(
+ checkpoint_dir=self.checkpoint_dir, resume_type=args.resume_type
+ )
+ end = time.monotonic_ns()
+ self.logger.info(
+ f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
+ )
+ else:
+ ## Resume from the given checkpoint path
+ if not os.path.exists(args.resume_from_ckpt_path):
+ raise ValueError(
+ "[Error] The resumed checkpoint path {} don't exist.".format(
+ args.resume_from_ckpt_path
+ )
+ )
+ self.logger.info(
+ "Resuming from {}...".format(args.resume_from_ckpt_path)
+ )
+ start = time.monotonic_ns()
+ ckpt_path = self._load_model(
+ checkpoint_path=args.resume_from_ckpt_path,
+ resume_type=args.resume_type,
+ )
+ end = time.monotonic_ns()
+ self.logger.info(
+ f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
+ )
+
+ # save config file path
+ self.config_save_path = os.path.join(self.exp_dir, "args.json")
+
+ def _accelerator_prepare(self):
+ (
+ self.train_dataloader,
+ self.valid_dataloader,
+ self.model,
+ self.optimizer,
+ self.scheduler,
+ ) = self.accelerator.prepare(
+ self.train_dataloader,
+ self.valid_dataloader,
+ self.model,
+ self.optimizer,
+ self.scheduler,
+ )
+
+ ### Following are abstract methods that should be implemented in child classes ###
+ @abstractmethod
+ def _build_dataset(self):
+ r"""Build dataset for model training/validating/evaluating."""
+ pass
+
+ @staticmethod
+ @abstractmethod
+ def _build_criterion():
+ r"""Build criterion function for model loss calculation."""
+ pass
+
+ @abstractmethod
+ def _build_model(self):
+ r"""Build model for training/validating/evaluating."""
+ pass
+
+ @abstractmethod
+ def _forward_step(self, batch):
+ r"""One forward step of the neural network. This abstract method is trying to
+ unify ``_train_step`` and ``_valid_step`` and avoid redundant implementation.
+ However, for special case that using different forward step pattern for
+ training and validating, you could just override this method with ``pass`` and
+ implement ``_train_step`` and ``_valid_step`` separately.
+ """
+ pass
+
+ def save_checkpoint(self):
+ if self.accelerator.is_main_process:
+ keep_last = self.keep_last[0]
+ # 读取self.checkpoint_dir所有的folder
+ all_ckpts = os.listdir(self.checkpoint_dir)
+ all_ckpts = filter(lambda x: x.startswith("epoch"), all_ckpts)
+ all_ckpts = list(all_ckpts)
+ if len(all_ckpts) > keep_last:
+ # 只保留keep_last个的folder in self.checkpoint_dir, sort by step "epoch-{:04d}_step-{:07d}_loss-{:.6f}"
+ all_ckpts = sorted(
+ all_ckpts, key=lambda x: int(x.split("_")[1].split("-")[1])
+ )
+ for ckpt in all_ckpts[:-keep_last]:
+ shutil.rmtree(os.path.join(self.checkpoint_dir, ckpt))
+ checkpoint_filename = "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
+ self.epoch, self.step, self.current_loss
+ )
+ path = os.path.join(self.checkpoint_dir, checkpoint_filename)
+ self.logger.info("Saving state to {}...".format(path))
+ self.accelerator.save_state(path)
+ self.logger.info("Finished saving state.")
+
+ @abstractmethod
+ def _save_auxiliary_states(self):
+ r"""To save some auxiliary states when saving model's ckpt"""
+ pass
+
+ def echo_log(self, losses, mode="Training"):
+ message = [
+ "{} - Epoch {} Step {}: [{:.3f} s/step]".format(
+ mode, self.epoch + 1, self.step, self.time_window.average
+ )
+ ]
+
+ for key in sorted(losses.keys()):
+ if isinstance(losses[key], dict):
+ for k, v in losses[key].items():
+ message.append(
+ str(k).split("/")[-1] + "=" + str(round(float(v), 5))
+ )
+ else:
+ message.append(
+ str(key).split("/")[-1] + "=" + str(round(float(losses[key]), 5))
+ )
+ self.logger.info(", ".join(message))
+
+ ### Abstract methods end ###
+
+ ### THIS IS MAIN ENTRY ###
+ def train_loop(self):
+ r"""Training loop. The public entry of training process."""
+ # Wait everyone to prepare before we move on
+ self.accelerator.wait_for_everyone()
+ # dump config file
+ if self.accelerator.is_main_process:
+ self.__dump_cfg(self.config_save_path)
+ self.model.train()
+ self.optimizer.zero_grad()
+ while self.epoch < self.max_epoch:
+ self.logger.info("\n")
+ self.logger.info("-" * 32)
+ self.logger.info("Epoch {}: ".format(self.epoch))
+
+ ### TODO: change the return values of _train_epoch() to a loss dict, or (total_loss, loss_dict)
+ ### It's inconvenient for the model with multiple losses
+ # Do training & validating epoch
+ train_loss = self._train_epoch()
+ self.logger.info(" |- Train/Loss: {:.6f}".format(train_loss))
+ valid_loss = self._valid_epoch()
+ self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_loss))
+ self.accelerator.log(
+ {"Epoch/Train Loss": train_loss, "Epoch/Valid Loss": valid_loss},
+ step=self.epoch,
+ )
+
+ self.accelerator.wait_for_everyone()
+
+ # Update info for each epoch
+ self.epoch += 1
+
+ # Finish training and save final checkpoint
+ self.accelerator.wait_for_everyone()
+ if self.accelerator.is_main_process:
+ self.accelerator.save_state(
+ os.path.join(
+ self.checkpoint_dir,
+ "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
+ self.epoch, self.step, valid_loss
+ ),
+ )
+ )
+ self._save_auxiliary_states()
+
+ self.accelerator.end_training()
+
+ def get_lr(self, it):
+ # 1) linear warmup for warmup_iters steps
+ if it < self.cfg.train.scheduler.warmup_steps:
+ return self.cfg.train.adamw.lr * it / self.cfg.train.scheduler.warmup_steps
+ # 2) if it > lr_decay_iters, return min learning rate
+ if it > self.cfg.train.scheduler.total_steps:
+ return self.cfg.train.scheduler.min_lr
+ # 3) in between, use cosine decay down to min learning rate
+ decay_ratio = (it - self.cfg.train.scheduler.warmup_steps) / (
+ self.cfg.train.scheduler.total_steps - self.cfg.train.scheduler.warmup_steps
+ )
+ assert 0 <= decay_ratio <= 1
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
+ return self.cfg.train.scheduler.min_lr + coeff * (
+ self.cfg.train.adamw.lr - self.cfg.train.scheduler.min_lr
+ )
+
+ ### Following are methods that can be used directly in child classes ###
+ def _train_epoch(self):
+ r"""Training epoch. Should return average loss of a batch (sample) over
+ one epoch. See ``train_loop`` for usage.
+ """
+ self.model.train()
+ epoch_sum_loss: float = 0.0
+ ema_loss = None
+
+ # profiler
+ start_this_step_time = time.time()
+ finish_last_step_time = time.time()
+
+ for batch in tqdm(
+ self.train_dataloader,
+ desc=f"Training Epoch {self.epoch}",
+ unit="batch",
+ colour="GREEN",
+ leave=False,
+ dynamic_ncols=True,
+ smoothing=0.04,
+ disable=not self.accelerator.is_main_process,
+ ):
+ assert batch is not None
+
+ # start_this_step_time = time.time()
+ # print(f'load batch took: {start_this_step_time - finish_last_step_time:.6f}s')
+
+ # update learning rate
+ lr = self.get_lr(self.step)
+ for param_group in self.optimizer.param_groups:
+ param_group["lr"] = lr
+
+ # Do training step and BP
+ with self.accelerator.accumulate(self.model):
+ loss = self._train_step(batch)
+ self.current_loss = loss.item()
+ ema_loss = (
+ 0.99 * ema_loss + 0.01 * self.current_loss
+ if ema_loss is not None
+ else self.current_loss
+ )
+ self.accelerator.backward(loss)
+ if self.accelerator.sync_gradients:
+ self.accelerator.clip_grad_norm_(self.model.parameters(), 1.0)
+ self.optimizer.step()
+ self.optimizer.zero_grad()
+ self.batch_count += 1
+
+ # if self.accelerator.is_main_process:
+ # print(self.current_loss)
+
+ if self.accelerator.sync_gradients:
+ if self.step % self.cfg.train.save_checkpoint_stride[0] == 0:
+ self.accelerator.wait_for_everyone()
+ if self.accelerator.is_main_process:
+ try:
+ self.save_checkpoint()
+ except:
+ self.logger.info("Failed to save checkpoint, resuming...")
+ if self.accelerator.is_main_process:
+ if self.step % 100 == 0:
+ self.logger.info(f"EMA Loss: {ema_loss:.6f}")
+ self.accelerator.log(
+ {
+ "Step/Train Loss": loss,
+ "Step/Learning Rate": self.optimizer.param_groups[0]["lr"],
+ },
+ step=self.step,
+ )
+ epoch_sum_loss += loss
+ self.step += 1
+
+ # finish_last_step_time = time.time()
+ # print(f'load took: {finish_last_step_time - start_this_step_time:.6f}s')
+ return (
+ epoch_sum_loss
+ / len(self.train_dataloader)
+ * self.cfg.train.gradient_accumulation_step
+ )
+
+ @torch.inference_mode()
+ def _valid_epoch(self):
+ r"""Testing epoch. Should return average loss of a batch (sample) over
+ one epoch. See ``train_loop`` for usage.
+ """
+ self.model.eval()
+ epoch_sum_loss = 0.0
+ for batch in tqdm(
+ self.valid_dataloader,
+ desc=f"Validating Epoch {self.epoch}",
+ unit="batch",
+ colour="GREEN",
+ leave=False,
+ dynamic_ncols=True,
+ smoothing=0.04,
+ disable=not self.accelerator.is_main_process,
+ ):
+ batch_loss = self._valid_step(batch)
+ epoch_sum_loss += batch_loss.item()
+
+ return epoch_sum_loss / len(self.valid_dataloader)
+
+ def _train_step(self, batch):
+ r"""Training forward step. Should return average loss of a sample over
+ one batch. Provoke ``_forward_step`` is recommended except for special case.
+ See ``_train_epoch`` for usage.
+ """
+ return self._forward_step(batch)
+
+ @torch.inference_mode()
+ def _valid_step(self, batch):
+ r"""Testing forward step. Should return average loss of a sample over
+ one batch. Provoke ``_forward_step`` is recommended except for special case.
+ See ``_test_epoch`` for usage.
+ """
+ return self._forward_step(batch)
+
+ def _load_model(
+ self,
+ checkpoint_dir: str = None,
+ checkpoint_path: str = None,
+ resume_type: str = "",
+ ):
+ r"""Load model from checkpoint. If checkpoint_path is None, it will
+ load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
+ None, it will load the checkpoint specified by checkpoint_path. **Only use this
+ method after** ``accelerator.prepare()``.
+ """
+ if checkpoint_path is None:
+ try:
+ all_ckpts = os.listdir(checkpoint_dir)
+ all_ckpts = filter(lambda x: x.startswith("epoch"), all_ckpts)
+ ls = list(all_ckpts)
+ ls = [os.path.join(checkpoint_dir, i) for i in ls]
+ ls.sort(
+ key=lambda x: int(x.split("_")[-2].split("-")[-1]), reverse=True
+ )
+ checkpoint_path = ls[0]
+ self.logger.info("Resume from {}".format(checkpoint_path))
+ except Exception as e:
+ print(
+ "Failed to load checkpoint from {}, starting FROM SCRATCH...".format(
+ checkpoint_dir
+ )
+ )
+ return None
+
+ if resume_type in ["resume", ""]:
+ # Load all the things, including model weights, optimizer, scheduler, and random states.
+ self.accelerator.load_state(input_dir=checkpoint_path)
+
+ # set epoch and step
+ self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
+ self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
+
+ elif resume_type == "finetune":
+ # Load only the model weights
+ accelerate.load_checkpoint_and_dispatch(
+ self.accelerator.unwrap_model(self.model),
+ os.path.join(checkpoint_path, "pytorch_model.bin"),
+ )
+ self.logger.info("Load model weights for finetune...")
+
+ else:
+ raise ValueError("Resume_type must be `resume` or `finetune`.")
+
+ return checkpoint_path
+
+ # TODO: LEGACY CODE
+ def _build_dataloader(self):
+ Dataset, Collator = self._build_dataset()
+
+ # build dataset instance for each dataset and combine them by ConcatDataset
+ datasets_list = []
+ for dataset in self.cfg.dataset:
+ subdataset = Dataset(self.cfg, dataset, is_valid=False)
+ datasets_list.append(subdataset)
+ train_dataset = ConcatDataset(datasets_list)
+ train_collate = Collator(self.cfg)
+ _, batch_sampler = build_samplers(train_dataset, self.cfg, self.logger, "train")
+ self.logger.debug(f"train batch_sampler: {list(batch_sampler)}")
+ self.logger.debug(f"length: {train_dataset.cumulative_sizes}")
+ # TODO: use config instead of (sampler, shuffle, drop_last, batch_size)
+ train_loader = DataLoader(
+ train_dataset,
+ collate_fn=train_collate,
+ batch_sampler=batch_sampler,
+ num_workers=self.cfg.train.dataloader.num_worker,
+ pin_memory=self.cfg.train.dataloader.pin_memory,
+ )
+
+ # Build valid dataloader
+ datasets_list = []
+ for dataset in self.cfg.dataset:
+ subdataset = Dataset(self.cfg, dataset, is_valid=True)
+ datasets_list.append(subdataset)
+ valid_dataset = ConcatDataset(datasets_list)
+ valid_collate = Collator(self.cfg)
+ _, batch_sampler = build_samplers(valid_dataset, self.cfg, self.logger, "valid")
+ self.logger.debug(f"valid batch_sampler: {list(batch_sampler)}")
+ self.logger.debug(f"length: {valid_dataset.cumulative_sizes}")
+ valid_loader = DataLoader(
+ valid_dataset,
+ collate_fn=valid_collate,
+ batch_sampler=batch_sampler,
+ num_workers=self.cfg.train.dataloader.num_worker,
+ pin_memory=self.cfg.train.dataloader.pin_memory,
+ )
+ return train_loader, valid_loader
+
+ @staticmethod
+ def _set_random_seed(seed):
+ r"""Set random seed for all possible random modules."""
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.random.manual_seed(seed)
+
+ def _check_nan(self, loss, y_pred, y_gt):
+ if torch.any(torch.isnan(loss)):
+ self.logger.fatal("Fatal Error: Training is down since loss has Nan!")
+ self.logger.error("loss = {:.6f}".format(loss.item()), in_order=True)
+ if torch.any(torch.isnan(y_pred)):
+ self.logger.error(
+ f"y_pred has Nan: {torch.any(torch.isnan(y_pred))}", in_order=True
+ )
+ else:
+ self.logger.debug(
+ f"y_pred has Nan: {torch.any(torch.isnan(y_pred))}", in_order=True
+ )
+ if torch.any(torch.isnan(y_gt)):
+ self.logger.error(
+ f"y_gt has Nan: {torch.any(torch.isnan(y_gt))}", in_order=True
+ )
+ else:
+ self.logger.debug(
+ f"y_gt has nan: {torch.any(torch.isnan(y_gt))}", in_order=True
+ )
+ if torch.any(torch.isnan(y_pred)):
+ self.logger.error(f"y_pred: {y_pred}", in_order=True)
+ else:
+ self.logger.debug(f"y_pred: {y_pred}", in_order=True)
+ if torch.any(torch.isnan(y_gt)):
+ self.logger.error(f"y_gt: {y_gt}", in_order=True)
+ else:
+ self.logger.debug(f"y_gt: {y_gt}", in_order=True)
+
+ # TODO: still OK to save tracking?
+ self.accelerator.end_training()
+ raise RuntimeError("Loss has Nan! See log for more info.")
+
+ ### Protected methods end ###
+
+ ## Following are private methods ##
+ ## !!! These are inconvenient for GAN-based model training. It'd be better to move these to svc_trainer.py if needed.
+ def _build_optimizer(self):
+ r"""Build optimizer for model."""
+ # Make case-insensitive matching
+ if self.cfg.train.optimizer.lower() == "adadelta":
+ optimizer = torch.optim.Adadelta(
+ self.model.parameters(), **self.cfg.train.adadelta
+ )
+ self.logger.info("Using Adadelta optimizer.")
+ elif self.cfg.train.optimizer.lower() == "adagrad":
+ optimizer = torch.optim.Adagrad(
+ self.model.parameters(), **self.cfg.train.adagrad
+ )
+ self.logger.info("Using Adagrad optimizer.")
+ elif self.cfg.train.optimizer.lower() == "adam":
+ optimizer = torch.optim.Adam(self.model.parameters(), **self.cfg.train.adam)
+ self.logger.info("Using Adam optimizer.")
+ elif self.cfg.train.optimizer.lower() == "adamw":
+ optimizer = torch.optim.AdamW(
+ self.model.parameters(), **self.cfg.train.adamw
+ )
+ elif self.cfg.train.optimizer.lower() == "sparseadam":
+ optimizer = torch.optim.SparseAdam(
+ self.model.parameters(), **self.cfg.train.sparseadam
+ )
+ elif self.cfg.train.optimizer.lower() == "adamax":
+ optimizer = torch.optim.Adamax(
+ self.model.parameters(), **self.cfg.train.adamax
+ )
+ elif self.cfg.train.optimizer.lower() == "asgd":
+ optimizer = torch.optim.ASGD(self.model.parameters(), **self.cfg.train.asgd)
+ elif self.cfg.train.optimizer.lower() == "lbfgs":
+ optimizer = torch.optim.LBFGS(
+ self.model.parameters(), **self.cfg.train.lbfgs
+ )
+ elif self.cfg.train.optimizer.lower() == "nadam":
+ optimizer = torch.optim.NAdam(
+ self.model.parameters(), **self.cfg.train.nadam
+ )
+ elif self.cfg.train.optimizer.lower() == "radam":
+ optimizer = torch.optim.RAdam(
+ self.model.parameters(), **self.cfg.train.radam
+ )
+ elif self.cfg.train.optimizer.lower() == "rmsprop":
+ optimizer = torch.optim.RMSprop(
+ self.model.parameters(), **self.cfg.train.rmsprop
+ )
+ elif self.cfg.train.optimizer.lower() == "rprop":
+ optimizer = torch.optim.Rprop(
+ self.model.parameters(), **self.cfg.train.rprop
+ )
+ elif self.cfg.train.optimizer.lower() == "sgd":
+ optimizer = torch.optim.SGD(self.model.parameters(), **self.cfg.train.sgd)
+ else:
+ raise NotImplementedError(
+ f"Optimizer {self.cfg.train.optimizer} not supported yet!"
+ )
+ return optimizer
+
+ def _build_scheduler(self):
+ r"""Build scheduler for optimizer."""
+ # Make case-insensitive matching
+ if self.cfg.train.scheduler.lower() == "lambdalr":
+ scheduler = torch.optim.lr_scheduler.LambdaLR(
+ self.optimizer, **self.cfg.train.lambdalr
+ )
+ elif self.cfg.train.scheduler.lower() == "multiplicativelr":
+ scheduler = torch.optim.lr_scheduler.MultiplicativeLR(
+ self.optimizer, **self.cfg.train.multiplicativelr
+ )
+ elif self.cfg.train.scheduler.lower() == "steplr":
+ scheduler = torch.optim.lr_scheduler.StepLR(
+ self.optimizer, **self.cfg.train.steplr
+ )
+ elif self.cfg.train.scheduler.lower() == "multisteplr":
+ scheduler = torch.optim.lr_scheduler.MultiStepLR(
+ self.optimizer, **self.cfg.train.multisteplr
+ )
+ elif self.cfg.train.scheduler.lower() == "constantlr":
+ scheduler = torch.optim.lr_scheduler.ConstantLR(
+ self.optimizer, **self.cfg.train.constantlr
+ )
+ elif self.cfg.train.scheduler.lower() == "linearlr":
+ scheduler = torch.optim.lr_scheduler.LinearLR(
+ self.optimizer, **self.cfg.train.linearlr
+ )
+ elif self.cfg.train.scheduler.lower() == "exponentiallr":
+ scheduler = torch.optim.lr_scheduler.ExponentialLR(
+ self.optimizer, **self.cfg.train.exponentiallr
+ )
+ elif self.cfg.train.scheduler.lower() == "polynomiallr":
+ scheduler = torch.optim.lr_scheduler.PolynomialLR(
+ self.optimizer, **self.cfg.train.polynomiallr
+ )
+ elif self.cfg.train.scheduler.lower() == "cosineannealinglr":
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
+ self.optimizer, **self.cfg.train.cosineannealinglr
+ )
+ elif self.cfg.train.scheduler.lower() == "sequentiallr":
+ scheduler = torch.optim.lr_scheduler.SequentialLR(
+ self.optimizer, **self.cfg.train.sequentiallr
+ )
+ elif self.cfg.train.scheduler.lower() == "reducelronplateau":
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
+ self.optimizer, **self.cfg.train.reducelronplateau
+ )
+ elif self.cfg.train.scheduler.lower() == "cycliclr":
+ scheduler = torch.optim.lr_scheduler.CyclicLR(
+ self.optimizer, **self.cfg.train.cycliclr
+ )
+ elif self.cfg.train.scheduler.lower() == "onecyclelr":
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
+ self.optimizer, **self.cfg.train.onecyclelr
+ )
+ elif self.cfg.train.scheduler.lower() == "cosineannearingwarmrestarts":
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
+ self.optimizer, **self.cfg.train.cosineannearingwarmrestarts
+ )
+ elif self.cfg.train.scheduler.lower() == "noamlr":
+ scheduler = NoamLR(self.optimizer, **self.cfg.train.lr_scheduler)
+ else:
+ raise NotImplementedError(
+ f"Scheduler {self.cfg.train.scheduler} not supported yet!"
+ )
+ return scheduler
+
+ def _init_accelerator(self):
+ self.exp_dir = os.path.join(
+ os.path.abspath(self.cfg.log_dir), self.args.exp_name
+ )
+ project_config = ProjectConfiguration(
+ project_dir=self.exp_dir,
+ logging_dir=os.path.join(self.exp_dir, "log"),
+ )
+ from accelerate import DistributedDataParallelKwargs
+
+ kwargs = DistributedDataParallelKwargs(
+ find_unused_parameters=self.cfg.train.find_unused_parameters
+ )
+
+ self.accelerator = accelerate.Accelerator(
+ gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step,
+ log_with=self.cfg.train.tracker,
+ project_config=project_config,
+ kwargs_handlers=[kwargs],
+ )
+ if self.accelerator.is_main_process:
+ os.makedirs(project_config.project_dir, exist_ok=True)
+ os.makedirs(project_config.logging_dir, exist_ok=True)
+ with self.accelerator.main_process_first():
+ self.accelerator.init_trackers(self.args.exp_name)
+
+ def __check_basic_configs(self):
+ if self.cfg.train.gradient_accumulation_step <= 0:
+ self.logger.fatal("Invalid gradient_accumulation_step value!")
+ self.logger.error(
+ f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
+ )
+ self.accelerator.end_training()
+ raise ValueError(
+ f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
+ )
+ # TODO: check other values
+
+ @staticmethod
+ def __count_parameters(model):
+ model_param = 0.0
+ if isinstance(model, dict):
+ for key, value in model.items():
+ model_param += sum(p.numel() for p in model[key].parameters())
+ else:
+ model_param = sum(p.numel() for p in model.parameters())
+ return model_param
+
+ def __dump_cfg(self, path):
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ json5.dump(
+ self.cfg,
+ open(path, "w"),
+ indent=4,
+ sort_keys=True,
+ ensure_ascii=False,
+ quote_keys=True,
+ )
+
+ @torch.inference_mode()
+ def test_loop(self):
+ pass
+
+ ### Private methods end ###
diff --git a/models/tts/valle_v2/g2p_processor.py b/models/tts/valle_v2/g2p_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..43807fb1236b342f524690489a9ac927d1b97594
--- /dev/null
+++ b/models/tts/valle_v2/g2p_processor.py
@@ -0,0 +1,363 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import json
+import numpy as np
+import os
+import torch
+import copy
+from g2p_en import G2p
+import re
+import unicodedata
+from g2p_en import G2p
+from g2p_en.expand import normalize_numbers
+
+g2p = G2p()
+
+PHONE_SET = [
+ "!",
+ ",",
+ ".",
+ ".B",
+ ":",
+ "",
+ "",
+ "",
+ "",
+ "?",
+ "AA0B",
+ "AA0E",
+ "AA0I",
+ "AA1B",
+ "AA1E",
+ "AA1I",
+ "AA2B",
+ "AA2E",
+ "AA2I",
+ "AE0B",
+ "AE0E",
+ "AE0I",
+ "AE1B",
+ "AE1E",
+ "AE1I",
+ "AE2B",
+ "AE2E",
+ "AE2I",
+ "AH0B",
+ "AH0E",
+ "AH0I",
+ "AH1B",
+ "AH1E",
+ "AH1I",
+ "AH2B",
+ "AH2E",
+ "AH2I",
+ "AO0B",
+ "AO0E",
+ "AO0I",
+ "AO1",
+ "AO1B",
+ "AO1E",
+ "AO1I",
+ "AO2B",
+ "AO2E",
+ "AO2I",
+ "AW0B",
+ "AW0E",
+ "AW0I",
+ "AW1B",
+ "AW1E",
+ "AW1I",
+ "AW2B",
+ "AW2E",
+ "AW2I",
+ "AY0B",
+ "AY0E",
+ "AY0I",
+ "AY1B",
+ "AY1E",
+ "AY1I",
+ "AY2B",
+ "AY2E",
+ "AY2I",
+ "BB",
+ "BE",
+ "BI",
+ "CHB",
+ "CHE",
+ "CHI",
+ "DB",
+ "DE",
+ "DHB",
+ "DHE",
+ "DHI",
+ "DI",
+ "EH0B",
+ "EH0E",
+ "EH0I",
+ "EH1B",
+ "EH1E",
+ "EH1I",
+ "EH2B",
+ "EH2E",
+ "EH2I",
+ "ER0B",
+ "ER0E",
+ "ER0I",
+ "ER1B",
+ "ER1E",
+ "ER1I",
+ "ER2B",
+ "ER2E",
+ "ER2I",
+ "EY0B",
+ "EY0E",
+ "EY0I",
+ "EY1B",
+ "EY1E",
+ "EY1I",
+ "EY2B",
+ "EY2E",
+ "EY2I",
+ "FB",
+ "FE",
+ "FI",
+ "GB",
+ "GE",
+ "GI",
+ "HHB",
+ "HHE",
+ "HHI",
+ "IH0B",
+ "IH0E",
+ "IH0I",
+ "IH1B",
+ "IH1E",
+ "IH1I",
+ "IH2B",
+ "IH2E",
+ "IH2I",
+ "IY0B",
+ "IY0E",
+ "IY0I",
+ "IY1B",
+ "IY1E",
+ "IY1I",
+ "IY2B",
+ "IY2E",
+ "IY2I",
+ "JHB",
+ "JHE",
+ "JHI",
+ "KB",
+ "KE",
+ "KI",
+ "L",
+ "LB",
+ "LE",
+ "LI",
+ "MB",
+ "ME",
+ "MI",
+ "NB",
+ "NE",
+ "NGB",
+ "NGE",
+ "NGI",
+ "NI",
+ "OW0B",
+ "OW0E",
+ "OW0I",
+ "OW1B",
+ "OW1E",
+ "OW1I",
+ "OW2B",
+ "OW2E",
+ "OW2I",
+ "OY0B",
+ "OY0E",
+ "OY0I",
+ "OY1B",
+ "OY1E",
+ "OY1I",
+ "OY2B",
+ "OY2E",
+ "OY2I",
+ "PB",
+ "PE",
+ "PI",
+ "RB",
+ "RE",
+ "RI",
+ "SB",
+ "SE",
+ "SHB",
+ "SHE",
+ "SHI",
+ "SI",
+ "TB",
+ "TE",
+ "THB",
+ "THE",
+ "THI",
+ "TI",
+ "UH0B",
+ "UH0E",
+ "UH0I",
+ "UH1B",
+ "UH2B",
+ "UH1E",
+ "UH1I",
+ "UH2E",
+ "UH2I",
+ "UW0B",
+ "UW0E",
+ "UW0I",
+ "UW1B",
+ "UW1E",
+ "UW1I",
+ "UW2B",
+ "UW2E",
+ "UW2I",
+ "VB",
+ "VE",
+ "VI",
+ "WB",
+ "WE",
+ "WI",
+ "YB",
+ "YE",
+ "YI",
+ "ZB",
+ "ZE",
+ "ZHB",
+ "ZHE",
+ "ZHI",
+ "ZI",
+ "|",
+]
+PHPONE2ID = {PHONE_SET[i]: i for i in range(len(PHONE_SET))}
+
+PUNCS = "!,.?;:"
+
+
+def is_sil_phoneme(p):
+ return p == "" or not p[0].isalpha()
+
+
+def add_bdr(txt_struct):
+ txt_struct_ = []
+ for i, ts in enumerate(txt_struct):
+ txt_struct_.append(ts)
+ if (
+ i != len(txt_struct) - 1
+ and not is_sil_phoneme(txt_struct[i][0])
+ and not is_sil_phoneme(txt_struct[i + 1][0])
+ ):
+ txt_struct_.append(["|", ["|"]])
+ return txt_struct_
+
+
+def preprocess_text(text):
+ text = normalize_numbers(text)
+ text = "".join(
+ char
+ for char in unicodedata.normalize("NFD", text)
+ if unicodedata.category(char) != "Mn"
+ ) # Strip accents
+ text = text.lower()
+ text = re.sub("['\"()]+", "", text)
+ text = re.sub("[-]+", " ", text)
+ text = re.sub(f"[^ a-z{PUNCS}]", "", text)
+ text = re.sub(f" ?([{PUNCS}]) ?", r"\1", text) # !! -> !
+ text = re.sub(f"([{PUNCS}])+", r"\1", text) # !! -> !
+ text = text.replace("i.e.", "that is")
+ text = text.replace("i.e.", "that is")
+ text = text.replace("etc.", "etc")
+ text = re.sub(f"([{PUNCS}])", r" ", text) # remove punctuations for now
+ text = re.sub(rf"\s+", r" ", text)
+ return text
+
+
+def postprocess(txt_struct):
+ while len(txt_struct) > 0 and is_sil_phoneme(txt_struct[0][0]):
+ txt_struct = txt_struct[1:]
+ while len(txt_struct) > 0 and is_sil_phoneme(txt_struct[-1][0]):
+ txt_struct = txt_struct[:-1]
+ txt_struct = add_bdr(txt_struct)
+ txt_struct = [["", [""]]] + txt_struct + [["", [""]]]
+ return txt_struct
+
+
+def process(txt, g2p):
+ txt = preprocess_text(txt).strip()
+ phs = g2p(txt)
+ txt_struct = [[w, []] for w in txt.split(" ")]
+ i_word = 0
+ for p in phs:
+ if p == " ":
+ i_word += 1
+ else:
+ txt_struct[i_word][1].append(p)
+
+ txt_struct_ret = copy.deepcopy(txt_struct)
+
+ for i_word in range(len(txt_struct)):
+ if not is_sil_phoneme(txt_struct[i_word][0]):
+ if len(txt_struct[i_word][1]) > 1:
+ txt_struct_ret[i_word][1][0] += "B"
+ for i in range(1, len(txt_struct[i_word][1]) - 1):
+ txt_struct_ret[i_word][1][i] += "I"
+ txt_struct_ret[i_word][1][-1] += "E"
+ else:
+ txt_struct_ret[i_word][1][0] += "B"
+
+ txt_struct_ret = postprocess(txt_struct_ret)
+
+ return txt_struct_ret, txt
+
+
+def test():
+ g2p = G2p()
+ txt = "This is a test sentence."
+ txt_struct, txt = process(txt, g2p)
+ print(txt_struct)
+ print(txt)
+ phone_seq = [p for w in txt_struct for p in w[1]]
+ print(phone_seq)
+ phone_id = [PHPONE2ID[p] for p in phone_seq]
+ print(phone_id)
+
+
+class G2pProcessor:
+ def __init__(self):
+ self.g2p = G2p()
+
+ def __call__(self, txt, lang="en"):
+ return self.txt2phoneid(txt)
+
+ def txt2phoneid(self, txt):
+ txt_struct, txt = process(txt, self.g2p)
+ phone_seq = [p for w in txt_struct for p in w[1]]
+ phone_id = [PHPONE2ID[p] for p in phone_seq]
+ return None, phone_id
+
+ def phoneid2txt(self, phone_id):
+ txt = []
+ for i in phone_id:
+ txt.append(PHONE_SET[i])
+ return txt
+
+
+if __name__ == "__main__":
+ g2p = G2pProcessor()
+ txt = "This is a test sentence."
+ phoneid = g2p.txt2phoneid(txt)[1]
+ # output: [5, 73, 118, 175, 218, 116, 213, 218, 28, 218, 180, 82, 179, 181, 218, 174, 82, 149, 185, 30, 149, 175, 6]
+ # print(phoneid)
+ print(g2p.phoneid2txt(phoneid))
+ # output: ['', 'DHB', 'IH1I', 'SE', '|', 'IH1B', 'ZE', '|', 'AH0B', '|', 'TB', 'EH1I', 'SI', 'TE', '|', 'SB', 'EH1I', 'NI', 'TI', 'AH0I', 'NI', 'SE', '']
+ print(len(PHONE_SET))
+ # output: 219
diff --git a/models/tts/valle_v2/libritts_dataset.py b/models/tts/valle_v2/libritts_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..89b40e6f1648071b036a7f3d80a7791944518e81
--- /dev/null
+++ b/models/tts/valle_v2/libritts_dataset.py
@@ -0,0 +1,271 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import random
+import torch
+from torch.nn.utils.rnn import pad_sequence
+from utils.data_utils import *
+from tqdm import tqdm
+from g2p_en import G2p
+import librosa
+from torch.utils.data import Dataset
+import pandas as pd
+import time
+import io
+
+SAMPLE_RATE = 16000
+# g2p
+from .g2p_processor import G2pProcessor
+
+phonemizer_g2p = G2pProcessor()
+
+
+class VALLEDataset(Dataset):
+ def __init__(self, args):
+ print(f"Initializing VALLEDataset")
+ self.dataset_list = args.dataset_list
+
+ print(f"using sampling rate {SAMPLE_RATE}")
+
+ # set dataframe clumn name
+ book_col_name = [
+ "ID",
+ "Original_text",
+ "Normalized_text",
+ "Aligned_or_not",
+ "Start_time",
+ "End_time",
+ "Signal_to_noise_ratio",
+ ]
+ trans_col_name = [
+ "ID",
+ "Original_text",
+ "Normalized_text",
+ "Dir_path",
+ "Duration",
+ ]
+ self.metadata_cache = pd.DataFrame(columns=book_col_name)
+ self.trans_cache = pd.DataFrame(columns=trans_col_name)
+ # dataset_cache_dir = args.cache_dir # cache_dir
+ # print(f"args.cache_dir = ", args.cache_dir)
+ # os.makedirs(dataset_cache_dir, exist_ok=True)
+
+ ######## add data dir to dataset2dir ##########
+ self.dataset2dir = {
+ "dev-clean": f"{args.data_dir}/dev-clean",
+ "dev-other": f"{args.data_dir}/dev-other",
+ "test-clean": f"{args.data_dir}/test-clean",
+ "test-other": f"{args.data_dir}/test-other",
+ "train-clean-100": f"{args.data_dir}/train-clean-100",
+ "train-clean-360": f"{args.data_dir}/train-clean-360",
+ "train-other-500": f"{args.data_dir}/train-other-500",
+ }
+
+ ###### load metadata and transcripts #####
+ for dataset_name in self.dataset_list:
+ print("Initializing dataset: ", dataset_name)
+ # get [book,transcripts,audio] files list
+ self.book_files_list = self.get_metadata_files(
+ self.dataset2dir[dataset_name]
+ )
+ self.trans_files_list = self.get_trans_files(self.dataset2dir[dataset_name])
+
+ ## create metadata_cache (book.tsv file is not filtered, some file is not exist, but contain Duration and Signal_to_noise_ratio)
+ print("reading paths for dataset...")
+ for book_path in tqdm(self.book_files_list):
+ tmp_cache = pd.read_csv(
+ book_path, sep="\t", names=book_col_name, quoting=3
+ )
+ self.metadata_cache = pd.concat(
+ [self.metadata_cache, tmp_cache], ignore_index=True
+ )
+ self.metadata_cache.set_index("ID", inplace=True)
+
+ ## create transcripts (the trans.tsv file)
+ print("creating transcripts for dataset...")
+ for trans_path in tqdm(self.trans_files_list):
+ tmp_cache = pd.read_csv(
+ trans_path, sep="\t", names=trans_col_name, quoting=3
+ )
+ tmp_cache["Dir_path"] = os.path.dirname(trans_path)
+ self.trans_cache = pd.concat(
+ [self.trans_cache, tmp_cache], ignore_index=True
+ )
+ self.trans_cache.set_index("ID", inplace=True)
+
+ ## calc duration
+ self.trans_cache["Duration"] = (
+ self.metadata_cache.End_time[self.trans_cache.index]
+ - self.metadata_cache.Start_time[self.trans_cache.index]
+ )
+ ## add fullpath
+ # self.trans_cache['Full_path'] = os.path.join(self.dataset2dir[dataset_name],self.trans_cache['ID'])
+
+ # filter_by_duration: filter_out files with duration < 3.0 or > 15.0
+ print(f"Filtering files with duration between 3.0 and 15.0 seconds")
+ print(f"Before filtering: {len(self.trans_cache)}")
+ self.trans_cache = self.trans_cache[
+ (self.trans_cache["Duration"] >= 3.0)
+ & (self.trans_cache["Duration"] <= 15.0)
+ ]
+ print(f"After filtering: {len(self.trans_cache)}")
+
+ def get_metadata_files(self, directory):
+ book_files = []
+ for root, _, files in os.walk(directory):
+ for file in files:
+ if file.endswith(".book.tsv") and file[0] != ".":
+ rel_path = os.path.join(root, file)
+ book_files.append(rel_path)
+ return book_files
+
+ def get_trans_files(self, directory):
+ trans_files = []
+ for root, _, files in os.walk(directory):
+ for file in files:
+ if file.endswith(".trans.tsv") and file[0] != ".":
+ rel_path = os.path.join(root, file)
+ trans_files.append(rel_path)
+ return trans_files
+
+ def get_audio_files(self, directory):
+ audio_files = []
+ for root, _, files in os.walk(directory):
+ for file in files:
+ if file.endswith((".flac", ".wav", ".opus")):
+ rel_path = os.path.relpath(os.path.join(root, file), directory)
+ audio_files.append(rel_path)
+ return audio_files
+
+ def get_num_frames(self, index):
+ # get_num_frames(durations) by index
+ duration = self.meta_data_cache["Duration"][index]
+ # num_frames = duration * SAMPLE_RATE
+ num_frames = int(duration * 75)
+
+ # file_rel_path = self.meta_data_cache['relpath'][index]
+ # uid = file_rel_path.rstrip('.flac').split('/')[-1]
+ # num_frames += len(self.transcripts[uid])
+ return num_frames
+
+ def __len__(self):
+ return len(self.trans_cache)
+
+ def __getitem__(self, idx):
+ # Get the file rel path
+ file_dir_path = self.trans_cache["Dir_path"].iloc[idx]
+ # Get uid
+ uid = self.trans_cache.index[idx]
+ # Get the file name from cache uid
+ file_name = uid + ".wav"
+ # Get the full file path
+ full_file_path = os.path.join(file_dir_path, file_name)
+
+ # get phone
+ phone = self.trans_cache["Normalized_text"][uid]
+ phone = phonemizer_g2p(phone, "en")[1]
+ # load speech
+ speech, _ = librosa.load(full_file_path, sr=SAMPLE_RATE)
+ # if self.resample_to_24k:
+ # speech = librosa.resample(speech, orig_sr=SAMPLE_RATE, target_sr=24000)
+ # speech = torch.tensor(speech, dtype=torch.float32)
+ # pad speech to multiples of 200
+
+ # remainder = speech.size(0) % 200
+ # if remainder > 0:
+ # pad = 200 - remainder
+ # speech = torch.cat([speech, torch.zeros(pad, dtype=torch.float32)], dim=0)
+
+ # inputs = self._get_reference_vc(speech, hop_length=200)
+ inputs = {}
+ # Get the speaker id
+ # speaker = self.meta_data_cache['speaker'][idx]
+ # speaker_id = self.speaker2id[speaker]
+ # inputs["speaker_id"] = speaker_id
+ inputs["speech"] = speech # 24khz speech, [T]
+ inputs["phone"] = phone # [T]
+ return inputs
+
+
+def _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
+ if len(batch) == 0:
+ return 0
+ if len(batch) == max_sentences:
+ return 1
+ if num_tokens > max_tokens:
+ return 1
+ return 0
+
+
+def batch_by_size(
+ indices,
+ num_tokens_fn,
+ max_tokens=None,
+ max_sentences=None,
+ required_batch_size_multiple=1,
+):
+ """
+ Yield mini-batches of indices bucketed by size. Batches may contain
+ sequences of different lengths.
+
+ Args:
+ indices (List[int]): ordered list of dataset indices
+ num_tokens_fn (callable): function that returns the number of tokens at
+ a given index
+ max_tokens (int, optional): max number of tokens in each batch
+ (default: None).
+ max_sentences (int, optional): max number of sentences in each
+ batch (default: None).
+ required_batch_size_multiple (int, optional): require batch size to
+ be a multiple of N (default: 1).
+ """
+ bsz_mult = required_batch_size_multiple
+
+ sample_len = 0
+ sample_lens = []
+ batch = []
+ batches = []
+ for i in range(len(indices)):
+ idx = indices[i]
+ num_tokens = num_tokens_fn(idx)
+ sample_lens.append(num_tokens)
+ sample_len = max(sample_len, num_tokens)
+
+ assert (
+ sample_len <= max_tokens
+ ), "sentence at index {} of size {} exceeds max_tokens " "limit of {}!".format(
+ idx, sample_len, max_tokens
+ )
+ num_tokens = (len(batch) + 1) * sample_len
+
+ if _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
+ mod_len = max(
+ bsz_mult * (len(batch) // bsz_mult),
+ len(batch) % bsz_mult,
+ )
+ batches.append(batch[:mod_len])
+ batch = batch[mod_len:]
+ sample_lens = sample_lens[mod_len:]
+ sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
+ batch.append(idx)
+ if len(batch) > 0:
+ batches.append(batch)
+ return batches
+
+
+def test():
+ from utils.util import load_config
+
+ cfg = load_config("./egs/tts/VALLE_V2/exp_ar_libritts.json")
+ dataset = VALLEDataset(cfg.dataset)
+ metadata_cache = dataset.metadata_cache
+ trans_cache = dataset.trans_cache
+ print(trans_cache.head(10))
+ # print(dataset.book_files_list)
+ breakpoint()
+
+
+if __name__ == "__main__":
+ test()
diff --git a/models/tts/valle_v2/modeling_llama.py b/models/tts/valle_v2/modeling_llama.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b71b6d5aff3d8bc5005ae06b32c935e4eddedb
--- /dev/null
+++ b/models/tts/valle_v2/modeling_llama.py
@@ -0,0 +1,1043 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+# This code is modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
+
+# Original work copyright
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch LLaMA model."""
+import math
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from transformers.models.llama.modeling_llama import ACT2FN
+from transformers.models.llama.modeling_llama import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ SequenceClassifierOutputWithPast,
+)
+from transformers.models.llama.modeling_llama import PreTrainedModel
+from transformers.models.llama.modeling_llama import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from transformers.models.llama.modeling_llama import LlamaConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "LlamaConfig"
+
+
+# Copied from transformers.models.bart.modeling_bart._make_causal_mask
+def _make_causal_mask(
+ input_ids_shape: torch.Size,
+ dtype: torch.dtype,
+ device: torch.device,
+ past_key_values_length: int = 0,
+):
+ """
+ Make causal mask used for bi-directional self-attention.
+ """
+ bsz, tgt_len = input_ids_shape
+ mask = torch.full(
+ (tgt_len, tgt_len),
+ torch.tensor(torch.finfo(dtype).min, device=device),
+ device=device,
+ )
+ mask_cond = torch.arange(mask.size(-1), device=device)
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
+ mask = mask.to(dtype)
+
+ if past_key_values_length > 0:
+ mask = torch.cat(
+ [
+ torch.zeros(
+ tgt_len, past_key_values_length, dtype=dtype, device=device
+ ),
+ mask,
+ ],
+ dim=-1,
+ )
+ return mask[None, None, :, :].expand(
+ bsz, 1, tgt_len, tgt_len + past_key_values_length
+ )
+
+
+# Copied from transformers.models.bart.modeling_bart._expand_mask
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
+ )
+
+
+class LlamaRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+
+ return (self.weight * hidden_states).to(input_dtype)
+
+
+class LlamaRotaryEmbedding(torch.nn.Module):
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
+ super().__init__()
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
+ self.register_buffer("inv_freq", inv_freq)
+
+ # Build here to make `torch.jit.trace` work.
+ self.max_seq_len_cached = max_position_embeddings
+ t = torch.arange(
+ self.max_seq_len_cached,
+ device=self.inv_freq.device,
+ dtype=self.inv_freq.dtype,
+ )
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.register_buffer(
+ "cos_cached", emb.cos()[None, None, :, :], persistent=False
+ )
+ self.register_buffer(
+ "sin_cached", emb.sin()[None, None, :, :], persistent=False
+ )
+
+ def forward(self, x, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
+ if seq_len > self.max_seq_len_cached:
+ self.max_seq_len_cached = seq_len
+ t = torch.arange(
+ self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype
+ )
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+ self.register_buffer(
+ "cos_cached", emb.cos()[None, None, :, :], persistent=False
+ )
+ self.register_buffer(
+ "sin_cached", emb.sin()[None, None, :, :], persistent=False
+ )
+ return (
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
+ )
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class LlamaMLP(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ hidden_act: str,
+ ):
+ super().__init__()
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
+ self.act_fn = ACT2FN[hidden_act]
+
+ def forward(self, x):
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+
+
+class LlamaAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: LlamaConfig, **kwargs):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.max_position_embeddings = config.max_position_embeddings
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+ self.q_proj = nn.Linear(
+ self.hidden_size, self.num_heads * self.head_dim, bias=False
+ )
+ self.k_proj = nn.Linear(
+ self.hidden_size, self.num_heads * self.head_dim, bias=False
+ )
+ self.v_proj = nn.Linear(
+ self.hidden_size, self.num_heads * self.head_dim, bias=False
+ )
+ self.o_proj = nn.Linear(
+ self.num_heads * self.head_dim, self.hidden_size, bias=False
+ )
+ self.rotary_emb = LlamaRotaryEmbedding(
+ self.head_dim, max_position_embeddings=self.max_position_embeddings
+ )
+
+ if "layer_idx" in kwargs:
+ self.layer_idx = kwargs["layer_idx"]
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return (
+ tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ .contiguous()
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = (
+ self.q_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ key_states = (
+ self.k_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ value_states = (
+ self.v_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin, position_ids
+ )
+ # [bsz, nh, t, hd]
+
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ attn_weights = torch.matmul(
+ query_states, key_states.transpose(2, 3)
+ ) / math.sqrt(self.head_dim)
+
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights + attention_mask
+ attn_weights = torch.max(
+ attn_weights,
+ torch.tensor(
+ torch.finfo(attn_weights.dtype).min, device=attn_weights.device
+ ),
+ )
+
+ unnormed_attn_weights = attn_weights
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(
+ attn_weights, dim=-1, dtype=torch.float32
+ ).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, unnormed_attn_weights, past_key_value
+
+
+class LlamaDecoderLayer(nn.Module):
+ def __init__(self, config: LlamaConfig, **kwargs):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.self_attn = LlamaAttention(config=config)
+ self.mlp = LlamaMLP(
+ hidden_size=self.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ )
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = LlamaRMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ ) -> Tuple[
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
+ ]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ """
+
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+LLAMA_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`LlamaConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+ LLAMA_START_DOCSTRING,
+)
+class LlamaPreTrainedModel(PreTrainedModel):
+ config_class = LlamaConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["LlamaDecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, LlamaModel):
+ module.gradient_checkpointing = value
+
+
+LLAMA_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+ LLAMA_START_DOCSTRING,
+)
+class LlamaModel(LlamaPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
+
+ Args:
+ config: LlamaConfig
+ """
+
+ def __init__(self, config: LlamaConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(
+ config.vocab_size, config.hidden_size, self.padding_idx
+ )
+ self.layers = nn.ModuleList(
+ [LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]
+ )
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
+ def _prepare_decoder_attention_mask(
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
+ ):
+ # create causal mask
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ combined_attention_mask = None
+ if input_shape[-1] > 1:
+ combined_attention_mask = _make_causal_mask(
+ input_shape,
+ inputs_embeds.dtype,
+ device=inputs_embeds.device,
+ past_key_values_length=past_key_values_length,
+ )
+
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ expanded_attn_mask = _expand_mask(
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ ).to(inputs_embeds.device)
+ combined_attention_mask = (
+ expanded_attn_mask
+ if combined_attention_mask is None
+ else expanded_attn_mask + combined_attention_mask
+ )
+
+ return combined_attention_mask
+
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
+ )
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError(
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
+ )
+
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length,
+ seq_length + past_key_values_length,
+ dtype=torch.long,
+ device=device,
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ # embed positions
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (batch_size, seq_length_with_past),
+ dtype=torch.bool,
+ device=inputs_embeds.device,
+ )
+ attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask,
+ (batch_size, seq_length),
+ inputs_embeds,
+ past_key_values_length,
+ )
+
+ hidden_states = inputs_embeds
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ for idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ past_key_value = (
+ past_key_values[idx] if past_key_values is not None else None
+ )
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, output_attentions, None)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ attention_mask,
+ position_ids,
+ None,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
+ if v is not None
+ )
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+class LlamaForCausalLM(LlamaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = LlamaModel(config)
+
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+ @replace_return_docstrings(
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
+ )
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
+
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
+
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
+ ```"""
+
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ **kwargs,
+ ):
+ if past_key_values:
+ input_ids = input_ids[:, -1:]
+
+ position_ids = kwargs.get("position_ids", None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[:, -1].unsqueeze(-1)
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ }
+ )
+ return model_inputs
+
+ @staticmethod
+ def _reorder_cache(past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (
+ tuple(
+ past_state.index_select(0, beam_idx) for past_state in layer_past
+ ),
+ )
+ return reordered_past
+
+
+@add_start_docstrings(
+ """
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
+
+ [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-2) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """,
+ LLAMA_START_DOCSTRING,
+)
+class LlamaForSequenceClassification(LlamaPreTrainedModel):
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = LlamaModel(config)
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ transformer_outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size = input_ids.shape[0]
+ else:
+ batch_size = inputs_embeds.shape[0]
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError(
+ "Cannot handle batch sizes > 1 if no padding token is defined."
+ )
+ if self.config.pad_token_id is None:
+ sequence_lengths = -1
+ else:
+ if input_ids is not None:
+ sequence_lengths = (
+ torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
+ ).to(logits.device)
+ else:
+ sequence_lengths = -1
+
+ pooled_logits = logits[
+ torch.arange(batch_size, device=logits.device), sequence_lengths
+ ]
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (
+ labels.dtype == torch.long or labels.dtype == torch.int
+ ):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(
+ pooled_logits.view(-1, self.num_labels), labels.view(-1)
+ )
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
diff --git a/models/tts/valle_v2/valle_ar.py b/models/tts/valle_v2/valle_ar.py
new file mode 100644
index 0000000000000000000000000000000000000000..f50820fb4fb154ffab029d78c7b9a71d2c21fe95
--- /dev/null
+++ b/models/tts/valle_v2/valle_ar.py
@@ -0,0 +1,302 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .modeling_llama import LlamaConfig, LlamaForCausalLM, LlamaModel
+import torch
+import torch.nn.functional as F
+import numpy as np
+import os
+import torch.nn as nn
+
+
+class ValleAR(nn.Module):
+ def __init__(
+ self,
+ phone_vocab_size=256,
+ target_vocab_size=1024,
+ hidden_size=1024,
+ intermediate_size=4096,
+ num_hidden_layers=12,
+ num_attention_heads=16,
+ pad_token_id=1281,
+ bos_target_id=1282,
+ eos_target_id=1283,
+ bos_phone_id=1284,
+ eos_phone_id=1285,
+ use_input_embeds=False,
+ emb_dim=256,
+ **kwargs,
+ ):
+ super(ValleAR, self).__init__()
+ self.config = LlamaConfig(
+ vocab_size=phone_vocab_size + target_vocab_size + 10,
+ hidden_size=hidden_size,
+ intermediate_size=intermediate_size,
+ num_hidden_layers=num_hidden_layers,
+ num_attention_heads=num_attention_heads,
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_target_id,
+ eos_token_id=eos_target_id,
+ )
+ self.phone_vocab_size = phone_vocab_size
+ self.target_vocab_size = target_vocab_size
+ self.pad_token_id = pad_token_id
+ self.bos_target_id = bos_target_id
+ self.eos_target_id = eos_target_id
+ self.bos_phone_id = bos_phone_id
+ self.eos_phone_id = eos_phone_id
+ self.model = LlamaForCausalLM(self.config)
+
+ self.use_input_embeds = use_input_embeds
+
+ # no input embedding is used to provide speaker information
+ if self.use_input_embeds:
+ self.emb_linear = nn.Linear(emb_dim, hidden_size)
+ self.emb_linear.weight.data.normal_(mean=0.0, std=0.01)
+ self.emb_linear.bias.data.zero_()
+
+ def forward(
+ self, phone_ids, phone_mask, target_ids, target_mask, input_embeds=None
+ ):
+ if input_embeds is not None:
+ input_embeds = self.emb_linear(input_embeds)
+ phone_ids, phone_mask, phone_label = self.add_phone_eos_bos_label(
+ phone_ids,
+ phone_mask,
+ self.eos_phone_id,
+ self.bos_phone_id,
+ self.pad_token_id,
+ )
+ target_ids, target_mask, target_label = self.add_target_eos_bos_label(
+ target_ids,
+ target_mask,
+ self.eos_target_id,
+ self.bos_target_id,
+ self.pad_token_id,
+ )
+ input_token_ids = torch.cat([phone_ids, target_ids], dim=-1)
+ attention_mask = torch.cat([phone_mask, target_mask], dim=-1)
+ # breakpoint()
+ if input_embeds is not None:
+ raise NotImplementedError
+ attention_mask = torch.cat(
+ [
+ torch.ones(
+ (input_embeds.shape[0], input_embeds.shape[1]),
+ dtype=attention_mask.dtype,
+ device=attention_mask.device,
+ ),
+ attention_mask,
+ ],
+ dim=-1,
+ )
+ labels = torch.cat([phone_label, target_label], dim=-1)
+ if input_embeds is not None:
+ raise NotImplementedError
+ labels = torch.cat(
+ [
+ -100
+ * torch.ones(
+ (input_embeds.shape[0], input_embeds.shape[1]),
+ dtype=labels.dtype,
+ device=labels.device,
+ ),
+ labels,
+ ],
+ dim=-1,
+ )
+
+ if input_embeds is not None:
+ raise NotImplementedError
+ inputs_embeds = torch.cat(
+ [input_embeds, self.model.model.embed_tokens(input_token_ids)], dim=1
+ )
+ out = self.model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ labels=labels,
+ return_dict=True,
+ )
+ return out
+
+ out = self.model(
+ input_token_ids,
+ attention_mask=attention_mask,
+ labels=labels,
+ return_dict=True,
+ )
+
+ # calcualte top1, top5, top10 accuracy
+ logits = out.logits
+ logits = logits[:, -target_ids.shape[1] :]
+ top1_acc = logits.argmax(-1)[..., :-1] == target_ids[:, 1:]
+ top1_acc = (top1_acc * target_mask[..., :-1]).sum() / target_mask.sum()
+
+ top5_acc = torch.topk(logits[..., :-1, :], 5, dim=-1)[1]
+ top5_acc = top5_acc == target_ids[:, 1:].unsqueeze(-1)
+ top5_acc = (
+ top5_acc * target_mask[..., :-1].unsqueeze(-1)
+ ).sum() / target_mask.sum()
+
+ top10_acc = torch.topk(logits[..., :-1, :], 10, dim=-1)[1]
+ top10_acc = top10_acc == target_ids[:, 1:].unsqueeze(-1)
+ top10_acc = (
+ top10_acc * target_mask[..., :-1].unsqueeze(-1)
+ ).sum() / target_mask.sum()
+
+ out.top1_acc = top1_acc
+ out.top5_acc = top5_acc
+ out.top10_acc = top10_acc
+
+ return out
+
+ def add_phone_eos_bos_label(
+ self, phone_ids, phone_mask, phone_eos_id, phone_bos_id, pad_token_id
+ ):
+ # phone_ids: [B, T]
+ # phone_mask: [B, T]
+
+ phone_ids = phone_ids + self.target_vocab_size * phone_mask
+
+ phone_ids = phone_ids * phone_mask
+ phone_ids = F.pad(phone_ids, (0, 1), value=0) + phone_eos_id * F.pad(
+ 1 - phone_mask, (0, 1), value=1
+ ) # make pad token eos token, add eos token at the end
+ phone_mask = F.pad(phone_mask, (1, 0), value=1) # add eos mask
+ phone_ids = phone_ids * phone_mask + pad_token_id * (
+ 1 - phone_mask
+ ) # restore pad token ids
+ phone_ids = F.pad(phone_ids, (1, 0), value=phone_bos_id) # add bos token
+ phone_mask = F.pad(phone_mask, (1, 0), value=1) # add bos mask
+ phone_label = -100 * torch.ones_like(
+ phone_ids
+ ) # loss for entire phone is not computed (passed to llama)
+ return phone_ids, phone_mask, phone_label
+
+ def add_target_eos_bos_label(
+ self, target_ids, target_mask, target_eos_id, target_bos_id, pad_token_id
+ ):
+ # target_ids: [B, T]
+ # target_mask: [B, T]
+ target_ids = target_ids * target_mask
+ target_ids = F.pad(target_ids, (0, 1), value=0) + target_eos_id * F.pad(
+ 1 - target_mask, (0, 1), value=1
+ )
+ target_mask = F.pad(target_mask, (1, 0), value=1)
+ target_ids = target_ids * target_mask + pad_token_id * (1 - target_mask)
+ target_ids = F.pad(target_ids, (1, 0), value=target_bos_id)
+ target_mask = F.pad(target_mask, (1, 0), value=1)
+ target_label = target_ids * target_mask + (-100) * (
+ 1 - target_mask
+ ) # loss for target is computed on unmasked tokens
+ return target_ids, target_mask, target_label
+
+ def sample_hf(
+ self,
+ phone_ids, # the phones of prompt and target should be concatenated together
+ prompt_ids,
+ inputs_embeds=None,
+ max_length=2000,
+ temperature=1.0,
+ top_k=100,
+ top_p=0.9,
+ repeat_penalty=1.0,
+ num_beams=1,
+ ):
+ if inputs_embeds is not None:
+ inputs_embeds = self.emb_linear(inputs_embeds)
+ phone_mask = torch.ones_like(phone_ids)
+ prompt_mask = torch.ones_like(prompt_ids)
+ phone_ids, _, _ = self.add_phone_eos_bos_label(
+ phone_ids,
+ phone_mask,
+ self.eos_phone_id,
+ self.bos_phone_id,
+ self.pad_token_id,
+ )
+ prompt_ids, _, _ = self.add_target_eos_bos_label(
+ prompt_ids,
+ prompt_mask,
+ self.eos_target_id,
+ self.bos_target_id,
+ self.pad_token_id,
+ )
+ prompt_ids = prompt_ids[:, :-1] # remove end token. Make it continue mode
+
+ input_token_ids = torch.cat([phone_ids, prompt_ids], dim=-1)
+
+ if inputs_embeds is not None:
+ raise NotImplementedError
+ inputs_embeds = torch.cat(
+ [inputs_embeds, self.model.model.embed_tokens(input_token_ids)], dim=1
+ )
+ generated_ids = self.model.generate(
+ inputs_embeds=inputs_embeds,
+ do_sample=True,
+ max_length=max_length,
+ pad_token_id=self.pad_token_id,
+ eos_token_id=self.eos_target_id,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ repetition_penalty=repeat_penalty,
+ )
+ gen_tokens = generated_ids[:, :-1]
+ return gen_tokens
+
+ input_length = input_token_ids.shape[1]
+ generated_ids = self.model.generate(
+ input_token_ids,
+ do_sample=True,
+ max_length=max_length,
+ pad_token_id=self.pad_token_id,
+ eos_token_id=self.eos_target_id,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ repetition_penalty=repeat_penalty,
+ num_beams=num_beams,
+ )
+
+ gen_tokens = generated_ids[:, input_length:-1]
+
+ return gen_tokens
+
+
+def test():
+ model = ValleAR()
+
+ phone_ids = torch.LongTensor([[1, 2, 3, 4, 5, 0], [1, 2, 3, 4, 5, 6]])
+ phone_mask = torch.LongTensor([[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0]])
+ target_ids = torch.LongTensor([765, 234, 123, 234, 123, 599]).expand(2, -1)
+ target_mask = torch.LongTensor([1, 1, 1, 1, 0, 0]).expand(2, -1)
+
+ optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
+
+ for i in range(15):
+ optimizer.zero_grad()
+ out = model(
+ phone_ids=phone_ids,
+ phone_mask=phone_mask,
+ target_ids=target_ids,
+ target_mask=target_mask,
+ )
+ loss = out.loss
+
+ loss.backward()
+
+ optimizer.step()
+
+ print(f"iter={i}, {loss}.")
+
+ phone_ids = torch.LongTensor([1, 2, 3]).reshape(1, -1)
+ target_ids = torch.LongTensor([765, 234]).reshape(1, -1)
+ sampled = model.sample_hf(phone_ids, target_ids)
+
+ breakpoint()
+
+
+if __name__ == "__main__":
+ test()
diff --git a/models/tts/valle_v2/valle_ar_trainer.py b/models/tts/valle_v2/valle_ar_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5dc1aa07f55e9845e60a99877af7940335c9b99c
--- /dev/null
+++ b/models/tts/valle_v2/valle_ar_trainer.py
@@ -0,0 +1,371 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import json
+import os
+import shutil
+import torch
+import time
+from pathlib import Path
+import torch
+from tqdm import tqdm
+import torch.nn as nn
+from .base_trainer import BaseTrainer
+
+
+def make_pad_mask(
+ lengths: torch.Tensor, max_len: int = 0, left_pad=False
+) -> torch.Tensor:
+ """
+ Args:
+ lengths:
+ A 1-D tensor containing sentence lengths.
+ max_len:
+ The length of masks.
+ left_pad:
+ A boolean indicating whether to left pad the mask.
+ Returns:
+ Return a 2-D bool tensor, where masked positions
+ are filled with `True` and non-masked positions are
+ filled with `False`.
+
+ >>> lengths = torch.tensor([1, 3, 2, 5])
+ >>> make_pad_mask(lengths)
+ tensor([[False, True, True, True, True],
+ [False, False, False, True, True],
+ [False, False, True, True, True],
+ [False, False, False, False, False]])
+ """
+ assert lengths.ndim == 1, lengths.ndim
+ max_len = max(max_len, lengths.max())
+ n = lengths.size(0)
+ seq_range = torch.arange(0, max_len, device=lengths.device)
+ expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
+ mask = expaned_lengths >= lengths.unsqueeze(-1)
+
+ if left_pad:
+ mask = mask.flip(dims=[1])
+
+ return mask
+
+
+class ValleARTrainer(BaseTrainer):
+ def __init__(self, args=None, cfg=None):
+ super().__init__(args, cfg)
+ if self.cfg.use_speechtokenizer:
+ from models.codec.speechtokenizer.model import SpeechTokenizer
+
+ config_path = "./ckpts/speechtokenizer_hubert_avg/config.json"
+ ckpt_path = "./ckpts/speechtokenizer_hubert_avg/SpeechTokenizer.pt"
+ assert os.path.isfile(
+ config_path
+ ), f"codec model {config_path} not found! Download with huggingface-cli download fnlp/SpeechTokenizer speechtokenizer_hubert_avg/SpeechTokenizer.pt speechtokenizer_hubert_avg/config.json --local-dir ckpts"
+ assert os.path.isfile(
+ ckpt_path
+ ), f"codec model {ckpt_path} not found! Download with huggingface-cli download fnlp/SpeechTokenizer speechtokenizer_hubert_avg/SpeechTokenizer.pt speechtokenizer_hubert_avg/config.json --local-dir ckpts"
+ self.codec_encoder = SpeechTokenizer.load_from_checkpoint(
+ config_path, ckpt_path
+ )
+ self.codec_encoder.eval()
+ self.codec_encoder.to(self.accelerator.device)
+ print(f"Loaded SpeechTokenizer from {config_path} and {ckpt_path}")
+ else:
+ from encodec import EncodecModel
+
+ with self.accelerator.main_process_first():
+ self.codec_encoder = EncodecModel.encodec_model_24khz()
+ self.codec_encoder.set_target_bandwidth(6.0)
+ self.codec_encoder.to(self.accelerator.device)
+ self.codec_decoder = None
+ print("Loaded EncodecModel")
+ self.top1_accuracies = []
+ self.top5_accuracies = []
+ self.top10_accuracies = []
+
+ if hasattr(self.cfg, "flatten_first_2_layers"):
+ self.flatten_first_2_layers = self.cfg.flatten_first_2_layers
+ print("flattened:", self.flatten_first_2_layers)
+ else:
+ self.flatten_first_2_layers = False
+
+ if hasattr(self.cfg, "num_prediction_heads"):
+ self.num_prediction_heads = self.cfg.num_prediction_heads
+ print("num_prediction_heads:", self.num_prediction_heads)
+
+ def _accelerator_prepare(self):
+ # if self.accelerator.is_main_process:
+ # breakpoint()
+ # self.accelerator.wait_for_everyone()
+
+ (
+ self.model,
+ self.optimizer,
+ ) = self.accelerator.prepare(
+ self.model,
+ self.optimizer,
+ )
+
+ def _build_criterion(self):
+ pass # loss is directly returned from model
+
+ def _build_scheduler(self):
+ from transformers import (
+ get_cosine_schedule_with_warmup,
+ get_constant_schedule_with_warmup,
+ )
+
+ return get_cosine_schedule_with_warmup(
+ self.optimizer,
+ num_warmup_steps=self.cfg.train.scheduler.warmup_steps,
+ num_training_steps=self.cfg.train.scheduler.total_steps,
+ )
+
+ def _build_model(self):
+ if hasattr(self.cfg.model, "num_prediction_heads"):
+ from .valle_ar_multihead import ValleAR
+ else:
+ from .valle_ar import ValleAR
+ return ValleAR(**self.cfg.model)
+
+ def _train_step(self, batch):
+ # inference codec
+ """Returns: dict('speech', 'speech_len', 'phone_ids', 'phone_lens')
+ speech: [B, T]
+ speech_len: [B]
+ phone_ids: [B, T]
+ phone_lens: [B]
+ """
+ device = self.accelerator.device
+ for k, v in batch.items():
+ if isinstance(v, torch.Tensor):
+ batch[k] = v.to(device)
+ with torch.no_grad():
+ if self.cfg.use_speechtokenizer:
+ # Extract discrete codes from SpeechTokenizer
+ vq_id = self.codec_encoder.encode(
+ batch["speech"].unsqueeze(1)
+ ) # [B,1,T] -> (n_q, B, T)
+ else:
+ vq_id = self.codec_encoder.encode(batch["speech"].unsqueeze(1))
+ vq_id = torch.cat([encoded[0] for encoded in vq_id], dim=-1).transpose(
+ 0, 1
+ )
+
+ # recovered_audio = self.codec_decoder(vq_emb, vq=False)
+ # torchaudio.save('a.wav', recovered_audio[0], 16000)
+ # vq_id: [8, B, T//320]
+ if self.flatten_first_2_layers:
+ first_layer = vq_id[0]
+ second_layer = vq_id[1]
+ # flatten the first two layers
+ batch["speech"] = torch.stack(
+ [first_layer, second_layer], dim=-1
+ ).flatten(-2, -1)
+ batch["speech_len"] = batch["speech_len"] // 160
+ elif hasattr(self.cfg.model, "num_prediction_heads"):
+ batch["speech"] = vq_id[:2] # first two layers
+ batch["speech_len"] = (
+ batch["speech_len"] // 320
+ ) # our codec downsamples 320x
+ else:
+ batch["speech"] = vq_id[0] # use first layer
+ batch["speech_len"] = (
+ batch["speech_len"] // 320
+ ) # our codec downsamples 320x
+ assert batch["speech_len"].max() <= batch["speech"].shape[-1]
+
+ phone_mask = 1 - make_pad_mask(
+ batch["phone_lens"], max_len=batch["phone_ids"].size(1), left_pad=False
+ ).to(torch.long)
+ speech_mask = 1 - make_pad_mask(
+ batch["speech_len"], max_len=batch["speech"].size(1)
+ ).to(torch.long)
+
+ out = self.model(
+ phone_ids=batch["phone_ids"],
+ phone_mask=phone_mask,
+ target_ids=batch["speech"],
+ target_mask=speech_mask,
+ )
+ loss = out.loss
+ # if self.accelerator.is_main_process:
+ # print(loss)
+ # if hasattr(out, 'top1_acc'):
+ # self.top1_accuracies.append(out.top1_acc)
+ # self.top5_accuracies.append(out.top5_acc)
+ # self.top10_accuracies.append(out.top10_acc)
+ # print(f'avgs: top1: {sum(self.top1_accuracies)/len(self.top1_accuracies)}, top5: {sum(self.top5_accuracies)/len(self.top5_accuracies)}, top10: {sum(self.top10_accuracies)/len(self.top10_accuracies)}')
+ # breakpoint()
+ return loss
+
+ ##########add your own dataloader to the trainer#############
+ def _build_dataloader(self):
+ from torch.utils.data import ConcatDataset, DataLoader
+
+ if self.cfg.train.dataset.name == "emilia":
+ from .emilia_dataset import EmiliaDataset as VALLEDataset
+
+ train_dataset = VALLEDataset()
+ elif self.cfg.train.dataset.name == "mls":
+ from .mls_dataset import VALLEDataset as VALLEDataset
+
+ train_dataset = VALLEDataset(self.cfg.dataset, resample_to_24k=False)
+ elif self.cfg.train.dataset.name == "libritts":
+ from .libritts_dataset import VALLEDataset as VALLEDataset
+
+ train_dataset = VALLEDataset(self.cfg.dataset)
+
+ from .valle_collator import VALLECollator
+ import numpy as np
+
+ print("length of train_dataset:", len(train_dataset))
+
+ collator = VALLECollator()
+
+ if self.cfg.train.dataset.use_dynamic_batchsize:
+ if self.accelerator.is_main_process:
+ self.logger.info("Use Dynamic Batchsize......")
+ from .mls_dataset import batch_by_size
+
+ batch_sampler = batch_by_size(
+ train_dataset.num_frame_indices,
+ train_dataset.get_num_frames,
+ max_tokens=self.cfg.train.max_tokens * self.accelerator.num_processes,
+ max_sentences=self.cfg.train.max_sentences
+ * self.accelerator.num_processes,
+ required_batch_size_multiple=self.accelerator.num_processes,
+ )
+ np.random.shuffle(batch_sampler)
+ print(batch_sampler[0])
+ batches = [
+ x[
+ self.accelerator.local_process_index :: self.accelerator.num_processes
+ ]
+ for x in batch_sampler
+ if len(x) % self.accelerator.num_processes == 0
+ ]
+ from models.base.base_sampler import VariableSampler
+
+ train_loader = DataLoader(
+ train_dataset,
+ collate_fn=collator,
+ num_workers=self.cfg.train.dataloader.num_worker,
+ batch_sampler=VariableSampler(
+ batches, drop_last=True, use_random_sampler=True
+ ),
+ pin_memory=self.cfg.train.dataloader.pin_memory,
+ persistent_workers=self.cfg.train.dataloader.persistent_workers,
+ prefetch_factor=4,
+ )
+ print(
+ f"process {self.accelerator.local_process_index} has {len(batches)} batches"
+ )
+ self.accelerator.wait_for_everyone()
+
+ else:
+ sampler = torch.utils.data.distributed.DistributedSampler(
+ train_dataset,
+ num_replicas=self.accelerator.num_processes,
+ rank=self.accelerator.local_process_index,
+ shuffle=True,
+ )
+ train_loader = DataLoader(
+ train_dataset,
+ batch_size=self.cfg.train.batch_size,
+ num_workers=self.cfg.train.dataloader.num_worker,
+ pin_memory=self.cfg.train.dataloader.pin_memory,
+ collate_fn=collator,
+ sampler=sampler,
+ )
+ print(
+ f"process {self.accelerator.local_process_index} has {len(train_loader)} batches"
+ )
+
+ return train_loader, None
+
+ def _test_step(self, batch):
+ # inference codec
+ """Returns: dict('speech', 'speech_len', 'phone_ids', 'phone_lens')
+ speech: [B, T]
+ speech_len: [B]
+ phone_ids: [B, T]
+ phone_lens: [B]
+ """
+ import torchaudio
+
+ device = self.accelerator.device
+ for k, v in batch.items():
+ if isinstance(v, torch.Tensor):
+ batch[k] = v.to(device)
+ with torch.no_grad():
+ if self.cfg.use_speechtokenizer:
+ # Extract discrete codes from SpeechTokenizer
+ vq_id = self.codec_encoder.encode(
+ batch["speech"].unsqueeze(1)
+ ) # [B,1,T] -> (n_q, B, T)
+ else:
+ vq_id = self.codec_encoder.encode(batch["speech"].unsqueeze(1))
+ vq_id = torch.cat([encoded[0] for encoded in vq_id], dim=-1).transpose(
+ 0, 1
+ )
+ # recovered_audio = self.codec_decoder(vq_emb, vq=False)
+ # torchaudio.save('a.wav', recovered_audio[0], 16000)
+ # vq_id: [8, B, T//200]
+
+ # vq_emb = self.codec_decoder.quantizer.vq2emb(vq=vq_id[:1], n_quantizers=1)
+ # recovered_audio = self.codec_decoder(vq_emb, vq=False)
+ # recovered_audio.shape: torch.Size([1, 1, 50200])
+
+ if self.flatten_first_2_layers:
+ first_layer = vq_id[0]
+ second_layer = vq_id[1]
+ # flatten the first two layers
+ batch["speech"] = torch.stack(
+ [first_layer, second_layer], dim=-1
+ ).flatten(-2, -1)
+ batch["speech_len"] = batch["speech_len"] // 160
+ elif hasattr(self.cfg.model, "num_prediction_heads"):
+ batch["speech"] = vq_id[:2] # first two layers
+ batch["speech_len"] = (
+ batch["speech_len"] // 320
+ ) # our codec downsamples 320x
+ else:
+ batch["speech"] = vq_id[0] # use first layer
+ batch["speech_len"] = (
+ batch["speech_len"] // 320
+ ) # our codec downsamples 320x
+
+ # save gt
+ breakpoint()
+ recovered_audio = self.codec_encoder.decode(vq_id[:1, :1])
+ # recovered_audio = self.codec_encoder.decode([(vq_id[:1].transpose(0,1), None)])
+ torchaudio.save("gt.wav", recovered_audio[0].cpu(), 16000)
+ out_vq_ids = self.model.sample_hf(
+ batch["phone_ids"][:1, ...], batch["speech"][:1, :225], temperature=0.9
+ )
+ # out_vq_ids = torch.cat([batch['speech'][:1, :225], out_vq_ids[:1, ...]], dim=1)
+
+ # reconstruct form tokens
+ recovered_audio = self.codec_encoder.decode(out_vq_ids.unsqueeze(0))
+ # recovered_audio = self.codec_encoder.decode([(out_vq_ids, None)])
+ torchaudio.save("a.wav", recovered_audio[0].cpu(), 16000)
+ breakpoint()
+ print()
+
+ @torch.inference_mode()
+ def _valid_epoch(self):
+ r"""Testing epoch. Should return average loss of a batch (sample) over
+ one epoch. See ``train_loop`` for usage.
+ """
+ epoch_sum_loss = 0.0
+ return epoch_sum_loss
+
+ def _inference(self):
+ pass
+
+ def test_loop(self):
+ self.model.eval()
+ for batch in self.train_dataloader:
+ self._test_step(batch)
diff --git a/models/tts/valle_v2/valle_collator.py b/models/tts/valle_v2/valle_collator.py
new file mode 100644
index 0000000000000000000000000000000000000000..29db1b32a4b0c9ef7ba807fc60dc1e37acb685ff
--- /dev/null
+++ b/models/tts/valle_v2/valle_collator.py
@@ -0,0 +1,57 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch.nn.utils.rnn import pad_sequence
+
+
+class VALLECollator:
+ def __init__(self, cfg=None):
+ self.cfg = cfg
+
+ def __call__(self, batch):
+ """Returns: dict('speech', 'speech_len', 'phone_ids', 'phone_lens')
+ speech: [B, T]
+ speech_len: [B]
+ phone_ids: [B, T]
+ phone_lens: [B]
+ """
+ assert len(batch) != 0, "batch is empty before None checking"
+ batch = [b for b in batch if b is not None]
+ assert len(batch) != 0, "batch is empty after None checking"
+ packed_batch_features = {}
+
+ # Function to handle tensor copying
+ def process_tensor(data, dtype=torch.float32):
+ if isinstance(data, torch.Tensor):
+ return data.detach()
+ else:
+ return torch.tensor(data, dtype=dtype)
+
+ # Process 'speech' data
+ speeches = [process_tensor(b["speech"]) for b in batch]
+ packed_batch_features["speech_len"] = torch.tensor(
+ [len(s) for s in speeches], dtype=torch.long
+ )
+ packed_batch_features["speech"] = pad_sequence(
+ speeches, batch_first=True, padding_value=0
+ )
+
+ # right-padding 'phone' data
+ phones = [process_tensor(b["phone"], dtype=torch.long) for b in batch]
+ packed_batch_features["phone_lens"] = torch.tensor(
+ [len(phone) for phone in phones], dtype=torch.long
+ )
+ packed_batch_features["phone_ids"] = pad_sequence(
+ phones, batch_first=True, padding_value=0
+ )
+
+ # # Process 'phone' data, with left padding
+ # phones = [process_tensor(b['phone'], dtype=torch.long).flip(0) for b in batch] # first reverse the whole sequence
+ # packed_batch_features['phone_lens'] = torch.tensor([len(phone) for phone in phones], dtype=torch.long)
+ # packed_batch_features['phone_ids'] = pad_sequence(phones, batch_first=True, padding_value=0) # do the right padding
+ # packed_batch_features['phone_ids'] = packed_batch_features['phone_ids'].flip(1) # flip back to original order (left padding)
+
+ return packed_batch_features
diff --git a/models/tts/valle_v2/valle_inference.py b/models/tts/valle_v2/valle_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..39efc6f2d486d7e396bdad9b6f098c85b1e1bf0c
--- /dev/null
+++ b/models/tts/valle_v2/valle_inference.py
@@ -0,0 +1,169 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torchaudio
+
+
+class ValleInference(torch.nn.Module):
+ def __init__(
+ self,
+ use_vocos=False,
+ use_speechtokenizer=True,
+ ar_path=None,
+ nar_path=None,
+ speechtokenizer_path=None,
+ device="cuda",
+ ):
+ super().__init__()
+
+ self.device = device
+
+ # prepare pretrained VALLE AR model
+ from .valle_ar import ValleAR
+
+ self.ar_model = ValleAR(
+ phone_vocab_size=300,
+ target_vocab_size=1024,
+ pad_token_id=1324,
+ bos_target_id=1325,
+ eos_target_id=1326,
+ bos_phone_id=1327,
+ eos_phone_id=1328,
+ bos_prompt_id=1329,
+ eos_prompt_id=1330,
+ num_hidden_layers=16,
+ )
+ # change the following path to your trained model path
+ assert ar_path is not None
+ self.ar_model.load_state_dict(torch.load(ar_path, map_location="cpu"))
+ self.ar_model.eval().to(self.device)
+
+ # prepare pretrained VALLE NAR model
+ from .valle_nar import ValleNAR
+
+ self.nar_model = ValleNAR(
+ phone_vocab_size=300,
+ target_vocab_size=1024,
+ pad_token_id=1324,
+ bos_target_id=1325,
+ eos_target_id=1326,
+ bos_phone_id=1327,
+ eos_phone_id=1328,
+ bos_prompt_id=1329,
+ eos_prompt_id=1330,
+ num_hidden_layers=16,
+ )
+ assert nar_path is not None
+ self.nar_model.load_state_dict(torch.load(nar_path, map_location="cpu"))
+ self.nar_model.eval().to(self.device)
+
+ # prepare codec encoder
+ assert not (
+ use_speechtokenizer and use_vocos
+ ), "Only one of use_speechtokenizer and use_vocos can be True"
+ self.use_speechtokenizer = use_speechtokenizer
+ if use_speechtokenizer:
+ from models.codec.speechtokenizer.model import SpeechTokenizer
+
+ # download from https://huggingface.co/fnlp/SpeechTokenizer/tree/main/speechtokenizer_hubert_avg
+ config_path = speechtokenizer_path + "/config.json"
+ ckpt_path = speechtokenizer_path + "/SpeechTokenizer.pt"
+ self.codec_encoder = SpeechTokenizer.load_from_checkpoint(
+ config_path, ckpt_path
+ )
+ self.codec_encoder.eval()
+ self.codec_encoder.to(device)
+ print(f"Loaded SpeechTokenizer from {config_path} and {ckpt_path}")
+ else:
+ # use Encodec
+ from encodec import EncodecModel
+
+ self.codec_encoder = EncodecModel.encodec_model_24khz()
+ self.codec_encoder.set_target_bandwidth(6.0)
+ self.codec_encoder.to(self.device)
+ if use_vocos:
+ from vocos import Vocos
+
+ self.codec_decoder = Vocos.from_pretrained(
+ "charactr/vocos-encodec-24khz"
+ )
+ self.codec_decoder.to(self.device)
+ print("Loaded Vocos")
+ print("Loaded EncodecModel")
+
+ self.use_vocos = use_vocos
+
+ def decode(self, vq_ids):
+ """vq_ids.shape: [8, B, T],
+ returns: [B, 1, T]"""
+ if self.use_speechtokenizer:
+ # infer speechtokenizer
+ return self.codec_encoder.decode(vq_ids) # [B, 1, T]
+ else:
+ if not self.use_vocos:
+ # vocos decoder
+ return self.codec_encoder.decode([(vq_ids.transpose(0, 1), None)])
+ else:
+ # encodec decoder
+ features = self.codec_decoder.codes_to_features(vq_ids.squeeze(1))
+ bandwidth_id = torch.tensor([2], device=vq_ids.device)
+ return self.codec_decoder.decode(
+ features, bandwidth_id=bandwidth_id
+ ).unsqueeze(0)
+
+ def forward(self, batch, chunk_configs: list, return_prompt=False, prompt_len=None):
+ """batch: dict(
+ speech: [B, T]
+ phone_ids: [B, T]
+ )
+ returns: [B, 1, T] audio
+ """
+ if prompt_len is None:
+ prompt_len = 100000 # no prompt length limiting
+ for k, v in batch.items():
+ if isinstance(v, torch.Tensor):
+ batch[k] = v.to(self.device)
+ with torch.no_grad():
+ if self.use_speechtokenizer:
+ vq_id = self.codec_encoder.encode(
+ batch["speech"].unsqueeze(1)
+ ) # [B,1,T] -> (n_q, B, T)
+ else:
+ vq_id = self.codec_encoder.encode(batch["speech"].unsqueeze(1))
+ vq_id = torch.cat([encoded[0] for encoded in vq_id], dim=-1).transpose(
+ 0, 1
+ )
+
+ # typically we only require one config in the chunk,
+ # but we can also use multiple configs to, for example, use different sampling temperature at different positions
+ for chunk in chunk_configs:
+ ar_vq_ids = self.ar_model.sample_hf(
+ batch["phone_ids"],
+ vq_id[0, :, :prompt_len],
+ top_p=chunk["top_p"],
+ top_k=chunk["top_k"],
+ temperature=chunk["temperature"],
+ num_beams=chunk["num_beams"],
+ repeat_penalty=chunk["repeat_penalty"],
+ max_length=chunk["max_length"],
+ )
+ # recovered_audio_ar = self.decode(ar_vq_ids.unsqueeze(0))
+ # torchaudio.save('recovered_audio_ar.wav', recovered_audio_ar[0].cpu(), 24000)
+
+ nar_vq_ids = self.nar_model.sample_hf(
+ phone_ids=batch["phone_ids"],
+ prompt_ids=vq_id[:, :, :prompt_len],
+ first_stage_ids=ar_vq_ids,
+ # first_stage_ids=vq_id[0, :, prompt_len:],
+ )
+
+ if return_prompt:
+ nar_vq_ids = torch.cat(
+ [vq_id[..., :prompt_len], nar_vq_ids], dim=-1
+ )
+
+ recovered_audio = self.decode(nar_vq_ids)
+ return recovered_audio # [B, 1, T]
diff --git a/models/tts/valle_v2/valle_nar.py b/models/tts/valle_v2/valle_nar.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d2ff0218ff0449f72b13e92f7dcb28b2a333483
--- /dev/null
+++ b/models/tts/valle_v2/valle_nar.py
@@ -0,0 +1,801 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
+import torch
+import torch.nn.functional as F
+import numpy as np
+import os
+import torch.nn as nn
+from typing import List, Optional, Tuple, Union
+
+from transformers.models.llama.modeling_llama import LlamaDecoderLayer
+
+NUM_QUANTIZERS = 8 # number of quantizers in total, currently assumes first layer AR.
+START_QUANTIZATION_LAYER = 1 # start quantization layer
+END_QUANTIZATION_LAYER = 7 # end quantization layer
+
+
+class LlamaAdaptiveRMSNorm(nn.Module):
+ def __init__(self, hidden_size=1024, eps=1e-9, dim_cond=1024):
+ super().__init__()
+ self.to_weight = nn.Linear(dim_cond, hidden_size)
+ nn.init.normal_(self.to_weight.weight, mean=0.0, std=0.02)
+ # nn.init.zeros_(self.to_weight.weight)
+ # nn.init.ones_(self.to_weight.bias)
+ self.variance_epsilon = eps
+ self._is_hf_initialized = True # disable automatic init
+
+ def forward(self, hidden_states, cond_embedding):
+ input_dtype = hidden_states.dtype
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+
+ weight = self.to_weight(cond_embedding)
+
+ return (weight * hidden_states).to(input_dtype)
+
+
+class LlamaNARDecoderLayer(LlamaDecoderLayer):
+ def __init__(self, config: LlamaConfig):
+ """Override to adaptive layer norm"""
+ super().__init__(config=config, layer_idx=0) # init attention, mlp, etc.
+ self.input_layernorm = LlamaAdaptiveRMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
+ )
+ self.post_attention_layernorm = LlamaAdaptiveRMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
+ )
+
+ # add `cond` in forward function
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cond_embedding: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ ) -> Tuple[
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
+ ]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ """
+
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(
+ hidden_states, cond_embedding=cond_embedding
+ )
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(
+ hidden_states, cond_embedding=cond_embedding
+ )
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+from transformers.models.llama.modeling_llama import BaseModelOutputWithPast
+
+
+class MultiEmbedding(nn.Module):
+ """Embedding for multiple quantization layers, summing up the embeddings of each layer."""
+
+ def __init__(
+ self,
+ num_embeddings=1034,
+ embedding_dim=1024,
+ num_quantization_layers=NUM_QUANTIZERS,
+ ):
+ super().__init__()
+ self.embeddings = nn.ModuleList(
+ [
+ nn.Embedding(num_embeddings, embedding_dim)
+ for _ in range(num_quantization_layers)
+ ]
+ )
+
+ # initialize embeddings
+ for i in range(num_quantization_layers):
+ self.embeddings[i].weight.data.normal_(mean=0.0, std=0.02)
+ self._is_hf_initialized = True # disable automatic init
+
+ def forward(self, input_ids):
+ """Input: [num_quant, B, T] -> Output: [B, T, H]"""
+ num_quant, B, T = input_ids.shape
+ summed_embeddings = torch.zeros(
+ B, T, self.embeddings[0].embedding_dim, device=input_ids.device
+ )
+ for i in range(num_quant):
+ summed_embeddings += self.embeddings[i](input_ids[i])
+ return summed_embeddings
+
+
+class LlammaNARModel(LlamaModel):
+ def __init__(self, config):
+ """Adding adaptive layer norm, conditional embeddings, and multi-level input embeddings to the decoder layer"""
+ super().__init__(config)
+ self.layers = nn.ModuleList(
+ [LlamaNARDecoderLayer(config) for _ in range(config.num_hidden_layers)]
+ )
+ self.norm = LlamaAdaptiveRMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
+ )
+
+ self.embed_cond = nn.Embedding(
+ NUM_QUANTIZERS, config.hidden_size
+ ) # 7 quantization layers
+
+ for layer in self.layers:
+ layer.input_layernorm = LlamaAdaptiveRMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
+ )
+ layer.post_attention_layernorm = LlamaAdaptiveRMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
+ )
+
+ self.post_init()
+
+ def _prepare_decoder_attention_mask(
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
+ ):
+ # create noncausal mask
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ combined_attention_mask = None
+
+ def _expand_mask(
+ mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
+ ):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = (
+ mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+ )
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
+ )
+
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ expanded_attn_mask = _expand_mask(
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ ).to(inputs_embeds.device)
+ combined_attention_mask = (
+ expanded_attn_mask
+ if combined_attention_mask is None
+ else expanded_attn_mask + combined_attention_mask
+ )
+
+ return combined_attention_mask
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None, # [num_quant, B, T]
+ cond: torch.LongTensor = None, # index for conditional embeddings, [B]
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+
+ # retrieve some shape info
+ batch_size, seq_length, _ = input_ids.shape
+
+ inputs_embeds = input_ids # [B, T, H]
+ # embed cond
+ cond_embedding = self.embed_cond(cond) # [B, H]
+
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length,
+ seq_length + past_key_values_length,
+ dtype=torch.long,
+ device=device,
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ # embed positions
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (batch_size, seq_length_with_past),
+ dtype=torch.bool,
+ device=inputs_embeds.device,
+ )
+ attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask,
+ (batch_size, seq_length),
+ inputs_embeds,
+ past_key_values_length,
+ )
+
+ hidden_states = inputs_embeds
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ for idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ past_key_value = (
+ past_key_values[idx] if past_key_values is not None else None
+ )
+
+ if self.gradient_checkpointing and self.training:
+ raise NotImplementedError
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, output_attentions, None)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ attention_mask,
+ position_ids,
+ None,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cond_embedding=cond_embedding, # using cond embed
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states, cond_embedding=cond_embedding)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
+ if v is not None
+ )
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+from transformers.models.llama.modeling_llama import LlamaPreTrainedModel
+from transformers.models.llama.modeling_llama import CrossEntropyLoss
+from easydict import EasyDict as edict
+
+
+class LlamaForNARModeling(LlamaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = LlammaNARModel(config)
+
+ self.lm_head = nn.ModuleList(
+ [
+ nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ for i in range(END_QUANTIZATION_LAYER - START_QUANTIZATION_LAYER + 1)
+ ]
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def forward(
+ self,
+ cond: torch.LongTensor, # added
+ prediction_target: torch.LongTensor = None, # added. No shifting. -100 means no loss
+ input_ids: torch.LongTensor = None, # expect an embedding, [B, T, H]
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ # labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ """Prediction target: [B, T]"""
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ cond=cond, # added
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ logits = self.lm_head[cond - START_QUANTIZATION_LAYER](hidden_states)
+
+ loss = None
+ loss_fct = CrossEntropyLoss()
+
+ if prediction_target is not None:
+ # calculate loss if prediction_target is provided
+ logits_tmp = logits.view(-1, logits.size(-1))
+ prediction_target = prediction_target.view(-1)
+ loss = loss_fct(logits_tmp, prediction_target)
+
+ return edict(
+ loss=loss,
+ logits=logits,
+ )
+
+
+class ValleNAR(nn.Module):
+ def __init__(
+ self,
+ phone_vocab_size=256,
+ target_vocab_size=1024,
+ hidden_size=1024,
+ intermediate_size=4096,
+ num_hidden_layers=12,
+ num_attention_heads=16,
+ pad_token_id=1024 + 256,
+ bos_target_id=1282,
+ eos_target_id=1283,
+ bos_phone_id=1284,
+ eos_phone_id=1285,
+ bos_prompt_id=1286,
+ eos_prompt_id=1287,
+ use_input_embeds=False,
+ emb_dim=256,
+ ):
+ super(ValleNAR, self).__init__()
+ self.config = LlamaConfig(
+ vocab_size=phone_vocab_size + target_vocab_size + 10,
+ hidden_size=hidden_size,
+ intermediate_size=intermediate_size,
+ num_hidden_layers=num_hidden_layers,
+ num_attention_heads=num_attention_heads,
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_target_id,
+ eos_token_id=eos_target_id,
+ use_cache=False,
+ )
+ self.phone_vocab_size = phone_vocab_size
+ self.target_vocab_size = target_vocab_size
+ self.pad_token_id = pad_token_id
+ self.bos_target_id = bos_target_id
+ self.eos_target_id = eos_target_id
+ self.bos_phone_id = bos_phone_id
+ self.eos_phone_id = eos_phone_id
+ self.bos_prompt_id = bos_prompt_id
+ self.eos_prompt_id = eos_prompt_id
+ self.model = LlamaForNARModeling(self.config)
+
+ self.use_input_embeds = use_input_embeds
+
+ self.phone_embedder = nn.Embedding(
+ self.phone_vocab_size + 10, hidden_size
+ ) # use phone_embedder to embed all eos, bos tokens
+ self.prompt_embedder = MultiEmbedding(
+ num_embeddings=self.target_vocab_size,
+ embedding_dim=hidden_size,
+ num_quantization_layers=NUM_QUANTIZERS,
+ )
+ self.phone_embedder.weight.data.normal_(mean=0.0, std=0.02)
+
+ # use linear mask schedule when training
+ # another option is uniform
+ self.mask_layer_schedule = "uniform"
+
+ # no input embedding is used to provide speaker information
+ if self.use_input_embeds:
+ self.emb_linear = nn.Linear(emb_dim, hidden_size)
+ self.emb_linear.weight.data.normal_(mean=0.0, std=0.01)
+ self.emb_linear.bias.data.zero_()
+
+ def forward(
+ self,
+ phone_ids,
+ phone_mask,
+ target_ids,
+ target_mask,
+ target_quantization_layer=None,
+ prompt_len=None,
+ dropout=0.0,
+ ):
+ """
+ phone_ids: [B, T]
+ phone_mask: [B, T]
+ target_ids: [8,B,T]
+ target_mask: [B, T]
+ dropout: rate of dropping out the target tokens
+ """
+ assert (target_ids < 1024).all(), "target_ids should be less than 1024"
+ phone_ids = phone_ids + self.target_vocab_size
+ phone_ids = phone_ids * phone_mask + (1 - phone_mask) * self.pad_token_id
+ # assert (phone_ids >= 1024).all(), "phone_ids should be greater than 1024"
+ # phone_ids, phone_mask, phone_label = self.add_phone_eos_bos_label(
+ # phone_ids,
+ # phone_mask,
+ # self.eos_phone_id,
+ # self.bos_phone_id,
+ # self.pad_token_id,
+ # )
+ phone_label = -100 * (1 - phone_mask)
+ # get phone embedding
+ phone_embedding = self.phone_embedder(
+ phone_ids - self.target_vocab_size
+ ) # [B, T, H]
+
+ if prompt_len is not None:
+ assert not self.training # inference stage fix prompt len to input
+ NUM_PROMPT_TOKENS = prompt_len
+ else:
+ assert self.training
+ # randomly select a prompt length
+ assert self.training # randomize prompt len in training
+ NUM_PROMPT_TOKENS = np.random.randint(
+ min(target_ids.shape[-1] // 4, 5), target_ids.shape[-1] // 2
+ )
+
+ # extract 8-level prompts
+ prompt_tokens = target_ids[:, :, :NUM_PROMPT_TOKENS] # [Q, B, T]
+ prompt_mask = torch.ones_like(prompt_tokens[0])
+ prompt_label = -100 * prompt_mask
+ # get prompt embedding
+ prompt_embedding = self.prompt_embedder(prompt_tokens) # [B, T, H]
+
+ # randomly select a target qnt layer to predict
+ # total quant layer is 0 to 7
+ if target_quantization_layer is None:
+ if self.mask_layer_schedule == "linear":
+ weights = torch.tensor(
+ [
+ NUM_QUANTIZERS - i
+ for i in range(
+ START_QUANTIZATION_LAYER, END_QUANTIZATION_LAYER + 1
+ )
+ ]
+ )
+ weights = weights / weights.sum()
+ mask_layer = (
+ torch.multinomial(weights, 1, replacement=True)
+ + START_QUANTIZATION_LAYER
+ )
+ assert (
+ mask_layer >= START_QUANTIZATION_LAYER
+ and mask_layer <= END_QUANTIZATION_LAYER
+ )
+ target_quantization_layer = mask_layer.item()
+ elif self.mask_layer_schedule == "cosine":
+ weights = torch.tensor(
+ [
+ np.cos(i / NUM_QUANTIZERS * np.pi / 2)
+ for i in range(
+ START_QUANTIZATION_LAYER, END_QUANTIZATION_LAYER + 1
+ )
+ ]
+ )
+ weights = weights / weights.sum()
+ mask_layer = (
+ torch.multinomial(weights, 1, replacement=True)
+ + START_QUANTIZATION_LAYER
+ )
+ assert (
+ mask_layer >= START_QUANTIZATION_LAYER
+ and mask_layer <= END_QUANTIZATION_LAYER
+ )
+ target_quantization_layer = mask_layer.item()
+ breakpoint()
+ elif self.mask_layer_schedule == "uniform":
+ target_quantization_layer = np.random.randint(
+ START_QUANTIZATION_LAYER, END_QUANTIZATION_LAYER + 1
+ )
+
+ # print(f'target layer: {target_quantization_layer}')
+ # prompt of the target part
+ target_prompt_ids = target_ids[
+ :target_quantization_layer, :, NUM_PROMPT_TOKENS:
+ ]
+
+ def randomly_set_elements(tensor, fraction, value):
+ """
+ Randomly set a fraction of the elements in a tensor to a specific value.
+
+ Args:
+ tensor (torch.Tensor): The input tensor.
+ fraction (float): The fraction of elements to set to the specified value (between 0 and 1).
+ value (float or int): The value to set the elements to.
+
+ Returns:
+ torch.Tensor: The tensor with some elements set to the specified value.
+ """
+ # Create a mask with the same shape as the tensor
+ mask = torch.rand_like(tensor, dtype=torch.float32) < fraction
+ # Clone the tensor to avoid modifying the original tensor
+ result_tensor = tensor.clone()
+ # Set the elements where the mask is True to the specified value
+ result_tensor[mask] = value
+ return result_tensor
+
+ if dropout != 0.0:
+ target_prompt_ids = randomly_set_elements(
+ target_prompt_ids, dropout, self.target_vocab_size
+ )
+
+ target_embedding = self.prompt_embedder(target_prompt_ids)
+
+ # mask of the target part
+ target_mask = target_mask[:, NUM_PROMPT_TOKENS:]
+
+ target_labels = target_ids[
+ target_quantization_layer, :, NUM_PROMPT_TOKENS:
+ ] * target_mask + (-100 * (1 - target_mask))
+
+ # input embeddings
+ input_embeddings = torch.cat(
+ [phone_embedding, prompt_embedding, target_embedding], dim=1
+ )
+ input_mask = torch.cat([phone_mask, prompt_mask, target_mask], dim=1) # [B, T]
+ prediction_target = torch.cat(
+ [phone_label, prompt_label, target_labels], dim=1
+ ) # [B, T]
+
+ out = self.model(
+ cond=torch.tensor(
+ target_quantization_layer,
+ device=prediction_target.device,
+ dtype=torch.long,
+ ),
+ input_ids=input_embeddings,
+ prediction_target=prediction_target,
+ attention_mask=input_mask,
+ return_dict=True,
+ )
+ logits = out.logits[:, -target_embedding.shape[1] :, :]
+ targets = prediction_target[..., -target_embedding.shape[1] :]
+ top1_acc = logits.argmax(-1) == targets
+ top1_acc = (top1_acc * target_mask).sum() / target_mask.sum()
+
+ top5_acc = (logits.topk(5, dim=-1).indices == targets.unsqueeze(-1)).any(-1)
+ top5_acc = (top5_acc * target_mask).sum() / target_mask.sum()
+
+ top10_acc = (logits.topk(10, dim=-1).indices == targets.unsqueeze(-1)).any(-1)
+ top10_acc = (top10_acc * target_mask).sum() / target_mask.sum()
+
+ out.target_quantization_layer = target_quantization_layer
+ out.top1_acc = top1_acc
+ out.top5_acc = top5_acc
+ out.top10_acc = top10_acc
+
+ return out
+
+ def add_phone_eos_bos_label(
+ self, phone_ids, phone_mask, phone_eos_id, phone_bos_id, pad_token_id
+ ):
+ # phone_ids: [B, T]
+ # phone_mask: [B, T]
+
+ phone_ids = phone_ids + self.target_vocab_size * phone_mask
+
+ phone_ids = phone_ids * phone_mask
+ phone_ids = F.pad(phone_ids, (0, 1), value=0) + phone_eos_id * F.pad(
+ 1 - phone_mask, (0, 1), value=1
+ ) # make pad token eos token, add eos token at the end
+ phone_mask = F.pad(phone_mask, (1, 0), value=1) # add eos mask
+ phone_ids = phone_ids * phone_mask + pad_token_id * (
+ 1 - phone_mask
+ ) # restore pad token ids
+ phone_ids = F.pad(phone_ids, (1, 0), value=phone_bos_id) # add bos token
+ phone_mask = F.pad(phone_mask, (1, 0), value=1) # add bos mask
+ phone_label = -100 * torch.ones_like(
+ phone_ids
+ ) # loss for entire phone is not computed (passed to llama)
+ return phone_ids, phone_mask, phone_label
+
+ @torch.no_grad()
+ def sample_hf(
+ self,
+ phone_ids, # [B, T]
+ prompt_ids, # [8, B, T]
+ first_stage_ids, # [B, T]
+ top_k=50,
+ top_p=1,
+ temperature=1.1,
+ first_stage_ids_gt=None, # [Q, B, T]
+ first_stage_ids_gt_end_layer=None, # 2 to 8
+ ):
+ """
+ phone_ids: [B, T]
+ prompt_ids: [8, B, T]
+ first_stage_ids: [B, T] result from first quant layer. Should be continuation of prompt_ids
+ """
+ phone_mask = torch.ones_like(phone_ids, dtype=torch.long)
+
+ assert prompt_ids.shape[-1] >= 5, "prompt_ids should have at least 5 tokens"
+ target_ids = torch.cat(
+ [prompt_ids, first_stage_ids.expand(prompt_ids.shape[0], -1, -1)], dim=-1
+ )
+ target_mask = torch.ones_like(target_ids[0], dtype=torch.long)
+
+ if first_stage_ids_gt is not None:
+ target_ids[
+ :first_stage_ids_gt_end_layer, :, -first_stage_ids_gt.shape[-1] :
+ ] = first_stage_ids_gt[:first_stage_ids_gt_end_layer]
+
+ gen_len = first_stage_ids.shape[-1]
+
+ start_qnt_layer = 1
+ if first_stage_ids_gt_end_layer is not None:
+ start_qnt_layer = first_stage_ids_gt_end_layer
+ for qnt_level in range(start_qnt_layer, 8):
+ out = self.forward(
+ phone_ids=phone_ids,
+ phone_mask=phone_mask,
+ target_ids=target_ids,
+ target_mask=target_mask,
+ target_quantization_layer=qnt_level,
+ prompt_len=prompt_ids.shape[-1],
+ )
+ logits = out.logits
+ gen_tokens = torch.argmax(logits, dim=-1).reshape(-1)[
+ -gen_len:
+ ] # [T], generated tokens in this level
+
+ # overwrite the target_ids with the generated tokens
+ target_ids[qnt_level, :, -gen_len:] = gen_tokens
+
+ return target_ids[:, :, -gen_len:]
+
+
+def test():
+ model = ValleNAR().cuda()
+
+ phone_ids = torch.LongTensor([1, 2, 3, 4, 5]).reshape(1, -1).cuda()
+ phone_mask = torch.LongTensor([1, 1, 1, 1, 1]).reshape(1, -1).cuda()
+ target_ids = torch.randint(high=1024, size=(8, 1, 250), dtype=torch.long).cuda()
+ target_mask = torch.ones(1, 250, dtype=torch.long).cuda()
+ optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
+
+ for i in range(200):
+ optimizer.zero_grad()
+ out = model(
+ phone_ids=phone_ids,
+ phone_mask=phone_mask,
+ target_ids=target_ids,
+ target_mask=target_mask,
+ # target_quantization_layer=1+i%6,
+ )
+ loss = out.loss
+
+ loss.backward()
+
+ optimizer.step()
+
+ print(f"iter={i}, {loss}.")
+ target_ids_short = target_ids[:, :, :240]
+
+ model.eval()
+ sampled = model.sample_hf(
+ phone_ids, prompt_ids=target_ids_short, first_stage_ids=target_ids[0, :, 240:]
+ )
+
+ print(target_ids[:, :, -10:])
+ print(sampled)
+
+ print((sampled == target_ids[:, :, -10:]).all())
+
+
+if __name__ == "__main__":
+ test()
diff --git a/models/tts/valle_v2/valle_nar_trainer.py b/models/tts/valle_v2/valle_nar_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..88e5e8a362e90da7776feba7f108b4c7eafe8618
--- /dev/null
+++ b/models/tts/valle_v2/valle_nar_trainer.py
@@ -0,0 +1,205 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torchaudio
+import numpy as np
+import time
+from .valle_ar_trainer import ValleARTrainer, make_pad_mask
+
+
+class ValleNARTrainer(ValleARTrainer):
+ def __init__(self, args=None, cfg=None):
+ super().__init__(args, cfg)
+ print("simple NAR")
+ self.top1_accuracies = {
+ 1: [],
+ 2: [],
+ 3: [],
+ 4: [],
+ 5: [],
+ 6: [],
+ 7: [],
+ }
+ self.top5_accuracies = {
+ 1: [],
+ 2: [],
+ 3: [],
+ 4: [],
+ 5: [],
+ 6: [],
+ 7: [],
+ }
+ self.top10_accuracies = {
+ 1: [],
+ 2: [],
+ 3: [],
+ 4: [],
+ 5: [],
+ 6: [],
+ 7: [],
+ }
+
+ def _build_model(self):
+ from .valle_nar import ValleNAR
+
+ return ValleNAR(**self.cfg.model)
+
+ def _train_step(self, batch):
+ # inference codec
+ """Returns: dict('speech', 'speech_len', 'phone_ids', 'phone_lens')
+ speech: [B, T]
+ speech_len: [B]
+ phone_ids: [B, T]
+ phone_lens: [B]
+ """
+ device = self.accelerator.device
+ for k, v in batch.items():
+ if isinstance(v, torch.Tensor):
+ batch[k] = v.to(device)
+
+ with torch.no_grad():
+ if self.cfg.use_speechtokenizer:
+ # Extract discrete codes from SpeechTokenizer
+ # 16k
+ vq_id = self.codec_encoder.encode(
+ batch["speech"].unsqueeze(1)
+ ) # [B,T] -> (n_q, B, T)
+ # RVQ_1 = codes[:1, :, :] # Contain content info, can be considered as semantic tokens
+ # RVQ_supplement = codes[1:, :, :] # Contain timbre info, complete info lost by the first quantizer
+ # Concatenating semantic tokens (RVQ_1) and supplementary timbre tokens and then decoding
+ # wav = self.codec_encoder.decode(vq_id)
+ # torchaudio.save('a.wav', wav[0].cpu(), 16000)
+
+ # # Decoding from RVQ-i:j tokens from the ith quantizers to the jth quantizers
+ # wav = model.decode(codes[i: (j + 1)], st=i)
+ else:
+ # using encodec, 24k
+ vq_id = self.codec_encoder.encode(batch["speech"].unsqueeze(1))
+ vq_id = torch.cat([encoded[0] for encoded in vq_id], dim=-1).transpose(
+ 0, 1
+ )
+
+ # recovered_audio = self.codec_decoder(vq_emb, vq=False)
+ # torchaudio.save('a.wav', recovered_audio[0], 16000)
+ # vq_id: [8, B, T//320]
+ batch["speech"] = vq_id
+ batch["speech_len"] = batch["speech_len"] // 320 # our codec downsamples 320x
+ assert batch["speech_len"].max() <= batch["speech"].shape[-1]
+
+ phone_mask = 1 - make_pad_mask(
+ batch["phone_lens"], max_len=batch["phone_ids"].size(1), left_pad=False
+ ).to(torch.long)
+ speech_mask = 1 - make_pad_mask(
+ batch["speech_len"], max_len=batch["speech"].size(-1)
+ ).to(torch.long)
+
+ np.random.seed(int(time.time()) - 5 * self.accelerator.process_index)
+
+ if hasattr(self.cfg.train, "dropout"):
+ dropout = self.cfg.train.dropout
+ else:
+ dropout = 0.0
+
+ out = self.model(
+ phone_ids=batch["phone_ids"],
+ phone_mask=phone_mask,
+ target_ids=batch["speech"],
+ target_mask=speech_mask,
+ dropout=dropout,
+ )
+ loss = out.loss
+
+ self.accelerator.log(
+ {f"Train/NAR L{out.target_quantization_layer} Top1 acc": out.top1_acc},
+ step=self.step,
+ )
+ self.accelerator.log(
+ {f"Train/NAR L{out.target_quantization_layer} Top5 acc": out.top5_acc},
+ step=self.step,
+ )
+ self.accelerator.log(
+ {f"Train/NAR L{out.target_quantization_layer} Top10 acc": out.top10_acc},
+ step=self.step,
+ )
+
+ # if hasattr(out, 'top1_acc'):
+ # idx = out.target_quantization_layer
+ # self.top1_accuracies[idx].append(out.top1_acc)
+ # self.top5_accuracies[idx].append(out.top5_acc)
+ # self.top10_accuracies[idx].append(out.top10_acc)
+ # if len(self.top1_accuracies[idx]) >= 160:
+ # breakpoint()
+ # if self.accelerator.is_main_process:
+ # print(loss)
+ return loss
+
+ def _test_step(self, batch):
+ # inference codec
+ """Returns: dict('speech', 'speech_len', 'phone_ids', 'phone_lens')
+ speech: [B, T]
+ speech_len: [B]
+ phone_ids: [B, T]
+ phone_lens: [B]
+ """
+ import torchaudio
+
+ device = self.accelerator.device
+ for k, v in batch.items():
+ if isinstance(v, torch.Tensor):
+ batch[k] = v.to(device)
+ with torch.no_grad():
+ if self.cfg.use_speechtokenizer:
+ # Extract discrete codes from SpeechTokenizer
+ # 16k
+ vq_id = self.codec_encoder.encode(
+ batch["speech"].unsqueeze(1)
+ ) # [B,1,T] -> (n_q, B, T)
+ # Concatenating semantic tokens (RVQ_1) and supplementary timbre tokens and then decoding
+ # wav = self.codec_encoder.decode(vq_id)
+ # torchaudio.save('a.wav', wav[0].cpu(), 16000)
+
+ else:
+ vq_id = self.codec_encoder.encode(batch["speech"].unsqueeze(1))
+ vq_id = torch.cat([encoded[0] for encoded in vq_id], dim=-1).transpose(
+ 0, 1
+ )
+ # recovered_audio = self.codec_encoder.decode([(vq_id.transpose(0,1), None)])
+ # recovered_audio = self.codec_decoder(vq_emb, vq=False)
+ # torchaudio.save('a.wav', recovered_audio[0], 16000)
+ # vq_id: [8, B, T//200]
+
+ # vq_emb = self.codec_decoder.quantizer.vq2emb(vq=vq_id[:1], n_quantizers=1)
+ # recovered_audio = self.codec_decoder(vq_emb, vq=False)
+ # recovered_audio.shape: torch.Size([1, 1, 50200])
+
+ batch["speech"] = vq_id
+
+ # save gt
+ if self.cfg.use_speechtokenizer:
+ recovered_audio = self.codec_encoder.decode(vq_id)
+ else:
+ recovered_audio = self.codec_encoder.decode(
+ [(vq_id.transpose(0, 1), None)]
+ )
+ torchaudio.save("gt.wav", recovered_audio[0].cpu(), 16000)
+ self.model.eval()
+ out_vq_ids = self.model.sample_hf(
+ phone_ids=batch["phone_ids"][:1],
+ prompt_ids=batch["speech"][:, :1, :150],
+ first_stage_ids=batch["speech"][0, :1, 150:],
+ )
+ # breakpoint()
+ # out_vq_ids = torch.cat([batch['speech'][:, :225], out_vq_ids], dim=1)
+
+ # reconstruct form tokens
+ if self.cfg.use_speechtokenizer:
+ recovered_audio = self.codec_encoder.decode(out_vq_ids)
+ else:
+ recovered_audio = self.codec_encoder.decode(
+ [(out_vq_ids.transpose(0, 1)[:1], None)]
+ )
+ torchaudio.save("a.wav", recovered_audio[0].cpu(), 16000)
+ breakpoint()
diff --git a/models/tts/vits/__init__.py b/models/tts/vits/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/tts/vits/vits.py b/models/tts/vits/vits.py
new file mode 100644
index 0000000000000000000000000000000000000000..61aff4b24caa7652e818b623bd89be3ee8eff3c7
--- /dev/null
+++ b/models/tts/vits/vits.py
@@ -0,0 +1,379 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# This code is modified from https://github.com/jaywalnut310/vits/blob/main/models.py
+import math
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from utils.util import *
+from modules.flow.modules import *
+from modules.base.base_module import *
+from modules.transformer.attentions import Encoder
+from modules.duration_predictor.standard_duration_predictor import DurationPredictor
+from modules.duration_predictor.stochastic_duration_predictor import (
+ StochasticDurationPredictor,
+)
+from models.vocoders.gan.generator.hifigan import HiFiGAN_vits as Generator
+
+try:
+ from modules import monotonic_align
+except ImportError:
+ print("Monotonic align not found. Please make sure you have compiled it.")
+
+
+class TextEncoder(nn.Module):
+ def __init__(
+ self,
+ n_vocab,
+ out_channels,
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size,
+ p_dropout,
+ ):
+ super().__init__()
+ self.n_vocab = n_vocab
+ self.out_channels = out_channels
+ self.hidden_channels = hidden_channels
+ self.filter_channels = filter_channels
+ self.n_heads = n_heads
+ self.n_layers = n_layers
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+
+ self.emb = nn.Embedding(n_vocab, hidden_channels)
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
+
+ self.encoder = Encoder(
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
+ )
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
+
+ def forward(self, x, x_lengths):
+ x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
+ x = torch.transpose(x, 1, -1) # [b, h, t]
+ x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
+
+ x = self.encoder(x * x_mask, x_mask)
+ stats = self.proj(x) * x_mask
+
+ m, logs = torch.split(stats, self.out_channels, dim=1)
+ return x, m, logs, x_mask
+
+
+class ResidualCouplingBlock(nn.Module):
+ def __init__(
+ self,
+ channels,
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ n_flows=4,
+ gin_channels=0,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.hidden_channels = hidden_channels
+ self.kernel_size = kernel_size
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.n_flows = n_flows
+ self.gin_channels = gin_channels
+
+ self.flows = nn.ModuleList()
+ for i in range(n_flows):
+ self.flows.append(
+ ResidualCouplingLayer(
+ channels,
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ gin_channels=gin_channels,
+ mean_only=True,
+ )
+ )
+ self.flows.append(Flip())
+
+ def forward(self, x, x_mask, g=None, reverse=False):
+ if not reverse:
+ for flow in self.flows:
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
+ else:
+ for flow in reversed(self.flows):
+ x = flow(x, x_mask, g=g, reverse=reverse)
+ return x
+
+
+class PosteriorEncoder(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ gin_channels=0,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.hidden_channels = hidden_channels
+ self.kernel_size = kernel_size
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.gin_channels = gin_channels
+
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
+ self.enc = WN(
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ gin_channels=gin_channels,
+ )
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
+
+ def forward(self, x, x_lengths, g=None):
+ x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
+ x = self.pre(x) * x_mask
+ x = self.enc(x, x_mask, g=g)
+ stats = self.proj(x) * x_mask
+ m, logs = torch.split(stats, self.out_channels, dim=1)
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
+ return z, m, logs, x_mask
+
+
+class SynthesizerTrn(nn.Module):
+ """
+ Synthesizer for Training
+ """
+
+ def __init__(
+ self,
+ n_vocab,
+ spec_channels,
+ segment_size,
+ inter_channels,
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size,
+ p_dropout,
+ resblock,
+ resblock_kernel_sizes,
+ resblock_dilation_sizes,
+ upsample_rates,
+ upsample_initial_channel,
+ upsample_kernel_sizes,
+ n_speakers=0,
+ gin_channels=0,
+ use_sdp=True,
+ **kwargs,
+ ):
+ super().__init__()
+ self.n_vocab = n_vocab
+ self.spec_channels = spec_channels
+ self.inter_channels = inter_channels
+ self.hidden_channels = hidden_channels
+ self.filter_channels = filter_channels
+ self.n_heads = n_heads
+ self.n_layers = n_layers
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.resblock = resblock
+ self.resblock_kernel_sizes = resblock_kernel_sizes
+ self.resblock_dilation_sizes = resblock_dilation_sizes
+ self.upsample_rates = upsample_rates
+ self.upsample_initial_channel = upsample_initial_channel
+ self.upsample_kernel_sizes = upsample_kernel_sizes
+ self.segment_size = segment_size
+ self.n_speakers = n_speakers
+ self.gin_channels = gin_channels
+
+ self.use_sdp = use_sdp
+
+ self.enc_p = TextEncoder(
+ n_vocab,
+ inter_channels,
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size,
+ p_dropout,
+ )
+ self.dec = Generator(
+ inter_channels,
+ resblock,
+ resblock_kernel_sizes,
+ resblock_dilation_sizes,
+ upsample_rates,
+ upsample_initial_channel,
+ upsample_kernel_sizes,
+ gin_channels=gin_channels,
+ )
+ self.enc_q = PosteriorEncoder(
+ spec_channels,
+ inter_channels,
+ hidden_channels,
+ 5,
+ 1,
+ 16,
+ gin_channels=gin_channels,
+ )
+ self.flow = ResidualCouplingBlock(
+ inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
+ )
+
+ if use_sdp:
+ self.dp = StochasticDurationPredictor(
+ hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
+ )
+ else:
+ self.dp = DurationPredictor(
+ hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
+ )
+
+ if n_speakers >= 1:
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
+
+ def forward(self, data):
+ x = data["phone_seq"]
+ x_lengths = data["phone_len"]
+ y = data["linear"]
+ y_lengths = data["target_len"]
+
+ x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
+ if self.n_speakers > 0:
+ g = self.emb_g(data["spk_id"].squeeze(-1)).unsqueeze(-1) # [b, h, 1]
+ else:
+ g = None
+
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
+ z_p = self.flow(z, y_mask, g=g)
+
+ with torch.no_grad():
+ # negative cross-entropy
+ s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
+ neg_cent1 = torch.sum(
+ -0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True
+ ) # [b, 1, t_s]
+ neg_cent2 = torch.matmul(
+ -0.5 * (z_p**2).transpose(1, 2), s_p_sq_r
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
+ neg_cent3 = torch.matmul(
+ z_p.transpose(1, 2), (m_p * s_p_sq_r)
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
+ neg_cent4 = torch.sum(
+ -0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True
+ ) # [b, 1, t_s]
+ neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
+
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
+ attn = (
+ monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1))
+ .unsqueeze(1)
+ .detach()
+ )
+
+ w = attn.sum(2)
+ if self.use_sdp:
+ l_length = self.dp(x, x_mask, w, g=g)
+ l_length = l_length / torch.sum(x_mask)
+ else:
+ logw_ = torch.log(w + 1e-6) * x_mask
+ logw = self.dp(x, x_mask, g=g)
+ l_length = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(x_mask)
+
+ # expand prior
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
+
+ z_slice, ids_slice = rand_slice_segments(z, y_lengths, self.segment_size)
+ o = self.dec(z_slice, g=g)
+ outputs = {
+ "y_hat": o,
+ "l_length": l_length,
+ "attn": attn,
+ "ids_slice": ids_slice,
+ "x_mask": x_mask,
+ "z_mask": y_mask,
+ "z": z,
+ "z_p": z_p,
+ "m_p": m_p,
+ "logs_p": logs_p,
+ "m_q": m_q,
+ "logs_q": logs_q,
+ }
+ return outputs
+
+ def infer(
+ self,
+ x,
+ x_lengths,
+ sid=None,
+ noise_scale=1,
+ length_scale=1,
+ noise_scale_w=1.0,
+ max_len=None,
+ ):
+ x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
+ if self.n_speakers > 0:
+ sid = sid.squeeze(-1)
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
+ else:
+ g = None
+
+ if self.use_sdp:
+ logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
+ else:
+ logw = self.dp(x, x_mask, g=g)
+ w = torch.exp(logw) * x_mask * length_scale
+ w_ceil = torch.ceil(w)
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
+ y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(x_mask.dtype)
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
+ attn = generate_path(w_ceil, attn_mask)
+
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
+ 1, 2
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
+ 1, 2
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
+
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
+ o = self.dec((z * y_mask)[:, :, :max_len], g=g)
+
+ outputs = {
+ "y_hat": o,
+ "attn": attn,
+ "mask": y_mask,
+ "z": z,
+ "z_p": z_p,
+ "m_p": m_p,
+ "logs_p": logs_p,
+ }
+
+ return outputs
+
+ def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
+ assert self.n_speakers > 0, "n_speakers have to be larger than 0."
+ g_src = self.emb_g(sid_src).unsqueeze(-1)
+ g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src)
+ z_p = self.flow(z, y_mask, g=g_src)
+ z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
+ o_hat = self.dec(z_hat * y_mask, g=g_tgt)
+ return o_hat, y_mask, (z, z_p, z_hat)
diff --git a/models/tts/vits/vits_dataset.py b/models/tts/vits/vits_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3a1444bcec76386fb4d0145c88d78437ffb4834
--- /dev/null
+++ b/models/tts/vits/vits_dataset.py
@@ -0,0 +1,140 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import json
+import numpy as np
+from text import text_to_sequence
+from text.text_token_collation import phoneIDCollation
+from models.tts.base.tts_dataset import (
+ TTSDataset,
+ TTSCollator,
+ TTSTestDataset,
+ TTSTestCollator,
+)
+
+
+class VITSDataset(TTSDataset):
+ def __init__(self, cfg, dataset, is_valid):
+ super().__init__(cfg, dataset, is_valid=is_valid)
+
+ def __getitem__(self, index):
+ single_feature = super().__getitem__(index)
+ return single_feature
+
+ def __len__(self):
+ return super().__len__()
+
+ def get_metadata(self):
+ metadata_filter = []
+ with open(self.metafile_path, "r", encoding="utf-8") as f:
+ metadata = json.load(f)
+ for utt_info in metadata:
+ duration = utt_info["Duration"]
+ frame_len = (
+ duration
+ * self.cfg.preprocess.sample_rate
+ // self.cfg.preprocess.hop_size
+ )
+ if (
+ frame_len
+ < self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size
+ ):
+ continue
+ metadata_filter.append(utt_info)
+
+ return metadata_filter
+
+
+class VITSCollator(TTSCollator):
+ """Zero-pads model inputs and targets based on number of frames per step"""
+
+ def __init__(self, cfg):
+ super().__init__(cfg)
+
+ def __call__(self, batch):
+ parsed_batch_features = super().__call__(batch)
+ return parsed_batch_features
+
+
+class VITSTestDataset(TTSTestDataset):
+ def __init__(self, args, cfg):
+ super().__init__(args, cfg)
+ processed_data_dir = os.path.join(cfg.preprocess.processed_dir, args.dataset)
+ if cfg.preprocess.use_spkid:
+ spk2id_path = os.path.join(processed_data_dir, cfg.preprocess.spk2id)
+ with open(spk2id_path, "r") as f:
+ self.spk2id = json.load(f)
+
+ utt2spk_path = os.path.join(processed_data_dir, cfg.preprocess.utt2spk)
+ self.utt2spk = dict()
+ with open(utt2spk_path, "r") as f:
+ for line in f.readlines():
+ utt, spk = line.strip().split("\t")
+ self.utt2spk[utt] = spk
+
+ if cfg.preprocess.use_text or cfg.preprocess.use_phone:
+ self.utt2seq = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ if cfg.preprocess.use_text:
+ text = utt_info["Text"]
+ sequence = text_to_sequence(text, cfg.preprocess.text_cleaners)
+ elif cfg.preprocess.use_phone:
+ # load phoneme squence from phone file
+ phone_path = os.path.join(
+ processed_data_dir, cfg.preprocess.phone_dir, uid + ".phone"
+ )
+ with open(phone_path, "r") as fin:
+ phones = fin.readlines()
+ assert len(phones) == 1
+ phones = phones[0].strip()
+ phones_seq = phones.split(" ")
+
+ phon_id_collator = phoneIDCollation(cfg, dataset=dataset)
+ sequence = phon_id_collator.get_phone_id_sequence(cfg, phones_seq)
+
+ self.utt2seq[utt] = sequence
+
+ def __getitem__(self, index):
+ utt_info = self.metadata[index]
+
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ single_feature = dict()
+
+ if self.cfg.preprocess.use_spkid:
+ single_feature["spk_id"] = np.array(
+ [self.spk2id[self.utt2spk[utt]]], dtype=np.int32
+ )
+
+ if self.cfg.preprocess.use_phone or self.cfg.preprocess.use_text:
+ single_feature["phone_seq"] = np.array(self.utt2seq[utt])
+ single_feature["phone_len"] = len(self.utt2seq[utt])
+
+ return single_feature
+
+ def get_metadata(self):
+ with open(self.metafile_path, "r", encoding="utf-8") as f:
+ metadata = json.load(f)
+ return metadata
+
+ def __len__(self):
+ return len(self.metadata)
+
+
+class VITSTestCollator(TTSTestCollator):
+ """Zero-pads model inputs and targets based on number of frames per step"""
+
+ def __init__(self, cfg):
+ self.cfg = cfg
+
+ def __call__(self, batch):
+ return super().__call__(batch)
diff --git a/models/tts/vits/vits_inference.py b/models/tts/vits/vits_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e28858af7c1897d8e0f1c6906605595e7a712c2
--- /dev/null
+++ b/models/tts/vits/vits_inference.py
@@ -0,0 +1,163 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import time
+import numpy as np
+from tqdm import tqdm
+import torch
+import json
+from models.tts.base.tts_inferece import TTSInference
+from models.tts.vits.vits_dataset import VITSTestDataset, VITSTestCollator
+from models.tts.vits.vits import SynthesizerTrn
+from processors.phone_extractor import phoneExtractor
+from text.text_token_collation import phoneIDCollation
+from utils.data_utils import *
+
+
+class VitsInference(TTSInference):
+ def __init__(self, args=None, cfg=None):
+ TTSInference.__init__(self, args, cfg)
+
+ def _build_model(self):
+ net_g = SynthesizerTrn(
+ self.cfg.model.text_token_num,
+ self.cfg.preprocess.n_fft // 2 + 1,
+ self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
+ **self.cfg.model,
+ )
+
+ return net_g
+
+ def _build_test_dataset(sefl):
+ return VITSTestDataset, VITSTestCollator
+
+ def build_save_dir(self, dataset, speaker):
+ save_dir = os.path.join(
+ self.args.output_dir,
+ "tts_am_step-{}_{}".format(self.am_restore_step, self.args.mode),
+ )
+ if dataset is not None:
+ save_dir = os.path.join(save_dir, "data_{}".format(dataset))
+ if speaker != -1:
+ save_dir = os.path.join(
+ save_dir,
+ "spk_{}".format(speaker),
+ )
+ os.makedirs(save_dir, exist_ok=True)
+ print("Saving to ", save_dir)
+ return save_dir
+
+ def inference_for_batches(
+ self, noise_scale=0.667, noise_scale_w=0.8, length_scale=1
+ ):
+ ###### Construct test_batch ######
+ n_batch = len(self.test_dataloader)
+ now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
+ print(
+ "Model eval time: {}, batch_size = {}, n_batch = {}".format(
+ now, self.test_batch_size, n_batch
+ )
+ )
+ self.model.eval()
+
+ ###### Inference for each batch ######
+ pred_res = []
+ with torch.no_grad():
+ for i, batch_data in enumerate(
+ self.test_dataloader if n_batch == 1 else tqdm(self.test_dataloader)
+ ):
+ spk_id = None
+ if (
+ self.cfg.preprocess.use_spkid
+ and self.cfg.train.multi_speaker_training
+ ):
+ spk_id = batch_data["spk_id"]
+
+ outputs = self.model.infer(
+ batch_data["phone_seq"],
+ batch_data["phone_len"],
+ spk_id,
+ noise_scale=noise_scale,
+ noise_scale_w=noise_scale_w,
+ length_scale=length_scale,
+ )
+
+ audios = outputs["y_hat"]
+ masks = outputs["mask"]
+
+ for idx in range(audios.size(0)):
+ audio = audios[idx, 0, :].data.cpu().float()
+ mask = masks[idx, :, :]
+ audio_length = (
+ mask.sum([0, 1]).long() * self.cfg.preprocess.hop_size
+ )
+ audio_length = audio_length.cpu().numpy()
+ audio = audio[:audio_length]
+ pred_res.append(audio)
+
+ return pred_res
+
+ def inference_for_single_utterance(
+ self, noise_scale=0.667, noise_scale_w=0.8, length_scale=1
+ ):
+ text = self.args.text
+
+ # get phone symbol file
+ phone_symbol_file = None
+ if self.cfg.preprocess.phone_extractor != "lexicon":
+ phone_symbol_file = os.path.join(
+ self.exp_dir, self.cfg.preprocess.symbols_dict
+ )
+ assert os.path.exists(phone_symbol_file)
+ # convert text to phone sequence
+ phone_extractor = phoneExtractor(self.cfg)
+ phone_seq = phone_extractor.extract_phone(text) # phone_seq: list
+ # convert phone sequence to phone id sequence
+ phon_id_collator = phoneIDCollation(
+ self.cfg, symbols_dict_file=phone_symbol_file
+ )
+ phone_id_seq = phon_id_collator.get_phone_id_sequence(self.cfg, phone_seq)
+
+ if self.cfg.preprocess.add_blank:
+ phone_id_seq = intersperse(phone_id_seq, 0)
+
+ # convert phone sequence to phone id sequence
+ phone_id_seq = np.array(phone_id_seq)
+ phone_id_seq = torch.from_numpy(phone_id_seq)
+
+ # get speaker id if multi-speaker training and use speaker id
+ speaker_id = None
+ if self.cfg.preprocess.use_spkid and self.cfg.train.multi_speaker_training:
+ spk2id_file = os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)
+ with open(spk2id_file, "r") as f:
+ spk2id = json.load(f)
+ speaker_name = self.args.speaker_name
+ assert (
+ speaker_name in spk2id
+ ), f"Speaker {speaker_name} not found in the spk2id keys. \
+ Please make sure you've specified the correct speaker name in infer_speaker_name."
+ speaker_id = spk2id[speaker_name]
+ speaker_id = torch.from_numpy(
+ np.array([speaker_id], dtype=np.int32)
+ ).unsqueeze(0)
+
+ with torch.no_grad():
+ x_tst = phone_id_seq.to(self.device).unsqueeze(0)
+ x_tst_lengths = torch.LongTensor([phone_id_seq.size(0)]).to(self.device)
+ if speaker_id is not None:
+ speaker_id = speaker_id.to(self.device)
+ outputs = self.model.infer(
+ x_tst,
+ x_tst_lengths,
+ sid=speaker_id,
+ noise_scale=noise_scale,
+ noise_scale_w=noise_scale_w,
+ length_scale=length_scale,
+ )
+
+ audio = outputs["y_hat"][0, 0].data.cpu().float().numpy()
+
+ return audio
diff --git a/models/tts/vits/vits_trainer.py b/models/tts/vits/vits_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e799f9173887db01745901abee736df7e2ba9070
--- /dev/null
+++ b/models/tts/vits/vits_trainer.py
@@ -0,0 +1,439 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from torch.optim.lr_scheduler import ExponentialLR
+
+from tqdm import tqdm
+
+from utils.util import *
+from utils.mel import mel_spectrogram_torch
+from models.tts.base import TTSTrainer
+from models.tts.vits.vits import SynthesizerTrn
+from models.tts.vits.vits_dataset import VITSDataset, VITSCollator
+from models.vocoders.gan.discriminator.mpd import (
+ MultiPeriodDiscriminator_vits as MultiPeriodDiscriminator,
+)
+
+
+class VITSTrainer(TTSTrainer):
+ def __init__(self, args, cfg):
+ TTSTrainer.__init__(self, args, cfg)
+
+ if cfg.preprocess.use_spkid and cfg.train.multi_speaker_training:
+ if cfg.model.n_speakers == 0:
+ cfg.model.n_speaker = len(self.speakers)
+
+ def _build_model(self):
+ net_g = SynthesizerTrn(
+ self.cfg.model.text_token_num,
+ self.cfg.preprocess.n_fft // 2 + 1,
+ self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
+ **self.cfg.model,
+ )
+ net_d = MultiPeriodDiscriminator(self.cfg.model.use_spectral_norm)
+ model = {"generator": net_g, "discriminator": net_d}
+
+ return model
+
+ def _build_dataset(self):
+ return VITSDataset, VITSCollator
+
+ def _build_optimizer(self):
+ optimizer_g = torch.optim.AdamW(
+ self.model["generator"].parameters(),
+ self.cfg.train.learning_rate,
+ betas=self.cfg.train.AdamW.betas,
+ eps=self.cfg.train.AdamW.eps,
+ )
+ optimizer_d = torch.optim.AdamW(
+ self.model["discriminator"].parameters(),
+ self.cfg.train.learning_rate,
+ betas=self.cfg.train.AdamW.betas,
+ eps=self.cfg.train.AdamW.eps,
+ )
+ optimizer = {"optimizer_g": optimizer_g, "optimizer_d": optimizer_d}
+
+ return optimizer
+
+ def _build_scheduler(self):
+ scheduler_g = ExponentialLR(
+ self.optimizer["optimizer_g"],
+ gamma=self.cfg.train.lr_decay,
+ last_epoch=self.epoch - 1,
+ )
+ scheduler_d = ExponentialLR(
+ self.optimizer["optimizer_d"],
+ gamma=self.cfg.train.lr_decay,
+ last_epoch=self.epoch - 1,
+ )
+
+ scheduler = {"scheduler_g": scheduler_g, "scheduler_d": scheduler_d}
+ return scheduler
+
+ def _build_criterion(self):
+ class GeneratorLoss(nn.Module):
+ def __init__(self, cfg):
+ super(GeneratorLoss, self).__init__()
+ self.cfg = cfg
+ self.l1_loss = nn.L1Loss()
+
+ def generator_loss(self, disc_outputs):
+ loss = 0
+ gen_losses = []
+ for dg in disc_outputs:
+ dg = dg.float()
+ l = torch.mean((1 - dg) ** 2)
+ gen_losses.append(l)
+ loss += l
+
+ return loss, gen_losses
+
+ def feature_loss(self, fmap_r, fmap_g):
+ loss = 0
+ for dr, dg in zip(fmap_r, fmap_g):
+ for rl, gl in zip(dr, dg):
+ rl = rl.float().detach()
+ gl = gl.float()
+ loss += torch.mean(torch.abs(rl - gl))
+
+ return loss * 2
+
+ def kl_loss(self, z_p, logs_q, m_p, logs_p, z_mask):
+ """
+ z_p, logs_q: [b, h, t_t]
+ m_p, logs_p: [b, h, t_t]
+ """
+ z_p = z_p.float()
+ logs_q = logs_q.float()
+ m_p = m_p.float()
+ logs_p = logs_p.float()
+ z_mask = z_mask.float()
+
+ kl = logs_p - logs_q - 0.5
+ kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
+ kl = torch.sum(kl * z_mask)
+ l = kl / torch.sum(z_mask)
+ return l
+
+ def forward(
+ self,
+ outputs_g,
+ outputs_d,
+ y_mel,
+ y_hat_mel,
+ ):
+ loss_g = {}
+
+ # duration loss
+ loss_dur = torch.sum(outputs_g["l_length"].float())
+ loss_g["loss_dur"] = loss_dur
+
+ # mel loss
+ loss_mel = self.l1_loss(y_mel, y_hat_mel) * self.cfg.train.c_mel
+ loss_g["loss_mel"] = loss_mel
+
+ # kl loss
+ loss_kl = (
+ self.kl_loss(
+ outputs_g["z_p"],
+ outputs_g["logs_q"],
+ outputs_g["m_p"],
+ outputs_g["logs_p"],
+ outputs_g["z_mask"],
+ )
+ * self.cfg.train.c_kl
+ )
+ loss_g["loss_kl"] = loss_kl
+
+ # feature loss
+ loss_fm = self.feature_loss(outputs_d["fmap_rs"], outputs_d["fmap_gs"])
+ loss_g["loss_fm"] = loss_fm
+
+ # gan loss
+ loss_gen, losses_gen = self.generator_loss(outputs_d["y_d_hat_g"])
+ loss_g["loss_gen"] = loss_gen
+ loss_g["loss_gen_all"] = (
+ loss_dur + loss_mel + loss_kl + loss_fm + loss_gen
+ )
+
+ return loss_g
+
+ class DiscriminatorLoss(nn.Module):
+ def __init__(self, cfg):
+ super(DiscriminatorLoss, self).__init__()
+ self.cfg = cfg
+ self.l1Loss = torch.nn.L1Loss(reduction="mean")
+
+ def __call__(self, disc_real_outputs, disc_generated_outputs):
+ loss_d = {}
+
+ loss = 0
+ r_losses = []
+ g_losses = []
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
+ dr = dr.float()
+ dg = dg.float()
+ r_loss = torch.mean((1 - dr) ** 2)
+ g_loss = torch.mean(dg**2)
+ loss += r_loss + g_loss
+ r_losses.append(r_loss.item())
+ g_losses.append(g_loss.item())
+
+ loss_d["loss_disc_all"] = loss
+
+ return loss_d
+
+ criterion = {
+ "generator": GeneratorLoss(self.cfg),
+ "discriminator": DiscriminatorLoss(self.cfg),
+ }
+ return criterion
+
+ def write_summary(
+ self,
+ losses,
+ stats,
+ images={},
+ audios={},
+ audio_sampling_rate=24000,
+ tag="train",
+ ):
+ for key, value in losses.items():
+ self.sw.add_scalar(tag + "/" + key, value, self.step)
+ self.sw.add_scalar(
+ "learning_rate",
+ self.optimizer["optimizer_g"].param_groups[0]["lr"],
+ self.step,
+ )
+
+ if len(images) != 0:
+ for key, value in images.items():
+ self.sw.add_image(key, value, self.global_step, batchformats="HWC")
+ if len(audios) != 0:
+ for key, value in audios.items():
+ self.sw.add_audio(key, value, self.global_step, audio_sampling_rate)
+
+ def write_valid_summary(
+ self, losses, stats, images={}, audios={}, audio_sampling_rate=24000, tag="val"
+ ):
+ for key, value in losses.items():
+ self.sw.add_scalar(tag + "/" + key, value, self.step)
+
+ if len(images) != 0:
+ for key, value in images.items():
+ self.sw.add_image(key, value, self.global_step, batchformats="HWC")
+ if len(audios) != 0:
+ for key, value in audios.items():
+ self.sw.add_audio(key, value, self.global_step, audio_sampling_rate)
+
+ def get_state_dict(self):
+ state_dict = {
+ "generator": self.model["generator"].state_dict(),
+ "discriminator": self.model["discriminator"].state_dict(),
+ "optimizer_g": self.optimizer["optimizer_g"].state_dict(),
+ "optimizer_d": self.optimizer["optimizer_d"].state_dict(),
+ "scheduler_g": self.scheduler["scheduler_g"].state_dict(),
+ "scheduler_d": self.scheduler["scheduler_d"].state_dict(),
+ "step": self.step,
+ "epoch": self.epoch,
+ "batch_size": self.cfg.train.batch_size,
+ }
+ return state_dict
+
+ def load_model(self, checkpoint):
+ self.step = checkpoint["step"]
+ self.epoch = checkpoint["epoch"]
+ self.model["generator"].load_state_dict(checkpoint["generator"])
+ self.model["discriminator"].load_state_dict(checkpoint["discriminator"])
+ self.optimizer["optimizer_g"].load_state_dict(checkpoint["optimizer_g"])
+ self.optimizer["optimizer_d"].load_state_dict(checkpoint["optimizer_d"])
+ self.scheduler["scheduler_g"].load_state_dict(checkpoint["scheduler_g"])
+ self.scheduler["scheduler_d"].load_state_dict(checkpoint["scheduler_d"])
+
+ @torch.inference_mode()
+ def _valid_step(self, batch):
+ r"""Testing forward step. Should return average loss of a sample over
+ one batch. Provoke ``_forward_step`` is recommended except for special case.
+ See ``_test_epoch`` for usage.
+ """
+
+ valid_losses = {}
+ total_loss = 0
+ valid_stats = {}
+
+ batch["linear"] = batch["linear"].transpose(2, 1) # [b, d, t]
+ batch["mel"] = batch["mel"].transpose(2, 1) # [b, d, t]
+ batch["audio"] = batch["audio"].unsqueeze(1) # [b, d, t]
+
+ # Discriminator
+ # Generator output
+ outputs_g = self.model["generator"](batch)
+
+ y_mel = slice_segments(
+ batch["mel"],
+ outputs_g["ids_slice"],
+ self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
+ )
+ y_hat_mel = mel_spectrogram_torch(
+ outputs_g["y_hat"].squeeze(1), self.cfg.preprocess
+ )
+ y = slice_segments(
+ batch["audio"],
+ outputs_g["ids_slice"] * self.cfg.preprocess.hop_size,
+ self.cfg.preprocess.segment_size,
+ )
+
+ # Discriminator output
+ outputs_d = self.model["discriminator"](y, outputs_g["y_hat"].detach())
+ ## Discriminator loss
+ loss_d = self.criterion["discriminator"](
+ outputs_d["y_d_hat_r"], outputs_d["y_d_hat_g"]
+ )
+ valid_losses.update(loss_d)
+
+ ## Generator
+ outputs_d = self.model["discriminator"](y, outputs_g["y_hat"])
+ loss_g = self.criterion["generator"](outputs_g, outputs_d, y_mel, y_hat_mel)
+ valid_losses.update(loss_g)
+
+ for item in valid_losses:
+ valid_losses[item] = valid_losses[item].item()
+
+ total_loss = loss_g["loss_gen_all"] + loss_d["loss_disc_all"]
+
+ return (
+ total_loss.item(),
+ valid_losses,
+ valid_stats,
+ )
+
+ def _train_step(self, batch):
+ r"""Forward step for training and inference. This function is called
+ in ``_train_step`` & ``_test_step`` function.
+ """
+
+ train_losses = {}
+ total_loss = 0
+ training_stats = {}
+
+ batch["linear"] = batch["linear"].transpose(2, 1) # [b, d, t]
+ batch["mel"] = batch["mel"].transpose(2, 1) # [b, d, t]
+ batch["audio"] = batch["audio"].unsqueeze(1) # [b, d, t]
+
+ # Train Discriminator
+ # Generator output
+ outputs_g = self.model["generator"](batch)
+
+ y_mel = slice_segments(
+ batch["mel"],
+ outputs_g["ids_slice"],
+ self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
+ )
+ y_hat_mel = mel_spectrogram_torch(
+ outputs_g["y_hat"].squeeze(1), self.cfg.preprocess
+ )
+ y = slice_segments(
+ batch["audio"],
+ outputs_g["ids_slice"] * self.cfg.preprocess.hop_size,
+ self.cfg.preprocess.segment_size,
+ )
+
+ # Discriminator output
+ outputs_d = self.model["discriminator"](y, outputs_g["y_hat"].detach())
+ ## Discriminator loss
+ loss_d = self.criterion["discriminator"](
+ outputs_d["y_d_hat_r"], outputs_d["y_d_hat_g"]
+ )
+ train_losses.update(loss_d)
+
+ # BP and Grad Updated
+ self.optimizer["optimizer_d"].zero_grad()
+ self.accelerator.backward(loss_d["loss_disc_all"])
+ self.optimizer["optimizer_d"].step()
+
+ ## Train Generator
+ outputs_d = self.model["discriminator"](y, outputs_g["y_hat"])
+ loss_g = self.criterion["generator"](outputs_g, outputs_d, y_mel, y_hat_mel)
+ train_losses.update(loss_g)
+
+ # BP and Grad Updated
+ self.optimizer["optimizer_g"].zero_grad()
+ self.accelerator.backward(loss_g["loss_gen_all"])
+ self.optimizer["optimizer_g"].step()
+
+ for item in train_losses:
+ train_losses[item] = train_losses[item].item()
+
+ total_loss = loss_g["loss_gen_all"] + loss_d["loss_disc_all"]
+
+ return (
+ total_loss.item(),
+ train_losses,
+ training_stats,
+ )
+
+ def _train_epoch(self):
+ r"""Training epoch. Should return average loss of a batch (sample) over
+ one epoch. See ``train_loop`` for usage.
+ """
+ epoch_sum_loss: float = 0.0
+ epoch_losses: dict = {}
+ epoch_step: int = 0
+ for batch in tqdm(
+ self.train_dataloader,
+ desc=f"Training Epoch {self.epoch}",
+ unit="batch",
+ colour="GREEN",
+ leave=False,
+ dynamic_ncols=True,
+ smoothing=0.04,
+ disable=not self.accelerator.is_main_process,
+ ):
+ with self.accelerator.accumulate(self.model):
+ total_loss, train_losses, training_stats = self._train_step(batch)
+ self.batch_count += 1
+
+ if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
+ epoch_sum_loss += total_loss
+ for key, value in train_losses.items():
+ if key not in epoch_losses.keys():
+ epoch_losses[key] = value
+ else:
+ epoch_losses[key] += value
+
+ self.accelerator.log(
+ {
+ "Step/Generator Loss": train_losses["loss_gen_all"],
+ "Step/Discriminator Loss": train_losses["loss_disc_all"],
+ "Step/Generator Learning Rate": self.optimizer[
+ "optimizer_d"
+ ].param_groups[0]["lr"],
+ "Step/Discriminator Learning Rate": self.optimizer[
+ "optimizer_g"
+ ].param_groups[0]["lr"],
+ },
+ step=self.step,
+ )
+ self.step += 1
+ epoch_step += 1
+
+ self.accelerator.wait_for_everyone()
+
+ epoch_sum_loss = (
+ epoch_sum_loss
+ / len(self.train_dataloader)
+ * self.cfg.train.gradient_accumulation_step
+ )
+
+ for key in epoch_losses.keys():
+ epoch_losses[key] = (
+ epoch_losses[key]
+ / len(self.train_dataloader)
+ * self.cfg.train.gradient_accumulation_step
+ )
+
+ return epoch_sum_loss, epoch_losses
diff --git a/models/vocoders/autoregressive/autoregressive_vocoder_dataset.py b/models/vocoders/autoregressive/autoregressive_vocoder_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/vocoders/autoregressive/autoregressive_vocoder_inference.py b/models/vocoders/autoregressive/autoregressive_vocoder_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/vocoders/autoregressive/autoregressive_vocoder_trainer.py b/models/vocoders/autoregressive/autoregressive_vocoder_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/vocoders/autoregressive/wavenet/conv.py b/models/vocoders/autoregressive/wavenet/conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..a095aad5d7203f6e5fb5a4d585b894e34dbe63c7
--- /dev/null
+++ b/models/vocoders/autoregressive/wavenet/conv.py
@@ -0,0 +1,66 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from torch import nn
+from torch.nn import functional as F
+
+
+class Conv1d(nn.Conv1d):
+ """Extended nn.Conv1d for incremental dilated convolutions"""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.clear_buffer()
+ self._linearized_weight = None
+ self.register_backward_hook(self._clear_linearized_weight)
+
+ def incremental_forward(self, input):
+ # input (B, T, C)
+ # run forward pre hooks
+ for hook in self._forward_pre_hooks.values():
+ hook(self, input)
+
+ # reshape weight
+ weight = self._get_linearized_weight()
+ kw = self.kernel_size[0]
+ dilation = self.dilation[0]
+
+ bsz = input.size(0)
+ if kw > 1:
+ input = input.data
+ if self.input_buffer is None:
+ self.input_buffer = input.new(
+ bsz, kw + (kw - 1) * (dilation - 1), input.size(2)
+ )
+ self.input_buffer.zero_()
+ else:
+ # shift buffer
+ self.input_buffer[:, :-1, :] = self.input_buffer[:, 1:, :].clone()
+ # append next input
+ self.input_buffer[:, -1, :] = input[:, -1, :]
+ input = self.input_buffer
+ if dilation > 1:
+ input = input[:, 0::dilation, :].contiguous()
+ output = F.linear(input.view(bsz, -1), weight, self.bias)
+ return output.view(bsz, 1, -1)
+
+ def clear_buffer(self):
+ self.input_buffer = None
+
+ def _get_linearized_weight(self):
+ if self._linearized_weight is None:
+ kw = self.kernel_size[0]
+ # nn.Conv1d
+ if self.weight.size() == (self.out_channels, self.in_channels, kw):
+ weight = self.weight.transpose(1, 2).contiguous()
+ else:
+ # fairseq.modules.conv_tbc.ConvTBC
+ weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous()
+ assert weight.size() == (self.out_channels, kw, self.in_channels)
+ self._linearized_weight = weight.view(self.out_channels, -1)
+ return self._linearized_weight
+
+ def _clear_linearized_weight(self, *args):
+ self._linearized_weight = None
diff --git a/models/vocoders/autoregressive/wavenet/modules.py b/models/vocoders/autoregressive/wavenet/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..13d51e52a50af3bc1f7fe9627aeae8d2b1b28b7d
--- /dev/null
+++ b/models/vocoders/autoregressive/wavenet/modules.py
@@ -0,0 +1,152 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import math
+
+from torch import nn
+from torch.nn import functional as F
+
+from .conv import Conv1d as conv_Conv1d
+
+
+def Conv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs):
+ m = conv_Conv1d(in_channels, out_channels, kernel_size, **kwargs)
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ return nn.utils.weight_norm(m)
+
+
+def Conv1d1x1(in_channels, out_channels, bias=True):
+ return Conv1d(
+ in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias
+ )
+
+
+def _conv1x1_forward(conv, x, is_incremental):
+ if is_incremental:
+ x = conv.incremental_forward(x)
+ else:
+ x = conv(x)
+ return x
+
+
+class ResidualConv1dGLU(nn.Module):
+ """Residual dilated conv1d + Gated linear unit
+
+ Args:
+ residual_channels (int): Residual input / output channels
+ gate_channels (int): Gated activation channels.
+ kernel_size (int): Kernel size of convolution layers.
+ skip_out_channels (int): Skip connection channels. If None, set to same
+ as ``residual_channels``.
+ cin_channels (int): Local conditioning channels. If negative value is
+ set, local conditioning is disabled.
+ dropout (float): Dropout probability.
+ padding (int): Padding for convolution layers. If None, proper padding
+ is computed depends on dilation and kernel_size.
+ dilation (int): Dilation factor.
+ """
+
+ def __init__(
+ self,
+ residual_channels,
+ gate_channels,
+ kernel_size,
+ skip_out_channels=None,
+ cin_channels=-1,
+ dropout=1 - 0.95,
+ padding=None,
+ dilation=1,
+ causal=True,
+ bias=True,
+ *args,
+ **kwargs,
+ ):
+ super(ResidualConv1dGLU, self).__init__()
+ self.dropout = dropout
+
+ if skip_out_channels is None:
+ skip_out_channels = residual_channels
+ if padding is None:
+ # no future time stamps available
+ if causal:
+ padding = (kernel_size - 1) * dilation
+ else:
+ padding = (kernel_size - 1) // 2 * dilation
+ self.causal = causal
+
+ self.conv = Conv1d(
+ residual_channels,
+ gate_channels,
+ kernel_size,
+ padding=padding,
+ dilation=dilation,
+ bias=bias,
+ *args,
+ **kwargs,
+ )
+
+ # mel conditioning
+ self.conv1x1c = Conv1d1x1(cin_channels, gate_channels, bias=False)
+
+ gate_out_channels = gate_channels // 2
+ self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels, bias=bias)
+ self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_out_channels, bias=bias)
+
+ def forward(self, x, c=None):
+ return self._forward(x, c, False)
+
+ def incremental_forward(self, x, c=None):
+ return self._forward(x, c, True)
+
+ def clear_buffer(self):
+ for c in [
+ self.conv,
+ self.conv1x1_out,
+ self.conv1x1_skip,
+ self.conv1x1c,
+ ]:
+ if c is not None:
+ c.clear_buffer()
+
+ def _forward(self, x, c, is_incremental):
+ """Forward
+
+ Args:
+ x (Tensor): B x C x T
+ c (Tensor): B x C x T, Mel conditioning features
+ Returns:
+ Tensor: output
+ """
+ residual = x
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ if is_incremental:
+ splitdim = -1
+ x = self.conv.incremental_forward(x)
+ else:
+ splitdim = 1
+ x = self.conv(x)
+ # remove future time steps
+ x = x[:, :, : residual.size(-1)] if self.causal else x
+
+ a, b = x.split(x.size(splitdim) // 2, dim=splitdim)
+
+ assert self.conv1x1c is not None
+ c = _conv1x1_forward(self.conv1x1c, c, is_incremental)
+ ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim)
+ a, b = a + ca, b + cb
+
+ x = torch.tanh(a) * torch.sigmoid(b)
+
+ # For skip connection
+ s = _conv1x1_forward(self.conv1x1_skip, x, is_incremental)
+
+ # For residual connection
+ x = _conv1x1_forward(self.conv1x1_out, x, is_incremental)
+
+ x = (x + residual) * math.sqrt(0.5)
+ return x, s
diff --git a/models/vocoders/autoregressive/wavenet/upsample.py b/models/vocoders/autoregressive/wavenet/upsample.py
new file mode 100644
index 0000000000000000000000000000000000000000..b664302cd56545f1709a4f1874ebadd8e9375a9c
--- /dev/null
+++ b/models/vocoders/autoregressive/wavenet/upsample.py
@@ -0,0 +1,109 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+import numpy as np
+
+from torch import nn
+from torch.nn import functional as F
+
+
+class Stretch2d(nn.Module):
+ def __init__(self, x_scale, y_scale, mode="nearest"):
+ super(Stretch2d, self).__init__()
+ self.x_scale = x_scale
+ self.y_scale = y_scale
+ self.mode = mode
+
+ def forward(self, x):
+ return F.interpolate(
+ x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode
+ )
+
+
+def _get_activation(upsample_activation):
+ nonlinear = getattr(nn, upsample_activation)
+ return nonlinear
+
+
+class UpsampleNetwork(nn.Module):
+ def __init__(
+ self,
+ upsample_scales,
+ upsample_activation="none",
+ upsample_activation_params={},
+ mode="nearest",
+ freq_axis_kernel_size=1,
+ cin_pad=0,
+ cin_channels=128,
+ ):
+ super(UpsampleNetwork, self).__init__()
+ self.up_layers = nn.ModuleList()
+ total_scale = np.prod(upsample_scales)
+ self.indent = cin_pad * total_scale
+ for scale in upsample_scales:
+ freq_axis_padding = (freq_axis_kernel_size - 1) // 2
+ k_size = (freq_axis_kernel_size, scale * 2 + 1)
+ padding = (freq_axis_padding, scale)
+ stretch = Stretch2d(scale, 1, mode)
+ conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False)
+ conv.weight.data.fill_(1.0 / np.prod(k_size))
+ conv = nn.utils.weight_norm(conv)
+ self.up_layers.append(stretch)
+ self.up_layers.append(conv)
+ if upsample_activation != "none":
+ nonlinear = _get_activation(upsample_activation)
+ self.up_layers.append(nonlinear(**upsample_activation_params))
+
+ def forward(self, c):
+ """
+ Args:
+ c : B x C x T
+ """
+
+ # B x 1 x C x T
+ c = c.unsqueeze(1)
+ for f in self.up_layers:
+ c = f(c)
+ # B x C x T
+ c = c.squeeze(1)
+
+ if self.indent > 0:
+ c = c[:, :, self.indent : -self.indent]
+ return c
+
+
+class ConvInUpsampleNetwork(nn.Module):
+ def __init__(
+ self,
+ upsample_scales,
+ upsample_activation="none",
+ upsample_activation_params={},
+ mode="nearest",
+ freq_axis_kernel_size=1,
+ cin_pad=0,
+ cin_channels=128,
+ ):
+ super(ConvInUpsampleNetwork, self).__init__()
+ # To capture wide-context information in conditional features
+ # meaningless if cin_pad == 0
+ ks = 2 * cin_pad + 1
+ self.conv_in = nn.Conv1d(
+ cin_channels, cin_channels, kernel_size=ks, padding=cin_pad, bias=False
+ )
+ self.upsample = UpsampleNetwork(
+ upsample_scales,
+ upsample_activation,
+ upsample_activation_params,
+ mode,
+ freq_axis_kernel_size,
+ cin_pad=cin_pad,
+ cin_channels=cin_channels,
+ )
+
+ def forward(self, c):
+ c_up = self.upsample(self.conv_in(c))
+ return c_up
diff --git a/models/vocoders/autoregressive/wavenet/wavenet.py b/models/vocoders/autoregressive/wavenet/wavenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..d63f22c2600fd0f83e5bdf339ebb121b3d2f35e6
--- /dev/null
+++ b/models/vocoders/autoregressive/wavenet/wavenet.py
@@ -0,0 +1,170 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+
+from torch import nn
+from torch.nn import functional as F
+
+from .modules import Conv1d1x1, ResidualConv1dGLU
+from .upsample import ConvInUpsampleNetwork
+
+
+def receptive_field_size(
+ total_layers, num_cycles, kernel_size, dilation=lambda x: 2**x
+):
+ """Compute receptive field size
+
+ Args:
+ total_layers (int): total layers
+ num_cycles (int): cycles
+ kernel_size (int): kernel size
+ dilation (lambda): lambda to compute dilation factor. ``lambda x : 1``
+ to disable dilated convolution.
+
+ Returns:
+ int: receptive field size in sample
+
+ """
+ assert total_layers % num_cycles == 0
+
+ layers_per_cycle = total_layers // num_cycles
+ dilations = [dilation(i % layers_per_cycle) for i in range(total_layers)]
+ return (kernel_size - 1) * sum(dilations) + 1
+
+
+class WaveNet(nn.Module):
+ """The WaveNet model that supports local and global conditioning.
+
+ Args:
+ out_channels (int): Output channels. If input_type is mu-law quantized
+ one-hot vecror. this must equal to the quantize channels. Other wise
+ num_mixtures x 3 (pi, mu, log_scale).
+ layers (int): Number of total layers
+ stacks (int): Number of dilation cycles
+ residual_channels (int): Residual input / output channels
+ gate_channels (int): Gated activation channels.
+ skip_out_channels (int): Skip connection channels.
+ kernel_size (int): Kernel size of convolution layers.
+ dropout (float): Dropout probability.
+ input_dim (int): Number of mel-spec dimension.
+ upsample_scales (list): List of upsample scale.
+ ``np.prod(upsample_scales)`` must equal to hop size. Used only if
+ upsample_conditional_features is enabled.
+ freq_axis_kernel_size (int): Freq-axis kernel_size for transposed
+ convolution layers for upsampling. If you only care about time-axis
+ upsampling, set this to 1.
+ scalar_input (Bool): If True, scalar input ([-1, 1]) is expected, otherwise
+ quantized one-hot vector is expected..
+ """
+
+ def __init__(self, cfg):
+ super(WaveNet, self).__init__()
+ self.cfg = cfg
+ self.scalar_input = self.cfg.VOCODER.SCALAR_INPUT
+ self.out_channels = self.cfg.VOCODER.OUT_CHANNELS
+ self.cin_channels = self.cfg.VOCODER.INPUT_DIM
+ self.residual_channels = self.cfg.VOCODER.RESIDUAL_CHANNELS
+ self.layers = self.cfg.VOCODER.LAYERS
+ self.stacks = self.cfg.VOCODER.STACKS
+ self.gate_channels = self.cfg.VOCODER.GATE_CHANNELS
+ self.kernel_size = self.cfg.VOCODER.KERNEL_SIZE
+ self.skip_out_channels = self.cfg.VOCODER.SKIP_OUT_CHANNELS
+ self.dropout = self.cfg.VOCODER.DROPOUT
+ self.upsample_scales = self.cfg.VOCODER.UPSAMPLE_SCALES
+ self.mel_frame_pad = self.cfg.VOCODER.MEL_FRAME_PAD
+
+ assert self.layers % self.stacks == 0
+
+ layers_per_stack = self.layers // self.stacks
+ if self.scalar_input:
+ self.first_conv = Conv1d1x1(1, self.residual_channels)
+ else:
+ self.first_conv = Conv1d1x1(self.out_channels, self.residual_channels)
+
+ self.conv_layers = nn.ModuleList()
+ for layer in range(self.layers):
+ dilation = 2 ** (layer % layers_per_stack)
+ conv = ResidualConv1dGLU(
+ self.residual_channels,
+ self.gate_channels,
+ kernel_size=self.kernel_size,
+ skip_out_channels=self.skip_out_channels,
+ bias=True,
+ dilation=dilation,
+ dropout=self.dropout,
+ cin_channels=self.cin_channels,
+ )
+ self.conv_layers.append(conv)
+
+ self.last_conv_layers = nn.ModuleList(
+ [
+ nn.ReLU(inplace=True),
+ Conv1d1x1(self.skip_out_channels, self.skip_out_channels),
+ nn.ReLU(inplace=True),
+ Conv1d1x1(self.skip_out_channels, self.out_channels),
+ ]
+ )
+
+ self.upsample_net = ConvInUpsampleNetwork(
+ upsample_scales=self.upsample_scales,
+ cin_pad=self.mel_frame_pad,
+ cin_channels=self.cin_channels,
+ )
+
+ self.receptive_field = receptive_field_size(
+ self.layers, self.stacks, self.kernel_size
+ )
+
+ def forward(self, x, mel, softmax=False):
+ """Forward step
+
+ Args:
+ x (Tensor): One-hot encoded audio signal, shape (B x C x T)
+ mel (Tensor): Local conditioning features,
+ shape (B x cin_channels x T)
+ softmax (bool): Whether applies softmax or not.
+
+ Returns:
+ Tensor: output, shape B x out_channels x T
+ """
+ B, _, T = x.size()
+
+ mel = self.upsample_net(mel)
+ assert mel.shape[-1] == x.shape[-1]
+
+ x = self.first_conv(x)
+ skips = 0
+ for f in self.conv_layers:
+ x, h = f(x, mel)
+ skips += h
+ skips *= math.sqrt(1.0 / len(self.conv_layers))
+
+ x = skips
+ for f in self.last_conv_layers:
+ x = f(x)
+
+ x = F.softmax(x, dim=1) if softmax else x
+
+ return x
+
+ def clear_buffer(self):
+ self.first_conv.clear_buffer()
+ for f in self.conv_layers:
+ f.clear_buffer()
+ for f in self.last_conv_layers:
+ try:
+ f.clear_buffer()
+ except AttributeError:
+ pass
+
+ def make_generation_fast_(self):
+ def remove_weight_norm(m):
+ try:
+ nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(remove_weight_norm)
diff --git a/models/vocoders/autoregressive/wavernn/wavernn.py b/models/vocoders/autoregressive/wavernn/wavernn.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7475fa8fe8b4575bf714e615349582ff98bbc27
--- /dev/null
+++ b/models/vocoders/autoregressive/wavernn/wavernn.py
@@ -0,0 +1,188 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import numpy as np
+
+
+class ResBlock(nn.Module):
+ def __init__(self, dims):
+ super().__init__()
+ self.conv1 = nn.Conv1d(dims, dims, kernel_size=1, bias=False)
+ self.conv2 = nn.Conv1d(dims, dims, kernel_size=1, bias=False)
+ self.batch_norm1 = nn.BatchNorm1d(dims)
+ self.batch_norm2 = nn.BatchNorm1d(dims)
+
+ def forward(self, x):
+ residual = x
+ x = self.conv1(x)
+ x = self.batch_norm1(x)
+ x = F.relu(x)
+ x = self.conv2(x)
+ x = self.batch_norm2(x)
+ x = x + residual
+ return x
+
+
+class MelResNet(nn.Module):
+ def __init__(self, res_blocks, in_dims, compute_dims, res_out_dims, pad):
+ super().__init__()
+ kernel_size = pad * 2 + 1
+ self.conv_in = nn.Conv1d(
+ in_dims, compute_dims, kernel_size=kernel_size, bias=False
+ )
+ self.batch_norm = nn.BatchNorm1d(compute_dims)
+ self.layers = nn.ModuleList()
+ for i in range(res_blocks):
+ self.layers.append(ResBlock(compute_dims))
+ self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1)
+
+ def forward(self, x):
+ x = self.conv_in(x)
+ x = self.batch_norm(x)
+ x = F.relu(x)
+ for f in self.layers:
+ x = f(x)
+ x = self.conv_out(x)
+ return x
+
+
+class Stretch2d(nn.Module):
+ def __init__(self, x_scale, y_scale):
+ super().__init__()
+ self.x_scale = x_scale
+ self.y_scale = y_scale
+
+ def forward(self, x):
+ b, c, h, w = x.size()
+ x = x.unsqueeze(-1).unsqueeze(3)
+ x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale)
+ return x.view(b, c, h * self.y_scale, w * self.x_scale)
+
+
+class UpsampleNetwork(nn.Module):
+ def __init__(
+ self, feat_dims, upsample_scales, compute_dims, res_blocks, res_out_dims, pad
+ ):
+ super().__init__()
+ total_scale = np.cumproduct(upsample_scales)[-1]
+ self.indent = pad * total_scale
+ self.resnet = MelResNet(res_blocks, feat_dims, compute_dims, res_out_dims, pad)
+ self.resnet_stretch = Stretch2d(total_scale, 1)
+ self.up_layers = nn.ModuleList()
+ for scale in upsample_scales:
+ kernel_size = (1, scale * 2 + 1)
+ padding = (0, scale)
+ stretch = Stretch2d(scale, 1)
+ conv = nn.Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False)
+ conv.weight.data.fill_(1.0 / kernel_size[1])
+ self.up_layers.append(stretch)
+ self.up_layers.append(conv)
+
+ def forward(self, m):
+ aux = self.resnet(m).unsqueeze(1)
+ aux = self.resnet_stretch(aux)
+ aux = aux.squeeze(1)
+ m = m.unsqueeze(1)
+ for f in self.up_layers:
+ m = f(m)
+ m = m.squeeze(1)[:, :, self.indent : -self.indent]
+ return m.transpose(1, 2), aux.transpose(1, 2)
+
+
+class WaveRNN(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+
+ self.cfg = cfg
+ self.pad = self.cfg.VOCODER.MEL_FRAME_PAD
+
+ if self.cfg.VOCODER.MODE == "mu_law_quantize":
+ self.n_classes = 2**self.cfg.VOCODER.BITS
+ elif self.cfg.VOCODER.MODE == "mu_law" or self.cfg.VOCODER:
+ self.n_classes = 30
+
+ self._to_flatten = []
+
+ self.rnn_dims = self.cfg.VOCODER.RNN_DIMS
+ self.aux_dims = self.cfg.VOCODER.RES_OUT_DIMS // 4
+ self.hop_length = self.cfg.VOCODER.HOP_LENGTH
+ self.fc_dims = self.cfg.VOCODER.FC_DIMS
+ self.upsample_factors = self.cfg.VOCODER.UPSAMPLE_FACTORS
+ self.feat_dims = self.cfg.VOCODER.INPUT_DIM
+ self.compute_dims = self.cfg.VOCODER.COMPUTE_DIMS
+ self.res_out_dims = self.cfg.VOCODER.RES_OUT_DIMS
+ self.res_blocks = self.cfg.VOCODER.RES_BLOCKS
+
+ self.upsample = UpsampleNetwork(
+ self.feat_dims,
+ self.upsample_factors,
+ self.compute_dims,
+ self.res_blocks,
+ self.res_out_dims,
+ self.pad,
+ )
+ self.I = nn.Linear(self.feat_dims + self.aux_dims + 1, self.rnn_dims)
+
+ self.rnn1 = nn.GRU(self.rnn_dims, self.rnn_dims, batch_first=True)
+ self.rnn2 = nn.GRU(
+ self.rnn_dims + self.aux_dims, self.rnn_dims, batch_first=True
+ )
+ self._to_flatten += [self.rnn1, self.rnn2]
+
+ self.fc1 = nn.Linear(self.rnn_dims + self.aux_dims, self.fc_dims)
+ self.fc2 = nn.Linear(self.fc_dims + self.aux_dims, self.fc_dims)
+ self.fc3 = nn.Linear(self.fc_dims, self.n_classes)
+
+ self.num_params()
+
+ self._flatten_parameters()
+
+ def forward(self, x, mels):
+ device = next(self.parameters()).device
+
+ self._flatten_parameters()
+
+ batch_size = x.size(0)
+ h1 = torch.zeros(1, batch_size, self.rnn_dims, device=device)
+ h2 = torch.zeros(1, batch_size, self.rnn_dims, device=device)
+ mels, aux = self.upsample(mels)
+
+ aux_idx = [self.aux_dims * i for i in range(5)]
+ a1 = aux[:, :, aux_idx[0] : aux_idx[1]]
+ a2 = aux[:, :, aux_idx[1] : aux_idx[2]]
+ a3 = aux[:, :, aux_idx[2] : aux_idx[3]]
+ a4 = aux[:, :, aux_idx[3] : aux_idx[4]]
+
+ x = torch.cat([x.unsqueeze(-1), mels, a1], dim=2)
+ x = self.I(x)
+ res = x
+ x, _ = self.rnn1(x, h1)
+
+ x = x + res
+ res = x
+ x = torch.cat([x, a2], dim=2)
+ x, _ = self.rnn2(x, h2)
+
+ x = x + res
+ x = torch.cat([x, a3], dim=2)
+ x = F.relu(self.fc1(x))
+
+ x = torch.cat([x, a4], dim=2)
+ x = F.relu(self.fc2(x))
+ return self.fc3(x)
+
+ def num_params(self, print_out=True):
+ parameters = filter(lambda p: p.requires_grad, self.parameters())
+ parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
+ if print_out:
+ print("Trainable Parameters: %.3fM" % parameters)
+ return parameters
+
+ def _flatten_parameters(self):
+ [m.flatten_parameters() for m in self._to_flatten]
diff --git a/models/vocoders/diffusion/diffusion_vocoder_dataset.py b/models/vocoders/diffusion/diffusion_vocoder_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d7991aae6355278fa3cf7888e6c2b89aa883b5d
--- /dev/null
+++ b/models/vocoders/diffusion/diffusion_vocoder_dataset.py
@@ -0,0 +1,166 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import random
+
+import numpy as np
+
+from torch.nn import functional as F
+
+from torch.nn.utils.rnn import pad_sequence
+from utils.data_utils import *
+from models.vocoders.vocoder_dataset import VocoderDataset
+
+
+class DiffusionVocoderDataset(VocoderDataset):
+ def __init__(self, cfg, dataset, is_valid=False):
+ """
+ Args:
+ cfg: config
+ dataset: dataset name
+ is_valid: whether to use train or valid dataset
+ """
+ super().__init__(cfg, dataset, is_valid)
+
+ eval_index = random.randint(0, len(self.metadata) - 1)
+ eval_utt_info = self.metadata[eval_index]
+ eval_utt = "{}_{}".format(eval_utt_info["Dataset"], eval_utt_info["Uid"])
+ self.eval_audio = np.load(self.utt2audio_path[eval_utt])
+ if cfg.preprocess.use_mel:
+ self.eval_mel = np.load(self.utt2mel_path[eval_utt])
+ if cfg.preprocess.use_frame_pitch:
+ self.eval_pitch = np.load(self.utt2frame_pitch_path[eval_utt])
+
+ def __getitem__(self, index):
+ utt_info = self.metadata[index]
+
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ single_feature = dict()
+
+ if self.cfg.preprocess.use_mel:
+ mel = np.load(self.utt2mel_path[utt])
+ assert mel.shape[0] == self.cfg.preprocess.n_mel
+
+ if "target_len" not in single_feature.keys():
+ single_feature["target_len"] = mel.shape[1]
+
+ if single_feature["target_len"] <= self.cfg.preprocess.cut_mel_frame:
+ mel = np.pad(
+ mel,
+ ((0, 0), (0, self.cfg.preprocess.cut_mel_frame - mel.shape[-1])),
+ mode="constant",
+ )
+ else:
+ if "start" not in single_feature.keys():
+ start = random.randint(
+ 0, mel.shape[-1] - self.cfg.preprocess.cut_mel_frame
+ )
+ end = start + self.cfg.preprocess.cut_mel_frame
+ single_feature["start"] = start
+ single_feature["end"] = end
+ mel = mel[:, single_feature["start"] : single_feature["end"]]
+ single_feature["mel"] = mel
+
+ if self.cfg.preprocess.use_frame_pitch:
+ frame_pitch = np.load(self.utt2frame_pitch_path[utt])
+ if "target_len" not in single_feature.keys():
+ single_feature["target_len"] = len(frame_pitch)
+ aligned_frame_pitch = align_length(
+ frame_pitch, single_feature["target_len"]
+ )
+
+ if single_feature["target_len"] <= self.cfg.preprocess.cut_mel_frame:
+ aligned_frame_pitch = np.pad(
+ aligned_frame_pitch,
+ (
+ (
+ 0,
+ self.cfg.preprocess.cut_mel_frame
+ * self.cfg.preprocess.hop_size
+ - audio.shape[-1],
+ )
+ ),
+ mode="constant",
+ )
+ else:
+ if "start" not in single_feature.keys():
+ start = random.randint(
+ 0,
+ aligned_frame_pitch.shape[-1]
+ - self.cfg.preprocess.cut_mel_frame,
+ )
+ end = start + self.cfg.preprocess.cut_mel_frame
+ single_feature["start"] = start
+ single_feature["end"] = end
+ aligned_frame_pitch = aligned_frame_pitch[
+ single_feature["start"] : single_feature["end"]
+ ]
+ single_feature["frame_pitch"] = aligned_frame_pitch
+
+ if self.cfg.preprocess.use_audio:
+ audio = np.load(self.utt2audio_path[utt])
+
+ assert "target_len" in single_feature.keys()
+
+ if (
+ audio.shape[-1]
+ <= self.cfg.preprocess.cut_mel_frame * self.cfg.preprocess.hop_size
+ ):
+ audio = np.pad(
+ audio,
+ (
+ (
+ 0,
+ self.cfg.preprocess.cut_mel_frame
+ * self.cfg.preprocess.hop_size
+ - audio.shape[-1],
+ )
+ ),
+ mode="constant",
+ )
+ else:
+ if "start" not in single_feature.keys():
+ audio = audio[
+ 0 : self.cfg.preprocess.cut_mel_frame
+ * self.cfg.preprocess.hop_size
+ ]
+ else:
+ audio = audio[
+ single_feature["start"]
+ * self.cfg.preprocess.hop_size : single_feature["end"]
+ * self.cfg.preprocess.hop_size,
+ ]
+ single_feature["audio"] = audio
+
+ return single_feature
+
+
+class DiffusionVocoderCollator(object):
+ """Zero-pads model inputs and targets based on number of frames per step"""
+
+ def __init__(self, cfg):
+ self.cfg = cfg
+
+ def __call__(self, batch):
+ packed_batch_features = dict()
+
+ # mel: [b, n_mels, frame]
+ # frame_pitch: [b, frame]
+ # audios: [b, frame * hop_size]
+
+ for key in batch[0].keys():
+ if key in ["target_len", "start", "end"]:
+ continue
+ else:
+ values = [torch.from_numpy(b[key]) for b in batch]
+ packed_batch_features[key] = pad_sequence(
+ values, batch_first=True, padding_value=0
+ )
+
+ return packed_batch_features
diff --git a/models/vocoders/diffusion/diffusion_vocoder_inference.py b/models/vocoders/diffusion/diffusion_vocoder_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9a1afbce5282251748902b4ec22970a27693f17
--- /dev/null
+++ b/models/vocoders/diffusion/diffusion_vocoder_inference.py
@@ -0,0 +1,131 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import numpy as np
+
+from tqdm import tqdm
+from utils.util import pad_mels_to_tensors, pad_f0_to_tensors
+
+
+def vocoder_inference(cfg, model, mels, f0s=None, device=None, fast_inference=False):
+ """Inference the vocoder
+ Args:
+ mels: A tensor of mel-specs with the shape (batch_size, num_mels, frames)
+ Returns:
+ audios: A tensor of audios with the shape (batch_size, seq_len)
+ """
+ model.eval()
+
+ with torch.no_grad():
+ training_noise_schedule = np.array(cfg.model.diffwave.noise_schedule)
+ inference_noise_schedule = (
+ np.array(cfg.model.diffwave.inference_noise_schedule)
+ if fast_inference
+ else np.array(cfg.model.diffwave.noise_schedule)
+ )
+
+ talpha = 1 - training_noise_schedule
+ talpha_cum = np.cumprod(talpha)
+
+ beta = inference_noise_schedule
+ alpha = 1 - beta
+ alpha_cum = np.cumprod(alpha)
+
+ T = []
+ for s in range(len(inference_noise_schedule)):
+ for t in range(len(training_noise_schedule) - 1):
+ if talpha_cum[t + 1] <= alpha_cum[s] <= talpha_cum[t]:
+ twiddle = (talpha_cum[t] ** 0.5 - alpha_cum[s] ** 0.5) / (
+ talpha_cum[t] ** 0.5 - talpha_cum[t + 1] ** 0.5
+ )
+ T.append(t + twiddle)
+ break
+ T = np.array(T, dtype=np.float32)
+
+ mels = mels.to(device)
+ audio = torch.randn(
+ mels.shape[0],
+ cfg.preprocess.hop_size * mels.shape[-1],
+ device=device,
+ )
+
+ for n in tqdm(range(len(alpha) - 1, -1, -1)):
+ c1 = 1 / alpha[n] ** 0.5
+ c2 = beta[n] / (1 - alpha_cum[n]) ** 0.5
+ audio = c1 * (
+ audio
+ - c2
+ * model(audio, torch.tensor([T[n]], device=audio.device), mels).squeeze(
+ 1
+ )
+ )
+ if n > 0:
+ noise = torch.randn_like(audio)
+ sigma = (
+ (1.0 - alpha_cum[n - 1]) / (1.0 - alpha_cum[n]) * beta[n]
+ ) ** 0.5
+ audio += sigma * noise
+ audio = torch.clamp(audio, -1.0, 1.0)
+
+ return audio.detach().cpu()
+
+
+def synthesis_audios(cfg, model, mels, f0s=None, batch_size=None, fast_inference=False):
+ """Inference the vocoder
+ Args:
+ mels: A list of mel-specs
+ Returns:
+ audios: A list of audios
+ """
+ # Get the device
+ device = next(model.parameters()).device
+
+ audios = []
+
+ # Pad the given list into tensors
+ mel_batches, mel_frames = pad_mels_to_tensors(mels, batch_size)
+ if f0s != None:
+ f0_batches = pad_f0_to_tensors(f0s, batch_size)
+
+ if f0s == None:
+ for mel_batch, mel_frame in zip(mel_batches, mel_frames):
+ for i in range(mel_batch.shape[0]):
+ mel = mel_batch[i]
+ frame = mel_frame[i]
+ audio = vocoder_inference(
+ cfg,
+ model,
+ mel.unsqueeze(0),
+ device=device,
+ fast_inference=fast_inference,
+ ).squeeze(0)
+
+ # calculate the audio length
+ audio_length = frame * cfg.preprocess.hop_size
+ audio = audio[:audio_length]
+
+ audios.append(audio)
+ else:
+ for mel_batch, f0_batch, mel_frame in zip(mel_batches, f0_batches, mel_frames):
+ for i in range(mel_batch.shape[0]):
+ mel = mel_batch[i]
+ f0 = f0_batch[i]
+ frame = mel_frame[i]
+ audio = vocoder_inference(
+ cfg,
+ model,
+ mel.unsqueeze(0),
+ f0s=f0.unsqueeze(0),
+ device=device,
+ fast_inference=fast_inference,
+ ).squeeze(0)
+
+ # calculate the audio length
+ audio_length = frame * cfg.preprocess.hop_size
+ audio = audio[:audio_length]
+
+ audios.append(audio)
+ return audios
diff --git a/models/vocoders/diffusion/diffusion_vocoder_trainer.py b/models/vocoders/diffusion/diffusion_vocoder_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..eee8b65fb963181f1ba70da3f34baf31741b7f87
--- /dev/null
+++ b/models/vocoders/diffusion/diffusion_vocoder_trainer.py
@@ -0,0 +1,534 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import sys
+import time
+import torch
+import json
+import itertools
+import accelerate
+import torch.distributed as dist
+import torch.nn as nn
+import torch.nn.functional as F
+from tqdm import tqdm
+from torch.nn.parallel import DistributedDataParallel
+from torch.utils.data import DataLoader
+from torch.utils.data.distributed import DistributedSampler
+from torch.utils.tensorboard import SummaryWriter
+
+from torch.optim import AdamW
+from torch.optim.lr_scheduler import ExponentialLR
+
+from librosa.filters import mel as librosa_mel_fn
+
+from accelerate.logging import get_logger
+from pathlib import Path
+
+from utils.io import save_audio
+from utils.data_utils import *
+from utils.util import (
+ Logger,
+ ValueWindow,
+ remove_older_ckpt,
+ set_all_random_seed,
+ save_config,
+)
+from utils.mel import extract_mel_features
+from models.vocoders.vocoder_trainer import VocoderTrainer
+from models.vocoders.diffusion.diffusion_vocoder_dataset import (
+ DiffusionVocoderDataset,
+ DiffusionVocoderCollator,
+)
+
+from models.vocoders.diffusion.diffwave.diffwave import DiffWave
+
+from models.vocoders.diffusion.diffusion_vocoder_inference import vocoder_inference
+
+supported_models = {
+ "diffwave": DiffWave,
+}
+
+
+class DiffusionVocoderTrainer(VocoderTrainer):
+ def __init__(self, args, cfg):
+ super().__init__()
+
+ self.args = args
+ self.cfg = cfg
+
+ cfg.exp_name = args.exp_name
+
+ # Diffusion
+ self.cfg.model.diffwave.noise_schedule = np.linspace(
+ self.cfg.model.diffwave.noise_schedule_factors[0],
+ self.cfg.model.diffwave.noise_schedule_factors[1],
+ self.cfg.model.diffwave.noise_schedule_factors[2],
+ )
+ beta = np.array(self.cfg.model.diffwave.noise_schedule)
+ noise_level = np.cumprod(1 - beta)
+ self.noise_level = torch.tensor(noise_level.astype(np.float32))
+
+ # Init accelerator
+ self._init_accelerator()
+ self.accelerator.wait_for_everyone()
+
+ # Init logger
+ with self.accelerator.main_process_first():
+ self.logger = get_logger(args.exp_name, log_level=args.log_level)
+
+ self.logger.info("=" * 56)
+ self.logger.info("||\t\t" + "New training process started." + "\t\t||")
+ self.logger.info("=" * 56)
+ self.logger.info("\n")
+ self.logger.debug(f"Using {args.log_level.upper()} logging level.")
+ self.logger.info(f"Experiment name: {args.exp_name}")
+ self.logger.info(f"Experiment directory: {self.exp_dir}")
+ self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
+ if self.accelerator.is_main_process:
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
+ self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
+
+ # Init training status
+ self.batch_count: int = 0
+ self.step: int = 0
+ self.epoch: int = 0
+
+ self.max_epoch = (
+ self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf")
+ )
+ self.logger.info(
+ "Max epoch: {}".format(
+ self.max_epoch if self.max_epoch < float("inf") else "Unlimited"
+ )
+ )
+
+ # Check potential erorrs
+ if self.accelerator.is_main_process:
+ self._check_basic_configs()
+ self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride
+ self.checkpoints_path = [
+ [] for _ in range(len(self.save_checkpoint_stride))
+ ]
+ self.run_eval = self.cfg.train.run_eval
+
+ # Set random seed
+ with self.accelerator.main_process_first():
+ start = time.monotonic_ns()
+ self._set_random_seed(self.cfg.train.random_seed)
+ end = time.monotonic_ns()
+ self.logger.debug(
+ f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
+ )
+ self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
+
+ # Build dataloader
+ with self.accelerator.main_process_first():
+ self.logger.info("Building dataset...")
+ start = time.monotonic_ns()
+ self.train_dataloader, self.valid_dataloader = self._build_dataloader()
+ end = time.monotonic_ns()
+ self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
+
+ # Build model
+ with self.accelerator.main_process_first():
+ self.logger.info("Building model...")
+ start = time.monotonic_ns()
+ self.model = self._build_model()
+ end = time.monotonic_ns()
+ self.logger.debug(self.model)
+ self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms")
+ self.logger.info(f"Model parameters: {self._count_parameters()/1e6:.2f}M")
+
+ # Build optimizers and schedulers
+ with self.accelerator.main_process_first():
+ self.logger.info("Building optimizer and scheduler...")
+ start = time.monotonic_ns()
+ self.optimizer = self._build_optimizer()
+ self.scheduler = self._build_scheduler()
+ end = time.monotonic_ns()
+ self.logger.info(
+ f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms"
+ )
+
+ # Accelerator preparing
+ self.logger.info("Initializing accelerate...")
+ start = time.monotonic_ns()
+ (
+ self.train_dataloader,
+ self.valid_dataloader,
+ self.model,
+ self.optimizer,
+ self.scheduler,
+ ) = self.accelerator.prepare(
+ self.train_dataloader,
+ self.valid_dataloader,
+ self.model,
+ self.optimizer,
+ self.scheduler,
+ )
+ end = time.monotonic_ns()
+ self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms")
+
+ # Build criterions
+ with self.accelerator.main_process_first():
+ self.logger.info("Building criterion...")
+ start = time.monotonic_ns()
+ self.criterion = self._build_criterion()
+ end = time.monotonic_ns()
+ self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms")
+
+ # Resume checkpoints
+ with self.accelerator.main_process_first():
+ if args.resume_type:
+ self.logger.info("Resuming from checkpoint...")
+ start = time.monotonic_ns()
+ ckpt_path = Path(args.checkpoint)
+ if self._is_valid_pattern(ckpt_path.parts[-1]):
+ ckpt_path = self._load_model(
+ None, args.checkpoint, args.resume_type
+ )
+ else:
+ ckpt_path = self._load_model(
+ args.checkpoint, resume_type=args.resume_type
+ )
+ end = time.monotonic_ns()
+ self.logger.info(
+ f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
+ )
+ self.checkpoints_path = json.load(
+ open(os.path.join(ckpt_path, "ckpts.json"), "r")
+ )
+
+ self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
+ if self.accelerator.is_main_process:
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
+ self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
+
+ # Save config
+ self.config_save_path = os.path.join(self.exp_dir, "args.json")
+
+ # Device
+ self.device = next(self.model.parameters()).device
+ self.noise_level = self.noise_level.to(self.device)
+
+ def _build_dataset(self):
+ return DiffusionVocoderDataset, DiffusionVocoderCollator
+
+ def _build_criterion(self):
+ criterion = nn.L1Loss()
+ return criterion
+
+ def _build_model(self):
+ model = supported_models[self.cfg.model.generator](self.cfg)
+ return model
+
+ def _build_optimizer(self):
+ optimizer = AdamW(
+ self.model.parameters(),
+ lr=self.cfg.train.adamw.lr,
+ betas=(self.cfg.train.adamw.adam_b1, self.cfg.train.adamw.adam_b2),
+ )
+ return optimizer
+
+ def _build_scheduler(self):
+ scheduler = ExponentialLR(
+ self.optimizer,
+ gamma=self.cfg.train.exponential_lr.lr_decay,
+ last_epoch=self.epoch - 1,
+ )
+ return scheduler
+
+ def train_loop(self):
+ """Training process"""
+ self.accelerator.wait_for_everyone()
+
+ # Dump config
+ if self.accelerator.is_main_process:
+ self._dump_cfg(self.config_save_path)
+ self.model.train()
+ self.optimizer.zero_grad()
+
+ # Sync and start training
+ self.accelerator.wait_for_everyone()
+ while self.epoch < self.max_epoch:
+ self.logger.info("\n")
+ self.logger.info("-" * 32)
+ self.logger.info("Epoch {}: ".format(self.epoch))
+
+ # Train and Validate
+ train_total_loss = self._train_epoch()
+ valid_total_loss = self._valid_epoch()
+ self.accelerator.log(
+ {
+ "Epoch/Train Total Loss": train_total_loss,
+ "Epoch/Valid Total Loss": valid_total_loss,
+ },
+ step=self.epoch,
+ )
+
+ # Update scheduler
+ self.accelerator.wait_for_everyone()
+ self.scheduler.step()
+
+ # Check save checkpoint interval
+ run_eval = False
+ if self.accelerator.is_main_process:
+ save_checkpoint = False
+ for i, num in enumerate(self.save_checkpoint_stride):
+ if self.epoch % num == 0:
+ save_checkpoint = True
+ run_eval |= self.run_eval[i]
+
+ # Save checkpoints
+ self.accelerator.wait_for_everyone()
+ if self.accelerator.is_main_process and save_checkpoint:
+ path = os.path.join(
+ self.checkpoint_dir,
+ "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
+ self.epoch, self.step, valid_total_loss
+ ),
+ )
+ self.accelerator.save_state(path)
+ json.dump(
+ self.checkpoints_path,
+ open(os.path.join(path, "ckpts.json"), "w"),
+ ensure_ascii=False,
+ indent=4,
+ )
+
+ # Save eval audios
+ self.accelerator.wait_for_everyone()
+ if self.accelerator.is_main_process and run_eval:
+ for i in range(len(self.valid_dataloader.dataset.eval_audios)):
+ if self.cfg.preprocess.use_frame_pitch:
+ eval_audio = self._inference(
+ self.valid_dataloader.dataset.eval_mels[i],
+ eval_pitch=self.valid_dataloader.dataset.eval_pitchs[i],
+ use_pitch=True,
+ )
+ else:
+ eval_audio = self._inference(
+ self.valid_dataloader.dataset.eval_mels[i]
+ )
+ path = os.path.join(
+ self.checkpoint_dir,
+ "epoch-{:04d}_step-{:07d}_loss-{:.6f}_eval_audio_{}.wav".format(
+ self.epoch,
+ self.step,
+ valid_total_loss,
+ self.valid_dataloader.dataset.eval_dataset_names[i],
+ ),
+ )
+ path_gt = os.path.join(
+ self.checkpoint_dir,
+ "epoch-{:04d}_step-{:07d}_loss-{:.6f}_eval_audio_{}_gt.wav".format(
+ self.epoch,
+ self.step,
+ valid_total_loss,
+ self.valid_dataloader.dataset.eval_dataset_names[i],
+ ),
+ )
+ save_audio(path, eval_audio, self.cfg.preprocess.sample_rate)
+ save_audio(
+ path_gt,
+ self.valid_dataloader.dataset.eval_audios[i],
+ self.cfg.preprocess.sample_rate,
+ )
+
+ self.accelerator.wait_for_everyone()
+
+ self.epoch += 1
+
+ # Finish training
+ self.accelerator.wait_for_everyone()
+ path = os.path.join(
+ self.checkpoint_dir,
+ "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
+ self.epoch, self.step, valid_total_loss
+ ),
+ )
+ self.accelerator.save_state(path)
+
+ def _train_epoch(self):
+ """Training epoch. Should return average loss of a batch (sample) over
+ one epoch. See ``train_loop`` for usage.
+ """
+ self.model.train()
+
+ epoch_total_loss: int = 0
+
+ for batch in tqdm(
+ self.train_dataloader,
+ desc=f"Training Epoch {self.epoch}",
+ unit="batch",
+ colour="GREEN",
+ leave=False,
+ dynamic_ncols=True,
+ smoothing=0.04,
+ disable=not self.accelerator.is_main_process,
+ ):
+ # Get losses
+ total_loss = self._train_step(batch)
+ self.batch_count += 1
+
+ # Log info
+ if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
+ self.accelerator.log(
+ {
+ "Step/Learning Rate": self.optimizer.param_groups[0]["lr"],
+ },
+ step=self.step,
+ )
+ epoch_total_loss += total_loss
+ self.step += 1
+
+ # Get and log total losses
+ self.accelerator.wait_for_everyone()
+ epoch_total_loss = (
+ epoch_total_loss
+ / len(self.train_dataloader)
+ * self.cfg.train.gradient_accumulation_step
+ )
+ return epoch_total_loss
+
+ def _train_step(self, data):
+ """Training forward step. Should return average loss of a sample over
+ one batch. Provoke ``_forward_step`` is recommended except for special case.
+ See ``_train_epoch`` for usage.
+ """
+ # Init losses
+ total_loss = 0
+
+ # Use input feature to get predictions
+ mel_input = data["mel"]
+ audio_gt = data["audio"]
+
+ if self.cfg.preprocess.use_frame_pitch:
+ pitch_input = data["frame_pitch"]
+
+ self.optimizer.zero_grad()
+ N = audio_gt.shape[0]
+ t = torch.randint(
+ 0, len(self.cfg.model.diffwave.noise_schedule), [N], device=self.device
+ )
+ noise_scale = self.noise_level[t].unsqueeze(1)
+ noise_scale_sqrt = noise_scale**0.5
+ noise = torch.randn_like(audio_gt).to(self.device)
+ noisy_audio = noise_scale_sqrt * audio_gt + (1.0 - noise_scale) ** 0.5 * noise
+
+ audio_pred = self.model(noisy_audio, t, mel_input)
+ total_loss = self.criterion(noise, audio_pred.squeeze(1))
+
+ self.accelerator.backward(total_loss)
+ self.optimizer.step()
+
+ return total_loss.item()
+
+ def _valid_epoch(self):
+ """Testing epoch. Should return average loss of a batch (sample) over
+ one epoch. See ``train_loop`` for usage.
+ """
+ self.model.eval()
+
+ epoch_total_loss: int = 0
+
+ for batch in tqdm(
+ self.valid_dataloader,
+ desc=f"Validating Epoch {self.epoch}",
+ unit="batch",
+ colour="GREEN",
+ leave=False,
+ dynamic_ncols=True,
+ smoothing=0.04,
+ disable=not self.accelerator.is_main_process,
+ ):
+ # Get losses
+ total_loss = self._valid_step(batch)
+
+ # Log info
+ epoch_total_loss += total_loss
+
+ # Get and log total losses
+ self.accelerator.wait_for_everyone()
+ epoch_total_loss = epoch_total_loss / len(self.valid_dataloader)
+ return epoch_total_loss
+
+ def _valid_step(self, data):
+ """Testing forward step. Should return average loss of a sample over
+ one batch. Provoke ``_forward_step`` is recommended except for special case.
+ See ``_test_epoch`` for usage.
+ """
+ # Init losses
+ total_loss = 0
+
+ # Use feature inputs to get the predicted audio
+ mel_input = data["mel"]
+ audio_gt = data["audio"]
+
+ if self.cfg.preprocess.use_frame_pitch:
+ pitch_input = data["frame_pitch"]
+
+ N = audio_gt.shape[0]
+ t = torch.randint(
+ 0, len(self.cfg.model.diffwave.noise_schedule), [N], device=self.device
+ )
+ noise_scale = self.noise_level[t].unsqueeze(1)
+ noise_scale_sqrt = noise_scale**0.5
+ noise = torch.randn_like(audio_gt)
+ noisy_audio = noise_scale_sqrt * audio_gt + (1.0 - noise_scale) ** 0.5 * noise
+
+ audio_pred = self.model(noisy_audio, t, mel_input)
+ total_loss = self.criterion(noise, audio_pred.squeeze(1))
+
+ return total_loss.item()
+
+ def _inference(self, eval_mel, eval_pitch=None, use_pitch=False):
+ """Inference during training for test audios."""
+ if use_pitch:
+ eval_pitch = align_length(eval_pitch, eval_mel.shape[1])
+ eval_audio = vocoder_inference(
+ self.cfg,
+ self.model,
+ torch.from_numpy(eval_mel).unsqueeze(0),
+ f0s=torch.from_numpy(eval_pitch).unsqueeze(0).float(),
+ device=next(self.model.parameters()).device,
+ ).squeeze(0)
+ else:
+ eval_audio = vocoder_inference(
+ self.cfg,
+ self.model,
+ torch.from_numpy(eval_mel).unsqueeze(0),
+ device=next(self.model.parameters()).device,
+ ).squeeze(0)
+ return eval_audio
+
+ def _load_model(self, checkpoint_dir, checkpoint_path=None, resume_type="resume"):
+ """Load model from checkpoint. If checkpoint_path is None, it will
+ load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
+ None, it will load the checkpoint specified by checkpoint_path. **Only use this
+ method after** ``accelerator.prepare()``.
+ """
+ if checkpoint_path is None:
+ ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
+ ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
+ checkpoint_path = ls[0]
+ if resume_type == "resume":
+ self.accelerator.load_state(checkpoint_path)
+ self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
+ self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
+ elif resume_type == "finetune":
+ accelerate.load_checkpoint_and_dispatch(
+ self.accelerator.unwrap_model(self.model),
+ os.path.join(checkpoint_path, "pytorch_model.bin"),
+ )
+ self.logger.info("Load model weights for finetune SUCCESS!")
+ else:
+ raise ValueError("Unsupported resume type: {}".format(resume_type))
+ return checkpoint_path
+
+ def _count_parameters(self):
+ result = sum(p.numel() for p in self.model.parameters())
+ return result
diff --git a/models/vocoders/diffusion/diffwave/diffwave.py b/models/vocoders/diffusion/diffwave/diffwave.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddc9358a01c4dd0f4badf63ac50667c3ea2ec5e0
--- /dev/null
+++ b/models/vocoders/diffusion/diffwave/diffwave.py
@@ -0,0 +1,179 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# This model code is adopted from DiffWave/model.py under the Apache License
+# https://github.com/lmnt-com/diffwave
+# Only the config-related varaible names are changed.
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from math import sqrt
+
+
+Linear = nn.Linear
+ConvTranspose2d = nn.ConvTranspose2d
+
+
+def Conv1d(*args, **kwargs):
+ layer = nn.Conv1d(*args, **kwargs)
+ nn.init.kaiming_normal_(layer.weight)
+ return layer
+
+
+@torch.jit.script
+def silu(x):
+ return x * torch.sigmoid(x)
+
+
+class DiffusionEmbedding(nn.Module):
+ def __init__(self, max_steps):
+ super().__init__()
+ self.register_buffer(
+ "embedding", self._build_embedding(max_steps), persistent=False
+ )
+ self.projection1 = Linear(128, 512)
+ self.projection2 = Linear(512, 512)
+
+ def forward(self, diffusion_step):
+ if diffusion_step.dtype in [torch.int32, torch.int64]:
+ x = self.embedding[diffusion_step]
+ else:
+ x = self._lerp_embedding(diffusion_step)
+ x = self.projection1(x)
+ x = silu(x)
+ x = self.projection2(x)
+ x = silu(x)
+ return x
+
+ def _lerp_embedding(self, t):
+ low_idx = torch.floor(t).long()
+ high_idx = torch.ceil(t).long()
+ low = self.embedding[low_idx]
+ high = self.embedding[high_idx]
+ return low + (high - low) * (t - low_idx)
+
+ def _build_embedding(self, max_steps):
+ steps = torch.arange(max_steps).unsqueeze(1) # [T,1]
+ dims = torch.arange(64).unsqueeze(0) # [1,64]
+ table = steps * 10.0 ** (dims * 4.0 / 63.0) # [T,64]
+ table = torch.cat([torch.sin(table), torch.cos(table)], dim=1)
+ return table
+
+
+class SpectrogramUpsampler(nn.Module):
+ def __init__(self, upsample_factors):
+ super().__init__()
+ self.conv1 = ConvTranspose2d(
+ 1,
+ 1,
+ [3, upsample_factors[0] * 2],
+ stride=[1, upsample_factors[0]],
+ padding=[1, upsample_factors[0] // 2],
+ )
+ self.conv2 = ConvTranspose2d(
+ 1,
+ 1,
+ [3, upsample_factors[1] * 2],
+ stride=[1, upsample_factors[1]],
+ padding=[1, upsample_factors[1] // 2],
+ )
+
+ def forward(self, x):
+ x = torch.unsqueeze(x, 1)
+ x = self.conv1(x)
+ x = F.leaky_relu(x, 0.4)
+ x = self.conv2(x)
+ x = F.leaky_relu(x, 0.4)
+ x = torch.squeeze(x, 1)
+ return x
+
+
+class ResidualBlock(nn.Module):
+ def __init__(self, n_mels, residual_channels, dilation):
+ super().__init__()
+ self.dilated_conv = Conv1d(
+ residual_channels,
+ 2 * residual_channels,
+ 3,
+ padding=dilation,
+ dilation=dilation,
+ )
+ self.diffusion_projection = Linear(512, residual_channels)
+
+ self.conditioner_projection = Conv1d(n_mels, 2 * residual_channels, 1)
+
+ self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1)
+
+ def forward(self, x, diffusion_step, conditioner):
+ diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
+ y = x + diffusion_step
+
+ conditioner = self.conditioner_projection(conditioner)
+ y = self.dilated_conv(y) + conditioner
+
+ gate, filter = torch.chunk(y, 2, dim=1)
+ y = torch.sigmoid(gate) * torch.tanh(filter)
+
+ y = self.output_projection(y)
+ residual, skip = torch.chunk(y, 2, dim=1)
+ return (x + residual) / sqrt(2.0), skip
+
+
+class DiffWave(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.cfg = cfg
+ self.cfg.model.diffwave.noise_schedule = np.linspace(
+ self.cfg.model.diffwave.noise_schedule_factors[0],
+ self.cfg.model.diffwave.noise_schedule_factors[1],
+ self.cfg.model.diffwave.noise_schedule_factors[2],
+ ).tolist()
+ self.input_projection = Conv1d(1, self.cfg.model.diffwave.residual_channels, 1)
+ self.diffusion_embedding = DiffusionEmbedding(
+ len(self.cfg.model.diffwave.noise_schedule)
+ )
+ self.spectrogram_upsampler = SpectrogramUpsampler(
+ self.cfg.model.diffwave.upsample_factors
+ )
+
+ self.residual_layers = nn.ModuleList(
+ [
+ ResidualBlock(
+ self.cfg.preprocess.n_mel,
+ self.cfg.model.diffwave.residual_channels,
+ 2 ** (i % self.cfg.model.diffwave.dilation_cycle_length),
+ )
+ for i in range(self.cfg.model.diffwave.residual_layers)
+ ]
+ )
+ self.skip_projection = Conv1d(
+ self.cfg.model.diffwave.residual_channels,
+ self.cfg.model.diffwave.residual_channels,
+ 1,
+ )
+ self.output_projection = Conv1d(self.cfg.model.diffwave.residual_channels, 1, 1)
+ nn.init.zeros_(self.output_projection.weight)
+
+ def forward(self, audio, diffusion_step, spectrogram):
+ x = audio.unsqueeze(1)
+ x = self.input_projection(x)
+ x = F.relu(x)
+
+ diffusion_step = self.diffusion_embedding(diffusion_step)
+ spectrogram = self.spectrogram_upsampler(spectrogram)
+
+ skip = None
+ for layer in self.residual_layers:
+ x, skip_connection = layer(x, diffusion_step, spectrogram)
+ skip = skip_connection if skip is None else skip_connection + skip
+
+ x = skip / sqrt(len(self.residual_layers))
+ x = self.skip_projection(x)
+ x = F.relu(x)
+ x = self.output_projection(x)
+ return x
diff --git a/models/vocoders/dsp/world/world.py b/models/vocoders/dsp/world/world.py
new file mode 100644
index 0000000000000000000000000000000000000000..59f28e8e896f883fe6ce243dfb7f254e78fd09c6
--- /dev/null
+++ b/models/vocoders/dsp/world/world.py
@@ -0,0 +1,183 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# 1. Extract WORLD features including F0, AP, SP
+# 2. Transform between SP and MCEP
+import torchaudio
+import pyworld as pw
+import numpy as np
+import torch
+import diffsptk
+import os
+from tqdm import tqdm
+import pickle
+import json
+import re
+import torchaudio
+
+from cuhkszsvc.configs.config_parse import get_wav_path, get_wav_file_path
+from utils.io import has_existed
+
+
+def get_mcep_params(fs):
+ """Hyperparameters of transformation between SP and MCEP
+
+ Reference:
+ https://github.com/CSTR-Edinburgh/merlin/blob/master/misc/scripts/vocoder/world_v2/copy_synthesis.sh
+
+ """
+ if fs in [44100, 48000]:
+ fft_size = 2048
+ alpha = 0.77
+ if fs in [16000]:
+ fft_size = 1024
+ alpha = 0.58
+ return fft_size, alpha
+
+
+def extract_world_features(wave_file, fs, frameshift):
+ # waveform: (1, seq)
+ waveform, sample_rate = torchaudio.load(wave_file)
+ if sample_rate != fs:
+ waveform = torchaudio.functional.resample(
+ waveform, orig_freq=sample_rate, new_freq=fs
+ )
+ # x: (seq,)
+ x = np.array(torch.clamp(waveform[0], -1.0, 1.0), dtype=np.double)
+
+ _f0, t = pw.dio(x, fs, frame_period=frameshift) # raw pitch extractor
+ f0 = pw.stonemask(x, _f0, t, fs) # pitch refinement
+ sp = pw.cheaptrick(x, f0, t, fs) # extract smoothed spectrogram
+ ap = pw.d4c(x, f0, t, fs) # extract aperiodicity
+
+ return f0, sp, ap, fs
+
+
+def sp2mcep(x, mcsize, fs):
+ fft_size, alpha = get_mcep_params(fs)
+ x = torch.as_tensor(x, dtype=torch.float)
+
+ tmp = diffsptk.ScalarOperation("SquareRoot")(x)
+ tmp = diffsptk.ScalarOperation("Multiplication", 32768.0)(tmp)
+ mgc = diffsptk.MelCepstralAnalysis(
+ cep_order=mcsize - 1, fft_length=fft_size, alpha=alpha, n_iter=1
+ )(tmp)
+ return mgc.numpy()
+
+
+def mcep2sp(x, mcsize, fs):
+ fft_size, alpha = get_mcep_params(fs)
+ x = torch.as_tensor(x, dtype=torch.float)
+
+ tmp = diffsptk.MelGeneralizedCepstrumToSpectrum(
+ alpha=alpha,
+ cep_order=mcsize - 1,
+ fft_length=fft_size,
+ )(x)
+ tmp = diffsptk.ScalarOperation("Division", 32768.0)(tmp)
+ sp = diffsptk.ScalarOperation("Power", 2)(tmp)
+ return sp.double().numpy()
+
+
+def extract_mcep_features_of_dataset(
+ output_path, dataset_path, dataset, mcsize, fs, frameshift, splits=None
+):
+ output_dir = os.path.join(output_path, dataset, "mcep/{}".format(fs))
+
+ if not splits:
+ splits = ["train", "test"] if dataset != "m4singer" else ["test"]
+
+ for dataset_type in splits:
+ print("-" * 20)
+ print("Dataset: {}, {}".format(dataset, dataset_type))
+
+ output_file = os.path.join(output_dir, "{}.pkl".format(dataset_type))
+ if has_existed(output_file):
+ continue
+
+ # Extract SP features
+ print("\nExtracting SP featuers...")
+ sp_features = get_world_features_of_dataset(
+ output_path, dataset_path, dataset, dataset_type, fs, frameshift
+ )
+
+ # SP to MCEP
+ print("\nTransform SP to MCEP...")
+ mcep_features = [sp2mcep(sp, mcsize=mcsize, fs=fs) for sp in tqdm(sp_features)]
+
+ # Save
+ os.makedirs(output_dir, exist_ok=True)
+ with open(output_file, "wb") as f:
+ pickle.dump(mcep_features, f)
+
+
+def get_world_features_of_dataset(
+ output_path,
+ dataset_path,
+ dataset,
+ dataset_type,
+ fs,
+ frameshift,
+ save_sp_feature=False,
+):
+ data_dir = os.path.join(output_path, dataset)
+ wave_dir = get_wav_path(dataset_path, dataset)
+
+ # Dataset
+ dataset_file = os.path.join(data_dir, "{}.json".format(dataset_type))
+ if not os.path.exists(dataset_file):
+ print("File {} has not existed.".format(dataset_file))
+ return None
+
+ with open(dataset_file, "r") as f:
+ datasets = json.load(f)
+
+ # Save dir
+ f0_dir = os.path.join(output_path, dataset, "f0")
+ os.makedirs(f0_dir, exist_ok=True)
+
+ # Extract
+ f0_features = []
+ sp_features = []
+ for utt in tqdm(datasets):
+ wave_file = get_wav_file_path(dataset, wave_dir, utt)
+ f0, sp, _, _ = extract_world_features(wave_file, fs, frameshift)
+
+ sp_features.append(sp)
+ f0_features.append(f0)
+
+ # Save sp
+ if save_sp_feature:
+ sp_dir = os.path.join(output_path, dataset, "sp")
+ os.makedirs(sp_dir, exist_ok=True)
+ with open(os.path.join(sp_dir, "{}.pkl".format(dataset_type)), "wb") as f:
+ pickle.dump(sp_features, f)
+
+ # F0 statistics
+ f0_statistics_file = os.path.join(f0_dir, "{}_f0.pkl".format(dataset_type))
+ f0_statistics(f0_features, f0_statistics_file)
+
+ return sp_features
+
+
+def f0_statistics(f0_features, path):
+ print("\nF0 statistics...")
+
+ total_f0 = []
+ for f0 in tqdm(f0_features):
+ total_f0 += [f for f in f0 if f != 0]
+
+ mean = sum(total_f0) / len(total_f0)
+ print("Min = {}, Max = {}, Mean = {}".format(min(total_f0), max(total_f0), mean))
+
+ with open(path, "wb") as f:
+ pickle.dump([mean, total_f0], f)
+
+
+def world_synthesis(f0, sp, ap, fs, frameshift):
+ y = pw.synthesize(
+ f0, sp, ap, fs, frame_period=frameshift
+ ) # synthesize an utterance using the parameters
+ return y
diff --git a/models/vocoders/flow/flow_vocoder_dataset.py b/models/vocoders/flow/flow_vocoder_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/vocoders/flow/flow_vocoder_inference.py b/models/vocoders/flow/flow_vocoder_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/vocoders/flow/flow_vocoder_trainer.py b/models/vocoders/flow/flow_vocoder_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/vocoders/flow/waveglow/waveglow.py b/models/vocoders/flow/waveglow/waveglow.py
new file mode 100644
index 0000000000000000000000000000000000000000..13e2a1bf8f5e3c3d47a031ceec87e4ff111cd5fe
--- /dev/null
+++ b/models/vocoders/flow/waveglow/waveglow.py
@@ -0,0 +1,249 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch.autograd import Variable
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
+ n_channels_int = n_channels[0]
+ in_act = input_a + input_b
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
+ acts = t_act * s_act
+ return acts
+
+
+class Invertible1x1Conv(torch.nn.Module):
+ """
+ The layer outputs both the convolution, and the log determinant
+ of its weight matrix. If reverse=True it does convolution with
+ inverse
+ """
+
+ def __init__(self, c):
+ super(Invertible1x1Conv, self).__init__()
+ self.conv = torch.nn.Conv1d(
+ c, c, kernel_size=1, stride=1, padding=0, bias=False
+ )
+
+ # Sample a random orthonormal matrix to initialize weights
+ W = torch.linalg.qr(torch.FloatTensor(c, c).normal_())[0]
+
+ # Ensure determinant is 1.0 not -1.0
+ if torch.det(W) < 0:
+ W[:, 0] = -1 * W[:, 0]
+ W = W.view(c, c, 1)
+ self.conv.weight.data = W
+
+ def forward(self, z, reverse=False):
+ # shape
+ batch_size, group_size, n_of_groups = z.size()
+
+ W = self.conv.weight.squeeze()
+
+ if reverse:
+ if not hasattr(self, "W_inverse"):
+ # Reverse computation
+ W_inverse = W.float().inverse()
+ W_inverse = Variable(W_inverse[..., None])
+ if z.type() == "torch.cuda.HalfTensor":
+ W_inverse = W_inverse.half()
+ self.W_inverse = W_inverse
+ z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
+ return z
+ else:
+ # Forward computation
+ log_det_W = batch_size * n_of_groups * torch.logdet(W)
+ z = self.conv(z)
+ return z, log_det_W
+
+
+class WN(torch.nn.Module):
+ """
+ This is the WaveNet like layer for the affine coupling. The primary difference
+ from WaveNet is the convolutions need not be causal. There is also no dilation
+ size reset. The dilation only doubles on each layer
+ """
+
+ def __init__(
+ self, n_in_channels, n_mel_channels, n_layers, n_channels, kernel_size
+ ):
+ super(WN, self).__init__()
+ assert kernel_size % 2 == 1
+ assert n_channels % 2 == 0
+ self.n_layers = n_layers
+ self.n_channels = n_channels
+ self.in_layers = torch.nn.ModuleList()
+ self.res_skip_layers = torch.nn.ModuleList()
+
+ start = torch.nn.Conv1d(n_in_channels, n_channels, 1)
+ start = torch.nn.utils.weight_norm(start, name="weight")
+ self.start = start
+
+ # Initializing last layer to 0 makes the affine coupling layers
+ # do nothing at first. This helps with training stability
+ end = torch.nn.Conv1d(n_channels, 2 * n_in_channels, 1)
+ end.weight.data.zero_()
+ end.bias.data.zero_()
+ self.end = end
+
+ cond_layer = torch.nn.Conv1d(n_mel_channels, 2 * n_channels * n_layers, 1)
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
+
+ for i in range(n_layers):
+ dilation = 2**i
+ padding = int((kernel_size * dilation - dilation) / 2)
+ in_layer = torch.nn.Conv1d(
+ n_channels,
+ 2 * n_channels,
+ kernel_size,
+ dilation=dilation,
+ padding=padding,
+ )
+ in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
+ self.in_layers.append(in_layer)
+
+ # last one is not necessary
+ if i < n_layers - 1:
+ res_skip_channels = 2 * n_channels
+ else:
+ res_skip_channels = n_channels
+ res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1)
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
+ self.res_skip_layers.append(res_skip_layer)
+
+ def forward(self, forward_input):
+ audio, spect = forward_input
+ audio = self.start(audio)
+ output = torch.zeros_like(audio)
+ n_channels_tensor = torch.IntTensor([self.n_channels])
+
+ spect = self.cond_layer(spect)
+
+ for i in range(self.n_layers):
+ spect_offset = i * 2 * self.n_channels
+ acts = fused_add_tanh_sigmoid_multiply(
+ self.in_layers[i](audio),
+ spect[:, spect_offset : spect_offset + 2 * self.n_channels, :],
+ n_channels_tensor,
+ )
+
+ res_skip_acts = self.res_skip_layers[i](acts)
+ if i < self.n_layers - 1:
+ audio = audio + res_skip_acts[:, : self.n_channels, :]
+ output = output + res_skip_acts[:, self.n_channels :, :]
+ else:
+ output = output + res_skip_acts
+
+ return self.end(output)
+
+
+class WaveGlow(torch.nn.Module):
+ def __init__(self, cfg):
+ super(WaveGlow, self).__init__()
+
+ self.cfg = cfg
+
+ self.upsample = torch.nn.ConvTranspose1d(
+ self.cfg.VOCODER.INPUT_DIM,
+ self.cfg.VOCODER.INPUT_DIM,
+ 1024,
+ stride=256,
+ )
+ assert self.cfg.VOCODER.N_GROUP % 2 == 0
+ self.n_flows = self.cfg.VOCODER.N_FLOWS
+ self.n_group = self.cfg.VOCODER.N_GROUP
+ self.n_early_every = self.cfg.VOCODER.N_EARLY_EVERY
+ self.n_early_size = self.cfg.VOCODER.N_EARLY_SIZE
+ self.WN = torch.nn.ModuleList()
+ self.convinv = torch.nn.ModuleList()
+
+ n_half = int(self.cfg.VOCODER.N_GROUP / 2)
+
+ # Set up layers with the right sizes based on how many dimensions
+ # have been output already
+ n_remaining_channels = self.cfg.VOCODER.N_GROUP
+ for k in range(self.cfg.VOCODER.N_FLOWS):
+ if k % self.n_early_every == 0 and k > 0:
+ n_half = n_half - int(self.n_early_size / 2)
+ n_remaining_channels = n_remaining_channels - self.n_early_size
+ self.convinv.append(Invertible1x1Conv(n_remaining_channels))
+ self.WN.append(
+ WN(
+ n_half,
+ self.cfg.VOCODER.INPUT_DIM * self.cfg.VOCODER.N_GROUP,
+ self.cfg.VOCODER.N_LAYERS,
+ self.cfg.VOCODER.N_CHANNELS,
+ self.cfg.VOCODER.KERNEL_SIZE,
+ )
+ )
+ self.n_remaining_channels = n_remaining_channels # Useful during inference
+
+ def forward(self, forward_input):
+ """
+ forward_input[0] = mel_spectrogram: batch x n_mel_channels x frames
+ forward_input[1] = audio: batch x time
+ """
+ spect, audio = forward_input
+
+ # Upsample spectrogram to size of audio
+ spect = self.upsample(spect)
+ assert spect.size(2) >= audio.size(1)
+ if spect.size(2) > audio.size(1):
+ spect = spect[:, :, : audio.size(1)]
+
+ spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
+ spect = (
+ spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1)
+ )
+
+ audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1)
+ output_audio = []
+ log_s_list = []
+ log_det_W_list = []
+
+ for k in range(self.n_flows):
+ if k % self.n_early_every == 0 and k > 0:
+ output_audio.append(audio[:, : self.n_early_size, :])
+ audio = audio[:, self.n_early_size :, :]
+
+ audio, log_det_W = self.convinv[k](audio)
+ log_det_W_list.append(log_det_W)
+
+ n_half = int(audio.size(1) / 2)
+ audio_0 = audio[:, :n_half, :]
+ audio_1 = audio[:, n_half:, :]
+
+ output = self.WN[k]((audio_0, spect))
+ log_s = output[:, n_half:, :]
+ b = output[:, :n_half, :]
+ audio_1 = torch.exp(log_s) * audio_1 + b
+ log_s_list.append(log_s)
+
+ audio = torch.cat([audio_0, audio_1], 1)
+
+ output_audio.append(audio)
+ return torch.cat(output_audio, 1), log_s_list, log_det_W_list
+
+ @staticmethod
+ def remove_weightnorm(model):
+ waveglow = model
+ for WN in waveglow.WN:
+ WN.start = torch.nn.utils.remove_weight_norm(WN.start)
+ WN.in_layers = remove(WN.in_layers)
+ WN.cond_layer = torch.nn.utils.remove_weight_norm(WN.cond_layer)
+ WN.res_skip_layers = remove(WN.res_skip_layers)
+ return waveglow
+
+
+def remove(conv_list):
+ new_conv_list = torch.nn.ModuleList()
+ for old_conv in conv_list:
+ old_conv = torch.nn.utils.remove_weight_norm(old_conv)
+ new_conv_list.append(old_conv)
+ return new_conv_list
diff --git a/models/vocoders/gan/discriminator/mpd.py b/models/vocoders/gan/discriminator/mpd.py
new file mode 100644
index 0000000000000000000000000000000000000000..51e4e7921711ace71cea2f36d0f683070a1904c0
--- /dev/null
+++ b/models/vocoders/gan/discriminator/mpd.py
@@ -0,0 +1,391 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn.functional as F
+import torch.nn as nn
+from torch.nn import Conv2d, Conv1d
+from torch.nn.utils import weight_norm, spectral_norm
+from torch import nn
+from modules.vocoder_blocks import *
+from models.vocoders.gan.discriminator.msd import MultiScaleDiscriminator_JETS
+
+LRELU_SLOPE = 0.1
+
+
+class DiscriminatorP(torch.nn.Module):
+ def __init__(self, cfg, period, kernel_size=5, stride=3, use_spectral_norm=False):
+ super(DiscriminatorP, self).__init__()
+ self.period = period
+ self.d_mult = cfg.model.mpd.discriminator_channel_mult_factor
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+ self.convs = nn.ModuleList(
+ [
+ norm_f(
+ Conv2d(
+ 1,
+ int(32 * self.d_mult),
+ (kernel_size, 1),
+ (stride, 1),
+ padding=(get_padding(5, 1), 0),
+ )
+ ),
+ norm_f(
+ Conv2d(
+ int(32 * self.d_mult),
+ int(128 * self.d_mult),
+ (kernel_size, 1),
+ (stride, 1),
+ padding=(get_padding(5, 1), 0),
+ )
+ ),
+ norm_f(
+ Conv2d(
+ int(128 * self.d_mult),
+ int(512 * self.d_mult),
+ (kernel_size, 1),
+ (stride, 1),
+ padding=(get_padding(5, 1), 0),
+ )
+ ),
+ norm_f(
+ Conv2d(
+ int(512 * self.d_mult),
+ int(1024 * self.d_mult),
+ (kernel_size, 1),
+ (stride, 1),
+ padding=(get_padding(5, 1), 0),
+ )
+ ),
+ norm_f(
+ Conv2d(
+ int(1024 * self.d_mult),
+ int(1024 * self.d_mult),
+ (kernel_size, 1),
+ (stride, 1),
+ padding=(2, 0),
+ )
+ ),
+ ]
+ )
+ self.conv_post = norm_f(
+ Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0))
+ )
+
+ def forward(self, x):
+ fmap = []
+
+ # 1d to 2d
+ b, c, t = x.shape
+ if t % self.period != 0: # pad first
+ n_pad = self.period - (t % self.period)
+ x = F.pad(x, (0, n_pad), "reflect")
+ t = t + n_pad
+
+ x = x.view(b, c, t // self.period, self.period)
+
+ for l in self.convs:
+ x = l(x)
+ x = F.leaky_relu(x, LRELU_SLOPE)
+ fmap.append(x)
+
+ x = self.conv_post(x)
+ fmap.append(x)
+ x = torch.flatten(x, 1, -1)
+
+ return x, fmap
+
+
+class MultiPeriodDiscriminator(torch.nn.Module):
+ def __init__(self, cfg):
+ super(MultiPeriodDiscriminator, self).__init__()
+ self.mpd_reshapes = cfg.model.mpd.mpd_reshapes
+ print("mpd_reshapes: {}".format(self.mpd_reshapes))
+ discriminators = [
+ DiscriminatorP(cfg, rs, use_spectral_norm=cfg.model.mpd.use_spectral_norm)
+ for rs in self.mpd_reshapes
+ ]
+ self.discriminators = nn.ModuleList(discriminators)
+
+ def forward(self, y, y_hat):
+ y_d_rs = []
+ y_d_gs = []
+ fmap_rs = []
+ fmap_gs = []
+ for i, d in enumerate(self.discriminators):
+ y_d_r, fmap_r = d(y)
+ y_d_g, fmap_g = d(y_hat)
+ y_d_rs.append(y_d_r)
+ fmap_rs.append(fmap_r)
+ y_d_gs.append(y_d_g)
+ fmap_gs.append(fmap_g)
+
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+# TODO: merge with DiscriminatorP (lmxue, yicheng)
+class DiscriminatorP_vits(torch.nn.Module):
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
+ super(DiscriminatorP_vits, self).__init__()
+ self.period = period
+ self.use_spectral_norm = use_spectral_norm
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+ self.convs = nn.ModuleList(
+ [
+ norm_f(
+ Conv2d(
+ 1,
+ 32,
+ (kernel_size, 1),
+ (stride, 1),
+ padding=(get_padding(kernel_size, 1), 0),
+ )
+ ),
+ norm_f(
+ Conv2d(
+ 32,
+ 128,
+ (kernel_size, 1),
+ (stride, 1),
+ padding=(get_padding(kernel_size, 1), 0),
+ )
+ ),
+ norm_f(
+ Conv2d(
+ 128,
+ 512,
+ (kernel_size, 1),
+ (stride, 1),
+ padding=(get_padding(kernel_size, 1), 0),
+ )
+ ),
+ norm_f(
+ Conv2d(
+ 512,
+ 1024,
+ (kernel_size, 1),
+ (stride, 1),
+ padding=(get_padding(kernel_size, 1), 0),
+ )
+ ),
+ norm_f(
+ Conv2d(
+ 1024,
+ 1024,
+ (kernel_size, 1),
+ 1,
+ padding=(get_padding(kernel_size, 1), 0),
+ )
+ ),
+ ]
+ )
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
+
+ def forward(self, x):
+ fmap = []
+
+ # 1d to 2d
+ b, c, t = x.shape
+ if t % self.period != 0: # pad first
+ n_pad = self.period - (t % self.period)
+ x = F.pad(x, (0, n_pad), "reflect")
+ t = t + n_pad
+ x = x.view(b, c, t // self.period, self.period)
+
+ for l in self.convs:
+ x = l(x)
+ x = F.leaky_relu(x, LRELU_SLOPE)
+ fmap.append(x)
+ x = self.conv_post(x)
+ fmap.append(x)
+ x = torch.flatten(x, 1, -1)
+
+ return x, fmap
+
+
+class DiscriminatorS(torch.nn.Module):
+ def __init__(self, use_spectral_norm=False):
+ super(DiscriminatorS, self).__init__()
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+ self.convs = nn.ModuleList(
+ [
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
+ ]
+ )
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
+
+ def forward(self, x):
+ fmap = []
+
+ for l in self.convs:
+ x = l(x)
+ x = F.leaky_relu(x, LRELU_SLOPE)
+ fmap.append(x)
+ x = self.conv_post(x)
+ fmap.append(x)
+ x = torch.flatten(x, 1, -1)
+
+ return x, fmap
+
+
+# TODO: merge with MultiPeriodDiscriminator (lmxue, yicheng)
+class MultiPeriodDiscriminator_vits(torch.nn.Module):
+ def __init__(self, use_spectral_norm=False):
+ super(MultiPeriodDiscriminator_vits, self).__init__()
+ periods = [2, 3, 5, 7, 11]
+
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
+ discs = discs + [
+ DiscriminatorP_vits(i, use_spectral_norm=use_spectral_norm) for i in periods
+ ]
+ self.discriminators = nn.ModuleList(discs)
+
+ def forward(self, y, y_hat):
+ y_d_rs = []
+ y_d_gs = []
+ fmap_rs = []
+ fmap_gs = []
+ for i, d in enumerate(self.discriminators):
+ y_d_r, fmap_r = d(y)
+ y_d_g, fmap_g = d(y_hat)
+ y_d_rs.append(y_d_r)
+ y_d_gs.append(y_d_g)
+ fmap_rs.append(fmap_r)
+ fmap_gs.append(fmap_g)
+
+ outputs = {
+ "y_d_hat_r": y_d_rs,
+ "y_d_hat_g": y_d_gs,
+ "fmap_rs": fmap_rs,
+ "fmap_gs": fmap_gs,
+ }
+
+ return outputs
+
+
+class DiscriminatorP_JETS(torch.nn.Module):
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
+ super(DiscriminatorP_JETS, self).__init__()
+ self.period = period
+ self.use_spectral_norm = use_spectral_norm
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+ self.convs = nn.ModuleList(
+ [
+ norm_f(
+ Conv2d(
+ 1,
+ 32,
+ (kernel_size, 1),
+ (stride, 1),
+ padding=(get_padding(kernel_size, 1), 0),
+ )
+ ),
+ norm_f(
+ Conv2d(
+ 32,
+ 128,
+ (kernel_size, 1),
+ (stride, 1),
+ padding=(get_padding(kernel_size, 1), 0),
+ )
+ ),
+ norm_f(
+ Conv2d(
+ 128,
+ 512,
+ (kernel_size, 1),
+ (stride, 1),
+ padding=(get_padding(kernel_size, 1), 0),
+ )
+ ),
+ norm_f(
+ Conv2d(
+ 512,
+ 1024,
+ (kernel_size, 1),
+ (stride, 1),
+ padding=(get_padding(kernel_size, 1), 0),
+ )
+ ),
+ norm_f(
+ Conv2d(
+ 1024,
+ 1024,
+ (kernel_size, 1),
+ 1,
+ padding=(get_padding(kernel_size, 1), 0),
+ )
+ ),
+ ]
+ )
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
+
+ def forward(self, x):
+ fmap = []
+
+ # 1d to 2d
+ b, c, t = x.shape
+ if t % self.period != 0: # pad first
+ n_pad = self.period - (t % self.period)
+ x = F.pad(x, (0, n_pad), "reflect")
+ t = t + n_pad
+ x = x.view(b, c, t // self.period, self.period)
+
+ for l in self.convs:
+ x = l(x)
+ x = F.leaky_relu(x, LRELU_SLOPE)
+ fmap.append(x)
+ x = self.conv_post(x)
+ x = torch.flatten(x, 1, -1)
+ fmap.append(x)
+
+ return x, fmap
+
+
+class MultiPeriodDiscriminator_JETS(torch.nn.Module):
+ def __init__(self, use_spectral_norm=False):
+ super(MultiPeriodDiscriminator_JETS, self).__init__()
+ periods = [2, 3, 5, 7, 11]
+
+ discs = [
+ DiscriminatorP_JETS(i, use_spectral_norm=use_spectral_norm) for i in periods
+ ]
+ self.discriminators = nn.ModuleList(discs)
+
+ def forward(self, y):
+ y_d_rs = []
+ fmap_rs = []
+ for i, d in enumerate(self.discriminators):
+ y_d_r, fmap_r = d(y)
+ y_d_rs.append(y_d_r)
+ fmap_rs.append(fmap_r)
+
+ return y_d_rs, fmap_rs
+
+
+# JETS Multi-scale Multi-period discriminator module.
+class MultiScaleMultiPeriodDiscriminator(torch.nn.Module):
+ """HiFi-GAN multi-scale + multi-period discriminator module."""
+
+ def __init__(self, use_spectral_norm=False):
+ super(MultiScaleMultiPeriodDiscriminator, self).__init__()
+
+ self.msd = MultiScaleDiscriminator_JETS()
+ self.mpd = MultiPeriodDiscriminator_JETS()
+
+ def forward(self, y):
+
+ _, msd_outs_d_rs = self.msd(y)
+ # msd_outs = self.msd(y, y_hat)
+ _, mpd_outs_d_rs = self.mpd(y)
+ # mpd_outs = self.mpd(y, y_hat)
+ return msd_outs_d_rs + mpd_outs_d_rs
+ # ground_truth, generated
+ # return msd_outs + mpd_outs
diff --git a/models/vocoders/gan/discriminator/mrd.py b/models/vocoders/gan/discriminator/mrd.py
new file mode 100644
index 0000000000000000000000000000000000000000..38ee80bfbf82b6aa63c80dbc2c6ffed8cb50a924
--- /dev/null
+++ b/models/vocoders/gan/discriminator/mrd.py
@@ -0,0 +1,160 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn.functional as F
+import torch.nn as nn
+from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
+from torch import nn
+
+LRELU_SLOPE = 0.1
+
+
+# This code is a refined MRD adopted from BigVGAN under the MIT License
+# https://github.com/NVIDIA/BigVGAN
+
+
+class DiscriminatorR(nn.Module):
+ def __init__(self, cfg, resolution):
+ super().__init__()
+
+ self.resolution = resolution
+ assert (
+ len(self.resolution) == 3
+ ), "MRD layer requires list with len=3, got {}".format(self.resolution)
+ self.lrelu_slope = LRELU_SLOPE
+
+ norm_f = (
+ weight_norm if cfg.model.mrd.use_spectral_norm == False else spectral_norm
+ )
+ if cfg.model.mrd.mrd_override:
+ print(
+ "INFO: overriding MRD use_spectral_norm as {}".format(
+ cfg.model.mrd.mrd_use_spectral_norm
+ )
+ )
+ norm_f = (
+ weight_norm
+ if cfg.model.mrd.mrd_use_spectral_norm == False
+ else spectral_norm
+ )
+ self.d_mult = cfg.model.mrd.discriminator_channel_mult_factor
+ if cfg.model.mrd.mrd_override:
+ print(
+ "INFO: overriding mrd channel multiplier as {}".format(
+ cfg.model.mrd.mrd_channel_mult
+ )
+ )
+ self.d_mult = cfg.model.mrd.mrd_channel_mult
+
+ self.convs = nn.ModuleList(
+ [
+ norm_f(nn.Conv2d(1, int(32 * self.d_mult), (3, 9), padding=(1, 4))),
+ norm_f(
+ nn.Conv2d(
+ int(32 * self.d_mult),
+ int(32 * self.d_mult),
+ (3, 9),
+ stride=(1, 2),
+ padding=(1, 4),
+ )
+ ),
+ norm_f(
+ nn.Conv2d(
+ int(32 * self.d_mult),
+ int(32 * self.d_mult),
+ (3, 9),
+ stride=(1, 2),
+ padding=(1, 4),
+ )
+ ),
+ norm_f(
+ nn.Conv2d(
+ int(32 * self.d_mult),
+ int(32 * self.d_mult),
+ (3, 9),
+ stride=(1, 2),
+ padding=(1, 4),
+ )
+ ),
+ norm_f(
+ nn.Conv2d(
+ int(32 * self.d_mult),
+ int(32 * self.d_mult),
+ (3, 3),
+ padding=(1, 1),
+ )
+ ),
+ ]
+ )
+ self.conv_post = norm_f(
+ nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1))
+ )
+
+ def forward(self, x):
+ fmap = []
+
+ x = self.spectrogram(x)
+ x = x.unsqueeze(1)
+ for l in self.convs:
+ x = l(x)
+ x = F.leaky_relu(x, self.lrelu_slope)
+ fmap.append(x)
+ x = self.conv_post(x)
+ fmap.append(x)
+ x = torch.flatten(x, 1, -1)
+
+ return x, fmap
+
+ def spectrogram(self, x):
+ n_fft, hop_length, win_length = self.resolution
+ x = F.pad(
+ x,
+ (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
+ mode="reflect",
+ )
+ x = x.squeeze(1)
+ x = torch.stft(
+ x,
+ n_fft=n_fft,
+ hop_length=hop_length,
+ win_length=win_length,
+ center=False,
+ return_complex=True,
+ )
+ x = torch.view_as_real(x) # [B, F, TT, 2]
+ mag = torch.norm(x, p=2, dim=-1) # [B, F, TT]
+
+ return mag
+
+
+class MultiResolutionDiscriminator(nn.Module):
+ def __init__(self, cfg, debug=False):
+ super().__init__()
+ self.resolutions = cfg.model.mrd.resolutions
+ assert (
+ len(self.resolutions) == 3
+ ), "MRD requires list of list with len=3, each element having a list with len=3. got {}".format(
+ self.resolutions
+ )
+ self.discriminators = nn.ModuleList(
+ [DiscriminatorR(cfg, resolution) for resolution in self.resolutions]
+ )
+
+ def forward(self, y, y_hat):
+ y_d_rs = []
+ y_d_gs = []
+ fmap_rs = []
+ fmap_gs = []
+
+ for i, d in enumerate(self.discriminators):
+ y_d_r, fmap_r = d(x=y)
+ y_d_g, fmap_g = d(x=y_hat)
+ y_d_rs.append(y_d_r)
+ fmap_rs.append(fmap_r)
+ y_d_gs.append(y_d_g)
+ fmap_gs.append(fmap_g)
+
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
diff --git a/models/vocoders/gan/discriminator/msd.py b/models/vocoders/gan/discriminator/msd.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c13c27d7235aff31d02839cab07d80f563605c3
--- /dev/null
+++ b/models/vocoders/gan/discriminator/msd.py
@@ -0,0 +1,119 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn.functional as F
+import torch.nn as nn
+from torch.nn import Conv1d, AvgPool1d
+from torch.nn.utils import weight_norm, spectral_norm
+from torch import nn
+from modules.vocoder_blocks import *
+
+
+LRELU_SLOPE = 0.1
+
+
+class DiscriminatorS(nn.Module):
+ def __init__(self, use_spectral_norm=False):
+ super(DiscriminatorS, self).__init__()
+
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+
+ self.convs = nn.ModuleList(
+ [
+ norm_f(Conv1d(1, 128, 15, 1, padding=7)),
+ norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
+ norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
+ norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
+ norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
+ norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
+ ]
+ )
+
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
+
+ def forward(self, x):
+ fmap = []
+
+ for l in self.convs:
+ x = l(x)
+ x = F.leaky_relu(x, LRELU_SLOPE)
+ fmap.append(x)
+
+ x = self.conv_post(x)
+ fmap.append(x)
+ x = torch.flatten(x, 1, -1)
+
+ return x, fmap
+
+
+class MultiScaleDiscriminator(nn.Module):
+ def __init__(self, cfg):
+ super(MultiScaleDiscriminator, self).__init__()
+
+ self.cfg = cfg
+
+ self.discriminators = nn.ModuleList(
+ [
+ DiscriminatorS(use_spectral_norm=True),
+ DiscriminatorS(),
+ DiscriminatorS(),
+ ]
+ )
+
+ self.meanpools = nn.ModuleList(
+ [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]
+ )
+
+ def forward(self, y, y_hat):
+ y_d_rs = []
+ y_d_gs = []
+ fmap_rs = []
+ fmap_gs = []
+
+ for i, d in enumerate(self.discriminators):
+ if i != 0:
+ y = self.meanpools[i - 1](y)
+ y_hat = self.meanpools[i - 1](y_hat)
+ y_d_r, fmap_r = d(y)
+ y_d_g, fmap_g = d(y_hat)
+ y_d_rs.append(y_d_r)
+ fmap_rs.append(fmap_r)
+ y_d_gs.append(y_d_g)
+ fmap_gs.append(fmap_g)
+
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+class MultiScaleDiscriminator_JETS(nn.Module):
+ def __init__(self):
+ super(MultiScaleDiscriminator_JETS, self).__init__()
+
+ self.discriminators = nn.ModuleList(
+ [
+ DiscriminatorS(use_spectral_norm=True),
+ DiscriminatorS(),
+ DiscriminatorS(),
+ ]
+ )
+
+ self.meanpools = nn.ModuleList(
+ [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]
+ )
+
+ def forward(self, y):
+ y_d_rs = [] # p, y, groud-truth
+ fmap_rs = []
+
+ for i, d in enumerate(self.discriminators):
+ if i != 0:
+ y = self.meanpools[i - 1](y)
+ y_d_r, fmap_r = d(y)
+ y_d_rs.append(y_d_r)
+ fmap_rs.append(fmap_r)
+
+ return y_d_rs, fmap_rs
+ # fmap_rs is real, fmap_gs is generated.
diff --git a/models/vocoders/gan/discriminator/mssbcqtd.py b/models/vocoders/gan/discriminator/mssbcqtd.py
new file mode 100644
index 0000000000000000000000000000000000000000..213de5441754944a360707e99a3734ad035d9077
--- /dev/null
+++ b/models/vocoders/gan/discriminator/mssbcqtd.py
@@ -0,0 +1,182 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn.functional as F
+import torch.nn as nn
+from torch import nn
+from modules.vocoder_blocks import *
+
+from einops import rearrange
+import torchaudio.transforms as T
+
+from nnAudio import features
+
+LRELU_SLOPE = 0.1
+
+
+class DiscriminatorCQT(nn.Module):
+ def __init__(self, cfg, hop_length, n_octaves, bins_per_octave):
+ super(DiscriminatorCQT, self).__init__()
+ self.cfg = cfg
+
+ self.filters = cfg.model.mssbcqtd.filters
+ self.max_filters = cfg.model.mssbcqtd.max_filters
+ self.filters_scale = cfg.model.mssbcqtd.filters_scale
+ self.kernel_size = (3, 9)
+ self.dilations = cfg.model.mssbcqtd.dilations
+ self.stride = (1, 2)
+
+ self.in_channels = cfg.model.mssbcqtd.in_channels
+ self.out_channels = cfg.model.mssbcqtd.out_channels
+ self.fs = cfg.preprocess.sample_rate
+ self.hop_length = hop_length
+ self.n_octaves = n_octaves
+ self.bins_per_octave = bins_per_octave
+
+ self.cqt_transform = features.cqt.CQT2010v2(
+ sr=self.fs * 2,
+ hop_length=self.hop_length,
+ n_bins=self.bins_per_octave * self.n_octaves,
+ bins_per_octave=self.bins_per_octave,
+ output_format="Complex",
+ pad_mode="constant",
+ )
+
+ self.conv_pres = nn.ModuleList()
+ for i in range(self.n_octaves):
+ self.conv_pres.append(
+ NormConv2d(
+ self.in_channels * 2,
+ self.in_channels * 2,
+ kernel_size=self.kernel_size,
+ padding=get_2d_padding(self.kernel_size),
+ )
+ )
+
+ self.convs = nn.ModuleList()
+
+ self.convs.append(
+ NormConv2d(
+ self.in_channels * 2,
+ self.filters,
+ kernel_size=self.kernel_size,
+ padding=get_2d_padding(self.kernel_size),
+ )
+ )
+
+ in_chs = min(self.filters_scale * self.filters, self.max_filters)
+ for i, dilation in enumerate(self.dilations):
+ out_chs = min(
+ (self.filters_scale ** (i + 1)) * self.filters, self.max_filters
+ )
+ self.convs.append(
+ NormConv2d(
+ in_chs,
+ out_chs,
+ kernel_size=self.kernel_size,
+ stride=self.stride,
+ dilation=(dilation, 1),
+ padding=get_2d_padding(self.kernel_size, (dilation, 1)),
+ norm="weight_norm",
+ )
+ )
+ in_chs = out_chs
+ out_chs = min(
+ (self.filters_scale ** (len(self.dilations) + 1)) * self.filters,
+ self.max_filters,
+ )
+ self.convs.append(
+ NormConv2d(
+ in_chs,
+ out_chs,
+ kernel_size=(self.kernel_size[0], self.kernel_size[0]),
+ padding=get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
+ norm="weight_norm",
+ )
+ )
+
+ self.conv_post = NormConv2d(
+ out_chs,
+ self.out_channels,
+ kernel_size=(self.kernel_size[0], self.kernel_size[0]),
+ padding=get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
+ norm="weight_norm",
+ )
+
+ self.activation = torch.nn.LeakyReLU(negative_slope=LRELU_SLOPE)
+ self.resample = T.Resample(orig_freq=self.fs, new_freq=self.fs * 2)
+
+ def forward(self, x):
+ fmap = []
+
+ x = self.resample(x)
+
+ z = self.cqt_transform(x)
+
+ z_amplitude = z[:, :, :, 0].unsqueeze(1)
+ z_phase = z[:, :, :, 1].unsqueeze(1)
+
+ z = torch.cat([z_amplitude, z_phase], dim=1)
+ z = rearrange(z, "b c w t -> b c t w")
+
+ latent_z = []
+ for i in range(self.n_octaves):
+ latent_z.append(
+ self.conv_pres[i](
+ z[
+ :,
+ :,
+ :,
+ i * self.bins_per_octave : (i + 1) * self.bins_per_octave,
+ ]
+ )
+ )
+ latent_z = torch.cat(latent_z, dim=-1)
+
+ for i, l in enumerate(self.convs):
+ latent_z = l(latent_z)
+
+ latent_z = self.activation(latent_z)
+ fmap.append(latent_z)
+
+ latent_z = self.conv_post(latent_z)
+
+ return latent_z, fmap
+
+
+class MultiScaleSubbandCQTDiscriminator(nn.Module):
+ def __init__(self, cfg):
+ super(MultiScaleSubbandCQTDiscriminator, self).__init__()
+
+ self.cfg = cfg
+
+ self.discriminators = nn.ModuleList(
+ [
+ DiscriminatorCQT(
+ cfg,
+ hop_length=cfg.model.mssbcqtd.hop_lengths[i],
+ n_octaves=cfg.model.mssbcqtd.n_octaves[i],
+ bins_per_octave=cfg.model.mssbcqtd.bins_per_octaves[i],
+ )
+ for i in range(len(cfg.model.mssbcqtd.hop_lengths))
+ ]
+ )
+
+ def forward(self, y, y_hat):
+ y_d_rs = []
+ y_d_gs = []
+ fmap_rs = []
+ fmap_gs = []
+
+ for disc in self.discriminators:
+ y_d_r, fmap_r = disc(y)
+ y_d_g, fmap_g = disc(y_hat)
+ y_d_rs.append(y_d_r)
+ fmap_rs.append(fmap_r)
+ y_d_gs.append(y_d_g)
+ fmap_gs.append(fmap_g)
+
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
diff --git a/models/vocoders/gan/discriminator/msstftd.py b/models/vocoders/gan/discriminator/msstftd.py
new file mode 100644
index 0000000000000000000000000000000000000000..83dedb78848d2d73ac667e7a191f05de1ed7bf21
--- /dev/null
+++ b/models/vocoders/gan/discriminator/msstftd.py
@@ -0,0 +1,226 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# This code is adopted from META's Encodec under MIT License
+# https://github.com/facebookresearch/encodec
+
+"""MS-STFT discriminator, provided here for reference."""
+
+import typing as tp
+
+import torchaudio
+import torch
+from torch import nn
+from einops import rearrange
+
+from modules.vocoder_blocks import *
+
+
+FeatureMapType = tp.List[torch.Tensor]
+LogitsType = torch.Tensor
+DiscriminatorOutput = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]]
+
+
+def get_2d_padding(
+ kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)
+):
+ return (
+ ((kernel_size[0] - 1) * dilation[0]) // 2,
+ ((kernel_size[1] - 1) * dilation[1]) // 2,
+ )
+
+
+class DiscriminatorSTFT(nn.Module):
+ """STFT sub-discriminator.
+ Args:
+ filters (int): Number of filters in convolutions
+ in_channels (int): Number of input channels. Default: 1
+ out_channels (int): Number of output channels. Default: 1
+ n_fft (int): Size of FFT for each scale. Default: 1024
+ hop_length (int): Length of hop between STFT windows for each scale. Default: 256
+ kernel_size (tuple of int): Inner Conv2d kernel sizes. Default: ``(3, 9)``
+ stride (tuple of int): Inner Conv2d strides. Default: ``(1, 2)``
+ dilations (list of int): Inner Conv2d dilation on the time dimension. Default: ``[1, 2, 4]``
+ win_length (int): Window size for each scale. Default: 1024
+ normalized (bool): Whether to normalize by magnitude after stft. Default: True
+ norm (str): Normalization method. Default: `'weight_norm'`
+ activation (str): Activation function. Default: `'LeakyReLU'`
+ activation_params (dict): Parameters to provide to the activation function.
+ growth (int): Growth factor for the filters. Default: 1
+ """
+
+ def __init__(
+ self,
+ filters: int,
+ in_channels: int = 1,
+ out_channels: int = 1,
+ n_fft: int = 1024,
+ hop_length: int = 256,
+ win_length: int = 1024,
+ max_filters: int = 1024,
+ filters_scale: int = 1,
+ kernel_size: tp.Tuple[int, int] = (3, 9),
+ dilations: tp.List = [1, 2, 4],
+ stride: tp.Tuple[int, int] = (1, 2),
+ normalized: bool = True,
+ norm: str = "weight_norm",
+ activation: str = "LeakyReLU",
+ activation_params: dict = {"negative_slope": 0.2},
+ ):
+ super().__init__()
+ assert len(kernel_size) == 2
+ assert len(stride) == 2
+ self.filters = filters
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.n_fft = n_fft
+ self.hop_length = hop_length
+ self.win_length = win_length
+ self.normalized = normalized
+ self.activation = getattr(torch.nn, activation)(**activation_params)
+ self.spec_transform = torchaudio.transforms.Spectrogram(
+ n_fft=self.n_fft,
+ hop_length=self.hop_length,
+ win_length=self.win_length,
+ window_fn=torch.hann_window,
+ normalized=self.normalized,
+ center=False,
+ pad_mode=None,
+ power=None,
+ )
+ spec_channels = 2 * self.in_channels
+ self.convs = nn.ModuleList()
+ self.convs.append(
+ NormConv2d(
+ spec_channels,
+ self.filters,
+ kernel_size=kernel_size,
+ padding=get_2d_padding(kernel_size),
+ )
+ )
+ in_chs = min(filters_scale * self.filters, max_filters)
+ for i, dilation in enumerate(dilations):
+ out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters)
+ self.convs.append(
+ NormConv2d(
+ in_chs,
+ out_chs,
+ kernel_size=kernel_size,
+ stride=stride,
+ dilation=(dilation, 1),
+ padding=get_2d_padding(kernel_size, (dilation, 1)),
+ norm=norm,
+ )
+ )
+ in_chs = out_chs
+ out_chs = min(
+ (filters_scale ** (len(dilations) + 1)) * self.filters, max_filters
+ )
+ self.convs.append(
+ NormConv2d(
+ in_chs,
+ out_chs,
+ kernel_size=(kernel_size[0], kernel_size[0]),
+ padding=get_2d_padding((kernel_size[0], kernel_size[0])),
+ norm=norm,
+ )
+ )
+ self.conv_post = NormConv2d(
+ out_chs,
+ self.out_channels,
+ kernel_size=(kernel_size[0], kernel_size[0]),
+ padding=get_2d_padding((kernel_size[0], kernel_size[0])),
+ norm=norm,
+ )
+
+ def forward(self, x: torch.Tensor):
+ """Discriminator STFT Module is the sub module of MultiScaleSTFTDiscriminator.
+
+ Args:
+ x (torch.Tensor): input tensor of shape [B, 1, Time]
+
+ Returns:
+ z: z is the output of the last convolutional layer of shape
+ fmap: fmap is the list of feature maps of every convolutional layer of shape
+ """
+ fmap = []
+ z = self.spec_transform(x) # [B, 2, Freq, Frames, 2]
+ z = torch.cat([z.real, z.imag], dim=1)
+ z = rearrange(z, "b c w t -> b c t w")
+ for i, layer in enumerate(self.convs):
+ z = layer(z)
+
+ z = self.activation(z)
+ fmap.append(z)
+ z = self.conv_post(z)
+ return z, fmap
+
+
+class MultiScaleSTFTDiscriminator(nn.Module):
+ """Multi-Scale STFT (MS-STFT) discriminator.
+ Args:
+ filters (int): Number of filters in convolutions
+ in_channels (int): Number of input channels. Default: 1
+ out_channels (int): Number of output channels. Default: 1
+ n_ffts (Sequence[int]): Size of FFT for each scale
+ hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale
+ win_lengths (Sequence[int]): Window size for each scale
+ **kwargs: additional args for STFTDiscriminator
+ """
+
+ def __init__(
+ self,
+ cfg,
+ in_channels: int = 1,
+ out_channels: int = 1,
+ n_ffts: tp.List[int] = [1024, 2048, 512],
+ hop_lengths: tp.List[int] = [256, 512, 256],
+ win_lengths: tp.List[int] = [1024, 2048, 512],
+ **kwargs,
+ ):
+ self.cfg = cfg
+ super().__init__()
+ assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
+ self.discriminators = nn.ModuleList(
+ [
+ DiscriminatorSTFT(
+ filters=self.cfg.model.msstftd.filters,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ n_fft=n_ffts[i],
+ win_length=win_lengths[i],
+ hop_length=hop_lengths[i],
+ **kwargs,
+ )
+ for i in range(len(n_ffts))
+ ]
+ )
+ self.num_discriminators = len(self.discriminators)
+
+ def forward(self, y, y_hat) -> DiscriminatorOutput:
+ """Multi-Scale STFT (MS-STFT) discriminator.
+
+ Args:
+ x (torch.Tensor): input waveform
+
+ Returns:
+ logits: list of every discriminator's output
+ fmaps: list of every discriminator's feature maps,
+ each feature maps is a list of Discriminator STFT's every layer
+ """
+ y_d_rs = []
+ y_d_gs = []
+ fmap_rs = []
+ fmap_gs = []
+
+ for disc in self.discriminators:
+ y_d_r, fmap_r = disc(y)
+ y_d_g, fmap_g = disc(y_hat)
+ y_d_rs.append(y_d_r)
+ fmap_rs.append(fmap_r)
+ y_d_gs.append(y_d_g)
+ fmap_gs.append(fmap_g)
+
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
diff --git a/models/vocoders/gan/gan_vocoder_dataset.py b/models/vocoders/gan/gan_vocoder_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bf87c371647a44fb5bcae33701eda65616e5fd7
--- /dev/null
+++ b/models/vocoders/gan/gan_vocoder_dataset.py
@@ -0,0 +1,205 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import random
+
+import numpy as np
+
+from torch.nn import functional as F
+
+from torch.nn.utils.rnn import pad_sequence
+from utils.data_utils import *
+from models.vocoders.vocoder_dataset import VocoderDataset
+
+
+class GANVocoderDataset(VocoderDataset):
+ def __init__(self, cfg, dataset, is_valid=False):
+ """
+ Args:
+ cfg: config
+ dataset: dataset name
+ is_valid: whether to use train or valid dataset
+ """
+ super().__init__(cfg, dataset, is_valid)
+
+ eval_index = random.randint(0, len(self.metadata) - 1)
+ eval_utt_info = self.metadata[eval_index]
+ eval_utt = "{}_{}".format(eval_utt_info["Dataset"], eval_utt_info["Uid"])
+ self.eval_audio = np.load(self.utt2audio_path[eval_utt])
+ if cfg.preprocess.use_mel:
+ self.eval_mel = np.load(self.utt2mel_path[eval_utt])
+ if cfg.preprocess.use_frame_pitch:
+ self.eval_pitch = np.load(self.utt2frame_pitch_path[eval_utt])
+
+ def __getitem__(self, index):
+ utt_info = self.metadata[index]
+
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ single_feature = dict()
+
+ if self.cfg.preprocess.use_mel:
+ mel = np.load(self.utt2mel_path[utt])
+ assert mel.shape[0] == self.cfg.preprocess.n_mel
+
+ if "target_len" not in single_feature.keys():
+ single_feature["target_len"] = mel.shape[1]
+
+ if single_feature["target_len"] <= self.cfg.preprocess.cut_mel_frame:
+ mel = np.pad(
+ mel,
+ ((0, 0), (0, self.cfg.preprocess.cut_mel_frame - mel.shape[-1])),
+ mode="constant",
+ )
+ else:
+ if "start" not in single_feature.keys():
+ start = random.randint(
+ 0, mel.shape[-1] - self.cfg.preprocess.cut_mel_frame
+ )
+ end = start + self.cfg.preprocess.cut_mel_frame
+ single_feature["start"] = start
+ single_feature["end"] = end
+ mel = mel[:, single_feature["start"] : single_feature["end"]]
+ single_feature["mel"] = mel
+
+ if self.cfg.preprocess.use_frame_pitch:
+ frame_pitch = np.load(self.utt2frame_pitch_path[utt])
+ if "target_len" not in single_feature.keys():
+ single_feature["target_len"] = len(frame_pitch)
+ aligned_frame_pitch = align_length(
+ frame_pitch, single_feature["target_len"]
+ )
+
+ if single_feature["target_len"] <= self.cfg.preprocess.cut_mel_frame:
+ aligned_frame_pitch = np.pad(
+ aligned_frame_pitch,
+ (
+ (
+ 0,
+ self.cfg.preprocess.cut_mel_frame
+ * self.cfg.preprocess.hop_size
+ - audio.shape[-1],
+ )
+ ),
+ mode="constant",
+ )
+ else:
+ if "start" not in single_feature.keys():
+ start = random.randint(
+ 0,
+ aligned_frame_pitch.shape[-1]
+ - self.cfg.preprocess.cut_mel_frame,
+ )
+ end = start + self.cfg.preprocess.cut_mel_frame
+ single_feature["start"] = start
+ single_feature["end"] = end
+ aligned_frame_pitch = aligned_frame_pitch[
+ single_feature["start"] : single_feature["end"]
+ ]
+ single_feature["frame_pitch"] = aligned_frame_pitch
+
+ if self.cfg.preprocess.use_audio:
+ audio = np.load(self.utt2audio_path[utt])
+
+ assert "target_len" in single_feature.keys()
+
+ if (
+ audio.shape[-1]
+ <= self.cfg.preprocess.cut_mel_frame * self.cfg.preprocess.hop_size
+ ):
+ audio = np.pad(
+ audio,
+ (
+ (
+ 0,
+ self.cfg.preprocess.cut_mel_frame
+ * self.cfg.preprocess.hop_size
+ - audio.shape[-1],
+ )
+ ),
+ mode="constant",
+ )
+ else:
+ if "start" not in single_feature.keys():
+ audio = audio[
+ 0 : self.cfg.preprocess.cut_mel_frame
+ * self.cfg.preprocess.hop_size
+ ]
+ else:
+ audio = audio[
+ single_feature["start"]
+ * self.cfg.preprocess.hop_size : single_feature["end"]
+ * self.cfg.preprocess.hop_size,
+ ]
+ single_feature["audio"] = audio
+
+ if self.cfg.preprocess.use_amplitude_phase:
+ logamp = np.load(self.utt2logamp_path[utt])
+ pha = np.load(self.utt2pha_path[utt])
+ rea = np.load(self.utt2rea_path[utt])
+ imag = np.load(self.utt2imag_path[utt])
+
+ assert "target_len" in single_feature.keys()
+
+ if single_feature["target_len"] <= self.cfg.preprocess.cut_mel_frame:
+ logamp = np.pad(
+ logamp,
+ ((0, 0), (0, self.cfg.preprocess.cut_mel_frame - mel.shape[-1])),
+ mode="constant",
+ )
+ pha = np.pad(
+ pha,
+ ((0, 0), (0, self.cfg.preprocess.cut_mel_frame - mel.shape[-1])),
+ mode="constant",
+ )
+ rea = np.pad(
+ rea,
+ ((0, 0), (0, self.cfg.preprocess.cut_mel_frame - mel.shape[-1])),
+ mode="constant",
+ )
+ imag = np.pad(
+ imag,
+ ((0, 0), (0, self.cfg.preprocess.cut_mel_frame - mel.shape[-1])),
+ mode="constant",
+ )
+ else:
+ logamp = logamp[:, single_feature["start"] : single_feature["end"]]
+ pha = pha[:, single_feature["start"] : single_feature["end"]]
+ rea = rea[:, single_feature["start"] : single_feature["end"]]
+ imag = imag[:, single_feature["start"] : single_feature["end"]]
+ single_feature["logamp"] = logamp
+ single_feature["pha"] = pha
+ single_feature["rea"] = rea
+ single_feature["imag"] = imag
+
+ return single_feature
+
+
+class GANVocoderCollator(object):
+ """Zero-pads model inputs and targets based on number of frames per step"""
+
+ def __init__(self, cfg):
+ self.cfg = cfg
+
+ def __call__(self, batch):
+ packed_batch_features = dict()
+
+ # mel: [b, n_mels, frame]
+ # frame_pitch: [b, frame]
+ # audios: [b, frame * hop_size]
+
+ for key in batch[0].keys():
+ if key in ["target_len", "start", "end"]:
+ continue
+ else:
+ values = [torch.from_numpy(b[key]) for b in batch]
+ packed_batch_features[key] = pad_sequence(
+ values, batch_first=True, padding_value=0
+ )
+
+ return packed_batch_features
diff --git a/models/vocoders/gan/gan_vocoder_inference.py b/models/vocoders/gan/gan_vocoder_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..4354d5b5810614da07db435ff76f34241de50c62
--- /dev/null
+++ b/models/vocoders/gan/gan_vocoder_inference.py
@@ -0,0 +1,96 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+from utils.util import pad_mels_to_tensors, pad_f0_to_tensors
+
+
+def vocoder_inference(cfg, model, mels, f0s=None, device=None, fast_inference=False):
+ """Inference the vocoder
+ Args:
+ mels: A tensor of mel-specs with the shape (batch_size, num_mels, frames)
+ Returns:
+ audios: A tensor of audios with the shape (batch_size, seq_len)
+ """
+ model.eval()
+
+ with torch.no_grad():
+ mels = mels.to(device)
+ if f0s != None:
+ f0s = f0s.to(device)
+
+ if f0s == None and not cfg.preprocess.extract_amplitude_phase:
+ output = model.forward(mels)
+ elif cfg.preprocess.extract_amplitude_phase:
+ (
+ _,
+ _,
+ _,
+ _,
+ output,
+ ) = model.forward(mels)
+ else:
+ output = model.forward(mels, f0s)
+
+ return output.squeeze(1).detach().cpu()
+
+
+def synthesis_audios(cfg, model, mels, f0s=None, batch_size=None, fast_inference=False):
+ """Inference the vocoder
+ Args:
+ mels: A list of mel-specs
+ Returns:
+ audios: A list of audios
+ """
+ # Get the device
+ device = next(model.parameters()).device
+
+ audios = []
+
+ # Pad the given list into tensors
+ mel_batches, mel_frames = pad_mels_to_tensors(mels, batch_size)
+ if f0s != None:
+ f0_batches = pad_f0_to_tensors(f0s, batch_size)
+
+ if f0s == None:
+ for mel_batch, mel_frame in zip(mel_batches, mel_frames):
+ for i in range(mel_batch.shape[0]):
+ mel = mel_batch[i]
+ frame = mel_frame[i]
+ audio = vocoder_inference(
+ cfg,
+ model,
+ mel.unsqueeze(0),
+ device=device,
+ fast_inference=fast_inference,
+ ).squeeze(0)
+
+ # calculate the audio length
+ audio_length = frame * model.cfg.preprocess.hop_size
+ audio = audio[:audio_length]
+
+ audios.append(audio)
+ else:
+ for mel_batch, f0_batch, mel_frame in zip(mel_batches, f0_batches, mel_frames):
+ for i in range(mel_batch.shape[0]):
+ mel = mel_batch[i]
+ f0 = f0_batch[i]
+ frame = mel_frame[i]
+ audio = vocoder_inference(
+ cfg,
+ model,
+ mel.unsqueeze(0),
+ f0s=f0.unsqueeze(0),
+ device=device,
+ fast_inference=fast_inference,
+ ).squeeze(0)
+
+ # calculate the audio length
+ audio_length = frame * model.cfg.preprocess.hop_size
+ audio = audio[:audio_length]
+
+ audios.append(audio)
+ return audios
diff --git a/models/vocoders/gan/gan_vocoder_trainer.py b/models/vocoders/gan/gan_vocoder_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7caa4ec63fe2bd11621b41b2c904dc1c62375922
--- /dev/null
+++ b/models/vocoders/gan/gan_vocoder_trainer.py
@@ -0,0 +1,1109 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import sys
+import time
+import torch
+import json
+import itertools
+import accelerate
+import torch.distributed as dist
+import torch.nn.functional as F
+from tqdm import tqdm
+from torch.nn.parallel import DistributedDataParallel
+from torch.utils.data import DataLoader
+from torch.utils.data.distributed import DistributedSampler
+from torch.utils.tensorboard import SummaryWriter
+
+from torch.optim import AdamW
+from torch.optim.lr_scheduler import ExponentialLR
+
+from librosa.filters import mel as librosa_mel_fn
+
+from accelerate.logging import get_logger
+from pathlib import Path
+
+from utils.io import save_audio
+from utils.data_utils import *
+from utils.util import (
+ Logger,
+ ValueWindow,
+ remove_older_ckpt,
+ set_all_random_seed,
+ save_config,
+)
+from utils.mel import extract_mel_features
+from models.vocoders.vocoder_trainer import VocoderTrainer
+from models.vocoders.gan.gan_vocoder_dataset import (
+ GANVocoderDataset,
+ GANVocoderCollator,
+)
+
+from models.vocoders.gan.generator.bigvgan import BigVGAN
+from models.vocoders.gan.generator.hifigan import HiFiGAN
+from models.vocoders.gan.generator.melgan import MelGAN
+from models.vocoders.gan.generator.nsfhifigan import NSFHiFiGAN
+from models.vocoders.gan.generator.apnet import APNet
+
+from models.vocoders.gan.discriminator.mpd import MultiPeriodDiscriminator
+from models.vocoders.gan.discriminator.mrd import MultiResolutionDiscriminator
+from models.vocoders.gan.discriminator.mssbcqtd import MultiScaleSubbandCQTDiscriminator
+from models.vocoders.gan.discriminator.msd import MultiScaleDiscriminator
+from models.vocoders.gan.discriminator.msstftd import MultiScaleSTFTDiscriminator
+
+from models.vocoders.gan.gan_vocoder_inference import vocoder_inference
+
+supported_generators = {
+ "bigvgan": BigVGAN,
+ "hifigan": HiFiGAN,
+ "melgan": MelGAN,
+ "nsfhifigan": NSFHiFiGAN,
+ "apnet": APNet,
+}
+
+supported_discriminators = {
+ "mpd": MultiPeriodDiscriminator,
+ "msd": MultiScaleDiscriminator,
+ "mrd": MultiResolutionDiscriminator,
+ "msstftd": MultiScaleSTFTDiscriminator,
+ "mssbcqtd": MultiScaleSubbandCQTDiscriminator,
+}
+
+
+class GANVocoderTrainer(VocoderTrainer):
+ def __init__(self, args, cfg):
+ super().__init__()
+
+ self.args = args
+ self.cfg = cfg
+
+ cfg.exp_name = args.exp_name
+
+ # Init accelerator
+ self._init_accelerator()
+ self.accelerator.wait_for_everyone()
+
+ # Init logger
+ with self.accelerator.main_process_first():
+ self.logger = get_logger(args.exp_name, log_level=args.log_level)
+
+ self.logger.info("=" * 56)
+ self.logger.info("||\t\t" + "New training process started." + "\t\t||")
+ self.logger.info("=" * 56)
+ self.logger.info("\n")
+ self.logger.debug(f"Using {args.log_level.upper()} logging level.")
+ self.logger.info(f"Experiment name: {args.exp_name}")
+ self.logger.info(f"Experiment directory: {self.exp_dir}")
+ self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
+ if self.accelerator.is_main_process:
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
+ self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
+
+ # Init training status
+ self.batch_count: int = 0
+ self.step: int = 0
+ self.epoch: int = 0
+
+ self.max_epoch = (
+ self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf")
+ )
+ self.logger.info(
+ "Max epoch: {}".format(
+ self.max_epoch if self.max_epoch < float("inf") else "Unlimited"
+ )
+ )
+
+ # Check potential erorrs
+ if self.accelerator.is_main_process:
+ self._check_basic_configs()
+ self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride
+ self.checkpoints_path = [
+ [] for _ in range(len(self.save_checkpoint_stride))
+ ]
+ self.run_eval = self.cfg.train.run_eval
+
+ # Set random seed
+ with self.accelerator.main_process_first():
+ start = time.monotonic_ns()
+ self._set_random_seed(self.cfg.train.random_seed)
+ end = time.monotonic_ns()
+ self.logger.debug(
+ f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
+ )
+ self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
+
+ # Build dataloader
+ with self.accelerator.main_process_first():
+ self.logger.info("Building dataset...")
+ start = time.monotonic_ns()
+ self.train_dataloader, self.valid_dataloader = self._build_dataloader()
+ end = time.monotonic_ns()
+ self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
+
+ # Build model
+ with self.accelerator.main_process_first():
+ self.logger.info("Building model...")
+ start = time.monotonic_ns()
+ self.generator, self.discriminators = self._build_model()
+ end = time.monotonic_ns()
+ self.logger.debug(self.generator)
+ for _, discriminator in self.discriminators.items():
+ self.logger.debug(discriminator)
+ self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms")
+ self.logger.info(f"Model parameters: {self._count_parameters()/1e6:.2f}M")
+
+ # Build optimizers and schedulers
+ with self.accelerator.main_process_first():
+ self.logger.info("Building optimizer and scheduler...")
+ start = time.monotonic_ns()
+ (
+ self.generator_optimizer,
+ self.discriminator_optimizer,
+ ) = self._build_optimizer()
+ (
+ self.generator_scheduler,
+ self.discriminator_scheduler,
+ ) = self._build_scheduler()
+ end = time.monotonic_ns()
+ self.logger.info(
+ f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms"
+ )
+
+ # Accelerator preparing
+ self.logger.info("Initializing accelerate...")
+ start = time.monotonic_ns()
+ (
+ self.train_dataloader,
+ self.valid_dataloader,
+ self.generator,
+ self.generator_optimizer,
+ self.discriminator_optimizer,
+ self.generator_scheduler,
+ self.discriminator_scheduler,
+ ) = self.accelerator.prepare(
+ self.train_dataloader,
+ self.valid_dataloader,
+ self.generator,
+ self.generator_optimizer,
+ self.discriminator_optimizer,
+ self.generator_scheduler,
+ self.discriminator_scheduler,
+ )
+ for key, discriminator in self.discriminators.items():
+ self.discriminators[key] = self.accelerator.prepare_model(discriminator)
+ end = time.monotonic_ns()
+ self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms")
+
+ # Build criterions
+ with self.accelerator.main_process_first():
+ self.logger.info("Building criterion...")
+ start = time.monotonic_ns()
+ self.criterions = self._build_criterion()
+ end = time.monotonic_ns()
+ self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms")
+
+ # Resume checkpoints
+ with self.accelerator.main_process_first():
+ if args.resume_type:
+ self.logger.info("Resuming from checkpoint...")
+ start = time.monotonic_ns()
+ ckpt_path = Path(args.checkpoint)
+ if self._is_valid_pattern(ckpt_path.parts[-1]):
+ ckpt_path = self._load_model(
+ None, args.checkpoint, args.resume_type
+ )
+ else:
+ ckpt_path = self._load_model(
+ args.checkpoint, resume_type=args.resume_type
+ )
+ end = time.monotonic_ns()
+ self.logger.info(
+ f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
+ )
+ self.checkpoints_path = json.load(
+ open(os.path.join(ckpt_path, "ckpts.json"), "r")
+ )
+
+ self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
+ if self.accelerator.is_main_process:
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
+ self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
+
+ # Save config
+ self.config_save_path = os.path.join(self.exp_dir, "args.json")
+
+ def _build_dataset(self):
+ return GANVocoderDataset, GANVocoderCollator
+
+ def _build_criterion(self):
+ class feature_criterion(torch.nn.Module):
+ def __init__(self, cfg):
+ super(feature_criterion, self).__init__()
+ self.cfg = cfg
+ self.l1Loss = torch.nn.L1Loss(reduction="mean")
+ self.l2Loss = torch.nn.MSELoss(reduction="mean")
+ self.relu = torch.nn.ReLU()
+
+ def __call__(self, fmap_r, fmap_g):
+ loss = 0
+
+ if self.cfg.model.generator in [
+ "hifigan",
+ "nsfhifigan",
+ "bigvgan",
+ "apnet",
+ ]:
+ for dr, dg in zip(fmap_r, fmap_g):
+ for rl, gl in zip(dr, dg):
+ loss += torch.mean(torch.abs(rl - gl))
+
+ loss = loss * 2
+ elif self.cfg.model.generator in ["melgan"]:
+ for dr, dg in zip(fmap_r, fmap_g):
+ for rl, gl in zip(dr, dg):
+ loss += self.l1Loss(rl, gl)
+
+ loss = loss * 10
+ elif self.cfg.model.generator in ["codec"]:
+ for dr, dg in zip(fmap_r, fmap_g):
+ for rl, gl in zip(dr, dg):
+ loss = loss + self.l1Loss(rl, gl) / torch.mean(
+ torch.abs(rl)
+ )
+
+ KL_scale = len(fmap_r) * len(fmap_r[0])
+
+ loss = 3 * loss / KL_scale
+ else:
+ raise NotImplementedError
+
+ return loss
+
+ class discriminator_criterion(torch.nn.Module):
+ def __init__(self, cfg):
+ super(discriminator_criterion, self).__init__()
+ self.cfg = cfg
+ self.l1Loss = torch.nn.L1Loss(reduction="mean")
+ self.l2Loss = torch.nn.MSELoss(reduction="mean")
+ self.relu = torch.nn.ReLU()
+
+ def __call__(self, disc_real_outputs, disc_generated_outputs):
+ loss = 0
+ r_losses = []
+ g_losses = []
+
+ if self.cfg.model.generator in [
+ "hifigan",
+ "nsfhifigan",
+ "bigvgan",
+ "apnet",
+ ]:
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
+ r_loss = torch.mean((1 - dr) ** 2)
+ g_loss = torch.mean(dg**2)
+ loss += r_loss + g_loss
+ r_losses.append(r_loss.item())
+ g_losses.append(g_loss.item())
+ elif self.cfg.model.generator in ["melgan"]:
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
+ r_loss = torch.mean(self.relu(1 - dr))
+ g_loss = torch.mean(self.relu(1 + dg))
+ loss = loss + r_loss + g_loss
+ r_losses.append(r_loss.item())
+ g_losses.append(g_loss.item())
+ elif self.cfg.model.generator in ["codec"]:
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
+ r_loss = torch.mean(self.relu(1 - dr))
+ g_loss = torch.mean(self.relu(1 + dg))
+ loss = loss + r_loss + g_loss
+ r_losses.append(r_loss.item())
+ g_losses.append(g_loss.item())
+
+ loss = loss / len(disc_real_outputs)
+ else:
+ raise NotImplementedError
+
+ return loss, r_losses, g_losses
+
+ class generator_criterion(torch.nn.Module):
+ def __init__(self, cfg):
+ super(generator_criterion, self).__init__()
+ self.cfg = cfg
+ self.l1Loss = torch.nn.L1Loss(reduction="mean")
+ self.l2Loss = torch.nn.MSELoss(reduction="mean")
+ self.relu = torch.nn.ReLU()
+
+ def __call__(self, disc_outputs):
+ loss = 0
+ gen_losses = []
+
+ if self.cfg.model.generator in [
+ "hifigan",
+ "nsfhifigan",
+ "bigvgan",
+ "apnet",
+ ]:
+ for dg in disc_outputs:
+ l = torch.mean((1 - dg) ** 2)
+ gen_losses.append(l)
+ loss += l
+ elif self.cfg.model.generator in ["melgan"]:
+ for dg in disc_outputs:
+ l = -torch.mean(dg)
+ gen_losses.append(l)
+ loss += l
+ elif self.cfg.model.generator in ["codec"]:
+ for dg in disc_outputs:
+ l = torch.mean(self.relu(1 - dg)) / len(disc_outputs)
+ gen_losses.append(l)
+ loss += l
+ else:
+ raise NotImplementedError
+
+ return loss, gen_losses
+
+ class mel_criterion(torch.nn.Module):
+ def __init__(self, cfg):
+ super(mel_criterion, self).__init__()
+ self.cfg = cfg
+ self.l1Loss = torch.nn.L1Loss(reduction="mean")
+ self.l2Loss = torch.nn.MSELoss(reduction="mean")
+ self.relu = torch.nn.ReLU()
+
+ def __call__(self, y_gt, y_pred):
+ loss = 0
+
+ if self.cfg.model.generator in [
+ "hifigan",
+ "nsfhifigan",
+ "bigvgan",
+ "melgan",
+ "codec",
+ "apnet",
+ ]:
+ y_gt_mel = extract_mel_features(y_gt, self.cfg.preprocess)
+ y_pred_mel = extract_mel_features(
+ y_pred.squeeze(1), self.cfg.preprocess
+ )
+
+ loss = self.l1Loss(y_gt_mel, y_pred_mel) * 45
+ else:
+ raise NotImplementedError
+
+ return loss
+
+ class wav_criterion(torch.nn.Module):
+ def __init__(self, cfg):
+ super(wav_criterion, self).__init__()
+ self.cfg = cfg
+ self.l1Loss = torch.nn.L1Loss(reduction="mean")
+ self.l2Loss = torch.nn.MSELoss(reduction="mean")
+ self.relu = torch.nn.ReLU()
+
+ def __call__(self, y_gt, y_pred):
+ loss = 0
+
+ if self.cfg.model.generator in [
+ "hifigan",
+ "nsfhifigan",
+ "bigvgan",
+ "apnet",
+ ]:
+ loss = self.l2Loss(y_gt, y_pred.squeeze(1)) * 100
+ elif self.cfg.model.generator in ["melgan"]:
+ loss = self.l1Loss(y_gt, y_pred.squeeze(1)) / 10
+ elif self.cfg.model.generator in ["codec"]:
+ loss = self.l1Loss(y_gt, y_pred.squeeze(1)) + self.l2Loss(
+ y_gt, y_pred.squeeze(1)
+ )
+ loss /= 10
+ else:
+ raise NotImplementedError
+
+ return loss
+
+ class phase_criterion(torch.nn.Module):
+ def __init__(self, cfg):
+ super(phase_criterion, self).__init__()
+ self.cfg = cfg
+ self.l1Loss = torch.nn.L1Loss(reduction="mean")
+ self.l2Loss = torch.nn.MSELoss(reduction="mean")
+ self.relu = torch.nn.ReLU()
+
+ def __call__(self, phase_gt, phase_pred):
+ n_fft = self.cfg.preprocess.n_fft
+ frames = phase_gt.size()[-1]
+
+ GD_matrix = (
+ torch.triu(torch.ones(n_fft // 2 + 1, n_fft // 2 + 1), diagonal=1)
+ - torch.triu(torch.ones(n_fft // 2 + 1, n_fft // 2 + 1), diagonal=2)
+ - torch.eye(n_fft // 2 + 1)
+ )
+ GD_matrix = GD_matrix.to(phase_pred.device)
+
+ GD_r = torch.matmul(phase_gt.permute(0, 2, 1), GD_matrix)
+ GD_g = torch.matmul(phase_pred.permute(0, 2, 1), GD_matrix)
+
+ PTD_matrix = (
+ torch.triu(torch.ones(frames, frames), diagonal=1)
+ - torch.triu(torch.ones(frames, frames), diagonal=2)
+ - torch.eye(frames)
+ )
+ PTD_matrix = PTD_matrix.to(phase_pred.device)
+
+ PTD_r = torch.matmul(phase_gt, PTD_matrix)
+ PTD_g = torch.matmul(phase_pred, PTD_matrix)
+
+ IP_loss = torch.mean(-torch.cos(phase_gt - phase_pred))
+ GD_loss = torch.mean(-torch.cos(GD_r - GD_g))
+ PTD_loss = torch.mean(-torch.cos(PTD_r - PTD_g))
+
+ return 100 * (IP_loss + GD_loss + PTD_loss)
+
+ class amplitude_criterion(torch.nn.Module):
+ def __init__(self, cfg):
+ super(amplitude_criterion, self).__init__()
+ self.cfg = cfg
+ self.l1Loss = torch.nn.L1Loss(reduction="mean")
+ self.l2Loss = torch.nn.MSELoss(reduction="mean")
+ self.relu = torch.nn.ReLU()
+
+ def __call__(self, log_amplitude_gt, log_amplitude_pred):
+ amplitude_loss = self.l2Loss(log_amplitude_gt, log_amplitude_pred)
+
+ return 45 * amplitude_loss
+
+ class consistency_criterion(torch.nn.Module):
+ def __init__(self, cfg):
+ super(consistency_criterion, self).__init__()
+ self.cfg = cfg
+ self.l1Loss = torch.nn.L1Loss(reduction="mean")
+ self.l2Loss = torch.nn.MSELoss(reduction="mean")
+ self.relu = torch.nn.ReLU()
+
+ def __call__(
+ self,
+ rea_gt,
+ rea_pred,
+ rea_pred_final,
+ imag_gt,
+ imag_pred,
+ imag_pred_final,
+ ):
+ C_loss = torch.mean(
+ torch.mean(
+ (rea_pred - rea_pred_final) ** 2
+ + (imag_pred - imag_pred_final) ** 2,
+ (1, 2),
+ )
+ )
+
+ L_R = self.l1Loss(rea_gt, rea_pred)
+ L_I = self.l1Loss(imag_gt, imag_pred)
+
+ return 20 * (C_loss + 2.25 * (L_R + L_I))
+
+ criterions = dict()
+ for key in self.cfg.train.criterions:
+ if key == "feature":
+ criterions["feature"] = feature_criterion(self.cfg)
+ elif key == "discriminator":
+ criterions["discriminator"] = discriminator_criterion(self.cfg)
+ elif key == "generator":
+ criterions["generator"] = generator_criterion(self.cfg)
+ elif key == "mel":
+ criterions["mel"] = mel_criterion(self.cfg)
+ elif key == "wav":
+ criterions["wav"] = wav_criterion(self.cfg)
+ elif key == "phase":
+ criterions["phase"] = phase_criterion(self.cfg)
+ elif key == "amplitude":
+ criterions["amplitude"] = amplitude_criterion(self.cfg)
+ elif key == "consistency":
+ criterions["consistency"] = consistency_criterion(self.cfg)
+ else:
+ raise NotImplementedError
+
+ return criterions
+
+ def _build_model(self):
+ generator = supported_generators[self.cfg.model.generator](self.cfg)
+ discriminators = dict()
+ for key in self.cfg.model.discriminators:
+ discriminators[key] = supported_discriminators[key](self.cfg)
+
+ return generator, discriminators
+
+ def _build_optimizer(self):
+ optimizer_params_generator = [dict(params=self.generator.parameters())]
+ generator_optimizer = AdamW(
+ optimizer_params_generator,
+ lr=self.cfg.train.adamw.lr,
+ betas=(self.cfg.train.adamw.adam_b1, self.cfg.train.adamw.adam_b2),
+ )
+
+ optimizer_params_discriminator = []
+ for discriminator in self.discriminators.keys():
+ optimizer_params_discriminator.append(
+ dict(params=self.discriminators[discriminator].parameters())
+ )
+ discriminator_optimizer = AdamW(
+ optimizer_params_discriminator,
+ lr=self.cfg.train.adamw.lr,
+ betas=(self.cfg.train.adamw.adam_b1, self.cfg.train.adamw.adam_b2),
+ )
+
+ return generator_optimizer, discriminator_optimizer
+
+ def _build_scheduler(self):
+ discriminator_scheduler = ExponentialLR(
+ self.discriminator_optimizer,
+ gamma=self.cfg.train.exponential_lr.lr_decay,
+ last_epoch=self.epoch - 1,
+ )
+
+ generator_scheduler = ExponentialLR(
+ self.generator_optimizer,
+ gamma=self.cfg.train.exponential_lr.lr_decay,
+ last_epoch=self.epoch - 1,
+ )
+
+ return generator_scheduler, discriminator_scheduler
+
+ def train_loop(self):
+ """Training process"""
+ self.accelerator.wait_for_everyone()
+
+ # Dump config
+ if self.accelerator.is_main_process:
+ self._dump_cfg(self.config_save_path)
+ self.generator.train()
+ for key in self.discriminators.keys():
+ self.discriminators[key].train()
+ self.generator_optimizer.zero_grad()
+ self.discriminator_optimizer.zero_grad()
+
+ # Sync and start training
+ self.accelerator.wait_for_everyone()
+ while self.epoch < self.max_epoch:
+ self.logger.info("\n")
+ self.logger.info("-" * 32)
+ self.logger.info("Epoch {}: ".format(self.epoch))
+
+ # Train and Validate
+ train_total_loss, train_losses = self._train_epoch()
+ for key, loss in train_losses.items():
+ self.logger.info(" |- Train/{} Loss: {:.6f}".format(key, loss))
+ self.accelerator.log(
+ {"Epoch/Train {} Loss".format(key): loss},
+ step=self.epoch,
+ )
+ valid_total_loss, valid_losses = self._valid_epoch()
+ for key, loss in valid_losses.items():
+ self.logger.info(" |- Valid/{} Loss: {:.6f}".format(key, loss))
+ self.accelerator.log(
+ {"Epoch/Valid {} Loss".format(key): loss},
+ step=self.epoch,
+ )
+ self.accelerator.log(
+ {
+ "Epoch/Train Total Loss": train_total_loss,
+ "Epoch/Valid Total Loss": valid_total_loss,
+ },
+ step=self.epoch,
+ )
+
+ # Update scheduler
+ self.accelerator.wait_for_everyone()
+ self.generator_scheduler.step()
+ self.discriminator_scheduler.step()
+
+ # Check save checkpoint interval
+ run_eval = False
+ if self.accelerator.is_main_process:
+ save_checkpoint = False
+ for i, num in enumerate(self.save_checkpoint_stride):
+ if self.epoch % num == 0:
+ save_checkpoint = True
+ run_eval |= self.run_eval[i]
+
+ # Save checkpoints
+ self.accelerator.wait_for_everyone()
+ if self.accelerator.is_main_process and save_checkpoint:
+ path = os.path.join(
+ self.checkpoint_dir,
+ "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
+ self.epoch, self.step, valid_total_loss
+ ),
+ )
+ self.accelerator.save_state(path)
+ json.dump(
+ self.checkpoints_path,
+ open(os.path.join(path, "ckpts.json"), "w"),
+ ensure_ascii=False,
+ indent=4,
+ )
+
+ # Save eval audios
+ self.accelerator.wait_for_everyone()
+ if self.accelerator.is_main_process and run_eval:
+ for i in range(len(self.valid_dataloader.dataset.eval_audios)):
+ if self.cfg.preprocess.use_frame_pitch:
+ eval_audio = self._inference(
+ self.valid_dataloader.dataset.eval_mels[i],
+ eval_pitch=self.valid_dataloader.dataset.eval_pitchs[i],
+ use_pitch=True,
+ )
+ else:
+ eval_audio = self._inference(
+ self.valid_dataloader.dataset.eval_mels[i]
+ )
+ path = os.path.join(
+ self.checkpoint_dir,
+ "epoch-{:04d}_step-{:07d}_loss-{:.6f}_eval_audio_{}.wav".format(
+ self.epoch,
+ self.step,
+ valid_total_loss,
+ self.valid_dataloader.dataset.eval_dataset_names[i],
+ ),
+ )
+ path_gt = os.path.join(
+ self.checkpoint_dir,
+ "epoch-{:04d}_step-{:07d}_loss-{:.6f}_eval_audio_{}_gt.wav".format(
+ self.epoch,
+ self.step,
+ valid_total_loss,
+ self.valid_dataloader.dataset.eval_dataset_names[i],
+ ),
+ )
+ save_audio(path, eval_audio, self.cfg.preprocess.sample_rate)
+ save_audio(
+ path_gt,
+ self.valid_dataloader.dataset.eval_audios[i],
+ self.cfg.preprocess.sample_rate,
+ )
+
+ self.accelerator.wait_for_everyone()
+
+ self.epoch += 1
+
+ # Finish training
+ self.accelerator.wait_for_everyone()
+ path = os.path.join(
+ self.checkpoint_dir,
+ "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
+ self.epoch, self.step, valid_total_loss
+ ),
+ )
+ self.accelerator.save_state(path)
+
+ def _train_epoch(self):
+ """Training epoch. Should return average loss of a batch (sample) over
+ one epoch. See ``train_loop`` for usage.
+ """
+ self.generator.train()
+ for key, _ in self.discriminators.items():
+ self.discriminators[key].train()
+
+ epoch_losses: dict = {}
+ epoch_total_loss: int = 0
+
+ for batch in tqdm(
+ self.train_dataloader,
+ desc=f"Training Epoch {self.epoch}",
+ unit="batch",
+ colour="GREEN",
+ leave=False,
+ dynamic_ncols=True,
+ smoothing=0.04,
+ disable=not self.accelerator.is_main_process,
+ ):
+ # Get losses
+ total_loss, losses = self._train_step(batch)
+ self.batch_count += 1
+
+ # Log info
+ if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
+ self.accelerator.log(
+ {
+ "Step/Generator Learning Rate": self.generator_optimizer.param_groups[
+ 0
+ ][
+ "lr"
+ ],
+ "Step/Discriminator Learning Rate": self.discriminator_optimizer.param_groups[
+ 0
+ ][
+ "lr"
+ ],
+ },
+ step=self.step,
+ )
+ for key, _ in losses.items():
+ self.accelerator.log(
+ {
+ "Step/Train {} Loss".format(key): losses[key],
+ },
+ step=self.step,
+ )
+
+ if not epoch_losses:
+ epoch_losses = losses
+ else:
+ for key, value in losses.items():
+ epoch_losses[key] += value
+ epoch_total_loss += total_loss
+ self.step += 1
+
+ # Get and log total losses
+ self.accelerator.wait_for_everyone()
+ epoch_total_loss = (
+ epoch_total_loss
+ / len(self.train_dataloader)
+ * self.cfg.train.gradient_accumulation_step
+ )
+ for key in epoch_losses.keys():
+ epoch_losses[key] = (
+ epoch_losses[key]
+ / len(self.train_dataloader)
+ * self.cfg.train.gradient_accumulation_step
+ )
+ return epoch_total_loss, epoch_losses
+
+ def _train_step(self, data):
+ """Training forward step. Should return average loss of a sample over
+ one batch. Provoke ``_forward_step`` is recommended except for special case.
+ See ``_train_epoch`` for usage.
+ """
+ # Init losses
+ train_losses = {}
+ total_loss = 0
+
+ generator_losses = {}
+ generator_total_loss = 0
+ discriminator_losses = {}
+ discriminator_total_loss = 0
+
+ # Use input feature to get predictions
+ mel_input = data["mel"]
+ audio_gt = data["audio"]
+
+ if self.cfg.preprocess.extract_amplitude_phase:
+ logamp_gt = data["logamp"]
+ pha_gt = data["pha"]
+ rea_gt = data["rea"]
+ imag_gt = data["imag"]
+
+ if self.cfg.preprocess.use_frame_pitch:
+ pitch_input = data["frame_pitch"]
+
+ if self.cfg.preprocess.use_frame_pitch:
+ pitch_input = pitch_input.float()
+ audio_pred = self.generator.forward(mel_input, pitch_input)
+ elif self.cfg.preprocess.extract_amplitude_phase:
+ (
+ logamp_pred,
+ pha_pred,
+ rea_pred,
+ imag_pred,
+ audio_pred,
+ ) = self.generator.forward(mel_input)
+ from utils.mel import amplitude_phase_spectrum
+
+ _, _, rea_pred_final, imag_pred_final = amplitude_phase_spectrum(
+ audio_pred.squeeze(1), self.cfg.preprocess
+ )
+ else:
+ audio_pred = self.generator.forward(mel_input)
+
+ # Calculate and BP Discriminator losses
+ self.discriminator_optimizer.zero_grad()
+ for key, _ in self.discriminators.items():
+ y_r, y_g, _, _ = self.discriminators[key].forward(
+ audio_gt.unsqueeze(1), audio_pred.detach()
+ )
+ (
+ discriminator_losses["{}_discriminator".format(key)],
+ _,
+ _,
+ ) = self.criterions["discriminator"](y_r, y_g)
+ discriminator_total_loss += discriminator_losses[
+ "{}_discriminator".format(key)
+ ]
+
+ self.accelerator.backward(discriminator_total_loss)
+ self.discriminator_optimizer.step()
+
+ # Calculate and BP Generator losses
+ self.generator_optimizer.zero_grad()
+ for key, _ in self.discriminators.items():
+ y_r, y_g, f_r, f_g = self.discriminators[key].forward(
+ audio_gt.unsqueeze(1), audio_pred
+ )
+ generator_losses["{}_feature".format(key)] = self.criterions["feature"](
+ f_r, f_g
+ )
+ generator_losses["{}_generator".format(key)], _ = self.criterions[
+ "generator"
+ ](y_g)
+ generator_total_loss += generator_losses["{}_feature".format(key)]
+ generator_total_loss += generator_losses["{}_generator".format(key)]
+
+ if "mel" in self.criterions.keys():
+ generator_losses["mel"] = self.criterions["mel"](audio_gt, audio_pred)
+ generator_total_loss += generator_losses["mel"]
+
+ if "wav" in self.criterions.keys():
+ generator_losses["wav"] = self.criterions["wav"](audio_gt, audio_pred)
+ generator_total_loss += generator_losses["wav"]
+
+ if "amplitude" in self.criterions.keys():
+ generator_losses["amplitude"] = self.criterions["amplitude"](
+ logamp_gt, logamp_pred
+ )
+ generator_total_loss += generator_losses["amplitude"]
+
+ if "phase" in self.criterions.keys():
+ generator_losses["phase"] = self.criterions["phase"](pha_gt, pha_pred)
+ generator_total_loss += generator_losses["phase"]
+
+ if "consistency" in self.criterions.keys():
+ generator_losses["consistency"] = self.criterions["consistency"](
+ rea_gt, rea_pred, rea_pred_final, imag_gt, imag_pred, imag_pred_final
+ )
+ generator_total_loss += generator_losses["consistency"]
+
+ self.accelerator.backward(generator_total_loss)
+ self.generator_optimizer.step()
+
+ # Get the total losses
+ total_loss = discriminator_total_loss + generator_total_loss
+ train_losses.update(discriminator_losses)
+ train_losses.update(generator_losses)
+
+ for key, _ in train_losses.items():
+ train_losses[key] = train_losses[key].item()
+
+ return total_loss.item(), train_losses
+
+ def _valid_epoch(self):
+ """Testing epoch. Should return average loss of a batch (sample) over
+ one epoch. See ``train_loop`` for usage.
+ """
+ self.generator.eval()
+ for key, _ in self.discriminators.items():
+ self.discriminators[key].eval()
+
+ epoch_losses: dict = {}
+ epoch_total_loss: int = 0
+
+ for batch in tqdm(
+ self.valid_dataloader,
+ desc=f"Validating Epoch {self.epoch}",
+ unit="batch",
+ colour="GREEN",
+ leave=False,
+ dynamic_ncols=True,
+ smoothing=0.04,
+ disable=not self.accelerator.is_main_process,
+ ):
+ # Get losses
+ total_loss, losses = self._valid_step(batch)
+
+ # Log info
+ for key, _ in losses.items():
+ self.accelerator.log(
+ {
+ "Step/Valid {} Loss".format(key): losses[key],
+ },
+ step=self.step,
+ )
+
+ if not epoch_losses:
+ epoch_losses = losses
+ else:
+ for key, value in losses.items():
+ epoch_losses[key] += value
+ epoch_total_loss += total_loss
+
+ # Get and log total losses
+ self.accelerator.wait_for_everyone()
+ epoch_total_loss = epoch_total_loss / len(self.valid_dataloader)
+ for key in epoch_losses.keys():
+ epoch_losses[key] = epoch_losses[key] / len(self.valid_dataloader)
+ return epoch_total_loss, epoch_losses
+
+ def _valid_step(self, data):
+ """Testing forward step. Should return average loss of a sample over
+ one batch. Provoke ``_forward_step`` is recommended except for special case.
+ See ``_test_epoch`` for usage.
+ """
+ # Init losses
+ valid_losses = {}
+ total_loss = 0
+
+ generator_losses = {}
+ generator_total_loss = 0
+ discriminator_losses = {}
+ discriminator_total_loss = 0
+
+ # Use feature inputs to get the predicted audio
+ mel_input = data["mel"]
+ audio_gt = data["audio"]
+
+ if self.cfg.preprocess.extract_amplitude_phase:
+ logamp_gt = data["logamp"]
+ pha_gt = data["pha"]
+ rea_gt = data["rea"]
+ imag_gt = data["imag"]
+
+ if self.cfg.preprocess.use_frame_pitch:
+ pitch_input = data["frame_pitch"]
+
+ if self.cfg.preprocess.use_frame_pitch:
+ pitch_input = pitch_input.float()
+ audio_pred = self.generator.forward(mel_input, pitch_input)
+ elif self.cfg.preprocess.extract_amplitude_phase:
+ (
+ logamp_pred,
+ pha_pred,
+ rea_pred,
+ imag_pred,
+ audio_pred,
+ ) = self.generator.forward(mel_input)
+ from utils.mel import amplitude_phase_spectrum
+
+ _, _, rea_pred_final, imag_pred_final = amplitude_phase_spectrum(
+ audio_pred.squeeze(1), self.cfg.preprocess
+ )
+ else:
+ audio_pred = self.generator.forward(mel_input)
+
+ # Get Discriminator losses
+ for key, _ in self.discriminators.items():
+ y_r, y_g, _, _ = self.discriminators[key].forward(
+ audio_gt.unsqueeze(1), audio_pred
+ )
+ (
+ discriminator_losses["{}_discriminator".format(key)],
+ _,
+ _,
+ ) = self.criterions["discriminator"](y_r, y_g)
+ discriminator_total_loss += discriminator_losses[
+ "{}_discriminator".format(key)
+ ]
+
+ for key, _ in self.discriminators.items():
+ y_r, y_g, f_r, f_g = self.discriminators[key].forward(
+ audio_gt.unsqueeze(1), audio_pred
+ )
+ generator_losses["{}_feature".format(key)] = self.criterions["feature"](
+ f_r, f_g
+ )
+ generator_losses["{}_generator".format(key)], _ = self.criterions[
+ "generator"
+ ](y_g)
+ generator_total_loss += generator_losses["{}_feature".format(key)]
+ generator_total_loss += generator_losses["{}_generator".format(key)]
+
+ if "mel" in self.criterions.keys():
+ generator_losses["mel"] = self.criterions["mel"](audio_gt, audio_pred)
+ generator_total_loss += generator_losses["mel"]
+ if "mel" in self.criterions.keys():
+ generator_losses["mel"] = self.criterions["mel"](audio_gt, audio_pred)
+ generator_total_loss += generator_losses["mel"]
+
+ if "wav" in self.criterions.keys():
+ generator_losses["wav"] = self.criterions["wav"](audio_gt, audio_pred)
+ generator_total_loss += generator_losses["wav"]
+ if "wav" in self.criterions.keys():
+ generator_losses["wav"] = self.criterions["wav"](audio_gt, audio_pred)
+ generator_total_loss += generator_losses["wav"]
+
+ if "amplitude" in self.criterions.keys():
+ generator_losses["amplitude"] = self.criterions["amplitude"](
+ logamp_gt, logamp_pred
+ )
+ generator_total_loss += generator_losses["amplitude"]
+
+ if "phase" in self.criterions.keys():
+ generator_losses["phase"] = self.criterions["phase"](pha_gt, pha_pred)
+ generator_total_loss += generator_losses["phase"]
+
+ if "consistency" in self.criterions.keys():
+ generator_losses["consistency"] = self.criterions["consistency"](
+ rea_gt,
+ rea_pred,
+ rea_pred_final,
+ imag_gt,
+ imag_pred,
+ imag_pred_final,
+ )
+ generator_total_loss += generator_losses["consistency"]
+
+ total_loss = discriminator_total_loss + generator_total_loss
+ valid_losses.update(discriminator_losses)
+ valid_losses.update(generator_losses)
+
+ for item in valid_losses:
+ valid_losses[item] = valid_losses[item].item()
+
+ return total_loss.item(), valid_losses
+
+ def _inference(self, eval_mel, eval_pitch=None, use_pitch=False):
+ """Inference during training for test audios."""
+ if use_pitch:
+ eval_pitch = align_length(eval_pitch, eval_mel.shape[1])
+ eval_audio = vocoder_inference(
+ self.cfg,
+ self.generator,
+ torch.from_numpy(eval_mel).unsqueeze(0),
+ f0s=torch.from_numpy(eval_pitch).unsqueeze(0).float(),
+ device=next(self.generator.parameters()).device,
+ ).squeeze(0)
+ else:
+ eval_audio = vocoder_inference(
+ self.cfg,
+ self.generator,
+ torch.from_numpy(eval_mel).unsqueeze(0),
+ device=next(self.generator.parameters()).device,
+ ).squeeze(0)
+ return eval_audio
+
+ def _load_model(self, checkpoint_dir, checkpoint_path=None, resume_type="resume"):
+ """Load model from checkpoint. If checkpoint_path is None, it will
+ load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
+ None, it will load the checkpoint specified by checkpoint_path. **Only use this
+ method after** ``accelerator.prepare()``.
+ """
+ if checkpoint_path is None:
+ ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
+ ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
+ checkpoint_path = ls[0]
+ if resume_type == "resume":
+ self.accelerator.load_state(checkpoint_path)
+ self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
+ self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
+ elif resume_type == "finetune":
+ accelerate.load_checkpoint_and_dispatch(
+ self.accelerator.unwrap_model(self.generator),
+ os.path.join(checkpoint_path, "pytorch_model.bin"),
+ )
+ for key, _ in self.discriminators.items():
+ accelerate.load_checkpoint_and_dispatch(
+ self.accelerator.unwrap_model(self.discriminators[key]),
+ os.path.join(checkpoint_path, "pytorch_model.bin"),
+ )
+ self.logger.info("Load model weights for finetune SUCCESS!")
+ else:
+ raise ValueError("Unsupported resume type: {}".format(resume_type))
+ return checkpoint_path
+
+ def _count_parameters(self):
+ result = sum(p.numel() for p in self.generator.parameters())
+ for _, discriminator in self.discriminators.items():
+ result += sum(p.numel() for p in discriminator.parameters())
+ return result
diff --git a/models/vocoders/gan/generator/__init__.py b/models/vocoders/gan/generator/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/vocoders/gan/generator/apnet.py b/models/vocoders/gan/generator/apnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d529bbda7dd89857df9c54f1e60e873a5c9fc48
--- /dev/null
+++ b/models/vocoders/gan/generator/apnet.py
@@ -0,0 +1,399 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn.functional as F
+import torch.nn as nn
+from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
+from torch.nn.utils import weight_norm, spectral_norm
+from modules.vocoder_blocks import *
+
+LRELU_SLOPE = 0.1
+
+
+class ISTFT(nn.Module):
+ """
+ Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
+ windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
+ See issue: https://github.com/pytorch/pytorch/issues/62323
+ Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
+ The NOLA constraint is met as we trim padded samples anyway.
+
+ Args:
+ n_fft (int): Size of Fourier transform.
+ hop_length (int): The distance between neighboring sliding window frames.
+ win_length (int): The size of window frame and STFT filter.
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+ """
+
+ def __init__(
+ self,
+ n_fft: int,
+ hop_length: int,
+ win_length: int,
+ padding: str = "same",
+ ):
+ super().__init__()
+ if padding not in ["center", "same"]:
+ raise ValueError("Padding must be 'center' or 'same'.")
+ self.padding = padding
+ self.n_fft = n_fft
+ self.hop_length = hop_length
+ self.win_length = win_length
+
+ def forward(self, spec: torch.Tensor, window) -> torch.Tensor:
+ """
+ Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
+
+ Args:
+ spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
+ N is the number of frequency bins, and T is the number of time frames.
+
+ Returns:
+ Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
+ """
+ if self.padding == "center":
+ # Fallback to pytorch native implementation
+ return torch.istft(
+ spec,
+ self.n_fft,
+ self.hop_length,
+ self.win_length,
+ window,
+ center=True,
+ )
+ elif self.padding == "same":
+ pad = (self.win_length - self.hop_length) // 2
+ else:
+ raise ValueError("Padding must be 'center' or 'same'.")
+
+ assert spec.dim() == 3, "Expected a 3D tensor as input"
+ B, N, T = spec.shape
+
+ # Inverse FFT
+ ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
+ ifft = ifft * window[None, :, None]
+
+ # Overlap and Add
+ output_size = (T - 1) * self.hop_length + self.win_length
+ y = torch.nn.functional.fold(
+ ifft,
+ output_size=(1, output_size),
+ kernel_size=(1, self.win_length),
+ stride=(1, self.hop_length),
+ )[:, 0, 0, pad:-pad]
+
+ # Window envelope
+ window_sq = window.square().expand(1, T, -1).transpose(1, 2)
+ window_envelope = torch.nn.functional.fold(
+ window_sq,
+ output_size=(1, output_size),
+ kernel_size=(1, self.win_length),
+ stride=(1, self.hop_length),
+ ).squeeze()[pad:-pad]
+
+ # Normalize
+ assert (window_envelope > 1e-11).all()
+ y = y / window_envelope
+
+ return y
+
+
+# The ASP and PSP Module are adopted from APNet under the MIT License
+# https://github.com/YangAi520/APNet/blob/main/models.py
+
+
+class ASPResBlock(torch.nn.Module):
+ def __init__(self, cfg, channels, kernel_size=3, dilation=(1, 3, 5)):
+ super(ASPResBlock, self).__init__()
+ self.cfg = cfg
+ self.convs1 = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[2],
+ padding=get_padding(kernel_size, dilation[2]),
+ )
+ ),
+ ]
+ )
+ self.convs1.apply(init_weights)
+
+ self.convs2 = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ ]
+ )
+ self.convs2.apply(init_weights)
+
+ def forward(self, x):
+ for c1, c2 in zip(self.convs1, self.convs2):
+ xt = F.leaky_relu(x, LRELU_SLOPE)
+ xt = c1(xt)
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
+ xt = c2(xt)
+ x = xt + x
+ return x
+
+
+class PSPResBlock(torch.nn.Module):
+ def __init__(self, cfg, channels, kernel_size=3, dilation=(1, 3, 5)):
+ super(PSPResBlock, self).__init__()
+ self.cfg = cfg
+ self.convs1 = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[2],
+ padding=get_padding(kernel_size, dilation[2]),
+ )
+ ),
+ ]
+ )
+ self.convs1.apply(init_weights)
+
+ self.convs2 = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ ]
+ )
+ self.convs2.apply(init_weights)
+
+ def forward(self, x):
+ for c1, c2 in zip(self.convs1, self.convs2):
+ xt = F.leaky_relu(x, LRELU_SLOPE)
+ xt = c1(xt)
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
+ xt = c2(xt)
+ x = xt + x
+ return x
+
+
+class APNet(torch.nn.Module):
+ def __init__(self, cfg):
+ super(APNet, self).__init__()
+ self.cfg = cfg
+ self.ASP_num_kernels = len(cfg.model.apnet.ASP_resblock_kernel_sizes)
+ self.PSP_num_kernels = len(cfg.model.apnet.PSP_resblock_kernel_sizes)
+
+ self.ASP_input_conv = weight_norm(
+ Conv1d(
+ cfg.preprocess.n_mel,
+ cfg.model.apnet.ASP_channel,
+ cfg.model.apnet.ASP_input_conv_kernel_size,
+ 1,
+ padding=get_padding(cfg.model.apnet.ASP_input_conv_kernel_size, 1),
+ )
+ )
+ self.PSP_input_conv = weight_norm(
+ Conv1d(
+ cfg.preprocess.n_mel,
+ cfg.model.apnet.PSP_channel,
+ cfg.model.apnet.PSP_input_conv_kernel_size,
+ 1,
+ padding=get_padding(cfg.model.apnet.PSP_input_conv_kernel_size, 1),
+ )
+ )
+
+ self.ASP_ResNet = nn.ModuleList()
+ for j, (k, d) in enumerate(
+ zip(
+ cfg.model.apnet.ASP_resblock_kernel_sizes,
+ cfg.model.apnet.ASP_resblock_dilation_sizes,
+ )
+ ):
+ self.ASP_ResNet.append(ASPResBlock(cfg, cfg.model.apnet.ASP_channel, k, d))
+
+ self.PSP_ResNet = nn.ModuleList()
+ for j, (k, d) in enumerate(
+ zip(
+ cfg.model.apnet.PSP_resblock_kernel_sizes,
+ cfg.model.apnet.PSP_resblock_dilation_sizes,
+ )
+ ):
+ self.PSP_ResNet.append(PSPResBlock(cfg, cfg.model.apnet.PSP_channel, k, d))
+
+ self.ASP_output_conv = weight_norm(
+ Conv1d(
+ cfg.model.apnet.ASP_channel,
+ cfg.preprocess.n_fft // 2 + 1,
+ cfg.model.apnet.ASP_output_conv_kernel_size,
+ 1,
+ padding=get_padding(cfg.model.apnet.ASP_output_conv_kernel_size, 1),
+ )
+ )
+ self.PSP_output_R_conv = weight_norm(
+ Conv1d(
+ cfg.model.apnet.PSP_channel,
+ cfg.preprocess.n_fft // 2 + 1,
+ cfg.model.apnet.PSP_output_R_conv_kernel_size,
+ 1,
+ padding=get_padding(cfg.model.apnet.PSP_output_R_conv_kernel_size, 1),
+ )
+ )
+ self.PSP_output_I_conv = weight_norm(
+ Conv1d(
+ cfg.model.apnet.PSP_channel,
+ cfg.preprocess.n_fft // 2 + 1,
+ cfg.model.apnet.PSP_output_I_conv_kernel_size,
+ 1,
+ padding=get_padding(cfg.model.apnet.PSP_output_I_conv_kernel_size, 1),
+ )
+ )
+
+ self.iSTFT = ISTFT(
+ self.cfg.preprocess.n_fft,
+ hop_length=self.cfg.preprocess.hop_size,
+ win_length=self.cfg.preprocess.win_size,
+ )
+
+ self.ASP_output_conv.apply(init_weights)
+ self.PSP_output_R_conv.apply(init_weights)
+ self.PSP_output_I_conv.apply(init_weights)
+
+ def forward(self, mel):
+ logamp = self.ASP_input_conv(mel)
+ logamps = None
+ for j in range(self.ASP_num_kernels):
+ if logamps is None:
+ logamps = self.ASP_ResNet[j](logamp)
+ else:
+ logamps += self.ASP_ResNet[j](logamp)
+ logamp = logamps / self.ASP_num_kernels
+ logamp = F.leaky_relu(logamp)
+ logamp = self.ASP_output_conv(logamp)
+
+ pha = self.PSP_input_conv(mel)
+ phas = None
+ for j in range(self.PSP_num_kernels):
+ if phas is None:
+ phas = self.PSP_ResNet[j](pha)
+ else:
+ phas += self.PSP_ResNet[j](pha)
+ pha = phas / self.PSP_num_kernels
+ pha = F.leaky_relu(pha)
+ R = self.PSP_output_R_conv(pha)
+ I = self.PSP_output_I_conv(pha)
+
+ pha = torch.atan2(I, R)
+
+ rea = torch.exp(logamp) * torch.cos(pha)
+ imag = torch.exp(logamp) * torch.sin(pha)
+
+ spec = torch.cat((rea.unsqueeze(-1), imag.unsqueeze(-1)), -1)
+
+ spec = torch.view_as_complex(spec)
+
+ audio = self.iSTFT.forward(
+ spec, torch.hann_window(self.cfg.preprocess.win_size).to(mel.device)
+ )
+
+ return logamp, pha, rea, imag, audio.unsqueeze(1)
diff --git a/models/vocoders/gan/generator/bigvgan.py b/models/vocoders/gan/generator/bigvgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7658d31d59efe613aee3e7bf8089e91c8f484af
--- /dev/null
+++ b/models/vocoders/gan/generator/bigvgan.py
@@ -0,0 +1,341 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+import torch.nn as nn
+
+from torch.nn import Conv1d, ConvTranspose1d
+from torch.nn.utils import weight_norm, remove_weight_norm
+
+from modules.vocoder_blocks import *
+from modules.activation_functions import *
+from modules.anti_aliasing import *
+
+LRELU_SLOPE = 0.1
+
+# The AMPBlock Module is adopted from BigVGAN under the MIT License
+# https://github.com/NVIDIA/BigVGAN
+
+
+class AMPBlock1(torch.nn.Module):
+ def __init__(
+ self, cfg, channels, kernel_size=3, dilation=(1, 3, 5), activation=None
+ ):
+ super(AMPBlock1, self).__init__()
+ self.cfg = cfg
+
+ self.convs1 = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[2],
+ padding=get_padding(kernel_size, dilation[2]),
+ )
+ ),
+ ]
+ )
+ self.convs1.apply(init_weights)
+
+ self.convs2 = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ ]
+ )
+ self.convs2.apply(init_weights)
+
+ self.num_layers = len(self.convs1) + len(
+ self.convs2
+ ) # total number of conv layers
+
+ if (
+ activation == "snake"
+ ): # periodic nonlinearity with snake function and anti-aliasing
+ self.activations = nn.ModuleList(
+ [
+ Activation1d(
+ activation=Snake(
+ channels, alpha_logscale=cfg.model.bigvgan.snake_logscale
+ )
+ )
+ for _ in range(self.num_layers)
+ ]
+ )
+ elif (
+ activation == "snakebeta"
+ ): # periodic nonlinearity with snakebeta function and anti-aliasing
+ self.activations = nn.ModuleList(
+ [
+ Activation1d(
+ activation=SnakeBeta(
+ channels, alpha_logscale=cfg.model.bigvgan.snake_logscale
+ )
+ )
+ for _ in range(self.num_layers)
+ ]
+ )
+ else:
+ raise NotImplementedError(
+ "activation incorrectly specified. check the config file and look for 'activation'."
+ )
+
+ def forward(self, x):
+ acts1, acts2 = self.activations[::2], self.activations[1::2]
+ for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
+ xt = a1(x)
+ xt = c1(xt)
+ xt = a2(xt)
+ xt = c2(xt)
+ x = xt + x
+
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs1:
+ remove_weight_norm(l)
+ for l in self.convs2:
+ remove_weight_norm(l)
+
+
+class AMPBlock2(torch.nn.Module):
+ def __init__(self, cfg, channels, kernel_size=3, dilation=(1, 3), activation=None):
+ super(AMPBlock2, self).__init__()
+ self.cfg = cfg
+
+ self.convs = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]),
+ )
+ ),
+ ]
+ )
+ self.convs.apply(init_weights)
+
+ self.num_layers = len(self.convs) # total number of conv layers
+
+ if (
+ activation == "snake"
+ ): # periodic nonlinearity with snake function and anti-aliasing
+ self.activations = nn.ModuleList(
+ [
+ Activation1d(
+ activation=Snake(
+ channels, alpha_logscale=cfg.model.bigvgan.snake_logscale
+ )
+ )
+ for _ in range(self.num_layers)
+ ]
+ )
+ elif (
+ activation == "snakebeta"
+ ): # periodic nonlinearity with snakebeta function and anti-aliasing
+ self.activations = nn.ModuleList(
+ [
+ Activation1d(
+ activation=SnakeBeta(
+ channels, alpha_logscale=cfg.model.bigvgan.snake_logscale
+ )
+ )
+ for _ in range(self.num_layers)
+ ]
+ )
+ else:
+ raise NotImplementedError(
+ "activation incorrectly specified. check the config file and look for 'activation'."
+ )
+
+ def forward(self, x):
+ for c, a in zip(self.convs, self.activations):
+ xt = a(x)
+ xt = c(xt)
+ x = xt + x
+
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs:
+ remove_weight_norm(l)
+
+
+class BigVGAN(torch.nn.Module):
+ def __init__(self, cfg):
+ super(BigVGAN, self).__init__()
+ self.cfg = cfg
+
+ self.num_kernels = len(cfg.model.bigvgan.resblock_kernel_sizes)
+ self.num_upsamples = len(cfg.model.bigvgan.upsample_rates)
+
+ # Conv pre to boost channels
+ self.conv_pre = weight_norm(
+ Conv1d(
+ cfg.preprocess.n_mel,
+ cfg.model.bigvgan.upsample_initial_channel,
+ 7,
+ 1,
+ padding=3,
+ )
+ )
+
+ resblock = AMPBlock1 if cfg.model.bigvgan.resblock == "1" else AMPBlock2
+
+ # Upsamplers
+ self.ups = nn.ModuleList()
+ for i, (u, k) in enumerate(
+ zip(
+ cfg.model.bigvgan.upsample_rates,
+ cfg.model.bigvgan.upsample_kernel_sizes,
+ )
+ ):
+ self.ups.append(
+ nn.ModuleList(
+ [
+ weight_norm(
+ ConvTranspose1d(
+ cfg.model.bigvgan.upsample_initial_channel // (2**i),
+ cfg.model.bigvgan.upsample_initial_channel
+ // (2 ** (i + 1)),
+ k,
+ u,
+ padding=(k - u) // 2,
+ )
+ )
+ ]
+ )
+ )
+
+ # Res Blocks with AMP and Anti-aliasing
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = cfg.model.bigvgan.upsample_initial_channel // (2 ** (i + 1))
+ for j, (k, d) in enumerate(
+ zip(
+ cfg.model.bigvgan.resblock_kernel_sizes,
+ cfg.model.bigvgan.resblock_dilation_sizes,
+ )
+ ):
+ self.resblocks.append(
+ resblock(cfg, ch, k, d, activation=cfg.model.bigvgan.activation)
+ )
+
+ # Conv post for result
+ if cfg.model.bigvgan.activation == "snake":
+ activation_post = Snake(ch, alpha_logscale=cfg.model.bigvgan.snake_logscale)
+ self.activation_post = Activation1d(activation=activation_post)
+ elif cfg.model.bigvgan.activation == "snakebeta":
+ activation_post = SnakeBeta(
+ ch, alpha_logscale=cfg.model.bigvgan.snake_logscale
+ )
+ self.activation_post = Activation1d(activation=activation_post)
+ else:
+ raise NotImplementedError(
+ "activation incorrectly specified. check the config file and look for 'activation'."
+ )
+
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
+
+ # Weight Norm
+ for i in range(len(self.ups)):
+ self.ups[i].apply(init_weights)
+ self.conv_post.apply(init_weights)
+
+ def forward(self, x):
+ x = self.conv_pre(x)
+
+ for i in range(self.num_upsamples):
+ for i_up in range(len(self.ups[i])):
+ x = self.ups[i][i_up](x)
+ xs = None
+ for j in range(self.num_kernels):
+ if xs is None:
+ xs = self.resblocks[i * self.num_kernels + j](x)
+ else:
+ xs += self.resblocks[i * self.num_kernels + j](x)
+ x = xs / self.num_kernels
+
+ x = self.activation_post(x)
+ x = self.conv_post(x)
+ x = torch.tanh(x)
+
+ return x
+
+ def remove_weight_norm(self):
+ print("Removing weight norm...")
+ for l in self.ups:
+ for l_i in l:
+ remove_weight_norm(l_i)
+ for l in self.resblocks:
+ l.remove_weight_norm()
+ remove_weight_norm(self.conv_pre)
+ remove_weight_norm(self.conv_post)
diff --git a/models/vocoders/gan/generator/hifigan.py b/models/vocoders/gan/generator/hifigan.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f5f32498f5eb6441db787b0ae204a1eeff36aa3
--- /dev/null
+++ b/models/vocoders/gan/generator/hifigan.py
@@ -0,0 +1,449 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn.functional as F
+import torch.nn as nn
+from torch.nn import Conv1d, ConvTranspose1d
+from torch.nn.utils import weight_norm, remove_weight_norm
+from modules.vocoder_blocks import *
+
+
+LRELU_SLOPE = 0.1
+
+
+class ResBlock1(torch.nn.Module):
+ def __init__(self, cfg, channels, kernel_size=3, dilation=(1, 3, 5)):
+ super(ResBlock1, self).__init__()
+ self.cfg = cfg
+ self.convs1 = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[2],
+ padding=get_padding(kernel_size, dilation[2]),
+ )
+ ),
+ ]
+ )
+ self.convs1.apply(init_weights)
+
+ self.convs2 = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ ]
+ )
+ self.convs2.apply(init_weights)
+
+ def forward(self, x):
+ for c1, c2 in zip(self.convs1, self.convs2):
+ xt = F.leaky_relu(x, LRELU_SLOPE)
+ xt = c1(xt)
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
+ xt = c2(xt)
+ x = xt + x
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs1:
+ remove_weight_norm(l)
+ for l in self.convs2:
+ remove_weight_norm(l)
+
+
+class ResBlock2(torch.nn.Module):
+ def __init__(self, cfg, channels, kernel_size=3, dilation=(1, 3)):
+ super(ResBlock2, self).__init__()
+ self.cfg = cfg
+ self.convs = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]),
+ )
+ ),
+ ]
+ )
+ self.convs.apply(init_weights)
+
+ def forward(self, x):
+ for c in self.convs:
+ xt = F.leaky_relu(x, LRELU_SLOPE)
+ xt = c(xt)
+ x = xt + x
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs:
+ remove_weight_norm(l)
+
+
+class HiFiGAN(torch.nn.Module):
+ def __init__(self, cfg):
+ super(HiFiGAN, self).__init__()
+ self.cfg = cfg
+ self.num_kernels = len(self.cfg.model.hifigan.resblock_kernel_sizes)
+ self.num_upsamples = len(self.cfg.model.hifigan.upsample_rates)
+ self.conv_pre = weight_norm(
+ Conv1d(
+ cfg.preprocess.n_mel,
+ self.cfg.model.hifigan.upsample_initial_channel,
+ 7,
+ 1,
+ padding=3,
+ )
+ )
+ resblock = ResBlock1 if self.cfg.model.hifigan.resblock == "1" else ResBlock2
+
+ self.ups = nn.ModuleList()
+ for i, (u, k) in enumerate(
+ zip(
+ self.cfg.model.hifigan.upsample_rates,
+ self.cfg.model.hifigan.upsample_kernel_sizes,
+ )
+ ):
+ self.ups.append(
+ weight_norm(
+ ConvTranspose1d(
+ self.cfg.model.hifigan.upsample_initial_channel // (2**i),
+ self.cfg.model.hifigan.upsample_initial_channel
+ // (2 ** (i + 1)),
+ k,
+ u,
+ padding=(k - u) // 2,
+ )
+ )
+ )
+
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = self.cfg.model.hifigan.upsample_initial_channel // (2 ** (i + 1))
+ for j, (k, d) in enumerate(
+ zip(
+ self.cfg.model.hifigan.resblock_kernel_sizes,
+ self.cfg.model.hifigan.resblock_dilation_sizes,
+ )
+ ):
+ self.resblocks.append(resblock(self.cfg, ch, k, d))
+
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
+ self.ups.apply(init_weights)
+ self.conv_post.apply(init_weights)
+
+ def forward(self, x):
+ x = self.conv_pre(x)
+ for i in range(self.num_upsamples):
+ x = F.leaky_relu(x, LRELU_SLOPE)
+ x = self.ups[i](x)
+ xs = None
+ for j in range(self.num_kernels):
+ if xs is None:
+ xs = self.resblocks[i * self.num_kernels + j](x)
+ else:
+ xs += self.resblocks[i * self.num_kernels + j](x)
+ x = xs / self.num_kernels
+ x = F.leaky_relu(x)
+ x = self.conv_post(x)
+ x = torch.tanh(x)
+
+ return x
+
+ def remove_weight_norm(self):
+ print("Removing weight norm...")
+ for l in self.ups:
+ remove_weight_norm(l)
+ for l in self.resblocks:
+ l.remove_weight_norm()
+ remove_weight_norm(self.conv_pre)
+ remove_weight_norm(self.conv_post)
+
+
+# todo: merge with ResBlock1 (lmxue, yicheng)
+class ResBlock1_vits(torch.nn.Module):
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
+ super(ResBlock1_vits, self).__init__()
+ self.convs1 = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[2],
+ padding=get_padding(kernel_size, dilation[2]),
+ )
+ ),
+ ]
+ )
+ self.convs1.apply(init_weights)
+
+ self.convs2 = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ ]
+ )
+ self.convs2.apply(init_weights)
+
+ def forward(self, x, x_mask=None):
+ for c1, c2 in zip(self.convs1, self.convs2):
+ xt = F.leaky_relu(x, LRELU_SLOPE)
+ if x_mask is not None:
+ xt = xt * x_mask
+ xt = c1(xt)
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
+ if x_mask is not None:
+ xt = xt * x_mask
+ xt = c2(xt)
+ x = xt + x
+ if x_mask is not None:
+ x = x * x_mask
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs1:
+ remove_weight_norm(l)
+ for l in self.convs2:
+ remove_weight_norm(l)
+
+
+# todo: merge with ResBlock2 (lmxue, yicheng)
+class ResBlock2_vits(torch.nn.Module):
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
+ super(ResBlock2_vits, self).__init__()
+ self.convs = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]),
+ )
+ ),
+ ]
+ )
+ self.convs.apply(init_weights)
+
+ def forward(self, x, x_mask=None):
+ for c in self.convs:
+ xt = F.leaky_relu(x, LRELU_SLOPE)
+ if x_mask is not None:
+ xt = xt * x_mask
+ xt = c(xt)
+ x = xt + x
+ if x_mask is not None:
+ x = x * x_mask
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs:
+ remove_weight_norm(l)
+
+
+# todo: merge with HiFiGAN (lmxue, yicheng)
+class HiFiGAN_vits(torch.nn.Module):
+ def __init__(
+ self,
+ initial_channel,
+ resblock,
+ resblock_kernel_sizes,
+ resblock_dilation_sizes,
+ upsample_rates,
+ upsample_initial_channel,
+ upsample_kernel_sizes,
+ gin_channels=0,
+ ):
+ super(HiFiGAN_vits, self).__init__()
+ self.num_kernels = len(resblock_kernel_sizes)
+ self.num_upsamples = len(upsample_rates)
+ self.conv_pre = Conv1d(
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
+ )
+ resblock = ResBlock1_vits if resblock == "1" else ResBlock2_vits
+
+ self.ups = nn.ModuleList()
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+ self.ups.append(
+ weight_norm(
+ ConvTranspose1d(
+ upsample_initial_channel // (2**i),
+ upsample_initial_channel // (2 ** (i + 1)),
+ k,
+ u,
+ padding=(k - u) // 2,
+ )
+ )
+ )
+
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = upsample_initial_channel // (2 ** (i + 1))
+ for j, (k, d) in enumerate(
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
+ ):
+ self.resblocks.append(resblock(ch, k, d))
+
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
+ self.ups.apply(init_weights)
+
+ if gin_channels != 0:
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
+
+ def forward(self, x, g=None):
+ x = self.conv_pre(x)
+ if g is not None:
+ x = x + self.cond(g)
+
+ for i in range(self.num_upsamples):
+ x = F.leaky_relu(x, LRELU_SLOPE)
+ x = self.ups[i](x)
+ xs = None
+ for j in range(self.num_kernels):
+ if xs is None:
+ xs = self.resblocks[i * self.num_kernels + j](x)
+ else:
+ xs += self.resblocks[i * self.num_kernels + j](x)
+ x = xs / self.num_kernels
+ x = F.leaky_relu(x)
+ x = self.conv_post(x)
+ x = torch.tanh(x)
+
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.ups:
+ remove_weight_norm(l)
+ for l in self.resblocks:
+ l.remove_weight_norm()
diff --git a/models/vocoders/gan/generator/melgan.py b/models/vocoders/gan/generator/melgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..d13c5fe6c0844ff7d753ed14324a92b34cc7798c
--- /dev/null
+++ b/models/vocoders/gan/generator/melgan.py
@@ -0,0 +1,104 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+
+from torch.nn.utils import weight_norm
+
+# This code is adopted from MelGAN under the MIT License
+# https://github.com/descriptinc/melgan-neurips
+
+
+def weights_init(m):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ m.weight.data.normal_(0.0, 0.02)
+ elif classname.find("BatchNorm2d") != -1:
+ m.weight.data.normal_(1.0, 0.02)
+ m.bias.data.fill_(0)
+
+
+def WNConv1d(*args, **kwargs):
+ return weight_norm(nn.Conv1d(*args, **kwargs))
+
+
+def WNConvTranspose1d(*args, **kwargs):
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, dim, dilation=1):
+ super().__init__()
+ self.block = nn.Sequential(
+ nn.LeakyReLU(0.2),
+ nn.ReflectionPad1d(dilation),
+ WNConv1d(dim, dim, kernel_size=3, dilation=dilation),
+ nn.LeakyReLU(0.2),
+ WNConv1d(dim, dim, kernel_size=1),
+ )
+ self.shortcut = WNConv1d(dim, dim, kernel_size=1)
+
+ def forward(self, x):
+ return self.shortcut(x) + self.block(x)
+
+
+class MelGAN(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+
+ self.cfg = cfg
+
+ self.hop_length = np.prod(self.cfg.model.melgan.ratios)
+ mult = int(2 ** len(self.cfg.model.melgan.ratios))
+
+ model = [
+ nn.ReflectionPad1d(3),
+ WNConv1d(
+ self.cfg.preprocess.n_mel,
+ mult * self.cfg.model.melgan.ngf,
+ kernel_size=7,
+ padding=0,
+ ),
+ ]
+
+ # Upsample to raw audio scale
+ for i, r in enumerate(self.cfg.model.melgan.ratios):
+ model += [
+ nn.LeakyReLU(0.2),
+ WNConvTranspose1d(
+ mult * self.cfg.model.melgan.ngf,
+ mult * self.cfg.model.melgan.ngf // 2,
+ kernel_size=r * 2,
+ stride=r,
+ padding=r // 2 + r % 2,
+ output_padding=r % 2,
+ ),
+ ]
+
+ for j in range(self.cfg.model.melgan.n_residual_layers):
+ model += [
+ ResnetBlock(mult * self.cfg.model.melgan.ngf // 2, dilation=3**j)
+ ]
+
+ mult //= 2
+
+ model += [
+ nn.LeakyReLU(0.2),
+ nn.ReflectionPad1d(3),
+ WNConv1d(self.cfg.model.melgan.ngf, 1, kernel_size=7, padding=0),
+ nn.Tanh(),
+ ]
+
+ self.model = nn.Sequential(*model)
+ self.apply(weights_init)
+
+ def forward(self, x):
+ return self.model(x)
diff --git a/models/vocoders/gan/generator/nsfhifigan.py b/models/vocoders/gan/generator/nsfhifigan.py
new file mode 100644
index 0000000000000000000000000000000000000000..2db7f6d88b09525decc444a14115e6a63e548485
--- /dev/null
+++ b/models/vocoders/gan/generator/nsfhifigan.py
@@ -0,0 +1,283 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+import numpy as np
+import torch.nn as nn
+import torch.nn.functional as F
+
+from torch.nn import Conv1d, ConvTranspose1d
+from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
+
+from modules.neural_source_filter import *
+from modules.vocoder_blocks import *
+
+
+LRELU_SLOPE = 0.1
+
+
+class ResBlock1(nn.Module):
+ def __init__(self, cfg, channels, kernel_size=3, dilation=(1, 3, 5)):
+ super(ResBlock1, self).__init__()
+ self.cfg = cfg
+
+ self.convs1 = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[2],
+ padding=get_padding(kernel_size, dilation[2]),
+ )
+ ),
+ ]
+ )
+ self.convs1.apply(init_weights)
+
+ self.convs2 = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ ]
+ )
+ self.convs2.apply(init_weights)
+
+ def forward(self, x):
+ for c1, c2 in zip(self.convs1, self.convs2):
+ xt = F.leaky_relu(x, LRELU_SLOPE)
+ xt = c1(xt)
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
+ xt = c2(xt)
+ x = xt + x
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs1:
+ remove_weight_norm(l)
+ for l in self.convs2:
+ remove_weight_norm(l)
+
+
+class ResBlock2(nn.Module):
+ def __init__(self, cfg, channels, kernel_size=3, dilation=(1, 3)):
+ super(ResBlock1, self).__init__()
+ self.cfg = cfg
+
+ self.convs = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]),
+ )
+ ),
+ ]
+ )
+ self.convs.apply(init_weights)
+
+ def forward(self, x):
+ for c in self.convs:
+ xt = F.leaky_relu(x, LRELU_SLOPE)
+ xt = c(xt)
+ x = xt + x
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs:
+ remove_weight_norm(l)
+
+
+# This NSF Module is adopted from Xin Wang's NSF under the MIT License
+# https://github.com/nii-yamagishilab/project-NN-Pytorch-scripts
+
+
+class SourceModuleHnNSF(nn.Module):
+ def __init__(
+ self, fs, harmonic_num=0, amp=0.1, noise_std=0.003, voiced_threshold=0
+ ):
+ super(SourceModuleHnNSF, self).__init__()
+
+ self.amp = amp
+ self.noise_std = noise_std
+ self.l_sin_gen = SineGen(fs, harmonic_num, amp, noise_std, voiced_threshold)
+
+ self.l_linear = nn.Linear(harmonic_num + 1, 1)
+ self.l_tanh = nn.Tanh()
+
+ def forward(self, x, upp):
+ sine_wavs, uv, _ = self.l_sin_gen(x, upp)
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
+ return sine_merge
+
+
+class NSFHiFiGAN(nn.Module):
+ def __init__(self, cfg):
+ super(NSFHiFiGAN, self).__init__()
+
+ self.cfg = cfg
+ self.num_kernels = len(self.cfg.model.nsfhifigan.resblock_kernel_sizes)
+ self.num_upsamples = len(self.cfg.model.nsfhifigan.upsample_rates)
+ self.m_source = SourceModuleHnNSF(
+ fs=self.cfg.preprocess.sample_rate,
+ harmonic_num=self.cfg.model.nsfhifigan.harmonic_num,
+ )
+ self.noise_convs = nn.ModuleList()
+ self.conv_pre = weight_norm(
+ Conv1d(
+ self.cfg.preprocess.n_mel,
+ self.cfg.model.nsfhifigan.upsample_initial_channel,
+ 7,
+ 1,
+ padding=3,
+ )
+ )
+
+ resblock = ResBlock1 if self.cfg.model.nsfhifigan.resblock == "1" else ResBlock2
+
+ self.ups = nn.ModuleList()
+ for i, (u, k) in enumerate(
+ zip(
+ self.cfg.model.nsfhifigan.upsample_rates,
+ self.cfg.model.nsfhifigan.upsample_kernel_sizes,
+ )
+ ):
+ c_cur = self.cfg.model.nsfhifigan.upsample_initial_channel // (2 ** (i + 1))
+ self.ups.append(
+ weight_norm(
+ ConvTranspose1d(
+ self.cfg.model.nsfhifigan.upsample_initial_channel // (2**i),
+ self.cfg.model.nsfhifigan.upsample_initial_channel
+ // (2 ** (i + 1)),
+ k,
+ u,
+ padding=(k - u) // 2,
+ )
+ )
+ )
+ if i + 1 < len(self.cfg.model.nsfhifigan.upsample_rates):
+ stride_f0 = int(
+ np.prod(self.cfg.model.nsfhifigan.upsample_rates[i + 1 :])
+ )
+ self.noise_convs.append(
+ Conv1d(
+ 1,
+ c_cur,
+ kernel_size=stride_f0 * 2,
+ stride=stride_f0,
+ padding=stride_f0 // 2,
+ )
+ )
+ else:
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
+
+ self.resblocks = nn.ModuleList()
+ ch = self.cfg.model.nsfhifigan.upsample_initial_channel
+ for i in range(len(self.ups)):
+ ch //= 2
+ for j, (k, d) in enumerate(
+ zip(
+ self.cfg.model.nsfhifigan.resblock_kernel_sizes,
+ self.cfg.model.nsfhifigan.resblock_dilation_sizes,
+ )
+ ):
+ self.resblocks.append(resblock(cfg, ch, k, d))
+
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
+
+ self.ups.apply(init_weights)
+ self.conv_post.apply(init_weights)
+ self.upp = int(np.prod(self.cfg.model.nsfhifigan.upsample_rates))
+
+ def forward(self, x, f0):
+ har_source = self.m_source(f0, self.upp).transpose(1, 2)
+ x = self.conv_pre(x)
+ for i in range(self.num_upsamples):
+ x = F.leaky_relu(x, LRELU_SLOPE)
+ x = self.ups[i](x)
+ x_source = self.noise_convs[i](har_source)
+
+ length = min(x.shape[-1], x_source.shape[-1])
+ x = x[:, :, :length]
+ x_source = x[:, :, :length]
+
+ x = x + x_source
+ xs = None
+ for j in range(self.num_kernels):
+ if xs is None:
+ xs = self.resblocks[i * self.num_kernels + j](x)
+ else:
+ xs += self.resblocks[i * self.num_kernels + j](x)
+ x = xs / self.num_kernels
+ x = F.leaky_relu(x)
+ x = self.conv_post(x)
+ x = torch.tanh(x)
+
+ return x
diff --git a/models/vocoders/gan/generator/sifigan.py b/models/vocoders/gan/generator/sifigan.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/vocoders/vocoder_dataset.py b/models/vocoders/vocoder_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..7df17b97ba7a4f770f01971324126eca4a2db272
--- /dev/null
+++ b/models/vocoders/vocoder_dataset.py
@@ -0,0 +1,264 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Iterable
+import torch
+import numpy as np
+import torch.utils.data
+from torch.nn.utils.rnn import pad_sequence
+from utils.data_utils import *
+from torch.utils.data import ConcatDataset, Dataset
+
+
+class VocoderDataset(torch.utils.data.Dataset):
+ def __init__(self, cfg, dataset, is_valid=False):
+ """
+ Args:
+ cfg: config
+ dataset: dataset name
+ is_valid: whether to use train or valid dataset
+ """
+ assert isinstance(dataset, str)
+
+ processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
+
+ meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
+ self.metafile_path = os.path.join(processed_data_dir, meta_file)
+ self.metadata = self.get_metadata()
+
+ self.data_root = processed_data_dir
+ self.cfg = cfg
+
+ if cfg.preprocess.use_audio:
+ self.utt2audio_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2audio_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.audio_dir,
+ uid + ".npy",
+ )
+ elif cfg.preprocess.use_label:
+ self.utt2label_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2label_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.label_dir,
+ uid + ".npy",
+ )
+ elif cfg.preprocess.use_one_hot:
+ self.utt2one_hot_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2one_hot_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.one_hot_dir,
+ uid + ".npy",
+ )
+
+ if cfg.preprocess.use_mel:
+ self.utt2mel_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2mel_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.mel_dir,
+ uid + ".npy",
+ )
+
+ if cfg.preprocess.use_frame_pitch:
+ self.utt2frame_pitch_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ self.utt2frame_pitch_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.pitch_dir,
+ uid + ".npy",
+ )
+
+ if cfg.preprocess.use_uv:
+ self.utt2uv_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+ self.utt2uv_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.uv_dir,
+ uid + ".npy",
+ )
+
+ if cfg.preprocess.use_amplitude_phase:
+ self.utt2logamp_path = {}
+ self.utt2pha_path = {}
+ self.utt2rea_path = {}
+ self.utt2imag_path = {}
+ for utt_info in self.metadata:
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+ self.utt2logamp_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.log_amplitude_dir,
+ uid + ".npy",
+ )
+ self.utt2pha_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.phase_dir,
+ uid + ".npy",
+ )
+ self.utt2rea_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.real_dir,
+ uid + ".npy",
+ )
+ self.utt2imag_path[utt] = os.path.join(
+ cfg.preprocess.processed_dir,
+ dataset,
+ cfg.preprocess.imaginary_dir,
+ uid + ".npy",
+ )
+
+ def __getitem__(self, index):
+ utt_info = self.metadata[index]
+
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ single_feature = dict()
+
+ if self.cfg.preprocess.use_mel:
+ mel = np.load(self.utt2mel_path[utt])
+ assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T]
+
+ if "target_len" not in single_feature.keys():
+ single_feature["target_len"] = mel.shape[1]
+
+ single_feature["mel"] = mel
+
+ if self.cfg.preprocess.use_frame_pitch:
+ frame_pitch = np.load(self.utt2frame_pitch_path[utt])
+
+ if "target_len" not in single_feature.keys():
+ single_feature["target_len"] = len(frame_pitch)
+
+ aligned_frame_pitch = align_length(
+ frame_pitch, single_feature["target_len"]
+ )
+
+ single_feature["frame_pitch"] = aligned_frame_pitch
+
+ if self.cfg.preprocess.use_audio:
+ audio = np.load(self.utt2audio_path[utt])
+
+ single_feature["audio"] = audio
+
+ return single_feature
+
+ def get_metadata(self):
+ with open(self.metafile_path, "r", encoding="utf-8") as f:
+ metadata = json.load(f)
+
+ return metadata
+
+ def get_dataset_name(self):
+ return self.metadata[0]["Dataset"]
+
+ def __len__(self):
+ return len(self.metadata)
+
+
+class VocoderConcatDataset(ConcatDataset):
+ def __init__(self, datasets: Iterable[Dataset], full_audio_inference=False):
+ """Concatenate a series of datasets with their random inference audio merged."""
+ super().__init__(datasets)
+
+ self.cfg = self.datasets[0].cfg
+
+ self.metadata = []
+
+ # Merge metadata
+ for dataset in self.datasets:
+ self.metadata += dataset.metadata
+
+ # Merge random inference features
+ if full_audio_inference:
+ self.eval_audios = []
+ self.eval_dataset_names = []
+ if self.cfg.preprocess.use_mel:
+ self.eval_mels = []
+ if self.cfg.preprocess.use_frame_pitch:
+ self.eval_pitchs = []
+ for dataset in self.datasets:
+ self.eval_audios.append(dataset.eval_audio)
+ self.eval_dataset_names.append(dataset.get_dataset_name())
+ if self.cfg.preprocess.use_mel:
+ self.eval_mels.append(dataset.eval_mel)
+ if self.cfg.preprocess.use_frame_pitch:
+ self.eval_pitchs.append(dataset.eval_pitch)
+
+
+class VocoderCollator(object):
+ """Zero-pads model inputs and targets based on number of frames per step"""
+
+ def __init__(self, cfg):
+ self.cfg = cfg
+
+ def __call__(self, batch):
+ packed_batch_features = dict()
+
+ # mel: [b, n_mels, frame]
+ # frame_pitch: [b, frame]
+ # audios: [b, frame * hop_size]
+
+ for key in batch[0].keys():
+ if key == "target_len":
+ packed_batch_features["target_len"] = torch.LongTensor(
+ [b["target_len"] for b in batch]
+ )
+ masks = [
+ torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
+ ]
+ packed_batch_features["mask"] = pad_sequence(
+ masks, batch_first=True, padding_value=0
+ )
+ elif key == "mel":
+ values = [torch.from_numpy(b[key]).T for b in batch]
+ packed_batch_features[key] = pad_sequence(
+ values, batch_first=True, padding_value=0
+ )
+ else:
+ values = [torch.from_numpy(b[key]) for b in batch]
+ packed_batch_features[key] = pad_sequence(
+ values, batch_first=True, padding_value=0
+ )
+
+ return packed_batch_features
diff --git a/models/vocoders/vocoder_inference.py b/models/vocoders/vocoder_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..95e354c5db80cbac986543fdf7923014426c5078
--- /dev/null
+++ b/models/vocoders/vocoder_inference.py
@@ -0,0 +1,515 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import torch
+import json
+import json5
+import time
+import accelerate
+import random
+import numpy as np
+import shutil
+
+from pathlib import Path
+from tqdm import tqdm
+from glob import glob
+from accelerate.logging import get_logger
+from torch.utils.data import DataLoader
+
+from models.vocoders.vocoder_dataset import (
+ VocoderDataset,
+ VocoderCollator,
+ VocoderConcatDataset,
+)
+
+from models.vocoders.gan.generator import bigvgan, hifigan, melgan, nsfhifigan, apnet
+from models.vocoders.flow.waveglow import waveglow
+from models.vocoders.diffusion.diffwave import diffwave
+from models.vocoders.autoregressive.wavenet import wavenet
+from models.vocoders.autoregressive.wavernn import wavernn
+
+from models.vocoders.gan import gan_vocoder_inference
+from models.vocoders.diffusion import diffusion_vocoder_inference
+
+from utils.io import save_audio
+
+_vocoders = {
+ "diffwave": diffwave.DiffWave,
+ "wavernn": wavernn.WaveRNN,
+ "wavenet": wavenet.WaveNet,
+ "waveglow": waveglow.WaveGlow,
+ "nsfhifigan": nsfhifigan.NSFHiFiGAN,
+ "bigvgan": bigvgan.BigVGAN,
+ "hifigan": hifigan.HiFiGAN,
+ "melgan": melgan.MelGAN,
+ "apnet": apnet.APNet,
+}
+
+# Forward call for generalized Inferencor
+_vocoder_forward_funcs = {
+ # "world": world_inference.synthesis_audios,
+ # "wavernn": wavernn_inference.synthesis_audios,
+ # "wavenet": wavenet_inference.synthesis_audios,
+ "diffwave": diffusion_vocoder_inference.vocoder_inference,
+ "nsfhifigan": gan_vocoder_inference.vocoder_inference,
+ "bigvgan": gan_vocoder_inference.vocoder_inference,
+ "melgan": gan_vocoder_inference.vocoder_inference,
+ "hifigan": gan_vocoder_inference.vocoder_inference,
+ "apnet": gan_vocoder_inference.vocoder_inference,
+}
+
+# APIs for other tasks. e.g. SVC, TTS, TTA...
+_vocoder_infer_funcs = {
+ # "world": world_inference.synthesis_audios,
+ # "wavernn": wavernn_inference.synthesis_audios,
+ # "wavenet": wavenet_inference.synthesis_audios,
+ "diffwave": diffusion_vocoder_inference.synthesis_audios,
+ "nsfhifigan": gan_vocoder_inference.synthesis_audios,
+ "bigvgan": gan_vocoder_inference.synthesis_audios,
+ "melgan": gan_vocoder_inference.synthesis_audios,
+ "hifigan": gan_vocoder_inference.synthesis_audios,
+ "apnet": gan_vocoder_inference.synthesis_audios,
+}
+
+
+class VocoderInference(object):
+ def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
+ super().__init__()
+
+ start = time.monotonic_ns()
+ self.args = args
+ self.cfg = cfg
+ self.infer_type = infer_type
+
+ # Init accelerator
+ self.accelerator = accelerate.Accelerator()
+ self.accelerator.wait_for_everyone()
+
+ # Get logger
+ with self.accelerator.main_process_first():
+ self.logger = get_logger("inference", log_level=args.log_level)
+
+ # Log some info
+ self.logger.info("=" * 56)
+ self.logger.info("||\t\t" + "New inference process started." + "\t\t||")
+ self.logger.info("=" * 56)
+ self.logger.info("\n")
+
+ self.vocoder_dir = args.vocoder_dir
+ self.logger.debug(f"Vocoder dir: {args.vocoder_dir}")
+
+ os.makedirs(args.output_dir, exist_ok=True)
+ if os.path.exists(os.path.join(args.output_dir, "pred")):
+ shutil.rmtree(os.path.join(args.output_dir, "pred"))
+ if os.path.exists(os.path.join(args.output_dir, "gt")):
+ shutil.rmtree(os.path.join(args.output_dir, "gt"))
+ os.makedirs(os.path.join(args.output_dir, "pred"), exist_ok=True)
+ os.makedirs(os.path.join(args.output_dir, "gt"), exist_ok=True)
+
+ # Set random seed
+ with self.accelerator.main_process_first():
+ start = time.monotonic_ns()
+ self._set_random_seed(self.cfg.train.random_seed)
+ end = time.monotonic_ns()
+ self.logger.debug(
+ f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
+ )
+ self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
+
+ # Setup inference mode
+ if self.infer_type == "infer_from_dataset":
+ self.cfg.dataset = self.args.infer_datasets
+ elif self.infer_type == "infer_from_feature":
+ self._build_tmp_dataset_from_feature()
+ self.cfg.dataset = ["tmp"]
+ elif self.infer_type == "infer_from_audio":
+ self._build_tmp_dataset_from_audio()
+ self.cfg.dataset = ["tmp"]
+
+ # Setup data loader
+ with self.accelerator.main_process_first():
+ self.logger.info("Building dataset...")
+ start = time.monotonic_ns()
+ self.test_dataloader = self._build_dataloader()
+ end = time.monotonic_ns()
+ self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
+
+ # Build model
+ with self.accelerator.main_process_first():
+ self.logger.info("Building model...")
+ start = time.monotonic_ns()
+ self.model = self._build_model()
+ end = time.monotonic_ns()
+ self.logger.info(f"Building model done in {(end - start) / 1e6:.3f}ms")
+
+ # Init with accelerate
+ self.logger.info("Initializing accelerate...")
+ start = time.monotonic_ns()
+ self.accelerator = accelerate.Accelerator()
+ (self.model, self.test_dataloader) = self.accelerator.prepare(
+ self.model, self.test_dataloader
+ )
+ end = time.monotonic_ns()
+ self.accelerator.wait_for_everyone()
+ self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.3f}ms")
+
+ with self.accelerator.main_process_first():
+ self.logger.info("Loading checkpoint...")
+ start = time.monotonic_ns()
+ if os.path.isdir(args.vocoder_dir):
+ if os.path.isdir(os.path.join(args.vocoder_dir, "checkpoint")):
+ self._load_model(os.path.join(args.vocoder_dir, "checkpoint"))
+ else:
+ self._load_model(os.path.join(args.vocoder_dir))
+ else:
+ self._load_model(os.path.join(args.vocoder_dir))
+ end = time.monotonic_ns()
+ self.logger.info(f"Loading checkpoint done in {(end - start) / 1e6:.3f}ms")
+
+ self.model.eval()
+ self.accelerator.wait_for_everyone()
+
+ def _build_tmp_dataset_from_feature(self):
+ if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")):
+ shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
+
+ utts = []
+ mels = glob(os.path.join(self.args.feature_folder, "mels", "*.npy"))
+ for i, mel in enumerate(mels):
+ uid = mel.split("/")[-1].split(".")[0]
+ utt = {"Dataset": "tmp", "Uid": uid, "index": i}
+ utts.append(utt)
+
+ os.makedirs(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
+ with open(
+ os.path.join(self.cfg.preprocess.processed_dir, "tmp", "test.json"), "w"
+ ) as f:
+ json.dump(utts, f)
+
+ meta_info = {"dataset": "tmp", "test": {"size": len(utts)}}
+
+ with open(
+ os.path.join(self.cfg.preprocess.processed_dir, "tmp", "meta_info.json"),
+ "w",
+ ) as f:
+ json.dump(meta_info, f)
+
+ features = glob(os.path.join(self.args.feature_folder, "*"))
+ for feature in features:
+ feature_name = feature.split("/")[-1]
+ if os.path.isfile(feature):
+ continue
+ shutil.copytree(
+ os.path.join(self.args.feature_folder, feature_name),
+ os.path.join(self.cfg.preprocess.processed_dir, "tmp", feature_name),
+ )
+
+ def _build_tmp_dataset_from_audio(self):
+ if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")):
+ shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
+
+ utts = []
+ audios = glob(os.path.join(self.args.audio_folder, "*"))
+ for i, audio in enumerate(audios):
+ uid = audio.split("/")[-1].split(".")[0]
+ utt = {"Dataset": "tmp", "Uid": uid, "index": i, "Path": audio}
+ utts.append(utt)
+
+ os.makedirs(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
+ with open(
+ os.path.join(self.cfg.preprocess.processed_dir, "tmp", "test.json"), "w"
+ ) as f:
+ json.dump(utts, f)
+
+ meta_info = {"dataset": "tmp", "test": {"size": len(utts)}}
+
+ with open(
+ os.path.join(self.cfg.preprocess.processed_dir, "tmp", "meta_info.json"),
+ "w",
+ ) as f:
+ json.dump(meta_info, f)
+
+ from processors import acoustic_extractor
+
+ acoustic_extractor.extract_utt_acoustic_features_serial(
+ utts, os.path.join(self.cfg.preprocess.processed_dir, "tmp"), self.cfg
+ )
+
+ def _build_test_dataset(self):
+ return VocoderDataset, VocoderCollator
+
+ def _build_model(self):
+ model = _vocoders[self.cfg.model.generator](self.cfg)
+ return model
+
+ def _build_dataloader(self):
+ """Build dataloader which merges a series of datasets."""
+ Dataset, Collator = self._build_test_dataset()
+
+ datasets_list = []
+ for dataset in self.cfg.dataset:
+ subdataset = Dataset(self.cfg, dataset, is_valid=True)
+ datasets_list.append(subdataset)
+ test_dataset = VocoderConcatDataset(datasets_list, full_audio_inference=False)
+ test_collate = Collator(self.cfg)
+ test_batch_size = min(self.cfg.inference.batch_size, len(test_dataset))
+ test_dataloader = DataLoader(
+ test_dataset,
+ collate_fn=test_collate,
+ num_workers=1,
+ batch_size=test_batch_size,
+ shuffle=False,
+ )
+ self.test_batch_size = test_batch_size
+ self.test_dataset = test_dataset
+ return test_dataloader
+
+ def _load_model(self, checkpoint_dir, from_multi_gpu=False):
+ """Load model from checkpoint. If a folder is given, it will
+ load the latest checkpoint in checkpoint_dir. If a path is given
+ it will load the checkpoint specified by checkpoint_path.
+ **Only use this method after** ``accelerator.prepare()``.
+ """
+ if os.path.isdir(checkpoint_dir):
+ if "epoch" in checkpoint_dir and "step" in checkpoint_dir:
+ checkpoint_path = checkpoint_dir
+ else:
+ # Load the latest accelerator state dicts
+ ls = [
+ str(i)
+ for i in Path(checkpoint_dir).glob("*")
+ if not "audio" in str(i)
+ ]
+ ls.sort(
+ key=lambda x: int(x.split("/")[-1].split("_")[0].split("-")[-1]),
+ reverse=True,
+ )
+ checkpoint_path = ls[0]
+ accelerate.load_checkpoint_and_dispatch(
+ self.accelerator.unwrap_model(self.model),
+ os.path.join(checkpoint_path, "pytorch_model.bin"),
+ )
+ return str(checkpoint_path)
+ else:
+ # Load old .pt checkpoints
+ if self.cfg.model.generator in [
+ "bigvgan",
+ "hifigan",
+ "melgan",
+ "nsfhifigan",
+ ]:
+ ckpt = torch.load(
+ checkpoint_dir,
+ map_location=(
+ torch.device("cuda")
+ if torch.cuda.is_available()
+ else torch.device("cpu")
+ ),
+ )
+ if from_multi_gpu:
+ pretrained_generator_dict = ckpt["generator_state_dict"]
+ generator_dict = self.model.state_dict()
+
+ new_generator_dict = {
+ k.split("module.")[-1]: v
+ for k, v in pretrained_generator_dict.items()
+ if (
+ k.split("module.")[-1] in generator_dict
+ and v.shape == generator_dict[k.split("module.")[-1]].shape
+ )
+ }
+
+ generator_dict.update(new_generator_dict)
+
+ self.model.load_state_dict(generator_dict)
+ else:
+ self.model.load_state_dict(ckpt["generator_state_dict"])
+ else:
+ self.model.load_state_dict(torch.load(checkpoint_dir)["state_dict"])
+ return str(checkpoint_dir)
+
+ def inference(self):
+ """Inference via batches"""
+ for i, batch in tqdm(enumerate(self.test_dataloader)):
+ if self.cfg.preprocess.use_frame_pitch:
+ audio_pred = _vocoder_forward_funcs[self.cfg.model.generator](
+ self.cfg,
+ self.model,
+ batch["mel"].transpose(-1, -2),
+ f0s=batch["frame_pitch"].float(),
+ device=next(self.model.parameters()).device,
+ )
+ else:
+ audio_pred = _vocoder_forward_funcs[self.cfg.model.generator](
+ self.cfg,
+ self.model,
+ batch["mel"].transpose(-1, -2),
+ device=next(self.model.parameters()).device,
+ )
+ audio_ls = audio_pred.chunk(self.test_batch_size)
+ audio_gt_ls = batch["audio"].cpu().chunk(self.test_batch_size)
+ length_ls = batch["target_len"].cpu().chunk(self.test_batch_size)
+ j = 0
+ for it, it_gt, l in zip(audio_ls, audio_gt_ls, length_ls):
+ l = l.item()
+ it = it.squeeze(0).squeeze(0)[: l * self.cfg.preprocess.hop_size]
+ it_gt = it_gt.squeeze(0)[: l * self.cfg.preprocess.hop_size]
+ uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"]
+ save_audio(
+ os.path.join(self.args.output_dir, "pred", "{}.wav").format(uid),
+ it,
+ self.cfg.preprocess.sample_rate,
+ )
+ save_audio(
+ os.path.join(self.args.output_dir, "gt", "{}.wav").format(uid),
+ it_gt,
+ self.cfg.preprocess.sample_rate,
+ )
+ j += 1
+
+ if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")):
+ shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
+
+ def _set_random_seed(self, seed):
+ """Set random seed for all possible random modules."""
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.random.manual_seed(seed)
+
+ def _count_parameters(self, model):
+ return sum(p.numel() for p in model.parameters())
+
+ def _dump_cfg(self, path):
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ json5.dump(
+ self.cfg,
+ open(path, "w"),
+ indent=4,
+ sort_keys=True,
+ ensure_ascii=False,
+ quote_keys=True,
+ )
+
+
+def load_nnvocoder(
+ cfg,
+ vocoder_name,
+ weights_file,
+ from_multi_gpu=False,
+):
+ """Load the specified vocoder.
+ cfg: the vocoder config filer.
+ weights_file: a folder or a .pt path.
+ from_multi_gpu: automatically remove the "module" string in state dicts if "True".
+ """
+ print("Loading Vocoder from Weights file: {}".format(weights_file))
+
+ # Build model
+ model = _vocoders[vocoder_name](cfg)
+ if not os.path.isdir(weights_file):
+ # Load from .pt file
+ if vocoder_name in ["bigvgan", "hifigan", "melgan", "nsfhifigan"]:
+ ckpt = torch.load(
+ weights_file,
+ map_location=(
+ torch.device("cuda")
+ if torch.cuda.is_available()
+ else torch.device("cpu")
+ ),
+ )
+ if from_multi_gpu:
+ pretrained_generator_dict = ckpt["generator_state_dict"]
+ generator_dict = model.state_dict()
+
+ new_generator_dict = {
+ k.split("module.")[-1]: v
+ for k, v in pretrained_generator_dict.items()
+ if (
+ k.split("module.")[-1] in generator_dict
+ and v.shape == generator_dict[k.split("module.")[-1]].shape
+ )
+ }
+
+ generator_dict.update(new_generator_dict)
+
+ model.load_state_dict(generator_dict)
+ else:
+ model.load_state_dict(ckpt["generator_state_dict"])
+ else:
+ model.load_state_dict(torch.load(weights_file)["state_dict"])
+ else:
+ # Load from accelerator state dict
+ weights_file = os.path.join(weights_file, "checkpoint")
+ ls = [str(i) for i in Path(weights_file).glob("*") if not "audio" in str(i)]
+ ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
+ checkpoint_path = ls[0]
+ accelerator = accelerate.Accelerator()
+ model = accelerator.prepare(model)
+ accelerator.load_state(checkpoint_path)
+
+ if torch.cuda.is_available():
+ model = model.cuda()
+
+ model = model.eval()
+ return model
+
+
+def tensorize(data, device, n_samples):
+ """
+ data: a list of numpy array
+ """
+ assert type(data) == list
+ if n_samples:
+ data = data[:n_samples]
+ data = [torch.as_tensor(x, device=device) for x in data]
+ return data
+
+
+def synthesis(
+ cfg,
+ vocoder_weight_file,
+ n_samples,
+ pred,
+ f0s=None,
+ batch_size=64,
+ fast_inference=False,
+):
+ """Synthesis audios from a given vocoder and series of given features.
+ cfg: vocoder config.
+ vocoder_weight_file: a folder of accelerator state dict or a path to the .pt file.
+ pred: a list of numpy arrays. [(seq_len1, acoustic_features_dim), (seq_len2, acoustic_features_dim), ...]
+ """
+
+ vocoder_name = cfg.model.generator
+
+ print("Synthesis audios using {} vocoder...".format(vocoder_name))
+
+ ###### TODO: World Vocoder Refactor ######
+ # if vocoder_name == "world":
+ # world_inference.synthesis_audios(
+ # cfg, dataset_name, split, n_samples, pred, save_dir, tag
+ # )
+ # return
+
+ # ====== Loading neural vocoder model ======
+ vocoder = load_nnvocoder(
+ cfg, vocoder_name, weights_file=vocoder_weight_file, from_multi_gpu=True
+ )
+ device = next(vocoder.parameters()).device
+
+ # ====== Inference for predicted acoustic features ======
+ # pred: (frame_len, n_mels) -> (n_mels, frame_len)
+ mels_pred = tensorize([p.T for p in pred], device, n_samples)
+ print("For predicted mels, #sample = {}...".format(len(mels_pred)))
+ audios_pred = _vocoder_infer_funcs[vocoder_name](
+ cfg,
+ vocoder,
+ mels_pred,
+ f0s=f0s,
+ batch_size=batch_size,
+ fast_inference=fast_inference,
+ )
+ return audios_pred
diff --git a/models/vocoders/vocoder_sampler.py b/models/vocoders/vocoder_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d29f88a291dcf7386cadaeae0d990c8e76ebf98
--- /dev/null
+++ b/models/vocoders/vocoder_sampler.py
@@ -0,0 +1,126 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import random
+
+from torch.utils.data import ConcatDataset, Dataset
+from torch.utils.data.sampler import (
+ BatchSampler,
+ RandomSampler,
+ Sampler,
+ SequentialSampler,
+)
+
+
+class ScheduledSampler(Sampler):
+ """A sampler that samples data from a given concat-dataset.
+
+ Args:
+ concat_dataset (ConcatDataset): a concatenated dataset consisting of all datasets
+ batch_size (int): batch size
+ holistic_shuffle (bool): whether to shuffle the whole dataset or not
+ logger (logging.Logger): logger to print warning message
+
+ Usage:
+ For cfg.train.batch_size = 3, cfg.train.holistic_shuffle = False, cfg.train.drop_last = True:
+ >>> list(ScheduledSampler(ConcatDataset([0, 1, 2], [3, 4, 5], [6, 7, 8]])))
+ [3, 4, 5, 0, 1, 2, 6, 7, 8]
+ """
+
+ def __init__(
+ self, concat_dataset, batch_size, holistic_shuffle, logger=None, type="train"
+ ):
+ if not isinstance(concat_dataset, ConcatDataset):
+ raise ValueError(
+ "concat_dataset must be an instance of ConcatDataset, but got {}".format(
+ type(concat_dataset)
+ )
+ )
+ if not isinstance(batch_size, int):
+ raise ValueError(
+ "batch_size must be an integer, but got {}".format(type(batch_size))
+ )
+ if not isinstance(holistic_shuffle, bool):
+ raise ValueError(
+ "holistic_shuffle must be a boolean, but got {}".format(
+ type(holistic_shuffle)
+ )
+ )
+
+ self.concat_dataset = concat_dataset
+ self.batch_size = batch_size
+ self.holistic_shuffle = holistic_shuffle
+
+ affected_dataset_name = []
+ affected_dataset_len = []
+ for dataset in concat_dataset.datasets:
+ dataset_len = len(dataset)
+ dataset_name = dataset.get_dataset_name()
+ if dataset_len < batch_size:
+ affected_dataset_name.append(dataset_name)
+ affected_dataset_len.append(dataset_len)
+
+ self.type = type
+ for dataset_name, dataset_len in zip(
+ affected_dataset_name, affected_dataset_len
+ ):
+ if not type == "valid":
+ logger.warning(
+ "The {} dataset {} has a length of {}, which is smaller than the batch size {}. This may cause unexpected behavior.".format(
+ type, dataset_name, dataset_len, batch_size
+ )
+ )
+
+ def __len__(self):
+ # the number of batches with drop last
+ num_of_batches = sum(
+ [
+ math.floor(len(dataset) / self.batch_size)
+ for dataset in self.concat_dataset.datasets
+ ]
+ )
+ return num_of_batches * self.batch_size
+
+ def __iter__(self):
+ iters = []
+ for dataset in self.concat_dataset.datasets:
+ iters.append(
+ SequentialSampler(dataset).__iter__()
+ if self.holistic_shuffle
+ else RandomSampler(dataset).__iter__()
+ )
+ init_indices = [0] + self.concat_dataset.cumulative_sizes[:-1]
+ output_batches = []
+ for dataset_idx in range(len(self.concat_dataset.datasets)):
+ cur_batch = []
+ for idx in iters[dataset_idx]:
+ cur_batch.append(idx + init_indices[dataset_idx])
+ if len(cur_batch) == self.batch_size:
+ output_batches.append(cur_batch)
+ cur_batch = []
+ if self.type == "valid" and len(cur_batch) > 0:
+ output_batches.append(cur_batch)
+ cur_batch = []
+ # force drop last in training
+ random.shuffle(output_batches)
+ output_indices = [item for sublist in output_batches for item in sublist]
+ return iter(output_indices)
+
+
+def build_samplers(concat_dataset: Dataset, cfg, logger, type):
+ sampler = ScheduledSampler(
+ concat_dataset,
+ cfg.train.batch_size,
+ cfg.train.sampler.holistic_shuffle,
+ logger,
+ type,
+ )
+ batch_sampler = BatchSampler(
+ sampler,
+ cfg.train.batch_size,
+ cfg.train.sampler.drop_last if not type == "valid" else False,
+ )
+ return sampler, batch_sampler
diff --git a/models/vocoders/vocoder_trainer.py b/models/vocoders/vocoder_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5821e735a64f07fcf9c782712670e24ce6a91c04
--- /dev/null
+++ b/models/vocoders/vocoder_trainer.py
@@ -0,0 +1,180 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import random
+from pathlib import Path
+import re
+
+import accelerate
+import json5
+import numpy as np
+import torch
+from accelerate.utils import ProjectConfiguration
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+
+from models.vocoders.vocoder_dataset import VocoderConcatDataset
+from models.vocoders.vocoder_sampler import build_samplers
+
+
+class VocoderTrainer:
+ def __init__(self):
+ super().__init__()
+
+ def _init_accelerator(self):
+ """Initialize the accelerator components."""
+ self.exp_dir = os.path.join(
+ os.path.abspath(self.cfg.log_dir), self.args.exp_name
+ )
+ project_config = ProjectConfiguration(
+ project_dir=self.exp_dir, logging_dir=os.path.join(self.exp_dir, "log")
+ )
+ self.accelerator = accelerate.Accelerator(
+ gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step,
+ log_with=self.cfg.train.tracker,
+ project_config=project_config,
+ )
+ if self.accelerator.is_main_process:
+ os.makedirs(project_config.project_dir, exist_ok=True)
+ os.makedirs(project_config.logging_dir, exist_ok=True)
+ with self.accelerator.main_process_first():
+ self.accelerator.init_trackers(self.args.exp_name)
+
+ def _build_dataset(self):
+ pass
+
+ def _build_criterion(self):
+ pass
+
+ def _build_model(self):
+ pass
+
+ def _build_dataloader(self):
+ """Build dataloader which merges a series of datasets."""
+ # Build dataset instance for each dataset and combine them by ConcatDataset
+ Dataset, Collator = self._build_dataset()
+
+ # Build train set
+ datasets_list = []
+ for dataset in self.cfg.dataset:
+ subdataset = Dataset(self.cfg, dataset, is_valid=False)
+ datasets_list.append(subdataset)
+ train_dataset = VocoderConcatDataset(datasets_list, full_audio_inference=True)
+ train_collate = Collator(self.cfg)
+ _, batch_sampler = build_samplers(train_dataset, self.cfg, self.logger, "train")
+ train_loader = DataLoader(
+ train_dataset,
+ collate_fn=train_collate,
+ batch_sampler=batch_sampler,
+ num_workers=self.cfg.train.dataloader.num_worker,
+ pin_memory=self.cfg.train.dataloader.pin_memory,
+ )
+
+ # Build test set
+ datasets_list = []
+ for dataset in self.cfg.dataset:
+ subdataset = Dataset(self.cfg, dataset, is_valid=True)
+ datasets_list.append(subdataset)
+ valid_dataset = VocoderConcatDataset(datasets_list, full_audio_inference=True)
+ valid_collate = Collator(self.cfg)
+ _, batch_sampler = build_samplers(valid_dataset, self.cfg, self.logger, "train")
+ valid_loader = DataLoader(
+ valid_dataset,
+ collate_fn=valid_collate,
+ batch_sampler=batch_sampler,
+ num_workers=self.cfg.train.dataloader.num_worker,
+ pin_memory=self.cfg.train.dataloader.pin_memory,
+ )
+ return train_loader, valid_loader
+
+ def _build_optimizer(self):
+ pass
+
+ def _build_scheduler(self):
+ pass
+
+ def _load_model(self, checkpoint_dir, checkpoint_path=None, resume_type="resume"):
+ """Load model from checkpoint. If a folder is given, it will
+ load the latest checkpoint in checkpoint_dir. If a path is given
+ it will load the checkpoint specified by checkpoint_path.
+ **Only use this method after** ``accelerator.prepare()``.
+ """
+ if checkpoint_path is None:
+ ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
+ ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
+ checkpoint_path = ls[0]
+ if resume_type == "resume":
+ self.accelerator.load_state(checkpoint_path)
+ elif resume_type == "finetune":
+ accelerate.load_checkpoint_and_dispatch(
+ self.accelerator.unwrap_model(self.model),
+ os.path.join(checkpoint_path, "pytorch_model.bin"),
+ )
+ self.logger.info("Load model weights for finetune SUCCESS!")
+ else:
+ raise ValueError("Unsupported resume type: {}".format(resume_type))
+ self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
+ self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
+ return checkpoint_path
+
+ def train_loop(self):
+ pass
+
+ def _train_epoch(self):
+ pass
+
+ def _valid_epoch(self):
+ pass
+
+ def _train_step(self):
+ pass
+
+ def _valid_step(self):
+ pass
+
+ def _inference(self):
+ pass
+
+ def _set_random_seed(self, seed):
+ """Set random seed for all possible random modules."""
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.random.manual_seed(seed)
+
+ def _check_nan(self, loss):
+ if torch.any(torch.isnan(loss)):
+ self.logger.fatal("Fatal Error: NaN!")
+ self.logger.error("loss = {:.6f}".format(loss.item()), in_order=True)
+
+ def _check_basic_configs(self):
+ if self.cfg.train.gradient_accumulation_step <= 0:
+ self.logger.fatal("Invalid gradient_accumulation_step value!")
+ self.logger.error(
+ f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
+ )
+ self.accelerator.end_training()
+ raise ValueError(
+ f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
+ )
+
+ def _count_parameters(self):
+ pass
+
+ def _dump_cfg(self, path):
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ json5.dump(
+ self.cfg,
+ open(path, "w"),
+ indent=4,
+ sort_keys=True,
+ ensure_ascii=False,
+ quote_keys=True,
+ )
+
+ def _is_valid_pattern(self, directory_name):
+ directory_name = str(directory_name)
+ pattern = r"^epoch-\d{4}_step-\d{7}_loss-\d{1}\.\d{6}"
+ return re.match(pattern, directory_name) is not None
diff --git a/modules/__init__.py b/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/modules/activation_functions/__init__.py b/modules/activation_functions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d35dfbc27a825ea8ed286074d40971764b890468
--- /dev/null
+++ b/modules/activation_functions/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .gated_activation_unit import GaU
+from .snake import Snake, SnakeBeta
diff --git a/modules/activation_functions/gated_activation_unit.py b/modules/activation_functions/gated_activation_unit.py
new file mode 100644
index 0000000000000000000000000000000000000000..e55956981fa99299059dd1c02236011fbadac049
--- /dev/null
+++ b/modules/activation_functions/gated_activation_unit.py
@@ -0,0 +1,61 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+
+from modules.general.utils import Conv1d
+
+
+class GaU(nn.Module):
+ r"""Gated Activation Unit (GaU) proposed in `Gated Activation Units for Neural
+ Networks `_.
+
+ Args:
+ channels: number of input channels.
+ kernel_size: kernel size of the convolution.
+ dilation: dilation rate of the convolution.
+ d_context: dimension of context tensor, None if don't use context.
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ kernel_size: int = 3,
+ dilation: int = 1,
+ d_context: int = None,
+ ):
+ super().__init__()
+
+ self.context = d_context
+
+ self.conv = Conv1d(
+ channels,
+ channels * 2,
+ kernel_size,
+ dilation=dilation,
+ padding=dilation * (kernel_size - 1) // 2,
+ )
+
+ if self.context:
+ self.context_proj = Conv1d(d_context, channels * 2, 1)
+
+ def forward(self, x: torch.Tensor, context: torch.Tensor = None):
+ r"""Calculate forward propagation.
+
+ Args:
+ x: input tensor with shape [B, C, T].
+ context: context tensor with shape [B, ``d_context``, T], default to None.
+ """
+
+ h = self.conv(x)
+
+ if self.context:
+ h = h + self.context_proj(context)
+
+ h1, h2 = h.chunk(2, 1)
+ h = torch.tanh(h1) * torch.sigmoid(h2)
+
+ return h
diff --git a/modules/activation_functions/snake.py b/modules/activation_functions/snake.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5398819f2e7c836e02189321cbf048bedcd9e91
--- /dev/null
+++ b/modules/activation_functions/snake.py
@@ -0,0 +1,122 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch import nn, pow, sin
+from torch.nn import Parameter
+
+
+class Snake(nn.Module):
+ r"""Implementation of a sine-based periodic activation function.
+ Alpha is initialized to 1 by default, higher values means higher frequency.
+ It will be trained along with the rest of your model.
+
+ Args:
+ in_features: shape of the input
+ alpha: trainable parameter
+
+ Shape:
+ - Input: (B, C, T)
+ - Output: (B, C, T), same shape as the input
+
+ References:
+ This activation function is from this paper by Liu Ziyin, Tilman Hartwig,
+ Masahito Ueda: https://arxiv.org/abs/2006.08195
+
+ Examples:
+ >>> a1 = Snake(256)
+ >>> x = torch.randn(256)
+ >>> x = a1(x)
+ """
+
+ def __init__(
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
+ ):
+ super(Snake, self).__init__()
+ self.in_features = in_features
+
+ # initialize alpha
+ self.alpha_logscale = alpha_logscale
+ if self.alpha_logscale: # log scale alphas initialized to zeros
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
+ else: # linear scale alphas initialized to ones
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
+
+ self.alpha.requires_grad = alpha_trainable
+
+ self.no_div_by_zero = 0.000000001
+
+ def forward(self, x):
+ r"""Forward pass of the function. Applies the function to the input elementwise.
+ Snake ∶= x + 1/a * sin^2 (ax)
+ """
+
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
+ if self.alpha_logscale:
+ alpha = torch.exp(alpha)
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
+
+ return x
+
+
+class SnakeBeta(nn.Module):
+ r"""A modified Snake function which uses separate parameters for the magnitude
+ of the periodic components. Alpha is initialized to 1 by default,
+ higher values means higher frequency. Beta is initialized to 1 by default,
+ higher values means higher magnitude. Both will be trained along with the
+ rest of your model.
+
+ Args:
+ in_features: shape of the input
+ alpha: trainable parameter that controls frequency
+ beta: trainable parameter that controls magnitude
+
+ Shape:
+ - Input: (B, C, T)
+ - Output: (B, C, T), same shape as the input
+
+ References:
+ This activation function is a modified version based on this paper by Liu Ziyin,
+ Tilman Hartwig, Masahito Ueda: https://arxiv.org/abs/2006.08195
+
+ Examples:
+ >>> a1 = SnakeBeta(256)
+ >>> x = torch.randn(256)
+ >>> x = a1(x)
+ """
+
+ def __init__(
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
+ ):
+ super(SnakeBeta, self).__init__()
+ self.in_features = in_features
+
+ # initialize alpha
+ self.alpha_logscale = alpha_logscale
+ if self.alpha_logscale: # log scale alphas initialized to zeros
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
+ else: # linear scale alphas initialized to ones
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
+ self.beta = Parameter(torch.ones(in_features) * alpha)
+
+ self.alpha.requires_grad = alpha_trainable
+ self.beta.requires_grad = alpha_trainable
+
+ self.no_div_by_zero = 0.000000001
+
+ def forward(self, x):
+ r"""Forward pass of the function. Applies the function to the input elementwise.
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
+ """
+
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
+ if self.alpha_logscale:
+ alpha = torch.exp(alpha)
+ beta = torch.exp(beta)
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
+
+ return x
diff --git a/modules/anti_aliasing/__init__.py b/modules/anti_aliasing/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..193756bfa04218d10f65fbed8665a8a0abad1cd4
--- /dev/null
+++ b/modules/anti_aliasing/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .act import *
+from .filter import *
+from .resample import *
diff --git a/modules/anti_aliasing/act.py b/modules/anti_aliasing/act.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b62c29467ad33899dd699437236f9b808a4dbf5
--- /dev/null
+++ b/modules/anti_aliasing/act.py
@@ -0,0 +1,36 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch.nn as nn
+
+from .resample import *
+
+# This code is adopted from BigVGAN under the MIT License
+# https://github.com/NVIDIA/BigVGAN
+
+
+class Activation1d(nn.Module):
+ def __init__(
+ self,
+ activation,
+ up_ratio: int = 2,
+ down_ratio: int = 2,
+ up_kernel_size: int = 12,
+ down_kernel_size: int = 12,
+ ):
+ super().__init__()
+ self.up_ratio = up_ratio
+ self.down_ratio = down_ratio
+ self.act = activation
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
+
+ # x: [B,C,T]
+ def forward(self, x):
+ x = self.upsample(x)
+ x = self.act(x)
+ x = self.downsample(x)
+
+ return x
diff --git a/modules/anti_aliasing/filter.py b/modules/anti_aliasing/filter.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91688866bd073c7ce268004791ec1263a6d76d7
--- /dev/null
+++ b/modules/anti_aliasing/filter.py
@@ -0,0 +1,99 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+
+if "sinc" in dir(torch):
+ sinc = torch.sinc
+else:
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
+ # https://adefossez.github.io/julius/julius/core.html
+ def sinc(x: torch.Tensor):
+ """
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
+ """
+ return torch.where(
+ x == 0,
+ torch.tensor(1.0, device=x.device, dtype=x.dtype),
+ torch.sin(math.pi * x) / math.pi / x,
+ )
+
+
+# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
+# https://adefossez.github.io/julius/julius/lowpass.html
+def kaiser_sinc_filter1d(
+ cutoff, half_width, kernel_size
+): # return filter [1,1,kernel_size]
+ even = kernel_size % 2 == 0
+ half_size = kernel_size // 2
+
+ # For kaiser window
+ delta_f = 4 * half_width
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
+ if A > 50.0:
+ beta = 0.1102 * (A - 8.7)
+ elif A >= 21.0:
+ beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
+ else:
+ beta = 0.0
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
+
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
+ if even:
+ time = torch.arange(-half_size, half_size) + 0.5
+ else:
+ time = torch.arange(kernel_size) - half_size
+ if cutoff == 0:
+ filter_ = torch.zeros_like(time)
+ else:
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
+ # Normalize filter to have sum = 1, otherwise we will have a small leakage
+ # of the constant component in the input signal.
+ filter_ /= filter_.sum()
+ filter = filter_.view(1, 1, kernel_size)
+
+ return filter
+
+
+class LowPassFilter1d(nn.Module):
+ def __init__(
+ self,
+ cutoff=0.5,
+ half_width=0.6,
+ stride: int = 1,
+ padding: bool = True,
+ padding_mode: str = "replicate",
+ kernel_size: int = 12,
+ ):
+ # kernel_size should be even number for stylegan3 setup,
+ # in this implementation, odd number is also possible.
+ super().__init__()
+ if cutoff < -0.0:
+ raise ValueError("Minimum cutoff must be larger than zero.")
+ if cutoff > 0.5:
+ raise ValueError("A cutoff above 0.5 does not make sense.")
+ self.kernel_size = kernel_size
+ self.even = kernel_size % 2 == 0
+ self.pad_left = kernel_size // 2 - int(self.even)
+ self.pad_right = kernel_size // 2
+ self.stride = stride
+ self.padding = padding
+ self.padding_mode = padding_mode
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
+ self.register_buffer("filter", filter)
+
+ # input [B, C, T]
+ def forward(self, x):
+ _, C, _ = x.shape
+
+ if self.padding:
+ x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
+ out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
+
+ return out
diff --git a/modules/anti_aliasing/resample.py b/modules/anti_aliasing/resample.py
new file mode 100644
index 0000000000000000000000000000000000000000..462c8a42380bd8f6d26e19350599bbd5820f79a1
--- /dev/null
+++ b/modules/anti_aliasing/resample.py
@@ -0,0 +1,65 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+#################### Anti-aliasing ####################
+
+import torch.nn as nn
+from torch.nn import functional as F
+
+from .filter import *
+
+# This code is adopted from BigVGAN under the MIT License
+# https://github.com/NVIDIA/BigVGAN
+
+
+class UpSample1d(nn.Module):
+ def __init__(self, ratio=2, kernel_size=None):
+ super().__init__()
+ self.ratio = ratio
+ self.kernel_size = (
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
+ )
+ self.stride = ratio
+ self.pad = self.kernel_size // ratio - 1
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
+ self.pad_right = (
+ self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
+ )
+ filter = kaiser_sinc_filter1d(
+ cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
+ )
+ self.register_buffer("filter", filter)
+
+ # x: [B, C, T]
+ def forward(self, x):
+ _, C, _ = x.shape
+
+ x = F.pad(x, (self.pad, self.pad), mode="replicate")
+ x = self.ratio * F.conv_transpose1d(
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
+ )
+ x = x[..., self.pad_left : -self.pad_right]
+
+ return x
+
+
+class DownSample1d(nn.Module):
+ def __init__(self, ratio=2, kernel_size=None):
+ super().__init__()
+ self.ratio = ratio
+ self.kernel_size = (
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
+ )
+ self.lowpass = LowPassFilter1d(
+ cutoff=0.5 / ratio,
+ half_width=0.6 / ratio,
+ stride=ratio,
+ kernel_size=self.kernel_size,
+ )
+
+ def forward(self, x):
+ xx = self.lowpass(x)
+
+ return xx
diff --git a/modules/base/base_module.py b/modules/base/base_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3c81206f4f130b706fffd45e93a054ce1231ddd
--- /dev/null
+++ b/modules/base/base_module.py
@@ -0,0 +1,75 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+class LayerNorm(nn.Module):
+ def __init__(self, channels, eps=1e-5):
+ super().__init__()
+ self.channels = channels
+ self.eps = eps
+
+ self.gamma = nn.Parameter(torch.ones(channels))
+ self.beta = nn.Parameter(torch.zeros(channels))
+
+ def forward(self, x):
+ x = x.transpose(1, -1)
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
+ return x.transpose(1, -1)
+
+
+class ConvReluNorm(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ hidden_channels,
+ out_channels,
+ kernel_size,
+ n_layers,
+ p_dropout,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.hidden_channels = hidden_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.n_layers = n_layers
+ self.p_dropout = p_dropout
+ assert n_layers > 1, "Number of layers should be larger than 0."
+
+ self.conv_layers = nn.ModuleList()
+ self.norm_layers = nn.ModuleList()
+ self.conv_layers.append(
+ nn.Conv1d(
+ in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
+ )
+ )
+ self.norm_layers.append(LayerNorm(hidden_channels))
+ self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
+ for _ in range(n_layers - 1):
+ self.conv_layers.append(
+ nn.Conv1d(
+ hidden_channels,
+ hidden_channels,
+ kernel_size,
+ padding=kernel_size // 2,
+ )
+ )
+ self.norm_layers.append(LayerNorm(hidden_channels))
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
+ self.proj.weight.data.zero_()
+ self.proj.bias.data.zero_()
+
+ def forward(self, x, x_mask):
+ x_org = x
+ for i in range(self.n_layers):
+ x = self.conv_layers[i](x * x_mask)
+ x = self.norm_layers[i](x)
+ x = self.relu_drop(x)
+ x = x_org + self.proj(x)
+ return x * x_mask
diff --git a/modules/dac/__init__.py b/modules/dac/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..025cc5f6548f6dcfe9ba0f79a6e239d15013979d
--- /dev/null
+++ b/modules/dac/__init__.py
@@ -0,0 +1,22 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# This code is modified from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/__init__.py
+
+__version__ = "1.0.0"
+
+# preserved here for legacy reasons
+__model_version__ = "latest"
+
+# import audiotools
+
+# audiotools.ml.BaseModel.INTERN += ["dac.**"]
+# audiotools.ml.BaseModel.EXTERN += ["einops"]
+
+
+from . import nn
+from . import model
+from .model import DAC
+from .model import DACFile
diff --git a/modules/dac/model/__init__.py b/modules/dac/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b26926624bb3cb84e25a9268bfb213ee4a10b105
--- /dev/null
+++ b/modules/dac/model/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# This code is modified from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/model/__init__.py
+
+from .base import CodecMixin
+from .base import DACFile
+from .dac import DAC
+from .discriminator import Discriminator
diff --git a/modules/dac/model/base.py b/modules/dac/model/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..69754c97f40876525890dd19eeb41823cc7800c0
--- /dev/null
+++ b/modules/dac/model/base.py
@@ -0,0 +1,301 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# This code is modified from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/model/base.py
+
+import math
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Union
+
+import numpy as np
+import torch
+import tqdm
+from audiotools import AudioSignal
+from torch import nn
+
+SUPPORTED_VERSIONS = ["1.0.0"]
+
+
+@dataclass
+class DACFile:
+ codes: torch.Tensor
+
+ # Metadata
+ chunk_length: int
+ original_length: int
+ input_db: float
+ channels: int
+ sample_rate: int
+ padding: bool
+ dac_version: str
+
+ def save(self, path):
+ artifacts = {
+ "codes": self.codes.numpy().astype(np.uint16),
+ "metadata": {
+ "input_db": self.input_db.numpy().astype(np.float32),
+ "original_length": self.original_length,
+ "sample_rate": self.sample_rate,
+ "chunk_length": self.chunk_length,
+ "channels": self.channels,
+ "padding": self.padding,
+ "dac_version": SUPPORTED_VERSIONS[-1],
+ },
+ }
+ path = Path(path).with_suffix(".dac")
+ with open(path, "wb") as f:
+ np.save(f, artifacts)
+ return path
+
+ @classmethod
+ def load(cls, path):
+ artifacts = np.load(path, allow_pickle=True)[()]
+ codes = torch.from_numpy(artifacts["codes"].astype(int))
+ if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS:
+ raise RuntimeError(
+ f"Given file {path} can't be loaded with this version of descript-audio-codec."
+ )
+ return cls(codes=codes, **artifacts["metadata"])
+
+
+class CodecMixin:
+ @property
+ def padding(self):
+ if not hasattr(self, "_padding"):
+ self._padding = True
+ return self._padding
+
+ @padding.setter
+ def padding(self, value):
+ assert isinstance(value, bool)
+
+ layers = [
+ l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))
+ ]
+
+ for layer in layers:
+ if value:
+ if hasattr(layer, "original_padding"):
+ layer.padding = layer.original_padding
+ else:
+ layer.original_padding = layer.padding
+ layer.padding = tuple(0 for _ in range(len(layer.padding)))
+
+ self._padding = value
+
+ def get_delay(self):
+ # Any number works here, delay is invariant to input length
+ l_out = self.get_output_length(0)
+ L = l_out
+
+ layers = []
+ for layer in self.modules():
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
+ layers.append(layer)
+
+ for layer in reversed(layers):
+ d = layer.dilation[0]
+ k = layer.kernel_size[0]
+ s = layer.stride[0]
+
+ if isinstance(layer, nn.ConvTranspose1d):
+ L = ((L - d * (k - 1) - 1) / s) + 1
+ elif isinstance(layer, nn.Conv1d):
+ L = (L - 1) * s + d * (k - 1) + 1
+
+ L = math.ceil(L)
+
+ l_in = L
+
+ return (l_in - l_out) // 2
+
+ def get_output_length(self, input_length):
+ L = input_length
+ # Calculate output length
+ for layer in self.modules():
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
+ d = layer.dilation[0]
+ k = layer.kernel_size[0]
+ s = layer.stride[0]
+
+ if isinstance(layer, nn.Conv1d):
+ L = ((L - d * (k - 1) - 1) / s) + 1
+ elif isinstance(layer, nn.ConvTranspose1d):
+ L = (L - 1) * s + d * (k - 1) + 1
+
+ L = math.floor(L)
+ return L
+
+ @torch.no_grad()
+ def compress(
+ self,
+ audio_path_or_signal: Union[str, Path, AudioSignal],
+ win_duration: float = 1.0,
+ verbose: bool = False,
+ normalize_db: float = -16,
+ n_quantizers: int = None,
+ ) -> DACFile:
+ """Processes an audio signal from a file or AudioSignal object into
+ discrete codes. This function processes the signal in short windows,
+ using constant GPU memory.
+
+ Parameters
+ ----------
+ audio_path_or_signal : Union[str, Path, AudioSignal]
+ audio signal to reconstruct
+ win_duration : float, optional
+ window duration in seconds, by default 5.0
+ verbose : bool, optional
+ by default False
+ normalize_db : float, optional
+ normalize db, by default -16
+
+ Returns
+ -------
+ DACFile
+ Object containing compressed codes and metadata
+ required for decompression
+ """
+ audio_signal = audio_path_or_signal
+ if isinstance(audio_signal, (str, Path)):
+ audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
+
+ self.eval()
+ original_padding = self.padding
+ original_device = audio_signal.device
+
+ audio_signal = audio_signal.clone()
+ original_sr = audio_signal.sample_rate
+
+ resample_fn = audio_signal.resample
+ loudness_fn = audio_signal.loudness
+
+ # If audio is > 10 minutes long, use the ffmpeg versions
+ if audio_signal.signal_duration >= 10 * 60 * 60:
+ resample_fn = audio_signal.ffmpeg_resample
+ loudness_fn = audio_signal.ffmpeg_loudness
+
+ original_length = audio_signal.signal_length
+ resample_fn(self.sample_rate)
+ input_db = loudness_fn()
+
+ if normalize_db is not None:
+ audio_signal.normalize(normalize_db)
+ audio_signal.ensure_max_of_audio()
+
+ nb, nac, nt = audio_signal.audio_data.shape
+ audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
+ win_duration = (
+ audio_signal.signal_duration if win_duration is None else win_duration
+ )
+
+ if audio_signal.signal_duration <= win_duration:
+ # Unchunked compression (used if signal length < win duration)
+ self.padding = True
+ n_samples = nt
+ hop = nt
+ else:
+ # Chunked inference
+ self.padding = False
+ # Zero-pad signal on either side by the delay
+ audio_signal.zero_pad(self.delay, self.delay)
+ n_samples = int(win_duration * self.sample_rate)
+ # Round n_samples to nearest hop length multiple
+ n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
+ hop = self.get_output_length(n_samples)
+
+ codes = []
+ range_fn = range if not verbose else tqdm.trange
+
+ for i in range_fn(0, nt, hop):
+ x = audio_signal[..., i : i + n_samples]
+ x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
+
+ audio_data = x.audio_data.to(self.device)
+ audio_data = self.preprocess(audio_data, self.sample_rate)
+ _, c, _, _, _ = self.encode(audio_data, n_quantizers)
+ codes.append(c.to(original_device))
+ chunk_length = c.shape[-1]
+
+ codes = torch.cat(codes, dim=-1)
+
+ dac_file = DACFile(
+ codes=codes,
+ chunk_length=chunk_length,
+ original_length=original_length,
+ input_db=input_db,
+ channels=nac,
+ sample_rate=original_sr,
+ padding=self.padding,
+ dac_version=SUPPORTED_VERSIONS[-1],
+ )
+
+ if n_quantizers is not None:
+ codes = codes[:, :n_quantizers, :]
+
+ self.padding = original_padding
+ return dac_file
+
+ @torch.no_grad()
+ def decompress(
+ self,
+ obj: Union[str, Path, DACFile],
+ verbose: bool = False,
+ ) -> AudioSignal:
+ """Reconstruct audio from a given .dac file
+
+ Parameters
+ ----------
+ obj : Union[str, Path, DACFile]
+ .dac file location or corresponding DACFile object.
+ verbose : bool, optional
+ Prints progress if True, by default False
+
+ Returns
+ -------
+ AudioSignal
+ Object with the reconstructed audio
+ """
+ self.eval()
+ if isinstance(obj, (str, Path)):
+ obj = DACFile.load(obj)
+
+ original_padding = self.padding
+ self.padding = obj.padding
+
+ range_fn = range if not verbose else tqdm.trange
+ codes = obj.codes
+ original_device = codes.device
+ chunk_length = obj.chunk_length
+ recons = []
+
+ for i in range_fn(0, codes.shape[-1], chunk_length):
+ c = codes[..., i : i + chunk_length].to(self.device)
+ z = self.quantizer.from_codes(c)[0]
+ r = self.decode(z)
+ recons.append(r.to(original_device))
+
+ recons = torch.cat(recons, dim=-1)
+ recons = AudioSignal(recons, self.sample_rate)
+
+ resample_fn = recons.resample
+ loudness_fn = recons.loudness
+
+ # If audio is > 10 minutes long, use the ffmpeg versions
+ if recons.signal_duration >= 10 * 60 * 60:
+ resample_fn = recons.ffmpeg_resample
+ loudness_fn = recons.ffmpeg_loudness
+
+ recons.normalize(obj.input_db)
+ resample_fn(obj.sample_rate)
+ recons = recons[..., : obj.original_length]
+ loudness_fn()
+ recons.audio_data = recons.audio_data.reshape(
+ -1, obj.channels, obj.original_length
+ )
+
+ self.padding = original_padding
+ return recons
diff --git a/modules/dac/model/dac.py b/modules/dac/model/dac.py
new file mode 100644
index 0000000000000000000000000000000000000000..22f49ee7f201bd7baedc4d557754b5c508f68a73
--- /dev/null
+++ b/modules/dac/model/dac.py
@@ -0,0 +1,439 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# This code is modified from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/model/dac.py
+
+import math
+from typing import List
+from typing import Union
+
+import numpy as np
+import torch
+from audiotools import AudioSignal
+from audiotools.ml import BaseModel
+from torch import nn
+
+from .base import CodecMixin
+from ..nn.layers import Snake1d
+from ..nn.layers import WNConv1d
+from ..nn.layers import WNConvTranspose1d
+from ..nn.quantize import ResidualVectorQuantize
+from .encodec import SConv1d, SConvTranspose1d, SLSTM
+
+
+def init_weights(m):
+ if isinstance(m, nn.Conv1d):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ nn.init.constant_(m.bias, 0)
+
+
+class ResidualUnit(nn.Module):
+ def __init__(self, dim: int = 16, dilation: int = 1, causal: bool = False):
+ super().__init__()
+ conv1d_type = SConv1d # if causal else WNConv1d
+ pad = ((7 - 1) * dilation) // 2
+ self.block = nn.Sequential(
+ Snake1d(dim),
+ conv1d_type(
+ dim,
+ dim,
+ kernel_size=7,
+ dilation=dilation,
+ padding=pad,
+ causal=causal,
+ norm="weight_norm",
+ ),
+ Snake1d(dim),
+ conv1d_type(dim, dim, kernel_size=1, causal=causal, norm="weight_norm"),
+ )
+
+ def forward(self, x):
+ y = self.block(x)
+ pad = (x.shape[-1] - y.shape[-1]) // 2
+ if pad > 0:
+ x = x[..., pad:-pad]
+ return x + y
+
+
+class EncoderBlock(nn.Module):
+ def __init__(self, dim: int = 16, stride: int = 1, causal: bool = False):
+ super().__init__()
+ conv1d_type = SConv1d # if causal else WNConv1d
+ self.block = nn.Sequential(
+ ResidualUnit(dim // 2, dilation=1, causal=causal),
+ ResidualUnit(dim // 2, dilation=3, causal=causal),
+ ResidualUnit(dim // 2, dilation=9, causal=causal),
+ Snake1d(dim // 2),
+ conv1d_type(
+ dim // 2,
+ dim,
+ kernel_size=2 * stride,
+ stride=stride,
+ padding=math.ceil(stride / 2),
+ causal=causal,
+ norm="weight_norm",
+ ),
+ )
+
+ def forward(self, x):
+ return self.block(x)
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ d_model: int = 64,
+ strides: list = [2, 4, 8, 8],
+ d_latent: int = 64,
+ causal: bool = False,
+ lstm: int = 2,
+ ):
+ super().__init__()
+ conv1d_type = SConv1d # if causal else WNConv1d
+ # Create first convolution
+ self.block = [
+ conv1d_type(
+ 1, d_model, kernel_size=7, padding=3, causal=causal, norm="weight_norm"
+ )
+ ]
+
+ # Create EncoderBlocks that double channels as they downsample by `stride`
+ for stride in strides:
+ d_model *= 2
+ self.block += [EncoderBlock(d_model, stride=stride, causal=causal)]
+
+ # Add LSTM if needed
+ self.use_lstm = lstm
+ if lstm:
+ self.block += [SLSTM(d_model, lstm)]
+
+ # Create last convolution
+ self.block += [
+ Snake1d(d_model),
+ conv1d_type(
+ d_model,
+ d_latent,
+ kernel_size=3,
+ padding=1,
+ causal=causal,
+ norm="weight_norm",
+ ),
+ ]
+
+ # Wrap black into nn.Sequential
+ self.block = nn.Sequential(*self.block)
+ self.enc_dim = d_model
+
+ def forward(self, x):
+ return self.block(x)
+
+
+class DecoderBlock(nn.Module):
+ def __init__(
+ self,
+ input_dim: int = 16,
+ output_dim: int = 8,
+ stride: int = 1,
+ causal: bool = False,
+ ):
+ super().__init__()
+ conv1d_type = SConvTranspose1d # if causal else WNConvTranspose1d
+ self.block = nn.Sequential(
+ Snake1d(input_dim),
+ conv1d_type(
+ input_dim,
+ output_dim,
+ kernel_size=2 * stride,
+ stride=stride,
+ padding=math.ceil(stride / 2),
+ causal=causal,
+ norm="weight_norm",
+ ),
+ ResidualUnit(output_dim, dilation=1, causal=causal),
+ ResidualUnit(output_dim, dilation=3, causal=causal),
+ ResidualUnit(output_dim, dilation=9, causal=causal),
+ )
+
+ def forward(self, x):
+ return self.block(x)
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ input_channel,
+ channels,
+ rates,
+ d_out: int = 1,
+ causal: bool = False,
+ lstm: int = 2,
+ ):
+ super().__init__()
+ conv1d_type = SConv1d # if causal else WNConv1d
+ # Add first conv layer
+ layers = [
+ conv1d_type(
+ input_channel,
+ channels,
+ kernel_size=7,
+ padding=3,
+ causal=causal,
+ norm="weight_norm",
+ )
+ ]
+
+ if lstm:
+ layers += [SLSTM(channels, num_layers=lstm)]
+
+ # Add upsampling + MRF blocks
+ for i, stride in enumerate(rates):
+ input_dim = channels // 2**i
+ output_dim = channels // 2 ** (i + 1)
+ layers += [DecoderBlock(input_dim, output_dim, stride, causal=causal)]
+
+ # Add final conv layer
+ layers += [
+ Snake1d(output_dim),
+ conv1d_type(
+ output_dim,
+ d_out,
+ kernel_size=7,
+ padding=3,
+ causal=causal,
+ norm="weight_norm",
+ ),
+ nn.Tanh(),
+ ]
+
+ self.model = nn.Sequential(*layers)
+
+ def forward(self, x):
+ return self.model(x)
+
+
+class DAC(BaseModel, CodecMixin):
+ def __init__(
+ self,
+ encoder_dim: int = 64,
+ encoder_rates: List[int] = [2, 4, 8, 8],
+ latent_dim: int = None,
+ decoder_dim: int = 1536,
+ decoder_rates: List[int] = [8, 8, 4, 2],
+ n_codebooks: int = 9,
+ codebook_size: int = 1024,
+ codebook_dim: Union[int, list] = 8,
+ quantizer_dropout: bool = False,
+ sample_rate: int = 44100,
+ lstm: int = 2,
+ causal: bool = False,
+ ):
+ super().__init__()
+
+ self.encoder_dim = encoder_dim
+ self.encoder_rates = encoder_rates
+ self.decoder_dim = decoder_dim
+ self.decoder_rates = decoder_rates
+ self.sample_rate = sample_rate
+
+ if latent_dim is None:
+ latent_dim = encoder_dim * (2 ** len(encoder_rates))
+
+ self.latent_dim = latent_dim
+
+ self.hop_length = np.prod(encoder_rates)
+ self.encoder = Encoder(
+ encoder_dim, encoder_rates, latent_dim, causal=causal, lstm=lstm
+ )
+
+ self.n_codebooks = n_codebooks
+ self.codebook_size = codebook_size
+ self.codebook_dim = codebook_dim
+ self.quantizer = ResidualVectorQuantize(
+ input_dim=latent_dim,
+ n_codebooks=n_codebooks,
+ codebook_size=codebook_size,
+ codebook_dim=codebook_dim,
+ quantizer_dropout=quantizer_dropout,
+ )
+
+ self.decoder = Decoder(
+ latent_dim,
+ decoder_dim,
+ decoder_rates,
+ lstm=lstm,
+ causal=causal,
+ )
+ self.sample_rate = sample_rate
+ self.apply(init_weights)
+
+ self.delay = self.get_delay()
+
+ def preprocess(self, audio_data, sample_rate):
+ if sample_rate is None:
+ sample_rate = self.sample_rate
+ assert sample_rate == self.sample_rate
+
+ length = audio_data.shape[-1]
+ right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
+ audio_data = nn.functional.pad(audio_data, (0, right_pad))
+
+ return audio_data
+
+ def encode(
+ self,
+ audio_data: torch.Tensor,
+ n_quantizers: int = None,
+ ):
+ """Encode given audio data and return quantized latent codes
+
+ Parameters
+ ----------
+ audio_data : Tensor[B x 1 x T]
+ Audio data to encode
+ n_quantizers : int, optional
+ Number of quantizers to use, by default None
+ If None, all quantizers are used.
+
+ Returns
+ -------
+ dict
+ A dictionary with the following keys:
+ "z" : Tensor[B x D x T]
+ Quantized continuous representation of input
+ "codes" : Tensor[B x N x T]
+ Codebook indices for each codebook
+ (quantized discrete representation of input)
+ "latents" : Tensor[B x N*D x T]
+ Projected latents (continuous representation of input before quantization)
+ "vq/commitment_loss" : Tensor[1]
+ Commitment loss to train encoder to predict vectors closer to codebook
+ entries
+ "vq/codebook_loss" : Tensor[1]
+ Codebook loss to update the codebook
+ "length" : int
+ Number of samples in input audio
+ """
+ z = self.encoder(audio_data)
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(
+ z, n_quantizers
+ )
+ return z, codes, latents, commitment_loss, codebook_loss
+
+ def decode(self, z: torch.Tensor):
+ """Decode given latent codes and return audio data
+
+ Parameters
+ ----------
+ z : Tensor[B x D x T]
+ Quantized continuous representation of input
+ length : int, optional
+ Number of samples in output audio, by default None
+
+ Returns
+ -------
+ dict
+ A dictionary with the following keys:
+ "audio" : Tensor[B x 1 x length]
+ Decoded audio data.
+ """
+ return self.decoder(z)
+
+ def forward(
+ self,
+ audio_data: torch.Tensor,
+ sample_rate: int = None,
+ n_quantizers: int = None,
+ ):
+ """Model forward pass
+
+ Parameters
+ ----------
+ audio_data : Tensor[B x 1 x T]
+ Audio data to encode
+ sample_rate : int, optional
+ Sample rate of audio data in Hz, by default None
+ If None, defaults to `self.sample_rate`
+ n_quantizers : int, optional
+ Number of quantizers to use, by default None.
+ If None, all quantizers are used.
+
+ Returns
+ -------
+ dict
+ A dictionary with the following keys:
+ "z" : Tensor[B x D x T]
+ Quantized continuous representation of input
+ "codes" : Tensor[B x N x T]
+ Codebook indices for each codebook
+ (quantized discrete representation of input)
+ "latents" : Tensor[B x N*D x T]
+ Projected latents (continuous representation of input before quantization)
+ "vq/commitment_loss" : Tensor[1]
+ Commitment loss to train encoder to predict vectors closer to codebook
+ entries
+ "vq/codebook_loss" : Tensor[1]
+ Codebook loss to update the codebook
+ "length" : int
+ Number of samples in input audio
+ "audio" : Tensor[B x 1 x length]
+ Decoded audio data.
+ """
+ length = audio_data.shape[-1]
+ audio_data = self.preprocess(audio_data, sample_rate)
+ z, codes, latents, commitment_loss, codebook_loss = self.encode(
+ audio_data, n_quantizers
+ )
+
+ x = self.decode(z)
+ return {
+ "audio": x[..., :length],
+ "z": z,
+ "codes": codes,
+ "latents": latents,
+ "vq/commitment_loss": commitment_loss,
+ "vq/codebook_loss": codebook_loss,
+ }
+
+
+if __name__ == "__main__":
+ import numpy as np
+ from functools import partial
+
+ model = DAC().to("cpu")
+
+ for n, m in model.named_modules():
+ o = m.extra_repr()
+ p = sum([np.prod(p.size()) for p in m.parameters()])
+ fn = lambda o, p: o + f" {p/1e6:<.3f}M params."
+ setattr(m, "extra_repr", partial(fn, o=o, p=p))
+ print(model)
+ print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
+
+ length = 88200 * 2
+ x = torch.randn(1, 1, length).to(model.device)
+ x.requires_grad_(True)
+ x.retain_grad()
+
+ # Make a forward pass
+ out = model(x)["audio"]
+ print("Input shape:", x.shape)
+ print("Output shape:", out.shape)
+
+ # Create gradient variable
+ grad = torch.zeros_like(out)
+ grad[:, :, grad.shape[-1] // 2] = 1
+
+ # Make a backward pass
+ out.backward(grad)
+
+ # Check non-zero values
+ gradmap = x.grad.squeeze(0)
+ gradmap = (gradmap != 0).sum(0) # sum across features
+ rf = (gradmap != 0).sum()
+
+ print(f"Receptive field: {rf.item()}")
+
+ x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100)
+ model.decompress(model.compress(x, verbose=True), verbose=True)
diff --git a/modules/dac/model/discriminator.py b/modules/dac/model/discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..de5d63b28e6199046fdd2c2dad5bc02a446c05d1
--- /dev/null
+++ b/modules/dac/model/discriminator.py
@@ -0,0 +1,235 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# This code is modified from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/model/discriminator.py
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from audiotools import AudioSignal
+from audiotools import ml
+from audiotools import STFTParams
+from einops import rearrange
+from torch.nn.utils import weight_norm
+
+
+def WNConv1d(*args, **kwargs):
+ act = kwargs.pop("act", True)
+ conv = weight_norm(nn.Conv1d(*args, **kwargs))
+ if not act:
+ return conv
+ return nn.Sequential(conv, nn.LeakyReLU(0.1))
+
+
+def WNConv2d(*args, **kwargs):
+ act = kwargs.pop("act", True)
+ conv = weight_norm(nn.Conv2d(*args, **kwargs))
+ if not act:
+ return conv
+ return nn.Sequential(conv, nn.LeakyReLU(0.1))
+
+
+class MPD(nn.Module):
+ def __init__(self, period):
+ super().__init__()
+ self.period = period
+ self.convs = nn.ModuleList(
+ [
+ WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
+ WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
+ WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
+ WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
+ WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
+ ]
+ )
+ self.conv_post = WNConv2d(
+ 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
+ )
+
+ def pad_to_period(self, x):
+ t = x.shape[-1]
+ x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
+ return x
+
+ def forward(self, x):
+ fmap = []
+
+ x = self.pad_to_period(x)
+ x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
+
+ for layer in self.convs:
+ x = layer(x)
+ fmap.append(x)
+
+ x = self.conv_post(x)
+ fmap.append(x)
+
+ return fmap
+
+
+class MSD(nn.Module):
+ def __init__(self, rate: int = 1, sample_rate: int = 44100):
+ super().__init__()
+ self.convs = nn.ModuleList(
+ [
+ WNConv1d(1, 16, 15, 1, padding=7),
+ WNConv1d(16, 64, 41, 4, groups=4, padding=20),
+ WNConv1d(64, 256, 41, 4, groups=16, padding=20),
+ WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
+ WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
+ WNConv1d(1024, 1024, 5, 1, padding=2),
+ ]
+ )
+ self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
+ self.sample_rate = sample_rate
+ self.rate = rate
+
+ def forward(self, x):
+ x = AudioSignal(x, self.sample_rate)
+ x.resample(self.sample_rate // self.rate)
+ x = x.audio_data
+
+ fmap = []
+
+ for l in self.convs:
+ x = l(x)
+ fmap.append(x)
+ x = self.conv_post(x)
+ fmap.append(x)
+
+ return fmap
+
+
+BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
+
+
+class MRD(nn.Module):
+ def __init__(
+ self,
+ window_length: int,
+ hop_factor: float = 0.25,
+ sample_rate: int = 44100,
+ bands: list = BANDS,
+ ):
+ """Complex multi-band spectrogram discriminator.
+ Parameters
+ ----------
+ window_length : int
+ Window length of STFT.
+ hop_factor : float, optional
+ Hop factor of the STFT, defaults to ``0.25 * window_length``.
+ sample_rate : int, optional
+ Sampling rate of audio in Hz, by default 44100
+ bands : list, optional
+ Bands to run discriminator over.
+ """
+ super().__init__()
+
+ self.window_length = window_length
+ self.hop_factor = hop_factor
+ self.sample_rate = sample_rate
+ self.stft_params = STFTParams(
+ window_length=window_length,
+ hop_length=int(window_length * hop_factor),
+ match_stride=True,
+ )
+
+ n_fft = window_length // 2 + 1
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
+ self.bands = bands
+
+ ch = 32
+ convs = lambda: nn.ModuleList(
+ [
+ WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
+ WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
+ ]
+ )
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
+ self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
+
+ def spectrogram(self, x):
+ x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
+ x = torch.view_as_real(x.stft())
+ x = rearrange(x, "b 1 f t c -> (b 1) c t f")
+ # Split into bands
+ x_bands = [x[..., b[0] : b[1]] for b in self.bands]
+ return x_bands
+
+ def forward(self, x):
+ x_bands = self.spectrogram(x)
+ fmap = []
+
+ x = []
+ for band, stack in zip(x_bands, self.band_convs):
+ for layer in stack:
+ band = layer(band)
+ fmap.append(band)
+ x.append(band)
+
+ x = torch.cat(x, dim=-1)
+ x = self.conv_post(x)
+ fmap.append(x)
+
+ return fmap
+
+
+class Discriminator(nn.Module):
+ def __init__(
+ self,
+ rates: list = [],
+ periods: list = [2, 3, 5, 7, 11],
+ fft_sizes: list = [2048, 1024, 512],
+ sample_rate: int = 44100,
+ bands: list = BANDS,
+ ):
+ """Discriminator that combines multiple discriminators.
+
+ Parameters
+ ----------
+ rates : list, optional
+ sampling rates (in Hz) to run MSD at, by default []
+ If empty, MSD is not used.
+ periods : list, optional
+ periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
+ fft_sizes : list, optional
+ Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
+ sample_rate : int, optional
+ Sampling rate of audio in Hz, by default 44100
+ bands : list, optional
+ Bands to run MRD at, by default `BANDS`
+ """
+ super().__init__()
+ discs = []
+ discs += [MPD(p) for p in periods]
+ discs += [MSD(r, sample_rate=sample_rate) for r in rates]
+ discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes]
+ self.discriminators = nn.ModuleList(discs)
+
+ def preprocess(self, y):
+ # Remove DC offset
+ y = y - y.mean(dim=-1, keepdims=True)
+ # Peak normalize the volume of input audio
+ y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
+ return y
+
+ def forward(self, x):
+ x = self.preprocess(x)
+ fmaps = [d(x) for d in self.discriminators]
+ return fmaps
+
+
+if __name__ == "__main__":
+ disc = Discriminator()
+ x = torch.zeros(1, 1, 44100)
+ results = disc(x)
+ for i, result in enumerate(results):
+ print(f"disc{i}")
+ for i, r in enumerate(result):
+ print(r.shape, r.mean(), r.min(), r.max())
+ print()
diff --git a/modules/dac/model/encodec.py b/modules/dac/model/encodec.py
new file mode 100644
index 0000000000000000000000000000000000000000..08bcc84ce328d309b7592426fade27a1dcaed43f
--- /dev/null
+++ b/modules/dac/model/encodec.py
@@ -0,0 +1,390 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# This code is modified from https://github.com/facebookresearch/encodec/blob/main/encodec/modules/conv.py
+
+"""Convolutional layers wrappers and utilities."""
+
+import math
+import typing as tp
+import warnings
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.nn.utils import spectral_norm, weight_norm
+
+import typing as tp
+
+import einops
+
+
+class ConvLayerNorm(nn.LayerNorm):
+ """
+ Convolution-friendly LayerNorm that moves channels to last dimensions
+ before running the normalization and moves them back to original position right after.
+ """
+
+ def __init__(
+ self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs
+ ):
+ super().__init__(normalized_shape, **kwargs)
+
+ def forward(self, x):
+ x = einops.rearrange(x, "b ... t -> b t ...")
+ x = super().forward(x)
+ x = einops.rearrange(x, "b t ... -> b ... t")
+ return
+
+
+CONV_NORMALIZATIONS = frozenset(
+ [
+ "none",
+ "weight_norm",
+ "spectral_norm",
+ "time_layer_norm",
+ "layer_norm",
+ "time_group_norm",
+ ]
+)
+
+
+def apply_parametrization_norm(module: nn.Module, norm: str = "none") -> nn.Module:
+ assert norm in CONV_NORMALIZATIONS
+ if norm == "weight_norm":
+ return weight_norm(module)
+ elif norm == "spectral_norm":
+ return spectral_norm(module)
+ else:
+ # We already check was in CONV_NORMALIZATION, so any other choice
+ # doesn't need reparametrization.
+ return module
+
+
+def get_norm_module(
+ module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs
+) -> nn.Module:
+ """Return the proper normalization module. If causal is True, this will ensure the returned
+ module is causal, or return an error if the normalization doesn't support causal evaluation.
+ """
+ assert norm in CONV_NORMALIZATIONS
+ if norm == "layer_norm":
+ assert isinstance(module, nn.modules.conv._ConvNd)
+ return ConvLayerNorm(module.out_channels, **norm_kwargs)
+ elif norm == "time_group_norm":
+ if causal:
+ raise ValueError("GroupNorm doesn't support causal evaluation.")
+ assert isinstance(module, nn.modules.conv._ConvNd)
+ return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
+ else:
+ return nn.Identity()
+
+
+def get_extra_padding_for_conv1d(
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
+) -> int:
+ """See `pad_for_conv1d`."""
+ length = x.shape[-1]
+ n_frames = (length - kernel_size + padding_total) / stride + 1
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
+ return ideal_length - length
+
+
+def pad_for_conv1d(
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
+):
+ """Pad for a convolution to make sure that the last window is full.
+ Extra padding is added at the end. This is required to ensure that we can rebuild
+ an output of the same length, as otherwise, even with padding, some time steps
+ might get removed.
+ For instance, with total padding = 4, kernel size = 4, stride = 2:
+ 0 0 1 2 3 4 5 0 0 # (0s are padding)
+ 1 2 3 # (output frames of a convolution, last 0 is never used)
+ 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
+ 1 2 3 4 # once you removed padding, we are missing one time step !
+ """
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
+ return F.pad(x, (0, extra_padding))
+
+
+def pad1d(
+ x: torch.Tensor,
+ paddings: tp.Tuple[int, int],
+ mode: str = "zero",
+ value: float = 0.0,
+):
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
+ """
+ length = x.shape[-1]
+ padding_left, padding_right = paddings
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+ if mode == "reflect":
+ max_pad = max(padding_left, padding_right)
+ extra_pad = 0
+ if length <= max_pad:
+ extra_pad = max_pad - length + 1
+ x = F.pad(x, (0, extra_pad))
+ padded = F.pad(x, paddings, mode, value)
+ end = padded.shape[-1] - extra_pad
+ return padded[..., :end]
+ else:
+ return F.pad(x, paddings, mode, value)
+
+
+def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
+ padding_left, padding_right = paddings
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+ assert (padding_left + padding_right) <= x.shape[-1]
+ end = x.shape[-1] - padding_right
+ return x[..., padding_left:end]
+
+
+class NormConv1d(nn.Module):
+ """Wrapper around Conv1d and normalization applied to this conv
+ to provide a uniform interface across normalization approaches.
+ """
+
+ def __init__(
+ self,
+ *args,
+ causal: bool = False,
+ norm: str = "none",
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
+ **kwargs,
+ ):
+ super().__init__()
+ self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
+ self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
+ self.norm_type = norm
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.norm(x)
+ return x
+
+
+class NormConv2d(nn.Module):
+ """Wrapper around Conv2d and normalization applied to this conv
+ to provide a uniform interface across normalization approaches.
+ """
+
+ def __init__(
+ self,
+ *args,
+ norm: str = "none",
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
+ **kwargs,
+ ):
+ super().__init__()
+ self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
+ self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
+ self.norm_type = norm
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.norm(x)
+ return x
+
+
+class NormConvTranspose1d(nn.Module):
+ """Wrapper around ConvTranspose1d and normalization applied to this conv
+ to provide a uniform interface across normalization approaches.
+ """
+
+ def __init__(
+ self,
+ *args,
+ causal: bool = False,
+ norm: str = "none",
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
+ **kwargs,
+ ):
+ super().__init__()
+ self.convtr = apply_parametrization_norm(
+ nn.ConvTranspose1d(*args, **kwargs), norm
+ )
+ self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
+ self.norm_type = norm
+
+ def forward(self, x):
+ x = self.convtr(x)
+ x = self.norm(x)
+ return x
+
+
+class NormConvTranspose2d(nn.Module):
+ """Wrapper around ConvTranspose2d and normalization applied to this conv
+ to provide a uniform interface across normalization approaches.
+ """
+
+ def __init__(
+ self,
+ *args,
+ norm: str = "none",
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
+ **kwargs,
+ ):
+ super().__init__()
+ self.convtr = apply_parametrization_norm(
+ nn.ConvTranspose2d(*args, **kwargs), norm
+ )
+ self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
+
+ def forward(self, x):
+ x = self.convtr(x)
+ x = self.norm(x)
+ return x
+
+
+class SConv1d(nn.Module):
+ """Conv1d with some builtin handling of asymmetric or causal padding
+ and normalization.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int = 1,
+ dilation: int = 1,
+ groups: int = 1,
+ bias: bool = True,
+ causal: bool = False,
+ norm: str = "none",
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
+ pad_mode: str = "reflect",
+ **kwargs,
+ ):
+ super().__init__()
+ # warn user on unusual setup between dilation and stride
+ if stride > 1 and dilation > 1:
+ warnings.warn(
+ "SConv1d has been initialized with stride > 1 and dilation > 1"
+ f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})."
+ )
+ self.conv = NormConv1d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ dilation=dilation,
+ groups=groups,
+ bias=bias,
+ causal=causal,
+ norm=norm,
+ norm_kwargs=norm_kwargs,
+ )
+ self.causal = causal
+ self.pad_mode = pad_mode
+
+ def forward(self, x):
+ B, C, T = x.shape
+ kernel_size = self.conv.conv.kernel_size[0]
+ stride = self.conv.conv.stride[0]
+ dilation = self.conv.conv.dilation[0]
+ kernel_size = (
+ kernel_size - 1
+ ) * dilation + 1 # effective kernel size with dilations
+ padding_total = kernel_size - stride
+ extra_padding = get_extra_padding_for_conv1d(
+ x, kernel_size, stride, padding_total
+ )
+ if self.causal:
+ # Left padding for causal
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
+ else:
+ # Asymmetric padding required for odd strides
+ padding_right = padding_total // 2
+ padding_left = padding_total - padding_right
+ x = pad1d(
+ x, (padding_left, padding_right + extra_padding), mode=self.pad_mode
+ )
+ return self.conv(x)
+
+
+class SConvTranspose1d(nn.Module):
+ """ConvTranspose1d with some builtin handling of asymmetric or causal padding
+ and normalization.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int = 1,
+ causal: bool = False,
+ norm: str = "none",
+ trim_right_ratio: float = 1.0,
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
+ **kwargs,
+ ):
+ super().__init__()
+ self.convtr = NormConvTranspose1d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ causal=causal,
+ norm=norm,
+ norm_kwargs=norm_kwargs,
+ )
+ self.causal = causal
+ self.trim_right_ratio = trim_right_ratio
+ assert (
+ self.causal or self.trim_right_ratio == 1.0
+ ), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
+ assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0
+
+ def forward(self, x):
+ kernel_size = self.convtr.convtr.kernel_size[0]
+ stride = self.convtr.convtr.stride[0]
+ padding_total = kernel_size - stride
+
+ y = self.convtr(x)
+
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
+ # removed at the very end, when keeping only the right length for the output,
+ # as removing it here would require also passing the length at the matching layer
+ # in the encoder.
+ if self.causal:
+ # Trim the padding on the right according to the specified ratio
+ # if trim_right_ratio = 1.0, trim everything from right
+ padding_right = math.ceil(padding_total * self.trim_right_ratio)
+ padding_left = padding_total - padding_right
+ y = unpad1d(y, (padding_left, padding_right))
+ else:
+ # Asymmetric padding required for odd strides
+ padding_right = padding_total // 2
+ padding_left = padding_total - padding_right
+ y = unpad1d(y, (padding_left, padding_right))
+ return y
+
+
+class SLSTM(nn.Module):
+ """
+ LSTM without worrying about the hidden state, nor the layout of the data.
+ Expects input as convolutional layout.
+ """
+
+ def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
+ super().__init__()
+ self.skip = skip
+ self.lstm = nn.LSTM(dimension, dimension, num_layers)
+ self.hidden = None
+
+ def forward(self, x):
+ x = x.permute(2, 0, 1)
+ if self.training:
+ y, _ = self.lstm(x)
+ else:
+ y, self.hidden = self.lstm(x, self.hidden)
+ if self.skip:
+ y = y + x
+ y = y.permute(1, 2, 0)
+ return y
diff --git a/modules/dac/nn/__init__.py b/modules/dac/nn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..08c0311e24cb8cde190033047e2feb6718a2194e
--- /dev/null
+++ b/modules/dac/nn/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# This code is borrowed from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/__init__.py
+
+from . import layers
+from . import loss
+from . import quantize
diff --git a/modules/dac/nn/layers.py b/modules/dac/nn/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce5ea229f0367057c154fdaf0905ec7ae284d3e7
--- /dev/null
+++ b/modules/dac/nn/layers.py
@@ -0,0 +1,40 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# This code is modified from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/layers.py
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from torch.nn.utils import weight_norm
+
+
+def WNConv1d(*args, **kwargs):
+ return weight_norm(nn.Conv1d(*args, **kwargs))
+
+
+def WNConvTranspose1d(*args, **kwargs):
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
+
+
+# Scripting this brings model speed up 1.4x
+@torch.jit.script
+def snake(x, alpha):
+ shape = x.shape
+ x = x.reshape(shape[0], shape[1], -1)
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
+ x = x.reshape(shape)
+ return x
+
+
+class Snake1d(nn.Module):
+ def __init__(self, channels):
+ super().__init__()
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
+
+ def forward(self, x):
+ return snake(x, self.alpha)
diff --git a/modules/dac/nn/loss.py b/modules/dac/nn/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..06b564fedd7cf4d74f3de1b7d163307d8b8f99bc
--- /dev/null
+++ b/modules/dac/nn/loss.py
@@ -0,0 +1,389 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# This code is modified from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/loss.py
+
+import typing
+from typing import List
+
+import torch
+import torch.nn.functional as F
+from audiotools import AudioSignal
+from audiotools import STFTParams
+from torch import nn
+
+
+class L1Loss(nn.L1Loss):
+ """L1 Loss between AudioSignals. Defaults
+ to comparing ``audio_data``, but any
+ attribute of an AudioSignal can be used.
+
+ Parameters
+ ----------
+ attribute : str, optional
+ Attribute of signal to compare, defaults to ``audio_data``.
+ weight : float, optional
+ Weight of this loss, defaults to 1.0.
+
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
+ """
+
+ def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
+ self.attribute = attribute
+ self.weight = weight
+ super().__init__(**kwargs)
+
+ def forward(self, x: AudioSignal, y: AudioSignal):
+ """
+ Parameters
+ ----------
+ x : AudioSignal
+ Estimate AudioSignal
+ y : AudioSignal
+ Reference AudioSignal
+
+ Returns
+ -------
+ torch.Tensor
+ L1 loss between AudioSignal attributes.
+ """
+ if isinstance(x, AudioSignal):
+ x = getattr(x, self.attribute)
+ y = getattr(y, self.attribute)
+ return super().forward(x, y)
+
+
+class SISDRLoss(nn.Module):
+ """
+ Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
+ of estimated and reference audio signals or aligned features.
+
+ Parameters
+ ----------
+ scaling : int, optional
+ Whether to use scale-invariant (True) or
+ signal-to-noise ratio (False), by default True
+ reduction : str, optional
+ How to reduce across the batch (either 'mean',
+ 'sum', or none).], by default ' mean'
+ zero_mean : int, optional
+ Zero mean the references and estimates before
+ computing the loss, by default True
+ clip_min : int, optional
+ The minimum possible loss value. Helps network
+ to not focus on making already good examples better, by default None
+ weight : float, optional
+ Weight of this loss, defaults to 1.0.
+
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
+ """
+
+ def __init__(
+ self,
+ scaling: int = True,
+ reduction: str = "mean",
+ zero_mean: int = True,
+ clip_min: int = None,
+ weight: float = 1.0,
+ ):
+ self.scaling = scaling
+ self.reduction = reduction
+ self.zero_mean = zero_mean
+ self.clip_min = clip_min
+ self.weight = weight
+ super().__init__()
+
+ def forward(self, x: AudioSignal, y: AudioSignal):
+ eps = 1e-8
+ # nb, nc, nt
+ if isinstance(x, AudioSignal):
+ references = x.audio_data
+ estimates = y.audio_data
+ else:
+ references = x
+ estimates = y
+
+ nb = references.shape[0]
+ references = references.reshape(nb, 1, -1).permute(0, 2, 1)
+ estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
+
+ # samples now on axis 1
+ if self.zero_mean:
+ mean_reference = references.mean(dim=1, keepdim=True)
+ mean_estimate = estimates.mean(dim=1, keepdim=True)
+ else:
+ mean_reference = 0
+ mean_estimate = 0
+
+ _references = references - mean_reference
+ _estimates = estimates - mean_estimate
+
+ references_projection = (_references**2).sum(dim=-2) + eps
+ references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
+
+ scale = (
+ (references_on_estimates / references_projection).unsqueeze(1)
+ if self.scaling
+ else 1
+ )
+
+ e_true = scale * _references
+ e_res = _estimates - e_true
+
+ signal = (e_true**2).sum(dim=1)
+ noise = (e_res**2).sum(dim=1)
+ sdr = -10 * torch.log10(signal / noise + eps)
+
+ if self.clip_min is not None:
+ sdr = torch.clamp(sdr, min=self.clip_min)
+
+ if self.reduction == "mean":
+ sdr = sdr.mean()
+ elif self.reduction == "sum":
+ sdr = sdr.sum()
+ return sdr
+
+
+class MultiScaleSTFTLoss(nn.Module):
+ """Computes the multi-scale STFT loss from [1].
+
+ Parameters
+ ----------
+ window_lengths : List[int], optional
+ Length of each window of each STFT, by default [2048, 512]
+ loss_fn : typing.Callable, optional
+ How to compare each loss, by default nn.L1Loss()
+ clamp_eps : float, optional
+ Clamp on the log magnitude, below, by default 1e-5
+ mag_weight : float, optional
+ Weight of raw magnitude portion of loss, by default 1.0
+ log_weight : float, optional
+ Weight of log magnitude portion of loss, by default 1.0
+ pow : float, optional
+ Power to raise magnitude to before taking log, by default 2.0
+ weight : float, optional
+ Weight of this loss, by default 1.0
+ match_stride : bool, optional
+ Whether to match the stride of convolutional layers, by default False
+
+ References
+ ----------
+
+ 1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
+ "DDSP: Differentiable Digital Signal Processing."
+ International Conference on Learning Representations. 2019.
+
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
+ """
+
+ def __init__(
+ self,
+ window_lengths: List[int] = [2048, 512],
+ loss_fn: typing.Callable = nn.L1Loss(),
+ clamp_eps: float = 1e-5,
+ mag_weight: float = 1.0,
+ log_weight: float = 1.0,
+ pow: float = 2.0,
+ weight: float = 1.0,
+ match_stride: bool = False,
+ window_type: str = None,
+ ):
+ super().__init__()
+ self.stft_params = [
+ STFTParams(
+ window_length=w,
+ hop_length=w // 4,
+ match_stride=match_stride,
+ window_type=window_type,
+ )
+ for w in window_lengths
+ ]
+ self.loss_fn = loss_fn
+ self.log_weight = log_weight
+ self.mag_weight = mag_weight
+ self.clamp_eps = clamp_eps
+ self.weight = weight
+ self.pow = pow
+
+ def forward(self, x: AudioSignal, y: AudioSignal):
+ """Computes multi-scale STFT between an estimate and a reference
+ signal.
+
+ Parameters
+ ----------
+ x : AudioSignal
+ Estimate signal
+ y : AudioSignal
+ Reference signal
+
+ Returns
+ -------
+ torch.Tensor
+ Multi-scale STFT loss.
+ """
+ loss = 0.0
+ for s in self.stft_params:
+ x.stft(s.window_length, s.hop_length, s.window_type)
+ y.stft(s.window_length, s.hop_length, s.window_type)
+ loss += self.log_weight * self.loss_fn(
+ x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
+ y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
+ )
+ loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
+ return loss
+
+
+class MelSpectrogramLoss(nn.Module):
+ """Compute distance between mel spectrograms. Can be used
+ in a multi-scale way.
+
+ Parameters
+ ----------
+ n_mels : List[int]
+ Number of mels per STFT, by default [150, 80],
+ window_lengths : List[int], optional
+ Length of each window of each STFT, by default [2048, 512]
+ loss_fn : typing.Callable, optional
+ How to compare each loss, by default nn.L1Loss()
+ clamp_eps : float, optional
+ Clamp on the log magnitude, below, by default 1e-5
+ mag_weight : float, optional
+ Weight of raw magnitude portion of loss, by default 1.0
+ log_weight : float, optional
+ Weight of log magnitude portion of loss, by default 1.0
+ pow : float, optional
+ Power to raise magnitude to before taking log, by default 2.0
+ weight : float, optional
+ Weight of this loss, by default 1.0
+ match_stride : bool, optional
+ Whether to match the stride of convolutional layers, by default False
+
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
+ """
+
+ def __init__(
+ self,
+ n_mels: List[int] = [150, 80],
+ window_lengths: List[int] = [2048, 512],
+ loss_fn: typing.Callable = nn.L1Loss(),
+ clamp_eps: float = 1e-5,
+ mag_weight: float = 1.0,
+ log_weight: float = 1.0,
+ pow: float = 2.0,
+ weight: float = 1.0,
+ match_stride: bool = False,
+ mel_fmin: List[float] = [0.0, 0.0],
+ mel_fmax: List[float] = [None, None],
+ window_type: str = None,
+ ):
+ super().__init__()
+ self.stft_params = [
+ STFTParams(
+ window_length=w,
+ hop_length=w // 4,
+ match_stride=match_stride,
+ window_type=window_type,
+ )
+ for w in window_lengths
+ ]
+ self.n_mels = n_mels
+ self.loss_fn = loss_fn
+ self.clamp_eps = clamp_eps
+ self.log_weight = log_weight
+ self.mag_weight = mag_weight
+ self.weight = weight
+ self.mel_fmin = mel_fmin
+ self.mel_fmax = mel_fmax
+ self.pow = pow
+
+ def forward(self, x: AudioSignal, y: AudioSignal):
+ """Computes mel loss between an estimate and a reference
+ signal.
+
+ Parameters
+ ----------
+ x : AudioSignal
+ Estimate signal
+ y : AudioSignal
+ Reference signal
+
+ Returns
+ -------
+ torch.Tensor
+ Mel loss.
+ """
+ loss = 0.0
+ for n_mels, fmin, fmax, s in zip(
+ self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
+ ):
+ kwargs = {
+ "window_length": s.window_length,
+ "hop_length": s.hop_length,
+ "window_type": s.window_type,
+ }
+ x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
+ y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
+
+ loss += self.log_weight * self.loss_fn(
+ x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
+ y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
+ )
+ loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
+ return loss
+
+
+class FocalLoss(torch.nn.Module):
+ def __init__(self, gamma=0, eps=1e-7):
+ super(FocalLoss, self).__init__()
+ self.gamma = gamma
+ self.eps = eps
+ self.ce = torch.nn.CrossEntropyLoss()
+
+ def forward(self, input, target):
+ logp = self.ce(input, target)
+ p = torch.exp(-logp)
+ loss = (1 - p) ** self.gamma * logp
+ return loss.mean()
+
+
+class GANLoss(nn.Module):
+ """
+ Computes a discriminator loss, given a discriminator on
+ generated waveforms/spectrograms compared to ground truth
+ waveforms/spectrograms. Computes the loss for both the
+ discriminator and the generator in separate functions.
+ """
+
+ def __init__(self, discriminator):
+ super().__init__()
+ self.discriminator = discriminator
+
+ def forward(self, fake, real):
+ d_fake = self.discriminator(fake.audio_data)
+ d_real = self.discriminator(real.audio_data)
+ return d_fake, d_real
+
+ def discriminator_loss(self, fake, real):
+ d_fake, d_real = self.forward(fake.clone().detach(), real)
+
+ loss_d = 0
+ for x_fake, x_real in zip(d_fake, d_real):
+ loss_d += torch.mean(x_fake[-1] ** 2)
+ loss_d += torch.mean((1 - x_real[-1]) ** 2)
+ return loss_d
+
+ def generator_loss(self, fake, real):
+ d_fake, d_real = self.forward(fake, real)
+
+ loss_g = 0
+ for x_fake in d_fake:
+ loss_g += torch.mean((1 - x_fake[-1]) ** 2)
+
+ loss_feature = 0
+
+ for i in range(len(d_fake)):
+ for j in range(len(d_fake[i]) - 1):
+ loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
+ return loss_g, loss_feature
diff --git a/modules/dac/nn/quantize.py b/modules/dac/nn/quantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0d88d5ea85e89f2fe5c0b92283405caa4c8f998
--- /dev/null
+++ b/modules/dac/nn/quantize.py
@@ -0,0 +1,269 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# This code is modified from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/quantize.py
+
+from typing import Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from torch.nn.utils import weight_norm
+
+from ..nn.layers import WNConv1d
+
+
+class VectorQuantize(nn.Module):
+ """
+ Implementation of VQ similar to Karpathy's repo:
+ https://github.com/karpathy/deep-vector-quantization
+ Additionally uses following tricks from Improved VQGAN
+ (https://arxiv.org/pdf/2110.04627.pdf):
+ 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
+ for improved codebook usage
+ 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
+ improves training stability
+ """
+
+ def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
+ super().__init__()
+ self.codebook_size = codebook_size
+ self.codebook_dim = codebook_dim
+
+ self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
+ self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
+
+ def forward(self, z):
+ """Quantized the input tensor using a fixed codebook and returns
+ the corresponding codebook vectors
+
+ Parameters
+ ----------
+ z : Tensor[B x D x T]
+
+ Returns
+ -------
+ Tensor[B x D x T]
+ Quantized continuous representation of input
+ Tensor[1]
+ Commitment loss to train encoder to predict vectors closer to codebook
+ entries
+ Tensor[1]
+ Codebook loss to update the codebook
+ Tensor[B x T]
+ Codebook indices (quantized discrete representation of input)
+ Tensor[B x D x T]
+ Projected latents (continuous representation of input before quantization)
+ """
+
+ # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
+ z_e = self.in_proj(z) # z_e : (B x D x T)
+ z_q, indices = self.decode_latents(z_e)
+
+ commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
+ codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
+
+ z_q = (
+ z_e + (z_q - z_e).detach()
+ ) # noop in forward pass, straight-through gradient estimator in backward pass
+
+ z_q = self.out_proj(z_q)
+
+ return z_q, commitment_loss, codebook_loss, indices, z_e
+
+ def embed_code(self, embed_id):
+ return F.embedding(embed_id, self.codebook.weight)
+
+ def decode_code(self, embed_id):
+ return self.embed_code(embed_id).transpose(1, 2)
+
+ def decode_latents(self, latents):
+ encodings = rearrange(latents, "b d t -> (b t) d")
+ codebook = self.codebook.weight # codebook: (N x D)
+
+ # L2 normalize encodings and codebook (ViT-VQGAN)
+ encodings = F.normalize(encodings)
+ codebook = F.normalize(codebook)
+
+ # Compute euclidean distance with codebook
+ dist = (
+ encodings.pow(2).sum(1, keepdim=True)
+ - 2 * encodings @ codebook.t()
+ + codebook.pow(2).sum(1, keepdim=True).t()
+ )
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
+ z_q = self.decode_code(indices)
+ return z_q, indices
+
+
+class ResidualVectorQuantize(nn.Module):
+ """
+ Introduced in SoundStream: An end2end neural audio codec
+ https://arxiv.org/abs/2107.03312
+ """
+
+ def __init__(
+ self,
+ input_dim: int = 512,
+ n_codebooks: int = 9,
+ codebook_size: int = 1024,
+ codebook_dim: Union[int, list] = 8,
+ quantizer_dropout: float = 0.0,
+ ):
+ super().__init__()
+ if isinstance(codebook_dim, int):
+ codebook_dim = [codebook_dim for _ in range(n_codebooks)]
+
+ self.n_codebooks = n_codebooks
+ self.codebook_dim = codebook_dim
+ self.codebook_size = codebook_size
+
+ self.quantizers = nn.ModuleList(
+ [
+ VectorQuantize(input_dim, codebook_size, codebook_dim[i])
+ for i in range(n_codebooks)
+ ]
+ )
+ self.quantizer_dropout = quantizer_dropout
+
+ def forward(self, z, n_quantizers: int = None):
+ """Quantized the input tensor using a fixed set of `n` codebooks and returns
+ the corresponding codebook vectors
+ Parameters
+ ----------
+ z : Tensor[B x D x T]
+ n_quantizers : int, optional
+ No. of quantizers to use
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
+ when in training mode, and a random number of quantizers is used.
+ Returns
+ -------
+ dict
+ A dictionary with the following keys:
+
+ "z" : Tensor[B x D x T]
+ Quantized continuous representation of input
+ "codes" : Tensor[B x N x T]
+ Codebook indices for each codebook
+ (quantized discrete representation of input)
+ "latents" : Tensor[B x N*D x T]
+ Projected latents (continuous representation of input before quantization)
+ "vq/commitment_loss" : Tensor[1]
+ Commitment loss to train encoder to predict vectors closer to codebook
+ entries
+ "vq/codebook_loss" : Tensor[1]
+ Codebook loss to update the codebook
+ """
+ z_q = 0
+ residual = z
+ commitment_loss = 0
+ codebook_loss = 0
+
+ codebook_indices = []
+ latents = []
+
+ if n_quantizers is None:
+ n_quantizers = self.n_codebooks
+ if self.training:
+ n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
+ dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
+ n_quantizers = n_quantizers.to(z.device)
+
+ for i, quantizer in enumerate(self.quantizers):
+ if self.training is False and i >= n_quantizers:
+ break
+
+ z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
+ residual
+ )
+
+ # Create mask to apply quantizer dropout
+ mask = (
+ torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
+ )
+ z_q = z_q + z_q_i * mask[:, None, None]
+ residual = residual - z_q_i
+
+ # Sum losses
+ commitment_loss += (commitment_loss_i * mask).mean()
+ codebook_loss += (codebook_loss_i * mask).mean()
+
+ codebook_indices.append(indices_i)
+ latents.append(z_e_i)
+
+ codes = torch.stack(codebook_indices, dim=1)
+ latents = torch.cat(latents, dim=1)
+
+ return z_q, codes, latents, commitment_loss, codebook_loss
+
+ def from_codes(self, codes: torch.Tensor):
+ """Given the quantized codes, reconstruct the continuous representation
+ Parameters
+ ----------
+ codes : Tensor[B x N x T]
+ Quantized discrete representation of input
+ Returns
+ -------
+ Tensor[B x D x T]
+ Quantized continuous representation of input
+ """
+ z_q = 0.0
+ z_p = []
+ n_codebooks = codes.shape[1]
+ for i in range(n_codebooks):
+ z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
+ z_p.append(z_p_i)
+
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
+ z_q = z_q + z_q_i
+ return z_q, torch.cat(z_p, dim=1), codes
+
+ def from_latents(self, latents: torch.Tensor):
+ """Given the unquantized latents, reconstruct the
+ continuous representation after quantization.
+
+ Parameters
+ ----------
+ latents : Tensor[B x N x T]
+ Continuous representation of input after projection
+
+ Returns
+ -------
+ Tensor[B x D x T]
+ Quantized representation of full-projected space
+ Tensor[B x D x T]
+ Quantized representation of latent space
+ """
+ z_q = 0
+ z_p = []
+ codes = []
+ dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
+
+ n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
+ 0
+ ]
+ for i in range(n_codebooks):
+ j, k = dims[i], dims[i + 1]
+ z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
+ z_p.append(z_p_i)
+ codes.append(codes_i)
+
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
+ z_q = z_q + z_q_i
+
+ return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
+
+
+if __name__ == "__main__":
+ rvq = ResidualVectorQuantize(quantizer_dropout=True)
+ x = torch.randn(16, 512, 80)
+ y = rvq(x)
+ print(y["latents"].shape)
diff --git a/modules/diffusion/__init__.py b/modules/diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf1e4f9e71ef3b0ef2d677a9c0db8e03b022eaea
--- /dev/null
+++ b/modules/diffusion/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .bidilconv.bidilated_conv import BiDilConv
+from .unet.unet import UNet
diff --git a/modules/diffusion/bidilconv/bidilated_conv.py b/modules/diffusion/bidilconv/bidilated_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac75522a3da8a40ae1b76656dd2a78097d6d94f3
--- /dev/null
+++ b/modules/diffusion/bidilconv/bidilated_conv.py
@@ -0,0 +1,102 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+
+import torch.nn as nn
+
+from modules.general.utils import Conv1d, zero_module
+from .residual_block import ResidualBlock
+
+
+class BiDilConv(nn.Module):
+ r"""Dilated CNN architecture with residual connections, default diffusion decoder.
+
+ Args:
+ input_channel: The number of input channels.
+ base_channel: The number of base channels.
+ n_res_block: The number of residual blocks.
+ conv_kernel_size: The kernel size of convolutional layers.
+ dilation_cycle_length: The cycle length of dilation.
+ conditioner_size: The size of conditioner.
+ """
+
+ def __init__(
+ self,
+ input_channel,
+ base_channel,
+ n_res_block,
+ conv_kernel_size,
+ dilation_cycle_length,
+ conditioner_size,
+ output_channel: int = -1,
+ ):
+ super().__init__()
+
+ self.input_channel = input_channel
+ self.base_channel = base_channel
+ self.n_res_block = n_res_block
+ self.conv_kernel_size = conv_kernel_size
+ self.dilation_cycle_length = dilation_cycle_length
+ self.conditioner_size = conditioner_size
+ self.output_channel = output_channel if output_channel > 0 else input_channel
+
+ self.input = nn.Sequential(
+ Conv1d(
+ input_channel,
+ base_channel,
+ 1,
+ ),
+ nn.ReLU(),
+ )
+
+ self.residual_blocks = nn.ModuleList(
+ [
+ ResidualBlock(
+ channels=base_channel,
+ kernel_size=conv_kernel_size,
+ dilation=2 ** (i % dilation_cycle_length),
+ d_context=conditioner_size,
+ )
+ for i in range(n_res_block)
+ ]
+ )
+
+ self.out_proj = nn.Sequential(
+ Conv1d(
+ base_channel,
+ base_channel,
+ 1,
+ ),
+ nn.ReLU(),
+ zero_module(
+ Conv1d(
+ base_channel,
+ self.output_channel,
+ 1,
+ ),
+ ),
+ )
+
+ def forward(self, x, y, context=None):
+ """
+ Args:
+ x: Noisy mel-spectrogram [B x ``n_mel`` x L]
+ y: FILM embeddings with the shape of (B, ``base_channel``)
+ context: Context with the shape of [B x ``d_context`` x L], default to None.
+ """
+
+ h = self.input(x)
+
+ skip = None
+ for i in range(self.n_res_block):
+ h, skip_connection = self.residual_blocks[i](h, y, context)
+ skip = skip_connection if skip is None else skip_connection + skip
+
+ out = skip / math.sqrt(self.n_res_block)
+
+ out = self.out_proj(out)
+
+ return out
diff --git a/modules/diffusion/bidilconv/residual_block.py b/modules/diffusion/bidilconv/residual_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..df75d4fcb307af58fafe7258ba6e4a7a8cedd2f6
--- /dev/null
+++ b/modules/diffusion/bidilconv/residual_block.py
@@ -0,0 +1,73 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+
+import torch
+import torch.nn as nn
+
+from modules.activation_functions import GaU
+from modules.general.utils import Conv1d
+
+
+class ResidualBlock(nn.Module):
+ r"""Residual block with dilated convolution, main portion of ``BiDilConv``.
+
+ Args:
+ channels: The number of channels of input and output.
+ kernel_size: The kernel size of dilated convolution.
+ dilation: The dilation rate of dilated convolution.
+ d_context: The dimension of content encoder output, None if don't use context.
+ """
+
+ def __init__(
+ self,
+ channels: int = 256,
+ kernel_size: int = 3,
+ dilation: int = 1,
+ d_context: int = None,
+ ):
+ super().__init__()
+
+ self.context = d_context
+
+ self.gau = GaU(
+ channels,
+ kernel_size,
+ dilation,
+ d_context,
+ )
+
+ self.out_proj = Conv1d(
+ channels,
+ channels * 2,
+ 1,
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ y_emb: torch.Tensor,
+ context: torch.Tensor = None,
+ ):
+ """
+ Args:
+ x: Latent representation inherited from previous residual block
+ with the shape of [B x C x T].
+ y_emb: Embeddings with the shape of [B x C], which will be FILM on the x.
+ context: Context with the shape of [B x ``d_context`` x T], default to None.
+ """
+
+ h = x + y_emb[..., None]
+
+ if self.context:
+ h = self.gau(h, context)
+ else:
+ h = self.gau(h)
+
+ h = self.out_proj(h)
+ res, skip = h.chunk(2, 1)
+
+ return (res + x) / math.sqrt(2.0), skip
diff --git a/modules/diffusion/karras/karras_diffusion.py b/modules/diffusion/karras/karras_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e3664967a673e4c80f32a5fa1220126f994d07e
--- /dev/null
+++ b/modules/diffusion/karras/karras_diffusion.py
@@ -0,0 +1,977 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Based on: https://github.com/crowsonkb/k-diffusion
+"""
+import random
+
+import numpy as np
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+
+# from piq import LPIPS
+from utils.ssim import SSIM
+
+from modules.diffusion.karras.random_utils import get_generator
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(
+ f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
+ )
+ return x[(...,) + (None,) * dims_to_append]
+
+
+def append_zero(x):
+ return th.cat([x, x.new_zeros([1])])
+
+
+def get_weightings(weight_schedule, snrs, sigma_data):
+ if weight_schedule == "snr":
+ weightings = snrs
+ elif weight_schedule == "snr+1":
+ weightings = snrs + 1
+ elif weight_schedule == "karras":
+ weightings = snrs + 1.0 / sigma_data**2
+ elif weight_schedule == "truncated-snr":
+ weightings = th.clamp(snrs, min=1.0)
+ elif weight_schedule == "uniform":
+ weightings = th.ones_like(snrs)
+ else:
+ raise NotImplementedError()
+ return weightings
+
+
+class KarrasDenoiser:
+ def __init__(
+ self,
+ sigma_data: float = 0.5,
+ sigma_max=80.0,
+ sigma_min=0.002,
+ rho=7.0,
+ weight_schedule="karras",
+ distillation=False,
+ loss_norm="l2",
+ ):
+ self.sigma_data = sigma_data
+ self.sigma_max = sigma_max
+ self.sigma_min = sigma_min
+ self.weight_schedule = weight_schedule
+ self.distillation = distillation
+ self.loss_norm = loss_norm
+ # if loss_norm == "lpips":
+ # self.lpips_loss = LPIPS(replace_pooling=True, reduction="none")
+ if loss_norm == "ssim":
+ self.ssim_loss = SSIM()
+ self.rho = rho
+ self.num_timesteps = 40
+
+ def get_snr(self, sigmas):
+ return sigmas**-2
+
+ def get_sigmas(self, sigmas):
+ return sigmas
+
+ def get_scalings(self, sigma):
+ c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
+ c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
+ c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
+ return c_skip, c_out, c_in
+
+ def get_scalings_for_boundary_condition(self, sigma):
+ c_skip = self.sigma_data**2 / (
+ (sigma - self.sigma_min) ** 2 + self.sigma_data**2
+ )
+ c_out = (
+ (sigma - self.sigma_min)
+ * self.sigma_data
+ / (sigma**2 + self.sigma_data**2) ** 0.5
+ )
+ c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
+ return c_skip, c_out, c_in
+
+ def training_losses(self, model, x_start, sigmas, condition=None, noise=None):
+ if noise is None:
+ noise = th.randn_like(x_start)
+
+ terms = {}
+
+ dims = x_start.ndim
+ x_t = x_start + noise * append_dims(sigmas, dims)
+ model_output, denoised = self.denoise(model, x_t, sigmas, condition)
+
+ snrs = self.get_snr(sigmas)
+ weights = append_dims(
+ get_weightings(self.weight_schedule, snrs, self.sigma_data), dims
+ )
+ # terms["xs_mse"] = mean_flat((denoised - x_start) ** 2)
+ terms["mse"] = mean_flat(weights * (denoised - x_start) ** 2)
+ # terms["mae"] = mean_flat(weights * th.abs(denoised - x_start))
+ # terms["mse"] = nn.MSELoss(reduction="none")(denoised, x_start)
+
+ # if "vb" in terms:
+ # terms["loss"] = terms["mse"] + terms["vb"]
+ # else:
+ terms["loss"] = terms["mse"]
+
+ return terms
+
+ def consistency_losses(
+ self,
+ model,
+ x_start,
+ num_scales,
+ # model_kwargs=None,
+ condition=None,
+ target_model=None,
+ teacher_model=None,
+ teacher_diffusion=None,
+ noise=None,
+ ):
+ if noise is None:
+ noise = th.randn_like(x_start)
+
+ dims = x_start.ndim
+
+ def denoise_fn(x, t):
+ return self.denoise(model, x, t, condition)[1]
+
+ if target_model:
+
+ @th.no_grad()
+ def target_denoise_fn(x, t):
+ return self.denoise(target_model, x, t, condition)[1]
+
+ else:
+ raise NotImplementedError("Must have a target model")
+
+ if teacher_model:
+
+ @th.no_grad()
+ def teacher_denoise_fn(x, t):
+ return teacher_diffusion.denoise(teacher_model, x, t, condition)[1]
+
+ @th.no_grad()
+ def heun_solver(samples, t, next_t, x0):
+ x = samples
+ if teacher_model is None:
+ denoiser = x0
+ else:
+ denoiser = teacher_denoise_fn(x, t)
+
+ d = (x - denoiser) / append_dims(t, dims)
+ samples = x + d * append_dims(next_t - t, dims)
+ if teacher_model is None:
+ denoiser = x0
+ else:
+ denoiser = teacher_denoise_fn(samples, next_t)
+
+ next_d = (samples - denoiser) / append_dims(next_t, dims)
+ samples = x + (d + next_d) * append_dims((next_t - t) / 2, dims)
+
+ return samples
+
+ @th.no_grad()
+ def euler_solver(samples, t, next_t, x0):
+ x = samples
+ if teacher_model is None:
+ denoiser = x0
+ else:
+ denoiser = teacher_denoise_fn(x, t)
+ d = (x - denoiser) / append_dims(t, dims)
+ samples = x + d * append_dims(next_t - t, dims)
+
+ return samples
+
+ indices = th.randint(
+ 0, num_scales - 1, (x_start.shape[0],), device=x_start.device
+ )
+
+ t = self.sigma_max ** (1 / self.rho) + indices / (num_scales - 1) * (
+ self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho)
+ )
+ t = t**self.rho
+
+ t2 = self.sigma_max ** (1 / self.rho) + (indices + 1) / (num_scales - 1) * (
+ self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho)
+ )
+ t2 = t2**self.rho
+
+ x_t = x_start + noise * append_dims(t, dims)
+
+ dropout_state = th.get_rng_state()
+ distiller = denoise_fn(x_t, t)
+
+ if teacher_model is None:
+ x_t2 = euler_solver(x_t, t, t2, x_start).detach()
+ else:
+ x_t2 = heun_solver(x_t, t, t2, x_start).detach()
+
+ th.set_rng_state(dropout_state)
+ distiller_target = target_denoise_fn(x_t2, t2)
+ distiller_target = distiller_target.detach()
+
+ snrs = self.get_snr(t)
+ weights = get_weightings(self.weight_schedule, snrs, self.sigma_data)
+ if self.loss_norm == "l1":
+ diffs = th.abs(distiller - distiller_target)
+ loss = mean_flat(diffs) * weights
+ elif self.loss_norm == "l2":
+ # diffs = (distiller - distiller_target) ** 2
+ loss = F.mse_loss(distiller, distiller_target)
+ # loss = mean_flat(diffs) * weights
+ elif self.loss_norm == "ssim":
+ loss = self.ssim_loss(distiller, distiller_target) * weights
+ # elif self.loss_norm == "l2-32":
+ # distiller = F.interpolate(distiller, size=32, mode="bilinear")
+ # distiller_target = F.interpolate(
+ # distiller_target,
+ # size=32,
+ # mode="bilinear",
+ # )
+ # diffs = (distiller - distiller_target) ** 2
+ # loss = mean_flat(diffs) * weights
+ # elif self.loss_norm == "lpips":
+ # if x_start.shape[-1] < 256:
+ # distiller = F.interpolate(distiller, size=224, mode="bilinear")
+ # distiller_target = F.interpolate(
+ # distiller_target, size=224, mode="bilinear"
+ # )
+
+ # loss = (
+ # self.lpips_loss(
+ # (distiller + 1) / 2.0,
+ # (distiller_target + 1) / 2.0,
+ # )
+ # * weights
+ # )
+ else:
+ raise ValueError(f"Unknown loss norm {self.loss_norm}")
+
+ terms = {}
+ terms["loss"] = loss
+
+ return terms
+
+ # def progdist_losses(
+ # self,
+ # model,
+ # x_start,
+ # num_scales,
+ # model_kwargs=None,
+ # teacher_model=None,
+ # teacher_diffusion=None,
+ # noise=None,
+ # ):
+ # if model_kwargs is None:
+ # model_kwargs = {}
+ # if noise is None:
+ # noise = th.randn_like(x_start)
+
+ # dims = x_start.ndim
+
+ # def denoise_fn(x, t):
+ # return self.denoise(model, x, t, **model_kwargs)[1]
+
+ # @th.no_grad()
+ # def teacher_denoise_fn(x, t):
+ # return teacher_diffusion.denoise(teacher_model, x, t, **model_kwargs)[1]
+
+ # @th.no_grad()
+ # def euler_solver(samples, t, next_t):
+ # x = samples
+ # denoiser = teacher_denoise_fn(x, t)
+ # d = (x - denoiser) / append_dims(t, dims)
+ # samples = x + d * append_dims(next_t - t, dims)
+
+ # return samples
+
+ # @th.no_grad()
+ # def euler_to_denoiser(x_t, t, x_next_t, next_t):
+ # denoiser = x_t - append_dims(t, dims) * (x_next_t - x_t) / append_dims(
+ # next_t - t, dims
+ # )
+ # return denoiser
+
+ # indices = th.randint(0, num_scales, (x_start.shape[0],), device=x_start.device)
+
+ # t = self.sigma_max ** (1 / self.rho) + indices / num_scales * (
+ # self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho)
+ # )
+ # t = t**self.rho
+
+ # t2 = self.sigma_max ** (1 / self.rho) + (indices + 0.5) / num_scales * (
+ # self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho)
+ # )
+ # t2 = t2**self.rho
+
+ # t3 = self.sigma_max ** (1 / self.rho) + (indices + 1) / num_scales * (
+ # self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho)
+ # )
+ # t3 = t3**self.rho
+
+ # x_t = x_start + noise * append_dims(t, dims)
+
+ # denoised_x = denoise_fn(x_t, t)
+
+ # x_t2 = euler_solver(x_t, t, t2).detach()
+ # x_t3 = euler_solver(x_t2, t2, t3).detach()
+
+ # target_x = euler_to_denoiser(x_t, t, x_t3, t3).detach()
+
+ # snrs = self.get_snr(t)
+ # weights = get_weightings(self.weight_schedule, snrs, self.sigma_data)
+ # if self.loss_norm == "l1":
+ # diffs = th.abs(denoised_x - target_x)
+ # loss = mean_flat(diffs) * weights
+ # elif self.loss_norm == "l2":
+ # diffs = (denoised_x - target_x) ** 2
+ # loss = mean_flat(diffs) * weights
+ # elif self.loss_norm == "lpips":
+ # if x_start.shape[-1] < 256:
+ # denoised_x = F.interpolate(denoised_x, size=224, mode="bilinear")
+ # target_x = F.interpolate(target_x, size=224, mode="bilinear")
+ # loss = (
+ # self.lpips_loss(
+ # (denoised_x + 1) / 2.0,
+ # (target_x + 1) / 2.0,
+ # )
+ # * weights
+ # )
+ # else:
+ # raise ValueError(f"Unknown loss norm {self.loss_norm}")
+
+ # terms = {}
+ # terms["loss"] = loss
+
+ # return terms
+
+ def denoise(self, model, x_t, sigmas, condition):
+ if not self.distillation:
+ c_skip, c_out, c_in = [
+ append_dims(x, x_t.ndim) for x in self.get_scalings(sigmas)
+ ]
+ else:
+ c_skip, c_out, c_in = [
+ append_dims(x, x_t.ndim)
+ for x in self.get_scalings_for_boundary_condition(sigmas)
+ ]
+ rescaled_t = 1000 * 0.25 * th.log(sigmas + 1e-44)
+ # rescaled_t = rescaled_t[:, None]
+ model_output = model(c_in * x_t, rescaled_t, condition)
+ denoised = c_out * model_output + c_skip * x_t
+ return model_output, denoised
+
+
+def karras_sample(
+ diffusion,
+ model,
+ shape,
+ steps,
+ clip_denoised=True,
+ progress=True,
+ callback=None,
+ # model_kwargs=None,
+ condition=None,
+ device=None,
+ sigma_min=0.002,
+ sigma_max=80, # higher for highres?
+ rho=7.0,
+ sampler="heun",
+ s_churn=0.0,
+ s_tmin=0.0,
+ s_tmax=float("inf"),
+ s_noise=1.0,
+ generator=None,
+ ts=None,
+):
+ if generator is None:
+ generator = get_generator("dummy")
+
+ if sampler == "progdist":
+ sigmas = get_sigmas_karras(steps + 1, sigma_min, sigma_max, rho, device=device)
+ else:
+ sigmas = get_sigmas_karras(steps, sigma_min, sigma_max, rho, device=device)
+ th.manual_seed(42)
+ x_T = generator.randn(*shape, device=device) * sigma_max
+ sigmas = sigmas.unsqueeze(-1)
+ sample_fn = {
+ "heun": sample_heun,
+ "dpm": sample_dpm,
+ "ancestral": sample_euler_ancestral,
+ "onestep": sample_onestep,
+ "progdist": sample_progdist,
+ "euler": sample_euler,
+ "multistep": stochastic_iterative_sampler,
+ }[sampler]
+
+ if sampler in ["heun", "dpm"]:
+ sampler_args = dict(
+ s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise
+ )
+ elif sampler == "multistep":
+ sampler_args = dict(
+ ts=ts, t_min=sigma_min, t_max=sigma_max, rho=diffusion.rho, steps=steps
+ )
+ else:
+ sampler_args = {}
+
+ def denoiser(x_t, sigma):
+ _, denoised = diffusion.denoise(model, x_t, sigma, condition)
+ if clip_denoised:
+ denoised = denoised.clamp(-1, 1)
+ return denoised
+
+ x_0 = sample_fn(
+ denoiser,
+ x_T,
+ sigmas,
+ generator,
+ progress=progress,
+ callback=callback,
+ **sampler_args,
+ )
+ return x_0.clamp(-1, 1)
+
+
+def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device="cpu"):
+ """Constructs the noise schedule of Karras et al. (2022)."""
+ ramp = th.linspace(0, 1, n)
+ min_inv_rho = sigma_min ** (1 / rho)
+ max_inv_rho = sigma_max ** (1 / rho)
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
+ return append_zero(sigmas).to(device)
+
+
+def to_d(x, sigma, denoised):
+ """Converts a denoiser output to a Karras ODE derivative."""
+ return (x - denoised) / append_dims(sigma, x.ndim)
+
+
+def get_ancestral_step(sigma_from, sigma_to):
+ """Calculates the noise level (sigma_down) to step down to and the amount
+ of noise to add (sigma_up) when doing an ancestral sampling step."""
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
+ return sigma_down, sigma_up
+
+
+@th.no_grad()
+def sample_euler_ancestral(model, x, sigmas, generator, progress=False, callback=None):
+ """Ancestral sampling with Euler method steps."""
+ s_in = x.new_ones([x.shape[0]])
+ indices = range(len(sigmas) - 1)
+ if progress:
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ denoised = model(x, sigmas[i] * s_in)
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
+ if callback is not None:
+ callback(
+ {
+ "x": x,
+ "i": i,
+ "sigma": sigmas[i],
+ "sigma_hat": sigmas[i],
+ "denoised": denoised,
+ }
+ )
+ d = to_d(x, sigmas[i], denoised)
+ # Euler method
+ dt = sigma_down - sigmas[i]
+ x = x + d * dt
+ x = x + generator.randn_like(x) * sigma_up
+ return x
+
+
+@th.no_grad()
+def sample_midpoint_ancestral(model, x, ts, generator, progress=False, callback=None):
+ """Ancestral sampling with midpoint method steps."""
+ s_in = x.new_ones([x.shape[0]])
+ step_size = 1 / len(ts)
+ if progress:
+ from tqdm.auto import tqdm
+
+ ts = tqdm(ts)
+
+ for tn in ts:
+ dn = model(x, tn * s_in)
+ dn_2 = model(x + (step_size / 2) * dn, (tn + step_size / 2) * s_in)
+ x = x + step_size * dn_2
+ if callback is not None:
+ callback({"x": x, "tn": tn, "dn": dn, "dn_2": dn_2})
+ return x
+
+
+@th.no_grad()
+def sample_heun(
+ denoiser,
+ x,
+ sigmas,
+ generator,
+ progress=False,
+ callback=None,
+ s_churn=0.0,
+ s_tmin=0.0,
+ s_tmax=float("inf"),
+ s_noise=1.0,
+):
+ """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
+ s_in = x.new_ones([x.shape[0]])
+ indices = range(len(sigmas) - 1)
+ if progress:
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ gamma = (
+ min(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
+ if s_tmin <= sigmas[i] <= s_tmax
+ else 0.0
+ )
+ eps = generator.randn_like(x) * s_noise
+ sigma_hat = sigmas[i] * (gamma + 1)
+ if gamma > 0:
+ x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
+ denoised = denoiser(x, sigma_hat * s_in)
+ d = to_d(x, sigma_hat, denoised)
+ if callback is not None:
+ callback(
+ {
+ "x": x,
+ "i": i,
+ "sigma": sigmas[i],
+ "sigma_hat": sigma_hat,
+ "denoised": denoised,
+ }
+ )
+ dt = sigmas[i + 1] - sigma_hat
+ if sigmas[i + 1] == 0:
+ # Euler method
+ x = x + d * dt
+ else:
+ # Heun's method
+ x_2 = x + d * dt
+ denoised_2 = denoiser(x_2, sigmas[i + 1] * s_in)
+ d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
+ d_prime = (d + d_2) / 2
+ x = x + d_prime * dt
+ return x
+
+
+@th.no_grad()
+def sample_euler(
+ denoiser,
+ x,
+ sigmas,
+ generator,
+ progress=False,
+ callback=None,
+):
+ """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
+ s_in = x.new_ones([x.shape[0]])
+ indices = range(len(sigmas) - 1)
+ if progress:
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ sigma = sigmas[i]
+ denoised = denoiser(x, sigma * s_in)
+ d = to_d(x, sigma, denoised)
+ if callback is not None:
+ callback(
+ {
+ "x": x,
+ "i": i,
+ "sigma": sigmas[i],
+ "denoised": denoised,
+ }
+ )
+ dt = sigmas[i + 1] - sigma
+ x = x + d * dt
+ return x
+
+
+@th.no_grad()
+def sample_dpm(
+ denoiser,
+ x,
+ sigmas,
+ generator,
+ progress=False,
+ callback=None,
+ s_churn=0.0,
+ s_tmin=0.0,
+ s_tmax=float("inf"),
+ s_noise=1.0,
+):
+ """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
+ s_in = x.new_ones([x.shape[0]])
+ indices = range(len(sigmas) - 1)
+ if progress:
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ gamma = (
+ min(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
+ if s_tmin <= sigmas[i] <= s_tmax
+ else 0.0
+ )
+ eps = generator.randn_like(x) * s_noise
+ sigma_hat = sigmas[i] * (gamma + 1)
+ if gamma > 0:
+ x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
+ denoised = denoiser(x, sigma_hat * s_in)
+ d = to_d(x, sigma_hat, denoised)
+ if callback is not None:
+ callback(
+ {
+ "x": x,
+ "i": i,
+ "sigma": sigmas[i],
+ "sigma_hat": sigma_hat,
+ "denoised": denoised,
+ }
+ )
+ # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
+ sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3
+ dt_1 = sigma_mid - sigma_hat
+ dt_2 = sigmas[i + 1] - sigma_hat
+ x_2 = x + d * dt_1
+ denoised_2 = denoiser(x_2, sigma_mid * s_in)
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
+ x = x + d_2 * dt_2
+ return x
+
+
+@th.no_grad()
+def sample_onestep(
+ distiller,
+ x,
+ sigmas,
+ generator=None,
+ progress=False,
+ callback=None,
+):
+ """Single-step generation from a distilled model."""
+ s_in = x.new_ones([x.shape[0]])
+ return distiller(x, sigmas[0] * s_in)
+
+
+@th.no_grad()
+def stochastic_iterative_sampler(
+ distiller,
+ x,
+ sigmas,
+ generator,
+ ts,
+ progress=False,
+ callback=None,
+ t_min=0.002,
+ t_max=80.0,
+ rho=7.0,
+ steps=40,
+):
+ t_max_rho = t_max ** (1 / rho)
+ t_min_rho = t_min ** (1 / rho)
+ s_in = x.new_ones([x.shape[0]])
+
+ for i in range(len(ts) - 1):
+ t = (t_max_rho + ts[i] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
+ x0 = distiller(x, t * s_in)
+ next_t = (t_max_rho + ts[i + 1] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
+ next_t = np.clip(next_t, t_min, t_max)
+ x = x0 + generator.randn_like(x) * np.sqrt(next_t**2 - t_min**2)
+
+ return x
+
+
+@th.no_grad()
+def sample_progdist(
+ denoiser,
+ x,
+ sigmas,
+ generator=None,
+ progress=False,
+ callback=None,
+):
+ s_in = x.new_ones([x.shape[0]])
+ sigmas = sigmas[:-1] # skip the zero sigma
+
+ indices = range(len(sigmas) - 1)
+ if progress:
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ sigma = sigmas[i]
+ denoised = denoiser(x, sigma * s_in)
+ d = to_d(x, sigma, denoised)
+ if callback is not None:
+ callback(
+ {
+ "x": x,
+ "i": i,
+ "sigma": sigma,
+ "denoised": denoised,
+ }
+ )
+ dt = sigmas[i + 1] - sigma
+ x = x + d * dt
+
+ return x
+
+
+# @th.no_grad()
+# def iterative_colorization(
+# distiller,
+# images,
+# x,
+# ts,
+# t_min=0.002,
+# t_max=80.0,
+# rho=7.0,
+# steps=40,
+# generator=None,
+# ):
+# def obtain_orthogonal_matrix():
+# vector = np.asarray([0.2989, 0.5870, 0.1140])
+# vector = vector / np.linalg.norm(vector)
+# matrix = np.eye(3)
+# matrix[:, 0] = vector
+# matrix = np.linalg.qr(matrix)[0]
+# if np.sum(matrix[:, 0]) < 0:
+# matrix = -matrix
+# return matrix
+
+# Q = th.from_numpy(obtain_orthogonal_matrix()).to(dist_util.dev()).to(th.float32)
+# mask = th.zeros(*x.shape[1:], device=dist_util.dev())
+# mask[0, ...] = 1.0
+
+# def replacement(x0, x1):
+# x0 = th.einsum("bchw,cd->bdhw", x0, Q)
+# x1 = th.einsum("bchw,cd->bdhw", x1, Q)
+
+# x_mix = x0 * mask + x1 * (1.0 - mask)
+# x_mix = th.einsum("bdhw,cd->bchw", x_mix, Q)
+# return x_mix
+
+# t_max_rho = t_max ** (1 / rho)
+# t_min_rho = t_min ** (1 / rho)
+# s_in = x.new_ones([x.shape[0]])
+# images = replacement(images, th.zeros_like(images))
+
+# for i in range(len(ts) - 1):
+# t = (t_max_rho + ts[i] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
+# x0 = distiller(x, t * s_in)
+# x0 = th.clamp(x0, -1.0, 1.0)
+# x0 = replacement(images, x0)
+# next_t = (t_max_rho + ts[i + 1] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
+# next_t = np.clip(next_t, t_min, t_max)
+# x = x0 + generator.randn_like(x) * np.sqrt(next_t**2 - t_min**2)
+
+# return x, images
+
+
+# @th.no_grad()
+# def iterative_inpainting(
+# distiller,
+# images,
+# x,
+# ts,
+# t_min=0.002,
+# t_max=80.0,
+# rho=7.0,
+# steps=40,
+# generator=None,
+# ):
+# from PIL import Image, ImageDraw, ImageFont
+
+# image_size = x.shape[-1]
+
+# # create a blank image with a white background
+# img = Image.new("RGB", (image_size, image_size), color="white")
+
+# # get a drawing context for the image
+# draw = ImageDraw.Draw(img)
+
+# # load a font
+# font = ImageFont.truetype("arial.ttf", 250)
+
+# # draw the letter "C" in black
+# draw.text((50, 0), "S", font=font, fill=(0, 0, 0))
+
+# # convert the image to a numpy array
+# img_np = np.array(img)
+# img_np = img_np.transpose(2, 0, 1)
+# img_th = th.from_numpy(img_np).to(dist_util.dev())
+
+# mask = th.zeros(*x.shape, device=dist_util.dev())
+# mask = mask.reshape(-1, 7, 3, image_size, image_size)
+
+# mask[::2, :, img_th > 0.5] = 1.0
+# mask[1::2, :, img_th < 0.5] = 1.0
+# mask = mask.reshape(-1, 3, image_size, image_size)
+
+# def replacement(x0, x1):
+# x_mix = x0 * mask + x1 * (1 - mask)
+# return x_mix
+
+# t_max_rho = t_max ** (1 / rho)
+# t_min_rho = t_min ** (1 / rho)
+# s_in = x.new_ones([x.shape[0]])
+# images = replacement(images, -th.ones_like(images))
+
+# for i in range(len(ts) - 1):
+# t = (t_max_rho + ts[i] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
+# x0 = distiller(x, t * s_in)
+# x0 = th.clamp(x0, -1.0, 1.0)
+# x0 = replacement(images, x0)
+# next_t = (t_max_rho + ts[i + 1] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
+# next_t = np.clip(next_t, t_min, t_max)
+# x = x0 + generator.randn_like(x) * np.sqrt(next_t**2 - t_min**2)
+
+# return x, images
+
+
+# @th.no_grad()
+# def iterative_superres(
+# distiller,
+# images,
+# x,
+# ts,
+# t_min=0.002,
+# t_max=80.0,
+# rho=7.0,
+# steps=40,
+# generator=None,
+# ):
+# patch_size = 8
+
+# def obtain_orthogonal_matrix():
+# vector = np.asarray([1] * patch_size**2)
+# vector = vector / np.linalg.norm(vector)
+# matrix = np.eye(patch_size**2)
+# matrix[:, 0] = vector
+# matrix = np.linalg.qr(matrix)[0]
+# if np.sum(matrix[:, 0]) < 0:
+# matrix = -matrix
+# return matrix
+
+# Q = th.from_numpy(obtain_orthogonal_matrix()).to(dist_util.dev()).to(th.float32)
+
+# image_size = x.shape[-1]
+
+# def replacement(x0, x1):
+# x0_flatten = (
+# x0.reshape(-1, 3, image_size, image_size)
+# .reshape(
+# -1,
+# 3,
+# image_size // patch_size,
+# patch_size,
+# image_size // patch_size,
+# patch_size,
+# )
+# .permute(0, 1, 2, 4, 3, 5)
+# .reshape(-1, 3, image_size**2 // patch_size**2, patch_size**2)
+# )
+# x1_flatten = (
+# x1.reshape(-1, 3, image_size, image_size)
+# .reshape(
+# -1,
+# 3,
+# image_size // patch_size,
+# patch_size,
+# image_size // patch_size,
+# patch_size,
+# )
+# .permute(0, 1, 2, 4, 3, 5)
+# .reshape(-1, 3, image_size**2 // patch_size**2, patch_size**2)
+# )
+# x0 = th.einsum("bcnd,de->bcne", x0_flatten, Q)
+# x1 = th.einsum("bcnd,de->bcne", x1_flatten, Q)
+# x_mix = x0.new_zeros(x0.shape)
+# x_mix[..., 0] = x0[..., 0]
+# x_mix[..., 1:] = x1[..., 1:]
+# x_mix = th.einsum("bcne,de->bcnd", x_mix, Q)
+# x_mix = (
+# x_mix.reshape(
+# -1,
+# 3,
+# image_size // patch_size,
+# image_size // patch_size,
+# patch_size,
+# patch_size,
+# )
+# .permute(0, 1, 2, 4, 3, 5)
+# .reshape(-1, 3, image_size, image_size)
+# )
+# return x_mix
+
+# def average_image_patches(x):
+# x_flatten = (
+# x.reshape(-1, 3, image_size, image_size)
+# .reshape(
+# -1,
+# 3,
+# image_size // patch_size,
+# patch_size,
+# image_size // patch_size,
+# patch_size,
+# )
+# .permute(0, 1, 2, 4, 3, 5)
+# .reshape(-1, 3, image_size**2 // patch_size**2, patch_size**2)
+# )
+# x_flatten[..., :] = x_flatten.mean(dim=-1, keepdim=True)
+# return (
+# x_flatten.reshape(
+# -1,
+# 3,
+# image_size // patch_size,
+# image_size // patch_size,
+# patch_size,
+# patch_size,
+# )
+# .permute(0, 1, 2, 4, 3, 5)
+# .reshape(-1, 3, image_size, image_size)
+# )
+
+# t_max_rho = t_max ** (1 / rho)
+# t_min_rho = t_min ** (1 / rho)
+# s_in = x.new_ones([x.shape[0]])
+# images = average_image_patches(images)
+
+# for i in range(len(ts) - 1):
+# t = (t_max_rho + ts[i] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
+# x0 = distiller(x, t * s_in)
+# x0 = th.clamp(x0, -1.0, 1.0)
+# x0 = replacement(images, x0)
+# next_t = (t_max_rho + ts[i + 1] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
+# next_t = np.clip(next_t, t_min, t_max)
+# x = x0 + generator.randn_like(x) * np.sqrt(next_t**2 - t_min**2)
+
+# return x, images
diff --git a/modules/diffusion/karras/random_utils.py b/modules/diffusion/karras/random_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ebf4265934edad153adaf83777731afa313d000
--- /dev/null
+++ b/modules/diffusion/karras/random_utils.py
@@ -0,0 +1,177 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch as th
+
+
+def get_generator(generator, num_samples=0, seed=0):
+ if generator == "dummy":
+ return DummyGenerator()
+ elif generator == "determ":
+ return DeterministicGenerator(num_samples, seed)
+ elif generator == "determ-indiv":
+ return DeterministicIndividualGenerator(num_samples, seed)
+ else:
+ raise NotImplementedError
+
+
+class DummyGenerator:
+ def randn(self, *args, **kwargs):
+ return th.randn(*args, **kwargs)
+
+ def randint(self, *args, **kwargs):
+ return th.randint(*args, **kwargs)
+
+ def randn_like(self, *args, **kwargs):
+ return th.randn_like(*args, **kwargs)
+
+
+class DeterministicGenerator:
+ """
+ RNG to deterministically sample num_samples samples that does not depend on batch_size or mpi_machines
+ Uses a single rng and samples num_samples sized randomness and subsamples the current indices
+ """
+
+ def __init__(self, num_samples, seed=0):
+ print("Warning: Distributed not initialised, using single rank")
+ self.rank = 0
+ self.world_size = 1
+ self.num_samples = num_samples
+ self.done_samples = 0
+ self.seed = seed
+ self.rng_cpu = th.Generator()
+ if th.cuda.is_available():
+ self.rng_cuda = th.Generator(dist_util.dev())
+ self.set_seed(seed)
+
+ def get_global_size_and_indices(self, size):
+ global_size = (self.num_samples, *size[1:])
+ indices = th.arange(
+ self.done_samples + self.rank,
+ self.done_samples + self.world_size * int(size[0]),
+ self.world_size,
+ )
+ indices = th.clamp(indices, 0, self.num_samples - 1)
+ assert (
+ len(indices) == size[0]
+ ), f"rank={self.rank}, ws={self.world_size}, l={len(indices)}, bs={size[0]}"
+ return global_size, indices
+
+ def get_generator(self, device):
+ return self.rng_cpu if th.device(device).type == "cpu" else self.rng_cuda
+
+ def randn(self, *size, dtype=th.float, device="cpu"):
+ global_size, indices = self.get_global_size_and_indices(size)
+ generator = self.get_generator(device)
+ return th.randn(*global_size, generator=generator, dtype=dtype, device=device)[
+ indices
+ ]
+
+ def randint(self, low, high, size, dtype=th.long, device="cpu"):
+ global_size, indices = self.get_global_size_and_indices(size)
+ generator = self.get_generator(device)
+ return th.randint(
+ low, high, generator=generator, size=global_size, dtype=dtype, device=device
+ )[indices]
+
+ def randn_like(self, tensor):
+ size, dtype, device = tensor.size(), tensor.dtype, tensor.device
+ return self.randn(*size, dtype=dtype, device=device)
+
+ def set_done_samples(self, done_samples):
+ self.done_samples = done_samples
+ self.set_seed(self.seed)
+
+ def get_seed(self):
+ return self.seed
+
+ def set_seed(self, seed):
+ self.rng_cpu.manual_seed(seed)
+ if th.cuda.is_available():
+ self.rng_cuda.manual_seed(seed)
+
+
+class DeterministicIndividualGenerator:
+ """
+ RNG to deterministically sample num_samples samples that does not depend on batch_size or mpi_machines
+ Uses a separate rng for each sample to reduce memoery usage
+ """
+
+ def __init__(self, num_samples, seed=0):
+ print("Warning: Distributed not initialised, using single rank")
+ self.rank = 0
+ self.world_size = 1
+ self.num_samples = num_samples
+ self.done_samples = 0
+ self.seed = seed
+ self.rng_cpu = [th.Generator() for _ in range(num_samples)]
+ if th.cuda.is_available():
+ self.rng_cuda = [th.Generator(dist_util.dev()) for _ in range(num_samples)]
+ self.set_seed(seed)
+
+ def get_size_and_indices(self, size):
+ indices = th.arange(
+ self.done_samples + self.rank,
+ self.done_samples + self.world_size * int(size[0]),
+ self.world_size,
+ )
+ indices = th.clamp(indices, 0, self.num_samples - 1)
+ assert (
+ len(indices) == size[0]
+ ), f"rank={self.rank}, ws={self.world_size}, l={len(indices)}, bs={size[0]}"
+ return (1, *size[1:]), indices
+
+ def get_generator(self, device):
+ return self.rng_cpu if th.device(device).type == "cpu" else self.rng_cuda
+
+ def randn(self, *size, dtype=th.float, device="cpu"):
+ size, indices = self.get_size_and_indices(size)
+ generator = self.get_generator(device)
+ return th.cat(
+ [
+ th.randn(*size, generator=generator[i], dtype=dtype, device=device)
+ for i in indices
+ ],
+ dim=0,
+ )
+
+ def randint(self, low, high, size, dtype=th.long, device="cpu"):
+ size, indices = self.get_size_and_indices(size)
+ generator = self.get_generator(device)
+ return th.cat(
+ [
+ th.randint(
+ low,
+ high,
+ generator=generator[i],
+ size=size,
+ dtype=dtype,
+ device=device,
+ )
+ for i in indices
+ ],
+ dim=0,
+ )
+
+ def randn_like(self, tensor):
+ size, dtype, device = tensor.size(), tensor.dtype, tensor.device
+ return self.randn(*size, dtype=dtype, device=device)
+
+ def set_done_samples(self, done_samples):
+ self.done_samples = done_samples
+
+ def get_seed(self):
+ return self.seed
+
+ def set_seed(self, seed):
+ [
+ rng_cpu.manual_seed(i + self.num_samples * seed)
+ for i, rng_cpu in enumerate(self.rng_cpu)
+ ]
+ if th.cuda.is_available():
+ [
+ rng_cuda.manual_seed(i + self.num_samples * seed)
+ for i, rng_cuda in enumerate(self.rng_cuda)
+ ]
diff --git a/modules/diffusion/karras/sample.py b/modules/diffusion/karras/sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d7472a37342c8e8eb06ccd6e292a28b80e8dfef
--- /dev/null
+++ b/modules/diffusion/karras/sample.py
@@ -0,0 +1,185 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from abc import ABC, abstractmethod
+
+import numpy as np
+import torch as th
+from scipy.stats import norm
+import torch.distributed as dist
+
+
+def create_named_schedule_sampler(name, diffusion):
+ """
+ Create a ScheduleSampler from a library of pre-defined samplers.
+
+ :param name: the name of the sampler.
+ :param diffusion: the diffusion object to sample for.
+ """
+ if name == "uniform":
+ return UniformSampler(diffusion)
+ elif name == "loss-second-moment":
+ return LossSecondMomentResampler(diffusion)
+ elif name == "lognormal":
+ return LogNormalSampler()
+ else:
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
+
+
+class ScheduleSampler(ABC):
+ """
+ A distribution over timesteps in the diffusion process, intended to reduce
+ variance of the objective.
+
+ By default, samplers perform unbiased importance sampling, in which the
+ objective's mean is unchanged.
+ However, subclasses may override sample() to change how the resampled
+ terms are reweighted, allowing for actual changes in the objective.
+ """
+
+ @abstractmethod
+ def weights(self):
+ """
+ Get a numpy array of weights, one per diffusion step.
+
+ The weights needn't be normalized, but must be positive.
+ """
+
+ def sample(self, batch_size, device):
+ """
+ Importance-sample timesteps for a batch.
+
+ :param batch_size: the number of timesteps.
+ :param device: the torch device to save to.
+ :return: a tuple (timesteps, weights):
+ - timesteps: a tensor of timestep indices.
+ - weights: a tensor of weights to scale the resulting losses.
+ """
+ w = self.weights()
+ p = w / np.sum(w)
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
+ indices = th.from_numpy(indices_np).long().to(device)
+ weights_np = 1 / (len(p) * p[indices_np])
+ weights = th.from_numpy(weights_np).float().to(device)
+ return indices, weights
+
+
+class UniformSampler(ScheduleSampler):
+ def __init__(self, diffusion):
+ self.diffusion = diffusion
+ self._weights = np.ones([diffusion.num_timesteps])
+
+ def weights(self):
+ return self._weights
+
+
+class LossAwareSampler(ScheduleSampler):
+ def update_with_local_losses(self, local_ts, local_losses):
+ """
+ Update the reweighting using losses from a model.
+
+ Call this method from each rank with a batch of timesteps and the
+ corresponding losses for each of those timesteps.
+ This method will perform synchronization to make sure all of the ranks
+ maintain the exact same reweighting.
+
+ :param local_ts: an integer Tensor of timesteps.
+ :param local_losses: a 1D Tensor of losses.
+ """
+ batch_sizes = [
+ th.tensor([0], dtype=th.int32, device=local_ts.device)
+ for _ in range(dist.get_world_size())
+ ]
+ dist.all_gather(
+ batch_sizes,
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
+ )
+
+ # Pad all_gather batches to be the maximum batch size.
+ batch_sizes = [x.item() for x in batch_sizes]
+ max_bs = max(batch_sizes)
+
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
+ dist.all_gather(timestep_batches, local_ts)
+ dist.all_gather(loss_batches, local_losses)
+ timesteps = [
+ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
+ ]
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
+ self.update_with_all_losses(timesteps, losses)
+
+ @abstractmethod
+ def update_with_all_losses(self, ts, losses):
+ """
+ Update the reweighting using losses from a model.
+
+ Sub-classes should override this method to update the reweighting
+ using losses from the model.
+
+ This method directly updates the reweighting without synchronizing
+ between workers. It is called by update_with_local_losses from all
+ ranks with identical arguments. Thus, it should have deterministic
+ behavior to maintain state across workers.
+
+ :param ts: a list of int timesteps.
+ :param losses: a list of float losses, one per timestep.
+ """
+
+
+class LossSecondMomentResampler(LossAwareSampler):
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
+ self.diffusion = diffusion
+ self.history_per_term = history_per_term
+ self.uniform_prob = uniform_prob
+ self._loss_history = np.zeros(
+ [diffusion.num_timesteps, history_per_term], dtype=np.float64
+ )
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
+
+ def weights(self):
+ if not self._warmed_up():
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
+ weights = np.sqrt(np.mean(self._loss_history**2, axis=-1))
+ weights /= np.sum(weights)
+ weights *= 1 - self.uniform_prob
+ weights += self.uniform_prob / len(weights)
+ return weights
+
+ def update_with_all_losses(self, ts, losses):
+ for t, loss in zip(ts, losses):
+ if self._loss_counts[t] == self.history_per_term:
+ # Shift out the oldest loss term.
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
+ self._loss_history[t, -1] = loss
+ else:
+ self._loss_history[t, self._loss_counts[t]] = loss
+ self._loss_counts[t] += 1
+
+ def _warmed_up(self):
+ return (self._loss_counts == self.history_per_term).all()
+
+
+class LogNormalSampler:
+ def __init__(self, p_mean=-1.2, p_std=1.2, even=False):
+ self.p_mean = p_mean
+ self.p_std = p_std
+ self.even = even
+ if self.even:
+ self.inv_cdf = lambda x: norm.ppf(x, loc=p_mean, scale=p_std)
+ self.rank, self.size = dist.get_rank(), dist.get_world_size()
+
+ def sample(self, bs, device):
+ if self.even:
+ # buckets = [1/G]
+ start_i, end_i = self.rank * bs, (self.rank + 1) * bs
+ global_batch_size = self.size * bs
+ locs = (th.arange(start_i, end_i) + th.rand(bs)) / global_batch_size
+ log_sigmas = th.tensor(self.inv_cdf(locs), dtype=th.float32, device=device)
+ else:
+ log_sigmas = self.p_mean + self.p_std * th.randn(bs, device=device)
+ sigmas = th.exp(log_sigmas)
+ weights = th.ones_like(sigmas)
+ return sigmas, weights
diff --git a/modules/diffusion/unet/attention.py b/modules/diffusion/unet/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5a210cb392e3e1beb9f8c2884978f6b0b71ced4
--- /dev/null
+++ b/modules/diffusion/unet/attention.py
@@ -0,0 +1,241 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from modules.general.utils import Conv1d, normalization, zero_module
+from .basic import UNetBlock
+
+
+class AttentionBlock(UNetBlock):
+ r"""A spatial transformer encoder block that allows spatial positions to attend
+ to each other. Reference from `latent diffusion repo
+ `_.
+
+ Args:
+ channels: Number of channels in the input.
+ num_head_channels: Number of channels per attention head.
+ num_heads: Number of attention heads. Overrides ``num_head_channels`` if set.
+ encoder_channels: Number of channels in the encoder output for cross-attention.
+ If ``None``, then self-attention is performed.
+ use_self_attention: Whether to use self-attention before cross-attention, only applicable if encoder_channels is set.
+ dims: Number of spatial dimensions, i.e. 1 for temporal signals, 2 for images.
+ h_dim: The dimension of the height, would be applied if ``dims`` is 2.
+ encoder_hdim: The dimension of the height of the encoder output, would be applied if ``dims`` is 2.
+ p_dropout: Dropout probability.
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ num_head_channels: int = 32,
+ num_heads: int = -1,
+ encoder_channels: int = None,
+ use_self_attention: bool = False,
+ dims: int = 1,
+ h_dim: int = 100,
+ encoder_hdim: int = 384,
+ p_dropout: float = 0.0,
+ ):
+ super().__init__()
+
+ self.channels = channels
+ self.p_dropout = p_dropout
+ self.dims = dims
+
+ if dims == 1:
+ self.channels = channels
+ elif dims == 2:
+ # We consider the channel as product of channel and height, i.e. C x H
+ # This is because we want to apply attention on the audio signal, which is 1D
+ self.channels = channels * h_dim
+ else:
+ raise ValueError(f"invalid number of dimensions: {dims}")
+
+ if num_head_channels == -1:
+ assert (
+ self.channels % num_heads == 0
+ ), f"q,k,v channels {self.channels} is not divisible by num_heads {num_heads}"
+ self.num_heads = num_heads
+ self.num_head_channels = self.channels // num_heads
+ else:
+ assert (
+ self.channels % num_head_channels == 0
+ ), f"q,k,v channels {self.channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = self.channels // num_head_channels
+ self.num_head_channels = num_head_channels
+
+ if encoder_channels is not None:
+ self.use_self_attention = use_self_attention
+
+ if dims == 1:
+ self.encoder_channels = encoder_channels
+ elif dims == 2:
+ self.encoder_channels = encoder_channels * encoder_hdim
+ else:
+ raise ValueError(f"invalid number of dimensions: {dims}")
+
+ if use_self_attention:
+ self.self_attention = BasicAttentionBlock(
+ self.channels,
+ self.num_head_channels,
+ self.num_heads,
+ p_dropout=self.p_dropout,
+ )
+ self.cross_attention = BasicAttentionBlock(
+ self.channels,
+ self.num_head_channels,
+ self.num_heads,
+ self.encoder_channels,
+ p_dropout=self.p_dropout,
+ )
+ else:
+ self.encoder_channels = None
+ self.self_attention = BasicAttentionBlock(
+ self.channels,
+ self.num_head_channels,
+ self.num_heads,
+ p_dropout=self.p_dropout,
+ )
+
+ def forward(self, x: torch.Tensor, encoder_output: torch.Tensor = None):
+ r"""
+ Args:
+ x: input tensor with shape [B x ``channels`` x ...]
+ encoder_output: feature tensor with shape [B x ``encoder_channels`` x ...], if ``None``, then self-attention is performed.
+
+ Returns:
+ output tensor with shape [B x ``channels`` x ...]
+ """
+ shape = x.size()
+ x = x.reshape(shape[0], self.channels, -1).contiguous()
+
+ if self.encoder_channels is None:
+ assert (
+ encoder_output is None
+ ), "encoder_output must be None for self-attention."
+ h = self.self_attention(x)
+
+ else:
+ assert (
+ encoder_output is not None
+ ), "encoder_output must be given for cross-attention."
+ encoder_output = encoder_output.reshape(
+ shape[0], self.encoder_channels, -1
+ ).contiguous()
+
+ if self.use_self_attention:
+ x = self.self_attention(x)
+ h = self.cross_attention(x, encoder_output)
+
+ return h.reshape(*shape).contiguous()
+
+
+class BasicAttentionBlock(nn.Module):
+ def __init__(
+ self,
+ channels: int,
+ num_head_channels: int = 32,
+ num_heads: int = -1,
+ context_channels: int = None,
+ p_dropout: float = 0.0,
+ ):
+ super().__init__()
+
+ self.channels = channels
+ self.p_dropout = p_dropout
+ self.context_channels = context_channels
+
+ if num_head_channels == -1:
+ assert (
+ self.channels % num_heads == 0
+ ), f"q,k,v channels {self.channels} is not divisible by num_heads {num_heads}"
+ self.num_heads = num_heads
+ self.num_head_channels = self.channels // num_heads
+ else:
+ assert (
+ self.channels % num_head_channels == 0
+ ), f"q,k,v channels {self.channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = self.channels // num_head_channels
+ self.num_head_channels = num_head_channels
+
+ if context_channels is not None:
+ self.to_q = nn.Sequential(
+ normalization(self.channels),
+ Conv1d(self.channels, self.channels, 1),
+ )
+ self.to_kv = Conv1d(context_channels, 2 * self.channels, 1)
+ else:
+ self.to_qkv = nn.Sequential(
+ normalization(self.channels),
+ Conv1d(self.channels, 3 * self.channels, 1),
+ )
+
+ self.linear = Conv1d(self.channels, self.channels)
+
+ self.proj_out = nn.Sequential(
+ normalization(self.channels),
+ Conv1d(self.channels, self.channels, 1),
+ nn.GELU(),
+ nn.Dropout(p=self.p_dropout),
+ zero_module(Conv1d(self.channels, self.channels, 1)),
+ )
+
+ def forward(self, q: torch.Tensor, kv: torch.Tensor = None):
+ r"""
+ Args:
+ q: input tensor with shape [B, ``channels``, L]
+ kv: feature tensor with shape [B, ``context_channels``, T], if ``None``, then self-attention is performed.
+
+ Returns:
+ output tensor with shape [B, ``channels``, L]
+ """
+ N, C, L = q.size()
+
+ if self.context_channels is not None:
+ assert kv is not None, "kv must be given for cross-attention."
+
+ q = (
+ self.to_q(q)
+ .reshape(self.num_heads, self.num_head_channels, -1)
+ .transpose(-1, -2)
+ .contiguous()
+ )
+ kv = (
+ self.to_kv(kv)
+ .reshape(2, self.num_heads, self.num_head_channels, -1)
+ .transpose(-1, -2)
+ .chunk(2)
+ )
+ k, v = (
+ kv[0].squeeze(0).contiguous(),
+ kv[1].squeeze(0).contiguous(),
+ )
+
+ else:
+ qkv = (
+ self.to_qkv(q)
+ .reshape(3, self.num_heads, self.num_head_channels, -1)
+ .transpose(-1, -2)
+ .chunk(3)
+ )
+ q, k, v = (
+ qkv[0].squeeze(0).contiguous(),
+ qkv[1].squeeze(0).contiguous(),
+ qkv[2].squeeze(0).contiguous(),
+ )
+
+ h = F.scaled_dot_product_attention(q, k, v, dropout_p=self.p_dropout).transpose(
+ -1, -2
+ )
+ h = h.reshape(N, -1, L).contiguous()
+ h = self.linear(h)
+
+ x = q + h
+ h = self.proj_out(x)
+
+ return x + h
diff --git a/modules/diffusion/unet/basic.py b/modules/diffusion/unet/basic.py
new file mode 100644
index 0000000000000000000000000000000000000000..9cd4bc55ab58b3ebb7f9b2c0ec8dd37ad6d43721
--- /dev/null
+++ b/modules/diffusion/unet/basic.py
@@ -0,0 +1,15 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch.nn as nn
+from abc import abstractmethod
+
+
+class UNetBlock(nn.Module):
+ r"""Any module where forward() takes timestep embeddings as a second argument."""
+
+ @abstractmethod
+ def forward(self, x, emb):
+ r"""Apply the module to `x` given `emb` timestep embeddings."""
diff --git a/modules/diffusion/unet/resblock.py b/modules/diffusion/unet/resblock.py
new file mode 100644
index 0000000000000000000000000000000000000000..144c867785bbf76fa877ea6e24dfbdcffb78008e
--- /dev/null
+++ b/modules/diffusion/unet/resblock.py
@@ -0,0 +1,178 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .basic import UNetBlock
+from modules.general.utils import (
+ append_dims,
+ ConvNd,
+ normalization,
+ zero_module,
+)
+
+
+class ResBlock(UNetBlock):
+ r"""A residual block that can optionally change the number of channels.
+
+ Args:
+ channels: the number of input channels.
+ emb_channels: the number of timestep embedding channels.
+ dropout: the rate of dropout.
+ out_channels: if specified, the number of out channels.
+ use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ dims: determines if the signal is 1D, 2D, or 3D.
+ up: if True, use this block for upsampling.
+ down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout: float = 0.0,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_scale_shift_norm = use_scale_shift_norm
+
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ ConvNd(dims, channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ ConvNd(
+ dims,
+ emb_channels,
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+ 1,
+ ),
+ )
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ ConvNd(dims, self.out_channels, self.out_channels, 3, padding=1)
+ ),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = ConvNd(
+ dims, channels, self.out_channels, 3, padding=1
+ )
+ else:
+ self.skip_connection = ConvNd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x, emb):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+
+ x: an [N x C x ...] Tensor of features.
+ emb: an [N x emb_channels x ...] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb)
+ emb_out = append_dims(emb_out, h.dim())
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+
+class Upsample(nn.Module):
+ r"""An upsampling layer with an optional convolution.
+
+ Args:
+ channels: channels in the inputs and outputs.
+ dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ out_channels: if specified, the number of out channels.
+ """
+
+ def __init__(self, channels, dims=2, out_channels=None):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.dims = dims
+ self.conv = ConvNd(dims, self.channels, self.out_channels, 3, padding=1)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
+ )
+ else:
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ r"""A downsampling layer with an optional convolution.
+
+ Args:
+ channels: channels in the inputs and outputs.
+ dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ out_channels: if specified, the number of output channels.
+ """
+
+ def __init__(self, channels, dims=2, out_channels=None):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ self.op = ConvNd(
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=1
+ )
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
diff --git a/modules/diffusion/unet/unet.py b/modules/diffusion/unet/unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d39f5d1e07f64012ba08268a0a4a5e71ad01e88
--- /dev/null
+++ b/modules/diffusion/unet/unet.py
@@ -0,0 +1,310 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+
+from modules.encoder.position_encoder import PositionEncoder
+from modules.general.utils import append_dims, ConvNd, normalization, zero_module
+from .attention import AttentionBlock
+from .resblock import Downsample, ResBlock, Upsample
+
+
+class UNet(nn.Module):
+ r"""The full UNet model with attention and timestep embedding.
+
+ Args:
+ dims: determines if the signal is 1D (temporal), 2D(spatial).
+ in_channels: channels in the input Tensor.
+ model_channels: base channel count for the model.
+ out_channels: channels in the output Tensor.
+ num_res_blocks: number of residual blocks per downsample.
+ channel_mult: channel multiplier for each level of the UNet.
+ num_attn_blocks: number of attention blocks at place.
+ attention_resolutions: a collection of downsample rates at which attention will
+ take place. May be a set, list, or tuple. For example, if this contains 4,
+ then at 4x downsampling, attention will be used.
+ num_heads: the number of attention heads in each attention layer.
+ num_head_channels: if specified, ignore num_heads and instead use a fixed
+ channel width per attention head.
+ d_context: if specified, use for cross-attention channel project.
+ p_dropout: the dropout probability.
+ use_self_attention: Apply self attention before cross attention.
+ num_classes: if specified (as an int), then this model will be class-conditional
+ with ``num_classes`` classes.
+ use_extra_film: if specified, use an extra FiLM-like conditioning mechanism.
+ d_emb: if specified, use for FiLM-like conditioning.
+ use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ resblock_updown: use residual blocks for up/downsampling.
+ """
+
+ def __init__(
+ self,
+ dims: int = 1,
+ in_channels: int = 100,
+ model_channels: int = 128,
+ out_channels: int = 100,
+ h_dim: int = 128,
+ num_res_blocks: int = 1,
+ channel_mult: tuple = (1, 2, 4),
+ num_attn_blocks: int = 1,
+ attention_resolutions: tuple = (1, 2, 4),
+ num_heads: int = 1,
+ num_head_channels: int = -1,
+ d_context: int = None,
+ context_hdim: int = 128,
+ p_dropout: float = 0.0,
+ num_classes: int = -1,
+ use_extra_film: str = None,
+ d_emb: int = None,
+ use_scale_shift_norm: bool = True,
+ resblock_updown: bool = False,
+ ):
+ super().__init__()
+
+ self.dims = dims
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.channel_mult = channel_mult
+ self.num_attn_blocks = num_attn_blocks
+ self.attention_resolutions = attention_resolutions
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.d_context = d_context
+ self.p_dropout = p_dropout
+ self.num_classes = num_classes
+ self.use_extra_film = use_extra_film
+ self.d_emb = d_emb
+ self.use_scale_shift_norm = use_scale_shift_norm
+ self.resblock_updown = resblock_updown
+
+ time_embed_dim = model_channels * 4
+ self.pos_enc = PositionEncoder(model_channels, time_embed_dim)
+
+ assert (
+ num_classes == -1 or use_extra_film is None
+ ), "You cannot set both num_classes and use_extra_film."
+
+ if self.num_classes > 0:
+ # TODO: if used for singer, norm should be 1, correct?
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim, max_norm=1.0)
+ elif use_extra_film is not None:
+ assert (
+ d_emb is not None
+ ), "d_emb must be specified if use_extra_film is not None"
+ assert use_extra_film in [
+ "add",
+ "concat",
+ ], f"use_extra_film only supported by add or concat. Your input is {use_extra_film}"
+ self.use_extra_film = use_extra_film
+ self.film_emb = ConvNd(dims, d_emb, time_embed_dim, 1)
+ if use_extra_film == "concat":
+ time_embed_dim *= 2
+
+ # Input blocks
+ ch = input_ch = int(channel_mult[0] * model_channels)
+ self.input_blocks = nn.ModuleList(
+ [UNetSequential(ConvNd(dims, in_channels, ch, 3, padding=1))]
+ )
+ self._feature_size = ch
+ input_block_chans = [ch]
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ p_dropout,
+ out_channels=int(mult * model_channels),
+ dims=dims,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = int(mult * model_channels)
+ if ds in attention_resolutions:
+ for _ in range(num_attn_blocks):
+ layers.append(
+ AttentionBlock(
+ ch,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ encoder_channels=d_context,
+ dims=dims,
+ h_dim=h_dim // (level + 1),
+ encoder_hdim=context_hdim,
+ p_dropout=p_dropout,
+ )
+ )
+ self.input_blocks.append(UNetSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ UNetSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ p_dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(ch, dims=dims, out_channels=out_ch)
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ # Middle blocks
+ self.middle_block = UNetSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ p_dropout,
+ dims=dims,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ encoder_channels=d_context,
+ dims=dims,
+ h_dim=h_dim // (level + 1),
+ encoder_hdim=context_hdim,
+ p_dropout=p_dropout,
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ p_dropout,
+ dims=dims,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+
+ # Output blocks
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in tuple(enumerate(channel_mult))[::-1]:
+ for i in range(num_res_blocks + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ p_dropout,
+ out_channels=int(model_channels * mult),
+ dims=dims,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = int(model_channels * mult)
+ if ds in attention_resolutions:
+ for _ in range(num_attn_blocks):
+ layers.append(
+ AttentionBlock(
+ ch,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ encoder_channels=d_context,
+ dims=dims,
+ h_dim=h_dim // (level + 1),
+ encoder_hdim=context_hdim,
+ p_dropout=p_dropout,
+ )
+ )
+ if level and i == num_res_blocks:
+ out_ch = ch
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ p_dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(ch, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(UNetSequential(*layers))
+ self._feature_size += ch
+
+ # Final proj out
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(ConvNd(dims, input_ch, out_channels, 3, padding=1)),
+ )
+
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
+ r"""Apply the model to an input batch.
+
+ Args:
+ x: an [N x C x ...] Tensor of inputs.
+ timesteps: a 1-D batch of timesteps, i.e. [N].
+ context: conditioning Tensor with shape of [N x ``d_context`` x ...] plugged
+ in via cross attention.
+ y: an [N] Tensor of labels, if **class-conditional**.
+ an [N x ``d_emb`` x ...] Tensor if **film-embed conditional**.
+
+ Returns:
+ an [N x C x ...] Tensor of outputs.
+ """
+ assert (y is None) or (
+ (y is not None)
+ and ((self.num_classes > 0) or (self.use_extra_film is not None))
+ ), f"y must be specified if num_classes or use_extra_film is not None. \nGot num_classes: {self.num_classes}\t\nuse_extra_film: {self.use_extra_film}\t\n"
+
+ hs = []
+ emb = self.pos_enc(timesteps)
+ emb = append_dims(emb, x.dim())
+
+ if self.num_classes > 0:
+ assert y.size() == (x.size(0),)
+ emb = emb + self.label_emb(y)
+ elif self.use_extra_film is not None:
+ assert y.size() == (x.size(0), self.d_emb, *x.size()[2:])
+ y = self.film_emb(y)
+ if self.use_extra_film == "add":
+ emb = emb + y
+ elif self.use_extra_film == "concat":
+ emb = torch.cat([emb, y], dim=1)
+
+ h = x
+ for module in self.input_blocks:
+ h = module(h, emb, context)
+ hs.append(h)
+ h = self.middle_block(h, emb, context)
+ for module in self.output_blocks:
+ h = torch.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context)
+
+ return self.out(h)
+
+
+class UNetSequential(nn.Sequential):
+ r"""A sequential module that passes embeddings to the children that support it."""
+
+ def forward(self, x, emb=None, context=None):
+ for layer in self:
+ if isinstance(layer, ResBlock):
+ x = layer(x, emb)
+ elif isinstance(layer, AttentionBlock):
+ x = layer(x, context)
+ else:
+ x = layer(x)
+ return x
diff --git a/modules/distributions/__init__.py b/modules/distributions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/modules/distributions/distributions.py b/modules/distributions/distributions.py
new file mode 100644
index 0000000000000000000000000000000000000000..0287d76c7ff2d0c1436c8e8b24ccd43513d38714
--- /dev/null
+++ b/modules/distributions/distributions.py
@@ -0,0 +1,107 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import numpy as np
+
+
+class AbstractDistribution:
+ def sample(self):
+ raise NotImplementedError()
+
+ def mode(self):
+ raise NotImplementedError()
+
+
+class DiracDistribution(AbstractDistribution):
+ def __init__(self, value):
+ self.value = value
+
+ def sample(self):
+ return self.value
+
+ def mode(self):
+ return self.value
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(
+ device=self.parameters.device
+ )
+
+ def sample(self):
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(
+ device=self.parameters.device
+ )
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3],
+ )
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var
+ - 1.0
+ - self.logvar
+ + other.logvar,
+ dim=[1, 2, 3],
+ )
+
+ def nll(self, sample, dims=[1, 2, 3]):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims,
+ )
+
+ def mode(self):
+ return self.mean
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, torch.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for torch.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + torch.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
+ )
diff --git a/modules/duration_predictor/__init__.py b/modules/duration_predictor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/modules/duration_predictor/standard_duration_predictor.py b/modules/duration_predictor/standard_duration_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..9924e70a7571cef0359019d6d083db7d9bd6e9cf
--- /dev/null
+++ b/modules/duration_predictor/standard_duration_predictor.py
@@ -0,0 +1,53 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# This code is modified from https://github.com/jaywalnut310/vits/blob/main/models.py
+
+import torch
+from torch import nn
+from modules.base.base_module import LayerNorm
+
+
+class DurationPredictor(nn.Module):
+ def __init__(
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
+ ):
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.filter_channels = filter_channels
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.gin_channels = gin_channels
+
+ self.drop = nn.Dropout(p_dropout)
+ self.conv_1 = nn.Conv1d(
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
+ )
+ self.norm_1 = LayerNorm(filter_channels)
+ self.conv_2 = nn.Conv1d(
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
+ )
+ self.norm_2 = LayerNorm(filter_channels)
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
+
+ if gin_channels != 0:
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
+
+ def forward(self, x, x_mask, g=None):
+ x = torch.detach(x)
+ if g is not None:
+ g = torch.detach(g)
+ x = x + self.cond(g)
+ x = self.conv_1(x * x_mask)
+ x = torch.relu(x)
+ x = self.norm_1(x)
+ x = self.drop(x)
+ x = self.conv_2(x * x_mask)
+ x = torch.relu(x)
+ x = self.norm_2(x)
+ x = self.drop(x)
+ x = self.proj(x * x_mask)
+ return x * x_mask
diff --git a/modules/duration_predictor/stochastic_duration_predictor.py b/modules/duration_predictor/stochastic_duration_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..60bbb95547bb89a24bb8852c57a462a2d2326a32
--- /dev/null
+++ b/modules/duration_predictor/stochastic_duration_predictor.py
@@ -0,0 +1,120 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# This code is modified from https://github.com/jaywalnut310/vits/blob/main/models.pyimport torch
+
+from torch import nn
+from torch.nn import functional as F
+import math
+from modules.flow.modules import *
+
+
+class StochasticDurationPredictor(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ filter_channels,
+ kernel_size,
+ p_dropout,
+ n_flows=4,
+ gin_channels=0,
+ ):
+ super().__init__()
+ filter_channels = in_channels
+ self.in_channels = in_channels
+ self.filter_channels = filter_channels
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.n_flows = n_flows
+ self.gin_channels = gin_channels
+
+ self.log_flow = Log()
+ self.flows = nn.ModuleList()
+ self.flows.append(ElementwiseAffine(2))
+ for i in range(n_flows):
+ self.flows.append(ConvFlow(2, filter_channels, kernel_size, n_layers=3))
+ self.flows.append(Flip())
+
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
+ self.post_convs = DDSConv(
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
+ )
+ self.post_flows = nn.ModuleList()
+ self.post_flows.append(ElementwiseAffine(2))
+ for i in range(4):
+ self.post_flows.append(
+ ConvFlow(2, filter_channels, kernel_size, n_layers=3)
+ )
+ self.post_flows.append(Flip())
+
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
+ self.convs = DDSConv(
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
+ )
+ if gin_channels != 0:
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
+
+ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
+ x = torch.detach(x)
+ x = self.pre(x)
+ if g is not None:
+ g = torch.detach(g)
+ x = x + self.cond(g)
+ x = self.convs(x, x_mask)
+ x = self.proj(x) * x_mask
+
+ if not reverse:
+ flows = self.flows
+ assert w is not None
+
+ logdet_tot_q = 0
+ h_w = self.post_pre(w)
+ h_w = self.post_convs(h_w, x_mask)
+ h_w = self.post_proj(h_w) * x_mask
+ e_q = (
+ torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
+ * x_mask
+ )
+ z_q = e_q
+ for flow in self.post_flows:
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
+ logdet_tot_q += logdet_q
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
+ u = torch.sigmoid(z_u) * x_mask
+ z0 = (w - u) * x_mask
+ logdet_tot_q += torch.sum(
+ (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
+ )
+ logq = (
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
+ - logdet_tot_q
+ )
+
+ logdet_tot = 0
+ z0, logdet = self.log_flow(z0, x_mask)
+ logdet_tot += logdet
+ z = torch.cat([z0, z1], 1)
+ for flow in flows:
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
+ logdet_tot = logdet_tot + logdet
+ nll = (
+ torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
+ - logdet_tot
+ )
+ return nll + logq
+ else:
+ flows = list(reversed(self.flows))
+ flows = flows[:-2] + [flows[-1]]
+ z = (
+ torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
+ * noise_scale
+ )
+ for flow in flows:
+ z = flow(z, x_mask, g=x, reverse=reverse)
+ z0, z1 = torch.split(z, [1, 1], 1)
+ logw = z0
+ return logw
diff --git a/modules/encoder/__init__.py b/modules/encoder/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d704d2a703dac50edb49cec355c7a0b1ce92b5fa
--- /dev/null
+++ b/modules/encoder/__init__.py
@@ -0,0 +1 @@
+from .token_encoder import TokenEmbedding
diff --git a/modules/encoder/condition_encoder.py b/modules/encoder/condition_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..1600d078dfe8ed8c228dd113f7e68931dcc545ae
--- /dev/null
+++ b/modules/encoder/condition_encoder.py
@@ -0,0 +1,244 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+import torch.nn as nn
+from torchaudio.models import Conformer
+from models.svc.transformer.transformer import PositionalEncoding
+
+from utils.f0 import f0_to_coarse
+
+
+class ContentEncoder(nn.Module):
+ def __init__(self, cfg, input_dim, output_dim):
+ super().__init__()
+ self.cfg = cfg
+
+ assert input_dim != 0
+ self.nn = nn.Linear(input_dim, output_dim)
+
+ # Introduce conformer or not
+ if (
+ "use_conformer_for_content_features" in cfg
+ and cfg.use_conformer_for_content_features
+ ):
+ self.pos_encoder = PositionalEncoding(input_dim)
+ self.conformer = Conformer(
+ input_dim=input_dim,
+ num_heads=2,
+ ffn_dim=256,
+ num_layers=6,
+ depthwise_conv_kernel_size=3,
+ )
+ else:
+ self.conformer = None
+
+ def forward(self, x, length=None):
+ # x: (N, seq_len, input_dim) -> (N, seq_len, output_dim)
+ if self.conformer:
+ x = self.pos_encoder(x)
+ x, _ = self.conformer(x, length)
+ return self.nn(x)
+
+
+class MelodyEncoder(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.cfg = cfg
+
+ self.input_dim = self.cfg.input_melody_dim
+ self.output_dim = self.cfg.output_melody_dim
+ self.n_bins = self.cfg.n_bins_melody
+
+ if self.input_dim != 0:
+ if self.n_bins == 0:
+ # Not use quantization
+ self.nn = nn.Linear(self.input_dim, self.output_dim)
+ else:
+ self.f0_min = cfg.f0_min
+ self.f0_max = cfg.f0_max
+
+ self.nn = nn.Embedding(
+ num_embeddings=self.n_bins,
+ embedding_dim=self.output_dim,
+ padding_idx=None,
+ )
+ self.uv_embedding = nn.Embedding(2, self.output_dim)
+
+ def forward(self, x, uv=None, length=None):
+ # x: (B, frame_len)
+ if self.n_bins == 0:
+ x = x.unsqueeze(-1)
+ else:
+ x = f0_to_coarse(x, self.n_bins, self.f0_min, self.f0_max)
+ x = self.nn(x)
+
+ if self.cfg.use_uv:
+ uv = self.uv_embedding(uv)
+ x = x + uv
+ return x
+
+
+class LoudnessEncoder(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.cfg = cfg
+
+ self.input_dim = self.cfg.input_loudness_dim
+ self.output_dim = self.cfg.output_loudness_dim
+ self.n_bins = self.cfg.n_bins_loudness
+
+ if self.input_dim != 0:
+ if self.n_bins == 0:
+ # Not use quantization
+ self.nn = nn.Linear(self.input_dim, self.output_dim)
+ else:
+ # TODO: set empirically now
+ self.loudness_min = 1e-30
+ self.loudness_max = 1.5
+ self.energy_bins = nn.Parameter(
+ torch.exp(
+ torch.linspace(
+ np.log(self.loudness_min),
+ np.log(self.loudness_max),
+ self.n_bins - 1,
+ )
+ ),
+ requires_grad=False,
+ )
+
+ self.nn = nn.Embedding(
+ num_embeddings=self.n_bins,
+ embedding_dim=self.output_dim,
+ padding_idx=None,
+ )
+
+ def forward(self, x):
+ # x: (N, frame_len)
+ if self.n_bins == 0:
+ x = x.unsqueeze(-1)
+ else:
+ x = torch.bucketize(x, self.energy_bins)
+ return self.nn(x)
+
+
+class SingerEncoder(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.cfg = cfg
+
+ self.input_dim = 1
+ self.output_dim = self.cfg.output_singer_dim
+
+ self.nn = nn.Embedding(
+ num_embeddings=cfg.singer_table_size,
+ embedding_dim=self.output_dim,
+ padding_idx=None,
+ )
+
+ def forward(self, x):
+ # x: (N, 1) -> (N, 1, output_dim)
+ return self.nn(x)
+
+
+class ConditionEncoder(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.cfg = cfg
+ self.merge_mode = cfg.merge_mode
+
+ ### Semantic Features ###
+ if cfg.use_whisper:
+ self.whisper_encoder = ContentEncoder(
+ self.cfg, self.cfg.whisper_dim, self.cfg.content_encoder_dim
+ )
+ if cfg.use_contentvec:
+ self.contentvec_encoder = ContentEncoder(
+ self.cfg, self.cfg.contentvec_dim, self.cfg.content_encoder_dim
+ )
+ if cfg.use_mert:
+ self.mert_encoder = ContentEncoder(
+ self.cfg, self.cfg.mert_dim, self.cfg.content_encoder_dim
+ )
+ if cfg.use_wenet:
+ self.wenet_encoder = ContentEncoder(
+ self.cfg, self.cfg.wenet_dim, self.cfg.content_encoder_dim
+ )
+
+ ### Prosody Features ###
+ if cfg.use_f0:
+ self.melody_encoder = MelodyEncoder(self.cfg)
+ if cfg.use_energy:
+ self.loudness_encoder = LoudnessEncoder(self.cfg)
+
+ ### Speaker Features ###
+ if cfg.use_spkid:
+ self.singer_encoder = SingerEncoder(self.cfg)
+
+ def forward(self, x):
+ outputs = []
+
+ if self.cfg.use_f0:
+ if self.cfg.use_uv:
+ pitch_enc_out = self.melody_encoder(
+ x["frame_pitch"], uv=x["frame_uv"], length=x["target_len"]
+ )
+ else:
+ pitch_enc_out = self.melody_encoder(
+ x["frame_pitch"], uv=None, length=x["target_len"]
+ )
+ outputs.append(pitch_enc_out)
+
+ if self.cfg.use_energy:
+ loudness_enc_out = self.loudness_encoder(x["frame_energy"])
+ outputs.append(loudness_enc_out)
+
+ if self.cfg.use_whisper:
+ # whisper_feat: [b, T, 1024]
+ whiser_enc_out = self.whisper_encoder(
+ x["whisper_feat"], length=x["target_len"]
+ )
+ outputs.append(whiser_enc_out)
+ seq_len = whiser_enc_out.shape[1]
+
+ if self.cfg.use_contentvec:
+ contentvec_enc_out = self.contentvec_encoder(
+ x["contentvec_feat"], length=x["target_len"]
+ )
+ outputs.append(contentvec_enc_out)
+ seq_len = contentvec_enc_out.shape[1]
+
+ if self.cfg.use_mert:
+ mert_enc_out = self.mert_encoder(x["mert_feat"], length=x["target_len"])
+ outputs.append(mert_enc_out)
+ seq_len = mert_enc_out.shape[1]
+
+ if self.cfg.use_wenet:
+ wenet_enc_out = self.wenet_encoder(x["wenet_feat"], length=x["target_len"])
+ outputs.append(wenet_enc_out)
+ seq_len = wenet_enc_out.shape[1]
+
+ if self.cfg.use_spkid:
+ speaker_enc_out = self.singer_encoder(x["spk_id"]) # [b, 1, 384]
+ assert (
+ "whisper_feat" in x.keys()
+ or "contentvec_feat" in x.keys()
+ or "mert_feat" in x.keys()
+ or "wenet_feat" in x.keys()
+ )
+ singer_info = speaker_enc_out.expand(-1, seq_len, -1)
+ outputs.append(singer_info)
+
+ encoder_output = None
+ if self.merge_mode == "concat":
+ encoder_output = torch.cat(outputs, dim=-1)
+ if self.merge_mode == "add":
+ # (#modules, N, seq_len, output_dim)
+ outputs = torch.cat([out[None, :, :, :] for out in outputs], dim=0)
+ # (N, seq_len, output_dim)
+ encoder_output = torch.sum(outputs, dim=0)
+
+ return encoder_output
diff --git a/modules/encoder/position_encoder.py b/modules/encoder/position_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f67c688a2b55390c75142ba57d159aca4a31c91
--- /dev/null
+++ b/modules/encoder/position_encoder.py
@@ -0,0 +1,85 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+
+import torch
+import torch.nn as nn
+
+from modules.general.utils import Linear
+
+
+class PositionEncoder(nn.Module):
+ r"""Encoder of positional embedding, generates PE and then
+ feed into 2 full-connected layers with ``SiLU``.
+
+ Args:
+ d_raw_emb: The dimension of raw embedding vectors.
+ d_out: The dimension of output embedding vectors, default to ``d_raw_emb``.
+ d_mlp: The dimension of hidden layer in MLP, default to ``d_raw_emb`` * 4.
+ activation_function: The activation function used in MLP, default to ``SiLU``.
+ n_layer: The number of layers in MLP, default to 2.
+ max_period: controls the minimum frequency of the embeddings.
+ """
+
+ def __init__(
+ self,
+ d_raw_emb: int = 128,
+ d_out: int = None,
+ d_mlp: int = None,
+ activation_function: str = "SiLU",
+ n_layer: int = 2,
+ max_period: int = 10000,
+ ):
+ super().__init__()
+
+ self.d_raw_emb = d_raw_emb
+ self.d_out = d_raw_emb if d_out is None else d_out
+ self.d_mlp = d_raw_emb * 4 if d_mlp is None else d_mlp
+ self.n_layer = n_layer
+ self.max_period = max_period
+
+ if activation_function.lower() == "silu":
+ self.activation_function = "SiLU"
+ elif activation_function.lower() == "relu":
+ self.activation_function = "ReLU"
+ elif activation_function.lower() == "gelu":
+ self.activation_function = "GELU"
+ else:
+ raise ValueError("activation_function must be one of SiLU, ReLU, GELU")
+ self.activation_function = activation_function
+
+ tmp = [Linear(self.d_raw_emb, self.d_mlp), getattr(nn, activation_function)()]
+ for _ in range(self.n_layer - 1):
+ tmp.append(Linear(self.d_mlp, self.d_mlp))
+ tmp.append(getattr(nn, activation_function)())
+ tmp.append(Linear(self.d_mlp, self.d_out))
+
+ self.out = nn.Sequential(*tmp)
+
+ def forward(self, steps: torch.Tensor) -> torch.Tensor:
+ r"""Create and return sinusoidal timestep embeddings directly.
+
+ Args:
+ steps: a 1D Tensor of N indices, one per batch element.
+ These may be fractional.
+
+ Returns:
+ an [N x ``d_out``] Tensor of positional embeddings.
+ """
+
+ half = self.d_raw_emb // 2
+ freqs = torch.exp(
+ -math.log(self.max_period)
+ / half
+ * torch.arange(half, dtype=torch.float32, device=steps.device)
+ )
+ args = steps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if self.d_raw_emb % 2:
+ embedding = torch.cat(
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
+ )
+ return self.out(embedding)
diff --git a/modules/encoder/token_encoder.py b/modules/encoder/token_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7594c173fd9127d543ff9e1e5fb6471704c6468
--- /dev/null
+++ b/modules/encoder/token_encoder.py
@@ -0,0 +1,25 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# This code is modified from https://github.com/lifeiteng/vall-e
+
+import torch
+import torch.nn as nn
+
+
+class TokenEmbedding(nn.Module):
+ def __init__(self, dim_model: int, vocab_size: int, dropout: float = 0.0):
+ super().__init__()
+ self.dropout = nn.Dropout(p=dropout)
+ self.word_embeddings = nn.Embedding(vocab_size, dim_model)
+
+ @property
+ def weight(self) -> torch.Tensor:
+ return self.word_embeddings.weight
+
+ def forward(self, x: torch.Tensor):
+ x = self.word_embeddings(x)
+ x = self.dropout(x)
+ return x
diff --git a/modules/flow/modules.py b/modules/flow/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4d897a4ae151a90d4a1fb886108903511ff14b7
--- /dev/null
+++ b/modules/flow/modules.py
@@ -0,0 +1,457 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# This code is modified from https://github.com/jaywalnut310/vits/
+
+import math
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from torch.nn import Conv1d
+from torch.nn.utils import weight_norm, remove_weight_norm
+
+from utils.util import *
+from modules.transformer.transforms import (
+ piecewise_rational_quadratic_transform,
+)
+from modules.base.base_module import LayerNorm
+
+LRELU_SLOPE = 0.1
+
+
+class DDSConv(nn.Module):
+ """
+ Dialted and Depth-Separable Convolution
+ """
+
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
+ super().__init__()
+ self.channels = channels
+ self.kernel_size = kernel_size
+ self.n_layers = n_layers
+ self.p_dropout = p_dropout
+
+ self.drop = nn.Dropout(p_dropout)
+ self.convs_sep = nn.ModuleList()
+ self.convs_1x1 = nn.ModuleList()
+ self.norms_1 = nn.ModuleList()
+ self.norms_2 = nn.ModuleList()
+ for i in range(n_layers):
+ dilation = kernel_size**i
+ padding = (kernel_size * dilation - dilation) // 2
+ self.convs_sep.append(
+ nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ groups=channels,
+ dilation=dilation,
+ padding=padding,
+ )
+ )
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
+ self.norms_1.append(LayerNorm(channels))
+ self.norms_2.append(LayerNorm(channels))
+
+ def forward(self, x, x_mask, g=None):
+ if g is not None:
+ x = x + g
+ for i in range(self.n_layers):
+ y = self.convs_sep[i](x * x_mask)
+ y = self.norms_1[i](y)
+ y = F.gelu(y)
+ y = self.convs_1x1[i](y)
+ y = self.norms_2[i](y)
+ y = F.gelu(y)
+ y = self.drop(y)
+ x = x + y
+ return x * x_mask
+
+
+class WN(torch.nn.Module):
+ def __init__(
+ self,
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ gin_channels=0,
+ p_dropout=0,
+ ):
+ super(WN, self).__init__()
+ assert kernel_size % 2 == 1
+ self.hidden_channels = hidden_channels
+ self.kernel_size = (kernel_size,)
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.gin_channels = gin_channels
+ self.p_dropout = p_dropout
+
+ self.in_layers = torch.nn.ModuleList()
+ self.res_skip_layers = torch.nn.ModuleList()
+ self.drop = nn.Dropout(p_dropout)
+
+ if gin_channels != 0:
+ cond_layer = torch.nn.Conv1d(
+ gin_channels, 2 * hidden_channels * n_layers, 1
+ )
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
+
+ for i in range(n_layers):
+ dilation = dilation_rate**i
+ padding = int((kernel_size * dilation - dilation) / 2)
+ in_layer = torch.nn.Conv1d(
+ hidden_channels,
+ 2 * hidden_channels,
+ kernel_size,
+ dilation=dilation,
+ padding=padding,
+ )
+ in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
+ self.in_layers.append(in_layer)
+
+ # last one is not necessary
+ if i < n_layers - 1:
+ res_skip_channels = 2 * hidden_channels
+ else:
+ res_skip_channels = hidden_channels
+
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
+ self.res_skip_layers.append(res_skip_layer)
+
+ def forward(self, x, x_mask, g=None, **kwargs):
+ output = torch.zeros_like(x)
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
+
+ if g is not None:
+ g = self.cond_layer(g)
+
+ for i in range(self.n_layers):
+ x_in = self.in_layers[i](x)
+ if g is not None:
+ cond_offset = i * 2 * self.hidden_channels
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
+ else:
+ g_l = torch.zeros_like(x_in)
+
+ acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
+ acts = self.drop(acts)
+
+ res_skip_acts = self.res_skip_layers[i](acts)
+ if i < self.n_layers - 1:
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
+ x = (x + res_acts) * x_mask
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
+ else:
+ output = output + res_skip_acts
+ return output * x_mask
+
+ def remove_weight_norm(self):
+ if self.gin_channels != 0:
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
+ for l in self.in_layers:
+ torch.nn.utils.remove_weight_norm(l)
+ for l in self.res_skip_layers:
+ torch.nn.utils.remove_weight_norm(l)
+
+
+class ResBlock1(torch.nn.Module):
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
+ super(ResBlock1, self).__init__()
+ self.convs1 = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[2],
+ padding=get_padding(kernel_size, dilation[2]),
+ )
+ ),
+ ]
+ )
+ self.convs1.apply(init_weights)
+
+ self.convs2 = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ ]
+ )
+ self.convs2.apply(init_weights)
+
+ def forward(self, x, x_mask=None):
+ for c1, c2 in zip(self.convs1, self.convs2):
+ xt = F.leaky_relu(x, LRELU_SLOPE)
+ if x_mask is not None:
+ xt = xt * x_mask
+ xt = c1(xt)
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
+ if x_mask is not None:
+ xt = xt * x_mask
+ xt = c2(xt)
+ x = xt + x
+ if x_mask is not None:
+ x = x * x_mask
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs1:
+ remove_weight_norm(l)
+ for l in self.convs2:
+ remove_weight_norm(l)
+
+
+class ResBlock2(torch.nn.Module):
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
+ super(ResBlock2, self).__init__()
+ self.convs = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]),
+ )
+ ),
+ ]
+ )
+ self.convs.apply(init_weights)
+
+ def forward(self, x, x_mask=None):
+ for c in self.convs:
+ xt = F.leaky_relu(x, LRELU_SLOPE)
+ if x_mask is not None:
+ xt = xt * x_mask
+ xt = c(xt)
+ x = xt + x
+ if x_mask is not None:
+ x = x * x_mask
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs:
+ remove_weight_norm(l)
+
+
+class Log(nn.Module):
+ def forward(self, x, x_mask, reverse=False, **kwargs):
+ if not reverse:
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
+ logdet = torch.sum(-y, [1, 2])
+ return y, logdet
+ else:
+ x = torch.exp(x) * x_mask
+ return x
+
+
+class Flip(nn.Module):
+ def forward(self, x, *args, reverse=False, **kwargs):
+ x = torch.flip(x, [1])
+ if not reverse:
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
+ return x, logdet
+ else:
+ return x
+
+
+class ElementwiseAffine(nn.Module):
+ def __init__(self, channels):
+ super().__init__()
+ self.channels = channels
+ self.m = nn.Parameter(torch.zeros(channels, 1))
+ self.logs = nn.Parameter(torch.zeros(channels, 1))
+
+ def forward(self, x, x_mask, reverse=False, **kwargs):
+ if not reverse:
+ y = self.m + torch.exp(self.logs) * x
+ y = y * x_mask
+ logdet = torch.sum(self.logs * x_mask, [1, 2])
+ return y, logdet
+ else:
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
+ return x
+
+
+class ResidualCouplingLayer(nn.Module):
+ def __init__(
+ self,
+ channels,
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ p_dropout=0,
+ gin_channels=0,
+ mean_only=False,
+ ):
+ assert channels % 2 == 0, "channels should be divisible by 2"
+ super().__init__()
+ self.channels = channels
+ self.hidden_channels = hidden_channels
+ self.kernel_size = kernel_size
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.half_channels = channels // 2
+ self.mean_only = mean_only
+
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
+ self.enc = WN(
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ p_dropout=p_dropout,
+ gin_channels=gin_channels,
+ )
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
+ self.post.weight.data.zero_()
+ self.post.bias.data.zero_()
+
+ def forward(self, x, x_mask, g=None, reverse=False):
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
+ h = self.pre(x0) * x_mask
+ h = self.enc(h, x_mask, g=g)
+ stats = self.post(h) * x_mask
+ if not self.mean_only:
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
+ else:
+ m = stats
+ logs = torch.zeros_like(m)
+
+ if not reverse:
+ x1 = m + x1 * torch.exp(logs) * x_mask
+ x = torch.cat([x0, x1], 1)
+ logdet = torch.sum(logs, [1, 2])
+ return x, logdet
+ else:
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
+ x = torch.cat([x0, x1], 1)
+ return x
+
+
+class ConvFlow(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ filter_channels,
+ kernel_size,
+ n_layers,
+ num_bins=10,
+ tail_bound=5.0,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.filter_channels = filter_channels
+ self.kernel_size = kernel_size
+ self.n_layers = n_layers
+ self.num_bins = num_bins
+ self.tail_bound = tail_bound
+ self.half_channels = in_channels // 2
+
+ self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
+ self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
+ self.proj = nn.Conv1d(
+ filter_channels, self.half_channels * (num_bins * 3 - 1), 1
+ )
+ self.proj.weight.data.zero_()
+ self.proj.bias.data.zero_()
+
+ def forward(self, x, x_mask, g=None, reverse=False):
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
+ h = self.pre(x0)
+ h = self.convs(h, x_mask, g=g)
+ h = self.proj(h) * x_mask
+
+ b, c, t = x0.shape
+ h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
+
+ unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
+ unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
+ self.filter_channels
+ )
+ unnormalized_derivatives = h[..., 2 * self.num_bins :]
+
+ x1, logabsdet = piecewise_rational_quadratic_transform(
+ x1,
+ unnormalized_widths,
+ unnormalized_heights,
+ unnormalized_derivatives,
+ inverse=reverse,
+ tails="linear",
+ tail_bound=self.tail_bound,
+ )
+
+ x = torch.cat([x0, x1], 1) * x_mask
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
+ if not reverse:
+ return x, logdet
+ else:
+ return x
diff --git a/modules/general/__init__.py b/modules/general/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f68ee3911c2b31c32ad0bd0fcfb017d3743370f
--- /dev/null
+++ b/modules/general/__init__.py
@@ -0,0 +1,3 @@
+from .input_strategies import PromptedFeatures, PromptedPrecomputedFeatures
+from .scaling import BalancedDoubleSwish
+from .utils import Transpose
diff --git a/modules/general/input_strategies.py b/modules/general/input_strategies.py
new file mode 100644
index 0000000000000000000000000000000000000000..2edefd9ec9224a650c1002a416062c2ece37f553
--- /dev/null
+++ b/modules/general/input_strategies.py
@@ -0,0 +1,130 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+# This code is modified from
+# https://github.com/lifeiteng/vall-e/blob/9c69096d603ce13174fb5cb025f185e2e9b36ac7/valle/data/input_strategies.py
+import random
+from collections import defaultdict
+from concurrent.futures import ThreadPoolExecutor
+from typing import Tuple, Type
+
+from lhotse import CutSet
+from lhotse.dataset.collation import collate_features
+from lhotse.dataset.input_strategies import (
+ ExecutorType,
+ PrecomputedFeatures,
+ _get_executor,
+)
+from lhotse.utils import fastcopy
+
+
+class PromptedFeatures:
+ def __init__(self, prompts, features):
+ self.prompts = prompts
+ self.features = features
+
+ def to(self, device):
+ return PromptedFeatures(self.prompts.to(device), self.features.to(device))
+
+ def sum(self):
+ return self.features.sum()
+
+ @property
+ def ndim(self):
+ return self.features.ndim
+
+ @property
+ def data(self):
+ return (self.prompts, self.features)
+
+
+class PromptedPrecomputedFeatures(PrecomputedFeatures):
+ def __init__(
+ self,
+ dataset: str,
+ cuts: CutSet,
+ num_workers: int = 0,
+ executor_type: Type[ExecutorType] = ThreadPoolExecutor,
+ ) -> None:
+ super().__init__(num_workers, executor_type)
+ self.utt2neighbors = self._create_utt2neighbors(dataset, cuts)
+
+ def __call__(self, cuts: CutSet) -> Tuple[PromptedFeatures, PromptedFeatures]:
+ features, features_lens = self._collate_features(cuts)
+ prompts, prompts_lens = self._collate_prompts(cuts)
+ return PromptedFeatures(prompts, features), PromptedFeatures(
+ prompts_lens, features_lens
+ )
+
+ def _create_utt2neighbors(self, dataset, cuts):
+ utt2neighbors = defaultdict(lambda: [])
+ utt2cut = {cut.id: cut for cut in cuts}
+ if dataset.lower() == "libritts":
+ self._process_libritts_dataset(utt2neighbors, utt2cut, cuts)
+ elif dataset.lower() == "ljspeech":
+ self._process_ljspeech_dataset(utt2neighbors, utt2cut, cuts)
+ else:
+ raise ValueError("Unsupported dataset")
+ return utt2neighbors
+
+ def _process_libritts_dataset(self, utt2neighbors, utt2cut, cuts):
+ speaker2utts = defaultdict(lambda: [])
+ for cut in cuts:
+ speaker = cut.supervisions[0].speaker
+ speaker2utts[speaker].append(cut.id)
+
+ for spk, uttids in speaker2utts.items():
+ sorted_uttids = sorted(uttids)
+ if len(sorted_uttids) == 1:
+ utt2neighbors[sorted_uttids[0]].append(utt2cut[sorted_uttids[0]])
+ continue
+
+ utt2prevutt = dict(
+ zip(sorted_uttids, [sorted_uttids[1]] + sorted_uttids[:-1])
+ )
+ utt2postutt = dict(zip(sorted_uttids[:-1], sorted_uttids[1:]))
+ for utt in sorted_uttids:
+ if utt in utt2prevutt:
+ utt2neighbors[utt].append(utt2cut[utt2prevutt[utt]])
+ if utt in utt2postutt:
+ utt2neighbors[utt].append(utt2cut[utt2postutt[utt]])
+
+ def _process_ljspeech_dataset(self, utt2neighbors, utt2cut, cuts):
+ uttids = [cut.id for cut in cuts]
+ if len(uttids) == 1:
+ utt2neighbors[uttids[0]].append(utt2cut[uttids[0]])
+ return
+
+ utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1]))
+ utt2postutt = dict(zip(uttids[:-1], uttids[1:]))
+ for utt in uttids:
+ prevutt, postutt = utt2prevutt.get(utt), utt2postutt.get(utt)
+ if prevutt and utt[:5] == prevutt[:5]:
+ utt2neighbors[utt].append(utt2cut[prevutt])
+ if postutt and utt[:5] == postutt[:5]:
+ utt2neighbors[utt].append(utt2cut[postutt])
+
+ def _collate_features(self, cuts):
+ return collate_features(
+ cuts,
+ executor=_get_executor(self.num_workers, executor_type=self._executor_type),
+ )
+
+ def _collate_prompts(self, cuts):
+ prompts_cuts = []
+ for k, cut in enumerate(cuts):
+ prompts_cut = random.choice(self.utt2neighbors[cut.id])
+ prompts_cuts.append(fastcopy(prompts_cut, id=f"{cut.id}-{str(k)}"))
+
+ mini_duration = min([cut.duration for cut in prompts_cuts] + [3.0])
+ prompts_cuts = CutSet(
+ cuts={k: cut for k, cut in enumerate(prompts_cuts)}
+ ).truncate(max_duration=mini_duration, offset_type="random", preserve_id=False)
+
+ return collate_features(
+ prompts_cuts,
+ executor=_get_executor(self.num_workers, executor_type=self._executor_type),
+ )
diff --git a/modules/general/scaling.py b/modules/general/scaling.py
new file mode 100644
index 0000000000000000000000000000000000000000..bef858ffc65f7bcefedc0040f3377965f9eb3992
--- /dev/null
+++ b/modules/general/scaling.py
@@ -0,0 +1,1349 @@
+# This module is modified from https://github.com/Plachtaa/VALL-E-X/blob/3faaf8ccadb154d63b38070caf518ce9309ea0f4/modules/scaling.py
+
+
+import logging
+import random
+import math
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+
+
+class Transpose(nn.Identity):
+ """(N, T, D) -> (N, D, T)"""
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return input.transpose(1, 2)
+
+
+class ActivationBalancerFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x: Tensor,
+ scale_factor: Tensor,
+ sign_factor: Optional[Tensor],
+ channel_dim: int,
+ ) -> Tensor:
+ if channel_dim < 0:
+ channel_dim += x.ndim
+ ctx.channel_dim = channel_dim
+ xgt0 = x > 0
+ if sign_factor is None:
+ ctx.save_for_backward(xgt0, scale_factor)
+ else:
+ ctx.save_for_backward(xgt0, scale_factor, sign_factor)
+ return x
+
+ @staticmethod
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
+ if len(ctx.saved_tensors) == 3:
+ xgt0, scale_factor, sign_factor = ctx.saved_tensors
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
+ scale_factor = scale_factor.unsqueeze(-1)
+ sign_factor = sign_factor.unsqueeze(-1)
+ factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
+ else:
+ xgt0, scale_factor = ctx.saved_tensors
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
+ scale_factor = scale_factor.unsqueeze(-1)
+ factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
+ neg_delta_grad = x_grad.abs() * factor
+ return (
+ x_grad - neg_delta_grad,
+ None,
+ None,
+ None,
+ )
+
+
+def _compute_scale_factor(
+ x: Tensor,
+ channel_dim: int,
+ min_abs: float,
+ max_abs: float,
+ gain_factor: float,
+ max_factor: float,
+) -> Tensor:
+ if channel_dim < 0:
+ channel_dim += x.ndim
+ sum_dims = [d for d in range(x.ndim) if d != channel_dim]
+ x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
+
+ if min_abs == 0.0:
+ below_threshold = 0.0
+ else:
+ # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
+ # x_abs)_mean , min_abs.
+ below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(
+ min=0, max=max_factor
+ )
+
+ above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
+ min=0, max=max_factor
+ )
+
+ return below_threshold - above_threshold
+
+
+def _compute_sign_factor(
+ x: Tensor,
+ channel_dim: int,
+ min_positive: float,
+ max_positive: float,
+ gain_factor: float,
+ max_factor: float,
+) -> Tensor:
+ if channel_dim < 0:
+ channel_dim += x.ndim
+ sum_dims = [d for d in range(x.ndim) if d != channel_dim]
+ proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims)
+ if min_positive == 0.0:
+ factor1 = 0.0
+ else:
+ # 0 if proportion_positive >= min_positive, else can be
+ # as large as max_factor.
+ factor1 = (
+ (min_positive - proportion_positive) * (gain_factor / min_positive)
+ ).clamp_(min=0, max=max_factor)
+
+ if max_positive == 1.0:
+ factor2 = 0.0
+ else:
+ # 0 if self.proportion_positive <= max_positive, else can be
+ # as large as -max_factor.
+ factor2 = (
+ (proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))
+ ).clamp_(min=0, max=max_factor)
+ sign_factor = factor1 - factor2
+ # require min_positive != 0 or max_positive != 1:
+ assert not isinstance(sign_factor, float)
+ return sign_factor
+
+
+class ActivationScaleBalancerFunction(torch.autograd.Function):
+ """
+ This object is used in class ActivationBalancer when the user specified
+ min_positive=0, max_positive=1, so there are no constraints on the signs
+ of the activations and only the absolute value has a constraint.
+ """
+
+ @staticmethod
+ def forward(
+ ctx,
+ x: Tensor,
+ sign_factor: Tensor,
+ scale_factor: Tensor,
+ channel_dim: int,
+ ) -> Tensor:
+ if channel_dim < 0:
+ channel_dim += x.ndim
+ ctx.channel_dim = channel_dim
+ xgt0 = x > 0
+ ctx.save_for_backward(xgt0, sign_factor, scale_factor)
+ return x
+
+ @staticmethod
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
+ xgt0, sign_factor, scale_factor = ctx.saved_tensors
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
+ sign_factor = sign_factor.unsqueeze(-1)
+ scale_factor = scale_factor.unsqueeze(-1)
+
+ factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
+ neg_delta_grad = x_grad.abs() * factor
+ return (
+ x_grad - neg_delta_grad,
+ None,
+ None,
+ None,
+ )
+
+
+class RandomClampFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x: Tensor,
+ min: Optional[float],
+ max: Optional[float],
+ prob: float,
+ reflect: float,
+ ) -> Tensor:
+ x_clamped = torch.clamp(x, min=min, max=max)
+ mask = torch.rand_like(x) < prob
+ ans = torch.where(mask, x_clamped, x)
+ if x.requires_grad:
+ ctx.save_for_backward(ans == x)
+ ctx.reflect = reflect
+ if reflect != 0.0:
+ ans = ans * (1.0 + reflect) - (x * reflect)
+ return ans
+
+ @staticmethod
+ def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]:
+ (is_same,) = ctx.saved_tensors
+ x_grad = ans_grad * is_same.to(ans_grad.dtype)
+ reflect = ctx.reflect
+ if reflect != 0.0:
+ x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect)
+ return x_grad, None, None, None, None
+
+
+def random_clamp(
+ x: Tensor,
+ min: Optional[float] = None,
+ max: Optional[float] = None,
+ prob: float = 0.5,
+ reflect: float = 0.0,
+):
+ return RandomClampFunction.apply(x, min, max, prob, reflect)
+
+
+def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor:
+ """
+ A randomized way of casting a floating point value to half precision.
+ """
+ if x.dtype == torch.float16:
+ return x
+ x_abs = x.abs()
+ is_too_small = x_abs < min_abs
+ # for elements where is_too_small is true, random_val will contain +-min_abs with
+ # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations,
+ # for those elements].
+ random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
+ return torch.where(is_too_small, random_val, x).to(torch.float16)
+
+
+class RandomGradFunction(torch.autograd.Function):
+ """
+ Does nothing in forward pass; in backward pass, gets rid of very small grads using
+ randomized approach that preserves expectations (intended to reduce roundoff).
+ """
+
+ @staticmethod
+ def forward(ctx, x: Tensor, min_abs: float) -> Tensor:
+ ctx.min_abs = min_abs
+ return x
+
+ @staticmethod
+ def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]:
+ if ans_grad.dtype == torch.float16:
+ return (
+ random_cast_to_half(ans_grad.to(torch.float32), min_abs=ctx.min_abs),
+ None,
+ )
+ else:
+ return ans_grad, None
+
+
+class RandomGrad(torch.nn.Module):
+ """
+ Gets rid of very small gradients using an expectation-preserving method, intended to increase
+ accuracy of training when using amp (automatic mixed precision)
+ """
+
+ def __init__(self, min_abs: float = 5.0e-06):
+ super(RandomGrad, self).__init__()
+ self.min_abs = min_abs
+
+ def forward(self, x: Tensor):
+ if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing():
+ return x
+ else:
+ return RandomGradFunction.apply(x, self.min_abs)
+
+
+class SoftmaxFunction(torch.autograd.Function):
+ """
+ Tries to handle half-precision derivatives in a randomized way that should
+ be more accurate for training than the default behavior.
+ """
+
+ @staticmethod
+ def forward(ctx, x: Tensor, dim: int):
+ ans = x.softmax(dim=dim)
+ # if x dtype is float16, x.softmax() returns a float32 because
+ # (presumably) that op does not support float16, and autocast
+ # is enabled.
+ if torch.is_autocast_enabled():
+ ans = ans.to(torch.float16)
+ ctx.save_for_backward(ans)
+ ctx.x_dtype = x.dtype
+ ctx.dim = dim
+ return ans
+
+ @staticmethod
+ def backward(ctx, ans_grad: Tensor):
+ (ans,) = ctx.saved_tensors
+ with torch.cuda.amp.autocast(enabled=False):
+ ans_grad = ans_grad.to(torch.float32)
+ ans = ans.to(torch.float32)
+ x_grad = ans_grad * ans
+ x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
+ return x_grad, None
+
+
+def softmax(x: Tensor, dim: int):
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ return x.softmax(dim)
+
+ return SoftmaxFunction.apply(x, dim)
+
+
+class MaxEigLimiterFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x: Tensor,
+ coeffs: Tensor,
+ direction: Tensor,
+ channel_dim: int,
+ grad_scale: float,
+ ) -> Tensor:
+ ctx.channel_dim = channel_dim
+ ctx.grad_scale = grad_scale
+ ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach())
+ return x
+
+ @staticmethod
+ def backward(ctx, x_grad, *args):
+ with torch.enable_grad():
+ (x_orig, coeffs, new_direction) = ctx.saved_tensors
+ x_orig.requires_grad = True
+ num_channels = x_orig.shape[ctx.channel_dim]
+ x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
+ new_direction.requires_grad = False
+ x = x - x.mean(dim=0)
+ x_var = (x**2).mean()
+ x_residual = x - coeffs * new_direction
+ x_residual_var = (x_residual**2).mean()
+ # `variance_proportion` is the proportion of the variance accounted for
+ # by the top eigen-direction. This is to be minimized.
+ variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
+ variance_proportion.backward()
+ x_orig_grad = x_orig.grad
+ x_extra_grad = (
+ x_orig.grad
+ * ctx.grad_scale
+ * x_grad.norm()
+ / (x_orig_grad.norm() + 1.0e-20)
+ )
+ return x_grad + x_extra_grad.detach(), None, None, None, None
+
+
+class BasicNorm(torch.nn.Module):
+ """
+ This is intended to be a simpler, and hopefully cheaper, replacement for
+ LayerNorm. The observation this is based on, is that Transformer-type
+ networks, especially with pre-norm, sometimes seem to set one of the
+ feature dimensions to a large constant value (e.g. 50), which "defeats"
+ the LayerNorm because the output magnitude is then not strongly dependent
+ on the other (useful) features. Presumably the weight and bias of the
+ LayerNorm are required to allow it to do this.
+
+ So the idea is to introduce this large constant value as an explicit
+ parameter, that takes the role of the "eps" in LayerNorm, so the network
+ doesn't have to do this trick. We make the "eps" learnable.
+
+ Args:
+ num_channels: the number of channels, e.g. 512.
+ channel_dim: the axis/dimension corresponding to the channel,
+ interprted as an offset from the input's ndim if negative.
+ shis is NOT the num_channels; it should typically be one of
+ {-2, -1, 0, 1, 2, 3}.
+ eps: the initial "epsilon" that we add as ballast in:
+ scale = ((input_vec**2).mean() + epsilon)**-0.5
+ Note: our epsilon is actually large, but we keep the name
+ to indicate the connection with conventional LayerNorm.
+ learn_eps: if true, we learn epsilon; if false, we keep it
+ at the initial value.
+ eps_min: float
+ eps_max: float
+ """
+
+ def __init__(
+ self,
+ num_channels: int,
+ channel_dim: int = -1, # CAUTION: see documentation.
+ eps: float = 0.25,
+ learn_eps: bool = True,
+ eps_min: float = -3.0,
+ eps_max: float = 3.0,
+ ) -> None:
+ super(BasicNorm, self).__init__()
+ self.num_channels = num_channels
+ self.channel_dim = channel_dim
+ if learn_eps:
+ self.eps = nn.Parameter(torch.tensor(eps).log().detach())
+ else:
+ self.register_buffer("eps", torch.tensor(eps).log().detach())
+ self.eps_min = eps_min
+ self.eps_max = eps_max
+
+ def forward(self, x: Tensor) -> Tensor:
+ assert x.shape[self.channel_dim] == self.num_channels
+ eps = self.eps
+ if self.training and random.random() < 0.25:
+ # with probability 0.25, in training mode, clamp eps between the min
+ # and max; this will encourage it to learn parameters within the
+ # allowed range by making parameters that are outside the allowed
+ # range noisy.
+
+ # gradients to allow the parameter to get back into the allowed
+ # region if it happens to exit it.
+ eps = eps.clamp(min=self.eps_min, max=self.eps_max)
+ scales = (
+ torch.mean(x**2, dim=self.channel_dim, keepdim=True) + eps.exp()
+ ) ** -0.5
+ return x * scales
+
+
+def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
+ """
+ Behaves like a constructor of a modified version of nn.Linear
+ that gives an easy way to set the default initial parameter scale.
+
+ Args:
+ Accepts the standard args and kwargs that nn.Linear accepts
+ e.g. in_features, out_features, bias=False.
+
+ initial_scale: you can override this if you want to increase
+ or decrease the initial magnitude of the module's output
+ (affects the initialization of weight_scale and bias_scale).
+ Another option, if you want to do something like this, is
+ to re-initialize the parameters.
+ """
+ ans = nn.Linear(*args, **kwargs)
+ with torch.no_grad():
+ ans.weight[:] *= initial_scale
+ if ans.bias is not None:
+ torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
+ return ans
+
+
+def ScaledConv1d(
+ *args,
+ initial_scale: float = 1.0,
+ kernel_size: int = 3,
+ padding: str = "same",
+ **kwargs,
+) -> nn.Conv1d:
+ """
+ Behaves like a constructor of a modified version of nn.Conv1d
+ that gives an easy way to set the default initial parameter scale.
+
+ Args:
+ Accepts the standard args and kwargs that nn.Linear accepts
+ e.g. in_features, out_features, bias=False.
+
+ initial_scale: you can override this if you want to increase
+ or decrease the initial magnitude of the module's output
+ (affects the initialization of weight_scale and bias_scale).
+ Another option, if you want to do something like this, is
+ to re-initialize the parameters.
+ """
+ ans = nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs)
+ with torch.no_grad():
+ ans.weight[:] *= initial_scale
+ if ans.bias is not None:
+ torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
+ return ans
+
+
+def TransposeScaledConv1d(
+ *args,
+ initial_scale: float = 1.0,
+ kernel_size: int = 3,
+ padding: str = "same",
+ **kwargs,
+) -> nn.Sequential:
+ """
+ Transpose -> ScaledConv1d
+ """
+ return nn.Sequential(
+ Transpose(),
+ ScaledConv1d(
+ *args,
+ initial_scale=initial_scale,
+ kernel_size=kernel_size,
+ padding=padding,
+ **kwargs,
+ ),
+ )
+
+
+def ScaledConv1dTranspose(
+ *args,
+ initial_scale: float = 1.0,
+ kernel_size: int = 3,
+ padding: str = "same",
+ **kwargs,
+) -> nn.Sequential:
+ """
+ Transpose -> ScaledConv1d
+ """
+ return nn.Sequential(
+ ScaledConv1d(
+ *args,
+ initial_scale=initial_scale,
+ kernel_size=kernel_size,
+ padding=padding,
+ **kwargs,
+ ),
+ Transpose(),
+ )
+
+
+def TransposeConv1d(
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
+) -> nn.Sequential:
+ """
+ Transpose -> Conv1d
+ """
+ return nn.Sequential(
+ Transpose(),
+ nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
+ )
+
+
+def Conv1dTranspose(
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
+) -> nn.Sequential:
+ """
+ ScaledConv1d -> Transpose
+ """
+ return nn.Sequential(
+ nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
+ Transpose(),
+ )
+
+
+class SRLinear(nn.Linear):
+ """https://arxiv.org/abs/2303.06296
+ Stabilizing Transformer Training by Preventing Attention Entropy Collapse
+ """
+
+ def __init__(self, in_features, out_features, bias=True, **kwargs):
+ super().__init__(in_features, out_features, bias=bias, **kwargs)
+ self.register_buffer(
+ "u", nn.functional.normalize(torch.randn(in_features), dim=0)
+ )
+ with torch.no_grad():
+ sigma = self.get_sigma()
+ self.register_buffer("spectral_norm", sigma)
+ self.sigma = nn.Parameter(torch.ones(1))
+
+ def get_sigma(self):
+ with torch.no_grad():
+ u = self.u
+ v = self.weight.mv(u)
+ v = nn.functional.normalize(v, dim=0)
+ u = self.weight.T.mv(v)
+ u = nn.functional.normalize(u, dim=0)
+ self.u.data.copy_(u)
+ return torch.einsum("c,cd,d->", v, self.weight, u)
+
+ def get_weight(self):
+ sigma = self.get_sigma()
+ if self.training:
+ self.spectral_norm.data.copy_(sigma)
+ weight = (self.sigma / sigma) * self.weight
+ return weight
+
+ def forward(self, x):
+ return nn.functional.linear(x, self.get_weight(), self.bias)
+
+
+class SRConv1d(SRLinear):
+ def __init__(
+ self,
+ in_features,
+ out_features,
+ kernel_size,
+ stride: int = 1,
+ padding: str = "same",
+ bias: bool = True,
+ **kwargs,
+ ):
+ in_features = in_features * kernel_size
+ super().__init__(in_features, out_features, bias=bias, **kwargs)
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.padding = padding
+
+ def forward(self, x):
+ in_features = self.in_features // self.kernel_size
+ weight = self.get_weight().view(
+ self.out_features, in_features, self.kernel_size
+ )
+ return nn.functional.conv1d(
+ x, weight, bias=self.bias, stride=self.stride, padding=self.padding
+ )
+
+
+def TransposeSRConv1d(
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
+) -> nn.Sequential:
+ """
+ Transpose -> SRConv1d
+ """
+ return nn.Sequential(
+ Transpose(),
+ SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
+ )
+
+
+def SRConv1dTranspose(
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
+) -> nn.Sequential:
+ """
+ SRConv1d -> Transpose
+ """
+ return nn.Sequential(
+ SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
+ Transpose(),
+ )
+
+
+class ActivationBalancer(torch.nn.Module):
+ """
+ Modifies the backpropped derivatives of a function to try to encourage, for
+ each channel, that it is positive at least a proportion `threshold` of the
+ time. It does this by multiplying negative derivative values by up to
+ (1+max_factor), and positive derivative values by up to (1-max_factor),
+ interpolated from 1 at the threshold to those extremal values when none
+ of the inputs are positive.
+
+ Args:
+ num_channels: the number of channels
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
+ min_positive: the minimum, per channel, of the proportion of the time
+ that (x > 0), below which we start to modify the derivatives.
+ max_positive: the maximum, per channel, of the proportion of the time
+ that (x > 0), above which we start to modify the derivatives.
+ max_factor: the maximum factor by which we modify the derivatives for
+ either the sign constraint or the magnitude constraint;
+ e.g. with max_factor=0.02, the the derivatives would be multiplied by
+ values in the range [0.98..1.02].
+ sign_gain_factor: determines the 'gain' with which we increase the
+ change in gradient once the constraints on min_positive and max_positive
+ are violated.
+ scale_gain_factor: determines the 'gain' with which we increase the
+ change in gradient once the constraints on min_abs and max_abs
+ are violated.
+ min_abs: the minimum average-absolute-value difference from the mean
+ value per channel, which we allow, before we start to modify
+ the derivatives to prevent this.
+ max_abs: the maximum average-absolute-value difference from the mean
+ value per channel, which we allow, before we start to modify
+ the derivatives to prevent this.
+ min_prob: determines the minimum probability with which we modify the
+ gradients for the {min,max}_positive and {min,max}_abs constraints,
+ on each forward(). This is done randomly to prevent all layers
+ from doing it at the same time. Early in training we may use
+ higher probabilities than this; it will decay to this value.
+ """
+
+ def __init__(
+ self,
+ num_channels: int,
+ channel_dim: int,
+ min_positive: float = 0.05,
+ max_positive: float = 0.95,
+ max_factor: float = 0.04,
+ sign_gain_factor: float = 0.01,
+ scale_gain_factor: float = 0.02,
+ min_abs: float = 0.2,
+ max_abs: float = 100.0,
+ min_prob: float = 0.1,
+ ):
+ super(ActivationBalancer, self).__init__()
+ self.num_channels = num_channels
+ self.channel_dim = channel_dim
+ self.min_positive = min_positive
+ self.max_positive = max_positive
+ self.max_factor = max_factor
+ self.min_abs = min_abs
+ self.max_abs = max_abs
+ self.min_prob = min_prob
+ self.sign_gain_factor = sign_gain_factor
+ self.scale_gain_factor = scale_gain_factor
+
+ # count measures how many times the forward() function has been called.
+ # We occasionally sync this to a tensor called `count`, that exists to
+ # make sure it is synced to disk when we load and save the model.
+ self.cpu_count = 0
+ self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
+
+ def forward(self, x: Tensor) -> Tensor:
+ if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing():
+ return _no_op(x)
+
+ count = self.cpu_count
+ self.cpu_count += 1
+
+ if random.random() < 0.01:
+ # Occasionally sync self.cpu_count with self.count.
+ # count affects the decay of 'prob'. don't do this on every iter,
+ # because syncing with the GPU is slow.
+ self.cpu_count = max(self.cpu_count, self.count.item())
+ self.count.fill_(self.cpu_count)
+
+ # the prob of doing some work exponentially decreases from 0.5 till it hits
+ # a floor at min_prob (==0.1, by default)
+ prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0)))
+
+ if random.random() < prob:
+ sign_gain_factor = 0.5
+ if self.min_positive != 0.0 or self.max_positive != 1.0:
+ sign_factor = _compute_sign_factor(
+ x,
+ self.channel_dim,
+ self.min_positive,
+ self.max_positive,
+ gain_factor=self.sign_gain_factor / prob,
+ max_factor=self.max_factor,
+ )
+ else:
+ sign_factor = None
+
+ scale_factor = _compute_scale_factor(
+ x.detach(),
+ self.channel_dim,
+ min_abs=self.min_abs,
+ max_abs=self.max_abs,
+ gain_factor=self.scale_gain_factor / prob,
+ max_factor=self.max_factor,
+ )
+ return ActivationBalancerFunction.apply(
+ x,
+ scale_factor,
+ sign_factor,
+ self.channel_dim,
+ )
+ else:
+ return _no_op(x)
+
+
+def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
+ """
+ Returns x unmodified, but in backprop will put a penalty for the excess of
+ the absolute values of elements of x over the limit "limit". E.g. if
+ limit == 10.0, then if x has any values over 10 it will get a penalty.
+
+ Caution: the value of this penalty will be affected by grad scaling used
+ in automatic mixed precision training. For this reasons we use this,
+ it shouldn't really matter, or may even be helpful; we just use this
+ to disallow really implausible values of scores to be given to softmax.
+ """
+ x_sign = x.sign()
+ over_limit = (x.abs() - limit) > 0
+ # The following is a memory efficient way to penalize the absolute values of
+ # x that's over the limit. (The memory efficiency comes when you think
+ # about which items torch needs to cache for the autograd, and which ones it
+ # can throw away). The numerical value of aux_loss as computed here will
+ # actually be larger than it should be, by limit * over_limit.sum(), but it
+ # has the same derivative as the real aux_loss which is penalty * (x.abs() -
+ # limit).relu().
+ aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
+ # note: we don't do sum() here on aux)_loss, but it's as if we had done
+ # sum() due to how with_loss() works.
+ x = with_loss(x, aux_loss)
+ # you must use x for something, or this will be ineffective.
+ return x
+
+
+def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
+ if x.ndim == 2:
+ return x.diag()
+ else:
+ (batch, dim, dim) = x.shape
+ x = x.reshape(batch, dim * dim)
+ x = x[:, :: dim + 1]
+ assert x.shape == (batch, dim)
+ return x
+
+
+def _whitening_metric(x: Tensor, num_groups: int):
+ """
+ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
+ of the centered feature covariance are the same within each group's covariance matrix
+ and also between groups.
+ Args:
+ x: a Tensor of shape (*, num_channels)
+ num_groups: the number of groups of channels, a number >=1 that divides num_channels
+ Returns:
+ Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
+ greater than 1.0 otherwise.
+ """
+ assert x.dtype != torch.float16
+ x = x.reshape(-1, x.shape[-1])
+ (num_frames, num_channels) = x.shape
+ assert num_channels % num_groups == 0
+ channels_per_group = num_channels // num_groups
+ x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
+ # x now has shape (num_groups, num_frames, channels_per_group)
+ # subtract the mean so we use the centered, not uncentered, covariance.
+ # My experience has been that when we "mess with the gradients" like this,
+ # it's better not do anything that tries to move the mean around, because
+ # that can easily cause instability.
+ x = x - x.mean(dim=1, keepdim=True)
+ # x_covar: (num_groups, channels_per_group, channels_per_group)
+ x_covar = torch.matmul(x.transpose(1, 2), x)
+ x_covar_mean_diag = _diag(x_covar).mean()
+ # the following expression is what we'd get if we took the matrix product
+ # of each covariance and measured the mean of its trace, i.e.
+ # the same as _diag(torch.matmul(x_covar, x_covar)).mean().
+ x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group)
+ # this metric will be >= 1.0; the larger it is, the less 'white' the data was.
+ metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20)
+ return metric
+
+
+class WhiteningPenaltyFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x: Tensor,
+ num_groups: int,
+ whitening_limit: float,
+ grad_scale: float,
+ ) -> Tensor:
+ ctx.save_for_backward(x)
+ ctx.num_groups = num_groups
+ ctx.whitening_limit = whitening_limit
+ ctx.grad_scale = grad_scale
+ return x
+
+ @staticmethod
+ def backward(ctx, x_grad: Tensor):
+ (x_orig,) = ctx.saved_tensors
+ with torch.enable_grad():
+ with torch.cuda.amp.autocast(enabled=False):
+ x_detached = x_orig.to(torch.float32).detach()
+ x_detached.requires_grad = True
+
+ metric = _whitening_metric(x_detached, ctx.num_groups)
+
+ if random.random() < 0.005 or __name__ == "__main__":
+ logging.info(
+ f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, "
+ f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}"
+ )
+
+ (metric - ctx.whitening_limit).relu().backward()
+ penalty_grad = x_detached.grad
+ scale = ctx.grad_scale * (
+ x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20)
+ )
+ penalty_grad = penalty_grad * scale
+ return x_grad + penalty_grad.to(x_grad.dtype), None, None, None
+
+
+class Whiten(nn.Module):
+ def __init__(
+ self,
+ num_groups: int,
+ whitening_limit: float,
+ prob: Union[float, Tuple[float, float]],
+ grad_scale: float,
+ ):
+ """
+ Args:
+ num_groups: the number of groups to divide the channel dim into before
+ whitening. We will attempt to make the feature covariance
+ within each group, after mean subtraction, as "white" as possible,
+ while having the same trace across all groups.
+ whitening_limit: a value greater than 1.0, that dictates how much
+ freedom we have to violate the constraints. 1.0 would mean perfectly
+ white, with exactly the same trace across groups; larger values
+ give more freedom. E.g. 2.0.
+ prob: the probability with which we apply the gradient modification
+ (also affects the grad scale). May be supplied as a float,
+ or as a pair (min_prob, max_prob)
+
+ grad_scale: determines the scale on the gradient term from this object,
+ relative to the rest of the gradient on the attention weights.
+ E.g. 0.02 (you may want to use smaller values than this if prob is large)
+ """
+ super(Whiten, self).__init__()
+ assert num_groups >= 1
+ assert whitening_limit >= 1
+ assert grad_scale >= 0
+ self.num_groups = num_groups
+ self.whitening_limit = whitening_limit
+ if isinstance(prob, float):
+ assert 0 < prob <= 1
+ self.prob = prob
+ else:
+ (self.min_prob, self.max_prob) = prob
+ assert 0 < self.min_prob < self.max_prob <= 1
+ self.prob = self.max_prob
+
+ self.grad_scale = grad_scale
+
+ def forward(self, x: Tensor) -> Tensor:
+ """
+ In the forward pass, this function just returns the input unmodified.
+ In the backward pass, it will modify the gradients to ensure that the
+ distribution in each group has close to (lambda times I) as the covariance
+ after mean subtraction, with the same lambda across groups.
+ For whitening_limit > 1, there will be more freedom to violate this
+ constraint.
+
+ Args:
+ x: the input of shape (*, num_channels)
+
+ Returns:
+ x, unmodified. You should make sure
+ you use the returned value, or the graph will be freed
+ and nothing will happen in backprop.
+ """
+ if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0:
+ return _no_op(x)
+ else:
+ if hasattr(self, "min_prob") and random.random() < 0.25:
+ # occasionally switch between min_prob and max_prob, based on whether
+ # we are above or below the threshold.
+ if (
+ _whitening_metric(x.to(torch.float32), self.num_groups)
+ > self.whitening_limit
+ ):
+ # there would be a change to the grad.
+ self.prob = self.max_prob
+ else:
+ self.prob = self.min_prob
+
+ return WhiteningPenaltyFunction.apply(
+ x, self.num_groups, self.whitening_limit, self.grad_scale
+ )
+
+
+class WithLoss(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x: Tensor, y: Tensor):
+ ctx.y_shape = y.shape
+ return x
+
+ @staticmethod
+ def backward(ctx, ans_grad: Tensor):
+ return ans_grad, torch.ones(
+ ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device
+ )
+
+
+def with_loss(x, y):
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ return x
+ # returns x but adds y.sum() to the loss function.
+ return WithLoss.apply(x, y)
+
+
+def _no_op(x: Tensor) -> Tensor:
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ return x
+ else:
+ # a no-op function that will have a node in the autograd graph,
+ # to avoid certain bugs relating to backward hooks
+ return x.chunk(1, dim=-1)[0]
+
+
+class Identity(torch.nn.Module):
+ def __init__(self):
+ super(Identity, self).__init__()
+
+ def forward(self, x):
+ return _no_op(x)
+
+
+class MaxEig(torch.nn.Module):
+ """
+ Modifies the backpropped derivatives of a function to try to discourage
+ that any given direction in activation space accounts for more than
+ a specified proportion of the covariance (e.g. 0.2).
+
+
+ Args:
+ num_channels: the number of channels
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
+ max_var_per_eig: the maximum proportion of the variance of the
+ features/channels, after mean subtraction, that can come from
+ any given eigenvalue.
+ min_prob: the minimum probability with which we apply this during any invocation
+ of forward(), assuming last time we applied the constraint it was
+ not active; supplied for speed.
+ scale: determines the scale with which we modify the gradients, relative
+ to the existing / unmodified gradients
+ """
+
+ def __init__(
+ self,
+ num_channels: int,
+ channel_dim: int,
+ max_var_per_eig: float = 0.2,
+ min_prob: float = 0.01,
+ scale: float = 0.01,
+ ):
+ super(MaxEig, self).__init__()
+ self.num_channels = num_channels
+ self.channel_dim = channel_dim
+ self.scale = scale
+ assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels
+ self.max_var_per_eig = max_var_per_eig
+
+ # we figure out the dominant direction using the power method: starting with
+ # a random vector, keep multiplying by the covariance and renormalizing.
+ with torch.no_grad():
+ # arbitrary.. would use randn() but want to leave the rest of the model's
+ # random parameters unchanged for comparison
+ direction = torch.arange(num_channels).to(torch.float)
+ direction = direction / direction.norm()
+ self.register_buffer("max_eig_direction", direction)
+
+ self.min_prob = min_prob
+ # cur_prob is the current probability we'll use to apply the ActivationBalancer.
+ # We'll regress this towards prob, each tiem we try to apply it and it is not
+ # active.
+ self.cur_prob = 1.0
+
+ def forward(self, x: Tensor) -> Tensor:
+ if (
+ torch.jit.is_scripting()
+ or self.max_var_per_eig <= 0
+ or random.random() > self.cur_prob
+ or torch.jit.is_tracing()
+ ):
+ return _no_op(x)
+
+ with torch.cuda.amp.autocast(enabled=False):
+ eps = 1.0e-20
+ orig_x = x
+ x = x.to(torch.float32)
+ with torch.no_grad():
+ x = x.transpose(self.channel_dim, -1).reshape(-1, self.num_channels)
+ x = x - x.mean(dim=0)
+ new_direction, coeffs = self._find_direction_coeffs(
+ x, self.max_eig_direction
+ )
+ x_var = (x**2).mean()
+ x_residual = x - coeffs * new_direction
+ x_residual_var = (x_residual**2).mean()
+
+ # `variance_proportion` is the proportion of the variance accounted for
+ # by the top eigen-direction.
+ variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
+
+ # ensure new direction is nonzero even if x == 0, by including `direction`.
+ self._set_direction(0.1 * self.max_eig_direction + new_direction)
+
+ if random.random() < 0.01 or __name__ == "__main__":
+ logging.info(
+ f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}"
+ )
+
+ if variance_proportion >= self.max_var_per_eig:
+ # The constraint is active. Note, we should quite rarely
+ # reach here, only near the beginning of training if we are
+ # starting to diverge, should this constraint be active.
+ cur_prob = self.cur_prob
+ self.cur_prob = 1.0 # next time, do the update with probability 1.0.
+ return MaxEigLimiterFunction.apply(
+ orig_x, coeffs, new_direction, self.channel_dim, self.scale
+ )
+ else:
+ # let self.cur_prob exponentially approach self.min_prob, as
+ # long as the constraint is inactive.
+ self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob
+ return orig_x
+
+ def _set_direction(self, direction: Tensor):
+ """
+ Sets self.max_eig_direction to a normalized version of `direction`
+ """
+ direction = direction.detach()
+ direction = direction / direction.norm()
+ direction_sum = direction.sum().item()
+ if direction_sum - direction_sum == 0: # no inf/nan
+ self.max_eig_direction[:] = direction
+ else:
+ logging.info(
+ f"Warning: sum of direction in MaxEig is {direction_sum}, "
+ "num_channels={self.num_channels}, channel_dim={self.channel_dim}"
+ )
+
+ def _find_direction_coeffs(
+ self, x: Tensor, prev_direction: Tensor
+ ) -> Tuple[Tensor, Tensor, Tensor]:
+ """
+ Figure out (an approximation to) the proportion of the variance of a set of
+ feature vectors that can be attributed to the top eigen-direction.
+ Args:
+ x: a Tensor of shape (num_frames, num_channels), with num_frames > 1.
+ prev_direction: a Tensor of shape (num_channels,), that is our previous estimate
+ of the top eigen-direction, or a random direction if this is the first
+ iteration. Does not have to be normalized, but should be nonzero.
+
+ Returns: (cur_direction, coeffs), where:
+ cur_direction: a Tensor of shape (num_channels,) that is the current
+ estimate of the top eigen-direction.
+ coeffs: a Tensor of shape (num_frames, 1) that minimizes, or
+ approximately minimizes, (x - coeffs * cur_direction).norm()
+ """
+ (num_frames, num_channels) = x.shape
+ assert num_channels > 1 and num_frames > 1
+ assert prev_direction.shape == (num_channels,)
+ # `coeffs` are the coefficients of `prev_direction` in x.
+ # actually represent the coeffs up to a constant positive factor.
+ coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10
+ cur_direction = (x * coeffs).sum(dim=0) / ((coeffs**2).sum() + 1.0e-20)
+ return cur_direction, coeffs
+
+
+class DoubleSwishFunction(torch.autograd.Function):
+ """
+ double_swish(x) = x * torch.sigmoid(x-1)
+ This is a definition, originally motivated by its close numerical
+ similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
+
+ Memory-efficient derivative computation:
+ double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
+ double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
+ Now, s'(x) = s(x) * (1-s(x)).
+ double_swish'(x) = x * s'(x) + s(x).
+ = x * s(x) * (1-s(x)) + s(x).
+ = double_swish(x) * (1-s(x)) + s(x)
+ ... so we just need to remember s(x) but not x itself.
+ """
+
+ @staticmethod
+ def forward(ctx, x: Tensor) -> Tensor:
+ requires_grad = x.requires_grad
+ x_dtype = x.dtype
+ if x.dtype == torch.float16:
+ x = x.to(torch.float32)
+
+ s = torch.sigmoid(x - 1.0)
+ y = x * s
+
+ if requires_grad:
+ deriv = y * (1 - s) + s
+ # notes on derivative of x * sigmoid(x - 1):
+ # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
+ # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund
+ # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
+ # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
+ # floors), should be expectation-preserving.
+ floor = -0.043637
+ ceil = 1.2
+ d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
+ deriv
+ )
+ if __name__ == "__main__":
+ # for self-testing only.
+ assert d_scaled.min() >= 0.0
+ assert d_scaled.max() < 256.0
+ d_int = d_scaled.to(torch.uint8)
+ ctx.save_for_backward(d_int)
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
+ y = y.to(torch.float16)
+ return y
+
+ @staticmethod
+ def backward(ctx, y_grad: Tensor) -> Tensor:
+ (d,) = ctx.saved_tensors
+ # the same constants as used in forward pass.
+ floor = -0.043637
+ ceil = 1.2
+ d = d * ((ceil - floor) / 255.0) + floor
+ return y_grad * d
+
+
+class DoubleSwish(torch.nn.Module):
+ def forward(self, x: Tensor) -> Tensor:
+ """Return double-swish activation function which is an approximation to Swish(Swish(x)),
+ that we approximate closely with x * sigmoid(x-1).
+ """
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ return x * torch.sigmoid(x - 1.0)
+ return DoubleSwishFunction.apply(x)
+
+
+def BalancedDoubleSwish(
+ d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
+) -> nn.Sequential:
+ """
+ ActivationBalancer -> DoubleSwish
+ """
+ balancer = ActivationBalancer(
+ d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
+ )
+ return nn.Sequential(
+ balancer,
+ DoubleSwish(),
+ )
+
+
+def _test_max_eig():
+ for proportion in [0.1, 0.5, 10.0]:
+ logging.info(f"proportion = {proportion}")
+ x = torch.randn(100, 128)
+ direction = torch.randn(128)
+ coeffs = torch.randn(100, 1)
+ x += proportion * direction * coeffs
+
+ x.requires_grad = True
+
+ num_channels = 128
+ m = MaxEig(
+ num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig
+ ) # grad_scale
+
+ for _ in range(4):
+ y = m(x)
+
+ y_grad = torch.randn_like(x)
+ y.backward(gradient=y_grad)
+
+ if proportion < 0.2:
+ assert torch.allclose(x.grad, y_grad, atol=1.0e-02)
+ elif proportion > 1.0:
+ assert not torch.allclose(x.grad, y_grad)
+
+
+def _test_whiten():
+ for proportion in [0.1, 0.5, 10.0]:
+ logging.info(f"_test_whiten(): proportion = {proportion}")
+ x = torch.randn(100, 128)
+ direction = torch.randn(128)
+ coeffs = torch.randn(100, 1)
+ x += proportion * direction * coeffs
+
+ x.requires_grad = True
+
+ num_channels = 128
+ m = Whiten(
+ 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit,
+ ) # grad_scale
+
+ for _ in range(4):
+ y = m(x)
+
+ y_grad = torch.randn_like(x)
+ y.backward(gradient=y_grad)
+
+ if proportion < 0.2:
+ assert torch.allclose(x.grad, y_grad)
+ elif proportion > 1.0:
+ assert not torch.allclose(x.grad, y_grad)
+
+
+def _test_activation_balancer_sign():
+ probs = torch.arange(0, 1, 0.01)
+ N = 1000
+ x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0)
+ x = x.detach()
+ x.requires_grad = True
+ m = ActivationBalancer(
+ probs.numel(),
+ channel_dim=0,
+ min_positive=0.05,
+ max_positive=0.95,
+ max_factor=0.2,
+ min_abs=0.0,
+ )
+
+ y_grad = torch.sign(torch.randn(probs.numel(), N))
+
+ y = m(x)
+ y.backward(gradient=y_grad)
+ print("_test_activation_balancer_sign: x = ", x)
+ print("_test_activation_balancer_sign: y grad = ", y_grad)
+ print("_test_activation_balancer_sign: x grad = ", x.grad)
+
+
+def _test_activation_balancer_magnitude():
+ magnitudes = torch.arange(0, 1, 0.01)
+ N = 1000
+ x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
+ x = x.detach()
+ x.requires_grad = True
+ m = ActivationBalancer(
+ magnitudes.numel(),
+ channel_dim=0,
+ min_positive=0.0,
+ max_positive=1.0,
+ max_factor=0.2,
+ min_abs=0.2,
+ max_abs=0.8,
+ min_prob=1.0,
+ )
+
+ y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
+
+ y = m(x)
+ y.backward(gradient=y_grad)
+ print("_test_activation_balancer_magnitude: x = ", x)
+ print("_test_activation_balancer_magnitude: y grad = ", y_grad)
+ print("_test_activation_balancer_magnitude: x grad = ", x.grad)
+
+
+def _test_basic_norm():
+ num_channels = 128
+ m = BasicNorm(num_channels=num_channels, channel_dim=1)
+
+ x = torch.randn(500, num_channels)
+
+ y = m(x)
+
+ assert y.shape == x.shape
+ x_rms = (x**2).mean().sqrt()
+ y_rms = (y**2).mean().sqrt()
+ print("x rms = ", x_rms)
+ print("y rms = ", y_rms)
+ assert y_rms < x_rms
+ assert y_rms > 0.5 * x_rms
+
+
+def _test_double_swish_deriv():
+ x = torch.randn(10, 12, dtype=torch.double) * 3.0
+ x.requires_grad = True
+ m = DoubleSwish()
+
+ tol = (1.2 - (-0.043637)) / 255.0
+ torch.autograd.gradcheck(m, x, atol=tol)
+
+ # for self-test.
+ x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
+ x.requires_grad = True
+ y = m(x)
+
+
+def _test_softmax():
+ a = torch.randn(2, 10, dtype=torch.float64)
+ b = a.clone()
+ a.requires_grad = True
+ b.requires_grad = True
+ a.softmax(dim=1)[:, 0].sum().backward()
+ print("a grad = ", a.grad)
+ softmax(b, dim=1)[:, 0].sum().backward()
+ print("b grad = ", b.grad)
+ assert torch.allclose(a.grad, b.grad)
+
+
+if __name__ == "__main__":
+ logging.getLogger().setLevel(logging.INFO)
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ _test_softmax()
+ _test_whiten()
+ _test_max_eig()
+ _test_activation_balancer_sign()
+ _test_activation_balancer_magnitude()
+ _test_basic_norm()
+ _test_double_swish_deriv()
diff --git a/modules/general/utils.py b/modules/general/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..38c94dfbb9859661322f3d2e8204b423fffae736
--- /dev/null
+++ b/modules/general/utils.py
@@ -0,0 +1,100 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+
+
+def normalization(channels: int, groups: int = 32):
+ r"""Make a standard normalization layer, i.e. GroupNorm.
+
+ Args:
+ channels: number of input channels.
+ groups: number of groups for group normalization.
+
+ Returns:
+ a ``nn.Module`` for normalization.
+ """
+ assert groups > 0, f"invalid number of groups: {groups}"
+ return nn.GroupNorm(groups, channels)
+
+
+def Linear(*args, **kwargs):
+ r"""Wrapper of ``nn.Linear`` with kaiming_normal_ initialization."""
+ layer = nn.Linear(*args, **kwargs)
+ nn.init.kaiming_normal_(layer.weight)
+ return layer
+
+
+def Conv1d(*args, **kwargs):
+ r"""Wrapper of ``nn.Conv1d`` with kaiming_normal_ initialization."""
+ layer = nn.Conv1d(*args, **kwargs)
+ nn.init.kaiming_normal_(layer.weight)
+ return layer
+
+
+def Conv2d(*args, **kwargs):
+ r"""Wrapper of ``nn.Conv2d`` with kaiming_normal_ initialization."""
+ layer = nn.Conv2d(*args, **kwargs)
+ nn.init.kaiming_normal_(layer.weight)
+ return layer
+
+
+def ConvNd(dims: int = 1, *args, **kwargs):
+ r"""Wrapper of N-dimension convolution with kaiming_normal_ initialization.
+
+ Args:
+ dims: number of dimensions of the convolution.
+ """
+ if dims == 1:
+ return Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return Conv2d(*args, **kwargs)
+ else:
+ raise ValueError(f"invalid number of dimensions: {dims}")
+
+
+def zero_module(module: nn.Module):
+ r"""Zero out the parameters of a module and return it."""
+ nn.init.zeros_(module.weight)
+ nn.init.zeros_(module.bias)
+ return module
+
+
+def scale_module(module: nn.Module, scale):
+ r"""Scale the parameters of a module and return it."""
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def mean_flat(tensor: torch.Tensor):
+ r"""Take the mean over all non-batch dimensions."""
+ return tensor.mean(dim=tuple(range(1, tensor.dim())))
+
+
+def append_dims(x, target_dims):
+ r"""Appends dimensions to the end of a tensor until
+ it has target_dims dimensions.
+ """
+ dims_to_append = target_dims - x.dim()
+ if dims_to_append < 0:
+ raise ValueError(
+ f"input has {x.dim()} dims but target_dims is {target_dims}, which is less"
+ )
+ return x[(...,) + (None,) * dims_to_append]
+
+
+def append_zero(x, count=1):
+ r"""Appends ``count`` zeros to the end of a tensor along the last dimension."""
+ assert count > 0, f"invalid count: {count}"
+ return torch.cat([x, x.new_zeros((*x.size()[:-1], count))], dim=-1)
+
+
+class Transpose(nn.Identity):
+ """(N, T, D) -> (N, D, T)"""
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return input.transpose(1, 2)
diff --git a/modules/monotonic_align/__init__.py b/modules/monotonic_align/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..784d9081ca1aa1355b893dbd91d159d6f911cb4b
--- /dev/null
+++ b/modules/monotonic_align/__init__.py
@@ -0,0 +1,21 @@
+# This code from https://github.com/jaywalnut310/vits/
+
+import numpy as np
+import torch
+from .monotonic_align.core import maximum_path_c
+
+
+def maximum_path(neg_cent, mask):
+ """Cython optimized version.
+ neg_cent: [b, t_t, t_s]
+ mask: [b, t_t, t_s]
+ """
+ device = neg_cent.device
+ dtype = neg_cent.dtype
+ neg_cent = neg_cent.data.cpu().numpy().astype(np.float32)
+ path = np.zeros(neg_cent.shape, dtype=np.int32)
+
+ t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32)
+ t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32)
+ maximum_path_c(path, neg_cent, t_t_max, t_s_max)
+ return torch.from_numpy(path).to(device=device, dtype=dtype)
diff --git a/modules/monotonic_align/core.pyx b/modules/monotonic_align/core.pyx
new file mode 100644
index 0000000000000000000000000000000000000000..bfaabd4d21c2299cdd978f0cc0caefa20ad186e5
--- /dev/null
+++ b/modules/monotonic_align/core.pyx
@@ -0,0 +1,42 @@
+cimport cython
+from cython.parallel import prange
+
+
+@cython.boundscheck(False)
+@cython.wraparound(False)
+cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil:
+ cdef int x
+ cdef int y
+ cdef float v_prev
+ cdef float v_cur
+ cdef float tmp
+ cdef int index = t_x - 1
+
+ for y in range(t_y):
+ for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
+ if x == y:
+ v_cur = max_neg_val
+ else:
+ v_cur = value[y-1, x]
+ if x == 0:
+ if y == 0:
+ v_prev = 0.
+ else:
+ v_prev = max_neg_val
+ else:
+ v_prev = value[y-1, x-1]
+ value[y, x] += max(v_prev, v_cur)
+
+ for y in range(t_y - 1, -1, -1):
+ path[y, index] = 1
+ if index != 0 and (index == y or value[y-1, index] < value[y-1, index-1]):
+ index = index - 1
+
+
+@cython.boundscheck(False)
+@cython.wraparound(False)
+cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_ys, int[::1] t_xs) nogil:
+ cdef int b = paths.shape[0]
+ cdef int i
+ for i in prange(b, nogil=True):
+ maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i])
diff --git a/modules/monotonic_align/setup.py b/modules/monotonic_align/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..773fafd96fb0f1bac88037187542c84eccb26ef5
--- /dev/null
+++ b/modules/monotonic_align/setup.py
@@ -0,0 +1,22 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# This code is modified from https://github.com/jaywalnut310/vits/
+
+from distutils.core import setup, Extension
+from Cython.Build import cythonize
+import numpy
+
+extension = Extension(
+ name="monotonic_align.core",
+ sources=["core.pyx"],
+ include_dirs=[numpy.get_include()],
+ # Define additional arguments if needed
+)
+
+setup(
+ name="monotonic_align",
+ ext_modules=cythonize([extension]),
+)
diff --git a/modules/naturalpseech2/transformers.py b/modules/naturalpseech2/transformers.py
new file mode 100644
index 0000000000000000000000000000000000000000..d094c465f97cab4063f3e9baf1e749b2d403e5a8
--- /dev/null
+++ b/modules/naturalpseech2/transformers.py
@@ -0,0 +1,514 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import torch
+from torch import nn
+from torch.nn import Parameter
+import torch.nn.functional as F
+import numpy as np
+
+
+class StyleAdaptiveLayerNorm(nn.Module):
+ def __init__(self, normalized_shape, eps=1e-5):
+ super().__init__()
+ self.in_dim = normalized_shape
+ self.norm = nn.LayerNorm(self.in_dim, eps=eps, elementwise_affine=False)
+ self.style = nn.Linear(self.in_dim, self.in_dim * 2)
+ self.style.bias.data[: self.in_dim] = 1
+ self.style.bias.data[self.in_dim :] = 0
+
+ def forward(self, x, condition):
+ # x: (B, T, d); condition: (B, T, d)
+
+ style = self.style(torch.mean(condition, dim=1, keepdim=True))
+
+ gamma, beta = style.chunk(2, -1)
+
+ out = self.norm(x)
+
+ out = gamma * out + beta
+ return out
+
+
+class PositionalEncoding(nn.Module):
+ def __init__(self, d_model, dropout, max_len=5000):
+ super().__init__()
+
+ self.dropout = dropout
+ position = torch.arange(max_len).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
+ )
+ pe = torch.zeros(max_len, 1, d_model)
+ pe[:, 0, 0::2] = torch.sin(position * div_term)
+ pe[:, 0, 1::2] = torch.cos(position * div_term)
+ self.register_buffer("pe", pe)
+
+ def forward(self, x):
+ x = x + self.pe[: x.size(0)]
+ return F.dropout(x, self.dropout, training=self.training)
+
+
+class TransformerFFNLayer(nn.Module):
+ def __init__(
+ self, encoder_hidden, conv_filter_size, conv_kernel_size, encoder_dropout
+ ):
+ super().__init__()
+
+ self.encoder_hidden = encoder_hidden
+ self.conv_filter_size = conv_filter_size
+ self.conv_kernel_size = conv_kernel_size
+ self.encoder_dropout = encoder_dropout
+
+ self.ffn_1 = nn.Conv1d(
+ self.encoder_hidden,
+ self.conv_filter_size,
+ self.conv_kernel_size,
+ padding=self.conv_kernel_size // 2,
+ )
+ self.ffn_1.weight.data.normal_(0.0, 0.02)
+ self.ffn_2 = nn.Linear(self.conv_filter_size, self.encoder_hidden)
+ self.ffn_2.weight.data.normal_(0.0, 0.02)
+
+ def forward(self, x):
+ # x: (B, T, d)
+ x = self.ffn_1(x.permute(0, 2, 1)).permute(
+ 0, 2, 1
+ ) # (B, T, d) -> (B, d, T) -> (B, T, d)
+ x = F.relu(x)
+ x = F.dropout(x, self.encoder_dropout, training=self.training)
+ x = self.ffn_2(x)
+ return x
+
+
+class TransformerEncoderLayer(nn.Module):
+ def __init__(
+ self,
+ encoder_hidden,
+ encoder_head,
+ conv_filter_size,
+ conv_kernel_size,
+ encoder_dropout,
+ use_cln,
+ ):
+ super().__init__()
+ self.encoder_hidden = encoder_hidden
+ self.encoder_head = encoder_head
+ self.conv_filter_size = conv_filter_size
+ self.conv_kernel_size = conv_kernel_size
+ self.encoder_dropout = encoder_dropout
+ self.use_cln = use_cln
+
+ if not self.use_cln:
+ self.ln_1 = nn.LayerNorm(self.encoder_hidden)
+ self.ln_2 = nn.LayerNorm(self.encoder_hidden)
+ else:
+ self.ln_1 = StyleAdaptiveLayerNorm(self.encoder_hidden)
+ self.ln_2 = StyleAdaptiveLayerNorm(self.encoder_hidden)
+
+ self.self_attn = nn.MultiheadAttention(
+ self.encoder_hidden, self.encoder_head, batch_first=True
+ )
+
+ self.ffn = TransformerFFNLayer(
+ self.encoder_hidden,
+ self.conv_filter_size,
+ self.conv_kernel_size,
+ self.encoder_dropout,
+ )
+
+ def forward(self, x, key_padding_mask, conditon=None):
+ # x: (B, T, d); key_padding_mask: (B, T), mask is 0; condition: (B, T, d)
+
+ # self attention
+ residual = x
+ if self.use_cln:
+ x = self.ln_1(x, conditon)
+ else:
+ x = self.ln_1(x)
+
+ if key_padding_mask != None:
+ key_padding_mask_input = ~(key_padding_mask.bool())
+ else:
+ key_padding_mask_input = None
+ x, _ = self.self_attn(
+ query=x, key=x, value=x, key_padding_mask=key_padding_mask_input
+ )
+ x = F.dropout(x, self.encoder_dropout, training=self.training)
+ x = residual + x
+
+ # ffn
+ residual = x
+ if self.use_cln:
+ x = self.ln_2(x, conditon)
+ else:
+ x = self.ln_2(x)
+ x = self.ffn(x)
+ x = residual + x
+
+ return x
+
+
+class TransformerEncoder(nn.Module):
+ def __init__(
+ self,
+ enc_emb_tokens=None,
+ encoder_layer=None,
+ encoder_hidden=None,
+ encoder_head=None,
+ conv_filter_size=None,
+ conv_kernel_size=None,
+ encoder_dropout=None,
+ use_cln=None,
+ cfg=None,
+ ):
+ super().__init__()
+
+ self.encoder_layer = (
+ encoder_layer if encoder_layer is not None else cfg.encoder_layer
+ )
+ self.encoder_hidden = (
+ encoder_hidden if encoder_hidden is not None else cfg.encoder_hidden
+ )
+ self.encoder_head = (
+ encoder_head if encoder_head is not None else cfg.encoder_head
+ )
+ self.conv_filter_size = (
+ conv_filter_size if conv_filter_size is not None else cfg.conv_filter_size
+ )
+ self.conv_kernel_size = (
+ conv_kernel_size if conv_kernel_size is not None else cfg.conv_kernel_size
+ )
+ self.encoder_dropout = (
+ encoder_dropout if encoder_dropout is not None else cfg.encoder_dropout
+ )
+ self.use_cln = use_cln if use_cln is not None else cfg.use_cln
+
+ if enc_emb_tokens != None:
+ self.use_enc_emb = True
+ self.enc_emb_tokens = enc_emb_tokens
+ else:
+ self.use_enc_emb = False
+
+ self.position_emb = PositionalEncoding(
+ self.encoder_hidden, self.encoder_dropout
+ )
+
+ self.layers = nn.ModuleList([])
+ self.layers.extend(
+ [
+ TransformerEncoderLayer(
+ self.encoder_hidden,
+ self.encoder_head,
+ self.conv_filter_size,
+ self.conv_kernel_size,
+ self.encoder_dropout,
+ self.use_cln,
+ )
+ for i in range(self.encoder_layer)
+ ]
+ )
+
+ if self.use_cln:
+ self.last_ln = StyleAdaptiveLayerNorm(self.encoder_hidden)
+ else:
+ self.last_ln = nn.LayerNorm(self.encoder_hidden)
+
+ def forward(self, x, key_padding_mask, condition=None):
+ if len(x.shape) == 2 and self.use_enc_emb:
+ x = self.enc_emb_tokens(x)
+ x = self.position_emb(x)
+ else:
+ x = self.position_emb(x) # (B, T, d)
+
+ for layer in self.layers:
+ x = layer(x, key_padding_mask, condition)
+
+ if self.use_cln:
+ x = self.last_ln(x, condition)
+ else:
+ x = self.last_ln(x)
+
+ return x
+
+
+class DurationPredictor(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.cfg = cfg
+ self.input_size = cfg.input_size
+ self.filter_size = cfg.filter_size
+ self.kernel_size = cfg.kernel_size
+ self.conv_layers = cfg.conv_layers
+ self.cross_attn_per_layer = cfg.cross_attn_per_layer
+ self.attn_head = cfg.attn_head
+ self.drop_out = cfg.drop_out
+
+ self.conv = nn.ModuleList()
+ self.cattn = nn.ModuleList()
+
+ for idx in range(self.conv_layers):
+ in_dim = self.input_size if idx == 0 else self.filter_size
+ self.conv += [
+ nn.Sequential(
+ nn.Conv1d(
+ in_dim,
+ self.filter_size,
+ self.kernel_size,
+ padding=self.kernel_size // 2,
+ ),
+ nn.ReLU(),
+ nn.LayerNorm(self.filter_size),
+ nn.Dropout(self.drop_out),
+ )
+ ]
+ if idx % self.cross_attn_per_layer == 0:
+ self.cattn.append(
+ torch.nn.Sequential(
+ nn.MultiheadAttention(
+ self.filter_size,
+ self.attn_head,
+ batch_first=True,
+ kdim=self.filter_size,
+ vdim=self.filter_size,
+ ),
+ nn.LayerNorm(self.filter_size),
+ nn.Dropout(0.2),
+ )
+ )
+
+ self.linear = nn.Linear(self.filter_size, 1)
+ self.linear.weight.data.normal_(0.0, 0.02)
+
+ def forward(self, x, mask, ref_emb, ref_mask):
+ """
+ input:
+ x: (B, N, d)
+ mask: (B, N), mask is 0
+ ref_emb: (B, d, T')
+ ref_mask: (B, T'), mask is 0
+
+ output:
+ dur_pred: (B, N)
+ dur_pred_log: (B, N)
+ dur_pred_round: (B, N)
+ """
+
+ input_ref_mask = ~(ref_mask.bool()) # (B, T')
+ # print(input_ref_mask)
+
+ x = x.transpose(1, -1) # (B, N, d) -> (B, d, N)
+
+ for idx, (conv, act, ln, dropout) in enumerate(self.conv):
+ res = x
+ # print(torch.min(x), torch.max(x))
+ if idx % self.cross_attn_per_layer == 0:
+ attn_idx = idx // self.cross_attn_per_layer
+ attn, attn_ln, attn_drop = self.cattn[attn_idx]
+
+ attn_res = y_ = x.transpose(1, 2) # (B, d, N) -> (B, N, d)
+
+ y_ = attn_ln(y_)
+ # print(torch.min(y_), torch.min(y_))
+ # print(torch.min(ref_emb), torch.max(ref_emb))
+ y_, _ = attn(
+ y_,
+ ref_emb.transpose(1, 2),
+ ref_emb.transpose(1, 2),
+ key_padding_mask=input_ref_mask,
+ )
+ # y_, _ = attn(y_, ref_emb.transpose(1, 2), ref_emb.transpose(1, 2))
+ # print(torch.min(y_), torch.min(y_))
+ y_ = attn_drop(y_)
+ y_ = (y_ + attn_res) / math.sqrt(2.0)
+
+ x = y_.transpose(1, 2)
+
+ x = conv(x)
+ # print(torch.min(x), torch.max(x))
+ x = act(x)
+ x = ln(x.transpose(1, 2))
+ # print(torch.min(x), torch.max(x))
+ x = x.transpose(1, 2)
+
+ x = dropout(x)
+
+ if idx != 0:
+ x += res
+
+ if mask is not None:
+ x = x * mask.to(x.dtype)[:, None, :]
+
+ x = self.linear(x.transpose(1, 2))
+ x = torch.squeeze(x, -1)
+
+ dur_pred = x.exp() - 1
+ dur_pred_round = torch.clamp(torch.round(x.exp() - 1), min=0).long()
+
+ return {
+ "dur_pred_log": x,
+ "dur_pred": dur_pred,
+ "dur_pred_round": dur_pred_round,
+ }
+
+
+class PitchPredictor(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.cfg = cfg
+ self.input_size = cfg.input_size
+ self.filter_size = cfg.filter_size
+ self.kernel_size = cfg.kernel_size
+ self.conv_layers = cfg.conv_layers
+ self.cross_attn_per_layer = cfg.cross_attn_per_layer
+ self.attn_head = cfg.attn_head
+ self.drop_out = cfg.drop_out
+
+ self.conv = nn.ModuleList()
+ self.cattn = nn.ModuleList()
+
+ for idx in range(self.conv_layers):
+ in_dim = self.input_size if idx == 0 else self.filter_size
+ self.conv += [
+ nn.Sequential(
+ nn.Conv1d(
+ in_dim,
+ self.filter_size,
+ self.kernel_size,
+ padding=self.kernel_size // 2,
+ ),
+ nn.ReLU(),
+ nn.LayerNorm(self.filter_size),
+ nn.Dropout(self.drop_out),
+ )
+ ]
+ if idx % self.cross_attn_per_layer == 0:
+ self.cattn.append(
+ torch.nn.Sequential(
+ nn.MultiheadAttention(
+ self.filter_size,
+ self.attn_head,
+ batch_first=True,
+ kdim=self.filter_size,
+ vdim=self.filter_size,
+ ),
+ nn.LayerNorm(self.filter_size),
+ nn.Dropout(0.2),
+ )
+ )
+
+ self.linear = nn.Linear(self.filter_size, 1)
+ self.linear.weight.data.normal_(0.0, 0.02)
+
+ def forward(self, x, mask, ref_emb, ref_mask):
+ """
+ input:
+ x: (B, N, d)
+ mask: (B, N), mask is 0
+ ref_emb: (B, d, T')
+ ref_mask: (B, T'), mask is 0
+
+ output:
+ pitch_pred: (B, T)
+ """
+
+ input_ref_mask = ~(ref_mask.bool()) # (B, T')
+
+ x = x.transpose(1, -1) # (B, N, d) -> (B, d, N)
+
+ for idx, (conv, act, ln, dropout) in enumerate(self.conv):
+ res = x
+ if idx % self.cross_attn_per_layer == 0:
+ attn_idx = idx // self.cross_attn_per_layer
+ attn, attn_ln, attn_drop = self.cattn[attn_idx]
+
+ attn_res = y_ = x.transpose(1, 2) # (B, d, N) -> (B, N, d)
+
+ y_ = attn_ln(y_)
+ y_, _ = attn(
+ y_,
+ ref_emb.transpose(1, 2),
+ ref_emb.transpose(1, 2),
+ key_padding_mask=input_ref_mask,
+ )
+ # y_, _ = attn(y_, ref_emb.transpose(1, 2), ref_emb.transpose(1, 2))
+ y_ = attn_drop(y_)
+ y_ = (y_ + attn_res) / math.sqrt(2.0)
+
+ x = y_.transpose(1, 2)
+
+ x = conv(x)
+ x = act(x)
+ x = ln(x.transpose(1, 2))
+ x = x.transpose(1, 2)
+
+ x = dropout(x)
+
+ if idx != 0:
+ x += res
+
+ x = self.linear(x.transpose(1, 2))
+ x = torch.squeeze(x, -1)
+
+ return x
+
+
+def pad(input_ele, mel_max_length=None):
+ if mel_max_length:
+ max_len = mel_max_length
+ else:
+ max_len = max([input_ele[i].size(0) for i in range(len(input_ele))])
+
+ out_list = list()
+ for i, batch in enumerate(input_ele):
+ if len(batch.shape) == 1:
+ one_batch_padded = F.pad(
+ batch, (0, max_len - batch.size(0)), "constant", 0.0
+ )
+ elif len(batch.shape) == 2:
+ one_batch_padded = F.pad(
+ batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0
+ )
+ out_list.append(one_batch_padded)
+ out_padded = torch.stack(out_list)
+ return out_padded
+
+
+class LengthRegulator(nn.Module):
+ """Length Regulator"""
+
+ def __init__(self):
+ super(LengthRegulator, self).__init__()
+
+ def LR(self, x, duration, max_len):
+ device = x.device
+ output = list()
+ mel_len = list()
+ for batch, expand_target in zip(x, duration):
+ expanded = self.expand(batch, expand_target)
+ output.append(expanded)
+ mel_len.append(expanded.shape[0])
+
+ if max_len is not None:
+ output = pad(output, max_len)
+ else:
+ output = pad(output)
+
+ return output, torch.LongTensor(mel_len).to(device)
+
+ def expand(self, batch, predicted):
+ out = list()
+
+ for i, vec in enumerate(batch):
+ expand_size = predicted[i].item()
+ out.append(vec.expand(max(int(expand_size), 0), -1))
+ out = torch.cat(out, 0)
+
+ return out
+
+ def forward(self, x, duration, max_len):
+ output, mel_len = self.LR(x, duration, max_len)
+ return output, mel_len
diff --git a/modules/neural_source_filter/__init__.py b/modules/neural_source_filter/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa2749078427c312fa046f354ded858bae82f666
--- /dev/null
+++ b/modules/neural_source_filter/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .sine_excitation import *
diff --git a/modules/neural_source_filter/sine_excitation.py b/modules/neural_source_filter/sine_excitation.py
new file mode 100644
index 0000000000000000000000000000000000000000..08b72b2fc4f5e3065e32b2891bdd7f510c727e70
--- /dev/null
+++ b/modules/neural_source_filter/sine_excitation.py
@@ -0,0 +1,82 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+#################### NSF ####################
+
+import torch
+
+import numpy as np
+import torch.nn as nn
+import torch.nn.functional as F
+
+# This code is adopted from Xin Wang's NSF under the MIT License
+# https://github.com/nii-yamagishilab/project-NN-Pytorch-scripts
+
+
+class SineGen(nn.Module):
+ def __init__(
+ self, fs, harmonic_num=0, amp=0.1, noise_std=0.003, voiced_threshold=0
+ ):
+ super(SineGen, self).__init__()
+ self.amp = amp
+ self.noise_std = noise_std
+ self.harmonic_num = harmonic_num
+ self.dim = harmonic_num + 1
+ self.fs = fs
+ self.voice_threshold = voiced_threshold
+
+ def _f0toUnvoiced(self, f0):
+ uv = torch.ones_like(f0)
+ uv = uv * (f0 > self.voice_threshold)
+ return uv
+
+ @torch.no_grad()
+ def forward(self, f0, upp):
+ f0 = f0.unsqueeze(-1)
+ fn = torch.multiply(
+ f0, torch.arange(1, self.dim + 1, device=f0.device).reshape(1, 1, -1)
+ )
+ rad_values = (fn / self.fs) % 1
+ rand_ini = torch.rand(fn.shape[0], fn.shape[2], device=fn.device)
+ rand_ini[:, 0] = 0
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
+ is_half = rad_values.dtype is not torch.float32
+ tmp_over_one = torch.cumsum(rad_values.double(), 1)
+ if is_half:
+ tmp_over_one = tmp_over_one.half()
+ else:
+ tmp_over_one = tmp_over_one.float()
+ tmp_over_one *= upp
+ tmp_over_one = F.interpolate(
+ tmp_over_one.transpose(2, 1),
+ scale_factor=upp,
+ mode="linear",
+ align_corners=True,
+ ).transpose(2, 1)
+ rad_values = F.interpolate(
+ rad_values.transpose(2, 1), scale_factor=upp, mode="nearest"
+ ).transpose(2, 1)
+ tmp_over_one %= 1
+ tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
+ cumsum_shift = torch.zeros_like(rad_values)
+ cumsum_shift[:, 1:, :] = tmp_over_one_idx * (-1.0)
+ rad_values = rad_values.double()
+ cumsum_shift = cumsum_shift.double()
+ sine_waves = torch.sin(
+ torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
+ )
+ if is_half:
+ sine_waves = sine_waves.half()
+ else:
+ sine_waves = sine_waves.float()
+ sine_waves = sine_waves * self.amp
+ uv = self._f0toUnvoiced(f0)
+ uv = F.interpolate(
+ uv.transpose(2, 1), scale_factor=upp, mode="nearest"
+ ).transpose(2, 1)
+ noise_amp = uv * self.noise_std + (1 - uv) * self.amp / 3
+ noise = noise_amp * torch.randn_like(sine_waves)
+ sine_waves = sine_waves * uv + noise
+ return sine_waves, uv, noise
diff --git a/modules/norms/__init__.py b/modules/norms/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..16c8ce8395657c349c2cf41a4d64eb20ceda7e53
--- /dev/null
+++ b/modules/norms/__init__.py
@@ -0,0 +1 @@
+from .norm import AdaptiveLayerNorm, LayerNorm, BalancedBasicNorm, IdentityNorm
diff --git a/modules/norms/norm.py b/modules/norms/norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..adf6992f1a70250406ccd48ff132718f2af884a6
--- /dev/null
+++ b/modules/norms/norm.py
@@ -0,0 +1,173 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import copy
+import numbers
+from typing import Any, List, Tuple, Union
+
+import torch
+from torch import Tensor, nn
+from torch.nn import functional as F
+
+from modules.general.scaling import ActivationBalancer
+from modules.general.scaling import BasicNorm as _BasicNorm
+
+
+_shape_t = Union[int, List[int], torch.Size]
+
+
+class LayerNorm(nn.Module):
+ __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
+ normalized_shape: Tuple[int, ...]
+ eps: float
+ elementwise_affine: bool
+
+ def __init__(
+ self,
+ normalized_shape: _shape_t,
+ eps: float = 1e-5,
+ elementwise_affine: bool = True,
+ device=None,
+ dtype=None,
+ ) -> None:
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super(LayerNorm, self).__init__()
+ if isinstance(normalized_shape, numbers.Integral):
+ normalized_shape = (normalized_shape,)
+ self.normalized_shape = tuple(normalized_shape)
+ self.eps = eps
+ self.elementwise_affine = elementwise_affine
+ if self.elementwise_affine:
+ self.weight = nn.Parameter(
+ torch.empty(self.normalized_shape, **factory_kwargs)
+ )
+ self.bias = nn.Parameter(
+ torch.empty(self.normalized_shape, **factory_kwargs)
+ )
+ else:
+ self.register_parameter("weight", None)
+ self.register_parameter("bias", None)
+
+ self.reset_parameters()
+
+ def reset_parameters(self) -> None:
+ if self.elementwise_affine:
+ nn.init.ones_(self.weight)
+ nn.init.zeros_(self.bias)
+
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
+ if isinstance(input, tuple):
+ input, embedding = input
+ output = F.layer_norm(
+ input, self.normalized_shape, self.weight, self.bias, self.eps
+ )
+ return output, embedding
+
+ assert embedding is None
+ return F.layer_norm(
+ input, self.normalized_shape, self.weight, self.bias, self.eps
+ )
+
+ def extra_repr(self) -> str:
+ return (
+ "{normalized_shape}, eps={eps}, "
+ "elementwise_affine={elementwise_affine}".format(**self.__dict__)
+ )
+
+
+class AdaptiveLayerNorm(nn.Module):
+ r"""Adaptive Layer Normalization"""
+
+ def __init__(self, d_model, norm) -> None:
+ super(AdaptiveLayerNorm, self).__init__()
+ self.project_layer = nn.Linear(d_model, 2 * d_model)
+ self.norm = norm
+ self.d_model = d_model
+ self.eps = self.norm.eps
+
+ def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
+ if isinstance(input, tuple):
+ input, embedding = input
+ weight, bias = torch.split(
+ self.project_layer(embedding),
+ split_size_or_sections=self.d_model,
+ dim=-1,
+ )
+ return (weight * self.norm(input) + bias, embedding)
+
+ weight, bias = torch.split(
+ self.project_layer(embedding),
+ split_size_or_sections=self.d_model,
+ dim=-1,
+ )
+ return weight * self.norm(input) + bias
+
+
+class BasicNorm(_BasicNorm):
+ def __init__(
+ self,
+ d_model: int,
+ eps: float = 1e-5,
+ device=None,
+ dtype=None,
+ ):
+ super(BasicNorm, self).__init__(d_model, eps=eps)
+
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
+ if isinstance(input, tuple):
+ input, embedding = input
+ return (
+ super(BasicNorm, self).forward(input),
+ embedding,
+ )
+
+ assert embedding is None
+ return super(BasicNorm, self).forward(input)
+
+
+class BalancedBasicNorm(nn.Module):
+ def __init__(
+ self,
+ d_model: int,
+ eps: float = 1e-5,
+ device=None,
+ dtype=None,
+ ):
+ super(BalancedBasicNorm, self).__init__()
+ self.balancer = ActivationBalancer(
+ d_model,
+ channel_dim=-1,
+ min_positive=0.45,
+ max_positive=0.55,
+ max_abs=6.0,
+ )
+ self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype)
+
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
+ if isinstance(input, tuple):
+ input, embedding = input
+ return self.norm((self.balancer(input), embedding))
+
+ assert embedding is None
+ return self.norm(self.balancer(input))
+
+
+class IdentityNorm(nn.Module):
+ def __init__(
+ self,
+ d_model: int,
+ eps: float = 1e-5,
+ device=None,
+ dtype=None,
+ ) -> None:
+ super(IdentityNorm, self).__init__()
+
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
+ if isinstance(input, tuple):
+ return input
+
+ assert embedding is None
+ return input
diff --git a/modules/transformer/Constants.py b/modules/transformer/Constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bd586f3185114b7ab22f9a2479e93088d74ced0
--- /dev/null
+++ b/modules/transformer/Constants.py
@@ -0,0 +1,14 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+PAD = 0
+UNK = 1
+BOS = 2
+EOS = 3
+
+PAD_WORD = ""
+UNK_WORD = ""
+BOS_WORD = ""
+EOS_WORD = ""
diff --git a/modules/transformer/Layers.py b/modules/transformer/Layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..915b86a7f434456f24643e28fd93813bd6e13d52
--- /dev/null
+++ b/modules/transformer/Layers.py
@@ -0,0 +1,137 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+from .SubLayers import MultiHeadAttention, PositionwiseFeedForward
+
+
+class FFTBlock(torch.nn.Module):
+ """FFT Block"""
+
+ def __init__(self, d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=0.1):
+ super(FFTBlock, self).__init__()
+ self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
+ self.pos_ffn = PositionwiseFeedForward(
+ d_model, d_inner, kernel_size, dropout=dropout
+ )
+
+ def forward(self, enc_input, mask=None, slf_attn_mask=None):
+ enc_output, enc_slf_attn = self.slf_attn(
+ enc_input, enc_input, enc_input, mask=slf_attn_mask
+ )
+ enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0)
+
+ enc_output = self.pos_ffn(enc_output)
+ enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0)
+
+ return enc_output, enc_slf_attn
+
+
+class ConvNorm(torch.nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=None,
+ dilation=1,
+ bias=True,
+ w_init_gain="linear",
+ ):
+ super(ConvNorm, self).__init__()
+
+ if padding is None:
+ assert kernel_size % 2 == 1
+ padding = int(dilation * (kernel_size - 1) / 2)
+
+ self.conv = torch.nn.Conv1d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ bias=bias,
+ )
+
+ def forward(self, signal):
+ conv_signal = self.conv(signal)
+
+ return conv_signal
+
+
+class PostNet(nn.Module):
+ """
+ PostNet: Five 1-d convolution with 512 channels and kernel size 5
+ """
+
+ def __init__(
+ self,
+ n_mel_channels=80,
+ postnet_embedding_dim=512,
+ postnet_kernel_size=5,
+ postnet_n_convolutions=5,
+ ):
+ super(PostNet, self).__init__()
+ self.convolutions = nn.ModuleList()
+
+ self.convolutions.append(
+ nn.Sequential(
+ ConvNorm(
+ n_mel_channels,
+ postnet_embedding_dim,
+ kernel_size=postnet_kernel_size,
+ stride=1,
+ padding=int((postnet_kernel_size - 1) / 2),
+ dilation=1,
+ w_init_gain="tanh",
+ ),
+ nn.BatchNorm1d(postnet_embedding_dim),
+ )
+ )
+
+ for i in range(1, postnet_n_convolutions - 1):
+ self.convolutions.append(
+ nn.Sequential(
+ ConvNorm(
+ postnet_embedding_dim,
+ postnet_embedding_dim,
+ kernel_size=postnet_kernel_size,
+ stride=1,
+ padding=int((postnet_kernel_size - 1) / 2),
+ dilation=1,
+ w_init_gain="tanh",
+ ),
+ nn.BatchNorm1d(postnet_embedding_dim),
+ )
+ )
+
+ self.convolutions.append(
+ nn.Sequential(
+ ConvNorm(
+ postnet_embedding_dim,
+ n_mel_channels,
+ kernel_size=postnet_kernel_size,
+ stride=1,
+ padding=int((postnet_kernel_size - 1) / 2),
+ dilation=1,
+ w_init_gain="linear",
+ ),
+ nn.BatchNorm1d(n_mel_channels),
+ )
+ )
+
+ def forward(self, x):
+ x = x.contiguous().transpose(1, 2)
+
+ for i in range(len(self.convolutions) - 1):
+ x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
+ x = F.dropout(self.convolutions[-1](x), 0.5, self.training)
+
+ x = x.contiguous().transpose(1, 2)
+ return x
diff --git a/modules/transformer/Models.py b/modules/transformer/Models.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f2bf3796da5324dbdacd50297563d135e08af6c
--- /dev/null
+++ b/modules/transformer/Models.py
@@ -0,0 +1,181 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import numpy as np
+
+from .Layers import FFTBlock
+from text.symbols import symbols
+
+PAD = 0
+UNK = 1
+BOS = 2
+EOS = 3
+
+PAD_WORD = ""
+UNK_WORD = ""
+BOS_WORD = ""
+EOS_WORD = ""
+
+
+def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
+ """Sinusoid position encoding table"""
+
+ def cal_angle(position, hid_idx):
+ return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)
+
+ def get_posi_angle_vec(position):
+ return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
+
+ sinusoid_table = np.array(
+ [get_posi_angle_vec(pos_i) for pos_i in range(n_position)]
+ )
+
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
+
+ if padding_idx is not None:
+ # zero vector for padding dimension
+ sinusoid_table[padding_idx] = 0.0
+
+ return torch.FloatTensor(sinusoid_table)
+
+
+class Encoder(nn.Module):
+ """Encoder"""
+
+ def __init__(self, config):
+ super(Encoder, self).__init__()
+
+ n_position = config["max_seq_len"] + 1
+ n_src_vocab = len(symbols) + 1
+ d_word_vec = config["transformer"]["encoder_hidden"]
+ n_layers = config["transformer"]["encoder_layer"]
+ n_head = config["transformer"]["encoder_head"]
+ d_k = d_v = (
+ config["transformer"]["encoder_hidden"]
+ // config["transformer"]["encoder_head"]
+ )
+ d_model = config["transformer"]["encoder_hidden"]
+ d_inner = config["transformer"]["conv_filter_size"]
+ kernel_size = config["transformer"]["conv_kernel_size"]
+ dropout = config["transformer"]["encoder_dropout"]
+
+ self.max_seq_len = config["max_seq_len"]
+ self.d_model = d_model
+
+ self.src_word_emb = nn.Embedding(n_src_vocab, d_word_vec, padding_idx=PAD)
+ self.position_enc = nn.Parameter(
+ get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0),
+ requires_grad=False,
+ )
+
+ self.layer_stack = nn.ModuleList(
+ [
+ FFTBlock(
+ d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout
+ )
+ for _ in range(n_layers)
+ ]
+ )
+
+ def forward(self, src_seq, mask, return_attns=False):
+ enc_slf_attn_list = []
+ batch_size, max_len = src_seq.shape[0], src_seq.shape[1]
+
+ # -- Prepare masks
+ slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
+
+ # -- Forward
+ if not self.training and src_seq.shape[1] > self.max_seq_len:
+ enc_output = self.src_word_emb(src_seq) + get_sinusoid_encoding_table(
+ src_seq.shape[1], self.d_model
+ )[: src_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to(
+ src_seq.device
+ )
+ else:
+ enc_output = self.src_word_emb(src_seq) + self.position_enc[
+ :, :max_len, :
+ ].expand(batch_size, -1, -1)
+
+ for enc_layer in self.layer_stack:
+ enc_output, enc_slf_attn = enc_layer(
+ enc_output, mask=mask, slf_attn_mask=slf_attn_mask
+ )
+ if return_attns:
+ enc_slf_attn_list += [enc_slf_attn]
+
+ return enc_output
+
+
+class Decoder(nn.Module):
+ """Decoder"""
+
+ def __init__(self, config):
+ super(Decoder, self).__init__()
+
+ n_position = config["max_seq_len"] + 1
+ d_word_vec = config["transformer"]["decoder_hidden"]
+ n_layers = config["transformer"]["decoder_layer"]
+ n_head = config["transformer"]["decoder_head"]
+ d_k = d_v = (
+ config["transformer"]["decoder_hidden"]
+ // config["transformer"]["decoder_head"]
+ )
+ d_model = config["transformer"]["decoder_hidden"]
+ d_inner = config["transformer"]["conv_filter_size"]
+ kernel_size = config["transformer"]["conv_kernel_size"]
+ dropout = config["transformer"]["decoder_dropout"]
+
+ self.max_seq_len = config["max_seq_len"]
+ self.d_model = d_model
+
+ self.position_enc = nn.Parameter(
+ get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0),
+ requires_grad=False,
+ )
+
+ self.layer_stack = nn.ModuleList(
+ [
+ FFTBlock(
+ d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout
+ )
+ for _ in range(n_layers)
+ ]
+ )
+
+ def forward(self, enc_seq, mask, return_attns=False):
+ dec_slf_attn_list = []
+ batch_size, max_len = enc_seq.shape[0], enc_seq.shape[1]
+
+ # -- Forward
+ if not self.training and enc_seq.shape[1] > self.max_seq_len:
+ # -- Prepare masks
+ slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
+ dec_output = enc_seq + get_sinusoid_encoding_table(
+ enc_seq.shape[1], self.d_model
+ )[: enc_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to(
+ enc_seq.device
+ )
+ else:
+ max_len = min(max_len, self.max_seq_len)
+
+ # -- Prepare masks
+ slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
+ dec_output = enc_seq[:, :max_len, :] + self.position_enc[
+ :, :max_len, :
+ ].expand(batch_size, -1, -1)
+ mask = mask[:, :max_len]
+ slf_attn_mask = slf_attn_mask[:, :, :max_len]
+
+ for dec_layer in self.layer_stack:
+ dec_output, dec_slf_attn = dec_layer(
+ dec_output, mask=mask, slf_attn_mask=slf_attn_mask
+ )
+ if return_attns:
+ dec_slf_attn_list += [dec_slf_attn]
+
+ return dec_output, mask
diff --git a/modules/transformer/Modules.py b/modules/transformer/Modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..339aedd0cd3ec603724204e98aced5372cd0a334
--- /dev/null
+++ b/modules/transformer/Modules.py
@@ -0,0 +1,29 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import numpy as np
+
+
+class ScaledDotProductAttention(nn.Module):
+ """Scaled Dot-Product Attention"""
+
+ def __init__(self, temperature):
+ super().__init__()
+ self.temperature = temperature
+ self.softmax = nn.Softmax(dim=2)
+
+ def forward(self, q, k, v, mask=None):
+ attn = torch.bmm(q, k.transpose(1, 2))
+ attn = attn / self.temperature
+
+ if mask is not None:
+ attn = attn.masked_fill(mask, -np.inf)
+
+ attn = self.softmax(attn)
+ output = torch.bmm(attn, v)
+
+ return output, attn
diff --git a/modules/transformer/SubLayers.py b/modules/transformer/SubLayers.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1d1b88dfee5eaaec4cc9bfd2017d88847b18b60
--- /dev/null
+++ b/modules/transformer/SubLayers.py
@@ -0,0 +1,97 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+
+from .Modules import ScaledDotProductAttention
+
+
+class MultiHeadAttention(nn.Module):
+ """Multi-Head Attention module"""
+
+ def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
+ super().__init__()
+
+ self.n_head = n_head
+ self.d_k = d_k
+ self.d_v = d_v
+
+ self.w_qs = nn.Linear(d_model, n_head * d_k)
+ self.w_ks = nn.Linear(d_model, n_head * d_k)
+ self.w_vs = nn.Linear(d_model, n_head * d_v)
+
+ self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
+ self.layer_norm = nn.LayerNorm(d_model)
+
+ self.fc = nn.Linear(n_head * d_v, d_model)
+
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, q, k, v, mask=None):
+ d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
+
+ sz_b, len_q, _ = q.size()
+ sz_b, len_k, _ = k.size()
+ sz_b, len_v, _ = v.size()
+
+ residual = q
+
+ q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
+ k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
+ v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
+ q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
+ k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
+ v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
+
+ mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
+ output, attn = self.attention(q, k, v, mask=mask)
+
+ output = output.view(n_head, sz_b, len_q, d_v)
+ output = (
+ output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1)
+ ) # b x lq x (n*dv)
+
+ output = self.dropout(self.fc(output))
+ output = self.layer_norm(output + residual)
+
+ return output, attn
+
+
+class PositionwiseFeedForward(nn.Module):
+ """A two-feed-forward-layer module"""
+
+ def __init__(self, d_in, d_hid, kernel_size, dropout=0.1):
+ super().__init__()
+
+ # Use Conv1D
+ # position-wise
+ self.w_1 = nn.Conv1d(
+ d_in,
+ d_hid,
+ kernel_size=kernel_size[0],
+ padding=(kernel_size[0] - 1) // 2,
+ )
+ # position-wise
+ self.w_2 = nn.Conv1d(
+ d_hid,
+ d_in,
+ kernel_size=kernel_size[1],
+ padding=(kernel_size[1] - 1) // 2,
+ )
+
+ self.layer_norm = nn.LayerNorm(d_in)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x):
+ residual = x
+ output = x.transpose(1, 2)
+ output = self.w_2(F.relu(self.w_1(output)))
+ output = output.transpose(1, 2)
+ output = self.dropout(output)
+ output = self.layer_norm(output + residual)
+
+ return output
diff --git a/modules/transformer/__init__.py b/modules/transformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b790531c24500929331a498856196b3254e14e4c
--- /dev/null
+++ b/modules/transformer/__init__.py
@@ -0,0 +1,2 @@
+from .mh_attention import MultiheadAttention
+from .position_embedding import SinePositionalEmbedding
diff --git a/modules/transformer/attentions.py b/modules/transformer/attentions.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b0cfc54ae038d4ca35ba32732dc69c776f4c4f0
--- /dev/null
+++ b/modules/transformer/attentions.py
@@ -0,0 +1,416 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from utils.util import *
+from modules.base.base_module import *
+from modules.base.base_module import LayerNorm
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size=1,
+ p_dropout=0.0,
+ window_size=4,
+ **kwargs,
+ ):
+ super().__init__()
+ self.hidden_channels = hidden_channels
+ self.filter_channels = filter_channels
+ self.n_heads = n_heads
+ self.n_layers = n_layers
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.window_size = window_size
+
+ self.drop = nn.Dropout(p_dropout)
+ self.attn_layers = nn.ModuleList()
+ self.norm_layers_1 = nn.ModuleList()
+ self.ffn_layers = nn.ModuleList()
+ self.norm_layers_2 = nn.ModuleList()
+ for i in range(self.n_layers):
+ self.attn_layers.append(
+ MultiHeadAttention(
+ hidden_channels,
+ hidden_channels,
+ n_heads,
+ p_dropout=p_dropout,
+ window_size=window_size,
+ )
+ )
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
+ self.ffn_layers.append(
+ FFN(
+ hidden_channels,
+ hidden_channels,
+ filter_channels,
+ kernel_size,
+ p_dropout=p_dropout,
+ )
+ )
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
+
+ def forward(self, x, x_mask):
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
+ x = x * x_mask
+ for i in range(self.n_layers):
+ y = self.attn_layers[i](x, x, attn_mask)
+ y = self.drop(y)
+ x = self.norm_layers_1[i](x + y)
+
+ y = self.ffn_layers[i](x, x_mask)
+ y = self.drop(y)
+ x = self.norm_layers_2[i](x + y)
+ x = x * x_mask
+ return x
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size=1,
+ p_dropout=0.0,
+ proximal_bias=False,
+ proximal_init=True,
+ **kwargs,
+ ):
+ super().__init__()
+ self.hidden_channels = hidden_channels
+ self.filter_channels = filter_channels
+ self.n_heads = n_heads
+ self.n_layers = n_layers
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.proximal_bias = proximal_bias
+ self.proximal_init = proximal_init
+
+ self.drop = nn.Dropout(p_dropout)
+ self.self_attn_layers = nn.ModuleList()
+ self.norm_layers_0 = nn.ModuleList()
+ self.encdec_attn_layers = nn.ModuleList()
+ self.norm_layers_1 = nn.ModuleList()
+ self.ffn_layers = nn.ModuleList()
+ self.norm_layers_2 = nn.ModuleList()
+ for i in range(self.n_layers):
+ self.self_attn_layers.append(
+ MultiHeadAttention(
+ hidden_channels,
+ hidden_channels,
+ n_heads,
+ p_dropout=p_dropout,
+ proximal_bias=proximal_bias,
+ proximal_init=proximal_init,
+ )
+ )
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
+ self.encdec_attn_layers.append(
+ MultiHeadAttention(
+ hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
+ )
+ )
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
+ self.ffn_layers.append(
+ FFN(
+ hidden_channels,
+ hidden_channels,
+ filter_channels,
+ kernel_size,
+ p_dropout=p_dropout,
+ causal=True,
+ )
+ )
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
+
+ def forward(self, x, x_mask, h, h_mask):
+ """
+ x: decoder input
+ h: encoder output
+ """
+ self_attn_mask = subsequent_mask(x_mask.size(2)).to(
+ device=x.device, dtype=x.dtype
+ )
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
+ x = x * x_mask
+ for i in range(self.n_layers):
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
+ y = self.drop(y)
+ x = self.norm_layers_0[i](x + y)
+
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
+ y = self.drop(y)
+ x = self.norm_layers_1[i](x + y)
+
+ y = self.ffn_layers[i](x, x_mask)
+ y = self.drop(y)
+ x = self.norm_layers_2[i](x + y)
+ x = x * x_mask
+ return x
+
+
+class MultiHeadAttention(nn.Module):
+ def __init__(
+ self,
+ channels,
+ out_channels,
+ n_heads,
+ p_dropout=0.0,
+ window_size=None,
+ heads_share=True,
+ block_length=None,
+ proximal_bias=False,
+ proximal_init=False,
+ ):
+ super().__init__()
+ assert channels % n_heads == 0
+
+ self.channels = channels
+ self.out_channels = out_channels
+ self.n_heads = n_heads
+ self.p_dropout = p_dropout
+ self.window_size = window_size
+ self.heads_share = heads_share
+ self.block_length = block_length
+ self.proximal_bias = proximal_bias
+ self.proximal_init = proximal_init
+ self.attn = None
+
+ self.k_channels = channels // n_heads
+ self.conv_q = nn.Conv1d(channels, channels, 1)
+ self.conv_k = nn.Conv1d(channels, channels, 1)
+ self.conv_v = nn.Conv1d(channels, channels, 1)
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
+ self.drop = nn.Dropout(p_dropout)
+
+ if window_size is not None:
+ n_heads_rel = 1 if heads_share else n_heads
+ rel_stddev = self.k_channels**-0.5
+ self.emb_rel_k = nn.Parameter(
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
+ * rel_stddev
+ )
+ self.emb_rel_v = nn.Parameter(
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
+ * rel_stddev
+ )
+
+ nn.init.xavier_uniform_(self.conv_q.weight)
+ nn.init.xavier_uniform_(self.conv_k.weight)
+ nn.init.xavier_uniform_(self.conv_v.weight)
+ if proximal_init:
+ with torch.no_grad():
+ self.conv_k.weight.copy_(self.conv_q.weight)
+ self.conv_k.bias.copy_(self.conv_q.bias)
+
+ def forward(self, x, c, attn_mask=None):
+ q = self.conv_q(x)
+ k = self.conv_k(c)
+ v = self.conv_v(c)
+
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
+
+ x = self.conv_o(x)
+ return x
+
+ def attention(self, query, key, value, mask=None):
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
+ b, d, t_s, t_t = (*key.size(), query.size(2))
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
+
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
+ if self.window_size is not None:
+ assert (
+ t_s == t_t
+ ), "Relative attention is only available for self-attention."
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
+ rel_logits = self._matmul_with_relative_keys(
+ query / math.sqrt(self.k_channels), key_relative_embeddings
+ )
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
+ scores = scores + scores_local
+ if self.proximal_bias:
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
+ scores = scores + self._attention_bias_proximal(t_s).to(
+ device=scores.device, dtype=scores.dtype
+ )
+ if mask is not None:
+ scores = scores.masked_fill(mask == 0, -1e4)
+ if self.block_length is not None:
+ assert (
+ t_s == t_t
+ ), "Local attention is only available for self-attention."
+ block_mask = (
+ torch.ones_like(scores)
+ .triu(-self.block_length)
+ .tril(self.block_length)
+ )
+ scores = scores.masked_fill(block_mask == 0, -1e4)
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
+ p_attn = self.drop(p_attn)
+ output = torch.matmul(p_attn, value)
+ if self.window_size is not None:
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
+ value_relative_embeddings = self._get_relative_embeddings(
+ self.emb_rel_v, t_s
+ )
+ output = output + self._matmul_with_relative_values(
+ relative_weights, value_relative_embeddings
+ )
+ output = (
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
+ return output, p_attn
+
+ def _matmul_with_relative_values(self, x, y):
+ """
+ x: [b, h, l, m]
+ y: [h or 1, m, d]
+ ret: [b, h, l, d]
+ """
+ ret = torch.matmul(x, y.unsqueeze(0))
+ return ret
+
+ def _matmul_with_relative_keys(self, x, y):
+ """
+ x: [b, h, l, d]
+ y: [h or 1, m, d]
+ ret: [b, h, l, m]
+ """
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
+ return ret
+
+ def _get_relative_embeddings(self, relative_embeddings, length):
+ max_relative_position = 2 * self.window_size + 1
+ # Pad first before slice to avoid using cond ops.
+ pad_length = max(length - (self.window_size + 1), 0)
+ slice_start_position = max((self.window_size + 1) - length, 0)
+ slice_end_position = slice_start_position + 2 * length - 1
+ if pad_length > 0:
+ padded_relative_embeddings = F.pad(
+ relative_embeddings,
+ convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
+ )
+ else:
+ padded_relative_embeddings = relative_embeddings
+ used_relative_embeddings = padded_relative_embeddings[
+ :, slice_start_position:slice_end_position
+ ]
+ return used_relative_embeddings
+
+ def _relative_position_to_absolute_position(self, x):
+ """
+ x: [b, h, l, 2*l-1]
+ ret: [b, h, l, l]
+ """
+ batch, heads, length, _ = x.size()
+ # Concat columns of pad to shift from relative to absolute indexing.
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
+
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
+ x_flat = x.view([batch, heads, length * 2 * length])
+ x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
+
+ # Reshape and slice out the padded elements.
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
+ :, :, :length, length - 1 :
+ ]
+ return x_final
+
+ def _absolute_position_to_relative_position(self, x):
+ """
+ x: [b, h, l, l]
+ ret: [b, h, l, 2*l-1]
+ """
+ batch, heads, length, _ = x.size()
+ # padd along column
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
+ # add 0's in the beginning that will skew the elements after reshape
+ x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
+ return x_final
+
+ def _attention_bias_proximal(self, length):
+ """Bias for self-attention to encourage attention to close positions.
+ Args:
+ length: an integer scalar.
+ Returns:
+ a Tensor with shape [1, 1, length, length]
+ """
+ r = torch.arange(length, dtype=torch.float32)
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
+
+
+class FFN(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ filter_channels,
+ kernel_size,
+ p_dropout=0.0,
+ activation=None,
+ causal=False,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.filter_channels = filter_channels
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.activation = activation
+ self.causal = causal
+
+ if causal:
+ self.padding = self._causal_padding
+ else:
+ self.padding = self._same_padding
+
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
+ self.drop = nn.Dropout(p_dropout)
+
+ def forward(self, x, x_mask):
+ x = self.conv_1(self.padding(x * x_mask))
+ if self.activation == "gelu":
+ x = x * torch.sigmoid(1.702 * x)
+ else:
+ x = torch.relu(x)
+ x = self.drop(x)
+ x = self.conv_2(self.padding(x * x_mask))
+ return x * x_mask
+
+ def _causal_padding(self, x):
+ if self.kernel_size == 1:
+ return x
+ pad_l = self.kernel_size - 1
+ pad_r = 0
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
+ x = F.pad(x, convert_pad_shape(padding))
+ return x
+
+ def _same_padding(self, x):
+ if self.kernel_size == 1:
+ return x
+ pad_l = (self.kernel_size - 1) // 2
+ pad_r = self.kernel_size // 2
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
+ x = F.pad(x, convert_pad_shape(padding))
+ return x
diff --git a/modules/transformer/mh_attention.py b/modules/transformer/mh_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf576ca729950790bcec0f1f0706739a86fb0690
--- /dev/null
+++ b/modules/transformer/mh_attention.py
@@ -0,0 +1,417 @@
+from typing import Optional, Tuple
+
+import torch
+from torch import Tensor
+from torch.nn import Linear, Module
+from torch.nn import functional as F
+from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
+from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
+from torch.nn.parameter import Parameter
+
+
+class MultiheadAttention(Module):
+ r"""Allows the model to jointly attend to information
+ from different representation subspaces as described in the paper:
+ `Attention Is All You Need `_.
+
+ Multi-Head Attention is defined as:
+
+ .. math::
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
+
+ where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
+
+ ``forward()`` will use a special optimized implementation if all of the following
+ conditions are met:
+
+ - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
+ restriction will be loosened in the future.)
+ - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
+ - training is disabled (using ``.eval()``)
+ - dropout is 0
+ - ``add_bias_kv`` is ``False``
+ - ``add_zero_attn`` is ``False``
+ - ``batch_first`` is ``True`` and the input is batched
+ - ``kdim`` and ``vdim`` are equal to ``embed_dim``
+ - at most one of ``key_padding_mask`` or ``attn_mask`` is passed
+ - if a `NestedTensor `_ is passed, neither ``key_padding_mask``
+ nor ``attn_mask`` is passed
+
+ If the optimized implementation is in use, a
+ `NestedTensor `_ can be passed for
+ ``query``/``key``/``value`` to represent padding more efficiently than using a
+ padding mask. In this case, a `NestedTensor `_
+ will be returned, and an additional speedup proportional to the fraction of the input
+ that is padding can be expected.
+
+ Args:
+ embed_dim: Total dimension of the model.
+ num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
+ across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
+ dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
+ bias: If specified, adds bias to input / output projection layers. Default: ``True``.
+ add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
+ add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
+ Default: ``False``.
+ kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
+ vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
+ batch_first: If ``True``, then the input and output tensors are provided
+ as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
+
+ Examples::
+
+ >>> # xdoctest: +SKIP
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
+
+ """
+
+ __constants__ = ["batch_first"]
+ bias_k: Optional[torch.Tensor]
+ bias_v: Optional[torch.Tensor]
+
+ def __init__(
+ self,
+ embed_dim,
+ num_heads,
+ dropout=0.0,
+ bias=True,
+ add_bias_kv=False,
+ add_zero_attn=False,
+ kdim=None,
+ vdim=None,
+ batch_first=False,
+ linear1_cls=Linear,
+ linear2_cls=Linear,
+ device=None,
+ dtype=None,
+ ) -> None:
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super(MultiheadAttention, self).__init__()
+ self.embed_dim = embed_dim
+ self.kdim = kdim if kdim is not None else embed_dim
+ self.vdim = vdim if vdim is not None else embed_dim
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
+
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.batch_first = batch_first
+ self.head_dim = embed_dim // num_heads
+ assert (
+ self.head_dim * num_heads == self.embed_dim
+ ), "embed_dim must be divisible by num_heads"
+
+ if add_bias_kv:
+ self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
+ self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
+ else:
+ self.bias_k = self.bias_v = None
+
+ if linear1_cls == Linear:
+ if not self._qkv_same_embed_dim:
+ self.q_proj_weight = Parameter(
+ torch.empty((embed_dim, embed_dim), **factory_kwargs)
+ )
+ self.k_proj_weight = Parameter(
+ torch.empty((embed_dim, self.kdim), **factory_kwargs)
+ )
+ self.v_proj_weight = Parameter(
+ torch.empty((embed_dim, self.vdim), **factory_kwargs)
+ )
+ self.register_parameter("in_proj_weight", None)
+ else:
+ self.in_proj_weight = Parameter(
+ torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
+ )
+ self.register_parameter("q_proj_weight", None)
+ self.register_parameter("k_proj_weight", None)
+ self.register_parameter("v_proj_weight", None)
+
+ if bias:
+ self.in_proj_bias = Parameter(
+ torch.empty(3 * embed_dim, **factory_kwargs)
+ )
+ else:
+ self.register_parameter("in_proj_bias", None)
+ self.out_proj = NonDynamicallyQuantizableLinear(
+ embed_dim, embed_dim, bias=bias, **factory_kwargs
+ )
+
+ self._reset_parameters()
+ else:
+ if not self._qkv_same_embed_dim:
+ raise NotImplementedError
+ else:
+ self.in_proj_linear = linear1_cls(
+ embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
+ )
+ self.in_proj_weight = self.in_proj_linear.weight
+
+ self.register_parameter("q_proj_weight", None)
+ self.register_parameter("k_proj_weight", None)
+ self.register_parameter("v_proj_weight", None)
+
+ if bias:
+ self.in_proj_bias = self.in_proj_linear.bias
+ else:
+ self.register_parameter("in_proj_bias", None)
+
+ self.out_proj = linear2_cls(
+ embed_dim, embed_dim, bias=bias, **factory_kwargs
+ )
+
+ if self.bias_k is not None:
+ xavier_normal_(self.bias_k)
+ if self.bias_v is not None:
+ xavier_normal_(self.bias_v)
+
+ self.add_zero_attn = add_zero_attn
+
+ def _reset_parameters(self):
+ if self._qkv_same_embed_dim:
+ xavier_uniform_(self.in_proj_weight)
+ else:
+ xavier_uniform_(self.q_proj_weight)
+ xavier_uniform_(self.k_proj_weight)
+ xavier_uniform_(self.v_proj_weight)
+
+ if self.in_proj_bias is not None:
+ constant_(self.in_proj_bias, 0.0)
+ constant_(self.out_proj.bias, 0.0)
+
+ if self.bias_k is not None:
+ xavier_normal_(self.bias_k)
+ if self.bias_v is not None:
+ xavier_normal_(self.bias_v)
+
+ def __setstate__(self, state):
+ if "_qkv_same_embed_dim" not in state:
+ state["_qkv_same_embed_dim"] = True
+
+ super(MultiheadAttention, self).__setstate__(state)
+
+ def forward(
+ self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ need_weights: bool = True,
+ attn_mask: Optional[Tensor] = None,
+ average_attn_weights: bool = True,
+ ) -> Tuple[Tensor, Optional[Tensor]]:
+ r"""
+ Args:
+ query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
+ or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
+ :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
+ Queries are compared against key-value pairs to produce the output.
+ See "Attention Is All You Need" for more details.
+ key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
+ or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
+ :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
+ See "Attention Is All You Need" for more details.
+ value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
+ ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
+ sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
+ See "Attention Is All You Need" for more details.
+ key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
+ to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
+ Binary and byte masks are supported.
+ For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
+ the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
+ need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
+ Default: ``True``.
+ attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
+ :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
+ :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
+ broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
+ Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
+ corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
+ corresponding position is not allowed to attend. For a float mask, the mask values will be added to
+ the attention weight.
+ average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
+ heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
+ effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
+
+ Outputs:
+ - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
+ :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
+ where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
+ embedding dimension ``embed_dim``.
+ - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
+ returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
+ :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
+ :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
+ head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
+
+ .. note::
+ `batch_first` argument is ignored for unbatched inputs.
+ """
+ is_batched = query.dim() == 3
+ if key_padding_mask is not None:
+ _kpm_dtype = key_padding_mask.dtype
+ if _kpm_dtype != torch.bool and not torch.is_floating_point(
+ key_padding_mask
+ ):
+ raise AssertionError(
+ "only bool and floating types of key_padding_mask are supported"
+ )
+ why_not_fast_path = ""
+ if not is_batched:
+ why_not_fast_path = (
+ f"input not batched; expected query.dim() of 3 but got {query.dim()}"
+ )
+ elif query is not key or key is not value:
+ # When lifting this restriction, don't forget to either
+ # enforce that the dtypes all match or test cases where
+ # they don't!
+ why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
+ elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
+ elif (
+ self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype
+ ):
+ # this case will fail anyway, but at least they'll get a useful error message.
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
+ elif self.training:
+ why_not_fast_path = "training is enabled"
+ elif not self.batch_first:
+ why_not_fast_path = "batch_first was not True"
+ elif self.bias_k is not None:
+ why_not_fast_path = "self.bias_k was not None"
+ elif self.bias_v is not None:
+ why_not_fast_path = "self.bias_v was not None"
+ elif self.dropout:
+ why_not_fast_path = f"dropout was {self.dropout}, required zero"
+ elif self.add_zero_attn:
+ why_not_fast_path = "add_zero_attn was enabled"
+ elif not self._qkv_same_embed_dim:
+ why_not_fast_path = "_qkv_same_embed_dim was not True"
+ elif attn_mask is not None:
+ why_not_fast_path = "attn_mask was not None"
+ elif query.is_nested and key_padding_mask is not None:
+ why_not_fast_path = (
+ "key_padding_mask is not supported with NestedTensor input"
+ )
+ elif self.num_heads % 2 == 1:
+ why_not_fast_path = "num_heads is odd"
+ elif torch.is_autocast_enabled():
+ why_not_fast_path = "autocast is enabled"
+
+ if not why_not_fast_path:
+ tensor_args = (
+ query,
+ key,
+ value,
+ self.in_proj_weight,
+ self.in_proj_bias,
+ self.out_proj.weight,
+ self.out_proj.bias,
+ )
+ # We have to use list comprehensions below because TorchScript does not support
+ # generator expressions.
+ if torch.overrides.has_torch_function(tensor_args):
+ why_not_fast_path = "some Tensor argument has_torch_function"
+ elif not all(
+ [
+ (x is None or x.is_cuda or "cpu" in str(x.device))
+ for x in tensor_args
+ ]
+ ):
+ why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
+ elif torch.is_grad_enabled() and any(
+ [x is not None and x.requires_grad for x in tensor_args]
+ ):
+ why_not_fast_path = (
+ "grad is enabled and at least one of query or the "
+ "input/output projection weights or biases requires_grad"
+ )
+ if not why_not_fast_path:
+ return torch._native_multi_head_attention(
+ query,
+ key,
+ value,
+ self.embed_dim,
+ self.num_heads,
+ self.in_proj_weight,
+ self.in_proj_bias,
+ self.out_proj.weight,
+ self.out_proj.bias,
+ key_padding_mask if key_padding_mask is not None else attn_mask,
+ need_weights,
+ average_attn_weights,
+ (
+ 1
+ if key_padding_mask is not None
+ else 0 if attn_mask is not None else None
+ ),
+ )
+
+ any_nested = query.is_nested or key.is_nested or value.is_nested
+ assert not any_nested, (
+ "MultiheadAttention does not support NestedTensor outside of its fast path. "
+ + f"The fast path was not hit because {why_not_fast_path}"
+ )
+
+ if self.batch_first and is_batched:
+ # make sure that the transpose op does not affect the "is" property
+ if key is value:
+ if query is key:
+ query = key = value = query.transpose(1, 0)
+ else:
+ query, key = [x.transpose(1, 0) for x in (query, key)]
+ value = key
+ else:
+ query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
+
+ if not self._qkv_same_embed_dim:
+ attn_output, attn_output_weights = F.multi_head_attention_forward(
+ query,
+ key,
+ value,
+ self.embed_dim,
+ self.num_heads,
+ self.in_proj_weight,
+ self.in_proj_bias,
+ self.bias_k,
+ self.bias_v,
+ self.add_zero_attn,
+ self.dropout,
+ self.out_proj.weight,
+ self.out_proj.bias,
+ training=self.training,
+ key_padding_mask=key_padding_mask,
+ need_weights=need_weights,
+ attn_mask=attn_mask,
+ use_separate_proj_weight=True,
+ q_proj_weight=self.q_proj_weight,
+ k_proj_weight=self.k_proj_weight,
+ v_proj_weight=self.v_proj_weight,
+ average_attn_weights=average_attn_weights,
+ )
+ else:
+ attn_output, attn_output_weights = F.multi_head_attention_forward(
+ query,
+ key,
+ value,
+ self.embed_dim,
+ self.num_heads,
+ self.in_proj_weight,
+ self.in_proj_bias,
+ self.bias_k,
+ self.bias_v,
+ self.add_zero_attn,
+ self.dropout,
+ self.out_proj.weight,
+ self.out_proj.bias,
+ training=self.training,
+ key_padding_mask=key_padding_mask,
+ need_weights=need_weights,
+ attn_mask=attn_mask,
+ average_attn_weights=average_attn_weights,
+ )
+ if self.batch_first and is_batched:
+ return attn_output.transpose(1, 0), attn_output_weights
+ else:
+ return attn_output, attn_output_weights
diff --git a/modules/transformer/position_embedding.py b/modules/transformer/position_embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..aca97ff35fccb0ab84a7f5556c62be2337dda6c0
--- /dev/null
+++ b/modules/transformer/position_embedding.py
@@ -0,0 +1,108 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import math
+
+
+class SinePositionalEmbedding(nn.Module):
+ def __init__(
+ self,
+ dim_model: int,
+ dropout: float = 0.0,
+ scale: bool = False,
+ alpha: bool = False,
+ ):
+ super().__init__()
+ self.dim_model = dim_model
+ self.x_scale = math.sqrt(dim_model) if scale else 1.0
+ self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
+ self.dropout = torch.nn.Dropout(p=dropout)
+
+ self.reverse = False
+ self.pe = None
+ self.extend_pe(torch.tensor(0.0).expand(1, 4000))
+
+ def extend_pe(self, x):
+ """Reset the positional encodings."""
+ if self.pe is not None:
+ if self.pe.size(1) >= x.size(1):
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+ return
+ pe = torch.zeros(x.size(1), self.dim_model)
+ if self.reverse:
+ position = torch.arange(
+ x.size(1) - 1, -1, -1.0, dtype=torch.float32
+ ).unsqueeze(1)
+ else:
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, self.dim_model, 2, dtype=torch.float32)
+ * -(math.log(10000.0) / self.dim_model)
+ )
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0)
+ self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ self.extend_pe(x)
+ output = x.unsqueeze(-1) if x.ndim == 2 else x
+ output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
+ return self.dropout(output)
+
+
+# import torch
+# import torch.nn as nn
+# import math
+
+# class SinePositionalEmbedding(nn.Module):
+# def __init__(
+# self,
+# dim_model: int,
+# dropout: float = 0.0,
+# scale: bool = False,
+# alpha: bool = False,
+# ):
+# super().__init__()
+# self.dim_model = dim_model
+# self.x_scale = math.sqrt(dim_model) if scale else 1.0
+# self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
+# self.dropout = torch.nn.Dropout(p=dropout)
+
+# self.reverse = False
+# self.pe = None
+# self.extend_pe(torch.zeros(1, 4000))
+
+# def extend_pe(self, x):
+# """Reset the positional encodings."""
+# if self._pe_needs_extension(x):
+# self.pe = self._generate_positional_encodings(x)
+
+# def _pe_needs_extension(self, x):
+# return self.pe is None or self.pe.size(1) < x.size(1) or self.pe.dtype != x.dtype or self.pe.device != x.device
+
+# def _generate_positional_encodings(self, x):
+# pe = torch.zeros(x.size(1), self.dim_model)
+# position = self._get_position_tensor(x)
+# div_term = self._get_div_term()
+# pe[:, 0::2] = torch.sin(position * div_term)
+# pe[:, 1::2] = torch.cos(position * div_term)
+# return pe.unsqueeze(0).to(device=x.device, dtype=x.dtype).detach()
+
+# def _get_position_tensor(self, x):
+# position = torch.arange(x.size(1), dtype=torch.float32).unsqueeze(1)
+# return position.flip(0) if self.reverse else position
+
+# def _get_div_term(self):
+# return torch.exp(torch.arange(0, self.dim_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.dim_model))
+
+# def forward(self, x: torch.Tensor) -> torch.Tensor:
+# self.extend_pe(x)
+# output = x.unsqueeze(-1) if x.ndim == 2 else x
+# output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
+# return self.dropout(output)
diff --git a/modules/transformer/transformer.py b/modules/transformer/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3ca7370f1193285b5802462319d6c6680490b23
--- /dev/null
+++ b/modules/transformer/transformer.py
@@ -0,0 +1,415 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import copy
+from functools import partial
+from typing import Any, Callable, List, Optional, Union
+
+import torch
+from torch import Tensor, nn
+from torch.nn import functional as F
+
+from modules.norms import AdaptiveLayerNorm, LayerNorm, BalancedBasicNorm, IdentityNorm
+from modules.transformer import MultiheadAttention
+from modules.general.scaling import BalancedDoubleSwish
+
+
+class TransformerEncoderLayer(nn.Module):
+ __constants__ = ["batch_first", "norm_first"]
+
+ def __init__(
+ self,
+ d_model: int,
+ nhead: int,
+ dim_feedforward: int = 2048,
+ dropout: float = 0.1,
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
+ batch_first: bool = False,
+ norm_first: bool = False,
+ device=None,
+ dtype=None,
+ linear1_self_attention_cls: nn.Module = nn.Linear,
+ linear2_self_attention_cls: nn.Module = nn.Linear,
+ linear1_feedforward_cls: nn.Module = nn.Linear,
+ linear2_feedforward_cls: nn.Module = nn.Linear,
+ layer_norm_cls: nn.Module = LayerNorm,
+ layer_norm_eps: float = 1e-5,
+ adaptive_layer_norm=False,
+ ) -> None:
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super(TransformerEncoderLayer, self).__init__()
+ self.self_attn = MultiheadAttention(
+ d_model,
+ nhead,
+ dropout=dropout,
+ batch_first=batch_first,
+ linear1_cls=linear1_self_attention_cls,
+ linear2_cls=linear2_self_attention_cls,
+ **factory_kwargs,
+ )
+
+ # Implementation of Feedforward model
+ self.linear1 = linear1_feedforward_cls(
+ d_model, dim_feedforward, **factory_kwargs
+ )
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = linear2_feedforward_cls(
+ dim_feedforward, d_model, **factory_kwargs
+ )
+
+ self.norm_first = norm_first
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ if isinstance(activation, str):
+ activation = _get_activation_fn(activation)
+ elif isinstance(activation, partial):
+ activation = activation(d_model)
+ elif activation == BalancedDoubleSwish:
+ activation = BalancedDoubleSwish(d_model)
+
+ self.activation = activation
+
+ norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
+ if layer_norm_cls == IdentityNorm:
+ norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
+ else:
+ norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
+
+ if adaptive_layer_norm:
+ self.norm1 = AdaptiveLayerNorm(d_model, norm1)
+ self.norm2 = AdaptiveLayerNorm(d_model, norm2)
+ else:
+ self.norm1 = norm1
+ self.norm2 = norm2
+
+ def __setstate__(self, state):
+ super(TransformerEncoderLayer, self).__setstate__(state)
+ if not hasattr(self, "activation"):
+ self.activation = F.relu
+
+ def forward(
+ self,
+ src: Tensor,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ r"""Pass the input through the encoder layer.
+
+ Args:
+ src: the sequence to the encoder layer (required).
+ src_mask: the mask for the src sequence (optional).
+ src_key_padding_mask: the mask for the src keys per batch (optional).
+
+ Shape:
+ see the docs in Transformer class.
+ """
+ x, stage_embedding = src, None
+ is_src_tuple = False
+ if isinstance(src, tuple):
+ x, stage_embedding = src
+ is_src_tuple = True
+
+ if src_key_padding_mask is not None:
+ _skpm_dtype = src_key_padding_mask.dtype
+ if _skpm_dtype != torch.bool and not torch.is_floating_point(
+ src_key_padding_mask
+ ):
+ raise AssertionError(
+ "only bool and floating types of key_padding_mask are supported"
+ )
+
+ if self.norm_first:
+ x = x + self._sa_block(
+ self.norm1(x, stage_embedding),
+ src_mask,
+ src_key_padding_mask,
+ )
+ x = x + self._ff_block(self.norm2(x, stage_embedding))
+ else:
+ x = self.norm1(
+ x + self._sa_block(x, src_mask, src_key_padding_mask),
+ stage_embedding,
+ )
+ x = self.norm2(x + self._ff_block(x), stage_embedding)
+
+ if is_src_tuple:
+ return (x, stage_embedding)
+ return x
+
+ def _sa_block(
+ self,
+ x: Tensor,
+ attn_mask: Optional[Tensor],
+ key_padding_mask: Optional[Tensor],
+ ) -> Tensor:
+ x = self.self_attn(
+ x,
+ x,
+ x,
+ attn_mask=attn_mask,
+ key_padding_mask=key_padding_mask,
+ need_weights=False,
+ )[0]
+ return self.dropout1(x)
+
+ def _ff_block(self, x: Tensor) -> Tensor:
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
+ return self.dropout2(x)
+
+
+class TransformerEncoder(nn.Module):
+ """TransformerEncoder is a stack of N encoder layers."""
+
+ def __init__(self, encoder_layer, num_layers, norm=None):
+ super(TransformerEncoder, self).__init__()
+ self.layers = _get_clones(encoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+
+ def forward(
+ self,
+ src: Tensor,
+ mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ return_layer_states: bool = False,
+ ) -> Tensor:
+ # Pass the input through the encoder layers
+ output = src
+ layer_states = [] if return_layer_states else None
+
+ for mod in self.layers:
+ output = self._apply_module(
+ mod, output, mask, src_key_padding_mask, layer_states
+ )
+
+ if self.norm is not None:
+ output = self.norm(output)
+
+ return (layer_states, output) if return_layer_states else output
+
+ def _apply_module(self, module, output, mask, key_padding_mask, layer_states):
+ # Apply a single transformer module
+ output = module(output, src_mask=mask, src_key_padding_mask=key_padding_mask)
+ if layer_states is not None:
+ layer_states.append(output)
+ return output
+
+
+class TransformerDecoderLayer(nn.Module):
+ __constants__ = ["batch_first", "norm_first"]
+
+ def __init__(
+ self,
+ d_model: int,
+ nhead: int,
+ dim_feedforward: int = 2048,
+ dropout: float = 0.1,
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
+ linear1_self_attention_cls: nn.Module = nn.Linear,
+ linear2_self_attention_cls: nn.Module = nn.Linear,
+ linear1_feedforward_cls: nn.Module = nn.Linear,
+ linear2_feedforward_cls: nn.Module = nn.Linear,
+ batch_first: bool = False,
+ norm_first: bool = False,
+ device=None,
+ dtype=None,
+ layer_norm_cls: nn.Module = LayerNorm,
+ layer_norm_eps: float = 1e-5,
+ adaptive_layer_norm=False,
+ ) -> None:
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super(TransformerDecoderLayer, self).__init__()
+ self.self_attn = MultiheadAttention(
+ d_model,
+ nhead,
+ dropout=dropout,
+ batch_first=batch_first,
+ linear1_cls=linear1_self_attention_cls,
+ linear2_cls=linear2_self_attention_cls,
+ **factory_kwargs,
+ )
+ self.multihead_attn = MultiheadAttention(
+ d_model,
+ nhead,
+ dropout=dropout,
+ batch_first=batch_first,
+ linear1_cls=linear1_self_attention_cls,
+ linear2_cls=linear2_self_attention_cls,
+ **factory_kwargs,
+ )
+ self.linear1 = linear1_feedforward_cls(
+ d_model, dim_feedforward, **factory_kwargs
+ )
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = linear2_feedforward_cls(
+ dim_feedforward, d_model, **factory_kwargs
+ )
+
+ self.norm_first = norm_first
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ self.activation = self._get_activation_fn(activation)
+ self.norm1, self.norm2, self.norm3 = self._init_norm_layers(
+ d_model, layer_norm_cls, layer_norm_eps, adaptive_layer_norm, factory_kwargs
+ )
+
+ def forward(
+ self,
+ tgt: Tensor,
+ memory: Tensor,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ r"""Pass the inputs (and mask) through the decoder layer.
+
+ Args:
+ tgt: the sequence to the decoder layer (required).
+ memory: the sequence from the last layer of the encoder (required).
+ tgt_mask: the mask for the tgt sequence (optional).
+ memory_mask: the mask for the memory sequence (optional).
+ tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
+ memory_key_padding_mask: the mask for the memory keys per batch (optional).
+
+ Shape:
+ see the docs in Transformer class.
+ """
+ tgt_is_tuple = False
+ if isinstance(tgt, tuple):
+ x, stage_embedding = tgt
+ tgt_is_tuple = True
+ else:
+ x, stage_embedding = tgt, None
+
+ if self.norm_first:
+ x = x + self._sa_block(
+ self.norm1(x, stage_embedding), tgt_mask, tgt_key_padding_mask
+ )
+ x = x + self._mha_block(
+ self.norm2(x, stage_embedding),
+ memory,
+ memory_mask,
+ memory_key_padding_mask,
+ )
+ x = x + self._ff_block(self.norm3(x, stage_embedding))
+ else:
+ x = self.norm1(
+ x + self._sa_block(x, tgt_mask, tgt_key_padding_mask),
+ stage_embedding,
+ )
+ x = self.norm2(
+ x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask),
+ stage_embedding,
+ )
+ x = self.norm3(x + self._ff_block(x), stage_embedding)
+
+ if tgt_is_tuple:
+ return (x, stage_embedding)
+ return x
+
+ def _sa_block(
+ self,
+ x: Tensor,
+ attn_mask: Optional[Tensor],
+ key_padding_mask: Optional[Tensor],
+ ) -> Tensor:
+ x = self.self_attn(
+ x,
+ x,
+ x,
+ attn_mask=attn_mask,
+ key_padding_mask=key_padding_mask,
+ need_weights=False,
+ )[0]
+ return self.dropout1(x)
+
+ def _mha_block(
+ self,
+ x: Tensor,
+ mem: Tensor,
+ attn_mask: Optional[Tensor],
+ key_padding_mask: Optional[Tensor],
+ ) -> Tensor:
+ x = self.multihead_attn(
+ x,
+ mem,
+ mem,
+ attn_mask=attn_mask,
+ key_padding_mask=key_padding_mask,
+ need_weights=False,
+ )[0]
+ return self.dropout2(x)
+
+ def _ff_block(self, x: Tensor) -> Tensor:
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
+ return self.dropout3(x)
+
+ def _get_activation_fn(self, activation):
+ if isinstance(activation, str):
+ return _get_activation_fn(activation)
+ elif callable(activation):
+ return activation
+ else:
+ raise ValueError("Unsupported activation type")
+
+ def _init_norm_layers(
+ self,
+ d_model,
+ layer_norm_cls,
+ layer_norm_eps,
+ adaptive_layer_norm,
+ factory_kwargs,
+ ):
+ if adaptive_layer_norm:
+ return (
+ AdaptiveLayerNorm(
+ d_model,
+ layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs),
+ ),
+ AdaptiveLayerNorm(
+ d_model,
+ layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs),
+ ),
+ AdaptiveLayerNorm(
+ d_model,
+ layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs),
+ ),
+ )
+ else:
+ return (
+ layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs),
+ layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs),
+ (
+ layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
+ if layer_norm_cls != IdentityNorm
+ else BalancedBasicNorm(
+ d_model, eps=layer_norm_eps, **factory_kwargs
+ )
+ ),
+ )
+
+
+def _get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
+ if activation == "relu":
+ return F.relu
+ elif activation == "gelu":
+ return F.gelu
+
+ raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
+
+
+class Transpose(nn.Identity):
+ """(N, T, D) -> (N, D, T)"""
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return input.transpose(1, 2)
diff --git a/modules/transformer/transforms.py b/modules/transformer/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..d28b2e037e80f81ab75665c1f397c930bd0eca20
--- /dev/null
+++ b/modules/transformer/transforms.py
@@ -0,0 +1,214 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch.nn import functional as F
+
+import numpy as np
+
+
+DEFAULT_MIN_BIN_WIDTH = 1e-3
+DEFAULT_MIN_BIN_HEIGHT = 1e-3
+DEFAULT_MIN_DERIVATIVE = 1e-3
+
+
+def piecewise_rational_quadratic_transform(
+ inputs,
+ unnormalized_widths,
+ unnormalized_heights,
+ unnormalized_derivatives,
+ inverse=False,
+ tails=None,
+ tail_bound=1.0,
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
+):
+ if tails is None:
+ spline_fn = rational_quadratic_spline
+ spline_kwargs = {}
+ else:
+ spline_fn = unconstrained_rational_quadratic_spline
+ spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
+
+ outputs, logabsdet = spline_fn(
+ inputs=inputs,
+ unnormalized_widths=unnormalized_widths,
+ unnormalized_heights=unnormalized_heights,
+ unnormalized_derivatives=unnormalized_derivatives,
+ inverse=inverse,
+ min_bin_width=min_bin_width,
+ min_bin_height=min_bin_height,
+ min_derivative=min_derivative,
+ **spline_kwargs,
+ )
+ return outputs, logabsdet
+
+
+def searchsorted(bin_locations, inputs, eps=1e-6):
+ bin_locations[..., -1] += eps
+ return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
+
+
+def unconstrained_rational_quadratic_spline(
+ inputs,
+ unnormalized_widths,
+ unnormalized_heights,
+ unnormalized_derivatives,
+ inverse=False,
+ tails="linear",
+ tail_bound=1.0,
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
+):
+ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
+ outside_interval_mask = ~inside_interval_mask
+
+ outputs = torch.zeros_like(inputs)
+ logabsdet = torch.zeros_like(inputs)
+
+ if tails == "linear":
+ unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
+ constant = np.log(np.exp(1 - min_derivative) - 1)
+ unnormalized_derivatives[..., 0] = constant
+ unnormalized_derivatives[..., -1] = constant
+
+ outputs[outside_interval_mask] = inputs[outside_interval_mask]
+ logabsdet[outside_interval_mask] = 0
+ else:
+ raise RuntimeError("{} tails are not implemented.".format(tails))
+
+ (
+ outputs[inside_interval_mask],
+ logabsdet[inside_interval_mask],
+ ) = rational_quadratic_spline(
+ inputs=inputs[inside_interval_mask],
+ unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
+ unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
+ unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
+ inverse=inverse,
+ left=-tail_bound,
+ right=tail_bound,
+ bottom=-tail_bound,
+ top=tail_bound,
+ min_bin_width=min_bin_width,
+ min_bin_height=min_bin_height,
+ min_derivative=min_derivative,
+ )
+
+ return outputs, logabsdet
+
+
+def rational_quadratic_spline(
+ inputs,
+ unnormalized_widths,
+ unnormalized_heights,
+ unnormalized_derivatives,
+ inverse=False,
+ left=0.0,
+ right=1.0,
+ bottom=0.0,
+ top=1.0,
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
+):
+ if torch.min(inputs) < left or torch.max(inputs) > right:
+ raise ValueError("Input to a transform is not within its domain")
+
+ num_bins = unnormalized_widths.shape[-1]
+
+ if min_bin_width * num_bins > 1.0:
+ raise ValueError("Minimal bin width too large for the number of bins")
+ if min_bin_height * num_bins > 1.0:
+ raise ValueError("Minimal bin height too large for the number of bins")
+
+ widths = F.softmax(unnormalized_widths, dim=-1)
+ widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
+ cumwidths = torch.cumsum(widths, dim=-1)
+ cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
+ cumwidths = (right - left) * cumwidths + left
+ cumwidths[..., 0] = left
+ cumwidths[..., -1] = right
+ widths = cumwidths[..., 1:] - cumwidths[..., :-1]
+
+ derivatives = min_derivative + F.softplus(unnormalized_derivatives)
+
+ heights = F.softmax(unnormalized_heights, dim=-1)
+ heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
+ cumheights = torch.cumsum(heights, dim=-1)
+ cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
+ cumheights = (top - bottom) * cumheights + bottom
+ cumheights[..., 0] = bottom
+ cumheights[..., -1] = top
+ heights = cumheights[..., 1:] - cumheights[..., :-1]
+
+ if inverse:
+ bin_idx = searchsorted(cumheights, inputs)[..., None]
+ else:
+ bin_idx = searchsorted(cumwidths, inputs)[..., None]
+
+ input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
+ input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
+
+ input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
+ delta = heights / widths
+ input_delta = delta.gather(-1, bin_idx)[..., 0]
+
+ input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
+ input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
+
+ input_heights = heights.gather(-1, bin_idx)[..., 0]
+
+ if inverse:
+ a = (inputs - input_cumheights) * (
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
+ ) + input_heights * (input_delta - input_derivatives)
+ b = input_heights * input_derivatives - (inputs - input_cumheights) * (
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
+ )
+ c = -input_delta * (inputs - input_cumheights)
+
+ discriminant = b.pow(2) - 4 * a * c
+ assert (discriminant >= 0).all()
+
+ root = (2 * c) / (-b - torch.sqrt(discriminant))
+ outputs = root * input_bin_widths + input_cumwidths
+
+ theta_one_minus_theta = root * (1 - root)
+ denominator = input_delta + (
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
+ * theta_one_minus_theta
+ )
+ derivative_numerator = input_delta.pow(2) * (
+ input_derivatives_plus_one * root.pow(2)
+ + 2 * input_delta * theta_one_minus_theta
+ + input_derivatives * (1 - root).pow(2)
+ )
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
+
+ return outputs, -logabsdet
+ else:
+ theta = (inputs - input_cumwidths) / input_bin_widths
+ theta_one_minus_theta = theta * (1 - theta)
+
+ numerator = input_heights * (
+ input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
+ )
+ denominator = input_delta + (
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
+ * theta_one_minus_theta
+ )
+ outputs = input_cumheights + numerator / denominator
+
+ derivative_numerator = input_delta.pow(2) * (
+ input_derivatives_plus_one * theta.pow(2)
+ + 2 * input_delta * theta_one_minus_theta
+ + input_derivatives * (1 - theta).pow(2)
+ )
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
+
+ return outputs, logabsdet
diff --git a/modules/vocoder_blocks/__init__.py b/modules/vocoder_blocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b678da73bf7b15cf9c42a1f4f36379f62fc43896
--- /dev/null
+++ b/modules/vocoder_blocks/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .gan_utils import *
+from .norm2d import *
diff --git a/modules/vocoder_blocks/gan_utils.py b/modules/vocoder_blocks/gan_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e6fc38cb3c047ea856f8d63b0f3a755644e98ee
--- /dev/null
+++ b/modules/vocoder_blocks/gan_utils.py
@@ -0,0 +1,28 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+#################### GAN utils ####################
+
+
+import typing as tp
+
+
+def get_padding(kernel_size, dilation=1):
+ return int((kernel_size * dilation - dilation) / 2)
+
+
+def get_2d_padding(
+ kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)
+):
+ return (
+ ((kernel_size[0] - 1) * dilation[0]) // 2,
+ ((kernel_size[1] - 1) * dilation[1]) // 2,
+ )
+
+
+def init_weights(m, mean=0.0, std=0.01):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ m.weight.data.normal_(mean, std)
diff --git a/modules/vocoder_blocks/norm2d.py b/modules/vocoder_blocks/norm2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..a907f907faae2ce9cf9e695b4a8b103ab39773ce
--- /dev/null
+++ b/modules/vocoder_blocks/norm2d.py
@@ -0,0 +1,92 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+#################### Norm2D for Discriminators ####################
+
+import torch
+import torch.nn as nn
+import einops
+from torch.nn.utils import spectral_norm, weight_norm
+
+CONV_NORMALIZATIONS = frozenset(
+ [
+ "none",
+ "weight_norm",
+ "spectral_norm",
+ "time_layer_norm",
+ "layer_norm",
+ "time_group_norm",
+ ]
+)
+
+
+class ConvLayerNorm(nn.LayerNorm):
+ """
+ Convolution-friendly LayerNorm that moves channels to last dimensions
+ before running the normalization and moves them back to original position right after.
+ """
+
+ def __init__(self, normalized_shape, **kwargs):
+ super().__init__(normalized_shape, **kwargs)
+
+ def forward(self, x):
+ x = einops.rearrange(x, "b ... t -> b t ...")
+ x = super().forward(x)
+ x = einops.rearrange(x, "b t ... -> b ... t")
+ return
+
+
+def apply_parametrization_norm(module: nn.Module, norm: str = "none") -> nn.Module:
+ assert norm in CONV_NORMALIZATIONS
+ if norm == "weight_norm":
+ return weight_norm(module)
+ elif norm == "spectral_norm":
+ return spectral_norm(module)
+ else:
+ # We already check was in CONV_NORMALIZATION, so any other choice
+ # doesn't need reparametrization.
+ return module
+
+
+def get_norm_module(
+ module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs
+) -> nn.Module:
+ """Return the proper normalization module. If causal is True, this will ensure the returned
+ module is causal, or return an error if the normalization doesn't support causal evaluation.
+ """
+ assert norm in CONV_NORMALIZATIONS
+ if norm == "layer_norm":
+ assert isinstance(module, nn.modules.conv._ConvNd)
+ return ConvLayerNorm(module.out_channels, **norm_kwargs)
+ elif norm == "time_group_norm":
+ if causal:
+ raise ValueError("GroupNorm doesn't support causal evaluation.")
+ assert isinstance(module, nn.modules.conv._ConvNd)
+ return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
+ else:
+ return nn.Identity()
+
+
+class NormConv2d(nn.Module):
+ """Wrapper around Conv2d and normalization applied to this conv
+ to provide a uniform interface across normalization approaches.
+ """
+
+ def __init__(
+ self,
+ *args,
+ norm: str = "none",
+ norm_kwargs={},
+ **kwargs,
+ ):
+ super().__init__()
+ self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
+ self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
+ self.norm_type = norm
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.norm(x)
+ return x
diff --git a/modules/wenet_extractor/README.md b/modules/wenet_extractor/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..b8f0abc4d5d9fda42edba488d01bd40222a78304
--- /dev/null
+++ b/modules/wenet_extractor/README.md
@@ -0,0 +1,23 @@
+## Acknowledgement
+
+This module borrows some codes from [WeNet](https://github.com/wenet-e2e/wenet).
+
+## Citations
+
+```bibtex
+@inproceedings{yao2021wenet,
+ title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+ author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+ booktitle={Proc. Interspeech},
+ year={2021},
+ address={Brno, Czech Republic },
+ organization={IEEE}
+}
+
+@article{zhang2022wenet,
+ title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+ author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+ journal={arXiv preprint arXiv:2203.15455},
+ year={2022}
+}
+```
\ No newline at end of file
diff --git a/modules/wenet_extractor/__init__.py b/modules/wenet_extractor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/modules/wenet_extractor/cif/predictor.py b/modules/wenet_extractor/cif/predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..91428c4c418da3026e6865f334e4b4cdafd02900
--- /dev/null
+++ b/modules/wenet_extractor/cif/predictor.py
@@ -0,0 +1,274 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+from typing import Optional
+
+import torch
+from torch import nn
+from modules.wenet_extractor.utils.mask import make_pad_mask
+
+
+class Predictor(nn.Module):
+ def __init__(
+ self,
+ idim,
+ l_order,
+ r_order,
+ threshold=1.0,
+ dropout=0.1,
+ smooth_factor=1.0,
+ noise_threshold=0,
+ tail_threshold=0.45,
+ ):
+ super().__init__()
+
+ self.pad = nn.ConstantPad1d((l_order, r_order), 0.0)
+ self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
+ self.cif_output = nn.Linear(idim, 1)
+ self.dropout = torch.nn.Dropout(p=dropout)
+ self.threshold = threshold
+ self.smooth_factor = smooth_factor
+ self.noise_threshold = noise_threshold
+ self.tail_threshold = tail_threshold
+
+ def forward(
+ self,
+ hidden,
+ target_label: Optional[torch.Tensor] = None,
+ mask: torch.Tensor = torch.tensor(0),
+ ignore_id: int = -1,
+ mask_chunk_predictor: Optional[torch.Tensor] = None,
+ target_label_length: Optional[torch.Tensor] = None,
+ ):
+ h = hidden
+ context = h.transpose(1, 2)
+ queries = self.pad(context)
+ memory = self.cif_conv1d(queries)
+ output = memory + context
+ output = self.dropout(output)
+ output = output.transpose(1, 2)
+ output = torch.relu(output)
+ output = self.cif_output(output)
+ alphas = torch.sigmoid(output)
+ alphas = torch.nn.functional.relu(
+ alphas * self.smooth_factor - self.noise_threshold
+ )
+ if mask is not None:
+ mask = mask.transpose(-1, -2).float()
+ alphas = alphas * mask
+ if mask_chunk_predictor is not None:
+ alphas = alphas * mask_chunk_predictor
+ alphas = alphas.squeeze(-1)
+ mask = mask.squeeze(-1)
+ if target_label_length is not None:
+ target_length = target_label_length
+ elif target_label is not None:
+ target_length = (target_label != ignore_id).float().sum(-1)
+ else:
+ target_length = None
+ token_num = alphas.sum(-1)
+ if target_length is not None:
+ alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
+ elif self.tail_threshold > 0.0:
+ hidden, alphas, token_num = self.tail_process_fn(
+ hidden, alphas, token_num, mask=mask
+ )
+
+ acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
+
+ if target_length is None and self.tail_threshold > 0.0:
+ token_num_int = torch.max(token_num).type(torch.int32).item()
+ acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
+
+ return acoustic_embeds, token_num, alphas, cif_peak
+
+ def tail_process_fn(
+ self,
+ hidden,
+ alphas,
+ token_num: Optional[torch.Tensor] = None,
+ mask: Optional[torch.Tensor] = None,
+ ):
+ b, t, d = hidden.size()
+ tail_threshold = self.tail_threshold
+ if mask is not None:
+ zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
+ ones_t = torch.ones_like(zeros_t)
+ mask_1 = torch.cat([mask, zeros_t], dim=1)
+ mask_2 = torch.cat([ones_t, mask], dim=1)
+ mask = mask_2 - mask_1
+ tail_threshold = mask * tail_threshold
+ alphas = torch.cat([alphas, zeros_t], dim=1)
+ alphas = torch.add(alphas, tail_threshold)
+ else:
+ tail_threshold_tensor = torch.tensor(
+ [tail_threshold], dtype=alphas.dtype
+ ).to(alphas.device)
+ tail_threshold_tensor = torch.reshape(tail_threshold_tensor, (1, 1))
+ alphas = torch.cat([alphas, tail_threshold_tensor], dim=1)
+ zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
+ hidden = torch.cat([hidden, zeros], dim=1)
+ token_num = alphas.sum(dim=-1)
+ token_num_floor = torch.floor(token_num)
+
+ return hidden, alphas, token_num_floor
+
+ def gen_frame_alignments(
+ self, alphas: torch.Tensor = None, encoder_sequence_length: torch.Tensor = None
+ ):
+ batch_size, maximum_length = alphas.size()
+ int_type = torch.int32
+
+ is_training = self.training
+ if is_training:
+ token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
+ else:
+ token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
+
+ max_token_num = torch.max(token_num).item()
+
+ alphas_cumsum = torch.cumsum(alphas, dim=1)
+ alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
+ alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
+
+ index = torch.ones([batch_size, max_token_num], dtype=int_type)
+ index = torch.cumsum(index, dim=1)
+ index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
+
+ index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
+ index_div_bool_zeros = index_div.eq(0)
+ index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
+ index_div_bool_zeros_count = torch.clamp(
+ index_div_bool_zeros_count, 0, encoder_sequence_length.max()
+ )
+ token_num_mask = (~make_pad_mask(token_num, max_len=max_token_num)).to(
+ token_num.device
+ )
+ index_div_bool_zeros_count *= token_num_mask
+
+ index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(
+ 1, 1, maximum_length
+ )
+ ones = torch.ones_like(index_div_bool_zeros_count_tile)
+ zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
+ ones = torch.cumsum(ones, dim=2)
+ cond = index_div_bool_zeros_count_tile == ones
+ index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
+
+ index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(
+ torch.bool
+ )
+ index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(
+ int_type
+ )
+ index_div_bool_zeros_count_tile_out = torch.sum(
+ index_div_bool_zeros_count_tile, dim=1
+ )
+ index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(
+ int_type
+ )
+ predictor_mask = (
+ (
+ ~make_pad_mask(
+ encoder_sequence_length, max_len=encoder_sequence_length.max()
+ )
+ )
+ .type(int_type)
+ .to(encoder_sequence_length.device)
+ )
+ index_div_bool_zeros_count_tile_out = (
+ index_div_bool_zeros_count_tile_out * predictor_mask
+ )
+
+ predictor_alignments = index_div_bool_zeros_count_tile_out
+ predictor_alignments_length = predictor_alignments.sum(-1).type(
+ encoder_sequence_length.dtype
+ )
+ return predictor_alignments.detach(), predictor_alignments_length.detach()
+
+
+class MAELoss(nn.Module):
+ def __init__(self, normalize_length=False):
+ super(MAELoss, self).__init__()
+ self.normalize_length = normalize_length
+ self.criterion = torch.nn.L1Loss(reduction="sum")
+
+ def forward(self, token_length, pre_token_length):
+ loss_token_normalizer = token_length.size(0)
+ if self.normalize_length:
+ loss_token_normalizer = token_length.sum().type(torch.float32)
+ loss = self.criterion(token_length, pre_token_length)
+ loss = loss / loss_token_normalizer
+ return loss
+
+
+def cif(hidden: torch.Tensor, alphas: torch.Tensor, threshold: float):
+ batch_size, len_time, hidden_size = hidden.size()
+
+ # loop varss
+ integrate = torch.zeros([batch_size], device=hidden.device)
+ frame = torch.zeros([batch_size, hidden_size], device=hidden.device)
+ # intermediate vars along time
+ list_fires = []
+ list_frames = []
+
+ for t in range(len_time):
+ alpha = alphas[:, t]
+ distribution_completion = (
+ torch.ones([batch_size], device=hidden.device) - integrate
+ )
+
+ integrate += alpha
+ list_fires.append(integrate)
+
+ fire_place = integrate >= threshold
+ integrate = torch.where(
+ fire_place,
+ integrate - torch.ones([batch_size], device=hidden.device),
+ integrate,
+ )
+ cur = torch.where(fire_place, distribution_completion, alpha)
+ remainds = alpha - cur
+
+ frame += cur[:, None] * hidden[:, t, :]
+ list_frames.append(frame)
+ frame = torch.where(
+ fire_place[:, None].repeat(1, hidden_size),
+ remainds[:, None] * hidden[:, t, :],
+ frame,
+ )
+
+ fires = torch.stack(list_fires, 1)
+ frames = torch.stack(list_frames, 1)
+ list_ls = []
+ len_labels = torch.round(alphas.sum(-1)).int()
+ max_label_len = len_labels.max()
+ for b in range(batch_size):
+ fire = fires[b, :]
+ l = torch.index_select(
+ frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze()
+ )
+ pad_l = torch.zeros(
+ [int(max_label_len - l.size(0)), hidden_size], device=hidden.device
+ )
+ list_ls.append(torch.cat([l, pad_l], 0))
+ return torch.stack(list_ls, 0), fires
diff --git a/modules/wenet_extractor/efficient_conformer/__init__.py b/modules/wenet_extractor/efficient_conformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/modules/wenet_extractor/efficient_conformer/attention.py b/modules/wenet_extractor/efficient_conformer/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..96f5ac10d56e44c0873e6d659ed501fdedf43192
--- /dev/null
+++ b/modules/wenet_extractor/efficient_conformer/attention.py
@@ -0,0 +1,273 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+"""Multi-Head Attention layer definition."""
+
+import math
+from typing import Tuple, Optional
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+from modules.wenet_extractor.transformer.attention import MultiHeadedAttention
+
+
+class GroupedRelPositionMultiHeadedAttention(MultiHeadedAttention):
+ """Multi-Head Attention layer with relative position encoding.
+ Paper:
+ https://arxiv.org/abs/1901.02860
+ https://arxiv.org/abs/2109.01163
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+ """
+
+ def __init__(self, n_head, n_feat, dropout_rate, group_size=3):
+ """Construct an RelPositionMultiHeadedAttention object."""
+ super().__init__(n_head, n_feat, dropout_rate)
+ # linear transformation for positional encoding
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
+ self.group_size = group_size
+ self.d_k = n_feat // n_head # for GroupedAttention
+ self.n_feat = n_feat
+ # these two learnable bias are used in matrix c and matrix d
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k * self.group_size))
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k * self.group_size))
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
+
+ def rel_shift(self, x, zero_triu: bool = False):
+ """Compute relative positinal encoding.
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, size).
+ zero_triu (bool): If true, return the lower triangular part of
+ the matrix.
+ Returns:
+ torch.Tensor: Output tensor.
+ """
+
+ zero_pad = torch.zeros(
+ (x.size()[0], x.size()[1], x.size()[2], 1), device=x.device, dtype=x.dtype
+ )
+ x_padded = torch.cat([zero_pad, x], dim=-1)
+
+ x_padded = x_padded.view(x.size()[0], x.size()[1], x.size(3) + 1, x.size(2))
+ x = x_padded[:, :, 1:].view_as(x)
+
+ if zero_triu:
+ ones = torch.ones((x.size(2), x.size(3)))
+ x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
+
+ return x
+
+ def pad4group(self, Q, K, V, P, mask, group_size: int = 3):
+ """
+ q: (#batch, time1, size) -> (#batch, head, time1, size/head)
+ k,v: (#batch, time2, size) -> (#batch, head, time2, size/head)
+ p: (#batch, time2, size)
+ """
+ # Compute Overflows
+ overflow_Q = Q.size(2) % group_size
+ overflow_KV = K.size(2) % group_size
+
+ # if-else for ONNX export
+ # 0 // 0.00000000000000001 = 0
+ # 1 // 1.00000000000000001 = 1
+ padding_Q = (group_size - overflow_Q) * int(
+ overflow_Q // (overflow_Q + 0.00000000000000001)
+ )
+ padding_KV = (group_size - overflow_KV) * int(
+ overflow_KV // (overflow_KV + 0.00000000000000001)
+ )
+
+ batch_size, _, seq_len_KV, _ = K.size()
+
+ # Input Padding (B, T, D) -> (B, T + P, D)
+ Q = F.pad(Q, (0, 0, 0, padding_Q), value=0.0)
+ K = F.pad(K, (0, 0, 0, padding_KV), value=0.0)
+ V = F.pad(V, (0, 0, 0, padding_KV), value=0.0)
+
+ if mask is not None and mask.size(2) > 0: # time2 > 0:
+ mask = mask[:, ::group_size, ::group_size]
+
+ Q = (
+ Q.transpose(1, 2)
+ .contiguous()
+ .view(batch_size, -1, self.h, self.d_k * group_size)
+ .transpose(1, 2)
+ )
+ K = (
+ K.transpose(1, 2)
+ .contiguous()
+ .view(batch_size, -1, self.h, self.d_k * group_size)
+ .transpose(1, 2)
+ )
+ V = (
+ V.transpose(1, 2)
+ .contiguous()
+ .view(batch_size, -1, self.h, self.d_k * group_size)
+ .transpose(1, 2)
+ )
+
+ # process pos_emb
+ P_batch_size = P.size(0)
+ overflow_P = P.size(1) % group_size
+ padding_P = group_size - overflow_P if overflow_P else 0
+ P = F.pad(P, (0, 0, 0, padding_P), value=0.0)
+ P = P.view(P_batch_size, -1, self.h, self.d_k * group_size).transpose(1, 2)
+
+ return Q, K, V, P, mask, padding_Q
+
+ def forward_attention(
+ self,
+ value: torch.Tensor,
+ scores: torch.Tensor,
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ padding_q: Optional[int] = None,
+ ) -> torch.Tensor:
+ """Compute attention context vector.
+
+ Args:
+ value (torch.Tensor): Transformed value, size
+ (#batch, n_head, time2, d_k).
+ scores (torch.Tensor): Attention score, size
+ (#batch, n_head, time1, time2).
+ mask (torch.Tensor): Mask, size (#batch, 1, time2) or
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
+ padding_q : for GroupedAttention in efficent conformer
+
+ Returns:
+ torch.Tensor: Transformed value (#batch, time1, d_model)
+ weighted by the attention score (#batch, time1, time2).
+
+ """
+ n_batch = value.size(0)
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be True?
+ # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
+ # 1st chunk to ease the onnx export.]
+ # 2. pytorch training
+ if mask.size(2) > 0: # time2 > 0
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
+ # For last chunk, time2 might be larger than scores.size(-1)
+ mask = mask[:, :, :, : scores.size(-1)] # (batch, 1, *, time2)
+ scores = scores.masked_fill(mask, -float("inf"))
+ attn = torch.softmax(scores, dim=-1).masked_fill(
+ mask, 0.0
+ ) # (batch, head, time1, time2)
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be False?
+ # 1. onnx(16/-1, -1/-1, 16/0)
+ # 2. jit (16/-1, -1/-1, 16/0, 16/4)
+ else:
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+
+ p_attn = self.dropout(attn)
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
+
+ # n_feat!=h*d_k may be happened in GroupAttention
+ x = (
+ x.transpose(1, 2).contiguous().view(n_batch, -1, self.n_feat)
+ ) # (batch, time1, d_model)
+ if padding_q is not None:
+ # for GroupedAttention in efficent conformer
+ x = x[:, : x.size(1) - padding_q]
+
+ return self.linear_out(x) # (batch, time1, d_model)
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ pos_emb: torch.Tensor = torch.empty(0),
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+ pos_emb (torch.Tensor): Positional embedding tensor
+ (#batch, time2, size).
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
+ where `cache_t == chunk_size * num_decoding_left_chunks`
+ and `head * d_k == size`
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
+ where `cache_t == chunk_size * num_decoding_left_chunks`
+ and `head * d_k == size`
+ """
+ q = self.linear_q(query)
+ k = self.linear_k(key) # (#batch, time2, size)
+ v = self.linear_v(value)
+ p = self.linear_pos(pos_emb) # (#batch, time2, size)
+
+ batch_size, seq_len_KV, _ = k.size() # seq_len_KV = time2
+
+ # (#batch, time2, size) -> (#batch, head, time2, size/head)
+ q = q.view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
+ k = k.view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
+ v = v.view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
+ if cache.size(0) > 0:
+ # use attention cache
+ key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1)
+ k = torch.cat([key_cache, k], dim=2)
+ v = torch.cat([value_cache, v], dim=2)
+ new_cache = torch.cat((k, v), dim=-1)
+
+ # May be k and p does not match. eg. time2=18+18/2=27 > mask=36/2=18
+ if mask is not None and mask.size(2) > 0:
+ time2 = mask.size(2)
+ k = k[:, :, -time2:, :]
+ v = v[:, :, -time2:, :]
+
+ # q k v p: (batch, head, time1, d_k)
+ q, k, v, p, mask, padding_q = self.pad4group(q, k, v, p, mask, self.group_size)
+
+ # q_with_bias_u & q_with_bias_v = (batch, head, time1, d_k)
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
+
+ # compute attention score
+ # first compute matrix a and matrix c
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ # (batch, head, time1, time2)
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
+
+ # compute matrix b and matrix d
+ # (batch, head, time1, time2)
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
+ # Remove rel_shift since it is useless in speech recognition,
+ # and it requires special attention for streaming.
+ # matrix_bd = self.rel_shift(matrix_bd)
+
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
+ self.d_k * self.group_size
+ ) # (batch, head, time1, time2)
+
+ return self.forward_attention(v, scores, mask, padding_q), new_cache
diff --git a/modules/wenet_extractor/efficient_conformer/convolution.py b/modules/wenet_extractor/efficient_conformer/convolution.py
new file mode 100644
index 0000000000000000000000000000000000000000..88ca82bffbd6ed5fd48a64d71f61bc281594efa2
--- /dev/null
+++ b/modules/wenet_extractor/efficient_conformer/convolution.py
@@ -0,0 +1,163 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+"""ConvolutionModule definition."""
+from typing import Tuple
+
+import torch
+from torch import nn
+
+
+class ConvolutionModule(nn.Module):
+ """ConvolutionModule in Conformer model."""
+
+ def __init__(
+ self,
+ channels: int,
+ kernel_size: int = 15,
+ activation: nn.Module = nn.ReLU(),
+ norm: str = "batch_norm",
+ causal: bool = False,
+ bias: bool = True,
+ stride: int = 1,
+ ):
+ """Construct an ConvolutionModule object.
+ Args:
+ channels (int): The number of channels of conv layers.
+ kernel_size (int): Kernel size of conv layers.
+ causal (int): Whether use causal convolution or not
+ stride (int): Stride Convolution, for efficient Conformer
+ """
+ super().__init__()
+
+ self.pointwise_conv1 = nn.Conv1d(
+ channels,
+ 2 * channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+ # self.lorder is used to distinguish if it's a causal convolution,
+ # if self.lorder > 0: it's a causal convolution, the input will be
+ # padded with self.lorder frames on the left in forward.
+ # else: it's a symmetrical convolution
+ if causal:
+ padding = 0
+ self.lorder = kernel_size - 1
+ else:
+ # kernel_size should be an odd number for none causal convolution
+ assert (kernel_size - 1) % 2 == 0
+ padding = (kernel_size - 1) // 2
+ self.lorder = 0
+
+ self.depthwise_conv = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ stride=stride, # for depthwise_conv in StrideConv
+ padding=padding,
+ groups=channels,
+ bias=bias,
+ )
+
+ assert norm in ["batch_norm", "layer_norm"]
+ if norm == "batch_norm":
+ self.use_layer_norm = False
+ self.norm = nn.BatchNorm1d(channels)
+ else:
+ self.use_layer_norm = True
+ self.norm = nn.LayerNorm(channels)
+
+ self.pointwise_conv2 = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+ self.activation = activation
+ self.stride = stride
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ cache: torch.Tensor = torch.zeros((0, 0, 0)),
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute convolution module.
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, channels).
+ mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
+ (0, 0, 0) means fake mask.
+ cache (torch.Tensor): left context cache, it is only
+ used in causal convolution (#batch, channels, cache_t),
+ (0, 0, 0) meas fake cache.
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, channels).
+ """
+ # exchange the temporal dimension and the feature dimension
+ x = x.transpose(1, 2) # (#batch, channels, time)
+
+ # mask batch padding
+ if mask_pad.size(2) > 0: # time > 0
+ x.masked_fill_(~mask_pad, 0.0)
+
+ if self.lorder > 0:
+ if cache.size(2) == 0: # cache_t == 0
+ x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
+ else:
+ # When export ONNX,the first cache is not None but all-zero,
+ # cause shape error in residual block,
+ # eg. cache14 + x9 = 23, 23-7+1=17 != 9
+ cache = cache[:, :, -self.lorder :]
+ assert cache.size(0) == x.size(0) # equal batch
+ assert cache.size(1) == x.size(1) # equal channel
+ x = torch.cat((cache, x), dim=2)
+ assert x.size(2) > self.lorder
+ new_cache = x[:, :, -self.lorder :]
+ else:
+ # It's better we just return None if no cache is requried,
+ # However, for JIT export, here we just fake one tensor instead of
+ # None.
+ new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
+
+ # GLU mechanism
+ x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
+
+ # 1D Depthwise Conv
+ x = self.depthwise_conv(x)
+ if self.use_layer_norm:
+ x = x.transpose(1, 2)
+ x = self.activation(self.norm(x))
+ if self.use_layer_norm:
+ x = x.transpose(1, 2)
+ x = self.pointwise_conv2(x)
+ # mask batch padding
+ if mask_pad.size(2) > 0: # time > 0
+ if mask_pad.size(2) != x.size(2):
+ mask_pad = mask_pad[:, :, :: self.stride]
+ x.masked_fill_(~mask_pad, 0.0)
+
+ return x.transpose(1, 2), new_cache
diff --git a/modules/wenet_extractor/efficient_conformer/encoder.py b/modules/wenet_extractor/efficient_conformer/encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b4a91b8f3d303fda189e7d84eea4458688f7dbe
--- /dev/null
+++ b/modules/wenet_extractor/efficient_conformer/encoder.py
@@ -0,0 +1,635 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+"""Encoder definition."""
+from typing import Tuple, Optional, List, Union
+
+import torch
+import logging
+import torch.nn.functional as F
+
+from modules.wenet_extractor.transformer.positionwise_feed_forward import (
+ PositionwiseFeedForward,
+)
+from modules.wenet_extractor.transformer.embedding import PositionalEncoding
+from modules.wenet_extractor.transformer.embedding import RelPositionalEncoding
+from modules.wenet_extractor.transformer.embedding import NoPositionalEncoding
+from modules.wenet_extractor.transformer.subsampling import Conv2dSubsampling4
+from modules.wenet_extractor.transformer.subsampling import Conv2dSubsampling6
+from modules.wenet_extractor.transformer.subsampling import Conv2dSubsampling8
+from modules.wenet_extractor.transformer.subsampling import LinearNoSubsampling
+from modules.wenet_extractor.transformer.attention import MultiHeadedAttention
+from modules.wenet_extractor.transformer.attention import (
+ RelPositionMultiHeadedAttention,
+)
+from modules.wenet_extractor.transformer.encoder_layer import ConformerEncoderLayer
+
+from modules.wenet_extractor.efficient_conformer.subsampling import Conv2dSubsampling2
+from modules.wenet_extractor.efficient_conformer.convolution import ConvolutionModule
+from modules.wenet_extractor.efficient_conformer.attention import (
+ GroupedRelPositionMultiHeadedAttention,
+)
+from modules.wenet_extractor.efficient_conformer.encoder_layer import (
+ StrideConformerEncoderLayer,
+)
+
+from modules.wenet_extractor.utils.common import get_activation
+from modules.wenet_extractor.utils.mask import make_pad_mask
+from modules.wenet_extractor.utils.mask import add_optional_chunk_mask
+
+
+class EfficientConformerEncoder(torch.nn.Module):
+ """Conformer encoder module."""
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ input_layer: str = "conv2d",
+ pos_enc_layer_type: str = "rel_pos",
+ normalize_before: bool = True,
+ static_chunk_size: int = 0,
+ use_dynamic_chunk: bool = False,
+ global_cmvn: torch.nn.Module = None,
+ use_dynamic_left_chunk: bool = False,
+ macaron_style: bool = True,
+ activation_type: str = "swish",
+ use_cnn_module: bool = True,
+ cnn_module_kernel: int = 15,
+ causal: bool = False,
+ cnn_module_norm: str = "batch_norm",
+ stride_layer_idx: Optional[Union[int, List[int]]] = 3,
+ stride: Optional[Union[int, List[int]]] = 2,
+ group_layer_idx: Optional[Union[int, List[int], tuple]] = (0, 1, 2, 3),
+ group_size: int = 3,
+ stride_kernel: bool = True,
+ **kwargs,
+ ):
+ """Construct Efficient Conformer Encoder
+
+ Args:
+ input_size to use_dynamic_chunk, see in BaseEncoder
+ macaron_style (bool): Whether to use macaron style for
+ positionwise layer.
+ activation_type (str): Encoder activation function type.
+ use_cnn_module (bool): Whether to use convolution module.
+ cnn_module_kernel (int): Kernel size of convolution module.
+ causal (bool): whether to use causal convolution or not.
+ stride_layer_idx (list): layer id with StrideConv, start from 0
+ stride (list): stride size of each StrideConv in efficient conformer
+ group_layer_idx (list): layer id with GroupedAttention, start from 0
+ group_size (int): group size of every GroupedAttention layer
+ stride_kernel (bool): default True. True: recompute cnn kernels with stride.
+ """
+ super().__init__()
+ self._output_size = output_size
+
+ if pos_enc_layer_type == "abs_pos":
+ pos_enc_class = PositionalEncoding
+ elif pos_enc_layer_type == "rel_pos":
+ pos_enc_class = RelPositionalEncoding
+ elif pos_enc_layer_type == "no_pos":
+ pos_enc_class = NoPositionalEncoding
+ else:
+ raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
+
+ if input_layer == "linear":
+ subsampling_class = LinearNoSubsampling
+ elif input_layer == "conv2d2":
+ subsampling_class = Conv2dSubsampling2
+ elif input_layer == "conv2d":
+ subsampling_class = Conv2dSubsampling4
+ elif input_layer == "conv2d6":
+ subsampling_class = Conv2dSubsampling6
+ elif input_layer == "conv2d8":
+ subsampling_class = Conv2dSubsampling8
+ else:
+ raise ValueError("unknown input_layer: " + input_layer)
+
+ logging.info(
+ f"input_layer = {input_layer}, " f"subsampling_class = {subsampling_class}"
+ )
+
+ self.global_cmvn = global_cmvn
+ self.embed = subsampling_class(
+ input_size,
+ output_size,
+ dropout_rate,
+ pos_enc_class(output_size, positional_dropout_rate),
+ )
+ self.input_layer = input_layer
+ self.normalize_before = normalize_before
+ self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
+ self.static_chunk_size = static_chunk_size
+ self.use_dynamic_chunk = use_dynamic_chunk
+ self.use_dynamic_left_chunk = use_dynamic_left_chunk
+
+ activation = get_activation(activation_type)
+ self.num_blocks = num_blocks
+ self.attention_heads = attention_heads
+ self.cnn_module_kernel = cnn_module_kernel
+ self.global_chunk_size = 0
+ self.chunk_feature_map = 0
+
+ # efficient conformer configs
+ self.stride_layer_idx = (
+ [stride_layer_idx] if type(stride_layer_idx) == int else stride_layer_idx
+ )
+ self.stride = [stride] if type(stride) == int else stride
+ self.group_layer_idx = (
+ [group_layer_idx] if type(group_layer_idx) == int else group_layer_idx
+ )
+ self.grouped_size = group_size # group size of every GroupedAttention layer
+
+ assert len(self.stride) == len(self.stride_layer_idx)
+ self.cnn_module_kernels = [cnn_module_kernel] # kernel size of each StridedConv
+ for i in self.stride:
+ if stride_kernel:
+ self.cnn_module_kernels.append(self.cnn_module_kernels[-1] // i)
+ else:
+ self.cnn_module_kernels.append(self.cnn_module_kernels[-1])
+
+ logging.info(
+ f"stride_layer_idx= {self.stride_layer_idx}, "
+ f"stride = {self.stride}, "
+ f"cnn_module_kernel = {self.cnn_module_kernels}, "
+ f"group_layer_idx = {self.group_layer_idx}, "
+ f"grouped_size = {self.grouped_size}"
+ )
+
+ # feed-forward module definition
+ positionwise_layer = PositionwiseFeedForward
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ dropout_rate,
+ activation,
+ )
+ # convolution module definition
+ convolution_layer = ConvolutionModule
+
+ # encoder definition
+ index = 0
+ layers = []
+ for i in range(num_blocks):
+ # self-attention module definition
+ if i in self.group_layer_idx:
+ encoder_selfattn_layer = GroupedRelPositionMultiHeadedAttention
+ encoder_selfattn_layer_args = (
+ attention_heads,
+ output_size,
+ attention_dropout_rate,
+ self.grouped_size,
+ )
+ else:
+ if pos_enc_layer_type == "no_pos":
+ encoder_selfattn_layer = MultiHeadedAttention
+ else:
+ encoder_selfattn_layer = RelPositionMultiHeadedAttention
+ encoder_selfattn_layer_args = (
+ attention_heads,
+ output_size,
+ attention_dropout_rate,
+ )
+
+ # conformer module definition
+ if i in self.stride_layer_idx:
+ # conformer block with downsampling
+ convolution_layer_args_stride = (
+ output_size,
+ self.cnn_module_kernels[index],
+ activation,
+ cnn_module_norm,
+ causal,
+ True,
+ self.stride[index],
+ )
+ layers.append(
+ StrideConformerEncoderLayer(
+ output_size,
+ encoder_selfattn_layer(*encoder_selfattn_layer_args),
+ positionwise_layer(*positionwise_layer_args),
+ (
+ positionwise_layer(*positionwise_layer_args)
+ if macaron_style
+ else None
+ ),
+ (
+ convolution_layer(*convolution_layer_args_stride)
+ if use_cnn_module
+ else None
+ ),
+ torch.nn.AvgPool1d(
+ kernel_size=self.stride[index],
+ stride=self.stride[index],
+ padding=0,
+ ceil_mode=True,
+ count_include_pad=False,
+ ), # pointwise_conv_layer
+ dropout_rate,
+ normalize_before,
+ )
+ )
+ index = index + 1
+ else:
+ # conformer block
+ convolution_layer_args_normal = (
+ output_size,
+ self.cnn_module_kernels[index],
+ activation,
+ cnn_module_norm,
+ causal,
+ )
+ layers.append(
+ ConformerEncoderLayer(
+ output_size,
+ encoder_selfattn_layer(*encoder_selfattn_layer_args),
+ positionwise_layer(*positionwise_layer_args),
+ (
+ positionwise_layer(*positionwise_layer_args)
+ if macaron_style
+ else None
+ ),
+ (
+ convolution_layer(*convolution_layer_args_normal)
+ if use_cnn_module
+ else None
+ ),
+ dropout_rate,
+ normalize_before,
+ )
+ )
+
+ self.encoders = torch.nn.ModuleList(layers)
+
+ def set_global_chunk_size(self, chunk_size):
+ """Used in ONNX export."""
+ logging.info(f"set global chunk size: {chunk_size}, default is 0.")
+ self.global_chunk_size = chunk_size
+ if self.embed.subsampling_rate == 2:
+ self.chunk_feature_map = 2 * self.global_chunk_size + 1
+ elif self.embed.subsampling_rate == 6:
+ self.chunk_feature_map = 6 * self.global_chunk_size + 5
+ elif self.embed.subsampling_rate == 8:
+ self.chunk_feature_map = 8 * self.global_chunk_size + 7
+ else:
+ self.chunk_feature_map = 4 * self.global_chunk_size + 3
+
+ def output_size(self) -> int:
+ return self._output_size
+
+ def calculate_downsampling_factor(self, i: int) -> int:
+ factor = 1
+ for idx, stride_idx in enumerate(self.stride_layer_idx):
+ if i > stride_idx:
+ factor *= self.stride[idx]
+ return factor
+
+ def forward(
+ self,
+ xs: torch.Tensor,
+ xs_lens: torch.Tensor,
+ decoding_chunk_size: int = 0,
+ num_decoding_left_chunks: int = -1,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Embed positions in tensor.
+ Args:
+ xs: padded input tensor (B, T, D)
+ xs_lens: input length (B)
+ decoding_chunk_size: decoding chunk size for dynamic chunk
+ 0: default for training, use random dynamic chunk.
+ <0: for decoding, use full chunk.
+ >0: for decoding, use fixed chunk size as set.
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
+ the chunk size is decoding_chunk_size.
+ >=0: use num_decoding_left_chunks
+ <0: use all left chunks
+ Returns:
+ encoder output tensor xs, and subsampled masks
+ xs: padded output tensor (B, T' ~= T/subsample_rate, D)
+ masks: torch.Tensor batch padding mask after subsample
+ (B, 1, T' ~= T/subsample_rate)
+ """
+ T = xs.size(1)
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
+ if self.global_cmvn is not None:
+ xs = self.global_cmvn(xs)
+ xs, pos_emb, masks = self.embed(xs, masks)
+ mask_pad = masks # (B, 1, T/subsample_rate)
+ chunk_masks = add_optional_chunk_mask(
+ xs,
+ masks,
+ self.use_dynamic_chunk,
+ self.use_dynamic_left_chunk,
+ decoding_chunk_size,
+ self.static_chunk_size,
+ num_decoding_left_chunks,
+ )
+ index = 0 # traverse stride
+ for i, layer in enumerate(self.encoders):
+ # layer return : x, mask, new_att_cache, new_cnn_cache
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
+ if i in self.stride_layer_idx:
+ masks = masks[:, :, :: self.stride[index]]
+ chunk_masks = chunk_masks[
+ :, :: self.stride[index], :: self.stride[index]
+ ]
+ mask_pad = masks
+ pos_emb = pos_emb[:, :: self.stride[index], :]
+ index = index + 1
+
+ if self.normalize_before:
+ xs = self.after_norm(xs)
+ # Here we assume the mask is not changed in encoder layers, so just
+ # return the masks before encoder layers, and the masks will be used
+ # for cross attention with decoder later
+ return xs, masks
+
+ def forward_chunk(
+ self,
+ xs: torch.Tensor,
+ offset: int,
+ required_cache_size: int,
+ att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
+ cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
+ att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Forward just one chunk
+
+ Args:
+ xs (torch.Tensor): chunk input
+ offset (int): current offset in encoder output time stamp
+ required_cache_size (int): cache size required for next chunk
+ compuation
+ >=0: actual cache size
+ <0: means all history cache is required
+ att_cache (torch.Tensor): cache tensor for KEY & VALUE in
+ transformer/conformer attention, with shape
+ (elayers, head, cache_t1, d_k * 2), where
+ `head * d_k == hidden-dim` and
+ `cache_t1 == chunk_size * num_decoding_left_chunks`.
+ cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
+ (elayers, b=1, hidden-dim, cache_t2), where
+ `cache_t2 == cnn.lorder - 1`
+ att_mask : mask matrix of self attention
+
+ Returns:
+ torch.Tensor: output of current input xs
+ torch.Tensor: subsampling cache required for next chunk computation
+ List[torch.Tensor]: encoder layers output cache required for next
+ chunk computation
+ List[torch.Tensor]: conformer cnn cache
+
+ """
+ assert xs.size(0) == 1
+
+ # using downsampling factor to recover offset
+ offset *= self.calculate_downsampling_factor(self.num_blocks + 1)
+
+ chunk_masks = torch.ones(1, xs.size(1), device=xs.device, dtype=torch.bool)
+ chunk_masks = chunk_masks.unsqueeze(1) # (1, 1, xs-time)
+
+ real_len = 0
+ if self.global_chunk_size > 0:
+ # for ONNX decode simulation, padding xs to chunk_size
+ real_len = xs.size(1)
+ pad_len = self.chunk_feature_map - real_len
+ xs = F.pad(xs, (0, 0, 0, pad_len), value=0.0)
+ chunk_masks = F.pad(chunk_masks, (0, pad_len), value=0.0)
+
+ if self.global_cmvn is not None:
+ xs = self.global_cmvn(xs)
+
+ # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
+ xs, pos_emb, chunk_masks = self.embed(xs, chunk_masks, offset)
+ elayers, cache_t1 = att_cache.size(0), att_cache.size(2)
+ chunk_size = xs.size(1)
+ attention_key_size = cache_t1 + chunk_size
+ # NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim)
+ # shape(pos_emb) = (b=1, chunk_size, emb_size=output_size=hidden-dim)
+
+ if required_cache_size < 0:
+ next_cache_start = 0
+ elif required_cache_size == 0:
+ next_cache_start = attention_key_size
+ else:
+ next_cache_start = max(attention_key_size - required_cache_size, 0)
+
+ r_att_cache = []
+ r_cnn_cache = []
+ mask_pad = torch.ones(1, xs.size(1), device=xs.device, dtype=torch.bool)
+ mask_pad = mask_pad.unsqueeze(1) # batchPad (b=1, 1, time=chunk_size)
+
+ if self.global_chunk_size > 0:
+ # for ONNX decode simulation
+ pos_emb = self.embed.position_encoding(
+ offset=max(offset - cache_t1, 0), size=cache_t1 + self.global_chunk_size
+ )
+ att_mask[:, :, -self.global_chunk_size :] = chunk_masks
+ mask_pad = chunk_masks.to(torch.bool)
+ else:
+ pos_emb = self.embed.position_encoding(
+ offset=offset - cache_t1, size=attention_key_size
+ )
+
+ max_att_len, max_cnn_len = 0, 0 # for repeat_interleave of new_att_cache
+ for i, layer in enumerate(self.encoders):
+ factor = self.calculate_downsampling_factor(i)
+ # NOTE(xcsong): Before layer.forward
+ # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
+ # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
+ # shape(new_att_cache) = [ batch, head, time2, outdim//head * 2 ]
+ att_cache_trunc = 0
+ if xs.size(1) + att_cache.size(2) / factor > pos_emb.size(1):
+ # The time step is not divisible by the downsampling multiple
+ att_cache_trunc = (
+ xs.size(1) + att_cache.size(2) // factor - pos_emb.size(1) + 1
+ )
+ xs, _, new_att_cache, new_cnn_cache = layer(
+ xs,
+ att_mask,
+ pos_emb,
+ mask_pad=mask_pad,
+ att_cache=att_cache[i : i + 1, :, ::factor, :][
+ :, :, att_cache_trunc:, :
+ ],
+ cnn_cache=cnn_cache[i, :, :, :] if cnn_cache.size(0) > 0 else cnn_cache,
+ )
+
+ if i in self.stride_layer_idx:
+ # compute time dimension for next block
+ efficient_index = self.stride_layer_idx.index(i)
+ att_mask = att_mask[
+ :, :: self.stride[efficient_index], :: self.stride[efficient_index]
+ ]
+ mask_pad = mask_pad[
+ :, :: self.stride[efficient_index], :: self.stride[efficient_index]
+ ]
+ pos_emb = pos_emb[:, :: self.stride[efficient_index], :]
+
+ # shape(new_att_cache) = [batch, head, time2, outdim]
+ new_att_cache = new_att_cache[:, :, next_cache_start // factor :, :]
+ # shape(new_cnn_cache) = [1, batch, outdim, cache_t2]
+ new_cnn_cache = new_cnn_cache.unsqueeze(0)
+
+ # use repeat_interleave to new_att_cache
+ new_att_cache = new_att_cache.repeat_interleave(repeats=factor, dim=2)
+ # padding new_cnn_cache to cnn.lorder for casual convolution
+ new_cnn_cache = F.pad(
+ new_cnn_cache, (self.cnn_module_kernel - 1 - new_cnn_cache.size(3), 0)
+ )
+
+ if i == 0:
+ # record length for the first block as max length
+ max_att_len = new_att_cache.size(2)
+ max_cnn_len = new_cnn_cache.size(3)
+
+ # update real shape of att_cache and cnn_cache
+ r_att_cache.append(new_att_cache[:, :, -max_att_len:, :])
+ r_cnn_cache.append(new_cnn_cache[:, :, :, -max_cnn_len:])
+
+ if self.normalize_before:
+ xs = self.after_norm(xs)
+
+ # NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
+ # ? may be larger than cache_t1, it depends on required_cache_size
+ r_att_cache = torch.cat(r_att_cache, dim=0)
+ # NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
+ r_cnn_cache = torch.cat(r_cnn_cache, dim=0)
+
+ if self.global_chunk_size > 0 and real_len:
+ chunk_real_len = (
+ real_len
+ // self.embed.subsampling_rate
+ // self.calculate_downsampling_factor(self.num_blocks + 1)
+ )
+ # Keeping 1 more timestep can mitigate information leakage
+ # from the encoder caused by the padding
+ xs = xs[:, : chunk_real_len + 1, :]
+
+ return xs, r_att_cache, r_cnn_cache
+
+ def forward_chunk_by_chunk(
+ self,
+ xs: torch.Tensor,
+ decoding_chunk_size: int,
+ num_decoding_left_chunks: int = -1,
+ use_onnx=False,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Forward input chunk by chunk with chunk_size like a streaming
+ fashion
+
+ Here we should pay special attention to computation cache in the
+ streaming style forward chunk by chunk. Three things should be taken
+ into account for computation in the current network:
+ 1. transformer/conformer encoder layers output cache
+ 2. convolution in conformer
+ 3. convolution in subsampling
+
+ However, we don't implement subsampling cache for:
+ 1. We can control subsampling module to output the right result by
+ overlapping input instead of cache left context, even though it
+ wastes some computation, but subsampling only takes a very
+ small fraction of computation in the whole model.
+ 2. Typically, there are several covolution layers with subsampling
+ in subsampling module, it is tricky and complicated to do cache
+ with different convolution layers with different subsampling
+ rate.
+ 3. Currently, nn.Sequential is used to stack all the convolution
+ layers in subsampling, we need to rewrite it to make it work
+ with cache, which is not prefered.
+ Args:
+ xs (torch.Tensor): (1, max_len, dim)
+ decoding_chunk_size (int): decoding chunk size
+ num_decoding_left_chunks (int):
+ use_onnx (bool): True for simulating ONNX model inference.
+ """
+ assert decoding_chunk_size > 0
+ # The model is trained by static or dynamic chunk
+ assert self.static_chunk_size > 0 or self.use_dynamic_chunk
+ subsampling = self.embed.subsampling_rate
+ context = self.embed.right_context + 1 # Add current frame
+ stride = subsampling * decoding_chunk_size
+ decoding_window = (decoding_chunk_size - 1) * subsampling + context
+ num_frames = xs.size(1)
+
+ outputs = []
+ offset = 0
+ required_cache_size = decoding_chunk_size * num_decoding_left_chunks
+ if use_onnx:
+ logging.info("Simulating for ONNX runtime ...")
+ att_cache: torch.Tensor = torch.zeros(
+ (
+ self.num_blocks,
+ self.attention_heads,
+ required_cache_size,
+ self.output_size() // self.attention_heads * 2,
+ ),
+ device=xs.device,
+ )
+ cnn_cache: torch.Tensor = torch.zeros(
+ (self.num_blocks, 1, self.output_size(), self.cnn_module_kernel - 1),
+ device=xs.device,
+ )
+ self.set_global_chunk_size(chunk_size=decoding_chunk_size)
+ else:
+ logging.info("Simulating for JIT runtime ...")
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
+
+ # Feed forward overlap input step by step
+ for cur in range(0, num_frames - context + 1, stride):
+ end = min(cur + decoding_window, num_frames)
+ logging.info(
+ f"-->> frame chunk msg: cur={cur}, "
+ f"end={end}, num_frames={end-cur}, "
+ f"decoding_window={decoding_window}"
+ )
+ if use_onnx:
+ att_mask: torch.Tensor = torch.ones(
+ (1, 1, required_cache_size + decoding_chunk_size),
+ dtype=torch.bool,
+ device=xs.device,
+ )
+ if cur == 0:
+ att_mask[:, :, :required_cache_size] = 0
+ else:
+ att_mask: torch.Tensor = torch.ones(
+ (0, 0, 0), dtype=torch.bool, device=xs.device
+ )
+
+ chunk_xs = xs[:, cur:end, :]
+ (y, att_cache, cnn_cache) = self.forward_chunk(
+ chunk_xs, offset, required_cache_size, att_cache, cnn_cache, att_mask
+ )
+ outputs.append(y)
+ offset += y.size(1)
+
+ ys = torch.cat(outputs, 1)
+ masks = torch.ones(1, 1, ys.size(1), device=ys.device, dtype=torch.bool)
+ return ys, masks
diff --git a/modules/wenet_extractor/efficient_conformer/encoder_layer.py b/modules/wenet_extractor/efficient_conformer/encoder_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..9156c0edf9f79df6425accd542369e139591ed42
--- /dev/null
+++ b/modules/wenet_extractor/efficient_conformer/encoder_layer.py
@@ -0,0 +1,172 @@
+## This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+"""Encoder self-attention layer definition."""
+
+from typing import Optional, Tuple
+import torch
+from torch import nn
+
+
+class StrideConformerEncoderLayer(nn.Module):
+ """Encoder layer module.
+ Args:
+ size (int): Input dimension.
+ self_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
+ instance can be used as the argument.
+ feed_forward (torch.nn.Module): Feed-forward module instance.
+ `PositionwiseFeedForward` instance can be used as the argument.
+ feed_forward_macaron (torch.nn.Module): Additional feed-forward module
+ instance.
+ `PositionwiseFeedForward` instance can be used as the argument.
+ conv_module (torch.nn.Module): Convolution module instance.
+ `ConvlutionModule` instance can be used as the argument.
+ dropout_rate (float): Dropout rate.
+ normalize_before (bool):
+ True: use layer_norm before each sub-block.
+ False: use layer_norm after each sub-block.
+ """
+
+ def __init__(
+ self,
+ size: int,
+ self_attn: torch.nn.Module,
+ feed_forward: Optional[nn.Module] = None,
+ feed_forward_macaron: Optional[nn.Module] = None,
+ conv_module: Optional[nn.Module] = None,
+ pointwise_conv_layer: Optional[nn.Module] = None,
+ dropout_rate: float = 0.1,
+ normalize_before: bool = True,
+ ):
+ """Construct an EncoderLayer object."""
+ super().__init__()
+ self.self_attn = self_attn
+ self.feed_forward = feed_forward
+ self.feed_forward_macaron = feed_forward_macaron
+ self.conv_module = conv_module
+ self.pointwise_conv_layer = pointwise_conv_layer
+ self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module
+ self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module
+ if feed_forward_macaron is not None:
+ self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
+ self.ff_scale = 0.5
+ else:
+ self.ff_scale = 1.0
+ if self.conv_module is not None:
+ self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module
+ self.norm_final = nn.LayerNorm(
+ size, eps=1e-5
+ ) # for the final output of the block
+ self.dropout = nn.Dropout(dropout_rate)
+ self.size = size
+ self.normalize_before = normalize_before
+ self.concat_linear = nn.Linear(size + size, size)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mask: torch.Tensor,
+ pos_emb: torch.Tensor,
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Compute encoded features.
+
+ Args:
+ x (torch.Tensor): (#batch, time, size)
+ mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
+ (0, 0, 0) means fake mask.
+ pos_emb (torch.Tensor): positional encoding, must not be None
+ for ConformerEncoderLayer.
+ mask_pad (torch.Tensor): batch padding mask used for conv module.
+ (#batch, 1,time), (0, 0, 0) means fake mask.
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
+ (#batch=1, size, cache_t2)
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, size).
+ torch.Tensor: Mask tensor (#batch, time, time).
+ torch.Tensor: att_cache tensor,
+ (#batch=1, head, cache_t1 + time, d_k * 2).
+ torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
+ """
+
+ # whether to use macaron style
+ if self.feed_forward_macaron is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm_ff_macaron(x)
+ x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
+ if not self.normalize_before:
+ x = self.norm_ff_macaron(x)
+
+ # multi-headed self-attention module
+ residual = x
+ if self.normalize_before:
+ x = self.norm_mha(x)
+
+ x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, att_cache)
+
+ x = residual + self.dropout(x_att)
+ if not self.normalize_before:
+ x = self.norm_mha(x)
+
+ # convolution module
+ # Fake new cnn cache here, and then change it in conv_module
+ new_cnn_cache = torch.tensor([0.0], dtype=x.dtype, device=x.device)
+ if self.conv_module is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm_conv(x)
+ x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
+
+ # add pointwise_conv for efficient conformer
+ # pointwise_conv_layer does not change shape
+ if self.pointwise_conv_layer is not None:
+ residual = residual.transpose(1, 2)
+ residual = self.pointwise_conv_layer(residual)
+ residual = residual.transpose(1, 2)
+ assert residual.size(0) == x.size(0)
+ assert residual.size(1) == x.size(1)
+ assert residual.size(2) == x.size(2)
+
+ x = residual + self.dropout(x)
+
+ if not self.normalize_before:
+ x = self.norm_conv(x)
+
+ # feed forward module
+ residual = x
+ if self.normalize_before:
+ x = self.norm_ff(x)
+
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
+ if not self.normalize_before:
+ x = self.norm_ff(x)
+
+ if self.conv_module is not None:
+ x = self.norm_final(x)
+
+ return x, mask, new_att_cache, new_cnn_cache
diff --git a/modules/wenet_extractor/efficient_conformer/subsampling.py b/modules/wenet_extractor/efficient_conformer/subsampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..6017084c2accb40c86be16d798e587b4f32c26d4
--- /dev/null
+++ b/modules/wenet_extractor/efficient_conformer/subsampling.py
@@ -0,0 +1,81 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+
+"""Subsampling layer definition."""
+
+from typing import Tuple, Union
+
+import torch
+from modules.wenet_extractor.transformer.subsampling import BaseSubsampling
+
+
+class Conv2dSubsampling2(BaseSubsampling):
+ """Convolutional 2D subsampling (to 1/4 length).
+
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(
+ self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
+ ):
+ """Construct an Conv2dSubsampling4 object."""
+ super().__init__()
+ self.conv = torch.nn.Sequential(torch.nn.Conv2d(1, odim, 3, 2), torch.nn.ReLU())
+ self.out = torch.nn.Sequential(torch.nn.Linear(odim * ((idim - 1) // 2), odim))
+ self.pos_enc = pos_enc_class
+ # The right context for every conv layer is computed by:
+ # (kernel_size - 1) * frame_rate_of_this_layer
+ self.subsampling_rate = 2
+ # 2 = (3 - 1) * 1
+ self.right_context = 2
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Subsample x.
+
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+ Returns:
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
+ where time' = time // 2.
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
+ where time' = time // 2.
+ torch.Tensor: positional encoding
+
+ """
+ x = x.unsqueeze(1) # (b, c=1, t, f)
+ x = self.conv(x)
+ b, c, t, f = x.size()
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
+ x, pos_emb = self.pos_enc(x, offset)
+ return x, pos_emb, x_mask[:, :, :-2:2]
diff --git a/modules/wenet_extractor/paraformer/paraformer.py b/modules/wenet_extractor/paraformer/paraformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f144baa2a5df2292fa4616952338c8f937bfcba
--- /dev/null
+++ b/modules/wenet_extractor/paraformer/paraformer.py
@@ -0,0 +1,366 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+from typing import Dict, Optional, Tuple
+
+import torch
+
+from modules.wenet_extractor.cif.predictor import MAELoss
+from modules.wenet_extractor.paraformer.search.beam_search import Hypothesis
+from modules.wenet_extractor.transformer.asr_model import ASRModel
+from modules.wenet_extractor.transformer.ctc import CTC
+from modules.wenet_extractor.transformer.decoder import TransformerDecoder
+from modules.wenet_extractor.transformer.encoder import TransformerEncoder
+from modules.wenet_extractor.utils.common import IGNORE_ID, add_sos_eos, th_accuracy
+from modules.wenet_extractor.utils.mask import make_pad_mask
+
+
+class Paraformer(ASRModel):
+ """Paraformer: Fast and Accurate Parallel Transformer for
+ Non-autoregressive End-to-End Speech Recognition
+ see https://arxiv.org/pdf/2206.08317.pdf
+ """
+
+ def __init__(
+ self,
+ vocab_size: int,
+ encoder: TransformerEncoder,
+ decoder: TransformerDecoder,
+ ctc: CTC,
+ predictor,
+ ctc_weight: float = 0.5,
+ predictor_weight: float = 1.0,
+ predictor_bias: int = 0,
+ ignore_id: int = IGNORE_ID,
+ reverse_weight: float = 0.0,
+ lsm_weight: float = 0.0,
+ length_normalized_loss: bool = False,
+ ):
+ assert 0.0 <= ctc_weight <= 1.0, ctc_weight
+ assert 0.0 <= predictor_weight <= 1.0, predictor_weight
+
+ super().__init__(
+ vocab_size,
+ encoder,
+ decoder,
+ ctc,
+ ctc_weight,
+ ignore_id,
+ reverse_weight,
+ lsm_weight,
+ length_normalized_loss,
+ )
+ self.predictor = predictor
+ self.predictor_weight = predictor_weight
+ self.predictor_bias = predictor_bias
+ self.criterion_pre = MAELoss(normalize_length=length_normalized_loss)
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ ) -> Dict[str, Optional[torch.Tensor]]:
+ """Frontend + Encoder + Decoder + Calc loss
+
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ """
+ assert text_lengths.dim() == 1, text_lengths.shape
+ # Check that batch_size is unified
+ assert (
+ speech.shape[0]
+ == speech_lengths.shape[0]
+ == text.shape[0]
+ == text_lengths.shape[0]
+ ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
+ # 1. Encoder
+ encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
+ encoder_out_lens = encoder_mask.squeeze(1).sum(1)
+
+ # 2a. Attention-decoder branch
+ if self.ctc_weight != 1.0:
+ loss_att, acc_att, loss_pre = self._calc_att_loss(
+ encoder_out, encoder_mask, text, text_lengths
+ )
+ else:
+ # loss_att = None
+ # loss_pre = None
+ loss_att: torch.Tensor = torch.tensor(0)
+ loss_pre: torch.Tensor = torch.tensor(0)
+
+ # 2b. CTC branch
+ if self.ctc_weight != 0.0:
+ loss_ctc = self.ctc(encoder_out, encoder_out_lens, text, text_lengths)
+ else:
+ loss_ctc = None
+
+ if loss_ctc is None:
+ loss = loss_att + self.predictor_weight * loss_pre
+ # elif loss_att is None:
+ elif loss_att == torch.tensor(0):
+ loss = loss_ctc
+ else:
+ loss = (
+ self.ctc_weight * loss_ctc
+ + (1 - self.ctc_weight) * loss_att
+ + self.predictor_weight * loss_pre
+ )
+ return {
+ "loss": loss,
+ "loss_att": loss_att,
+ "loss_ctc": loss_ctc,
+ "loss_pre": loss_pre,
+ }
+
+ def _calc_att_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_mask: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ) -> Tuple[torch.Tensor, float, torch.Tensor]:
+ if self.predictor_bias == 1:
+ _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_pad_lens = ys_pad_lens + self.predictor_bias
+ pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(
+ encoder_out, ys_pad, encoder_mask, ignore_id=self.ignore_id
+ )
+ # 1. Forward decoder
+ decoder_out, _, _ = self.decoder(
+ encoder_out, encoder_mask, pre_acoustic_embeds, ys_pad_lens
+ )
+
+ # 2. Compute attention loss
+ loss_att = self.criterion_att(decoder_out, ys_pad)
+ acc_att = th_accuracy(
+ decoder_out.view(-1, self.vocab_size),
+ ys_pad,
+ ignore_label=self.ignore_id,
+ )
+ loss_pre: torch.Tensor = self.criterion_pre(
+ ys_pad_lens.type_as(pre_token_length), pre_token_length
+ )
+
+ return loss_att, acc_att, loss_pre
+
+ def calc_predictor(self, encoder_out, encoder_mask):
+ encoder_mask = (
+ ~make_pad_mask(encoder_mask, max_len=encoder_out.size(1))[:, None, :]
+ ).to(encoder_out.device)
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(
+ encoder_out, None, encoder_mask, ignore_id=self.ignore_id
+ )
+ return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
+
+ def cal_decoder_with_predictor(
+ self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
+ ):
+ decoder_out, _, _ = self.decoder(
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
+ )
+ decoder_out = torch.log_softmax(decoder_out, dim=-1)
+ return decoder_out, ys_pad_lens
+
+ def recognize(self):
+ raise NotImplementedError
+
+ def paraformer_greedy_search(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ decoding_chunk_size: int = -1,
+ num_decoding_left_chunks: int = -1,
+ simulate_streaming: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Apply beam search on attention decoder
+
+ Args:
+ speech (torch.Tensor): (batch, max_len, feat_dim)
+ speech_length (torch.Tensor): (batch, )
+ decoding_chunk_size (int): decoding chunk for dynamic chunk
+ trained model.
+ <0: for decoding, use full chunk.
+ >0: for decoding, use fixed chunk size as set.
+ 0: used for training, it's prohibited here
+ simulate_streaming (bool): whether do encoder forward in a
+ streaming fashion
+
+ Returns:
+ torch.Tensor: decoding result, (batch, max_result_len)
+ """
+ assert speech.shape[0] == speech_lengths.shape[0]
+ assert decoding_chunk_size != 0
+ device = speech.device
+ batch_size = speech.shape[0]
+
+ # Let's assume B = batch_size and N = beam_size
+ # 1. Encoder
+ encoder_out, encoder_mask = self._forward_encoder(
+ speech,
+ speech_lengths,
+ decoding_chunk_size,
+ num_decoding_left_chunks,
+ simulate_streaming,
+ ) # (B, maxlen, encoder_dim)
+ encoder_out_lens = encoder_mask.squeeze(1).sum(1)
+ # 2. Predictor
+ predictor_outs = self.calc_predictor(encoder_out, encoder_mask)
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = (
+ predictor_outs[0],
+ predictor_outs[1],
+ predictor_outs[2],
+ predictor_outs[3],
+ )
+ pre_token_length = pre_token_length.round().long()
+ if torch.max(pre_token_length) < 1:
+ return torch.tensor([]), torch.tensor([])
+ # 2. Decoder forward
+ decoder_outs = self.cal_decoder_with_predictor(
+ encoder_out, encoder_out_lens, pre_acoustic_embeds, pre_token_length
+ )
+ decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+ hyps = []
+ b, n, d = decoder_out.size()
+ for i in range(b):
+ x = encoder_out[i, : encoder_out_lens[i], :]
+ am_scores = decoder_out[i, : pre_token_length[i], :]
+ yseq = am_scores.argmax(dim=-1)
+ score = am_scores.max(dim=-1)[0]
+ score = torch.sum(score, dim=-1)
+ # pad with mask tokens to ensure compatibility with sos/eos tokens
+ yseq = torch.tensor(
+ [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
+ )
+ nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
+
+ for hyp in nbest_hyps:
+ assert isinstance(hyp, (Hypothesis)), type(hyp)
+
+ # remove sos/eos and get hyps
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1:last_pos]
+ else:
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # remove blank symbol id and unk id, which is assumed to be 0
+ # and 1
+ token_int = list(filter(lambda x: x != 0 and x != 1, token_int))
+ hyps.append(token_int)
+ return hyps
+
+ def paraformer_beam_search(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ beam_search: torch.nn.Module = None,
+ decoding_chunk_size: int = -1,
+ num_decoding_left_chunks: int = -1,
+ simulate_streaming: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Apply beam search on attention decoder
+
+ Args:
+ speech (torch.Tensor): (batch, max_len, feat_dim)
+ speech_lengths (torch.Tensor): (batch, )
+ beam_search (torch.nn.Moudle): beam search module
+ decoding_chunk_size (int): decoding chunk for dynamic chunk
+ trained model.
+ <0: for decoding, use full chunk.
+ >0: for decoding, use fixed chunk size as set.
+ 0: used for training, it's prohibited here
+ simulate_streaming (bool): whether do encoder forward in a
+ streaming fashion
+
+ Returns:
+ torch.Tensor: decoding result, (batch, max_result_len)
+ """
+ assert speech.shape[0] == speech_lengths.shape[0]
+ assert decoding_chunk_size != 0
+ device = speech.device
+ batch_size = speech.shape[0]
+
+ # Let's assume B = batch_size and N = beam_size
+ # 1. Encoder
+ encoder_out, encoder_mask = self._forward_encoder(
+ speech,
+ speech_lengths,
+ decoding_chunk_size,
+ num_decoding_left_chunks,
+ simulate_streaming,
+ ) # (B, maxlen, encoder_dim)
+ encoder_out_lens = encoder_mask.squeeze(1).sum(1)
+ # 2. Predictor
+ predictor_outs = self.calc_predictor(encoder_out, encoder_mask)
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = (
+ predictor_outs[0],
+ predictor_outs[1],
+ predictor_outs[2],
+ predictor_outs[3],
+ )
+ pre_token_length = pre_token_length.round().long()
+ if torch.max(pre_token_length) < 1:
+ return torch.tensor([]), torch.tensor([])
+ # 2. Decoder forward
+ decoder_outs = self.cal_decoder_with_predictor(
+ encoder_out, encoder_out_lens, pre_acoustic_embeds, pre_token_length
+ )
+ decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+ hyps = []
+ b, n, d = decoder_out.size()
+ for i in range(b):
+ x = encoder_out[i, : encoder_out_lens[i], :]
+ am_scores = decoder_out[i, : pre_token_length[i], :]
+ if beam_search is not None:
+ nbest_hyps = beam_search(x=x, am_scores=am_scores)
+ nbest_hyps = nbest_hyps[:1]
+ else:
+ yseq = am_scores.argmax(dim=-1)
+ score = am_scores.max(dim=-1)[0]
+ score = torch.sum(score, dim=-1)
+ # pad with mask tokens to ensure compatibility with sos/eos
+ # tokens
+ yseq = torch.tensor(
+ [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
+ )
+ nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
+
+ for hyp in nbest_hyps:
+ assert isinstance(hyp, (Hypothesis)), type(hyp)
+
+ # remove sos/eos and get hyps
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1:last_pos]
+ else:
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # remove blank symbol id and unk id, which is assumed to be 0
+ # and 1
+ token_int = list(filter(lambda x: x != 0 and x != 1, token_int))
+ hyps.append(token_int)
+ return hyps
diff --git a/modules/wenet_extractor/paraformer/search/beam_search.py b/modules/wenet_extractor/paraformer/search/beam_search.py
new file mode 100644
index 0000000000000000000000000000000000000000..5389ce4ac7f0026ab1dcf042c62b9c790bd61085
--- /dev/null
+++ b/modules/wenet_extractor/paraformer/search/beam_search.py
@@ -0,0 +1,479 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+from itertools import chain
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Tuple
+from typing import Union
+from typing import NamedTuple
+
+import torch
+
+from modules.wenet_extractor.paraformer.utils import end_detect
+from modules.wenet_extractor.paraformer.search.ctc import CTCPrefixScorer
+from modules.wenet_extractor.paraformer.search.scorer_interface import (
+ ScorerInterface,
+ PartialScorerInterface,
+)
+
+
+class Hypothesis(NamedTuple):
+ """Hypothesis data type."""
+
+ yseq: torch.Tensor
+ score: Union[float, torch.Tensor] = 0
+ scores: Dict[str, Union[float, torch.Tensor]] = dict()
+ states: Dict[str, Any] = dict()
+
+ def asdict(self) -> dict:
+ """Convert data to JSON-friendly dict."""
+ return self._replace(
+ yseq=self.yseq.tolist(),
+ score=float(self.score),
+ scores={k: float(v) for k, v in self.scores.items()},
+ )._asdict()
+
+
+class BeamSearchCIF(torch.nn.Module):
+ """Beam search implementation."""
+
+ def __init__(
+ self,
+ scorers: Dict[str, ScorerInterface],
+ weights: Dict[str, float],
+ beam_size: int,
+ vocab_size: int,
+ sos: int,
+ eos: int,
+ pre_beam_ratio: float = 1.5,
+ pre_beam_score_key: str = None,
+ ):
+ """Initialize beam search.
+
+ Args:
+ scorers (dict[str, ScorerInterface]): Dict of decoder modules
+ e.g., Decoder, CTCPrefixScorer, LM
+ The scorer will be ignored if it is `None`
+ weights (dict[str, float]): Dict of weights for each scorers
+ The scorer will be ignored if its weight is 0
+ beam_size (int): The number of hypotheses kept during search
+ vocab_size (int): The number of vocabulary
+ sos (int): Start of sequence id
+ eos (int): End of sequence id
+ pre_beam_score_key (str): key of scores to perform pre-beam search
+ pre_beam_ratio (float): beam size in the pre-beam search
+ will be `int(pre_beam_ratio * beam_size)`
+
+ """
+ super().__init__()
+ # set scorers
+ self.weights = weights
+ self.scorers = dict()
+ self.full_scorers = dict()
+ self.part_scorers = dict()
+ # this module dict is required for recursive cast
+ # `self.to(device, dtype)` in `recog.py`
+ self.nn_dict = torch.nn.ModuleDict()
+ for k, v in scorers.items():
+ w = weights.get(k, 0)
+ if w == 0 or v is None:
+ continue
+ assert isinstance(
+ v, ScorerInterface
+ ), f"{k} ({type(v)}) does not implement ScorerInterface"
+ self.scorers[k] = v
+ if isinstance(v, PartialScorerInterface):
+ self.part_scorers[k] = v
+ else:
+ self.full_scorers[k] = v
+ if isinstance(v, torch.nn.Module):
+ self.nn_dict[k] = v
+
+ # set configurations
+ self.sos = sos
+ self.eos = eos
+ self.pre_beam_size = int(pre_beam_ratio * beam_size)
+ self.beam_size = beam_size
+ self.n_vocab = vocab_size
+ if (
+ pre_beam_score_key is not None
+ and pre_beam_score_key != "full"
+ and pre_beam_score_key not in self.full_scorers
+ ):
+ raise KeyError(
+ f"{pre_beam_score_key} is not found in " f"{self.full_scorers}"
+ )
+ self.pre_beam_score_key = pre_beam_score_key
+ self.do_pre_beam = (
+ self.pre_beam_score_key is not None
+ and self.pre_beam_size < self.n_vocab
+ and len(self.part_scorers) > 0
+ )
+
+ def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]:
+ """Get an initial hypothesis data.
+
+ Args:
+ x (torch.Tensor): The encoder output feature
+
+ Returns:
+ Hypothesis: The initial hypothesis.
+
+ """
+ init_states = dict()
+ init_scores = dict()
+ for k, d in self.scorers.items():
+ init_states[k] = d.init_state(x)
+ init_scores[k] = 0.0
+ return [
+ Hypothesis(
+ score=0.0,
+ scores=init_scores,
+ states=init_states,
+ yseq=torch.tensor([self.sos], device=x.device),
+ )
+ ]
+
+ @staticmethod
+ def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
+ """Append new token to prefix tokens.
+
+ Args:
+ xs (torch.Tensor): The prefix token
+ x (int): The new token to append
+
+ Returns:
+ torch.Tensor: New tensor contains: xs + [x] with xs.dtype and
+ xs.device
+
+ """
+ x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
+ return torch.cat((xs, x))
+
+ def score_full(
+ self, hyp: Hypothesis, x: torch.Tensor
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
+ """Score new hypothesis by `self.full_scorers`.
+
+ Args:
+ hyp (Hypothesis): Hypothesis with prefix tokens to score
+ x (torch.Tensor): Corresponding input feature
+
+ Returns:
+ Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
+ score dict of `hyp` that has string keys of `self.full_scorers`
+ and tensor score values of shape: `(self.n_vocab,)`,
+ and state dict that has string keys
+ and state values of `self.full_scorers`
+
+ """
+ scores = dict()
+ states = dict()
+ for k, d in self.full_scorers.items():
+ scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x)
+ return scores, states
+
+ def score_partial(
+ self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
+ """Score new hypothesis by `self.part_scorers`.
+
+ Args:
+ hyp (Hypothesis): Hypothesis with prefix tokens to score
+ ids (torch.Tensor): 1D tensor of new partial tokens to score
+ x (torch.Tensor): Corresponding input feature
+
+ Returns:
+ Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
+ score dict of `hyp` that has string keys of `self.part_scorers`
+ and tensor score values of shape: `(len(ids),)`,
+ and state dict that has string keys
+ and state values of `self.part_scorers`
+
+ """
+ scores = dict()
+ states = dict()
+ for k, d in self.part_scorers.items():
+ scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x)
+ return scores, states
+
+ def beam(
+ self, weighted_scores: torch.Tensor, ids: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute topk full token ids and partial token ids.
+
+ Args:
+ weighted_scores (torch.Tensor): The weighted sum scores for each
+ tokens.
+ Its shape is `(self.n_vocab,)`.
+ ids (torch.Tensor): The partial token ids to compute topk
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]:
+ The topk full token ids and partial token ids.
+ Their shapes are `(self.beam_size,)`
+
+ """
+ # no pre beam performed
+ if weighted_scores.size(0) == ids.size(0):
+ top_ids = weighted_scores.topk(self.beam_size)[1]
+ return top_ids, top_ids
+
+ # mask pruned in pre-beam not to select in topk
+ tmp = weighted_scores[ids]
+ weighted_scores[:] = -float("inf")
+ weighted_scores[ids] = tmp
+ top_ids = weighted_scores.topk(self.beam_size)[1]
+ local_ids = weighted_scores[ids].topk(self.beam_size)[1]
+ return top_ids, local_ids
+
+ @staticmethod
+ def merge_scores(
+ prev_scores: Dict[str, float],
+ next_full_scores: Dict[str, torch.Tensor],
+ full_idx: int,
+ next_part_scores: Dict[str, torch.Tensor],
+ part_idx: int,
+ ) -> Dict[str, torch.Tensor]:
+ """Merge scores for new hypothesis.
+
+ Args:
+ prev_scores (Dict[str, float]):
+ The previous hypothesis scores by `self.scorers`
+ next_full_scores (Dict[str, torch.Tensor]): scores by
+ `self.full_scorers`
+ full_idx (int): The next token id for `next_full_scores`
+ next_part_scores (Dict[str, torch.Tensor]):
+ scores of partial tokens by `self.part_scorers`
+ part_idx (int): The new token id for `next_part_scores`
+
+ Returns:
+ Dict[str, torch.Tensor]: The new score dict.
+ Its keys are names of `self.full_scorers` and
+ `self.part_scorers`.
+ Its values are scalar tensors by the scorers.
+
+ """
+ new_scores = dict()
+ for k, v in next_full_scores.items():
+ new_scores[k] = prev_scores[k] + v[full_idx]
+ for k, v in next_part_scores.items():
+ new_scores[k] = prev_scores[k] + v[part_idx]
+ return new_scores
+
+ def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
+ """Merge states for new hypothesis.
+
+ Args:
+ states: states of `self.full_scorers`
+ part_states: states of `self.part_scorers`
+ part_idx (int): The new token id for `part_scores`
+
+ Returns:
+ Dict[str, torch.Tensor]: The new score dict.
+ Its keys are names of `self.full_scorers` and
+ `self.part_scorers`.
+ Its values are states of the scorers.
+
+ """
+ new_states = dict()
+ for k, v in states.items():
+ new_states[k] = v
+ for k, d in self.part_scorers.items():
+ new_states[k] = d.select_state(part_states[k], part_idx)
+ return new_states
+
+ def search(
+ self, running_hyps: List[Hypothesis], x: torch.Tensor, am_score: torch.Tensor
+ ) -> List[Hypothesis]:
+ """Search new tokens for running hypotheses and encoded speech x.
+
+ Args:
+ running_hyps (List[Hypothesis]): Running hypotheses on beam
+ x (torch.Tensor): Encoded speech feature (T, D)
+
+ Returns:
+ List[Hypotheses]: Best sorted hypotheses
+
+ """
+ best_hyps = []
+ part_ids = torch.arange(self.n_vocab, device=x.device) # no pre-beam
+ for hyp in running_hyps:
+ # scoring
+ weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device)
+ weighted_scores += am_score
+ scores, states = self.score_full(hyp, x)
+ for k in self.full_scorers:
+ weighted_scores += self.weights[k] * scores[k]
+ # partial scoring
+ if self.do_pre_beam:
+ pre_beam_scores = (
+ weighted_scores
+ if self.pre_beam_score_key == "full"
+ else scores[self.pre_beam_score_key]
+ )
+ part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1]
+ part_scores, part_states = self.score_partial(hyp, part_ids, x)
+ for k in self.part_scorers:
+ weighted_scores[part_ids] += self.weights[k] * part_scores[k]
+ # add previous hyp score
+ weighted_scores += hyp.score
+
+ # update hyps
+ for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
+ # will be (2 x beam at most)
+ best_hyps.append(
+ Hypothesis(
+ score=weighted_scores[j],
+ yseq=self.append_token(hyp.yseq, j),
+ scores=self.merge_scores(
+ hyp.scores, scores, j, part_scores, part_j
+ ),
+ states=self.merge_states(states, part_states, part_j),
+ )
+ )
+
+ # sort and prune 2 x beam -> beam
+ best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
+ : min(len(best_hyps), self.beam_size)
+ ]
+ return best_hyps
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ am_scores: torch.Tensor,
+ maxlenratio: float = 0.0,
+ minlenratio: float = 0.0,
+ ) -> List[Hypothesis]:
+ """Perform beam search.
+
+ Args:
+ x (torch.Tensor): Encoded speech feature (T, D)
+ maxlenratio (float): Input length ratio to obtain max output length.
+ If maxlenratio=0.0 (default), it uses a end-detect function
+ to automatically find maximum hypothesis lengths
+ If maxlenratio<0.0, its absolute value is interpreted
+ as a constant max output length.
+ minlenratio (float): Input length ratio to obtain min output length.
+
+ Returns:
+ list[Hypothesis]: N-best decoding results
+
+ """
+ # set length bounds
+ maxlen = am_scores.shape[0]
+
+ # main loop of prefix search
+ running_hyps = self.init_hyp(x)
+ ended_hyps = []
+ for i in range(maxlen):
+ best = self.search(running_hyps, x, am_scores[i])
+ # post process of one iteration
+ running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
+ # end detection
+ if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
+ break
+
+ nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
+ # check the number of hypotheses reaching to eos
+ if len(nbest_hyps) == 0:
+ return (
+ []
+ if minlenratio < 0.1
+ else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
+ )
+
+ best = nbest_hyps[0]
+ return nbest_hyps
+
+ def post_process(
+ self,
+ i: int,
+ maxlen: int,
+ maxlenratio: float,
+ running_hyps: List[Hypothesis],
+ ended_hyps: List[Hypothesis],
+ ) -> List[Hypothesis]:
+ """Perform post-processing of beam search iterations.
+
+ Args:
+ i (int): The length of hypothesis tokens.
+ maxlen (int): The maximum length of tokens in beam search.
+ maxlenratio (int): The maximum length ratio in beam search.
+ running_hyps (List[Hypothesis]): The running hypotheses in beam
+ search.
+ ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
+
+ Returns:
+ List[Hypothesis]: The new running hypotheses.
+
+ """
+
+ # add eos in the final loop to avoid that there are no ended hyps
+ if i == maxlen - 1:
+ # logging.info("adding in the last position in the loop")
+ running_hyps = [
+ h._replace(yseq=self.append_token(h.yseq, self.eos))
+ for h in running_hyps
+ ]
+
+ # add ended hypotheses to a final list, and removed them from current
+ # hypotheses
+ # (this will be a problem, number of hyps < beam)
+ remained_hyps = []
+ for hyp in running_hyps:
+ if hyp.yseq[-1] == self.eos:
+ # e.g., Word LM needs to add final score
+ for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
+ s = d.final_score(hyp.states[k])
+ hyp.scores[k] += s
+ hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
+ ended_hyps.append(hyp)
+ else:
+ remained_hyps.append(hyp)
+ return remained_hyps
+
+
+def build_beam_search(model, args, device):
+ scorers = {}
+ if model.ctc is not None:
+ ctc = CTCPrefixScorer(ctc=model.ctc, eos=model.eos)
+ scorers.update(ctc=ctc)
+ weights = dict(
+ decoder=1.0 - args.ctc_weight,
+ ctc=args.ctc_weight,
+ length_bonus=args.penalty,
+ )
+ beam_search = BeamSearchCIF(
+ beam_size=args.beam_size,
+ weights=weights,
+ scorers=scorers,
+ sos=model.sos,
+ eos=model.eos,
+ vocab_size=model.vocab_size,
+ pre_beam_score_key=None if args.ctc_weight == 1.0 else "full",
+ )
+ beam_search.to(device=device, dtype=torch.float32).eval()
+ return beam_search
diff --git a/modules/wenet_extractor/paraformer/search/ctc.py b/modules/wenet_extractor/paraformer/search/ctc.py
new file mode 100644
index 0000000000000000000000000000000000000000..a40c337ed3ce69051292096a9648b6675f3b8049
--- /dev/null
+++ b/modules/wenet_extractor/paraformer/search/ctc.py
@@ -0,0 +1,181 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+"""ScorerInterface implementation for CTC."""
+import numpy as np
+import torch
+
+from modules.wenet_extractor.paraformer.search.ctc_prefix_score import CTCPrefixScore
+from modules.wenet_extractor.paraformer.search.ctc_prefix_score import CTCPrefixScoreTH
+from modules.wenet_extractor.paraformer.search.scorer_interface import (
+ BatchPartialScorerInterface,
+)
+
+
+class CTCPrefixScorer(BatchPartialScorerInterface):
+ """Decoder interface wrapper for CTCPrefixScore."""
+
+ def __init__(self, ctc: torch.nn.Module, eos: int):
+ """Initialize class.
+
+ Args:
+ ctc (torch.nn.Module): The CTC implementation.
+ For example, :class:`espnet.nets.pytorch_backend.ctc.CTC`
+ eos (int): The end-of-sequence id.
+
+ """
+ self.ctc = ctc
+ self.eos = eos
+ self.impl = None
+
+ def init_state(self, x: torch.Tensor):
+ """Get an initial state for decoding.
+
+ Args:
+ x (torch.Tensor): The encoded feature tensor
+
+ Returns: initial state
+
+ """
+ logp = self.ctc.log_softmax(x.unsqueeze(0)).detach().squeeze(0).cpu().numpy()
+ # TODO(karita): use CTCPrefixScoreTH
+ self.impl = CTCPrefixScore(logp, 0, self.eos, np)
+ return 0, self.impl.initial_state()
+
+ def select_state(self, state, i, new_id=None):
+ """Select state with relative ids in the main beam search.
+
+ Args:
+ state: Decoder state for prefix tokens
+ i (int): Index to select a state in the main beam search
+ new_id (int): New label id to select a state if necessary
+
+ Returns:
+ state: pruned state
+
+ """
+ if type(state) == tuple:
+ if len(state) == 2: # for CTCPrefixScore
+ sc, st = state
+ return sc[i], st[i]
+ else: # for CTCPrefixScoreTH (need new_id > 0)
+ r, log_psi, f_min, f_max, scoring_idmap = state
+ s = log_psi[i, new_id].expand(log_psi.size(1))
+ if scoring_idmap is not None:
+ return r[:, :, i, scoring_idmap[i, new_id]], s, f_min, f_max
+ else:
+ return r[:, :, i, new_id], s, f_min, f_max
+ return None if state is None else state[i]
+
+ def score_partial(self, y, ids, state, x):
+ """Score new token.
+
+ Args:
+ y (torch.Tensor): 1D prefix token
+ next_tokens (torch.Tensor): torch.int64 next token to score
+ state: decoder state for prefix tokens
+ x (torch.Tensor): 2D encoder feature that generates ys
+
+ Returns:
+ tuple[torch.Tensor, Any]:
+ Tuple of a score tensor for y that has a shape
+ `(len(next_tokens),)` and next state for ys
+
+ """
+ prev_score, state = state
+ presub_score, new_st = self.impl(y.cpu(), ids.cpu(), state)
+ tscore = torch.as_tensor(
+ presub_score - prev_score, device=x.device, dtype=x.dtype
+ )
+ return tscore, (presub_score, new_st)
+
+ def batch_init_state(self, x: torch.Tensor):
+ """Get an initial state for decoding.
+
+ Args:
+ x (torch.Tensor): The encoded feature tensor
+
+ Returns: initial state
+
+ """
+ logp = self.ctc.log_softmax(x.unsqueeze(0)) # assuming batch_size = 1
+ xlen = torch.tensor([logp.size(1)])
+ self.impl = CTCPrefixScoreTH(logp, xlen, 0, self.eos)
+ return None
+
+ def batch_score_partial(self, y, ids, state, x):
+ """Score new token.
+
+ Args:
+ y (torch.Tensor): 1D prefix token
+ ids (torch.Tensor): torch.int64 next token to score
+ state: decoder state for prefix tokens
+ x (torch.Tensor): 2D encoder feature that generates ys
+
+ Returns:
+ tuple[torch.Tensor, Any]:
+ Tuple of a score tensor for y that has a shape
+ `(len(next_tokens),)` and next state for ys
+
+ """
+ batch_state = (
+ (
+ torch.stack([s[0] for s in state], dim=2),
+ torch.stack([s[1] for s in state]),
+ state[0][2],
+ state[0][3],
+ )
+ if state[0] is not None
+ else None
+ )
+ return self.impl(y, batch_state, ids)
+
+ def extend_prob(self, x: torch.Tensor):
+ """Extend probs for decoding.
+
+ This extension is for streaming decoding
+ as in Eq (14) in https://arxiv.org/abs/2006.14941
+
+ Args:
+ x (torch.Tensor): The encoded feature tensor
+
+ """
+ logp = self.ctc.log_softmax(x.unsqueeze(0))
+ self.impl.extend_prob(logp)
+
+ def extend_state(self, state):
+ """Extend state for decoding.
+
+ This extension is for streaming decoding
+ as in Eq (14) in https://arxiv.org/abs/2006.14941
+
+ Args:
+ state: The states of hyps
+
+ Returns: exteded state
+
+ """
+ new_state = []
+ for s in state:
+ new_state.append(self.impl.extend_state(s))
+
+ return new_state
diff --git a/modules/wenet_extractor/paraformer/search/ctc_prefix_score.py b/modules/wenet_extractor/paraformer/search/ctc_prefix_score.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4a4eb50b7556f6a463aef5428dba3ebc4f49875
--- /dev/null
+++ b/modules/wenet_extractor/paraformer/search/ctc_prefix_score.py
@@ -0,0 +1,377 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+import torch
+import numpy as np
+
+import six
+
+
+class CTCPrefixScore(object):
+ """Compute CTC label sequence scores
+
+ which is based on Algorithm 2 in WATANABE et al.
+ "HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
+ but extended to efficiently compute the probablities of multiple labels
+ simultaneously
+ """
+
+ def __init__(self, x, blank, eos, xp):
+ self.xp = xp
+ self.logzero = -10000000000.0
+ self.blank = blank
+ self.eos = eos
+ self.input_length = len(x)
+ self.x = x
+
+ def initial_state(self):
+ """Obtain an initial CTC state
+
+ :return: CTC state
+ """
+ # initial CTC state is made of a frame x 2 tensor that corresponds to
+ # r_t^n() and r_t^b(), where 0 and 1 of axis=1 represent
+ # superscripts n and b (non-blank and blank), respectively.
+ r = self.xp.full((self.input_length, 2), self.logzero, dtype=np.float32)
+ r[0, 1] = self.x[0, self.blank]
+ for i in six.moves.range(1, self.input_length):
+ r[i, 1] = r[i - 1, 1] + self.x[i, self.blank]
+ return r
+
+ def __call__(self, y, cs, r_prev):
+ """Compute CTC prefix scores for next labels
+
+ :param y : prefix label sequence
+ :param cs : array of next labels
+ :param r_prev: previous CTC state
+ :return ctc_scores, ctc_states
+ """
+ # initialize CTC states
+ output_length = len(y) - 1 # ignore sos
+ # new CTC states are prepared as a frame x (n or b) x n_labels tensor
+ # that corresponds to r_t^n(h) and r_t^b(h).
+ r = self.xp.ndarray((self.input_length, 2, len(cs)), dtype=np.float32)
+ xs = self.x[:, cs]
+ if output_length == 0:
+ r[0, 0] = xs[0]
+ r[0, 1] = self.logzero
+ else:
+ r[output_length - 1] = self.logzero
+
+ # prepare forward probabilities for the last label
+ r_sum = self.xp.logaddexp(
+ r_prev[:, 0], r_prev[:, 1]
+ ) # log(r_t^n(g) + r_t^b(g))
+ last = y[-1]
+ if output_length > 0 and last in cs:
+ log_phi = self.xp.ndarray((self.input_length, len(cs)), dtype=np.float32)
+ for i in six.moves.range(len(cs)):
+ log_phi[:, i] = r_sum if cs[i] != last else r_prev[:, 1]
+ else:
+ log_phi = r_sum
+
+ # compute forward probabilities log(r_t^n(h)), log(r_t^b(h)),
+ # and log prefix probabilities log(psi)
+ start = max(output_length, 1)
+ log_psi = r[start - 1, 0]
+ for t in six.moves.range(start, self.input_length):
+ r[t, 0] = self.xp.logaddexp(r[t - 1, 0], log_phi[t - 1]) + xs[t]
+ r[t, 1] = (
+ self.xp.logaddexp(r[t - 1, 0], r[t - 1, 1]) + self.x[t, self.blank]
+ )
+ log_psi = self.xp.logaddexp(log_psi, log_phi[t - 1] + xs[t])
+
+ # get P(...eos|X) that ends with the prefix itself
+ eos_pos = self.xp.where(cs == self.eos)[0]
+ if len(eos_pos) > 0:
+ log_psi[eos_pos] = r_sum[-1] # log(r_T^n(g) + r_T^b(g))
+
+ # exclude blank probs
+ blank_pos = self.xp.where(cs == self.blank)[0]
+ if len(blank_pos) > 0:
+ log_psi[blank_pos] = self.logzero
+
+ # return the log prefix probability and CTC states, where the label axis
+ # of the CTC states is moved to the first axis to slice it easily
+ return log_psi, self.xp.rollaxis(r, 2)
+
+
+class CTCPrefixScoreTH(object):
+ """Batch processing of CTCPrefixScore
+
+ which is based on Algorithm 2 in WATANABE et al.
+ "HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
+ but extended to efficiently compute the label probablities for multiple
+ hypotheses simultaneously
+ See also Seki et al. "Vectorized Beam Search for CTC-Attention-Based
+ Speech Recognition," In INTERSPEECH (pp. 3825-3829), 2019.
+ """
+
+ def __init__(self, x, xlens, blank, eos, margin=0):
+ """Construct CTC prefix scorer
+
+ :param torch.Tensor x: input label posterior sequences (B, T, O)
+ :param torch.Tensor xlens: input lengths (B,)
+ :param int blank: blank label id
+ :param int eos: end-of-sequence id
+ :param int margin: margin parameter for windowing (0 means no windowing)
+ """
+ # In the comment lines,
+ # we assume T: input_length, B: batch size, W: beam width, O: output dim
+ self.logzero = -10000000000.0
+ self.blank = blank
+ self.eos = eos
+ self.batch = x.size(0)
+ self.input_length = x.size(1)
+ self.odim = x.size(2)
+ self.dtype = x.dtype
+ self.device = (
+ torch.device("cuda:%d" % x.get_device())
+ if x.is_cuda
+ else torch.device("cpu")
+ )
+ # Pad the rest of posteriors in the batch
+ # TODO(takaaki-hori): need a better way without for-loops
+ for i, l in enumerate(xlens):
+ if l < self.input_length:
+ x[i, l:, :] = self.logzero
+ x[i, l:, blank] = 0
+ # Reshape input x
+ xn = x.transpose(0, 1) # (B, T, O) -> (T, B, O)
+ xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim)
+ self.x = torch.stack([xn, xb]) # (2, T, B, O)
+ self.end_frames = torch.as_tensor(xlens) - 1
+
+ # Setup CTC windowing
+ self.margin = margin
+ if margin > 0:
+ self.frame_ids = torch.arange(
+ self.input_length, dtype=self.dtype, device=self.device
+ )
+ # Base indices for index conversion
+ self.idx_bh = None
+ self.idx_b = torch.arange(self.batch, device=self.device)
+ self.idx_bo = (self.idx_b * self.odim).unsqueeze(1)
+
+ def __call__(self, y, state, scoring_ids=None, att_w=None):
+ """Compute CTC prefix scores for next labels
+
+ :param list y: prefix label sequences
+ :param tuple state: previous CTC state
+ :param torch.Tensor pre_scores: scores for pre-selection of hypotheses
+ (BW, O)
+ :param torch.Tensor att_w: attention weights to decide CTC window
+ :return new_state, ctc_local_scores (BW, O)
+ """
+ output_length = len(y[0]) - 1 # ignore sos
+ last_ids = [yi[-1] for yi in y] # last output label ids
+ n_bh = len(last_ids) # batch * hyps
+ n_hyps = n_bh // self.batch # assuming each utterance has the same
+ self.scoring_num = scoring_ids.size(-1) if scoring_ids is not None else 0
+ # prepare state info
+ if state is None:
+ r_prev = torch.full(
+ (self.input_length, 2, self.batch, n_hyps),
+ self.logzero,
+ dtype=self.dtype,
+ device=self.device,
+ )
+ r_prev[:, 1] = torch.cumsum(self.x[0, :, :, self.blank], 0).unsqueeze(2)
+ r_prev = r_prev.view(-1, 2, n_bh)
+ s_prev = 0.0
+ f_min_prev = 0
+ f_max_prev = 1
+ else:
+ r_prev, s_prev, f_min_prev, f_max_prev = state
+
+ # select input dimensions for scoring
+ if self.scoring_num > 0:
+ scoring_idmap = torch.full(
+ (n_bh, self.odim), -1, dtype=torch.long, device=self.device
+ )
+ snum = self.scoring_num
+ if self.idx_bh is None or n_bh > len(self.idx_bh):
+ self.idx_bh = torch.arange(n_bh, device=self.device).view(-1, 1)
+ scoring_idmap[self.idx_bh[:n_bh], scoring_ids] = torch.arange(
+ snum, device=self.device
+ )
+ scoring_idx = (
+ scoring_ids + self.idx_bo.repeat(1, n_hyps).view(-1, 1)
+ ).view(-1)
+ x_ = torch.index_select(
+ self.x.view(2, -1, self.batch * self.odim), 2, scoring_idx
+ ).view(2, -1, n_bh, snum)
+ else:
+ scoring_ids = None
+ scoring_idmap = None
+ snum = self.odim
+ x_ = self.x.unsqueeze(3).repeat(1, 1, 1, n_hyps, 1).view(2, -1, n_bh, snum)
+
+ # new CTC forward probs are prepared as a (T x 2 x BW x S) tensor
+ # that corresponds to r_t^n(h) and r_t^b(h) in a batch.
+ r = torch.full(
+ (self.input_length, 2, n_bh, snum),
+ self.logzero,
+ dtype=self.dtype,
+ device=self.device,
+ )
+ if output_length == 0:
+ r[0, 0] = x_[0, 0]
+
+ r_sum = torch.logsumexp(r_prev, 1)
+ log_phi = r_sum.unsqueeze(2).repeat(1, 1, snum)
+ if scoring_ids is not None:
+ for idx in range(n_bh):
+ pos = scoring_idmap[idx, last_ids[idx]]
+ if pos >= 0:
+ log_phi[:, idx, pos] = r_prev[:, 1, idx]
+ else:
+ for idx in range(n_bh):
+ log_phi[:, idx, last_ids[idx]] = r_prev[:, 1, idx]
+
+ # decide start and end frames based on attention weights
+ if att_w is not None and self.margin > 0:
+ f_arg = torch.matmul(att_w, self.frame_ids)
+ f_min = max(int(f_arg.min().cpu()), f_min_prev)
+ f_max = max(int(f_arg.max().cpu()), f_max_prev)
+ start = min(f_max_prev, max(f_min - self.margin, output_length, 1))
+ end = min(f_max + self.margin, self.input_length)
+ else:
+ f_min = f_max = 0
+ start = max(output_length, 1)
+ end = self.input_length
+
+ # compute forward probabilities log(r_t^n(h)) and log(r_t^b(h))
+ for t in range(start, end):
+ rp = r[t - 1]
+ rr = torch.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view(
+ 2, 2, n_bh, snum
+ )
+ r[t] = torch.logsumexp(rr, 1) + x_[:, t]
+
+ # compute log prefix probabilities log(psi)
+ log_phi_x = torch.cat((log_phi[0].unsqueeze(0), log_phi[:-1]), dim=0) + x_[0]
+ if scoring_ids is not None:
+ log_psi = torch.full(
+ (n_bh, self.odim), self.logzero, dtype=self.dtype, device=self.device
+ )
+ log_psi_ = torch.logsumexp(
+ torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0),
+ dim=0,
+ )
+ for si in range(n_bh):
+ log_psi[si, scoring_ids[si]] = log_psi_[si]
+ else:
+ log_psi = torch.logsumexp(
+ torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0),
+ dim=0,
+ )
+
+ for si in range(n_bh):
+ log_psi[si, self.eos] = r_sum[self.end_frames[si // n_hyps], si]
+
+ # exclude blank probs
+ log_psi[:, self.blank] = self.logzero
+
+ return (log_psi - s_prev), (r, log_psi, f_min, f_max, scoring_idmap)
+
+ def index_select_state(self, state, best_ids):
+ """Select CTC states according to best ids
+
+ :param state : CTC state
+ :param best_ids : index numbers selected by beam pruning (B, W)
+ :return selected_state
+ """
+ r, s, f_min, f_max, scoring_idmap = state
+ # convert ids to BHO space
+ n_bh = len(s)
+ n_hyps = n_bh // self.batch
+ vidx = (best_ids + (self.idx_b * (n_hyps * self.odim)).view(-1, 1)).view(-1)
+ # select hypothesis scores
+ s_new = torch.index_select(s.view(-1), 0, vidx)
+ s_new = s_new.view(-1, 1).repeat(1, self.odim).view(n_bh, self.odim)
+ # convert ids to BHS space (S: scoring_num)
+ if scoring_idmap is not None:
+ snum = self.scoring_num
+ hyp_idx = (best_ids // self.odim + (self.idx_b * n_hyps).view(-1, 1)).view(
+ -1
+ )
+ label_ids = torch.fmod(best_ids, self.odim).view(-1)
+ score_idx = scoring_idmap[hyp_idx, label_ids]
+ score_idx[score_idx == -1] = 0
+ vidx = score_idx + hyp_idx * snum
+ else:
+ snum = self.odim
+ # select forward probabilities
+ r_new = torch.index_select(r.view(-1, 2, n_bh * snum), 2, vidx).view(
+ -1, 2, n_bh
+ )
+ return r_new, s_new, f_min, f_max
+
+ def extend_prob(self, x):
+ """Extend CTC prob.
+
+ :param torch.Tensor x: input label posterior sequences (B, T, O)
+ """
+
+ if self.x.shape[1] < x.shape[1]: # self.x (2,T,B,O); x (B,T,O)
+ # Pad the rest of posteriors in the batch
+ # TODO(takaaki-hori): need a better way without for-loops
+ xlens = [x.size(1)]
+ for i, l in enumerate(xlens):
+ if l < self.input_length:
+ x[i, l:, :] = self.logzero
+ x[i, l:, self.blank] = 0
+ tmp_x = self.x
+ xn = x.transpose(0, 1) # (B, T, O) -> (T, B, O)
+ xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim)
+ self.x = torch.stack([xn, xb]) # (2, T, B, O)
+ self.x[:, : tmp_x.shape[1], :, :] = tmp_x
+ self.input_length = x.size(1)
+ self.end_frames = torch.as_tensor(xlens) - 1
+
+ def extend_state(self, state):
+ """Compute CTC prefix state.
+
+
+ :param state : CTC state
+ :return ctc_state
+ """
+
+ if state is None:
+ # nothing to do
+ return state
+ else:
+ r_prev, s_prev, f_min_prev, f_max_prev = state
+
+ r_prev_new = torch.full(
+ (self.input_length, 2),
+ self.logzero,
+ dtype=self.dtype,
+ device=self.device,
+ )
+ start = max(r_prev.shape[0], 1)
+ r_prev_new[0:start] = r_prev
+ for t in six.moves.range(start, self.input_length):
+ r_prev_new[t, 1] = r_prev_new[t - 1, 1] + self.x[0, t, :, self.blank]
+
+ return r_prev_new, s_prev, f_min_prev, f_max_prev
diff --git a/modules/wenet_extractor/paraformer/search/scorer_interface.py b/modules/wenet_extractor/paraformer/search/scorer_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed021726f3a93dddd9b00eab23bc6fbe7583a04b
--- /dev/null
+++ b/modules/wenet_extractor/paraformer/search/scorer_interface.py
@@ -0,0 +1,208 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+"""Scorer interface module."""
+from abc import ABC
+from typing import Any
+from typing import List
+from typing import Tuple
+
+import torch
+
+
+class ScorerInterface:
+ """Scorer interface for beam search.
+
+ The scorer performs scoring of the all tokens in vocabulary.
+
+ Examples:
+ * Search heuristics
+ * :class:`espnet.nets.scorers.length_bonus.LengthBonus`
+ * Decoder networks of the sequence-to-sequence models
+ * :class:`espnet.nets.pytorch_backend.nets.transformer.decoder
+ .Decoder`
+ * :class:`espnet.nets.pytorch_backend.nets.rnn.decoders.Decoder`
+ * Neural language models
+ * :class:`espnet.nets.pytorch_backend.lm.transformer.TransformerLM`
+ * :class:`espnet.nets.pytorch_backend.lm.default.DefaultRNNLM`
+ * :class:`espnet.nets.pytorch_backend.lm.seq_rnn.SequentialRNNLM`
+
+ """
+
+ def init_state(self, x: torch.Tensor) -> Any:
+ """Get an initial state for decoding (optional).
+
+ Args:
+ x (torch.Tensor): The encoded feature tensor
+
+ Returns: initial state
+
+ """
+ return None
+
+ def select_state(self, state: Any, i: int, new_id: int = None) -> Any:
+ """Select state with relative ids in the main beam search.
+
+ Args:
+ state: Decoder state for prefix tokens
+ i (int): Index to select a state in the main beam search
+ new_id (int): New label index to select a state if necessary
+
+ Returns:
+ state: pruned state
+
+ """
+ return None if state is None else state[i]
+
+ def score(
+ self, y: torch.Tensor, state: Any, x: torch.Tensor
+ ) -> Tuple[torch.Tensor, Any]:
+ """Score new token (required).
+
+ Args:
+ y (torch.Tensor): 1D torch.int64 prefix tokens.
+ state: Scorer state for prefix tokens
+ x (torch.Tensor): The encoder feature that generates ys.
+
+ Returns:
+ tuple[torch.Tensor, Any]: Tuple of
+ scores for next token that has a shape of `(n_vocab)`
+ and next state for ys
+
+ """
+ raise NotImplementedError
+
+ def final_score(self, state: Any) -> float:
+ """Score eos (optional).
+
+ Args:
+ state: Scorer state for prefix tokens
+
+ Returns:
+ float: final score
+
+ """
+ return 0.0
+
+
+class BatchScorerInterface(ScorerInterface, ABC):
+ """Batch scorer interface."""
+
+ def batch_init_state(self, x: torch.Tensor) -> Any:
+ """Get an initial state for decoding (optional).
+
+ Args:
+ x (torch.Tensor): The encoded feature tensor
+
+ Returns: initial state
+
+ """
+ return self.init_state(x)
+
+ def batch_score(
+ self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
+ ) -> Tuple[torch.Tensor, List[Any]]:
+ """Score new token batch (required).
+
+ Args:
+ ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
+ states (List[Any]): Scorer states for prefix tokens.
+ xs (torch.Tensor):
+ The encoder feature that generates ys (n_batch, xlen, n_feat).
+
+ Returns:
+ tuple[torch.Tensor, List[Any]]: Tuple of
+ batchfied scores for next token with shape of `(n_batch,
+ n_vocab)`
+ and next state list for ys.
+
+ """
+ scores = list()
+ outstates = list()
+ for i, (y, state, x) in enumerate(zip(ys, states, xs)):
+ score, outstate = self.score(y, state, x)
+ outstates.append(outstate)
+ scores.append(score)
+ scores = torch.cat(scores, 0).view(ys.shape[0], -1)
+ return scores, outstates
+
+
+class PartialScorerInterface(ScorerInterface, ABC):
+ """Partial scorer interface for beam search.
+
+ The partial scorer performs scoring when non-partial scorer finished scoring
+ and receives pre-pruned next tokens to score because it is too heavy to
+ score all the tokens.
+
+ Examples:
+ * Prefix search for connectionist-temporal-classification models
+ * :class:`espnet.nets.scorers.ctc.CTCPrefixScorer`
+
+ """
+
+ def score_partial(
+ self, y: torch.Tensor, next_tokens: torch.Tensor, state: Any, x: torch.Tensor
+ ) -> Tuple[torch.Tensor, Any]:
+ """Score new token (required).
+
+ Args:
+ y (torch.Tensor): 1D prefix token
+ next_tokens (torch.Tensor): torch.int64 next token to score
+ state: decoder state for prefix tokens
+ x (torch.Tensor): The encoder feature that generates ys
+
+ Returns:
+ tuple[torch.Tensor, Any]:
+ Tuple of a score tensor for y that has a shape
+ `(len(next_tokens),)` and next state for ys
+
+ """
+ raise NotImplementedError
+
+
+class BatchPartialScorerInterface(BatchScorerInterface, PartialScorerInterface, ABC):
+ """Batch partial scorer interface for beam search."""
+
+ def batch_score_partial(
+ self,
+ ys: torch.Tensor,
+ next_tokens: torch.Tensor,
+ states: List[Any],
+ xs: torch.Tensor,
+ ) -> Tuple[torch.Tensor, Any]:
+ """Score new token (required).
+
+ Args:
+ ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
+ next_tokens (torch.Tensor): torch.int64 tokens to score (n_batch,
+ n_token).
+ states (List[Any]): Scorer states for prefix tokens.
+ xs (torch.Tensor):
+ The encoder feature that generates ys (n_batch, xlen, n_feat).
+
+ Returns:
+ tuple[torch.Tensor, Any]:
+ Tuple of a score tensor for ys that has a shape `(n_batch,
+ n_vocab)`
+ and next states for ys
+ """
+ raise NotImplementedError
diff --git a/modules/wenet_extractor/paraformer/utils.py b/modules/wenet_extractor/paraformer/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..83d644f7643f3f0c213a54ab4b634e89e51a1c40
--- /dev/null
+++ b/modules/wenet_extractor/paraformer/utils.py
@@ -0,0 +1,76 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+from typing import Optional
+
+import six
+import torch
+import numpy as np
+
+
+def sequence_mask(
+ lengths,
+ maxlen: Optional[int] = None,
+ dtype: torch.dtype = torch.float32,
+ device: Optional[torch.device] = None,
+) -> torch.Tensor:
+ if maxlen is None:
+ maxlen = lengths.max()
+ row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
+ matrix = torch.unsqueeze(lengths, dim=-1)
+ mask = row_vector < matrix
+ mask = mask.detach()
+
+ return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
+
+
+def end_detect(ended_hyps, i, M=3, d_end=np.log(1 * np.exp(-10))):
+ """End detection.
+
+ described in Eq. (50) of S. Watanabe et al
+ "Hybrid CTC/Attention Architecture for End-to-End Speech Recognition"
+
+ :param ended_hyps:
+ :param i:
+ :param M:
+ :param d_end:
+ :return:
+ """
+ if len(ended_hyps) == 0:
+ return False
+ count = 0
+ best_hyp = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[0]
+ for m in six.moves.range(M):
+ # get ended_hyps with their length is i - m
+ hyp_length = i - m
+ hyps_same_length = [x for x in ended_hyps if len(x["yseq"]) == hyp_length]
+ if len(hyps_same_length) > 0:
+ best_hyp_same_length = sorted(
+ hyps_same_length, key=lambda x: x["score"], reverse=True
+ )[0]
+ if best_hyp_same_length["score"] - best_hyp["score"] < d_end:
+ count += 1
+
+ if count == M:
+ return True
+ else:
+ return False
diff --git a/modules/wenet_extractor/squeezeformer/__init__.py b/modules/wenet_extractor/squeezeformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/modules/wenet_extractor/squeezeformer/attention.py b/modules/wenet_extractor/squeezeformer/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..830da3dccc0257afc36747c7f572455746ef652e
--- /dev/null
+++ b/modules/wenet_extractor/squeezeformer/attention.py
@@ -0,0 +1,239 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+"""Multi-Head Attention layer definition."""
+
+import math
+import torch
+import torch.nn as nn
+from modules.wenet_extractor.transformer.attention import MultiHeadedAttention
+from typing import Tuple
+
+
+class RelPositionMultiHeadedAttention(MultiHeadedAttention):
+ """Multi-Head Attention layer with relative position encoding.
+ Paper: https://arxiv.org/abs/1901.02860
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+ """
+
+ def __init__(
+ self,
+ n_head,
+ n_feat,
+ dropout_rate,
+ do_rel_shift=False,
+ adaptive_scale=False,
+ init_weights=False,
+ ):
+ """Construct an RelPositionMultiHeadedAttention object."""
+ super().__init__(n_head, n_feat, dropout_rate)
+ # linear transformation for positional encoding
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
+ # these two learnable bias are used in matrix c and matrix d
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ self.do_rel_shift = do_rel_shift
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
+ self.adaptive_scale = adaptive_scale
+ self.ada_scale = nn.Parameter(
+ torch.ones([1, 1, n_feat]), requires_grad=adaptive_scale
+ )
+ self.ada_bias = nn.Parameter(
+ torch.zeros([1, 1, n_feat]), requires_grad=adaptive_scale
+ )
+ if init_weights:
+ self.init_weights()
+
+ def init_weights(self):
+ input_max = (self.h * self.d_k) ** -0.5
+ torch.nn.init.uniform_(self.linear_q.weight, -input_max, input_max)
+ torch.nn.init.uniform_(self.linear_q.bias, -input_max, input_max)
+ torch.nn.init.uniform_(self.linear_k.weight, -input_max, input_max)
+ torch.nn.init.uniform_(self.linear_k.bias, -input_max, input_max)
+ torch.nn.init.uniform_(self.linear_v.weight, -input_max, input_max)
+ torch.nn.init.uniform_(self.linear_v.bias, -input_max, input_max)
+ torch.nn.init.uniform_(self.linear_pos.weight, -input_max, input_max)
+ torch.nn.init.uniform_(self.linear_out.weight, -input_max, input_max)
+ torch.nn.init.uniform_(self.linear_out.bias, -input_max, input_max)
+
+ def rel_shift(self, x, zero_triu: bool = False):
+ """Compute relative positinal encoding.
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, size).
+ zero_triu (bool): If true, return the lower triangular part of
+ the matrix.
+ Returns:
+ torch.Tensor: Output tensor.
+ """
+
+ zero_pad = torch.zeros(
+ (x.size()[0], x.size()[1], x.size()[2], 1), device=x.device, dtype=x.dtype
+ )
+ x_padded = torch.cat([zero_pad, x], dim=-1)
+
+ x_padded = x_padded.view(x.size()[0], x.size()[1], x.size(3) + 1, x.size(2))
+ x = x_padded[:, :, 1:].view_as(x)
+
+ if zero_triu:
+ ones = torch.ones((x.size(2), x.size(3)))
+ x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
+
+ return x
+
+ def forward_attention(
+ self,
+ value: torch.Tensor,
+ scores: torch.Tensor,
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ ) -> torch.Tensor:
+ """Compute attention context vector.
+
+ Args:
+ value (torch.Tensor): Transformed value, size
+ (#batch, n_head, time2, d_k).
+ scores (torch.Tensor): Attention score, size
+ (#batch, n_head, time1, time2).
+ mask (torch.Tensor): Mask, size (#batch, 1, time2) or
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
+
+ Returns:
+ torch.Tensor: Transformed value (#batch, time1, d_model)
+ weighted by the attention score (#batch, time1, time2).
+
+ """
+ n_batch = value.size(0)
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be True?
+ # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
+ # 1st chunk to ease the onnx export.]
+ # 2. pytorch training
+ if mask.size(2) > 0: # time2 > 0
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
+ # For last chunk, time2 might be larger than scores.size(-1)
+ mask = mask[:, :, :, : scores.size(-1)] # (batch, 1, *, time2)
+ scores = scores.masked_fill(mask, -float("inf"))
+ # (batch, head, time1, time2)
+ attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be False?
+ # 1. onnx(16/-1, -1/-1, 16/0)
+ # 2. jit (16/-1, -1/-1, 16/0, 16/4)
+ else:
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+
+ p_attn = self.dropout(attn)
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
+ x = (
+ x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+ ) # (batch, time1, d_model)
+
+ return self.linear_out(x) # (batch, time1, d_model)
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ pos_emb: torch.Tensor = torch.empty(0),
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
+ pos_emb (torch.Tensor): Positional embedding tensor
+ (#batch, time2, size).
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
+ where `cache_t == chunk_size * num_decoding_left_chunks`
+ and `head * d_k == size`
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
+ where `cache_t == chunk_size * num_decoding_left_chunks`
+ and `head * d_k == size`
+ """
+ if self.adaptive_scale:
+ query = self.ada_scale * query + self.ada_bias
+ key = self.ada_scale * key + self.ada_bias
+ value = self.ada_scale * value + self.ada_bias
+ q, k, v = self.forward_qkv(query, key, value)
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
+
+ # NOTE(xcsong):
+ # when export onnx model, for 1st chunk, we feed
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
+ # and we will always do splitting and
+ # concatnation(this will simplify onnx export). Note that
+ # it's OK to concat & split zero-shaped tensors(see code below).
+ # when export jit model, for 1st chunk, we always feed
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
+ # >>> a = torch.ones((1, 2, 0, 4))
+ # >>> b = torch.ones((1, 2, 3, 4))
+ # >>> c = torch.cat((a, b), dim=2)
+ # >>> torch.equal(b, c) # True
+ # >>> d = torch.split(a, 2, dim=-1)
+ # >>> torch.equal(d[0], d[1]) # True
+ if cache.size(0) > 0:
+ key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1)
+ k = torch.cat([key_cache, k], dim=2)
+ v = torch.cat([value_cache, v], dim=2)
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
+ # non-trivial to calculate `next_cache_start` here.
+ new_cache = torch.cat((k, v), dim=-1)
+
+ n_batch_pos = pos_emb.size(0)
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
+
+ # (batch, head, time1, d_k)
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
+ # (batch, head, time1, d_k)
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
+
+ # compute attention score
+ # first compute matrix a and matrix c
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ # (batch, head, time1, time2)
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
+
+ # compute matrix b and matrix d
+ # (batch, head, time1, time2)
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
+ # Remove rel_shift since it is useless in speech recognition,
+ # and it requires special attention for streaming.
+ if self.do_rel_shift:
+ matrix_bd = self.rel_shift(matrix_bd)
+
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
+ self.d_k
+ ) # (batch, head, time1, time2)
+
+ return self.forward_attention(v, scores, mask), new_cache
diff --git a/modules/wenet_extractor/squeezeformer/conv2d.py b/modules/wenet_extractor/squeezeformer/conv2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..01d2981e4b6853013ce1a629496e03b04517d0c0
--- /dev/null
+++ b/modules/wenet_extractor/squeezeformer/conv2d.py
@@ -0,0 +1,93 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+"""Conv2d Module with Valid Padding"""
+
+import torch.nn.functional as F
+from torch.nn.modules.conv import _ConvNd, _size_2_t, Union, _pair, Tensor, Optional
+
+
+class Conv2dValid(_ConvNd):
+ """
+ Conv2d operator for VALID mode padding.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: _size_2_t,
+ stride: _size_2_t = 1,
+ padding: Union[str, _size_2_t] = 0,
+ dilation: _size_2_t = 1,
+ groups: int = 1,
+ bias: bool = True,
+ padding_mode: str = "zeros", # TODO: refine this type
+ device=None,
+ dtype=None,
+ valid_trigx: bool = False,
+ valid_trigy: bool = False,
+ ) -> None:
+ factory_kwargs = {"device": device, "dtype": dtype}
+ kernel_size_ = _pair(kernel_size)
+ stride_ = _pair(stride)
+ padding_ = padding if isinstance(padding, str) else _pair(padding)
+ dilation_ = _pair(dilation)
+ super(Conv2dValid, self).__init__(
+ in_channels,
+ out_channels,
+ kernel_size_,
+ stride_,
+ padding_,
+ dilation_,
+ False,
+ _pair(0),
+ groups,
+ bias,
+ padding_mode,
+ **factory_kwargs,
+ )
+ self.valid_trigx = valid_trigx
+ self.valid_trigy = valid_trigy
+
+ def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
+ validx, validy = 0, 0
+ if self.valid_trigx:
+ validx = (
+ input.size(-2) * (self.stride[-2] - 1) - 1 + self.kernel_size[-2]
+ ) // 2
+ if self.valid_trigy:
+ validy = (
+ input.size(-1) * (self.stride[-1] - 1) - 1 + self.kernel_size[-1]
+ ) // 2
+ return F.conv2d(
+ input,
+ weight,
+ bias,
+ self.stride,
+ (validx, validy),
+ self.dilation,
+ self.groups,
+ )
+
+ def forward(self, input: Tensor) -> Tensor:
+ return self._conv_forward(input, self.weight, self.bias)
diff --git a/modules/wenet_extractor/squeezeformer/convolution.py b/modules/wenet_extractor/squeezeformer/convolution.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0ec43dadcd22aff53f79c459888344def57ef12
--- /dev/null
+++ b/modules/wenet_extractor/squeezeformer/convolution.py
@@ -0,0 +1,182 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+"""ConvolutionModule definition."""
+
+from typing import Tuple
+
+import torch
+from torch import nn
+
+
+class ConvolutionModule(nn.Module):
+ """ConvolutionModule in Conformer model."""
+
+ def __init__(
+ self,
+ channels: int,
+ kernel_size: int = 15,
+ activation: nn.Module = nn.ReLU(),
+ norm: str = "batch_norm",
+ causal: bool = False,
+ bias: bool = True,
+ adaptive_scale: bool = False,
+ init_weights: bool = False,
+ ):
+ """Construct an ConvolutionModule object.
+ Args:
+ channels (int): The number of channels of conv layers.
+ kernel_size (int): Kernel size of conv layers.
+ causal (int): Whether use causal convolution or not
+ """
+ super().__init__()
+ self.bias = bias
+ self.channels = channels
+ self.kernel_size = kernel_size
+ self.adaptive_scale = adaptive_scale
+ self.ada_scale = torch.nn.Parameter(
+ torch.ones([1, 1, channels]), requires_grad=adaptive_scale
+ )
+ self.ada_bias = torch.nn.Parameter(
+ torch.zeros([1, 1, channels]), requires_grad=adaptive_scale
+ )
+
+ self.pointwise_conv1 = nn.Conv1d(
+ channels,
+ 2 * channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+ # self.lorder is used to distinguish if it's a causal convolution,
+ # if self.lorder > 0: it's a causal convolution, the input will be
+ # padded with self.lorder frames on the left in forward.
+ # else: it's a symmetrical convolution
+ if causal:
+ padding = 0
+ self.lorder = kernel_size - 1
+ else:
+ # kernel_size should be an odd number for none causal convolution
+ assert (kernel_size - 1) % 2 == 0
+ padding = (kernel_size - 1) // 2
+ self.lorder = 0
+ self.depthwise_conv = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ stride=1,
+ padding=padding,
+ groups=channels,
+ bias=bias,
+ )
+
+ assert norm in ["batch_norm", "layer_norm"]
+ if norm == "batch_norm":
+ self.use_layer_norm = False
+ self.norm = nn.BatchNorm1d(channels)
+ else:
+ self.use_layer_norm = True
+ self.norm = nn.LayerNorm(channels)
+
+ self.pointwise_conv2 = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+ self.activation = activation
+ if init_weights:
+ self.init_weights()
+
+ def init_weights(self):
+ pw_max = self.channels**-0.5
+ dw_max = self.kernel_size**-0.5
+ torch.nn.init.uniform_(self.pointwise_conv1.weight.data, -pw_max, pw_max)
+ if self.bias:
+ torch.nn.init.uniform_(self.pointwise_conv1.bias.data, -pw_max, pw_max)
+ torch.nn.init.uniform_(self.depthwise_conv.weight.data, -dw_max, dw_max)
+ if self.bias:
+ torch.nn.init.uniform_(self.depthwise_conv.bias.data, -dw_max, dw_max)
+ torch.nn.init.uniform_(self.pointwise_conv2.weight.data, -pw_max, pw_max)
+ if self.bias:
+ torch.nn.init.uniform_(self.pointwise_conv2.bias.data, -pw_max, pw_max)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ cache: torch.Tensor = torch.zeros((0, 0, 0)),
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute convolution module.
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, channels).
+ mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
+ (0, 0, 0) means fake mask.
+ cache (torch.Tensor): left context cache, it is only
+ used in causal convolution (#batch, channels, cache_t),
+ (0, 0, 0) meas fake cache.
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, channels).
+ """
+ if self.adaptive_scale:
+ x = self.ada_scale * x + self.ada_bias
+ # exchange the temporal dimension and the feature dimension
+ x = x.transpose(1, 2) # (#batch, channels, time)
+ # mask batch padding
+ if mask_pad.size(2) > 0: # time > 0
+ x.masked_fill_(~mask_pad, 0.0)
+
+ if self.lorder > 0:
+ if cache.size(2) == 0: # cache_t == 0
+ x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
+ else:
+ assert cache.size(0) == x.size(0) # equal batch
+ assert cache.size(1) == x.size(1) # equal channel
+ x = torch.cat((cache, x), dim=2)
+ assert x.size(2) > self.lorder
+ new_cache = x[:, :, -self.lorder :]
+ else:
+ # It's better we just return None if no cache is required,
+ # However, for JIT export, here we just fake one tensor instead of
+ # None.
+ new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
+
+ # GLU mechanism
+ x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
+
+ # 1D Depthwise Conv
+ x = self.depthwise_conv(x)
+ if self.use_layer_norm:
+ x = x.transpose(1, 2)
+ x = self.activation(self.norm(x))
+ if self.use_layer_norm:
+ x = x.transpose(1, 2)
+ x = self.pointwise_conv2(x)
+ # mask batch padding
+ if mask_pad.size(2) > 0: # time > 0
+ x.masked_fill_(~mask_pad, 0.0)
+
+ return x.transpose(1, 2), new_cache
diff --git a/modules/wenet_extractor/squeezeformer/encoder.py b/modules/wenet_extractor/squeezeformer/encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..a51568bceb0189eaf1b3577e56d0d5808b2d9b5d
--- /dev/null
+++ b/modules/wenet_extractor/squeezeformer/encoder.py
@@ -0,0 +1,516 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+import torch
+import torch.nn as nn
+from typing import Tuple, Union, Optional, List
+from modules.wenet_extractor.squeezeformer.subsampling import (
+ DepthwiseConv2dSubsampling4,
+ TimeReductionLayer1D,
+ TimeReductionLayer2D,
+ TimeReductionLayerStream,
+)
+from modules.wenet_extractor.squeezeformer.encoder_layer import (
+ SqueezeformerEncoderLayer,
+)
+from modules.wenet_extractor.transformer.embedding import RelPositionalEncoding
+from modules.wenet_extractor.transformer.attention import MultiHeadedAttention
+from modules.wenet_extractor.squeezeformer.attention import (
+ RelPositionMultiHeadedAttention,
+)
+from modules.wenet_extractor.squeezeformer.positionwise_feed_forward import (
+ PositionwiseFeedForward,
+)
+from modules.wenet_extractor.squeezeformer.convolution import ConvolutionModule
+from modules.wenet_extractor.utils.mask import make_pad_mask, add_optional_chunk_mask
+from modules.wenet_extractor.utils.common import get_activation
+
+
+class SqueezeformerEncoder(nn.Module):
+ def __init__(
+ self,
+ input_size: int = 80,
+ encoder_dim: int = 256,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ num_blocks: int = 12,
+ reduce_idx: Optional[Union[int, List[int]]] = 5,
+ recover_idx: Optional[Union[int, List[int]]] = 11,
+ feed_forward_expansion_factor: int = 4,
+ dw_stride: bool = False,
+ input_dropout_rate: float = 0.1,
+ pos_enc_layer_type: str = "rel_pos",
+ time_reduction_layer_type: str = "conv1d",
+ do_rel_shift: bool = True,
+ feed_forward_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.1,
+ cnn_module_kernel: int = 31,
+ cnn_norm_type: str = "batch_norm",
+ dropout: float = 0.1,
+ causal: bool = False,
+ adaptive_scale: bool = True,
+ activation_type: str = "swish",
+ init_weights: bool = True,
+ global_cmvn: torch.nn.Module = None,
+ normalize_before: bool = False,
+ use_dynamic_chunk: bool = False,
+ concat_after: bool = False,
+ static_chunk_size: int = 0,
+ use_dynamic_left_chunk: bool = False,
+ ):
+ """Construct SqueezeformerEncoder
+
+ Args:
+ input_size to use_dynamic_chunk, see in Transformer BaseEncoder.
+ encoder_dim (int): The hidden dimension of encoder layer.
+ output_size (int): The output dimension of final projection layer.
+ attention_heads (int): Num of attention head in attention module.
+ num_blocks (int): Num of encoder layers.
+ reduce_idx Optional[Union[int, List[int]]]:
+ reduce layer index, from 40ms to 80ms per frame.
+ recover_idx Optional[Union[int, List[int]]]:
+ recover layer index, from 80ms to 40ms per frame.
+ feed_forward_expansion_factor (int): Enlarge coefficient of FFN.
+ dw_stride (bool): Whether do depthwise convolution
+ on subsampling module.
+ input_dropout_rate (float): Dropout rate of input projection layer.
+ pos_enc_layer_type (str): Self attention type.
+ time_reduction_layer_type (str): Conv1d or Conv2d reduction layer.
+ do_rel_shift (bool): Whether to do relative shift
+ operation on rel-attention module.
+ cnn_module_kernel (int): Kernel size of CNN module.
+ activation_type (str): Encoder activation function type.
+ use_cnn_module (bool): Whether to use convolution module.
+ cnn_module_kernel (int): Kernel size of convolution module.
+ adaptive_scale (bool): Whether to use adaptive scale.
+ init_weights (bool): Whether to initialize weights.
+ causal (bool): whether to use causal convolution or not.
+ """
+ super(SqueezeformerEncoder, self).__init__()
+ self.global_cmvn = global_cmvn
+ self.reduce_idx: Optional[Union[int, List[int]]] = (
+ [reduce_idx] if type(reduce_idx) == int else reduce_idx
+ )
+ self.recover_idx: Optional[Union[int, List[int]]] = (
+ [recover_idx] if type(recover_idx) == int else recover_idx
+ )
+ self.check_ascending_list()
+ if reduce_idx is None:
+ self.time_reduce = None
+ else:
+ if recover_idx is None:
+ self.time_reduce = "normal" # no recovery at the end
+ else:
+ self.time_reduce = "recover" # recovery at the end
+ assert len(self.reduce_idx) == len(self.recover_idx)
+ self.reduce_stride = 2
+ self._output_size = output_size
+ self.normalize_before = normalize_before
+ self.static_chunk_size = static_chunk_size
+ self.use_dynamic_chunk = use_dynamic_chunk
+ self.use_dynamic_left_chunk = use_dynamic_left_chunk
+ self.pos_enc_layer_type = pos_enc_layer_type
+ activation = get_activation(activation_type)
+
+ # self-attention module definition
+ if pos_enc_layer_type != "rel_pos":
+ encoder_selfattn_layer = MultiHeadedAttention
+ encoder_selfattn_layer_args = (
+ attention_heads,
+ output_size,
+ attention_dropout_rate,
+ )
+ else:
+ encoder_selfattn_layer = RelPositionMultiHeadedAttention
+ encoder_selfattn_layer_args = (
+ attention_heads,
+ encoder_dim,
+ attention_dropout_rate,
+ do_rel_shift,
+ adaptive_scale,
+ init_weights,
+ )
+
+ # feed-forward module definition
+ positionwise_layer = PositionwiseFeedForward
+ positionwise_layer_args = (
+ encoder_dim,
+ encoder_dim * feed_forward_expansion_factor,
+ feed_forward_dropout_rate,
+ activation,
+ adaptive_scale,
+ init_weights,
+ )
+
+ # convolution module definition
+ convolution_layer = ConvolutionModule
+ convolution_layer_args = (
+ encoder_dim,
+ cnn_module_kernel,
+ activation,
+ cnn_norm_type,
+ causal,
+ True,
+ adaptive_scale,
+ init_weights,
+ )
+
+ self.embed = DepthwiseConv2dSubsampling4(
+ 1,
+ encoder_dim,
+ RelPositionalEncoding(encoder_dim, dropout_rate=0.1),
+ dw_stride,
+ input_size,
+ input_dropout_rate,
+ init_weights,
+ )
+
+ self.preln = nn.LayerNorm(encoder_dim)
+ self.encoders = torch.nn.ModuleList(
+ [
+ SqueezeformerEncoderLayer(
+ encoder_dim,
+ encoder_selfattn_layer(*encoder_selfattn_layer_args),
+ positionwise_layer(*positionwise_layer_args),
+ convolution_layer(*convolution_layer_args),
+ positionwise_layer(*positionwise_layer_args),
+ normalize_before,
+ dropout,
+ concat_after,
+ )
+ for _ in range(num_blocks)
+ ]
+ )
+ if time_reduction_layer_type == "conv1d":
+ time_reduction_layer = TimeReductionLayer1D
+ time_reduction_layer_args = {
+ "channel": encoder_dim,
+ "out_dim": encoder_dim,
+ }
+ elif time_reduction_layer_type == "stream":
+ time_reduction_layer = TimeReductionLayerStream
+ time_reduction_layer_args = {
+ "channel": encoder_dim,
+ "out_dim": encoder_dim,
+ }
+ else:
+ time_reduction_layer = TimeReductionLayer2D
+ time_reduction_layer_args = {"encoder_dim": encoder_dim}
+
+ self.time_reduction_layer = time_reduction_layer(**time_reduction_layer_args)
+ self.time_recover_layer = nn.Linear(encoder_dim, encoder_dim)
+ self.final_proj = None
+ if output_size != encoder_dim:
+ self.final_proj = nn.Linear(encoder_dim, output_size)
+
+ def output_size(self) -> int:
+ return self._output_size
+
+ def forward(
+ self,
+ xs: torch.Tensor,
+ xs_lens: torch.Tensor,
+ decoding_chunk_size: int = 0,
+ num_decoding_left_chunks: int = -1,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ T = xs.size(1)
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
+ if self.global_cmvn is not None:
+ xs = self.global_cmvn(xs)
+ xs, pos_emb, masks = self.embed(xs, masks)
+ mask_pad = masks # (B, 1, T/subsample_rate)
+ chunk_masks = add_optional_chunk_mask(
+ xs,
+ masks,
+ self.use_dynamic_chunk,
+ self.use_dynamic_left_chunk,
+ decoding_chunk_size,
+ self.static_chunk_size,
+ num_decoding_left_chunks,
+ )
+ xs_lens = mask_pad.squeeze(1).sum(1)
+ xs = self.preln(xs)
+ recover_activations: List[
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
+ ] = []
+ index = 0
+ for i, layer in enumerate(self.encoders):
+ if self.reduce_idx is not None:
+ if self.time_reduce is not None and i in self.reduce_idx:
+ recover_activations.append((xs, chunk_masks, pos_emb, mask_pad))
+ xs, xs_lens, chunk_masks, mask_pad = self.time_reduction_layer(
+ xs, xs_lens, chunk_masks, mask_pad
+ )
+ pos_emb = pos_emb[:, ::2, :]
+ index += 1
+
+ if self.recover_idx is not None:
+ if self.time_reduce == "recover" and i in self.recover_idx:
+ index -= 1
+ (
+ recover_tensor,
+ recover_chunk_masks,
+ recover_pos_emb,
+ recover_mask_pad,
+ ) = recover_activations[index]
+ # recover output length for ctc decode
+ xs = xs.unsqueeze(2).repeat(1, 1, 2, 1).flatten(1, 2)
+ xs = self.time_recover_layer(xs)
+ recoverd_t = recover_tensor.size(1)
+ xs = recover_tensor + xs[:, :recoverd_t, :].contiguous()
+ chunk_masks = recover_chunk_masks
+ pos_emb = recover_pos_emb
+ mask_pad = recover_mask_pad
+ xs = xs.masked_fill(~mask_pad[:, 0, :].unsqueeze(-1), 0.0)
+
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
+
+ if self.final_proj is not None:
+ xs = self.final_proj(xs)
+ return xs, masks
+
+ def check_ascending_list(self):
+ if self.reduce_idx is not None:
+ assert self.reduce_idx == sorted(
+ self.reduce_idx
+ ), "reduce_idx should be int or ascending list"
+ if self.recover_idx is not None:
+ assert self.recover_idx == sorted(
+ self.recover_idx
+ ), "recover_idx should be int or ascending list"
+
+ def calculate_downsampling_factor(self, i: int) -> int:
+ if self.reduce_idx is None:
+ return 1
+ else:
+ reduce_exp, recover_exp = 0, 0
+ for exp, rd_idx in enumerate(self.reduce_idx):
+ if i >= rd_idx:
+ reduce_exp = exp + 1
+ if self.recover_idx is not None:
+ for exp, rc_idx in enumerate(self.recover_idx):
+ if i >= rc_idx:
+ recover_exp = exp + 1
+ return int(2 ** (reduce_exp - recover_exp))
+
+ def forward_chunk(
+ self,
+ xs: torch.Tensor,
+ offset: int,
+ required_cache_size: int,
+ att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
+ cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
+ att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """ Forward just one chunk
+
+ Args:
+ xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim),
+ where `time == (chunk_size - 1) * subsample_rate + \
+ subsample.right_context + 1`
+ offset (int): current offset in encoder output time stamp
+ required_cache_size (int): cache size required for next chunk
+ compuation
+ >=0: actual cache size
+ <0: means all history cache is required
+ att_cache (torch.Tensor): cache tensor for KEY & VALUE in
+ transformer/conformer attention, with shape
+ (elayers, head, cache_t1, d_k * 2), where
+ `head * d_k == hidden-dim` and
+ `cache_t1 == chunk_size * num_decoding_left_chunks`.
+ cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
+ (elayers, b=1, hidden-dim, cache_t2), where
+ `cache_t2 == cnn.lorder - 1`
+
+ Returns:
+ torch.Tensor: output of current input xs,
+ with shape (b=1, chunk_size, hidden-dim).
+ torch.Tensor: new attention cache required for next chunk, with
+ dynamic shape (elayers, head, ?, d_k * 2)
+ depending on required_cache_size.
+ torch.Tensor: new conformer cnn cache required for next chunk, with
+ same shape as the original cnn_cache.
+
+ """
+ assert xs.size(0) == 1
+ # tmp_masks is just for interface compatibility
+ tmp_masks = torch.ones(1, xs.size(1), device=xs.device, dtype=torch.bool)
+ tmp_masks = tmp_masks.unsqueeze(1)
+ if self.global_cmvn is not None:
+ xs = self.global_cmvn(xs)
+ # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
+ xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
+ # NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim)
+ elayers, cache_t1 = att_cache.size(0), att_cache.size(2)
+ chunk_size = xs.size(1)
+ attention_key_size = cache_t1 + chunk_size
+ pos_emb = self.embed.position_encoding(
+ offset=offset - cache_t1, size=attention_key_size
+ )
+ if required_cache_size < 0:
+ next_cache_start = 0
+ elif required_cache_size == 0:
+ next_cache_start = attention_key_size
+ else:
+ next_cache_start = max(attention_key_size - required_cache_size, 0)
+
+ r_att_cache = []
+ r_cnn_cache = []
+
+ mask_pad = torch.ones(1, xs.size(1), device=xs.device, dtype=torch.bool)
+ mask_pad = mask_pad.unsqueeze(1)
+ max_att_len: int = 0
+ recover_activations: List[
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
+ ] = []
+ index = 0
+ xs_lens = torch.tensor([xs.size(1)], device=xs.device, dtype=torch.int)
+ xs = self.preln(xs)
+ for i, layer in enumerate(self.encoders):
+ # NOTE(xcsong): Before layer.forward
+ # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
+ # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
+ if self.reduce_idx is not None:
+ if self.time_reduce is not None and i in self.reduce_idx:
+ recover_activations.append((xs, att_mask, pos_emb, mask_pad))
+ xs, xs_lens, att_mask, mask_pad = self.time_reduction_layer(
+ xs, xs_lens, att_mask, mask_pad
+ )
+ pos_emb = pos_emb[:, ::2, :]
+ index += 1
+
+ if self.recover_idx is not None:
+ if self.time_reduce == "recover" and i in self.recover_idx:
+ index -= 1
+ (
+ recover_tensor,
+ recover_att_mask,
+ recover_pos_emb,
+ recover_mask_pad,
+ ) = recover_activations[index]
+ # recover output length for ctc decode
+ xs = xs.unsqueeze(2).repeat(1, 1, 2, 1).flatten(1, 2)
+ xs = self.time_recover_layer(xs)
+ recoverd_t = recover_tensor.size(1)
+ xs = recover_tensor + xs[:, :recoverd_t, :].contiguous()
+ att_mask = recover_att_mask
+ pos_emb = recover_pos_emb
+ mask_pad = recover_mask_pad
+ if att_mask.size(1) != 0:
+ xs = xs.masked_fill(~att_mask[:, 0, :].unsqueeze(-1), 0.0)
+
+ factor = self.calculate_downsampling_factor(i)
+
+ xs, _, new_att_cache, new_cnn_cache = layer(
+ xs,
+ att_mask,
+ pos_emb,
+ att_cache=(
+ att_cache[i : i + 1][:, :, ::factor, :][
+ :, :, : pos_emb.size(1) - xs.size(1), :
+ ]
+ if elayers > 0
+ else att_cache[:, :, ::factor, :]
+ ),
+ cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache,
+ )
+ # NOTE(xcsong): After layer.forward
+ # shape(new_att_cache) is (1, head, attention_key_size, d_k * 2),
+ # shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2)
+ cached_att = new_att_cache[:, :, next_cache_start // factor :, :]
+ cached_cnn = new_cnn_cache.unsqueeze(0)
+ cached_att = (
+ cached_att.unsqueeze(3).repeat(1, 1, 1, factor, 1).flatten(2, 3)
+ )
+ if i == 0:
+ # record length for the first block as max length
+ max_att_len = cached_att.size(2)
+ r_att_cache.append(cached_att[:, :, :max_att_len, :])
+ r_cnn_cache.append(cached_cnn)
+ # NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
+ # ? may be larger than cache_t1, it depends on required_cache_size
+ r_att_cache = torch.cat(r_att_cache, dim=0)
+ # NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
+ r_cnn_cache = torch.cat(r_cnn_cache, dim=0)
+
+ if self.final_proj is not None:
+ xs = self.final_proj(xs)
+ return (xs, r_att_cache, r_cnn_cache)
+
+ def forward_chunk_by_chunk(
+ self,
+ xs: torch.Tensor,
+ decoding_chunk_size: int,
+ num_decoding_left_chunks: int = -1,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Forward input chunk by chunk with chunk_size like a streaming
+ fashion
+
+ Here we should pay special attention to computation cache in the
+ streaming style forward chunk by chunk. Three things should be taken
+ into account for computation in the current network:
+ 1. transformer/conformer encoder layers output cache
+ 2. convolution in conformer
+ 3. convolution in subsampling
+
+ However, we don't implement subsampling cache for:
+ 1. We can control subsampling module to output the right result by
+ overlapping input instead of cache left context, even though it
+ wastes some computation, but subsampling only takes a very
+ small fraction of computation in the whole model.
+ 2. Typically, there are several covolution layers with subsampling
+ in subsampling module, it is tricky and complicated to do cache
+ with different convolution layers with different subsampling
+ rate.
+ 3. Currently, nn.Sequential is used to stack all the convolution
+ layers in subsampling, we need to rewrite it to make it work
+ with cache, which is not prefered.
+ Args:
+ xs (torch.Tensor): (1, max_len, dim)
+ chunk_size (int): decoding chunk size
+ """
+ assert decoding_chunk_size > 0
+ # The model is trained by static or dynamic chunk
+ assert self.static_chunk_size > 0 or self.use_dynamic_chunk
+ subsampling = self.embed.subsampling_rate
+ context = self.embed.right_context + 1 # Add current frame
+ stride = subsampling * decoding_chunk_size
+ decoding_window = (decoding_chunk_size - 1) * subsampling + context
+ num_frames = xs.size(1)
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
+ outputs = []
+ offset = 0
+ required_cache_size = decoding_chunk_size * num_decoding_left_chunks
+
+ # Feed forward overlap input step by step
+ for cur in range(0, num_frames - context + 1, stride):
+ end = min(cur + decoding_window, num_frames)
+ chunk_xs = xs[:, cur:end, :]
+ (y, att_cache, cnn_cache) = self.forward_chunk(
+ chunk_xs, offset, required_cache_size, att_cache, cnn_cache
+ )
+ outputs.append(y)
+ offset += y.size(1)
+ ys = torch.cat(outputs, 1)
+ masks = torch.ones((1, 1, ys.size(1)), device=ys.device, dtype=torch.bool)
+ return ys, masks
diff --git a/modules/wenet_extractor/squeezeformer/encoder_layer.py b/modules/wenet_extractor/squeezeformer/encoder_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..74e7c0fdeee6f7156468b14416d268702cf4d40a
--- /dev/null
+++ b/modules/wenet_extractor/squeezeformer/encoder_layer.py
@@ -0,0 +1,129 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+"""SqueezeformerEncoderLayer definition."""
+
+import torch
+import torch.nn as nn
+from typing import Optional, Tuple
+
+
+class SqueezeformerEncoderLayer(nn.Module):
+ """Encoder layer module.
+ Args:
+ size (int): Input dimension.
+ self_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
+ instance can be used as the argument.
+ feed_forward1 (torch.nn.Module): Feed-forward module instance.
+ `PositionwiseFeedForward` instance can be used as the argument.
+ conv_module (torch.nn.Module): Convolution module instance.
+ `ConvlutionModule` instance can be used as the argument.
+ feed_forward2 (torch.nn.Module): Feed-forward module instance.
+ `PositionwiseFeedForward` instance can be used as the argument.
+ dropout_rate (float): Dropout rate.
+ normalize_before (bool):
+ True: use layer_norm before each sub-block.
+ False: use layer_norm after each sub-block.
+ """
+
+ def __init__(
+ self,
+ size: int,
+ self_attn: torch.nn.Module,
+ feed_forward1: Optional[nn.Module] = None,
+ conv_module: Optional[nn.Module] = None,
+ feed_forward2: Optional[nn.Module] = None,
+ normalize_before: bool = False,
+ dropout_rate: float = 0.1,
+ concat_after: bool = False,
+ ):
+ super(SqueezeformerEncoderLayer, self).__init__()
+ self.size = size
+ self.self_attn = self_attn
+ self.layer_norm1 = nn.LayerNorm(size)
+ self.ffn1 = feed_forward1
+ self.layer_norm2 = nn.LayerNorm(size)
+ self.conv_module = conv_module
+ self.layer_norm3 = nn.LayerNorm(size)
+ self.ffn2 = feed_forward2
+ self.layer_norm4 = nn.LayerNorm(size)
+ self.normalize_before = normalize_before
+ self.dropout = nn.Dropout(dropout_rate)
+ self.concat_after = concat_after
+ if concat_after:
+ self.concat_linear = nn.Linear(size + size, size)
+ else:
+ self.concat_linear = nn.Identity()
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mask: torch.Tensor,
+ pos_emb: torch.Tensor,
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # self attention module
+ residual = x
+ if self.normalize_before:
+ x = self.layer_norm1(x)
+ x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, att_cache)
+ if self.concat_after:
+ x_concat = torch.cat((x, x_att), dim=-1)
+ x = residual + self.concat_linear(x_concat)
+ else:
+ x = residual + self.dropout(x_att)
+ if not self.normalize_before:
+ x = self.layer_norm1(x)
+
+ # ffn module
+ residual = x
+ if self.normalize_before:
+ x = self.layer_norm2(x)
+ x = self.ffn1(x)
+ x = residual + self.dropout(x)
+ if not self.normalize_before:
+ x = self.layer_norm2(x)
+
+ # conv module
+ new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
+ residual = x
+ if self.normalize_before:
+ x = self.layer_norm3(x)
+ x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
+ x = residual + self.dropout(x)
+ if not self.normalize_before:
+ x = self.layer_norm3(x)
+
+ # ffn module
+ residual = x
+ if self.normalize_before:
+ x = self.layer_norm4(x)
+ x = self.ffn2(x)
+ # we do not use dropout here since it is inside feed forward function
+ x = residual + self.dropout(x)
+ if not self.normalize_before:
+ x = self.layer_norm4(x)
+
+ return x, mask, new_att_cache, new_cnn_cache
diff --git a/modules/wenet_extractor/squeezeformer/positionwise_feed_forward.py b/modules/wenet_extractor/squeezeformer/positionwise_feed_forward.py
new file mode 100644
index 0000000000000000000000000000000000000000..bfd4aabf9a7e246243d9792616bbe995c2a8a6ba
--- /dev/null
+++ b/modules/wenet_extractor/squeezeformer/positionwise_feed_forward.py
@@ -0,0 +1,88 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+"""Positionwise feed forward layer definition."""
+
+import torch
+
+
+class PositionwiseFeedForward(torch.nn.Module):
+ """Positionwise feed forward layer.
+
+ FeedForward are appied on each position of the sequence.
+ The output dim is same with the input dim.
+
+ Args:
+ idim (int): Input dimenstion.
+ hidden_units (int): The number of hidden units.
+ dropout_rate (float): Dropout rate.
+ activation (torch.nn.Module): Activation function
+ """
+
+ def __init__(
+ self,
+ idim: int,
+ hidden_units: int,
+ dropout_rate: float,
+ activation: torch.nn.Module = torch.nn.ReLU(),
+ adaptive_scale: bool = False,
+ init_weights: bool = False,
+ ):
+ """Construct a PositionwiseFeedForward object."""
+ super(PositionwiseFeedForward, self).__init__()
+ self.idim = idim
+ self.hidden_units = hidden_units
+ self.w_1 = torch.nn.Linear(idim, hidden_units)
+ self.activation = activation
+ self.dropout = torch.nn.Dropout(dropout_rate)
+ self.w_2 = torch.nn.Linear(hidden_units, idim)
+ self.ada_scale = None
+ self.ada_bias = None
+ self.adaptive_scale = adaptive_scale
+ self.ada_scale = torch.nn.Parameter(
+ torch.ones([1, 1, idim]), requires_grad=adaptive_scale
+ )
+ self.ada_bias = torch.nn.Parameter(
+ torch.zeros([1, 1, idim]), requires_grad=adaptive_scale
+ )
+ if init_weights:
+ self.init_weights()
+
+ def init_weights(self):
+ ffn1_max = self.idim**-0.5
+ ffn2_max = self.hidden_units**-0.5
+ torch.nn.init.uniform_(self.w_1.weight.data, -ffn1_max, ffn1_max)
+ torch.nn.init.uniform_(self.w_1.bias.data, -ffn1_max, ffn1_max)
+ torch.nn.init.uniform_(self.w_2.weight.data, -ffn2_max, ffn2_max)
+ torch.nn.init.uniform_(self.w_2.bias.data, -ffn2_max, ffn2_max)
+
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
+ """Forward function.
+
+ Args:
+ xs: input tensor (B, L, D)
+ Returns:
+ output tensor, (B, L, D)
+ """
+ if self.adaptive_scale:
+ xs = self.ada_scale * xs + self.ada_bias
+ return self.w_2(self.dropout(self.activation(self.w_1(xs))))
diff --git a/modules/wenet_extractor/squeezeformer/subsampling.py b/modules/wenet_extractor/squeezeformer/subsampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..329dd6c51f1719d73307d636dd625cf98a34cd34
--- /dev/null
+++ b/modules/wenet_extractor/squeezeformer/subsampling.py
@@ -0,0 +1,321 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+"""DepthwiseConv2dSubsampling4 and TimeReductionLayer definition."""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from modules.wenet_extractor.transformer.subsampling import BaseSubsampling
+from typing import Tuple
+from modules.wenet_extractor.squeezeformer.conv2d import Conv2dValid
+
+
+class DepthwiseConv2dSubsampling4(BaseSubsampling):
+ """Depthwise Convolutional 2D subsampling (to 1/4 length).
+
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ pos_enc_class (nn.Module): position encoding class.
+ dw_stride (int): Whether do depthwise convolution.
+ input_size (int): filter bank dimension.
+
+ """
+
+ def __init__(
+ self,
+ idim: int,
+ odim: int,
+ pos_enc_class: torch.nn.Module,
+ dw_stride: bool = False,
+ input_size: int = 80,
+ input_dropout_rate: float = 0.1,
+ init_weights: bool = True,
+ ):
+ super(DepthwiseConv2dSubsampling4, self).__init__()
+ self.idim = idim
+ self.odim = odim
+ self.pw_conv = nn.Conv2d(
+ in_channels=idim, out_channels=odim, kernel_size=3, stride=2
+ )
+ self.act1 = nn.ReLU()
+ self.dw_conv = nn.Conv2d(
+ in_channels=odim,
+ out_channels=odim,
+ kernel_size=3,
+ stride=2,
+ groups=odim if dw_stride else 1,
+ )
+ self.act2 = nn.ReLU()
+ self.pos_enc = pos_enc_class
+ self.input_proj = nn.Sequential(
+ nn.Linear(odim * (((input_size - 1) // 2 - 1) // 2), odim),
+ nn.Dropout(p=input_dropout_rate),
+ )
+ if init_weights:
+ linear_max = (odim * input_size / 4) ** -0.5
+ torch.nn.init.uniform_(
+ self.input_proj.state_dict()["0.weight"], -linear_max, linear_max
+ )
+ torch.nn.init.uniform_(
+ self.input_proj.state_dict()["0.bias"], -linear_max, linear_max
+ )
+ self.subsampling_rate = 4
+ # 6 = (3 - 1) * 1 + (3 - 1) * 2
+ self.right_context = 6
+
+ def forward(
+ self, x: torch.Tensor, x_mask: torch.Tensor, offset: int = 0
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ x = x.unsqueeze(1) # (b, c=1, t, f)
+ x = self.pw_conv(x)
+ x = self.act1(x)
+ x = self.dw_conv(x)
+ x = self.act2(x)
+ b, c, t, f = x.size()
+ x = x.permute(0, 2, 1, 3)
+ x = x.contiguous().view(b, t, c * f)
+ x, pos_emb = self.pos_enc(x, offset)
+ x = self.input_proj(x)
+ return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2]
+
+
+class TimeReductionLayer1D(nn.Module):
+ """
+ Modified NeMo,
+ Squeezeformer Time Reduction procedure.
+ Downsamples the audio by `stride` in the time dimension.
+ Args:
+ channel (int): input dimension of
+ MultiheadAttentionMechanism and PositionwiseFeedForward
+ out_dim (int): Output dimension of the module.
+ kernel_size (int): Conv kernel size for
+ depthwise convolution in convolution module
+ stride (int): Downsampling factor in time dimension.
+ """
+
+ def __init__(
+ self, channel: int, out_dim: int, kernel_size: int = 5, stride: int = 2
+ ):
+ super(TimeReductionLayer1D, self).__init__()
+
+ self.channel = channel
+ self.out_dim = out_dim
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.padding = max(0, self.kernel_size - self.stride)
+
+ self.dw_conv = nn.Conv1d(
+ in_channels=channel,
+ out_channels=channel,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=self.padding,
+ groups=channel,
+ )
+
+ self.pw_conv = nn.Conv1d(
+ in_channels=channel,
+ out_channels=out_dim,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ groups=1,
+ )
+
+ self.init_weights()
+
+ def init_weights(self):
+ dw_max = self.kernel_size**-0.5
+ pw_max = self.channel**-0.5
+ torch.nn.init.uniform_(self.dw_conv.weight, -dw_max, dw_max)
+ torch.nn.init.uniform_(self.dw_conv.bias, -dw_max, dw_max)
+ torch.nn.init.uniform_(self.pw_conv.weight, -pw_max, pw_max)
+ torch.nn.init.uniform_(self.pw_conv.bias, -pw_max, pw_max)
+
+ def forward(
+ self,
+ xs,
+ xs_lens: torch.Tensor,
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ ):
+ xs = xs.transpose(1, 2) # [B, C, T]
+ xs = xs.masked_fill(mask_pad.eq(0), 0.0)
+
+ xs = self.dw_conv(xs)
+ xs = self.pw_conv(xs)
+
+ xs = xs.transpose(1, 2) # [B, T, C]
+
+ B, T, D = xs.size()
+ mask = mask[:, :: self.stride, :: self.stride]
+ mask_pad = mask_pad[:, :, :: self.stride]
+ L = mask_pad.size(-1)
+ # For JIT exporting, we remove F.pad operator.
+ if L - T < 0:
+ xs = xs[:, : L - T, :].contiguous()
+ else:
+ dummy_pad = torch.zeros(B, L - T, D, device=xs.device)
+ xs = torch.cat([xs, dummy_pad], dim=1)
+
+ xs_lens = torch.div(xs_lens + 1, 2, rounding_mode="trunc")
+ return xs, xs_lens, mask, mask_pad
+
+
+class TimeReductionLayer2D(nn.Module):
+ def __init__(self, kernel_size: int = 5, stride: int = 2, encoder_dim: int = 256):
+ super(TimeReductionLayer2D, self).__init__()
+ self.encoder_dim = encoder_dim
+ self.kernel_size = kernel_size
+ self.dw_conv = Conv2dValid(
+ in_channels=encoder_dim,
+ out_channels=encoder_dim,
+ kernel_size=(kernel_size, 1),
+ stride=stride,
+ valid_trigy=True,
+ )
+ self.pw_conv = Conv2dValid(
+ in_channels=encoder_dim,
+ out_channels=encoder_dim,
+ kernel_size=1,
+ stride=1,
+ valid_trigx=False,
+ valid_trigy=False,
+ )
+
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.init_weights()
+
+ def init_weights(self):
+ dw_max = self.kernel_size**-0.5
+ pw_max = self.encoder_dim**-0.5
+ torch.nn.init.uniform_(self.dw_conv.weight, -dw_max, dw_max)
+ torch.nn.init.uniform_(self.dw_conv.bias, -dw_max, dw_max)
+ torch.nn.init.uniform_(self.pw_conv.weight, -pw_max, pw_max)
+ torch.nn.init.uniform_(self.pw_conv.bias, -pw_max, pw_max)
+
+ def forward(
+ self,
+ xs: torch.Tensor,
+ xs_lens: torch.Tensor,
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ xs = xs.masked_fill(mask_pad.transpose(1, 2).eq(0), 0.0)
+ xs = xs.unsqueeze(2)
+ padding1 = self.kernel_size - self.stride
+ xs = F.pad(xs, (0, 0, 0, 0, 0, padding1, 0, 0), mode="constant", value=0.0)
+ xs = self.dw_conv(xs.permute(0, 3, 1, 2))
+ xs = self.pw_conv(xs).permute(0, 3, 2, 1).squeeze(1).contiguous()
+ tmp_length = xs.size(1)
+ xs_lens = torch.div(xs_lens + 1, 2, rounding_mode="trunc")
+ padding2 = max(0, (xs_lens.max() - tmp_length).data.item())
+ batch_size, hidden = xs.size(0), xs.size(-1)
+ dummy_pad = torch.zeros(batch_size, padding2, hidden, device=xs.device)
+ xs = torch.cat([xs, dummy_pad], dim=1)
+ mask = mask[:, ::2, ::2]
+ mask_pad = mask_pad[:, :, ::2]
+ return xs, xs_lens, mask, mask_pad
+
+
+class TimeReductionLayerStream(nn.Module):
+ """
+ Squeezeformer Time Reduction procedure.
+ Downsamples the audio by `stride` in the time dimension.
+ Args:
+ channel (int): input dimension of
+ MultiheadAttentionMechanism and PositionwiseFeedForward
+ out_dim (int): Output dimension of the module.
+ kernel_size (int): Conv kernel size for
+ depthwise convolution in convolution module
+ stride (int): Downsampling factor in time dimension.
+ """
+
+ def __init__(
+ self, channel: int, out_dim: int, kernel_size: int = 1, stride: int = 2
+ ):
+ super(TimeReductionLayerStream, self).__init__()
+
+ self.channel = channel
+ self.out_dim = out_dim
+ self.kernel_size = kernel_size
+ self.stride = stride
+
+ self.dw_conv = nn.Conv1d(
+ in_channels=channel,
+ out_channels=channel,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=0,
+ groups=channel,
+ )
+
+ self.pw_conv = nn.Conv1d(
+ in_channels=channel,
+ out_channels=out_dim,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ groups=1,
+ )
+
+ self.init_weights()
+
+ def init_weights(self):
+ dw_max = self.kernel_size**-0.5
+ pw_max = self.channel**-0.5
+ torch.nn.init.uniform_(self.dw_conv.weight, -dw_max, dw_max)
+ torch.nn.init.uniform_(self.dw_conv.bias, -dw_max, dw_max)
+ torch.nn.init.uniform_(self.pw_conv.weight, -pw_max, pw_max)
+ torch.nn.init.uniform_(self.pw_conv.bias, -pw_max, pw_max)
+
+ def forward(
+ self,
+ xs,
+ xs_lens: torch.Tensor,
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ ):
+ xs = xs.transpose(1, 2) # [B, C, T]
+ xs = xs.masked_fill(mask_pad.eq(0), 0.0)
+
+ xs = self.dw_conv(xs)
+ xs = self.pw_conv(xs)
+
+ xs = xs.transpose(1, 2) # [B, T, C]
+
+ B, T, D = xs.size()
+ mask = mask[:, :: self.stride, :: self.stride]
+ mask_pad = mask_pad[:, :, :: self.stride]
+ L = mask_pad.size(-1)
+ # For JIT exporting, we remove F.pad operator.
+ if L - T < 0:
+ xs = xs[:, : L - T, :].contiguous()
+ else:
+ dummy_pad = torch.zeros(B, L - T, D, device=xs.device)
+ xs = torch.cat([xs, dummy_pad], dim=1)
+
+ xs_lens = torch.div(xs_lens + 1, 2, rounding_mode="trunc")
+ return xs, xs_lens, mask, mask_pad
diff --git a/modules/wenet_extractor/transducer/__init__.py b/modules/wenet_extractor/transducer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/modules/wenet_extractor/transducer/joint.py b/modules/wenet_extractor/transducer/joint.py
new file mode 100644
index 0000000000000000000000000000000000000000..8592136c9f33902942fb01ee42af82761b7ad26b
--- /dev/null
+++ b/modules/wenet_extractor/transducer/joint.py
@@ -0,0 +1,72 @@
+from typing import Optional
+
+import torch
+from torch import nn
+from modules.wenet_extractor.utils.common import get_activation
+
+
+class TransducerJoint(torch.nn.Module):
+ def __init__(
+ self,
+ voca_size: int,
+ enc_output_size: int,
+ pred_output_size: int,
+ join_dim: int,
+ prejoin_linear: bool = True,
+ postjoin_linear: bool = False,
+ joint_mode: str = "add",
+ activation: str = "tanh",
+ ):
+ # TODO(Mddct): concat in future
+ assert joint_mode in ["add"]
+ super().__init__()
+
+ self.activatoin = get_activation(activation)
+ self.prejoin_linear = prejoin_linear
+ self.postjoin_linear = postjoin_linear
+ self.joint_mode = joint_mode
+
+ if not self.prejoin_linear and not self.postjoin_linear:
+ assert enc_output_size == pred_output_size == join_dim
+ # torchscript compatibility
+ self.enc_ffn: Optional[nn.Linear] = None
+ self.pred_ffn: Optional[nn.Linear] = None
+ if self.prejoin_linear:
+ self.enc_ffn = nn.Linear(enc_output_size, join_dim)
+ self.pred_ffn = nn.Linear(pred_output_size, join_dim)
+ # torchscript compatibility
+ self.post_ffn: Optional[nn.Linear] = None
+ if self.postjoin_linear:
+ self.post_ffn = nn.Linear(join_dim, join_dim)
+
+ self.ffn_out = nn.Linear(join_dim, voca_size)
+
+ def forward(self, enc_out: torch.Tensor, pred_out: torch.Tensor):
+ """
+ Args:
+ enc_out (torch.Tensor): [B, T, E]
+ pred_out (torch.Tensor): [B, T, P]
+ Return:
+ [B,T,U,V]
+ """
+ if (
+ self.prejoin_linear
+ and self.enc_ffn is not None
+ and self.pred_ffn is not None
+ ):
+ enc_out = self.enc_ffn(enc_out) # [B,T,E] -> [B,T,V]
+ pred_out = self.pred_ffn(pred_out)
+
+ enc_out = enc_out.unsqueeze(2) # [B,T,V] -> [B,T,1,V]
+ pred_out = pred_out.unsqueeze(1) # [B,U,V] -> [B,1 U, V]
+
+ # TODO(Mddct): concat joint
+ _ = self.joint_mode
+ out = enc_out + pred_out # [B,T,U,V]
+
+ if self.postjoin_linear and self.post_ffn is not None:
+ out = self.post_ffn(out)
+
+ out = self.activatoin(out)
+ out = self.ffn_out(out)
+ return out
diff --git a/modules/wenet_extractor/transducer/predictor.py b/modules/wenet_extractor/transducer/predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c1ffdc7b30c23147ba80ae928eb3bd18f602932
--- /dev/null
+++ b/modules/wenet_extractor/transducer/predictor.py
@@ -0,0 +1,477 @@
+from typing import List, Optional, Tuple
+
+import torch
+from torch import nn
+from modules.wenet_extractor.utils.common import get_activation, get_rnn
+
+
+def ApplyPadding(input, padding, pad_value) -> torch.Tensor:
+ """
+ Args:
+ input: [bs, max_time_step, dim]
+ padding: [bs, max_time_step]
+ """
+ return padding * pad_value + input * (1 - padding)
+
+
+class PredictorBase(torch.nn.Module):
+ # NOTE(Mddct): We can use ABC abstract here, but
+ # keep this class simple enough for now
+ def __init__(self) -> None:
+ super().__init__()
+
+ def init_state(
+ self, batch_size: int, device: torch.device, method: str = "zero"
+ ) -> List[torch.Tensor]:
+ _, _, _ = batch_size, method, device
+ raise NotImplementedError("this is a base precictor")
+
+ def batch_to_cache(self, cache: List[torch.Tensor]) -> List[List[torch.Tensor]]:
+ _ = cache
+ raise NotImplementedError("this is a base precictor")
+
+ def cache_to_batch(self, cache: List[List[torch.Tensor]]) -> List[torch.Tensor]:
+ _ = cache
+ raise NotImplementedError("this is a base precictor")
+
+ def forward(
+ self,
+ input: torch.Tensor,
+ cache: Optional[List[torch.Tensor]] = None,
+ ):
+ (
+ _,
+ _,
+ ) = (
+ input,
+ cache,
+ )
+ raise NotImplementedError("this is a base precictor")
+
+ def forward_step(
+ self, input: torch.Tensor, padding: torch.Tensor, cache: List[torch.Tensor]
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+ (
+ _,
+ _,
+ _,
+ ) = (
+ input,
+ padding,
+ cache,
+ )
+ raise NotImplementedError("this is a base precictor")
+
+
+class RNNPredictor(PredictorBase):
+ def __init__(
+ self,
+ voca_size: int,
+ embed_size: int,
+ output_size: int,
+ embed_dropout: float,
+ hidden_size: int,
+ num_layers: int,
+ bias: bool = True,
+ rnn_type: str = "lstm",
+ dropout: float = 0.1,
+ ) -> None:
+ super().__init__()
+ self.n_layers = num_layers
+ self.hidden_size = hidden_size
+ # disable rnn base out projection
+ self.embed = nn.Embedding(voca_size, embed_size)
+ self.dropout = nn.Dropout(embed_dropout)
+ # NOTE(Mddct): rnn base from torch not support layer norm
+ # will add layer norm and prune value in cell and layer
+ # ref: https://github.com/Mddct/neural-lm/blob/main/models/gru_cell.py
+ self.rnn = get_rnn(rnn_type=rnn_type)(
+ input_size=embed_size,
+ hidden_size=hidden_size,
+ num_layers=num_layers,
+ bias=bias,
+ batch_first=True,
+ dropout=dropout,
+ )
+ self.projection = nn.Linear(hidden_size, output_size)
+
+ def forward(
+ self,
+ input: torch.Tensor,
+ cache: Optional[List[torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ """
+ Args:
+ input (torch.Tensor): [batch, max_time).
+ padding (torch.Tensor): [batch, max_time]
+ cache : rnn predictor cache[0] == state_m
+ cache[1] == state_c
+ Returns:
+ output: [batch, max_time, output_size]
+ """
+
+ # NOTE(Mddct): we don't use pack input format
+ embed = self.embed(input) # [batch, max_time, emb_size]
+ embed = self.dropout(embed)
+ states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
+ if cache is None:
+ state = self.init_state(batch_size=input.size(0), device=input.device)
+ states = (state[0], state[1])
+ else:
+ assert len(cache) == 2
+ states = (cache[0], cache[1])
+ out, (m, c) = self.rnn(embed, states)
+ out = self.projection(out)
+
+ # NOTE(Mddct): Although we don't use staate in transducer
+ # training forward, we need make it right for padding value
+ # so we create forward_step for infering, forward for training
+ _, _ = m, c
+ return out
+
+ def batch_to_cache(self, cache: List[torch.Tensor]) -> List[List[torch.Tensor]]:
+ """
+ Args:
+ cache: [state_m, state_c]
+ state_ms: [1*n_layers, bs, ...]
+ state_cs: [1*n_layers, bs, ...]
+ Returns:
+ new_cache: [[state_m_1, state_c_1], [state_m_2, state_c_2]...]
+ """
+ assert len(cache) == 2
+ state_ms = cache[0]
+ state_cs = cache[1]
+
+ assert state_ms.size(1) == state_cs.size(1)
+
+ new_cache: List[List[torch.Tensor]] = []
+ for state_m, state_c in zip(
+ torch.split(state_ms, 1, dim=1), torch.split(state_cs, 1, dim=1)
+ ):
+ new_cache.append([state_m, state_c])
+ return new_cache
+
+ def cache_to_batch(self, cache: List[List[torch.Tensor]]) -> List[torch.Tensor]:
+ """
+ Args:
+ cache : [[state_m_1, state_c_1], [state_m_1, state_c_1]...]
+
+ Returns:
+ new_caceh: [state_ms, state_cs],
+ state_ms: [1*n_layers, bs, ...]
+ state_cs: [1*n_layers, bs, ...]
+ """
+ state_ms = torch.cat([states[0] for states in cache], dim=1)
+ state_cs = torch.cat([states[1] for states in cache], dim=1)
+ return [state_ms, state_cs]
+
+ def init_state(
+ self,
+ batch_size: int,
+ device: torch.device,
+ method: str = "zero",
+ ) -> List[torch.Tensor]:
+ assert batch_size > 0
+ # TODO(Mddct): xavier init method
+ _ = method
+ return [
+ torch.zeros(1 * self.n_layers, batch_size, self.hidden_size, device=device),
+ torch.zeros(1 * self.n_layers, batch_size, self.hidden_size, device=device),
+ ]
+
+ def forward_step(
+ self, input: torch.Tensor, padding: torch.Tensor, cache: List[torch.Tensor]
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+ """
+ Args:
+ input (torch.Tensor): [batch_size, time_step=1]
+ padding (torch.Tensor): [batch_size,1], 1 is padding value
+ cache : rnn predictor cache[0] == state_m
+ cache[1] == state_c
+ """
+ assert len(cache) == 2
+ state_m, state_c = cache[0], cache[1]
+ embed = self.embed(input) # [batch, 1, emb_size]
+ embed = self.dropout(embed)
+ out, (m, c) = self.rnn(embed, (state_m, state_c))
+
+ out = self.projection(out)
+ m = ApplyPadding(m, padding.unsqueeze(0), state_m)
+ c = ApplyPadding(c, padding.unsqueeze(0), state_c)
+
+ return (out, [m, c])
+
+
+class EmbeddingPredictor(PredictorBase):
+ """Embedding predictor
+
+ Described in:
+ https://arxiv.org/pdf/2109.07513.pdf
+
+ embed-> proj -> layer norm -> swish
+ """
+
+ def __init__(
+ self,
+ voca_size: int,
+ embed_size: int,
+ embed_dropout: float,
+ n_head: int,
+ history_size: int = 2,
+ activation: str = "swish",
+ bias: bool = False,
+ layer_norm_epsilon: float = 1e-5,
+ ) -> None:
+ super().__init__()
+ # multi head
+ self.num_heads = n_head
+ self.embed_size = embed_size
+ self.context_size = history_size + 1
+ self.pos_embed = torch.nn.Linear(
+ embed_size * self.context_size, self.num_heads, bias=bias
+ )
+ self.embed = nn.Embedding(voca_size, self.embed_size)
+ self.embed_dropout = nn.Dropout(p=embed_dropout)
+ self.ffn = nn.Linear(self.embed_size, self.embed_size)
+ self.norm = nn.LayerNorm(self.embed_size, eps=layer_norm_epsilon)
+ self.activatoin = get_activation(activation)
+
+ def init_state(
+ self, batch_size: int, device: torch.device, method: str = "zero"
+ ) -> List[torch.Tensor]:
+ assert batch_size > 0
+ _ = method
+ return [
+ torch.zeros(
+ batch_size, self.context_size - 1, self.embed_size, device=device
+ ),
+ ]
+
+ def batch_to_cache(self, cache: List[torch.Tensor]) -> List[List[torch.Tensor]]:
+ """
+ Args:
+ cache : [history]
+ history: [bs, ...]
+ Returns:
+ new_ache : [[history_1], [history_2], [history_3]...]
+ """
+ assert len(cache) == 1
+ cache_0 = cache[0]
+ history: List[List[torch.Tensor]] = []
+ for h in torch.split(cache_0, 1, dim=0):
+ history.append([h])
+ return history
+
+ def cache_to_batch(self, cache: List[List[torch.Tensor]]) -> List[torch.Tensor]:
+ """
+ Args:
+ cache : [[history_1], [history_2], [history3]...]
+
+ Returns:
+ new_caceh: [history],
+ history: [bs, ...]
+ """
+ history = torch.cat([h[0] for h in cache], dim=0)
+ return [history]
+
+ def forward(self, input: torch.Tensor, cache: Optional[List[torch.Tensor]] = None):
+ """forward for training"""
+ input = self.embed(input) # [bs, seq_len, embed]
+ input = self.embed_dropout(input)
+ if cache is None:
+ zeros = self.init_state(input.size(0), device=input.device)[0]
+ else:
+ assert len(cache) == 1
+ zeros = cache[0]
+
+ input = torch.cat(
+ (zeros, input), dim=1
+ ) # [bs, context_size-1 + seq_len, embed]
+
+ input = input.unfold(1, self.context_size, 1).permute(
+ 0, 1, 3, 2
+ ) # [bs, seq_len, context_size, embed]
+ # multi head pos: [n_head, embed, context_size]
+ multi_head_pos = self.pos_embed.weight.view(
+ self.num_heads, self.embed_size, self.context_size
+ )
+
+ # broadcast dot attenton
+ input_expand = input.unsqueeze(2) # [bs, seq_len, 1, context_size, embed]
+ multi_head_pos = multi_head_pos.permute(
+ 0, 2, 1
+ ) # [num_heads, context_size, embed]
+
+ # [bs, seq_len, num_heads, context_size, embed]
+ weight = input_expand * multi_head_pos
+ weight = weight.sum(dim=-1, keepdim=False).unsqueeze(
+ 3
+ ) # [bs, seq_len, num_heads, 1, context_size]
+ output = weight.matmul(input_expand).squeeze(
+ dim=3
+ ) # [bs, seq_len, num_heads, embed]
+ output = output.sum(dim=2) # [bs, seq_len, embed]
+ output = output / (self.num_heads * self.context_size)
+
+ output = self.ffn(output)
+ output = self.norm(output)
+ output = self.activatoin(output)
+ return output
+
+ def forward_step(
+ self,
+ input: torch.Tensor,
+ padding: torch.Tensor,
+ cache: List[torch.Tensor],
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+ """forward step for inference
+ Args:
+ input (torch.Tensor): [batch_size, time_step=1]
+ padding (torch.Tensor): [batch_size,1], 1 is padding value
+ cache: for embedding predictor, cache[0] == history
+ """
+ assert input.size(1) == 1
+ assert len(cache) == 1
+ history = cache[0]
+ assert history.size(1) == self.context_size - 1
+ input = self.embed(input) # [bs, 1, embed]
+ input = self.embed_dropout(input)
+ context_input = torch.cat((history, input), dim=1)
+ input_expand = context_input.unsqueeze(1).unsqueeze(
+ 2
+ ) # [bs, 1, 1, context_size, embed]
+
+ # multi head pos: [n_head, embed, context_size]
+ multi_head_pos = self.pos_embed.weight.view(
+ self.num_heads, self.embed_size, self.context_size
+ )
+
+ multi_head_pos = multi_head_pos.permute(
+ 0, 2, 1
+ ) # [num_heads, context_size, embed]
+ # [bs, 1, num_heads, context_size, embed]
+ weight = input_expand * multi_head_pos
+ weight = weight.sum(dim=-1, keepdim=False).unsqueeze(
+ 3
+ ) # [bs, 1, num_heads, 1, context_size]
+ output = weight.matmul(input_expand).squeeze(dim=3) # [bs, 1, num_heads, embed]
+ output = output.sum(dim=2) # [bs, 1, embed]
+ output = output / (self.num_heads * self.context_size)
+
+ output = self.ffn(output)
+ output = self.norm(output)
+ output = self.activatoin(output)
+ new_cache = context_input[:, 1:, :]
+ # TODO(Mddct): we need padding new_cache in future
+ # new_cache = ApplyPadding(history, padding, new_cache)
+ return (output, [new_cache])
+
+
+class ConvPredictor(PredictorBase):
+ def __init__(
+ self,
+ voca_size: int,
+ embed_size: int,
+ embed_dropout: float,
+ history_size: int = 2,
+ activation: str = "relu",
+ bias: bool = False,
+ layer_norm_epsilon: float = 1e-5,
+ ) -> None:
+ super().__init__()
+
+ assert history_size >= 0
+ self.embed_size = embed_size
+ self.context_size = history_size + 1
+ self.embed = nn.Embedding(voca_size, self.embed_size)
+ self.embed_dropout = nn.Dropout(p=embed_dropout)
+ self.conv = nn.Conv1d(
+ in_channels=embed_size,
+ out_channels=embed_size,
+ kernel_size=self.context_size,
+ padding=0,
+ groups=embed_size,
+ bias=bias,
+ )
+ self.norm = nn.LayerNorm(embed_size, eps=layer_norm_epsilon)
+ self.activatoin = get_activation(activation)
+
+ def init_state(
+ self, batch_size: int, device: torch.device, method: str = "zero"
+ ) -> List[torch.Tensor]:
+ assert batch_size > 0
+ assert method == "zero"
+ return [
+ torch.zeros(
+ batch_size, self.context_size - 1, self.embed_size, device=device
+ )
+ ]
+
+ def cache_to_batch(self, cache: List[List[torch.Tensor]]) -> List[torch.Tensor]:
+ """
+ Args:
+ cache : [[history_1], [history_2], [history3]...]
+
+ Returns:
+ new_caceh: [history],
+ history: [bs, ...]
+ """
+ history = torch.cat([h[0] for h in cache], dim=0)
+ return [history]
+
+ def batch_to_cache(self, cache: List[torch.Tensor]) -> List[List[torch.Tensor]]:
+ """
+ Args:
+ cache : [history]
+ history: [bs, ...]
+ Returns:
+ new_ache : [[history_1], [history_2], [history_3]...]
+ """
+ assert len(cache) == 1
+ cache_0 = cache[0]
+ history: List[List[torch.Tensor]] = []
+ for h in torch.split(cache_0, 1, dim=0):
+ history.append([h])
+ return history
+
+ def forward(self, input: torch.Tensor, cache: Optional[List[torch.Tensor]] = None):
+ """forward for training"""
+ input = self.embed(input) # [bs, seq_len, embed]
+ input = self.embed_dropout(input)
+ if cache is None:
+ zeros = self.init_state(input.size(0), device=input.device)[0]
+ else:
+ assert len(cache) == 1
+ zeros = cache[0]
+
+ input = torch.cat(
+ (zeros, input), dim=1
+ ) # [bs, context_size-1 + seq_len, embed]
+ input = input.permute(0, 2, 1)
+ out = self.conv(input).permute(0, 2, 1)
+ out = self.activatoin(self.norm(out))
+ return out
+
+ def forward_step(
+ self, input: torch.Tensor, padding: torch.Tensor, cache: List[torch.Tensor]
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+ """forward step for inference
+ Args:
+ input (torch.Tensor): [batch_size, time_step=1]
+ padding (torch.Tensor): [batch_size,1], 1 is padding value
+ cache: for embedding predictor, cache[0] == history
+ """
+ assert input.size(1) == 1
+ assert len(cache) == 1
+ history = cache[0]
+ assert history.size(1) == self.context_size - 1
+ input = self.embed(input) # [bs, 1, embed]
+ input = self.embed_dropout(input)
+ context_input = torch.cat((history, input), dim=1)
+ input = context_input.permute(0, 2, 1)
+ out = self.conv(input).permute(0, 2, 1)
+ out = self.activatoin(self.norm(out))
+
+ new_cache = context_input[:, 1:, :]
+ # TODO(Mddct): apply padding in future
+ return (out, [new_cache])
diff --git a/modules/wenet_extractor/transducer/search/greedy_search.py b/modules/wenet_extractor/transducer/search/greedy_search.py
new file mode 100644
index 0000000000000000000000000000000000000000..df4cfea28902aac0005b0b72e930c790a2d79de4
--- /dev/null
+++ b/modules/wenet_extractor/transducer/search/greedy_search.py
@@ -0,0 +1,52 @@
+from typing import List
+
+import torch
+
+
+def basic_greedy_search(
+ model: torch.nn.Module,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ n_steps: int = 64,
+) -> List[List[int]]:
+ # fake padding
+ padding = torch.zeros(1, 1).to(encoder_out.device)
+ # sos
+ pred_input_step = torch.tensor([model.blank]).reshape(1, 1)
+ cache = model.predictor.init_state(1, method="zero", device=encoder_out.device)
+ new_cache: List[torch.Tensor] = []
+ t = 0
+ hyps = []
+ prev_out_nblk = True
+ pred_out_step = None
+ per_frame_max_noblk = n_steps
+ per_frame_noblk = 0
+ while t < encoder_out_lens:
+ encoder_out_step = encoder_out[:, t : t + 1, :] # [1, 1, E]
+ if prev_out_nblk:
+ step_outs = model.predictor.forward_step(
+ pred_input_step, padding, cache
+ ) # [1, 1, P]
+ pred_out_step, new_cache = step_outs[0], step_outs[1]
+
+ joint_out_step = model.joint(encoder_out_step, pred_out_step) # [1,1,v]
+ joint_out_probs = joint_out_step.log_softmax(dim=-1)
+
+ joint_out_max = joint_out_probs.argmax(dim=-1).squeeze() # []
+ if joint_out_max != model.blank:
+ hyps.append(joint_out_max.item())
+ prev_out_nblk = True
+ per_frame_noblk = per_frame_noblk + 1
+ pred_input_step = joint_out_max.reshape(1, 1)
+ # state_m, state_c = clstate_out_m, state_out_c
+ cache = new_cache
+
+ if joint_out_max == model.blank or per_frame_noblk >= per_frame_max_noblk:
+ if joint_out_max == model.blank:
+ prev_out_nblk = False
+ # TODO(Mddct): make t in chunk for streamming
+ # or t should't be too lang to predict none blank
+ t = t + 1
+ per_frame_noblk = 0
+
+ return [hyps]
diff --git a/modules/wenet_extractor/transducer/search/prefix_beam_search.py b/modules/wenet_extractor/transducer/search/prefix_beam_search.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4b4f4ea46f24f75f43f7cef27d9a12b57a7452d
--- /dev/null
+++ b/modules/wenet_extractor/transducer/search/prefix_beam_search.py
@@ -0,0 +1,149 @@
+from typing import List, Tuple
+
+import torch
+from modules.wenet_extractor.utils.common import log_add
+
+
+class Sequence:
+ __slots__ = {"hyp", "score", "cache"}
+
+ def __init__(
+ self,
+ hyp: List[torch.Tensor],
+ score,
+ cache: List[torch.Tensor],
+ ):
+ self.hyp = hyp
+ self.score = score
+ self.cache = cache
+
+
+class PrefixBeamSearch:
+ def __init__(self, encoder, predictor, joint, ctc, blank):
+ self.encoder = encoder
+ self.predictor = predictor
+ self.joint = joint
+ self.ctc = ctc
+ self.blank = blank
+
+ def forward_decoder_one_step(
+ self, encoder_x: torch.Tensor, pre_t: torch.Tensor, cache: List[torch.Tensor]
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+ padding = torch.zeros(pre_t.size(0), 1, device=encoder_x.device)
+ pre_t, new_cache = self.predictor.forward_step(
+ pre_t.unsqueeze(-1), padding, cache
+ )
+ x = self.joint(encoder_x, pre_t) # [beam, 1, 1, vocab]
+ x = x.log_softmax(dim=-1)
+ return x, new_cache
+
+ def prefix_beam_search(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ decoding_chunk_size: int = -1,
+ beam_size: int = 5,
+ num_decoding_left_chunks: int = -1,
+ simulate_streaming: bool = False,
+ ctc_weight: float = 0.3,
+ transducer_weight: float = 0.7,
+ ):
+ """prefix beam search
+ also see wenet.transducer.transducer.beam_search
+ """
+ assert speech.shape[0] == speech_lengths.shape[0]
+ assert decoding_chunk_size != 0
+ device = speech.device
+ batch_size = speech.shape[0]
+ assert batch_size == 1
+
+ # 1. Encoder
+ encoder_out, _ = self.encoder(
+ speech, speech_lengths, decoding_chunk_size, num_decoding_left_chunks
+ ) # (B, maxlen, encoder_dim)
+ maxlen = encoder_out.size(1)
+
+ ctc_probs = self.ctc.log_softmax(encoder_out).squeeze(0)
+ beam_init: List[Sequence] = []
+
+ # 2. init beam using Sequence to save beam unit
+ cache = self.predictor.init_state(1, method="zero", device=device)
+ beam_init.append(Sequence(hyp=[self.blank], score=0.0, cache=cache))
+ # 3. start decoding (notice: we use breathwise first searching)
+ # !!!! In this decoding method: one frame do not output multi units. !!!!
+ # !!!! Experiments show that this strategy has little impact !!!!
+ for i in range(maxlen):
+ # 3.1 building input
+ # decoder taking the last token to predict the next token
+ input_hyp = [s.hyp[-1] for s in beam_init]
+ input_hyp_tensor = torch.tensor(input_hyp, dtype=torch.int, device=device)
+ # building statement from beam
+ cache_batch = self.predictor.cache_to_batch([s.cache for s in beam_init])
+ # build score tensor to do torch.add() function
+ scores = torch.tensor([s.score for s in beam_init]).to(device)
+
+ # 3.2 forward decoder
+ logp, new_cache = self.forward_decoder_one_step(
+ encoder_out[:, i, :].unsqueeze(1),
+ input_hyp_tensor,
+ cache_batch,
+ ) # logp: (N, 1, 1, vocab_size)
+ logp = logp.squeeze(1).squeeze(1) # logp: (N, vocab_size)
+ new_cache = self.predictor.batch_to_cache(new_cache)
+
+ # 3.3 shallow fusion for transducer score
+ # and ctc score where we can also add the LM score
+ logp = torch.log(
+ torch.add(
+ transducer_weight * torch.exp(logp),
+ ctc_weight * torch.exp(ctc_probs[i].unsqueeze(0)),
+ )
+ )
+
+ # 3.4 first beam prune
+ top_k_logp, top_k_index = logp.topk(beam_size) # (N, N)
+ scores = torch.add(scores.unsqueeze(1), top_k_logp)
+
+ # 3.5 generate new beam (N*N)
+ beam_A = []
+ for j in range(len(beam_init)):
+ # update seq
+ base_seq = beam_init[j]
+ for t in range(beam_size):
+ # blank: only update the score
+ if top_k_index[j, t] == self.blank:
+ new_seq = Sequence(
+ hyp=base_seq.hyp.copy(),
+ score=scores[j, t].item(),
+ cache=base_seq.cache,
+ )
+
+ beam_A.append(new_seq)
+ # other unit: update hyp score statement and last
+ else:
+ hyp_new = base_seq.hyp.copy()
+ hyp_new.append(top_k_index[j, t].item())
+ new_seq = Sequence(
+ hyp=hyp_new, score=scores[j, t].item(), cache=new_cache[j]
+ )
+ beam_A.append(new_seq)
+
+ # 3.6 prefix fusion
+ fusion_A = [beam_A[0]]
+ for j in range(1, len(beam_A)):
+ s1 = beam_A[j]
+ if_do_append = True
+ for t in range(len(fusion_A)):
+ # notice: A_ can not fusion with A
+ if s1.hyp == fusion_A[t].hyp:
+ fusion_A[t].score = log_add([fusion_A[t].score, s1.score])
+ if_do_append = False
+ break
+ if if_do_append:
+ fusion_A.append(s1)
+
+ # 4. second pruned
+ fusion_A.sort(key=lambda x: x.score, reverse=True)
+ beam_init = fusion_A[:beam_size]
+
+ return beam_init, encoder_out
diff --git a/modules/wenet_extractor/transducer/transducer.py b/modules/wenet_extractor/transducer/transducer.py
new file mode 100644
index 0000000000000000000000000000000000000000..3618eeb2e355955a606d59e3a91f6253872843e0
--- /dev/null
+++ b/modules/wenet_extractor/transducer/transducer.py
@@ -0,0 +1,486 @@
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+import torchaudio
+from torch import nn
+from torch.nn.utils.rnn import pad_sequence
+
+from modules.wenet_extractor.transducer.predictor import PredictorBase
+from modules.wenet_extractor.transducer.search.greedy_search import basic_greedy_search
+from modules.wenet_extractor.transducer.search.prefix_beam_search import (
+ PrefixBeamSearch,
+)
+from modules.wenet_extractor.transformer.asr_model import ASRModel
+from modules.wenet_extractor.transformer.ctc import CTC
+from modules.wenet_extractor.transformer.decoder import (
+ BiTransformerDecoder,
+ TransformerDecoder,
+)
+from modules.wenet_extractor.transformer.label_smoothing_loss import LabelSmoothingLoss
+from modules.wenet_extractor.utils.common import (
+ IGNORE_ID,
+ add_blank,
+ add_sos_eos,
+ reverse_pad_list,
+)
+
+
+class Transducer(ASRModel):
+ """Transducer-ctc-attention hybrid Encoder-Predictor-Decoder model"""
+
+ def __init__(
+ self,
+ vocab_size: int,
+ blank: int,
+ encoder: nn.Module,
+ predictor: PredictorBase,
+ joint: nn.Module,
+ attention_decoder: Optional[
+ Union[TransformerDecoder, BiTransformerDecoder]
+ ] = None,
+ ctc: Optional[CTC] = None,
+ ctc_weight: float = 0,
+ ignore_id: int = IGNORE_ID,
+ reverse_weight: float = 0.0,
+ lsm_weight: float = 0.0,
+ length_normalized_loss: bool = False,
+ transducer_weight: float = 1.0,
+ attention_weight: float = 0.0,
+ ) -> None:
+ assert attention_weight + ctc_weight + transducer_weight == 1.0
+ super().__init__(
+ vocab_size,
+ encoder,
+ attention_decoder,
+ ctc,
+ ctc_weight,
+ ignore_id,
+ reverse_weight,
+ lsm_weight,
+ length_normalized_loss,
+ )
+
+ self.blank = blank
+ self.transducer_weight = transducer_weight
+ self.attention_decoder_weight = 1 - self.transducer_weight - self.ctc_weight
+
+ self.predictor = predictor
+ self.joint = joint
+ self.bs = None
+
+ # Note(Mddct): decoder also means predictor in transducer,
+ # but here decoder is attention decoder
+ del self.criterion_att
+ if attention_decoder is not None:
+ self.criterion_att = LabelSmoothingLoss(
+ size=vocab_size,
+ padding_idx=ignore_id,
+ smoothing=lsm_weight,
+ normalize_length=length_normalized_loss,
+ )
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ ) -> Dict[str, Optional[torch.Tensor]]:
+ """Frontend + Encoder + predictor + joint + loss
+
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ """
+ assert text_lengths.dim() == 1, text_lengths.shape
+ # Check that batch_size is unified
+ assert (
+ speech.shape[0]
+ == speech_lengths.shape[0]
+ == text.shape[0]
+ == text_lengths.shape[0]
+ ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
+
+ # Encoder
+ encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
+ encoder_out_lens = encoder_mask.squeeze(1).sum(1)
+ # predictor
+ ys_in_pad = add_blank(text, self.blank, self.ignore_id)
+ predictor_out = self.predictor(ys_in_pad)
+ # joint
+ joint_out = self.joint(encoder_out, predictor_out)
+ # NOTE(Mddct): some loss implementation require pad valid is zero
+ # torch.int32 rnnt_loss required
+ rnnt_text = text.to(torch.int64)
+ rnnt_text = torch.where(rnnt_text == self.ignore_id, 0, rnnt_text).to(
+ torch.int32
+ )
+ rnnt_text_lengths = text_lengths.to(torch.int32)
+ encoder_out_lens = encoder_out_lens.to(torch.int32)
+ loss = torchaudio.functional.rnnt_loss(
+ joint_out,
+ rnnt_text,
+ encoder_out_lens,
+ rnnt_text_lengths,
+ blank=self.blank,
+ reduction="mean",
+ )
+ loss_rnnt = loss
+
+ loss = self.transducer_weight * loss
+ # optional attention decoder
+ loss_att: Optional[torch.Tensor] = None
+ if self.attention_decoder_weight != 0.0 and self.decoder is not None:
+ loss_att, _ = self._calc_att_loss(
+ encoder_out, encoder_mask, text, text_lengths
+ )
+
+ # optional ctc
+ loss_ctc: Optional[torch.Tensor] = None
+ if self.ctc_weight != 0.0 and self.ctc is not None:
+ loss_ctc = self.ctc(encoder_out, encoder_out_lens, text, text_lengths)
+ else:
+ loss_ctc = None
+
+ if loss_ctc is not None:
+ loss = loss + self.ctc_weight * loss_ctc.sum()
+ if loss_att is not None:
+ loss = loss + self.attention_decoder_weight * loss_att.sum()
+ # NOTE: 'loss' must be in dict
+ return {
+ "loss": loss,
+ "loss_att": loss_att,
+ "loss_ctc": loss_ctc,
+ "loss_rnnt": loss_rnnt,
+ }
+
+ def init_bs(self):
+ if self.bs is None:
+ self.bs = PrefixBeamSearch(
+ self.encoder, self.predictor, self.joint, self.ctc, self.blank
+ )
+
+ def _cal_transducer_score(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_mask: torch.Tensor,
+ hyps_lens: torch.Tensor,
+ hyps_pad: torch.Tensor,
+ ):
+ # ignore id -> blank, add blank at head
+ hyps_pad_blank = add_blank(hyps_pad, self.blank, self.ignore_id)
+ xs_in_lens = encoder_mask.squeeze(1).sum(1).int()
+
+ # 1. Forward predictor
+ predictor_out = self.predictor(hyps_pad_blank)
+ # 2. Forward joint
+ joint_out = self.joint(encoder_out, predictor_out)
+ rnnt_text = hyps_pad.to(torch.int64)
+ rnnt_text = torch.where(rnnt_text == self.ignore_id, 0, rnnt_text).to(
+ torch.int32
+ )
+ # 3. Compute transducer loss
+ loss_td = torchaudio.functional.rnnt_loss(
+ joint_out,
+ rnnt_text,
+ xs_in_lens,
+ hyps_lens.int(),
+ blank=self.blank,
+ reduction="none",
+ )
+ return loss_td * -1
+
+ def _cal_attn_score(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_mask: torch.Tensor,
+ hyps_pad: torch.Tensor,
+ hyps_lens: torch.Tensor,
+ ):
+ # (beam_size, max_hyps_len)
+ ori_hyps_pad = hyps_pad
+
+ # td_score = loss_td * -1
+ hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id)
+ hyps_lens = hyps_lens + 1 # Add at begining
+ # used for right to left decoder
+ r_hyps_pad = reverse_pad_list(ori_hyps_pad, hyps_lens, self.ignore_id)
+ r_hyps_pad, _ = add_sos_eos(r_hyps_pad, self.sos, self.eos, self.ignore_id)
+ decoder_out, r_decoder_out, _ = self.decoder(
+ encoder_out,
+ encoder_mask,
+ hyps_pad,
+ hyps_lens,
+ r_hyps_pad,
+ self.reverse_weight,
+ ) # (beam_size, max_hyps_len, vocab_size)
+ decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1)
+ decoder_out = decoder_out.cpu().numpy()
+ # r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a
+ # conventional transformer decoder.
+ r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1)
+ r_decoder_out = r_decoder_out.cpu().numpy()
+ return decoder_out, r_decoder_out
+
+ def beam_search(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ decoding_chunk_size: int = -1,
+ beam_size: int = 5,
+ num_decoding_left_chunks: int = -1,
+ simulate_streaming: bool = False,
+ ctc_weight: float = 0.3,
+ transducer_weight: float = 0.7,
+ ):
+ """beam search
+
+ Args:
+ speech (torch.Tensor): (batch=1, max_len, feat_dim)
+ speech_length (torch.Tensor): (batch, )
+ beam_size (int): beam size for beam search
+ decoding_chunk_size (int): decoding chunk for dynamic chunk
+ trained model.
+ <0: for decoding, use full chunk.
+ >0: for decoding, use fixed chunk size as set.
+ 0: used for training, it's prohibited here
+ simulate_streaming (bool): whether do encoder forward in a
+ streaming fashion
+ ctc_weight (float): ctc probability weight in transducer
+ prefix beam search.
+ final_prob = ctc_weight * ctc_prob + transducer_weight * transducer_prob
+ transducer_weight (float): transducer probability weight in
+ prefix beam search
+ Returns:
+ List[List[int]]: best path result
+
+ """
+ self.init_bs()
+ beam, _ = self.bs.prefix_beam_search(
+ speech,
+ speech_lengths,
+ decoding_chunk_size,
+ beam_size,
+ num_decoding_left_chunks,
+ simulate_streaming,
+ ctc_weight,
+ transducer_weight,
+ )
+ return beam[0].hyp[1:], beam[0].score
+
+ def transducer_attention_rescoring(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ beam_size: int,
+ decoding_chunk_size: int = -1,
+ num_decoding_left_chunks: int = -1,
+ simulate_streaming: bool = False,
+ reverse_weight: float = 0.0,
+ ctc_weight: float = 0.0,
+ attn_weight: float = 0.0,
+ transducer_weight: float = 0.0,
+ search_ctc_weight: float = 1.0,
+ search_transducer_weight: float = 0.0,
+ beam_search_type: str = "transducer",
+ ) -> List[List[int]]:
+ """beam search
+
+ Args:
+ speech (torch.Tensor): (batch=1, max_len, feat_dim)
+ speech_length (torch.Tensor): (batch, )
+ beam_size (int): beam size for beam search
+ decoding_chunk_size (int): decoding chunk for dynamic chunk
+ trained model.
+ <0: for decoding, use full chunk.
+ >0: for decoding, use fixed chunk size as set.
+ 0: used for training, it's prohibited here
+ simulate_streaming (bool): whether do encoder forward in a
+ streaming fashion
+ ctc_weight (float): ctc probability weight using in rescoring.
+ rescore_prob = ctc_weight * ctc_prob +
+ transducer_weight * (transducer_loss * -1) +
+ attn_weight * attn_prob
+ attn_weight (float): attn probability weight using in rescoring.
+ transducer_weight (float): transducer probability weight using in
+ rescoring
+ search_ctc_weight (float): ctc weight using
+ in rnnt beam search (seeing in self.beam_search)
+ search_transducer_weight (float): transducer weight using
+ in rnnt beam search (seeing in self.beam_search)
+ Returns:
+ List[List[int]]: best path result
+
+ """
+
+ assert speech.shape[0] == speech_lengths.shape[0]
+ assert decoding_chunk_size != 0
+ if reverse_weight > 0.0:
+ # decoder should be a bitransformer decoder if reverse_weight > 0.0
+ assert hasattr(self.decoder, "right_decoder")
+ device = speech.device
+ batch_size = speech.shape[0]
+ # For attention rescoring we only support batch_size=1
+ assert batch_size == 1
+ # encoder_out: (1, maxlen, encoder_dim), len(hyps) = beam_size
+ self.init_bs()
+ if beam_search_type == "transducer":
+ beam, encoder_out = self.bs.prefix_beam_search(
+ speech,
+ speech_lengths,
+ decoding_chunk_size=decoding_chunk_size,
+ beam_size=beam_size,
+ num_decoding_left_chunks=num_decoding_left_chunks,
+ ctc_weight=search_ctc_weight,
+ transducer_weight=search_transducer_weight,
+ )
+ beam_score = [s.score for s in beam]
+ hyps = [s.hyp[1:] for s in beam]
+
+ elif beam_search_type == "ctc":
+ hyps, encoder_out = self._ctc_prefix_beam_search(
+ speech,
+ speech_lengths,
+ beam_size=beam_size,
+ decoding_chunk_size=decoding_chunk_size,
+ num_decoding_left_chunks=num_decoding_left_chunks,
+ simulate_streaming=simulate_streaming,
+ )
+ beam_score = [hyp[1] for hyp in hyps]
+ hyps = [hyp[0] for hyp in hyps]
+ assert len(hyps) == beam_size
+
+ # build hyps and encoder output
+ hyps_pad = pad_sequence(
+ [torch.tensor(hyp, device=device, dtype=torch.long) for hyp in hyps],
+ True,
+ self.ignore_id,
+ ) # (beam_size, max_hyps_len)
+ hyps_lens = torch.tensor(
+ [len(hyp) for hyp in hyps], device=device, dtype=torch.long
+ ) # (beam_size,)
+
+ encoder_out = encoder_out.repeat(beam_size, 1, 1)
+ encoder_mask = torch.ones(
+ beam_size, 1, encoder_out.size(1), dtype=torch.bool, device=device
+ )
+
+ # 2.1 calculate transducer score
+ td_score = self._cal_transducer_score(
+ encoder_out,
+ encoder_mask,
+ hyps_lens,
+ hyps_pad,
+ )
+ # 2.2 calculate attention score
+ decoder_out, r_decoder_out = self._cal_attn_score(
+ encoder_out,
+ encoder_mask,
+ hyps_pad,
+ hyps_lens,
+ )
+
+ # Only use decoder score for rescoring
+ best_score = -float("inf")
+ best_index = 0
+ for i, hyp in enumerate(hyps):
+ score = 0.0
+ for j, w in enumerate(hyp):
+ score += decoder_out[i][j][w]
+ score += decoder_out[i][len(hyp)][self.eos]
+ td_s = td_score[i]
+ # add right to left decoder score
+ if reverse_weight > 0:
+ r_score = 0.0
+ for j, w in enumerate(hyp):
+ r_score += r_decoder_out[i][len(hyp) - j - 1][w]
+ r_score += r_decoder_out[i][len(hyp)][self.eos]
+ score = score * (1 - reverse_weight) + r_score * reverse_weight
+ # add ctc score
+ score = (
+ score * attn_weight
+ + beam_score[i] * ctc_weight
+ + td_s * transducer_weight
+ )
+ if score > best_score:
+ best_score = score
+ best_index = i
+
+ return hyps[best_index], best_score
+
+ def greedy_search(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ decoding_chunk_size: int = -1,
+ num_decoding_left_chunks: int = -1,
+ simulate_streaming: bool = False,
+ n_steps: int = 64,
+ ) -> List[List[int]]:
+ """greedy search
+
+ Args:
+ speech (torch.Tensor): (batch=1, max_len, feat_dim)
+ speech_length (torch.Tensor): (batch, )
+ beam_size (int): beam size for beam search
+ decoding_chunk_size (int): decoding chunk for dynamic chunk
+ trained model.
+ <0: for decoding, use full chunk.
+ >0: for decoding, use fixed chunk size as set.
+ 0: used for training, it's prohibited here
+ simulate_streaming (bool): whether do encoder forward in a
+ streaming fashion
+ Returns:
+ List[List[int]]: best path result
+ """
+ # TODO(Mddct): batch decode
+ assert speech.size(0) == 1
+ assert speech.shape[0] == speech_lengths.shape[0]
+ assert decoding_chunk_size != 0
+ # TODO(Mddct): forward chunk by chunk
+ _ = simulate_streaming
+ # Let's assume B = batch_size
+ encoder_out, encoder_mask = self.encoder(
+ speech,
+ speech_lengths,
+ decoding_chunk_size,
+ num_decoding_left_chunks,
+ )
+ encoder_out_lens = encoder_mask.squeeze(1).sum()
+ hyps = basic_greedy_search(self, encoder_out, encoder_out_lens, n_steps=n_steps)
+
+ return hyps
+
+ @torch.jit.export
+ def forward_encoder_chunk(
+ self,
+ xs: torch.Tensor,
+ offset: int,
+ required_cache_size: int,
+ att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
+ cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ return self.encoder.forward_chunk(
+ xs, offset, required_cache_size, att_cache, cnn_cache
+ )
+
+ @torch.jit.export
+ def forward_predictor_step(
+ self, xs: torch.Tensor, cache: List[torch.Tensor]
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+ assert len(cache) == 2
+ # fake padding
+ padding = torch.zeros(1, 1)
+ return self.predictor.forward_step(xs, padding, cache)
+
+ @torch.jit.export
+ def forward_joint_step(
+ self, enc_out: torch.Tensor, pred_out: torch.Tensor
+ ) -> torch.Tensor:
+ return self.joint(enc_out, pred_out)
+
+ @torch.jit.export
+ def forward_predictor_init_state(self) -> List[torch.Tensor]:
+ return self.predictor.init_state(1, device=torch.device("cpu"))
diff --git a/modules/wenet_extractor/transformer/__init__.py b/modules/wenet_extractor/transformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/modules/wenet_extractor/transformer/asr_model.py b/modules/wenet_extractor/transformer/asr_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ad7869679ee15fdcf6a759a7c39762d7f166ff2
--- /dev/null
+++ b/modules/wenet_extractor/transformer/asr_model.py
@@ -0,0 +1,1056 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+from collections import defaultdict
+from typing import Dict, List, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+from torch.nn.utils.rnn import pad_sequence
+
+from modules.wenet_extractor.transformer.ctc import CTC
+from modules.wenet_extractor.transformer.decoder import TransformerDecoder
+from modules.wenet_extractor.transformer.encoder import TransformerEncoder
+from modules.wenet_extractor.transformer.label_smoothing_loss import LabelSmoothingLoss
+from modules.wenet_extractor.utils.common import (
+ IGNORE_ID,
+ add_sos_eos,
+ log_add,
+ remove_duplicates_and_blank,
+ th_accuracy,
+ reverse_pad_list,
+)
+from modules.wenet_extractor.utils.mask import (
+ make_pad_mask,
+ mask_finished_preds,
+ mask_finished_scores,
+ subsequent_mask,
+)
+
+
+class ASRModel(torch.nn.Module):
+ """CTC-attention hybrid Encoder-Decoder model"""
+
+ def __init__(
+ self,
+ vocab_size: int,
+ encoder: TransformerEncoder,
+ decoder: TransformerDecoder,
+ ctc: CTC,
+ ctc_weight: float = 0.5,
+ ignore_id: int = IGNORE_ID,
+ reverse_weight: float = 0.0,
+ lsm_weight: float = 0.0,
+ length_normalized_loss: bool = False,
+ lfmmi_dir: str = "",
+ ):
+ assert 0.0 <= ctc_weight <= 1.0, ctc_weight
+
+ super().__init__()
+ # note that eos is the same as sos (equivalent ID)
+ self.sos = vocab_size - 1
+ self.eos = vocab_size - 1
+ self.vocab_size = vocab_size
+ self.ignore_id = ignore_id
+ self.ctc_weight = ctc_weight
+ self.reverse_weight = reverse_weight
+
+ self.encoder = encoder
+ self.decoder = decoder
+ self.ctc = ctc
+ self.criterion_att = LabelSmoothingLoss(
+ size=vocab_size,
+ padding_idx=ignore_id,
+ smoothing=lsm_weight,
+ normalize_length=length_normalized_loss,
+ )
+ self.lfmmi_dir = lfmmi_dir
+ if self.lfmmi_dir != "":
+ self.load_lfmmi_resource()
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ ) -> Dict[str, Optional[torch.Tensor]]:
+ """Frontend + Encoder + Decoder + Calc loss
+
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ """
+
+ assert text_lengths.dim() == 1, text_lengths.shape
+ # Check that batch_size is unified
+ assert (
+ speech.shape[0]
+ == speech_lengths.shape[0]
+ == text.shape[0]
+ == text_lengths.shape[0]
+ ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
+ # 1. Encoder
+ encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
+ encoder_out_lens = encoder_mask.squeeze(1).sum(1)
+
+ # 2a. Attention-decoder branch
+ if self.ctc_weight != 1.0:
+ loss_att, acc_att = self._calc_att_loss(
+ encoder_out, encoder_mask, text, text_lengths
+ )
+ else:
+ loss_att = None
+
+ # 2b. CTC branch or LF-MMI loss
+ if self.ctc_weight != 0.0:
+ if self.lfmmi_dir != "":
+ loss_ctc = self._calc_lfmmi_loss(encoder_out, encoder_mask, text)
+ else:
+ loss_ctc = self.ctc(encoder_out, encoder_out_lens, text, text_lengths)
+ else:
+ loss_ctc = None
+
+ if loss_ctc is None:
+ loss = loss_att
+ elif loss_att is None:
+ loss = loss_ctc
+ else:
+ loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
+ return {"loss": loss, "loss_att": loss_att, "loss_ctc": loss_ctc}
+
+ def _calc_att_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_mask: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ) -> Tuple[torch.Tensor, float]:
+ ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_in_lens = ys_pad_lens + 1
+
+ # reverse the seq, used for right to left decoder
+ r_ys_pad = reverse_pad_list(ys_pad, ys_pad_lens, float(self.ignore_id))
+ r_ys_in_pad, r_ys_out_pad = add_sos_eos(
+ r_ys_pad, self.sos, self.eos, self.ignore_id
+ )
+ # 1. Forward decoder
+ decoder_out, r_decoder_out, _ = self.decoder(
+ encoder_out,
+ encoder_mask,
+ ys_in_pad,
+ ys_in_lens,
+ r_ys_in_pad,
+ self.reverse_weight,
+ )
+ # 2. Compute attention loss
+ loss_att = self.criterion_att(decoder_out, ys_out_pad)
+ r_loss_att = torch.tensor(0.0)
+ if self.reverse_weight > 0.0:
+ r_loss_att = self.criterion_att(r_decoder_out, r_ys_out_pad)
+ loss_att = (
+ loss_att * (1 - self.reverse_weight) + r_loss_att * self.reverse_weight
+ )
+ acc_att = th_accuracy(
+ decoder_out.view(-1, self.vocab_size),
+ ys_out_pad,
+ ignore_label=self.ignore_id,
+ )
+ return loss_att, acc_att
+
+ def _forward_encoder(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ decoding_chunk_size: int = -1,
+ num_decoding_left_chunks: int = -1,
+ simulate_streaming: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # Let's assume B = batch_size
+ # 1. Encoder
+ if simulate_streaming and decoding_chunk_size > 0:
+ encoder_out, encoder_mask = self.encoder.forward_chunk_by_chunk(
+ speech,
+ decoding_chunk_size=decoding_chunk_size,
+ num_decoding_left_chunks=num_decoding_left_chunks,
+ ) # (B, maxlen, encoder_dim)
+ else:
+ encoder_out, encoder_mask = self.encoder(
+ speech,
+ speech_lengths,
+ decoding_chunk_size=decoding_chunk_size,
+ num_decoding_left_chunks=num_decoding_left_chunks,
+ ) # (B, maxlen, encoder_dim)
+ return encoder_out, encoder_mask
+
+ def encoder_extractor(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ decoding_chunk_size: int = -1,
+ num_decoding_left_chunks: int = -1,
+ simulate_streaming: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # assert speech.shape[0] == speech_lengths[0]
+ assert decoding_chunk_size != 0
+ batch_size = speech.shape[0]
+
+ encoder_out, encoder_mask = self._forward_encoder(
+ speech,
+ speech_lengths,
+ decoding_chunk_size,
+ num_decoding_left_chunks,
+ simulate_streaming,
+ ) # (B, maxlen, encoder_dim)
+
+ return encoder_out
+
+ def recognize(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ beam_size: int = 10,
+ decoding_chunk_size: int = -1,
+ num_decoding_left_chunks: int = -1,
+ simulate_streaming: bool = False,
+ ) -> torch.Tensor:
+ """Apply beam search on attention decoder
+
+ Args:
+ speech (torch.Tensor): (batch, max_len, feat_dim)
+ speech_length (torch.Tensor): (batch, )
+ beam_size (int): beam size for beam search
+ decoding_chunk_size (int): decoding chunk for dynamic chunk
+ trained model.
+ <0: for decoding, use full chunk.
+ >0: for decoding, use fixed chunk size as set.
+ 0: used for training, it's prohibited here
+ simulate_streaming (bool): whether do encoder forward in a
+ streaming fashion
+
+ Returns:
+ torch.Tensor: decoding result, (batch, max_result_len)
+ """
+ assert speech.shape[0] == speech_lengths.shape[0]
+ assert decoding_chunk_size != 0
+ device = speech.device
+ batch_size = speech.shape[0]
+
+ # Let's assume B = batch_size and N = beam_size
+ # 1. Encoder
+ encoder_out, encoder_mask = self._forward_encoder(
+ speech,
+ speech_lengths,
+ decoding_chunk_size,
+ num_decoding_left_chunks,
+ simulate_streaming,
+ ) # (B, maxlen, encoder_dim)
+ maxlen = encoder_out.size(1)
+ encoder_dim = encoder_out.size(2)
+ running_size = batch_size * beam_size
+ encoder_out = (
+ encoder_out.unsqueeze(1)
+ .repeat(1, beam_size, 1, 1)
+ .view(running_size, maxlen, encoder_dim)
+ ) # (B*N, maxlen, encoder_dim)
+ encoder_mask = (
+ encoder_mask.unsqueeze(1)
+ .repeat(1, beam_size, 1, 1)
+ .view(running_size, 1, maxlen)
+ ) # (B*N, 1, max_len)
+
+ hyps = torch.ones([running_size, 1], dtype=torch.long, device=device).fill_(
+ self.sos
+ ) # (B*N, 1)
+ scores = torch.tensor(
+ [0.0] + [-float("inf")] * (beam_size - 1), dtype=torch.float
+ )
+ scores = (
+ scores.to(device).repeat([batch_size]).unsqueeze(1).to(device)
+ ) # (B*N, 1)
+ end_flag = torch.zeros_like(scores, dtype=torch.bool, device=device)
+ cache: Optional[List[torch.Tensor]] = None
+ # 2. Decoder forward step by step
+ for i in range(1, maxlen + 1):
+ # Stop if all batch and all beam produce eos
+ if end_flag.sum() == running_size:
+ break
+ # 2.1 Forward decoder step
+ hyps_mask = (
+ subsequent_mask(i).unsqueeze(0).repeat(running_size, 1, 1).to(device)
+ ) # (B*N, i, i)
+ # logp: (B*N, vocab)
+ logp, cache = self.decoder.forward_one_step(
+ encoder_out, encoder_mask, hyps, hyps_mask, cache
+ )
+ # 2.2 First beam prune: select topk best prob at current time
+ top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N)
+ top_k_logp = mask_finished_scores(top_k_logp, end_flag)
+ top_k_index = mask_finished_preds(top_k_index, end_flag, self.eos)
+ # 2.3 Second beam prune: select topk score with history
+ scores = scores + top_k_logp # (B*N, N), broadcast add
+ scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N)
+ scores, offset_k_index = scores.topk(k=beam_size) # (B, N)
+ # Update cache to be consistent with new topk scores / hyps
+ cache_index = (offset_k_index // beam_size).view(-1) # (B*N)
+ base_cache_index = (
+ torch.arange(batch_size, device=device)
+ .view(-1, 1)
+ .repeat([1, beam_size])
+ * beam_size
+ ).view(
+ -1
+ ) # (B*N)
+ cache_index = base_cache_index + cache_index
+ cache = [torch.index_select(c, dim=0, index=cache_index) for c in cache]
+ scores = scores.view(-1, 1) # (B*N, 1)
+ # 2.4. Compute base index in top_k_index,
+ # regard top_k_index as (B*N*N),regard offset_k_index as (B*N),
+ # then find offset_k_index in top_k_index
+ base_k_index = (
+ torch.arange(batch_size, device=device)
+ .view(-1, 1)
+ .repeat([1, beam_size])
+ ) # (B, N)
+ base_k_index = base_k_index * beam_size * beam_size
+ best_k_index = base_k_index.view(-1) + offset_k_index.view(-1) # (B*N)
+
+ # 2.5 Update best hyps
+ best_k_pred = torch.index_select(
+ top_k_index.view(-1), dim=-1, index=best_k_index
+ ) # (B*N)
+ best_hyps_index = best_k_index // beam_size
+ last_best_k_hyps = torch.index_select(
+ hyps, dim=0, index=best_hyps_index
+ ) # (B*N, i)
+ hyps = torch.cat(
+ (last_best_k_hyps, best_k_pred.view(-1, 1)), dim=1
+ ) # (B*N, i+1)
+
+ # 2.6 Update end flag
+ end_flag = torch.eq(hyps[:, -1], self.eos).view(-1, 1)
+
+ # 3. Select best of best
+ scores = scores.view(batch_size, beam_size)
+ # TODO: length normalization
+ best_scores, best_index = scores.max(dim=-1)
+ best_hyps_index = (
+ best_index
+ + torch.arange(batch_size, dtype=torch.long, device=device) * beam_size
+ )
+ best_hyps = torch.index_select(hyps, dim=0, index=best_hyps_index)
+ best_hyps = best_hyps[:, 1:]
+ return best_hyps, best_scores
+
+ def ctc_greedy_search(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ decoding_chunk_size: int = -1,
+ num_decoding_left_chunks: int = -1,
+ simulate_streaming: bool = False,
+ ) -> List[List[int]]:
+ """Apply CTC greedy search
+
+ Args:
+ speech (torch.Tensor): (batch, max_len, feat_dim)
+ speech_length (torch.Tensor): (batch, )
+ beam_size (int): beam size for beam search
+ decoding_chunk_size (int): decoding chunk for dynamic chunk
+ trained model.
+ <0: for decoding, use full chunk.
+ >0: for decoding, use fixed chunk size as set.
+ 0: used for training, it's prohibited here
+ simulate_streaming (bool): whether do encoder forward in a
+ streaming fashion
+ Returns:
+ List[List[int]]: best path result
+ """
+ assert speech.shape[0] == speech_lengths.shape[0]
+ assert decoding_chunk_size != 0
+ batch_size = speech.shape[0]
+ # Let's assume B = batch_size
+ encoder_out, encoder_mask = self._forward_encoder(
+ speech,
+ speech_lengths,
+ decoding_chunk_size,
+ num_decoding_left_chunks,
+ simulate_streaming,
+ ) # (B, maxlen, encoder_dim)
+ maxlen = encoder_out.size(1)
+ encoder_out_lens = encoder_mask.squeeze(1).sum(1)
+ ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size)
+ topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1)
+ topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen)
+ mask = make_pad_mask(encoder_out_lens, maxlen) # (B, maxlen)
+ topk_index = topk_index.masked_fill_(mask, self.eos) # (B, maxlen)
+ hyps = [hyp.tolist() for hyp in topk_index]
+ scores = topk_prob.max(1)
+ hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps]
+ return hyps, scores
+
+ def _ctc_prefix_beam_search(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ beam_size: int,
+ decoding_chunk_size: int = -1,
+ num_decoding_left_chunks: int = -1,
+ simulate_streaming: bool = False,
+ ) -> Tuple[List[List[int]], torch.Tensor]:
+ """CTC prefix beam search inner implementation
+
+ Args:
+ speech (torch.Tensor): (batch, max_len, feat_dim)
+ speech_length (torch.Tensor): (batch, )
+ beam_size (int): beam size for beam search
+ decoding_chunk_size (int): decoding chunk for dynamic chunk
+ trained model.
+ <0: for decoding, use full chunk.
+ >0: for decoding, use fixed chunk size as set.
+ 0: used for training, it's prohibited here
+ simulate_streaming (bool): whether do encoder forward in a
+ streaming fashion
+
+ Returns:
+ List[List[int]]: nbest results
+ torch.Tensor: encoder output, (1, max_len, encoder_dim),
+ it will be used for rescoring in attention rescoring mode
+ """
+ assert speech.shape[0] == speech_lengths.shape[0]
+ assert decoding_chunk_size != 0
+ batch_size = speech.shape[0]
+ # For CTC prefix beam search, we only support batch_size=1
+ assert batch_size == 1
+ # Let's assume B = batch_size and N = beam_size
+ # 1. Encoder forward and get CTC score
+ encoder_out, encoder_mask = self._forward_encoder(
+ speech,
+ speech_lengths,
+ decoding_chunk_size,
+ num_decoding_left_chunks,
+ simulate_streaming,
+ ) # (B, maxlen, encoder_dim)
+ maxlen = encoder_out.size(1)
+ ctc_probs = self.ctc.log_softmax(encoder_out) # (1, maxlen, vocab_size)
+ ctc_probs = ctc_probs.squeeze(0)
+ # cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
+ cur_hyps = [(tuple(), (0.0, -float("inf")))]
+ # 2. CTC beam search step by step
+ for t in range(0, maxlen):
+ logp = ctc_probs[t] # (vocab_size,)
+ # key: prefix, value (pb, pnb), default value(-inf, -inf)
+ next_hyps = defaultdict(lambda: (-float("inf"), -float("inf")))
+ # 2.1 First beam prune: select topk best
+ top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,)
+ for s in top_k_index:
+ s = s.item()
+ ps = logp[s].item()
+ for prefix, (pb, pnb) in cur_hyps:
+ last = prefix[-1] if len(prefix) > 0 else None
+ if s == 0: # blank
+ n_pb, n_pnb = next_hyps[prefix]
+ n_pb = log_add([n_pb, pb + ps, pnb + ps])
+ next_hyps[prefix] = (n_pb, n_pnb)
+ elif s == last:
+ # Update *ss -> *s;
+ n_pb, n_pnb = next_hyps[prefix]
+ n_pnb = log_add([n_pnb, pnb + ps])
+ next_hyps[prefix] = (n_pb, n_pnb)
+ # Update *s-s -> *ss, - is for blank
+ n_prefix = prefix + (s,)
+ n_pb, n_pnb = next_hyps[n_prefix]
+ n_pnb = log_add([n_pnb, pb + ps])
+ next_hyps[n_prefix] = (n_pb, n_pnb)
+ else:
+ n_prefix = prefix + (s,)
+ n_pb, n_pnb = next_hyps[n_prefix]
+ n_pnb = log_add([n_pnb, pb + ps, pnb + ps])
+ next_hyps[n_prefix] = (n_pb, n_pnb)
+
+ # 2.2 Second beam prune
+ next_hyps = sorted(
+ next_hyps.items(), key=lambda x: log_add(list(x[1])), reverse=True
+ )
+ cur_hyps = next_hyps[:beam_size]
+ hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in cur_hyps]
+ return hyps, encoder_out
+
+ def ctc_prefix_beam_search(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ beam_size: int,
+ decoding_chunk_size: int = -1,
+ num_decoding_left_chunks: int = -1,
+ simulate_streaming: bool = False,
+ ) -> List[int]:
+ """Apply CTC prefix beam search
+
+ Args:
+ speech (torch.Tensor): (batch, max_len, feat_dim)
+ speech_length (torch.Tensor): (batch, )
+ beam_size (int): beam size for beam search
+ decoding_chunk_size (int): decoding chunk for dynamic chunk
+ trained model.
+ <0: for decoding, use full chunk.
+ >0: for decoding, use fixed chunk size as set.
+ 0: used for training, it's prohibited here
+ simulate_streaming (bool): whether do encoder forward in a
+ streaming fashion
+
+ Returns:
+ List[int]: CTC prefix beam search nbest results
+ """
+ hyps, _ = self._ctc_prefix_beam_search(
+ speech,
+ speech_lengths,
+ beam_size,
+ decoding_chunk_size,
+ num_decoding_left_chunks,
+ simulate_streaming,
+ )
+ return hyps[0]
+
+ def attention_rescoring(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ beam_size: int,
+ decoding_chunk_size: int = -1,
+ num_decoding_left_chunks: int = -1,
+ ctc_weight: float = 0.0,
+ simulate_streaming: bool = False,
+ reverse_weight: float = 0.0,
+ ) -> List[int]:
+ """Apply attention rescoring decoding, CTC prefix beam search
+ is applied first to get nbest, then we resoring the nbest on
+ attention decoder with corresponding encoder out
+
+ Args:
+ speech (torch.Tensor): (batch, max_len, feat_dim)
+ speech_length (torch.Tensor): (batch, )
+ beam_size (int): beam size for beam search
+ decoding_chunk_size (int): decoding chunk for dynamic chunk
+ trained model.
+ <0: for decoding, use full chunk.
+ >0: for decoding, use fixed chunk size as set.
+ 0: used for training, it's prohibited here
+ simulate_streaming (bool): whether do encoder forward in a
+ streaming fashion
+ reverse_weight (float): right to left decoder weight
+ ctc_weight (float): ctc score weight
+
+ Returns:
+ List[int]: Attention rescoring result
+ """
+ assert speech.shape[0] == speech_lengths.shape[0]
+ assert decoding_chunk_size != 0
+ if reverse_weight > 0.0:
+ # decoder should be a bitransformer decoder if reverse_weight > 0.0
+ assert hasattr(self.decoder, "right_decoder")
+ device = speech.device
+ batch_size = speech.shape[0]
+ # For attention rescoring we only support batch_size=1
+ assert batch_size == 1
+ # encoder_out: (1, maxlen, encoder_dim), len(hyps) = beam_size
+ hyps, encoder_out = self._ctc_prefix_beam_search(
+ speech,
+ speech_lengths,
+ beam_size,
+ decoding_chunk_size,
+ num_decoding_left_chunks,
+ simulate_streaming,
+ )
+
+ assert len(hyps) == beam_size
+ hyps_pad = pad_sequence(
+ [torch.tensor(hyp[0], device=device, dtype=torch.long) for hyp in hyps],
+ True,
+ self.ignore_id,
+ ) # (beam_size, max_hyps_len)
+ ori_hyps_pad = hyps_pad
+ hyps_lens = torch.tensor(
+ [len(hyp[0]) for hyp in hyps], device=device, dtype=torch.long
+ ) # (beam_size,)
+ hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id)
+ hyps_lens = hyps_lens + 1 # Add at begining
+ encoder_out = encoder_out.repeat(beam_size, 1, 1)
+ encoder_mask = torch.ones(
+ beam_size, 1, encoder_out.size(1), dtype=torch.bool, device=device
+ )
+ # used for right to left decoder
+ r_hyps_pad = reverse_pad_list(ori_hyps_pad, hyps_lens, self.ignore_id)
+ r_hyps_pad, _ = add_sos_eos(r_hyps_pad, self.sos, self.eos, self.ignore_id)
+ decoder_out, r_decoder_out, _ = self.decoder(
+ encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad, reverse_weight
+ ) # (beam_size, max_hyps_len, vocab_size)
+ decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1)
+ decoder_out = decoder_out.cpu().numpy()
+ # r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a
+ # conventional transformer decoder.
+ r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1)
+ r_decoder_out = r_decoder_out.cpu().numpy()
+ # Only use decoder score for rescoring
+ best_score = -float("inf")
+ best_index = 0
+ for i, hyp in enumerate(hyps):
+ score = 0.0
+ for j, w in enumerate(hyp[0]):
+ score += decoder_out[i][j][w]
+ score += decoder_out[i][len(hyp[0])][self.eos]
+ # add right to left decoder score
+ if reverse_weight > 0:
+ r_score = 0.0
+ for j, w in enumerate(hyp[0]):
+ r_score += r_decoder_out[i][len(hyp[0]) - j - 1][w]
+ r_score += r_decoder_out[i][len(hyp[0])][self.eos]
+ score = score * (1 - reverse_weight) + r_score * reverse_weight
+ # add ctc score
+ score += hyp[1] * ctc_weight
+ if score > best_score:
+ best_score = score
+ best_index = i
+ return hyps[best_index][0], best_score
+
+ @torch.jit.unused
+ def load_lfmmi_resource(self):
+ with open("{}/tokens.txt".format(self.lfmmi_dir), "r") as fin:
+ for line in fin:
+ arr = line.strip().split()
+ if arr[0] == "":
+ self.sos_eos_id = int(arr[1])
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ self.graph_compiler = MmiTrainingGraphCompiler(
+ self.lfmmi_dir,
+ device=device,
+ oov="",
+ sos_id=self.sos_eos_id,
+ eos_id=self.sos_eos_id,
+ )
+ self.lfmmi = LFMMILoss(
+ graph_compiler=self.graph_compiler,
+ den_scale=1,
+ use_pruned_intersect=False,
+ )
+ self.word_table = {}
+ with open("{}/words.txt".format(self.lfmmi_dir), "r") as fin:
+ for line in fin:
+ arr = line.strip().split()
+ assert len(arr) == 2
+ self.word_table[int(arr[1])] = arr[0]
+
+ @torch.jit.unused
+ def _calc_lfmmi_loss(self, encoder_out, encoder_mask, text):
+ ctc_probs = self.ctc.log_softmax(encoder_out)
+ supervision_segments = torch.stack(
+ (
+ torch.arange(len(encoder_mask)),
+ torch.zeros(len(encoder_mask)),
+ encoder_mask.squeeze(dim=1).sum(dim=1).to("cpu"),
+ ),
+ 1,
+ ).to(torch.int32)
+ dense_fsa_vec = k2.DenseFsaVec(
+ ctc_probs,
+ supervision_segments,
+ allow_truncate=3,
+ )
+ text = [
+ " ".join([self.word_table[j.item()] for j in i if j != -1]) for i in text
+ ]
+ loss = self.lfmmi(dense_fsa_vec=dense_fsa_vec, texts=text) / len(text)
+ return loss
+
+ def load_hlg_resource_if_necessary(self, hlg, word):
+ if not hasattr(self, "hlg"):
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ self.hlg = k2.Fsa.from_dict(torch.load(hlg, map_location=device))
+ if not hasattr(self.hlg, "lm_scores"):
+ self.hlg.lm_scores = self.hlg.scores.clone()
+ if not hasattr(self, "word_table"):
+ self.word_table = {}
+ with open(word, "r") as fin:
+ for line in fin:
+ arr = line.strip().split()
+ assert len(arr) == 2
+ self.word_table[int(arr[1])] = arr[0]
+
+ @torch.no_grad()
+ def hlg_onebest(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ decoding_chunk_size: int = -1,
+ num_decoding_left_chunks: int = -1,
+ simulate_streaming: bool = False,
+ hlg: str = "",
+ word: str = "",
+ symbol_table: Dict[str, int] = None,
+ ) -> List[int]:
+ self.load_hlg_resource_if_necessary(hlg, word)
+ encoder_out, encoder_mask = self._forward_encoder(
+ speech,
+ speech_lengths,
+ decoding_chunk_size,
+ num_decoding_left_chunks,
+ simulate_streaming,
+ ) # (B, maxlen, encoder_dim)
+ ctc_probs = self.ctc.log_softmax(encoder_out) # (1, maxlen, vocab_size)
+ supervision_segments = torch.stack(
+ (
+ torch.arange(len(encoder_mask)),
+ torch.zeros(len(encoder_mask)),
+ encoder_mask.squeeze(dim=1).sum(dim=1).cpu(),
+ ),
+ 1,
+ ).to(torch.int32)
+ lattice = get_lattice(
+ nnet_output=ctc_probs,
+ decoding_graph=self.hlg,
+ supervision_segments=supervision_segments,
+ search_beam=20,
+ output_beam=7,
+ min_active_states=30,
+ max_active_states=10000,
+ subsampling_factor=4,
+ )
+ best_path = one_best_decoding(lattice=lattice, use_double_scores=True)
+ hyps = get_texts(best_path)
+ hyps = [[symbol_table[k] for j in i for k in self.word_table[j]] for i in hyps]
+ return hyps
+
+ @torch.no_grad()
+ def hlg_rescore(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ decoding_chunk_size: int = -1,
+ num_decoding_left_chunks: int = -1,
+ simulate_streaming: bool = False,
+ lm_scale: float = 0,
+ decoder_scale: float = 0,
+ r_decoder_scale: float = 0,
+ hlg: str = "",
+ word: str = "",
+ symbol_table: Dict[str, int] = None,
+ ) -> List[int]:
+ self.load_hlg_resource_if_necessary(hlg, word)
+ device = speech.device
+ encoder_out, encoder_mask = self._forward_encoder(
+ speech,
+ speech_lengths,
+ decoding_chunk_size,
+ num_decoding_left_chunks,
+ simulate_streaming,
+ ) # (B, maxlen, encoder_dim)
+ ctc_probs = self.ctc.log_softmax(encoder_out) # (1, maxlen, vocab_size)
+ supervision_segments = torch.stack(
+ (
+ torch.arange(len(encoder_mask)),
+ torch.zeros(len(encoder_mask)),
+ encoder_mask.squeeze(dim=1).sum(dim=1).cpu(),
+ ),
+ 1,
+ ).to(torch.int32)
+ lattice = get_lattice(
+ nnet_output=ctc_probs,
+ decoding_graph=self.hlg,
+ supervision_segments=supervision_segments,
+ search_beam=20,
+ output_beam=7,
+ min_active_states=30,
+ max_active_states=10000,
+ subsampling_factor=4,
+ )
+ nbest = Nbest.from_lattice(
+ lattice=lattice,
+ num_paths=100,
+ use_double_scores=True,
+ nbest_scale=0.5,
+ )
+ nbest = nbest.intersect(lattice)
+ assert hasattr(nbest.fsa, "lm_scores")
+ assert hasattr(nbest.fsa, "tokens")
+ assert isinstance(nbest.fsa.tokens, torch.Tensor)
+
+ tokens_shape = nbest.fsa.arcs.shape().remove_axis(1)
+ tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens)
+ tokens = tokens.remove_values_leq(0)
+ hyps = tokens.tolist()
+
+ # cal attention_score
+ hyps_pad = pad_sequence(
+ [torch.tensor(hyp, device=device, dtype=torch.long) for hyp in hyps],
+ True,
+ self.ignore_id,
+ ) # (beam_size, max_hyps_len)
+ ori_hyps_pad = hyps_pad
+ hyps_lens = torch.tensor(
+ [len(hyp) for hyp in hyps], device=device, dtype=torch.long
+ ) # (beam_size,)
+ hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id)
+ hyps_lens = hyps_lens + 1 # Add at begining
+ encoder_out_repeat = []
+ tot_scores = nbest.tot_scores()
+ repeats = [tot_scores[i].shape[0] for i in range(tot_scores.dim0)]
+ for i in range(len(encoder_out)):
+ encoder_out_repeat.append(encoder_out[i : i + 1].repeat(repeats[i], 1, 1))
+ encoder_out = torch.concat(encoder_out_repeat, dim=0)
+ encoder_mask = torch.ones(
+ encoder_out.size(0), 1, encoder_out.size(1), dtype=torch.bool, device=device
+ )
+ # used for right to left decoder
+ r_hyps_pad = reverse_pad_list(ori_hyps_pad, hyps_lens, self.ignore_id)
+ r_hyps_pad, _ = add_sos_eos(r_hyps_pad, self.sos, self.eos, self.ignore_id)
+ reverse_weight = 0.5
+ decoder_out, r_decoder_out, _ = self.decoder(
+ encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad, reverse_weight
+ ) # (beam_size, max_hyps_len, vocab_size)
+ decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1)
+ decoder_out = decoder_out
+ # r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a
+ # conventional transformer decoder.
+ r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1)
+ r_decoder_out = r_decoder_out
+
+ decoder_scores = torch.tensor(
+ [
+ sum([decoder_out[i, j, hyps[i][j]] for j in range(len(hyps[i]))])
+ for i in range(len(hyps))
+ ],
+ device=device,
+ )
+ r_decoder_scores = []
+ for i in range(len(hyps)):
+ score = 0
+ for j in range(len(hyps[i])):
+ score += r_decoder_out[i, len(hyps[i]) - j - 1, hyps[i][j]]
+ score += r_decoder_out[i, len(hyps[i]), self.eos]
+ r_decoder_scores.append(score)
+ r_decoder_scores = torch.tensor(r_decoder_scores, device=device)
+
+ am_scores = nbest.compute_am_scores()
+ ngram_lm_scores = nbest.compute_lm_scores()
+ tot_scores = (
+ am_scores.values
+ + lm_scale * ngram_lm_scores.values
+ + decoder_scale * decoder_scores
+ + r_decoder_scale * r_decoder_scores
+ )
+ ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
+ max_indexes = ragged_tot_scores.argmax()
+ best_path = k2.index_fsa(nbest.fsa, max_indexes)
+ hyps = get_texts(best_path)
+ hyps = [[symbol_table[k] for j in i for k in self.word_table[j]] for i in hyps]
+ return hyps
+
+ @torch.jit.export
+ def subsampling_rate(self) -> int:
+ """Export interface for c++ call, return subsampling_rate of the
+ model
+ """
+ return self.encoder.embed.subsampling_rate
+
+ @torch.jit.export
+ def right_context(self) -> int:
+ """Export interface for c++ call, return right_context of the model"""
+ return self.encoder.embed.right_context
+
+ @torch.jit.export
+ def sos_symbol(self) -> int:
+ """Export interface for c++ call, return sos symbol id of the model"""
+ return self.sos
+
+ @torch.jit.export
+ def eos_symbol(self) -> int:
+ """Export interface for c++ call, return eos symbol id of the model"""
+ return self.eos
+
+ @torch.jit.export
+ def forward_encoder_chunk(
+ self,
+ xs: torch.Tensor,
+ offset: int,
+ required_cache_size: int,
+ att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
+ cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """ Export interface for c++ call, give input chunk xs, and return
+ output from time 0 to current chunk.
+
+ Args:
+ xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim),
+ where `time == (chunk_size - 1) * subsample_rate + \
+ subsample.right_context + 1`
+ offset (int): current offset in encoder output time stamp
+ required_cache_size (int): cache size required for next chunk
+ compuation
+ >=0: actual cache size
+ <0: means all history cache is required
+ att_cache (torch.Tensor): cache tensor for KEY & VALUE in
+ transformer/conformer attention, with shape
+ (elayers, head, cache_t1, d_k * 2), where
+ `head * d_k == hidden-dim` and
+ `cache_t1 == chunk_size * num_decoding_left_chunks`.
+ cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
+ (elayers, b=1, hidden-dim, cache_t2), where
+ `cache_t2 == cnn.lorder - 1`
+
+ Returns:
+ torch.Tensor: output of current input xs,
+ with shape (b=1, chunk_size, hidden-dim).
+ torch.Tensor: new attention cache required for next chunk, with
+ dynamic shape (elayers, head, ?, d_k * 2)
+ depending on required_cache_size.
+ torch.Tensor: new conformer cnn cache required for next chunk, with
+ same shape as the original cnn_cache.
+
+ """
+ return self.encoder.forward_chunk(
+ xs, offset, required_cache_size, att_cache, cnn_cache
+ )
+
+ @torch.jit.export
+ def ctc_activation(self, xs: torch.Tensor) -> torch.Tensor:
+ """Export interface for c++ call, apply linear transform and log
+ softmax before ctc
+ Args:
+ xs (torch.Tensor): encoder output
+
+ Returns:
+ torch.Tensor: activation before ctc
+
+ """
+ return self.ctc.log_softmax(xs)
+
+ @torch.jit.export
+ def is_bidirectional_decoder(self) -> bool:
+ """
+ Returns:
+ torch.Tensor: decoder output
+ """
+ if hasattr(self.decoder, "right_decoder"):
+ return True
+ else:
+ return False
+
+ @torch.jit.export
+ def forward_attention_decoder(
+ self,
+ hyps: torch.Tensor,
+ hyps_lens: torch.Tensor,
+ encoder_out: torch.Tensor,
+ reverse_weight: float = 0,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Export interface for c++ call, forward decoder with multiple
+ hypothesis from ctc prefix beam search and one encoder output
+ Args:
+ hyps (torch.Tensor): hyps from ctc prefix beam search, already
+ pad sos at the begining
+ hyps_lens (torch.Tensor): length of each hyp in hyps
+ encoder_out (torch.Tensor): corresponding encoder output
+ r_hyps (torch.Tensor): hyps from ctc prefix beam search, already
+ pad eos at the begining which is used fo right to left decoder
+ reverse_weight: used for verfing whether used right to left decoder,
+ > 0 will use.
+
+ Returns:
+ torch.Tensor: decoder output
+ """
+ assert encoder_out.size(0) == 1
+ num_hyps = hyps.size(0)
+ assert hyps_lens.size(0) == num_hyps
+ encoder_out = encoder_out.repeat(num_hyps, 1, 1)
+ encoder_mask = torch.ones(
+ num_hyps,
+ 1,
+ encoder_out.size(1),
+ dtype=torch.bool,
+ device=encoder_out.device,
+ )
+
+ # input for right to left decoder
+ # this hyps_lens has count token, we need minus it.
+ r_hyps_lens = hyps_lens - 1
+ # this hyps has included token, so it should be
+ # convert the original hyps.
+ r_hyps = hyps[:, 1:]
+ # >>> r_hyps
+ # >>> tensor([[ 1, 2, 3],
+ # >>> [ 9, 8, 4],
+ # >>> [ 2, -1, -1]])
+ # >>> r_hyps_lens
+ # >>> tensor([3, 3, 1])
+
+ # NOTE(Mddct): `pad_sequence` is not supported by ONNX, it is used
+ # in `reverse_pad_list` thus we have to refine the below code.
+ # Issue: https://github.com/wenet-e2e/wenet/issues/1113
+ # Equal to:
+ # >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id))
+ # >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id)
+ max_len = torch.max(r_hyps_lens)
+ index_range = torch.arange(0, max_len, 1).to(encoder_out.device)
+ seq_len_expand = r_hyps_lens.unsqueeze(1)
+ seq_mask = seq_len_expand > index_range # (beam, max_len)
+ # >>> seq_mask
+ # >>> tensor([[ True, True, True],
+ # >>> [ True, True, True],
+ # >>> [ True, False, False]])
+ index = (seq_len_expand - 1) - index_range # (beam, max_len)
+ # >>> index
+ # >>> tensor([[ 2, 1, 0],
+ # >>> [ 2, 1, 0],
+ # >>> [ 0, -1, -2]])
+ index = index * seq_mask
+ # >>> index
+ # >>> tensor([[2, 1, 0],
+ # >>> [2, 1, 0],
+ # >>> [0, 0, 0]])
+ r_hyps = torch.gather(r_hyps, 1, index)
+ # >>> r_hyps
+ # >>> tensor([[3, 2, 1],
+ # >>> [4, 8, 9],
+ # >>> [2, 2, 2]])
+ r_hyps = torch.where(seq_mask, r_hyps, self.eos)
+ # >>> r_hyps
+ # >>> tensor([[3, 2, 1],
+ # >>> [4, 8, 9],
+ # >>> [2, eos, eos]])
+ r_hyps = torch.cat([hyps[:, 0:1], r_hyps], dim=1)
+ # >>> r_hyps
+ # >>> tensor([[sos, 3, 2, 1],
+ # >>> [sos, 4, 8, 9],
+ # >>> [sos, 2, eos, eos]])
+
+ decoder_out, r_decoder_out, _ = self.decoder(
+ encoder_out, encoder_mask, hyps, hyps_lens, r_hyps, reverse_weight
+ ) # (num_hyps, max_hyps_len, vocab_size)
+ decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1)
+
+ # right to left decoder may be not used during decoding process,
+ # which depends on reverse_weight param.
+ # r_dccoder_out will be 0.0, if reverse_weight is 0.0
+ r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1)
+ return decoder_out, r_decoder_out
diff --git a/modules/wenet_extractor/transformer/attention.py b/modules/wenet_extractor/transformer/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..a13c0ce0aa408166a990fba119ba2e167f2ee868
--- /dev/null
+++ b/modules/wenet_extractor/transformer/attention.py
@@ -0,0 +1,326 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+"""Multi-Head Attention layer definition."""
+
+import math
+from typing import Tuple
+
+import torch
+from torch import nn
+
+
+class MultiHeadedAttention(nn.Module):
+ """Multi-Head Attention layer.
+
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self, n_head: int, n_feat: int, dropout_rate: float):
+ """Construct an MultiHeadedAttention object."""
+ super().__init__()
+ assert n_feat % n_head == 0
+ # We assume d_v always equals d_k
+ self.d_k = n_feat // n_head
+ self.h = n_head
+ self.linear_q = nn.Linear(n_feat, n_feat)
+ self.linear_k = nn.Linear(n_feat, n_feat)
+ self.linear_v = nn.Linear(n_feat, n_feat)
+ self.linear_out = nn.Linear(n_feat, n_feat)
+ self.dropout = nn.Dropout(p=dropout_rate)
+
+ def forward_qkv(
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Transform query, key and value.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+
+ Returns:
+ torch.Tensor: Transformed query tensor, size
+ (#batch, n_head, time1, d_k).
+ torch.Tensor: Transformed key tensor, size
+ (#batch, n_head, time2, d_k).
+ torch.Tensor: Transformed value tensor, size
+ (#batch, n_head, time2, d_k).
+
+ """
+ n_batch = query.size(0)
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
+
+ return q, k, v
+
+ def forward_attention(
+ self,
+ value: torch.Tensor,
+ scores: torch.Tensor,
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ ) -> torch.Tensor:
+ """Compute attention context vector.
+
+ Args:
+ value (torch.Tensor): Transformed value, size
+ (#batch, n_head, time2, d_k).
+ scores (torch.Tensor): Attention score, size
+ (#batch, n_head, time1, time2).
+ mask (torch.Tensor): Mask, size (#batch, 1, time2) or
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
+
+ Returns:
+ torch.Tensor: Transformed value (#batch, time1, d_model)
+ weighted by the attention score (#batch, time1, time2).
+
+ """
+ n_batch = value.size(0)
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be True?
+ # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
+ # 1st chunk to ease the onnx export.]
+ # 2. pytorch training
+ if mask.size(2) > 0: # time2 > 0
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
+ # For last chunk, time2 might be larger than scores.size(-1)
+ mask = mask[:, :, :, : scores.size(-1)] # (batch, 1, *, time2)
+ scores = scores.masked_fill(mask, -float("inf"))
+ attn = torch.softmax(scores, dim=-1).masked_fill(
+ mask, 0.0
+ ) # (batch, head, time1, time2)
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be False?
+ # 1. onnx(16/-1, -1/-1, 16/0)
+ # 2. jit (16/-1, -1/-1, 16/0, 16/4)
+ else:
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+
+ p_attn = self.dropout(attn)
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
+ x = (
+ x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+ ) # (batch, time1, d_model)
+
+ return self.linear_out(x) # (batch, time1, d_model)
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ pos_emb: torch.Tensor = torch.empty(0),
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute scaled dot product attention.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+ 1.When applying cross attention between decoder and encoder,
+ the batch padding mask for input is in (#batch, 1, T) shape.
+ 2.When applying self attention of encoder,
+ the mask is in (#batch, T, T) shape.
+ 3.When applying self attention of decoder,
+ the mask is in (#batch, L, L) shape.
+ 4.If the different position in decoder see different block
+ of the encoder, such as Mocha, the passed in mask could be
+ in (#batch, L, T) shape. But there is no such case in current
+ Wenet.
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
+ where `cache_t == chunk_size * num_decoding_left_chunks`
+ and `head * d_k == size`
+
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
+ where `cache_t == chunk_size * num_decoding_left_chunks`
+ and `head * d_k == size`
+
+ """
+ q, k, v = self.forward_qkv(query, key, value)
+
+ # NOTE(xcsong):
+ # when export onnx model, for 1st chunk, we feed
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
+ # and we will always do splitting and
+ # concatnation(this will simplify onnx export). Note that
+ # it's OK to concat & split zero-shaped tensors(see code below).
+ # when export jit model, for 1st chunk, we always feed
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
+ # >>> a = torch.ones((1, 2, 0, 4))
+ # >>> b = torch.ones((1, 2, 3, 4))
+ # >>> c = torch.cat((a, b), dim=2)
+ # >>> torch.equal(b, c) # True
+ # >>> d = torch.split(a, 2, dim=-1)
+ # >>> torch.equal(d[0], d[1]) # True
+ if cache.size(0) > 0:
+ key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1)
+ k = torch.cat([key_cache, k], dim=2)
+ v = torch.cat([value_cache, v], dim=2)
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
+ # non-trivial to calculate `next_cache_start` here.
+ new_cache = torch.cat((k, v), dim=-1)
+
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
+ return self.forward_attention(v, scores, mask), new_cache
+
+
+class RelPositionMultiHeadedAttention(MultiHeadedAttention):
+ """Multi-Head Attention layer with relative position encoding.
+ Paper: https://arxiv.org/abs/1901.02860
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+ """
+
+ def __init__(self, n_head, n_feat, dropout_rate):
+ """Construct an RelPositionMultiHeadedAttention object."""
+ super().__init__(n_head, n_feat, dropout_rate)
+ # linear transformation for positional encoding
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
+ # these two learnable bias are used in matrix c and matrix d
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
+
+ def rel_shift(self, x, zero_triu: bool = False):
+ """Compute relative positinal encoding.
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, size).
+ zero_triu (bool): If true, return the lower triangular part of
+ the matrix.
+ Returns:
+ torch.Tensor: Output tensor.
+ """
+
+ zero_pad = torch.zeros(
+ (x.size()[0], x.size()[1], x.size()[2], 1), device=x.device, dtype=x.dtype
+ )
+ x_padded = torch.cat([zero_pad, x], dim=-1)
+
+ x_padded = x_padded.view(x.size()[0], x.size()[1], x.size(3) + 1, x.size(2))
+ x = x_padded[:, :, 1:].view_as(x)
+
+ if zero_triu:
+ ones = torch.ones((x.size(2), x.size(3)))
+ x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
+
+ return x
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ pos_emb: torch.Tensor = torch.empty(0),
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
+ pos_emb (torch.Tensor): Positional embedding tensor
+ (#batch, time2, size).
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
+ where `cache_t == chunk_size * num_decoding_left_chunks`
+ and `head * d_k == size`
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
+ where `cache_t == chunk_size * num_decoding_left_chunks`
+ and `head * d_k == size`
+ """
+ q, k, v = self.forward_qkv(query, key, value)
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
+
+ # NOTE(xcsong):
+ # when export onnx model, for 1st chunk, we feed
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
+ # and we will always do splitting and
+ # concatnation(this will simplify onnx export). Note that
+ # it's OK to concat & split zero-shaped tensors(see code below).
+ # when export jit model, for 1st chunk, we always feed
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
+ # >>> a = torch.ones((1, 2, 0, 4))
+ # >>> b = torch.ones((1, 2, 3, 4))
+ # >>> c = torch.cat((a, b), dim=2)
+ # >>> torch.equal(b, c) # True
+ # >>> d = torch.split(a, 2, dim=-1)
+ # >>> torch.equal(d[0], d[1]) # True
+ if cache.size(0) > 0:
+ key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1)
+ k = torch.cat([key_cache, k], dim=2)
+ v = torch.cat([value_cache, v], dim=2)
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
+ # non-trivial to calculate `next_cache_start` here.
+ new_cache = torch.cat((k, v), dim=-1)
+
+ n_batch_pos = pos_emb.size(0)
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
+
+ # (batch, head, time1, d_k)
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
+ # (batch, head, time1, d_k)
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
+
+ # compute attention score
+ # first compute matrix a and matrix c
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ # (batch, head, time1, time2)
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
+
+ # compute matrix b and matrix d
+ # (batch, head, time1, time2)
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
+ # Remove rel_shift since it is useless in speech recognition,
+ # and it requires special attention for streaming.
+ # matrix_bd = self.rel_shift(matrix_bd)
+
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
+ self.d_k
+ ) # (batch, head, time1, time2)
+
+ return self.forward_attention(v, scores, mask), new_cache
diff --git a/modules/wenet_extractor/transformer/cmvn.py b/modules/wenet_extractor/transformer/cmvn.py
new file mode 100644
index 0000000000000000000000000000000000000000..d97f8924e7bf0db25d475be79256086682aa9f6a
--- /dev/null
+++ b/modules/wenet_extractor/transformer/cmvn.py
@@ -0,0 +1,51 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+import torch
+
+
+class GlobalCMVN(torch.nn.Module):
+ def __init__(self, mean: torch.Tensor, istd: torch.Tensor, norm_var: bool = True):
+ """
+ Args:
+ mean (torch.Tensor): mean stats
+ istd (torch.Tensor): inverse std, std which is 1.0 / std
+ """
+ super().__init__()
+ assert mean.shape == istd.shape
+ self.norm_var = norm_var
+ # The buffer can be accessed from this module using self.mean
+ self.register_buffer("mean", mean)
+ self.register_buffer("istd", istd)
+
+ def forward(self, x: torch.Tensor):
+ """
+ Args:
+ x (torch.Tensor): (batch, max_len, feat_dim)
+
+ Returns:
+ (torch.Tensor): normalized feature
+ """
+ x = x - self.mean
+ if self.norm_var:
+ x = x * self.istd
+ return x
diff --git a/modules/wenet_extractor/transformer/convolution.py b/modules/wenet_extractor/transformer/convolution.py
new file mode 100644
index 0000000000000000000000000000000000000000..be1148386f8aab6c1b01244f4a84c7e90aba7952
--- /dev/null
+++ b/modules/wenet_extractor/transformer/convolution.py
@@ -0,0 +1,154 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+"""ConvolutionModule definition."""
+
+from typing import Tuple
+
+import torch
+from torch import nn
+
+
+class ConvolutionModule(nn.Module):
+ """ConvolutionModule in Conformer model."""
+
+ def __init__(
+ self,
+ channels: int,
+ kernel_size: int = 15,
+ activation: nn.Module = nn.ReLU(),
+ norm: str = "batch_norm",
+ causal: bool = False,
+ bias: bool = True,
+ ):
+ """Construct an ConvolutionModule object.
+ Args:
+ channels (int): The number of channels of conv layers.
+ kernel_size (int): Kernel size of conv layers.
+ causal (int): Whether use causal convolution or not
+ """
+ super().__init__()
+
+ self.pointwise_conv1 = nn.Conv1d(
+ channels,
+ 2 * channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+ # self.lorder is used to distinguish if it's a causal convolution,
+ # if self.lorder > 0: it's a causal convolution, the input will be
+ # padded with self.lorder frames on the left in forward.
+ # else: it's a symmetrical convolution
+ if causal:
+ padding = 0
+ self.lorder = kernel_size - 1
+ else:
+ # kernel_size should be an odd number for none causal convolution
+ assert (kernel_size - 1) % 2 == 0
+ padding = (kernel_size - 1) // 2
+ self.lorder = 0
+ self.depthwise_conv = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ stride=1,
+ padding=padding,
+ groups=channels,
+ bias=bias,
+ )
+
+ assert norm in ["batch_norm", "layer_norm"]
+ if norm == "batch_norm":
+ self.use_layer_norm = False
+ self.norm = nn.BatchNorm1d(channels)
+ else:
+ self.use_layer_norm = True
+ self.norm = nn.LayerNorm(channels)
+
+ self.pointwise_conv2 = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+ self.activation = activation
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ cache: torch.Tensor = torch.zeros((0, 0, 0)),
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute convolution module.
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, channels).
+ mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
+ (0, 0, 0) means fake mask.
+ cache (torch.Tensor): left context cache, it is only
+ used in causal convolution (#batch, channels, cache_t),
+ (0, 0, 0) meas fake cache.
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, channels).
+ """
+ # exchange the temporal dimension and the feature dimension
+ x = x.transpose(1, 2) # (#batch, channels, time)
+
+ # mask batch padding
+ if mask_pad.size(2) > 0: # time > 0
+ x.masked_fill_(~mask_pad, 0.0)
+
+ if self.lorder > 0:
+ if cache.size(2) == 0: # cache_t == 0
+ x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
+ else:
+ assert cache.size(0) == x.size(0) # equal batch
+ assert cache.size(1) == x.size(1) # equal channel
+ x = torch.cat((cache, x), dim=2)
+ assert x.size(2) > self.lorder
+ new_cache = x[:, :, -self.lorder :]
+ else:
+ # It's better we just return None if no cache is required,
+ # However, for JIT export, here we just fake one tensor instead of
+ # None.
+ new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
+
+ # GLU mechanism
+ x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
+
+ # 1D Depthwise Conv
+ x = self.depthwise_conv(x)
+ if self.use_layer_norm:
+ x = x.transpose(1, 2)
+ x = self.activation(self.norm(x))
+ if self.use_layer_norm:
+ x = x.transpose(1, 2)
+ x = self.pointwise_conv2(x)
+ # mask batch padding
+ if mask_pad.size(2) > 0: # time > 0
+ x.masked_fill_(~mask_pad, 0.0)
+
+ return x.transpose(1, 2), new_cache
diff --git a/modules/wenet_extractor/transformer/ctc.py b/modules/wenet_extractor/transformer/ctc.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cbf13a83137642a61b2c46033907ca22d9d6498
--- /dev/null
+++ b/modules/wenet_extractor/transformer/ctc.py
@@ -0,0 +1,95 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+import torch
+import torch.nn.functional as F
+
+
+class CTC(torch.nn.Module):
+ """CTC module"""
+
+ def __init__(
+ self,
+ odim: int,
+ encoder_output_size: int,
+ dropout_rate: float = 0.0,
+ reduce: bool = True,
+ ):
+ """Construct CTC module
+ Args:
+ odim: dimension of outputs
+ encoder_output_size: number of encoder projection units
+ dropout_rate: dropout rate (0.0 ~ 1.0)
+ reduce: reduce the CTC loss into a scalar
+ """
+ super().__init__()
+ eprojs = encoder_output_size
+ self.dropout_rate = dropout_rate
+ self.ctc_lo = torch.nn.Linear(eprojs, odim)
+
+ reduction_type = "sum" if reduce else "none"
+ self.ctc_loss = torch.nn.CTCLoss(reduction=reduction_type)
+
+ def forward(
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_lens: torch.Tensor,
+ ) -> torch.Tensor:
+ """Calculate CTC loss.
+
+ Args:
+ hs_pad: batch of padded hidden state sequences (B, Tmax, D)
+ hlens: batch of lengths of hidden state sequences (B)
+ ys_pad: batch of padded character id sequence tensor (B, Lmax)
+ ys_lens: batch of lengths of character sequence (B)
+ """
+ # hs_pad: (B, L, NProj) -> ys_hat: (B, L, Nvocab)
+ ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate))
+ # ys_hat: (B, L, D) -> (L, B, D)
+ ys_hat = ys_hat.transpose(0, 1)
+ ys_hat = ys_hat.log_softmax(2)
+ loss = self.ctc_loss(ys_hat, ys_pad, hlens, ys_lens)
+ # Batch-size average
+ loss = loss / ys_hat.size(1)
+ return loss
+
+ def log_softmax(self, hs_pad: torch.Tensor) -> torch.Tensor:
+ """log_softmax of frame activations
+
+ Args:
+ Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
+ Returns:
+ torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim)
+ """
+ return F.log_softmax(self.ctc_lo(hs_pad), dim=2)
+
+ def argmax(self, hs_pad: torch.Tensor) -> torch.Tensor:
+ """argmax of frame activations
+
+ Args:
+ torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
+ Returns:
+ torch.Tensor: argmax applied 2d tensor (B, Tmax)
+ """
+ return torch.argmax(self.ctc_lo(hs_pad), dim=2)
diff --git a/modules/wenet_extractor/transformer/decoder.py b/modules/wenet_extractor/transformer/decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..83a4238930470167a0898f773c34683b5851ae20
--- /dev/null
+++ b/modules/wenet_extractor/transformer/decoder.py
@@ -0,0 +1,325 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+"""Decoder definition."""
+from typing import Tuple, List, Optional
+
+import torch
+
+from modules.wenet_extractor.transformer.attention import MultiHeadedAttention
+from modules.wenet_extractor.transformer.decoder_layer import DecoderLayer
+from modules.wenet_extractor.transformer.embedding import PositionalEncoding
+from modules.wenet_extractor.transformer.embedding import NoPositionalEncoding
+from modules.wenet_extractor.transformer.positionwise_feed_forward import (
+ PositionwiseFeedForward,
+)
+from modules.wenet_extractor.utils.mask import subsequent_mask, make_pad_mask
+
+
+class TransformerDecoder(torch.nn.Module):
+ """Base class of Transfomer decoder module.
+ Args:
+ vocab_size: output dim
+ encoder_output_size: dimension of attention
+ attention_heads: the number of heads of multi head attention
+ linear_units: the hidden units number of position-wise feedforward
+ num_blocks: the number of decoder blocks
+ dropout_rate: dropout rate
+ self_attention_dropout_rate: dropout rate for attention
+ input_layer: input layer type
+ use_output_layer: whether to use output layer
+ pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
+ normalize_before:
+ True: use layer_norm before each sub-block of a layer.
+ False: use layer_norm after each sub-block of a layer.
+ src_attention: if false, encoder-decoder cross attention is not
+ applied, such as CIF model
+ """
+
+ def __init__(
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ self_attention_dropout_rate: float = 0.0,
+ src_attention_dropout_rate: float = 0.0,
+ input_layer: str = "embed",
+ use_output_layer: bool = True,
+ normalize_before: bool = True,
+ src_attention: bool = True,
+ ):
+ super().__init__()
+ attention_dim = encoder_output_size
+
+ if input_layer == "embed":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Embedding(vocab_size, attention_dim),
+ PositionalEncoding(attention_dim, positional_dropout_rate),
+ )
+ elif input_layer == "none":
+ self.embed = NoPositionalEncoding(attention_dim, positional_dropout_rate)
+ else:
+ raise ValueError(f"only 'embed' is supported: {input_layer}")
+
+ self.normalize_before = normalize_before
+ self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-5)
+ self.use_output_layer = use_output_layer
+ self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
+ self.num_blocks = num_blocks
+ self.decoders = torch.nn.ModuleList(
+ [
+ DecoderLayer(
+ attention_dim,
+ MultiHeadedAttention(
+ attention_heads, attention_dim, self_attention_dropout_rate
+ ),
+ (
+ MultiHeadedAttention(
+ attention_heads, attention_dim, src_attention_dropout_rate
+ )
+ if src_attention
+ else None
+ ),
+ PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ )
+ for _ in range(self.num_blocks)
+ ]
+ )
+
+ def forward(
+ self,
+ memory: torch.Tensor,
+ memory_mask: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ r_ys_in_pad: torch.Tensor = torch.empty(0),
+ reverse_weight: float = 0.0,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Forward decoder.
+ Args:
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
+ memory_mask: encoder memory mask, (batch, 1, maxlen_in)
+ ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
+ ys_in_lens: input lengths of this batch (batch)
+ r_ys_in_pad: not used in transformer decoder, in order to unify api
+ with bidirectional decoder
+ reverse_weight: not used in transformer decoder, in order to unify
+ api with bidirectional decode
+ Returns:
+ (tuple): tuple containing:
+ x: decoded token score before softmax (batch, maxlen_out,
+ vocab_size) if use_output_layer is True,
+ torch.tensor(0.0), in order to unify api with bidirectional decoder
+ olens: (batch, )
+ """
+ tgt = ys_in_pad
+ maxlen = tgt.size(1)
+ # tgt_mask: (B, 1, L)
+ tgt_mask = ~make_pad_mask(ys_in_lens, maxlen).unsqueeze(1)
+ tgt_mask = tgt_mask.to(tgt.device)
+ # m: (1, L, L)
+ m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
+ # tgt_mask: (B, L, L)
+ tgt_mask = tgt_mask & m
+ x, _ = self.embed(tgt)
+ for layer in self.decoders:
+ x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory, memory_mask)
+ if self.normalize_before:
+ x = self.after_norm(x)
+ if self.use_output_layer:
+ x = self.output_layer(x)
+ olens = tgt_mask.sum(1)
+ return x, torch.tensor(0.0), olens
+
+ def forward_one_step(
+ self,
+ memory: torch.Tensor,
+ memory_mask: torch.Tensor,
+ tgt: torch.Tensor,
+ tgt_mask: torch.Tensor,
+ cache: Optional[List[torch.Tensor]] = None,
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+ """Forward one step.
+ This is only used for decoding.
+ Args:
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
+ memory_mask: encoded memory mask, (batch, 1, maxlen_in)
+ tgt: input token ids, int64 (batch, maxlen_out)
+ tgt_mask: input token mask, (batch, maxlen_out)
+ dtype=torch.uint8 in PyTorch 1.2-
+ dtype=torch.bool in PyTorch 1.2+ (include 1.2)
+ cache: cached output list of (batch, max_time_out-1, size)
+ Returns:
+ y, cache: NN output value and cache per `self.decoders`.
+ y.shape` is (batch, maxlen_out, token)
+ """
+ x, _ = self.embed(tgt)
+ new_cache = []
+ for i, decoder in enumerate(self.decoders):
+ if cache is None:
+ c = None
+ else:
+ c = cache[i]
+ x, tgt_mask, memory, memory_mask = decoder(
+ x, tgt_mask, memory, memory_mask, cache=c
+ )
+ new_cache.append(x)
+ if self.normalize_before:
+ y = self.after_norm(x[:, -1])
+ else:
+ y = x[:, -1]
+ if self.use_output_layer:
+ y = torch.log_softmax(self.output_layer(y), dim=-1)
+ return y, new_cache
+
+
+class BiTransformerDecoder(torch.nn.Module):
+ """Base class of Transfomer decoder module.
+ Args:
+ vocab_size: output dim
+ encoder_output_size: dimension of attention
+ attention_heads: the number of heads of multi head attention
+ linear_units: the hidden units number of position-wise feedforward
+ num_blocks: the number of decoder blocks
+ r_num_blocks: the number of right to left decoder blocks
+ dropout_rate: dropout rate
+ self_attention_dropout_rate: dropout rate for attention
+ input_layer: input layer type
+ use_output_layer: whether to use output layer
+ pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
+ normalize_before:
+ True: use layer_norm before each sub-block of a layer.
+ False: use layer_norm after each sub-block of a layer.
+ """
+
+ def __init__(
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ r_num_blocks: int = 0,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ self_attention_dropout_rate: float = 0.0,
+ src_attention_dropout_rate: float = 0.0,
+ input_layer: str = "embed",
+ use_output_layer: bool = True,
+ normalize_before: bool = True,
+ ):
+ super().__init__()
+ self.left_decoder = TransformerDecoder(
+ vocab_size,
+ encoder_output_size,
+ attention_heads,
+ linear_units,
+ num_blocks,
+ dropout_rate,
+ positional_dropout_rate,
+ self_attention_dropout_rate,
+ src_attention_dropout_rate,
+ input_layer,
+ use_output_layer,
+ normalize_before,
+ )
+
+ self.right_decoder = TransformerDecoder(
+ vocab_size,
+ encoder_output_size,
+ attention_heads,
+ linear_units,
+ r_num_blocks,
+ dropout_rate,
+ positional_dropout_rate,
+ self_attention_dropout_rate,
+ src_attention_dropout_rate,
+ input_layer,
+ use_output_layer,
+ normalize_before,
+ )
+
+ def forward(
+ self,
+ memory: torch.Tensor,
+ memory_mask: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ r_ys_in_pad: torch.Tensor,
+ reverse_weight: float = 0.0,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Forward decoder.
+ Args:
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
+ memory_mask: encoder memory mask, (batch, 1, maxlen_in)
+ ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
+ ys_in_lens: input lengths of this batch (batch)
+ r_ys_in_pad: padded input token ids, int64 (batch, maxlen_out),
+ used for right to left decoder
+ reverse_weight: used for right to left decoder
+ Returns:
+ (tuple): tuple containing:
+ x: decoded token score before softmax (batch, maxlen_out,
+ vocab_size) if use_output_layer is True,
+ r_x: x: decoded token score (right to left decoder)
+ before softmax (batch, maxlen_out, vocab_size)
+ if use_output_layer is True,
+ olens: (batch, )
+ """
+ l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad, ys_in_lens)
+ r_x = torch.tensor(0.0)
+ if reverse_weight > 0.0:
+ r_x, _, olens = self.right_decoder(
+ memory, memory_mask, r_ys_in_pad, ys_in_lens
+ )
+ return l_x, r_x, olens
+
+ def forward_one_step(
+ self,
+ memory: torch.Tensor,
+ memory_mask: torch.Tensor,
+ tgt: torch.Tensor,
+ tgt_mask: torch.Tensor,
+ cache: Optional[List[torch.Tensor]] = None,
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+ """Forward one step.
+ This is only used for decoding.
+ Args:
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
+ memory_mask: encoded memory mask, (batch, 1, maxlen_in)
+ tgt: input token ids, int64 (batch, maxlen_out)
+ tgt_mask: input token mask, (batch, maxlen_out)
+ dtype=torch.uint8 in PyTorch 1.2-
+ dtype=torch.bool in PyTorch 1.2+ (include 1.2)
+ cache: cached output list of (batch, max_time_out-1, size)
+ Returns:
+ y, cache: NN output value and cache per `self.decoders`.
+ y.shape` is (batch, maxlen_out, token)
+ """
+ return self.left_decoder.forward_one_step(
+ memory, memory_mask, tgt, tgt_mask, cache
+ )
diff --git a/modules/wenet_extractor/transformer/decoder_layer.py b/modules/wenet_extractor/transformer/decoder_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc4171c47d4cccddade323f04ee1730b98feff76
--- /dev/null
+++ b/modules/wenet_extractor/transformer/decoder_layer.py
@@ -0,0 +1,140 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+"""Decoder self-attention layer definition."""
+from typing import Optional, Tuple
+
+import torch
+from torch import nn
+
+
+class DecoderLayer(nn.Module):
+ """Single decoder layer module.
+
+ Args:
+ size (int): Input dimension.
+ self_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` instance can be used as the argument.
+ src_attn (torch.nn.Module): Inter-attention module instance.
+ `MultiHeadedAttention` instance can be used as the argument.
+ If `None` is passed, Inter-attention is not used, such as
+ CIF, GPT, and other decoder only model.
+ feed_forward (torch.nn.Module): Feed-forward module instance.
+ `PositionwiseFeedForward` instance can be used as the argument.
+ dropout_rate (float): Dropout rate.
+ normalize_before (bool):
+ True: use layer_norm before each sub-block.
+ False: to use layer_norm after each sub-block.
+ """
+
+ def __init__(
+ self,
+ size: int,
+ self_attn: nn.Module,
+ src_attn: Optional[nn.Module],
+ feed_forward: nn.Module,
+ dropout_rate: float,
+ normalize_before: bool = True,
+ ):
+ """Construct an DecoderLayer object."""
+ super().__init__()
+ self.size = size
+ self.self_attn = self_attn
+ self.src_attn = src_attn
+ self.feed_forward = feed_forward
+ self.norm1 = nn.LayerNorm(size, eps=1e-5)
+ self.norm2 = nn.LayerNorm(size, eps=1e-5)
+ self.norm3 = nn.LayerNorm(size, eps=1e-5)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.normalize_before = normalize_before
+
+ def forward(
+ self,
+ tgt: torch.Tensor,
+ tgt_mask: torch.Tensor,
+ memory: torch.Tensor,
+ memory_mask: torch.Tensor,
+ cache: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Compute decoded features.
+
+ Args:
+ tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+ tgt_mask (torch.Tensor): Mask for input tensor
+ (#batch, maxlen_out).
+ memory (torch.Tensor): Encoded memory
+ (#batch, maxlen_in, size).
+ memory_mask (torch.Tensor): Encoded memory mask
+ (#batch, maxlen_in).
+ cache (torch.Tensor): cached tensors.
+ (#batch, maxlen_out - 1, size).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, maxlen_out, size).
+ torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+ torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+ torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+ """
+ residual = tgt
+ if self.normalize_before:
+ tgt = self.norm1(tgt)
+
+ if cache is None:
+ tgt_q = tgt
+ tgt_q_mask = tgt_mask
+ else:
+ # compute only the last frame query keeping dim: max_time_out -> 1
+ assert cache.shape == (
+ tgt.shape[0],
+ tgt.shape[1] - 1,
+ self.size,
+ ), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
+ tgt_q = tgt[:, -1:, :]
+ residual = residual[:, -1:, :]
+ tgt_q_mask = tgt_mask[:, -1:, :]
+
+ x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0])
+ if not self.normalize_before:
+ x = self.norm1(x)
+
+ if self.src_attn is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm2(x)
+ x = residual + self.dropout(
+ self.src_attn(x, memory, memory, memory_mask)[0]
+ )
+ if not self.normalize_before:
+ x = self.norm2(x)
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm3(x)
+ x = residual + self.dropout(self.feed_forward(x))
+ if not self.normalize_before:
+ x = self.norm3(x)
+
+ if cache is not None:
+ x = torch.cat([cache, x], dim=1)
+
+ return x, tgt_mask, memory, memory_mask
diff --git a/modules/wenet_extractor/transformer/embedding.py b/modules/wenet_extractor/transformer/embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..55ee113f8d267875f147a55326883c207cc0c646
--- /dev/null
+++ b/modules/wenet_extractor/transformer/embedding.py
@@ -0,0 +1,174 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+"""Positonal Encoding Module."""
+
+import math
+from typing import Tuple, Union
+
+import torch
+import torch.nn.functional as F
+
+
+class PositionalEncoding(torch.nn.Module):
+ """Positional encoding.
+
+ :param int d_model: embedding dim
+ :param float dropout_rate: dropout rate
+ :param int max_len: maximum input length
+
+ PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
+ PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
+ """
+
+ def __init__(
+ self,
+ d_model: int,
+ dropout_rate: float,
+ max_len: int = 5000,
+ reverse: bool = False,
+ ):
+ """Construct an PositionalEncoding object."""
+ super().__init__()
+ self.d_model = d_model
+ self.xscale = math.sqrt(self.d_model)
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+ self.max_len = max_len
+
+ self.pe = torch.zeros(self.max_len, self.d_model)
+ position = torch.arange(0, self.max_len, dtype=torch.float32).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
+ * -(math.log(10000.0) / self.d_model)
+ )
+ self.pe[:, 0::2] = torch.sin(position * div_term)
+ self.pe[:, 1::2] = torch.cos(position * div_term)
+ self.pe = self.pe.unsqueeze(0)
+
+ def forward(
+ self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Add positional encoding.
+
+ Args:
+ x (torch.Tensor): Input. Its shape is (batch, time, ...)
+ offset (int, torch.tensor): position offset
+
+ Returns:
+ torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
+ torch.Tensor: for compatibility to RelPositionalEncoding
+ """
+
+ self.pe = self.pe.to(x.device)
+ pos_emb = self.position_encoding(offset, x.size(1), False)
+ x = x * self.xscale + pos_emb
+ return self.dropout(x), self.dropout(pos_emb)
+
+ def position_encoding(
+ self, offset: Union[int, torch.Tensor], size: int, apply_dropout: bool = True
+ ) -> torch.Tensor:
+ """For getting encoding in a streaming fashion
+
+ Attention!!!!!
+ we apply dropout only once at the whole utterance level in a none
+ streaming way, but will call this function several times with
+ increasing input size in a streaming scenario, so the dropout will
+ be applied several times.
+
+ Args:
+ offset (int or torch.tensor): start offset
+ size (int): required size of position encoding
+
+ Returns:
+ torch.Tensor: Corresponding encoding
+ """
+ # How to subscript a Union type:
+ # https://github.com/pytorch/pytorch/issues/69434
+ if isinstance(offset, int):
+ assert offset + size < self.max_len
+ pos_emb = self.pe[:, offset : offset + size]
+ elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar
+ assert offset + size < self.max_len
+ pos_emb = self.pe[:, offset : offset + size]
+ else: # for batched streaming decoding on GPU
+ assert torch.max(offset) + size < self.max_len
+ index = offset.unsqueeze(1) + torch.arange(0, size).to(
+ offset.device
+ ) # B X T
+ flag = index > 0
+ # remove negative offset
+ index = index * flag
+ pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model
+
+ if apply_dropout:
+ pos_emb = self.dropout(pos_emb)
+ return pos_emb
+
+
+class RelPositionalEncoding(PositionalEncoding):
+ """Relative positional encoding module.
+ See : Appendix B in https://arxiv.org/abs/1901.02860
+ Args:
+ d_model (int): Embedding dimension.
+ dropout_rate (float): Dropout rate.
+ max_len (int): Maximum input length.
+ """
+
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
+ """Initialize class."""
+ super().__init__(d_model, dropout_rate, max_len, reverse=True)
+
+ def forward(
+ self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute positional encoding.
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, `*`).
+ Returns:
+ torch.Tensor: Encoded tensor (batch, time, `*`).
+ torch.Tensor: Positional embedding tensor (1, time, `*`).
+ """
+ self.pe = self.pe.to(x.device)
+ x = x * self.xscale
+ pos_emb = self.position_encoding(offset, x.size(1), False)
+ return self.dropout(x), self.dropout(pos_emb)
+
+
+class NoPositionalEncoding(torch.nn.Module):
+ """No position encoding"""
+
+ def __init__(self, d_model: int, dropout_rate: float):
+ super().__init__()
+ self.d_model = d_model
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+
+ def forward(
+ self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Just return zero vector for interface compatibility"""
+ pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device)
+ return self.dropout(x), pos_emb
+
+ def position_encoding(
+ self, offset: Union[int, torch.Tensor], size: int
+ ) -> torch.Tensor:
+ return torch.zeros(1, size, self.d_model)
diff --git a/modules/wenet_extractor/transformer/encoder.py b/modules/wenet_extractor/transformer/encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..2eda3c2cb721804559f5ade930753c7ed9e43172
--- /dev/null
+++ b/modules/wenet_extractor/transformer/encoder.py
@@ -0,0 +1,507 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+"""Encoder definition."""
+from typing import Tuple
+
+import torch
+
+from modules.wenet_extractor.transformer.attention import MultiHeadedAttention
+from modules.wenet_extractor.transformer.attention import (
+ RelPositionMultiHeadedAttention,
+)
+from modules.wenet_extractor.transformer.convolution import ConvolutionModule
+from modules.wenet_extractor.transformer.embedding import PositionalEncoding
+from modules.wenet_extractor.transformer.embedding import RelPositionalEncoding
+from modules.wenet_extractor.transformer.embedding import NoPositionalEncoding
+from modules.wenet_extractor.transformer.encoder_layer import TransformerEncoderLayer
+from modules.wenet_extractor.transformer.encoder_layer import ConformerEncoderLayer
+from modules.wenet_extractor.transformer.positionwise_feed_forward import (
+ PositionwiseFeedForward,
+)
+from modules.wenet_extractor.transformer.subsampling import Conv2dSubsampling4
+from modules.wenet_extractor.transformer.subsampling import Conv2dSubsampling6
+from modules.wenet_extractor.transformer.subsampling import Conv2dSubsampling8
+from modules.wenet_extractor.transformer.subsampling import LinearNoSubsampling
+from modules.wenet_extractor.utils.common import get_activation
+from modules.wenet_extractor.utils.mask import make_pad_mask
+from modules.wenet_extractor.utils.mask import add_optional_chunk_mask
+
+
+class BaseEncoder(torch.nn.Module):
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ input_layer: str = "conv2d",
+ pos_enc_layer_type: str = "abs_pos",
+ normalize_before: bool = True,
+ static_chunk_size: int = 0,
+ use_dynamic_chunk: bool = False,
+ global_cmvn: torch.nn.Module = None,
+ use_dynamic_left_chunk: bool = False,
+ ):
+ """
+ Args:
+ input_size (int): input dim
+ output_size (int): dimension of attention
+ attention_heads (int): the number of heads of multi head attention
+ linear_units (int): the hidden units number of position-wise feed
+ forward
+ num_blocks (int): the number of decoder blocks
+ dropout_rate (float): dropout rate
+ attention_dropout_rate (float): dropout rate in attention
+ positional_dropout_rate (float): dropout rate after adding
+ positional encoding
+ input_layer (str): input layer type.
+ optional [linear, conv2d, conv2d6, conv2d8]
+ pos_enc_layer_type (str): Encoder positional encoding layer type.
+ opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
+ normalize_before (bool):
+ True: use layer_norm before each sub-block of a layer.
+ False: use layer_norm after each sub-block of a layer.
+ static_chunk_size (int): chunk size for static chunk training and
+ decoding
+ use_dynamic_chunk (bool): whether use dynamic chunk size for
+ training or not, You can only use fixed chunk(chunk_size > 0)
+ or dyanmic chunk size(use_dynamic_chunk = True)
+ global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
+ use_dynamic_left_chunk (bool): whether use dynamic left chunk in
+ dynamic chunk training
+ """
+ super().__init__()
+ self._output_size = output_size
+
+ if pos_enc_layer_type == "abs_pos":
+ pos_enc_class = PositionalEncoding
+ elif pos_enc_layer_type == "rel_pos":
+ pos_enc_class = RelPositionalEncoding
+ elif pos_enc_layer_type == "no_pos":
+ pos_enc_class = NoPositionalEncoding
+ else:
+ raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
+
+ if input_layer == "linear":
+ subsampling_class = LinearNoSubsampling
+ elif input_layer == "conv2d":
+ subsampling_class = Conv2dSubsampling4
+ elif input_layer == "conv2d6":
+ subsampling_class = Conv2dSubsampling6
+ elif input_layer == "conv2d8":
+ subsampling_class = Conv2dSubsampling8
+ else:
+ raise ValueError("unknown input_layer: " + input_layer)
+
+ self.global_cmvn = global_cmvn
+ self.embed = subsampling_class(
+ input_size,
+ output_size,
+ dropout_rate,
+ pos_enc_class(output_size, positional_dropout_rate),
+ )
+
+ self.normalize_before = normalize_before
+ self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
+ self.static_chunk_size = static_chunk_size
+ self.use_dynamic_chunk = use_dynamic_chunk
+ self.use_dynamic_left_chunk = use_dynamic_left_chunk
+
+ def output_size(self) -> int:
+ return self._output_size
+
+ def forward(
+ self,
+ xs: torch.Tensor,
+ xs_lens: torch.Tensor,
+ decoding_chunk_size: int = 0,
+ num_decoding_left_chunks: int = -1,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Embed positions in tensor.
+
+ Args:
+ xs: padded input tensor (B, T, D)
+ xs_lens: input length (B)
+ decoding_chunk_size: decoding chunk size for dynamic chunk
+ 0: default for training, use random dynamic chunk.
+ <0: for decoding, use full chunk.
+ >0: for decoding, use fixed chunk size as set.
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
+ the chunk size is decoding_chunk_size.
+ >=0: use num_decoding_left_chunks
+ <0: use all left chunks
+ Returns:
+ encoder output tensor xs, and subsampled masks
+ xs: padded output tensor (B, T' ~= T/subsample_rate, D)
+ masks: torch.Tensor batch padding mask after subsample
+ (B, 1, T' ~= T/subsample_rate)
+ """
+ T = xs.size(1)
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
+ if self.global_cmvn is not None:
+ xs = self.global_cmvn(xs)
+ xs, pos_emb, masks = self.embed(xs, masks)
+ mask_pad = masks # (B, 1, T/subsample_rate)
+ chunk_masks = add_optional_chunk_mask(
+ xs,
+ masks,
+ self.use_dynamic_chunk,
+ self.use_dynamic_left_chunk,
+ decoding_chunk_size,
+ self.static_chunk_size,
+ num_decoding_left_chunks,
+ )
+ for layer in self.encoders:
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
+ if self.normalize_before:
+ xs = self.after_norm(xs)
+ # Here we assume the mask is not changed in encoder layers, so just
+ # return the masks before encoder layers, and the masks will be used
+ # for cross attention with decoder later
+ return xs, masks
+
+ def forward_chunk(
+ self,
+ xs: torch.Tensor,
+ offset: int,
+ required_cache_size: int,
+ att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
+ cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
+ att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """ Forward just one chunk
+
+ Args:
+ xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim),
+ where `time == (chunk_size - 1) * subsample_rate + \
+ subsample.right_context + 1`
+ offset (int): current offset in encoder output time stamp
+ required_cache_size (int): cache size required for next chunk
+ compuation
+ >=0: actual cache size
+ <0: means all history cache is required
+ att_cache (torch.Tensor): cache tensor for KEY & VALUE in
+ transformer/conformer attention, with shape
+ (elayers, head, cache_t1, d_k * 2), where
+ `head * d_k == hidden-dim` and
+ `cache_t1 == chunk_size * num_decoding_left_chunks`.
+ cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
+ (elayers, b=1, hidden-dim, cache_t2), where
+ `cache_t2 == cnn.lorder - 1`
+
+ Returns:
+ torch.Tensor: output of current input xs,
+ with shape (b=1, chunk_size, hidden-dim).
+ torch.Tensor: new attention cache required for next chunk, with
+ dynamic shape (elayers, head, ?, d_k * 2)
+ depending on required_cache_size.
+ torch.Tensor: new conformer cnn cache required for next chunk, with
+ same shape as the original cnn_cache.
+
+ """
+ assert xs.size(0) == 1
+ # tmp_masks is just for interface compatibility
+ tmp_masks = torch.ones(1, xs.size(1), device=xs.device, dtype=torch.bool)
+ tmp_masks = tmp_masks.unsqueeze(1)
+ if self.global_cmvn is not None:
+ xs = self.global_cmvn(xs)
+ # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
+ xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
+ # NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim)
+ elayers, cache_t1 = att_cache.size(0), att_cache.size(2)
+ chunk_size = xs.size(1)
+ attention_key_size = cache_t1 + chunk_size
+ pos_emb = self.embed.position_encoding(
+ offset=offset - cache_t1, size=attention_key_size
+ )
+ if required_cache_size < 0:
+ next_cache_start = 0
+ elif required_cache_size == 0:
+ next_cache_start = attention_key_size
+ else:
+ next_cache_start = max(attention_key_size - required_cache_size, 0)
+ r_att_cache = []
+ r_cnn_cache = []
+ for i, layer in enumerate(self.encoders):
+ # NOTE(xcsong): Before layer.forward
+ # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
+ # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
+ xs, _, new_att_cache, new_cnn_cache = layer(
+ xs,
+ att_mask,
+ pos_emb,
+ att_cache=att_cache[i : i + 1] if elayers > 0 else att_cache,
+ cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache,
+ )
+ # NOTE(xcsong): After layer.forward
+ # shape(new_att_cache) is (1, head, attention_key_size, d_k * 2),
+ # shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2)
+ r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
+ r_cnn_cache.append(new_cnn_cache.unsqueeze(0))
+ if self.normalize_before:
+ xs = self.after_norm(xs)
+
+ # NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
+ # ? may be larger than cache_t1, it depends on required_cache_size
+ r_att_cache = torch.cat(r_att_cache, dim=0)
+ # NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
+ r_cnn_cache = torch.cat(r_cnn_cache, dim=0)
+
+ return (xs, r_att_cache, r_cnn_cache)
+
+ def forward_chunk_by_chunk(
+ self,
+ xs: torch.Tensor,
+ decoding_chunk_size: int,
+ num_decoding_left_chunks: int = -1,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Forward input chunk by chunk with chunk_size like a streaming
+ fashion
+
+ Here we should pay special attention to computation cache in the
+ streaming style forward chunk by chunk. Three things should be taken
+ into account for computation in the current network:
+ 1. transformer/conformer encoder layers output cache
+ 2. convolution in conformer
+ 3. convolution in subsampling
+
+ However, we don't implement subsampling cache for:
+ 1. We can control subsampling module to output the right result by
+ overlapping input instead of cache left context, even though it
+ wastes some computation, but subsampling only takes a very
+ small fraction of computation in the whole model.
+ 2. Typically, there are several covolution layers with subsampling
+ in subsampling module, it is tricky and complicated to do cache
+ with different convolution layers with different subsampling
+ rate.
+ 3. Currently, nn.Sequential is used to stack all the convolution
+ layers in subsampling, we need to rewrite it to make it work
+ with cache, which is not prefered.
+ Args:
+ xs (torch.Tensor): (1, max_len, dim)
+ chunk_size (int): decoding chunk size
+ """
+ assert decoding_chunk_size > 0
+ # The model is trained by static or dynamic chunk
+ assert self.static_chunk_size > 0 or self.use_dynamic_chunk
+ subsampling = self.embed.subsampling_rate
+ context = self.embed.right_context + 1 # Add current frame
+ stride = subsampling * decoding_chunk_size
+ decoding_window = (decoding_chunk_size - 1) * subsampling + context
+ num_frames = xs.size(1)
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
+ outputs = []
+ offset = 0
+ required_cache_size = decoding_chunk_size * num_decoding_left_chunks
+
+ # Feed forward overlap input step by step
+ for cur in range(0, num_frames - context + 1, stride):
+ end = min(cur + decoding_window, num_frames)
+ chunk_xs = xs[:, cur:end, :]
+ (y, att_cache, cnn_cache) = self.forward_chunk(
+ chunk_xs, offset, required_cache_size, att_cache, cnn_cache
+ )
+ outputs.append(y)
+ offset += y.size(1)
+ ys = torch.cat(outputs, 1)
+ masks = torch.ones((1, 1, ys.size(1)), device=ys.device, dtype=torch.bool)
+ return ys, masks
+
+
+class TransformerEncoder(BaseEncoder):
+ """Transformer encoder module."""
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ input_layer: str = "conv2d",
+ pos_enc_layer_type: str = "abs_pos",
+ normalize_before: bool = True,
+ static_chunk_size: int = 0,
+ use_dynamic_chunk: bool = False,
+ global_cmvn: torch.nn.Module = None,
+ use_dynamic_left_chunk: bool = False,
+ ):
+ """Construct TransformerEncoder
+
+ See Encoder for the meaning of each parameter.
+ """
+ super().__init__(
+ input_size,
+ output_size,
+ attention_heads,
+ linear_units,
+ num_blocks,
+ dropout_rate,
+ positional_dropout_rate,
+ attention_dropout_rate,
+ input_layer,
+ pos_enc_layer_type,
+ normalize_before,
+ static_chunk_size,
+ use_dynamic_chunk,
+ global_cmvn,
+ use_dynamic_left_chunk,
+ )
+ self.encoders = torch.nn.ModuleList(
+ [
+ TransformerEncoderLayer(
+ output_size,
+ MultiHeadedAttention(
+ attention_heads, output_size, attention_dropout_rate
+ ),
+ PositionwiseFeedForward(output_size, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ )
+ for _ in range(num_blocks)
+ ]
+ )
+
+
+class ConformerEncoder(BaseEncoder):
+ """Conformer encoder module."""
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ input_layer: str = "conv2d",
+ pos_enc_layer_type: str = "rel_pos",
+ normalize_before: bool = True,
+ static_chunk_size: int = 0,
+ use_dynamic_chunk: bool = False,
+ global_cmvn: torch.nn.Module = None,
+ use_dynamic_left_chunk: bool = False,
+ positionwise_conv_kernel_size: int = 1,
+ macaron_style: bool = True,
+ selfattention_layer_type: str = "rel_selfattn",
+ activation_type: str = "swish",
+ use_cnn_module: bool = True,
+ cnn_module_kernel: int = 15,
+ causal: bool = False,
+ cnn_module_norm: str = "batch_norm",
+ ):
+ """Construct ConformerEncoder
+
+ Args:
+ input_size to use_dynamic_chunk, see in BaseEncoder
+ positionwise_conv_kernel_size (int): Kernel size of positionwise
+ conv1d layer.
+ macaron_style (bool): Whether to use macaron style for
+ positionwise layer.
+ selfattention_layer_type (str): Encoder attention layer type,
+ the parameter has no effect now, it's just for configure
+ compatibility.
+ activation_type (str): Encoder activation function type.
+ use_cnn_module (bool): Whether to use convolution module.
+ cnn_module_kernel (int): Kernel size of convolution module.
+ causal (bool): whether to use causal convolution or not.
+ """
+ super().__init__(
+ input_size,
+ output_size,
+ attention_heads,
+ linear_units,
+ num_blocks,
+ dropout_rate,
+ positional_dropout_rate,
+ attention_dropout_rate,
+ input_layer,
+ pos_enc_layer_type,
+ normalize_before,
+ static_chunk_size,
+ use_dynamic_chunk,
+ global_cmvn,
+ use_dynamic_left_chunk,
+ )
+ activation = get_activation(activation_type)
+
+ # self-attention module definition
+ if pos_enc_layer_type != "rel_pos":
+ encoder_selfattn_layer = MultiHeadedAttention
+ else:
+ encoder_selfattn_layer = RelPositionMultiHeadedAttention
+ encoder_selfattn_layer_args = (
+ attention_heads,
+ output_size,
+ attention_dropout_rate,
+ )
+ # feed-forward module definition
+ positionwise_layer = PositionwiseFeedForward
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ dropout_rate,
+ activation,
+ )
+ # convolution module definition
+ convolution_layer = ConvolutionModule
+ convolution_layer_args = (
+ output_size,
+ cnn_module_kernel,
+ activation,
+ cnn_module_norm,
+ causal,
+ )
+
+ self.encoders = torch.nn.ModuleList(
+ [
+ ConformerEncoderLayer(
+ output_size,
+ encoder_selfattn_layer(*encoder_selfattn_layer_args),
+ positionwise_layer(*positionwise_layer_args),
+ (
+ positionwise_layer(*positionwise_layer_args)
+ if macaron_style
+ else None
+ ),
+ (
+ convolution_layer(*convolution_layer_args)
+ if use_cnn_module
+ else None
+ ),
+ dropout_rate,
+ normalize_before,
+ )
+ for _ in range(num_blocks)
+ ]
+ )
diff --git a/modules/wenet_extractor/transformer/encoder_layer.py b/modules/wenet_extractor/transformer/encoder_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b2ebb2f4fc9cdb5f6cd6c2401597638aa27fa4c
--- /dev/null
+++ b/modules/wenet_extractor/transformer/encoder_layer.py
@@ -0,0 +1,242 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+"""Encoder self-attention layer definition."""
+
+from typing import Optional, Tuple
+
+import torch
+from torch import nn
+
+
+class TransformerEncoderLayer(nn.Module):
+ """Encoder layer module.
+
+ Args:
+ size (int): Input dimension.
+ self_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
+ instance can be used as the argument.
+ feed_forward (torch.nn.Module): Feed-forward module instance.
+ `PositionwiseFeedForward`, instance can be used as the argument.
+ dropout_rate (float): Dropout rate.
+ normalize_before (bool):
+ True: use layer_norm before each sub-block.
+ False: to use layer_norm after each sub-block.
+ """
+
+ def __init__(
+ self,
+ size: int,
+ self_attn: torch.nn.Module,
+ feed_forward: torch.nn.Module,
+ dropout_rate: float,
+ normalize_before: bool = True,
+ ):
+ """Construct an EncoderLayer object."""
+ super().__init__()
+ self.self_attn = self_attn
+ self.feed_forward = feed_forward
+ self.norm1 = nn.LayerNorm(size, eps=1e-5)
+ self.norm2 = nn.LayerNorm(size, eps=1e-5)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.size = size
+ self.normalize_before = normalize_before
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mask: torch.Tensor,
+ pos_emb: torch.Tensor,
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Compute encoded features.
+
+ Args:
+ x (torch.Tensor): (#batch, time, size)
+ mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
+ (0, 0, 0) means fake mask.
+ pos_emb (torch.Tensor): just for interface compatibility
+ to ConformerEncoderLayer
+ mask_pad (torch.Tensor): does not used in transformer layer,
+ just for unified api with conformer.
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
+ (#batch=1, size, cache_t2), not used here, it's for interface
+ compatibility to ConformerEncoderLayer.
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, size).
+ torch.Tensor: Mask tensor (#batch, time, time).
+ torch.Tensor: att_cache tensor,
+ (#batch=1, head, cache_t1 + time, d_k * 2).
+ torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2).
+
+ """
+ residual = x
+ if self.normalize_before:
+ x = self.norm1(x)
+ x_att, new_att_cache = self.self_attn(x, x, x, mask, cache=att_cache)
+ x = residual + self.dropout(x_att)
+ if not self.normalize_before:
+ x = self.norm1(x)
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm2(x)
+ x = residual + self.dropout(self.feed_forward(x))
+ if not self.normalize_before:
+ x = self.norm2(x)
+
+ fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
+ return x, mask, new_att_cache, fake_cnn_cache
+
+
+class ConformerEncoderLayer(nn.Module):
+ """Encoder layer module.
+ Args:
+ size (int): Input dimension.
+ self_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
+ instance can be used as the argument.
+ feed_forward (torch.nn.Module): Feed-forward module instance.
+ `PositionwiseFeedForward` instance can be used as the argument.
+ feed_forward_macaron (torch.nn.Module): Additional feed-forward module
+ instance.
+ `PositionwiseFeedForward` instance can be used as the argument.
+ conv_module (torch.nn.Module): Convolution module instance.
+ `ConvlutionModule` instance can be used as the argument.
+ dropout_rate (float): Dropout rate.
+ normalize_before (bool):
+ True: use layer_norm before each sub-block.
+ False: use layer_norm after each sub-block.
+ """
+
+ def __init__(
+ self,
+ size: int,
+ self_attn: torch.nn.Module,
+ feed_forward: Optional[nn.Module] = None,
+ feed_forward_macaron: Optional[nn.Module] = None,
+ conv_module: Optional[nn.Module] = None,
+ dropout_rate: float = 0.1,
+ normalize_before: bool = True,
+ ):
+ """Construct an EncoderLayer object."""
+ super().__init__()
+ self.self_attn = self_attn
+ self.feed_forward = feed_forward
+ self.feed_forward_macaron = feed_forward_macaron
+ self.conv_module = conv_module
+ self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module
+ self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module
+ if feed_forward_macaron is not None:
+ self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
+ self.ff_scale = 0.5
+ else:
+ self.ff_scale = 1.0
+ if self.conv_module is not None:
+ self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module
+ self.norm_final = nn.LayerNorm(
+ size, eps=1e-5
+ ) # for the final output of the block
+ self.dropout = nn.Dropout(dropout_rate)
+ self.size = size
+ self.normalize_before = normalize_before
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mask: torch.Tensor,
+ pos_emb: torch.Tensor,
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Compute encoded features.
+
+ Args:
+ x (torch.Tensor): (#batch, time, size)
+ mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
+ (0, 0, 0) means fake mask.
+ pos_emb (torch.Tensor): positional encoding, must not be None
+ for ConformerEncoderLayer.
+ mask_pad (torch.Tensor): batch padding mask used for conv module.
+ (#batch, 1,time), (0, 0, 0) means fake mask.
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
+ (#batch=1, size, cache_t2)
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, size).
+ torch.Tensor: Mask tensor (#batch, time, time).
+ torch.Tensor: att_cache tensor,
+ (#batch=1, head, cache_t1 + time, d_k * 2).
+ torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
+ """
+
+ # whether to use macaron style
+ if self.feed_forward_macaron is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm_ff_macaron(x)
+ x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
+ if not self.normalize_before:
+ x = self.norm_ff_macaron(x)
+
+ # multi-headed self-attention module
+ residual = x
+ if self.normalize_before:
+ x = self.norm_mha(x)
+ x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, att_cache)
+ x = residual + self.dropout(x_att)
+ if not self.normalize_before:
+ x = self.norm_mha(x)
+
+ # convolution module
+ # Fake new cnn cache here, and then change it in conv_module
+ new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
+ if self.conv_module is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm_conv(x)
+ x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
+ x = residual + self.dropout(x)
+
+ if not self.normalize_before:
+ x = self.norm_conv(x)
+
+ # feed forward module
+ residual = x
+ if self.normalize_before:
+ x = self.norm_ff(x)
+
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
+ if not self.normalize_before:
+ x = self.norm_ff(x)
+
+ if self.conv_module is not None:
+ x = self.norm_final(x)
+
+ return x, mask, new_att_cache, new_cnn_cache
diff --git a/modules/wenet_extractor/transformer/label_smoothing_loss.py b/modules/wenet_extractor/transformer/label_smoothing_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c2b1cae5923fef66665adf9f84604af03c1ea9a
--- /dev/null
+++ b/modules/wenet_extractor/transformer/label_smoothing_loss.py
@@ -0,0 +1,106 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+"""Label smoothing module."""
+
+import torch
+from torch import nn
+
+
+class LabelSmoothingLoss(nn.Module):
+ """Label-smoothing loss.
+
+ In a standard CE loss, the label's data distribution is:
+ [0,1,2] ->
+ [
+ [1.0, 0.0, 0.0],
+ [0.0, 1.0, 0.0],
+ [0.0, 0.0, 1.0],
+ ]
+
+ In the smoothing version CE Loss,some probabilities
+ are taken from the true label prob (1.0) and are divided
+ among other labels.
+
+ e.g.
+ smoothing=0.1
+ [0,1,2] ->
+ [
+ [0.9, 0.05, 0.05],
+ [0.05, 0.9, 0.05],
+ [0.05, 0.05, 0.9],
+ ]
+
+ Args:
+ size (int): the number of class
+ padding_idx (int): padding class id which will be ignored for loss
+ smoothing (float): smoothing rate (0.0 means the conventional CE)
+ normalize_length (bool):
+ normalize loss by sequence length if True
+ normalize loss by batch size if False
+ """
+
+ def __init__(
+ self,
+ size: int,
+ padding_idx: int,
+ smoothing: float,
+ normalize_length: bool = False,
+ ):
+ """Construct an LabelSmoothingLoss object."""
+ super(LabelSmoothingLoss, self).__init__()
+ self.criterion = nn.KLDivLoss(reduction="none")
+ self.padding_idx = padding_idx
+ self.confidence = 1.0 - smoothing
+ self.smoothing = smoothing
+ self.size = size
+ self.normalize_length = normalize_length
+
+ def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
+ """Compute loss between x and target.
+
+ The model outputs and data labels tensors are flatten to
+ (batch*seqlen, class) shape and a mask is applied to the
+ padding part which should not be calculated for loss.
+
+ Args:
+ x (torch.Tensor): prediction (batch, seqlen, class)
+ target (torch.Tensor):
+ target signal masked with self.padding_id (batch, seqlen)
+ Returns:
+ loss (torch.Tensor) : The KL loss, scalar float value
+ """
+ assert x.size(2) == self.size
+ batch_size = x.size(0)
+ x = x.view(-1, self.size)
+ target = target.view(-1)
+ # use zeros_like instead of torch.no_grad() for true_dist,
+ # since no_grad() can not be exported by JIT
+ true_dist = torch.zeros_like(x)
+ true_dist.fill_(self.smoothing / (self.size - 1))
+ ignore = target == self.padding_idx # (B,)
+ total = len(target) - ignore.sum().item()
+ target = target.masked_fill(ignore, 0) # avoid -1 index
+ true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
+ kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
+ denom = total if self.normalize_length else batch_size
+ return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
diff --git a/modules/wenet_extractor/transformer/positionwise_feed_forward.py b/modules/wenet_extractor/transformer/positionwise_feed_forward.py
new file mode 100644
index 0000000000000000000000000000000000000000..eaaefdc61d392d44e503458deb490690b3183a03
--- /dev/null
+++ b/modules/wenet_extractor/transformer/positionwise_feed_forward.py
@@ -0,0 +1,63 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+"""Positionwise feed forward layer definition."""
+
+import torch
+
+
+class PositionwiseFeedForward(torch.nn.Module):
+ """Positionwise feed forward layer.
+
+ FeedForward are appied on each position of the sequence.
+ The output dim is same with the input dim.
+
+ Args:
+ idim (int): Input dimenstion.
+ hidden_units (int): The number of hidden units.
+ dropout_rate (float): Dropout rate.
+ activation (torch.nn.Module): Activation function
+ """
+
+ def __init__(
+ self,
+ idim: int,
+ hidden_units: int,
+ dropout_rate: float,
+ activation: torch.nn.Module = torch.nn.ReLU(),
+ ):
+ """Construct a PositionwiseFeedForward object."""
+ super(PositionwiseFeedForward, self).__init__()
+ self.w_1 = torch.nn.Linear(idim, hidden_units)
+ self.activation = activation
+ self.dropout = torch.nn.Dropout(dropout_rate)
+ self.w_2 = torch.nn.Linear(hidden_units, idim)
+
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
+ """Forward function.
+
+ Args:
+ xs: input tensor (B, L, D)
+ Returns:
+ output tensor, (B, L, D)
+ """
+ return self.w_2(self.dropout(self.activation(self.w_1(xs))))
diff --git a/modules/wenet_extractor/transformer/subsampling.py b/modules/wenet_extractor/transformer/subsampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffbb5dae0dd013c916cc4a4266cdc838ba48a84f
--- /dev/null
+++ b/modules/wenet_extractor/transformer/subsampling.py
@@ -0,0 +1,257 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+
+"""Subsampling layer definition."""
+
+from typing import Tuple, Union
+
+import torch
+
+
+class BaseSubsampling(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.right_context = 0
+ self.subsampling_rate = 1
+
+ def position_encoding(
+ self, offset: Union[int, torch.Tensor], size: int
+ ) -> torch.Tensor:
+ return self.pos_enc.position_encoding(offset, size)
+
+
+class LinearNoSubsampling(BaseSubsampling):
+ """Linear transform the input without subsampling
+
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(
+ self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
+ ):
+ """Construct an linear object."""
+ super().__init__()
+ self.out = torch.nn.Sequential(
+ torch.nn.Linear(idim, odim),
+ torch.nn.LayerNorm(odim, eps=1e-5),
+ torch.nn.Dropout(dropout_rate),
+ )
+ self.pos_enc = pos_enc_class
+ self.right_context = 0
+ self.subsampling_rate = 1
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Input x.
+
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+ Returns:
+ torch.Tensor: linear input tensor (#batch, time', odim),
+ where time' = time .
+ torch.Tensor: linear input mask (#batch, 1, time'),
+ where time' = time .
+
+ """
+ x = self.out(x)
+ x, pos_emb = self.pos_enc(x, offset)
+ return x, pos_emb, x_mask
+
+
+class Conv2dSubsampling4(BaseSubsampling):
+ """Convolutional 2D subsampling (to 1/4 length).
+
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(
+ self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
+ ):
+ """Construct an Conv2dSubsampling4 object."""
+ super().__init__()
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv2d(1, odim, 3, 2),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(odim, odim, 3, 2),
+ torch.nn.ReLU(),
+ )
+ self.out = torch.nn.Sequential(
+ torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
+ )
+ self.pos_enc = pos_enc_class
+ # The right context for every conv layer is computed by:
+ # (kernel_size - 1) * frame_rate_of_this_layer
+ self.subsampling_rate = 4
+ # 6 = (3 - 1) * 1 + (3 - 1) * 2
+ self.right_context = 6
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Subsample x.
+
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+ Returns:
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
+ where time' = time // 4.
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
+ where time' = time // 4.
+ torch.Tensor: positional encoding
+
+ """
+ x = x.unsqueeze(1) # (b, c=1, t, f)
+ x = self.conv(x)
+ b, c, t, f = x.size()
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
+ x, pos_emb = self.pos_enc(x, offset)
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2]
+
+
+class Conv2dSubsampling6(BaseSubsampling):
+ """Convolutional 2D subsampling (to 1/6 length).
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+ pos_enc (torch.nn.Module): Custom position encoding layer.
+ """
+
+ def __init__(
+ self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
+ ):
+ """Construct an Conv2dSubsampling6 object."""
+ super().__init__()
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv2d(1, odim, 3, 2),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(odim, odim, 5, 3),
+ torch.nn.ReLU(),
+ )
+ self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim)
+ self.pos_enc = pos_enc_class
+ # 10 = (3 - 1) * 1 + (5 - 1) * 2
+ self.subsampling_rate = 6
+ self.right_context = 10
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Subsample x.
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+ Returns:
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
+ where time' = time // 6.
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
+ where time' = time // 6.
+ torch.Tensor: positional encoding
+ """
+ x = x.unsqueeze(1) # (b, c, t, f)
+ x = self.conv(x)
+ b, c, t, f = x.size()
+ x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
+ x, pos_emb = self.pos_enc(x, offset)
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3]
+
+
+class Conv2dSubsampling8(BaseSubsampling):
+ """Convolutional 2D subsampling (to 1/8 length).
+
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(
+ self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
+ ):
+ """Construct an Conv2dSubsampling8 object."""
+ super().__init__()
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv2d(1, odim, 3, 2),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(odim, odim, 3, 2),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(odim, odim, 3, 2),
+ torch.nn.ReLU(),
+ )
+ self.linear = torch.nn.Linear(
+ odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim
+ )
+ self.pos_enc = pos_enc_class
+ self.subsampling_rate = 8
+ # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
+ self.right_context = 14
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Subsample x.
+
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+ Returns:
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
+ where time' = time // 8.
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
+ where time' = time // 8.
+ torch.Tensor: positional encoding
+ """
+ x = x.unsqueeze(1) # (b, c, t, f)
+ x = self.conv(x)
+ b, c, t, f = x.size()
+ x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
+ x, pos_emb = self.pos_enc(x, offset)
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2]
diff --git a/modules/wenet_extractor/transformer/swish.py b/modules/wenet_extractor/transformer/swish.py
new file mode 100644
index 0000000000000000000000000000000000000000..8728089430d0c20fa1575d12e3b6573417b4448c
--- /dev/null
+++ b/modules/wenet_extractor/transformer/swish.py
@@ -0,0 +1,33 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+"""Swish() activation function for Conformer."""
+
+import torch
+
+
+class Swish(torch.nn.Module):
+ """Construct an Swish object."""
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Return Swish activation function."""
+ return x * torch.sigmoid(x)
diff --git a/modules/wenet_extractor/utils/__init__.py b/modules/wenet_extractor/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/modules/wenet_extractor/utils/checkpoint.py b/modules/wenet_extractor/utils/checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..f17784b1e1b41284191b0caff7570dd2066f7e91
--- /dev/null
+++ b/modules/wenet_extractor/utils/checkpoint.py
@@ -0,0 +1,113 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+import logging
+import os
+import re
+
+import yaml
+import torch
+from collections import OrderedDict
+
+import datetime
+
+
+def load_checkpoint(model: torch.nn.Module, path: str) -> dict:
+ if torch.cuda.is_available():
+ logging.info("Checkpoint: loading from checkpoint %s for GPU" % path)
+ checkpoint = torch.load(path)
+ else:
+ logging.info("Checkpoint: loading from checkpoint %s for CPU" % path)
+ checkpoint = torch.load(path, map_location="cpu")
+ model.load_state_dict(checkpoint, strict=False)
+ info_path = re.sub(".pt$", ".yaml", path)
+ configs = {}
+ if os.path.exists(info_path):
+ with open(info_path, "r") as fin:
+ configs = yaml.load(fin, Loader=yaml.FullLoader)
+ return configs
+
+
+def save_checkpoint(model: torch.nn.Module, path: str, infos=None):
+ """
+ Args:
+ infos (dict or None): any info you want to save.
+ """
+ logging.info("Checkpoint: save to checkpoint %s" % path)
+ if isinstance(model, torch.nn.DataParallel):
+ state_dict = model.module.state_dict()
+ elif isinstance(model, torch.nn.parallel.DistributedDataParallel):
+ state_dict = model.module.state_dict()
+ else:
+ state_dict = model.state_dict()
+ torch.save(state_dict, path)
+ info_path = re.sub(".pt$", ".yaml", path)
+ if infos is None:
+ infos = {}
+ infos["save_time"] = datetime.datetime.now().strftime("%d/%m/%Y %H:%M:%S")
+ with open(info_path, "w") as fout:
+ data = yaml.dump(infos)
+ fout.write(data)
+
+
+def filter_modules(model_state_dict, modules):
+ new_mods = []
+ incorrect_mods = []
+ mods_model = model_state_dict.keys()
+ for mod in modules:
+ if any(key.startswith(mod) for key in mods_model):
+ new_mods += [mod]
+ else:
+ incorrect_mods += [mod]
+ if incorrect_mods:
+ logging.warning(
+ "module(s) %s don't match or (partially match) "
+ "available modules in model.",
+ incorrect_mods,
+ )
+ logging.warning("for information, the existing modules in model are:")
+ logging.warning("%s", mods_model)
+
+ return new_mods
+
+
+def load_trained_modules(model: torch.nn.Module, args: None):
+ # Load encoder modules with pre-trained model(s).
+ enc_model_path = args.enc_init
+ enc_modules = args.enc_init_mods
+ main_state_dict = model.state_dict()
+ logging.warning("model(s) found for pre-initialization")
+ if os.path.isfile(enc_model_path):
+ logging.info("Checkpoint: loading from checkpoint %s for CPU" % enc_model_path)
+ model_state_dict = torch.load(enc_model_path, map_location="cpu")
+ modules = filter_modules(model_state_dict, enc_modules)
+ partial_state_dict = OrderedDict()
+ for key, value in model_state_dict.items():
+ if any(key.startswith(m) for m in modules):
+ partial_state_dict[key] = value
+ main_state_dict.update(partial_state_dict)
+ else:
+ logging.warning("model was not found : %s", enc_model_path)
+
+ model.load_state_dict(main_state_dict)
+ configs = {}
+ return configs
diff --git a/modules/wenet_extractor/utils/cmvn.py b/modules/wenet_extractor/utils/cmvn.py
new file mode 100644
index 0000000000000000000000000000000000000000..483adc1f709786ae83f570faf10ee1a1ded85e52
--- /dev/null
+++ b/modules/wenet_extractor/utils/cmvn.py
@@ -0,0 +1,103 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+import json
+import math
+
+import numpy as np
+
+
+def _load_json_cmvn(json_cmvn_file):
+ """Load the json format cmvn stats file and calculate cmvn
+
+ Args:
+ json_cmvn_file: cmvn stats file in json format
+
+ Returns:
+ a numpy array of [means, vars]
+ """
+ with open(json_cmvn_file) as f:
+ cmvn_stats = json.load(f)
+
+ means = cmvn_stats["mean_stat"]
+ variance = cmvn_stats["var_stat"]
+ count = cmvn_stats["frame_num"]
+ for i in range(len(means)):
+ means[i] /= count
+ variance[i] = variance[i] / count - means[i] * means[i]
+ if variance[i] < 1.0e-20:
+ variance[i] = 1.0e-20
+ variance[i] = 1.0 / math.sqrt(variance[i])
+ cmvn = np.array([means, variance])
+ return cmvn
+
+
+def _load_kaldi_cmvn(kaldi_cmvn_file):
+ """Load the kaldi format cmvn stats file and calculate cmvn
+
+ Args:
+ kaldi_cmvn_file: kaldi text style global cmvn file, which
+ is generated by:
+ compute-cmvn-stats --binary=false scp:feats.scp global_cmvn
+
+ Returns:
+ a numpy array of [means, vars]
+ """
+ means = []
+ variance = []
+ with open(kaldi_cmvn_file, "r") as fid:
+ # kaldi binary file start with '\0B'
+ if fid.read(2) == "\0B":
+ logging.error(
+ "kaldi cmvn binary file is not supported, please "
+ "recompute it by: compute-cmvn-stats --binary=false "
+ " scp:feats.scp global_cmvn"
+ )
+ sys.exit(1)
+ fid.seek(0)
+ arr = fid.read().split()
+ assert arr[0] == "["
+ assert arr[-2] == "0"
+ assert arr[-1] == "]"
+ feat_dim = int((len(arr) - 2 - 2) / 2)
+ for i in range(1, feat_dim + 1):
+ means.append(float(arr[i]))
+ count = float(arr[feat_dim + 1])
+ for i in range(feat_dim + 2, 2 * feat_dim + 2):
+ variance.append(float(arr[i]))
+
+ for i in range(len(means)):
+ means[i] /= count
+ variance[i] = variance[i] / count - means[i] * means[i]
+ if variance[i] < 1.0e-20:
+ variance[i] = 1.0e-20
+ variance[i] = 1.0 / math.sqrt(variance[i])
+ cmvn = np.array([means, variance])
+ return cmvn
+
+
+def load_cmvn(cmvn_file, is_json):
+ if is_json:
+ cmvn = _load_json_cmvn(cmvn_file)
+ else:
+ cmvn = _load_kaldi_cmvn(cmvn_file)
+ return cmvn[0], cmvn[1]
diff --git a/modules/wenet_extractor/utils/common.py b/modules/wenet_extractor/utils/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..8900ce4cd279e87829807c14a14e2ca38841b89d
--- /dev/null
+++ b/modules/wenet_extractor/utils/common.py
@@ -0,0 +1,266 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+"""Unility functions for Transformer."""
+
+import math
+from typing import List, Tuple
+
+import torch
+from torch.nn.utils.rnn import pad_sequence
+
+IGNORE_ID = -1
+
+
+def pad_list(xs: List[torch.Tensor], pad_value: int):
+ """Perform padding for the list of tensors.
+
+ Args:
+ xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
+ pad_value (float): Value for padding.
+
+ Returns:
+ Tensor: Padded tensor (B, Tmax, `*`).
+
+ Examples:
+ >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
+ >>> x
+ [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
+ >>> pad_list(x, 0)
+ tensor([[1., 1., 1., 1.],
+ [1., 1., 0., 0.],
+ [1., 0., 0., 0.]])
+
+ """
+ n_batch = len(xs)
+ max_len = max([x.size(0) for x in xs])
+ pad = torch.zeros(n_batch, max_len, dtype=xs[0].dtype, device=xs[0].device)
+ pad = pad.fill_(pad_value)
+ for i in range(n_batch):
+ pad[i, : xs[i].size(0)] = xs[i]
+
+ return pad
+
+
+def add_blank(ys_pad: torch.Tensor, blank: int, ignore_id: int) -> torch.Tensor:
+ """Prepad blank for transducer predictor
+
+ Args:
+ ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax)
+ blank (int): index of
+
+ Returns:
+ ys_in (torch.Tensor) : (B, Lmax + 1)
+
+ Examples:
+ >>> blank = 0
+ >>> ignore_id = -1
+ >>> ys_pad
+ tensor([[ 1, 2, 3, 4, 5],
+ [ 4, 5, 6, -1, -1],
+ [ 7, 8, 9, -1, -1]], dtype=torch.int32)
+ >>> ys_in = add_blank(ys_pad, 0, -1)
+ >>> ys_in
+ tensor([[0, 1, 2, 3, 4, 5],
+ [0, 4, 5, 6, 0, 0],
+ [0, 7, 8, 9, 0, 0]])
+ """
+ bs = ys_pad.size(0)
+ _blank = torch.tensor(
+ [blank], dtype=torch.long, requires_grad=False, device=ys_pad.device
+ )
+ _blank = _blank.repeat(bs).unsqueeze(1) # [bs,1]
+ out = torch.cat([_blank, ys_pad], dim=1) # [bs, Lmax+1]
+ return torch.where(out == ignore_id, blank, out)
+
+
+def add_sos_eos(
+ ys_pad: torch.Tensor, sos: int, eos: int, ignore_id: int
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Add and labels.
+
+ Args:
+ ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax)
+ sos (int): index of
+ eos (int): index of
+ ignore_id (int): index of padding
+
+ Returns:
+ ys_in (torch.Tensor) : (B, Lmax + 1)
+ ys_out (torch.Tensor) : (B, Lmax + 1)
+
+ Examples:
+ >>> sos_id = 10
+ >>> eos_id = 11
+ >>> ignore_id = -1
+ >>> ys_pad
+ tensor([[ 1, 2, 3, 4, 5],
+ [ 4, 5, 6, -1, -1],
+ [ 7, 8, 9, -1, -1]], dtype=torch.int32)
+ >>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id)
+ >>> ys_in
+ tensor([[10, 1, 2, 3, 4, 5],
+ [10, 4, 5, 6, 11, 11],
+ [10, 7, 8, 9, 11, 11]])
+ >>> ys_out
+ tensor([[ 1, 2, 3, 4, 5, 11],
+ [ 4, 5, 6, 11, -1, -1],
+ [ 7, 8, 9, 11, -1, -1]])
+ """
+ _sos = torch.tensor(
+ [sos], dtype=torch.long, requires_grad=False, device=ys_pad.device
+ )
+ _eos = torch.tensor(
+ [eos], dtype=torch.long, requires_grad=False, device=ys_pad.device
+ )
+ ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
+ ys_in = [torch.cat([_sos, y], dim=0) for y in ys]
+ ys_out = [torch.cat([y, _eos], dim=0) for y in ys]
+ return pad_list(ys_in, eos), pad_list(ys_out, ignore_id)
+
+
+def reverse_pad_list(
+ ys_pad: torch.Tensor, ys_lens: torch.Tensor, pad_value: float = -1.0
+) -> torch.Tensor:
+ """Reverse padding for the list of tensors.
+
+ Args:
+ ys_pad (tensor): The padded tensor (B, Tokenmax).
+ ys_lens (tensor): The lens of token seqs (B)
+ pad_value (int): Value for padding.
+
+ Returns:
+ Tensor: Padded tensor (B, Tokenmax).
+
+ Examples:
+ >>> x
+ tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]])
+ >>> pad_list(x, 0)
+ tensor([[4, 3, 2, 1],
+ [7, 6, 5, 0],
+ [9, 8, 0, 0]])
+
+ """
+ r_ys_pad = pad_sequence(
+ [(torch.flip(y.int()[:i], [0])) for y, i in zip(ys_pad, ys_lens)],
+ True,
+ pad_value,
+ )
+ return r_ys_pad
+
+
+def th_accuracy(
+ pad_outputs: torch.Tensor, pad_targets: torch.Tensor, ignore_label: int
+) -> float:
+ """Calculate accuracy.
+
+ Args:
+ pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
+ pad_targets (LongTensor): Target label tensors (B, Lmax).
+ ignore_label (int): Ignore label id.
+
+ Returns:
+ float: Accuracy value (0.0 - 1.0).
+
+ """
+ pad_pred = pad_outputs.view(
+ pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)
+ ).argmax(2)
+ mask = pad_targets != ignore_label
+ numerator = torch.sum(
+ pad_pred.masked_select(mask) == pad_targets.masked_select(mask)
+ )
+ denominator = torch.sum(mask)
+ return float(numerator) / float(denominator)
+
+
+def get_rnn(rnn_type: str) -> torch.nn.Module:
+ assert rnn_type in ["rnn", "lstm", "gru"]
+ if rnn_type == "rnn":
+ return torch.nn.RNN
+ elif rnn_type == "lstm":
+ return torch.nn.LSTM
+ else:
+ return torch.nn.GRU
+
+
+def get_activation(act):
+ """Return activation function."""
+ # Lazy load to avoid unused import
+ from modules.wenet_extractor.transformer.swish import Swish
+
+ activation_funcs = {
+ "hardtanh": torch.nn.Hardtanh,
+ "tanh": torch.nn.Tanh,
+ "relu": torch.nn.ReLU,
+ "selu": torch.nn.SELU,
+ "swish": getattr(torch.nn, "SiLU", Swish),
+ "gelu": torch.nn.GELU,
+ }
+
+ return activation_funcs[act]()
+
+
+def get_subsample(config):
+ input_layer = config["encoder_conf"]["input_layer"]
+ assert input_layer in ["conv2d", "conv2d6", "conv2d8"]
+ if input_layer == "conv2d":
+ return 4
+ elif input_layer == "conv2d6":
+ return 6
+ elif input_layer == "conv2d8":
+ return 8
+
+
+def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
+ new_hyp: List[int] = []
+ cur = 0
+ while cur < len(hyp):
+ if hyp[cur] != 0:
+ new_hyp.append(hyp[cur])
+ prev = cur
+ while cur < len(hyp) and hyp[cur] == hyp[prev]:
+ cur += 1
+ return new_hyp
+
+
+def replace_duplicates_with_blank(hyp: List[int]) -> List[int]:
+ new_hyp: List[int] = []
+ cur = 0
+ while cur < len(hyp):
+ new_hyp.append(hyp[cur])
+ prev = cur
+ cur += 1
+ while cur < len(hyp) and hyp[cur] == hyp[prev] and hyp[cur] != 0:
+ new_hyp.append(0)
+ cur += 1
+ return new_hyp
+
+
+def log_add(args: List[int]) -> float:
+ """
+ Stable log add
+ """
+ if all(a == -float("inf") for a in args):
+ return -float("inf")
+ a_max = max(args)
+ lsp = math.log(sum(math.exp(a - a_max) for a in args))
+ return a_max + lsp
diff --git a/modules/wenet_extractor/utils/config.py b/modules/wenet_extractor/utils/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f7ab928eec9d18d1af46655f9f324a7ac1eb9dc
--- /dev/null
+++ b/modules/wenet_extractor/utils/config.py
@@ -0,0 +1,48 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+
+import copy
+
+
+def override_config(configs, override_list):
+ new_configs = copy.deepcopy(configs)
+ for item in override_list:
+ arr = item.split()
+ if len(arr) != 2:
+ print(f"the overrive {item} format not correct, skip it")
+ continue
+ keys = arr[0].split(".")
+ s_configs = new_configs
+ for i, key in enumerate(keys):
+ if key not in s_configs:
+ print(f"the overrive {item} format not correct, skip it")
+ if i == len(keys) - 1:
+ param_type = type(s_configs[key])
+ if param_type != bool:
+ s_configs[key] = param_type(arr[1])
+ else:
+ s_configs[key] = arr[1] in ["true", "True"]
+ print(f"override {arr[0]} with {arr[1]}")
+ else:
+ s_configs = s_configs[key]
+ return new_configs
diff --git a/modules/wenet_extractor/utils/ctc_util.py b/modules/wenet_extractor/utils/ctc_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..03154248bb087d930a9cc579e73d782756cd6ea0
--- /dev/null
+++ b/modules/wenet_extractor/utils/ctc_util.py
@@ -0,0 +1,96 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+import numpy as np
+import torch
+
+
+def insert_blank(label, blank_id=0):
+ """Insert blank token between every two label token."""
+ label = np.expand_dims(label, 1)
+ blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id
+ label = np.concatenate([blanks, label], axis=1)
+ label = label.reshape(-1)
+ label = np.append(label, label[0])
+ return label
+
+
+def forced_align(ctc_probs: torch.Tensor, y: torch.Tensor, blank_id=0) -> list:
+ """ctc forced alignment.
+
+ Args:
+ torch.Tensor ctc_probs: hidden state sequence, 2d tensor (T, D)
+ torch.Tensor y: id sequence tensor 1d tensor (L)
+ int blank_id: blank symbol index
+ Returns:
+ torch.Tensor: alignment result
+ """
+ y_insert_blank = insert_blank(y, blank_id)
+
+ log_alpha = torch.zeros((ctc_probs.size(0), len(y_insert_blank)))
+ log_alpha = log_alpha - float("inf") # log of zero
+ state_path = (
+ torch.zeros((ctc_probs.size(0), len(y_insert_blank)), dtype=torch.int16) - 1
+ ) # state path
+
+ # init start state
+ log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]]
+ log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]]
+
+ for t in range(1, ctc_probs.size(0)):
+ for s in range(len(y_insert_blank)):
+ if (
+ y_insert_blank[s] == blank_id
+ or s < 2
+ or y_insert_blank[s] == y_insert_blank[s - 2]
+ ):
+ candidates = torch.tensor(
+ [log_alpha[t - 1, s], log_alpha[t - 1, s - 1]]
+ )
+ prev_state = [s, s - 1]
+ else:
+ candidates = torch.tensor(
+ [
+ log_alpha[t - 1, s],
+ log_alpha[t - 1, s - 1],
+ log_alpha[t - 1, s - 2],
+ ]
+ )
+ prev_state = [s, s - 1, s - 2]
+ log_alpha[t, s] = torch.max(candidates) + ctc_probs[t][y_insert_blank[s]]
+ state_path[t, s] = prev_state[torch.argmax(candidates)]
+
+ state_seq = -1 * torch.ones((ctc_probs.size(0), 1), dtype=torch.int16)
+
+ candidates = torch.tensor(
+ [log_alpha[-1, len(y_insert_blank) - 1], log_alpha[-1, len(y_insert_blank) - 2]]
+ )
+ final_state = [len(y_insert_blank) - 1, len(y_insert_blank) - 2]
+ state_seq[-1] = final_state[torch.argmax(candidates)]
+ for t in range(ctc_probs.size(0) - 2, -1, -1):
+ state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]]
+
+ output_alignment = []
+ for t in range(0, ctc_probs.size(0)):
+ output_alignment.append(y_insert_blank[state_seq[t, 0]])
+
+ return output_alignment
diff --git a/modules/wenet_extractor/utils/executor.py b/modules/wenet_extractor/utils/executor.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d55549fec4f4d350e56093f3995ac93a31ff029
--- /dev/null
+++ b/modules/wenet_extractor/utils/executor.py
@@ -0,0 +1,163 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+import logging
+from contextlib import nullcontext
+
+# if your python version < 3.7 use the below one
+# from contextlib import suppress as nullcontext
+import torch
+from torch.nn.utils import clip_grad_norm_
+
+
+class Executor:
+ def __init__(self):
+ self.step = 0
+
+ def train(
+ self, model, optimizer, scheduler, data_loader, device, writer, args, scaler
+ ):
+ """Train one epoch"""
+ model.train()
+ clip = args.get("grad_clip", 50.0)
+ log_interval = args.get("log_interval", 10)
+ rank = args.get("rank", 0)
+ epoch = args.get("epoch", 0)
+ accum_grad = args.get("accum_grad", 1)
+ is_distributed = args.get("is_distributed", True)
+ use_amp = args.get("use_amp", False)
+ logging.info(
+ "using accumulate grad, new batch size is {} times"
+ " larger than before".format(accum_grad)
+ )
+ if use_amp:
+ assert scaler is not None
+ # A context manager to be used in conjunction with an instance of
+ # torch.nn.parallel.DistributedDataParallel to be able to train
+ # with uneven inputs across participating processes.
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel):
+ model_context = model.join
+ else:
+ model_context = nullcontext
+ num_seen_utts = 0
+ with model_context():
+ for batch_idx, batch in enumerate(data_loader):
+ key, feats, target, feats_lengths, target_lengths = batch
+ feats = feats.to(device)
+ target = target.to(device)
+ feats_lengths = feats_lengths.to(device)
+ target_lengths = target_lengths.to(device)
+ num_utts = target_lengths.size(0)
+ if num_utts == 0:
+ continue
+ context = None
+ # Disable gradient synchronizations across DDP processes.
+ # Within this context, gradients will be accumulated on module
+ # variables, which will later be synchronized.
+ if is_distributed and batch_idx % accum_grad != 0:
+ context = model.no_sync
+ # Used for single gpu training and DDP gradient synchronization
+ # processes.
+ else:
+ context = nullcontext
+ with context():
+ # autocast context
+ # The more details about amp can be found in
+ # https://pytorch.org/docs/stable/notes/amp_examples.html
+ with torch.cuda.amp.autocast(scaler is not None):
+ loss_dict = model(feats, feats_lengths, target, target_lengths)
+ loss = loss_dict["loss"] / accum_grad
+ if use_amp:
+ scaler.scale(loss).backward()
+ else:
+ loss.backward()
+
+ num_seen_utts += num_utts
+ if batch_idx % accum_grad == 0:
+ if rank == 0 and writer is not None:
+ writer.add_scalar("train_loss", loss, self.step)
+ # Use mixed precision training
+ if use_amp:
+ scaler.unscale_(optimizer)
+ grad_norm = clip_grad_norm_(model.parameters(), clip)
+ # Must invoke scaler.update() if unscale_() is used in
+ # the iteration to avoid the following error:
+ # RuntimeError: unscale_() has already been called
+ # on this optimizer since the last update().
+ # We don't check grad here since that if the gradient
+ # has inf/nan values, scaler.step will skip
+ # optimizer.step().
+ scaler.step(optimizer)
+ scaler.update()
+ else:
+ grad_norm = clip_grad_norm_(model.parameters(), clip)
+ if torch.isfinite(grad_norm):
+ optimizer.step()
+ optimizer.zero_grad()
+ scheduler.step()
+ self.step += 1
+ if batch_idx % log_interval == 0:
+ lr = optimizer.param_groups[0]["lr"]
+ log_str = "TRAIN Batch {}/{} loss {:.6f} ".format(
+ epoch, batch_idx, loss.item() * accum_grad
+ )
+ for name, value in loss_dict.items():
+ if name != "loss" and value is not None:
+ log_str += "{} {:.6f} ".format(name, value.item())
+ log_str += "lr {:.8f} rank {}".format(lr, rank)
+ logging.debug(log_str)
+
+ def cv(self, model, data_loader, device, args):
+ """Cross validation on"""
+ model.eval()
+ rank = args.get("rank", 0)
+ epoch = args.get("epoch", 0)
+ log_interval = args.get("log_interval", 10)
+ # in order to avoid division by 0
+ num_seen_utts = 1
+ total_loss = 0.0
+ with torch.no_grad():
+ for batch_idx, batch in enumerate(data_loader):
+ key, feats, target, feats_lengths, target_lengths = batch
+ feats = feats.to(device)
+ target = target.to(device)
+ feats_lengths = feats_lengths.to(device)
+ target_lengths = target_lengths.to(device)
+ num_utts = target_lengths.size(0)
+ if num_utts == 0:
+ continue
+ loss_dict = model(feats, feats_lengths, target, target_lengths)
+ loss = loss_dict["loss"]
+ if torch.isfinite(loss):
+ num_seen_utts += num_utts
+ total_loss += loss.item() * num_utts
+ if batch_idx % log_interval == 0:
+ log_str = "CV Batch {}/{} loss {:.6f} ".format(
+ epoch, batch_idx, loss.item()
+ )
+ for name, value in loss_dict.items():
+ if name != "loss" and value is not None:
+ log_str += "{} {:.6f} ".format(name, value.item())
+ log_str += "history loss {:.6f}".format(total_loss / num_seen_utts)
+ log_str += " rank {}".format(rank)
+ logging.debug(log_str)
+ return total_loss, num_seen_utts
diff --git a/modules/wenet_extractor/utils/file_utils.py b/modules/wenet_extractor/utils/file_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..443759bd4da9d7f85d48f0534205f9e5cd7048df
--- /dev/null
+++ b/modules/wenet_extractor/utils/file_utils.py
@@ -0,0 +1,77 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+import re
+
+
+def read_lists(list_file):
+ lists = []
+ with open(list_file, "r", encoding="utf8") as fin:
+ for line in fin:
+ lists.append(line.strip())
+ return lists
+
+
+def read_non_lang_symbols(non_lang_sym_path):
+ """read non-linguistic symbol from file.
+
+ The file format is like below:
+
+ {NOISE}\n
+ {BRK}\n
+ ...
+
+
+ Args:
+ non_lang_sym_path: non-linguistic symbol file path, None means no any
+ syms.
+
+ """
+ if non_lang_sym_path is None:
+ return None
+ else:
+ syms = read_lists(non_lang_sym_path)
+ non_lang_syms_pattern = re.compile(r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})")
+ for sym in syms:
+ if non_lang_syms_pattern.fullmatch(sym) is None:
+
+ class BadSymbolFormat(Exception):
+ pass
+
+ raise BadSymbolFormat(
+ "Non-linguistic symbols should be "
+ "formatted in {xxx}//[xxx], consider"
+ " modify '%s' to meet the requirment. "
+ "More details can be found in discussions here : "
+ "https://github.com/wenet-e2e/wenet/pull/819" % (sym)
+ )
+ return syms
+
+
+def read_symbol_table(symbol_table_file):
+ symbol_table = {}
+ with open(symbol_table_file, "r", encoding="utf8") as fin:
+ for line in fin:
+ arr = line.strip().split()
+ assert len(arr) == 2
+ symbol_table[arr[0]] = int(arr[1])
+ return symbol_table
diff --git a/modules/wenet_extractor/utils/init_model.py b/modules/wenet_extractor/utils/init_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..a76dd199de3c3cd87116b5eb2d5a0d7eec284a7c
--- /dev/null
+++ b/modules/wenet_extractor/utils/init_model.py
@@ -0,0 +1,154 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+import torch
+from modules.wenet_extractor.transducer.joint import TransducerJoint
+from modules.wenet_extractor.transducer.predictor import (
+ ConvPredictor,
+ EmbeddingPredictor,
+ RNNPredictor,
+)
+from modules.wenet_extractor.transducer.transducer import Transducer
+from modules.wenet_extractor.transformer.asr_model import ASRModel
+from modules.wenet_extractor.transformer.cmvn import GlobalCMVN
+from modules.wenet_extractor.transformer.ctc import CTC
+from modules.wenet_extractor.transformer.decoder import (
+ BiTransformerDecoder,
+ TransformerDecoder,
+)
+from modules.wenet_extractor.transformer.encoder import (
+ ConformerEncoder,
+ TransformerEncoder,
+)
+from modules.wenet_extractor.squeezeformer.encoder import SqueezeformerEncoder
+from modules.wenet_extractor.efficient_conformer.encoder import (
+ EfficientConformerEncoder,
+)
+from modules.wenet_extractor.paraformer.paraformer import Paraformer
+from modules.wenet_extractor.cif.predictor import Predictor
+from modules.wenet_extractor.utils.cmvn import load_cmvn
+
+
+def init_model(configs):
+ if configs["cmvn_file"] is not None:
+ mean, istd = load_cmvn(configs["cmvn_file"], configs["is_json_cmvn"])
+ global_cmvn = GlobalCMVN(
+ torch.from_numpy(mean).float(), torch.from_numpy(istd).float()
+ )
+ else:
+ global_cmvn = None
+
+ input_dim = configs["input_dim"]
+ vocab_size = configs["output_dim"]
+
+ encoder_type = configs.get("encoder", "conformer")
+ decoder_type = configs.get("decoder", "bitransformer")
+
+ if encoder_type == "conformer":
+ encoder = ConformerEncoder(
+ input_dim, global_cmvn=global_cmvn, **configs["encoder_conf"]
+ )
+ elif encoder_type == "squeezeformer":
+ encoder = SqueezeformerEncoder(
+ input_dim, global_cmvn=global_cmvn, **configs["encoder_conf"]
+ )
+ elif encoder_type == "efficientConformer":
+ encoder = EfficientConformerEncoder(
+ input_dim,
+ global_cmvn=global_cmvn,
+ **configs["encoder_conf"],
+ **(
+ configs["encoder_conf"]["efficient_conf"]
+ if "efficient_conf" in configs["encoder_conf"]
+ else {}
+ ),
+ )
+ else:
+ encoder = TransformerEncoder(
+ input_dim, global_cmvn=global_cmvn, **configs["encoder_conf"]
+ )
+ if decoder_type == "transformer":
+ decoder = TransformerDecoder(
+ vocab_size, encoder.output_size(), **configs["decoder_conf"]
+ )
+ else:
+ assert 0.0 < configs["model_conf"]["reverse_weight"] < 1.0
+ assert configs["decoder_conf"]["r_num_blocks"] > 0
+ decoder = BiTransformerDecoder(
+ vocab_size, encoder.output_size(), **configs["decoder_conf"]
+ )
+ ctc = CTC(vocab_size, encoder.output_size())
+
+ # Init joint CTC/Attention or Transducer model
+ if "predictor" in configs:
+ predictor_type = configs.get("predictor", "rnn")
+ if predictor_type == "rnn":
+ predictor = RNNPredictor(vocab_size, **configs["predictor_conf"])
+ elif predictor_type == "embedding":
+ predictor = EmbeddingPredictor(vocab_size, **configs["predictor_conf"])
+ configs["predictor_conf"]["output_size"] = configs["predictor_conf"][
+ "embed_size"
+ ]
+ elif predictor_type == "conv":
+ predictor = ConvPredictor(vocab_size, **configs["predictor_conf"])
+ configs["predictor_conf"]["output_size"] = configs["predictor_conf"][
+ "embed_size"
+ ]
+ else:
+ raise NotImplementedError("only rnn, embedding and conv type support now")
+ configs["joint_conf"]["enc_output_size"] = configs["encoder_conf"][
+ "output_size"
+ ]
+ configs["joint_conf"]["pred_output_size"] = configs["predictor_conf"][
+ "output_size"
+ ]
+ joint = TransducerJoint(vocab_size, **configs["joint_conf"])
+ model = Transducer(
+ vocab_size=vocab_size,
+ blank=0,
+ predictor=predictor,
+ encoder=encoder,
+ attention_decoder=decoder,
+ joint=joint,
+ ctc=ctc,
+ **configs["model_conf"],
+ )
+ elif "paraformer" in configs:
+ predictor = Predictor(**configs["cif_predictor_conf"])
+ model = Paraformer(
+ vocab_size=vocab_size,
+ encoder=encoder,
+ decoder=decoder,
+ ctc=ctc,
+ predictor=predictor,
+ **configs["model_conf"],
+ )
+ else:
+ model = ASRModel(
+ vocab_size=vocab_size,
+ encoder=encoder,
+ decoder=decoder,
+ ctc=ctc,
+ lfmmi_dir=configs.get("lfmmi_dir", ""),
+ **configs["model_conf"],
+ )
+ return model
diff --git a/modules/wenet_extractor/utils/mask.py b/modules/wenet_extractor/utils/mask.py
new file mode 100644
index 0000000000000000000000000000000000000000..97b6149781893da3fb37e225ef47338a79ef601f
--- /dev/null
+++ b/modules/wenet_extractor/utils/mask.py
@@ -0,0 +1,304 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+
+import torch
+
+'''
+def subsequent_mask(
+ size: int,
+ device: torch.device = torch.device("cpu"),
+) -> torch.Tensor:
+ """Create mask for subsequent steps (size, size).
+
+ This mask is used only in decoder which works in an auto-regressive mode.
+ This means the current step could only do attention with its left steps.
+
+ In encoder, fully attention is used when streaming is not necessary and
+ the sequence is not long. In this case, no attention mask is needed.
+
+ When streaming is need, chunk-based attention is used in encoder. See
+ subsequent_chunk_mask for the chunk-based attention mask.
+
+ Args:
+ size (int): size of mask
+ str device (str): "cpu" or "cuda" or torch.Tensor.device
+ dtype (torch.device): result dtype
+
+ Returns:
+ torch.Tensor: mask
+
+ Examples:
+ >>> subsequent_mask(3)
+ [[1, 0, 0],
+ [1, 1, 0],
+ [1, 1, 1]]
+ """
+ ret = torch.ones(size, size, device=device, dtype=torch.bool)
+ return torch.tril(ret)
+'''
+
+
+def subsequent_mask(
+ size: int,
+ device: torch.device = torch.device("cpu"),
+) -> torch.Tensor:
+ """Create mask for subsequent steps (size, size).
+
+ This mask is used only in decoder which works in an auto-regressive mode.
+ This means the current step could only do attention with its left steps.
+
+ In encoder, fully attention is used when streaming is not necessary and
+ the sequence is not long. In this case, no attention mask is needed.
+
+ When streaming is need, chunk-based attention is used in encoder. See
+ subsequent_chunk_mask for the chunk-based attention mask.
+
+ Args:
+ size (int): size of mask
+ str device (str): "cpu" or "cuda" or torch.Tensor.device
+ dtype (torch.device): result dtype
+
+ Returns:
+ torch.Tensor: mask
+
+ Examples:
+ >>> subsequent_mask(3)
+ [[1, 0, 0],
+ [1, 1, 0],
+ [1, 1, 1]]
+ """
+ arange = torch.arange(size, device=device)
+ mask = arange.expand(size, size)
+ arange = arange.unsqueeze(-1)
+ mask = mask <= arange
+ return mask
+
+
+def subsequent_chunk_mask(
+ size: int,
+ chunk_size: int,
+ num_left_chunks: int = -1,
+ device: torch.device = torch.device("cpu"),
+) -> torch.Tensor:
+ """Create mask for subsequent steps (size, size) with chunk size,
+ this is for streaming encoder
+
+ Args:
+ size (int): size of mask
+ chunk_size (int): size of chunk
+ num_left_chunks (int): number of left chunks
+ <0: use full chunk
+ >=0: use num_left_chunks
+ device (torch.device): "cpu" or "cuda" or torch.Tensor.device
+
+ Returns:
+ torch.Tensor: mask
+
+ Examples:
+ >>> subsequent_chunk_mask(4, 2)
+ [[1, 1, 0, 0],
+ [1, 1, 0, 0],
+ [1, 1, 1, 1],
+ [1, 1, 1, 1]]
+ """
+ ret = torch.zeros(size, size, device=device, dtype=torch.bool)
+ for i in range(size):
+ if num_left_chunks < 0:
+ start = 0
+ else:
+ start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
+ ending = min((i // chunk_size + 1) * chunk_size, size)
+ ret[i, start:ending] = True
+ return ret
+
+
+def add_optional_chunk_mask(
+ xs: torch.Tensor,
+ masks: torch.Tensor,
+ use_dynamic_chunk: bool,
+ use_dynamic_left_chunk: bool,
+ decoding_chunk_size: int,
+ static_chunk_size: int,
+ num_decoding_left_chunks: int,
+):
+ """Apply optional mask for encoder.
+
+ Args:
+ xs (torch.Tensor): padded input, (B, L, D), L for max length
+ mask (torch.Tensor): mask for xs, (B, 1, L)
+ use_dynamic_chunk (bool): whether to use dynamic chunk or not
+ use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
+ training.
+ decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
+ 0: default for training, use random dynamic chunk.
+ <0: for decoding, use full chunk.
+ >0: for decoding, use fixed chunk size as set.
+ static_chunk_size (int): chunk size for static chunk training/decoding
+ if it's greater than 0, if use_dynamic_chunk is true,
+ this parameter will be ignored
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
+ the chunk size is decoding_chunk_size.
+ >=0: use num_decoding_left_chunks
+ <0: use all left chunks
+
+ Returns:
+ torch.Tensor: chunk mask of the input xs.
+ """
+ # Whether to use chunk mask or not
+ if use_dynamic_chunk:
+ max_len = xs.size(1)
+ if decoding_chunk_size < 0:
+ chunk_size = max_len
+ num_left_chunks = -1
+ elif decoding_chunk_size > 0:
+ chunk_size = decoding_chunk_size
+ num_left_chunks = num_decoding_left_chunks
+ else:
+ # chunk size is either [1, 25] or full context(max_len).
+ # Since we use 4 times subsampling and allow up to 1s(100 frames)
+ # delay, the maximum frame is 100 / 4 = 25.
+ chunk_size = torch.randint(1, max_len, (1,)).item()
+ num_left_chunks = -1
+ if chunk_size > max_len // 2:
+ chunk_size = max_len
+ else:
+ chunk_size = chunk_size % 25 + 1
+ if use_dynamic_left_chunk:
+ max_left_chunks = (max_len - 1) // chunk_size
+ num_left_chunks = torch.randint(0, max_left_chunks, (1,)).item()
+ chunk_masks = subsequent_chunk_mask(
+ xs.size(1), chunk_size, num_left_chunks, xs.device
+ ) # (L, L)
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
+ chunk_masks = masks & chunk_masks # (B, L, L)
+ elif static_chunk_size > 0:
+ num_left_chunks = num_decoding_left_chunks
+ chunk_masks = subsequent_chunk_mask(
+ xs.size(1), static_chunk_size, num_left_chunks, xs.device
+ ) # (L, L)
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
+ chunk_masks = masks & chunk_masks # (B, L, L)
+ else:
+ chunk_masks = masks
+ return chunk_masks
+
+
+def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
+ """Make mask tensor containing indices of padded part.
+
+ See description of make_non_pad_mask.
+
+ Args:
+ lengths (torch.Tensor): Batch of lengths (B,).
+ Returns:
+ torch.Tensor: Mask tensor containing indices of padded part.
+
+ Examples:
+ >>> lengths = [5, 3, 2]
+ >>> make_pad_mask(lengths)
+ masks = [[0, 0, 0, 0 ,0],
+ [0, 0, 0, 1, 1],
+ [0, 0, 1, 1, 1]]
+ """
+ batch_size = lengths.size(0)
+ max_len = max_len if max_len > 0 else lengths.max().item()
+ seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device)
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
+ seq_length_expand = lengths.unsqueeze(-1)
+ mask = seq_range_expand >= seq_length_expand
+ return mask
+
+
+def make_non_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
+ """Make mask tensor containing indices of non-padded part.
+
+ The sequences in a batch may have different lengths. To enable
+ batch computing, padding is need to make all sequence in same
+ size. To avoid the padding part pass value to context dependent
+ block such as attention or convolution , this padding part is
+ masked.
+
+ This pad_mask is used in both encoder and decoder.
+
+ 1 for non-padded part and 0 for padded part.
+
+ Args:
+ lengths (torch.Tensor): Batch of lengths (B,).
+ Returns:
+ torch.Tensor: mask tensor containing indices of padded part.
+
+ Examples:
+ >>> lengths = [5, 3, 2]
+ >>> make_non_pad_mask(lengths)
+ masks = [[1, 1, 1, 1 ,1],
+ [1, 1, 1, 0, 0],
+ [1, 1, 0, 0, 0]]
+ """
+ return ~make_pad_mask(lengths)
+
+
+def mask_finished_scores(score: torch.Tensor, flag: torch.Tensor) -> torch.Tensor:
+ """
+ If a sequence is finished, we only allow one alive branch. This function
+ aims to give one branch a zero score and the rest -inf score.
+
+ Args:
+ score (torch.Tensor): A real value array with shape
+ (batch_size * beam_size, beam_size).
+ flag (torch.Tensor): A bool array with shape
+ (batch_size * beam_size, 1).
+
+ Returns:
+ torch.Tensor: (batch_size * beam_size, beam_size).
+ """
+ beam_size = score.size(-1)
+ zero_mask = torch.zeros_like(flag, dtype=torch.bool)
+ if beam_size > 1:
+ unfinished = torch.cat((zero_mask, flag.repeat([1, beam_size - 1])), dim=1)
+ finished = torch.cat((flag, zero_mask.repeat([1, beam_size - 1])), dim=1)
+ else:
+ unfinished = zero_mask
+ finished = flag
+ score.masked_fill_(unfinished, -float("inf"))
+ score.masked_fill_(finished, 0)
+ return score
+
+
+def mask_finished_preds(
+ pred: torch.Tensor, flag: torch.Tensor, eos: int
+) -> torch.Tensor:
+ """
+ If a sequence is finished, all of its branch should be
+
+ Args:
+ pred (torch.Tensor): A int array with shape
+ (batch_size * beam_size, beam_size).
+ flag (torch.Tensor): A bool array with shape
+ (batch_size * beam_size, 1).
+
+ Returns:
+ torch.Tensor: (batch_size * beam_size).
+ """
+ beam_size = pred.size(-1)
+ finished = flag.repeat([1, beam_size])
+ return pred.masked_fill_(finished, eos)
diff --git a/modules/wenet_extractor/utils/scheduler.py b/modules/wenet_extractor/utils/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5885dbb1faca52272f8825df2da7aa7533d1540
--- /dev/null
+++ b/modules/wenet_extractor/utils/scheduler.py
@@ -0,0 +1,738 @@
+# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
+
+# ## Citations
+
+# ```bibtex
+# @inproceedings{yao2021wenet,
+# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
+# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
+# booktitle={Proc. Interspeech},
+# year={2021},
+# address={Brno, Czech Republic },
+# organization={IEEE}
+# }
+
+# @article{zhang2022wenet,
+# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
+# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
+# journal={arXiv preprint arXiv:2203.15455},
+# year={2022}
+# }
+#
+
+from typing import Union
+
+import math
+import warnings
+import torch
+from torch.optim.lr_scheduler import _LRScheduler
+
+
+class WarmupLR(_LRScheduler):
+ """The WarmupLR scheduler
+
+ This scheduler is almost same as NoamLR Scheduler except for following
+ difference:
+
+ NoamLR:
+ lr = optimizer.lr * model_size ** -0.5
+ * min(step ** -0.5, step * warmup_step ** -1.5)
+ WarmupLR:
+ lr = optimizer.lr * warmup_step ** 0.5
+ * min(step ** -0.5, step * warmup_step ** -1.5)
+
+ Note that the maximum lr equals to optimizer.lr in this scheduler.
+
+ """
+
+ def __init__(
+ self,
+ optimizer: torch.optim.Optimizer,
+ warmup_steps: Union[int, float] = 25000,
+ last_epoch: int = -1,
+ ):
+ self.warmup_steps = warmup_steps
+
+ # __init__() must be invoked before setting field
+ # because step() is also invoked in __init__()
+ super().__init__(optimizer, last_epoch)
+
+ def __repr__(self):
+ return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})"
+
+ def get_lr(self):
+ step_num = self.last_epoch + 1
+ if self.warmup_steps == 0:
+ return [lr * step_num**-0.5 for lr in self.base_lrs]
+ else:
+ return [
+ lr
+ * self.warmup_steps**0.5
+ * min(step_num**-0.5, step_num * self.warmup_steps**-1.5)
+ for lr in self.base_lrs
+ ]
+
+ def set_step(self, step: int):
+ self.last_epoch = step
+
+
+class WarmupPolicy(_LRScheduler):
+ """Adds warmup kwargs and warmup logic to lr policy.
+ All arguments should be passed as kwargs for clarity,
+ Args:
+ warmup_steps: Number of training steps in warmup stage
+ warmup_ratio: Ratio of warmup steps to total steps
+ max_steps: Total number of steps while training or `None` for
+ infinite training
+ """
+
+ def __init__(
+ self,
+ optimizer,
+ *,
+ warmup_steps=None,
+ warmup_ratio=None,
+ max_steps=None,
+ min_lr=0.0,
+ last_epoch=-1,
+ ):
+ assert not (
+ warmup_steps is not None and warmup_ratio is not None
+ ), "Either use particular number of step or ratio"
+ assert (
+ warmup_ratio is None or max_steps is not None
+ ), "If there is a ratio, there should be a total steps"
+
+ # It is necessary to assign all attributes *before* __init__,
+ # as class is wrapped by an inner class.
+ self.max_steps = max_steps
+ if warmup_steps is not None:
+ self.warmup_steps = warmup_steps
+ elif warmup_ratio is not None:
+ self.warmup_steps = int(warmup_ratio * max_steps)
+ else:
+ self.warmup_steps = 0
+
+ self.min_lr = min_lr
+ super().__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ if not self._get_lr_called_within_step:
+ warnings.warn(
+ "To get the last learning rate computed "
+ "by the scheduler, please use `get_last_lr()`.",
+ UserWarning,
+ stacklevel=2,
+ )
+
+ step = self.last_epoch
+
+ if step <= self.warmup_steps and self.warmup_steps > 0:
+ return self._get_warmup_lr(step)
+
+ if step > self.max_steps:
+ return [self.min_lr for _ in self.base_lrs]
+
+ return self._get_lr(step)
+
+ def _get_warmup_lr(self, step):
+ lr_val = (step + 1) / (self.warmup_steps + 1)
+ return [initial_lr * lr_val for initial_lr in self.base_lrs]
+
+ def _get_lr(self, step):
+ """Simple const lr policy"""
+ return self.base_lrs
+
+
+class SquareRootConstantPolicy(_LRScheduler):
+ """Adds warmup kwargs and warmup logic to lr policy.
+ All arguments should be passed as kwargs for clarity,
+ Args:
+ warmup_steps: Number of training steps in warmup stage
+ warmup_ratio: Ratio of warmup steps to total steps
+ max_steps: Total number of steps while training or `None` for
+ infinite training
+ """
+
+ def __init__(
+ self,
+ optimizer,
+ *,
+ constant_steps=None,
+ constant_ratio=None,
+ max_steps=None,
+ min_lr=0.0,
+ last_epoch=-1,
+ ):
+ assert not (
+ constant_steps is not None and constant_ratio is not None
+ ), "Either use particular number of step or ratio"
+ assert (
+ constant_ratio is None or max_steps is not None
+ ), "If there is a ratio, there should be a total steps"
+
+ # It is necessary to assign all attributes *before* __init__,
+ # as class is wrapped by an inner class.
+ self.max_steps = max_steps
+ if constant_steps is not None:
+ self.constant_steps = constant_steps
+ elif constant_ratio is not None:
+ self.constant_steps = int(constant_ratio * max_steps)
+ else:
+ self.constant_steps = 0
+
+ self.constant_lr = 1 / (constant_steps**0.5)
+ self.min_lr = min_lr
+ super().__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ if not self._get_lr_called_within_step:
+ warnings.warn(
+ "To get the last learning rate computed "
+ "by the scheduler, please use `get_last_lr()`.",
+ UserWarning,
+ stacklevel=2,
+ )
+
+ step = self.last_epoch
+
+ if step <= self.constant_steps:
+ return [self.constant_lr for _ in self.base_lrs]
+
+ if step > self.max_steps:
+ return [self.min_lr for _ in self.base_lrs]
+
+ return self._get_lr(step)
+
+ def _get_lr(self, step):
+ """Simple const lr policy"""
+ return self.base_lrs
+
+
+class WarmupHoldPolicy(WarmupPolicy):
+ """Variant of WarmupPolicy which maintains high
+ learning rate for a defined number of steps.
+ All arguments should be passed as kwargs for clarity,
+ Args:
+ warmup_steps: Number of training steps in warmup stage
+ warmup_ratio: Ratio of warmup steps to total steps
+ hold_steps: Number of training steps to
+ hold the learning rate after warm up
+ hold_ratio: Ratio of hold steps to total steps
+ max_steps: Total number of steps while training or `None` for
+ infinite training
+ """
+
+ def __init__(
+ self,
+ optimizer,
+ *,
+ warmup_steps=None,
+ warmup_ratio=None,
+ hold_steps=None,
+ hold_ratio=None,
+ max_steps=None,
+ min_lr=0.0,
+ last_epoch=-1,
+ ):
+ assert not (
+ hold_steps is not None and hold_ratio is not None
+ ), "Either use particular number of step or ratio"
+ assert (
+ hold_ratio is None or max_steps is not None
+ ), "If there is a ratio, there should be a total steps"
+
+ self.min_lr = min_lr
+ self._last_warmup_lr = 0.0
+
+ # Necessary to duplicate as class attributes are hidden in inner class
+ self.max_steps = max_steps
+ if warmup_steps is not None:
+ self.warmup_steps = warmup_steps
+ elif warmup_ratio is not None:
+ self.warmup_steps = int(warmup_ratio * max_steps)
+ else:
+ self.warmup_steps = 0
+
+ if hold_steps is not None:
+ self.hold_steps = hold_steps + self.warmup_steps
+ elif hold_ratio is not None:
+ self.hold_steps = int(hold_ratio * max_steps) + self.warmup_steps
+ else:
+ self.hold_steps = 0
+
+ super().__init__(
+ optimizer,
+ warmup_steps=warmup_steps,
+ warmup_ratio=warmup_ratio,
+ max_steps=max_steps,
+ last_epoch=last_epoch,
+ min_lr=min_lr,
+ )
+
+ def get_lr(self):
+ if not self._get_lr_called_within_step:
+ warnings.warn(
+ "To get the last learning rate computed by the scheduler,"
+ " "
+ "please use `get_last_lr()`.",
+ UserWarning,
+ stacklevel=2,
+ )
+
+ step = self.last_epoch
+
+ # Warmup phase
+ if step <= self.warmup_steps and self.warmup_steps > 0:
+ return self._get_warmup_lr(step)
+
+ # Hold phase
+ if (step >= self.warmup_steps) and (step < self.hold_steps):
+ return self.base_lrs
+
+ if step > self.max_steps:
+ return [self.min_lr for _ in self.base_lrs]
+
+ return self._get_lr(step)
+
+
+class WarmupAnnealHoldPolicy(_LRScheduler):
+ """Adds warmup kwargs and warmup logic to lr policy.
+ All arguments should be passed as kwargs for clarity,
+ Args:
+ warmup_steps: Number of training steps in warmup stage
+ warmup_ratio: Ratio of warmup steps to total steps
+ max_steps: Total number of steps while training or `None` for
+ infinite training
+ min_lr: Minimum lr to hold the learning rate after decay at.
+ constant_steps: Number of steps to keep lr constant at.
+ constant_ratio: Ratio of steps to keep lr constant.
+ """
+
+ def __init__(
+ self,
+ optimizer,
+ *,
+ warmup_steps=None,
+ warmup_ratio=None,
+ constant_steps=None,
+ constant_ratio=None,
+ max_steps=None,
+ min_lr=0.0,
+ last_epoch=-1,
+ ):
+ assert not (
+ warmup_steps is not None and warmup_ratio is not None
+ ), "Either use particular number of step or ratio"
+ assert not (
+ constant_steps is not None and constant_ratio is not None
+ ), "Either use constant_steps or constant_ratio"
+ assert (
+ warmup_ratio is None or max_steps is not None
+ ), "If there is a ratio, there should be a total steps"
+
+ # It is necessary to assign all attributes *before* __init__,
+ # as class is wrapped by an inner class.
+ self.max_steps = max_steps
+
+ if warmup_steps is not None:
+ self.warmup_steps = warmup_steps
+ elif warmup_ratio is not None:
+ self.warmup_steps = int(warmup_ratio * max_steps)
+ else:
+ self.warmup_steps = 0
+
+ if constant_steps is not None:
+ self.constant_steps = constant_steps
+ elif constant_ratio is not None:
+ self.constant_steps = int(constant_ratio * max_steps)
+ else:
+ self.constant_steps = 0
+
+ self.decay_steps = max_steps - (self.constant_steps + self.warmup_steps)
+
+ self.min_lr = min_lr
+ super().__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ if not self._get_lr_called_within_step:
+ warnings.warn(
+ "To get the last learning rate computed "
+ "by the scheduler, please use `get_last_lr()`.",
+ UserWarning,
+ stacklevel=2,
+ )
+
+ step = self.last_epoch
+
+ # Warmup steps
+ if self.warmup_steps > 0 and step <= self.warmup_steps:
+ return self._get_warmup_lr(step)
+
+ # Constant steps after warmup and decay
+ if (
+ self.constant_steps > 0
+ and (self.warmup_steps + self.decay_steps) < step <= self.max_steps
+ ):
+ return self._get_constant_lr(step)
+
+ # Min lr after max steps of updates
+ if step > self.max_steps:
+ return [self.min_lr for _ in self.base_lrs]
+
+ return self._get_lr(step)
+
+ def _get_warmup_lr(self, step):
+ lr_val = (step + 1) / (self.warmup_steps + 1)
+ return [initial_lr * lr_val for initial_lr in self.base_lrs]
+
+ def _get_constant_lr(self, step):
+ return [self.min_lr for _ in self.base_lrs]
+
+ def _get_lr(self, step):
+ """Simple const lr policy"""
+ return self.base_lrs
+
+
+def _squareroot_annealing(initial_lr, step, max_steps, min_lr):
+ mult = ((max_steps - step) / max_steps) ** 0.5
+ out_lr = initial_lr * mult
+ out_lr = max(out_lr, min_lr)
+ return out_lr
+
+
+def _square_annealing(initial_lr, step, max_steps, min_lr):
+ mult = ((max_steps - step) / max_steps) ** 2
+ out_lr = initial_lr * mult
+ out_lr = max(out_lr, min_lr)
+ return out_lr
+
+
+def _cosine_annealing(initial_lr, step, max_steps, min_lr):
+ mult = 0.5 * (1 + math.cos(math.pi * step / max_steps))
+ out_lr = (initial_lr - min_lr) * mult + min_lr
+ return out_lr
+
+
+def _linear_warmup_with_cosine_annealing(
+ max_lr, warmup_steps, step, decay_steps, min_lr
+):
+ assert max_lr > min_lr
+ # Use linear warmup for the initial part.
+ if warmup_steps > 0 and step <= warmup_steps:
+ return max_lr * float(step) / float(warmup_steps)
+
+ # For any steps larger than `decay_steps`, use `min_lr`.
+ if step > warmup_steps + decay_steps:
+ return min_lr
+
+ # If we are done with the warmup period, use the decay style.
+ num_steps_ = step - warmup_steps
+ decay_steps_ = decay_steps
+ decay_ratio = float(num_steps_) / float(decay_steps_)
+ assert decay_ratio >= 0.0
+ assert decay_ratio <= 1.0
+ delta_lr = max_lr - min_lr
+
+ coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
+
+ return min_lr + coeff * delta_lr
+
+
+def _poly_decay(initial_lr, step, decay_steps, power, min_lr, cycle):
+ if cycle:
+ multiplier = 1.0 if step == 0 else math.ceil(step / decay_steps)
+ decay_steps *= multiplier
+ else:
+ step = min(step, decay_steps)
+ p = step / decay_steps
+ lr = (initial_lr - min_lr) * math.pow(1.0 - p, power)
+ lr += min_lr
+ return lr
+
+
+def _noam_hold_annealing(
+ initial_lr, step, warmup_steps, hold_steps, decay_rate, min_lr
+):
+ # hold_steps = total number of steps
+ # to hold the LR, not the warmup + hold steps.
+ T_warmup_decay = max(1, warmup_steps**decay_rate)
+ T_hold_decay = max(1, (step - hold_steps) ** decay_rate)
+ lr = (initial_lr * T_warmup_decay) / T_hold_decay
+ lr = max(lr, min_lr)
+ return lr
+
+
+class SquareAnnealing(WarmupPolicy):
+ def __init__(self, optimizer, *, max_steps, min_lr=1e-5, last_epoch=-1, **kwargs):
+ super().__init__(
+ optimizer=optimizer,
+ max_steps=max_steps,
+ last_epoch=last_epoch,
+ min_lr=min_lr,
+ **kwargs,
+ )
+
+ def _get_lr(self, step):
+ new_lrs = [
+ _square_annealing(
+ initial_lr=initial_lr,
+ step=step - self.warmup_steps,
+ max_steps=self.max_steps - self.warmup_steps,
+ min_lr=self.min_lr,
+ )
+ for initial_lr in self.base_lrs
+ ]
+ return new_lrs
+
+
+class SquareRootAnnealing(WarmupPolicy):
+ def __init__(self, optimizer, *, max_steps, min_lr=0, last_epoch=-1, **kwargs):
+ super().__init__(
+ optimizer=optimizer,
+ max_steps=max_steps,
+ last_epoch=last_epoch,
+ min_lr=min_lr,
+ **kwargs,
+ )
+
+ def _get_lr(self, step):
+ new_lrs = [
+ _squareroot_annealing(
+ initial_lr=initial_lr,
+ step=step,
+ max_steps=self.max_steps,
+ min_lr=self.min_lr,
+ )
+ for initial_lr in self.base_lrs
+ ]
+ return new_lrs
+
+
+class CosineAnnealing(WarmupAnnealHoldPolicy):
+ def __init__(self, optimizer, *, max_steps, min_lr=0, last_epoch=-1, **kwargs):
+ super().__init__(
+ optimizer=optimizer,
+ max_steps=max_steps,
+ last_epoch=last_epoch,
+ min_lr=min_lr,
+ **kwargs,
+ )
+
+ def _get_lr(self, step):
+ for initial_lr in self.base_lrs:
+ if initial_lr < self.min_lr:
+ raise ValueError(
+ f"{self} received an initial learning rate "
+ f"that was lower than the minimum learning rate."
+ )
+
+ if self.constant_steps is None or self.constant_steps == 0:
+ new_lrs = [
+ _cosine_annealing(
+ initial_lr=initial_lr,
+ step=step - self.warmup_steps,
+ max_steps=self.max_steps - self.warmup_steps,
+ min_lr=self.min_lr,
+ )
+ for initial_lr in self.base_lrs
+ ]
+ else:
+ new_lrs = self._get_linear_warmup_with_cosine_annealing_lr(step)
+ return new_lrs
+
+ def _get_warmup_lr(self, step):
+ if self.constant_steps is None or self.constant_steps == 0:
+ return super()._get_warmup_lr(step)
+ else:
+ # Use linear warmup for the initial part.
+ return self._get_linear_warmup_with_cosine_annealing_lr(step)
+
+ def _get_constant_lr(self, step):
+ # Only called when `constant_steps` > 0.
+ return self._get_linear_warmup_with_cosine_annealing_lr(step)
+
+ def _get_linear_warmup_with_cosine_annealing_lr(self, step):
+ # Cosine Schedule for Megatron LM,
+ # slightly different warmup schedule + constant LR at the end.
+ new_lrs = [
+ _linear_warmup_with_cosine_annealing(
+ max_lr=self.base_lrs[0],
+ warmup_steps=self.warmup_steps,
+ step=step,
+ decay_steps=self.decay_steps,
+ min_lr=self.min_lr,
+ )
+ for _ in self.base_lrs
+ ]
+ return new_lrs
+
+
+class NoamAnnealing(_LRScheduler):
+ def __init__(
+ self,
+ optimizer,
+ *,
+ d_model,
+ warmup_steps=None,
+ warmup_ratio=None,
+ max_steps=None,
+ min_lr=0.0,
+ last_epoch=-1,
+ ):
+ self._normalize = d_model ** (-0.5)
+ assert not (
+ warmup_steps is not None and warmup_ratio is not None
+ ), "Either use particular number of step or ratio"
+ assert (
+ warmup_ratio is None or max_steps is not None
+ ), "If there is a ratio, there should be a total steps"
+
+ # It is necessary to assign all attributes *before* __init__,
+ # as class is wrapped by an inner class.
+ self.max_steps = max_steps
+ if warmup_steps is not None:
+ self.warmup_steps = warmup_steps
+ elif warmup_ratio is not None:
+ self.warmup_steps = int(warmup_ratio * max_steps)
+ else:
+ self.warmup_steps = 0
+
+ self.min_lr = min_lr
+ super().__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ if not self._get_lr_called_within_step:
+ warnings.warn(
+ "To get the last learning rate computed "
+ "by the scheduler, please use `get_last_lr()`.",
+ UserWarning,
+ stacklevel=2,
+ )
+
+ step = max(1, self.last_epoch)
+
+ for initial_lr in self.base_lrs:
+ if initial_lr < self.min_lr:
+ raise ValueError(
+ f"{self} received an initial learning rate "
+ f"that was lower than the minimum learning rate."
+ )
+
+ new_lrs = [
+ self._noam_annealing(initial_lr=initial_lr, step=step)
+ for initial_lr in self.base_lrs
+ ]
+ return new_lrs
+
+ def _noam_annealing(self, initial_lr, step):
+ if self.warmup_steps > 0:
+ mult = self._normalize * min(
+ step ** (-0.5), step * (self.warmup_steps ** (-1.5))
+ )
+ else:
+ mult = self._normalize * step ** (-0.5)
+
+ out_lr = initial_lr * mult
+ if step > self.warmup_steps:
+ out_lr = max(out_lr, self.min_lr)
+ return out_lr
+
+
+class NoamHoldAnnealing(WarmupHoldPolicy):
+ def __init__(
+ self,
+ optimizer,
+ *,
+ max_steps,
+ decay_rate=0.5,
+ min_lr=0.0,
+ last_epoch=-1,
+ **kwargs,
+ ):
+ """
+ From Nemo:
+ Implementation of the Noam Hold Annealing policy
+ from the SqueezeFormer paper.
+
+ Unlike NoamAnnealing, the peak learning rate
+ can be explicitly set for this scheduler.
+ The schedule first performs linear warmup,
+ then holds the peak LR, then decays with some schedule for
+ the remainder of the steps.
+ Therefore the min-lr is still dependent
+ on the hyper parameters selected.
+
+ It's schedule is determined by three factors-
+
+ Warmup Steps: Initial stage, where linear warmup
+ occurs uptil the peak LR is reached. Unlike NoamAnnealing,
+ the peak LR is explicitly stated here instead of a scaling factor.
+
+ Hold Steps: Intermediate stage, where the peak LR
+ is maintained for some number of steps. In this region,
+ the high peak LR allows the model to converge faster
+ if training is stable. However the high LR
+ may also cause instability during training.
+ Should usually be a significant fraction of training
+ steps (around 30-40% of the entire training steps).
+
+ Decay Steps: Final stage, where the LR rapidly decays
+ with some scaling rate (set by decay rate).
+ To attain Noam decay, use 0.5,
+ for Squeezeformer recommended decay, use 1.0.
+ The fast decay after prolonged high LR during
+ hold phase allows for rapid convergence.
+
+ References:
+ - [Squeezeformer:
+ An Efficient Transformer for Automatic Speech Recognition]
+ (https://arxiv.org/abs/2206.00888)
+
+ Args:
+ optimizer: Pytorch compatible Optimizer object.
+ warmup_steps: Number of training steps in warmup stage
+ warmup_ratio: Ratio of warmup steps to total steps
+ hold_steps: Number of training steps to
+ hold the learning rate after warm up
+ hold_ratio: Ratio of hold steps to total steps
+ max_steps: Total number of steps while training or `None` for
+ infinite training
+ decay_rate: Float value describing the polynomial decay
+ after the hold period. Default value
+ of 0.5 corresponds to Noam decay.
+ min_lr: Minimum learning rate.
+ """
+ self.decay_rate = decay_rate
+ super().__init__(
+ optimizer=optimizer,
+ max_steps=max_steps,
+ last_epoch=last_epoch,
+ min_lr=min_lr,
+ **kwargs,
+ )
+
+ def _get_lr(self, step):
+ if self.warmup_steps is None or self.warmup_steps == 0:
+ raise ValueError("Noam scheduler cannot be used without warmup steps")
+
+ if self.hold_steps > 0:
+ hold_steps = self.hold_steps - self.warmup_steps
+ else:
+ hold_steps = 0
+
+ new_lrs = [
+ _noam_hold_annealing(
+ initial_lr,
+ step=step,
+ warmup_steps=self.warmup_steps,
+ hold_steps=hold_steps,
+ decay_rate=self.decay_rate,
+ min_lr=self.min_lr,
+ )
+ for initial_lr in self.base_lrs
+ ]
+ return new_lrs
+
+ def set_step(self, step: int):
+ self.last_epoch = step
diff --git a/optimizer/__init__.py b/optimizer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/optimizer/optimizers.py b/optimizer/optimizers.py
new file mode 100644
index 0000000000000000000000000000000000000000..94260aeefa2aefd8373888e2a2f0b2ab81a5fd59
--- /dev/null
+++ b/optimizer/optimizers.py
@@ -0,0 +1,774 @@
+# This module is modified from https://github.com/Plachtaa/VALL-E-X/blob/3faaf8ccadb154d63b38070caf518ce9309ea0f4/modules/optim.py#L836
+
+import logging
+import contextlib
+import torch
+from torch import Tensor
+from torch.optim.lr_scheduler import _LRScheduler
+from torch.optim import Optimizer
+from typing import List, Tuple
+from collections import defaultdict
+
+
+class NoamLR(_LRScheduler):
+ """
+ Implements the Noam Learning rate schedule. This corresponds to increasing the learning rate
+ linearly for the first ``num_warmup`` training steps, and decreasing it thereafter proportionally
+ to the inverse square root of the step number, scaled by the inverse square root of the
+ dimensionality of the model. Time will tell if this is just madness or it's actually important.
+ Parameters
+ ----------
+ num_warmup: ``int``, required.
+ The number of steps to linearly increase the learning rate.
+ """
+
+ def __init__(self, optimizer, num_warmup):
+ self.num_warmup = num_warmup
+ self.base_lr = optimizer.param_groups[0]["lr"]
+ super().__init__(optimizer)
+
+ def get_lr(self):
+ last_epoch = max(1, self.last_epoch)
+ scale = min(last_epoch ** (-0.5), last_epoch * self.num_warmup ** (-1.5))
+ return [scale * self.base_lr]
+
+
+class Eve(Optimizer):
+ """
+ Implements Eve algorithm. This is a modified version of AdamW with a special
+ way of setting the weight-decay / shrinkage-factor, which is designed to make the
+ rms of the parameters approach a particular target_rms (default: 0.1). This is
+ for use with networks with 'scaled' versions of modules (see scaling.py), which
+ will be close to invariant to the absolute scale on the parameter matrix.
+
+ The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
+ The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
+ Eve is unpublished so far.
+
+ Arguments:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 1e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square (default: (0.9, 0.999))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-8)
+ weight_decay (float, optional): weight decay coefficient (default: 3e-4;
+ this value means that the weight would decay significantly after
+ about 3k minibatches. Is not multiplied by learning rate, but
+ is conditional on RMS-value of parameter being > target_rms.
+ target_rms (float, optional): target root-mean-square value of
+ parameters, if they fall below this we will stop applying weight decay.
+
+
+ .. _Adam: A Method for Stochastic Optimization:
+ https://arxiv.org/abs/1412.6980
+ .. _Decoupled Weight Decay Regularization:
+ https://arxiv.org/abs/1711.05101
+ .. _On the Convergence of Adam and Beyond:
+ https://openreview.net/forum?id=ryQu7f-RZ
+ """
+
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.98),
+ eps=1e-8,
+ weight_decay=1e-3,
+ target_rms=0.1,
+ ):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+ if not 0 <= weight_decay <= 0.1:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ if not 0 < target_rms <= 10.0:
+ raise ValueError("Invalid target_rms value: {}".format(target_rms))
+ defaults = dict(
+ lr=lr,
+ betas=betas,
+ eps=eps,
+ weight_decay=weight_decay,
+ target_rms=target_rms,
+ )
+ super(Eve, self).__init__(params, defaults)
+
+ def __setstate__(self, state):
+ super(Eve, self).__setstate__(state)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group["params"]:
+ if p.grad is None:
+ continue
+
+ # Perform optimization step
+ grad = p.grad
+ if grad.is_sparse:
+ raise RuntimeError("AdamW does not support sparse gradients")
+
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ state["step"] = 0
+ # Exponential moving average of gradient values
+ state["exp_avg"] = torch.zeros_like(
+ p, memory_format=torch.preserve_format
+ )
+ # Exponential moving average of squared gradient values
+ state["exp_avg_sq"] = torch.zeros_like(
+ p, memory_format=torch.preserve_format
+ )
+
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
+
+ beta1, beta2 = group["betas"]
+
+ state["step"] += 1
+ bias_correction1 = 1 - beta1 ** state["step"]
+ bias_correction2 = 1 - beta2 ** state["step"]
+
+ # Decay the first and second moment running average coefficient
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
+ denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_(
+ group["eps"]
+ )
+
+ step_size = group["lr"] / bias_correction1
+ target_rms = group["target_rms"]
+ weight_decay = group["weight_decay"]
+
+ if p.numel() > 1:
+ # avoid applying this weight-decay on "scaling factors"
+ # (which are scalar).
+ is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5))
+ p.mul_(1 - (weight_decay * is_above_target_rms))
+
+ p.addcdiv_(exp_avg, denom, value=-step_size)
+
+ # if random.random() < 0.0005:
+ # step = (exp_avg / denom) * step_size
+ # logging.info(
+ # f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}"
+ # )
+
+ return loss
+
+
+class BatchedOptimizer(Optimizer):
+ """
+ This class adds to class Optimizer the capability to optimize parameters in batches:
+ it will stack the parameters and their grads for you so the optimizer can work
+ on tensors with an extra leading dimension. This is intended for speed with GPUs,
+ as it reduces the number of kernels launched in the optimizer.
+
+ Args:
+ params:
+ """
+
+ def __init__(self, params, defaults):
+ super(BatchedOptimizer, self).__init__(params, defaults)
+
+ @contextlib.contextmanager
+ def batched_params(self, param_group, group_params_names):
+ """
+ This function returns (technically, yields) a list of
+ of tuples (p, state), where
+ p is a `fake` parameter that is stacked (over axis 0) from real parameters
+ that share the same shape, and its gradient is also stacked;
+ `state` is the state corresponding to this batch of parameters
+ (it will be physically located in the "state" for one of the real
+ parameters, the last one that has any particular shape and dtype).
+
+ This function is decorated as a context manager so that it can
+ write parameters back to their "real" locations.
+
+ The idea is, instead of doing:
+
+ for p in group["params"]:
+ state = self.state[p]
+ ...
+
+ you can do:
+
+ with self.batched_params(group["params"]) as batches:
+ for p, state, p_names in batches:
+ ...
+
+
+ Args:
+ group: a parameter group, which is a list of parameters; should be
+ one of self.param_groups.
+ group_params_names: name for each parameter in group,
+ which is List[str].
+ """
+ batches = defaultdict(
+ list
+ ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
+ batches_names = defaultdict(
+ list
+ ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
+
+ assert len(param_group) == len(group_params_names)
+ for p, named_p in zip(param_group, group_params_names):
+ key = (str(p.dtype), *p.shape)
+ batches[key].append(p)
+ batches_names[key].append(named_p)
+
+ batches_names_keys = list(batches_names.keys())
+ sorted_idx = sorted(
+ range(len(batches_names)), key=lambda i: batches_names_keys[i]
+ )
+ batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
+ batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
+
+ stacked_params_dict = dict()
+
+ # turn batches into a list, in deterministic order.
+ # tuples will contain tuples of (stacked_param, state, stacked_params_names),
+ # one for each batch in `batches`.
+ tuples = []
+
+ for batch, batch_names in zip(batches, batches_names):
+ p = batch[0]
+ # we arbitrarily store the state in the
+ # state corresponding to the 1st parameter in the
+ # group. class Optimizer will take care of saving/loading state.
+ state = self.state[p]
+ p_stacked = torch.stack(batch)
+ grad = torch.stack(
+ [torch.zeros_like(p) if p.grad is None else p.grad for p in batch]
+ )
+ p_stacked.grad = grad
+ stacked_params_dict[key] = p_stacked
+ tuples.append((p_stacked, state, batch_names))
+
+ yield tuples
+
+ for (stacked_params, _state, _names), batch in zip(tuples, batches):
+ for i, p in enumerate(batch):
+ p.copy_(stacked_params[i])
+
+
+class ScaledAdam(BatchedOptimizer):
+ """
+ Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
+ proportional to the norm of that parameter; and also learn the scale of the parameter,
+ in log space, subject to upper and lower limits (as if we had factored each parameter as
+ param = underlying_param * log_scale.exp())
+
+
+ Args:
+ params: The parameters or param_groups to optimize (like other Optimizer subclasses)
+ lr: The learning rate. We will typically use a learning rate schedule that starts
+ at 0.03 and decreases over time, i.e. much higher than other common
+ optimizers.
+ clipping_scale: (e.g. 2.0)
+ A scale for gradient-clipping: if specified, the normalized gradients
+ over the whole model will be clipped to have 2-norm equal to
+ `clipping_scale` times the median 2-norm over the most recent period
+ of `clipping_update_period` minibatches. By "normalized gradients",
+ we mean after multiplying by the rms parameter value for this tensor
+ [for non-scalars]; this is appropriate because our update is scaled
+ by this quantity.
+ betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
+ Must satisfy 0 < beta <= beta2 < 1.
+ scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
+ scale of each parameter tensor and scalar parameters of the mode..
+ If each parameter were decomposed
+ as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
+ would be a the scaling factor on the learning rate of p_scale.
+ eps: A general-purpose epsilon to prevent division by zero
+ param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
+ learning the scale on the parameters (we'll constrain the rms of each non-scalar
+ parameter tensor to be >= this value)
+ param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
+ learning the scale on the parameters (we'll constrain the rms of each non-scalar
+ parameter tensor to be <= this value)
+ scalar_max: Maximum absolute value for scalar parameters (applicable if your
+ model has any parameters with numel() == 1).
+ size_update_period: The periodicity, in steps, with which we update the size (scale)
+ of the parameter tensor. This is provided to save a little time
+ in the update.
+ clipping_update_period: if clipping_scale is specified, this is the period
+ """
+
+ def __init__(
+ self,
+ params,
+ lr=3e-02,
+ clipping_scale=None,
+ betas=(0.9, 0.98),
+ scalar_lr_scale=0.1,
+ eps=1.0e-08,
+ param_min_rms=1.0e-05,
+ param_max_rms=3.0,
+ scalar_max=10.0,
+ size_update_period=4,
+ clipping_update_period=100,
+ parameters_names=None,
+ show_dominant_parameters=True,
+ ):
+ assert parameters_names is not None, (
+ "Please prepare parameters_names,"
+ "which is a List[List[str]]. Each List[str] is for a group"
+ "and each str is for a parameter"
+ )
+ defaults = dict(
+ lr=lr,
+ clipping_scale=clipping_scale,
+ betas=betas,
+ scalar_lr_scale=scalar_lr_scale,
+ eps=eps,
+ param_min_rms=param_min_rms,
+ param_max_rms=param_max_rms,
+ scalar_max=scalar_max,
+ size_update_period=size_update_period,
+ clipping_update_period=clipping_update_period,
+ )
+
+ super(ScaledAdam, self).__init__(params, defaults)
+ assert len(self.param_groups) == len(parameters_names)
+ self.parameters_names = parameters_names
+ self.show_dominant_parameters = show_dominant_parameters
+
+ def __setstate__(self, state):
+ super(ScaledAdam, self).__setstate__(state)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ batch = True
+
+ for group, group_params_names in zip(self.param_groups, self.parameters_names):
+ with self.batched_params(group["params"], group_params_names) as batches:
+ # batches is list of pairs (stacked_param, state). stacked_param is like
+ # a regular parameter, and will have a .grad, but the 1st dim corresponds to
+ # a stacking dim, it is not a real dim.
+
+ if len(batches[0][1]) == 0:
+ clipping_scale = 1
+ else:
+ clipping_scale = self._get_clipping_scale(group, batches)
+
+ for p, state, _ in batches:
+ # Perform optimization step.
+ # grad is not going to be None, we handled that when creating the batches.
+ grad = p.grad
+ if grad.is_sparse:
+ raise RuntimeError(
+ "ScaledAdam optimizer does not support sparse gradients"
+ )
+ # State initialization
+ if len(state) == 0:
+ self._init_state(group, p, state)
+
+ self._step_one_batch(group, p, state, clipping_scale)
+
+ return loss
+
+ def _init_state(self, group: dict, p: Tensor, state: dict):
+ """
+ Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
+ is actually the batch dimension, corresponding to batched-together
+ parameters of a given shape.
+
+
+ Args:
+ group: Dict to look up configuration values.
+ p: The parameter that we are initializing the state for
+ state: Dict from string to whatever state we are initializing
+ """
+ size_update_period = group["size_update_period"]
+
+ state["step"] = 0
+
+ kwargs = {"device": p.device, "dtype": p.dtype}
+
+ # 'delta' implements conventional momentum. There are
+ # several different kinds of update going on, so rather than
+ # compute "exp_avg" like in Adam, we store and decay a
+ # parameter-change "delta", which combines all forms of
+ # update. this is equivalent to how it's done in Adam,
+ # except for the first few steps.
+ state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format)
+
+ batch_size = p.shape[0]
+ numel = p.numel() // batch_size
+ numel = p.numel()
+
+ if numel > 1:
+ # "param_rms" just periodically records the scalar root-mean-square value of
+ # the parameter tensor.
+ # it has a shape like (batch_size, 1, 1, 1, 1)
+ param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
+ state["param_rms"] = param_rms
+
+ state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
+ state["scale_grads"] = torch.zeros(
+ size_update_period, *param_rms.shape, **kwargs
+ )
+
+ # exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
+ state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
+
+ def _get_clipping_scale(
+ self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]
+ ) -> float:
+ """
+ Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
+ by this amount before applying the rest of the update.
+
+ Args:
+ group: the parameter group, an item in self.param_groups
+ tuples: a list of tuples of (param, state, param_names)
+ where param is a batched set of parameters,
+ with a .grad (1st dim is batch dim)
+ and state is the state-dict where optimization parameters are kept.
+ param_names is a List[str] while each str is name for a parameter
+ in batched set of parameters "param".
+ """
+ assert len(tuples) >= 1
+ clipping_scale = group["clipping_scale"]
+ (first_p, first_state, _) = tuples[0]
+ step = first_state["step"]
+ if clipping_scale is None or step == 0:
+ # no clipping. return early on step == 0 because the other
+ # parameters' state won't have been initialized yet.
+ return 1.0
+ clipping_update_period = group["clipping_update_period"]
+
+ tot_sumsq = torch.tensor(0.0, device=first_p.device)
+ for p, state, param_names in tuples:
+ grad = p.grad
+ if grad.is_sparse:
+ raise RuntimeError(
+ "ScaledAdam optimizer does not support sparse gradients"
+ )
+ if p.numel() == p.shape[0]: # a batch of scalars
+ tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
+ else:
+ tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
+
+ tot_norm = tot_sumsq.sqrt()
+ if "model_norms" not in first_state:
+ first_state["model_norms"] = torch.zeros(
+ clipping_update_period, device=p.device
+ )
+ first_state["model_norms"][step % clipping_update_period] = tot_norm
+
+ if step % clipping_update_period == 0:
+ # Print some stats.
+ # We don't reach here if step == 0 because we would have returned
+ # above.
+ sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
+ quartiles = []
+ for n in range(0, 5):
+ index = min(
+ clipping_update_period - 1,
+ (clipping_update_period // 4) * n,
+ )
+ quartiles.append(sorted_norms[index].item())
+
+ median = quartiles[2]
+ threshold = clipping_scale * median
+ first_state["model_norm_threshold"] = threshold
+ percent_clipped = (
+ first_state["num_clipped"] * 100.0 / clipping_update_period
+ if "num_clipped" in first_state
+ else 0.0
+ )
+ first_state["num_clipped"] = 0
+ quartiles = " ".join(["%.3e" % x for x in quartiles])
+ logging.info(
+ f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
+ f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
+ )
+
+ if step < clipping_update_period:
+ return 1.0 # We have not yet estimated a norm to clip to.
+ else:
+ try:
+ model_norm_threshold = first_state["model_norm_threshold"]
+ except KeyError:
+ logging.info(
+ "Warning: model_norm_threshold not in state: possibly "
+ "you changed config when restarting, adding clipping_scale option?"
+ )
+ return 1.0
+ ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
+ if ans < 1.0:
+ first_state["num_clipped"] += 1
+ if ans < 0.1:
+ logging.warn(
+ f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
+ )
+ if self.show_dominant_parameters:
+ assert p.shape[0] == len(param_names)
+ self._show_gradient_dominating_parameter(tuples, tot_sumsq)
+ return ans
+
+ def _show_gradient_dominating_parameter(
+ self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor
+ ):
+ """
+ Show information of parameter wihch dominanting tot_sumsq.
+
+ Args:
+ tuples: a list of tuples of (param, state, param_names)
+ where param is a batched set of parameters,
+ with a .grad (1st dim is batch dim)
+ and state is the state-dict where optimization parameters are kept.
+ param_names is a List[str] while each str is name for a parameter
+ in batched set of parameters "param".
+ tot_sumsq: sumsq of all parameters. Though it's could be calculated
+ from tuples, we still pass it to save some time.
+ """
+ all_sumsq_orig = {}
+ for p, state, batch_param_names in tuples:
+ # p is a stacked batch parameters.
+ batch_grad = p.grad
+ if p.numel() == p.shape[0]: # a batch of scalars
+ batch_sumsq_orig = batch_grad**2
+ # Dummpy values used by following `zip` statement.
+ batch_rms_orig = torch.ones(p.shape[0])
+ else:
+ batch_rms_orig = state["param_rms"]
+ batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(
+ dim=list(range(1, batch_grad.ndim))
+ )
+
+ for name, sumsq_orig, rms, grad in zip(
+ batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
+ ):
+ proportion_orig = sumsq_orig / tot_sumsq
+ all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
+
+ assert torch.isclose(
+ sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
+ torch.tensor(1.0),
+ )
+ sorted_by_proportion = {
+ k: v
+ for k, v in sorted(
+ all_sumsq_orig.items(),
+ key=lambda item: item[1][0],
+ reverse=True,
+ )
+ }
+ dominant_param_name = next(iter(sorted_by_proportion))
+ (
+ dominant_proportion,
+ dominant_sumsq,
+ dominant_rms,
+ dominant_grad,
+ ) = sorted_by_proportion[dominant_param_name]
+ logging.info(
+ f"Parameter Dominanting tot_sumsq {dominant_param_name}"
+ f" with proportion {dominant_proportion:.2f},"
+ f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
+ f"={dominant_sumsq:.3e},"
+ f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
+ f" orig_rms_sq={(dominant_rms**2).item():.3e}"
+ )
+
+ def _step_one_batch(
+ self, group: dict, p: Tensor, state: dict, clipping_scale: float
+ ):
+ """
+ Do the step for one parameter, which is actually going to be a batch of
+ `real` parameters, with dim 0 as the batch dim.
+ Args:
+ group: dict to look up configuration values
+ p: parameter to update (actually multiple parameters stacked together
+ as a batch)
+ state: state-dict for p, to look up the optimizer state
+ """
+ lr = group["lr"]
+ size_update_period = group["size_update_period"]
+ beta1 = group["betas"][0]
+
+ grad = p.grad
+ if clipping_scale != 1.0:
+ grad = grad * clipping_scale
+ step = state["step"]
+ delta = state["delta"]
+
+ delta.mul_(beta1)
+ batch_size = p.shape[0]
+ numel = p.numel() // batch_size
+ if numel > 1:
+ # Update the size/scale of p, and set param_rms
+ scale_grads = state["scale_grads"]
+ scale_grads[step % size_update_period] = (p * grad).sum(
+ dim=list(range(1, p.ndim)), keepdim=True
+ )
+ if step % size_update_period == size_update_period - 1:
+ param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
+ param_rms.copy_(
+ (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
+ )
+ if step > 0:
+ # self._size_update() learns the overall scale on the
+ # parameter, by shrinking or expanding it.
+ self._size_update(group, scale_grads, p, state)
+
+ if numel == 1:
+ # For parameters with 1 element we just use regular Adam.
+ # Updates delta.
+ self._step_scalar(group, p, state)
+ else:
+ self._step(group, p, state)
+
+ state["step"] = step + 1
+
+ def _size_update(
+ self, group: dict, scale_grads: Tensor, p: Tensor, state: dict
+ ) -> None:
+ """
+ Called only where p.numel() > 1, this updates the scale of the parameter.
+ If we imagine: p = underlying_param * scale.exp(), and we are doing
+ gradient descent on underlying param and on scale, this function does the update
+ on `scale`.
+
+ Args:
+ group: dict to look up configuration values
+ scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
+ grads w.r.t. the scales.
+ p: The parameter to update
+ state: The state-dict of p
+ """
+
+ param_rms = state["param_rms"]
+ beta1, beta2 = group["betas"]
+ size_lr = group["lr"] * group["scalar_lr_scale"]
+ param_min_rms = group["param_min_rms"]
+ param_max_rms = group["param_max_rms"]
+ eps = group["eps"]
+ step = state["step"]
+ batch_size = p.shape[0]
+
+ size_update_period = scale_grads.shape[0]
+ # correct beta2 for the size update period: we will have
+ # faster decay at this level.
+ beta2_corr = beta2**size_update_period
+
+ scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
+ scale_exp_avg_sq.mul_(beta2_corr).add_(
+ (scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
+ alpha=1 - beta2_corr,
+ ) # shape is (batch_size, 1, 1, ...)
+
+ # The 1st time we reach here is when size_step == 1.
+ size_step = (step + 1) // size_update_period
+ bias_correction2 = 1 - beta2_corr**size_step
+ # we don't bother with bias_correction1; this will help prevent divergence
+ # at the start of training.
+
+ denom = scale_exp_avg_sq.sqrt() + eps
+
+ scale_step = -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom
+
+ is_too_small = param_rms < param_min_rms
+ is_too_large = param_rms > param_max_rms
+
+ # when the param gets too small, just don't shrink it any further.
+ scale_step.masked_fill_(is_too_small, 0.0)
+ # when it gets too large, stop it from getting any larger.
+ scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
+ delta = state["delta"]
+ # the factor of (1-beta1) relates to momentum.
+ delta.add_(p * scale_step, alpha=(1 - beta1))
+
+ def _step(self, group: dict, p: Tensor, state: dict):
+ """
+ This function does the core update of self.step(), in the case where the members of
+ the batch have more than 1 element.
+
+ Args:
+ group: A dict which will be used to look up configuration values
+ p: The parameter to be updated
+ grad: The grad of p
+ state: The state-dict corresponding to parameter p
+
+ This function modifies p.
+ """
+ grad = p.grad
+ lr = group["lr"]
+ beta1, beta2 = group["betas"]
+ eps = group["eps"]
+ param_min_rms = group["param_min_rms"]
+ step = state["step"]
+
+ exp_avg_sq = state["exp_avg_sq"]
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
+
+ this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0)
+ bias_correction2 = 1 - beta2 ** (this_step + 1)
+ if bias_correction2 < 0.99:
+ # note: not in-place.
+ exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
+
+ denom = exp_avg_sq.sqrt()
+ denom += eps
+ grad = grad / denom
+
+ alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms)
+
+ delta = state["delta"]
+ delta.add_(grad * alpha)
+ p.add_(delta)
+
+ def _step_scalar(self, group: dict, p: Tensor, state: dict):
+ """
+ A simplified form of the core update for scalar tensors, where we cannot get a good
+ estimate of the parameter rms.
+ """
+ beta1, beta2 = group["betas"]
+ scalar_max = group["scalar_max"]
+ eps = group["eps"]
+ lr = group["lr"] * group["scalar_lr_scale"]
+ grad = p.grad
+
+ exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
+
+ # bias_correction2 is like in Adam. Don't bother with bias_correction1;
+ # slower update at the start will help stability anyway.
+ bias_correction2 = 1 - beta2 ** (state["step"] + 1)
+ denom = (exp_avg_sq / bias_correction2).sqrt() + eps
+
+ delta = state["delta"]
+ delta.add_(grad / denom, alpha=-lr * (1 - beta1))
+ p.clamp_(min=-scalar_max, max=scalar_max)
+ p.add_(delta)
diff --git a/preprocessors/Emilia/README.md b/preprocessors/Emilia/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..2fb9db5d69d5bf4d8ad313de3e286f4297462ec1
--- /dev/null
+++ b/preprocessors/Emilia/README.md
@@ -0,0 +1,230 @@
+# Emilia: An Extensive, Multilingual, and Diverse Speech Dataset for Large-Scale Speech Generation
+[![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2407.05361) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Dataset-yellow)](https://huggingface.co/datasets/amphion/Emilia-Dataset) [![OpenDataLab](https://img.shields.io/badge/OpenDataLab-Dataset-blue)](https://opendatalab.com/Amphion/Emilia) [![GitHub](https://img.shields.io/badge/GitHub-Repo-green)](https://github.com/open-mmlab/Amphion/tree/main/preprocessors/Emilia) [![demo](https://img.shields.io/badge/WebPage-Demo-red)](https://emilia-dataset.github.io/Emilia-Demo-Page/)
+
+This is the official repository 👑 for the **Emilia** dataset and the source code for **Emilia-Pipe** speech data preprocessing pipeline.
+
+
+
+## News 🔥
+- **2024/09/01**: [Emilia](https://arxiv.org/abs/2407.05361) got accepted by IEEE SLT 2024! 🤗
+- **2024/08/28**: Welcome to join Amphion's [Discord channel](https://discord.com/invite/ZxxREr3Y) to stay connected and engage with our community!
+- **2024/08/27**: *The Emilia dataset is now publicly available!* Discover the most extensive and diverse speech generation dataset with 101k hours of in-the-wild speech data now at [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Dataset-yellow)](https://huggingface.co/datasets/amphion/Emilia-Dataset) or [![OpenDataLab](https://img.shields.io/badge/OpenDataLab-Dataset-blue)](https://opendatalab.com/Amphion/Emilia)! 👑👑👑
+- **2024/07/08**: Our preprint [paper](https://arxiv.org/abs/2407.05361) is now available! 🔥🔥🔥
+- **2024/07/03**: We welcome everyone to check our [homepage](https://emilia-dataset.github.io/Emilia-Demo-Page/) for our brief introduction for Emilia dataset and our demos!
+- **2024/07/01**: We release of Emilia and Emilia-Pipe! We welcome everyone to explore it on our [GitHub](https://github.com/open-mmlab/Amphion/tree/main/preprocessors/Emilia)! 🎉🎉🎉
+
+## Emilia Overview ⭐️
+The **Emilia** dataset is a comprehensive, multilingual dataset with the following features:
+- containing over *101k* hours of speech data;
+- covering six different languages: *English (En), Chinese (Zh), German (De), French (Fr), Japanese (Ja), and Korean (Ko)*;
+- containing diverse speech data with *various speaking styles* from diverse video platforms and podcasts on the Internet, covering various content genres such as talk shows, interviews, debates, sports commentary, and audiobooks.
+
+The table below provides the duration statistics for each language in the dataset.
+
+| Language | Duration (hours) |
+|:-----------:|:----------------:|
+| English | 46,828 |
+| Chinese | 49,922 |
+| German | 1,590 |
+| French | 1,381 |
+| Japanese | 1,715 |
+| Korean | 217 |
+
+
+The **Emilia-Pipe** is the first open-source preprocessing pipeline designed to transform raw, in-the-wild speech data into high-quality training data with annotations for speech generation. This pipeline can process one hour of raw audio into model-ready data in just a few minutes, requiring only the raw speech data.
+
+Detailed description for the Emilia and Emilia-Pipe could be found in our [paper](https://arxiv.org/abs/2407.05361).
+
+## Emilia Dataset Usage 📖
+The Emilia dataset is now publicly available at [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Dataset-yellow)](https://huggingface.co/datasets/amphion/Emilia-Dataset)! Users in mainland China can also download Emilia from [![OpenDataLab](https://img.shields.io/badge/OpenDataLab-Dataset-blue)](https://opendatalab.com/Amphion/Emilia)!
+
+- To download from HuggingFace, you must first gain access to the dataset by completing the request form and accepting the terms of access. Please note that due to HuggingFace's file size limit of 50 GB, the `EN/EN_B00008.tar.gz` file has been split into `EN/EN_B00008.tar.gz.0` and `EN/EN_B00008.tar.gz.1`. Before extracting the files, you will need to run the following command to combine the parts: `cat EN/EN_B00008.tar.gz.* > EN/EN_B00008.tar.gz`
+
+- To download from OpenDataLab (i.e., OpenXLab), please follow the guidence [here](https://speechteam.feishu.cn/wiki/PC8Ew5igviqBiJkElMJcJxNonJc) to gain access.
+
+**ENJOY USING EMILIA!!!** 🔥
+
+If you wish to re-build Emilia from scratch, you may download the raw audio files from the [provided URL list](https://huggingface.co/datasets/amphion/Emilia) and use our open-source [Emilia-Pipe](https://github.com/open-mmlab/Amphion/tree/main/preprocessors/Emilia) preprocessing pipeline to preprocess the raw data. Additionally, users can easily use Emilia-Pipe to preprocess their own raw speech data for custom needs. By open-sourcing the Emilia-Pipe code, we aim to enable the speech community to collaborate on large-scale speech generation research.
+
+*Please note that Emilia does not own the copyright to the audio files; the copyright remains with the original owners of the videos or audio. Users are permitted to use this dataset only for non-commercial purposes under the CC BY-NC-4.0 license.*
+
+## Emilia Dataset Structure ⛪️
+The Emilia dataset will be structured as follows:
+
+Structure example:
+```
+|-- openemilia_all.tar.gz (all .JSONL files are gzipped with directory structure in this file)
+|-- EN (114 batches)
+| |-- EN_B00000.jsonl
+| |-- EN_B00000 (= EN_B00000.tar.gz)
+| | |-- EN_B00000_S00000
+| | | `-- mp3
+| | | |-- EN_B00000_S00000_W000000.mp3
+| | | `-- EN_B00000_S00000_W000001.mp3
+| | |-- ...
+| |-- ...
+| |-- EN_B00113.jsonl
+| `-- EN_B00113
+|-- ZH (92 batches)
+|-- DE (9 batches)
+|-- FR (10 batches)
+|-- JA (7 batches)
+|-- KO (4 batches)
+
+```
+JSONL files example:
+```
+{"id": "EN_B00000_S00000_W000000", "wav": "EN_B00000/EN_B00000_S00000/mp3/EN_B00000_S00000_W000000.mp3", "text": " You can help my mother and you- No. You didn't leave a bad situation back home to get caught up in another one here. What happened to you, Los Angeles?", "duration": 6.264, "speaker": "EN_B00000_S00000", "language": "en", "dnsmos": 3.2927}
+{"id": "EN_B00000_S00000_W000001", "wav": "EN_B00000/EN_B00000_S00000/mp3/EN_B00000_S00000_W000001.mp3", "text": " Honda's gone, 20 squads done. X is gonna split us up and put us on different squads. The team's come and go, but 20 squad, can't believe it's ending.", "duration": 8.031, "speaker": "EN_B00000_S00000", "language": "en", "dnsmos": 3.0442}
+```
+
+
+## Emilia-Pipe Overview 👀
+The Emilia-Pipe includes the following major steps:
+
+0. Standardization:Audio normalization
+1. Source Separation: Long audio -> Long audio without BGM
+2. Speaker Diarization: Get medium-length single-speaker speech data
+3. Fine-grained Segmentation by VAD: Get 3-30s single-speaker speech segments
+4. ASR: Get transcriptions of the speech segments
+5. Filtering: Obtain the final processed dataset
+
+## Setup Steps 👨💻
+
+### 0. Prepare Environment
+
+1. Install Python and CUDA.
+2. Run the following commands to install the required packages:
+
+ ```bash
+ conda create -y -n AudioPipeline python=3.9
+ conda activate AudioPipeline
+
+ bash env.sh
+ ```
+
+3. Download the model files from the third-party repositories.
+ - Manually download the checkpoints of UVR-MDX-NET-Inst_HQ_3 ([UVR-MDX-NET-Inst_3.onnx](https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/UVR-MDX-NET-Inst_HQ_3.onnx)) and DNSMOS P.835 ([sig_bak_ovr.onnx](https://github.com/microsoft/DNS-Challenge/blob/master/DNSMOS/DNSMOS/sig_bak_ovr.onnx)), then save their path for the next step configuration (i.e. #2 and #3 TODO).
+ - Creat the access token to pyannote/speaker-diarization-3.1 following [the guide](https://huggingface.co/pyannote/speaker-diarization-3.1#requirements), then save it for the next step configuration (i.e. #4 TODO).
+ - Make sure you have stable connection to GitHub and HuggingFace. The checkpoints of Silero and Whisperx-medium will be downloaded automatically on the pipeline's first run.
+
+
+### 1. Modify Config File
+
+Change the config.json file according to the following TODOs.
+
+```json
+{
+ "language": {
+ "multilingual": true,
+ "supported": [
+ "zh",
+ "en",
+ "fr",
+ "ja",
+ "ko",
+ "de"
+ ]
+ },
+ "entrypoint": {
+ // TODO: Fill in the input_folder_path.
+ "input_folder_path": "examples", // #1: Data input folder for processing
+ "SAMPLE_RATE": 24000
+ },
+ "separate": {
+ "step1": {
+ // TODO: Fill in the source separation model's path.
+ "model_path": "/path/to/model/separate_model/UVR-MDX-NET-Inst_HQ_3.onnx", // #2: Model path
+ "denoise": true,
+ "margin": 44100,
+ "chunks": 15,
+ "n_fft": 6144,
+ "dim_t": 8,
+ "dim_f": 3072
+ }
+ },
+ "mos_model": {
+ // TODO: Fill in the DNSMOS prediction model's path.
+ "primary_model_path": "/path/to/model/mos_model/DNSMOS/sig_bak_ovr.onnx" // #3: Model path
+ },
+ // TODO: Fill in your huggingface access token for pynannote.
+ "huggingface_token": "" // #4: Huggingface access token for pyannote
+}
+```
+
+### 2. Run Script
+
+1. Change the `input_folder_path` in `config.json` to the folder path where the downloaded audio files are stored (i.e. #1 TODO).
+2. Run the following command to process the audio files:
+
+```bash
+conda activate AudioPipeline
+export CUDA_VISIBLE_DEVICES=0 # Setting the GPU to run the pipeline, separate by comma
+
+python main.py
+```
+
+3. Processed audio will be saved into `input_folder_path`_processed folder.
+
+
+### 3. Check the Results
+
+The processed audio (default 24k sample rate) files will be saved into `input_folder_path`_processed folder. The results for a single audio will be saved in a same folder with its original name and include the following information:
+
+1. **MP3 file**: `_.mp3` where `idx` is corresponding to the index in the JSON-encoded array.
+2. **JSON file**: `.json`
+
+```json
+[
+ {
+ "text": "So, don't worry about that. But, like for instance, like yesterday was very hard for me to say, you know what, I should go to bed.", // Transcription
+ "start": 67.18, // Start timestamp, in second unit
+ "end": 74.41, // End timestamp, in second unit
+ "language": "en", // Language
+ "dnsmos": 3.44 // DNSMOS P.835 score
+ }
+]
+```
+
+## TODOs 📝
+
+Here are some potential improvements for the Emilia-Pipe pipeline:
+
+- [x] Optimize the pipeline for better processing speed.
+- [ ] Support input audio files larger than 4GB (calculated in WAVE format).
+- [ ] Update source separation model to better handle noisy audio (e.g., reverberation).
+- [ ] Ensure single speaker in each segment in the speaker diarization step.
+- [ ] Move VAD to the first step to filter out non-speech segments. (for better speed)
+- [ ] Extend ASR supported max length over 30s while keeping the speed.
+- [ ] Fine-tune the ASR model to improve transcription accuracy on puctuation.
+- [ ] Adding multimodal features to the pipeline for better transcription accuracy.
+- [ ] Filter segments with unclean background noise, speaker overlap, hallucination transcriptions, etc.
+- [ ] Labeling the data: speaker info (e.g., gender, age, native language, health), emotion, speaking style (pitch, rate, accent), acoustic features (e.g., fundamental frequency, formants), and environmental factors (background noise, microphone setup). Besides, non-verbal cues (e.g., laughter, coughing, silence, filters) and paralinguistic features could be labeled as well.
+
+## Acknowledgement 🔔
+We acknowledge the wonderful work by these excellent developers!
+- Source Separation: [UVR-MDX-NET-Inst_HQ_3](https://github.com/TRvlvr/model_repo/releases/tag/all_public_uvr_models)
+- VAD: [snakers4/silero-vad](https://github.com/snakers4/silero-vad)
+- Speaker Diarization: [pyannote/speaker-diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1)
+- ASR: [m-bain/whisperX](https://github.com/m-bain/whisperX), using [faster-whisper](https://github.com/guillaumekln/faster-whisper) and [CTranslate2](https://github.com/OpenNMT/CTranslate2) backend.
+- DNSMOS Prediction: [DNSMOS P.835](https://github.com/microsoft/DNS-Challenge)
+
+
+## Reference 📖
+If you use the Emilia dataset or the Emilia-Pipe pipeline, please cite the following papers:
+```bibtex
+@inproceedings{emilia,
+ author={He, Haorui and Shang, Zengqiang and Wang, Chaoren and Li, Xuyuan and Gu, Yicheng and Hua, Hua and Liu, Liwei and Yang, Chen and Li, Jiaqi and Shi, Peiyang and Wang, Yuancheng and Chen, Kai and Zhang, Pengyuan and Wu, Zhizheng},
+ title={Emilia: An Extensive, Multilingual, and Diverse Speech Dataset for Large-Scale Speech Generation},
+ booktitle={Proc.~of SLT},
+ year={2024}
+}
+```
+```bibtex
+@inproceedings{amphion,
+ author={Zhang, Xueyao and Xue, Liumeng and Gu, Yicheng and Wang, Yuancheng and Li, Jiaqi and He, Haorui and Wang, Chaoren and Song, Ting and Chen, Xi and Fang, Zihao and Chen, Haopeng and Zhang, Junan and Tang, Tze Ying and Zou, Lexiao and Wang, Mingxuan and Han, Jun and Chen, Kai and Li, Haizhou and Wu, Zhizheng},
+ title={Amphion: An Open-Source Audio, Music and Speech Generation Toolkit},
+ booktitle={{IEEE} Spoken Language Technology Workshop, {SLT} 2024},
+ year={2024}
+}
+```
diff --git a/preprocessors/Emilia/config.json b/preprocessors/Emilia/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..bf5da33298e2cfd801945e43722b84660f50b864
--- /dev/null
+++ b/preprocessors/Emilia/config.json
@@ -0,0 +1,35 @@
+{
+ "language": {
+ "multilingual": true,
+ "supported": [
+ "zh",
+ "en",
+ "fr",
+ "ja",
+ "ko",
+ "de"
+ ]
+ },
+ "entrypoint": {
+ // TODO: Fill in the input_folder_path.
+ "input_folder_path": "examples",
+ "SAMPLE_RATE": 24000
+ },
+ "separate": {
+ "step1": {
+ // TODO: Fill in the source separation model's path.
+ "model_path": "/path/to/model/separate_model/UVR-MDX-NET-Inst_HQ_3.onnx",
+ "denoise": true,
+ "margin": 44100,
+ "chunks": 15,
+ "n_fft": 6144,
+ "dim_t": 8,
+ "dim_f": 3072
+ }
+ },
+ "mos_model": {
+ // TODO: Fill in the DNSMOS prediction model's path.
+ "primary_model_path": "/path/to/model/mos_model/DNSMOS/sig_bak_ovr.onnx"
+ },
+ "huggingface_token": ""
+}
\ No newline at end of file
diff --git a/preprocessors/Emilia/env.sh b/preprocessors/Emilia/env.sh
new file mode 100644
index 0000000000000000000000000000000000000000..bbc4b1d27e67d5ad69bca401f14de3153e7eef33
--- /dev/null
+++ b/preprocessors/Emilia/env.sh
@@ -0,0 +1,10 @@
+#!/bin/bash
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+conda install ffmpeg -y
+conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia -y
+pip install -r requirements.txt
+pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
diff --git a/preprocessors/Emilia/main.py b/preprocessors/Emilia/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1663c313cdc5a20250e2560ade5ba8136a1e1eb
--- /dev/null
+++ b/preprocessors/Emilia/main.py
@@ -0,0 +1,571 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import json
+import librosa
+import numpy as np
+import sys
+import os
+import tqdm
+import warnings
+import torch
+from pydub import AudioSegment
+from pyannote.audio import Pipeline
+import pandas as pd
+
+from utils.tool import (
+ export_to_mp3,
+ load_cfg,
+ get_audio_files,
+ detect_gpu,
+ check_env,
+ calculate_audio_stats,
+)
+from utils.logger import Logger, time_logger
+from models import separate_fast, dnsmos, whisper_asr, silero_vad
+
+warnings.filterwarnings("ignore")
+audio_count = 0
+
+
+@time_logger
+def standardization(audio):
+ """
+ Preprocess the audio file, including setting sample rate, bit depth, channels, and volume normalization.
+
+ Args:
+ audio (str or AudioSegment): Audio file path or AudioSegment object, the audio to be preprocessed.
+
+ Returns:
+ dict: A dictionary containing the preprocessed audio waveform, audio file name, and sample rate, formatted as:
+ {
+ "waveform": np.ndarray, the preprocessed audio waveform, dtype is np.float32, shape is (num_samples,)
+ "name": str, the audio file name
+ "sample_rate": int, the audio sample rate
+ }
+
+ Raises:
+ ValueError: If the audio parameter is neither a str nor an AudioSegment.
+ """
+ global audio_count
+ name = "audio"
+
+ if isinstance(audio, str):
+ name = os.path.basename(audio)
+ audio = AudioSegment.from_file(audio)
+ elif isinstance(audio, AudioSegment):
+ name = f"audio_{audio_count}"
+ audio_count += 1
+ else:
+ raise ValueError("Invalid audio type")
+
+ logger.debug("Entering the preprocessing of audio")
+
+ # Convert the audio file to WAV format
+ audio = audio.set_frame_rate(cfg["entrypoint"]["SAMPLE_RATE"])
+ audio = audio.set_sample_width(2) # Set bit depth to 16bit
+ audio = audio.set_channels(1) # Set to mono
+
+ logger.debug("Audio file converted to WAV format")
+
+ # Calculate the gain to be applied
+ target_dBFS = -20
+ gain = target_dBFS - audio.dBFS
+ logger.info(f"Calculating the gain needed for the audio: {gain} dB")
+
+ # Normalize volume and limit gain range to between -3 and 3
+ normalized_audio = audio.apply_gain(min(max(gain, -3), 3))
+
+ waveform = np.array(normalized_audio.get_array_of_samples(), dtype=np.float32)
+ max_amplitude = np.max(np.abs(waveform))
+ waveform /= max_amplitude # Normalize
+
+ logger.debug(f"waveform shape: {waveform.shape}")
+ logger.debug("waveform in np ndarray, dtype=" + str(waveform.dtype))
+
+ return {
+ "waveform": waveform,
+ "name": name,
+ "sample_rate": cfg["entrypoint"]["SAMPLE_RATE"],
+ }
+
+
+@time_logger
+def source_separation(predictor, audio):
+ """
+ Separate the audio into vocals and non-vocals using the given predictor.
+
+ Args:
+ predictor: The separation model predictor.
+ audio (str or dict): The audio file path or a dictionary containing audio waveform and sample rate.
+
+ Returns:
+ dict: A dictionary containing the separated vocals and updated audio waveform.
+ """
+
+ mix, rate = None, None
+
+ if isinstance(audio, str):
+ mix, rate = librosa.load(audio, mono=False, sr=44100)
+ else:
+ # resample to 44100
+ rate = audio["sample_rate"]
+ mix = librosa.resample(audio["waveform"], orig_sr=rate, target_sr=44100)
+
+ vocals, no_vocals = predictor.predict(mix)
+
+ # convert vocals back to previous sample rate
+ logger.debug(f"vocals shape before resample: {vocals.shape}")
+ vocals = librosa.resample(vocals.T, orig_sr=44100, target_sr=rate).T
+ logger.debug(f"vocals shape after resample: {vocals.shape}")
+ audio["waveform"] = vocals[:, 0] # vocals is stereo, only use one channel
+
+ return audio
+
+
+# Step 2: Speaker Diarization
+@time_logger
+def speaker_diarization(audio):
+ """
+ Perform speaker diarization on the given audio.
+
+ Args:
+ audio (dict): A dictionary containing the audio waveform and sample rate.
+
+ Returns:
+ pd.DataFrame: A dataframe containing segments with speaker labels.
+ """
+ logger.debug(f"Start speaker diarization")
+ logger.debug(f"audio waveform shape: {audio['waveform'].shape}")
+
+ waveform = torch.tensor(audio["waveform"]).to(device)
+ waveform = torch.unsqueeze(waveform, 0)
+
+ segments = dia_pipeline(
+ {
+ "waveform": waveform,
+ "sample_rate": audio["sample_rate"],
+ "channel": 0,
+ }
+ )
+
+ diarize_df = pd.DataFrame(
+ segments.itertracks(yield_label=True),
+ columns=["segment", "label", "speaker"],
+ )
+ diarize_df["start"] = diarize_df["segment"].apply(lambda x: x.start)
+ diarize_df["end"] = diarize_df["segment"].apply(lambda x: x.end)
+
+ logger.debug(f"diarize_df: {diarize_df}")
+
+ return diarize_df
+
+
+@time_logger
+def cut_by_speaker_label(vad_list):
+ """
+ Merge and trim VAD segments by speaker labels, enforcing constraints on segment length and merge gaps.
+
+ Args:
+ vad_list (list): List of VAD segments with start, end, and speaker labels.
+
+ Returns:
+ list: A list of updated VAD segments after merging and trimming.
+ """
+ MERGE_GAP = 2 # merge gap in seconds, if smaller than this, merge
+ MIN_SEGMENT_LENGTH = 3 # min segment length in seconds
+ MAX_SEGMENT_LENGTH = 30 # max segment length in seconds
+
+ updated_list = []
+
+ for idx, vad in enumerate(vad_list):
+ last_start_time = updated_list[-1]["start"] if updated_list else None
+ last_end_time = updated_list[-1]["end"] if updated_list else None
+ last_speaker = updated_list[-1]["speaker"] if updated_list else None
+
+ if vad["end"] - vad["start"] >= MAX_SEGMENT_LENGTH:
+ current_start = vad["start"]
+ segment_end = vad["end"]
+ logger.warning(
+ f"cut_by_speaker_label > segment longer than 30s, force trimming to 30s smaller segments"
+ )
+ while segment_end - current_start >= MAX_SEGMENT_LENGTH:
+ vad["end"] = current_start + MAX_SEGMENT_LENGTH # update end time
+ updated_list.append(vad)
+ vad = vad.copy()
+ current_start += MAX_SEGMENT_LENGTH
+ vad["start"] = current_start # update start time
+ vad["end"] = segment_end
+ updated_list.append(vad)
+ continue
+
+ if (
+ last_speaker is None
+ or last_speaker != vad["speaker"]
+ or vad["end"] - vad["start"] >= MIN_SEGMENT_LENGTH
+ ):
+ updated_list.append(vad)
+ continue
+
+ if (
+ vad["start"] - last_end_time >= MERGE_GAP
+ or vad["end"] - last_start_time >= MAX_SEGMENT_LENGTH
+ ):
+ updated_list.append(vad)
+ else:
+ updated_list[-1]["end"] = vad["end"] # merge the time
+
+ logger.debug(
+ f"cut_by_speaker_label > merged {len(vad_list) - len(updated_list)} segments"
+ )
+
+ filter_list = [
+ vad for vad in updated_list if vad["end"] - vad["start"] >= MIN_SEGMENT_LENGTH
+ ]
+
+ logger.debug(
+ f"cut_by_speaker_label > removed: {len(updated_list) - len(filter_list)} segments by length"
+ )
+
+ return filter_list
+
+
+@time_logger
+def asr(vad_segments, audio):
+ """
+ Perform Automatic Speech Recognition (ASR) on the VAD segments of the given audio.
+
+ Args:
+ vad_segments (list): List of VAD segments with start and end times.
+ audio (dict): A dictionary containing the audio waveform and sample rate.
+
+ Returns:
+ list: A list of ASR results with transcriptions and language details.
+ """
+ if len(vad_segments) == 0:
+ return []
+
+ temp_audio = audio["waveform"]
+ start_time = vad_segments[0]["start"]
+ end_time = vad_segments[-1]["end"]
+ start_frame = int(start_time * audio["sample_rate"])
+ end_frame = int(end_time * audio["sample_rate"])
+ temp_audio = temp_audio[start_frame:end_frame] # remove silent start and end
+
+ # update vad_segments start and end time (this is a little trick for batched asr:)
+ for idx, segment in enumerate(vad_segments):
+ vad_segments[idx]["start"] -= start_time
+ vad_segments[idx]["end"] -= start_time
+
+ # resample to 16k
+ temp_audio = librosa.resample(
+ temp_audio, orig_sr=audio["sample_rate"], target_sr=16000
+ )
+
+ if multilingual_flag:
+ logger.debug("Multilingual flag is on")
+ valid_vad_segments, valid_vad_segments_language = [], []
+ # get valid segments to be transcripted
+ for idx, segment in enumerate(vad_segments):
+ start_frame = int(segment["start"] * 16000)
+ end_frame = int(segment["end"] * 16000)
+ segment_audio = temp_audio[start_frame:end_frame]
+ language, prob = asr_model.detect_language(segment_audio)
+ # 1. if language is in supported list, 2. if prob > 0.8
+ if language in supported_languages and prob > 0.8:
+ valid_vad_segments.append(vad_segments[idx])
+ valid_vad_segments_language.append(language)
+
+ # if no valid segment, return empty
+ if len(valid_vad_segments) == 0:
+ return []
+ all_transcribe_result = []
+ logger.debug(f"valid_vad_segments_language: {valid_vad_segments_language}")
+ unique_languages = list(set(valid_vad_segments_language))
+ logger.debug(f"unique_languages: {unique_languages}")
+ # process each language one by one
+ for language_token in unique_languages:
+ language = language_token
+ # filter out segments with different language
+ vad_segments = [
+ valid_vad_segments[i]
+ for i, x in enumerate(valid_vad_segments_language)
+ if x == language
+ ]
+ # bacthed trascription
+ transcribe_result_temp = asr_model.transcribe(
+ temp_audio,
+ vad_segments,
+ batch_size=batch_size,
+ language=language,
+ print_progress=True,
+ )
+ result = transcribe_result_temp["segments"]
+ # restore the segment annotation
+ for idx, segment in enumerate(result):
+ result[idx]["start"] += start_time
+ result[idx]["end"] += start_time
+ result[idx]["language"] = transcribe_result_temp["language"]
+ all_transcribe_result.extend(result)
+ # sort by start time
+ all_transcribe_result = sorted(all_transcribe_result, key=lambda x: x["start"])
+ return all_transcribe_result
+ else:
+ logger.debug("Multilingual flag is off")
+ language, prob = asr_model.detect_language(temp_audio)
+ if language in supported_languages and prob > 0.8:
+ transcribe_result = asr_model.transcribe(
+ temp_audio,
+ vad_segments,
+ batch_size=batch_size,
+ language=language,
+ print_progress=True,
+ )
+ result = transcribe_result["segments"]
+ for idx, segment in enumerate(result):
+ result[idx]["start"] += start_time
+ result[idx]["end"] += start_time
+ result[idx]["language"] = transcribe_result["language"]
+ return result
+ else:
+ return []
+
+
+@time_logger
+def mos_prediction(audio, vad_list):
+ """
+ Predict the Mean Opinion Score (MOS) for the given audio and VAD segments.
+
+ Args:
+ audio (dict): A dictionary containing the audio waveform and sample rate.
+ vad_list (list): List of VAD segments with start and end times.
+
+ Returns:
+ tuple: A tuple containing the average MOS and the updated VAD segments with MOS scores.
+ """
+ audio = audio["waveform"]
+ sample_rate = 16000
+
+ audio = librosa.resample(
+ audio, orig_sr=cfg["entrypoint"]["SAMPLE_RATE"], target_sr=sample_rate
+ )
+
+ for index, vad in enumerate(tqdm.tqdm(vad_list, desc="DNSMOS")):
+ start, end = int(vad["start"] * sample_rate), int(vad["end"] * sample_rate)
+ segment = audio[start:end]
+
+ dnsmos = dnsmos_compute_score(segment, sample_rate, False)["OVRL"]
+
+ vad_list[index]["dnsmos"] = dnsmos
+
+ predict_dnsmos = np.mean([vad["dnsmos"] for vad in vad_list])
+
+ logger.debug(f"avg predict_dnsmos for whole audio: {predict_dnsmos}")
+
+ return predict_dnsmos, vad_list
+
+
+def filter(mos_list):
+ """
+ Filter out the segments with MOS scores, wrong char duration, and total duration.
+
+ Args:
+ mos_list (list): List of VAD segments with MOS scores.
+
+ Returns:
+ list: A list of VAD segments with MOS scores above the average MOS.
+ """
+ filtered_audio_stats, all_audio_stats = calculate_audio_stats(mos_list)
+ filtered_segment = len(filtered_audio_stats)
+ all_segment = len(all_audio_stats)
+ logger.debug(
+ f"> {all_segment - filtered_segment}/{all_segment} {(all_segment - filtered_segment) / all_segment:.2%} segments filtered."
+ )
+ filtered_list = [mos_list[idx] for idx, _ in filtered_audio_stats]
+ return filtered_list
+
+
+def main_process(audio_path, save_path=None, audio_name=None):
+ """
+ Process the audio file, including standardization, source separation, speaker segmentation, VAD, ASR, export to MP3, and MOS prediction.
+
+ Args:
+ audio_path (str): Audio file path.
+ save_path (str, optional): Save path, defaults to None, which means saving in the "_processed" folder in the audio file's directory.
+ audio_name (str, optional): Audio file name, defaults to None, which means using the file name from the audio file path.
+
+ Returns:
+ tuple: Contains the save path and the MOS list.
+ """
+ if not audio_path.endswith((".mp3", ".wav", ".flac", ".m4a", ".aac")):
+ logger.warning(f"Unsupported file type: {audio_path}")
+
+ # for a single audio from path Ïaaa/bbb/ccc.wav ---> save to aaa/bbb_processed/ccc/ccc_0.wav
+ audio_name = audio_name or os.path.splitext(os.path.basename(audio_path))[0]
+ save_path = save_path or os.path.join(
+ os.path.dirname(audio_path) + "_processed", audio_name
+ )
+ os.makedirs(save_path, exist_ok=True)
+ logger.debug(
+ f"Processing audio: {audio_name}, from {audio_path}, save to: {save_path}"
+ )
+
+ logger.info(
+ "Step 0: Preprocess all audio files --> 24k sample rate + wave format + loudnorm + bit depth 16"
+ )
+ audio = standardization(audio_path)
+
+ logger.info("Step 1: Source Separation")
+ audio = source_separation(separate_predictor1, audio)
+
+ logger.info("Step 2: Speaker Diarization")
+ speakerdia = speaker_diarization(audio)
+
+ logger.info("Step 3: Fine-grained Segmentation by VAD")
+ vad_list = vad.vad(speakerdia, audio)
+ segment_list = cut_by_speaker_label(vad_list) # post process after vad
+
+ logger.info("Step 4: ASR")
+ asr_result = asr(segment_list, audio)
+
+ logger.info("Step 5: Filter")
+ logger.info("Step 5.1: calculate mos_prediction")
+ avg_mos, mos_list = mos_prediction(audio, asr_result)
+
+ logger.info(f"Step 5.1: done, average MOS: {avg_mos}")
+
+ logger.info("Step 5.2: Filter out files with less than average MOS")
+ filtered_list = filter(mos_list)
+
+ logger.info("Step 6: write result into MP3 and JSON file")
+ export_to_mp3(audio, filtered_list, save_path, audio_name)
+
+ final_path = os.path.join(save_path, audio_name + ".json")
+ with open(final_path, "w") as f:
+ json.dump(filtered_list, f, ensure_ascii=False)
+
+ logger.info(f"All done, Saved to: {final_path}")
+ return final_path, filtered_list
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--input_folder_path",
+ type=str,
+ default="",
+ help="input folder path, this will override config if set",
+ )
+ parser.add_argument(
+ "--config_path", type=str, default="config.json", help="config path"
+ )
+ parser.add_argument("--batch_size", type=int, default=16, help="batch size")
+ parser.add_argument(
+ "--compute_type",
+ type=str,
+ default="float16",
+ help="The compute type to use for the model",
+ )
+ parser.add_argument(
+ "--whisper_arch",
+ type=str,
+ default="medium",
+ help="The name of the Whisper model to load.",
+ )
+ parser.add_argument(
+ "--threads",
+ type=int,
+ default=4,
+ help="The number of CPU threads to use per worker, e.g. will be multiplied by num workers.",
+ )
+ parser.add_argument(
+ "--exit_pipeline",
+ type=bool,
+ default=False,
+ help="Exit pipeline when task done.",
+ )
+ args = parser.parse_args()
+
+ batch_size = args.batch_size
+ cfg = load_cfg(args.config_path)
+
+ logger = Logger.get_logger()
+
+ if args.input_folder_path:
+ logger.info(f"Using input folder path: {args.input_folder_path}")
+ cfg["entrypoint"]["input_folder_path"] = args.input_folder_path
+
+ logger.debug("Loading models...")
+
+ # Load models
+ if detect_gpu():
+ logger.info("Using GPU")
+ device_name = "cuda"
+ device = torch.device(device_name)
+ else:
+ logger.info("Using CPU")
+ device_name = "cpu"
+ device = torch.device(device_name)
+
+ check_env(logger)
+
+ # Speaker Diarization
+ logger.debug(" * Loading Speaker Diarization Model")
+ if not cfg["huggingface_token"].startswith("hf"):
+ raise ValueError(
+ "huggingface_token must start with 'hf', check the config file. "
+ "You can get the token at https://huggingface.co/settings/tokens. "
+ "Remeber grant access following https://github.com/pyannote/pyannote-audio?tab=readme-ov-file#tldr"
+ )
+ dia_pipeline = Pipeline.from_pretrained(
+ "pyannote/speaker-diarization-3.1",
+ use_auth_token=cfg["huggingface_token"],
+ )
+ dia_pipeline.to(device)
+
+ # ASR
+ logger.debug(" * Loading ASR Model")
+ asr_model = whisper_asr.load_asr_model(
+ args.whisper_arch,
+ device_name,
+ compute_type=args.compute_type,
+ threads=args.threads,
+ asr_options={
+ "initial_prompt": "Um, Uh, Ah. Like, you know. I mean, right. Actually. Basically, and right? okay. Alright. Emm. So. Oh. 生于忧患,死于安乐。岂不快哉?当然,嗯,呃,就,这样,那个,哪个,啊,呀,哎呀,哎哟,唉哇,啧,唷,哟,噫!微斯人,吾谁与归?ええと、あの、ま、そう、ええ。äh, hm, so, tja, halt, eigentlich. euh, quoi, bah, ben, tu vois, tu sais, t'sais, eh bien, du coup. genre, comme, style. 응,어,그,음."
+ },
+ )
+
+ # VAD
+ logger.debug(" * Loading VAD Model")
+ vad = silero_vad.SileroVAD(device=device)
+
+ # Background Noise Separation
+ logger.debug(" * Loading Background Noise Model")
+ separate_predictor1 = separate_fast.Predictor(
+ args=cfg["separate"]["step1"], device=device_name
+ )
+
+ # DNSMOS Scoring
+ logger.debug(" * Loading DNSMOS Model")
+ primary_model_path = cfg["mos_model"]["primary_model_path"]
+ dnsmos_compute_score = dnsmos.ComputeScore(primary_model_path, device_name)
+ logger.debug("All models loaded")
+
+ supported_languages = cfg["language"]["supported"]
+ multilingual_flag = cfg["language"]["multilingual"]
+ logger.debug(f"supported languages multilingual {supported_languages}")
+ logger.debug(f"using multilingual asr {multilingual_flag}")
+
+ input_folder_path = cfg["entrypoint"]["input_folder_path"]
+
+ if not os.path.exists(input_folder_path):
+ raise FileNotFoundError(f"input_folder_path: {input_folder_path} not found")
+
+ audio_paths = get_audio_files(input_folder_path) # Get all audio files
+ logger.debug(f"Scanning {len(audio_paths)} audio files in {input_folder_path}")
+
+ for path in audio_paths:
+ main_process(path)
diff --git a/preprocessors/Emilia/main_multi.py b/preprocessors/Emilia/main_multi.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f27e7df4925383fcdfb2ff3b15e410a3b4e629b
--- /dev/null
+++ b/preprocessors/Emilia/main_multi.py
@@ -0,0 +1,118 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import multiprocessing
+import os
+import subprocess
+import time
+
+from utils.logger import Logger
+from utils.tool import get_gpu_nums
+
+
+def run_script(args, gpu_id, self_id):
+ """
+ Run the script by passing the GPU ID and self ID to environment variables and execute the main.py script.
+
+ Args:
+ gpu_id (int): ID of the GPU.
+ self_id (int): ID of the process.
+
+ Returns:
+ None
+ """
+ env = os.environ.copy()
+ env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
+ env["SELF_ID"] = str(self_id)
+
+ command = (
+ f"source {args.conda_path} &&"
+ 'eval "$(conda shell.bash hook)" && '
+ f"conda activate {args.conda_env_name} && "
+ "python main.py"
+ )
+
+ try:
+ process = subprocess.Popen(command, shell=True, env=env, executable="/bin/bash")
+ process.wait()
+ logger.info(f"Process for GPU {gpu_id} completed successfully.")
+ except KeyboardInterrupt:
+ logger.warning(f"Multi - GPU {gpu_id}: Interrupted by keyboard, exiting...")
+ except Exception as e:
+ logger.error(f"Error occurred for GPU {gpu_id}: {e}")
+
+
+def main(args, self_id):
+ """
+ Start multiple script tasks using multiple processes, each process using one GPU.
+
+ Args:
+ self_id (str): Identifier for the current process.
+
+ Returns:
+ None
+ """
+ disabled_ids = []
+ if args.disabled_gpu_ids:
+ disabled_ids = [int(i) for i in args.disabled_gpu_ids.split(",")]
+ logger.info(f"CUDA_DISABLE_ID is set, not using: {disabled_ids}")
+
+ gpus_count = get_gpu_nums()
+
+ available_gpus = [i for i in range(gpus_count) if i not in disabled_ids]
+ processes = []
+
+ for gpu_id in available_gpus:
+ process = multiprocessing.Process(
+ target=run_script, args=(args, gpu_id, self_id)
+ )
+ process.start()
+ logger.info(f"GPU {gpu_id}: started...")
+ time.sleep(1)
+ processes.append(process)
+
+ for process in processes:
+ process.join()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--self_id", type=str, default="main_multi", help="Log ID")
+ parser.add_argument(
+ "--disabled_gpu_ids",
+ type=str,
+ default="",
+ help="Comma-separated list of disabled GPU IDs, default uses all available GPUs",
+ )
+ parser.add_argument(
+ "--conda_path",
+ type=str,
+ default="/opt/conda/etc/profile.d/conda.sh",
+ help="Conda path",
+ )
+ parser.add_argument(
+ "--conda_env_name",
+ type=str,
+ default="AudioPipeline",
+ help="Conda environment name",
+ )
+ parser.add_argument(
+ "--main_command_args",
+ type=str,
+ default="",
+ help="Main command args, check available options by `python main.py --help`",
+ )
+ args = parser.parse_args()
+
+ self_id = args.self_id
+ if "SELF_ID" in os.environ:
+ self_id = f"{self_id}_#{os.environ['SELF_ID']}"
+
+ logger = Logger.get_logger(self_id)
+
+ logger.info(f"Starting main_multi.py with self_id: {self_id}, args: {vars(args)}.")
+ main(args, self_id)
+ logger.info("Exiting main_multi.py...")
diff --git a/preprocessors/Emilia/models/__init__.py b/preprocessors/Emilia/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/preprocessors/Emilia/models/dnsmos.py b/preprocessors/Emilia/models/dnsmos.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b17f196569d61968a2fb3867b4da75048b3ef6c
--- /dev/null
+++ b/preprocessors/Emilia/models/dnsmos.py
@@ -0,0 +1,174 @@
+# Source: https://github.com/microsoft/DNS-Challenge/tree/master/DNSMOS
+#
+# Copyright (c) 2022 Microsoft
+#
+# This code is licensed under the Creative Commons Attribution 4.0 International (CC BY 4.0) license.
+# The full license text is available at the root of the source repository.
+#
+# Note: This code has been modified to fit the context of this repository.
+# This code is included in an MIT-licensed repository.
+# The repository's MIT license does not apply to this code.
+
+import os
+import librosa
+import numpy as np
+import onnxruntime as ort
+import pandas as pd
+import tqdm
+import warnings
+
+
+warnings.filterwarnings("ignore")
+
+SAMPLING_RATE = 16000
+INPUT_LENGTH = 9.01
+
+
+class ComputeScore:
+ """
+ ComputeScore class for evaluating DNSMOS.
+ """
+
+ def __init__(self, primary_model_path, device="cpu") -> None:
+ """
+ Initialize the ComputeScore object.
+
+ Args:
+ primary_model_path (str): Path to the primary model.
+ device (str): Device to run the models on ('cpu' or 'cuda').
+
+ Returns:
+ None
+
+ Raises:
+ RuntimeError: If the device is not supported.
+ """
+ if device == "cuda":
+ self.onnx_sess = ort.InferenceSession(
+ primary_model_path, providers=["CUDAExecutionProvider"]
+ )
+ print("Using CUDA:", self.onnx_sess.get_providers())
+ else:
+ self.onnx_sess = ort.InferenceSession(primary_model_path)
+
+ def audio_melspec(
+ self, audio, n_mels=120, frame_size=320, hop_length=160, sr=16000, to_db=True
+ ):
+ """
+ Compute the mel spectrogram of an audio signal.
+
+ Args:
+ audio (np.ndarray): Input audio signal.
+ n_mels (int): Number of mel bands.
+ frame_size (int): Size of the FFT window.
+ hop_length (int): Number of samples between successive frames.
+ sr (int): Sampling rate.
+ to_db (bool): Whether to convert the power spectrogram to decibel units.
+
+ Returns:
+ np.ndarray: Mel spectrogram.
+ """
+ mel_spec = librosa.feature.melspectrogram(
+ y=audio, sr=sr, n_fft=frame_size + 1, hop_length=hop_length, n_mels=n_mels
+ )
+ if to_db:
+ mel_spec = (librosa.power_to_db(mel_spec, ref=np.max) + 40) / 40
+ return mel_spec.T
+
+ def get_polyfit_val(self, sig, bak, ovr, is_personalized_MOS):
+ """
+ Apply polynomial fitting to MOS scores.
+
+ Args:
+ sig (float): Signal MOS score.
+ bak (float): Background MOS score.
+ ovr (float): Overall MOS score.
+ is_personalized_MOS (bool): Flag for personalized MOS.
+
+ Returns:
+ tuple: Tuple containing the adjusted signal, background, and overall MOS scores.
+ """
+ if is_personalized_MOS:
+ p_ovr = np.poly1d([-0.00533021, 0.005101, 1.18058466, -0.11236046])
+ p_sig = np.poly1d([-0.01019296, 0.02751166, 1.19576786, -0.24348726])
+ p_bak = np.poly1d([-0.04976499, 0.44276479, -0.1644611, 0.96883132])
+ else:
+ p_ovr = np.poly1d([-0.06766283, 1.11546468, 0.04602535])
+ p_sig = np.poly1d([-0.08397278, 1.22083953, 0.0052439])
+ p_bak = np.poly1d([-0.13166888, 1.60915514, -0.39604546])
+
+ sig_poly = p_sig(sig)
+ bak_poly = p_bak(bak)
+ ovr_poly = p_ovr(ovr)
+
+ return sig_poly, bak_poly, ovr_poly
+
+ def __call__(self, audio, sampling_rate, is_personalized_MOS):
+ """
+ Compute DNSMOS scores for an audio signal.
+
+ Args:
+ audio (np.ndarray or str): Input audio signal or path to audio file.
+ sampling_rate (int): Sampling rate of the input audio.
+ is_personalized_MOS (bool): Flag for personalized MOS.
+
+ Returns:
+ dict: Dictionary containing MOS scores.
+
+ Raises:
+ ValueError: If the input audio is not valid.
+ """
+ fs = SAMPLING_RATE
+ if isinstance(audio, str):
+ audio, _ = librosa.load(audio, sr=fs)
+ elif sampling_rate != fs:
+ # resample audio
+ audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=fs)
+
+ actual_audio_len = len(audio)
+
+ len_samples = int(INPUT_LENGTH * fs)
+ while len(audio) < len_samples:
+ audio = np.append(audio, audio)
+
+ num_hops = int(np.floor(len(audio) / fs) - INPUT_LENGTH) + 1
+ hop_len_samples = fs
+ predicted_mos_sig_seg_raw = []
+ predicted_mos_bak_seg_raw = []
+ predicted_mos_ovr_seg_raw = []
+ predicted_mos_sig_seg = []
+ predicted_mos_bak_seg = []
+ predicted_mos_ovr_seg = []
+
+ for idx in range(num_hops):
+ audio_seg = audio[
+ int(idx * hop_len_samples) : int((idx + INPUT_LENGTH) * hop_len_samples)
+ ]
+ if len(audio_seg) < len_samples:
+ continue
+ input_features = np.array(audio_seg).astype("float32")[np.newaxis, :]
+ oi = {"input_1": input_features}
+ mos_sig_raw, mos_bak_raw, mos_ovr_raw = self.onnx_sess.run(None, oi)[0][0]
+ mos_sig, mos_bak, mos_ovr = self.get_polyfit_val(
+ mos_sig_raw, mos_bak_raw, mos_ovr_raw, is_personalized_MOS
+ )
+ predicted_mos_sig_seg_raw.append(mos_sig_raw)
+ predicted_mos_bak_seg_raw.append(mos_bak_raw)
+ predicted_mos_ovr_seg_raw.append(mos_ovr_raw)
+ predicted_mos_sig_seg.append(mos_sig)
+ predicted_mos_bak_seg.append(mos_bak)
+ predicted_mos_ovr_seg.append(mos_ovr)
+
+ clip_dict = {
+ "filename": "audio_clip",
+ "len_in_sec": actual_audio_len / fs,
+ "sr": fs,
+ "num_hops": num_hops,
+ "OVRL_raw": np.mean(predicted_mos_ovr_seg_raw),
+ "SIG_raw": np.mean(predicted_mos_sig_seg_raw),
+ "BAK_raw": np.mean(predicted_mos_bak_seg_raw),
+ "OVRL": np.mean(predicted_mos_ovr_seg),
+ "SIG": np.mean(predicted_mos_sig_seg),
+ "BAK": np.mean(predicted_mos_bak_seg),
+ }
+ return clip_dict
diff --git a/preprocessors/Emilia/models/separate_fast.py b/preprocessors/Emilia/models/separate_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..d761dd8d114ffcbfd1f63bcde83b3d6723763433
--- /dev/null
+++ b/preprocessors/Emilia/models/separate_fast.py
@@ -0,0 +1,293 @@
+# Copyright (c) 2023 seanghay
+#
+# This code is from an unliscensed repository.
+#
+# Note: This code has been modified to fit the context of this repository.
+# This code is included in an MIT-licensed repository.
+# The repository's MIT license does not apply to this code.
+
+# This code is modified from https://github.com/seanghay/uvr-mdx-infer/blob/main/separate.py
+
+import torch
+import numpy as np
+import onnxruntime as ort
+from tqdm import tqdm
+
+
+class ConvTDFNet:
+ """
+ ConvTDFNet - Convolutional Temporal Frequency Domain Network.
+ """
+
+ def __init__(self, target_name, L, dim_f, dim_t, n_fft, hop=1024):
+ """
+ Initialize ConvTDFNet.
+
+ Args:
+ target_name (str): The target name for separation.
+ L (int): Number of layers.
+ dim_f (int): Dimension in the frequency domain.
+ dim_t (int): Dimension in the time domain (log2).
+ n_fft (int): FFT size.
+ hop (int, optional): Hop size. Defaults to 1024.
+
+ Returns:
+ None
+ """
+ super(ConvTDFNet, self).__init__()
+ self.dim_c = 4
+ self.dim_f = dim_f
+ self.dim_t = 2**dim_t
+ self.n_fft = n_fft
+ self.hop = hop
+ self.n_bins = self.n_fft // 2 + 1
+ self.chunk_size = hop * (self.dim_t - 1)
+ self.window = torch.hann_window(window_length=self.n_fft, periodic=True)
+ self.target_name = target_name
+
+ out_c = self.dim_c * 4 if target_name == "*" else self.dim_c
+
+ self.freq_pad = torch.zeros([1, out_c, self.n_bins - self.dim_f, self.dim_t])
+ self.n = L // 2
+
+ def stft(self, x):
+ """
+ Perform Short-Time Fourier Transform (STFT).
+
+ Args:
+ x (torch.Tensor): Input waveform.
+
+ Returns:
+ torch.Tensor: STFT of the input waveform.
+ """
+ x = x.reshape([-1, self.chunk_size])
+ x = torch.stft(
+ x,
+ n_fft=self.n_fft,
+ hop_length=self.hop,
+ window=self.window,
+ center=True,
+ return_complex=True,
+ )
+ x = torch.view_as_real(x)
+ x = x.permute([0, 3, 1, 2])
+ x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape(
+ [-1, self.dim_c, self.n_bins, self.dim_t]
+ )
+ return x[:, :, : self.dim_f]
+
+ def istft(self, x, freq_pad=None):
+ """
+ Perform Inverse Short-Time Fourier Transform (ISTFT).
+
+ Args:
+ x (torch.Tensor): Input STFT.
+ freq_pad (torch.Tensor, optional): Frequency padding. Defaults to None.
+
+ Returns:
+ torch.Tensor: Inverse STFT of the input.
+ """
+ freq_pad = (
+ self.freq_pad.repeat([x.shape[0], 1, 1, 1])
+ if freq_pad is None
+ else freq_pad
+ )
+ x = torch.cat([x, freq_pad], -2)
+ c = 4 * 2 if self.target_name == "*" else 2
+ x = x.reshape([-1, c, 2, self.n_bins, self.dim_t]).reshape(
+ [-1, 2, self.n_bins, self.dim_t]
+ )
+ x = x.permute([0, 2, 3, 1])
+ x = x.contiguous()
+ x = torch.view_as_complex(x)
+ x = torch.istft(
+ x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True
+ )
+ return x.reshape([-1, c, self.chunk_size])
+
+
+class Predictor:
+ """
+ Predictor class for source separation using ConvTDFNet and ONNX Runtime.
+ """
+
+ def __init__(self, args, device):
+ """
+ Initialize the Predictor.
+
+ Args:
+ args (dict): Configuration arguments.
+ device (str): Device to run the model ('cuda' or 'cpu').
+
+ Returns:
+ None
+
+ Raises:
+ ValueError: If the provided device is not 'cuda' or 'cpu'.
+ """
+ self.args = args
+ self.model_ = ConvTDFNet(
+ target_name="vocals",
+ L=11,
+ dim_f=args["dim_f"],
+ dim_t=args["dim_t"],
+ n_fft=args["n_fft"],
+ )
+
+ if device == "cuda":
+ self.model = ort.InferenceSession(
+ args["model_path"], providers=["CUDAExecutionProvider"]
+ )
+ elif device == "cpu":
+ self.model = ort.InferenceSession(
+ args["model_path"], providers=["CPUExecutionProvider"]
+ )
+ else:
+ raise ValueError("Device must be either 'cuda' or 'cpu'")
+
+ def demix(self, mix):
+ """
+ Separate the sources from the input mix.
+
+ Args:
+ mix (np.ndarray): Input mixture signal.
+
+ Returns:
+ np.ndarray: Separated sources.
+
+ Raises:
+ AssertionError: If margin is zero.
+ """
+ samples = mix.shape[-1]
+ margin = self.args["margin"]
+ chunk_size = self.args["chunks"] * 44100
+
+ assert margin != 0, "Margin cannot be zero!"
+
+ if margin > chunk_size:
+ margin = chunk_size
+
+ segmented_mix = {}
+
+ if self.args["chunks"] == 0 or samples < chunk_size:
+ chunk_size = samples
+
+ counter = -1
+ for skip in range(0, samples, chunk_size):
+ counter += 1
+ s_margin = 0 if counter == 0 else margin
+ end = min(skip + chunk_size + margin, samples)
+ start = skip - s_margin
+ segmented_mix[skip] = mix[:, start:end].copy()
+ if end == samples:
+ break
+
+ sources = self.demix_base(segmented_mix, margin_size=margin)
+ return sources
+
+ def demix_base(self, mixes, margin_size):
+ """
+ Base function for source separation.
+
+ Args:
+ mixes (dict): Dictionary of segmented mixtures.
+ margin_size (int): Size of the margin.
+
+ Returns:
+ np.ndarray: Separated sources.
+ """
+ chunked_sources = []
+ progress_bar = tqdm(total=len(mixes))
+ progress_bar.set_description("Source separation")
+
+ for mix in mixes:
+ cmix = mixes[mix]
+ sources = []
+ n_sample = cmix.shape[1]
+ model = self.model_
+ trim = model.n_fft // 2
+ gen_size = model.chunk_size - 2 * trim
+ pad = gen_size - n_sample % gen_size
+ mix_p = np.concatenate(
+ (np.zeros((2, trim)), cmix, np.zeros((2, pad)), np.zeros((2, trim))), 1
+ )
+ mix_waves = []
+ i = 0
+ while i < n_sample + pad:
+ waves = np.array(mix_p[:, i : i + model.chunk_size])
+ mix_waves.append(waves)
+ i += gen_size
+
+ mix_waves = torch.tensor(np.array(mix_waves), dtype=torch.float32)
+
+ with torch.no_grad():
+ _ort = self.model
+ spek = model.stft(mix_waves)
+ if self.args["denoise"]:
+ spec_pred = (
+ -_ort.run(None, {"input": -spek.cpu().numpy()})[0] * 0.5
+ + _ort.run(None, {"input": spek.cpu().numpy()})[0] * 0.5
+ )
+ tar_waves = model.istft(torch.tensor(spec_pred))
+ else:
+ tar_waves = model.istft(
+ torch.tensor(_ort.run(None, {"input": spek.cpu().numpy()})[0])
+ )
+ tar_signal = (
+ tar_waves[:, :, trim:-trim]
+ .transpose(0, 1)
+ .reshape(2, -1)
+ .numpy()[:, :-pad]
+ )
+
+ start = 0 if mix == 0 else margin_size
+ end = None if mix == list(mixes.keys())[::-1][0] else -margin_size
+
+ if margin_size == 0:
+ end = None
+
+ sources.append(tar_signal[:, start:end])
+
+ progress_bar.update(1)
+
+ chunked_sources.append(sources)
+ _sources = np.concatenate(chunked_sources, axis=-1)
+
+ progress_bar.close()
+ return _sources
+
+ def predict(self, mix):
+ """
+ Predict the separated sources from the input mix.
+
+ Args:
+ mix (np.ndarray): Input mixture signal.
+
+ Returns:
+ tuple: Tuple containing the mixture minus the separated sources and the separated sources.
+ """
+ if mix.ndim == 1:
+ mix = np.asfortranarray([mix, mix])
+
+ tail = mix.shape[1] % (self.args["chunks"] * 44100)
+ if mix.shape[1] % (self.args["chunks"] * 44100) != 0:
+ mix = np.pad(
+ mix,
+ (
+ (0, 0),
+ (
+ 0,
+ self.args["chunks"] * 44100
+ - mix.shape[1] % (self.args["chunks"] * 44100),
+ ),
+ ),
+ )
+
+ mix = mix.T
+ sources = self.demix(mix.T)
+ opt = sources[0].T
+
+ if tail != 0:
+ return ((mix - opt)[: -(self.args["chunks"] * 44100 - tail), :], opt)
+ else:
+ return ((mix - opt), opt)
diff --git a/preprocessors/Emilia/models/silero_vad.py b/preprocessors/Emilia/models/silero_vad.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca9390c46e2a92d51a2e213b238d8187c0480a7a
--- /dev/null
+++ b/preprocessors/Emilia/models/silero_vad.py
@@ -0,0 +1,181 @@
+# Source: https://github.com/snakers4/silero-vad
+#
+# Copyright (c) 2024 snakers4
+#
+# This code is from a MIT-licensed repository. The full license text is available at the root of the source repository.
+#
+# Note: This code has been modified to fit the context of this repository.
+
+import librosa
+import torch
+import numpy as np
+
+VAD_THRESHOLD = 20
+SAMPLING_RATE = 16000
+
+
+class SileroVAD:
+ """
+ Voice Activity Detection (VAD) using Silero-VAD.
+ """
+
+ def __init__(self, local=False, model="silero_vad", device=torch.device("cpu")):
+ """
+ Initialize the VAD object.
+
+ Args:
+ local (bool, optional): Whether to load the model locally. Defaults to False.
+ model (str, optional): The VAD model name to load. Defaults to "silero_vad".
+ device (torch.device, optional): The device to run the model on. Defaults to 'cpu'.
+
+ Returns:
+ None
+
+ Raises:
+ RuntimeError: If loading the model fails.
+ """
+ try:
+ vad_model, utils = torch.hub.load(
+ repo_or_dir="snakers4/silero-vad" if not local else "vad/silero-vad",
+ model=model,
+ force_reload=False,
+ onnx=True,
+ source="github" if not local else "local",
+ )
+ self.vad_model = vad_model
+ (get_speech_timestamps, _, _, _, _) = utils
+ self.get_speech_timestamps = get_speech_timestamps
+ except Exception as e:
+ raise RuntimeError(f"Failed to load VAD model: {e}")
+
+ def segment_speech(self, audio_segment, start_time, end_time, sampling_rate):
+ """
+ Segment speech from an audio segment and return a list of timestamps.
+
+ Args:
+ audio_segment (np.ndarray): The audio segment to be segmented.
+ start_time (int): The start time of the audio segment in frames.
+ end_time (int): The end time of the audio segment in frames.
+ sampling_rate (int): The sampling rate of the audio segment.
+
+ Returns:
+ list: A list of timestamps, each containing the start and end times of speech segments in frames.
+
+ Raises:
+ ValueError: If the audio segment is invalid.
+ """
+ if audio_segment is None or not isinstance(audio_segment, (np.ndarray, list)):
+ raise ValueError("Invalid audio segment")
+
+ speech_timestamps = self.get_speech_timestamps(
+ audio_segment, self.vad_model, sampling_rate=sampling_rate
+ )
+
+ adjusted_timestamps = [
+ (ts["start"] + start_time, ts["end"] + start_time)
+ for ts in speech_timestamps
+ ]
+ if not adjusted_timestamps:
+ return []
+
+ intervals = [
+ end[0] - start[1]
+ for start, end in zip(adjusted_timestamps[:-1], adjusted_timestamps[1:])
+ ]
+
+ segments = []
+
+ def split_timestamps(start_index, end_index):
+ if (
+ start_index == end_index
+ or adjusted_timestamps[end_index][1]
+ - adjusted_timestamps[start_index][0]
+ < 20 * sampling_rate
+ ):
+ segments.append([start_index, end_index])
+ else:
+ if not intervals[start_index:end_index]:
+ return
+ max_interval_index = intervals[start_index:end_index].index(
+ max(intervals[start_index:end_index])
+ )
+ split_index = start_index + max_interval_index
+ split_timestamps(start_index, split_index)
+ split_timestamps(split_index + 1, end_index)
+
+ split_timestamps(0, len(adjusted_timestamps) - 1)
+
+ merged_timestamps = [
+ [adjusted_timestamps[start][0], adjusted_timestamps[end][1]]
+ for start, end in segments
+ ]
+ return merged_timestamps
+
+ def vad(self, speakerdia, audio):
+ """
+ Process the audio based on the given speaker diarization dataframe.
+
+ Args:
+ speakerdia (pd.DataFrame): The diarization dataframe containing start, end, and speaker info.
+ audio (dict): A dictionary containing the audio waveform and sample rate.
+
+ Returns:
+ list: A list of dictionaries containing processed audio segments with start, end, and speaker.
+ """
+ sampling_rate = audio["sample_rate"]
+ audio_data = audio["waveform"]
+
+ out = []
+ last_end = 0
+ speakers_seen = set()
+ count_id = 0
+
+ for index, row in speakerdia.iterrows():
+ start = float(row["start"])
+ end = float(row["end"])
+
+ if end <= last_end:
+ continue
+ last_end = end
+
+ start_frame = int(start * sampling_rate)
+ end_frame = int(end * sampling_rate)
+ if row["speaker"] not in speakers_seen:
+ speakers_seen.add(row["speaker"])
+
+ if end - start <= VAD_THRESHOLD:
+ out.append(
+ {
+ "index": str(count_id).zfill(5),
+ "start": start, # in seconds
+ "end": end,
+ "speaker": row["speaker"], # same for all
+ }
+ )
+ count_id += 1
+ continue
+
+ temp_audio = audio_data[start_frame:end_frame]
+
+ # resample from 24k to 16k
+ temp_audio_resampled = librosa.resample(
+ temp_audio, orig_sr=sampling_rate, target_sr=SAMPLING_RATE
+ )
+
+ for start_frame_sub, end_frame_sub in self.segment_speech(
+ temp_audio_resampled,
+ int(start * SAMPLING_RATE),
+ int(end * SAMPLING_RATE),
+ SAMPLING_RATE,
+ ):
+ out.append(
+ {
+ "index": str(count_id).zfill(5),
+ "start": start_frame_sub / SAMPLING_RATE, # in seconds
+ "end": end_frame_sub / SAMPLING_RATE,
+ "speaker": row["speaker"], # same for all
+ }
+ )
+ count_id += 1
+
+ return out
diff --git a/preprocessors/Emilia/models/whisper_asr.py b/preprocessors/Emilia/models/whisper_asr.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd062b80ef1cc4e91502f94db4d95ac6452fa9f1
--- /dev/null
+++ b/preprocessors/Emilia/models/whisper_asr.py
@@ -0,0 +1,299 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import faster_whisper
+from typing import List, Union, Optional, NamedTuple
+import torch
+import numpy as np
+import tqdm
+from whisperx.audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
+from whisperx.types import TranscriptionResult, SingleSegment
+from whisperx.asr import WhisperModel, FasterWhisperPipeline, find_numeral_symbol_tokens
+
+
+class VadFreeFasterWhisperPipeline(FasterWhisperPipeline):
+ """
+ FasterWhisperModel without VAD
+ """
+
+ def __init__(
+ self,
+ model,
+ options: NamedTuple,
+ tokenizer=None,
+ device: Union[int, str, "torch.device"] = -1,
+ framework="pt",
+ language: Optional[str] = None,
+ suppress_numerals: bool = False,
+ **kwargs,
+ ):
+ """
+ Initialize the VadFreeFasterWhisperPipeline.
+
+ Args:
+ model: The Whisper model instance.
+ options: Transcription options.
+ tokenizer: The tokenizer instance.
+ device: Device to run the model on.
+ framework: The framework to use ('pt' for PyTorch).
+ language: The language for transcription.
+ suppress_numerals: Whether to suppress numeral tokens.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ None
+ """
+ super().__init__(
+ model=model,
+ vad=None,
+ vad_params={},
+ options=options,
+ tokenizer=tokenizer,
+ device=device,
+ framework=framework,
+ language=language,
+ suppress_numerals=suppress_numerals,
+ **kwargs,
+ )
+
+ def detect_language(self, audio: np.ndarray):
+ """
+ Detect the language of the audio.
+
+ Args:
+ audio (np.ndarray): The input audio signal.
+
+ Returns:
+ tuple: Detected language and its probability.
+ """
+ model_n_mels = self.model.feat_kwargs.get("feature_size")
+ if audio.shape[0] > N_SAMPLES:
+ # Randomly sample N_SAMPLES from the audio array
+ start_index = np.random.randint(0, audio.shape[0] - N_SAMPLES)
+ audio_sample = audio[start_index : start_index + N_SAMPLES]
+ else:
+ audio_sample = audio[:N_SAMPLES]
+ padding = 0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0]
+ segment = log_mel_spectrogram(
+ audio_sample,
+ n_mels=model_n_mels if model_n_mels is not None else 80,
+ padding=padding,
+ )
+ encoder_output = self.model.encode(segment)
+ results = self.model.model.detect_language(encoder_output)
+ language_token, language_probability = results[0][0]
+ language = language_token[2:-2]
+ return language, language_probability
+
+ def transcribe(
+ self,
+ audio: Union[str, np.ndarray],
+ vad_segments: List[dict],
+ batch_size=None,
+ num_workers=0,
+ language=None,
+ task=None,
+ chunk_size=30,
+ print_progress=False,
+ combined_progress=False,
+ ) -> TranscriptionResult:
+ """
+ Transcribe the audio into text.
+
+ Args:
+ audio (Union[str, np.ndarray]): The input audio signal or path to audio file.
+ vad_segments (List[dict]): List of VAD segments.
+ batch_size (int, optional): Batch size for transcription. Defaults to None.
+ num_workers (int, optional): Number of workers for loading data. Defaults to 0.
+ language (str, optional): Language for transcription. Defaults to None.
+ task (str, optional): Task type ('transcribe' or 'translate'). Defaults to None.
+ chunk_size (int, optional): Size of chunks for processing. Defaults to 30.
+ print_progress (bool, optional): Whether to print progress. Defaults to False.
+ combined_progress (bool, optional): Whether to combine progress. Defaults to False.
+
+ Returns:
+ TranscriptionResult: The transcription result containing segments and language.
+ """
+ if isinstance(audio, str):
+ audio = load_audio(audio)
+
+ def data(audio, segments):
+ for seg in segments:
+ f1 = int(seg["start"] * SAMPLE_RATE)
+ f2 = int(seg["end"] * SAMPLE_RATE)
+ yield {"inputs": audio[f1:f2]}
+
+ if self.tokenizer is None:
+ language = language or self.detect_language(audio)
+ task = task or "transcribe"
+ self.tokenizer = faster_whisper.tokenizer.Tokenizer(
+ self.model.hf_tokenizer,
+ self.model.model.is_multilingual,
+ task=task,
+ language=language,
+ )
+ else:
+ language = language or self.tokenizer.language_code
+ task = task or self.tokenizer.task
+ if task != self.tokenizer.task or language != self.tokenizer.language_code:
+ self.tokenizer = faster_whisper.tokenizer.Tokenizer(
+ self.model.hf_tokenizer,
+ self.model.model.is_multilingual,
+ task=task,
+ language=language,
+ )
+
+ if self.suppress_numerals:
+ previous_suppress_tokens = self.options.suppress_tokens
+ numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer)
+ new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens
+ new_suppressed_tokens = list(set(new_suppressed_tokens))
+ self.options = self.options._replace(suppress_tokens=new_suppressed_tokens)
+
+ segments: List[SingleSegment] = []
+ batch_size = batch_size or self._batch_size
+ total_segments = len(vad_segments)
+ progress = tqdm.tqdm(total=total_segments, desc="Transcribing")
+ for idx, out in enumerate(
+ self.__call__(
+ data(audio, vad_segments),
+ batch_size=batch_size,
+ num_workers=num_workers,
+ )
+ ):
+ if print_progress:
+ progress.update(1)
+ text = out["text"]
+ if batch_size in [0, 1, None]:
+ text = text[0]
+ segments.append(
+ {
+ "text": text,
+ "start": round(vad_segments[idx]["start"], 3),
+ "end": round(vad_segments[idx]["end"], 3),
+ "speaker": vad_segments[idx].get("speaker", None),
+ }
+ )
+
+ # revert the tokenizer if multilingual inference is enabled
+ if self.preset_language is None:
+ self.tokenizer = None
+
+ # revert suppressed tokens if suppress_numerals is enabled
+ if self.suppress_numerals:
+ self.options = self.options._replace(
+ suppress_tokens=previous_suppress_tokens
+ )
+
+ return {"segments": segments, "language": language}
+
+
+def load_asr_model(
+ whisper_arch: str,
+ device: str,
+ device_index: int = 0,
+ compute_type: str = "float16",
+ asr_options: Optional[dict] = None,
+ language: Optional[str] = None,
+ vad_model=None,
+ vad_options=None,
+ model: Optional[WhisperModel] = None,
+ task: str = "transcribe",
+ download_root: Optional[str] = None,
+ threads: int = 4,
+) -> VadFreeFasterWhisperPipeline:
+ """
+ Load a Whisper model for inference.
+
+ Args:
+ whisper_arch (str): The name of the Whisper model to load.
+ device (str): The device to load the model on.
+ device_index (int, optional): The device index. Defaults to 0.
+ compute_type (str, optional): The compute type to use for the model. Defaults to "float16".
+ asr_options (Optional[dict], optional): Options for ASR. Defaults to None.
+ language (Optional[str], optional): The language of the model. Defaults to None.
+ vad_model: The VAD model instance. Defaults to None.
+ vad_options: Options for VAD. Defaults to None.
+ model (Optional[WhisperModel], optional): The WhisperModel instance to use. Defaults to None.
+ task (str, optional): The task type ('transcribe' or 'translate'). Defaults to "transcribe".
+ download_root (Optional[str], optional): The root directory to download the model to. Defaults to None.
+ threads (int, optional): The number of CPU threads to use per worker. Defaults to 4.
+
+ Returns:
+ VadFreeFasterWhisperPipeline: The loaded Whisper pipeline.
+
+ Raises:
+ ValueError: If the whisper architecture is not recognized.
+ """
+
+ if whisper_arch.endswith(".en"):
+ language = "en"
+
+ model = model or WhisperModel(
+ whisper_arch,
+ device=device,
+ device_index=device_index,
+ compute_type=compute_type,
+ download_root=download_root,
+ cpu_threads=threads,
+ )
+ if language is not None:
+ tokenizer = faster_whisper.tokenizer.Tokenizer(
+ model.hf_tokenizer,
+ model.model.is_multilingual,
+ task=task,
+ language=language,
+ )
+ else:
+ print(
+ "No language specified, language will be detected for each audio file (increases inference time)."
+ )
+ tokenizer = None
+
+ default_asr_options = {
+ "beam_size": 5,
+ "best_of": 5,
+ "patience": 1,
+ "length_penalty": 1,
+ "repetition_penalty": 1,
+ "no_repeat_ngram_size": 0,
+ "temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
+ "compression_ratio_threshold": 2.4,
+ "log_prob_threshold": -1.0,
+ "no_speech_threshold": 0.6,
+ "condition_on_previous_text": False,
+ "prompt_reset_on_temperature": 0.5,
+ "initial_prompt": None,
+ "prefix": None,
+ "suppress_blank": True,
+ "suppress_tokens": [-1],
+ "without_timestamps": True,
+ "max_initial_timestamp": 0.0,
+ "word_timestamps": False,
+ "prepend_punctuations": "\"'“¿([{-",
+ "append_punctuations": "\"'.。,,!!??::”)]}、",
+ "suppress_numerals": False,
+ "max_new_tokens": None,
+ "clip_timestamps": None,
+ "hallucination_silence_threshold": None,
+ }
+
+ if asr_options is not None:
+ default_asr_options.update(asr_options)
+
+ suppress_numerals = default_asr_options["suppress_numerals"]
+ del default_asr_options["suppress_numerals"]
+
+ default_asr_options = faster_whisper.transcribe.TranscriptionOptions(
+ **default_asr_options
+ )
+
+ return VadFreeFasterWhisperPipeline(
+ model=model,
+ options=default_asr_options,
+ tokenizer=tokenizer,
+ language=language,
+ suppress_numerals=suppress_numerals,
+ )
diff --git a/preprocessors/Emilia/requirements.txt b/preprocessors/Emilia/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8657331fb5e9b58b5c83a4375c9cfe712880c271
--- /dev/null
+++ b/preprocessors/Emilia/requirements.txt
@@ -0,0 +1,7 @@
+librosa
+numpy
+tqdm
+pydub
+pyannote.audio
+pandas
+git+https://github.com/m-bain/whisperx.git # needs torch >= 2.0.0
diff --git a/preprocessors/Emilia/utils/__init__.py b/preprocessors/Emilia/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/preprocessors/Emilia/utils/logger.py b/preprocessors/Emilia/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd6864f2aee1edc11b46e70596d3229519214ca8
--- /dev/null
+++ b/preprocessors/Emilia/utils/logger.py
@@ -0,0 +1,124 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import time
+import os
+
+
+class Logger:
+ """
+ Logger class for managing logging operations.
+ """
+
+ _logger = None
+
+ @classmethod
+ def get_logger(cls, name=None):
+ """
+ Get the logger instance with the specified name. If it doesn't exist, create and cache it.
+
+ Args:
+ cls (type): The class type.
+ name (str, optional): The name of the logger. Defaults to None, which uses the class name.
+
+ Returns:
+ logging.Logger: The logger instance.
+ """
+ if cls._logger is None:
+ cls._logger = cls.init_logger(name)
+ return cls._logger
+
+ @classmethod
+ def init_logger(cls, name=None):
+ """
+ Initialize the logger, including file and console logging.
+
+ Args:
+ cls (type): The class type.
+ name (str, optional): The name of the logger. Defaults to None.
+
+ Returns:
+ logging.Logger: The initialized logger instance.
+ """
+ if name is None:
+ name = "main"
+ if "SELF_ID" in os.environ:
+ name = name + "_ID" + os.environ["SELF_ID"]
+ if "CUDA_VISIBLE_DEVICES" in os.environ:
+ name = name + "_GPU" + os.environ["CUDA_VISIBLE_DEVICES"]
+ print(f"Initialize logger for {name}")
+ logger = logging.getLogger(name)
+ logger.setLevel(logging.DEBUG)
+
+ # Add file handler to save logs to a file
+ log_date = time.strftime("%Y-%m-%d", time.localtime())
+ log_time = time.strftime("%H-%M-%S", time.localtime())
+ os.makedirs(f"logs/{log_date}", exist_ok=True)
+
+ formatter = logging.Formatter(
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+ )
+ fh = logging.FileHandler(f"logs/{log_date}/{name}-{log_time}.log")
+ fh.setFormatter(formatter)
+ logger.addHandler(fh)
+
+ # Create a custom log formatter to set specific log levels to color
+ class ColorFormatter(logging.Formatter):
+ """
+ Custom log formatter to add color to specific log levels.
+ """
+
+ def format(self, record):
+ """
+ Format the log record with color based on log level.
+
+ Args:
+ record (logging.LogRecord): The log record to format.
+
+ Returns:
+ str: The formatted log message.
+ """
+ if record.levelno >= logging.ERROR:
+ record.msg = "\033[1;31m" + str(record.msg) + "\033[0m"
+ elif record.levelno >= logging.WARNING:
+ record.msg = "\033[1;33m" + str(record.msg) + "\033[0m"
+ elif record.levelno >= logging.INFO:
+ record.msg = "\033[1;34m" + str(record.msg) + "\033[0m"
+ elif record.levelno >= logging.DEBUG:
+ record.msg = "\033[1;32m" + str(record.msg) + "\033[0m"
+ return super().format(record)
+
+ color_formatter = ColorFormatter(
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+ )
+ ch = logging.StreamHandler()
+ ch.setFormatter(color_formatter)
+ logger.addHandler(ch)
+
+ return logger
+
+
+def time_logger(func):
+ """
+ Decorator to log the execution time of a function.
+
+ Args:
+ func (callable): The function whose execution time is to be logged.
+
+ Returns:
+ callable: The wrapper function that logs the execution time of the original function.
+ """
+
+ def wrapper(*args, **kwargs):
+ start_time = time.time()
+ result = func(*args, **kwargs)
+ end_time = time.time()
+ Logger.get_logger().debug(
+ f"Function {func.__name__} took {end_time - start_time} seconds to execute"
+ )
+ return result
+
+ return wrapper
diff --git a/preprocessors/Emilia/utils/tool.py b/preprocessors/Emilia/utils/tool.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d3a278a57e0718ec9b1a77ef225cca440586492
--- /dev/null
+++ b/preprocessors/Emilia/utils/tool.py
@@ -0,0 +1,323 @@
+# Copyright (c) 2024 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from concurrent.futures import ThreadPoolExecutor
+import json
+import os
+import librosa
+import numpy as np
+import time
+import torch
+from pydub import AudioSegment
+import soundfile as sf
+import onnxruntime as ort
+import tqdm
+import subprocess
+import re
+
+from utils.logger import Logger, time_logger
+
+
+def load_cfg(cfg_path):
+ """
+ Load configuration from a JSON file.
+
+ Args:
+ cfg_path (str): Path to the configuration file.
+
+ Returns:
+ dict: Configuration dictionary.
+ """
+ if not os.path.exists(cfg_path):
+ raise FileNotFoundError(
+ f"{cfg_path} not found. Please copy, configure, and rename `config.json.example` to `{cfg_path}`."
+ )
+ with open(cfg_path, "r") as f:
+ try:
+ cfg = json.load(f)
+ except json.decoder.JSONDecodeError as e:
+ raise TypeError(
+ "Please finish the `// TODO:` in the `config.json` file before running the script. Check README.md for details."
+ )
+ return cfg
+
+
+def write_wav(path, sr, x):
+ """Write numpy array to WAV file."""
+ sf.write(path, x, sr)
+
+
+def write_mp3(path, sr, x):
+ """Convert numpy array to MP3."""
+ try:
+ # Ensure x is in the correct format and normalize if necessary
+ if x.dtype != np.int16:
+ # Normalize the array to fit in int16 range if it's not already int16
+ x = np.int16(x / np.max(np.abs(x)) * 32767)
+
+ # Create audio segment from numpy array
+ audio = AudioSegment(
+ x.tobytes(), frame_rate=sr, sample_width=x.dtype.itemsize, channels=1
+ )
+ # Export as MP3 file
+ audio.export(path, format="mp3")
+ except Exception as e:
+ print(e)
+ print("Error: Failed to write MP3 file.")
+
+
+def get_audio_files(folder_path):
+ """Get all audio files in a folder."""
+ audio_files = []
+ for root, _, files in os.walk(folder_path):
+ if "_processed" in root:
+ continue
+ for file in files:
+ if ".temp" in file:
+ continue
+ if file.endswith((".mp3", ".wav", ".flac", ".m4a", ".aac")):
+ audio_files.append(os.path.join(root, file))
+ return audio_files
+
+
+def get_specific_files(folder_path, ext):
+ """Get specific files with a given extension in a folder."""
+ audio_files = []
+ for root, _, files in os.walk(folder_path):
+ if "_processed" in root:
+ continue
+ for file in files:
+ if ".temp" in file:
+ continue
+ if file.endswith(ext):
+ audio_files.append(os.path.join(root, file))
+ return audio_files
+
+
+def export_to_srt(asr_result, file_path):
+ """Export ASR result to SRT file."""
+ with open(file_path, "w") as f:
+
+ def format_time(seconds):
+ return (
+ time.strftime("%H:%M:%S", time.gmtime(seconds))
+ + f",{int(seconds * 1000 % 1000):03d}"
+ )
+
+ for idx, segment in enumerate(asr_result):
+ f.write(f"{idx + 1}\n")
+ f.write(
+ f"{format_time(segment['start'])} --> {format_time(segment['end'])}\n"
+ )
+ f.write(f"{segment['speaker']}: {segment['text']}\n\n")
+
+
+def detect_gpu():
+ """Detect if GPU is available and print related information."""
+ logger = Logger.get_logger()
+
+ if "CUDA_VISIBLE_DEVICES" not in os.environ:
+ logger.info("ENV: CUDA_VISIBLE_DEVICES not set, use default setting")
+ else:
+ gpu_id = os.environ["CUDA_VISIBLE_DEVICES"]
+ logger.info(f"ENV: CUDA_VISIBLE_DEVICES = {gpu_id}")
+
+ if not torch.cuda.is_available():
+ logger.error("Torch CUDA: No GPU detected. torch.cuda.is_available() = False.")
+ return False
+
+ num_gpus = torch.cuda.device_count()
+ logger.debug(f"Torch CUDA: Detected {num_gpus} GPUs.")
+ for i in range(num_gpus):
+ gpu_name = torch.cuda.get_device_name(i)
+ logger.debug(f" * GPU {i}: {gpu_name}")
+
+ logger.debug("Torch: CUDNN version = " + str(torch.backends.cudnn.version()))
+ if not torch.backends.cudnn.is_available():
+ logger.error("Torch: CUDNN is not available.")
+ return False
+ logger.debug("Torch: CUDNN is available.")
+
+ ort_providers = ort.get_available_providers()
+ logger.debug(f"ORT: Available providers: {ort_providers}")
+ if "CUDAExecutionProvider" not in ort_providers:
+ logger.warning(
+ "ORT: CUDAExecutionProvider is not available. "
+ "Please install a compatible version of ONNX Runtime. "
+ "See https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html"
+ )
+
+ return True
+
+
+def get_gpu_nums():
+ """Get GPU nums by nvidia-smi."""
+ logger = Logger.get_logger()
+ try:
+ result = subprocess.check_output("nvidia-smi -L | wc -l", shell=True)
+ gpus_count = int(result.decode().strip())
+ except Exception as e:
+ logger.error("Error occurred while getting GPU count: " + str(e))
+ gpus_count = 8 # Default to 8 if GPU count retrieval fails
+ return gpus_count
+
+
+def check_env(logger):
+ """Check environment variables."""
+ if "http_proxy" in os.environ:
+ logger.info(f"ENV: http_proxy = {os.environ['http_proxy']}")
+ else:
+ logger.info("ENV: http_proxy not set")
+
+ if "https_proxy" in os.environ:
+ logger.info(f"ENV: https_proxy = {os.environ['https_proxy']}")
+ else:
+ logger.info("ENV: https_proxy not set")
+
+ if "HF_ENDPOINT" in os.environ:
+ logger.info(
+ f"ENV: HF_ENDPOINT = {os.environ['HF_ENDPOINT']}, if downloading slow, try `unset HF_ENDPOINT`"
+ )
+ else:
+ logger.info("ENV: HF_ENDPOINT not set")
+
+ hostname = os.popen("hostname").read().strip()
+ logger.debug(f"HOSTNAME: {hostname}")
+
+ environ_path = os.environ["PATH"]
+ environ_ld_library = os.environ.get("LD_LIBRARY_PATH", "")
+ logger.debug(f"ENV: PATH = {environ_path}, LD_LIBRARY_PATH = {environ_ld_library}")
+
+
+@time_logger
+def export_to_mp3(audio, asr_result, folder_path, file_name):
+ """Export segmented audio to MP3 files."""
+ sr = audio["sample_rate"]
+ audio = audio["waveform"]
+
+ os.makedirs(folder_path, exist_ok=True)
+
+ # Function to process each segment in a separate thread
+ def process_segment(idx, segment):
+ start, end = int(segment["start"] * sr), int(segment["end"] * sr)
+ split_audio = audio[start:end]
+ split_audio = librosa.to_mono(split_audio)
+ out_file = f"{file_name}_{idx}.mp3"
+ out_path = os.path.join(folder_path, out_file)
+ write_mp3(out_path, sr, split_audio)
+
+ # Use ThreadPoolExecutor for concurrent execution
+ with ThreadPoolExecutor(max_workers=72) as executor:
+ # Submit each segment processing as a separate thread
+ futures = [
+ executor.submit(process_segment, idx, segment)
+ for idx, segment in enumerate(asr_result)
+ ]
+
+ # Wait for all threads to complete
+ for future in tqdm.tqdm(
+ futures, total=len(asr_result), desc="Exporting to MP3"
+ ):
+ future.result()
+
+
+@time_logger
+def export_to_wav(audio, asr_result, folder_path, file_name):
+ """Export segmented audio to WAV files."""
+ sr = audio["sample_rate"]
+ audio = audio["waveform"]
+
+ os.makedirs(folder_path, exist_ok=True)
+
+ for idx, segment in enumerate(tqdm.tqdm(asr_result, desc="Exporting to WAV")):
+ start, end = int(segment["start"] * sr), int(segment["end"] * sr)
+ split_audio = audio[start:end]
+ split_audio = librosa.to_mono(split_audio)
+ out_file = f"{file_name}_{idx}.wav"
+ out_path = os.path.join(folder_path, out_file)
+ write_wav(out_path, sr, split_audio)
+
+
+def get_char_count(text):
+ """
+ Get the number of characters in the text.
+
+ Args:
+ text (str): Input text.
+
+ Returns:
+ int: Number of characters in the text.
+ """
+ # Using regular expression to remove punctuation and spaces
+ cleaned_text = re.sub(r"[,.!?\"',。!?“”‘’ ]", "", text)
+ char_count = len(cleaned_text)
+ return char_count
+
+
+def calculate_audio_stats(
+ data, min_duration=3, max_duration=30, min_dnsmos=3, min_char_count=2
+):
+ """
+ Reading the proviced json, calculate and return the audio ID and their duration that meet the given filtering criteria.
+
+ Args:
+ data: JSON.
+ min_duration: Minimum duration of the audio in seconds.
+ max_duration: Maximum duration of the audio in seconds.
+ min_dnsmos: Minimum DNSMOS value.
+ min_char_count: Minimum number of characters.
+
+ Returns:
+ valid_audio_stats: A list containing tuples of audio ID and their duration.
+ """
+ all_audio_stats = []
+ valid_audio_stats = []
+ avg_durations = []
+
+ # iterate over each entry in the JSON to collect the average duration of the phonemes
+ for entry in data:
+ # remove punctuation and spaces
+ char_count = get_char_count(entry["text"])
+ duration = entry["end"] - entry["start"]
+ if char_count > 0:
+ avg_durations.append(duration / char_count)
+
+ # calculate the bounds for the average character duration
+ if len(avg_durations) > 0:
+ q1 = np.percentile(avg_durations, 25)
+ q3 = np.percentile(avg_durations, 75)
+ iqr = q3 - q1
+ lower_bound = q1 - 1.5 * iqr
+ upper_bound = q3 + 1.5 * iqr
+ else:
+ # if no valid character data, use default values
+ lower_bound, upper_bound = 0, np.inf
+
+ # iterate over each entry in the JSON to apply all filtering criteria
+ for idx, entry in enumerate(data):
+ duration = entry["end"] - entry["start"]
+ dnsmos = entry["dnsmos"]
+ # remove punctuation and spaces
+ char_count = get_char_count(entry["text"])
+ if char_count > 0:
+ avg_char_duration = duration / char_count
+ else:
+ avg_char_duration = 0
+
+ # collect the duration of all audios
+ all_audio_stats.append((idx, duration))
+
+ # apply filtering criteria
+ if (
+ (min_duration <= duration <= max_duration) # withing duration range
+ and (dnsmos >= min_dnsmos)
+ and (char_count >= min_char_count)
+ and (
+ lower_bound <= avg_char_duration <= upper_bound
+ ) # average character duration within bounds
+ ):
+ valid_audio_stats.append((idx, duration))
+
+ return valid_audio_stats, all_audio_stats
diff --git a/preprocessors/__init__.py b/preprocessors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..92514cbbafea25f61bcc4f4bbdce13bfa58b5fd8
--- /dev/null
+++ b/preprocessors/__init__.py
@@ -0,0 +1,189 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+For source datasets' standard samples
+"""
+
+from collections import defaultdict
+import os
+import json
+
+SPEECH_DATASETS = ["vctk", "vctksample"]
+
+GOLDEN_TEST_SAMPLES = defaultdict(list)
+GOLDEN_TEST_SAMPLES["m4singer"] = [
+ "Alto-1_美错_0014",
+ "Bass-1_十年_0008",
+ "Soprano-2_同桌的你_0018",
+ "Tenor-5_爱笑的眼睛_0010",
+]
+GOLDEN_TEST_SAMPLES["svcc"] = [
+ # IDF1
+ "IDF1_10030",
+ "IDF1_10120",
+ "IDF1_10140",
+ # IDM1
+ "IDM1_10001",
+ "IDM1_10030",
+ "IDM1_10120",
+ # CDF1
+ "CDF1_10030",
+ "CDF1_10120",
+ "CDF1_10140",
+ # CDM1
+ "CDM1_10001",
+ "CDM1_10030",
+ "CDM1_10120",
+]
+GOLDEN_TEST_SAMPLES["svcceval"] = [
+ # SF1
+ "SF1_30001",
+ "SF1_30002",
+ "SF1_30003",
+ # SM1
+ "SM1_30001",
+ "SM1_30002",
+ "SM1_30003",
+]
+GOLDEN_TEST_SAMPLES["popbutfy"] = [
+ "Female1#you_are_my_sunshine_Professional#0",
+ "Female4#Someone_Like_You_Professional#10",
+ "Male2#Lemon_Tree_Professional#12",
+ "Male5#can_you_feel_the_love_tonight_Professional#20",
+]
+GOLDEN_TEST_SAMPLES["opensinger"] = [
+ "Man_0_大鱼_10",
+ "Man_21_丑八怪_14",
+ "Woman_39_mojito_22",
+ "Woman_40_易燃易爆炸_12",
+]
+GOLDEN_TEST_SAMPLES["nus48e"] = [
+ "ADIZ_read#01#0000",
+ "MCUR_sing#10#0000",
+ "JLEE_read#08#0001",
+ "SAMF_sing#18#0001",
+]
+GOLDEN_TEST_SAMPLES["popcs"] = [
+ "明天会更好_0004",
+ "欧若拉_0005",
+ "虫儿飞_0006",
+ "隐形的翅膀_0008",
+]
+GOLDEN_TEST_SAMPLES["kising"] = [
+ "421_0040",
+ "424_0013",
+ "431_0026",
+]
+GOLDEN_TEST_SAMPLES["csd"] = [
+ "en_004a_0001",
+ "en_042b_0006",
+ "kr_013a_0006",
+ "kr_045b_0004",
+]
+GOLDEN_TEST_SAMPLES["opera"] = [
+ "fem_01#neg_1#0000",
+ "fem_12#pos_3#0003",
+ "male_02#neg_1#0002",
+ "male_11#pos_2#0001",
+]
+GOLDEN_TEST_SAMPLES["lijian"] = [
+ "058矜持_0000",
+ "079绒花_0000",
+ "120遥远的天空底下_0000",
+]
+GOLDEN_TEST_SAMPLES["cdmusiceval"] = ["陶喆_普通朋友", "蔡琴_给电影人的情书"]
+
+GOLDEN_TRAIN_SAMPLES = defaultdict(list)
+
+
+def get_golden_samples_indexes(
+ dataset_name,
+ dataset_dir=None,
+ cfg=None,
+ split=None,
+ min_samples=5,
+):
+ """
+ # Get Standard samples' indexes
+ """
+ if dataset_dir is None:
+ assert cfg is not None
+ dataset_dir = os.path.join(
+ cfg.OUTPUT_PATH,
+ "preprocess/{}_version".format(cfg.PREPROCESS_VERSION),
+ dataset_name,
+ )
+
+ assert split is not None
+ utt_file = os.path.join(dataset_dir, "{}.json".format(split))
+ with open(utt_file, "r", encoding="utf-8") as f:
+ samples = json.load(f)
+
+ if "train" in split:
+ golden_samples = GOLDEN_TRAIN_SAMPLES[dataset_name]
+ if "test" in split:
+ golden_samples = GOLDEN_TEST_SAMPLES[dataset_name]
+
+ res = []
+ for idx, utt in enumerate(samples):
+ if utt["Uid"] in golden_samples:
+ res.append(idx)
+
+ if dataset_name == "cdmusiceval":
+ if "_".join(utt["Uid"].split("_")[:2]) in golden_samples:
+ res.append(idx)
+
+ if len(res) == 0:
+ res = [i for i in range(min_samples)]
+
+ return res
+
+
+def get_specific_singer_indexes(dataset_dir, singer_name, split):
+ utt_file = os.path.join(dataset_dir, "{}.json".format(split))
+ with open(utt_file, "r", encoding="utf-8") as f:
+ samples = json.load(f)
+
+ res = []
+ for idx, utt in enumerate(samples):
+ if utt["Singer"] == singer_name:
+ res.append(idx)
+
+ assert len(res) != 0
+ return res
+
+
+def get_uids_and_wav_paths(
+ cfg, dataset, dataset_type="train", only_specific_singer=None, return_singers=False
+):
+ dataset_dir = os.path.join(
+ cfg.OUTPUT_PATH, "preprocess/{}_version".format(cfg.PREPROCESS_VERSION), dataset
+ )
+ dataset_file = os.path.join(
+ dataset_dir, "{}.json".format(dataset_type.split("_")[-1])
+ )
+ with open(dataset_file, "r") as f:
+ utterances = json.load(f)
+
+ indexes = range(len(utterances))
+ if "golden" in dataset_type:
+ # golden_train or golden_test
+ indexes = get_golden_samples_indexes(
+ dataset, dataset_dir, split=dataset_type.split("_")[-1]
+ )
+ if only_specific_singer is not None:
+ indexes = get_specific_singer_indexes(
+ dataset_dir, only_specific_singer, dataset_type
+ )
+
+ uids = [utterances[i]["Uid"] for i in indexes]
+ wav_paths = [utterances[i]["Path"] for i in indexes]
+ singers = [utterances[i]["Singer"] for i in indexes]
+
+ if not return_singers:
+ return uids, wav_paths
+ else:
+ return uids, wav_paths, singers
diff --git a/preprocessors/bigdata.py b/preprocessors/bigdata.py
new file mode 100644
index 0000000000000000000000000000000000000000..da541191259e337933ce469503c0d057808ebb85
--- /dev/null
+++ b/preprocessors/bigdata.py
@@ -0,0 +1,145 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import json
+import os
+from collections import defaultdict
+from tqdm import tqdm
+
+
+def get_uids_and_wav_paths(cfg, dataset, dataset_type):
+ assert dataset == "bigdata"
+ dataset_dir = os.path.join(
+ cfg.OUTPUT_PATH,
+ "preprocess/{}_version".format(cfg.PREPROCESS_VERSION),
+ "bigdata/{}".format(cfg.BIGDATA_VERSION),
+ )
+ dataset_file = os.path.join(
+ dataset_dir, "{}.json".format(dataset_type.split("_")[-1])
+ )
+ with open(dataset_file, "r") as f:
+ utterances = json.load(f)
+
+ # Uids
+ uids = [u["Uid"] for u in utterances]
+
+ # Wav paths
+ wav_paths = [u["Path"] for u in utterances]
+
+ return uids, wav_paths
+
+
+def take_duration(utt):
+ return utt["Duration"]
+
+
+def main(output_path, cfg):
+ datasets = cfg.dataset
+
+ print("-" * 10)
+ print("Preparing samples for bigdata...")
+ print("Including: \n{}\n".format("\n".join(datasets)))
+
+ datasets.sort()
+ bigdata_version = "_".join(datasets)
+
+ save_dir = os.path.join(output_path, bigdata_version)
+ os.makedirs(save_dir, exist_ok=True)
+
+ train_output_file = os.path.join(save_dir, "train.json")
+ test_output_file = os.path.join(save_dir, "test.json")
+ singer_dict_file = os.path.join(save_dir, cfg.preprocess.spk2id)
+ utt2singer_file = os.path.join(save_dir, cfg.preprocess.utt2spk)
+ utt2singer = open(utt2singer_file, "a+")
+ # We select songs of standard samples as test songs
+ train = []
+ test = []
+
+ train_total_duration = 0
+ test_total_duration = 0
+
+ # Singer unique names
+ singer_names = set()
+
+ for dataset in datasets:
+ dataset_path = os.path.join(output_path, dataset)
+ train_json = os.path.join(dataset_path, "train.json")
+ test_json = os.path.join(dataset_path, "test.json")
+
+ with open(train_json, "r", encoding="utf-8") as f:
+ train_utterances = json.load(f)
+
+ with open(test_json, "r", encoding="utf-8") as f:
+ test_utterances = json.load(f)
+
+ for utt in tqdm(train_utterances):
+ train.append(utt)
+ train_total_duration += utt["Duration"]
+ singer_names.add("{}_{}".format(utt["Dataset"], utt["Singer"]))
+ utt2singer.write(
+ "{}_{}\t{}_{}\n".format(
+ utt["Dataset"], utt["Uid"], utt["Dataset"], utt["Singer"]
+ )
+ )
+
+ for utt in test_utterances:
+ test.append(utt)
+ test_total_duration += utt["Duration"]
+ singer_names.add("{}_{}".format(utt["Dataset"], utt["Singer"]))
+ utt2singer.write(
+ "{}_{}\t{}_{}\n".format(
+ utt["Dataset"], utt["Uid"], utt["Dataset"], utt["Singer"]
+ )
+ )
+
+ utt2singer.close()
+
+ train.sort(key=take_duration)
+ test.sort(key=take_duration)
+ print("#Train = {}, #Test = {}".format(len(train), len(test)))
+ print(
+ "#Train hours= {}, #Test hours= {}".format(
+ train_total_duration / 3600, test_total_duration / 3600
+ )
+ )
+
+ # Singer Look Up Table
+ singer_names = list(singer_names)
+ singer_names.sort()
+ singer_lut = {name: i for i, name in enumerate(singer_names)}
+ print("#Singers: {}\n".format(len(singer_lut)))
+
+ # Save
+ with open(train_output_file, "w") as f:
+ json.dump(train, f, indent=4, ensure_ascii=False)
+ with open(test_output_file, "w") as f:
+ json.dump(test, f, indent=4, ensure_ascii=False)
+ with open(singer_dict_file, "w") as f:
+ json.dump(singer_lut, f, indent=4, ensure_ascii=False)
+
+ # Save meta info
+ meta_info = {
+ "datasets": datasets,
+ "train": {"size": len(train), "hours": round(train_total_duration / 3600, 4)},
+ "test": {"size": len(test), "hours": round(test_total_duration / 3600, 4)},
+ "singers": {"size": len(singer_lut)},
+ }
+ singer2mins = defaultdict(float)
+ for utt in train:
+ dataset, singer, duration = utt["Dataset"], utt["Singer"], utt["Duration"]
+ singer2mins["{}_{}".format(dataset, singer)] += duration / 60
+ singer2mins = sorted(singer2mins.items(), key=lambda x: x[1], reverse=True)
+ singer2mins = dict(
+ zip([i[0] for i in singer2mins], [round(i[1], 2) for i in singer2mins])
+ )
+ meta_info["singers"]["training_minutes"] = singer2mins
+
+ with open(os.path.join(save_dir, "meta_info.json"), "w") as f:
+ json.dump(meta_info, f, indent=4, ensure_ascii=False)
+
+ for singer, min in singer2mins.items():
+ print("Singer {}: {} mins".format(singer, min))
+ print("-" * 10, "\n")
diff --git a/preprocessors/cdmusiceval.py b/preprocessors/cdmusiceval.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c4af97c17153196501200811f4e96cfcfa2f4f1
--- /dev/null
+++ b/preprocessors/cdmusiceval.py
@@ -0,0 +1,174 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from glob import glob
+import os
+import json
+import torchaudio
+from tqdm import tqdm
+from collections import defaultdict
+
+from utils.util import has_existed, remove_and_create
+from utils.audio_slicer import split_utterances_from_audio
+
+
+def split_to_utterances(input_dir, output_dir):
+ print("Splitting to utterances for {}...".format(input_dir))
+
+ files_list = glob("*", root_dir=input_dir)
+ files_list.sort()
+ for wav_file in tqdm(files_list):
+ # # Load waveform
+ # waveform, fs = torchaudio.load(os.path.join(input_dir, wav_file))
+
+ # Singer name, Song name
+ song_name, singer_name = wav_file.split("_")[2].split("-")
+ save_dir = os.path.join(output_dir, singer_name, song_name)
+
+ split_utterances_from_audio(
+ os.path.join(input_dir, wav_file), save_dir, max_duration_of_utterance=10
+ )
+
+ # # Split
+ # slicer = Slicer(sr=fs, threshold=-30.0, max_sil_kept=3000, min_interval=1000)
+ # chunks = slicer.slice(waveform)
+
+ # for i, chunk in enumerate(chunks):
+ # save_dir = os.path.join(output_dir, singer_name, song_name)
+ # os.makedirs(save_dir, exist_ok=True)
+
+ # output_file = os.path.join(save_dir, "{:04d}.wav".format(i))
+ # save_audio(output_file, chunk, fs)
+
+
+def _main(dataset_path):
+ """
+ Split to utterances
+ """
+ utterance_dir = os.path.join(dataset_path, "utterances")
+ remove_and_create(utterance_dir)
+ split_to_utterances(os.path.join(dataset_path, "vocal"), utterance_dir)
+
+
+def statistics(utterance_dir):
+ singers = []
+ songs = []
+ singers2songs = defaultdict(lambda: defaultdict(list))
+
+ singer_infos = glob(utterance_dir + "/*")
+
+ for singer_info in singer_infos:
+ singer = singer_info.split("/")[-1]
+
+ song_infos = glob(singer_info + "/*")
+
+ for song_info in song_infos:
+ song = song_info.split("/")[-1]
+
+ singers.append(singer)
+ songs.append(song)
+
+ utts = glob(song_info + "/*.wav")
+
+ for utt in utts:
+ uid = utt.split("/")[-1].split(".")[0]
+ singers2songs[singer][song].append(uid)
+
+ unique_singers = list(set(singers))
+ unique_songs = list(set(songs))
+ unique_singers.sort()
+ unique_songs.sort()
+
+ print(
+ "Statistics: {} singers, {} utterances ({} unique songs)".format(
+ len(unique_singers), len(songs), len(unique_songs)
+ )
+ )
+ print("Singers: \n{}".format("\t".join(unique_singers)))
+ return singers2songs, unique_singers
+
+
+def main(output_path, dataset_path):
+ print("-" * 10)
+ print("Preparing samples for CD Music Eval...\n")
+
+ if not os.path.exists(os.path.join(dataset_path, "utterances")):
+ print("Spliting into utterances...\n")
+ _main(dataset_path)
+
+ save_dir = os.path.join(output_path, "cdmusiceval")
+ os.makedirs(save_dir, exist_ok=True)
+ train_output_file = os.path.join(save_dir, "train.json")
+ test_output_file = os.path.join(save_dir, "test.json")
+ singer_dict_file = os.path.join(save_dir, "singers.json")
+ utt2singer_file = os.path.join(save_dir, "utt2singer")
+ if (
+ has_existed(train_output_file)
+ and has_existed(test_output_file)
+ and has_existed(singer_dict_file)
+ and has_existed(utt2singer_file)
+ ):
+ return
+ utt2singer = open(utt2singer_file, "w")
+
+ # Load
+ utt_path = os.path.join(dataset_path, "utterances")
+ singers2songs, unique_singers = statistics(utt_path)
+
+ # We select songs of standard samples as test songs
+ train = []
+ test = []
+
+ train_index_count = 0
+ test_index_count = 0
+
+ train_total_duration = 0
+ test_total_duration = 0
+
+ for singer, songs in tqdm(singers2songs.items()):
+ song_names = list(songs.keys())
+
+ for chosen_song in song_names:
+ for chosen_uid in songs[chosen_song]:
+ res = {
+ "Dataset": "cdmusiceval",
+ "Singer": singer,
+ "Uid": "{}_{}_{}".format(singer, chosen_song, chosen_uid),
+ }
+ res["Path"] = "{}/{}/{}.wav".format(singer, chosen_song, chosen_uid)
+ res["Path"] = os.path.join(utt_path, res["Path"])
+ assert os.path.exists(res["Path"])
+
+ waveform, sample_rate = torchaudio.load(res["Path"])
+ duration = waveform.size(-1) / sample_rate
+ res["Duration"] = duration
+
+ if duration <= 1e-8:
+ continue
+
+ res["index"] = test_index_count
+ test_total_duration += duration
+ test.append(res)
+ test_index_count += 1
+
+ utt2singer.write("{}\t{}\n".format(res["Uid"], res["Singer"]))
+
+ print("#Train = {}, #Test = {}".format(len(train), len(test)))
+ print(
+ "#Train hours= {}, #Test hours= {}".format(
+ train_total_duration / 3600, test_total_duration / 3600
+ )
+ )
+
+ # Save train.json and test.json
+ with open(train_output_file, "w") as f:
+ json.dump(train, f, indent=4, ensure_ascii=False)
+ with open(test_output_file, "w") as f:
+ json.dump(test, f, indent=4, ensure_ascii=False)
+
+ # Save singers.json
+ singer_lut = {name: i for i, name in enumerate(unique_singers)}
+ with open(singer_dict_file, "w") as f:
+ json.dump(singer_lut, f, indent=4, ensure_ascii=False)
diff --git a/preprocessors/coco.py b/preprocessors/coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ac462f32d7a58ed245f58d13a44e02b4efef06b
--- /dev/null
+++ b/preprocessors/coco.py
@@ -0,0 +1,100 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import json
+import torchaudio
+from tqdm import tqdm
+from glob import glob
+from collections import defaultdict
+
+from utils.util import has_existed
+from preprocessors import GOLDEN_TEST_SAMPLES
+
+
+def get_test_songs():
+ return ["007Di Da Di"]
+
+
+def coco_statistics(data_dir):
+ song2utts = defaultdict(list)
+
+ song_infos = glob(data_dir + "/*")
+
+ for song in song_infos:
+ song_name = song.split("/")[-1]
+ utts = glob(song + "/*.wav")
+ for utt in utts:
+ uid = utt.split("/")[-1].split(".")[0]
+ song2utts[song_name].append(uid)
+
+ print("Coco: {} songs".format(len(song_infos)))
+ return song2utts
+
+
+def main(output_path, dataset_path):
+ print("-" * 10)
+ print("Preparing datasets for Coco...\n")
+
+ save_dir = os.path.join(output_path, "coco")
+ train_output_file = os.path.join(save_dir, "train.json")
+ test_output_file = os.path.join(save_dir, "test.json")
+ if has_existed(test_output_file):
+ return
+
+ # Load
+ song2utts = coco_statistics(dataset_path)
+ test_songs = get_test_songs()
+
+ # We select songs of standard samples as test songs
+ train = []
+ test = []
+
+ train_index_count = 0
+ test_index_count = 0
+
+ train_total_duration = 0
+ test_total_duration = 0
+
+ for song_name, uids in tqdm(song2utts.items()):
+ for chosen_uid in uids:
+ res = {
+ "Dataset": "coco",
+ "Singer": "coco",
+ "Song": song_name,
+ "Uid": "{}_{}".format(song_name, chosen_uid),
+ }
+ res["Path"] = "{}/{}.wav".format(song_name, chosen_uid)
+ res["Path"] = os.path.join(dataset_path, res["Path"])
+ assert os.path.exists(res["Path"])
+
+ waveform, sample_rate = torchaudio.load(res["Path"])
+ duration = waveform.size(-1) / sample_rate
+ res["Duration"] = duration
+
+ if song_name in test_songs:
+ res["index"] = test_index_count
+ test_total_duration += duration
+ test.append(res)
+ test_index_count += 1
+ else:
+ res["index"] = train_index_count
+ train_total_duration += duration
+ train.append(res)
+ train_index_count += 1
+
+ print("#Train = {}, #Test = {}".format(len(train), len(test)))
+ print(
+ "#Train hours= {}, #Test hours= {}".format(
+ train_total_duration / 3600, test_total_duration / 3600
+ )
+ )
+
+ # Save
+ os.makedirs(save_dir, exist_ok=True)
+ with open(train_output_file, "w") as f:
+ json.dump(train, f, indent=4, ensure_ascii=False)
+ with open(test_output_file, "w") as f:
+ json.dump(test, f, indent=4, ensure_ascii=False)
diff --git a/preprocessors/cocoeval.py b/preprocessors/cocoeval.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cb3604c9d540d4e8f166f459a1ac992a1bd7a17
--- /dev/null
+++ b/preprocessors/cocoeval.py
@@ -0,0 +1,99 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import random
+import os
+import json
+import torchaudio
+from tqdm import tqdm
+from glob import glob
+from collections import defaultdict
+
+from utils.util import has_existed
+from utils.audio_slicer import split_utterances_from_audio
+from preprocessors import GOLDEN_TEST_SAMPLES
+
+
+def _split_utts():
+ raw_dir = "/mnt/chongqinggeminiceph1fs/geminicephfs/wx-mm-spr-xxxx/xueyaozhang/dataset/李玟/cocoeval/raw"
+ output_root = "/mnt/chongqinggeminiceph1fs/geminicephfs/wx-mm-spr-xxxx/xueyaozhang/dataset/李玟/cocoeval/utterances"
+
+ if os.path.exists(output_root):
+ os.system("rm -rf {}".format(output_root))
+
+ vocal_files = glob(os.path.join(raw_dir, "*/vocal.wav"))
+ for vocal_f in tqdm(vocal_files):
+ song_name = vocal_f.split("/")[-2]
+
+ output_dir = os.path.join(output_root, song_name)
+ os.makedirs(output_dir, exist_ok=True)
+
+ split_utterances_from_audio(vocal_f, output_dir, min_interval=300)
+
+
+def cocoeval_statistics(data_dir):
+ song2utts = defaultdict(list)
+
+ song_infos = glob(data_dir + "/*")
+
+ for song in song_infos:
+ song_name = song.split("/")[-1]
+ utts = glob(song + "/*.wav")
+ for utt in utts:
+ uid = utt.split("/")[-1].split(".")[0]
+ song2utts[song_name].append(uid)
+
+ print("Cocoeval: {} songs".format(len(song_infos)))
+ return song2utts
+
+
+def main(output_path, dataset_path):
+ print("-" * 10)
+ print("Preparing datasets for Cocoeval...\n")
+
+ save_dir = os.path.join(output_path, "cocoeval")
+ test_output_file = os.path.join(save_dir, "test.json")
+ if has_existed(test_output_file):
+ return
+
+ # Load
+ song2utts = cocoeval_statistics(dataset_path)
+
+ train, test = [], []
+ train_index_count, test_index_count = 0, 0
+ train_total_duration, test_total_duration = 0.0, 0.0
+
+ for song_name, uids in tqdm(song2utts.items()):
+ for chosen_uid in uids:
+ res = {
+ "Dataset": "cocoeval",
+ "Singer": "TBD",
+ "Song": song_name,
+ "Uid": "{}_{}".format(song_name, chosen_uid),
+ }
+ res["Path"] = "{}/{}.wav".format(song_name, chosen_uid)
+ res["Path"] = os.path.join(dataset_path, res["Path"])
+ assert os.path.exists(res["Path"])
+
+ waveform, sample_rate = torchaudio.load(res["Path"])
+ duration = waveform.size(-1) / sample_rate
+ res["Duration"] = duration
+
+ res["index"] = test_index_count
+ test_total_duration += duration
+ test.append(res)
+ test_index_count += 1
+
+ print("#Train = {}, #Test = {}".format(len(train), len(test)))
+ print(
+ "#Train hours= {}, #Test hours= {}".format(
+ train_total_duration / 3600, test_total_duration / 3600
+ )
+ )
+
+ # Save
+ os.makedirs(save_dir, exist_ok=True)
+ with open(test_output_file, "w") as f:
+ json.dump(test, f, indent=4, ensure_ascii=False)
diff --git a/preprocessors/csd.py b/preprocessors/csd.py
new file mode 100644
index 0000000000000000000000000000000000000000..645a8b3d1e2bc556ffa7ee36d9b889422e500619
--- /dev/null
+++ b/preprocessors/csd.py
@@ -0,0 +1,202 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import json
+import os
+import glob
+from tqdm import tqdm
+import torchaudio
+import pandas as pd
+from glob import glob
+from collections import defaultdict
+
+from utils.io import save_audio
+from utils.util import has_existed
+from preprocessors import GOLDEN_TEST_SAMPLES
+
+
+def save_utterance(output_file, waveform, fs, start, end, overlap=0.1):
+ """
+ waveform: [#channel, audio_len]
+ start, end, overlap: seconds
+ """
+ start = int((start - overlap) * fs)
+ end = int((end + overlap) * fs)
+ utterance = waveform[:, start:end]
+ save_audio(output_file, utterance, fs)
+
+
+def split_to_utterances(language_dir, output_dir):
+ print("Splitting to utterances for {}...".format(language_dir))
+ wav_dir = os.path.join(language_dir, "wav")
+ phoneme_dir = os.path.join(language_dir, "txt")
+ annot_dir = os.path.join(language_dir, "csv")
+
+ pitches = set()
+ for wav_file in tqdm(glob("{}/*.wav".format(wav_dir))):
+ # Load waveform
+ song_name = wav_file.split("/")[-1].split(".")[0]
+ waveform, fs = torchaudio.load(wav_file)
+
+ # Load utterances
+ phoneme_file = os.path.join(phoneme_dir, "{}.txt".format(song_name))
+ with open(phoneme_file, "r") as f:
+ lines = f.readlines()
+ utterances = [l.strip().split() for l in lines]
+ utterances = [utt for utt in utterances if len(utt) > 0]
+
+ # Load annotation
+ annot_file = os.path.join(annot_dir, "{}.csv".format(song_name))
+ annot_df = pd.read_csv(annot_file)
+ pitches = pitches.union(set(annot_df["pitch"]))
+ starts = annot_df["start"].tolist()
+ ends = annot_df["end"].tolist()
+ syllables = annot_df["syllable"].tolist()
+
+ # Split
+ curr = 0
+ for i, phones in enumerate(utterances):
+ sz = len(phones)
+ assert phones[0] == syllables[curr]
+ assert phones[-1] == syllables[curr + sz - 1]
+
+ s = starts[curr]
+ e = ends[curr + sz - 1]
+ curr += sz
+
+ save_dir = os.path.join(output_dir, song_name)
+ os.makedirs(save_dir, exist_ok=True)
+
+ output_file = os.path.join(save_dir, "{:04d}.wav".format(i))
+ save_utterance(output_file, waveform, fs, start=s, end=e)
+
+
+def _main(dataset_path):
+ """
+ Split to utterances
+ """
+ utterance_dir = os.path.join(dataset_path, "utterances")
+
+ for lang in ["english", "korean"]:
+ split_to_utterances(os.path.join(dataset_path, lang), utterance_dir)
+
+
+def get_test_songs():
+ golden_samples = GOLDEN_TEST_SAMPLES["csd"]
+ # every item is a tuple (language, song)
+ golden_songs = [s.split("_")[:2] for s in golden_samples]
+ # language_song, eg: en_001a
+ return golden_songs
+
+
+def csd_statistics(data_dir):
+ languages = []
+ songs = []
+ languages2songs = defaultdict(lambda: defaultdict(list))
+
+ folder_infos = glob(data_dir + "/*")
+
+ for folder_info in folder_infos:
+ folder_info_split = folder_info.split("/")[-1]
+
+ language = folder_info_split[:2]
+ song = folder_info_split[2:]
+
+ languages.append(language)
+ songs.append(song)
+
+ utts = glob(folder_info + "/*")
+
+ for utt in utts:
+ uid = utt.split("/")[-1].split(".")[0]
+ languages2songs[language][song].append(uid)
+
+ unique_languages = list(set(languages))
+ unique_songs = list(set(songs))
+ unique_languages.sort()
+ unique_songs.sort()
+
+ print(
+ "csd: {} languages, {} utterances ({} unique songs)".format(
+ len(unique_languages), len(songs), len(unique_songs)
+ )
+ )
+ print("Languages: \n{}".format("\t".join(unique_languages)))
+ return languages2songs
+
+
+def main(output_path, dataset_path):
+ print("-" * 10)
+ print("Preparing test samples for csd...\n")
+
+ if not os.path.exists(os.path.join(dataset_path, "utterances")):
+ print("Spliting into utterances...\n")
+ _main(dataset_path)
+
+ save_dir = os.path.join(output_path, "csd")
+ train_output_file = os.path.join(save_dir, "train.json")
+ test_output_file = os.path.join(save_dir, "test.json")
+ if has_existed(test_output_file):
+ return
+
+ # Load
+ csd_path = os.path.join(dataset_path, "utterances")
+
+ language2songs = csd_statistics(csd_path)
+ test_songs = get_test_songs()
+
+ # We select songs of standard samples as test songs
+ train = []
+ test = []
+
+ train_index_count = 0
+ test_index_count = 0
+
+ train_total_duration = 0
+ test_total_duration = 0
+
+ for language, songs in tqdm(language2songs.items()):
+ song_names = list(songs.keys())
+
+ for chosen_song in song_names:
+ for chosen_uid in songs[chosen_song]:
+ res = {
+ "Dataset": "csd",
+ "Singer": "Female1_{}".format(language),
+ "Uid": "{}_{}_{}".format(language, chosen_song, chosen_uid),
+ }
+ res["Path"] = "{}{}/{}.wav".format(language, chosen_song, chosen_uid)
+ res["Path"] = os.path.join(csd_path, res["Path"])
+ assert os.path.exists(res["Path"])
+
+ waveform, sample_rate = torchaudio.load(res["Path"])
+ duration = waveform.size(-1) / sample_rate
+ res["Duration"] = duration
+
+ if [language, chosen_song] in test_songs:
+ res["index"] = test_index_count
+ test_total_duration += duration
+ test.append(res)
+ test_index_count += 1
+ else:
+ res["index"] = train_index_count
+ train_total_duration += duration
+ train.append(res)
+ train_index_count += 1
+
+ print("#Train = {}, #Test = {}".format(len(train), len(test)))
+ print(
+ "#Train hours= {}, #Test hours= {}".format(
+ train_total_duration / 3600, test_total_duration / 3600
+ )
+ )
+
+ # Save
+ os.makedirs(save_dir, exist_ok=True)
+ with open(train_output_file, "w") as f:
+ json.dump(train, f, indent=4, ensure_ascii=False)
+ with open(test_output_file, "w") as f:
+ json.dump(test, f, indent=4, ensure_ascii=False)
diff --git a/preprocessors/customsvcdataset.py b/preprocessors/customsvcdataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..84f204dea16fb7a00245786e5b3affb59c24b5eb
--- /dev/null
+++ b/preprocessors/customsvcdataset.py
@@ -0,0 +1,145 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from glob import glob
+import os
+import json
+import torchaudio
+from tqdm import tqdm
+from collections import defaultdict
+
+from utils.util import has_existed
+
+
+def statistics(utterance_dir):
+ singers = []
+ songs = []
+ utts_all = []
+ singers2songs = defaultdict(lambda: defaultdict(list))
+
+ singer_infos = glob(utterance_dir + "/*")
+
+ for singer_info in singer_infos:
+ singer = singer_info.split("/")[-1]
+
+ song_infos = glob(singer_info + "/*")
+
+ for song_info in song_infos:
+ song = song_info.split("/")[-1]
+
+ singers.append(singer)
+ songs.append(song)
+
+ utts = glob(song_info + "/*.wav")
+ utts_all.extend(utts)
+
+ for utt in utts:
+ uid = utt.split("/")[-1].split(".")[0]
+ singers2songs[singer][song].append(uid)
+
+ unique_singers = list(set(singers))
+ unique_songs = list(set(songs))
+ unique_singers.sort()
+ unique_songs.sort()
+
+ print(
+ "Statistics: {} singers, {} utterances ({} unique songs)".format(
+ len(unique_singers), len(utts_all), len(unique_songs)
+ )
+ )
+ print("Singers: \n{}".format("\t".join(unique_singers)))
+ return singers2songs, unique_singers
+
+
+def main(output_path, dataset_path, dataset_name):
+ print("-" * 10)
+ print("Preparing samples for {}...\n".format(dataset_name))
+
+ save_dir = os.path.join(output_path, dataset_name)
+ os.makedirs(save_dir, exist_ok=True)
+
+ train_output_file = os.path.join(save_dir, "train.json")
+ test_output_file = os.path.join(save_dir, "test.json")
+ singer_dict_file = os.path.join(save_dir, "singers.json")
+ utt2singer_file = os.path.join(save_dir, "utt2singer")
+ if (
+ has_existed(train_output_file)
+ and has_existed(test_output_file)
+ and has_existed(singer_dict_file)
+ and has_existed(utt2singer_file)
+ ):
+ return
+ utt2singer = open(utt2singer_file, "w")
+
+ # Load
+ singers2songs, unique_singers = statistics(dataset_path)
+
+ # We select songs of standard samples as test songs
+ train = []
+ test = []
+ test_songs = set()
+
+ train_index_count = 0
+ test_index_count = 0
+
+ train_total_duration = 0
+ test_total_duration = 0
+
+ for singer, songs in singers2songs.items():
+ song_names = list(songs.keys())
+
+ print("Singer {}...".format(singer))
+ for chosen_song in tqdm(song_names):
+ for chosen_uid in songs[chosen_song]:
+ res = {
+ "Dataset": dataset_name,
+ "Singer": singer,
+ "Uid": "{}_{}_{}".format(singer, chosen_song, chosen_uid),
+ }
+ res["Path"] = "{}/{}/{}.wav".format(singer, chosen_song, chosen_uid)
+ res["Path"] = os.path.join(dataset_path, res["Path"])
+ assert os.path.exists(res["Path"])
+
+ waveform, sample_rate = torchaudio.load(res["Path"])
+ duration = waveform.size(-1) / sample_rate
+ res["Duration"] = duration
+
+ # Remove the utterance whose duration is shorter than 0.1s
+ if duration <= 1e-2:
+ continue
+
+ # Place into train or test
+ if "{}_{}".format(singer, chosen_song) not in test_songs:
+ test_songs.add("{}_{}".format(singer, chosen_song))
+
+ res["index"] = test_index_count
+ test_total_duration += duration
+ test.append(res)
+ test_index_count += 1
+ else:
+ res["index"] = train_index_count
+ train_total_duration += duration
+ train.append(res)
+ train_index_count += 1
+
+ utt2singer.write("{}\t{}\n".format(res["Uid"], res["Singer"]))
+
+ print("#Train = {}, #Test = {}".format(len(train), len(test)))
+ print(
+ "#Train hours= {}, #Test hours= {}".format(
+ train_total_duration / 3600, test_total_duration / 3600
+ )
+ )
+
+ # Save train.json and test.json
+ with open(train_output_file, "w") as f:
+ json.dump(train, f, indent=4, ensure_ascii=False)
+ with open(test_output_file, "w") as f:
+ json.dump(test, f, indent=4, ensure_ascii=False)
+
+ # Save singers.json
+ singer_lut = {name: i for i, name in enumerate(unique_singers)}
+ with open(singer_dict_file, "w") as f:
+ json.dump(singer_lut, f, indent=4, ensure_ascii=False)
diff --git a/preprocessors/hifitts.py b/preprocessors/hifitts.py
new file mode 100644
index 0000000000000000000000000000000000000000..78b069cadefc0d02510df13243a051e3e7755d3e
--- /dev/null
+++ b/preprocessors/hifitts.py
@@ -0,0 +1,127 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import json
+import torchaudio
+from tqdm import tqdm
+from glob import glob
+from collections import defaultdict
+
+from utils.util import has_existed
+
+
+def main(output_path, dataset_path):
+ print("-" * 10)
+ print("Preparing samples for hifitts...\n")
+
+ save_dir = os.path.join(output_path, "hifitts")
+ os.makedirs(save_dir, exist_ok=True)
+ print("Saving to ", save_dir)
+
+ train_output_file = os.path.join(save_dir, "train.json")
+ test_output_file = os.path.join(save_dir, "test.json")
+ valid_output_file = os.path.join(save_dir, "valid.json")
+ singer_dict_file = os.path.join(save_dir, "singers.json")
+ utt2singer_file = os.path.join(save_dir, "utt2singer")
+ if has_existed(train_output_file):
+ return
+ utt2singer = open(utt2singer_file, "w")
+
+ hifitts_path = dataset_path
+
+ speakers = []
+
+ train = []
+ test = []
+ valid = []
+
+ train_index_count = 0
+ test_index_count = 0
+ valid_index_count = 0
+
+ train_total_duration = 0
+ test_total_duration = 0
+ valid_total_duration = 0
+
+ distribution_infos = glob(hifitts_path + "/*.json")
+
+ for distribution_info in tqdm(
+ distribution_infos, desc="Extracting metadata from distributions"
+ ):
+ distribution = distribution_info.split("/")[-1].split(".")[0]
+ speaker_id = distribution.split("_")[0]
+ speakers.append(speaker_id)
+
+ with open(distribution_info, "r", encoding="utf-8") as file:
+ for line in file:
+ entry = json.loads(line)
+ utt_path = entry.get("audio_filepath")
+ chosen_book = utt_path.split("/")[-2]
+ chosen_uid = utt_path.split("/")[-1].split(".")[0]
+ duration = entry.get("duration")
+ text = entry.get("text_normalized")
+ path = os.path.join(hifitts_path, utt_path)
+ assert os.path.exists(path)
+
+ res = {
+ "Dataset": "hifitts",
+ "Singer": speaker_id,
+ "Uid": "{}#{}#{}#{}".format(
+ distribution, speaker_id, chosen_book, chosen_uid
+ ),
+ "Text": text,
+ "Path": path,
+ "Duration": duration,
+ }
+
+ if "train" in distribution:
+ res["index"] = train_index_count
+ train_total_duration += duration
+ train.append(res)
+ train_index_count += 1
+
+ elif "test" in distribution:
+ res["index"] = test_index_count
+ test_total_duration += duration
+ test.append(res)
+ test_index_count += 1
+
+ elif "dev" in distribution:
+ res["index"] = valid_index_count
+ valid_total_duration += duration
+ valid.append(res)
+ valid_index_count += 1
+
+ utt2singer.write("{}\t{}\n".format(res["Uid"], res["Singer"]))
+
+ unique_speakers = list(set(speakers))
+ unique_speakers.sort()
+
+ print("Speakers: \n{}".format("\t".join(unique_speakers)))
+
+ print(
+ "#Train = {}, #Test = {}, #Valid = {}".format(len(train), len(test), len(valid))
+ )
+ print(
+ "#Train hours= {}, #Test hours= {}, #Valid hours= {}".format(
+ train_total_duration / 3600,
+ test_total_duration / 3600,
+ valid_total_duration / 3600,
+ )
+ )
+
+ # Save train.json, test.json, valid.json
+ with open(train_output_file, "w") as f:
+ json.dump(train, f, indent=4, ensure_ascii=False)
+ with open(test_output_file, "w") as f:
+ json.dump(test, f, indent=4, ensure_ascii=False)
+ with open(valid_output_file, "w") as f:
+ json.dump(valid, f, indent=4, ensure_ascii=False)
+
+ # Save singers.json
+ singer_lut = {name: i for i, name in enumerate(unique_speakers)}
+ with open(singer_dict_file, "w") as f:
+ json.dump(singer_lut, f, indent=4, ensure_ascii=False)
diff --git a/preprocessors/kising.py b/preprocessors/kising.py
new file mode 100644
index 0000000000000000000000000000000000000000..f81b852aeb05a5d30fe22df07af4c4ab9defda2a
--- /dev/null
+++ b/preprocessors/kising.py
@@ -0,0 +1,116 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import random
+import os
+import json
+import torchaudio
+from tqdm import tqdm
+from glob import glob
+from collections import defaultdict
+
+from utils.util import has_existed
+from preprocessors import GOLDEN_TEST_SAMPLES
+
+
+def get_test_folders():
+ golden_samples = GOLDEN_TEST_SAMPLES["kising"]
+ # every item is a string
+ golden_folders = [s.split("_")[:1] for s in golden_samples]
+ # folder, eg: 422
+ return golden_folders
+
+
+def KiSing_statistics(data_dir):
+ folders = []
+ folders2utts = defaultdict(list)
+
+ folder_infos = glob(data_dir + "/*")
+
+ for folder_info in folder_infos:
+ folder = folder_info.split("/")[-1]
+
+ folders.append(folder)
+
+ utts = glob(folder_info + "/*.wav")
+
+ for utt in utts:
+ uid = utt.split("/")[-1].split(".")[0]
+ folders2utts[folder].append(uid)
+
+ unique_folders = list(set(folders))
+ unique_folders.sort()
+
+ print("KiSing: {} unique songs".format(len(unique_folders)))
+ return folders2utts
+
+
+def main(output_path, dataset_path):
+ print("-" * 10)
+ print("Preparing test samples for KiSing...\n")
+
+ save_dir = os.path.join(output_path, "kising")
+ train_output_file = os.path.join(save_dir, "train.json")
+ test_output_file = os.path.join(save_dir, "test.json")
+ if has_existed(test_output_file):
+ return
+
+ # Load
+ KiSing_dir = dataset_path
+
+ folders2utts = KiSing_statistics(KiSing_dir)
+ test_folders = get_test_folders()
+
+ # We select songs of standard samples as test songs
+ train = []
+ test = []
+
+ train_index_count = 0
+ test_index_count = 0
+
+ train_total_duration = 0
+ test_total_duration = 0
+
+ folder_names = list(folders2utts.keys())
+
+ for chosen_folder in folder_names:
+ for chosen_uid in folders2utts[chosen_folder]:
+ res = {
+ "Dataset": "kising",
+ "Singer": "female1",
+ "Uid": "{}_{}".format(chosen_folder, chosen_uid),
+ }
+ res["Path"] = "{}/{}.wav".format(chosen_folder, chosen_uid)
+ res["Path"] = os.path.join(KiSing_dir, res["Path"])
+ assert os.path.exists(res["Path"])
+
+ waveform, sample_rate = torchaudio.load(res["Path"])
+ duration = waveform.size(-1) / sample_rate
+ res["Duration"] = duration
+
+ if ([chosen_folder]) in test_folders:
+ res["index"] = test_index_count
+ test_total_duration += duration
+ test.append(res)
+ test_index_count += 1
+ else:
+ res["index"] = train_index_count
+ train_total_duration += duration
+ train.append(res)
+ train_index_count += 1
+
+ print("#Train = {}, #Test = {}".format(len(train), len(test)))
+ print(
+ "#Train hours= {}, #Test hours= {}".format(
+ train_total_duration / 3600, test_total_duration / 3600
+ )
+ )
+
+ # Save
+ os.makedirs(save_dir, exist_ok=True)
+ with open(train_output_file, "w") as f:
+ json.dump(train, f, indent=4, ensure_ascii=False)
+ with open(test_output_file, "w") as f:
+ json.dump(test, f, indent=4, ensure_ascii=False)
diff --git a/preprocessors/librilight.py b/preprocessors/librilight.py
new file mode 100644
index 0000000000000000000000000000000000000000..eef5d6f90b66c742aed76439c1f2e70a3636b5cf
--- /dev/null
+++ b/preprocessors/librilight.py
@@ -0,0 +1,329 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import json
+from tqdm import tqdm
+import os
+import torchaudio
+import torch
+
+
+from utils.mfa_prepare import (
+ process_wav_files,
+ get_wav_files,
+ filter_wav_files_by_length,
+)
+from utils.cut_by_vad import cut_segments
+from utils.whisper_transcription import asr_main
+from utils.util import has_existed
+
+import subprocess
+import random
+from collections import defaultdict
+from glob import glob
+import shutil
+
+
+def librilight_statistics(data_dir):
+ """Get statistics for librilight dataset"""
+ distribution2speakers2utts = defaultdict(lambda: defaultdict(list))
+ distribution_infos = glob(data_dir + "/*")
+ for distribution_info in distribution_infos:
+ distribution = distribution_info.split("/")[-1]
+ print(distribution)
+ speaker_infos = glob(distribution_info + "/*")
+ if len(speaker_infos) == 0:
+ continue
+ for speaker_info in speaker_infos:
+ speaker = speaker_info.split("/")[-1]
+ utts = glob(speaker_info + "/*.wav")
+ for utt in utts:
+ uid = utt.split("/")[-1].split(".")[0]
+ distribution2speakers2utts[distribution][speaker].append(uid)
+ return distribution2speakers2utts
+
+
+def get_speakers_from_directory(directory):
+ return [
+ d for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d))
+ ]
+
+
+def split_dataset_by_speaker(base_dir, train_ratio=0.8, dev_ratio=0.1):
+ train_dir = os.path.join(base_dir, "train")
+ dev_dir = os.path.join(base_dir, "dev")
+ eval_dir = os.path.join(base_dir, "eval")
+
+ # Check if dataset is already split
+ if has_existed(train_dir) or has_existed(dev_dir) or has_existed(eval_dir):
+ print("Dataset already split. Calculating speakers...")
+ train_speakers = get_speakers_from_directory(train_dir)
+ dev_speakers = get_speakers_from_directory(dev_dir)
+ eval_speakers = get_speakers_from_directory(eval_dir)
+ all_speakers = train_speakers + dev_speakers + eval_speakers
+ unique_speakers = list(set(all_speakers))
+ unique_speakers.sort()
+ return unique_speakers
+
+ # List all directories in the base directory
+ all_speakers = [
+ d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))
+ ]
+ random.shuffle(all_speakers)
+
+ # Calculate split sizes
+ total_speakers = len(all_speakers)
+ train_size = int(total_speakers * train_ratio)
+ dev_size = int(total_speakers * dev_ratio)
+ eval_size = total_speakers - train_size - dev_size
+ print("Total speakers:", total_speakers)
+ print("Train speakers:", train_size)
+ print("Dev speakers:", dev_size)
+ print("Eval speakers:", eval_size)
+
+ # Split directories
+ train_speakers = all_speakers[:train_size]
+ dev_speakers = all_speakers[train_size : train_size + dev_size]
+ eval_speakers = all_speakers[train_size + dev_size :]
+
+ # Function to move directories
+ def move_speakers(speakers, target_dir):
+ for speaker in speakers:
+ shutil.move(
+ os.path.join(base_dir, speaker), os.path.join(target_dir, speaker)
+ )
+
+ # Move directories
+ print("Moving directories...")
+ print("Moving Train speakers...")
+ move_speakers(train_speakers, train_dir)
+ print("Moving Dev speakers...")
+ move_speakers(dev_speakers, dev_dir)
+ print("Moving Eval speakers...")
+ move_speakers(eval_speakers, eval_dir)
+
+ unique_speakers = list(set(all_speakers))
+ unique_speakers.sort()
+ return unique_speakers
+
+
+def save_meta_data(save_dir, processed_dir, distribution2speakers2utts, speakers):
+ """Save metadata for librilight dataset"""
+ os.makedirs(save_dir, exist_ok=True)
+ train_output_file = os.path.join(save_dir, "train.json")
+ valid_output_file = os.path.join(save_dir, "dev.json")
+ test_output_file = os.path.join(save_dir, "eval.json")
+ singer_dict_file = os.path.join(save_dir, "singers.json")
+ utt2singer_file = os.path.join(save_dir, "utt2singer")
+ utt2singer = open(utt2singer_file, "w")
+ if has_existed(train_output_file):
+ print("Metadata already exists. Skipping...")
+ return
+
+ train = []
+ test = []
+ valid = []
+
+ train_index_count = 0
+ test_index_count = 0
+ valid_index_count = 0
+
+ train_total_duration = 0
+ test_total_duration = 0
+ valid_total_duration = 0
+
+ # Save metadata
+ for distribution, speakers2utts in tqdm(distribution2speakers2utts.items()):
+ for speaker, utts in tqdm(speakers2utts.items()):
+ for chosen_uid in utts:
+ res = {
+ "Dataset": "librilight",
+ "Singer": speaker,
+ "Uid": "{}#{}#{}".format(distribution, speaker, chosen_uid),
+ }
+ res["Path"] = "{}/{}/{}.wav".format(distribution, speaker, chosen_uid)
+ res["Path"] = os.path.join(processed_dir, res["Path"])
+ assert os.path.exists(res["Path"])
+
+ text_file_path = os.path.join(
+ processed_dir,
+ distribution,
+ speaker,
+ chosen_uid + ".txt",
+ )
+ with open(text_file_path, "r") as f:
+ lines = f.readlines()
+ assert len(lines) == 1
+ text = lines[0].strip()
+ res["Text"] = text
+
+ waveform, sample_rate = torchaudio.load(res["Path"])
+ duration = waveform.size(-1) / sample_rate
+ res["Duration"] = duration
+
+ if "train" in distribution:
+ res["index"] = train_index_count
+ train_total_duration += duration
+ train.append(res)
+ train_index_count += 1
+ elif "dev" in distribution:
+ res["index"] = valid_index_count
+ valid_total_duration += duration
+ valid.append(res)
+ valid_index_count += 1
+ elif "eval" in distribution:
+ res["index"] = test_index_count
+ test_total_duration += duration
+ test.append(res)
+ test_index_count += 1
+ utt2singer.write("{}\t{}\n".format(res["Uid"], res["Singer"]))
+ print("Done!")
+ print(
+ "Utterance count: train = {}, dev = {}, eval = {}".format(
+ len(train), len(valid), len(test)
+ )
+ )
+ print(
+ "#Train duration= {}, #Dev duration= {}, #Eval duration= {}".format(
+ train_total_duration / 3600,
+ valid_total_duration / 3600,
+ test_total_duration / 3600,
+ )
+ )
+ with open(train_output_file, "w") as f:
+ json.dump(train, f, indent=4, ensure_ascii=False)
+ with open(test_output_file, "w") as f:
+ json.dump(test, f, indent=4, ensure_ascii=False)
+ with open(valid_output_file, "w") as f:
+ json.dump(valid, f, indent=4, ensure_ascii=False)
+ utt2singer.close()
+ singer_lut = {name: i for i, name in enumerate(speakers)}
+ with open(singer_dict_file, "w") as f:
+ json.dump(singer_lut, f, indent=4, ensure_ascii=False)
+ print("Metadata saved to", save_dir)
+
+
+def main(output_path, dataset_path, cfg):
+ """Preprocess librilight dataset"""
+ n_cpus = cfg.n_cpus # number of cpus to use for preprocessing
+ n_gpus = cfg.n_gpus # number of gpus to use for transcription
+ cut_length = cfg.cut_length # target length of utterance in seconds
+ max_length = cfg.max_length # max length of utterance in seconds
+
+ # MFA files
+ mfa_config_path = cfg.mfa_config_path # path to mfa config file
+ mfa_dict_path = cfg.mfa_dict_path # path to mfa dict file
+ mfa_model_path = cfg.mfa_model_path # path to mfa model file
+
+ # check if mfa files exist
+ if (
+ not os.path.exists(mfa_dict_path)
+ or not os.path.exists(mfa_model_path)
+ or not os.path.exists(mfa_config_path)
+ ):
+ raise Exception("MFA files not found.")
+
+ # Whisper model id
+ model_id = cfg.whisper_model_id # id of whisper model to use for transcription
+
+ subsets = [
+ d
+ for d in os.listdir(dataset_path)
+ if (
+ os.path.isdir(os.path.join(dataset_path, d))
+ and d in ["tiny", "small", "medium", "large"]
+ )
+ ]
+ print("Found subsets:", subsets)
+
+ if len(subsets) == 0:
+ print("No subsets found. Exiting...")
+ return
+ # Preprocess each subset
+ for subset in subsets:
+ # Construct paths based on the base path
+ print("Pre-proccessing Libri-light subset:", subset)
+ raw_dir = f"{dataset_path}/{subset}"
+ save_dir = f"{output_path}/{subset}"
+ processed_dir = f"{dataset_path}/processed/{subset}"
+ os.makedirs(processed_dir, exist_ok=True)
+ os.makedirs(save_dir, exist_ok=True)
+
+ # Step 1: Segmentation
+ print("-" * 10)
+ print("Step 1: Segmentation")
+ print("Cutting audio files...")
+
+ cut_segments(raw_dir, processed_dir, cut_length, n_cpus)
+
+ # Steps 2 & 3: Filter and Preprocess
+ print("-" * 10)
+ print("Step 2 & 3: Filter and Preprocess")
+ print("Filtering and preprocessing audio files...")
+
+ wav_files = get_wav_files(processed_dir)
+ filtered_wav_files = filter_wav_files_by_length(wav_files, max_length)
+ process_wav_files(filtered_wav_files, processed_dir, n_cpus)
+
+ # Step 4 & 5: Transcription & Text-preprocess
+ print("-" * 10)
+ print("Step 4 & 5: Transcription & Text-preprocess")
+ print("Transcribing audio files...")
+
+ n_gpus = min(n_gpus, torch.cuda.device_count())
+ asr_main(processed_dir, n_gpus, model_id)
+
+ # Step 6: MFA Align
+ print("-" * 10)
+ print("Step 6: MFA Align")
+ print("Aligning audio files...")
+
+ command = [
+ "mfa",
+ "align",
+ "-v",
+ "-j",
+ str(n_cpus),
+ "-c",
+ mfa_config_path,
+ processed_dir,
+ mfa_dict_path,
+ mfa_model_path,
+ processed_dir,
+ "--output_format",
+ "long_textgrid",
+ "--clean",
+ "--overwrite",
+ ]
+ subprocess.run(command, text=True)
+
+ # Step 7: train/dev/eval split
+ print("-" * 10)
+ print("Step 7: train/dev/eval split")
+ print("Splitting dataset by speaker...")
+
+ speakers = split_dataset_by_speaker(processed_dir)
+
+ # Step 8: Statistics
+ print("-" * 10)
+ print("Step 8: Statistics")
+ print("Calculating statistics...")
+
+ distribution2speakers2utts = librilight_statistics(processed_dir)
+
+ # Step 9: Save metadata
+ print("-" * 10)
+ print("Step 9: Save metadata")
+ print("Preparing Metadata for Librilight...")
+
+ save_meta_data(save_dir, processed_dir, distribution2speakers2utts, speakers)
+ print("Preprocessing subset", subset, "done!")
+ print("-" * 10)
+
+
+if __name__ == "__main__":
+ dataset_path = "/path/to/dataset/librilight"
+ output_path = "/path/to/output"
+ main(output_path, dataset_path)
diff --git a/preprocessors/libritts.py b/preprocessors/libritts.py
new file mode 100644
index 0000000000000000000000000000000000000000..adc86f0cecc8101a191e8a59764edaa845541d53
--- /dev/null
+++ b/preprocessors/libritts.py
@@ -0,0 +1,171 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import json
+import torchaudio
+from tqdm import tqdm
+from glob import glob
+from collections import defaultdict
+
+from utils.util import has_existed
+
+
+def libritts_statistics(data_dir):
+ speakers = []
+ distribution2speakers2pharases2utts = defaultdict(
+ lambda: defaultdict(lambda: defaultdict(list))
+ )
+
+ distribution_infos = glob(data_dir + "/*")
+
+ for distribution_info in distribution_infos:
+ distribution = distribution_info.split("/")[-1]
+ print(distribution)
+
+ speaker_infos = glob(distribution_info + "/*")
+
+ if len(speaker_infos) == 0:
+ continue
+
+ for speaker_info in speaker_infos:
+ speaker = speaker_info.split("/")[-1]
+
+ speakers.append(speaker)
+
+ pharase_infos = glob(speaker_info + "/*")
+
+ for pharase_info in pharase_infos:
+ pharase = pharase_info.split("/")[-1]
+
+ utts = glob(pharase_info + "/*.wav")
+
+ for utt in utts:
+ uid = utt.split("/")[-1].split(".")[0]
+ distribution2speakers2pharases2utts[distribution][speaker][
+ pharase
+ ].append(uid)
+
+ unique_speakers = list(set(speakers))
+ unique_speakers.sort()
+
+ print("Speakers: \n{}".format("\t".join(unique_speakers)))
+ return distribution2speakers2pharases2utts, unique_speakers
+
+
+def main(output_path, dataset_path):
+ print("-" * 10)
+ print("Preparing samples for libritts...\n")
+
+ save_dir = os.path.join(output_path, "libritts")
+ os.makedirs(save_dir, exist_ok=True)
+ train_output_file = os.path.join(save_dir, "train.json")
+ test_output_file = os.path.join(save_dir, "test.json")
+ valid_output_file = os.path.join(save_dir, "valid.json")
+ singer_dict_file = os.path.join(save_dir, "singers.json")
+ utt2singer_file = os.path.join(save_dir, "utt2singer")
+ if has_existed(train_output_file):
+ return
+ utt2singer = open(utt2singer_file, "w")
+
+ # Load
+ libritts_path = dataset_path
+
+ distribution2speakers2pharases2utts, unique_speakers = libritts_statistics(
+ libritts_path
+ )
+
+ # We select pharases of standard spekaer as test songs
+ train = []
+ test = []
+ valid = []
+
+ train_index_count = 0
+ test_index_count = 0
+ valid_index_count = 0
+
+ train_total_duration = 0
+ test_total_duration = 0
+ valid_total_duration = 0
+
+ for distribution, speakers2pharases2utts in tqdm(
+ distribution2speakers2pharases2utts.items()
+ ):
+ for speaker, pharases2utts in tqdm(speakers2pharases2utts.items()):
+ pharase_names = list(pharases2utts.keys())
+
+ for chosen_pharase in pharase_names:
+ for chosen_uid in pharases2utts[chosen_pharase]:
+ res = {
+ "Dataset": "libritts",
+ "Singer": speaker,
+ "Uid": "{}#{}#{}#{}".format(
+ distribution, speaker, chosen_pharase, chosen_uid
+ ),
+ }
+ res["Path"] = "{}/{}/{}/{}.wav".format(
+ distribution, speaker, chosen_pharase, chosen_uid
+ )
+ res["Path"] = os.path.join(libritts_path, res["Path"])
+ assert os.path.exists(res["Path"])
+
+ text_file_path = os.path.join(
+ libritts_path,
+ distribution,
+ speaker,
+ chosen_pharase,
+ chosen_uid + ".normalized.txt",
+ )
+ with open(text_file_path, "r") as f:
+ lines = f.readlines()
+ assert len(lines) == 1
+ text = lines[0].strip()
+ res["Text"] = text
+
+ waveform, sample_rate = torchaudio.load(res["Path"])
+ duration = waveform.size(-1) / sample_rate
+ res["Duration"] = duration
+
+ if "test" in distribution:
+ res["index"] = test_index_count
+ test_total_duration += duration
+ test.append(res)
+ test_index_count += 1
+ elif "train" in distribution:
+ res["index"] = train_index_count
+ train_total_duration += duration
+ train.append(res)
+ train_index_count += 1
+ elif "dev" in distribution:
+ res["index"] = valid_index_count
+ valid_total_duration += duration
+ valid.append(res)
+ valid_index_count += 1
+
+ utt2singer.write("{}\t{}\n".format(res["Uid"], res["Singer"]))
+
+ print(
+ "#Train = {}, #Test = {}, #Valid = {}".format(len(train), len(test), len(valid))
+ )
+ print(
+ "#Train hours= {}, #Test hours= {}, #Valid hours= {}".format(
+ train_total_duration / 3600,
+ test_total_duration / 3600,
+ valid_total_duration / 3600,
+ )
+ )
+
+ # Save train.json and test.json
+ with open(train_output_file, "w") as f:
+ json.dump(train, f, indent=4, ensure_ascii=False)
+ with open(test_output_file, "w") as f:
+ json.dump(test, f, indent=4, ensure_ascii=False)
+ with open(valid_output_file, "w") as f:
+ json.dump(valid, f, indent=4, ensure_ascii=False)
+
+ # Save singers.json
+ singer_lut = {name: i for i, name in enumerate(unique_speakers)}
+ with open(singer_dict_file, "w") as f:
+ json.dump(singer_lut, f, indent=4, ensure_ascii=False)
diff --git a/preprocessors/lijian.py b/preprocessors/lijian.py
new file mode 100644
index 0000000000000000000000000000000000000000..459986a8d68a5b0dc0811f4f8197cd818fdc263f
--- /dev/null
+++ b/preprocessors/lijian.py
@@ -0,0 +1,151 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import glob
+import os
+import json
+import torchaudio
+from tqdm import tqdm
+from collections import defaultdict
+
+
+from utils.io import save_audio
+from utils.util import has_existed, remove_and_create
+from utils.audio_slicer import Slicer
+from preprocessors import GOLDEN_TEST_SAMPLES
+
+
+def split_to_utterances(input_dir, output_dir):
+ print("Splitting to utterances for {}...".format(input_dir))
+
+ files_list = glob.glob("*.flac", root_dir=input_dir)
+ files_list.sort()
+ for wav_file in tqdm(files_list):
+ # Load waveform
+ waveform, fs = torchaudio.load(os.path.join(input_dir, wav_file))
+
+ # Song name
+ filename = wav_file.replace(" ", "")
+ filename = filename.replace("(Live)", "")
+ song_id, filename = filename.split("李健-")
+
+ song_id = song_id.split("_")[0]
+ song_name = "{:03d}".format(int(song_id)) + filename.split("_")[0].split("-")[0]
+
+ # Split
+ slicer = Slicer(sr=fs, threshold=-30.0, max_sil_kept=3000)
+ chunks = slicer.slice(waveform)
+
+ save_dir = os.path.join(output_dir, song_name)
+ remove_and_create(save_dir)
+
+ for i, chunk in enumerate(chunks):
+ output_file = os.path.join(save_dir, "{:04d}.wav".format(i))
+ save_audio(output_file, chunk, fs)
+
+
+def _main(dataset_path):
+ """
+ Split to utterances
+ """
+ utterance_dir = os.path.join(dataset_path, "utterances")
+ split_to_utterances(os.path.join(dataset_path, "vocal_v2"), utterance_dir)
+
+
+def get_test_songs():
+ golden_samples = GOLDEN_TEST_SAMPLES["lijian"]
+ golden_songs = [s.split("_")[0] for s in golden_samples]
+ return golden_songs
+
+
+def statistics(utt_dir):
+ song2utts = defaultdict(list)
+
+ song_infos = glob.glob(utt_dir + "/*")
+ song_infos.sort()
+ for song in song_infos:
+ song_name = song.split("/")[-1]
+ utt_infos = glob.glob(song + "/*.wav")
+ utt_infos.sort()
+ for utt in utt_infos:
+ uid = utt.split("/")[-1].split(".")[0]
+ song2utts[song_name].append(uid)
+
+ utt_sum = sum([len(utts) for utts in song2utts.values()])
+ print("Li Jian: {} unique songs, {} utterances".format(len(song2utts), utt_sum))
+ return song2utts
+
+
+def main(output_path, dataset_path):
+ print("-" * 10)
+ print("Preparing test samples for Li Jian...\n")
+
+ if not os.path.exists(os.path.join(dataset_path, "utterances")):
+ print("Spliting into utterances...\n")
+ _main(dataset_path)
+
+ save_dir = os.path.join(output_path, "lijian")
+ train_output_file = os.path.join(save_dir, "train.json")
+ test_output_file = os.path.join(save_dir, "test.json")
+ if has_existed(test_output_file):
+ return
+
+ # Load
+ lijian_path = os.path.join(dataset_path, "utterances")
+ song2utts = statistics(lijian_path)
+ test_songs = get_test_songs()
+
+ # We select songs of standard samples as test songs
+ train = []
+ test = []
+
+ train_index_count = 0
+ test_index_count = 0
+
+ train_total_duration = 0
+ test_total_duration = 0
+
+ for chosen_song, utts in tqdm(song2utts.items()):
+ for chosen_uid in song2utts[chosen_song]:
+ res = {
+ "Dataset": "lijian",
+ "Singer": "lijian",
+ "Uid": "{}_{}".format(chosen_song, chosen_uid),
+ }
+ res["Path"] = "{}/{}.wav".format(chosen_song, chosen_uid)
+ res["Path"] = os.path.join(lijian_path, res["Path"])
+ assert os.path.exists(res["Path"])
+
+ waveform, sample_rate = torchaudio.load(res["Path"])
+ duration = waveform.size(-1) / sample_rate
+ res["Duration"] = duration
+
+ if duration <= 1e-8:
+ continue
+
+ if chosen_song in test_songs:
+ res["index"] = test_index_count
+ test_total_duration += duration
+ test.append(res)
+ test_index_count += 1
+ else:
+ res["index"] = train_index_count
+ train_total_duration += duration
+ train.append(res)
+ train_index_count += 1
+
+ print("#Train = {}, #Test = {}".format(len(train), len(test)))
+ print(
+ "#Train hours= {}, #Test hours= {}".format(
+ train_total_duration / 3600, test_total_duration / 3600
+ )
+ )
+
+ # Save
+ os.makedirs(save_dir, exist_ok=True)
+ with open(train_output_file, "w") as f:
+ json.dump(train, f, indent=4, ensure_ascii=False)
+ with open(test_output_file, "w") as f:
+ json.dump(test, f, indent=4, ensure_ascii=False)
diff --git a/preprocessors/ljspeech.py b/preprocessors/ljspeech.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3c19be8e434f625c502fe0557b92c0297ca2068
--- /dev/null
+++ b/preprocessors/ljspeech.py
@@ -0,0 +1,222 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import json
+from tqdm import tqdm
+import os
+import torchaudio
+from utils import audio
+import csv
+import random
+
+from utils.util import has_existed
+from text import _clean_text
+import librosa
+import soundfile as sf
+from scipy.io import wavfile
+
+from pathlib import Path
+import numpy as np
+
+
+def textgird_extract(
+ corpus_directory,
+ output_directory,
+ mfa_path=os.path.join(
+ "pretrained", "mfa", "montreal-forced-aligner", "bin", "mfa_align"
+ ),
+ lexicon=os.path.join("text", "lexicon", "librispeech-lexicon.txt"),
+ acoustic_model_path=os.path.join(
+ "pretrained",
+ "mfa",
+ "montreal-forced-aligner",
+ "pretrained_models",
+ "english.zip",
+ ),
+ jobs="8",
+):
+ assert os.path.exists(
+ corpus_directory
+ ), "Please check the directionary contains *.wav, *.lab"
+ assert (
+ os.path.exists(mfa_path)
+ and os.path.exists(lexicon)
+ and os.path.exists(acoustic_model_path)
+ ), f"Please download the MFA tools to {mfa_path} firstly"
+ Path(output_directory).mkdir(parents=True, exist_ok=True)
+ print(f"MFA results are save in {output_directory}")
+ os.system(
+ f".{os.path.sep}{mfa_path} {corpus_directory} {lexicon} {acoustic_model_path} {output_directory} -j {jobs} --clean"
+ )
+
+
+def get_lines(file):
+ lines = []
+ with open(file, encoding="utf-8") as f:
+ for line in tqdm(f):
+ lines.append(line.strip())
+ return lines
+
+
+def get_uid2utt(ljspeech_path, dataset, cfg):
+ index_count = 0
+ total_duration = 0
+
+ uid2utt = []
+ for l in tqdm(dataset):
+ items = l.split("|")
+ uid = items[0]
+ text = items[2]
+
+ res = {
+ "Dataset": "LJSpeech",
+ "index": index_count,
+ "Singer": "LJSpeech",
+ "Uid": uid,
+ "Text": text,
+ }
+
+ # Duration in wav files
+ audio_file = os.path.join(ljspeech_path, "wavs/{}.wav".format(uid))
+
+ res["Path"] = audio_file
+
+ waveform, sample_rate = torchaudio.load(audio_file)
+ duration = waveform.size(-1) / sample_rate
+ res["Duration"] = duration
+
+ uid2utt.append(res)
+
+ index_count = index_count + 1
+ total_duration += duration
+
+ return uid2utt, total_duration / 3600
+
+
+def split_dataset(
+ lines, test_rate=0.05, valid_rate=0.05, test_size=None, valid_size=None
+):
+ if test_size == None:
+ test_size = int(len(lines) * test_rate)
+ if valid_size == None:
+ valid_size = int(len(lines) * valid_rate)
+ random.shuffle(lines)
+
+ train_set = []
+ test_set = []
+ valid_set = []
+
+ for line in lines[:test_size]:
+ test_set.append(line)
+ for line in lines[test_size : test_size + valid_size]:
+ valid_set.append(line)
+ for line in lines[test_size + valid_size :]:
+ train_set.append(line)
+ return train_set, test_set, valid_set
+
+
+max_wav_value = 32768.0
+
+
+def prepare_align(dataset, dataset_path, cfg, output_path):
+ in_dir = dataset_path
+ out_dir = os.path.join(output_path, dataset, cfg.raw_data)
+ sampling_rate = cfg.sample_rate
+ cleaners = cfg.text_cleaners
+ speaker = "LJSpeech"
+ with open(os.path.join(dataset_path, "metadata.csv"), encoding="utf-8") as f:
+ for line in tqdm(f):
+ parts = line.strip().split("|")
+ base_name = parts[0]
+ text = parts[2]
+ text = _clean_text(text, cleaners)
+
+ output_wav_path = os.path.join(out_dir, speaker, "{}.wav".format(base_name))
+ output_lab_path = os.path.join(out_dir, speaker, "{}.lab".format(base_name))
+
+ if os.path.exists(output_wav_path) and os.path.exists(output_lab_path):
+ continue
+
+ wav_path = os.path.join(in_dir, "wavs", "{}.wav".format(base_name))
+ if os.path.exists(wav_path):
+ os.makedirs(os.path.join(out_dir, speaker), exist_ok=True)
+ wav, _ = librosa.load(wav_path, sampling_rate)
+ wav = wav / max(abs(wav)) * max_wav_value
+
+ wavfile.write(
+ os.path.join(out_dir, speaker, "{}.wav".format(base_name)),
+ sampling_rate,
+ wav.astype(np.int16),
+ )
+
+ with open(
+ os.path.join(out_dir, speaker, "{}.lab".format(base_name)),
+ "w",
+ ) as f1:
+ f1.write(text)
+ # Extract textgird with MFA
+ textgird_extract(
+ corpus_directory=out_dir,
+ output_directory=os.path.join(output_path, dataset, "TextGrid"),
+ )
+
+
+def main(output_path, dataset_path, cfg):
+ print("-" * 10)
+ print("Dataset splits for {}...\n".format("LJSpeech"))
+
+ dataset = "LJSpeech"
+
+ save_dir = os.path.join(output_path, dataset)
+ os.makedirs(save_dir, exist_ok=True)
+ ljspeech_path = dataset_path
+
+ train_output_file = os.path.join(save_dir, "train.json")
+ test_output_file = os.path.join(save_dir, "test.json")
+ valid_output_file = os.path.join(save_dir, "valid.json")
+ singer_dict_file = os.path.join(save_dir, "singers.json")
+
+ speaker = "LJSpeech"
+ speakers = [dataset + "_" + speaker]
+ singer_lut = {name: i for i, name in enumerate(sorted(speakers))}
+ with open(singer_dict_file, "w") as f:
+ json.dump(singer_lut, f, indent=4, ensure_ascii=False)
+
+ if (
+ has_existed(train_output_file)
+ and has_existed(test_output_file)
+ and has_existed(valid_output_file)
+ ):
+ return
+
+ meta_file = os.path.join(ljspeech_path, "metadata.csv")
+ lines = get_lines(meta_file)
+
+ train_set, test_set, valid_set = split_dataset(lines)
+
+ res, hours = get_uid2utt(ljspeech_path, train_set, cfg)
+
+ # Save train
+ os.makedirs(save_dir, exist_ok=True)
+ with open(train_output_file, "w") as f:
+ json.dump(res, f, indent=4, ensure_ascii=False)
+
+ print("Train_hours= {}".format(hours))
+
+ res, hours = get_uid2utt(ljspeech_path, test_set, cfg)
+
+ # Save test
+ os.makedirs(save_dir, exist_ok=True)
+ with open(test_output_file, "w") as f:
+ json.dump(res, f, indent=4, ensure_ascii=False)
+
+ print("Test_hours= {}".format(hours))
+
+ # Save valid
+ os.makedirs(save_dir, exist_ok=True)
+ with open(valid_output_file, "w") as f:
+ json.dump(res, f, indent=4, ensure_ascii=False)
+
+ print("Valid_hours= {}".format(hours))
diff --git a/preprocessors/ljspeech_vocoder.py b/preprocessors/ljspeech_vocoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6ef4fd2c88d85454a939e9fe3329e02b0be949b
--- /dev/null
+++ b/preprocessors/ljspeech_vocoder.py
@@ -0,0 +1,86 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import json
+import torchaudio
+from tqdm import tqdm
+from glob import glob
+
+from utils.util import has_existed
+
+
+def main(output_path, dataset_path):
+ print("-" * 10)
+ print("Dataset splits for ljspeech...\n")
+
+ save_dir = os.path.join(output_path, "ljspeech")
+ ljspeech_path = dataset_path
+
+ wave_files = glob(ljspeech_path + "/wavs/*.wav")
+
+ train_output_file = os.path.join(save_dir, "train.json")
+ test_output_file = os.path.join(save_dir, "test.json")
+
+ if has_existed(train_output_file):
+ return
+
+ utts = []
+
+ for wave_file in tqdm(wave_files):
+ res = {
+ "Dataset": "ljspeech",
+ "Singer": "female1",
+ "Uid": "{}".format(wave_file.split("/")[-1].split(".")[0]),
+ }
+ res["Path"] = wave_file
+ assert os.path.exists(res["Path"])
+
+ waveform, sample_rate = torchaudio.load(res["Path"])
+ duration = waveform.size(-1) / sample_rate
+ res["Duration"] = duration
+
+ if duration <= 1e-8:
+ continue
+
+ utts.append(res)
+
+ test_length = len(utts) // 20
+
+ train_utts = []
+ train_index_count = 0
+ train_total_duration = 0
+
+ for i in tqdm(range(len(utts) - test_length)):
+ tmp = utts[i]
+ tmp["index"] = train_index_count
+ train_index_count += 1
+ train_total_duration += tmp["Duration"]
+ train_utts.append(tmp)
+
+ test_utts = []
+ test_index_count = 0
+ test_total_duration = 0
+
+ for i in tqdm(range(len(utts) - test_length, len(utts))):
+ tmp = utts[i]
+ tmp["index"] = test_index_count
+ test_index_count += 1
+ test_total_duration += tmp["Duration"]
+ test_utts.append(tmp)
+
+ print("#Train = {}, #Test = {}".format(len(train_utts), len(test_utts)))
+ print(
+ "#Train hours= {}, #Test hours= {}".format(
+ train_total_duration / 3600, test_total_duration / 3600
+ )
+ )
+
+ # Save
+ os.makedirs(save_dir, exist_ok=True)
+ with open(train_output_file, "w") as f:
+ json.dump(train_utts, f, indent=4, ensure_ascii=False)
+ with open(test_output_file, "w") as f:
+ json.dump(test_utts, f, indent=4, ensure_ascii=False)
diff --git a/preprocessors/m4singer.py b/preprocessors/m4singer.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb8eb1715ce2af6d712efd9d93c5983dd58ba257
--- /dev/null
+++ b/preprocessors/m4singer.py
@@ -0,0 +1,138 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import json
+import librosa
+from tqdm import tqdm
+from collections import defaultdict
+
+from utils.util import has_existed
+from preprocessors import GOLDEN_TEST_SAMPLES
+
+
+def get_test_songs():
+ golden_samples = GOLDEN_TEST_SAMPLES["m4singer"]
+ # every item is a tuple (singer, song)
+ golden_songs = [s.split("_")[:2] for s in golden_samples]
+ # singer_song, eg: Alto-1_美错
+ golden_songs = ["_".join(t) for t in golden_songs]
+ return golden_songs
+
+
+def m4singer_statistics(meta):
+ singers = []
+ songs = []
+ singer2songs = defaultdict(lambda: defaultdict(list))
+ for utt in meta:
+ p, s, uid = utt["item_name"].split("#")
+ singers.append(p)
+ songs.append(s)
+ singer2songs[p][s].append(uid)
+
+ unique_singers = list(set(singers))
+ unique_songs = list(set(songs))
+ unique_singers.sort()
+ unique_songs.sort()
+
+ print(
+ "M4Singer: {} singers, {} utterances ({} unique songs)".format(
+ len(unique_singers), len(songs), len(unique_songs)
+ )
+ )
+ print("Singers: \n{}".format("\t".join(unique_singers)))
+ return singer2songs, unique_singers
+
+
+def main(output_path, dataset_path):
+ print("-" * 10)
+ print("Preparing test samples for m4singer...\n")
+
+ save_dir = os.path.join(output_path, "m4singer")
+ os.makedirs(save_dir, exist_ok=True)
+ train_output_file = os.path.join(save_dir, "train.json")
+ test_output_file = os.path.join(save_dir, "test.json")
+ singer_dict_file = os.path.join(save_dir, "singers.json")
+ utt2singer_file = os.path.join(save_dir, "utt2singer")
+ if (
+ has_existed(train_output_file)
+ and has_existed(test_output_file)
+ and has_existed(singer_dict_file)
+ and has_existed(utt2singer_file)
+ ):
+ return
+ utt2singer = open(utt2singer_file, "w")
+
+ # Load
+ m4singer_dir = dataset_path
+ meta_file = os.path.join(m4singer_dir, "meta.json")
+ with open(meta_file, "r", encoding="utf-8") as f:
+ meta = json.load(f)
+
+ singer2songs, unique_singers = m4singer_statistics(meta)
+
+ test_songs = get_test_songs()
+
+ # We select songs of standard samples as test songs
+ train = []
+ test = []
+
+ train_index_count = 0
+ test_index_count = 0
+
+ train_total_duration = 0
+ test_total_duration = 0
+
+ for singer, songs in tqdm(singer2songs.items()):
+ song_names = list(songs.keys())
+
+ for chosen_song in song_names:
+ chosen_song = chosen_song.replace(" ", "-")
+ for chosen_uid in songs[chosen_song]:
+ res = {
+ "Dataset": "m4singer",
+ "Singer": singer,
+ "Song": chosen_song,
+ "Uid": "{}_{}_{}".format(singer, chosen_song, chosen_uid),
+ }
+
+ res["Path"] = os.path.join(
+ m4singer_dir, "{}#{}/{}.wav".format(singer, chosen_song, chosen_uid)
+ )
+ assert os.path.exists(res["Path"])
+
+ duration = librosa.get_duration(filename=res["Path"])
+ res["Duration"] = duration
+
+ if "_".join([singer, chosen_song]) in test_songs:
+ res["index"] = test_index_count
+ test_total_duration += duration
+ test.append(res)
+ test_index_count += 1
+ else:
+ res["index"] = train_index_count
+ train_total_duration += duration
+ train.append(res)
+ train_index_count += 1
+
+ utt2singer.write("{}\t{}\n".format(res["Uid"], res["Singer"]))
+
+ print("#Train = {}, #Test = {}".format(len(train), len(test)))
+ print(
+ "#Train hours= {}, #Test hours= {}".format(
+ train_total_duration / 3600, test_total_duration / 3600
+ )
+ )
+
+ # Save train.json and test.json
+ with open(train_output_file, "w") as f:
+ json.dump(train, f, indent=4, ensure_ascii=False)
+ with open(test_output_file, "w") as f:
+ json.dump(test, f, indent=4, ensure_ascii=False)
+
+ # Save singers.json
+ singer_lut = {name: i for i, name in enumerate(unique_singers)}
+ with open(singer_dict_file, "w") as f:
+ json.dump(singer_lut, f, indent=4, ensure_ascii=False)
diff --git a/preprocessors/metadata.py b/preprocessors/metadata.py
new file mode 100644
index 0000000000000000000000000000000000000000..8411ea9ab74883e454b4e9a3927d979bd1a64d1d
--- /dev/null
+++ b/preprocessors/metadata.py
@@ -0,0 +1,137 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import json
+from tqdm import tqdm
+
+
+def cal_metadata(cfg, dataset_types=["train", "test"]):
+ """
+ Dump metadata (singers.json, meta_info.json, utt2singer) for singer dataset or multi-datasets.
+ """
+ from collections import Counter
+
+ datasets = cfg.dataset
+
+ print("-" * 10)
+ print("Preparing metadata...")
+ print("Including: \n{}\n".format("\n".join(datasets)))
+
+ datasets.sort()
+
+ for dataset in tqdm(datasets):
+ save_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
+ assert os.path.exists(save_dir)
+
+ # 'train.json' and 'test.json' and 'valid.json' of target dataset
+ meta_info = dict()
+ utterances_dict = dict()
+ all_utterances = list()
+ duration = dict()
+ total_duration = 0.0
+ for dataset_type in dataset_types:
+ metadata = os.path.join(save_dir, "{}.json".format(dataset_type))
+
+ # Sort the metadata as the duration order
+ with open(metadata, "r", encoding="utf-8") as f:
+ utterances = json.load(f)
+ utterances = sorted(utterances, key=lambda x: x["Duration"])
+ utterances_dict[dataset_type] = utterances
+ all_utterances.extend(utterances)
+
+ # Write back the sorted metadata
+ with open(metadata, "w") as f:
+ json.dump(utterances, f, indent=4, ensure_ascii=False)
+
+ # Get the total duration and singer names for train and test utterances
+ duration[dataset_type] = sum(utt["Duration"] for utt in utterances)
+ total_duration += duration[dataset_type]
+
+ # Paths of metadata needed to be generated
+ singer_dict_file = os.path.join(save_dir, cfg.preprocess.spk2id)
+ utt2singer_file = os.path.join(save_dir, cfg.preprocess.utt2spk)
+
+ singer_names = set(
+ f"{replace_augment_name(utt['Dataset'])}_{utt['Singer']}"
+ for utt in all_utterances
+ )
+
+ # Write the utt2singer file and sort the singer names
+ with open(utt2singer_file, "w", encoding="utf-8") as f:
+ for utt in all_utterances:
+ f.write(
+ f"{utt['Dataset']}_{utt['Uid']}\t{replace_augment_name(utt['Dataset'])}_{utt['Singer']}\n"
+ )
+
+ singer_names = sorted(singer_names)
+ singer_lut = {name: i for i, name in enumerate(singer_names)}
+
+ # dump singers.json
+ with open(singer_dict_file, "w", encoding="utf-8") as f:
+ json.dump(singer_lut, f, indent=4, ensure_ascii=False)
+
+ meta_info = {
+ "dataset": dataset,
+ "statistics": {
+ "size": len(all_utterances),
+ "hours": round(total_duration / 3600, 4),
+ },
+ }
+
+ for dataset_type in dataset_types:
+ meta_info[dataset_type] = {
+ "size": len(utterances_dict[dataset_type]),
+ "hours": round(duration[dataset_type] / 3600, 4),
+ }
+
+ meta_info["singers"] = {"size": len(singer_lut)}
+
+ # Use Counter to count the minutes for each singer
+ total_singer2mins = Counter()
+ training_singer2mins = Counter()
+ for dataset_type in dataset_types:
+ for utt in utterances_dict[dataset_type]:
+ k = f"{replace_augment_name(utt['Dataset'])}_{utt['Singer']}"
+ if dataset_type == "train":
+ training_singer2mins[k] += utt["Duration"] / 60
+ total_singer2mins[k] += utt["Duration"] / 60
+
+ training_singer2mins = dict(
+ sorted(training_singer2mins.items(), key=lambda x: x[1], reverse=True)
+ )
+ training_singer2mins = {k: round(v, 2) for k, v in training_singer2mins.items()}
+ meta_info["singers"]["training_minutes"] = training_singer2mins
+
+ total_singer2mins = dict(
+ sorted(total_singer2mins.items(), key=lambda x: x[1], reverse=True)
+ )
+ total_singer2mins = {k: round(v, 2) for k, v in total_singer2mins.items()}
+ meta_info["singers"]["minutes"] = total_singer2mins
+
+ with open(os.path.join(save_dir, "meta_info.json"), "w") as f:
+ json.dump(meta_info, f, indent=4, ensure_ascii=False)
+
+ for singer, min in training_singer2mins.items():
+ print(f"Speaker/Singer {singer}: {min} mins for training")
+ print("-" * 10, "\n")
+
+
+def replace_augment_name(dataset: str) -> str:
+ """Replace the augmented dataset name with the original dataset name.
+ >>> print(replace_augment_name("dataset_equalizer"))
+ dataset
+ """
+ if "equalizer" in dataset:
+ dataset = dataset.replace("_equalizer", "")
+ elif "formant_shift" in dataset:
+ dataset = dataset.replace("_formant_shift", "")
+ elif "pitch_shift" in dataset:
+ dataset = dataset.replace("_pitch_shift", "")
+ elif "time_stretch" in dataset:
+ dataset = dataset.replace("_time_stretch", "")
+ else:
+ pass
+ return dataset
diff --git a/preprocessors/nus48e.py b/preprocessors/nus48e.py
new file mode 100644
index 0000000000000000000000000000000000000000..e780b3265b3921f21ee0efb5fcd9162e5555c529
--- /dev/null
+++ b/preprocessors/nus48e.py
@@ -0,0 +1,203 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import json
+import torchaudio
+from tqdm import tqdm
+from glob import glob
+from collections import defaultdict
+
+
+from utils.io import save_audio
+from utils.util import has_existed
+from utils.audio_slicer import Slicer
+from preprocessors import GOLDEN_TEST_SAMPLES
+
+
+def split_to_utterances(dataset_path, singer, style, output_dir):
+ data_dir = os.path.join(dataset_path, singer, style)
+
+ print("Splitting to utterances for {}...".format(data_dir))
+
+ wave_files = glob(data_dir + "/*.wav")
+
+ for wav_file in tqdm(wave_files):
+ # Load waveform
+ song_name = wav_file.split("/")[-1].split(".")[0]
+ waveform, fs = torchaudio.load(wav_file)
+
+ # Split
+ slicer = Slicer(sr=fs, threshold=-40.0, max_sil_kept=4000)
+ chunks = slicer.slice(waveform)
+
+ for i, chunk in enumerate(chunks):
+ save_dir = os.path.join(output_dir, singer, style, song_name)
+ os.makedirs(save_dir, exist_ok=True)
+
+ output_file = os.path.join(save_dir, "{:04d}.wav".format(i))
+ save_audio(output_file, chunk, fs)
+
+
+def _main(dataset_path):
+ """
+ Split to utterances
+ """
+ utterance_dir = os.path.join(dataset_path, "utterances")
+
+ singer_infos = glob(dataset_path + "/*")
+
+ for singer_info in singer_infos:
+ singer = singer_info.split("/")[-1]
+
+ for style in ["read", "sing"]:
+ split_to_utterances(dataset_path, singer, style, utterance_dir)
+
+
+def get_test_songs():
+ golden_samples = GOLDEN_TEST_SAMPLES["nus48e"]
+ # every item is a tuple (singer, song)
+ golden_songs = [s.split("#")[:2] for s in golden_samples]
+ # singer_song, eg: Female1#Almost_lover_Amateur
+ return golden_songs
+
+
+def nus48e_statistics(data_dir):
+ singers = []
+ songs = []
+ singer2songs = defaultdict(lambda: defaultdict(list))
+
+ singer_infos = glob(data_dir + "/*")
+
+ for singer_info in singer_infos:
+ singer_info_split = singer_info.split("/")[-1]
+
+ style_infos = glob(singer_info + "/*")
+
+ for style_info in style_infos:
+ style_info_split = style_info.split("/")[-1]
+
+ singer = singer_info_split + "_" + style_info_split
+ singers.append(singer)
+
+ song_infos = glob(style_info + "/*")
+
+ for song_info in song_infos:
+ song = song_info.split("/")[-1]
+
+ songs.append(song)
+
+ utts = glob(song_info + "/*.wav")
+
+ for utt in utts:
+ uid = utt.split("/")[-1].split(".")[0]
+ singer2songs[singer][song].append(uid)
+
+ unique_singers = list(set(singers))
+ unique_songs = list(set(songs))
+ unique_singers.sort()
+ unique_songs.sort()
+
+ print(
+ "nus_48_e: {} singers, {} utterances ({} unique songs)".format(
+ len(unique_singers), len(songs), len(unique_songs)
+ )
+ )
+ print("Singers: \n{}".format("\t".join(unique_singers)))
+ return singer2songs, unique_singers
+
+
+def main(output_path, dataset_path):
+ print("-" * 10)
+ print("Preparing test samples for nus48e...\n")
+
+ if not os.path.exists(os.path.join(dataset_path, "utterances")):
+ print("Spliting into utterances...\n")
+ _main(dataset_path)
+
+ save_dir = os.path.join(output_path, "nus48e")
+ os.makedirs(save_dir, exist_ok=True)
+ train_output_file = os.path.join(save_dir, "train.json")
+ test_output_file = os.path.join(save_dir, "test.json")
+ singer_dict_file = os.path.join(save_dir, "singers.json")
+ utt2singer_file = os.path.join(save_dir, "utt2singer")
+ if (
+ has_existed(train_output_file)
+ and has_existed(test_output_file)
+ and has_existed(singer_dict_file)
+ and has_existed(utt2singer_file)
+ ):
+ return
+ utt2singer = open(utt2singer_file, "w")
+
+ # Load
+ nus48e_path = os.path.join(dataset_path, "utterances")
+
+ singer2songs, unique_singers = nus48e_statistics(nus48e_path)
+ test_songs = get_test_songs()
+
+ # We select songs of standard samples as test songs
+ train = []
+ test = []
+
+ train_index_count = 0
+ test_index_count = 0
+
+ train_total_duration = 0
+ test_total_duration = 0
+
+ for singer, songs in singer2songs.items():
+ song_names = list(songs.keys())
+
+ for chosen_song in song_names:
+ for chosen_uid in songs[chosen_song]:
+ res = {
+ "Dataset": "nus48e",
+ "Singer": singer,
+ "Uid": "{}#{}#{}".format(singer, chosen_song, chosen_uid),
+ }
+ res["Path"] = "{}/{}/{}/{}.wav".format(
+ singer.split("_")[0], singer.split("_")[-1], chosen_song, chosen_uid
+ )
+ res["Path"] = os.path.join(nus48e_path, res["Path"])
+ assert os.path.exists(res["Path"])
+
+ waveform, sample_rate = torchaudio.load(res["Path"])
+ duration = waveform.size(-1) / sample_rate
+ res["Duration"] = duration
+
+ if duration <= 1e-8:
+ continue
+
+ if ([singer, chosen_song]) in test_songs:
+ res["index"] = test_index_count
+ test_total_duration += duration
+ test.append(res)
+ test_index_count += 1
+ else:
+ res["index"] = train_index_count
+ train_total_duration += duration
+ train.append(res)
+ train_index_count += 1
+
+ utt2singer.write("{}\t{}\n".format(res["Uid"], res["Singer"]))
+
+ print("#Train = {}, #Test = {}".format(len(train), len(test)))
+ print(
+ "#Train hours= {}, #Test hours= {}".format(
+ train_total_duration / 3600, test_total_duration / 3600
+ )
+ )
+
+ # Save train.json and test.json
+ with open(train_output_file, "w") as f:
+ json.dump(train, f, indent=4, ensure_ascii=False)
+ with open(test_output_file, "w") as f:
+ json.dump(test, f, indent=4, ensure_ascii=False)
+
+ # Save singers.json
+ singer_lut = {name: i for i, name in enumerate(unique_singers)}
+ with open(singer_dict_file, "w") as f:
+ json.dump(singer_lut, f, indent=4, ensure_ascii=False)
diff --git a/preprocessors/opencpop.py b/preprocessors/opencpop.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1da6e303cbb06370f3215ce5c2cb4dd22e611b5
--- /dev/null
+++ b/preprocessors/opencpop.py
@@ -0,0 +1,73 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import json
+from tqdm import tqdm
+import os
+import librosa
+
+from utils.util import has_existed
+
+
+def get_lines(file):
+ with open(file, "r") as f:
+ lines = f.readlines()
+ lines = [l.strip() for l in lines]
+ return lines
+
+
+def get_uid2utt(opencpop_path, dataset, dataset_type):
+ index_count = 0
+ total_duration = 0
+
+ file = os.path.join(opencpop_path, "segments", "{}.txt".format(dataset_type))
+ lines = get_lines(file)
+
+ uid2utt = []
+ for l in tqdm(lines):
+ items = l.split("|")
+ uid = items[0]
+
+ res = {
+ "Dataset": dataset,
+ "index": index_count,
+ "Singer": "female1",
+ "Uid": uid,
+ }
+
+ # Duration in wav files
+ audio_file = os.path.join(opencpop_path, "segments/wavs/{}.wav".format(uid))
+ res["Path"] = audio_file
+
+ duration = librosa.get_duration(filename=res["Path"])
+ res["Duration"] = duration
+
+ uid2utt.append(res)
+
+ index_count = index_count + 1
+ total_duration += duration
+
+ return uid2utt, total_duration / 3600
+
+
+def main(dataset, output_path, dataset_path):
+ print("-" * 10)
+ print("Dataset splits for {}...\n".format(dataset))
+
+ save_dir = os.path.join(output_path, dataset)
+ opencpop_path = dataset_path
+ for dataset_type in ["train", "test"]:
+ output_file = os.path.join(save_dir, "{}.json".format(dataset_type))
+ if has_existed(output_file):
+ continue
+
+ res, hours = get_uid2utt(opencpop_path, dataset, dataset_type)
+
+ # Save
+ os.makedirs(save_dir, exist_ok=True)
+ with open(output_file, "w") as f:
+ json.dump(res, f, indent=4, ensure_ascii=False)
+
+ print("{}_{}_hours= {}".format(dataset, dataset_type, hours))
diff --git a/preprocessors/opensinger.py b/preprocessors/opensinger.py
new file mode 100644
index 0000000000000000000000000000000000000000..93fc3f648df11a866f6c7017d7c6964c9b36cee6
--- /dev/null
+++ b/preprocessors/opensinger.py
@@ -0,0 +1,169 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import random
+import os
+import json
+import librosa
+from tqdm import tqdm
+from glob import glob
+from collections import defaultdict
+
+from utils.util import has_existed
+from preprocessors import GOLDEN_TEST_SAMPLES
+
+
+def get_test_songs():
+ golden_samples = GOLDEN_TEST_SAMPLES["opensinger"]
+ # every item is a tuple (singer, song)
+ golden_songs = [s.split("_")[:3] for s in golden_samples]
+ # singer_song, eg: Female1#Almost_lover_Amateur
+ return golden_songs
+
+
+def opensinger_statistics(data_dir):
+ singers = []
+ songs = []
+ singer2songs = defaultdict(lambda: defaultdict(list))
+
+ gender_infos = glob(data_dir + "/*")
+
+ for gender_info in gender_infos:
+ gender_info_split = gender_info.split("/")[-1][:-3]
+
+ singer_and_song_infos = glob(gender_info + "/*")
+
+ for singer_and_song_info in singer_and_song_infos:
+ singer_and_song_info_split = singer_and_song_info.split("/")[-1].split("_")
+ singer_id, song = (
+ singer_and_song_info_split[0],
+ singer_and_song_info_split[1],
+ )
+ singer = gender_info_split + "_" + singer_id
+ singers.append(singer)
+ songs.append(song)
+
+ utts = glob(singer_and_song_info + "/*.wav")
+
+ for utt in utts:
+ uid = utt.split("/")[-1].split("_")[-1].split(".")[0]
+ singer2songs[singer][song].append(uid)
+
+ unique_singers = list(set(singers))
+ unique_songs = list(set(songs))
+ unique_singers.sort()
+ unique_songs.sort()
+
+ print(
+ "opensinger: {} singers, {} songs ({} unique songs)".format(
+ len(unique_singers), len(songs), len(unique_songs)
+ )
+ )
+ print("Singers: \n{}".format("\t".join(unique_singers)))
+ return singer2songs, unique_singers
+
+
+def main(output_path, dataset_path):
+ print("-" * 10)
+ print("Preparing test samples for opensinger...\n")
+
+ save_dir = os.path.join(output_path, "opensinger")
+ os.makedirs(save_dir, exist_ok=True)
+ train_output_file = os.path.join(save_dir, "train.json")
+ test_output_file = os.path.join(save_dir, "test.json")
+ singer_dict_file = os.path.join(save_dir, "singers.json")
+ utt2singer_file = os.path.join(save_dir, "utt2singer")
+ if (
+ has_existed(train_output_file)
+ and has_existed(test_output_file)
+ and has_existed(singer_dict_file)
+ and has_existed(utt2singer_file)
+ ):
+ return
+ utt2singer = open(utt2singer_file, "w")
+
+ # Load
+ opensinger_path = dataset_path
+
+ singer2songs, unique_singers = opensinger_statistics(opensinger_path)
+ test_songs = get_test_songs()
+
+ # We select songs of standard samples as test songs
+ train = []
+ test = []
+
+ train_index_count = 0
+ test_index_count = 0
+
+ train_total_duration = 0
+ test_total_duration = 0
+
+ for i, (singer, songs) in enumerate(singer2songs.items()):
+ song_names = list(songs.keys())
+
+ for chosen_song in tqdm(
+ song_names, desc="Singer {}/{}".format(i, len(singer2songs))
+ ):
+ for chosen_uid in songs[chosen_song]:
+ res = {
+ "Dataset": "opensinger",
+ "Singer": singer,
+ "Song": chosen_song,
+ "Uid": "{}_{}_{}".format(singer, chosen_song, chosen_uid),
+ }
+ res["Path"] = "{}Raw/{}_{}/{}_{}_{}.wav".format(
+ singer.split("_")[0],
+ singer.split("_")[1],
+ chosen_song,
+ singer.split("_")[1],
+ chosen_song,
+ chosen_uid,
+ )
+ res["Path"] = os.path.join(opensinger_path, res["Path"])
+ assert os.path.exists(res["Path"])
+
+ duration = librosa.get_duration(filename=res["Path"])
+ res["Duration"] = duration
+
+ if duration > 30:
+ print(
+ "Wav file: {}, the duration = {:.2f}s > 30s, which has been abandoned.".format(
+ res["Path"], duration
+ )
+ )
+ continue
+
+ if (
+ [singer.split("_")[0], singer.split("_")[1], chosen_song]
+ ) in test_songs:
+ res["index"] = test_index_count
+ test_total_duration += duration
+ test.append(res)
+ test_index_count += 1
+ else:
+ res["index"] = train_index_count
+ train_total_duration += duration
+ train.append(res)
+ train_index_count += 1
+
+ utt2singer.write("{}\t{}\n".format(res["Uid"], res["Singer"]))
+
+ print("#Train = {}, #Test = {}".format(len(train), len(test)))
+ print(
+ "#Train hours= {}, #Test hours= {}".format(
+ train_total_duration / 3600, test_total_duration / 3600
+ )
+ )
+
+ # Save train.json and test.json
+ with open(train_output_file, "w") as f:
+ json.dump(train, f, indent=4, ensure_ascii=False)
+ with open(test_output_file, "w") as f:
+ json.dump(test, f, indent=4, ensure_ascii=False)
+
+ # Save singers.json
+ singer_lut = {name: i for i, name in enumerate(unique_singers)}
+ with open(singer_dict_file, "w") as f:
+ json.dump(singer_lut, f, indent=4, ensure_ascii=False)
diff --git a/preprocessors/opera.py b/preprocessors/opera.py
new file mode 100644
index 0000000000000000000000000000000000000000..c421fbb8346945c5dec40b1c872cebcb49f1f606
--- /dev/null
+++ b/preprocessors/opera.py
@@ -0,0 +1,186 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import json
+import os
+from tqdm import tqdm
+import torchaudio
+from glob import glob
+from collections import defaultdict
+
+from utils.util import has_existed
+from utils.io import save_audio
+from utils.audio_slicer import Slicer
+from preprocessors import GOLDEN_TEST_SAMPLES
+
+
+def split_to_utterances(language_dir, output_dir):
+ print("Splitting to utterances for {}...".format(language_dir))
+
+ for wav_file in tqdm(glob("{}/*/*".format(language_dir))):
+ # Load waveform
+ singer_name, song_name = wav_file.split("/")[-2:]
+ song_name = song_name.split(".")[0]
+ waveform, fs = torchaudio.load(wav_file)
+
+ # Split
+ slicer = Slicer(sr=fs, threshold=-30.0, max_sil_kept=3000)
+ chunks = slicer.slice(waveform)
+
+ for i, chunk in enumerate(chunks):
+ save_dir = os.path.join(output_dir, singer_name, song_name)
+ os.makedirs(save_dir, exist_ok=True)
+
+ output_file = os.path.join(save_dir, "{:04d}.wav".format(i))
+ save_audio(output_file, chunk, fs)
+
+
+def _main(dataset_path):
+ """
+ Split to utterances
+ """
+ utterance_dir = os.path.join(dataset_path, "utterances")
+
+ for lang in ["chinese", "western"]:
+ split_to_utterances(os.path.join(dataset_path, lang), utterance_dir)
+
+
+def get_test_songs():
+ golden_samples = GOLDEN_TEST_SAMPLES["opera"]
+ # every item is a tuple (singer, song)
+ golden_songs = [s.split("#")[:2] for s in golden_samples]
+ # singer#song, eg:fem_01#neg_01
+ return golden_songs
+
+
+def opera_statistics(data_dir):
+ singers = []
+ songs = []
+ singers2songs = defaultdict(lambda: defaultdict(list))
+
+ singer_infos = glob(data_dir + "/*")
+
+ for singer_info in singer_infos:
+ singer = singer_info.split("/")[-1]
+
+ song_infos = glob(singer_info + "/*")
+
+ for song_info in song_infos:
+ song = song_info.split("/")[-1]
+
+ singers.append(singer)
+ songs.append(song)
+
+ utts = glob(song_info + "/*.wav")
+
+ for utt in utts:
+ uid = utt.split("/")[-1].split(".")[0]
+ singers2songs[singer][song].append(uid)
+
+ unique_singers = list(set(singers))
+ unique_songs = list(set(songs))
+ unique_singers.sort()
+ unique_songs.sort()
+
+ print(
+ "opera: {} singers, {} utterances ({} unique songs)".format(
+ len(unique_singers), len(songs), len(unique_songs)
+ )
+ )
+ print("Singers: \n{}".format("\t".join(unique_singers)))
+ return singers2songs, unique_singers
+
+
+def main(output_path, dataset_path):
+ print("-" * 10)
+ print("Preparing test samples for opera...\n")
+
+ if not os.path.exists(os.path.join(dataset_path, "utterances")):
+ print("Spliting into utterances...\n")
+ _main(dataset_path)
+
+ save_dir = os.path.join(output_path, "opera")
+ os.makedirs(save_dir, exist_ok=True)
+ train_output_file = os.path.join(save_dir, "train.json")
+ test_output_file = os.path.join(save_dir, "test.json")
+ singer_dict_file = os.path.join(save_dir, "singers.json")
+ utt2singer_file = os.path.join(save_dir, "utt2singer")
+ if (
+ has_existed(train_output_file)
+ and has_existed(test_output_file)
+ and has_existed(singer_dict_file)
+ and has_existed(utt2singer_file)
+ ):
+ return
+ utt2singer = open(utt2singer_file, "w")
+
+ # Load
+ opera_path = os.path.join(dataset_path, "utterances")
+
+ singers2songs, unique_singers = opera_statistics(opera_path)
+ test_songs = get_test_songs()
+
+ # We select songs of standard samples as test songs
+ train = []
+ test = []
+
+ train_index_count = 0
+ test_index_count = 0
+
+ train_total_duration = 0
+ test_total_duration = 0
+
+ for singer, songs in tqdm(singers2songs.items()):
+ song_names = list(songs.keys())
+
+ for chosen_song in song_names:
+ for chosen_uid in songs[chosen_song]:
+ res = {
+ "Dataset": "opera",
+ "Singer": singer,
+ "Uid": "{}#{}#{}".format(singer, chosen_song, chosen_uid),
+ }
+ res["Path"] = "{}/{}/{}.wav".format(singer, chosen_song, chosen_uid)
+ res["Path"] = os.path.join(opera_path, res["Path"])
+ assert os.path.exists(res["Path"])
+
+ waveform, sample_rate = torchaudio.load(res["Path"])
+ duration = waveform.size(-1) / sample_rate
+ res["Duration"] = duration
+
+ if duration <= 1e-8:
+ continue
+
+ if ([singer, chosen_song]) in test_songs:
+ res["index"] = test_index_count
+ test_total_duration += duration
+ test.append(res)
+ test_index_count += 1
+ else:
+ res["index"] = train_index_count
+ train_total_duration += duration
+ train.append(res)
+ train_index_count += 1
+
+ utt2singer.write("{}\t{}\n".format(res["Uid"], res["Singer"]))
+
+ print("#Train = {}, #Test = {}".format(len(train), len(test)))
+ print(
+ "#Train hours= {}, #Test hours= {}".format(
+ train_total_duration / 3600, test_total_duration / 3600
+ )
+ )
+
+ # Save train.json and test.json
+ with open(train_output_file, "w") as f:
+ json.dump(train, f, indent=4, ensure_ascii=False)
+ with open(test_output_file, "w") as f:
+ json.dump(test, f, indent=4, ensure_ascii=False)
+
+ # Save singers.json
+ singer_lut = {name: i for i, name in enumerate(unique_singers)}
+ with open(singer_dict_file, "w") as f:
+ json.dump(singer_lut, f, indent=4, ensure_ascii=False)
diff --git a/preprocessors/pjs.py b/preprocessors/pjs.py
new file mode 100644
index 0000000000000000000000000000000000000000..78d69bc56cae59b3bf512078e365853ccad053ff
--- /dev/null
+++ b/preprocessors/pjs.py
@@ -0,0 +1,135 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+from tqdm import tqdm
+import glob
+import json
+import torchaudio
+
+from utils.util import has_existed
+from utils.io import save_audio
+
+
+def get_splitted_utterances(
+ raw_wav_dir, trimed_wav_dir, n_utterance_splits, overlapping
+):
+ res = []
+ raw_song_files = glob.glob(
+ os.path.join(raw_wav_dir, "**/pjs*_song.wav"), recursive=True
+ )
+ trimed_song_files = glob.glob(
+ os.path.join(trimed_wav_dir, "**/*.wav"), recursive=True
+ )
+
+ if len(raw_song_files) * n_utterance_splits == len(trimed_song_files):
+ print("Splitted done...")
+ for wav_file in tqdm(trimed_song_files):
+ uid = wav_file.split("/")[-1].split(".")[0]
+ utt = {"Dataset": "pjs", "Singer": "male1", "Uid": uid, "Path": wav_file}
+
+ waveform, sample_rate = torchaudio.load(wav_file)
+ duration = waveform.size(-1) / sample_rate
+ utt["Duration"] = duration
+
+ res.append(utt)
+
+ else:
+ for wav_file in tqdm(raw_song_files):
+ song_id = wav_file.split("/")[-1].split(".")[0]
+
+ waveform, sample_rate = torchaudio.load(wav_file)
+ trimed_waveform = torchaudio.functional.vad(waveform, sample_rate)
+ trimed_waveform = torchaudio.functional.vad(
+ trimed_waveform.flip(dims=[1]), sample_rate
+ ).flip(dims=[1])
+
+ audio_len = trimed_waveform.size(-1)
+ lapping_len = overlapping * sample_rate
+
+ for i in range(n_utterance_splits):
+ start = i * audio_len // 3
+ end = start + audio_len // 3 + lapping_len
+ splitted_waveform = trimed_waveform[:, start:end]
+
+ utt = {
+ "Dataset": "pjs",
+ "Singer": "male1",
+ "Uid": "{}_{}".format(song_id, i),
+ }
+
+ # Duration
+ duration = splitted_waveform.size(-1) / sample_rate
+ utt["Duration"] = duration
+
+ # Save trimed wav
+ splitted_waveform_file = os.path.join(
+ trimed_wav_dir, "{}.wav".format(utt["Uid"])
+ )
+ save_audio(splitted_waveform_file, splitted_waveform, sample_rate)
+
+ # Path
+ utt["Path"] = splitted_waveform_file
+
+ res.append(utt)
+
+ res = sorted(res, key=lambda x: x["Uid"])
+ return res
+
+
+def main(output_path, dataset_path, n_utterance_splits=3, overlapping=1):
+ """
+ 1. Split one raw utterance to three splits (since some samples are too long)
+ 2. Overlapping of ajacent splits is 1 s
+ """
+ print("-" * 10)
+ print("Preparing training dataset for PJS...")
+
+ save_dir = os.path.join(output_path, "pjs")
+ raw_wav_dir = os.path.join(dataset_path, "PJS_corpus_ver1.1")
+
+ # Trim for silence
+ trimed_wav_dir = os.path.join(dataset_path, "trim")
+ os.makedirs(trimed_wav_dir, exist_ok=True)
+
+ # Total utterances
+ utterances = get_splitted_utterances(
+ raw_wav_dir, trimed_wav_dir, n_utterance_splits, overlapping
+ )
+ total_uids = [utt["Uid"] for utt in utterances]
+
+ # Test uids
+ n_test_songs = 3
+ test_uids = []
+ for i in range(1, n_test_songs + 1):
+ test_uids += [
+ "pjs00{}_song_{}".format(i, split_id)
+ for split_id in range(n_utterance_splits)
+ ]
+
+ # Train uids
+ train_uids = [uid for uid in total_uids if uid not in test_uids]
+
+ for dataset_type in ["train", "test"]:
+ output_file = os.path.join(save_dir, "{}.json".format(dataset_type))
+ if has_existed(output_file):
+ continue
+
+ uids = eval("{}_uids".format(dataset_type))
+ res = [utt for utt in utterances if utt["Uid"] in uids]
+ for i in range(len(res)):
+ res[i]["index"] = i
+
+ time = sum([utt["Duration"] for utt in res])
+ print(
+ "{}, Total size: {}, Total Duraions = {} s = {:.2f} hour\n".format(
+ dataset_type, len(res), time, time / 3600
+ )
+ )
+
+ # Save
+ os.makedirs(save_dir, exist_ok=True)
+ with open(output_file, "w") as f:
+ json.dump(res, f, indent=4, ensure_ascii=False)
diff --git a/preprocessors/popbutfy.py b/preprocessors/popbutfy.py
new file mode 100644
index 0000000000000000000000000000000000000000..72ba7bdd60f6d370a7ff39bbb90d695f1157a4b4
--- /dev/null
+++ b/preprocessors/popbutfy.py
@@ -0,0 +1,153 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import json
+import torchaudio
+import librosa
+from tqdm import tqdm
+from glob import glob
+from collections import defaultdict
+
+from utils.util import has_existed
+from preprocessors import GOLDEN_TEST_SAMPLES
+
+
+def get_test_songs():
+ golden_samples = GOLDEN_TEST_SAMPLES["popbutfy"]
+ # every item is a tuple (singer, song)
+ golden_songs = [s.split("#")[:2] for s in golden_samples]
+ # singer#song, eg: Female1#Almost_lover_Amateur
+ return golden_songs
+
+
+def popbutfy_statistics(data_dir):
+ singers = []
+ songs = []
+ singer2songs = defaultdict(lambda: defaultdict(list))
+
+ data_infos = glob(data_dir + "/*")
+
+ for data_info in data_infos:
+ data_info_split = data_info.split("/")[-1].split("#")
+
+ singer, song = data_info_split[0], data_info_split[-1]
+ singers.append(singer)
+ songs.append(song)
+
+ utts = glob(data_info + "/*")
+
+ for utt in utts:
+ uid = utt.split("/")[-1].split("_")[-1].split(".")[0]
+ singer2songs[singer][song].append(uid)
+
+ unique_singers = list(set(singers))
+ unique_songs = list(set(songs))
+ unique_singers.sort()
+ unique_songs.sort()
+
+ print(
+ "PopBuTFy: {} singers, {} utterances ({} unique songs)".format(
+ len(unique_singers), len(songs), len(unique_songs)
+ )
+ )
+ print("Singers: \n{}".format("\t".join(unique_singers)))
+ return singer2songs, unique_singers
+
+
+def main(output_path, dataset_path):
+ print("-" * 10)
+ print("Preparing test samples for popbutfy...\n")
+
+ save_dir = os.path.join(output_path, "popbutfy")
+ os.makedirs(save_dir, exist_ok=True)
+ train_output_file = os.path.join(save_dir, "train.json")
+ test_output_file = os.path.join(save_dir, "test.json")
+ singer_dict_file = os.path.join(save_dir, "singers.json")
+ utt2singer_file = os.path.join(save_dir, "utt2singer")
+ if (
+ has_existed(train_output_file)
+ and has_existed(test_output_file)
+ and has_existed(singer_dict_file)
+ and has_existed(utt2singer_file)
+ ):
+ return
+ utt2singer = open(utt2singer_file, "w")
+
+ # Load
+ popbutfy_dir = dataset_path
+
+ singer2songs, unique_singers = popbutfy_statistics(popbutfy_dir)
+ test_songs = get_test_songs()
+
+ # We select songs of standard samples as test songs
+ train = []
+ test = []
+
+ train_index_count = 0
+ test_index_count = 0
+
+ train_total_duration = 0
+ test_total_duration = 0
+
+ for singer, songs in tqdm(singer2songs.items()):
+ song_names = list(songs.keys())
+
+ for chosen_song in song_names:
+ for chosen_uid in songs[chosen_song]:
+ res = {
+ "Dataset": "popbutfy",
+ "Singer": singer,
+ "Song": chosen_song,
+ "Uid": "{}#{}#".format(singer, chosen_song, chosen_uid),
+ }
+ res["Path"] = "{}#singing#{}/{}#singing#{}_{}.mp3".format(
+ singer, chosen_song, singer, chosen_song, chosen_uid
+ )
+ if not os.path.exists(os.path.join(popbutfy_dir, res["Path"])):
+ res["Path"] = "{}#singing#{}/{}#singing#{}_{}.wav".format(
+ singer, chosen_song, singer, chosen_song, chosen_uid
+ )
+ res["Path"] = os.path.join(popbutfy_dir, res["Path"])
+ assert os.path.exists(res["Path"])
+
+ if res["Path"].split("/")[-1].split(".")[-1] == "wav":
+ waveform, sample_rate = torchaudio.load(res["Path"])
+ duration = waveform.size(-1) / sample_rate
+ else:
+ waveform, sample_rate = librosa.load(res["Path"])
+ duration = waveform.shape[-1] / sample_rate
+ res["Duration"] = duration
+
+ if ([singer, chosen_song]) in test_songs:
+ res["index"] = test_index_count
+ test_total_duration += duration
+ test.append(res)
+ test_index_count += 1
+ else:
+ res["index"] = train_index_count
+ train_total_duration += duration
+ train.append(res)
+ train_index_count += 1
+
+ utt2singer.write("{}\t{}\n".format(res["Uid"], res["Singer"]))
+
+ print("#Train = {}, #Test = {}".format(len(train), len(test)))
+ print(
+ "#Train hours= {}, #Test hours= {}".format(
+ train_total_duration / 3600, test_total_duration / 3600
+ )
+ )
+
+ # Save train.json and test.json
+ with open(train_output_file, "w") as f:
+ json.dump(train, f, indent=4, ensure_ascii=False)
+ with open(test_output_file, "w") as f:
+ json.dump(test, f, indent=4, ensure_ascii=False)
+
+ # Save singers.json
+ singer_lut = {name: i for i, name in enumerate(unique_singers)}
+ with open(singer_dict_file, "w") as f:
+ json.dump(singer_lut, f, indent=4, ensure_ascii=False)
diff --git a/preprocessors/popcs.py b/preprocessors/popcs.py
new file mode 100644
index 0000000000000000000000000000000000000000..9acf2127ec27f679d715108495f580a30f75d01a
--- /dev/null
+++ b/preprocessors/popcs.py
@@ -0,0 +1,118 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import json
+import torchaudio
+from glob import glob
+from collections import defaultdict
+
+from utils.util import has_existed
+from preprocessors import GOLDEN_TEST_SAMPLES
+
+
+def get_test_songs():
+ golden_samples = GOLDEN_TEST_SAMPLES["popcs"]
+ # every item is a string
+ golden_songs = [s.split("_")[:1] for s in golden_samples]
+ # song, eg: 万有引力
+ return golden_songs
+
+
+def popcs_statistics(data_dir):
+ songs = []
+ songs2utts = defaultdict(list)
+
+ song_infos = glob(data_dir + "/*")
+
+ for song_info in song_infos:
+ song_info_split = song_info.split("/")[-1].split("-")[-1]
+
+ songs.append(song_info_split)
+
+ utts = glob(song_info + "/*.wav")
+
+ for utt in utts:
+ uid = utt.split("/")[-1].split("_")[0]
+ songs2utts[song_info_split].append(uid)
+
+ unique_songs = list(set(songs))
+ unique_songs.sort()
+
+ print(
+ "popcs: {} utterances ({} unique songs)".format(len(songs), len(unique_songs))
+ )
+ print("Songs: \n{}".format("\t".join(unique_songs)))
+ return songs2utts
+
+
+def main(output_path, dataset_path):
+ print("-" * 10)
+ print("Preparing test samples for popcs...\n")
+
+ save_dir = os.path.join(output_path, "popcs")
+ train_output_file = os.path.join(save_dir, "train.json")
+ test_output_file = os.path.join(save_dir, "test.json")
+ if has_existed(test_output_file):
+ return
+
+ # Load
+ popcs_dir = dataset_path
+
+ songs2utts = popcs_statistics(popcs_dir)
+ test_songs = get_test_songs()
+
+ # We select songs of standard samples as test songs
+ train = []
+ test = []
+
+ train_index_count = 0
+ test_index_count = 0
+
+ train_total_duration = 0
+ test_total_duration = 0
+
+ song_names = list(songs2utts.keys())
+
+ for chosen_song in song_names:
+ for chosen_uid in songs2utts[chosen_song]:
+ res = {
+ "Dataset": "popcs",
+ "Singer": "female1",
+ "Song": chosen_song,
+ "Uid": "{}_{}".format(chosen_song, chosen_uid),
+ }
+ res["Path"] = "popcs-{}/{}_wf0.wav".format(chosen_song, chosen_uid)
+ res["Path"] = os.path.join(popcs_dir, res["Path"])
+ assert os.path.exists(res["Path"])
+
+ waveform, sample_rate = torchaudio.load(res["Path"])
+ duration = waveform.size(-1) / sample_rate
+ res["Duration"] = duration
+
+ if ([chosen_song]) in test_songs:
+ res["index"] = test_index_count
+ test_total_duration += duration
+ test.append(res)
+ test_index_count += 1
+ else:
+ res["index"] = train_index_count
+ train_total_duration += duration
+ train.append(res)
+ train_index_count += 1
+
+ print("#Train = {}, #Test = {}".format(len(train), len(test)))
+ print(
+ "#Train hours= {}, #Test hours= {}".format(
+ train_total_duration / 3600, test_total_duration / 3600
+ )
+ )
+
+ # Save
+ os.makedirs(save_dir, exist_ok=True)
+ with open(train_output_file, "w") as f:
+ json.dump(train, f, indent=4, ensure_ascii=False)
+ with open(test_output_file, "w") as f:
+ json.dump(test, f, indent=4, ensure_ascii=False)
diff --git a/preprocessors/processor.py b/preprocessors/processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..037ac6c58e7781f44322f7cac9a15e64b0e641fc
--- /dev/null
+++ b/preprocessors/processor.py
@@ -0,0 +1,110 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import re
+from preprocessors import (
+ m4singer,
+ opencpop,
+ svcc,
+ pjs,
+ popbutfy,
+ opensinger,
+ popcs,
+ kising,
+ csd,
+ opera,
+ nus48e,
+ svcceval,
+ vctk,
+ vctksample,
+ libritts,
+ lijian,
+ cdmusiceval,
+ ljspeech,
+ coco,
+ cocoeval,
+ customsvcdataset,
+ vocalist,
+ ljspeech_vocoder,
+ librilight,
+ hifitts,
+)
+
+
+def preprocess_dataset(
+ dataset, dataset_path, output_path, cfg, task_type, is_custom_dataset=False
+):
+ """Call specific function to handle specific dataset
+ Args:
+ dataset (str): name of a dataset, e.g. opencpop, m4singer
+ dataset_path (str): path to dataset
+ output_path (str): path to store preprocessing result files
+ """
+ if is_custom_dataset:
+ if task_type == "svc":
+ customsvcdataset.main(output_path, dataset_path, dataset_name=dataset)
+ else:
+ raise NotImplementedError(
+ "Custom dataset for {} task not implemented!".format(cfg.task_type)
+ )
+
+ if re.match("opencpop*", dataset):
+ opencpop.main(dataset, output_path, dataset_path)
+ if dataset == "m4singer":
+ m4singer.main(output_path, dataset_path)
+ if dataset == "svcc":
+ svcc.main(output_path, dataset_path)
+ if dataset == "pjs":
+ pjs.main(output_path, dataset_path)
+ if dataset == "popbutfy":
+ popbutfy.main(output_path, dataset_path)
+ if dataset == "opensinger":
+ opensinger.main(output_path, dataset_path)
+ if dataset == "popcs":
+ popcs.main(output_path, dataset_path)
+ if dataset == "kising":
+ kising.main(output_path, dataset_path)
+ if dataset == "csd":
+ csd.main(output_path, dataset_path)
+ if dataset == "opera":
+ opera.main(output_path, dataset_path)
+ if dataset == "nus48e":
+ nus48e.main(output_path, dataset_path)
+ if dataset == "vctk":
+ vctk.main(output_path, dataset_path)
+ if dataset == "svcceval":
+ svcceval.main(output_path, dataset_path)
+ if dataset == "libritts":
+ libritts.main(output_path, dataset_path)
+ if dataset == "lijian":
+ lijian.main(output_path, dataset_path)
+ if dataset == "cdmusiceval":
+ cdmusiceval.main(output_path, dataset_path)
+ if dataset == "LJSpeech":
+ ljspeech.main(output_path, dataset_path, cfg)
+ if dataset == "ljspeech":
+ ljspeech_vocoder.main(output_path, dataset_path)
+ if dataset == "coco":
+ coco.main(output_path, dataset_path)
+ if dataset == "cocoeval":
+ cocoeval.main(output_path, dataset_path)
+ if dataset == "vocalist":
+ vocalist.main(output_path, dataset_path)
+ if dataset == "librilight":
+ librilight.main(output_path, dataset_path, cfg)
+ if dataset == "hifitts":
+ hifitts.main(output_path, dataset_path)
+
+
+def prepare_align(dataset, dataset_path, cfg, output_path):
+ """Call specific function to handle specific dataset
+
+ Args:
+ dataset (str): name of a dataset, e.g. ljspeech
+ dataset_path (str): path to dataset
+ output_path (str): path to store preprocessing result files
+ """
+ if dataset == "LJSpeech":
+ ljspeech.prepare_align(dataset, dataset_path, cfg, output_path)
diff --git a/preprocessors/svcc.py b/preprocessors/svcc.py
new file mode 100644
index 0000000000000000000000000000000000000000..6afee6197146eff9464839ea7844edbb64e76211
--- /dev/null
+++ b/preprocessors/svcc.py
@@ -0,0 +1,85 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import glob
+import librosa
+import json
+
+from utils.util import has_existed
+from preprocessors import GOLDEN_TEST_SAMPLES
+
+
+def main(output_path, dataset_path):
+ print("-" * 10)
+ print("Preparing training dataset for svcc...")
+
+ data_dir = os.path.join(dataset_path, "Data")
+ save_dir = os.path.join(output_path, "svcc")
+ os.makedirs(save_dir, exist_ok=True)
+
+ singer_dict_file = os.path.join(save_dir, "singers.json")
+ utt2singer_file = os.path.join(save_dir, "utt2singer")
+ utt2singer = open(utt2singer_file, "w")
+
+ # Load utterances
+ train = []
+ test = []
+ singers = []
+
+ for wav_file in glob.glob(os.path.join(data_dir, "*/*.wav")):
+ singer, filename = wav_file.split("/")[-2:]
+ uid = filename.split(".")[0]
+ utt = {
+ "Dataset": "svcc",
+ "Singer": singer,
+ "Uid": "{}_{}".format(singer, uid),
+ "Path": wav_file,
+ }
+
+ # Duration
+ duration = librosa.get_duration(filename=wav_file)
+ utt["Duration"] = duration
+
+ if utt["Uid"] in GOLDEN_TEST_SAMPLES["svcc"]:
+ test.append(utt)
+ else:
+ train.append(utt)
+
+ singers.append(singer)
+ utt2singer.write("{}\t{}\n".format(utt["Uid"], utt["Singer"]))
+
+ # Save singers.json
+ unique_singers = list(set(singers))
+ unique_singers.sort()
+ singer_lut = {name: i for i, name in enumerate(unique_singers)}
+ with open(singer_dict_file, "w") as f:
+ json.dump(singer_lut, f, indent=4, ensure_ascii=False)
+
+ train_total_duration = sum([utt["Duration"] for utt in train])
+ test_total_duration = sum([utt["Duration"] for utt in test])
+
+ for dataset_type in ["train", "test"]:
+ output_file = os.path.join(save_dir, "{}.json".format(dataset_type))
+ if has_existed(output_file):
+ continue
+
+ utterances = eval(dataset_type)
+ utterances = sorted(utterances, key=lambda x: x["Uid"])
+
+ for i in range(len(utterances)):
+ utterances[i]["index"] = i
+
+ print("{}: Total size: {}\n".format(dataset_type, len(utterances)))
+
+ # Save
+ with open(output_file, "w") as f:
+ json.dump(utterances, f, indent=4, ensure_ascii=False)
+
+ print(
+ "#Train hours= {}, #Test hours= {}".format(
+ train_total_duration / 3600, test_total_duration / 3600
+ )
+ )
diff --git a/preprocessors/svcceval.py b/preprocessors/svcceval.py
new file mode 100644
index 0000000000000000000000000000000000000000..871e78956d8e9d323125aaa79770f02176e2426c
--- /dev/null
+++ b/preprocessors/svcceval.py
@@ -0,0 +1,80 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import glob
+import librosa
+import json
+
+from utils.util import has_existed
+
+
+def main(output_path, dataset_path):
+ print("-" * 10)
+ print("Preparing training dataset for svcceval...")
+
+ data_dir = os.path.join(dataset_path, "Data")
+ save_dir = os.path.join(output_path, "svcceval")
+ os.makedirs(save_dir, exist_ok=True)
+
+ singer_dict_file = os.path.join(save_dir, "singers.json")
+ utt2singer_file = os.path.join(save_dir, "utt2singer")
+ utt2singer = open(utt2singer_file, "w")
+
+ # Load utterances
+ train = []
+ test = []
+ singers = []
+ for wav_file in glob.glob(os.path.join(data_dir, "*/*.wav")):
+ singer, filename = wav_file.split("/")[-2:]
+ uid = filename.split(".")[0]
+ utt = {
+ "Dataset": "svcceval",
+ "Singer": singer,
+ "Uid": "{}_{}".format(singer, uid),
+ "Path": wav_file,
+ }
+
+ # Duration
+ duration = librosa.get_duration(filename=wav_file)
+ utt["Duration"] = duration
+
+ test.append(utt)
+
+ singers.append(singer)
+ utt2singer.write("{}\t{}\n".format(utt["Uid"], utt["Singer"]))
+
+ # Save singers.json
+ unique_singers = list(set(singers))
+ unique_singers.sort()
+ singer_lut = {name: i for i, name in enumerate(unique_singers)}
+ with open(singer_dict_file, "w") as f:
+ json.dump(singer_lut, f, indent=4, ensure_ascii=False)
+
+ train_total_duration = sum([utt["Duration"] for utt in train])
+ test_total_duration = sum([utt["Duration"] for utt in test])
+
+ for dataset_type in ["train", "test"]:
+ output_file = os.path.join(save_dir, "{}.json".format(dataset_type))
+ if has_existed(output_file):
+ continue
+
+ utterances = eval(dataset_type)
+ utterances = sorted(utterances, key=lambda x: x["Uid"])
+
+ for i in range(len(utterances)):
+ utterances[i]["index"] = i
+
+ print("{}: Total size: {}\n".format(dataset_type, len(utterances)))
+
+ # Save
+ with open(output_file, "w") as f:
+ json.dump(utterances, f, indent=4, ensure_ascii=False)
+
+ print(
+ "#Train hours= {}, #Test hours= {}".format(
+ train_total_duration / 3600, test_total_duration / 3600
+ )
+ )
diff --git a/preprocessors/vctk.py b/preprocessors/vctk.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa4c1c69414561f6e40c69acc653630ba783abd1
--- /dev/null
+++ b/preprocessors/vctk.py
@@ -0,0 +1,163 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import json
+import librosa
+from tqdm import tqdm
+from glob import glob
+from collections import defaultdict
+
+from utils.util import has_existed
+
+
+def get_lines(file):
+ with open(file, "r") as f:
+ lines = f.readlines()
+ lines = [l.strip() for l in lines]
+ return lines
+
+
+def vctk_statistics(data_dir):
+ speakers = []
+ speakers2utts = defaultdict(list)
+
+ speaker_infos = glob(data_dir + "/wav48_silence_trimmed" + "/*")
+
+ for speaker_info in speaker_infos:
+ speaker = speaker_info.split("/")[-1]
+
+ if speaker == "log.txt":
+ continue
+
+ speakers.append(speaker)
+
+ utts = glob(speaker_info + "/*")
+
+ for utt in utts:
+ uid = (
+ utt.split("/")[-1].split("_")[1]
+ + "_"
+ + utt.split("/")[-1].split("_")[2].split(".")[0]
+ )
+ speakers2utts[speaker].append(uid)
+
+ unique_speakers = list(set(speakers))
+ unique_speakers.sort()
+
+ print("Speakers: \n{}".format("\t".join(unique_speakers)))
+ return speakers2utts, unique_speakers
+
+
+def vctk_speaker_infos(data_dir):
+ file = os.path.join(data_dir, "speaker-info.txt")
+ lines = get_lines(file)
+
+ ID2speakers = defaultdict()
+ for l in tqdm(lines):
+ items = l.replace(" ", "")
+
+ if items[:2] == "ID":
+ # The header line
+ continue
+
+ if items[0] == "p":
+ id = items[:4]
+ gender = items[6]
+ elif items[0] == "s":
+ id = items[:2]
+ gender = items[4]
+
+ if gender == "F":
+ speaker = "female_{}".format(id)
+ elif gender == "M":
+ speaker = "male_{}".format(id)
+
+ ID2speakers[id] = speaker
+
+ return ID2speakers
+
+
+def main(output_path, dataset_path, TEST_NUM_OF_EVERY_SPEAKER=3):
+ print("-" * 10)
+ print("Preparing test samples for vctk...")
+
+ save_dir = os.path.join(output_path, "vctk")
+ os.makedirs(save_dir, exist_ok=True)
+ train_output_file = os.path.join(save_dir, "train.json")
+ test_output_file = os.path.join(save_dir, "test.json")
+ singer_dict_file = os.path.join(save_dir, "singers.json")
+ utt2singer_file = os.path.join(save_dir, "utt2singer")
+ if has_existed(train_output_file):
+ return
+ utt2singer = open(utt2singer_file, "w")
+
+ # Load
+ vctk_dir = dataset_path
+
+ ID2speakers = vctk_speaker_infos(vctk_dir)
+ speaker2utts, unique_speakers = vctk_statistics(vctk_dir)
+
+ # We select speakers of standard samples as test utts
+ train = []
+ test = []
+
+ train_index_count = 0
+ test_index_count = 0
+ test_speaker_count = defaultdict(int)
+
+ train_total_duration = 0
+ test_total_duration = 0
+
+ for i, speaker in enumerate(speaker2utts.keys()):
+ for chosen_uid in tqdm(
+ speaker2utts[speaker],
+ desc="Speaker {}/{}, #Train = {}, #Test = {}".format(
+ i + 1, len(speaker2utts), train_index_count, test_index_count
+ ),
+ ):
+ res = {
+ "Dataset": "vctk",
+ "Singer": ID2speakers[speaker],
+ "Uid": "{}#{}".format(ID2speakers[speaker], chosen_uid),
+ }
+ res["Path"] = "{}/{}_{}.flac".format(speaker, speaker, chosen_uid)
+ res["Path"] = os.path.join(vctk_dir, "wav48_silence_trimmed", res["Path"])
+ assert os.path.exists(res["Path"])
+
+ duration = librosa.get_duration(filename=res["Path"])
+ res["Duration"] = duration
+
+ if test_speaker_count[speaker] < TEST_NUM_OF_EVERY_SPEAKER:
+ res["index"] = test_index_count
+ test_total_duration += duration
+ test.append(res)
+ test_index_count += 1
+ test_speaker_count[speaker] += 1
+ else:
+ res["index"] = train_index_count
+ train_total_duration += duration
+ train.append(res)
+ train_index_count += 1
+
+ utt2singer.write("{}\t{}\n".format(res["Uid"], res["Singer"]))
+
+ print("#Train = {}, #Test = {}".format(len(train), len(test)))
+ print(
+ "#Train hours= {}, #Test hours= {}".format(
+ train_total_duration / 3600, test_total_duration / 3600
+ )
+ )
+
+ # Save train.json and test.json
+ with open(train_output_file, "w") as f:
+ json.dump(train, f, indent=4, ensure_ascii=False)
+ with open(test_output_file, "w") as f:
+ json.dump(test, f, indent=4, ensure_ascii=False)
+
+ # Save singers.json
+ singer_lut = {name: i for i, name in enumerate(unique_speakers)}
+ with open(singer_dict_file, "w") as f:
+ json.dump(singer_lut, f, indent=4, ensure_ascii=False)
diff --git a/preprocessors/vctkfewsinger.py b/preprocessors/vctkfewsinger.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4b6adc76597b18ccbbf065ef339ccff77032a52
--- /dev/null
+++ b/preprocessors/vctkfewsinger.py
@@ -0,0 +1,175 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import json
+import pickle
+import glob
+from collections import defaultdict
+from tqdm import tqdm
+
+
+# Train: male 20 hours, female 10 hours
+TRAIN_MALE_MAX_SECONDS = 20 * 3600
+TRAIN_FEMALE_MAX_SECONDS = 10 * 3600
+TEST_MAX_NUM_EVERY_PERSON = 5
+
+
+def select_sample_idxs():
+ chosen_speakers = get_chosen_speakers()
+
+ with open(os.path.join(vctk_dir, "train.json"), "r") as f:
+ raw_train = json.load(f)
+ with open(os.path.join(vctk_dir, "test.json"), "r") as f:
+ raw_test = json.load(f)
+
+ train_idxs, test_idxs = [], []
+
+ # =========== Test ===========
+ test_nums = defaultdict(int)
+ for utt in tqdm(raw_train):
+ idx = utt["index"]
+ singer = utt["Singer"]
+
+ if singer in chosen_speakers and test_nums[singer] < TEST_MAX_NUM_EVERY_PERSON:
+ test_nums[singer] += 1
+ test_idxs.append("train_{}".format(idx))
+
+ for utt in tqdm(raw_test):
+ idx = utt["index"]
+ singer = utt["Singer"]
+
+ if singer in chosen_speakers and test_nums[singer] < TEST_MAX_NUM_EVERY_PERSON:
+ test_nums[singer] += 1
+ test_idxs.append("test_{}".format(idx))
+
+ # =========== Train ===========
+ for utt in tqdm(raw_train):
+ idx = utt["index"]
+ singer = utt["Singer"]
+
+ if singer in chosen_speakers and "train_{}".format(idx) not in test_idxs:
+ train_idxs.append("train_{}".format(idx))
+
+ for utt in tqdm(raw_test):
+ idx = utt["index"]
+ singer = utt["Singer"]
+
+ if singer in chosen_speakers and "test_{}".format(idx) not in test_idxs:
+ train_idxs.append("test_{}".format(idx))
+
+ train_idxs.sort()
+ test_idxs.sort()
+ return train_idxs, test_idxs, raw_train, raw_test
+
+
+def statistics_of_speakers():
+ speaker2time = defaultdict(float)
+ sex2time = defaultdict(float)
+
+ with open(os.path.join(vctk_dir, "train.json"), "r") as f:
+ train = json.load(f)
+ with open(os.path.join(vctk_dir, "test.json"), "r") as f:
+ test = json.load(f)
+
+ for utt in train + test:
+ # minutes
+ speaker2time[utt["Singer"]] += utt["Duration"]
+ # hours
+ sex2time[utt["Singer"].split("_")[0]] += utt["Duration"]
+
+ print(
+ "Female: {:.2f} hours, Male: {:.2f} hours.\n".format(
+ sex2time["female"] / 3600, sex2time["male"] / 3600
+ )
+ )
+
+ speaker2time = sorted(speaker2time.items(), key=lambda x: x[-1], reverse=True)
+ for singer, seconds in speaker2time:
+ print("{}\t{:.2f} mins".format(singer, seconds / 60))
+
+ return speaker2time
+
+
+def get_chosen_speakers():
+ speaker2time = statistics_of_speakers()
+
+ chosen_time = defaultdict(float)
+ chosen_speaker = defaultdict(list)
+ train_constrait = {
+ "male": TRAIN_MALE_MAX_SECONDS,
+ "female": TRAIN_FEMALE_MAX_SECONDS,
+ }
+
+ for speaker, seconds in speaker2time:
+ sex = speaker.split("_")[0]
+ if chosen_time[sex] < train_constrait[sex]:
+ chosen_time[sex] += seconds
+ chosen_speaker[sex].append(speaker)
+
+ speaker2time = dict(speaker2time)
+ chosen_speaker = chosen_speaker["male"] + chosen_speaker["female"]
+ print("\n#Chosen speakers = {}".format(len(chosen_speaker)))
+ for spk in chosen_speaker:
+ print("{}\t{:.2f} mins".format(spk, speaker2time[spk] / 60))
+
+ return chosen_speaker
+
+
+if __name__ == "__main__":
+ root_path = ""
+ vctk_dir = os.path.join(root_path, "vctk")
+ fewspeaker_dir = os.path.join(root_path, "vctkfewspeaker")
+ os.makedirs(fewspeaker_dir, exist_ok=True)
+
+ train_idxs, test_idxs, raw_train, raw_test = select_sample_idxs()
+ print("#Train = {}, #Test = {}".format(len(train_idxs), len(test_idxs)))
+
+ # There are no data leakage
+ assert len(set(train_idxs).intersection(set(test_idxs))) == 0
+ for idx in train_idxs + test_idxs:
+ # No test data of raw vctk
+ assert "test_" not in idx
+
+ for split, chosen_idxs in zip(["train", "test"], [train_idxs, test_idxs]):
+ print("{}: #chosen idx = {}\n".format(split, len(chosen_idxs)))
+
+ # Select features
+ feat_files = glob.glob("**/train.pkl", root_dir=vctk_dir, recursive=True)
+ for file in tqdm(feat_files):
+ raw_file = os.path.join(vctk_dir, file)
+ new_file = os.path.join(
+ fewspeaker_dir, file.replace("train.pkl", "{}.pkl".format(split))
+ )
+
+ new_dir = "/".join(new_file.split("/")[:-1])
+ os.makedirs(new_dir, exist_ok=True)
+
+ if "mel_min" in file or "mel_max" in file:
+ os.system("cp {} {}".format(raw_file, new_file))
+ continue
+
+ with open(raw_file, "rb") as f:
+ raw_feats = pickle.load(f)
+
+ print("file: {}, #raw_feats = {}".format(file, len(raw_feats)))
+ new_feats = []
+ for idx in chosen_idxs:
+ chosen_split_is_train, raw_idx = idx.split("_")
+ assert chosen_split_is_train == "train"
+ new_feats.append(raw_feats[int(raw_idx)])
+
+ with open(new_file, "wb") as f:
+ pickle.dump(new_feats, f)
+ print("New file: {}, #new_feats = {}".format(new_file, len(new_feats)))
+
+ # Utterance re-index
+ news_utts = [raw_train[int(idx.split("_")[-1])] for idx in chosen_idxs]
+ for i, utt in enumerate(news_utts):
+ utt["Dataset"] = "vctkfewsinger"
+ utt["index"] = i
+
+ with open(os.path.join(fewspeaker_dir, "{}.json".format(split)), "w") as f:
+ json.dump(news_utts, f, indent=4)
diff --git a/preprocessors/vctksample.py b/preprocessors/vctksample.py
new file mode 100644
index 0000000000000000000000000000000000000000..476790f3941eb07cfe6ed9051293004b4e588f24
--- /dev/null
+++ b/preprocessors/vctksample.py
@@ -0,0 +1,108 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import json
+import pickle
+import glob
+from collections import defaultdict
+from tqdm import tqdm
+from preprocessors import get_golden_samples_indexes
+
+
+TRAIN_MAX_NUM_EVERY_PERSON = 250
+TEST_MAX_NUM_EVERY_PERSON = 25
+
+
+def select_sample_idxs():
+ # =========== Train ===========
+ with open(os.path.join(vctk_dir, "train.json"), "r") as f:
+ raw_train = json.load(f)
+
+ train_idxs = []
+ train_nums = defaultdict(int)
+ for utt in tqdm(raw_train):
+ idx = utt["index"]
+ singer = utt["Singer"]
+
+ if train_nums[singer] < TRAIN_MAX_NUM_EVERY_PERSON:
+ train_idxs.append(idx)
+ train_nums[singer] += 1
+
+ # =========== Test ===========
+ with open(os.path.join(vctk_dir, "test.json"), "r") as f:
+ raw_test = json.load(f)
+
+ # golden test
+ test_idxs = get_golden_samples_indexes(
+ dataset_name="vctk", split="test", dataset_dir=vctk_dir
+ )
+ test_nums = defaultdict(int)
+ for idx in test_idxs:
+ singer = raw_test[idx]["Singer"]
+ test_nums[singer] += 1
+
+ for utt in tqdm(raw_test):
+ idx = utt["index"]
+ singer = utt["Singer"]
+
+ if test_nums[singer] < TEST_MAX_NUM_EVERY_PERSON:
+ test_idxs.append(idx)
+ test_nums[singer] += 1
+
+ train_idxs.sort()
+ test_idxs.sort()
+ return train_idxs, test_idxs, raw_train, raw_test
+
+
+if __name__ == "__main__":
+ root_path = ""
+ vctk_dir = os.path.join(root_path, "vctk")
+ sample_dir = os.path.join(root_path, "vctksample")
+ os.makedirs(sample_dir, exist_ok=True)
+
+ train_idxs, test_idxs, raw_train, raw_test = select_sample_idxs()
+ print("#Train = {}, #Test = {}".format(len(train_idxs), len(test_idxs)))
+
+ for split, chosen_idxs, utterances in zip(
+ ["train", "test"], [train_idxs, test_idxs], [raw_train, raw_test]
+ ):
+ print(
+ "#{} = {}, #chosen idx = {}\n".format(
+ split, len(utterances), len(chosen_idxs)
+ )
+ )
+
+ # Select features
+ feat_files = glob.glob(
+ "**/{}.pkl".format(split), root_dir=vctk_dir, recursive=True
+ )
+ for file in tqdm(feat_files):
+ raw_file = os.path.join(vctk_dir, file)
+ new_file = os.path.join(sample_dir, file)
+
+ new_dir = "/".join(new_file.split("/")[:-1])
+ os.makedirs(new_dir, exist_ok=True)
+
+ if "mel_min" in file or "mel_max" in file:
+ os.system("cp {} {}".format(raw_file, new_file))
+ continue
+
+ with open(raw_file, "rb") as f:
+ raw_feats = pickle.load(f)
+
+ print("file: {}, #raw_feats = {}".format(file, len(raw_feats)))
+ new_feats = [raw_feats[idx] for idx in chosen_idxs]
+ with open(new_file, "wb") as f:
+ pickle.dump(new_feats, f)
+
+ # Utterance re-index
+ news_utts = [utterances[idx] for idx in chosen_idxs]
+ for i, utt in enumerate(news_utts):
+ utt["Dataset"] = "vctksample"
+ utt["index"] = i
+
+ with open(os.path.join(sample_dir, "{}.json".format(split)), "w") as f:
+ json.dump(news_utts, f, indent=4)
diff --git a/preprocessors/vocalist.py b/preprocessors/vocalist.py
new file mode 100644
index 0000000000000000000000000000000000000000..44de1ac00ebcf10d65e6af37a98afcfe6ae85b89
--- /dev/null
+++ b/preprocessors/vocalist.py
@@ -0,0 +1,137 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import json
+import torchaudio
+from tqdm import tqdm
+from glob import glob
+from collections import defaultdict
+
+from utils.util import has_existed
+
+
+def vocalist_statistics(data_dir):
+ singers = []
+ songs = []
+ global2singer2songs = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
+
+ global_infos = glob(data_dir + "/*")
+
+ for global_info in global_infos:
+ global_split = global_info.split("/")[-1]
+
+ singer_infos = glob(global_info + "/*")
+
+ for singer_info in singer_infos:
+ singer = singer_info.split("/")[-1]
+
+ singers.append(singer)
+
+ song_infos = glob(singer_info + "/*")
+ for song_info in song_infos:
+ song = song_info.split("/")[-1]
+
+ songs.append(song)
+
+ utts = glob(song_info + "/*.wav")
+
+ for utt in utts:
+ uid = utt.split("/")[-1].split(".")[0]
+ global2singer2songs[global_split][singer][song].append(uid)
+
+ unique_singers = list(set(singers))
+ unique_songs = list(set(songs))
+ unique_singers.sort()
+ unique_songs.sort()
+
+ print(
+ "vocalist: {} singers, {} songs ({} unique songs)".format(
+ len(unique_singers), len(songs), len(unique_songs)
+ )
+ )
+ print("Singers: \n{}".format("\t".join(unique_singers)))
+ return global2singer2songs, unique_singers
+
+
+def main(output_path, dataset_path):
+ print("-" * 10)
+ print("Preparing test samples for vocalist...\n")
+
+ save_dir = os.path.join(output_path, "vocalist")
+ os.makedirs(save_dir, exist_ok=True)
+ train_output_file = os.path.join(save_dir, "train.json")
+ test_output_file = os.path.join(save_dir, "test.json")
+ singer_dict_file = os.path.join(save_dir, "singers.json")
+ utt2singer_file = os.path.join(save_dir, "utt2singer")
+ if (
+ has_existed(train_output_file)
+ and has_existed(test_output_file)
+ and has_existed(singer_dict_file)
+ and has_existed(utt2singer_file)
+ ):
+ return
+ utt2singer = open(utt2singer_file, "w")
+
+ # Load
+ vocalist_path = dataset_path
+
+ global2singer2songs, unique_singers = vocalist_statistics(vocalist_path)
+
+ train = []
+ test = []
+
+ train_index_count = 0
+ test_index_count = 0
+
+ train_total_duration = 0
+ test_total_duration = 0
+
+ for global_info, singer2songs in tqdm(global2singer2songs.items()):
+ for singer, songs in tqdm(singer2songs.items()):
+ song_names = list(songs.keys())
+
+ for chosen_song in song_names:
+ for chosen_uid in songs[chosen_song]:
+ res = {
+ "Dataset": "opensinger",
+ "Singer": singer,
+ "Song": chosen_song,
+ "Uid": "{}_{}_{}".format(singer, chosen_song, chosen_uid),
+ }
+ res["Path"] = "{}/{}/{}/{}.wav".format(
+ global_info, singer, chosen_song, chosen_uid
+ )
+ res["Path"] = os.path.join(vocalist_path, res["Path"])
+ assert os.path.exists(res["Path"])
+
+ waveform, sample_rate = torchaudio.load(res["Path"])
+ duration = waveform.size(-1) / sample_rate
+ res["Duration"] = duration
+
+ res["index"] = test_index_count
+ test_total_duration += duration
+ test.append(res)
+ test_index_count += 1
+
+ utt2singer.write("{}\t{}\n".format(res["Uid"], res["Singer"]))
+
+ print("#Train = {}, #Test = {}".format(len(train), len(test)))
+ print(
+ "#Train hours= {}, #Test hours= {}".format(
+ train_total_duration / 3600, test_total_duration / 3600
+ )
+ )
+
+ # Save train.json and test.json
+ with open(train_output_file, "w") as f:
+ json.dump(train, f, indent=4, ensure_ascii=False)
+ with open(test_output_file, "w") as f:
+ json.dump(test, f, indent=4, ensure_ascii=False)
+
+ # Save singers.json
+ singer_lut = {name: i for i, name in enumerate(unique_singers)}
+ with open(singer_dict_file, "w") as f:
+ json.dump(singer_lut, f, indent=4, ensure_ascii=False)
diff --git a/pretrained/README.md b/pretrained/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..5320f52f4fa6da8d334ac318283dd5d60fcb0839
--- /dev/null
+++ b/pretrained/README.md
@@ -0,0 +1,227 @@
+# Pretrained Models Dependency
+
+The models dependency of Amphion are as follows (sort alphabetically):
+
+- [Pretrained Models Dependency](#pretrained-models-dependency)
+ - [Amphion Singing BigVGAN](#amphion-singing-bigvgan)
+ - [Amphion Speech HiFi-GAN](#amphion-speech-hifi-gan)
+ - [ContentVec](#contentvec)
+ - [WeNet](#wenet)
+ - [Whisper](#whisper)
+ - [RawNet3](#rawnet3)
+
+
+The instructions about how to download them is displayed as follows.
+
+## Amphion Singing BigVGAN
+
+We fine-tune the official BigVGAN pretrained model with over 120 hours singing voice data. The fine-tuned checkpoint can be downloaded [here](https://huggingface.co/amphion/BigVGAN_singing_bigdata). You need to download the `400000.pt` and `args.json` files into `Amphion/pretrained/bigvgan`:
+
+```
+Amphion
+ ┣ pretrained
+ ┃ ┣ bivgan
+ ┃ ┃ ┣ 400000.pt
+ ┃ ┃ ┣ args.json
+```
+
+## Amphion Speech HiFi-GAN
+
+We trained our HiFi-GAN pretrained model with 685 hours speech data. Which can be downloaded [here](https://huggingface.co/amphion/hifigan_speech_bigdata). You need to download the whole folder of `hifigan_speech` into `Amphion/pretrained/hifigan`.
+
+```
+Amphion
+ ┣ pretrained
+ ┃ ┣ hifigan
+ ┃ ┃ ┣ hifigan_speech
+ ┃ ┃ ┃ ┣ log
+ ┃ ┃ ┃ ┣ result
+ ┃ ┃ ┃ ┣ checkpoint
+ ┃ ┃ ┃ ┣ args.json
+```
+
+## Amphion DiffWave
+
+We trained our DiffWave pretrained model with 125 hours speech data and around 80 hours of singing voice data. Which can be downloaded [here](https://huggingface.co/amphion/diffwave). You need to download the whole folder of `diffwave` into `Amphion/pretrained/diffwave`.
+
+```
+Amphion
+ ┣ pretrained
+ ┃ ┣ diffwave
+ ┃ ┃ ┣ diffwave_speech
+ ┃ ┃ ┃ ┣ samples
+ ┃ ┃ ┃ ┣ checkpoint
+ ┃ ┃ ┃ ┣ args.json
+```
+
+## ContentVec
+
+You can download the pretrained ContentVec model [here](https://github.com/auspicious3000/contentvec). Note that we use the `ContentVec_legacy-500 classes` checkpoint. Assume that you download the `checkpoint_best_legacy_500.pt` into the `Amphion/pretrained/contentvec`.
+
+```
+Amphion
+ ┣ pretrained
+ ┃ ┣ contentvec
+ ┃ ┃ ┣ checkpoint_best_legacy_500.pt
+```
+
+## WeNet
+
+You can download the pretrained WeNet model [here](https://github.com/wenet-e2e/wenet/blob/main/docs/pretrained_models.md). Take the `wenetspeech` pretrained checkpoint as an example, assume you download the `wenetspeech_u2pp_conformer_exp.tar` into the `Amphion/pretrained/wenet`. Unzip it and modify its configuration file as follows:
+
+```sh
+cd Amphion/pretrained/wenet
+
+### Unzip the expt dir
+tar -xvf wenetspeech_u2pp_conformer_exp.tar.gz
+
+### Specify the updated path in train.yaml
+cd 20220506_u2pp_conformer_exp
+vim train.yaml
+# TODO: Change the value of "cmvn_file" (Line 2) to the absolute path of the `global_cmvn` file. (Eg: [YourPath]/Amphion/pretrained/wenet/20220506_u2pp_conformer_exp/global_cmvn)
+```
+
+The final file struture tree is like:
+
+```
+Amphion
+ ┣ pretrained
+ ┃ ┣ wenet
+ ┃ ┃ ┣ 20220506_u2pp_conformer_exp
+ ┃ ┃ ┃ ┣ final.pt
+ ┃ ┃ ┃ ┣ global_cmvn
+ ┃ ┃ ┃ ┣ train.yaml
+ ┃ ┃ ┃ ┣ units.txt
+```
+
+## Whisper
+
+The official pretrained whisper checkpoints can be available [here](https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/__init__.py#L17). In Amphion, we use the `medium` whisper model by default. You can download it as follows:
+
+```bash
+cd Amphion/pretrained
+mkdir whisper
+cd whisper
+
+wget https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt
+```
+
+The final file structure tree is like:
+
+```
+Amphion
+ ┣ pretrained
+ ┃ ┣ whisper
+ ┃ ┃ ┣ medium.pt
+```
+
+## RawNet3
+
+The official pretrained RawNet3 checkpoints can be available [here](https://huggingface.co/jungjee/RawNet3). You need to download the `model.pt` file and put it in the folder.
+
+The final file structure tree is like:
+
+```
+Amphion
+ ┣ pretrained
+ ┃ ┣ rawnet3
+ ┃ ┃ ┣ model.pt
+```
+
+
+# (Optional) Model Dependencies for Evaluation
+When utilizing Amphion's Evaluation Pipelines, terminals without access to `huggingface.co` may encounter error messages such as "OSError: Can't load tokenizer for ...". To work around this, the dependant models for evaluation can be pre-prepared and stored here, at `Amphion/pretrained`, and follow [this README](../egs/metrics/README.md#troubleshooting) to configure your environment to load local models.
+
+The dependant models of Amphion's evaluation pipeline are as follows (sort alphabetically):
+
+- [Evaluation Pipeline Models Dependency](#optional-model-dependencies-for-evaluation)
+ - [bert-base-uncased](#bert-base-uncased)
+ - [facebook/bart-base](#facebookbart-base)
+ - [roberta-base](#roberta-base)
+ - [wavlm](#wavlm)
+
+The instructions about how to download them is displayed as follows.
+
+## bert-base-uncased
+
+To load `bert-base-uncased` locally, follow [this link](https://huggingface.co/bert-base-uncased) to download all files for `bert-base-uncased` model, and store them under `Amphion/pretrained/bert-base-uncased`, conforming to the following file structure tree:
+
+```
+Amphion
+ ┣ pretrained
+ ┃ ┣ bert-base-uncased
+ ┃ ┃ ┣ config.json
+ ┃ ┃ ┣ coreml
+ ┃ ┃ ┃ ┣ fill-mask
+ ┃ ┃ ┃ ┣ float32_model.mlpackage
+ ┃ ┃ ┃ ┣ Data
+ ┃ ┃ ┃ ┣ com.apple.CoreML
+ ┃ ┃ ┃ ┣ model.mlmodel
+ ┃ ┃ ┣ flax_model.msgpack
+ ┃ ┃ ┣ LICENSE
+ ┃ ┃ ┣ model.onnx
+ ┃ ┃ ┣ model.safetensors
+ ┃ ┃ ┣ pytorch_model.bin
+ ┃ ┃ ┣ README.md
+ ┃ ┃ ┣ rust_model.ot
+ ┃ ┃ ┣ tf_model.h5
+ ┃ ┃ ┣ tokenizer_config.json
+ ┃ ┃ ┣ tokenizer.json
+ ┃ ┃ ┣ vocab.txt
+```
+
+## facebook/bart-base
+
+To load `facebook/bart-base` locally, follow [this link](https://huggingface.co/facebook/bart-base) to download all files for `facebook/bart-base` model, and store them under `Amphion/pretrained/facebook/bart-base`, conforming to the following file structure tree:
+
+```
+Amphion
+ ┣ pretrained
+ ┃ ┣ facebook
+ ┃ ┃ ┣ bart-base
+ ┃ ┃ ┃ ┣ config.json
+ ┃ ┃ ┃ ┣ flax_model.msgpack
+ ┃ ┃ ┃ ┣ gitattributes.txt
+ ┃ ┃ ┃ ┣ merges.txt
+ ┃ ┃ ┃ ┣ model.safetensors
+ ┃ ┃ ┃ ┣ pytorch_model.bin
+ ┃ ┃ ┃ ┣ README.txt
+ ┃ ┃ ┃ ┣ rust_model.ot
+ ┃ ┃ ┃ ┣ tf_model.h5
+ ┃ ┃ ┃ ┣ tokenizer.json
+ ┃ ┃ ┃ ┣ vocab.json
+```
+
+## roberta-base
+
+To load `roberta-base` locally, follow [this link](https://huggingface.co/roberta-base) to download all files for `roberta-base` model, and store them under `Amphion/pretrained/roberta-base`, conforming to the following file structure tree:
+
+```
+Amphion
+ ┣ pretrained
+ ┃ ┣ roberta-base
+ ┃ ┃ ┣ config.json
+ ┃ ┃ ┣ dict.txt
+ ┃ ┃ ┣ flax_model.msgpack
+ ┃ ┃ ┣ gitattributes.txt
+ ┃ ┃ ┣ merges.txt
+ ┃ ┃ ┣ model.safetensors
+ ┃ ┃ ┣ pytorch_model.bin
+ ┃ ┃ ┣ README.txt
+ ┃ ┃ ┣ rust_model.ot
+ ┃ ┃ ┣ tf_model.h5
+ ┃ ┃ ┣ tokenizer.json
+ ┃ ┃ ┣ vocab.json
+```
+
+## wavlm
+
+The official pretrained wavlm checkpoints can be available [here](https://huggingface.co/microsoft/wavlm-base-plus-sv). The file structure tree is as follows:
+
+```
+Amphion
+ ┣ wavlm
+ ┃ ┣ config.json
+ ┃ ┣ preprocessor_config.json
+ ┃ ┣ pytorch_model.bin
+```
\ No newline at end of file
diff --git a/pretrained/bert-base-uncased/README.md b/pretrained/bert-base-uncased/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..08ee7e591c070d21e72235b7380ef63107bd5a27
--- /dev/null
+++ b/pretrained/bert-base-uncased/README.md
@@ -0,0 +1,8 @@
+[![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Models-pink)](https://huggingface.co/bert-base-uncased)
+
+# Download
+
+- [Link](https://huggingface.co/bert-base-uncased)
+- Model: `bert-base-uncased`
+- Download the latest files under `Files and versions` tab.
+- Overwrite this file if necessary.
diff --git a/pretrained/bigvgan/README.md b/pretrained/bigvgan/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..5d2a1fbce49953c03508c0a9353cf914e8a66d5c
--- /dev/null
+++ b/pretrained/bigvgan/README.md
@@ -0,0 +1,7 @@
+[![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Models-pink)](https://huggingface.co/amphion/BigVGAN_singing_bigdata)
+
+# Download
+
+- [Link](https://huggingface.co/amphion/BigVGAN_singing_bigdata)
+- Model: `bigvgan_singing`
+- Datasets: VCTK + LibriTTS + LJSpeech
diff --git a/pretrained/contentvec/README.md b/pretrained/contentvec/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..6ea10938244c7282355be035c9489efa5bf08bdd
--- /dev/null
+++ b/pretrained/contentvec/README.md
@@ -0,0 +1,5 @@
+# Download
+
+- [Link](https://github.com/auspicious3000/contentvec)
+- Model: `ContentVec_legacy`
+- Classes: 500
diff --git a/pretrained/diffwave/README.md b/pretrained/diffwave/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..6ec39d5d1d0c313d887dc6972ed485224ae485c8
--- /dev/null
+++ b/pretrained/diffwave/README.md
@@ -0,0 +1,7 @@
+[![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Models-pink)](https://huggingface.co/amphion/diffwave)
+
+# Download
+
+- [Link](https://huggingface.co/amphion/diffwave)
+- Model: `diffwave`
+- Datasets: VCTK + LJSpeech + PJS + Opencpop + CSD + M4Singer + OpenSinger + PopCS
diff --git a/pretrained/facebook/bart-base/README.md b/pretrained/facebook/bart-base/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..1a8dad4a599cf29bff049e3b06236f3ce06f14a5
--- /dev/null
+++ b/pretrained/facebook/bart-base/README.md
@@ -0,0 +1,8 @@
+[![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Models-pink)](https://huggingface.co/facebook/bart-base)
+
+# Download
+
+- [Link](https://huggingface.co/facebook/bart-base)
+- Model: `facebook/bart-base`
+- Download the latest files under `Files and versions` tab.
+- Overwrite this file if necessary.
diff --git a/pretrained/hifigan/README.md b/pretrained/hifigan/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f26ddfa785ae889b18f67ce59dd4220bc749d91e
--- /dev/null
+++ b/pretrained/hifigan/README.md
@@ -0,0 +1,7 @@
+[![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Models-pink)](https://huggingface.co/amphion/hifigan_speech_bigdata)
+
+# Download
+
+- [Link](https://huggingface.co/amphion/hifigan_speech_bigdata)
+- Model: `hifigan_speech`
+- Datasets: VCTK + LibriTTS + LJSpeech
diff --git a/pretrained/rawnet3/README.md b/pretrained/rawnet3/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..8f4d94a9adb6e83e50f30a9d79c44f4ec2f56d42
--- /dev/null
+++ b/pretrained/rawnet3/README.md
@@ -0,0 +1,4 @@
+# Download
+
+- [Link](https://huggingface.co/jungjee/RawNet3)
+- Pretrained Datasets: `VoxCeleb1`, `VoxCeleb2`
diff --git a/pretrained/roberta-base/README.md b/pretrained/roberta-base/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..630b8908efc48a73f31117eb2581e22cb286084d
--- /dev/null
+++ b/pretrained/roberta-base/README.md
@@ -0,0 +1,8 @@
+[![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Models-pink)](https://huggingface.co/roberta-base)
+
+# Download
+
+- [Link](https://huggingface.co/roberta-base)
+- Model: `roberta-base`
+- Download the latest files under `Files and versions` tab.
+- Overwrite this file if necessary.
diff --git a/pretrained/wavlm/README.md b/pretrained/wavlm/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f08a31080cdb4c6138480bb102525760e5542953
--- /dev/null
+++ b/pretrained/wavlm/README.md
@@ -0,0 +1,4 @@
+# Download
+
+- [Link](https://huggingface.co/microsoft/wavlm-base-plus-sv)
+- Pretrained Model: `wavlm-base-plus-sv`
diff --git a/pretrained/wenet/README.md b/pretrained/wenet/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..34c91e30f4b3a3fae0fc89fbc92f59e3824260c1
--- /dev/null
+++ b/pretrained/wenet/README.md
@@ -0,0 +1,4 @@
+# Download
+
+- [Link](https://github.com/wenet-e2e/wenet/blob/main/docs/pretrained_models.md)
+- Pretrained Datasets: `wenetspeech`
diff --git a/processors/__init__.py b/processors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/processors/acoustic_extractor.py b/processors/acoustic_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c4d9be76caa16d0c1ae6de6338974ce9f47da5c
--- /dev/null
+++ b/processors/acoustic_extractor.py
@@ -0,0 +1,1042 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import torch
+import numpy as np
+
+import json
+from tqdm import tqdm
+from sklearn.preprocessing import StandardScaler
+from utils.io import save_feature, save_txt, save_torch_audio
+from utils.util import has_existed
+from utils.tokenizer import extract_encodec_token
+from utils.stft import TacotronSTFT
+from utils.dsp import compress, audio_to_label
+from utils.data_utils import remove_outlier
+from preprocessors.metadata import replace_augment_name
+from scipy.interpolate import interp1d
+from utils.mel import (
+ extract_mel_features,
+ extract_linear_features,
+ extract_mel_features_tts,
+)
+
+ZERO = 1e-12
+
+
+def extract_utt_acoustic_features_parallel(metadata, dataset_output, cfg, n_workers=1):
+ """Extract acoustic features from utterances using muliprocess
+
+ Args:
+ metadata (dict): dictionary that stores data in train.json and test.json files
+ dataset_output (str): directory to store acoustic features
+ cfg (dict): dictionary that stores configurations
+ n_workers (int, optional): num of processes to extract features in parallel. Defaults to 1.
+
+ Returns:
+ list: acoustic features
+ """
+ for utt in tqdm(metadata):
+ if cfg.task_type == "tts":
+ extract_utt_acoustic_features_tts(dataset_output, cfg, utt)
+ if cfg.task_type == "svc":
+ extract_utt_acoustic_features_svc(dataset_output, cfg, utt)
+ if cfg.task_type == "vocoder":
+ extract_utt_acoustic_features_vocoder(dataset_output, cfg, utt)
+ if cfg.task_type == "tta":
+ extract_utt_acoustic_features_tta(dataset_output, cfg, utt)
+
+
+def avg_phone_feature(feature, duration, interpolation=False):
+ feature = feature[: sum(duration)]
+ if interpolation:
+ nonzero_ids = np.where(feature != 0)[0]
+ interp_fn = interp1d(
+ nonzero_ids,
+ feature[nonzero_ids],
+ fill_value=(feature[nonzero_ids[0]], feature[nonzero_ids[-1]]),
+ bounds_error=False,
+ )
+ feature = interp_fn(np.arange(0, len(feature)))
+
+ # Phoneme-level average
+ pos = 0
+ for i, d in enumerate(duration):
+ if d > 0:
+ feature[i] = np.mean(feature[pos : pos + d])
+ else:
+ feature[i] = 0
+ pos += d
+ feature = feature[: len(duration)]
+ return feature
+
+
+def extract_utt_acoustic_features_serial(metadata, dataset_output, cfg):
+ """Extract acoustic features from utterances (in single process)
+
+ Args:
+ metadata (dict): dictionary that stores data in train.json and test.json files
+ dataset_output (str): directory to store acoustic features
+ cfg (dict): dictionary that stores configurations
+
+ """
+ for utt in tqdm(metadata):
+ if cfg.task_type == "tts":
+ extract_utt_acoustic_features_tts(dataset_output, cfg, utt)
+ if cfg.task_type == "svc":
+ extract_utt_acoustic_features_svc(dataset_output, cfg, utt)
+ if cfg.task_type == "vocoder":
+ extract_utt_acoustic_features_vocoder(dataset_output, cfg, utt)
+ if cfg.task_type == "tta":
+ extract_utt_acoustic_features_tta(dataset_output, cfg, utt)
+
+
+def __extract_utt_acoustic_features(dataset_output, cfg, utt):
+ """Extract acoustic features from utterances (in single process)
+
+ Args:
+ dataset_output (str): directory to store acoustic features
+ cfg (dict): dictionary that stores configurations
+ utt (dict): utterance info including dataset, singer, uid:{singer}_{song}_{index},
+ path to utternace, duration, utternace index
+
+ """
+ from utils import audio, f0, world, duration
+
+ uid = utt["Uid"]
+ wav_path = utt["Path"]
+ if os.path.exists(os.path.join(dataset_output, cfg.preprocess.raw_data)):
+ wav_path = os.path.join(
+ dataset_output, cfg.preprocess.raw_data, utt["Singer"], uid + ".wav"
+ )
+
+ with torch.no_grad():
+ # Load audio data into tensor with sample rate of the config file
+ wav_torch, _ = audio.load_audio_torch(wav_path, cfg.preprocess.sample_rate)
+ wav = wav_torch.cpu().numpy()
+
+ # extract features
+ if cfg.preprocess.extract_duration:
+ durations, phones, start, end = duration.get_duration(
+ utt, wav, cfg.preprocess
+ )
+ save_feature(dataset_output, cfg.preprocess.duration_dir, uid, durations)
+ save_txt(dataset_output, cfg.preprocess.lab_dir, uid, phones)
+ wav = wav[start:end].astype(np.float32)
+ wav_torch = torch.from_numpy(wav).to(wav_torch.device)
+
+ if cfg.preprocess.extract_linear_spec:
+ linear = extract_linear_features(wav_torch.unsqueeze(0), cfg.preprocess)
+ save_feature(
+ dataset_output, cfg.preprocess.linear_dir, uid, linear.cpu().numpy()
+ )
+
+ if cfg.preprocess.extract_mel:
+ if cfg.preprocess.mel_extract_mode == "taco":
+ _stft = TacotronSTFT(
+ sampling_rate=cfg.preprocess.sample_rate,
+ win_length=cfg.preprocess.win_size,
+ hop_length=cfg.preprocess.hop_size,
+ filter_length=cfg.preprocess.n_fft,
+ n_mel_channels=cfg.preprocess.n_mel,
+ mel_fmin=cfg.preprocess.fmin,
+ mel_fmax=cfg.preprocess.fmax,
+ )
+ mel = extract_mel_features(
+ wav_torch.unsqueeze(0), cfg.preprocess, taco=True, _stft=_stft
+ )
+ if cfg.preprocess.extract_duration:
+ mel = mel[:, : sum(durations)]
+ else:
+ mel = extract_mel_features(wav_torch.unsqueeze(0), cfg.preprocess)
+ save_feature(dataset_output, cfg.preprocess.mel_dir, uid, mel.cpu().numpy())
+
+ if cfg.preprocess.extract_energy:
+ if (
+ cfg.preprocess.energy_extract_mode == "from_mel"
+ and cfg.preprocess.extract_mel
+ ):
+ energy = (mel.exp() ** 2).sum(0).sqrt().cpu().numpy()
+ elif cfg.preprocess.energy_extract_mode == "from_waveform":
+ energy = audio.energy(wav, cfg.preprocess)
+ elif cfg.preprocess.energy_extract_mode == "from_tacotron_stft":
+ _stft = TacotronSTFT(
+ sampling_rate=cfg.preprocess.sample_rate,
+ win_length=cfg.preprocess.win_size,
+ hop_length=cfg.preprocess.hop_size,
+ filter_length=cfg.preprocess.n_fft,
+ n_mel_channels=cfg.preprocess.n_mel,
+ mel_fmin=cfg.preprocess.fmin,
+ mel_fmax=cfg.preprocess.fmax,
+ )
+ _, energy = audio.get_energy_from_tacotron(wav, _stft)
+ else:
+ assert cfg.preprocess.energy_extract_mode in [
+ "from_mel",
+ "from_waveform",
+ "from_tacotron_stft",
+ ], f"{cfg.preprocess.energy_extract_mode} not in supported energy_extract_mode [from_mel, from_waveform, from_tacotron_stft]"
+ if cfg.preprocess.extract_duration:
+ energy = energy[: sum(durations)]
+ phone_energy = avg_phone_feature(energy, durations)
+ save_feature(
+ dataset_output, cfg.preprocess.phone_energy_dir, uid, phone_energy
+ )
+
+ save_feature(dataset_output, cfg.preprocess.energy_dir, uid, energy)
+
+ if cfg.preprocess.extract_pitch:
+ pitch = f0.get_f0(wav, cfg.preprocess)
+ if cfg.preprocess.extract_duration:
+ pitch = pitch[: sum(durations)]
+ phone_pitch = avg_phone_feature(pitch, durations, interpolation=True)
+ save_feature(
+ dataset_output, cfg.preprocess.phone_pitch_dir, uid, phone_pitch
+ )
+ save_feature(dataset_output, cfg.preprocess.pitch_dir, uid, pitch)
+
+ if cfg.preprocess.extract_uv:
+ assert isinstance(pitch, np.ndarray)
+ uv = pitch != 0
+ save_feature(dataset_output, cfg.preprocess.uv_dir, uid, uv)
+
+ if cfg.preprocess.extract_audio:
+ save_feature(dataset_output, cfg.preprocess.audio_dir, uid, wav)
+
+ if cfg.preprocess.extract_label:
+ if cfg.preprocess.is_mu_law:
+ # compress audio
+ wav = compress(wav, cfg.preprocess.bits)
+ label = audio_to_label(wav, cfg.preprocess.bits)
+ save_feature(dataset_output, cfg.preprocess.label_dir, uid, label)
+
+ if cfg.preprocess.extract_acoustic_token:
+ if cfg.preprocess.acoustic_token_extractor == "Encodec":
+ codes = extract_encodec_token(wav_path)
+ save_feature(
+ dataset_output, cfg.preprocess.acoustic_token_dir, uid, codes
+ )
+
+
+# TODO: refactor extract_utt_acoustic_features_task function due to many duplicated code
+def extract_utt_acoustic_features_tts(dataset_output, cfg, utt):
+ """Extract acoustic features from utterances (in single process)
+
+ Args:
+ dataset_output (str): directory to store acoustic features
+ cfg (dict): dictionary that stores configurations
+ utt (dict): utterance info including dataset, singer, uid:{singer}_{song}_{index},
+ path to utternace, duration, utternace index
+
+ """
+ from utils import audio, f0, world, duration
+
+ uid = utt["Uid"]
+ wav_path = utt["Path"]
+ if os.path.exists(os.path.join(dataset_output, cfg.preprocess.raw_data)):
+ wav_path = os.path.join(
+ dataset_output, cfg.preprocess.raw_data, utt["Singer"], uid + ".wav"
+ )
+ if not os.path.exists(wav_path):
+ wav_path = os.path.join(
+ dataset_output, cfg.preprocess.raw_data, utt["Singer"], uid + ".flac"
+ )
+
+ assert os.path.exists(wav_path)
+
+ with torch.no_grad():
+ # Load audio data into tensor with sample rate of the config file
+ wav_torch, _ = audio.load_audio_torch(wav_path, cfg.preprocess.sample_rate)
+ wav = wav_torch.cpu().numpy()
+
+ # extract features
+ if cfg.preprocess.extract_duration:
+ durations, phones, start, end = duration.get_duration(
+ utt, wav, cfg.preprocess
+ )
+ save_feature(dataset_output, cfg.preprocess.duration_dir, uid, durations)
+ save_txt(dataset_output, cfg.preprocess.lab_dir, uid, phones)
+ wav = wav[start:end].astype(np.float32)
+ wav_torch = torch.from_numpy(wav).to(wav_torch.device)
+
+ if cfg.preprocess.extract_linear_spec:
+ from utils.mel import extract_linear_features
+
+ linear = extract_linear_features(wav_torch.unsqueeze(0), cfg.preprocess)
+ save_feature(
+ dataset_output, cfg.preprocess.linear_dir, uid, linear.cpu().numpy()
+ )
+
+ if cfg.preprocess.extract_mel:
+ from utils.mel import extract_mel_features
+
+ if cfg.preprocess.mel_extract_mode == "taco":
+ _stft = TacotronSTFT(
+ sampling_rate=cfg.preprocess.sample_rate,
+ win_length=cfg.preprocess.win_size,
+ hop_length=cfg.preprocess.hop_size,
+ filter_length=cfg.preprocess.n_fft,
+ n_mel_channels=cfg.preprocess.n_mel,
+ mel_fmin=cfg.preprocess.fmin,
+ mel_fmax=cfg.preprocess.fmax,
+ )
+ mel = extract_mel_features_tts(
+ wav_torch.unsqueeze(0), cfg.preprocess, taco=True, _stft=_stft
+ )
+ if cfg.preprocess.extract_duration:
+ mel = mel[:, : sum(durations)]
+ else:
+ mel = extract_mel_features(wav_torch.unsqueeze(0), cfg.preprocess)
+ save_feature(dataset_output, cfg.preprocess.mel_dir, uid, mel.cpu().numpy())
+
+ if cfg.preprocess.extract_energy:
+ if (
+ cfg.preprocess.energy_extract_mode == "from_mel"
+ and cfg.preprocess.extract_mel
+ ):
+ energy = (mel.exp() ** 2).sum(0).sqrt().cpu().numpy()
+ elif cfg.preprocess.energy_extract_mode == "from_waveform":
+ energy = audio.energy(wav, cfg.preprocess)
+ elif cfg.preprocess.energy_extract_mode == "from_tacotron_stft":
+ _stft = TacotronSTFT(
+ sampling_rate=cfg.preprocess.sample_rate,
+ win_length=cfg.preprocess.win_size,
+ hop_length=cfg.preprocess.hop_size,
+ filter_length=cfg.preprocess.n_fft,
+ n_mel_channels=cfg.preprocess.n_mel,
+ mel_fmin=cfg.preprocess.fmin,
+ mel_fmax=cfg.preprocess.fmax,
+ )
+ _, energy = audio.get_energy_from_tacotron(wav, _stft)
+ else:
+ assert cfg.preprocess.energy_extract_mode in [
+ "from_mel",
+ "from_waveform",
+ "from_tacotron_stft",
+ ], f"{cfg.preprocess.energy_extract_mode} not in supported energy_extract_mode [from_mel, from_waveform, from_tacotron_stft]"
+ if cfg.preprocess.extract_duration:
+ energy = energy[: sum(durations)]
+ phone_energy = avg_phone_feature(energy, durations)
+ save_feature(
+ dataset_output, cfg.preprocess.phone_energy_dir, uid, phone_energy
+ )
+
+ save_feature(dataset_output, cfg.preprocess.energy_dir, uid, energy)
+
+ if cfg.preprocess.extract_pitch:
+ pitch = f0.get_f0(wav, cfg.preprocess)
+ if cfg.preprocess.extract_duration:
+ pitch = pitch[: sum(durations)]
+ phone_pitch = avg_phone_feature(pitch, durations, interpolation=True)
+ save_feature(
+ dataset_output, cfg.preprocess.phone_pitch_dir, uid, phone_pitch
+ )
+ save_feature(dataset_output, cfg.preprocess.pitch_dir, uid, pitch)
+
+ if cfg.preprocess.extract_uv:
+ assert isinstance(pitch, np.ndarray)
+ uv = pitch != 0
+ save_feature(dataset_output, cfg.preprocess.uv_dir, uid, uv)
+
+ if cfg.preprocess.extract_audio:
+ save_torch_audio(
+ dataset_output,
+ cfg.preprocess.audio_dir,
+ uid,
+ wav_torch,
+ cfg.preprocess.sample_rate,
+ )
+
+ if cfg.preprocess.extract_label:
+ if cfg.preprocess.is_mu_law:
+ # compress audio
+ wav = compress(wav, cfg.preprocess.bits)
+ label = audio_to_label(wav, cfg.preprocess.bits)
+ save_feature(dataset_output, cfg.preprocess.label_dir, uid, label)
+
+ if cfg.preprocess.extract_acoustic_token:
+ if cfg.preprocess.acoustic_token_extractor == "Encodec":
+ codes = extract_encodec_token(wav_path)
+ save_feature(
+ dataset_output, cfg.preprocess.acoustic_token_dir, uid, codes
+ )
+
+
+def extract_utt_acoustic_features_svc(dataset_output, cfg, utt):
+ __extract_utt_acoustic_features(dataset_output, cfg, utt)
+
+
+def extract_utt_acoustic_features_tta(dataset_output, cfg, utt):
+ __extract_utt_acoustic_features(dataset_output, cfg, utt)
+
+
+def extract_utt_acoustic_features_vocoder(dataset_output, cfg, utt):
+ """Extract acoustic features from utterances (in single process)
+
+ Args:
+ dataset_output (str): directory to store acoustic features
+ cfg (dict): dictionary that stores configurations
+ utt (dict): utterance info including dataset, singer, uid:{singer}_{song}_{index},
+ path to utternace, duration, utternace index
+
+ """
+ from utils import audio, f0, world, duration
+
+ uid = utt["Uid"]
+ wav_path = utt["Path"]
+
+ with torch.no_grad():
+ # Load audio data into tensor with sample rate of the config file
+ wav_torch, _ = audio.load_audio_torch(wav_path, cfg.preprocess.sample_rate)
+ wav = wav_torch.cpu().numpy()
+
+ # extract features
+ if cfg.preprocess.extract_mel:
+ from utils.mel import extract_mel_features
+
+ mel = extract_mel_features(wav_torch.unsqueeze(0), cfg.preprocess)
+ save_feature(dataset_output, cfg.preprocess.mel_dir, uid, mel.cpu().numpy())
+
+ if cfg.preprocess.extract_energy:
+ if (
+ cfg.preprocess.energy_extract_mode == "from_mel"
+ and cfg.preprocess.extract_mel
+ ):
+ energy = (mel.exp() ** 2).sum(0).sqrt().cpu().numpy()
+ elif cfg.preprocess.energy_extract_mode == "from_waveform":
+ energy = audio.energy(wav, cfg.preprocess)
+ else:
+ assert cfg.preprocess.energy_extract_mode in [
+ "from_mel",
+ "from_waveform",
+ ], f"{cfg.preprocess.energy_extract_mode} not in supported energy_extract_mode [from_mel, from_waveform, from_tacotron_stft]"
+
+ save_feature(dataset_output, cfg.preprocess.energy_dir, uid, energy)
+
+ if cfg.preprocess.extract_pitch:
+ pitch = f0.get_f0(wav, cfg.preprocess)
+ save_feature(dataset_output, cfg.preprocess.pitch_dir, uid, pitch)
+
+ if cfg.preprocess.extract_uv:
+ assert isinstance(pitch, np.ndarray)
+ uv = pitch != 0
+ save_feature(dataset_output, cfg.preprocess.uv_dir, uid, uv)
+
+ if cfg.preprocess.extract_amplitude_phase:
+ from utils.mel import amplitude_phase_spectrum
+
+ log_amplitude, phase, real, imaginary = amplitude_phase_spectrum(
+ wav_torch.unsqueeze(0), cfg.preprocess
+ )
+ save_feature(
+ dataset_output, cfg.preprocess.log_amplitude_dir, uid, log_amplitude
+ )
+ save_feature(dataset_output, cfg.preprocess.phase_dir, uid, phase)
+ save_feature(dataset_output, cfg.preprocess.real_dir, uid, real)
+ save_feature(dataset_output, cfg.preprocess.imaginary_dir, uid, imaginary)
+
+ if cfg.preprocess.extract_audio:
+ save_feature(dataset_output, cfg.preprocess.audio_dir, uid, wav)
+
+ if cfg.preprocess.extract_label:
+ if cfg.preprocess.is_mu_law:
+ # compress audio
+ wav = compress(wav, cfg.preprocess.bits)
+ label = audio_to_label(wav, cfg.preprocess.bits)
+ save_feature(dataset_output, cfg.preprocess.label_dir, uid, label)
+
+
+def cal_normalized_mel(mel, dataset_name, cfg):
+ """
+ mel: (n_mels, T)
+ """
+ # mel_min, mel_max: (n_mels)
+ mel_min, mel_max = load_mel_extrema(cfg, dataset_name)
+ mel_norm = normalize_mel_channel(mel, mel_min, mel_max)
+ return mel_norm
+
+
+def cal_mel_min_max(dataset, output_path, cfg, metadata=None):
+ dataset_output = os.path.join(output_path, dataset)
+
+ if metadata is None:
+ metadata = []
+ for dataset_type in ["train", "test"] if "eval" not in dataset else ["test"]:
+ dataset_file = os.path.join(dataset_output, "{}.json".format(dataset_type))
+ with open(dataset_file, "r") as f:
+ metadata.extend(json.load(f))
+
+ tmp_mel_min = []
+ tmp_mel_max = []
+ for item in metadata:
+ mel_path = os.path.join(
+ dataset_output, cfg.preprocess.mel_dir, item["Uid"] + ".npy"
+ )
+ if not os.path.exists(mel_path):
+ continue
+ mel = np.load(mel_path)
+ if mel.shape[0] != cfg.preprocess.n_mel:
+ mel = mel.T
+ # mel: (n_mels, T)
+ assert mel.shape[0] == cfg.preprocess.n_mel
+
+ tmp_mel_min.append(np.min(mel, axis=-1))
+ tmp_mel_max.append(np.max(mel, axis=-1))
+
+ mel_min = np.min(tmp_mel_min, axis=0)
+ mel_max = np.max(tmp_mel_max, axis=0)
+
+ ## save mel min max data
+ mel_min_max_dir = os.path.join(dataset_output, cfg.preprocess.mel_min_max_stats_dir)
+ os.makedirs(mel_min_max_dir, exist_ok=True)
+
+ mel_min_path = os.path.join(mel_min_max_dir, "mel_min.npy")
+ mel_max_path = os.path.join(mel_min_max_dir, "mel_max.npy")
+ np.save(mel_min_path, mel_min)
+ np.save(mel_max_path, mel_max)
+
+
+def denorm_for_pred_mels(cfg, dataset_name, split, pred):
+ """
+ Args:
+ pred: a list whose every element is (frame_len, n_mels)
+ Return:
+ similar like pred
+ """
+ mel_min, mel_max = load_mel_extrema(cfg.preprocess, dataset_name)
+ recovered_mels = [
+ denormalize_mel_channel(mel.T, mel_min, mel_max).T for mel in pred
+ ]
+
+ return recovered_mels
+
+
+def load_mel_extrema(cfg, dataset_name):
+ data_dir = os.path.join(cfg.processed_dir, dataset_name, cfg.mel_min_max_stats_dir)
+
+ min_file = os.path.join(data_dir, "mel_min.npy")
+ max_file = os.path.join(data_dir, "mel_max.npy")
+
+ mel_min = np.load(min_file)
+ mel_max = np.load(max_file)
+
+ return mel_min, mel_max
+
+
+def denormalize_mel_channel(mel, mel_min, mel_max):
+ mel_min = np.expand_dims(mel_min, -1)
+ mel_max = np.expand_dims(mel_max, -1)
+ return (mel + 1) / 2 * (mel_max - mel_min + ZERO) + mel_min
+
+
+def normalize_mel_channel(mel, mel_min, mel_max):
+ """
+ mel: (n_mels, T)
+ mel_min, mel_max: (n_mels)
+ """
+ mel_min = np.expand_dims(mel_min, -1)
+ mel_max = np.expand_dims(mel_max, -1)
+ return (mel - mel_min) / (mel_max - mel_min + ZERO) * 2 - 1
+
+
+def normalize(dataset, feat_dir, cfg):
+ dataset_output = os.path.join(cfg.preprocess.processed_dir, dataset)
+ print(f"normalize {feat_dir}")
+
+ max_value = np.finfo(np.float64).min
+ min_value = np.finfo(np.float64).max
+
+ scaler = StandardScaler()
+ feat_files = os.listdir(os.path.join(dataset_output, feat_dir))
+
+ for feat_file in tqdm(feat_files):
+ feat_file = os.path.join(dataset_output, feat_dir, feat_file)
+ if not feat_file.endswith(".npy"):
+ continue
+ feat = np.load(feat_file)
+ max_value = max(max_value, max(feat))
+ min_value = min(min_value, min(feat))
+ scaler.partial_fit(feat.reshape((-1, 1)))
+ mean = scaler.mean_[0]
+ std = scaler.scale_[0]
+ stat = np.array([min_value, max_value, mean, std])
+ stat_npy = os.path.join(dataset_output, f"{feat_dir}_stat.npy")
+ np.save(stat_npy, stat)
+ return mean, std, min_value, max_value
+
+
+def load_normalized(feat_dir, dataset_name, cfg):
+ dataset_output = os.path.join(cfg.preprocess.processed_dir, dataset_name)
+ stat_npy = os.path.join(dataset_output, f"{feat_dir}_stat.npy")
+ min_value, max_value, mean, std = np.load(stat_npy)
+ return mean, std, min_value, max_value
+
+
+def cal_pitch_statistics_svc(dataset, output_path, cfg, metadata=None):
+ # path of dataset
+ dataset_dir = os.path.join(output_path, dataset)
+ save_dir = os.path.join(dataset_dir, cfg.preprocess.pitch_dir)
+ os.makedirs(save_dir, exist_ok=True)
+ if has_existed(os.path.join(save_dir, "statistics.json")):
+ return
+
+ if metadata is None:
+ # load singers and ids
+ singers = json.load(open(os.path.join(dataset_dir, "singers.json"), "r"))
+
+ # combine train and test metadata
+ metadata = []
+ for dataset_type in ["train", "test"] if "eval" not in dataset else ["test"]:
+ dataset_file = os.path.join(dataset_dir, "{}.json".format(dataset_type))
+ with open(dataset_file, "r") as f:
+ metadata.extend(json.load(f))
+ else:
+ singers = list(set([item["Singer"] for item in metadata]))
+ singers = {
+ "{}_{}".format(dataset, name): idx for idx, name in enumerate(singers)
+ }
+
+ # use different scalers for each singer
+ pitch_scalers = [[] for _ in range(len(singers))]
+ total_pitch_scalers = [[] for _ in range(len(singers))]
+
+ for utt_info in tqdm(metadata, desc="Loading F0..."):
+ # utt = f'{utt_info["Dataset"]}_{utt_info["Uid"]}'
+ singer = utt_info["Singer"]
+ pitch_path = os.path.join(
+ dataset_dir, cfg.preprocess.pitch_dir, utt_info["Uid"] + ".npy"
+ )
+ # total_pitch contains all pitch including unvoiced frames
+ if not os.path.exists(pitch_path):
+ continue
+ total_pitch = np.load(pitch_path)
+ assert len(total_pitch) > 0
+ # pitch contains only voiced frames
+ pitch = total_pitch[total_pitch != 0]
+ spkid = singers[f"{replace_augment_name(dataset)}_{singer}"]
+
+ # update pitch scalers
+ pitch_scalers[spkid].extend(pitch.tolist())
+ # update total pitch scalers
+ total_pitch_scalers[spkid].extend(total_pitch.tolist())
+
+ # save pitch statistics for each singer in dict
+ sta_dict = {}
+ for singer in tqdm(singers, desc="Singers statistics"):
+ spkid = singers[singer]
+ # voiced pitch statistics
+ mean, std, min, max, median = (
+ np.mean(pitch_scalers[spkid]),
+ np.std(pitch_scalers[spkid]),
+ np.min(pitch_scalers[spkid]),
+ np.max(pitch_scalers[spkid]),
+ np.median(pitch_scalers[spkid]),
+ )
+
+ # total pitch statistics
+ mean_t, std_t, min_t, max_t, median_t = (
+ np.mean(total_pitch_scalers[spkid]),
+ np.std(total_pitch_scalers[spkid]),
+ np.min(total_pitch_scalers[spkid]),
+ np.max(total_pitch_scalers[spkid]),
+ np.median(total_pitch_scalers[spkid]),
+ )
+ sta_dict[singer] = {
+ "voiced_positions": {
+ "mean": mean,
+ "std": std,
+ "median": median,
+ "min": min,
+ "max": max,
+ },
+ "total_positions": {
+ "mean": mean_t,
+ "std": std_t,
+ "median": median_t,
+ "min": min_t,
+ "max": max_t,
+ },
+ }
+
+ # save statistics
+ with open(os.path.join(save_dir, "statistics.json"), "w") as f:
+ json.dump(sta_dict, f, indent=4, ensure_ascii=False)
+
+
+def cal_pitch_statistics(dataset, output_path, cfg):
+ # path of dataset
+ dataset_dir = os.path.join(output_path, dataset)
+ if cfg.preprocess.use_phone_pitch:
+ pitch_dir = cfg.preprocess.phone_pitch_dir
+ else:
+ pitch_dir = cfg.preprocess.pitch_dir
+ save_dir = os.path.join(dataset_dir, pitch_dir)
+
+ os.makedirs(save_dir, exist_ok=True)
+ if has_existed(os.path.join(save_dir, "statistics.json")):
+ return
+ # load singers and ids
+ singers = json.load(open(os.path.join(dataset_dir, "singers.json"), "r"))
+
+ # combine train and test metadata
+ metadata = []
+ for dataset_type in ["train", "test"] if "eval" not in dataset else ["test"]:
+ dataset_file = os.path.join(dataset_dir, "{}.json".format(dataset_type))
+ with open(dataset_file, "r") as f:
+ metadata.extend(json.load(f))
+
+ # use different scalers for each singer
+ pitch_scalers = [[] for _ in range(len(singers))]
+ total_pitch_scalers = [[] for _ in range(len(singers))]
+
+ for utt_info in metadata:
+ utt = f'{utt_info["Dataset"]}_{utt_info["Uid"]}'
+ singer = utt_info["Singer"]
+ pitch_path = os.path.join(dataset_dir, pitch_dir, utt_info["Uid"] + ".npy")
+ # total_pitch contains all pitch including unvoiced frames
+ if not os.path.exists(pitch_path):
+ continue
+ total_pitch = np.load(pitch_path)
+ assert len(total_pitch) > 0
+ # pitch contains only voiced frames
+ # pitch = total_pitch[total_pitch != 0]
+ if cfg.preprocess.pitch_remove_outlier:
+ pitch = remove_outlier(total_pitch)
+ spkid = singers[f"{replace_augment_name(dataset)}_{singer}"]
+
+ # update pitch scalers
+ pitch_scalers[spkid].extend(pitch.tolist())
+ # update total pitch scalers
+ total_pitch_scalers[spkid].extend(total_pitch.tolist())
+
+ # save pitch statistics for each singer in dict
+ sta_dict = {}
+ for singer in singers:
+ spkid = singers[singer]
+ # voiced pitch statistics
+ mean, std, min, max, median = (
+ np.mean(pitch_scalers[spkid]),
+ np.std(pitch_scalers[spkid]),
+ np.min(pitch_scalers[spkid]),
+ np.max(pitch_scalers[spkid]),
+ np.median(pitch_scalers[spkid]),
+ )
+
+ # total pitch statistics
+ mean_t, std_t, min_t, max_t, median_t = (
+ np.mean(total_pitch_scalers[spkid]),
+ np.std(total_pitch_scalers[spkid]),
+ np.min(total_pitch_scalers[spkid]),
+ np.max(total_pitch_scalers[spkid]),
+ np.median(total_pitch_scalers[spkid]),
+ )
+ sta_dict[singer] = {
+ "voiced_positions": {
+ "mean": mean,
+ "std": std,
+ "median": median,
+ "min": min,
+ "max": max,
+ },
+ "total_positions": {
+ "mean": mean_t,
+ "std": std_t,
+ "median": median_t,
+ "min": min_t,
+ "max": max_t,
+ },
+ }
+
+ # save statistics
+ with open(os.path.join(save_dir, "statistics.json"), "w") as f:
+ json.dump(sta_dict, f, indent=4, ensure_ascii=False)
+
+
+def cal_energy_statistics(dataset, output_path, cfg):
+ # path of dataset
+ dataset_dir = os.path.join(output_path, dataset)
+ if cfg.preprocess.use_phone_energy:
+ energy_dir = cfg.preprocess.phone_energy_dir
+ else:
+ energy_dir = cfg.preprocess.energy_dir
+ save_dir = os.path.join(dataset_dir, energy_dir)
+ os.makedirs(save_dir, exist_ok=True)
+ print(os.path.join(save_dir, "statistics.json"))
+ if has_existed(os.path.join(save_dir, "statistics.json")):
+ return
+ # load singers and ids
+ singers = json.load(open(os.path.join(dataset_dir, "singers.json"), "r"))
+
+ # combine train and test metadata
+ metadata = []
+ for dataset_type in ["train", "test"] if "eval" not in dataset else ["test"]:
+ dataset_file = os.path.join(dataset_dir, "{}.json".format(dataset_type))
+ with open(dataset_file, "r") as f:
+ metadata.extend(json.load(f))
+
+ # use different scalers for each singer
+ energy_scalers = [[] for _ in range(len(singers))]
+ total_energy_scalers = [[] for _ in range(len(singers))]
+
+ for utt_info in metadata:
+ utt = f'{utt_info["Dataset"]}_{utt_info["Uid"]}'
+ singer = utt_info["Singer"]
+ energy_path = os.path.join(dataset_dir, energy_dir, utt_info["Uid"] + ".npy")
+ # total_energy contains all energy including unvoiced frames
+ if not os.path.exists(energy_path):
+ continue
+ total_energy = np.load(energy_path)
+ assert len(total_energy) > 0
+ # energy contains only voiced frames
+ # energy = total_energy[total_energy != 0]
+ if cfg.preprocess.energy_remove_outlier:
+ energy = remove_outlier(total_energy)
+ spkid = singers[f"{replace_augment_name(dataset)}_{singer}"]
+
+ # update energy scalers
+ energy_scalers[spkid].extend(energy.tolist())
+ # update total energyscalers
+ total_energy_scalers[spkid].extend(total_energy.tolist())
+
+ # save energy statistics for each singer in dict
+ sta_dict = {}
+ for singer in singers:
+ spkid = singers[singer]
+ # voiced energy statistics
+ mean, std, min, max, median = (
+ np.mean(energy_scalers[spkid]),
+ np.std(energy_scalers[spkid]),
+ np.min(energy_scalers[spkid]),
+ np.max(energy_scalers[spkid]),
+ np.median(energy_scalers[spkid]),
+ )
+
+ # total energy statistics
+ mean_t, std_t, min_t, max_t, median_t = (
+ np.mean(total_energy_scalers[spkid]),
+ np.std(total_energy_scalers[spkid]),
+ np.min(total_energy_scalers[spkid]),
+ np.max(total_energy_scalers[spkid]),
+ np.median(total_energy_scalers[spkid]),
+ )
+ sta_dict[singer] = {
+ "voiced_positions": {
+ "mean": mean,
+ "std": std,
+ "median": median,
+ "min": min,
+ "max": max,
+ },
+ "total_positions": {
+ "mean": mean_t,
+ "std": std_t,
+ "median": median_t,
+ "min": min_t,
+ "max": max_t,
+ },
+ }
+
+ # save statistics
+ with open(os.path.join(save_dir, "statistics.json"), "w") as f:
+ json.dump(sta_dict, f, indent=4, ensure_ascii=False)
+
+
+def copy_acoustic_features(metadata, dataset_dir, src_dataset_dir, cfg):
+ """Copy acoustic features from src_dataset_dir to dataset_dir
+
+ Args:
+ metadata (dict): dictionary that stores data in train.json and test.json files
+ dataset_dir (str): directory to store acoustic features
+ src_dataset_dir (str): directory to store acoustic features
+ cfg (dict): dictionary that stores configurations
+
+ """
+
+ if cfg.preprocess.extract_mel:
+ if not has_existed(os.path.join(dataset_dir, cfg.preprocess.mel_dir)):
+ os.makedirs(
+ os.path.join(dataset_dir, cfg.preprocess.mel_dir), exist_ok=True
+ )
+ print(
+ "Copying mel features from {} to {}...".format(
+ src_dataset_dir, dataset_dir
+ )
+ )
+ for utt_info in tqdm(metadata):
+ src_mel_path = os.path.join(
+ src_dataset_dir, cfg.preprocess.mel_dir, utt_info["Uid"] + ".npy"
+ )
+ dst_mel_path = os.path.join(
+ dataset_dir, cfg.preprocess.mel_dir, utt_info["Uid"] + ".npy"
+ )
+ # create soft-links
+ if not os.path.exists(dst_mel_path):
+ os.symlink(src_mel_path, dst_mel_path)
+ if cfg.preprocess.extract_energy:
+ if not has_existed(os.path.join(dataset_dir, cfg.preprocess.energy_dir)):
+ os.makedirs(
+ os.path.join(dataset_dir, cfg.preprocess.energy_dir), exist_ok=True
+ )
+ print(
+ "Copying energy features from {} to {}...".format(
+ src_dataset_dir, dataset_dir
+ )
+ )
+ for utt_info in tqdm(metadata):
+ src_energy_path = os.path.join(
+ src_dataset_dir, cfg.preprocess.energy_dir, utt_info["Uid"] + ".npy"
+ )
+ dst_energy_path = os.path.join(
+ dataset_dir, cfg.preprocess.energy_dir, utt_info["Uid"] + ".npy"
+ )
+ # create soft-links
+ if not os.path.exists(dst_energy_path):
+ os.symlink(src_energy_path, dst_energy_path)
+ if cfg.preprocess.extract_pitch:
+ if not has_existed(os.path.join(dataset_dir, cfg.preprocess.pitch_dir)):
+ os.makedirs(
+ os.path.join(dataset_dir, cfg.preprocess.pitch_dir), exist_ok=True
+ )
+ print(
+ "Copying pitch features from {} to {}...".format(
+ src_dataset_dir, dataset_dir
+ )
+ )
+ for utt_info in tqdm(metadata):
+ src_pitch_path = os.path.join(
+ src_dataset_dir, cfg.preprocess.pitch_dir, utt_info["Uid"] + ".npy"
+ )
+ dst_pitch_path = os.path.join(
+ dataset_dir, cfg.preprocess.pitch_dir, utt_info["Uid"] + ".npy"
+ )
+ # create soft-links
+ if not os.path.exists(dst_pitch_path):
+ os.symlink(src_pitch_path, dst_pitch_path)
+ if cfg.preprocess.extract_uv:
+ if not has_existed(os.path.join(dataset_dir, cfg.preprocess.uv_dir)):
+ os.makedirs(
+ os.path.join(dataset_dir, cfg.preprocess.uv_dir), exist_ok=True
+ )
+ print(
+ "Copying uv features from {} to {}...".format(
+ src_dataset_dir, dataset_dir
+ )
+ )
+ for utt_info in tqdm(metadata):
+ src_uv_path = os.path.join(
+ src_dataset_dir, cfg.preprocess.uv_dir, utt_info["Uid"] + ".npy"
+ )
+ dst_uv_path = os.path.join(
+ dataset_dir, cfg.preprocess.uv_dir, utt_info["Uid"] + ".npy"
+ )
+ # create soft-links
+ if not os.path.exists(dst_uv_path):
+ os.symlink(src_uv_path, dst_uv_path)
+ if cfg.preprocess.extract_audio:
+ if not has_existed(os.path.join(dataset_dir, cfg.preprocess.audio_dir)):
+ os.makedirs(
+ os.path.join(dataset_dir, cfg.preprocess.audio_dir), exist_ok=True
+ )
+ print(
+ "Copying audio features from {} to {}...".format(
+ src_dataset_dir, dataset_dir
+ )
+ )
+ for utt_info in tqdm(metadata):
+ if cfg.task_type == "tts":
+ src_audio_path = os.path.join(
+ src_dataset_dir,
+ cfg.preprocess.audio_dir,
+ utt_info["Uid"] + ".wav",
+ )
+ else:
+ src_audio_path = os.path.join(
+ src_dataset_dir,
+ cfg.preprocess.audio_dir,
+ utt_info["Uid"] + ".npy",
+ )
+ if cfg.task_type == "tts":
+ dst_audio_path = os.path.join(
+ dataset_dir, cfg.preprocess.audio_dir, utt_info["Uid"] + ".wav"
+ )
+ else:
+ dst_audio_path = os.path.join(
+ dataset_dir, cfg.preprocess.audio_dir, utt_info["Uid"] + ".npy"
+ )
+ # create soft-links
+ if not os.path.exists(dst_audio_path):
+ os.symlink(src_audio_path, dst_audio_path)
+ if cfg.preprocess.extract_label:
+ if not has_existed(os.path.join(dataset_dir, cfg.preprocess.label_dir)):
+ os.makedirs(
+ os.path.join(dataset_dir, cfg.preprocess.label_dir), exist_ok=True
+ )
+ print(
+ "Copying label features from {} to {}...".format(
+ src_dataset_dir, dataset_dir
+ )
+ )
+ for utt_info in tqdm(metadata):
+ src_label_path = os.path.join(
+ src_dataset_dir, cfg.preprocess.label_dir, utt_info["Uid"] + ".npy"
+ )
+ dst_label_path = os.path.join(
+ dataset_dir, cfg.preprocess.label_dir, utt_info["Uid"] + ".npy"
+ )
+ # create soft-links
+ if not os.path.exists(dst_label_path):
+ os.symlink(src_label_path, dst_label_path)
+
+
+def align_duration_mel(dataset, output_path, cfg):
+ print("align the duration and mel")
+
+ dataset_dir = os.path.join(output_path, dataset)
+ metadata = []
+ for dataset_type in ["train", "test"] if "eval" not in dataset else ["test"]:
+ dataset_file = os.path.join(dataset_dir, "{}.json".format(dataset_type))
+ with open(dataset_file, "r") as f:
+ metadata.extend(json.load(f))
+
+ utt2dur = {}
+ for index in tqdm(range(len(metadata))):
+ utt_info = metadata[index]
+ dataset = utt_info["Dataset"]
+ uid = utt_info["Uid"]
+ utt = "{}_{}".format(dataset, uid)
+
+ mel_path = os.path.join(dataset_dir, cfg.preprocess.mel_dir, uid + ".npy")
+ mel = np.load(mel_path).transpose(1, 0)
+ duration_path = os.path.join(
+ dataset_dir, cfg.preprocess.duration_dir, uid + ".npy"
+ )
+ duration = np.load(duration_path)
+ if sum(duration) != mel.shape[0]:
+ duration_sum = sum(duration)
+ mel_len = mel.shape[0]
+ mismatch = abs(duration_sum - mel_len)
+ assert mismatch <= 5, "duration and mel length mismatch!"
+ cloned = np.array(duration, copy=True)
+ if duration_sum > mel_len:
+ for j in range(1, len(duration) - 1):
+ if mismatch == 0:
+ break
+ dur_val = cloned[-j]
+ if dur_val >= mismatch:
+ cloned[-j] -= mismatch
+ mismatch -= dur_val
+ break
+ else:
+ cloned[-j] = 0
+ mismatch -= dur_val
+
+ elif duration_sum < mel_len:
+ cloned[-1] += mismatch
+ duration = cloned
+ utt2dur[utt] = duration
+ np.save(duration_path, duration)
+
+ return utt2dur
diff --git a/processors/audio_features_extractor.py b/processors/audio_features_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e38bd5eb69af5e439e7493a9c8b8121dcac2ae6
--- /dev/null
+++ b/processors/audio_features_extractor.py
@@ -0,0 +1,157 @@
+# Copyright (c) 2023 Amphion.
+
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+
+This module aims to be an entrance that integrates all the functions for extracting features from raw audio.
+
+The common audio features include:
+1. Acoustic features such as Mel Spectrogram, F0, Energy, etc.
+2. Content features such as phonetic posteriorgrams (PPG) and bottleneck features (BNF) from pretrained models
+
+Note:
+All the features extraction are designed to utilize GPU to the maximum extent, which can ease the on-the-fly extraction for large-scale dataset.
+
+"""
+
+import torch
+from torch.nn.utils.rnn import pad_sequence
+
+from utils.mel import extract_mel_features
+from utils.f0 import get_f0 as extract_f0_features
+from processors.content_extractor import (
+ WhisperExtractor,
+ ContentvecExtractor,
+ WenetExtractor,
+)
+
+
+class AudioFeaturesExtractor:
+ def __init__(self, cfg):
+ """
+ Args:
+ cfg: Amphion config that would be used to specify the processing parameters
+ """
+ self.cfg = cfg
+
+ def get_mel_spectrogram(self, wavs):
+ """Get Mel Spectrogram Features
+
+ Args:
+ wavs: Tensor whose shape is (B, T)
+
+ Returns:
+ Tensor whose shape is (B, n_mels, n_frames)
+ """
+ return extract_mel_features(y=wavs, cfg=self.cfg.preprocess)
+
+ def get_f0(self, wavs, wav_lens=None, use_interpolate=False, return_uv=False):
+ """Get F0 Features
+
+ Args:
+ wavs: Tensor whose shape is (B, T)
+
+ Returns:
+ Tensor whose shape is (B, n_frames)
+ """
+ device = wavs.device
+
+ f0s = []
+ uvs = []
+ for i, w in enumerate(wavs):
+ if wav_lens is not None:
+ w = w[: wav_lens[i]]
+
+ f0, uv = extract_f0_features(
+ # Use numpy to extract
+ w.cpu().numpy(),
+ self.cfg.preprocess,
+ use_interpolate=use_interpolate,
+ return_uv=True,
+ )
+ f0s.append(torch.as_tensor(f0, device=device))
+ uvs.append(torch.as_tensor(uv, device=device, dtype=torch.long))
+
+ # (B, n_frames)
+ f0s = pad_sequence(f0s, batch_first=True, padding_value=0)
+ uvs = pad_sequence(uvs, batch_first=True, padding_value=0)
+
+ if return_uv:
+ return f0s, uvs
+
+ return f0s
+
+ def get_energy(self, wavs, mel_spec=None):
+ """Get Energy Features
+
+ Args:
+ wavs: Tensor whose shape is (B, T)
+ mel_spec: Tensor whose shape is (B, n_mels, n_frames)
+
+ Returns:
+ Tensor whose shape is (B, n_frames)
+ """
+ if mel_spec is None:
+ mel_spec = self.get_mel_spectrogram(wavs)
+
+ energies = (mel_spec.exp() ** 2).sum(dim=1).sqrt()
+ return energies
+
+ def get_whisper_features(self, wavs, target_frame_len):
+ """Get Whisper Features
+
+ Args:
+ wavs: Tensor whose shape is (B, T)
+ target_frame_len: int
+
+ Returns:
+ Tensor whose shape is (B, target_frame_len, D)
+ """
+ if not hasattr(self, "whisper_extractor"):
+ self.whisper_extractor = WhisperExtractor(self.cfg)
+ self.whisper_extractor.load_model()
+
+ whisper_feats = self.whisper_extractor.extract_content_features(wavs)
+ whisper_feats = self.whisper_extractor.ReTrans(whisper_feats, target_frame_len)
+ return whisper_feats
+
+ def get_contentvec_features(self, wavs, target_frame_len):
+ """Get ContentVec Features
+
+ Args:
+ wavs: Tensor whose shape is (B, T)
+ target_frame_len: int
+
+ Returns:
+ Tensor whose shape is (B, target_frame_len, D)
+ """
+ if not hasattr(self, "contentvec_extractor"):
+ self.contentvec_extractor = ContentvecExtractor(self.cfg)
+ self.contentvec_extractor.load_model()
+
+ contentvec_feats = self.contentvec_extractor.extract_content_features(wavs)
+ contentvec_feats = self.contentvec_extractor.ReTrans(
+ contentvec_feats, target_frame_len
+ )
+ return contentvec_feats
+
+ def get_wenet_features(self, wavs, target_frame_len, wav_lens=None):
+ """Get WeNet Features
+
+ Args:
+ wavs: Tensor whose shape is (B, T)
+ target_frame_len: int
+ wav_lens: Tensor whose shape is (B)
+
+ Returns:
+ Tensor whose shape is (B, target_frame_len, D)
+ """
+ if not hasattr(self, "wenet_extractor"):
+ self.wenet_extractor = WenetExtractor(self.cfg)
+ self.wenet_extractor.load_model()
+
+ wenet_feats = self.wenet_extractor.extract_content_features(wavs, lens=wav_lens)
+ wenet_feats = self.wenet_extractor.ReTrans(wenet_feats, target_frame_len)
+ return wenet_feats
diff --git a/processors/content_extractor.py b/processors/content_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..34b54917a8d672a91f25fc6a54453e5a15d1296d
--- /dev/null
+++ b/processors/content_extractor.py
@@ -0,0 +1,626 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import torch
+import numpy as np
+import yaml
+import copy
+from tqdm import tqdm
+from torchaudio.compliance import kaldi
+from torch.nn.utils.rnn import pad_sequence
+from torch.utils.data import DataLoader
+from fairseq import checkpoint_utils
+from transformers import AutoModel, Wav2Vec2FeatureExtractor
+
+from utils.io_optim import (
+ TorchaudioDataset,
+ LibrosaDataset,
+ FFmpegDataset,
+ collate_batch,
+)
+import whisper
+from modules.wenet_extractor.utils.init_model import init_model
+from modules.wenet_extractor.utils.checkpoint import load_checkpoint
+
+"""
+ Extractor for content features
+ 1. whisper
+ 2. contentvec
+ 3. wenet
+ 4. mert
+
+ Pipeline:
+ in preprocess.py:
+ call extract_utt_content_features() to extract content features for each utterance
+ extract_utt_content_features() envelopes the following steps:
+ 1. load the model (whisper, contentvec, wenet)
+ 2. extract the content features
+ 3. save the content features into files
+ in svc_dataset.py:
+ call offline_align() to align the content features to the given target length
+
+"""
+
+"""
+ Extractor Usage:
+ 1. initialize an instance of extractor
+ extractor = WhisperExtractor(cfg)
+ 2. load the specified model
+ extractor.load_model()
+ 3. extract the content features
+ extractor.extract_content(utt) for single utterance
+ extractor.extract_content_batch(utts) for batch utterances
+ 4. save the content features
+ extractor.save_feature(utt, content_feature) for single utterance
+"""
+
+
+class AudioPretrainedModelFeaturesExtractor:
+ def __init__(self, cfg, extractor_type):
+ self.cfg = cfg
+ self.extractor_type = extractor_type
+ self.model = None
+ self.init_for_retrans()
+
+ def init_for_retrans(self):
+ target_hop = self.cfg.preprocess.hop_size
+
+ assert self.extractor_type in ["whisper", "contentvec", "wenet"]
+ if self.extractor_type == "whisper":
+ source_hop = (
+ self.cfg.preprocess.whisper_frameshift
+ * self.cfg.preprocess.whisper_downsample_rate
+ * self.cfg.preprocess.sample_rate
+ )
+ elif self.extractor_type == "contentvec":
+ source_hop = (
+ self.cfg.preprocess.contentvec_frameshift
+ * self.cfg.preprocess.sample_rate
+ )
+ elif self.extractor_type == "wenet":
+ source_hop = (
+ self.cfg.preprocess.wenet_frameshift
+ * self.cfg.preprocess.wenet_downsample_rate
+ * self.cfg.preprocess.sample_rate
+ )
+ source_hop = int(source_hop)
+ factor = np.gcd(source_hop, target_hop)
+ source_hop //= factor
+ target_hop //= factor
+
+ self.source_hop = source_hop
+ self.target_hop = target_hop
+
+ def offline_resolution_transformation(self, content, target_len):
+ """
+ args:
+ content: (source_len, dim)
+ target_len: target length
+ return:
+ mapped_feature: (target_len, dim)
+ """
+ source_hop = self.source_hop
+ target_hop = self.target_hop
+
+ # (source_len, 256)
+ _, width = content.shape
+ # slice the content from padded feature
+ source_len = min(target_len * target_hop // source_hop + 1, len(content))
+
+ # const ~= target_len * target_hop
+ const = source_len * source_hop // target_hop * target_hop
+
+ # (source_len * source_hop, dim)
+ up_sampling_feats = np.repeat(content, source_hop, axis=0)
+ # (const, dim) -> (const/target_hop, target_hop, dim) -> (const/target_hop, dim)
+ down_sampling_feats = np.average(
+ up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1
+ )
+
+ err = abs(target_len - len(down_sampling_feats))
+ if err > 8:
+ # err_log_dir is indeterminate
+ err_log_dir = os.path.join(
+ self.cfg.preprocess.processed_dir, "align_max_err.log"
+ )
+ try:
+ with open(err_log_dir, "r") as f:
+ err_num = int(f.read())
+ except:
+ with open(err_log_dir, "w") as f:
+ f.write("0")
+ err_num = 0
+ if err > err_num:
+ with open(err_log_dir, "w") as f:
+ f.write(str(err))
+
+ if len(down_sampling_feats) < target_len:
+ # (1, dim) -> (err, dim)
+ end = down_sampling_feats[-1][None, :].repeat(err, axis=0)
+ down_sampling_feats = np.concatenate([down_sampling_feats, end], axis=0)
+
+ # (target_len, dim)
+ mapped_feature = down_sampling_feats[:target_len]
+
+ return mapped_feature
+
+ def log_for_ReTrans(self, err):
+ err_log_dir = os.path.join(
+ self.cfg.preprocess.processed_dir, "align_max_err.log"
+ )
+ try:
+ with open(err_log_dir, "r") as f:
+ err_num = int(f.read())
+ except:
+ with open(err_log_dir, "w") as f:
+ f.write("0")
+ err_num = 0
+ if err > err_num:
+ with open(err_log_dir, "w") as f:
+ f.write(str(err))
+
+ def ReTrans(self, source_feats, padded_target_len):
+ """
+ Resolution Transformation for mismatched frames alginment.
+
+ TODO: Merge the offline resolution_transformation into one
+
+ args:
+ source_feats: Tensor, (B, padded_source_len, D)
+ padded_target_len: int, the maximum target length in a batch
+ return:
+ mapped_feature: Tensor, (B, padded_target_len, D)
+ """
+ source_hop = self.source_hop
+ target_hop = self.target_hop
+
+ # (B, padded_source_len, D)
+ B, padded_source_len, D = source_feats.shape
+
+ # select the valid content from padded feature
+ source_len = min(
+ padded_target_len * target_hop // source_hop + 1, padded_source_len
+ )
+
+ # const ~= padded_target_len * target_hop (padded wav's duration)
+ const = source_len * source_hop // target_hop * target_hop
+
+ # (B, padded_source_len, D) -> (B, padded_source_len * source_hop, D) -> (B, const, D)
+ up_sampling_feats = torch.repeat_interleave(source_feats, source_hop, dim=1)[
+ :, :const
+ ]
+ # (B, const, D) -> (B, const/target_hop, target_hop, D) -> (B, const/target_hop, D)
+ down_sampling_feats = torch.mean(
+ up_sampling_feats.reshape(B, -1, target_hop, D), dim=2
+ )
+
+ err = abs(padded_target_len - down_sampling_feats.shape[1])
+ if err > 8:
+ self.log_for_ReTrans(err)
+
+ if down_sampling_feats.shape[1] < padded_target_len:
+ # (B, 1, D) -> (B, err, D)
+ end = down_sampling_feats[:, -1, :][:, None, :].repeat_interleave(
+ err, dim=1
+ )
+ # -> (B, padded_target_len, D)
+ down_sampling_feats = torch.cat([down_sampling_feats, end], dim=1)
+
+ # (B, padded_target_len, D)
+ mapped_feature = down_sampling_feats[:, :padded_target_len]
+ return mapped_feature
+
+ def get_valid_features(self, utt, content_feature):
+ # only keep effective parts
+ duration = utt["Duration"]
+ if self.extractor_type == "whisper":
+ frameshift = (
+ self.cfg.preprocess.whisper_frameshift
+ * self.cfg.preprocess.whisper_downsample_rate
+ ) # 20ms
+ elif self.extractor_type == "contentvec":
+ frameshift = self.cfg.preprocess.contentvec_frameshift # 20ms
+ elif self.extractor_type == "wenet":
+ frameshift = (
+ self.cfg.preprocess.wenet_frameshift
+ * self.cfg.preprocess.wenet_downsample_rate
+ ) # 40ms
+ elif self.extractor_type == "mert":
+ frameshift = self.cfg.preprocess.mert_frameshift
+ else:
+ raise NotImplementedError
+
+ # calculate the number of valid frames
+ num_frames = int(np.ceil((duration - frameshift) / frameshift)) + 1
+ assert (
+ len(content_feature.shape) == 2
+ ), "content feature shape error, it should be (num_frames, dim)"
+ content_feature = content_feature[:num_frames, :]
+ return content_feature
+
+ def save_feature(self, utt, content_feature):
+ """Save a single utternace to path {cfg.preprocess.processed_dir}
+
+ Args:
+ utt (dict): one item in metadata, containing information for one utterance
+ content_feature (tensor): content feature of one utterance
+ """
+ uid = utt["Uid"]
+ assert self.extractor_type != None
+ out_dir = os.path.join(
+ self.cfg.preprocess.processed_dir, utt["Dataset"], self.extractor_type
+ )
+ os.makedirs(out_dir, exist_ok=True)
+ save_path = os.path.join(out_dir, uid + ".npy")
+
+ content_feature = self.get_valid_features(utt, content_feature)
+ np.save(save_path, content_feature.cpu().detach().numpy())
+
+
+class WhisperExtractor(AudioPretrainedModelFeaturesExtractor):
+ def __init__(self, config):
+ super(WhisperExtractor, self).__init__(config, extractor_type="whisper")
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ def load_model(self):
+ # load whisper checkpoint
+ print("Loading Whisper Model...")
+
+ if "whisper_model_path" in self.cfg.preprocess:
+ if os.path.isfile(self.cfg.preprocess.whisper_model_path):
+ # "pretrained/whisper/medium.pt"
+ download_root = os.path.dirname(self.cfg.preprocess.whisper_model_path)
+ elif os.path.isdir(self.cfg.preprocess.whisper_model_path):
+ # "pretrained/whisper"
+ download_root = self.cfg.preprocess.whisper_model_path
+ else:
+ # if the path does not exist, download the model to the path
+ download_root = self.cfg.preprocess.whisper_model_path
+ if download_root.endswith(".pt"):
+ download_root = os.path.dirname(download_root)
+ else:
+ download_root = None
+
+ model = whisper.load_model(
+ self.cfg.preprocess.whisper_model, self.device, download_root
+ )
+ if torch.cuda.is_available():
+ print("Using GPU...\n")
+ model = model.cuda()
+ else:
+ print("Using CPU...\n")
+
+ self.model = model.eval()
+
+ def extract_content_features(self, wavs):
+ """extract content features from a batch of dataloader
+ Args:
+ wavs: tensor (batch_size, T)
+ """
+ # wavs: (batch, max_len)
+ wavs = whisper.pad_or_trim(wavs)
+ # batch_mel: (batch, 80, 3000)
+ batch_mel = whisper.log_mel_spectrogram(wavs, device=self.model.device)
+ with torch.no_grad():
+ # (batch, 1500, 1024)
+ features = self.model.embed_audio(batch_mel)
+ return features
+
+
+class ContentvecExtractor(AudioPretrainedModelFeaturesExtractor):
+ def __init__(self, cfg):
+ super(ContentvecExtractor, self).__init__(cfg, extractor_type="contentvec")
+
+ def load_model(self):
+ assert self.model == None
+ # Load model
+ ckpt_path = self.cfg.preprocess.contentvec_file
+ print("Load Contentvec Model...")
+
+ models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
+ [ckpt_path],
+ suffix="",
+ )
+ model = models[0]
+ model.eval()
+
+ if torch.cuda.is_available():
+ # print("Using GPU...\n")
+ model = model.cuda()
+
+ self.model = model
+
+ def extract_content_features(self, wavs):
+ """extract content features from a batch of dataloader
+ Args:
+ wavs: tensor (batch, T)
+ """
+ device = next(self.model.parameters()).device
+ wavs = wavs.to(device) # (batch, max_len)
+ padding_mask = torch.eq(wavs, torch.zeros_like(wavs)).to(device)
+ with torch.no_grad():
+ logits = self.model.extract_features(
+ source=wavs, padding_mask=padding_mask, output_layer=12
+ )
+ # feats: (batch, T, 256)
+ feats = self.model.final_proj(logits[0])
+ return feats
+
+
+class WenetExtractor(AudioPretrainedModelFeaturesExtractor):
+ def __init__(self, config):
+ super(WenetExtractor, self).__init__(config, extractor_type="wenet")
+
+ def load_model(self):
+ wenet_cfg = self.cfg.preprocess.wenet_config
+ wenet_model_path = self.cfg.preprocess.wenet_model_path
+ # load Wenet config
+ with open(wenet_cfg, "r") as w:
+ wenet_configs = yaml.load(w, Loader=yaml.FullLoader)
+ self.extract_conf = copy.deepcopy(wenet_configs["dataset_conf"])
+ print("Loading Wenet Model...")
+ self.model = init_model(wenet_configs)
+ load_checkpoint(self.model, wenet_model_path)
+
+ if torch.cuda.is_available():
+ print("Using GPU...\n")
+ self.model = self.model.cuda()
+ else:
+ print("Using CPU...\n")
+
+ self.model = self.model.eval()
+
+ def extract_content_features(self, wavs, lens):
+ """extract content features from a batch of dataloader
+ Args:
+ wavs: tensor, whose shape is (B, T)
+ lens: list
+ """
+ feats_list = []
+ lengths_list = []
+
+ device = next(self.model.parameters()).device
+ # Extract fbank/mfcc features by kaldi
+ assert self.extract_conf is not None, "load model first!"
+ feats_type = self.extract_conf.get("feats_type", "fbank")
+ assert feats_type in ["fbank", "mfcc"]
+
+ for idx, wav in enumerate(wavs):
+ # wav: (T)
+ wav = wav[: lens[idx]].to(device)
+
+ # pad one frame to compensate for the frame cut off after feature extraction
+ pad_tensor = torch.zeros(160, device=wav.device)
+ wav = torch.cat((wav, pad_tensor), dim=-1)
+ wav *= 1 << 15
+
+ wav = wav.unsqueeze(0) # (T) -> (1, T)
+ if feats_type == "fbank":
+ fbank_conf = self.extract_conf.get("fbank_conf", {})
+ feat = kaldi.fbank(
+ wav,
+ sample_frequency=16000,
+ num_mel_bins=fbank_conf["num_mel_bins"],
+ frame_length=fbank_conf["frame_length"],
+ frame_shift=fbank_conf["frame_shift"],
+ dither=fbank_conf["dither"],
+ )
+ elif feats_type == "mfcc":
+ mfcc_conf = self.extract_conf.get("mfcc", {})
+ feat = kaldi.mfcc(
+ wav,
+ sample_frequency=16000,
+ num_mel_bins=mfcc_conf["num_mel_bins"],
+ frame_length=mfcc_conf["frame_length"],
+ frame_shift=mfcc_conf["frame_shift"],
+ dither=mfcc_conf["dither"],
+ num_ceps=mfcc_conf.get("num_ceps", 40),
+ high_freq=mfcc_conf.get("high_freq", 0.0),
+ low_freq=mfcc_conf.get("low_freq", 20.0),
+ )
+ feats_list.append(feat)
+ lengths_list.append(feat.shape[0])
+
+ feats_lengths = torch.tensor(lengths_list, dtype=torch.int32).to(device)
+ feats_tensor = pad_sequence(feats_list, batch_first=True).to(
+ device
+ ) # (batch, len, 80)
+
+ features = self.model.encoder_extractor(
+ feats_tensor,
+ feats_lengths,
+ decoding_chunk_size=-1,
+ num_decoding_left_chunks=-1,
+ simulate_streaming=False,
+ )
+ return features
+
+
+class MertExtractor(AudioPretrainedModelFeaturesExtractor):
+ def __init__(self, cfg):
+ super(MertExtractor, self).__init__(cfg, extractor_type="mert")
+ self.preprocessor = None
+
+ def load_model(self):
+ assert self.model == None
+ assert self.preprocessor == None
+
+ print("Loading MERT Model: ...", self.cfg.preprocess.mert_model)
+
+ model_name = self.cfg.preprocess.mert_model
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
+
+ if torch.cuda.is_available():
+ model = model.cuda()
+ preprocessor = Wav2Vec2FeatureExtractor.from_pretrained(
+ model_name, trust_remote_code=True
+ )
+
+ self.model = model
+ self.preprocessor = preprocessor
+
+ def extract_content_features(self, wavs):
+ """extract content features from a batch of dataloader
+ Args:
+ wavs: tensor (batch, T)
+ """
+ with torch.no_grad():
+ sample_rate = self.preprocessor.sampling_rate
+ device = next(self.model.parameters()).device
+ assert (
+ sample_rate == self.cfg.preprocess.mert_sample_rate
+ ), "mert sample rate mismatch, expected {}, got {}".format(
+ self.cfg.preprocess.mert_sample_rate, sample_rate
+ )
+ mert_features = []
+ # wav: (len)
+ for wav in wavs:
+ # {input_values: tensor, attention_mask: tensor}
+ inputs = self.preprocessor(
+ wavs, sampling_rate=sample_rate, return_tensors="pt"
+ ).to(device)
+
+ outputs = self.model(**inputs, output_hidden_states=True)
+ # (25 layers, time steps, 1024 feature_dim)
+ all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze()
+ # (1, frame_len, 1024) -> (frame_len, 1024)
+ feature = outputs.hidden_states[
+ self.cfg.preprocess.mert_feature_layer
+ ].squeeze(0)
+ mert_features.append(feature)
+
+ return mert_features
+
+
+def extract_utt_content_features_dataloader(cfg, metadata, num_workers):
+ dataset_name = metadata[0]["Dataset"]
+ with torch.no_grad():
+ if cfg.preprocess.extract_whisper_feature:
+ feat_dir = os.path.join(
+ cfg.preprocess.processed_dir, dataset_name, "whisper"
+ )
+ os.makedirs(feat_dir, exist_ok=True)
+ feat_files_num = len(os.listdir(feat_dir))
+
+ if feat_files_num != len(metadata):
+ whisper_waveforms = FFmpegDataset(
+ cfg,
+ dataset_name,
+ cfg.preprocess.whisper_sample_rate,
+ metadata=metadata,
+ )
+ data_loader = DataLoader(
+ whisper_waveforms,
+ num_workers=num_workers,
+ shuffle=False,
+ pin_memory=cfg.preprocess.pin_memory,
+ batch_size=cfg.preprocess.content_feature_batch_size,
+ collate_fn=collate_batch,
+ drop_last=False,
+ )
+ extractor = WhisperExtractor(cfg)
+ extractor.load_model()
+ for batch_idx, items in enumerate(tqdm(data_loader)):
+ _metadata, wavs, lens = items
+
+ batch_content_features = extractor.extract_content_features(wavs)
+ for index, utt in enumerate(_metadata):
+ extractor.save_feature(utt, batch_content_features[index])
+
+ if cfg.preprocess.extract_contentvec_feature:
+ feat_dir = os.path.join(
+ cfg.preprocess.processed_dir, dataset_name, "contentvec"
+ )
+ os.makedirs(feat_dir, exist_ok=True)
+ feat_files_num = len(os.listdir(feat_dir))
+
+ if feat_files_num != len(metadata):
+ contentvec_waveforms = LibrosaDataset(
+ cfg,
+ dataset_name,
+ cfg.preprocess.contentvec_sample_rate,
+ metadata=metadata,
+ )
+ data_loader = DataLoader(
+ contentvec_waveforms,
+ num_workers=num_workers,
+ shuffle=False,
+ pin_memory=cfg.preprocess.pin_memory,
+ batch_size=cfg.preprocess.content_feature_batch_size,
+ collate_fn=collate_batch,
+ drop_last=False,
+ )
+ extractor = ContentvecExtractor(cfg)
+ extractor.load_model()
+ for batch_idx, items in enumerate(tqdm(data_loader)):
+ _metadata, wavs, lens = items
+
+ batch_content_features = extractor.extract_content_features(wavs)
+ for index, utt in enumerate(_metadata):
+ extractor.save_feature(utt, batch_content_features[index])
+
+ if cfg.preprocess.extract_wenet_feature:
+ feat_dir = os.path.join(cfg.preprocess.processed_dir, dataset_name, "wenet")
+ os.makedirs(feat_dir, exist_ok=True)
+ feat_files_num = len(os.listdir(feat_dir))
+
+ if feat_files_num != len(metadata):
+ wenet_waveforms = TorchaudioDataset(
+ cfg,
+ dataset_name,
+ cfg.preprocess.wenet_sample_rate,
+ metadata=metadata,
+ )
+ data_loader = DataLoader(
+ wenet_waveforms,
+ num_workers=num_workers,
+ shuffle=False,
+ pin_memory=cfg.preprocess.pin_memory,
+ batch_size=cfg.preprocess.content_feature_batch_size,
+ collate_fn=collate_batch,
+ drop_last=False,
+ )
+ extractor = WenetExtractor(cfg)
+ extractor.load_model()
+ for batch_idx, items in enumerate(tqdm(data_loader)):
+ _metadata, wavs, lens = items
+
+ batch_content_features = extractor.extract_content_features(
+ wavs,
+ lens,
+ )
+ for index, utt in enumerate(_metadata):
+ extractor.save_feature(utt, batch_content_features[index])
+
+ if cfg.preprocess.extract_mert_feature:
+ feat_dir = os.path.join(cfg.preprocess.processed_dir, dataset_name, "mert")
+ os.makedirs(feat_dir, exist_ok=True)
+ feat_files_num = len(os.listdir(feat_dir))
+
+ if feat_files_num != len(metadata):
+ mert_waveforms = TorchaudioDataset(
+ cfg,
+ dataset_name,
+ cfg.preprocess.mert_sample_rate,
+ metadata=metadata,
+ )
+ data_loader = DataLoader(
+ mert_waveforms,
+ num_workers=num_workers,
+ shuffle=False,
+ pin_memory=cfg.preprocess.pin_memory,
+ batch_size=cfg.preprocess.content_feature_batch_size,
+ collate_fn=collate_batch,
+ drop_last=False,
+ )
+ extractor = MertExtractor(cfg)
+ extractor.load_model()
+ for batch_idx, items in enumerate(tqdm(data_loader)):
+ _metadata, wavs, lens = items
+
+ batch_content_features = extractor.extract_content_features(wavs)
+ for index, utt in enumerate(_metadata):
+ extractor.save_feature(utt, batch_content_features[index])
diff --git a/processors/data_augment.py b/processors/data_augment.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fc183361d4bcfd454693ee0b7ffdd9758c09312
--- /dev/null
+++ b/processors/data_augment.py
@@ -0,0 +1,378 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import random
+import os
+import json
+
+import numpy as np
+import parselmouth
+import torch
+import torchaudio
+from tqdm import tqdm
+
+from audiomentations import TimeStretch
+
+from pedalboard import (
+ Pedalboard,
+ HighShelfFilter,
+ LowShelfFilter,
+ PeakFilter,
+ PitchShift,
+)
+
+from utils.util import has_existed
+
+PRAAT_CHANGEGENDER_PITCHMEDIAN_DEFAULT = 0.0
+PRAAT_CHANGEGENDER_FORMANTSHIFTRATIO_DEFAULT = 1.0
+PRAAT_CHANGEGENDER_PITCHSHIFTRATIO_DEFAULT = 1.0
+PRAAT_CHANGEGENDER_PITCHRANGERATIO_DEFAULT = 1.0
+PRAAT_CHANGEGENDER_DURATIONFACTOR_DEFAULT = 1.0
+
+
+def wav_to_Sound(wav, sr: int) -> parselmouth.Sound:
+ """Convert a waveform to a parselmouth.Sound object
+
+ Args:
+ wav (np.ndarray/torch.Tensor): waveform of shape (n_channels, n_samples)
+ sr (int, optional): sampling rate.
+
+ Returns:
+ parselmouth.Sound: a parselmouth.Sound object
+ """
+ assert wav.shape == (1, len(wav[0])), "wav must be of shape (1, n_samples)"
+ sound = None
+ if isinstance(wav, np.ndarray):
+ sound = parselmouth.Sound(wav[0], sampling_frequency=sr)
+ elif isinstance(wav, torch.Tensor):
+ sound = parselmouth.Sound(wav[0].numpy(), sampling_frequency=sr)
+ assert sound is not None, "wav must be either np.ndarray or torch.Tensor"
+ return sound
+
+
+def get_pitch_median(wav, sr: int):
+ """Get the median pitch of a waveform
+
+ Args:
+ wav (np.ndarray/torch.Tensor): waveform of shape (n_channels, n_samples)
+ sr (int, optional): sampling rate.
+
+ Returns:
+ parselmouth.Pitch, float: a parselmouth.Pitch object and the median pitch
+ """
+ if not isinstance(wav, parselmouth.Sound):
+ sound = wav_to_Sound(wav, sr)
+ else:
+ sound = wav
+ pitch_median = PRAAT_CHANGEGENDER_PITCHMEDIAN_DEFAULT
+
+ # To Pitch: Time step(s)(standard value: 0.0), Pitch floor (Hz)(standard value: 75), Pitch ceiling (Hz)(standard value: 600.0)
+ pitch = parselmouth.praat.call(sound, "To Pitch", 0.8 / 75, 75, 600)
+ # Get quantile: From time (s), To time (s), Quantile(0.5 is then the 50% quantile, i.e., the median), Units (Hertz or Bark)
+ pitch_median = parselmouth.praat.call(pitch, "Get quantile", 0.0, 0.0, 0.5, "Hertz")
+
+ return pitch, pitch_median
+
+
+def change_gender(
+ sound,
+ pitch=None,
+ formant_shift_ratio: float = PRAAT_CHANGEGENDER_FORMANTSHIFTRATIO_DEFAULT,
+ new_pitch_median: float = PRAAT_CHANGEGENDER_PITCHMEDIAN_DEFAULT,
+ pitch_range_ratio: float = PRAAT_CHANGEGENDER_PITCHRANGERATIO_DEFAULT,
+ duration_factor: float = PRAAT_CHANGEGENDER_DURATIONFACTOR_DEFAULT,
+) -> parselmouth.Sound:
+ """Invoke change gender function in praat
+
+ Args:
+ sound (parselmouth.Sound): a parselmouth.Sound object
+ pitch (parselmouth.Pitch, optional): a parselmouth.Pitch object. Defaults to None.
+ formant_shift_ratio (float, optional): formant shift ratio. A value of 1.0 means no change. Greater than 1.0 means higher pitch. Less than 1.0 means lower pitch.
+ new_pitch_median (float, optional): new pitch median.
+ pitch_range_ratio (float, optional): pitch range ratio. A value of 1.0 means no change. Greater than 1.0 means higher pitch range. Less than 1.0 means lower pitch range.
+ duration_factor (float, optional): duration factor. A value of 1.0 means no change. Greater than 1.0 means longer duration. Less than 1.0 means shorter duration.
+
+ Returns:
+ parselmouth.Sound: a parselmouth.Sound object
+ """
+ if pitch is None:
+ new_sound = parselmouth.praat.call(
+ sound,
+ "Change gender",
+ 75,
+ 600,
+ formant_shift_ratio,
+ new_pitch_median,
+ pitch_range_ratio,
+ duration_factor,
+ )
+ else:
+ new_sound = parselmouth.praat.call(
+ (sound, pitch),
+ "Change gender",
+ formant_shift_ratio,
+ new_pitch_median,
+ pitch_range_ratio,
+ duration_factor,
+ )
+ return new_sound
+
+
+def apply_formant_and_pitch_shift(
+ sound: parselmouth.Sound,
+ formant_shift_ratio: float = PRAAT_CHANGEGENDER_FORMANTSHIFTRATIO_DEFAULT,
+ pitch_shift_ratio: float = PRAAT_CHANGEGENDER_PITCHSHIFTRATIO_DEFAULT,
+ pitch_range_ratio: float = PRAAT_CHANGEGENDER_PITCHRANGERATIO_DEFAULT,
+ duration_factor: float = PRAAT_CHANGEGENDER_DURATIONFACTOR_DEFAULT,
+) -> parselmouth.Sound:
+ """use Praat "Changer gender" command to manipulate pitch and formant
+ "Change gender": Praat -> Sound Object -> Convert -> Change gender
+ refer to Help of Praat for more details
+ # https://github.com/YannickJadoul/Parselmouth/issues/25#issuecomment-608632887 might help
+ """
+ pitch = None
+ new_pitch_median = PRAAT_CHANGEGENDER_PITCHMEDIAN_DEFAULT
+ if pitch_shift_ratio != 1.0:
+ pitch, pitch_median = get_pitch_median(sound, sound.sampling_frequency)
+ new_pitch_median = pitch_median * pitch_shift_ratio
+
+ # refer to https://github.com/praat/praat/issues/1926#issuecomment-974909408
+ pitch_minimum = parselmouth.praat.call(
+ pitch, "Get minimum", 0.0, 0.0, "Hertz", "Parabolic"
+ )
+ new_median = pitch_median * pitch_shift_ratio
+ scaled_minimum = pitch_minimum * pitch_shift_ratio
+ result_minimum = new_median + (scaled_minimum - new_median) * pitch_range_ratio
+ if result_minimum < 0:
+ new_pitch_median = PRAAT_CHANGEGENDER_PITCHMEDIAN_DEFAULT
+ pitch_range_ratio = PRAAT_CHANGEGENDER_PITCHRANGERATIO_DEFAULT
+
+ if math.isnan(new_pitch_median):
+ new_pitch_median = PRAAT_CHANGEGENDER_PITCHMEDIAN_DEFAULT
+ pitch_range_ratio = PRAAT_CHANGEGENDER_PITCHRANGERATIO_DEFAULT
+
+ new_sound = change_gender(
+ sound,
+ pitch,
+ formant_shift_ratio,
+ new_pitch_median,
+ pitch_range_ratio,
+ duration_factor,
+ )
+ return new_sound
+
+
+# Function used in EQ
+def pedalboard_equalizer(wav: np.ndarray, sr: int) -> np.ndarray:
+ """Use pedalboard to do equalizer"""
+ board = Pedalboard()
+
+ cutoff_low_freq = 60
+ cutoff_high_freq = 10000
+
+ q_min = 2
+ q_max = 5
+
+ random_all_freq = True
+ num_filters = 10
+ if random_all_freq:
+ key_freqs = [random.uniform(1, 12000) for _ in range(num_filters)]
+ else:
+ key_freqs = [
+ power_ratio(float(z) / (num_filters - 1), cutoff_low_freq, cutoff_high_freq)
+ for z in range(num_filters)
+ ]
+ q_values = [
+ power_ratio(random.uniform(0, 1), q_min, q_max) for _ in range(num_filters)
+ ]
+ gains = [random.uniform(-12, 12) for _ in range(num_filters)]
+ # low-shelving filter
+ board.append(
+ LowShelfFilter(
+ cutoff_frequency_hz=key_freqs[0], gain_db=gains[0], q=q_values[0]
+ )
+ )
+ # peaking filters
+ for i in range(1, 9):
+ board.append(
+ PeakFilter(
+ cutoff_frequency_hz=key_freqs[i], gain_db=gains[i], q=q_values[i]
+ )
+ )
+ # high-shelving filter
+ board.append(
+ HighShelfFilter(
+ cutoff_frequency_hz=key_freqs[9], gain_db=gains[9], q=q_values[9]
+ )
+ )
+
+ # Apply the pedalboard to the audio
+ processed_audio = board(wav, sr)
+ return processed_audio
+
+
+def power_ratio(r: float, a: float, b: float):
+ return a * math.pow((b / a), r)
+
+
+def audiomentations_time_stretch(wav: np.ndarray, sr: int) -> np.ndarray:
+ """Use audiomentations to do time stretch"""
+ transform = TimeStretch(
+ min_rate=0.8, max_rate=1.25, leave_length_unchanged=False, p=1.0
+ )
+ augmented_wav = transform(wav, sample_rate=sr)
+ return augmented_wav
+
+
+def formant_and_pitch_shift(
+ sound: parselmouth.Sound, fs: bool, ps: bool
+) -> parselmouth.Sound:
+ """ """
+ formant_shift_ratio = PRAAT_CHANGEGENDER_FORMANTSHIFTRATIO_DEFAULT
+ pitch_shift_ratio = PRAAT_CHANGEGENDER_PITCHSHIFTRATIO_DEFAULT
+ pitch_range_ratio = PRAAT_CHANGEGENDER_PITCHRANGERATIO_DEFAULT
+
+ assert fs != ps, "fs, ps are mutually exclusive"
+
+ if fs:
+ formant_shift_ratio = random.uniform(1.0, 1.4)
+ use_reciprocal = random.uniform(-1, 1) > 0
+ if use_reciprocal:
+ formant_shift_ratio = 1.0 / formant_shift_ratio
+ # only use praat to change formant
+ new_sound = apply_formant_and_pitch_shift(
+ sound,
+ formant_shift_ratio=formant_shift_ratio,
+ )
+ return new_sound
+
+ if ps:
+ board = Pedalboard()
+ board.append(PitchShift(random.uniform(-12, 12)))
+ wav_numpy = sound.values
+ wav_numpy = board(wav_numpy, sound.sampling_frequency)
+ # use pedalboard to change pitch
+ new_sound = parselmouth.Sound(
+ wav_numpy, sampling_frequency=sound.sampling_frequency
+ )
+ return new_sound
+
+
+def wav_manipulation(
+ wav: torch.Tensor,
+ sr: int,
+ aug_type: str = "None",
+ formant_shift: bool = False,
+ pitch_shift: bool = False,
+ time_stretch: bool = False,
+ equalizer: bool = False,
+) -> torch.Tensor:
+ assert aug_type == "None" or aug_type in [
+ "formant_shift",
+ "pitch_shift",
+ "time_stretch",
+ "equalizer",
+ ], "aug_type must be one of formant_shift, pitch_shift, time_stretch, equalizer"
+
+ assert aug_type == "None" or (
+ formant_shift == False
+ and pitch_shift == False
+ and time_stretch == False
+ and equalizer == False
+ ), "if aug_type is specified, other argument must be False"
+
+ if aug_type != "None":
+ if aug_type == "formant_shift":
+ formant_shift = True
+ if aug_type == "pitch_shift":
+ pitch_shift = True
+ if aug_type == "equalizer":
+ equalizer = True
+ if aug_type == "time_stretch":
+ time_stretch = True
+
+ wav_numpy = wav.numpy()
+
+ if equalizer:
+ wav_numpy = pedalboard_equalizer(wav_numpy, sr)
+
+ if time_stretch:
+ wav_numpy = audiomentations_time_stretch(wav_numpy, sr)
+
+ sound = wav_to_Sound(wav_numpy, sr)
+
+ if formant_shift or pitch_shift:
+ sound = formant_and_pitch_shift(sound, formant_shift, pitch_shift)
+
+ wav = torch.from_numpy(sound.values).float()
+ # shape (1, n_samples)
+ return wav
+
+
+def augment_dataset(cfg, dataset) -> list:
+ """Augment dataset with formant_shift, pitch_shift, time_stretch, equalizer
+
+ Args:
+ cfg (dict): configuration
+ dataset (str): dataset name
+
+ Returns:
+ list: augmented dataset names
+ """
+ # load metadata
+ dataset_path = os.path.join(cfg.preprocess.processed_dir, dataset)
+ split = ["train", "test"] if "eval" not in dataset else ["test"]
+ augment_datasets = []
+ aug_types = [
+ "formant_shift" if cfg.preprocess.use_formant_shift else None,
+ "pitch_shift" if cfg.preprocess.use_pitch_shift else None,
+ "time_stretch" if cfg.preprocess.use_time_stretch else None,
+ "equalizer" if cfg.preprocess.use_equalizer else None,
+ ]
+ aug_types = filter(None, aug_types)
+ for aug_type in aug_types:
+ print("Augmenting {} with {}...".format(dataset, aug_type))
+ new_dataset = dataset + "_" + aug_type
+ augment_datasets.append(new_dataset)
+ new_dataset_path = os.path.join(cfg.preprocess.processed_dir, new_dataset)
+
+ for dataset_type in split:
+ metadata_path = os.path.join(dataset_path, "{}.json".format(dataset_type))
+ augmented_metadata = []
+ new_metadata_path = os.path.join(
+ new_dataset_path, "{}.json".format(dataset_type)
+ )
+ os.makedirs(new_dataset_path, exist_ok=True)
+ new_dataset_wav_dir = os.path.join(new_dataset_path, "wav")
+ os.makedirs(new_dataset_wav_dir, exist_ok=True)
+
+ if has_existed(new_metadata_path):
+ continue
+
+ with open(metadata_path, "r") as f:
+ metadata = json.load(f)
+
+ for utt in tqdm(metadata):
+ original_wav_path = utt["Path"]
+ original_wav, sr = torchaudio.load(original_wav_path)
+ new_wav = wav_manipulation(original_wav, sr, aug_type=aug_type)
+ new_wav_path = os.path.join(new_dataset_wav_dir, utt["Uid"] + ".wav")
+ torchaudio.save(new_wav_path, new_wav, sr)
+ new_utt = {
+ "Dataset": utt["Dataset"] + "_" + aug_type,
+ "index": utt["index"],
+ "Singer": utt["Singer"],
+ "Uid": utt["Uid"],
+ "Path": new_wav_path,
+ "Duration": utt["Duration"],
+ }
+ augmented_metadata.append(new_utt)
+ new_metadata_path = os.path.join(
+ new_dataset_path, "{}.json".format(dataset_type)
+ )
+ with open(new_metadata_path, "w") as f:
+ json.dump(augmented_metadata, f, indent=4, ensure_ascii=False)
+ return augment_datasets
diff --git a/processors/descriptive_text_features_extractor.py b/processors/descriptive_text_features_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e9cf05bfc0ae935711df3f3ea483dcf77022307
--- /dev/null
+++ b/processors/descriptive_text_features_extractor.py
@@ -0,0 +1,17 @@
+# Copyright (c) 2023 Amphion.
+
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+TODO:
+
+This module aims to be an entrance that integrates all the "descriptive text" features extraction functions.
+
+The common descriptive text features include:
+1. Global semantic guidance features that extracted some pretrained text models like T5. It can be adopted to TTA, TTM, etc.
+
+Note:
+All the features extraction are designed to utilize GPU to the maximum extent, which can ease the on-the-fly extraction for large-scale dataset.
+
+"""
diff --git a/processors/phone_extractor.py b/processors/phone_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..f55081922cd3702cf96837a9e3998ac152e54d1e
--- /dev/null
+++ b/processors/phone_extractor.py
@@ -0,0 +1,156 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+from tqdm import tqdm
+from text.g2p_module import G2PModule, LexiconModule
+from text.symbol_table import SymbolTable
+
+"""
+ phoneExtractor: extract phone from text
+"""
+
+
+class phoneExtractor:
+ def __init__(self, cfg, dataset_name=None, phone_symbol_file=None):
+ """
+ Args:
+ cfg: config
+ dataset_name: name of dataset
+ """
+ self.cfg = cfg
+
+ # phone symbols dict
+ self.phone_symbols = set()
+
+ # phone symbols dict file
+ if phone_symbol_file is not None:
+ self.phone_symbols_file = phone_symbol_file
+ elif dataset_name is not None:
+ self.dataset_name = dataset_name
+ self.phone_symbols_file = os.path.join(
+ cfg.preprocess.processed_dir, dataset_name, cfg.preprocess.symbols_dict
+ )
+
+ # initialize g2p module
+ if cfg.preprocess.phone_extractor in [
+ "espeak",
+ "pypinyin",
+ "pypinyin_initials_finals",
+ ]:
+ self.g2p_module = G2PModule(
+ backend=cfg.preprocess.phone_extractor, language=cfg.preprocess.language
+ )
+ elif cfg.preprocess.phone_extractor == "lexicon":
+ assert cfg.preprocess.lexicon_path != ""
+ self.g2p_module = LexiconModule(cfg.preprocess.lexicon_path)
+ else:
+ print("No support to", cfg.preprocess.phone_extractor)
+ raise
+
+ def extract_phone(self, text):
+ """
+ Extract phone from text
+ Args:
+
+ text: text of utterance
+
+ Returns:
+ phone_symbols: set of phone symbols
+ phone_seq: list of phone sequence of each utterance
+ """
+
+ if self.cfg.preprocess.phone_extractor in [
+ "espeak",
+ "pypinyin",
+ "pypinyin_initials_finals",
+ ]:
+ text = text.replace("”", '"').replace("“", '"')
+ phone = self.g2p_module.g2p_conversion(text=text)
+ self.phone_symbols.update(phone)
+ phone_seq = [phn for phn in phone]
+
+ elif self.cfg.preprocess.phone_extractor == "lexicon":
+ phone_seq = self.g2p_module.g2p_conversion(text)
+ phone = phone_seq
+ if not isinstance(phone_seq, list):
+ phone_seq = phone_seq.split()
+
+ return phone_seq
+
+ def save_dataset_phone_symbols_to_table(self):
+ # load and merge saved phone symbols
+ if os.path.exists(self.phone_symbols_file):
+ phone_symbol_dict_saved = SymbolTable.from_file(
+ self.phone_symbols_file
+ )._sym2id.keys()
+ self.phone_symbols.update(set(phone_symbol_dict_saved))
+
+ # save phone symbols
+ phone_symbol_dict = SymbolTable()
+ for s in sorted(list(self.phone_symbols)):
+ phone_symbol_dict.add(s)
+ phone_symbol_dict.to_file(self.phone_symbols_file)
+
+
+def extract_utt_phone_sequence(dataset, cfg, metadata):
+ """
+ Extract phone sequence from text
+ Args:
+ dataset (str): name of dataset, e.g. opencpop
+ cfg: config
+ metadata: list of dict, each dict contains "Uid", "Text"
+
+ """
+
+ dataset_name = dataset
+
+ # output path
+ out_path = os.path.join(
+ cfg.preprocess.processed_dir, dataset_name, cfg.preprocess.phone_dir
+ )
+ os.makedirs(out_path, exist_ok=True)
+
+ phone_extractor = phoneExtractor(cfg, dataset_name)
+
+ for utt in tqdm(metadata):
+ uid = utt["Uid"]
+ text = utt["Text"]
+
+ phone_seq = phone_extractor.extract_phone(text)
+
+ phone_path = os.path.join(out_path, uid + ".phone")
+ with open(phone_path, "w") as fin:
+ fin.write(" ".join(phone_seq))
+
+ if cfg.preprocess.phone_extractor != "lexicon":
+ phone_extractor.save_dataset_phone_symbols_to_table()
+
+
+def save_all_dataset_phone_symbols_to_table(self, cfg, dataset):
+ # phone symbols dict
+ phone_symbols = set()
+
+ for dataset_name in dataset:
+ phone_symbols_file = os.path.join(
+ cfg.preprocess.processed_dir, dataset_name, cfg.preprocess.symbols_dict
+ )
+
+ # load and merge saved phone symbols
+ assert os.path.exists(phone_symbols_file)
+ phone_symbol_dict_saved = SymbolTable.from_file(
+ phone_symbols_file
+ )._sym2id.keys()
+ phone_symbols.update(set(phone_symbol_dict_saved))
+
+ # save all phone symbols to each dataset
+ phone_symbol_dict = SymbolTable()
+ for s in sorted(list(phone_symbols)):
+ phone_symbol_dict.add(s)
+ for dataset_name in dataset:
+ phone_symbols_file = os.path.join(
+ cfg.preprocess.processed_dir, dataset_name, cfg.preprocess.symbols_dict
+ )
+ phone_symbol_dict.to_file(phone_symbols_file)
diff --git a/processors/text_features_extractor.py b/processors/text_features_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..b67f3dec526874d22431c3e11a3d2743667f901e
--- /dev/null
+++ b/processors/text_features_extractor.py
@@ -0,0 +1,17 @@
+# Copyright (c) 2023 Amphion.
+
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+TODO:
+
+This module aims to be an entrance that integrates all the "text" features extraction functions.
+
+The common text features include:
+1. phone features that are used for TTS, SVS, etc.
+
+Note:
+All the features extraction are designed to utilize GPU to the maximum extent, which can ease the on-the-fly extraction for large-scale dataset.
+
+"""
diff --git a/schedulers/__init__.py b/schedulers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/schedulers/scheduler.py b/schedulers/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbde0e53679440b08d968c7ea0a4dc88f3db5ffa
--- /dev/null
+++ b/schedulers/scheduler.py
@@ -0,0 +1,173 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch.optim import Optimizer
+from typing import List, Optional, Tuple, Union
+
+
+def calc_lr(step, dim_embed, warmup_steps):
+ return dim_embed ** (-0.5) * min(step ** (-0.5), step * warmup_steps ** (-1.5))
+
+
+# The function is modified from
+# https://github.com/lifeiteng/vall-e/blob/9c69096d603ce13174fb5cb025f185e2e9b36ac7/valle/modules/scheduler.py
+class NoamScheduler(torch.optim.lr_scheduler._LRScheduler):
+ def __init__(
+ self,
+ base_lr: float,
+ optimizer: torch.optim.Optimizer,
+ dim_embed: int,
+ warmup_steps: int,
+ last_epoch: int = -1,
+ verbose: bool = False,
+ ) -> None:
+ self.dim_embed = dim_embed
+ self.base_lr = base_lr
+ self.warmup_steps = warmup_steps
+ self.num_param_groups = len(optimizer.param_groups)
+
+ super().__init__(optimizer, last_epoch, verbose)
+
+ def get_lr(self) -> float:
+ lr = self.base_lr * calc_lr(self._step_count, self.dim_embed, self.warmup_steps)
+ return [lr] * self.num_param_groups
+
+ def set_step(self, step: int):
+ self._step_count = step
+
+
+class LRScheduler(object):
+ """
+ Base-class for learning rate schedulers where the learning-rate depends on both the
+ batch and the epoch.
+ """
+
+ def __init__(self, optimizer: Optimizer, verbose: bool = False):
+ # Attach optimizer
+ if not isinstance(optimizer, Optimizer):
+ raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__))
+ self.optimizer = optimizer
+ self.verbose = verbose
+
+ for group in optimizer.param_groups:
+ group.setdefault("base_lr", group["lr"])
+
+ self.base_lrs = [group["base_lr"] for group in optimizer.param_groups]
+
+ self.epoch = 0
+ self.batch = 0
+
+ def state_dict(self):
+ """Returns the state of the scheduler as a :class:`dict`.
+
+ It contains an entry for every variable in self.__dict__ which
+ is not the optimizer.
+ """
+ return {
+ "base_lrs": self.base_lrs,
+ "epoch": self.epoch,
+ "batch": self.batch,
+ }
+
+ def load_state_dict(self, state_dict):
+ """Loads the schedulers state.
+
+ Args:
+ state_dict (dict): scheduler state. Should be an object returned
+ from a call to :meth:`state_dict`.
+ """
+ self.__dict__.update(state_dict)
+
+ def get_last_lr(self) -> List[float]:
+ """Return last computed learning rate by current scheduler. Will be a list of float."""
+ return self._last_lr
+
+ def get_lr(self):
+ # Compute list of learning rates from self.epoch and self.batch and
+ # self.base_lrs; this must be overloaded by the user.
+ # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ]
+ raise NotImplementedError
+
+ def step_batch(self, batch: Optional[int] = None) -> None:
+ # Step the batch index, or just set it. If `batch` is specified, it
+ # must be the batch index from the start of training, i.e. summed over
+ # all epochs.
+ # You can call this in any order; if you don't provide 'batch', it should
+ # of course be called once per batch.
+ if batch is not None:
+ self.batch = batch
+ else:
+ self.batch = self.batch + 1
+ self._set_lrs()
+
+ def step_epoch(self, epoch: Optional[int] = None):
+ # Step the epoch index, or just set it. If you provide the 'epoch' arg,
+ # you should call this at the start of the epoch; if you don't provide the 'epoch'
+ # arg, you should call it at the end of the epoch.
+ if epoch is not None:
+ self.epoch = epoch
+ else:
+ self.epoch = self.epoch + 1
+ self._set_lrs()
+
+ def _set_lrs(self):
+ values = self.get_lr()
+ assert len(values) == len(self.optimizer.param_groups)
+
+ for i, data in enumerate(zip(self.optimizer.param_groups, values)):
+ param_group, lr = data
+ param_group["lr"] = lr
+ self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
+
+
+class Eden(LRScheduler):
+ """
+ Eden scheduler.
+ The basic formula (before warmup) is:
+ lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 *
+ (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup
+ where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches
+ and then stays constant at 1.
+
+
+ E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
+
+ Args:
+ optimizer: the optimizer to change the learning rates on
+ lr_batches: the number of batches after which we start significantly
+ decreasing the learning rate, suggest 5000.
+ lr_epochs: the number of epochs after which we start significantly
+ decreasing the learning rate, suggest 6 if you plan to do e.g.
+ 20 to 40 epochs, but may need smaller number if dataset is huge
+ and you will do few epochs.
+ """
+
+ def __init__(
+ self,
+ optimizer: Optimizer,
+ lr_batches: Union[int, float],
+ lr_epochs: Union[int, float],
+ warmup_batches: Union[int, float] = 500.0,
+ verbose: bool = False,
+ ):
+ super(Eden, self).__init__(optimizer, verbose)
+ self.lr_batches = lr_batches
+ self.lr_epochs = lr_epochs
+ self.warmup_batches = warmup_batches
+
+ def get_lr(self):
+ factor = (
+ (self.batch**2 + self.lr_batches**2) / self.lr_batches**2
+ ) ** -0.25 * (
+ ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25
+ )
+ warmup_factor = (
+ 1.0
+ if self.batch >= self.warmup_batches
+ else 0.5 + 0.5 * (self.batch / self.warmup_batches)
+ )
+
+ return [x * factor * warmup_factor for x in self.base_lrs]