Cédric Colas
initial commit
e775f6d
raw
history blame
9.28 kB
import os.path
from dataset import CocktailDataset, MusicDataset, CocktailLabeledDataset, MusicLabeledDataset, RegressedGroundingDataset
import torch
import numpy as np
from src.music.utilities.representation_learning_utilities.sampler import FixedLenRandomSampler
from torch.utils.data.sampler import RandomSampler
from torch.utils.data import DataLoader
from src.music.utils import get_all_subfiles_with_extension
import pickle
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def wasserstein1d(x, y):
x1, _ = torch.sort(x, dim=0)
y1, _ = torch.sort(y, dim=0)
z = (x1-y1).view(-1)
n = z.size(0)
return torch.dot(z, z) / n
def compute_swd_loss(minibatch1, minibatch2, latent_dim, n_projections=10000):
# sample random projections
theta = torch.randn((latent_dim, n_projections),
requires_grad=False,
device=device)
theta = theta/torch.norm(theta, dim=0)[None, :]
proj1 = minibatch1@theta
proj2 = minibatch2@theta
# compute sliced wasserstein distance on projected features
gloss = wasserstein1d(proj1, proj2)
return gloss
def get_dataloaders(cocktail_rep_path, music_rep_path, batch_size, train_epoch_size, test_epoch_size):
assert train_epoch_size % batch_size == 0, 'epoch size is expressed in steps, must be a multiple of batch size'
assert test_epoch_size % batch_size == 0, 'epoch size is expressed in steps, must be a multiple of batch size'
assert '.pickle' in music_rep_path
if not os.path.exists(music_rep_path):
music_rep_paths = get_all_subfiles_with_extension(music_rep_path.replace('music_reps_normalized_meanstd.pickle', ''), max_depth=3, extension='.txt', current_depth=0)
music_reps = []
for p in music_rep_paths:
music_reps.append(np.loadtxt(p))
music_reps = np.array(music_reps)
mean_std = np.array([music_reps.mean(axis=0), music_reps.std(axis=0)])
music_reps = (music_reps - mean_std[0]) / mean_std[1]
assert len(music_rep_paths) == len(music_reps), 'check bug with mean_std'
data = dict(zip(music_rep_paths, music_reps))
to_save = dict(musicpath2musicrep=data,
mean_std=mean_std)
with open(music_rep_path, 'wb') as f:
pickle.dump(to_save, f)
with open(music_rep_path, 'rb') as f:
data = pickle.load(f)
mean_std_music_rep = data['mean_std']
music_rep_paths = sorted(data['musicpath2musicrep'].keys())
music_reps = np.array([data['musicpath2musicrep'][k] for k in music_rep_paths])
cocktail_reps = np.loadtxt(cocktail_rep_path)
mean_std_cocktail_rep_norm11 = np.array([cocktail_reps.mean(axis=0), cocktail_reps.std(axis=0)])
cocktail_reps = (cocktail_reps - cocktail_reps.mean(axis=0)) / cocktail_reps.std(axis=0)
train_data_cocktail = CocktailDataset(split='train', cocktail_reps=cocktail_reps)
test_data_cocktail = CocktailDataset(split='test', cocktail_reps=cocktail_reps)
train_data_music = MusicDataset(split='train', music_reps=music_reps, music_rep_paths=music_rep_paths)
test_data_music = MusicDataset(split='test', music_reps=music_reps, music_rep_paths=music_rep_paths)
train_sampler_cocktail = FixedLenRandomSampler(train_data_cocktail, bs=batch_size, epoch_size=train_epoch_size)
test_sampler_cocktail = FixedLenRandomSampler(test_data_cocktail, bs=batch_size, epoch_size=test_epoch_size)
train_sampler_music = FixedLenRandomSampler(train_data_music, bs=batch_size, epoch_size=train_epoch_size)
test_sampler_music = FixedLenRandomSampler(test_data_music, bs=batch_size, epoch_size=test_epoch_size)
train_data_cocktail = DataLoader(train_data_cocktail, batch_sampler=train_sampler_cocktail)
test_data_cocktail = DataLoader(test_data_cocktail, batch_sampler=test_sampler_cocktail)
train_data_music = DataLoader(train_data_music, batch_sampler=train_sampler_music)
test_data_music = DataLoader(test_data_music, batch_sampler=test_sampler_music)
train_data_cocktail_labeled = CocktailLabeledDataset(split='train', cocktail_reps=cocktail_reps)
test_data_cocktail_labeled = CocktailLabeledDataset(split='test', cocktail_reps=cocktail_reps)
train_data_music_labeled = MusicLabeledDataset(split='train', music_reps=music_reps, music_rep_paths=music_rep_paths)
test_data_music_labeled = MusicLabeledDataset(split='test', music_reps=music_reps, music_rep_paths=music_rep_paths)
train_sampler_cocktail_labeled = FixedLenRandomSampler(train_data_cocktail_labeled, bs=batch_size, epoch_size=train_epoch_size)
test_sampler_cocktail_labeled = FixedLenRandomSampler(test_data_cocktail_labeled, bs=batch_size, epoch_size=test_epoch_size)
train_sampler_music_labeled = FixedLenRandomSampler(train_data_music_labeled, bs=batch_size, epoch_size=train_epoch_size)
test_sampler_music_labeled = FixedLenRandomSampler(test_data_music_labeled, bs=batch_size, epoch_size=test_epoch_size)
train_data_cocktail_labeled = DataLoader(train_data_cocktail_labeled, batch_sampler=train_sampler_cocktail_labeled)
test_data_cocktail_labeled = DataLoader(test_data_cocktail_labeled, batch_sampler=test_sampler_cocktail_labeled)
train_data_music_labeled = DataLoader(train_data_music_labeled, batch_sampler=train_sampler_music_labeled)
test_data_music_labeled = DataLoader(test_data_music_labeled, batch_sampler=test_sampler_music_labeled)
train_data_grounding = RegressedGroundingDataset(split='train', music_reps=music_reps, music_rep_paths=music_rep_paths, cocktail_reps=cocktail_reps)
test_data_grounding = RegressedGroundingDataset(split='test', music_reps=music_reps, music_rep_paths=music_rep_paths, cocktail_reps=cocktail_reps)
train_sampler_grounding = FixedLenRandomSampler(train_data_grounding, bs=batch_size, epoch_size=train_epoch_size)
test_sampler_grounding = FixedLenRandomSampler(test_data_grounding, bs=batch_size, epoch_size=test_epoch_size)
train_data_grounding = DataLoader(train_data_grounding, batch_sampler=train_sampler_grounding)
test_data_grounding = DataLoader(test_data_grounding, batch_sampler=test_sampler_grounding)
data_loaders = dict(music_train=train_data_music,
music_test=test_data_music,
cocktail_train=train_data_cocktail,
cocktail_test=test_data_cocktail,
music_labeled_train=train_data_music_labeled,
music_labeled_test=test_data_music_labeled,
cocktail_labeled_train=train_data_cocktail_labeled,
cocktail_labeled_test=test_data_cocktail_labeled,
reg_grounding_train=train_data_grounding,
reg_grounding_test=test_data_grounding
)
for k in data_loaders.keys():
print(f'Dataset {k}, size: {len(data_loaders[k].dataset)}')
assert data_loaders['cocktail_labeled_train'].dataset.n_labels == data_loaders['music_labeled_train'].dataset.n_labels
stats = dict(mean_std_music_rep=mean_std_music_rep.tolist(), mean_std_cocktail_rep_norm11=mean_std_cocktail_rep_norm11.tolist())
return data_loaders, data_loaders['music_labeled_train'].dataset.n_labels, stats
class FixedLenRandomSampler(RandomSampler):
"""
Code from mnpinto - Miguel
https://forums.fast.ai/t/epochs-of-arbitrary-length/27777/10
"""
def __init__(self, data_source, bs, epoch_size, *args, **kwargs):
super().__init__(data_source)
self.not_sampled = np.array([True]*len(data_source))
self.update_epoch_size_and_batch(epoch_size, bs)
def update_epoch_size_and_batch(self, epoch_size, bs):
self.epoch_size = epoch_size
self.bs = bs
self.size_to_sample = self.epoch_size
self.nb_batches_per_epoch = self.epoch_size // self.bs
def _reset_state(self):
self.not_sampled[:] = True
def reset_and_sample(self, idx, total_to_sample):
n_to_sample = total_to_sample - len(idx)
ns = sum(self.not_sampled)
if ns == 0:
self._reset_state()
return self.reset_and_sample(idx, total_to_sample)
elif ns >= n_to_sample:
new_idx = np.random.choice(np.where(self.not_sampled)[0], size=n_to_sample, replace=False).tolist()
new_idx = [*idx, *new_idx]
assert len(new_idx) == total_to_sample
return new_idx
else:
idx_last = np.where(self.not_sampled)[0].tolist()
new_idx = [*idx, *idx_last]
self._reset_state()
return self.reset_and_sample(new_idx, total_to_sample)
def __iter__(self):
idx = self.reset_and_sample(idx=[], total_to_sample=self.size_to_sample)
assert len(idx) == self.size_to_sample
self.not_sampled[idx] = False
# print(ns, len(idx), len(idx_last)) # debug
out = []
i_idx = 0
for i in range(self.nb_batches_per_epoch):
batch = []
for j in range(self.bs):
batch.append(idx[i_idx])
i_idx += 1
out.append(batch)
return iter(out)
def __len__(self):
return self.nb_batches_per_epoch