ZeroRVC / zerorvc /utils /data_utils.py
github-actions[bot]
Sync to HuggingFace Spaces
f80c5ec
raw
history blame
3.14 kB
import logging
import torch
import torch.utils.data
logger = logging.getLogger(__name__)
class TextAudioCollateMultiNSFsid:
"""Zero-pads model inputs and targets"""
def __init__(self):
pass
def __call__(self, batch):
"""Collate's training batch from normalized text and aduio
PARAMS
------
batch: [text_normalized, spec_normalized, wav_normalized]
"""
device = batch[0]["spec"].device
with device:
# Right zero-pad all one-hot text sequences to max input length
_, ids_sorted_decreasing = torch.sort(
torch.tensor([x["spec"].size(1) for x in batch], dtype=torch.long),
dim=0,
descending=True,
)
max_spec_len = max([x["spec"].size(1) for x in batch])
max_wave_len = max([x["wav_gt"]["array"].size(0) for x in batch])
spec_lengths = torch.zeros(len(batch), dtype=torch.long)
wave_lengths = torch.zeros(len(batch), dtype=torch.long)
spec_padded = torch.zeros(
len(batch), batch[0]["spec"].size(0), max_spec_len, dtype=torch.float32
)
wave_padded = torch.zeros(len(batch), 1, max_wave_len, dtype=torch.float32)
max_phone_len = max([x["hubert_feats"].size(0) for x in batch])
phone_lengths = torch.zeros(len(batch), dtype=torch.long)
phone_padded = torch.zeros(
len(batch),
max_phone_len,
batch[0]["hubert_feats"].shape[1],
dtype=torch.float32,
) # (spec, wav, phone, pitch)
pitch_padded = torch.zeros(len(batch), max_phone_len, dtype=torch.long)
pitchf_padded = torch.zeros(len(batch), max_phone_len, dtype=torch.float32)
# dv = torch.FloatTensor(len(batch), 256)#gin=256
sid = torch.zeros(len(batch), dtype=torch.long)
for i in range(len(ids_sorted_decreasing)):
row = batch[ids_sorted_decreasing[i]]
spec = row["spec"]
spec_padded[i, :, : spec.size(1)] = spec
spec_lengths[i] = spec.size(1)
wave = row["wav_gt"]["array"]
wave_padded[i, :, : wave.size(0)] = wave
wave_lengths[i] = wave.size(0)
phone = row["hubert_feats"]
phone_padded[i, : phone.size(0), :] = phone
phone_lengths[i] = phone.size(0)
pitch = row["f0"]
pitch_padded[i, : pitch.size(0)] = pitch
pitchf = row["f0nsf"]
pitchf_padded[i, : pitchf.size(0)] = pitchf
sid[i] = torch.tensor([0], dtype=torch.long)
return (
phone_padded,
phone_lengths,
pitch_padded,
pitchf_padded,
spec_padded,
spec_lengths,
wave_padded,
wave_lengths,
sid,
)