maskgct / bins /tts /inference.py
Hecheng0625's picture
Upload 167 files
8c92a11 verified
raw
history blame
4.68 kB
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
from argparse import ArgumentParser
import os
from models.tts.fastspeech2.fs2_inference import FastSpeech2Inference
from models.tts.vits.vits_inference import VitsInference
from models.tts.valle.valle_inference import VALLEInference
from models.tts.naturalspeech2.ns2_inference import NS2Inference
from models.tts.jets.jets_inference import JetsInference
from utils.util import load_config
import torch
def build_inference(args, cfg):
supported_inference = {
"FastSpeech2": FastSpeech2Inference,
"VITS": VitsInference,
"VALLE": VALLEInference,
"NaturalSpeech2": NS2Inference,
"Jets": JetsInference,
}
inference_class = supported_inference[cfg.model_type]
inference = inference_class(args, cfg)
return inference
def cuda_relevant(deterministic=False):
torch.cuda.empty_cache()
# TF32 on Ampere and above
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.enabled = True
torch.backends.cudnn.allow_tf32 = True
# Deterministic
torch.backends.cudnn.deterministic = deterministic
torch.backends.cudnn.benchmark = not deterministic
torch.use_deterministic_algorithms(deterministic)
def build_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
type=str,
required=True,
help="JSON/YAML file for configurations.",
)
parser.add_argument(
"--dataset",
type=str,
help="convert from the source data",
default=None,
)
parser.add_argument(
"--testing_set",
type=str,
help="train, test, golden_test",
default="test",
)
parser.add_argument(
"--test_list_file",
type=str,
help="convert from the test list file",
default=None,
)
parser.add_argument(
"--speaker_name",
type=str,
default=None,
help="speaker name for multi-speaker synthesis, for single-sentence mode only",
)
parser.add_argument(
"--text",
help="Text to be synthesized.",
type=str,
default="",
)
parser.add_argument(
"--vocoder_dir",
type=str,
default=None,
help="Vocoder checkpoint directory. Searching behavior is the same as "
"the acoustics one.",
)
parser.add_argument(
"--acoustics_dir",
type=str,
default=None,
help="Acoustic model checkpoint directory. If a directory is given, "
"search for the latest checkpoint dir in the directory. If a specific "
"checkpoint dir is given, directly load the checkpoint.",
)
parser.add_argument(
"--checkpoint_path",
type=str,
default=None,
help="Acoustic model checkpoint directory. If a directory is given, "
"search for the latest checkpoint dir in the directory. If a specific "
"checkpoint dir is given, directly load the checkpoint.",
)
parser.add_argument(
"--mode",
type=str,
choices=["batch", "single"],
required=True,
help="Synthesize a whole dataset or a single sentence",
)
parser.add_argument(
"--log_level",
type=str,
default="warning",
help="Logging level. Default: warning",
)
parser.add_argument(
"--pitch_control",
type=float,
default=1.0,
help="control the pitch of the whole utterance, larger value for higher pitch",
)
parser.add_argument(
"--energy_control",
type=float,
default=1.0,
help="control the energy of the whole utterance, larger value for larger volume",
)
parser.add_argument(
"--duration_control",
type=float,
default=1.0,
help="control the speed of the whole utterance, larger value for slower speaking rate",
)
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="Output dir for saving generated results",
)
return parser
def main():
# Parse arguments
parser = build_parser()
VALLEInference.add_arguments(parser)
NS2Inference.add_arguments(parser)
args = parser.parse_args()
print(args)
# Parse config
cfg = load_config(args.config)
# CUDA settings
cuda_relevant()
# Build inference
inferencer = build_inference(args, cfg)
# Run inference
inferencer.inference()
if __name__ == "__main__":
main()