File size: 7,382 Bytes
e775f6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import numpy as np
from sklearn.cluster import KMeans
from sklearn.neighbors import NearestNeighbors
from sklearn.manifold import TSNE
from src.music.utils import get_all_subfiles_with_extension
import matplotlib.pyplot as plt
import pickle
import random
# import umap
import os
from shutil import copy
# install numba =numba==0.51.2
# keyword = '32_represented'
# rep_path = f"/home/cedric/Documents/pianocktail/data/music/{keyword}/"
# plot_path = '/home/cedric/Documents/pianocktail/data/music/representation_analysis/plots/'
# neighbors_path = '/home/cedric/Documents/pianocktail/data/music/representation_analysis/neighbors/'
interpolation_path = '/home/cedric/Documents/pianocktail/data/music/representation_analysis/interpolation/'
keyword = 'b256_r128_represented'
rep_path = f"/home/cedric/Documents/pianocktail/data/music/dataset_exploration/dataset_representation/{keyword}/"
plot_path = '/home/cedric/Documents/pianocktail/data/music/dataset_exploration/dataset_representation/analysis/plots/'
neighbors_path = f'/home/cedric/Documents/pianocktail/data/music/dataset_exploration/dataset_representation/analysis/neighbors_{keyword}/'
os.makedirs(neighbors_path, exist_ok=True)
def extract_all_reps(rep_path):
    all_rep_path = get_all_subfiles_with_extension(rep_path, max_depth=3, extension='.txt', current_depth=0)
    all_data = []
    new_all_rep_path = []
    for i_r, r in enumerate(all_rep_path):
        if 'mean_std' not in r:
            all_data.append(np.loadtxt(r))
            assert len(all_data[-1]) == 128
            new_all_rep_path.append(r)
    data = np.array(all_data)
    to_save = dict(reps=data,
                   paths=new_all_rep_path)
    with open(rep_path + 'music_reps_unnormalized.pickle', 'wb') as f:
        pickle.dump(to_save, f)
    for sample_size in [100, 200, 500, 1000, 2000, 5000]:
        if sample_size < len(data):
            inds = np.arange(len(data))
            np.random.shuffle(inds)
            to_save = dict(reps=data[inds[:sample_size]],
                           paths=np.array(all_rep_path)[inds[:sample_size]])
            with open(rep_path + f'all_reps_unnormalized_sample{sample_size}.pickle', 'wb') as f:
                pickle.dump(to_save, f)

def load_reps(rep_path, sample_size=None):
    if sample_size:
        with open(rep_path + f'all_reps_unnormalized_sample{sample_size}.pickle', 'rb') as f:
            data = pickle.load(f)
    else:
        with open(rep_path + f'music_reps_unnormalized.pickle', 'rb') as f:
            data = pickle.load(f)
    reps = data['reps']
    # playlists = [r.split(f'_{keyword}')[0].split('/')[-1] for r in data['paths']]
    playlists = [r.split(f'{keyword}')[1].split('/')[1] for r in data['paths']]
    n_data, dim_data = reps.shape
    return reps, data['paths'], playlists, n_data, dim_data


def plot_tsne(reps, playlist_indexes, playlist_colors):
    tsne_reps = TSNE(n_components=2, learning_rate='auto', init='random').fit_transform(reps)
    plt.figure()
    keys_to_print = ['spot_piano_solo_blues', 'itsremco', 'piano_solo_classical',
                     'piano_solo_pop', 'piano_jazz_unspecified','spot_piano_solo_jazz_1', 'piano_solo_jazz_latin']
    keys_to_print = playlist_indexes.keys()
    for k in sorted(keys_to_print):
        if k in playlist_indexes.keys():
            # plt.scatter(tsne_reps[playlist_indexes[k], 0], tsne_reps[playlist_indexes[k], 1], s=100, label=k, alpha=0.5)
            plt.scatter(tsne_reps[playlist_indexes[k], 0], tsne_reps[playlist_indexes[k], 1], s=100, c=playlist_colors[k], label=k, alpha=0.5)
    plt.legend()
    plt.savefig(plot_path + f'tsne_{keyword}.png')
    fig = plt.gcf()
    plt.close(fig)
    # umap_reps = umap.UMAP().fit_transform(reps)
    # plt.figure()
    # for k in sorted(keys_to_print):
    #     if k in playlist_indexes.keys():
    #         plt.scatter(umap_reps[playlist_indexes[k], 0], tsne_reps[playlist_indexes[k], 1], s=100, c=playlist_colors[k], label=k, alpha=0.5)
    # plt.legend()
    # plt.savefig(plot_path + f'umap_{keyword}.png')
    # fig = plt.gcf()
    # plt.close(fig)
    return tsne_reps#, umap_reps

def get_playlist_indexes(playlists):
    playlist_indexes = dict()
    for i in range(n_data):
        if playlists[i] not in playlist_indexes.keys():
            playlist_indexes[playlists[i]] = [i]
        else:
            playlist_indexes[playlists[i]].append(i)
    for k in playlist_indexes.keys():
        playlist_indexes[k] = np.array(playlist_indexes[k])
    set_playlists = sorted(set(playlists))
    playlist_colors = dict(zip(set_playlists, ['#%06X' % random.randint(0, 0xFFFFFF) for _ in range(len(set_playlists))]))
    return set_playlists, playlist_indexes, playlist_colors

def convert_rep_path_midi_path(rep_path):
    # playlist = rep_path.split(f'_{keyword}/')[0].split('/')[-1]
    playlist = rep_path.split(f'{keyword}')[1].split('/')[1].replace('_represented', '')
    midi_path = "/home/cedric/Documents/pianocktail/data/music/dataset_exploration/dataset_representation/processed/" + playlist + '_processed/'
    filename = rep_path.split(f'{keyword}')[1].split(f'/')[2].split('_represented.txt')[0] + '_processed.mid'
    # filename = rep_path.split(f'_{keyword}/')[-1].split(f'_{keyword}')[0] + '_processed.mid'
    midi_path = midi_path + filename
    assert os.path.exists(midi_path), midi_path
    return midi_path

def sample_nn(reps, rep_paths, playlists, n_samples=30):
    nn_model = NearestNeighbors(n_neighbors=6, metric='cosine')
    nn_model.fit(reps)
    indexes = np.arange(len(reps))
    np.random.shuffle(indexes)
    for i, ind in enumerate(indexes[:n_samples]):
        out = nn_model.kneighbors(reps[ind].reshape(1, -1))[1][0][1:]
        midi_path = convert_rep_path_midi_path(rep_paths[ind])
        copy(midi_path, neighbors_path + f'sample_{i}_playlist_{playlists[ind]}_target.mid')
        for i_n, neighbor in enumerate(out):
            midi_path = convert_rep_path_midi_path(rep_paths[neighbor])
            copy(midi_path, neighbors_path + f'sample_{i}_playlist_{playlists[neighbor]}_neighbor_{i_n}.mid')

def interpolate(reps, rep_paths, path):
    files = os.listdir(path)
    bounds = [f for f in files if 'interpolation' not in f]
    b_reps = [np.loadtxt(path + f) for f in bounds]
    nn_model = NearestNeighbors(n_neighbors=6)
    nn_model.fit(reps)
    reps = [alpha * b_reps[0] + (1 - alpha) * b_reps[1] for alpha in np.linspace(0, 1., 5)]
    copy(convert_rep_path_midi_path(path + bounds[1]), path + 'interpolation_0.mid')
    copy(convert_rep_path_midi_path(path + bounds[0]), path + 'interpolation_1.mid')
    for alpha, rep in zip(np.linspace(0, 1, 5)[1:-1], reps[1: -1]):
        dists, indexes = nn_model.kneighbors(rep.reshape(1, -1))
        if dists.flatten()[0] == 0:
            nn = indexes.flatten()[1]
        else:
            nn = indexes.flatten()[0]
        midi_path = convert_rep_path_midi_path(rep_paths[nn])
        copy(midi_path, path + f'interpolation_{alpha}.mid')

if __name__ == '__main__':
    extract_all_reps(rep_path)
    reps, rep_paths, playlists, n_data, dim_data = load_reps(rep_path)
    set_playlists, playlist_indexes, playlist_colors = get_playlist_indexes(playlists)
    # interpolate(reps, rep_paths, interpolation_path + 'trial_1/')
    sample_nn(reps, rep_paths, playlists)
    tsne_reps, umap_reps = plot_tsne(reps, playlist_indexes, playlist_colors)