Spaces:
Sleeping
Sleeping
import os | |
import random | |
import re | |
import numpy as np | |
import librosa | |
import torch | |
import random | |
from utils import repeat_expand_2d | |
from tqdm import tqdm | |
from torch.utils.data import Dataset | |
def traverse_dir( | |
root_dir, | |
extensions, | |
amount=None, | |
str_include=None, | |
str_exclude=None, | |
is_pure=False, | |
is_sort=False, | |
is_ext=True): | |
file_list = [] | |
cnt = 0 | |
for root, _, files in os.walk(root_dir): | |
for file in files: | |
if any([file.endswith(f".{ext}") for ext in extensions]): | |
# path | |
mix_path = os.path.join(root, file) | |
pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path | |
# amount | |
if (amount is not None) and (cnt == amount): | |
if is_sort: | |
file_list.sort() | |
return file_list | |
# check string | |
if (str_include is not None) and (str_include not in pure_path): | |
continue | |
if (str_exclude is not None) and (str_exclude in pure_path): | |
continue | |
if not is_ext: | |
ext = pure_path.split('.')[-1] | |
pure_path = pure_path[:-(len(ext)+1)] | |
file_list.append(pure_path) | |
cnt += 1 | |
if is_sort: | |
file_list.sort() | |
return file_list | |
def get_data_loaders(args, whole_audio=False): | |
data_train = AudioDataset( | |
filelists = args.data.training_files, | |
waveform_sec=args.data.duration, | |
hop_size=args.data.block_size, | |
sample_rate=args.data.sampling_rate, | |
load_all_data=args.train.cache_all_data, | |
whole_audio=whole_audio, | |
extensions=args.data.extensions, | |
n_spk=args.model.n_spk, | |
spk=args.spk, | |
device=args.train.cache_device, | |
fp16=args.train.cache_fp16, | |
use_aug=True) | |
loader_train = torch.utils.data.DataLoader( | |
data_train , | |
batch_size=args.train.batch_size if not whole_audio else 1, | |
shuffle=True, | |
num_workers=args.train.num_workers if args.train.cache_device=='cpu' else 0, | |
persistent_workers=(args.train.num_workers > 0) if args.train.cache_device=='cpu' else False, | |
pin_memory=True if args.train.cache_device=='cpu' else False | |
) | |
data_valid = AudioDataset( | |
filelists = args.data.validation_files, | |
waveform_sec=args.data.duration, | |
hop_size=args.data.block_size, | |
sample_rate=args.data.sampling_rate, | |
load_all_data=args.train.cache_all_data, | |
whole_audio=True, | |
spk=args.spk, | |
extensions=args.data.extensions, | |
n_spk=args.model.n_spk) | |
loader_valid = torch.utils.data.DataLoader( | |
data_valid, | |
batch_size=1, | |
shuffle=False, | |
num_workers=0, | |
pin_memory=True | |
) | |
return loader_train, loader_valid | |
class AudioDataset(Dataset): | |
def __init__( | |
self, | |
filelists, | |
waveform_sec, | |
hop_size, | |
sample_rate, | |
spk, | |
load_all_data=True, | |
whole_audio=False, | |
extensions=['wav'], | |
n_spk=1, | |
device='cpu', | |
fp16=False, | |
use_aug=False, | |
): | |
super().__init__() | |
self.waveform_sec = waveform_sec | |
self.sample_rate = sample_rate | |
self.hop_size = hop_size | |
self.filelists = filelists | |
self.whole_audio = whole_audio | |
self.use_aug = use_aug | |
self.data_buffer={} | |
self.pitch_aug_dict = {} | |
# np.load(os.path.join(self.path_root, 'pitch_aug_dict.npy'), allow_pickle=True).item() | |
if load_all_data: | |
print('Load all the data filelists:', filelists) | |
else: | |
print('Load the f0, volume data filelists:', filelists) | |
with open(filelists,"r") as f: | |
self.paths = f.read().splitlines() | |
for name_ext in tqdm(self.paths, total=len(self.paths)): | |
name = os.path.splitext(name_ext)[0] | |
path_audio = name_ext | |
duration = librosa.get_duration(filename = path_audio, sr = self.sample_rate) | |
path_f0 = name_ext + ".f0.npy" | |
f0,_ = np.load(path_f0,allow_pickle=True) | |
f0 = torch.from_numpy(np.array(f0,dtype=float)).float().unsqueeze(-1).to(device) | |
path_volume = name_ext + ".vol.npy" | |
volume = np.load(path_volume) | |
volume = torch.from_numpy(volume).float().unsqueeze(-1).to(device) | |
path_augvol = name_ext + ".aug_vol.npy" | |
aug_vol = np.load(path_augvol) | |
aug_vol = torch.from_numpy(aug_vol).float().unsqueeze(-1).to(device) | |
if n_spk is not None and n_spk > 1: | |
spk_name = name_ext.split("/")[-2] | |
spk_id = spk[spk_name] if spk_name in spk else 0 | |
if spk_id < 0 or spk_id >= n_spk: | |
raise ValueError(' [x] Muiti-speaker traing error : spk_id must be a positive integer from 0 to n_spk-1 ') | |
else: | |
spk_id = 0 | |
spk_id = torch.LongTensor(np.array([spk_id])).to(device) | |
if load_all_data: | |
''' | |
audio, sr = librosa.load(path_audio, sr=self.sample_rate) | |
if len(audio.shape) > 1: | |
audio = librosa.to_mono(audio) | |
audio = torch.from_numpy(audio).to(device) | |
''' | |
path_mel = name_ext + ".mel.npy" | |
mel = np.load(path_mel) | |
mel = torch.from_numpy(mel).to(device) | |
path_augmel = name_ext + ".aug_mel.npy" | |
aug_mel,keyshift = np.load(path_augmel, allow_pickle=True) | |
aug_mel = np.array(aug_mel,dtype=float) | |
aug_mel = torch.from_numpy(aug_mel).to(device) | |
self.pitch_aug_dict[name_ext] = keyshift | |
path_units = name_ext + ".soft.pt" | |
units = torch.load(path_units).to(device) | |
units = units[0] | |
units = repeat_expand_2d(units,f0.size(0)).transpose(0,1) | |
if fp16: | |
mel = mel.half() | |
aug_mel = aug_mel.half() | |
units = units.half() | |
self.data_buffer[name_ext] = { | |
'duration': duration, | |
'mel': mel, | |
'aug_mel': aug_mel, | |
'units': units, | |
'f0': f0, | |
'volume': volume, | |
'aug_vol': aug_vol, | |
'spk_id': spk_id | |
} | |
else: | |
path_augmel = name_ext + ".aug_mel.npy" | |
aug_mel,keyshift = np.load(path_augmel, allow_pickle=True) | |
self.pitch_aug_dict[name_ext] = keyshift | |
self.data_buffer[name_ext] = { | |
'duration': duration, | |
'f0': f0, | |
'volume': volume, | |
'aug_vol': aug_vol, | |
'spk_id': spk_id | |
} | |
def __getitem__(self, file_idx): | |
name_ext = self.paths[file_idx] | |
data_buffer = self.data_buffer[name_ext] | |
# check duration. if too short, then skip | |
if data_buffer['duration'] < (self.waveform_sec + 0.1): | |
return self.__getitem__( (file_idx + 1) % len(self.paths)) | |
# get item | |
return self.get_data(name_ext, data_buffer) | |
def get_data(self, name_ext, data_buffer): | |
name = os.path.splitext(name_ext)[0] | |
frame_resolution = self.hop_size / self.sample_rate | |
duration = data_buffer['duration'] | |
waveform_sec = duration if self.whole_audio else self.waveform_sec | |
# load audio | |
idx_from = 0 if self.whole_audio else random.uniform(0, duration - waveform_sec - 0.1) | |
start_frame = int(idx_from / frame_resolution) | |
units_frame_len = int(waveform_sec / frame_resolution) | |
aug_flag = random.choice([True, False]) and self.use_aug | |
''' | |
audio = data_buffer.get('audio') | |
if audio is None: | |
path_audio = os.path.join(self.path_root, 'audio', name) + '.wav' | |
audio, sr = librosa.load( | |
path_audio, | |
sr = self.sample_rate, | |
offset = start_frame * frame_resolution, | |
duration = waveform_sec) | |
if len(audio.shape) > 1: | |
audio = librosa.to_mono(audio) | |
# clip audio into N seconds | |
audio = audio[ : audio.shape[-1] // self.hop_size * self.hop_size] | |
audio = torch.from_numpy(audio).float() | |
else: | |
audio = audio[start_frame * self.hop_size : (start_frame + units_frame_len) * self.hop_size] | |
''' | |
# load mel | |
mel_key = 'aug_mel' if aug_flag else 'mel' | |
mel = data_buffer.get(mel_key) | |
if mel is None: | |
mel = name_ext + ".mel.npy" | |
mel = np.load(mel) | |
mel = mel[start_frame : start_frame + units_frame_len] | |
mel = torch.from_numpy(mel).float() | |
else: | |
mel = mel[start_frame : start_frame + units_frame_len] | |
# load f0 | |
f0 = data_buffer.get('f0') | |
aug_shift = 0 | |
if aug_flag: | |
aug_shift = self.pitch_aug_dict[name_ext] | |
f0_frames = 2 ** (aug_shift / 12) * f0[start_frame : start_frame + units_frame_len] | |
# load units | |
units = data_buffer.get('units') | |
if units is None: | |
path_units = name_ext + ".soft.pt" | |
units = torch.load(path_units) | |
units = units[0] | |
units = repeat_expand_2d(units,f0.size(0)).transpose(0,1) | |
units = units[start_frame : start_frame + units_frame_len] | |
# load volume | |
vol_key = 'aug_vol' if aug_flag else 'volume' | |
volume = data_buffer.get(vol_key) | |
volume_frames = volume[start_frame : start_frame + units_frame_len] | |
# load spk_id | |
spk_id = data_buffer.get('spk_id') | |
# load shift | |
aug_shift = torch.from_numpy(np.array([[aug_shift]])).float() | |
return dict(mel=mel, f0=f0_frames, volume=volume_frames, units=units, spk_id=spk_id, aug_shift=aug_shift, name=name, name_ext=name_ext) | |
def __len__(self): | |
return len(self.paths) |