Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,734 Bytes
8c92a11 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import torch
from models.tts.fastspeech2.fs2_trainer import FastSpeech2Trainer
from models.tts.vits.vits_trainer import VITSTrainer
from models.tts.valle.valle_trainer import VALLETrainer
from models.tts.naturalspeech2.ns2_trainer import NS2Trainer
from models.tts.valle_v2.valle_ar_trainer import ValleARTrainer as VALLE_V2_AR
from models.tts.valle_v2.valle_nar_trainer import ValleNARTrainer as VALLE_V2_NAR
from models.tts.jets.jets_trainer import JetsTrainer
from utils.util import load_config
def build_trainer(args, cfg):
supported_trainer = {
"FastSpeech2": FastSpeech2Trainer,
"VITS": VITSTrainer,
"VALLE": VALLETrainer,
"NaturalSpeech2": NS2Trainer,
"VALLE_V2_AR": VALLE_V2_AR,
"VALLE_V2_NAR": VALLE_V2_NAR,
"Jets": JetsTrainer,
}
trainer_class = supported_trainer[cfg.model_type]
trainer = trainer_class(args, cfg)
return trainer
def cuda_relevant(deterministic=False):
torch.cuda.empty_cache()
# TF32 on Ampere and above
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.allow_tf32 = True
# Deterministic
torch.backends.cudnn.deterministic = deterministic
torch.backends.cudnn.benchmark = not deterministic
torch.use_deterministic_algorithms(deterministic)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
default="config.json",
help="json files for configurations.",
required=True,
)
parser.add_argument(
"--seed",
type=int,
default=1234,
help="random seed",
required=False,
)
parser.add_argument(
"--exp_name",
type=str,
default="exp_name",
help="A specific name to note the experiment",
required=True,
)
parser.add_argument(
"--resume", action="store_true", help="The model name to restore"
)
parser.add_argument(
"--test", action="store_true", default=False, help="Test the model"
)
parser.add_argument(
"--log_level", default="warning", help="logging level (debug, info, warning)"
)
parser.add_argument(
"--resume_type",
type=str,
default="resume",
help="Resume training or finetuning.",
)
parser.add_argument(
"--checkpoint_path",
type=str,
default=None,
help="Checkpoint for resume training or finetuning.",
)
parser.add_argument(
"--resume_from_ckpt_path",
type=str,
default="",
help="Checkpoint for resume training or finetuning.",
)
# VALLETrainer.add_arguments(parser)
args = parser.parse_args()
cfg = load_config(args.config)
# Data Augmentation
if hasattr(cfg, "preprocess"):
if hasattr(cfg.preprocess, "data_augment"):
if (
type(cfg.preprocess.data_augment) == list
and len(cfg.preprocess.data_augment) > 0
):
new_datasets_list = []
for dataset in cfg.preprocess.data_augment:
new_datasets = [
(
f"{dataset}_pitch_shift"
if cfg.preprocess.use_pitch_shift
else None
),
(
f"{dataset}_formant_shift"
if cfg.preprocess.use_formant_shift
else None
),
(
f"{dataset}_equalizer"
if cfg.preprocess.use_equalizer
else None
),
(
f"{dataset}_time_stretch"
if cfg.preprocess.use_time_stretch
else None
),
]
new_datasets_list.extend(filter(None, new_datasets))
cfg.dataset.extend(new_datasets_list)
print("experiment name: ", args.exp_name)
# # CUDA settings
cuda_relevant()
# Build trainer
print(f"Building {cfg.model_type} trainer")
trainer = build_trainer(args, cfg)
print(f"Start training {cfg.model_type} model")
if args.test:
trainer.test_loop()
else:
trainer.train_loop()
if __name__ == "__main__":
main()
|