|
import argparse |
|
import os |
|
import shutil |
|
from pathlib import Path |
|
|
|
import soundfile as sf |
|
import torch |
|
from tqdm import tqdm |
|
|
|
from common.log import logger |
|
from common.stdout_wrapper import SAFE_STDOUT |
|
|
|
vad_model, utils = torch.hub.load( |
|
repo_or_dir="snakers4/silero-vad", |
|
model="silero_vad", |
|
onnx=True, |
|
trust_repo=True, |
|
) |
|
|
|
(get_speech_timestamps, _, read_audio, *_) = utils |
|
|
|
|
|
def get_stamps( |
|
audio_file, min_silence_dur_ms: int = 700, min_sec: float = 2, max_sec: float = 12 |
|
): |
|
""" |
|
min_silence_dur_ms: int (ミリ秒): |
|
このミリ秒数以上を無音だと判断する。 |
|
逆に、この秒数以下の無音区間では区切られない。 |
|
小さくすると、音声がぶつ切りに小さくなりすぎ、 |
|
大きくすると音声一つ一つが長くなりすぎる。 |
|
データセットによってたぶん要調整。 |
|
min_sec: float (秒): |
|
この秒数より小さい発話は無視する。 |
|
max_sec: float (秒): |
|
この秒数より大きい発話は無視する。 |
|
""" |
|
|
|
sampling_rate = 16000 |
|
|
|
min_ms = int(min_sec * 1000) |
|
|
|
wav = read_audio(audio_file, sampling_rate=sampling_rate) |
|
speech_timestamps = get_speech_timestamps( |
|
wav, |
|
vad_model, |
|
sampling_rate=sampling_rate, |
|
min_silence_duration_ms=min_silence_dur_ms, |
|
min_speech_duration_ms=min_ms, |
|
max_speech_duration_s=max_sec, |
|
) |
|
|
|
return speech_timestamps |
|
|
|
|
|
def split_wav( |
|
audio_file, |
|
target_dir="raw", |
|
min_sec=2, |
|
max_sec=12, |
|
min_silence_dur_ms=700, |
|
): |
|
margin = 200 |
|
speech_timestamps = get_stamps( |
|
audio_file, |
|
min_silence_dur_ms=min_silence_dur_ms, |
|
min_sec=min_sec, |
|
max_sec=max_sec, |
|
) |
|
|
|
data, sr = sf.read(audio_file) |
|
|
|
total_ms = len(data) / sr * 1000 |
|
|
|
file_name = os.path.basename(audio_file).split(".")[0] |
|
os.makedirs(target_dir, exist_ok=True) |
|
|
|
total_time_ms = 0 |
|
|
|
|
|
for i, ts in enumerate(speech_timestamps): |
|
start_ms = max(ts["start"] / 16 - margin, 0) |
|
end_ms = min(ts["end"] / 16 + margin, total_ms) |
|
|
|
start_sample = int(start_ms / 1000 * sr) |
|
end_sample = int(end_ms / 1000 * sr) |
|
segment = data[start_sample:end_sample] |
|
|
|
sf.write(os.path.join(target_dir, f"{file_name}-{i}.wav"), segment, sr) |
|
total_time_ms += end_ms - start_ms |
|
|
|
return total_time_ms / 1000 |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--min_sec", "-m", type=float, default=2, help="Minimum seconds of a slice" |
|
) |
|
parser.add_argument( |
|
"--max_sec", "-M", type=float, default=12, help="Maximum seconds of a slice" |
|
) |
|
parser.add_argument( |
|
"--input_dir", |
|
"-i", |
|
type=str, |
|
default="inputs", |
|
help="Directory of input wav files", |
|
) |
|
parser.add_argument( |
|
"--output_dir", |
|
"-o", |
|
type=str, |
|
default="raw", |
|
help="Directory of output wav files", |
|
) |
|
parser.add_argument( |
|
"--min_silence_dur_ms", |
|
"-s", |
|
type=int, |
|
default=700, |
|
help="Silence above this duration (ms) is considered as a split point.", |
|
) |
|
args = parser.parse_args() |
|
|
|
input_dir = args.input_dir |
|
output_dir = args.output_dir |
|
min_sec = args.min_sec |
|
max_sec = args.max_sec |
|
min_silence_dur_ms = args.min_silence_dur_ms |
|
|
|
wav_files = Path(input_dir).glob("**/*.wav") |
|
wav_files = list(wav_files) |
|
logger.info(f"Found {len(wav_files)} wav files.") |
|
if os.path.exists(output_dir): |
|
logger.warning(f"Output directory {output_dir} already exists, deleting...") |
|
shutil.rmtree(output_dir) |
|
|
|
total_sec = 0 |
|
for wav_file in tqdm(wav_files, file=SAFE_STDOUT): |
|
time_sec = split_wav( |
|
audio_file=str(wav_file), |
|
target_dir=output_dir, |
|
min_sec=min_sec, |
|
max_sec=max_sec, |
|
min_silence_dur_ms=min_silence_dur_ms, |
|
) |
|
total_sec += time_sec |
|
|
|
logger.info(f"Slice done! Total time: {total_sec / 60:.2f} min.") |
|
|