Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import argparse | |
import os | |
import re | |
import time | |
from pathlib import Path | |
import torch | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
from models.vocoders.vocoder_inference import synthesis | |
from torch.utils.data import DataLoader | |
from utils.util import set_all_random_seed | |
from utils.util import load_config | |
def parse_vocoder(vocoder_dir): | |
r"""Parse vocoder config""" | |
vocoder_dir = os.path.abspath(vocoder_dir) | |
ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")] | |
ckpt_list.sort(key=lambda x: int(x.stem), reverse=True) | |
ckpt_path = str(ckpt_list[0]) | |
vocoder_cfg = load_config(os.path.join(vocoder_dir, "args.json"), lowercase=True) | |
vocoder_cfg.model.bigvgan = vocoder_cfg.vocoder | |
return vocoder_cfg, ckpt_path | |
class BaseInference(object): | |
def __init__(self, cfg, args): | |
self.cfg = cfg | |
self.args = args | |
self.model_type = cfg.model_type | |
self.avg_rtf = list() | |
set_all_random_seed(10086) | |
os.makedirs(args.output_dir, exist_ok=True) | |
if torch.cuda.is_available(): | |
self.device = torch.device("cuda") | |
else: | |
self.device = torch.device("cpu") | |
torch.set_num_threads(10) # inference on 1 core cpu. | |
# Load acoustic model | |
self.model = self.create_model().to(self.device) | |
state_dict = self.load_state_dict() | |
self.load_model(state_dict) | |
self.model.eval() | |
# Load vocoder model if necessary | |
if self.args.checkpoint_dir_vocoder is not None: | |
self.get_vocoder_info() | |
def create_model(self): | |
raise NotImplementedError | |
def load_state_dict(self): | |
self.checkpoint_file = self.args.checkpoint_file | |
if self.checkpoint_file is None: | |
assert self.args.checkpoint_dir is not None | |
checkpoint_path = os.path.join(self.args.checkpoint_dir, "checkpoint") | |
checkpoint_filename = open(checkpoint_path).readlines()[-1].strip() | |
self.checkpoint_file = os.path.join( | |
self.args.checkpoint_dir, checkpoint_filename | |
) | |
self.checkpoint_dir = os.path.split(self.checkpoint_file)[0] | |
print("Restore acoustic model from {}".format(self.checkpoint_file)) | |
raw_state_dict = torch.load(self.checkpoint_file, map_location=self.device) | |
self.am_restore_step = re.findall(r"step-(.+?)_loss", self.checkpoint_file)[0] | |
return raw_state_dict | |
def load_model(self, model): | |
raise NotImplementedError | |
def get_vocoder_info(self): | |
self.checkpoint_dir_vocoder = self.args.checkpoint_dir_vocoder | |
self.vocoder_cfg = os.path.join( | |
os.path.dirname(self.checkpoint_dir_vocoder), "args.json" | |
) | |
self.cfg.vocoder = load_config(self.vocoder_cfg, lowercase=True) | |
self.vocoder_tag = self.checkpoint_dir_vocoder.split("/")[-2].split(":")[-1] | |
self.vocoder_steps = self.checkpoint_dir_vocoder.split("/")[-1].split(".")[0] | |
def build_test_utt_data(self): | |
raise NotImplementedError | |
def build_testdata_loader(self, args, target_speaker=None): | |
datasets, collate = self.build_test_dataset() | |
self.test_dataset = datasets(self.cfg, args, target_speaker) | |
self.test_collate = collate(self.cfg) | |
self.test_batch_size = min( | |
self.cfg.train.batch_size, len(self.test_dataset.metadata) | |
) | |
test_loader = DataLoader( | |
self.test_dataset, | |
collate_fn=self.test_collate, | |
num_workers=self.args.num_workers, | |
batch_size=self.test_batch_size, | |
shuffle=False, | |
) | |
return test_loader | |
def inference_each_batch(self, batch_data): | |
raise NotImplementedError | |
def inference_for_batches(self, args, target_speaker=None): | |
###### Construct test_batch ###### | |
loader = self.build_testdata_loader(args, target_speaker) | |
n_batch = len(loader) | |
now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())) | |
print( | |
"Model eval time: {}, batch_size = {}, n_batch = {}".format( | |
now, self.test_batch_size, n_batch | |
) | |
) | |
self.model.eval() | |
###### Inference for each batch ###### | |
pred_res = [] | |
with torch.no_grad(): | |
for i, batch_data in enumerate(loader if n_batch == 1 else tqdm(loader)): | |
# Put the data to device | |
for k, v in batch_data.items(): | |
batch_data[k] = batch_data[k].to(self.device) | |
y_pred, stats = self.inference_each_batch(batch_data) | |
pred_res += y_pred | |
return pred_res | |
def inference(self, feature): | |
raise NotImplementedError | |
def synthesis_by_vocoder(self, pred): | |
audios_pred = synthesis( | |
self.vocoder_cfg, | |
self.checkpoint_dir_vocoder, | |
len(pred), | |
pred, | |
) | |
return audios_pred | |
def __call__(self, utt): | |
feature = self.build_test_utt_data(utt) | |
start_time = time.time() | |
with torch.no_grad(): | |
outputs = self.inference(feature)[0] | |
time_used = time.time() - start_time | |
rtf = time_used / ( | |
outputs.shape[1] | |
* self.cfg.preprocess.hop_size | |
/ self.cfg.preprocess.sample_rate | |
) | |
print("Time used: {:.3f}, RTF: {:.4f}".format(time_used, rtf)) | |
self.avg_rtf.append(rtf) | |
audios = outputs.cpu().squeeze().numpy().reshape(-1, 1) | |
return audios | |
def base_parser(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--config", default="config.json", help="json files for configurations." | |
) | |
parser.add_argument("--use_ddp_inference", default=False) | |
parser.add_argument("--n_workers", default=1, type=int) | |
parser.add_argument("--local_rank", default=-1, type=int) | |
parser.add_argument( | |
"--batch_size", default=1, type=int, help="Batch size for inference" | |
) | |
parser.add_argument( | |
"--num_workers", | |
default=1, | |
type=int, | |
help="Worker number for inference dataloader", | |
) | |
parser.add_argument( | |
"--checkpoint_dir", | |
type=str, | |
default=None, | |
help="Checkpoint dir including model file and configuration", | |
) | |
parser.add_argument( | |
"--checkpoint_file", help="checkpoint file", type=str, default=None | |
) | |
parser.add_argument( | |
"--test_list", help="test utterance list for testing", type=str, default=None | |
) | |
parser.add_argument( | |
"--checkpoint_dir_vocoder", | |
help="Vocoder's checkpoint dir including model file and configuration", | |
type=str, | |
default=None, | |
) | |
parser.add_argument( | |
"--output_dir", | |
type=str, | |
default=None, | |
help="Output dir for saving generated results", | |
) | |
return parser | |
if __name__ == "__main__": | |
parser = base_parser() | |
args = parser.parse_args() | |
cfg = load_config(args.config) | |
# Build inference | |
inference = BaseInference(cfg, args) | |
inference() | |