from torch.utils.data import Dataset import numpy as np import torch device = 'cuda' if torch.cuda.is_available() else 'cpu' from src.music2cocktailrep.analysis.explore import get_alignment_dataset # Add your custom dataset class here class CocktailDataset(Dataset): def __init__(self, split, cocktail_reps): self.n_cocktails, self.dim_cocktail = cocktail_reps.shape labels = np.zeros([self.n_cocktails]) if split == 'train': self.cocktail_reps = cocktail_reps[:int(0.9 * self.n_cocktails), :].copy() self.labels = labels[:int(0.9 * self.n_cocktails)].copy() elif split == 'test': self.cocktail_reps = cocktail_reps[int(0.9 * self.n_cocktails):, :].copy() self.labels = labels[int(0.9 * self.n_cocktails):].copy() elif split == 'all': self.cocktail_reps = cocktail_reps.copy() self.labels = labels.copy() else: raise ValueError # self.n_cocktails = self.cocktail_reps.shape[0] # indexes = np.arange(self.n_cocktails) # np.random.shuffle(indexes) self.cocktail_reps = torch.FloatTensor(self.cocktail_reps).to(device) # oversample cocktails with eggs and bubbles ind_egg = np.argwhere(self.cocktail_reps[:, -1] > 0).flatten() ind_bubbles = np.argwhere(self.cocktail_reps[:, -3] > 0).flatten() n_copies = 4 egg_copies = torch.tile(self.cocktail_reps[ind_egg, :], dims=(n_copies * 3, 1)) bubbles_copies = torch.tile(self.cocktail_reps[ind_bubbles, :], dims=(n_copies, 1)) self.cocktail_reps = torch.cat([self.cocktail_reps, egg_copies, bubbles_copies], dim=0) self.n_cocktails = self.cocktail_reps.shape[0] indexes = np.arange(self.n_cocktails) np.random.shuffle(indexes) self.cocktail_reps = self.cocktail_reps[indexes] self.labels = torch.LongTensor(np.zeros([self.n_cocktails])).to(device) self.contains_egg = self.cocktail_reps[:, -1] > 0 self.contains_bubbles = self.cocktail_reps[:, -3] > 0 def __len__(self): return self.cocktail_reps.shape[0] def __getitem__(self, idx): return self.cocktail_reps[idx], self.labels[idx], self.contains_egg[idx], self.contains_bubbles[idx] class CocktailLabeledDataset(Dataset): def __init__(self, split, cocktail_reps): dataset = get_alignment_dataset() labels = sorted(dataset['cocktail'].keys()) self.n_labels = len(labels) # n_cocktails = np.sum([len(dataset['cocktail'][k]) for k in labels]) all_cocktails = [] for k in labels: all_cocktails += dataset['cocktail'][k] # assert n_cocktails == len(set(all_cocktails)) all_cocktails = np.array(all_cocktails) cocktail_reps = cocktail_reps[all_cocktails] cocktail_labels = [] for i in all_cocktails: for i_k, k in enumerate(labels): if i in dataset['cocktail'][k]: cocktail_labels.append(i_k) break cocktail_labels = np.array(cocktail_labels) assert len(cocktail_labels) == len(cocktail_reps) self.n_cocktails, self.dim_cocktail = cocktail_reps.shape indexes_train = [] indexes_test = [] for k in labels: indexes_k = np.argwhere(cocktail_labels == labels.index(k)).flatten() indexes_train += list(indexes_k[:int(0.9 * len(indexes_k))]) indexes_test += list(indexes_k[int(0.9 * len(indexes_k)):]) indexes_train = np.array(indexes_train) indexes_test = np.array(indexes_test) assert len(set(indexes_train) & set(indexes_test)) == 0 if split == 'train': self.cocktail_reps = cocktail_reps[indexes_train].copy() self.labels = cocktail_labels[indexes_train].copy() elif split == 'test': self.cocktail_reps = cocktail_reps[indexes_test].copy() self.labels = cocktail_labels[indexes_test].copy() elif split == 'all': self.cocktail_reps = cocktail_reps.copy() self.labels = cocktail_labels.copy() else: raise ValueError self.n_cocktails = self.cocktail_reps.shape[0] indexes = np.arange(self.n_cocktails) np.random.shuffle(indexes) self.cocktail_reps = torch.FloatTensor(self.cocktail_reps[indexes]).to(device) self.labels = torch.LongTensor(self.labels[indexes]).to(device) def __len__(self): return self.cocktail_reps.shape[0] def __getitem__(self, idx): return self.cocktail_reps[idx], self.labels[idx] class MusicDataset(Dataset): def __init__(self, split, music_reps, music_rep_paths): self.n_music, self.dim_music = music_reps.shape labels = np.zeros([self.n_music]) if split == 'train': self.music_reps = music_reps[:int(0.9 * self.n_music), :].copy() self.labels = labels[:int(0.9 * self.n_music)].copy() elif split == 'test': self.music_reps = music_reps[int(0.9 * self.n_music):, :].copy() self.labels = labels[int(0.9 * self.n_music):].copy() elif split == 'all': self.music_reps = music_reps.copy() self.labels = labels.copy() else: raise ValueError self.n_music = self.music_reps.shape[0] indexes = np.arange(self.n_music) np.random.shuffle(indexes) self.music_reps = torch.FloatTensor(self.music_reps[indexes]).to(device) self.labels = torch.LongTensor(self.labels[indexes]).to(device) def __len__(self): return self.music_reps.shape[0] def __getitem__(self, idx): return self.music_reps[idx], self.labels[idx] class RegressedGroundingDataset(Dataset): def __init__(self, split, music_reps, music_rep_paths, cocktail_reps): dataset = get_alignment_dataset() labels = sorted(dataset['cocktail'].keys()) self.n_labels = len(labels) n_music = np.sum([len(dataset['music'][k]) for k in labels]) all_music_filenames = [] for k in labels: all_music_filenames += dataset['music'][k] assert n_music == len(set(all_music_filenames)) all_music_filenames = np.array(all_music_filenames) all_cocktails = [] for k in labels: all_cocktails += dataset['cocktail'][k] # assert n_cocktails == len(set(all_cocktails)) all_cocktails = np.array(all_cocktails) indexes = [] for music_filename in all_music_filenames: rep_name = music_filename.replace('_processed.mid', '_b256_r128_represented.txt') found = False for i, rep_path in enumerate(music_rep_paths): if rep_name == rep_path[-len(rep_name):]: indexes.append(i) found = True break assert found # assert len(indexes) == len(all_music_filenames) music_reps = music_reps[np.array(indexes)] music_labels = [] for music_filename in all_music_filenames: for i_k, k in enumerate(labels): if music_filename in dataset['music'][k]: music_labels.append(i_k) break assert len(music_labels) == len(music_reps) music_labels = np.array(music_labels) self.n_music, self.dim_music = music_reps.shape self.classes = labels cocktail_reps = cocktail_reps[all_cocktails] cocktail_labels = [] for i in all_cocktails: for i_k, k in enumerate(labels): if i in dataset['cocktail'][k]: cocktail_labels.append(i_k) break cocktail_labels = np.array(cocktail_labels) assert len(cocktail_labels) == len(cocktail_reps) self.n_cocktails, self.dim_cocktail = cocktail_reps.shape cocktail_reps_matching_music_reps = [] for l in music_labels: ind_cocktails = np.where(cocktail_labels==l)[0] cocktail_reps_matching_music_reps.append(cocktail_reps[np.random.choice(ind_cocktails)]) cocktail_reps_matching_music_reps = np.array(cocktail_reps_matching_music_reps) indexes_train = [] indexes_test = [] for k in labels: indexes_k = np.argwhere(music_labels == labels.index(k)).flatten() indexes_train += list(indexes_k[:int(0.9 * len(indexes_k))]) indexes_test += list(indexes_k[int(0.9 * len(indexes_k)):]) indexes_train = np.array(indexes_train) indexes_test = np.array(indexes_test) assert len(set(indexes_train) & set(indexes_test)) == 0 if split == 'train': self.music_reps = music_reps[indexes_train].copy() self.cocktail_reps = cocktail_reps_matching_music_reps[indexes_train].copy() # self.labels = music_labels[indexes_train].copy() elif split == 'test': self.music_reps = music_reps[indexes_test].copy() self.cocktail_reps = cocktail_reps_matching_music_reps[indexes_test].copy() # self.labels = music_labels[indexes_test].copy() elif split == 'all': self.music_reps = music_reps.copy() self.cocktail_reps = cocktail_reps_matching_music_reps.copy() # self.labels = music_labels.copy() else: raise ValueError self.n_music = self.music_reps.shape[0] indexes = np.arange(self.n_music) np.random.shuffle(indexes) self.music_reps = torch.FloatTensor(self.music_reps[indexes]).to(device) self.cocktail_reps = torch.FloatTensor(self.cocktail_reps[indexes]).to(device) # self.labels = torch.LongTensor(self.labels[indexes]).to(device) def __len__(self): return self.music_reps.shape[0] def __getitem__(self, idx): return self.music_reps[idx], self.cocktail_reps[idx] class MusicLabeledDataset(Dataset): def __init__(self, split, music_reps, music_rep_paths): dataset = get_alignment_dataset() labels = sorted(dataset['cocktail'].keys()) self.n_labels = len(labels) n_music = np.sum([len(dataset['music'][k]) for k in labels]) all_music_filenames = [] for k in labels: all_music_filenames += dataset['music'][k] assert n_music == len(set(all_music_filenames)) all_music_filenames = np.array(all_music_filenames) indexes = [] for music_filename in all_music_filenames: rep_name = music_filename.replace('_processed.mid', '_b256_r128_represented.txt') found = False for i, rep_path in enumerate(music_rep_paths): if rep_name == rep_path[-len(rep_name):]: indexes.append(i) found = True break assert found # assert len(indexes) == len(all_music_filenames) music_reps = music_reps[np.array(indexes)] music_labels = [] for music_filename in all_music_filenames: for i_k, k in enumerate(labels): if music_filename in dataset['music'][k]: music_labels.append(i_k) break assert len(music_labels) == len(music_reps) music_labels = np.array(music_labels) self.n_music, self.dim_music = music_reps.shape self.classes = labels indexes_train = [] indexes_test = [] for k in labels: indexes_k = np.argwhere(music_labels == labels.index(k)).flatten() indexes_train += list(indexes_k[:int(0.9 * len(indexes_k))]) indexes_test += list(indexes_k[int(0.9 * len(indexes_k)):]) indexes_train = np.array(indexes_train) indexes_test = np.array(indexes_test) assert len(set(indexes_train) & set(indexes_test)) == 0 if split == 'train': self.music_reps = music_reps[indexes_train].copy() self.labels = music_labels[indexes_train].copy() elif split == 'test': self.music_reps = music_reps[indexes_test].copy() self.labels = music_labels[indexes_test].copy() elif split == 'all': self.music_reps = music_reps.copy() self.labels = music_labels.copy() else: raise ValueError self.n_music = self.music_reps.shape[0] indexes = np.arange(self.n_music) np.random.shuffle(indexes) self.music_reps = torch.FloatTensor(self.music_reps[indexes]).to(device) self.labels = torch.LongTensor(self.labels[indexes]).to(device) def __len__(self): return self.music_reps.shape[0] def __getitem__(self, idx): return self.music_reps[idx], self.labels[idx]