Spaces:
Runtime error
Runtime error
File size: 9,275 Bytes
e775f6d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import os.path
from dataset import CocktailDataset, MusicDataset, CocktailLabeledDataset, MusicLabeledDataset, RegressedGroundingDataset
import torch
import numpy as np
from src.music.utilities.representation_learning_utilities.sampler import FixedLenRandomSampler
from torch.utils.data.sampler import RandomSampler
from torch.utils.data import DataLoader
from src.music.utils import get_all_subfiles_with_extension
import pickle
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def wasserstein1d(x, y):
x1, _ = torch.sort(x, dim=0)
y1, _ = torch.sort(y, dim=0)
z = (x1-y1).view(-1)
n = z.size(0)
return torch.dot(z, z) / n
def compute_swd_loss(minibatch1, minibatch2, latent_dim, n_projections=10000):
# sample random projections
theta = torch.randn((latent_dim, n_projections),
requires_grad=False,
device=device)
theta = theta/torch.norm(theta, dim=0)[None, :]
proj1 = minibatch1@theta
proj2 = minibatch2@theta
# compute sliced wasserstein distance on projected features
gloss = wasserstein1d(proj1, proj2)
return gloss
def get_dataloaders(cocktail_rep_path, music_rep_path, batch_size, train_epoch_size, test_epoch_size):
assert train_epoch_size % batch_size == 0, 'epoch size is expressed in steps, must be a multiple of batch size'
assert test_epoch_size % batch_size == 0, 'epoch size is expressed in steps, must be a multiple of batch size'
assert '.pickle' in music_rep_path
if not os.path.exists(music_rep_path):
music_rep_paths = get_all_subfiles_with_extension(music_rep_path.replace('music_reps_normalized_meanstd.pickle', ''), max_depth=3, extension='.txt', current_depth=0)
music_reps = []
for p in music_rep_paths:
music_reps.append(np.loadtxt(p))
music_reps = np.array(music_reps)
mean_std = np.array([music_reps.mean(axis=0), music_reps.std(axis=0)])
music_reps = (music_reps - mean_std[0]) / mean_std[1]
assert len(music_rep_paths) == len(music_reps), 'check bug with mean_std'
data = dict(zip(music_rep_paths, music_reps))
to_save = dict(musicpath2musicrep=data,
mean_std=mean_std)
with open(music_rep_path, 'wb') as f:
pickle.dump(to_save, f)
with open(music_rep_path, 'rb') as f:
data = pickle.load(f)
mean_std_music_rep = data['mean_std']
music_rep_paths = sorted(data['musicpath2musicrep'].keys())
music_reps = np.array([data['musicpath2musicrep'][k] for k in music_rep_paths])
cocktail_reps = np.loadtxt(cocktail_rep_path)
mean_std_cocktail_rep_norm11 = np.array([cocktail_reps.mean(axis=0), cocktail_reps.std(axis=0)])
cocktail_reps = (cocktail_reps - cocktail_reps.mean(axis=0)) / cocktail_reps.std(axis=0)
train_data_cocktail = CocktailDataset(split='train', cocktail_reps=cocktail_reps)
test_data_cocktail = CocktailDataset(split='test', cocktail_reps=cocktail_reps)
train_data_music = MusicDataset(split='train', music_reps=music_reps, music_rep_paths=music_rep_paths)
test_data_music = MusicDataset(split='test', music_reps=music_reps, music_rep_paths=music_rep_paths)
train_sampler_cocktail = FixedLenRandomSampler(train_data_cocktail, bs=batch_size, epoch_size=train_epoch_size)
test_sampler_cocktail = FixedLenRandomSampler(test_data_cocktail, bs=batch_size, epoch_size=test_epoch_size)
train_sampler_music = FixedLenRandomSampler(train_data_music, bs=batch_size, epoch_size=train_epoch_size)
test_sampler_music = FixedLenRandomSampler(test_data_music, bs=batch_size, epoch_size=test_epoch_size)
train_data_cocktail = DataLoader(train_data_cocktail, batch_sampler=train_sampler_cocktail)
test_data_cocktail = DataLoader(test_data_cocktail, batch_sampler=test_sampler_cocktail)
train_data_music = DataLoader(train_data_music, batch_sampler=train_sampler_music)
test_data_music = DataLoader(test_data_music, batch_sampler=test_sampler_music)
train_data_cocktail_labeled = CocktailLabeledDataset(split='train', cocktail_reps=cocktail_reps)
test_data_cocktail_labeled = CocktailLabeledDataset(split='test', cocktail_reps=cocktail_reps)
train_data_music_labeled = MusicLabeledDataset(split='train', music_reps=music_reps, music_rep_paths=music_rep_paths)
test_data_music_labeled = MusicLabeledDataset(split='test', music_reps=music_reps, music_rep_paths=music_rep_paths)
train_sampler_cocktail_labeled = FixedLenRandomSampler(train_data_cocktail_labeled, bs=batch_size, epoch_size=train_epoch_size)
test_sampler_cocktail_labeled = FixedLenRandomSampler(test_data_cocktail_labeled, bs=batch_size, epoch_size=test_epoch_size)
train_sampler_music_labeled = FixedLenRandomSampler(train_data_music_labeled, bs=batch_size, epoch_size=train_epoch_size)
test_sampler_music_labeled = FixedLenRandomSampler(test_data_music_labeled, bs=batch_size, epoch_size=test_epoch_size)
train_data_cocktail_labeled = DataLoader(train_data_cocktail_labeled, batch_sampler=train_sampler_cocktail_labeled)
test_data_cocktail_labeled = DataLoader(test_data_cocktail_labeled, batch_sampler=test_sampler_cocktail_labeled)
train_data_music_labeled = DataLoader(train_data_music_labeled, batch_sampler=train_sampler_music_labeled)
test_data_music_labeled = DataLoader(test_data_music_labeled, batch_sampler=test_sampler_music_labeled)
train_data_grounding = RegressedGroundingDataset(split='train', music_reps=music_reps, music_rep_paths=music_rep_paths, cocktail_reps=cocktail_reps)
test_data_grounding = RegressedGroundingDataset(split='test', music_reps=music_reps, music_rep_paths=music_rep_paths, cocktail_reps=cocktail_reps)
train_sampler_grounding = FixedLenRandomSampler(train_data_grounding, bs=batch_size, epoch_size=train_epoch_size)
test_sampler_grounding = FixedLenRandomSampler(test_data_grounding, bs=batch_size, epoch_size=test_epoch_size)
train_data_grounding = DataLoader(train_data_grounding, batch_sampler=train_sampler_grounding)
test_data_grounding = DataLoader(test_data_grounding, batch_sampler=test_sampler_grounding)
data_loaders = dict(music_train=train_data_music,
music_test=test_data_music,
cocktail_train=train_data_cocktail,
cocktail_test=test_data_cocktail,
music_labeled_train=train_data_music_labeled,
music_labeled_test=test_data_music_labeled,
cocktail_labeled_train=train_data_cocktail_labeled,
cocktail_labeled_test=test_data_cocktail_labeled,
reg_grounding_train=train_data_grounding,
reg_grounding_test=test_data_grounding
)
for k in data_loaders.keys():
print(f'Dataset {k}, size: {len(data_loaders[k].dataset)}')
assert data_loaders['cocktail_labeled_train'].dataset.n_labels == data_loaders['music_labeled_train'].dataset.n_labels
stats = dict(mean_std_music_rep=mean_std_music_rep.tolist(), mean_std_cocktail_rep_norm11=mean_std_cocktail_rep_norm11.tolist())
return data_loaders, data_loaders['music_labeled_train'].dataset.n_labels, stats
class FixedLenRandomSampler(RandomSampler):
"""
Code from mnpinto - Miguel
https://forums.fast.ai/t/epochs-of-arbitrary-length/27777/10
"""
def __init__(self, data_source, bs, epoch_size, *args, **kwargs):
super().__init__(data_source)
self.not_sampled = np.array([True]*len(data_source))
self.update_epoch_size_and_batch(epoch_size, bs)
def update_epoch_size_and_batch(self, epoch_size, bs):
self.epoch_size = epoch_size
self.bs = bs
self.size_to_sample = self.epoch_size
self.nb_batches_per_epoch = self.epoch_size // self.bs
def _reset_state(self):
self.not_sampled[:] = True
def reset_and_sample(self, idx, total_to_sample):
n_to_sample = total_to_sample - len(idx)
ns = sum(self.not_sampled)
if ns == 0:
self._reset_state()
return self.reset_and_sample(idx, total_to_sample)
elif ns >= n_to_sample:
new_idx = np.random.choice(np.where(self.not_sampled)[0], size=n_to_sample, replace=False).tolist()
new_idx = [*idx, *new_idx]
assert len(new_idx) == total_to_sample
return new_idx
else:
idx_last = np.where(self.not_sampled)[0].tolist()
new_idx = [*idx, *idx_last]
self._reset_state()
return self.reset_and_sample(new_idx, total_to_sample)
def __iter__(self):
idx = self.reset_and_sample(idx=[], total_to_sample=self.size_to_sample)
assert len(idx) == self.size_to_sample
self.not_sampled[idx] = False
# print(ns, len(idx), len(idx_last)) # debug
out = []
i_idx = 0
for i in range(self.nb_batches_per_epoch):
batch = []
for j in range(self.bs):
batch.append(idx[i_idx])
i_idx += 1
out.append(batch)
return iter(out)
def __len__(self):
return self.nb_batches_per_epoch
|