Spaces:
Runtime error
Runtime error
from torch.utils.data import Dataset | |
import numpy as np | |
import torch | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
from src.music2cocktailrep.analysis.explore import get_alignment_dataset | |
# Add your custom dataset class here | |
class CocktailDataset(Dataset): | |
def __init__(self, split, cocktail_reps): | |
self.n_cocktails, self.dim_cocktail = cocktail_reps.shape | |
labels = np.zeros([self.n_cocktails]) | |
if split == 'train': | |
self.cocktail_reps = cocktail_reps[:int(0.9 * self.n_cocktails), :].copy() | |
self.labels = labels[:int(0.9 * self.n_cocktails)].copy() | |
elif split == 'test': | |
self.cocktail_reps = cocktail_reps[int(0.9 * self.n_cocktails):, :].copy() | |
self.labels = labels[int(0.9 * self.n_cocktails):].copy() | |
elif split == 'all': | |
self.cocktail_reps = cocktail_reps.copy() | |
self.labels = labels.copy() | |
else: | |
raise ValueError | |
# self.n_cocktails = self.cocktail_reps.shape[0] | |
# indexes = np.arange(self.n_cocktails) | |
# np.random.shuffle(indexes) | |
self.cocktail_reps = torch.FloatTensor(self.cocktail_reps).to(device) | |
# oversample cocktails with eggs and bubbles | |
ind_egg = np.argwhere(self.cocktail_reps[:, -1] > 0).flatten() | |
ind_bubbles = np.argwhere(self.cocktail_reps[:, -3] > 0).flatten() | |
n_copies = 4 | |
egg_copies = torch.tile(self.cocktail_reps[ind_egg, :], dims=(n_copies * 3, 1)) | |
bubbles_copies = torch.tile(self.cocktail_reps[ind_bubbles, :], dims=(n_copies, 1)) | |
self.cocktail_reps = torch.cat([self.cocktail_reps, egg_copies, bubbles_copies], dim=0) | |
self.n_cocktails = self.cocktail_reps.shape[0] | |
indexes = np.arange(self.n_cocktails) | |
np.random.shuffle(indexes) | |
self.cocktail_reps = self.cocktail_reps[indexes] | |
self.labels = torch.LongTensor(np.zeros([self.n_cocktails])).to(device) | |
self.contains_egg = self.cocktail_reps[:, -1] > 0 | |
self.contains_bubbles = self.cocktail_reps[:, -3] > 0 | |
def __len__(self): | |
return self.cocktail_reps.shape[0] | |
def __getitem__(self, idx): | |
return self.cocktail_reps[idx], self.labels[idx], self.contains_egg[idx], self.contains_bubbles[idx] | |
class CocktailLabeledDataset(Dataset): | |
def __init__(self, split, cocktail_reps): | |
dataset = get_alignment_dataset() | |
labels = sorted(dataset['cocktail'].keys()) | |
self.n_labels = len(labels) | |
# n_cocktails = np.sum([len(dataset['cocktail'][k]) for k in labels]) | |
all_cocktails = [] | |
for k in labels: | |
all_cocktails += dataset['cocktail'][k] | |
# assert n_cocktails == len(set(all_cocktails)) | |
all_cocktails = np.array(all_cocktails) | |
cocktail_reps = cocktail_reps[all_cocktails] | |
cocktail_labels = [] | |
for i in all_cocktails: | |
for i_k, k in enumerate(labels): | |
if i in dataset['cocktail'][k]: | |
cocktail_labels.append(i_k) | |
break | |
cocktail_labels = np.array(cocktail_labels) | |
assert len(cocktail_labels) == len(cocktail_reps) | |
self.n_cocktails, self.dim_cocktail = cocktail_reps.shape | |
indexes_train = [] | |
indexes_test = [] | |
for k in labels: | |
indexes_k = np.argwhere(cocktail_labels == labels.index(k)).flatten() | |
indexes_train += list(indexes_k[:int(0.9 * len(indexes_k))]) | |
indexes_test += list(indexes_k[int(0.9 * len(indexes_k)):]) | |
indexes_train = np.array(indexes_train) | |
indexes_test = np.array(indexes_test) | |
assert len(set(indexes_train) & set(indexes_test)) == 0 | |
if split == 'train': | |
self.cocktail_reps = cocktail_reps[indexes_train].copy() | |
self.labels = cocktail_labels[indexes_train].copy() | |
elif split == 'test': | |
self.cocktail_reps = cocktail_reps[indexes_test].copy() | |
self.labels = cocktail_labels[indexes_test].copy() | |
elif split == 'all': | |
self.cocktail_reps = cocktail_reps.copy() | |
self.labels = cocktail_labels.copy() | |
else: | |
raise ValueError | |
self.n_cocktails = self.cocktail_reps.shape[0] | |
indexes = np.arange(self.n_cocktails) | |
np.random.shuffle(indexes) | |
self.cocktail_reps = torch.FloatTensor(self.cocktail_reps[indexes]).to(device) | |
self.labels = torch.LongTensor(self.labels[indexes]).to(device) | |
def __len__(self): | |
return self.cocktail_reps.shape[0] | |
def __getitem__(self, idx): | |
return self.cocktail_reps[idx], self.labels[idx] | |
class MusicDataset(Dataset): | |
def __init__(self, split, music_reps, music_rep_paths): | |
self.n_music, self.dim_music = music_reps.shape | |
labels = np.zeros([self.n_music]) | |
if split == 'train': | |
self.music_reps = music_reps[:int(0.9 * self.n_music), :].copy() | |
self.labels = labels[:int(0.9 * self.n_music)].copy() | |
elif split == 'test': | |
self.music_reps = music_reps[int(0.9 * self.n_music):, :].copy() | |
self.labels = labels[int(0.9 * self.n_music):].copy() | |
elif split == 'all': | |
self.music_reps = music_reps.copy() | |
self.labels = labels.copy() | |
else: | |
raise ValueError | |
self.n_music = self.music_reps.shape[0] | |
indexes = np.arange(self.n_music) | |
np.random.shuffle(indexes) | |
self.music_reps = torch.FloatTensor(self.music_reps[indexes]).to(device) | |
self.labels = torch.LongTensor(self.labels[indexes]).to(device) | |
def __len__(self): | |
return self.music_reps.shape[0] | |
def __getitem__(self, idx): | |
return self.music_reps[idx], self.labels[idx] | |
class RegressedGroundingDataset(Dataset): | |
def __init__(self, split, music_reps, music_rep_paths, cocktail_reps): | |
dataset = get_alignment_dataset() | |
labels = sorted(dataset['cocktail'].keys()) | |
self.n_labels = len(labels) | |
n_music = np.sum([len(dataset['music'][k]) for k in labels]) | |
all_music_filenames = [] | |
for k in labels: | |
all_music_filenames += dataset['music'][k] | |
assert n_music == len(set(all_music_filenames)) | |
all_music_filenames = np.array(all_music_filenames) | |
all_cocktails = [] | |
for k in labels: | |
all_cocktails += dataset['cocktail'][k] | |
# assert n_cocktails == len(set(all_cocktails)) | |
all_cocktails = np.array(all_cocktails) | |
indexes = [] | |
for music_filename in all_music_filenames: | |
rep_name = music_filename.replace('_processed.mid', '_b256_r128_represented.txt') | |
found = False | |
for i, rep_path in enumerate(music_rep_paths): | |
if rep_name == rep_path[-len(rep_name):]: | |
indexes.append(i) | |
found = True | |
break | |
assert found | |
# assert len(indexes) == len(all_music_filenames) | |
music_reps = music_reps[np.array(indexes)] | |
music_labels = [] | |
for music_filename in all_music_filenames: | |
for i_k, k in enumerate(labels): | |
if music_filename in dataset['music'][k]: | |
music_labels.append(i_k) | |
break | |
assert len(music_labels) == len(music_reps) | |
music_labels = np.array(music_labels) | |
self.n_music, self.dim_music = music_reps.shape | |
self.classes = labels | |
cocktail_reps = cocktail_reps[all_cocktails] | |
cocktail_labels = [] | |
for i in all_cocktails: | |
for i_k, k in enumerate(labels): | |
if i in dataset['cocktail'][k]: | |
cocktail_labels.append(i_k) | |
break | |
cocktail_labels = np.array(cocktail_labels) | |
assert len(cocktail_labels) == len(cocktail_reps) | |
self.n_cocktails, self.dim_cocktail = cocktail_reps.shape | |
cocktail_reps_matching_music_reps = [] | |
for l in music_labels: | |
ind_cocktails = np.where(cocktail_labels==l)[0] | |
cocktail_reps_matching_music_reps.append(cocktail_reps[np.random.choice(ind_cocktails)]) | |
cocktail_reps_matching_music_reps = np.array(cocktail_reps_matching_music_reps) | |
indexes_train = [] | |
indexes_test = [] | |
for k in labels: | |
indexes_k = np.argwhere(music_labels == labels.index(k)).flatten() | |
indexes_train += list(indexes_k[:int(0.9 * len(indexes_k))]) | |
indexes_test += list(indexes_k[int(0.9 * len(indexes_k)):]) | |
indexes_train = np.array(indexes_train) | |
indexes_test = np.array(indexes_test) | |
assert len(set(indexes_train) & set(indexes_test)) == 0 | |
if split == 'train': | |
self.music_reps = music_reps[indexes_train].copy() | |
self.cocktail_reps = cocktail_reps_matching_music_reps[indexes_train].copy() | |
# self.labels = music_labels[indexes_train].copy() | |
elif split == 'test': | |
self.music_reps = music_reps[indexes_test].copy() | |
self.cocktail_reps = cocktail_reps_matching_music_reps[indexes_test].copy() | |
# self.labels = music_labels[indexes_test].copy() | |
elif split == 'all': | |
self.music_reps = music_reps.copy() | |
self.cocktail_reps = cocktail_reps_matching_music_reps.copy() | |
# self.labels = music_labels.copy() | |
else: | |
raise ValueError | |
self.n_music = self.music_reps.shape[0] | |
indexes = np.arange(self.n_music) | |
np.random.shuffle(indexes) | |
self.music_reps = torch.FloatTensor(self.music_reps[indexes]).to(device) | |
self.cocktail_reps = torch.FloatTensor(self.cocktail_reps[indexes]).to(device) | |
# self.labels = torch.LongTensor(self.labels[indexes]).to(device) | |
def __len__(self): | |
return self.music_reps.shape[0] | |
def __getitem__(self, idx): | |
return self.music_reps[idx], self.cocktail_reps[idx] | |
class MusicLabeledDataset(Dataset): | |
def __init__(self, split, music_reps, music_rep_paths): | |
dataset = get_alignment_dataset() | |
labels = sorted(dataset['cocktail'].keys()) | |
self.n_labels = len(labels) | |
n_music = np.sum([len(dataset['music'][k]) for k in labels]) | |
all_music_filenames = [] | |
for k in labels: | |
all_music_filenames += dataset['music'][k] | |
assert n_music == len(set(all_music_filenames)) | |
all_music_filenames = np.array(all_music_filenames) | |
indexes = [] | |
for music_filename in all_music_filenames: | |
rep_name = music_filename.replace('_processed.mid', '_b256_r128_represented.txt') | |
found = False | |
for i, rep_path in enumerate(music_rep_paths): | |
if rep_name == rep_path[-len(rep_name):]: | |
indexes.append(i) | |
found = True | |
break | |
assert found | |
# assert len(indexes) == len(all_music_filenames) | |
music_reps = music_reps[np.array(indexes)] | |
music_labels = [] | |
for music_filename in all_music_filenames: | |
for i_k, k in enumerate(labels): | |
if music_filename in dataset['music'][k]: | |
music_labels.append(i_k) | |
break | |
assert len(music_labels) == len(music_reps) | |
music_labels = np.array(music_labels) | |
self.n_music, self.dim_music = music_reps.shape | |
self.classes = labels | |
indexes_train = [] | |
indexes_test = [] | |
for k in labels: | |
indexes_k = np.argwhere(music_labels == labels.index(k)).flatten() | |
indexes_train += list(indexes_k[:int(0.9 * len(indexes_k))]) | |
indexes_test += list(indexes_k[int(0.9 * len(indexes_k)):]) | |
indexes_train = np.array(indexes_train) | |
indexes_test = np.array(indexes_test) | |
assert len(set(indexes_train) & set(indexes_test)) == 0 | |
if split == 'train': | |
self.music_reps = music_reps[indexes_train].copy() | |
self.labels = music_labels[indexes_train].copy() | |
elif split == 'test': | |
self.music_reps = music_reps[indexes_test].copy() | |
self.labels = music_labels[indexes_test].copy() | |
elif split == 'all': | |
self.music_reps = music_reps.copy() | |
self.labels = music_labels.copy() | |
else: | |
raise ValueError | |
self.n_music = self.music_reps.shape[0] | |
indexes = np.arange(self.n_music) | |
np.random.shuffle(indexes) | |
self.music_reps = torch.FloatTensor(self.music_reps[indexes]).to(device) | |
self.labels = torch.LongTensor(self.labels[indexes]).to(device) | |
def __len__(self): | |
return self.music_reps.shape[0] | |
def __getitem__(self, idx): | |
return self.music_reps[idx], self.labels[idx] |