diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..858018e750da7be7b271bb7307e68d159ed67ef6 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Shivam Mehta + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..c013140cdfb9de19c4d4e73c73a44e33f33fa871 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,14 @@ +include README.md +include LICENSE.txt +include requirements.*.txt +include *.cff +include requirements.txt +include matcha/VERSION +recursive-include matcha *.json +recursive-include matcha *.html +recursive-include matcha *.png +recursive-include matcha *.md +recursive-include matcha *.py +recursive-include matcha *.pyx +recursive-exclude tests * +prune tests* diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..4b523dd17b13a19617c9cc9d9dad7f7d8d4c24a0 --- /dev/null +++ b/Makefile @@ -0,0 +1,42 @@ + +help: ## Show help + @grep -E '^[.a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +clean: ## Clean autogenerated files + rm -rf dist + find . -type f -name "*.DS_Store" -ls -delete + find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf + find . | grep -E ".pytest_cache" | xargs rm -rf + find . | grep -E ".ipynb_checkpoints" | xargs rm -rf + rm -f .coverage + +clean-logs: ## Clean logs + rm -rf logs/** + +create-package: ## Create wheel and tar gz + rm -rf dist/ + python setup.py bdist_wheel --plat-name=manylinux1_x86_64 + python setup.py sdist + python -m twine upload dist/* --verbose --skip-existing + +format: ## Run pre-commit hooks + pre-commit run -a + +sync: ## Merge changes from main branch to your current branch + git pull + git pull origin main + +test: ## Run not slow tests + pytest -k "not slow" + +test-full: ## Run all tests + pytest + +train-ljspeech: ## Train the model + python matcha/train.py experiment=ljspeech + +train-ljspeech-min: ## Train the model with minimum memory + python matcha/train.py experiment=ljspeech_min_memory + +start_app: ## Start the app + python matcha/app.py diff --git a/README.md:Zone.Identifier b/README.md:Zone.Identifier new file mode 100644 index 0000000000000000000000000000000000000000..cfb6db5ae8db553aced535d78635eb42932097b0 --- /dev/null +++ b/README.md:Zone.Identifier @@ -0,0 +1,4 @@ +[ZoneTransfer] +ZoneId=3 +ReferrerUrl=https://huggingface.co/spaces/the-cramer-project/AkylAI_TTS_small/tree/main +HostUrl=https://huggingface.co/spaces/the-cramer-project/AkylAI_TTS_small/resolve/main/README.md?download=true diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..52cd1e923c03660b036c6cdf7057d9e758f45932 --- /dev/null +++ b/app.py @@ -0,0 +1,174 @@ +from pathlib import Path +import argparse +import soundfile as sf +import torch +import io +import argparse +from matcha.hifigan.config import v1 +from matcha.hifigan.denoiser import Denoiser +from matcha.hifigan.env import AttrDict +from matcha.hifigan.models import Generator as HiFiGAN +from matcha.models.matcha_tts import MatchaTTS +from matcha.text import sequence_to_text, text_to_sequence +from matcha.utils.utils import intersperse +import gradio as gr +import requests + +def download_file(url, save_path): + response = requests.get(url) + with open(save_path, 'wb') as file: + file.write(response.content) + +url_checkpoint = 'https://github.com/simonlobgromov/AkylAI_Matcha_Checkpoint/releases/download/Matcha-TTS/checkpoint_epoch.499.ckpt' +save_checkpoint_path = './checkpoints/checkpoint.ckpt' +url_generator = 'https://github.com/simonlobgromov/AkylAI_Matcha_HiFiGan/releases/download/Generator/generator_v1' +save_generator_path = './checkpoints/generator' + +download_file(url_checkpoint, save_checkpoint_path) +download_file(url_generator, save_generator_path) + +def load_matcha( checkpoint_path, device): + model = MatchaTTS.load_from_checkpoint(checkpoint_path, map_location=device) + _ = model.eval() + return model + +def load_hifigan(checkpoint_path, device): + h = AttrDict(v1) + hifigan = HiFiGAN(h).to(device) + hifigan.load_state_dict(torch.load(checkpoint_path, map_location=device)["generator"]) + _ = hifigan.eval() + hifigan.remove_weight_norm() + return hifigan + +def load_vocoder(checkpoint_path, device): + vocoder = None + vocoder = load_hifigan(checkpoint_path, device) + denoiser = Denoiser(vocoder, mode="zeros") + return vocoder, denoiser + +def process_text(i: int, text: str, device: torch.device): + print(f"[{i}] - Input text: {text}") + x = torch.tensor( + intersperse(text_to_sequence(text, ["kyrgyz_cleaners"]), 0), + dtype=torch.long, + device=device, + )[None] + x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device) + x_phones = sequence_to_text(x.squeeze(0).tolist()) + print(f"[{i}] - Phonetised text: {x_phones[1::2]}") + return {"x_orig": text, "x": x, "x_lengths": x_lengths, "x_phones": x_phones} + +def to_waveform(mel, vocoder, denoiser=None): + audio = vocoder(mel).clamp(-1, 1) + if denoiser is not None: + audio = denoiser(audio.squeeze(), strength=0.00025).cpu().squeeze() + return audio.cpu().squeeze() + +@torch.inference_mode() +def process_text_gradio(text): + output = process_text(1, text, device) + return output["x_phones"][1::2], output["x"], output["x_lengths"] + +@torch.inference_mode() +def synthesise_mel(text, text_length, n_timesteps, temperature, length_scale, spk=-1): + spk = torch.tensor([spk], device=device, dtype=torch.long) if spk >= 0 else None + output = model.synthesise( + text, + text_length, + n_timesteps=n_timesteps, + temperature=temperature, + spks=spk, + length_scale=length_scale, + ) + output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) + return output["waveform"].numpy() + +def get_inference(text, n_timesteps=20, mel_temp = 0.667, length_scale=0.8, spk=-1): + phones, text, text_lengths = process_text_gradio(text) + print(type(synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk))) + return synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk) + + +device = torch.device("cpu") +model_path = './checkpoints/checkpoint.ckpt' +vocoder_path = './checkpoints/generator' +model = load_matcha(model_path, device) +vocoder, denoiser = load_vocoder(vocoder_path, device) + +def gen_tts(text, speaking_rate): + return 22050, get_inference(text = text, length_scale = speaking_rate) + +default_text = "Баарыңарга салам, менин атым Акылай." + +css = """ + #share-btn-container { + display: flex; + padding-left: 0.5rem !important; + padding-right: 0.5rem !important; + background-color: #000000; + justify-content: center; + align-items: center; + border-radius: 9999px !important; + width: 13rem; + margin-top: 10px; + margin-left: auto; + flex: unset !important; + } + #share-btn { + all: initial; + color: #ffffff; + font-weight: 600; + cursor: pointer; + font-family: 'IBM Plex Sans', sans-serif; + margin-left: 0.5rem !important; + padding-top: 0.25rem !important; + padding-bottom: 0.25rem !important; + right:0; + } + #share-btn * { + all: unset !important; + } + #share-btn-container div:nth-child(-n+2){ + width: auto !important; + min-height: 0px !important; + } + #share-btn-container .wrap { + display: none !important; + } +""" +with gr.Blocks(css=css) as block: + gr.HTML( + """ +
+
+

+ Akyl-AI TTS +

+
+
+ """ + ) + with gr.Row(): + image_path = "./photo_2024-04-07_15-59-52.png" + gr.Image(image_path, label=None, width=660, height=315, show_label=False) + with gr.Row(): + with gr.Column(): + input_text = gr.Textbox(label="Input Text", lines=2, value=default_text, elem_id="input_text") + speaking_rate = gr.Slider(label='Speaking rate', minimum=0.5, maximum=1, step=0.05, value=0.8, interactive=True, show_label=True, elem_id="speaking_rate") + + + run_button = gr.Button("Generate Audio", variant="primary") + with gr.Column(): + audio_out = gr.Audio(label="Parler-TTS generation", type="numpy", elem_id="audio_out") + + inputs = [input_text, speaking_rate] + outputs = [audio_out] + run_button.click(fn=gen_tts, inputs=inputs, outputs=outputs, queue=True) + + +block.queue() +block.launch(share=True) diff --git a/checkpoints/info.txt b/checkpoints/info.txt new file mode 100644 index 0000000000000000000000000000000000000000..103c4df94caa6aceba547da00799f9bec5f6f407 --- /dev/null +++ b/checkpoints/info.txt @@ -0,0 +1 @@ +Забудь дорогу всяк сюда входящий! diff --git a/configs/__init__.py b/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..56bf7f4aa4906bc0f997132708cc0826c198e4aa --- /dev/null +++ b/configs/__init__.py @@ -0,0 +1 @@ +# this file is needed here to include configs when building project as a package diff --git a/configs/callbacks/default.yaml b/configs/callbacks/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ebaa3ed31a7f626bc62f90184dc4b25b631e52a9 --- /dev/null +++ b/configs/callbacks/default.yaml @@ -0,0 +1,5 @@ +defaults: + - model_checkpoint.yaml + - model_summary.yaml + - rich_progress_bar.yaml + - _self_ diff --git a/configs/callbacks/model_checkpoint.yaml b/configs/callbacks/model_checkpoint.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5de653632376777bc112d708a6f238aa3696d0b4 --- /dev/null +++ b/configs/callbacks/model_checkpoint.yaml @@ -0,0 +1,17 @@ +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html + +model_checkpoint: + _target_: lightning.pytorch.callbacks.ModelCheckpoint + dirpath: ${paths.output_dir}/checkpoints # directory to save the model file + filename: checkpoint_{epoch:03d} # checkpoint filename + monitor: epoch # name of the logged metric which determines when model is improving + verbose: False # verbosity mode + save_last: true # additionally always save an exact copy of the last checkpoint to a file last.ckpt + save_top_k: 5 # save k best models (determined by above metric) + mode: "max" # "max" means higher metric value is better, can be also "min" + auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name + save_weights_only: False # if True, then only the model’s weights will be saved + every_n_train_steps: null # number of training steps between checkpoints + train_time_interval: null # checkpoints are monitored at the specified time interval + every_n_epochs: 10 # number of epochs between checkpoints + save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation diff --git a/configs/callbacks/model_summary.yaml b/configs/callbacks/model_summary.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6e5368d0e94298cce6d5421365b4583bd763ba92 --- /dev/null +++ b/configs/callbacks/model_summary.yaml @@ -0,0 +1,5 @@ +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html + +model_summary: + _target_: lightning.pytorch.callbacks.RichModelSummary + max_depth: 3 # the maximum depth of layer nesting that the summary will include diff --git a/configs/callbacks/none.yaml b/configs/callbacks/none.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/configs/callbacks/rich_progress_bar.yaml b/configs/callbacks/rich_progress_bar.yaml new file mode 100644 index 0000000000000000000000000000000000000000..de6f1ccb11205a4db93645fb6f297e50205de172 --- /dev/null +++ b/configs/callbacks/rich_progress_bar.yaml @@ -0,0 +1,4 @@ +# https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html + +rich_progress_bar: + _target_: lightning.pytorch.callbacks.RichProgressBar diff --git a/configs/data/akylai.yaml b/configs/data/akylai.yaml new file mode 100644 index 0000000000000000000000000000000000000000..530b9f6591e95112ad8f758a2d16e4b852a2b47b --- /dev/null +++ b/configs/data/akylai.yaml @@ -0,0 +1,21 @@ +_target_: matcha.data.text_mel_datamodule.TextMelDataModule +name: akylai +train_filelist_path: ./Kany_dataset_mk4_v1/Kany_dataset_mk4_v1_filelist_train.txt +valid_filelist_path: ./Kany_dataset_mk4_v1/Kany_dataset_mk4_v1_filelist_test.txt +batch_size: 12 +num_workers: 12 +pin_memory: True +cleaners: [kyrgyz_cleaners] +add_blank: True +n_spks: 1 +n_fft: 1024 +n_feats: 80 +sample_rate: 22050 +hop_length: 256 +win_length: 1024 +f_min: 0 +f_max: 8000 +data_statistics: # Computed for ljspeech dataset + mel_mean: -5.638045310974121 + mel_std: 2.6814498901367188 +seed: ${seed} diff --git a/configs/data/akylai_multi.yaml b/configs/data/akylai_multi.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1afb6f351bcca738d53031555a23bd36140bfacc --- /dev/null +++ b/configs/data/akylai_multi.yaml @@ -0,0 +1,21 @@ +_target_: matcha.data.text_mel_datamodule.TextMelDataModule +name: akylai_multi +train_filelist_path: ./akylai_multi_dataset/akylai_mlspk_filelist_train.txt +valid_filelist_path: ./akylai_multi_dataset/akylai_mlspk_filelist_test.txt +batch_size: 32 +num_workers: 20 +pin_memory: True +cleaners: [kyrgyz_cleaners] +add_blank: True +n_spks: 2 +n_fft: 1024 +n_feats: 80 +sample_rate: 22050 +hop_length: 256 +win_length: 1024 +f_min: 0 +f_max: 8000 +data_statistics: + mel_mean: -5.6814561 + mel_std: 2.7337122 +seed: ${seed} diff --git a/configs/data/hi-fi_en-US_female.yaml b/configs/data/hi-fi_en-US_female.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1269f9b3b421d27a204bb0697e2f27a0fa0864a3 --- /dev/null +++ b/configs/data/hi-fi_en-US_female.yaml @@ -0,0 +1,14 @@ +defaults: + - ljspeech + - _self_ + +# Dataset URL: https://ast-astrec.nict.go.jp/en/release/hi-fi-captain/ +_target_: matcha.data.text_mel_datamodule.TextMelDataModule +name: hi-fi_en-US_female +train_filelist_path: data/filelists/hi-fi-captain-en-us-female_train.txt +valid_filelist_path: data/filelists/hi-fi-captain-en-us-female_val.txt +batch_size: 32 +cleaners: [english_cleaners_piper] +data_statistics: # Computed for this dataset + mel_mean: -6.38385 + mel_std: 2.541796 diff --git a/configs/data/ljspeech.yaml b/configs/data/ljspeech.yaml new file mode 100644 index 0000000000000000000000000000000000000000..569f4d86f2293b90ab8ff990b06b4f042add7680 --- /dev/null +++ b/configs/data/ljspeech.yaml @@ -0,0 +1,22 @@ +_target_: matcha.data.text_mel_datamodule.TextMelDataModule +name: ljspeech +train_filelist_path: /content/kany_dataset/kany_filelist_train.txt +valid_filelist_path: /content/kany_dataset/kany_filelist_test.txt +batch_size: 16 +num_workers: 20 +pin_memory: True +cleaners: [kyrgyz_cleaners] +add_blank: True +n_spks: 1 +n_fft: 1024 +n_feats: 80 +sample_rate: 22050 +hop_length: 256 +win_length: 1024 +f_min: 0 +f_max: 8000 +data_statistics: # Computed for ljspeech dataset + mel_mean: -5.68145561 + mel_std: 2.7337122 +seed: ${seed} + diff --git a/configs/data/vctk.yaml b/configs/data/vctk.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ba11cc63371ad6308d6711513268de7efe50eed9 --- /dev/null +++ b/configs/data/vctk.yaml @@ -0,0 +1,14 @@ +defaults: + - ljspeech + - _self_ + +_target_: matcha.data.text_mel_datamodule.TextMelDataModule +name: vctk +train_filelist_path: data/filelists/vctk_audio_sid_text_train_filelist.txt +valid_filelist_path: data/filelists/vctk_audio_sid_text_val_filelist.txt +batch_size: 32 +add_blank: True +n_spks: 109 +data_statistics: # Computed for vctk dataset + mel_mean: -6.630575 + mel_std: 2.482914 diff --git a/configs/debug/default.yaml b/configs/debug/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e3932c82585fbe44047c1569a5cfe9ee9895c71a --- /dev/null +++ b/configs/debug/default.yaml @@ -0,0 +1,35 @@ +# @package _global_ + +# default debugging setup, runs 1 full epoch +# other debugging configs can inherit from this one + +# overwrite task name so debugging logs are stored in separate folder +task_name: "debug" + +# disable callbacks and loggers during debugging +# callbacks: null +# logger: null + +extras: + ignore_warnings: False + enforce_tags: False + +# sets level of all command line loggers to 'DEBUG' +# https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ +hydra: + job_logging: + root: + level: DEBUG + + # use this to also set hydra loggers to 'DEBUG' + # verbose: True + +trainer: + max_epochs: 1 + accelerator: cpu # debuggers don't like gpus + devices: 1 # debuggers don't like multiprocessing + detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor + +data: + num_workers: 0 # debuggers don't like multiprocessing + pin_memory: False # disable gpu memory pin diff --git a/configs/debug/fdr.yaml b/configs/debug/fdr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7f2d34fa37c31017e749d5a4fc5ae6763e688b46 --- /dev/null +++ b/configs/debug/fdr.yaml @@ -0,0 +1,9 @@ +# @package _global_ + +# runs 1 train, 1 validation and 1 test step + +defaults: + - default + +trainer: + fast_dev_run: true diff --git a/configs/debug/limit.yaml b/configs/debug/limit.yaml new file mode 100644 index 0000000000000000000000000000000000000000..514d77fbd1475b03fff0372e3da3c2fa7ea7d190 --- /dev/null +++ b/configs/debug/limit.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +# uses only 1% of the training data and 5% of validation/test data + +defaults: + - default + +trainer: + max_epochs: 3 + limit_train_batches: 0.01 + limit_val_batches: 0.05 + limit_test_batches: 0.05 diff --git a/configs/debug/overfit.yaml b/configs/debug/overfit.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9906586a67a12aa81ff69138f589a366dbe2222f --- /dev/null +++ b/configs/debug/overfit.yaml @@ -0,0 +1,13 @@ +# @package _global_ + +# overfits to 3 batches + +defaults: + - default + +trainer: + max_epochs: 20 + overfit_batches: 3 + +# model ckpt and early stopping need to be disabled during overfitting +callbacks: null diff --git a/configs/debug/profiler.yaml b/configs/debug/profiler.yaml new file mode 100644 index 0000000000000000000000000000000000000000..266295f15e0166e1d1b58b88caa7673f4b6493b5 --- /dev/null +++ b/configs/debug/profiler.yaml @@ -0,0 +1,15 @@ +# @package _global_ + +# runs with execution time profiling + +defaults: + - default + +trainer: + max_epochs: 1 + # profiler: "simple" + profiler: "advanced" + # profiler: "pytorch" + accelerator: gpu + + limit_train_batches: 0.02 diff --git a/configs/eval.yaml b/configs/eval.yaml new file mode 100644 index 0000000000000000000000000000000000000000..52cc42ffec5ab0d5373fb14ef68165b9c165aa53 --- /dev/null +++ b/configs/eval.yaml @@ -0,0 +1,18 @@ +# @package _global_ + +defaults: + - _self_ + - data: akylai # choose datamodule with `test_dataloader()` for evaluation + - model: matcha + - logger: null + - trainer: default + - paths: default + - extras: default + - hydra: default + +task_name: "eval" + +tags: ["dev"] + +# passing checkpoint path is necessary for evaluation +ckpt_path: ??? diff --git a/configs/experiment/akylai.yaml b/configs/experiment/akylai.yaml new file mode 100644 index 0000000000000000000000000000000000000000..081dabc08372062d3fbff25ce2bcd7b9b52273a8 --- /dev/null +++ b/configs/experiment/akylai.yaml @@ -0,0 +1,14 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=multispeaker + +defaults: + - override /data: akylai.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["akylai"] + +run_name: akylai diff --git a/configs/experiment/akylai_multi.yaml b/configs/experiment/akylai_multi.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0d347e4044438c043dedcbafce44e4c1abd3fa79 --- /dev/null +++ b/configs/experiment/akylai_multi.yaml @@ -0,0 +1,14 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=multispeaker + +defaults: + - override /data: akylai_multi.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["akylai_multi"] + +run_name: akylai_multi \ No newline at end of file diff --git a/configs/experiment/hifi_dataset_piper_phonemizer.yaml b/configs/experiment/hifi_dataset_piper_phonemizer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7e6c57a0d0a399f7463f4ff2d96e1928c435779b --- /dev/null +++ b/configs/experiment/hifi_dataset_piper_phonemizer.yaml @@ -0,0 +1,14 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=multispeaker + +defaults: + - override /data: hi-fi_en-US_female.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["hi-fi", "single_speaker", "piper_phonemizer", "en_US", "female"] + +run_name: hi-fi_en-US_female_piper_phonemizer diff --git a/configs/experiment/ljspeech.yaml b/configs/experiment/ljspeech.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d5723f42cf3552226c42bd91202cc18818b685f0 --- /dev/null +++ b/configs/experiment/ljspeech.yaml @@ -0,0 +1,14 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=multispeaker + +defaults: + - override /data: ljspeech.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["ljspeech"] + +run_name: ljspeech diff --git a/configs/experiment/ljspeech_min_memory.yaml b/configs/experiment/ljspeech_min_memory.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ef554dc633c392b1592d90d9d7734f2329264fdd --- /dev/null +++ b/configs/experiment/ljspeech_min_memory.yaml @@ -0,0 +1,18 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=multispeaker + +defaults: + - override /data: ljspeech.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["ljspeech"] + +run_name: ljspeech_min + + +model: + out_size: 172 diff --git a/configs/experiment/multispeaker.yaml b/configs/experiment/multispeaker.yaml new file mode 100644 index 0000000000000000000000000000000000000000..553842f4e2168db0fee4e44db11b5d086295b044 --- /dev/null +++ b/configs/experiment/multispeaker.yaml @@ -0,0 +1,14 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=multispeaker + +defaults: + - override /data: vctk.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["multispeaker"] + +run_name: multispeaker diff --git a/configs/extras/default.yaml b/configs/extras/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b9c6b622283a647fbc513166fc14f016cc3ed8a0 --- /dev/null +++ b/configs/extras/default.yaml @@ -0,0 +1,8 @@ +# disable python warnings if they annoy you +ignore_warnings: False + +# ask user for tags if none are provided in the config +enforce_tags: True + +# pretty print config tree at the start of the run using Rich library +print_config: True diff --git a/configs/hparams_search/mnist_optuna.yaml b/configs/hparams_search/mnist_optuna.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1391183ebcdec3d8f5eb61374e0719d13c7545da --- /dev/null +++ b/configs/hparams_search/mnist_optuna.yaml @@ -0,0 +1,52 @@ +# @package _global_ + +# example hyperparameter optimization of some experiment with Optuna: +# python train.py -m hparams_search=mnist_optuna experiment=example + +defaults: + - override /hydra/sweeper: optuna + +# choose metric which will be optimized by Optuna +# make sure this is the correct name of some metric logged in lightning module! +optimized_metric: "val/acc_best" + +# here we define Optuna hyperparameter search +# it optimizes for value returned from function with @hydra.main decorator +# docs: https://hydra.cc/docs/next/plugins/optuna_sweeper +hydra: + mode: "MULTIRUN" # set hydra to multirun by default if this config is attached + + sweeper: + _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper + + # storage URL to persist optimization results + # for example, you can use SQLite if you set 'sqlite:///example.db' + storage: null + + # name of the study to persist optimization results + study_name: null + + # number of parallel workers + n_jobs: 1 + + # 'minimize' or 'maximize' the objective + direction: maximize + + # total number of runs that will be executed + n_trials: 20 + + # choose Optuna hyperparameter sampler + # you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others + # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html + sampler: + _target_: optuna.samplers.TPESampler + seed: 1234 + n_startup_trials: 10 # number of random sampling runs before optimization starts + + # define hyperparameter search space + params: + model.optimizer.lr: interval(0.0001, 0.1) + data.batch_size: choice(32, 64, 128, 256) + model.net.lin1_size: choice(64, 128, 256) + model.net.lin2_size: choice(64, 128, 256) + model.net.lin3_size: choice(32, 64, 128, 256) diff --git a/configs/hydra/default.yaml b/configs/hydra/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1533136b22802a4f81e5387b74e407289edce94d --- /dev/null +++ b/configs/hydra/default.yaml @@ -0,0 +1,19 @@ +# https://hydra.cc/docs/configure_hydra/intro/ + +# enable color logging +defaults: + - override hydra_logging: colorlog + - override job_logging: colorlog + +# output directory, generated dynamically on each run +run: + dir: ${paths.log_dir}/${task_name}/${run_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S} +sweep: + dir: ${paths.log_dir}/${task_name}/${run_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S} + subdir: ${hydra.job.num} + +job_logging: + handlers: + file: + # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242 + filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log diff --git a/configs/local/.gitkeep b/configs/local/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/configs/logger/aim.yaml b/configs/logger/aim.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8f9f6adad7feb2780c2efd5ddb0ed053621e05f8 --- /dev/null +++ b/configs/logger/aim.yaml @@ -0,0 +1,28 @@ +# https://aimstack.io/ + +# example usage in lightning module: +# https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py + +# open the Aim UI with the following command (run in the folder containing the `.aim` folder): +# `aim up` + +aim: + _target_: aim.pytorch_lightning.AimLogger + repo: ${paths.root_dir} # .aim folder will be created here + # repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html# + + # aim allows to group runs under experiment name + experiment: null # any string, set to "default" if not specified + + train_metric_prefix: "train/" + val_metric_prefix: "val/" + test_metric_prefix: "test/" + + # sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.) + system_tracking_interval: 10 # set to null to disable system metrics tracking + + # enable/disable logging of system params such as installed packages, git info, env vars, etc. + log_system_params: true + + # enable/disable tracking console logs (default value is true) + capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550 diff --git a/configs/logger/comet.yaml b/configs/logger/comet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e0789274e2137ee6c97ca37a5d56c2b8abaf0aaa --- /dev/null +++ b/configs/logger/comet.yaml @@ -0,0 +1,12 @@ +# https://www.comet.ml + +comet: + _target_: lightning.pytorch.loggers.comet.CometLogger + api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable + save_dir: "${paths.output_dir}" + project_name: "lightning-hydra-template" + rest_api_key: null + # experiment_name: "" + experiment_key: null # set to resume experiment + offline: False + prefix: "" diff --git a/configs/logger/csv.yaml b/configs/logger/csv.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fa028e9c146430c319101ffdfce466514338591c --- /dev/null +++ b/configs/logger/csv.yaml @@ -0,0 +1,7 @@ +# csv logger built in lightning + +csv: + _target_: lightning.pytorch.loggers.csv_logs.CSVLogger + save_dir: "${paths.output_dir}" + name: "csv/" + prefix: "" diff --git a/configs/logger/many_loggers.yaml b/configs/logger/many_loggers.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dd586800bdccb4e8f4b0236a181b7ddd756ba9ab --- /dev/null +++ b/configs/logger/many_loggers.yaml @@ -0,0 +1,9 @@ +# train with many loggers at once + +defaults: + # - comet + - csv + # - mlflow + # - neptune + - tensorboard + - wandb diff --git a/configs/logger/mlflow.yaml b/configs/logger/mlflow.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f8fb7e685fa27fc8141387a421b90a0b9b492d9e --- /dev/null +++ b/configs/logger/mlflow.yaml @@ -0,0 +1,12 @@ +# https://mlflow.org + +mlflow: + _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger + # experiment_name: "" + # run_name: "" + tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI + tags: null + # save_dir: "./mlruns" + prefix: "" + artifact_location: null + # run_id: "" diff --git a/configs/logger/neptune.yaml b/configs/logger/neptune.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8233c140018ecce6ab62971beed269991d31c89b --- /dev/null +++ b/configs/logger/neptune.yaml @@ -0,0 +1,9 @@ +# https://neptune.ai + +neptune: + _target_: lightning.pytorch.loggers.neptune.NeptuneLogger + api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable + project: username/lightning-hydra-template + # name: "" + log_model_checkpoints: True + prefix: "" diff --git a/configs/logger/tensorboard.yaml b/configs/logger/tensorboard.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2bd31f6d8ba68d1f5c36a804885d5b9f9c1a9302 --- /dev/null +++ b/configs/logger/tensorboard.yaml @@ -0,0 +1,10 @@ +# https://www.tensorflow.org/tensorboard/ + +tensorboard: + _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger + save_dir: "${paths.output_dir}/tensorboard/" + name: null + log_graph: False + default_hp_metric: True + prefix: "" + # version: "" diff --git a/configs/logger/wandb.yaml b/configs/logger/wandb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ece165889b3d0d9dc750a8f3c7454188cfdf12b7 --- /dev/null +++ b/configs/logger/wandb.yaml @@ -0,0 +1,16 @@ +# https://wandb.ai + +wandb: + _target_: lightning.pytorch.loggers.wandb.WandbLogger + # name: "" # name of the run (normally generated by wandb) + save_dir: "${paths.output_dir}" + offline: False + id: null # pass correct id to resume experiment! + anonymous: null # enable anonymous logging + project: "lightning-hydra-template" + log_model: False # upload lightning ckpts + prefix: "" # a string to put at the beginning of metric keys + # entity: "" # set to name of your wandb team + group: "" + tags: [] + job_type: "" diff --git a/configs/model/cfm/default.yaml b/configs/model/cfm/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0d1d9609e2d05c7b0a12a26115520340ac18e584 --- /dev/null +++ b/configs/model/cfm/default.yaml @@ -0,0 +1,3 @@ +name: CFM +solver: euler +sigma_min: 1e-4 diff --git a/configs/model/decoder/default.yaml b/configs/model/decoder/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..aaa00e63402ade5c76247a2f1d6b294ec3c61e63 --- /dev/null +++ b/configs/model/decoder/default.yaml @@ -0,0 +1,7 @@ +channels: [256, 256] +dropout: 0.05 +attention_head_dim: 64 +n_blocks: 1 +num_mid_blocks: 2 +num_heads: 2 +act_fn: snakebeta diff --git a/configs/model/encoder/default.yaml b/configs/model/encoder/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d4d5e5adee8f707bd384b682a3ad9a116c40c6ed --- /dev/null +++ b/configs/model/encoder/default.yaml @@ -0,0 +1,18 @@ +encoder_type: RoPE Encoder +encoder_params: + n_feats: ${model.n_feats} + n_channels: 192 + filter_channels: 768 + filter_channels_dp: 256 + n_heads: 2 + n_layers: 6 + kernel_size: 3 + p_dropout: 0.1 + spk_emb_dim: 64 + n_spks: 1 + prenet: true + +duration_predictor_params: + filter_channels_dp: ${model.encoder.encoder_params.filter_channels_dp} + kernel_size: 3 + p_dropout: ${model.encoder.encoder_params.p_dropout} diff --git a/configs/model/matcha.yaml b/configs/model/matcha.yaml new file mode 100644 index 0000000000000000000000000000000000000000..36f6eafbdcaa324f7494a4b97a7590da7824f357 --- /dev/null +++ b/configs/model/matcha.yaml @@ -0,0 +1,15 @@ +defaults: + - _self_ + - encoder: default.yaml + - decoder: default.yaml + - cfm: default.yaml + - optimizer: adam.yaml + +_target_: matcha.models.matcha_tts.MatchaTTS +n_vocab: 178 +n_spks: ${data.n_spks} +spk_emb_dim: 64 +n_feats: 80 +data_statistics: ${data.data_statistics} +out_size: null # Must be divisible by 4 +prior_loss: true diff --git a/configs/model/optimizer/adam.yaml b/configs/model/optimizer/adam.yaml new file mode 100644 index 0000000000000000000000000000000000000000..42795577474eaee5b0b96845a95e1a11c9152385 --- /dev/null +++ b/configs/model/optimizer/adam.yaml @@ -0,0 +1,4 @@ +_target_: torch.optim.Adam +_partial_: true +lr: 1e-4 +weight_decay: 0.0 diff --git a/configs/paths/default.yaml b/configs/paths/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ec81db2d34712909a79be3e42e65efe08c35ecee --- /dev/null +++ b/configs/paths/default.yaml @@ -0,0 +1,18 @@ +# path to root directory +# this requires PROJECT_ROOT environment variable to exist +# you can replace it with "." if you want the root to be the current working directory +root_dir: ${oc.env:PROJECT_ROOT} + +# path to data directory +data_dir: ${paths.root_dir}/data/ + +# path to logging directory +log_dir: ${paths.root_dir}/logs/ + +# path to output directory, created dynamically by hydra +# path generation pattern is specified in `configs/hydra/default.yaml` +# use it to store all files generated during the run, like ckpts and metrics +output_dir: ${hydra:runtime.output_dir} + +# path to working directory +work_dir: ${hydra:runtime.cwd} diff --git a/configs/train.yaml b/configs/train.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f0cd014d3c33a73d5cf46e6161cb184fc55fdc39 --- /dev/null +++ b/configs/train.yaml @@ -0,0 +1,51 @@ +# @package _global_ + +# specify here default configuration +# order of defaults determines the order in which configs override each other +defaults: + - _self_ + - data: akylai + - model: matcha + - callbacks: default + - logger: tensorboard # set logger here or use command line (e.g. `python train.py logger=tensorboard`) + - trainer: default + - paths: default + - extras: default + - hydra: default + + # experiment configs allow for version control of specific hyperparameters + # e.g. best hyperparameters for given model and datamodule + - experiment: null + + # config for hyperparameter optimization + - hparams_search: null + + # optional local config for machine/user specific settings + # it's optional since it doesn't need to exist and is excluded from version control + - optional local: default + + # debugging config (enable through command line, e.g. `python train.py debug=default) + - debug: null + +# task name, determines output directory path +task_name: "train" + +run_name: ??? + +# tags to help you identify your experiments +# you can overwrite this in experiment configs +# overwrite from command line with `python train.py tags="[first_tag, second_tag]"` +tags: ["dev"] + +# set False to skip model training +train: True + +# evaluate on test set, using best model weights achieved during training +# lightning chooses best weights based on the metric specified in checkpoint callback +test: False + +# simply provide checkpoint path to resume training +ckpt_path: "https://github.com/simonlobgromov/AkylAI_Matcha_Checkpoint/releases/download/Matcha-TTS/checkpoint_epoch.499.ckpt" + +# seed for random number generators in pytorch, numpy and python.random +seed: 1234 diff --git a/configs/trainer/cpu.yaml b/configs/trainer/cpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b7d6767e60c956567555980654f15e7bb673a41f --- /dev/null +++ b/configs/trainer/cpu.yaml @@ -0,0 +1,5 @@ +defaults: + - default + +accelerator: cpu +devices: 1 diff --git a/configs/trainer/ddp.yaml b/configs/trainer/ddp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..94b43e20ca7bf1f2ea92627fd46906e4f0a273a1 --- /dev/null +++ b/configs/trainer/ddp.yaml @@ -0,0 +1,9 @@ +defaults: + - default + +strategy: ddp + +accelerator: gpu +devices: [0,1] +num_nodes: 1 +sync_batchnorm: True diff --git a/configs/trainer/ddp_sim.yaml b/configs/trainer/ddp_sim.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8404419e5c295654967d0dfb73a7366e75be2f1f --- /dev/null +++ b/configs/trainer/ddp_sim.yaml @@ -0,0 +1,7 @@ +defaults: + - default + +# simulate DDP on CPU, useful for debugging +accelerator: cpu +devices: 2 +strategy: ddp_spawn diff --git a/configs/trainer/default.yaml b/configs/trainer/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..631701fa235e62b81dc0beed29314946722ce9dd --- /dev/null +++ b/configs/trainer/default.yaml @@ -0,0 +1,20 @@ +_target_: lightning.pytorch.trainer.Trainer + +default_root_dir: ${paths.output_dir} + +max_epochs: 710 + +accelerator: gpu +devices: [0, 1, 2, 3] + +# mixed precision for extra speed-up +precision: 16-mixed + +# perform a validation loop every N training epochs +check_val_every_n_epoch: 1 + +# set True to to ensure deterministic results +# makes training slower but gives more reproducibility than just setting seeds +deterministic: False + +gradient_clip_val: 5.0 diff --git a/configs/trainer/gpu.yaml b/configs/trainer/gpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b2389510a90f5f0161cff6ccfcb4a96097ddf9a1 --- /dev/null +++ b/configs/trainer/gpu.yaml @@ -0,0 +1,5 @@ +defaults: + - default + +accelerator: gpu +devices: 1 diff --git a/configs/trainer/mps.yaml b/configs/trainer/mps.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1ecf6d5cc3a34ca127c5510f4a18e989561e38e4 --- /dev/null +++ b/configs/trainer/mps.yaml @@ -0,0 +1,5 @@ +defaults: + - default + +accelerator: mps +devices: 1 diff --git a/data b/data new file mode 120000 index 0000000000000000000000000000000000000000..18e4b1a408c735470895e67a9e5151721c71d271 --- /dev/null +++ b/data @@ -0,0 +1 @@ +/home/smehta/Projects/Speech-Backbones/Grad-TTS/data \ No newline at end of file diff --git a/gitattributes:Zone.Identifier b/gitattributes:Zone.Identifier new file mode 100644 index 0000000000000000000000000000000000000000..fb4b5404c39f2dff8a33bb4346afcebe48fd9e65 --- /dev/null +++ b/gitattributes:Zone.Identifier @@ -0,0 +1,4 @@ +[ZoneTransfer] +ZoneId=3 +ReferrerUrl=https://huggingface.co/spaces/the-cramer-project/AkylAI_TTS_small/tree/main +HostUrl=https://huggingface.co/spaces/the-cramer-project/AkylAI_TTS_small/resolve/main/.gitattributes?download=true diff --git a/inference_script.py b/inference_script.py new file mode 100644 index 0000000000000000000000000000000000000000..530ea820ce67ccc66fd13c88d05be2a98e131f10 --- /dev/null +++ b/inference_script.py @@ -0,0 +1,114 @@ +from pathlib import Path +import argparse +import soundfile as sf +import torch +import io +import argparse +from matcha.hifigan.config import v1 +from matcha.hifigan.denoiser import Denoiser +from matcha.hifigan.env import AttrDict +from matcha.hifigan.models import Generator as HiFiGAN +from matcha.models.matcha_tts import MatchaTTS +from matcha.text import sequence_to_text, text_to_sequence +from matcha.utils.utils import intersperse + + + + +def load_matcha( checkpoint_path, device): + model = MatchaTTS.load_from_checkpoint(checkpoint_path, map_location=device) + _ = model.eval() + return model + +def load_hifigan(checkpoint_path, device): + h = AttrDict(v1) + hifigan = HiFiGAN(h).to(device) + hifigan.load_state_dict(torch.load(checkpoint_path, map_location=device)["generator"]) + _ = hifigan.eval() + hifigan.remove_weight_norm() + return hifigan + +def load_vocoder(checkpoint_path, device): + vocoder = None + vocoder = load_hifigan(checkpoint_path, device) + denoiser = Denoiser(vocoder, mode="zeros") + return vocoder, denoiser + +def process_text(i: int, text: str, device: torch.device): + print(f"[{i}] - Input text: {text}") + x = torch.tensor( + intersperse(text_to_sequence(text, ["kyrgyz_cleaners"]), 0), + dtype=torch.long, + device=device, + )[None] + x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device) + x_phones = sequence_to_text(x.squeeze(0).tolist()) + print(f"[{i}] - Phonetised text: {x_phones[1::2]}") + return {"x_orig": text, "x": x, "x_lengths": x_lengths, "x_phones": x_phones} + +def to_waveform(mel, vocoder, denoiser=None): + audio = vocoder(mel).clamp(-1, 1) + if denoiser is not None: + audio = denoiser(audio.squeeze(), strength=0.00025).cpu().squeeze() + return audio.cpu().squeeze() + +@torch.inference_mode() +def process_text_gradio(text): + output = process_text(1, text, device) + return output["x_phones"][1::2], output["x"], output["x_lengths"] + +@torch.inference_mode() +def synthesise_mel(text, text_length, n_timesteps, temperature, length_scale, spk=-1): + spk = torch.tensor([spk], device=device, dtype=torch.long) if spk >= 0 else None + output = model.synthesise( + text, + text_length, + n_timesteps=n_timesteps, + temperature=temperature, + spks=spk, + length_scale=length_scale, + ) + output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) + sf.write('./output/out_audio.wav', output["waveform"], 22050, "PCM_24") + +def get_inference(text, n_timesteps=20, mel_temp = 0.667, length_scale=0.8, spk=-1): + phones, text, text_lengths = process_text_gradio(text) + synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk) + + +def tensor_to_wav_bytes(tensor_audio, sample_rate=22050): # Байтовый формат + waveform = tensor_audio.cpu().numpy() + with io.BytesIO() as buffer: + sf.write(buffer, waveform, sample_rate, format='WAV') + wav_bytes = buffer.getvalue() + return wav_bytes + + + +device = torch.device("cpu") +model_path = './checkpoints/checkpoint.ckpt' +vocoder_path = './checkpoints/generator' +model = load_matcha(model_path, device) +vocoder, denoiser = load_vocoder(vocoder_path, device) + +def main(): + + parser = argparse.ArgumentParser( + description="Если возжелаете параметры которые вам угодны, Сэр))" + ) + parser.add_argument("--text", type=str, default=None, help="Text to synthesize") + parser.add_argument( + "--speaking_rate", + type=float, + default=0.8, + help="change the speaking rate, a higher value means slower speaking rate (default: 0.8)", + ) + args = parser.parse_args() + + get_inference(text = args.text, length_scale=args.speaking_rate) + + + + +if __name__ == "__main__": + main() diff --git a/matcha/VERSION b/matcha/VERSION new file mode 100644 index 0000000000000000000000000000000000000000..442b1138f7851df1c22deb15fd5d6ff5b742e550 --- /dev/null +++ b/matcha/VERSION @@ -0,0 +1 @@ +0.0.5.1 diff --git a/matcha/__init__.py b/matcha/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/matcha/data/__init__.py b/matcha/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/matcha/data/components/__init__.py b/matcha/data/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/matcha/data/text_mel_datamodule.py b/matcha/data/text_mel_datamodule.py new file mode 100644 index 0000000000000000000000000000000000000000..704f93629f1874b88efd07609409653ffbb8338a --- /dev/null +++ b/matcha/data/text_mel_datamodule.py @@ -0,0 +1,231 @@ +import random +from typing import Any, Dict, Optional + +import torch +import torchaudio as ta +from lightning import LightningDataModule +from torch.utils.data.dataloader import DataLoader + +from matcha.text import text_to_sequence +from matcha.utils.audio import mel_spectrogram +from matcha.utils.model import fix_len_compatibility, normalize +from matcha.utils.utils import intersperse + + +def parse_filelist(filelist_path, split_char="|"): + with open(filelist_path, encoding="utf-8") as f: + filepaths_and_text = [line.strip().split(split_char) for line in f] + return filepaths_and_text + + +class TextMelDataModule(LightningDataModule): + def __init__( # pylint: disable=unused-argument + self, + name, + train_filelist_path, + valid_filelist_path, + batch_size, + num_workers, + pin_memory, + cleaners, + add_blank, + n_spks, + n_fft, + n_feats, + sample_rate, + hop_length, + win_length, + f_min, + f_max, + data_statistics, + seed, + ): + super().__init__() + + # this line allows to access init params with 'self.hparams' attribute + # also ensures init params will be stored in ckpt + self.save_hyperparameters(logger=False) + + def setup(self, stage: Optional[str] = None): # pylint: disable=unused-argument + """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. + + This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be + careful not to execute things like random split twice! + """ + # load and split datasets only if not loaded already + + self.trainset = TextMelDataset( # pylint: disable=attribute-defined-outside-init + self.hparams.train_filelist_path, + self.hparams.n_spks, + self.hparams.cleaners, + self.hparams.add_blank, + self.hparams.n_fft, + self.hparams.n_feats, + self.hparams.sample_rate, + self.hparams.hop_length, + self.hparams.win_length, + self.hparams.f_min, + self.hparams.f_max, + self.hparams.data_statistics, + self.hparams.seed, + ) + self.validset = TextMelDataset( # pylint: disable=attribute-defined-outside-init + self.hparams.valid_filelist_path, + self.hparams.n_spks, + self.hparams.cleaners, + self.hparams.add_blank, + self.hparams.n_fft, + self.hparams.n_feats, + self.hparams.sample_rate, + self.hparams.hop_length, + self.hparams.win_length, + self.hparams.f_min, + self.hparams.f_max, + self.hparams.data_statistics, + self.hparams.seed, + ) + + def train_dataloader(self): + return DataLoader( + dataset=self.trainset, + batch_size=self.hparams.batch_size, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=True, + collate_fn=TextMelBatchCollate(self.hparams.n_spks), + ) + + def val_dataloader(self): + return DataLoader( + dataset=self.validset, + batch_size=self.hparams.batch_size, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=False, + collate_fn=TextMelBatchCollate(self.hparams.n_spks), + ) + + def teardown(self, stage: Optional[str] = None): + """Clean up after fit or test.""" + pass # pylint: disable=unnecessary-pass + + def state_dict(self): # pylint: disable=no-self-use + """Extra things to save to checkpoint.""" + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]): + """Things to do when loading checkpoint.""" + pass # pylint: disable=unnecessary-pass + + +class TextMelDataset(torch.utils.data.Dataset): + def __init__( + self, + filelist_path, + n_spks, + cleaners, + add_blank=True, + n_fft=1024, + n_mels=80, + sample_rate=22050, + hop_length=256, + win_length=1024, + f_min=0.0, + f_max=8000, + data_parameters=None, + seed=None, + ): + self.filepaths_and_text = parse_filelist(filelist_path) + self.n_spks = n_spks + self.cleaners = cleaners + self.add_blank = add_blank + self.n_fft = n_fft + self.n_mels = n_mels + self.sample_rate = sample_rate + self.hop_length = hop_length + self.win_length = win_length + self.f_min = f_min + self.f_max = f_max + if data_parameters is not None: + self.data_parameters = data_parameters + else: + self.data_parameters = {"mel_mean": 0, "mel_std": 1} + random.seed(seed) + random.shuffle(self.filepaths_and_text) + + def get_datapoint(self, filepath_and_text): + if self.n_spks > 1: + filepath, spk, text = ( + filepath_and_text[0], + int(filepath_and_text[1]), + filepath_and_text[2], + ) + else: + filepath, text = filepath_and_text[0], filepath_and_text[1] + spk = None + + text = self.get_text(text, add_blank=self.add_blank) + mel = self.get_mel(filepath) + + return {"x": text, "y": mel, "spk": spk} + + def get_mel(self, filepath): + audio, sr = ta.load(filepath) + assert sr == self.sample_rate + mel = mel_spectrogram( + audio, + self.n_fft, + self.n_mels, + self.sample_rate, + self.hop_length, + self.win_length, + self.f_min, + self.f_max, + center=False, + ).squeeze() + mel = normalize(mel, self.data_parameters["mel_mean"], self.data_parameters["mel_std"]) + return mel + + def get_text(self, text, add_blank=True): + text_norm = text_to_sequence(text, self.cleaners) + if self.add_blank: + text_norm = intersperse(text_norm, 0) + text_norm = torch.IntTensor(text_norm) + return text_norm + + def __getitem__(self, index): + datapoint = self.get_datapoint(self.filepaths_and_text[index]) + return datapoint + + def __len__(self): + return len(self.filepaths_and_text) + + +class TextMelBatchCollate: + def __init__(self, n_spks): + self.n_spks = n_spks + + def __call__(self, batch): + B = len(batch) + y_max_length = max([item["y"].shape[-1] for item in batch]) + y_max_length = fix_len_compatibility(y_max_length) + x_max_length = max([item["x"].shape[-1] for item in batch]) + n_feats = batch[0]["y"].shape[-2] + + y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32) + x = torch.zeros((B, x_max_length), dtype=torch.long) + y_lengths, x_lengths = [], [] + spks = [] + for i, item in enumerate(batch): + y_, x_ = item["y"], item["x"] + y_lengths.append(y_.shape[-1]) + x_lengths.append(x_.shape[-1]) + y[i, :, : y_.shape[-1]] = y_ + x[i, : x_.shape[-1]] = x_ + spks.append(item["spk"]) + + y_lengths = torch.tensor(y_lengths, dtype=torch.long) + x_lengths = torch.tensor(x_lengths, dtype=torch.long) + spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None + + return {"x": x, "x_lengths": x_lengths, "y": y, "y_lengths": y_lengths, "spks": spks} diff --git a/matcha/hifigan/LICENSE b/matcha/hifigan/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..91751daed806f63ac594cf077a3065f719a41662 --- /dev/null +++ b/matcha/hifigan/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Jungil Kong + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/matcha/hifigan/README.md b/matcha/hifigan/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5db25850451a794b1db1b15b08e82c1d802edbb3 --- /dev/null +++ b/matcha/hifigan/README.md @@ -0,0 +1,101 @@ +# HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis + +### Jungil Kong, Jaehyeon Kim, Jaekyoung Bae + +In our [paper](https://arxiv.org/abs/2010.05646), +we proposed HiFi-GAN: a GAN-based model capable of generating high fidelity speech efficiently.
+We provide our implementation and pretrained models as open source in this repository. + +**Abstract :** +Several recent work on speech synthesis have employed generative adversarial networks (GANs) to produce raw waveforms. +Although such methods improve the sampling efficiency and memory usage, +their sample quality has not yet reached that of autoregressive and flow-based generative models. +In this work, we propose HiFi-GAN, which achieves both efficient and high-fidelity speech synthesis. +As speech audio consists of sinusoidal signals with various periods, +we demonstrate that modeling periodic patterns of an audio is crucial for enhancing sample quality. +A subjective human evaluation (mean opinion score, MOS) of a single speaker dataset indicates that our proposed method +demonstrates similarity to human quality while generating 22.05 kHz high-fidelity audio 167.9 times faster than +real-time on a single V100 GPU. We further show the generality of HiFi-GAN to the mel-spectrogram inversion of unseen +speakers and end-to-end speech synthesis. Finally, a small footprint version of HiFi-GAN generates samples 13.4 times +faster than real-time on CPU with comparable quality to an autoregressive counterpart. + +Visit our [demo website](https://jik876.github.io/hifi-gan-demo/) for audio samples. + +## Pre-requisites + +1. Python >= 3.6 +2. Clone this repository. +3. Install python requirements. Please refer [requirements.txt](requirements.txt) +4. Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/). + And move all wav files to `LJSpeech-1.1/wavs` + +## Training + +``` +python train.py --config config_v1.json +``` + +To train V2 or V3 Generator, replace `config_v1.json` with `config_v2.json` or `config_v3.json`.
+Checkpoints and copy of the configuration file are saved in `cp_hifigan` directory by default.
+You can change the path by adding `--checkpoint_path` option. + +Validation loss during training with V1 generator.
+![validation loss](./validation_loss.png) + +## Pretrained Model + +You can also use pretrained models we provide.
+[Download pretrained models](https://drive.google.com/drive/folders/1-eEYTB5Av9jNql0WGBlRoi-WH2J7bp5Y?usp=sharing)
+Details of each folder are as in follows: + +| Folder Name | Generator | Dataset | Fine-Tuned | +| ------------ | --------- | --------- | ------------------------------------------------------ | +| LJ_V1 | V1 | LJSpeech | No | +| LJ_V2 | V2 | LJSpeech | No | +| LJ_V3 | V3 | LJSpeech | No | +| LJ_FT_T2_V1 | V1 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | +| LJ_FT_T2_V2 | V2 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | +| LJ_FT_T2_V3 | V3 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | +| VCTK_V1 | V1 | VCTK | No | +| VCTK_V2 | V2 | VCTK | No | +| VCTK_V3 | V3 | VCTK | No | +| UNIVERSAL_V1 | V1 | Universal | No | + +We provide the universal model with discriminator weights that can be used as a base for transfer learning to other datasets. + +## Fine-Tuning + +1. Generate mel-spectrograms in numpy format using [Tacotron2](https://github.com/NVIDIA/tacotron2) with teacher-forcing.
+ The file name of the generated mel-spectrogram should match the audio file and the extension should be `.npy`.
+ Example: + ` Audio File : LJ001-0001.wav +Mel-Spectrogram File : LJ001-0001.npy` +2. Create `ft_dataset` folder and copy the generated mel-spectrogram files into it.
+3. Run the following command. + ``` + python train.py --fine_tuning True --config config_v1.json + ``` + For other command line options, please refer to the training section. + +## Inference from wav file + +1. Make `test_files` directory and copy wav files into the directory. +2. Run the following command. + ` python inference.py --checkpoint_file [generator checkpoint file path]` + Generated wav files are saved in `generated_files` by default.
+ You can change the path by adding `--output_dir` option. + +## Inference for end-to-end speech synthesis + +1. Make `test_mel_files` directory and copy generated mel-spectrogram files into the directory.
+ You can generate mel-spectrograms using [Tacotron2](https://github.com/NVIDIA/tacotron2), + [Glow-TTS](https://github.com/jaywalnut310/glow-tts) and so forth. +2. Run the following command. + ` python inference_e2e.py --checkpoint_file [generator checkpoint file path]` + Generated wav files are saved in `generated_files_from_mel` by default.
+ You can change the path by adding `--output_dir` option. + +## Acknowledgements + +We referred to [WaveGlow](https://github.com/NVIDIA/waveglow), [MelGAN](https://github.com/descriptinc/melgan-neurips) +and [Tacotron2](https://github.com/NVIDIA/tacotron2) to implement this. diff --git a/matcha/hifigan/__init__.py b/matcha/hifigan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/matcha/hifigan/config.py b/matcha/hifigan/config.py new file mode 100644 index 0000000000000000000000000000000000000000..b3abea9e151a08864353d32066bd4935e24b82e7 --- /dev/null +++ b/matcha/hifigan/config.py @@ -0,0 +1,28 @@ +v1 = { + "resblock": "1", + "num_gpus": 0, + "batch_size": 16, + "learning_rate": 0.0004, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "seed": 1234, + "upsample_rates": [8, 8, 2, 2], + "upsample_kernel_sizes": [16, 16, 4, 4], + "upsample_initial_channel": 512, + "resblock_kernel_sizes": [3, 7, 11], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "resblock_initial_channel": 256, + "segment_size": 8192, + "num_mels": 80, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + "sampling_rate": 22050, + "fmin": 0, + "fmax": 8000, + "fmax_loss": None, + "num_workers": 4, + "dist_config": {"dist_backend": "nccl", "dist_url": "tcp://localhost:54321", "world_size": 1}, +} diff --git a/matcha/hifigan/denoiser.py b/matcha/hifigan/denoiser.py new file mode 100644 index 0000000000000000000000000000000000000000..9fd33312a09b1940374a0e29a97fe3a1a1dac7d2 --- /dev/null +++ b/matcha/hifigan/denoiser.py @@ -0,0 +1,64 @@ +# Code modified from Rafael Valle's implementation https://github.com/NVIDIA/waveglow/blob/5bc2a53e20b3b533362f974cfa1ea0267ae1c2b1/denoiser.py + +"""Waveglow style denoiser can be used to remove the artifacts from the HiFiGAN generated audio.""" +import torch + + +class Denoiser(torch.nn.Module): + """Removes model bias from audio produced with waveglow""" + + def __init__(self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros"): + super().__init__() + self.filter_length = filter_length + self.hop_length = int(filter_length / n_overlap) + self.win_length = win_length + + dtype, device = next(vocoder.parameters()).dtype, next(vocoder.parameters()).device + self.device = device + if mode == "zeros": + mel_input = torch.zeros((1, 80, 88), dtype=dtype, device=device) + elif mode == "normal": + mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device) + else: + raise Exception(f"Mode {mode} if not supported") + + def stft_fn(audio, n_fft, hop_length, win_length, window): + spec = torch.stft( + audio, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + return_complex=True, + ) + spec = torch.view_as_real(spec) + return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(spec[..., -1], spec[..., 0]) + + self.stft = lambda x: stft_fn( + audio=x, + n_fft=self.filter_length, + hop_length=self.hop_length, + win_length=self.win_length, + window=torch.hann_window(self.win_length, device=device), + ) + self.istft = lambda x, y: torch.istft( + torch.complex(x * torch.cos(y), x * torch.sin(y)), + n_fft=self.filter_length, + hop_length=self.hop_length, + win_length=self.win_length, + window=torch.hann_window(self.win_length, device=device), + ) + + with torch.no_grad(): + bias_audio = vocoder(mel_input).float().squeeze(0) + bias_spec, _ = self.stft(bias_audio) + + self.register_buffer("bias_spec", bias_spec[:, :, 0][:, :, None]) + + @torch.inference_mode() + def forward(self, audio, strength=0.0005): + audio_spec, audio_angles = self.stft(audio) + audio_spec_denoised = audio_spec - self.bias_spec.to(audio.device) * strength + audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0) + audio_denoised = self.istft(audio_spec_denoised, audio_angles) + return audio_denoised diff --git a/matcha/hifigan/env.py b/matcha/hifigan/env.py new file mode 100644 index 0000000000000000000000000000000000000000..9ea4f948a3f002921bf9bc24f52cbc1c0b1fc2ec --- /dev/null +++ b/matcha/hifigan/env.py @@ -0,0 +1,17 @@ +""" from https://github.com/jik876/hifi-gan """ + +import os +import shutil + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__dict__ = self + + +def build_env(config, config_name, path): + t_path = os.path.join(path, config_name) + if config != t_path: + os.makedirs(path, exist_ok=True) + shutil.copyfile(config, os.path.join(path, config_name)) diff --git a/matcha/hifigan/meldataset.py b/matcha/hifigan/meldataset.py new file mode 100644 index 0000000000000000000000000000000000000000..8b43ea7965e04a52d5427a485ee911b743057c4a --- /dev/null +++ b/matcha/hifigan/meldataset.py @@ -0,0 +1,217 @@ +""" from https://github.com/jik876/hifi-gan """ + +import math +import os +import random + +import numpy as np +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn +from librosa.util import normalize +from scipy.io.wavfile import read + +MAX_WAV_VALUE = 32768.0 + + +def load_wav(full_path): + sampling_rate, data = read(full_path) + return data, sampling_rate + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window # pylint: disable=global-statement + if fmax not in mel_basis: + mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) + mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) + hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" + ) + y = y.squeeze(1) + + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[str(y.device)], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + + spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec + + +def get_dataset_filelist(a): + with open(a.input_training_file, encoding="utf-8") as fi: + training_files = [ + os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0 + ] + + with open(a.input_validation_file, encoding="utf-8") as fi: + validation_files = [ + os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0 + ] + return training_files, validation_files + + +class MelDataset(torch.utils.data.Dataset): + def __init__( + self, + training_files, + segment_size, + n_fft, + num_mels, + hop_size, + win_size, + sampling_rate, + fmin, + fmax, + split=True, + shuffle=True, + n_cache_reuse=1, + device=None, + fmax_loss=None, + fine_tuning=False, + base_mels_path=None, + ): + self.audio_files = training_files + random.seed(1234) + if shuffle: + random.shuffle(self.audio_files) + self.segment_size = segment_size + self.sampling_rate = sampling_rate + self.split = split + self.n_fft = n_fft + self.num_mels = num_mels + self.hop_size = hop_size + self.win_size = win_size + self.fmin = fmin + self.fmax = fmax + self.fmax_loss = fmax_loss + self.cached_wav = None + self.n_cache_reuse = n_cache_reuse + self._cache_ref_count = 0 + self.device = device + self.fine_tuning = fine_tuning + self.base_mels_path = base_mels_path + + def __getitem__(self, index): + filename = self.audio_files[index] + if self._cache_ref_count == 0: + audio, sampling_rate = load_wav(filename) + audio = audio / MAX_WAV_VALUE + if not self.fine_tuning: + audio = normalize(audio) * 0.95 + self.cached_wav = audio + if sampling_rate != self.sampling_rate: + raise ValueError(f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR") + self._cache_ref_count = self.n_cache_reuse + else: + audio = self.cached_wav + self._cache_ref_count -= 1 + + audio = torch.FloatTensor(audio) + audio = audio.unsqueeze(0) + + if not self.fine_tuning: + if self.split: + if audio.size(1) >= self.segment_size: + max_audio_start = audio.size(1) - self.segment_size + audio_start = random.randint(0, max_audio_start) + audio = audio[:, audio_start : audio_start + self.segment_size] + else: + audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant") + + mel = mel_spectrogram( + audio, + self.n_fft, + self.num_mels, + self.sampling_rate, + self.hop_size, + self.win_size, + self.fmin, + self.fmax, + center=False, + ) + else: + mel = np.load(os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + ".npy")) + mel = torch.from_numpy(mel) + + if len(mel.shape) < 3: + mel = mel.unsqueeze(0) + + if self.split: + frames_per_seg = math.ceil(self.segment_size / self.hop_size) + + if audio.size(1) >= self.segment_size: + mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1) + mel = mel[:, :, mel_start : mel_start + frames_per_seg] + audio = audio[:, mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size] + else: + mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant") + audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant") + + mel_loss = mel_spectrogram( + audio, + self.n_fft, + self.num_mels, + self.sampling_rate, + self.hop_size, + self.win_size, + self.fmin, + self.fmax_loss, + center=False, + ) + + return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze()) + + def __len__(self): + return len(self.audio_files) diff --git a/matcha/hifigan/models.py b/matcha/hifigan/models.py new file mode 100644 index 0000000000000000000000000000000000000000..d209d9a4e99ec29e4167a5a2eaa62d72b3eff694 --- /dev/null +++ b/matcha/hifigan/models.py @@ -0,0 +1,368 @@ +""" from https://github.com/jik876/hifi-gan """ + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d +from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm + +from .xutils import get_padding, init_weights + +LRELU_SLOPE = 0.1 + + +class ResBlock1(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super().__init__() + self.h = h + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): + super().__init__() + self.h = h + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + self.convs.apply(init_weights) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class Generator(torch.nn.Module): + def __init__(self, h): + super().__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)) + resblock = ResBlock1 if h.resblock == "1" else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2 ** (i + 1)) + for _, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print("Removing weight norm...") + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super().__init__() + self.period = period + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ] + ) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self): + super().__init__() + self.discriminators = nn.ModuleList( + [ + DiscriminatorP(2), + DiscriminatorP(3), + DiscriminatorP(5), + DiscriminatorP(7), + DiscriminatorP(11), + ] + ) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for _, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super().__init__() + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv1d(1, 128, 15, 1, padding=7)), + norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiScaleDiscriminator(torch.nn.Module): + def __init__(self): + super().__init__() + self.discriminators = nn.ModuleList( + [ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ] + ) + self.meanpools = nn.ModuleList([AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + if i != 0: + y = self.meanpools[i - 1](y) + y_hat = self.meanpools[i - 1](y_hat) + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg**2) + loss += r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses diff --git a/matcha/hifigan/xutils.py b/matcha/hifigan/xutils.py new file mode 100644 index 0000000000000000000000000000000000000000..eefadcb7a1d0bf9015e636b88fee3e22c9771bc5 --- /dev/null +++ b/matcha/hifigan/xutils.py @@ -0,0 +1,60 @@ +""" from https://github.com/jik876/hifi-gan """ + +import glob +import os + +import matplotlib +import torch +from torch.nn.utils import weight_norm + +matplotlib.use("Agg") +import matplotlib.pylab as plt + + +def plot_spectrogram(spectrogram): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def apply_weight_norm(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + weight_norm(m) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print(f"Loading '{filepath}'") + checkpoint_dict = torch.load(filepath, map_location=device) + print("Complete.") + return checkpoint_dict + + +def save_checkpoint(filepath, obj): + print(f"Saving checkpoint to {filepath}") + torch.save(obj, filepath) + print("Complete.") + + +def scan_checkpoint(cp_dir, prefix): + pattern = os.path.join(cp_dir, prefix + "????????") + cp_list = glob.glob(pattern) + if len(cp_list) == 0: + return None + return sorted(cp_list)[-1] diff --git a/matcha/models/__init__.py b/matcha/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/matcha/models/baselightningmodule.py b/matcha/models/baselightningmodule.py new file mode 100644 index 0000000000000000000000000000000000000000..3724888090e36b5f55445d33a87fcdae687b35a5 --- /dev/null +++ b/matcha/models/baselightningmodule.py @@ -0,0 +1,209 @@ +""" +This is a base lightning module that can be used to train a model. +The benefit of this abstraction is that all the logic outside of model definition can be reused for different models. +""" +import inspect +from abc import ABC +from typing import Any, Dict + +import torch +from lightning import LightningModule +from lightning.pytorch.utilities import grad_norm + +from matcha import utils +from matcha.utils.utils import plot_tensor + +log = utils.get_pylogger(__name__) + + +class BaseLightningClass(LightningModule, ABC): + def update_data_statistics(self, data_statistics): + if data_statistics is None: + data_statistics = { + "mel_mean": 0.0, + "mel_std": 1.0, + } + + self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"])) + self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"])) + + def configure_optimizers(self) -> Any: + optimizer = self.hparams.optimizer(params=self.parameters()) + if self.hparams.scheduler not in (None, {}): + scheduler_args = {} + # Manage last epoch for exponential schedulers + if "last_epoch" in inspect.signature(self.hparams.scheduler.scheduler).parameters: + if hasattr(self, "ckpt_loaded_epoch"): + current_epoch = self.ckpt_loaded_epoch - 1 + else: + current_epoch = -1 + + scheduler_args.update({"optimizer": optimizer}) + scheduler = self.hparams.scheduler.scheduler(**scheduler_args) + scheduler.last_epoch = current_epoch + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": self.hparams.scheduler.lightning_args.interval, + "frequency": self.hparams.scheduler.lightning_args.frequency, + "name": "learning_rate", + }, + } + + return {"optimizer": optimizer} + + def get_losses(self, batch): + x, x_lengths = batch["x"], batch["x_lengths"] + y, y_lengths = batch["y"], batch["y_lengths"] + spks = batch["spks"] + + dur_loss, prior_loss, diff_loss = self( + x=x, + x_lengths=x_lengths, + y=y, + y_lengths=y_lengths, + spks=spks, + out_size=self.out_size, + ) + return { + "dur_loss": dur_loss, + "prior_loss": prior_loss, + "diff_loss": diff_loss, + } + + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + self.ckpt_loaded_epoch = checkpoint["epoch"] # pylint: disable=attribute-defined-outside-init + + def training_step(self, batch: Any, batch_idx: int): + loss_dict = self.get_losses(batch) + self.log( + "step", + float(self.global_step), + on_step=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + + self.log( + "sub_loss/train_dur_loss", + loss_dict["dur_loss"], + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + ) + self.log( + "sub_loss/train_prior_loss", + loss_dict["prior_loss"], + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + ) + self.log( + "sub_loss/train_diff_loss", + loss_dict["diff_loss"], + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + ) + + total_loss = sum(loss_dict.values()) + self.log( + "loss/train", + total_loss, + on_step=True, + on_epoch=True, + logger=True, + prog_bar=True, + sync_dist=True, + ) + + return {"loss": total_loss, "log": loss_dict} + + def validation_step(self, batch: Any, batch_idx: int): + loss_dict = self.get_losses(batch) + self.log( + "sub_loss/val_dur_loss", + loss_dict["dur_loss"], + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + ) + self.log( + "sub_loss/val_prior_loss", + loss_dict["prior_loss"], + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + ) + self.log( + "sub_loss/val_diff_loss", + loss_dict["diff_loss"], + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + ) + + total_loss = sum(loss_dict.values()) + self.log( + "loss/val", + total_loss, + on_step=True, + on_epoch=True, + logger=True, + prog_bar=True, + sync_dist=True, + ) + + return total_loss + + def on_validation_end(self) -> None: + if self.trainer.is_global_zero: + one_batch = next(iter(self.trainer.val_dataloaders)) + if self.current_epoch == 0: + log.debug("Plotting original samples") + for i in range(2): + y = one_batch["y"][i].unsqueeze(0).to(self.device) + self.logger.experiment.add_image( + f"original/{i}", + plot_tensor(y.squeeze().cpu()), + self.current_epoch, + dataformats="HWC", + ) + + log.debug("Synthesising...") + for i in range(2): + x = one_batch["x"][i].unsqueeze(0).to(self.device) + x_lengths = one_batch["x_lengths"][i].unsqueeze(0).to(self.device) + spks = one_batch["spks"][i].unsqueeze(0).to(self.device) if one_batch["spks"] is not None else None + output = self.synthesise(x[:, :x_lengths], x_lengths, n_timesteps=10, spks=spks) + y_enc, y_dec = output["encoder_outputs"], output["decoder_outputs"] + attn = output["attn"] + self.logger.experiment.add_image( + f"generated_enc/{i}", + plot_tensor(y_enc.squeeze().cpu()), + self.current_epoch, + dataformats="HWC", + ) + self.logger.experiment.add_image( + f"generated_dec/{i}", + plot_tensor(y_dec.squeeze().cpu()), + self.current_epoch, + dataformats="HWC", + ) + self.logger.experiment.add_image( + f"alignment/{i}", + plot_tensor(attn.squeeze().cpu()), + self.current_epoch, + dataformats="HWC", + ) + + def on_before_optimizer_step(self, optimizer): + self.log_dict({f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()}) diff --git a/matcha/models/components/__init__.py b/matcha/models/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/matcha/models/components/decoder.py b/matcha/models/components/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..1137cd7008e9d07b4f306926a82e44c2b2cddbdf --- /dev/null +++ b/matcha/models/components/decoder.py @@ -0,0 +1,443 @@ +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from conformer import ConformerBlock +from diffusers.models.activations import get_activation +from einops import pack, rearrange, repeat + +from matcha.models.components.transformer import BasicTransformerBlock + + +class SinusoidalPosEmb(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even" + + def forward(self, x, scale=1000): + if x.ndim < 1: + x = x.unsqueeze(0) + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class Block1D(torch.nn.Module): + def __init__(self, dim, dim_out, groups=8): + super().__init__() + self.block = torch.nn.Sequential( + torch.nn.Conv1d(dim, dim_out, 3, padding=1), + torch.nn.GroupNorm(groups, dim_out), + nn.Mish(), + ) + + def forward(self, x, mask): + output = self.block(x * mask) + return output * mask + + +class ResnetBlock1D(torch.nn.Module): + def __init__(self, dim, dim_out, time_emb_dim, groups=8): + super().__init__() + self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out)) + + self.block1 = Block1D(dim, dim_out, groups=groups) + self.block2 = Block1D(dim_out, dim_out, groups=groups) + + self.res_conv = torch.nn.Conv1d(dim, dim_out, 1) + + def forward(self, x, mask, time_emb): + h = self.block1(x, mask) + h += self.mlp(time_emb).unsqueeze(-1) + h = self.block2(h, mask) + output = h + self.res_conv(x * mask) + return output + + +class Downsample1D(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + ): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = get_activation(act_fn) + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) + + if post_act_fn is None: + self.post_act = None + else: + self.post_act = get_activation(post_act_fn) + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class Upsample1D(nn.Module): + """A 1D upsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + use_conv_transpose (`bool`, default `False`): + option to use a convolution transpose. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + """ + + def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + self.conv = None + if use_conv_transpose: + self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) + elif use_conv: + self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) + + def forward(self, inputs): + assert inputs.shape[1] == self.channels + if self.use_conv_transpose: + return self.conv(inputs) + + outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest") + + if self.use_conv: + outputs = self.conv(outputs) + + return outputs + + +class ConformerWrapper(ConformerBlock): + def __init__( # pylint: disable=useless-super-delegation + self, + *, + dim, + dim_head=64, + heads=8, + ff_mult=4, + conv_expansion_factor=2, + conv_kernel_size=31, + attn_dropout=0, + ff_dropout=0, + conv_dropout=0, + conv_causal=False, + ): + super().__init__( + dim=dim, + dim_head=dim_head, + heads=heads, + ff_mult=ff_mult, + conv_expansion_factor=conv_expansion_factor, + conv_kernel_size=conv_kernel_size, + attn_dropout=attn_dropout, + ff_dropout=ff_dropout, + conv_dropout=conv_dropout, + conv_causal=conv_causal, + ) + + def forward( + self, + hidden_states, + attention_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + timestep=None, + ): + return super().forward(x=hidden_states, mask=attention_mask.bool()) + + +class Decoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + channels=(256, 256), + dropout=0.05, + attention_head_dim=64, + n_blocks=1, + num_mid_blocks=2, + num_heads=4, + act_fn="snake", + down_block_type="transformer", + mid_block_type="transformer", + up_block_type="transformer", + ): + super().__init__() + channels = tuple(channels) + self.in_channels = in_channels + self.out_channels = out_channels + + self.time_embeddings = SinusoidalPosEmb(in_channels) + time_embed_dim = channels[0] * 4 + self.time_mlp = TimestepEmbedding( + in_channels=in_channels, + time_embed_dim=time_embed_dim, + act_fn="silu", + ) + + self.down_blocks = nn.ModuleList([]) + self.mid_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + output_channel = in_channels + for i in range(len(channels)): # pylint: disable=consider-using-enumerate + input_channel = output_channel + output_channel = channels[i] + is_last = i == len(channels) - 1 + resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + transformer_blocks = nn.ModuleList( + [ + self.get_block( + down_block_type, + output_channel, + attention_head_dim, + num_heads, + dropout, + act_fn, + ) + for _ in range(n_blocks) + ] + ) + downsample = ( + Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1) + ) + + self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample])) + + for i in range(num_mid_blocks): + input_channel = channels[-1] + out_channels = channels[-1] + + resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + + transformer_blocks = nn.ModuleList( + [ + self.get_block( + mid_block_type, + output_channel, + attention_head_dim, + num_heads, + dropout, + act_fn, + ) + for _ in range(n_blocks) + ] + ) + + self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks])) + + channels = channels[::-1] + (channels[0],) + for i in range(len(channels) - 1): + input_channel = channels[i] + output_channel = channels[i + 1] + is_last = i == len(channels) - 2 + + resnet = ResnetBlock1D( + dim=2 * input_channel, + dim_out=output_channel, + time_emb_dim=time_embed_dim, + ) + transformer_blocks = nn.ModuleList( + [ + self.get_block( + up_block_type, + output_channel, + attention_head_dim, + num_heads, + dropout, + act_fn, + ) + for _ in range(n_blocks) + ] + ) + upsample = ( + Upsample1D(output_channel, use_conv_transpose=True) + if not is_last + else nn.Conv1d(output_channel, output_channel, 3, padding=1) + ) + + self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample])) + + self.final_block = Block1D(channels[-1], channels[-1]) + self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1) + + self.initialize_weights() + # nn.init.normal_(self.final_proj.weight) + + @staticmethod + def get_block(block_type, dim, attention_head_dim, num_heads, dropout, act_fn): + if block_type == "conformer": + block = ConformerWrapper( + dim=dim, + dim_head=attention_head_dim, + heads=num_heads, + ff_mult=1, + conv_expansion_factor=2, + ff_dropout=dropout, + attn_dropout=dropout, + conv_dropout=dropout, + conv_kernel_size=31, + ) + elif block_type == "transformer": + block = BasicTransformerBlock( + dim=dim, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + else: + raise ValueError(f"Unknown block type {block_type}") + + return block + + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.GroupNorm): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x, mask, mu, t, spks=None, cond=None): + """Forward pass of the UNet1DConditional model. + + Args: + x (torch.Tensor): shape (batch_size, in_channels, time) + mask (_type_): shape (batch_size, 1, time) + t (_type_): shape (batch_size) + spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None. + cond (_type_, optional): placeholder for future use. Defaults to None. + + Raises: + ValueError: _description_ + ValueError: _description_ + + Returns: + _type_: _description_ + """ + + t = self.time_embeddings(t) + t = self.time_mlp(t) + + x = pack([x, mu], "b * t")[0] + + if spks is not None: + spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) + x = pack([x, spks], "b * t")[0] + + hiddens = [] + masks = [mask] + for resnet, transformer_blocks, downsample in self.down_blocks: + mask_down = masks[-1] + x = resnet(x, mask_down, t) + x = rearrange(x, "b c t -> b t c") + mask_down = rearrange(mask_down, "b 1 t -> b t") + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=mask_down, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t") + mask_down = rearrange(mask_down, "b t -> b 1 t") + hiddens.append(x) # Save hidden states for skip connections + x = downsample(x * mask_down) + masks.append(mask_down[:, :, ::2]) + + masks = masks[:-1] + mask_mid = masks[-1] + + for resnet, transformer_blocks in self.mid_blocks: + x = resnet(x, mask_mid, t) + x = rearrange(x, "b c t -> b t c") + mask_mid = rearrange(mask_mid, "b 1 t -> b t") + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=mask_mid, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t") + mask_mid = rearrange(mask_mid, "b t -> b 1 t") + + for resnet, transformer_blocks, upsample in self.up_blocks: + mask_up = masks.pop() + x = resnet(pack([x, hiddens.pop()], "b * t")[0], mask_up, t) + x = rearrange(x, "b c t -> b t c") + mask_up = rearrange(mask_up, "b 1 t -> b t") + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=mask_up, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t") + mask_up = rearrange(mask_up, "b t -> b 1 t") + x = upsample(x * mask_up) + + x = self.final_block(x, mask_up) + output = self.final_proj(x * mask_up) + + return output * mask diff --git a/matcha/models/components/flow_matching.py b/matcha/models/components/flow_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..5cad7431ef66a8d11da32a77c1af7f6e31d6b774 --- /dev/null +++ b/matcha/models/components/flow_matching.py @@ -0,0 +1,132 @@ +from abc import ABC + +import torch +import torch.nn.functional as F + +from matcha.models.components.decoder import Decoder +from matcha.utils.pylogger import get_pylogger + +log = get_pylogger(__name__) + + +class BASECFM(torch.nn.Module, ABC): + def __init__( + self, + n_feats, + cfm_params, + n_spks=1, + spk_emb_dim=128, + ): + super().__init__() + self.n_feats = n_feats + self.n_spks = n_spks + self.spk_emb_dim = spk_emb_dim + self.solver = cfm_params.solver + if hasattr(cfm_params, "sigma_min"): + self.sigma_min = cfm_params.sigma_min + else: + self.sigma_min = 1e-4 + + self.estimator = None + + @torch.inference_mode() + def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): + """Forward diffusion + + Args: + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + n_timesteps (int): number of diffusion steps + temperature (float, optional): temperature for scaling noise. Defaults to 1.0. + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + + Returns: + sample: generated mel-spectrogram + shape: (batch_size, n_feats, mel_timesteps) + """ + z = torch.randn_like(mu) * temperature + t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) + return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond) + + def solve_euler(self, x, t_span, mu, mask, spks, cond): + """ + Fixed euler solver for ODEs. + Args: + x (torch.Tensor): random noise + t_span (torch.Tensor): n_timesteps interpolated + shape: (n_timesteps + 1,) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + """ + t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] + + # I am storing this because I can later plot it by putting a debugger here and saving it to a file + # Or in future might add like a return_all_steps flag + sol = [] + + for step in range(1, len(t_span)): + dphi_dt = self.estimator(x, mask, mu, t, spks, cond) + + x = x + dt * dphi_dt + t = t + dt + sol.append(x) + if step < len(t_span) - 1: + dt = t_span[step + 1] - t + + return sol[-1] + + def compute_loss(self, x1, mask, mu, spks=None, cond=None): + """Computes diffusion loss + + Args: + x1 (torch.Tensor): Target + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): target mask + shape: (batch_size, 1, mel_timesteps) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + spks (torch.Tensor, optional): speaker embedding. Defaults to None. + shape: (batch_size, spk_emb_dim) + + Returns: + loss: conditional flow matching loss + y: conditional flow + shape: (batch_size, n_feats, mel_timesteps) + """ + b, _, t = mu.shape + + # random timestep + t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) + # sample noise p(x_0) + z = torch.randn_like(x1) + + y = (1 - (1 - self.sigma_min) * t) * z + t * x1 + u = x1 - (1 - self.sigma_min) * z + + loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / ( + torch.sum(mask) * u.shape[1] + ) + return loss, y + + +class CFM(BASECFM): + def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64): + super().__init__( + n_feats=in_channels, + cfm_params=cfm_params, + n_spks=n_spks, + spk_emb_dim=spk_emb_dim, + ) + + in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0) + # Just change the architecture of the estimator here + self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params) diff --git a/matcha/models/components/text_encoder.py b/matcha/models/components/text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a388d05d6351fa2c9d9632fed0942d51fbec067b --- /dev/null +++ b/matcha/models/components/text_encoder.py @@ -0,0 +1,410 @@ +""" from https://github.com/jaywalnut310/glow-tts """ + +import math + +import torch +import torch.nn as nn +from einops import rearrange + +import matcha.utils as utils +from matcha.utils.model import sequence_mask + +log = utils.get_pylogger(__name__) + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-4): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = torch.nn.Parameter(torch.ones(channels)) + self.beta = torch.nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + n_dims = len(x.shape) + mean = torch.mean(x, 1, keepdim=True) + variance = torch.mean((x - mean) ** 2, 1, keepdim=True) + + x = (x - mean) * torch.rsqrt(variance + self.eps) + + shape = [1, -1] + [1] * (n_dims - 2) + x = x * self.gamma.view(*shape) + self.beta.view(*shape) + return x + + +class ConvReluNorm(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + + self.conv_layers = torch.nn.ModuleList() + self.norm_layers = torch.nn.ModuleList() + self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout)) + for _ in range(n_layers - 1): + self.conv_layers.append( + torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2) + ) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask + + +class DurationPredictor(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout): + super().__init__() + self.in_channels = in_channels + self.filter_channels = filter_channels + self.p_dropout = p_dropout + + self.drop = torch.nn.Dropout(p_dropout) + self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_1 = LayerNorm(filter_channels) + self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_2 = LayerNorm(filter_channels) + self.proj = torch.nn.Conv1d(filter_channels, 1, 1) + + def forward(self, x, x_mask): + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask + + +class RotaryPositionalEmbeddings(nn.Module): + """ + ## RoPE module + + Rotary encoding transforms pairs of features by rotating in the 2D plane. + That is, it organizes the $d$ features as $\frac{d}{2}$ pairs. + Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it + by an angle depending on the position of the token. + """ + + def __init__(self, d: int, base: int = 10_000): + r""" + * `d` is the number of features $d$ + * `base` is the constant used for calculating $\Theta$ + """ + super().__init__() + + self.base = base + self.d = int(d) + self.cos_cached = None + self.sin_cached = None + + def _build_cache(self, x: torch.Tensor): + r""" + Cache $\cos$ and $\sin$ values + """ + # Return if cache is already built + if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]: + return + + # Get sequence length + seq_len = x.shape[0] + + # $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device) + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.einsum("n,d->nd", seq_idx, theta) + + # Concatenate so that for row $m$ we have + # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$ + idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1) + + # Cache them + self.cos_cached = idx_theta2.cos()[:, None, None, :] + self.sin_cached = idx_theta2.sin()[:, None, None, :] + + def _neg_half(self, x: torch.Tensor): + # $\frac{d}{2}$ + d_2 = self.d // 2 + + # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ + return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1) + + def forward(self, x: torch.Tensor): + """ + * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]` + """ + # Cache $\cos$ and $\sin$ values + x = rearrange(x, "b h t d -> t b h d") + + self._build_cache(x) + + # Split the features, we can choose to apply rotary embeddings only to a partial set of features. + x_rope, x_pass = x[..., : self.d], x[..., self.d :] + + # Calculate + # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ + neg_half_x = self._neg_half(x_rope) + + x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]]) + + return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d") + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels, + out_channels, + n_heads, + heads_share=True, + p_dropout=0.0, + proximal_bias=False, + proximal_init=False, + ): + super().__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.heads_share = heads_share + self.proximal_bias = proximal_bias + self.p_dropout = p_dropout + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = torch.nn.Conv1d(channels, channels, 1) + self.conv_k = torch.nn.Conv1d(channels, channels, 1) + self.conv_v = torch.nn.Conv1d(channels, channels, 1) + + # from https://nn.labml.ai/transformers/rope/index.html + self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) + self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) + + self.conv_o = torch.nn.Conv1d(channels, out_channels, 1) + self.drop = torch.nn.Dropout(p_dropout) + + torch.nn.init.xavier_uniform_(self.conv_q.weight) + torch.nn.init.xavier_uniform_(self.conv_k.weight) + if proximal_init: + self.conv_k.weight.data.copy_(self.conv_q.weight.data) + self.conv_k.bias.data.copy_(self.conv_q.bias.data) + torch.nn.init.xavier_uniform_(self.conv_v.weight) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, self.attn = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = rearrange(query, "b (h c) t-> b h t c", h=self.n_heads) + key = rearrange(key, "b (h c) t-> b h t c", h=self.n_heads) + value = rearrange(value, "b (h c) t-> b h t c", h=self.n_heads) + + query = self.query_rotary_pe(query) + key = self.key_rotary_pe(key) + + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels) + + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + p_attn = torch.nn.functional.softmax(scores, dim=-1) + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + output = output.transpose(2, 3).contiguous().view(b, d, t_t) + return output, p_attn + + @staticmethod + def _attention_bias_proximal(length): + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(nn.Module): + def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2) + self.drop = torch.nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + return x * x_mask + + +class Encoder(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + **kwargs, + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.drop = torch.nn.Dropout(p_dropout) + self.attn_layers = torch.nn.ModuleList() + self.norm_layers_1 = torch.nn.ModuleList() + self.ffn_layers = torch.nn.ModuleList() + self.norm_layers_2 = torch.nn.ModuleList() + for _ in range(self.n_layers): + self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + for i in range(self.n_layers): + x = x * x_mask + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class TextEncoder(nn.Module): + def __init__( + self, + encoder_type, + encoder_params, + duration_predictor_params, + n_vocab, + n_spks=1, + spk_emb_dim=128, + ): + super().__init__() + self.encoder_type = encoder_type + self.n_vocab = n_vocab + self.n_feats = encoder_params.n_feats + self.n_channels = encoder_params.n_channels + self.spk_emb_dim = spk_emb_dim + self.n_spks = n_spks + + self.emb = torch.nn.Embedding(n_vocab, self.n_channels) + torch.nn.init.normal_(self.emb.weight, 0.0, self.n_channels**-0.5) + + if encoder_params.prenet: + self.prenet = ConvReluNorm( + self.n_channels, + self.n_channels, + self.n_channels, + kernel_size=5, + n_layers=3, + p_dropout=0.5, + ) + else: + self.prenet = lambda x, x_mask: x + + self.encoder = Encoder( + encoder_params.n_channels + (spk_emb_dim if n_spks > 1 else 0), + encoder_params.filter_channels, + encoder_params.n_heads, + encoder_params.n_layers, + encoder_params.kernel_size, + encoder_params.p_dropout, + ) + + self.proj_m = torch.nn.Conv1d(self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1) + self.proj_w = DurationPredictor( + self.n_channels + (spk_emb_dim if n_spks > 1 else 0), + duration_predictor_params.filter_channels_dp, + duration_predictor_params.kernel_size, + duration_predictor_params.p_dropout, + ) + + def forward(self, x, x_lengths, spks=None): + """Run forward pass to the transformer based encoder and duration predictor + + Args: + x (torch.Tensor): text input + shape: (batch_size, max_text_length) + x_lengths (torch.Tensor): text input lengths + shape: (batch_size,) + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size,) + + Returns: + mu (torch.Tensor): average output of the encoder + shape: (batch_size, n_feats, max_text_length) + logw (torch.Tensor): log duration predicted by the duration predictor + shape: (batch_size, 1, max_text_length) + x_mask (torch.Tensor): mask for the text input + shape: (batch_size, 1, max_text_length) + """ + x = self.emb(x) * math.sqrt(self.n_channels) + x = torch.transpose(x, 1, -1) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + + x = self.prenet(x, x_mask) + if self.n_spks > 1: + x = torch.cat([x, spks.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1) + x = self.encoder(x, x_mask) + mu = self.proj_m(x) * x_mask + + x_dp = torch.detach(x) + logw = self.proj_w(x_dp, x_mask) + + return mu, logw, x_mask diff --git a/matcha/models/components/transformer.py b/matcha/models/components/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..dd1afa3aff5383912209e508676c6885e13ef4ee --- /dev/null +++ b/matcha/models/components/transformer.py @@ -0,0 +1,316 @@ +from typing import Any, Dict, Optional + +import torch +import torch.nn as nn +from diffusers.models.attention import ( + GEGLU, + GELU, + AdaLayerNorm, + AdaLayerNormZero, + ApproximateGELU, +) +from diffusers.models.attention_processor import Attention +from diffusers.models.lora import LoRACompatibleLinear +from diffusers.utils.torch_utils import maybe_allow_in_graph + + +class SnakeBeta(nn.Module): + """ + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + """ + + def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): + """ + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + """ + super().__init__() + self.in_features = out_features if isinstance(out_features, list) else [out_features] + self.proj = LoRACompatibleLinear(in_features, out_features) + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha) + self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha) + self.beta = nn.Parameter(torch.ones(self.in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + """ + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + """ + x = self.proj(x) + if self.alpha_logscale: + alpha = torch.exp(self.alpha) + beta = torch.exp(self.beta) + else: + alpha = self.alpha + beta = self.beta + + x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2) + + return x + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh") + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim) + elif activation_fn == "snakebeta": + act_fn = SnakeBeta(dim, inner_dim) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(LoRACompatibleLinear(inner_dim, dim_out)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states): + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +@maybe_allow_in_graph +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + final_dropout: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + # scale_qk=False, # uncomment this to not to use flash attention + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + ): + # Notice that normalization is always applied before the real computation in the following blocks. + # 1. Self-Attention + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat( + [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states diff --git a/matcha/models/matcha_tts.py b/matcha/models/matcha_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..64b2c07fe8de4760aee1aed80d206112d30df55f --- /dev/null +++ b/matcha/models/matcha_tts.py @@ -0,0 +1,239 @@ +import datetime as dt +import math +import random + +import torch + +import matcha.utils.monotonic_align as monotonic_align +from matcha import utils +from matcha.models.baselightningmodule import BaseLightningClass +from matcha.models.components.flow_matching import CFM +from matcha.models.components.text_encoder import TextEncoder +from matcha.utils.model import ( + denormalize, + duration_loss, + fix_len_compatibility, + generate_path, + sequence_mask, +) + +log = utils.get_pylogger(__name__) + + +class MatchaTTS(BaseLightningClass): # 🍵 + def __init__( + self, + n_vocab, + n_spks, + spk_emb_dim, + n_feats, + encoder, + decoder, + cfm, + data_statistics, + out_size, + optimizer=None, + scheduler=None, + prior_loss=True, + ): + super().__init__() + + self.save_hyperparameters(logger=False) + + self.n_vocab = n_vocab + self.n_spks = n_spks + self.spk_emb_dim = spk_emb_dim + self.n_feats = n_feats + self.out_size = out_size + self.prior_loss = prior_loss + + if n_spks > 1: + self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim) + + self.encoder = TextEncoder( + encoder.encoder_type, + encoder.encoder_params, + encoder.duration_predictor_params, + n_vocab, + n_spks, + spk_emb_dim, + ) + + self.decoder = CFM( + in_channels=2 * encoder.encoder_params.n_feats, + out_channel=encoder.encoder_params.n_feats, + cfm_params=cfm, + decoder_params=decoder, + n_spks=n_spks, + spk_emb_dim=spk_emb_dim, + ) + + self.update_data_statistics(data_statistics) + + @torch.inference_mode() + def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, length_scale=1.0): + """ + Generates mel-spectrogram from text. Returns: + 1. encoder outputs + 2. decoder outputs + 3. generated alignment + + Args: + x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. + shape: (batch_size, max_text_length) + x_lengths (torch.Tensor): lengths of texts in batch. + shape: (batch_size,) + n_timesteps (int): number of steps to use for reverse diffusion in decoder. + temperature (float, optional): controls variance of terminal distribution. + spks (bool, optional): speaker ids. + shape: (batch_size,) + length_scale (float, optional): controls speech pace. + Increase value to slow down generated speech and vice versa. + + Returns: + dict: { + "encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), + # Average mel spectrogram generated by the encoder + "decoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), + # Refined mel spectrogram improved by the CFM + "attn": torch.Tensor, shape: (batch_size, max_text_length, max_mel_length), + # Alignment map between text and mel spectrogram + "mel": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), + # Denormalized mel spectrogram + "mel_lengths": torch.Tensor, shape: (batch_size,), + # Lengths of mel spectrograms + "rtf": float, + # Real-time factor + """ + # For RTF computation + t = dt.datetime.now() + + if self.n_spks > 1: + # Get speaker embedding + spks = self.spk_emb(spks.long()) + + # Get encoder_outputs `mu_x` and log-scaled token durations `logw` + mu_x, logw, x_mask = self.encoder(x, x_lengths, spks) + + w = torch.exp(logw) * x_mask + w_ceil = torch.ceil(w) * length_scale + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_max_length = y_lengths.max() + y_max_length_ = fix_len_compatibility(y_max_length) + + # Using obtained durations `w` construct alignment map `attn` + y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype) + attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) + attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) + + # Align encoded text and get mu_y + mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) + mu_y = mu_y.transpose(1, 2) + encoder_outputs = mu_y[:, :, :y_max_length] + + # Generate sample tracing the probability flow + decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, spks) + decoder_outputs = decoder_outputs[:, :, :y_max_length] + + t = (dt.datetime.now() - t).total_seconds() + rtf = t * 22050 / (decoder_outputs.shape[-1] * 256) + + return { + "encoder_outputs": encoder_outputs, + "decoder_outputs": decoder_outputs, + "attn": attn[:, :, :y_max_length], + "mel": denormalize(decoder_outputs, self.mel_mean, self.mel_std), + "mel_lengths": y_lengths, + "rtf": rtf, + } + + def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None): + """ + Computes 3 losses: + 1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS). + 2. prior loss: loss between mel-spectrogram and encoder outputs. + 3. flow matching loss: loss between mel-spectrogram and decoder outputs. + + Args: + x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. + shape: (batch_size, max_text_length) + x_lengths (torch.Tensor): lengths of texts in batch. + shape: (batch_size,) + y (torch.Tensor): batch of corresponding mel-spectrograms. + shape: (batch_size, n_feats, max_mel_length) + y_lengths (torch.Tensor): lengths of mel-spectrograms in batch. + shape: (batch_size,) + out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained. + Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size. + spks (torch.Tensor, optional): speaker ids. + shape: (batch_size,) + """ + if self.n_spks > 1: + # Get speaker embedding + spks = self.spk_emb(spks) + + # Get encoder_outputs `mu_x` and log-scaled token durations `logw` + mu_x, logw, x_mask = self.encoder(x, x_lengths, spks) + y_max_length = y.shape[-1] + + y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask) + attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) + + # Use MAS to find most likely alignment `attn` between text and mel-spectrogram + with torch.no_grad(): + const = -0.5 * math.log(2 * math.pi) * self.n_feats + factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device) + y_square = torch.matmul(factor.transpose(1, 2), y**2) + y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y) + mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1) + log_prior = y_square - y_mu_double + mu_square + const + + attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1)) + attn = attn.detach() + + # Compute loss between predicted log-scaled durations and those obtained from MAS + # refered to as prior loss in the paper + logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask + dur_loss = duration_loss(logw, logw_, x_lengths) + + # Cut a small segment of mel-spectrogram in order to increase batch size + # - "Hack" taken from Grad-TTS, in case of Grad-TTS, we cannot train batch size 32 on a 24GB GPU without it + # - Do not need this hack for Matcha-TTS, but it works with it as well + if not isinstance(out_size, type(None)): + max_offset = (y_lengths - out_size).clamp(0) + offset_ranges = list(zip([0] * max_offset.shape[0], max_offset.cpu().numpy())) + out_offset = torch.LongTensor( + [torch.tensor(random.choice(range(start, end)) if end > start else 0) for start, end in offset_ranges] + ).to(y_lengths) + attn_cut = torch.zeros(attn.shape[0], attn.shape[1], out_size, dtype=attn.dtype, device=attn.device) + y_cut = torch.zeros(y.shape[0], self.n_feats, out_size, dtype=y.dtype, device=y.device) + + y_cut_lengths = [] + for i, (y_, out_offset_) in enumerate(zip(y, out_offset)): + y_cut_length = out_size + (y_lengths[i] - out_size).clamp(None, 0) + y_cut_lengths.append(y_cut_length) + cut_lower, cut_upper = out_offset_, out_offset_ + y_cut_length + y_cut[i, :, :y_cut_length] = y_[:, cut_lower:cut_upper] + attn_cut[i, :, :y_cut_length] = attn[i, :, cut_lower:cut_upper] + + y_cut_lengths = torch.LongTensor(y_cut_lengths) + y_cut_mask = sequence_mask(y_cut_lengths).unsqueeze(1).to(y_mask) + + attn = attn_cut + y = y_cut + y_mask = y_cut_mask + + # Align encoded text with mel-spectrogram and get mu_y segment + mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) + mu_y = mu_y.transpose(1, 2) + + # Compute loss of the decoder + diff_loss, _ = self.decoder.compute_loss(x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond) + + if self.prior_loss: + prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask) + prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats) + else: + prior_loss = 0 + + return dur_loss, prior_loss, diff_loss diff --git a/matcha/onnx/__init__.py b/matcha/onnx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/matcha/onnx/export.py b/matcha/onnx/export.py new file mode 100644 index 0000000000000000000000000000000000000000..9b795086158e1ad8a4bb5cd92306f3fa765f71ea --- /dev/null +++ b/matcha/onnx/export.py @@ -0,0 +1,181 @@ +import argparse +import random +from pathlib import Path + +import numpy as np +import torch +from lightning import LightningModule + +from matcha.cli import VOCODER_URLS, load_matcha, load_vocoder + +DEFAULT_OPSET = 15 + +SEED = 1234 +random.seed(SEED) +np.random.seed(SEED) +torch.manual_seed(SEED) +torch.cuda.manual_seed(SEED) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + + +class MatchaWithVocoder(LightningModule): + def __init__(self, matcha, vocoder): + super().__init__() + self.matcha = matcha + self.vocoder = vocoder + + def forward(self, x, x_lengths, scales, spks=None): + mel, mel_lengths = self.matcha(x, x_lengths, scales, spks) + wavs = self.vocoder(mel).clamp(-1, 1) + lengths = mel_lengths * 256 + return wavs.squeeze(1), lengths + + +def get_exportable_module(matcha, vocoder, n_timesteps): + """ + Return an appropriate `LighteningModule` and output-node names + based on whether the vocoder is embedded in the final graph + """ + + def onnx_forward_func(x, x_lengths, scales, spks=None): + """ + Custom forward function for accepting + scaler parameters as tensors + """ + # Extract scaler parameters from tensors + temperature = scales[0] + length_scale = scales[1] + output = matcha.synthesise(x, x_lengths, n_timesteps, temperature, spks, length_scale) + return output["mel"], output["mel_lengths"] + + # Monkey-patch Matcha's forward function + matcha.forward = onnx_forward_func + + if vocoder is None: + model, output_names = matcha, ["mel", "mel_lengths"] + else: + model = MatchaWithVocoder(matcha, vocoder) + output_names = ["wav", "wav_lengths"] + return model, output_names + + +def get_inputs(is_multi_speaker): + """ + Create dummy inputs for tracing + """ + dummy_input_length = 50 + x = torch.randint(low=0, high=20, size=(1, dummy_input_length), dtype=torch.long) + x_lengths = torch.LongTensor([dummy_input_length]) + + # Scales + temperature = 0.667 + length_scale = 1.0 + scales = torch.Tensor([temperature, length_scale]) + + model_inputs = [x, x_lengths, scales] + input_names = [ + "x", + "x_lengths", + "scales", + ] + + if is_multi_speaker: + spks = torch.LongTensor([1]) + model_inputs.append(spks) + input_names.append("spks") + + return tuple(model_inputs), input_names + + +def main(): + parser = argparse.ArgumentParser(description="Export 🍵 Matcha-TTS to ONNX") + + parser.add_argument( + "checkpoint_path", + type=str, + help="Path to the model checkpoint", + ) + parser.add_argument("output", type=str, help="Path to output `.onnx` file") + parser.add_argument( + "--n-timesteps", type=int, default=5, help="Number of steps to use for reverse diffusion in decoder (default 5)" + ) + parser.add_argument( + "--vocoder-name", + type=str, + choices=list(VOCODER_URLS.keys()), + default=None, + help="Name of the vocoder to embed in the ONNX graph", + ) + parser.add_argument( + "--vocoder-checkpoint-path", + type=str, + default=None, + help="Vocoder checkpoint to embed in the ONNX graph for an `e2e` like experience", + ) + parser.add_argument("--opset", type=int, default=DEFAULT_OPSET, help="ONNX opset version to use (default 15") + + args = parser.parse_args() + + print(f"[🍵] Loading Matcha checkpoint from {args.checkpoint_path}") + print(f"Setting n_timesteps to {args.n_timesteps}") + + checkpoint_path = Path(args.checkpoint_path) + matcha = load_matcha(checkpoint_path.stem, checkpoint_path, "cpu") + + if args.vocoder_name or args.vocoder_checkpoint_path: + assert ( + args.vocoder_name and args.vocoder_checkpoint_path + ), "Both vocoder_name and vocoder-checkpoint are required when embedding the vocoder in the ONNX graph." + vocoder, _ = load_vocoder(args.vocoder_name, args.vocoder_checkpoint_path, "cpu") + else: + vocoder = None + + is_multi_speaker = matcha.n_spks > 1 + + dummy_input, input_names = get_inputs(is_multi_speaker) + model, output_names = get_exportable_module(matcha, vocoder, args.n_timesteps) + + # Set dynamic shape for inputs/outputs + dynamic_axes = { + "x": {0: "batch_size", 1: "time"}, + "x_lengths": {0: "batch_size"}, + } + + if vocoder is None: + dynamic_axes.update( + { + "mel": {0: "batch_size", 2: "time"}, + "mel_lengths": {0: "batch_size"}, + } + ) + else: + print("Embedding the vocoder in the ONNX graph") + dynamic_axes.update( + { + "wav": {0: "batch_size", 1: "time"}, + "wav_lengths": {0: "batch_size"}, + } + ) + + if is_multi_speaker: + dynamic_axes["spks"] = {0: "batch_size"} + + # Create the output directory (if not exists) + Path(args.output).parent.mkdir(parents=True, exist_ok=True) + + model.to_onnx( + args.output, + dummy_input, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + opset_version=args.opset, + export_params=True, + do_constant_folding=True, + ) + print(f"[🍵] ONNX model exported to {args.output}") + + +if __name__ == "__main__": + main() diff --git a/matcha/onnx/infer.py b/matcha/onnx/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..89ca92559c6df3776a07a038d7838242a3d19189 --- /dev/null +++ b/matcha/onnx/infer.py @@ -0,0 +1,168 @@ +import argparse +import os +import warnings +from pathlib import Path +from time import perf_counter + +import numpy as np +import onnxruntime as ort +import soundfile as sf +import torch + +from matcha.cli import plot_spectrogram_to_numpy, process_text + + +def validate_args(args): + assert ( + args.text or args.file + ), "Either text or file must be provided Matcha-T(ea)TTS need sometext to whisk the waveforms." + assert args.temperature >= 0, "Sampling temperature cannot be negative" + assert args.speaking_rate >= 0, "Speaking rate must be greater than 0" + return args + + +def write_wavs(model, inputs, output_dir, external_vocoder=None): + if external_vocoder is None: + print("The provided model has the vocoder embedded in the graph.\nGenerating waveform directly") + t0 = perf_counter() + wavs, wav_lengths = model.run(None, inputs) + infer_secs = perf_counter() - t0 + mel_infer_secs = vocoder_infer_secs = None + else: + print("[🍵] Generating mel using Matcha") + mel_t0 = perf_counter() + mels, mel_lengths = model.run(None, inputs) + mel_infer_secs = perf_counter() - mel_t0 + print("Generating waveform from mel using external vocoder") + vocoder_inputs = {external_vocoder.get_inputs()[0].name: mels} + vocoder_t0 = perf_counter() + wavs = external_vocoder.run(None, vocoder_inputs)[0] + vocoder_infer_secs = perf_counter() - vocoder_t0 + wavs = wavs.squeeze(1) + wav_lengths = mel_lengths * 256 + infer_secs = mel_infer_secs + vocoder_infer_secs + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + for i, (wav, wav_length) in enumerate(zip(wavs, wav_lengths)): + output_filename = output_dir.joinpath(f"output_{i + 1}.wav") + audio = wav[:wav_length] + print(f"Writing audio to {output_filename}") + sf.write(output_filename, audio, 22050, "PCM_24") + + wav_secs = wav_lengths.sum() / 22050 + print(f"Inference seconds: {infer_secs}") + print(f"Generated wav seconds: {wav_secs}") + rtf = infer_secs / wav_secs + if mel_infer_secs is not None: + mel_rtf = mel_infer_secs / wav_secs + print(f"Matcha RTF: {mel_rtf}") + if vocoder_infer_secs is not None: + vocoder_rtf = vocoder_infer_secs / wav_secs + print(f"Vocoder RTF: {vocoder_rtf}") + print(f"Overall RTF: {rtf}") + + +def write_mels(model, inputs, output_dir): + t0 = perf_counter() + mels, mel_lengths = model.run(None, inputs) + infer_secs = perf_counter() - t0 + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + for i, mel in enumerate(mels): + output_stem = output_dir.joinpath(f"output_{i + 1}") + plot_spectrogram_to_numpy(mel.squeeze(), output_stem.with_suffix(".png")) + np.save(output_stem.with_suffix(".numpy"), mel) + + wav_secs = (mel_lengths * 256).sum() / 22050 + print(f"Inference seconds: {infer_secs}") + print(f"Generated wav seconds: {wav_secs}") + rtf = infer_secs / wav_secs + print(f"RTF: {rtf}") + + +def main(): + parser = argparse.ArgumentParser( + description=" 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching" + ) + parser.add_argument( + "model", + type=str, + help="ONNX model to use", + ) + parser.add_argument("--vocoder", type=str, default=None, help="Vocoder to use (defaults to None)") + parser.add_argument("--text", type=str, default=None, help="Text to synthesize") + parser.add_argument("--file", type=str, default=None, help="Text file to synthesize") + parser.add_argument("--spk", type=int, default=None, help="Speaker ID") + parser.add_argument( + "--temperature", + type=float, + default=0.667, + help="Variance of the x0 noise (default: 0.667)", + ) + parser.add_argument( + "--speaking-rate", + type=float, + default=1.0, + help="change the speaking rate, a higher value means slower speaking rate (default: 1.0)", + ) + parser.add_argument("--gpu", action="store_true", help="Use CPU for inference (default: use GPU if available)") + parser.add_argument( + "--output-dir", + type=str, + default=os.getcwd(), + help="Output folder to save results (default: current dir)", + ) + + args = parser.parse_args() + args = validate_args(args) + + if args.gpu: + providers = ["GPUExecutionProvider"] + else: + providers = ["CPUExecutionProvider"] + model = ort.InferenceSession(args.model, providers=providers) + + model_inputs = model.get_inputs() + model_outputs = list(model.get_outputs()) + + if args.text: + text_lines = args.text.splitlines() + else: + with open(args.file, encoding="utf-8") as file: + text_lines = file.read().splitlines() + + processed_lines = [process_text(0, line, "cpu") for line in text_lines] + x = [line["x"].squeeze() for line in processed_lines] + # Pad + x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True) + x = x.detach().cpu().numpy() + x_lengths = np.array([line["x_lengths"].item() for line in processed_lines], dtype=np.int64) + inputs = { + "x": x, + "x_lengths": x_lengths, + "scales": np.array([args.temperature, args.speaking_rate], dtype=np.float32), + } + is_multi_speaker = len(model_inputs) == 4 + if is_multi_speaker: + if args.spk is None: + args.spk = 0 + warn = "[!] Speaker ID not provided! Using speaker ID 0" + warnings.warn(warn, UserWarning) + inputs["spks"] = np.repeat(args.spk, x.shape[0]).astype(np.int64) + + has_vocoder_embedded = model_outputs[0].name == "wav" + if has_vocoder_embedded: + write_wavs(model, inputs, args.output_dir) + elif args.vocoder: + external_vocoder = ort.InferenceSession(args.vocoder, providers=providers) + write_wavs(model, inputs, args.output_dir, external_vocoder=external_vocoder) + else: + warn = "[!] A vocoder is not embedded in the graph nor an external vocoder is provided. The mel output will be written as numpy arrays to `*.npy` files in the output directory" + warnings.warn(warn, UserWarning) + write_mels(model, inputs, args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/matcha/text/__init__.py b/matcha/text/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6c6fae204ad99a4f180ba2566074c1d584f188ec --- /dev/null +++ b/matcha/text/__init__.py @@ -0,0 +1,54 @@ +""" from https://github.com/keithito/tacotron """ +from matcha.text import cleaners +from matcha.text.symbols import symbols + +# Mappings from symbol to numeric ID and vice versa: +_symbol_to_id = {s: i for i, s in enumerate(symbols)} +_id_to_symbol = {i: s for i, s in enumerate(symbols)} # pylint: disable=unnecessary-comprehension + + +def text_to_sequence(text, cleaner_names): + """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + Args: + text: string to convert to a sequence + cleaner_names: names of the cleaner functions to run the text through + Returns: + List of integers corresponding to the symbols in the text + """ + sequence = [] + + clean_text = _clean_text(text, cleaner_names) + for symbol in clean_text: + symbol_id = _symbol_to_id[symbol] + sequence += [symbol_id] + # print(text, clean_text, sep='\n') + return sequence + + +def cleaned_text_to_sequence(cleaned_text): + """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + Args: + text: string to convert to a sequence + Returns: + List of integers corresponding to the symbols in the text + """ + sequence = [_symbol_to_id[symbol] for symbol in cleaned_text] + return sequence + + +def sequence_to_text(sequence): + """Converts a sequence of IDs back to a string""" + result = "" + for symbol_id in sequence: + s = _id_to_symbol[symbol_id] + result += s + return result + + +def _clean_text(text, cleaner_names): + for name in cleaner_names: + cleaner = getattr(cleaners, name) + if not cleaner: + raise Exception("Unknown cleaner: %s" % name) + text = cleaner(text) + return text diff --git a/matcha/text/cleaners.py b/matcha/text/cleaners.py new file mode 100644 index 0000000000000000000000000000000000000000..50a821511fa02d3a108cda8115a5cc6b46abf292 --- /dev/null +++ b/matcha/text/cleaners.py @@ -0,0 +1,114 @@ +""" from https://github.com/keithito/tacotron + +Cleaners are transformations that run over the input text at both training and eval time. + +Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" +hyperparameter. Some cleaners are English-specific. You'll typically want to use: + 1. "english_cleaners" for English text + 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using + the Unidecode library (https://pypi.python.org/pypi/Unidecode) + 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update + the symbols in symbols.py to match your data). +""" + +import logging +import re + +import phonemizer +import piper_phonemize +from unidecode import unidecode + +# To avoid excessive logging we set the log level of the phonemizer package to Critical +critical_logger = logging.getLogger("phonemizer") +critical_logger.setLevel(logging.CRITICAL) + +# Intializing the phonemizer globally significantly reduces the speed +# now the phonemizer is not initialising at every call +# Might be less flexible, but it is much-much faster +global_phonemizer = phonemizer.backend.EspeakBackend( + language="ky", + preserve_punctuation=True, + with_stress=True, + language_switch="remove-flags", + logger=critical_logger, +) + + +# Regular expression matching whitespace: +_whitespace_re = re.compile(r"\s+") + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("mrs", "misess"), + ("mr", "mister"), + ("dr", "doctor"), + ("st", "saint"), + ("co", "company"), + ("jr", "junior"), + ("maj", "major"), + ("gen", "general"), + ("drs", "doctors"), + ("rev", "reverend"), + ("lt", "lieutenant"), + ("hon", "honorable"), + ("sgt", "sergeant"), + ("capt", "captain"), + ("esq", "esquire"), + ("ltd", "limited"), + ("col", "colonel"), + ("ft", "fort"), + ] +] + + +def expand_abbreviations(text): + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + + +def lowercase(text): + return text.lower() + + +def collapse_whitespace(text): + return re.sub(_whitespace_re, " ", text) + + +def convert_to_ascii(text): + return unidecode(text) + + +def basic_cleaners(text): + """Basic pipeline that lowercases and collapses whitespace without transliteration.""" + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def transliteration_cleaners(text): + """Pipeline for non-English text that transliterates to ASCII.""" + text = convert_to_ascii(text) + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def kyrgyz_cleaners(text): + """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" + text = lowercase(text) + phonemes = global_phonemizer.phonemize([text], strip=True, njobs=1)[0] + phonemes = collapse_whitespace(phonemes) + return phonemes + + +def english_cleaners_piper(text): + """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_abbreviations(text) + phonemes = "".join(piper_phonemize.phonemize_espeak(text=text, voice="en-US")[0]) + phonemes = collapse_whitespace(phonemes) + return phonemes diff --git a/matcha/text/numbers.py b/matcha/text/numbers.py new file mode 100644 index 0000000000000000000000000000000000000000..f99a8686dcb73532091122613e74bd643a8a327f --- /dev/null +++ b/matcha/text/numbers.py @@ -0,0 +1,71 @@ +""" from https://github.com/keithito/tacotron """ + +import re + +import inflect + +_inflect = inflect.engine() +_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") +_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") +_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") +_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") +_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") +_number_re = re.compile(r"[0-9]+") + + +def _remove_commas(m): + return m.group(1).replace(",", "") + + +def _expand_decimal_point(m): + return m.group(1).replace(".", " point ") + + +def _expand_dollars(m): + match = m.group(1) + parts = match.split(".") + if len(parts) > 2: + return match + " dollars" + dollars = int(parts[0]) if parts[0] else 0 + cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 + if dollars and cents: + dollar_unit = "dollar" if dollars == 1 else "dollars" + cent_unit = "cent" if cents == 1 else "cents" + return f"{dollars} {dollar_unit}, {cents} {cent_unit}" + elif dollars: + dollar_unit = "dollar" if dollars == 1 else "dollars" + return f"{dollars} {dollar_unit}" + elif cents: + cent_unit = "cent" if cents == 1 else "cents" + return f"{cents} {cent_unit}" + else: + return "zero dollars" + + +def _expand_ordinal(m): + return _inflect.number_to_words(m.group(0)) + + +def _expand_number(m): + num = int(m.group(0)) + if num > 1000 and num < 3000: + if num == 2000: + return "two thousand" + elif num > 2000 and num < 2010: + return "two thousand " + _inflect.number_to_words(num % 100) + elif num % 100 == 0: + return _inflect.number_to_words(num // 100) + " hundred" + else: + return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ") + else: + return _inflect.number_to_words(num, andword="") + + +def normalize_numbers(text): + text = re.sub(_comma_number_re, _remove_commas, text) + text = re.sub(_pounds_re, r"\1 pounds", text) + text = re.sub(_dollars_re, _expand_dollars, text) + text = re.sub(_decimal_number_re, _expand_decimal_point, text) + text = re.sub(_ordinal_re, _expand_ordinal, text) + text = re.sub(_number_re, _expand_number, text) + return text diff --git a/matcha/text/symbols.py b/matcha/text/symbols.py new file mode 100644 index 0000000000000000000000000000000000000000..d462218fc0fb8a332a92bebdbad98ce5578ba856 --- /dev/null +++ b/matcha/text/symbols.py @@ -0,0 +1,18 @@ +""" from https://github.com/keithito/tacotron + +Defines the set of symbols used in text input to the model. +""" +_pad = "_" +_punctuation = ';:,.!?¡¿—…"«»“” ' +_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" +_letters_ipa = ( + "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑-1['̩'ᵻ" +) + + +# Export all symbols: +symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) + +# Special symbol ids +SPACE_ID = symbols.index(" ") + diff --git a/matcha/utils/__init__.py b/matcha/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..074db6461184e8cbb86d977cb41d9ebd918e958a --- /dev/null +++ b/matcha/utils/__init__.py @@ -0,0 +1,5 @@ +from matcha.utils.instantiators import instantiate_callbacks, instantiate_loggers +from matcha.utils.logging_utils import log_hyperparameters +from matcha.utils.pylogger import get_pylogger +from matcha.utils.rich_utils import enforce_tags, print_config_tree +from matcha.utils.utils import extras, get_metric_value, task_wrapper diff --git a/matcha/utils/audio.py b/matcha/utils/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..0bcd74df47fb006f68deb5a5f4a4c2fb0aa84f57 --- /dev/null +++ b/matcha/utils/audio.py @@ -0,0 +1,82 @@ +import numpy as np +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn +from scipy.io.wavfile import read + +MAX_WAV_VALUE = 32768.0 + + +def load_wav(full_path): + sampling_rate, data = read(full_path) + return data, sampling_rate + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window # pylint: disable=global-statement + if f"{str(fmax)}_{str(y.device)}" not in mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) + hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" + ) + y = y.squeeze(1) + + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[str(y.device)], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + + spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec diff --git a/matcha/utils/generate_data_statistics.py b/matcha/utils/generate_data_statistics.py new file mode 100644 index 0000000000000000000000000000000000000000..96a5382296426803f1010385d184af7bfc901290 --- /dev/null +++ b/matcha/utils/generate_data_statistics.py @@ -0,0 +1,111 @@ +r""" +The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it +when needed. + +Parameters from hparam.py will be used +""" +import argparse +import json +import os +import sys +from pathlib import Path + +import rootutils +import torch +from hydra import compose, initialize +from omegaconf import open_dict +from tqdm.auto import tqdm + +from matcha.data.text_mel_datamodule import TextMelDataModule +from matcha.utils.logging_utils import pylogger + +log = pylogger.get_pylogger(__name__) + + +def compute_data_statistics(data_loader: torch.utils.data.DataLoader, out_channels: int): + """Generate data mean and standard deviation helpful in data normalisation + + Args: + data_loader (torch.utils.data.Dataloader): _description_ + out_channels (int): mel spectrogram channels + """ + total_mel_sum = 0 + total_mel_sq_sum = 0 + total_mel_len = 0 + + for batch in tqdm(data_loader, leave=False): + mels = batch["y"] + mel_lengths = batch["y_lengths"] + + total_mel_len += torch.sum(mel_lengths) + total_mel_sum += torch.sum(mels) + total_mel_sq_sum += torch.sum(torch.pow(mels, 2)) + + data_mean = total_mel_sum / (total_mel_len * out_channels) + data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2)) + + return {"mel_mean": data_mean.item(), "mel_std": data_std.item()} + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "-i", + "--input-config", + type=str, + default="vctk.yaml", + help="The name of the yaml config file under configs/data", + ) + + parser.add_argument( + "-b", + "--batch-size", + type=int, + default="256", + help="Can have increased batch size for faster computation", + ) + + parser.add_argument( + "-f", + "--force", + action="store_true", + default=False, + required=False, + help="force overwrite the file", + ) + args = parser.parse_args() + output_file = Path(args.input_config).with_suffix(".json") + + if os.path.exists(output_file) and not args.force: + print("File already exists. Use -f to force overwrite") + sys.exit(1) + + with initialize(version_base="1.3", config_path="../../configs/data"): + cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[]) + + root_path = rootutils.find_root(search_from=__file__, indicator=".project-root") + + with open_dict(cfg): + del cfg["hydra"] + del cfg["_target_"] + cfg["data_statistics"] = None + cfg["seed"] = 1234 + cfg["batch_size"] = args.batch_size + cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) + cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) + + text_mel_datamodule = TextMelDataModule(**cfg) + text_mel_datamodule.setup() + data_loader = text_mel_datamodule.train_dataloader() + log.info("Dataloader loaded! Now computing stats...") + params = compute_data_statistics(data_loader, cfg["n_feats"]) + print(params) + json.dump( + params, + open(output_file, "w"), + ) + + +if __name__ == "__main__": + main() diff --git a/matcha/utils/instantiators.py b/matcha/utils/instantiators.py new file mode 100644 index 0000000000000000000000000000000000000000..5547b4ed61ed8c21e63c528f58526a949879a94f --- /dev/null +++ b/matcha/utils/instantiators.py @@ -0,0 +1,56 @@ +from typing import List + +import hydra +from lightning import Callback +from lightning.pytorch.loggers import Logger +from omegaconf import DictConfig + +from matcha.utils import pylogger + +log = pylogger.get_pylogger(__name__) + + +def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: + """Instantiates callbacks from config. + + :param callbacks_cfg: A DictConfig object containing callback configurations. + :return: A list of instantiated callbacks. + """ + callbacks: List[Callback] = [] + + if not callbacks_cfg: + log.warning("No callback configs found! Skipping..") + return callbacks + + if not isinstance(callbacks_cfg, DictConfig): + raise TypeError("Callbacks config must be a DictConfig!") + + for _, cb_conf in callbacks_cfg.items(): + if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: + log.info(f"Instantiating callback <{cb_conf._target_}>") # pylint: disable=protected-access + callbacks.append(hydra.utils.instantiate(cb_conf)) + + return callbacks + + +def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: + """Instantiates loggers from config. + + :param logger_cfg: A DictConfig object containing logger configurations. + :return: A list of instantiated loggers. + """ + logger: List[Logger] = [] + + if not logger_cfg: + log.warning("No logger configs found! Skipping...") + return logger + + if not isinstance(logger_cfg, DictConfig): + raise TypeError("Logger config must be a DictConfig!") + + for _, lg_conf in logger_cfg.items(): + if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: + log.info(f"Instantiating logger <{lg_conf._target_}>") # pylint: disable=protected-access + logger.append(hydra.utils.instantiate(lg_conf)) + + return logger diff --git a/matcha/utils/logging_utils.py b/matcha/utils/logging_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1a12d1ddafa25ca3ae8e497bcd7de2191f13659b --- /dev/null +++ b/matcha/utils/logging_utils.py @@ -0,0 +1,53 @@ +from typing import Any, Dict + +from lightning.pytorch.utilities import rank_zero_only +from omegaconf import OmegaConf + +from matcha.utils import pylogger + +log = pylogger.get_pylogger(__name__) + + +@rank_zero_only +def log_hyperparameters(object_dict: Dict[str, Any]) -> None: + """Controls which config parts are saved by Lightning loggers. + + Additionally saves: + - Number of model parameters + + :param object_dict: A dictionary containing the following objects: + - `"cfg"`: A DictConfig object containing the main config. + - `"model"`: The Lightning model. + - `"trainer"`: The Lightning trainer. + """ + hparams = {} + + cfg = OmegaConf.to_container(object_dict["cfg"]) + model = object_dict["model"] + trainer = object_dict["trainer"] + + if not trainer.logger: + log.warning("Logger not found! Skipping hyperparameter logging...") + return + + hparams["model"] = cfg["model"] + + # save number of model parameters + hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) + hparams["model/params/trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad) + hparams["model/params/non_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad) + + hparams["data"] = cfg["data"] + hparams["trainer"] = cfg["trainer"] + + hparams["callbacks"] = cfg.get("callbacks") + hparams["extras"] = cfg.get("extras") + + hparams["task_name"] = cfg.get("task_name") + hparams["tags"] = cfg.get("tags") + hparams["ckpt_path"] = cfg.get("ckpt_path") + hparams["seed"] = cfg.get("seed") + + # send hparams to all loggers + for logger in trainer.loggers: + logger.log_hyperparams(hparams) diff --git a/matcha/utils/model.py b/matcha/utils/model.py new file mode 100644 index 0000000000000000000000000000000000000000..869cc6092f5952930534c47544fae88308e96abf --- /dev/null +++ b/matcha/utils/model.py @@ -0,0 +1,90 @@ +""" from https://github.com/jaywalnut310/glow-tts """ + +import numpy as np +import torch + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def fix_len_compatibility(length, num_downsamplings_in_unet=2): + factor = torch.scalar_tensor(2).pow(num_downsamplings_in_unet) + length = (length / factor).ceil() * factor + if not torch.onnx.is_in_onnx_export(): + return length.int().item() + else: + return length + + +def convert_pad_shape(pad_shape): + inverted_shape = pad_shape[::-1] + pad_shape = [item for sublist in inverted_shape for item in sublist] + return pad_shape + + +def generate_path(duration, mask): + device = duration.device + + b, t_x, t_y = mask.shape + cum_duration = torch.cumsum(duration, 1) + path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path * mask + return path + + +def duration_loss(logw, logw_, lengths): + loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths) + return loss + + +def normalize(data, mu, std): + if not isinstance(mu, (float, int)): + if isinstance(mu, list): + mu = torch.tensor(mu, dtype=data.dtype, device=data.device) + elif isinstance(mu, torch.Tensor): + mu = mu.to(data.device) + elif isinstance(mu, np.ndarray): + mu = torch.from_numpy(mu).to(data.device) + mu = mu.unsqueeze(-1) + + if not isinstance(std, (float, int)): + if isinstance(std, list): + std = torch.tensor(std, dtype=data.dtype, device=data.device) + elif isinstance(std, torch.Tensor): + std = std.to(data.device) + elif isinstance(std, np.ndarray): + std = torch.from_numpy(std).to(data.device) + std = std.unsqueeze(-1) + + return (data - mu) / std + + +def denormalize(data, mu, std): + if not isinstance(mu, float): + if isinstance(mu, list): + mu = torch.tensor(mu, dtype=data.dtype, device=data.device) + elif isinstance(mu, torch.Tensor): + mu = mu.to(data.device) + elif isinstance(mu, np.ndarray): + mu = torch.from_numpy(mu).to(data.device) + mu = mu.unsqueeze(-1) + + if not isinstance(std, float): + if isinstance(std, list): + std = torch.tensor(std, dtype=data.dtype, device=data.device) + elif isinstance(std, torch.Tensor): + std = std.to(data.device) + elif isinstance(std, np.ndarray): + std = torch.from_numpy(std).to(data.device) + std = std.unsqueeze(-1) + + return data * std + mu diff --git a/matcha/utils/monotonic_align/__init__.py b/matcha/utils/monotonic_align/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eee6e0d47c2e3612ef02bc17442e6886998e5a94 --- /dev/null +++ b/matcha/utils/monotonic_align/__init__.py @@ -0,0 +1,22 @@ +import numpy as np +import torch + +from matcha.utils.monotonic_align.core import maximum_path_c + + +def maximum_path(value, mask): + """Cython optimised version. + value: [b, t_x, t_y] + mask: [b, t_x, t_y] + """ + value = value * mask + device = value.device + dtype = value.dtype + value = value.data.cpu().numpy().astype(np.float32) + path = np.zeros_like(value).astype(np.int32) + mask = mask.data.cpu().numpy() + + t_x_max = mask.sum(1)[:, 0].astype(np.int32) + t_y_max = mask.sum(2)[:, 0].astype(np.int32) + maximum_path_c(path, value, t_x_max, t_y_max) + return torch.from_numpy(path).to(device=device, dtype=dtype) diff --git a/matcha/utils/monotonic_align/core.pyx b/matcha/utils/monotonic_align/core.pyx new file mode 100644 index 0000000000000000000000000000000000000000..091fcc3a50a51f3d3fee47a70825260757e6d885 --- /dev/null +++ b/matcha/utils/monotonic_align/core.pyx @@ -0,0 +1,47 @@ +import numpy as np + +cimport cython +cimport numpy as np + +from cython.parallel import prange + + +@cython.boundscheck(False) +@cython.wraparound(False) +cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil: + cdef int x + cdef int y + cdef float v_prev + cdef float v_cur + cdef float tmp + cdef int index = t_x - 1 + + for y in range(t_y): + for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): + if x == y: + v_cur = max_neg_val + else: + v_cur = value[x, y-1] + if x == 0: + if y == 0: + v_prev = 0. + else: + v_prev = max_neg_val + else: + v_prev = value[x-1, y-1] + value[x, y] = max(v_cur, v_prev) + value[x, y] + + for y in range(t_y - 1, -1, -1): + path[index, y] = 1 + if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]): + index = index - 1 + + +@cython.boundscheck(False) +@cython.wraparound(False) +cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil: + cdef int b = values.shape[0] + + cdef int i + for i in prange(b, nogil=True): + maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val) diff --git a/matcha/utils/monotonic_align/setup.py b/matcha/utils/monotonic_align/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..f22bc6a35a5a04c9e6d7b82040973722c9b770c9 --- /dev/null +++ b/matcha/utils/monotonic_align/setup.py @@ -0,0 +1,7 @@ +# from distutils.core import setup +# from Cython.Build import cythonize +# import numpy + +# setup(name='monotonic_align', +# ext_modules=cythonize("core.pyx"), +# include_dirs=[numpy.get_include()]) diff --git a/matcha/utils/pylogger.py b/matcha/utils/pylogger.py new file mode 100644 index 0000000000000000000000000000000000000000..61600678029362e110f655edb91d5f3bc5b1cd1c --- /dev/null +++ b/matcha/utils/pylogger.py @@ -0,0 +1,21 @@ +import logging + +from lightning.pytorch.utilities import rank_zero_only + + +def get_pylogger(name: str = __name__) -> logging.Logger: + """Initializes a multi-GPU-friendly python command line logger. + + :param name: The name of the logger, defaults to ``__name__``. + + :return: A logger object. + """ + logger = logging.getLogger(name) + + # this ensures all logging levels get marked with the rank zero decorator + # otherwise logs would get multiplied for each GPU process in multi-GPU setup + logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") + for level in logging_levels: + setattr(logger, level, rank_zero_only(getattr(logger, level))) + + return logger diff --git a/matcha/utils/rich_utils.py b/matcha/utils/rich_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f602f6e9351d948946eb419eb4e420190ea634bc --- /dev/null +++ b/matcha/utils/rich_utils.py @@ -0,0 +1,101 @@ +from pathlib import Path +from typing import Sequence + +import rich +import rich.syntax +import rich.tree +from hydra.core.hydra_config import HydraConfig +from lightning.pytorch.utilities import rank_zero_only +from omegaconf import DictConfig, OmegaConf, open_dict +from rich.prompt import Prompt + +from matcha.utils import pylogger + +log = pylogger.get_pylogger(__name__) + + +@rank_zero_only +def print_config_tree( + cfg: DictConfig, + print_order: Sequence[str] = ( + "data", + "model", + "callbacks", + "logger", + "trainer", + "paths", + "extras", + ), + resolve: bool = False, + save_to_file: bool = False, +) -> None: + """Prints the contents of a DictConfig as a tree structure using the Rich library. + + :param cfg: A DictConfig composed by Hydra. + :param print_order: Determines in what order config components are printed. Default is ``("data", "model", + "callbacks", "logger", "trainer", "paths", "extras")``. + :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``. + :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``. + """ + style = "dim" + tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) + + queue = [] + + # add fields from `print_order` to queue + for field in print_order: + _ = ( + queue.append(field) + if field in cfg + else log.warning(f"Field '{field}' not found in config. Skipping '{field}' config printing...") + ) + + # add all the other fields to queue (not specified in `print_order`) + for field in cfg: + if field not in queue: + queue.append(field) + + # generate config tree from queue + for field in queue: + branch = tree.add(field, style=style, guide_style=style) + + config_group = cfg[field] + if isinstance(config_group, DictConfig): + branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) + else: + branch_content = str(config_group) + + branch.add(rich.syntax.Syntax(branch_content, "yaml")) + + # print config tree + rich.print(tree) + + # save config tree to file + if save_to_file: + with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: + rich.print(tree, file=file) + + +@rank_zero_only +def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: + """Prompts user to input tags from command line if no tags are provided in config. + + :param cfg: A DictConfig composed by Hydra. + :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``. + """ + if not cfg.get("tags"): + if "id" in HydraConfig().cfg.hydra.job: + raise ValueError("Specify tags before launching a multirun!") + + log.warning("No tags provided in config. Prompting user to input tags...") + tags = Prompt.ask("Enter a list of comma separated tags", default="dev") + tags = [t.strip() for t in tags.split(",") if t != ""] + + with open_dict(cfg): + cfg.tags = tags + + log.info(f"Tags: {cfg.tags}") + + if save_to_file: + with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: + rich.print(cfg.tags, file=file) diff --git a/matcha/utils/utils.py b/matcha/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..af65e09070b4a4786ad139ec6e3d57d5ef578204 --- /dev/null +++ b/matcha/utils/utils.py @@ -0,0 +1,219 @@ +import os +import sys +import warnings +from importlib.util import find_spec +from pathlib import Path +from typing import Any, Callable, Dict, Tuple + +import gdown +import matplotlib.pyplot as plt +import numpy as np +import torch +import wget +from omegaconf import DictConfig + +from matcha.utils import pylogger, rich_utils + +log = pylogger.get_pylogger(__name__) + + +def extras(cfg: DictConfig) -> None: + """Applies optional utilities before the task is started. + + Utilities: + - Ignoring python warnings + - Setting tags from command line + - Rich config printing + + :param cfg: A DictConfig object containing the config tree. + """ + # return if no `extras` config + if not cfg.get("extras"): + log.warning("Extras config not found! ") + return + + # disable python warnings + if cfg.extras.get("ignore_warnings"): + log.info("Disabling python warnings! ") + warnings.filterwarnings("ignore") + + # prompt user to input tags from command line if none are provided in the config + if cfg.extras.get("enforce_tags"): + log.info("Enforcing tags! ") + rich_utils.enforce_tags(cfg, save_to_file=True) + + # pretty print config tree using Rich library + if cfg.extras.get("print_config"): + log.info("Printing config tree with Rich! ") + rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) + + +def task_wrapper(task_func: Callable) -> Callable: + """Optional decorator that controls the failure behavior when executing the task function. + + This wrapper can be used to: + - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) + - save the exception to a `.log` file + - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) + - etc. (adjust depending on your needs) + + Example: + ``` + @utils.task_wrapper + def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: + ... + return metric_dict, object_dict + ``` + + :param task_func: The task function to be wrapped. + + :return: The wrapped task function. + """ + + def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: + # execute the task + try: + metric_dict, object_dict = task_func(cfg=cfg) + + # things to do if exception occurs + except Exception as ex: + # save exception to `.log` file + log.exception("") + + # some hyperparameter combinations might be invalid or cause out-of-memory errors + # so when using hparam search plugins like Optuna, you might want to disable + # raising the below exception to avoid multirun failure + raise ex + + # things to always do after either success or exception + finally: + # display output dir path in terminal + log.info(f"Output dir: {cfg.paths.output_dir}") + + # always close wandb run (even if exception occurs so multirun won't fail) + if find_spec("wandb"): # check if wandb is installed + import wandb + + if wandb.run: + log.info("Closing wandb!") + wandb.finish() + + return metric_dict, object_dict + + return wrap + + +def get_metric_value(metric_dict: Dict[str, Any], metric_name: str) -> float: + """Safely retrieves value of the metric logged in LightningModule. + + :param metric_dict: A dict containing metric values. + :param metric_name: The name of the metric to retrieve. + :return: The value of the metric. + """ + if not metric_name: + log.info("Metric name is None! Skipping metric value retrieval...") + return None + + if metric_name not in metric_dict: + raise ValueError( + f"Metric value not found! \n" + "Make sure metric name logged in LightningModule is correct!\n" + "Make sure `optimized_metric` name in `hparams_search` config is correct!" + ) + + metric_value = metric_dict[metric_name].item() + log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") + + return metric_value + + +def intersperse(lst, item): + # Adds blank symbol + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result + + +def save_figure_to_numpy(fig): + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + return data + + +def plot_tensor(tensor): + plt.style.use("default") + fig, ax = plt.subplots(figsize=(12, 3)) + im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + plt.tight_layout() + fig.canvas.draw() + data = save_figure_to_numpy(fig) + plt.close() + return data + + +def save_plot(tensor, savepath): + plt.style.use("default") + fig, ax = plt.subplots(figsize=(12, 3)) + im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + plt.tight_layout() + fig.canvas.draw() + plt.savefig(savepath) + plt.close() + + +def to_numpy(tensor): + if isinstance(tensor, np.ndarray): + return tensor + elif isinstance(tensor, torch.Tensor): + return tensor.detach().cpu().numpy() + elif isinstance(tensor, list): + return np.array(tensor) + else: + raise TypeError("Unsupported type for conversion to numpy array") + + +def get_user_data_dir(appname="matcha_tts"): + """ + Args: + appname (str): Name of application + + Returns: + Path: path to user data directory + """ + + MATCHA_HOME = os.environ.get("MATCHA_HOME") + if MATCHA_HOME is not None: + ans = Path(MATCHA_HOME).expanduser().resolve(strict=False) + elif sys.platform == "win32": + import winreg # pylint: disable=import-outside-toplevel + + key = winreg.OpenKey( + winreg.HKEY_CURRENT_USER, + r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders", + ) + dir_, _ = winreg.QueryValueEx(key, "Local AppData") + ans = Path(dir_).resolve(strict=False) + elif sys.platform == "darwin": + ans = Path("~/Library/Application Support/").expanduser() + else: + ans = Path.home().joinpath(".local/share") + + final_path = ans.joinpath(appname) + final_path.mkdir(parents=True, exist_ok=True) + return final_path + + +def assert_model_downloaded(checkpoint_path, url, use_wget=True): + if Path(checkpoint_path).exists(): + log.debug(f"[+] Model already present at {checkpoint_path}!") + print(f"[+] Model already present at {checkpoint_path}!") + return + log.info(f"[-] Model not found at {checkpoint_path}! Will download it") + print(f"[-] Model not found at {checkpoint_path}! Will download it") + checkpoint_path = str(checkpoint_path) + if not use_wget: + gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True) + else: + wget.download(url=url, out=checkpoint_path) diff --git a/photo_2024-04-07_15-59-52.png b/photo_2024-04-07_15-59-52.png new file mode 100644 index 0000000000000000000000000000000000000000..694944b8022b1d04b42dc2fd7325b08bd14654ce Binary files /dev/null and b/photo_2024-04-07_15-59-52.png differ diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..74aa39300a61b8b3607dc634d68aa47013141ec5 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,51 @@ +[build-system] +requires = ["setuptools", "wheel", "cython==0.29.35", "numpy==1.24.3", "packaging"] + +[tool.black] +line-length = 120 +target-version = ['py310'] +exclude = ''' + +( + /( + \.eggs # exclude a few common directories in the + | \.git # root of the project + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist + )/ + | foo.py # also separately exclude a file named foo.py in + # the root of the project +) +''' + +[tool.pytest.ini_options] +addopts = [ + "--color=yes", + "--durations=0", + "--strict-markers", + "--doctest-modules", +] +filterwarnings = [ + "ignore::DeprecationWarning", + "ignore::UserWarning", +] +log_cli = "True" +markers = [ + "slow: slow tests", +] +minversion = "6.0" +testpaths = "tests/" + +[tool.coverage.report] +exclude_lines = [ + "pragma: nocover", + "raise NotImplementedError", + "raise NotImplementedError()", + "if __name__ == .__main__.:", +] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..1716129f389e3a3d885ebe152bfe0dd3f9012447 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,58 @@ +# --------- pytorch --------- # +torch>=2.0.0 +torchvision>=0.15.0 +lightning>=2.0.0 +torchmetrics>=0.11.4 + +# --------- hydra --------- # +hydra-core==1.3.2 +hydra-colorlog==1.2.0 +hydra-optuna-sweeper==1.2.0 + +# --------- loggers --------- # +wandb +# neptune-client +# mlflow +# comet-ml +# aim>=3.16.2 # no lower than 3.16.2, see https://github.com/aimhubio/aim/issues/2550 + +# --------- others --------- # +rootutils # standardizing the project root setup +pre-commit # hooks for applying linters on commit +rich # beautiful text formatting in terminal +pytest # tests +# sh # for running bash commands in some tests (linux/macos only) +phonemizer # phonemization of text +tensorboard +librosa +Cython +numpy +einops +inflect +Unidecode +scipy +torchaudio +matplotlib +pandas +conformer==0.3.2 +diffusers==0.25.0 +notebook +ipywidgets +gradio==3.43.2 +gdown +wget +seaborn +piper_phonemize +huggingface_hub[cli] + +# ----- create dataset ---- # +pandas==2.0.3 +numpy==1.25.2 +datasets==2.18.0 +transformers +ipython==7.34.0 +librosa==0.10.1 +tqdm==4.66.2 +# ---- inference interface ---- # +streamlit==1.33.0 +requests diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..9696d82c1e75d2521cd56daf78a6f7f14cf1254a --- /dev/null +++ b/setup.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python +import os + +import numpy +from Cython.Build import cythonize +from setuptools import Extension, find_packages, setup + +exts = [ + Extension( + name="matcha.utils.monotonic_align.core", + sources=["matcha/utils/monotonic_align/core.pyx"], + ) +] + +with open("README.md", encoding="utf-8") as readme_file: + README = readme_file.read() + +cwd = os.path.dirname(os.path.abspath(__file__)) +with open(os.path.join(cwd, "matcha", "VERSION")) as fin: + version = fin.read().strip() + +setup( + name="matcha-tts", + version=version, + description="🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching", + long_description=README, + long_description_content_type="text/markdown", + author="Shivam Mehta", + author_email="shivam.mehta25@gmail.com", + url="https://shivammehta25.github.io/Matcha-TTS", + install_requires=[str(r) for r in open(os.path.join(os.path.dirname(__file__), "requirements.txt"))], + include_dirs=[numpy.get_include()], + include_package_data=True, + packages=find_packages(exclude=["tests", "tests/*", "examples", "examples/*"]), + # use this to customize global commands available in the terminal after installing the package + + ext_modules=cythonize(exts, language_level=3), + python_requires=">=3.9.0", +) +