Spaces:
Runtime error
Runtime error
import logging | |
import torch | |
import torch.utils.data | |
logger = logging.getLogger(__name__) | |
class TextAudioCollateMultiNSFsid: | |
"""Zero-pads model inputs and targets""" | |
def __init__(self): | |
pass | |
def __call__(self, batch): | |
"""Collate's training batch from normalized text and aduio | |
PARAMS | |
------ | |
batch: [text_normalized, spec_normalized, wav_normalized] | |
""" | |
device = batch[0]["spec"].device | |
with device: | |
# Right zero-pad all one-hot text sequences to max input length | |
_, ids_sorted_decreasing = torch.sort( | |
torch.tensor([x["spec"].size(1) for x in batch], dtype=torch.long), | |
dim=0, | |
descending=True, | |
) | |
max_spec_len = max([x["spec"].size(1) for x in batch]) | |
max_wave_len = max([x["wav_gt"]["array"].size(0) for x in batch]) | |
spec_lengths = torch.zeros(len(batch), dtype=torch.long) | |
wave_lengths = torch.zeros(len(batch), dtype=torch.long) | |
spec_padded = torch.zeros( | |
len(batch), batch[0]["spec"].size(0), max_spec_len, dtype=torch.float32 | |
) | |
wave_padded = torch.zeros(len(batch), 1, max_wave_len, dtype=torch.float32) | |
max_phone_len = max([x["hubert_feats"].size(0) for x in batch]) | |
phone_lengths = torch.zeros(len(batch), dtype=torch.long) | |
phone_padded = torch.zeros( | |
len(batch), | |
max_phone_len, | |
batch[0]["hubert_feats"].shape[1], | |
dtype=torch.float32, | |
) # (spec, wav, phone, pitch) | |
pitch_padded = torch.zeros(len(batch), max_phone_len, dtype=torch.long) | |
pitchf_padded = torch.zeros(len(batch), max_phone_len, dtype=torch.float32) | |
# dv = torch.FloatTensor(len(batch), 256)#gin=256 | |
sid = torch.zeros(len(batch), dtype=torch.long) | |
for i in range(len(ids_sorted_decreasing)): | |
row = batch[ids_sorted_decreasing[i]] | |
spec = row["spec"] | |
spec_padded[i, :, : spec.size(1)] = spec | |
spec_lengths[i] = spec.size(1) | |
wave = row["wav_gt"]["array"] | |
wave_padded[i, :, : wave.size(0)] = wave | |
wave_lengths[i] = wave.size(0) | |
phone = row["hubert_feats"] | |
phone_padded[i, : phone.size(0), :] = phone | |
phone_lengths[i] = phone.size(0) | |
pitch = row["f0"] | |
pitch_padded[i, : pitch.size(0)] = pitch | |
pitchf = row["f0nsf"] | |
pitchf_padded[i, : pitchf.size(0)] = pitchf | |
sid[i] = torch.tensor([0], dtype=torch.long) | |
return ( | |
phone_padded, | |
phone_lengths, | |
pitch_padded, | |
pitchf_padded, | |
spec_padded, | |
spec_lengths, | |
wave_padded, | |
wave_lengths, | |
sid, | |
) | |