File size: 3,533 Bytes
c968fc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import random

import numpy as np

import torchaudio
import librosa
from torch.nn import functional as F

from torch.nn.utils.rnn import pad_sequence
from utils.data_utils import *
from models.codec.codec_dataset import CodecDataset


class FAcodecDataset(torch.utils.data.Dataset):
    def __init__(self, cfg, dataset, is_valid=False):
        """
        Args:
            cfg: config
            dataset: dataset name
            is_valid: whether to use train or valid dataset
        """
        self.data_root_dir = cfg.dataset
        self.data_list = []
        # walk through the dataset directory recursively, save all files ends with .wav/.mp3/.opus/.flac/.m4a
        for root, _, files in os.walk(self.data_root_dir):
            for file in files:
                if file.endswith((".wav", ".mp3", ".opus", ".flac", ".m4a")):
                    self.data_list.append(os.path.join(root, file))
        self.sr = cfg.preprocess_params.sr
        self.duration_range = cfg.preprocess_params.duration_range
        self.to_mel = torchaudio.transforms.MelSpectrogram(
            n_mels=cfg.preprocess_params.spect_params.n_mels,
            n_fft=cfg.preprocess_params.spect_params.n_fft,
            win_length=cfg.preprocess_params.spect_params.win_length,
            hop_length=cfg.preprocess_params.spect_params.hop_length,
        )
        self.mean, self.std = -4, 4

    def preprocess(self, wave):
        wave_tensor = (
            torch.from_numpy(wave).float() if isinstance(wave, np.ndarray) else wave
        )
        mel_tensor = self.to_mel(wave_tensor)
        mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - self.mean) / self.std
        return mel_tensor

    def __len__(self):
        # return len(self.data_list)
        return len(self.data_list)  # return a fixed number for testing

    def __getitem__(self, index):
        wave, _ = librosa.load(self.data_list[index], sr=self.sr)
        wave = np.random.randn(self.sr * random.randint(*self.duration_range))
        wave = wave / np.max(np.abs(wave))
        mel = self.preprocess(wave).squeeze(0)
        wave = torch.from_numpy(wave).float()
        return wave, mel


class FAcodecCollator(object):
    """Zero-pads model inputs and targets based on number of frames per step"""

    def __init__(self, cfg):
        self.cfg = cfg

    def __call__(self, batch):
        # batch[0] = wave, mel, text, f0, speakerid
        batch_size = len(batch)

        # sort by mel length
        lengths = [b[1].shape[1] for b in batch]
        batch_indexes = np.argsort(lengths)[::-1]
        batch = [batch[bid] for bid in batch_indexes]

        nmels = batch[0][1].size(0)
        max_mel_length = max([b[1].shape[1] for b in batch])
        max_wave_length = max([b[0].size(0) for b in batch])

        mels = torch.zeros((batch_size, nmels, max_mel_length)).float() - 10
        waves = torch.zeros((batch_size, max_wave_length)).float()

        mel_lengths = torch.zeros(batch_size).long()
        wave_lengths = torch.zeros(batch_size).long()

        for bid, (wave, mel) in enumerate(batch):
            mel_size = mel.size(1)
            mels[bid, :, :mel_size] = mel
            waves[bid, : wave.size(0)] = wave
            mel_lengths[bid] = mel_size
            wave_lengths[bid] = wave.size(0)

        return waves, mels, wave_lengths, mel_lengths