File size: 2,231 Bytes
3dd84f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import os
from dataclasses import asdict
import torch
import torchaudio
from torch.utils.data import Dataset

from utils.audio import LogMelSpectrogram
from config import MelConfig
    
class VocosDataset(Dataset):
    def __init__(self, filelist_path, segment_size: int, mel_config: MelConfig):
        self.filelist_path = filelist_path     
        self.segment_size = segment_size
        self.sample_rate = mel_config.sample_rate
        self.mel_extractor = LogMelSpectrogram(**asdict(mel_config))
        
        self.filelist = self._load_filelist(filelist_path)

    def _load_filelist(self, filelist_path):
        if os.path.isdir(filelist_path):
            print('scanning dir to get audio files')
            filelist = find_audio_files(filelist_path)
        else:
            with open(filelist_path, 'r', encoding='utf-8') as f:
                filelist = [line.strip() for line in f if os.path.exists(line.strip())]
        return filelist
    
    def __len__(self):
        return len(self.filelist)

    def __getitem__(self, idx):
        audio = load_and_pad_audio(self.filelist[idx], self.sample_rate, self.segment_size)
        start_index = torch.randint(0, audio.size(-1) - self.segment_size + 1, (1,)).item()
        audio = audio[:, start_index:start_index + self.segment_size] # shape: [1, segment_size]
        mel = self.mel_extractor(audio).squeeze(0) # shape: [n_mels, segment_size // hop_length]
        return audio, mel
    
def load_and_pad_audio(audio_path, target_sr, segment_size):
    y, sr = torchaudio.load(audio_path)
    if y.size(0) > 1:
        y = y[0, :].unsqueeze(0)
    if sr != target_sr:
        y = torchaudio.functional.resample(y, sr, target_sr)
    if y.size(-1) < segment_size:
        y = torch.nn.functional.pad(y, (0, segment_size - y.size(-1)), "constant", 0)
    return y

def find_audio_files(directory):
    audio_files = []
    valid_extensions = ('.wav', '.ogg', '.opus', '.mp3', '.flac')
    
    for root, dirs, files in os.walk(directory):
        for file in files:
            if file.endswith(valid_extensions):
                audio_files.append(os.path.join(root, file))
    
    return audio_files