File size: 5,502 Bytes
9206300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
import os
import importlib
from inference.tts.base_tts_infer import BaseTTSInfer
from utils.ckpt_utils import load_ckpt, get_last_checkpoint
from modules.GenerSpeech.model.generspeech import GenerSpeech
from data_gen.tts.emotion import inference as EmotionEncoder
from data_gen.tts.emotion.inference import embed_utterance as Embed_utterance
from data_gen.tts.emotion.inference import preprocess_wav
from data_gen.tts.data_gen_utils import is_sil_phoneme
from resemblyzer import VoiceEncoder
from utils import audio
class GenerSpeechInfer(BaseTTSInfer):
    def build_model(self):
        model = GenerSpeech(self.ph_encoder)
        model.eval()
        load_ckpt(model, self.hparams['work_dir'], 'model')
        return model

    def preprocess_input(self, inp):
        """
        :param inp: {'text': str, 'item_name': (str, optional), 'spk_name': (str, optional)}
        :return:
        """
        # processed text
        preprocessor, preprocess_args = self.preprocessor, self.preprocess_args
        text_raw = inp['text']
        item_name = inp.get('item_name', '<ITEM_NAME>')
        ph, txt, word, ph2word, ph_gb_word = preprocessor.txt_to_ph(preprocessor.txt_processor, text_raw, preprocess_args)
        ph_token = self.ph_encoder.encode(ph)

        # processed ref audio
        ref_audio = inp['ref_audio']
        processed_ref_audio = 'example/temp.wav'
        voice_encoder = VoiceEncoder().cuda()
        encoder = [self.ph_encoder, self.word_encoder]
        EmotionEncoder.load_model(self.hparams['emotion_encoder_path'])
        binarizer_cls = self.hparams.get("binarizer_cls", 'data_gen.tts.base_binarizerr.BaseBinarizer')
        pkg = ".".join(binarizer_cls.split(".")[:-1])
        cls_name = binarizer_cls.split(".")[-1]
        binarizer_cls = getattr(importlib.import_module(pkg), cls_name)

        ref_audio_raw, ref_text_raw = self.asr(ref_audio)  # prepare text
        ph_ref, txt_ref, word_ref, ph2word_ref, ph_gb_word_ref = preprocessor.txt_to_ph(preprocessor.txt_processor, ref_text_raw, preprocess_args)
        ph_gb_word_nosil = ["_".join([p for p in w.split("_") if not is_sil_phoneme(p)]) for w in ph_gb_word_ref.split(" ") if not is_sil_phoneme(w)]
        phs_for_align = ['SIL'] + ph_gb_word_nosil + ['SIL']
        phs_for_align = " ".join(phs_for_align)

        # prepare files for alignment
        os.system('rm -r example/; mkdir example/')
        audio.save_wav(ref_audio_raw, processed_ref_audio, self.hparams['audio_sample_rate'])
        with open(f'example/temp.lab', 'w') as f_txt:
            f_txt.write(phs_for_align)
        os.system(f'mfa align example/ {self.hparams["binary_data_dir"]}/mfa_dict.txt {self.hparams["binary_data_dir"]}/mfa_model.zip example/textgrid/  --clean')
        item2tgfn = 'example/textgrid/temp.TextGrid'  # prepare textgrid alignment

        item = binarizer_cls.process_item(item_name, ph_ref, txt_ref, item2tgfn, processed_ref_audio, 0, 0, encoder, self.hparams['binarization_args'])
        item['emo_embed'] = Embed_utterance(preprocess_wav(item['wav_fn']))
        item['spk_embed'] = voice_encoder.embed_utterance(item['wav'])

        item.update({
            'ref_ph': item['ph'],
            'ph': ph,
            'ph_token': ph_token,
            'text': txt
        })
        return item

    def input_to_batch(self, item):
        item_names = [item['item_name']]
        text = [item['text']]
        ph = [item['ph']]

        txt_tokens = torch.LongTensor(item['ph_token'])[None, :].to(self.device)
        txt_lengths = torch.LongTensor([txt_tokens.shape[1]]).to(self.device)
        mels = torch.FloatTensor(item['mel'])[None, :].to(self.device)
        f0 = torch.FloatTensor(item['f0'])[None, :].to(self.device)
        # uv = torch.FloatTensor(item['uv']).to(self.device)
        mel2ph = torch.LongTensor(item['mel2ph'])[None, :].to(self.device)
        spk_embed = torch.FloatTensor(item['spk_embed'])[None, :].to(self.device)
        emo_embed = torch.FloatTensor(item['emo_embed'])[None, :].to(self.device)

        ph2word = torch.LongTensor(item['ph2word'])[None, :].to(self.device)
        mel2word = torch.LongTensor(item['mel2word'])[None, :].to(self.device)
        word_tokens = torch.LongTensor(item['word_tokens'])[None, :].to(self.device)

        batch = {
            'item_name': item_names,
            'text': text,
            'ph': ph,
            'mels': mels,
            'f0': f0,
            'txt_tokens': txt_tokens,
            'txt_lengths': txt_lengths,
            'spk_embed': spk_embed,
            'emo_embed': emo_embed,
            'mel2ph': mel2ph,
            'ph2word': ph2word,
            'mel2word': mel2word,
            'word_tokens': word_tokens,
        }
        return batch

    def forward_model(self, inp):
        sample = self.input_to_batch(inp)
        txt_tokens = sample['txt_tokens']  # [B, T_t]
        with torch.no_grad():
            output = self.model(txt_tokens, ref_mel2ph=sample['mel2ph'], ref_mel2word=sample['mel2word'], ref_mels=sample['mels'],
                                spk_embed=sample['spk_embed'], emo_embed=sample['emo_embed'], global_steps=300000, infer=True)
            mel_out = output['mel_out']
            wav_out = self.run_vocoder(mel_out)
        wav_out = wav_out.squeeze().cpu().numpy()
        return wav_out




if __name__ == '__main__':
    inp = {
        'text': 'here we go',
        'ref_audio': 'assets/0011_001570.wav'
    }
    GenerSpeechInfer.example_run(inp)