Spaces:
Runtime error
Runtime error
import os | |
from huggingface_hub import hf_hub_download | |
from models import SynthesizerTrn | |
import subprocess | |
import torch | |
from torch import nn | |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, AutoTokenizer, AutoModelForSeq2SeqLM | |
from text.mappers import TextMapper, preprocess_char, preprocess_text | |
import utils | |
class CombinedModel(nn.Module): | |
def __init__(self, stt_model_name, nmt_model_name,tts_checkpoint_path = False, huggingface_checkpoint = False, tts_checkpoint_name="G_100000.pth" ,language = "eng", device = "cuda"): | |
super(CombinedModel, self).__init__() | |
self.stt_processor = Wav2Vec2Processor.from_pretrained(stt_model_name) | |
self.stt_model = Wav2Vec2ForCTC.from_pretrained(stt_model_name) | |
self.nmt_tokenizer = AutoTokenizer.from_pretrained(nmt_model_name) | |
self.nmt_model = AutoModelForSeq2SeqLM.from_pretrained(nmt_model_name) | |
self.language = language | |
self.device = device | |
if huggingface_checkpoint: | |
self.tts_checkpoint_path = hf_hub_download(huggingface_checkpoint,tts_checkpoint_name) | |
vocab_file = hf_hub_download(huggingface_checkpoint,"vocab.txt") | |
config_file = hf_hub_download(huggingface_checkpoint,"config.json") | |
elif not tts_checkpoint_path: | |
self.tts_checkpoint_path = self.download_mms_tts(self.language) | |
vocab_file = f"{self.tts_checkpoint_path}/vocab.txt" | |
config_file = f"{self.tts_checkpoint_path}/config.json" | |
else: | |
self.tts_checkpoint_path = tts_checkpoint_path | |
vocab_file = f"{self.tts_checkpoint_path}/vocab.txt" | |
config_file = f"{self.tts_checkpoint_path}/config.json" | |
self.hps = utils.get_hparams_from_file(config_file) | |
self.text_mapper = TextMapper(vocab_file) | |
self.tts_synth = SynthesizerTrn( | |
len(self.text_mapper.symbols), | |
self.hps.data.filter_length // 2 + 1, | |
self.hps.train.segment_size // self.hps.data.hop_length, | |
**self.hps.model) | |
if huggingface_checkpoint: | |
g_pth = self.tts_checkpoint_path | |
else: | |
g_pth = f"{self.tts_checkpoint_path}/{self.tts_checkpoint_name}" | |
_ = utils.load_checkpoint(g_pth, self.tts_synth, None) | |
def forward(self, batch, *args, **kwargs): | |
# Use stt_model to transcribe the audio to text | |
device = self.device | |
audio = torch.tensor(batch["audio"][0]).to(self.device) | |
input_features = self.stt_processor(audio,sampling_rate=16000, return_tensors="pt",max_length=110000, padding=True, truncation=True) | |
stt_output = self.stt_model(input_features.input_values.to(device), attention_mask= input_features.attention_mask.to(device) ) | |
transcription = self.stt_processor.decode(torch.squeeze(stt_output.logits.argmax(axis=-1)).to(device)) | |
input_nmt_tokens = self.nmt_tokenizer(transcription, return_tensors="pt", padding=True, truncation=True) | |
output_nmt_output = self.nmt_model.generate(input_ids = input_nmt_tokens.input_ids.to(device), attention_mask= input_nmt_tokens.attention_mask.to(device)) | |
decoded_nmt_output = self.nmt_tokenizer.batch_decode(output_nmt_output, skip_special_tokens=True) | |
txt = preprocess_text(decoded_nmt_output[0], self.text_mapper, self.hps, lang=self.language) | |
txt = self.text_mapper.get_text(txt, self.hps) | |
x_tst = txt.unsqueeze(0).to(device) | |
x_tst_lengths = torch.LongTensor([txt.size(0)]).to(device) | |
# No speaker embedding | |
generated_audio = self.tts_synth.infer( | |
x_tst, x_tst_lengths, noise_scale=.667, | |
noise_scale_w=0.8, length_scale=1.0 | |
)[0][0].detach().cpu() | |
return transcription, decoded_nmt_output, generated_audio | |
def download_mms_tts(lang, tgt_dir="./"): | |
#FIXME don't redownload | |
if os.path.join(tgt_dir, lang): | |
return os.path.join(tgt_dir, lang) | |
lang_fn, lang_dir = os.path.join(tgt_dir, lang+'.tar.gz'), os.path.join(tgt_dir, lang) | |
cmd = ";".join([ | |
f"wget https://dl.fbaipublicfiles.com/mms/tts/{lang}.tar.gz -O {lang_fn}", | |
f"tar zxvf {lang_fn}" | |
]) | |
print(f"Download model for language: {lang}") | |
subprocess.check_output(cmd, shell=True) | |
print(f"Model checkpoints in {lang_dir}: {os.listdir(lang_dir)}") | |
return lang_dir | |
# Usage | |
#model = CombinedModel("ak3ra/wav2vec2-sunbird-speech-lug", "Sunbird/sunbird-mul-en-mbart-merged", device="cpu") | |