freddyaboulton HF staff commited on
Commit
9067154
1 Parent(s): 4a37dab
Files changed (1) hide show
  1. utils/vad.py +290 -0
utils/vad.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bisect
2
+ import functools
3
+ import os
4
+ import warnings
5
+
6
+ from typing import List, NamedTuple, Optional
7
+
8
+ import numpy as np
9
+
10
+
11
+ # The code below is adapted from https://github.com/snakers4/silero-vad.
12
+ class VadOptions(NamedTuple):
13
+ """VAD options.
14
+
15
+ Attributes:
16
+ threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
17
+ probabilities ABOVE this value are considered as SPEECH. It is better to tune this
18
+ parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
19
+ min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out.
20
+ max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer
21
+ than max_speech_duration_s will be split at the timestamp of the last silence that
22
+ lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be
23
+ split aggressively just before max_speech_duration_s.
24
+ min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms
25
+ before separating it
26
+ window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model.
27
+ WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate.
28
+ Values other than these may affect model performance!!
29
+ speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side
30
+ """
31
+
32
+ threshold: float = 0.5
33
+ min_speech_duration_ms: int = 250
34
+ max_speech_duration_s: float = float("inf")
35
+ min_silence_duration_ms: int = 2000
36
+ window_size_samples: int = 1024
37
+ speech_pad_ms: int = 400
38
+
39
+
40
+ def get_speech_timestamps(
41
+ audio: np.ndarray,
42
+ vad_options: Optional[VadOptions] = None,
43
+ **kwargs,
44
+ ) -> List[dict]:
45
+ """This method is used for splitting long audios into speech chunks using silero VAD.
46
+
47
+ Args:
48
+ audio: One dimensional float array.
49
+ vad_options: Options for VAD processing.
50
+ kwargs: VAD options passed as keyword arguments for backward compatibility.
51
+
52
+ Returns:
53
+ List of dicts containing begin and end samples of each speech chunk.
54
+ """
55
+ if vad_options is None:
56
+ vad_options = VadOptions(**kwargs)
57
+
58
+ threshold = vad_options.threshold
59
+ min_speech_duration_ms = vad_options.min_speech_duration_ms
60
+ max_speech_duration_s = vad_options.max_speech_duration_s
61
+ min_silence_duration_ms = vad_options.min_silence_duration_ms
62
+ window_size_samples = vad_options.window_size_samples
63
+ speech_pad_ms = vad_options.speech_pad_ms
64
+
65
+ if window_size_samples not in [512, 1024, 1536]:
66
+ warnings.warn(
67
+ "Unusual window_size_samples! Supported window_size_samples:\n"
68
+ " - [512, 1024, 1536] for 16000 sampling_rate"
69
+ )
70
+
71
+ sampling_rate = 16000
72
+ min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
73
+ speech_pad_samples = sampling_rate * speech_pad_ms / 1000
74
+ max_speech_samples = (
75
+ sampling_rate * max_speech_duration_s
76
+ - window_size_samples
77
+ - 2 * speech_pad_samples
78
+ )
79
+ min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
80
+ min_silence_samples_at_max_speech = sampling_rate * 98 / 1000
81
+
82
+ audio_length_samples = len(audio)
83
+
84
+ model = get_vad_model()
85
+ state = model.get_initial_state(batch_size=1)
86
+
87
+ speech_probs = []
88
+ for current_start_sample in range(0, audio_length_samples, window_size_samples):
89
+ chunk = audio[current_start_sample : current_start_sample + window_size_samples]
90
+ if len(chunk) < window_size_samples:
91
+ chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
92
+ speech_prob, state = model(chunk, state, sampling_rate)
93
+ speech_probs.append(speech_prob)
94
+
95
+ triggered = False
96
+ speeches = []
97
+ current_speech = {}
98
+ neg_threshold = threshold - 0.15
99
+
100
+ # to save potential segment end (and tolerate some silence)
101
+ temp_end = 0
102
+ # to save potential segment limits in case of maximum segment size reached
103
+ prev_end = next_start = 0
104
+
105
+ for i, speech_prob in enumerate(speech_probs):
106
+ if (speech_prob >= threshold) and temp_end:
107
+ temp_end = 0
108
+ if next_start < prev_end:
109
+ next_start = window_size_samples * i
110
+
111
+ if (speech_prob >= threshold) and not triggered:
112
+ triggered = True
113
+ current_speech["start"] = window_size_samples * i
114
+ continue
115
+
116
+ if (
117
+ triggered
118
+ and (window_size_samples * i) - current_speech["start"] > max_speech_samples
119
+ ):
120
+ if prev_end:
121
+ current_speech["end"] = prev_end
122
+ speeches.append(current_speech)
123
+ current_speech = {}
124
+ # previously reached silence (< neg_thres) and is still not speech (< thres)
125
+ if next_start < prev_end:
126
+ triggered = False
127
+ else:
128
+ current_speech["start"] = next_start
129
+ prev_end = next_start = temp_end = 0
130
+ else:
131
+ current_speech["end"] = window_size_samples * i
132
+ speeches.append(current_speech)
133
+ current_speech = {}
134
+ prev_end = next_start = temp_end = 0
135
+ triggered = False
136
+ continue
137
+
138
+ if (speech_prob < neg_threshold) and triggered:
139
+ if not temp_end:
140
+ temp_end = window_size_samples * i
141
+ # condition to avoid cutting in very short silence
142
+ if (window_size_samples * i) - temp_end > min_silence_samples_at_max_speech:
143
+ prev_end = temp_end
144
+ if (window_size_samples * i) - temp_end < min_silence_samples:
145
+ continue
146
+ else:
147
+ current_speech["end"] = temp_end
148
+ if (
149
+ current_speech["end"] - current_speech["start"]
150
+ ) > min_speech_samples:
151
+ speeches.append(current_speech)
152
+ current_speech = {}
153
+ prev_end = next_start = temp_end = 0
154
+ triggered = False
155
+ continue
156
+
157
+ if (
158
+ current_speech
159
+ and (audio_length_samples - current_speech["start"]) > min_speech_samples
160
+ ):
161
+ current_speech["end"] = audio_length_samples
162
+ speeches.append(current_speech)
163
+
164
+ for i, speech in enumerate(speeches):
165
+ if i == 0:
166
+ speech["start"] = int(max(0, speech["start"] - speech_pad_samples))
167
+ if i != len(speeches) - 1:
168
+ silence_duration = speeches[i + 1]["start"] - speech["end"]
169
+ if silence_duration < 2 * speech_pad_samples:
170
+ speech["end"] += int(silence_duration // 2)
171
+ speeches[i + 1]["start"] = int(
172
+ max(0, speeches[i + 1]["start"] - silence_duration // 2)
173
+ )
174
+ else:
175
+ speech["end"] = int(
176
+ min(audio_length_samples, speech["end"] + speech_pad_samples)
177
+ )
178
+ speeches[i + 1]["start"] = int(
179
+ max(0, speeches[i + 1]["start"] - speech_pad_samples)
180
+ )
181
+ else:
182
+ speech["end"] = int(
183
+ min(audio_length_samples, speech["end"] + speech_pad_samples)
184
+ )
185
+
186
+ return speeches
187
+
188
+
189
+ def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray:
190
+ """Collects and concatenates audio chunks."""
191
+ if not chunks:
192
+ return np.array([], dtype=np.float32)
193
+
194
+ return np.concatenate([audio[chunk["start"] : chunk["end"]] for chunk in chunks])
195
+
196
+
197
+ class SpeechTimestampsMap:
198
+ """Helper class to restore original speech timestamps."""
199
+
200
+ def __init__(self, chunks: List[dict], sampling_rate: int, time_precision: int = 2):
201
+ self.sampling_rate = sampling_rate
202
+ self.time_precision = time_precision
203
+ self.chunk_end_sample = []
204
+ self.total_silence_before = []
205
+
206
+ previous_end = 0
207
+ silent_samples = 0
208
+
209
+ for chunk in chunks:
210
+ silent_samples += chunk["start"] - previous_end
211
+ previous_end = chunk["end"]
212
+
213
+ self.chunk_end_sample.append(chunk["end"] - silent_samples)
214
+ self.total_silence_before.append(silent_samples / sampling_rate)
215
+
216
+ def get_original_time(
217
+ self,
218
+ time: float,
219
+ chunk_index: Optional[int] = None,
220
+ ) -> float:
221
+ if chunk_index is None:
222
+ chunk_index = self.get_chunk_index(time)
223
+
224
+ total_silence_before = self.total_silence_before[chunk_index]
225
+ return round(total_silence_before + time, self.time_precision)
226
+
227
+ def get_chunk_index(self, time: float) -> int:
228
+ sample = int(time * self.sampling_rate)
229
+ return min(
230
+ bisect.bisect(self.chunk_end_sample, sample),
231
+ len(self.chunk_end_sample) - 1,
232
+ )
233
+
234
+
235
+ @functools.lru_cache
236
+ def get_vad_model():
237
+ """Returns the VAD model instance."""
238
+ asset_dir = os.path.join(os.path.dirname(__file__), "assets")
239
+ path = os.path.join(asset_dir, "silero_vad.onnx")
240
+ return SileroVADModel(path)
241
+
242
+
243
+ class SileroVADModel:
244
+ def __init__(self, path):
245
+ try:
246
+ import onnxruntime
247
+ except ImportError as e:
248
+ raise RuntimeError(
249
+ "Applying the VAD filter requires the onnxruntime package"
250
+ ) from e
251
+
252
+ opts = onnxruntime.SessionOptions()
253
+ opts.inter_op_num_threads = 1
254
+ opts.intra_op_num_threads = 1
255
+ opts.log_severity_level = 4
256
+
257
+ self.session = onnxruntime.InferenceSession(
258
+ path,
259
+ providers=["CPUExecutionProvider"],
260
+ sess_options=opts,
261
+ )
262
+
263
+ def get_initial_state(self, batch_size: int):
264
+ h = np.zeros((2, batch_size, 64), dtype=np.float32)
265
+ c = np.zeros((2, batch_size, 64), dtype=np.float32)
266
+ return h, c
267
+
268
+ def __call__(self, x, state, sr: int):
269
+ if len(x.shape) == 1:
270
+ x = np.expand_dims(x, 0)
271
+ if len(x.shape) > 2:
272
+ raise ValueError(
273
+ f"Too many dimensions for input audio chunk {len(x.shape)}"
274
+ )
275
+ if sr / x.shape[1] > 31.25:
276
+ raise ValueError("Input audio chunk is too short")
277
+
278
+ h, c = state
279
+
280
+ ort_inputs = {
281
+ "input": x,
282
+ "h": h,
283
+ "c": c,
284
+ "sr": np.array(sr, dtype="int64"),
285
+ }
286
+
287
+ out, h, c = self.session.run(None, ort_inputs)
288
+ state = (h, c)
289
+
290
+ return out, state