Hecheng0625's picture
Upload 409 files
c968fc3 verified
raw
history blame
6.29 kB
# Source: https://github.com/snakers4/silero-vad
#
# Copyright (c) 2024 snakers4
#
# This code is from a MIT-licensed repository. The full license text is available at the root of the source repository.
#
# Note: This code has been modified to fit the context of this repository.
import librosa
import torch
import numpy as np
VAD_THRESHOLD = 20
SAMPLING_RATE = 16000
class SileroVAD:
"""
Voice Activity Detection (VAD) using Silero-VAD.
"""
def __init__(self, local=False, model="silero_vad", device=torch.device("cpu")):
"""
Initialize the VAD object.
Args:
local (bool, optional): Whether to load the model locally. Defaults to False.
model (str, optional): The VAD model name to load. Defaults to "silero_vad".
device (torch.device, optional): The device to run the model on. Defaults to 'cpu'.
Returns:
None
Raises:
RuntimeError: If loading the model fails.
"""
try:
vad_model, utils = torch.hub.load(
repo_or_dir="snakers4/silero-vad" if not local else "vad/silero-vad",
model=model,
force_reload=False,
onnx=True,
source="github" if not local else "local",
)
self.vad_model = vad_model
(get_speech_timestamps, _, _, _, _) = utils
self.get_speech_timestamps = get_speech_timestamps
except Exception as e:
raise RuntimeError(f"Failed to load VAD model: {e}")
def segment_speech(self, audio_segment, start_time, end_time, sampling_rate):
"""
Segment speech from an audio segment and return a list of timestamps.
Args:
audio_segment (np.ndarray): The audio segment to be segmented.
start_time (int): The start time of the audio segment in frames.
end_time (int): The end time of the audio segment in frames.
sampling_rate (int): The sampling rate of the audio segment.
Returns:
list: A list of timestamps, each containing the start and end times of speech segments in frames.
Raises:
ValueError: If the audio segment is invalid.
"""
if audio_segment is None or not isinstance(audio_segment, (np.ndarray, list)):
raise ValueError("Invalid audio segment")
speech_timestamps = self.get_speech_timestamps(
audio_segment, self.vad_model, sampling_rate=sampling_rate
)
adjusted_timestamps = [
(ts["start"] + start_time, ts["end"] + start_time)
for ts in speech_timestamps
]
if not adjusted_timestamps:
return []
intervals = [
end[0] - start[1]
for start, end in zip(adjusted_timestamps[:-1], adjusted_timestamps[1:])
]
segments = []
def split_timestamps(start_index, end_index):
if (
start_index == end_index
or adjusted_timestamps[end_index][1]
- adjusted_timestamps[start_index][0]
< 20 * sampling_rate
):
segments.append([start_index, end_index])
else:
if not intervals[start_index:end_index]:
return
max_interval_index = intervals[start_index:end_index].index(
max(intervals[start_index:end_index])
)
split_index = start_index + max_interval_index
split_timestamps(start_index, split_index)
split_timestamps(split_index + 1, end_index)
split_timestamps(0, len(adjusted_timestamps) - 1)
merged_timestamps = [
[adjusted_timestamps[start][0], adjusted_timestamps[end][1]]
for start, end in segments
]
return merged_timestamps
def vad(self, speakerdia, audio):
"""
Process the audio based on the given speaker diarization dataframe.
Args:
speakerdia (pd.DataFrame): The diarization dataframe containing start, end, and speaker info.
audio (dict): A dictionary containing the audio waveform and sample rate.
Returns:
list: A list of dictionaries containing processed audio segments with start, end, and speaker.
"""
sampling_rate = audio["sample_rate"]
audio_data = audio["waveform"]
out = []
last_end = 0
speakers_seen = set()
count_id = 0
for index, row in speakerdia.iterrows():
start = float(row["start"])
end = float(row["end"])
if end <= last_end:
continue
last_end = end
start_frame = int(start * sampling_rate)
end_frame = int(end * sampling_rate)
if row["speaker"] not in speakers_seen:
speakers_seen.add(row["speaker"])
if end - start <= VAD_THRESHOLD:
out.append(
{
"index": str(count_id).zfill(5),
"start": start, # in seconds
"end": end,
"speaker": row["speaker"], # same for all
}
)
count_id += 1
continue
temp_audio = audio_data[start_frame:end_frame]
# resample from 24k to 16k
temp_audio_resampled = librosa.resample(
temp_audio, orig_sr=sampling_rate, target_sr=SAMPLING_RATE
)
for start_frame_sub, end_frame_sub in self.segment_speech(
temp_audio_resampled,
int(start * SAMPLING_RATE),
int(end * SAMPLING_RATE),
SAMPLING_RATE,
):
out.append(
{
"index": str(count_id).zfill(5),
"start": start_frame_sub / SAMPLING_RATE, # in seconds
"end": end_frame_sub / SAMPLING_RATE,
"speaker": row["speaker"], # same for all
}
)
count_id += 1
return out