File size: 10,909 Bytes
c968fc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
# Copyright (c) 2024 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import faster_whisper
from typing import List, Union, Optional, NamedTuple
import torch
import numpy as np
import tqdm
from whisperx.audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
from whisperx.types import TranscriptionResult, SingleSegment
from whisperx.asr import WhisperModel, FasterWhisperPipeline, find_numeral_symbol_tokens


class VadFreeFasterWhisperPipeline(FasterWhisperPipeline):
    """
    FasterWhisperModel without VAD
    """

    def __init__(
        self,
        model,
        options: NamedTuple,
        tokenizer=None,
        device: Union[int, str, "torch.device"] = -1,
        framework="pt",
        language: Optional[str] = None,
        suppress_numerals: bool = False,
        **kwargs,
    ):
        """
        Initialize the VadFreeFasterWhisperPipeline.

        Args:
            model: The Whisper model instance.
            options: Transcription options.
            tokenizer: The tokenizer instance.
            device: Device to run the model on.
            framework: The framework to use ('pt' for PyTorch).
            language: The language for transcription.
            suppress_numerals: Whether to suppress numeral tokens.
            **kwargs: Additional keyword arguments.

        Returns:
            None
        """
        super().__init__(
            model=model,
            vad=None,
            vad_params={},
            options=options,
            tokenizer=tokenizer,
            device=device,
            framework=framework,
            language=language,
            suppress_numerals=suppress_numerals,
            **kwargs,
        )

    def detect_language(self, audio: np.ndarray):
        """
        Detect the language of the audio.

        Args:
            audio (np.ndarray): The input audio signal.

        Returns:
            tuple: Detected language and its probability.
        """
        model_n_mels = self.model.feat_kwargs.get("feature_size")
        if audio.shape[0] > N_SAMPLES:
            # Randomly sample N_SAMPLES from the audio array
            start_index = np.random.randint(0, audio.shape[0] - N_SAMPLES)
            audio_sample = audio[start_index : start_index + N_SAMPLES]
        else:
            audio_sample = audio[:N_SAMPLES]
        padding = 0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0]
        segment = log_mel_spectrogram(
            audio_sample,
            n_mels=model_n_mels if model_n_mels is not None else 80,
            padding=padding,
        )
        encoder_output = self.model.encode(segment)
        results = self.model.model.detect_language(encoder_output)
        language_token, language_probability = results[0][0]
        language = language_token[2:-2]
        return language, language_probability

    def transcribe(
        self,
        audio: Union[str, np.ndarray],
        vad_segments: List[dict],
        batch_size=None,
        num_workers=0,
        language=None,
        task=None,
        chunk_size=30,
        print_progress=False,
        combined_progress=False,
    ) -> TranscriptionResult:
        """
        Transcribe the audio into text.

        Args:
            audio (Union[str, np.ndarray]): The input audio signal or path to audio file.
            vad_segments (List[dict]): List of VAD segments.
            batch_size (int, optional): Batch size for transcription. Defaults to None.
            num_workers (int, optional): Number of workers for loading data. Defaults to 0.
            language (str, optional): Language for transcription. Defaults to None.
            task (str, optional): Task type ('transcribe' or 'translate'). Defaults to None.
            chunk_size (int, optional): Size of chunks for processing. Defaults to 30.
            print_progress (bool, optional): Whether to print progress. Defaults to False.
            combined_progress (bool, optional): Whether to combine progress. Defaults to False.

        Returns:
            TranscriptionResult: The transcription result containing segments and language.
        """
        if isinstance(audio, str):
            audio = load_audio(audio)

        def data(audio, segments):
            for seg in segments:
                f1 = int(seg["start"] * SAMPLE_RATE)
                f2 = int(seg["end"] * SAMPLE_RATE)
                yield {"inputs": audio[f1:f2]}

        if self.tokenizer is None:
            language = language or self.detect_language(audio)
            task = task or "transcribe"
            self.tokenizer = faster_whisper.tokenizer.Tokenizer(
                self.model.hf_tokenizer,
                self.model.model.is_multilingual,
                task=task,
                language=language,
            )
        else:
            language = language or self.tokenizer.language_code
            task = task or self.tokenizer.task
            if task != self.tokenizer.task or language != self.tokenizer.language_code:
                self.tokenizer = faster_whisper.tokenizer.Tokenizer(
                    self.model.hf_tokenizer,
                    self.model.model.is_multilingual,
                    task=task,
                    language=language,
                )

        if self.suppress_numerals:
            previous_suppress_tokens = self.options.suppress_tokens
            numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer)
            new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens
            new_suppressed_tokens = list(set(new_suppressed_tokens))
            self.options = self.options._replace(suppress_tokens=new_suppressed_tokens)

        segments: List[SingleSegment] = []
        batch_size = batch_size or self._batch_size
        total_segments = len(vad_segments)
        progress = tqdm.tqdm(total=total_segments, desc="Transcribing")
        for idx, out in enumerate(
            self.__call__(
                data(audio, vad_segments),
                batch_size=batch_size,
                num_workers=num_workers,
            )
        ):
            if print_progress:
                progress.update(1)
            text = out["text"]
            if batch_size in [0, 1, None]:
                text = text[0]
            segments.append(
                {
                    "text": text,
                    "start": round(vad_segments[idx]["start"], 3),
                    "end": round(vad_segments[idx]["end"], 3),
                    "speaker": vad_segments[idx].get("speaker", None),
                }
            )

        # revert the tokenizer if multilingual inference is enabled
        if self.preset_language is None:
            self.tokenizer = None

        # revert suppressed tokens if suppress_numerals is enabled
        if self.suppress_numerals:
            self.options = self.options._replace(
                suppress_tokens=previous_suppress_tokens
            )

        return {"segments": segments, "language": language}


def load_asr_model(
    whisper_arch: str,
    device: str,
    device_index: int = 0,
    compute_type: str = "float16",
    asr_options: Optional[dict] = None,
    language: Optional[str] = None,
    vad_model=None,
    vad_options=None,
    model: Optional[WhisperModel] = None,
    task: str = "transcribe",
    download_root: Optional[str] = None,
    threads: int = 4,
) -> VadFreeFasterWhisperPipeline:
    """
    Load a Whisper model for inference.

    Args:
        whisper_arch (str): The name of the Whisper model to load.
        device (str): The device to load the model on.
        device_index (int, optional): The device index. Defaults to 0.
        compute_type (str, optional): The compute type to use for the model. Defaults to "float16".
        asr_options (Optional[dict], optional): Options for ASR. Defaults to None.
        language (Optional[str], optional): The language of the model. Defaults to None.
        vad_model: The VAD model instance. Defaults to None.
        vad_options: Options for VAD. Defaults to None.
        model (Optional[WhisperModel], optional): The WhisperModel instance to use. Defaults to None.
        task (str, optional): The task type ('transcribe' or 'translate'). Defaults to "transcribe".
        download_root (Optional[str], optional): The root directory to download the model to. Defaults to None.
        threads (int, optional): The number of CPU threads to use per worker. Defaults to 4.

    Returns:
        VadFreeFasterWhisperPipeline: The loaded Whisper pipeline.

    Raises:
        ValueError: If the whisper architecture is not recognized.
    """

    if whisper_arch.endswith(".en"):
        language = "en"

    model = model or WhisperModel(
        whisper_arch,
        device=device,
        device_index=device_index,
        compute_type=compute_type,
        download_root=download_root,
        cpu_threads=threads,
    )
    if language is not None:
        tokenizer = faster_whisper.tokenizer.Tokenizer(
            model.hf_tokenizer,
            model.model.is_multilingual,
            task=task,
            language=language,
        )
    else:
        print(
            "No language specified, language will be detected for each audio file (increases inference time)."
        )
        tokenizer = None

    default_asr_options = {
        "beam_size": 5,
        "best_of": 5,
        "patience": 1,
        "length_penalty": 1,
        "repetition_penalty": 1,
        "no_repeat_ngram_size": 0,
        "temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
        "compression_ratio_threshold": 2.4,
        "log_prob_threshold": -1.0,
        "no_speech_threshold": 0.6,
        "condition_on_previous_text": False,
        "prompt_reset_on_temperature": 0.5,
        "initial_prompt": None,
        "prefix": None,
        "suppress_blank": True,
        "suppress_tokens": [-1],
        "without_timestamps": True,
        "max_initial_timestamp": 0.0,
        "word_timestamps": False,
        "prepend_punctuations": "\"'“¿([{-",
        "append_punctuations": "\"'.。,,!!??::”)]}、",
        "suppress_numerals": False,
        "max_new_tokens": None,
        "clip_timestamps": None,
        "hallucination_silence_threshold": None,
    }

    if asr_options is not None:
        default_asr_options.update(asr_options)

    suppress_numerals = default_asr_options["suppress_numerals"]
    del default_asr_options["suppress_numerals"]

    default_asr_options = faster_whisper.transcribe.TranscriptionOptions(
        **default_asr_options
    )

    return VadFreeFasterWhisperPipeline(
        model=model,
        options=default_asr_options,
        tokenizer=tokenizer,
        language=language,
        suppress_numerals=suppress_numerals,
    )