# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import random import torch from torch.nn.utils.rnn import pad_sequence from utils.data_utils import * from processors.acoustic_extractor import cal_normalized_mel from processors.acoustic_extractor import load_normalized from models.base.base_dataset import ( BaseOfflineCollator, BaseOfflineDataset, BaseTestDataset, BaseTestCollator, ) from text import text_to_sequence from text.cmudict import valid_symbols from tqdm import tqdm import pickle class NS2Dataset(torch.utils.data.Dataset): def __init__(self, cfg, dataset, is_valid=False): assert isinstance(dataset, str) processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset) meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file # train.json self.metafile_path = os.path.join(processed_data_dir, meta_file) self.metadata = self.get_metadata() self.cfg = cfg assert cfg.preprocess.use_mel == False if cfg.preprocess.use_mel: self.utt2melspec_path = {} for utt_info in self.metadata: dataset = utt_info["Dataset"] uid = utt_info["Uid"] utt = "{}_{}".format(dataset, uid) self.utt2melspec_path[utt] = os.path.join( cfg.preprocess.processed_dir, dataset, cfg.preprocess.melspec_dir, # mel utt_info["speaker"], uid + ".npy", ) assert cfg.preprocess.use_code == True if cfg.preprocess.use_code: self.utt2code_path = {} for utt_info in self.metadata: dataset = utt_info["Dataset"] uid = utt_info["Uid"] utt = "{}_{}".format(dataset, uid) self.utt2code_path[utt] = os.path.join( cfg.preprocess.processed_dir, dataset, cfg.preprocess.code_dir, # code utt_info["speaker"], uid + ".npy", ) assert cfg.preprocess.use_spkid == True if cfg.preprocess.use_spkid: self.utt2spkid = {} for utt_info in self.metadata: dataset = utt_info["Dataset"] uid = utt_info["Uid"] utt = "{}_{}".format(dataset, uid) self.utt2spkid[utt] = utt_info["speaker"] assert cfg.preprocess.use_pitch == True if cfg.preprocess.use_pitch: self.utt2pitch_path = {} for utt_info in self.metadata: dataset = utt_info["Dataset"] uid = utt_info["Uid"] utt = "{}_{}".format(dataset, uid) self.utt2pitch_path[utt] = os.path.join( cfg.preprocess.processed_dir, dataset, cfg.preprocess.pitch_dir, # pitch utt_info["speaker"], uid + ".npy", ) assert cfg.preprocess.use_duration == True if cfg.preprocess.use_duration: self.utt2duration_path = {} for utt_info in self.metadata: dataset = utt_info["Dataset"] uid = utt_info["Uid"] utt = "{}_{}".format(dataset, uid) self.utt2duration_path[utt] = os.path.join( cfg.preprocess.processed_dir, dataset, cfg.preprocess.duration_dir, # duration utt_info["speaker"], uid + ".npy", ) assert cfg.preprocess.use_phone == True if cfg.preprocess.use_phone: self.utt2phone = {} for utt_info in self.metadata: dataset = utt_info["Dataset"] uid = utt_info["Uid"] utt = "{}_{}".format(dataset, uid) self.utt2phone[utt] = utt_info["phones"] assert cfg.preprocess.use_len == True if cfg.preprocess.use_len: self.utt2len = {} for utt_info in self.metadata: dataset = utt_info["Dataset"] uid = utt_info["Uid"] utt = "{}_{}".format(dataset, uid) self.utt2len[utt] = utt_info["num_frames"] # for cross reference if cfg.preprocess.use_cross_reference: self.spkid2utt = {} for utt_info in self.metadata: dataset = utt_info["Dataset"] uid = utt_info["Uid"] utt = "{}_{}".format(dataset, uid) spkid = utt_info["speaker"] if spkid not in self.spkid2utt: self.spkid2utt[spkid] = [] self.spkid2utt[spkid].append(utt) # get phone to id / id to phone map self.phone2id, self.id2phone = self.get_phone_map() self.all_num_frames = [] for i in range(len(self.metadata)): self.all_num_frames.append(self.metadata[i]["num_frames"]) self.num_frame_sorted = np.array(sorted(self.all_num_frames)) self.num_frame_indices = np.array( sorted( range(len(self.all_num_frames)), key=lambda k: self.all_num_frames[k] ) ) def __len__(self): return len(self.metadata) def get_dataset_name(self): return self.metadata[0]["Dataset"] def get_metadata(self): with open(self.metafile_path, "r", encoding="utf-8") as f: metadata = json.load(f) print("metadata len: ", len(metadata)) return metadata def get_phone_map(self): symbols = valid_symbols + ["sp", "spn", "sil"] + ["", ""] phone2id = {s: i for i, s in enumerate(symbols)} id2phone = {i: s for s, i in phone2id.items()} return phone2id, id2phone def __getitem__(self, index): utt_info = self.metadata[index] dataset = utt_info["Dataset"] uid = utt_info["Uid"] utt = "{}_{}".format(dataset, uid) single_feature = dict() if self.cfg.preprocess.read_metadata: metadata_uid_path = os.path.join( self.cfg.preprocess.processed_dir, self.cfg.preprocess.metadata_dir, dataset, # utt_info["speaker"], uid + ".pkl", ) with open(metadata_uid_path, "rb") as f: metadata_uid = pickle.load(f) # code code = metadata_uid["code"] # frame_nums frame_nums = code.shape[1] # pitch pitch = metadata_uid["pitch"] # duration duration = metadata_uid["duration"] # phone_id phone_id = np.array( [ *map( self.phone2id.get, self.utt2phone[utt].replace("{", "").replace("}", "").split(), ) ] ) else: # code code = np.load(self.utt2code_path[utt]) # frame_nums frame_nums = code.shape[1] # pitch pitch = np.load(self.utt2pitch_path[utt]) # duration duration = np.load(self.utt2duration_path[utt]) # phone_id phone_id = np.array( [ *map( self.phone2id.get, self.utt2phone[utt].replace("{", "").replace("}", "").split(), ) ] ) # align length code, pitch, duration, phone_id, frame_nums = self.align_length( code, pitch, duration, phone_id, frame_nums ) # spkid spkid = self.utt2spkid[utt] # get target and reference out = self.get_target_and_reference(code, pitch, duration, phone_id, frame_nums) code, ref_code = out["code"], out["ref_code"] pitch, ref_pitch = out["pitch"], out["ref_pitch"] duration, ref_duration = out["duration"], out["ref_duration"] phone_id, ref_phone_id = out["phone_id"], out["ref_phone_id"] frame_nums, ref_frame_nums = out["frame_nums"], out["ref_frame_nums"] # phone_id_frame assert len(phone_id) == len(duration) phone_id_frame = [] for i in range(len(phone_id)): phone_id_frame.extend([phone_id[i] for _ in range(duration[i])]) phone_id_frame = np.array(phone_id_frame) # ref_phone_id_frame assert len(ref_phone_id) == len(ref_duration) ref_phone_id_frame = [] for i in range(len(ref_phone_id)): ref_phone_id_frame.extend([ref_phone_id[i] for _ in range(ref_duration[i])]) ref_phone_id_frame = np.array(ref_phone_id_frame) single_feature.update( { "code": code, "frame_nums": frame_nums, "pitch": pitch, "duration": duration, "phone_id": phone_id, "phone_id_frame": phone_id_frame, "ref_code": ref_code, "ref_frame_nums": ref_frame_nums, "ref_pitch": ref_pitch, "ref_duration": ref_duration, "ref_phone_id": ref_phone_id, "ref_phone_id_frame": ref_phone_id_frame, "spkid": spkid, } ) return single_feature def get_num_frames(self, index): utt_info = self.metadata[index] return utt_info["num_frames"] def align_length(self, code, pitch, duration, phone_id, frame_nums): # aligh lenght of code, pitch, duration, phone_id, and frame nums code_len = code.shape[1] pitch_len = len(pitch) dur_sum = sum(duration) min_len = min(code_len, dur_sum) code = code[:, :min_len] if pitch_len >= min_len: pitch = pitch[:min_len] else: pitch = np.pad(pitch, (0, min_len - pitch_len), mode="edge") frame_nums = min_len if dur_sum > min_len: assert (duration[-1] - (dur_sum - min_len)) >= 0 duration[-1] = duration[-1] - (dur_sum - min_len) assert duration[-1] >= 0 return code, pitch, duration, phone_id, frame_nums def get_target_and_reference(self, code, pitch, duration, phone_id, frame_nums): phone_nums = len(phone_id) clip_phone_nums = np.random.randint( int(phone_nums * 0.1), int(phone_nums * 0.5) + 1 ) clip_phone_nums = max(clip_phone_nums, 1) assert clip_phone_nums < phone_nums and clip_phone_nums >= 1 if self.cfg.preprocess.clip_mode == "mid": start_idx = np.random.randint(0, phone_nums - clip_phone_nums) elif self.cfg.preprocess.clip_mode == "start": if duration[0] == 0 and clip_phone_nums == 1: start_idx = 1 else: start_idx = 0 else: assert self.cfg.preprocess.clip_mode in ["mid", "start"] end_idx = start_idx + clip_phone_nums start_frames = sum(duration[:start_idx]) end_frames = sum(duration[:end_idx]) new_code = np.concatenate( (code[:, :start_frames], code[:, end_frames:]), axis=1 ) ref_code = code[:, start_frames:end_frames] new_pitch = np.append(pitch[:start_frames], pitch[end_frames:]) ref_pitch = pitch[start_frames:end_frames] new_duration = np.append(duration[:start_idx], duration[end_idx:]) ref_duration = duration[start_idx:end_idx] new_phone_id = np.append(phone_id[:start_idx], phone_id[end_idx:]) ref_phone_id = phone_id[start_idx:end_idx] new_frame_nums = frame_nums - (end_frames - start_frames) ref_frame_nums = end_frames - start_frames return { "code": new_code, "ref_code": ref_code, "pitch": new_pitch, "ref_pitch": ref_pitch, "duration": new_duration, "ref_duration": ref_duration, "phone_id": new_phone_id, "ref_phone_id": ref_phone_id, "frame_nums": new_frame_nums, "ref_frame_nums": ref_frame_nums, } class NS2Collator(BaseOfflineCollator): def __init__(self, cfg): BaseOfflineCollator.__init__(self, cfg) def __call__(self, batch): packed_batch_features = dict() # code: (B, 16, T) # frame_nums: (B,) not used # pitch: (B, T) # duration: (B, N) # phone_id: (B, N) # phone_id_frame: (B, T) # ref_code: (B, 16, T') # ref_frame_nums: (B,) not used # ref_pitch: (B, T) not used # ref_duration: (B, N') not used # ref_phone_id: (B, N') not used # ref_phone_frame: (B, T') not used # spkid: (B,) not used # phone_mask: (B, N) # mask: (B, T) # ref_mask: (B, T') for key in batch[0].keys(): if key == "phone_id": phone_ids = [torch.LongTensor(b["phone_id"]) for b in batch] phone_masks = [torch.ones(len(b["phone_id"])) for b in batch] packed_batch_features["phone_id"] = pad_sequence( phone_ids, batch_first=True, padding_value=0, ) packed_batch_features["phone_mask"] = pad_sequence( phone_masks, batch_first=True, padding_value=0, ) elif key == "phone_id_frame": phone_id_frames = [torch.LongTensor(b["phone_id_frame"]) for b in batch] masks = [torch.ones(len(b["phone_id_frame"])) for b in batch] packed_batch_features["phone_id_frame"] = pad_sequence( phone_id_frames, batch_first=True, padding_value=0, ) packed_batch_features["mask"] = pad_sequence( masks, batch_first=True, padding_value=0, ) elif key == "ref_code": ref_codes = [ torch.from_numpy(b["ref_code"]).transpose(0, 1) for b in batch ] ref_masks = [torch.ones(max(b["ref_code"].shape[1], 1)) for b in batch] packed_batch_features["ref_code"] = pad_sequence( ref_codes, batch_first=True, padding_value=0, ).transpose(1, 2) packed_batch_features["ref_mask"] = pad_sequence( ref_masks, batch_first=True, padding_value=0, ) elif key == "code": codes = [torch.from_numpy(b["code"]).transpose(0, 1) for b in batch] masks = [torch.ones(max(b["code"].shape[1], 1)) for b in batch] packed_batch_features["code"] = pad_sequence( codes, batch_first=True, padding_value=0, ).transpose(1, 2) packed_batch_features["mask"] = pad_sequence( masks, batch_first=True, padding_value=0, ) elif key == "pitch": values = [torch.from_numpy(b[key]) for b in batch] packed_batch_features[key] = pad_sequence( values, batch_first=True, padding_value=50.0 ) elif key == "duration": values = [torch.from_numpy(b[key]) for b in batch] packed_batch_features[key] = pad_sequence( values, batch_first=True, padding_value=0 ) elif key == "frame_nums": packed_batch_features["frame_nums"] = torch.LongTensor( [b["frame_nums"] for b in batch] ) elif key == "ref_frame_nums": packed_batch_features["ref_frame_nums"] = torch.LongTensor( [b["ref_frame_nums"] for b in batch] ) else: pass return packed_batch_features def _is_batch_full(batch, num_tokens, max_tokens, max_sentences): if len(batch) == 0: return 0 if len(batch) == max_sentences: return 1 if num_tokens > max_tokens: return 1 return 0 def batch_by_size( indices, num_tokens_fn, max_tokens=None, max_sentences=None, required_batch_size_multiple=1, ): """ Yield mini-batches of indices bucketed by size. Batches may contain sequences of different lengths. Args: indices (List[int]): ordered list of dataset indices num_tokens_fn (callable): function that returns the number of tokens at a given index max_tokens (int, optional): max number of tokens in each batch (default: None). max_sentences (int, optional): max number of sentences in each batch (default: None). required_batch_size_multiple (int, optional): require batch size to be a multiple of N (default: 1). """ bsz_mult = required_batch_size_multiple sample_len = 0 sample_lens = [] batch = [] batches = [] for i in range(len(indices)): idx = indices[i] num_tokens = num_tokens_fn(idx) sample_lens.append(num_tokens) sample_len = max(sample_len, num_tokens) assert ( sample_len <= max_tokens ), "sentence at index {} of size {} exceeds max_tokens " "limit of {}!".format( idx, sample_len, max_tokens ) num_tokens = (len(batch) + 1) * sample_len if _is_batch_full(batch, num_tokens, max_tokens, max_sentences): mod_len = max( bsz_mult * (len(batch) // bsz_mult), len(batch) % bsz_mult, ) batches.append(batch[:mod_len]) batch = batch[mod_len:] sample_lens = sample_lens[mod_len:] sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 batch.append(idx) if len(batch) > 0: batches.append(batch) return batches