from src.music.config import CHECKPOINTS_PATH import pickle import numpy as np # can be computed from cocktail2affect cluster_model_path = CHECKPOINTS_PATH + "/music2cocktails/affects2affect_cluster/cluster_model.pickle" def get_affect2affective_cluster(): with open(cluster_model_path, 'rb') as f: data = pickle.load(f) model = data['cluster_model'] dimensions_weights = data['dimensions_weights'] def find_cluster(aff_coord): if aff_coord.ndim == 1: aff_coord = aff_coord.reshape(1, -1) return model.predict(aff_coord * np.array(dimensions_weights)) return find_cluster def get_affective_cluster_centers(): with open(cluster_model_path, 'rb') as f: data = pickle.load(f) return data['cluster_centers']