File size: 4,658 Bytes
7ee3434
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import pathlib
import string
import time
from multiprocessing import Pool, Value, Lock
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
import torch
import whisper

processed_files_count = Value("i", 0)  # count of processed files
lock = Lock()  # lock for the count


def preprocess_text(text):
    """Preprocess text after ASR"""
    return text.lower().translate(str.maketrans("", "", string.punctuation))


def transcribe_audio(model, processor, audio_file, device):
    """Transcribe audio file"""
    audio = whisper.load_audio(audio_file)  # load from path
    audio = whisper.pad_or_trim(audio)  # default 30 seconds
    inputs = whisper.log_mel_spectrogram(audio).to(
        device=device
    )  # convert to spectrogram
    inputs = inputs.unsqueeze(0).type(torch.cuda.HalfTensor)  # add batch dimension

    outputs = model.generate(
        inputs=inputs, max_new_tokens=128
    )  # generate transcription
    transcription = processor.batch_decode(outputs, skip_special_tokens=True)[
        0
    ]  # decode
    transcription_processed = preprocess_text(transcription)  # preprocess
    return transcription_processed


def write_transcription(audio_file, transcription):
    """Write transcription to txt file"""
    txt_file = audio_file.with_suffix(".txt")
    with open(txt_file, "w") as file:
        file.write(transcription)


def init_whisper(model_id, device):
    """Initialize whisper model and processor"""
    torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
    print(f"Loading model {model_id}")  # model_id = "distil-whisper/distil-large-v2"
    distil_model = AutoModelForSpeechSeq2Seq.from_pretrained(
        model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=False
    )
    distil_model = distil_model.to(device)
    processor = AutoProcessor.from_pretrained(model_id)
    return distil_model, processor


def asr_wav_files(file_list, gpu_id, total_files, model_id):
    """Transcribe wav files in a list"""
    device = f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu"
    whisper_model, processor = init_whisper(model_id, device)
    print(f"Processing on {device} starts")
    start_time = time.time()
    for audio_file in file_list:
        try:
            transcription = transcribe_audio(
                whisper_model, processor, audio_file, device
            )
            write_transcription(audio_file, transcription)
            with lock:
                processed_files_count.value += 1
                if processed_files_count.value % 5 == 0:
                    current_time = time.time()
                    avg_time_per_file = (current_time - start_time) / (
                        processed_files_count.value
                    )
                    remaining_files = total_files - processed_files_count.value
                    estimated_time_remaining = avg_time_per_file * remaining_files
                    remaining_time_formatted = time.strftime(
                        "%H:%M:%S", time.gmtime(estimated_time_remaining)
                    )
                    print(
                        f"Processed {processed_files_count.value}/{total_files} files, time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}, Estimated time remaining: {remaining_time_formatted}"
                    )
        except Exception as e:
            print(f"Error processing file {audio_file}: {e}")


def asr_main(input_dir, num_gpus, model_id):
    """Transcribe wav files in a directory"""
    num_processes = min(num_gpus, os.cpu_count())
    print(f"Using {num_processes} GPUs for transcription")
    wav_files = list(pathlib.Path(input_dir).rglob("*.wav"))
    total_files = len(wav_files)
    print(f"Found {total_files} wav files in {input_dir}")
    files_per_process = len(wav_files) // num_processes
    print(f"Processing {files_per_process} files per process")
    with Pool(num_processes) as p:
        p.starmap(
            asr_wav_files,
            [
                (
                    wav_files[i * files_per_process : (i + 1) * files_per_process],
                    i % num_gpus,
                    total_files,
                    model_id,
                )
                for i in range(num_processes)
            ],
        )
    print("Done!")


if __name__ == "__main__":
    input_dir = "/path/to/output/directory"
    num_gpus = 2
    model_id = "distil-whisper/distil-large-v2"
    asr_main(input_dir, num_gpus, model_id)