File size: 9,275 Bytes
e775f6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import os.path

from dataset import CocktailDataset, MusicDataset, CocktailLabeledDataset, MusicLabeledDataset, RegressedGroundingDataset
import torch
import numpy as np
from src.music.utilities.representation_learning_utilities.sampler import FixedLenRandomSampler
from torch.utils.data.sampler import RandomSampler
from torch.utils.data import DataLoader
from src.music.utils import get_all_subfiles_with_extension
import pickle

device = 'cuda' if torch.cuda.is_available() else 'cpu'


def wasserstein1d(x, y):
    x1, _ = torch.sort(x, dim=0)
    y1, _ = torch.sort(y, dim=0)
    z = (x1-y1).view(-1)
    n = z.size(0)
    return torch.dot(z, z) / n

def compute_swd_loss(minibatch1, minibatch2, latent_dim, n_projections=10000):
    # sample random projections
    theta = torch.randn((latent_dim, n_projections),
                        requires_grad=False,
                        device=device)
    theta = theta/torch.norm(theta, dim=0)[None, :]

    proj1 = minibatch1@theta
    proj2 = minibatch2@theta

    # compute sliced wasserstein distance on projected features
    gloss = wasserstein1d(proj1, proj2)
    return gloss



def get_dataloaders(cocktail_rep_path, music_rep_path, batch_size, train_epoch_size, test_epoch_size):
    assert train_epoch_size % batch_size == 0, 'epoch size is expressed in steps, must be a multiple of batch size'
    assert test_epoch_size % batch_size == 0, 'epoch size is expressed in steps, must be a multiple of batch size'
    assert '.pickle' in music_rep_path

    if not os.path.exists(music_rep_path):
        music_rep_paths = get_all_subfiles_with_extension(music_rep_path.replace('music_reps_normalized_meanstd.pickle', ''), max_depth=3, extension='.txt', current_depth=0)
        music_reps = []
        for p in music_rep_paths:
            music_reps.append(np.loadtxt(p))
        music_reps = np.array(music_reps)
        mean_std = np.array([music_reps.mean(axis=0), music_reps.std(axis=0)])
        music_reps = (music_reps - mean_std[0]) / mean_std[1]
        assert len(music_rep_paths) == len(music_reps), 'check bug with mean_std'
        data = dict(zip(music_rep_paths, music_reps))
        to_save = dict(musicpath2musicrep=data,
                       mean_std=mean_std)
        with open(music_rep_path, 'wb') as f:
            pickle.dump(to_save, f)

    with open(music_rep_path, 'rb') as f:
        data = pickle.load(f)
    mean_std_music_rep = data['mean_std']
    music_rep_paths = sorted(data['musicpath2musicrep'].keys())
    music_reps = np.array([data['musicpath2musicrep'][k] for k in music_rep_paths])


    cocktail_reps = np.loadtxt(cocktail_rep_path)
    mean_std_cocktail_rep_norm11 = np.array([cocktail_reps.mean(axis=0), cocktail_reps.std(axis=0)])
    cocktail_reps = (cocktail_reps - cocktail_reps.mean(axis=0)) / cocktail_reps.std(axis=0)
    
    train_data_cocktail = CocktailDataset(split='train', cocktail_reps=cocktail_reps)
    test_data_cocktail = CocktailDataset(split='test', cocktail_reps=cocktail_reps)
    train_data_music = MusicDataset(split='train', music_reps=music_reps, music_rep_paths=music_rep_paths)
    test_data_music = MusicDataset(split='test', music_reps=music_reps, music_rep_paths=music_rep_paths)

    train_sampler_cocktail = FixedLenRandomSampler(train_data_cocktail, bs=batch_size, epoch_size=train_epoch_size)
    test_sampler_cocktail = FixedLenRandomSampler(test_data_cocktail, bs=batch_size, epoch_size=test_epoch_size)
    train_sampler_music = FixedLenRandomSampler(train_data_music, bs=batch_size, epoch_size=train_epoch_size)
    test_sampler_music = FixedLenRandomSampler(test_data_music, bs=batch_size, epoch_size=test_epoch_size)

    train_data_cocktail = DataLoader(train_data_cocktail, batch_sampler=train_sampler_cocktail)
    test_data_cocktail = DataLoader(test_data_cocktail,  batch_sampler=test_sampler_cocktail)
    train_data_music = DataLoader(train_data_music,  batch_sampler=train_sampler_music)
    test_data_music = DataLoader(test_data_music,  batch_sampler=test_sampler_music)

    train_data_cocktail_labeled = CocktailLabeledDataset(split='train', cocktail_reps=cocktail_reps)
    test_data_cocktail_labeled = CocktailLabeledDataset(split='test', cocktail_reps=cocktail_reps)
    train_data_music_labeled = MusicLabeledDataset(split='train', music_reps=music_reps, music_rep_paths=music_rep_paths)
    test_data_music_labeled = MusicLabeledDataset(split='test', music_reps=music_reps, music_rep_paths=music_rep_paths)

    train_sampler_cocktail_labeled = FixedLenRandomSampler(train_data_cocktail_labeled, bs=batch_size, epoch_size=train_epoch_size)
    test_sampler_cocktail_labeled = FixedLenRandomSampler(test_data_cocktail_labeled, bs=batch_size, epoch_size=test_epoch_size)
    train_sampler_music_labeled = FixedLenRandomSampler(train_data_music_labeled, bs=batch_size, epoch_size=train_epoch_size)
    test_sampler_music_labeled = FixedLenRandomSampler(test_data_music_labeled, bs=batch_size, epoch_size=test_epoch_size)

    train_data_cocktail_labeled = DataLoader(train_data_cocktail_labeled, batch_sampler=train_sampler_cocktail_labeled)
    test_data_cocktail_labeled = DataLoader(test_data_cocktail_labeled,  batch_sampler=test_sampler_cocktail_labeled)
    train_data_music_labeled = DataLoader(train_data_music_labeled,  batch_sampler=train_sampler_music_labeled)
    test_data_music_labeled = DataLoader(test_data_music_labeled,  batch_sampler=test_sampler_music_labeled)

    train_data_grounding = RegressedGroundingDataset(split='train', music_reps=music_reps, music_rep_paths=music_rep_paths, cocktail_reps=cocktail_reps)
    test_data_grounding = RegressedGroundingDataset(split='test', music_reps=music_reps, music_rep_paths=music_rep_paths, cocktail_reps=cocktail_reps)

    train_sampler_grounding = FixedLenRandomSampler(train_data_grounding, bs=batch_size, epoch_size=train_epoch_size)
    test_sampler_grounding = FixedLenRandomSampler(test_data_grounding, bs=batch_size, epoch_size=test_epoch_size)

    train_data_grounding = DataLoader(train_data_grounding, batch_sampler=train_sampler_grounding)
    test_data_grounding = DataLoader(test_data_grounding, batch_sampler=test_sampler_grounding)

    data_loaders = dict(music_train=train_data_music,
                        music_test=test_data_music,
                        cocktail_train=train_data_cocktail,
                        cocktail_test=test_data_cocktail,
                        music_labeled_train=train_data_music_labeled,
                        music_labeled_test=test_data_music_labeled,
                        cocktail_labeled_train=train_data_cocktail_labeled,
                        cocktail_labeled_test=test_data_cocktail_labeled,
                        reg_grounding_train=train_data_grounding,
                        reg_grounding_test=test_data_grounding
                        )
    for k in data_loaders.keys():
        print(f'Dataset {k}, size: {len(data_loaders[k].dataset)}')
    assert data_loaders['cocktail_labeled_train'].dataset.n_labels == data_loaders['music_labeled_train'].dataset.n_labels
    stats = dict(mean_std_music_rep=mean_std_music_rep.tolist(), mean_std_cocktail_rep_norm11=mean_std_cocktail_rep_norm11.tolist())
    return data_loaders, data_loaders['music_labeled_train'].dataset.n_labels, stats


class FixedLenRandomSampler(RandomSampler):
    """
    Code from mnpinto - Miguel
    https://forums.fast.ai/t/epochs-of-arbitrary-length/27777/10
    """
    def __init__(self, data_source, bs, epoch_size, *args, **kwargs):
        super().__init__(data_source)
        self.not_sampled = np.array([True]*len(data_source))
        self.update_epoch_size_and_batch(epoch_size, bs)

    def update_epoch_size_and_batch(self, epoch_size, bs):
        self.epoch_size = epoch_size
        self.bs = bs
        self.size_to_sample = self.epoch_size
        self.nb_batches_per_epoch = self.epoch_size // self.bs

    def _reset_state(self):
        self.not_sampled[:] = True

    def reset_and_sample(self, idx, total_to_sample):
        n_to_sample = total_to_sample - len(idx)
        ns = sum(self.not_sampled)
        if ns == 0:
            self._reset_state()
            return self.reset_and_sample(idx, total_to_sample)
        elif ns >= n_to_sample:
            new_idx = np.random.choice(np.where(self.not_sampled)[0], size=n_to_sample, replace=False).tolist()
            new_idx = [*idx, *new_idx]
            assert len(new_idx) == total_to_sample
            return new_idx
        else:
            idx_last = np.where(self.not_sampled)[0].tolist()
            new_idx = [*idx, *idx_last]
            self._reset_state()
            return self.reset_and_sample(new_idx, total_to_sample)

    def __iter__(self):
        idx = self.reset_and_sample(idx=[], total_to_sample=self.size_to_sample)
        assert len(idx) == self.size_to_sample
        self.not_sampled[idx] = False
        # print(ns, len(idx), len(idx_last)) # debug
        out = []
        i_idx = 0
        for i in range(self.nb_batches_per_epoch):
            batch = []
            for j in range(self.bs):
                batch.append(idx[i_idx])
                i_idx += 1
            out.append(batch)
        return iter(out)

    def __len__(self):
        return self.nb_batches_per_epoch