Spaces:
Runtime error
Runtime error
import numpy as np | |
from sklearn.cluster import KMeans | |
from sklearn.neighbors import NearestNeighbors | |
from sklearn.manifold import TSNE | |
from src.music.utils import get_all_subfiles_with_extension | |
import matplotlib.pyplot as plt | |
import pickle | |
import random | |
# import umap | |
import os | |
from shutil import copy | |
# install numba =numba==0.51.2 | |
# keyword = '32_represented' | |
# rep_path = f"/home/cedric/Documents/pianocktail/data/music/{keyword}/" | |
# plot_path = '/home/cedric/Documents/pianocktail/data/music/representation_analysis/plots/' | |
# neighbors_path = '/home/cedric/Documents/pianocktail/data/music/representation_analysis/neighbors/' | |
interpolation_path = '/home/cedric/Documents/pianocktail/data/music/representation_analysis/interpolation/' | |
keyword = 'b256_r128_represented' | |
rep_path = f"/home/cedric/Documents/pianocktail/data/music/dataset_exploration/dataset_representation/{keyword}/" | |
plot_path = '/home/cedric/Documents/pianocktail/data/music/dataset_exploration/dataset_representation/analysis/plots/' | |
neighbors_path = f'/home/cedric/Documents/pianocktail/data/music/dataset_exploration/dataset_representation/analysis/neighbors_{keyword}/' | |
os.makedirs(neighbors_path, exist_ok=True) | |
def extract_all_reps(rep_path): | |
all_rep_path = get_all_subfiles_with_extension(rep_path, max_depth=3, extension='.txt', current_depth=0) | |
all_data = [] | |
new_all_rep_path = [] | |
for i_r, r in enumerate(all_rep_path): | |
if 'mean_std' not in r: | |
all_data.append(np.loadtxt(r)) | |
assert len(all_data[-1]) == 128 | |
new_all_rep_path.append(r) | |
data = np.array(all_data) | |
to_save = dict(reps=data, | |
paths=new_all_rep_path) | |
with open(rep_path + 'music_reps_unnormalized.pickle', 'wb') as f: | |
pickle.dump(to_save, f) | |
for sample_size in [100, 200, 500, 1000, 2000, 5000]: | |
if sample_size < len(data): | |
inds = np.arange(len(data)) | |
np.random.shuffle(inds) | |
to_save = dict(reps=data[inds[:sample_size]], | |
paths=np.array(all_rep_path)[inds[:sample_size]]) | |
with open(rep_path + f'all_reps_unnormalized_sample{sample_size}.pickle', 'wb') as f: | |
pickle.dump(to_save, f) | |
def load_reps(rep_path, sample_size=None): | |
if sample_size: | |
with open(rep_path + f'all_reps_unnormalized_sample{sample_size}.pickle', 'rb') as f: | |
data = pickle.load(f) | |
else: | |
with open(rep_path + f'music_reps_unnormalized.pickle', 'rb') as f: | |
data = pickle.load(f) | |
reps = data['reps'] | |
# playlists = [r.split(f'_{keyword}')[0].split('/')[-1] for r in data['paths']] | |
playlists = [r.split(f'{keyword}')[1].split('/')[1] for r in data['paths']] | |
n_data, dim_data = reps.shape | |
return reps, data['paths'], playlists, n_data, dim_data | |
def plot_tsne(reps, playlist_indexes, playlist_colors): | |
tsne_reps = TSNE(n_components=2, learning_rate='auto', init='random').fit_transform(reps) | |
plt.figure() | |
keys_to_print = ['spot_piano_solo_blues', 'itsremco', 'piano_solo_classical', | |
'piano_solo_pop', 'piano_jazz_unspecified','spot_piano_solo_jazz_1', 'piano_solo_jazz_latin'] | |
keys_to_print = playlist_indexes.keys() | |
for k in sorted(keys_to_print): | |
if k in playlist_indexes.keys(): | |
# plt.scatter(tsne_reps[playlist_indexes[k], 0], tsne_reps[playlist_indexes[k], 1], s=100, label=k, alpha=0.5) | |
plt.scatter(tsne_reps[playlist_indexes[k], 0], tsne_reps[playlist_indexes[k], 1], s=100, c=playlist_colors[k], label=k, alpha=0.5) | |
plt.legend() | |
plt.savefig(plot_path + f'tsne_{keyword}.png') | |
fig = plt.gcf() | |
plt.close(fig) | |
# umap_reps = umap.UMAP().fit_transform(reps) | |
# plt.figure() | |
# for k in sorted(keys_to_print): | |
# if k in playlist_indexes.keys(): | |
# plt.scatter(umap_reps[playlist_indexes[k], 0], tsne_reps[playlist_indexes[k], 1], s=100, c=playlist_colors[k], label=k, alpha=0.5) | |
# plt.legend() | |
# plt.savefig(plot_path + f'umap_{keyword}.png') | |
# fig = plt.gcf() | |
# plt.close(fig) | |
return tsne_reps#, umap_reps | |
def get_playlist_indexes(playlists): | |
playlist_indexes = dict() | |
for i in range(n_data): | |
if playlists[i] not in playlist_indexes.keys(): | |
playlist_indexes[playlists[i]] = [i] | |
else: | |
playlist_indexes[playlists[i]].append(i) | |
for k in playlist_indexes.keys(): | |
playlist_indexes[k] = np.array(playlist_indexes[k]) | |
set_playlists = sorted(set(playlists)) | |
playlist_colors = dict(zip(set_playlists, ['#%06X' % random.randint(0, 0xFFFFFF) for _ in range(len(set_playlists))])) | |
return set_playlists, playlist_indexes, playlist_colors | |
def convert_rep_path_midi_path(rep_path): | |
# playlist = rep_path.split(f'_{keyword}/')[0].split('/')[-1] | |
playlist = rep_path.split(f'{keyword}')[1].split('/')[1].replace('_represented', '') | |
midi_path = "/home/cedric/Documents/pianocktail/data/music/dataset_exploration/dataset_representation/processed/" + playlist + '_processed/' | |
filename = rep_path.split(f'{keyword}')[1].split(f'/')[2].split('_represented.txt')[0] + '_processed.mid' | |
# filename = rep_path.split(f'_{keyword}/')[-1].split(f'_{keyword}')[0] + '_processed.mid' | |
midi_path = midi_path + filename | |
assert os.path.exists(midi_path), midi_path | |
return midi_path | |
def sample_nn(reps, rep_paths, playlists, n_samples=30): | |
nn_model = NearestNeighbors(n_neighbors=6, metric='cosine') | |
nn_model.fit(reps) | |
indexes = np.arange(len(reps)) | |
np.random.shuffle(indexes) | |
for i, ind in enumerate(indexes[:n_samples]): | |
out = nn_model.kneighbors(reps[ind].reshape(1, -1))[1][0][1:] | |
midi_path = convert_rep_path_midi_path(rep_paths[ind]) | |
copy(midi_path, neighbors_path + f'sample_{i}_playlist_{playlists[ind]}_target.mid') | |
for i_n, neighbor in enumerate(out): | |
midi_path = convert_rep_path_midi_path(rep_paths[neighbor]) | |
copy(midi_path, neighbors_path + f'sample_{i}_playlist_{playlists[neighbor]}_neighbor_{i_n}.mid') | |
def interpolate(reps, rep_paths, path): | |
files = os.listdir(path) | |
bounds = [f for f in files if 'interpolation' not in f] | |
b_reps = [np.loadtxt(path + f) for f in bounds] | |
nn_model = NearestNeighbors(n_neighbors=6) | |
nn_model.fit(reps) | |
reps = [alpha * b_reps[0] + (1 - alpha) * b_reps[1] for alpha in np.linspace(0, 1., 5)] | |
copy(convert_rep_path_midi_path(path + bounds[1]), path + 'interpolation_0.mid') | |
copy(convert_rep_path_midi_path(path + bounds[0]), path + 'interpolation_1.mid') | |
for alpha, rep in zip(np.linspace(0, 1, 5)[1:-1], reps[1: -1]): | |
dists, indexes = nn_model.kneighbors(rep.reshape(1, -1)) | |
if dists.flatten()[0] == 0: | |
nn = indexes.flatten()[1] | |
else: | |
nn = indexes.flatten()[0] | |
midi_path = convert_rep_path_midi_path(rep_paths[nn]) | |
copy(midi_path, path + f'interpolation_{alpha}.mid') | |
if __name__ == '__main__': | |
extract_all_reps(rep_path) | |
reps, rep_paths, playlists, n_data, dim_data = load_reps(rep_path) | |
set_playlists, playlist_indexes, playlist_colors = get_playlist_indexes(playlists) | |
# interpolate(reps, rep_paths, interpolation_path + 'trial_1/') | |
sample_nn(reps, rep_paths, playlists) | |
tsne_reps, umap_reps = plot_tsne(reps, playlist_indexes, playlist_colors) | |