|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import os |
|
import re |
|
import time |
|
from pathlib import Path |
|
|
|
import torch |
|
from torch.utils.data import DataLoader |
|
from tqdm import tqdm |
|
|
|
from models.vocoders.vocoder_inference import synthesis |
|
from torch.utils.data import DataLoader |
|
from utils.util import set_all_random_seed |
|
from utils.util import load_config |
|
|
|
|
|
def parse_vocoder(vocoder_dir): |
|
r"""Parse vocoder config""" |
|
vocoder_dir = os.path.abspath(vocoder_dir) |
|
ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")] |
|
ckpt_list.sort(key=lambda x: int(x.stem), reverse=True) |
|
ckpt_path = str(ckpt_list[0]) |
|
vocoder_cfg = load_config(os.path.join(vocoder_dir, "args.json"), lowercase=True) |
|
vocoder_cfg.model.bigvgan = vocoder_cfg.vocoder |
|
return vocoder_cfg, ckpt_path |
|
|
|
|
|
class BaseInference(object): |
|
def __init__(self, cfg, args): |
|
self.cfg = cfg |
|
self.args = args |
|
self.model_type = cfg.model_type |
|
self.avg_rtf = list() |
|
set_all_random_seed(10086) |
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
if torch.cuda.is_available(): |
|
self.device = torch.device("cuda") |
|
else: |
|
self.device = torch.device("cpu") |
|
torch.set_num_threads(10) |
|
|
|
|
|
self.model = self.create_model().to(self.device) |
|
state_dict = self.load_state_dict() |
|
self.load_model(state_dict) |
|
self.model.eval() |
|
|
|
|
|
if self.args.checkpoint_dir_vocoder is not None: |
|
self.get_vocoder_info() |
|
|
|
def create_model(self): |
|
raise NotImplementedError |
|
|
|
def load_state_dict(self): |
|
self.checkpoint_file = self.args.checkpoint_file |
|
if self.checkpoint_file is None: |
|
assert self.args.checkpoint_dir is not None |
|
checkpoint_path = os.path.join(self.args.checkpoint_dir, "checkpoint") |
|
checkpoint_filename = open(checkpoint_path).readlines()[-1].strip() |
|
self.checkpoint_file = os.path.join( |
|
self.args.checkpoint_dir, checkpoint_filename |
|
) |
|
|
|
self.checkpoint_dir = os.path.split(self.checkpoint_file)[0] |
|
|
|
print("Restore acoustic model from {}".format(self.checkpoint_file)) |
|
raw_state_dict = torch.load(self.checkpoint_file, map_location=self.device) |
|
self.am_restore_step = re.findall(r"step-(.+?)_loss", self.checkpoint_file)[0] |
|
|
|
return raw_state_dict |
|
|
|
def load_model(self, model): |
|
raise NotImplementedError |
|
|
|
def get_vocoder_info(self): |
|
self.checkpoint_dir_vocoder = self.args.checkpoint_dir_vocoder |
|
self.vocoder_cfg = os.path.join( |
|
os.path.dirname(self.checkpoint_dir_vocoder), "args.json" |
|
) |
|
self.cfg.vocoder = load_config(self.vocoder_cfg, lowercase=True) |
|
self.vocoder_tag = self.checkpoint_dir_vocoder.split("/")[-2].split(":")[-1] |
|
self.vocoder_steps = self.checkpoint_dir_vocoder.split("/")[-1].split(".")[0] |
|
|
|
def build_test_utt_data(self): |
|
raise NotImplementedError |
|
|
|
def build_testdata_loader(self, args, target_speaker=None): |
|
datasets, collate = self.build_test_dataset() |
|
self.test_dataset = datasets(self.cfg, args, target_speaker) |
|
self.test_collate = collate(self.cfg) |
|
self.test_batch_size = min( |
|
self.cfg.train.batch_size, len(self.test_dataset.metadata) |
|
) |
|
test_loader = DataLoader( |
|
self.test_dataset, |
|
collate_fn=self.test_collate, |
|
num_workers=self.args.num_workers, |
|
batch_size=self.test_batch_size, |
|
shuffle=False, |
|
) |
|
return test_loader |
|
|
|
def inference_each_batch(self, batch_data): |
|
raise NotImplementedError |
|
|
|
def inference_for_batches(self, args, target_speaker=None): |
|
|
|
loader = self.build_testdata_loader(args, target_speaker) |
|
|
|
n_batch = len(loader) |
|
now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())) |
|
print( |
|
"Model eval time: {}, batch_size = {}, n_batch = {}".format( |
|
now, self.test_batch_size, n_batch |
|
) |
|
) |
|
self.model.eval() |
|
|
|
|
|
pred_res = [] |
|
with torch.no_grad(): |
|
for i, batch_data in enumerate(loader if n_batch == 1 else tqdm(loader)): |
|
|
|
for k, v in batch_data.items(): |
|
batch_data[k] = batch_data[k].to(self.device) |
|
|
|
y_pred, stats = self.inference_each_batch(batch_data) |
|
|
|
pred_res += y_pred |
|
|
|
return pred_res |
|
|
|
def inference(self, feature): |
|
raise NotImplementedError |
|
|
|
def synthesis_by_vocoder(self, pred): |
|
audios_pred = synthesis( |
|
self.vocoder_cfg, |
|
self.checkpoint_dir_vocoder, |
|
len(pred), |
|
pred, |
|
) |
|
return audios_pred |
|
|
|
def __call__(self, utt): |
|
feature = self.build_test_utt_data(utt) |
|
start_time = time.time() |
|
with torch.no_grad(): |
|
outputs = self.inference(feature)[0] |
|
time_used = time.time() - start_time |
|
rtf = time_used / ( |
|
outputs.shape[1] |
|
* self.cfg.preprocess.hop_size |
|
/ self.cfg.preprocess.sample_rate |
|
) |
|
print("Time used: {:.3f}, RTF: {:.4f}".format(time_used, rtf)) |
|
self.avg_rtf.append(rtf) |
|
audios = outputs.cpu().squeeze().numpy().reshape(-1, 1) |
|
return audios |
|
|
|
|
|
def base_parser(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--config", default="config.json", help="json files for configurations." |
|
) |
|
parser.add_argument("--use_ddp_inference", default=False) |
|
parser.add_argument("--n_workers", default=1, type=int) |
|
parser.add_argument("--local_rank", default=-1, type=int) |
|
parser.add_argument( |
|
"--batch_size", default=1, type=int, help="Batch size for inference" |
|
) |
|
parser.add_argument( |
|
"--num_workers", |
|
default=1, |
|
type=int, |
|
help="Worker number for inference dataloader", |
|
) |
|
parser.add_argument( |
|
"--checkpoint_dir", |
|
type=str, |
|
default=None, |
|
help="Checkpoint dir including model file and configuration", |
|
) |
|
parser.add_argument( |
|
"--checkpoint_file", help="checkpoint file", type=str, default=None |
|
) |
|
parser.add_argument( |
|
"--test_list", help="test utterance list for testing", type=str, default=None |
|
) |
|
parser.add_argument( |
|
"--checkpoint_dir_vocoder", |
|
help="Vocoder's checkpoint dir including model file and configuration", |
|
type=str, |
|
default=None, |
|
) |
|
parser.add_argument( |
|
"--output_dir", |
|
type=str, |
|
default=None, |
|
help="Output dir for saving generated results", |
|
) |
|
return parser |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = base_parser() |
|
args = parser.parse_args() |
|
cfg = load_config(args.config) |
|
|
|
|
|
inference = BaseInference(cfg, args) |
|
inference() |
|
|