Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2024 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import faster_whisper | |
from typing import List, Union, Optional, NamedTuple | |
import torch | |
import numpy as np | |
import tqdm | |
from whisperx.audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram | |
from whisperx.types import TranscriptionResult, SingleSegment | |
from whisperx.asr import WhisperModel, FasterWhisperPipeline, find_numeral_symbol_tokens | |
class VadFreeFasterWhisperPipeline(FasterWhisperPipeline): | |
""" | |
FasterWhisperModel without VAD | |
""" | |
def __init__( | |
self, | |
model, | |
options: NamedTuple, | |
tokenizer=None, | |
device: Union[int, str, "torch.device"] = -1, | |
framework="pt", | |
language: Optional[str] = None, | |
suppress_numerals: bool = False, | |
**kwargs, | |
): | |
""" | |
Initialize the VadFreeFasterWhisperPipeline. | |
Args: | |
model: The Whisper model instance. | |
options: Transcription options. | |
tokenizer: The tokenizer instance. | |
device: Device to run the model on. | |
framework: The framework to use ('pt' for PyTorch). | |
language: The language for transcription. | |
suppress_numerals: Whether to suppress numeral tokens. | |
**kwargs: Additional keyword arguments. | |
Returns: | |
None | |
""" | |
super().__init__( | |
model=model, | |
vad=None, | |
vad_params={}, | |
options=options, | |
tokenizer=tokenizer, | |
device=device, | |
framework=framework, | |
language=language, | |
suppress_numerals=suppress_numerals, | |
**kwargs, | |
) | |
def detect_language(self, audio: np.ndarray): | |
""" | |
Detect the language of the audio. | |
Args: | |
audio (np.ndarray): The input audio signal. | |
Returns: | |
tuple: Detected language and its probability. | |
""" | |
model_n_mels = self.model.feat_kwargs.get("feature_size") | |
if audio.shape[0] > N_SAMPLES: | |
# Randomly sample N_SAMPLES from the audio array | |
start_index = np.random.randint(0, audio.shape[0] - N_SAMPLES) | |
audio_sample = audio[start_index : start_index + N_SAMPLES] | |
else: | |
audio_sample = audio[:N_SAMPLES] | |
padding = 0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0] | |
segment = log_mel_spectrogram( | |
audio_sample, | |
n_mels=model_n_mels if model_n_mels is not None else 80, | |
padding=padding, | |
) | |
encoder_output = self.model.encode(segment) | |
results = self.model.model.detect_language(encoder_output) | |
language_token, language_probability = results[0][0] | |
language = language_token[2:-2] | |
return language, language_probability | |
def transcribe( | |
self, | |
audio: Union[str, np.ndarray], | |
vad_segments: List[dict], | |
batch_size=None, | |
num_workers=0, | |
language=None, | |
task=None, | |
chunk_size=30, | |
print_progress=False, | |
combined_progress=False, | |
) -> TranscriptionResult: | |
""" | |
Transcribe the audio into text. | |
Args: | |
audio (Union[str, np.ndarray]): The input audio signal or path to audio file. | |
vad_segments (List[dict]): List of VAD segments. | |
batch_size (int, optional): Batch size for transcription. Defaults to None. | |
num_workers (int, optional): Number of workers for loading data. Defaults to 0. | |
language (str, optional): Language for transcription. Defaults to None. | |
task (str, optional): Task type ('transcribe' or 'translate'). Defaults to None. | |
chunk_size (int, optional): Size of chunks for processing. Defaults to 30. | |
print_progress (bool, optional): Whether to print progress. Defaults to False. | |
combined_progress (bool, optional): Whether to combine progress. Defaults to False. | |
Returns: | |
TranscriptionResult: The transcription result containing segments and language. | |
""" | |
if isinstance(audio, str): | |
audio = load_audio(audio) | |
def data(audio, segments): | |
for seg in segments: | |
f1 = int(seg["start"] * SAMPLE_RATE) | |
f2 = int(seg["end"] * SAMPLE_RATE) | |
yield {"inputs": audio[f1:f2]} | |
if self.tokenizer is None: | |
language = language or self.detect_language(audio) | |
task = task or "transcribe" | |
self.tokenizer = faster_whisper.tokenizer.Tokenizer( | |
self.model.hf_tokenizer, | |
self.model.model.is_multilingual, | |
task=task, | |
language=language, | |
) | |
else: | |
language = language or self.tokenizer.language_code | |
task = task or self.tokenizer.task | |
if task != self.tokenizer.task or language != self.tokenizer.language_code: | |
self.tokenizer = faster_whisper.tokenizer.Tokenizer( | |
self.model.hf_tokenizer, | |
self.model.model.is_multilingual, | |
task=task, | |
language=language, | |
) | |
if self.suppress_numerals: | |
previous_suppress_tokens = self.options.suppress_tokens | |
numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer) | |
new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens | |
new_suppressed_tokens = list(set(new_suppressed_tokens)) | |
self.options = self.options._replace(suppress_tokens=new_suppressed_tokens) | |
segments: List[SingleSegment] = [] | |
batch_size = batch_size or self._batch_size | |
total_segments = len(vad_segments) | |
progress = tqdm.tqdm(total=total_segments, desc="Transcribing") | |
for idx, out in enumerate( | |
self.__call__( | |
data(audio, vad_segments), | |
batch_size=batch_size, | |
num_workers=num_workers, | |
) | |
): | |
if print_progress: | |
progress.update(1) | |
text = out["text"] | |
if batch_size in [0, 1, None]: | |
text = text[0] | |
segments.append( | |
{ | |
"text": text, | |
"start": round(vad_segments[idx]["start"], 3), | |
"end": round(vad_segments[idx]["end"], 3), | |
"speaker": vad_segments[idx].get("speaker", None), | |
} | |
) | |
# revert the tokenizer if multilingual inference is enabled | |
if self.preset_language is None: | |
self.tokenizer = None | |
# revert suppressed tokens if suppress_numerals is enabled | |
if self.suppress_numerals: | |
self.options = self.options._replace( | |
suppress_tokens=previous_suppress_tokens | |
) | |
return {"segments": segments, "language": language} | |
def load_asr_model( | |
whisper_arch: str, | |
device: str, | |
device_index: int = 0, | |
compute_type: str = "float16", | |
asr_options: Optional[dict] = None, | |
language: Optional[str] = None, | |
vad_model=None, | |
vad_options=None, | |
model: Optional[WhisperModel] = None, | |
task: str = "transcribe", | |
download_root: Optional[str] = None, | |
threads: int = 4, | |
) -> VadFreeFasterWhisperPipeline: | |
""" | |
Load a Whisper model for inference. | |
Args: | |
whisper_arch (str): The name of the Whisper model to load. | |
device (str): The device to load the model on. | |
device_index (int, optional): The device index. Defaults to 0. | |
compute_type (str, optional): The compute type to use for the model. Defaults to "float16". | |
asr_options (Optional[dict], optional): Options for ASR. Defaults to None. | |
language (Optional[str], optional): The language of the model. Defaults to None. | |
vad_model: The VAD model instance. Defaults to None. | |
vad_options: Options for VAD. Defaults to None. | |
model (Optional[WhisperModel], optional): The WhisperModel instance to use. Defaults to None. | |
task (str, optional): The task type ('transcribe' or 'translate'). Defaults to "transcribe". | |
download_root (Optional[str], optional): The root directory to download the model to. Defaults to None. | |
threads (int, optional): The number of CPU threads to use per worker. Defaults to 4. | |
Returns: | |
VadFreeFasterWhisperPipeline: The loaded Whisper pipeline. | |
Raises: | |
ValueError: If the whisper architecture is not recognized. | |
""" | |
if whisper_arch.endswith(".en"): | |
language = "en" | |
model = model or WhisperModel( | |
whisper_arch, | |
device=device, | |
device_index=device_index, | |
compute_type=compute_type, | |
download_root=download_root, | |
cpu_threads=threads, | |
) | |
if language is not None: | |
tokenizer = faster_whisper.tokenizer.Tokenizer( | |
model.hf_tokenizer, | |
model.model.is_multilingual, | |
task=task, | |
language=language, | |
) | |
else: | |
print( | |
"No language specified, language will be detected for each audio file (increases inference time)." | |
) | |
tokenizer = None | |
default_asr_options = { | |
"beam_size": 5, | |
"best_of": 5, | |
"patience": 1, | |
"length_penalty": 1, | |
"repetition_penalty": 1, | |
"no_repeat_ngram_size": 0, | |
"temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0], | |
"compression_ratio_threshold": 2.4, | |
"log_prob_threshold": -1.0, | |
"no_speech_threshold": 0.6, | |
"condition_on_previous_text": False, | |
"prompt_reset_on_temperature": 0.5, | |
"initial_prompt": None, | |
"prefix": None, | |
"suppress_blank": True, | |
"suppress_tokens": [-1], | |
"without_timestamps": True, | |
"max_initial_timestamp": 0.0, | |
"word_timestamps": False, | |
"prepend_punctuations": "\"'“¿([{-", | |
"append_punctuations": "\"'.。,,!!??::”)]}、", | |
"suppress_numerals": False, | |
"max_new_tokens": None, | |
"clip_timestamps": None, | |
"hallucination_silence_threshold": None, | |
} | |
if asr_options is not None: | |
default_asr_options.update(asr_options) | |
suppress_numerals = default_asr_options["suppress_numerals"] | |
del default_asr_options["suppress_numerals"] | |
default_asr_options = faster_whisper.transcribe.TranscriptionOptions( | |
**default_asr_options | |
) | |
return VadFreeFasterWhisperPipeline( | |
model=model, | |
options=default_asr_options, | |
tokenizer=tokenizer, | |
language=language, | |
suppress_numerals=suppress_numerals, | |
) | |