Spaces:
Runtime error
Runtime error
Delete src
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- src/__init__.py +0 -0
- src/cocktails/__init__.py +0 -0
- src/cocktails/config.py +0 -21
- src/cocktails/pipeline/__init__.py +0 -0
- src/cocktails/pipeline/cocktail2affect.py +0 -372
- src/cocktails/pipeline/cocktailrep2recipe.py +0 -329
- src/cocktails/pipeline/get_affect2affective_cluster.py +0 -23
- src/cocktails/pipeline/get_cocktail2affective_cluster.py +0 -9
- src/cocktails/representation_learning/__init__.py +0 -0
- src/cocktails/representation_learning/dataset.py +0 -324
- src/cocktails/representation_learning/multihead_model.py +0 -148
- src/cocktails/representation_learning/run.py +0 -557
- src/cocktails/representation_learning/run_simple_net.py +0 -302
- src/cocktails/representation_learning/run_without_vae.py +0 -514
- src/cocktails/representation_learning/simple_model.py +0 -54
- src/cocktails/representation_learning/vae_model.py +0 -238
- src/cocktails/utilities/__init__.py +0 -0
- src/cocktails/utilities/analysis_utilities.py +0 -189
- src/cocktails/utilities/cocktail_category_detection_utilities.py +0 -221
- src/cocktails/utilities/cocktail_generation_utilities/__init__.py +0 -0
- src/cocktails/utilities/cocktail_generation_utilities/individual.py +0 -587
- src/cocktails/utilities/cocktail_generation_utilities/population.py +0 -213
- src/cocktails/utilities/cocktail_utilities.py +0 -220
- src/cocktails/utilities/glass_and_volume_utilities.py +0 -42
- src/cocktails/utilities/ingredients_utilities.py +0 -209
- src/cocktails/utilities/other_scrubbing_utilities.py +0 -240
- src/debugger.py +0 -180
- src/music/__init__.py +0 -0
- src/music/config.py +0 -72
- src/music/pipeline/__init__.py +0 -0
- src/music/pipeline/audio2midi.py +0 -52
- src/music/pipeline/audio2piano_solo_prob.py +0 -47
- src/music/pipeline/encoded2rep.py +0 -88
- src/music/pipeline/midi2processed.py +0 -152
- src/music/pipeline/music_pipeline.py +0 -86
- src/music/pipeline/processed2encoded.py +0 -52
- src/music/pipeline/processed2handcodedrep.py +0 -343
- src/music/pipeline/synth2audio.py +0 -170
- src/music/pipeline/synth2midi.py +0 -146
- src/music/pipeline/url2audio.py +0 -119
- src/music/representation_analysis/__init__.py +0 -0
- src/music/representation_analysis/analyze_rep.py +0 -146
- src/music/representation_learning/__init__.py +0 -0
- src/music/representation_learning/mlm_pretrain/__init__.py +0 -0
- src/music/representation_learning/mlm_pretrain/data_collators.py +0 -180
- src/music/representation_learning/mlm_pretrain/models/music-bert/config.json +0 -20
- src/music/representation_learning/mlm_pretrain/models/music-bert/tokenizer.json +0 -1
- src/music/representation_learning/mlm_pretrain/models/music-spanbert/config.json +0 -20
- src/music/representation_learning/mlm_pretrain/models/music-spanbert/tokenizer.json +0 -1
- src/music/representation_learning/mlm_pretrain/models/music-t5-small/config.json +0 -56
src/__init__.py
DELETED
File without changes
|
src/cocktails/__init__.py
DELETED
File without changes
|
src/cocktails/config.py
DELETED
@@ -1,21 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
|
3 |
-
REPO_PATH = '/'.join(os.path.abspath(__file__).split('/')[:-3]) + '/'
|
4 |
-
|
5 |
-
# QUADRUPLETS_PATH = REPO_PATH + 'checkpoints/cocktail_representation/quadruplets.pickle'
|
6 |
-
INGREDIENTS_LIST_PATH = REPO_PATH + 'checkpoints/cocktail_representation/ingredient_list.csv'
|
7 |
-
# ING_MATCH_SCORE_Q_PATH = REPO_PATH + 'checkpoints/cocktail_representation/ingredient_match_score_q.txt'
|
8 |
-
# ING_MATCH_SCORE_COUNT_PATH = REPO_PATH + 'checkpoints/cocktail_representation/ingredient_match_score_count.txt'
|
9 |
-
# COCKTAIL_DATA_FOLDER_PATH = REPO_PATH + 'checkpoints/cocktail_representation/'
|
10 |
-
COCKTAILS_CSV_DATA = REPO_PATH + 'checkpoints/cocktail_representation/cocktails_data.csv'
|
11 |
-
# COCKTAILS_PKL_DATA = REPO_PATH + 'checkpoints/cocktail_representation/cocktails_data.pkl'
|
12 |
-
# COCKTAILS_URL_DATA = REPO_PATH + 'checkpoints/cocktail_representation/cocktails_names_urls.pkl'
|
13 |
-
EXPERIMENT_PATH = REPO_PATH + 'experiments/cocktails/representation_learning/'
|
14 |
-
# ANALYSIS_PATH = REPO_PATH + 'experiments/cocktails/representation_analysis/'
|
15 |
-
# REPRESENTATIONS_PATH = REPO_PATH + 'experiments/cocktails/learned_representations/'
|
16 |
-
|
17 |
-
FULL_COCKTAIL_REP_PATH = REPO_PATH + "/checkpoints/cocktail_representation/handcoded_reps/cocktail_handcoded_reps_minmax_norm-1_1_dim13_customkeys.txt"
|
18 |
-
RECIPE2FEATURES_PATH = REPO_PATH + "/checkpoints/cocktail_representation/" # get this by running run_without_vae
|
19 |
-
COCKTAIL_REP_CHKPT_PATH = REPO_PATH + "/checkpoints/cocktail_representation/handcoded_reps/"
|
20 |
-
# FULL_COCKTAIL_REP_PATH = REPO_PATH + "experiments/cocktails/representation_analysis/affective_mapping/clustered_representations/all_cocktail_reps_norm-1_1_custom_keys_dim13.txt'
|
21 |
-
COCKTAIL_NN_PATH = REPO_PATH + "/checkpoints/cocktail_representation/handcoded_reps/nn_model.pickle"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/cocktails/pipeline/__init__.py
DELETED
File without changes
|
src/cocktails/pipeline/cocktail2affect.py
DELETED
@@ -1,372 +0,0 @@
|
|
1 |
-
import pandas as pd
|
2 |
-
import numpy as np
|
3 |
-
import os
|
4 |
-
from src.cocktails.utilities.cocktail_utilities import get_bunch_of_rep_keys
|
5 |
-
from src.cocktails.utilities.other_scrubbing_utilities import print_recipe
|
6 |
-
from src.cocktails.config import COCKTAILS_CSV_DATA
|
7 |
-
from src.music.config import CHECKPOINTS_PATH, EXPERIMENT_PATH
|
8 |
-
import matplotlib.pyplot as plt
|
9 |
-
from sklearn.cluster import KMeans
|
10 |
-
from sklearn.mixture import GaussianMixture
|
11 |
-
from sklearn.neighbors import NearestNeighbors
|
12 |
-
import pickle
|
13 |
-
import random
|
14 |
-
|
15 |
-
experiment_path = EXPERIMENT_PATH + '/cocktails/representation_analysis/affective_mapping/'
|
16 |
-
min_max_path = CHECKPOINTS_PATH + "/cocktail_representation/minmax/"
|
17 |
-
cluster_model_path = CHECKPOINTS_PATH + "/music2cocktails/affects2affect_cluster/cluster_model.pickle"
|
18 |
-
affective_space_dimensions = ((-1, 1), (-1, 1), (-1, 1)) # valence, arousal, dominance
|
19 |
-
n_splits = (3, 3, 2) # number of bins per dimension
|
20 |
-
# dimensions_weights = [1, 1, 0.5]
|
21 |
-
dimensions_weights = [1, 1, 1]
|
22 |
-
total_n_clusters = np.prod(n_splits) # total number of bins
|
23 |
-
affective_boundaries = [np.arange(asd[0], asd[1]+1e-6, (asd[1] - asd[0]) / n_split) for asd, n_split in zip(affective_space_dimensions, n_splits)]
|
24 |
-
for af in affective_boundaries:
|
25 |
-
af[-1] += 1e-6
|
26 |
-
all_keys = get_bunch_of_rep_keys()['custom']
|
27 |
-
original_affective_keys = get_bunch_of_rep_keys()['affective']
|
28 |
-
affective_keys = [a.split(' ')[1] for a in original_affective_keys]
|
29 |
-
random.seed(0)
|
30 |
-
cluster_colors = ['#%06X' % random.randint(0, 0xFFFFFF) for _ in range(total_n_clusters)]
|
31 |
-
|
32 |
-
clustering_method = 'k_means' # 'k_means', 'handcoded', 'agglo', 'spectral'
|
33 |
-
if clustering_method != 'handcoded':
|
34 |
-
total_n_clusters = 10
|
35 |
-
min_arousal = np.loadtxt(min_max_path + 'min_arousal.txt')
|
36 |
-
max_arousal = np.loadtxt(min_max_path + 'max_arousal.txt')
|
37 |
-
min_val = np.loadtxt(min_max_path + 'min_valence.txt')
|
38 |
-
max_val = np.loadtxt(min_max_path + 'max_valence.txt')
|
39 |
-
min_dom = np.loadtxt(min_max_path + 'min_dominance.txt')
|
40 |
-
max_dom = np.loadtxt(min_max_path + 'max_dominance.txt')
|
41 |
-
|
42 |
-
def get_cocktail_reps(path, save=False):
|
43 |
-
cocktail_data = pd.read_csv(path)
|
44 |
-
cocktail_reps = np.array([cocktail_data[k] for k in original_affective_keys]).transpose()
|
45 |
-
n_data, dim_rep = cocktail_reps.shape
|
46 |
-
# print(f'{n_data} data points of {dim_rep} dimensions: {affective_keys}')
|
47 |
-
cocktail_reps = normalize_cocktail_reps_affective(cocktail_reps, save=save)
|
48 |
-
if save:
|
49 |
-
np.savetxt(experiment_path + f'cocktail_reps_for_affective_mapping_-1_1_norm_sigmoid_rescaling_{dim_rep}_keys.txt', cocktail_reps)
|
50 |
-
return cocktail_reps
|
51 |
-
|
52 |
-
def sigmoid(x, shift, beta):
|
53 |
-
return (1 / (1 + np.exp(-(x + shift) * beta)) - 0.5) * 2
|
54 |
-
|
55 |
-
def normalize_cocktail_reps_affective(cocktail_reps, save=False):
|
56 |
-
if save:
|
57 |
-
min_cr = cocktail_reps.min(axis=0)
|
58 |
-
max_cr = cocktail_reps.max(axis=0)
|
59 |
-
np.savetxt(min_max_path + 'min_cocktail_reps_affective.txt', min_cr)
|
60 |
-
np.savetxt(min_max_path + 'max_cocktail_reps_affective.txt', max_cr)
|
61 |
-
else:
|
62 |
-
min_cr = np.loadtxt(min_max_path + 'min_cocktail_reps_affective.txt')
|
63 |
-
max_cr = np.loadtxt(min_max_path + 'max_cocktail_reps_affective.txt')
|
64 |
-
cocktail_reps = ((cocktail_reps - min_cr) / (max_cr - min_cr) - 0.5) * 2
|
65 |
-
cocktail_reps[:, 0] = sigmoid(cocktail_reps[:, 0], shift=0.05, beta=4)
|
66 |
-
cocktail_reps[:, 1] = sigmoid(cocktail_reps[:, 1], shift=0.3, beta=5)
|
67 |
-
cocktail_reps[:, 2] = sigmoid(cocktail_reps[:, 2], shift=0.15, beta=3)
|
68 |
-
cocktail_reps[:, 3] = sigmoid(cocktail_reps[:, 3], shift=0.9, beta=20)
|
69 |
-
cocktail_reps[:, 4] = sigmoid(cocktail_reps[:, 4], shift=0, beta=4)
|
70 |
-
cocktail_reps[:, 5] = sigmoid(cocktail_reps[:, 5], shift=0.2, beta=3)
|
71 |
-
cocktail_reps[:, 6] = sigmoid(cocktail_reps[:, 6], shift=0.5, beta=5)
|
72 |
-
cocktail_reps[:, 7] = sigmoid(cocktail_reps[:, 7], shift=0.2, beta=6)
|
73 |
-
return cocktail_reps
|
74 |
-
|
75 |
-
def plot(cocktail_reps):
|
76 |
-
dim_rep = cocktail_reps.shape[1]
|
77 |
-
for i in range(dim_rep):
|
78 |
-
for j in range(i+1, dim_rep):
|
79 |
-
plt.figure()
|
80 |
-
plt.scatter(cocktail_reps[:, i], cocktail_reps[:, j], s=150, alpha=0.5)
|
81 |
-
plt.xlabel(affective_keys[i])
|
82 |
-
plt.ylabel(affective_keys[j])
|
83 |
-
plt.savefig(experiment_path + f'scatters/{affective_keys[i]}_vs_{affective_keys[j]}.png', dpi=300)
|
84 |
-
plt.close('all')
|
85 |
-
plt.figure()
|
86 |
-
plt.hist(cocktail_reps[:, i])
|
87 |
-
plt.xlabel(affective_keys[i])
|
88 |
-
plt.savefig(experiment_path + f'hists/{affective_keys[i]}.png', dpi=300)
|
89 |
-
plt.close('all')
|
90 |
-
|
91 |
-
def get_clusters(affective_coordinates, save=False):
|
92 |
-
if clustering_method in ['k_means', 'gmm',]:
|
93 |
-
if clustering_method == 'k_means': model = KMeans(n_clusters=total_n_clusters)
|
94 |
-
elif clustering_method == 'gmm': model = GaussianMixture(n_components=total_n_clusters, covariance_type="full")
|
95 |
-
model.fit(affective_coordinates * np.array(dimensions_weights))
|
96 |
-
|
97 |
-
def find_cluster(aff_coord):
|
98 |
-
if aff_coord.ndim == 1:
|
99 |
-
aff_coord = aff_coord.reshape(1, -1)
|
100 |
-
return model.predict(aff_coord * np.array(dimensions_weights))
|
101 |
-
cluster_centers = model.cluster_centers_ if clustering_method == 'k_means' else []
|
102 |
-
if save:
|
103 |
-
to_save = dict(cluster_model=model,
|
104 |
-
cluster_centers=cluster_centers,
|
105 |
-
nb_clusters=len(cluster_centers),
|
106 |
-
dimensions_weights=dimensions_weights)
|
107 |
-
with open(cluster_model_path, 'wb') as f:
|
108 |
-
pickle.dump(to_save, f)
|
109 |
-
stop= 1
|
110 |
-
|
111 |
-
elif clustering_method == 'handcoded':
|
112 |
-
def find_cluster(aff_coord):
|
113 |
-
if aff_coord.ndim == 1:
|
114 |
-
aff_coord = aff_coord.reshape(1, -1)
|
115 |
-
cluster_coordinates = []
|
116 |
-
for i in range(aff_coord.shape[0]):
|
117 |
-
cluster_coordinates.append([np.argwhere(affective_boundaries[j] <= aff_coord[i, j]).flatten()[-1] for j in range(3)])
|
118 |
-
cluster_coordinates = np.array(cluster_coordinates)
|
119 |
-
cluster_ids = cluster_coordinates[:, 0] * np.prod(n_splits[1:]) + cluster_coordinates[:, 1] * n_splits[-1] + cluster_coordinates[:, 2]
|
120 |
-
return cluster_ids
|
121 |
-
# find cluster centers
|
122 |
-
cluster_centers = []
|
123 |
-
for i in range(n_splits[0]):
|
124 |
-
asd = affective_space_dimensions[0]
|
125 |
-
x_coordinate = np.arange(asd[0] + 1 / n_splits[0], asd[1], (asd[1] - asd[0]) / n_splits[0])[i]
|
126 |
-
for j in range(n_splits[1]):
|
127 |
-
asd = affective_space_dimensions[1]
|
128 |
-
y_coordinate = np.arange(asd[0] + 1 / n_splits[1], asd[1], (asd[1] - asd[0]) / n_splits[1])[j]
|
129 |
-
for k in range(n_splits[2]):
|
130 |
-
asd = affective_space_dimensions[2]
|
131 |
-
z_coordinate = np.arange(asd[0] + 1 / n_splits[2], asd[1], (asd[1] - asd[0]) / n_splits[2])[k]
|
132 |
-
cluster_centers.append([x_coordinate, y_coordinate, z_coordinate])
|
133 |
-
cluster_centers = np.array(cluster_centers)
|
134 |
-
else:
|
135 |
-
raise NotImplemented
|
136 |
-
cluster_ids = find_cluster(affective_coordinates)
|
137 |
-
return cluster_ids, cluster_centers, find_cluster
|
138 |
-
|
139 |
-
|
140 |
-
def cocktail2affect(cocktail_reps, save=False):
|
141 |
-
if cocktail_reps.ndim == 1:
|
142 |
-
cocktail_reps = cocktail_reps.reshape(1, -1)
|
143 |
-
|
144 |
-
assert affective_keys == ['booze', 'sweet', 'sour', 'fizzy', 'complex', 'bitter', 'spicy', 'colorful']
|
145 |
-
all_weights = []
|
146 |
-
|
147 |
-
# valence
|
148 |
-
# + sweet - bitter - booze + colorful
|
149 |
-
weights = np.array([-1, 1, 0, 0, 0, -1, 0, 1])
|
150 |
-
valence = (cocktail_reps * weights).sum(axis=1)
|
151 |
-
if save:
|
152 |
-
min_ = valence.min()
|
153 |
-
max_ = valence.max()
|
154 |
-
np.savetxt(min_max_path + 'min_valence.txt', np.array([min_]))
|
155 |
-
np.savetxt(min_max_path + 'max_valence.txt', np.array([max_]))
|
156 |
-
else:
|
157 |
-
min_ = min_val
|
158 |
-
max_ = max_val
|
159 |
-
valence = 2 * ((valence - min_) / (max_ - min_) - 0.5)
|
160 |
-
valence = sigmoid(valence, shift=0.1, beta=3.5)
|
161 |
-
valence = valence.reshape(-1, 1)
|
162 |
-
all_weights.append(weights.copy())
|
163 |
-
|
164 |
-
# arousal
|
165 |
-
# + fizzy + sour + complex - sweet + spicy + bitter
|
166 |
-
# weights = np.array([0, -1, 1, 1, 1, 1, 1, 0])
|
167 |
-
weights = np.array([0.7, 0, 1.5, 1.5, 0.6, 0, 0.6, 0])
|
168 |
-
arousal = (cocktail_reps * weights).sum(axis=1)
|
169 |
-
if save:
|
170 |
-
min_ = arousal.min()
|
171 |
-
max_ = arousal.max()
|
172 |
-
np.savetxt(min_max_path + 'min_arousal.txt', np.array([min_]))
|
173 |
-
np.savetxt(min_max_path + 'max_arousal.txt', np.array([max_]))
|
174 |
-
else:
|
175 |
-
min_, max_ = min_arousal, max_arousal
|
176 |
-
arousal = 2 * ((arousal - min_) / (max_ - min_) - 0.5) # normalize to -1, 1
|
177 |
-
arousal = sigmoid(arousal, shift=0.3, beta=4)
|
178 |
-
arousal = arousal.reshape(-1, 1)
|
179 |
-
all_weights.append(weights.copy())
|
180 |
-
|
181 |
-
# dominance
|
182 |
-
# assert affective_keys == ['booze', 'sweet', 'sour', 'fizzy', 'complex', 'bitter', 'spicy', 'colorful']
|
183 |
-
# + booze + fizzy - complex - bitter - sweet
|
184 |
-
weights = np.array([1.5, -0.8, 0, 0.7, -1, -1.5, 0, 0])
|
185 |
-
dominance = (cocktail_reps * weights).sum(axis=1)
|
186 |
-
if save:
|
187 |
-
min_ = dominance.min()
|
188 |
-
max_ = dominance.max()
|
189 |
-
np.savetxt(min_max_path + 'min_dominance.txt', np.array([min_]))
|
190 |
-
np.savetxt(min_max_path + 'max_dominance.txt', np.array([max_]))
|
191 |
-
else:
|
192 |
-
min_, max_ = min_dom, max_dom
|
193 |
-
dominance = 2 * ((dominance - min_) / (max_ - min_) - 0.5)
|
194 |
-
dominance = sigmoid(dominance, shift=-0.05, beta=5)
|
195 |
-
dominance = dominance.reshape(-1, 1)
|
196 |
-
all_weights.append(weights.copy())
|
197 |
-
|
198 |
-
affective_coordinates = np.concatenate([valence, arousal, dominance], axis=1)
|
199 |
-
# if save:
|
200 |
-
# assert (affective_coordinates.min(axis=0) == np.array([ac[0] for ac in affective_space_dimensions])).all()
|
201 |
-
# assert (affective_coordinates.max(axis=0) == np.array([ac[1] for ac in affective_space_dimensions])).all()
|
202 |
-
return affective_coordinates, all_weights
|
203 |
-
|
204 |
-
def save_reps(path, affective_cluster_ids):
|
205 |
-
cocktail_data = pd.read_csv(path)
|
206 |
-
rep_keys = get_bunch_of_rep_keys()['custom']
|
207 |
-
cocktail_reps = np.array([cocktail_data[k] for k in rep_keys]).transpose()
|
208 |
-
np.savetxt(experiment_path + 'clustered_representations/' + f'min_cocktail_reps_custom_keys_dim{cocktail_reps.shape[1]}.txt', cocktail_reps.min(axis=0))
|
209 |
-
np.savetxt(experiment_path + 'clustered_representations/' + f'max_cocktail_reps_custom_keys_dim{cocktail_reps.shape[1]}.txt', cocktail_reps.max(axis=0))
|
210 |
-
cocktail_reps = ((cocktail_reps - cocktail_reps.min(axis=0)) / (cocktail_reps.max(axis=0) - cocktail_reps.min(axis=0)) - 0.5) * 2 # normalize in -1, 1
|
211 |
-
np.savetxt(experiment_path + 'clustered_representations/' + f'all_cocktail_reps_norm-1_1_custom_keys_dim{cocktail_reps.shape[1]}.txt', cocktail_reps)
|
212 |
-
np.savetxt(experiment_path + 'clustered_representations/' + 'affective_cluster_ids.txt', affective_cluster_ids)
|
213 |
-
for cluster_id in sorted(set(affective_cluster_ids)):
|
214 |
-
indexes = np.argwhere(affective_cluster_ids == cluster_id).flatten()
|
215 |
-
reps = cocktail_reps[indexes, :]
|
216 |
-
np.savetxt(experiment_path + 'clustered_representations/' + f'rep_cluster{cluster_id}_norm-1_1_custom_keys_dim{cocktail_reps.shape[1]}.txt', reps)
|
217 |
-
|
218 |
-
def study_affects(affective_coordinates, affective_cluster_ids):
|
219 |
-
plt.figure()
|
220 |
-
plt.hist(affective_cluster_ids, bins=total_n_clusters)
|
221 |
-
plt.xlabel('Affective cluster ids')
|
222 |
-
plt.xticks(np.arange(total_n_clusters))
|
223 |
-
plt.savefig(experiment_path + 'affective_cluster_distrib.png')
|
224 |
-
fig = plt.gcf()
|
225 |
-
plt.close(fig)
|
226 |
-
|
227 |
-
fig = plt.figure()
|
228 |
-
ax = fig.add_subplot(projection='3d')
|
229 |
-
ax.set_xlim([-1, 1])
|
230 |
-
ax.set_ylim([-1, 1])
|
231 |
-
ax.set_zlim([-1, 1])
|
232 |
-
for cluster_id in sorted(set(affective_cluster_ids)):
|
233 |
-
indexes = np.argwhere(affective_cluster_ids == cluster_id).flatten()
|
234 |
-
ax.scatter(affective_coordinates[indexes, 0], affective_coordinates[indexes, 1], affective_coordinates[indexes, 2], c=cluster_colors[cluster_id], s=150)
|
235 |
-
ax.set_xlabel('Valence')
|
236 |
-
ax.set_ylabel('Arousal')
|
237 |
-
ax.set_zlabel('Dominance')
|
238 |
-
stop = 1
|
239 |
-
plt.savefig(experiment_path + 'scatters_affect/affective_mapping.png')
|
240 |
-
fig = plt.gcf()
|
241 |
-
plt.close(fig)
|
242 |
-
|
243 |
-
affects = ['Valence', 'Arousal', 'Dominance']
|
244 |
-
for i in range(3):
|
245 |
-
for j in range(i + 1, 3):
|
246 |
-
fig = plt.figure()
|
247 |
-
ax = fig.add_subplot()
|
248 |
-
for cluster_id in sorted(set(affective_cluster_ids)):
|
249 |
-
indexes = np.argwhere(affective_cluster_ids == cluster_id).flatten()
|
250 |
-
ax.scatter(affective_coordinates[indexes, i], affective_coordinates[indexes, j], alpha=0.5, c=cluster_colors[cluster_id], s=150)
|
251 |
-
ax.set_xlabel(affects[i])
|
252 |
-
ax.set_ylabel(affects[j])
|
253 |
-
plt.savefig(experiment_path + f'scatters_affect/scatter_{affects[i]}_vs_{affects[j]}.png')
|
254 |
-
fig = plt.gcf()
|
255 |
-
plt.close(fig)
|
256 |
-
plt.figure()
|
257 |
-
plt.hist(affective_coordinates[:, i])
|
258 |
-
plt.xlabel(affects[i])
|
259 |
-
plt.savefig(experiment_path + f'hists_affect/hist_{affects[i]}.png')
|
260 |
-
fig = plt.gcf()
|
261 |
-
plt.close(fig)
|
262 |
-
plt.close('all')
|
263 |
-
stop = 1
|
264 |
-
|
265 |
-
def sample_clusters(path, cocktail_reps, all_weights, affective_cluster_ids, affective_cluster_centers, affective_coordinates, n_samples=4):
|
266 |
-
cocktail_data = pd.read_csv(path)
|
267 |
-
these_cocktail_reps = normalize_cocktail_reps_affective(np.array([cocktail_data[k] for k in original_affective_keys]).transpose())
|
268 |
-
|
269 |
-
names = cocktail_data['names']
|
270 |
-
urls = cocktail_data['urls']
|
271 |
-
ingr_str = cocktail_data['ingredients_str']
|
272 |
-
for cluster_id in sorted(set(affective_cluster_ids)):
|
273 |
-
indexes = np.argwhere(affective_cluster_ids == cluster_id).flatten()
|
274 |
-
print('\n\n\n---------\n----------\n-----------\n')
|
275 |
-
cluster_str = ''
|
276 |
-
cluster_str += f'Affective cluster #{cluster_id}' + \
|
277 |
-
f'\n\tSize: {len(indexes)}' + \
|
278 |
-
f'\n\tCenter: ' + \
|
279 |
-
f'\n\t\tVal: {affective_cluster_centers[cluster_id][0]:.2f}, ' + \
|
280 |
-
f'\n\t\tArousal: {affective_cluster_centers[cluster_id][1]:.2f}, ' + \
|
281 |
-
f'\n\t\tDominance: {affective_cluster_centers[cluster_id][2]:.2f}'
|
282 |
-
print(cluster_str)
|
283 |
-
if affective_cluster_centers[cluster_id][2] == np.max(affective_cluster_centers[:, 2]):
|
284 |
-
stop = 1
|
285 |
-
sampled_idx = np.random.choice(indexes, size=min(len(indexes), n_samples), replace=False)
|
286 |
-
cocktail_str = ''
|
287 |
-
for i in sampled_idx:
|
288 |
-
assert np.sum(cocktail_reps[i] - these_cocktail_reps[i]) < 1e-9
|
289 |
-
cocktail_str += f'\n\n-------------'
|
290 |
-
cocktail_str += print_recipe(ingr_str[i], name=names[i], to_print=False)
|
291 |
-
cocktail_str += f'\nUrl: {urls[i]}'
|
292 |
-
cocktail_str += '\n\nRepresentation: ' + ', '.join([f'{af}: {cr:.2f}' for af, cr in zip(affective_keys, cocktail_reps[i])]) + '\n'
|
293 |
-
cocktail_str += '\n' + generate_explanation(cocktail_reps[i], all_weights, affective_coordinates[i])
|
294 |
-
print(cocktail_str)
|
295 |
-
stop = 1
|
296 |
-
cluster_str += '\n' + cocktail_str
|
297 |
-
with open(f"/home/cedric/Documents/pianocktail/experiments/cocktails/representation_analysis/affective_mapping/clusters/cluster_{cluster_id}", 'w') as f:
|
298 |
-
f.write(cluster_str)
|
299 |
-
stop = 1
|
300 |
-
|
301 |
-
def explanation_per_dimension(i, cocktail_rep, all_weights, aff_coord):
|
302 |
-
names = ['valence', 'arousal', 'dominance']
|
303 |
-
weights = all_weights[i]
|
304 |
-
explanation_str = f'\n{names[i].capitalize()} explanation ({aff_coord[i]:.2f}):'
|
305 |
-
strengths = np.abs(weights * cocktail_rep)
|
306 |
-
strengths /= strengths.sum()
|
307 |
-
indexes = np.flip(np.argsort(strengths))
|
308 |
-
for ind in indexes:
|
309 |
-
if strengths[ind] != 0:
|
310 |
-
if np.sign(weights[ind]) == np.sign(cocktail_rep[ind]):
|
311 |
-
keyword = 'high' if cocktail_rep[ind] > 0 else 'low'
|
312 |
-
explanation_str += f'\n\t{int(strengths[ind]*100)}%: higher {names[i]} because {keyword} {affective_keys[ind]}'
|
313 |
-
else:
|
314 |
-
keyword = 'high' if cocktail_rep[ind] > 0 else 'low'
|
315 |
-
explanation_str += f'\n\t{int(strengths[ind]*100)}%: low {names[i]} because {keyword} {affective_keys[ind]}'
|
316 |
-
return explanation_str
|
317 |
-
|
318 |
-
def generate_explanation(cocktail_rep, all_weights, aff_coord):
|
319 |
-
explanation_str = ''
|
320 |
-
for i in range(3):
|
321 |
-
explanation_str += explanation_per_dimension(i, cocktail_rep, all_weights, aff_coord)
|
322 |
-
return explanation_str
|
323 |
-
|
324 |
-
def cocktails2affect_clusters(cocktail_rep):
|
325 |
-
if cocktail_rep.ndim == 1:
|
326 |
-
cocktail_rep = cocktail_rep.reshape(1, -1)
|
327 |
-
affective_coordinates, _ = cocktail2affect(cocktail_rep)
|
328 |
-
affective_cluster_ids, _, _ = get_clusters(affective_coordinates)
|
329 |
-
return affective_cluster_ids
|
330 |
-
|
331 |
-
|
332 |
-
def setup_affective_space(path, save=False):
|
333 |
-
cocktail_data = pd.read_csv(path)
|
334 |
-
names = cocktail_data['names']
|
335 |
-
recipes = cocktail_data['ingredients_str']
|
336 |
-
urls = cocktail_data['urls']
|
337 |
-
reps = get_cocktail_reps(path)
|
338 |
-
affective_coordinates, all_weights = cocktail2affect(reps)
|
339 |
-
affective_cluster_ids, affective_cluster_centers, find_cluster = get_clusters(affective_coordinates, save=save)
|
340 |
-
nn_model = NearestNeighbors(n_neighbors=1)
|
341 |
-
nn_model.fit(affective_coordinates)
|
342 |
-
def cocktail2affect_cluster(cocktail_rep):
|
343 |
-
affective_coordinates, _ = cocktail2affect(cocktail_rep)
|
344 |
-
return find_cluster(affective_coordinates)
|
345 |
-
|
346 |
-
affective_clusters = dict(affective_coordinates=affective_coordinates, # coordinates of cocktail in affective space
|
347 |
-
affective_cluster_ids=affective_cluster_ids, # cluster id of cocktails
|
348 |
-
affective_cluster_centers=affective_cluster_centers, # cluster centers in affective space
|
349 |
-
affective_weights=all_weights, # weights to compute valence, arousal, dominance from cocktail representations
|
350 |
-
original_affective_keys=original_affective_keys,
|
351 |
-
cocktail_reps=reps, # cocktail representations from the dataset (normalized)
|
352 |
-
find_cluster=find_cluster, # function to retrieve a cluster from affective coordinates
|
353 |
-
nn_model=nn_model, # to predict the nearest neighbor affective space,
|
354 |
-
names=names, # names of cocktails in the dataset
|
355 |
-
urls=urls, # urls from the dataset
|
356 |
-
recipes=recipes, # recipes of the dataset
|
357 |
-
cocktail2affect=cocktail2affect, # function to compute affects from cocktail representations
|
358 |
-
cocktails2affect_clusters=cocktails2affect_clusters,
|
359 |
-
cocktail2affect_cluster=cocktail2affect_cluster
|
360 |
-
)
|
361 |
-
|
362 |
-
return affective_clusters
|
363 |
-
|
364 |
-
if __name__ == '__main__':
|
365 |
-
reps = get_cocktail_reps(COCKTAILS_CSV_DATA, save=True)
|
366 |
-
# plot(reps)
|
367 |
-
affective_coordinates, all_weights = cocktail2affect(reps, save=True)
|
368 |
-
affective_cluster_ids, affective_cluster_centers, find_cluster = get_clusters(affective_coordinates)
|
369 |
-
save_reps(COCKTAILS_CSV_DATA, affective_cluster_ids)
|
370 |
-
study_affects(affective_coordinates, affective_cluster_ids)
|
371 |
-
sample_clusters(COCKTAILS_CSV_DATA, reps, all_weights, affective_cluster_ids, affective_cluster_centers, affective_coordinates)
|
372 |
-
setup_affective_space(COCKTAILS_CSV_DATA, save=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/cocktails/pipeline/cocktailrep2recipe.py
DELETED
@@ -1,329 +0,0 @@
|
|
1 |
-
import matplotlib.pyplot as plt
|
2 |
-
import pickle
|
3 |
-
from src.cocktails.utilities.cocktail_generation_utilities.population import *
|
4 |
-
from src.cocktails.utilities.glass_and_volume_utilities import glass_volume
|
5 |
-
from src.cocktails.config import RECIPE2FEATURES_PATH
|
6 |
-
|
7 |
-
def test_mutation_params(cocktail_reps):
|
8 |
-
indexes = np.arange(cocktail_reps.shape[0])
|
9 |
-
np.random.shuffle(indexes)
|
10 |
-
perfs = []
|
11 |
-
mutated_perfs = []
|
12 |
-
pop_params = dict(mutation_params=dict(p_add_ing=0.7,
|
13 |
-
p_remove_ing=0.7,
|
14 |
-
p_switch_ing=0.5,
|
15 |
-
p_change_q=0.7,
|
16 |
-
delta_change_q=0.3,
|
17 |
-
asexual_rep=True,
|
18 |
-
crossover=True,
|
19 |
-
ingredient_addition=(0.1, 0.05)),
|
20 |
-
nb_generations=100,
|
21 |
-
pop_size=100,
|
22 |
-
nb_elites=10,
|
23 |
-
dist='mse',
|
24 |
-
n_neighbors=5)
|
25 |
-
|
26 |
-
for i in indexes[:20]:
|
27 |
-
target = cocktail_reps[i]
|
28 |
-
for j in range(100):
|
29 |
-
parent = IndividualCocktail(pop_params=pop_params,
|
30 |
-
target_affective_cluster=None,
|
31 |
-
target=target.copy())
|
32 |
-
perfs.append(parent.perf)
|
33 |
-
child = parent.get_child()[0]
|
34 |
-
# child.compute_cocktail_rep()
|
35 |
-
# child.compute_perf()
|
36 |
-
if perfs[-1] != child.perf:
|
37 |
-
mutated_perfs.append(child.perf)
|
38 |
-
else:
|
39 |
-
perfs.pop(-1)
|
40 |
-
filtered_children = np.argwhere(np.array(mutated_perfs)==-100).flatten()
|
41 |
-
non_filtered_ids = np.argwhere(np.logical_and(np.array(perfs)!=-100, np.array(mutated_perfs)!=-100)).flatten()
|
42 |
-
print(f'Proportion of filtered: {filtered_children.size} / {len(mutated_perfs)} = {int(filtered_children.size / len(mutated_perfs)*100)}%')
|
43 |
-
plt.figure()
|
44 |
-
plt.scatter(np.array(perfs)[non_filtered_ids], np.array(mutated_perfs)[non_filtered_ids], s=100, alpha=0.5)
|
45 |
-
plt.xlabel('parent perf')
|
46 |
-
plt.ylabel('child perf')
|
47 |
-
print(np.corrcoef(np.array(perfs)[non_filtered_ids], np.array(mutated_perfs)[non_filtered_ids])[0, 1])
|
48 |
-
plt.show()
|
49 |
-
stop = 1
|
50 |
-
|
51 |
-
def test_crossover(cocktail_reps):
|
52 |
-
indexes = np.arange(cocktail_reps.shape[0])
|
53 |
-
np.random.shuffle(indexes)
|
54 |
-
perfs = []
|
55 |
-
mutated_perfs = []
|
56 |
-
pop_params = dict(mutation_params=dict(p_add_ing=0.7,
|
57 |
-
p_remove_ing=0.7,
|
58 |
-
p_switch_ing=0.5,
|
59 |
-
p_change_q=0.7,
|
60 |
-
delta_change_q=0.3,
|
61 |
-
asexual_rep=True,
|
62 |
-
crossover=True,
|
63 |
-
ingredient_addition=(0.1, 0.05)),
|
64 |
-
nb_generations=100,
|
65 |
-
pop_size=100,
|
66 |
-
nb_elites=10,
|
67 |
-
dist='mse',
|
68 |
-
n_neighbors=5)
|
69 |
-
for i in indexes[:20]:
|
70 |
-
for j in range(100):
|
71 |
-
target = cocktail_reps[i]
|
72 |
-
parent1 = IndividualCocktail(pop_params=pop_params,
|
73 |
-
target_affective_cluster=None,
|
74 |
-
target=target.copy())
|
75 |
-
parent2 = IndividualCocktail(pop_params=pop_params,
|
76 |
-
target_affective_cluster=None,
|
77 |
-
target=target.copy())
|
78 |
-
child = parent1.get_child_with(parent2)[0]
|
79 |
-
# child.compute_cocktail_rep()
|
80 |
-
# child.compute_perf()
|
81 |
-
perfs.append((parent1.perf + parent2.perf)/2)
|
82 |
-
if perfs[-1] != child.perf:
|
83 |
-
mutated_perfs.append(child.perf)
|
84 |
-
else:
|
85 |
-
perfs.pop(-1)
|
86 |
-
filtered_children = np.argwhere(np.array(mutated_perfs)==-100).flatten()
|
87 |
-
non_filtered_ids = np.argwhere(np.logical_and(np.array(perfs)>-45, np.array(mutated_perfs)!=-100)).flatten()
|
88 |
-
print(f'Proportion of filtered: {filtered_children.size} / {len(mutated_perfs)} = {int(filtered_children.size / len(mutated_perfs)*100)}%')
|
89 |
-
plt.figure()
|
90 |
-
plt.scatter(np.array(perfs)[non_filtered_ids], np.array(mutated_perfs)[non_filtered_ids], s=100, alpha=0.5)
|
91 |
-
plt.xlabel('parent perf')
|
92 |
-
plt.ylabel('child perf')
|
93 |
-
print(np.corrcoef(np.array(perfs)[non_filtered_ids], np.array(mutated_perfs)[non_filtered_ids])[0, 1])
|
94 |
-
plt.show()
|
95 |
-
stop = 1
|
96 |
-
|
97 |
-
def run_comparisons():
|
98 |
-
np.random.seed(0)
|
99 |
-
indexes = np.arange(cocktail_reps.shape[0])
|
100 |
-
np.random.shuffle(indexes)
|
101 |
-
for n_neighbors in [0, 5]:
|
102 |
-
id_str_neigh = '5neigh_' if n_neighbors == 5 else '0_neigh_'
|
103 |
-
for asexual_rep in [True, False]:
|
104 |
-
id_str_as = id_str_neigh + 'asexual_' if asexual_rep else id_str_neigh
|
105 |
-
for crossover in [True, False]:
|
106 |
-
id_str = id_str_as + 'crossover_' if crossover else id_str_as
|
107 |
-
if crossover or asexual_rep:
|
108 |
-
mutation_params = dict(p_add_ing = 0.5,
|
109 |
-
p_remove_ing = 0.5,
|
110 |
-
p_change_q = 0.5,
|
111 |
-
delta_change_q = 0.3,
|
112 |
-
asexual_rep=asexual_rep,
|
113 |
-
crossover=crossover,
|
114 |
-
ingredient_addition = (0.1, 0.05))
|
115 |
-
nb_generations = 100
|
116 |
-
pop_size=100
|
117 |
-
nb_elites=10
|
118 |
-
dist = 'mse'
|
119 |
-
results = dict()
|
120 |
-
print(id_str)
|
121 |
-
for i, ind in enumerate(indexes[:30]):
|
122 |
-
print(i+1)
|
123 |
-
target_ing_str = data['ingredients_str'][ind]
|
124 |
-
target = cocktail_reps[ind]
|
125 |
-
population = Population(nb_generations=nb_generations, pop_size=pop_size, nb_elite=nb_elites,
|
126 |
-
target=target, dist=dist, mutation_params=mutation_params,
|
127 |
-
n_neighbors=n_neighbors, target_ing_str=target_ing_str, true_prep_type=data['category'][ind])
|
128 |
-
population.run_evolution(verbose=False)
|
129 |
-
best_scores, best_ind = population.get_best_score()
|
130 |
-
recipes = [ind.get_recipe()[3] for ind in best_ind[:5]]
|
131 |
-
results[str(ind)] = dict(best_scores=best_scores[:5], recipes=recipes, target=population.target_individual.get_recipe()[3])
|
132 |
-
with open(f'/home/cedric/Desktop/ga_tests_{id_str}.pickle', 'wb') as f:
|
133 |
-
pickle.dump(results, f)
|
134 |
-
|
135 |
-
def get_cocktail_distribution(cocktail_reps):
|
136 |
-
return (np.mean(cocktail_reps, axis=0), np.cov(cocktail_reps, rowvar=0))
|
137 |
-
|
138 |
-
def sample_cocktails(cocktail_reps, n=10, target_affective_cluster=None, to_print=True):
|
139 |
-
distrib = get_cocktail_distribution(cocktail_reps)
|
140 |
-
sampled_cocktail_reps = np.random.multivariate_normal(distrib[0], distrib[1], size=n)
|
141 |
-
recipes = []
|
142 |
-
closest_recipes = []
|
143 |
-
for i_c, cr in enumerate(sampled_cocktail_reps):
|
144 |
-
population = setup_recipe_generation(cr.copy(), target_affective_cluster=target_affective_cluster)
|
145 |
-
closest_recipes.append(population.nn_recipes[0])
|
146 |
-
best_scores, best_individuals = population.run_evolution()
|
147 |
-
recipes.append(best_individuals[0].get_recipe()[3])
|
148 |
-
if to_print:
|
149 |
-
print(f'Sample #{len(recipes)}:')
|
150 |
-
print(recipes[-1])
|
151 |
-
print('Closest from dataset:')
|
152 |
-
print(closest_recipes[-1])
|
153 |
-
stop = 1
|
154 |
-
return recipes, closest_recipes
|
155 |
-
|
156 |
-
def setup_recipe_generation(target, known_target_dict=None, target_affective_cluster=None):
|
157 |
-
# pop_params = dict(mutation_params=dict(p_add_ing=0.7,
|
158 |
-
# p_remove_ing=0.7,
|
159 |
-
# p_switch_ing=0.5,
|
160 |
-
# p_change_q=0.7,
|
161 |
-
# delta_change_q=0.3,
|
162 |
-
# asexual_rep=True,
|
163 |
-
# crossover=True,
|
164 |
-
# ingredient_addition=(0.1, 0.05)),
|
165 |
-
# nb_generations=2, #100
|
166 |
-
# pop_size=5, #100
|
167 |
-
# nb_elites=2, #10
|
168 |
-
# dist='mse',
|
169 |
-
# n_neighbors=3) #5
|
170 |
-
pop_params = dict(mutation_params=dict(p_add_ing=0.4,
|
171 |
-
p_remove_ing=1,
|
172 |
-
p_switch_ing=0.5,
|
173 |
-
p_change_q=1,
|
174 |
-
delta_change_q=0.3,
|
175 |
-
asexual_rep=True,
|
176 |
-
crossover=True,
|
177 |
-
ingredient_addition=(0.1, 0.05)),
|
178 |
-
nb_generations=100, # 100
|
179 |
-
pop_size=100, # 100
|
180 |
-
nb_elites=10, # 10
|
181 |
-
dist='mse',
|
182 |
-
n_neighbors=5) # 5
|
183 |
-
|
184 |
-
population = Population(target=target, target_affective_cluster=target_affective_cluster, known_target_dict=known_target_dict, pop_params=pop_params)
|
185 |
-
return population
|
186 |
-
|
187 |
-
def cocktailrep2recipe(cocktail_rep, unit='mL', target_affective_cluster=None, known_target_dict=None, n_output=1, return_ind=False, verbose=True, full_verbose=False, level=0):
|
188 |
-
init_time = time.time()
|
189 |
-
if verbose: print(' ' * level + 'Generating cocktail..')
|
190 |
-
if cocktail_rep.ndim > 1:
|
191 |
-
assert cocktail_rep.shape[0] == 1
|
192 |
-
cocktail_rep = cocktail_rep.flatten()
|
193 |
-
# target_affective_cluster = target_affective_cluster[0]
|
194 |
-
population = setup_recipe_generation(cocktail_rep.copy(), known_target_dict=known_target_dict, target_affective_cluster=target_affective_cluster)
|
195 |
-
if full_verbose:
|
196 |
-
print(' ' * (level + 2) + '3 nearest neighbors:')
|
197 |
-
for i, recipe, score in zip(range(3), population.nn_recipes[:3], population.nn_scores[:3]):
|
198 |
-
print(' ' * (level + 4) + f'#{i+1}, score: {score:.2f}')
|
199 |
-
print(' ' * (level + 4) + recipe[1:].replace('None ()', '').replace('\t\t', ' ' * (level + 6)))
|
200 |
-
best_scores, best_individuals = population.run_evolution(verbose=full_verbose, level=level+2)
|
201 |
-
for i in range(n_output):
|
202 |
-
best_individuals[i].make_recipe_fit_the_glass()
|
203 |
-
instructions = [ind.get_instructions() for ind in best_individuals[:n_output]]
|
204 |
-
recipes = [ind.get_recipe(unit=unit)[3] for ind in best_individuals[:n_output]]
|
205 |
-
glasses = [ind.glass for ind in best_individuals[:n_output]]
|
206 |
-
prep_types = [ind.prep_type for ind in best_individuals[:n_output]]
|
207 |
-
for i, g, p, inst in zip(range(len(recipes)), glasses, prep_types, instructions):
|
208 |
-
recipes[i] = recipes[i].replace('Recipe', 'Ingredients') + f'Serve in:\n {g.capitalize()} glass.\n' + inst
|
209 |
-
if full_verbose:
|
210 |
-
print(f'\n--------------\n{n_output} best results:')
|
211 |
-
for i, recipe, score in zip(range(n_output), recipes, best_scores[:n_output]):
|
212 |
-
print(f'#{i+1}, score: {score:.2f}')
|
213 |
-
print(recipe)
|
214 |
-
if verbose: print(' ' * (level + 2) + f'Generated in {int(time.time() - init_time)} seconds.')
|
215 |
-
if return_ind:
|
216 |
-
return recipes, best_scores[:n_output], best_individuals[:n_output]
|
217 |
-
else:
|
218 |
-
return recipes, best_scores[:n_output]
|
219 |
-
|
220 |
-
|
221 |
-
def interpolate(cocktail_rep1, cocktail_rep2, alpha, verbose=False):
|
222 |
-
recipe, score = cocktailrep2recipe(alpha * cocktail_rep1 + (1 - alpha) * cocktail_rep2, verbose=verbose)
|
223 |
-
return recipe[0], score
|
224 |
-
|
225 |
-
def interpolation_study(n_steps, cocktail_reps):
|
226 |
-
alphas = np.arange(0, 1 + 1e-6, 1/(n_steps + 1))
|
227 |
-
indexes = np.random.choice(np.arange(cocktail_reps.shape[0]), size=2, replace=False)
|
228 |
-
target_ing_str1, target_ing_str2 = data['ingredients_str'][indexes[0]], data['ingredients_str'][indexes[1]]
|
229 |
-
cocktail_rep1, cocktail_rep2 = cocktail_reps[indexes[0]], cocktail_reps[indexes[1]]
|
230 |
-
recipes, scores = [], []
|
231 |
-
for alpha in alphas:
|
232 |
-
recipe, score = interpolate(cocktail_rep1, cocktail_rep2, alpha)
|
233 |
-
recipes.append(recipe)
|
234 |
-
scores.append(score[0])
|
235 |
-
print('Point A:')
|
236 |
-
print_recipe(ingredient_str=target_ing_str2)
|
237 |
-
for i, alpha in enumerate(alphas):
|
238 |
-
print(f'Alpha = {alpha}, score = {scores[i]}')
|
239 |
-
print(recipes[i])
|
240 |
-
print('Point B:')
|
241 |
-
print_recipe(ingredient_str=target_ing_str1)
|
242 |
-
stop = 1
|
243 |
-
|
244 |
-
def test_robustness_affective_cluster(cocktail_reps):
|
245 |
-
indexes = np.arange(cocktail_reps.shape[0])
|
246 |
-
np.random.shuffle(indexes)
|
247 |
-
matches = []
|
248 |
-
for i in indexes:
|
249 |
-
target_ing_str = data['ingredients_str'][i]
|
250 |
-
true_prep_type = data['category'][i]
|
251 |
-
target = cocktail_reps[i]
|
252 |
-
# get affective cluster
|
253 |
-
recipes, best_scores, best_inds = cocktailrep2recipe(cocktail_rep=target, target_ing_str=target_ing_str, true_prep_type=true_prep_type, n_output=1, verbose=False,
|
254 |
-
return_ind=True)
|
255 |
-
|
256 |
-
matches.append(best_inds[0].does_affective_cluster_match())
|
257 |
-
print(np.mean(matches))
|
258 |
-
|
259 |
-
def test(cocktail_reps):
|
260 |
-
indexes = np.arange(these_cocktail_reps.shape[0])
|
261 |
-
unnormalized_cr = np.array([data[k] for k in rep_keys]).transpose()
|
262 |
-
|
263 |
-
for i in indexes:
|
264 |
-
target_ing_str = data['ingredients_str'][i]
|
265 |
-
true_prep_type = data['category'][i]
|
266 |
-
target = these_cocktail_reps[i]
|
267 |
-
# print('preptype:', true_prep_type)
|
268 |
-
# print('cocktail unnormalized', np.sum(unnormalized_cr[i]), unnormalized_cr[i])
|
269 |
-
# print('cocktail hand normalized', np.sum(normalize_cocktail(unnormalized_cr[i])), normalize_cocktail(unnormalized_cr[i]))
|
270 |
-
# print('cocktail rep normalized', np.sum(these_cocktail_reps[i]), these_cocktail_reps[i])
|
271 |
-
# print('cocktail rep normalized', np.sum(all_reps[i]), all_reps[i])
|
272 |
-
|
273 |
-
population = setup_recipe_generation(target.copy(), target_ing_str=target_ing_str, target_affective_cluster=None, true_prep_type=true_prep_type)
|
274 |
-
target = population.target_individual
|
275 |
-
target.compute_perf()
|
276 |
-
if target.perf < -50:
|
277 |
-
print(i)
|
278 |
-
print_recipe(target_ing_str)
|
279 |
-
if not target.is_alcohol_present(): print('No alcohol')
|
280 |
-
if not target.is_total_volume_enough(): print('small volume')
|
281 |
-
if not target.does_fit_glass():
|
282 |
-
print(target.end_volume)
|
283 |
-
print(glass_volume[target.get_glass_type()] * 0.81)
|
284 |
-
print('too much volume')
|
285 |
-
if not target.is_alcohol_reasonable():
|
286 |
-
print(f'amount of alcohol too small or too large: {target.alcohol_precentage}')
|
287 |
-
stop = 1
|
288 |
-
|
289 |
-
|
290 |
-
if __name__ == '__main__':
|
291 |
-
these_cocktail_reps = COCKTAIL_REPS.copy()
|
292 |
-
# test_crossover(these_cocktail_reps)
|
293 |
-
# test_mutation_params(these_cocktail_reps)
|
294 |
-
# test(these_cocktail_reps)
|
295 |
-
# recipes, closest_recipes = sample_cocktails(these_cocktail_reps, n=10)
|
296 |
-
# interpolation_study(n_steps=4, cocktail_reps=these_cocktail_reps)
|
297 |
-
# test_robustness_affective_cluster(these_cocktail_reps)
|
298 |
-
indexes = np.arange(these_cocktail_reps.shape[0])
|
299 |
-
np.random.shuffle(indexes)
|
300 |
-
# test_crossover(mutation_params, dist)
|
301 |
-
# test_mutation_params(mutation_params, dist)
|
302 |
-
stop = 1
|
303 |
-
unnormalized_cr = np.array([data[k] for k in rep_keys]).transpose()
|
304 |
-
for i in indexes:
|
305 |
-
print(i)
|
306 |
-
target_ing_str = data['ingredients_str'][i]
|
307 |
-
target_prep_type = data['category'][i]
|
308 |
-
target_glass = data['glass'][i]
|
309 |
-
|
310 |
-
print('preptype:', target_prep_type)
|
311 |
-
print('cocktail unnormalized', np.sum(unnormalized_cr[i]), unnormalized_cr[i])
|
312 |
-
print('cocktail hand normalized', np.sum(normalize_cocktail(unnormalized_cr[i])), normalize_cocktail(unnormalized_cr[i]))
|
313 |
-
print('cocktail rep normalized', np.sum(these_cocktail_reps[i]), these_cocktail_reps[i])
|
314 |
-
print('cocktail rep normalized', np.sum(all_reps[i]), all_reps[i])
|
315 |
-
print(i)
|
316 |
-
|
317 |
-
print('___________Target')
|
318 |
-
nn_model = NearestNeighbors()
|
319 |
-
nn_model.fit(these_cocktail_reps)
|
320 |
-
dists, indexes = nn_model.kneighbors(these_cocktail_reps[i].reshape(1, -1))
|
321 |
-
print(indexes)
|
322 |
-
print_recipe(target_ing_str)
|
323 |
-
target = these_cocktail_reps[i]
|
324 |
-
known_target_dict = dict(prep_type=target_prep_type,
|
325 |
-
ing_str=target_ing_str,
|
326 |
-
glass=target_glass)
|
327 |
-
recipes, best_scores = cocktailrep2recipe(cocktail_rep=target, known_target_dict=known_target_dict, n_output=1, verbose=True, full_verbose=True)
|
328 |
-
|
329 |
-
stop = 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/cocktails/pipeline/get_affect2affective_cluster.py
DELETED
@@ -1,23 +0,0 @@
|
|
1 |
-
from src.music.config import CHECKPOINTS_PATH
|
2 |
-
import pickle
|
3 |
-
import numpy as np
|
4 |
-
|
5 |
-
# can be computed from cocktail2affect
|
6 |
-
cluster_model_path = CHECKPOINTS_PATH + "/music2cocktails/affects2affect_cluster/cluster_model.pickle"
|
7 |
-
|
8 |
-
def get_affect2affective_cluster():
|
9 |
-
with open(cluster_model_path, 'rb') as f:
|
10 |
-
data = pickle.load(f)
|
11 |
-
model = data['cluster_model']
|
12 |
-
dimensions_weights = data['dimensions_weights']
|
13 |
-
def find_cluster(aff_coord):
|
14 |
-
if aff_coord.ndim == 1:
|
15 |
-
aff_coord = aff_coord.reshape(1, -1)
|
16 |
-
return model.predict(aff_coord * np.array(dimensions_weights))
|
17 |
-
return find_cluster
|
18 |
-
|
19 |
-
def get_affective_cluster_centers():
|
20 |
-
with open(cluster_model_path, 'rb') as f:
|
21 |
-
data = pickle.load(f)
|
22 |
-
return data['cluster_centers']
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/cocktails/pipeline/get_cocktail2affective_cluster.py
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
from src.cocktails.pipeline.get_affect2affective_cluster import get_affect2affective_cluster
|
2 |
-
from src.cocktails.pipeline.cocktail2affect import cocktail2affect
|
3 |
-
|
4 |
-
def get_cocktail2affective_cluster():
|
5 |
-
find_cluster = get_affect2affective_cluster()
|
6 |
-
def cocktail2affect_cluster(cocktail_rep):
|
7 |
-
affective_coordinates, _ = cocktail2affect(cocktail_rep)
|
8 |
-
return find_cluster(affective_coordinates)
|
9 |
-
return cocktail2affect_cluster
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/cocktails/representation_learning/__init__.py
DELETED
File without changes
|
src/cocktails/representation_learning/dataset.py
DELETED
@@ -1,324 +0,0 @@
|
|
1 |
-
from torch.utils.data import Dataset
|
2 |
-
import pickle
|
3 |
-
from src.cocktails.utilities.ingredients_utilities import extract_ingredients, ingredient_list, ingredient_profiles, ingredients_per_type
|
4 |
-
from src.cocktails.utilities.other_scrubbing_utilities import print_recipe
|
5 |
-
import numpy as np
|
6 |
-
|
7 |
-
def get_representation_from_ingredient(ingredients, quantities, max_q_per_ing, index, params):
|
8 |
-
assert len(ingredients) == len(quantities)
|
9 |
-
ing, q = ingredients[index], quantities[index]
|
10 |
-
proportion = q / np.sum(quantities)
|
11 |
-
index_ing = ingredient_list.index(ing)
|
12 |
-
# add keys of profile
|
13 |
-
rep_ingredient = []
|
14 |
-
rep_ingredient += [ingredient_profiles[k][index_ing] for k in params['ing_keys']]
|
15 |
-
# add category encoding
|
16 |
-
# rep_ingredient += list(params['category_encodings'][ingredient_profiles['type'][index_ing]])
|
17 |
-
# add quantitiy and relative quantity
|
18 |
-
rep_ingredient += [q / max_q_per_ing[ing], proportion]
|
19 |
-
ing_one_hot = np.zeros(len(ingredient_list))
|
20 |
-
ing_one_hot[index_ing] = 1
|
21 |
-
rep_ingredient += list(ing_one_hot)
|
22 |
-
indexes_to_normalize = list(range(len(params['ing_keys'])))
|
23 |
-
#TODO: should we add ing one hot? Or make sure no 2 ing have same embedding
|
24 |
-
return np.array(rep_ingredient), indexes_to_normalize
|
25 |
-
|
26 |
-
def get_max_n_ingredients(data):
|
27 |
-
max_count = 0
|
28 |
-
ingredient_set = set()
|
29 |
-
alcohol_set = set()
|
30 |
-
liqueur_set = set()
|
31 |
-
ing_str = np.array(data['ingredients_str'])
|
32 |
-
for i in range(len(data['names'])):
|
33 |
-
ingredients, quantities = extract_ingredients(ing_str[i])
|
34 |
-
max_count = max(max_count, len(ingredients))
|
35 |
-
for ing in ingredients:
|
36 |
-
ingredient_set.add(ing)
|
37 |
-
if ing in ingredients_per_type['liquor']:
|
38 |
-
alcohol_set.add(ing)
|
39 |
-
if ing in ingredients_per_type['liqueur']:
|
40 |
-
liqueur_set.add(ing)
|
41 |
-
return max_count, ingredient_set, alcohol_set, liqueur_set
|
42 |
-
|
43 |
-
# Add your custom dataset class here
|
44 |
-
class MyDataset(Dataset):
|
45 |
-
def __init__(self, split, params):
|
46 |
-
data = params['raw_data']
|
47 |
-
self.dim_rep_ingredient = params['dim_rep_ingredient']
|
48 |
-
n_data = len(data["names"])
|
49 |
-
|
50 |
-
preparation_list = sorted(set(data['category']))
|
51 |
-
categories_list = sorted(set(data['subcategory']))
|
52 |
-
glasses_list = sorted(set(data['glass']))
|
53 |
-
|
54 |
-
max_ingredients, ingredient_set, liquor_set, liqueur_set = get_max_n_ingredients(data)
|
55 |
-
ingredient_set = sorted(ingredient_set)
|
56 |
-
self.ingredient_set = ingredient_set
|
57 |
-
|
58 |
-
ingredient_quantities = [] # output of our network
|
59 |
-
ingr_strs = np.array(data['ingredients_str'])
|
60 |
-
for i in range(n_data):
|
61 |
-
ingredients, quantities = extract_ingredients(ingr_strs[i])
|
62 |
-
# get ingredient presence and quantity
|
63 |
-
ingredient_q_rep = np.zeros([len(ingredient_set)])
|
64 |
-
for ing, q in zip(ingredients, quantities):
|
65 |
-
ingredient_q_rep[ingredient_set.index(ing)] = q
|
66 |
-
ingredient_quantities.append(ingredient_q_rep)
|
67 |
-
|
68 |
-
# take care of ingredient quantities (OUTPUTS)
|
69 |
-
ingredient_quantities = np.array(ingredient_quantities)
|
70 |
-
ingredients_presence = (ingredient_quantities>0).astype(np.int)
|
71 |
-
|
72 |
-
min_ing_quantities = np.min(ingredient_quantities, axis=0)
|
73 |
-
max_ing_quantities = np.max(ingredient_quantities, axis=0)
|
74 |
-
def normalize_ing_quantities(ing_quantities):
|
75 |
-
return ((ing_quantities - min_ing_quantities) / (max_ing_quantities - min_ing_quantities)).copy()
|
76 |
-
|
77 |
-
def denormalize_ing_quantities(normalized_ing_quantities):
|
78 |
-
return (normalized_ing_quantities * (max_ing_quantities - min_ing_quantities) + min_ing_quantities).copy()
|
79 |
-
ing_q_when_present = ingredient_quantities.copy()
|
80 |
-
for i in range(len(ing_q_when_present)):
|
81 |
-
ing_q_when_present[i, np.where(ing_q_when_present[i, :] == 0)] = np.nan
|
82 |
-
self.min_when_present_ing_quantities = np.nanmin(ing_q_when_present, axis=0)
|
83 |
-
|
84 |
-
|
85 |
-
def filter_decoder_output(output):
|
86 |
-
output_unnormalized = output * max_ing_quantities
|
87 |
-
if output.ndim == 1:
|
88 |
-
output_unnormalized[np.where(output_unnormalized<self.min_when_present_ing_quantities)] = 0
|
89 |
-
else:
|
90 |
-
for i in range(output.shape[0]):
|
91 |
-
output_unnormalized[i, np.where(output_unnormalized[i] < self.min_when_present_ing_quantities)] = 0
|
92 |
-
return output_unnormalized.copy()
|
93 |
-
self.filter_decoder_output = filter_decoder_output
|
94 |
-
# arg_mins = np.nanargmin(ing_q_when_present, axis=0)
|
95 |
-
#
|
96 |
-
# for ing, minq, argminq in zip(ingredient_set, self.min_when_present_ing_quantities, arg_mins):
|
97 |
-
# print(f'__\n{ing}: {minq}')
|
98 |
-
# print_recipe(ingr_strs[argminq])
|
99 |
-
# ingredients, quantities = extract_ingredients(ingr_strs[argminq])
|
100 |
-
# # get ingredient presence and quantity
|
101 |
-
# ingredient_q_rep = np.zeros([len(ingredient_set)])
|
102 |
-
# for ing, q in zip(ingredients, quantities):
|
103 |
-
# ingredient_q_rep[ingredient_set.index(ing)] = q
|
104 |
-
# print(np.array(data['urls'])[argminq])
|
105 |
-
# stop = 1
|
106 |
-
|
107 |
-
self.max_ing_quantities = max_ing_quantities
|
108 |
-
self.mean_ing_quantities = np.mean(ingredient_quantities, axis=0)
|
109 |
-
self.std_ing_quantities = np.std(ingredient_quantities, axis=0)
|
110 |
-
if split == 'train':
|
111 |
-
np.savetxt(params['save_path'] + 'min_when_present_ing_quantities.txt', self.min_when_present_ing_quantities)
|
112 |
-
np.savetxt(params['save_path'] + 'max_ing_quantities.txt', max_ing_quantities)
|
113 |
-
np.savetxt(params['save_path'] + 'mean_ing_quantities.txt', self.mean_ing_quantities)
|
114 |
-
np.savetxt(params['save_path'] + 'std_ing_quantities.txt', self.std_ing_quantities)
|
115 |
-
|
116 |
-
# print(ingredient_quantities[0])
|
117 |
-
# ingredient_quantities = (ingredient_quantities - self.mean_ing_quantities) / self.std_ing_quantities
|
118 |
-
# print(ingredient_quantities[0])
|
119 |
-
# print(ingredient_quantities[0] * self.std_ing_quantities + self.mean_ing_quantities )
|
120 |
-
ingredient_quantities = ingredient_quantities / max_ing_quantities#= normalize_ing_quantities(ingredient_quantities)
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
max_q_per_ing = dict(zip(ingredient_set, max_ing_quantities))
|
126 |
-
# print(ingredient_quantities[0])
|
127 |
-
#########
|
128 |
-
# Process input representation_analysis: list of ingredient representation_analysis
|
129 |
-
#########
|
130 |
-
input_data = [] # input of ingredient encoders
|
131 |
-
all_ing_reps = []
|
132 |
-
for i in range(n_data):
|
133 |
-
ingredients, quantities = extract_ingredients(ingr_strs[i])
|
134 |
-
# get ingredient presence and quantity
|
135 |
-
ingredient_q_rep = np.zeros([len(ingredient_set)])
|
136 |
-
for ing, q in zip(ingredients, quantities):
|
137 |
-
ingredient_q_rep[ingredient_set.index(ing)] = q
|
138 |
-
# get main liquor
|
139 |
-
cocktail_rep = []
|
140 |
-
for j in range(len(ingredients)):
|
141 |
-
cocktail_rep.append(get_representation_from_ingredient(ingredients, quantities, max_q_per_ing, index=j, params=params)[0])
|
142 |
-
all_ing_reps.append(cocktail_rep[-1].copy())
|
143 |
-
input_data.append(cocktail_rep)
|
144 |
-
|
145 |
-
|
146 |
-
all_ing_reps = np.array(all_ing_reps)
|
147 |
-
min_ing_reps = np.min(all_ing_reps[:, params['indexes_ing_to_normalize']], axis=0)
|
148 |
-
max_ing_reps = np.max(all_ing_reps[:, params['indexes_ing_to_normalize']], axis=0)
|
149 |
-
|
150 |
-
def normalize_ing_reps(ing_reps):
|
151 |
-
if ing_reps.ndim == 1:
|
152 |
-
ing_reps = ing_reps.reshape(1, -1)
|
153 |
-
out = ing_reps.copy()
|
154 |
-
out[:, params['indexes_ing_to_normalize']] = (out[:, params['indexes_ing_to_normalize']] - min_ing_reps) / (max_ing_reps - min_ing_reps)
|
155 |
-
return out
|
156 |
-
|
157 |
-
def denormalize_ing_reps(normalized_ing_reps):
|
158 |
-
if normalized_ing_reps.ndim == 1:
|
159 |
-
normalized_ing_reps = normalized_ing_reps.reshape(1, -1)
|
160 |
-
out = normalized_ing_reps.copy()
|
161 |
-
out[:, params['indexes_ing_to_normalize']] = out[:, params['indexes_ing_to_normalize']] * (max_ing_reps - min_ing_reps) + min_ing_reps
|
162 |
-
return out
|
163 |
-
|
164 |
-
|
165 |
-
# put everything in a big matrix
|
166 |
-
dim_cocktail_rep = max_ingredients * self.dim_rep_ingredient
|
167 |
-
input_data2 = []
|
168 |
-
nb_ingredients = []
|
169 |
-
for d in input_data:
|
170 |
-
cocktail_rep = np.zeros([dim_cocktail_rep])
|
171 |
-
cocktail_rep.fill(np.nan)
|
172 |
-
index = 0
|
173 |
-
nb_ingredients.append(len(d))
|
174 |
-
for dj in d:
|
175 |
-
cocktail_rep[index:index + self.dim_rep_ingredient] = normalize_ing_reps(dj)
|
176 |
-
index += self.dim_rep_ingredient
|
177 |
-
input_data2.append(cocktail_rep)
|
178 |
-
input_data = np.array(input_data2)
|
179 |
-
nb_ingredients = np.array(nb_ingredients)
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
# let us now extract various possible output we might want to predict:
|
186 |
-
#########
|
187 |
-
# Process output cocktail representation_analysis (computed from ingredient reps)
|
188 |
-
#########
|
189 |
-
# quantities_indexes = np.arange(20, 456, 57)
|
190 |
-
# qs = input_data[0, quantities_indexes]
|
191 |
-
# ingredient_quantities[0]
|
192 |
-
# get final volume
|
193 |
-
volumes = np.array(params['raw_data']['end volume'])
|
194 |
-
|
195 |
-
min_vol = volumes.min()
|
196 |
-
max_vol = volumes.max()
|
197 |
-
def normalize_vol(volume):
|
198 |
-
return (volume - min_vol) / (max_vol - min_vol)
|
199 |
-
|
200 |
-
def denormalize_vol(normalized_vol):
|
201 |
-
return normalized_vol * (max_vol - min_vol) + min_vol
|
202 |
-
|
203 |
-
volumes = normalize_vol(volumes)
|
204 |
-
|
205 |
-
|
206 |
-
# computed cocktail representation
|
207 |
-
computed_cocktail_reps = params['cocktail_reps']
|
208 |
-
self.dim_rep = computed_cocktail_reps.shape[1]
|
209 |
-
|
210 |
-
#########
|
211 |
-
# Process output sub categories
|
212 |
-
#########
|
213 |
-
categories = np.array([categories_list.index(sc) for sc in data['subcategory']])
|
214 |
-
counts = dict(zip(categories_list, [0] * len(categories)))
|
215 |
-
for c in data['subcategory']:
|
216 |
-
counts[c] += 1
|
217 |
-
for k in counts.keys():
|
218 |
-
counts[k] /= len(data['subcategory'])
|
219 |
-
self.categories = categories_list
|
220 |
-
self.categories_weights = []
|
221 |
-
for c in self.categories:
|
222 |
-
self.categories_weights.append(1/len(self.categories)/counts[c])
|
223 |
-
print(counts)
|
224 |
-
|
225 |
-
#########
|
226 |
-
# Process output glass type
|
227 |
-
#########
|
228 |
-
glasses = np.array([glasses_list.index(sc) for sc in data['glass']])
|
229 |
-
counts = dict(zip(glasses_list, [0] * len(set(data['glass']))))
|
230 |
-
for c in data['glass']:
|
231 |
-
counts[c] += 1
|
232 |
-
for k in counts.keys():
|
233 |
-
counts[k] /= len(data['glass'])
|
234 |
-
self.glasses = glasses_list
|
235 |
-
self.glasses_weights = []
|
236 |
-
for c in self.glasses:
|
237 |
-
self.glasses_weights.append(1 / len(self.glasses) / counts[c])
|
238 |
-
print(counts)
|
239 |
-
|
240 |
-
#########
|
241 |
-
# Process output preparation type
|
242 |
-
#########
|
243 |
-
prep_type = np.array([preparation_list.index(sc) for sc in data['category']])
|
244 |
-
counts = dict(zip(preparation_list, [0] * len(preparation_list)))
|
245 |
-
for c in data['category']:
|
246 |
-
counts[c] += 1
|
247 |
-
for k in counts.keys():
|
248 |
-
counts[k] /= len(data['category'])
|
249 |
-
self.prep_types = preparation_list
|
250 |
-
self.prep_types_weights = []
|
251 |
-
for c in self.prep_types:
|
252 |
-
self.prep_types_weights.append(1 / len(self.prep_types) / counts[c])
|
253 |
-
print(counts)
|
254 |
-
|
255 |
-
taste_reps = list(data['taste_rep'])
|
256 |
-
taste_rep_ground_truth = []
|
257 |
-
taste_rep_valid = []
|
258 |
-
for tr in taste_reps:
|
259 |
-
if len(tr) > 2:
|
260 |
-
taste_rep_valid.append(True)
|
261 |
-
taste_rep_ground_truth.append([float(tr.split('[')[1].split(',')[0]), float(tr.split(']')[0].split(',')[1][1:])])
|
262 |
-
else:
|
263 |
-
taste_rep_valid.append(False)
|
264 |
-
taste_rep_ground_truth.append([np.nan, np.nan])
|
265 |
-
taste_rep_ground_truth = np.array(taste_rep_ground_truth)
|
266 |
-
taste_rep_valid = np.array(taste_rep_valid)
|
267 |
-
taste_rep_ground_truth /= 10
|
268 |
-
|
269 |
-
auxiliary_data = dict(categories=categories,
|
270 |
-
glasses=glasses,
|
271 |
-
prep_type=prep_type,
|
272 |
-
cocktail_reps=computed_cocktail_reps,
|
273 |
-
ingredients_presence=ingredients_presence,
|
274 |
-
taste_reps=taste_rep_ground_truth,
|
275 |
-
volume=volumes,
|
276 |
-
ingredients_quantities=ingredient_quantities)
|
277 |
-
self.auxiliary_keys = sorted(params['auxiliaries_dict'].keys())
|
278 |
-
assert self.auxiliary_keys == sorted(auxiliary_data.keys())
|
279 |
-
|
280 |
-
data_preprocessing = dict(min_max_ing_quantities=(min_ing_quantities, max_ing_quantities),
|
281 |
-
min_max_ing_reps=(min_ing_reps, max_ing_reps),
|
282 |
-
min_max_vol=(min_vol, max_vol))
|
283 |
-
|
284 |
-
if split == 'train':
|
285 |
-
with open(params['save_path'] + 'normalization_funcs.pickle', 'wb') as f:
|
286 |
-
pickle.dump(data_preprocessing, f)
|
287 |
-
|
288 |
-
n_data = len(input_data)
|
289 |
-
assert len(ingredient_quantities) == n_data
|
290 |
-
for aux in self.auxiliary_keys:
|
291 |
-
assert len(auxiliary_data[aux]) == n_data
|
292 |
-
|
293 |
-
if split == 'train':
|
294 |
-
indexes = np.arange(int(0.9 * n_data))
|
295 |
-
elif split == 'test':
|
296 |
-
indexes = np.arange(int(0.9 * n_data), n_data)
|
297 |
-
elif split == 'all':
|
298 |
-
indexes = np.arange(n_data)
|
299 |
-
else:
|
300 |
-
raise ValueError
|
301 |
-
|
302 |
-
# np.random.shuffle(indexes)
|
303 |
-
self.taste_rep_valid = taste_rep_valid[indexes]
|
304 |
-
self.input_ingredients = input_data[indexes]
|
305 |
-
self.ingredient_quantities = ingredient_quantities[indexes]
|
306 |
-
self.computed_cocktail_reps = computed_cocktail_reps[indexes]
|
307 |
-
self.auxiliaries = dict()
|
308 |
-
for aux in self.auxiliary_keys:
|
309 |
-
self.auxiliaries[aux] = auxiliary_data[aux][indexes]
|
310 |
-
self.nb_ingredients = nb_ingredients[indexes]
|
311 |
-
|
312 |
-
def __len__(self):
|
313 |
-
return len(self.input_ingredients)
|
314 |
-
|
315 |
-
def get_auxiliary_data(self, idx):
|
316 |
-
out = dict()
|
317 |
-
for aux in self.auxiliary_keys:
|
318 |
-
out[aux] = self.auxiliaries[aux][idx]
|
319 |
-
return out
|
320 |
-
|
321 |
-
def __getitem__(self, idx):
|
322 |
-
assert self.nb_ingredients[idx] == np.argwhere(~np.isnan(self.input_ingredients[idx])).flatten().size / self.dim_rep_ingredient
|
323 |
-
return [self.nb_ingredients[idx], self.input_ingredients[idx], self.ingredient_quantities[idx], self.computed_cocktail_reps[idx], self.get_auxiliary_data(idx),
|
324 |
-
self.taste_rep_valid[idx]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/cocktails/representation_learning/multihead_model.py
DELETED
@@ -1,148 +0,0 @@
|
|
1 |
-
import torch; torch.manual_seed(0)
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
import torch.utils
|
5 |
-
import torch.distributions
|
6 |
-
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
|
7 |
-
from src.cocktails.representation_learning.simple_model import SimpleNet
|
8 |
-
|
9 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
10 |
-
|
11 |
-
def get_activation(activation):
|
12 |
-
if activation == 'tanh':
|
13 |
-
activ = F.tanh
|
14 |
-
elif activation == 'relu':
|
15 |
-
activ = F.relu
|
16 |
-
elif activation == 'mish':
|
17 |
-
activ = F.mish
|
18 |
-
elif activation == 'sigmoid':
|
19 |
-
activ = F.sigmoid
|
20 |
-
elif activation == 'leakyrelu':
|
21 |
-
activ = F.leaky_relu
|
22 |
-
elif activation == 'exp':
|
23 |
-
activ = torch.exp
|
24 |
-
else:
|
25 |
-
raise ValueError
|
26 |
-
return activ
|
27 |
-
|
28 |
-
class IngredientEncoder(nn.Module):
|
29 |
-
def __init__(self, input_dim, deepset_latent_dim, hidden_dims, activation, dropout):
|
30 |
-
super(IngredientEncoder, self).__init__()
|
31 |
-
self.linears = nn.ModuleList()
|
32 |
-
self.dropouts = nn.ModuleList()
|
33 |
-
dims = [input_dim] + hidden_dims + [deepset_latent_dim]
|
34 |
-
for d_in, d_out in zip(dims[:-1], dims[1:]):
|
35 |
-
self.linears.append(nn.Linear(d_in, d_out))
|
36 |
-
self.dropouts.append(nn.Dropout(dropout))
|
37 |
-
self.activation = get_activation(activation)
|
38 |
-
self.n_layers = len(self.linears)
|
39 |
-
self.layer_range = range(self.n_layers)
|
40 |
-
|
41 |
-
def forward(self, x):
|
42 |
-
for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts):
|
43 |
-
x = layer(x)
|
44 |
-
if i_layer != self.n_layers - 1:
|
45 |
-
x = self.activation(dropout(x))
|
46 |
-
return x # do not use dropout on last layer?
|
47 |
-
|
48 |
-
class DeepsetCocktailEncoder(nn.Module):
|
49 |
-
def __init__(self, input_dim, deepset_latent_dim, hidden_dims_ing, activation,
|
50 |
-
hidden_dims_cocktail, latent_dim, aggregation, dropout):
|
51 |
-
super(DeepsetCocktailEncoder, self).__init__()
|
52 |
-
self.input_dim = input_dim # dimension of ingredient representation + quantity
|
53 |
-
self.ingredient_encoder = IngredientEncoder(input_dim, deepset_latent_dim, hidden_dims_ing, activation, dropout) # encode each ingredient separately
|
54 |
-
self.deepset_latent_dim = deepset_latent_dim # dimension of the deepset aggregation
|
55 |
-
self.aggregation = aggregation
|
56 |
-
self.latent_dim = latent_dim
|
57 |
-
# post aggregation network
|
58 |
-
self.linears = nn.ModuleList()
|
59 |
-
self.dropouts = nn.ModuleList()
|
60 |
-
dims = [deepset_latent_dim] + hidden_dims_cocktail
|
61 |
-
for d_in, d_out in zip(dims[:-1], dims[1:]):
|
62 |
-
self.linears.append(nn.Linear(d_in, d_out))
|
63 |
-
self.dropouts.append(nn.Dropout(dropout))
|
64 |
-
self.FC_mean = nn.Linear(hidden_dims_cocktail[-1], latent_dim)
|
65 |
-
self.FC_logvar = nn.Linear(hidden_dims_cocktail[-1], latent_dim)
|
66 |
-
self.softplus = nn.Softplus()
|
67 |
-
|
68 |
-
self.activation = get_activation(activation)
|
69 |
-
self.n_layers = len(self.linears)
|
70 |
-
self.layer_range = range(self.n_layers)
|
71 |
-
|
72 |
-
def forward(self, nb_ingredients, x):
|
73 |
-
|
74 |
-
# reshape x in (batch size * nb ingredients, dim_ing_rep)
|
75 |
-
batch_size = x.shape[0]
|
76 |
-
all_ingredients = []
|
77 |
-
for i in range(batch_size):
|
78 |
-
for j in range(nb_ingredients[i]):
|
79 |
-
all_ingredients.append(x[i, self.input_dim * j: self.input_dim * (j + 1)].reshape(1, -1))
|
80 |
-
x = torch.cat(all_ingredients, dim=0)
|
81 |
-
# encode ingredients in parallel
|
82 |
-
ingredients_encodings = self.ingredient_encoder(x)
|
83 |
-
assert ingredients_encodings.shape == (torch.sum(nb_ingredients), self.deepset_latent_dim)
|
84 |
-
|
85 |
-
# aggregate
|
86 |
-
x = []
|
87 |
-
index_first = 0
|
88 |
-
for i in range(batch_size):
|
89 |
-
index_last = index_first + nb_ingredients[i]
|
90 |
-
# aggregate
|
91 |
-
if self.aggregation == 'sum':
|
92 |
-
x.append(torch.sum(ingredients_encodings[index_first:index_last], dim=0).reshape(1, -1))
|
93 |
-
elif self.aggregation == 'mean':
|
94 |
-
x.append(torch.mean(ingredients_encodings[index_first:index_last], dim=0).reshape(1, -1))
|
95 |
-
else:
|
96 |
-
raise ValueError
|
97 |
-
index_first = index_last
|
98 |
-
x = torch.cat(x, dim=0)
|
99 |
-
assert x.shape[0] == batch_size
|
100 |
-
|
101 |
-
for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts):
|
102 |
-
x = self.activation(dropout(layer(x)))
|
103 |
-
mean = self.FC_mean(x)
|
104 |
-
logvar = self.FC_logvar(x)
|
105 |
-
return mean, logvar
|
106 |
-
|
107 |
-
|
108 |
-
class MultiHeadModel(nn.Module):
|
109 |
-
def __init__(self, encoder, auxiliaries_dict, activation, hidden_dims_decoder):
|
110 |
-
super(MultiHeadModel, self).__init__()
|
111 |
-
self.encoder = encoder
|
112 |
-
self.latent_dim = self.encoder.output_dim
|
113 |
-
self.auxiliaries_str = []
|
114 |
-
self.auxiliaries = nn.ModuleList()
|
115 |
-
for aux_str in sorted(auxiliaries_dict.keys()):
|
116 |
-
if aux_str == 'taste_reps':
|
117 |
-
self.taste_reps_decoder = SimpleNet(input_dim=self.latent_dim, hidden_dims=[], output_dim=auxiliaries_dict[aux_str]['dim_output'],
|
118 |
-
activation=activation, dropout=0.0, final_activ=auxiliaries_dict[aux_str]['final_activ'])
|
119 |
-
else:
|
120 |
-
self.auxiliaries_str.append(aux_str)
|
121 |
-
if aux_str == 'ingredients_quantities':
|
122 |
-
hd = hidden_dims_decoder
|
123 |
-
else:
|
124 |
-
hd = []
|
125 |
-
self.auxiliaries.append(SimpleNet(input_dim=self.latent_dim, hidden_dims=hd, output_dim=auxiliaries_dict[aux_str]['dim_output'],
|
126 |
-
activation=activation, dropout=0.0, final_activ=auxiliaries_dict[aux_str]['final_activ']))
|
127 |
-
|
128 |
-
def get_all_auxiliaries(self, x):
|
129 |
-
return [aux(x) for aux in self.auxiliaries]
|
130 |
-
|
131 |
-
def get_auxiliary(self, z, aux_str):
|
132 |
-
if aux_str == 'taste_reps':
|
133 |
-
return self.taste_reps_decoder(z)
|
134 |
-
else:
|
135 |
-
index = self.auxiliaries_str.index(aux_str)
|
136 |
-
return self.auxiliaries[index](z)
|
137 |
-
|
138 |
-
def forward(self, x, aux_str=None):
|
139 |
-
z = self.encoder(x)
|
140 |
-
if aux_str is not None:
|
141 |
-
return z, self.get_auxiliary(z, aux_str), [aux_str]
|
142 |
-
else:
|
143 |
-
return z, self.get_all_auxiliaries(z), self.auxiliaries_str
|
144 |
-
|
145 |
-
def get_multihead_model(input_dim, activation, hidden_dims_cocktail, latent_dim, dropout, auxiliaries_dict, hidden_dims_decoder):
|
146 |
-
encoder = SimpleNet(input_dim, hidden_dims_cocktail, latent_dim, activation, dropout)
|
147 |
-
model = MultiHeadModel(encoder, auxiliaries_dict, activation, hidden_dims_decoder)
|
148 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/cocktails/representation_learning/run.py
DELETED
@@ -1,557 +0,0 @@
|
|
1 |
-
import torch; torch.manual_seed(0)
|
2 |
-
import torch.utils
|
3 |
-
from torch.utils.data import DataLoader
|
4 |
-
import torch.distributions
|
5 |
-
import torch.nn as nn
|
6 |
-
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
|
7 |
-
from src.cocktails.representation_learning.dataset import MyDataset, get_representation_from_ingredient, get_max_n_ingredients
|
8 |
-
import json
|
9 |
-
import pandas as pd
|
10 |
-
import numpy as np
|
11 |
-
import os
|
12 |
-
from src.cocktails.representation_learning.vae_model import get_vae_model
|
13 |
-
from src.cocktails.config import COCKTAILS_CSV_DATA, FULL_COCKTAIL_REP_PATH, EXPERIMENT_PATH
|
14 |
-
from src.cocktails.utilities.cocktail_utilities import get_bunch_of_rep_keys
|
15 |
-
from src.cocktails.utilities.ingredients_utilities import ingredient_profiles
|
16 |
-
from resource import getrusage
|
17 |
-
from resource import RUSAGE_SELF
|
18 |
-
import gc
|
19 |
-
gc.collect(2)
|
20 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
21 |
-
|
22 |
-
def get_params():
|
23 |
-
data = pd.read_csv(COCKTAILS_CSV_DATA)
|
24 |
-
max_ingredients, ingredient_set, liquor_set, liqueur_set = get_max_n_ingredients(data)
|
25 |
-
num_ingredients = len(ingredient_set)
|
26 |
-
rep_keys = get_bunch_of_rep_keys()['custom']
|
27 |
-
ing_keys = [k.split(' ')[1] for k in rep_keys]
|
28 |
-
ing_keys.remove('volume')
|
29 |
-
nb_ing_categories = len(set(ingredient_profiles['type']))
|
30 |
-
category_encodings = dict(zip(sorted(set(ingredient_profiles['type'])), np.eye(nb_ing_categories)))
|
31 |
-
|
32 |
-
params = dict(trial_id='test',
|
33 |
-
save_path=EXPERIMENT_PATH + "/deepset_vae/",
|
34 |
-
nb_epochs=2000,
|
35 |
-
print_every=50,
|
36 |
-
plot_every=100,
|
37 |
-
batch_size=64,
|
38 |
-
lr=0.001,
|
39 |
-
dropout=0.,
|
40 |
-
nb_epoch_switch_beta=600,
|
41 |
-
latent_dim=10,
|
42 |
-
beta_vae=0.2,
|
43 |
-
ing_keys=ing_keys,
|
44 |
-
nb_ingredients=len(ingredient_set),
|
45 |
-
hidden_dims_ingredients=[128],
|
46 |
-
hidden_dims_cocktail=[32],
|
47 |
-
hidden_dims_decoder=[32],
|
48 |
-
agg='mean',
|
49 |
-
activation='relu',
|
50 |
-
auxiliaries_dict=dict(categories=dict(weight=0, type='classif', final_activ=None, dim_output=len(set(data['subcategory']))),
|
51 |
-
glasses=dict(weight=0, type='classif', final_activ=None, dim_output=len(set(data['glass']))),
|
52 |
-
prep_type=dict(weight=0, type='classif', final_activ=None, dim_output=len(set(data['category']))),
|
53 |
-
cocktail_reps=dict(weight=0, type='regression', final_activ=None, dim_output=13),
|
54 |
-
volume=dict(weight=0, type='regression', final_activ='relu', dim_output=1),
|
55 |
-
taste_reps=dict(weight=0, type='regression', final_activ='relu', dim_output=2),
|
56 |
-
ingredients_presence=dict(weight=0, type='multiclassif', final_activ=None, dim_output=num_ingredients)),
|
57 |
-
category_encodings=category_encodings
|
58 |
-
)
|
59 |
-
# params = dict(trial_id='test',
|
60 |
-
# save_path=EXPERIMENT_PATH + "/deepset_vae/",
|
61 |
-
# nb_epochs=1000,
|
62 |
-
# print_every=50,
|
63 |
-
# plot_every=100,
|
64 |
-
# batch_size=64,
|
65 |
-
# lr=0.001,
|
66 |
-
# dropout=0.,
|
67 |
-
# nb_epoch_switch_beta=500,
|
68 |
-
# latent_dim=64,
|
69 |
-
# beta_vae=0.3,
|
70 |
-
# ing_keys=ing_keys,
|
71 |
-
# nb_ingredients=len(ingredient_set),
|
72 |
-
# hidden_dims_ingredients=[128],
|
73 |
-
# hidden_dims_cocktail=[128, 128],
|
74 |
-
# hidden_dims_decoder=[128, 128],
|
75 |
-
# agg='mean',
|
76 |
-
# activation='mish',
|
77 |
-
# auxiliaries_dict=dict(categories=dict(weight=0.5, type='classif', final_activ=None, dim_output=len(set(data['subcategory']))),
|
78 |
-
# glasses=dict(weight=0.03, type='classif', final_activ=None, dim_output=len(set(data['glass']))),
|
79 |
-
# prep_type=dict(weight=0.02, type='classif', final_activ=None, dim_output=len(set(data['category']))),
|
80 |
-
# cocktail_reps=dict(weight=1, type='regression', final_activ=None, dim_output=13),
|
81 |
-
# volume=dict(weight=1, type='regression', final_activ='relu', dim_output=1),
|
82 |
-
# taste_reps=dict(weight=1, type='regression', final_activ='relu', dim_output=2),
|
83 |
-
# ingredients_presence=dict(weight=1.5, type='multiclassif', final_activ=None, dim_output=num_ingredients)),
|
84 |
-
# category_encodings=category_encodings
|
85 |
-
# )
|
86 |
-
water_rep, indexes_to_normalize = get_representation_from_ingredient(ingredients=['water'], quantities=[1],
|
87 |
-
max_q_per_ing=dict(zip(ingredient_set, [1] * num_ingredients)), index=0,
|
88 |
-
params=params)
|
89 |
-
dim_rep_ingredient = water_rep.size
|
90 |
-
params['indexes_ing_to_normalize'] = indexes_to_normalize
|
91 |
-
params['deepset_latent_dim'] = dim_rep_ingredient * max_ingredients
|
92 |
-
params['input_dim'] = dim_rep_ingredient
|
93 |
-
params['dim_rep_ingredient'] = dim_rep_ingredient
|
94 |
-
params = compute_expe_name_and_save_path(params)
|
95 |
-
del params['category_encodings'] # to dump
|
96 |
-
with open(params['save_path'] + 'params.json', 'w') as f:
|
97 |
-
json.dump(params, f)
|
98 |
-
|
99 |
-
params = complete_params(params)
|
100 |
-
return params
|
101 |
-
|
102 |
-
def complete_params(params):
|
103 |
-
data = pd.read_csv(COCKTAILS_CSV_DATA)
|
104 |
-
cocktail_reps = np.loadtxt(FULL_COCKTAIL_REP_PATH)
|
105 |
-
nb_ing_categories = len(set(ingredient_profiles['type']))
|
106 |
-
category_encodings = dict(zip(sorted(set(ingredient_profiles['type'])), np.eye(nb_ing_categories)))
|
107 |
-
params['cocktail_reps'] = cocktail_reps
|
108 |
-
params['raw_data'] = data
|
109 |
-
params['category_encodings'] = category_encodings
|
110 |
-
return params
|
111 |
-
|
112 |
-
def compute_losses_and_accuracies(loss_functions, auxiliaries, auxiliaries_str, outputs, data):
|
113 |
-
losses = dict()
|
114 |
-
accuracies = dict()
|
115 |
-
other_metrics = dict()
|
116 |
-
for i_k, k in enumerate(auxiliaries_str):
|
117 |
-
# get ground truth
|
118 |
-
# compute loss
|
119 |
-
if k == 'volume':
|
120 |
-
outputs[i_k] = outputs[i_k].flatten()
|
121 |
-
ground_truth = auxiliaries[k]
|
122 |
-
if ground_truth.dtype == torch.float64:
|
123 |
-
losses[k] = loss_functions[k](outputs[i_k], ground_truth.float()).float()
|
124 |
-
elif ground_truth.dtype == torch.int64:
|
125 |
-
if str(loss_functions[k]) != "BCEWithLogitsLoss()":
|
126 |
-
losses[k] = loss_functions[k](outputs[i_k].float(), ground_truth.long()).float()
|
127 |
-
else:
|
128 |
-
losses[k] = loss_functions[k](outputs[i_k].float(), ground_truth.float()).float()
|
129 |
-
else:
|
130 |
-
losses[k] = loss_functions[k](outputs[i_k], ground_truth).float()
|
131 |
-
# compute accuracies
|
132 |
-
if str(loss_functions[k]) == 'CrossEntropyLoss()':
|
133 |
-
bs, n_options = outputs[i_k].shape
|
134 |
-
predicted = outputs[i_k].argmax(dim=1).detach().numpy()
|
135 |
-
true = ground_truth.int().detach().numpy()
|
136 |
-
confusion_matrix = np.zeros([n_options, n_options])
|
137 |
-
for i in range(bs):
|
138 |
-
confusion_matrix[true[i], predicted[i]] += 1
|
139 |
-
acc = confusion_matrix.diagonal().sum() / bs
|
140 |
-
for i in range(n_options):
|
141 |
-
if confusion_matrix[i].sum() != 0:
|
142 |
-
confusion_matrix[i] /= confusion_matrix[i].sum()
|
143 |
-
other_metrics[k + '_confusion'] = confusion_matrix
|
144 |
-
accuracies[k] = np.mean(outputs[i_k].argmax(dim=1).detach().numpy() == ground_truth.int().detach().numpy())
|
145 |
-
assert (acc - accuracies[k]) < 1e-5
|
146 |
-
|
147 |
-
elif str(loss_functions[k]) == 'BCEWithLogitsLoss()':
|
148 |
-
assert k == 'ingredients_presence'
|
149 |
-
outputs_rescaled = outputs[i_k].detach().numpy() * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities
|
150 |
-
predicted_presence = (outputs_rescaled > 0).astype(bool)
|
151 |
-
presence = ground_truth.detach().numpy().astype(bool)
|
152 |
-
other_metrics[k + '_false_positive'] = np.mean(np.logical_and(predicted_presence.astype(bool), ~presence.astype(bool)))
|
153 |
-
other_metrics[k + '_false_negative'] = np.mean(np.logical_and(~predicted_presence.astype(bool), presence.astype(bool)))
|
154 |
-
accuracies[k] = np.mean(predicted_presence == presence) # accuracy for multi class labeling
|
155 |
-
elif str(loss_functions[k]) == 'MSELoss()':
|
156 |
-
accuracies[k] = np.nan
|
157 |
-
else:
|
158 |
-
raise ValueError
|
159 |
-
return losses, accuracies, other_metrics
|
160 |
-
|
161 |
-
def compute_metric_output(aux_other_metrics, data, ingredient_quantities, x_hat):
|
162 |
-
ing_q = ingredient_quantities.detach().numpy() * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities
|
163 |
-
ing_presence = (ing_q > 0)
|
164 |
-
x_hat = x_hat.detach().numpy() * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities
|
165 |
-
# abs_diff = np.abs(ing_q - x_hat) * data.dataset.max_ing_quantities
|
166 |
-
abs_diff = np.abs(ing_q - x_hat)
|
167 |
-
ing_q_abs_loss_when_present, ing_q_abs_loss_when_absent = [], []
|
168 |
-
for i in range(ingredient_quantities.shape[0]):
|
169 |
-
ing_q_abs_loss_when_present.append(np.mean(abs_diff[i, np.where(ing_presence[i])]))
|
170 |
-
ing_q_abs_loss_when_absent.append(np.mean(abs_diff[i, np.where(~ing_presence[i])]))
|
171 |
-
aux_other_metrics['ing_q_abs_loss_when_present'] = np.mean(ing_q_abs_loss_when_present)
|
172 |
-
aux_other_metrics['ing_q_abs_loss_when_absent'] = np.mean(ing_q_abs_loss_when_absent)
|
173 |
-
return aux_other_metrics
|
174 |
-
|
175 |
-
def run_epoch(opt, train, model, data, loss_functions, weights, params):
|
176 |
-
if train:
|
177 |
-
model.train()
|
178 |
-
else:
|
179 |
-
model.eval()
|
180 |
-
|
181 |
-
# prepare logging of losses
|
182 |
-
losses = dict(kld_loss=[],
|
183 |
-
mse_loss=[],
|
184 |
-
vae_loss=[],
|
185 |
-
volume_loss=[],
|
186 |
-
global_loss=[])
|
187 |
-
accuracies = dict()
|
188 |
-
other_metrics = dict()
|
189 |
-
for aux in params['auxiliaries_dict'].keys():
|
190 |
-
losses[aux] = []
|
191 |
-
accuracies[aux] = []
|
192 |
-
if train: opt.zero_grad()
|
193 |
-
|
194 |
-
for d in data:
|
195 |
-
nb_ingredients = d[0]
|
196 |
-
batch_size = nb_ingredients.shape[0]
|
197 |
-
x_ingredients = d[1].float()
|
198 |
-
ingredient_quantities = d[2]
|
199 |
-
cocktail_reps = d[3]
|
200 |
-
auxiliaries = d[4]
|
201 |
-
for k in auxiliaries.keys():
|
202 |
-
if auxiliaries[k].dtype == torch.float64: auxiliaries[k] = auxiliaries[k].float()
|
203 |
-
taste_valid = d[-1]
|
204 |
-
x = x_ingredients.to(device)
|
205 |
-
x_hat, z, mean, log_var, outputs, auxiliaries_str = model.forward_direct(ingredient_quantities.float())
|
206 |
-
# get auxiliary losses and accuracies
|
207 |
-
aux_losses, aux_accuracies, aux_other_metrics = compute_losses_and_accuracies(loss_functions, auxiliaries, auxiliaries_str, outputs, data)
|
208 |
-
|
209 |
-
# compute vae loss
|
210 |
-
mse_loss = ((ingredient_quantities - x_hat) ** 2).mean().float()
|
211 |
-
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mean ** 2 - log_var.exp(), dim=1)).float()
|
212 |
-
vae_loss = mse_loss + params['beta_vae'] * (params['latent_dim'] / params['nb_ingredients']) * kld_loss
|
213 |
-
# compute total volume loss to train decoder
|
214 |
-
# volume_loss = ((ingredient_quantities.sum(dim=1) - x_hat.sum(dim=1)) ** 2).mean().float()
|
215 |
-
volume_loss = torch.FloatTensor([0])
|
216 |
-
|
217 |
-
aux_other_metrics = compute_metric_output(aux_other_metrics, data, ingredient_quantities, x_hat)
|
218 |
-
|
219 |
-
indexes_taste_valid = np.argwhere(taste_valid.detach().numpy()).flatten()
|
220 |
-
if indexes_taste_valid.size > 0:
|
221 |
-
outputs_taste = model.get_auxiliary(z[indexes_taste_valid], aux_str='taste_reps')
|
222 |
-
gt = auxiliaries['taste_reps'][indexes_taste_valid]
|
223 |
-
factor_loss = indexes_taste_valid.size / (0.3 * batch_size)# factor on the loss: if same ratio as actual dataset factor = 1 if there is less data, then the factor decreases, more data, it increases
|
224 |
-
aux_losses['taste_reps'] = (loss_functions['taste_reps'](outputs_taste, gt) * factor_loss).float()
|
225 |
-
else:
|
226 |
-
aux_losses['taste_reps'] = torch.FloatTensor([0]).reshape([])
|
227 |
-
aux_accuracies['taste_reps'] = 0
|
228 |
-
|
229 |
-
# aggregate losses
|
230 |
-
global_loss = torch.sum(torch.cat([torch.atleast_1d(vae_loss), torch.atleast_1d(volume_loss)] + [torch.atleast_1d(aux_losses[k] * weights[k]) for k in params['auxiliaries_dict'].keys()]))
|
231 |
-
# for k in params['auxiliaries_dict'].keys():
|
232 |
-
# global_loss += aux_losses[k] * weights[k]
|
233 |
-
|
234 |
-
if train:
|
235 |
-
global_loss.backward()
|
236 |
-
opt.step()
|
237 |
-
opt.zero_grad()
|
238 |
-
|
239 |
-
# logging
|
240 |
-
losses['global_loss'].append(float(global_loss))
|
241 |
-
losses['mse_loss'].append(float(mse_loss))
|
242 |
-
losses['vae_loss'].append(float(vae_loss))
|
243 |
-
losses['volume_loss'].append(float(volume_loss))
|
244 |
-
losses['kld_loss'].append(float(kld_loss))
|
245 |
-
for k in params['auxiliaries_dict'].keys():
|
246 |
-
losses[k].append(float(aux_losses[k]))
|
247 |
-
accuracies[k].append(float(aux_accuracies[k]))
|
248 |
-
for k in aux_other_metrics.keys():
|
249 |
-
if k not in other_metrics.keys():
|
250 |
-
other_metrics[k] = [aux_other_metrics[k]]
|
251 |
-
else:
|
252 |
-
other_metrics[k].append(aux_other_metrics[k])
|
253 |
-
|
254 |
-
for k in losses.keys():
|
255 |
-
losses[k] = np.mean(losses[k])
|
256 |
-
for k in accuracies.keys():
|
257 |
-
accuracies[k] = np.mean(accuracies[k])
|
258 |
-
for k in other_metrics.keys():
|
259 |
-
other_metrics[k] = np.mean(other_metrics[k], axis=0)
|
260 |
-
return model, losses, accuracies, other_metrics
|
261 |
-
|
262 |
-
def prepare_data_and_loss(params):
|
263 |
-
train_data = MyDataset(split='train', params=params)
|
264 |
-
test_data = MyDataset(split='test', params=params)
|
265 |
-
|
266 |
-
train_data_loader = DataLoader(train_data, batch_size=params['batch_size'], shuffle=True)
|
267 |
-
test_data_loader = DataLoader(test_data, batch_size=params['batch_size'], shuffle=True)
|
268 |
-
|
269 |
-
loss_functions = dict()
|
270 |
-
weights = dict()
|
271 |
-
for k in sorted(params['auxiliaries_dict'].keys()):
|
272 |
-
if params['auxiliaries_dict'][k]['type'] == 'classif':
|
273 |
-
if k == 'glasses':
|
274 |
-
classif_weights = train_data.glasses_weights
|
275 |
-
elif k == 'prep_type':
|
276 |
-
classif_weights = train_data.prep_types_weights
|
277 |
-
elif k == 'categories':
|
278 |
-
classif_weights = train_data.categories_weights
|
279 |
-
else:
|
280 |
-
raise ValueError
|
281 |
-
loss_functions[k] = nn.CrossEntropyLoss(torch.FloatTensor(classif_weights))
|
282 |
-
elif params['auxiliaries_dict'][k]['type'] == 'multiclassif':
|
283 |
-
loss_functions[k] = nn.BCEWithLogitsLoss()
|
284 |
-
elif params['auxiliaries_dict'][k]['type'] == 'regression':
|
285 |
-
loss_functions[k] = nn.MSELoss()
|
286 |
-
else:
|
287 |
-
raise ValueError
|
288 |
-
weights[k] = params['auxiliaries_dict'][k]['weight']
|
289 |
-
|
290 |
-
|
291 |
-
return loss_functions, train_data_loader, test_data_loader, weights
|
292 |
-
|
293 |
-
def print_losses(train, losses, accuracies, other_metrics):
|
294 |
-
keyword = 'Train' if train else 'Eval'
|
295 |
-
print(f'\t{keyword} logs:')
|
296 |
-
keys = ['global_loss', 'vae_loss', 'mse_loss', 'kld_loss', 'volume_loss']
|
297 |
-
for k in keys:
|
298 |
-
print(f'\t\t{k} - Loss: {losses[k]:.2f}')
|
299 |
-
for k in sorted(accuracies.keys()):
|
300 |
-
print(f'\t\t{k} (aux) - Loss: {losses[k]:.2f}, Acc: {accuracies[k]:.2f}')
|
301 |
-
for k in sorted(other_metrics.keys()):
|
302 |
-
if 'confusion' not in k:
|
303 |
-
print(f'\t\t{k} - {other_metrics[k]:.2f}')
|
304 |
-
|
305 |
-
|
306 |
-
def run_experiment(params, verbose=True):
|
307 |
-
loss_functions, train_data_loader, test_data_loader, weights = prepare_data_and_loss(params)
|
308 |
-
params['filter_decoder_output'] = train_data_loader.dataset.filter_decoder_output
|
309 |
-
|
310 |
-
model_params = [params[k] for k in ["input_dim", "deepset_latent_dim", "hidden_dims_ingredients", "activation",
|
311 |
-
"hidden_dims_cocktail", "hidden_dims_decoder", "nb_ingredients", "latent_dim", "agg", "dropout", "auxiliaries_dict",
|
312 |
-
"filter_decoder_output"]]
|
313 |
-
model = get_vae_model(*model_params)
|
314 |
-
opt = torch.optim.AdamW(model.parameters(), lr=params['lr'])
|
315 |
-
|
316 |
-
|
317 |
-
all_train_losses = []
|
318 |
-
all_eval_losses = []
|
319 |
-
all_train_accuracies = []
|
320 |
-
all_eval_accuracies = []
|
321 |
-
all_eval_other_metrics = []
|
322 |
-
all_train_other_metrics = []
|
323 |
-
best_loss = np.inf
|
324 |
-
model, eval_losses, eval_accuracies, eval_other_metrics = run_epoch(opt=opt, train=False, model=model, data=test_data_loader, loss_functions=loss_functions,
|
325 |
-
weights=weights, params=params)
|
326 |
-
all_eval_losses.append(eval_losses)
|
327 |
-
all_eval_accuracies.append(eval_accuracies)
|
328 |
-
all_eval_other_metrics.append(eval_other_metrics)
|
329 |
-
if verbose: print(f'\n--------\nEpoch #0')
|
330 |
-
if verbose: print_losses(train=False, accuracies=eval_accuracies, losses=eval_losses, other_metrics=eval_other_metrics)
|
331 |
-
for epoch in range(params['nb_epochs']):
|
332 |
-
if verbose and (epoch + 1) % params['print_every'] == 0: print(f'\n--------\nEpoch #{epoch+1}')
|
333 |
-
model, train_losses, train_accuracies, train_other_metrics = run_epoch(opt=opt, train=True, model=model, data=train_data_loader, loss_functions=loss_functions,
|
334 |
-
weights=weights, params=params)
|
335 |
-
if verbose and (epoch + 1) % params['print_every'] == 0: print_losses(train=True, accuracies=train_accuracies, losses=train_losses, other_metrics=train_other_metrics)
|
336 |
-
model, eval_losses, eval_accuracies, eval_other_metrics = run_epoch(opt=opt, train=False, model=model, data=test_data_loader, loss_functions=loss_functions,
|
337 |
-
weights=weights, params=params)
|
338 |
-
if verbose and (epoch + 1) % params['print_every'] == 0: print_losses(train=False, accuracies=eval_accuracies, losses=eval_losses, other_metrics=eval_other_metrics)
|
339 |
-
if eval_losses['global_loss'] < best_loss:
|
340 |
-
best_loss = eval_losses['global_loss']
|
341 |
-
if verbose: print(f'Saving new best model with loss {best_loss:.2f}')
|
342 |
-
torch.save(model.state_dict(), params['save_path'] + f'checkpoint_best.save')
|
343 |
-
|
344 |
-
# log
|
345 |
-
all_train_losses.append(train_losses)
|
346 |
-
all_train_accuracies.append(train_accuracies)
|
347 |
-
all_eval_losses.append(eval_losses)
|
348 |
-
all_eval_accuracies.append(eval_accuracies)
|
349 |
-
all_eval_other_metrics.append(eval_other_metrics)
|
350 |
-
all_train_other_metrics.append(train_other_metrics)
|
351 |
-
|
352 |
-
# if epoch == params['nb_epoch_switch_beta']:
|
353 |
-
# params['beta_vae'] = 2.5
|
354 |
-
# params['auxiliaries_dict']['prep_type']['weight'] /= 10
|
355 |
-
# params['auxiliaries_dict']['glasses']['weight'] /= 10
|
356 |
-
|
357 |
-
if (epoch + 1) % params['plot_every'] == 0:
|
358 |
-
|
359 |
-
plot_results(all_train_losses, all_train_accuracies, all_train_other_metrics,
|
360 |
-
all_eval_losses, all_eval_accuracies, all_eval_other_metrics, params['plot_path'], weights)
|
361 |
-
|
362 |
-
return model
|
363 |
-
|
364 |
-
def plot_results(all_train_losses, all_train_accuracies, all_train_other_metrics,
|
365 |
-
all_eval_losses, all_eval_accuracies, all_eval_other_metrics, plot_path, weights):
|
366 |
-
|
367 |
-
steps = np.arange(len(all_eval_accuracies))
|
368 |
-
|
369 |
-
loss_keys = sorted(all_train_losses[0].keys())
|
370 |
-
acc_keys = sorted(all_train_accuracies[0].keys())
|
371 |
-
metrics_keys = sorted(all_train_other_metrics[0].keys())
|
372 |
-
|
373 |
-
plt.figure()
|
374 |
-
plt.title('Train losses')
|
375 |
-
for k in loss_keys:
|
376 |
-
factor = 1 if k == 'mse_loss' else 1
|
377 |
-
if k not in weights.keys():
|
378 |
-
plt.plot(steps[1:], [train_loss[k] * factor for train_loss in all_train_losses], label=k)
|
379 |
-
else:
|
380 |
-
if weights[k] != 0:
|
381 |
-
plt.plot(steps[1:], [train_loss[k] * factor for train_loss in all_train_losses], label=k)
|
382 |
-
|
383 |
-
plt.legend()
|
384 |
-
plt.ylim([0, 4])
|
385 |
-
plt.savefig(plot_path + 'train_losses.png', dpi=200)
|
386 |
-
fig = plt.gcf()
|
387 |
-
plt.close(fig)
|
388 |
-
|
389 |
-
plt.figure()
|
390 |
-
plt.title('Train accuracies')
|
391 |
-
for k in acc_keys:
|
392 |
-
if weights[k] != 0:
|
393 |
-
plt.plot(steps[1:], [train_acc[k] for train_acc in all_train_accuracies], label=k)
|
394 |
-
plt.legend()
|
395 |
-
plt.ylim([0, 1])
|
396 |
-
plt.savefig(plot_path + 'train_acc.png', dpi=200)
|
397 |
-
fig = plt.gcf()
|
398 |
-
plt.close(fig)
|
399 |
-
|
400 |
-
plt.figure()
|
401 |
-
plt.title('Train other metrics')
|
402 |
-
for k in metrics_keys:
|
403 |
-
if 'confusion' not in k and 'presence' in k:
|
404 |
-
plt.plot(steps[1:], [train_metric[k] for train_metric in all_train_other_metrics], label=k)
|
405 |
-
plt.legend()
|
406 |
-
plt.ylim([0, 1])
|
407 |
-
plt.savefig(plot_path + 'train_ing_presence_errors.png', dpi=200)
|
408 |
-
fig = plt.gcf()
|
409 |
-
plt.close(fig)
|
410 |
-
|
411 |
-
plt.figure()
|
412 |
-
plt.title('Train other metrics')
|
413 |
-
for k in metrics_keys:
|
414 |
-
if 'confusion' not in k and 'presence' not in k:
|
415 |
-
plt.plot(steps[1:], [train_metric[k] for train_metric in all_train_other_metrics], label=k)
|
416 |
-
plt.legend()
|
417 |
-
plt.savefig(plot_path + 'train_ing_q_error.png', dpi=200)
|
418 |
-
fig = plt.gcf()
|
419 |
-
plt.close(fig)
|
420 |
-
|
421 |
-
plt.figure()
|
422 |
-
plt.title('Eval losses')
|
423 |
-
for k in loss_keys:
|
424 |
-
factor = 1 if k == 'mse_loss' else 1
|
425 |
-
if k not in weights.keys():
|
426 |
-
plt.plot(steps, [eval_loss[k] * factor for eval_loss in all_eval_losses], label=k)
|
427 |
-
else:
|
428 |
-
if weights[k] != 0:
|
429 |
-
plt.plot(steps, [eval_loss[k] * factor for eval_loss in all_eval_losses], label=k)
|
430 |
-
plt.legend()
|
431 |
-
plt.ylim([0, 4])
|
432 |
-
plt.savefig(plot_path + 'eval_losses.png', dpi=200)
|
433 |
-
fig = plt.gcf()
|
434 |
-
plt.close(fig)
|
435 |
-
|
436 |
-
plt.figure()
|
437 |
-
plt.title('Eval accuracies')
|
438 |
-
for k in acc_keys:
|
439 |
-
if weights[k] != 0:
|
440 |
-
plt.plot(steps, [eval_acc[k] for eval_acc in all_eval_accuracies], label=k)
|
441 |
-
plt.legend()
|
442 |
-
plt.ylim([0, 1])
|
443 |
-
plt.savefig(plot_path + 'eval_acc.png', dpi=200)
|
444 |
-
fig = plt.gcf()
|
445 |
-
plt.close(fig)
|
446 |
-
|
447 |
-
plt.figure()
|
448 |
-
plt.title('Eval other metrics')
|
449 |
-
for k in metrics_keys:
|
450 |
-
if 'confusion' not in k and 'presence' in k:
|
451 |
-
plt.plot(steps, [eval_metric[k] for eval_metric in all_eval_other_metrics], label=k)
|
452 |
-
plt.legend()
|
453 |
-
plt.ylim([0, 1])
|
454 |
-
plt.savefig(plot_path + 'eval_ing_presence_errors.png', dpi=200)
|
455 |
-
fig = plt.gcf()
|
456 |
-
plt.close(fig)
|
457 |
-
|
458 |
-
plt.figure()
|
459 |
-
plt.title('Eval other metrics')
|
460 |
-
for k in metrics_keys:
|
461 |
-
if 'confusion' not in k and 'presence' not in k:
|
462 |
-
plt.plot(steps, [eval_metric[k] for eval_metric in all_eval_other_metrics], label=k)
|
463 |
-
plt.legend()
|
464 |
-
plt.savefig(plot_path + 'eval_ing_q_error.png', dpi=200)
|
465 |
-
fig = plt.gcf()
|
466 |
-
plt.close(fig)
|
467 |
-
|
468 |
-
|
469 |
-
for k in metrics_keys:
|
470 |
-
if 'confusion' in k:
|
471 |
-
plt.figure()
|
472 |
-
plt.title(k)
|
473 |
-
plt.ylabel('True')
|
474 |
-
plt.xlabel('Predicted')
|
475 |
-
plt.imshow(all_eval_other_metrics[-1][k], vmin=0, vmax=1)
|
476 |
-
plt.colorbar()
|
477 |
-
plt.savefig(plot_path + f'eval_{k}.png', dpi=200)
|
478 |
-
fig = plt.gcf()
|
479 |
-
plt.close(fig)
|
480 |
-
|
481 |
-
for k in metrics_keys:
|
482 |
-
if 'confusion' in k:
|
483 |
-
plt.figure()
|
484 |
-
plt.title(k)
|
485 |
-
plt.ylabel('True')
|
486 |
-
plt.xlabel('Predicted')
|
487 |
-
plt.imshow(all_train_other_metrics[-1][k], vmin=0, vmax=1)
|
488 |
-
plt.colorbar()
|
489 |
-
plt.savefig(plot_path + f'train_{k}.png', dpi=200)
|
490 |
-
fig = plt.gcf()
|
491 |
-
plt.close(fig)
|
492 |
-
|
493 |
-
plt.close('all')
|
494 |
-
|
495 |
-
|
496 |
-
def get_model(model_path):
|
497 |
-
|
498 |
-
with open(model_path + 'params.json', 'r') as f:
|
499 |
-
params = json.load(f)
|
500 |
-
params['save_path'] = model_path
|
501 |
-
max_ing_quantities = np.loadtxt(params['save_path'] + 'max_ing_quantities.txt')
|
502 |
-
mean_ing_quantities = np.loadtxt(params['save_path'] + 'mean_ing_quantities.txt')
|
503 |
-
std_ing_quantities = np.loadtxt(params['save_path'] + 'std_ing_quantities.txt')
|
504 |
-
min_when_present_ing_quantities = np.loadtxt(params['save_path'] + 'min_when_present_ing_quantities.txt')
|
505 |
-
def filter_decoder_output(output):
|
506 |
-
output = output.detach().numpy()
|
507 |
-
output_unnormalized = output * std_ing_quantities + mean_ing_quantities
|
508 |
-
if output.ndim == 1:
|
509 |
-
output_unnormalized[np.where(output_unnormalized < min_when_present_ing_quantities)] = 0
|
510 |
-
else:
|
511 |
-
for i in range(output.shape[0]):
|
512 |
-
output_unnormalized[i, np.where(output_unnormalized[i] < min_when_present_ing_quantities)] = 0
|
513 |
-
return output_unnormalized.copy()
|
514 |
-
params['filter_decoder_output'] = filter_decoder_output
|
515 |
-
model_chkpt = model_path + "checkpoint_best.save"
|
516 |
-
model_params = [params[k] for k in ["input_dim", "deepset_latent_dim", "hidden_dims_ingredients", "activation",
|
517 |
-
"hidden_dims_cocktail", "hidden_dims_decoder", "nb_ingredients", "latent_dim", "agg", "dropout", "auxiliaries_dict",
|
518 |
-
"filter_decoder_output"]]
|
519 |
-
model = get_vae_model(*model_params)
|
520 |
-
model.load_state_dict(torch.load(model_chkpt))
|
521 |
-
model.eval()
|
522 |
-
return model, filter_decoder_output, params
|
523 |
-
|
524 |
-
|
525 |
-
def compute_expe_name_and_save_path(params):
|
526 |
-
weights_str = '['
|
527 |
-
for aux in params['auxiliaries_dict'].keys():
|
528 |
-
weights_str += f'{params["auxiliaries_dict"][aux]["weight"]}, '
|
529 |
-
weights_str = weights_str[:-2] + ']'
|
530 |
-
save_path = params['save_path'] + params["trial_id"]
|
531 |
-
save_path += f'_lr{params["lr"]}'
|
532 |
-
save_path += f'_betavae{params["beta_vae"]}'
|
533 |
-
save_path += f'_bs{params["batch_size"]}'
|
534 |
-
save_path += f'_latentdim{params["latent_dim"]}'
|
535 |
-
save_path += f'_hding{params["hidden_dims_ingredients"]}'
|
536 |
-
save_path += f'_hdcocktail{params["hidden_dims_cocktail"]}'
|
537 |
-
save_path += f'_hddecoder{params["hidden_dims_decoder"]}'
|
538 |
-
save_path += f'_agg{params["agg"]}'
|
539 |
-
save_path += f'_activ{params["activation"]}'
|
540 |
-
save_path += f'_w{weights_str}'
|
541 |
-
counter = 0
|
542 |
-
while os.path.exists(save_path + f"_{counter}"):
|
543 |
-
counter += 1
|
544 |
-
save_path = save_path + f"_{counter}" + '/'
|
545 |
-
params["save_path"] = save_path
|
546 |
-
os.makedirs(save_path)
|
547 |
-
os.makedirs(save_path + 'plots/')
|
548 |
-
params['plot_path'] = save_path + 'plots/'
|
549 |
-
print(f'logging to {save_path}')
|
550 |
-
return params
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
if __name__ == '__main__':
|
555 |
-
params = get_params()
|
556 |
-
run_experiment(params)
|
557 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/cocktails/representation_learning/run_simple_net.py
DELETED
@@ -1,302 +0,0 @@
|
|
1 |
-
import torch; torch.manual_seed(0)
|
2 |
-
import torch.utils
|
3 |
-
from torch.utils.data import DataLoader
|
4 |
-
import torch.distributions
|
5 |
-
import torch.nn as nn
|
6 |
-
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
|
7 |
-
from src.cocktails.representation_learning.dataset import MyDataset, get_representation_from_ingredient, get_max_n_ingredients
|
8 |
-
import json
|
9 |
-
import pandas as pd
|
10 |
-
import numpy as np
|
11 |
-
import os
|
12 |
-
from src.cocktails.representation_learning.simple_model import SimpleNet
|
13 |
-
from src.cocktails.config import COCKTAILS_CSV_DATA, FULL_COCKTAIL_REP_PATH, EXPERIMENT_PATH
|
14 |
-
from src.cocktails.utilities.cocktail_utilities import get_bunch_of_rep_keys
|
15 |
-
from src.cocktails.utilities.ingredients_utilities import ingredient_profiles
|
16 |
-
from resource import getrusage
|
17 |
-
from resource import RUSAGE_SELF
|
18 |
-
import gc
|
19 |
-
gc.collect(2)
|
20 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
21 |
-
|
22 |
-
def get_params():
|
23 |
-
data = pd.read_csv(COCKTAILS_CSV_DATA)
|
24 |
-
max_ingredients, ingredient_set, liquor_set, liqueur_set = get_max_n_ingredients(data)
|
25 |
-
num_ingredients = len(ingredient_set)
|
26 |
-
rep_keys = get_bunch_of_rep_keys()['custom']
|
27 |
-
ing_keys = [k.split(' ')[1] for k in rep_keys]
|
28 |
-
ing_keys.remove('volume')
|
29 |
-
nb_ing_categories = len(set(ingredient_profiles['type']))
|
30 |
-
category_encodings = dict(zip(sorted(set(ingredient_profiles['type'])), np.eye(nb_ing_categories)))
|
31 |
-
|
32 |
-
params = dict(trial_id='test',
|
33 |
-
save_path=EXPERIMENT_PATH + "/simple_net/",
|
34 |
-
nb_epochs=100,
|
35 |
-
print_every=50,
|
36 |
-
plot_every=50,
|
37 |
-
batch_size=128,
|
38 |
-
lr=0.001,
|
39 |
-
dropout=0.15,
|
40 |
-
output_keyword='glasses',
|
41 |
-
ing_keys=ing_keys,
|
42 |
-
nb_ingredients=len(ingredient_set),
|
43 |
-
hidden_dims=[16],
|
44 |
-
activation='sigmoid',
|
45 |
-
auxiliaries_dict=dict(categories=dict(weight=0, type='classif', final_activ=None, dim_output=len(set(data['subcategory']))),
|
46 |
-
glasses=dict(weight=0, type='classif', final_activ=None, dim_output=len(set(data['glass']))),
|
47 |
-
prep_type=dict(weight=0, type='classif', final_activ=None, dim_output=len(set(data['category']))),
|
48 |
-
cocktail_reps=dict(weight=0, type='regression', final_activ=None, dim_output=13),
|
49 |
-
volume=dict(weight=0, type='regression', final_activ='relu', dim_output=1),
|
50 |
-
taste_reps=dict(weight=0, type='regression', final_activ='relu', dim_output=2),
|
51 |
-
ingredients_presence=dict(weight=0, type='multiclassif', final_activ=None, dim_output=num_ingredients),
|
52 |
-
ingredients_quantities=dict(weight=0, type='regression', final_activ=None, dim_output=num_ingredients)),
|
53 |
-
|
54 |
-
category_encodings=category_encodings
|
55 |
-
)
|
56 |
-
params['output_dim'] = params['auxiliaries_dict'][params['output_keyword']]['dim_output']
|
57 |
-
water_rep, indexes_to_normalize = get_representation_from_ingredient(ingredients=['water'], quantities=[1],
|
58 |
-
max_q_per_ing=dict(zip(ingredient_set, [1] * num_ingredients)), index=0,
|
59 |
-
params=params)
|
60 |
-
dim_rep_ingredient = water_rep.size
|
61 |
-
params['indexes_ing_to_normalize'] = indexes_to_normalize
|
62 |
-
params['deepset_latent_dim'] = dim_rep_ingredient * max_ingredients
|
63 |
-
params['dim_rep_ingredient'] = dim_rep_ingredient
|
64 |
-
params['input_dim'] = params['nb_ingredients']
|
65 |
-
params = compute_expe_name_and_save_path(params)
|
66 |
-
del params['category_encodings'] # to dump
|
67 |
-
with open(params['save_path'] + 'params.json', 'w') as f:
|
68 |
-
json.dump(params, f)
|
69 |
-
|
70 |
-
params = complete_params(params)
|
71 |
-
return params
|
72 |
-
|
73 |
-
def complete_params(params):
|
74 |
-
data = pd.read_csv(COCKTAILS_CSV_DATA)
|
75 |
-
cocktail_reps = np.loadtxt(FULL_COCKTAIL_REP_PATH)
|
76 |
-
nb_ing_categories = len(set(ingredient_profiles['type']))
|
77 |
-
category_encodings = dict(zip(sorted(set(ingredient_profiles['type'])), np.eye(nb_ing_categories)))
|
78 |
-
params['cocktail_reps'] = cocktail_reps
|
79 |
-
params['raw_data'] = data
|
80 |
-
params['category_encodings'] = category_encodings
|
81 |
-
return params
|
82 |
-
|
83 |
-
def compute_confusion_matrix_and_accuracy(predictions, ground_truth):
|
84 |
-
bs, n_options = predictions.shape
|
85 |
-
predicted = predictions.argmax(dim=1).detach().numpy()
|
86 |
-
true = ground_truth.int().detach().numpy()
|
87 |
-
confusion_matrix = np.zeros([n_options, n_options])
|
88 |
-
for i in range(bs):
|
89 |
-
confusion_matrix[true[i], predicted[i]] += 1
|
90 |
-
acc = confusion_matrix.diagonal().sum() / bs
|
91 |
-
for i in range(n_options):
|
92 |
-
if confusion_matrix[i].sum() != 0:
|
93 |
-
confusion_matrix[i] /= confusion_matrix[i].sum()
|
94 |
-
acc2 = np.mean(predicted == true)
|
95 |
-
assert (acc - acc2) < 1e-5
|
96 |
-
return confusion_matrix, acc
|
97 |
-
|
98 |
-
|
99 |
-
def run_epoch(opt, train, model, data, loss_function, params):
|
100 |
-
if train:
|
101 |
-
model.train()
|
102 |
-
else:
|
103 |
-
model.eval()
|
104 |
-
|
105 |
-
# prepare logging of losses
|
106 |
-
losses = []
|
107 |
-
accuracies = []
|
108 |
-
cf_matrices = []
|
109 |
-
if train: opt.zero_grad()
|
110 |
-
|
111 |
-
for d in data:
|
112 |
-
nb_ingredients = d[0]
|
113 |
-
batch_size = nb_ingredients.shape[0]
|
114 |
-
x_ingredients = d[1].float()
|
115 |
-
ingredient_quantities = d[2].float()
|
116 |
-
cocktail_reps = d[3].float()
|
117 |
-
auxiliaries = d[4]
|
118 |
-
for k in auxiliaries.keys():
|
119 |
-
if auxiliaries[k].dtype == torch.float64: auxiliaries[k] = auxiliaries[k].float()
|
120 |
-
taste_valid = d[-1]
|
121 |
-
predictions = model(ingredient_quantities)
|
122 |
-
loss = loss_function(predictions, auxiliaries[params['output_keyword']].long()).float()
|
123 |
-
cf_matrix, accuracy = compute_confusion_matrix_and_accuracy(predictions, auxiliaries[params['output_keyword']])
|
124 |
-
if train:
|
125 |
-
loss.backward()
|
126 |
-
opt.step()
|
127 |
-
opt.zero_grad()
|
128 |
-
|
129 |
-
losses.append(float(loss))
|
130 |
-
cf_matrices.append(cf_matrix)
|
131 |
-
accuracies.append(accuracy)
|
132 |
-
|
133 |
-
return model, np.mean(losses), np.mean(accuracies), np.mean(cf_matrices, axis=0)
|
134 |
-
|
135 |
-
def prepare_data_and_loss(params):
|
136 |
-
train_data = MyDataset(split='train', params=params)
|
137 |
-
test_data = MyDataset(split='test', params=params)
|
138 |
-
|
139 |
-
train_data_loader = DataLoader(train_data, batch_size=params['batch_size'], shuffle=True)
|
140 |
-
test_data_loader = DataLoader(test_data, batch_size=params['batch_size'], shuffle=True)
|
141 |
-
|
142 |
-
|
143 |
-
if params['auxiliaries_dict'][params['output_keyword']]['type'] == 'classif':
|
144 |
-
if params['output_keyword'] == 'glasses':
|
145 |
-
classif_weights = train_data.glasses_weights
|
146 |
-
elif params['output_keyword'] == 'prep_type':
|
147 |
-
classif_weights = train_data.prep_types_weights
|
148 |
-
elif params['output_keyword'] == 'categories':
|
149 |
-
classif_weights = train_data.categories_weights
|
150 |
-
else:
|
151 |
-
raise ValueError
|
152 |
-
# classif_weights = (np.array(classif_weights) * 2 + np.ones(len(classif_weights))) / 3
|
153 |
-
loss_function = nn.CrossEntropyLoss(torch.FloatTensor(classif_weights))
|
154 |
-
# loss_function = nn.CrossEntropyLoss()
|
155 |
-
|
156 |
-
elif params['auxiliaries_dict'][params['output_keyword']]['type'] == 'multiclassif':
|
157 |
-
loss_function = nn.BCEWithLogitsLoss()
|
158 |
-
elif params['auxiliaries_dict'][params['output_keyword']]['type'] == 'regression':
|
159 |
-
loss_function = nn.MSELoss()
|
160 |
-
else:
|
161 |
-
raise ValueError
|
162 |
-
|
163 |
-
return loss_function, train_data_loader, test_data_loader
|
164 |
-
|
165 |
-
def print_losses(train, loss, accuracy):
|
166 |
-
keyword = 'Train' if train else 'Eval'
|
167 |
-
print(f'\t{keyword} logs:')
|
168 |
-
print(f'\t\t Loss: {loss:.2f}, Acc: {accuracy:.2f}')
|
169 |
-
|
170 |
-
|
171 |
-
def run_experiment(params, verbose=True):
|
172 |
-
loss_function, train_data_loader, test_data_loader = prepare_data_and_loss(params)
|
173 |
-
|
174 |
-
model = SimpleNet(params['input_dim'], params['hidden_dims'], params['output_dim'], params['activation'], params['dropout'])
|
175 |
-
opt = torch.optim.AdamW(model.parameters(), lr=params['lr'])
|
176 |
-
|
177 |
-
all_train_losses = []
|
178 |
-
all_eval_losses = []
|
179 |
-
all_eval_cf_matrices = []
|
180 |
-
all_train_accuracies = []
|
181 |
-
all_eval_accuracies = []
|
182 |
-
all_train_cf_matrices = []
|
183 |
-
best_loss = np.inf
|
184 |
-
model, eval_loss, eval_accuracy, eval_cf_matrix = run_epoch(opt=opt, train=False, model=model, data=test_data_loader, loss_function=loss_function, params=params)
|
185 |
-
all_eval_losses.append(eval_loss)
|
186 |
-
all_eval_accuracies.append(eval_accuracy)
|
187 |
-
if verbose: print(f'\n--------\nEpoch #0')
|
188 |
-
if verbose: print_losses(train=False, accuracy=eval_accuracy, loss=eval_loss)
|
189 |
-
for epoch in range(params['nb_epochs']):
|
190 |
-
if verbose and (epoch + 1) % params['print_every'] == 0: print(f'\n--------\nEpoch #{epoch+1}')
|
191 |
-
model, train_loss, train_accuracy, train_cf_matrix = run_epoch(opt=opt, train=True, model=model, data=train_data_loader, loss_function=loss_function, params=params)
|
192 |
-
if verbose and (epoch + 1) % params['print_every'] == 0: print_losses(train=True, accuracy=train_accuracy, loss=train_loss)
|
193 |
-
model, eval_loss, eval_accuracy, eval_cf_matrix = run_epoch(opt=opt, train=False, model=model, data=test_data_loader, loss_function=loss_function, params=params)
|
194 |
-
if verbose and (epoch + 1) % params['print_every'] == 0: print_losses(train=False, accuracy=eval_accuracy, loss=eval_loss)
|
195 |
-
if eval_loss < best_loss:
|
196 |
-
best_loss = eval_loss
|
197 |
-
if verbose: print(f'Saving new best model with loss {best_loss:.2f}')
|
198 |
-
torch.save(model.state_dict(), params['save_path'] + f'checkpoint_best.save')
|
199 |
-
|
200 |
-
# log
|
201 |
-
all_train_losses.append(train_loss)
|
202 |
-
all_train_accuracies.append(train_accuracy)
|
203 |
-
all_eval_losses.append(eval_loss)
|
204 |
-
all_eval_accuracies.append(eval_accuracy)
|
205 |
-
all_eval_cf_matrices.append(eval_cf_matrix)
|
206 |
-
all_train_cf_matrices.append(train_cf_matrix)
|
207 |
-
|
208 |
-
if (epoch + 1) % params['plot_every'] == 0:
|
209 |
-
|
210 |
-
plot_results(all_train_losses, all_train_accuracies, all_train_cf_matrices,
|
211 |
-
all_eval_losses, all_eval_accuracies, all_eval_cf_matrices, params['plot_path'])
|
212 |
-
|
213 |
-
return model
|
214 |
-
|
215 |
-
def plot_results(all_train_losses, all_train_accuracies, all_train_cf_matrices,
|
216 |
-
all_eval_losses, all_eval_accuracies, all_eval_cf_matrices, plot_path):
|
217 |
-
|
218 |
-
steps = np.arange(len(all_eval_accuracies))
|
219 |
-
|
220 |
-
plt.figure()
|
221 |
-
plt.title('Losses')
|
222 |
-
plt.plot(steps[1:], all_train_losses, label='train')
|
223 |
-
plt.plot(steps, all_eval_losses, label='eval')
|
224 |
-
plt.legend()
|
225 |
-
plt.ylim([0, 4])
|
226 |
-
plt.savefig(plot_path + 'losses.png', dpi=200)
|
227 |
-
fig = plt.gcf()
|
228 |
-
plt.close(fig)
|
229 |
-
|
230 |
-
plt.figure()
|
231 |
-
plt.title('Accuracies')
|
232 |
-
plt.plot(steps[1:], all_train_accuracies, label='train')
|
233 |
-
plt.plot(steps, all_eval_accuracies, label='eval')
|
234 |
-
plt.legend()
|
235 |
-
plt.ylim([0, 1])
|
236 |
-
plt.savefig(plot_path + 'accs.png', dpi=200)
|
237 |
-
fig = plt.gcf()
|
238 |
-
plt.close(fig)
|
239 |
-
|
240 |
-
|
241 |
-
plt.figure()
|
242 |
-
plt.title('Train confusion matrix')
|
243 |
-
plt.ylabel('True')
|
244 |
-
plt.xlabel('Predicted')
|
245 |
-
plt.imshow(all_train_cf_matrices[-1], vmin=0, vmax=1)
|
246 |
-
plt.colorbar()
|
247 |
-
plt.savefig(plot_path + f'train_confusion_matrix.png', dpi=200)
|
248 |
-
fig = plt.gcf()
|
249 |
-
plt.close(fig)
|
250 |
-
|
251 |
-
plt.figure()
|
252 |
-
plt.title('Eval confusion matrix')
|
253 |
-
plt.ylabel('True')
|
254 |
-
plt.xlabel('Predicted')
|
255 |
-
plt.imshow(all_eval_cf_matrices[-1], vmin=0, vmax=1)
|
256 |
-
plt.colorbar()
|
257 |
-
plt.savefig(plot_path + f'eval_confusion_matrix.png', dpi=200)
|
258 |
-
fig = plt.gcf()
|
259 |
-
plt.close(fig)
|
260 |
-
|
261 |
-
plt.close('all')
|
262 |
-
|
263 |
-
|
264 |
-
def get_model(model_path):
|
265 |
-
with open(model_path + 'params.json', 'r') as f:
|
266 |
-
params = json.load(f)
|
267 |
-
params['save_path'] = model_path
|
268 |
-
model_chkpt = model_path + "checkpoint_best.save"
|
269 |
-
model = SimpleNet(params['input_dim'], params['hidden_dims'], params['output_dim'], params['activation'], params['dropout'])
|
270 |
-
model.load_state_dict(torch.load(model_chkpt))
|
271 |
-
model.eval()
|
272 |
-
return model, params
|
273 |
-
|
274 |
-
|
275 |
-
def compute_expe_name_and_save_path(params):
|
276 |
-
weights_str = '['
|
277 |
-
for aux in params['auxiliaries_dict'].keys():
|
278 |
-
weights_str += f'{params["auxiliaries_dict"][aux]["weight"]}, '
|
279 |
-
weights_str = weights_str[:-2] + ']'
|
280 |
-
save_path = params['save_path'] + params["trial_id"]
|
281 |
-
save_path += f'_lr{params["lr"]}'
|
282 |
-
save_path += f'_bs{params["batch_size"]}'
|
283 |
-
save_path += f'_hd{params["hidden_dims"]}'
|
284 |
-
save_path += f'_activ{params["activation"]}'
|
285 |
-
save_path += f'_w{weights_str}'
|
286 |
-
counter = 0
|
287 |
-
while os.path.exists(save_path + f"_{counter}"):
|
288 |
-
counter += 1
|
289 |
-
save_path = save_path + f"_{counter}" + '/'
|
290 |
-
params["save_path"] = save_path
|
291 |
-
os.makedirs(save_path)
|
292 |
-
os.makedirs(save_path + 'plots/')
|
293 |
-
params['plot_path'] = save_path + 'plots/'
|
294 |
-
print(f'logging to {save_path}')
|
295 |
-
return params
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
if __name__ == '__main__':
|
300 |
-
params = get_params()
|
301 |
-
run_experiment(params)
|
302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/cocktails/representation_learning/run_without_vae.py
DELETED
@@ -1,514 +0,0 @@
|
|
1 |
-
import torch; torch.manual_seed(0)
|
2 |
-
import torch.utils
|
3 |
-
from torch.utils.data import DataLoader
|
4 |
-
import torch.distributions
|
5 |
-
import torch.nn as nn
|
6 |
-
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
|
7 |
-
from src.cocktails.representation_learning.dataset import MyDataset, get_representation_from_ingredient, get_max_n_ingredients
|
8 |
-
import json
|
9 |
-
import pandas as pd
|
10 |
-
import numpy as np
|
11 |
-
import os
|
12 |
-
from src.cocktails.representation_learning.multihead_model import get_multihead_model
|
13 |
-
from src.cocktails.config import COCKTAILS_CSV_DATA, FULL_COCKTAIL_REP_PATH, EXPERIMENT_PATH
|
14 |
-
from src.cocktails.utilities.cocktail_utilities import get_bunch_of_rep_keys
|
15 |
-
from src.cocktails.utilities.ingredients_utilities import ingredient_profiles
|
16 |
-
from resource import getrusage
|
17 |
-
from resource import RUSAGE_SELF
|
18 |
-
import gc
|
19 |
-
gc.collect(2)
|
20 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
21 |
-
|
22 |
-
def get_params():
|
23 |
-
data = pd.read_csv(COCKTAILS_CSV_DATA)
|
24 |
-
max_ingredients, ingredient_set, liquor_set, liqueur_set = get_max_n_ingredients(data)
|
25 |
-
num_ingredients = len(ingredient_set)
|
26 |
-
rep_keys = get_bunch_of_rep_keys()['custom']
|
27 |
-
ing_keys = [k.split(' ')[1] for k in rep_keys]
|
28 |
-
ing_keys.remove('volume')
|
29 |
-
nb_ing_categories = len(set(ingredient_profiles['type']))
|
30 |
-
category_encodings = dict(zip(sorted(set(ingredient_profiles['type'])), np.eye(nb_ing_categories)))
|
31 |
-
|
32 |
-
params = dict(trial_id='test',
|
33 |
-
save_path=EXPERIMENT_PATH + "/multihead_model/",
|
34 |
-
nb_epochs=500,
|
35 |
-
print_every=50,
|
36 |
-
plot_every=50,
|
37 |
-
batch_size=128,
|
38 |
-
lr=0.001,
|
39 |
-
dropout=0.,
|
40 |
-
nb_epoch_switch_beta=600,
|
41 |
-
latent_dim=10,
|
42 |
-
beta_vae=0.2,
|
43 |
-
ing_keys=ing_keys,
|
44 |
-
nb_ingredients=len(ingredient_set),
|
45 |
-
hidden_dims_ingredients=[128],
|
46 |
-
hidden_dims_cocktail=[64],
|
47 |
-
hidden_dims_decoder=[32],
|
48 |
-
agg='mean',
|
49 |
-
activation='relu',
|
50 |
-
auxiliaries_dict=dict(categories=dict(weight=5, type='classif', final_activ=None, dim_output=len(set(data['subcategory']))), #0.5
|
51 |
-
glasses=dict(weight=0.5, type='classif', final_activ=None, dim_output=len(set(data['glass']))), #0.1
|
52 |
-
prep_type=dict(weight=0.1, type='classif', final_activ=None, dim_output=len(set(data['category']))),#1
|
53 |
-
cocktail_reps=dict(weight=1, type='regression', final_activ=None, dim_output=13),#1
|
54 |
-
volume=dict(weight=1, type='regression', final_activ='relu', dim_output=1),#1
|
55 |
-
taste_reps=dict(weight=1, type='regression', final_activ='relu', dim_output=2),#1
|
56 |
-
ingredients_presence=dict(weight=0, type='multiclassif', final_activ=None, dim_output=num_ingredients),#10
|
57 |
-
ingredients_quantities=dict(weight=0, type='regression', final_activ=None, dim_output=num_ingredients)),
|
58 |
-
category_encodings=category_encodings
|
59 |
-
)
|
60 |
-
water_rep, indexes_to_normalize = get_representation_from_ingredient(ingredients=['water'], quantities=[1],
|
61 |
-
max_q_per_ing=dict(zip(ingredient_set, [1] * num_ingredients)), index=0,
|
62 |
-
params=params)
|
63 |
-
dim_rep_ingredient = water_rep.size
|
64 |
-
params['indexes_ing_to_normalize'] = indexes_to_normalize
|
65 |
-
params['deepset_latent_dim'] = dim_rep_ingredient * max_ingredients
|
66 |
-
params['dim_rep_ingredient'] = dim_rep_ingredient
|
67 |
-
params['input_dim'] = params['nb_ingredients']
|
68 |
-
params = compute_expe_name_and_save_path(params)
|
69 |
-
del params['category_encodings'] # to dump
|
70 |
-
with open(params['save_path'] + 'params.json', 'w') as f:
|
71 |
-
json.dump(params, f)
|
72 |
-
|
73 |
-
params = complete_params(params)
|
74 |
-
return params
|
75 |
-
|
76 |
-
def complete_params(params):
|
77 |
-
data = pd.read_csv(COCKTAILS_CSV_DATA)
|
78 |
-
cocktail_reps = np.loadtxt(FULL_COCKTAIL_REP_PATH)
|
79 |
-
nb_ing_categories = len(set(ingredient_profiles['type']))
|
80 |
-
category_encodings = dict(zip(sorted(set(ingredient_profiles['type'])), np.eye(nb_ing_categories)))
|
81 |
-
params['cocktail_reps'] = cocktail_reps
|
82 |
-
params['raw_data'] = data
|
83 |
-
params['category_encodings'] = category_encodings
|
84 |
-
return params
|
85 |
-
|
86 |
-
def compute_losses_and_accuracies(loss_functions, auxiliaries, auxiliaries_str, outputs, data):
|
87 |
-
losses = dict()
|
88 |
-
accuracies = dict()
|
89 |
-
other_metrics = dict()
|
90 |
-
for i_k, k in enumerate(auxiliaries_str):
|
91 |
-
# get ground truth
|
92 |
-
# compute loss
|
93 |
-
if k == 'volume':
|
94 |
-
outputs[i_k] = outputs[i_k].flatten()
|
95 |
-
ground_truth = auxiliaries[k]
|
96 |
-
if ground_truth.dtype == torch.float64:
|
97 |
-
losses[k] = loss_functions[k](outputs[i_k], ground_truth.float()).float()
|
98 |
-
elif ground_truth.dtype == torch.int64:
|
99 |
-
if str(loss_functions[k]) != "BCEWithLogitsLoss()":
|
100 |
-
losses[k] = loss_functions[k](outputs[i_k].float(), ground_truth.long()).float()
|
101 |
-
else:
|
102 |
-
losses[k] = loss_functions[k](outputs[i_k].float(), ground_truth.float()).float()
|
103 |
-
else:
|
104 |
-
losses[k] = loss_functions[k](outputs[i_k], ground_truth).float()
|
105 |
-
# compute accuracies
|
106 |
-
if str(loss_functions[k]) == 'CrossEntropyLoss()':
|
107 |
-
bs, n_options = outputs[i_k].shape
|
108 |
-
predicted = outputs[i_k].argmax(dim=1).detach().numpy()
|
109 |
-
true = ground_truth.int().detach().numpy()
|
110 |
-
confusion_matrix = np.zeros([n_options, n_options])
|
111 |
-
for i in range(bs):
|
112 |
-
confusion_matrix[true[i], predicted[i]] += 1
|
113 |
-
acc = confusion_matrix.diagonal().sum() / bs
|
114 |
-
for i in range(n_options):
|
115 |
-
if confusion_matrix[i].sum() != 0:
|
116 |
-
confusion_matrix[i] /= confusion_matrix[i].sum()
|
117 |
-
other_metrics[k + '_confusion'] = confusion_matrix
|
118 |
-
accuracies[k] = np.mean(outputs[i_k].argmax(dim=1).detach().numpy() == ground_truth.int().detach().numpy())
|
119 |
-
assert (acc - accuracies[k]) < 1e-5
|
120 |
-
|
121 |
-
elif str(loss_functions[k]) == 'BCEWithLogitsLoss()':
|
122 |
-
assert k == 'ingredients_presence'
|
123 |
-
outputs_rescaled = outputs[i_k].detach().numpy() * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities
|
124 |
-
predicted_presence = (outputs_rescaled > 0).astype(bool)
|
125 |
-
presence = ground_truth.detach().numpy().astype(bool)
|
126 |
-
other_metrics[k + '_false_positive'] = np.mean(np.logical_and(predicted_presence.astype(bool), ~presence.astype(bool)))
|
127 |
-
other_metrics[k + '_false_negative'] = np.mean(np.logical_and(~predicted_presence.astype(bool), presence.astype(bool)))
|
128 |
-
accuracies[k] = np.mean(predicted_presence == presence) # accuracy for multi class labeling
|
129 |
-
elif str(loss_functions[k]) == 'MSELoss()':
|
130 |
-
accuracies[k] = np.nan
|
131 |
-
else:
|
132 |
-
raise ValueError
|
133 |
-
return losses, accuracies, other_metrics
|
134 |
-
|
135 |
-
def compute_metric_output(aux_other_metrics, data, ingredient_quantities, x_hat):
|
136 |
-
ing_q = ingredient_quantities.detach().numpy()# * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities
|
137 |
-
ing_presence = (ing_q > 0)
|
138 |
-
x_hat = x_hat.detach().numpy()
|
139 |
-
# x_hat = x_hat.detach().numpy() * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities
|
140 |
-
abs_diff = np.abs(ing_q - x_hat) * data.dataset.max_ing_quantities
|
141 |
-
# abs_diff = np.abs(ing_q - x_hat)
|
142 |
-
ing_q_abs_loss_when_present, ing_q_abs_loss_when_absent = [], []
|
143 |
-
for i in range(ingredient_quantities.shape[0]):
|
144 |
-
ing_q_abs_loss_when_present.append(np.mean(abs_diff[i, np.where(ing_presence[i])]))
|
145 |
-
ing_q_abs_loss_when_absent.append(np.mean(abs_diff[i, np.where(~ing_presence[i])]))
|
146 |
-
aux_other_metrics['ing_q_abs_loss_when_present'] = np.mean(ing_q_abs_loss_when_present)
|
147 |
-
aux_other_metrics['ing_q_abs_loss_when_absent'] = np.mean(ing_q_abs_loss_when_absent)
|
148 |
-
return aux_other_metrics
|
149 |
-
|
150 |
-
def run_epoch(opt, train, model, data, loss_functions, weights, params):
|
151 |
-
if train:
|
152 |
-
model.train()
|
153 |
-
else:
|
154 |
-
model.eval()
|
155 |
-
|
156 |
-
# prepare logging of losses
|
157 |
-
losses = dict(kld_loss=[],
|
158 |
-
mse_loss=[],
|
159 |
-
vae_loss=[],
|
160 |
-
volume_loss=[],
|
161 |
-
global_loss=[])
|
162 |
-
accuracies = dict()
|
163 |
-
other_metrics = dict()
|
164 |
-
for aux in params['auxiliaries_dict'].keys():
|
165 |
-
losses[aux] = []
|
166 |
-
accuracies[aux] = []
|
167 |
-
if train: opt.zero_grad()
|
168 |
-
|
169 |
-
for d in data:
|
170 |
-
nb_ingredients = d[0]
|
171 |
-
batch_size = nb_ingredients.shape[0]
|
172 |
-
x_ingredients = d[1].float()
|
173 |
-
ingredient_quantities = d[2]
|
174 |
-
cocktail_reps = d[3]
|
175 |
-
auxiliaries = d[4]
|
176 |
-
for k in auxiliaries.keys():
|
177 |
-
if auxiliaries[k].dtype == torch.float64: auxiliaries[k] = auxiliaries[k].float()
|
178 |
-
taste_valid = d[-1]
|
179 |
-
z, outputs, auxiliaries_str = model.forward(ingredient_quantities.float())
|
180 |
-
# get auxiliary losses and accuracies
|
181 |
-
aux_losses, aux_accuracies, aux_other_metrics = compute_losses_and_accuracies(loss_functions, auxiliaries, auxiliaries_str, outputs, data)
|
182 |
-
|
183 |
-
# compute vae loss
|
184 |
-
aux_other_metrics = compute_metric_output(aux_other_metrics, data, ingredient_quantities, outputs[auxiliaries_str.index('ingredients_quantities')])
|
185 |
-
|
186 |
-
indexes_taste_valid = np.argwhere(taste_valid.detach().numpy()).flatten()
|
187 |
-
if indexes_taste_valid.size > 0:
|
188 |
-
outputs_taste = model.get_auxiliary(z[indexes_taste_valid], aux_str='taste_reps')
|
189 |
-
gt = auxiliaries['taste_reps'][indexes_taste_valid]
|
190 |
-
factor_loss = indexes_taste_valid.size / (0.3 * batch_size)# factor on the loss: if same ratio as actual dataset factor = 1 if there is less data, then the factor decreases, more data, it increases
|
191 |
-
aux_losses['taste_reps'] = (loss_functions['taste_reps'](outputs_taste, gt) * factor_loss).float()
|
192 |
-
else:
|
193 |
-
aux_losses['taste_reps'] = torch.FloatTensor([0]).reshape([])
|
194 |
-
aux_accuracies['taste_reps'] = 0
|
195 |
-
|
196 |
-
# aggregate losses
|
197 |
-
global_loss = torch.sum(torch.cat([torch.atleast_1d(aux_losses[k] * weights[k]) for k in params['auxiliaries_dict'].keys()]))
|
198 |
-
# for k in params['auxiliaries_dict'].keys():
|
199 |
-
# global_loss += aux_losses[k] * weights[k]
|
200 |
-
|
201 |
-
if train:
|
202 |
-
global_loss.backward()
|
203 |
-
opt.step()
|
204 |
-
opt.zero_grad()
|
205 |
-
|
206 |
-
# logging
|
207 |
-
losses['global_loss'].append(float(global_loss))
|
208 |
-
for k in params['auxiliaries_dict'].keys():
|
209 |
-
losses[k].append(float(aux_losses[k]))
|
210 |
-
accuracies[k].append(float(aux_accuracies[k]))
|
211 |
-
for k in aux_other_metrics.keys():
|
212 |
-
if k not in other_metrics.keys():
|
213 |
-
other_metrics[k] = [aux_other_metrics[k]]
|
214 |
-
else:
|
215 |
-
other_metrics[k].append(aux_other_metrics[k])
|
216 |
-
|
217 |
-
for k in losses.keys():
|
218 |
-
losses[k] = np.mean(losses[k])
|
219 |
-
for k in accuracies.keys():
|
220 |
-
accuracies[k] = np.mean(accuracies[k])
|
221 |
-
for k in other_metrics.keys():
|
222 |
-
other_metrics[k] = np.mean(other_metrics[k], axis=0)
|
223 |
-
return model, losses, accuracies, other_metrics
|
224 |
-
|
225 |
-
def prepare_data_and_loss(params):
|
226 |
-
train_data = MyDataset(split='train', params=params)
|
227 |
-
test_data = MyDataset(split='test', params=params)
|
228 |
-
|
229 |
-
train_data_loader = DataLoader(train_data, batch_size=params['batch_size'], shuffle=True)
|
230 |
-
test_data_loader = DataLoader(test_data, batch_size=params['batch_size'], shuffle=True)
|
231 |
-
|
232 |
-
loss_functions = dict()
|
233 |
-
weights = dict()
|
234 |
-
for k in sorted(params['auxiliaries_dict'].keys()):
|
235 |
-
if params['auxiliaries_dict'][k]['type'] == 'classif':
|
236 |
-
if k == 'glasses':
|
237 |
-
classif_weights = train_data.glasses_weights
|
238 |
-
elif k == 'prep_type':
|
239 |
-
classif_weights = train_data.prep_types_weights
|
240 |
-
elif k == 'categories':
|
241 |
-
classif_weights = train_data.categories_weights
|
242 |
-
else:
|
243 |
-
raise ValueError
|
244 |
-
loss_functions[k] = nn.CrossEntropyLoss(torch.FloatTensor(classif_weights))
|
245 |
-
elif params['auxiliaries_dict'][k]['type'] == 'multiclassif':
|
246 |
-
loss_functions[k] = nn.BCEWithLogitsLoss()
|
247 |
-
elif params['auxiliaries_dict'][k]['type'] == 'regression':
|
248 |
-
loss_functions[k] = nn.MSELoss()
|
249 |
-
else:
|
250 |
-
raise ValueError
|
251 |
-
weights[k] = params['auxiliaries_dict'][k]['weight']
|
252 |
-
|
253 |
-
|
254 |
-
return loss_functions, train_data_loader, test_data_loader, weights
|
255 |
-
|
256 |
-
def print_losses(train, losses, accuracies, other_metrics):
|
257 |
-
keyword = 'Train' if train else 'Eval'
|
258 |
-
print(f'\t{keyword} logs:')
|
259 |
-
keys = ['global_loss', 'vae_loss', 'mse_loss', 'kld_loss', 'volume_loss']
|
260 |
-
for k in keys:
|
261 |
-
print(f'\t\t{k} - Loss: {losses[k]:.2f}')
|
262 |
-
for k in sorted(accuracies.keys()):
|
263 |
-
print(f'\t\t{k} (aux) - Loss: {losses[k]:.2f}, Acc: {accuracies[k]:.2f}')
|
264 |
-
for k in sorted(other_metrics.keys()):
|
265 |
-
if 'confusion' not in k:
|
266 |
-
print(f'\t\t{k} - {other_metrics[k]:.2f}')
|
267 |
-
|
268 |
-
|
269 |
-
def run_experiment(params, verbose=True):
|
270 |
-
loss_functions, train_data_loader, test_data_loader, weights = prepare_data_and_loss(params)
|
271 |
-
|
272 |
-
model_params = [params[k] for k in ["input_dim", "activation", "hidden_dims_cocktail", "latent_dim", "dropout", "auxiliaries_dict", "hidden_dims_decoder"]]
|
273 |
-
model = get_multihead_model(*model_params)
|
274 |
-
opt = torch.optim.AdamW(model.parameters(), lr=params['lr'])
|
275 |
-
|
276 |
-
|
277 |
-
all_train_losses = []
|
278 |
-
all_eval_losses = []
|
279 |
-
all_train_accuracies = []
|
280 |
-
all_eval_accuracies = []
|
281 |
-
all_eval_other_metrics = []
|
282 |
-
all_train_other_metrics = []
|
283 |
-
best_loss = np.inf
|
284 |
-
model, eval_losses, eval_accuracies, eval_other_metrics = run_epoch(opt=opt, train=False, model=model, data=test_data_loader, loss_functions=loss_functions,
|
285 |
-
weights=weights, params=params)
|
286 |
-
all_eval_losses.append(eval_losses)
|
287 |
-
all_eval_accuracies.append(eval_accuracies)
|
288 |
-
all_eval_other_metrics.append(eval_other_metrics)
|
289 |
-
if verbose: print(f'\n--------\nEpoch #0')
|
290 |
-
if verbose: print_losses(train=False, accuracies=eval_accuracies, losses=eval_losses, other_metrics=eval_other_metrics)
|
291 |
-
for epoch in range(params['nb_epochs']):
|
292 |
-
if verbose and (epoch + 1) % params['print_every'] == 0: print(f'\n--------\nEpoch #{epoch+1}')
|
293 |
-
model, train_losses, train_accuracies, train_other_metrics = run_epoch(opt=opt, train=True, model=model, data=train_data_loader, loss_functions=loss_functions,
|
294 |
-
weights=weights, params=params)
|
295 |
-
if verbose and (epoch + 1) % params['print_every'] == 0: print_losses(train=True, accuracies=train_accuracies, losses=train_losses, other_metrics=train_other_metrics)
|
296 |
-
model, eval_losses, eval_accuracies, eval_other_metrics = run_epoch(opt=opt, train=False, model=model, data=test_data_loader, loss_functions=loss_functions,
|
297 |
-
weights=weights, params=params)
|
298 |
-
if verbose and (epoch + 1) % params['print_every'] == 0: print_losses(train=False, accuracies=eval_accuracies, losses=eval_losses, other_metrics=eval_other_metrics)
|
299 |
-
if eval_losses['global_loss'] < best_loss:
|
300 |
-
best_loss = eval_losses['global_loss']
|
301 |
-
if verbose: print(f'Saving new best model with loss {best_loss:.2f}')
|
302 |
-
torch.save(model.state_dict(), params['save_path'] + f'checkpoint_best.save')
|
303 |
-
|
304 |
-
# log
|
305 |
-
all_train_losses.append(train_losses)
|
306 |
-
all_train_accuracies.append(train_accuracies)
|
307 |
-
all_eval_losses.append(eval_losses)
|
308 |
-
all_eval_accuracies.append(eval_accuracies)
|
309 |
-
all_eval_other_metrics.append(eval_other_metrics)
|
310 |
-
all_train_other_metrics.append(train_other_metrics)
|
311 |
-
|
312 |
-
# if epoch == params['nb_epoch_switch_beta']:
|
313 |
-
# params['beta_vae'] = 2.5
|
314 |
-
# params['auxiliaries_dict']['prep_type']['weight'] /= 10
|
315 |
-
# params['auxiliaries_dict']['glasses']['weight'] /= 10
|
316 |
-
|
317 |
-
if (epoch + 1) % params['plot_every'] == 0:
|
318 |
-
|
319 |
-
plot_results(all_train_losses, all_train_accuracies, all_train_other_metrics,
|
320 |
-
all_eval_losses, all_eval_accuracies, all_eval_other_metrics, params['plot_path'], weights)
|
321 |
-
|
322 |
-
return model
|
323 |
-
|
324 |
-
def plot_results(all_train_losses, all_train_accuracies, all_train_other_metrics,
|
325 |
-
all_eval_losses, all_eval_accuracies, all_eval_other_metrics, plot_path, weights):
|
326 |
-
|
327 |
-
steps = np.arange(len(all_eval_accuracies))
|
328 |
-
|
329 |
-
loss_keys = sorted(all_train_losses[0].keys())
|
330 |
-
acc_keys = sorted(all_train_accuracies[0].keys())
|
331 |
-
metrics_keys = sorted(all_train_other_metrics[0].keys())
|
332 |
-
|
333 |
-
plt.figure()
|
334 |
-
plt.title('Train losses')
|
335 |
-
for k in loss_keys:
|
336 |
-
factor = 1 if k == 'mse_loss' else 1
|
337 |
-
if k not in weights.keys():
|
338 |
-
plt.plot(steps[1:], [train_loss[k] * factor for train_loss in all_train_losses], label=k)
|
339 |
-
else:
|
340 |
-
if weights[k] != 0:
|
341 |
-
plt.plot(steps[1:], [train_loss[k] * factor for train_loss in all_train_losses], label=k)
|
342 |
-
|
343 |
-
plt.legend()
|
344 |
-
plt.ylim([0, 4])
|
345 |
-
plt.savefig(plot_path + 'train_losses.png', dpi=200)
|
346 |
-
fig = plt.gcf()
|
347 |
-
plt.close(fig)
|
348 |
-
|
349 |
-
plt.figure()
|
350 |
-
plt.title('Train accuracies')
|
351 |
-
for k in acc_keys:
|
352 |
-
if weights[k] != 0:
|
353 |
-
plt.plot(steps[1:], [train_acc[k] for train_acc in all_train_accuracies], label=k)
|
354 |
-
plt.legend()
|
355 |
-
plt.ylim([0, 1])
|
356 |
-
plt.savefig(plot_path + 'train_acc.png', dpi=200)
|
357 |
-
fig = plt.gcf()
|
358 |
-
plt.close(fig)
|
359 |
-
|
360 |
-
plt.figure()
|
361 |
-
plt.title('Train other metrics')
|
362 |
-
for k in metrics_keys:
|
363 |
-
if 'confusion' not in k and 'presence' in k:
|
364 |
-
plt.plot(steps[1:], [train_metric[k] for train_metric in all_train_other_metrics], label=k)
|
365 |
-
plt.legend()
|
366 |
-
plt.ylim([0, 1])
|
367 |
-
plt.savefig(plot_path + 'train_ing_presence_errors.png', dpi=200)
|
368 |
-
fig = plt.gcf()
|
369 |
-
plt.close(fig)
|
370 |
-
|
371 |
-
plt.figure()
|
372 |
-
plt.title('Train other metrics')
|
373 |
-
for k in metrics_keys:
|
374 |
-
if 'confusion' not in k and 'presence' not in k:
|
375 |
-
plt.plot(steps[1:], [train_metric[k] for train_metric in all_train_other_metrics], label=k)
|
376 |
-
plt.legend()
|
377 |
-
plt.ylim([0, 15])
|
378 |
-
plt.savefig(plot_path + 'train_ing_q_error.png', dpi=200)
|
379 |
-
fig = plt.gcf()
|
380 |
-
plt.close(fig)
|
381 |
-
|
382 |
-
plt.figure()
|
383 |
-
plt.title('Eval losses')
|
384 |
-
for k in loss_keys:
|
385 |
-
factor = 1 if k == 'mse_loss' else 1
|
386 |
-
if k not in weights.keys():
|
387 |
-
plt.plot(steps, [eval_loss[k] * factor for eval_loss in all_eval_losses], label=k)
|
388 |
-
else:
|
389 |
-
if weights[k] != 0:
|
390 |
-
plt.plot(steps, [eval_loss[k] * factor for eval_loss in all_eval_losses], label=k)
|
391 |
-
plt.legend()
|
392 |
-
plt.ylim([0, 4])
|
393 |
-
plt.savefig(plot_path + 'eval_losses.png', dpi=200)
|
394 |
-
fig = plt.gcf()
|
395 |
-
plt.close(fig)
|
396 |
-
|
397 |
-
plt.figure()
|
398 |
-
plt.title('Eval accuracies')
|
399 |
-
for k in acc_keys:
|
400 |
-
if weights[k] != 0:
|
401 |
-
plt.plot(steps, [eval_acc[k] for eval_acc in all_eval_accuracies], label=k)
|
402 |
-
plt.legend()
|
403 |
-
plt.ylim([0, 1])
|
404 |
-
plt.savefig(plot_path + 'eval_acc.png', dpi=200)
|
405 |
-
fig = plt.gcf()
|
406 |
-
plt.close(fig)
|
407 |
-
|
408 |
-
plt.figure()
|
409 |
-
plt.title('Eval other metrics')
|
410 |
-
for k in metrics_keys:
|
411 |
-
if 'confusion' not in k and 'presence' in k:
|
412 |
-
plt.plot(steps, [eval_metric[k] for eval_metric in all_eval_other_metrics], label=k)
|
413 |
-
plt.legend()
|
414 |
-
plt.ylim([0, 1])
|
415 |
-
plt.savefig(plot_path + 'eval_ing_presence_errors.png', dpi=200)
|
416 |
-
fig = plt.gcf()
|
417 |
-
plt.close(fig)
|
418 |
-
|
419 |
-
plt.figure()
|
420 |
-
plt.title('Eval other metrics')
|
421 |
-
for k in metrics_keys:
|
422 |
-
if 'confusion' not in k and 'presence' not in k:
|
423 |
-
plt.plot(steps, [eval_metric[k] for eval_metric in all_eval_other_metrics], label=k)
|
424 |
-
plt.legend()
|
425 |
-
plt.ylim([0, 15])
|
426 |
-
plt.savefig(plot_path + 'eval_ing_q_error.png', dpi=200)
|
427 |
-
fig = plt.gcf()
|
428 |
-
plt.close(fig)
|
429 |
-
|
430 |
-
|
431 |
-
for k in metrics_keys:
|
432 |
-
if 'confusion' in k:
|
433 |
-
plt.figure()
|
434 |
-
plt.title(k)
|
435 |
-
plt.ylabel('True')
|
436 |
-
plt.xlabel('Predicted')
|
437 |
-
plt.imshow(all_eval_other_metrics[-1][k], vmin=0, vmax=1)
|
438 |
-
plt.colorbar()
|
439 |
-
plt.savefig(plot_path + f'eval_{k}.png', dpi=200)
|
440 |
-
fig = plt.gcf()
|
441 |
-
plt.close(fig)
|
442 |
-
|
443 |
-
for k in metrics_keys:
|
444 |
-
if 'confusion' in k:
|
445 |
-
plt.figure()
|
446 |
-
plt.title(k)
|
447 |
-
plt.ylabel('True')
|
448 |
-
plt.xlabel('Predicted')
|
449 |
-
plt.imshow(all_train_other_metrics[-1][k], vmin=0, vmax=1)
|
450 |
-
plt.colorbar()
|
451 |
-
plt.savefig(plot_path + f'train_{k}.png', dpi=200)
|
452 |
-
fig = plt.gcf()
|
453 |
-
plt.close(fig)
|
454 |
-
|
455 |
-
plt.close('all')
|
456 |
-
|
457 |
-
|
458 |
-
def get_model(model_path):
|
459 |
-
|
460 |
-
with open(model_path + 'params.json', 'r') as f:
|
461 |
-
params = json.load(f)
|
462 |
-
params['save_path'] = model_path
|
463 |
-
model_chkpt = model_path + "checkpoint_best.save"
|
464 |
-
model_params = [params[k] for k in ["input_dim", "activation", "hidden_dims_cocktail", "latent_dim", "dropout", "auxiliaries_dict", "hidden_dims_decoder"]]
|
465 |
-
model = get_multihead_model(*model_params)
|
466 |
-
model.load_state_dict(torch.load(model_chkpt))
|
467 |
-
model.eval()
|
468 |
-
max_ing_quantities = np.loadtxt(model_path + 'max_ing_quantities.txt')
|
469 |
-
def predict(ing_qs, aux_str):
|
470 |
-
ing_qs /= max_ing_quantities
|
471 |
-
input_model = torch.FloatTensor(ing_qs).reshape(1, -1)
|
472 |
-
_, outputs, auxiliaries_str = model.forward(input_model, )
|
473 |
-
if isinstance(aux_str, str):
|
474 |
-
return outputs[auxiliaries_str.index(aux_str)].detach().numpy()
|
475 |
-
elif isinstance(aux_str, list):
|
476 |
-
return [outputs[auxiliaries_str.index(aux)].detach().numpy() for aux in aux_str]
|
477 |
-
else:
|
478 |
-
raise ValueError
|
479 |
-
return predict, params
|
480 |
-
|
481 |
-
|
482 |
-
def compute_expe_name_and_save_path(params):
|
483 |
-
weights_str = '['
|
484 |
-
for aux in params['auxiliaries_dict'].keys():
|
485 |
-
weights_str += f'{params["auxiliaries_dict"][aux]["weight"]}, '
|
486 |
-
weights_str = weights_str[:-2] + ']'
|
487 |
-
save_path = params['save_path'] + params["trial_id"]
|
488 |
-
save_path += f'_lr{params["lr"]}'
|
489 |
-
save_path += f'_betavae{params["beta_vae"]}'
|
490 |
-
save_path += f'_bs{params["batch_size"]}'
|
491 |
-
save_path += f'_latentdim{params["latent_dim"]}'
|
492 |
-
save_path += f'_hding{params["hidden_dims_ingredients"]}'
|
493 |
-
save_path += f'_hdcocktail{params["hidden_dims_cocktail"]}'
|
494 |
-
save_path += f'_hddecoder{params["hidden_dims_decoder"]}'
|
495 |
-
save_path += f'_agg{params["agg"]}'
|
496 |
-
save_path += f'_activ{params["activation"]}'
|
497 |
-
save_path += f'_w{weights_str}'
|
498 |
-
counter = 0
|
499 |
-
while os.path.exists(save_path + f"_{counter}"):
|
500 |
-
counter += 1
|
501 |
-
save_path = save_path + f"_{counter}" + '/'
|
502 |
-
params["save_path"] = save_path
|
503 |
-
os.makedirs(save_path)
|
504 |
-
os.makedirs(save_path + 'plots/')
|
505 |
-
params['plot_path'] = save_path + 'plots/'
|
506 |
-
print(f'logging to {save_path}')
|
507 |
-
return params
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
if __name__ == '__main__':
|
512 |
-
params = get_params()
|
513 |
-
run_experiment(params)
|
514 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/cocktails/representation_learning/simple_model.py
DELETED
@@ -1,54 +0,0 @@
|
|
1 |
-
import torch; torch.manual_seed(0)
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
import torch.utils
|
5 |
-
import torch.distributions
|
6 |
-
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
|
7 |
-
|
8 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
9 |
-
|
10 |
-
def get_activation(activation):
|
11 |
-
if activation == 'tanh':
|
12 |
-
activ = F.tanh
|
13 |
-
elif activation == 'relu':
|
14 |
-
activ = F.relu
|
15 |
-
elif activation == 'mish':
|
16 |
-
activ = F.mish
|
17 |
-
elif activation == 'sigmoid':
|
18 |
-
activ = torch.sigmoid
|
19 |
-
elif activation == 'leakyrelu':
|
20 |
-
activ = F.leaky_relu
|
21 |
-
elif activation == 'exp':
|
22 |
-
activ = torch.exp
|
23 |
-
else:
|
24 |
-
raise ValueError
|
25 |
-
return activ
|
26 |
-
|
27 |
-
|
28 |
-
class SimpleNet(nn.Module):
|
29 |
-
def __init__(self, input_dim, hidden_dims, output_dim, activation, dropout, final_activ=None):
|
30 |
-
super(SimpleNet, self).__init__()
|
31 |
-
self.linears = nn.ModuleList()
|
32 |
-
self.dropouts = nn.ModuleList()
|
33 |
-
self.output_dim = output_dim
|
34 |
-
dims = [input_dim] + hidden_dims + [output_dim]
|
35 |
-
for d_in, d_out in zip(dims[:-1], dims[1:]):
|
36 |
-
self.linears.append(nn.Linear(d_in, d_out))
|
37 |
-
self.dropouts.append(nn.Dropout(dropout))
|
38 |
-
self.activation = get_activation(activation)
|
39 |
-
self.n_layers = len(self.linears)
|
40 |
-
self.layer_range = range(self.n_layers)
|
41 |
-
if final_activ != None:
|
42 |
-
self.final_activ = get_activation(final_activ)
|
43 |
-
self.use_final_activ = True
|
44 |
-
else:
|
45 |
-
self.use_final_activ = False
|
46 |
-
|
47 |
-
def forward(self, x):
|
48 |
-
for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts):
|
49 |
-
x = layer(x)
|
50 |
-
if i_layer != self.n_layers - 1:
|
51 |
-
x = self.activation(dropout(x))
|
52 |
-
if self.use_final_activ: x = self.final_activ(x)
|
53 |
-
return x
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/cocktails/representation_learning/vae_model.py
DELETED
@@ -1,238 +0,0 @@
|
|
1 |
-
import torch; torch.manual_seed(0)
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
import torch.utils
|
5 |
-
import torch.distributions
|
6 |
-
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
|
7 |
-
|
8 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
9 |
-
|
10 |
-
def get_activation(activation):
|
11 |
-
if activation == 'tanh':
|
12 |
-
activ = F.tanh
|
13 |
-
elif activation == 'relu':
|
14 |
-
activ = F.relu
|
15 |
-
elif activation == 'mish':
|
16 |
-
activ = F.mish
|
17 |
-
elif activation == 'sigmoid':
|
18 |
-
activ = F.sigmoid
|
19 |
-
elif activation == 'leakyrelu':
|
20 |
-
activ = F.leaky_relu
|
21 |
-
elif activation == 'exp':
|
22 |
-
activ = torch.exp
|
23 |
-
else:
|
24 |
-
raise ValueError
|
25 |
-
return activ
|
26 |
-
|
27 |
-
class IngredientEncoder(nn.Module):
|
28 |
-
def __init__(self, input_dim, deepset_latent_dim, hidden_dims, activation, dropout):
|
29 |
-
super(IngredientEncoder, self).__init__()
|
30 |
-
self.linears = nn.ModuleList()
|
31 |
-
self.dropouts = nn.ModuleList()
|
32 |
-
dims = [input_dim] + hidden_dims + [deepset_latent_dim]
|
33 |
-
for d_in, d_out in zip(dims[:-1], dims[1:]):
|
34 |
-
self.linears.append(nn.Linear(d_in, d_out))
|
35 |
-
self.dropouts.append(nn.Dropout(dropout))
|
36 |
-
self.activation = get_activation(activation)
|
37 |
-
self.n_layers = len(self.linears)
|
38 |
-
self.layer_range = range(self.n_layers)
|
39 |
-
|
40 |
-
def forward(self, x):
|
41 |
-
for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts):
|
42 |
-
x = layer(x)
|
43 |
-
if i_layer != self.n_layers - 1:
|
44 |
-
x = self.activation(dropout(x))
|
45 |
-
return x # do not use dropout on last layer?
|
46 |
-
|
47 |
-
class DeepsetCocktailEncoder(nn.Module):
|
48 |
-
def __init__(self, input_dim, deepset_latent_dim, hidden_dims_ing, activation,
|
49 |
-
hidden_dims_cocktail, latent_dim, aggregation, dropout):
|
50 |
-
super(DeepsetCocktailEncoder, self).__init__()
|
51 |
-
self.input_dim = input_dim # dimension of ingredient representation + quantity
|
52 |
-
self.ingredient_encoder = IngredientEncoder(input_dim, deepset_latent_dim, hidden_dims_ing, activation, dropout) # encode each ingredient separately
|
53 |
-
self.deepset_latent_dim = deepset_latent_dim # dimension of the deepset aggregation
|
54 |
-
self.aggregation = aggregation
|
55 |
-
self.latent_dim = latent_dim
|
56 |
-
# post aggregation network
|
57 |
-
self.linears = nn.ModuleList()
|
58 |
-
self.dropouts = nn.ModuleList()
|
59 |
-
dims = [deepset_latent_dim] + hidden_dims_cocktail
|
60 |
-
for d_in, d_out in zip(dims[:-1], dims[1:]):
|
61 |
-
self.linears.append(nn.Linear(d_in, d_out))
|
62 |
-
self.dropouts.append(nn.Dropout(dropout))
|
63 |
-
self.FC_mean = nn.Linear(hidden_dims_cocktail[-1], latent_dim)
|
64 |
-
self.FC_logvar = nn.Linear(hidden_dims_cocktail[-1], latent_dim)
|
65 |
-
self.softplus = nn.Softplus()
|
66 |
-
|
67 |
-
self.activation = get_activation(activation)
|
68 |
-
self.n_layers = len(self.linears)
|
69 |
-
self.layer_range = range(self.n_layers)
|
70 |
-
|
71 |
-
def forward(self, nb_ingredients, x):
|
72 |
-
|
73 |
-
# reshape x in (batch size * nb ingredients, dim_ing_rep)
|
74 |
-
batch_size = x.shape[0]
|
75 |
-
all_ingredients = []
|
76 |
-
for i in range(batch_size):
|
77 |
-
for j in range(nb_ingredients[i]):
|
78 |
-
all_ingredients.append(x[i, self.input_dim * j: self.input_dim * (j + 1)].reshape(1, -1))
|
79 |
-
x = torch.cat(all_ingredients, dim=0)
|
80 |
-
# encode ingredients in parallel
|
81 |
-
ingredients_encodings = self.ingredient_encoder(x)
|
82 |
-
assert ingredients_encodings.shape == (torch.sum(nb_ingredients), self.deepset_latent_dim)
|
83 |
-
|
84 |
-
# aggregate
|
85 |
-
x = []
|
86 |
-
index_first = 0
|
87 |
-
for i in range(batch_size):
|
88 |
-
index_last = index_first + nb_ingredients[i]
|
89 |
-
# aggregate
|
90 |
-
if self.aggregation == 'sum':
|
91 |
-
x.append(torch.sum(ingredients_encodings[index_first:index_last], dim=0).reshape(1, -1))
|
92 |
-
elif self.aggregation == 'mean':
|
93 |
-
x.append(torch.mean(ingredients_encodings[index_first:index_last], dim=0).reshape(1, -1))
|
94 |
-
else:
|
95 |
-
raise ValueError
|
96 |
-
index_first = index_last
|
97 |
-
x = torch.cat(x, dim=0)
|
98 |
-
assert x.shape[0] == batch_size
|
99 |
-
|
100 |
-
for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts):
|
101 |
-
x = self.activation(dropout(layer(x)))
|
102 |
-
mean = self.FC_mean(x)
|
103 |
-
logvar = self.FC_logvar(x)
|
104 |
-
return mean, logvar
|
105 |
-
|
106 |
-
class Decoder(nn.Module):
|
107 |
-
def __init__(self, latent_dim, hidden_dims, num_ingredients, activation, dropout, filter_output=None):
|
108 |
-
super(Decoder, self).__init__()
|
109 |
-
self.linears = nn.ModuleList()
|
110 |
-
self.dropouts = nn.ModuleList()
|
111 |
-
dims = [latent_dim] + hidden_dims + [num_ingredients]
|
112 |
-
for d_in, d_out in zip(dims[:-1], dims[1:]):
|
113 |
-
self.linears.append(nn.Linear(d_in, d_out))
|
114 |
-
self.dropouts.append(nn.Dropout(dropout))
|
115 |
-
self.activation = get_activation(activation)
|
116 |
-
self.n_layers = len(self.linears)
|
117 |
-
self.layer_range = range(self.n_layers)
|
118 |
-
self.filter = filter_output
|
119 |
-
|
120 |
-
def forward(self, x, to_filter=False):
|
121 |
-
for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts):
|
122 |
-
x = layer(x)
|
123 |
-
if i_layer != self.n_layers - 1:
|
124 |
-
x = self.activation(dropout(x))
|
125 |
-
if to_filter:
|
126 |
-
x = self.filter(x)
|
127 |
-
return x
|
128 |
-
|
129 |
-
class PredictorHead(nn.Module):
|
130 |
-
def __init__(self, latent_dim, dim_output, final_activ):
|
131 |
-
super(PredictorHead, self).__init__()
|
132 |
-
self.linear = nn.Linear(latent_dim, dim_output)
|
133 |
-
if final_activ != None:
|
134 |
-
self.final_activ = get_activation(final_activ)
|
135 |
-
self.use_final_activ = True
|
136 |
-
else:
|
137 |
-
self.use_final_activ = False
|
138 |
-
|
139 |
-
def forward(self, x):
|
140 |
-
x = self.linear(x)
|
141 |
-
if self.use_final_activ: x = self.final_activ(x)
|
142 |
-
return x
|
143 |
-
|
144 |
-
|
145 |
-
class VAEModel(nn.Module):
|
146 |
-
def __init__(self, encoder, decoder, auxiliaries_dict):
|
147 |
-
super(VAEModel, self).__init__()
|
148 |
-
self.encoder = encoder
|
149 |
-
self.decoder = decoder
|
150 |
-
self.latent_dim = self.encoder.latent_dim
|
151 |
-
self.auxiliaries_str = []
|
152 |
-
self.auxiliaries = nn.ModuleList()
|
153 |
-
for aux_str in sorted(auxiliaries_dict.keys()):
|
154 |
-
if aux_str == 'taste_reps':
|
155 |
-
self.taste_reps_decoder = PredictorHead(self.latent_dim, auxiliaries_dict[aux_str]['dim_output'], auxiliaries_dict[aux_str]['final_activ'])
|
156 |
-
else:
|
157 |
-
self.auxiliaries_str.append(aux_str)
|
158 |
-
self.auxiliaries.append(PredictorHead(self.latent_dim, auxiliaries_dict[aux_str]['dim_output'], auxiliaries_dict[aux_str]['final_activ']))
|
159 |
-
|
160 |
-
def reparameterization(self, mean, logvar):
|
161 |
-
std = torch.exp(0.5 * logvar)
|
162 |
-
epsilon = torch.randn_like(std).to(device) # sampling epsilon
|
163 |
-
z = mean + std * epsilon # reparameterization trick
|
164 |
-
return z
|
165 |
-
|
166 |
-
|
167 |
-
def sample(self, n=1):
|
168 |
-
z = torch.randn(size=(n, self.latent_dim))
|
169 |
-
return self.decoder(z)
|
170 |
-
|
171 |
-
def get_all_auxiliaries(self, x):
|
172 |
-
return [aux(x) for aux in self.auxiliaries]
|
173 |
-
|
174 |
-
def get_auxiliary(self, z, aux_str):
|
175 |
-
if aux_str == 'taste_reps':
|
176 |
-
return self.taste_reps_decoder(z)
|
177 |
-
else:
|
178 |
-
index = self.auxiliaries_str.index(aux_str)
|
179 |
-
return self.auxiliaries[index](z)
|
180 |
-
|
181 |
-
def forward_direct(self, x, aux_str=None, to_filter=False):
|
182 |
-
mean, logvar = self.encoder(x)
|
183 |
-
z = self.reparameterization(mean, logvar) # takes exponential function (log var -> std)
|
184 |
-
x_hat = self.decoder(mean, to_filter=to_filter)
|
185 |
-
if aux_str is not None:
|
186 |
-
return x_hat, z, mean, logvar, self.get_auxiliary(z, aux_str), [aux_str]
|
187 |
-
else:
|
188 |
-
return x_hat, z, mean, logvar, self.get_all_auxiliaries(z), self.auxiliaries_str
|
189 |
-
|
190 |
-
def forward(self, nb_ingredients, x, aux_str=None, to_filter=False):
|
191 |
-
assert False
|
192 |
-
mean, std = self.encoder(nb_ingredients, x)
|
193 |
-
z = self.reparameterization(mean, std) # takes exponential function (log var -> std)
|
194 |
-
x_hat = self.decoder(mean, to_filter=to_filter)
|
195 |
-
if aux_str is not None:
|
196 |
-
return x_hat, z, mean, std, self.get_auxiliary(z, aux_str), [aux_str]
|
197 |
-
else:
|
198 |
-
return x_hat, z, mean, std, self.get_all_auxiliaries(z), self.auxiliaries_str
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
class SimpleEncoder(nn.Module):
|
204 |
-
|
205 |
-
def __init__(self, input_dim, hidden_dims, latent_dim, activation, dropout):
|
206 |
-
super(SimpleEncoder, self).__init__()
|
207 |
-
self.latent_dim = latent_dim
|
208 |
-
# post aggregation network
|
209 |
-
self.linears = nn.ModuleList()
|
210 |
-
self.dropouts = nn.ModuleList()
|
211 |
-
dims = [input_dim] + hidden_dims
|
212 |
-
for d_in, d_out in zip(dims[:-1], dims[1:]):
|
213 |
-
self.linears.append(nn.Linear(d_in, d_out))
|
214 |
-
self.dropouts.append(nn.Dropout(dropout))
|
215 |
-
self.FC_mean = nn.Linear(hidden_dims[-1], latent_dim)
|
216 |
-
self.FC_logvar = nn.Linear(hidden_dims[-1], latent_dim)
|
217 |
-
# self.softplus = nn.Softplus()
|
218 |
-
|
219 |
-
self.activation = get_activation(activation)
|
220 |
-
self.n_layers = len(self.linears)
|
221 |
-
self.layer_range = range(self.n_layers)
|
222 |
-
|
223 |
-
def forward(self, x):
|
224 |
-
for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts):
|
225 |
-
x = self.activation(dropout(layer(x)))
|
226 |
-
mean = self.FC_mean(x)
|
227 |
-
logvar = self.FC_logvar(x)
|
228 |
-
return mean, logvar
|
229 |
-
|
230 |
-
def get_vae_model(input_dim, deepset_latent_dim, hidden_dims_ing, activation,
|
231 |
-
hidden_dims_cocktail, hidden_dims_decoder, num_ingredients, latent_dim, aggregation, dropout, auxiliaries_dict,
|
232 |
-
filter_decoder_output):
|
233 |
-
# encoder = DeepsetCocktailEncoder(input_dim, deepset_latent_dim, hidden_dims_ing, activation,
|
234 |
-
# hidden_dims_cocktail, latent_dim, aggregation, dropout)
|
235 |
-
encoder = SimpleEncoder(num_ingredients, hidden_dims_cocktail, latent_dim, activation, dropout)
|
236 |
-
decoder = Decoder(latent_dim, hidden_dims_decoder, num_ingredients, activation, dropout, filter_output=filter_decoder_output)
|
237 |
-
vae = VAEModel(encoder, decoder, auxiliaries_dict)
|
238 |
-
return vae
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/cocktails/utilities/__init__.py
DELETED
File without changes
|
src/cocktails/utilities/analysis_utilities.py
DELETED
@@ -1,189 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import matplotlib.pyplot as plt
|
3 |
-
|
4 |
-
from src.cocktails.utilities.ingredients_utilities import ingredient_list, extract_ingredients, ingredients_per_type
|
5 |
-
|
6 |
-
color_codes = dict(ancestral='#000000',
|
7 |
-
spirit_forward='#2320D2',
|
8 |
-
duo='#6E20D2',
|
9 |
-
champagne_cocktail='#25FFCA',
|
10 |
-
complex_highball='#068F25',
|
11 |
-
simple_highball='#25FF57',
|
12 |
-
collins='#77FF96',
|
13 |
-
julep='#25B8FF',
|
14 |
-
simple_sour='#FBD756',
|
15 |
-
complex_sour='#DCAD07',
|
16 |
-
simple_sour_with_juice='#FF5033',
|
17 |
-
complex_sour_with_juice='#D42306',
|
18 |
-
# simple_sour_with_egg='#FF9C54',
|
19 |
-
# complex_sour_with_egg='#CF5700',
|
20 |
-
# almost_simple_sor='#FF5033',
|
21 |
-
# almost_sor='#D42306',
|
22 |
-
# almost_sor_with_egg='#D42306',
|
23 |
-
other='#9B9B9B'
|
24 |
-
)
|
25 |
-
|
26 |
-
def get_subcategories(data):
|
27 |
-
subcategories = np.array(data['subcategory'])
|
28 |
-
sub_categories_list = sorted(set(subcategories))
|
29 |
-
subcat_count = dict(zip(sub_categories_list, [0] * len(sub_categories_list)))
|
30 |
-
for sc in data['subcategory']:
|
31 |
-
subcat_count[sc] += 1
|
32 |
-
return subcategories, sub_categories_list, subcat_count
|
33 |
-
|
34 |
-
def get_ingredient_count(data):
|
35 |
-
ingredient_counts = dict(zip(ingredient_list, [0] * len(ingredient_list)))
|
36 |
-
for ing_str in data['ingredients_str']:
|
37 |
-
ingredients, _ = extract_ingredients(ing_str)
|
38 |
-
for ing in ingredients:
|
39 |
-
ingredient_counts[ing] += 1
|
40 |
-
return ingredient_counts
|
41 |
-
|
42 |
-
def compute_eucl_dist(a, b):
|
43 |
-
return np.sqrt(np.sum((a - b)**2))
|
44 |
-
|
45 |
-
def recipe_contains(ingredients, stuff):
|
46 |
-
if stuff in ingredient_list:
|
47 |
-
return stuff in ingredients
|
48 |
-
elif stuff == 'juice':
|
49 |
-
return any(['juice' in ing and 'lemon' not in ing and 'lime' not in ing for ing in ingredients])
|
50 |
-
elif stuff == 'bubbles':
|
51 |
-
return any([ing in ['soda', 'tonic', 'cola', 'sparkling wine', 'ginger beer'] for ing in ingredients])
|
52 |
-
elif stuff == 'acid':
|
53 |
-
return any([ing in ['lemon juice', 'lime juice'] for ing in ingredients])
|
54 |
-
elif stuff == 'vermouth':
|
55 |
-
return any([ing in ingredients_per_type['vermouth'] for ing in ingredients])
|
56 |
-
elif stuff == 'plain sweet':
|
57 |
-
plain_sweet = ingredients_per_type['sweeteners']
|
58 |
-
return any([ing in plain_sweet for ing in ingredients])
|
59 |
-
elif stuff == 'sweet':
|
60 |
-
sweet = ingredients_per_type['sweeteners'] + ingredients_per_type['liqueur'] + ['sweet vermouth', 'lillet blanc']
|
61 |
-
return any([ing in sweet for ing in ingredients])
|
62 |
-
elif stuff == 'spirit':
|
63 |
-
return any([ing in ingredients_per_type['liquor'] for ing in ingredients])
|
64 |
-
else:
|
65 |
-
raise ValueError
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
def radar_factory(num_vars, frame='circle'):
|
70 |
-
# from stackoverflow's post? Or matplotlib's blog
|
71 |
-
"""
|
72 |
-
Create a radar chart with `num_vars` axes.
|
73 |
-
|
74 |
-
This function creates a RadarAxes projection and registers it.
|
75 |
-
|
76 |
-
Parameters
|
77 |
-
----------
|
78 |
-
num_vars : int
|
79 |
-
Number of variables for radar chart.
|
80 |
-
frame : {'circle', 'polygon'}
|
81 |
-
Shape of frame surrounding axes.
|
82 |
-
|
83 |
-
"""
|
84 |
-
import numpy as np
|
85 |
-
|
86 |
-
from matplotlib.patches import Circle, RegularPolygon
|
87 |
-
from matplotlib.path import Path
|
88 |
-
from matplotlib.projections.polar import PolarAxes
|
89 |
-
from matplotlib.projections import register_projection
|
90 |
-
from matplotlib.spines import Spine
|
91 |
-
from matplotlib.transforms import Affine2D
|
92 |
-
# calculate evenly-spaced axis angles
|
93 |
-
theta = np.linspace(0, 2*np.pi, num_vars, endpoint=False)
|
94 |
-
|
95 |
-
class RadarAxes(PolarAxes):
|
96 |
-
|
97 |
-
name = 'radar'
|
98 |
-
# use 1 line segment to connect specified points
|
99 |
-
RESOLUTION = 1
|
100 |
-
|
101 |
-
def __init__(self, *args, **kwargs):
|
102 |
-
super().__init__(*args, **kwargs)
|
103 |
-
# rotate plot such that the first axis is at the top
|
104 |
-
self.set_theta_zero_location('N')
|
105 |
-
|
106 |
-
def fill(self, *args, closed=True, **kwargs):
|
107 |
-
"""Override fill so that line is closed by default"""
|
108 |
-
return super().fill(closed=closed, *args, **kwargs)
|
109 |
-
|
110 |
-
def plot(self, *args, **kwargs):
|
111 |
-
"""Override plot so that line is closed by default"""
|
112 |
-
lines = super().plot(*args, **kwargs)
|
113 |
-
for line in lines:
|
114 |
-
self._close_line(line)
|
115 |
-
|
116 |
-
def _close_line(self, line):
|
117 |
-
x, y = line.get_data()
|
118 |
-
# FIXME: markers at x[0], y[0] get doubled-up
|
119 |
-
if x[0] != x[-1]:
|
120 |
-
x = np.append(x, x[0])
|
121 |
-
y = np.append(y, y[0])
|
122 |
-
line.set_data(x, y)
|
123 |
-
|
124 |
-
def set_varlabels(self, labels):
|
125 |
-
self.set_thetagrids(np.degrees(theta), labels)
|
126 |
-
|
127 |
-
def _gen_axes_patch(self):
|
128 |
-
# The Axes patch must be centered at (0.5, 0.5) and of radius 0.5
|
129 |
-
# in axes coordinates.
|
130 |
-
if frame == 'circle':
|
131 |
-
return Circle((0.5, 0.5), 0.5)
|
132 |
-
elif frame == 'polygon':
|
133 |
-
return RegularPolygon((0.5, 0.5), num_vars,
|
134 |
-
radius=.5, edgecolor="k")
|
135 |
-
else:
|
136 |
-
raise ValueError("Unknown value for 'frame': %s" % frame)
|
137 |
-
|
138 |
-
def _gen_axes_spines(self):
|
139 |
-
if frame == 'circle':
|
140 |
-
return super()._gen_axes_spines()
|
141 |
-
elif frame == 'polygon':
|
142 |
-
# spine_type must be 'left'/'right'/'top'/'bottom'/'circle'.
|
143 |
-
spine = Spine(axes=self,
|
144 |
-
spine_type='circle',
|
145 |
-
path=Path.unit_regular_polygon(num_vars))
|
146 |
-
# unit_regular_polygon gives a polygon of radius 1 centered at
|
147 |
-
# (0, 0) but we want a polygon of radius 0.5 centered at (0.5,
|
148 |
-
# 0.5) in axes coordinates.
|
149 |
-
spine.set_transform(Affine2D().scale(.5).translate(.5, .5)
|
150 |
-
+ self.transAxes)
|
151 |
-
return {'polar': spine}
|
152 |
-
else:
|
153 |
-
raise ValueError("Unknown value for 'frame': %s" % frame)
|
154 |
-
|
155 |
-
register_projection(RadarAxes)
|
156 |
-
return theta
|
157 |
-
|
158 |
-
def plot_radar_cocktail(representation, labels_dim, labels_cocktails, save_path=None, to_show=False, to_save=False):
|
159 |
-
assert to_show or to_save, 'either show or save'
|
160 |
-
assert representation.ndim == 2
|
161 |
-
n_data, dim_rep = representation.shape
|
162 |
-
assert len(labels_cocktails) == n_data
|
163 |
-
assert len(labels_dim) == dim_rep
|
164 |
-
assert n_data <= 5, 'max 5 representation_analysis please'
|
165 |
-
|
166 |
-
theta = radar_factory(dim_rep, frame='circle')
|
167 |
-
|
168 |
-
|
169 |
-
fig, ax = plt.subplots(figsize=(9, 9), subplot_kw=dict(projection='radar'))
|
170 |
-
fig.subplots_adjust(wspace=0.25, hspace=0.20, top=0.85, bottom=0.05)
|
171 |
-
|
172 |
-
colors = ['b', 'r', 'g', 'm', 'y']
|
173 |
-
# Plot the four cases from the example data on separate axes
|
174 |
-
ax.set_rgrids([0.2, 0.4, 0.6, 0.8])
|
175 |
-
for d, color in zip(representation, colors):
|
176 |
-
ax.plot(theta, d, color=color)
|
177 |
-
for d, color in zip(representation, colors):
|
178 |
-
ax.fill(theta, d, facecolor=color, alpha=0.25)
|
179 |
-
ax.set_varlabels(labels_dim)
|
180 |
-
|
181 |
-
# add legend relative to top-left plot
|
182 |
-
legend = ax.legend(labels_cocktails, loc=(0.9, .95),
|
183 |
-
labelspacing=0.1, fontsize='small')
|
184 |
-
|
185 |
-
if to_save:
|
186 |
-
plt.savefig(save_path, bbox_artists=(legend,), dpi=200)
|
187 |
-
else:
|
188 |
-
plt.show()
|
189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/cocktails/utilities/cocktail_category_detection_utilities.py
DELETED
@@ -1,221 +0,0 @@
|
|
1 |
-
# The following functions check whether a cocktail belong to any of N categories
|
2 |
-
import numpy as np
|
3 |
-
from src.cocktails.utilities.ingredients_utilities import ingredient_profiles, ingredients_per_type, ingredient2ingredient_id, extract_ingredients
|
4 |
-
|
5 |
-
|
6 |
-
def is_ancestral(n, ingredient_indexes, ingredients, quantities):
|
7 |
-
# ancestrals have a strong spirit and some sweetness from sugar, syrup or liqueurs, no citrus.
|
8 |
-
# absinthe can be added up to 3 dashes.
|
9 |
-
# Liqueurs are there to bring sweetness, thus must stay below 15ml (if not it's a duo)
|
10 |
-
if n['spirit'] > 0 and n['citrus'] == 0 and n['plain_sweet'] + n['liqueur'] <= 2:
|
11 |
-
if n['spirit'] > 1 and 'absinthe' in ingredients:
|
12 |
-
if quantities[ingredients.index('absinthe')] < 3:
|
13 |
-
pass
|
14 |
-
else:
|
15 |
-
return False
|
16 |
-
if n['sugar'] < 2 and n['liqueur'] < 3:
|
17 |
-
if n['all'] - n['spirit'] - n['sugar'] -n['syrup']- n['liqueur']- n['inconsequentials'] == 0:
|
18 |
-
if n['liqueur'] == 0:
|
19 |
-
return True
|
20 |
-
else:
|
21 |
-
q_liqueur = np.sum([quantities[i_ing]
|
22 |
-
for i_ind, i_ing in zip(ingredient_indexes, range(len(ingredients)))
|
23 |
-
if ingredient_profiles['type'][i_ind].lower() == 'liqueur'])
|
24 |
-
if q_liqueur <= 15:
|
25 |
-
return True
|
26 |
-
else:
|
27 |
-
return False
|
28 |
-
return False
|
29 |
-
|
30 |
-
|
31 |
-
def is_simple_sour(n, ingredient_indexes, ingredients, quantities):
|
32 |
-
# simple sours contain a citrus, at least 1 spirit and non-alcoholic sweetness
|
33 |
-
if n['citrus'] + n['coffee']> 0 and n['spirit'] > 0 and n['plain_sweet'] > 0 and n['juice'] == 0:
|
34 |
-
if n['all'] - n['citrus'] - n['coffee'] - n['spirit'] - n['plain_sweet'] - n['juice'] -n['egg'] - n['inconsequentials'] == 0:
|
35 |
-
return True
|
36 |
-
return False
|
37 |
-
|
38 |
-
def is_complex_sour(n, ingredient_indexes, ingredients, quantities):
|
39 |
-
# complex sours are simple sours that use alcoholic sweetness, at least in part
|
40 |
-
if n['citrus'] + n['coffee'] > 0 and n['all_sweet'] > 0 and n['juice'] == 0:
|
41 |
-
if (n['spirit'] == 0 and n['liqueur'] > 0) or n['spirit'] > 0:
|
42 |
-
if n['vermouth'] + n['liqueur'] <= 2 and n['vermouth'] + n['liqueur'] > 0:
|
43 |
-
if n['all'] -n['coffee'] - n['citrus'] - n['spirit'] - n['sugar'] - n['syrup'] \
|
44 |
-
- n['liqueur'] - n['vermouth'] - n['egg'] - n['juice'] - n['inconsequentials'] == 0:
|
45 |
-
return True
|
46 |
-
return False
|
47 |
-
|
48 |
-
def is_spirit_forward(n, ingredient_indexes, ingredients, quantities):
|
49 |
-
# spirit forward contain at least a spirit and vermouth, no citrus. Can contain sweet (sugar, syrups, liqueurs)
|
50 |
-
if n['spirit'] > 0 and n['citrus'] == 0 and n['vermouth'] > 0:
|
51 |
-
if n['all'] - n['spirit'] - n['sugar'] - n['syrup'] - n['liqueur'] -n['egg'] - n['vermouth'] - n['inconsequentials']== 0:
|
52 |
-
return True
|
53 |
-
return False
|
54 |
-
|
55 |
-
def is_duo(n, ingredient_indexes, ingredients, quantities):
|
56 |
-
# duos are made of one spirit and one liqueur (above 15ml), under it's an ancestral, no citrus.
|
57 |
-
if n['spirit'] >= 1 and n['citrus'] == 0 and n['sugar']==0 and n['liqueur'] > 0 and n['vermouth'] == 0:
|
58 |
-
if n['all'] - n['spirit'] - n['sugar'] - n['liqueur'] - n['vermouth'] - n['inconsequentials'] == 0:
|
59 |
-
q_liqueur = np.sum([quantities[i_ing]
|
60 |
-
for i_ind, i_ing in zip(ingredient_indexes, range(len(ingredients)))
|
61 |
-
if ingredient_profiles['type'][i_ind].lower() == 'liqueur'])
|
62 |
-
if q_liqueur > 15:
|
63 |
-
return True
|
64 |
-
else:
|
65 |
-
return False
|
66 |
-
return False
|
67 |
-
|
68 |
-
def is_champagne_cocktail(n, ingredient_indexes, ingredients, quantities):
|
69 |
-
if n['sparkling'] > 0:
|
70 |
-
return True
|
71 |
-
else:
|
72 |
-
return False
|
73 |
-
|
74 |
-
def is_simple_highball(n, ingredient_indexes, ingredients, quantities):
|
75 |
-
# simple highballs have one alcoholic ingredient and bubbles
|
76 |
-
if n['alcoholic'] == 1 and n['bubbles'] > 0:
|
77 |
-
if n['all'] - n['alcoholic'] - n['bubbles'] - n['inconsequentials']== 0:
|
78 |
-
return True
|
79 |
-
return False
|
80 |
-
|
81 |
-
def is_complex_highball(n, ingredient_indexes, ingredients, quantities):
|
82 |
-
# complex highballs have at least one alcoholic ingredient and bubbles (possibly alcoholic). They also contain extra sugar under any form and juice
|
83 |
-
if n['alcoholic'] > 0 and (n['bubbles'] + n['sparkling']) == 1 and n['juice'] + n['all_sweet'] + n['sugar_bubbles']> 0:
|
84 |
-
if n['all'] - n['spirit'] - n['bubbles'] - n['sparkling'] - n['citrus'] - n['juice'] - n['liqueur'] \
|
85 |
-
- n['syrup'] - n['sugar'] -n['vermouth'] -n['egg'] - n['inconsequentials'] == 0:
|
86 |
-
if not is_collins(n, ingredient_indexes, ingredients, quantities) and not is_simple_highball(n, ingredient_indexes, ingredients, quantities):
|
87 |
-
return True
|
88 |
-
return False
|
89 |
-
|
90 |
-
def is_collins(n, ingredient_indexes, ingredients, quantities):
|
91 |
-
# collins are a particular kind of highball with sugar and citrus
|
92 |
-
if n['alcoholic'] == 1 and n['bubbles'] == 1 and n['citrus'] > 0 and n['plain_sweet'] + n['sugar_bubbles'] > 0:
|
93 |
-
if n['all'] - n['spirit'] - n['bubbles'] - n['citrus'] - n['sugar'] - n['inconsequentials'] == 0:
|
94 |
-
return True
|
95 |
-
return False
|
96 |
-
|
97 |
-
def is_julep(n, ingredient_indexes, ingredients, quantities):
|
98 |
-
# juleps involve smashd mint, sugar and a spirit, no citrus.
|
99 |
-
if 'mint' in ingredients and n['sugar'] > 0 and n['spirit'] > 0 and n['vermouth'] == 0 and n['citrus'] == 0:
|
100 |
-
return True
|
101 |
-
return False
|
102 |
-
|
103 |
-
def is_simple_sour_with_juice(n, ingredient_indexes, ingredients, quantities):
|
104 |
-
# almost sours are sours with juice
|
105 |
-
if n['juice'] > 0 and n['spirit'] > 0 and n['plain_sweet'] > 0:
|
106 |
-
if n['all'] - n['citrus'] - n['coffee'] - n['juice'] - n['spirit'] - n['sugar'] - n['syrup'] - n['egg'] - n['inconsequentials'] == 0:
|
107 |
-
return True
|
108 |
-
return False
|
109 |
-
|
110 |
-
|
111 |
-
def is_complex_sour_with_juice(n, ingredient_indexes, ingredients, quantities):
|
112 |
-
# almost sours are sours with juice
|
113 |
-
if n['juice'] > 0 and n['all_sweet'] > 0:
|
114 |
-
if (n['spirit'] == 0 and n['liqueur'] > 0) or n['spirit'] > 0:
|
115 |
-
if n['vermouth'] + n['liqueur'] <= 2 and n['vermouth'] + n['liqueur'] > 0:
|
116 |
-
if n['all'] -n['coffee'] - n['citrus'] - n['spirit'] - n['sugar'] - n['syrup'] \
|
117 |
-
- n['liqueur'] - n['vermouth'] - n['egg'] - n['juice'] - n['inconsequentials'] == 0:
|
118 |
-
return True
|
119 |
-
return False
|
120 |
-
|
121 |
-
|
122 |
-
is_sub_category = [is_ancestral, is_complex_sour, is_simple_sour, is_duo, is_champagne_cocktail,
|
123 |
-
is_spirit_forward, is_simple_highball, is_complex_highball, is_collins,
|
124 |
-
is_julep, is_simple_sour_with_juice, is_complex_sour_with_juice]
|
125 |
-
sub_categories = ['ancestral', 'complex_sour', 'simple_sour', 'duo', 'champagne_cocktail',
|
126 |
-
'spirit_forward', 'simple_highball', 'complex_highball', 'collins',
|
127 |
-
'julep', 'simple_sour_with_juice', 'complex_sour_with_juice']
|
128 |
-
|
129 |
-
|
130 |
-
# compute cocktail category as a function of ingredients and quantities, uses name to check match between name and cat (e.g. XXX Collins should be collins..)
|
131 |
-
# Categories definitions are based on https://www.seriouseats.com/cocktail-style-guide-categories-of-cocktails-glossary-families-of-drinks
|
132 |
-
def find_cocktail_sub_category(ingredients, quantities, name=None):
|
133 |
-
ingredient_indexes = [ingredient2ingredient_id[ing] for ing in ingredients]
|
134 |
-
n_spirit = np.sum([ingredient_profiles['type'][i].lower() == 'liquor' for i in ingredient_indexes ])
|
135 |
-
n_citrus = np.sum([ingredient_profiles['type'][i].lower()== 'acid' for i in ingredient_indexes])
|
136 |
-
n_sugar = np.sum([ingredient_profiles['ingredient'][i].lower() in ['double syrup', 'simple syrup', 'honey syrup'] for i in ingredient_indexes])
|
137 |
-
plain_sweet = ingredients_per_type['sweeteners']
|
138 |
-
all_sweet = ingredients_per_type['sweeteners'] + ingredients_per_type['liqueur'] + ['sweet vermouth', 'lillet blanc']
|
139 |
-
n_plain_sweet = np.sum([ingredient_profiles['ingredient'][i].lower() in plain_sweet for i in ingredient_indexes])
|
140 |
-
n_all_sweet = np.sum([ingredient_profiles['ingredient'][i].lower() in all_sweet for i in ingredient_indexes])
|
141 |
-
n_sugar_bubbles = np.sum([ingredient_profiles['ingredient'][i].lower() in ['cola', 'ginger beer', 'tonic'] for i in ingredient_indexes])
|
142 |
-
n_juice = np.sum([ingredient_profiles['type'][i].lower() == 'juice' for i in ingredient_indexes])
|
143 |
-
n_liqueur = np.sum([ingredient_profiles['type'][i].lower() == 'liqueur' for i in ingredient_indexes])
|
144 |
-
alcoholic = ingredients_per_type['liquor'] + ingredients_per_type['liqueur'] + ingredients_per_type['vermouth']
|
145 |
-
n_alcoholic = np.sum([ingredient_profiles['ingredient'][i].lower() in alcoholic for i in ingredient_indexes])
|
146 |
-
n_bitter = np.sum([ingredient_profiles['type'][i].lower() == 'bitters' for i in ingredient_indexes])
|
147 |
-
n_egg = np.sum([ingredient_profiles['ingredient'][i].lower() == 'egg' for i in ingredient_indexes])
|
148 |
-
n_vermouth = np.sum([ingredient_profiles['type'][i].lower() == 'vermouth' for i in ingredient_indexes])
|
149 |
-
n_sparkling = np.sum([ingredient_profiles['ingredient'][i].lower() == 'sparkling wine' for i in ingredient_indexes])
|
150 |
-
n_bubbles = np.sum([ingredient_profiles['ingredient'][i].lower() in ['soda', 'tonic', 'cola', 'ginger beer'] for i in ingredient_indexes])
|
151 |
-
n_syrup = np.sum([ingredient_profiles['ingredient'][i].lower() in ['grenadine', 'raspberry syrup'] for i in ingredient_indexes])
|
152 |
-
n_coffee = np.sum([ingredient_profiles['ingredient'][i].lower() == 'espresso' for i in ingredient_indexes])
|
153 |
-
inconsequentials = ['water', 'salt', 'angostura', 'orange bitters', 'mint']
|
154 |
-
n_inconsequentials = np.sum([ingredient_profiles['ingredient'][i].lower() in inconsequentials for i in ingredient_indexes])
|
155 |
-
n = dict(all=len(ingredients),
|
156 |
-
inconsequentials=n_inconsequentials,
|
157 |
-
sugar_bubbles=n_sugar_bubbles,
|
158 |
-
bubbles=n_bubbles,
|
159 |
-
plain_sweet=n_plain_sweet,
|
160 |
-
all_sweet=n_all_sweet,
|
161 |
-
coffee=n_coffee,
|
162 |
-
alcoholic=n_alcoholic,
|
163 |
-
syrup=n_syrup,
|
164 |
-
sparkling=n_sparkling,
|
165 |
-
sugar=n_sugar,
|
166 |
-
spirit=n_spirit,
|
167 |
-
citrus=n_citrus,
|
168 |
-
juice=n_juice,
|
169 |
-
liqueur=n_liqueur,
|
170 |
-
bitter=n_bitter,
|
171 |
-
egg=n_egg,
|
172 |
-
vermouth=n_vermouth)
|
173 |
-
|
174 |
-
sub_cats = [c for c, test_c in zip(sub_categories, is_sub_category) if test_c(n, ingredient_indexes, ingredients, quantities)]
|
175 |
-
if name != None:
|
176 |
-
name = name.lower()
|
177 |
-
keywords_to_test = ['julep', 'collins', 'highball', 'sour', 'champagne']
|
178 |
-
for k in keywords_to_test:
|
179 |
-
if k in name and not any([k in cat for cat in sub_cats]):
|
180 |
-
print(k)
|
181 |
-
for ing, q in zip(ingredients, quantities):
|
182 |
-
print(f'{ing}: {q} ml')
|
183 |
-
print(n)
|
184 |
-
break
|
185 |
-
if sorted(sub_cats) == ['champagne_cocktail', 'complex_highball']:
|
186 |
-
sub_cats = ['champagne_cocktail']
|
187 |
-
elif sorted(sub_cats) == ['collins', 'complex_highball']:
|
188 |
-
sub_cats = ['collins']
|
189 |
-
elif sorted(sub_cats) == ['champagne_cocktail', 'complex_highball', 'julep']:
|
190 |
-
sub_cats = ['champagne_cocktail']
|
191 |
-
elif sorted(sub_cats) == ['ancestral', 'julep']:
|
192 |
-
sub_cats = ['julep']
|
193 |
-
elif sorted(sub_cats) == ['complex_highball', 'julep']:
|
194 |
-
sub_cats = ['complex_highball']
|
195 |
-
elif sorted(sub_cats) == ['julep', 'simple_sour_with_juice']:
|
196 |
-
sub_cats = ['simple_sour_with_juice']
|
197 |
-
elif sorted(sub_cats) == ['complex_sour_with_juice', 'julep']:
|
198 |
-
sub_cats = ['complex_sour_with_juice']
|
199 |
-
if len(sub_cats) != 1:
|
200 |
-
# print(sub_cats)
|
201 |
-
# for ing, q in zip(ingredients, quantities):
|
202 |
-
# print(f'{ing}: {q} ml')
|
203 |
-
# print(n)
|
204 |
-
# if len(sub_cats) == 0:
|
205 |
-
sub_cats = ['other']
|
206 |
-
assert len(sub_cats) == 1, sub_cats
|
207 |
-
return sub_cats[0], n
|
208 |
-
|
209 |
-
def get_cocktails_attributes(ing_strs):
|
210 |
-
attributes = dict()
|
211 |
-
cats = []
|
212 |
-
for ing_str in ing_strs:
|
213 |
-
ingredients, quantities = extract_ingredients(ing_str)
|
214 |
-
cat, atts = find_cocktail_sub_category(ingredients, quantities)
|
215 |
-
for k in atts.keys():
|
216 |
-
if k not in attributes.keys():
|
217 |
-
attributes[k] = [atts[k]]
|
218 |
-
else:
|
219 |
-
attributes[k].append(atts[k])
|
220 |
-
cats.append(cat)
|
221 |
-
return cats, attributes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/cocktails/utilities/cocktail_generation_utilities/__init__.py
DELETED
File without changes
|
src/cocktails/utilities/cocktail_generation_utilities/individual.py
DELETED
@@ -1,587 +0,0 @@
|
|
1 |
-
from src.cocktails.utilities.ingredients_utilities import get_ingredients_info, format_ingredients, extract_ingredients, ingredients_per_type, bubble_ingredients
|
2 |
-
import numpy as np
|
3 |
-
from src.cocktails.utilities.other_scrubbing_utilities import print_recipe
|
4 |
-
from src.cocktails.utilities.cocktail_utilities import get_cocktail_rep, get_profile, get_bunch_of_rep_keys
|
5 |
-
from src.cocktails.utilities.glass_and_volume_utilities import glass_volume
|
6 |
-
from src.cocktails.representation_learning.run import get_model
|
7 |
-
from src.cocktails.pipeline.get_cocktail2affective_cluster import get_cocktail2affective_cluster
|
8 |
-
from src.cocktails.config import COCKTAILS_CSV_DATA, FULL_COCKTAIL_REP_PATH, REPO_PATH, COCKTAIL_REP_CHKPT_PATH, RECIPE2FEATURES_PATH
|
9 |
-
from src.cocktails.representation_learning.run_without_vae import get_model
|
10 |
-
from src.cocktails.utilities.cocktail_category_detection_utilities import find_cocktail_sub_category
|
11 |
-
|
12 |
-
import pandas as pd
|
13 |
-
import torch
|
14 |
-
import time
|
15 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
16 |
-
|
17 |
-
density_ingredients = np.loadtxt(COCKTAIL_REP_CHKPT_PATH + 'density_ingredients.txt')
|
18 |
-
max_ingredients, ingredient_list, ind_alcohol = get_ingredients_info()
|
19 |
-
min_ingredients = 2
|
20 |
-
factor_max = 1.2 # generated recipes can go up to 1.2 times the max quantity of the ingredient found in the dataset
|
21 |
-
|
22 |
-
prep_model = get_model(RECIPE2FEATURES_PATH + 'multi_predictor/')[0]
|
23 |
-
|
24 |
-
all_rep_path = FULL_COCKTAIL_REP_PATH
|
25 |
-
all_reps = np.loadtxt(all_rep_path)
|
26 |
-
experiment_dir = REPO_PATH + '/experiments/cocktails/'
|
27 |
-
rep_keys = get_bunch_of_rep_keys()['custom']
|
28 |
-
dict_weights_mse_computation = {'end volume': .1, 'end sour': 2, 'end sweet': 2, 'end booze': 4, 'end bitter': 2, 'end fruit': 1, 'end herb': 1,
|
29 |
-
'end complex': 1, 'end spicy': 5, 'end oaky': 1, 'end fizzy': 10, 'end colorful': 1, 'end eggy': 10}
|
30 |
-
assert sorted(dict_weights_mse_computation.keys()) == sorted(rep_keys)
|
31 |
-
weights_mse_computation = np.array([dict_weights_mse_computation[k] for k in rep_keys])
|
32 |
-
weights_mse_computation /= weights_mse_computation.sum()
|
33 |
-
data = pd.read_csv(COCKTAILS_CSV_DATA)
|
34 |
-
preparation_list = sorted(set(data['category']))
|
35 |
-
glasses_list = sorted(set(data['glass']))
|
36 |
-
|
37 |
-
weights_perf_n_ing = {2:0.71, 3:0.81, 4:0.93, 5:1., 6:1.03, 7:1.08, 8:1.05}
|
38 |
-
|
39 |
-
# weights_perf_n_ing = {2:0.75, 3:0.8, 4:0.95, 5:1.05, 6:1.05, 7:1.05, 8:1.05}
|
40 |
-
min_ingredients_quantities_when_present = np.loadtxt(COCKTAIL_REP_CHKPT_PATH +'ingredients_min_quantities_when_present.txt')
|
41 |
-
min_ingredients_quantities = np.loadtxt(COCKTAIL_REP_CHKPT_PATH +'ingredients_min_quantities.txt')
|
42 |
-
max_ingredients_quantities = np.loadtxt(COCKTAIL_REP_CHKPT_PATH + 'ingredients_max_quantities.txt')
|
43 |
-
min_cocktail_rep, max_cocktail_rep = np.loadtxt(COCKTAIL_REP_CHKPT_PATH +'cocktail_minmax_dim13_customkeys.txt')
|
44 |
-
distrib_nb_ings_2_8 = np.loadtxt(COCKTAIL_REP_CHKPT_PATH + 'distrib_nb_ing.txt')[2:]
|
45 |
-
def normalize_cocktail(cocktail_rep):
|
46 |
-
return ((cocktail_rep - min_cocktail_rep) / (max_cocktail_rep - min_cocktail_rep) - 0.5) * 2
|
47 |
-
|
48 |
-
def denormalize_cocktail(cocktail_rep):
|
49 |
-
return (cocktail_rep / 2 + 0.5) * (max_cocktail_rep - min_cocktail_rep) + min_cocktail_rep
|
50 |
-
|
51 |
-
def normalize_ingredient_q_rep(ingredients_q):
|
52 |
-
return (ingredients_q - min_ingredients_quantities_when_present) / (max_ingredients_quantities * factor_max - min_ingredients_quantities_when_present)
|
53 |
-
|
54 |
-
COCKTAIL_REPS = normalize_cocktail(np.array([data[k] for k in rep_keys]).transpose())
|
55 |
-
assert np.abs(COCKTAIL_REPS - all_reps).sum() < 1e-8
|
56 |
-
|
57 |
-
cocktail2affective_cluster = get_cocktail2affective_cluster()
|
58 |
-
|
59 |
-
original_affective_keys = get_bunch_of_rep_keys()['affective']
|
60 |
-
def sigmoid(x, shift, beta):
|
61 |
-
return (1 / (1 + np.exp(-(x + shift) * beta)) - 0.5) * 2
|
62 |
-
|
63 |
-
def get_normalized_affective_cocktail_rep_from_normalized_cocktail_rep(cocktail_rep):
|
64 |
-
indexes = np.array([rep_keys.index(key) for key in original_affective_keys])
|
65 |
-
cocktail_rep = cocktail_rep[indexes]
|
66 |
-
cocktail_rep[0] = sigmoid(cocktail_rep[0], shift=0.05, beta=4)
|
67 |
-
cocktail_rep[1] = sigmoid(cocktail_rep[1], shift=0.3, beta=5)
|
68 |
-
cocktail_rep[2] = sigmoid(cocktail_rep[2], shift=0.15, beta=3)
|
69 |
-
cocktail_rep[3] = sigmoid(cocktail_rep[3], shift=0.9, beta=20)
|
70 |
-
cocktail_rep[4] = sigmoid(cocktail_rep[4], shift=0, beta=4)
|
71 |
-
cocktail_rep[5] = sigmoid(cocktail_rep[5], shift=0.2, beta=3)
|
72 |
-
cocktail_rep[6] = sigmoid(cocktail_rep[6], shift=0.5, beta=5)
|
73 |
-
cocktail_rep[7] = sigmoid(cocktail_rep[7], shift=0.2, beta=6)
|
74 |
-
return cocktail_rep
|
75 |
-
|
76 |
-
class IndividualCocktail():
|
77 |
-
def __init__(self, pop_params, target, target_affective_cluster, genes_presence=None, genes_quantity=None,
|
78 |
-
compute_perf=True, known_target_dict=None, run_hard_check=False):
|
79 |
-
|
80 |
-
self.pop_params = pop_params
|
81 |
-
self.n_genes = len(ingredient_list)
|
82 |
-
self.max_ingredients = max_ingredients
|
83 |
-
self.min_ingredients = min_ingredients
|
84 |
-
self.mutation_params = pop_params['mutation_params']
|
85 |
-
self.dist = pop_params['dist']
|
86 |
-
self.target = target
|
87 |
-
self.is_known = known_target_dict is not None
|
88 |
-
self.known_target_dict = known_target_dict
|
89 |
-
self.perf = None
|
90 |
-
self.cocktail_rep = None
|
91 |
-
self.affective_cluster = None
|
92 |
-
self.target_affective_cluster = target_affective_cluster
|
93 |
-
self.ing_list = np.array(ingredient_list)
|
94 |
-
self.ing_set = set(ingredient_list)
|
95 |
-
|
96 |
-
self.ing_ids_per_cat = dict(bubbles=set(self.get_ingredients_ids_from_list(bubble_ingredients)),
|
97 |
-
liquor=set(self.get_ingredients_ids_from_list(ingredients_per_type['liquor'])),
|
98 |
-
liqueur=set(self.get_ingredients_ids_from_list(ingredients_per_type['liqueur'])),
|
99 |
-
citrus=set(self.get_ingredients_ids_from_list(ingredients_per_type['acid'] + ['orange juice'])),
|
100 |
-
alcohol=set(ind_alcohol),
|
101 |
-
sweeteners=set(self.get_ingredients_ids_from_list(ingredients_per_type['sweeteners'])),
|
102 |
-
vermouth=set(self.get_ingredients_ids_from_list(ingredients_per_type['vermouth'])),
|
103 |
-
bitters=set(self.get_ingredients_ids_from_list(ingredients_per_type['bitters'])),
|
104 |
-
juice=set(self.get_ingredients_ids_from_list(ingredients_per_type['juice'])),
|
105 |
-
acid=set(self.get_ingredients_ids_from_list(ingredients_per_type['acid'])),
|
106 |
-
egg=set(self.get_ingredients_ids_from_list(['egg']))
|
107 |
-
)
|
108 |
-
|
109 |
-
if genes_presence is not None:
|
110 |
-
assert len(genes_presence) == self.n_genes
|
111 |
-
assert len(genes_quantity) == self.n_genes
|
112 |
-
self.genes_presence = genes_presence
|
113 |
-
self.genes_quantity = genes_quantity
|
114 |
-
if compute_perf:
|
115 |
-
self.compute_cocktail_rep()
|
116 |
-
self.compute_perf()
|
117 |
-
else:
|
118 |
-
self.sample_initial_genes()
|
119 |
-
self.compute_cocktail_rep()
|
120 |
-
# self.make_recipe_fit_the_glass()
|
121 |
-
self.compute_perf()
|
122 |
-
|
123 |
-
|
124 |
-
# # # # # # # # # # # # # # # # # # # # # # # #
|
125 |
-
# Sample initial genes with smart rules
|
126 |
-
# # # # # # # # # # # # # # # # # # # # # # # #
|
127 |
-
|
128 |
-
def sample_initial_genes(self):
|
129 |
-
# rules:
|
130 |
-
# - between min_ingredients and max_ingredients
|
131 |
-
# - at most one type of bubbles
|
132 |
-
# - at least one alcohol
|
133 |
-
# - no egg without lime or lemon
|
134 |
-
# - at most two liqueurs
|
135 |
-
# - at most three liquors
|
136 |
-
# - at most two sweetener
|
137 |
-
self.genes_quantity = np.random.uniform(0, 1, size=self.n_genes) # holds quantities for each ingredient
|
138 |
-
n_ingredients = np.random.choice(np.arange(min_ingredients, max_ingredients + 1), p=distrib_nb_ings_2_8)
|
139 |
-
self.genes_presence = np.zeros(self.n_genes)
|
140 |
-
# add one alchohol
|
141 |
-
self.genes_presence[np.random.choice(ind_alcohol)] = 1
|
142 |
-
while self.get_ing_count() < n_ingredients:
|
143 |
-
candidate_ids = self.get_candidate_ingredients_ids(self.genes_presence)
|
144 |
-
probas = density_ingredients[candidate_ids] / np.sum(density_ingredients[candidate_ids])
|
145 |
-
self.genes_presence[np.random.choice(candidate_ids, p=probas)] = 1
|
146 |
-
|
147 |
-
def get_candidate_ingredients_ids(self, genes_presence):
|
148 |
-
candidates = set(np.argwhere(genes_presence==0).flatten())
|
149 |
-
present_ids = set(np.argwhere(genes_presence==1).flatten())
|
150 |
-
|
151 |
-
if self.count_in_genes(present_ids, 'bubbles') >= 1: # at most one type of bubbles
|
152 |
-
candidates = candidates - self.ing_ids_per_cat['bubbles']
|
153 |
-
if self.count_in_genes(present_ids, 'liquor') >= 3: # at most three liquors
|
154 |
-
candidates = candidates - self.ing_ids_per_cat['liquor']
|
155 |
-
if self.count_in_genes(present_ids, 'liqueur') >= 2: # at most two liqueurs
|
156 |
-
candidates = candidates - self.ing_ids_per_cat['liqueur']
|
157 |
-
if self.count_in_genes(present_ids, 'sweeteners') >= 2: # at most two sweetener
|
158 |
-
candidates = candidates - self.ing_ids_per_cat['sweeteners']
|
159 |
-
if self.count_in_genes(present_ids, 'citrus') == 0: # no egg without lime or lemon
|
160 |
-
candidates = candidates - self.ing_ids_per_cat['egg']
|
161 |
-
return np.array(sorted(candidates))
|
162 |
-
|
163 |
-
def count_in_genes(self, present_ids, keyword):
|
164 |
-
if keyword == 'citrus': return len(present_ids & self.ing_ids_per_cat['citrus'])
|
165 |
-
elif keyword == 'bubbles': return len(present_ids & self.ing_ids_per_cat['bubbles'])
|
166 |
-
elif keyword == 'liquor': return len(present_ids & self.ing_ids_per_cat['liquor'])
|
167 |
-
elif keyword == 'liqueur': return len(present_ids & self.ing_ids_per_cat['liqueur'])
|
168 |
-
elif keyword == 'alcohol': return len(present_ids & self.ing_ids_per_cat['alcohol'])
|
169 |
-
elif keyword == 'sweeteners': return len(present_ids & self.ing_ids_per_cat['sweeteners'])
|
170 |
-
else: raise ValueError
|
171 |
-
|
172 |
-
def get_ingredients_ids_from_list(self, ing_list):
|
173 |
-
return [ingredient_list.index(ing) for ing in ing_list]
|
174 |
-
|
175 |
-
def get_ing_count(self):
|
176 |
-
return np.sum(self.genes_presence)
|
177 |
-
|
178 |
-
# # # # # # # # # # # # # # # # # # # # # # # #
|
179 |
-
# Compute cocktail representations
|
180 |
-
# # # # # # # # # # # # # # # # # # # # # # # #
|
181 |
-
|
182 |
-
def get_absent_ing(self):
|
183 |
-
return np.argwhere(self.genes_presence==0).flatten()
|
184 |
-
|
185 |
-
def get_present_ing(self):
|
186 |
-
return np.argwhere(self.genes_presence==1).flatten()
|
187 |
-
|
188 |
-
def get_ingredient_quantities(self):
|
189 |
-
# unnormalize quantities to get real ones
|
190 |
-
return (self.genes_quantity * (max_ingredients_quantities * factor_max - min_ingredients_quantities_when_present) + min_ingredients_quantities_when_present) * self.genes_presence
|
191 |
-
|
192 |
-
def get_ing_and_q_from_genes(self):
|
193 |
-
present_ings = self.get_present_ing()
|
194 |
-
ing_quantities = self.get_ingredient_quantities()
|
195 |
-
ingredients, quantities = [], []
|
196 |
-
for i_ing in present_ings:
|
197 |
-
ingredients.append(ingredient_list[i_ing])
|
198 |
-
quantities.append(ing_quantities[i_ing])
|
199 |
-
return ingredients, quantities, ing_quantities
|
200 |
-
|
201 |
-
def compute_cocktail_rep(self):
|
202 |
-
# only call when genes have changes
|
203 |
-
init_time = time.time()
|
204 |
-
ingredients, quantities, ing_quantities = self.get_ing_and_q_from_genes()
|
205 |
-
# compute cocktail category
|
206 |
-
self.category = find_cocktail_sub_category(ingredients, quantities)[0]
|
207 |
-
# print(f't1: {time.time() - init_time}')
|
208 |
-
init_time = time.time()
|
209 |
-
self.prep_type = self.get_prep_type(ing_quantities)
|
210 |
-
# print(f't2: {time.time() - init_time}')
|
211 |
-
init_time = time.time()
|
212 |
-
cocktail_rep, self.end_volume, self.end_alcohol = get_cocktail_rep(self.prep_type, ingredients, quantities, keys=rep_keys[1:]) # volume is added later
|
213 |
-
# print(f't3: {time.time() - init_time}')
|
214 |
-
init_time = time.time()
|
215 |
-
self.cocktail_rep = normalize_cocktail(cocktail_rep)
|
216 |
-
# print(f't4: {time.time() - init_time}')
|
217 |
-
init_time = time.time()
|
218 |
-
self.glass = self.get_glass_type(ing_quantities)
|
219 |
-
# print(f't5: {time.time() - init_time}')
|
220 |
-
init_time = time.time()
|
221 |
-
if self.is_known:
|
222 |
-
assert np.abs(self.cocktail_rep - self.target).sum() < 1e-6
|
223 |
-
return self.cocktail_rep
|
224 |
-
|
225 |
-
def get_prep_type(self, quantities=None):
|
226 |
-
if self.is_known: return self.known_target_dict['prep_type']
|
227 |
-
else:
|
228 |
-
if quantities is None:
|
229 |
-
quantities = self.get_ingredient_quantities()
|
230 |
-
if quantities[ingredient_list.index('egg')] > 0:
|
231 |
-
prep_cat = 'egg_shaken'
|
232 |
-
elif self.category in ['spirit_forward', 'simple_sour_with_juice', 'julep', 'duo', 'ancestral', 'complex_sour_with_juice']:
|
233 |
-
# use hard coded rules for most obvious cases determined with the correlations_glass_cat_prep_script
|
234 |
-
if self.category in ['ancestral', 'spirit_forward', 'duo']:
|
235 |
-
prep_cat = 'stirred'
|
236 |
-
elif self.category in ['complex_sour_with_juice', 'julep', 'simple_sour_with_juice']:
|
237 |
-
prep_cat = 'shaken'
|
238 |
-
else:
|
239 |
-
raise ValueError
|
240 |
-
else:
|
241 |
-
output = prep_model(quantities, aux_str='prep_type').flatten()
|
242 |
-
output[preparation_list.index('egg_shaken')] = -np.inf
|
243 |
-
prep_cat = preparation_list[np.argmax(output)]
|
244 |
-
return prep_cat
|
245 |
-
|
246 |
-
def get_glass_type(self, quantities=None):
|
247 |
-
if self.is_known: return self.known_target_dict['glass']
|
248 |
-
else:
|
249 |
-
if self.category in ['collins', 'complex_highball', 'simple_highball', 'champagne_cocktail', 'complex_sour']:
|
250 |
-
# use hard coded rules for most obvious cases determined with the correlations_glass_cat_prep_script
|
251 |
-
if self.category in ['collins', 'complex_highball', 'simple_highball']:
|
252 |
-
glass = 'collins'
|
253 |
-
elif self.category in ['champagne_cocktail', 'complex_sour']:
|
254 |
-
glass = 'coupe'
|
255 |
-
else:
|
256 |
-
if quantities is None:
|
257 |
-
quantities = self.get_ingredient_quantities()
|
258 |
-
output = prep_model(quantities, aux_str='glasses').flatten()
|
259 |
-
glass = glasses_list[np.argmax(output)]
|
260 |
-
return glass
|
261 |
-
|
262 |
-
# # # # # # # # # # # # # # # # # # # # # # # #
|
263 |
-
# Adapt recipe to fit the glass
|
264 |
-
# # # # # # # # # # # # # # # # # # # # # # # #
|
265 |
-
|
266 |
-
def is_too_large_for_glass(self):
|
267 |
-
return self.end_volume > glass_volume[self.glass] * 0.80
|
268 |
-
|
269 |
-
def is_too_small_for_glass(self):
|
270 |
-
return self.end_volume < glass_volume[self.glass] * 0.3
|
271 |
-
|
272 |
-
def scale_ing_quantities(self, present_ings, factor):
|
273 |
-
qs = self.get_ingredient_quantities().copy()
|
274 |
-
qs[present_ings] *= factor
|
275 |
-
self.set_genes_from_quantities(present_ings, qs)
|
276 |
-
|
277 |
-
def set_genes_from_quantities(self, present_ings, quantities):
|
278 |
-
genes_quantity = np.clip((quantities - min_ingredients_quantities_when_present) /
|
279 |
-
(factor_max * max_ingredients_quantities - min_ingredients_quantities_when_present), 0, 1)
|
280 |
-
self.genes_quantity[present_ings] = genes_quantity[present_ings]
|
281 |
-
|
282 |
-
def make_recipe_fit_the_glass(self):
|
283 |
-
# check if citrus, if not remove egg
|
284 |
-
present_ids = np.argwhere(self.genes_presence == 1).flatten()
|
285 |
-
ing_list = self.ing_list[present_ids]
|
286 |
-
present_ids = set(present_ids)
|
287 |
-
if self.count_in_genes(present_ids, 'citrus') == 0 and 'egg' in ing_list:
|
288 |
-
if self.genes_presence.sum() > 2:
|
289 |
-
i_egg = ingredient_list.index('egg')
|
290 |
-
self.genes_presence[i_egg] = 0.
|
291 |
-
self.compute_cocktail_rep()
|
292 |
-
|
293 |
-
|
294 |
-
i_trial = 0
|
295 |
-
present_ings = self.get_present_ing()
|
296 |
-
while self.is_too_large_for_glass():
|
297 |
-
i_trial += 1
|
298 |
-
end_volume = self.end_volume
|
299 |
-
desired_volume = glass_volume[self.glass] * 0.80
|
300 |
-
ratio = desired_volume / end_volume
|
301 |
-
self.scale_ing_quantities(present_ings, factor=ratio)
|
302 |
-
self.compute_cocktail_rep()
|
303 |
-
if end_volume == self.end_volume: break
|
304 |
-
if i_trial == 10: break
|
305 |
-
while self.is_too_small_for_glass():
|
306 |
-
i_trial += 1
|
307 |
-
end_volume = self.end_volume
|
308 |
-
desired_volume = glass_volume[self.glass] * 0.80
|
309 |
-
ratio = desired_volume / end_volume
|
310 |
-
self.scale_ing_quantities(present_ings, factor=ratio)
|
311 |
-
self.compute_cocktail_rep()
|
312 |
-
if end_volume == self.end_volume: break
|
313 |
-
if i_trial == 10: break
|
314 |
-
|
315 |
-
# # # # # # # # # # # # # # # # # # # # # # # #
|
316 |
-
# Compute performance
|
317 |
-
# # # # # # # # # # # # # # # # # # # # # # # #
|
318 |
-
|
319 |
-
def passes_checks(self):
|
320 |
-
present_ids = np.argwhere(self.genes_presence==1).flatten()
|
321 |
-
# ing_list = self.ing_list[present_ids]
|
322 |
-
present_ids = set(present_ids)
|
323 |
-
if len(present_ids) < 2 or len(present_ids) > 8: return False
|
324 |
-
# if self.is_too_large_for_glass(): return False
|
325 |
-
# if self.is_too_small_for_glass(): return False
|
326 |
-
if self.end_alcohol < 0.05 or self.end_alcohol > 0.31: return False
|
327 |
-
if self.count_in_genes(present_ids, 'sweeteners') > 2: return False
|
328 |
-
if self.count_in_genes(present_ids, 'liqueur') > 2: return False
|
329 |
-
if self.count_in_genes(present_ids, 'liquor') > 3: return False
|
330 |
-
# if self.count_in_genes(present_ids, 'citrus') == 0 and 'egg' in ing_list: return False
|
331 |
-
if self.count_in_genes(present_ids, 'bubbles') > 1: return False
|
332 |
-
else: return True
|
333 |
-
|
334 |
-
def get_affective_cluster(self):
|
335 |
-
cocktail_rep_affective = get_normalized_affective_cocktail_rep_from_normalized_cocktail_rep(self.cocktail_rep)
|
336 |
-
self.affective_cluster = cocktail2affective_cluster(cocktail_rep_affective)[0]
|
337 |
-
return self.affective_cluster
|
338 |
-
|
339 |
-
def does_affective_cluster_match(self):
|
340 |
-
return True#self.get_affective_cluster() == self.target_affective_cluster
|
341 |
-
|
342 |
-
def compute_perf(self):
|
343 |
-
if not self.passes_checks(): self.perf = -100
|
344 |
-
else:
|
345 |
-
if self.dist == 'mse':
|
346 |
-
# self.perf = - np.sqrt(((self.cocktail_rep - self.target)**2).mean())
|
347 |
-
self.perf = - np.sqrt(np.dot((self.cocktail_rep - self.target)**2, weights_mse_computation))
|
348 |
-
self.perf *= weights_perf_n_ing[int(self.genes_presence.sum())]
|
349 |
-
if not self.does_affective_cluster_match():
|
350 |
-
self.perf *= 2
|
351 |
-
else: raise NotImplemented
|
352 |
-
|
353 |
-
|
354 |
-
# # # # # # # # # # # # # # # # # # # # # # # #
|
355 |
-
# Mutations and crossover
|
356 |
-
# # # # # # # # # # # # # # # # # # # # # # # #
|
357 |
-
|
358 |
-
def get_child(self):
|
359 |
-
time_dict = dict()
|
360 |
-
init_time = time.time()
|
361 |
-
child = IndividualCocktail(pop_params=self.pop_params, target_affective_cluster=self.target_affective_cluster,
|
362 |
-
target=self.target, genes_presence=self.genes_presence.copy(),
|
363 |
-
genes_quantity=self.genes_quantity.copy(), compute_perf=False)
|
364 |
-
time_dict[' asexual child creation'] = [time.time() - init_time]
|
365 |
-
init_time = time.time()
|
366 |
-
this_time_dict = child.mutate()
|
367 |
-
time_dict = self.update_time_dict(time_dict, this_time_dict)
|
368 |
-
time_dict[' asexual child mutation'] = [time.time() - init_time]
|
369 |
-
return child, time_dict
|
370 |
-
|
371 |
-
def get_child_with(self, other_parent):
|
372 |
-
time_dict = dict()
|
373 |
-
init_time = time.time()
|
374 |
-
new_genes_presence = np.zeros(self.n_genes)
|
375 |
-
present_ing = self.get_present_ing()
|
376 |
-
other_present_ing = other_parent.get_present_ing()
|
377 |
-
new_genes_quantity = np.random.uniform(0, 1, size=self.n_genes)
|
378 |
-
shared_ingredients = sorted(set(present_ing) & set(other_present_ing))
|
379 |
-
unique_ingredients_one = sorted(set(present_ing) - set(other_present_ing))
|
380 |
-
unique_ingredients_two = sorted(set(other_present_ing) - set(present_ing))
|
381 |
-
for i in shared_ingredients:
|
382 |
-
new_genes_presence[i] = 1
|
383 |
-
new_genes_quantity[i] = (self.genes_quantity[i] + other_parent.genes_quantity[i]) / 2
|
384 |
-
time_dict[' crossover child creation'] = [time.time() - init_time]
|
385 |
-
init_time = time.time()
|
386 |
-
# add one alcohol if none present
|
387 |
-
if len(set(np.argwhere(new_genes_presence==1).flatten()).intersection(ind_alcohol)) == 0:
|
388 |
-
new_genes_presence[np.random.choice(ind_alcohol)] = 1
|
389 |
-
# up to here, we respect the constraints (assuming both parents do).
|
390 |
-
candidate_genes = np.array(unique_ingredients_one + unique_ingredients_two)
|
391 |
-
candidate_quantities = np.array([self.genes_quantity[i] for i in unique_ingredients_one] + [other_parent.genes_quantity[i] for i in unique_ingredients_two])
|
392 |
-
indexes = np.arange(len(candidate_genes))
|
393 |
-
np.random.shuffle(indexes)
|
394 |
-
candidate_genes = candidate_genes[indexes]
|
395 |
-
candidate_quantities = candidate_quantities[indexes]
|
396 |
-
time_dict[' crossover prepare selection'] = [time.time() - init_time]
|
397 |
-
init_time = time.time()
|
398 |
-
# now let's try to add each of them while respecting the constraints
|
399 |
-
for i in range(len(indexes)):
|
400 |
-
if np.random.rand() < 0.5 or np.sum(new_genes_presence) < self.min_ingredients: # only try to add one every two ingredient
|
401 |
-
ing_id = candidate_genes[i]
|
402 |
-
q = candidate_quantities[i]
|
403 |
-
new_genes_presence[ing_id] = 1
|
404 |
-
new_genes_quantity[ing_id] = q
|
405 |
-
if np.sum(new_genes_presence) == self.max_ingredients:
|
406 |
-
break
|
407 |
-
time_dict[' crossover do selection'] = [time.time() - init_time]
|
408 |
-
init_time = time.time()
|
409 |
-
# create new child
|
410 |
-
child = IndividualCocktail(pop_params=self.pop_params, target_affective_cluster=self.target_affective_cluster, target=self.target,
|
411 |
-
genes_presence=new_genes_presence.copy(), genes_quantity=new_genes_quantity.copy(), compute_perf=False)
|
412 |
-
time_dict[' crossover create child'] = [time.time() - init_time]
|
413 |
-
init_time = time.time()
|
414 |
-
this_time_dict = child.mutate()
|
415 |
-
time_dict = self.update_time_dict(time_dict, this_time_dict)
|
416 |
-
time_dict[' crossover child mutation'] = [time.time() - init_time]
|
417 |
-
init_time = time.time()
|
418 |
-
return child, time_dict
|
419 |
-
|
420 |
-
def mutate(self):
|
421 |
-
# self.print_recipe()
|
422 |
-
time_dict = dict()
|
423 |
-
# remove an ingredient
|
424 |
-
init_time = time.time()
|
425 |
-
present_ids = set(np.argwhere(self.genes_presence==1).flatten())
|
426 |
-
|
427 |
-
if np.random.rand() < self.mutation_params['p_remove_ing']:
|
428 |
-
if self.get_ing_count() > self.min_ingredients:
|
429 |
-
candidate_ings = self.get_present_ing()
|
430 |
-
if self.count_in_genes(present_ids, 'alcohol') == 1: # make sure we keep at least one liquor
|
431 |
-
candidate_ings = np.array(sorted(set(candidate_ings) - set(ind_alcohol)))
|
432 |
-
index_to_remove = np.random.choice(candidate_ings)
|
433 |
-
self.genes_presence[index_to_remove] = 0
|
434 |
-
time_dict[' mutation remove ing'] = [time.time() - init_time]
|
435 |
-
init_time = time.time()
|
436 |
-
# add an ingredient
|
437 |
-
if np.random.rand() < self.mutation_params['p_add_ing']:
|
438 |
-
if self.get_ing_count() < self.max_ingredients:
|
439 |
-
candidate_ings = self.get_candidate_ingredients_ids(self.genes_presence.copy())
|
440 |
-
index_to_add = np.random.choice(candidate_ings, p=density_ingredients[candidate_ings] / np.sum(density_ingredients[candidate_ings]))
|
441 |
-
self.genes_presence[index_to_add] = 1
|
442 |
-
time_dict[' mutation add ing'] = [time.time() - init_time]
|
443 |
-
|
444 |
-
init_time = time.time()
|
445 |
-
# replace ings by others from the same family
|
446 |
-
if np.random.rand() < self.mutation_params['p_switch_ing']:
|
447 |
-
i = np.random.choice(self.get_present_ing())
|
448 |
-
ing_str = ingredient_list[i]
|
449 |
-
if ing_str not in ['sparkling wine', 'orange juice']:
|
450 |
-
if ing_str in bubble_ingredients:
|
451 |
-
candidates_ids = np.array(sorted(self.ing_ids_per_cat['bubbles'] - set([i])))
|
452 |
-
new_bubble = np.random.choice(candidates_ids, p=density_ingredients[candidates_ids] / np.sum(density_ingredients[candidates_ids]))
|
453 |
-
self.genes_presence[i] = 0
|
454 |
-
self.genes_presence[new_bubble] = 1
|
455 |
-
self.genes_quantity[new_bubble] = self.genes_quantity[i] # copy quantity
|
456 |
-
categories = ['acid', 'bitters', 'juice', 'liqueur', 'liquor', 'sweeteners', 'vermouth']
|
457 |
-
for cat in categories:
|
458 |
-
if ing_str in ingredients_per_type[cat]:
|
459 |
-
present_ings = self.get_present_ing()
|
460 |
-
candidates_ids = np.array(sorted(self.ing_ids_per_cat[cat] - set([i]) - set(present_ings)))
|
461 |
-
if len(candidates_ids) > 0:
|
462 |
-
replacing_ing = np.random.choice(candidates_ids, p=density_ingredients[candidates_ids] / np.sum(density_ingredients[candidates_ids]))
|
463 |
-
self.genes_presence[i] = 0
|
464 |
-
self.genes_presence[replacing_ing] = 1
|
465 |
-
self.genes_quantity[replacing_ing] = self.genes_quantity[i] # copy quantity
|
466 |
-
break
|
467 |
-
time_dict[' mutation switch ing'] = [time.time() - init_time]
|
468 |
-
init_time = time.time()
|
469 |
-
# add noise on ing quantity
|
470 |
-
for i in self.get_present_ing():
|
471 |
-
if np.random.rand() < self.mutation_params['p_change_q']:
|
472 |
-
self.genes_quantity[i] += np.random.randn() * self.mutation_params['delta_change_q']
|
473 |
-
self.genes_quantity = np.clip(self.genes_quantity, 0, 1)
|
474 |
-
time_dict[' mutation change quantity'] = [time.time() - init_time]
|
475 |
-
|
476 |
-
init_time = time.time()
|
477 |
-
self.compute_cocktail_rep()
|
478 |
-
time_dict[' mutation compute cocktail rep'] = [time.time() - init_time]
|
479 |
-
init_time = time.time()
|
480 |
-
# self.make_recipe_fit_the_glass()
|
481 |
-
time_dict[' mutation check glass fit'] = [time.time() - init_time]
|
482 |
-
init_time = time.time()
|
483 |
-
self.compute_perf()
|
484 |
-
time_dict[' mutation compute perf'] = [time.time() - init_time]
|
485 |
-
init_time = time.time()
|
486 |
-
stop = 1
|
487 |
-
return time_dict
|
488 |
-
|
489 |
-
|
490 |
-
def update_time_dict(self, main_dict, new_dict):
|
491 |
-
for k in new_dict.keys():
|
492 |
-
if k in main_dict.keys():
|
493 |
-
main_dict[k].append(np.sum(new_dict[k]))
|
494 |
-
else:
|
495 |
-
main_dict[k] = [np.sum(new_dict[k])]
|
496 |
-
return main_dict
|
497 |
-
|
498 |
-
# # # # # # # # # # # # # # # # # # # # # # # #
|
499 |
-
# Get recipe and print
|
500 |
-
# # # # # # # # # # # # # # # # # # # # # # # #
|
501 |
-
|
502 |
-
def get_recipe(self, unit='mL', name=None):
|
503 |
-
ing_quantities = self.get_ingredient_quantities()
|
504 |
-
ingredients, quantities = [], []
|
505 |
-
for i_ing, q_ing in enumerate(ing_quantities):
|
506 |
-
if q_ing > 0.8:
|
507 |
-
ingredients.append(ingredient_list[i_ing])
|
508 |
-
quantities.append(round(q_ing))
|
509 |
-
recipe_str = format_ingredients(ingredients, quantities)
|
510 |
-
recipe_str_readable = print_recipe(unit=unit, ingredient_str=recipe_str, name=name, to_print=False)
|
511 |
-
return ingredients, quantities, recipe_str, recipe_str_readable
|
512 |
-
|
513 |
-
def get_instructions(self):
|
514 |
-
ing_quantities = self.get_ingredient_quantities()
|
515 |
-
ingredients, quantities = [], []
|
516 |
-
for i_ing, q_ing in enumerate(ing_quantities):
|
517 |
-
if q_ing > 0.8:
|
518 |
-
ingredients.append(ingredient_list[i_ing])
|
519 |
-
quantities.append(round(q_ing))
|
520 |
-
str_out = 'Instructions:\n '
|
521 |
-
|
522 |
-
if 'mint' in ingredients:
|
523 |
-
i_mint = ingredients.index('mint')
|
524 |
-
n_leaves = quantities[i_mint]
|
525 |
-
str_out += f'Add {n_leaves} mint leaves to a shaker, followed by an ice cube.\n Muddle the mint and ice together with a muddler.\n '
|
526 |
-
bubbles = ['sparkling wine', 'tonic', 'soda', 'ginger beer']
|
527 |
-
other_ings = [ing for ing in ingredients if ing not in ['egg', 'angostura', 'orange bitters'] + bubbles]
|
528 |
-
|
529 |
-
if self.prep_type == 'built':
|
530 |
-
str_out += 'Add a large ice cube in the glass.\n '
|
531 |
-
# add ingredients to pour
|
532 |
-
str_out += 'Pour'
|
533 |
-
for i, ing in enumerate(other_ings):
|
534 |
-
if i == len(other_ings) - 2:
|
535 |
-
str_out += f' {ing} and'
|
536 |
-
elif i == len(other_ings) - 1:
|
537 |
-
str_out += f' {ing}'
|
538 |
-
else:
|
539 |
-
str_out += f' {ing},'
|
540 |
-
|
541 |
-
if self.prep_type in ['built'] and 'mint' not in ingredients:
|
542 |
-
str_out += ' into the glass.\n '
|
543 |
-
else:
|
544 |
-
str_out += ' into the shaker.\n '
|
545 |
-
|
546 |
-
if self.prep_type == 'egg_shaken' and 'egg' in ingredients:
|
547 |
-
str_out += 'Add the egg white.\n Dry-shake for 15s (without ice), then fill with ice and shake for another 15s.\n Serve into the glass through a strainer.\n '
|
548 |
-
elif 'shaken' in self.prep_type:
|
549 |
-
str_out += 'Fill with ice and shake for 15s.\n Serve into the glass through a strainer.\n '
|
550 |
-
elif self.prep_type == 'stirred':
|
551 |
-
str_out += 'Add ice and stir the cocktail with a spoon for 15s.\n Serve into the glass through a strainer.\n '
|
552 |
-
elif self.prep_type == 'built':
|
553 |
-
str_out += 'Stir two turns with a spoon.\n '
|
554 |
-
|
555 |
-
bubble_ing = [ing for ing in ingredients if ing in bubbles]
|
556 |
-
if len(bubble_ing) > 0:
|
557 |
-
str_out += f'Top up with '
|
558 |
-
for ing in bubble_ing:
|
559 |
-
str_out += f'{ing}, '
|
560 |
-
str_out = str_out[:-2] + '.\n '
|
561 |
-
bitter_ing = [ing for ing in ingredients if ing in ['angostura', 'orange bitters']]
|
562 |
-
if len(bitter_ing) > 0:
|
563 |
-
if len(bitter_ing) == 1:
|
564 |
-
q = quantities[ingredients.index(bitter_ing[0])]
|
565 |
-
n_dashes = max(1, int(q / 0.6))
|
566 |
-
str_out += f'Add {n_dashes} dash'
|
567 |
-
if n_dashes > 1:
|
568 |
-
str_out += 'es'
|
569 |
-
str_out += f' of {bitter_ing[0]}.\n '
|
570 |
-
elif len(bitter_ing) == 2:
|
571 |
-
q = quantities[ingredients.index(bitter_ing[0])]
|
572 |
-
n_dashes = max(1, int(q / 0.6))
|
573 |
-
str_out += f'Add {n_dashes} dash'
|
574 |
-
if n_dashes > 1:
|
575 |
-
str_out += 'es'
|
576 |
-
str_out += f' of {bitter_ing[0]} and '
|
577 |
-
q = quantities[ingredients.index(bitter_ing[1])]
|
578 |
-
n_dashes = max(1, int(q / 0.6))
|
579 |
-
str_out += f'{n_dashes} dash'
|
580 |
-
if n_dashes > 1:
|
581 |
-
str_out += 'es'
|
582 |
-
str_out += f' of {bitter_ing[1]}.\n '
|
583 |
-
str_out += 'Enjoy!'
|
584 |
-
return str_out
|
585 |
-
|
586 |
-
def print_recipe(self, name=None):
|
587 |
-
print(self.get_recipe(name)[3])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/cocktails/utilities/cocktail_generation_utilities/population.py
DELETED
@@ -1,213 +0,0 @@
|
|
1 |
-
from src.cocktails.utilities.cocktail_generation_utilities.individual import *
|
2 |
-
from sklearn.neighbors import NearestNeighbors
|
3 |
-
import time
|
4 |
-
import pickle
|
5 |
-
from src.cocktails.config import COCKTAIL_NN_PATH, COCKTAILS_CSV_DATA
|
6 |
-
|
7 |
-
class Population:
|
8 |
-
def __init__(self, target, pop_params, target_affective_cluster=None, known_target_dict=None):
|
9 |
-
self.pop_params = pop_params
|
10 |
-
self.pop_size = pop_params['pop_size']
|
11 |
-
self.nb_elite = pop_params['nb_elites']
|
12 |
-
self.nb_generations = pop_params['nb_generations']
|
13 |
-
self.target = target
|
14 |
-
self.mutation_params = pop_params['mutation_params']
|
15 |
-
self.dist = pop_params['dist']
|
16 |
-
self.n_neighbors = pop_params['n_neighbors']
|
17 |
-
self.known_target_dict = known_target_dict
|
18 |
-
|
19 |
-
|
20 |
-
with open(COCKTAIL_NN_PATH, 'rb') as f:
|
21 |
-
data = pickle.load(f)
|
22 |
-
self.nn_model_cocktail = data['nn_model']
|
23 |
-
self.dim_rep_cocktail = data['dim_rep_cocktail']
|
24 |
-
self.n_cocktails = data['n_cocktails']
|
25 |
-
self.cocktail_data = pd.read_csv(COCKTAILS_CSV_DATA)
|
26 |
-
|
27 |
-
if target_affective_cluster is None:
|
28 |
-
cocktail_rep_affective = get_normalized_affective_cocktail_rep_from_normalized_cocktail_rep(target)
|
29 |
-
self.target_affective_cluster = cocktail2affective_cluster(cocktail_rep_affective)[0]
|
30 |
-
else:
|
31 |
-
self.target_affective_cluster = target_affective_cluster
|
32 |
-
|
33 |
-
self.pop_elite = []
|
34 |
-
self.pop = []
|
35 |
-
self.add_target_individual() # create a target individual (not in pop)
|
36 |
-
self.add_nearest_neighbors_in_pop() # add nearest neighbor from dataset into the population
|
37 |
-
|
38 |
-
# fill population
|
39 |
-
while self.get_pop_size() < self.pop_size:
|
40 |
-
self.add_individual()
|
41 |
-
while len(self.pop_elite) < self.nb_elite:
|
42 |
-
self.pop_elite.append(IndividualCocktail(pop_params=self.pop_params,
|
43 |
-
target=self.target.copy(),
|
44 |
-
target_affective_cluster=self.target_affective_cluster))
|
45 |
-
self.update_elite_and_get_next_pop()
|
46 |
-
|
47 |
-
def add_target_individual(self):
|
48 |
-
if self.known_target_dict is not None:
|
49 |
-
genes_presence, genes_quantity = self.get_q_rep(*extract_ingredients(self.known_target_dict['ing_str']))
|
50 |
-
self.target_individual = IndividualCocktail(pop_params=self.pop_params,
|
51 |
-
target=self.target.copy(),
|
52 |
-
known_target_dict=self.known_target_dict,
|
53 |
-
target_affective_cluster=self.target_affective_cluster,
|
54 |
-
genes_presence=genes_presence,
|
55 |
-
genes_quantity=genes_quantity
|
56 |
-
)
|
57 |
-
else:
|
58 |
-
self.target_individual = None
|
59 |
-
|
60 |
-
|
61 |
-
def add_nearest_neighbors_in_pop(self):
|
62 |
-
# add nearest neighbor from dataset into the population
|
63 |
-
if self.n_neighbors > 0:
|
64 |
-
dists, indexes = self.nn_model_cocktail.kneighbors(self.target.reshape(1, -1))
|
65 |
-
dists, indexes = dists.flatten(), indexes.flatten()
|
66 |
-
first = 1 if dists[0] == 0 else 0 # avoid taking the target when testing with known targets from the dataset
|
67 |
-
indexes = indexes[first:first + self.n_neighbors]
|
68 |
-
self.ing_strs = np.array(self.cocktail_data['ingredients_str'])[indexes]
|
69 |
-
recipes = [extract_ingredients(ing_str) for ing_str in self.ing_strs]
|
70 |
-
for r in recipes:
|
71 |
-
genes_presence, genes_quantity = self.get_q_rep(r[0], r[1])
|
72 |
-
genes_presence[-1] = 0 # remove water ingredient
|
73 |
-
self.add_individual(genes_presence=genes_presence.copy(), genes_quantity=genes_quantity.copy())
|
74 |
-
self.nn_recipes = [ind.get_recipe()[3] for ind in self.pop]
|
75 |
-
self.nn_scores = [ind.perf for ind in self.pop]
|
76 |
-
else:
|
77 |
-
self.ing_strs = None
|
78 |
-
|
79 |
-
def add_individual(self, genes_presence=None, genes_quantity=None):
|
80 |
-
self.pop.append(IndividualCocktail(pop_params=self.pop_params,
|
81 |
-
target=self.target.copy(),
|
82 |
-
target_affective_cluster=self.target_affective_cluster,
|
83 |
-
genes_presence=genes_presence,
|
84 |
-
genes_quantity=genes_quantity))
|
85 |
-
|
86 |
-
def get_elite_perf(self):
|
87 |
-
return np.array([e.perf for e in self.pop_elite])
|
88 |
-
|
89 |
-
def get_pop_perf(self):
|
90 |
-
return np.array([ind.perf for ind in self.pop])
|
91 |
-
|
92 |
-
|
93 |
-
def update_elite_and_get_next_pop(self):
|
94 |
-
time_dict = dict()
|
95 |
-
init_time = time.time()
|
96 |
-
elite_perfs = self.get_elite_perf()
|
97 |
-
pop_perfs = self.get_pop_perf()
|
98 |
-
all_perfs = np.concatenate([elite_perfs, pop_perfs])
|
99 |
-
temp_list = self.pop_elite + self.pop
|
100 |
-
time_dict[' get pop perfs'] = [time.time() - init_time]
|
101 |
-
init_time = time.time()
|
102 |
-
# update elite population with new bests
|
103 |
-
indexes_sorted = np.flip(np.argsort(all_perfs))
|
104 |
-
new_pop_elite = [IndividualCocktail(pop_params=self.pop_params,
|
105 |
-
target=self.target.copy(),
|
106 |
-
target_affective_cluster=self.target_affective_cluster,
|
107 |
-
genes_presence=temp_list[i_new_e].genes_presence.copy(),
|
108 |
-
genes_quantity=temp_list[i_new_e].genes_quantity.copy()) for i_new_e in indexes_sorted[:self.nb_elite]]
|
109 |
-
time_dict[' recreate elite individuals'] = [time.time() - init_time]
|
110 |
-
init_time = time.time()
|
111 |
-
# select parents
|
112 |
-
rank_perfs = np.flip(np.arange(len(temp_list)))
|
113 |
-
sampling_probs = rank_perfs / np.sum(rank_perfs)
|
114 |
-
if self.mutation_params['asexual_rep'] and not self.mutation_params['crossover']:
|
115 |
-
new_pop_indexes = np.random.choice(indexes_sorted, p=sampling_probs, size=self.pop_size)
|
116 |
-
self.pop = [temp_list[i].get_child() for i in new_pop_indexes]
|
117 |
-
elif self.mutation_params['crossover'] and not self.mutation_params['asexual_rep']:
|
118 |
-
self.pop = []
|
119 |
-
while len(self.pop) < self.pop_size:
|
120 |
-
parents = np.random.choice(indexes_sorted, p=sampling_probs, size=2, replace=False)
|
121 |
-
self.pop.append(temp_list[parents[0]].get_child_with(temp_list[parents[1]]))
|
122 |
-
elif self.mutation_params['crossover'] and self.mutation_params['asexual_rep']:
|
123 |
-
new_pop_indexes = np.random.choice(indexes_sorted, p=sampling_probs, size=self.pop_size//2)
|
124 |
-
time_dict[' choose asexual parent indexes'] = [time.time() - init_time]
|
125 |
-
init_time = time.time()
|
126 |
-
self.pop = []
|
127 |
-
for i in new_pop_indexes:
|
128 |
-
child, this_time_dict = temp_list[i].get_child()
|
129 |
-
self.pop.append(child)
|
130 |
-
time_dict = self.update_time_dict(time_dict, this_time_dict)
|
131 |
-
time_dict[' get asexual children'] = [time.time() - init_time]
|
132 |
-
init_time = time.time()
|
133 |
-
while len(self.pop) < self.pop_size:
|
134 |
-
parents = np.random.choice(indexes_sorted, p=sampling_probs, size=2, replace=False)
|
135 |
-
child, this_time_dict = temp_list[parents[0]].get_child_with(temp_list[parents[1]])
|
136 |
-
self.pop.append(child)
|
137 |
-
time_dict = self.update_time_dict(time_dict, this_time_dict)
|
138 |
-
time_dict[' get sexual children'] = [time.time() - init_time]
|
139 |
-
self.pop_elite = new_pop_elite
|
140 |
-
return time_dict
|
141 |
-
|
142 |
-
def get_pop_size(self):
|
143 |
-
return len(self.pop)
|
144 |
-
|
145 |
-
def get_q_rep(self, ingredients, quantities):
|
146 |
-
ingredient_q_rep = np.zeros([len(ingredient_list)])
|
147 |
-
genes_presence = np.zeros([len(ingredient_list)])
|
148 |
-
for ing, q in zip(ingredients, quantities):
|
149 |
-
ingredient_q_rep[ingredient_list.index(ing)] = q
|
150 |
-
genes_presence[ingredient_list.index(ing)] = 1
|
151 |
-
return genes_presence.copy(), normalize_ingredient_q_rep(ingredient_q_rep)
|
152 |
-
|
153 |
-
def get_best_score(self, affective_cluster_check=False):
|
154 |
-
elite_perfs = self.get_elite_perf()
|
155 |
-
pop_perfs = self.get_pop_perf()
|
156 |
-
all_perfs = np.concatenate([elite_perfs, pop_perfs])
|
157 |
-
temp_list = self.pop_elite + self.pop
|
158 |
-
if affective_cluster_check:
|
159 |
-
indexes = np.array([i for i in range(len(temp_list)) if temp_list[i].does_affective_cluster_match()])
|
160 |
-
if indexes.size > 0:
|
161 |
-
temp_list = np.array(temp_list)[indexes]
|
162 |
-
all_perfs = all_perfs[indexes]
|
163 |
-
indexes_best = np.flip(np.argsort(all_perfs))
|
164 |
-
return np.array(all_perfs)[indexes_best], np.array(temp_list)[indexes_best]
|
165 |
-
|
166 |
-
def update_time_dict(self, main_dict, new_dict):
|
167 |
-
for k in new_dict.keys():
|
168 |
-
if k in main_dict.keys():
|
169 |
-
main_dict[k].append(np.sum(new_dict[k]))
|
170 |
-
else:
|
171 |
-
main_dict[k] = [np.sum(new_dict[k])]
|
172 |
-
return main_dict
|
173 |
-
|
174 |
-
def run_one_generation(self, verbose=True, affective_cluster_check=False):
|
175 |
-
time_dict = dict()
|
176 |
-
init_time = time.time()
|
177 |
-
this_time_dict = self.update_elite_and_get_next_pop()
|
178 |
-
time_dict['update_elite_and_pop'] = [time.time() - init_time]
|
179 |
-
time_dict = self.update_time_dict(time_dict, this_time_dict)
|
180 |
-
init_time = time.time()
|
181 |
-
best_perfs, best_individuals = self.get_best_score(affective_cluster_check)
|
182 |
-
time_dict['get best scores'] = [time.time() - init_time]
|
183 |
-
return best_perfs[0], time_dict
|
184 |
-
|
185 |
-
def run_evolution(self, verbose=False, print_every=10, affective_cluster_check=False, level=0):
|
186 |
-
best_score = -np.inf
|
187 |
-
time_dict = dict()
|
188 |
-
init_time = time.time()
|
189 |
-
for i in range(self.nb_generations):
|
190 |
-
best_score, this_time_dict = self.run_one_generation(verbose, affective_cluster_check=affective_cluster_check)
|
191 |
-
time_dict = self.update_time_dict(time_dict, this_time_dict)
|
192 |
-
if verbose and (i+1) % print_every == 0:
|
193 |
-
print(' ' * level + f'Gen #{i+1} - Current best perf: {best_score:.2f}, time: {time.time() - init_time:.4f}')
|
194 |
-
init_time = time.time()
|
195 |
-
#
|
196 |
-
# to_print = time_dict.copy()
|
197 |
-
# keys = sorted(to_print.keys())
|
198 |
-
# values = []
|
199 |
-
# for k in keys:
|
200 |
-
# to_print[k] = np.sum(to_print[k])
|
201 |
-
# values.append(to_print[k])
|
202 |
-
# sorted_inds = np.flip(np.argsort(values))
|
203 |
-
# for i in sorted_inds:
|
204 |
-
# print(f'{keys[i]}: {values[i]:.4f}')
|
205 |
-
if verbose: print(' ' * level + f'Evolution over, best perf: {best_score:.2f}')
|
206 |
-
return self.get_best_score()
|
207 |
-
|
208 |
-
def print_results(self, n=3):
|
209 |
-
best_scores, best_ind = self.get_best_score()
|
210 |
-
for i in range(n):
|
211 |
-
best_ind[i].print_recipe(f'Candidate #{i+1}, Score: {best_scores[i]:.2f}')
|
212 |
-
|
213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/cocktails/utilities/cocktail_utilities.py
DELETED
@@ -1,220 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
from src.cocktails.utilities.ingredients_utilities import ingredient2ingredient_id, ingredient_profiles, ingredients_per_type, ingredient_list, find_ingredient_from_str
|
3 |
-
from src.cocktails.utilities.cocktail_category_detection_utilities import *
|
4 |
-
import time
|
5 |
-
|
6 |
-
# representation_keys = ['pH', 'sour', 'sweet', 'booze', 'bitter', 'fruit', 'herb',
|
7 |
-
# 'complex', 'spicy', 'strong', 'oaky', 'fizzy', 'colorful', 'eggy']
|
8 |
-
representation_keys = ['sour', 'sweet', 'booze', 'bitter', 'fruit', 'herb',
|
9 |
-
'complex', 'spicy', 'oaky', 'fizzy', 'colorful', 'eggy']
|
10 |
-
representation_keys_linear = list(set(representation_keys) - set(['pH', 'complex']))
|
11 |
-
|
12 |
-
ing_reps = np.array([[ingredient_profiles[k][ing_id] for ing_id in ingredient2ingredient_id.values()] for k in representation_keys]).transpose()
|
13 |
-
|
14 |
-
|
15 |
-
def compute_cocktail_representation(profile, ingredients, quantities):
|
16 |
-
# computes representation of a cocktail from the recipe (ingredients, quantities) and volume
|
17 |
-
n = len(ingredients)
|
18 |
-
assert n == len(quantities)
|
19 |
-
quantities = np.array(quantities)
|
20 |
-
|
21 |
-
weights = quantities / np.sum(quantities)
|
22 |
-
rep = dict()
|
23 |
-
|
24 |
-
ing_ids = np.array([ingredient2ingredient_id[ing] for ing in ingredients])
|
25 |
-
# compute features as linear combination of ingredient features
|
26 |
-
for k in representation_keys_linear:
|
27 |
-
k_ing = np.array([ingredient_profiles[k][ing_id] for ing_id in ing_ids])
|
28 |
-
rep[k] = np.dot(weights, k_ing)
|
29 |
-
|
30 |
-
# for ph
|
31 |
-
# ph = - log10 x
|
32 |
-
phs = np.array([ingredient_profiles['pH'][ing_id] for ing_id in ing_ids])
|
33 |
-
concentrations = 10 ** (- phs)
|
34 |
-
mix_c = np.dot(weights, concentrations)
|
35 |
-
|
36 |
-
rep['pH'] = - np.log10(mix_c)
|
37 |
-
|
38 |
-
rep['complex'] = np.mean([ingredient_profiles['complex'][ing_id] for ing_id in ing_ids]) + len(ing_ids)
|
39 |
-
|
40 |
-
# compute profile after dilution
|
41 |
-
volume_ratio = profile['mix volume'] / profile['end volume']
|
42 |
-
for k in representation_keys:
|
43 |
-
rep['end ' + k] = rep[k] * volume_ratio
|
44 |
-
concentration = 10 ** (-rep['pH'])
|
45 |
-
end_concentration = concentration * volume_ratio
|
46 |
-
rep['end pH'] = - np.log10(end_concentration)
|
47 |
-
return rep
|
48 |
-
|
49 |
-
def get_alcohol_profile(ingredients, quantities):
|
50 |
-
ingredients = ingredients.copy()
|
51 |
-
quantities = quantities.copy()
|
52 |
-
assert len(ingredients) == len(quantities)
|
53 |
-
if 'mint' in ingredients:
|
54 |
-
mint_ind = ingredients.index('mint')
|
55 |
-
ingredients.pop(mint_ind)
|
56 |
-
quantities.pop(mint_ind)
|
57 |
-
alcohol = []
|
58 |
-
volume_mix = np.sum(quantities)
|
59 |
-
weights = quantities / volume_mix
|
60 |
-
assert np.abs(np.sum(weights) - 1) < 1e-4
|
61 |
-
ingredients_list = [ing.lower() for ing in ingredient_list]
|
62 |
-
for ing, q in zip(ingredients, quantities):
|
63 |
-
id = ingredients_list.index(ing)
|
64 |
-
alcohol.append(ingredient_profiles['ethanol'][id])
|
65 |
-
alcohol = np.dot(alcohol, weights)
|
66 |
-
return alcohol, volume_mix
|
67 |
-
|
68 |
-
def get_mix_profile(ingredients, quantities):
|
69 |
-
ingredients = ingredients.copy()
|
70 |
-
quantities = quantities.copy()
|
71 |
-
assert len(ingredients) == len(quantities)
|
72 |
-
if 'mint' in ingredients:
|
73 |
-
mint_ind = ingredients.index('mint')
|
74 |
-
ingredients.pop(mint_ind)
|
75 |
-
quantities.pop(mint_ind)
|
76 |
-
alcohol, sugar, acid = [], [], []
|
77 |
-
volume_mix = np.sum(quantities)
|
78 |
-
weights = quantities / volume_mix
|
79 |
-
assert np.abs(np.sum(weights) - 1) < 1e-4
|
80 |
-
ingredients_list = [ing.lower() for ing in ingredient_list]
|
81 |
-
for ing, q in zip(ingredients, quantities):
|
82 |
-
id = ingredients_list.index(ing)
|
83 |
-
sugar.append(ingredient_profiles['sugar'][id])
|
84 |
-
alcohol.append(ingredient_profiles['ethanol'][id])
|
85 |
-
acid.append(ingredient_profiles['acid'][id])
|
86 |
-
sugar = np.dot(sugar, weights)
|
87 |
-
acid = np.dot(acid, weights)
|
88 |
-
alcohol = np.dot(alcohol, weights)
|
89 |
-
return alcohol, sugar, acid
|
90 |
-
|
91 |
-
|
92 |
-
def extract_preparation_type(instructions, recipe):
|
93 |
-
flag = False
|
94 |
-
instructions = instructions.lower()
|
95 |
-
egg_in_recipe = any([find_ingredient_from_str(ing_str)[1]=='egg' for ing_str in recipe[1]])
|
96 |
-
if 'shake' in instructions:
|
97 |
-
if egg_in_recipe:
|
98 |
-
prep_type = 'egg_shaken'
|
99 |
-
else:
|
100 |
-
prep_type = 'shaken'
|
101 |
-
elif 'stir' in instructions:
|
102 |
-
prep_type = 'stirred'
|
103 |
-
elif 'blend' in instructions:
|
104 |
-
prep_type = 'blended'
|
105 |
-
elif any([w in instructions for w in ['build', 'mix', 'pour', 'combine', 'place']]):
|
106 |
-
prep_type = 'built'
|
107 |
-
else:
|
108 |
-
prep_type = 'built'
|
109 |
-
if egg_in_recipe and 'shaken' not in prep_type:
|
110 |
-
stop = 1
|
111 |
-
return flag, prep_type
|
112 |
-
|
113 |
-
def get_dilution_ratio(category, alcohol):
|
114 |
-
# formulas from the Liquid Intelligence book
|
115 |
-
# The formula for built was invented
|
116 |
-
if category == 'stirred':
|
117 |
-
return -1.21 * alcohol**2 + 1.246 * alcohol + 0.145
|
118 |
-
elif category in ['shaken', 'egg_shaken']:
|
119 |
-
return -1.567 * alcohol**2 + 1.742 * alcohol + 0.203
|
120 |
-
elif category == 'built':
|
121 |
-
return (-1.21 * alcohol**2 + 1.246 * alcohol + 0.145) /2
|
122 |
-
else:
|
123 |
-
return 1
|
124 |
-
|
125 |
-
def get_cocktail_rep(category, ingredients, quantities, keys):
|
126 |
-
ingredients = ingredients.copy()
|
127 |
-
quantities = quantities.copy()
|
128 |
-
assert len(ingredients) == len(quantities)
|
129 |
-
|
130 |
-
volume_mix = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] != 'mint'])
|
131 |
-
|
132 |
-
# compute alcohol content without mint ingredient
|
133 |
-
ingredients2 = [ing for ing in ingredients if ing != 'mint']
|
134 |
-
quantities2 = [q for ing, q in zip(ingredients, quantities) if ing != 'mint']
|
135 |
-
weights2 = quantities2 / np.sum(quantities2)
|
136 |
-
assert np.abs(np.sum(weights2) - 1) < 1e-4
|
137 |
-
ing_ids2 = np.array([ingredient2ingredient_id[ing] for ing in ingredients2])
|
138 |
-
alcohol = np.array([ingredient_profiles['ethanol'][ing_id] for ing_id in ing_ids2])
|
139 |
-
alcohol = np.dot(alcohol, weights2)
|
140 |
-
dilution_ratio = get_dilution_ratio(category, alcohol)
|
141 |
-
end_volume = volume_mix + volume_mix * dilution_ratio
|
142 |
-
volume_ratio = volume_mix / end_volume
|
143 |
-
end_alcohol = alcohol * volume_ratio
|
144 |
-
|
145 |
-
# computes representation of a cocktail from the recipe (ingredients, quantities) and volume
|
146 |
-
weights = quantities / np.sum(quantities)
|
147 |
-
assert np.abs(np.sum(weights) - 1) < 1e-4
|
148 |
-
ing_ids = np.array([ingredient2ingredient_id[ing] for ing in ingredients])
|
149 |
-
reps = ing_reps[ing_ids]
|
150 |
-
cocktail_rep = np.dot(weights, reps)
|
151 |
-
i_complex = keys.index('end complex')
|
152 |
-
cocktail_rep[i_complex] = np.mean(reps[:, i_complex]) + len(ing_ids) # complexity increases with number of ingredients
|
153 |
-
|
154 |
-
# compute profile after dilution
|
155 |
-
cocktail_rep = cocktail_rep * volume_ratio
|
156 |
-
cocktail_rep = np.concatenate([[end_volume], cocktail_rep])
|
157 |
-
return cocktail_rep, end_volume, end_alcohol
|
158 |
-
|
159 |
-
def get_profile(category, ingredients, quantities):
|
160 |
-
|
161 |
-
volume_mix = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] != 'mint'])
|
162 |
-
alcohol, sugar, acid = get_mix_profile(ingredients, quantities)
|
163 |
-
dilution_ratio = get_dilution_ratio(category, alcohol)
|
164 |
-
end_volume = volume_mix + volume_mix * dilution_ratio
|
165 |
-
volume_ratio = volume_mix / end_volume
|
166 |
-
profile = {'mix volume': volume_mix,
|
167 |
-
'mix alcohol': alcohol,
|
168 |
-
'mix sugar': sugar,
|
169 |
-
'mix acid': acid,
|
170 |
-
'dilution ratio': dilution_ratio,
|
171 |
-
'end volume': end_volume,
|
172 |
-
'end alcohol': alcohol * volume_ratio,
|
173 |
-
'end sugar': sugar * volume_ratio,
|
174 |
-
'end acid': acid * volume_ratio}
|
175 |
-
cocktail_rep = compute_cocktail_representation(profile, ingredients, quantities)
|
176 |
-
profile.update(cocktail_rep)
|
177 |
-
return profile
|
178 |
-
|
179 |
-
profile_keys = ['mix volume', 'end volume',
|
180 |
-
'dilution ratio',
|
181 |
-
'mix alcohol', 'end alcohol',
|
182 |
-
'mix sugar', 'end sugar',
|
183 |
-
'mix acid', 'end acid'] \
|
184 |
-
+ representation_keys \
|
185 |
-
+ ['end ' + k for k in representation_keys]
|
186 |
-
|
187 |
-
def update_profile_in_datapoint(datapoint, category, ingredients, quantities):
|
188 |
-
profile = get_profile(category, ingredients, quantities)
|
189 |
-
for k in profile_keys:
|
190 |
-
datapoint[k] = profile[k]
|
191 |
-
return datapoint
|
192 |
-
|
193 |
-
# define representation keys
|
194 |
-
def get_bunch_of_rep_keys():
|
195 |
-
dict_rep_keys = dict()
|
196 |
-
# all
|
197 |
-
rep_keys = profile_keys
|
198 |
-
dict_rep_keys['all'] = rep_keys
|
199 |
-
# only_end
|
200 |
-
rep_keys = [k for k in profile_keys if 'end' in k ]
|
201 |
-
dict_rep_keys['only_end'] = rep_keys
|
202 |
-
# except_end
|
203 |
-
rep_keys = [k for k in profile_keys if 'end' not in k ]
|
204 |
-
dict_rep_keys['except_end'] = rep_keys
|
205 |
-
# custom
|
206 |
-
to_remove = ['end alcohol', 'end sugar', 'end acid', 'end pH', 'end strong']
|
207 |
-
rep_keys = [k for k in profile_keys if 'end' in k ]
|
208 |
-
for k in to_remove:
|
209 |
-
if k in rep_keys:
|
210 |
-
rep_keys.remove(k)
|
211 |
-
dict_rep_keys['custom'] = rep_keys
|
212 |
-
# custom restricted
|
213 |
-
to_remove = ['end alcohol', 'end sugar', 'end acid', 'end pH', 'end strong', 'end spicy', 'end oaky']
|
214 |
-
rep_keys = [k for k in profile_keys if 'end' in k ]
|
215 |
-
for k in to_remove:
|
216 |
-
if k in rep_keys:
|
217 |
-
rep_keys.remove(k)
|
218 |
-
dict_rep_keys['restricted'] = rep_keys
|
219 |
-
dict_rep_keys['affective'] = ['end booze', 'end sweet', 'end sour', 'end fizzy', 'end complex', 'end bitter', 'end spicy', 'end colorful']
|
220 |
-
return dict_rep_keys
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/cocktails/utilities/glass_and_volume_utilities.py
DELETED
@@ -1,42 +0,0 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
glass_conversion = {'coupe':'coupe',
|
4 |
-
'martini': 'martini',
|
5 |
-
'collins': 'collins',
|
6 |
-
'oldfashion': 'oldfashion',
|
7 |
-
'Coupe glass': 'coupe',
|
8 |
-
'Old-fashioned glass': 'oldfashion',
|
9 |
-
'Martini glass': 'martini',
|
10 |
-
'Nick & Nora glass': 'coupe',
|
11 |
-
'Julep tin': 'oldfashion',
|
12 |
-
'Collins or Pineapple shell glass': 'collins',
|
13 |
-
'Collins glass': 'collins',
|
14 |
-
'Rocks glass': 'oldfashion',
|
15 |
-
'Highball (max 10oz/300ml)': 'collins',
|
16 |
-
'Wine glass': 'coupe',
|
17 |
-
'Flute glass': 'coupe',
|
18 |
-
'Double old-fashioned': 'oldfashion',
|
19 |
-
'Copa glass': 'coupe',
|
20 |
-
'Toddy glass': 'oldfashion',
|
21 |
-
'Sling glass': 'collins',
|
22 |
-
'Goblet glass': 'oldfashion',
|
23 |
-
'Fizz or Highball (8oz to 10oz)': 'collins',
|
24 |
-
'Copper mug or Collins glass': 'collins',
|
25 |
-
'Tiki mug or collins': 'collins',
|
26 |
-
'Snifter glass': 'oldfashion',
|
27 |
-
'Coconut shell or Collins glass': 'collins',
|
28 |
-
'Martini (large 10oz) glass': 'martini',
|
29 |
-
'Hurricane glass': 'collins',
|
30 |
-
'Absinthe glass or old-fashioned glass': 'oldfashion'
|
31 |
-
}
|
32 |
-
glass_volume = dict(coupe = 200,
|
33 |
-
collins=350,
|
34 |
-
martini=200,
|
35 |
-
oldfashion=320)
|
36 |
-
assert set(glass_conversion.values()) == set(glass_volume.keys())
|
37 |
-
|
38 |
-
volume_ranges = dict(stirred=(90, 97),
|
39 |
-
built=(70, 75),
|
40 |
-
shaken=(98, 112),
|
41 |
-
egg_shaken=(130, 143),
|
42 |
-
carbonated=(150, 150))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/cocktails/utilities/ingredients_utilities.py
DELETED
@@ -1,209 +0,0 @@
|
|
1 |
-
# This script loads the list and profiles of our ingredients selection.
|
2 |
-
# It defines rules to recognize ingredients from the list in recipes and the function to extract that information from ingredient strings.
|
3 |
-
|
4 |
-
import pandas as pd
|
5 |
-
from src.cocktails.config import INGREDIENTS_LIST_PATH, COCKTAILS_CSV_DATA
|
6 |
-
import numpy as np
|
7 |
-
|
8 |
-
ingredient_profiles = pd.read_csv(INGREDIENTS_LIST_PATH)
|
9 |
-
ingredient_list = [ing.lower() for ing in ingredient_profiles['ingredient']]
|
10 |
-
n_ingredients = len(ingredient_list)
|
11 |
-
ingredient2ingredient_id = dict(zip(ingredient_list, range(n_ingredients)))
|
12 |
-
|
13 |
-
ingredients_types = sorted(set(ingredient_profiles['type']))
|
14 |
-
# for each type, get all ingredients
|
15 |
-
ing_per_type = [[ing for ing in ingredient_list if ingredient_profiles['type'][ingredient_list.index(ing)] == type] for type in ingredients_types]
|
16 |
-
ingredients_per_type = dict(zip(ingredients_types, ing_per_type))
|
17 |
-
|
18 |
-
bubble_ingredients = ['soda', 'ginger beer', 'tonic', 'sparkling wine']
|
19 |
-
# rules to recognize ingredients in recipes.
|
20 |
-
# in [] are separate rules with an OR relation: only one needs to be satisfied
|
21 |
-
# within [], rules apply with and AND relation: all rules need to be satisfied.
|
22 |
-
# ~ indicates that the following expression must NOT appear
|
23 |
-
# simple expression indicate that the expression MUST appear.
|
24 |
-
ingredient_search = {#'salt': ['salt'],
|
25 |
-
'lime juice': [['lime', '~soda', '~lemonade', '~cordial']],
|
26 |
-
'lemon juice': [['lemon', '~soda', '~lemonade']],
|
27 |
-
'angostura': [['angostura', '~orange'],
|
28 |
-
['bitter', '~campari', '~orange', '~red', '~italian', '~fernet']],
|
29 |
-
'orange bitters': [['orange', 'bitter', '~bittersweet']],
|
30 |
-
'orange juice': [['orange', '~bitter', '~jam', '~marmalade', '~liqueur', '~water'],
|
31 |
-
['orange', 'squeeze']],
|
32 |
-
'pineapple juice': [['pineapple']],
|
33 |
-
# 'apple juice': [['apple', 'juice', '~pine']],
|
34 |
-
'cranberry juice': [['cranberry', 'juice']],
|
35 |
-
'cointreau': ['cointreau', 'triple sec', 'grand marnier', 'curaçao', 'curacao'],
|
36 |
-
'luxardo maraschino': ['luxardo', 'maraschino', 'kirsch'],
|
37 |
-
'amaretto': ['amaretto'],
|
38 |
-
'benedictine': ['benedictine', 'bénédictine', 'bénedictine', 'benédictine'],
|
39 |
-
'campari': ['campari', ['italian', 'red', 'bitter'], 'aperol', 'bittersweet', 'aperitivo', 'orange-red'],
|
40 |
-
# 'campari': ['campari', ['italian', 'red', 'bitter']],
|
41 |
-
# 'crème de violette': [['violette', 'crème'], ['crême', 'violette'], ['liqueur', 'violette']],
|
42 |
-
# 'aperol': ['aperol', 'bittersweet', 'aperitivo', 'orange-red'],
|
43 |
-
'green chartreuse': ['chartreuse'],
|
44 |
-
'black raspberry liqueur': [['cassis', 'liqueur'],
|
45 |
-
['black raspberry', 'liqueur'],
|
46 |
-
['raspberry', 'liqueur'],
|
47 |
-
['strawberry', 'liqueur'],
|
48 |
-
['blackberry', 'liqueur'],
|
49 |
-
['violette', 'crème'], ['crême', 'violette'], ['liqueur', 'violette']],
|
50 |
-
# 'simple syrup': [],
|
51 |
-
# 'drambuie': ['drambuie'],
|
52 |
-
# 'fernet branca': ['fernet', 'branca'],
|
53 |
-
'gin': [['gin', '~sloe', '~ginger']],
|
54 |
-
'vodka': ['vodka'],
|
55 |
-
'cuban rum': [['rum', 'puerto rican'], ['light', 'rum'], ['white', 'rum'], ['rum', 'havana', '~7'], ['rum', 'bacardi']],
|
56 |
-
'cognac': [['cognac', '~grand marnier', '~cointreau', '~orange']],
|
57 |
-
# 'bourbon': [['bourbon', '~liqueur']],
|
58 |
-
# 'tequila': ['tequila', 'pisco'],
|
59 |
-
# 'tequila': ['tequila'],
|
60 |
-
'scotch': ['scotch'],
|
61 |
-
'dark rum': [['rum', 'age', '~bacardi', '~havana'],
|
62 |
-
['rum', 'dark', '~bacardi', '~havana'],
|
63 |
-
['rum', 'old', '~bacardi', '~havana'],
|
64 |
-
['rum', 'old', '7'],
|
65 |
-
['rum', 'havana', '7'],
|
66 |
-
['havana', 'rum', 'especial']],
|
67 |
-
'absinthe': ['absinthe'],
|
68 |
-
'rye whiskey': ['rye', ['bourbon', '~liqueur']],
|
69 |
-
# 'rye whiskey': ['rye'],
|
70 |
-
'apricot brandy': [['apricot', 'brandy']],
|
71 |
-
# 'pisco': ['pisco'],
|
72 |
-
# 'cachaça': ['cachaça', 'cachaca'],
|
73 |
-
'egg': [['egg', 'white', '~yolk', '~whole']],
|
74 |
-
'soda': [['soda', 'water', '~lemon', '~lime']],
|
75 |
-
'mint': ['mint'],
|
76 |
-
'sparkling wine': ['sparkling wine', 'prosecco', 'champagne'],
|
77 |
-
'ginger beer': [['ginger', 'beer'], ['ginger', 'ale']],
|
78 |
-
'tonic': [['tonic'], ['7up'], ['sprite']],
|
79 |
-
# 'espresso': ['espresso', 'expresso', ['café', '~liqueur', '~cream'],
|
80 |
-
# ['cafe', '~liqueur', '~cream'],
|
81 |
-
# ['coffee', '~liqueur', '~cream']],
|
82 |
-
# 'southern comfort': ['southern comfort'],
|
83 |
-
# 'cola': ['cola', 'coke', 'pepsi'],
|
84 |
-
'double syrup': [['sugar','~raspberry'], ['simple', 'syrup'], ['double', 'syrup']],
|
85 |
-
# 'grenadine': ['grenadine', ['pomegranate', 'syrup']],
|
86 |
-
'grenadine': ['grenadine', ['pomegranate', 'syrup'], ['raspberry', 'syrup', '~black']],
|
87 |
-
'honey syrup': ['honey', ['maple', 'syrup']],
|
88 |
-
# 'raspberry syrup': [['raspberry', 'syrup', '~black']],
|
89 |
-
'dry vermouth': [['vermouth', 'dry'], ['vermouth', 'white'], ['vermouth', 'french'], 'lillet'],
|
90 |
-
'sweet vermouth': [['vermouth', 'sweet'], ['vermouth', 'red'], ['vermouth', 'italian']],
|
91 |
-
# 'lillet blanc': ['lillet'],
|
92 |
-
'water': [['water', '~sugar', '~coconut', '~soda', '~tonic', '~honey', '~orange', '~melon']]
|
93 |
-
}
|
94 |
-
# check that there is a rule for all ingredients in the list
|
95 |
-
assert sorted(ingredient_list) == sorted(ingredient_search.keys()), 'ing search dict keys do not match ingredient list'
|
96 |
-
|
97 |
-
def get_ingredients_info():
|
98 |
-
data = pd.read_csv(COCKTAILS_CSV_DATA)
|
99 |
-
max_ingredients, ingredient_set, liquor_set, liqueur_set, vermouth_set = get_max_n_ingredients(data)
|
100 |
-
ingredient_list = sorted(ingredient_set)
|
101 |
-
alcohol = sorted(liquor_set.union(liqueur_set).union(vermouth_set).union(set(['sparkling wine'])))
|
102 |
-
ind_alcohol = [i for i in range(len(ingredient_list)) if ingredient_list[i] in alcohol]
|
103 |
-
return max_ingredients, ingredient_list, ind_alcohol
|
104 |
-
|
105 |
-
def get_max_n_ingredients(data):
|
106 |
-
max_count = 0
|
107 |
-
ingredient_set = set()
|
108 |
-
alcohol_set = set()
|
109 |
-
liqueur_set = set()
|
110 |
-
vermouth_set = set()
|
111 |
-
ing_str = np.array(data['ingredients_str'])
|
112 |
-
for i in range(len(data['names'])):
|
113 |
-
ingredients, quantities = extract_ingredients(ing_str[i])
|
114 |
-
max_count = max(max_count, len(ingredients))
|
115 |
-
for ing in ingredients:
|
116 |
-
ingredient_set.add(ing)
|
117 |
-
if ing in ingredients_per_type['liquor']:
|
118 |
-
alcohol_set.add(ing)
|
119 |
-
if ing in ingredients_per_type['liqueur']:
|
120 |
-
liqueur_set.add(ing)
|
121 |
-
if ing in ingredients_per_type['vermouth']:
|
122 |
-
vermouth_set.add(ing)
|
123 |
-
return max_count, ingredient_set, alcohol_set, liqueur_set, vermouth_set
|
124 |
-
|
125 |
-
def find_ingredient_from_str(ing_str):
|
126 |
-
# function that assigns an ingredient string to one of the ingredient if possible, following the rules defined above.
|
127 |
-
# return a flag and the ingredient string. When flag is false, the ingredient has not been found and the cocktail is rejected.
|
128 |
-
ing_str = ing_str.lower()
|
129 |
-
flags = []
|
130 |
-
for k in ingredient_list:
|
131 |
-
or_flags = [] # get flag for each of several conditions
|
132 |
-
for i_p, pattern in enumerate(ingredient_search[k]):
|
133 |
-
or_flags.append(True)
|
134 |
-
if isinstance(pattern, str):
|
135 |
-
if pattern[0] == '~' and pattern[1:] in ing_str:
|
136 |
-
or_flags[-1] = False
|
137 |
-
elif pattern[0] != '~' and pattern not in ing_str:
|
138 |
-
or_flags[-1] = False
|
139 |
-
elif isinstance(pattern, list):
|
140 |
-
for element in pattern:
|
141 |
-
if element[0] == '~':
|
142 |
-
or_flags[-1] = or_flags[-1] and not element[1:] in ing_str
|
143 |
-
else:
|
144 |
-
or_flags[-1] = or_flags[-1] and element in ing_str
|
145 |
-
else:
|
146 |
-
raise ValueError
|
147 |
-
flags.append(any(or_flags))
|
148 |
-
if sum(flags) > 1:
|
149 |
-
print(ing_str)
|
150 |
-
for i_f, f in enumerate(flags):
|
151 |
-
if f:
|
152 |
-
print(ingredient_list[i_f])
|
153 |
-
stop = 1
|
154 |
-
return True, ingredient_list[flags.index(True)]
|
155 |
-
elif sum(flags) == 0:
|
156 |
-
# if 'grape' not in ing_str:
|
157 |
-
# print('\t\t Not found:', ing_str)
|
158 |
-
return True, None
|
159 |
-
else:
|
160 |
-
return False, ingredient_list[flags.index(True)]
|
161 |
-
|
162 |
-
def get_cocktails_per_ingredient(ing_strs):
|
163 |
-
cocktails_per_ing = dict(zip(ingredient_list, [[] for _ in range(len(ingredient_list))]))
|
164 |
-
for i_ing, ing_str in enumerate(ing_strs):
|
165 |
-
ingredients, _ = extract_ingredients(ing_str)
|
166 |
-
for ing in ingredients:
|
167 |
-
cocktails_per_ing[ing].append(i_ing)
|
168 |
-
return cocktails_per_ing
|
169 |
-
|
170 |
-
def extract_ingredients(ingredient_str):
|
171 |
-
# extract list of ingredients and quantities from an formatted ingredient string (reverse of format_ingredients)
|
172 |
-
ingredient_str = ingredient_str[1: -1]
|
173 |
-
words = ingredient_str.split(',')
|
174 |
-
ingredients = []
|
175 |
-
quantities = []
|
176 |
-
for i in range(len(words)//2):
|
177 |
-
ingredients.append(words[2 * i][1:])
|
178 |
-
quantities.append(float(words[2 * i + 1][:-1]))
|
179 |
-
return ingredients, quantities
|
180 |
-
|
181 |
-
def format_ingredients(ingredients, quantities):
|
182 |
-
# format an ingredient string from the lists of ingredients and quantities (reverse of extract_ingredients)
|
183 |
-
out = '['
|
184 |
-
for ing, q in zip(ingredients, quantities):
|
185 |
-
if ing[-1] == ' ':
|
186 |
-
ingre = ing[:-1]
|
187 |
-
else:
|
188 |
-
ingre = ing
|
189 |
-
out += f'({ingre},{q}),'
|
190 |
-
out = out[:-1] + ']'
|
191 |
-
return out
|
192 |
-
|
193 |
-
|
194 |
-
def get_ingredient_count(data):
|
195 |
-
# get count of ingredients in the whole dataset
|
196 |
-
ingredient_counts = dict(zip(ingredient_list, [0] * len(ingredient_list)))
|
197 |
-
for i in range(len(data['names'])):
|
198 |
-
if data['to_keep'][i]:
|
199 |
-
ingredients, _ = extract_ingredients(data['ingredients_str'][i])
|
200 |
-
for i in ingredients:
|
201 |
-
ingredient_counts[i] += 1
|
202 |
-
return ingredient_counts
|
203 |
-
|
204 |
-
def add_counts_to_ingredient_list(data):
|
205 |
-
# update the list of ingredients to add their count of occurence in dataset.
|
206 |
-
ingredient_counts = get_ingredient_count(data)
|
207 |
-
counts = [ingredient_counts[k] for k in ingredient_list]
|
208 |
-
ingredient_profiles['counts'] = counts
|
209 |
-
ingredient_profiles.to_csv(INGREDIENTS_LIST_PATH, index=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/cocktails/utilities/other_scrubbing_utilities.py
DELETED
@@ -1,240 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import pickle
|
3 |
-
from src.cocktails.utilities.cocktail_utilities import get_profile, profile_keys
|
4 |
-
from src.cocktails.utilities.ingredients_utilities import extract_ingredients, ingredient_list, ingredient_profiles
|
5 |
-
from src.cocktails.utilities.glass_and_volume_utilities import glass_volume, volume_ranges
|
6 |
-
|
7 |
-
one_dash = 1
|
8 |
-
one_splash = 6
|
9 |
-
one_tablespoon = 15
|
10 |
-
one_barspoon = 5
|
11 |
-
fill_rate = 0.8
|
12 |
-
quantity_factors ={'ml':1,
|
13 |
-
'cl':10,
|
14 |
-
'splash':one_splash,
|
15 |
-
'splashes':one_splash,
|
16 |
-
'dash':one_dash,
|
17 |
-
'dashes':one_dash,
|
18 |
-
'spoon':one_barspoon,
|
19 |
-
'spoons':one_barspoon,
|
20 |
-
'tablespoon':one_tablespoon,
|
21 |
-
'barspoons':one_barspoon,
|
22 |
-
'barspoon':one_barspoon,
|
23 |
-
'bar spoons': one_barspoon,
|
24 |
-
'bar spoon': one_barspoon,
|
25 |
-
'tablespoons':one_tablespoon,
|
26 |
-
'teaspoon':5,
|
27 |
-
'teaspoons':5,
|
28 |
-
'drop':0.05,
|
29 |
-
'drops':0.05}
|
30 |
-
quantitiy_keys = sorted(quantity_factors.keys())
|
31 |
-
indexes_keys = np.flip(np.argsort([len(k) for k in quantitiy_keys]))
|
32 |
-
quantity_factors_keys = list(np.array(quantitiy_keys)[indexes_keys])
|
33 |
-
|
34 |
-
keys_to_track = ['names', 'urls', 'glass', 'garnish', 'recipe', 'how_to', 'review', 'taste_rep', 'valid']
|
35 |
-
keys_to_add = ['category', 'subcategory', 'ingredients_str', 'ingredients', 'quantities', 'to_keep']
|
36 |
-
keys_to_update = ['glass']
|
37 |
-
keys_for_csv = ['names', 'category', 'subcategory', 'ingredients_str', 'urls', 'glass', 'garnish', 'how_to', 'review', 'taste_rep'] + profile_keys
|
38 |
-
|
39 |
-
to_replace_q = {' fresh': ''}
|
40 |
-
to_replace_ing = {'maple syrup': 'honey syrup',
|
41 |
-
'agave syrup': 'honey syrup',
|
42 |
-
'basil': 'mint'}
|
43 |
-
|
44 |
-
def print_recipe(unit='mL', ingredient_str=None, ingredients=None, quantities=None, name='', cat='', to_print=True):
|
45 |
-
str_out = ''
|
46 |
-
if ingredient_str is None:
|
47 |
-
assert len(ingredients) == len(quantities), 'provide either ingredient_str, or list ingredients and quantities'
|
48 |
-
else:
|
49 |
-
assert ingredients is None and quantities is None, 'provide either ingredient_str, or list ingredients and quantities'
|
50 |
-
ingredients, quantities = extract_ingredients(ingredient_str)
|
51 |
-
|
52 |
-
str_out += f'\nRecipe:'
|
53 |
-
if name != '' and name is not None: str_out += f' {name}'
|
54 |
-
if cat != '': str_out += f' ({cat})'
|
55 |
-
str_out += '\n'
|
56 |
-
for i in range(len(ingredients)):
|
57 |
-
# get quantifier
|
58 |
-
if ingredients[i] == 'egg':
|
59 |
-
quantities[i] = 1
|
60 |
-
ingredients[i] = 'egg white'
|
61 |
-
if unit == 'mL':
|
62 |
-
quantifier = ' (30 mL)'
|
63 |
-
elif unit == 'oz':
|
64 |
-
quantifier = ' (1 fl oz)'
|
65 |
-
else:
|
66 |
-
raise ValueError
|
67 |
-
elif ingredients[i] in ['angostura', 'orange bitters']:
|
68 |
-
quantities[i] = max(1, int(quantities[i] / 0.6))
|
69 |
-
quantifier = ' dash'
|
70 |
-
if quantities[i] > 1: quantifier += 'es'
|
71 |
-
elif ingredients[i] == 'mint':
|
72 |
-
if quantities[i] > 1: quantifier = ' leaves'
|
73 |
-
else: quantifier = ' leaf'
|
74 |
-
else:
|
75 |
-
if unit == "oz":
|
76 |
-
quantities[i] = float(f"{quantities[i] * 0.033814:.3f}") # convert to fl oz
|
77 |
-
quantifier = ' fl oz'
|
78 |
-
else:
|
79 |
-
quantifier = ' mL'
|
80 |
-
str_out += f' {quantities[i]}{quantifier} - {ingredients[i]}\n'
|
81 |
-
|
82 |
-
if to_print:
|
83 |
-
print(str_out)
|
84 |
-
return str_out
|
85 |
-
|
86 |
-
|
87 |
-
def test_datapoint(datapoint, category, ingredients, quantities):
|
88 |
-
# run checks
|
89 |
-
ingredient_indexes = [ingredient_list.index(ing) for ing in ingredients]
|
90 |
-
profile = get_profile(category, ingredients, quantities)
|
91 |
-
volume = profile['end volume']
|
92 |
-
alcohol = profile['end alcohol']
|
93 |
-
acid = profile['end acid']
|
94 |
-
sugar = profile['end sugar']
|
95 |
-
# check volume
|
96 |
-
if datapoint['glass'] != None:
|
97 |
-
if volume > glass_volume[datapoint['glass']] * fill_rate:
|
98 |
-
# recompute quantities for it to match
|
99 |
-
ratio = fill_rate * glass_volume[datapoint['glass']] / volume
|
100 |
-
for i_q in range(len(quantities)):
|
101 |
-
quantities[i_q] = float(f'{quantities[i_q] * ratio:.2f}')
|
102 |
-
# check alcohol
|
103 |
-
assert alcohol < 30, 'too boozy'
|
104 |
-
assert alcohol < 5, 'not boozy enough'
|
105 |
-
assert acid < 2, 'too much acid'
|
106 |
-
assert sugar < 20, 'too much sugar'
|
107 |
-
assert len(ingredients) > 1, 'only one ingredient'
|
108 |
-
if len(set(ingredients)) != len(ingredients):
|
109 |
-
i_doubles = []
|
110 |
-
s_ing = set()
|
111 |
-
for i, ing in enumerate(ingredients):
|
112 |
-
if ing in s_ing:
|
113 |
-
i_doubles.append(i)
|
114 |
-
else:
|
115 |
-
s_ing.add(ing)
|
116 |
-
ingredient_double_ok = ['mint', 'cointreau', 'lemon juice', 'cuban rum', 'double syrup']
|
117 |
-
if len(i_doubles) == 1 and ingredients[i_doubles[0]] in ingredient_double_ok:
|
118 |
-
ing_double = ingredients[i_doubles[0]]
|
119 |
-
double_q = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] == ing_double])
|
120 |
-
ingredients.pop(i_doubles[0])
|
121 |
-
quantities.pop(i_doubles[0])
|
122 |
-
quantities[ingredients.index(ing_double)] = double_q
|
123 |
-
else:
|
124 |
-
assert False, f'double ingredient, not {ingredient_double_ok}'
|
125 |
-
lemon_lime_q = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] in ['lime juice', 'lemon juice']])
|
126 |
-
assert lemon_lime_q <= 45, 'too much lemon and lime'
|
127 |
-
salt_q = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] == 'salt'])
|
128 |
-
assert salt_q <= 8, 'too much salt'
|
129 |
-
bitter_q = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] in ['angostura', 'orange bitters']])
|
130 |
-
assert bitter_q <= 5 * one_dash, 'too much bitter'
|
131 |
-
absinthe_q = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] == 'absinthe'])
|
132 |
-
if absinthe_q > 4 * one_dash:
|
133 |
-
mix_volume = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] != 'mint'])
|
134 |
-
assert absinthe_q < 0.5 * mix_volume, 'filter absinthe glasses'
|
135 |
-
if any([w in datapoint['how_to'] or any([w in ing.lower() for ing in datapoint['recipe'][1]]) for w in ['warm', 'boil', 'hot']]) and 'shot' not in datapoint['how_to']:
|
136 |
-
assert False
|
137 |
-
water_q = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] == 'water'])
|
138 |
-
assert water_q < 40
|
139 |
-
# n_liqueur = np.sum([ingredient_profiles['type'][i].lower() == 'liqueur' for i in ingredient_indexes])
|
140 |
-
# assert n_liqueur <= 2
|
141 |
-
n_liqueur_and_vermouth = np.sum([ingredient_profiles['type'][i].lower() in ['liqueur', 'vermouth'] for i in ingredient_indexes])
|
142 |
-
assert n_liqueur_and_vermouth <= 3
|
143 |
-
return ingredients, quantities
|
144 |
-
|
145 |
-
def run_battery_checks_difford(datapoint, category, ingredients, quantities):
|
146 |
-
flag = False
|
147 |
-
try:
|
148 |
-
ingredients, quantities = test_datapoint(datapoint, category, ingredients, quantities)
|
149 |
-
except:
|
150 |
-
flag = True
|
151 |
-
print(datapoint["names"])
|
152 |
-
print(datapoint["urls"])
|
153 |
-
ingredients, quantities = None, None
|
154 |
-
|
155 |
-
return flag, ingredients, quantities
|
156 |
-
|
157 |
-
def tambouille(q, ingredients_scrubbed, quantities_scrubbed, cat):
|
158 |
-
# ugly
|
159 |
-
ing_scrubbed = ingredients_scrubbed[len(quantities_scrubbed)]
|
160 |
-
if q == '4 cube' and ing_scrubbed == 'pineapple juice':
|
161 |
-
q = '20 ml'
|
162 |
-
elif 'top up with' in q:
|
163 |
-
volume_so_far = np.sum([quantities_scrubbed[i] for i in range(len(quantities_scrubbed)) if ingredients_scrubbed[i] != 'mint'])
|
164 |
-
volume_mix = np.sum(volume_ranges[cat]) / 2
|
165 |
-
if (volume_mix - volume_so_far) < 15:
|
166 |
-
q = '15 ml'#
|
167 |
-
else:
|
168 |
-
q = str(int(volume_mix - volume_so_far)) + ' ml'
|
169 |
-
elif q == '1 pinch' and ing_scrubbed == 'salt':
|
170 |
-
q = '2 drops'
|
171 |
-
elif 'cube' in q and ing_scrubbed == 'double syrup':
|
172 |
-
q = f'{float(q.split(" ")[0]) * 2 * 1.7:.2f} ml' #2g per cube, 1.7 is ratio solid / syrup
|
173 |
-
elif 'wedge' in q:
|
174 |
-
if ing_scrubbed == 'orange juice':
|
175 |
-
vol = 70
|
176 |
-
elif ing_scrubbed == 'lime juice':
|
177 |
-
vol = 30
|
178 |
-
elif ing_scrubbed == 'lemon juice':
|
179 |
-
vol = 45
|
180 |
-
elif ing_scrubbed == 'pineapple juice':
|
181 |
-
vol = 140
|
182 |
-
factor = float(q.split(' ')[0]) * 0.15 # consider a wedge to be 0.15*the fruit.
|
183 |
-
q = f'{factor * vol:.2f} ml'
|
184 |
-
elif 'slice' in q:
|
185 |
-
if ing_scrubbed == 'orange juice':
|
186 |
-
vol = 70
|
187 |
-
elif ing_scrubbed == 'lime juice':
|
188 |
-
vol = 30
|
189 |
-
elif ing_scrubbed == 'lemon juice':
|
190 |
-
vol = 45
|
191 |
-
elif ing_scrubbed == 'pineapple juice':
|
192 |
-
vol = 140
|
193 |
-
f = q.split(' ')[0]
|
194 |
-
if len(f.split('⁄')) > 1:
|
195 |
-
frac = f.split('⁄')
|
196 |
-
factor = float(frac[0]) / float(frac[1])
|
197 |
-
else:
|
198 |
-
factor = float(f)
|
199 |
-
factor *= 0.1 # consider a slice to be 0.1*the fruit.
|
200 |
-
q = f'{factor * vol:.2f} ml'
|
201 |
-
elif q == '1 whole' and ing_scrubbed == 'luxardo maraschino':
|
202 |
-
q = '10 ml'
|
203 |
-
elif ing_scrubbed == 'egg' and 'ml' not in q:
|
204 |
-
q = f'{float(q) * 30:.2f} ml' # 30 ml per egg
|
205 |
-
return q
|
206 |
-
|
207 |
-
|
208 |
-
def compute_eucl_dist(a, b):
|
209 |
-
return np.sqrt(np.sum((a - b)**2))
|
210 |
-
|
211 |
-
def evaluate_with_quadruplets(representations, strategy='all'):
|
212 |
-
with open(QUADRUPLETS_PATH, 'rb') as f:
|
213 |
-
data = pickle.load(f)
|
214 |
-
data = list(data.values())
|
215 |
-
quadruplets = []
|
216 |
-
if strategy != 'all':
|
217 |
-
for d in data:
|
218 |
-
if d[0] == strategy:
|
219 |
-
quadruplets.append(d[1:])
|
220 |
-
elif strategy == 'all':
|
221 |
-
for d in data:
|
222 |
-
quadruplets.append(d[1:])
|
223 |
-
else:
|
224 |
-
raise ValueError
|
225 |
-
|
226 |
-
scores = []
|
227 |
-
for q in quadruplets:
|
228 |
-
close = q[0]
|
229 |
-
if len(close) == 2:
|
230 |
-
far = q[1]
|
231 |
-
distance_close = compute_eucl_dist(representations[close[0]], representations[close[1]])
|
232 |
-
distances_far = [compute_eucl_dist(representations[far[i][0]], representations[far[i][1]]) for i in range(len(far))]
|
233 |
-
scores.append(distance_close < np.min(distances_far))
|
234 |
-
if len(scores) == 0:
|
235 |
-
score = np.nan
|
236 |
-
else:
|
237 |
-
score = np.mean(scores)
|
238 |
-
return score
|
239 |
-
|
240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/debugger.py
DELETED
@@ -1,180 +0,0 @@
|
|
1 |
-
import os.path
|
2 |
-
|
3 |
-
# from src.music.data_collection.is_audio_solo_piano import calculate_piano_solo_prob
|
4 |
-
from src.music.utils import load_audio
|
5 |
-
from src.music.config import FPS
|
6 |
-
import pretty_midi as pm
|
7 |
-
import numpy as np
|
8 |
-
from src.music.config import MUSIC_REP_PATH, MUSIC_NN_PATH
|
9 |
-
from sklearn.neighbors import NearestNeighbors
|
10 |
-
from src.cocktails.config import FULL_COCKTAIL_REP_PATH, COCKTAIL_NN_PATH, COCKTAILS_CSV_DATA
|
11 |
-
# from src.cocktails.pipeline.get_affect2affective_cluster import get_affective_cluster_centers
|
12 |
-
from src.cocktails.utilities.other_scrubbing_utilities import print_recipe
|
13 |
-
from src.music.utils import get_all_subfiles_with_extension
|
14 |
-
import os
|
15 |
-
import pickle
|
16 |
-
import pandas as pd
|
17 |
-
import time
|
18 |
-
|
19 |
-
keyword = 'b256_r128_represented'
|
20 |
-
def load_reps(rep_path, sample_size=None):
|
21 |
-
if sample_size:
|
22 |
-
with open(rep_path + f'all_reps_unnormalized_sample{sample_size}.pickle', 'rb') as f:
|
23 |
-
data = pickle.load(f)
|
24 |
-
else:
|
25 |
-
with open(rep_path + f'music_reps_unnormalized.pickle', 'rb') as f:
|
26 |
-
data = pickle.load(f)
|
27 |
-
reps = data['reps']
|
28 |
-
# playlists = [r.split(f'_{keyword}')[0].split('/')[-1] for r in data['paths']]
|
29 |
-
playlists = [r.split(f'{keyword}')[1].split('/')[1] for r in data['paths']]
|
30 |
-
n_data, dim_data = reps.shape
|
31 |
-
return reps, data['paths'], playlists, n_data, dim_data
|
32 |
-
|
33 |
-
class Debugger():
|
34 |
-
def __init__(self, verbose=True):
|
35 |
-
|
36 |
-
if verbose: print('Setting up debugger.')
|
37 |
-
if not os.path.exists(MUSIC_NN_PATH):
|
38 |
-
reps_path = MUSIC_REP_PATH + 'music_reps_unnormalized.pickle'
|
39 |
-
if not os.path.exists(reps_path):
|
40 |
-
all_rep_path = get_all_subfiles_with_extension(MUSIC_REP_PATH, max_depth=3, extension='.txt', current_depth=0)
|
41 |
-
all_data = []
|
42 |
-
new_all_rep_path = []
|
43 |
-
for i_r, r in enumerate(all_rep_path):
|
44 |
-
if 'mean_std' not in r:
|
45 |
-
all_data.append(np.loadtxt(r))
|
46 |
-
assert len(all_data[-1]) == 128
|
47 |
-
new_all_rep_path.append(r)
|
48 |
-
data = np.array(all_data)
|
49 |
-
to_save = dict(reps=data,
|
50 |
-
paths=new_all_rep_path)
|
51 |
-
with open(reps_path, 'wb') as f:
|
52 |
-
pickle.dump(to_save, f)
|
53 |
-
|
54 |
-
reps, self.rep_paths, playlists, n_data, self.dim_rep_music = load_reps(MUSIC_REP_PATH)
|
55 |
-
self.nn_model_music = NearestNeighbors(n_neighbors=6, metric='cosine')
|
56 |
-
self.nn_model_music.fit(reps)
|
57 |
-
to_save = dict(nn_model=self.nn_model_music,
|
58 |
-
rep_paths=self.rep_paths,
|
59 |
-
dim_rep_music=self.dim_rep_music)
|
60 |
-
with open(MUSIC_NN_PATH, 'wb') as f:
|
61 |
-
pickle.dump(to_save, f)
|
62 |
-
else:
|
63 |
-
with open(MUSIC_NN_PATH, 'rb') as f:
|
64 |
-
data = pickle.load(f)
|
65 |
-
self.nn_model_music = data['nn_model']
|
66 |
-
self.rep_paths = data['rep_paths']
|
67 |
-
self.dim_rep_music = data['dim_rep_music']
|
68 |
-
if verbose: print(f' {len(self.rep_paths)} songs, representation dim: {self.dim_rep_music}')
|
69 |
-
self.rep_paths = np.array(self.rep_paths)
|
70 |
-
if not os.path.exists(COCKTAIL_NN_PATH):
|
71 |
-
cocktail_reps = np.loadtxt(FULL_COCKTAIL_REP_PATH)
|
72 |
-
# cocktail_reps = (cocktail_reps - cocktail_reps.mean(axis=0)) / cocktail_reps.std(axis=0)
|
73 |
-
self.nn_model_cocktail = NearestNeighbors(n_neighbors=6)
|
74 |
-
self.nn_model_cocktail.fit(cocktail_reps)
|
75 |
-
self.dim_rep_cocktail = cocktail_reps.shape[1]
|
76 |
-
self.n_cocktails = cocktail_reps.shape[0]
|
77 |
-
to_save = dict(nn_model=self.nn_model_cocktail,
|
78 |
-
dim_rep_cocktail=self.dim_rep_cocktail,
|
79 |
-
n_cocktails=self.n_cocktails)
|
80 |
-
with open(COCKTAIL_NN_PATH, 'wb') as f:
|
81 |
-
pickle.dump(to_save, f)
|
82 |
-
else:
|
83 |
-
with open(COCKTAIL_NN_PATH, 'rb') as f:
|
84 |
-
data = pickle.load(f)
|
85 |
-
self.nn_model_cocktail = data['nn_model']
|
86 |
-
self.dim_rep_cocktail = data['dim_rep_cocktail']
|
87 |
-
self.n_cocktails = data['n_cocktails']
|
88 |
-
if verbose: print(f' {self.n_cocktails} cocktails, representation dim: {self.dim_rep_cocktail}')
|
89 |
-
|
90 |
-
self.cocktail_data = pd.read_csv(COCKTAILS_CSV_DATA)
|
91 |
-
# self.affective_cluster_centers = get_affective_cluster_centers()
|
92 |
-
self.keys_to_print = ['mse_reconstruction', 'nearest_cocktail_recipes', 'nearest_cocktail_urls',
|
93 |
-
'nn_music_dists', 'nn_music', 'dim_rep', 'nb_notes', 'audio_len', 'piano_solo_prob', 'recipe_score', 'cocktail_rep']
|
94 |
-
# 'affect', 'affective_cluster_id', 'affective_cluster_center',
|
95 |
-
|
96 |
-
|
97 |
-
def get_nearest_songs(self, music_rep):
|
98 |
-
dists, indexes = self.nn_model_music.kneighbors(music_rep.reshape(1, -1))
|
99 |
-
indexes = indexes.flatten()[:5]
|
100 |
-
rep_paths = [r.split('/')[-1] for r in self.rep_paths[indexes[:5]]]
|
101 |
-
return rep_paths, dists.flatten().tolist()
|
102 |
-
|
103 |
-
def get_nearest_cocktails(self, cocktail_rep):
|
104 |
-
dists, indexes = self.nn_model_cocktail.kneighbors(cocktail_rep.reshape(1, -1))
|
105 |
-
indexes = indexes.flatten()
|
106 |
-
nn_names = np.array(self.cocktail_data['names'])[indexes].tolist()
|
107 |
-
nn_urls = np.array(self.cocktail_data['urls'])[indexes].tolist()
|
108 |
-
nn_recipes = [print_recipe(ingredient_str=ing_str, to_print=False) for ing_str in np.array(self.cocktail_data['ingredients_str'])[indexes]]
|
109 |
-
nn_ing_strs = np.array(self.cocktail_data['ingredients_str'])[indexes].tolist()
|
110 |
-
return indexes, nn_names, nn_urls, nn_recipes, nn_ing_strs
|
111 |
-
|
112 |
-
def extract_info(self, all_paths, affective_cluster_id, affect, cocktail_rep, music_reconstruction, recipe_score, verbose=False, level=0):
|
113 |
-
if verbose: print(' ' * level + 'Extracting debug info..')
|
114 |
-
init_time = time.time()
|
115 |
-
debug_dict = dict()
|
116 |
-
debug_dict['all_paths'] = all_paths
|
117 |
-
debug_dict['recipe_score'] = recipe_score
|
118 |
-
|
119 |
-
if all_paths['audio_path'] != None:
|
120 |
-
# is it piano?
|
121 |
-
debug_dict['piano_solo_prob'] = None#float(calculate_piano_solo_prob(all_paths['audio_path'])[0])
|
122 |
-
# how long is the audio
|
123 |
-
(audio, _) = load_audio(all_paths['audio_path'], sr=FPS, mono=True)
|
124 |
-
debug_dict['audio_len'] = int(len(audio) / FPS)
|
125 |
-
else:
|
126 |
-
debug_dict['piano_solo_prob'] = None
|
127 |
-
debug_dict['audio_len'] = None
|
128 |
-
|
129 |
-
# how many notes?
|
130 |
-
midi = pm.PrettyMIDI(all_paths['processed_path'])
|
131 |
-
debug_dict['nb_notes'] = len(midi.instruments[0].notes)
|
132 |
-
|
133 |
-
# dimension of music rep
|
134 |
-
representation = np.loadtxt(all_paths['representation_path'])
|
135 |
-
debug_dict['dim_rep'] = representation.shape[0]
|
136 |
-
|
137 |
-
# closest songs in dataset
|
138 |
-
debug_dict['nn_music'], debug_dict['nn_music_dists'] = self.get_nearest_songs(representation)
|
139 |
-
|
140 |
-
# get affective cluster info
|
141 |
-
# debug_dict['affective_cluster_id'] = affective_cluster_id[0]
|
142 |
-
# debug_dict['affective_cluster_center'] = self.affective_cluster_centers[affective_cluster_id].flatten().tolist()
|
143 |
-
# debug_dict['affect'] = affect.flatten().tolist()
|
144 |
-
indexes, nn_names, nn_urls, nn_recipes, nn_ing_strs = self.get_nearest_cocktails(cocktail_rep)
|
145 |
-
debug_dict['cocktail_rep'] = cocktail_rep.copy().tolist()
|
146 |
-
debug_dict['nearest_cocktail_indexes'] = indexes.tolist()
|
147 |
-
debug_dict['nn_ing_strs'] = nn_ing_strs
|
148 |
-
debug_dict['nearest_cocktail_names'] = nn_names
|
149 |
-
debug_dict['nearest_cocktail_urls'] = nn_urls
|
150 |
-
debug_dict['nearest_cocktail_recipes'] = nn_recipes
|
151 |
-
|
152 |
-
debug_dict['music_reconstruction'] = music_reconstruction.tolist()
|
153 |
-
debug_dict['mse_reconstruction'] = ((music_reconstruction - representation) ** 2).mean()
|
154 |
-
self.debug_dict = debug_dict
|
155 |
-
if verbose: print(' ' * (level + 2) + f'Debug info extracted in {int(time.time() - init_time)} seconds.')
|
156 |
-
|
157 |
-
return self.debug_dict
|
158 |
-
|
159 |
-
def print_debug(self, level=0):
|
160 |
-
print(' ' * level + '__DEBUGGING INFO__')
|
161 |
-
for k in self.keys_to_print:
|
162 |
-
to_print = self.debug_dict[k]
|
163 |
-
if k == 'nearest_cocktail_recipes':
|
164 |
-
to_print = self.debug_dict[k].copy()
|
165 |
-
for i in range(len(to_print)):
|
166 |
-
to_print[i] = to_print[i].replace('\n', '').replace('\t', '').replace('()', '')
|
167 |
-
if k == "nn_music":
|
168 |
-
to_print = self.debug_dict[k].copy()
|
169 |
-
for i in range(len(to_print)):
|
170 |
-
to_print[i] = to_print[i].replace('encoded_new_structured_', '').replace('_represented.txt', '')
|
171 |
-
to_print_str = f'{to_print}'
|
172 |
-
if isinstance(to_print, float):
|
173 |
-
to_print_str = f'{to_print:.2f}'
|
174 |
-
elif isinstance(to_print, list):
|
175 |
-
if isinstance(to_print[0], float):
|
176 |
-
to_print_str = '['
|
177 |
-
for element in to_print:
|
178 |
-
to_print_str += f'{element:.2f}, '
|
179 |
-
to_print_str = to_print_str[:-2] + ']'
|
180 |
-
print(' ' * (level + 2) + f'{k} : ' + to_print_str)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/music/__init__.py
DELETED
File without changes
|
src/music/config.py
DELETED
@@ -1,72 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import os
|
3 |
-
|
4 |
-
REPO_PATH = '/'.join(os.path.abspath(__file__).split('/')[:-3]) + '/'
|
5 |
-
AUDIO_PATH = REPO_PATH + 'data/music/audio/'
|
6 |
-
MIDI_PATH = REPO_PATH + 'data/music/midi/'
|
7 |
-
MUSIC_PATH = REPO_PATH + 'data/music/'
|
8 |
-
PROCESSED_PATH = REPO_PATH + 'data/music/processed/'
|
9 |
-
ENCODED_PATH = REPO_PATH + 'data/music/encoded/'
|
10 |
-
HANDCODED_REP_PATH = MUSIC_PATH + 'handcoded_reps/'
|
11 |
-
DATASET_PATH = REPO_PATH + 'data/music/encoded_new_structured/diverse_piano/'
|
12 |
-
SYNTH_RECORDED_AUDIO_PATH = AUDIO_PATH + 'synth_audio_recorded/'
|
13 |
-
SYNTH_RECORDED_MIDI_PATH = MIDI_PATH + 'synth_midi_recorded/'
|
14 |
-
CHECKPOINTS_PATH = REPO_PATH + 'checkpoints/'
|
15 |
-
EXPERIMENT_PATH = REPO_PATH + 'experiments/'
|
16 |
-
SEED = 0
|
17 |
-
|
18 |
-
# params for data download
|
19 |
-
ALL_URL_PATH = REPO_PATH + 'data/music/audio/all_urls.pickle'
|
20 |
-
ALL_FAILED_URL_PATH = REPO_PATH + 'data/music/audio/all_failed_urls.pickle'
|
21 |
-
RATE_AUDIO_SAVE = 16000
|
22 |
-
FROM_URL_PATH = AUDIO_PATH + 'from_url/'
|
23 |
-
|
24 |
-
# params transcription
|
25 |
-
CHKPT_PATH_TRANSCRIPTION = REPO_PATH + 'checkpoints/piano_transcription/note_F1=0.9677_pedal_F1=0.9186.pth' # transcriptor chkpt path
|
26 |
-
FPS = 16000
|
27 |
-
RANDOM_CROP = True # whether to use random crops in case of cropped audio
|
28 |
-
CROP_LEN = 26 * 60
|
29 |
-
|
30 |
-
# params midi scrubbing and processing
|
31 |
-
MAX_DEPTH = 5 # max depth when searching in folders for audio files
|
32 |
-
MAX_GAP_IN_SONG = 10 # in secs
|
33 |
-
MIN_LEN = 20 # actual min len could go down to MIN_LEN - 2 * (REMOVE_FIRST_AND_LAST / 5)
|
34 |
-
MAX_LEN = 25 * 60 # maximum audio len for playlist downloads, and maximum audio length for transcription (in sec)
|
35 |
-
MIN_NB_NOTES = 80 # min nb of notes per minute of recording
|
36 |
-
REMOVE_FIRST_AND_LAST = 10 # will be divided by 5 if cutting this makes the song fall below min len
|
37 |
-
|
38 |
-
# parameters encoding
|
39 |
-
NOISE_INJECTED = True
|
40 |
-
AUGMENTATION = True
|
41 |
-
NB_AUG = 4 if AUGMENTATION else 0
|
42 |
-
RANGE_NOTE_ON = 128
|
43 |
-
RANGE_NOTE_OFF = 128
|
44 |
-
RANGE_VEL = 32
|
45 |
-
RANGE_TIME_SHIFT = 100
|
46 |
-
MAX_EMBEDDING = RANGE_VEL + RANGE_NOTE_OFF + RANGE_TIME_SHIFT + RANGE_NOTE_ON
|
47 |
-
MAX_TEST_SIZE = 1000
|
48 |
-
CHECKSUM_PATH = REPO_PATH + 'data/music/midi/checksum.pickle'
|
49 |
-
CHUNK_SIZE = 512
|
50 |
-
|
51 |
-
ALL_AUGMENTATIONS = []
|
52 |
-
for p in [-3, -2, -1, 1, 2, 3]:
|
53 |
-
ALL_AUGMENTATIONS.append((p))
|
54 |
-
ALL_AUGMENTATIONS = np.array(ALL_AUGMENTATIONS)
|
55 |
-
|
56 |
-
ALL_NOISE = []
|
57 |
-
for s in [-5, -2.5, 0, 2.5, 5]:
|
58 |
-
for p in np.arange(-6, 7):
|
59 |
-
if not ((s == 0) and (p==0)):
|
60 |
-
ALL_NOISE.append((s, p))
|
61 |
-
ALL_NOISE = np.array(ALL_NOISE)
|
62 |
-
|
63 |
-
# music transformer params
|
64 |
-
REP_MODEL_NAME = REPO_PATH + "checkpoints/music_representation/sentence_embedding/smallbert_b256_r128_1/best_model"
|
65 |
-
MUSIC_REP_PATH = REPO_PATH + "checkpoints/b256_r128_represented/"
|
66 |
-
MUSIC_NN_PATH = REPO_PATH + "checkpoints/music_representation/b256_r128_represented/nn_model.pickle"
|
67 |
-
|
68 |
-
TRANSLATION_VAE_CHKP_PATH = REPO_PATH + "checkpoints/music2cocktails/music2flavor/b256_r128_classif001_ld40_meanstd_regground2.5_egg_bubbles/"
|
69 |
-
|
70 |
-
# piano solo evaluation
|
71 |
-
# META_DATA_PIANO_EVAL_PATH = REPO_PATH + 'data/music/audio/is_piano.csv'
|
72 |
-
# CHKPT_PATH_PIANO_EVAL = REPO_PATH + 'data/checkpoints/piano_detection/piano_solo_model_32k.pth'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/music/pipeline/__init__.py
DELETED
File without changes
|
src/music/pipeline/audio2midi.py
DELETED
@@ -1,52 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import piano_transcription_inference
|
3 |
-
import numpy as np
|
4 |
-
import os
|
5 |
-
import sys
|
6 |
-
sys.path.append('../../')
|
7 |
-
from src.music.utils import get_out_path, load_audio
|
8 |
-
from src.music.config import CHKPT_PATH_TRANSCRIPTION, FPS, MIN_LEN, CROP_LEN
|
9 |
-
# import librosa
|
10 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
11 |
-
TRANSCRIPTOR = piano_transcription_inference.PianoTranscription(device=device,
|
12 |
-
checkpoint_path=CHKPT_PATH_TRANSCRIPTION)
|
13 |
-
|
14 |
-
def audio2midi(audio_path, midi_path=None, crop=CROP_LEN, random_crop=True, verbose=False, level=0):
|
15 |
-
if verbose and crop < MIN_LEN + 2:
|
16 |
-
print('crop is inferior to the minimal length of a tune')
|
17 |
-
assert '.mp3' == audio_path[-4:]
|
18 |
-
if midi_path is None:
|
19 |
-
midi_path, _, _ = get_out_path(in_path=audio_path, in_word='audio', out_word='midi', out_extension='.mid')
|
20 |
-
|
21 |
-
if verbose: print(' ' * level + f'Transcribing {audio_path}.')
|
22 |
-
if os.path.exists(midi_path):
|
23 |
-
if verbose: print(' ' * (level + 2) + 'Midi file already exists.')
|
24 |
-
return midi_path, ''
|
25 |
-
|
26 |
-
error_msg = 'Error in transcription. '
|
27 |
-
try:
|
28 |
-
error_msg += 'Maybe in audio loading?'
|
29 |
-
(audio, _) = load_audio(audio_path,
|
30 |
-
sr=FPS,
|
31 |
-
mono=True)
|
32 |
-
error_msg += ' Nope. Cropping?'
|
33 |
-
if isinstance(crop, int) and len(audio) > FPS * crop:
|
34 |
-
rc_str = ' (random crop)' if random_crop else ' (start crop)'
|
35 |
-
if verbose: print(' ' * (level + 2) + f'Cropping the song to {crop}s before transcription{rc_str}. ')
|
36 |
-
size_crop = FPS * crop
|
37 |
-
if random_crop:
|
38 |
-
index_begining = np.random.randint(len(audio) - size_crop - 1)
|
39 |
-
else:
|
40 |
-
index_begining = 0
|
41 |
-
audio = audio[index_begining: index_begining + size_crop]
|
42 |
-
error_msg += ' Nope. Transcription?'
|
43 |
-
TRANSCRIPTOR.transcribe(audio, midi_path)
|
44 |
-
error_msg += ' Nope.'
|
45 |
-
extra = f' Saved to {midi_path}' if midi_path else ''
|
46 |
-
if verbose: print(' ' * (level + 2) + f'Success! {extra}')
|
47 |
-
return midi_path, ''
|
48 |
-
except:
|
49 |
-
if verbose: print(' ' * (level + 2) + 'Transcription failed.')
|
50 |
-
if os.path.exists(midi_path):
|
51 |
-
os.remove(midi_path)
|
52 |
-
return None, error_msg + ' Yes.'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/music/pipeline/audio2piano_solo_prob.py
DELETED
@@ -1,47 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import librosa
|
3 |
-
import sys
|
4 |
-
sys.path.append('../../../data/')
|
5 |
-
from src.music.utilities.processing_models import piano_detection_model
|
6 |
-
from src.music.config import CHKPT_PATH_PIANO_EVAL
|
7 |
-
|
8 |
-
PIANO_SOLO_DETECTOR = piano_detection_model.PianoSoloDetector(CHKPT_PATH_PIANO_EVAL)
|
9 |
-
exclude_playlist_folders = ['synth_audio_recorded', 'from_url']
|
10 |
-
|
11 |
-
def clean_start_and_end_blanks(probs):
|
12 |
-
if len(probs) > 20:
|
13 |
-
# clean up to 10s in each direction
|
14 |
-
n_zeros_start = 0
|
15 |
-
for i in range(10):
|
16 |
-
if probs[i] <= 0.001:
|
17 |
-
n_zeros_start += 1
|
18 |
-
else:
|
19 |
-
break
|
20 |
-
n_zeros_end = 0
|
21 |
-
for i in range(10):
|
22 |
-
if probs[-(i + 1)] <= 0.001:
|
23 |
-
n_zeros_end += 1
|
24 |
-
else:
|
25 |
-
break
|
26 |
-
if n_zeros_end == 0:
|
27 |
-
return probs[n_zeros_start:]
|
28 |
-
else:
|
29 |
-
return probs[n_zeros_start:-n_zeros_end]
|
30 |
-
else:
|
31 |
-
return probs
|
32 |
-
|
33 |
-
def calculate_piano_solo_prob(audio_path, verbose=False):
|
34 |
-
"""Calculate the piano solo probability of all downloaded mp3s, and append
|
35 |
-
the probability to the meta csv file. Code from https://github.com/bytedance/GiantMIDI-Piano
|
36 |
-
"""
|
37 |
-
try:
|
38 |
-
error_msg = 'Error in audio loading?'
|
39 |
-
(audio, _) = librosa.core.load(audio_path, sr=piano_detection_model.SR, mono=True)
|
40 |
-
error_msg += ' Nope. Error in solo prediction?'
|
41 |
-
probs = PIANO_SOLO_DETECTOR.predict(audio)
|
42 |
-
# probs = clean_start_and_end_blanks(probs) # remove blanks at start and end (<=10s each way). If not piano, the rest of the song will be enough to tell.
|
43 |
-
piano_solo_prob = np.mean(probs)
|
44 |
-
error_msg += ' Nope. '
|
45 |
-
return piano_solo_prob, ''
|
46 |
-
except:
|
47 |
-
return None, error_msg + 'Yes.'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/music/pipeline/encoded2rep.py
DELETED
@@ -1,88 +0,0 @@
|
|
1 |
-
from src.music.utilities.representation_learning_utilities.constants import *
|
2 |
-
from src.music.config import REP_MODEL_NAME
|
3 |
-
from src.music.utils import get_out_path
|
4 |
-
import pickle
|
5 |
-
import numpy as np
|
6 |
-
# from transformers import AutoModel, AutoTokenizer
|
7 |
-
from torch import nn
|
8 |
-
from src.music.representation_learning.sentence_transfo.sentence_transformers import SentenceTransformer
|
9 |
-
|
10 |
-
class Argument(object):
|
11 |
-
def __init__(self, adict):
|
12 |
-
self.__dict__.update(adict)
|
13 |
-
|
14 |
-
class RepModel(nn.Module):
|
15 |
-
def __init__(self, model, model_name):
|
16 |
-
super().__init__()
|
17 |
-
if 't5' in model_name:
|
18 |
-
self.model = model.get_encoder()
|
19 |
-
else:
|
20 |
-
self.model = model
|
21 |
-
self.model.eval()
|
22 |
-
|
23 |
-
def forward(self, inputs):
|
24 |
-
with torch.no_grad():
|
25 |
-
out = self.model(inputs, output_hidden_states=True)
|
26 |
-
embeddings = out.hidden_states[-1]
|
27 |
-
return torch.mean(embeddings[0], dim=0)
|
28 |
-
|
29 |
-
# def get_trained_music_LM(model_name):
|
30 |
-
# tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True)
|
31 |
-
# model = RepModel(AutoModel.from_pretrained(model_name, use_auth_token=True), model_name)
|
32 |
-
#
|
33 |
-
# return model, tokenizer
|
34 |
-
|
35 |
-
def get_trained_sentence_embedder(model_name):
|
36 |
-
model = SentenceTransformer(model_name)
|
37 |
-
return model
|
38 |
-
|
39 |
-
MODEL = get_trained_sentence_embedder(REP_MODEL_NAME)
|
40 |
-
|
41 |
-
def encoded2rep(encoded_path, rep_path=None, return_rep=False, verbose=False, level=0):
|
42 |
-
if not rep_path:
|
43 |
-
rep_path, _, _ = get_out_path(in_path=encoded_path, in_word='encoded', out_word='represented', out_extension='.txt')
|
44 |
-
|
45 |
-
error_msg = 'Error in music transformer mapping.'
|
46 |
-
if verbose: print(' ' * level + 'Mapping to final music representations')
|
47 |
-
try:
|
48 |
-
error_msg += ' Error in encoded file loading?'
|
49 |
-
with open(encoded_path, 'rb') as f:
|
50 |
-
data = pickle.load(f)
|
51 |
-
performance = [str(w) for w in data['main'] if w != 1]
|
52 |
-
assert len(performance) % 5 == 0
|
53 |
-
if(len(performance) == 0):
|
54 |
-
error_msg += " Error: No midi messages in primer file"
|
55 |
-
assert False
|
56 |
-
error_msg += ' Nope, error in tokenization?'
|
57 |
-
perf = ' '.join(performance)
|
58 |
-
# tokenized = torch.IntTensor(TOKENIZER.encode(perf)).unsqueeze(dim=0)
|
59 |
-
error_msg += ' Nope. Maybe in performance encoding?'
|
60 |
-
# reps = []
|
61 |
-
# for i_chunk in range(min(tokenized.shape[1] // 510 - 1, 8)):
|
62 |
-
# chunk_tokenized = tokenized[:, i_chunk * 510: (i_chunk + 1) * 510 + 2]
|
63 |
-
# rep = MODEL(chunk_tokenized)
|
64 |
-
# reps.append(rep.detach().numpy())
|
65 |
-
# representation = np.mean(reps, axis=0)
|
66 |
-
p = [int(p) for p in perf.split(' ')]
|
67 |
-
print('PERF:', np.sum(p), perf)
|
68 |
-
representation = MODEL.encode(perf)
|
69 |
-
print('model weights sum: ', torch.sum(torch.Tensor([param.sum() for param in list(MODEL.parameters())])))
|
70 |
-
print('reprep', representation)
|
71 |
-
error_msg += ' Nope. Saving performance?'
|
72 |
-
np.savetxt(rep_path, representation)
|
73 |
-
error_msg += ' Nope.'
|
74 |
-
if verbose: print(' ' * (level + 2) + 'Success.')
|
75 |
-
if return_rep:
|
76 |
-
return rep_path, representation, ''
|
77 |
-
else:
|
78 |
-
return rep_path, ''
|
79 |
-
except:
|
80 |
-
if verbose: print(' ' * (level + 2) + f'Failed with error: {error_msg}')
|
81 |
-
if return_rep:
|
82 |
-
return None, None, error_msg
|
83 |
-
else:
|
84 |
-
return None, error_msg
|
85 |
-
|
86 |
-
if __name__ == "__main__":
|
87 |
-
representation = encoded2rep("/home/cedric/Documents/pianocktail/data/music/encoded/single_videos_midi_processed_encoded/chris_dawson_all_of_me_.pickle")
|
88 |
-
stop = 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/music/pipeline/midi2processed.py
DELETED
@@ -1,152 +0,0 @@
|
|
1 |
-
import time
|
2 |
-
import os
|
3 |
-
import sys
|
4 |
-
sys.path.append('../../')
|
5 |
-
import pretty_midi as pm
|
6 |
-
import numpy as np
|
7 |
-
|
8 |
-
from src.music.utils import get_out_path
|
9 |
-
from src.music.config import MIN_LEN, MIN_NB_NOTES, MAX_GAP_IN_SONG, REMOVE_FIRST_AND_LAST
|
10 |
-
|
11 |
-
|
12 |
-
def sort_notes(notes):
|
13 |
-
starts = np.array([n.start for n in notes])
|
14 |
-
index_sorted = np.argsort(starts)
|
15 |
-
return [notes[i] for i in index_sorted].copy()
|
16 |
-
|
17 |
-
|
18 |
-
def delete_notes_end_after_start(notes):
|
19 |
-
indexes_to_keep = [i for i, n in enumerate(notes) if n.start < n.end]
|
20 |
-
return [notes[i] for i in indexes_to_keep].copy()
|
21 |
-
|
22 |
-
def compute_largest_gap(notes):
|
23 |
-
gaps = []
|
24 |
-
latest_note_end_so_far = notes[0].end
|
25 |
-
for i in range(len(notes) - 1):
|
26 |
-
note_start = notes[i + 1].start
|
27 |
-
if latest_note_end_so_far < note_start:
|
28 |
-
gaps.append(note_start - latest_note_end_so_far)
|
29 |
-
latest_note_end_so_far = max(latest_note_end_so_far, notes[i+1].end)
|
30 |
-
if len(gaps) > 0:
|
31 |
-
largest_gap = np.max(gaps)
|
32 |
-
else:
|
33 |
-
largest_gap = 0
|
34 |
-
return largest_gap
|
35 |
-
|
36 |
-
def analyze_instrument(inst):
|
37 |
-
# test that piano plays throughout
|
38 |
-
init = time.time()
|
39 |
-
notes = inst.notes.copy()
|
40 |
-
nb_notes = len(notes)
|
41 |
-
start = notes[0].start
|
42 |
-
end = inst.get_end_time()
|
43 |
-
duration = end - start
|
44 |
-
largest_gap = compute_largest_gap(notes)
|
45 |
-
return nb_notes, start, end, duration, largest_gap
|
46 |
-
|
47 |
-
def remove_beginning_and_end(midi, end_time):
|
48 |
-
notes = midi.instruments[0].notes.copy()
|
49 |
-
new_notes = [n for n in notes if n.start > REMOVE_FIRST_AND_LAST and n.end < end_time - REMOVE_FIRST_AND_LAST]
|
50 |
-
midi.instruments[0].notes = new_notes
|
51 |
-
return midi
|
52 |
-
|
53 |
-
def remove_blanks_beginning_and_end(midi):
|
54 |
-
# remove blanks and the beginning and the end
|
55 |
-
shift = midi.instruments[0].notes[0].start
|
56 |
-
for n in midi.instruments[0].notes:
|
57 |
-
n.start = max(0, n.start - shift)
|
58 |
-
n.end = max(0, n.end - shift)
|
59 |
-
for ksc in midi.key_signature_changes:
|
60 |
-
ksc.time = max(0, ksc.time - shift)
|
61 |
-
for tsc in midi.time_signature_changes:
|
62 |
-
tsc.time = max(0, tsc.time - shift)
|
63 |
-
for pb in midi.instruments[0].pitch_bends:
|
64 |
-
pb.time = max(0, pb.time - shift)
|
65 |
-
for cc in midi.instruments[0].control_changes:
|
66 |
-
cc.time = max(0, cc.time - shift)
|
67 |
-
return midi
|
68 |
-
|
69 |
-
def is_valid_inst(largest_gap, duration, nb_notes, gap_counts=True):
|
70 |
-
error_msg = ''
|
71 |
-
valid = True
|
72 |
-
if largest_gap > MAX_GAP_IN_SONG and gap_counts:
|
73 |
-
valid = False
|
74 |
-
error_msg += f'wide gap ({largest_gap:.2f} secs), '
|
75 |
-
if duration < (MIN_LEN + 2 * REMOVE_FIRST_AND_LAST):
|
76 |
-
valid = False
|
77 |
-
error_msg += f'too short ({duration:.2f} secs), '
|
78 |
-
if nb_notes < MIN_NB_NOTES * duration / 60: # nb of notes needs to be superior to the minimum number / min * the duration in minute
|
79 |
-
valid = False
|
80 |
-
error_msg += f'too few notes ({nb_notes}), '
|
81 |
-
return valid, error_msg
|
82 |
-
|
83 |
-
def midi2processed(midi_path, processed_path=None, apply_filtering=True, verbose=False, level=0):
|
84 |
-
assert midi_path.split('.')[-1] in ['mid', 'midi']
|
85 |
-
if not processed_path:
|
86 |
-
processed_path, _, _ = get_out_path(in_path=midi_path, in_word='midi', out_word='processed', out_extension='.mid')
|
87 |
-
|
88 |
-
if verbose: print(' ' * level + f'Processing {midi_path}.')
|
89 |
-
|
90 |
-
if os.path.exists(processed_path):
|
91 |
-
if verbose: print(' ' * (level + 2) + 'Processed midi file already exists.')
|
92 |
-
return processed_path, ''
|
93 |
-
error_msg = 'Error in scrubbing. '
|
94 |
-
# try:
|
95 |
-
inst_error_msg = ''
|
96 |
-
# load mid file
|
97 |
-
error_msg += 'Error in midi loading?'
|
98 |
-
midi = pm.PrettyMIDI(midi_path)
|
99 |
-
error_msg += ' Nope. Removing invalid notes?'
|
100 |
-
midi.remove_invalid_notes() # filter invalid notes
|
101 |
-
error_msg += ' Nope. Filtering instruments?'
|
102 |
-
# filter instruments
|
103 |
-
instruments = midi.instruments.copy()
|
104 |
-
new_instru = []
|
105 |
-
instruments_data = []
|
106 |
-
for i_inst, inst in enumerate(instruments):
|
107 |
-
if inst.program <= 7 and not inst.is_drum and len(inst.notes) > 5:
|
108 |
-
# inst is a piano
|
109 |
-
# check data
|
110 |
-
inst.notes = sort_notes(inst.notes) # sort notes
|
111 |
-
inst.notes = delete_notes_end_after_start(inst.notes) # delete invalid notes
|
112 |
-
nb_notes, start, end, duration, largest_gap = analyze_instrument(inst)
|
113 |
-
is_valid, err_msg = is_valid_inst(largest_gap=largest_gap, duration=duration, nb_notes=nb_notes, gap_counts='maestro' not in midi_path)
|
114 |
-
if is_valid or not apply_filtering:
|
115 |
-
new_instru.append(inst)
|
116 |
-
instruments_data.append([nb_notes, start, end, duration, largest_gap])
|
117 |
-
else:
|
118 |
-
inst_error_msg += 'inst1: ' + err_msg + '\n'
|
119 |
-
instruments_data = np.array(instruments_data)
|
120 |
-
error_msg += ' Nope. Taking one instrument?'
|
121 |
-
|
122 |
-
if len(new_instru) == 0:
|
123 |
-
error_msg = f'No piano instrument. {inst_error_msg}'
|
124 |
-
assert False
|
125 |
-
elif len(new_instru) > 1:
|
126 |
-
# take instrument playing the most notes
|
127 |
-
instrument = new_instru[np.argmax(instruments_data[:, 0])]
|
128 |
-
else:
|
129 |
-
instrument = new_instru[0]
|
130 |
-
instrument.program = 0 # set the instrument to Grand Piano.
|
131 |
-
midi.instruments = [instrument] # put instrument in midi file
|
132 |
-
error_msg += ' Nope. Removing blanks?'
|
133 |
-
# remove first and last REMOVE_FIRST_AND_LAST seconds (avoid clapping and jingles)
|
134 |
-
end_time = midi.get_end_time()
|
135 |
-
if apply_filtering: midi = remove_beginning_and_end(midi, end_time)
|
136 |
-
|
137 |
-
# remove beginning and end
|
138 |
-
midi = remove_blanks_beginning_and_end(midi)
|
139 |
-
error_msg += ' Nope. Saving?'
|
140 |
-
|
141 |
-
# save midi file
|
142 |
-
midi.write(processed_path)
|
143 |
-
error_msg += ' Nope.'
|
144 |
-
if verbose:
|
145 |
-
extra = f' Saved to {processed_path}' if midi_path else ''
|
146 |
-
print(' ' * (level + 2) + f'Success! {extra}')
|
147 |
-
return processed_path, ''
|
148 |
-
#except:
|
149 |
-
# if verbose: print(' ' * (level + 2) + 'Scrubbing failed.')
|
150 |
-
# if os.path.exists(processed_path):
|
151 |
-
# os.remove(processed_path)
|
152 |
-
# return None, error_msg + ' Yes.'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/music/pipeline/music_pipeline.py
DELETED
@@ -1,86 +0,0 @@
|
|
1 |
-
from src.music.pipeline.url2audio import url2audio
|
2 |
-
from src.music.pipeline.audio2midi import audio2midi
|
3 |
-
from src.music.pipeline.midi2processed import midi2processed
|
4 |
-
from src.music.pipeline.processed2encoded import processed2encoded
|
5 |
-
from src.music.pipeline.encoded2rep import encoded2rep
|
6 |
-
from src.music.config import RANDOM_CROP, NB_AUG, FROM_URL_PATH
|
7 |
-
# from src.music.pipeline.synth2audio import AudioRecorder
|
8 |
-
# from src.music.pipeline.processed2handcodedrep import processed2handcodedrep
|
9 |
-
import time
|
10 |
-
import hashlib
|
11 |
-
|
12 |
-
VERBOSE = True
|
13 |
-
AUGMENTATION, NOISE_INJECTED = False, False
|
14 |
-
CROP = 10# crop 30s before transcription
|
15 |
-
|
16 |
-
# AUDIO_RECORDER = AudioRecorder(place='home')
|
17 |
-
|
18 |
-
def encode_music(url=None,
|
19 |
-
audio_path=None,
|
20 |
-
midi_path=None,
|
21 |
-
processed_path=None,
|
22 |
-
record=False,
|
23 |
-
crop=CROP,
|
24 |
-
random_crop=RANDOM_CROP,
|
25 |
-
augmentation=AUGMENTATION,
|
26 |
-
noise_injection=NOISE_INJECTED,
|
27 |
-
apply_filtering=True,
|
28 |
-
nb_aug=NB_AUG,
|
29 |
-
level=0,
|
30 |
-
verbose=VERBOSE):
|
31 |
-
if not record: assert url is not None or audio_path is not None or midi_path is not None or processed_path is not None
|
32 |
-
init_time = time.time()
|
33 |
-
error = ''
|
34 |
-
try:
|
35 |
-
if record:
|
36 |
-
assert audio_path is None and midi_path is None
|
37 |
-
if verbose: print(' ' * level + 'Processing music, recorded from mic.')
|
38 |
-
audio_path = AUDIO_RECORDER.record_one()
|
39 |
-
error = ''
|
40 |
-
if processed_path is None:
|
41 |
-
if midi_path is None:
|
42 |
-
if audio_path is None:
|
43 |
-
if verbose and not record: print(' ' * level + 'Processing music, from audio source.')
|
44 |
-
init_t = time.time()
|
45 |
-
audio_path, _, error = url2audio(playlist_path=FROM_URL_PATH, video_url=url, verbose=verbose, level=level+2)
|
46 |
-
if verbose: print(' ' * (level + 4) + f'Audio downloaded in {int(time.time() - init_t)} seconds.')
|
47 |
-
else:
|
48 |
-
if verbose and not record: print(' ' * level + 'Processing music, from midi source.')
|
49 |
-
init_t = time.time()
|
50 |
-
midi_path, error = audio2midi(audio_path, crop=crop, random_crop=random_crop, verbose=verbose, level=level+2)
|
51 |
-
if verbose: print(' ' * (level + 4) + f'Audio transcribed to midi in {int(time.time() - init_t)} seconds.')
|
52 |
-
init_t = time.time()
|
53 |
-
processed_path, error = midi2processed(midi_path, apply_filtering=apply_filtering, verbose=verbose, level=level+2)
|
54 |
-
if verbose: print(' ' * (level + 4) + f'Midi preprocessed in {int(time.time() - init_t)} seconds.')
|
55 |
-
init_t = time.time()
|
56 |
-
encoded_path, error = processed2encoded(processed_path, augmentation=augmentation, nb_aug=nb_aug, noise_injection=noise_injection, verbose=verbose, level=level+2)
|
57 |
-
if verbose: print(' ' * (level + 4) + f'Midi encoded in {int(time.time() - init_t)} seconds.')
|
58 |
-
init_t = time.time()
|
59 |
-
representation_path, representation, error = encoded2rep(encoded_path, return_rep=True, level=level+2, verbose=verbose)
|
60 |
-
if verbose: print(' ' * (level + 4) + f'Music representation computed in {int(time.time() - init_t)} seconds.')
|
61 |
-
init_t = time.time()
|
62 |
-
handcoded_rep_path, handcoded_rep, error = None, None, ''
|
63 |
-
# handcoded_rep_path, handcoded_rep, error = processed2handcodedrep(processed_path, return_rep=True, level=level+2, verbose=verbose)
|
64 |
-
if verbose: print(' ' * (level + 4) + f'Music handcoded representation computed in {int(time.time() - init_t)} seconds.')
|
65 |
-
# assert handcoded_rep_path is not None and representation_path is not None
|
66 |
-
all_paths = dict(url=url, audio_path=audio_path, midi_path=midi_path, processed_path=processed_path, encoded_path=encoded_path,
|
67 |
-
representation_path=representation_path, handcoded_rep_path=handcoded_rep_path)
|
68 |
-
print('audio hash: ', hashlib.md5(open(audio_path, 'rb').read()).hexdigest())
|
69 |
-
print('midi hash: ', hashlib.md5(open(midi_path, 'rb').read()).hexdigest())
|
70 |
-
print('processed hash: ', hashlib.md5(open(processed_path, 'rb').read()).hexdigest())
|
71 |
-
print('encoded hash: ', hashlib.md5(open(encoded_path, 'rb').read()).hexdigest())
|
72 |
-
print('rep hash: ', hashlib.md5(open(representation_path, 'rb').read()).hexdigest())
|
73 |
-
print("rep:", representation[:10])
|
74 |
-
if verbose: print(' ' * (level + 2) + f'Music processed in {int(time.time() - init_time)} seconds.')
|
75 |
-
except Exception as err:
|
76 |
-
print(err, error)
|
77 |
-
if verbose: print(' ' * (level + 2) + f'Music FAILED to process in {int(time.time() - init_time)} seconds.')
|
78 |
-
representation = None
|
79 |
-
handcoded_rep = None
|
80 |
-
all_paths = dict()
|
81 |
-
|
82 |
-
return representation, handcoded_rep, all_paths, error
|
83 |
-
|
84 |
-
if __name__ == '__main__':
|
85 |
-
representation = encode_music(url="https://www.youtube.com/watch?v=a2LFVWBmoiw")[0]
|
86 |
-
# representation = encode_music(record=True)[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/music/pipeline/processed2encoded.py
DELETED
@@ -1,52 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import numpy as np
|
4 |
-
import pickle
|
5 |
-
sys.path.append('../../')
|
6 |
-
|
7 |
-
from src.music.utils import get_out_path
|
8 |
-
from src.music.config import ALL_NOISE, ALL_AUGMENTATIONS, NB_AUG, NOISE_INJECTED
|
9 |
-
from src.music.utilities.midi_processor import encode_midi_structured, encode_midi_chunks_structured
|
10 |
-
|
11 |
-
nb_noise = ALL_NOISE.shape[0]
|
12 |
-
nb_aug = ALL_AUGMENTATIONS.shape[0]
|
13 |
-
|
14 |
-
def sample_augmentations(n):
|
15 |
-
return ALL_AUGMENTATIONS[np.random.choice(np.arange(nb_aug), size=n, replace=False)]
|
16 |
-
|
17 |
-
def sample_noise():
|
18 |
-
return ALL_NOISE[np.random.choice(np.arange(nb_noise))]
|
19 |
-
|
20 |
-
def processed2encoded(processed_path, encoded_path=None, augmentation=False, nb_aug=None, noise_injection=False, verbose=False, level=0):
|
21 |
-
assert processed_path.split('.')[-1] in ['mid', 'midi']
|
22 |
-
if not encoded_path:
|
23 |
-
encoded_path, _, _ = get_out_path(in_path=processed_path, in_word='processed', out_word='encoded', out_extension='.pickle')
|
24 |
-
|
25 |
-
if verbose: print(' ' * level + f'Encoding {processed_path}')
|
26 |
-
if os.path.exists(encoded_path):
|
27 |
-
if verbose: print(' ' * (level + 2) + 'Midi file is already encoded.')
|
28 |
-
return encoded_path, ''
|
29 |
-
|
30 |
-
if augmentation:
|
31 |
-
assert isinstance(nb_aug, int)
|
32 |
-
error_msg = 'Error in encoding. '
|
33 |
-
try:
|
34 |
-
error_msg = 'Error in encoding midi?'
|
35 |
-
nb_noise = 1 if noise_injection else 0
|
36 |
-
encoded_main, encoded_aug, encoded_noisy = encode_midi_structured(processed_path, nb_aug, nb_noise)
|
37 |
-
|
38 |
-
# make sure augmentations are not out of bounds
|
39 |
-
error_msg = ' Nope. Error in saving encoding?'
|
40 |
-
with open(encoded_path, 'wb') as f:
|
41 |
-
pickle.dump(dict(main=encoded_main, aug=encoded_aug, noisy=encoded_noisy), f)
|
42 |
-
error_msg = ' Nope.'
|
43 |
-
if verbose:
|
44 |
-
extra = f' Saved to {encoded_path}' if encoded_path else ''
|
45 |
-
print(' ' * (level + 2) + f'Success! {extra}')
|
46 |
-
return encoded_path, ''
|
47 |
-
except:
|
48 |
-
if verbose: print(' ' * (level + 2) + 'Transcription failed.')
|
49 |
-
if os.path.exists(encoded_path):
|
50 |
-
os.remove(encoded_path)
|
51 |
-
return None, error_msg + ' Yes.'
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/music/pipeline/processed2handcodedrep.py
DELETED
@@ -1,343 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
from music21 import *
|
3 |
-
from music21.features import native, jSymbolic, DataSet
|
4 |
-
import pretty_midi as pm
|
5 |
-
from src.music.utils import get_out_path
|
6 |
-
from src.music.utilities.handcoded_rep_utilities.tht import tactus_hypothesis_tracker, tracker_analysis
|
7 |
-
from src.music.utilities.handcoded_rep_utilities.loudness import get_loudness, compute_total_loudness, amplitude2db, velocity2amplitude, get_db_of_equivalent_loudness_at_440hz, pitch2freq
|
8 |
-
import json
|
9 |
-
import os
|
10 |
-
environment.set('musicxmlPath', '/home/cedric/Desktop/test/')
|
11 |
-
midi_path = "/home/cedric/Documents/pianocktail/data/music/processed/doug_mckenzie_processed/allthethings_reharmonized_processed.mid"
|
12 |
-
|
13 |
-
FEATURES_DICT_SCORE = dict(
|
14 |
-
# strongest pulse: measures how fast the melody is
|
15 |
-
# stronger_pulse=jSymbolic.StrongestRhythmicPulseFeature,
|
16 |
-
# weights of the two strongest pulse, measures rhythmic consistency: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#combinedstrengthoftwostrongestrhythmicpulsesfeature
|
17 |
-
pulse_strength_two=jSymbolic.CombinedStrengthOfTwoStrongestRhythmicPulsesFeature,
|
18 |
-
# weights of the strongest pulse, measures rhythmic consistency: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#combinedstrengthoftwostrongestrhythmicpulsesfeature
|
19 |
-
pulse_strength = jSymbolic.StrengthOfStrongestRhythmicPulseFeature,
|
20 |
-
# variability of attacks: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#variabilityoftimebetweenattacksfeature
|
21 |
-
|
22 |
-
)
|
23 |
-
FEATURES_DICT = dict(
|
24 |
-
# bass register importance: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#importanceofbassregisterfeature
|
25 |
-
# bass_register=jSymbolic.ImportanceOfBassRegisterFeature,
|
26 |
-
# high register importance: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#importanceofbassregisterfeature
|
27 |
-
# high_register=jSymbolic.ImportanceOfHighRegisterFeature,
|
28 |
-
# medium register importance: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#importanceofbassregisterfeature
|
29 |
-
# medium_register=jSymbolic.ImportanceOfMiddleRegisterFeature,
|
30 |
-
# number of common pitches (at least 9% of all): https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#numberofcommonmelodicintervalsfeature
|
31 |
-
# common_pitches=jSymbolic.NumberOfCommonPitchesFeature,
|
32 |
-
# pitch class variety (used at least once): https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#pitchvarietyfeature
|
33 |
-
# pitch_variety=jSymbolic.PitchVarietyFeature,
|
34 |
-
# attack_variability = jSymbolic.VariabilityOfTimeBetweenAttacksFeature,
|
35 |
-
# staccato fraction: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#staccatoincidencefeature
|
36 |
-
# staccato_score = jSymbolic.StaccatoIncidenceFeature,
|
37 |
-
# mode analysis: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesNative.html
|
38 |
-
av_melodic_interval = jSymbolic.AverageMelodicIntervalFeature,
|
39 |
-
# chromatic motion: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#chromaticmotionfeature
|
40 |
-
chromatic_motion = jSymbolic.ChromaticMotionFeature,
|
41 |
-
# direction of motion (fraction of rising intervals: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#directionofmotionfeature
|
42 |
-
motion_direction = jSymbolic.DirectionOfMotionFeature,
|
43 |
-
# duration of melodic arcs: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#durationofmelodicarcsfeature
|
44 |
-
melodic_arcs_duration = jSymbolic.DurationOfMelodicArcsFeature,
|
45 |
-
# melodic arcs size: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#sizeofmelodicarcsfeature
|
46 |
-
melodic_arcs_size = jSymbolic.SizeOfMelodicArcsFeature,
|
47 |
-
# number of common melodic interval (at least 9% of all): https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#numberofcommonmelodicintervalsfeature
|
48 |
-
# common_melodic_intervals = jSymbolic.NumberOfCommonMelodicIntervalsFeature,
|
49 |
-
# https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#amountofarpeggiationfeature
|
50 |
-
# arpeggiato=jSymbolic.AmountOfArpeggiationFeature,
|
51 |
-
)
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
def compute_beat_info(onsets):
|
59 |
-
onsets_in_ms = np.array(onsets) * 1000
|
60 |
-
|
61 |
-
tht = tactus_hypothesis_tracker.default_tht()
|
62 |
-
trackers = tht(onsets_in_ms)
|
63 |
-
top_hts = tracker_analysis.top_hypothesis(trackers, len(onsets_in_ms))
|
64 |
-
beats = tracker_analysis.produce_beats_information(onsets_in_ms, top_hts, adapt_period=250 is not None,
|
65 |
-
adapt_phase=tht.eval_f, max_delta_bpm=250, avoid_quickturns=None)
|
66 |
-
tempo = 1 / (np.mean(np.diff(beats)) / 1000) * 60 # in bpm
|
67 |
-
conf_values = tracker_analysis.tht_tracking_confs(trackers, len(onsets_in_ms))
|
68 |
-
pulse_clarity = np.mean(np.array(conf_values), axis=0)[1]
|
69 |
-
return tempo, pulse_clarity
|
70 |
-
|
71 |
-
def dissonance_score(A):
|
72 |
-
"""
|
73 |
-
Given a piano-roll indicator matrix representation of a musical work (128 pitches x beats),
|
74 |
-
return the dissonance as a function of beats.
|
75 |
-
Input:
|
76 |
-
A - 128 x beats indicator matrix of MIDI pitch number
|
77 |
-
|
78 |
-
"""
|
79 |
-
freq_rats = np.arange(1, 7) # Harmonic series ratios
|
80 |
-
amps = np.exp(-.5 * freq_rats) # Partial amplitudes
|
81 |
-
F0 = 8.1757989156 # base frequency for MIDI (note 0)
|
82 |
-
diss = [] # List for dissonance values
|
83 |
-
thresh = 1e-3
|
84 |
-
for beat in A.T:
|
85 |
-
idx = np.where(beat>thresh)[0]
|
86 |
-
if len(idx):
|
87 |
-
freqs, mags = [], [] # lists for frequencies, mags
|
88 |
-
for i in idx:
|
89 |
-
freqs.extend(F0*2**(i/12.0)*freq_rats)
|
90 |
-
mags.extend(amps)
|
91 |
-
freqs = np.array(freqs)
|
92 |
-
mags = np.array(mags)
|
93 |
-
sortIdx = freqs.argsort()
|
94 |
-
d = compute_dissonance(freqs[sortIdx],mags[sortIdx])
|
95 |
-
diss.extend([d])
|
96 |
-
else:
|
97 |
-
diss.extend([-1]) # Null value
|
98 |
-
diss = np.array(diss)
|
99 |
-
return diss[np.where(diss != -1)]
|
100 |
-
|
101 |
-
def compute_dissonance(freqs, amps):
|
102 |
-
"""
|
103 |
-
From https://notebook.community/soundspotter/consonance/week1_consonance
|
104 |
-
Compute dissonance between partials with center frequencies in freqs, uses a model of critical bandwidth.
|
105 |
-
and amplitudes in amps. Based on Sethares "Tuning, Timbre, Spectrum, Scale" (1998) after Plomp and Levelt (1965)
|
106 |
-
|
107 |
-
inputs:
|
108 |
-
freqs - list of partial frequencies
|
109 |
-
amps - list of corresponding amplitudes [default, uniformly 1]
|
110 |
-
"""
|
111 |
-
b1, b2, s1, s2, c1, c2, Dstar = (-3.51, -5.75, 0.0207, 19.96, 5, -5, 0.24)
|
112 |
-
f = np.array(freqs)
|
113 |
-
a = np.array(amps)
|
114 |
-
idx = np.argsort(f)
|
115 |
-
f = f[idx]
|
116 |
-
a = a[idx]
|
117 |
-
N = f.size
|
118 |
-
D = 0
|
119 |
-
for i in range(1, N):
|
120 |
-
Fmin = f[ 0 : N - i ]
|
121 |
-
S = Dstar / ( s1 * Fmin + s2)
|
122 |
-
Fdif = f[ i : N ] - f[ 0 : N - i ]
|
123 |
-
am = a[ i : N ] * a[ 0 : N - i ]
|
124 |
-
Dnew = am * (c1 * np.exp (b1 * S * Fdif) + c2 * np.exp(b2 * S * Fdif))
|
125 |
-
D += Dnew.sum()
|
126 |
-
return D
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
def store_new_midi(notes, out_path):
|
132 |
-
midi = pm.PrettyMIDI()
|
133 |
-
midi.instruments.append(pm.Instrument(program=0, is_drum=False))
|
134 |
-
midi.instruments[0].notes = notes
|
135 |
-
midi.write(out_path)
|
136 |
-
return midi
|
137 |
-
|
138 |
-
|
139 |
-
def processed2handcodedrep(midi_path, handcoded_rep_path=None, crop=30, verbose=False, save=True, return_rep=False, level=0):
|
140 |
-
try:
|
141 |
-
if not handcoded_rep_path:
|
142 |
-
handcoded_rep_path, _, _ = get_out_path(in_path=midi_path, in_word='processed', out_word='handcoded_reps', out_extension='.mid')
|
143 |
-
features = dict()
|
144 |
-
if verbose: print(' ' * level + 'Computing handcoded representations')
|
145 |
-
if os.path.exists(handcoded_rep_path):
|
146 |
-
with open(handcoded_rep_path.replace('.mid', '.json'), 'r') as f:
|
147 |
-
features = json.load(f)
|
148 |
-
rep = np.array([features[k] for k in sorted(features.keys())])
|
149 |
-
if rep.size == 49:
|
150 |
-
os.remove(handcoded_rep_path)
|
151 |
-
else:
|
152 |
-
if verbose: print(' ' * (level + 2) + 'Already computed.')
|
153 |
-
if return_rep:
|
154 |
-
return handcoded_rep_path, np.array([features[k] for k in sorted(features.keys())]), ''
|
155 |
-
else:
|
156 |
-
return handcoded_rep_path, ''
|
157 |
-
midi = pm.PrettyMIDI(midi_path) # load midi with pretty midi
|
158 |
-
notes = midi.instruments[0].notes # get notes
|
159 |
-
notes.sort(key=lambda x: (x.start, x.pitch)) # sort notes per start and pitch
|
160 |
-
onsets, offsets, pitches, durations, velocities = [], [], [], [], []
|
161 |
-
n_notes_cropped = len(notes)
|
162 |
-
for i_n, n in enumerate(notes):
|
163 |
-
onsets.append(n.start)
|
164 |
-
offsets.append(n.end)
|
165 |
-
durations.append(n.end-n.start)
|
166 |
-
pitches.append(n.pitch)
|
167 |
-
velocities.append(n.velocity)
|
168 |
-
if crop is not None: # find how many notes to keep
|
169 |
-
if n.start > crop and n_notes_cropped == len(notes):
|
170 |
-
n_notes_cropped = i_n
|
171 |
-
break
|
172 |
-
notes = notes[:n_notes_cropped]
|
173 |
-
midi = store_new_midi(notes, handcoded_rep_path)
|
174 |
-
# pianoroll = midi.get_piano_roll() # extract piano roll representation
|
175 |
-
|
176 |
-
# compute loudness
|
177 |
-
amplitudes = velocity2amplitude(np.array(velocities))
|
178 |
-
power_dbs = amplitude2db(amplitudes)
|
179 |
-
frequencies = pitch2freq(np.array(pitches))
|
180 |
-
loudness_values = get_loudness(power_dbs, frequencies)
|
181 |
-
# compute average perceived loudness
|
182 |
-
# for each power, compute loudness, then compute power such that the loudness at 440 Hz would be equivalent.
|
183 |
-
# equivalent_powers_dbs = get_db_of_equivalent_loudness_at_440hz(frequencies, power_dbs)
|
184 |
-
# then get the corresponding amplitudes
|
185 |
-
# equivalent_amplitudes = 10 ** (equivalent_powers_dbs / 20)
|
186 |
-
# not use a amplitude model across the sample to compute the instantaneous amplitude, turn it back to dbs, then to perceived loudness with unique freq 440 Hz
|
187 |
-
# av_total_loudness, std_total_loudness = compute_total_loudness(equivalent_amplitudes, onsets, offsets)
|
188 |
-
|
189 |
-
end_time = np.max(offsets)
|
190 |
-
start_time = notes[0].start
|
191 |
-
|
192 |
-
|
193 |
-
score = converter.parse(handcoded_rep_path)
|
194 |
-
score.chordify()
|
195 |
-
notes_without_chords = stream.Stream(score.flatten().getElementsByClass('Note'))
|
196 |
-
|
197 |
-
velocities_wo_chords, pitches_wo_chords, amplitudes_wo_chords, dbs_wo_chords = [], [], [], []
|
198 |
-
frequencies_wo_chords, loudness_values_wo_chords, onsets_wo_chords, offsets_wo_chords, durations_wo_chords = [], [], [], [], []
|
199 |
-
for i_n in range(len(notes_without_chords)):
|
200 |
-
n = notes_without_chords[i_n]
|
201 |
-
velocities_wo_chords.append(n.volume.velocity)
|
202 |
-
pitches_wo_chords.append(n.pitch.midi)
|
203 |
-
onsets_wo_chords.append(n.offset)
|
204 |
-
offsets_wo_chords.append(onsets_wo_chords[-1] + n.seconds)
|
205 |
-
durations_wo_chords.append(n.seconds)
|
206 |
-
|
207 |
-
amplitudes_wo_chords = velocity2amplitude(np.array(velocities_wo_chords))
|
208 |
-
power_dbs_wo_chords = amplitude2db(amplitudes_wo_chords)
|
209 |
-
frequencies_wo_chords = pitch2freq(np.array(pitches_wo_chords))
|
210 |
-
loudness_values_wo_chords = get_loudness(power_dbs_wo_chords, frequencies_wo_chords)
|
211 |
-
# compute average perceived loudness
|
212 |
-
# for each power, compute loudness, then compute power such that the loudness at 440 Hz would be equivalent.
|
213 |
-
# equivalent_powers_dbs_wo_chords = get_db_of_equivalent_loudness_at_440hz(frequencies_wo_chords, power_dbs_wo_chords)
|
214 |
-
# then get the corresponding amplitudes
|
215 |
-
# equivalent_amplitudes_wo_chords = 10 ** (equivalent_powers_dbs_wo_chords / 20)
|
216 |
-
# not use a amplitude model across the sample to compute the instantaneous amplitude, turn it back to dbs, then to perceived loudness with unique freq 440 Hz
|
217 |
-
# av_total_loudness_wo_chords, std_total_loudness_wo_chords = compute_total_loudness(equivalent_amplitudes_wo_chords, onsets_wo_chords, offsets_wo_chords)
|
218 |
-
|
219 |
-
ds = DataSet(classLabel='test')
|
220 |
-
f = list(FEATURES_DICT.values())
|
221 |
-
ds.addFeatureExtractors(f)
|
222 |
-
ds.addData(notes_without_chords)
|
223 |
-
ds.process()
|
224 |
-
for k, f in zip(FEATURES_DICT.keys(), ds.getFeaturesAsList()[0][1:-1]):
|
225 |
-
features[k] = f
|
226 |
-
|
227 |
-
ds = DataSet(classLabel='test')
|
228 |
-
f = list(FEATURES_DICT_SCORE.values())
|
229 |
-
ds.addFeatureExtractors(f)
|
230 |
-
ds.addData(score)
|
231 |
-
ds.process()
|
232 |
-
for k, f in zip(FEATURES_DICT_SCORE.keys(), ds.getFeaturesAsList()[0][1:-1]):
|
233 |
-
features[k] = f
|
234 |
-
|
235 |
-
# # # # #
|
236 |
-
# Register features
|
237 |
-
# # # # #
|
238 |
-
|
239 |
-
# features['av_pitch'] = np.mean(pitches)
|
240 |
-
# features['std_pitch'] = np.std(pitches)
|
241 |
-
# features['range_pitch'] = np.max(pitches) - np.min(pitches) # aka ambitus
|
242 |
-
|
243 |
-
# # # # #
|
244 |
-
# Rhythmic features
|
245 |
-
# # # # #
|
246 |
-
|
247 |
-
# tempo, pulse_clarity = compute_beat_info(onsets[:n_notes_cropped])
|
248 |
-
# features['pulse_clarity'] = pulse_clarity
|
249 |
-
# features['tempo'] = tempo
|
250 |
-
features['tempo_pm'] = midi.estimate_tempo()
|
251 |
-
|
252 |
-
# # # # #
|
253 |
-
# Temporal features
|
254 |
-
# # # # #
|
255 |
-
|
256 |
-
features['av_duration'] = np.mean(durations)
|
257 |
-
# features['std_duration'] = np.std(durations)
|
258 |
-
features['note_density'] = len(notes) / (end_time - start_time)
|
259 |
-
# intervals_wo_chords = np.diff(onsets_wo_chords)
|
260 |
-
# articulations = [max((i-d)/i, 0) for d, i in zip(durations_wo_chords, intervals_wo_chords) if i != 0]
|
261 |
-
# features['articulation'] = np.mean(articulations)
|
262 |
-
# features['av_duration_wo_chords'] = np.mean(durations_wo_chords)
|
263 |
-
# features['std_duration_wo_chords'] = np.std(durations_wo_chords)
|
264 |
-
|
265 |
-
# # # # #
|
266 |
-
# Dynamics features
|
267 |
-
# # # # #
|
268 |
-
features['av_velocity'] = np.mean(velocities)
|
269 |
-
features['std_velocity'] = np.std(velocities)
|
270 |
-
features['av_loudness'] = np.mean(loudness_values)
|
271 |
-
# features['std_loudness'] = np.std(loudness_values)
|
272 |
-
features['range_loudness'] = np.max(loudness_values) - np.min(loudness_values)
|
273 |
-
# features['av_integrated_loudness'] = av_total_loudness
|
274 |
-
# features['std_integrated_loudness'] = std_total_loudness
|
275 |
-
# features['av_velocity_wo_chords'] = np.mean(velocities_wo_chords)
|
276 |
-
# features['std_velocity_wo_chords'] = np.std(velocities_wo_chords)
|
277 |
-
# features['av_loudness_wo_chords'] = np.mean(loudness_values_wo_chords)
|
278 |
-
# features['std_loudness_wo_chords'] = np.std(loudness_values_wo_chords)
|
279 |
-
features['range_loudness_wo_chords'] = np.max(loudness_values_wo_chords) - np.min(loudness_values_wo_chords)
|
280 |
-
# features['av_integrated_loudness'] = av_total_loudness_wo_chords
|
281 |
-
# features['std_integrated_loudness'] = std_total_loudness_wo_chords
|
282 |
-
# indices_with_intervals = np.where(intervals_wo_chords > 0.01)
|
283 |
-
# features['av_loudness_change'] = np.mean(np.abs(np.diff(np.array(loudness_values_wo_chords)[indices_with_intervals]))) # accentuation
|
284 |
-
# features['av_velocity_change'] = np.mean(np.abs(np.diff(np.array(velocities_wo_chords)[indices_with_intervals]))) # accentuation
|
285 |
-
|
286 |
-
# # # # #
|
287 |
-
# Harmony features
|
288 |
-
# # # # #
|
289 |
-
|
290 |
-
# get major_minor score: https://web.mit.edu/music21/doc/moduleReference/moduleAnalysisDiscrete.html
|
291 |
-
music_analysis = score.analyze('key')
|
292 |
-
major_score = None
|
293 |
-
minor_score = None
|
294 |
-
for a in [music_analysis] + music_analysis.alternateInterpretations:
|
295 |
-
if 'major' in a.__str__() and a.correlationCoefficient > 0:
|
296 |
-
major_score = a.correlationCoefficient
|
297 |
-
elif 'minor' in a.__str__() and a.correlationCoefficient > 0:
|
298 |
-
minor_score = a.correlationCoefficient
|
299 |
-
if major_score is not None and minor_score is not None:
|
300 |
-
break
|
301 |
-
features['major_minor'] = major_score / (major_score + minor_score)
|
302 |
-
features['tonal_certainty'] = music_analysis.tonalCertainty()
|
303 |
-
# features['av_sensory_dissonance'] = np.mean(dissonance_score(pianoroll))
|
304 |
-
#TODO only works for chords, do something with melodic intervals: like proportion that is not third, fifth or sevenths?
|
305 |
-
|
306 |
-
# # # # #
|
307 |
-
# Interval features
|
308 |
-
# # # # #
|
309 |
-
#https://web.mit.edu/music21/doc/moduleReference/moduleAnalysisPatel.html
|
310 |
-
# features['melodic_interval_variability'] = analysis.patel.melodicIntervalVariability(notes_without_chords)
|
311 |
-
|
312 |
-
# # # # #
|
313 |
-
# Suprize features
|
314 |
-
# # # # #
|
315 |
-
# https://web.mit.edu/music21/doc/moduleReference/moduleAnalysisMetrical.html
|
316 |
-
# analysis.metrical.thomassenMelodicAccent(notes_without_chords)
|
317 |
-
# melodic_accents = [n.melodicAccent for n in notes_without_chords]
|
318 |
-
# features['melodic_accent'] = np.mean(melodic_accents)
|
319 |
-
|
320 |
-
if save:
|
321 |
-
for k, v in features.items():
|
322 |
-
features[k] = float(features[k])
|
323 |
-
with open(handcoded_rep_path.replace('.mid', '.json'), 'w') as f:
|
324 |
-
json.dump(features, f)
|
325 |
-
else:
|
326 |
-
print(features)
|
327 |
-
if os.path.exists(handcoded_rep_path):
|
328 |
-
os.remove(handcoded_rep_path)
|
329 |
-
if verbose: print(' ' * (level + 2) + 'Success.')
|
330 |
-
if return_rep:
|
331 |
-
return handcoded_rep_path, np.array([features[k] for k in sorted(features.keys())]), ''
|
332 |
-
else:
|
333 |
-
return handcoded_rep_path, ''
|
334 |
-
except:
|
335 |
-
if verbose: print(' ' * (level + 2) + 'Failed.')
|
336 |
-
if return_rep:
|
337 |
-
return None, None, 'error'
|
338 |
-
else:
|
339 |
-
return None, 'error'
|
340 |
-
|
341 |
-
|
342 |
-
if __name__ == '__main__':
|
343 |
-
processed2handcodedrep(midi_path, '/home/cedric/Desktop/test.mid', save=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/music/pipeline/synth2audio.py
DELETED
@@ -1,170 +0,0 @@
|
|
1 |
-
import pynput
|
2 |
-
import sys
|
3 |
-
sys.path.append('../../')
|
4 |
-
from src.music.config import SYNTH_RECORDED_AUDIO_PATH, RATE_AUDIO_SAVE
|
5 |
-
from datetime import datetime
|
6 |
-
import numpy as np
|
7 |
-
import os
|
8 |
-
import wave
|
9 |
-
|
10 |
-
from ctypes import *
|
11 |
-
from contextlib import contextmanager
|
12 |
-
import pyaudio
|
13 |
-
|
14 |
-
ERROR_HANDLER_FUNC = CFUNCTYPE(None, c_char_p, c_int, c_char_p, c_int, c_char_p)
|
15 |
-
|
16 |
-
def py_error_handler(filename, line, function, err, fmt):
|
17 |
-
pass
|
18 |
-
c_error_handler = ERROR_HANDLER_FUNC(py_error_handler)
|
19 |
-
|
20 |
-
@contextmanager
|
21 |
-
def noalsaerr():
|
22 |
-
asound = cdll.LoadLibrary('libasound.so')
|
23 |
-
asound.snd_lib_error_set_handler(c_error_handler)
|
24 |
-
yield
|
25 |
-
asound.snd_lib_error_set_handler(None)
|
26 |
-
|
27 |
-
global KEY_PRESSED
|
28 |
-
KEY_PRESSED = None
|
29 |
-
|
30 |
-
def on_press(key):
|
31 |
-
global KEY_PRESSED
|
32 |
-
try:
|
33 |
-
KEY_PRESSED = key.name
|
34 |
-
except:
|
35 |
-
pass
|
36 |
-
|
37 |
-
def on_release(key):
|
38 |
-
global KEY_PRESSED
|
39 |
-
KEY_PRESSED = None
|
40 |
-
|
41 |
-
|
42 |
-
def is_pressed(key):
|
43 |
-
global KEY_PRESSED
|
44 |
-
return KEY_PRESSED == key
|
45 |
-
|
46 |
-
# keyboard listener
|
47 |
-
listener = pynput.keyboard.Listener(on_press=on_press, on_release=on_release)
|
48 |
-
listener.start()
|
49 |
-
|
50 |
-
LEN_RECORDINGS = 40
|
51 |
-
class AudioRecorder:
|
52 |
-
def __init__(self, chunk=2**10, rate=44100, place='', len_recording=LEN_RECORDINGS, drop_beginning=0.5):
|
53 |
-
self.chunk = chunk
|
54 |
-
self.rate = rate
|
55 |
-
with noalsaerr():
|
56 |
-
self.audio = pyaudio.PyAudio()
|
57 |
-
self.channels = 1
|
58 |
-
self.format = pyaudio.paInt16
|
59 |
-
self.stream = self.audio.open(format=self.format,
|
60 |
-
channels=self.channels,
|
61 |
-
rate=rate,
|
62 |
-
input=True,
|
63 |
-
frames_per_buffer=chunk)
|
64 |
-
self.stream.stop_stream()
|
65 |
-
self.drop_beginning_chunks = int(drop_beginning * self.rate / self.chunk)
|
66 |
-
self.place = place
|
67 |
-
self.len_recordings = len_recording
|
68 |
-
|
69 |
-
def get_filename(self):
|
70 |
-
now = datetime.now()
|
71 |
-
return self.place + '_' + now.strftime("%b_%d_%Y_%Hh%Mm%Ss") + '.mp3'
|
72 |
-
|
73 |
-
def read_last_chunk(self):
|
74 |
-
return self.stream.read(self.chunk)
|
75 |
-
|
76 |
-
def live_read(self):
|
77 |
-
if self.stream.is_stopped():
|
78 |
-
self.stream.start_stream()
|
79 |
-
i = 0
|
80 |
-
while not is_pressed('esc'):
|
81 |
-
data = np.frombuffer(self.stream.read(self.chunk), dtype=np.int16)
|
82 |
-
peak = np.average(np.abs(data)) * 2
|
83 |
-
bars = "#"*int(50 * peak / 2 ** 16)
|
84 |
-
i += 1
|
85 |
-
print("%04d %05d %s"%(i,peak,bars))
|
86 |
-
self.stream.stop_stream()
|
87 |
-
|
88 |
-
def record_next_N_seconds(self, n=None, saving_path=None):
|
89 |
-
if saving_path is None:
|
90 |
-
saving_path = SYNTH_RECORDED_AUDIO_PATH + self.get_filename()
|
91 |
-
if n is None:
|
92 |
-
n = self.len_recordings
|
93 |
-
|
94 |
-
print(f'Recoding the next {n} secs.'
|
95 |
-
# f'\n\tRecording starts when the first key is pressed;'
|
96 |
-
f'\n\tPress Enter to end the recording;'
|
97 |
-
f'\n\tPress BackSpace (<--) to cancel the recording;'
|
98 |
-
f'\n\tSaving to {saving_path}')
|
99 |
-
try:
|
100 |
-
self.stream.start_stream()
|
101 |
-
backspace_pressed = False
|
102 |
-
self.recording = []
|
103 |
-
i_chunk = 0
|
104 |
-
while not is_pressed('enter') and self.chunk / self.rate * i_chunk < n:
|
105 |
-
self.recording.append(self.read_last_chunk())
|
106 |
-
i_chunk += 1
|
107 |
-
if is_pressed('backspace'):
|
108 |
-
backspace_pressed = True
|
109 |
-
print('\n \t--> Recording cancelled! (you pressed BackSpace)')
|
110 |
-
break
|
111 |
-
self.stream.stop_stream()
|
112 |
-
|
113 |
-
# save the file
|
114 |
-
if not backspace_pressed:
|
115 |
-
self.recording = self.recording[self.drop_beginning_chunks:] # drop first chunks to remove keyboard sound
|
116 |
-
with wave.open(saving_path[:-4] + '.wav', 'wb') as waveFile:
|
117 |
-
waveFile.setnchannels(self.channels)
|
118 |
-
waveFile.setsampwidth(self.audio.get_sample_size(self.format))
|
119 |
-
waveFile.setframerate(self.rate)
|
120 |
-
waveFile.writeframes(b''.join(self.recording))
|
121 |
-
os.system(f'ffmpeg -i "{saving_path[:-4] + ".wav"}" -vn -loglevel panic -y -ac 1 -ar {int(RATE_AUDIO_SAVE)} -b:a 320k "{saving_path}" ')
|
122 |
-
os.remove(saving_path[:-4] + '.wav')
|
123 |
-
print(f'\n--> Recording saved, duration: {self.chunk / self.rate * i_chunk:.2f} secs.')
|
124 |
-
return saving_path
|
125 |
-
except:
|
126 |
-
print('\n --> The recording failed.')
|
127 |
-
return None
|
128 |
-
|
129 |
-
def record_one(self):
|
130 |
-
ready_msg = False
|
131 |
-
print('Starting the recording loop!\n\tPress BackSpace to cancel the current recording;\n\tPress Esc to quit the loop (only works while not recording)')
|
132 |
-
while True:
|
133 |
-
if not ready_msg:
|
134 |
-
print('-------\nReady to record!')
|
135 |
-
print('Press space to start a recording\n')
|
136 |
-
ready_msg = True
|
137 |
-
|
138 |
-
if is_pressed('space'):
|
139 |
-
saving_path = self.record_next_N_seconds()
|
140 |
-
break
|
141 |
-
return saving_path
|
142 |
-
|
143 |
-
def run(self):
|
144 |
-
# with pynput.Listener(
|
145 |
-
# on_press=self.on_press) as listener:
|
146 |
-
# listener.join()
|
147 |
-
ready_msg = False
|
148 |
-
print('Starting the recording loop!\n\tPress BackSpace to cancel the current recording;\n\tPress Esc to quit the loop (only works while not recording)')
|
149 |
-
while True:
|
150 |
-
if not ready_msg:
|
151 |
-
print('-------\nReady to record!')
|
152 |
-
print('Press space to start a recording\n')
|
153 |
-
ready_msg = True
|
154 |
-
|
155 |
-
if is_pressed('space'):
|
156 |
-
self.record_next_N_seconds()
|
157 |
-
ready_msg = False
|
158 |
-
if is_pressed('esc'):
|
159 |
-
print('End of the recording session. See you soon!')
|
160 |
-
self.close()
|
161 |
-
break
|
162 |
-
|
163 |
-
def close(self):
|
164 |
-
self.stream.close()
|
165 |
-
self.audio.terminate()
|
166 |
-
|
167 |
-
if __name__ == '__main__':
|
168 |
-
audio_recorder = AudioRecorder(place='home')
|
169 |
-
audio_recorder.record_one()
|
170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/music/pipeline/synth2midi.py
DELETED
@@ -1,146 +0,0 @@
|
|
1 |
-
import mido
|
2 |
-
mido.set_backend('mido.backends.pygame')
|
3 |
-
from mido import Message, MidiFile, MidiTrack
|
4 |
-
import time
|
5 |
-
import pynput
|
6 |
-
import sys
|
7 |
-
sys.path.append('../../')
|
8 |
-
from src.music.config import SYNTH_RECORDED_MIDI_PATH
|
9 |
-
from datetime import datetime
|
10 |
-
|
11 |
-
#TODO: debug this with other cable, keyboard and sound card
|
12 |
-
global KEY_PRESSED
|
13 |
-
KEY_PRESSED = None
|
14 |
-
|
15 |
-
def on_press(key):
|
16 |
-
global KEY_PRESSED
|
17 |
-
try:
|
18 |
-
KEY_PRESSED = key.name
|
19 |
-
except:
|
20 |
-
pass
|
21 |
-
|
22 |
-
def on_release(key):
|
23 |
-
global KEY_PRESSED
|
24 |
-
KEY_PRESSED = None
|
25 |
-
|
26 |
-
|
27 |
-
def is_pressed(key):
|
28 |
-
global KEY_PRESSED
|
29 |
-
return KEY_PRESSED == key
|
30 |
-
|
31 |
-
# keyboard listener
|
32 |
-
listener = pynput.keyboard.Listener(on_press=on_press, on_release=on_release)
|
33 |
-
listener.start()
|
34 |
-
|
35 |
-
LEN_MIDI_RECORDINGS = 30
|
36 |
-
class MidiRecorder:
|
37 |
-
def __init__(self, place='', len_midi_recordings=LEN_MIDI_RECORDINGS):
|
38 |
-
self.place = place
|
39 |
-
self.len_midi_recordings = len_midi_recordings
|
40 |
-
self.port = mido.open_input(mido.get_input_names()[0])
|
41 |
-
|
42 |
-
def get_filename(self):
|
43 |
-
now = datetime.now()
|
44 |
-
return self.place + '_' + now.strftime("%b_%d_%Y_%Hh%Mm%Ss") + '.mid'
|
45 |
-
|
46 |
-
def read_last_midi_msgs(self):
|
47 |
-
return list(self.port.iter_pending())
|
48 |
-
|
49 |
-
def live_read(self):
|
50 |
-
while not is_pressed('esc'):
|
51 |
-
for msg in self.read_last_midi_msgs():
|
52 |
-
print(msg)
|
53 |
-
|
54 |
-
def check_if_recording_started(self, msgs, t_init):
|
55 |
-
started = False
|
56 |
-
if len(msgs) > 0:
|
57 |
-
for m in msgs:
|
58 |
-
if m.type == 'note_on':
|
59 |
-
started = True
|
60 |
-
t_init = time.time()
|
61 |
-
return started, t_init
|
62 |
-
|
63 |
-
def create_empty_midi(self):
|
64 |
-
mid = MidiFile()
|
65 |
-
track = MidiTrack()
|
66 |
-
mid.tracks.append(track)
|
67 |
-
track.append(Message('program_change', program=0, time=0))
|
68 |
-
return mid, track
|
69 |
-
|
70 |
-
def record_next_N_seconds(self, n=None, saving_path=None):
|
71 |
-
if saving_path is None:
|
72 |
-
saving_path = SYNTH_RECORDED_PATH + self.get_filename()
|
73 |
-
if n is None:
|
74 |
-
n = self.len_midi_recordings
|
75 |
-
|
76 |
-
print(f'Recoding the next {n} secs.'
|
77 |
-
f'\n\tRecording starts when the first key is pressed;'
|
78 |
-
f'\n\tPress Enter to end the recording;'
|
79 |
-
f'\n\tPress BackSpace (<--) to cancel the recording;'
|
80 |
-
f'\n\tSaving to {saving_path}')
|
81 |
-
try:
|
82 |
-
mid, track = self.create_empty_midi()
|
83 |
-
started = False
|
84 |
-
backspace_pressed = False
|
85 |
-
t_init = time.time()
|
86 |
-
while not is_pressed('enter') and (time.time() - t_init) < n:
|
87 |
-
msgs = self.read_last_midi_msgs()
|
88 |
-
if not started:
|
89 |
-
started, t_init = self.check_if_recording_started(msgs, t_init)
|
90 |
-
if started:
|
91 |
-
print("\n\t--> First note pressed, it's on!")
|
92 |
-
for m in msgs:
|
93 |
-
print(m)
|
94 |
-
if m.type == 'note_on' and m.velocity == 0:
|
95 |
-
m_off = Message(type='note_off', velocity=127, note=m.note, channel=m.channel, time=m.time)
|
96 |
-
track.append(m_off)
|
97 |
-
track.append(m)
|
98 |
-
if is_pressed('backspace'):
|
99 |
-
backspace_pressed = True
|
100 |
-
print('\n \t--> Recording cancelled! (you pressed BackSpace)')
|
101 |
-
break
|
102 |
-
# save the file
|
103 |
-
if not backspace_pressed and len(mid.tracks[0]) > 0:
|
104 |
-
mid.save(saving_path)
|
105 |
-
print(f'\n--> Recording saved, duration: {mid.length:.2f} secs, {len(mid.tracks[0])} events.')
|
106 |
-
except:
|
107 |
-
print('\n --> The recording failed.')
|
108 |
-
|
109 |
-
|
110 |
-
def run(self):
|
111 |
-
# with pynput.Listener(
|
112 |
-
# on_press=self.on_press) as listener:
|
113 |
-
# listener.join()
|
114 |
-
ready_msg = False
|
115 |
-
print('Starting the recording loop!\n\tPress BackSpace to cancel the current recording;\n\tPress Esc to quit the loop (only works while not recording)')
|
116 |
-
while True:
|
117 |
-
if not ready_msg:
|
118 |
-
print('-------\nReady to record!')
|
119 |
-
print('Press space to start a recording\n')
|
120 |
-
ready_msg = True
|
121 |
-
|
122 |
-
if is_pressed('space'):
|
123 |
-
self.record_next_N_seconds()
|
124 |
-
ready_msg = False
|
125 |
-
if is_pressed('esc'):
|
126 |
-
print('End of the recording session. See you soon!')
|
127 |
-
break
|
128 |
-
|
129 |
-
|
130 |
-
midi_recorder = MidiRecorder(place='home')
|
131 |
-
midi_recorder.live_read()
|
132 |
-
# midi_recorder.run()
|
133 |
-
|
134 |
-
|
135 |
-
# try:
|
136 |
-
# controls[msg.control] = msg.value
|
137 |
-
# except:
|
138 |
-
# notes.append(msg.note)
|
139 |
-
# port = mido.open_input()
|
140 |
-
# while True:
|
141 |
-
# for msg in port.iter_pending():
|
142 |
-
# print(msg)
|
143 |
-
#
|
144 |
-
# print('start pause')
|
145 |
-
# time.sleep(5)
|
146 |
-
# print('stop pause')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/music/pipeline/url2audio.py
DELETED
@@ -1,119 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from pytube import YouTube
|
3 |
-
from src.music.utils import RATE_AUDIO_SAVE, slugify
|
4 |
-
from src.music.config import MAX_LEN
|
5 |
-
|
6 |
-
# define filtering keyworfds
|
7 |
-
start_keywords = [' ', '(', ',', ':']
|
8 |
-
end_keywords = [')', ' ', '.', ',', '!', ':']
|
9 |
-
def get_all_keywords(k):
|
10 |
-
all_keywords = []
|
11 |
-
for s in start_keywords:
|
12 |
-
for e in end_keywords:
|
13 |
-
all_keywords.append(s + k + e)
|
14 |
-
return all_keywords
|
15 |
-
filtered_keywords = ['duet', 'duo', 'quartet', 'orchestre', 'orchestra',
|
16 |
-
'quintet', 'sixtet', 'septet', 'octet', 'backing track', 'accompaniment', 'string',
|
17 |
-
'contrebrasse', 'drums', 'guitar'] + get_all_keywords('live') + get_all_keywords('trio')
|
18 |
-
|
19 |
-
# list of playlist for which no filtering should occur on keywords (they were prefiltered already, it's supposed to be only piano)
|
20 |
-
playlist_and_channel_not_to_filter = ["https://www.youtube.com/c/MySheetMusicTranscriptions",
|
21 |
-
"https://www.youtube.com/c/PianoNotion",
|
22 |
-
"https://www.youtube.com/c/PianoNotion",
|
23 |
-
"https://www.youtube.com/watch?v=3F5glYefwio&list=PLFv3ZQw-ZPxi2DH3Bau7lBC5K6zfPJZxc",
|
24 |
-
"https://www.youtube.com/user/Mercuziopianist",
|
25 |
-
"https://www.youtube.com/channel/UCy6NPK6-xeX7MZLaMARa5qg",
|
26 |
-
"https://www.youtube.com/channel/UCKMRNFV2dWTWIJnymtA9_Iw",
|
27 |
-
"https://www.youtube.com/c/pianomaedaful",
|
28 |
-
"https://www.youtube.com/c/FrancescoParrinoMusic",
|
29 |
-
"https://www.youtube.com/c/itsremco"]
|
30 |
-
playlist_ok = "https://www.youtube.com/watch?v=sYv_vk6bJtk&list=PLO9E3V4rGLD9-0BEd3t-AvvMcVF1zOJPj"
|
31 |
-
|
32 |
-
|
33 |
-
def should_be_filtered(title, length, url, playlist_url, max_length):
|
34 |
-
to_filter = False
|
35 |
-
reason = ''
|
36 |
-
lower_title = title.lower()
|
37 |
-
if length > max_length:
|
38 |
-
reason += f'it is too long (>{max_length/60:.1f} min), '
|
39 |
-
to_filter = True
|
40 |
-
if any([f in lower_title for f in filtered_keywords]) \
|
41 |
-
and playlist_url not in playlist_and_channel_not_to_filter \
|
42 |
-
and 'to live' not in lower_title and 'alive' not in lower_title \
|
43 |
-
and url not in playlist_ok:
|
44 |
-
reason += 'it contains a filtered keyword, '
|
45 |
-
to_filter = True
|
46 |
-
return to_filter, reason
|
47 |
-
|
48 |
-
def convert_mp4_to_mp3(path, verbose=True):
|
49 |
-
if verbose: print(f"Converting mp4 to mp3, in {path}\n")
|
50 |
-
assert '.mp4' == path[-4:]
|
51 |
-
os.system(f'ffmpeg -i "{path}" -loglevel panic -y -ac 1 -ar {int(RATE_AUDIO_SAVE)} "{path[:-4] + ".mp3"}" ')
|
52 |
-
os.remove(path)
|
53 |
-
if verbose: print('\tDone.')
|
54 |
-
|
55 |
-
def pipeline_video(video, playlist_path, filename):
|
56 |
-
# extract best stream for this video
|
57 |
-
stream, kbps = extract_best_stream(video.streams)
|
58 |
-
stream.download(output_path=playlist_path, filename=filename + '.mp4')
|
59 |
-
# convert to mp3
|
60 |
-
convert_mp4_to_mp3(playlist_path + filename + '.mp4', verbose=False)
|
61 |
-
return kbps
|
62 |
-
|
63 |
-
def extract_best_stream(streams):
|
64 |
-
# extract best audio stream
|
65 |
-
stream_out = streams.get_audio_only()
|
66 |
-
kbps = int(stream_out.abr[:-4])
|
67 |
-
return stream_out, kbps
|
68 |
-
|
69 |
-
def get_title_and_length(video):
|
70 |
-
title = video.title
|
71 |
-
filename = slugify(title)
|
72 |
-
length = video.length
|
73 |
-
return title, filename, length, video.metadata
|
74 |
-
|
75 |
-
|
76 |
-
def url2audio(playlist_path, video_url=None, video=None, playlist_url='', apply_filters=False, verbose=False, level=0):
|
77 |
-
assert video_url is not None or video is not None, 'needs either video or url'
|
78 |
-
error_msg = 'Error in loading video?'
|
79 |
-
try:
|
80 |
-
if not video:
|
81 |
-
video = YouTube(video_url)
|
82 |
-
error_msg += ' Nope. In extracting title and length?'
|
83 |
-
title, filename, length, video_meta_data = get_title_and_length(video)
|
84 |
-
if apply_filters:
|
85 |
-
to_filter, reason = should_be_filtered(title, length, video_url, playlist_url, MAX_LEN)
|
86 |
-
else:
|
87 |
-
to_filter = False
|
88 |
-
if not to_filter:
|
89 |
-
audio_path = playlist_path + filename + ".mp3"
|
90 |
-
if verbose: print(' ' * level + f'Downloading {title}, Url: {video_url}')
|
91 |
-
if not os.path.exists(audio_path):
|
92 |
-
if length > MAX_LEN and verbose: print(' ' * (level + 2) + f'Long video ({int(length/60)} min), will be cut after {int(MAX_LEN/60)} min.')
|
93 |
-
error_msg += ' Nope. In pipeline video?'
|
94 |
-
kbps = None
|
95 |
-
for _ in range(5):
|
96 |
-
try:
|
97 |
-
kbps = pipeline_video(video, playlist_path, filename)
|
98 |
-
break
|
99 |
-
except:
|
100 |
-
pass
|
101 |
-
assert kbps is not None
|
102 |
-
error_msg += ' Nope. In dict filling?'
|
103 |
-
data = dict(title=title, filename=filename, length=length, kbps=kbps, url=video_url, meta=video_meta_data)
|
104 |
-
error_msg += ' Nope. '
|
105 |
-
else:
|
106 |
-
if verbose: print(' ' * (level + 2) + 'Song already downloaded')
|
107 |
-
data = None
|
108 |
-
return audio_path, data, ''
|
109 |
-
else:
|
110 |
-
return None, None, f'Filtered because {reason}'
|
111 |
-
except:
|
112 |
-
if verbose: print(' ' * (level + 2) + f'Download failed with error {error_msg}')
|
113 |
-
if os.path.exists(audio_path):
|
114 |
-
os.remove(audio_path)
|
115 |
-
return None, None, error_msg + ' Yes.'
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/music/representation_analysis/__init__.py
DELETED
File without changes
|
src/music/representation_analysis/analyze_rep.py
DELETED
@@ -1,146 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
from sklearn.cluster import KMeans
|
3 |
-
from sklearn.neighbors import NearestNeighbors
|
4 |
-
from sklearn.manifold import TSNE
|
5 |
-
from src.music.utils import get_all_subfiles_with_extension
|
6 |
-
import matplotlib.pyplot as plt
|
7 |
-
import pickle
|
8 |
-
import random
|
9 |
-
# import umap
|
10 |
-
import os
|
11 |
-
from shutil import copy
|
12 |
-
# install numba =numba==0.51.2
|
13 |
-
# keyword = '32_represented'
|
14 |
-
# rep_path = f"/home/cedric/Documents/pianocktail/data/music/{keyword}/"
|
15 |
-
# plot_path = '/home/cedric/Documents/pianocktail/data/music/representation_analysis/plots/'
|
16 |
-
# neighbors_path = '/home/cedric/Documents/pianocktail/data/music/representation_analysis/neighbors/'
|
17 |
-
interpolation_path = '/home/cedric/Documents/pianocktail/data/music/representation_analysis/interpolation/'
|
18 |
-
keyword = 'b256_r128_represented'
|
19 |
-
rep_path = f"/home/cedric/Documents/pianocktail/data/music/dataset_exploration/dataset_representation/{keyword}/"
|
20 |
-
plot_path = '/home/cedric/Documents/pianocktail/data/music/dataset_exploration/dataset_representation/analysis/plots/'
|
21 |
-
neighbors_path = f'/home/cedric/Documents/pianocktail/data/music/dataset_exploration/dataset_representation/analysis/neighbors_{keyword}/'
|
22 |
-
os.makedirs(neighbors_path, exist_ok=True)
|
23 |
-
def extract_all_reps(rep_path):
|
24 |
-
all_rep_path = get_all_subfiles_with_extension(rep_path, max_depth=3, extension='.txt', current_depth=0)
|
25 |
-
all_data = []
|
26 |
-
new_all_rep_path = []
|
27 |
-
for i_r, r in enumerate(all_rep_path):
|
28 |
-
if 'mean_std' not in r:
|
29 |
-
all_data.append(np.loadtxt(r))
|
30 |
-
assert len(all_data[-1]) == 128
|
31 |
-
new_all_rep_path.append(r)
|
32 |
-
data = np.array(all_data)
|
33 |
-
to_save = dict(reps=data,
|
34 |
-
paths=new_all_rep_path)
|
35 |
-
with open(rep_path + 'music_reps_unnormalized.pickle', 'wb') as f:
|
36 |
-
pickle.dump(to_save, f)
|
37 |
-
for sample_size in [100, 200, 500, 1000, 2000, 5000]:
|
38 |
-
if sample_size < len(data):
|
39 |
-
inds = np.arange(len(data))
|
40 |
-
np.random.shuffle(inds)
|
41 |
-
to_save = dict(reps=data[inds[:sample_size]],
|
42 |
-
paths=np.array(all_rep_path)[inds[:sample_size]])
|
43 |
-
with open(rep_path + f'all_reps_unnormalized_sample{sample_size}.pickle', 'wb') as f:
|
44 |
-
pickle.dump(to_save, f)
|
45 |
-
|
46 |
-
def load_reps(rep_path, sample_size=None):
|
47 |
-
if sample_size:
|
48 |
-
with open(rep_path + f'all_reps_unnormalized_sample{sample_size}.pickle', 'rb') as f:
|
49 |
-
data = pickle.load(f)
|
50 |
-
else:
|
51 |
-
with open(rep_path + f'music_reps_unnormalized.pickle', 'rb') as f:
|
52 |
-
data = pickle.load(f)
|
53 |
-
reps = data['reps']
|
54 |
-
# playlists = [r.split(f'_{keyword}')[0].split('/')[-1] for r in data['paths']]
|
55 |
-
playlists = [r.split(f'{keyword}')[1].split('/')[1] for r in data['paths']]
|
56 |
-
n_data, dim_data = reps.shape
|
57 |
-
return reps, data['paths'], playlists, n_data, dim_data
|
58 |
-
|
59 |
-
|
60 |
-
def plot_tsne(reps, playlist_indexes, playlist_colors):
|
61 |
-
tsne_reps = TSNE(n_components=2, learning_rate='auto', init='random').fit_transform(reps)
|
62 |
-
plt.figure()
|
63 |
-
keys_to_print = ['spot_piano_solo_blues', 'itsremco', 'piano_solo_classical',
|
64 |
-
'piano_solo_pop', 'piano_jazz_unspecified','spot_piano_solo_jazz_1', 'piano_solo_jazz_latin']
|
65 |
-
keys_to_print = playlist_indexes.keys()
|
66 |
-
for k in sorted(keys_to_print):
|
67 |
-
if k in playlist_indexes.keys():
|
68 |
-
# plt.scatter(tsne_reps[playlist_indexes[k], 0], tsne_reps[playlist_indexes[k], 1], s=100, label=k, alpha=0.5)
|
69 |
-
plt.scatter(tsne_reps[playlist_indexes[k], 0], tsne_reps[playlist_indexes[k], 1], s=100, c=playlist_colors[k], label=k, alpha=0.5)
|
70 |
-
plt.legend()
|
71 |
-
plt.savefig(plot_path + f'tsne_{keyword}.png')
|
72 |
-
fig = plt.gcf()
|
73 |
-
plt.close(fig)
|
74 |
-
# umap_reps = umap.UMAP().fit_transform(reps)
|
75 |
-
# plt.figure()
|
76 |
-
# for k in sorted(keys_to_print):
|
77 |
-
# if k in playlist_indexes.keys():
|
78 |
-
# plt.scatter(umap_reps[playlist_indexes[k], 0], tsne_reps[playlist_indexes[k], 1], s=100, c=playlist_colors[k], label=k, alpha=0.5)
|
79 |
-
# plt.legend()
|
80 |
-
# plt.savefig(plot_path + f'umap_{keyword}.png')
|
81 |
-
# fig = plt.gcf()
|
82 |
-
# plt.close(fig)
|
83 |
-
return tsne_reps#, umap_reps
|
84 |
-
|
85 |
-
def get_playlist_indexes(playlists):
|
86 |
-
playlist_indexes = dict()
|
87 |
-
for i in range(n_data):
|
88 |
-
if playlists[i] not in playlist_indexes.keys():
|
89 |
-
playlist_indexes[playlists[i]] = [i]
|
90 |
-
else:
|
91 |
-
playlist_indexes[playlists[i]].append(i)
|
92 |
-
for k in playlist_indexes.keys():
|
93 |
-
playlist_indexes[k] = np.array(playlist_indexes[k])
|
94 |
-
set_playlists = sorted(set(playlists))
|
95 |
-
playlist_colors = dict(zip(set_playlists, ['#%06X' % random.randint(0, 0xFFFFFF) for _ in range(len(set_playlists))]))
|
96 |
-
return set_playlists, playlist_indexes, playlist_colors
|
97 |
-
|
98 |
-
def convert_rep_path_midi_path(rep_path):
|
99 |
-
# playlist = rep_path.split(f'_{keyword}/')[0].split('/')[-1]
|
100 |
-
playlist = rep_path.split(f'{keyword}')[1].split('/')[1].replace('_represented', '')
|
101 |
-
midi_path = "/home/cedric/Documents/pianocktail/data/music/dataset_exploration/dataset_representation/processed/" + playlist + '_processed/'
|
102 |
-
filename = rep_path.split(f'{keyword}')[1].split(f'/')[2].split('_represented.txt')[0] + '_processed.mid'
|
103 |
-
# filename = rep_path.split(f'_{keyword}/')[-1].split(f'_{keyword}')[0] + '_processed.mid'
|
104 |
-
midi_path = midi_path + filename
|
105 |
-
assert os.path.exists(midi_path), midi_path
|
106 |
-
return midi_path
|
107 |
-
|
108 |
-
def sample_nn(reps, rep_paths, playlists, n_samples=30):
|
109 |
-
nn_model = NearestNeighbors(n_neighbors=6, metric='cosine')
|
110 |
-
nn_model.fit(reps)
|
111 |
-
indexes = np.arange(len(reps))
|
112 |
-
np.random.shuffle(indexes)
|
113 |
-
for i, ind in enumerate(indexes[:n_samples]):
|
114 |
-
out = nn_model.kneighbors(reps[ind].reshape(1, -1))[1][0][1:]
|
115 |
-
midi_path = convert_rep_path_midi_path(rep_paths[ind])
|
116 |
-
copy(midi_path, neighbors_path + f'sample_{i}_playlist_{playlists[ind]}_target.mid')
|
117 |
-
for i_n, neighbor in enumerate(out):
|
118 |
-
midi_path = convert_rep_path_midi_path(rep_paths[neighbor])
|
119 |
-
copy(midi_path, neighbors_path + f'sample_{i}_playlist_{playlists[neighbor]}_neighbor_{i_n}.mid')
|
120 |
-
|
121 |
-
def interpolate(reps, rep_paths, path):
|
122 |
-
files = os.listdir(path)
|
123 |
-
bounds = [f for f in files if 'interpolation' not in f]
|
124 |
-
b_reps = [np.loadtxt(path + f) for f in bounds]
|
125 |
-
nn_model = NearestNeighbors(n_neighbors=6)
|
126 |
-
nn_model.fit(reps)
|
127 |
-
reps = [alpha * b_reps[0] + (1 - alpha) * b_reps[1] for alpha in np.linspace(0, 1., 5)]
|
128 |
-
copy(convert_rep_path_midi_path(path + bounds[1]), path + 'interpolation_0.mid')
|
129 |
-
copy(convert_rep_path_midi_path(path + bounds[0]), path + 'interpolation_1.mid')
|
130 |
-
for alpha, rep in zip(np.linspace(0, 1, 5)[1:-1], reps[1: -1]):
|
131 |
-
dists, indexes = nn_model.kneighbors(rep.reshape(1, -1))
|
132 |
-
if dists.flatten()[0] == 0:
|
133 |
-
nn = indexes.flatten()[1]
|
134 |
-
else:
|
135 |
-
nn = indexes.flatten()[0]
|
136 |
-
midi_path = convert_rep_path_midi_path(rep_paths[nn])
|
137 |
-
copy(midi_path, path + f'interpolation_{alpha}.mid')
|
138 |
-
|
139 |
-
if __name__ == '__main__':
|
140 |
-
extract_all_reps(rep_path)
|
141 |
-
reps, rep_paths, playlists, n_data, dim_data = load_reps(rep_path)
|
142 |
-
set_playlists, playlist_indexes, playlist_colors = get_playlist_indexes(playlists)
|
143 |
-
# interpolate(reps, rep_paths, interpolation_path + 'trial_1/')
|
144 |
-
sample_nn(reps, rep_paths, playlists)
|
145 |
-
tsne_reps, umap_reps = plot_tsne(reps, playlist_indexes, playlist_colors)
|
146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/music/representation_learning/__init__.py
DELETED
File without changes
|
src/music/representation_learning/mlm_pretrain/__init__.py
DELETED
File without changes
|
src/music/representation_learning/mlm_pretrain/data_collators.py
DELETED
@@ -1,180 +0,0 @@
|
|
1 |
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
2 |
-
from transformers.data.data_collator import DataCollatorForLanguageModeling, PreTrainedTokenizerBase, BatchEncoding, DataCollatorForPermutationLanguageModeling
|
3 |
-
from dataclasses import dataclass
|
4 |
-
|
5 |
-
|
6 |
-
def _torch_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
|
7 |
-
"""Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
|
8 |
-
import numpy as np
|
9 |
-
import torch
|
10 |
-
|
11 |
-
# Tensorize if necessary.
|
12 |
-
if isinstance(examples[0], (list, tuple, np.ndarray)):
|
13 |
-
examples = [torch.tensor(e, dtype=torch.long) for e in examples]
|
14 |
-
|
15 |
-
length_of_first = examples[0].size(0)
|
16 |
-
|
17 |
-
# Check if padding is necessary.
|
18 |
-
|
19 |
-
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
|
20 |
-
if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
|
21 |
-
return torch.stack(examples, dim=0)
|
22 |
-
|
23 |
-
# If yes, check if we have a `pad_token`.
|
24 |
-
if tokenizer._pad_token is None:
|
25 |
-
raise ValueError(
|
26 |
-
"You are attempting to pad samples but the tokenizer you are using"
|
27 |
-
f" ({tokenizer.__class__.__name__}) does not have a pad token."
|
28 |
-
)
|
29 |
-
|
30 |
-
# Creating the full tensor and filling it with our data.
|
31 |
-
max_length = max(x.size(0) for x in examples)
|
32 |
-
if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
33 |
-
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
34 |
-
result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
|
35 |
-
for i, example in enumerate(examples):
|
36 |
-
if tokenizer.padding_side == "right":
|
37 |
-
result[i, : example.shape[0]] = example
|
38 |
-
else:
|
39 |
-
result[i, -example.shape[0] :] = example
|
40 |
-
return result
|
41 |
-
|
42 |
-
|
43 |
-
@dataclass
|
44 |
-
class DataCollatorForMusicModeling(DataCollatorForLanguageModeling):
|
45 |
-
"""
|
46 |
-
Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
|
47 |
-
are not all of the same length.
|
48 |
-
Args:
|
49 |
-
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
|
50 |
-
The tokenizer used for encoding the data.
|
51 |
-
mlm (`bool`, *optional*, defaults to `True`):
|
52 |
-
Whether or not to use masked language modeling. If set to `False`, the labels are the same as the inputs
|
53 |
-
with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for non-masked
|
54 |
-
tokens and the value to predict for the masked token.
|
55 |
-
mlm_probability (`float`, *optional*, defaults to 0.15):
|
56 |
-
The probability with which to (randomly) mask tokens in the input, when `mlm` is set to `True`.
|
57 |
-
pad_to_multiple_of (`int`, *optional*):
|
58 |
-
If set will pad the sequence to a multiple of the provided value.
|
59 |
-
return_tensors (`str`):
|
60 |
-
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
61 |
-
<Tip>
|
62 |
-
For best performance, this data collator should be used with a dataset having items that are dictionaries or
|
63 |
-
BatchEncoding, with the `"special_tokens_mask"` key, as returned by a [`PreTrainedTokenizer`] or a
|
64 |
-
[`PreTrainedTokenizerFast`] with the argument `return_special_tokens_mask=True`.
|
65 |
-
</Tip>"""
|
66 |
-
|
67 |
-
|
68 |
-
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
69 |
-
# Handle dict or lists with proper padding and conversion to tensor.
|
70 |
-
if isinstance(examples[0], (dict, BatchEncoding)):
|
71 |
-
batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of)
|
72 |
-
else:
|
73 |
-
batch = {
|
74 |
-
"input_ids": _torch_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
75 |
-
}
|
76 |
-
|
77 |
-
# If special token mask has been preprocessed, pop it from the dict.
|
78 |
-
special_tokens_mask = batch.pop("special_tokens_mask", None)
|
79 |
-
if self.mlm:
|
80 |
-
batch["input_ids"], batch["labels"] = self.torch_mask_tokens(
|
81 |
-
batch["input_ids"], special_tokens_mask=special_tokens_mask
|
82 |
-
)
|
83 |
-
else:
|
84 |
-
labels = batch["input_ids"].clone()
|
85 |
-
if self.tokenizer.pad_token_id is not None:
|
86 |
-
labels[labels == self.tokenizer.pad_token_id] = -100
|
87 |
-
batch["labels"] = labels
|
88 |
-
return batch
|
89 |
-
|
90 |
-
def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
|
91 |
-
"""
|
92 |
-
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
|
93 |
-
"""
|
94 |
-
import torch
|
95 |
-
|
96 |
-
labels = inputs.clone()
|
97 |
-
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
|
98 |
-
notes_shape = (labels.shape[0], labels.shape[1] // 5)
|
99 |
-
probability_matrix = torch.full(notes_shape, self.mlm_probability)
|
100 |
-
# if special_tokens_mask is None:
|
101 |
-
# special_tokens_mask = [
|
102 |
-
# self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
103 |
-
# ]
|
104 |
-
# special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
|
105 |
-
# else:
|
106 |
-
# special_tokens_mask = special_tokens_mask.bool()
|
107 |
-
|
108 |
-
# probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
|
109 |
-
masked_notes_indices = torch.bernoulli(probability_matrix).bool()
|
110 |
-
masked_indices = torch.repeat_interleave(masked_notes_indices, repeats=5, dim=1)
|
111 |
-
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
112 |
-
|
113 |
-
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
114 |
-
indices_notes_replaced = torch.bernoulli(torch.full(notes_shape, 0.8)).bool() & masked_notes_indices
|
115 |
-
indices_replaced = torch.repeat_interleave(indices_notes_replaced, repeats=5, dim=1)
|
116 |
-
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
117 |
-
|
118 |
-
# 10% of the time, we replace masked input tokens with random word
|
119 |
-
indices_notes_random = torch.bernoulli(torch.full(notes_shape, 0.5)).bool() & masked_notes_indices & ~indices_notes_replaced
|
120 |
-
indices_random = torch.repeat_interleave(indices_notes_random, repeats=5, dim=1)
|
121 |
-
random_words = torch.randint(3, len(self.tokenizer), labels.shape, dtype=torch.long)
|
122 |
-
inputs[indices_random] = random_words[indices_random]
|
123 |
-
|
124 |
-
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
125 |
-
return inputs, labels
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
@dataclass
|
130 |
-
class DataCollatorForSpanMusicModeling(DataCollatorForLanguageModeling):
|
131 |
-
"""
|
132 |
-
Data collator used for permutation language modeling.
|
133 |
-
- collates batches of tensors, honoring their tokenizer's pad_token
|
134 |
-
- preprocesses batches for permutation language modeling with procedures specific to XLNet
|
135 |
-
"""
|
136 |
-
|
137 |
-
|
138 |
-
def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
|
139 |
-
"""
|
140 |
-
The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
|
141 |
-
0. Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
|
142 |
-
1. Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
|
143 |
-
2. Reserve a context of length `context_length = span_length / plm_probability` to surround span to be
|
144 |
-
masked
|
145 |
-
3. Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length -
|
146 |
-
span_length]` and mask tokens `start_index:start_index + span_length`
|
147 |
-
4. Set `cur_len = cur_len + context_length`. If `cur_len < max_len` (i.e. there are tokens remaining in the
|
148 |
-
sequence to be processed), repeat from Step 1.
|
149 |
-
"""
|
150 |
-
|
151 |
-
import torch
|
152 |
-
|
153 |
-
labels = inputs.clone()
|
154 |
-
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
|
155 |
-
notes_shape = (labels.shape[0], labels.shape[1] // 5)
|
156 |
-
masked_notes_indices = torch.full(notes_shape, 0, dtype=torch.bool)
|
157 |
-
|
158 |
-
for i in range(labels.size(0)):
|
159 |
-
# Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
|
160 |
-
cur_len = 0
|
161 |
-
max_len = notes_shape[1]
|
162 |
-
|
163 |
-
while cur_len < max_len:
|
164 |
-
# Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
|
165 |
-
span_length = torch.randint(1, 5 + 1, (1,)).item()
|
166 |
-
# Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked
|
167 |
-
context_length = int(span_length / self.mlm_probability)
|
168 |
-
# Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length`
|
169 |
-
start_index = cur_len + torch.randint(context_length - span_length + 1, (1,)).item()
|
170 |
-
masked_notes_indices[i, start_index: start_index + span_length] = 1
|
171 |
-
# Set `cur_len = cur_len + context_length`
|
172 |
-
cur_len += context_length
|
173 |
-
|
174 |
-
masked_indices = torch.repeat_interleave(masked_notes_indices, repeats=5, dim=1)
|
175 |
-
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
176 |
-
|
177 |
-
inputs[masked_indices] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
178 |
-
|
179 |
-
return inputs, labels
|
180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/music/representation_learning/mlm_pretrain/models/music-bert/config.json
DELETED
@@ -1,20 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"attention_probs_dropout_prob": 0.1,
|
3 |
-
"gradient_checkpointing": false,
|
4 |
-
"hidden_act": "gelu",
|
5 |
-
"hidden_dropout_prob": 0.1,
|
6 |
-
"hidden_size": 768,
|
7 |
-
"initializer_range": 0.02,
|
8 |
-
"intermediate_size": 3072,
|
9 |
-
"layer_norm_eps": 1e-12,
|
10 |
-
"max_position_embeddings": 512,
|
11 |
-
"model_type": "bert",
|
12 |
-
"num_attention_heads": 12,
|
13 |
-
"num_hidden_layers": 12,
|
14 |
-
"pad_token_id": 0,
|
15 |
-
"position_embedding_type": "relative_key_query",
|
16 |
-
"transformers_version": "4.8.2",
|
17 |
-
"type_vocab_size": 2,
|
18 |
-
"use_cache": true,
|
19 |
-
"vocab_size": 30522
|
20 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/music/representation_learning/mlm_pretrain/models/music-bert/tokenizer.json
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
{"version":"1.0","truncation":null,"padding":null,"added_tokens":[],"normalizer":{"type":"Lowercase"},"pre_tokenizer":{"type":"Whitespace"},"post_processor":null,"decoder":null,"model":{"type":"WordLevel","vocab":{"[PAD]":0,"[MASK]":1,"[UNK]":2,"2":3,"3":4,"4":5,"5":6,"6":7,"7":8,"8":9,"9":10,"10":11,"11":12,"12":13,"13":14,"14":15,"15":16,"16":17,"17":18,"18":19,"19":20,"20":21,"21":22,"22":23,"23":24,"24":25,"25":26,"26":27,"27":28,"28":29,"29":30,"30":31,"31":32,"32":33,"33":34,"34":35,"35":36,"36":37,"37":38,"38":39,"39":40,"40":41,"41":42,"42":43,"43":44,"44":45,"45":46,"46":47,"47":48,"48":49,"49":50,"50":51,"51":52,"52":53,"53":54,"54":55,"55":56,"56":57,"57":58,"58":59,"59":60,"60":61,"61":62,"62":63,"63":64,"64":65,"65":66,"66":67,"67":68,"68":69,"69":70,"70":71,"71":72,"72":73,"73":74,"74":75,"75":76,"76":77,"77":78,"78":79,"79":80,"80":81,"81":82,"82":83,"83":84,"84":85,"85":86,"86":87,"87":88,"88":89,"89":90,"90":91,"91":92,"92":93,"93":94,"94":95,"95":96,"96":97,"97":98,"98":99,"99":100,"100":101,"101":102,"102":103,"103":104,"104":105,"105":106,"106":107,"107":108,"108":109,"109":110,"110":111,"111":112,"112":113,"113":114,"114":115,"115":116,"116":117,"117":118,"118":119,"119":120,"120":121,"121":122,"122":123,"123":124,"124":125,"125":126,"126":127,"127":128,"128":129,"129":130,"130":131,"131":132,"132":133,"133":134,"134":135,"135":136,"136":137,"137":138,"138":139,"139":140,"140":141,"141":142,"142":143,"143":144,"144":145,"145":146,"146":147,"147":148,"148":149,"149":150,"150":151,"151":152,"152":153,"153":154,"154":155,"155":156,"156":157,"157":158,"158":159,"159":160,"160":161,"161":162,"162":163,"163":164,"164":165,"165":166,"166":167,"167":168,"168":169,"169":170,"170":171,"171":172,"172":173,"173":174,"174":175,"175":176,"176":177,"177":178,"178":179,"179":180,"180":181,"181":182,"182":183},"unk_token":"[UNK]"}}
|
|
|
|
src/music/representation_learning/mlm_pretrain/models/music-spanbert/config.json
DELETED
@@ -1,20 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"attention_probs_dropout_prob": 0.1,
|
3 |
-
"gradient_checkpointing": false,
|
4 |
-
"hidden_act": "gelu",
|
5 |
-
"hidden_dropout_prob": 0.1,
|
6 |
-
"hidden_size": 768,
|
7 |
-
"initializer_range": 0.02,
|
8 |
-
"intermediate_size": 3072,
|
9 |
-
"layer_norm_eps": 1e-12,
|
10 |
-
"max_position_embeddings": 512,
|
11 |
-
"model_type": "bert",
|
12 |
-
"num_attention_heads": 12,
|
13 |
-
"num_hidden_layers": 12,
|
14 |
-
"pad_token_id": 0,
|
15 |
-
"position_embedding_type": "relative_key_query",
|
16 |
-
"transformers_version": "4.8.2",
|
17 |
-
"type_vocab_size": 2,
|
18 |
-
"use_cache": true,
|
19 |
-
"vocab_size": 30522
|
20 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/music/representation_learning/mlm_pretrain/models/music-spanbert/tokenizer.json
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
{"version":"1.0","truncation":null,"padding":null,"added_tokens":[],"normalizer":{"type":"Lowercase"},"pre_tokenizer":{"type":"Whitespace"},"post_processor":null,"decoder":null,"model":{"type":"WordLevel","vocab":{"[PAD]":0,"[MASK]":1,"[UNK]":2,"2":3,"3":4,"4":5,"5":6,"6":7,"7":8,"8":9,"9":10,"10":11,"11":12,"12":13,"13":14,"14":15,"15":16,"16":17,"17":18,"18":19,"19":20,"20":21,"21":22,"22":23,"23":24,"24":25,"25":26,"26":27,"27":28,"28":29,"29":30,"30":31,"31":32,"32":33,"33":34,"34":35,"35":36,"36":37,"37":38,"38":39,"39":40,"40":41,"41":42,"42":43,"43":44,"44":45,"45":46,"46":47,"47":48,"48":49,"49":50,"50":51,"51":52,"52":53,"53":54,"54":55,"55":56,"56":57,"57":58,"58":59,"59":60,"60":61,"61":62,"62":63,"63":64,"64":65,"65":66,"66":67,"67":68,"68":69,"69":70,"70":71,"71":72,"72":73,"73":74,"74":75,"75":76,"76":77,"77":78,"78":79,"79":80,"80":81,"81":82,"82":83,"83":84,"84":85,"85":86,"86":87,"87":88,"88":89,"89":90,"90":91,"91":92,"92":93,"93":94,"94":95,"95":96,"96":97,"97":98,"98":99,"99":100,"100":101,"101":102,"102":103,"103":104,"104":105,"105":106,"106":107,"107":108,"108":109,"109":110,"110":111,"111":112,"112":113,"113":114,"114":115,"115":116,"116":117,"117":118,"118":119,"119":120,"120":121,"121":122,"122":123,"123":124,"124":125,"125":126,"126":127,"127":128,"128":129,"129":130,"130":131,"131":132,"132":133,"133":134,"134":135,"135":136,"136":137,"137":138,"138":139,"139":140,"140":141,"141":142,"142":143,"143":144,"144":145,"145":146,"146":147,"147":148,"148":149,"149":150,"150":151,"151":152,"152":153,"153":154,"154":155,"155":156,"156":157,"157":158,"158":159,"159":160,"160":161,"161":162,"162":163,"163":164,"164":165,"165":166,"166":167,"167":168,"168":169,"169":170,"170":171,"171":172,"172":173,"173":174,"174":175,"175":176,"176":177,"177":178,"178":179,"179":180,"180":181,"181":182,"182":183},"unk_token":"[UNK]"}}
|
|
|
|
src/music/representation_learning/mlm_pretrain/models/music-t5-small/config.json
DELETED
@@ -1,56 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"architectures": [
|
3 |
-
"T5WithLMHeadModel"
|
4 |
-
],
|
5 |
-
"d_ff": 2048,
|
6 |
-
"d_kv": 64,
|
7 |
-
"d_model": 512,
|
8 |
-
"decoder_start_token_id": 0,
|
9 |
-
"dropout_rate": 0.1,
|
10 |
-
"eos_token_id": 1,
|
11 |
-
"feed_forward_proj": "relu",
|
12 |
-
"gradient_checkpointing": false,
|
13 |
-
"initializer_factor": 1.0,
|
14 |
-
"is_encoder_decoder": true,
|
15 |
-
"layer_norm_epsilon": 1e-06,
|
16 |
-
"model_type": "t5",
|
17 |
-
"n_positions": 512,
|
18 |
-
"num_decoder_layers": 6,
|
19 |
-
"num_heads": 8,
|
20 |
-
"num_layers": 6,
|
21 |
-
"output_past": true,
|
22 |
-
"pad_token_id": 0,
|
23 |
-
"relative_attention_num_buckets": 32,
|
24 |
-
"task_specific_params": {
|
25 |
-
"summarization": {
|
26 |
-
"early_stopping": true,
|
27 |
-
"length_penalty": 2.0,
|
28 |
-
"max_length": 200,
|
29 |
-
"min_length": 30,
|
30 |
-
"no_repeat_ngram_size": 3,
|
31 |
-
"num_beams": 4,
|
32 |
-
"prefix": "summarize: "
|
33 |
-
},
|
34 |
-
"translation_en_to_de": {
|
35 |
-
"early_stopping": true,
|
36 |
-
"max_length": 300,
|
37 |
-
"num_beams": 4,
|
38 |
-
"prefix": "translate English to German: "
|
39 |
-
},
|
40 |
-
"translation_en_to_fr": {
|
41 |
-
"early_stopping": true,
|
42 |
-
"max_length": 300,
|
43 |
-
"num_beams": 4,
|
44 |
-
"prefix": "translate English to French: "
|
45 |
-
},
|
46 |
-
"translation_en_to_ro": {
|
47 |
-
"early_stopping": true,
|
48 |
-
"max_length": 300,
|
49 |
-
"num_beams": 4,
|
50 |
-
"prefix": "translate English to Romanian: "
|
51 |
-
}
|
52 |
-
},
|
53 |
-
"transformers_version": "4.8.2",
|
54 |
-
"use_cache": true,
|
55 |
-
"vocab_size": 32128
|
56 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|