KingNish's picture
Update app.py
1a6b2bc verified
raw
history blame
4.71 kB
import spaces
import torch
import gradio as gr
import tempfile
import os
import uuid
import scipy.io.wavfile
import time
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, WhisperTokenizer, pipeline
import subprocess
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16
MODEL_NAME = "openai/whisper-large-v3-turbo"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
MODEL_NAME, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="flash_attention_2"
)
model.to(device)
processor = AutoProcessor.from_pretrained(MODEL_NAME)
tokenizer = WhisperTokenizer.from_pretrained(MODEL_NAME, language="en")
pipe = pipeline(
task="automatic-speech-recognition",
model=model,
tokenizer=tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=25,
torch_dtype=torch_dtype,
device=device,
)
@spaces.GPU
def transcribe(inputs, previous_transcription):
start_time = time.time()
try:
filename = f"{uuid.uuid4().hex}.wav"
sample_rate, audio_data = inputs
# Check the duration of the audio
duration = len(audio_data) / sample_rate # Duration in seconds
if duration > 5:
# Split audio into chunks of 5 seconds
chunk_size = 5 * sample_rate # Number of samples for 5 seconds
num_chunks = int(np.ceil(len(audio_data) / chunk_size))
transcriptions = []
for i in range(num_chunks):
start_index = i * chunk_size
end_index = min(start_index + chunk_size, len(audio_data))
chunk_data = audio_data[start_index:end_index]
# Write chunk to a temporary file
chunk_filename = f"{uuid.uuid4().hex}_chunk.wav"
scipy.io.wavfile.write(chunk_filename, sample_rate, chunk_data)
# Transcribe the chunk
transcription = pipe(chunk_filename)["text"]
transcriptions.append(transcription)
# Combine all transcriptions
previous_transcription += " ".join(transcriptions)
else:
# Write the original audio file if it's 5 seconds or less
scipy.io.wavfile.write(filename, sample_rate, audio_data)
transcription = pipe(filename)["text"]
previous_transcription += transcription
end_time = time.time()
latency = end_time - start_time
return previous_transcription, f"{latency:.2f}"
except Exception as e:
print(f"Error during Transcription: {e}")
return previous_transcription, "Error"
def clear():
return ""
with gr.Blocks() as microphone:
with gr.Column():
gr.Markdown(f"# Realtime Whisper Large V3 Turbo: \n Transcribe Audio in Realtime. This Demo uses the Checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers.\n Note: The first token takes about 5 seconds. After that, it works flawlessly.")
with gr.Row():
input_audio_microphone = gr.Audio(streaming=True)
output = gr.Textbox(label="Transcription", value="")
latency_textbox = gr.Textbox(label="Latency (seconds)", value="0.0", scale=0)
with gr.Row():
clear_button = gr.Button("Clear Output")
input_audio_microphone.stream(transcribe, [input_audio_microphone, output], [output, latency_textbox], time_limit=45, stream_every=2, concurrency_limit=None)
clear_button.click(clear, outputs=[output])
with gr.Blocks() as flie:
with gr.Column():
gr.Markdown(f"# Realtime Whisper Large V3 Turbo: \n Transcribe Audio in Realtime. This Demo uses the Checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers.\n Note: The first token takes about 5 seconds. After that, it works flawlessly.")
with gr.Row():
input_audio_microphone = gr.Audio(sources="upload", type="numpy")
output = gr.Textbox(label="Transcription", value="")
latency_textbox = gr.Textbox(label="Latency (seconds)", value="0.0", scale=0)
with gr.Row():
submit_button = gr.Button("Submit")
clear_button = gr.Button("Clear Output")
submit_button.click(transcribe, [input_audio_microphone, output], [output, latency_textbox], concurrency_limit=None)
clear_button.click(clear, outputs=[output])
with gr.Blocks() as demo:
gr.TabbedInterface([microphone, flie], ["Microphone", "Audio file"])
demo.launch()