File size: 1,094 Bytes
53fa903
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
# from inference.tts.fs import FastSpeechInfer
# from modules.tts.fs2_orig import FastSpeech2Orig
from inference.tts.base_tts_infer import BaseTTSInfer
from modules.tts.diffspeech.shallow_diffusion_tts import GaussianDiffusion
from utils.commons.ckpt_utils import load_ckpt
from utils.commons.hparams import hparams


class DiffSpeechInfer(BaseTTSInfer):
    def build_model(self):
        dict_size = len(self.ph_encoder)
        model = GaussianDiffusion(dict_size, self.hparams)
        model.eval()
        load_ckpt(model, hparams['work_dir'], 'model')
        return model

    def forward_model(self, inp):
        sample = self.input_to_batch(inp)
        txt_tokens = sample['txt_tokens']  # [B, T_t]
        spk_id = sample.get('spk_ids')
        with torch.no_grad():
            output = self.model(txt_tokens, spk_id=spk_id, ref_mels=None, infer=True)
            mel_out = output['mel_out']
            wav_out = self.run_vocoder(mel_out)
        wav_out = wav_out.cpu().numpy()
        return wav_out[0]

if __name__ == '__main__':
    DiffSpeechInfer.example_run()