DiffSpeech / tasks /vocoder /dataset_utils.py
RayeRen's picture
init
d1b91e7
raw
history blame
5.54 kB
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import DistributedSampler
from utils.commons.dataset_utils import BaseDataset, collate_1d, collate_2d
from utils.commons.hparams import hparams
from utils.commons.indexed_datasets import IndexedDataset
class EndlessDistributedSampler(DistributedSampler):
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.shuffle = shuffle
g = torch.Generator()
g.manual_seed(self.epoch)
if self.shuffle:
indices = [i for _ in range(1000) for i in torch.randperm(
len(self.dataset), generator=g).tolist()]
else:
indices = [i for _ in range(1000) for i in list(range(len(self.dataset)))]
indices = indices[:len(indices) // self.num_replicas * self.num_replicas]
indices = indices[self.rank::self.num_replicas]
self.indices = indices
def __iter__(self):
return iter(self.indices)
def __len__(self):
return len(self.indices)
class VocoderDataset(BaseDataset):
def __init__(self, prefix, shuffle=False):
super().__init__(shuffle)
self.hparams = hparams
self.prefix = prefix
self.data_dir = hparams['binary_data_dir']
self.is_infer = prefix == 'test'
self.batch_max_frames = 0 if self.is_infer else hparams['max_samples'] // hparams['hop_size']
self.hop_size = hparams['hop_size']
self.indexed_ds = None
self.sizes = np.load(f'{self.data_dir}/{self.prefix}_lengths.npy')
self.avail_idxs = [idx for idx, s in enumerate(self.sizes) if s > self.batch_max_frames]
print(f"| {len(self.sizes) - len(self.avail_idxs)} short items are skipped in {prefix} set.")
self.sizes = [s for idx, s in enumerate(self.sizes) if s > self.batch_max_frames]
def _get_item(self, index):
if self.indexed_ds is None:
self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}')
item = self.indexed_ds[index]
return item
def __getitem__(self, index):
index = self.avail_idxs[index]
item = self._get_item(index)
sample = {
"id": index,
"item_name": item['item_name'],
"mel": torch.FloatTensor(item['mel']),
"wav": torch.FloatTensor(item['wav'].astype(np.float32)),
"pitch": torch.LongTensor(item['pitch']),
"f0": torch.FloatTensor(item['f0'])
}
return sample
def collater(self, batch):
if len(batch) == 0:
return {}
y_batch, c_batch, p_batch, f0_batch = [], [], [], []
item_name = []
for idx in range(len(batch)):
item_name.append(batch[idx]['item_name'])
x, c = batch[idx]['wav'], batch[idx]['mel']
p, f0 = batch[idx]['pitch'], batch[idx]['f0']
self._assert_ready_for_upsampling(x, c, self.hop_size)
if len(c) > self.batch_max_frames:
# randomly pickup with the batch_max_steps length of the part
batch_max_frames = self.batch_max_frames if self.batch_max_frames != 0 else len(c) - 1
batch_max_steps = batch_max_frames * self.hop_size
interval_start = 0
interval_end = len(c) - batch_max_frames
start_frame = np.random.randint(interval_start, interval_end)
start_step = start_frame * self.hop_size
y = x[start_step: start_step + batch_max_steps]
c = c[start_frame: start_frame + batch_max_frames]
p = p[start_frame: start_frame + batch_max_frames]
f0 = f0[start_frame: start_frame + batch_max_frames]
self._assert_ready_for_upsampling(y, c, self.hop_size)
else:
print(f"Removed short sample from batch (length={len(x)}).")
continue
y_batch += [y.reshape(-1, 1)] # [(T, 1), (T, 1), ...]
c_batch += [c] # [(T' C), (T' C), ...]
p_batch += [p] # [(T' C), (T' C), ...]
f0_batch += [f0] # [(T' C), (T' C), ...]
# convert each batch to tensor, asuume that each item in batch has the same length
y_batch = collate_2d(y_batch, 0).transpose(2, 1) # (B, 1, T)
c_batch = collate_2d(c_batch, 0).transpose(2, 1) # (B, C, T')
p_batch = collate_1d(p_batch, 0) # (B, T')
f0_batch = collate_1d(f0_batch, 0) # (B, T')
# make input noise signal batch tensor
z_batch = torch.randn(y_batch.size()) # (B, 1, T)
return {
'z': z_batch,
'mels': c_batch,
'wavs': y_batch,
'pitches': p_batch,
'f0': f0_batch,
'item_name': item_name
}
@staticmethod
def _assert_ready_for_upsampling(x, c, hop_size):
"""Assert the audio and feature lengths are correctly adjusted for upsamping."""
assert len(x) == (len(c)) * hop_size