KdaiP's picture
Upload 80 files
3dd84f8 verified
import os
from dataclasses import asdict
import torch
import torchaudio
from torch.utils.data import Dataset
from utils.audio import LogMelSpectrogram
from config import MelConfig
class VocosDataset(Dataset):
def __init__(self, filelist_path, segment_size: int, mel_config: MelConfig):
self.filelist_path = filelist_path
self.segment_size = segment_size
self.sample_rate = mel_config.sample_rate
self.mel_extractor = LogMelSpectrogram(**asdict(mel_config))
self.filelist = self._load_filelist(filelist_path)
def _load_filelist(self, filelist_path):
if os.path.isdir(filelist_path):
print('scanning dir to get audio files')
filelist = find_audio_files(filelist_path)
else:
with open(filelist_path, 'r', encoding='utf-8') as f:
filelist = [line.strip() for line in f if os.path.exists(line.strip())]
return filelist
def __len__(self):
return len(self.filelist)
def __getitem__(self, idx):
audio = load_and_pad_audio(self.filelist[idx], self.sample_rate, self.segment_size)
start_index = torch.randint(0, audio.size(-1) - self.segment_size + 1, (1,)).item()
audio = audio[:, start_index:start_index + self.segment_size] # shape: [1, segment_size]
mel = self.mel_extractor(audio).squeeze(0) # shape: [n_mels, segment_size // hop_length]
return audio, mel
def load_and_pad_audio(audio_path, target_sr, segment_size):
y, sr = torchaudio.load(audio_path)
if y.size(0) > 1:
y = y[0, :].unsqueeze(0)
if sr != target_sr:
y = torchaudio.functional.resample(y, sr, target_sr)
if y.size(-1) < segment_size:
y = torch.nn.functional.pad(y, (0, segment_size - y.size(-1)), "constant", 0)
return y
def find_audio_files(directory):
audio_files = []
valid_extensions = ('.wav', '.ogg', '.opus', '.mp3', '.flac')
for root, dirs, files in os.walk(directory):
for file in files:
if file.endswith(valid_extensions):
audio_files.append(os.path.join(root, file))
return audio_files