File size: 4,709 Bytes
c968fc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import random
import torch
from torch.nn.utils.rnn import pad_sequence
from utils.data_utils import *


from models.base.base_dataset import (
    BaseOfflineCollator,
    BaseOfflineDataset,
    BaseTestDataset,
    BaseTestCollator,
)
import librosa

from transformers import AutoTokenizer


class AudioLDMDataset(BaseOfflineDataset):
    def __init__(self, cfg, dataset, is_valid=False):
        BaseOfflineDataset.__init__(self, cfg, dataset, is_valid=is_valid)

        self.cfg = cfg

        # utt2melspec
        if cfg.preprocess.use_melspec:
            self.utt2melspec_path = {}
            for utt_info in self.metadata:
                dataset = utt_info["Dataset"]
                uid = utt_info["Uid"]
                utt = "{}_{}".format(dataset, uid)

                self.utt2melspec_path[utt] = os.path.join(
                    cfg.preprocess.processed_dir,
                    dataset,
                    cfg.preprocess.melspec_dir,
                    uid + ".npy",
                )

        # utt2wav
        if cfg.preprocess.use_wav:
            self.utt2wav_path = {}
            for utt_info in self.metadata:
                dataset = utt_info["Dataset"]
                uid = utt_info["Uid"]
                utt = "{}_{}".format(dataset, uid)

                self.utt2wav_path[utt] = os.path.join(
                    cfg.preprocess.processed_dir,
                    dataset,
                    cfg.preprocess.wav_dir,
                    uid + ".wav",
                )

        # utt2caption
        if cfg.preprocess.use_caption:
            self.utt2caption = {}
            for utt_info in self.metadata:
                dataset = utt_info["Dataset"]
                uid = utt_info["Uid"]
                utt = "{}_{}".format(dataset, uid)

                self.utt2caption[utt] = utt_info["Caption"]

    def __getitem__(self, index):
        # melspec: (n_mels, T)
        # wav: (T,)

        single_feature = BaseOfflineDataset.__getitem__(self, index)

        utt_info = self.metadata[index]
        dataset = utt_info["Dataset"]
        uid = utt_info["Uid"]
        utt = "{}_{}".format(dataset, uid)

        if self.cfg.preprocess.use_melspec:
            single_feature["melspec"] = np.load(self.utt2melspec_path[utt])

        if self.cfg.preprocess.use_wav:
            wav, sr = librosa.load(
                self.utt2wav_path[utt], sr=16000
            )  # hard coding for 16KHz...
            single_feature["wav"] = wav

        if self.cfg.preprocess.use_caption:
            cond_mask = np.random.choice(
                [1, 0],
                p=[
                    self.cfg.preprocess.cond_mask_prob,
                    1 - self.cfg.preprocess.cond_mask_prob,
                ],
            )  # (0.1, 0.9)
            if cond_mask:
                single_feature["caption"] = ""
            else:
                single_feature["caption"] = self.utt2caption[utt]

        return single_feature

    def __len__(self):
        return len(self.metadata)


class AudioLDMCollator(BaseOfflineCollator):
    def __init__(self, cfg):
        BaseOfflineCollator.__init__(self, cfg)

        self.tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512)

    def __call__(self, batch):
        # mel: (B, n_mels, T)
        # wav (option): (B, T)
        # text_input_ids: (B, L)
        # text_attention_mask: (B, L)

        packed_batch_features = dict()

        for key in batch[0].keys():
            if key == "melspec":
                packed_batch_features["melspec"] = torch.from_numpy(
                    np.array([b["melspec"][:, :624] for b in batch])
                )

            if key == "wav":
                values = [torch.from_numpy(b[key]) for b in batch]
                packed_batch_features[key] = pad_sequence(
                    values, batch_first=True, padding_value=0
                )

            if key == "caption":
                captions = [b[key] for b in batch]
                text_input = self.tokenizer(
                    captions, return_tensors="pt", truncation=True, padding="longest"
                )
                text_input_ids = text_input["input_ids"]
                text_attention_mask = text_input["attention_mask"]

                packed_batch_features["text_input_ids"] = text_input_ids
                packed_batch_features["text_attention_mask"] = text_attention_mask

        return packed_batch_features


class AudioLDMTestDataset(BaseTestDataset): ...


class AudioLDMTestCollator(BaseTestCollator): ...