Cédric Colas
initial commit
e775f6d
raw
history blame
7.38 kB
import numpy as np
from sklearn.cluster import KMeans
from sklearn.neighbors import NearestNeighbors
from sklearn.manifold import TSNE
from src.music.utils import get_all_subfiles_with_extension
import matplotlib.pyplot as plt
import pickle
import random
# import umap
import os
from shutil import copy
# install numba =numba==0.51.2
# keyword = '32_represented'
# rep_path = f"/home/cedric/Documents/pianocktail/data/music/{keyword}/"
# plot_path = '/home/cedric/Documents/pianocktail/data/music/representation_analysis/plots/'
# neighbors_path = '/home/cedric/Documents/pianocktail/data/music/representation_analysis/neighbors/'
interpolation_path = '/home/cedric/Documents/pianocktail/data/music/representation_analysis/interpolation/'
keyword = 'b256_r128_represented'
rep_path = f"/home/cedric/Documents/pianocktail/data/music/dataset_exploration/dataset_representation/{keyword}/"
plot_path = '/home/cedric/Documents/pianocktail/data/music/dataset_exploration/dataset_representation/analysis/plots/'
neighbors_path = f'/home/cedric/Documents/pianocktail/data/music/dataset_exploration/dataset_representation/analysis/neighbors_{keyword}/'
os.makedirs(neighbors_path, exist_ok=True)
def extract_all_reps(rep_path):
all_rep_path = get_all_subfiles_with_extension(rep_path, max_depth=3, extension='.txt', current_depth=0)
all_data = []
new_all_rep_path = []
for i_r, r in enumerate(all_rep_path):
if 'mean_std' not in r:
all_data.append(np.loadtxt(r))
assert len(all_data[-1]) == 128
new_all_rep_path.append(r)
data = np.array(all_data)
to_save = dict(reps=data,
paths=new_all_rep_path)
with open(rep_path + 'music_reps_unnormalized.pickle', 'wb') as f:
pickle.dump(to_save, f)
for sample_size in [100, 200, 500, 1000, 2000, 5000]:
if sample_size < len(data):
inds = np.arange(len(data))
np.random.shuffle(inds)
to_save = dict(reps=data[inds[:sample_size]],
paths=np.array(all_rep_path)[inds[:sample_size]])
with open(rep_path + f'all_reps_unnormalized_sample{sample_size}.pickle', 'wb') as f:
pickle.dump(to_save, f)
def load_reps(rep_path, sample_size=None):
if sample_size:
with open(rep_path + f'all_reps_unnormalized_sample{sample_size}.pickle', 'rb') as f:
data = pickle.load(f)
else:
with open(rep_path + f'music_reps_unnormalized.pickle', 'rb') as f:
data = pickle.load(f)
reps = data['reps']
# playlists = [r.split(f'_{keyword}')[0].split('/')[-1] for r in data['paths']]
playlists = [r.split(f'{keyword}')[1].split('/')[1] for r in data['paths']]
n_data, dim_data = reps.shape
return reps, data['paths'], playlists, n_data, dim_data
def plot_tsne(reps, playlist_indexes, playlist_colors):
tsne_reps = TSNE(n_components=2, learning_rate='auto', init='random').fit_transform(reps)
plt.figure()
keys_to_print = ['spot_piano_solo_blues', 'itsremco', 'piano_solo_classical',
'piano_solo_pop', 'piano_jazz_unspecified','spot_piano_solo_jazz_1', 'piano_solo_jazz_latin']
keys_to_print = playlist_indexes.keys()
for k in sorted(keys_to_print):
if k in playlist_indexes.keys():
# plt.scatter(tsne_reps[playlist_indexes[k], 0], tsne_reps[playlist_indexes[k], 1], s=100, label=k, alpha=0.5)
plt.scatter(tsne_reps[playlist_indexes[k], 0], tsne_reps[playlist_indexes[k], 1], s=100, c=playlist_colors[k], label=k, alpha=0.5)
plt.legend()
plt.savefig(plot_path + f'tsne_{keyword}.png')
fig = plt.gcf()
plt.close(fig)
# umap_reps = umap.UMAP().fit_transform(reps)
# plt.figure()
# for k in sorted(keys_to_print):
# if k in playlist_indexes.keys():
# plt.scatter(umap_reps[playlist_indexes[k], 0], tsne_reps[playlist_indexes[k], 1], s=100, c=playlist_colors[k], label=k, alpha=0.5)
# plt.legend()
# plt.savefig(plot_path + f'umap_{keyword}.png')
# fig = plt.gcf()
# plt.close(fig)
return tsne_reps#, umap_reps
def get_playlist_indexes(playlists):
playlist_indexes = dict()
for i in range(n_data):
if playlists[i] not in playlist_indexes.keys():
playlist_indexes[playlists[i]] = [i]
else:
playlist_indexes[playlists[i]].append(i)
for k in playlist_indexes.keys():
playlist_indexes[k] = np.array(playlist_indexes[k])
set_playlists = sorted(set(playlists))
playlist_colors = dict(zip(set_playlists, ['#%06X' % random.randint(0, 0xFFFFFF) for _ in range(len(set_playlists))]))
return set_playlists, playlist_indexes, playlist_colors
def convert_rep_path_midi_path(rep_path):
# playlist = rep_path.split(f'_{keyword}/')[0].split('/')[-1]
playlist = rep_path.split(f'{keyword}')[1].split('/')[1].replace('_represented', '')
midi_path = "/home/cedric/Documents/pianocktail/data/music/dataset_exploration/dataset_representation/processed/" + playlist + '_processed/'
filename = rep_path.split(f'{keyword}')[1].split(f'/')[2].split('_represented.txt')[0] + '_processed.mid'
# filename = rep_path.split(f'_{keyword}/')[-1].split(f'_{keyword}')[0] + '_processed.mid'
midi_path = midi_path + filename
assert os.path.exists(midi_path), midi_path
return midi_path
def sample_nn(reps, rep_paths, playlists, n_samples=30):
nn_model = NearestNeighbors(n_neighbors=6, metric='cosine')
nn_model.fit(reps)
indexes = np.arange(len(reps))
np.random.shuffle(indexes)
for i, ind in enumerate(indexes[:n_samples]):
out = nn_model.kneighbors(reps[ind].reshape(1, -1))[1][0][1:]
midi_path = convert_rep_path_midi_path(rep_paths[ind])
copy(midi_path, neighbors_path + f'sample_{i}_playlist_{playlists[ind]}_target.mid')
for i_n, neighbor in enumerate(out):
midi_path = convert_rep_path_midi_path(rep_paths[neighbor])
copy(midi_path, neighbors_path + f'sample_{i}_playlist_{playlists[neighbor]}_neighbor_{i_n}.mid')
def interpolate(reps, rep_paths, path):
files = os.listdir(path)
bounds = [f for f in files if 'interpolation' not in f]
b_reps = [np.loadtxt(path + f) for f in bounds]
nn_model = NearestNeighbors(n_neighbors=6)
nn_model.fit(reps)
reps = [alpha * b_reps[0] + (1 - alpha) * b_reps[1] for alpha in np.linspace(0, 1., 5)]
copy(convert_rep_path_midi_path(path + bounds[1]), path + 'interpolation_0.mid')
copy(convert_rep_path_midi_path(path + bounds[0]), path + 'interpolation_1.mid')
for alpha, rep in zip(np.linspace(0, 1, 5)[1:-1], reps[1: -1]):
dists, indexes = nn_model.kneighbors(rep.reshape(1, -1))
if dists.flatten()[0] == 0:
nn = indexes.flatten()[1]
else:
nn = indexes.flatten()[0]
midi_path = convert_rep_path_midi_path(rep_paths[nn])
copy(midi_path, path + f'interpolation_{alpha}.mid')
if __name__ == '__main__':
extract_all_reps(rep_path)
reps, rep_paths, playlists, n_data, dim_data = load_reps(rep_path)
set_playlists, playlist_indexes, playlist_colors = get_playlist_indexes(playlists)
# interpolate(reps, rep_paths, interpolation_path + 'trial_1/')
sample_nn(reps, rep_paths, playlists)
tsne_reps, umap_reps = plot_tsne(reps, playlist_indexes, playlist_colors)