import logging |
import random |
from pathlib import Path |
import numpy as np |
import torch |
import torchaudio |
import torchaudio.functional as AF |
from torch.nn.utils.rnn import pad_sequence |
from torch.utils.data import Dataset as DatasetBase |
from ..hparams import HParams |
from .distorter import Distorter |
from .utils import rglob_audio_files |
logger = logging.getLogger(__name__) |
def _normalize(x): |
return x / (np.abs(x).max() + 1e-7) |
def _collate(batch, key, tensor=True, pad=True): |
l = [d[key] for d in batch] |
if l[0] is None: |
return None |
if tensor: |
l = [torch.from_numpy(x) for x in l] |
if pad: |
assert tensor, "Can't pad non-tensor" |
l = pad_sequence(l, batch_first=True) |
return l |
def praat_augment(wav, sr): |
try: |
import parselmouth |
except ImportError: |
raise ImportError("Please install parselmouth>=0.5.0 to use Praat augmentation") |
assert wav.ndim == 1, f"wav.ndim must be 1 but got {wav.ndim}" |
sound = parselmouth.Sound(wav, sr) |
formant_shift_ratio = random.uniform(1.1, 1.5) |
pitch_range_factor = random.uniform(0.5, 2.0) |
sound = parselmouth.praat.call( |
sound, "Change gender", 75, 600, formant_shift_ratio, 0, pitch_range_factor, 1.0 |
) |
wav = np.array(sound.values)[0].astype(np.float32) |
return wav |
class Dataset(DatasetBase): |
def __init__( |
self, |
fg_paths: list[Path], |
hp: HParams, |
training=True, |
max_retries=100, |
silent_fg_prob=0.01, |
mode=False, |
): |
super().__init__() |
assert mode in ("enhancer", "denoiser"), f"Invalid mode: {mode}" |
self.hp = hp |
self.fg_paths = fg_paths |
self.bg_paths = rglob_audio_files(hp.bg_dir) |
if len(self.fg_paths) == 0: |
raise ValueError(f"No foreground audio files found in {hp.fg_dir}") |
if len(self.bg_paths) == 0: |
raise ValueError(f"No background audio files found in {hp.bg_dir}") |
logger.info( |
f"Found {len(self.fg_paths)} foreground files and {len(self.bg_paths)} background files" |
) |
self.training = training |
self.max_retries = max_retries |
self.silent_fg_prob = silent_fg_prob |
self.mode = mode |
self.distorter = Distorter(hp, training=training, mode=mode) |
def _load_wav(self, path, length=None, random_crop=True): |
wav, sr = torchaudio.load(path) |
wav = AF.resample( |
waveform=wav, |
orig_freq=sr, |
new_freq=self.hp.wav_rate, |
lowpass_filter_width=64, |
rolloff=0.9475937167399596, |
resampling_method="sinc_interp_kaiser", |
beta=14.769656459379492, |
) |
wav = wav.float().numpy() |
if wav.ndim == 2: |
wav = np.mean(wav, axis=0) |
if length is None and self.training: |
length = int(self.hp.training_seconds * self.hp.wav_rate) |
if length is not None: |
if random_crop: |
start = random.randint(0, max(0, len(wav) - length)) |
wav = wav[start : start + length] |
else: |
wav = wav[:length] |
if length is not None and len(wav) < length: |
wav = np.pad(wav, (0, length - len(wav))) |
wav = _normalize(wav) |
return wav |
def _getitem_unsafe(self, index: int): |
fg_path = self.fg_paths[index] |
if self.training and random.random() < self.silent_fg_prob: |
fg_wav = np.zeros( |
int(self.hp.training_seconds * self.hp.wav_rate), dtype=np.float32 |
) |
else: |
fg_wav = self._load_wav(fg_path) |
if random.random() < self.hp.praat_augment_prob and self.training: |
fg_wav = praat_augment(fg_wav, self.hp.wav_rate) |
if self.hp.load_fg_only: |
bg_wav = None |
fg_dwav = None |
bg_dwav = None |
else: |
fg_dwav = _normalize(self.distorter(fg_wav, self.hp.wav_rate)).astype( |
np.float32 |
) |
if self.training: |
bg_path = random.choice(self.bg_paths) |
else: |
bg_path = self.bg_paths[index % len(self.bg_paths)] |
bg_wav = self._load_wav( |
bg_path, length=len(fg_wav), random_crop=self.training |
) |
bg_dwav = _normalize(self.distorter(bg_wav, self.hp.wav_rate)).astype( |
np.float32 |
) |
return dict( |
fg_wav=fg_wav, |
bg_wav=bg_wav, |
fg_dwav=fg_dwav, |
bg_dwav=bg_dwav, |
) |
def __getitem__(self, index: int): |
for i in range(self.max_retries): |
try: |
return self._getitem_unsafe(index) |
except Exception as e: |
if i == self.max_retries - 1: |
raise RuntimeError( |
f"Failed to load {self.fg_paths[index]} after {self.max_retries} retries" |
) from e |
logger.debug(f"Error loading {self.fg_paths[index]}: {e}, skipping") |
index = np.random.randint(0, len(self)) |
def __len__(self): |
return len(self.fg_paths) |
@staticmethod |
def collate_fn(batch): |
return dict( |
fg_wavs=_collate(batch, "fg_wav"), |
bg_wavs=_collate(batch, "bg_wav"), |
fg_dwavs=_collate(batch, "fg_dwav"), |
bg_dwavs=_collate(batch, "bg_dwav"), |
) |