Spaces:
Running
Running
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import argparse | |
import os | |
import torch | |
from models.vocoders.vocoder_inference import VocoderInference | |
from utils.util import load_config | |
def build_inference(args, cfg, infer_type="infer_from_dataset"): | |
supported_inference = { | |
"GANVocoder": VocoderInference, | |
"DiffusionVocoder": VocoderInference, | |
} | |
inference_class = supported_inference[cfg.model_type] | |
return inference_class(args, cfg, infer_type) | |
def cuda_relevant(deterministic=False): | |
torch.cuda.empty_cache() | |
# TF32 on Ampere and above | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.enabled = True | |
torch.backends.cudnn.allow_tf32 = True | |
# Deterministic | |
torch.backends.cudnn.deterministic = deterministic | |
torch.backends.cudnn.benchmark = not deterministic | |
torch.use_deterministic_algorithms(deterministic) | |
def build_parser(): | |
r"""Build argument parser for inference.py. | |
Anything else should be put in an extra config YAML file. | |
""" | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--config", | |
type=str, | |
required=True, | |
help="JSON/YAML file for configurations.", | |
) | |
parser.add_argument( | |
"--infer_mode", | |
type=str, | |
required=None, | |
) | |
parser.add_argument( | |
"--infer_datasets", | |
nargs="+", | |
default=None, | |
) | |
parser.add_argument( | |
"--feature_folder", | |
type=str, | |
default=None, | |
) | |
parser.add_argument( | |
"--audio_folder", | |
type=str, | |
default=None, | |
) | |
parser.add_argument( | |
"--vocoder_dir", | |
type=str, | |
required=True, | |
help="Vocoder checkpoint directory. Searching behavior is the same as " | |
"the acoustics one.", | |
) | |
parser.add_argument( | |
"--output_dir", | |
type=str, | |
default="result", | |
help="Output directory. Default: ./result", | |
) | |
parser.add_argument( | |
"--log_level", | |
type=str, | |
default="warning", | |
help="Logging level. Default: warning", | |
) | |
parser.add_argument( | |
"--keep_cache", | |
action="store_true", | |
default=False, | |
help="Keep cache files. Only applicable to inference from files.", | |
) | |
return parser | |
def main(): | |
# Parse arguments | |
args = build_parser().parse_args() | |
# Parse config | |
cfg = load_config(args.config) | |
# CUDA settings | |
cuda_relevant() | |
# Build inference | |
trainer = build_inference(args, cfg, args.infer_mode) | |
# Run inference | |
trainer.inference() | |
if __name__ == "__main__": | |
main() | |