Spaces:
Runtime error
Runtime error
from src.music.config import CHECKPOINTS_PATH | |
import pickle | |
import numpy as np | |
# can be computed from cocktail2affect | |
cluster_model_path = CHECKPOINTS_PATH + "/music2cocktails/affects2affect_cluster/cluster_model.pickle" | |
def get_affect2affective_cluster(): | |
with open(cluster_model_path, 'rb') as f: | |
data = pickle.load(f) | |
model = data['cluster_model'] | |
dimensions_weights = data['dimensions_weights'] | |
def find_cluster(aff_coord): | |
if aff_coord.ndim == 1: | |
aff_coord = aff_coord.reshape(1, -1) | |
return model.predict(aff_coord * np.array(dimensions_weights)) | |
return find_cluster | |
def get_affective_cluster_centers(): | |
with open(cluster_model_path, 'rb') as f: | |
data = pickle.load(f) | |
return data['cluster_centers'] | |