Hecheng0625's picture
Upload 409 files
c968fc3 verified
# Copyright (c) 2024 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import faster_whisper
from typing import List, Union, Optional, NamedTuple
import torch
import numpy as np
import tqdm
from whisperx.audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
from whisperx.types import TranscriptionResult, SingleSegment
from whisperx.asr import WhisperModel, FasterWhisperPipeline, find_numeral_symbol_tokens
class VadFreeFasterWhisperPipeline(FasterWhisperPipeline):
"""
FasterWhisperModel without VAD
"""
def __init__(
self,
model,
options: NamedTuple,
tokenizer=None,
device: Union[int, str, "torch.device"] = -1,
framework="pt",
language: Optional[str] = None,
suppress_numerals: bool = False,
**kwargs,
):
"""
Initialize the VadFreeFasterWhisperPipeline.
Args:
model: The Whisper model instance.
options: Transcription options.
tokenizer: The tokenizer instance.
device: Device to run the model on.
framework: The framework to use ('pt' for PyTorch).
language: The language for transcription.
suppress_numerals: Whether to suppress numeral tokens.
**kwargs: Additional keyword arguments.
Returns:
None
"""
super().__init__(
model=model,
vad=None,
vad_params={},
options=options,
tokenizer=tokenizer,
device=device,
framework=framework,
language=language,
suppress_numerals=suppress_numerals,
**kwargs,
)
def detect_language(self, audio: np.ndarray):
"""
Detect the language of the audio.
Args:
audio (np.ndarray): The input audio signal.
Returns:
tuple: Detected language and its probability.
"""
model_n_mels = self.model.feat_kwargs.get("feature_size")
if audio.shape[0] > N_SAMPLES:
# Randomly sample N_SAMPLES from the audio array
start_index = np.random.randint(0, audio.shape[0] - N_SAMPLES)
audio_sample = audio[start_index : start_index + N_SAMPLES]
else:
audio_sample = audio[:N_SAMPLES]
padding = 0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0]
segment = log_mel_spectrogram(
audio_sample,
n_mels=model_n_mels if model_n_mels is not None else 80,
padding=padding,
)
encoder_output = self.model.encode(segment)
results = self.model.model.detect_language(encoder_output)
language_token, language_probability = results[0][0]
language = language_token[2:-2]
return language, language_probability
def transcribe(
self,
audio: Union[str, np.ndarray],
vad_segments: List[dict],
batch_size=None,
num_workers=0,
language=None,
task=None,
chunk_size=30,
print_progress=False,
combined_progress=False,
) -> TranscriptionResult:
"""
Transcribe the audio into text.
Args:
audio (Union[str, np.ndarray]): The input audio signal or path to audio file.
vad_segments (List[dict]): List of VAD segments.
batch_size (int, optional): Batch size for transcription. Defaults to None.
num_workers (int, optional): Number of workers for loading data. Defaults to 0.
language (str, optional): Language for transcription. Defaults to None.
task (str, optional): Task type ('transcribe' or 'translate'). Defaults to None.
chunk_size (int, optional): Size of chunks for processing. Defaults to 30.
print_progress (bool, optional): Whether to print progress. Defaults to False.
combined_progress (bool, optional): Whether to combine progress. Defaults to False.
Returns:
TranscriptionResult: The transcription result containing segments and language.
"""
if isinstance(audio, str):
audio = load_audio(audio)
def data(audio, segments):
for seg in segments:
f1 = int(seg["start"] * SAMPLE_RATE)
f2 = int(seg["end"] * SAMPLE_RATE)
yield {"inputs": audio[f1:f2]}
if self.tokenizer is None:
language = language or self.detect_language(audio)
task = task or "transcribe"
self.tokenizer = faster_whisper.tokenizer.Tokenizer(
self.model.hf_tokenizer,
self.model.model.is_multilingual,
task=task,
language=language,
)
else:
language = language or self.tokenizer.language_code
task = task or self.tokenizer.task
if task != self.tokenizer.task or language != self.tokenizer.language_code:
self.tokenizer = faster_whisper.tokenizer.Tokenizer(
self.model.hf_tokenizer,
self.model.model.is_multilingual,
task=task,
language=language,
)
if self.suppress_numerals:
previous_suppress_tokens = self.options.suppress_tokens
numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer)
new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens
new_suppressed_tokens = list(set(new_suppressed_tokens))
self.options = self.options._replace(suppress_tokens=new_suppressed_tokens)
segments: List[SingleSegment] = []
batch_size = batch_size or self._batch_size
total_segments = len(vad_segments)
progress = tqdm.tqdm(total=total_segments, desc="Transcribing")
for idx, out in enumerate(
self.__call__(
data(audio, vad_segments),
batch_size=batch_size,
num_workers=num_workers,
)
):
if print_progress:
progress.update(1)
text = out["text"]
if batch_size in [0, 1, None]:
text = text[0]
segments.append(
{
"text": text,
"start": round(vad_segments[idx]["start"], 3),
"end": round(vad_segments[idx]["end"], 3),
"speaker": vad_segments[idx].get("speaker", None),
}
)
# revert the tokenizer if multilingual inference is enabled
if self.preset_language is None:
self.tokenizer = None
# revert suppressed tokens if suppress_numerals is enabled
if self.suppress_numerals:
self.options = self.options._replace(
suppress_tokens=previous_suppress_tokens
)
return {"segments": segments, "language": language}
def load_asr_model(
whisper_arch: str,
device: str,
device_index: int = 0,
compute_type: str = "float16",
asr_options: Optional[dict] = None,
language: Optional[str] = None,
vad_model=None,
vad_options=None,
model: Optional[WhisperModel] = None,
task: str = "transcribe",
download_root: Optional[str] = None,
threads: int = 4,
) -> VadFreeFasterWhisperPipeline:
"""
Load a Whisper model for inference.
Args:
whisper_arch (str): The name of the Whisper model to load.
device (str): The device to load the model on.
device_index (int, optional): The device index. Defaults to 0.
compute_type (str, optional): The compute type to use for the model. Defaults to "float16".
asr_options (Optional[dict], optional): Options for ASR. Defaults to None.
language (Optional[str], optional): The language of the model. Defaults to None.
vad_model: The VAD model instance. Defaults to None.
vad_options: Options for VAD. Defaults to None.
model (Optional[WhisperModel], optional): The WhisperModel instance to use. Defaults to None.
task (str, optional): The task type ('transcribe' or 'translate'). Defaults to "transcribe".
download_root (Optional[str], optional): The root directory to download the model to. Defaults to None.
threads (int, optional): The number of CPU threads to use per worker. Defaults to 4.
Returns:
VadFreeFasterWhisperPipeline: The loaded Whisper pipeline.
Raises:
ValueError: If the whisper architecture is not recognized.
"""
if whisper_arch.endswith(".en"):
language = "en"
model = model or WhisperModel(
whisper_arch,
device=device,
device_index=device_index,
compute_type=compute_type,
download_root=download_root,
cpu_threads=threads,
)
if language is not None:
tokenizer = faster_whisper.tokenizer.Tokenizer(
model.hf_tokenizer,
model.model.is_multilingual,
task=task,
language=language,
)
else:
print(
"No language specified, language will be detected for each audio file (increases inference time)."
)
tokenizer = None
default_asr_options = {
"beam_size": 5,
"best_of": 5,
"patience": 1,
"length_penalty": 1,
"repetition_penalty": 1,
"no_repeat_ngram_size": 0,
"temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
"compression_ratio_threshold": 2.4,
"log_prob_threshold": -1.0,
"no_speech_threshold": 0.6,
"condition_on_previous_text": False,
"prompt_reset_on_temperature": 0.5,
"initial_prompt": None,
"prefix": None,
"suppress_blank": True,
"suppress_tokens": [-1],
"without_timestamps": True,
"max_initial_timestamp": 0.0,
"word_timestamps": False,
"prepend_punctuations": "\"'“¿([{-",
"append_punctuations": "\"'.。,,!!??::”)]}、",
"suppress_numerals": False,
"max_new_tokens": None,
"clip_timestamps": None,
"hallucination_silence_threshold": None,
}
if asr_options is not None:
default_asr_options.update(asr_options)
suppress_numerals = default_asr_options["suppress_numerals"]
del default_asr_options["suppress_numerals"]
default_asr_options = faster_whisper.transcribe.TranscriptionOptions(
**default_asr_options
)
return VadFreeFasterWhisperPipeline(
model=model,
options=default_asr_options,
tokenizer=tokenizer,
language=language,
suppress_numerals=suppress_numerals,
)