Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
import torch | |
import torchaudio | |
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, WhisperTokenizer | |
# Define paths to the model and processor | |
model_name = "userdata/whisper-largeV2-03-ms-v11-LORA-Merged" | |
# Load the processor and model | |
processor = AutoProcessor.from_pretrained(model_name) | |
tokenizer = WhisperTokenizer.from_pretrained(model_name) | |
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name) | |
# Check and set the device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Function to chunk the audio | |
def chunk_audio(audio, chunk_length): | |
num_chunks = len(audio) // chunk_length + (1 if len(audio) % chunk_length > 0 else 0) | |
return [audio[i * chunk_length:(i + 1) * chunk_length] for i in range(num_chunks)] | |
# Function to transcribe an audio file | |
def transcribe(audio_path, chunk_length=16000 * 30): # 30 seconds chunks | |
# Load audio | |
speech_array, sampling_rate = torchaudio.load(audio_path) | |
# Resample to 16kHz | |
resampler = torchaudio.transforms.Resample(sampling_rate, 16000) | |
speech = resampler(speech_array).squeeze().numpy() | |
# Chunk the audio if it's too long | |
chunks = chunk_audio(speech, chunk_length) | |
# Transcribe each chunk | |
transcriptions = [] | |
for chunk in chunks: | |
# Process the audio | |
inputs = processor(chunk, sampling_rate=16000, return_tensors="pt") | |
inputs = {key: value.to(device).to(torch.float16) for key, value in inputs.items()} # Convert to float16 | |
# Generate token IDs | |
with torch.no_grad(): | |
generated_ids = model.generate(inputs["input_features"], max_length=448) | |
# Decode the token IDs to text | |
transcription = tokenizer.decode(generated_ids[0], skip_special_tokens=True) | |
transcriptions.append(transcription) | |
# Combine transcriptions | |
full_transcription = ' '.join(transcriptions) | |
return full_transcription | |
# Create the Gradio interface | |
iface = gr.Interface( | |
fn=transcribe, # Update to match the function name | |
inputs=gr.Audio(type="filepath"), | |
outputs=gr.Textbox(), | |
title="Audio Transcription App", | |
description="Upload an audio file to get a transcription." | |
) | |
# Launch the Gradio interface | |
iface.launch() | |