Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import uuid | |
import json | |
import librosa | |
import os | |
import tempfile | |
import soundfile as sf | |
import scipy.io.wavfile as wav | |
from transformers import VitsModel, AutoTokenizer, set_seed | |
from nemo.collections.asr.models import EncDecMultiTaskModel | |
# Constants | |
SAMPLE_RATE = 16000 # Hz | |
# Load ASR model | |
canary_model = EncDecMultiTaskModel.from_pretrained('nvidia/canary-1b') | |
decode_cfg = canary_model.cfg.decoding | |
decode_cfg.beam.beam_size = 1 | |
canary_model.change_decoding_strategy(decode_cfg) | |
# Function to convert audio to text using ASR | |
def gen_text(audio_filepath, action, source_lang, target_lang): | |
if audio_filepath is None: | |
raise gr.Error("Please provide some input audio.") | |
utt_id = uuid.uuid4() | |
with tempfile.TemporaryDirectory() as tmpdir: | |
# Convert to 16 kHz | |
data, sr = librosa.load(audio_filepath, sr=None, mono=True) | |
if sr != SAMPLE_RATE: | |
data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE) | |
converted_audio_filepath = os.path.join(tmpdir, f"{utt_id}.wav") | |
sf.write(converted_audio_filepath, data, SAMPLE_RATE) | |
# Transcribe or translate audio | |
duration = len(data) / SAMPLE_RATE | |
manifest_data = { | |
"audio_filepath": converted_audio_filepath, | |
"taskname": action, | |
"source_lang": source_lang, | |
"target_lang": source_lang if action == "asr" else target_lang, | |
"pnc": "no", | |
"answer": "predict", | |
"duration": str(duration), | |
} | |
manifest_filepath = os.path.join(tmpdir, f"{utt_id}.json") | |
with open(manifest_filepath, 'w') as fout: | |
fout.write(json.dumps(manifest_data)) | |
predicted_text = canary_model.transcribe(manifest_filepath)[0] | |
return predicted_text | |
# Function to convert text to speech using TTS | |
def gen_speech(text, lang): | |
set_seed(555) # Make it deterministic | |
model = f"facebook/mms-tts-{lang}" | |
# load TTS model | |
tts_model = VitsModel.from_pretrained(model) | |
tts_tokenizer = AutoTokenizer.from_pretrained(model) | |
input_text = tts_tokenizer(text, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = tts_model(**input_text) | |
waveform_np = outputs.waveform[0].cpu().numpy() | |
return SAMPLE_RATE, waveform_np | |
# Main function for speech-to-speech translation | |
def speech_to_speech_translation(audio_filepath, source_lang, target_lang): | |
translation = gen_text(audio_filepath, "s2t_translation", source_lang, target_lang) | |
sample_rate, synthesized_speech = gen_speech(translation, target_lang) | |
return sample_rate, synthesized_speech | |
# Define supported languages | |
LANGUAGES = { | |
"English": "eng", | |
"German": "deu", | |
"Spanish": "spa", | |
"French": "fra" | |
} | |
# Create Gradio interface | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown("# Multilingual Speech-to-Speech Translation") | |
gr.Markdown("Translate speech from one language to another.") | |
with gr.Row(): | |
source_lang = gr.Dropdown(choices=list(LANGUAGES.keys()), value="English", label="Source Language") | |
target_lang = gr.Dropdown(choices=list(LANGUAGES.keys()), value="French", label="Target Language") | |
with gr.Tabs(): | |
with gr.TabItem("Microphone"): | |
mic_input = gr.Audio(source="microphone", type="filepath") | |
mic_output = gr.Audio(label="Generated Speech", type="numpy") | |
mic_button = gr.Button("Translate") | |
with gr.TabItem("Audio File"): | |
file_input = gr.Audio(source="upload", type="filepath") | |
file_output = gr.Audio(label="Generated Speech", type="numpy") | |
file_button = gr.Button("Translate") | |
mic_button.click( | |
speech_to_speech_translation, | |
inputs=[mic_input, source_lang, target_lang], | |
outputs=mic_output | |
) | |
file_button.click( | |
speech_to_speech_translation, | |
inputs=[file_input, source_lang, target_lang], | |
outputs=file_output | |
) | |
demo.launch() |