Spaces:
Runtime error
Runtime error
File size: 9,207 Bytes
e775f6d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import os.path
# from src.music.data_collection.is_audio_solo_piano import calculate_piano_solo_prob
from src.music.utils import load_audio
from src.music.config import FPS
import pretty_midi as pm
import numpy as np
from src.music.config import MUSIC_REP_PATH, MUSIC_NN_PATH
from sklearn.neighbors import NearestNeighbors
from src.cocktails.config import FULL_COCKTAIL_REP_PATH, COCKTAIL_NN_PATH, COCKTAILS_CSV_DATA
# from src.cocktails.pipeline.get_affect2affective_cluster import get_affective_cluster_centers
from src.cocktails.utilities.other_scrubbing_utilities import print_recipe
from src.music.utils import get_all_subfiles_with_extension
import os
import pickle
import pandas as pd
import time
keyword = 'b256_r128_represented'
def load_reps(rep_path, sample_size=None):
if sample_size:
with open(rep_path + f'all_reps_unnormalized_sample{sample_size}.pickle', 'rb') as f:
data = pickle.load(f)
else:
with open(rep_path + f'music_reps_unnormalized.pickle', 'rb') as f:
data = pickle.load(f)
reps = data['reps']
# playlists = [r.split(f'_{keyword}')[0].split('/')[-1] for r in data['paths']]
playlists = [r.split(f'{keyword}')[1].split('/')[1] for r in data['paths']]
n_data, dim_data = reps.shape
return reps, data['paths'], playlists, n_data, dim_data
class Debugger():
def __init__(self, verbose=True):
if verbose: print('Setting up debugger.')
if not os.path.exists(MUSIC_NN_PATH):
reps_path = MUSIC_REP_PATH + 'music_reps_unnormalized.pickle'
if not os.path.exists(reps_path):
all_rep_path = get_all_subfiles_with_extension(MUSIC_REP_PATH, max_depth=3, extension='.txt', current_depth=0)
all_data = []
new_all_rep_path = []
for i_r, r in enumerate(all_rep_path):
if 'mean_std' not in r:
all_data.append(np.loadtxt(r))
assert len(all_data[-1]) == 128
new_all_rep_path.append(r)
data = np.array(all_data)
to_save = dict(reps=data,
paths=new_all_rep_path)
with open(reps_path, 'wb') as f:
pickle.dump(to_save, f)
reps, self.rep_paths, playlists, n_data, self.dim_rep_music = load_reps(MUSIC_REP_PATH)
self.nn_model_music = NearestNeighbors(n_neighbors=6, metric='cosine')
self.nn_model_music.fit(reps)
to_save = dict(nn_model=self.nn_model_music,
rep_paths=self.rep_paths,
dim_rep_music=self.dim_rep_music)
with open(MUSIC_NN_PATH, 'wb') as f:
pickle.dump(to_save, f)
else:
with open(MUSIC_NN_PATH, 'rb') as f:
data = pickle.load(f)
self.nn_model_music = data['nn_model']
self.rep_paths = data['rep_paths']
self.dim_rep_music = data['dim_rep_music']
if verbose: print(f' {len(self.rep_paths)} songs, representation dim: {self.dim_rep_music}')
self.rep_paths = np.array(self.rep_paths)
if not os.path.exists(COCKTAIL_NN_PATH):
cocktail_reps = np.loadtxt(FULL_COCKTAIL_REP_PATH)
# cocktail_reps = (cocktail_reps - cocktail_reps.mean(axis=0)) / cocktail_reps.std(axis=0)
self.nn_model_cocktail = NearestNeighbors(n_neighbors=6)
self.nn_model_cocktail.fit(cocktail_reps)
self.dim_rep_cocktail = cocktail_reps.shape[1]
self.n_cocktails = cocktail_reps.shape[0]
to_save = dict(nn_model=self.nn_model_cocktail,
dim_rep_cocktail=self.dim_rep_cocktail,
n_cocktails=self.n_cocktails)
with open(COCKTAIL_NN_PATH, 'wb') as f:
pickle.dump(to_save, f)
else:
with open(COCKTAIL_NN_PATH, 'rb') as f:
data = pickle.load(f)
self.nn_model_cocktail = data['nn_model']
self.dim_rep_cocktail = data['dim_rep_cocktail']
self.n_cocktails = data['n_cocktails']
if verbose: print(f' {self.n_cocktails} cocktails, representation dim: {self.dim_rep_cocktail}')
self.cocktail_data = pd.read_csv(COCKTAILS_CSV_DATA)
# self.affective_cluster_centers = get_affective_cluster_centers()
self.keys_to_print = ['mse_reconstruction', 'nearest_cocktail_recipes', 'nearest_cocktail_urls',
'nn_music_dists', 'nn_music', 'dim_rep', 'nb_notes', 'audio_len', 'piano_solo_prob', 'recipe_score', 'cocktail_rep']
# 'affect', 'affective_cluster_id', 'affective_cluster_center',
def get_nearest_songs(self, music_rep):
dists, indexes = self.nn_model_music.kneighbors(music_rep.reshape(1, -1))
indexes = indexes.flatten()[:5]
rep_paths = [r.split('/')[-1] for r in self.rep_paths[indexes[:5]]]
return rep_paths, dists.flatten().tolist()
def get_nearest_cocktails(self, cocktail_rep):
dists, indexes = self.nn_model_cocktail.kneighbors(cocktail_rep.reshape(1, -1))
indexes = indexes.flatten()
nn_names = np.array(self.cocktail_data['names'])[indexes].tolist()
nn_urls = np.array(self.cocktail_data['urls'])[indexes].tolist()
nn_recipes = [print_recipe(ingredient_str=ing_str, to_print=False) for ing_str in np.array(self.cocktail_data['ingredients_str'])[indexes]]
nn_ing_strs = np.array(self.cocktail_data['ingredients_str'])[indexes].tolist()
return indexes, nn_names, nn_urls, nn_recipes, nn_ing_strs
def extract_info(self, all_paths, affective_cluster_id, affect, cocktail_rep, music_reconstruction, recipe_score, verbose=False, level=0):
if verbose: print(' ' * level + 'Extracting debug info..')
init_time = time.time()
debug_dict = dict()
debug_dict['all_paths'] = all_paths
debug_dict['recipe_score'] = recipe_score
if all_paths['audio_path'] != None:
# is it piano?
debug_dict['piano_solo_prob'] = None#float(calculate_piano_solo_prob(all_paths['audio_path'])[0])
# how long is the audio
(audio, _) = load_audio(all_paths['audio_path'], sr=FPS, mono=True)
debug_dict['audio_len'] = int(len(audio) / FPS)
else:
debug_dict['piano_solo_prob'] = None
debug_dict['audio_len'] = None
# how many notes?
midi = pm.PrettyMIDI(all_paths['processed_path'])
debug_dict['nb_notes'] = len(midi.instruments[0].notes)
# dimension of music rep
representation = np.loadtxt(all_paths['representation_path'])
debug_dict['dim_rep'] = representation.shape[0]
# closest songs in dataset
debug_dict['nn_music'], debug_dict['nn_music_dists'] = self.get_nearest_songs(representation)
# get affective cluster info
# debug_dict['affective_cluster_id'] = affective_cluster_id[0]
# debug_dict['affective_cluster_center'] = self.affective_cluster_centers[affective_cluster_id].flatten().tolist()
# debug_dict['affect'] = affect.flatten().tolist()
indexes, nn_names, nn_urls, nn_recipes, nn_ing_strs = self.get_nearest_cocktails(cocktail_rep)
debug_dict['cocktail_rep'] = cocktail_rep.copy().tolist()
debug_dict['nearest_cocktail_indexes'] = indexes.tolist()
debug_dict['nn_ing_strs'] = nn_ing_strs
debug_dict['nearest_cocktail_names'] = nn_names
debug_dict['nearest_cocktail_urls'] = nn_urls
debug_dict['nearest_cocktail_recipes'] = nn_recipes
debug_dict['music_reconstruction'] = music_reconstruction.tolist()
debug_dict['mse_reconstruction'] = ((music_reconstruction - representation) ** 2).mean()
self.debug_dict = debug_dict
if verbose: print(' ' * (level + 2) + f'Debug info extracted in {int(time.time() - init_time)} seconds.')
return self.debug_dict
def print_debug(self, level=0):
print(' ' * level + '__DEBUGGING INFO__')
for k in self.keys_to_print:
to_print = self.debug_dict[k]
if k == 'nearest_cocktail_recipes':
to_print = self.debug_dict[k].copy()
for i in range(len(to_print)):
to_print[i] = to_print[i].replace('\n', '').replace('\t', '').replace('()', '')
if k == "nn_music":
to_print = self.debug_dict[k].copy()
for i in range(len(to_print)):
to_print[i] = to_print[i].replace('encoded_new_structured_', '').replace('_represented.txt', '')
to_print_str = f'{to_print}'
if isinstance(to_print, float):
to_print_str = f'{to_print:.2f}'
elif isinstance(to_print, list):
if isinstance(to_print[0], float):
to_print_str = '['
for element in to_print:
to_print_str += f'{element:.2f}, '
to_print_str = to_print_str[:-2] + ']'
print(' ' * (level + 2) + f'{k} : ' + to_print_str) |