TastyPiano / src /cocktails /pipeline /get_affect2affective_cluster.py
Cédric Colas
initial commit
e775f6d
raw
history blame
781 Bytes
from src.music.config import CHECKPOINTS_PATH
import pickle
import numpy as np
# can be computed from cocktail2affect
cluster_model_path = CHECKPOINTS_PATH + "/music2cocktails/affects2affect_cluster/cluster_model.pickle"
def get_affect2affective_cluster():
with open(cluster_model_path, 'rb') as f:
data = pickle.load(f)
model = data['cluster_model']
dimensions_weights = data['dimensions_weights']
def find_cluster(aff_coord):
if aff_coord.ndim == 1:
aff_coord = aff_coord.reshape(1, -1)
return model.predict(aff_coord * np.array(dimensions_weights))
return find_cluster
def get_affective_cluster_centers():
with open(cluster_model_path, 'rb') as f:
data = pickle.load(f)
return data['cluster_centers']