Spaces:
Runtime error
Runtime error
# Copyright (c) 2024 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import os | |
import torch | |
from tqdm import tqdm | |
from collections import OrderedDict | |
from models.tts.base.tts_inferece import TTSInference | |
from models.tts.jets.jets_dataset import JetsTestDataset, JetsTestCollator | |
from utils.util import load_config | |
from utils.io import save_audio | |
from models.tts.jets.jets import Jets | |
from models.vocoders.vocoder_inference import synthesis | |
from pathlib import Path | |
from processors.phone_extractor import phoneExtractor | |
from text.text_token_collation import phoneIDCollation | |
import numpy as np | |
import json | |
import time | |
class JetsInference(TTSInference): | |
def __init__(self, args, cfg): | |
TTSInference.__init__(self, args, cfg) | |
self.args = args | |
self.cfg = cfg | |
self.infer_type = args.mode | |
def _build_model(self): | |
self.model = Jets(self.cfg) | |
return self.model | |
def _build_test_dataset(self): | |
return JetsTestDataset, JetsTestCollator | |
def inference_for_batches(self): | |
###### Construct test_batch ###### | |
n_batch = len(self.test_dataloader) | |
now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())) | |
print( | |
"Model eval time: {}, batch_size = {}, n_batch = {}".format( | |
now, self.test_batch_size, n_batch | |
) | |
) | |
self.model.eval() | |
###### Inference for each batch ###### | |
pred_res = [] | |
with torch.no_grad(): | |
for i, batch_data in enumerate( | |
self.test_dataloader if n_batch == 1 else tqdm(self.test_dataloader) | |
): | |
outputs = self.model.inference(batch_data) | |
audios, d_predictions = outputs | |
d_predictions = d_predictions.unsqueeze(-1) | |
for idx in range(audios.size(0)): | |
audio = audios[idx, 0, :].data.cpu().float() | |
duration = d_predictions[idx, :, :] | |
audio_length = ( | |
duration.sum([0, 1]).long() * self.cfg.preprocess.hop_size | |
) | |
audio_length = audio_length.cpu().numpy() | |
audio = audio[:audio_length] | |
pred_res.append(audio) | |
return pred_res | |