# Source: https://github.com/snakers4/silero-vad # # Copyright (c) 2024 snakers4 # # This code is from a MIT-licensed repository. The full license text is available at the root of the source repository. # # Note: This code has been modified to fit the context of this repository. import librosa import torch import numpy as np VAD_THRESHOLD = 20 SAMPLING_RATE = 16000 class SileroVAD: """ Voice Activity Detection (VAD) using Silero-VAD. """ def __init__(self, local=False, model="silero_vad", device=torch.device("cpu")): """ Initialize the VAD object. Args: local (bool, optional): Whether to load the model locally. Defaults to False. model (str, optional): The VAD model name to load. Defaults to "silero_vad". device (torch.device, optional): The device to run the model on. Defaults to 'cpu'. Returns: None Raises: RuntimeError: If loading the model fails. """ try: vad_model, utils = torch.hub.load( repo_or_dir="snakers4/silero-vad" if not local else "vad/silero-vad", model=model, force_reload=False, onnx=True, source="github" if not local else "local", ) self.vad_model = vad_model (get_speech_timestamps, _, _, _, _) = utils self.get_speech_timestamps = get_speech_timestamps except Exception as e: raise RuntimeError(f"Failed to load VAD model: {e}") def segment_speech(self, audio_segment, start_time, end_time, sampling_rate): """ Segment speech from an audio segment and return a list of timestamps. Args: audio_segment (np.ndarray): The audio segment to be segmented. start_time (int): The start time of the audio segment in frames. end_time (int): The end time of the audio segment in frames. sampling_rate (int): The sampling rate of the audio segment. Returns: list: A list of timestamps, each containing the start and end times of speech segments in frames. Raises: ValueError: If the audio segment is invalid. """ if audio_segment is None or not isinstance(audio_segment, (np.ndarray, list)): raise ValueError("Invalid audio segment") speech_timestamps = self.get_speech_timestamps( audio_segment, self.vad_model, sampling_rate=sampling_rate ) adjusted_timestamps = [ (ts["start"] + start_time, ts["end"] + start_time) for ts in speech_timestamps ] if not adjusted_timestamps: return [] intervals = [ end[0] - start[1] for start, end in zip(adjusted_timestamps[:-1], adjusted_timestamps[1:]) ] segments = [] def split_timestamps(start_index, end_index): if ( start_index == end_index or adjusted_timestamps[end_index][1] - adjusted_timestamps[start_index][0] < 20 * sampling_rate ): segments.append([start_index, end_index]) else: if not intervals[start_index:end_index]: return max_interval_index = intervals[start_index:end_index].index( max(intervals[start_index:end_index]) ) split_index = start_index + max_interval_index split_timestamps(start_index, split_index) split_timestamps(split_index + 1, end_index) split_timestamps(0, len(adjusted_timestamps) - 1) merged_timestamps = [ [adjusted_timestamps[start][0], adjusted_timestamps[end][1]] for start, end in segments ] return merged_timestamps def vad(self, speakerdia, audio): """ Process the audio based on the given speaker diarization dataframe. Args: speakerdia (pd.DataFrame): The diarization dataframe containing start, end, and speaker info. audio (dict): A dictionary containing the audio waveform and sample rate. Returns: list: A list of dictionaries containing processed audio segments with start, end, and speaker. """ sampling_rate = audio["sample_rate"] audio_data = audio["waveform"] out = [] last_end = 0 speakers_seen = set() count_id = 0 for index, row in speakerdia.iterrows(): start = float(row["start"]) end = float(row["end"]) if end <= last_end: continue last_end = end start_frame = int(start * sampling_rate) end_frame = int(end * sampling_rate) if row["speaker"] not in speakers_seen: speakers_seen.add(row["speaker"]) if end - start <= VAD_THRESHOLD: out.append( { "index": str(count_id).zfill(5), "start": start, # in seconds "end": end, "speaker": row["speaker"], # same for all } ) count_id += 1 continue temp_audio = audio_data[start_frame:end_frame] # resample from 24k to 16k temp_audio_resampled = librosa.resample( temp_audio, orig_sr=sampling_rate, target_sr=SAMPLING_RATE ) for start_frame_sub, end_frame_sub in self.segment_speech( temp_audio_resampled, int(start * SAMPLING_RATE), int(end * SAMPLING_RATE), SAMPLING_RATE, ): out.append( { "index": str(count_id).zfill(5), "start": start_frame_sub / SAMPLING_RATE, # in seconds "end": end_frame_sub / SAMPLING_RATE, "speaker": row["speaker"], # same for all } ) count_id += 1 return out