Spaces:
Runtime error
Runtime error
import os.path | |
from dataset import CocktailDataset, MusicDataset, CocktailLabeledDataset, MusicLabeledDataset, RegressedGroundingDataset | |
import torch | |
import numpy as np | |
from src.music.utilities.representation_learning_utilities.sampler import FixedLenRandomSampler | |
from torch.utils.data.sampler import RandomSampler | |
from torch.utils.data import DataLoader | |
from src.music.utils import get_all_subfiles_with_extension | |
import pickle | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
def wasserstein1d(x, y): | |
x1, _ = torch.sort(x, dim=0) | |
y1, _ = torch.sort(y, dim=0) | |
z = (x1-y1).view(-1) | |
n = z.size(0) | |
return torch.dot(z, z) / n | |
def compute_swd_loss(minibatch1, minibatch2, latent_dim, n_projections=10000): | |
# sample random projections | |
theta = torch.randn((latent_dim, n_projections), | |
requires_grad=False, | |
device=device) | |
theta = theta/torch.norm(theta, dim=0)[None, :] | |
proj1 = minibatch1@theta | |
proj2 = minibatch2@theta | |
# compute sliced wasserstein distance on projected features | |
gloss = wasserstein1d(proj1, proj2) | |
return gloss | |
def get_dataloaders(cocktail_rep_path, music_rep_path, batch_size, train_epoch_size, test_epoch_size): | |
assert train_epoch_size % batch_size == 0, 'epoch size is expressed in steps, must be a multiple of batch size' | |
assert test_epoch_size % batch_size == 0, 'epoch size is expressed in steps, must be a multiple of batch size' | |
assert '.pickle' in music_rep_path | |
if not os.path.exists(music_rep_path): | |
music_rep_paths = get_all_subfiles_with_extension(music_rep_path.replace('music_reps_normalized_meanstd.pickle', ''), max_depth=3, extension='.txt', current_depth=0) | |
music_reps = [] | |
for p in music_rep_paths: | |
music_reps.append(np.loadtxt(p)) | |
music_reps = np.array(music_reps) | |
mean_std = np.array([music_reps.mean(axis=0), music_reps.std(axis=0)]) | |
music_reps = (music_reps - mean_std[0]) / mean_std[1] | |
assert len(music_rep_paths) == len(music_reps), 'check bug with mean_std' | |
data = dict(zip(music_rep_paths, music_reps)) | |
to_save = dict(musicpath2musicrep=data, | |
mean_std=mean_std) | |
with open(music_rep_path, 'wb') as f: | |
pickle.dump(to_save, f) | |
with open(music_rep_path, 'rb') as f: | |
data = pickle.load(f) | |
mean_std_music_rep = data['mean_std'] | |
music_rep_paths = sorted(data['musicpath2musicrep'].keys()) | |
music_reps = np.array([data['musicpath2musicrep'][k] for k in music_rep_paths]) | |
cocktail_reps = np.loadtxt(cocktail_rep_path) | |
mean_std_cocktail_rep_norm11 = np.array([cocktail_reps.mean(axis=0), cocktail_reps.std(axis=0)]) | |
cocktail_reps = (cocktail_reps - cocktail_reps.mean(axis=0)) / cocktail_reps.std(axis=0) | |
train_data_cocktail = CocktailDataset(split='train', cocktail_reps=cocktail_reps) | |
test_data_cocktail = CocktailDataset(split='test', cocktail_reps=cocktail_reps) | |
train_data_music = MusicDataset(split='train', music_reps=music_reps, music_rep_paths=music_rep_paths) | |
test_data_music = MusicDataset(split='test', music_reps=music_reps, music_rep_paths=music_rep_paths) | |
train_sampler_cocktail = FixedLenRandomSampler(train_data_cocktail, bs=batch_size, epoch_size=train_epoch_size) | |
test_sampler_cocktail = FixedLenRandomSampler(test_data_cocktail, bs=batch_size, epoch_size=test_epoch_size) | |
train_sampler_music = FixedLenRandomSampler(train_data_music, bs=batch_size, epoch_size=train_epoch_size) | |
test_sampler_music = FixedLenRandomSampler(test_data_music, bs=batch_size, epoch_size=test_epoch_size) | |
train_data_cocktail = DataLoader(train_data_cocktail, batch_sampler=train_sampler_cocktail) | |
test_data_cocktail = DataLoader(test_data_cocktail, batch_sampler=test_sampler_cocktail) | |
train_data_music = DataLoader(train_data_music, batch_sampler=train_sampler_music) | |
test_data_music = DataLoader(test_data_music, batch_sampler=test_sampler_music) | |
train_data_cocktail_labeled = CocktailLabeledDataset(split='train', cocktail_reps=cocktail_reps) | |
test_data_cocktail_labeled = CocktailLabeledDataset(split='test', cocktail_reps=cocktail_reps) | |
train_data_music_labeled = MusicLabeledDataset(split='train', music_reps=music_reps, music_rep_paths=music_rep_paths) | |
test_data_music_labeled = MusicLabeledDataset(split='test', music_reps=music_reps, music_rep_paths=music_rep_paths) | |
train_sampler_cocktail_labeled = FixedLenRandomSampler(train_data_cocktail_labeled, bs=batch_size, epoch_size=train_epoch_size) | |
test_sampler_cocktail_labeled = FixedLenRandomSampler(test_data_cocktail_labeled, bs=batch_size, epoch_size=test_epoch_size) | |
train_sampler_music_labeled = FixedLenRandomSampler(train_data_music_labeled, bs=batch_size, epoch_size=train_epoch_size) | |
test_sampler_music_labeled = FixedLenRandomSampler(test_data_music_labeled, bs=batch_size, epoch_size=test_epoch_size) | |
train_data_cocktail_labeled = DataLoader(train_data_cocktail_labeled, batch_sampler=train_sampler_cocktail_labeled) | |
test_data_cocktail_labeled = DataLoader(test_data_cocktail_labeled, batch_sampler=test_sampler_cocktail_labeled) | |
train_data_music_labeled = DataLoader(train_data_music_labeled, batch_sampler=train_sampler_music_labeled) | |
test_data_music_labeled = DataLoader(test_data_music_labeled, batch_sampler=test_sampler_music_labeled) | |
train_data_grounding = RegressedGroundingDataset(split='train', music_reps=music_reps, music_rep_paths=music_rep_paths, cocktail_reps=cocktail_reps) | |
test_data_grounding = RegressedGroundingDataset(split='test', music_reps=music_reps, music_rep_paths=music_rep_paths, cocktail_reps=cocktail_reps) | |
train_sampler_grounding = FixedLenRandomSampler(train_data_grounding, bs=batch_size, epoch_size=train_epoch_size) | |
test_sampler_grounding = FixedLenRandomSampler(test_data_grounding, bs=batch_size, epoch_size=test_epoch_size) | |
train_data_grounding = DataLoader(train_data_grounding, batch_sampler=train_sampler_grounding) | |
test_data_grounding = DataLoader(test_data_grounding, batch_sampler=test_sampler_grounding) | |
data_loaders = dict(music_train=train_data_music, | |
music_test=test_data_music, | |
cocktail_train=train_data_cocktail, | |
cocktail_test=test_data_cocktail, | |
music_labeled_train=train_data_music_labeled, | |
music_labeled_test=test_data_music_labeled, | |
cocktail_labeled_train=train_data_cocktail_labeled, | |
cocktail_labeled_test=test_data_cocktail_labeled, | |
reg_grounding_train=train_data_grounding, | |
reg_grounding_test=test_data_grounding | |
) | |
for k in data_loaders.keys(): | |
print(f'Dataset {k}, size: {len(data_loaders[k].dataset)}') | |
assert data_loaders['cocktail_labeled_train'].dataset.n_labels == data_loaders['music_labeled_train'].dataset.n_labels | |
stats = dict(mean_std_music_rep=mean_std_music_rep.tolist(), mean_std_cocktail_rep_norm11=mean_std_cocktail_rep_norm11.tolist()) | |
return data_loaders, data_loaders['music_labeled_train'].dataset.n_labels, stats | |
class FixedLenRandomSampler(RandomSampler): | |
""" | |
Code from mnpinto - Miguel | |
https://forums.fast.ai/t/epochs-of-arbitrary-length/27777/10 | |
""" | |
def __init__(self, data_source, bs, epoch_size, *args, **kwargs): | |
super().__init__(data_source) | |
self.not_sampled = np.array([True]*len(data_source)) | |
self.update_epoch_size_and_batch(epoch_size, bs) | |
def update_epoch_size_and_batch(self, epoch_size, bs): | |
self.epoch_size = epoch_size | |
self.bs = bs | |
self.size_to_sample = self.epoch_size | |
self.nb_batches_per_epoch = self.epoch_size // self.bs | |
def _reset_state(self): | |
self.not_sampled[:] = True | |
def reset_and_sample(self, idx, total_to_sample): | |
n_to_sample = total_to_sample - len(idx) | |
ns = sum(self.not_sampled) | |
if ns == 0: | |
self._reset_state() | |
return self.reset_and_sample(idx, total_to_sample) | |
elif ns >= n_to_sample: | |
new_idx = np.random.choice(np.where(self.not_sampled)[0], size=n_to_sample, replace=False).tolist() | |
new_idx = [*idx, *new_idx] | |
assert len(new_idx) == total_to_sample | |
return new_idx | |
else: | |
idx_last = np.where(self.not_sampled)[0].tolist() | |
new_idx = [*idx, *idx_last] | |
self._reset_state() | |
return self.reset_and_sample(new_idx, total_to_sample) | |
def __iter__(self): | |
idx = self.reset_and_sample(idx=[], total_to_sample=self.size_to_sample) | |
assert len(idx) == self.size_to_sample | |
self.not_sampled[idx] = False | |
# print(ns, len(idx), len(idx_last)) # debug | |
out = [] | |
i_idx = 0 | |
for i in range(self.nb_batches_per_epoch): | |
batch = [] | |
for j in range(self.bs): | |
batch.append(idx[i_idx]) | |
i_idx += 1 | |
out.append(batch) | |
return iter(out) | |
def __len__(self): | |
return self.nb_batches_per_epoch | |