Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import spaces | |
import torch | |
from transformers import pipeline, WhisperTokenizer | |
import torchaudio | |
import gradio as gr | |
# Please note that the below import will override whisper LANGUAGES to add bambara | |
# this is not the best way to do it but at least it works. for more info check the bambara_utils code | |
from bambara_utils import BambaraWhisperTokenizer | |
# Determine the appropriate device (GPU or CPU) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Define the model checkpoint and language | |
#model_checkpoint = "oza75/whisper-bambara-asr-002" | |
#revision = "831cd15ed74a554caac9f304cf50dc773841ba1b" | |
model_checkpoint = "oza75/whisper-bambara-asr-005" | |
revision = "6a92cd0f19985d12739c2f6864607627115e015d" # first good checkpoint for bambara | |
#revision = "fb69a5750182933868397543366dbb63747cf40c" # this only translate in english | |
#revision = "129f9e68ead6cc854e7754b737b93aa78e0e61e1" # support transcription and translation | |
#revision = "cb8e351b35d6dc524066679d9646f4a947300b27" | |
#revision = "5f143f6070b64412a44fea08e912e1b7312e9ae9" # this checkpoint support both task without overfitting | |
#model_checkpoint = "oza75/whisper-bambara-asr-006" | |
#revision = "96535debb4ce0b7af7c9c186d09d088825f63840" | |
#revision = "4549778c08f29ed2e033cc9a497a187488b6bf56" | |
# language = "bambara" | |
language = "icelandic" # we use icelandic as the model was trained to replace the icelandic with bambara. | |
# Load the custom tokenizer designed for Bambara and the ASR model | |
#tokenizer = BambaraWhisperTokenizer.from_pretrained(model_checkpoint, language=language, device=device) | |
tokenizer = WhisperTokenizer.from_pretrained(model_checkpoint, language=language, device=device) | |
pipe = pipeline(model=model_checkpoint, tokenizer=tokenizer, device=device, revision=revision) | |
def resample_audio(audio_path, target_sample_rate=16000): | |
""" | |
Converts the audio file to the target sampling rate (16000 Hz). | |
Args: | |
audio_path (str): Path to the audio file. | |
target_sample_rate (int): The desired sample rate. | |
Returns: | |
A tensor containing the resampled audio data and the target sample rate. | |
""" | |
waveform, original_sample_rate = torchaudio.load(audio_path) | |
if original_sample_rate != target_sample_rate: | |
resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=target_sample_rate) | |
waveform = resampler(waveform) | |
return waveform, target_sample_rate | |
def transcribe(audio, task_type): | |
""" | |
Transcribes the provided audio file into text using the configured ASR pipeline. | |
Args: | |
audio: The path to the audio file to transcribe. | |
Returns: | |
A string representing the transcribed text. | |
""" | |
# Convert the audio to 16000 Hz | |
waveform, sample_rate = resample_audio(audio) | |
# Use the pipeline to perform transcription | |
sample = {"array": waveform.squeeze().numpy(), "sampling_rate": sample_rate} | |
text = pipe(sample, generate_kwargs={"task": task_type, "language": language})["text"] | |
return text | |
def get_wav_files(directory): | |
""" | |
Returns a list of absolute paths to all .wav files in the specified directory. | |
Args: | |
directory (str): The directory to search for .wav files. | |
Returns: | |
list: A list of absolute paths to the .wav files. | |
""" | |
# List all files in the directory | |
files = os.listdir(directory) | |
# Filter for .wav files and create absolute paths | |
wav_files = [os.path.abspath(os.path.join(directory, file)) for file in files if file.endswith('.wav')] | |
wav_files = [[f, "transcribe"] for f in wav_files] | |
return wav_files | |
def main(): | |
# Get a list of all .wav files in the examples directory | |
example_files = get_wav_files("./examples") | |
# Setup Gradio interface | |
iface = gr.Interface( | |
fn=transcribe, | |
inputs=[ | |
gr.Audio(type="filepath", value=example_files[0][0]), | |
gr.Radio(choices=["transcribe"], label="Task Type", value="transcribe") | |
], | |
outputs="text", | |
title="Bambara Automatic Speech Recognition", | |
description="Realtime demo for Bambara speech recognition based on a fine-tuning of the Whisper model.", | |
examples=example_files, | |
cache_examples="lazy", | |
) | |
# Launch the interface | |
iface.launch(share=False) | |
if __name__ == "__main__": | |
main() | |