maskgct / models /tts /jets /jets_inference.py
Hecheng0625's picture
Upload 409 files
c968fc3 verified
raw
history blame
2.38 kB
# Copyright (c) 2024 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import torch
from tqdm import tqdm
from collections import OrderedDict
from models.tts.base.tts_inferece import TTSInference
from models.tts.jets.jets_dataset import JetsTestDataset, JetsTestCollator
from utils.util import load_config
from utils.io import save_audio
from models.tts.jets.jets import Jets
from models.vocoders.vocoder_inference import synthesis
from pathlib import Path
from processors.phone_extractor import phoneExtractor
from text.text_token_collation import phoneIDCollation
import numpy as np
import json
import time
class JetsInference(TTSInference):
def __init__(self, args, cfg):
TTSInference.__init__(self, args, cfg)
self.args = args
self.cfg = cfg
self.infer_type = args.mode
def _build_model(self):
self.model = Jets(self.cfg)
return self.model
def _build_test_dataset(self):
return JetsTestDataset, JetsTestCollator
def inference_for_batches(self):
###### Construct test_batch ######
n_batch = len(self.test_dataloader)
now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
print(
"Model eval time: {}, batch_size = {}, n_batch = {}".format(
now, self.test_batch_size, n_batch
)
)
self.model.eval()
###### Inference for each batch ######
pred_res = []
with torch.no_grad():
for i, batch_data in enumerate(
self.test_dataloader if n_batch == 1 else tqdm(self.test_dataloader)
):
outputs = self.model.inference(batch_data)
audios, d_predictions = outputs
d_predictions = d_predictions.unsqueeze(-1)
for idx in range(audios.size(0)):
audio = audios[idx, 0, :].data.cpu().float()
duration = d_predictions[idx, :, :]
audio_length = (
duration.sum([0, 1]).long() * self.cfg.preprocess.hop_size
)
audio_length = audio_length.cpu().numpy()
audio = audio[:audio_length]
pred_res.append(audio)
return pred_res