# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import random import torch from torch.nn.utils.rnn import pad_sequence from utils.data_utils import * from tqdm import tqdm from g2p_en import G2p import librosa from torch.utils.data import Dataset import pandas as pd import time import io SAMPLE_RATE = 16000 # g2p from .g2p_processor import G2pProcessor phonemizer_g2p = G2pProcessor() class VALLEDataset(Dataset): def __init__(self, args): print(f"Initializing VALLEDataset") self.dataset_list = args.dataset_list print(f"using sampling rate {SAMPLE_RATE}") # set dataframe clumn name book_col_name = [ "ID", "Original_text", "Normalized_text", "Aligned_or_not", "Start_time", "End_time", "Signal_to_noise_ratio", ] trans_col_name = [ "ID", "Original_text", "Normalized_text", "Dir_path", "Duration", ] self.metadata_cache = pd.DataFrame(columns=book_col_name) self.trans_cache = pd.DataFrame(columns=trans_col_name) # dataset_cache_dir = args.cache_dir # cache_dir # print(f"args.cache_dir = ", args.cache_dir) # os.makedirs(dataset_cache_dir, exist_ok=True) ######## add data dir to dataset2dir ########## self.dataset2dir = { "dev-clean": f"{args.data_dir}/dev-clean", "dev-other": f"{args.data_dir}/dev-other", "test-clean": f"{args.data_dir}/test-clean", "test-other": f"{args.data_dir}/test-other", "train-clean-100": f"{args.data_dir}/train-clean-100", "train-clean-360": f"{args.data_dir}/train-clean-360", "train-other-500": f"{args.data_dir}/train-other-500", } ###### load metadata and transcripts ##### for dataset_name in self.dataset_list: print("Initializing dataset: ", dataset_name) # get [book,transcripts,audio] files list self.book_files_list = self.get_metadata_files( self.dataset2dir[dataset_name] ) self.trans_files_list = self.get_trans_files(self.dataset2dir[dataset_name]) ## create metadata_cache (book.tsv file is not filtered, some file is not exist, but contain Duration and Signal_to_noise_ratio) print("reading paths for dataset...") for book_path in tqdm(self.book_files_list): tmp_cache = pd.read_csv( book_path, sep="\t", names=book_col_name, quoting=3 ) self.metadata_cache = pd.concat( [self.metadata_cache, tmp_cache], ignore_index=True ) self.metadata_cache.set_index("ID", inplace=True) ## create transcripts (the trans.tsv file) print("creating transcripts for dataset...") for trans_path in tqdm(self.trans_files_list): tmp_cache = pd.read_csv( trans_path, sep="\t", names=trans_col_name, quoting=3 ) tmp_cache["Dir_path"] = os.path.dirname(trans_path) self.trans_cache = pd.concat( [self.trans_cache, tmp_cache], ignore_index=True ) self.trans_cache.set_index("ID", inplace=True) ## calc duration self.trans_cache["Duration"] = ( self.metadata_cache.End_time[self.trans_cache.index] - self.metadata_cache.Start_time[self.trans_cache.index] ) ## add fullpath # self.trans_cache['Full_path'] = os.path.join(self.dataset2dir[dataset_name],self.trans_cache['ID']) # filter_by_duration: filter_out files with duration < 3.0 or > 15.0 print(f"Filtering files with duration between 3.0 and 15.0 seconds") print(f"Before filtering: {len(self.trans_cache)}") self.trans_cache = self.trans_cache[ (self.trans_cache["Duration"] >= 3.0) & (self.trans_cache["Duration"] <= 15.0) ] print(f"After filtering: {len(self.trans_cache)}") def get_metadata_files(self, directory): book_files = [] for root, _, files in os.walk(directory): for file in files: if file.endswith(".book.tsv") and file[0] != ".": rel_path = os.path.join(root, file) book_files.append(rel_path) return book_files def get_trans_files(self, directory): trans_files = [] for root, _, files in os.walk(directory): for file in files: if file.endswith(".trans.tsv") and file[0] != ".": rel_path = os.path.join(root, file) trans_files.append(rel_path) return trans_files def get_audio_files(self, directory): audio_files = [] for root, _, files in os.walk(directory): for file in files: if file.endswith((".flac", ".wav", ".opus")): rel_path = os.path.relpath(os.path.join(root, file), directory) audio_files.append(rel_path) return audio_files def get_num_frames(self, index): # get_num_frames(durations) by index duration = self.meta_data_cache["Duration"][index] # num_frames = duration * SAMPLE_RATE num_frames = int(duration * 75) # file_rel_path = self.meta_data_cache['relpath'][index] # uid = file_rel_path.rstrip('.flac').split('/')[-1] # num_frames += len(self.transcripts[uid]) return num_frames def __len__(self): return len(self.trans_cache) def __getitem__(self, idx): # Get the file rel path file_dir_path = self.trans_cache["Dir_path"].iloc[idx] # Get uid uid = self.trans_cache.index[idx] # Get the file name from cache uid file_name = uid + ".wav" # Get the full file path full_file_path = os.path.join(file_dir_path, file_name) # get phone phone = self.trans_cache["Normalized_text"][uid] phone = phonemizer_g2p(phone, "en")[1] # load speech speech, _ = librosa.load(full_file_path, sr=SAMPLE_RATE) # if self.resample_to_24k: # speech = librosa.resample(speech, orig_sr=SAMPLE_RATE, target_sr=24000) # speech = torch.tensor(speech, dtype=torch.float32) # pad speech to multiples of 200 # remainder = speech.size(0) % 200 # if remainder > 0: # pad = 200 - remainder # speech = torch.cat([speech, torch.zeros(pad, dtype=torch.float32)], dim=0) # inputs = self._get_reference_vc(speech, hop_length=200) inputs = {} # Get the speaker id # speaker = self.meta_data_cache['speaker'][idx] # speaker_id = self.speaker2id[speaker] # inputs["speaker_id"] = speaker_id inputs["speech"] = speech # 24khz speech, [T] inputs["phone"] = phone # [T] return inputs def _is_batch_full(batch, num_tokens, max_tokens, max_sentences): if len(batch) == 0: return 0 if len(batch) == max_sentences: return 1 if num_tokens > max_tokens: return 1 return 0 def batch_by_size( indices, num_tokens_fn, max_tokens=None, max_sentences=None, required_batch_size_multiple=1, ): """ Yield mini-batches of indices bucketed by size. Batches may contain sequences of different lengths. Args: indices (List[int]): ordered list of dataset indices num_tokens_fn (callable): function that returns the number of tokens at a given index max_tokens (int, optional): max number of tokens in each batch (default: None). max_sentences (int, optional): max number of sentences in each batch (default: None). required_batch_size_multiple (int, optional): require batch size to be a multiple of N (default: 1). """ bsz_mult = required_batch_size_multiple sample_len = 0 sample_lens = [] batch = [] batches = [] for i in range(len(indices)): idx = indices[i] num_tokens = num_tokens_fn(idx) sample_lens.append(num_tokens) sample_len = max(sample_len, num_tokens) assert ( sample_len <= max_tokens ), "sentence at index {} of size {} exceeds max_tokens " "limit of {}!".format( idx, sample_len, max_tokens ) num_tokens = (len(batch) + 1) * sample_len if _is_batch_full(batch, num_tokens, max_tokens, max_sentences): mod_len = max( bsz_mult * (len(batch) // bsz_mult), len(batch) % bsz_mult, ) batches.append(batch[:mod_len]) batch = batch[mod_len:] sample_lens = sample_lens[mod_len:] sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 batch.append(idx) if len(batch) > 0: batches.append(batch) return batches def test(): from utils.util import load_config cfg = load_config("./egs/tts/VALLE_V2/exp_ar_libritts.json") dataset = VALLEDataset(cfg.dataset) metadata_cache = dataset.metadata_cache trans_cache = dataset.trans_cache print(trans_cache.head(10)) # print(dataset.book_files_list) breakpoint() if __name__ == "__main__": test()