|
|
|
|
|
|
|
|
|
|
|
import os |
|
import torch |
|
from tqdm import tqdm |
|
from collections import OrderedDict |
|
|
|
from models.tts.base.tts_inferece import TTSInference |
|
from models.tts.fastspeech2.fs2_dataset import FS2TestDataset, FS2TestCollator |
|
from utils.util import load_config |
|
from utils.io import save_audio |
|
from models.tts.fastspeech2.fs2 import FastSpeech2 |
|
from models.vocoders.vocoder_inference import synthesis |
|
from pathlib import Path |
|
from processors.phone_extractor import phoneExtractor |
|
from text.text_token_collation import phoneIDCollation |
|
import numpy as np |
|
import json |
|
|
|
|
|
class FastSpeech2Inference(TTSInference): |
|
def __init__(self, args, cfg): |
|
TTSInference.__init__(self, args, cfg) |
|
self.args = args |
|
self.cfg = cfg |
|
self.infer_type = args.mode |
|
|
|
def _build_model(self): |
|
self.model = FastSpeech2(self.cfg) |
|
return self.model |
|
|
|
def load_model(self, state_dict): |
|
raw_dict = state_dict["model"] |
|
clean_dict = OrderedDict() |
|
for k, v in raw_dict.items(): |
|
if k.startswith("module."): |
|
clean_dict[k[7:]] = v |
|
else: |
|
clean_dict[k] = v |
|
|
|
self.model.load_state_dict(clean_dict) |
|
|
|
def _build_test_dataset(self): |
|
return FS2TestDataset, FS2TestCollator |
|
|
|
@staticmethod |
|
def _parse_vocoder(vocoder_dir): |
|
r"""Parse vocoder config""" |
|
vocoder_dir = os.path.abspath(vocoder_dir) |
|
ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")] |
|
|
|
ckpt_list.sort( |
|
key=lambda x: int(x.stem.split("_")[-2].split("-")[-1]), reverse=True |
|
) |
|
ckpt_path = str(ckpt_list[0]) |
|
vocoder_cfg = load_config( |
|
os.path.join(vocoder_dir, "args.json"), lowercase=True |
|
) |
|
return vocoder_cfg, ckpt_path |
|
|
|
@torch.inference_mode() |
|
def inference_for_batches(self): |
|
y_pred = [] |
|
for i, batch in tqdm(enumerate(self.test_dataloader)): |
|
y_pred, mel_lens, _ = self._inference_each_batch(batch) |
|
y_ls = y_pred.chunk(self.test_batch_size) |
|
tgt_ls = mel_lens.chunk(self.test_batch_size) |
|
j = 0 |
|
for it, l in zip(y_ls, tgt_ls): |
|
l = l.item() |
|
it = it.squeeze(0)[:l].detach().cpu() |
|
|
|
uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"] |
|
torch.save(it, os.path.join(self.args.output_dir, f"{uid}.pt")) |
|
j += 1 |
|
|
|
vocoder_cfg, vocoder_ckpt = self._parse_vocoder(self.args.vocoder_dir) |
|
res = synthesis( |
|
cfg=vocoder_cfg, |
|
vocoder_weight_file=vocoder_ckpt, |
|
n_samples=None, |
|
pred=[ |
|
torch.load( |
|
os.path.join(self.args.output_dir, "{}.pt".format(item["Uid"])) |
|
).numpy() |
|
for item in self.test_dataset.metadata |
|
], |
|
) |
|
for it, wav in zip(self.test_dataset.metadata, res): |
|
uid = it["Uid"] |
|
save_audio( |
|
os.path.join(self.args.output_dir, f"{uid}.wav"), |
|
wav.numpy(), |
|
self.cfg.preprocess.sample_rate, |
|
add_silence=True, |
|
turn_up=True, |
|
) |
|
os.remove(os.path.join(self.args.output_dir, f"{uid}.pt")) |
|
|
|
@torch.inference_mode() |
|
def _inference_each_batch(self, batch_data): |
|
device = self.accelerator.device |
|
control_values = ( |
|
self.args.pitch_control, |
|
self.args.energy_control, |
|
self.args.duration_control, |
|
) |
|
for k, v in batch_data.items(): |
|
batch_data[k] = v.to(device) |
|
|
|
pitch_control, energy_control, duration_control = control_values |
|
|
|
output = self.model( |
|
batch_data, |
|
p_control=pitch_control, |
|
e_control=energy_control, |
|
d_control=duration_control, |
|
) |
|
pred_res = output["postnet_output"] |
|
mel_lens = output["mel_lens"].cpu() |
|
return pred_res, mel_lens, 0 |
|
|
|
def inference_for_single_utterance(self): |
|
text = self.args.text |
|
control_values = ( |
|
self.args.pitch_control, |
|
self.args.energy_control, |
|
self.args.duration_control, |
|
) |
|
pitch_control, energy_control, duration_control = control_values |
|
|
|
|
|
phone_symbol_file = None |
|
if self.cfg.preprocess.phone_extractor != "lexicon": |
|
phone_symbol_file = os.path.join( |
|
self.exp_dir, self.cfg.preprocess.symbols_dict |
|
) |
|
assert os.path.exists(phone_symbol_file) |
|
|
|
phone_extractor = phoneExtractor(self.cfg) |
|
|
|
phone_seq = phone_extractor.extract_phone(text) |
|
|
|
phon_id_collator = phoneIDCollation( |
|
self.cfg, symbols_dict_file=phone_symbol_file |
|
) |
|
phone_seq = ["{"] + phone_seq + ["}"] |
|
phone_id_seq = phon_id_collator.get_phone_id_sequence(self.cfg, phone_seq) |
|
|
|
|
|
phone_id_seq = np.array(phone_id_seq) |
|
phone_id_seq = torch.from_numpy(phone_id_seq) |
|
|
|
|
|
speaker_id = None |
|
if self.cfg.preprocess.use_spkid and self.cfg.train.multi_speaker_training: |
|
spk2id_file = os.path.join(self.exp_dir, self.cfg.preprocess.spk2id) |
|
with open(spk2id_file, "r") as f: |
|
spk2id = json.load(f) |
|
speaker_id = spk2id[self.args.speaker_name] |
|
speaker_id = torch.from_numpy(np.array([speaker_id], dtype=np.int32)) |
|
else: |
|
speaker_id = torch.Tensor(0).view(-1) |
|
|
|
with torch.no_grad(): |
|
x_tst = phone_id_seq.to(self.device).unsqueeze(0) |
|
x_tst_lengths = torch.LongTensor([phone_id_seq.size(0)]).to(self.device) |
|
if speaker_id is not None: |
|
speaker_id = speaker_id.to(self.device) |
|
|
|
data = {} |
|
data["texts"] = x_tst |
|
data["text_len"] = x_tst_lengths |
|
data["spk_id"] = speaker_id |
|
|
|
output = self.model( |
|
data, |
|
p_control=pitch_control, |
|
e_control=energy_control, |
|
d_control=duration_control, |
|
) |
|
pred_res = output["postnet_output"] |
|
vocoder_cfg, vocoder_ckpt = self._parse_vocoder(self.args.vocoder_dir) |
|
audio = synthesis( |
|
cfg=vocoder_cfg, |
|
vocoder_weight_file=vocoder_ckpt, |
|
n_samples=None, |
|
pred=pred_res, |
|
) |
|
return audio[0] |
|
|