Spaces:
Runtime error
Runtime error
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import os | |
import pathlib | |
import string | |
import time | |
from multiprocessing import Pool, Value, Lock | |
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor | |
import torch | |
import whisper | |
processed_files_count = Value("i", 0) # count of processed files | |
lock = Lock() # lock for the count | |
def preprocess_text(text): | |
"""Preprocess text after ASR""" | |
return text.lower().translate(str.maketrans("", "", string.punctuation)) | |
def transcribe_audio(model, processor, audio_file, device): | |
"""Transcribe audio file""" | |
audio = whisper.load_audio(audio_file) # load from path | |
audio = whisper.pad_or_trim(audio) # default 30 seconds | |
inputs = whisper.log_mel_spectrogram(audio).to( | |
device=device | |
) # convert to spectrogram | |
inputs = inputs.unsqueeze(0).type(torch.cuda.HalfTensor) # add batch dimension | |
outputs = model.generate( | |
inputs=inputs, max_new_tokens=128 | |
) # generate transcription | |
transcription = processor.batch_decode(outputs, skip_special_tokens=True)[ | |
0 | |
] # decode | |
transcription_processed = preprocess_text(transcription) # preprocess | |
return transcription_processed | |
def write_transcription(audio_file, transcription): | |
"""Write transcription to txt file""" | |
txt_file = audio_file.with_suffix(".txt") | |
with open(txt_file, "w") as file: | |
file.write(transcription) | |
def init_whisper(model_id, device): | |
"""Initialize whisper model and processor""" | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
print(f"Loading model {model_id}") # model_id = "distil-whisper/distil-large-v2" | |
distil_model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=False | |
) | |
distil_model = distil_model.to(device) | |
processor = AutoProcessor.from_pretrained(model_id) | |
return distil_model, processor | |
def asr_wav_files(file_list, gpu_id, total_files, model_id): | |
"""Transcribe wav files in a list""" | |
device = f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu" | |
whisper_model, processor = init_whisper(model_id, device) | |
print(f"Processing on {device} starts") | |
start_time = time.time() | |
for audio_file in file_list: | |
try: | |
transcription = transcribe_audio( | |
whisper_model, processor, audio_file, device | |
) | |
write_transcription(audio_file, transcription) | |
with lock: | |
processed_files_count.value += 1 | |
if processed_files_count.value % 5 == 0: | |
current_time = time.time() | |
avg_time_per_file = (current_time - start_time) / ( | |
processed_files_count.value | |
) | |
remaining_files = total_files - processed_files_count.value | |
estimated_time_remaining = avg_time_per_file * remaining_files | |
remaining_time_formatted = time.strftime( | |
"%H:%M:%S", time.gmtime(estimated_time_remaining) | |
) | |
print( | |
f"Processed {processed_files_count.value}/{total_files} files, time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}, Estimated time remaining: {remaining_time_formatted}" | |
) | |
except Exception as e: | |
print(f"Error processing file {audio_file}: {e}") | |
def asr_main(input_dir, num_gpus, model_id): | |
"""Transcribe wav files in a directory""" | |
num_processes = min(num_gpus, os.cpu_count()) | |
print(f"Using {num_processes} GPUs for transcription") | |
wav_files = list(pathlib.Path(input_dir).rglob("*.wav")) | |
total_files = len(wav_files) | |
print(f"Found {total_files} wav files in {input_dir}") | |
files_per_process = len(wav_files) // num_processes | |
print(f"Processing {files_per_process} files per process") | |
with Pool(num_processes) as p: | |
p.starmap( | |
asr_wav_files, | |
[ | |
( | |
wav_files[i * files_per_process : (i + 1) * files_per_process], | |
i % num_gpus, | |
total_files, | |
model_id, | |
) | |
for i in range(num_processes) | |
], | |
) | |
print("Done!") | |
if __name__ == "__main__": | |
input_dir = "/path/to/output/directory" | |
num_gpus = 2 | |
model_id = "distil-whisper/distil-large-v2" | |
asr_main(input_dir, num_gpus, model_id) | |