|
|
|
|
|
|
|
|
|
|
|
import random |
|
import torch |
|
import torchaudio |
|
from torch.nn.utils.rnn import pad_sequence |
|
from utils.data_utils import * |
|
from models.base.base_dataset import ( |
|
BaseOfflineCollator, |
|
BaseOfflineDataset, |
|
BaseTestDataset, |
|
BaseTestCollator, |
|
) |
|
from text import text_to_sequence |
|
|
|
|
|
class JetsDataset(BaseOfflineDataset): |
|
def __init__(self, cfg, dataset, is_valid=False): |
|
BaseOfflineDataset.__init__(self, cfg, dataset, is_valid=is_valid) |
|
self.batch_size = cfg.train.batch_size |
|
cfg = cfg.preprocess |
|
|
|
self.utt2duration_path = {} |
|
for utt_info in self.metadata: |
|
dataset = utt_info["Dataset"] |
|
uid = utt_info["Uid"] |
|
utt = "{}_{}".format(dataset, uid) |
|
|
|
self.utt2duration_path[utt] = os.path.join( |
|
cfg.processed_dir, |
|
dataset, |
|
cfg.duration_dir, |
|
uid + ".npy", |
|
) |
|
self.utt2dur = self.read_duration() |
|
|
|
if cfg.use_frame_energy: |
|
self.frame_utt2energy, self.energy_statistic = load_energy( |
|
self.metadata, |
|
cfg.processed_dir, |
|
cfg.energy_dir, |
|
use_log_scale=cfg.use_log_scale_energy, |
|
utt2spk=self.preprocess.utt2spk if cfg.use_spkid else None, |
|
return_norm=True, |
|
) |
|
elif cfg.use_phone_energy: |
|
self.phone_utt2energy, self.energy_statistic = load_energy( |
|
self.metadata, |
|
cfg.processed_dir, |
|
cfg.phone_energy_dir, |
|
use_log_scale=cfg.use_log_scale_energy, |
|
utt2spk=self.utt2spk if cfg.use_spkid else None, |
|
return_norm=True, |
|
) |
|
|
|
if cfg.use_frame_pitch: |
|
self.frame_utt2pitch, self.pitch_statistic = load_energy( |
|
self.metadata, |
|
cfg.processed_dir, |
|
cfg.pitch_dir, |
|
use_log_scale=cfg.energy_extract_mode, |
|
utt2spk=self.utt2spk if cfg.use_spkid else None, |
|
return_norm=True, |
|
) |
|
|
|
elif cfg.use_phone_pitch: |
|
self.phone_utt2pitch, self.pitch_statistic = load_energy( |
|
self.metadata, |
|
cfg.processed_dir, |
|
cfg.phone_pitch_dir, |
|
use_log_scale=cfg.use_log_scale_pitch, |
|
utt2spk=self.utt2spk if cfg.use_spkid else None, |
|
return_norm=True, |
|
) |
|
|
|
|
|
self.utt2lab_path = {} |
|
for utt_info in self.metadata: |
|
dataset = utt_info["Dataset"] |
|
uid = utt_info["Uid"] |
|
utt = "{}_{}".format(dataset, uid) |
|
|
|
self.utt2lab_path[utt] = os.path.join( |
|
cfg.processed_dir, |
|
dataset, |
|
cfg.lab_dir, |
|
uid + ".txt", |
|
) |
|
|
|
self.speaker_map = {} |
|
if os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json")): |
|
with open( |
|
os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json")) |
|
) as f: |
|
self.speaker_map = json.load(f) |
|
|
|
self.metadata = self.check_metadata() |
|
if cfg.use_audios: |
|
self.utt2audio_path = {} |
|
for utt_info in self.metadata: |
|
dataset = utt_info["Dataset"] |
|
uid = utt_info["Uid"] |
|
utt = "{}_{}".format(dataset, uid) |
|
|
|
if cfg.extract_audio: |
|
self.utt2audio_path[utt] = os.path.join( |
|
cfg.processed_dir, |
|
dataset, |
|
cfg.audio_dir, |
|
uid + ".wav", |
|
) |
|
else: |
|
self.utt2audio_path[utt] = utt_info["Path"] |
|
|
|
def __getitem__(self, index): |
|
single_feature = BaseOfflineDataset.__getitem__(self, index) |
|
|
|
utt_info = self.metadata[index] |
|
dataset = utt_info["Dataset"] |
|
uid = utt_info["Uid"] |
|
utt = "{}_{}".format(dataset, uid) |
|
|
|
duration = self.utt2dur[utt] |
|
|
|
|
|
f = open(self.utt2lab_path[utt], "r") |
|
phones = f.readlines()[0].strip() |
|
f.close() |
|
|
|
phones_ids = np.array(text_to_sequence(phones, ["english_cleaners"])) |
|
text_len = len(phones_ids) |
|
|
|
if self.cfg.preprocess.use_frame_pitch: |
|
pitch = self.frame_utt2pitch[utt] |
|
elif self.cfg.preprocess.use_phone_pitch: |
|
pitch = self.phone_utt2pitch[utt] |
|
|
|
if self.cfg.preprocess.use_frame_energy: |
|
energy = self.frame_utt2energy[utt] |
|
elif self.cfg.preprocess.use_phone_energy: |
|
energy = self.phone_utt2energy[utt] |
|
|
|
|
|
if len(self.speaker_map) > 0: |
|
speaker_id = self.speaker_map[utt_info["Singer"]] |
|
else: |
|
speaker_id = 0 |
|
|
|
single_feature.update( |
|
{ |
|
"durations": duration, |
|
"texts": phones_ids, |
|
"spk_id": speaker_id, |
|
"text_len": text_len, |
|
"pitch": pitch, |
|
"energy": energy, |
|
"uid": uid, |
|
} |
|
) |
|
|
|
if self.cfg.preprocess.use_audios: |
|
audio, sr = torchaudio.load(self.utt2audio_path[utt]) |
|
audio = audio.cpu().numpy().squeeze() |
|
single_feature["audio"] = audio |
|
single_feature["audio_len"] = audio.shape[0] |
|
return self.clip_if_too_long(single_feature) |
|
|
|
def read_duration(self): |
|
|
|
utt2dur = {} |
|
for index in range(len(self.metadata)): |
|
utt_info = self.metadata[index] |
|
dataset = utt_info["Dataset"] |
|
uid = utt_info["Uid"] |
|
utt = "{}_{}".format(dataset, uid) |
|
|
|
if not os.path.exists(self.utt2mel_path[utt]) or not os.path.exists( |
|
self.utt2duration_path[utt] |
|
): |
|
continue |
|
|
|
mel = np.load(self.utt2mel_path[utt]).transpose(1, 0) |
|
duration = np.load(self.utt2duration_path[utt]) |
|
assert mel.shape[0] == sum( |
|
duration |
|
), f"{utt}: mismatch length between mel {mel.shape[0]} and sum(duration) {sum(duration)}" |
|
utt2dur[utt] = duration |
|
return utt2dur |
|
|
|
def __len__(self): |
|
return len(self.metadata) |
|
|
|
def random_select(self, feature_seq_len, max_seq_len, ending_ts=2812): |
|
""" |
|
ending_ts: to avoid invalid whisper features for over 30s audios |
|
2812 = 30 * 24000 // 256 |
|
""" |
|
ts = max(feature_seq_len - max_seq_len, 0) |
|
ts = min(ts, ending_ts - max_seq_len) |
|
|
|
start = random.randint(0, ts) |
|
end = start + max_seq_len |
|
return start, end |
|
|
|
def clip_if_too_long(self, sample, max_seq_len=1000): |
|
""" |
|
sample : |
|
{ |
|
'spk_id': (1,), |
|
'target_len': int |
|
'mel': (seq_len, dim), |
|
'frame_pitch': (seq_len,) |
|
'frame_energy': (seq_len,) |
|
'content_vector_feat': (seq_len, dim) |
|
} |
|
""" |
|
if sample["target_len"] <= max_seq_len: |
|
return sample |
|
|
|
start, end = self.random_select(sample["target_len"], max_seq_len) |
|
sample["target_len"] = end - start |
|
|
|
for k in sample.keys(): |
|
if k not in ["spk_id", "target_len"]: |
|
sample[k] = sample[k][start:end] |
|
|
|
return sample |
|
|
|
def check_metadata(self): |
|
new_metadata = [] |
|
for utt_info in self.metadata: |
|
dataset = utt_info["Dataset"] |
|
uid = utt_info["Uid"] |
|
utt = "{}_{}".format(dataset, uid) |
|
if not os.path.exists(self.utt2duration_path[utt]) or not os.path.exists( |
|
self.utt2mel_path[utt] |
|
): |
|
continue |
|
else: |
|
new_metadata.append(utt_info) |
|
return new_metadata |
|
|
|
|
|
class JetsCollator(BaseOfflineCollator): |
|
"""Zero-pads model inputs and targets based on number of frames per step""" |
|
|
|
def __init__(self, cfg): |
|
BaseOfflineCollator.__init__(self, cfg) |
|
self.sort = cfg.train.sort_sample |
|
self.batch_size = cfg.train.batch_size |
|
self.drop_last = cfg.train.drop_last |
|
|
|
def __call__(self, batch): |
|
|
|
|
|
|
|
|
|
|
|
packed_batch_features = dict() |
|
|
|
for key in batch[0].keys(): |
|
if key == "target_len": |
|
packed_batch_features["target_len"] = torch.LongTensor( |
|
[b["target_len"] for b in batch] |
|
) |
|
masks = [ |
|
torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch |
|
] |
|
packed_batch_features["mask"] = pad_sequence( |
|
masks, batch_first=True, padding_value=0 |
|
) |
|
elif key == "text_len": |
|
packed_batch_features["text_len"] = torch.LongTensor( |
|
[b["text_len"] for b in batch] |
|
) |
|
masks = [ |
|
torch.ones((b["text_len"], 1), dtype=torch.long) for b in batch |
|
] |
|
packed_batch_features["text_mask"] = pad_sequence( |
|
masks, batch_first=True, padding_value=0 |
|
) |
|
elif key == "spk_id": |
|
packed_batch_features["spk_id"] = torch.LongTensor( |
|
[b["spk_id"] for b in batch] |
|
) |
|
elif key == "uid": |
|
packed_batch_features[key] = [b["uid"] for b in batch] |
|
elif key == "audio_len": |
|
packed_batch_features["audio_len"] = torch.LongTensor( |
|
[b["audio_len"] for b in batch] |
|
) |
|
else: |
|
values = [torch.from_numpy(b[key]) for b in batch] |
|
packed_batch_features[key] = pad_sequence( |
|
values, batch_first=True, padding_value=0 |
|
) |
|
return packed_batch_features |
|
|
|
|
|
class JetsTestDataset(BaseTestDataset): |
|
def __init__(self, args, cfg, infer_type=None): |
|
datasets = cfg.dataset |
|
cfg = cfg.preprocess |
|
is_bigdata = False |
|
|
|
assert len(datasets) >= 1 |
|
if len(datasets) > 1: |
|
datasets.sort() |
|
bigdata_version = "_".join(datasets) |
|
processed_data_dir = os.path.join(cfg.processed_dir, bigdata_version) |
|
is_bigdata = True |
|
else: |
|
processed_data_dir = os.path.join(cfg.processed_dir, args.dataset) |
|
|
|
if args.test_list_file: |
|
self.metafile_path = args.test_list_file |
|
self.metadata = self.get_metadata() |
|
else: |
|
assert args.testing_set |
|
source_metafile_path = os.path.join( |
|
cfg.processed_dir, |
|
args.dataset, |
|
"{}.json".format(args.testing_set), |
|
) |
|
with open(source_metafile_path, "r") as f: |
|
self.metadata = json.load(f) |
|
|
|
self.cfg = cfg |
|
self.datasets = datasets |
|
self.data_root = processed_data_dir |
|
self.is_bigdata = is_bigdata |
|
self.source_dataset = args.dataset |
|
|
|
|
|
if cfg.use_spkid: |
|
spk2id_path = os.path.join(self.data_root, cfg.spk2id) |
|
utt2sp_path = os.path.join(self.data_root, cfg.utt2spk) |
|
self.spk2id, self.utt2spk = get_spk_map(spk2id_path, utt2sp_path, datasets) |
|
|
|
|
|
self.utt2lab_path = {} |
|
for utt_info in self.metadata: |
|
dataset = utt_info["Dataset"] |
|
uid = utt_info["Uid"] |
|
utt = "{}_{}".format(dataset, uid) |
|
self.utt2lab_path[utt] = os.path.join( |
|
cfg.processed_dir, |
|
dataset, |
|
cfg.lab_dir, |
|
uid + ".txt", |
|
) |
|
|
|
self.speaker_map = {} |
|
if os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json")): |
|
with open( |
|
os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json")) |
|
) as f: |
|
self.speaker_map = json.load(f) |
|
|
|
def __getitem__(self, index): |
|
single_feature = {} |
|
|
|
utt_info = self.metadata[index] |
|
dataset = utt_info["Dataset"] |
|
uid = utt_info["Uid"] |
|
utt = "{}_{}".format(dataset, uid) |
|
|
|
|
|
f = open(self.utt2lab_path[utt], "r") |
|
phones = f.readlines()[0].strip() |
|
f.close() |
|
|
|
phones_ids = np.array(text_to_sequence(phones, self.cfg.text_cleaners)) |
|
text_len = len(phones_ids) |
|
|
|
|
|
if len(self.speaker_map) > 0: |
|
speaker_id = self.speaker_map[utt_info["Singer"]] |
|
else: |
|
speaker_id = 0 |
|
|
|
single_feature.update( |
|
{ |
|
"texts": phones_ids, |
|
"spk_id": speaker_id, |
|
"text_len": text_len, |
|
} |
|
) |
|
|
|
return single_feature |
|
|
|
def __len__(self): |
|
return len(self.metadata) |
|
|
|
def get_metadata(self): |
|
with open(self.metafile_path, "r", encoding="utf-8") as f: |
|
metadata = json.load(f) |
|
|
|
return metadata |
|
|
|
|
|
class JetsTestCollator(BaseTestCollator): |
|
"""Zero-pads model inputs and targets based on number of frames per step""" |
|
|
|
def __init__(self, cfg): |
|
self.cfg = cfg |
|
|
|
def __call__(self, batch): |
|
packed_batch_features = dict() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for key in batch[0].keys(): |
|
if key == "target_len": |
|
packed_batch_features["target_len"] = torch.LongTensor( |
|
[b["target_len"] for b in batch] |
|
) |
|
masks = [ |
|
torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch |
|
] |
|
packed_batch_features["mask"] = pad_sequence( |
|
masks, batch_first=True, padding_value=0 |
|
) |
|
elif key == "text_len": |
|
packed_batch_features["text_len"] = torch.LongTensor( |
|
[b["text_len"] for b in batch] |
|
) |
|
masks = [ |
|
torch.ones((b["text_len"], 1), dtype=torch.long) for b in batch |
|
] |
|
packed_batch_features["text_mask"] = pad_sequence( |
|
masks, batch_first=True, padding_value=0 |
|
) |
|
elif key == "spk_id": |
|
packed_batch_features["spk_id"] = torch.LongTensor( |
|
[b["spk_id"] for b in batch] |
|
) |
|
else: |
|
values = [torch.from_numpy(b[key]) for b in batch] |
|
packed_batch_features[key] = pad_sequence( |
|
values, batch_first=True, padding_value=0 |
|
) |
|
|
|
return packed_batch_features |
|
|