zhzluke96
update
32b2aaa
raw
history blame
1.27 kB
import logging
import random
from torch.utils.data import DataLoader
from ..hparams import HParams
from .dataset import Dataset
from .utils import mix_fg_bg, rglob_audio_files
logger = logging.getLogger(__name__)
def _create_datasets(hp: HParams, mode, val_size=10, seed=123):
paths = rglob_audio_files(hp.fg_dir)
logger.info(f"Found {len(paths)} audio files in {hp.fg_dir}")
random.Random(seed).shuffle(paths)
train_paths = paths[:-val_size]
val_paths = paths[-val_size:]
train_ds = Dataset(train_paths, hp, training=True, mode=mode)
val_ds = Dataset(val_paths, hp, training=False, mode=mode)
logger.info(f"Train set: {len(train_ds)} samples - Val set: {len(val_ds)} samples")
return train_ds, val_ds
def create_dataloaders(hp: HParams, mode):
train_ds, val_ds = _create_datasets(hp=hp, mode=mode)
train_dl = DataLoader(
train_ds,
batch_size=hp.batch_size_per_gpu,
shuffle=True,
num_workers=hp.nj,
drop_last=True,
collate_fn=train_ds.collate_fn,
)
val_dl = DataLoader(
val_ds,
batch_size=1,
shuffle=False,
num_workers=hp.nj,
drop_last=False,
collate_fn=val_ds.collate_fn,
)
return train_dl, val_dl