susnato's picture
Update app.py
cb99941
raw
history blame
2.69 kB
import torch
import numpy as np
import gradio as gr
from transformers import AutoProcessor, SpeechT5ForTextToSpeech, pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, SpeechT5HifiGan
from datasets import load_dataset
device = "cpu"
# load speech translation checkpoint
asr_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device)
# load text-to-speech checkpoint
tts_processor = AutoProcessor.from_pretrained("susnato/speecht5_finetuned_voxpopuli_nl")
tts_model = SpeechT5ForTextToSpeech.from_pretrained("susnato/speecht5_finetuned_voxpopuli_nl").to(device)
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device)
# load speaker embeddings
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
def transcribe(audio):
outputs = asr_pipe(audio, generate_kwargs={"task": "transcribe",
"language":"nl",
"use_cache":True,
"max_new_tokens":128})
return outputs["text"]
def synthesise(text):
inputs = tts_processor(text=text,
truncation=True,
return_tensors="pt")
speech = tts_model.generate_speech(inputs["input_ids"].to(device),
speaker_embeddings.to(device),
vocoder=vocoder,
)
return speech.cpu().numpy()
def speech_to_dutch_translation(audio):
dutch_text = transcribe(audio)
speech = synthesise(dutch_text)
speech = (speech * 32767).astype(np.int16)
return 16_000, speech
title = "Speech-To-Speech-Translation for Hindi"
description = """
![Cascaded STST](https://huggingface.co/datasets/huggingface-course/audio-course-images/resolve/main/s2st_cascaded.png "Diagram of cascaded speech to speech translation")
"""
demo = gr.Blocks()
mic_translate = gr.Interface(
fn=speech_to_dutch_translation,
inputs=gr.Audio(source="microphone", type="filepath"),
outputs=gr.Audio(label="Generated Speech", type="numpy"),
title=title,
description=description,
)
file_translate = gr.Interface(
fn=speech_to_dutch_translation,
inputs=gr.Audio(source="upload", type="filepath"),
outputs=gr.Audio(label="Generated Speech", type="numpy"),
# examples=["./example.wav"]],
title=title,
description=description,
)
with demo:
gr.TabbedInterface([mic_translate, file_translate], ["Microphone", "Audio File"])
demo.launch(debug=False)