import pickle import numpy as np from src.music.config import CHECKPOINTS_PATH # these can be generated by the fit_glm script. min_handcoded_reps = np.loadtxt(CHECKPOINTS_PATH + '/music2cocktails/music2affects/min_handcoded_reps.txt') max_handcoded_reps = np.loadtxt(CHECKPOINTS_PATH + 'music2cocktails/music2affects/max_handcoded_reps.txt') affective_models_path = CHECKPOINTS_PATH + '/music2cocktails/music2affects/music2affect_models.pickle' final_keys_path = CHECKPOINTS_PATH + "/music2cocktails/music2affects/final_best_keys.pickle" def sigmoid(x, shift, beta): return (1 / (1 + np.exp(-(x + shift) * beta)) - 0.5) * 2 def normalize_handcoded_reps(handcoded_rep): return (handcoded_rep - min_handcoded_reps) / (max_handcoded_reps - min_handcoded_reps) def setup_pretrained_affective_models(): with open(final_keys_path, 'rb') as f: best_keys = pickle.load(f) keys = sorted(set(best_keys[0] + best_keys[1] + best_keys[2])) bestkeys_indexes = [np.array([keys.index(k) for k in bk]) for bk in best_keys] with open(affective_models_path, 'rb') as f: music2affect_models = pickle.load(f) def music2affect(handcoded_rep): if handcoded_rep.ndim == 1: handcoded_rep = handcoded_rep.reshape(1, -1) assert handcoded_rep.shape[1] == len(keys) handcoded_rep = normalize_handcoded_reps(handcoded_rep) affects = [] for i_dim, dim in enumerate(['valence', 'arousal', 'dominance']): model = music2affect_models[dim] my_preds = [] probas = model.predict_proba(handcoded_rep[:, bestkeys_indexes[i_dim]]) for r in probas: my_preds.append(np.mean(np.random.choice(range(1, 11), p=r, size=1000))) my_preds = np.array(my_preds) affects.append(my_preds) # affects.append(model.predict(handcoded_rep)) affects = np.array(affects).transpose() affects = ((affects - 1) / 9 - 0.5) * 2 # map to -1, 1 affects[:, 0] = sigmoid(affects[:, 0], shift=0, beta=7) # stretch for wider distribution affects[:, 1] = sigmoid(affects[:, 1], shift=-0.05, beta=5) # stretch for wider distribution affects[:, 2] = sigmoid(affects[:, 2], shift=0.05, beta=8) # stretch for wider distribution return affects return music2affect, keys