import logging import random from pathlib import Path import numpy as np import torch import torchaudio import torchaudio.functional as AF from torch.nn.utils.rnn import pad_sequence from torch.utils.data import Dataset as DatasetBase from ..hparams import HParams from .distorter import Distorter from .utils import rglob_audio_files logger = logging.getLogger(__name__) def _normalize(x): return x / (np.abs(x).max() + 1e-7) def _collate(batch, key, tensor=True, pad=True): l = [d[key] for d in batch] if l[0] is None: return None if tensor: l = [torch.from_numpy(x) for x in l] if pad: assert tensor, "Can't pad non-tensor" l = pad_sequence(l, batch_first=True) return l def praat_augment(wav, sr): try: import parselmouth except ImportError: raise ImportError("Please install parselmouth>=0.5.0 to use Praat augmentation") # "praat-parselmouth @ git+https://github.com/YannickJadoul/Parselmouth@0bbcca69705ed73322f3712b19d71bb3694b2540", # https://github.com/YannickJadoul/Parselmouth/issues/68 # note that this function may hang if the praat version is 0.4.3 assert wav.ndim == 1, f"wav.ndim must be 1 but got {wav.ndim}" sound = parselmouth.Sound(wav, sr) formant_shift_ratio = random.uniform(1.1, 1.5) pitch_range_factor = random.uniform(0.5, 2.0) sound = parselmouth.praat.call( sound, "Change gender", 75, 600, formant_shift_ratio, 0, pitch_range_factor, 1.0 ) wav = np.array(sound.values)[0].astype(np.float32) return wav class Dataset(DatasetBase): def __init__( self, fg_paths: list[Path], hp: HParams, training=True, max_retries=100, silent_fg_prob=0.01, mode=False, ): super().__init__() assert mode in ("enhancer", "denoiser"), f"Invalid mode: {mode}" self.hp = hp self.fg_paths = fg_paths self.bg_paths = rglob_audio_files(hp.bg_dir) if len(self.fg_paths) == 0: raise ValueError(f"No foreground audio files found in {hp.fg_dir}") if len(self.bg_paths) == 0: raise ValueError(f"No background audio files found in {hp.bg_dir}") logger.info( f"Found {len(self.fg_paths)} foreground files and {len(self.bg_paths)} background files" ) self.training = training self.max_retries = max_retries self.silent_fg_prob = silent_fg_prob self.mode = mode self.distorter = Distorter(hp, training=training, mode=mode) def _load_wav(self, path, length=None, random_crop=True): wav, sr = torchaudio.load(path) wav = AF.resample( waveform=wav, orig_freq=sr, new_freq=self.hp.wav_rate, lowpass_filter_width=64, rolloff=0.9475937167399596, resampling_method="sinc_interp_kaiser", beta=14.769656459379492, ) wav = wav.float().numpy() if wav.ndim == 2: wav = np.mean(wav, axis=0) if length is None and self.training: length = int(self.hp.training_seconds * self.hp.wav_rate) if length is not None: if random_crop: start = random.randint(0, max(0, len(wav) - length)) wav = wav[start : start + length] else: wav = wav[:length] if length is not None and len(wav) < length: wav = np.pad(wav, (0, length - len(wav))) wav = _normalize(wav) return wav def _getitem_unsafe(self, index: int): fg_path = self.fg_paths[index] if self.training and random.random() < self.silent_fg_prob: fg_wav = np.zeros( int(self.hp.training_seconds * self.hp.wav_rate), dtype=np.float32 ) else: fg_wav = self._load_wav(fg_path) if random.random() < self.hp.praat_augment_prob and self.training: fg_wav = praat_augment(fg_wav, self.hp.wav_rate) if self.hp.load_fg_only: bg_wav = None fg_dwav = None bg_dwav = None else: fg_dwav = _normalize(self.distorter(fg_wav, self.hp.wav_rate)).astype( np.float32 ) if self.training: bg_path = random.choice(self.bg_paths) else: # Deterministic for validation bg_path = self.bg_paths[index % len(self.bg_paths)] bg_wav = self._load_wav( bg_path, length=len(fg_wav), random_crop=self.training ) bg_dwav = _normalize(self.distorter(bg_wav, self.hp.wav_rate)).astype( np.float32 ) return dict( fg_wav=fg_wav, bg_wav=bg_wav, fg_dwav=fg_dwav, bg_dwav=bg_dwav, ) def __getitem__(self, index: int): for i in range(self.max_retries): try: return self._getitem_unsafe(index) except Exception as e: if i == self.max_retries - 1: raise RuntimeError( f"Failed to load {self.fg_paths[index]} after {self.max_retries} retries" ) from e logger.debug(f"Error loading {self.fg_paths[index]}: {e}, skipping") index = np.random.randint(0, len(self)) def __len__(self): return len(self.fg_paths) @staticmethod def collate_fn(batch): return dict( fg_wavs=_collate(batch, "fg_wav"), bg_wavs=_collate(batch, "bg_wav"), fg_dwavs=_collate(batch, "fg_dwav"), bg_dwavs=_collate(batch, "bg_dwav"), )