Spaces:
Runtime error
Runtime error
Upload 174 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- src/__init__.py +0 -0
- src/cocktails/__init__.py +0 -0
- src/cocktails/__pycache__/__init__.cpython-39.pyc +0 -0
- src/cocktails/__pycache__/config.cpython-39.pyc +0 -0
- src/cocktails/config.py +21 -0
- src/cocktails/pipeline/__init__.py +0 -0
- src/cocktails/pipeline/__pycache__/__init__.cpython-39.pyc +0 -0
- src/cocktails/pipeline/__pycache__/cocktail2affect.cpython-39.pyc +0 -0
- src/cocktails/pipeline/__pycache__/cocktailrep2recipe.cpython-39.pyc +0 -0
- src/cocktails/pipeline/__pycache__/get_affect2affective_cluster.cpython-39.pyc +0 -0
- src/cocktails/pipeline/__pycache__/get_cocktail2affective_cluster.cpython-39.pyc +0 -0
- src/cocktails/pipeline/cocktail2affect.py +372 -0
- src/cocktails/pipeline/cocktailrep2recipe.py +329 -0
- src/cocktails/pipeline/get_affect2affective_cluster.py +23 -0
- src/cocktails/pipeline/get_cocktail2affective_cluster.py +9 -0
- src/cocktails/representation_learning/__init__.py +0 -0
- src/cocktails/representation_learning/__pycache__/__init__.cpython-39.pyc +0 -0
- src/cocktails/representation_learning/__pycache__/dataset.cpython-39.pyc +0 -0
- src/cocktails/representation_learning/__pycache__/multihead_model.cpython-39.pyc +0 -0
- src/cocktails/representation_learning/__pycache__/run.cpython-39.pyc +0 -0
- src/cocktails/representation_learning/__pycache__/run_without_vae.cpython-39.pyc +0 -0
- src/cocktails/representation_learning/__pycache__/simple_model.cpython-39.pyc +0 -0
- src/cocktails/representation_learning/__pycache__/vae_model.cpython-39.pyc +0 -0
- src/cocktails/representation_learning/dataset.py +324 -0
- src/cocktails/representation_learning/multihead_model.py +148 -0
- src/cocktails/representation_learning/run.py +557 -0
- src/cocktails/representation_learning/run_simple_net.py +302 -0
- src/cocktails/representation_learning/run_without_vae.py +514 -0
- src/cocktails/representation_learning/simple_model.py +54 -0
- src/cocktails/representation_learning/vae_model.py +238 -0
- src/cocktails/utilities/__init__.py +0 -0
- src/cocktails/utilities/__pycache__/__init__.cpython-39.pyc +0 -0
- src/cocktails/utilities/__pycache__/cocktail_category_detection_utilities.cpython-39.pyc +0 -0
- src/cocktails/utilities/__pycache__/cocktail_utilities.cpython-39.pyc +0 -0
- src/cocktails/utilities/__pycache__/glass_and_volume_utilities.cpython-39.pyc +0 -0
- src/cocktails/utilities/__pycache__/ingredients_utilities.cpython-39.pyc +0 -0
- src/cocktails/utilities/__pycache__/other_scrubbing_utilities.cpython-39.pyc +0 -0
- src/cocktails/utilities/analysis_utilities.py +189 -0
- src/cocktails/utilities/cocktail_category_detection_utilities.py +221 -0
- src/cocktails/utilities/cocktail_generation_utilities/__init__.py +0 -0
- src/cocktails/utilities/cocktail_generation_utilities/__pycache__/__init__.cpython-39.pyc +0 -0
- src/cocktails/utilities/cocktail_generation_utilities/__pycache__/individual.cpython-39.pyc +0 -0
- src/cocktails/utilities/cocktail_generation_utilities/__pycache__/population.cpython-39.pyc +0 -0
- src/cocktails/utilities/cocktail_generation_utilities/individual.py +587 -0
- src/cocktails/utilities/cocktail_generation_utilities/population.py +213 -0
- src/cocktails/utilities/cocktail_utilities.py +220 -0
- src/cocktails/utilities/glass_and_volume_utilities.py +42 -0
- src/cocktails/utilities/ingredients_utilities.py +209 -0
- src/cocktails/utilities/other_scrubbing_utilities.py +240 -0
- src/debugger.py +180 -0
src/__init__.py
ADDED
File without changes
|
src/cocktails/__init__.py
ADDED
File without changes
|
src/cocktails/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (198 Bytes). View file
|
|
src/cocktails/__pycache__/config.cpython-39.pyc
ADDED
Binary file (961 Bytes). View file
|
|
src/cocktails/config.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
REPO_PATH = '/'.join(os.path.abspath(__file__).split('/')[:-3]) + '/'
|
4 |
+
|
5 |
+
# QUADRUPLETS_PATH = REPO_PATH + 'checkpoints/cocktail_representation/quadruplets.pickle'
|
6 |
+
INGREDIENTS_LIST_PATH = REPO_PATH + 'checkpoints/cocktail_representation/ingredient_list.csv'
|
7 |
+
# ING_MATCH_SCORE_Q_PATH = REPO_PATH + 'checkpoints/cocktail_representation/ingredient_match_score_q.txt'
|
8 |
+
# ING_MATCH_SCORE_COUNT_PATH = REPO_PATH + 'checkpoints/cocktail_representation/ingredient_match_score_count.txt'
|
9 |
+
# COCKTAIL_DATA_FOLDER_PATH = REPO_PATH + 'checkpoints/cocktail_representation/'
|
10 |
+
COCKTAILS_CSV_DATA = REPO_PATH + 'checkpoints/cocktail_representation/cocktails_data.csv'
|
11 |
+
# COCKTAILS_PKL_DATA = REPO_PATH + 'checkpoints/cocktail_representation/cocktails_data.pkl'
|
12 |
+
# COCKTAILS_URL_DATA = REPO_PATH + 'checkpoints/cocktail_representation/cocktails_names_urls.pkl'
|
13 |
+
EXPERIMENT_PATH = REPO_PATH + 'experiments/cocktails/representation_learning/'
|
14 |
+
# ANALYSIS_PATH = REPO_PATH + 'experiments/cocktails/representation_analysis/'
|
15 |
+
# REPRESENTATIONS_PATH = REPO_PATH + 'experiments/cocktails/learned_representations/'
|
16 |
+
|
17 |
+
FULL_COCKTAIL_REP_PATH = REPO_PATH + "/checkpoints/cocktail_representation/handcoded_reps/cocktail_handcoded_reps_minmax_norm-1_1_dim13_customkeys.txt"
|
18 |
+
RECIPE2FEATURES_PATH = REPO_PATH + "/checkpoints/cocktail_representation/" # get this by running run_without_vae
|
19 |
+
COCKTAIL_REP_CHKPT_PATH = REPO_PATH + "/checkpoints/cocktail_representation/handcoded_reps/"
|
20 |
+
# FULL_COCKTAIL_REP_PATH = REPO_PATH + "experiments/cocktails/representation_analysis/affective_mapping/clustered_representations/all_cocktail_reps_norm-1_1_custom_keys_dim13.txt'
|
21 |
+
COCKTAIL_NN_PATH = REPO_PATH + "/checkpoints/cocktail_representation/handcoded_reps/nn_model.pickle"
|
src/cocktails/pipeline/__init__.py
ADDED
File without changes
|
src/cocktails/pipeline/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (207 Bytes). View file
|
|
src/cocktails/pipeline/__pycache__/cocktail2affect.cpython-39.pyc
ADDED
Binary file (13.5 kB). View file
|
|
src/cocktails/pipeline/__pycache__/cocktailrep2recipe.cpython-39.pyc
ADDED
Binary file (10.6 kB). View file
|
|
src/cocktails/pipeline/__pycache__/get_affect2affective_cluster.cpython-39.pyc
ADDED
Binary file (1.15 kB). View file
|
|
src/cocktails/pipeline/__pycache__/get_cocktail2affective_cluster.cpython-39.pyc
ADDED
Binary file (789 Bytes). View file
|
|
src/cocktails/pipeline/cocktail2affect.py
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import os
|
4 |
+
from src.cocktails.utilities.cocktail_utilities import get_bunch_of_rep_keys
|
5 |
+
from src.cocktails.utilities.other_scrubbing_utilities import print_recipe
|
6 |
+
from src.cocktails.config import COCKTAILS_CSV_DATA
|
7 |
+
from src.music.config import CHECKPOINTS_PATH, EXPERIMENT_PATH
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
from sklearn.cluster import KMeans
|
10 |
+
from sklearn.mixture import GaussianMixture
|
11 |
+
from sklearn.neighbors import NearestNeighbors
|
12 |
+
import pickle
|
13 |
+
import random
|
14 |
+
|
15 |
+
experiment_path = EXPERIMENT_PATH + '/cocktails/representation_analysis/affective_mapping/'
|
16 |
+
min_max_path = CHECKPOINTS_PATH + "/cocktail_representation/minmax/"
|
17 |
+
cluster_model_path = CHECKPOINTS_PATH + "/music2cocktails/affects2affect_cluster/cluster_model.pickle"
|
18 |
+
affective_space_dimensions = ((-1, 1), (-1, 1), (-1, 1)) # valence, arousal, dominance
|
19 |
+
n_splits = (3, 3, 2) # number of bins per dimension
|
20 |
+
# dimensions_weights = [1, 1, 0.5]
|
21 |
+
dimensions_weights = [1, 1, 1]
|
22 |
+
total_n_clusters = np.prod(n_splits) # total number of bins
|
23 |
+
affective_boundaries = [np.arange(asd[0], asd[1]+1e-6, (asd[1] - asd[0]) / n_split) for asd, n_split in zip(affective_space_dimensions, n_splits)]
|
24 |
+
for af in affective_boundaries:
|
25 |
+
af[-1] += 1e-6
|
26 |
+
all_keys = get_bunch_of_rep_keys()['custom']
|
27 |
+
original_affective_keys = get_bunch_of_rep_keys()['affective']
|
28 |
+
affective_keys = [a.split(' ')[1] for a in original_affective_keys]
|
29 |
+
random.seed(0)
|
30 |
+
cluster_colors = ['#%06X' % random.randint(0, 0xFFFFFF) for _ in range(total_n_clusters)]
|
31 |
+
|
32 |
+
clustering_method = 'k_means' # 'k_means', 'handcoded', 'agglo', 'spectral'
|
33 |
+
if clustering_method != 'handcoded':
|
34 |
+
total_n_clusters = 10
|
35 |
+
min_arousal = np.loadtxt(min_max_path + 'min_arousal.txt')
|
36 |
+
max_arousal = np.loadtxt(min_max_path + 'max_arousal.txt')
|
37 |
+
min_val = np.loadtxt(min_max_path + 'min_valence.txt')
|
38 |
+
max_val = np.loadtxt(min_max_path + 'max_valence.txt')
|
39 |
+
min_dom = np.loadtxt(min_max_path + 'min_dominance.txt')
|
40 |
+
max_dom = np.loadtxt(min_max_path + 'max_dominance.txt')
|
41 |
+
|
42 |
+
def get_cocktail_reps(path, save=False):
|
43 |
+
cocktail_data = pd.read_csv(path)
|
44 |
+
cocktail_reps = np.array([cocktail_data[k] for k in original_affective_keys]).transpose()
|
45 |
+
n_data, dim_rep = cocktail_reps.shape
|
46 |
+
# print(f'{n_data} data points of {dim_rep} dimensions: {affective_keys}')
|
47 |
+
cocktail_reps = normalize_cocktail_reps_affective(cocktail_reps, save=save)
|
48 |
+
if save:
|
49 |
+
np.savetxt(experiment_path + f'cocktail_reps_for_affective_mapping_-1_1_norm_sigmoid_rescaling_{dim_rep}_keys.txt', cocktail_reps)
|
50 |
+
return cocktail_reps
|
51 |
+
|
52 |
+
def sigmoid(x, shift, beta):
|
53 |
+
return (1 / (1 + np.exp(-(x + shift) * beta)) - 0.5) * 2
|
54 |
+
|
55 |
+
def normalize_cocktail_reps_affective(cocktail_reps, save=False):
|
56 |
+
if save:
|
57 |
+
min_cr = cocktail_reps.min(axis=0)
|
58 |
+
max_cr = cocktail_reps.max(axis=0)
|
59 |
+
np.savetxt(min_max_path + 'min_cocktail_reps_affective.txt', min_cr)
|
60 |
+
np.savetxt(min_max_path + 'max_cocktail_reps_affective.txt', max_cr)
|
61 |
+
else:
|
62 |
+
min_cr = np.loadtxt(min_max_path + 'min_cocktail_reps_affective.txt')
|
63 |
+
max_cr = np.loadtxt(min_max_path + 'max_cocktail_reps_affective.txt')
|
64 |
+
cocktail_reps = ((cocktail_reps - min_cr) / (max_cr - min_cr) - 0.5) * 2
|
65 |
+
cocktail_reps[:, 0] = sigmoid(cocktail_reps[:, 0], shift=0.05, beta=4)
|
66 |
+
cocktail_reps[:, 1] = sigmoid(cocktail_reps[:, 1], shift=0.3, beta=5)
|
67 |
+
cocktail_reps[:, 2] = sigmoid(cocktail_reps[:, 2], shift=0.15, beta=3)
|
68 |
+
cocktail_reps[:, 3] = sigmoid(cocktail_reps[:, 3], shift=0.9, beta=20)
|
69 |
+
cocktail_reps[:, 4] = sigmoid(cocktail_reps[:, 4], shift=0, beta=4)
|
70 |
+
cocktail_reps[:, 5] = sigmoid(cocktail_reps[:, 5], shift=0.2, beta=3)
|
71 |
+
cocktail_reps[:, 6] = sigmoid(cocktail_reps[:, 6], shift=0.5, beta=5)
|
72 |
+
cocktail_reps[:, 7] = sigmoid(cocktail_reps[:, 7], shift=0.2, beta=6)
|
73 |
+
return cocktail_reps
|
74 |
+
|
75 |
+
def plot(cocktail_reps):
|
76 |
+
dim_rep = cocktail_reps.shape[1]
|
77 |
+
for i in range(dim_rep):
|
78 |
+
for j in range(i+1, dim_rep):
|
79 |
+
plt.figure()
|
80 |
+
plt.scatter(cocktail_reps[:, i], cocktail_reps[:, j], s=150, alpha=0.5)
|
81 |
+
plt.xlabel(affective_keys[i])
|
82 |
+
plt.ylabel(affective_keys[j])
|
83 |
+
plt.savefig(experiment_path + f'scatters/{affective_keys[i]}_vs_{affective_keys[j]}.png', dpi=300)
|
84 |
+
plt.close('all')
|
85 |
+
plt.figure()
|
86 |
+
plt.hist(cocktail_reps[:, i])
|
87 |
+
plt.xlabel(affective_keys[i])
|
88 |
+
plt.savefig(experiment_path + f'hists/{affective_keys[i]}.png', dpi=300)
|
89 |
+
plt.close('all')
|
90 |
+
|
91 |
+
def get_clusters(affective_coordinates, save=False):
|
92 |
+
if clustering_method in ['k_means', 'gmm',]:
|
93 |
+
if clustering_method == 'k_means': model = KMeans(n_clusters=total_n_clusters)
|
94 |
+
elif clustering_method == 'gmm': model = GaussianMixture(n_components=total_n_clusters, covariance_type="full")
|
95 |
+
model.fit(affective_coordinates * np.array(dimensions_weights))
|
96 |
+
|
97 |
+
def find_cluster(aff_coord):
|
98 |
+
if aff_coord.ndim == 1:
|
99 |
+
aff_coord = aff_coord.reshape(1, -1)
|
100 |
+
return model.predict(aff_coord * np.array(dimensions_weights))
|
101 |
+
cluster_centers = model.cluster_centers_ if clustering_method == 'k_means' else []
|
102 |
+
if save:
|
103 |
+
to_save = dict(cluster_model=model,
|
104 |
+
cluster_centers=cluster_centers,
|
105 |
+
nb_clusters=len(cluster_centers),
|
106 |
+
dimensions_weights=dimensions_weights)
|
107 |
+
with open(cluster_model_path, 'wb') as f:
|
108 |
+
pickle.dump(to_save, f)
|
109 |
+
stop= 1
|
110 |
+
|
111 |
+
elif clustering_method == 'handcoded':
|
112 |
+
def find_cluster(aff_coord):
|
113 |
+
if aff_coord.ndim == 1:
|
114 |
+
aff_coord = aff_coord.reshape(1, -1)
|
115 |
+
cluster_coordinates = []
|
116 |
+
for i in range(aff_coord.shape[0]):
|
117 |
+
cluster_coordinates.append([np.argwhere(affective_boundaries[j] <= aff_coord[i, j]).flatten()[-1] for j in range(3)])
|
118 |
+
cluster_coordinates = np.array(cluster_coordinates)
|
119 |
+
cluster_ids = cluster_coordinates[:, 0] * np.prod(n_splits[1:]) + cluster_coordinates[:, 1] * n_splits[-1] + cluster_coordinates[:, 2]
|
120 |
+
return cluster_ids
|
121 |
+
# find cluster centers
|
122 |
+
cluster_centers = []
|
123 |
+
for i in range(n_splits[0]):
|
124 |
+
asd = affective_space_dimensions[0]
|
125 |
+
x_coordinate = np.arange(asd[0] + 1 / n_splits[0], asd[1], (asd[1] - asd[0]) / n_splits[0])[i]
|
126 |
+
for j in range(n_splits[1]):
|
127 |
+
asd = affective_space_dimensions[1]
|
128 |
+
y_coordinate = np.arange(asd[0] + 1 / n_splits[1], asd[1], (asd[1] - asd[0]) / n_splits[1])[j]
|
129 |
+
for k in range(n_splits[2]):
|
130 |
+
asd = affective_space_dimensions[2]
|
131 |
+
z_coordinate = np.arange(asd[0] + 1 / n_splits[2], asd[1], (asd[1] - asd[0]) / n_splits[2])[k]
|
132 |
+
cluster_centers.append([x_coordinate, y_coordinate, z_coordinate])
|
133 |
+
cluster_centers = np.array(cluster_centers)
|
134 |
+
else:
|
135 |
+
raise NotImplemented
|
136 |
+
cluster_ids = find_cluster(affective_coordinates)
|
137 |
+
return cluster_ids, cluster_centers, find_cluster
|
138 |
+
|
139 |
+
|
140 |
+
def cocktail2affect(cocktail_reps, save=False):
|
141 |
+
if cocktail_reps.ndim == 1:
|
142 |
+
cocktail_reps = cocktail_reps.reshape(1, -1)
|
143 |
+
|
144 |
+
assert affective_keys == ['booze', 'sweet', 'sour', 'fizzy', 'complex', 'bitter', 'spicy', 'colorful']
|
145 |
+
all_weights = []
|
146 |
+
|
147 |
+
# valence
|
148 |
+
# + sweet - bitter - booze + colorful
|
149 |
+
weights = np.array([-1, 1, 0, 0, 0, -1, 0, 1])
|
150 |
+
valence = (cocktail_reps * weights).sum(axis=1)
|
151 |
+
if save:
|
152 |
+
min_ = valence.min()
|
153 |
+
max_ = valence.max()
|
154 |
+
np.savetxt(min_max_path + 'min_valence.txt', np.array([min_]))
|
155 |
+
np.savetxt(min_max_path + 'max_valence.txt', np.array([max_]))
|
156 |
+
else:
|
157 |
+
min_ = min_val
|
158 |
+
max_ = max_val
|
159 |
+
valence = 2 * ((valence - min_) / (max_ - min_) - 0.5)
|
160 |
+
valence = sigmoid(valence, shift=0.1, beta=3.5)
|
161 |
+
valence = valence.reshape(-1, 1)
|
162 |
+
all_weights.append(weights.copy())
|
163 |
+
|
164 |
+
# arousal
|
165 |
+
# + fizzy + sour + complex - sweet + spicy + bitter
|
166 |
+
# weights = np.array([0, -1, 1, 1, 1, 1, 1, 0])
|
167 |
+
weights = np.array([0.7, 0, 1.5, 1.5, 0.6, 0, 0.6, 0])
|
168 |
+
arousal = (cocktail_reps * weights).sum(axis=1)
|
169 |
+
if save:
|
170 |
+
min_ = arousal.min()
|
171 |
+
max_ = arousal.max()
|
172 |
+
np.savetxt(min_max_path + 'min_arousal.txt', np.array([min_]))
|
173 |
+
np.savetxt(min_max_path + 'max_arousal.txt', np.array([max_]))
|
174 |
+
else:
|
175 |
+
min_, max_ = min_arousal, max_arousal
|
176 |
+
arousal = 2 * ((arousal - min_) / (max_ - min_) - 0.5) # normalize to -1, 1
|
177 |
+
arousal = sigmoid(arousal, shift=0.3, beta=4)
|
178 |
+
arousal = arousal.reshape(-1, 1)
|
179 |
+
all_weights.append(weights.copy())
|
180 |
+
|
181 |
+
# dominance
|
182 |
+
# assert affective_keys == ['booze', 'sweet', 'sour', 'fizzy', 'complex', 'bitter', 'spicy', 'colorful']
|
183 |
+
# + booze + fizzy - complex - bitter - sweet
|
184 |
+
weights = np.array([1.5, -0.8, 0, 0.7, -1, -1.5, 0, 0])
|
185 |
+
dominance = (cocktail_reps * weights).sum(axis=1)
|
186 |
+
if save:
|
187 |
+
min_ = dominance.min()
|
188 |
+
max_ = dominance.max()
|
189 |
+
np.savetxt(min_max_path + 'min_dominance.txt', np.array([min_]))
|
190 |
+
np.savetxt(min_max_path + 'max_dominance.txt', np.array([max_]))
|
191 |
+
else:
|
192 |
+
min_, max_ = min_dom, max_dom
|
193 |
+
dominance = 2 * ((dominance - min_) / (max_ - min_) - 0.5)
|
194 |
+
dominance = sigmoid(dominance, shift=-0.05, beta=5)
|
195 |
+
dominance = dominance.reshape(-1, 1)
|
196 |
+
all_weights.append(weights.copy())
|
197 |
+
|
198 |
+
affective_coordinates = np.concatenate([valence, arousal, dominance], axis=1)
|
199 |
+
# if save:
|
200 |
+
# assert (affective_coordinates.min(axis=0) == np.array([ac[0] for ac in affective_space_dimensions])).all()
|
201 |
+
# assert (affective_coordinates.max(axis=0) == np.array([ac[1] for ac in affective_space_dimensions])).all()
|
202 |
+
return affective_coordinates, all_weights
|
203 |
+
|
204 |
+
def save_reps(path, affective_cluster_ids):
|
205 |
+
cocktail_data = pd.read_csv(path)
|
206 |
+
rep_keys = get_bunch_of_rep_keys()['custom']
|
207 |
+
cocktail_reps = np.array([cocktail_data[k] for k in rep_keys]).transpose()
|
208 |
+
np.savetxt(experiment_path + 'clustered_representations/' + f'min_cocktail_reps_custom_keys_dim{cocktail_reps.shape[1]}.txt', cocktail_reps.min(axis=0))
|
209 |
+
np.savetxt(experiment_path + 'clustered_representations/' + f'max_cocktail_reps_custom_keys_dim{cocktail_reps.shape[1]}.txt', cocktail_reps.max(axis=0))
|
210 |
+
cocktail_reps = ((cocktail_reps - cocktail_reps.min(axis=0)) / (cocktail_reps.max(axis=0) - cocktail_reps.min(axis=0)) - 0.5) * 2 # normalize in -1, 1
|
211 |
+
np.savetxt(experiment_path + 'clustered_representations/' + f'all_cocktail_reps_norm-1_1_custom_keys_dim{cocktail_reps.shape[1]}.txt', cocktail_reps)
|
212 |
+
np.savetxt(experiment_path + 'clustered_representations/' + 'affective_cluster_ids.txt', affective_cluster_ids)
|
213 |
+
for cluster_id in sorted(set(affective_cluster_ids)):
|
214 |
+
indexes = np.argwhere(affective_cluster_ids == cluster_id).flatten()
|
215 |
+
reps = cocktail_reps[indexes, :]
|
216 |
+
np.savetxt(experiment_path + 'clustered_representations/' + f'rep_cluster{cluster_id}_norm-1_1_custom_keys_dim{cocktail_reps.shape[1]}.txt', reps)
|
217 |
+
|
218 |
+
def study_affects(affective_coordinates, affective_cluster_ids):
|
219 |
+
plt.figure()
|
220 |
+
plt.hist(affective_cluster_ids, bins=total_n_clusters)
|
221 |
+
plt.xlabel('Affective cluster ids')
|
222 |
+
plt.xticks(np.arange(total_n_clusters))
|
223 |
+
plt.savefig(experiment_path + 'affective_cluster_distrib.png')
|
224 |
+
fig = plt.gcf()
|
225 |
+
plt.close(fig)
|
226 |
+
|
227 |
+
fig = plt.figure()
|
228 |
+
ax = fig.add_subplot(projection='3d')
|
229 |
+
ax.set_xlim([-1, 1])
|
230 |
+
ax.set_ylim([-1, 1])
|
231 |
+
ax.set_zlim([-1, 1])
|
232 |
+
for cluster_id in sorted(set(affective_cluster_ids)):
|
233 |
+
indexes = np.argwhere(affective_cluster_ids == cluster_id).flatten()
|
234 |
+
ax.scatter(affective_coordinates[indexes, 0], affective_coordinates[indexes, 1], affective_coordinates[indexes, 2], c=cluster_colors[cluster_id], s=150)
|
235 |
+
ax.set_xlabel('Valence')
|
236 |
+
ax.set_ylabel('Arousal')
|
237 |
+
ax.set_zlabel('Dominance')
|
238 |
+
stop = 1
|
239 |
+
plt.savefig(experiment_path + 'scatters_affect/affective_mapping.png')
|
240 |
+
fig = plt.gcf()
|
241 |
+
plt.close(fig)
|
242 |
+
|
243 |
+
affects = ['Valence', 'Arousal', 'Dominance']
|
244 |
+
for i in range(3):
|
245 |
+
for j in range(i + 1, 3):
|
246 |
+
fig = plt.figure()
|
247 |
+
ax = fig.add_subplot()
|
248 |
+
for cluster_id in sorted(set(affective_cluster_ids)):
|
249 |
+
indexes = np.argwhere(affective_cluster_ids == cluster_id).flatten()
|
250 |
+
ax.scatter(affective_coordinates[indexes, i], affective_coordinates[indexes, j], alpha=0.5, c=cluster_colors[cluster_id], s=150)
|
251 |
+
ax.set_xlabel(affects[i])
|
252 |
+
ax.set_ylabel(affects[j])
|
253 |
+
plt.savefig(experiment_path + f'scatters_affect/scatter_{affects[i]}_vs_{affects[j]}.png')
|
254 |
+
fig = plt.gcf()
|
255 |
+
plt.close(fig)
|
256 |
+
plt.figure()
|
257 |
+
plt.hist(affective_coordinates[:, i])
|
258 |
+
plt.xlabel(affects[i])
|
259 |
+
plt.savefig(experiment_path + f'hists_affect/hist_{affects[i]}.png')
|
260 |
+
fig = plt.gcf()
|
261 |
+
plt.close(fig)
|
262 |
+
plt.close('all')
|
263 |
+
stop = 1
|
264 |
+
|
265 |
+
def sample_clusters(path, cocktail_reps, all_weights, affective_cluster_ids, affective_cluster_centers, affective_coordinates, n_samples=4):
|
266 |
+
cocktail_data = pd.read_csv(path)
|
267 |
+
these_cocktail_reps = normalize_cocktail_reps_affective(np.array([cocktail_data[k] for k in original_affective_keys]).transpose())
|
268 |
+
|
269 |
+
names = cocktail_data['names']
|
270 |
+
urls = cocktail_data['urls']
|
271 |
+
ingr_str = cocktail_data['ingredients_str']
|
272 |
+
for cluster_id in sorted(set(affective_cluster_ids)):
|
273 |
+
indexes = np.argwhere(affective_cluster_ids == cluster_id).flatten()
|
274 |
+
print('\n\n\n---------\n----------\n-----------\n')
|
275 |
+
cluster_str = ''
|
276 |
+
cluster_str += f'Affective cluster #{cluster_id}' + \
|
277 |
+
f'\n\tSize: {len(indexes)}' + \
|
278 |
+
f'\n\tCenter: ' + \
|
279 |
+
f'\n\t\tVal: {affective_cluster_centers[cluster_id][0]:.2f}, ' + \
|
280 |
+
f'\n\t\tArousal: {affective_cluster_centers[cluster_id][1]:.2f}, ' + \
|
281 |
+
f'\n\t\tDominance: {affective_cluster_centers[cluster_id][2]:.2f}'
|
282 |
+
print(cluster_str)
|
283 |
+
if affective_cluster_centers[cluster_id][2] == np.max(affective_cluster_centers[:, 2]):
|
284 |
+
stop = 1
|
285 |
+
sampled_idx = np.random.choice(indexes, size=min(len(indexes), n_samples), replace=False)
|
286 |
+
cocktail_str = ''
|
287 |
+
for i in sampled_idx:
|
288 |
+
assert np.sum(cocktail_reps[i] - these_cocktail_reps[i]) < 1e-9
|
289 |
+
cocktail_str += f'\n\n-------------'
|
290 |
+
cocktail_str += print_recipe(ingr_str[i], name=names[i], to_print=False)
|
291 |
+
cocktail_str += f'\nUrl: {urls[i]}'
|
292 |
+
cocktail_str += '\n\nRepresentation: ' + ', '.join([f'{af}: {cr:.2f}' for af, cr in zip(affective_keys, cocktail_reps[i])]) + '\n'
|
293 |
+
cocktail_str += '\n' + generate_explanation(cocktail_reps[i], all_weights, affective_coordinates[i])
|
294 |
+
print(cocktail_str)
|
295 |
+
stop = 1
|
296 |
+
cluster_str += '\n' + cocktail_str
|
297 |
+
with open(f"/home/cedric/Documents/pianocktail/experiments/cocktails/representation_analysis/affective_mapping/clusters/cluster_{cluster_id}", 'w') as f:
|
298 |
+
f.write(cluster_str)
|
299 |
+
stop = 1
|
300 |
+
|
301 |
+
def explanation_per_dimension(i, cocktail_rep, all_weights, aff_coord):
|
302 |
+
names = ['valence', 'arousal', 'dominance']
|
303 |
+
weights = all_weights[i]
|
304 |
+
explanation_str = f'\n{names[i].capitalize()} explanation ({aff_coord[i]:.2f}):'
|
305 |
+
strengths = np.abs(weights * cocktail_rep)
|
306 |
+
strengths /= strengths.sum()
|
307 |
+
indexes = np.flip(np.argsort(strengths))
|
308 |
+
for ind in indexes:
|
309 |
+
if strengths[ind] != 0:
|
310 |
+
if np.sign(weights[ind]) == np.sign(cocktail_rep[ind]):
|
311 |
+
keyword = 'high' if cocktail_rep[ind] > 0 else 'low'
|
312 |
+
explanation_str += f'\n\t{int(strengths[ind]*100)}%: higher {names[i]} because {keyword} {affective_keys[ind]}'
|
313 |
+
else:
|
314 |
+
keyword = 'high' if cocktail_rep[ind] > 0 else 'low'
|
315 |
+
explanation_str += f'\n\t{int(strengths[ind]*100)}%: low {names[i]} because {keyword} {affective_keys[ind]}'
|
316 |
+
return explanation_str
|
317 |
+
|
318 |
+
def generate_explanation(cocktail_rep, all_weights, aff_coord):
|
319 |
+
explanation_str = ''
|
320 |
+
for i in range(3):
|
321 |
+
explanation_str += explanation_per_dimension(i, cocktail_rep, all_weights, aff_coord)
|
322 |
+
return explanation_str
|
323 |
+
|
324 |
+
def cocktails2affect_clusters(cocktail_rep):
|
325 |
+
if cocktail_rep.ndim == 1:
|
326 |
+
cocktail_rep = cocktail_rep.reshape(1, -1)
|
327 |
+
affective_coordinates, _ = cocktail2affect(cocktail_rep)
|
328 |
+
affective_cluster_ids, _, _ = get_clusters(affective_coordinates)
|
329 |
+
return affective_cluster_ids
|
330 |
+
|
331 |
+
|
332 |
+
def setup_affective_space(path, save=False):
|
333 |
+
cocktail_data = pd.read_csv(path)
|
334 |
+
names = cocktail_data['names']
|
335 |
+
recipes = cocktail_data['ingredients_str']
|
336 |
+
urls = cocktail_data['urls']
|
337 |
+
reps = get_cocktail_reps(path)
|
338 |
+
affective_coordinates, all_weights = cocktail2affect(reps)
|
339 |
+
affective_cluster_ids, affective_cluster_centers, find_cluster = get_clusters(affective_coordinates, save=save)
|
340 |
+
nn_model = NearestNeighbors(n_neighbors=1)
|
341 |
+
nn_model.fit(affective_coordinates)
|
342 |
+
def cocktail2affect_cluster(cocktail_rep):
|
343 |
+
affective_coordinates, _ = cocktail2affect(cocktail_rep)
|
344 |
+
return find_cluster(affective_coordinates)
|
345 |
+
|
346 |
+
affective_clusters = dict(affective_coordinates=affective_coordinates, # coordinates of cocktail in affective space
|
347 |
+
affective_cluster_ids=affective_cluster_ids, # cluster id of cocktails
|
348 |
+
affective_cluster_centers=affective_cluster_centers, # cluster centers in affective space
|
349 |
+
affective_weights=all_weights, # weights to compute valence, arousal, dominance from cocktail representations
|
350 |
+
original_affective_keys=original_affective_keys,
|
351 |
+
cocktail_reps=reps, # cocktail representations from the dataset (normalized)
|
352 |
+
find_cluster=find_cluster, # function to retrieve a cluster from affective coordinates
|
353 |
+
nn_model=nn_model, # to predict the nearest neighbor affective space,
|
354 |
+
names=names, # names of cocktails in the dataset
|
355 |
+
urls=urls, # urls from the dataset
|
356 |
+
recipes=recipes, # recipes of the dataset
|
357 |
+
cocktail2affect=cocktail2affect, # function to compute affects from cocktail representations
|
358 |
+
cocktails2affect_clusters=cocktails2affect_clusters,
|
359 |
+
cocktail2affect_cluster=cocktail2affect_cluster
|
360 |
+
)
|
361 |
+
|
362 |
+
return affective_clusters
|
363 |
+
|
364 |
+
if __name__ == '__main__':
|
365 |
+
reps = get_cocktail_reps(COCKTAILS_CSV_DATA, save=True)
|
366 |
+
# plot(reps)
|
367 |
+
affective_coordinates, all_weights = cocktail2affect(reps, save=True)
|
368 |
+
affective_cluster_ids, affective_cluster_centers, find_cluster = get_clusters(affective_coordinates)
|
369 |
+
save_reps(COCKTAILS_CSV_DATA, affective_cluster_ids)
|
370 |
+
study_affects(affective_coordinates, affective_cluster_ids)
|
371 |
+
sample_clusters(COCKTAILS_CSV_DATA, reps, all_weights, affective_cluster_ids, affective_cluster_centers, affective_coordinates)
|
372 |
+
setup_affective_space(COCKTAILS_CSV_DATA, save=True)
|
src/cocktails/pipeline/cocktailrep2recipe.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
import pickle
|
3 |
+
from src.cocktails.utilities.cocktail_generation_utilities.population import *
|
4 |
+
from src.cocktails.utilities.glass_and_volume_utilities import glass_volume
|
5 |
+
from src.cocktails.config import RECIPE2FEATURES_PATH
|
6 |
+
|
7 |
+
def test_mutation_params(cocktail_reps):
|
8 |
+
indexes = np.arange(cocktail_reps.shape[0])
|
9 |
+
np.random.shuffle(indexes)
|
10 |
+
perfs = []
|
11 |
+
mutated_perfs = []
|
12 |
+
pop_params = dict(mutation_params=dict(p_add_ing=0.7,
|
13 |
+
p_remove_ing=0.7,
|
14 |
+
p_switch_ing=0.5,
|
15 |
+
p_change_q=0.7,
|
16 |
+
delta_change_q=0.3,
|
17 |
+
asexual_rep=True,
|
18 |
+
crossover=True,
|
19 |
+
ingredient_addition=(0.1, 0.05)),
|
20 |
+
nb_generations=100,
|
21 |
+
pop_size=100,
|
22 |
+
nb_elites=10,
|
23 |
+
dist='mse',
|
24 |
+
n_neighbors=5)
|
25 |
+
|
26 |
+
for i in indexes[:20]:
|
27 |
+
target = cocktail_reps[i]
|
28 |
+
for j in range(100):
|
29 |
+
parent = IndividualCocktail(pop_params=pop_params,
|
30 |
+
target_affective_cluster=None,
|
31 |
+
target=target.copy())
|
32 |
+
perfs.append(parent.perf)
|
33 |
+
child = parent.get_child()[0]
|
34 |
+
# child.compute_cocktail_rep()
|
35 |
+
# child.compute_perf()
|
36 |
+
if perfs[-1] != child.perf:
|
37 |
+
mutated_perfs.append(child.perf)
|
38 |
+
else:
|
39 |
+
perfs.pop(-1)
|
40 |
+
filtered_children = np.argwhere(np.array(mutated_perfs)==-100).flatten()
|
41 |
+
non_filtered_ids = np.argwhere(np.logical_and(np.array(perfs)!=-100, np.array(mutated_perfs)!=-100)).flatten()
|
42 |
+
print(f'Proportion of filtered: {filtered_children.size} / {len(mutated_perfs)} = {int(filtered_children.size / len(mutated_perfs)*100)}%')
|
43 |
+
plt.figure()
|
44 |
+
plt.scatter(np.array(perfs)[non_filtered_ids], np.array(mutated_perfs)[non_filtered_ids], s=100, alpha=0.5)
|
45 |
+
plt.xlabel('parent perf')
|
46 |
+
plt.ylabel('child perf')
|
47 |
+
print(np.corrcoef(np.array(perfs)[non_filtered_ids], np.array(mutated_perfs)[non_filtered_ids])[0, 1])
|
48 |
+
plt.show()
|
49 |
+
stop = 1
|
50 |
+
|
51 |
+
def test_crossover(cocktail_reps):
|
52 |
+
indexes = np.arange(cocktail_reps.shape[0])
|
53 |
+
np.random.shuffle(indexes)
|
54 |
+
perfs = []
|
55 |
+
mutated_perfs = []
|
56 |
+
pop_params = dict(mutation_params=dict(p_add_ing=0.7,
|
57 |
+
p_remove_ing=0.7,
|
58 |
+
p_switch_ing=0.5,
|
59 |
+
p_change_q=0.7,
|
60 |
+
delta_change_q=0.3,
|
61 |
+
asexual_rep=True,
|
62 |
+
crossover=True,
|
63 |
+
ingredient_addition=(0.1, 0.05)),
|
64 |
+
nb_generations=100,
|
65 |
+
pop_size=100,
|
66 |
+
nb_elites=10,
|
67 |
+
dist='mse',
|
68 |
+
n_neighbors=5)
|
69 |
+
for i in indexes[:20]:
|
70 |
+
for j in range(100):
|
71 |
+
target = cocktail_reps[i]
|
72 |
+
parent1 = IndividualCocktail(pop_params=pop_params,
|
73 |
+
target_affective_cluster=None,
|
74 |
+
target=target.copy())
|
75 |
+
parent2 = IndividualCocktail(pop_params=pop_params,
|
76 |
+
target_affective_cluster=None,
|
77 |
+
target=target.copy())
|
78 |
+
child = parent1.get_child_with(parent2)[0]
|
79 |
+
# child.compute_cocktail_rep()
|
80 |
+
# child.compute_perf()
|
81 |
+
perfs.append((parent1.perf + parent2.perf)/2)
|
82 |
+
if perfs[-1] != child.perf:
|
83 |
+
mutated_perfs.append(child.perf)
|
84 |
+
else:
|
85 |
+
perfs.pop(-1)
|
86 |
+
filtered_children = np.argwhere(np.array(mutated_perfs)==-100).flatten()
|
87 |
+
non_filtered_ids = np.argwhere(np.logical_and(np.array(perfs)>-45, np.array(mutated_perfs)!=-100)).flatten()
|
88 |
+
print(f'Proportion of filtered: {filtered_children.size} / {len(mutated_perfs)} = {int(filtered_children.size / len(mutated_perfs)*100)}%')
|
89 |
+
plt.figure()
|
90 |
+
plt.scatter(np.array(perfs)[non_filtered_ids], np.array(mutated_perfs)[non_filtered_ids], s=100, alpha=0.5)
|
91 |
+
plt.xlabel('parent perf')
|
92 |
+
plt.ylabel('child perf')
|
93 |
+
print(np.corrcoef(np.array(perfs)[non_filtered_ids], np.array(mutated_perfs)[non_filtered_ids])[0, 1])
|
94 |
+
plt.show()
|
95 |
+
stop = 1
|
96 |
+
|
97 |
+
def run_comparisons():
|
98 |
+
np.random.seed(0)
|
99 |
+
indexes = np.arange(cocktail_reps.shape[0])
|
100 |
+
np.random.shuffle(indexes)
|
101 |
+
for n_neighbors in [0, 5]:
|
102 |
+
id_str_neigh = '5neigh_' if n_neighbors == 5 else '0_neigh_'
|
103 |
+
for asexual_rep in [True, False]:
|
104 |
+
id_str_as = id_str_neigh + 'asexual_' if asexual_rep else id_str_neigh
|
105 |
+
for crossover in [True, False]:
|
106 |
+
id_str = id_str_as + 'crossover_' if crossover else id_str_as
|
107 |
+
if crossover or asexual_rep:
|
108 |
+
mutation_params = dict(p_add_ing = 0.5,
|
109 |
+
p_remove_ing = 0.5,
|
110 |
+
p_change_q = 0.5,
|
111 |
+
delta_change_q = 0.3,
|
112 |
+
asexual_rep=asexual_rep,
|
113 |
+
crossover=crossover,
|
114 |
+
ingredient_addition = (0.1, 0.05))
|
115 |
+
nb_generations = 100
|
116 |
+
pop_size=100
|
117 |
+
nb_elites=10
|
118 |
+
dist = 'mse'
|
119 |
+
results = dict()
|
120 |
+
print(id_str)
|
121 |
+
for i, ind in enumerate(indexes[:30]):
|
122 |
+
print(i+1)
|
123 |
+
target_ing_str = data['ingredients_str'][ind]
|
124 |
+
target = cocktail_reps[ind]
|
125 |
+
population = Population(nb_generations=nb_generations, pop_size=pop_size, nb_elite=nb_elites,
|
126 |
+
target=target, dist=dist, mutation_params=mutation_params,
|
127 |
+
n_neighbors=n_neighbors, target_ing_str=target_ing_str, true_prep_type=data['category'][ind])
|
128 |
+
population.run_evolution(verbose=False)
|
129 |
+
best_scores, best_ind = population.get_best_score()
|
130 |
+
recipes = [ind.get_recipe()[3] for ind in best_ind[:5]]
|
131 |
+
results[str(ind)] = dict(best_scores=best_scores[:5], recipes=recipes, target=population.target_individual.get_recipe()[3])
|
132 |
+
with open(f'/home/cedric/Desktop/ga_tests_{id_str}.pickle', 'wb') as f:
|
133 |
+
pickle.dump(results, f)
|
134 |
+
|
135 |
+
def get_cocktail_distribution(cocktail_reps):
|
136 |
+
return (np.mean(cocktail_reps, axis=0), np.cov(cocktail_reps, rowvar=0))
|
137 |
+
|
138 |
+
def sample_cocktails(cocktail_reps, n=10, target_affective_cluster=None, to_print=True):
|
139 |
+
distrib = get_cocktail_distribution(cocktail_reps)
|
140 |
+
sampled_cocktail_reps = np.random.multivariate_normal(distrib[0], distrib[1], size=n)
|
141 |
+
recipes = []
|
142 |
+
closest_recipes = []
|
143 |
+
for i_c, cr in enumerate(sampled_cocktail_reps):
|
144 |
+
population = setup_recipe_generation(cr.copy(), target_affective_cluster=target_affective_cluster)
|
145 |
+
closest_recipes.append(population.nn_recipes[0])
|
146 |
+
best_scores, best_individuals = population.run_evolution()
|
147 |
+
recipes.append(best_individuals[0].get_recipe()[3])
|
148 |
+
if to_print:
|
149 |
+
print(f'Sample #{len(recipes)}:')
|
150 |
+
print(recipes[-1])
|
151 |
+
print('Closest from dataset:')
|
152 |
+
print(closest_recipes[-1])
|
153 |
+
stop = 1
|
154 |
+
return recipes, closest_recipes
|
155 |
+
|
156 |
+
def setup_recipe_generation(target, known_target_dict=None, target_affective_cluster=None):
|
157 |
+
# pop_params = dict(mutation_params=dict(p_add_ing=0.7,
|
158 |
+
# p_remove_ing=0.7,
|
159 |
+
# p_switch_ing=0.5,
|
160 |
+
# p_change_q=0.7,
|
161 |
+
# delta_change_q=0.3,
|
162 |
+
# asexual_rep=True,
|
163 |
+
# crossover=True,
|
164 |
+
# ingredient_addition=(0.1, 0.05)),
|
165 |
+
# nb_generations=2, #100
|
166 |
+
# pop_size=5, #100
|
167 |
+
# nb_elites=2, #10
|
168 |
+
# dist='mse',
|
169 |
+
# n_neighbors=3) #5
|
170 |
+
pop_params = dict(mutation_params=dict(p_add_ing=0.4,
|
171 |
+
p_remove_ing=1,
|
172 |
+
p_switch_ing=0.5,
|
173 |
+
p_change_q=1,
|
174 |
+
delta_change_q=0.3,
|
175 |
+
asexual_rep=True,
|
176 |
+
crossover=True,
|
177 |
+
ingredient_addition=(0.1, 0.05)),
|
178 |
+
nb_generations=100, # 100
|
179 |
+
pop_size=100, # 100
|
180 |
+
nb_elites=10, # 10
|
181 |
+
dist='mse',
|
182 |
+
n_neighbors=5) # 5
|
183 |
+
|
184 |
+
population = Population(target=target, target_affective_cluster=target_affective_cluster, known_target_dict=known_target_dict, pop_params=pop_params)
|
185 |
+
return population
|
186 |
+
|
187 |
+
def cocktailrep2recipe(cocktail_rep, unit='mL', target_affective_cluster=None, known_target_dict=None, n_output=1, return_ind=False, verbose=True, full_verbose=False, level=0):
|
188 |
+
init_time = time.time()
|
189 |
+
if verbose: print(' ' * level + 'Generating cocktail..')
|
190 |
+
if cocktail_rep.ndim > 1:
|
191 |
+
assert cocktail_rep.shape[0] == 1
|
192 |
+
cocktail_rep = cocktail_rep.flatten()
|
193 |
+
# target_affective_cluster = target_affective_cluster[0]
|
194 |
+
population = setup_recipe_generation(cocktail_rep.copy(), known_target_dict=known_target_dict, target_affective_cluster=target_affective_cluster)
|
195 |
+
if full_verbose:
|
196 |
+
print(' ' * (level + 2) + '3 nearest neighbors:')
|
197 |
+
for i, recipe, score in zip(range(3), population.nn_recipes[:3], population.nn_scores[:3]):
|
198 |
+
print(' ' * (level + 4) + f'#{i+1}, score: {score:.2f}')
|
199 |
+
print(' ' * (level + 4) + recipe[1:].replace('None ()', '').replace('\t\t', ' ' * (level + 6)))
|
200 |
+
best_scores, best_individuals = population.run_evolution(verbose=full_verbose, level=level+2)
|
201 |
+
for i in range(n_output):
|
202 |
+
best_individuals[i].make_recipe_fit_the_glass()
|
203 |
+
instructions = [ind.get_instructions() for ind in best_individuals[:n_output]]
|
204 |
+
recipes = [ind.get_recipe(unit=unit)[3] for ind in best_individuals[:n_output]]
|
205 |
+
glasses = [ind.glass for ind in best_individuals[:n_output]]
|
206 |
+
prep_types = [ind.prep_type for ind in best_individuals[:n_output]]
|
207 |
+
for i, g, p, inst in zip(range(len(recipes)), glasses, prep_types, instructions):
|
208 |
+
recipes[i] = recipes[i].replace('Recipe', 'Ingredients') + f'Serve in:\n {g.capitalize()} glass.\n' + inst
|
209 |
+
if full_verbose:
|
210 |
+
print(f'\n--------------\n{n_output} best results:')
|
211 |
+
for i, recipe, score in zip(range(n_output), recipes, best_scores[:n_output]):
|
212 |
+
print(f'#{i+1}, score: {score:.2f}')
|
213 |
+
print(recipe)
|
214 |
+
if verbose: print(' ' * (level + 2) + f'Generated in {int(time.time() - init_time)} seconds.')
|
215 |
+
if return_ind:
|
216 |
+
return recipes, best_scores[:n_output], best_individuals[:n_output]
|
217 |
+
else:
|
218 |
+
return recipes, best_scores[:n_output]
|
219 |
+
|
220 |
+
|
221 |
+
def interpolate(cocktail_rep1, cocktail_rep2, alpha, verbose=False):
|
222 |
+
recipe, score = cocktailrep2recipe(alpha * cocktail_rep1 + (1 - alpha) * cocktail_rep2, verbose=verbose)
|
223 |
+
return recipe[0], score
|
224 |
+
|
225 |
+
def interpolation_study(n_steps, cocktail_reps):
|
226 |
+
alphas = np.arange(0, 1 + 1e-6, 1/(n_steps + 1))
|
227 |
+
indexes = np.random.choice(np.arange(cocktail_reps.shape[0]), size=2, replace=False)
|
228 |
+
target_ing_str1, target_ing_str2 = data['ingredients_str'][indexes[0]], data['ingredients_str'][indexes[1]]
|
229 |
+
cocktail_rep1, cocktail_rep2 = cocktail_reps[indexes[0]], cocktail_reps[indexes[1]]
|
230 |
+
recipes, scores = [], []
|
231 |
+
for alpha in alphas:
|
232 |
+
recipe, score = interpolate(cocktail_rep1, cocktail_rep2, alpha)
|
233 |
+
recipes.append(recipe)
|
234 |
+
scores.append(score[0])
|
235 |
+
print('Point A:')
|
236 |
+
print_recipe(ingredient_str=target_ing_str2)
|
237 |
+
for i, alpha in enumerate(alphas):
|
238 |
+
print(f'Alpha = {alpha}, score = {scores[i]}')
|
239 |
+
print(recipes[i])
|
240 |
+
print('Point B:')
|
241 |
+
print_recipe(ingredient_str=target_ing_str1)
|
242 |
+
stop = 1
|
243 |
+
|
244 |
+
def test_robustness_affective_cluster(cocktail_reps):
|
245 |
+
indexes = np.arange(cocktail_reps.shape[0])
|
246 |
+
np.random.shuffle(indexes)
|
247 |
+
matches = []
|
248 |
+
for i in indexes:
|
249 |
+
target_ing_str = data['ingredients_str'][i]
|
250 |
+
true_prep_type = data['category'][i]
|
251 |
+
target = cocktail_reps[i]
|
252 |
+
# get affective cluster
|
253 |
+
recipes, best_scores, best_inds = cocktailrep2recipe(cocktail_rep=target, target_ing_str=target_ing_str, true_prep_type=true_prep_type, n_output=1, verbose=False,
|
254 |
+
return_ind=True)
|
255 |
+
|
256 |
+
matches.append(best_inds[0].does_affective_cluster_match())
|
257 |
+
print(np.mean(matches))
|
258 |
+
|
259 |
+
def test(cocktail_reps):
|
260 |
+
indexes = np.arange(these_cocktail_reps.shape[0])
|
261 |
+
unnormalized_cr = np.array([data[k] for k in rep_keys]).transpose()
|
262 |
+
|
263 |
+
for i in indexes:
|
264 |
+
target_ing_str = data['ingredients_str'][i]
|
265 |
+
true_prep_type = data['category'][i]
|
266 |
+
target = these_cocktail_reps[i]
|
267 |
+
# print('preptype:', true_prep_type)
|
268 |
+
# print('cocktail unnormalized', np.sum(unnormalized_cr[i]), unnormalized_cr[i])
|
269 |
+
# print('cocktail hand normalized', np.sum(normalize_cocktail(unnormalized_cr[i])), normalize_cocktail(unnormalized_cr[i]))
|
270 |
+
# print('cocktail rep normalized', np.sum(these_cocktail_reps[i]), these_cocktail_reps[i])
|
271 |
+
# print('cocktail rep normalized', np.sum(all_reps[i]), all_reps[i])
|
272 |
+
|
273 |
+
population = setup_recipe_generation(target.copy(), target_ing_str=target_ing_str, target_affective_cluster=None, true_prep_type=true_prep_type)
|
274 |
+
target = population.target_individual
|
275 |
+
target.compute_perf()
|
276 |
+
if target.perf < -50:
|
277 |
+
print(i)
|
278 |
+
print_recipe(target_ing_str)
|
279 |
+
if not target.is_alcohol_present(): print('No alcohol')
|
280 |
+
if not target.is_total_volume_enough(): print('small volume')
|
281 |
+
if not target.does_fit_glass():
|
282 |
+
print(target.end_volume)
|
283 |
+
print(glass_volume[target.get_glass_type()] * 0.81)
|
284 |
+
print('too much volume')
|
285 |
+
if not target.is_alcohol_reasonable():
|
286 |
+
print(f'amount of alcohol too small or too large: {target.alcohol_precentage}')
|
287 |
+
stop = 1
|
288 |
+
|
289 |
+
|
290 |
+
if __name__ == '__main__':
|
291 |
+
these_cocktail_reps = COCKTAIL_REPS.copy()
|
292 |
+
# test_crossover(these_cocktail_reps)
|
293 |
+
# test_mutation_params(these_cocktail_reps)
|
294 |
+
# test(these_cocktail_reps)
|
295 |
+
# recipes, closest_recipes = sample_cocktails(these_cocktail_reps, n=10)
|
296 |
+
# interpolation_study(n_steps=4, cocktail_reps=these_cocktail_reps)
|
297 |
+
# test_robustness_affective_cluster(these_cocktail_reps)
|
298 |
+
indexes = np.arange(these_cocktail_reps.shape[0])
|
299 |
+
np.random.shuffle(indexes)
|
300 |
+
# test_crossover(mutation_params, dist)
|
301 |
+
# test_mutation_params(mutation_params, dist)
|
302 |
+
stop = 1
|
303 |
+
unnormalized_cr = np.array([data[k] for k in rep_keys]).transpose()
|
304 |
+
for i in indexes:
|
305 |
+
print(i)
|
306 |
+
target_ing_str = data['ingredients_str'][i]
|
307 |
+
target_prep_type = data['category'][i]
|
308 |
+
target_glass = data['glass'][i]
|
309 |
+
|
310 |
+
print('preptype:', target_prep_type)
|
311 |
+
print('cocktail unnormalized', np.sum(unnormalized_cr[i]), unnormalized_cr[i])
|
312 |
+
print('cocktail hand normalized', np.sum(normalize_cocktail(unnormalized_cr[i])), normalize_cocktail(unnormalized_cr[i]))
|
313 |
+
print('cocktail rep normalized', np.sum(these_cocktail_reps[i]), these_cocktail_reps[i])
|
314 |
+
print('cocktail rep normalized', np.sum(all_reps[i]), all_reps[i])
|
315 |
+
print(i)
|
316 |
+
|
317 |
+
print('___________Target')
|
318 |
+
nn_model = NearestNeighbors()
|
319 |
+
nn_model.fit(these_cocktail_reps)
|
320 |
+
dists, indexes = nn_model.kneighbors(these_cocktail_reps[i].reshape(1, -1))
|
321 |
+
print(indexes)
|
322 |
+
print_recipe(target_ing_str)
|
323 |
+
target = these_cocktail_reps[i]
|
324 |
+
known_target_dict = dict(prep_type=target_prep_type,
|
325 |
+
ing_str=target_ing_str,
|
326 |
+
glass=target_glass)
|
327 |
+
recipes, best_scores = cocktailrep2recipe(cocktail_rep=target, known_target_dict=known_target_dict, n_output=1, verbose=True, full_verbose=True)
|
328 |
+
|
329 |
+
stop = 1
|
src/cocktails/pipeline/get_affect2affective_cluster.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.music.config import CHECKPOINTS_PATH
|
2 |
+
import pickle
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
# can be computed from cocktail2affect
|
6 |
+
cluster_model_path = CHECKPOINTS_PATH + "/music2cocktails/affects2affect_cluster/cluster_model.pickle"
|
7 |
+
|
8 |
+
def get_affect2affective_cluster():
|
9 |
+
with open(cluster_model_path, 'rb') as f:
|
10 |
+
data = pickle.load(f)
|
11 |
+
model = data['cluster_model']
|
12 |
+
dimensions_weights = data['dimensions_weights']
|
13 |
+
def find_cluster(aff_coord):
|
14 |
+
if aff_coord.ndim == 1:
|
15 |
+
aff_coord = aff_coord.reshape(1, -1)
|
16 |
+
return model.predict(aff_coord * np.array(dimensions_weights))
|
17 |
+
return find_cluster
|
18 |
+
|
19 |
+
def get_affective_cluster_centers():
|
20 |
+
with open(cluster_model_path, 'rb') as f:
|
21 |
+
data = pickle.load(f)
|
22 |
+
return data['cluster_centers']
|
23 |
+
|
src/cocktails/pipeline/get_cocktail2affective_cluster.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.cocktails.pipeline.get_affect2affective_cluster import get_affect2affective_cluster
|
2 |
+
from src.cocktails.pipeline.cocktail2affect import cocktail2affect
|
3 |
+
|
4 |
+
def get_cocktail2affective_cluster():
|
5 |
+
find_cluster = get_affect2affective_cluster()
|
6 |
+
def cocktail2affect_cluster(cocktail_rep):
|
7 |
+
affective_coordinates, _ = cocktail2affect(cocktail_rep)
|
8 |
+
return find_cluster(affective_coordinates)
|
9 |
+
return cocktail2affect_cluster
|
src/cocktails/representation_learning/__init__.py
ADDED
File without changes
|
src/cocktails/representation_learning/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (222 Bytes). View file
|
|
src/cocktails/representation_learning/__pycache__/dataset.cpython-39.pyc
ADDED
Binary file (8.77 kB). View file
|
|
src/cocktails/representation_learning/__pycache__/multihead_model.cpython-39.pyc
ADDED
Binary file (5.36 kB). View file
|
|
src/cocktails/representation_learning/__pycache__/run.cpython-39.pyc
ADDED
Binary file (16.1 kB). View file
|
|
src/cocktails/representation_learning/__pycache__/run_without_vae.cpython-39.pyc
ADDED
Binary file (15.7 kB). View file
|
|
src/cocktails/representation_learning/__pycache__/simple_model.cpython-39.pyc
ADDED
Binary file (1.96 kB). View file
|
|
src/cocktails/representation_learning/__pycache__/vae_model.cpython-39.pyc
ADDED
Binary file (8.28 kB). View file
|
|
src/cocktails/representation_learning/dataset.py
ADDED
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
import pickle
|
3 |
+
from src.cocktails.utilities.ingredients_utilities import extract_ingredients, ingredient_list, ingredient_profiles, ingredients_per_type
|
4 |
+
from src.cocktails.utilities.other_scrubbing_utilities import print_recipe
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
def get_representation_from_ingredient(ingredients, quantities, max_q_per_ing, index, params):
|
8 |
+
assert len(ingredients) == len(quantities)
|
9 |
+
ing, q = ingredients[index], quantities[index]
|
10 |
+
proportion = q / np.sum(quantities)
|
11 |
+
index_ing = ingredient_list.index(ing)
|
12 |
+
# add keys of profile
|
13 |
+
rep_ingredient = []
|
14 |
+
rep_ingredient += [ingredient_profiles[k][index_ing] for k in params['ing_keys']]
|
15 |
+
# add category encoding
|
16 |
+
# rep_ingredient += list(params['category_encodings'][ingredient_profiles['type'][index_ing]])
|
17 |
+
# add quantitiy and relative quantity
|
18 |
+
rep_ingredient += [q / max_q_per_ing[ing], proportion]
|
19 |
+
ing_one_hot = np.zeros(len(ingredient_list))
|
20 |
+
ing_one_hot[index_ing] = 1
|
21 |
+
rep_ingredient += list(ing_one_hot)
|
22 |
+
indexes_to_normalize = list(range(len(params['ing_keys'])))
|
23 |
+
#TODO: should we add ing one hot? Or make sure no 2 ing have same embedding
|
24 |
+
return np.array(rep_ingredient), indexes_to_normalize
|
25 |
+
|
26 |
+
def get_max_n_ingredients(data):
|
27 |
+
max_count = 0
|
28 |
+
ingredient_set = set()
|
29 |
+
alcohol_set = set()
|
30 |
+
liqueur_set = set()
|
31 |
+
ing_str = np.array(data['ingredients_str'])
|
32 |
+
for i in range(len(data['names'])):
|
33 |
+
ingredients, quantities = extract_ingredients(ing_str[i])
|
34 |
+
max_count = max(max_count, len(ingredients))
|
35 |
+
for ing in ingredients:
|
36 |
+
ingredient_set.add(ing)
|
37 |
+
if ing in ingredients_per_type['liquor']:
|
38 |
+
alcohol_set.add(ing)
|
39 |
+
if ing in ingredients_per_type['liqueur']:
|
40 |
+
liqueur_set.add(ing)
|
41 |
+
return max_count, ingredient_set, alcohol_set, liqueur_set
|
42 |
+
|
43 |
+
# Add your custom dataset class here
|
44 |
+
class MyDataset(Dataset):
|
45 |
+
def __init__(self, split, params):
|
46 |
+
data = params['raw_data']
|
47 |
+
self.dim_rep_ingredient = params['dim_rep_ingredient']
|
48 |
+
n_data = len(data["names"])
|
49 |
+
|
50 |
+
preparation_list = sorted(set(data['category']))
|
51 |
+
categories_list = sorted(set(data['subcategory']))
|
52 |
+
glasses_list = sorted(set(data['glass']))
|
53 |
+
|
54 |
+
max_ingredients, ingredient_set, liquor_set, liqueur_set = get_max_n_ingredients(data)
|
55 |
+
ingredient_set = sorted(ingredient_set)
|
56 |
+
self.ingredient_set = ingredient_set
|
57 |
+
|
58 |
+
ingredient_quantities = [] # output of our network
|
59 |
+
ingr_strs = np.array(data['ingredients_str'])
|
60 |
+
for i in range(n_data):
|
61 |
+
ingredients, quantities = extract_ingredients(ingr_strs[i])
|
62 |
+
# get ingredient presence and quantity
|
63 |
+
ingredient_q_rep = np.zeros([len(ingredient_set)])
|
64 |
+
for ing, q in zip(ingredients, quantities):
|
65 |
+
ingredient_q_rep[ingredient_set.index(ing)] = q
|
66 |
+
ingredient_quantities.append(ingredient_q_rep)
|
67 |
+
|
68 |
+
# take care of ingredient quantities (OUTPUTS)
|
69 |
+
ingredient_quantities = np.array(ingredient_quantities)
|
70 |
+
ingredients_presence = (ingredient_quantities>0).astype(np.int)
|
71 |
+
|
72 |
+
min_ing_quantities = np.min(ingredient_quantities, axis=0)
|
73 |
+
max_ing_quantities = np.max(ingredient_quantities, axis=0)
|
74 |
+
def normalize_ing_quantities(ing_quantities):
|
75 |
+
return ((ing_quantities - min_ing_quantities) / (max_ing_quantities - min_ing_quantities)).copy()
|
76 |
+
|
77 |
+
def denormalize_ing_quantities(normalized_ing_quantities):
|
78 |
+
return (normalized_ing_quantities * (max_ing_quantities - min_ing_quantities) + min_ing_quantities).copy()
|
79 |
+
ing_q_when_present = ingredient_quantities.copy()
|
80 |
+
for i in range(len(ing_q_when_present)):
|
81 |
+
ing_q_when_present[i, np.where(ing_q_when_present[i, :] == 0)] = np.nan
|
82 |
+
self.min_when_present_ing_quantities = np.nanmin(ing_q_when_present, axis=0)
|
83 |
+
|
84 |
+
|
85 |
+
def filter_decoder_output(output):
|
86 |
+
output_unnormalized = output * max_ing_quantities
|
87 |
+
if output.ndim == 1:
|
88 |
+
output_unnormalized[np.where(output_unnormalized<self.min_when_present_ing_quantities)] = 0
|
89 |
+
else:
|
90 |
+
for i in range(output.shape[0]):
|
91 |
+
output_unnormalized[i, np.where(output_unnormalized[i] < self.min_when_present_ing_quantities)] = 0
|
92 |
+
return output_unnormalized.copy()
|
93 |
+
self.filter_decoder_output = filter_decoder_output
|
94 |
+
# arg_mins = np.nanargmin(ing_q_when_present, axis=0)
|
95 |
+
#
|
96 |
+
# for ing, minq, argminq in zip(ingredient_set, self.min_when_present_ing_quantities, arg_mins):
|
97 |
+
# print(f'__\n{ing}: {minq}')
|
98 |
+
# print_recipe(ingr_strs[argminq])
|
99 |
+
# ingredients, quantities = extract_ingredients(ingr_strs[argminq])
|
100 |
+
# # get ingredient presence and quantity
|
101 |
+
# ingredient_q_rep = np.zeros([len(ingredient_set)])
|
102 |
+
# for ing, q in zip(ingredients, quantities):
|
103 |
+
# ingredient_q_rep[ingredient_set.index(ing)] = q
|
104 |
+
# print(np.array(data['urls'])[argminq])
|
105 |
+
# stop = 1
|
106 |
+
|
107 |
+
self.max_ing_quantities = max_ing_quantities
|
108 |
+
self.mean_ing_quantities = np.mean(ingredient_quantities, axis=0)
|
109 |
+
self.std_ing_quantities = np.std(ingredient_quantities, axis=0)
|
110 |
+
if split == 'train':
|
111 |
+
np.savetxt(params['save_path'] + 'min_when_present_ing_quantities.txt', self.min_when_present_ing_quantities)
|
112 |
+
np.savetxt(params['save_path'] + 'max_ing_quantities.txt', max_ing_quantities)
|
113 |
+
np.savetxt(params['save_path'] + 'mean_ing_quantities.txt', self.mean_ing_quantities)
|
114 |
+
np.savetxt(params['save_path'] + 'std_ing_quantities.txt', self.std_ing_quantities)
|
115 |
+
|
116 |
+
# print(ingredient_quantities[0])
|
117 |
+
# ingredient_quantities = (ingredient_quantities - self.mean_ing_quantities) / self.std_ing_quantities
|
118 |
+
# print(ingredient_quantities[0])
|
119 |
+
# print(ingredient_quantities[0] * self.std_ing_quantities + self.mean_ing_quantities )
|
120 |
+
ingredient_quantities = ingredient_quantities / max_ing_quantities#= normalize_ing_quantities(ingredient_quantities)
|
121 |
+
|
122 |
+
|
123 |
+
|
124 |
+
|
125 |
+
max_q_per_ing = dict(zip(ingredient_set, max_ing_quantities))
|
126 |
+
# print(ingredient_quantities[0])
|
127 |
+
#########
|
128 |
+
# Process input representation_analysis: list of ingredient representation_analysis
|
129 |
+
#########
|
130 |
+
input_data = [] # input of ingredient encoders
|
131 |
+
all_ing_reps = []
|
132 |
+
for i in range(n_data):
|
133 |
+
ingredients, quantities = extract_ingredients(ingr_strs[i])
|
134 |
+
# get ingredient presence and quantity
|
135 |
+
ingredient_q_rep = np.zeros([len(ingredient_set)])
|
136 |
+
for ing, q in zip(ingredients, quantities):
|
137 |
+
ingredient_q_rep[ingredient_set.index(ing)] = q
|
138 |
+
# get main liquor
|
139 |
+
cocktail_rep = []
|
140 |
+
for j in range(len(ingredients)):
|
141 |
+
cocktail_rep.append(get_representation_from_ingredient(ingredients, quantities, max_q_per_ing, index=j, params=params)[0])
|
142 |
+
all_ing_reps.append(cocktail_rep[-1].copy())
|
143 |
+
input_data.append(cocktail_rep)
|
144 |
+
|
145 |
+
|
146 |
+
all_ing_reps = np.array(all_ing_reps)
|
147 |
+
min_ing_reps = np.min(all_ing_reps[:, params['indexes_ing_to_normalize']], axis=0)
|
148 |
+
max_ing_reps = np.max(all_ing_reps[:, params['indexes_ing_to_normalize']], axis=0)
|
149 |
+
|
150 |
+
def normalize_ing_reps(ing_reps):
|
151 |
+
if ing_reps.ndim == 1:
|
152 |
+
ing_reps = ing_reps.reshape(1, -1)
|
153 |
+
out = ing_reps.copy()
|
154 |
+
out[:, params['indexes_ing_to_normalize']] = (out[:, params['indexes_ing_to_normalize']] - min_ing_reps) / (max_ing_reps - min_ing_reps)
|
155 |
+
return out
|
156 |
+
|
157 |
+
def denormalize_ing_reps(normalized_ing_reps):
|
158 |
+
if normalized_ing_reps.ndim == 1:
|
159 |
+
normalized_ing_reps = normalized_ing_reps.reshape(1, -1)
|
160 |
+
out = normalized_ing_reps.copy()
|
161 |
+
out[:, params['indexes_ing_to_normalize']] = out[:, params['indexes_ing_to_normalize']] * (max_ing_reps - min_ing_reps) + min_ing_reps
|
162 |
+
return out
|
163 |
+
|
164 |
+
|
165 |
+
# put everything in a big matrix
|
166 |
+
dim_cocktail_rep = max_ingredients * self.dim_rep_ingredient
|
167 |
+
input_data2 = []
|
168 |
+
nb_ingredients = []
|
169 |
+
for d in input_data:
|
170 |
+
cocktail_rep = np.zeros([dim_cocktail_rep])
|
171 |
+
cocktail_rep.fill(np.nan)
|
172 |
+
index = 0
|
173 |
+
nb_ingredients.append(len(d))
|
174 |
+
for dj in d:
|
175 |
+
cocktail_rep[index:index + self.dim_rep_ingredient] = normalize_ing_reps(dj)
|
176 |
+
index += self.dim_rep_ingredient
|
177 |
+
input_data2.append(cocktail_rep)
|
178 |
+
input_data = np.array(input_data2)
|
179 |
+
nb_ingredients = np.array(nb_ingredients)
|
180 |
+
|
181 |
+
|
182 |
+
|
183 |
+
|
184 |
+
|
185 |
+
# let us now extract various possible output we might want to predict:
|
186 |
+
#########
|
187 |
+
# Process output cocktail representation_analysis (computed from ingredient reps)
|
188 |
+
#########
|
189 |
+
# quantities_indexes = np.arange(20, 456, 57)
|
190 |
+
# qs = input_data[0, quantities_indexes]
|
191 |
+
# ingredient_quantities[0]
|
192 |
+
# get final volume
|
193 |
+
volumes = np.array(params['raw_data']['end volume'])
|
194 |
+
|
195 |
+
min_vol = volumes.min()
|
196 |
+
max_vol = volumes.max()
|
197 |
+
def normalize_vol(volume):
|
198 |
+
return (volume - min_vol) / (max_vol - min_vol)
|
199 |
+
|
200 |
+
def denormalize_vol(normalized_vol):
|
201 |
+
return normalized_vol * (max_vol - min_vol) + min_vol
|
202 |
+
|
203 |
+
volumes = normalize_vol(volumes)
|
204 |
+
|
205 |
+
|
206 |
+
# computed cocktail representation
|
207 |
+
computed_cocktail_reps = params['cocktail_reps']
|
208 |
+
self.dim_rep = computed_cocktail_reps.shape[1]
|
209 |
+
|
210 |
+
#########
|
211 |
+
# Process output sub categories
|
212 |
+
#########
|
213 |
+
categories = np.array([categories_list.index(sc) for sc in data['subcategory']])
|
214 |
+
counts = dict(zip(categories_list, [0] * len(categories)))
|
215 |
+
for c in data['subcategory']:
|
216 |
+
counts[c] += 1
|
217 |
+
for k in counts.keys():
|
218 |
+
counts[k] /= len(data['subcategory'])
|
219 |
+
self.categories = categories_list
|
220 |
+
self.categories_weights = []
|
221 |
+
for c in self.categories:
|
222 |
+
self.categories_weights.append(1/len(self.categories)/counts[c])
|
223 |
+
print(counts)
|
224 |
+
|
225 |
+
#########
|
226 |
+
# Process output glass type
|
227 |
+
#########
|
228 |
+
glasses = np.array([glasses_list.index(sc) for sc in data['glass']])
|
229 |
+
counts = dict(zip(glasses_list, [0] * len(set(data['glass']))))
|
230 |
+
for c in data['glass']:
|
231 |
+
counts[c] += 1
|
232 |
+
for k in counts.keys():
|
233 |
+
counts[k] /= len(data['glass'])
|
234 |
+
self.glasses = glasses_list
|
235 |
+
self.glasses_weights = []
|
236 |
+
for c in self.glasses:
|
237 |
+
self.glasses_weights.append(1 / len(self.glasses) / counts[c])
|
238 |
+
print(counts)
|
239 |
+
|
240 |
+
#########
|
241 |
+
# Process output preparation type
|
242 |
+
#########
|
243 |
+
prep_type = np.array([preparation_list.index(sc) for sc in data['category']])
|
244 |
+
counts = dict(zip(preparation_list, [0] * len(preparation_list)))
|
245 |
+
for c in data['category']:
|
246 |
+
counts[c] += 1
|
247 |
+
for k in counts.keys():
|
248 |
+
counts[k] /= len(data['category'])
|
249 |
+
self.prep_types = preparation_list
|
250 |
+
self.prep_types_weights = []
|
251 |
+
for c in self.prep_types:
|
252 |
+
self.prep_types_weights.append(1 / len(self.prep_types) / counts[c])
|
253 |
+
print(counts)
|
254 |
+
|
255 |
+
taste_reps = list(data['taste_rep'])
|
256 |
+
taste_rep_ground_truth = []
|
257 |
+
taste_rep_valid = []
|
258 |
+
for tr in taste_reps:
|
259 |
+
if len(tr) > 2:
|
260 |
+
taste_rep_valid.append(True)
|
261 |
+
taste_rep_ground_truth.append([float(tr.split('[')[1].split(',')[0]), float(tr.split(']')[0].split(',')[1][1:])])
|
262 |
+
else:
|
263 |
+
taste_rep_valid.append(False)
|
264 |
+
taste_rep_ground_truth.append([np.nan, np.nan])
|
265 |
+
taste_rep_ground_truth = np.array(taste_rep_ground_truth)
|
266 |
+
taste_rep_valid = np.array(taste_rep_valid)
|
267 |
+
taste_rep_ground_truth /= 10
|
268 |
+
|
269 |
+
auxiliary_data = dict(categories=categories,
|
270 |
+
glasses=glasses,
|
271 |
+
prep_type=prep_type,
|
272 |
+
cocktail_reps=computed_cocktail_reps,
|
273 |
+
ingredients_presence=ingredients_presence,
|
274 |
+
taste_reps=taste_rep_ground_truth,
|
275 |
+
volume=volumes,
|
276 |
+
ingredients_quantities=ingredient_quantities)
|
277 |
+
self.auxiliary_keys = sorted(params['auxiliaries_dict'].keys())
|
278 |
+
assert self.auxiliary_keys == sorted(auxiliary_data.keys())
|
279 |
+
|
280 |
+
data_preprocessing = dict(min_max_ing_quantities=(min_ing_quantities, max_ing_quantities),
|
281 |
+
min_max_ing_reps=(min_ing_reps, max_ing_reps),
|
282 |
+
min_max_vol=(min_vol, max_vol))
|
283 |
+
|
284 |
+
if split == 'train':
|
285 |
+
with open(params['save_path'] + 'normalization_funcs.pickle', 'wb') as f:
|
286 |
+
pickle.dump(data_preprocessing, f)
|
287 |
+
|
288 |
+
n_data = len(input_data)
|
289 |
+
assert len(ingredient_quantities) == n_data
|
290 |
+
for aux in self.auxiliary_keys:
|
291 |
+
assert len(auxiliary_data[aux]) == n_data
|
292 |
+
|
293 |
+
if split == 'train':
|
294 |
+
indexes = np.arange(int(0.9 * n_data))
|
295 |
+
elif split == 'test':
|
296 |
+
indexes = np.arange(int(0.9 * n_data), n_data)
|
297 |
+
elif split == 'all':
|
298 |
+
indexes = np.arange(n_data)
|
299 |
+
else:
|
300 |
+
raise ValueError
|
301 |
+
|
302 |
+
# np.random.shuffle(indexes)
|
303 |
+
self.taste_rep_valid = taste_rep_valid[indexes]
|
304 |
+
self.input_ingredients = input_data[indexes]
|
305 |
+
self.ingredient_quantities = ingredient_quantities[indexes]
|
306 |
+
self.computed_cocktail_reps = computed_cocktail_reps[indexes]
|
307 |
+
self.auxiliaries = dict()
|
308 |
+
for aux in self.auxiliary_keys:
|
309 |
+
self.auxiliaries[aux] = auxiliary_data[aux][indexes]
|
310 |
+
self.nb_ingredients = nb_ingredients[indexes]
|
311 |
+
|
312 |
+
def __len__(self):
|
313 |
+
return len(self.input_ingredients)
|
314 |
+
|
315 |
+
def get_auxiliary_data(self, idx):
|
316 |
+
out = dict()
|
317 |
+
for aux in self.auxiliary_keys:
|
318 |
+
out[aux] = self.auxiliaries[aux][idx]
|
319 |
+
return out
|
320 |
+
|
321 |
+
def __getitem__(self, idx):
|
322 |
+
assert self.nb_ingredients[idx] == np.argwhere(~np.isnan(self.input_ingredients[idx])).flatten().size / self.dim_rep_ingredient
|
323 |
+
return [self.nb_ingredients[idx], self.input_ingredients[idx], self.ingredient_quantities[idx], self.computed_cocktail_reps[idx], self.get_auxiliary_data(idx),
|
324 |
+
self.taste_rep_valid[idx]]
|
src/cocktails/representation_learning/multihead_model.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch; torch.manual_seed(0)
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch.utils
|
5 |
+
import torch.distributions
|
6 |
+
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
|
7 |
+
from src.cocktails.representation_learning.simple_model import SimpleNet
|
8 |
+
|
9 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
10 |
+
|
11 |
+
def get_activation(activation):
|
12 |
+
if activation == 'tanh':
|
13 |
+
activ = F.tanh
|
14 |
+
elif activation == 'relu':
|
15 |
+
activ = F.relu
|
16 |
+
elif activation == 'mish':
|
17 |
+
activ = F.mish
|
18 |
+
elif activation == 'sigmoid':
|
19 |
+
activ = F.sigmoid
|
20 |
+
elif activation == 'leakyrelu':
|
21 |
+
activ = F.leaky_relu
|
22 |
+
elif activation == 'exp':
|
23 |
+
activ = torch.exp
|
24 |
+
else:
|
25 |
+
raise ValueError
|
26 |
+
return activ
|
27 |
+
|
28 |
+
class IngredientEncoder(nn.Module):
|
29 |
+
def __init__(self, input_dim, deepset_latent_dim, hidden_dims, activation, dropout):
|
30 |
+
super(IngredientEncoder, self).__init__()
|
31 |
+
self.linears = nn.ModuleList()
|
32 |
+
self.dropouts = nn.ModuleList()
|
33 |
+
dims = [input_dim] + hidden_dims + [deepset_latent_dim]
|
34 |
+
for d_in, d_out in zip(dims[:-1], dims[1:]):
|
35 |
+
self.linears.append(nn.Linear(d_in, d_out))
|
36 |
+
self.dropouts.append(nn.Dropout(dropout))
|
37 |
+
self.activation = get_activation(activation)
|
38 |
+
self.n_layers = len(self.linears)
|
39 |
+
self.layer_range = range(self.n_layers)
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts):
|
43 |
+
x = layer(x)
|
44 |
+
if i_layer != self.n_layers - 1:
|
45 |
+
x = self.activation(dropout(x))
|
46 |
+
return x # do not use dropout on last layer?
|
47 |
+
|
48 |
+
class DeepsetCocktailEncoder(nn.Module):
|
49 |
+
def __init__(self, input_dim, deepset_latent_dim, hidden_dims_ing, activation,
|
50 |
+
hidden_dims_cocktail, latent_dim, aggregation, dropout):
|
51 |
+
super(DeepsetCocktailEncoder, self).__init__()
|
52 |
+
self.input_dim = input_dim # dimension of ingredient representation + quantity
|
53 |
+
self.ingredient_encoder = IngredientEncoder(input_dim, deepset_latent_dim, hidden_dims_ing, activation, dropout) # encode each ingredient separately
|
54 |
+
self.deepset_latent_dim = deepset_latent_dim # dimension of the deepset aggregation
|
55 |
+
self.aggregation = aggregation
|
56 |
+
self.latent_dim = latent_dim
|
57 |
+
# post aggregation network
|
58 |
+
self.linears = nn.ModuleList()
|
59 |
+
self.dropouts = nn.ModuleList()
|
60 |
+
dims = [deepset_latent_dim] + hidden_dims_cocktail
|
61 |
+
for d_in, d_out in zip(dims[:-1], dims[1:]):
|
62 |
+
self.linears.append(nn.Linear(d_in, d_out))
|
63 |
+
self.dropouts.append(nn.Dropout(dropout))
|
64 |
+
self.FC_mean = nn.Linear(hidden_dims_cocktail[-1], latent_dim)
|
65 |
+
self.FC_logvar = nn.Linear(hidden_dims_cocktail[-1], latent_dim)
|
66 |
+
self.softplus = nn.Softplus()
|
67 |
+
|
68 |
+
self.activation = get_activation(activation)
|
69 |
+
self.n_layers = len(self.linears)
|
70 |
+
self.layer_range = range(self.n_layers)
|
71 |
+
|
72 |
+
def forward(self, nb_ingredients, x):
|
73 |
+
|
74 |
+
# reshape x in (batch size * nb ingredients, dim_ing_rep)
|
75 |
+
batch_size = x.shape[0]
|
76 |
+
all_ingredients = []
|
77 |
+
for i in range(batch_size):
|
78 |
+
for j in range(nb_ingredients[i]):
|
79 |
+
all_ingredients.append(x[i, self.input_dim * j: self.input_dim * (j + 1)].reshape(1, -1))
|
80 |
+
x = torch.cat(all_ingredients, dim=0)
|
81 |
+
# encode ingredients in parallel
|
82 |
+
ingredients_encodings = self.ingredient_encoder(x)
|
83 |
+
assert ingredients_encodings.shape == (torch.sum(nb_ingredients), self.deepset_latent_dim)
|
84 |
+
|
85 |
+
# aggregate
|
86 |
+
x = []
|
87 |
+
index_first = 0
|
88 |
+
for i in range(batch_size):
|
89 |
+
index_last = index_first + nb_ingredients[i]
|
90 |
+
# aggregate
|
91 |
+
if self.aggregation == 'sum':
|
92 |
+
x.append(torch.sum(ingredients_encodings[index_first:index_last], dim=0).reshape(1, -1))
|
93 |
+
elif self.aggregation == 'mean':
|
94 |
+
x.append(torch.mean(ingredients_encodings[index_first:index_last], dim=0).reshape(1, -1))
|
95 |
+
else:
|
96 |
+
raise ValueError
|
97 |
+
index_first = index_last
|
98 |
+
x = torch.cat(x, dim=0)
|
99 |
+
assert x.shape[0] == batch_size
|
100 |
+
|
101 |
+
for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts):
|
102 |
+
x = self.activation(dropout(layer(x)))
|
103 |
+
mean = self.FC_mean(x)
|
104 |
+
logvar = self.FC_logvar(x)
|
105 |
+
return mean, logvar
|
106 |
+
|
107 |
+
|
108 |
+
class MultiHeadModel(nn.Module):
|
109 |
+
def __init__(self, encoder, auxiliaries_dict, activation, hidden_dims_decoder):
|
110 |
+
super(MultiHeadModel, self).__init__()
|
111 |
+
self.encoder = encoder
|
112 |
+
self.latent_dim = self.encoder.output_dim
|
113 |
+
self.auxiliaries_str = []
|
114 |
+
self.auxiliaries = nn.ModuleList()
|
115 |
+
for aux_str in sorted(auxiliaries_dict.keys()):
|
116 |
+
if aux_str == 'taste_reps':
|
117 |
+
self.taste_reps_decoder = SimpleNet(input_dim=self.latent_dim, hidden_dims=[], output_dim=auxiliaries_dict[aux_str]['dim_output'],
|
118 |
+
activation=activation, dropout=0.0, final_activ=auxiliaries_dict[aux_str]['final_activ'])
|
119 |
+
else:
|
120 |
+
self.auxiliaries_str.append(aux_str)
|
121 |
+
if aux_str == 'ingredients_quantities':
|
122 |
+
hd = hidden_dims_decoder
|
123 |
+
else:
|
124 |
+
hd = []
|
125 |
+
self.auxiliaries.append(SimpleNet(input_dim=self.latent_dim, hidden_dims=hd, output_dim=auxiliaries_dict[aux_str]['dim_output'],
|
126 |
+
activation=activation, dropout=0.0, final_activ=auxiliaries_dict[aux_str]['final_activ']))
|
127 |
+
|
128 |
+
def get_all_auxiliaries(self, x):
|
129 |
+
return [aux(x) for aux in self.auxiliaries]
|
130 |
+
|
131 |
+
def get_auxiliary(self, z, aux_str):
|
132 |
+
if aux_str == 'taste_reps':
|
133 |
+
return self.taste_reps_decoder(z)
|
134 |
+
else:
|
135 |
+
index = self.auxiliaries_str.index(aux_str)
|
136 |
+
return self.auxiliaries[index](z)
|
137 |
+
|
138 |
+
def forward(self, x, aux_str=None):
|
139 |
+
z = self.encoder(x)
|
140 |
+
if aux_str is not None:
|
141 |
+
return z, self.get_auxiliary(z, aux_str), [aux_str]
|
142 |
+
else:
|
143 |
+
return z, self.get_all_auxiliaries(z), self.auxiliaries_str
|
144 |
+
|
145 |
+
def get_multihead_model(input_dim, activation, hidden_dims_cocktail, latent_dim, dropout, auxiliaries_dict, hidden_dims_decoder):
|
146 |
+
encoder = SimpleNet(input_dim, hidden_dims_cocktail, latent_dim, activation, dropout)
|
147 |
+
model = MultiHeadModel(encoder, auxiliaries_dict, activation, hidden_dims_decoder)
|
148 |
+
return model
|
src/cocktails/representation_learning/run.py
ADDED
@@ -0,0 +1,557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch; torch.manual_seed(0)
|
2 |
+
import torch.utils
|
3 |
+
from torch.utils.data import DataLoader
|
4 |
+
import torch.distributions
|
5 |
+
import torch.nn as nn
|
6 |
+
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
|
7 |
+
from src.cocktails.representation_learning.dataset import MyDataset, get_representation_from_ingredient, get_max_n_ingredients
|
8 |
+
import json
|
9 |
+
import pandas as pd
|
10 |
+
import numpy as np
|
11 |
+
import os
|
12 |
+
from src.cocktails.representation_learning.vae_model import get_vae_model
|
13 |
+
from src.cocktails.config import COCKTAILS_CSV_DATA, FULL_COCKTAIL_REP_PATH, EXPERIMENT_PATH
|
14 |
+
from src.cocktails.utilities.cocktail_utilities import get_bunch_of_rep_keys
|
15 |
+
from src.cocktails.utilities.ingredients_utilities import ingredient_profiles
|
16 |
+
from resource import getrusage
|
17 |
+
from resource import RUSAGE_SELF
|
18 |
+
import gc
|
19 |
+
gc.collect(2)
|
20 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
21 |
+
|
22 |
+
def get_params():
|
23 |
+
data = pd.read_csv(COCKTAILS_CSV_DATA)
|
24 |
+
max_ingredients, ingredient_set, liquor_set, liqueur_set = get_max_n_ingredients(data)
|
25 |
+
num_ingredients = len(ingredient_set)
|
26 |
+
rep_keys = get_bunch_of_rep_keys()['custom']
|
27 |
+
ing_keys = [k.split(' ')[1] for k in rep_keys]
|
28 |
+
ing_keys.remove('volume')
|
29 |
+
nb_ing_categories = len(set(ingredient_profiles['type']))
|
30 |
+
category_encodings = dict(zip(sorted(set(ingredient_profiles['type'])), np.eye(nb_ing_categories)))
|
31 |
+
|
32 |
+
params = dict(trial_id='test',
|
33 |
+
save_path=EXPERIMENT_PATH + "/deepset_vae/",
|
34 |
+
nb_epochs=2000,
|
35 |
+
print_every=50,
|
36 |
+
plot_every=100,
|
37 |
+
batch_size=64,
|
38 |
+
lr=0.001,
|
39 |
+
dropout=0.,
|
40 |
+
nb_epoch_switch_beta=600,
|
41 |
+
latent_dim=10,
|
42 |
+
beta_vae=0.2,
|
43 |
+
ing_keys=ing_keys,
|
44 |
+
nb_ingredients=len(ingredient_set),
|
45 |
+
hidden_dims_ingredients=[128],
|
46 |
+
hidden_dims_cocktail=[32],
|
47 |
+
hidden_dims_decoder=[32],
|
48 |
+
agg='mean',
|
49 |
+
activation='relu',
|
50 |
+
auxiliaries_dict=dict(categories=dict(weight=0, type='classif', final_activ=None, dim_output=len(set(data['subcategory']))),
|
51 |
+
glasses=dict(weight=0, type='classif', final_activ=None, dim_output=len(set(data['glass']))),
|
52 |
+
prep_type=dict(weight=0, type='classif', final_activ=None, dim_output=len(set(data['category']))),
|
53 |
+
cocktail_reps=dict(weight=0, type='regression', final_activ=None, dim_output=13),
|
54 |
+
volume=dict(weight=0, type='regression', final_activ='relu', dim_output=1),
|
55 |
+
taste_reps=dict(weight=0, type='regression', final_activ='relu', dim_output=2),
|
56 |
+
ingredients_presence=dict(weight=0, type='multiclassif', final_activ=None, dim_output=num_ingredients)),
|
57 |
+
category_encodings=category_encodings
|
58 |
+
)
|
59 |
+
# params = dict(trial_id='test',
|
60 |
+
# save_path=EXPERIMENT_PATH + "/deepset_vae/",
|
61 |
+
# nb_epochs=1000,
|
62 |
+
# print_every=50,
|
63 |
+
# plot_every=100,
|
64 |
+
# batch_size=64,
|
65 |
+
# lr=0.001,
|
66 |
+
# dropout=0.,
|
67 |
+
# nb_epoch_switch_beta=500,
|
68 |
+
# latent_dim=64,
|
69 |
+
# beta_vae=0.3,
|
70 |
+
# ing_keys=ing_keys,
|
71 |
+
# nb_ingredients=len(ingredient_set),
|
72 |
+
# hidden_dims_ingredients=[128],
|
73 |
+
# hidden_dims_cocktail=[128, 128],
|
74 |
+
# hidden_dims_decoder=[128, 128],
|
75 |
+
# agg='mean',
|
76 |
+
# activation='mish',
|
77 |
+
# auxiliaries_dict=dict(categories=dict(weight=0.5, type='classif', final_activ=None, dim_output=len(set(data['subcategory']))),
|
78 |
+
# glasses=dict(weight=0.03, type='classif', final_activ=None, dim_output=len(set(data['glass']))),
|
79 |
+
# prep_type=dict(weight=0.02, type='classif', final_activ=None, dim_output=len(set(data['category']))),
|
80 |
+
# cocktail_reps=dict(weight=1, type='regression', final_activ=None, dim_output=13),
|
81 |
+
# volume=dict(weight=1, type='regression', final_activ='relu', dim_output=1),
|
82 |
+
# taste_reps=dict(weight=1, type='regression', final_activ='relu', dim_output=2),
|
83 |
+
# ingredients_presence=dict(weight=1.5, type='multiclassif', final_activ=None, dim_output=num_ingredients)),
|
84 |
+
# category_encodings=category_encodings
|
85 |
+
# )
|
86 |
+
water_rep, indexes_to_normalize = get_representation_from_ingredient(ingredients=['water'], quantities=[1],
|
87 |
+
max_q_per_ing=dict(zip(ingredient_set, [1] * num_ingredients)), index=0,
|
88 |
+
params=params)
|
89 |
+
dim_rep_ingredient = water_rep.size
|
90 |
+
params['indexes_ing_to_normalize'] = indexes_to_normalize
|
91 |
+
params['deepset_latent_dim'] = dim_rep_ingredient * max_ingredients
|
92 |
+
params['input_dim'] = dim_rep_ingredient
|
93 |
+
params['dim_rep_ingredient'] = dim_rep_ingredient
|
94 |
+
params = compute_expe_name_and_save_path(params)
|
95 |
+
del params['category_encodings'] # to dump
|
96 |
+
with open(params['save_path'] + 'params.json', 'w') as f:
|
97 |
+
json.dump(params, f)
|
98 |
+
|
99 |
+
params = complete_params(params)
|
100 |
+
return params
|
101 |
+
|
102 |
+
def complete_params(params):
|
103 |
+
data = pd.read_csv(COCKTAILS_CSV_DATA)
|
104 |
+
cocktail_reps = np.loadtxt(FULL_COCKTAIL_REP_PATH)
|
105 |
+
nb_ing_categories = len(set(ingredient_profiles['type']))
|
106 |
+
category_encodings = dict(zip(sorted(set(ingredient_profiles['type'])), np.eye(nb_ing_categories)))
|
107 |
+
params['cocktail_reps'] = cocktail_reps
|
108 |
+
params['raw_data'] = data
|
109 |
+
params['category_encodings'] = category_encodings
|
110 |
+
return params
|
111 |
+
|
112 |
+
def compute_losses_and_accuracies(loss_functions, auxiliaries, auxiliaries_str, outputs, data):
|
113 |
+
losses = dict()
|
114 |
+
accuracies = dict()
|
115 |
+
other_metrics = dict()
|
116 |
+
for i_k, k in enumerate(auxiliaries_str):
|
117 |
+
# get ground truth
|
118 |
+
# compute loss
|
119 |
+
if k == 'volume':
|
120 |
+
outputs[i_k] = outputs[i_k].flatten()
|
121 |
+
ground_truth = auxiliaries[k]
|
122 |
+
if ground_truth.dtype == torch.float64:
|
123 |
+
losses[k] = loss_functions[k](outputs[i_k], ground_truth.float()).float()
|
124 |
+
elif ground_truth.dtype == torch.int64:
|
125 |
+
if str(loss_functions[k]) != "BCEWithLogitsLoss()":
|
126 |
+
losses[k] = loss_functions[k](outputs[i_k].float(), ground_truth.long()).float()
|
127 |
+
else:
|
128 |
+
losses[k] = loss_functions[k](outputs[i_k].float(), ground_truth.float()).float()
|
129 |
+
else:
|
130 |
+
losses[k] = loss_functions[k](outputs[i_k], ground_truth).float()
|
131 |
+
# compute accuracies
|
132 |
+
if str(loss_functions[k]) == 'CrossEntropyLoss()':
|
133 |
+
bs, n_options = outputs[i_k].shape
|
134 |
+
predicted = outputs[i_k].argmax(dim=1).detach().numpy()
|
135 |
+
true = ground_truth.int().detach().numpy()
|
136 |
+
confusion_matrix = np.zeros([n_options, n_options])
|
137 |
+
for i in range(bs):
|
138 |
+
confusion_matrix[true[i], predicted[i]] += 1
|
139 |
+
acc = confusion_matrix.diagonal().sum() / bs
|
140 |
+
for i in range(n_options):
|
141 |
+
if confusion_matrix[i].sum() != 0:
|
142 |
+
confusion_matrix[i] /= confusion_matrix[i].sum()
|
143 |
+
other_metrics[k + '_confusion'] = confusion_matrix
|
144 |
+
accuracies[k] = np.mean(outputs[i_k].argmax(dim=1).detach().numpy() == ground_truth.int().detach().numpy())
|
145 |
+
assert (acc - accuracies[k]) < 1e-5
|
146 |
+
|
147 |
+
elif str(loss_functions[k]) == 'BCEWithLogitsLoss()':
|
148 |
+
assert k == 'ingredients_presence'
|
149 |
+
outputs_rescaled = outputs[i_k].detach().numpy() * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities
|
150 |
+
predicted_presence = (outputs_rescaled > 0).astype(bool)
|
151 |
+
presence = ground_truth.detach().numpy().astype(bool)
|
152 |
+
other_metrics[k + '_false_positive'] = np.mean(np.logical_and(predicted_presence.astype(bool), ~presence.astype(bool)))
|
153 |
+
other_metrics[k + '_false_negative'] = np.mean(np.logical_and(~predicted_presence.astype(bool), presence.astype(bool)))
|
154 |
+
accuracies[k] = np.mean(predicted_presence == presence) # accuracy for multi class labeling
|
155 |
+
elif str(loss_functions[k]) == 'MSELoss()':
|
156 |
+
accuracies[k] = np.nan
|
157 |
+
else:
|
158 |
+
raise ValueError
|
159 |
+
return losses, accuracies, other_metrics
|
160 |
+
|
161 |
+
def compute_metric_output(aux_other_metrics, data, ingredient_quantities, x_hat):
|
162 |
+
ing_q = ingredient_quantities.detach().numpy() * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities
|
163 |
+
ing_presence = (ing_q > 0)
|
164 |
+
x_hat = x_hat.detach().numpy() * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities
|
165 |
+
# abs_diff = np.abs(ing_q - x_hat) * data.dataset.max_ing_quantities
|
166 |
+
abs_diff = np.abs(ing_q - x_hat)
|
167 |
+
ing_q_abs_loss_when_present, ing_q_abs_loss_when_absent = [], []
|
168 |
+
for i in range(ingredient_quantities.shape[0]):
|
169 |
+
ing_q_abs_loss_when_present.append(np.mean(abs_diff[i, np.where(ing_presence[i])]))
|
170 |
+
ing_q_abs_loss_when_absent.append(np.mean(abs_diff[i, np.where(~ing_presence[i])]))
|
171 |
+
aux_other_metrics['ing_q_abs_loss_when_present'] = np.mean(ing_q_abs_loss_when_present)
|
172 |
+
aux_other_metrics['ing_q_abs_loss_when_absent'] = np.mean(ing_q_abs_loss_when_absent)
|
173 |
+
return aux_other_metrics
|
174 |
+
|
175 |
+
def run_epoch(opt, train, model, data, loss_functions, weights, params):
|
176 |
+
if train:
|
177 |
+
model.train()
|
178 |
+
else:
|
179 |
+
model.eval()
|
180 |
+
|
181 |
+
# prepare logging of losses
|
182 |
+
losses = dict(kld_loss=[],
|
183 |
+
mse_loss=[],
|
184 |
+
vae_loss=[],
|
185 |
+
volume_loss=[],
|
186 |
+
global_loss=[])
|
187 |
+
accuracies = dict()
|
188 |
+
other_metrics = dict()
|
189 |
+
for aux in params['auxiliaries_dict'].keys():
|
190 |
+
losses[aux] = []
|
191 |
+
accuracies[aux] = []
|
192 |
+
if train: opt.zero_grad()
|
193 |
+
|
194 |
+
for d in data:
|
195 |
+
nb_ingredients = d[0]
|
196 |
+
batch_size = nb_ingredients.shape[0]
|
197 |
+
x_ingredients = d[1].float()
|
198 |
+
ingredient_quantities = d[2]
|
199 |
+
cocktail_reps = d[3]
|
200 |
+
auxiliaries = d[4]
|
201 |
+
for k in auxiliaries.keys():
|
202 |
+
if auxiliaries[k].dtype == torch.float64: auxiliaries[k] = auxiliaries[k].float()
|
203 |
+
taste_valid = d[-1]
|
204 |
+
x = x_ingredients.to(device)
|
205 |
+
x_hat, z, mean, log_var, outputs, auxiliaries_str = model.forward_direct(ingredient_quantities.float())
|
206 |
+
# get auxiliary losses and accuracies
|
207 |
+
aux_losses, aux_accuracies, aux_other_metrics = compute_losses_and_accuracies(loss_functions, auxiliaries, auxiliaries_str, outputs, data)
|
208 |
+
|
209 |
+
# compute vae loss
|
210 |
+
mse_loss = ((ingredient_quantities - x_hat) ** 2).mean().float()
|
211 |
+
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mean ** 2 - log_var.exp(), dim=1)).float()
|
212 |
+
vae_loss = mse_loss + params['beta_vae'] * (params['latent_dim'] / params['nb_ingredients']) * kld_loss
|
213 |
+
# compute total volume loss to train decoder
|
214 |
+
# volume_loss = ((ingredient_quantities.sum(dim=1) - x_hat.sum(dim=1)) ** 2).mean().float()
|
215 |
+
volume_loss = torch.FloatTensor([0])
|
216 |
+
|
217 |
+
aux_other_metrics = compute_metric_output(aux_other_metrics, data, ingredient_quantities, x_hat)
|
218 |
+
|
219 |
+
indexes_taste_valid = np.argwhere(taste_valid.detach().numpy()).flatten()
|
220 |
+
if indexes_taste_valid.size > 0:
|
221 |
+
outputs_taste = model.get_auxiliary(z[indexes_taste_valid], aux_str='taste_reps')
|
222 |
+
gt = auxiliaries['taste_reps'][indexes_taste_valid]
|
223 |
+
factor_loss = indexes_taste_valid.size / (0.3 * batch_size)# factor on the loss: if same ratio as actual dataset factor = 1 if there is less data, then the factor decreases, more data, it increases
|
224 |
+
aux_losses['taste_reps'] = (loss_functions['taste_reps'](outputs_taste, gt) * factor_loss).float()
|
225 |
+
else:
|
226 |
+
aux_losses['taste_reps'] = torch.FloatTensor([0]).reshape([])
|
227 |
+
aux_accuracies['taste_reps'] = 0
|
228 |
+
|
229 |
+
# aggregate losses
|
230 |
+
global_loss = torch.sum(torch.cat([torch.atleast_1d(vae_loss), torch.atleast_1d(volume_loss)] + [torch.atleast_1d(aux_losses[k] * weights[k]) for k in params['auxiliaries_dict'].keys()]))
|
231 |
+
# for k in params['auxiliaries_dict'].keys():
|
232 |
+
# global_loss += aux_losses[k] * weights[k]
|
233 |
+
|
234 |
+
if train:
|
235 |
+
global_loss.backward()
|
236 |
+
opt.step()
|
237 |
+
opt.zero_grad()
|
238 |
+
|
239 |
+
# logging
|
240 |
+
losses['global_loss'].append(float(global_loss))
|
241 |
+
losses['mse_loss'].append(float(mse_loss))
|
242 |
+
losses['vae_loss'].append(float(vae_loss))
|
243 |
+
losses['volume_loss'].append(float(volume_loss))
|
244 |
+
losses['kld_loss'].append(float(kld_loss))
|
245 |
+
for k in params['auxiliaries_dict'].keys():
|
246 |
+
losses[k].append(float(aux_losses[k]))
|
247 |
+
accuracies[k].append(float(aux_accuracies[k]))
|
248 |
+
for k in aux_other_metrics.keys():
|
249 |
+
if k not in other_metrics.keys():
|
250 |
+
other_metrics[k] = [aux_other_metrics[k]]
|
251 |
+
else:
|
252 |
+
other_metrics[k].append(aux_other_metrics[k])
|
253 |
+
|
254 |
+
for k in losses.keys():
|
255 |
+
losses[k] = np.mean(losses[k])
|
256 |
+
for k in accuracies.keys():
|
257 |
+
accuracies[k] = np.mean(accuracies[k])
|
258 |
+
for k in other_metrics.keys():
|
259 |
+
other_metrics[k] = np.mean(other_metrics[k], axis=0)
|
260 |
+
return model, losses, accuracies, other_metrics
|
261 |
+
|
262 |
+
def prepare_data_and_loss(params):
|
263 |
+
train_data = MyDataset(split='train', params=params)
|
264 |
+
test_data = MyDataset(split='test', params=params)
|
265 |
+
|
266 |
+
train_data_loader = DataLoader(train_data, batch_size=params['batch_size'], shuffle=True)
|
267 |
+
test_data_loader = DataLoader(test_data, batch_size=params['batch_size'], shuffle=True)
|
268 |
+
|
269 |
+
loss_functions = dict()
|
270 |
+
weights = dict()
|
271 |
+
for k in sorted(params['auxiliaries_dict'].keys()):
|
272 |
+
if params['auxiliaries_dict'][k]['type'] == 'classif':
|
273 |
+
if k == 'glasses':
|
274 |
+
classif_weights = train_data.glasses_weights
|
275 |
+
elif k == 'prep_type':
|
276 |
+
classif_weights = train_data.prep_types_weights
|
277 |
+
elif k == 'categories':
|
278 |
+
classif_weights = train_data.categories_weights
|
279 |
+
else:
|
280 |
+
raise ValueError
|
281 |
+
loss_functions[k] = nn.CrossEntropyLoss(torch.FloatTensor(classif_weights))
|
282 |
+
elif params['auxiliaries_dict'][k]['type'] == 'multiclassif':
|
283 |
+
loss_functions[k] = nn.BCEWithLogitsLoss()
|
284 |
+
elif params['auxiliaries_dict'][k]['type'] == 'regression':
|
285 |
+
loss_functions[k] = nn.MSELoss()
|
286 |
+
else:
|
287 |
+
raise ValueError
|
288 |
+
weights[k] = params['auxiliaries_dict'][k]['weight']
|
289 |
+
|
290 |
+
|
291 |
+
return loss_functions, train_data_loader, test_data_loader, weights
|
292 |
+
|
293 |
+
def print_losses(train, losses, accuracies, other_metrics):
|
294 |
+
keyword = 'Train' if train else 'Eval'
|
295 |
+
print(f'\t{keyword} logs:')
|
296 |
+
keys = ['global_loss', 'vae_loss', 'mse_loss', 'kld_loss', 'volume_loss']
|
297 |
+
for k in keys:
|
298 |
+
print(f'\t\t{k} - Loss: {losses[k]:.2f}')
|
299 |
+
for k in sorted(accuracies.keys()):
|
300 |
+
print(f'\t\t{k} (aux) - Loss: {losses[k]:.2f}, Acc: {accuracies[k]:.2f}')
|
301 |
+
for k in sorted(other_metrics.keys()):
|
302 |
+
if 'confusion' not in k:
|
303 |
+
print(f'\t\t{k} - {other_metrics[k]:.2f}')
|
304 |
+
|
305 |
+
|
306 |
+
def run_experiment(params, verbose=True):
|
307 |
+
loss_functions, train_data_loader, test_data_loader, weights = prepare_data_and_loss(params)
|
308 |
+
params['filter_decoder_output'] = train_data_loader.dataset.filter_decoder_output
|
309 |
+
|
310 |
+
model_params = [params[k] for k in ["input_dim", "deepset_latent_dim", "hidden_dims_ingredients", "activation",
|
311 |
+
"hidden_dims_cocktail", "hidden_dims_decoder", "nb_ingredients", "latent_dim", "agg", "dropout", "auxiliaries_dict",
|
312 |
+
"filter_decoder_output"]]
|
313 |
+
model = get_vae_model(*model_params)
|
314 |
+
opt = torch.optim.AdamW(model.parameters(), lr=params['lr'])
|
315 |
+
|
316 |
+
|
317 |
+
all_train_losses = []
|
318 |
+
all_eval_losses = []
|
319 |
+
all_train_accuracies = []
|
320 |
+
all_eval_accuracies = []
|
321 |
+
all_eval_other_metrics = []
|
322 |
+
all_train_other_metrics = []
|
323 |
+
best_loss = np.inf
|
324 |
+
model, eval_losses, eval_accuracies, eval_other_metrics = run_epoch(opt=opt, train=False, model=model, data=test_data_loader, loss_functions=loss_functions,
|
325 |
+
weights=weights, params=params)
|
326 |
+
all_eval_losses.append(eval_losses)
|
327 |
+
all_eval_accuracies.append(eval_accuracies)
|
328 |
+
all_eval_other_metrics.append(eval_other_metrics)
|
329 |
+
if verbose: print(f'\n--------\nEpoch #0')
|
330 |
+
if verbose: print_losses(train=False, accuracies=eval_accuracies, losses=eval_losses, other_metrics=eval_other_metrics)
|
331 |
+
for epoch in range(params['nb_epochs']):
|
332 |
+
if verbose and (epoch + 1) % params['print_every'] == 0: print(f'\n--------\nEpoch #{epoch+1}')
|
333 |
+
model, train_losses, train_accuracies, train_other_metrics = run_epoch(opt=opt, train=True, model=model, data=train_data_loader, loss_functions=loss_functions,
|
334 |
+
weights=weights, params=params)
|
335 |
+
if verbose and (epoch + 1) % params['print_every'] == 0: print_losses(train=True, accuracies=train_accuracies, losses=train_losses, other_metrics=train_other_metrics)
|
336 |
+
model, eval_losses, eval_accuracies, eval_other_metrics = run_epoch(opt=opt, train=False, model=model, data=test_data_loader, loss_functions=loss_functions,
|
337 |
+
weights=weights, params=params)
|
338 |
+
if verbose and (epoch + 1) % params['print_every'] == 0: print_losses(train=False, accuracies=eval_accuracies, losses=eval_losses, other_metrics=eval_other_metrics)
|
339 |
+
if eval_losses['global_loss'] < best_loss:
|
340 |
+
best_loss = eval_losses['global_loss']
|
341 |
+
if verbose: print(f'Saving new best model with loss {best_loss:.2f}')
|
342 |
+
torch.save(model.state_dict(), params['save_path'] + f'checkpoint_best.save')
|
343 |
+
|
344 |
+
# log
|
345 |
+
all_train_losses.append(train_losses)
|
346 |
+
all_train_accuracies.append(train_accuracies)
|
347 |
+
all_eval_losses.append(eval_losses)
|
348 |
+
all_eval_accuracies.append(eval_accuracies)
|
349 |
+
all_eval_other_metrics.append(eval_other_metrics)
|
350 |
+
all_train_other_metrics.append(train_other_metrics)
|
351 |
+
|
352 |
+
# if epoch == params['nb_epoch_switch_beta']:
|
353 |
+
# params['beta_vae'] = 2.5
|
354 |
+
# params['auxiliaries_dict']['prep_type']['weight'] /= 10
|
355 |
+
# params['auxiliaries_dict']['glasses']['weight'] /= 10
|
356 |
+
|
357 |
+
if (epoch + 1) % params['plot_every'] == 0:
|
358 |
+
|
359 |
+
plot_results(all_train_losses, all_train_accuracies, all_train_other_metrics,
|
360 |
+
all_eval_losses, all_eval_accuracies, all_eval_other_metrics, params['plot_path'], weights)
|
361 |
+
|
362 |
+
return model
|
363 |
+
|
364 |
+
def plot_results(all_train_losses, all_train_accuracies, all_train_other_metrics,
|
365 |
+
all_eval_losses, all_eval_accuracies, all_eval_other_metrics, plot_path, weights):
|
366 |
+
|
367 |
+
steps = np.arange(len(all_eval_accuracies))
|
368 |
+
|
369 |
+
loss_keys = sorted(all_train_losses[0].keys())
|
370 |
+
acc_keys = sorted(all_train_accuracies[0].keys())
|
371 |
+
metrics_keys = sorted(all_train_other_metrics[0].keys())
|
372 |
+
|
373 |
+
plt.figure()
|
374 |
+
plt.title('Train losses')
|
375 |
+
for k in loss_keys:
|
376 |
+
factor = 1 if k == 'mse_loss' else 1
|
377 |
+
if k not in weights.keys():
|
378 |
+
plt.plot(steps[1:], [train_loss[k] * factor for train_loss in all_train_losses], label=k)
|
379 |
+
else:
|
380 |
+
if weights[k] != 0:
|
381 |
+
plt.plot(steps[1:], [train_loss[k] * factor for train_loss in all_train_losses], label=k)
|
382 |
+
|
383 |
+
plt.legend()
|
384 |
+
plt.ylim([0, 4])
|
385 |
+
plt.savefig(plot_path + 'train_losses.png', dpi=200)
|
386 |
+
fig = plt.gcf()
|
387 |
+
plt.close(fig)
|
388 |
+
|
389 |
+
plt.figure()
|
390 |
+
plt.title('Train accuracies')
|
391 |
+
for k in acc_keys:
|
392 |
+
if weights[k] != 0:
|
393 |
+
plt.plot(steps[1:], [train_acc[k] for train_acc in all_train_accuracies], label=k)
|
394 |
+
plt.legend()
|
395 |
+
plt.ylim([0, 1])
|
396 |
+
plt.savefig(plot_path + 'train_acc.png', dpi=200)
|
397 |
+
fig = plt.gcf()
|
398 |
+
plt.close(fig)
|
399 |
+
|
400 |
+
plt.figure()
|
401 |
+
plt.title('Train other metrics')
|
402 |
+
for k in metrics_keys:
|
403 |
+
if 'confusion' not in k and 'presence' in k:
|
404 |
+
plt.plot(steps[1:], [train_metric[k] for train_metric in all_train_other_metrics], label=k)
|
405 |
+
plt.legend()
|
406 |
+
plt.ylim([0, 1])
|
407 |
+
plt.savefig(plot_path + 'train_ing_presence_errors.png', dpi=200)
|
408 |
+
fig = plt.gcf()
|
409 |
+
plt.close(fig)
|
410 |
+
|
411 |
+
plt.figure()
|
412 |
+
plt.title('Train other metrics')
|
413 |
+
for k in metrics_keys:
|
414 |
+
if 'confusion' not in k and 'presence' not in k:
|
415 |
+
plt.plot(steps[1:], [train_metric[k] for train_metric in all_train_other_metrics], label=k)
|
416 |
+
plt.legend()
|
417 |
+
plt.savefig(plot_path + 'train_ing_q_error.png', dpi=200)
|
418 |
+
fig = plt.gcf()
|
419 |
+
plt.close(fig)
|
420 |
+
|
421 |
+
plt.figure()
|
422 |
+
plt.title('Eval losses')
|
423 |
+
for k in loss_keys:
|
424 |
+
factor = 1 if k == 'mse_loss' else 1
|
425 |
+
if k not in weights.keys():
|
426 |
+
plt.plot(steps, [eval_loss[k] * factor for eval_loss in all_eval_losses], label=k)
|
427 |
+
else:
|
428 |
+
if weights[k] != 0:
|
429 |
+
plt.plot(steps, [eval_loss[k] * factor for eval_loss in all_eval_losses], label=k)
|
430 |
+
plt.legend()
|
431 |
+
plt.ylim([0, 4])
|
432 |
+
plt.savefig(plot_path + 'eval_losses.png', dpi=200)
|
433 |
+
fig = plt.gcf()
|
434 |
+
plt.close(fig)
|
435 |
+
|
436 |
+
plt.figure()
|
437 |
+
plt.title('Eval accuracies')
|
438 |
+
for k in acc_keys:
|
439 |
+
if weights[k] != 0:
|
440 |
+
plt.plot(steps, [eval_acc[k] for eval_acc in all_eval_accuracies], label=k)
|
441 |
+
plt.legend()
|
442 |
+
plt.ylim([0, 1])
|
443 |
+
plt.savefig(plot_path + 'eval_acc.png', dpi=200)
|
444 |
+
fig = plt.gcf()
|
445 |
+
plt.close(fig)
|
446 |
+
|
447 |
+
plt.figure()
|
448 |
+
plt.title('Eval other metrics')
|
449 |
+
for k in metrics_keys:
|
450 |
+
if 'confusion' not in k and 'presence' in k:
|
451 |
+
plt.plot(steps, [eval_metric[k] for eval_metric in all_eval_other_metrics], label=k)
|
452 |
+
plt.legend()
|
453 |
+
plt.ylim([0, 1])
|
454 |
+
plt.savefig(plot_path + 'eval_ing_presence_errors.png', dpi=200)
|
455 |
+
fig = plt.gcf()
|
456 |
+
plt.close(fig)
|
457 |
+
|
458 |
+
plt.figure()
|
459 |
+
plt.title('Eval other metrics')
|
460 |
+
for k in metrics_keys:
|
461 |
+
if 'confusion' not in k and 'presence' not in k:
|
462 |
+
plt.plot(steps, [eval_metric[k] for eval_metric in all_eval_other_metrics], label=k)
|
463 |
+
plt.legend()
|
464 |
+
plt.savefig(plot_path + 'eval_ing_q_error.png', dpi=200)
|
465 |
+
fig = plt.gcf()
|
466 |
+
plt.close(fig)
|
467 |
+
|
468 |
+
|
469 |
+
for k in metrics_keys:
|
470 |
+
if 'confusion' in k:
|
471 |
+
plt.figure()
|
472 |
+
plt.title(k)
|
473 |
+
plt.ylabel('True')
|
474 |
+
plt.xlabel('Predicted')
|
475 |
+
plt.imshow(all_eval_other_metrics[-1][k], vmin=0, vmax=1)
|
476 |
+
plt.colorbar()
|
477 |
+
plt.savefig(plot_path + f'eval_{k}.png', dpi=200)
|
478 |
+
fig = plt.gcf()
|
479 |
+
plt.close(fig)
|
480 |
+
|
481 |
+
for k in metrics_keys:
|
482 |
+
if 'confusion' in k:
|
483 |
+
plt.figure()
|
484 |
+
plt.title(k)
|
485 |
+
plt.ylabel('True')
|
486 |
+
plt.xlabel('Predicted')
|
487 |
+
plt.imshow(all_train_other_metrics[-1][k], vmin=0, vmax=1)
|
488 |
+
plt.colorbar()
|
489 |
+
plt.savefig(plot_path + f'train_{k}.png', dpi=200)
|
490 |
+
fig = plt.gcf()
|
491 |
+
plt.close(fig)
|
492 |
+
|
493 |
+
plt.close('all')
|
494 |
+
|
495 |
+
|
496 |
+
def get_model(model_path):
|
497 |
+
|
498 |
+
with open(model_path + 'params.json', 'r') as f:
|
499 |
+
params = json.load(f)
|
500 |
+
params['save_path'] = model_path
|
501 |
+
max_ing_quantities = np.loadtxt(params['save_path'] + 'max_ing_quantities.txt')
|
502 |
+
mean_ing_quantities = np.loadtxt(params['save_path'] + 'mean_ing_quantities.txt')
|
503 |
+
std_ing_quantities = np.loadtxt(params['save_path'] + 'std_ing_quantities.txt')
|
504 |
+
min_when_present_ing_quantities = np.loadtxt(params['save_path'] + 'min_when_present_ing_quantities.txt')
|
505 |
+
def filter_decoder_output(output):
|
506 |
+
output = output.detach().numpy()
|
507 |
+
output_unnormalized = output * std_ing_quantities + mean_ing_quantities
|
508 |
+
if output.ndim == 1:
|
509 |
+
output_unnormalized[np.where(output_unnormalized < min_when_present_ing_quantities)] = 0
|
510 |
+
else:
|
511 |
+
for i in range(output.shape[0]):
|
512 |
+
output_unnormalized[i, np.where(output_unnormalized[i] < min_when_present_ing_quantities)] = 0
|
513 |
+
return output_unnormalized.copy()
|
514 |
+
params['filter_decoder_output'] = filter_decoder_output
|
515 |
+
model_chkpt = model_path + "checkpoint_best.save"
|
516 |
+
model_params = [params[k] for k in ["input_dim", "deepset_latent_dim", "hidden_dims_ingredients", "activation",
|
517 |
+
"hidden_dims_cocktail", "hidden_dims_decoder", "nb_ingredients", "latent_dim", "agg", "dropout", "auxiliaries_dict",
|
518 |
+
"filter_decoder_output"]]
|
519 |
+
model = get_vae_model(*model_params)
|
520 |
+
model.load_state_dict(torch.load(model_chkpt))
|
521 |
+
model.eval()
|
522 |
+
return model, filter_decoder_output, params
|
523 |
+
|
524 |
+
|
525 |
+
def compute_expe_name_and_save_path(params):
|
526 |
+
weights_str = '['
|
527 |
+
for aux in params['auxiliaries_dict'].keys():
|
528 |
+
weights_str += f'{params["auxiliaries_dict"][aux]["weight"]}, '
|
529 |
+
weights_str = weights_str[:-2] + ']'
|
530 |
+
save_path = params['save_path'] + params["trial_id"]
|
531 |
+
save_path += f'_lr{params["lr"]}'
|
532 |
+
save_path += f'_betavae{params["beta_vae"]}'
|
533 |
+
save_path += f'_bs{params["batch_size"]}'
|
534 |
+
save_path += f'_latentdim{params["latent_dim"]}'
|
535 |
+
save_path += f'_hding{params["hidden_dims_ingredients"]}'
|
536 |
+
save_path += f'_hdcocktail{params["hidden_dims_cocktail"]}'
|
537 |
+
save_path += f'_hddecoder{params["hidden_dims_decoder"]}'
|
538 |
+
save_path += f'_agg{params["agg"]}'
|
539 |
+
save_path += f'_activ{params["activation"]}'
|
540 |
+
save_path += f'_w{weights_str}'
|
541 |
+
counter = 0
|
542 |
+
while os.path.exists(save_path + f"_{counter}"):
|
543 |
+
counter += 1
|
544 |
+
save_path = save_path + f"_{counter}" + '/'
|
545 |
+
params["save_path"] = save_path
|
546 |
+
os.makedirs(save_path)
|
547 |
+
os.makedirs(save_path + 'plots/')
|
548 |
+
params['plot_path'] = save_path + 'plots/'
|
549 |
+
print(f'logging to {save_path}')
|
550 |
+
return params
|
551 |
+
|
552 |
+
|
553 |
+
|
554 |
+
if __name__ == '__main__':
|
555 |
+
params = get_params()
|
556 |
+
run_experiment(params)
|
557 |
+
|
src/cocktails/representation_learning/run_simple_net.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch; torch.manual_seed(0)
|
2 |
+
import torch.utils
|
3 |
+
from torch.utils.data import DataLoader
|
4 |
+
import torch.distributions
|
5 |
+
import torch.nn as nn
|
6 |
+
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
|
7 |
+
from src.cocktails.representation_learning.dataset import MyDataset, get_representation_from_ingredient, get_max_n_ingredients
|
8 |
+
import json
|
9 |
+
import pandas as pd
|
10 |
+
import numpy as np
|
11 |
+
import os
|
12 |
+
from src.cocktails.representation_learning.simple_model import SimpleNet
|
13 |
+
from src.cocktails.config import COCKTAILS_CSV_DATA, FULL_COCKTAIL_REP_PATH, EXPERIMENT_PATH
|
14 |
+
from src.cocktails.utilities.cocktail_utilities import get_bunch_of_rep_keys
|
15 |
+
from src.cocktails.utilities.ingredients_utilities import ingredient_profiles
|
16 |
+
from resource import getrusage
|
17 |
+
from resource import RUSAGE_SELF
|
18 |
+
import gc
|
19 |
+
gc.collect(2)
|
20 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
21 |
+
|
22 |
+
def get_params():
|
23 |
+
data = pd.read_csv(COCKTAILS_CSV_DATA)
|
24 |
+
max_ingredients, ingredient_set, liquor_set, liqueur_set = get_max_n_ingredients(data)
|
25 |
+
num_ingredients = len(ingredient_set)
|
26 |
+
rep_keys = get_bunch_of_rep_keys()['custom']
|
27 |
+
ing_keys = [k.split(' ')[1] for k in rep_keys]
|
28 |
+
ing_keys.remove('volume')
|
29 |
+
nb_ing_categories = len(set(ingredient_profiles['type']))
|
30 |
+
category_encodings = dict(zip(sorted(set(ingredient_profiles['type'])), np.eye(nb_ing_categories)))
|
31 |
+
|
32 |
+
params = dict(trial_id='test',
|
33 |
+
save_path=EXPERIMENT_PATH + "/simple_net/",
|
34 |
+
nb_epochs=100,
|
35 |
+
print_every=50,
|
36 |
+
plot_every=50,
|
37 |
+
batch_size=128,
|
38 |
+
lr=0.001,
|
39 |
+
dropout=0.15,
|
40 |
+
output_keyword='glasses',
|
41 |
+
ing_keys=ing_keys,
|
42 |
+
nb_ingredients=len(ingredient_set),
|
43 |
+
hidden_dims=[16],
|
44 |
+
activation='sigmoid',
|
45 |
+
auxiliaries_dict=dict(categories=dict(weight=0, type='classif', final_activ=None, dim_output=len(set(data['subcategory']))),
|
46 |
+
glasses=dict(weight=0, type='classif', final_activ=None, dim_output=len(set(data['glass']))),
|
47 |
+
prep_type=dict(weight=0, type='classif', final_activ=None, dim_output=len(set(data['category']))),
|
48 |
+
cocktail_reps=dict(weight=0, type='regression', final_activ=None, dim_output=13),
|
49 |
+
volume=dict(weight=0, type='regression', final_activ='relu', dim_output=1),
|
50 |
+
taste_reps=dict(weight=0, type='regression', final_activ='relu', dim_output=2),
|
51 |
+
ingredients_presence=dict(weight=0, type='multiclassif', final_activ=None, dim_output=num_ingredients),
|
52 |
+
ingredients_quantities=dict(weight=0, type='regression', final_activ=None, dim_output=num_ingredients)),
|
53 |
+
|
54 |
+
category_encodings=category_encodings
|
55 |
+
)
|
56 |
+
params['output_dim'] = params['auxiliaries_dict'][params['output_keyword']]['dim_output']
|
57 |
+
water_rep, indexes_to_normalize = get_representation_from_ingredient(ingredients=['water'], quantities=[1],
|
58 |
+
max_q_per_ing=dict(zip(ingredient_set, [1] * num_ingredients)), index=0,
|
59 |
+
params=params)
|
60 |
+
dim_rep_ingredient = water_rep.size
|
61 |
+
params['indexes_ing_to_normalize'] = indexes_to_normalize
|
62 |
+
params['deepset_latent_dim'] = dim_rep_ingredient * max_ingredients
|
63 |
+
params['dim_rep_ingredient'] = dim_rep_ingredient
|
64 |
+
params['input_dim'] = params['nb_ingredients']
|
65 |
+
params = compute_expe_name_and_save_path(params)
|
66 |
+
del params['category_encodings'] # to dump
|
67 |
+
with open(params['save_path'] + 'params.json', 'w') as f:
|
68 |
+
json.dump(params, f)
|
69 |
+
|
70 |
+
params = complete_params(params)
|
71 |
+
return params
|
72 |
+
|
73 |
+
def complete_params(params):
|
74 |
+
data = pd.read_csv(COCKTAILS_CSV_DATA)
|
75 |
+
cocktail_reps = np.loadtxt(FULL_COCKTAIL_REP_PATH)
|
76 |
+
nb_ing_categories = len(set(ingredient_profiles['type']))
|
77 |
+
category_encodings = dict(zip(sorted(set(ingredient_profiles['type'])), np.eye(nb_ing_categories)))
|
78 |
+
params['cocktail_reps'] = cocktail_reps
|
79 |
+
params['raw_data'] = data
|
80 |
+
params['category_encodings'] = category_encodings
|
81 |
+
return params
|
82 |
+
|
83 |
+
def compute_confusion_matrix_and_accuracy(predictions, ground_truth):
|
84 |
+
bs, n_options = predictions.shape
|
85 |
+
predicted = predictions.argmax(dim=1).detach().numpy()
|
86 |
+
true = ground_truth.int().detach().numpy()
|
87 |
+
confusion_matrix = np.zeros([n_options, n_options])
|
88 |
+
for i in range(bs):
|
89 |
+
confusion_matrix[true[i], predicted[i]] += 1
|
90 |
+
acc = confusion_matrix.diagonal().sum() / bs
|
91 |
+
for i in range(n_options):
|
92 |
+
if confusion_matrix[i].sum() != 0:
|
93 |
+
confusion_matrix[i] /= confusion_matrix[i].sum()
|
94 |
+
acc2 = np.mean(predicted == true)
|
95 |
+
assert (acc - acc2) < 1e-5
|
96 |
+
return confusion_matrix, acc
|
97 |
+
|
98 |
+
|
99 |
+
def run_epoch(opt, train, model, data, loss_function, params):
|
100 |
+
if train:
|
101 |
+
model.train()
|
102 |
+
else:
|
103 |
+
model.eval()
|
104 |
+
|
105 |
+
# prepare logging of losses
|
106 |
+
losses = []
|
107 |
+
accuracies = []
|
108 |
+
cf_matrices = []
|
109 |
+
if train: opt.zero_grad()
|
110 |
+
|
111 |
+
for d in data:
|
112 |
+
nb_ingredients = d[0]
|
113 |
+
batch_size = nb_ingredients.shape[0]
|
114 |
+
x_ingredients = d[1].float()
|
115 |
+
ingredient_quantities = d[2].float()
|
116 |
+
cocktail_reps = d[3].float()
|
117 |
+
auxiliaries = d[4]
|
118 |
+
for k in auxiliaries.keys():
|
119 |
+
if auxiliaries[k].dtype == torch.float64: auxiliaries[k] = auxiliaries[k].float()
|
120 |
+
taste_valid = d[-1]
|
121 |
+
predictions = model(ingredient_quantities)
|
122 |
+
loss = loss_function(predictions, auxiliaries[params['output_keyword']].long()).float()
|
123 |
+
cf_matrix, accuracy = compute_confusion_matrix_and_accuracy(predictions, auxiliaries[params['output_keyword']])
|
124 |
+
if train:
|
125 |
+
loss.backward()
|
126 |
+
opt.step()
|
127 |
+
opt.zero_grad()
|
128 |
+
|
129 |
+
losses.append(float(loss))
|
130 |
+
cf_matrices.append(cf_matrix)
|
131 |
+
accuracies.append(accuracy)
|
132 |
+
|
133 |
+
return model, np.mean(losses), np.mean(accuracies), np.mean(cf_matrices, axis=0)
|
134 |
+
|
135 |
+
def prepare_data_and_loss(params):
|
136 |
+
train_data = MyDataset(split='train', params=params)
|
137 |
+
test_data = MyDataset(split='test', params=params)
|
138 |
+
|
139 |
+
train_data_loader = DataLoader(train_data, batch_size=params['batch_size'], shuffle=True)
|
140 |
+
test_data_loader = DataLoader(test_data, batch_size=params['batch_size'], shuffle=True)
|
141 |
+
|
142 |
+
|
143 |
+
if params['auxiliaries_dict'][params['output_keyword']]['type'] == 'classif':
|
144 |
+
if params['output_keyword'] == 'glasses':
|
145 |
+
classif_weights = train_data.glasses_weights
|
146 |
+
elif params['output_keyword'] == 'prep_type':
|
147 |
+
classif_weights = train_data.prep_types_weights
|
148 |
+
elif params['output_keyword'] == 'categories':
|
149 |
+
classif_weights = train_data.categories_weights
|
150 |
+
else:
|
151 |
+
raise ValueError
|
152 |
+
# classif_weights = (np.array(classif_weights) * 2 + np.ones(len(classif_weights))) / 3
|
153 |
+
loss_function = nn.CrossEntropyLoss(torch.FloatTensor(classif_weights))
|
154 |
+
# loss_function = nn.CrossEntropyLoss()
|
155 |
+
|
156 |
+
elif params['auxiliaries_dict'][params['output_keyword']]['type'] == 'multiclassif':
|
157 |
+
loss_function = nn.BCEWithLogitsLoss()
|
158 |
+
elif params['auxiliaries_dict'][params['output_keyword']]['type'] == 'regression':
|
159 |
+
loss_function = nn.MSELoss()
|
160 |
+
else:
|
161 |
+
raise ValueError
|
162 |
+
|
163 |
+
return loss_function, train_data_loader, test_data_loader
|
164 |
+
|
165 |
+
def print_losses(train, loss, accuracy):
|
166 |
+
keyword = 'Train' if train else 'Eval'
|
167 |
+
print(f'\t{keyword} logs:')
|
168 |
+
print(f'\t\t Loss: {loss:.2f}, Acc: {accuracy:.2f}')
|
169 |
+
|
170 |
+
|
171 |
+
def run_experiment(params, verbose=True):
|
172 |
+
loss_function, train_data_loader, test_data_loader = prepare_data_and_loss(params)
|
173 |
+
|
174 |
+
model = SimpleNet(params['input_dim'], params['hidden_dims'], params['output_dim'], params['activation'], params['dropout'])
|
175 |
+
opt = torch.optim.AdamW(model.parameters(), lr=params['lr'])
|
176 |
+
|
177 |
+
all_train_losses = []
|
178 |
+
all_eval_losses = []
|
179 |
+
all_eval_cf_matrices = []
|
180 |
+
all_train_accuracies = []
|
181 |
+
all_eval_accuracies = []
|
182 |
+
all_train_cf_matrices = []
|
183 |
+
best_loss = np.inf
|
184 |
+
model, eval_loss, eval_accuracy, eval_cf_matrix = run_epoch(opt=opt, train=False, model=model, data=test_data_loader, loss_function=loss_function, params=params)
|
185 |
+
all_eval_losses.append(eval_loss)
|
186 |
+
all_eval_accuracies.append(eval_accuracy)
|
187 |
+
if verbose: print(f'\n--------\nEpoch #0')
|
188 |
+
if verbose: print_losses(train=False, accuracy=eval_accuracy, loss=eval_loss)
|
189 |
+
for epoch in range(params['nb_epochs']):
|
190 |
+
if verbose and (epoch + 1) % params['print_every'] == 0: print(f'\n--------\nEpoch #{epoch+1}')
|
191 |
+
model, train_loss, train_accuracy, train_cf_matrix = run_epoch(opt=opt, train=True, model=model, data=train_data_loader, loss_function=loss_function, params=params)
|
192 |
+
if verbose and (epoch + 1) % params['print_every'] == 0: print_losses(train=True, accuracy=train_accuracy, loss=train_loss)
|
193 |
+
model, eval_loss, eval_accuracy, eval_cf_matrix = run_epoch(opt=opt, train=False, model=model, data=test_data_loader, loss_function=loss_function, params=params)
|
194 |
+
if verbose and (epoch + 1) % params['print_every'] == 0: print_losses(train=False, accuracy=eval_accuracy, loss=eval_loss)
|
195 |
+
if eval_loss < best_loss:
|
196 |
+
best_loss = eval_loss
|
197 |
+
if verbose: print(f'Saving new best model with loss {best_loss:.2f}')
|
198 |
+
torch.save(model.state_dict(), params['save_path'] + f'checkpoint_best.save')
|
199 |
+
|
200 |
+
# log
|
201 |
+
all_train_losses.append(train_loss)
|
202 |
+
all_train_accuracies.append(train_accuracy)
|
203 |
+
all_eval_losses.append(eval_loss)
|
204 |
+
all_eval_accuracies.append(eval_accuracy)
|
205 |
+
all_eval_cf_matrices.append(eval_cf_matrix)
|
206 |
+
all_train_cf_matrices.append(train_cf_matrix)
|
207 |
+
|
208 |
+
if (epoch + 1) % params['plot_every'] == 0:
|
209 |
+
|
210 |
+
plot_results(all_train_losses, all_train_accuracies, all_train_cf_matrices,
|
211 |
+
all_eval_losses, all_eval_accuracies, all_eval_cf_matrices, params['plot_path'])
|
212 |
+
|
213 |
+
return model
|
214 |
+
|
215 |
+
def plot_results(all_train_losses, all_train_accuracies, all_train_cf_matrices,
|
216 |
+
all_eval_losses, all_eval_accuracies, all_eval_cf_matrices, plot_path):
|
217 |
+
|
218 |
+
steps = np.arange(len(all_eval_accuracies))
|
219 |
+
|
220 |
+
plt.figure()
|
221 |
+
plt.title('Losses')
|
222 |
+
plt.plot(steps[1:], all_train_losses, label='train')
|
223 |
+
plt.plot(steps, all_eval_losses, label='eval')
|
224 |
+
plt.legend()
|
225 |
+
plt.ylim([0, 4])
|
226 |
+
plt.savefig(plot_path + 'losses.png', dpi=200)
|
227 |
+
fig = plt.gcf()
|
228 |
+
plt.close(fig)
|
229 |
+
|
230 |
+
plt.figure()
|
231 |
+
plt.title('Accuracies')
|
232 |
+
plt.plot(steps[1:], all_train_accuracies, label='train')
|
233 |
+
plt.plot(steps, all_eval_accuracies, label='eval')
|
234 |
+
plt.legend()
|
235 |
+
plt.ylim([0, 1])
|
236 |
+
plt.savefig(plot_path + 'accs.png', dpi=200)
|
237 |
+
fig = plt.gcf()
|
238 |
+
plt.close(fig)
|
239 |
+
|
240 |
+
|
241 |
+
plt.figure()
|
242 |
+
plt.title('Train confusion matrix')
|
243 |
+
plt.ylabel('True')
|
244 |
+
plt.xlabel('Predicted')
|
245 |
+
plt.imshow(all_train_cf_matrices[-1], vmin=0, vmax=1)
|
246 |
+
plt.colorbar()
|
247 |
+
plt.savefig(plot_path + f'train_confusion_matrix.png', dpi=200)
|
248 |
+
fig = plt.gcf()
|
249 |
+
plt.close(fig)
|
250 |
+
|
251 |
+
plt.figure()
|
252 |
+
plt.title('Eval confusion matrix')
|
253 |
+
plt.ylabel('True')
|
254 |
+
plt.xlabel('Predicted')
|
255 |
+
plt.imshow(all_eval_cf_matrices[-1], vmin=0, vmax=1)
|
256 |
+
plt.colorbar()
|
257 |
+
plt.savefig(plot_path + f'eval_confusion_matrix.png', dpi=200)
|
258 |
+
fig = plt.gcf()
|
259 |
+
plt.close(fig)
|
260 |
+
|
261 |
+
plt.close('all')
|
262 |
+
|
263 |
+
|
264 |
+
def get_model(model_path):
|
265 |
+
with open(model_path + 'params.json', 'r') as f:
|
266 |
+
params = json.load(f)
|
267 |
+
params['save_path'] = model_path
|
268 |
+
model_chkpt = model_path + "checkpoint_best.save"
|
269 |
+
model = SimpleNet(params['input_dim'], params['hidden_dims'], params['output_dim'], params['activation'], params['dropout'])
|
270 |
+
model.load_state_dict(torch.load(model_chkpt))
|
271 |
+
model.eval()
|
272 |
+
return model, params
|
273 |
+
|
274 |
+
|
275 |
+
def compute_expe_name_and_save_path(params):
|
276 |
+
weights_str = '['
|
277 |
+
for aux in params['auxiliaries_dict'].keys():
|
278 |
+
weights_str += f'{params["auxiliaries_dict"][aux]["weight"]}, '
|
279 |
+
weights_str = weights_str[:-2] + ']'
|
280 |
+
save_path = params['save_path'] + params["trial_id"]
|
281 |
+
save_path += f'_lr{params["lr"]}'
|
282 |
+
save_path += f'_bs{params["batch_size"]}'
|
283 |
+
save_path += f'_hd{params["hidden_dims"]}'
|
284 |
+
save_path += f'_activ{params["activation"]}'
|
285 |
+
save_path += f'_w{weights_str}'
|
286 |
+
counter = 0
|
287 |
+
while os.path.exists(save_path + f"_{counter}"):
|
288 |
+
counter += 1
|
289 |
+
save_path = save_path + f"_{counter}" + '/'
|
290 |
+
params["save_path"] = save_path
|
291 |
+
os.makedirs(save_path)
|
292 |
+
os.makedirs(save_path + 'plots/')
|
293 |
+
params['plot_path'] = save_path + 'plots/'
|
294 |
+
print(f'logging to {save_path}')
|
295 |
+
return params
|
296 |
+
|
297 |
+
|
298 |
+
|
299 |
+
if __name__ == '__main__':
|
300 |
+
params = get_params()
|
301 |
+
run_experiment(params)
|
302 |
+
|
src/cocktails/representation_learning/run_without_vae.py
ADDED
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch; torch.manual_seed(0)
|
2 |
+
import torch.utils
|
3 |
+
from torch.utils.data import DataLoader
|
4 |
+
import torch.distributions
|
5 |
+
import torch.nn as nn
|
6 |
+
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
|
7 |
+
from src.cocktails.representation_learning.dataset import MyDataset, get_representation_from_ingredient, get_max_n_ingredients
|
8 |
+
import json
|
9 |
+
import pandas as pd
|
10 |
+
import numpy as np
|
11 |
+
import os
|
12 |
+
from src.cocktails.representation_learning.multihead_model import get_multihead_model
|
13 |
+
from src.cocktails.config import COCKTAILS_CSV_DATA, FULL_COCKTAIL_REP_PATH, EXPERIMENT_PATH
|
14 |
+
from src.cocktails.utilities.cocktail_utilities import get_bunch_of_rep_keys
|
15 |
+
from src.cocktails.utilities.ingredients_utilities import ingredient_profiles
|
16 |
+
from resource import getrusage
|
17 |
+
from resource import RUSAGE_SELF
|
18 |
+
import gc
|
19 |
+
gc.collect(2)
|
20 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
21 |
+
|
22 |
+
def get_params():
|
23 |
+
data = pd.read_csv(COCKTAILS_CSV_DATA)
|
24 |
+
max_ingredients, ingredient_set, liquor_set, liqueur_set = get_max_n_ingredients(data)
|
25 |
+
num_ingredients = len(ingredient_set)
|
26 |
+
rep_keys = get_bunch_of_rep_keys()['custom']
|
27 |
+
ing_keys = [k.split(' ')[1] for k in rep_keys]
|
28 |
+
ing_keys.remove('volume')
|
29 |
+
nb_ing_categories = len(set(ingredient_profiles['type']))
|
30 |
+
category_encodings = dict(zip(sorted(set(ingredient_profiles['type'])), np.eye(nb_ing_categories)))
|
31 |
+
|
32 |
+
params = dict(trial_id='test',
|
33 |
+
save_path=EXPERIMENT_PATH + "/multihead_model/",
|
34 |
+
nb_epochs=500,
|
35 |
+
print_every=50,
|
36 |
+
plot_every=50,
|
37 |
+
batch_size=128,
|
38 |
+
lr=0.001,
|
39 |
+
dropout=0.,
|
40 |
+
nb_epoch_switch_beta=600,
|
41 |
+
latent_dim=10,
|
42 |
+
beta_vae=0.2,
|
43 |
+
ing_keys=ing_keys,
|
44 |
+
nb_ingredients=len(ingredient_set),
|
45 |
+
hidden_dims_ingredients=[128],
|
46 |
+
hidden_dims_cocktail=[64],
|
47 |
+
hidden_dims_decoder=[32],
|
48 |
+
agg='mean',
|
49 |
+
activation='relu',
|
50 |
+
auxiliaries_dict=dict(categories=dict(weight=5, type='classif', final_activ=None, dim_output=len(set(data['subcategory']))), #0.5
|
51 |
+
glasses=dict(weight=0.5, type='classif', final_activ=None, dim_output=len(set(data['glass']))), #0.1
|
52 |
+
prep_type=dict(weight=0.1, type='classif', final_activ=None, dim_output=len(set(data['category']))),#1
|
53 |
+
cocktail_reps=dict(weight=1, type='regression', final_activ=None, dim_output=13),#1
|
54 |
+
volume=dict(weight=1, type='regression', final_activ='relu', dim_output=1),#1
|
55 |
+
taste_reps=dict(weight=1, type='regression', final_activ='relu', dim_output=2),#1
|
56 |
+
ingredients_presence=dict(weight=0, type='multiclassif', final_activ=None, dim_output=num_ingredients),#10
|
57 |
+
ingredients_quantities=dict(weight=0, type='regression', final_activ=None, dim_output=num_ingredients)),
|
58 |
+
category_encodings=category_encodings
|
59 |
+
)
|
60 |
+
water_rep, indexes_to_normalize = get_representation_from_ingredient(ingredients=['water'], quantities=[1],
|
61 |
+
max_q_per_ing=dict(zip(ingredient_set, [1] * num_ingredients)), index=0,
|
62 |
+
params=params)
|
63 |
+
dim_rep_ingredient = water_rep.size
|
64 |
+
params['indexes_ing_to_normalize'] = indexes_to_normalize
|
65 |
+
params['deepset_latent_dim'] = dim_rep_ingredient * max_ingredients
|
66 |
+
params['dim_rep_ingredient'] = dim_rep_ingredient
|
67 |
+
params['input_dim'] = params['nb_ingredients']
|
68 |
+
params = compute_expe_name_and_save_path(params)
|
69 |
+
del params['category_encodings'] # to dump
|
70 |
+
with open(params['save_path'] + 'params.json', 'w') as f:
|
71 |
+
json.dump(params, f)
|
72 |
+
|
73 |
+
params = complete_params(params)
|
74 |
+
return params
|
75 |
+
|
76 |
+
def complete_params(params):
|
77 |
+
data = pd.read_csv(COCKTAILS_CSV_DATA)
|
78 |
+
cocktail_reps = np.loadtxt(FULL_COCKTAIL_REP_PATH)
|
79 |
+
nb_ing_categories = len(set(ingredient_profiles['type']))
|
80 |
+
category_encodings = dict(zip(sorted(set(ingredient_profiles['type'])), np.eye(nb_ing_categories)))
|
81 |
+
params['cocktail_reps'] = cocktail_reps
|
82 |
+
params['raw_data'] = data
|
83 |
+
params['category_encodings'] = category_encodings
|
84 |
+
return params
|
85 |
+
|
86 |
+
def compute_losses_and_accuracies(loss_functions, auxiliaries, auxiliaries_str, outputs, data):
|
87 |
+
losses = dict()
|
88 |
+
accuracies = dict()
|
89 |
+
other_metrics = dict()
|
90 |
+
for i_k, k in enumerate(auxiliaries_str):
|
91 |
+
# get ground truth
|
92 |
+
# compute loss
|
93 |
+
if k == 'volume':
|
94 |
+
outputs[i_k] = outputs[i_k].flatten()
|
95 |
+
ground_truth = auxiliaries[k]
|
96 |
+
if ground_truth.dtype == torch.float64:
|
97 |
+
losses[k] = loss_functions[k](outputs[i_k], ground_truth.float()).float()
|
98 |
+
elif ground_truth.dtype == torch.int64:
|
99 |
+
if str(loss_functions[k]) != "BCEWithLogitsLoss()":
|
100 |
+
losses[k] = loss_functions[k](outputs[i_k].float(), ground_truth.long()).float()
|
101 |
+
else:
|
102 |
+
losses[k] = loss_functions[k](outputs[i_k].float(), ground_truth.float()).float()
|
103 |
+
else:
|
104 |
+
losses[k] = loss_functions[k](outputs[i_k], ground_truth).float()
|
105 |
+
# compute accuracies
|
106 |
+
if str(loss_functions[k]) == 'CrossEntropyLoss()':
|
107 |
+
bs, n_options = outputs[i_k].shape
|
108 |
+
predicted = outputs[i_k].argmax(dim=1).detach().numpy()
|
109 |
+
true = ground_truth.int().detach().numpy()
|
110 |
+
confusion_matrix = np.zeros([n_options, n_options])
|
111 |
+
for i in range(bs):
|
112 |
+
confusion_matrix[true[i], predicted[i]] += 1
|
113 |
+
acc = confusion_matrix.diagonal().sum() / bs
|
114 |
+
for i in range(n_options):
|
115 |
+
if confusion_matrix[i].sum() != 0:
|
116 |
+
confusion_matrix[i] /= confusion_matrix[i].sum()
|
117 |
+
other_metrics[k + '_confusion'] = confusion_matrix
|
118 |
+
accuracies[k] = np.mean(outputs[i_k].argmax(dim=1).detach().numpy() == ground_truth.int().detach().numpy())
|
119 |
+
assert (acc - accuracies[k]) < 1e-5
|
120 |
+
|
121 |
+
elif str(loss_functions[k]) == 'BCEWithLogitsLoss()':
|
122 |
+
assert k == 'ingredients_presence'
|
123 |
+
outputs_rescaled = outputs[i_k].detach().numpy() * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities
|
124 |
+
predicted_presence = (outputs_rescaled > 0).astype(bool)
|
125 |
+
presence = ground_truth.detach().numpy().astype(bool)
|
126 |
+
other_metrics[k + '_false_positive'] = np.mean(np.logical_and(predicted_presence.astype(bool), ~presence.astype(bool)))
|
127 |
+
other_metrics[k + '_false_negative'] = np.mean(np.logical_and(~predicted_presence.astype(bool), presence.astype(bool)))
|
128 |
+
accuracies[k] = np.mean(predicted_presence == presence) # accuracy for multi class labeling
|
129 |
+
elif str(loss_functions[k]) == 'MSELoss()':
|
130 |
+
accuracies[k] = np.nan
|
131 |
+
else:
|
132 |
+
raise ValueError
|
133 |
+
return losses, accuracies, other_metrics
|
134 |
+
|
135 |
+
def compute_metric_output(aux_other_metrics, data, ingredient_quantities, x_hat):
|
136 |
+
ing_q = ingredient_quantities.detach().numpy()# * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities
|
137 |
+
ing_presence = (ing_q > 0)
|
138 |
+
x_hat = x_hat.detach().numpy()
|
139 |
+
# x_hat = x_hat.detach().numpy() * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities
|
140 |
+
abs_diff = np.abs(ing_q - x_hat) * data.dataset.max_ing_quantities
|
141 |
+
# abs_diff = np.abs(ing_q - x_hat)
|
142 |
+
ing_q_abs_loss_when_present, ing_q_abs_loss_when_absent = [], []
|
143 |
+
for i in range(ingredient_quantities.shape[0]):
|
144 |
+
ing_q_abs_loss_when_present.append(np.mean(abs_diff[i, np.where(ing_presence[i])]))
|
145 |
+
ing_q_abs_loss_when_absent.append(np.mean(abs_diff[i, np.where(~ing_presence[i])]))
|
146 |
+
aux_other_metrics['ing_q_abs_loss_when_present'] = np.mean(ing_q_abs_loss_when_present)
|
147 |
+
aux_other_metrics['ing_q_abs_loss_when_absent'] = np.mean(ing_q_abs_loss_when_absent)
|
148 |
+
return aux_other_metrics
|
149 |
+
|
150 |
+
def run_epoch(opt, train, model, data, loss_functions, weights, params):
|
151 |
+
if train:
|
152 |
+
model.train()
|
153 |
+
else:
|
154 |
+
model.eval()
|
155 |
+
|
156 |
+
# prepare logging of losses
|
157 |
+
losses = dict(kld_loss=[],
|
158 |
+
mse_loss=[],
|
159 |
+
vae_loss=[],
|
160 |
+
volume_loss=[],
|
161 |
+
global_loss=[])
|
162 |
+
accuracies = dict()
|
163 |
+
other_metrics = dict()
|
164 |
+
for aux in params['auxiliaries_dict'].keys():
|
165 |
+
losses[aux] = []
|
166 |
+
accuracies[aux] = []
|
167 |
+
if train: opt.zero_grad()
|
168 |
+
|
169 |
+
for d in data:
|
170 |
+
nb_ingredients = d[0]
|
171 |
+
batch_size = nb_ingredients.shape[0]
|
172 |
+
x_ingredients = d[1].float()
|
173 |
+
ingredient_quantities = d[2]
|
174 |
+
cocktail_reps = d[3]
|
175 |
+
auxiliaries = d[4]
|
176 |
+
for k in auxiliaries.keys():
|
177 |
+
if auxiliaries[k].dtype == torch.float64: auxiliaries[k] = auxiliaries[k].float()
|
178 |
+
taste_valid = d[-1]
|
179 |
+
z, outputs, auxiliaries_str = model.forward(ingredient_quantities.float())
|
180 |
+
# get auxiliary losses and accuracies
|
181 |
+
aux_losses, aux_accuracies, aux_other_metrics = compute_losses_and_accuracies(loss_functions, auxiliaries, auxiliaries_str, outputs, data)
|
182 |
+
|
183 |
+
# compute vae loss
|
184 |
+
aux_other_metrics = compute_metric_output(aux_other_metrics, data, ingredient_quantities, outputs[auxiliaries_str.index('ingredients_quantities')])
|
185 |
+
|
186 |
+
indexes_taste_valid = np.argwhere(taste_valid.detach().numpy()).flatten()
|
187 |
+
if indexes_taste_valid.size > 0:
|
188 |
+
outputs_taste = model.get_auxiliary(z[indexes_taste_valid], aux_str='taste_reps')
|
189 |
+
gt = auxiliaries['taste_reps'][indexes_taste_valid]
|
190 |
+
factor_loss = indexes_taste_valid.size / (0.3 * batch_size)# factor on the loss: if same ratio as actual dataset factor = 1 if there is less data, then the factor decreases, more data, it increases
|
191 |
+
aux_losses['taste_reps'] = (loss_functions['taste_reps'](outputs_taste, gt) * factor_loss).float()
|
192 |
+
else:
|
193 |
+
aux_losses['taste_reps'] = torch.FloatTensor([0]).reshape([])
|
194 |
+
aux_accuracies['taste_reps'] = 0
|
195 |
+
|
196 |
+
# aggregate losses
|
197 |
+
global_loss = torch.sum(torch.cat([torch.atleast_1d(aux_losses[k] * weights[k]) for k in params['auxiliaries_dict'].keys()]))
|
198 |
+
# for k in params['auxiliaries_dict'].keys():
|
199 |
+
# global_loss += aux_losses[k] * weights[k]
|
200 |
+
|
201 |
+
if train:
|
202 |
+
global_loss.backward()
|
203 |
+
opt.step()
|
204 |
+
opt.zero_grad()
|
205 |
+
|
206 |
+
# logging
|
207 |
+
losses['global_loss'].append(float(global_loss))
|
208 |
+
for k in params['auxiliaries_dict'].keys():
|
209 |
+
losses[k].append(float(aux_losses[k]))
|
210 |
+
accuracies[k].append(float(aux_accuracies[k]))
|
211 |
+
for k in aux_other_metrics.keys():
|
212 |
+
if k not in other_metrics.keys():
|
213 |
+
other_metrics[k] = [aux_other_metrics[k]]
|
214 |
+
else:
|
215 |
+
other_metrics[k].append(aux_other_metrics[k])
|
216 |
+
|
217 |
+
for k in losses.keys():
|
218 |
+
losses[k] = np.mean(losses[k])
|
219 |
+
for k in accuracies.keys():
|
220 |
+
accuracies[k] = np.mean(accuracies[k])
|
221 |
+
for k in other_metrics.keys():
|
222 |
+
other_metrics[k] = np.mean(other_metrics[k], axis=0)
|
223 |
+
return model, losses, accuracies, other_metrics
|
224 |
+
|
225 |
+
def prepare_data_and_loss(params):
|
226 |
+
train_data = MyDataset(split='train', params=params)
|
227 |
+
test_data = MyDataset(split='test', params=params)
|
228 |
+
|
229 |
+
train_data_loader = DataLoader(train_data, batch_size=params['batch_size'], shuffle=True)
|
230 |
+
test_data_loader = DataLoader(test_data, batch_size=params['batch_size'], shuffle=True)
|
231 |
+
|
232 |
+
loss_functions = dict()
|
233 |
+
weights = dict()
|
234 |
+
for k in sorted(params['auxiliaries_dict'].keys()):
|
235 |
+
if params['auxiliaries_dict'][k]['type'] == 'classif':
|
236 |
+
if k == 'glasses':
|
237 |
+
classif_weights = train_data.glasses_weights
|
238 |
+
elif k == 'prep_type':
|
239 |
+
classif_weights = train_data.prep_types_weights
|
240 |
+
elif k == 'categories':
|
241 |
+
classif_weights = train_data.categories_weights
|
242 |
+
else:
|
243 |
+
raise ValueError
|
244 |
+
loss_functions[k] = nn.CrossEntropyLoss(torch.FloatTensor(classif_weights))
|
245 |
+
elif params['auxiliaries_dict'][k]['type'] == 'multiclassif':
|
246 |
+
loss_functions[k] = nn.BCEWithLogitsLoss()
|
247 |
+
elif params['auxiliaries_dict'][k]['type'] == 'regression':
|
248 |
+
loss_functions[k] = nn.MSELoss()
|
249 |
+
else:
|
250 |
+
raise ValueError
|
251 |
+
weights[k] = params['auxiliaries_dict'][k]['weight']
|
252 |
+
|
253 |
+
|
254 |
+
return loss_functions, train_data_loader, test_data_loader, weights
|
255 |
+
|
256 |
+
def print_losses(train, losses, accuracies, other_metrics):
|
257 |
+
keyword = 'Train' if train else 'Eval'
|
258 |
+
print(f'\t{keyword} logs:')
|
259 |
+
keys = ['global_loss', 'vae_loss', 'mse_loss', 'kld_loss', 'volume_loss']
|
260 |
+
for k in keys:
|
261 |
+
print(f'\t\t{k} - Loss: {losses[k]:.2f}')
|
262 |
+
for k in sorted(accuracies.keys()):
|
263 |
+
print(f'\t\t{k} (aux) - Loss: {losses[k]:.2f}, Acc: {accuracies[k]:.2f}')
|
264 |
+
for k in sorted(other_metrics.keys()):
|
265 |
+
if 'confusion' not in k:
|
266 |
+
print(f'\t\t{k} - {other_metrics[k]:.2f}')
|
267 |
+
|
268 |
+
|
269 |
+
def run_experiment(params, verbose=True):
|
270 |
+
loss_functions, train_data_loader, test_data_loader, weights = prepare_data_and_loss(params)
|
271 |
+
|
272 |
+
model_params = [params[k] for k in ["input_dim", "activation", "hidden_dims_cocktail", "latent_dim", "dropout", "auxiliaries_dict", "hidden_dims_decoder"]]
|
273 |
+
model = get_multihead_model(*model_params)
|
274 |
+
opt = torch.optim.AdamW(model.parameters(), lr=params['lr'])
|
275 |
+
|
276 |
+
|
277 |
+
all_train_losses = []
|
278 |
+
all_eval_losses = []
|
279 |
+
all_train_accuracies = []
|
280 |
+
all_eval_accuracies = []
|
281 |
+
all_eval_other_metrics = []
|
282 |
+
all_train_other_metrics = []
|
283 |
+
best_loss = np.inf
|
284 |
+
model, eval_losses, eval_accuracies, eval_other_metrics = run_epoch(opt=opt, train=False, model=model, data=test_data_loader, loss_functions=loss_functions,
|
285 |
+
weights=weights, params=params)
|
286 |
+
all_eval_losses.append(eval_losses)
|
287 |
+
all_eval_accuracies.append(eval_accuracies)
|
288 |
+
all_eval_other_metrics.append(eval_other_metrics)
|
289 |
+
if verbose: print(f'\n--------\nEpoch #0')
|
290 |
+
if verbose: print_losses(train=False, accuracies=eval_accuracies, losses=eval_losses, other_metrics=eval_other_metrics)
|
291 |
+
for epoch in range(params['nb_epochs']):
|
292 |
+
if verbose and (epoch + 1) % params['print_every'] == 0: print(f'\n--------\nEpoch #{epoch+1}')
|
293 |
+
model, train_losses, train_accuracies, train_other_metrics = run_epoch(opt=opt, train=True, model=model, data=train_data_loader, loss_functions=loss_functions,
|
294 |
+
weights=weights, params=params)
|
295 |
+
if verbose and (epoch + 1) % params['print_every'] == 0: print_losses(train=True, accuracies=train_accuracies, losses=train_losses, other_metrics=train_other_metrics)
|
296 |
+
model, eval_losses, eval_accuracies, eval_other_metrics = run_epoch(opt=opt, train=False, model=model, data=test_data_loader, loss_functions=loss_functions,
|
297 |
+
weights=weights, params=params)
|
298 |
+
if verbose and (epoch + 1) % params['print_every'] == 0: print_losses(train=False, accuracies=eval_accuracies, losses=eval_losses, other_metrics=eval_other_metrics)
|
299 |
+
if eval_losses['global_loss'] < best_loss:
|
300 |
+
best_loss = eval_losses['global_loss']
|
301 |
+
if verbose: print(f'Saving new best model with loss {best_loss:.2f}')
|
302 |
+
torch.save(model.state_dict(), params['save_path'] + f'checkpoint_best.save')
|
303 |
+
|
304 |
+
# log
|
305 |
+
all_train_losses.append(train_losses)
|
306 |
+
all_train_accuracies.append(train_accuracies)
|
307 |
+
all_eval_losses.append(eval_losses)
|
308 |
+
all_eval_accuracies.append(eval_accuracies)
|
309 |
+
all_eval_other_metrics.append(eval_other_metrics)
|
310 |
+
all_train_other_metrics.append(train_other_metrics)
|
311 |
+
|
312 |
+
# if epoch == params['nb_epoch_switch_beta']:
|
313 |
+
# params['beta_vae'] = 2.5
|
314 |
+
# params['auxiliaries_dict']['prep_type']['weight'] /= 10
|
315 |
+
# params['auxiliaries_dict']['glasses']['weight'] /= 10
|
316 |
+
|
317 |
+
if (epoch + 1) % params['plot_every'] == 0:
|
318 |
+
|
319 |
+
plot_results(all_train_losses, all_train_accuracies, all_train_other_metrics,
|
320 |
+
all_eval_losses, all_eval_accuracies, all_eval_other_metrics, params['plot_path'], weights)
|
321 |
+
|
322 |
+
return model
|
323 |
+
|
324 |
+
def plot_results(all_train_losses, all_train_accuracies, all_train_other_metrics,
|
325 |
+
all_eval_losses, all_eval_accuracies, all_eval_other_metrics, plot_path, weights):
|
326 |
+
|
327 |
+
steps = np.arange(len(all_eval_accuracies))
|
328 |
+
|
329 |
+
loss_keys = sorted(all_train_losses[0].keys())
|
330 |
+
acc_keys = sorted(all_train_accuracies[0].keys())
|
331 |
+
metrics_keys = sorted(all_train_other_metrics[0].keys())
|
332 |
+
|
333 |
+
plt.figure()
|
334 |
+
plt.title('Train losses')
|
335 |
+
for k in loss_keys:
|
336 |
+
factor = 1 if k == 'mse_loss' else 1
|
337 |
+
if k not in weights.keys():
|
338 |
+
plt.plot(steps[1:], [train_loss[k] * factor for train_loss in all_train_losses], label=k)
|
339 |
+
else:
|
340 |
+
if weights[k] != 0:
|
341 |
+
plt.plot(steps[1:], [train_loss[k] * factor for train_loss in all_train_losses], label=k)
|
342 |
+
|
343 |
+
plt.legend()
|
344 |
+
plt.ylim([0, 4])
|
345 |
+
plt.savefig(plot_path + 'train_losses.png', dpi=200)
|
346 |
+
fig = plt.gcf()
|
347 |
+
plt.close(fig)
|
348 |
+
|
349 |
+
plt.figure()
|
350 |
+
plt.title('Train accuracies')
|
351 |
+
for k in acc_keys:
|
352 |
+
if weights[k] != 0:
|
353 |
+
plt.plot(steps[1:], [train_acc[k] for train_acc in all_train_accuracies], label=k)
|
354 |
+
plt.legend()
|
355 |
+
plt.ylim([0, 1])
|
356 |
+
plt.savefig(plot_path + 'train_acc.png', dpi=200)
|
357 |
+
fig = plt.gcf()
|
358 |
+
plt.close(fig)
|
359 |
+
|
360 |
+
plt.figure()
|
361 |
+
plt.title('Train other metrics')
|
362 |
+
for k in metrics_keys:
|
363 |
+
if 'confusion' not in k and 'presence' in k:
|
364 |
+
plt.plot(steps[1:], [train_metric[k] for train_metric in all_train_other_metrics], label=k)
|
365 |
+
plt.legend()
|
366 |
+
plt.ylim([0, 1])
|
367 |
+
plt.savefig(plot_path + 'train_ing_presence_errors.png', dpi=200)
|
368 |
+
fig = plt.gcf()
|
369 |
+
plt.close(fig)
|
370 |
+
|
371 |
+
plt.figure()
|
372 |
+
plt.title('Train other metrics')
|
373 |
+
for k in metrics_keys:
|
374 |
+
if 'confusion' not in k and 'presence' not in k:
|
375 |
+
plt.plot(steps[1:], [train_metric[k] for train_metric in all_train_other_metrics], label=k)
|
376 |
+
plt.legend()
|
377 |
+
plt.ylim([0, 15])
|
378 |
+
plt.savefig(plot_path + 'train_ing_q_error.png', dpi=200)
|
379 |
+
fig = plt.gcf()
|
380 |
+
plt.close(fig)
|
381 |
+
|
382 |
+
plt.figure()
|
383 |
+
plt.title('Eval losses')
|
384 |
+
for k in loss_keys:
|
385 |
+
factor = 1 if k == 'mse_loss' else 1
|
386 |
+
if k not in weights.keys():
|
387 |
+
plt.plot(steps, [eval_loss[k] * factor for eval_loss in all_eval_losses], label=k)
|
388 |
+
else:
|
389 |
+
if weights[k] != 0:
|
390 |
+
plt.plot(steps, [eval_loss[k] * factor for eval_loss in all_eval_losses], label=k)
|
391 |
+
plt.legend()
|
392 |
+
plt.ylim([0, 4])
|
393 |
+
plt.savefig(plot_path + 'eval_losses.png', dpi=200)
|
394 |
+
fig = plt.gcf()
|
395 |
+
plt.close(fig)
|
396 |
+
|
397 |
+
plt.figure()
|
398 |
+
plt.title('Eval accuracies')
|
399 |
+
for k in acc_keys:
|
400 |
+
if weights[k] != 0:
|
401 |
+
plt.plot(steps, [eval_acc[k] for eval_acc in all_eval_accuracies], label=k)
|
402 |
+
plt.legend()
|
403 |
+
plt.ylim([0, 1])
|
404 |
+
plt.savefig(plot_path + 'eval_acc.png', dpi=200)
|
405 |
+
fig = plt.gcf()
|
406 |
+
plt.close(fig)
|
407 |
+
|
408 |
+
plt.figure()
|
409 |
+
plt.title('Eval other metrics')
|
410 |
+
for k in metrics_keys:
|
411 |
+
if 'confusion' not in k and 'presence' in k:
|
412 |
+
plt.plot(steps, [eval_metric[k] for eval_metric in all_eval_other_metrics], label=k)
|
413 |
+
plt.legend()
|
414 |
+
plt.ylim([0, 1])
|
415 |
+
plt.savefig(plot_path + 'eval_ing_presence_errors.png', dpi=200)
|
416 |
+
fig = plt.gcf()
|
417 |
+
plt.close(fig)
|
418 |
+
|
419 |
+
plt.figure()
|
420 |
+
plt.title('Eval other metrics')
|
421 |
+
for k in metrics_keys:
|
422 |
+
if 'confusion' not in k and 'presence' not in k:
|
423 |
+
plt.plot(steps, [eval_metric[k] for eval_metric in all_eval_other_metrics], label=k)
|
424 |
+
plt.legend()
|
425 |
+
plt.ylim([0, 15])
|
426 |
+
plt.savefig(plot_path + 'eval_ing_q_error.png', dpi=200)
|
427 |
+
fig = plt.gcf()
|
428 |
+
plt.close(fig)
|
429 |
+
|
430 |
+
|
431 |
+
for k in metrics_keys:
|
432 |
+
if 'confusion' in k:
|
433 |
+
plt.figure()
|
434 |
+
plt.title(k)
|
435 |
+
plt.ylabel('True')
|
436 |
+
plt.xlabel('Predicted')
|
437 |
+
plt.imshow(all_eval_other_metrics[-1][k], vmin=0, vmax=1)
|
438 |
+
plt.colorbar()
|
439 |
+
plt.savefig(plot_path + f'eval_{k}.png', dpi=200)
|
440 |
+
fig = plt.gcf()
|
441 |
+
plt.close(fig)
|
442 |
+
|
443 |
+
for k in metrics_keys:
|
444 |
+
if 'confusion' in k:
|
445 |
+
plt.figure()
|
446 |
+
plt.title(k)
|
447 |
+
plt.ylabel('True')
|
448 |
+
plt.xlabel('Predicted')
|
449 |
+
plt.imshow(all_train_other_metrics[-1][k], vmin=0, vmax=1)
|
450 |
+
plt.colorbar()
|
451 |
+
plt.savefig(plot_path + f'train_{k}.png', dpi=200)
|
452 |
+
fig = plt.gcf()
|
453 |
+
plt.close(fig)
|
454 |
+
|
455 |
+
plt.close('all')
|
456 |
+
|
457 |
+
|
458 |
+
def get_model(model_path):
|
459 |
+
|
460 |
+
with open(model_path + 'params.json', 'r') as f:
|
461 |
+
params = json.load(f)
|
462 |
+
params['save_path'] = model_path
|
463 |
+
model_chkpt = model_path + "checkpoint_best.save"
|
464 |
+
model_params = [params[k] for k in ["input_dim", "activation", "hidden_dims_cocktail", "latent_dim", "dropout", "auxiliaries_dict", "hidden_dims_decoder"]]
|
465 |
+
model = get_multihead_model(*model_params)
|
466 |
+
model.load_state_dict(torch.load(model_chkpt))
|
467 |
+
model.eval()
|
468 |
+
max_ing_quantities = np.loadtxt(model_path + 'max_ing_quantities.txt')
|
469 |
+
def predict(ing_qs, aux_str):
|
470 |
+
ing_qs /= max_ing_quantities
|
471 |
+
input_model = torch.FloatTensor(ing_qs).reshape(1, -1)
|
472 |
+
_, outputs, auxiliaries_str = model.forward(input_model, )
|
473 |
+
if isinstance(aux_str, str):
|
474 |
+
return outputs[auxiliaries_str.index(aux_str)].detach().numpy()
|
475 |
+
elif isinstance(aux_str, list):
|
476 |
+
return [outputs[auxiliaries_str.index(aux)].detach().numpy() for aux in aux_str]
|
477 |
+
else:
|
478 |
+
raise ValueError
|
479 |
+
return predict, params
|
480 |
+
|
481 |
+
|
482 |
+
def compute_expe_name_and_save_path(params):
|
483 |
+
weights_str = '['
|
484 |
+
for aux in params['auxiliaries_dict'].keys():
|
485 |
+
weights_str += f'{params["auxiliaries_dict"][aux]["weight"]}, '
|
486 |
+
weights_str = weights_str[:-2] + ']'
|
487 |
+
save_path = params['save_path'] + params["trial_id"]
|
488 |
+
save_path += f'_lr{params["lr"]}'
|
489 |
+
save_path += f'_betavae{params["beta_vae"]}'
|
490 |
+
save_path += f'_bs{params["batch_size"]}'
|
491 |
+
save_path += f'_latentdim{params["latent_dim"]}'
|
492 |
+
save_path += f'_hding{params["hidden_dims_ingredients"]}'
|
493 |
+
save_path += f'_hdcocktail{params["hidden_dims_cocktail"]}'
|
494 |
+
save_path += f'_hddecoder{params["hidden_dims_decoder"]}'
|
495 |
+
save_path += f'_agg{params["agg"]}'
|
496 |
+
save_path += f'_activ{params["activation"]}'
|
497 |
+
save_path += f'_w{weights_str}'
|
498 |
+
counter = 0
|
499 |
+
while os.path.exists(save_path + f"_{counter}"):
|
500 |
+
counter += 1
|
501 |
+
save_path = save_path + f"_{counter}" + '/'
|
502 |
+
params["save_path"] = save_path
|
503 |
+
os.makedirs(save_path)
|
504 |
+
os.makedirs(save_path + 'plots/')
|
505 |
+
params['plot_path'] = save_path + 'plots/'
|
506 |
+
print(f'logging to {save_path}')
|
507 |
+
return params
|
508 |
+
|
509 |
+
|
510 |
+
|
511 |
+
if __name__ == '__main__':
|
512 |
+
params = get_params()
|
513 |
+
run_experiment(params)
|
514 |
+
|
src/cocktails/representation_learning/simple_model.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch; torch.manual_seed(0)
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch.utils
|
5 |
+
import torch.distributions
|
6 |
+
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
|
7 |
+
|
8 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
9 |
+
|
10 |
+
def get_activation(activation):
|
11 |
+
if activation == 'tanh':
|
12 |
+
activ = F.tanh
|
13 |
+
elif activation == 'relu':
|
14 |
+
activ = F.relu
|
15 |
+
elif activation == 'mish':
|
16 |
+
activ = F.mish
|
17 |
+
elif activation == 'sigmoid':
|
18 |
+
activ = torch.sigmoid
|
19 |
+
elif activation == 'leakyrelu':
|
20 |
+
activ = F.leaky_relu
|
21 |
+
elif activation == 'exp':
|
22 |
+
activ = torch.exp
|
23 |
+
else:
|
24 |
+
raise ValueError
|
25 |
+
return activ
|
26 |
+
|
27 |
+
|
28 |
+
class SimpleNet(nn.Module):
|
29 |
+
def __init__(self, input_dim, hidden_dims, output_dim, activation, dropout, final_activ=None):
|
30 |
+
super(SimpleNet, self).__init__()
|
31 |
+
self.linears = nn.ModuleList()
|
32 |
+
self.dropouts = nn.ModuleList()
|
33 |
+
self.output_dim = output_dim
|
34 |
+
dims = [input_dim] + hidden_dims + [output_dim]
|
35 |
+
for d_in, d_out in zip(dims[:-1], dims[1:]):
|
36 |
+
self.linears.append(nn.Linear(d_in, d_out))
|
37 |
+
self.dropouts.append(nn.Dropout(dropout))
|
38 |
+
self.activation = get_activation(activation)
|
39 |
+
self.n_layers = len(self.linears)
|
40 |
+
self.layer_range = range(self.n_layers)
|
41 |
+
if final_activ != None:
|
42 |
+
self.final_activ = get_activation(final_activ)
|
43 |
+
self.use_final_activ = True
|
44 |
+
else:
|
45 |
+
self.use_final_activ = False
|
46 |
+
|
47 |
+
def forward(self, x):
|
48 |
+
for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts):
|
49 |
+
x = layer(x)
|
50 |
+
if i_layer != self.n_layers - 1:
|
51 |
+
x = self.activation(dropout(x))
|
52 |
+
if self.use_final_activ: x = self.final_activ(x)
|
53 |
+
return x
|
54 |
+
|
src/cocktails/representation_learning/vae_model.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch; torch.manual_seed(0)
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch.utils
|
5 |
+
import torch.distributions
|
6 |
+
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
|
7 |
+
|
8 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
9 |
+
|
10 |
+
def get_activation(activation):
|
11 |
+
if activation == 'tanh':
|
12 |
+
activ = F.tanh
|
13 |
+
elif activation == 'relu':
|
14 |
+
activ = F.relu
|
15 |
+
elif activation == 'mish':
|
16 |
+
activ = F.mish
|
17 |
+
elif activation == 'sigmoid':
|
18 |
+
activ = F.sigmoid
|
19 |
+
elif activation == 'leakyrelu':
|
20 |
+
activ = F.leaky_relu
|
21 |
+
elif activation == 'exp':
|
22 |
+
activ = torch.exp
|
23 |
+
else:
|
24 |
+
raise ValueError
|
25 |
+
return activ
|
26 |
+
|
27 |
+
class IngredientEncoder(nn.Module):
|
28 |
+
def __init__(self, input_dim, deepset_latent_dim, hidden_dims, activation, dropout):
|
29 |
+
super(IngredientEncoder, self).__init__()
|
30 |
+
self.linears = nn.ModuleList()
|
31 |
+
self.dropouts = nn.ModuleList()
|
32 |
+
dims = [input_dim] + hidden_dims + [deepset_latent_dim]
|
33 |
+
for d_in, d_out in zip(dims[:-1], dims[1:]):
|
34 |
+
self.linears.append(nn.Linear(d_in, d_out))
|
35 |
+
self.dropouts.append(nn.Dropout(dropout))
|
36 |
+
self.activation = get_activation(activation)
|
37 |
+
self.n_layers = len(self.linears)
|
38 |
+
self.layer_range = range(self.n_layers)
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts):
|
42 |
+
x = layer(x)
|
43 |
+
if i_layer != self.n_layers - 1:
|
44 |
+
x = self.activation(dropout(x))
|
45 |
+
return x # do not use dropout on last layer?
|
46 |
+
|
47 |
+
class DeepsetCocktailEncoder(nn.Module):
|
48 |
+
def __init__(self, input_dim, deepset_latent_dim, hidden_dims_ing, activation,
|
49 |
+
hidden_dims_cocktail, latent_dim, aggregation, dropout):
|
50 |
+
super(DeepsetCocktailEncoder, self).__init__()
|
51 |
+
self.input_dim = input_dim # dimension of ingredient representation + quantity
|
52 |
+
self.ingredient_encoder = IngredientEncoder(input_dim, deepset_latent_dim, hidden_dims_ing, activation, dropout) # encode each ingredient separately
|
53 |
+
self.deepset_latent_dim = deepset_latent_dim # dimension of the deepset aggregation
|
54 |
+
self.aggregation = aggregation
|
55 |
+
self.latent_dim = latent_dim
|
56 |
+
# post aggregation network
|
57 |
+
self.linears = nn.ModuleList()
|
58 |
+
self.dropouts = nn.ModuleList()
|
59 |
+
dims = [deepset_latent_dim] + hidden_dims_cocktail
|
60 |
+
for d_in, d_out in zip(dims[:-1], dims[1:]):
|
61 |
+
self.linears.append(nn.Linear(d_in, d_out))
|
62 |
+
self.dropouts.append(nn.Dropout(dropout))
|
63 |
+
self.FC_mean = nn.Linear(hidden_dims_cocktail[-1], latent_dim)
|
64 |
+
self.FC_logvar = nn.Linear(hidden_dims_cocktail[-1], latent_dim)
|
65 |
+
self.softplus = nn.Softplus()
|
66 |
+
|
67 |
+
self.activation = get_activation(activation)
|
68 |
+
self.n_layers = len(self.linears)
|
69 |
+
self.layer_range = range(self.n_layers)
|
70 |
+
|
71 |
+
def forward(self, nb_ingredients, x):
|
72 |
+
|
73 |
+
# reshape x in (batch size * nb ingredients, dim_ing_rep)
|
74 |
+
batch_size = x.shape[0]
|
75 |
+
all_ingredients = []
|
76 |
+
for i in range(batch_size):
|
77 |
+
for j in range(nb_ingredients[i]):
|
78 |
+
all_ingredients.append(x[i, self.input_dim * j: self.input_dim * (j + 1)].reshape(1, -1))
|
79 |
+
x = torch.cat(all_ingredients, dim=0)
|
80 |
+
# encode ingredients in parallel
|
81 |
+
ingredients_encodings = self.ingredient_encoder(x)
|
82 |
+
assert ingredients_encodings.shape == (torch.sum(nb_ingredients), self.deepset_latent_dim)
|
83 |
+
|
84 |
+
# aggregate
|
85 |
+
x = []
|
86 |
+
index_first = 0
|
87 |
+
for i in range(batch_size):
|
88 |
+
index_last = index_first + nb_ingredients[i]
|
89 |
+
# aggregate
|
90 |
+
if self.aggregation == 'sum':
|
91 |
+
x.append(torch.sum(ingredients_encodings[index_first:index_last], dim=0).reshape(1, -1))
|
92 |
+
elif self.aggregation == 'mean':
|
93 |
+
x.append(torch.mean(ingredients_encodings[index_first:index_last], dim=0).reshape(1, -1))
|
94 |
+
else:
|
95 |
+
raise ValueError
|
96 |
+
index_first = index_last
|
97 |
+
x = torch.cat(x, dim=0)
|
98 |
+
assert x.shape[0] == batch_size
|
99 |
+
|
100 |
+
for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts):
|
101 |
+
x = self.activation(dropout(layer(x)))
|
102 |
+
mean = self.FC_mean(x)
|
103 |
+
logvar = self.FC_logvar(x)
|
104 |
+
return mean, logvar
|
105 |
+
|
106 |
+
class Decoder(nn.Module):
|
107 |
+
def __init__(self, latent_dim, hidden_dims, num_ingredients, activation, dropout, filter_output=None):
|
108 |
+
super(Decoder, self).__init__()
|
109 |
+
self.linears = nn.ModuleList()
|
110 |
+
self.dropouts = nn.ModuleList()
|
111 |
+
dims = [latent_dim] + hidden_dims + [num_ingredients]
|
112 |
+
for d_in, d_out in zip(dims[:-1], dims[1:]):
|
113 |
+
self.linears.append(nn.Linear(d_in, d_out))
|
114 |
+
self.dropouts.append(nn.Dropout(dropout))
|
115 |
+
self.activation = get_activation(activation)
|
116 |
+
self.n_layers = len(self.linears)
|
117 |
+
self.layer_range = range(self.n_layers)
|
118 |
+
self.filter = filter_output
|
119 |
+
|
120 |
+
def forward(self, x, to_filter=False):
|
121 |
+
for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts):
|
122 |
+
x = layer(x)
|
123 |
+
if i_layer != self.n_layers - 1:
|
124 |
+
x = self.activation(dropout(x))
|
125 |
+
if to_filter:
|
126 |
+
x = self.filter(x)
|
127 |
+
return x
|
128 |
+
|
129 |
+
class PredictorHead(nn.Module):
|
130 |
+
def __init__(self, latent_dim, dim_output, final_activ):
|
131 |
+
super(PredictorHead, self).__init__()
|
132 |
+
self.linear = nn.Linear(latent_dim, dim_output)
|
133 |
+
if final_activ != None:
|
134 |
+
self.final_activ = get_activation(final_activ)
|
135 |
+
self.use_final_activ = True
|
136 |
+
else:
|
137 |
+
self.use_final_activ = False
|
138 |
+
|
139 |
+
def forward(self, x):
|
140 |
+
x = self.linear(x)
|
141 |
+
if self.use_final_activ: x = self.final_activ(x)
|
142 |
+
return x
|
143 |
+
|
144 |
+
|
145 |
+
class VAEModel(nn.Module):
|
146 |
+
def __init__(self, encoder, decoder, auxiliaries_dict):
|
147 |
+
super(VAEModel, self).__init__()
|
148 |
+
self.encoder = encoder
|
149 |
+
self.decoder = decoder
|
150 |
+
self.latent_dim = self.encoder.latent_dim
|
151 |
+
self.auxiliaries_str = []
|
152 |
+
self.auxiliaries = nn.ModuleList()
|
153 |
+
for aux_str in sorted(auxiliaries_dict.keys()):
|
154 |
+
if aux_str == 'taste_reps':
|
155 |
+
self.taste_reps_decoder = PredictorHead(self.latent_dim, auxiliaries_dict[aux_str]['dim_output'], auxiliaries_dict[aux_str]['final_activ'])
|
156 |
+
else:
|
157 |
+
self.auxiliaries_str.append(aux_str)
|
158 |
+
self.auxiliaries.append(PredictorHead(self.latent_dim, auxiliaries_dict[aux_str]['dim_output'], auxiliaries_dict[aux_str]['final_activ']))
|
159 |
+
|
160 |
+
def reparameterization(self, mean, logvar):
|
161 |
+
std = torch.exp(0.5 * logvar)
|
162 |
+
epsilon = torch.randn_like(std).to(device) # sampling epsilon
|
163 |
+
z = mean + std * epsilon # reparameterization trick
|
164 |
+
return z
|
165 |
+
|
166 |
+
|
167 |
+
def sample(self, n=1):
|
168 |
+
z = torch.randn(size=(n, self.latent_dim))
|
169 |
+
return self.decoder(z)
|
170 |
+
|
171 |
+
def get_all_auxiliaries(self, x):
|
172 |
+
return [aux(x) for aux in self.auxiliaries]
|
173 |
+
|
174 |
+
def get_auxiliary(self, z, aux_str):
|
175 |
+
if aux_str == 'taste_reps':
|
176 |
+
return self.taste_reps_decoder(z)
|
177 |
+
else:
|
178 |
+
index = self.auxiliaries_str.index(aux_str)
|
179 |
+
return self.auxiliaries[index](z)
|
180 |
+
|
181 |
+
def forward_direct(self, x, aux_str=None, to_filter=False):
|
182 |
+
mean, logvar = self.encoder(x)
|
183 |
+
z = self.reparameterization(mean, logvar) # takes exponential function (log var -> std)
|
184 |
+
x_hat = self.decoder(mean, to_filter=to_filter)
|
185 |
+
if aux_str is not None:
|
186 |
+
return x_hat, z, mean, logvar, self.get_auxiliary(z, aux_str), [aux_str]
|
187 |
+
else:
|
188 |
+
return x_hat, z, mean, logvar, self.get_all_auxiliaries(z), self.auxiliaries_str
|
189 |
+
|
190 |
+
def forward(self, nb_ingredients, x, aux_str=None, to_filter=False):
|
191 |
+
assert False
|
192 |
+
mean, std = self.encoder(nb_ingredients, x)
|
193 |
+
z = self.reparameterization(mean, std) # takes exponential function (log var -> std)
|
194 |
+
x_hat = self.decoder(mean, to_filter=to_filter)
|
195 |
+
if aux_str is not None:
|
196 |
+
return x_hat, z, mean, std, self.get_auxiliary(z, aux_str), [aux_str]
|
197 |
+
else:
|
198 |
+
return x_hat, z, mean, std, self.get_all_auxiliaries(z), self.auxiliaries_str
|
199 |
+
|
200 |
+
|
201 |
+
|
202 |
+
|
203 |
+
class SimpleEncoder(nn.Module):
|
204 |
+
|
205 |
+
def __init__(self, input_dim, hidden_dims, latent_dim, activation, dropout):
|
206 |
+
super(SimpleEncoder, self).__init__()
|
207 |
+
self.latent_dim = latent_dim
|
208 |
+
# post aggregation network
|
209 |
+
self.linears = nn.ModuleList()
|
210 |
+
self.dropouts = nn.ModuleList()
|
211 |
+
dims = [input_dim] + hidden_dims
|
212 |
+
for d_in, d_out in zip(dims[:-1], dims[1:]):
|
213 |
+
self.linears.append(nn.Linear(d_in, d_out))
|
214 |
+
self.dropouts.append(nn.Dropout(dropout))
|
215 |
+
self.FC_mean = nn.Linear(hidden_dims[-1], latent_dim)
|
216 |
+
self.FC_logvar = nn.Linear(hidden_dims[-1], latent_dim)
|
217 |
+
# self.softplus = nn.Softplus()
|
218 |
+
|
219 |
+
self.activation = get_activation(activation)
|
220 |
+
self.n_layers = len(self.linears)
|
221 |
+
self.layer_range = range(self.n_layers)
|
222 |
+
|
223 |
+
def forward(self, x):
|
224 |
+
for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts):
|
225 |
+
x = self.activation(dropout(layer(x)))
|
226 |
+
mean = self.FC_mean(x)
|
227 |
+
logvar = self.FC_logvar(x)
|
228 |
+
return mean, logvar
|
229 |
+
|
230 |
+
def get_vae_model(input_dim, deepset_latent_dim, hidden_dims_ing, activation,
|
231 |
+
hidden_dims_cocktail, hidden_dims_decoder, num_ingredients, latent_dim, aggregation, dropout, auxiliaries_dict,
|
232 |
+
filter_decoder_output):
|
233 |
+
# encoder = DeepsetCocktailEncoder(input_dim, deepset_latent_dim, hidden_dims_ing, activation,
|
234 |
+
# hidden_dims_cocktail, latent_dim, aggregation, dropout)
|
235 |
+
encoder = SimpleEncoder(num_ingredients, hidden_dims_cocktail, latent_dim, activation, dropout)
|
236 |
+
decoder = Decoder(latent_dim, hidden_dims_decoder, num_ingredients, activation, dropout, filter_output=filter_decoder_output)
|
237 |
+
vae = VAEModel(encoder, decoder, auxiliaries_dict)
|
238 |
+
return vae
|
src/cocktails/utilities/__init__.py
ADDED
File without changes
|
src/cocktails/utilities/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (208 Bytes). View file
|
|
src/cocktails/utilities/__pycache__/cocktail_category_detection_utilities.cpython-39.pyc
ADDED
Binary file (9.62 kB). View file
|
|
src/cocktails/utilities/__pycache__/cocktail_utilities.cpython-39.pyc
ADDED
Binary file (8.12 kB). View file
|
|
src/cocktails/utilities/__pycache__/glass_and_volume_utilities.cpython-39.pyc
ADDED
Binary file (1.19 kB). View file
|
|
src/cocktails/utilities/__pycache__/ingredients_utilities.cpython-39.pyc
ADDED
Binary file (6.86 kB). View file
|
|
src/cocktails/utilities/__pycache__/other_scrubbing_utilities.cpython-39.pyc
ADDED
Binary file (8.55 kB). View file
|
|
src/cocktails/utilities/analysis_utilities.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
|
4 |
+
from src.cocktails.utilities.ingredients_utilities import ingredient_list, extract_ingredients, ingredients_per_type
|
5 |
+
|
6 |
+
color_codes = dict(ancestral='#000000',
|
7 |
+
spirit_forward='#2320D2',
|
8 |
+
duo='#6E20D2',
|
9 |
+
champagne_cocktail='#25FFCA',
|
10 |
+
complex_highball='#068F25',
|
11 |
+
simple_highball='#25FF57',
|
12 |
+
collins='#77FF96',
|
13 |
+
julep='#25B8FF',
|
14 |
+
simple_sour='#FBD756',
|
15 |
+
complex_sour='#DCAD07',
|
16 |
+
simple_sour_with_juice='#FF5033',
|
17 |
+
complex_sour_with_juice='#D42306',
|
18 |
+
# simple_sour_with_egg='#FF9C54',
|
19 |
+
# complex_sour_with_egg='#CF5700',
|
20 |
+
# almost_simple_sor='#FF5033',
|
21 |
+
# almost_sor='#D42306',
|
22 |
+
# almost_sor_with_egg='#D42306',
|
23 |
+
other='#9B9B9B'
|
24 |
+
)
|
25 |
+
|
26 |
+
def get_subcategories(data):
|
27 |
+
subcategories = np.array(data['subcategory'])
|
28 |
+
sub_categories_list = sorted(set(subcategories))
|
29 |
+
subcat_count = dict(zip(sub_categories_list, [0] * len(sub_categories_list)))
|
30 |
+
for sc in data['subcategory']:
|
31 |
+
subcat_count[sc] += 1
|
32 |
+
return subcategories, sub_categories_list, subcat_count
|
33 |
+
|
34 |
+
def get_ingredient_count(data):
|
35 |
+
ingredient_counts = dict(zip(ingredient_list, [0] * len(ingredient_list)))
|
36 |
+
for ing_str in data['ingredients_str']:
|
37 |
+
ingredients, _ = extract_ingredients(ing_str)
|
38 |
+
for ing in ingredients:
|
39 |
+
ingredient_counts[ing] += 1
|
40 |
+
return ingredient_counts
|
41 |
+
|
42 |
+
def compute_eucl_dist(a, b):
|
43 |
+
return np.sqrt(np.sum((a - b)**2))
|
44 |
+
|
45 |
+
def recipe_contains(ingredients, stuff):
|
46 |
+
if stuff in ingredient_list:
|
47 |
+
return stuff in ingredients
|
48 |
+
elif stuff == 'juice':
|
49 |
+
return any(['juice' in ing and 'lemon' not in ing and 'lime' not in ing for ing in ingredients])
|
50 |
+
elif stuff == 'bubbles':
|
51 |
+
return any([ing in ['soda', 'tonic', 'cola', 'sparkling wine', 'ginger beer'] for ing in ingredients])
|
52 |
+
elif stuff == 'acid':
|
53 |
+
return any([ing in ['lemon juice', 'lime juice'] for ing in ingredients])
|
54 |
+
elif stuff == 'vermouth':
|
55 |
+
return any([ing in ingredients_per_type['vermouth'] for ing in ingredients])
|
56 |
+
elif stuff == 'plain sweet':
|
57 |
+
plain_sweet = ingredients_per_type['sweeteners']
|
58 |
+
return any([ing in plain_sweet for ing in ingredients])
|
59 |
+
elif stuff == 'sweet':
|
60 |
+
sweet = ingredients_per_type['sweeteners'] + ingredients_per_type['liqueur'] + ['sweet vermouth', 'lillet blanc']
|
61 |
+
return any([ing in sweet for ing in ingredients])
|
62 |
+
elif stuff == 'spirit':
|
63 |
+
return any([ing in ingredients_per_type['liquor'] for ing in ingredients])
|
64 |
+
else:
|
65 |
+
raise ValueError
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
def radar_factory(num_vars, frame='circle'):
|
70 |
+
# from stackoverflow's post? Or matplotlib's blog
|
71 |
+
"""
|
72 |
+
Create a radar chart with `num_vars` axes.
|
73 |
+
|
74 |
+
This function creates a RadarAxes projection and registers it.
|
75 |
+
|
76 |
+
Parameters
|
77 |
+
----------
|
78 |
+
num_vars : int
|
79 |
+
Number of variables for radar chart.
|
80 |
+
frame : {'circle', 'polygon'}
|
81 |
+
Shape of frame surrounding axes.
|
82 |
+
|
83 |
+
"""
|
84 |
+
import numpy as np
|
85 |
+
|
86 |
+
from matplotlib.patches import Circle, RegularPolygon
|
87 |
+
from matplotlib.path import Path
|
88 |
+
from matplotlib.projections.polar import PolarAxes
|
89 |
+
from matplotlib.projections import register_projection
|
90 |
+
from matplotlib.spines import Spine
|
91 |
+
from matplotlib.transforms import Affine2D
|
92 |
+
# calculate evenly-spaced axis angles
|
93 |
+
theta = np.linspace(0, 2*np.pi, num_vars, endpoint=False)
|
94 |
+
|
95 |
+
class RadarAxes(PolarAxes):
|
96 |
+
|
97 |
+
name = 'radar'
|
98 |
+
# use 1 line segment to connect specified points
|
99 |
+
RESOLUTION = 1
|
100 |
+
|
101 |
+
def __init__(self, *args, **kwargs):
|
102 |
+
super().__init__(*args, **kwargs)
|
103 |
+
# rotate plot such that the first axis is at the top
|
104 |
+
self.set_theta_zero_location('N')
|
105 |
+
|
106 |
+
def fill(self, *args, closed=True, **kwargs):
|
107 |
+
"""Override fill so that line is closed by default"""
|
108 |
+
return super().fill(closed=closed, *args, **kwargs)
|
109 |
+
|
110 |
+
def plot(self, *args, **kwargs):
|
111 |
+
"""Override plot so that line is closed by default"""
|
112 |
+
lines = super().plot(*args, **kwargs)
|
113 |
+
for line in lines:
|
114 |
+
self._close_line(line)
|
115 |
+
|
116 |
+
def _close_line(self, line):
|
117 |
+
x, y = line.get_data()
|
118 |
+
# FIXME: markers at x[0], y[0] get doubled-up
|
119 |
+
if x[0] != x[-1]:
|
120 |
+
x = np.append(x, x[0])
|
121 |
+
y = np.append(y, y[0])
|
122 |
+
line.set_data(x, y)
|
123 |
+
|
124 |
+
def set_varlabels(self, labels):
|
125 |
+
self.set_thetagrids(np.degrees(theta), labels)
|
126 |
+
|
127 |
+
def _gen_axes_patch(self):
|
128 |
+
# The Axes patch must be centered at (0.5, 0.5) and of radius 0.5
|
129 |
+
# in axes coordinates.
|
130 |
+
if frame == 'circle':
|
131 |
+
return Circle((0.5, 0.5), 0.5)
|
132 |
+
elif frame == 'polygon':
|
133 |
+
return RegularPolygon((0.5, 0.5), num_vars,
|
134 |
+
radius=.5, edgecolor="k")
|
135 |
+
else:
|
136 |
+
raise ValueError("Unknown value for 'frame': %s" % frame)
|
137 |
+
|
138 |
+
def _gen_axes_spines(self):
|
139 |
+
if frame == 'circle':
|
140 |
+
return super()._gen_axes_spines()
|
141 |
+
elif frame == 'polygon':
|
142 |
+
# spine_type must be 'left'/'right'/'top'/'bottom'/'circle'.
|
143 |
+
spine = Spine(axes=self,
|
144 |
+
spine_type='circle',
|
145 |
+
path=Path.unit_regular_polygon(num_vars))
|
146 |
+
# unit_regular_polygon gives a polygon of radius 1 centered at
|
147 |
+
# (0, 0) but we want a polygon of radius 0.5 centered at (0.5,
|
148 |
+
# 0.5) in axes coordinates.
|
149 |
+
spine.set_transform(Affine2D().scale(.5).translate(.5, .5)
|
150 |
+
+ self.transAxes)
|
151 |
+
return {'polar': spine}
|
152 |
+
else:
|
153 |
+
raise ValueError("Unknown value for 'frame': %s" % frame)
|
154 |
+
|
155 |
+
register_projection(RadarAxes)
|
156 |
+
return theta
|
157 |
+
|
158 |
+
def plot_radar_cocktail(representation, labels_dim, labels_cocktails, save_path=None, to_show=False, to_save=False):
|
159 |
+
assert to_show or to_save, 'either show or save'
|
160 |
+
assert representation.ndim == 2
|
161 |
+
n_data, dim_rep = representation.shape
|
162 |
+
assert len(labels_cocktails) == n_data
|
163 |
+
assert len(labels_dim) == dim_rep
|
164 |
+
assert n_data <= 5, 'max 5 representation_analysis please'
|
165 |
+
|
166 |
+
theta = radar_factory(dim_rep, frame='circle')
|
167 |
+
|
168 |
+
|
169 |
+
fig, ax = plt.subplots(figsize=(9, 9), subplot_kw=dict(projection='radar'))
|
170 |
+
fig.subplots_adjust(wspace=0.25, hspace=0.20, top=0.85, bottom=0.05)
|
171 |
+
|
172 |
+
colors = ['b', 'r', 'g', 'm', 'y']
|
173 |
+
# Plot the four cases from the example data on separate axes
|
174 |
+
ax.set_rgrids([0.2, 0.4, 0.6, 0.8])
|
175 |
+
for d, color in zip(representation, colors):
|
176 |
+
ax.plot(theta, d, color=color)
|
177 |
+
for d, color in zip(representation, colors):
|
178 |
+
ax.fill(theta, d, facecolor=color, alpha=0.25)
|
179 |
+
ax.set_varlabels(labels_dim)
|
180 |
+
|
181 |
+
# add legend relative to top-left plot
|
182 |
+
legend = ax.legend(labels_cocktails, loc=(0.9, .95),
|
183 |
+
labelspacing=0.1, fontsize='small')
|
184 |
+
|
185 |
+
if to_save:
|
186 |
+
plt.savefig(save_path, bbox_artists=(legend,), dpi=200)
|
187 |
+
else:
|
188 |
+
plt.show()
|
189 |
+
|
src/cocktails/utilities/cocktail_category_detection_utilities.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The following functions check whether a cocktail belong to any of N categories
|
2 |
+
import numpy as np
|
3 |
+
from src.cocktails.utilities.ingredients_utilities import ingredient_profiles, ingredients_per_type, ingredient2ingredient_id, extract_ingredients
|
4 |
+
|
5 |
+
|
6 |
+
def is_ancestral(n, ingredient_indexes, ingredients, quantities):
|
7 |
+
# ancestrals have a strong spirit and some sweetness from sugar, syrup or liqueurs, no citrus.
|
8 |
+
# absinthe can be added up to 3 dashes.
|
9 |
+
# Liqueurs are there to bring sweetness, thus must stay below 15ml (if not it's a duo)
|
10 |
+
if n['spirit'] > 0 and n['citrus'] == 0 and n['plain_sweet'] + n['liqueur'] <= 2:
|
11 |
+
if n['spirit'] > 1 and 'absinthe' in ingredients:
|
12 |
+
if quantities[ingredients.index('absinthe')] < 3:
|
13 |
+
pass
|
14 |
+
else:
|
15 |
+
return False
|
16 |
+
if n['sugar'] < 2 and n['liqueur'] < 3:
|
17 |
+
if n['all'] - n['spirit'] - n['sugar'] -n['syrup']- n['liqueur']- n['inconsequentials'] == 0:
|
18 |
+
if n['liqueur'] == 0:
|
19 |
+
return True
|
20 |
+
else:
|
21 |
+
q_liqueur = np.sum([quantities[i_ing]
|
22 |
+
for i_ind, i_ing in zip(ingredient_indexes, range(len(ingredients)))
|
23 |
+
if ingredient_profiles['type'][i_ind].lower() == 'liqueur'])
|
24 |
+
if q_liqueur <= 15:
|
25 |
+
return True
|
26 |
+
else:
|
27 |
+
return False
|
28 |
+
return False
|
29 |
+
|
30 |
+
|
31 |
+
def is_simple_sour(n, ingredient_indexes, ingredients, quantities):
|
32 |
+
# simple sours contain a citrus, at least 1 spirit and non-alcoholic sweetness
|
33 |
+
if n['citrus'] + n['coffee']> 0 and n['spirit'] > 0 and n['plain_sweet'] > 0 and n['juice'] == 0:
|
34 |
+
if n['all'] - n['citrus'] - n['coffee'] - n['spirit'] - n['plain_sweet'] - n['juice'] -n['egg'] - n['inconsequentials'] == 0:
|
35 |
+
return True
|
36 |
+
return False
|
37 |
+
|
38 |
+
def is_complex_sour(n, ingredient_indexes, ingredients, quantities):
|
39 |
+
# complex sours are simple sours that use alcoholic sweetness, at least in part
|
40 |
+
if n['citrus'] + n['coffee'] > 0 and n['all_sweet'] > 0 and n['juice'] == 0:
|
41 |
+
if (n['spirit'] == 0 and n['liqueur'] > 0) or n['spirit'] > 0:
|
42 |
+
if n['vermouth'] + n['liqueur'] <= 2 and n['vermouth'] + n['liqueur'] > 0:
|
43 |
+
if n['all'] -n['coffee'] - n['citrus'] - n['spirit'] - n['sugar'] - n['syrup'] \
|
44 |
+
- n['liqueur'] - n['vermouth'] - n['egg'] - n['juice'] - n['inconsequentials'] == 0:
|
45 |
+
return True
|
46 |
+
return False
|
47 |
+
|
48 |
+
def is_spirit_forward(n, ingredient_indexes, ingredients, quantities):
|
49 |
+
# spirit forward contain at least a spirit and vermouth, no citrus. Can contain sweet (sugar, syrups, liqueurs)
|
50 |
+
if n['spirit'] > 0 and n['citrus'] == 0 and n['vermouth'] > 0:
|
51 |
+
if n['all'] - n['spirit'] - n['sugar'] - n['syrup'] - n['liqueur'] -n['egg'] - n['vermouth'] - n['inconsequentials']== 0:
|
52 |
+
return True
|
53 |
+
return False
|
54 |
+
|
55 |
+
def is_duo(n, ingredient_indexes, ingredients, quantities):
|
56 |
+
# duos are made of one spirit and one liqueur (above 15ml), under it's an ancestral, no citrus.
|
57 |
+
if n['spirit'] >= 1 and n['citrus'] == 0 and n['sugar']==0 and n['liqueur'] > 0 and n['vermouth'] == 0:
|
58 |
+
if n['all'] - n['spirit'] - n['sugar'] - n['liqueur'] - n['vermouth'] - n['inconsequentials'] == 0:
|
59 |
+
q_liqueur = np.sum([quantities[i_ing]
|
60 |
+
for i_ind, i_ing in zip(ingredient_indexes, range(len(ingredients)))
|
61 |
+
if ingredient_profiles['type'][i_ind].lower() == 'liqueur'])
|
62 |
+
if q_liqueur > 15:
|
63 |
+
return True
|
64 |
+
else:
|
65 |
+
return False
|
66 |
+
return False
|
67 |
+
|
68 |
+
def is_champagne_cocktail(n, ingredient_indexes, ingredients, quantities):
|
69 |
+
if n['sparkling'] > 0:
|
70 |
+
return True
|
71 |
+
else:
|
72 |
+
return False
|
73 |
+
|
74 |
+
def is_simple_highball(n, ingredient_indexes, ingredients, quantities):
|
75 |
+
# simple highballs have one alcoholic ingredient and bubbles
|
76 |
+
if n['alcoholic'] == 1 and n['bubbles'] > 0:
|
77 |
+
if n['all'] - n['alcoholic'] - n['bubbles'] - n['inconsequentials']== 0:
|
78 |
+
return True
|
79 |
+
return False
|
80 |
+
|
81 |
+
def is_complex_highball(n, ingredient_indexes, ingredients, quantities):
|
82 |
+
# complex highballs have at least one alcoholic ingredient and bubbles (possibly alcoholic). They also contain extra sugar under any form and juice
|
83 |
+
if n['alcoholic'] > 0 and (n['bubbles'] + n['sparkling']) == 1 and n['juice'] + n['all_sweet'] + n['sugar_bubbles']> 0:
|
84 |
+
if n['all'] - n['spirit'] - n['bubbles'] - n['sparkling'] - n['citrus'] - n['juice'] - n['liqueur'] \
|
85 |
+
- n['syrup'] - n['sugar'] -n['vermouth'] -n['egg'] - n['inconsequentials'] == 0:
|
86 |
+
if not is_collins(n, ingredient_indexes, ingredients, quantities) and not is_simple_highball(n, ingredient_indexes, ingredients, quantities):
|
87 |
+
return True
|
88 |
+
return False
|
89 |
+
|
90 |
+
def is_collins(n, ingredient_indexes, ingredients, quantities):
|
91 |
+
# collins are a particular kind of highball with sugar and citrus
|
92 |
+
if n['alcoholic'] == 1 and n['bubbles'] == 1 and n['citrus'] > 0 and n['plain_sweet'] + n['sugar_bubbles'] > 0:
|
93 |
+
if n['all'] - n['spirit'] - n['bubbles'] - n['citrus'] - n['sugar'] - n['inconsequentials'] == 0:
|
94 |
+
return True
|
95 |
+
return False
|
96 |
+
|
97 |
+
def is_julep(n, ingredient_indexes, ingredients, quantities):
|
98 |
+
# juleps involve smashd mint, sugar and a spirit, no citrus.
|
99 |
+
if 'mint' in ingredients and n['sugar'] > 0 and n['spirit'] > 0 and n['vermouth'] == 0 and n['citrus'] == 0:
|
100 |
+
return True
|
101 |
+
return False
|
102 |
+
|
103 |
+
def is_simple_sour_with_juice(n, ingredient_indexes, ingredients, quantities):
|
104 |
+
# almost sours are sours with juice
|
105 |
+
if n['juice'] > 0 and n['spirit'] > 0 and n['plain_sweet'] > 0:
|
106 |
+
if n['all'] - n['citrus'] - n['coffee'] - n['juice'] - n['spirit'] - n['sugar'] - n['syrup'] - n['egg'] - n['inconsequentials'] == 0:
|
107 |
+
return True
|
108 |
+
return False
|
109 |
+
|
110 |
+
|
111 |
+
def is_complex_sour_with_juice(n, ingredient_indexes, ingredients, quantities):
|
112 |
+
# almost sours are sours with juice
|
113 |
+
if n['juice'] > 0 and n['all_sweet'] > 0:
|
114 |
+
if (n['spirit'] == 0 and n['liqueur'] > 0) or n['spirit'] > 0:
|
115 |
+
if n['vermouth'] + n['liqueur'] <= 2 and n['vermouth'] + n['liqueur'] > 0:
|
116 |
+
if n['all'] -n['coffee'] - n['citrus'] - n['spirit'] - n['sugar'] - n['syrup'] \
|
117 |
+
- n['liqueur'] - n['vermouth'] - n['egg'] - n['juice'] - n['inconsequentials'] == 0:
|
118 |
+
return True
|
119 |
+
return False
|
120 |
+
|
121 |
+
|
122 |
+
is_sub_category = [is_ancestral, is_complex_sour, is_simple_sour, is_duo, is_champagne_cocktail,
|
123 |
+
is_spirit_forward, is_simple_highball, is_complex_highball, is_collins,
|
124 |
+
is_julep, is_simple_sour_with_juice, is_complex_sour_with_juice]
|
125 |
+
sub_categories = ['ancestral', 'complex_sour', 'simple_sour', 'duo', 'champagne_cocktail',
|
126 |
+
'spirit_forward', 'simple_highball', 'complex_highball', 'collins',
|
127 |
+
'julep', 'simple_sour_with_juice', 'complex_sour_with_juice']
|
128 |
+
|
129 |
+
|
130 |
+
# compute cocktail category as a function of ingredients and quantities, uses name to check match between name and cat (e.g. XXX Collins should be collins..)
|
131 |
+
# Categories definitions are based on https://www.seriouseats.com/cocktail-style-guide-categories-of-cocktails-glossary-families-of-drinks
|
132 |
+
def find_cocktail_sub_category(ingredients, quantities, name=None):
|
133 |
+
ingredient_indexes = [ingredient2ingredient_id[ing] for ing in ingredients]
|
134 |
+
n_spirit = np.sum([ingredient_profiles['type'][i].lower() == 'liquor' for i in ingredient_indexes ])
|
135 |
+
n_citrus = np.sum([ingredient_profiles['type'][i].lower()== 'acid' for i in ingredient_indexes])
|
136 |
+
n_sugar = np.sum([ingredient_profiles['ingredient'][i].lower() in ['double syrup', 'simple syrup', 'honey syrup'] for i in ingredient_indexes])
|
137 |
+
plain_sweet = ingredients_per_type['sweeteners']
|
138 |
+
all_sweet = ingredients_per_type['sweeteners'] + ingredients_per_type['liqueur'] + ['sweet vermouth', 'lillet blanc']
|
139 |
+
n_plain_sweet = np.sum([ingredient_profiles['ingredient'][i].lower() in plain_sweet for i in ingredient_indexes])
|
140 |
+
n_all_sweet = np.sum([ingredient_profiles['ingredient'][i].lower() in all_sweet for i in ingredient_indexes])
|
141 |
+
n_sugar_bubbles = np.sum([ingredient_profiles['ingredient'][i].lower() in ['cola', 'ginger beer', 'tonic'] for i in ingredient_indexes])
|
142 |
+
n_juice = np.sum([ingredient_profiles['type'][i].lower() == 'juice' for i in ingredient_indexes])
|
143 |
+
n_liqueur = np.sum([ingredient_profiles['type'][i].lower() == 'liqueur' for i in ingredient_indexes])
|
144 |
+
alcoholic = ingredients_per_type['liquor'] + ingredients_per_type['liqueur'] + ingredients_per_type['vermouth']
|
145 |
+
n_alcoholic = np.sum([ingredient_profiles['ingredient'][i].lower() in alcoholic for i in ingredient_indexes])
|
146 |
+
n_bitter = np.sum([ingredient_profiles['type'][i].lower() == 'bitters' for i in ingredient_indexes])
|
147 |
+
n_egg = np.sum([ingredient_profiles['ingredient'][i].lower() == 'egg' for i in ingredient_indexes])
|
148 |
+
n_vermouth = np.sum([ingredient_profiles['type'][i].lower() == 'vermouth' for i in ingredient_indexes])
|
149 |
+
n_sparkling = np.sum([ingredient_profiles['ingredient'][i].lower() == 'sparkling wine' for i in ingredient_indexes])
|
150 |
+
n_bubbles = np.sum([ingredient_profiles['ingredient'][i].lower() in ['soda', 'tonic', 'cola', 'ginger beer'] for i in ingredient_indexes])
|
151 |
+
n_syrup = np.sum([ingredient_profiles['ingredient'][i].lower() in ['grenadine', 'raspberry syrup'] for i in ingredient_indexes])
|
152 |
+
n_coffee = np.sum([ingredient_profiles['ingredient'][i].lower() == 'espresso' for i in ingredient_indexes])
|
153 |
+
inconsequentials = ['water', 'salt', 'angostura', 'orange bitters', 'mint']
|
154 |
+
n_inconsequentials = np.sum([ingredient_profiles['ingredient'][i].lower() in inconsequentials for i in ingredient_indexes])
|
155 |
+
n = dict(all=len(ingredients),
|
156 |
+
inconsequentials=n_inconsequentials,
|
157 |
+
sugar_bubbles=n_sugar_bubbles,
|
158 |
+
bubbles=n_bubbles,
|
159 |
+
plain_sweet=n_plain_sweet,
|
160 |
+
all_sweet=n_all_sweet,
|
161 |
+
coffee=n_coffee,
|
162 |
+
alcoholic=n_alcoholic,
|
163 |
+
syrup=n_syrup,
|
164 |
+
sparkling=n_sparkling,
|
165 |
+
sugar=n_sugar,
|
166 |
+
spirit=n_spirit,
|
167 |
+
citrus=n_citrus,
|
168 |
+
juice=n_juice,
|
169 |
+
liqueur=n_liqueur,
|
170 |
+
bitter=n_bitter,
|
171 |
+
egg=n_egg,
|
172 |
+
vermouth=n_vermouth)
|
173 |
+
|
174 |
+
sub_cats = [c for c, test_c in zip(sub_categories, is_sub_category) if test_c(n, ingredient_indexes, ingredients, quantities)]
|
175 |
+
if name != None:
|
176 |
+
name = name.lower()
|
177 |
+
keywords_to_test = ['julep', 'collins', 'highball', 'sour', 'champagne']
|
178 |
+
for k in keywords_to_test:
|
179 |
+
if k in name and not any([k in cat for cat in sub_cats]):
|
180 |
+
print(k)
|
181 |
+
for ing, q in zip(ingredients, quantities):
|
182 |
+
print(f'{ing}: {q} ml')
|
183 |
+
print(n)
|
184 |
+
break
|
185 |
+
if sorted(sub_cats) == ['champagne_cocktail', 'complex_highball']:
|
186 |
+
sub_cats = ['champagne_cocktail']
|
187 |
+
elif sorted(sub_cats) == ['collins', 'complex_highball']:
|
188 |
+
sub_cats = ['collins']
|
189 |
+
elif sorted(sub_cats) == ['champagne_cocktail', 'complex_highball', 'julep']:
|
190 |
+
sub_cats = ['champagne_cocktail']
|
191 |
+
elif sorted(sub_cats) == ['ancestral', 'julep']:
|
192 |
+
sub_cats = ['julep']
|
193 |
+
elif sorted(sub_cats) == ['complex_highball', 'julep']:
|
194 |
+
sub_cats = ['complex_highball']
|
195 |
+
elif sorted(sub_cats) == ['julep', 'simple_sour_with_juice']:
|
196 |
+
sub_cats = ['simple_sour_with_juice']
|
197 |
+
elif sorted(sub_cats) == ['complex_sour_with_juice', 'julep']:
|
198 |
+
sub_cats = ['complex_sour_with_juice']
|
199 |
+
if len(sub_cats) != 1:
|
200 |
+
# print(sub_cats)
|
201 |
+
# for ing, q in zip(ingredients, quantities):
|
202 |
+
# print(f'{ing}: {q} ml')
|
203 |
+
# print(n)
|
204 |
+
# if len(sub_cats) == 0:
|
205 |
+
sub_cats = ['other']
|
206 |
+
assert len(sub_cats) == 1, sub_cats
|
207 |
+
return sub_cats[0], n
|
208 |
+
|
209 |
+
def get_cocktails_attributes(ing_strs):
|
210 |
+
attributes = dict()
|
211 |
+
cats = []
|
212 |
+
for ing_str in ing_strs:
|
213 |
+
ingredients, quantities = extract_ingredients(ing_str)
|
214 |
+
cat, atts = find_cocktail_sub_category(ingredients, quantities)
|
215 |
+
for k in atts.keys():
|
216 |
+
if k not in attributes.keys():
|
217 |
+
attributes[k] = [atts[k]]
|
218 |
+
else:
|
219 |
+
attributes[k].append(atts[k])
|
220 |
+
cats.append(cat)
|
221 |
+
return cats, attributes
|
src/cocktails/utilities/cocktail_generation_utilities/__init__.py
ADDED
File without changes
|
src/cocktails/utilities/cocktail_generation_utilities/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (238 Bytes). View file
|
|
src/cocktails/utilities/cocktail_generation_utilities/__pycache__/individual.cpython-39.pyc
ADDED
Binary file (20.2 kB). View file
|
|
src/cocktails/utilities/cocktail_generation_utilities/__pycache__/population.cpython-39.pyc
ADDED
Binary file (8.36 kB). View file
|
|
src/cocktails/utilities/cocktail_generation_utilities/individual.py
ADDED
@@ -0,0 +1,587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.cocktails.utilities.ingredients_utilities import get_ingredients_info, format_ingredients, extract_ingredients, ingredients_per_type, bubble_ingredients
|
2 |
+
import numpy as np
|
3 |
+
from src.cocktails.utilities.other_scrubbing_utilities import print_recipe
|
4 |
+
from src.cocktails.utilities.cocktail_utilities import get_cocktail_rep, get_profile, get_bunch_of_rep_keys
|
5 |
+
from src.cocktails.utilities.glass_and_volume_utilities import glass_volume
|
6 |
+
from src.cocktails.representation_learning.run import get_model
|
7 |
+
from src.cocktails.pipeline.get_cocktail2affective_cluster import get_cocktail2affective_cluster
|
8 |
+
from src.cocktails.config import COCKTAILS_CSV_DATA, FULL_COCKTAIL_REP_PATH, REPO_PATH, COCKTAIL_REP_CHKPT_PATH, RECIPE2FEATURES_PATH
|
9 |
+
from src.cocktails.representation_learning.run_without_vae import get_model
|
10 |
+
from src.cocktails.utilities.cocktail_category_detection_utilities import find_cocktail_sub_category
|
11 |
+
|
12 |
+
import pandas as pd
|
13 |
+
import torch
|
14 |
+
import time
|
15 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
16 |
+
|
17 |
+
density_ingredients = np.loadtxt(COCKTAIL_REP_CHKPT_PATH + 'density_ingredients.txt')
|
18 |
+
max_ingredients, ingredient_list, ind_alcohol = get_ingredients_info()
|
19 |
+
min_ingredients = 2
|
20 |
+
factor_max = 1.2 # generated recipes can go up to 1.2 times the max quantity of the ingredient found in the dataset
|
21 |
+
|
22 |
+
prep_model = get_model(RECIPE2FEATURES_PATH + 'multi_predictor/')[0]
|
23 |
+
|
24 |
+
all_rep_path = FULL_COCKTAIL_REP_PATH
|
25 |
+
all_reps = np.loadtxt(all_rep_path)
|
26 |
+
experiment_dir = REPO_PATH + '/experiments/cocktails/'
|
27 |
+
rep_keys = get_bunch_of_rep_keys()['custom']
|
28 |
+
dict_weights_mse_computation = {'end volume': .1, 'end sour': 2, 'end sweet': 2, 'end booze': 4, 'end bitter': 2, 'end fruit': 1, 'end herb': 1,
|
29 |
+
'end complex': 1, 'end spicy': 5, 'end oaky': 1, 'end fizzy': 10, 'end colorful': 1, 'end eggy': 10}
|
30 |
+
assert sorted(dict_weights_mse_computation.keys()) == sorted(rep_keys)
|
31 |
+
weights_mse_computation = np.array([dict_weights_mse_computation[k] for k in rep_keys])
|
32 |
+
weights_mse_computation /= weights_mse_computation.sum()
|
33 |
+
data = pd.read_csv(COCKTAILS_CSV_DATA)
|
34 |
+
preparation_list = sorted(set(data['category']))
|
35 |
+
glasses_list = sorted(set(data['glass']))
|
36 |
+
|
37 |
+
weights_perf_n_ing = {2:0.71, 3:0.81, 4:0.93, 5:1., 6:1.03, 7:1.08, 8:1.05}
|
38 |
+
|
39 |
+
# weights_perf_n_ing = {2:0.75, 3:0.8, 4:0.95, 5:1.05, 6:1.05, 7:1.05, 8:1.05}
|
40 |
+
min_ingredients_quantities_when_present = np.loadtxt(COCKTAIL_REP_CHKPT_PATH +'ingredients_min_quantities_when_present.txt')
|
41 |
+
min_ingredients_quantities = np.loadtxt(COCKTAIL_REP_CHKPT_PATH +'ingredients_min_quantities.txt')
|
42 |
+
max_ingredients_quantities = np.loadtxt(COCKTAIL_REP_CHKPT_PATH + 'ingredients_max_quantities.txt')
|
43 |
+
min_cocktail_rep, max_cocktail_rep = np.loadtxt(COCKTAIL_REP_CHKPT_PATH +'cocktail_minmax_dim13_customkeys.txt')
|
44 |
+
distrib_nb_ings_2_8 = np.loadtxt(COCKTAIL_REP_CHKPT_PATH + 'distrib_nb_ing.txt')[2:]
|
45 |
+
def normalize_cocktail(cocktail_rep):
|
46 |
+
return ((cocktail_rep - min_cocktail_rep) / (max_cocktail_rep - min_cocktail_rep) - 0.5) * 2
|
47 |
+
|
48 |
+
def denormalize_cocktail(cocktail_rep):
|
49 |
+
return (cocktail_rep / 2 + 0.5) * (max_cocktail_rep - min_cocktail_rep) + min_cocktail_rep
|
50 |
+
|
51 |
+
def normalize_ingredient_q_rep(ingredients_q):
|
52 |
+
return (ingredients_q - min_ingredients_quantities_when_present) / (max_ingredients_quantities * factor_max - min_ingredients_quantities_when_present)
|
53 |
+
|
54 |
+
COCKTAIL_REPS = normalize_cocktail(np.array([data[k] for k in rep_keys]).transpose())
|
55 |
+
assert np.abs(COCKTAIL_REPS - all_reps).sum() < 1e-8
|
56 |
+
|
57 |
+
cocktail2affective_cluster = get_cocktail2affective_cluster()
|
58 |
+
|
59 |
+
original_affective_keys = get_bunch_of_rep_keys()['affective']
|
60 |
+
def sigmoid(x, shift, beta):
|
61 |
+
return (1 / (1 + np.exp(-(x + shift) * beta)) - 0.5) * 2
|
62 |
+
|
63 |
+
def get_normalized_affective_cocktail_rep_from_normalized_cocktail_rep(cocktail_rep):
|
64 |
+
indexes = np.array([rep_keys.index(key) for key in original_affective_keys])
|
65 |
+
cocktail_rep = cocktail_rep[indexes]
|
66 |
+
cocktail_rep[0] = sigmoid(cocktail_rep[0], shift=0.05, beta=4)
|
67 |
+
cocktail_rep[1] = sigmoid(cocktail_rep[1], shift=0.3, beta=5)
|
68 |
+
cocktail_rep[2] = sigmoid(cocktail_rep[2], shift=0.15, beta=3)
|
69 |
+
cocktail_rep[3] = sigmoid(cocktail_rep[3], shift=0.9, beta=20)
|
70 |
+
cocktail_rep[4] = sigmoid(cocktail_rep[4], shift=0, beta=4)
|
71 |
+
cocktail_rep[5] = sigmoid(cocktail_rep[5], shift=0.2, beta=3)
|
72 |
+
cocktail_rep[6] = sigmoid(cocktail_rep[6], shift=0.5, beta=5)
|
73 |
+
cocktail_rep[7] = sigmoid(cocktail_rep[7], shift=0.2, beta=6)
|
74 |
+
return cocktail_rep
|
75 |
+
|
76 |
+
class IndividualCocktail():
|
77 |
+
def __init__(self, pop_params, target, target_affective_cluster, genes_presence=None, genes_quantity=None,
|
78 |
+
compute_perf=True, known_target_dict=None, run_hard_check=False):
|
79 |
+
|
80 |
+
self.pop_params = pop_params
|
81 |
+
self.n_genes = len(ingredient_list)
|
82 |
+
self.max_ingredients = max_ingredients
|
83 |
+
self.min_ingredients = min_ingredients
|
84 |
+
self.mutation_params = pop_params['mutation_params']
|
85 |
+
self.dist = pop_params['dist']
|
86 |
+
self.target = target
|
87 |
+
self.is_known = known_target_dict is not None
|
88 |
+
self.known_target_dict = known_target_dict
|
89 |
+
self.perf = None
|
90 |
+
self.cocktail_rep = None
|
91 |
+
self.affective_cluster = None
|
92 |
+
self.target_affective_cluster = target_affective_cluster
|
93 |
+
self.ing_list = np.array(ingredient_list)
|
94 |
+
self.ing_set = set(ingredient_list)
|
95 |
+
|
96 |
+
self.ing_ids_per_cat = dict(bubbles=set(self.get_ingredients_ids_from_list(bubble_ingredients)),
|
97 |
+
liquor=set(self.get_ingredients_ids_from_list(ingredients_per_type['liquor'])),
|
98 |
+
liqueur=set(self.get_ingredients_ids_from_list(ingredients_per_type['liqueur'])),
|
99 |
+
citrus=set(self.get_ingredients_ids_from_list(ingredients_per_type['acid'] + ['orange juice'])),
|
100 |
+
alcohol=set(ind_alcohol),
|
101 |
+
sweeteners=set(self.get_ingredients_ids_from_list(ingredients_per_type['sweeteners'])),
|
102 |
+
vermouth=set(self.get_ingredients_ids_from_list(ingredients_per_type['vermouth'])),
|
103 |
+
bitters=set(self.get_ingredients_ids_from_list(ingredients_per_type['bitters'])),
|
104 |
+
juice=set(self.get_ingredients_ids_from_list(ingredients_per_type['juice'])),
|
105 |
+
acid=set(self.get_ingredients_ids_from_list(ingredients_per_type['acid'])),
|
106 |
+
egg=set(self.get_ingredients_ids_from_list(['egg']))
|
107 |
+
)
|
108 |
+
|
109 |
+
if genes_presence is not None:
|
110 |
+
assert len(genes_presence) == self.n_genes
|
111 |
+
assert len(genes_quantity) == self.n_genes
|
112 |
+
self.genes_presence = genes_presence
|
113 |
+
self.genes_quantity = genes_quantity
|
114 |
+
if compute_perf:
|
115 |
+
self.compute_cocktail_rep()
|
116 |
+
self.compute_perf()
|
117 |
+
else:
|
118 |
+
self.sample_initial_genes()
|
119 |
+
self.compute_cocktail_rep()
|
120 |
+
# self.make_recipe_fit_the_glass()
|
121 |
+
self.compute_perf()
|
122 |
+
|
123 |
+
|
124 |
+
# # # # # # # # # # # # # # # # # # # # # # # #
|
125 |
+
# Sample initial genes with smart rules
|
126 |
+
# # # # # # # # # # # # # # # # # # # # # # # #
|
127 |
+
|
128 |
+
def sample_initial_genes(self):
|
129 |
+
# rules:
|
130 |
+
# - between min_ingredients and max_ingredients
|
131 |
+
# - at most one type of bubbles
|
132 |
+
# - at least one alcohol
|
133 |
+
# - no egg without lime or lemon
|
134 |
+
# - at most two liqueurs
|
135 |
+
# - at most three liquors
|
136 |
+
# - at most two sweetener
|
137 |
+
self.genes_quantity = np.random.uniform(0, 1, size=self.n_genes) # holds quantities for each ingredient
|
138 |
+
n_ingredients = np.random.choice(np.arange(min_ingredients, max_ingredients + 1), p=distrib_nb_ings_2_8)
|
139 |
+
self.genes_presence = np.zeros(self.n_genes)
|
140 |
+
# add one alchohol
|
141 |
+
self.genes_presence[np.random.choice(ind_alcohol)] = 1
|
142 |
+
while self.get_ing_count() < n_ingredients:
|
143 |
+
candidate_ids = self.get_candidate_ingredients_ids(self.genes_presence)
|
144 |
+
probas = density_ingredients[candidate_ids] / np.sum(density_ingredients[candidate_ids])
|
145 |
+
self.genes_presence[np.random.choice(candidate_ids, p=probas)] = 1
|
146 |
+
|
147 |
+
def get_candidate_ingredients_ids(self, genes_presence):
|
148 |
+
candidates = set(np.argwhere(genes_presence==0).flatten())
|
149 |
+
present_ids = set(np.argwhere(genes_presence==1).flatten())
|
150 |
+
|
151 |
+
if self.count_in_genes(present_ids, 'bubbles') >= 1: # at most one type of bubbles
|
152 |
+
candidates = candidates - self.ing_ids_per_cat['bubbles']
|
153 |
+
if self.count_in_genes(present_ids, 'liquor') >= 3: # at most three liquors
|
154 |
+
candidates = candidates - self.ing_ids_per_cat['liquor']
|
155 |
+
if self.count_in_genes(present_ids, 'liqueur') >= 2: # at most two liqueurs
|
156 |
+
candidates = candidates - self.ing_ids_per_cat['liqueur']
|
157 |
+
if self.count_in_genes(present_ids, 'sweeteners') >= 2: # at most two sweetener
|
158 |
+
candidates = candidates - self.ing_ids_per_cat['sweeteners']
|
159 |
+
if self.count_in_genes(present_ids, 'citrus') == 0: # no egg without lime or lemon
|
160 |
+
candidates = candidates - self.ing_ids_per_cat['egg']
|
161 |
+
return np.array(sorted(candidates))
|
162 |
+
|
163 |
+
def count_in_genes(self, present_ids, keyword):
|
164 |
+
if keyword == 'citrus': return len(present_ids & self.ing_ids_per_cat['citrus'])
|
165 |
+
elif keyword == 'bubbles': return len(present_ids & self.ing_ids_per_cat['bubbles'])
|
166 |
+
elif keyword == 'liquor': return len(present_ids & self.ing_ids_per_cat['liquor'])
|
167 |
+
elif keyword == 'liqueur': return len(present_ids & self.ing_ids_per_cat['liqueur'])
|
168 |
+
elif keyword == 'alcohol': return len(present_ids & self.ing_ids_per_cat['alcohol'])
|
169 |
+
elif keyword == 'sweeteners': return len(present_ids & self.ing_ids_per_cat['sweeteners'])
|
170 |
+
else: raise ValueError
|
171 |
+
|
172 |
+
def get_ingredients_ids_from_list(self, ing_list):
|
173 |
+
return [ingredient_list.index(ing) for ing in ing_list]
|
174 |
+
|
175 |
+
def get_ing_count(self):
|
176 |
+
return np.sum(self.genes_presence)
|
177 |
+
|
178 |
+
# # # # # # # # # # # # # # # # # # # # # # # #
|
179 |
+
# Compute cocktail representations
|
180 |
+
# # # # # # # # # # # # # # # # # # # # # # # #
|
181 |
+
|
182 |
+
def get_absent_ing(self):
|
183 |
+
return np.argwhere(self.genes_presence==0).flatten()
|
184 |
+
|
185 |
+
def get_present_ing(self):
|
186 |
+
return np.argwhere(self.genes_presence==1).flatten()
|
187 |
+
|
188 |
+
def get_ingredient_quantities(self):
|
189 |
+
# unnormalize quantities to get real ones
|
190 |
+
return (self.genes_quantity * (max_ingredients_quantities * factor_max - min_ingredients_quantities_when_present) + min_ingredients_quantities_when_present) * self.genes_presence
|
191 |
+
|
192 |
+
def get_ing_and_q_from_genes(self):
|
193 |
+
present_ings = self.get_present_ing()
|
194 |
+
ing_quantities = self.get_ingredient_quantities()
|
195 |
+
ingredients, quantities = [], []
|
196 |
+
for i_ing in present_ings:
|
197 |
+
ingredients.append(ingredient_list[i_ing])
|
198 |
+
quantities.append(ing_quantities[i_ing])
|
199 |
+
return ingredients, quantities, ing_quantities
|
200 |
+
|
201 |
+
def compute_cocktail_rep(self):
|
202 |
+
# only call when genes have changes
|
203 |
+
init_time = time.time()
|
204 |
+
ingredients, quantities, ing_quantities = self.get_ing_and_q_from_genes()
|
205 |
+
# compute cocktail category
|
206 |
+
self.category = find_cocktail_sub_category(ingredients, quantities)[0]
|
207 |
+
# print(f't1: {time.time() - init_time}')
|
208 |
+
init_time = time.time()
|
209 |
+
self.prep_type = self.get_prep_type(ing_quantities)
|
210 |
+
# print(f't2: {time.time() - init_time}')
|
211 |
+
init_time = time.time()
|
212 |
+
cocktail_rep, self.end_volume, self.end_alcohol = get_cocktail_rep(self.prep_type, ingredients, quantities, keys=rep_keys[1:]) # volume is added later
|
213 |
+
# print(f't3: {time.time() - init_time}')
|
214 |
+
init_time = time.time()
|
215 |
+
self.cocktail_rep = normalize_cocktail(cocktail_rep)
|
216 |
+
# print(f't4: {time.time() - init_time}')
|
217 |
+
init_time = time.time()
|
218 |
+
self.glass = self.get_glass_type(ing_quantities)
|
219 |
+
# print(f't5: {time.time() - init_time}')
|
220 |
+
init_time = time.time()
|
221 |
+
if self.is_known:
|
222 |
+
assert np.abs(self.cocktail_rep - self.target).sum() < 1e-6
|
223 |
+
return self.cocktail_rep
|
224 |
+
|
225 |
+
def get_prep_type(self, quantities=None):
|
226 |
+
if self.is_known: return self.known_target_dict['prep_type']
|
227 |
+
else:
|
228 |
+
if quantities is None:
|
229 |
+
quantities = self.get_ingredient_quantities()
|
230 |
+
if quantities[ingredient_list.index('egg')] > 0:
|
231 |
+
prep_cat = 'egg_shaken'
|
232 |
+
elif self.category in ['spirit_forward', 'simple_sour_with_juice', 'julep', 'duo', 'ancestral', 'complex_sour_with_juice']:
|
233 |
+
# use hard coded rules for most obvious cases determined with the correlations_glass_cat_prep_script
|
234 |
+
if self.category in ['ancestral', 'spirit_forward', 'duo']:
|
235 |
+
prep_cat = 'stirred'
|
236 |
+
elif self.category in ['complex_sour_with_juice', 'julep', 'simple_sour_with_juice']:
|
237 |
+
prep_cat = 'shaken'
|
238 |
+
else:
|
239 |
+
raise ValueError
|
240 |
+
else:
|
241 |
+
output = prep_model(quantities, aux_str='prep_type').flatten()
|
242 |
+
output[preparation_list.index('egg_shaken')] = -np.inf
|
243 |
+
prep_cat = preparation_list[np.argmax(output)]
|
244 |
+
return prep_cat
|
245 |
+
|
246 |
+
def get_glass_type(self, quantities=None):
|
247 |
+
if self.is_known: return self.known_target_dict['glass']
|
248 |
+
else:
|
249 |
+
if self.category in ['collins', 'complex_highball', 'simple_highball', 'champagne_cocktail', 'complex_sour']:
|
250 |
+
# use hard coded rules for most obvious cases determined with the correlations_glass_cat_prep_script
|
251 |
+
if self.category in ['collins', 'complex_highball', 'simple_highball']:
|
252 |
+
glass = 'collins'
|
253 |
+
elif self.category in ['champagne_cocktail', 'complex_sour']:
|
254 |
+
glass = 'coupe'
|
255 |
+
else:
|
256 |
+
if quantities is None:
|
257 |
+
quantities = self.get_ingredient_quantities()
|
258 |
+
output = prep_model(quantities, aux_str='glasses').flatten()
|
259 |
+
glass = glasses_list[np.argmax(output)]
|
260 |
+
return glass
|
261 |
+
|
262 |
+
# # # # # # # # # # # # # # # # # # # # # # # #
|
263 |
+
# Adapt recipe to fit the glass
|
264 |
+
# # # # # # # # # # # # # # # # # # # # # # # #
|
265 |
+
|
266 |
+
def is_too_large_for_glass(self):
|
267 |
+
return self.end_volume > glass_volume[self.glass] * 0.80
|
268 |
+
|
269 |
+
def is_too_small_for_glass(self):
|
270 |
+
return self.end_volume < glass_volume[self.glass] * 0.3
|
271 |
+
|
272 |
+
def scale_ing_quantities(self, present_ings, factor):
|
273 |
+
qs = self.get_ingredient_quantities().copy()
|
274 |
+
qs[present_ings] *= factor
|
275 |
+
self.set_genes_from_quantities(present_ings, qs)
|
276 |
+
|
277 |
+
def set_genes_from_quantities(self, present_ings, quantities):
|
278 |
+
genes_quantity = np.clip((quantities - min_ingredients_quantities_when_present) /
|
279 |
+
(factor_max * max_ingredients_quantities - min_ingredients_quantities_when_present), 0, 1)
|
280 |
+
self.genes_quantity[present_ings] = genes_quantity[present_ings]
|
281 |
+
|
282 |
+
def make_recipe_fit_the_glass(self):
|
283 |
+
# check if citrus, if not remove egg
|
284 |
+
present_ids = np.argwhere(self.genes_presence == 1).flatten()
|
285 |
+
ing_list = self.ing_list[present_ids]
|
286 |
+
present_ids = set(present_ids)
|
287 |
+
if self.count_in_genes(present_ids, 'citrus') == 0 and 'egg' in ing_list:
|
288 |
+
if self.genes_presence.sum() > 2:
|
289 |
+
i_egg = ingredient_list.index('egg')
|
290 |
+
self.genes_presence[i_egg] = 0.
|
291 |
+
self.compute_cocktail_rep()
|
292 |
+
|
293 |
+
|
294 |
+
i_trial = 0
|
295 |
+
present_ings = self.get_present_ing()
|
296 |
+
while self.is_too_large_for_glass():
|
297 |
+
i_trial += 1
|
298 |
+
end_volume = self.end_volume
|
299 |
+
desired_volume = glass_volume[self.glass] * 0.80
|
300 |
+
ratio = desired_volume / end_volume
|
301 |
+
self.scale_ing_quantities(present_ings, factor=ratio)
|
302 |
+
self.compute_cocktail_rep()
|
303 |
+
if end_volume == self.end_volume: break
|
304 |
+
if i_trial == 10: break
|
305 |
+
while self.is_too_small_for_glass():
|
306 |
+
i_trial += 1
|
307 |
+
end_volume = self.end_volume
|
308 |
+
desired_volume = glass_volume[self.glass] * 0.80
|
309 |
+
ratio = desired_volume / end_volume
|
310 |
+
self.scale_ing_quantities(present_ings, factor=ratio)
|
311 |
+
self.compute_cocktail_rep()
|
312 |
+
if end_volume == self.end_volume: break
|
313 |
+
if i_trial == 10: break
|
314 |
+
|
315 |
+
# # # # # # # # # # # # # # # # # # # # # # # #
|
316 |
+
# Compute performance
|
317 |
+
# # # # # # # # # # # # # # # # # # # # # # # #
|
318 |
+
|
319 |
+
def passes_checks(self):
|
320 |
+
present_ids = np.argwhere(self.genes_presence==1).flatten()
|
321 |
+
# ing_list = self.ing_list[present_ids]
|
322 |
+
present_ids = set(present_ids)
|
323 |
+
if len(present_ids) < 2 or len(present_ids) > 8: return False
|
324 |
+
# if self.is_too_large_for_glass(): return False
|
325 |
+
# if self.is_too_small_for_glass(): return False
|
326 |
+
if self.end_alcohol < 0.05 or self.end_alcohol > 0.31: return False
|
327 |
+
if self.count_in_genes(present_ids, 'sweeteners') > 2: return False
|
328 |
+
if self.count_in_genes(present_ids, 'liqueur') > 2: return False
|
329 |
+
if self.count_in_genes(present_ids, 'liquor') > 3: return False
|
330 |
+
# if self.count_in_genes(present_ids, 'citrus') == 0 and 'egg' in ing_list: return False
|
331 |
+
if self.count_in_genes(present_ids, 'bubbles') > 1: return False
|
332 |
+
else: return True
|
333 |
+
|
334 |
+
def get_affective_cluster(self):
|
335 |
+
cocktail_rep_affective = get_normalized_affective_cocktail_rep_from_normalized_cocktail_rep(self.cocktail_rep)
|
336 |
+
self.affective_cluster = cocktail2affective_cluster(cocktail_rep_affective)[0]
|
337 |
+
return self.affective_cluster
|
338 |
+
|
339 |
+
def does_affective_cluster_match(self):
|
340 |
+
return True#self.get_affective_cluster() == self.target_affective_cluster
|
341 |
+
|
342 |
+
def compute_perf(self):
|
343 |
+
if not self.passes_checks(): self.perf = -100
|
344 |
+
else:
|
345 |
+
if self.dist == 'mse':
|
346 |
+
# self.perf = - np.sqrt(((self.cocktail_rep - self.target)**2).mean())
|
347 |
+
self.perf = - np.sqrt(np.dot((self.cocktail_rep - self.target)**2, weights_mse_computation))
|
348 |
+
self.perf *= weights_perf_n_ing[int(self.genes_presence.sum())]
|
349 |
+
if not self.does_affective_cluster_match():
|
350 |
+
self.perf *= 2
|
351 |
+
else: raise NotImplemented
|
352 |
+
|
353 |
+
|
354 |
+
# # # # # # # # # # # # # # # # # # # # # # # #
|
355 |
+
# Mutations and crossover
|
356 |
+
# # # # # # # # # # # # # # # # # # # # # # # #
|
357 |
+
|
358 |
+
def get_child(self):
|
359 |
+
time_dict = dict()
|
360 |
+
init_time = time.time()
|
361 |
+
child = IndividualCocktail(pop_params=self.pop_params, target_affective_cluster=self.target_affective_cluster,
|
362 |
+
target=self.target, genes_presence=self.genes_presence.copy(),
|
363 |
+
genes_quantity=self.genes_quantity.copy(), compute_perf=False)
|
364 |
+
time_dict[' asexual child creation'] = [time.time() - init_time]
|
365 |
+
init_time = time.time()
|
366 |
+
this_time_dict = child.mutate()
|
367 |
+
time_dict = self.update_time_dict(time_dict, this_time_dict)
|
368 |
+
time_dict[' asexual child mutation'] = [time.time() - init_time]
|
369 |
+
return child, time_dict
|
370 |
+
|
371 |
+
def get_child_with(self, other_parent):
|
372 |
+
time_dict = dict()
|
373 |
+
init_time = time.time()
|
374 |
+
new_genes_presence = np.zeros(self.n_genes)
|
375 |
+
present_ing = self.get_present_ing()
|
376 |
+
other_present_ing = other_parent.get_present_ing()
|
377 |
+
new_genes_quantity = np.random.uniform(0, 1, size=self.n_genes)
|
378 |
+
shared_ingredients = sorted(set(present_ing) & set(other_present_ing))
|
379 |
+
unique_ingredients_one = sorted(set(present_ing) - set(other_present_ing))
|
380 |
+
unique_ingredients_two = sorted(set(other_present_ing) - set(present_ing))
|
381 |
+
for i in shared_ingredients:
|
382 |
+
new_genes_presence[i] = 1
|
383 |
+
new_genes_quantity[i] = (self.genes_quantity[i] + other_parent.genes_quantity[i]) / 2
|
384 |
+
time_dict[' crossover child creation'] = [time.time() - init_time]
|
385 |
+
init_time = time.time()
|
386 |
+
# add one alcohol if none present
|
387 |
+
if len(set(np.argwhere(new_genes_presence==1).flatten()).intersection(ind_alcohol)) == 0:
|
388 |
+
new_genes_presence[np.random.choice(ind_alcohol)] = 1
|
389 |
+
# up to here, we respect the constraints (assuming both parents do).
|
390 |
+
candidate_genes = np.array(unique_ingredients_one + unique_ingredients_two)
|
391 |
+
candidate_quantities = np.array([self.genes_quantity[i] for i in unique_ingredients_one] + [other_parent.genes_quantity[i] for i in unique_ingredients_two])
|
392 |
+
indexes = np.arange(len(candidate_genes))
|
393 |
+
np.random.shuffle(indexes)
|
394 |
+
candidate_genes = candidate_genes[indexes]
|
395 |
+
candidate_quantities = candidate_quantities[indexes]
|
396 |
+
time_dict[' crossover prepare selection'] = [time.time() - init_time]
|
397 |
+
init_time = time.time()
|
398 |
+
# now let's try to add each of them while respecting the constraints
|
399 |
+
for i in range(len(indexes)):
|
400 |
+
if np.random.rand() < 0.5 or np.sum(new_genes_presence) < self.min_ingredients: # only try to add one every two ingredient
|
401 |
+
ing_id = candidate_genes[i]
|
402 |
+
q = candidate_quantities[i]
|
403 |
+
new_genes_presence[ing_id] = 1
|
404 |
+
new_genes_quantity[ing_id] = q
|
405 |
+
if np.sum(new_genes_presence) == self.max_ingredients:
|
406 |
+
break
|
407 |
+
time_dict[' crossover do selection'] = [time.time() - init_time]
|
408 |
+
init_time = time.time()
|
409 |
+
# create new child
|
410 |
+
child = IndividualCocktail(pop_params=self.pop_params, target_affective_cluster=self.target_affective_cluster, target=self.target,
|
411 |
+
genes_presence=new_genes_presence.copy(), genes_quantity=new_genes_quantity.copy(), compute_perf=False)
|
412 |
+
time_dict[' crossover create child'] = [time.time() - init_time]
|
413 |
+
init_time = time.time()
|
414 |
+
this_time_dict = child.mutate()
|
415 |
+
time_dict = self.update_time_dict(time_dict, this_time_dict)
|
416 |
+
time_dict[' crossover child mutation'] = [time.time() - init_time]
|
417 |
+
init_time = time.time()
|
418 |
+
return child, time_dict
|
419 |
+
|
420 |
+
def mutate(self):
|
421 |
+
# self.print_recipe()
|
422 |
+
time_dict = dict()
|
423 |
+
# remove an ingredient
|
424 |
+
init_time = time.time()
|
425 |
+
present_ids = set(np.argwhere(self.genes_presence==1).flatten())
|
426 |
+
|
427 |
+
if np.random.rand() < self.mutation_params['p_remove_ing']:
|
428 |
+
if self.get_ing_count() > self.min_ingredients:
|
429 |
+
candidate_ings = self.get_present_ing()
|
430 |
+
if self.count_in_genes(present_ids, 'alcohol') == 1: # make sure we keep at least one liquor
|
431 |
+
candidate_ings = np.array(sorted(set(candidate_ings) - set(ind_alcohol)))
|
432 |
+
index_to_remove = np.random.choice(candidate_ings)
|
433 |
+
self.genes_presence[index_to_remove] = 0
|
434 |
+
time_dict[' mutation remove ing'] = [time.time() - init_time]
|
435 |
+
init_time = time.time()
|
436 |
+
# add an ingredient
|
437 |
+
if np.random.rand() < self.mutation_params['p_add_ing']:
|
438 |
+
if self.get_ing_count() < self.max_ingredients:
|
439 |
+
candidate_ings = self.get_candidate_ingredients_ids(self.genes_presence.copy())
|
440 |
+
index_to_add = np.random.choice(candidate_ings, p=density_ingredients[candidate_ings] / np.sum(density_ingredients[candidate_ings]))
|
441 |
+
self.genes_presence[index_to_add] = 1
|
442 |
+
time_dict[' mutation add ing'] = [time.time() - init_time]
|
443 |
+
|
444 |
+
init_time = time.time()
|
445 |
+
# replace ings by others from the same family
|
446 |
+
if np.random.rand() < self.mutation_params['p_switch_ing']:
|
447 |
+
i = np.random.choice(self.get_present_ing())
|
448 |
+
ing_str = ingredient_list[i]
|
449 |
+
if ing_str not in ['sparkling wine', 'orange juice']:
|
450 |
+
if ing_str in bubble_ingredients:
|
451 |
+
candidates_ids = np.array(sorted(self.ing_ids_per_cat['bubbles'] - set([i])))
|
452 |
+
new_bubble = np.random.choice(candidates_ids, p=density_ingredients[candidates_ids] / np.sum(density_ingredients[candidates_ids]))
|
453 |
+
self.genes_presence[i] = 0
|
454 |
+
self.genes_presence[new_bubble] = 1
|
455 |
+
self.genes_quantity[new_bubble] = self.genes_quantity[i] # copy quantity
|
456 |
+
categories = ['acid', 'bitters', 'juice', 'liqueur', 'liquor', 'sweeteners', 'vermouth']
|
457 |
+
for cat in categories:
|
458 |
+
if ing_str in ingredients_per_type[cat]:
|
459 |
+
present_ings = self.get_present_ing()
|
460 |
+
candidates_ids = np.array(sorted(self.ing_ids_per_cat[cat] - set([i]) - set(present_ings)))
|
461 |
+
if len(candidates_ids) > 0:
|
462 |
+
replacing_ing = np.random.choice(candidates_ids, p=density_ingredients[candidates_ids] / np.sum(density_ingredients[candidates_ids]))
|
463 |
+
self.genes_presence[i] = 0
|
464 |
+
self.genes_presence[replacing_ing] = 1
|
465 |
+
self.genes_quantity[replacing_ing] = self.genes_quantity[i] # copy quantity
|
466 |
+
break
|
467 |
+
time_dict[' mutation switch ing'] = [time.time() - init_time]
|
468 |
+
init_time = time.time()
|
469 |
+
# add noise on ing quantity
|
470 |
+
for i in self.get_present_ing():
|
471 |
+
if np.random.rand() < self.mutation_params['p_change_q']:
|
472 |
+
self.genes_quantity[i] += np.random.randn() * self.mutation_params['delta_change_q']
|
473 |
+
self.genes_quantity = np.clip(self.genes_quantity, 0, 1)
|
474 |
+
time_dict[' mutation change quantity'] = [time.time() - init_time]
|
475 |
+
|
476 |
+
init_time = time.time()
|
477 |
+
self.compute_cocktail_rep()
|
478 |
+
time_dict[' mutation compute cocktail rep'] = [time.time() - init_time]
|
479 |
+
init_time = time.time()
|
480 |
+
# self.make_recipe_fit_the_glass()
|
481 |
+
time_dict[' mutation check glass fit'] = [time.time() - init_time]
|
482 |
+
init_time = time.time()
|
483 |
+
self.compute_perf()
|
484 |
+
time_dict[' mutation compute perf'] = [time.time() - init_time]
|
485 |
+
init_time = time.time()
|
486 |
+
stop = 1
|
487 |
+
return time_dict
|
488 |
+
|
489 |
+
|
490 |
+
def update_time_dict(self, main_dict, new_dict):
|
491 |
+
for k in new_dict.keys():
|
492 |
+
if k in main_dict.keys():
|
493 |
+
main_dict[k].append(np.sum(new_dict[k]))
|
494 |
+
else:
|
495 |
+
main_dict[k] = [np.sum(new_dict[k])]
|
496 |
+
return main_dict
|
497 |
+
|
498 |
+
# # # # # # # # # # # # # # # # # # # # # # # #
|
499 |
+
# Get recipe and print
|
500 |
+
# # # # # # # # # # # # # # # # # # # # # # # #
|
501 |
+
|
502 |
+
def get_recipe(self, unit='mL', name=None):
|
503 |
+
ing_quantities = self.get_ingredient_quantities()
|
504 |
+
ingredients, quantities = [], []
|
505 |
+
for i_ing, q_ing in enumerate(ing_quantities):
|
506 |
+
if q_ing > 0.8:
|
507 |
+
ingredients.append(ingredient_list[i_ing])
|
508 |
+
quantities.append(round(q_ing))
|
509 |
+
recipe_str = format_ingredients(ingredients, quantities)
|
510 |
+
recipe_str_readable = print_recipe(unit=unit, ingredient_str=recipe_str, name=name, to_print=False)
|
511 |
+
return ingredients, quantities, recipe_str, recipe_str_readable
|
512 |
+
|
513 |
+
def get_instructions(self):
|
514 |
+
ing_quantities = self.get_ingredient_quantities()
|
515 |
+
ingredients, quantities = [], []
|
516 |
+
for i_ing, q_ing in enumerate(ing_quantities):
|
517 |
+
if q_ing > 0.8:
|
518 |
+
ingredients.append(ingredient_list[i_ing])
|
519 |
+
quantities.append(round(q_ing))
|
520 |
+
str_out = 'Instructions:\n '
|
521 |
+
|
522 |
+
if 'mint' in ingredients:
|
523 |
+
i_mint = ingredients.index('mint')
|
524 |
+
n_leaves = quantities[i_mint]
|
525 |
+
str_out += f'Add {n_leaves} mint leaves to a shaker, followed by an ice cube.\n Muddle the mint and ice together with a muddler.\n '
|
526 |
+
bubbles = ['sparkling wine', 'tonic', 'soda', 'ginger beer']
|
527 |
+
other_ings = [ing for ing in ingredients if ing not in ['egg', 'angostura', 'orange bitters'] + bubbles]
|
528 |
+
|
529 |
+
if self.prep_type == 'built':
|
530 |
+
str_out += 'Add a large ice cube in the glass.\n '
|
531 |
+
# add ingredients to pour
|
532 |
+
str_out += 'Pour'
|
533 |
+
for i, ing in enumerate(other_ings):
|
534 |
+
if i == len(other_ings) - 2:
|
535 |
+
str_out += f' {ing} and'
|
536 |
+
elif i == len(other_ings) - 1:
|
537 |
+
str_out += f' {ing}'
|
538 |
+
else:
|
539 |
+
str_out += f' {ing},'
|
540 |
+
|
541 |
+
if self.prep_type in ['built'] and 'mint' not in ingredients:
|
542 |
+
str_out += ' into the glass.\n '
|
543 |
+
else:
|
544 |
+
str_out += ' into the shaker.\n '
|
545 |
+
|
546 |
+
if self.prep_type == 'egg_shaken' and 'egg' in ingredients:
|
547 |
+
str_out += 'Add the egg white.\n Dry-shake for 15s (without ice), then fill with ice and shake for another 15s.\n Serve into the glass through a strainer.\n '
|
548 |
+
elif 'shaken' in self.prep_type:
|
549 |
+
str_out += 'Fill with ice and shake for 15s.\n Serve into the glass through a strainer.\n '
|
550 |
+
elif self.prep_type == 'stirred':
|
551 |
+
str_out += 'Add ice and stir the cocktail with a spoon for 15s.\n Serve into the glass through a strainer.\n '
|
552 |
+
elif self.prep_type == 'built':
|
553 |
+
str_out += 'Stir two turns with a spoon.\n '
|
554 |
+
|
555 |
+
bubble_ing = [ing for ing in ingredients if ing in bubbles]
|
556 |
+
if len(bubble_ing) > 0:
|
557 |
+
str_out += f'Top up with '
|
558 |
+
for ing in bubble_ing:
|
559 |
+
str_out += f'{ing}, '
|
560 |
+
str_out = str_out[:-2] + '.\n '
|
561 |
+
bitter_ing = [ing for ing in ingredients if ing in ['angostura', 'orange bitters']]
|
562 |
+
if len(bitter_ing) > 0:
|
563 |
+
if len(bitter_ing) == 1:
|
564 |
+
q = quantities[ingredients.index(bitter_ing[0])]
|
565 |
+
n_dashes = max(1, int(q / 0.6))
|
566 |
+
str_out += f'Add {n_dashes} dash'
|
567 |
+
if n_dashes > 1:
|
568 |
+
str_out += 'es'
|
569 |
+
str_out += f' of {bitter_ing[0]}.\n '
|
570 |
+
elif len(bitter_ing) == 2:
|
571 |
+
q = quantities[ingredients.index(bitter_ing[0])]
|
572 |
+
n_dashes = max(1, int(q / 0.6))
|
573 |
+
str_out += f'Add {n_dashes} dash'
|
574 |
+
if n_dashes > 1:
|
575 |
+
str_out += 'es'
|
576 |
+
str_out += f' of {bitter_ing[0]} and '
|
577 |
+
q = quantities[ingredients.index(bitter_ing[1])]
|
578 |
+
n_dashes = max(1, int(q / 0.6))
|
579 |
+
str_out += f'{n_dashes} dash'
|
580 |
+
if n_dashes > 1:
|
581 |
+
str_out += 'es'
|
582 |
+
str_out += f' of {bitter_ing[1]}.\n '
|
583 |
+
str_out += 'Enjoy!'
|
584 |
+
return str_out
|
585 |
+
|
586 |
+
def print_recipe(self, name=None):
|
587 |
+
print(self.get_recipe(name)[3])
|
src/cocktails/utilities/cocktail_generation_utilities/population.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.cocktails.utilities.cocktail_generation_utilities.individual import *
|
2 |
+
from sklearn.neighbors import NearestNeighbors
|
3 |
+
import time
|
4 |
+
import pickle
|
5 |
+
from src.cocktails.config import COCKTAIL_NN_PATH, COCKTAILS_CSV_DATA
|
6 |
+
|
7 |
+
class Population:
|
8 |
+
def __init__(self, target, pop_params, target_affective_cluster=None, known_target_dict=None):
|
9 |
+
self.pop_params = pop_params
|
10 |
+
self.pop_size = pop_params['pop_size']
|
11 |
+
self.nb_elite = pop_params['nb_elites']
|
12 |
+
self.nb_generations = pop_params['nb_generations']
|
13 |
+
self.target = target
|
14 |
+
self.mutation_params = pop_params['mutation_params']
|
15 |
+
self.dist = pop_params['dist']
|
16 |
+
self.n_neighbors = pop_params['n_neighbors']
|
17 |
+
self.known_target_dict = known_target_dict
|
18 |
+
|
19 |
+
|
20 |
+
with open(COCKTAIL_NN_PATH, 'rb') as f:
|
21 |
+
data = pickle.load(f)
|
22 |
+
self.nn_model_cocktail = data['nn_model']
|
23 |
+
self.dim_rep_cocktail = data['dim_rep_cocktail']
|
24 |
+
self.n_cocktails = data['n_cocktails']
|
25 |
+
self.cocktail_data = pd.read_csv(COCKTAILS_CSV_DATA)
|
26 |
+
|
27 |
+
if target_affective_cluster is None:
|
28 |
+
cocktail_rep_affective = get_normalized_affective_cocktail_rep_from_normalized_cocktail_rep(target)
|
29 |
+
self.target_affective_cluster = cocktail2affective_cluster(cocktail_rep_affective)[0]
|
30 |
+
else:
|
31 |
+
self.target_affective_cluster = target_affective_cluster
|
32 |
+
|
33 |
+
self.pop_elite = []
|
34 |
+
self.pop = []
|
35 |
+
self.add_target_individual() # create a target individual (not in pop)
|
36 |
+
self.add_nearest_neighbors_in_pop() # add nearest neighbor from dataset into the population
|
37 |
+
|
38 |
+
# fill population
|
39 |
+
while self.get_pop_size() < self.pop_size:
|
40 |
+
self.add_individual()
|
41 |
+
while len(self.pop_elite) < self.nb_elite:
|
42 |
+
self.pop_elite.append(IndividualCocktail(pop_params=self.pop_params,
|
43 |
+
target=self.target.copy(),
|
44 |
+
target_affective_cluster=self.target_affective_cluster))
|
45 |
+
self.update_elite_and_get_next_pop()
|
46 |
+
|
47 |
+
def add_target_individual(self):
|
48 |
+
if self.known_target_dict is not None:
|
49 |
+
genes_presence, genes_quantity = self.get_q_rep(*extract_ingredients(self.known_target_dict['ing_str']))
|
50 |
+
self.target_individual = IndividualCocktail(pop_params=self.pop_params,
|
51 |
+
target=self.target.copy(),
|
52 |
+
known_target_dict=self.known_target_dict,
|
53 |
+
target_affective_cluster=self.target_affective_cluster,
|
54 |
+
genes_presence=genes_presence,
|
55 |
+
genes_quantity=genes_quantity
|
56 |
+
)
|
57 |
+
else:
|
58 |
+
self.target_individual = None
|
59 |
+
|
60 |
+
|
61 |
+
def add_nearest_neighbors_in_pop(self):
|
62 |
+
# add nearest neighbor from dataset into the population
|
63 |
+
if self.n_neighbors > 0:
|
64 |
+
dists, indexes = self.nn_model_cocktail.kneighbors(self.target.reshape(1, -1))
|
65 |
+
dists, indexes = dists.flatten(), indexes.flatten()
|
66 |
+
first = 1 if dists[0] == 0 else 0 # avoid taking the target when testing with known targets from the dataset
|
67 |
+
indexes = indexes[first:first + self.n_neighbors]
|
68 |
+
self.ing_strs = np.array(self.cocktail_data['ingredients_str'])[indexes]
|
69 |
+
recipes = [extract_ingredients(ing_str) for ing_str in self.ing_strs]
|
70 |
+
for r in recipes:
|
71 |
+
genes_presence, genes_quantity = self.get_q_rep(r[0], r[1])
|
72 |
+
genes_presence[-1] = 0 # remove water ingredient
|
73 |
+
self.add_individual(genes_presence=genes_presence.copy(), genes_quantity=genes_quantity.copy())
|
74 |
+
self.nn_recipes = [ind.get_recipe()[3] for ind in self.pop]
|
75 |
+
self.nn_scores = [ind.perf for ind in self.pop]
|
76 |
+
else:
|
77 |
+
self.ing_strs = None
|
78 |
+
|
79 |
+
def add_individual(self, genes_presence=None, genes_quantity=None):
|
80 |
+
self.pop.append(IndividualCocktail(pop_params=self.pop_params,
|
81 |
+
target=self.target.copy(),
|
82 |
+
target_affective_cluster=self.target_affective_cluster,
|
83 |
+
genes_presence=genes_presence,
|
84 |
+
genes_quantity=genes_quantity))
|
85 |
+
|
86 |
+
def get_elite_perf(self):
|
87 |
+
return np.array([e.perf for e in self.pop_elite])
|
88 |
+
|
89 |
+
def get_pop_perf(self):
|
90 |
+
return np.array([ind.perf for ind in self.pop])
|
91 |
+
|
92 |
+
|
93 |
+
def update_elite_and_get_next_pop(self):
|
94 |
+
time_dict = dict()
|
95 |
+
init_time = time.time()
|
96 |
+
elite_perfs = self.get_elite_perf()
|
97 |
+
pop_perfs = self.get_pop_perf()
|
98 |
+
all_perfs = np.concatenate([elite_perfs, pop_perfs])
|
99 |
+
temp_list = self.pop_elite + self.pop
|
100 |
+
time_dict[' get pop perfs'] = [time.time() - init_time]
|
101 |
+
init_time = time.time()
|
102 |
+
# update elite population with new bests
|
103 |
+
indexes_sorted = np.flip(np.argsort(all_perfs))
|
104 |
+
new_pop_elite = [IndividualCocktail(pop_params=self.pop_params,
|
105 |
+
target=self.target.copy(),
|
106 |
+
target_affective_cluster=self.target_affective_cluster,
|
107 |
+
genes_presence=temp_list[i_new_e].genes_presence.copy(),
|
108 |
+
genes_quantity=temp_list[i_new_e].genes_quantity.copy()) for i_new_e in indexes_sorted[:self.nb_elite]]
|
109 |
+
time_dict[' recreate elite individuals'] = [time.time() - init_time]
|
110 |
+
init_time = time.time()
|
111 |
+
# select parents
|
112 |
+
rank_perfs = np.flip(np.arange(len(temp_list)))
|
113 |
+
sampling_probs = rank_perfs / np.sum(rank_perfs)
|
114 |
+
if self.mutation_params['asexual_rep'] and not self.mutation_params['crossover']:
|
115 |
+
new_pop_indexes = np.random.choice(indexes_sorted, p=sampling_probs, size=self.pop_size)
|
116 |
+
self.pop = [temp_list[i].get_child() for i in new_pop_indexes]
|
117 |
+
elif self.mutation_params['crossover'] and not self.mutation_params['asexual_rep']:
|
118 |
+
self.pop = []
|
119 |
+
while len(self.pop) < self.pop_size:
|
120 |
+
parents = np.random.choice(indexes_sorted, p=sampling_probs, size=2, replace=False)
|
121 |
+
self.pop.append(temp_list[parents[0]].get_child_with(temp_list[parents[1]]))
|
122 |
+
elif self.mutation_params['crossover'] and self.mutation_params['asexual_rep']:
|
123 |
+
new_pop_indexes = np.random.choice(indexes_sorted, p=sampling_probs, size=self.pop_size//2)
|
124 |
+
time_dict[' choose asexual parent indexes'] = [time.time() - init_time]
|
125 |
+
init_time = time.time()
|
126 |
+
self.pop = []
|
127 |
+
for i in new_pop_indexes:
|
128 |
+
child, this_time_dict = temp_list[i].get_child()
|
129 |
+
self.pop.append(child)
|
130 |
+
time_dict = self.update_time_dict(time_dict, this_time_dict)
|
131 |
+
time_dict[' get asexual children'] = [time.time() - init_time]
|
132 |
+
init_time = time.time()
|
133 |
+
while len(self.pop) < self.pop_size:
|
134 |
+
parents = np.random.choice(indexes_sorted, p=sampling_probs, size=2, replace=False)
|
135 |
+
child, this_time_dict = temp_list[parents[0]].get_child_with(temp_list[parents[1]])
|
136 |
+
self.pop.append(child)
|
137 |
+
time_dict = self.update_time_dict(time_dict, this_time_dict)
|
138 |
+
time_dict[' get sexual children'] = [time.time() - init_time]
|
139 |
+
self.pop_elite = new_pop_elite
|
140 |
+
return time_dict
|
141 |
+
|
142 |
+
def get_pop_size(self):
|
143 |
+
return len(self.pop)
|
144 |
+
|
145 |
+
def get_q_rep(self, ingredients, quantities):
|
146 |
+
ingredient_q_rep = np.zeros([len(ingredient_list)])
|
147 |
+
genes_presence = np.zeros([len(ingredient_list)])
|
148 |
+
for ing, q in zip(ingredients, quantities):
|
149 |
+
ingredient_q_rep[ingredient_list.index(ing)] = q
|
150 |
+
genes_presence[ingredient_list.index(ing)] = 1
|
151 |
+
return genes_presence.copy(), normalize_ingredient_q_rep(ingredient_q_rep)
|
152 |
+
|
153 |
+
def get_best_score(self, affective_cluster_check=False):
|
154 |
+
elite_perfs = self.get_elite_perf()
|
155 |
+
pop_perfs = self.get_pop_perf()
|
156 |
+
all_perfs = np.concatenate([elite_perfs, pop_perfs])
|
157 |
+
temp_list = self.pop_elite + self.pop
|
158 |
+
if affective_cluster_check:
|
159 |
+
indexes = np.array([i for i in range(len(temp_list)) if temp_list[i].does_affective_cluster_match()])
|
160 |
+
if indexes.size > 0:
|
161 |
+
temp_list = np.array(temp_list)[indexes]
|
162 |
+
all_perfs = all_perfs[indexes]
|
163 |
+
indexes_best = np.flip(np.argsort(all_perfs))
|
164 |
+
return np.array(all_perfs)[indexes_best], np.array(temp_list)[indexes_best]
|
165 |
+
|
166 |
+
def update_time_dict(self, main_dict, new_dict):
|
167 |
+
for k in new_dict.keys():
|
168 |
+
if k in main_dict.keys():
|
169 |
+
main_dict[k].append(np.sum(new_dict[k]))
|
170 |
+
else:
|
171 |
+
main_dict[k] = [np.sum(new_dict[k])]
|
172 |
+
return main_dict
|
173 |
+
|
174 |
+
def run_one_generation(self, verbose=True, affective_cluster_check=False):
|
175 |
+
time_dict = dict()
|
176 |
+
init_time = time.time()
|
177 |
+
this_time_dict = self.update_elite_and_get_next_pop()
|
178 |
+
time_dict['update_elite_and_pop'] = [time.time() - init_time]
|
179 |
+
time_dict = self.update_time_dict(time_dict, this_time_dict)
|
180 |
+
init_time = time.time()
|
181 |
+
best_perfs, best_individuals = self.get_best_score(affective_cluster_check)
|
182 |
+
time_dict['get best scores'] = [time.time() - init_time]
|
183 |
+
return best_perfs[0], time_dict
|
184 |
+
|
185 |
+
def run_evolution(self, verbose=False, print_every=10, affective_cluster_check=False, level=0):
|
186 |
+
best_score = -np.inf
|
187 |
+
time_dict = dict()
|
188 |
+
init_time = time.time()
|
189 |
+
for i in range(self.nb_generations):
|
190 |
+
best_score, this_time_dict = self.run_one_generation(verbose, affective_cluster_check=affective_cluster_check)
|
191 |
+
time_dict = self.update_time_dict(time_dict, this_time_dict)
|
192 |
+
if verbose and (i+1) % print_every == 0:
|
193 |
+
print(' ' * level + f'Gen #{i+1} - Current best perf: {best_score:.2f}, time: {time.time() - init_time:.4f}')
|
194 |
+
init_time = time.time()
|
195 |
+
#
|
196 |
+
# to_print = time_dict.copy()
|
197 |
+
# keys = sorted(to_print.keys())
|
198 |
+
# values = []
|
199 |
+
# for k in keys:
|
200 |
+
# to_print[k] = np.sum(to_print[k])
|
201 |
+
# values.append(to_print[k])
|
202 |
+
# sorted_inds = np.flip(np.argsort(values))
|
203 |
+
# for i in sorted_inds:
|
204 |
+
# print(f'{keys[i]}: {values[i]:.4f}')
|
205 |
+
if verbose: print(' ' * level + f'Evolution over, best perf: {best_score:.2f}')
|
206 |
+
return self.get_best_score()
|
207 |
+
|
208 |
+
def print_results(self, n=3):
|
209 |
+
best_scores, best_ind = self.get_best_score()
|
210 |
+
for i in range(n):
|
211 |
+
best_ind[i].print_recipe(f'Candidate #{i+1}, Score: {best_scores[i]:.2f}')
|
212 |
+
|
213 |
+
|
src/cocktails/utilities/cocktail_utilities.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from src.cocktails.utilities.ingredients_utilities import ingredient2ingredient_id, ingredient_profiles, ingredients_per_type, ingredient_list, find_ingredient_from_str
|
3 |
+
from src.cocktails.utilities.cocktail_category_detection_utilities import *
|
4 |
+
import time
|
5 |
+
|
6 |
+
# representation_keys = ['pH', 'sour', 'sweet', 'booze', 'bitter', 'fruit', 'herb',
|
7 |
+
# 'complex', 'spicy', 'strong', 'oaky', 'fizzy', 'colorful', 'eggy']
|
8 |
+
representation_keys = ['sour', 'sweet', 'booze', 'bitter', 'fruit', 'herb',
|
9 |
+
'complex', 'spicy', 'oaky', 'fizzy', 'colorful', 'eggy']
|
10 |
+
representation_keys_linear = list(set(representation_keys) - set(['pH', 'complex']))
|
11 |
+
|
12 |
+
ing_reps = np.array([[ingredient_profiles[k][ing_id] for ing_id in ingredient2ingredient_id.values()] for k in representation_keys]).transpose()
|
13 |
+
|
14 |
+
|
15 |
+
def compute_cocktail_representation(profile, ingredients, quantities):
|
16 |
+
# computes representation of a cocktail from the recipe (ingredients, quantities) and volume
|
17 |
+
n = len(ingredients)
|
18 |
+
assert n == len(quantities)
|
19 |
+
quantities = np.array(quantities)
|
20 |
+
|
21 |
+
weights = quantities / np.sum(quantities)
|
22 |
+
rep = dict()
|
23 |
+
|
24 |
+
ing_ids = np.array([ingredient2ingredient_id[ing] for ing in ingredients])
|
25 |
+
# compute features as linear combination of ingredient features
|
26 |
+
for k in representation_keys_linear:
|
27 |
+
k_ing = np.array([ingredient_profiles[k][ing_id] for ing_id in ing_ids])
|
28 |
+
rep[k] = np.dot(weights, k_ing)
|
29 |
+
|
30 |
+
# for ph
|
31 |
+
# ph = - log10 x
|
32 |
+
phs = np.array([ingredient_profiles['pH'][ing_id] for ing_id in ing_ids])
|
33 |
+
concentrations = 10 ** (- phs)
|
34 |
+
mix_c = np.dot(weights, concentrations)
|
35 |
+
|
36 |
+
rep['pH'] = - np.log10(mix_c)
|
37 |
+
|
38 |
+
rep['complex'] = np.mean([ingredient_profiles['complex'][ing_id] for ing_id in ing_ids]) + len(ing_ids)
|
39 |
+
|
40 |
+
# compute profile after dilution
|
41 |
+
volume_ratio = profile['mix volume'] / profile['end volume']
|
42 |
+
for k in representation_keys:
|
43 |
+
rep['end ' + k] = rep[k] * volume_ratio
|
44 |
+
concentration = 10 ** (-rep['pH'])
|
45 |
+
end_concentration = concentration * volume_ratio
|
46 |
+
rep['end pH'] = - np.log10(end_concentration)
|
47 |
+
return rep
|
48 |
+
|
49 |
+
def get_alcohol_profile(ingredients, quantities):
|
50 |
+
ingredients = ingredients.copy()
|
51 |
+
quantities = quantities.copy()
|
52 |
+
assert len(ingredients) == len(quantities)
|
53 |
+
if 'mint' in ingredients:
|
54 |
+
mint_ind = ingredients.index('mint')
|
55 |
+
ingredients.pop(mint_ind)
|
56 |
+
quantities.pop(mint_ind)
|
57 |
+
alcohol = []
|
58 |
+
volume_mix = np.sum(quantities)
|
59 |
+
weights = quantities / volume_mix
|
60 |
+
assert np.abs(np.sum(weights) - 1) < 1e-4
|
61 |
+
ingredients_list = [ing.lower() for ing in ingredient_list]
|
62 |
+
for ing, q in zip(ingredients, quantities):
|
63 |
+
id = ingredients_list.index(ing)
|
64 |
+
alcohol.append(ingredient_profiles['ethanol'][id])
|
65 |
+
alcohol = np.dot(alcohol, weights)
|
66 |
+
return alcohol, volume_mix
|
67 |
+
|
68 |
+
def get_mix_profile(ingredients, quantities):
|
69 |
+
ingredients = ingredients.copy()
|
70 |
+
quantities = quantities.copy()
|
71 |
+
assert len(ingredients) == len(quantities)
|
72 |
+
if 'mint' in ingredients:
|
73 |
+
mint_ind = ingredients.index('mint')
|
74 |
+
ingredients.pop(mint_ind)
|
75 |
+
quantities.pop(mint_ind)
|
76 |
+
alcohol, sugar, acid = [], [], []
|
77 |
+
volume_mix = np.sum(quantities)
|
78 |
+
weights = quantities / volume_mix
|
79 |
+
assert np.abs(np.sum(weights) - 1) < 1e-4
|
80 |
+
ingredients_list = [ing.lower() for ing in ingredient_list]
|
81 |
+
for ing, q in zip(ingredients, quantities):
|
82 |
+
id = ingredients_list.index(ing)
|
83 |
+
sugar.append(ingredient_profiles['sugar'][id])
|
84 |
+
alcohol.append(ingredient_profiles['ethanol'][id])
|
85 |
+
acid.append(ingredient_profiles['acid'][id])
|
86 |
+
sugar = np.dot(sugar, weights)
|
87 |
+
acid = np.dot(acid, weights)
|
88 |
+
alcohol = np.dot(alcohol, weights)
|
89 |
+
return alcohol, sugar, acid
|
90 |
+
|
91 |
+
|
92 |
+
def extract_preparation_type(instructions, recipe):
|
93 |
+
flag = False
|
94 |
+
instructions = instructions.lower()
|
95 |
+
egg_in_recipe = any([find_ingredient_from_str(ing_str)[1]=='egg' for ing_str in recipe[1]])
|
96 |
+
if 'shake' in instructions:
|
97 |
+
if egg_in_recipe:
|
98 |
+
prep_type = 'egg_shaken'
|
99 |
+
else:
|
100 |
+
prep_type = 'shaken'
|
101 |
+
elif 'stir' in instructions:
|
102 |
+
prep_type = 'stirred'
|
103 |
+
elif 'blend' in instructions:
|
104 |
+
prep_type = 'blended'
|
105 |
+
elif any([w in instructions for w in ['build', 'mix', 'pour', 'combine', 'place']]):
|
106 |
+
prep_type = 'built'
|
107 |
+
else:
|
108 |
+
prep_type = 'built'
|
109 |
+
if egg_in_recipe and 'shaken' not in prep_type:
|
110 |
+
stop = 1
|
111 |
+
return flag, prep_type
|
112 |
+
|
113 |
+
def get_dilution_ratio(category, alcohol):
|
114 |
+
# formulas from the Liquid Intelligence book
|
115 |
+
# The formula for built was invented
|
116 |
+
if category == 'stirred':
|
117 |
+
return -1.21 * alcohol**2 + 1.246 * alcohol + 0.145
|
118 |
+
elif category in ['shaken', 'egg_shaken']:
|
119 |
+
return -1.567 * alcohol**2 + 1.742 * alcohol + 0.203
|
120 |
+
elif category == 'built':
|
121 |
+
return (-1.21 * alcohol**2 + 1.246 * alcohol + 0.145) /2
|
122 |
+
else:
|
123 |
+
return 1
|
124 |
+
|
125 |
+
def get_cocktail_rep(category, ingredients, quantities, keys):
|
126 |
+
ingredients = ingredients.copy()
|
127 |
+
quantities = quantities.copy()
|
128 |
+
assert len(ingredients) == len(quantities)
|
129 |
+
|
130 |
+
volume_mix = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] != 'mint'])
|
131 |
+
|
132 |
+
# compute alcohol content without mint ingredient
|
133 |
+
ingredients2 = [ing for ing in ingredients if ing != 'mint']
|
134 |
+
quantities2 = [q for ing, q in zip(ingredients, quantities) if ing != 'mint']
|
135 |
+
weights2 = quantities2 / np.sum(quantities2)
|
136 |
+
assert np.abs(np.sum(weights2) - 1) < 1e-4
|
137 |
+
ing_ids2 = np.array([ingredient2ingredient_id[ing] for ing in ingredients2])
|
138 |
+
alcohol = np.array([ingredient_profiles['ethanol'][ing_id] for ing_id in ing_ids2])
|
139 |
+
alcohol = np.dot(alcohol, weights2)
|
140 |
+
dilution_ratio = get_dilution_ratio(category, alcohol)
|
141 |
+
end_volume = volume_mix + volume_mix * dilution_ratio
|
142 |
+
volume_ratio = volume_mix / end_volume
|
143 |
+
end_alcohol = alcohol * volume_ratio
|
144 |
+
|
145 |
+
# computes representation of a cocktail from the recipe (ingredients, quantities) and volume
|
146 |
+
weights = quantities / np.sum(quantities)
|
147 |
+
assert np.abs(np.sum(weights) - 1) < 1e-4
|
148 |
+
ing_ids = np.array([ingredient2ingredient_id[ing] for ing in ingredients])
|
149 |
+
reps = ing_reps[ing_ids]
|
150 |
+
cocktail_rep = np.dot(weights, reps)
|
151 |
+
i_complex = keys.index('end complex')
|
152 |
+
cocktail_rep[i_complex] = np.mean(reps[:, i_complex]) + len(ing_ids) # complexity increases with number of ingredients
|
153 |
+
|
154 |
+
# compute profile after dilution
|
155 |
+
cocktail_rep = cocktail_rep * volume_ratio
|
156 |
+
cocktail_rep = np.concatenate([[end_volume], cocktail_rep])
|
157 |
+
return cocktail_rep, end_volume, end_alcohol
|
158 |
+
|
159 |
+
def get_profile(category, ingredients, quantities):
|
160 |
+
|
161 |
+
volume_mix = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] != 'mint'])
|
162 |
+
alcohol, sugar, acid = get_mix_profile(ingredients, quantities)
|
163 |
+
dilution_ratio = get_dilution_ratio(category, alcohol)
|
164 |
+
end_volume = volume_mix + volume_mix * dilution_ratio
|
165 |
+
volume_ratio = volume_mix / end_volume
|
166 |
+
profile = {'mix volume': volume_mix,
|
167 |
+
'mix alcohol': alcohol,
|
168 |
+
'mix sugar': sugar,
|
169 |
+
'mix acid': acid,
|
170 |
+
'dilution ratio': dilution_ratio,
|
171 |
+
'end volume': end_volume,
|
172 |
+
'end alcohol': alcohol * volume_ratio,
|
173 |
+
'end sugar': sugar * volume_ratio,
|
174 |
+
'end acid': acid * volume_ratio}
|
175 |
+
cocktail_rep = compute_cocktail_representation(profile, ingredients, quantities)
|
176 |
+
profile.update(cocktail_rep)
|
177 |
+
return profile
|
178 |
+
|
179 |
+
profile_keys = ['mix volume', 'end volume',
|
180 |
+
'dilution ratio',
|
181 |
+
'mix alcohol', 'end alcohol',
|
182 |
+
'mix sugar', 'end sugar',
|
183 |
+
'mix acid', 'end acid'] \
|
184 |
+
+ representation_keys \
|
185 |
+
+ ['end ' + k for k in representation_keys]
|
186 |
+
|
187 |
+
def update_profile_in_datapoint(datapoint, category, ingredients, quantities):
|
188 |
+
profile = get_profile(category, ingredients, quantities)
|
189 |
+
for k in profile_keys:
|
190 |
+
datapoint[k] = profile[k]
|
191 |
+
return datapoint
|
192 |
+
|
193 |
+
# define representation keys
|
194 |
+
def get_bunch_of_rep_keys():
|
195 |
+
dict_rep_keys = dict()
|
196 |
+
# all
|
197 |
+
rep_keys = profile_keys
|
198 |
+
dict_rep_keys['all'] = rep_keys
|
199 |
+
# only_end
|
200 |
+
rep_keys = [k for k in profile_keys if 'end' in k ]
|
201 |
+
dict_rep_keys['only_end'] = rep_keys
|
202 |
+
# except_end
|
203 |
+
rep_keys = [k for k in profile_keys if 'end' not in k ]
|
204 |
+
dict_rep_keys['except_end'] = rep_keys
|
205 |
+
# custom
|
206 |
+
to_remove = ['end alcohol', 'end sugar', 'end acid', 'end pH', 'end strong']
|
207 |
+
rep_keys = [k for k in profile_keys if 'end' in k ]
|
208 |
+
for k in to_remove:
|
209 |
+
if k in rep_keys:
|
210 |
+
rep_keys.remove(k)
|
211 |
+
dict_rep_keys['custom'] = rep_keys
|
212 |
+
# custom restricted
|
213 |
+
to_remove = ['end alcohol', 'end sugar', 'end acid', 'end pH', 'end strong', 'end spicy', 'end oaky']
|
214 |
+
rep_keys = [k for k in profile_keys if 'end' in k ]
|
215 |
+
for k in to_remove:
|
216 |
+
if k in rep_keys:
|
217 |
+
rep_keys.remove(k)
|
218 |
+
dict_rep_keys['restricted'] = rep_keys
|
219 |
+
dict_rep_keys['affective'] = ['end booze', 'end sweet', 'end sour', 'end fizzy', 'end complex', 'end bitter', 'end spicy', 'end colorful']
|
220 |
+
return dict_rep_keys
|
src/cocktails/utilities/glass_and_volume_utilities.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
glass_conversion = {'coupe':'coupe',
|
4 |
+
'martini': 'martini',
|
5 |
+
'collins': 'collins',
|
6 |
+
'oldfashion': 'oldfashion',
|
7 |
+
'Coupe glass': 'coupe',
|
8 |
+
'Old-fashioned glass': 'oldfashion',
|
9 |
+
'Martini glass': 'martini',
|
10 |
+
'Nick & Nora glass': 'coupe',
|
11 |
+
'Julep tin': 'oldfashion',
|
12 |
+
'Collins or Pineapple shell glass': 'collins',
|
13 |
+
'Collins glass': 'collins',
|
14 |
+
'Rocks glass': 'oldfashion',
|
15 |
+
'Highball (max 10oz/300ml)': 'collins',
|
16 |
+
'Wine glass': 'coupe',
|
17 |
+
'Flute glass': 'coupe',
|
18 |
+
'Double old-fashioned': 'oldfashion',
|
19 |
+
'Copa glass': 'coupe',
|
20 |
+
'Toddy glass': 'oldfashion',
|
21 |
+
'Sling glass': 'collins',
|
22 |
+
'Goblet glass': 'oldfashion',
|
23 |
+
'Fizz or Highball (8oz to 10oz)': 'collins',
|
24 |
+
'Copper mug or Collins glass': 'collins',
|
25 |
+
'Tiki mug or collins': 'collins',
|
26 |
+
'Snifter glass': 'oldfashion',
|
27 |
+
'Coconut shell or Collins glass': 'collins',
|
28 |
+
'Martini (large 10oz) glass': 'martini',
|
29 |
+
'Hurricane glass': 'collins',
|
30 |
+
'Absinthe glass or old-fashioned glass': 'oldfashion'
|
31 |
+
}
|
32 |
+
glass_volume = dict(coupe = 200,
|
33 |
+
collins=350,
|
34 |
+
martini=200,
|
35 |
+
oldfashion=320)
|
36 |
+
assert set(glass_conversion.values()) == set(glass_volume.keys())
|
37 |
+
|
38 |
+
volume_ranges = dict(stirred=(90, 97),
|
39 |
+
built=(70, 75),
|
40 |
+
shaken=(98, 112),
|
41 |
+
egg_shaken=(130, 143),
|
42 |
+
carbonated=(150, 150))
|
src/cocktails/utilities/ingredients_utilities.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This script loads the list and profiles of our ingredients selection.
|
2 |
+
# It defines rules to recognize ingredients from the list in recipes and the function to extract that information from ingredient strings.
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
from src.cocktails.config import INGREDIENTS_LIST_PATH, COCKTAILS_CSV_DATA
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
ingredient_profiles = pd.read_csv(INGREDIENTS_LIST_PATH)
|
9 |
+
ingredient_list = [ing.lower() for ing in ingredient_profiles['ingredient']]
|
10 |
+
n_ingredients = len(ingredient_list)
|
11 |
+
ingredient2ingredient_id = dict(zip(ingredient_list, range(n_ingredients)))
|
12 |
+
|
13 |
+
ingredients_types = sorted(set(ingredient_profiles['type']))
|
14 |
+
# for each type, get all ingredients
|
15 |
+
ing_per_type = [[ing for ing in ingredient_list if ingredient_profiles['type'][ingredient_list.index(ing)] == type] for type in ingredients_types]
|
16 |
+
ingredients_per_type = dict(zip(ingredients_types, ing_per_type))
|
17 |
+
|
18 |
+
bubble_ingredients = ['soda', 'ginger beer', 'tonic', 'sparkling wine']
|
19 |
+
# rules to recognize ingredients in recipes.
|
20 |
+
# in [] are separate rules with an OR relation: only one needs to be satisfied
|
21 |
+
# within [], rules apply with and AND relation: all rules need to be satisfied.
|
22 |
+
# ~ indicates that the following expression must NOT appear
|
23 |
+
# simple expression indicate that the expression MUST appear.
|
24 |
+
ingredient_search = {#'salt': ['salt'],
|
25 |
+
'lime juice': [['lime', '~soda', '~lemonade', '~cordial']],
|
26 |
+
'lemon juice': [['lemon', '~soda', '~lemonade']],
|
27 |
+
'angostura': [['angostura', '~orange'],
|
28 |
+
['bitter', '~campari', '~orange', '~red', '~italian', '~fernet']],
|
29 |
+
'orange bitters': [['orange', 'bitter', '~bittersweet']],
|
30 |
+
'orange juice': [['orange', '~bitter', '~jam', '~marmalade', '~liqueur', '~water'],
|
31 |
+
['orange', 'squeeze']],
|
32 |
+
'pineapple juice': [['pineapple']],
|
33 |
+
# 'apple juice': [['apple', 'juice', '~pine']],
|
34 |
+
'cranberry juice': [['cranberry', 'juice']],
|
35 |
+
'cointreau': ['cointreau', 'triple sec', 'grand marnier', 'curaçao', 'curacao'],
|
36 |
+
'luxardo maraschino': ['luxardo', 'maraschino', 'kirsch'],
|
37 |
+
'amaretto': ['amaretto'],
|
38 |
+
'benedictine': ['benedictine', 'bénédictine', 'bénedictine', 'benédictine'],
|
39 |
+
'campari': ['campari', ['italian', 'red', 'bitter'], 'aperol', 'bittersweet', 'aperitivo', 'orange-red'],
|
40 |
+
# 'campari': ['campari', ['italian', 'red', 'bitter']],
|
41 |
+
# 'crème de violette': [['violette', 'crème'], ['crême', 'violette'], ['liqueur', 'violette']],
|
42 |
+
# 'aperol': ['aperol', 'bittersweet', 'aperitivo', 'orange-red'],
|
43 |
+
'green chartreuse': ['chartreuse'],
|
44 |
+
'black raspberry liqueur': [['cassis', 'liqueur'],
|
45 |
+
['black raspberry', 'liqueur'],
|
46 |
+
['raspberry', 'liqueur'],
|
47 |
+
['strawberry', 'liqueur'],
|
48 |
+
['blackberry', 'liqueur'],
|
49 |
+
['violette', 'crème'], ['crême', 'violette'], ['liqueur', 'violette']],
|
50 |
+
# 'simple syrup': [],
|
51 |
+
# 'drambuie': ['drambuie'],
|
52 |
+
# 'fernet branca': ['fernet', 'branca'],
|
53 |
+
'gin': [['gin', '~sloe', '~ginger']],
|
54 |
+
'vodka': ['vodka'],
|
55 |
+
'cuban rum': [['rum', 'puerto rican'], ['light', 'rum'], ['white', 'rum'], ['rum', 'havana', '~7'], ['rum', 'bacardi']],
|
56 |
+
'cognac': [['cognac', '~grand marnier', '~cointreau', '~orange']],
|
57 |
+
# 'bourbon': [['bourbon', '~liqueur']],
|
58 |
+
# 'tequila': ['tequila', 'pisco'],
|
59 |
+
# 'tequila': ['tequila'],
|
60 |
+
'scotch': ['scotch'],
|
61 |
+
'dark rum': [['rum', 'age', '~bacardi', '~havana'],
|
62 |
+
['rum', 'dark', '~bacardi', '~havana'],
|
63 |
+
['rum', 'old', '~bacardi', '~havana'],
|
64 |
+
['rum', 'old', '7'],
|
65 |
+
['rum', 'havana', '7'],
|
66 |
+
['havana', 'rum', 'especial']],
|
67 |
+
'absinthe': ['absinthe'],
|
68 |
+
'rye whiskey': ['rye', ['bourbon', '~liqueur']],
|
69 |
+
# 'rye whiskey': ['rye'],
|
70 |
+
'apricot brandy': [['apricot', 'brandy']],
|
71 |
+
# 'pisco': ['pisco'],
|
72 |
+
# 'cachaça': ['cachaça', 'cachaca'],
|
73 |
+
'egg': [['egg', 'white', '~yolk', '~whole']],
|
74 |
+
'soda': [['soda', 'water', '~lemon', '~lime']],
|
75 |
+
'mint': ['mint'],
|
76 |
+
'sparkling wine': ['sparkling wine', 'prosecco', 'champagne'],
|
77 |
+
'ginger beer': [['ginger', 'beer'], ['ginger', 'ale']],
|
78 |
+
'tonic': [['tonic'], ['7up'], ['sprite']],
|
79 |
+
# 'espresso': ['espresso', 'expresso', ['café', '~liqueur', '~cream'],
|
80 |
+
# ['cafe', '~liqueur', '~cream'],
|
81 |
+
# ['coffee', '~liqueur', '~cream']],
|
82 |
+
# 'southern comfort': ['southern comfort'],
|
83 |
+
# 'cola': ['cola', 'coke', 'pepsi'],
|
84 |
+
'double syrup': [['sugar','~raspberry'], ['simple', 'syrup'], ['double', 'syrup']],
|
85 |
+
# 'grenadine': ['grenadine', ['pomegranate', 'syrup']],
|
86 |
+
'grenadine': ['grenadine', ['pomegranate', 'syrup'], ['raspberry', 'syrup', '~black']],
|
87 |
+
'honey syrup': ['honey', ['maple', 'syrup']],
|
88 |
+
# 'raspberry syrup': [['raspberry', 'syrup', '~black']],
|
89 |
+
'dry vermouth': [['vermouth', 'dry'], ['vermouth', 'white'], ['vermouth', 'french'], 'lillet'],
|
90 |
+
'sweet vermouth': [['vermouth', 'sweet'], ['vermouth', 'red'], ['vermouth', 'italian']],
|
91 |
+
# 'lillet blanc': ['lillet'],
|
92 |
+
'water': [['water', '~sugar', '~coconut', '~soda', '~tonic', '~honey', '~orange', '~melon']]
|
93 |
+
}
|
94 |
+
# check that there is a rule for all ingredients in the list
|
95 |
+
assert sorted(ingredient_list) == sorted(ingredient_search.keys()), 'ing search dict keys do not match ingredient list'
|
96 |
+
|
97 |
+
def get_ingredients_info():
|
98 |
+
data = pd.read_csv(COCKTAILS_CSV_DATA)
|
99 |
+
max_ingredients, ingredient_set, liquor_set, liqueur_set, vermouth_set = get_max_n_ingredients(data)
|
100 |
+
ingredient_list = sorted(ingredient_set)
|
101 |
+
alcohol = sorted(liquor_set.union(liqueur_set).union(vermouth_set).union(set(['sparkling wine'])))
|
102 |
+
ind_alcohol = [i for i in range(len(ingredient_list)) if ingredient_list[i] in alcohol]
|
103 |
+
return max_ingredients, ingredient_list, ind_alcohol
|
104 |
+
|
105 |
+
def get_max_n_ingredients(data):
|
106 |
+
max_count = 0
|
107 |
+
ingredient_set = set()
|
108 |
+
alcohol_set = set()
|
109 |
+
liqueur_set = set()
|
110 |
+
vermouth_set = set()
|
111 |
+
ing_str = np.array(data['ingredients_str'])
|
112 |
+
for i in range(len(data['names'])):
|
113 |
+
ingredients, quantities = extract_ingredients(ing_str[i])
|
114 |
+
max_count = max(max_count, len(ingredients))
|
115 |
+
for ing in ingredients:
|
116 |
+
ingredient_set.add(ing)
|
117 |
+
if ing in ingredients_per_type['liquor']:
|
118 |
+
alcohol_set.add(ing)
|
119 |
+
if ing in ingredients_per_type['liqueur']:
|
120 |
+
liqueur_set.add(ing)
|
121 |
+
if ing in ingredients_per_type['vermouth']:
|
122 |
+
vermouth_set.add(ing)
|
123 |
+
return max_count, ingredient_set, alcohol_set, liqueur_set, vermouth_set
|
124 |
+
|
125 |
+
def find_ingredient_from_str(ing_str):
|
126 |
+
# function that assigns an ingredient string to one of the ingredient if possible, following the rules defined above.
|
127 |
+
# return a flag and the ingredient string. When flag is false, the ingredient has not been found and the cocktail is rejected.
|
128 |
+
ing_str = ing_str.lower()
|
129 |
+
flags = []
|
130 |
+
for k in ingredient_list:
|
131 |
+
or_flags = [] # get flag for each of several conditions
|
132 |
+
for i_p, pattern in enumerate(ingredient_search[k]):
|
133 |
+
or_flags.append(True)
|
134 |
+
if isinstance(pattern, str):
|
135 |
+
if pattern[0] == '~' and pattern[1:] in ing_str:
|
136 |
+
or_flags[-1] = False
|
137 |
+
elif pattern[0] != '~' and pattern not in ing_str:
|
138 |
+
or_flags[-1] = False
|
139 |
+
elif isinstance(pattern, list):
|
140 |
+
for element in pattern:
|
141 |
+
if element[0] == '~':
|
142 |
+
or_flags[-1] = or_flags[-1] and not element[1:] in ing_str
|
143 |
+
else:
|
144 |
+
or_flags[-1] = or_flags[-1] and element in ing_str
|
145 |
+
else:
|
146 |
+
raise ValueError
|
147 |
+
flags.append(any(or_flags))
|
148 |
+
if sum(flags) > 1:
|
149 |
+
print(ing_str)
|
150 |
+
for i_f, f in enumerate(flags):
|
151 |
+
if f:
|
152 |
+
print(ingredient_list[i_f])
|
153 |
+
stop = 1
|
154 |
+
return True, ingredient_list[flags.index(True)]
|
155 |
+
elif sum(flags) == 0:
|
156 |
+
# if 'grape' not in ing_str:
|
157 |
+
# print('\t\t Not found:', ing_str)
|
158 |
+
return True, None
|
159 |
+
else:
|
160 |
+
return False, ingredient_list[flags.index(True)]
|
161 |
+
|
162 |
+
def get_cocktails_per_ingredient(ing_strs):
|
163 |
+
cocktails_per_ing = dict(zip(ingredient_list, [[] for _ in range(len(ingredient_list))]))
|
164 |
+
for i_ing, ing_str in enumerate(ing_strs):
|
165 |
+
ingredients, _ = extract_ingredients(ing_str)
|
166 |
+
for ing in ingredients:
|
167 |
+
cocktails_per_ing[ing].append(i_ing)
|
168 |
+
return cocktails_per_ing
|
169 |
+
|
170 |
+
def extract_ingredients(ingredient_str):
|
171 |
+
# extract list of ingredients and quantities from an formatted ingredient string (reverse of format_ingredients)
|
172 |
+
ingredient_str = ingredient_str[1: -1]
|
173 |
+
words = ingredient_str.split(',')
|
174 |
+
ingredients = []
|
175 |
+
quantities = []
|
176 |
+
for i in range(len(words)//2):
|
177 |
+
ingredients.append(words[2 * i][1:])
|
178 |
+
quantities.append(float(words[2 * i + 1][:-1]))
|
179 |
+
return ingredients, quantities
|
180 |
+
|
181 |
+
def format_ingredients(ingredients, quantities):
|
182 |
+
# format an ingredient string from the lists of ingredients and quantities (reverse of extract_ingredients)
|
183 |
+
out = '['
|
184 |
+
for ing, q in zip(ingredients, quantities):
|
185 |
+
if ing[-1] == ' ':
|
186 |
+
ingre = ing[:-1]
|
187 |
+
else:
|
188 |
+
ingre = ing
|
189 |
+
out += f'({ingre},{q}),'
|
190 |
+
out = out[:-1] + ']'
|
191 |
+
return out
|
192 |
+
|
193 |
+
|
194 |
+
def get_ingredient_count(data):
|
195 |
+
# get count of ingredients in the whole dataset
|
196 |
+
ingredient_counts = dict(zip(ingredient_list, [0] * len(ingredient_list)))
|
197 |
+
for i in range(len(data['names'])):
|
198 |
+
if data['to_keep'][i]:
|
199 |
+
ingredients, _ = extract_ingredients(data['ingredients_str'][i])
|
200 |
+
for i in ingredients:
|
201 |
+
ingredient_counts[i] += 1
|
202 |
+
return ingredient_counts
|
203 |
+
|
204 |
+
def add_counts_to_ingredient_list(data):
|
205 |
+
# update the list of ingredients to add their count of occurence in dataset.
|
206 |
+
ingredient_counts = get_ingredient_count(data)
|
207 |
+
counts = [ingredient_counts[k] for k in ingredient_list]
|
208 |
+
ingredient_profiles['counts'] = counts
|
209 |
+
ingredient_profiles.to_csv(INGREDIENTS_LIST_PATH, index=False)
|
src/cocktails/utilities/other_scrubbing_utilities.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pickle
|
3 |
+
from src.cocktails.utilities.cocktail_utilities import get_profile, profile_keys
|
4 |
+
from src.cocktails.utilities.ingredients_utilities import extract_ingredients, ingredient_list, ingredient_profiles
|
5 |
+
from src.cocktails.utilities.glass_and_volume_utilities import glass_volume, volume_ranges
|
6 |
+
|
7 |
+
one_dash = 1
|
8 |
+
one_splash = 6
|
9 |
+
one_tablespoon = 15
|
10 |
+
one_barspoon = 5
|
11 |
+
fill_rate = 0.8
|
12 |
+
quantity_factors ={'ml':1,
|
13 |
+
'cl':10,
|
14 |
+
'splash':one_splash,
|
15 |
+
'splashes':one_splash,
|
16 |
+
'dash':one_dash,
|
17 |
+
'dashes':one_dash,
|
18 |
+
'spoon':one_barspoon,
|
19 |
+
'spoons':one_barspoon,
|
20 |
+
'tablespoon':one_tablespoon,
|
21 |
+
'barspoons':one_barspoon,
|
22 |
+
'barspoon':one_barspoon,
|
23 |
+
'bar spoons': one_barspoon,
|
24 |
+
'bar spoon': one_barspoon,
|
25 |
+
'tablespoons':one_tablespoon,
|
26 |
+
'teaspoon':5,
|
27 |
+
'teaspoons':5,
|
28 |
+
'drop':0.05,
|
29 |
+
'drops':0.05}
|
30 |
+
quantitiy_keys = sorted(quantity_factors.keys())
|
31 |
+
indexes_keys = np.flip(np.argsort([len(k) for k in quantitiy_keys]))
|
32 |
+
quantity_factors_keys = list(np.array(quantitiy_keys)[indexes_keys])
|
33 |
+
|
34 |
+
keys_to_track = ['names', 'urls', 'glass', 'garnish', 'recipe', 'how_to', 'review', 'taste_rep', 'valid']
|
35 |
+
keys_to_add = ['category', 'subcategory', 'ingredients_str', 'ingredients', 'quantities', 'to_keep']
|
36 |
+
keys_to_update = ['glass']
|
37 |
+
keys_for_csv = ['names', 'category', 'subcategory', 'ingredients_str', 'urls', 'glass', 'garnish', 'how_to', 'review', 'taste_rep'] + profile_keys
|
38 |
+
|
39 |
+
to_replace_q = {' fresh': ''}
|
40 |
+
to_replace_ing = {'maple syrup': 'honey syrup',
|
41 |
+
'agave syrup': 'honey syrup',
|
42 |
+
'basil': 'mint'}
|
43 |
+
|
44 |
+
def print_recipe(unit='mL', ingredient_str=None, ingredients=None, quantities=None, name='', cat='', to_print=True):
|
45 |
+
str_out = ''
|
46 |
+
if ingredient_str is None:
|
47 |
+
assert len(ingredients) == len(quantities), 'provide either ingredient_str, or list ingredients and quantities'
|
48 |
+
else:
|
49 |
+
assert ingredients is None and quantities is None, 'provide either ingredient_str, or list ingredients and quantities'
|
50 |
+
ingredients, quantities = extract_ingredients(ingredient_str)
|
51 |
+
|
52 |
+
str_out += f'\nRecipe:'
|
53 |
+
if name != '' and name is not None: str_out += f' {name}'
|
54 |
+
if cat != '': str_out += f' ({cat})'
|
55 |
+
str_out += '\n'
|
56 |
+
for i in range(len(ingredients)):
|
57 |
+
# get quantifier
|
58 |
+
if ingredients[i] == 'egg':
|
59 |
+
quantities[i] = 1
|
60 |
+
ingredients[i] = 'egg white'
|
61 |
+
if unit == 'mL':
|
62 |
+
quantifier = ' (30 mL)'
|
63 |
+
elif unit == 'oz':
|
64 |
+
quantifier = ' (1 fl oz)'
|
65 |
+
else:
|
66 |
+
raise ValueError
|
67 |
+
elif ingredients[i] in ['angostura', 'orange bitters']:
|
68 |
+
quantities[i] = max(1, int(quantities[i] / 0.6))
|
69 |
+
quantifier = ' dash'
|
70 |
+
if quantities[i] > 1: quantifier += 'es'
|
71 |
+
elif ingredients[i] == 'mint':
|
72 |
+
if quantities[i] > 1: quantifier = ' leaves'
|
73 |
+
else: quantifier = ' leaf'
|
74 |
+
else:
|
75 |
+
if unit == "oz":
|
76 |
+
quantities[i] = float(f"{quantities[i] * 0.033814:.3f}") # convert to fl oz
|
77 |
+
quantifier = ' fl oz'
|
78 |
+
else:
|
79 |
+
quantifier = ' mL'
|
80 |
+
str_out += f' {quantities[i]}{quantifier} - {ingredients[i]}\n'
|
81 |
+
|
82 |
+
if to_print:
|
83 |
+
print(str_out)
|
84 |
+
return str_out
|
85 |
+
|
86 |
+
|
87 |
+
def test_datapoint(datapoint, category, ingredients, quantities):
|
88 |
+
# run checks
|
89 |
+
ingredient_indexes = [ingredient_list.index(ing) for ing in ingredients]
|
90 |
+
profile = get_profile(category, ingredients, quantities)
|
91 |
+
volume = profile['end volume']
|
92 |
+
alcohol = profile['end alcohol']
|
93 |
+
acid = profile['end acid']
|
94 |
+
sugar = profile['end sugar']
|
95 |
+
# check volume
|
96 |
+
if datapoint['glass'] != None:
|
97 |
+
if volume > glass_volume[datapoint['glass']] * fill_rate:
|
98 |
+
# recompute quantities for it to match
|
99 |
+
ratio = fill_rate * glass_volume[datapoint['glass']] / volume
|
100 |
+
for i_q in range(len(quantities)):
|
101 |
+
quantities[i_q] = float(f'{quantities[i_q] * ratio:.2f}')
|
102 |
+
# check alcohol
|
103 |
+
assert alcohol < 30, 'too boozy'
|
104 |
+
assert alcohol < 5, 'not boozy enough'
|
105 |
+
assert acid < 2, 'too much acid'
|
106 |
+
assert sugar < 20, 'too much sugar'
|
107 |
+
assert len(ingredients) > 1, 'only one ingredient'
|
108 |
+
if len(set(ingredients)) != len(ingredients):
|
109 |
+
i_doubles = []
|
110 |
+
s_ing = set()
|
111 |
+
for i, ing in enumerate(ingredients):
|
112 |
+
if ing in s_ing:
|
113 |
+
i_doubles.append(i)
|
114 |
+
else:
|
115 |
+
s_ing.add(ing)
|
116 |
+
ingredient_double_ok = ['mint', 'cointreau', 'lemon juice', 'cuban rum', 'double syrup']
|
117 |
+
if len(i_doubles) == 1 and ingredients[i_doubles[0]] in ingredient_double_ok:
|
118 |
+
ing_double = ingredients[i_doubles[0]]
|
119 |
+
double_q = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] == ing_double])
|
120 |
+
ingredients.pop(i_doubles[0])
|
121 |
+
quantities.pop(i_doubles[0])
|
122 |
+
quantities[ingredients.index(ing_double)] = double_q
|
123 |
+
else:
|
124 |
+
assert False, f'double ingredient, not {ingredient_double_ok}'
|
125 |
+
lemon_lime_q = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] in ['lime juice', 'lemon juice']])
|
126 |
+
assert lemon_lime_q <= 45, 'too much lemon and lime'
|
127 |
+
salt_q = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] == 'salt'])
|
128 |
+
assert salt_q <= 8, 'too much salt'
|
129 |
+
bitter_q = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] in ['angostura', 'orange bitters']])
|
130 |
+
assert bitter_q <= 5 * one_dash, 'too much bitter'
|
131 |
+
absinthe_q = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] == 'absinthe'])
|
132 |
+
if absinthe_q > 4 * one_dash:
|
133 |
+
mix_volume = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] != 'mint'])
|
134 |
+
assert absinthe_q < 0.5 * mix_volume, 'filter absinthe glasses'
|
135 |
+
if any([w in datapoint['how_to'] or any([w in ing.lower() for ing in datapoint['recipe'][1]]) for w in ['warm', 'boil', 'hot']]) and 'shot' not in datapoint['how_to']:
|
136 |
+
assert False
|
137 |
+
water_q = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] == 'water'])
|
138 |
+
assert water_q < 40
|
139 |
+
# n_liqueur = np.sum([ingredient_profiles['type'][i].lower() == 'liqueur' for i in ingredient_indexes])
|
140 |
+
# assert n_liqueur <= 2
|
141 |
+
n_liqueur_and_vermouth = np.sum([ingredient_profiles['type'][i].lower() in ['liqueur', 'vermouth'] for i in ingredient_indexes])
|
142 |
+
assert n_liqueur_and_vermouth <= 3
|
143 |
+
return ingredients, quantities
|
144 |
+
|
145 |
+
def run_battery_checks_difford(datapoint, category, ingredients, quantities):
|
146 |
+
flag = False
|
147 |
+
try:
|
148 |
+
ingredients, quantities = test_datapoint(datapoint, category, ingredients, quantities)
|
149 |
+
except:
|
150 |
+
flag = True
|
151 |
+
print(datapoint["names"])
|
152 |
+
print(datapoint["urls"])
|
153 |
+
ingredients, quantities = None, None
|
154 |
+
|
155 |
+
return flag, ingredients, quantities
|
156 |
+
|
157 |
+
def tambouille(q, ingredients_scrubbed, quantities_scrubbed, cat):
|
158 |
+
# ugly
|
159 |
+
ing_scrubbed = ingredients_scrubbed[len(quantities_scrubbed)]
|
160 |
+
if q == '4 cube' and ing_scrubbed == 'pineapple juice':
|
161 |
+
q = '20 ml'
|
162 |
+
elif 'top up with' in q:
|
163 |
+
volume_so_far = np.sum([quantities_scrubbed[i] for i in range(len(quantities_scrubbed)) if ingredients_scrubbed[i] != 'mint'])
|
164 |
+
volume_mix = np.sum(volume_ranges[cat]) / 2
|
165 |
+
if (volume_mix - volume_so_far) < 15:
|
166 |
+
q = '15 ml'#
|
167 |
+
else:
|
168 |
+
q = str(int(volume_mix - volume_so_far)) + ' ml'
|
169 |
+
elif q == '1 pinch' and ing_scrubbed == 'salt':
|
170 |
+
q = '2 drops'
|
171 |
+
elif 'cube' in q and ing_scrubbed == 'double syrup':
|
172 |
+
q = f'{float(q.split(" ")[0]) * 2 * 1.7:.2f} ml' #2g per cube, 1.7 is ratio solid / syrup
|
173 |
+
elif 'wedge' in q:
|
174 |
+
if ing_scrubbed == 'orange juice':
|
175 |
+
vol = 70
|
176 |
+
elif ing_scrubbed == 'lime juice':
|
177 |
+
vol = 30
|
178 |
+
elif ing_scrubbed == 'lemon juice':
|
179 |
+
vol = 45
|
180 |
+
elif ing_scrubbed == 'pineapple juice':
|
181 |
+
vol = 140
|
182 |
+
factor = float(q.split(' ')[0]) * 0.15 # consider a wedge to be 0.15*the fruit.
|
183 |
+
q = f'{factor * vol:.2f} ml'
|
184 |
+
elif 'slice' in q:
|
185 |
+
if ing_scrubbed == 'orange juice':
|
186 |
+
vol = 70
|
187 |
+
elif ing_scrubbed == 'lime juice':
|
188 |
+
vol = 30
|
189 |
+
elif ing_scrubbed == 'lemon juice':
|
190 |
+
vol = 45
|
191 |
+
elif ing_scrubbed == 'pineapple juice':
|
192 |
+
vol = 140
|
193 |
+
f = q.split(' ')[0]
|
194 |
+
if len(f.split('⁄')) > 1:
|
195 |
+
frac = f.split('⁄')
|
196 |
+
factor = float(frac[0]) / float(frac[1])
|
197 |
+
else:
|
198 |
+
factor = float(f)
|
199 |
+
factor *= 0.1 # consider a slice to be 0.1*the fruit.
|
200 |
+
q = f'{factor * vol:.2f} ml'
|
201 |
+
elif q == '1 whole' and ing_scrubbed == 'luxardo maraschino':
|
202 |
+
q = '10 ml'
|
203 |
+
elif ing_scrubbed == 'egg' and 'ml' not in q:
|
204 |
+
q = f'{float(q) * 30:.2f} ml' # 30 ml per egg
|
205 |
+
return q
|
206 |
+
|
207 |
+
|
208 |
+
def compute_eucl_dist(a, b):
|
209 |
+
return np.sqrt(np.sum((a - b)**2))
|
210 |
+
|
211 |
+
def evaluate_with_quadruplets(representations, strategy='all'):
|
212 |
+
with open(QUADRUPLETS_PATH, 'rb') as f:
|
213 |
+
data = pickle.load(f)
|
214 |
+
data = list(data.values())
|
215 |
+
quadruplets = []
|
216 |
+
if strategy != 'all':
|
217 |
+
for d in data:
|
218 |
+
if d[0] == strategy:
|
219 |
+
quadruplets.append(d[1:])
|
220 |
+
elif strategy == 'all':
|
221 |
+
for d in data:
|
222 |
+
quadruplets.append(d[1:])
|
223 |
+
else:
|
224 |
+
raise ValueError
|
225 |
+
|
226 |
+
scores = []
|
227 |
+
for q in quadruplets:
|
228 |
+
close = q[0]
|
229 |
+
if len(close) == 2:
|
230 |
+
far = q[1]
|
231 |
+
distance_close = compute_eucl_dist(representations[close[0]], representations[close[1]])
|
232 |
+
distances_far = [compute_eucl_dist(representations[far[i][0]], representations[far[i][1]]) for i in range(len(far))]
|
233 |
+
scores.append(distance_close < np.min(distances_far))
|
234 |
+
if len(scores) == 0:
|
235 |
+
score = np.nan
|
236 |
+
else:
|
237 |
+
score = np.mean(scores)
|
238 |
+
return score
|
239 |
+
|
240 |
+
|
src/debugger.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
|
3 |
+
# from src.music.data_collection.is_audio_solo_piano import calculate_piano_solo_prob
|
4 |
+
from src.music.utils import load_audio
|
5 |
+
from src.music.config import FPS
|
6 |
+
import pretty_midi as pm
|
7 |
+
import numpy as np
|
8 |
+
from src.music.config import MUSIC_REP_PATH, MUSIC_NN_PATH
|
9 |
+
from sklearn.neighbors import NearestNeighbors
|
10 |
+
from src.cocktails.config import FULL_COCKTAIL_REP_PATH, COCKTAIL_NN_PATH, COCKTAILS_CSV_DATA
|
11 |
+
# from src.cocktails.pipeline.get_affect2affective_cluster import get_affective_cluster_centers
|
12 |
+
from src.cocktails.utilities.other_scrubbing_utilities import print_recipe
|
13 |
+
from src.music.utils import get_all_subfiles_with_extension
|
14 |
+
import os
|
15 |
+
import pickle
|
16 |
+
import pandas as pd
|
17 |
+
import time
|
18 |
+
|
19 |
+
keyword = 'b256_r128_represented'
|
20 |
+
def load_reps(rep_path, sample_size=None):
|
21 |
+
if sample_size:
|
22 |
+
with open(rep_path + f'all_reps_unnormalized_sample{sample_size}.pickle', 'rb') as f:
|
23 |
+
data = pickle.load(f)
|
24 |
+
else:
|
25 |
+
with open(rep_path + f'music_reps_unnormalized.pickle', 'rb') as f:
|
26 |
+
data = pickle.load(f)
|
27 |
+
reps = data['reps']
|
28 |
+
# playlists = [r.split(f'_{keyword}')[0].split('/')[-1] for r in data['paths']]
|
29 |
+
playlists = [r.split(f'{keyword}')[1].split('/')[1] for r in data['paths']]
|
30 |
+
n_data, dim_data = reps.shape
|
31 |
+
return reps, data['paths'], playlists, n_data, dim_data
|
32 |
+
|
33 |
+
class Debugger():
|
34 |
+
def __init__(self, verbose=True):
|
35 |
+
|
36 |
+
if verbose: print('Setting up debugger.')
|
37 |
+
if not os.path.exists(MUSIC_NN_PATH):
|
38 |
+
reps_path = MUSIC_REP_PATH + 'music_reps_unnormalized.pickle'
|
39 |
+
if not os.path.exists(reps_path):
|
40 |
+
all_rep_path = get_all_subfiles_with_extension(MUSIC_REP_PATH, max_depth=3, extension='.txt', current_depth=0)
|
41 |
+
all_data = []
|
42 |
+
new_all_rep_path = []
|
43 |
+
for i_r, r in enumerate(all_rep_path):
|
44 |
+
if 'mean_std' not in r:
|
45 |
+
all_data.append(np.loadtxt(r))
|
46 |
+
assert len(all_data[-1]) == 128
|
47 |
+
new_all_rep_path.append(r)
|
48 |
+
data = np.array(all_data)
|
49 |
+
to_save = dict(reps=data,
|
50 |
+
paths=new_all_rep_path)
|
51 |
+
with open(reps_path, 'wb') as f:
|
52 |
+
pickle.dump(to_save, f)
|
53 |
+
|
54 |
+
reps, self.rep_paths, playlists, n_data, self.dim_rep_music = load_reps(MUSIC_REP_PATH)
|
55 |
+
self.nn_model_music = NearestNeighbors(n_neighbors=6, metric='cosine')
|
56 |
+
self.nn_model_music.fit(reps)
|
57 |
+
to_save = dict(nn_model=self.nn_model_music,
|
58 |
+
rep_paths=self.rep_paths,
|
59 |
+
dim_rep_music=self.dim_rep_music)
|
60 |
+
with open(MUSIC_NN_PATH, 'wb') as f:
|
61 |
+
pickle.dump(to_save, f)
|
62 |
+
else:
|
63 |
+
with open(MUSIC_NN_PATH, 'rb') as f:
|
64 |
+
data = pickle.load(f)
|
65 |
+
self.nn_model_music = data['nn_model']
|
66 |
+
self.rep_paths = data['rep_paths']
|
67 |
+
self.dim_rep_music = data['dim_rep_music']
|
68 |
+
if verbose: print(f' {len(self.rep_paths)} songs, representation dim: {self.dim_rep_music}')
|
69 |
+
self.rep_paths = np.array(self.rep_paths)
|
70 |
+
if not os.path.exists(COCKTAIL_NN_PATH):
|
71 |
+
cocktail_reps = np.loadtxt(FULL_COCKTAIL_REP_PATH)
|
72 |
+
# cocktail_reps = (cocktail_reps - cocktail_reps.mean(axis=0)) / cocktail_reps.std(axis=0)
|
73 |
+
self.nn_model_cocktail = NearestNeighbors(n_neighbors=6)
|
74 |
+
self.nn_model_cocktail.fit(cocktail_reps)
|
75 |
+
self.dim_rep_cocktail = cocktail_reps.shape[1]
|
76 |
+
self.n_cocktails = cocktail_reps.shape[0]
|
77 |
+
to_save = dict(nn_model=self.nn_model_cocktail,
|
78 |
+
dim_rep_cocktail=self.dim_rep_cocktail,
|
79 |
+
n_cocktails=self.n_cocktails)
|
80 |
+
with open(COCKTAIL_NN_PATH, 'wb') as f:
|
81 |
+
pickle.dump(to_save, f)
|
82 |
+
else:
|
83 |
+
with open(COCKTAIL_NN_PATH, 'rb') as f:
|
84 |
+
data = pickle.load(f)
|
85 |
+
self.nn_model_cocktail = data['nn_model']
|
86 |
+
self.dim_rep_cocktail = data['dim_rep_cocktail']
|
87 |
+
self.n_cocktails = data['n_cocktails']
|
88 |
+
if verbose: print(f' {self.n_cocktails} cocktails, representation dim: {self.dim_rep_cocktail}')
|
89 |
+
|
90 |
+
self.cocktail_data = pd.read_csv(COCKTAILS_CSV_DATA)
|
91 |
+
# self.affective_cluster_centers = get_affective_cluster_centers()
|
92 |
+
self.keys_to_print = ['mse_reconstruction', 'nearest_cocktail_recipes', 'nearest_cocktail_urls',
|
93 |
+
'nn_music_dists', 'nn_music', 'dim_rep', 'nb_notes', 'audio_len', 'piano_solo_prob', 'recipe_score', 'cocktail_rep']
|
94 |
+
# 'affect', 'affective_cluster_id', 'affective_cluster_center',
|
95 |
+
|
96 |
+
|
97 |
+
def get_nearest_songs(self, music_rep):
|
98 |
+
dists, indexes = self.nn_model_music.kneighbors(music_rep.reshape(1, -1))
|
99 |
+
indexes = indexes.flatten()[:5]
|
100 |
+
rep_paths = [r.split('/')[-1] for r in self.rep_paths[indexes[:5]]]
|
101 |
+
return rep_paths, dists.flatten().tolist()
|
102 |
+
|
103 |
+
def get_nearest_cocktails(self, cocktail_rep):
|
104 |
+
dists, indexes = self.nn_model_cocktail.kneighbors(cocktail_rep.reshape(1, -1))
|
105 |
+
indexes = indexes.flatten()
|
106 |
+
nn_names = np.array(self.cocktail_data['names'])[indexes].tolist()
|
107 |
+
nn_urls = np.array(self.cocktail_data['urls'])[indexes].tolist()
|
108 |
+
nn_recipes = [print_recipe(ingredient_str=ing_str, to_print=False) for ing_str in np.array(self.cocktail_data['ingredients_str'])[indexes]]
|
109 |
+
nn_ing_strs = np.array(self.cocktail_data['ingredients_str'])[indexes].tolist()
|
110 |
+
return indexes, nn_names, nn_urls, nn_recipes, nn_ing_strs
|
111 |
+
|
112 |
+
def extract_info(self, all_paths, affective_cluster_id, affect, cocktail_rep, music_reconstruction, recipe_score, verbose=False, level=0):
|
113 |
+
if verbose: print(' ' * level + 'Extracting debug info..')
|
114 |
+
init_time = time.time()
|
115 |
+
debug_dict = dict()
|
116 |
+
debug_dict['all_paths'] = all_paths
|
117 |
+
debug_dict['recipe_score'] = recipe_score
|
118 |
+
|
119 |
+
if all_paths['audio_path'] != None:
|
120 |
+
# is it piano?
|
121 |
+
debug_dict['piano_solo_prob'] = None#float(calculate_piano_solo_prob(all_paths['audio_path'])[0])
|
122 |
+
# how long is the audio
|
123 |
+
(audio, _) = load_audio(all_paths['audio_path'], sr=FPS, mono=True)
|
124 |
+
debug_dict['audio_len'] = int(len(audio) / FPS)
|
125 |
+
else:
|
126 |
+
debug_dict['piano_solo_prob'] = None
|
127 |
+
debug_dict['audio_len'] = None
|
128 |
+
|
129 |
+
# how many notes?
|
130 |
+
midi = pm.PrettyMIDI(all_paths['processed_path'])
|
131 |
+
debug_dict['nb_notes'] = len(midi.instruments[0].notes)
|
132 |
+
|
133 |
+
# dimension of music rep
|
134 |
+
representation = np.loadtxt(all_paths['representation_path'])
|
135 |
+
debug_dict['dim_rep'] = representation.shape[0]
|
136 |
+
|
137 |
+
# closest songs in dataset
|
138 |
+
debug_dict['nn_music'], debug_dict['nn_music_dists'] = self.get_nearest_songs(representation)
|
139 |
+
|
140 |
+
# get affective cluster info
|
141 |
+
# debug_dict['affective_cluster_id'] = affective_cluster_id[0]
|
142 |
+
# debug_dict['affective_cluster_center'] = self.affective_cluster_centers[affective_cluster_id].flatten().tolist()
|
143 |
+
# debug_dict['affect'] = affect.flatten().tolist()
|
144 |
+
indexes, nn_names, nn_urls, nn_recipes, nn_ing_strs = self.get_nearest_cocktails(cocktail_rep)
|
145 |
+
debug_dict['cocktail_rep'] = cocktail_rep.copy().tolist()
|
146 |
+
debug_dict['nearest_cocktail_indexes'] = indexes.tolist()
|
147 |
+
debug_dict['nn_ing_strs'] = nn_ing_strs
|
148 |
+
debug_dict['nearest_cocktail_names'] = nn_names
|
149 |
+
debug_dict['nearest_cocktail_urls'] = nn_urls
|
150 |
+
debug_dict['nearest_cocktail_recipes'] = nn_recipes
|
151 |
+
|
152 |
+
debug_dict['music_reconstruction'] = music_reconstruction.tolist()
|
153 |
+
debug_dict['mse_reconstruction'] = ((music_reconstruction - representation) ** 2).mean()
|
154 |
+
self.debug_dict = debug_dict
|
155 |
+
if verbose: print(' ' * (level + 2) + f'Debug info extracted in {int(time.time() - init_time)} seconds.')
|
156 |
+
|
157 |
+
return self.debug_dict
|
158 |
+
|
159 |
+
def print_debug(self, level=0):
|
160 |
+
print(' ' * level + '__DEBUGGING INFO__')
|
161 |
+
for k in self.keys_to_print:
|
162 |
+
to_print = self.debug_dict[k]
|
163 |
+
if k == 'nearest_cocktail_recipes':
|
164 |
+
to_print = self.debug_dict[k].copy()
|
165 |
+
for i in range(len(to_print)):
|
166 |
+
to_print[i] = to_print[i].replace('\n', '').replace('\t', '').replace('()', '')
|
167 |
+
if k == "nn_music":
|
168 |
+
to_print = self.debug_dict[k].copy()
|
169 |
+
for i in range(len(to_print)):
|
170 |
+
to_print[i] = to_print[i].replace('encoded_new_structured_', '').replace('_represented.txt', '')
|
171 |
+
to_print_str = f'{to_print}'
|
172 |
+
if isinstance(to_print, float):
|
173 |
+
to_print_str = f'{to_print:.2f}'
|
174 |
+
elif isinstance(to_print, list):
|
175 |
+
if isinstance(to_print[0], float):
|
176 |
+
to_print_str = '['
|
177 |
+
for element in to_print:
|
178 |
+
to_print_str += f'{element:.2f}, '
|
179 |
+
to_print_str = to_print_str[:-2] + ']'
|
180 |
+
print(' ' * (level + 2) + f'{k} : ' + to_print_str)
|