|
|
|
|
|
|
|
|
|
|
|
import os |
|
import pathlib |
|
import string |
|
import time |
|
from multiprocessing import Pool, Value, Lock |
|
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor |
|
import torch |
|
import whisper |
|
|
|
processed_files_count = Value("i", 0) |
|
lock = Lock() |
|
|
|
|
|
def preprocess_text(text): |
|
"""Preprocess text after ASR""" |
|
return text.lower().translate(str.maketrans("", "", string.punctuation)) |
|
|
|
|
|
def transcribe_audio(model, processor, audio_file, device): |
|
"""Transcribe audio file""" |
|
audio = whisper.load_audio(audio_file) |
|
audio = whisper.pad_or_trim(audio) |
|
inputs = whisper.log_mel_spectrogram(audio).to( |
|
device=device |
|
) |
|
inputs = inputs.unsqueeze(0).type(torch.cuda.HalfTensor) |
|
|
|
outputs = model.generate( |
|
inputs=inputs, max_new_tokens=128 |
|
) |
|
transcription = processor.batch_decode(outputs, skip_special_tokens=True)[ |
|
0 |
|
] |
|
transcription_processed = preprocess_text(transcription) |
|
return transcription_processed |
|
|
|
|
|
def write_transcription(audio_file, transcription): |
|
"""Write transcription to txt file""" |
|
txt_file = audio_file.with_suffix(".txt") |
|
with open(txt_file, "w") as file: |
|
file.write(transcription) |
|
|
|
|
|
def init_whisper(model_id, device): |
|
"""Initialize whisper model and processor""" |
|
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
print(f"Loading model {model_id}") |
|
distil_model = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=False |
|
) |
|
distil_model = distil_model.to(device) |
|
processor = AutoProcessor.from_pretrained(model_id) |
|
return distil_model, processor |
|
|
|
|
|
def asr_wav_files(file_list, gpu_id, total_files, model_id): |
|
"""Transcribe wav files in a list""" |
|
device = f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu" |
|
whisper_model, processor = init_whisper(model_id, device) |
|
print(f"Processing on {device} starts") |
|
start_time = time.time() |
|
for audio_file in file_list: |
|
try: |
|
transcription = transcribe_audio( |
|
whisper_model, processor, audio_file, device |
|
) |
|
write_transcription(audio_file, transcription) |
|
with lock: |
|
processed_files_count.value += 1 |
|
if processed_files_count.value % 5 == 0: |
|
current_time = time.time() |
|
avg_time_per_file = (current_time - start_time) / ( |
|
processed_files_count.value |
|
) |
|
remaining_files = total_files - processed_files_count.value |
|
estimated_time_remaining = avg_time_per_file * remaining_files |
|
remaining_time_formatted = time.strftime( |
|
"%H:%M:%S", time.gmtime(estimated_time_remaining) |
|
) |
|
print( |
|
f"Processed {processed_files_count.value}/{total_files} files, time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}, Estimated time remaining: {remaining_time_formatted}" |
|
) |
|
except Exception as e: |
|
print(f"Error processing file {audio_file}: {e}") |
|
|
|
|
|
def asr_main(input_dir, num_gpus, model_id): |
|
"""Transcribe wav files in a directory""" |
|
num_processes = min(num_gpus, os.cpu_count()) |
|
print(f"Using {num_processes} GPUs for transcription") |
|
wav_files = list(pathlib.Path(input_dir).rglob("*.wav")) |
|
total_files = len(wav_files) |
|
print(f"Found {total_files} wav files in {input_dir}") |
|
files_per_process = len(wav_files) // num_processes |
|
print(f"Processing {files_per_process} files per process") |
|
with Pool(num_processes) as p: |
|
p.starmap( |
|
asr_wav_files, |
|
[ |
|
( |
|
wav_files[i * files_per_process : (i + 1) * files_per_process], |
|
i % num_gpus, |
|
total_files, |
|
model_id, |
|
) |
|
for i in range(num_processes) |
|
], |
|
) |
|
print("Done!") |
|
|
|
|
|
if __name__ == "__main__": |
|
input_dir = "/path/to/output/directory" |
|
num_gpus = 2 |
|
model_id = "distil-whisper/distil-large-v2" |
|
asr_main(input_dir, num_gpus, model_id) |
|
|