maskgct / bins /tta /train_tta.py
Hecheng0625's picture
Upload 167 files
8c92a11 verified
raw
history blame
2.02 kB
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import os
import torch
from models.tta.autoencoder.autoencoder_trainer import AutoencoderKLTrainer
from models.tta.ldm.audioldm_trainer import AudioLDMTrainer
from utils.util import load_config
def build_trainer(args, cfg):
supported_trainer = {
"AutoencoderKL": AutoencoderKLTrainer,
"AudioLDM": AudioLDMTrainer,
}
trainer_class = supported_trainer[cfg.model_type]
trainer = trainer_class(args, cfg)
return trainer
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
default="config.json",
help="json files for configurations.",
required=True,
)
parser.add_argument(
"--num_workers", type=int, default=6, help="Number of dataloader workers."
)
parser.add_argument(
"--exp_name",
type=str,
default="exp_name",
help="A specific name to note the experiment",
required=True,
)
parser.add_argument(
"--resume",
type=str,
default=None,
# action="store_true",
help="The model name to restore",
)
parser.add_argument(
"--log_level", default="info", help="logging level (info, debug, warning)"
)
parser.add_argument("--stdout_interval", default=5, type=int)
parser.add_argument("--local_rank", default=-1, type=int)
args = parser.parse_args()
cfg = load_config(args.config)
cfg.exp_name = args.exp_name
# Model saving dir
args.log_dir = os.path.join(cfg.log_dir, args.exp_name)
os.makedirs(args.log_dir, exist_ok=True)
if not cfg.train.ddp:
args.local_rank = torch.device("cuda")
# Build trainer
trainer = build_trainer(args, cfg)
# Restore models
if args.resume:
trainer.restore()
trainer.train()
if __name__ == "__main__":
main()