maskgct / bins /tta /inference.py
Hecheng0625's picture
Upload 167 files
8c92a11 verified
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
from argparse import ArgumentParser
import os
from models.tta.ldm.audioldm_inference import AudioLDMInference
from utils.util import save_config, load_model_config, load_config
import numpy as np
import torch
def build_inference(args, cfg):
supported_inference = {
"AudioLDM": AudioLDMInference,
}
inference_class = supported_inference[cfg.model_type]
inference = inference_class(args, cfg)
return inference
def build_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
type=str,
required=True,
help="JSON/YAML file for configurations.",
)
parser.add_argument(
"--text",
help="Text to be synthesized",
type=str,
default="Text to be synthesized.",
)
parser.add_argument(
"--checkpoint_path",
type=str,
)
parser.add_argument(
"--vocoder_path", type=str, help="Checkpoint path of the vocoder"
)
parser.add_argument(
"--vocoder_config_path", type=str, help="Config path of the vocoder"
)
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="Output dir for saving generated results",
)
parser.add_argument(
"--num_steps",
type=int,
default=200,
help="The total number of denosing steps",
)
parser.add_argument(
"--guidance_scale",
type=float,
default=4.0,
help="The scale of classifer free guidance",
)
parser.add_argument("--local_rank", default=-1, type=int)
return parser
def main():
# Parse arguments
args = build_parser().parse_args()
# args, infer_type = formulate_parser(args)
# Parse config
cfg = load_config(args.config)
if torch.cuda.is_available():
args.local_rank = torch.device("cuda")
else:
args.local_rank = torch.device("cpu")
print("args: ", args)
# Build inference
inferencer = build_inference(args, cfg)
# Run inference
inferencer.inference()
if __name__ == "__main__":
main()