Spaces:
Runtime error
Runtime error
import os.path | |
import random | |
import numpy as np | |
import torch | |
import re | |
import torch.utils.data | |
import json | |
import kaldiio | |
from tqdm import tqdm | |
class BaseCollate: | |
def __init__(self, n_frames_per_step=1): | |
self.n_frames_per_step = n_frames_per_step | |
def collate_text_mel(self, batch: [dict]): | |
""" | |
:param batch: list of dicts | |
""" | |
utt = list(map(lambda x: x['utt'], batch)) | |
input_lengths, ids_sorted_decreasing = torch.sort( | |
torch.LongTensor([len(x['text']) for x in batch]), | |
dim=0, descending=True) | |
max_input_len = input_lengths[0] | |
text_padded = torch.LongTensor(len(batch), max_input_len) | |
text_padded.zero_() | |
for i in range(len(ids_sorted_decreasing)): | |
text = batch[ids_sorted_decreasing[i]]['text'] | |
text_padded[i, :text.size(0)] = text | |
# Right zero-pad mel-spec | |
num_mels = batch[0]['mel'].size(0) | |
max_target_len = max([x['mel'].size(1) for x in batch]) | |
if max_target_len % self.n_frames_per_step != 0: | |
max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step | |
assert max_target_len % self.n_frames_per_step == 0 | |
# include mel padded | |
mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len) | |
mel_padded.zero_() | |
output_lengths = torch.LongTensor(len(batch)) | |
for i in range(len(ids_sorted_decreasing)): | |
mel = batch[ids_sorted_decreasing[i]]['mel'] | |
mel_padded[i, :, :mel.size(1)] = mel | |
output_lengths[i] = mel.size(1) | |
utt_name = np.array(utt)[ids_sorted_decreasing].tolist() | |
if isinstance(utt_name, str): | |
utt_name = [utt_name] | |
res = { | |
"utt": utt_name, | |
"text_padded": text_padded, | |
"input_lengths": input_lengths, | |
"mel_padded": mel_padded, | |
"output_lengths": output_lengths, | |
} | |
return res, ids_sorted_decreasing | |
class SpkIDCollate(BaseCollate): | |
def __call__(self, batch, *args, **kwargs): | |
base_data, ids_sorted_decreasing = self.collate_text_mel(batch) | |
spk_ids = torch.LongTensor(list(map(lambda x: x["spk_ids"], batch))) | |
spk_ids = spk_ids[ids_sorted_decreasing] | |
base_data.update({ | |
"spk_ids": spk_ids | |
}) | |
return base_data | |
class SpkIDCollateWithEmo(BaseCollate): | |
def __call__(self, batch, *args, **kwargs): | |
base_data, ids_sorted_decreasing = self.collate_text_mel(batch) | |
spk_ids = torch.LongTensor(list(map(lambda x: x["spk_ids"], batch))) | |
spk_ids = spk_ids[ids_sorted_decreasing] | |
emo_ids = torch.LongTensor(list(map(lambda x: x['emo_ids'], batch))) | |
emo_ids = emo_ids[ids_sorted_decreasing] | |
base_data.update({ | |
"spk_ids": spk_ids, | |
"emo_ids": emo_ids | |
}) | |
return base_data | |
class XvectorCollate(BaseCollate): | |
def __call__(self, batch, *args, **kwargs): | |
base_data, ids_sorted_decreasing = self.collate_text_mel(batch) | |
xvectors = torch.cat(list(map(lambda x: x["xvector"].unsqueeze(0), batch)), dim=0) | |
xvectors = xvectors[ids_sorted_decreasing] | |
base_data.update({ | |
"xvector": xvectors | |
}) | |
return base_data | |
class SpkIDCollateWithPE(BaseCollate): | |
def __call__(self, batch, *args, **kwargs): | |
base_data, ids_sorted_decreasing = self.collate_text_mel(batch) | |
spk_ids = torch.LongTensor(list(map(lambda x: x["spk_ids"], batch))) | |
spk_ids = spk_ids[ids_sorted_decreasing] | |
num_var = batch[0]["var"].size(0) | |
max_target_len = max([x["var"].size(1) for x in batch]) | |
if max_target_len % self.n_frames_per_step != 0: | |
max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step | |
assert max_target_len % self.n_frames_per_step == 0 | |
var_padded = torch.FloatTensor(len(batch), num_var, max_target_len) | |
var_padded.zero_() | |
for i in range(len(ids_sorted_decreasing)): | |
var = batch[ids_sorted_decreasing[i]]["var"] | |
var_padded[i, :, :var.size(1)] = var | |
base_data.update({ | |
"spk_ids": spk_ids, | |
"var_padded": var_padded | |
}) | |
return base_data | |
class XvectorCollateWithPE(BaseCollate): | |
def __call__(self, batch, *args, **kwargs): | |
base_data, ids_sorted_decreasing = self.collate_text_mel(batch) | |
xvectors = torch.cat(list(map(lambda x: x["xvector"].unsqueeze(0), batch)), dim=0) | |
xvectors = xvectors[ids_sorted_decreasing] | |
num_var = batch[0]["var"].size(0) | |
max_target_len = max([x["var"].size(1) for x in batch]) | |
if max_target_len % self.n_frames_per_step != 0: | |
max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step | |
assert max_target_len % self.n_frames_per_step == 0 | |
var_padded = torch.FloatTensor(len(batch), num_var, max_target_len) | |
var_padded.zero_() | |
for i in range(len(ids_sorted_decreasing)): | |
var = batch[ids_sorted_decreasing[i]]["var"] | |
var_padded[i, :, :var.size(1)] = var | |
base_data.update({ | |
"xvector": xvectors, | |
"var_padded": var_padded | |
}) | |
return base_data | |