Spaces:
Runtime error
Runtime error
# Copyright (c) 2022 NVIDIA CORPORATION. | |
# Licensed under the MIT license. | |
import os | |
import numpy as np | |
from scipy.io.wavfile import read as wavread | |
import warnings | |
warnings.filterwarnings("ignore") | |
import torch | |
from torch.utils.data import Dataset | |
from torch.utils.data.distributed import DistributedSampler | |
import random | |
random.seed(0) | |
torch.manual_seed(0) | |
np.random.seed(0) | |
from torchvision import datasets, models, transforms | |
import torchaudio | |
class CleanNoisyPairDataset(Dataset): | |
""" | |
Create a Dataset of clean and noisy audio pairs. | |
Each element is a tuple of the form (clean waveform, noisy waveform, file_id) | |
""" | |
def __init__(self, root='./', subset='training', crop_length_sec=0): | |
super(CleanNoisyPairDataset).__init__() | |
assert subset is None or subset in ["training", "testing"] | |
self.crop_length_sec = crop_length_sec | |
self.subset = subset | |
N_clean = len(os.listdir(os.path.join(root, 'training_set/clean'))) | |
N_noisy = len(os.listdir(os.path.join(root, 'training_set/noisy'))) | |
assert N_clean == N_noisy | |
if subset == "training": | |
self.files = [(os.path.join(root, 'training_set/clean', 'fileid_{}.wav'.format(i)), | |
os.path.join(root, 'training_set/noisy', 'fileid_{}.wav'.format(i))) for i in range(N_clean)] | |
elif subset == "testing": | |
sortkey = lambda name: '_'.join(name.split('_')[-2:]) # specific for dns due to test sample names | |
_p = os.path.join(root, 'datasets/test_set/synthetic/no_reverb') # path for DNS | |
clean_files = os.listdir(os.path.join(_p, 'clean')) | |
noisy_files = os.listdir(os.path.join(_p, 'noisy')) | |
clean_files.sort(key=sortkey) | |
noisy_files.sort(key=sortkey) | |
self.files = [] | |
for _c, _n in zip(clean_files, noisy_files): | |
assert sortkey(_c) == sortkey(_n) | |
self.files.append((os.path.join(_p, 'clean', _c), | |
os.path.join(_p, 'noisy', _n))) | |
self.crop_length_sec = 0 | |
else: | |
raise NotImplementedError | |
def __getitem__(self, n): | |
fileid = self.files[n] | |
clean_audio, sample_rate = torchaudio.load(fileid[0]) | |
noisy_audio, sample_rate = torchaudio.load(fileid[1]) | |
clean_audio, noisy_audio = clean_audio.squeeze(0), noisy_audio.squeeze(0) | |
assert len(clean_audio) == len(noisy_audio) | |
crop_length = int(self.crop_length_sec * sample_rate) | |
assert crop_length < len(clean_audio) | |
# random crop | |
if self.subset != 'testing' and crop_length > 0: | |
start = np.random.randint(low=0, high=len(clean_audio) - crop_length + 1) | |
clean_audio = clean_audio[start:(start + crop_length)] | |
noisy_audio = noisy_audio[start:(start + crop_length)] | |
clean_audio, noisy_audio = clean_audio.unsqueeze(0), noisy_audio.unsqueeze(0) | |
return (clean_audio, noisy_audio, fileid) | |
def __len__(self): | |
return len(self.files) | |
def load_CleanNoisyPairDataset(root, subset, crop_length_sec, batch_size, sample_rate, num_gpus=1): | |
""" | |
Get dataloader with distributed sampling | |
""" | |
dataset = CleanNoisyPairDataset(root=root, subset=subset, crop_length_sec=crop_length_sec) | |
kwargs = {"batch_size": batch_size, "num_workers": 4, "pin_memory": False, "drop_last": False} | |
if num_gpus > 1: | |
train_sampler = DistributedSampler(dataset) | |
dataloader = torch.utils.data.DataLoader(dataset, sampler=train_sampler, **kwargs) | |
else: | |
dataloader = torch.utils.data.DataLoader(dataset, sampler=None, shuffle=True, **kwargs) | |
return dataloader | |
if __name__ == '__main__': | |
import json | |
with open('./configs/DNS-large-full.json') as f: | |
data = f.read() | |
config = json.loads(data) | |
trainset_config = config["trainset_config"] | |
trainloader = load_CleanNoisyPairDataset(**trainset_config, subset='training', batch_size=2, num_gpus=1) | |
testloader = load_CleanNoisyPairDataset(**trainset_config, subset='testing', batch_size=2, num_gpus=1) | |
print(len(trainloader), len(testloader)) | |
for clean_audio, noisy_audio, fileid in trainloader: | |
clean_audio = clean_audio.cuda() | |
noisy_audio = noisy_audio.cuda() | |
print(clean_audio.shape, noisy_audio.shape, fileid) | |
break | |