bambara-asr / app.py
oza75's picture
Update app.py
b13531b verified
import os
import spaces
import torch
from transformers import pipeline, WhisperTokenizer
import torchaudio
import gradio as gr
# Please note that the below import will override whisper LANGUAGES to add bambara
# this is not the best way to do it but at least it works. for more info check the bambara_utils code
from bambara_utils import BambaraWhisperTokenizer
# Determine the appropriate device (GPU or CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Define the model checkpoint and language
#model_checkpoint = "oza75/whisper-bambara-asr-002"
#revision = "831cd15ed74a554caac9f304cf50dc773841ba1b"
model_checkpoint = "oza75/whisper-bambara-asr-005"
revision = "6a92cd0f19985d12739c2f6864607627115e015d" # first good checkpoint for bambara
#revision = "fb69a5750182933868397543366dbb63747cf40c" # this only translate in english
#revision = "129f9e68ead6cc854e7754b737b93aa78e0e61e1" # support transcription and translation
#revision = "cb8e351b35d6dc524066679d9646f4a947300b27"
#revision = "5f143f6070b64412a44fea08e912e1b7312e9ae9" # this checkpoint support both task without overfitting
#model_checkpoint = "oza75/whisper-bambara-asr-006"
#revision = "96535debb4ce0b7af7c9c186d09d088825f63840"
#revision = "4549778c08f29ed2e033cc9a497a187488b6bf56"
# language = "bambara"
language = "icelandic" # we use icelandic as the model was trained to replace the icelandic with bambara.
# Load the custom tokenizer designed for Bambara and the ASR model
#tokenizer = BambaraWhisperTokenizer.from_pretrained(model_checkpoint, language=language, device=device)
tokenizer = WhisperTokenizer.from_pretrained(model_checkpoint, language=language, device=device)
pipe = pipeline(model=model_checkpoint, tokenizer=tokenizer, device=device, revision=revision)
def resample_audio(audio_path, target_sample_rate=16000):
"""
Converts the audio file to the target sampling rate (16000 Hz).
Args:
audio_path (str): Path to the audio file.
target_sample_rate (int): The desired sample rate.
Returns:
A tensor containing the resampled audio data and the target sample rate.
"""
waveform, original_sample_rate = torchaudio.load(audio_path)
if original_sample_rate != target_sample_rate:
resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=target_sample_rate)
waveform = resampler(waveform)
return waveform, target_sample_rate
@spaces.GPU()
def transcribe(audio, task_type):
"""
Transcribes the provided audio file into text using the configured ASR pipeline.
Args:
audio: The path to the audio file to transcribe.
Returns:
A string representing the transcribed text.
"""
# Convert the audio to 16000 Hz
waveform, sample_rate = resample_audio(audio)
# Use the pipeline to perform transcription
sample = {"array": waveform.squeeze().numpy(), "sampling_rate": sample_rate}
text = pipe(sample, generate_kwargs={"task": task_type, "language": language})["text"]
return text
def get_wav_files(directory):
"""
Returns a list of absolute paths to all .wav files in the specified directory.
Args:
directory (str): The directory to search for .wav files.
Returns:
list: A list of absolute paths to the .wav files.
"""
# List all files in the directory
files = os.listdir(directory)
# Filter for .wav files and create absolute paths
wav_files = [os.path.abspath(os.path.join(directory, file)) for file in files if file.endswith('.wav')]
wav_files = [[f, "transcribe"] for f in wav_files]
return wav_files
def main():
# Get a list of all .wav files in the examples directory
example_files = get_wav_files("./examples")
# Setup Gradio interface
iface = gr.Interface(
fn=transcribe,
inputs=[
gr.Audio(type="filepath", value=example_files[0][0]),
gr.Radio(choices=["transcribe"], label="Task Type", value="transcribe")
],
outputs="text",
title="Bambara Automatic Speech Recognition",
description="Realtime demo for Bambara speech recognition based on a fine-tuning of the Whisper model.",
examples=example_files,
cache_examples="lazy",
)
# Launch the interface
iface.launch(share=False)
if __name__ == "__main__":
main()