File size: 3,680 Bytes
9206300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from tasks.tts.dataset_utils import FastSpeechWordDataset
from tasks.tts.tts_utils import load_data_preprocessor
from vocoders.hifigan import HifiGanGenerator
import os
import librosa
import soundfile as sf
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from string import punctuation
import torch
from utils.ckpt_utils import load_ckpt
from utils.hparams import set_hparams
from utils.hparams import hparams as hp

class BaseTTSInfer:
    def __init__(self, hparams, device=None):
        if device is None:
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.hparams = hparams
        self.device = device
        self.data_dir = hparams['binary_data_dir']
        self.preprocessor, self.preprocess_args = load_data_preprocessor()
        self.ph_encoder, self.word_encoder = self.preprocessor.load_dict(self.data_dir)
        self.ds_cls = FastSpeechWordDataset
        self.model = self.build_model()
        self.model.eval()
        self.model.to(self.device)
        self.vocoder = self.build_vocoder()
        self.vocoder.eval()
        self.vocoder.to(self.device)
        self.asr_processor, self.asr_model = self.build_asr()

    def build_model(self):
        raise NotImplementedError

    def forward_model(self, inp):
        raise NotImplementedError

    def build_asr(self):
        # load pretrained model
        processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")  # facebook/wav2vec2-base-960h  wav2vec2-large-960h-lv60-self
        model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(self.device)
        return processor, model

    def build_vocoder(self):
        base_dir = self.hparams['vocoder_ckpt']
        config_path = f'{base_dir}/config.yaml'
        config = set_hparams(config_path, global_hparams=False)
        vocoder = HifiGanGenerator(config)
        load_ckpt(vocoder, base_dir, 'model_gen')
        return vocoder

    def run_vocoder(self, c):
        c = c.transpose(2, 1)
        y = self.vocoder(c)[:, 0]
        return y

    def preprocess_input(self, inp):
        raise NotImplementedError

    def input_to_batch(self, item):
        raise NotImplementedError

    def postprocess_output(self, output):
        return output

    def infer_once(self, inp):
        inp = self.preprocess_input(inp)
        output = self.forward_model(inp)
        output = self.postprocess_output(output)
        return output

    @classmethod
    def example_run(cls, inp):
        from utils.audio import save_wav

        #set_hparams(print_hparams=False)
        infer_ins = cls(hp)
        out = infer_ins.infer_once(inp)
        os.makedirs('infer_out', exist_ok=True)
        save_wav(out, f'infer_out/{hp["text"]}.wav', hp['audio_sample_rate'])
        print(f'Save at infer_out/{hp["text"]}.wav.')

    def asr(self, file):
        sample_rate = self.hparams['audio_sample_rate']
        audio_input, source_sample_rate = sf.read(file)

        # Resample the wav if needed
        if sample_rate is not None and source_sample_rate != sample_rate:
            audio_input = librosa.resample(audio_input, source_sample_rate, sample_rate)

        # pad input values and return pt tensor
        input_values = self.asr_processor(audio_input, sampling_rate=sample_rate, return_tensors="pt").input_values

        # retrieve logits & take argmax
        logits = self.asr_model(input_values.cuda()).logits
        predicted_ids = torch.argmax(logits, dim=-1)

        # transcribe
        transcription = self.asr_processor.decode(predicted_ids[0])
        transcription = transcription.rstrip(punctuation)
        return audio_input, transcription