Spaces:
Build error
Build error
import torch | |
import os | |
class TTSInference: | |
def __init__(self, device=None): | |
print("Initializing TTS model to %s" % device) | |
from .tasks.tts.tts_utils import load_data_preprocessor | |
from .utils.commons.hparams import set_hparams | |
if device is None: | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
self.hparams = set_hparams("text_to_speech/checkpoints/ljspeech/ps_adv_baseline/config.yaml") | |
self.device = device | |
self.data_dir = 'text_to_speech/checkpoints/ljspeech/data_info' | |
self.preprocessor, self.preprocess_args = load_data_preprocessor() | |
self.ph_encoder, self.word_encoder = self.preprocessor.load_dict(self.data_dir) | |
self.spk_map = self.preprocessor.load_spk_map(self.data_dir) | |
self.model = self.build_model() | |
self.model.eval() | |
self.model.to(self.device) | |
self.vocoder = self.build_vocoder() | |
self.vocoder.eval() | |
self.vocoder.to(self.device) | |
print("TTS loaded!") | |
def build_model(self): | |
from .utils.commons.ckpt_utils import load_ckpt | |
from .modules.tts.portaspeech.portaspeech import PortaSpeech | |
ph_dict_size = len(self.ph_encoder) | |
word_dict_size = len(self.word_encoder) | |
model = PortaSpeech(ph_dict_size, word_dict_size, self.hparams) | |
load_ckpt(model, 'text_to_speech/checkpoints/ljspeech/ps_adv_baseline', 'model') | |
model.to(self.device) | |
with torch.no_grad(): | |
model.store_inverse_all() | |
model.eval() | |
return model | |
def forward_model(self, inp): | |
sample = self.input_to_batch(inp) | |
with torch.no_grad(): | |
output = self.model( | |
sample['txt_tokens'], | |
sample['word_tokens'], | |
ph2word=sample['ph2word'], | |
word_len=sample['word_lengths'].max(), | |
infer=True, | |
forward_post_glow=True, | |
spk_id=sample.get('spk_ids') | |
) | |
mel_out = output['mel_out'] | |
wav_out = self.run_vocoder(mel_out) | |
wav_out = wav_out.cpu().numpy() | |
return wav_out[0] | |
def build_vocoder(self): | |
from .utils.commons.hparams import set_hparams | |
from .modules.vocoder.hifigan.hifigan import HifiGanGenerator | |
from .utils.commons.ckpt_utils import load_ckpt | |
base_dir = 'text_to_speech/checkpoints/hifi_lj' | |
config_path = f'{base_dir}/config.yaml' | |
config = set_hparams(config_path, global_hparams=False) | |
vocoder = HifiGanGenerator(config) | |
load_ckpt(vocoder, base_dir, 'model_gen') | |
return vocoder | |
def run_vocoder(self, c): | |
c = c.transpose(2, 1) | |
y = self.vocoder(c)[:, 0] | |
return y | |
def preprocess_input(self, inp): | |
""" | |
:param inp: {'text': str, 'item_name': (str, optional), 'spk_name': (str, optional)} | |
:return: | |
""" | |
preprocessor, preprocess_args = self.preprocessor, self.preprocess_args | |
text_raw = inp['text'] | |
item_name = inp.get('item_name', '<ITEM_NAME>') | |
spk_name = inp.get('spk_name', '<SINGLE_SPK>') | |
ph, txt, word, ph2word, ph_gb_word = preprocessor.txt_to_ph( | |
preprocessor.txt_processor, text_raw, preprocess_args) | |
word_token = self.word_encoder.encode(word) | |
ph_token = self.ph_encoder.encode(ph) | |
spk_id = self.spk_map[spk_name] | |
item = {'item_name': item_name, 'text': txt, 'ph': ph, 'spk_id': spk_id, | |
'ph_token': ph_token, 'word_token': word_token, 'ph2word': ph2word, | |
'ph_words':ph_gb_word, 'words': word} | |
item['ph_len'] = len(item['ph_token']) | |
return item | |
def input_to_batch(self, item): | |
item_names = [item['item_name']] | |
text = [item['text']] | |
ph = [item['ph']] | |
txt_tokens = torch.LongTensor(item['ph_token'])[None, :].to(self.device) | |
txt_lengths = torch.LongTensor([txt_tokens.shape[1]]).to(self.device) | |
word_tokens = torch.LongTensor(item['word_token'])[None, :].to(self.device) | |
word_lengths = torch.LongTensor([txt_tokens.shape[1]]).to(self.device) | |
ph2word = torch.LongTensor(item['ph2word'])[None, :].to(self.device) | |
spk_ids = torch.LongTensor(item['spk_id'])[None, :].to(self.device) | |
batch = { | |
'item_name': item_names, | |
'text': text, | |
'ph': ph, | |
'txt_tokens': txt_tokens, | |
'txt_lengths': txt_lengths, | |
'word_tokens': word_tokens, | |
'word_lengths': word_lengths, | |
'ph2word': ph2word, | |
'spk_ids': spk_ids, | |
} | |
return batch | |
def postprocess_output(self, output): | |
return output | |
def infer_once(self, inp): | |
inp = self.preprocess_input(inp) | |
output = self.forward_model(inp) | |
output = self.postprocess_output(output) | |
return output | |