File size: 4,276 Bytes
0a06de9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b6adf6
0a06de9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
from huggingface_hub import hf_hub_download
from models import SynthesizerTrn
import subprocess
import torch
from torch import nn
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, AutoTokenizer, AutoModelForSeq2SeqLM
from text.mappers import TextMapper, preprocess_char, preprocess_text
import utils


class CombinedModel(nn.Module):
    def __init__(self, stt_model_name, nmt_model_name,tts_checkpoint_path = False, huggingface_checkpoint = False, tts_checkpoint_name="G_100000.pth" ,language = "eng", device = "cuda"):
        super(CombinedModel, self).__init__()

        self.stt_processor = Wav2Vec2Processor.from_pretrained(stt_model_name)
        self.stt_model = Wav2Vec2ForCTC.from_pretrained(stt_model_name)
        self.nmt_tokenizer = AutoTokenizer.from_pretrained(nmt_model_name)
        self.nmt_model = AutoModelForSeq2SeqLM.from_pretrained(nmt_model_name)

        self.language = language
        self.device = device

        if huggingface_checkpoint:
            self.tts_checkpoint_path = hf_hub_download(huggingface_checkpoint,tts_checkpoint_name)
        elif not tts_checkpoint_path:
            self.tts_checkpoint_path = self.download_mms_tts(self.language)
        else:
            self.tts_checkpoint_path = tts_checkpoint_path

        vocab_file = f"{self.tts_checkpoint_path}/vocab.txt"
        config_file = f"{self.tts_checkpoint_path}/config.json"
        
        self.hps = utils.get_hparams_from_file(config_file)
        self.text_mapper = TextMapper(vocab_file)
        self.tts_synth = SynthesizerTrn(
            len(self.text_mapper.symbols),
            self.hps.data.filter_length // 2 + 1,
            self.hps.train.segment_size // self.hps.data.hop_length,
            **self.hps.model)
        if not huggingface_checkpoint:
            g_pth = f"{self.tts_checkpoint_path}/{self.tts_checkpoint_name}"
        else:
            g_pth = self.tts_checkpoint_path
        _ = utils.load_checkpoint(g_pth, self.tts_synth, None)


    def forward(self, batch, *args, **kwargs):
        # Use stt_model to transcribe the audio to text
        device = self.device
        audio = torch.tensor(batch["audio"][0]).to(self.device)
        input_features = self.stt_processor(audio,sampling_rate=16000, return_tensors="pt",max_length=110000, padding=True, truncation=True)
        stt_output = self.stt_model(input_features.input_values.to(device), attention_mask= input_features.attention_mask.to(device) )
        transcription = self.stt_processor.decode(torch.squeeze(stt_output.logits.argmax(axis=-1)).to(device))
        input_nmt_tokens = self.nmt_tokenizer(transcription, return_tensors="pt", padding=True, truncation=True)
        output_nmt_output = self.nmt_model.generate(input_ids = input_nmt_tokens.input_ids.to(device), attention_mask= input_nmt_tokens.attention_mask.to(device))
        decoded_nmt_output = self.nmt_tokenizer.batch_decode(output_nmt_output, skip_special_tokens=True)

        txt = preprocess_text(decoded_nmt_output[0], self.text_mapper, self.hps, lang=self.language)
        txt = self.text_mapper.get_text(txt, self.hps)

        x_tst = txt.unsqueeze(0).to(device)
        x_tst_lengths = torch.LongTensor([txt.size(0)]).to(device)
        # No speaker embedding
        generated_audio = self.tts_synth.infer(
            x_tst, x_tst_lengths, noise_scale=.667,
            noise_scale_w=0.8, length_scale=1.0
        )[0][0].detach().cpu()

        return transcription, decoded_nmt_output, generated_audio

    @staticmethod
    def download_mms_tts(lang, tgt_dir="./"):
        #FIXME don't redownload
        if os.path.join(tgt_dir, lang):
            return os.path.join(tgt_dir, lang)
        lang_fn, lang_dir = os.path.join(tgt_dir, lang+'.tar.gz'), os.path.join(tgt_dir, lang)
        cmd = ";".join([
                f"wget https://dl.fbaipublicfiles.com/mms/tts/{lang}.tar.gz -O {lang_fn}",
                f"tar zxvf {lang_fn}"
        ])
        print(f"Download model for language: {lang}")
        subprocess.check_output(cmd, shell=True)
        print(f"Model checkpoints in {lang_dir}: {os.listdir(lang_dir)}")
        return lang_dir



# Usage
#model = CombinedModel("ak3ra/wav2vec2-sunbird-speech-lug", "Sunbird/sunbird-mul-en-mbart-merged", device="cpu")