ccolas commited on
Commit
981764f
1 Parent(s): 8060600

Delete src

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. src/__init__.py +0 -0
  2. src/cocktails/__init__.py +0 -0
  3. src/cocktails/config.py +0 -21
  4. src/cocktails/pipeline/__init__.py +0 -0
  5. src/cocktails/pipeline/cocktail2affect.py +0 -372
  6. src/cocktails/pipeline/cocktailrep2recipe.py +0 -329
  7. src/cocktails/pipeline/get_affect2affective_cluster.py +0 -23
  8. src/cocktails/pipeline/get_cocktail2affective_cluster.py +0 -9
  9. src/cocktails/representation_learning/__init__.py +0 -0
  10. src/cocktails/representation_learning/dataset.py +0 -324
  11. src/cocktails/representation_learning/multihead_model.py +0 -148
  12. src/cocktails/representation_learning/run.py +0 -557
  13. src/cocktails/representation_learning/run_simple_net.py +0 -302
  14. src/cocktails/representation_learning/run_without_vae.py +0 -514
  15. src/cocktails/representation_learning/simple_model.py +0 -54
  16. src/cocktails/representation_learning/vae_model.py +0 -238
  17. src/cocktails/utilities/__init__.py +0 -0
  18. src/cocktails/utilities/analysis_utilities.py +0 -189
  19. src/cocktails/utilities/cocktail_category_detection_utilities.py +0 -221
  20. src/cocktails/utilities/cocktail_generation_utilities/__init__.py +0 -0
  21. src/cocktails/utilities/cocktail_generation_utilities/individual.py +0 -587
  22. src/cocktails/utilities/cocktail_generation_utilities/population.py +0 -213
  23. src/cocktails/utilities/cocktail_utilities.py +0 -220
  24. src/cocktails/utilities/glass_and_volume_utilities.py +0 -42
  25. src/cocktails/utilities/ingredients_utilities.py +0 -209
  26. src/cocktails/utilities/other_scrubbing_utilities.py +0 -240
  27. src/debugger.py +0 -180
  28. src/music/__init__.py +0 -0
  29. src/music/config.py +0 -72
  30. src/music/pipeline/__init__.py +0 -0
  31. src/music/pipeline/audio2midi.py +0 -52
  32. src/music/pipeline/audio2piano_solo_prob.py +0 -47
  33. src/music/pipeline/encoded2rep.py +0 -88
  34. src/music/pipeline/midi2processed.py +0 -152
  35. src/music/pipeline/music_pipeline.py +0 -86
  36. src/music/pipeline/processed2encoded.py +0 -52
  37. src/music/pipeline/processed2handcodedrep.py +0 -343
  38. src/music/pipeline/synth2audio.py +0 -170
  39. src/music/pipeline/synth2midi.py +0 -146
  40. src/music/pipeline/url2audio.py +0 -119
  41. src/music/representation_analysis/__init__.py +0 -0
  42. src/music/representation_analysis/analyze_rep.py +0 -146
  43. src/music/representation_learning/__init__.py +0 -0
  44. src/music/representation_learning/mlm_pretrain/__init__.py +0 -0
  45. src/music/representation_learning/mlm_pretrain/data_collators.py +0 -180
  46. src/music/representation_learning/mlm_pretrain/models/music-bert/config.json +0 -20
  47. src/music/representation_learning/mlm_pretrain/models/music-bert/tokenizer.json +0 -1
  48. src/music/representation_learning/mlm_pretrain/models/music-spanbert/config.json +0 -20
  49. src/music/representation_learning/mlm_pretrain/models/music-spanbert/tokenizer.json +0 -1
  50. src/music/representation_learning/mlm_pretrain/models/music-t5-small/config.json +0 -56
src/__init__.py DELETED
File without changes
src/cocktails/__init__.py DELETED
File without changes
src/cocktails/config.py DELETED
@@ -1,21 +0,0 @@
1
- import os
2
-
3
- REPO_PATH = '/'.join(os.path.abspath(__file__).split('/')[:-3]) + '/'
4
-
5
- # QUADRUPLETS_PATH = REPO_PATH + 'checkpoints/cocktail_representation/quadruplets.pickle'
6
- INGREDIENTS_LIST_PATH = REPO_PATH + 'checkpoints/cocktail_representation/ingredient_list.csv'
7
- # ING_MATCH_SCORE_Q_PATH = REPO_PATH + 'checkpoints/cocktail_representation/ingredient_match_score_q.txt'
8
- # ING_MATCH_SCORE_COUNT_PATH = REPO_PATH + 'checkpoints/cocktail_representation/ingredient_match_score_count.txt'
9
- # COCKTAIL_DATA_FOLDER_PATH = REPO_PATH + 'checkpoints/cocktail_representation/'
10
- COCKTAILS_CSV_DATA = REPO_PATH + 'checkpoints/cocktail_representation/cocktails_data.csv'
11
- # COCKTAILS_PKL_DATA = REPO_PATH + 'checkpoints/cocktail_representation/cocktails_data.pkl'
12
- # COCKTAILS_URL_DATA = REPO_PATH + 'checkpoints/cocktail_representation/cocktails_names_urls.pkl'
13
- EXPERIMENT_PATH = REPO_PATH + 'experiments/cocktails/representation_learning/'
14
- # ANALYSIS_PATH = REPO_PATH + 'experiments/cocktails/representation_analysis/'
15
- # REPRESENTATIONS_PATH = REPO_PATH + 'experiments/cocktails/learned_representations/'
16
-
17
- FULL_COCKTAIL_REP_PATH = REPO_PATH + "/checkpoints/cocktail_representation/handcoded_reps/cocktail_handcoded_reps_minmax_norm-1_1_dim13_customkeys.txt"
18
- RECIPE2FEATURES_PATH = REPO_PATH + "/checkpoints/cocktail_representation/" # get this by running run_without_vae
19
- COCKTAIL_REP_CHKPT_PATH = REPO_PATH + "/checkpoints/cocktail_representation/handcoded_reps/"
20
- # FULL_COCKTAIL_REP_PATH = REPO_PATH + "experiments/cocktails/representation_analysis/affective_mapping/clustered_representations/all_cocktail_reps_norm-1_1_custom_keys_dim13.txt'
21
- COCKTAIL_NN_PATH = REPO_PATH + "/checkpoints/cocktail_representation/handcoded_reps/nn_model.pickle"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/cocktails/pipeline/__init__.py DELETED
File without changes
src/cocktails/pipeline/cocktail2affect.py DELETED
@@ -1,372 +0,0 @@
1
- import pandas as pd
2
- import numpy as np
3
- import os
4
- from src.cocktails.utilities.cocktail_utilities import get_bunch_of_rep_keys
5
- from src.cocktails.utilities.other_scrubbing_utilities import print_recipe
6
- from src.cocktails.config import COCKTAILS_CSV_DATA
7
- from src.music.config import CHECKPOINTS_PATH, EXPERIMENT_PATH
8
- import matplotlib.pyplot as plt
9
- from sklearn.cluster import KMeans
10
- from sklearn.mixture import GaussianMixture
11
- from sklearn.neighbors import NearestNeighbors
12
- import pickle
13
- import random
14
-
15
- experiment_path = EXPERIMENT_PATH + '/cocktails/representation_analysis/affective_mapping/'
16
- min_max_path = CHECKPOINTS_PATH + "/cocktail_representation/minmax/"
17
- cluster_model_path = CHECKPOINTS_PATH + "/music2cocktails/affects2affect_cluster/cluster_model.pickle"
18
- affective_space_dimensions = ((-1, 1), (-1, 1), (-1, 1)) # valence, arousal, dominance
19
- n_splits = (3, 3, 2) # number of bins per dimension
20
- # dimensions_weights = [1, 1, 0.5]
21
- dimensions_weights = [1, 1, 1]
22
- total_n_clusters = np.prod(n_splits) # total number of bins
23
- affective_boundaries = [np.arange(asd[0], asd[1]+1e-6, (asd[1] - asd[0]) / n_split) for asd, n_split in zip(affective_space_dimensions, n_splits)]
24
- for af in affective_boundaries:
25
- af[-1] += 1e-6
26
- all_keys = get_bunch_of_rep_keys()['custom']
27
- original_affective_keys = get_bunch_of_rep_keys()['affective']
28
- affective_keys = [a.split(' ')[1] for a in original_affective_keys]
29
- random.seed(0)
30
- cluster_colors = ['#%06X' % random.randint(0, 0xFFFFFF) for _ in range(total_n_clusters)]
31
-
32
- clustering_method = 'k_means' # 'k_means', 'handcoded', 'agglo', 'spectral'
33
- if clustering_method != 'handcoded':
34
- total_n_clusters = 10
35
- min_arousal = np.loadtxt(min_max_path + 'min_arousal.txt')
36
- max_arousal = np.loadtxt(min_max_path + 'max_arousal.txt')
37
- min_val = np.loadtxt(min_max_path + 'min_valence.txt')
38
- max_val = np.loadtxt(min_max_path + 'max_valence.txt')
39
- min_dom = np.loadtxt(min_max_path + 'min_dominance.txt')
40
- max_dom = np.loadtxt(min_max_path + 'max_dominance.txt')
41
-
42
- def get_cocktail_reps(path, save=False):
43
- cocktail_data = pd.read_csv(path)
44
- cocktail_reps = np.array([cocktail_data[k] for k in original_affective_keys]).transpose()
45
- n_data, dim_rep = cocktail_reps.shape
46
- # print(f'{n_data} data points of {dim_rep} dimensions: {affective_keys}')
47
- cocktail_reps = normalize_cocktail_reps_affective(cocktail_reps, save=save)
48
- if save:
49
- np.savetxt(experiment_path + f'cocktail_reps_for_affective_mapping_-1_1_norm_sigmoid_rescaling_{dim_rep}_keys.txt', cocktail_reps)
50
- return cocktail_reps
51
-
52
- def sigmoid(x, shift, beta):
53
- return (1 / (1 + np.exp(-(x + shift) * beta)) - 0.5) * 2
54
-
55
- def normalize_cocktail_reps_affective(cocktail_reps, save=False):
56
- if save:
57
- min_cr = cocktail_reps.min(axis=0)
58
- max_cr = cocktail_reps.max(axis=0)
59
- np.savetxt(min_max_path + 'min_cocktail_reps_affective.txt', min_cr)
60
- np.savetxt(min_max_path + 'max_cocktail_reps_affective.txt', max_cr)
61
- else:
62
- min_cr = np.loadtxt(min_max_path + 'min_cocktail_reps_affective.txt')
63
- max_cr = np.loadtxt(min_max_path + 'max_cocktail_reps_affective.txt')
64
- cocktail_reps = ((cocktail_reps - min_cr) / (max_cr - min_cr) - 0.5) * 2
65
- cocktail_reps[:, 0] = sigmoid(cocktail_reps[:, 0], shift=0.05, beta=4)
66
- cocktail_reps[:, 1] = sigmoid(cocktail_reps[:, 1], shift=0.3, beta=5)
67
- cocktail_reps[:, 2] = sigmoid(cocktail_reps[:, 2], shift=0.15, beta=3)
68
- cocktail_reps[:, 3] = sigmoid(cocktail_reps[:, 3], shift=0.9, beta=20)
69
- cocktail_reps[:, 4] = sigmoid(cocktail_reps[:, 4], shift=0, beta=4)
70
- cocktail_reps[:, 5] = sigmoid(cocktail_reps[:, 5], shift=0.2, beta=3)
71
- cocktail_reps[:, 6] = sigmoid(cocktail_reps[:, 6], shift=0.5, beta=5)
72
- cocktail_reps[:, 7] = sigmoid(cocktail_reps[:, 7], shift=0.2, beta=6)
73
- return cocktail_reps
74
-
75
- def plot(cocktail_reps):
76
- dim_rep = cocktail_reps.shape[1]
77
- for i in range(dim_rep):
78
- for j in range(i+1, dim_rep):
79
- plt.figure()
80
- plt.scatter(cocktail_reps[:, i], cocktail_reps[:, j], s=150, alpha=0.5)
81
- plt.xlabel(affective_keys[i])
82
- plt.ylabel(affective_keys[j])
83
- plt.savefig(experiment_path + f'scatters/{affective_keys[i]}_vs_{affective_keys[j]}.png', dpi=300)
84
- plt.close('all')
85
- plt.figure()
86
- plt.hist(cocktail_reps[:, i])
87
- plt.xlabel(affective_keys[i])
88
- plt.savefig(experiment_path + f'hists/{affective_keys[i]}.png', dpi=300)
89
- plt.close('all')
90
-
91
- def get_clusters(affective_coordinates, save=False):
92
- if clustering_method in ['k_means', 'gmm',]:
93
- if clustering_method == 'k_means': model = KMeans(n_clusters=total_n_clusters)
94
- elif clustering_method == 'gmm': model = GaussianMixture(n_components=total_n_clusters, covariance_type="full")
95
- model.fit(affective_coordinates * np.array(dimensions_weights))
96
-
97
- def find_cluster(aff_coord):
98
- if aff_coord.ndim == 1:
99
- aff_coord = aff_coord.reshape(1, -1)
100
- return model.predict(aff_coord * np.array(dimensions_weights))
101
- cluster_centers = model.cluster_centers_ if clustering_method == 'k_means' else []
102
- if save:
103
- to_save = dict(cluster_model=model,
104
- cluster_centers=cluster_centers,
105
- nb_clusters=len(cluster_centers),
106
- dimensions_weights=dimensions_weights)
107
- with open(cluster_model_path, 'wb') as f:
108
- pickle.dump(to_save, f)
109
- stop= 1
110
-
111
- elif clustering_method == 'handcoded':
112
- def find_cluster(aff_coord):
113
- if aff_coord.ndim == 1:
114
- aff_coord = aff_coord.reshape(1, -1)
115
- cluster_coordinates = []
116
- for i in range(aff_coord.shape[0]):
117
- cluster_coordinates.append([np.argwhere(affective_boundaries[j] <= aff_coord[i, j]).flatten()[-1] for j in range(3)])
118
- cluster_coordinates = np.array(cluster_coordinates)
119
- cluster_ids = cluster_coordinates[:, 0] * np.prod(n_splits[1:]) + cluster_coordinates[:, 1] * n_splits[-1] + cluster_coordinates[:, 2]
120
- return cluster_ids
121
- # find cluster centers
122
- cluster_centers = []
123
- for i in range(n_splits[0]):
124
- asd = affective_space_dimensions[0]
125
- x_coordinate = np.arange(asd[0] + 1 / n_splits[0], asd[1], (asd[1] - asd[0]) / n_splits[0])[i]
126
- for j in range(n_splits[1]):
127
- asd = affective_space_dimensions[1]
128
- y_coordinate = np.arange(asd[0] + 1 / n_splits[1], asd[1], (asd[1] - asd[0]) / n_splits[1])[j]
129
- for k in range(n_splits[2]):
130
- asd = affective_space_dimensions[2]
131
- z_coordinate = np.arange(asd[0] + 1 / n_splits[2], asd[1], (asd[1] - asd[0]) / n_splits[2])[k]
132
- cluster_centers.append([x_coordinate, y_coordinate, z_coordinate])
133
- cluster_centers = np.array(cluster_centers)
134
- else:
135
- raise NotImplemented
136
- cluster_ids = find_cluster(affective_coordinates)
137
- return cluster_ids, cluster_centers, find_cluster
138
-
139
-
140
- def cocktail2affect(cocktail_reps, save=False):
141
- if cocktail_reps.ndim == 1:
142
- cocktail_reps = cocktail_reps.reshape(1, -1)
143
-
144
- assert affective_keys == ['booze', 'sweet', 'sour', 'fizzy', 'complex', 'bitter', 'spicy', 'colorful']
145
- all_weights = []
146
-
147
- # valence
148
- # + sweet - bitter - booze + colorful
149
- weights = np.array([-1, 1, 0, 0, 0, -1, 0, 1])
150
- valence = (cocktail_reps * weights).sum(axis=1)
151
- if save:
152
- min_ = valence.min()
153
- max_ = valence.max()
154
- np.savetxt(min_max_path + 'min_valence.txt', np.array([min_]))
155
- np.savetxt(min_max_path + 'max_valence.txt', np.array([max_]))
156
- else:
157
- min_ = min_val
158
- max_ = max_val
159
- valence = 2 * ((valence - min_) / (max_ - min_) - 0.5)
160
- valence = sigmoid(valence, shift=0.1, beta=3.5)
161
- valence = valence.reshape(-1, 1)
162
- all_weights.append(weights.copy())
163
-
164
- # arousal
165
- # + fizzy + sour + complex - sweet + spicy + bitter
166
- # weights = np.array([0, -1, 1, 1, 1, 1, 1, 0])
167
- weights = np.array([0.7, 0, 1.5, 1.5, 0.6, 0, 0.6, 0])
168
- arousal = (cocktail_reps * weights).sum(axis=1)
169
- if save:
170
- min_ = arousal.min()
171
- max_ = arousal.max()
172
- np.savetxt(min_max_path + 'min_arousal.txt', np.array([min_]))
173
- np.savetxt(min_max_path + 'max_arousal.txt', np.array([max_]))
174
- else:
175
- min_, max_ = min_arousal, max_arousal
176
- arousal = 2 * ((arousal - min_) / (max_ - min_) - 0.5) # normalize to -1, 1
177
- arousal = sigmoid(arousal, shift=0.3, beta=4)
178
- arousal = arousal.reshape(-1, 1)
179
- all_weights.append(weights.copy())
180
-
181
- # dominance
182
- # assert affective_keys == ['booze', 'sweet', 'sour', 'fizzy', 'complex', 'bitter', 'spicy', 'colorful']
183
- # + booze + fizzy - complex - bitter - sweet
184
- weights = np.array([1.5, -0.8, 0, 0.7, -1, -1.5, 0, 0])
185
- dominance = (cocktail_reps * weights).sum(axis=1)
186
- if save:
187
- min_ = dominance.min()
188
- max_ = dominance.max()
189
- np.savetxt(min_max_path + 'min_dominance.txt', np.array([min_]))
190
- np.savetxt(min_max_path + 'max_dominance.txt', np.array([max_]))
191
- else:
192
- min_, max_ = min_dom, max_dom
193
- dominance = 2 * ((dominance - min_) / (max_ - min_) - 0.5)
194
- dominance = sigmoid(dominance, shift=-0.05, beta=5)
195
- dominance = dominance.reshape(-1, 1)
196
- all_weights.append(weights.copy())
197
-
198
- affective_coordinates = np.concatenate([valence, arousal, dominance], axis=1)
199
- # if save:
200
- # assert (affective_coordinates.min(axis=0) == np.array([ac[0] for ac in affective_space_dimensions])).all()
201
- # assert (affective_coordinates.max(axis=0) == np.array([ac[1] for ac in affective_space_dimensions])).all()
202
- return affective_coordinates, all_weights
203
-
204
- def save_reps(path, affective_cluster_ids):
205
- cocktail_data = pd.read_csv(path)
206
- rep_keys = get_bunch_of_rep_keys()['custom']
207
- cocktail_reps = np.array([cocktail_data[k] for k in rep_keys]).transpose()
208
- np.savetxt(experiment_path + 'clustered_representations/' + f'min_cocktail_reps_custom_keys_dim{cocktail_reps.shape[1]}.txt', cocktail_reps.min(axis=0))
209
- np.savetxt(experiment_path + 'clustered_representations/' + f'max_cocktail_reps_custom_keys_dim{cocktail_reps.shape[1]}.txt', cocktail_reps.max(axis=0))
210
- cocktail_reps = ((cocktail_reps - cocktail_reps.min(axis=0)) / (cocktail_reps.max(axis=0) - cocktail_reps.min(axis=0)) - 0.5) * 2 # normalize in -1, 1
211
- np.savetxt(experiment_path + 'clustered_representations/' + f'all_cocktail_reps_norm-1_1_custom_keys_dim{cocktail_reps.shape[1]}.txt', cocktail_reps)
212
- np.savetxt(experiment_path + 'clustered_representations/' + 'affective_cluster_ids.txt', affective_cluster_ids)
213
- for cluster_id in sorted(set(affective_cluster_ids)):
214
- indexes = np.argwhere(affective_cluster_ids == cluster_id).flatten()
215
- reps = cocktail_reps[indexes, :]
216
- np.savetxt(experiment_path + 'clustered_representations/' + f'rep_cluster{cluster_id}_norm-1_1_custom_keys_dim{cocktail_reps.shape[1]}.txt', reps)
217
-
218
- def study_affects(affective_coordinates, affective_cluster_ids):
219
- plt.figure()
220
- plt.hist(affective_cluster_ids, bins=total_n_clusters)
221
- plt.xlabel('Affective cluster ids')
222
- plt.xticks(np.arange(total_n_clusters))
223
- plt.savefig(experiment_path + 'affective_cluster_distrib.png')
224
- fig = plt.gcf()
225
- plt.close(fig)
226
-
227
- fig = plt.figure()
228
- ax = fig.add_subplot(projection='3d')
229
- ax.set_xlim([-1, 1])
230
- ax.set_ylim([-1, 1])
231
- ax.set_zlim([-1, 1])
232
- for cluster_id in sorted(set(affective_cluster_ids)):
233
- indexes = np.argwhere(affective_cluster_ids == cluster_id).flatten()
234
- ax.scatter(affective_coordinates[indexes, 0], affective_coordinates[indexes, 1], affective_coordinates[indexes, 2], c=cluster_colors[cluster_id], s=150)
235
- ax.set_xlabel('Valence')
236
- ax.set_ylabel('Arousal')
237
- ax.set_zlabel('Dominance')
238
- stop = 1
239
- plt.savefig(experiment_path + 'scatters_affect/affective_mapping.png')
240
- fig = plt.gcf()
241
- plt.close(fig)
242
-
243
- affects = ['Valence', 'Arousal', 'Dominance']
244
- for i in range(3):
245
- for j in range(i + 1, 3):
246
- fig = plt.figure()
247
- ax = fig.add_subplot()
248
- for cluster_id in sorted(set(affective_cluster_ids)):
249
- indexes = np.argwhere(affective_cluster_ids == cluster_id).flatten()
250
- ax.scatter(affective_coordinates[indexes, i], affective_coordinates[indexes, j], alpha=0.5, c=cluster_colors[cluster_id], s=150)
251
- ax.set_xlabel(affects[i])
252
- ax.set_ylabel(affects[j])
253
- plt.savefig(experiment_path + f'scatters_affect/scatter_{affects[i]}_vs_{affects[j]}.png')
254
- fig = plt.gcf()
255
- plt.close(fig)
256
- plt.figure()
257
- plt.hist(affective_coordinates[:, i])
258
- plt.xlabel(affects[i])
259
- plt.savefig(experiment_path + f'hists_affect/hist_{affects[i]}.png')
260
- fig = plt.gcf()
261
- plt.close(fig)
262
- plt.close('all')
263
- stop = 1
264
-
265
- def sample_clusters(path, cocktail_reps, all_weights, affective_cluster_ids, affective_cluster_centers, affective_coordinates, n_samples=4):
266
- cocktail_data = pd.read_csv(path)
267
- these_cocktail_reps = normalize_cocktail_reps_affective(np.array([cocktail_data[k] for k in original_affective_keys]).transpose())
268
-
269
- names = cocktail_data['names']
270
- urls = cocktail_data['urls']
271
- ingr_str = cocktail_data['ingredients_str']
272
- for cluster_id in sorted(set(affective_cluster_ids)):
273
- indexes = np.argwhere(affective_cluster_ids == cluster_id).flatten()
274
- print('\n\n\n---------\n----------\n-----------\n')
275
- cluster_str = ''
276
- cluster_str += f'Affective cluster #{cluster_id}' + \
277
- f'\n\tSize: {len(indexes)}' + \
278
- f'\n\tCenter: ' + \
279
- f'\n\t\tVal: {affective_cluster_centers[cluster_id][0]:.2f}, ' + \
280
- f'\n\t\tArousal: {affective_cluster_centers[cluster_id][1]:.2f}, ' + \
281
- f'\n\t\tDominance: {affective_cluster_centers[cluster_id][2]:.2f}'
282
- print(cluster_str)
283
- if affective_cluster_centers[cluster_id][2] == np.max(affective_cluster_centers[:, 2]):
284
- stop = 1
285
- sampled_idx = np.random.choice(indexes, size=min(len(indexes), n_samples), replace=False)
286
- cocktail_str = ''
287
- for i in sampled_idx:
288
- assert np.sum(cocktail_reps[i] - these_cocktail_reps[i]) < 1e-9
289
- cocktail_str += f'\n\n-------------'
290
- cocktail_str += print_recipe(ingr_str[i], name=names[i], to_print=False)
291
- cocktail_str += f'\nUrl: {urls[i]}'
292
- cocktail_str += '\n\nRepresentation: ' + ', '.join([f'{af}: {cr:.2f}' for af, cr in zip(affective_keys, cocktail_reps[i])]) + '\n'
293
- cocktail_str += '\n' + generate_explanation(cocktail_reps[i], all_weights, affective_coordinates[i])
294
- print(cocktail_str)
295
- stop = 1
296
- cluster_str += '\n' + cocktail_str
297
- with open(f"/home/cedric/Documents/pianocktail/experiments/cocktails/representation_analysis/affective_mapping/clusters/cluster_{cluster_id}", 'w') as f:
298
- f.write(cluster_str)
299
- stop = 1
300
-
301
- def explanation_per_dimension(i, cocktail_rep, all_weights, aff_coord):
302
- names = ['valence', 'arousal', 'dominance']
303
- weights = all_weights[i]
304
- explanation_str = f'\n{names[i].capitalize()} explanation ({aff_coord[i]:.2f}):'
305
- strengths = np.abs(weights * cocktail_rep)
306
- strengths /= strengths.sum()
307
- indexes = np.flip(np.argsort(strengths))
308
- for ind in indexes:
309
- if strengths[ind] != 0:
310
- if np.sign(weights[ind]) == np.sign(cocktail_rep[ind]):
311
- keyword = 'high' if cocktail_rep[ind] > 0 else 'low'
312
- explanation_str += f'\n\t{int(strengths[ind]*100)}%: higher {names[i]} because {keyword} {affective_keys[ind]}'
313
- else:
314
- keyword = 'high' if cocktail_rep[ind] > 0 else 'low'
315
- explanation_str += f'\n\t{int(strengths[ind]*100)}%: low {names[i]} because {keyword} {affective_keys[ind]}'
316
- return explanation_str
317
-
318
- def generate_explanation(cocktail_rep, all_weights, aff_coord):
319
- explanation_str = ''
320
- for i in range(3):
321
- explanation_str += explanation_per_dimension(i, cocktail_rep, all_weights, aff_coord)
322
- return explanation_str
323
-
324
- def cocktails2affect_clusters(cocktail_rep):
325
- if cocktail_rep.ndim == 1:
326
- cocktail_rep = cocktail_rep.reshape(1, -1)
327
- affective_coordinates, _ = cocktail2affect(cocktail_rep)
328
- affective_cluster_ids, _, _ = get_clusters(affective_coordinates)
329
- return affective_cluster_ids
330
-
331
-
332
- def setup_affective_space(path, save=False):
333
- cocktail_data = pd.read_csv(path)
334
- names = cocktail_data['names']
335
- recipes = cocktail_data['ingredients_str']
336
- urls = cocktail_data['urls']
337
- reps = get_cocktail_reps(path)
338
- affective_coordinates, all_weights = cocktail2affect(reps)
339
- affective_cluster_ids, affective_cluster_centers, find_cluster = get_clusters(affective_coordinates, save=save)
340
- nn_model = NearestNeighbors(n_neighbors=1)
341
- nn_model.fit(affective_coordinates)
342
- def cocktail2affect_cluster(cocktail_rep):
343
- affective_coordinates, _ = cocktail2affect(cocktail_rep)
344
- return find_cluster(affective_coordinates)
345
-
346
- affective_clusters = dict(affective_coordinates=affective_coordinates, # coordinates of cocktail in affective space
347
- affective_cluster_ids=affective_cluster_ids, # cluster id of cocktails
348
- affective_cluster_centers=affective_cluster_centers, # cluster centers in affective space
349
- affective_weights=all_weights, # weights to compute valence, arousal, dominance from cocktail representations
350
- original_affective_keys=original_affective_keys,
351
- cocktail_reps=reps, # cocktail representations from the dataset (normalized)
352
- find_cluster=find_cluster, # function to retrieve a cluster from affective coordinates
353
- nn_model=nn_model, # to predict the nearest neighbor affective space,
354
- names=names, # names of cocktails in the dataset
355
- urls=urls, # urls from the dataset
356
- recipes=recipes, # recipes of the dataset
357
- cocktail2affect=cocktail2affect, # function to compute affects from cocktail representations
358
- cocktails2affect_clusters=cocktails2affect_clusters,
359
- cocktail2affect_cluster=cocktail2affect_cluster
360
- )
361
-
362
- return affective_clusters
363
-
364
- if __name__ == '__main__':
365
- reps = get_cocktail_reps(COCKTAILS_CSV_DATA, save=True)
366
- # plot(reps)
367
- affective_coordinates, all_weights = cocktail2affect(reps, save=True)
368
- affective_cluster_ids, affective_cluster_centers, find_cluster = get_clusters(affective_coordinates)
369
- save_reps(COCKTAILS_CSV_DATA, affective_cluster_ids)
370
- study_affects(affective_coordinates, affective_cluster_ids)
371
- sample_clusters(COCKTAILS_CSV_DATA, reps, all_weights, affective_cluster_ids, affective_cluster_centers, affective_coordinates)
372
- setup_affective_space(COCKTAILS_CSV_DATA, save=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/cocktails/pipeline/cocktailrep2recipe.py DELETED
@@ -1,329 +0,0 @@
1
- import matplotlib.pyplot as plt
2
- import pickle
3
- from src.cocktails.utilities.cocktail_generation_utilities.population import *
4
- from src.cocktails.utilities.glass_and_volume_utilities import glass_volume
5
- from src.cocktails.config import RECIPE2FEATURES_PATH
6
-
7
- def test_mutation_params(cocktail_reps):
8
- indexes = np.arange(cocktail_reps.shape[0])
9
- np.random.shuffle(indexes)
10
- perfs = []
11
- mutated_perfs = []
12
- pop_params = dict(mutation_params=dict(p_add_ing=0.7,
13
- p_remove_ing=0.7,
14
- p_switch_ing=0.5,
15
- p_change_q=0.7,
16
- delta_change_q=0.3,
17
- asexual_rep=True,
18
- crossover=True,
19
- ingredient_addition=(0.1, 0.05)),
20
- nb_generations=100,
21
- pop_size=100,
22
- nb_elites=10,
23
- dist='mse',
24
- n_neighbors=5)
25
-
26
- for i in indexes[:20]:
27
- target = cocktail_reps[i]
28
- for j in range(100):
29
- parent = IndividualCocktail(pop_params=pop_params,
30
- target_affective_cluster=None,
31
- target=target.copy())
32
- perfs.append(parent.perf)
33
- child = parent.get_child()[0]
34
- # child.compute_cocktail_rep()
35
- # child.compute_perf()
36
- if perfs[-1] != child.perf:
37
- mutated_perfs.append(child.perf)
38
- else:
39
- perfs.pop(-1)
40
- filtered_children = np.argwhere(np.array(mutated_perfs)==-100).flatten()
41
- non_filtered_ids = np.argwhere(np.logical_and(np.array(perfs)!=-100, np.array(mutated_perfs)!=-100)).flatten()
42
- print(f'Proportion of filtered: {filtered_children.size} / {len(mutated_perfs)} = {int(filtered_children.size / len(mutated_perfs)*100)}%')
43
- plt.figure()
44
- plt.scatter(np.array(perfs)[non_filtered_ids], np.array(mutated_perfs)[non_filtered_ids], s=100, alpha=0.5)
45
- plt.xlabel('parent perf')
46
- plt.ylabel('child perf')
47
- print(np.corrcoef(np.array(perfs)[non_filtered_ids], np.array(mutated_perfs)[non_filtered_ids])[0, 1])
48
- plt.show()
49
- stop = 1
50
-
51
- def test_crossover(cocktail_reps):
52
- indexes = np.arange(cocktail_reps.shape[0])
53
- np.random.shuffle(indexes)
54
- perfs = []
55
- mutated_perfs = []
56
- pop_params = dict(mutation_params=dict(p_add_ing=0.7,
57
- p_remove_ing=0.7,
58
- p_switch_ing=0.5,
59
- p_change_q=0.7,
60
- delta_change_q=0.3,
61
- asexual_rep=True,
62
- crossover=True,
63
- ingredient_addition=(0.1, 0.05)),
64
- nb_generations=100,
65
- pop_size=100,
66
- nb_elites=10,
67
- dist='mse',
68
- n_neighbors=5)
69
- for i in indexes[:20]:
70
- for j in range(100):
71
- target = cocktail_reps[i]
72
- parent1 = IndividualCocktail(pop_params=pop_params,
73
- target_affective_cluster=None,
74
- target=target.copy())
75
- parent2 = IndividualCocktail(pop_params=pop_params,
76
- target_affective_cluster=None,
77
- target=target.copy())
78
- child = parent1.get_child_with(parent2)[0]
79
- # child.compute_cocktail_rep()
80
- # child.compute_perf()
81
- perfs.append((parent1.perf + parent2.perf)/2)
82
- if perfs[-1] != child.perf:
83
- mutated_perfs.append(child.perf)
84
- else:
85
- perfs.pop(-1)
86
- filtered_children = np.argwhere(np.array(mutated_perfs)==-100).flatten()
87
- non_filtered_ids = np.argwhere(np.logical_and(np.array(perfs)>-45, np.array(mutated_perfs)!=-100)).flatten()
88
- print(f'Proportion of filtered: {filtered_children.size} / {len(mutated_perfs)} = {int(filtered_children.size / len(mutated_perfs)*100)}%')
89
- plt.figure()
90
- plt.scatter(np.array(perfs)[non_filtered_ids], np.array(mutated_perfs)[non_filtered_ids], s=100, alpha=0.5)
91
- plt.xlabel('parent perf')
92
- plt.ylabel('child perf')
93
- print(np.corrcoef(np.array(perfs)[non_filtered_ids], np.array(mutated_perfs)[non_filtered_ids])[0, 1])
94
- plt.show()
95
- stop = 1
96
-
97
- def run_comparisons():
98
- np.random.seed(0)
99
- indexes = np.arange(cocktail_reps.shape[0])
100
- np.random.shuffle(indexes)
101
- for n_neighbors in [0, 5]:
102
- id_str_neigh = '5neigh_' if n_neighbors == 5 else '0_neigh_'
103
- for asexual_rep in [True, False]:
104
- id_str_as = id_str_neigh + 'asexual_' if asexual_rep else id_str_neigh
105
- for crossover in [True, False]:
106
- id_str = id_str_as + 'crossover_' if crossover else id_str_as
107
- if crossover or asexual_rep:
108
- mutation_params = dict(p_add_ing = 0.5,
109
- p_remove_ing = 0.5,
110
- p_change_q = 0.5,
111
- delta_change_q = 0.3,
112
- asexual_rep=asexual_rep,
113
- crossover=crossover,
114
- ingredient_addition = (0.1, 0.05))
115
- nb_generations = 100
116
- pop_size=100
117
- nb_elites=10
118
- dist = 'mse'
119
- results = dict()
120
- print(id_str)
121
- for i, ind in enumerate(indexes[:30]):
122
- print(i+1)
123
- target_ing_str = data['ingredients_str'][ind]
124
- target = cocktail_reps[ind]
125
- population = Population(nb_generations=nb_generations, pop_size=pop_size, nb_elite=nb_elites,
126
- target=target, dist=dist, mutation_params=mutation_params,
127
- n_neighbors=n_neighbors, target_ing_str=target_ing_str, true_prep_type=data['category'][ind])
128
- population.run_evolution(verbose=False)
129
- best_scores, best_ind = population.get_best_score()
130
- recipes = [ind.get_recipe()[3] for ind in best_ind[:5]]
131
- results[str(ind)] = dict(best_scores=best_scores[:5], recipes=recipes, target=population.target_individual.get_recipe()[3])
132
- with open(f'/home/cedric/Desktop/ga_tests_{id_str}.pickle', 'wb') as f:
133
- pickle.dump(results, f)
134
-
135
- def get_cocktail_distribution(cocktail_reps):
136
- return (np.mean(cocktail_reps, axis=0), np.cov(cocktail_reps, rowvar=0))
137
-
138
- def sample_cocktails(cocktail_reps, n=10, target_affective_cluster=None, to_print=True):
139
- distrib = get_cocktail_distribution(cocktail_reps)
140
- sampled_cocktail_reps = np.random.multivariate_normal(distrib[0], distrib[1], size=n)
141
- recipes = []
142
- closest_recipes = []
143
- for i_c, cr in enumerate(sampled_cocktail_reps):
144
- population = setup_recipe_generation(cr.copy(), target_affective_cluster=target_affective_cluster)
145
- closest_recipes.append(population.nn_recipes[0])
146
- best_scores, best_individuals = population.run_evolution()
147
- recipes.append(best_individuals[0].get_recipe()[3])
148
- if to_print:
149
- print(f'Sample #{len(recipes)}:')
150
- print(recipes[-1])
151
- print('Closest from dataset:')
152
- print(closest_recipes[-1])
153
- stop = 1
154
- return recipes, closest_recipes
155
-
156
- def setup_recipe_generation(target, known_target_dict=None, target_affective_cluster=None):
157
- # pop_params = dict(mutation_params=dict(p_add_ing=0.7,
158
- # p_remove_ing=0.7,
159
- # p_switch_ing=0.5,
160
- # p_change_q=0.7,
161
- # delta_change_q=0.3,
162
- # asexual_rep=True,
163
- # crossover=True,
164
- # ingredient_addition=(0.1, 0.05)),
165
- # nb_generations=2, #100
166
- # pop_size=5, #100
167
- # nb_elites=2, #10
168
- # dist='mse',
169
- # n_neighbors=3) #5
170
- pop_params = dict(mutation_params=dict(p_add_ing=0.4,
171
- p_remove_ing=1,
172
- p_switch_ing=0.5,
173
- p_change_q=1,
174
- delta_change_q=0.3,
175
- asexual_rep=True,
176
- crossover=True,
177
- ingredient_addition=(0.1, 0.05)),
178
- nb_generations=100, # 100
179
- pop_size=100, # 100
180
- nb_elites=10, # 10
181
- dist='mse',
182
- n_neighbors=5) # 5
183
-
184
- population = Population(target=target, target_affective_cluster=target_affective_cluster, known_target_dict=known_target_dict, pop_params=pop_params)
185
- return population
186
-
187
- def cocktailrep2recipe(cocktail_rep, unit='mL', target_affective_cluster=None, known_target_dict=None, n_output=1, return_ind=False, verbose=True, full_verbose=False, level=0):
188
- init_time = time.time()
189
- if verbose: print(' ' * level + 'Generating cocktail..')
190
- if cocktail_rep.ndim > 1:
191
- assert cocktail_rep.shape[0] == 1
192
- cocktail_rep = cocktail_rep.flatten()
193
- # target_affective_cluster = target_affective_cluster[0]
194
- population = setup_recipe_generation(cocktail_rep.copy(), known_target_dict=known_target_dict, target_affective_cluster=target_affective_cluster)
195
- if full_verbose:
196
- print(' ' * (level + 2) + '3 nearest neighbors:')
197
- for i, recipe, score in zip(range(3), population.nn_recipes[:3], population.nn_scores[:3]):
198
- print(' ' * (level + 4) + f'#{i+1}, score: {score:.2f}')
199
- print(' ' * (level + 4) + recipe[1:].replace('None ()', '').replace('\t\t', ' ' * (level + 6)))
200
- best_scores, best_individuals = population.run_evolution(verbose=full_verbose, level=level+2)
201
- for i in range(n_output):
202
- best_individuals[i].make_recipe_fit_the_glass()
203
- instructions = [ind.get_instructions() for ind in best_individuals[:n_output]]
204
- recipes = [ind.get_recipe(unit=unit)[3] for ind in best_individuals[:n_output]]
205
- glasses = [ind.glass for ind in best_individuals[:n_output]]
206
- prep_types = [ind.prep_type for ind in best_individuals[:n_output]]
207
- for i, g, p, inst in zip(range(len(recipes)), glasses, prep_types, instructions):
208
- recipes[i] = recipes[i].replace('Recipe', 'Ingredients') + f'Serve in:\n {g.capitalize()} glass.\n' + inst
209
- if full_verbose:
210
- print(f'\n--------------\n{n_output} best results:')
211
- for i, recipe, score in zip(range(n_output), recipes, best_scores[:n_output]):
212
- print(f'#{i+1}, score: {score:.2f}')
213
- print(recipe)
214
- if verbose: print(' ' * (level + 2) + f'Generated in {int(time.time() - init_time)} seconds.')
215
- if return_ind:
216
- return recipes, best_scores[:n_output], best_individuals[:n_output]
217
- else:
218
- return recipes, best_scores[:n_output]
219
-
220
-
221
- def interpolate(cocktail_rep1, cocktail_rep2, alpha, verbose=False):
222
- recipe, score = cocktailrep2recipe(alpha * cocktail_rep1 + (1 - alpha) * cocktail_rep2, verbose=verbose)
223
- return recipe[0], score
224
-
225
- def interpolation_study(n_steps, cocktail_reps):
226
- alphas = np.arange(0, 1 + 1e-6, 1/(n_steps + 1))
227
- indexes = np.random.choice(np.arange(cocktail_reps.shape[0]), size=2, replace=False)
228
- target_ing_str1, target_ing_str2 = data['ingredients_str'][indexes[0]], data['ingredients_str'][indexes[1]]
229
- cocktail_rep1, cocktail_rep2 = cocktail_reps[indexes[0]], cocktail_reps[indexes[1]]
230
- recipes, scores = [], []
231
- for alpha in alphas:
232
- recipe, score = interpolate(cocktail_rep1, cocktail_rep2, alpha)
233
- recipes.append(recipe)
234
- scores.append(score[0])
235
- print('Point A:')
236
- print_recipe(ingredient_str=target_ing_str2)
237
- for i, alpha in enumerate(alphas):
238
- print(f'Alpha = {alpha}, score = {scores[i]}')
239
- print(recipes[i])
240
- print('Point B:')
241
- print_recipe(ingredient_str=target_ing_str1)
242
- stop = 1
243
-
244
- def test_robustness_affective_cluster(cocktail_reps):
245
- indexes = np.arange(cocktail_reps.shape[0])
246
- np.random.shuffle(indexes)
247
- matches = []
248
- for i in indexes:
249
- target_ing_str = data['ingredients_str'][i]
250
- true_prep_type = data['category'][i]
251
- target = cocktail_reps[i]
252
- # get affective cluster
253
- recipes, best_scores, best_inds = cocktailrep2recipe(cocktail_rep=target, target_ing_str=target_ing_str, true_prep_type=true_prep_type, n_output=1, verbose=False,
254
- return_ind=True)
255
-
256
- matches.append(best_inds[0].does_affective_cluster_match())
257
- print(np.mean(matches))
258
-
259
- def test(cocktail_reps):
260
- indexes = np.arange(these_cocktail_reps.shape[0])
261
- unnormalized_cr = np.array([data[k] for k in rep_keys]).transpose()
262
-
263
- for i in indexes:
264
- target_ing_str = data['ingredients_str'][i]
265
- true_prep_type = data['category'][i]
266
- target = these_cocktail_reps[i]
267
- # print('preptype:', true_prep_type)
268
- # print('cocktail unnormalized', np.sum(unnormalized_cr[i]), unnormalized_cr[i])
269
- # print('cocktail hand normalized', np.sum(normalize_cocktail(unnormalized_cr[i])), normalize_cocktail(unnormalized_cr[i]))
270
- # print('cocktail rep normalized', np.sum(these_cocktail_reps[i]), these_cocktail_reps[i])
271
- # print('cocktail rep normalized', np.sum(all_reps[i]), all_reps[i])
272
-
273
- population = setup_recipe_generation(target.copy(), target_ing_str=target_ing_str, target_affective_cluster=None, true_prep_type=true_prep_type)
274
- target = population.target_individual
275
- target.compute_perf()
276
- if target.perf < -50:
277
- print(i)
278
- print_recipe(target_ing_str)
279
- if not target.is_alcohol_present(): print('No alcohol')
280
- if not target.is_total_volume_enough(): print('small volume')
281
- if not target.does_fit_glass():
282
- print(target.end_volume)
283
- print(glass_volume[target.get_glass_type()] * 0.81)
284
- print('too much volume')
285
- if not target.is_alcohol_reasonable():
286
- print(f'amount of alcohol too small or too large: {target.alcohol_precentage}')
287
- stop = 1
288
-
289
-
290
- if __name__ == '__main__':
291
- these_cocktail_reps = COCKTAIL_REPS.copy()
292
- # test_crossover(these_cocktail_reps)
293
- # test_mutation_params(these_cocktail_reps)
294
- # test(these_cocktail_reps)
295
- # recipes, closest_recipes = sample_cocktails(these_cocktail_reps, n=10)
296
- # interpolation_study(n_steps=4, cocktail_reps=these_cocktail_reps)
297
- # test_robustness_affective_cluster(these_cocktail_reps)
298
- indexes = np.arange(these_cocktail_reps.shape[0])
299
- np.random.shuffle(indexes)
300
- # test_crossover(mutation_params, dist)
301
- # test_mutation_params(mutation_params, dist)
302
- stop = 1
303
- unnormalized_cr = np.array([data[k] for k in rep_keys]).transpose()
304
- for i in indexes:
305
- print(i)
306
- target_ing_str = data['ingredients_str'][i]
307
- target_prep_type = data['category'][i]
308
- target_glass = data['glass'][i]
309
-
310
- print('preptype:', target_prep_type)
311
- print('cocktail unnormalized', np.sum(unnormalized_cr[i]), unnormalized_cr[i])
312
- print('cocktail hand normalized', np.sum(normalize_cocktail(unnormalized_cr[i])), normalize_cocktail(unnormalized_cr[i]))
313
- print('cocktail rep normalized', np.sum(these_cocktail_reps[i]), these_cocktail_reps[i])
314
- print('cocktail rep normalized', np.sum(all_reps[i]), all_reps[i])
315
- print(i)
316
-
317
- print('___________Target')
318
- nn_model = NearestNeighbors()
319
- nn_model.fit(these_cocktail_reps)
320
- dists, indexes = nn_model.kneighbors(these_cocktail_reps[i].reshape(1, -1))
321
- print(indexes)
322
- print_recipe(target_ing_str)
323
- target = these_cocktail_reps[i]
324
- known_target_dict = dict(prep_type=target_prep_type,
325
- ing_str=target_ing_str,
326
- glass=target_glass)
327
- recipes, best_scores = cocktailrep2recipe(cocktail_rep=target, known_target_dict=known_target_dict, n_output=1, verbose=True, full_verbose=True)
328
-
329
- stop = 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/cocktails/pipeline/get_affect2affective_cluster.py DELETED
@@ -1,23 +0,0 @@
1
- from src.music.config import CHECKPOINTS_PATH
2
- import pickle
3
- import numpy as np
4
-
5
- # can be computed from cocktail2affect
6
- cluster_model_path = CHECKPOINTS_PATH + "/music2cocktails/affects2affect_cluster/cluster_model.pickle"
7
-
8
- def get_affect2affective_cluster():
9
- with open(cluster_model_path, 'rb') as f:
10
- data = pickle.load(f)
11
- model = data['cluster_model']
12
- dimensions_weights = data['dimensions_weights']
13
- def find_cluster(aff_coord):
14
- if aff_coord.ndim == 1:
15
- aff_coord = aff_coord.reshape(1, -1)
16
- return model.predict(aff_coord * np.array(dimensions_weights))
17
- return find_cluster
18
-
19
- def get_affective_cluster_centers():
20
- with open(cluster_model_path, 'rb') as f:
21
- data = pickle.load(f)
22
- return data['cluster_centers']
23
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/cocktails/pipeline/get_cocktail2affective_cluster.py DELETED
@@ -1,9 +0,0 @@
1
- from src.cocktails.pipeline.get_affect2affective_cluster import get_affect2affective_cluster
2
- from src.cocktails.pipeline.cocktail2affect import cocktail2affect
3
-
4
- def get_cocktail2affective_cluster():
5
- find_cluster = get_affect2affective_cluster()
6
- def cocktail2affect_cluster(cocktail_rep):
7
- affective_coordinates, _ = cocktail2affect(cocktail_rep)
8
- return find_cluster(affective_coordinates)
9
- return cocktail2affect_cluster
 
 
 
 
 
 
 
 
 
 
src/cocktails/representation_learning/__init__.py DELETED
File without changes
src/cocktails/representation_learning/dataset.py DELETED
@@ -1,324 +0,0 @@
1
- from torch.utils.data import Dataset
2
- import pickle
3
- from src.cocktails.utilities.ingredients_utilities import extract_ingredients, ingredient_list, ingredient_profiles, ingredients_per_type
4
- from src.cocktails.utilities.other_scrubbing_utilities import print_recipe
5
- import numpy as np
6
-
7
- def get_representation_from_ingredient(ingredients, quantities, max_q_per_ing, index, params):
8
- assert len(ingredients) == len(quantities)
9
- ing, q = ingredients[index], quantities[index]
10
- proportion = q / np.sum(quantities)
11
- index_ing = ingredient_list.index(ing)
12
- # add keys of profile
13
- rep_ingredient = []
14
- rep_ingredient += [ingredient_profiles[k][index_ing] for k in params['ing_keys']]
15
- # add category encoding
16
- # rep_ingredient += list(params['category_encodings'][ingredient_profiles['type'][index_ing]])
17
- # add quantitiy and relative quantity
18
- rep_ingredient += [q / max_q_per_ing[ing], proportion]
19
- ing_one_hot = np.zeros(len(ingredient_list))
20
- ing_one_hot[index_ing] = 1
21
- rep_ingredient += list(ing_one_hot)
22
- indexes_to_normalize = list(range(len(params['ing_keys'])))
23
- #TODO: should we add ing one hot? Or make sure no 2 ing have same embedding
24
- return np.array(rep_ingredient), indexes_to_normalize
25
-
26
- def get_max_n_ingredients(data):
27
- max_count = 0
28
- ingredient_set = set()
29
- alcohol_set = set()
30
- liqueur_set = set()
31
- ing_str = np.array(data['ingredients_str'])
32
- for i in range(len(data['names'])):
33
- ingredients, quantities = extract_ingredients(ing_str[i])
34
- max_count = max(max_count, len(ingredients))
35
- for ing in ingredients:
36
- ingredient_set.add(ing)
37
- if ing in ingredients_per_type['liquor']:
38
- alcohol_set.add(ing)
39
- if ing in ingredients_per_type['liqueur']:
40
- liqueur_set.add(ing)
41
- return max_count, ingredient_set, alcohol_set, liqueur_set
42
-
43
- # Add your custom dataset class here
44
- class MyDataset(Dataset):
45
- def __init__(self, split, params):
46
- data = params['raw_data']
47
- self.dim_rep_ingredient = params['dim_rep_ingredient']
48
- n_data = len(data["names"])
49
-
50
- preparation_list = sorted(set(data['category']))
51
- categories_list = sorted(set(data['subcategory']))
52
- glasses_list = sorted(set(data['glass']))
53
-
54
- max_ingredients, ingredient_set, liquor_set, liqueur_set = get_max_n_ingredients(data)
55
- ingredient_set = sorted(ingredient_set)
56
- self.ingredient_set = ingredient_set
57
-
58
- ingredient_quantities = [] # output of our network
59
- ingr_strs = np.array(data['ingredients_str'])
60
- for i in range(n_data):
61
- ingredients, quantities = extract_ingredients(ingr_strs[i])
62
- # get ingredient presence and quantity
63
- ingredient_q_rep = np.zeros([len(ingredient_set)])
64
- for ing, q in zip(ingredients, quantities):
65
- ingredient_q_rep[ingredient_set.index(ing)] = q
66
- ingredient_quantities.append(ingredient_q_rep)
67
-
68
- # take care of ingredient quantities (OUTPUTS)
69
- ingredient_quantities = np.array(ingredient_quantities)
70
- ingredients_presence = (ingredient_quantities>0).astype(np.int)
71
-
72
- min_ing_quantities = np.min(ingredient_quantities, axis=0)
73
- max_ing_quantities = np.max(ingredient_quantities, axis=0)
74
- def normalize_ing_quantities(ing_quantities):
75
- return ((ing_quantities - min_ing_quantities) / (max_ing_quantities - min_ing_quantities)).copy()
76
-
77
- def denormalize_ing_quantities(normalized_ing_quantities):
78
- return (normalized_ing_quantities * (max_ing_quantities - min_ing_quantities) + min_ing_quantities).copy()
79
- ing_q_when_present = ingredient_quantities.copy()
80
- for i in range(len(ing_q_when_present)):
81
- ing_q_when_present[i, np.where(ing_q_when_present[i, :] == 0)] = np.nan
82
- self.min_when_present_ing_quantities = np.nanmin(ing_q_when_present, axis=0)
83
-
84
-
85
- def filter_decoder_output(output):
86
- output_unnormalized = output * max_ing_quantities
87
- if output.ndim == 1:
88
- output_unnormalized[np.where(output_unnormalized<self.min_when_present_ing_quantities)] = 0
89
- else:
90
- for i in range(output.shape[0]):
91
- output_unnormalized[i, np.where(output_unnormalized[i] < self.min_when_present_ing_quantities)] = 0
92
- return output_unnormalized.copy()
93
- self.filter_decoder_output = filter_decoder_output
94
- # arg_mins = np.nanargmin(ing_q_when_present, axis=0)
95
- #
96
- # for ing, minq, argminq in zip(ingredient_set, self.min_when_present_ing_quantities, arg_mins):
97
- # print(f'__\n{ing}: {minq}')
98
- # print_recipe(ingr_strs[argminq])
99
- # ingredients, quantities = extract_ingredients(ingr_strs[argminq])
100
- # # get ingredient presence and quantity
101
- # ingredient_q_rep = np.zeros([len(ingredient_set)])
102
- # for ing, q in zip(ingredients, quantities):
103
- # ingredient_q_rep[ingredient_set.index(ing)] = q
104
- # print(np.array(data['urls'])[argminq])
105
- # stop = 1
106
-
107
- self.max_ing_quantities = max_ing_quantities
108
- self.mean_ing_quantities = np.mean(ingredient_quantities, axis=0)
109
- self.std_ing_quantities = np.std(ingredient_quantities, axis=0)
110
- if split == 'train':
111
- np.savetxt(params['save_path'] + 'min_when_present_ing_quantities.txt', self.min_when_present_ing_quantities)
112
- np.savetxt(params['save_path'] + 'max_ing_quantities.txt', max_ing_quantities)
113
- np.savetxt(params['save_path'] + 'mean_ing_quantities.txt', self.mean_ing_quantities)
114
- np.savetxt(params['save_path'] + 'std_ing_quantities.txt', self.std_ing_quantities)
115
-
116
- # print(ingredient_quantities[0])
117
- # ingredient_quantities = (ingredient_quantities - self.mean_ing_quantities) / self.std_ing_quantities
118
- # print(ingredient_quantities[0])
119
- # print(ingredient_quantities[0] * self.std_ing_quantities + self.mean_ing_quantities )
120
- ingredient_quantities = ingredient_quantities / max_ing_quantities#= normalize_ing_quantities(ingredient_quantities)
121
-
122
-
123
-
124
-
125
- max_q_per_ing = dict(zip(ingredient_set, max_ing_quantities))
126
- # print(ingredient_quantities[0])
127
- #########
128
- # Process input representation_analysis: list of ingredient representation_analysis
129
- #########
130
- input_data = [] # input of ingredient encoders
131
- all_ing_reps = []
132
- for i in range(n_data):
133
- ingredients, quantities = extract_ingredients(ingr_strs[i])
134
- # get ingredient presence and quantity
135
- ingredient_q_rep = np.zeros([len(ingredient_set)])
136
- for ing, q in zip(ingredients, quantities):
137
- ingredient_q_rep[ingredient_set.index(ing)] = q
138
- # get main liquor
139
- cocktail_rep = []
140
- for j in range(len(ingredients)):
141
- cocktail_rep.append(get_representation_from_ingredient(ingredients, quantities, max_q_per_ing, index=j, params=params)[0])
142
- all_ing_reps.append(cocktail_rep[-1].copy())
143
- input_data.append(cocktail_rep)
144
-
145
-
146
- all_ing_reps = np.array(all_ing_reps)
147
- min_ing_reps = np.min(all_ing_reps[:, params['indexes_ing_to_normalize']], axis=0)
148
- max_ing_reps = np.max(all_ing_reps[:, params['indexes_ing_to_normalize']], axis=0)
149
-
150
- def normalize_ing_reps(ing_reps):
151
- if ing_reps.ndim == 1:
152
- ing_reps = ing_reps.reshape(1, -1)
153
- out = ing_reps.copy()
154
- out[:, params['indexes_ing_to_normalize']] = (out[:, params['indexes_ing_to_normalize']] - min_ing_reps) / (max_ing_reps - min_ing_reps)
155
- return out
156
-
157
- def denormalize_ing_reps(normalized_ing_reps):
158
- if normalized_ing_reps.ndim == 1:
159
- normalized_ing_reps = normalized_ing_reps.reshape(1, -1)
160
- out = normalized_ing_reps.copy()
161
- out[:, params['indexes_ing_to_normalize']] = out[:, params['indexes_ing_to_normalize']] * (max_ing_reps - min_ing_reps) + min_ing_reps
162
- return out
163
-
164
-
165
- # put everything in a big matrix
166
- dim_cocktail_rep = max_ingredients * self.dim_rep_ingredient
167
- input_data2 = []
168
- nb_ingredients = []
169
- for d in input_data:
170
- cocktail_rep = np.zeros([dim_cocktail_rep])
171
- cocktail_rep.fill(np.nan)
172
- index = 0
173
- nb_ingredients.append(len(d))
174
- for dj in d:
175
- cocktail_rep[index:index + self.dim_rep_ingredient] = normalize_ing_reps(dj)
176
- index += self.dim_rep_ingredient
177
- input_data2.append(cocktail_rep)
178
- input_data = np.array(input_data2)
179
- nb_ingredients = np.array(nb_ingredients)
180
-
181
-
182
-
183
-
184
-
185
- # let us now extract various possible output we might want to predict:
186
- #########
187
- # Process output cocktail representation_analysis (computed from ingredient reps)
188
- #########
189
- # quantities_indexes = np.arange(20, 456, 57)
190
- # qs = input_data[0, quantities_indexes]
191
- # ingredient_quantities[0]
192
- # get final volume
193
- volumes = np.array(params['raw_data']['end volume'])
194
-
195
- min_vol = volumes.min()
196
- max_vol = volumes.max()
197
- def normalize_vol(volume):
198
- return (volume - min_vol) / (max_vol - min_vol)
199
-
200
- def denormalize_vol(normalized_vol):
201
- return normalized_vol * (max_vol - min_vol) + min_vol
202
-
203
- volumes = normalize_vol(volumes)
204
-
205
-
206
- # computed cocktail representation
207
- computed_cocktail_reps = params['cocktail_reps']
208
- self.dim_rep = computed_cocktail_reps.shape[1]
209
-
210
- #########
211
- # Process output sub categories
212
- #########
213
- categories = np.array([categories_list.index(sc) for sc in data['subcategory']])
214
- counts = dict(zip(categories_list, [0] * len(categories)))
215
- for c in data['subcategory']:
216
- counts[c] += 1
217
- for k in counts.keys():
218
- counts[k] /= len(data['subcategory'])
219
- self.categories = categories_list
220
- self.categories_weights = []
221
- for c in self.categories:
222
- self.categories_weights.append(1/len(self.categories)/counts[c])
223
- print(counts)
224
-
225
- #########
226
- # Process output glass type
227
- #########
228
- glasses = np.array([glasses_list.index(sc) for sc in data['glass']])
229
- counts = dict(zip(glasses_list, [0] * len(set(data['glass']))))
230
- for c in data['glass']:
231
- counts[c] += 1
232
- for k in counts.keys():
233
- counts[k] /= len(data['glass'])
234
- self.glasses = glasses_list
235
- self.glasses_weights = []
236
- for c in self.glasses:
237
- self.glasses_weights.append(1 / len(self.glasses) / counts[c])
238
- print(counts)
239
-
240
- #########
241
- # Process output preparation type
242
- #########
243
- prep_type = np.array([preparation_list.index(sc) for sc in data['category']])
244
- counts = dict(zip(preparation_list, [0] * len(preparation_list)))
245
- for c in data['category']:
246
- counts[c] += 1
247
- for k in counts.keys():
248
- counts[k] /= len(data['category'])
249
- self.prep_types = preparation_list
250
- self.prep_types_weights = []
251
- for c in self.prep_types:
252
- self.prep_types_weights.append(1 / len(self.prep_types) / counts[c])
253
- print(counts)
254
-
255
- taste_reps = list(data['taste_rep'])
256
- taste_rep_ground_truth = []
257
- taste_rep_valid = []
258
- for tr in taste_reps:
259
- if len(tr) > 2:
260
- taste_rep_valid.append(True)
261
- taste_rep_ground_truth.append([float(tr.split('[')[1].split(',')[0]), float(tr.split(']')[0].split(',')[1][1:])])
262
- else:
263
- taste_rep_valid.append(False)
264
- taste_rep_ground_truth.append([np.nan, np.nan])
265
- taste_rep_ground_truth = np.array(taste_rep_ground_truth)
266
- taste_rep_valid = np.array(taste_rep_valid)
267
- taste_rep_ground_truth /= 10
268
-
269
- auxiliary_data = dict(categories=categories,
270
- glasses=glasses,
271
- prep_type=prep_type,
272
- cocktail_reps=computed_cocktail_reps,
273
- ingredients_presence=ingredients_presence,
274
- taste_reps=taste_rep_ground_truth,
275
- volume=volumes,
276
- ingredients_quantities=ingredient_quantities)
277
- self.auxiliary_keys = sorted(params['auxiliaries_dict'].keys())
278
- assert self.auxiliary_keys == sorted(auxiliary_data.keys())
279
-
280
- data_preprocessing = dict(min_max_ing_quantities=(min_ing_quantities, max_ing_quantities),
281
- min_max_ing_reps=(min_ing_reps, max_ing_reps),
282
- min_max_vol=(min_vol, max_vol))
283
-
284
- if split == 'train':
285
- with open(params['save_path'] + 'normalization_funcs.pickle', 'wb') as f:
286
- pickle.dump(data_preprocessing, f)
287
-
288
- n_data = len(input_data)
289
- assert len(ingredient_quantities) == n_data
290
- for aux in self.auxiliary_keys:
291
- assert len(auxiliary_data[aux]) == n_data
292
-
293
- if split == 'train':
294
- indexes = np.arange(int(0.9 * n_data))
295
- elif split == 'test':
296
- indexes = np.arange(int(0.9 * n_data), n_data)
297
- elif split == 'all':
298
- indexes = np.arange(n_data)
299
- else:
300
- raise ValueError
301
-
302
- # np.random.shuffle(indexes)
303
- self.taste_rep_valid = taste_rep_valid[indexes]
304
- self.input_ingredients = input_data[indexes]
305
- self.ingredient_quantities = ingredient_quantities[indexes]
306
- self.computed_cocktail_reps = computed_cocktail_reps[indexes]
307
- self.auxiliaries = dict()
308
- for aux in self.auxiliary_keys:
309
- self.auxiliaries[aux] = auxiliary_data[aux][indexes]
310
- self.nb_ingredients = nb_ingredients[indexes]
311
-
312
- def __len__(self):
313
- return len(self.input_ingredients)
314
-
315
- def get_auxiliary_data(self, idx):
316
- out = dict()
317
- for aux in self.auxiliary_keys:
318
- out[aux] = self.auxiliaries[aux][idx]
319
- return out
320
-
321
- def __getitem__(self, idx):
322
- assert self.nb_ingredients[idx] == np.argwhere(~np.isnan(self.input_ingredients[idx])).flatten().size / self.dim_rep_ingredient
323
- return [self.nb_ingredients[idx], self.input_ingredients[idx], self.ingredient_quantities[idx], self.computed_cocktail_reps[idx], self.get_auxiliary_data(idx),
324
- self.taste_rep_valid[idx]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/cocktails/representation_learning/multihead_model.py DELETED
@@ -1,148 +0,0 @@
1
- import torch; torch.manual_seed(0)
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import torch.utils
5
- import torch.distributions
6
- import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
7
- from src.cocktails.representation_learning.simple_model import SimpleNet
8
-
9
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
-
11
- def get_activation(activation):
12
- if activation == 'tanh':
13
- activ = F.tanh
14
- elif activation == 'relu':
15
- activ = F.relu
16
- elif activation == 'mish':
17
- activ = F.mish
18
- elif activation == 'sigmoid':
19
- activ = F.sigmoid
20
- elif activation == 'leakyrelu':
21
- activ = F.leaky_relu
22
- elif activation == 'exp':
23
- activ = torch.exp
24
- else:
25
- raise ValueError
26
- return activ
27
-
28
- class IngredientEncoder(nn.Module):
29
- def __init__(self, input_dim, deepset_latent_dim, hidden_dims, activation, dropout):
30
- super(IngredientEncoder, self).__init__()
31
- self.linears = nn.ModuleList()
32
- self.dropouts = nn.ModuleList()
33
- dims = [input_dim] + hidden_dims + [deepset_latent_dim]
34
- for d_in, d_out in zip(dims[:-1], dims[1:]):
35
- self.linears.append(nn.Linear(d_in, d_out))
36
- self.dropouts.append(nn.Dropout(dropout))
37
- self.activation = get_activation(activation)
38
- self.n_layers = len(self.linears)
39
- self.layer_range = range(self.n_layers)
40
-
41
- def forward(self, x):
42
- for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts):
43
- x = layer(x)
44
- if i_layer != self.n_layers - 1:
45
- x = self.activation(dropout(x))
46
- return x # do not use dropout on last layer?
47
-
48
- class DeepsetCocktailEncoder(nn.Module):
49
- def __init__(self, input_dim, deepset_latent_dim, hidden_dims_ing, activation,
50
- hidden_dims_cocktail, latent_dim, aggregation, dropout):
51
- super(DeepsetCocktailEncoder, self).__init__()
52
- self.input_dim = input_dim # dimension of ingredient representation + quantity
53
- self.ingredient_encoder = IngredientEncoder(input_dim, deepset_latent_dim, hidden_dims_ing, activation, dropout) # encode each ingredient separately
54
- self.deepset_latent_dim = deepset_latent_dim # dimension of the deepset aggregation
55
- self.aggregation = aggregation
56
- self.latent_dim = latent_dim
57
- # post aggregation network
58
- self.linears = nn.ModuleList()
59
- self.dropouts = nn.ModuleList()
60
- dims = [deepset_latent_dim] + hidden_dims_cocktail
61
- for d_in, d_out in zip(dims[:-1], dims[1:]):
62
- self.linears.append(nn.Linear(d_in, d_out))
63
- self.dropouts.append(nn.Dropout(dropout))
64
- self.FC_mean = nn.Linear(hidden_dims_cocktail[-1], latent_dim)
65
- self.FC_logvar = nn.Linear(hidden_dims_cocktail[-1], latent_dim)
66
- self.softplus = nn.Softplus()
67
-
68
- self.activation = get_activation(activation)
69
- self.n_layers = len(self.linears)
70
- self.layer_range = range(self.n_layers)
71
-
72
- def forward(self, nb_ingredients, x):
73
-
74
- # reshape x in (batch size * nb ingredients, dim_ing_rep)
75
- batch_size = x.shape[0]
76
- all_ingredients = []
77
- for i in range(batch_size):
78
- for j in range(nb_ingredients[i]):
79
- all_ingredients.append(x[i, self.input_dim * j: self.input_dim * (j + 1)].reshape(1, -1))
80
- x = torch.cat(all_ingredients, dim=0)
81
- # encode ingredients in parallel
82
- ingredients_encodings = self.ingredient_encoder(x)
83
- assert ingredients_encodings.shape == (torch.sum(nb_ingredients), self.deepset_latent_dim)
84
-
85
- # aggregate
86
- x = []
87
- index_first = 0
88
- for i in range(batch_size):
89
- index_last = index_first + nb_ingredients[i]
90
- # aggregate
91
- if self.aggregation == 'sum':
92
- x.append(torch.sum(ingredients_encodings[index_first:index_last], dim=0).reshape(1, -1))
93
- elif self.aggregation == 'mean':
94
- x.append(torch.mean(ingredients_encodings[index_first:index_last], dim=0).reshape(1, -1))
95
- else:
96
- raise ValueError
97
- index_first = index_last
98
- x = torch.cat(x, dim=0)
99
- assert x.shape[0] == batch_size
100
-
101
- for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts):
102
- x = self.activation(dropout(layer(x)))
103
- mean = self.FC_mean(x)
104
- logvar = self.FC_logvar(x)
105
- return mean, logvar
106
-
107
-
108
- class MultiHeadModel(nn.Module):
109
- def __init__(self, encoder, auxiliaries_dict, activation, hidden_dims_decoder):
110
- super(MultiHeadModel, self).__init__()
111
- self.encoder = encoder
112
- self.latent_dim = self.encoder.output_dim
113
- self.auxiliaries_str = []
114
- self.auxiliaries = nn.ModuleList()
115
- for aux_str in sorted(auxiliaries_dict.keys()):
116
- if aux_str == 'taste_reps':
117
- self.taste_reps_decoder = SimpleNet(input_dim=self.latent_dim, hidden_dims=[], output_dim=auxiliaries_dict[aux_str]['dim_output'],
118
- activation=activation, dropout=0.0, final_activ=auxiliaries_dict[aux_str]['final_activ'])
119
- else:
120
- self.auxiliaries_str.append(aux_str)
121
- if aux_str == 'ingredients_quantities':
122
- hd = hidden_dims_decoder
123
- else:
124
- hd = []
125
- self.auxiliaries.append(SimpleNet(input_dim=self.latent_dim, hidden_dims=hd, output_dim=auxiliaries_dict[aux_str]['dim_output'],
126
- activation=activation, dropout=0.0, final_activ=auxiliaries_dict[aux_str]['final_activ']))
127
-
128
- def get_all_auxiliaries(self, x):
129
- return [aux(x) for aux in self.auxiliaries]
130
-
131
- def get_auxiliary(self, z, aux_str):
132
- if aux_str == 'taste_reps':
133
- return self.taste_reps_decoder(z)
134
- else:
135
- index = self.auxiliaries_str.index(aux_str)
136
- return self.auxiliaries[index](z)
137
-
138
- def forward(self, x, aux_str=None):
139
- z = self.encoder(x)
140
- if aux_str is not None:
141
- return z, self.get_auxiliary(z, aux_str), [aux_str]
142
- else:
143
- return z, self.get_all_auxiliaries(z), self.auxiliaries_str
144
-
145
- def get_multihead_model(input_dim, activation, hidden_dims_cocktail, latent_dim, dropout, auxiliaries_dict, hidden_dims_decoder):
146
- encoder = SimpleNet(input_dim, hidden_dims_cocktail, latent_dim, activation, dropout)
147
- model = MultiHeadModel(encoder, auxiliaries_dict, activation, hidden_dims_decoder)
148
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/cocktails/representation_learning/run.py DELETED
@@ -1,557 +0,0 @@
1
- import torch; torch.manual_seed(0)
2
- import torch.utils
3
- from torch.utils.data import DataLoader
4
- import torch.distributions
5
- import torch.nn as nn
6
- import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
7
- from src.cocktails.representation_learning.dataset import MyDataset, get_representation_from_ingredient, get_max_n_ingredients
8
- import json
9
- import pandas as pd
10
- import numpy as np
11
- import os
12
- from src.cocktails.representation_learning.vae_model import get_vae_model
13
- from src.cocktails.config import COCKTAILS_CSV_DATA, FULL_COCKTAIL_REP_PATH, EXPERIMENT_PATH
14
- from src.cocktails.utilities.cocktail_utilities import get_bunch_of_rep_keys
15
- from src.cocktails.utilities.ingredients_utilities import ingredient_profiles
16
- from resource import getrusage
17
- from resource import RUSAGE_SELF
18
- import gc
19
- gc.collect(2)
20
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
21
-
22
- def get_params():
23
- data = pd.read_csv(COCKTAILS_CSV_DATA)
24
- max_ingredients, ingredient_set, liquor_set, liqueur_set = get_max_n_ingredients(data)
25
- num_ingredients = len(ingredient_set)
26
- rep_keys = get_bunch_of_rep_keys()['custom']
27
- ing_keys = [k.split(' ')[1] for k in rep_keys]
28
- ing_keys.remove('volume')
29
- nb_ing_categories = len(set(ingredient_profiles['type']))
30
- category_encodings = dict(zip(sorted(set(ingredient_profiles['type'])), np.eye(nb_ing_categories)))
31
-
32
- params = dict(trial_id='test',
33
- save_path=EXPERIMENT_PATH + "/deepset_vae/",
34
- nb_epochs=2000,
35
- print_every=50,
36
- plot_every=100,
37
- batch_size=64,
38
- lr=0.001,
39
- dropout=0.,
40
- nb_epoch_switch_beta=600,
41
- latent_dim=10,
42
- beta_vae=0.2,
43
- ing_keys=ing_keys,
44
- nb_ingredients=len(ingredient_set),
45
- hidden_dims_ingredients=[128],
46
- hidden_dims_cocktail=[32],
47
- hidden_dims_decoder=[32],
48
- agg='mean',
49
- activation='relu',
50
- auxiliaries_dict=dict(categories=dict(weight=0, type='classif', final_activ=None, dim_output=len(set(data['subcategory']))),
51
- glasses=dict(weight=0, type='classif', final_activ=None, dim_output=len(set(data['glass']))),
52
- prep_type=dict(weight=0, type='classif', final_activ=None, dim_output=len(set(data['category']))),
53
- cocktail_reps=dict(weight=0, type='regression', final_activ=None, dim_output=13),
54
- volume=dict(weight=0, type='regression', final_activ='relu', dim_output=1),
55
- taste_reps=dict(weight=0, type='regression', final_activ='relu', dim_output=2),
56
- ingredients_presence=dict(weight=0, type='multiclassif', final_activ=None, dim_output=num_ingredients)),
57
- category_encodings=category_encodings
58
- )
59
- # params = dict(trial_id='test',
60
- # save_path=EXPERIMENT_PATH + "/deepset_vae/",
61
- # nb_epochs=1000,
62
- # print_every=50,
63
- # plot_every=100,
64
- # batch_size=64,
65
- # lr=0.001,
66
- # dropout=0.,
67
- # nb_epoch_switch_beta=500,
68
- # latent_dim=64,
69
- # beta_vae=0.3,
70
- # ing_keys=ing_keys,
71
- # nb_ingredients=len(ingredient_set),
72
- # hidden_dims_ingredients=[128],
73
- # hidden_dims_cocktail=[128, 128],
74
- # hidden_dims_decoder=[128, 128],
75
- # agg='mean',
76
- # activation='mish',
77
- # auxiliaries_dict=dict(categories=dict(weight=0.5, type='classif', final_activ=None, dim_output=len(set(data['subcategory']))),
78
- # glasses=dict(weight=0.03, type='classif', final_activ=None, dim_output=len(set(data['glass']))),
79
- # prep_type=dict(weight=0.02, type='classif', final_activ=None, dim_output=len(set(data['category']))),
80
- # cocktail_reps=dict(weight=1, type='regression', final_activ=None, dim_output=13),
81
- # volume=dict(weight=1, type='regression', final_activ='relu', dim_output=1),
82
- # taste_reps=dict(weight=1, type='regression', final_activ='relu', dim_output=2),
83
- # ingredients_presence=dict(weight=1.5, type='multiclassif', final_activ=None, dim_output=num_ingredients)),
84
- # category_encodings=category_encodings
85
- # )
86
- water_rep, indexes_to_normalize = get_representation_from_ingredient(ingredients=['water'], quantities=[1],
87
- max_q_per_ing=dict(zip(ingredient_set, [1] * num_ingredients)), index=0,
88
- params=params)
89
- dim_rep_ingredient = water_rep.size
90
- params['indexes_ing_to_normalize'] = indexes_to_normalize
91
- params['deepset_latent_dim'] = dim_rep_ingredient * max_ingredients
92
- params['input_dim'] = dim_rep_ingredient
93
- params['dim_rep_ingredient'] = dim_rep_ingredient
94
- params = compute_expe_name_and_save_path(params)
95
- del params['category_encodings'] # to dump
96
- with open(params['save_path'] + 'params.json', 'w') as f:
97
- json.dump(params, f)
98
-
99
- params = complete_params(params)
100
- return params
101
-
102
- def complete_params(params):
103
- data = pd.read_csv(COCKTAILS_CSV_DATA)
104
- cocktail_reps = np.loadtxt(FULL_COCKTAIL_REP_PATH)
105
- nb_ing_categories = len(set(ingredient_profiles['type']))
106
- category_encodings = dict(zip(sorted(set(ingredient_profiles['type'])), np.eye(nb_ing_categories)))
107
- params['cocktail_reps'] = cocktail_reps
108
- params['raw_data'] = data
109
- params['category_encodings'] = category_encodings
110
- return params
111
-
112
- def compute_losses_and_accuracies(loss_functions, auxiliaries, auxiliaries_str, outputs, data):
113
- losses = dict()
114
- accuracies = dict()
115
- other_metrics = dict()
116
- for i_k, k in enumerate(auxiliaries_str):
117
- # get ground truth
118
- # compute loss
119
- if k == 'volume':
120
- outputs[i_k] = outputs[i_k].flatten()
121
- ground_truth = auxiliaries[k]
122
- if ground_truth.dtype == torch.float64:
123
- losses[k] = loss_functions[k](outputs[i_k], ground_truth.float()).float()
124
- elif ground_truth.dtype == torch.int64:
125
- if str(loss_functions[k]) != "BCEWithLogitsLoss()":
126
- losses[k] = loss_functions[k](outputs[i_k].float(), ground_truth.long()).float()
127
- else:
128
- losses[k] = loss_functions[k](outputs[i_k].float(), ground_truth.float()).float()
129
- else:
130
- losses[k] = loss_functions[k](outputs[i_k], ground_truth).float()
131
- # compute accuracies
132
- if str(loss_functions[k]) == 'CrossEntropyLoss()':
133
- bs, n_options = outputs[i_k].shape
134
- predicted = outputs[i_k].argmax(dim=1).detach().numpy()
135
- true = ground_truth.int().detach().numpy()
136
- confusion_matrix = np.zeros([n_options, n_options])
137
- for i in range(bs):
138
- confusion_matrix[true[i], predicted[i]] += 1
139
- acc = confusion_matrix.diagonal().sum() / bs
140
- for i in range(n_options):
141
- if confusion_matrix[i].sum() != 0:
142
- confusion_matrix[i] /= confusion_matrix[i].sum()
143
- other_metrics[k + '_confusion'] = confusion_matrix
144
- accuracies[k] = np.mean(outputs[i_k].argmax(dim=1).detach().numpy() == ground_truth.int().detach().numpy())
145
- assert (acc - accuracies[k]) < 1e-5
146
-
147
- elif str(loss_functions[k]) == 'BCEWithLogitsLoss()':
148
- assert k == 'ingredients_presence'
149
- outputs_rescaled = outputs[i_k].detach().numpy() * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities
150
- predicted_presence = (outputs_rescaled > 0).astype(bool)
151
- presence = ground_truth.detach().numpy().astype(bool)
152
- other_metrics[k + '_false_positive'] = np.mean(np.logical_and(predicted_presence.astype(bool), ~presence.astype(bool)))
153
- other_metrics[k + '_false_negative'] = np.mean(np.logical_and(~predicted_presence.astype(bool), presence.astype(bool)))
154
- accuracies[k] = np.mean(predicted_presence == presence) # accuracy for multi class labeling
155
- elif str(loss_functions[k]) == 'MSELoss()':
156
- accuracies[k] = np.nan
157
- else:
158
- raise ValueError
159
- return losses, accuracies, other_metrics
160
-
161
- def compute_metric_output(aux_other_metrics, data, ingredient_quantities, x_hat):
162
- ing_q = ingredient_quantities.detach().numpy() * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities
163
- ing_presence = (ing_q > 0)
164
- x_hat = x_hat.detach().numpy() * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities
165
- # abs_diff = np.abs(ing_q - x_hat) * data.dataset.max_ing_quantities
166
- abs_diff = np.abs(ing_q - x_hat)
167
- ing_q_abs_loss_when_present, ing_q_abs_loss_when_absent = [], []
168
- for i in range(ingredient_quantities.shape[0]):
169
- ing_q_abs_loss_when_present.append(np.mean(abs_diff[i, np.where(ing_presence[i])]))
170
- ing_q_abs_loss_when_absent.append(np.mean(abs_diff[i, np.where(~ing_presence[i])]))
171
- aux_other_metrics['ing_q_abs_loss_when_present'] = np.mean(ing_q_abs_loss_when_present)
172
- aux_other_metrics['ing_q_abs_loss_when_absent'] = np.mean(ing_q_abs_loss_when_absent)
173
- return aux_other_metrics
174
-
175
- def run_epoch(opt, train, model, data, loss_functions, weights, params):
176
- if train:
177
- model.train()
178
- else:
179
- model.eval()
180
-
181
- # prepare logging of losses
182
- losses = dict(kld_loss=[],
183
- mse_loss=[],
184
- vae_loss=[],
185
- volume_loss=[],
186
- global_loss=[])
187
- accuracies = dict()
188
- other_metrics = dict()
189
- for aux in params['auxiliaries_dict'].keys():
190
- losses[aux] = []
191
- accuracies[aux] = []
192
- if train: opt.zero_grad()
193
-
194
- for d in data:
195
- nb_ingredients = d[0]
196
- batch_size = nb_ingredients.shape[0]
197
- x_ingredients = d[1].float()
198
- ingredient_quantities = d[2]
199
- cocktail_reps = d[3]
200
- auxiliaries = d[4]
201
- for k in auxiliaries.keys():
202
- if auxiliaries[k].dtype == torch.float64: auxiliaries[k] = auxiliaries[k].float()
203
- taste_valid = d[-1]
204
- x = x_ingredients.to(device)
205
- x_hat, z, mean, log_var, outputs, auxiliaries_str = model.forward_direct(ingredient_quantities.float())
206
- # get auxiliary losses and accuracies
207
- aux_losses, aux_accuracies, aux_other_metrics = compute_losses_and_accuracies(loss_functions, auxiliaries, auxiliaries_str, outputs, data)
208
-
209
- # compute vae loss
210
- mse_loss = ((ingredient_quantities - x_hat) ** 2).mean().float()
211
- kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mean ** 2 - log_var.exp(), dim=1)).float()
212
- vae_loss = mse_loss + params['beta_vae'] * (params['latent_dim'] / params['nb_ingredients']) * kld_loss
213
- # compute total volume loss to train decoder
214
- # volume_loss = ((ingredient_quantities.sum(dim=1) - x_hat.sum(dim=1)) ** 2).mean().float()
215
- volume_loss = torch.FloatTensor([0])
216
-
217
- aux_other_metrics = compute_metric_output(aux_other_metrics, data, ingredient_quantities, x_hat)
218
-
219
- indexes_taste_valid = np.argwhere(taste_valid.detach().numpy()).flatten()
220
- if indexes_taste_valid.size > 0:
221
- outputs_taste = model.get_auxiliary(z[indexes_taste_valid], aux_str='taste_reps')
222
- gt = auxiliaries['taste_reps'][indexes_taste_valid]
223
- factor_loss = indexes_taste_valid.size / (0.3 * batch_size)# factor on the loss: if same ratio as actual dataset factor = 1 if there is less data, then the factor decreases, more data, it increases
224
- aux_losses['taste_reps'] = (loss_functions['taste_reps'](outputs_taste, gt) * factor_loss).float()
225
- else:
226
- aux_losses['taste_reps'] = torch.FloatTensor([0]).reshape([])
227
- aux_accuracies['taste_reps'] = 0
228
-
229
- # aggregate losses
230
- global_loss = torch.sum(torch.cat([torch.atleast_1d(vae_loss), torch.atleast_1d(volume_loss)] + [torch.atleast_1d(aux_losses[k] * weights[k]) for k in params['auxiliaries_dict'].keys()]))
231
- # for k in params['auxiliaries_dict'].keys():
232
- # global_loss += aux_losses[k] * weights[k]
233
-
234
- if train:
235
- global_loss.backward()
236
- opt.step()
237
- opt.zero_grad()
238
-
239
- # logging
240
- losses['global_loss'].append(float(global_loss))
241
- losses['mse_loss'].append(float(mse_loss))
242
- losses['vae_loss'].append(float(vae_loss))
243
- losses['volume_loss'].append(float(volume_loss))
244
- losses['kld_loss'].append(float(kld_loss))
245
- for k in params['auxiliaries_dict'].keys():
246
- losses[k].append(float(aux_losses[k]))
247
- accuracies[k].append(float(aux_accuracies[k]))
248
- for k in aux_other_metrics.keys():
249
- if k not in other_metrics.keys():
250
- other_metrics[k] = [aux_other_metrics[k]]
251
- else:
252
- other_metrics[k].append(aux_other_metrics[k])
253
-
254
- for k in losses.keys():
255
- losses[k] = np.mean(losses[k])
256
- for k in accuracies.keys():
257
- accuracies[k] = np.mean(accuracies[k])
258
- for k in other_metrics.keys():
259
- other_metrics[k] = np.mean(other_metrics[k], axis=0)
260
- return model, losses, accuracies, other_metrics
261
-
262
- def prepare_data_and_loss(params):
263
- train_data = MyDataset(split='train', params=params)
264
- test_data = MyDataset(split='test', params=params)
265
-
266
- train_data_loader = DataLoader(train_data, batch_size=params['batch_size'], shuffle=True)
267
- test_data_loader = DataLoader(test_data, batch_size=params['batch_size'], shuffle=True)
268
-
269
- loss_functions = dict()
270
- weights = dict()
271
- for k in sorted(params['auxiliaries_dict'].keys()):
272
- if params['auxiliaries_dict'][k]['type'] == 'classif':
273
- if k == 'glasses':
274
- classif_weights = train_data.glasses_weights
275
- elif k == 'prep_type':
276
- classif_weights = train_data.prep_types_weights
277
- elif k == 'categories':
278
- classif_weights = train_data.categories_weights
279
- else:
280
- raise ValueError
281
- loss_functions[k] = nn.CrossEntropyLoss(torch.FloatTensor(classif_weights))
282
- elif params['auxiliaries_dict'][k]['type'] == 'multiclassif':
283
- loss_functions[k] = nn.BCEWithLogitsLoss()
284
- elif params['auxiliaries_dict'][k]['type'] == 'regression':
285
- loss_functions[k] = nn.MSELoss()
286
- else:
287
- raise ValueError
288
- weights[k] = params['auxiliaries_dict'][k]['weight']
289
-
290
-
291
- return loss_functions, train_data_loader, test_data_loader, weights
292
-
293
- def print_losses(train, losses, accuracies, other_metrics):
294
- keyword = 'Train' if train else 'Eval'
295
- print(f'\t{keyword} logs:')
296
- keys = ['global_loss', 'vae_loss', 'mse_loss', 'kld_loss', 'volume_loss']
297
- for k in keys:
298
- print(f'\t\t{k} - Loss: {losses[k]:.2f}')
299
- for k in sorted(accuracies.keys()):
300
- print(f'\t\t{k} (aux) - Loss: {losses[k]:.2f}, Acc: {accuracies[k]:.2f}')
301
- for k in sorted(other_metrics.keys()):
302
- if 'confusion' not in k:
303
- print(f'\t\t{k} - {other_metrics[k]:.2f}')
304
-
305
-
306
- def run_experiment(params, verbose=True):
307
- loss_functions, train_data_loader, test_data_loader, weights = prepare_data_and_loss(params)
308
- params['filter_decoder_output'] = train_data_loader.dataset.filter_decoder_output
309
-
310
- model_params = [params[k] for k in ["input_dim", "deepset_latent_dim", "hidden_dims_ingredients", "activation",
311
- "hidden_dims_cocktail", "hidden_dims_decoder", "nb_ingredients", "latent_dim", "agg", "dropout", "auxiliaries_dict",
312
- "filter_decoder_output"]]
313
- model = get_vae_model(*model_params)
314
- opt = torch.optim.AdamW(model.parameters(), lr=params['lr'])
315
-
316
-
317
- all_train_losses = []
318
- all_eval_losses = []
319
- all_train_accuracies = []
320
- all_eval_accuracies = []
321
- all_eval_other_metrics = []
322
- all_train_other_metrics = []
323
- best_loss = np.inf
324
- model, eval_losses, eval_accuracies, eval_other_metrics = run_epoch(opt=opt, train=False, model=model, data=test_data_loader, loss_functions=loss_functions,
325
- weights=weights, params=params)
326
- all_eval_losses.append(eval_losses)
327
- all_eval_accuracies.append(eval_accuracies)
328
- all_eval_other_metrics.append(eval_other_metrics)
329
- if verbose: print(f'\n--------\nEpoch #0')
330
- if verbose: print_losses(train=False, accuracies=eval_accuracies, losses=eval_losses, other_metrics=eval_other_metrics)
331
- for epoch in range(params['nb_epochs']):
332
- if verbose and (epoch + 1) % params['print_every'] == 0: print(f'\n--------\nEpoch #{epoch+1}')
333
- model, train_losses, train_accuracies, train_other_metrics = run_epoch(opt=opt, train=True, model=model, data=train_data_loader, loss_functions=loss_functions,
334
- weights=weights, params=params)
335
- if verbose and (epoch + 1) % params['print_every'] == 0: print_losses(train=True, accuracies=train_accuracies, losses=train_losses, other_metrics=train_other_metrics)
336
- model, eval_losses, eval_accuracies, eval_other_metrics = run_epoch(opt=opt, train=False, model=model, data=test_data_loader, loss_functions=loss_functions,
337
- weights=weights, params=params)
338
- if verbose and (epoch + 1) % params['print_every'] == 0: print_losses(train=False, accuracies=eval_accuracies, losses=eval_losses, other_metrics=eval_other_metrics)
339
- if eval_losses['global_loss'] < best_loss:
340
- best_loss = eval_losses['global_loss']
341
- if verbose: print(f'Saving new best model with loss {best_loss:.2f}')
342
- torch.save(model.state_dict(), params['save_path'] + f'checkpoint_best.save')
343
-
344
- # log
345
- all_train_losses.append(train_losses)
346
- all_train_accuracies.append(train_accuracies)
347
- all_eval_losses.append(eval_losses)
348
- all_eval_accuracies.append(eval_accuracies)
349
- all_eval_other_metrics.append(eval_other_metrics)
350
- all_train_other_metrics.append(train_other_metrics)
351
-
352
- # if epoch == params['nb_epoch_switch_beta']:
353
- # params['beta_vae'] = 2.5
354
- # params['auxiliaries_dict']['prep_type']['weight'] /= 10
355
- # params['auxiliaries_dict']['glasses']['weight'] /= 10
356
-
357
- if (epoch + 1) % params['plot_every'] == 0:
358
-
359
- plot_results(all_train_losses, all_train_accuracies, all_train_other_metrics,
360
- all_eval_losses, all_eval_accuracies, all_eval_other_metrics, params['plot_path'], weights)
361
-
362
- return model
363
-
364
- def plot_results(all_train_losses, all_train_accuracies, all_train_other_metrics,
365
- all_eval_losses, all_eval_accuracies, all_eval_other_metrics, plot_path, weights):
366
-
367
- steps = np.arange(len(all_eval_accuracies))
368
-
369
- loss_keys = sorted(all_train_losses[0].keys())
370
- acc_keys = sorted(all_train_accuracies[0].keys())
371
- metrics_keys = sorted(all_train_other_metrics[0].keys())
372
-
373
- plt.figure()
374
- plt.title('Train losses')
375
- for k in loss_keys:
376
- factor = 1 if k == 'mse_loss' else 1
377
- if k not in weights.keys():
378
- plt.plot(steps[1:], [train_loss[k] * factor for train_loss in all_train_losses], label=k)
379
- else:
380
- if weights[k] != 0:
381
- plt.plot(steps[1:], [train_loss[k] * factor for train_loss in all_train_losses], label=k)
382
-
383
- plt.legend()
384
- plt.ylim([0, 4])
385
- plt.savefig(plot_path + 'train_losses.png', dpi=200)
386
- fig = plt.gcf()
387
- plt.close(fig)
388
-
389
- plt.figure()
390
- plt.title('Train accuracies')
391
- for k in acc_keys:
392
- if weights[k] != 0:
393
- plt.plot(steps[1:], [train_acc[k] for train_acc in all_train_accuracies], label=k)
394
- plt.legend()
395
- plt.ylim([0, 1])
396
- plt.savefig(plot_path + 'train_acc.png', dpi=200)
397
- fig = plt.gcf()
398
- plt.close(fig)
399
-
400
- plt.figure()
401
- plt.title('Train other metrics')
402
- for k in metrics_keys:
403
- if 'confusion' not in k and 'presence' in k:
404
- plt.plot(steps[1:], [train_metric[k] for train_metric in all_train_other_metrics], label=k)
405
- plt.legend()
406
- plt.ylim([0, 1])
407
- plt.savefig(plot_path + 'train_ing_presence_errors.png', dpi=200)
408
- fig = plt.gcf()
409
- plt.close(fig)
410
-
411
- plt.figure()
412
- plt.title('Train other metrics')
413
- for k in metrics_keys:
414
- if 'confusion' not in k and 'presence' not in k:
415
- plt.plot(steps[1:], [train_metric[k] for train_metric in all_train_other_metrics], label=k)
416
- plt.legend()
417
- plt.savefig(plot_path + 'train_ing_q_error.png', dpi=200)
418
- fig = plt.gcf()
419
- plt.close(fig)
420
-
421
- plt.figure()
422
- plt.title('Eval losses')
423
- for k in loss_keys:
424
- factor = 1 if k == 'mse_loss' else 1
425
- if k not in weights.keys():
426
- plt.plot(steps, [eval_loss[k] * factor for eval_loss in all_eval_losses], label=k)
427
- else:
428
- if weights[k] != 0:
429
- plt.plot(steps, [eval_loss[k] * factor for eval_loss in all_eval_losses], label=k)
430
- plt.legend()
431
- plt.ylim([0, 4])
432
- plt.savefig(plot_path + 'eval_losses.png', dpi=200)
433
- fig = plt.gcf()
434
- plt.close(fig)
435
-
436
- plt.figure()
437
- plt.title('Eval accuracies')
438
- for k in acc_keys:
439
- if weights[k] != 0:
440
- plt.plot(steps, [eval_acc[k] for eval_acc in all_eval_accuracies], label=k)
441
- plt.legend()
442
- plt.ylim([0, 1])
443
- plt.savefig(plot_path + 'eval_acc.png', dpi=200)
444
- fig = plt.gcf()
445
- plt.close(fig)
446
-
447
- plt.figure()
448
- plt.title('Eval other metrics')
449
- for k in metrics_keys:
450
- if 'confusion' not in k and 'presence' in k:
451
- plt.plot(steps, [eval_metric[k] for eval_metric in all_eval_other_metrics], label=k)
452
- plt.legend()
453
- plt.ylim([0, 1])
454
- plt.savefig(plot_path + 'eval_ing_presence_errors.png', dpi=200)
455
- fig = plt.gcf()
456
- plt.close(fig)
457
-
458
- plt.figure()
459
- plt.title('Eval other metrics')
460
- for k in metrics_keys:
461
- if 'confusion' not in k and 'presence' not in k:
462
- plt.plot(steps, [eval_metric[k] for eval_metric in all_eval_other_metrics], label=k)
463
- plt.legend()
464
- plt.savefig(plot_path + 'eval_ing_q_error.png', dpi=200)
465
- fig = plt.gcf()
466
- plt.close(fig)
467
-
468
-
469
- for k in metrics_keys:
470
- if 'confusion' in k:
471
- plt.figure()
472
- plt.title(k)
473
- plt.ylabel('True')
474
- plt.xlabel('Predicted')
475
- plt.imshow(all_eval_other_metrics[-1][k], vmin=0, vmax=1)
476
- plt.colorbar()
477
- plt.savefig(plot_path + f'eval_{k}.png', dpi=200)
478
- fig = plt.gcf()
479
- plt.close(fig)
480
-
481
- for k in metrics_keys:
482
- if 'confusion' in k:
483
- plt.figure()
484
- plt.title(k)
485
- plt.ylabel('True')
486
- plt.xlabel('Predicted')
487
- plt.imshow(all_train_other_metrics[-1][k], vmin=0, vmax=1)
488
- plt.colorbar()
489
- plt.savefig(plot_path + f'train_{k}.png', dpi=200)
490
- fig = plt.gcf()
491
- plt.close(fig)
492
-
493
- plt.close('all')
494
-
495
-
496
- def get_model(model_path):
497
-
498
- with open(model_path + 'params.json', 'r') as f:
499
- params = json.load(f)
500
- params['save_path'] = model_path
501
- max_ing_quantities = np.loadtxt(params['save_path'] + 'max_ing_quantities.txt')
502
- mean_ing_quantities = np.loadtxt(params['save_path'] + 'mean_ing_quantities.txt')
503
- std_ing_quantities = np.loadtxt(params['save_path'] + 'std_ing_quantities.txt')
504
- min_when_present_ing_quantities = np.loadtxt(params['save_path'] + 'min_when_present_ing_quantities.txt')
505
- def filter_decoder_output(output):
506
- output = output.detach().numpy()
507
- output_unnormalized = output * std_ing_quantities + mean_ing_quantities
508
- if output.ndim == 1:
509
- output_unnormalized[np.where(output_unnormalized < min_when_present_ing_quantities)] = 0
510
- else:
511
- for i in range(output.shape[0]):
512
- output_unnormalized[i, np.where(output_unnormalized[i] < min_when_present_ing_quantities)] = 0
513
- return output_unnormalized.copy()
514
- params['filter_decoder_output'] = filter_decoder_output
515
- model_chkpt = model_path + "checkpoint_best.save"
516
- model_params = [params[k] for k in ["input_dim", "deepset_latent_dim", "hidden_dims_ingredients", "activation",
517
- "hidden_dims_cocktail", "hidden_dims_decoder", "nb_ingredients", "latent_dim", "agg", "dropout", "auxiliaries_dict",
518
- "filter_decoder_output"]]
519
- model = get_vae_model(*model_params)
520
- model.load_state_dict(torch.load(model_chkpt))
521
- model.eval()
522
- return model, filter_decoder_output, params
523
-
524
-
525
- def compute_expe_name_and_save_path(params):
526
- weights_str = '['
527
- for aux in params['auxiliaries_dict'].keys():
528
- weights_str += f'{params["auxiliaries_dict"][aux]["weight"]}, '
529
- weights_str = weights_str[:-2] + ']'
530
- save_path = params['save_path'] + params["trial_id"]
531
- save_path += f'_lr{params["lr"]}'
532
- save_path += f'_betavae{params["beta_vae"]}'
533
- save_path += f'_bs{params["batch_size"]}'
534
- save_path += f'_latentdim{params["latent_dim"]}'
535
- save_path += f'_hding{params["hidden_dims_ingredients"]}'
536
- save_path += f'_hdcocktail{params["hidden_dims_cocktail"]}'
537
- save_path += f'_hddecoder{params["hidden_dims_decoder"]}'
538
- save_path += f'_agg{params["agg"]}'
539
- save_path += f'_activ{params["activation"]}'
540
- save_path += f'_w{weights_str}'
541
- counter = 0
542
- while os.path.exists(save_path + f"_{counter}"):
543
- counter += 1
544
- save_path = save_path + f"_{counter}" + '/'
545
- params["save_path"] = save_path
546
- os.makedirs(save_path)
547
- os.makedirs(save_path + 'plots/')
548
- params['plot_path'] = save_path + 'plots/'
549
- print(f'logging to {save_path}')
550
- return params
551
-
552
-
553
-
554
- if __name__ == '__main__':
555
- params = get_params()
556
- run_experiment(params)
557
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/cocktails/representation_learning/run_simple_net.py DELETED
@@ -1,302 +0,0 @@
1
- import torch; torch.manual_seed(0)
2
- import torch.utils
3
- from torch.utils.data import DataLoader
4
- import torch.distributions
5
- import torch.nn as nn
6
- import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
7
- from src.cocktails.representation_learning.dataset import MyDataset, get_representation_from_ingredient, get_max_n_ingredients
8
- import json
9
- import pandas as pd
10
- import numpy as np
11
- import os
12
- from src.cocktails.representation_learning.simple_model import SimpleNet
13
- from src.cocktails.config import COCKTAILS_CSV_DATA, FULL_COCKTAIL_REP_PATH, EXPERIMENT_PATH
14
- from src.cocktails.utilities.cocktail_utilities import get_bunch_of_rep_keys
15
- from src.cocktails.utilities.ingredients_utilities import ingredient_profiles
16
- from resource import getrusage
17
- from resource import RUSAGE_SELF
18
- import gc
19
- gc.collect(2)
20
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
21
-
22
- def get_params():
23
- data = pd.read_csv(COCKTAILS_CSV_DATA)
24
- max_ingredients, ingredient_set, liquor_set, liqueur_set = get_max_n_ingredients(data)
25
- num_ingredients = len(ingredient_set)
26
- rep_keys = get_bunch_of_rep_keys()['custom']
27
- ing_keys = [k.split(' ')[1] for k in rep_keys]
28
- ing_keys.remove('volume')
29
- nb_ing_categories = len(set(ingredient_profiles['type']))
30
- category_encodings = dict(zip(sorted(set(ingredient_profiles['type'])), np.eye(nb_ing_categories)))
31
-
32
- params = dict(trial_id='test',
33
- save_path=EXPERIMENT_PATH + "/simple_net/",
34
- nb_epochs=100,
35
- print_every=50,
36
- plot_every=50,
37
- batch_size=128,
38
- lr=0.001,
39
- dropout=0.15,
40
- output_keyword='glasses',
41
- ing_keys=ing_keys,
42
- nb_ingredients=len(ingredient_set),
43
- hidden_dims=[16],
44
- activation='sigmoid',
45
- auxiliaries_dict=dict(categories=dict(weight=0, type='classif', final_activ=None, dim_output=len(set(data['subcategory']))),
46
- glasses=dict(weight=0, type='classif', final_activ=None, dim_output=len(set(data['glass']))),
47
- prep_type=dict(weight=0, type='classif', final_activ=None, dim_output=len(set(data['category']))),
48
- cocktail_reps=dict(weight=0, type='regression', final_activ=None, dim_output=13),
49
- volume=dict(weight=0, type='regression', final_activ='relu', dim_output=1),
50
- taste_reps=dict(weight=0, type='regression', final_activ='relu', dim_output=2),
51
- ingredients_presence=dict(weight=0, type='multiclassif', final_activ=None, dim_output=num_ingredients),
52
- ingredients_quantities=dict(weight=0, type='regression', final_activ=None, dim_output=num_ingredients)),
53
-
54
- category_encodings=category_encodings
55
- )
56
- params['output_dim'] = params['auxiliaries_dict'][params['output_keyword']]['dim_output']
57
- water_rep, indexes_to_normalize = get_representation_from_ingredient(ingredients=['water'], quantities=[1],
58
- max_q_per_ing=dict(zip(ingredient_set, [1] * num_ingredients)), index=0,
59
- params=params)
60
- dim_rep_ingredient = water_rep.size
61
- params['indexes_ing_to_normalize'] = indexes_to_normalize
62
- params['deepset_latent_dim'] = dim_rep_ingredient * max_ingredients
63
- params['dim_rep_ingredient'] = dim_rep_ingredient
64
- params['input_dim'] = params['nb_ingredients']
65
- params = compute_expe_name_and_save_path(params)
66
- del params['category_encodings'] # to dump
67
- with open(params['save_path'] + 'params.json', 'w') as f:
68
- json.dump(params, f)
69
-
70
- params = complete_params(params)
71
- return params
72
-
73
- def complete_params(params):
74
- data = pd.read_csv(COCKTAILS_CSV_DATA)
75
- cocktail_reps = np.loadtxt(FULL_COCKTAIL_REP_PATH)
76
- nb_ing_categories = len(set(ingredient_profiles['type']))
77
- category_encodings = dict(zip(sorted(set(ingredient_profiles['type'])), np.eye(nb_ing_categories)))
78
- params['cocktail_reps'] = cocktail_reps
79
- params['raw_data'] = data
80
- params['category_encodings'] = category_encodings
81
- return params
82
-
83
- def compute_confusion_matrix_and_accuracy(predictions, ground_truth):
84
- bs, n_options = predictions.shape
85
- predicted = predictions.argmax(dim=1).detach().numpy()
86
- true = ground_truth.int().detach().numpy()
87
- confusion_matrix = np.zeros([n_options, n_options])
88
- for i in range(bs):
89
- confusion_matrix[true[i], predicted[i]] += 1
90
- acc = confusion_matrix.diagonal().sum() / bs
91
- for i in range(n_options):
92
- if confusion_matrix[i].sum() != 0:
93
- confusion_matrix[i] /= confusion_matrix[i].sum()
94
- acc2 = np.mean(predicted == true)
95
- assert (acc - acc2) < 1e-5
96
- return confusion_matrix, acc
97
-
98
-
99
- def run_epoch(opt, train, model, data, loss_function, params):
100
- if train:
101
- model.train()
102
- else:
103
- model.eval()
104
-
105
- # prepare logging of losses
106
- losses = []
107
- accuracies = []
108
- cf_matrices = []
109
- if train: opt.zero_grad()
110
-
111
- for d in data:
112
- nb_ingredients = d[0]
113
- batch_size = nb_ingredients.shape[0]
114
- x_ingredients = d[1].float()
115
- ingredient_quantities = d[2].float()
116
- cocktail_reps = d[3].float()
117
- auxiliaries = d[4]
118
- for k in auxiliaries.keys():
119
- if auxiliaries[k].dtype == torch.float64: auxiliaries[k] = auxiliaries[k].float()
120
- taste_valid = d[-1]
121
- predictions = model(ingredient_quantities)
122
- loss = loss_function(predictions, auxiliaries[params['output_keyword']].long()).float()
123
- cf_matrix, accuracy = compute_confusion_matrix_and_accuracy(predictions, auxiliaries[params['output_keyword']])
124
- if train:
125
- loss.backward()
126
- opt.step()
127
- opt.zero_grad()
128
-
129
- losses.append(float(loss))
130
- cf_matrices.append(cf_matrix)
131
- accuracies.append(accuracy)
132
-
133
- return model, np.mean(losses), np.mean(accuracies), np.mean(cf_matrices, axis=0)
134
-
135
- def prepare_data_and_loss(params):
136
- train_data = MyDataset(split='train', params=params)
137
- test_data = MyDataset(split='test', params=params)
138
-
139
- train_data_loader = DataLoader(train_data, batch_size=params['batch_size'], shuffle=True)
140
- test_data_loader = DataLoader(test_data, batch_size=params['batch_size'], shuffle=True)
141
-
142
-
143
- if params['auxiliaries_dict'][params['output_keyword']]['type'] == 'classif':
144
- if params['output_keyword'] == 'glasses':
145
- classif_weights = train_data.glasses_weights
146
- elif params['output_keyword'] == 'prep_type':
147
- classif_weights = train_data.prep_types_weights
148
- elif params['output_keyword'] == 'categories':
149
- classif_weights = train_data.categories_weights
150
- else:
151
- raise ValueError
152
- # classif_weights = (np.array(classif_weights) * 2 + np.ones(len(classif_weights))) / 3
153
- loss_function = nn.CrossEntropyLoss(torch.FloatTensor(classif_weights))
154
- # loss_function = nn.CrossEntropyLoss()
155
-
156
- elif params['auxiliaries_dict'][params['output_keyword']]['type'] == 'multiclassif':
157
- loss_function = nn.BCEWithLogitsLoss()
158
- elif params['auxiliaries_dict'][params['output_keyword']]['type'] == 'regression':
159
- loss_function = nn.MSELoss()
160
- else:
161
- raise ValueError
162
-
163
- return loss_function, train_data_loader, test_data_loader
164
-
165
- def print_losses(train, loss, accuracy):
166
- keyword = 'Train' if train else 'Eval'
167
- print(f'\t{keyword} logs:')
168
- print(f'\t\t Loss: {loss:.2f}, Acc: {accuracy:.2f}')
169
-
170
-
171
- def run_experiment(params, verbose=True):
172
- loss_function, train_data_loader, test_data_loader = prepare_data_and_loss(params)
173
-
174
- model = SimpleNet(params['input_dim'], params['hidden_dims'], params['output_dim'], params['activation'], params['dropout'])
175
- opt = torch.optim.AdamW(model.parameters(), lr=params['lr'])
176
-
177
- all_train_losses = []
178
- all_eval_losses = []
179
- all_eval_cf_matrices = []
180
- all_train_accuracies = []
181
- all_eval_accuracies = []
182
- all_train_cf_matrices = []
183
- best_loss = np.inf
184
- model, eval_loss, eval_accuracy, eval_cf_matrix = run_epoch(opt=opt, train=False, model=model, data=test_data_loader, loss_function=loss_function, params=params)
185
- all_eval_losses.append(eval_loss)
186
- all_eval_accuracies.append(eval_accuracy)
187
- if verbose: print(f'\n--------\nEpoch #0')
188
- if verbose: print_losses(train=False, accuracy=eval_accuracy, loss=eval_loss)
189
- for epoch in range(params['nb_epochs']):
190
- if verbose and (epoch + 1) % params['print_every'] == 0: print(f'\n--------\nEpoch #{epoch+1}')
191
- model, train_loss, train_accuracy, train_cf_matrix = run_epoch(opt=opt, train=True, model=model, data=train_data_loader, loss_function=loss_function, params=params)
192
- if verbose and (epoch + 1) % params['print_every'] == 0: print_losses(train=True, accuracy=train_accuracy, loss=train_loss)
193
- model, eval_loss, eval_accuracy, eval_cf_matrix = run_epoch(opt=opt, train=False, model=model, data=test_data_loader, loss_function=loss_function, params=params)
194
- if verbose and (epoch + 1) % params['print_every'] == 0: print_losses(train=False, accuracy=eval_accuracy, loss=eval_loss)
195
- if eval_loss < best_loss:
196
- best_loss = eval_loss
197
- if verbose: print(f'Saving new best model with loss {best_loss:.2f}')
198
- torch.save(model.state_dict(), params['save_path'] + f'checkpoint_best.save')
199
-
200
- # log
201
- all_train_losses.append(train_loss)
202
- all_train_accuracies.append(train_accuracy)
203
- all_eval_losses.append(eval_loss)
204
- all_eval_accuracies.append(eval_accuracy)
205
- all_eval_cf_matrices.append(eval_cf_matrix)
206
- all_train_cf_matrices.append(train_cf_matrix)
207
-
208
- if (epoch + 1) % params['plot_every'] == 0:
209
-
210
- plot_results(all_train_losses, all_train_accuracies, all_train_cf_matrices,
211
- all_eval_losses, all_eval_accuracies, all_eval_cf_matrices, params['plot_path'])
212
-
213
- return model
214
-
215
- def plot_results(all_train_losses, all_train_accuracies, all_train_cf_matrices,
216
- all_eval_losses, all_eval_accuracies, all_eval_cf_matrices, plot_path):
217
-
218
- steps = np.arange(len(all_eval_accuracies))
219
-
220
- plt.figure()
221
- plt.title('Losses')
222
- plt.plot(steps[1:], all_train_losses, label='train')
223
- plt.plot(steps, all_eval_losses, label='eval')
224
- plt.legend()
225
- plt.ylim([0, 4])
226
- plt.savefig(plot_path + 'losses.png', dpi=200)
227
- fig = plt.gcf()
228
- plt.close(fig)
229
-
230
- plt.figure()
231
- plt.title('Accuracies')
232
- plt.plot(steps[1:], all_train_accuracies, label='train')
233
- plt.plot(steps, all_eval_accuracies, label='eval')
234
- plt.legend()
235
- plt.ylim([0, 1])
236
- plt.savefig(plot_path + 'accs.png', dpi=200)
237
- fig = plt.gcf()
238
- plt.close(fig)
239
-
240
-
241
- plt.figure()
242
- plt.title('Train confusion matrix')
243
- plt.ylabel('True')
244
- plt.xlabel('Predicted')
245
- plt.imshow(all_train_cf_matrices[-1], vmin=0, vmax=1)
246
- plt.colorbar()
247
- plt.savefig(plot_path + f'train_confusion_matrix.png', dpi=200)
248
- fig = plt.gcf()
249
- plt.close(fig)
250
-
251
- plt.figure()
252
- plt.title('Eval confusion matrix')
253
- plt.ylabel('True')
254
- plt.xlabel('Predicted')
255
- plt.imshow(all_eval_cf_matrices[-1], vmin=0, vmax=1)
256
- plt.colorbar()
257
- plt.savefig(plot_path + f'eval_confusion_matrix.png', dpi=200)
258
- fig = plt.gcf()
259
- plt.close(fig)
260
-
261
- plt.close('all')
262
-
263
-
264
- def get_model(model_path):
265
- with open(model_path + 'params.json', 'r') as f:
266
- params = json.load(f)
267
- params['save_path'] = model_path
268
- model_chkpt = model_path + "checkpoint_best.save"
269
- model = SimpleNet(params['input_dim'], params['hidden_dims'], params['output_dim'], params['activation'], params['dropout'])
270
- model.load_state_dict(torch.load(model_chkpt))
271
- model.eval()
272
- return model, params
273
-
274
-
275
- def compute_expe_name_and_save_path(params):
276
- weights_str = '['
277
- for aux in params['auxiliaries_dict'].keys():
278
- weights_str += f'{params["auxiliaries_dict"][aux]["weight"]}, '
279
- weights_str = weights_str[:-2] + ']'
280
- save_path = params['save_path'] + params["trial_id"]
281
- save_path += f'_lr{params["lr"]}'
282
- save_path += f'_bs{params["batch_size"]}'
283
- save_path += f'_hd{params["hidden_dims"]}'
284
- save_path += f'_activ{params["activation"]}'
285
- save_path += f'_w{weights_str}'
286
- counter = 0
287
- while os.path.exists(save_path + f"_{counter}"):
288
- counter += 1
289
- save_path = save_path + f"_{counter}" + '/'
290
- params["save_path"] = save_path
291
- os.makedirs(save_path)
292
- os.makedirs(save_path + 'plots/')
293
- params['plot_path'] = save_path + 'plots/'
294
- print(f'logging to {save_path}')
295
- return params
296
-
297
-
298
-
299
- if __name__ == '__main__':
300
- params = get_params()
301
- run_experiment(params)
302
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/cocktails/representation_learning/run_without_vae.py DELETED
@@ -1,514 +0,0 @@
1
- import torch; torch.manual_seed(0)
2
- import torch.utils
3
- from torch.utils.data import DataLoader
4
- import torch.distributions
5
- import torch.nn as nn
6
- import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
7
- from src.cocktails.representation_learning.dataset import MyDataset, get_representation_from_ingredient, get_max_n_ingredients
8
- import json
9
- import pandas as pd
10
- import numpy as np
11
- import os
12
- from src.cocktails.representation_learning.multihead_model import get_multihead_model
13
- from src.cocktails.config import COCKTAILS_CSV_DATA, FULL_COCKTAIL_REP_PATH, EXPERIMENT_PATH
14
- from src.cocktails.utilities.cocktail_utilities import get_bunch_of_rep_keys
15
- from src.cocktails.utilities.ingredients_utilities import ingredient_profiles
16
- from resource import getrusage
17
- from resource import RUSAGE_SELF
18
- import gc
19
- gc.collect(2)
20
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
21
-
22
- def get_params():
23
- data = pd.read_csv(COCKTAILS_CSV_DATA)
24
- max_ingredients, ingredient_set, liquor_set, liqueur_set = get_max_n_ingredients(data)
25
- num_ingredients = len(ingredient_set)
26
- rep_keys = get_bunch_of_rep_keys()['custom']
27
- ing_keys = [k.split(' ')[1] for k in rep_keys]
28
- ing_keys.remove('volume')
29
- nb_ing_categories = len(set(ingredient_profiles['type']))
30
- category_encodings = dict(zip(sorted(set(ingredient_profiles['type'])), np.eye(nb_ing_categories)))
31
-
32
- params = dict(trial_id='test',
33
- save_path=EXPERIMENT_PATH + "/multihead_model/",
34
- nb_epochs=500,
35
- print_every=50,
36
- plot_every=50,
37
- batch_size=128,
38
- lr=0.001,
39
- dropout=0.,
40
- nb_epoch_switch_beta=600,
41
- latent_dim=10,
42
- beta_vae=0.2,
43
- ing_keys=ing_keys,
44
- nb_ingredients=len(ingredient_set),
45
- hidden_dims_ingredients=[128],
46
- hidden_dims_cocktail=[64],
47
- hidden_dims_decoder=[32],
48
- agg='mean',
49
- activation='relu',
50
- auxiliaries_dict=dict(categories=dict(weight=5, type='classif', final_activ=None, dim_output=len(set(data['subcategory']))), #0.5
51
- glasses=dict(weight=0.5, type='classif', final_activ=None, dim_output=len(set(data['glass']))), #0.1
52
- prep_type=dict(weight=0.1, type='classif', final_activ=None, dim_output=len(set(data['category']))),#1
53
- cocktail_reps=dict(weight=1, type='regression', final_activ=None, dim_output=13),#1
54
- volume=dict(weight=1, type='regression', final_activ='relu', dim_output=1),#1
55
- taste_reps=dict(weight=1, type='regression', final_activ='relu', dim_output=2),#1
56
- ingredients_presence=dict(weight=0, type='multiclassif', final_activ=None, dim_output=num_ingredients),#10
57
- ingredients_quantities=dict(weight=0, type='regression', final_activ=None, dim_output=num_ingredients)),
58
- category_encodings=category_encodings
59
- )
60
- water_rep, indexes_to_normalize = get_representation_from_ingredient(ingredients=['water'], quantities=[1],
61
- max_q_per_ing=dict(zip(ingredient_set, [1] * num_ingredients)), index=0,
62
- params=params)
63
- dim_rep_ingredient = water_rep.size
64
- params['indexes_ing_to_normalize'] = indexes_to_normalize
65
- params['deepset_latent_dim'] = dim_rep_ingredient * max_ingredients
66
- params['dim_rep_ingredient'] = dim_rep_ingredient
67
- params['input_dim'] = params['nb_ingredients']
68
- params = compute_expe_name_and_save_path(params)
69
- del params['category_encodings'] # to dump
70
- with open(params['save_path'] + 'params.json', 'w') as f:
71
- json.dump(params, f)
72
-
73
- params = complete_params(params)
74
- return params
75
-
76
- def complete_params(params):
77
- data = pd.read_csv(COCKTAILS_CSV_DATA)
78
- cocktail_reps = np.loadtxt(FULL_COCKTAIL_REP_PATH)
79
- nb_ing_categories = len(set(ingredient_profiles['type']))
80
- category_encodings = dict(zip(sorted(set(ingredient_profiles['type'])), np.eye(nb_ing_categories)))
81
- params['cocktail_reps'] = cocktail_reps
82
- params['raw_data'] = data
83
- params['category_encodings'] = category_encodings
84
- return params
85
-
86
- def compute_losses_and_accuracies(loss_functions, auxiliaries, auxiliaries_str, outputs, data):
87
- losses = dict()
88
- accuracies = dict()
89
- other_metrics = dict()
90
- for i_k, k in enumerate(auxiliaries_str):
91
- # get ground truth
92
- # compute loss
93
- if k == 'volume':
94
- outputs[i_k] = outputs[i_k].flatten()
95
- ground_truth = auxiliaries[k]
96
- if ground_truth.dtype == torch.float64:
97
- losses[k] = loss_functions[k](outputs[i_k], ground_truth.float()).float()
98
- elif ground_truth.dtype == torch.int64:
99
- if str(loss_functions[k]) != "BCEWithLogitsLoss()":
100
- losses[k] = loss_functions[k](outputs[i_k].float(), ground_truth.long()).float()
101
- else:
102
- losses[k] = loss_functions[k](outputs[i_k].float(), ground_truth.float()).float()
103
- else:
104
- losses[k] = loss_functions[k](outputs[i_k], ground_truth).float()
105
- # compute accuracies
106
- if str(loss_functions[k]) == 'CrossEntropyLoss()':
107
- bs, n_options = outputs[i_k].shape
108
- predicted = outputs[i_k].argmax(dim=1).detach().numpy()
109
- true = ground_truth.int().detach().numpy()
110
- confusion_matrix = np.zeros([n_options, n_options])
111
- for i in range(bs):
112
- confusion_matrix[true[i], predicted[i]] += 1
113
- acc = confusion_matrix.diagonal().sum() / bs
114
- for i in range(n_options):
115
- if confusion_matrix[i].sum() != 0:
116
- confusion_matrix[i] /= confusion_matrix[i].sum()
117
- other_metrics[k + '_confusion'] = confusion_matrix
118
- accuracies[k] = np.mean(outputs[i_k].argmax(dim=1).detach().numpy() == ground_truth.int().detach().numpy())
119
- assert (acc - accuracies[k]) < 1e-5
120
-
121
- elif str(loss_functions[k]) == 'BCEWithLogitsLoss()':
122
- assert k == 'ingredients_presence'
123
- outputs_rescaled = outputs[i_k].detach().numpy() * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities
124
- predicted_presence = (outputs_rescaled > 0).astype(bool)
125
- presence = ground_truth.detach().numpy().astype(bool)
126
- other_metrics[k + '_false_positive'] = np.mean(np.logical_and(predicted_presence.astype(bool), ~presence.astype(bool)))
127
- other_metrics[k + '_false_negative'] = np.mean(np.logical_and(~predicted_presence.astype(bool), presence.astype(bool)))
128
- accuracies[k] = np.mean(predicted_presence == presence) # accuracy for multi class labeling
129
- elif str(loss_functions[k]) == 'MSELoss()':
130
- accuracies[k] = np.nan
131
- else:
132
- raise ValueError
133
- return losses, accuracies, other_metrics
134
-
135
- def compute_metric_output(aux_other_metrics, data, ingredient_quantities, x_hat):
136
- ing_q = ingredient_quantities.detach().numpy()# * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities
137
- ing_presence = (ing_q > 0)
138
- x_hat = x_hat.detach().numpy()
139
- # x_hat = x_hat.detach().numpy() * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities
140
- abs_diff = np.abs(ing_q - x_hat) * data.dataset.max_ing_quantities
141
- # abs_diff = np.abs(ing_q - x_hat)
142
- ing_q_abs_loss_when_present, ing_q_abs_loss_when_absent = [], []
143
- for i in range(ingredient_quantities.shape[0]):
144
- ing_q_abs_loss_when_present.append(np.mean(abs_diff[i, np.where(ing_presence[i])]))
145
- ing_q_abs_loss_when_absent.append(np.mean(abs_diff[i, np.where(~ing_presence[i])]))
146
- aux_other_metrics['ing_q_abs_loss_when_present'] = np.mean(ing_q_abs_loss_when_present)
147
- aux_other_metrics['ing_q_abs_loss_when_absent'] = np.mean(ing_q_abs_loss_when_absent)
148
- return aux_other_metrics
149
-
150
- def run_epoch(opt, train, model, data, loss_functions, weights, params):
151
- if train:
152
- model.train()
153
- else:
154
- model.eval()
155
-
156
- # prepare logging of losses
157
- losses = dict(kld_loss=[],
158
- mse_loss=[],
159
- vae_loss=[],
160
- volume_loss=[],
161
- global_loss=[])
162
- accuracies = dict()
163
- other_metrics = dict()
164
- for aux in params['auxiliaries_dict'].keys():
165
- losses[aux] = []
166
- accuracies[aux] = []
167
- if train: opt.zero_grad()
168
-
169
- for d in data:
170
- nb_ingredients = d[0]
171
- batch_size = nb_ingredients.shape[0]
172
- x_ingredients = d[1].float()
173
- ingredient_quantities = d[2]
174
- cocktail_reps = d[3]
175
- auxiliaries = d[4]
176
- for k in auxiliaries.keys():
177
- if auxiliaries[k].dtype == torch.float64: auxiliaries[k] = auxiliaries[k].float()
178
- taste_valid = d[-1]
179
- z, outputs, auxiliaries_str = model.forward(ingredient_quantities.float())
180
- # get auxiliary losses and accuracies
181
- aux_losses, aux_accuracies, aux_other_metrics = compute_losses_and_accuracies(loss_functions, auxiliaries, auxiliaries_str, outputs, data)
182
-
183
- # compute vae loss
184
- aux_other_metrics = compute_metric_output(aux_other_metrics, data, ingredient_quantities, outputs[auxiliaries_str.index('ingredients_quantities')])
185
-
186
- indexes_taste_valid = np.argwhere(taste_valid.detach().numpy()).flatten()
187
- if indexes_taste_valid.size > 0:
188
- outputs_taste = model.get_auxiliary(z[indexes_taste_valid], aux_str='taste_reps')
189
- gt = auxiliaries['taste_reps'][indexes_taste_valid]
190
- factor_loss = indexes_taste_valid.size / (0.3 * batch_size)# factor on the loss: if same ratio as actual dataset factor = 1 if there is less data, then the factor decreases, more data, it increases
191
- aux_losses['taste_reps'] = (loss_functions['taste_reps'](outputs_taste, gt) * factor_loss).float()
192
- else:
193
- aux_losses['taste_reps'] = torch.FloatTensor([0]).reshape([])
194
- aux_accuracies['taste_reps'] = 0
195
-
196
- # aggregate losses
197
- global_loss = torch.sum(torch.cat([torch.atleast_1d(aux_losses[k] * weights[k]) for k in params['auxiliaries_dict'].keys()]))
198
- # for k in params['auxiliaries_dict'].keys():
199
- # global_loss += aux_losses[k] * weights[k]
200
-
201
- if train:
202
- global_loss.backward()
203
- opt.step()
204
- opt.zero_grad()
205
-
206
- # logging
207
- losses['global_loss'].append(float(global_loss))
208
- for k in params['auxiliaries_dict'].keys():
209
- losses[k].append(float(aux_losses[k]))
210
- accuracies[k].append(float(aux_accuracies[k]))
211
- for k in aux_other_metrics.keys():
212
- if k not in other_metrics.keys():
213
- other_metrics[k] = [aux_other_metrics[k]]
214
- else:
215
- other_metrics[k].append(aux_other_metrics[k])
216
-
217
- for k in losses.keys():
218
- losses[k] = np.mean(losses[k])
219
- for k in accuracies.keys():
220
- accuracies[k] = np.mean(accuracies[k])
221
- for k in other_metrics.keys():
222
- other_metrics[k] = np.mean(other_metrics[k], axis=0)
223
- return model, losses, accuracies, other_metrics
224
-
225
- def prepare_data_and_loss(params):
226
- train_data = MyDataset(split='train', params=params)
227
- test_data = MyDataset(split='test', params=params)
228
-
229
- train_data_loader = DataLoader(train_data, batch_size=params['batch_size'], shuffle=True)
230
- test_data_loader = DataLoader(test_data, batch_size=params['batch_size'], shuffle=True)
231
-
232
- loss_functions = dict()
233
- weights = dict()
234
- for k in sorted(params['auxiliaries_dict'].keys()):
235
- if params['auxiliaries_dict'][k]['type'] == 'classif':
236
- if k == 'glasses':
237
- classif_weights = train_data.glasses_weights
238
- elif k == 'prep_type':
239
- classif_weights = train_data.prep_types_weights
240
- elif k == 'categories':
241
- classif_weights = train_data.categories_weights
242
- else:
243
- raise ValueError
244
- loss_functions[k] = nn.CrossEntropyLoss(torch.FloatTensor(classif_weights))
245
- elif params['auxiliaries_dict'][k]['type'] == 'multiclassif':
246
- loss_functions[k] = nn.BCEWithLogitsLoss()
247
- elif params['auxiliaries_dict'][k]['type'] == 'regression':
248
- loss_functions[k] = nn.MSELoss()
249
- else:
250
- raise ValueError
251
- weights[k] = params['auxiliaries_dict'][k]['weight']
252
-
253
-
254
- return loss_functions, train_data_loader, test_data_loader, weights
255
-
256
- def print_losses(train, losses, accuracies, other_metrics):
257
- keyword = 'Train' if train else 'Eval'
258
- print(f'\t{keyword} logs:')
259
- keys = ['global_loss', 'vae_loss', 'mse_loss', 'kld_loss', 'volume_loss']
260
- for k in keys:
261
- print(f'\t\t{k} - Loss: {losses[k]:.2f}')
262
- for k in sorted(accuracies.keys()):
263
- print(f'\t\t{k} (aux) - Loss: {losses[k]:.2f}, Acc: {accuracies[k]:.2f}')
264
- for k in sorted(other_metrics.keys()):
265
- if 'confusion' not in k:
266
- print(f'\t\t{k} - {other_metrics[k]:.2f}')
267
-
268
-
269
- def run_experiment(params, verbose=True):
270
- loss_functions, train_data_loader, test_data_loader, weights = prepare_data_and_loss(params)
271
-
272
- model_params = [params[k] for k in ["input_dim", "activation", "hidden_dims_cocktail", "latent_dim", "dropout", "auxiliaries_dict", "hidden_dims_decoder"]]
273
- model = get_multihead_model(*model_params)
274
- opt = torch.optim.AdamW(model.parameters(), lr=params['lr'])
275
-
276
-
277
- all_train_losses = []
278
- all_eval_losses = []
279
- all_train_accuracies = []
280
- all_eval_accuracies = []
281
- all_eval_other_metrics = []
282
- all_train_other_metrics = []
283
- best_loss = np.inf
284
- model, eval_losses, eval_accuracies, eval_other_metrics = run_epoch(opt=opt, train=False, model=model, data=test_data_loader, loss_functions=loss_functions,
285
- weights=weights, params=params)
286
- all_eval_losses.append(eval_losses)
287
- all_eval_accuracies.append(eval_accuracies)
288
- all_eval_other_metrics.append(eval_other_metrics)
289
- if verbose: print(f'\n--------\nEpoch #0')
290
- if verbose: print_losses(train=False, accuracies=eval_accuracies, losses=eval_losses, other_metrics=eval_other_metrics)
291
- for epoch in range(params['nb_epochs']):
292
- if verbose and (epoch + 1) % params['print_every'] == 0: print(f'\n--------\nEpoch #{epoch+1}')
293
- model, train_losses, train_accuracies, train_other_metrics = run_epoch(opt=opt, train=True, model=model, data=train_data_loader, loss_functions=loss_functions,
294
- weights=weights, params=params)
295
- if verbose and (epoch + 1) % params['print_every'] == 0: print_losses(train=True, accuracies=train_accuracies, losses=train_losses, other_metrics=train_other_metrics)
296
- model, eval_losses, eval_accuracies, eval_other_metrics = run_epoch(opt=opt, train=False, model=model, data=test_data_loader, loss_functions=loss_functions,
297
- weights=weights, params=params)
298
- if verbose and (epoch + 1) % params['print_every'] == 0: print_losses(train=False, accuracies=eval_accuracies, losses=eval_losses, other_metrics=eval_other_metrics)
299
- if eval_losses['global_loss'] < best_loss:
300
- best_loss = eval_losses['global_loss']
301
- if verbose: print(f'Saving new best model with loss {best_loss:.2f}')
302
- torch.save(model.state_dict(), params['save_path'] + f'checkpoint_best.save')
303
-
304
- # log
305
- all_train_losses.append(train_losses)
306
- all_train_accuracies.append(train_accuracies)
307
- all_eval_losses.append(eval_losses)
308
- all_eval_accuracies.append(eval_accuracies)
309
- all_eval_other_metrics.append(eval_other_metrics)
310
- all_train_other_metrics.append(train_other_metrics)
311
-
312
- # if epoch == params['nb_epoch_switch_beta']:
313
- # params['beta_vae'] = 2.5
314
- # params['auxiliaries_dict']['prep_type']['weight'] /= 10
315
- # params['auxiliaries_dict']['glasses']['weight'] /= 10
316
-
317
- if (epoch + 1) % params['plot_every'] == 0:
318
-
319
- plot_results(all_train_losses, all_train_accuracies, all_train_other_metrics,
320
- all_eval_losses, all_eval_accuracies, all_eval_other_metrics, params['plot_path'], weights)
321
-
322
- return model
323
-
324
- def plot_results(all_train_losses, all_train_accuracies, all_train_other_metrics,
325
- all_eval_losses, all_eval_accuracies, all_eval_other_metrics, plot_path, weights):
326
-
327
- steps = np.arange(len(all_eval_accuracies))
328
-
329
- loss_keys = sorted(all_train_losses[0].keys())
330
- acc_keys = sorted(all_train_accuracies[0].keys())
331
- metrics_keys = sorted(all_train_other_metrics[0].keys())
332
-
333
- plt.figure()
334
- plt.title('Train losses')
335
- for k in loss_keys:
336
- factor = 1 if k == 'mse_loss' else 1
337
- if k not in weights.keys():
338
- plt.plot(steps[1:], [train_loss[k] * factor for train_loss in all_train_losses], label=k)
339
- else:
340
- if weights[k] != 0:
341
- plt.plot(steps[1:], [train_loss[k] * factor for train_loss in all_train_losses], label=k)
342
-
343
- plt.legend()
344
- plt.ylim([0, 4])
345
- plt.savefig(plot_path + 'train_losses.png', dpi=200)
346
- fig = plt.gcf()
347
- plt.close(fig)
348
-
349
- plt.figure()
350
- plt.title('Train accuracies')
351
- for k in acc_keys:
352
- if weights[k] != 0:
353
- plt.plot(steps[1:], [train_acc[k] for train_acc in all_train_accuracies], label=k)
354
- plt.legend()
355
- plt.ylim([0, 1])
356
- plt.savefig(plot_path + 'train_acc.png', dpi=200)
357
- fig = plt.gcf()
358
- plt.close(fig)
359
-
360
- plt.figure()
361
- plt.title('Train other metrics')
362
- for k in metrics_keys:
363
- if 'confusion' not in k and 'presence' in k:
364
- plt.plot(steps[1:], [train_metric[k] for train_metric in all_train_other_metrics], label=k)
365
- plt.legend()
366
- plt.ylim([0, 1])
367
- plt.savefig(plot_path + 'train_ing_presence_errors.png', dpi=200)
368
- fig = plt.gcf()
369
- plt.close(fig)
370
-
371
- plt.figure()
372
- plt.title('Train other metrics')
373
- for k in metrics_keys:
374
- if 'confusion' not in k and 'presence' not in k:
375
- plt.plot(steps[1:], [train_metric[k] for train_metric in all_train_other_metrics], label=k)
376
- plt.legend()
377
- plt.ylim([0, 15])
378
- plt.savefig(plot_path + 'train_ing_q_error.png', dpi=200)
379
- fig = plt.gcf()
380
- plt.close(fig)
381
-
382
- plt.figure()
383
- plt.title('Eval losses')
384
- for k in loss_keys:
385
- factor = 1 if k == 'mse_loss' else 1
386
- if k not in weights.keys():
387
- plt.plot(steps, [eval_loss[k] * factor for eval_loss in all_eval_losses], label=k)
388
- else:
389
- if weights[k] != 0:
390
- plt.plot(steps, [eval_loss[k] * factor for eval_loss in all_eval_losses], label=k)
391
- plt.legend()
392
- plt.ylim([0, 4])
393
- plt.savefig(plot_path + 'eval_losses.png', dpi=200)
394
- fig = plt.gcf()
395
- plt.close(fig)
396
-
397
- plt.figure()
398
- plt.title('Eval accuracies')
399
- for k in acc_keys:
400
- if weights[k] != 0:
401
- plt.plot(steps, [eval_acc[k] for eval_acc in all_eval_accuracies], label=k)
402
- plt.legend()
403
- plt.ylim([0, 1])
404
- plt.savefig(plot_path + 'eval_acc.png', dpi=200)
405
- fig = plt.gcf()
406
- plt.close(fig)
407
-
408
- plt.figure()
409
- plt.title('Eval other metrics')
410
- for k in metrics_keys:
411
- if 'confusion' not in k and 'presence' in k:
412
- plt.plot(steps, [eval_metric[k] for eval_metric in all_eval_other_metrics], label=k)
413
- plt.legend()
414
- plt.ylim([0, 1])
415
- plt.savefig(plot_path + 'eval_ing_presence_errors.png', dpi=200)
416
- fig = plt.gcf()
417
- plt.close(fig)
418
-
419
- plt.figure()
420
- plt.title('Eval other metrics')
421
- for k in metrics_keys:
422
- if 'confusion' not in k and 'presence' not in k:
423
- plt.plot(steps, [eval_metric[k] for eval_metric in all_eval_other_metrics], label=k)
424
- plt.legend()
425
- plt.ylim([0, 15])
426
- plt.savefig(plot_path + 'eval_ing_q_error.png', dpi=200)
427
- fig = plt.gcf()
428
- plt.close(fig)
429
-
430
-
431
- for k in metrics_keys:
432
- if 'confusion' in k:
433
- plt.figure()
434
- plt.title(k)
435
- plt.ylabel('True')
436
- plt.xlabel('Predicted')
437
- plt.imshow(all_eval_other_metrics[-1][k], vmin=0, vmax=1)
438
- plt.colorbar()
439
- plt.savefig(plot_path + f'eval_{k}.png', dpi=200)
440
- fig = plt.gcf()
441
- plt.close(fig)
442
-
443
- for k in metrics_keys:
444
- if 'confusion' in k:
445
- plt.figure()
446
- plt.title(k)
447
- plt.ylabel('True')
448
- plt.xlabel('Predicted')
449
- plt.imshow(all_train_other_metrics[-1][k], vmin=0, vmax=1)
450
- plt.colorbar()
451
- plt.savefig(plot_path + f'train_{k}.png', dpi=200)
452
- fig = plt.gcf()
453
- plt.close(fig)
454
-
455
- plt.close('all')
456
-
457
-
458
- def get_model(model_path):
459
-
460
- with open(model_path + 'params.json', 'r') as f:
461
- params = json.load(f)
462
- params['save_path'] = model_path
463
- model_chkpt = model_path + "checkpoint_best.save"
464
- model_params = [params[k] for k in ["input_dim", "activation", "hidden_dims_cocktail", "latent_dim", "dropout", "auxiliaries_dict", "hidden_dims_decoder"]]
465
- model = get_multihead_model(*model_params)
466
- model.load_state_dict(torch.load(model_chkpt))
467
- model.eval()
468
- max_ing_quantities = np.loadtxt(model_path + 'max_ing_quantities.txt')
469
- def predict(ing_qs, aux_str):
470
- ing_qs /= max_ing_quantities
471
- input_model = torch.FloatTensor(ing_qs).reshape(1, -1)
472
- _, outputs, auxiliaries_str = model.forward(input_model, )
473
- if isinstance(aux_str, str):
474
- return outputs[auxiliaries_str.index(aux_str)].detach().numpy()
475
- elif isinstance(aux_str, list):
476
- return [outputs[auxiliaries_str.index(aux)].detach().numpy() for aux in aux_str]
477
- else:
478
- raise ValueError
479
- return predict, params
480
-
481
-
482
- def compute_expe_name_and_save_path(params):
483
- weights_str = '['
484
- for aux in params['auxiliaries_dict'].keys():
485
- weights_str += f'{params["auxiliaries_dict"][aux]["weight"]}, '
486
- weights_str = weights_str[:-2] + ']'
487
- save_path = params['save_path'] + params["trial_id"]
488
- save_path += f'_lr{params["lr"]}'
489
- save_path += f'_betavae{params["beta_vae"]}'
490
- save_path += f'_bs{params["batch_size"]}'
491
- save_path += f'_latentdim{params["latent_dim"]}'
492
- save_path += f'_hding{params["hidden_dims_ingredients"]}'
493
- save_path += f'_hdcocktail{params["hidden_dims_cocktail"]}'
494
- save_path += f'_hddecoder{params["hidden_dims_decoder"]}'
495
- save_path += f'_agg{params["agg"]}'
496
- save_path += f'_activ{params["activation"]}'
497
- save_path += f'_w{weights_str}'
498
- counter = 0
499
- while os.path.exists(save_path + f"_{counter}"):
500
- counter += 1
501
- save_path = save_path + f"_{counter}" + '/'
502
- params["save_path"] = save_path
503
- os.makedirs(save_path)
504
- os.makedirs(save_path + 'plots/')
505
- params['plot_path'] = save_path + 'plots/'
506
- print(f'logging to {save_path}')
507
- return params
508
-
509
-
510
-
511
- if __name__ == '__main__':
512
- params = get_params()
513
- run_experiment(params)
514
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/cocktails/representation_learning/simple_model.py DELETED
@@ -1,54 +0,0 @@
1
- import torch; torch.manual_seed(0)
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import torch.utils
5
- import torch.distributions
6
- import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
7
-
8
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
9
-
10
- def get_activation(activation):
11
- if activation == 'tanh':
12
- activ = F.tanh
13
- elif activation == 'relu':
14
- activ = F.relu
15
- elif activation == 'mish':
16
- activ = F.mish
17
- elif activation == 'sigmoid':
18
- activ = torch.sigmoid
19
- elif activation == 'leakyrelu':
20
- activ = F.leaky_relu
21
- elif activation == 'exp':
22
- activ = torch.exp
23
- else:
24
- raise ValueError
25
- return activ
26
-
27
-
28
- class SimpleNet(nn.Module):
29
- def __init__(self, input_dim, hidden_dims, output_dim, activation, dropout, final_activ=None):
30
- super(SimpleNet, self).__init__()
31
- self.linears = nn.ModuleList()
32
- self.dropouts = nn.ModuleList()
33
- self.output_dim = output_dim
34
- dims = [input_dim] + hidden_dims + [output_dim]
35
- for d_in, d_out in zip(dims[:-1], dims[1:]):
36
- self.linears.append(nn.Linear(d_in, d_out))
37
- self.dropouts.append(nn.Dropout(dropout))
38
- self.activation = get_activation(activation)
39
- self.n_layers = len(self.linears)
40
- self.layer_range = range(self.n_layers)
41
- if final_activ != None:
42
- self.final_activ = get_activation(final_activ)
43
- self.use_final_activ = True
44
- else:
45
- self.use_final_activ = False
46
-
47
- def forward(self, x):
48
- for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts):
49
- x = layer(x)
50
- if i_layer != self.n_layers - 1:
51
- x = self.activation(dropout(x))
52
- if self.use_final_activ: x = self.final_activ(x)
53
- return x
54
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/cocktails/representation_learning/vae_model.py DELETED
@@ -1,238 +0,0 @@
1
- import torch; torch.manual_seed(0)
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import torch.utils
5
- import torch.distributions
6
- import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
7
-
8
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
9
-
10
- def get_activation(activation):
11
- if activation == 'tanh':
12
- activ = F.tanh
13
- elif activation == 'relu':
14
- activ = F.relu
15
- elif activation == 'mish':
16
- activ = F.mish
17
- elif activation == 'sigmoid':
18
- activ = F.sigmoid
19
- elif activation == 'leakyrelu':
20
- activ = F.leaky_relu
21
- elif activation == 'exp':
22
- activ = torch.exp
23
- else:
24
- raise ValueError
25
- return activ
26
-
27
- class IngredientEncoder(nn.Module):
28
- def __init__(self, input_dim, deepset_latent_dim, hidden_dims, activation, dropout):
29
- super(IngredientEncoder, self).__init__()
30
- self.linears = nn.ModuleList()
31
- self.dropouts = nn.ModuleList()
32
- dims = [input_dim] + hidden_dims + [deepset_latent_dim]
33
- for d_in, d_out in zip(dims[:-1], dims[1:]):
34
- self.linears.append(nn.Linear(d_in, d_out))
35
- self.dropouts.append(nn.Dropout(dropout))
36
- self.activation = get_activation(activation)
37
- self.n_layers = len(self.linears)
38
- self.layer_range = range(self.n_layers)
39
-
40
- def forward(self, x):
41
- for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts):
42
- x = layer(x)
43
- if i_layer != self.n_layers - 1:
44
- x = self.activation(dropout(x))
45
- return x # do not use dropout on last layer?
46
-
47
- class DeepsetCocktailEncoder(nn.Module):
48
- def __init__(self, input_dim, deepset_latent_dim, hidden_dims_ing, activation,
49
- hidden_dims_cocktail, latent_dim, aggregation, dropout):
50
- super(DeepsetCocktailEncoder, self).__init__()
51
- self.input_dim = input_dim # dimension of ingredient representation + quantity
52
- self.ingredient_encoder = IngredientEncoder(input_dim, deepset_latent_dim, hidden_dims_ing, activation, dropout) # encode each ingredient separately
53
- self.deepset_latent_dim = deepset_latent_dim # dimension of the deepset aggregation
54
- self.aggregation = aggregation
55
- self.latent_dim = latent_dim
56
- # post aggregation network
57
- self.linears = nn.ModuleList()
58
- self.dropouts = nn.ModuleList()
59
- dims = [deepset_latent_dim] + hidden_dims_cocktail
60
- for d_in, d_out in zip(dims[:-1], dims[1:]):
61
- self.linears.append(nn.Linear(d_in, d_out))
62
- self.dropouts.append(nn.Dropout(dropout))
63
- self.FC_mean = nn.Linear(hidden_dims_cocktail[-1], latent_dim)
64
- self.FC_logvar = nn.Linear(hidden_dims_cocktail[-1], latent_dim)
65
- self.softplus = nn.Softplus()
66
-
67
- self.activation = get_activation(activation)
68
- self.n_layers = len(self.linears)
69
- self.layer_range = range(self.n_layers)
70
-
71
- def forward(self, nb_ingredients, x):
72
-
73
- # reshape x in (batch size * nb ingredients, dim_ing_rep)
74
- batch_size = x.shape[0]
75
- all_ingredients = []
76
- for i in range(batch_size):
77
- for j in range(nb_ingredients[i]):
78
- all_ingredients.append(x[i, self.input_dim * j: self.input_dim * (j + 1)].reshape(1, -1))
79
- x = torch.cat(all_ingredients, dim=0)
80
- # encode ingredients in parallel
81
- ingredients_encodings = self.ingredient_encoder(x)
82
- assert ingredients_encodings.shape == (torch.sum(nb_ingredients), self.deepset_latent_dim)
83
-
84
- # aggregate
85
- x = []
86
- index_first = 0
87
- for i in range(batch_size):
88
- index_last = index_first + nb_ingredients[i]
89
- # aggregate
90
- if self.aggregation == 'sum':
91
- x.append(torch.sum(ingredients_encodings[index_first:index_last], dim=0).reshape(1, -1))
92
- elif self.aggregation == 'mean':
93
- x.append(torch.mean(ingredients_encodings[index_first:index_last], dim=0).reshape(1, -1))
94
- else:
95
- raise ValueError
96
- index_first = index_last
97
- x = torch.cat(x, dim=0)
98
- assert x.shape[0] == batch_size
99
-
100
- for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts):
101
- x = self.activation(dropout(layer(x)))
102
- mean = self.FC_mean(x)
103
- logvar = self.FC_logvar(x)
104
- return mean, logvar
105
-
106
- class Decoder(nn.Module):
107
- def __init__(self, latent_dim, hidden_dims, num_ingredients, activation, dropout, filter_output=None):
108
- super(Decoder, self).__init__()
109
- self.linears = nn.ModuleList()
110
- self.dropouts = nn.ModuleList()
111
- dims = [latent_dim] + hidden_dims + [num_ingredients]
112
- for d_in, d_out in zip(dims[:-1], dims[1:]):
113
- self.linears.append(nn.Linear(d_in, d_out))
114
- self.dropouts.append(nn.Dropout(dropout))
115
- self.activation = get_activation(activation)
116
- self.n_layers = len(self.linears)
117
- self.layer_range = range(self.n_layers)
118
- self.filter = filter_output
119
-
120
- def forward(self, x, to_filter=False):
121
- for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts):
122
- x = layer(x)
123
- if i_layer != self.n_layers - 1:
124
- x = self.activation(dropout(x))
125
- if to_filter:
126
- x = self.filter(x)
127
- return x
128
-
129
- class PredictorHead(nn.Module):
130
- def __init__(self, latent_dim, dim_output, final_activ):
131
- super(PredictorHead, self).__init__()
132
- self.linear = nn.Linear(latent_dim, dim_output)
133
- if final_activ != None:
134
- self.final_activ = get_activation(final_activ)
135
- self.use_final_activ = True
136
- else:
137
- self.use_final_activ = False
138
-
139
- def forward(self, x):
140
- x = self.linear(x)
141
- if self.use_final_activ: x = self.final_activ(x)
142
- return x
143
-
144
-
145
- class VAEModel(nn.Module):
146
- def __init__(self, encoder, decoder, auxiliaries_dict):
147
- super(VAEModel, self).__init__()
148
- self.encoder = encoder
149
- self.decoder = decoder
150
- self.latent_dim = self.encoder.latent_dim
151
- self.auxiliaries_str = []
152
- self.auxiliaries = nn.ModuleList()
153
- for aux_str in sorted(auxiliaries_dict.keys()):
154
- if aux_str == 'taste_reps':
155
- self.taste_reps_decoder = PredictorHead(self.latent_dim, auxiliaries_dict[aux_str]['dim_output'], auxiliaries_dict[aux_str]['final_activ'])
156
- else:
157
- self.auxiliaries_str.append(aux_str)
158
- self.auxiliaries.append(PredictorHead(self.latent_dim, auxiliaries_dict[aux_str]['dim_output'], auxiliaries_dict[aux_str]['final_activ']))
159
-
160
- def reparameterization(self, mean, logvar):
161
- std = torch.exp(0.5 * logvar)
162
- epsilon = torch.randn_like(std).to(device) # sampling epsilon
163
- z = mean + std * epsilon # reparameterization trick
164
- return z
165
-
166
-
167
- def sample(self, n=1):
168
- z = torch.randn(size=(n, self.latent_dim))
169
- return self.decoder(z)
170
-
171
- def get_all_auxiliaries(self, x):
172
- return [aux(x) for aux in self.auxiliaries]
173
-
174
- def get_auxiliary(self, z, aux_str):
175
- if aux_str == 'taste_reps':
176
- return self.taste_reps_decoder(z)
177
- else:
178
- index = self.auxiliaries_str.index(aux_str)
179
- return self.auxiliaries[index](z)
180
-
181
- def forward_direct(self, x, aux_str=None, to_filter=False):
182
- mean, logvar = self.encoder(x)
183
- z = self.reparameterization(mean, logvar) # takes exponential function (log var -> std)
184
- x_hat = self.decoder(mean, to_filter=to_filter)
185
- if aux_str is not None:
186
- return x_hat, z, mean, logvar, self.get_auxiliary(z, aux_str), [aux_str]
187
- else:
188
- return x_hat, z, mean, logvar, self.get_all_auxiliaries(z), self.auxiliaries_str
189
-
190
- def forward(self, nb_ingredients, x, aux_str=None, to_filter=False):
191
- assert False
192
- mean, std = self.encoder(nb_ingredients, x)
193
- z = self.reparameterization(mean, std) # takes exponential function (log var -> std)
194
- x_hat = self.decoder(mean, to_filter=to_filter)
195
- if aux_str is not None:
196
- return x_hat, z, mean, std, self.get_auxiliary(z, aux_str), [aux_str]
197
- else:
198
- return x_hat, z, mean, std, self.get_all_auxiliaries(z), self.auxiliaries_str
199
-
200
-
201
-
202
-
203
- class SimpleEncoder(nn.Module):
204
-
205
- def __init__(self, input_dim, hidden_dims, latent_dim, activation, dropout):
206
- super(SimpleEncoder, self).__init__()
207
- self.latent_dim = latent_dim
208
- # post aggregation network
209
- self.linears = nn.ModuleList()
210
- self.dropouts = nn.ModuleList()
211
- dims = [input_dim] + hidden_dims
212
- for d_in, d_out in zip(dims[:-1], dims[1:]):
213
- self.linears.append(nn.Linear(d_in, d_out))
214
- self.dropouts.append(nn.Dropout(dropout))
215
- self.FC_mean = nn.Linear(hidden_dims[-1], latent_dim)
216
- self.FC_logvar = nn.Linear(hidden_dims[-1], latent_dim)
217
- # self.softplus = nn.Softplus()
218
-
219
- self.activation = get_activation(activation)
220
- self.n_layers = len(self.linears)
221
- self.layer_range = range(self.n_layers)
222
-
223
- def forward(self, x):
224
- for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts):
225
- x = self.activation(dropout(layer(x)))
226
- mean = self.FC_mean(x)
227
- logvar = self.FC_logvar(x)
228
- return mean, logvar
229
-
230
- def get_vae_model(input_dim, deepset_latent_dim, hidden_dims_ing, activation,
231
- hidden_dims_cocktail, hidden_dims_decoder, num_ingredients, latent_dim, aggregation, dropout, auxiliaries_dict,
232
- filter_decoder_output):
233
- # encoder = DeepsetCocktailEncoder(input_dim, deepset_latent_dim, hidden_dims_ing, activation,
234
- # hidden_dims_cocktail, latent_dim, aggregation, dropout)
235
- encoder = SimpleEncoder(num_ingredients, hidden_dims_cocktail, latent_dim, activation, dropout)
236
- decoder = Decoder(latent_dim, hidden_dims_decoder, num_ingredients, activation, dropout, filter_output=filter_decoder_output)
237
- vae = VAEModel(encoder, decoder, auxiliaries_dict)
238
- return vae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/cocktails/utilities/__init__.py DELETED
File without changes
src/cocktails/utilities/analysis_utilities.py DELETED
@@ -1,189 +0,0 @@
1
- import numpy as np
2
- import matplotlib.pyplot as plt
3
-
4
- from src.cocktails.utilities.ingredients_utilities import ingredient_list, extract_ingredients, ingredients_per_type
5
-
6
- color_codes = dict(ancestral='#000000',
7
- spirit_forward='#2320D2',
8
- duo='#6E20D2',
9
- champagne_cocktail='#25FFCA',
10
- complex_highball='#068F25',
11
- simple_highball='#25FF57',
12
- collins='#77FF96',
13
- julep='#25B8FF',
14
- simple_sour='#FBD756',
15
- complex_sour='#DCAD07',
16
- simple_sour_with_juice='#FF5033',
17
- complex_sour_with_juice='#D42306',
18
- # simple_sour_with_egg='#FF9C54',
19
- # complex_sour_with_egg='#CF5700',
20
- # almost_simple_sor='#FF5033',
21
- # almost_sor='#D42306',
22
- # almost_sor_with_egg='#D42306',
23
- other='#9B9B9B'
24
- )
25
-
26
- def get_subcategories(data):
27
- subcategories = np.array(data['subcategory'])
28
- sub_categories_list = sorted(set(subcategories))
29
- subcat_count = dict(zip(sub_categories_list, [0] * len(sub_categories_list)))
30
- for sc in data['subcategory']:
31
- subcat_count[sc] += 1
32
- return subcategories, sub_categories_list, subcat_count
33
-
34
- def get_ingredient_count(data):
35
- ingredient_counts = dict(zip(ingredient_list, [0] * len(ingredient_list)))
36
- for ing_str in data['ingredients_str']:
37
- ingredients, _ = extract_ingredients(ing_str)
38
- for ing in ingredients:
39
- ingredient_counts[ing] += 1
40
- return ingredient_counts
41
-
42
- def compute_eucl_dist(a, b):
43
- return np.sqrt(np.sum((a - b)**2))
44
-
45
- def recipe_contains(ingredients, stuff):
46
- if stuff in ingredient_list:
47
- return stuff in ingredients
48
- elif stuff == 'juice':
49
- return any(['juice' in ing and 'lemon' not in ing and 'lime' not in ing for ing in ingredients])
50
- elif stuff == 'bubbles':
51
- return any([ing in ['soda', 'tonic', 'cola', 'sparkling wine', 'ginger beer'] for ing in ingredients])
52
- elif stuff == 'acid':
53
- return any([ing in ['lemon juice', 'lime juice'] for ing in ingredients])
54
- elif stuff == 'vermouth':
55
- return any([ing in ingredients_per_type['vermouth'] for ing in ingredients])
56
- elif stuff == 'plain sweet':
57
- plain_sweet = ingredients_per_type['sweeteners']
58
- return any([ing in plain_sweet for ing in ingredients])
59
- elif stuff == 'sweet':
60
- sweet = ingredients_per_type['sweeteners'] + ingredients_per_type['liqueur'] + ['sweet vermouth', 'lillet blanc']
61
- return any([ing in sweet for ing in ingredients])
62
- elif stuff == 'spirit':
63
- return any([ing in ingredients_per_type['liquor'] for ing in ingredients])
64
- else:
65
- raise ValueError
66
-
67
-
68
-
69
- def radar_factory(num_vars, frame='circle'):
70
- # from stackoverflow's post? Or matplotlib's blog
71
- """
72
- Create a radar chart with `num_vars` axes.
73
-
74
- This function creates a RadarAxes projection and registers it.
75
-
76
- Parameters
77
- ----------
78
- num_vars : int
79
- Number of variables for radar chart.
80
- frame : {'circle', 'polygon'}
81
- Shape of frame surrounding axes.
82
-
83
- """
84
- import numpy as np
85
-
86
- from matplotlib.patches import Circle, RegularPolygon
87
- from matplotlib.path import Path
88
- from matplotlib.projections.polar import PolarAxes
89
- from matplotlib.projections import register_projection
90
- from matplotlib.spines import Spine
91
- from matplotlib.transforms import Affine2D
92
- # calculate evenly-spaced axis angles
93
- theta = np.linspace(0, 2*np.pi, num_vars, endpoint=False)
94
-
95
- class RadarAxes(PolarAxes):
96
-
97
- name = 'radar'
98
- # use 1 line segment to connect specified points
99
- RESOLUTION = 1
100
-
101
- def __init__(self, *args, **kwargs):
102
- super().__init__(*args, **kwargs)
103
- # rotate plot such that the first axis is at the top
104
- self.set_theta_zero_location('N')
105
-
106
- def fill(self, *args, closed=True, **kwargs):
107
- """Override fill so that line is closed by default"""
108
- return super().fill(closed=closed, *args, **kwargs)
109
-
110
- def plot(self, *args, **kwargs):
111
- """Override plot so that line is closed by default"""
112
- lines = super().plot(*args, **kwargs)
113
- for line in lines:
114
- self._close_line(line)
115
-
116
- def _close_line(self, line):
117
- x, y = line.get_data()
118
- # FIXME: markers at x[0], y[0] get doubled-up
119
- if x[0] != x[-1]:
120
- x = np.append(x, x[0])
121
- y = np.append(y, y[0])
122
- line.set_data(x, y)
123
-
124
- def set_varlabels(self, labels):
125
- self.set_thetagrids(np.degrees(theta), labels)
126
-
127
- def _gen_axes_patch(self):
128
- # The Axes patch must be centered at (0.5, 0.5) and of radius 0.5
129
- # in axes coordinates.
130
- if frame == 'circle':
131
- return Circle((0.5, 0.5), 0.5)
132
- elif frame == 'polygon':
133
- return RegularPolygon((0.5, 0.5), num_vars,
134
- radius=.5, edgecolor="k")
135
- else:
136
- raise ValueError("Unknown value for 'frame': %s" % frame)
137
-
138
- def _gen_axes_spines(self):
139
- if frame == 'circle':
140
- return super()._gen_axes_spines()
141
- elif frame == 'polygon':
142
- # spine_type must be 'left'/'right'/'top'/'bottom'/'circle'.
143
- spine = Spine(axes=self,
144
- spine_type='circle',
145
- path=Path.unit_regular_polygon(num_vars))
146
- # unit_regular_polygon gives a polygon of radius 1 centered at
147
- # (0, 0) but we want a polygon of radius 0.5 centered at (0.5,
148
- # 0.5) in axes coordinates.
149
- spine.set_transform(Affine2D().scale(.5).translate(.5, .5)
150
- + self.transAxes)
151
- return {'polar': spine}
152
- else:
153
- raise ValueError("Unknown value for 'frame': %s" % frame)
154
-
155
- register_projection(RadarAxes)
156
- return theta
157
-
158
- def plot_radar_cocktail(representation, labels_dim, labels_cocktails, save_path=None, to_show=False, to_save=False):
159
- assert to_show or to_save, 'either show or save'
160
- assert representation.ndim == 2
161
- n_data, dim_rep = representation.shape
162
- assert len(labels_cocktails) == n_data
163
- assert len(labels_dim) == dim_rep
164
- assert n_data <= 5, 'max 5 representation_analysis please'
165
-
166
- theta = radar_factory(dim_rep, frame='circle')
167
-
168
-
169
- fig, ax = plt.subplots(figsize=(9, 9), subplot_kw=dict(projection='radar'))
170
- fig.subplots_adjust(wspace=0.25, hspace=0.20, top=0.85, bottom=0.05)
171
-
172
- colors = ['b', 'r', 'g', 'm', 'y']
173
- # Plot the four cases from the example data on separate axes
174
- ax.set_rgrids([0.2, 0.4, 0.6, 0.8])
175
- for d, color in zip(representation, colors):
176
- ax.plot(theta, d, color=color)
177
- for d, color in zip(representation, colors):
178
- ax.fill(theta, d, facecolor=color, alpha=0.25)
179
- ax.set_varlabels(labels_dim)
180
-
181
- # add legend relative to top-left plot
182
- legend = ax.legend(labels_cocktails, loc=(0.9, .95),
183
- labelspacing=0.1, fontsize='small')
184
-
185
- if to_save:
186
- plt.savefig(save_path, bbox_artists=(legend,), dpi=200)
187
- else:
188
- plt.show()
189
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/cocktails/utilities/cocktail_category_detection_utilities.py DELETED
@@ -1,221 +0,0 @@
1
- # The following functions check whether a cocktail belong to any of N categories
2
- import numpy as np
3
- from src.cocktails.utilities.ingredients_utilities import ingredient_profiles, ingredients_per_type, ingredient2ingredient_id, extract_ingredients
4
-
5
-
6
- def is_ancestral(n, ingredient_indexes, ingredients, quantities):
7
- # ancestrals have a strong spirit and some sweetness from sugar, syrup or liqueurs, no citrus.
8
- # absinthe can be added up to 3 dashes.
9
- # Liqueurs are there to bring sweetness, thus must stay below 15ml (if not it's a duo)
10
- if n['spirit'] > 0 and n['citrus'] == 0 and n['plain_sweet'] + n['liqueur'] <= 2:
11
- if n['spirit'] > 1 and 'absinthe' in ingredients:
12
- if quantities[ingredients.index('absinthe')] < 3:
13
- pass
14
- else:
15
- return False
16
- if n['sugar'] < 2 and n['liqueur'] < 3:
17
- if n['all'] - n['spirit'] - n['sugar'] -n['syrup']- n['liqueur']- n['inconsequentials'] == 0:
18
- if n['liqueur'] == 0:
19
- return True
20
- else:
21
- q_liqueur = np.sum([quantities[i_ing]
22
- for i_ind, i_ing in zip(ingredient_indexes, range(len(ingredients)))
23
- if ingredient_profiles['type'][i_ind].lower() == 'liqueur'])
24
- if q_liqueur <= 15:
25
- return True
26
- else:
27
- return False
28
- return False
29
-
30
-
31
- def is_simple_sour(n, ingredient_indexes, ingredients, quantities):
32
- # simple sours contain a citrus, at least 1 spirit and non-alcoholic sweetness
33
- if n['citrus'] + n['coffee']> 0 and n['spirit'] > 0 and n['plain_sweet'] > 0 and n['juice'] == 0:
34
- if n['all'] - n['citrus'] - n['coffee'] - n['spirit'] - n['plain_sweet'] - n['juice'] -n['egg'] - n['inconsequentials'] == 0:
35
- return True
36
- return False
37
-
38
- def is_complex_sour(n, ingredient_indexes, ingredients, quantities):
39
- # complex sours are simple sours that use alcoholic sweetness, at least in part
40
- if n['citrus'] + n['coffee'] > 0 and n['all_sweet'] > 0 and n['juice'] == 0:
41
- if (n['spirit'] == 0 and n['liqueur'] > 0) or n['spirit'] > 0:
42
- if n['vermouth'] + n['liqueur'] <= 2 and n['vermouth'] + n['liqueur'] > 0:
43
- if n['all'] -n['coffee'] - n['citrus'] - n['spirit'] - n['sugar'] - n['syrup'] \
44
- - n['liqueur'] - n['vermouth'] - n['egg'] - n['juice'] - n['inconsequentials'] == 0:
45
- return True
46
- return False
47
-
48
- def is_spirit_forward(n, ingredient_indexes, ingredients, quantities):
49
- # spirit forward contain at least a spirit and vermouth, no citrus. Can contain sweet (sugar, syrups, liqueurs)
50
- if n['spirit'] > 0 and n['citrus'] == 0 and n['vermouth'] > 0:
51
- if n['all'] - n['spirit'] - n['sugar'] - n['syrup'] - n['liqueur'] -n['egg'] - n['vermouth'] - n['inconsequentials']== 0:
52
- return True
53
- return False
54
-
55
- def is_duo(n, ingredient_indexes, ingredients, quantities):
56
- # duos are made of one spirit and one liqueur (above 15ml), under it's an ancestral, no citrus.
57
- if n['spirit'] >= 1 and n['citrus'] == 0 and n['sugar']==0 and n['liqueur'] > 0 and n['vermouth'] == 0:
58
- if n['all'] - n['spirit'] - n['sugar'] - n['liqueur'] - n['vermouth'] - n['inconsequentials'] == 0:
59
- q_liqueur = np.sum([quantities[i_ing]
60
- for i_ind, i_ing in zip(ingredient_indexes, range(len(ingredients)))
61
- if ingredient_profiles['type'][i_ind].lower() == 'liqueur'])
62
- if q_liqueur > 15:
63
- return True
64
- else:
65
- return False
66
- return False
67
-
68
- def is_champagne_cocktail(n, ingredient_indexes, ingredients, quantities):
69
- if n['sparkling'] > 0:
70
- return True
71
- else:
72
- return False
73
-
74
- def is_simple_highball(n, ingredient_indexes, ingredients, quantities):
75
- # simple highballs have one alcoholic ingredient and bubbles
76
- if n['alcoholic'] == 1 and n['bubbles'] > 0:
77
- if n['all'] - n['alcoholic'] - n['bubbles'] - n['inconsequentials']== 0:
78
- return True
79
- return False
80
-
81
- def is_complex_highball(n, ingredient_indexes, ingredients, quantities):
82
- # complex highballs have at least one alcoholic ingredient and bubbles (possibly alcoholic). They also contain extra sugar under any form and juice
83
- if n['alcoholic'] > 0 and (n['bubbles'] + n['sparkling']) == 1 and n['juice'] + n['all_sweet'] + n['sugar_bubbles']> 0:
84
- if n['all'] - n['spirit'] - n['bubbles'] - n['sparkling'] - n['citrus'] - n['juice'] - n['liqueur'] \
85
- - n['syrup'] - n['sugar'] -n['vermouth'] -n['egg'] - n['inconsequentials'] == 0:
86
- if not is_collins(n, ingredient_indexes, ingredients, quantities) and not is_simple_highball(n, ingredient_indexes, ingredients, quantities):
87
- return True
88
- return False
89
-
90
- def is_collins(n, ingredient_indexes, ingredients, quantities):
91
- # collins are a particular kind of highball with sugar and citrus
92
- if n['alcoholic'] == 1 and n['bubbles'] == 1 and n['citrus'] > 0 and n['plain_sweet'] + n['sugar_bubbles'] > 0:
93
- if n['all'] - n['spirit'] - n['bubbles'] - n['citrus'] - n['sugar'] - n['inconsequentials'] == 0:
94
- return True
95
- return False
96
-
97
- def is_julep(n, ingredient_indexes, ingredients, quantities):
98
- # juleps involve smashd mint, sugar and a spirit, no citrus.
99
- if 'mint' in ingredients and n['sugar'] > 0 and n['spirit'] > 0 and n['vermouth'] == 0 and n['citrus'] == 0:
100
- return True
101
- return False
102
-
103
- def is_simple_sour_with_juice(n, ingredient_indexes, ingredients, quantities):
104
- # almost sours are sours with juice
105
- if n['juice'] > 0 and n['spirit'] > 0 and n['plain_sweet'] > 0:
106
- if n['all'] - n['citrus'] - n['coffee'] - n['juice'] - n['spirit'] - n['sugar'] - n['syrup'] - n['egg'] - n['inconsequentials'] == 0:
107
- return True
108
- return False
109
-
110
-
111
- def is_complex_sour_with_juice(n, ingredient_indexes, ingredients, quantities):
112
- # almost sours are sours with juice
113
- if n['juice'] > 0 and n['all_sweet'] > 0:
114
- if (n['spirit'] == 0 and n['liqueur'] > 0) or n['spirit'] > 0:
115
- if n['vermouth'] + n['liqueur'] <= 2 and n['vermouth'] + n['liqueur'] > 0:
116
- if n['all'] -n['coffee'] - n['citrus'] - n['spirit'] - n['sugar'] - n['syrup'] \
117
- - n['liqueur'] - n['vermouth'] - n['egg'] - n['juice'] - n['inconsequentials'] == 0:
118
- return True
119
- return False
120
-
121
-
122
- is_sub_category = [is_ancestral, is_complex_sour, is_simple_sour, is_duo, is_champagne_cocktail,
123
- is_spirit_forward, is_simple_highball, is_complex_highball, is_collins,
124
- is_julep, is_simple_sour_with_juice, is_complex_sour_with_juice]
125
- sub_categories = ['ancestral', 'complex_sour', 'simple_sour', 'duo', 'champagne_cocktail',
126
- 'spirit_forward', 'simple_highball', 'complex_highball', 'collins',
127
- 'julep', 'simple_sour_with_juice', 'complex_sour_with_juice']
128
-
129
-
130
- # compute cocktail category as a function of ingredients and quantities, uses name to check match between name and cat (e.g. XXX Collins should be collins..)
131
- # Categories definitions are based on https://www.seriouseats.com/cocktail-style-guide-categories-of-cocktails-glossary-families-of-drinks
132
- def find_cocktail_sub_category(ingredients, quantities, name=None):
133
- ingredient_indexes = [ingredient2ingredient_id[ing] for ing in ingredients]
134
- n_spirit = np.sum([ingredient_profiles['type'][i].lower() == 'liquor' for i in ingredient_indexes ])
135
- n_citrus = np.sum([ingredient_profiles['type'][i].lower()== 'acid' for i in ingredient_indexes])
136
- n_sugar = np.sum([ingredient_profiles['ingredient'][i].lower() in ['double syrup', 'simple syrup', 'honey syrup'] for i in ingredient_indexes])
137
- plain_sweet = ingredients_per_type['sweeteners']
138
- all_sweet = ingredients_per_type['sweeteners'] + ingredients_per_type['liqueur'] + ['sweet vermouth', 'lillet blanc']
139
- n_plain_sweet = np.sum([ingredient_profiles['ingredient'][i].lower() in plain_sweet for i in ingredient_indexes])
140
- n_all_sweet = np.sum([ingredient_profiles['ingredient'][i].lower() in all_sweet for i in ingredient_indexes])
141
- n_sugar_bubbles = np.sum([ingredient_profiles['ingredient'][i].lower() in ['cola', 'ginger beer', 'tonic'] for i in ingredient_indexes])
142
- n_juice = np.sum([ingredient_profiles['type'][i].lower() == 'juice' for i in ingredient_indexes])
143
- n_liqueur = np.sum([ingredient_profiles['type'][i].lower() == 'liqueur' for i in ingredient_indexes])
144
- alcoholic = ingredients_per_type['liquor'] + ingredients_per_type['liqueur'] + ingredients_per_type['vermouth']
145
- n_alcoholic = np.sum([ingredient_profiles['ingredient'][i].lower() in alcoholic for i in ingredient_indexes])
146
- n_bitter = np.sum([ingredient_profiles['type'][i].lower() == 'bitters' for i in ingredient_indexes])
147
- n_egg = np.sum([ingredient_profiles['ingredient'][i].lower() == 'egg' for i in ingredient_indexes])
148
- n_vermouth = np.sum([ingredient_profiles['type'][i].lower() == 'vermouth' for i in ingredient_indexes])
149
- n_sparkling = np.sum([ingredient_profiles['ingredient'][i].lower() == 'sparkling wine' for i in ingredient_indexes])
150
- n_bubbles = np.sum([ingredient_profiles['ingredient'][i].lower() in ['soda', 'tonic', 'cola', 'ginger beer'] for i in ingredient_indexes])
151
- n_syrup = np.sum([ingredient_profiles['ingredient'][i].lower() in ['grenadine', 'raspberry syrup'] for i in ingredient_indexes])
152
- n_coffee = np.sum([ingredient_profiles['ingredient'][i].lower() == 'espresso' for i in ingredient_indexes])
153
- inconsequentials = ['water', 'salt', 'angostura', 'orange bitters', 'mint']
154
- n_inconsequentials = np.sum([ingredient_profiles['ingredient'][i].lower() in inconsequentials for i in ingredient_indexes])
155
- n = dict(all=len(ingredients),
156
- inconsequentials=n_inconsequentials,
157
- sugar_bubbles=n_sugar_bubbles,
158
- bubbles=n_bubbles,
159
- plain_sweet=n_plain_sweet,
160
- all_sweet=n_all_sweet,
161
- coffee=n_coffee,
162
- alcoholic=n_alcoholic,
163
- syrup=n_syrup,
164
- sparkling=n_sparkling,
165
- sugar=n_sugar,
166
- spirit=n_spirit,
167
- citrus=n_citrus,
168
- juice=n_juice,
169
- liqueur=n_liqueur,
170
- bitter=n_bitter,
171
- egg=n_egg,
172
- vermouth=n_vermouth)
173
-
174
- sub_cats = [c for c, test_c in zip(sub_categories, is_sub_category) if test_c(n, ingredient_indexes, ingredients, quantities)]
175
- if name != None:
176
- name = name.lower()
177
- keywords_to_test = ['julep', 'collins', 'highball', 'sour', 'champagne']
178
- for k in keywords_to_test:
179
- if k in name and not any([k in cat for cat in sub_cats]):
180
- print(k)
181
- for ing, q in zip(ingredients, quantities):
182
- print(f'{ing}: {q} ml')
183
- print(n)
184
- break
185
- if sorted(sub_cats) == ['champagne_cocktail', 'complex_highball']:
186
- sub_cats = ['champagne_cocktail']
187
- elif sorted(sub_cats) == ['collins', 'complex_highball']:
188
- sub_cats = ['collins']
189
- elif sorted(sub_cats) == ['champagne_cocktail', 'complex_highball', 'julep']:
190
- sub_cats = ['champagne_cocktail']
191
- elif sorted(sub_cats) == ['ancestral', 'julep']:
192
- sub_cats = ['julep']
193
- elif sorted(sub_cats) == ['complex_highball', 'julep']:
194
- sub_cats = ['complex_highball']
195
- elif sorted(sub_cats) == ['julep', 'simple_sour_with_juice']:
196
- sub_cats = ['simple_sour_with_juice']
197
- elif sorted(sub_cats) == ['complex_sour_with_juice', 'julep']:
198
- sub_cats = ['complex_sour_with_juice']
199
- if len(sub_cats) != 1:
200
- # print(sub_cats)
201
- # for ing, q in zip(ingredients, quantities):
202
- # print(f'{ing}: {q} ml')
203
- # print(n)
204
- # if len(sub_cats) == 0:
205
- sub_cats = ['other']
206
- assert len(sub_cats) == 1, sub_cats
207
- return sub_cats[0], n
208
-
209
- def get_cocktails_attributes(ing_strs):
210
- attributes = dict()
211
- cats = []
212
- for ing_str in ing_strs:
213
- ingredients, quantities = extract_ingredients(ing_str)
214
- cat, atts = find_cocktail_sub_category(ingredients, quantities)
215
- for k in atts.keys():
216
- if k not in attributes.keys():
217
- attributes[k] = [atts[k]]
218
- else:
219
- attributes[k].append(atts[k])
220
- cats.append(cat)
221
- return cats, attributes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/cocktails/utilities/cocktail_generation_utilities/__init__.py DELETED
File without changes
src/cocktails/utilities/cocktail_generation_utilities/individual.py DELETED
@@ -1,587 +0,0 @@
1
- from src.cocktails.utilities.ingredients_utilities import get_ingredients_info, format_ingredients, extract_ingredients, ingredients_per_type, bubble_ingredients
2
- import numpy as np
3
- from src.cocktails.utilities.other_scrubbing_utilities import print_recipe
4
- from src.cocktails.utilities.cocktail_utilities import get_cocktail_rep, get_profile, get_bunch_of_rep_keys
5
- from src.cocktails.utilities.glass_and_volume_utilities import glass_volume
6
- from src.cocktails.representation_learning.run import get_model
7
- from src.cocktails.pipeline.get_cocktail2affective_cluster import get_cocktail2affective_cluster
8
- from src.cocktails.config import COCKTAILS_CSV_DATA, FULL_COCKTAIL_REP_PATH, REPO_PATH, COCKTAIL_REP_CHKPT_PATH, RECIPE2FEATURES_PATH
9
- from src.cocktails.representation_learning.run_without_vae import get_model
10
- from src.cocktails.utilities.cocktail_category_detection_utilities import find_cocktail_sub_category
11
-
12
- import pandas as pd
13
- import torch
14
- import time
15
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
16
-
17
- density_ingredients = np.loadtxt(COCKTAIL_REP_CHKPT_PATH + 'density_ingredients.txt')
18
- max_ingredients, ingredient_list, ind_alcohol = get_ingredients_info()
19
- min_ingredients = 2
20
- factor_max = 1.2 # generated recipes can go up to 1.2 times the max quantity of the ingredient found in the dataset
21
-
22
- prep_model = get_model(RECIPE2FEATURES_PATH + 'multi_predictor/')[0]
23
-
24
- all_rep_path = FULL_COCKTAIL_REP_PATH
25
- all_reps = np.loadtxt(all_rep_path)
26
- experiment_dir = REPO_PATH + '/experiments/cocktails/'
27
- rep_keys = get_bunch_of_rep_keys()['custom']
28
- dict_weights_mse_computation = {'end volume': .1, 'end sour': 2, 'end sweet': 2, 'end booze': 4, 'end bitter': 2, 'end fruit': 1, 'end herb': 1,
29
- 'end complex': 1, 'end spicy': 5, 'end oaky': 1, 'end fizzy': 10, 'end colorful': 1, 'end eggy': 10}
30
- assert sorted(dict_weights_mse_computation.keys()) == sorted(rep_keys)
31
- weights_mse_computation = np.array([dict_weights_mse_computation[k] for k in rep_keys])
32
- weights_mse_computation /= weights_mse_computation.sum()
33
- data = pd.read_csv(COCKTAILS_CSV_DATA)
34
- preparation_list = sorted(set(data['category']))
35
- glasses_list = sorted(set(data['glass']))
36
-
37
- weights_perf_n_ing = {2:0.71, 3:0.81, 4:0.93, 5:1., 6:1.03, 7:1.08, 8:1.05}
38
-
39
- # weights_perf_n_ing = {2:0.75, 3:0.8, 4:0.95, 5:1.05, 6:1.05, 7:1.05, 8:1.05}
40
- min_ingredients_quantities_when_present = np.loadtxt(COCKTAIL_REP_CHKPT_PATH +'ingredients_min_quantities_when_present.txt')
41
- min_ingredients_quantities = np.loadtxt(COCKTAIL_REP_CHKPT_PATH +'ingredients_min_quantities.txt')
42
- max_ingredients_quantities = np.loadtxt(COCKTAIL_REP_CHKPT_PATH + 'ingredients_max_quantities.txt')
43
- min_cocktail_rep, max_cocktail_rep = np.loadtxt(COCKTAIL_REP_CHKPT_PATH +'cocktail_minmax_dim13_customkeys.txt')
44
- distrib_nb_ings_2_8 = np.loadtxt(COCKTAIL_REP_CHKPT_PATH + 'distrib_nb_ing.txt')[2:]
45
- def normalize_cocktail(cocktail_rep):
46
- return ((cocktail_rep - min_cocktail_rep) / (max_cocktail_rep - min_cocktail_rep) - 0.5) * 2
47
-
48
- def denormalize_cocktail(cocktail_rep):
49
- return (cocktail_rep / 2 + 0.5) * (max_cocktail_rep - min_cocktail_rep) + min_cocktail_rep
50
-
51
- def normalize_ingredient_q_rep(ingredients_q):
52
- return (ingredients_q - min_ingredients_quantities_when_present) / (max_ingredients_quantities * factor_max - min_ingredients_quantities_when_present)
53
-
54
- COCKTAIL_REPS = normalize_cocktail(np.array([data[k] for k in rep_keys]).transpose())
55
- assert np.abs(COCKTAIL_REPS - all_reps).sum() < 1e-8
56
-
57
- cocktail2affective_cluster = get_cocktail2affective_cluster()
58
-
59
- original_affective_keys = get_bunch_of_rep_keys()['affective']
60
- def sigmoid(x, shift, beta):
61
- return (1 / (1 + np.exp(-(x + shift) * beta)) - 0.5) * 2
62
-
63
- def get_normalized_affective_cocktail_rep_from_normalized_cocktail_rep(cocktail_rep):
64
- indexes = np.array([rep_keys.index(key) for key in original_affective_keys])
65
- cocktail_rep = cocktail_rep[indexes]
66
- cocktail_rep[0] = sigmoid(cocktail_rep[0], shift=0.05, beta=4)
67
- cocktail_rep[1] = sigmoid(cocktail_rep[1], shift=0.3, beta=5)
68
- cocktail_rep[2] = sigmoid(cocktail_rep[2], shift=0.15, beta=3)
69
- cocktail_rep[3] = sigmoid(cocktail_rep[3], shift=0.9, beta=20)
70
- cocktail_rep[4] = sigmoid(cocktail_rep[4], shift=0, beta=4)
71
- cocktail_rep[5] = sigmoid(cocktail_rep[5], shift=0.2, beta=3)
72
- cocktail_rep[6] = sigmoid(cocktail_rep[6], shift=0.5, beta=5)
73
- cocktail_rep[7] = sigmoid(cocktail_rep[7], shift=0.2, beta=6)
74
- return cocktail_rep
75
-
76
- class IndividualCocktail():
77
- def __init__(self, pop_params, target, target_affective_cluster, genes_presence=None, genes_quantity=None,
78
- compute_perf=True, known_target_dict=None, run_hard_check=False):
79
-
80
- self.pop_params = pop_params
81
- self.n_genes = len(ingredient_list)
82
- self.max_ingredients = max_ingredients
83
- self.min_ingredients = min_ingredients
84
- self.mutation_params = pop_params['mutation_params']
85
- self.dist = pop_params['dist']
86
- self.target = target
87
- self.is_known = known_target_dict is not None
88
- self.known_target_dict = known_target_dict
89
- self.perf = None
90
- self.cocktail_rep = None
91
- self.affective_cluster = None
92
- self.target_affective_cluster = target_affective_cluster
93
- self.ing_list = np.array(ingredient_list)
94
- self.ing_set = set(ingredient_list)
95
-
96
- self.ing_ids_per_cat = dict(bubbles=set(self.get_ingredients_ids_from_list(bubble_ingredients)),
97
- liquor=set(self.get_ingredients_ids_from_list(ingredients_per_type['liquor'])),
98
- liqueur=set(self.get_ingredients_ids_from_list(ingredients_per_type['liqueur'])),
99
- citrus=set(self.get_ingredients_ids_from_list(ingredients_per_type['acid'] + ['orange juice'])),
100
- alcohol=set(ind_alcohol),
101
- sweeteners=set(self.get_ingredients_ids_from_list(ingredients_per_type['sweeteners'])),
102
- vermouth=set(self.get_ingredients_ids_from_list(ingredients_per_type['vermouth'])),
103
- bitters=set(self.get_ingredients_ids_from_list(ingredients_per_type['bitters'])),
104
- juice=set(self.get_ingredients_ids_from_list(ingredients_per_type['juice'])),
105
- acid=set(self.get_ingredients_ids_from_list(ingredients_per_type['acid'])),
106
- egg=set(self.get_ingredients_ids_from_list(['egg']))
107
- )
108
-
109
- if genes_presence is not None:
110
- assert len(genes_presence) == self.n_genes
111
- assert len(genes_quantity) == self.n_genes
112
- self.genes_presence = genes_presence
113
- self.genes_quantity = genes_quantity
114
- if compute_perf:
115
- self.compute_cocktail_rep()
116
- self.compute_perf()
117
- else:
118
- self.sample_initial_genes()
119
- self.compute_cocktail_rep()
120
- # self.make_recipe_fit_the_glass()
121
- self.compute_perf()
122
-
123
-
124
- # # # # # # # # # # # # # # # # # # # # # # # #
125
- # Sample initial genes with smart rules
126
- # # # # # # # # # # # # # # # # # # # # # # # #
127
-
128
- def sample_initial_genes(self):
129
- # rules:
130
- # - between min_ingredients and max_ingredients
131
- # - at most one type of bubbles
132
- # - at least one alcohol
133
- # - no egg without lime or lemon
134
- # - at most two liqueurs
135
- # - at most three liquors
136
- # - at most two sweetener
137
- self.genes_quantity = np.random.uniform(0, 1, size=self.n_genes) # holds quantities for each ingredient
138
- n_ingredients = np.random.choice(np.arange(min_ingredients, max_ingredients + 1), p=distrib_nb_ings_2_8)
139
- self.genes_presence = np.zeros(self.n_genes)
140
- # add one alchohol
141
- self.genes_presence[np.random.choice(ind_alcohol)] = 1
142
- while self.get_ing_count() < n_ingredients:
143
- candidate_ids = self.get_candidate_ingredients_ids(self.genes_presence)
144
- probas = density_ingredients[candidate_ids] / np.sum(density_ingredients[candidate_ids])
145
- self.genes_presence[np.random.choice(candidate_ids, p=probas)] = 1
146
-
147
- def get_candidate_ingredients_ids(self, genes_presence):
148
- candidates = set(np.argwhere(genes_presence==0).flatten())
149
- present_ids = set(np.argwhere(genes_presence==1).flatten())
150
-
151
- if self.count_in_genes(present_ids, 'bubbles') >= 1: # at most one type of bubbles
152
- candidates = candidates - self.ing_ids_per_cat['bubbles']
153
- if self.count_in_genes(present_ids, 'liquor') >= 3: # at most three liquors
154
- candidates = candidates - self.ing_ids_per_cat['liquor']
155
- if self.count_in_genes(present_ids, 'liqueur') >= 2: # at most two liqueurs
156
- candidates = candidates - self.ing_ids_per_cat['liqueur']
157
- if self.count_in_genes(present_ids, 'sweeteners') >= 2: # at most two sweetener
158
- candidates = candidates - self.ing_ids_per_cat['sweeteners']
159
- if self.count_in_genes(present_ids, 'citrus') == 0: # no egg without lime or lemon
160
- candidates = candidates - self.ing_ids_per_cat['egg']
161
- return np.array(sorted(candidates))
162
-
163
- def count_in_genes(self, present_ids, keyword):
164
- if keyword == 'citrus': return len(present_ids & self.ing_ids_per_cat['citrus'])
165
- elif keyword == 'bubbles': return len(present_ids & self.ing_ids_per_cat['bubbles'])
166
- elif keyword == 'liquor': return len(present_ids & self.ing_ids_per_cat['liquor'])
167
- elif keyword == 'liqueur': return len(present_ids & self.ing_ids_per_cat['liqueur'])
168
- elif keyword == 'alcohol': return len(present_ids & self.ing_ids_per_cat['alcohol'])
169
- elif keyword == 'sweeteners': return len(present_ids & self.ing_ids_per_cat['sweeteners'])
170
- else: raise ValueError
171
-
172
- def get_ingredients_ids_from_list(self, ing_list):
173
- return [ingredient_list.index(ing) for ing in ing_list]
174
-
175
- def get_ing_count(self):
176
- return np.sum(self.genes_presence)
177
-
178
- # # # # # # # # # # # # # # # # # # # # # # # #
179
- # Compute cocktail representations
180
- # # # # # # # # # # # # # # # # # # # # # # # #
181
-
182
- def get_absent_ing(self):
183
- return np.argwhere(self.genes_presence==0).flatten()
184
-
185
- def get_present_ing(self):
186
- return np.argwhere(self.genes_presence==1).flatten()
187
-
188
- def get_ingredient_quantities(self):
189
- # unnormalize quantities to get real ones
190
- return (self.genes_quantity * (max_ingredients_quantities * factor_max - min_ingredients_quantities_when_present) + min_ingredients_quantities_when_present) * self.genes_presence
191
-
192
- def get_ing_and_q_from_genes(self):
193
- present_ings = self.get_present_ing()
194
- ing_quantities = self.get_ingredient_quantities()
195
- ingredients, quantities = [], []
196
- for i_ing in present_ings:
197
- ingredients.append(ingredient_list[i_ing])
198
- quantities.append(ing_quantities[i_ing])
199
- return ingredients, quantities, ing_quantities
200
-
201
- def compute_cocktail_rep(self):
202
- # only call when genes have changes
203
- init_time = time.time()
204
- ingredients, quantities, ing_quantities = self.get_ing_and_q_from_genes()
205
- # compute cocktail category
206
- self.category = find_cocktail_sub_category(ingredients, quantities)[0]
207
- # print(f't1: {time.time() - init_time}')
208
- init_time = time.time()
209
- self.prep_type = self.get_prep_type(ing_quantities)
210
- # print(f't2: {time.time() - init_time}')
211
- init_time = time.time()
212
- cocktail_rep, self.end_volume, self.end_alcohol = get_cocktail_rep(self.prep_type, ingredients, quantities, keys=rep_keys[1:]) # volume is added later
213
- # print(f't3: {time.time() - init_time}')
214
- init_time = time.time()
215
- self.cocktail_rep = normalize_cocktail(cocktail_rep)
216
- # print(f't4: {time.time() - init_time}')
217
- init_time = time.time()
218
- self.glass = self.get_glass_type(ing_quantities)
219
- # print(f't5: {time.time() - init_time}')
220
- init_time = time.time()
221
- if self.is_known:
222
- assert np.abs(self.cocktail_rep - self.target).sum() < 1e-6
223
- return self.cocktail_rep
224
-
225
- def get_prep_type(self, quantities=None):
226
- if self.is_known: return self.known_target_dict['prep_type']
227
- else:
228
- if quantities is None:
229
- quantities = self.get_ingredient_quantities()
230
- if quantities[ingredient_list.index('egg')] > 0:
231
- prep_cat = 'egg_shaken'
232
- elif self.category in ['spirit_forward', 'simple_sour_with_juice', 'julep', 'duo', 'ancestral', 'complex_sour_with_juice']:
233
- # use hard coded rules for most obvious cases determined with the correlations_glass_cat_prep_script
234
- if self.category in ['ancestral', 'spirit_forward', 'duo']:
235
- prep_cat = 'stirred'
236
- elif self.category in ['complex_sour_with_juice', 'julep', 'simple_sour_with_juice']:
237
- prep_cat = 'shaken'
238
- else:
239
- raise ValueError
240
- else:
241
- output = prep_model(quantities, aux_str='prep_type').flatten()
242
- output[preparation_list.index('egg_shaken')] = -np.inf
243
- prep_cat = preparation_list[np.argmax(output)]
244
- return prep_cat
245
-
246
- def get_glass_type(self, quantities=None):
247
- if self.is_known: return self.known_target_dict['glass']
248
- else:
249
- if self.category in ['collins', 'complex_highball', 'simple_highball', 'champagne_cocktail', 'complex_sour']:
250
- # use hard coded rules for most obvious cases determined with the correlations_glass_cat_prep_script
251
- if self.category in ['collins', 'complex_highball', 'simple_highball']:
252
- glass = 'collins'
253
- elif self.category in ['champagne_cocktail', 'complex_sour']:
254
- glass = 'coupe'
255
- else:
256
- if quantities is None:
257
- quantities = self.get_ingredient_quantities()
258
- output = prep_model(quantities, aux_str='glasses').flatten()
259
- glass = glasses_list[np.argmax(output)]
260
- return glass
261
-
262
- # # # # # # # # # # # # # # # # # # # # # # # #
263
- # Adapt recipe to fit the glass
264
- # # # # # # # # # # # # # # # # # # # # # # # #
265
-
266
- def is_too_large_for_glass(self):
267
- return self.end_volume > glass_volume[self.glass] * 0.80
268
-
269
- def is_too_small_for_glass(self):
270
- return self.end_volume < glass_volume[self.glass] * 0.3
271
-
272
- def scale_ing_quantities(self, present_ings, factor):
273
- qs = self.get_ingredient_quantities().copy()
274
- qs[present_ings] *= factor
275
- self.set_genes_from_quantities(present_ings, qs)
276
-
277
- def set_genes_from_quantities(self, present_ings, quantities):
278
- genes_quantity = np.clip((quantities - min_ingredients_quantities_when_present) /
279
- (factor_max * max_ingredients_quantities - min_ingredients_quantities_when_present), 0, 1)
280
- self.genes_quantity[present_ings] = genes_quantity[present_ings]
281
-
282
- def make_recipe_fit_the_glass(self):
283
- # check if citrus, if not remove egg
284
- present_ids = np.argwhere(self.genes_presence == 1).flatten()
285
- ing_list = self.ing_list[present_ids]
286
- present_ids = set(present_ids)
287
- if self.count_in_genes(present_ids, 'citrus') == 0 and 'egg' in ing_list:
288
- if self.genes_presence.sum() > 2:
289
- i_egg = ingredient_list.index('egg')
290
- self.genes_presence[i_egg] = 0.
291
- self.compute_cocktail_rep()
292
-
293
-
294
- i_trial = 0
295
- present_ings = self.get_present_ing()
296
- while self.is_too_large_for_glass():
297
- i_trial += 1
298
- end_volume = self.end_volume
299
- desired_volume = glass_volume[self.glass] * 0.80
300
- ratio = desired_volume / end_volume
301
- self.scale_ing_quantities(present_ings, factor=ratio)
302
- self.compute_cocktail_rep()
303
- if end_volume == self.end_volume: break
304
- if i_trial == 10: break
305
- while self.is_too_small_for_glass():
306
- i_trial += 1
307
- end_volume = self.end_volume
308
- desired_volume = glass_volume[self.glass] * 0.80
309
- ratio = desired_volume / end_volume
310
- self.scale_ing_quantities(present_ings, factor=ratio)
311
- self.compute_cocktail_rep()
312
- if end_volume == self.end_volume: break
313
- if i_trial == 10: break
314
-
315
- # # # # # # # # # # # # # # # # # # # # # # # #
316
- # Compute performance
317
- # # # # # # # # # # # # # # # # # # # # # # # #
318
-
319
- def passes_checks(self):
320
- present_ids = np.argwhere(self.genes_presence==1).flatten()
321
- # ing_list = self.ing_list[present_ids]
322
- present_ids = set(present_ids)
323
- if len(present_ids) < 2 or len(present_ids) > 8: return False
324
- # if self.is_too_large_for_glass(): return False
325
- # if self.is_too_small_for_glass(): return False
326
- if self.end_alcohol < 0.05 or self.end_alcohol > 0.31: return False
327
- if self.count_in_genes(present_ids, 'sweeteners') > 2: return False
328
- if self.count_in_genes(present_ids, 'liqueur') > 2: return False
329
- if self.count_in_genes(present_ids, 'liquor') > 3: return False
330
- # if self.count_in_genes(present_ids, 'citrus') == 0 and 'egg' in ing_list: return False
331
- if self.count_in_genes(present_ids, 'bubbles') > 1: return False
332
- else: return True
333
-
334
- def get_affective_cluster(self):
335
- cocktail_rep_affective = get_normalized_affective_cocktail_rep_from_normalized_cocktail_rep(self.cocktail_rep)
336
- self.affective_cluster = cocktail2affective_cluster(cocktail_rep_affective)[0]
337
- return self.affective_cluster
338
-
339
- def does_affective_cluster_match(self):
340
- return True#self.get_affective_cluster() == self.target_affective_cluster
341
-
342
- def compute_perf(self):
343
- if not self.passes_checks(): self.perf = -100
344
- else:
345
- if self.dist == 'mse':
346
- # self.perf = - np.sqrt(((self.cocktail_rep - self.target)**2).mean())
347
- self.perf = - np.sqrt(np.dot((self.cocktail_rep - self.target)**2, weights_mse_computation))
348
- self.perf *= weights_perf_n_ing[int(self.genes_presence.sum())]
349
- if not self.does_affective_cluster_match():
350
- self.perf *= 2
351
- else: raise NotImplemented
352
-
353
-
354
- # # # # # # # # # # # # # # # # # # # # # # # #
355
- # Mutations and crossover
356
- # # # # # # # # # # # # # # # # # # # # # # # #
357
-
358
- def get_child(self):
359
- time_dict = dict()
360
- init_time = time.time()
361
- child = IndividualCocktail(pop_params=self.pop_params, target_affective_cluster=self.target_affective_cluster,
362
- target=self.target, genes_presence=self.genes_presence.copy(),
363
- genes_quantity=self.genes_quantity.copy(), compute_perf=False)
364
- time_dict[' asexual child creation'] = [time.time() - init_time]
365
- init_time = time.time()
366
- this_time_dict = child.mutate()
367
- time_dict = self.update_time_dict(time_dict, this_time_dict)
368
- time_dict[' asexual child mutation'] = [time.time() - init_time]
369
- return child, time_dict
370
-
371
- def get_child_with(self, other_parent):
372
- time_dict = dict()
373
- init_time = time.time()
374
- new_genes_presence = np.zeros(self.n_genes)
375
- present_ing = self.get_present_ing()
376
- other_present_ing = other_parent.get_present_ing()
377
- new_genes_quantity = np.random.uniform(0, 1, size=self.n_genes)
378
- shared_ingredients = sorted(set(present_ing) & set(other_present_ing))
379
- unique_ingredients_one = sorted(set(present_ing) - set(other_present_ing))
380
- unique_ingredients_two = sorted(set(other_present_ing) - set(present_ing))
381
- for i in shared_ingredients:
382
- new_genes_presence[i] = 1
383
- new_genes_quantity[i] = (self.genes_quantity[i] + other_parent.genes_quantity[i]) / 2
384
- time_dict[' crossover child creation'] = [time.time() - init_time]
385
- init_time = time.time()
386
- # add one alcohol if none present
387
- if len(set(np.argwhere(new_genes_presence==1).flatten()).intersection(ind_alcohol)) == 0:
388
- new_genes_presence[np.random.choice(ind_alcohol)] = 1
389
- # up to here, we respect the constraints (assuming both parents do).
390
- candidate_genes = np.array(unique_ingredients_one + unique_ingredients_two)
391
- candidate_quantities = np.array([self.genes_quantity[i] for i in unique_ingredients_one] + [other_parent.genes_quantity[i] for i in unique_ingredients_two])
392
- indexes = np.arange(len(candidate_genes))
393
- np.random.shuffle(indexes)
394
- candidate_genes = candidate_genes[indexes]
395
- candidate_quantities = candidate_quantities[indexes]
396
- time_dict[' crossover prepare selection'] = [time.time() - init_time]
397
- init_time = time.time()
398
- # now let's try to add each of them while respecting the constraints
399
- for i in range(len(indexes)):
400
- if np.random.rand() < 0.5 or np.sum(new_genes_presence) < self.min_ingredients: # only try to add one every two ingredient
401
- ing_id = candidate_genes[i]
402
- q = candidate_quantities[i]
403
- new_genes_presence[ing_id] = 1
404
- new_genes_quantity[ing_id] = q
405
- if np.sum(new_genes_presence) == self.max_ingredients:
406
- break
407
- time_dict[' crossover do selection'] = [time.time() - init_time]
408
- init_time = time.time()
409
- # create new child
410
- child = IndividualCocktail(pop_params=self.pop_params, target_affective_cluster=self.target_affective_cluster, target=self.target,
411
- genes_presence=new_genes_presence.copy(), genes_quantity=new_genes_quantity.copy(), compute_perf=False)
412
- time_dict[' crossover create child'] = [time.time() - init_time]
413
- init_time = time.time()
414
- this_time_dict = child.mutate()
415
- time_dict = self.update_time_dict(time_dict, this_time_dict)
416
- time_dict[' crossover child mutation'] = [time.time() - init_time]
417
- init_time = time.time()
418
- return child, time_dict
419
-
420
- def mutate(self):
421
- # self.print_recipe()
422
- time_dict = dict()
423
- # remove an ingredient
424
- init_time = time.time()
425
- present_ids = set(np.argwhere(self.genes_presence==1).flatten())
426
-
427
- if np.random.rand() < self.mutation_params['p_remove_ing']:
428
- if self.get_ing_count() > self.min_ingredients:
429
- candidate_ings = self.get_present_ing()
430
- if self.count_in_genes(present_ids, 'alcohol') == 1: # make sure we keep at least one liquor
431
- candidate_ings = np.array(sorted(set(candidate_ings) - set(ind_alcohol)))
432
- index_to_remove = np.random.choice(candidate_ings)
433
- self.genes_presence[index_to_remove] = 0
434
- time_dict[' mutation remove ing'] = [time.time() - init_time]
435
- init_time = time.time()
436
- # add an ingredient
437
- if np.random.rand() < self.mutation_params['p_add_ing']:
438
- if self.get_ing_count() < self.max_ingredients:
439
- candidate_ings = self.get_candidate_ingredients_ids(self.genes_presence.copy())
440
- index_to_add = np.random.choice(candidate_ings, p=density_ingredients[candidate_ings] / np.sum(density_ingredients[candidate_ings]))
441
- self.genes_presence[index_to_add] = 1
442
- time_dict[' mutation add ing'] = [time.time() - init_time]
443
-
444
- init_time = time.time()
445
- # replace ings by others from the same family
446
- if np.random.rand() < self.mutation_params['p_switch_ing']:
447
- i = np.random.choice(self.get_present_ing())
448
- ing_str = ingredient_list[i]
449
- if ing_str not in ['sparkling wine', 'orange juice']:
450
- if ing_str in bubble_ingredients:
451
- candidates_ids = np.array(sorted(self.ing_ids_per_cat['bubbles'] - set([i])))
452
- new_bubble = np.random.choice(candidates_ids, p=density_ingredients[candidates_ids] / np.sum(density_ingredients[candidates_ids]))
453
- self.genes_presence[i] = 0
454
- self.genes_presence[new_bubble] = 1
455
- self.genes_quantity[new_bubble] = self.genes_quantity[i] # copy quantity
456
- categories = ['acid', 'bitters', 'juice', 'liqueur', 'liquor', 'sweeteners', 'vermouth']
457
- for cat in categories:
458
- if ing_str in ingredients_per_type[cat]:
459
- present_ings = self.get_present_ing()
460
- candidates_ids = np.array(sorted(self.ing_ids_per_cat[cat] - set([i]) - set(present_ings)))
461
- if len(candidates_ids) > 0:
462
- replacing_ing = np.random.choice(candidates_ids, p=density_ingredients[candidates_ids] / np.sum(density_ingredients[candidates_ids]))
463
- self.genes_presence[i] = 0
464
- self.genes_presence[replacing_ing] = 1
465
- self.genes_quantity[replacing_ing] = self.genes_quantity[i] # copy quantity
466
- break
467
- time_dict[' mutation switch ing'] = [time.time() - init_time]
468
- init_time = time.time()
469
- # add noise on ing quantity
470
- for i in self.get_present_ing():
471
- if np.random.rand() < self.mutation_params['p_change_q']:
472
- self.genes_quantity[i] += np.random.randn() * self.mutation_params['delta_change_q']
473
- self.genes_quantity = np.clip(self.genes_quantity, 0, 1)
474
- time_dict[' mutation change quantity'] = [time.time() - init_time]
475
-
476
- init_time = time.time()
477
- self.compute_cocktail_rep()
478
- time_dict[' mutation compute cocktail rep'] = [time.time() - init_time]
479
- init_time = time.time()
480
- # self.make_recipe_fit_the_glass()
481
- time_dict[' mutation check glass fit'] = [time.time() - init_time]
482
- init_time = time.time()
483
- self.compute_perf()
484
- time_dict[' mutation compute perf'] = [time.time() - init_time]
485
- init_time = time.time()
486
- stop = 1
487
- return time_dict
488
-
489
-
490
- def update_time_dict(self, main_dict, new_dict):
491
- for k in new_dict.keys():
492
- if k in main_dict.keys():
493
- main_dict[k].append(np.sum(new_dict[k]))
494
- else:
495
- main_dict[k] = [np.sum(new_dict[k])]
496
- return main_dict
497
-
498
- # # # # # # # # # # # # # # # # # # # # # # # #
499
- # Get recipe and print
500
- # # # # # # # # # # # # # # # # # # # # # # # #
501
-
502
- def get_recipe(self, unit='mL', name=None):
503
- ing_quantities = self.get_ingredient_quantities()
504
- ingredients, quantities = [], []
505
- for i_ing, q_ing in enumerate(ing_quantities):
506
- if q_ing > 0.8:
507
- ingredients.append(ingredient_list[i_ing])
508
- quantities.append(round(q_ing))
509
- recipe_str = format_ingredients(ingredients, quantities)
510
- recipe_str_readable = print_recipe(unit=unit, ingredient_str=recipe_str, name=name, to_print=False)
511
- return ingredients, quantities, recipe_str, recipe_str_readable
512
-
513
- def get_instructions(self):
514
- ing_quantities = self.get_ingredient_quantities()
515
- ingredients, quantities = [], []
516
- for i_ing, q_ing in enumerate(ing_quantities):
517
- if q_ing > 0.8:
518
- ingredients.append(ingredient_list[i_ing])
519
- quantities.append(round(q_ing))
520
- str_out = 'Instructions:\n '
521
-
522
- if 'mint' in ingredients:
523
- i_mint = ingredients.index('mint')
524
- n_leaves = quantities[i_mint]
525
- str_out += f'Add {n_leaves} mint leaves to a shaker, followed by an ice cube.\n Muddle the mint and ice together with a muddler.\n '
526
- bubbles = ['sparkling wine', 'tonic', 'soda', 'ginger beer']
527
- other_ings = [ing for ing in ingredients if ing not in ['egg', 'angostura', 'orange bitters'] + bubbles]
528
-
529
- if self.prep_type == 'built':
530
- str_out += 'Add a large ice cube in the glass.\n '
531
- # add ingredients to pour
532
- str_out += 'Pour'
533
- for i, ing in enumerate(other_ings):
534
- if i == len(other_ings) - 2:
535
- str_out += f' {ing} and'
536
- elif i == len(other_ings) - 1:
537
- str_out += f' {ing}'
538
- else:
539
- str_out += f' {ing},'
540
-
541
- if self.prep_type in ['built'] and 'mint' not in ingredients:
542
- str_out += ' into the glass.\n '
543
- else:
544
- str_out += ' into the shaker.\n '
545
-
546
- if self.prep_type == 'egg_shaken' and 'egg' in ingredients:
547
- str_out += 'Add the egg white.\n Dry-shake for 15s (without ice), then fill with ice and shake for another 15s.\n Serve into the glass through a strainer.\n '
548
- elif 'shaken' in self.prep_type:
549
- str_out += 'Fill with ice and shake for 15s.\n Serve into the glass through a strainer.\n '
550
- elif self.prep_type == 'stirred':
551
- str_out += 'Add ice and stir the cocktail with a spoon for 15s.\n Serve into the glass through a strainer.\n '
552
- elif self.prep_type == 'built':
553
- str_out += 'Stir two turns with a spoon.\n '
554
-
555
- bubble_ing = [ing for ing in ingredients if ing in bubbles]
556
- if len(bubble_ing) > 0:
557
- str_out += f'Top up with '
558
- for ing in bubble_ing:
559
- str_out += f'{ing}, '
560
- str_out = str_out[:-2] + '.\n '
561
- bitter_ing = [ing for ing in ingredients if ing in ['angostura', 'orange bitters']]
562
- if len(bitter_ing) > 0:
563
- if len(bitter_ing) == 1:
564
- q = quantities[ingredients.index(bitter_ing[0])]
565
- n_dashes = max(1, int(q / 0.6))
566
- str_out += f'Add {n_dashes} dash'
567
- if n_dashes > 1:
568
- str_out += 'es'
569
- str_out += f' of {bitter_ing[0]}.\n '
570
- elif len(bitter_ing) == 2:
571
- q = quantities[ingredients.index(bitter_ing[0])]
572
- n_dashes = max(1, int(q / 0.6))
573
- str_out += f'Add {n_dashes} dash'
574
- if n_dashes > 1:
575
- str_out += 'es'
576
- str_out += f' of {bitter_ing[0]} and '
577
- q = quantities[ingredients.index(bitter_ing[1])]
578
- n_dashes = max(1, int(q / 0.6))
579
- str_out += f'{n_dashes} dash'
580
- if n_dashes > 1:
581
- str_out += 'es'
582
- str_out += f' of {bitter_ing[1]}.\n '
583
- str_out += 'Enjoy!'
584
- return str_out
585
-
586
- def print_recipe(self, name=None):
587
- print(self.get_recipe(name)[3])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/cocktails/utilities/cocktail_generation_utilities/population.py DELETED
@@ -1,213 +0,0 @@
1
- from src.cocktails.utilities.cocktail_generation_utilities.individual import *
2
- from sklearn.neighbors import NearestNeighbors
3
- import time
4
- import pickle
5
- from src.cocktails.config import COCKTAIL_NN_PATH, COCKTAILS_CSV_DATA
6
-
7
- class Population:
8
- def __init__(self, target, pop_params, target_affective_cluster=None, known_target_dict=None):
9
- self.pop_params = pop_params
10
- self.pop_size = pop_params['pop_size']
11
- self.nb_elite = pop_params['nb_elites']
12
- self.nb_generations = pop_params['nb_generations']
13
- self.target = target
14
- self.mutation_params = pop_params['mutation_params']
15
- self.dist = pop_params['dist']
16
- self.n_neighbors = pop_params['n_neighbors']
17
- self.known_target_dict = known_target_dict
18
-
19
-
20
- with open(COCKTAIL_NN_PATH, 'rb') as f:
21
- data = pickle.load(f)
22
- self.nn_model_cocktail = data['nn_model']
23
- self.dim_rep_cocktail = data['dim_rep_cocktail']
24
- self.n_cocktails = data['n_cocktails']
25
- self.cocktail_data = pd.read_csv(COCKTAILS_CSV_DATA)
26
-
27
- if target_affective_cluster is None:
28
- cocktail_rep_affective = get_normalized_affective_cocktail_rep_from_normalized_cocktail_rep(target)
29
- self.target_affective_cluster = cocktail2affective_cluster(cocktail_rep_affective)[0]
30
- else:
31
- self.target_affective_cluster = target_affective_cluster
32
-
33
- self.pop_elite = []
34
- self.pop = []
35
- self.add_target_individual() # create a target individual (not in pop)
36
- self.add_nearest_neighbors_in_pop() # add nearest neighbor from dataset into the population
37
-
38
- # fill population
39
- while self.get_pop_size() < self.pop_size:
40
- self.add_individual()
41
- while len(self.pop_elite) < self.nb_elite:
42
- self.pop_elite.append(IndividualCocktail(pop_params=self.pop_params,
43
- target=self.target.copy(),
44
- target_affective_cluster=self.target_affective_cluster))
45
- self.update_elite_and_get_next_pop()
46
-
47
- def add_target_individual(self):
48
- if self.known_target_dict is not None:
49
- genes_presence, genes_quantity = self.get_q_rep(*extract_ingredients(self.known_target_dict['ing_str']))
50
- self.target_individual = IndividualCocktail(pop_params=self.pop_params,
51
- target=self.target.copy(),
52
- known_target_dict=self.known_target_dict,
53
- target_affective_cluster=self.target_affective_cluster,
54
- genes_presence=genes_presence,
55
- genes_quantity=genes_quantity
56
- )
57
- else:
58
- self.target_individual = None
59
-
60
-
61
- def add_nearest_neighbors_in_pop(self):
62
- # add nearest neighbor from dataset into the population
63
- if self.n_neighbors > 0:
64
- dists, indexes = self.nn_model_cocktail.kneighbors(self.target.reshape(1, -1))
65
- dists, indexes = dists.flatten(), indexes.flatten()
66
- first = 1 if dists[0] == 0 else 0 # avoid taking the target when testing with known targets from the dataset
67
- indexes = indexes[first:first + self.n_neighbors]
68
- self.ing_strs = np.array(self.cocktail_data['ingredients_str'])[indexes]
69
- recipes = [extract_ingredients(ing_str) for ing_str in self.ing_strs]
70
- for r in recipes:
71
- genes_presence, genes_quantity = self.get_q_rep(r[0], r[1])
72
- genes_presence[-1] = 0 # remove water ingredient
73
- self.add_individual(genes_presence=genes_presence.copy(), genes_quantity=genes_quantity.copy())
74
- self.nn_recipes = [ind.get_recipe()[3] for ind in self.pop]
75
- self.nn_scores = [ind.perf for ind in self.pop]
76
- else:
77
- self.ing_strs = None
78
-
79
- def add_individual(self, genes_presence=None, genes_quantity=None):
80
- self.pop.append(IndividualCocktail(pop_params=self.pop_params,
81
- target=self.target.copy(),
82
- target_affective_cluster=self.target_affective_cluster,
83
- genes_presence=genes_presence,
84
- genes_quantity=genes_quantity))
85
-
86
- def get_elite_perf(self):
87
- return np.array([e.perf for e in self.pop_elite])
88
-
89
- def get_pop_perf(self):
90
- return np.array([ind.perf for ind in self.pop])
91
-
92
-
93
- def update_elite_and_get_next_pop(self):
94
- time_dict = dict()
95
- init_time = time.time()
96
- elite_perfs = self.get_elite_perf()
97
- pop_perfs = self.get_pop_perf()
98
- all_perfs = np.concatenate([elite_perfs, pop_perfs])
99
- temp_list = self.pop_elite + self.pop
100
- time_dict[' get pop perfs'] = [time.time() - init_time]
101
- init_time = time.time()
102
- # update elite population with new bests
103
- indexes_sorted = np.flip(np.argsort(all_perfs))
104
- new_pop_elite = [IndividualCocktail(pop_params=self.pop_params,
105
- target=self.target.copy(),
106
- target_affective_cluster=self.target_affective_cluster,
107
- genes_presence=temp_list[i_new_e].genes_presence.copy(),
108
- genes_quantity=temp_list[i_new_e].genes_quantity.copy()) for i_new_e in indexes_sorted[:self.nb_elite]]
109
- time_dict[' recreate elite individuals'] = [time.time() - init_time]
110
- init_time = time.time()
111
- # select parents
112
- rank_perfs = np.flip(np.arange(len(temp_list)))
113
- sampling_probs = rank_perfs / np.sum(rank_perfs)
114
- if self.mutation_params['asexual_rep'] and not self.mutation_params['crossover']:
115
- new_pop_indexes = np.random.choice(indexes_sorted, p=sampling_probs, size=self.pop_size)
116
- self.pop = [temp_list[i].get_child() for i in new_pop_indexes]
117
- elif self.mutation_params['crossover'] and not self.mutation_params['asexual_rep']:
118
- self.pop = []
119
- while len(self.pop) < self.pop_size:
120
- parents = np.random.choice(indexes_sorted, p=sampling_probs, size=2, replace=False)
121
- self.pop.append(temp_list[parents[0]].get_child_with(temp_list[parents[1]]))
122
- elif self.mutation_params['crossover'] and self.mutation_params['asexual_rep']:
123
- new_pop_indexes = np.random.choice(indexes_sorted, p=sampling_probs, size=self.pop_size//2)
124
- time_dict[' choose asexual parent indexes'] = [time.time() - init_time]
125
- init_time = time.time()
126
- self.pop = []
127
- for i in new_pop_indexes:
128
- child, this_time_dict = temp_list[i].get_child()
129
- self.pop.append(child)
130
- time_dict = self.update_time_dict(time_dict, this_time_dict)
131
- time_dict[' get asexual children'] = [time.time() - init_time]
132
- init_time = time.time()
133
- while len(self.pop) < self.pop_size:
134
- parents = np.random.choice(indexes_sorted, p=sampling_probs, size=2, replace=False)
135
- child, this_time_dict = temp_list[parents[0]].get_child_with(temp_list[parents[1]])
136
- self.pop.append(child)
137
- time_dict = self.update_time_dict(time_dict, this_time_dict)
138
- time_dict[' get sexual children'] = [time.time() - init_time]
139
- self.pop_elite = new_pop_elite
140
- return time_dict
141
-
142
- def get_pop_size(self):
143
- return len(self.pop)
144
-
145
- def get_q_rep(self, ingredients, quantities):
146
- ingredient_q_rep = np.zeros([len(ingredient_list)])
147
- genes_presence = np.zeros([len(ingredient_list)])
148
- for ing, q in zip(ingredients, quantities):
149
- ingredient_q_rep[ingredient_list.index(ing)] = q
150
- genes_presence[ingredient_list.index(ing)] = 1
151
- return genes_presence.copy(), normalize_ingredient_q_rep(ingredient_q_rep)
152
-
153
- def get_best_score(self, affective_cluster_check=False):
154
- elite_perfs = self.get_elite_perf()
155
- pop_perfs = self.get_pop_perf()
156
- all_perfs = np.concatenate([elite_perfs, pop_perfs])
157
- temp_list = self.pop_elite + self.pop
158
- if affective_cluster_check:
159
- indexes = np.array([i for i in range(len(temp_list)) if temp_list[i].does_affective_cluster_match()])
160
- if indexes.size > 0:
161
- temp_list = np.array(temp_list)[indexes]
162
- all_perfs = all_perfs[indexes]
163
- indexes_best = np.flip(np.argsort(all_perfs))
164
- return np.array(all_perfs)[indexes_best], np.array(temp_list)[indexes_best]
165
-
166
- def update_time_dict(self, main_dict, new_dict):
167
- for k in new_dict.keys():
168
- if k in main_dict.keys():
169
- main_dict[k].append(np.sum(new_dict[k]))
170
- else:
171
- main_dict[k] = [np.sum(new_dict[k])]
172
- return main_dict
173
-
174
- def run_one_generation(self, verbose=True, affective_cluster_check=False):
175
- time_dict = dict()
176
- init_time = time.time()
177
- this_time_dict = self.update_elite_and_get_next_pop()
178
- time_dict['update_elite_and_pop'] = [time.time() - init_time]
179
- time_dict = self.update_time_dict(time_dict, this_time_dict)
180
- init_time = time.time()
181
- best_perfs, best_individuals = self.get_best_score(affective_cluster_check)
182
- time_dict['get best scores'] = [time.time() - init_time]
183
- return best_perfs[0], time_dict
184
-
185
- def run_evolution(self, verbose=False, print_every=10, affective_cluster_check=False, level=0):
186
- best_score = -np.inf
187
- time_dict = dict()
188
- init_time = time.time()
189
- for i in range(self.nb_generations):
190
- best_score, this_time_dict = self.run_one_generation(verbose, affective_cluster_check=affective_cluster_check)
191
- time_dict = self.update_time_dict(time_dict, this_time_dict)
192
- if verbose and (i+1) % print_every == 0:
193
- print(' ' * level + f'Gen #{i+1} - Current best perf: {best_score:.2f}, time: {time.time() - init_time:.4f}')
194
- init_time = time.time()
195
- #
196
- # to_print = time_dict.copy()
197
- # keys = sorted(to_print.keys())
198
- # values = []
199
- # for k in keys:
200
- # to_print[k] = np.sum(to_print[k])
201
- # values.append(to_print[k])
202
- # sorted_inds = np.flip(np.argsort(values))
203
- # for i in sorted_inds:
204
- # print(f'{keys[i]}: {values[i]:.4f}')
205
- if verbose: print(' ' * level + f'Evolution over, best perf: {best_score:.2f}')
206
- return self.get_best_score()
207
-
208
- def print_results(self, n=3):
209
- best_scores, best_ind = self.get_best_score()
210
- for i in range(n):
211
- best_ind[i].print_recipe(f'Candidate #{i+1}, Score: {best_scores[i]:.2f}')
212
-
213
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/cocktails/utilities/cocktail_utilities.py DELETED
@@ -1,220 +0,0 @@
1
- import numpy as np
2
- from src.cocktails.utilities.ingredients_utilities import ingredient2ingredient_id, ingredient_profiles, ingredients_per_type, ingredient_list, find_ingredient_from_str
3
- from src.cocktails.utilities.cocktail_category_detection_utilities import *
4
- import time
5
-
6
- # representation_keys = ['pH', 'sour', 'sweet', 'booze', 'bitter', 'fruit', 'herb',
7
- # 'complex', 'spicy', 'strong', 'oaky', 'fizzy', 'colorful', 'eggy']
8
- representation_keys = ['sour', 'sweet', 'booze', 'bitter', 'fruit', 'herb',
9
- 'complex', 'spicy', 'oaky', 'fizzy', 'colorful', 'eggy']
10
- representation_keys_linear = list(set(representation_keys) - set(['pH', 'complex']))
11
-
12
- ing_reps = np.array([[ingredient_profiles[k][ing_id] for ing_id in ingredient2ingredient_id.values()] for k in representation_keys]).transpose()
13
-
14
-
15
- def compute_cocktail_representation(profile, ingredients, quantities):
16
- # computes representation of a cocktail from the recipe (ingredients, quantities) and volume
17
- n = len(ingredients)
18
- assert n == len(quantities)
19
- quantities = np.array(quantities)
20
-
21
- weights = quantities / np.sum(quantities)
22
- rep = dict()
23
-
24
- ing_ids = np.array([ingredient2ingredient_id[ing] for ing in ingredients])
25
- # compute features as linear combination of ingredient features
26
- for k in representation_keys_linear:
27
- k_ing = np.array([ingredient_profiles[k][ing_id] for ing_id in ing_ids])
28
- rep[k] = np.dot(weights, k_ing)
29
-
30
- # for ph
31
- # ph = - log10 x
32
- phs = np.array([ingredient_profiles['pH'][ing_id] for ing_id in ing_ids])
33
- concentrations = 10 ** (- phs)
34
- mix_c = np.dot(weights, concentrations)
35
-
36
- rep['pH'] = - np.log10(mix_c)
37
-
38
- rep['complex'] = np.mean([ingredient_profiles['complex'][ing_id] for ing_id in ing_ids]) + len(ing_ids)
39
-
40
- # compute profile after dilution
41
- volume_ratio = profile['mix volume'] / profile['end volume']
42
- for k in representation_keys:
43
- rep['end ' + k] = rep[k] * volume_ratio
44
- concentration = 10 ** (-rep['pH'])
45
- end_concentration = concentration * volume_ratio
46
- rep['end pH'] = - np.log10(end_concentration)
47
- return rep
48
-
49
- def get_alcohol_profile(ingredients, quantities):
50
- ingredients = ingredients.copy()
51
- quantities = quantities.copy()
52
- assert len(ingredients) == len(quantities)
53
- if 'mint' in ingredients:
54
- mint_ind = ingredients.index('mint')
55
- ingredients.pop(mint_ind)
56
- quantities.pop(mint_ind)
57
- alcohol = []
58
- volume_mix = np.sum(quantities)
59
- weights = quantities / volume_mix
60
- assert np.abs(np.sum(weights) - 1) < 1e-4
61
- ingredients_list = [ing.lower() for ing in ingredient_list]
62
- for ing, q in zip(ingredients, quantities):
63
- id = ingredients_list.index(ing)
64
- alcohol.append(ingredient_profiles['ethanol'][id])
65
- alcohol = np.dot(alcohol, weights)
66
- return alcohol, volume_mix
67
-
68
- def get_mix_profile(ingredients, quantities):
69
- ingredients = ingredients.copy()
70
- quantities = quantities.copy()
71
- assert len(ingredients) == len(quantities)
72
- if 'mint' in ingredients:
73
- mint_ind = ingredients.index('mint')
74
- ingredients.pop(mint_ind)
75
- quantities.pop(mint_ind)
76
- alcohol, sugar, acid = [], [], []
77
- volume_mix = np.sum(quantities)
78
- weights = quantities / volume_mix
79
- assert np.abs(np.sum(weights) - 1) < 1e-4
80
- ingredients_list = [ing.lower() for ing in ingredient_list]
81
- for ing, q in zip(ingredients, quantities):
82
- id = ingredients_list.index(ing)
83
- sugar.append(ingredient_profiles['sugar'][id])
84
- alcohol.append(ingredient_profiles['ethanol'][id])
85
- acid.append(ingredient_profiles['acid'][id])
86
- sugar = np.dot(sugar, weights)
87
- acid = np.dot(acid, weights)
88
- alcohol = np.dot(alcohol, weights)
89
- return alcohol, sugar, acid
90
-
91
-
92
- def extract_preparation_type(instructions, recipe):
93
- flag = False
94
- instructions = instructions.lower()
95
- egg_in_recipe = any([find_ingredient_from_str(ing_str)[1]=='egg' for ing_str in recipe[1]])
96
- if 'shake' in instructions:
97
- if egg_in_recipe:
98
- prep_type = 'egg_shaken'
99
- else:
100
- prep_type = 'shaken'
101
- elif 'stir' in instructions:
102
- prep_type = 'stirred'
103
- elif 'blend' in instructions:
104
- prep_type = 'blended'
105
- elif any([w in instructions for w in ['build', 'mix', 'pour', 'combine', 'place']]):
106
- prep_type = 'built'
107
- else:
108
- prep_type = 'built'
109
- if egg_in_recipe and 'shaken' not in prep_type:
110
- stop = 1
111
- return flag, prep_type
112
-
113
- def get_dilution_ratio(category, alcohol):
114
- # formulas from the Liquid Intelligence book
115
- # The formula for built was invented
116
- if category == 'stirred':
117
- return -1.21 * alcohol**2 + 1.246 * alcohol + 0.145
118
- elif category in ['shaken', 'egg_shaken']:
119
- return -1.567 * alcohol**2 + 1.742 * alcohol + 0.203
120
- elif category == 'built':
121
- return (-1.21 * alcohol**2 + 1.246 * alcohol + 0.145) /2
122
- else:
123
- return 1
124
-
125
- def get_cocktail_rep(category, ingredients, quantities, keys):
126
- ingredients = ingredients.copy()
127
- quantities = quantities.copy()
128
- assert len(ingredients) == len(quantities)
129
-
130
- volume_mix = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] != 'mint'])
131
-
132
- # compute alcohol content without mint ingredient
133
- ingredients2 = [ing for ing in ingredients if ing != 'mint']
134
- quantities2 = [q for ing, q in zip(ingredients, quantities) if ing != 'mint']
135
- weights2 = quantities2 / np.sum(quantities2)
136
- assert np.abs(np.sum(weights2) - 1) < 1e-4
137
- ing_ids2 = np.array([ingredient2ingredient_id[ing] for ing in ingredients2])
138
- alcohol = np.array([ingredient_profiles['ethanol'][ing_id] for ing_id in ing_ids2])
139
- alcohol = np.dot(alcohol, weights2)
140
- dilution_ratio = get_dilution_ratio(category, alcohol)
141
- end_volume = volume_mix + volume_mix * dilution_ratio
142
- volume_ratio = volume_mix / end_volume
143
- end_alcohol = alcohol * volume_ratio
144
-
145
- # computes representation of a cocktail from the recipe (ingredients, quantities) and volume
146
- weights = quantities / np.sum(quantities)
147
- assert np.abs(np.sum(weights) - 1) < 1e-4
148
- ing_ids = np.array([ingredient2ingredient_id[ing] for ing in ingredients])
149
- reps = ing_reps[ing_ids]
150
- cocktail_rep = np.dot(weights, reps)
151
- i_complex = keys.index('end complex')
152
- cocktail_rep[i_complex] = np.mean(reps[:, i_complex]) + len(ing_ids) # complexity increases with number of ingredients
153
-
154
- # compute profile after dilution
155
- cocktail_rep = cocktail_rep * volume_ratio
156
- cocktail_rep = np.concatenate([[end_volume], cocktail_rep])
157
- return cocktail_rep, end_volume, end_alcohol
158
-
159
- def get_profile(category, ingredients, quantities):
160
-
161
- volume_mix = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] != 'mint'])
162
- alcohol, sugar, acid = get_mix_profile(ingredients, quantities)
163
- dilution_ratio = get_dilution_ratio(category, alcohol)
164
- end_volume = volume_mix + volume_mix * dilution_ratio
165
- volume_ratio = volume_mix / end_volume
166
- profile = {'mix volume': volume_mix,
167
- 'mix alcohol': alcohol,
168
- 'mix sugar': sugar,
169
- 'mix acid': acid,
170
- 'dilution ratio': dilution_ratio,
171
- 'end volume': end_volume,
172
- 'end alcohol': alcohol * volume_ratio,
173
- 'end sugar': sugar * volume_ratio,
174
- 'end acid': acid * volume_ratio}
175
- cocktail_rep = compute_cocktail_representation(profile, ingredients, quantities)
176
- profile.update(cocktail_rep)
177
- return profile
178
-
179
- profile_keys = ['mix volume', 'end volume',
180
- 'dilution ratio',
181
- 'mix alcohol', 'end alcohol',
182
- 'mix sugar', 'end sugar',
183
- 'mix acid', 'end acid'] \
184
- + representation_keys \
185
- + ['end ' + k for k in representation_keys]
186
-
187
- def update_profile_in_datapoint(datapoint, category, ingredients, quantities):
188
- profile = get_profile(category, ingredients, quantities)
189
- for k in profile_keys:
190
- datapoint[k] = profile[k]
191
- return datapoint
192
-
193
- # define representation keys
194
- def get_bunch_of_rep_keys():
195
- dict_rep_keys = dict()
196
- # all
197
- rep_keys = profile_keys
198
- dict_rep_keys['all'] = rep_keys
199
- # only_end
200
- rep_keys = [k for k in profile_keys if 'end' in k ]
201
- dict_rep_keys['only_end'] = rep_keys
202
- # except_end
203
- rep_keys = [k for k in profile_keys if 'end' not in k ]
204
- dict_rep_keys['except_end'] = rep_keys
205
- # custom
206
- to_remove = ['end alcohol', 'end sugar', 'end acid', 'end pH', 'end strong']
207
- rep_keys = [k for k in profile_keys if 'end' in k ]
208
- for k in to_remove:
209
- if k in rep_keys:
210
- rep_keys.remove(k)
211
- dict_rep_keys['custom'] = rep_keys
212
- # custom restricted
213
- to_remove = ['end alcohol', 'end sugar', 'end acid', 'end pH', 'end strong', 'end spicy', 'end oaky']
214
- rep_keys = [k for k in profile_keys if 'end' in k ]
215
- for k in to_remove:
216
- if k in rep_keys:
217
- rep_keys.remove(k)
218
- dict_rep_keys['restricted'] = rep_keys
219
- dict_rep_keys['affective'] = ['end booze', 'end sweet', 'end sour', 'end fizzy', 'end complex', 'end bitter', 'end spicy', 'end colorful']
220
- return dict_rep_keys
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/cocktails/utilities/glass_and_volume_utilities.py DELETED
@@ -1,42 +0,0 @@
1
-
2
-
3
- glass_conversion = {'coupe':'coupe',
4
- 'martini': 'martini',
5
- 'collins': 'collins',
6
- 'oldfashion': 'oldfashion',
7
- 'Coupe glass': 'coupe',
8
- 'Old-fashioned glass': 'oldfashion',
9
- 'Martini glass': 'martini',
10
- 'Nick & Nora glass': 'coupe',
11
- 'Julep tin': 'oldfashion',
12
- 'Collins or Pineapple shell glass': 'collins',
13
- 'Collins glass': 'collins',
14
- 'Rocks glass': 'oldfashion',
15
- 'Highball (max 10oz/300ml)': 'collins',
16
- 'Wine glass': 'coupe',
17
- 'Flute glass': 'coupe',
18
- 'Double old-fashioned': 'oldfashion',
19
- 'Copa glass': 'coupe',
20
- 'Toddy glass': 'oldfashion',
21
- 'Sling glass': 'collins',
22
- 'Goblet glass': 'oldfashion',
23
- 'Fizz or Highball (8oz to 10oz)': 'collins',
24
- 'Copper mug or Collins glass': 'collins',
25
- 'Tiki mug or collins': 'collins',
26
- 'Snifter glass': 'oldfashion',
27
- 'Coconut shell or Collins glass': 'collins',
28
- 'Martini (large 10oz) glass': 'martini',
29
- 'Hurricane glass': 'collins',
30
- 'Absinthe glass or old-fashioned glass': 'oldfashion'
31
- }
32
- glass_volume = dict(coupe = 200,
33
- collins=350,
34
- martini=200,
35
- oldfashion=320)
36
- assert set(glass_conversion.values()) == set(glass_volume.keys())
37
-
38
- volume_ranges = dict(stirred=(90, 97),
39
- built=(70, 75),
40
- shaken=(98, 112),
41
- egg_shaken=(130, 143),
42
- carbonated=(150, 150))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/cocktails/utilities/ingredients_utilities.py DELETED
@@ -1,209 +0,0 @@
1
- # This script loads the list and profiles of our ingredients selection.
2
- # It defines rules to recognize ingredients from the list in recipes and the function to extract that information from ingredient strings.
3
-
4
- import pandas as pd
5
- from src.cocktails.config import INGREDIENTS_LIST_PATH, COCKTAILS_CSV_DATA
6
- import numpy as np
7
-
8
- ingredient_profiles = pd.read_csv(INGREDIENTS_LIST_PATH)
9
- ingredient_list = [ing.lower() for ing in ingredient_profiles['ingredient']]
10
- n_ingredients = len(ingredient_list)
11
- ingredient2ingredient_id = dict(zip(ingredient_list, range(n_ingredients)))
12
-
13
- ingredients_types = sorted(set(ingredient_profiles['type']))
14
- # for each type, get all ingredients
15
- ing_per_type = [[ing for ing in ingredient_list if ingredient_profiles['type'][ingredient_list.index(ing)] == type] for type in ingredients_types]
16
- ingredients_per_type = dict(zip(ingredients_types, ing_per_type))
17
-
18
- bubble_ingredients = ['soda', 'ginger beer', 'tonic', 'sparkling wine']
19
- # rules to recognize ingredients in recipes.
20
- # in [] are separate rules with an OR relation: only one needs to be satisfied
21
- # within [], rules apply with and AND relation: all rules need to be satisfied.
22
- # ~ indicates that the following expression must NOT appear
23
- # simple expression indicate that the expression MUST appear.
24
- ingredient_search = {#'salt': ['salt'],
25
- 'lime juice': [['lime', '~soda', '~lemonade', '~cordial']],
26
- 'lemon juice': [['lemon', '~soda', '~lemonade']],
27
- 'angostura': [['angostura', '~orange'],
28
- ['bitter', '~campari', '~orange', '~red', '~italian', '~fernet']],
29
- 'orange bitters': [['orange', 'bitter', '~bittersweet']],
30
- 'orange juice': [['orange', '~bitter', '~jam', '~marmalade', '~liqueur', '~water'],
31
- ['orange', 'squeeze']],
32
- 'pineapple juice': [['pineapple']],
33
- # 'apple juice': [['apple', 'juice', '~pine']],
34
- 'cranberry juice': [['cranberry', 'juice']],
35
- 'cointreau': ['cointreau', 'triple sec', 'grand marnier', 'curaçao', 'curacao'],
36
- 'luxardo maraschino': ['luxardo', 'maraschino', 'kirsch'],
37
- 'amaretto': ['amaretto'],
38
- 'benedictine': ['benedictine', 'bénédictine', 'bénedictine', 'benédictine'],
39
- 'campari': ['campari', ['italian', 'red', 'bitter'], 'aperol', 'bittersweet', 'aperitivo', 'orange-red'],
40
- # 'campari': ['campari', ['italian', 'red', 'bitter']],
41
- # 'crème de violette': [['violette', 'crème'], ['crême', 'violette'], ['liqueur', 'violette']],
42
- # 'aperol': ['aperol', 'bittersweet', 'aperitivo', 'orange-red'],
43
- 'green chartreuse': ['chartreuse'],
44
- 'black raspberry liqueur': [['cassis', 'liqueur'],
45
- ['black raspberry', 'liqueur'],
46
- ['raspberry', 'liqueur'],
47
- ['strawberry', 'liqueur'],
48
- ['blackberry', 'liqueur'],
49
- ['violette', 'crème'], ['crême', 'violette'], ['liqueur', 'violette']],
50
- # 'simple syrup': [],
51
- # 'drambuie': ['drambuie'],
52
- # 'fernet branca': ['fernet', 'branca'],
53
- 'gin': [['gin', '~sloe', '~ginger']],
54
- 'vodka': ['vodka'],
55
- 'cuban rum': [['rum', 'puerto rican'], ['light', 'rum'], ['white', 'rum'], ['rum', 'havana', '~7'], ['rum', 'bacardi']],
56
- 'cognac': [['cognac', '~grand marnier', '~cointreau', '~orange']],
57
- # 'bourbon': [['bourbon', '~liqueur']],
58
- # 'tequila': ['tequila', 'pisco'],
59
- # 'tequila': ['tequila'],
60
- 'scotch': ['scotch'],
61
- 'dark rum': [['rum', 'age', '~bacardi', '~havana'],
62
- ['rum', 'dark', '~bacardi', '~havana'],
63
- ['rum', 'old', '~bacardi', '~havana'],
64
- ['rum', 'old', '7'],
65
- ['rum', 'havana', '7'],
66
- ['havana', 'rum', 'especial']],
67
- 'absinthe': ['absinthe'],
68
- 'rye whiskey': ['rye', ['bourbon', '~liqueur']],
69
- # 'rye whiskey': ['rye'],
70
- 'apricot brandy': [['apricot', 'brandy']],
71
- # 'pisco': ['pisco'],
72
- # 'cachaça': ['cachaça', 'cachaca'],
73
- 'egg': [['egg', 'white', '~yolk', '~whole']],
74
- 'soda': [['soda', 'water', '~lemon', '~lime']],
75
- 'mint': ['mint'],
76
- 'sparkling wine': ['sparkling wine', 'prosecco', 'champagne'],
77
- 'ginger beer': [['ginger', 'beer'], ['ginger', 'ale']],
78
- 'tonic': [['tonic'], ['7up'], ['sprite']],
79
- # 'espresso': ['espresso', 'expresso', ['café', '~liqueur', '~cream'],
80
- # ['cafe', '~liqueur', '~cream'],
81
- # ['coffee', '~liqueur', '~cream']],
82
- # 'southern comfort': ['southern comfort'],
83
- # 'cola': ['cola', 'coke', 'pepsi'],
84
- 'double syrup': [['sugar','~raspberry'], ['simple', 'syrup'], ['double', 'syrup']],
85
- # 'grenadine': ['grenadine', ['pomegranate', 'syrup']],
86
- 'grenadine': ['grenadine', ['pomegranate', 'syrup'], ['raspberry', 'syrup', '~black']],
87
- 'honey syrup': ['honey', ['maple', 'syrup']],
88
- # 'raspberry syrup': [['raspberry', 'syrup', '~black']],
89
- 'dry vermouth': [['vermouth', 'dry'], ['vermouth', 'white'], ['vermouth', 'french'], 'lillet'],
90
- 'sweet vermouth': [['vermouth', 'sweet'], ['vermouth', 'red'], ['vermouth', 'italian']],
91
- # 'lillet blanc': ['lillet'],
92
- 'water': [['water', '~sugar', '~coconut', '~soda', '~tonic', '~honey', '~orange', '~melon']]
93
- }
94
- # check that there is a rule for all ingredients in the list
95
- assert sorted(ingredient_list) == sorted(ingredient_search.keys()), 'ing search dict keys do not match ingredient list'
96
-
97
- def get_ingredients_info():
98
- data = pd.read_csv(COCKTAILS_CSV_DATA)
99
- max_ingredients, ingredient_set, liquor_set, liqueur_set, vermouth_set = get_max_n_ingredients(data)
100
- ingredient_list = sorted(ingredient_set)
101
- alcohol = sorted(liquor_set.union(liqueur_set).union(vermouth_set).union(set(['sparkling wine'])))
102
- ind_alcohol = [i for i in range(len(ingredient_list)) if ingredient_list[i] in alcohol]
103
- return max_ingredients, ingredient_list, ind_alcohol
104
-
105
- def get_max_n_ingredients(data):
106
- max_count = 0
107
- ingredient_set = set()
108
- alcohol_set = set()
109
- liqueur_set = set()
110
- vermouth_set = set()
111
- ing_str = np.array(data['ingredients_str'])
112
- for i in range(len(data['names'])):
113
- ingredients, quantities = extract_ingredients(ing_str[i])
114
- max_count = max(max_count, len(ingredients))
115
- for ing in ingredients:
116
- ingredient_set.add(ing)
117
- if ing in ingredients_per_type['liquor']:
118
- alcohol_set.add(ing)
119
- if ing in ingredients_per_type['liqueur']:
120
- liqueur_set.add(ing)
121
- if ing in ingredients_per_type['vermouth']:
122
- vermouth_set.add(ing)
123
- return max_count, ingredient_set, alcohol_set, liqueur_set, vermouth_set
124
-
125
- def find_ingredient_from_str(ing_str):
126
- # function that assigns an ingredient string to one of the ingredient if possible, following the rules defined above.
127
- # return a flag and the ingredient string. When flag is false, the ingredient has not been found and the cocktail is rejected.
128
- ing_str = ing_str.lower()
129
- flags = []
130
- for k in ingredient_list:
131
- or_flags = [] # get flag for each of several conditions
132
- for i_p, pattern in enumerate(ingredient_search[k]):
133
- or_flags.append(True)
134
- if isinstance(pattern, str):
135
- if pattern[0] == '~' and pattern[1:] in ing_str:
136
- or_flags[-1] = False
137
- elif pattern[0] != '~' and pattern not in ing_str:
138
- or_flags[-1] = False
139
- elif isinstance(pattern, list):
140
- for element in pattern:
141
- if element[0] == '~':
142
- or_flags[-1] = or_flags[-1] and not element[1:] in ing_str
143
- else:
144
- or_flags[-1] = or_flags[-1] and element in ing_str
145
- else:
146
- raise ValueError
147
- flags.append(any(or_flags))
148
- if sum(flags) > 1:
149
- print(ing_str)
150
- for i_f, f in enumerate(flags):
151
- if f:
152
- print(ingredient_list[i_f])
153
- stop = 1
154
- return True, ingredient_list[flags.index(True)]
155
- elif sum(flags) == 0:
156
- # if 'grape' not in ing_str:
157
- # print('\t\t Not found:', ing_str)
158
- return True, None
159
- else:
160
- return False, ingredient_list[flags.index(True)]
161
-
162
- def get_cocktails_per_ingredient(ing_strs):
163
- cocktails_per_ing = dict(zip(ingredient_list, [[] for _ in range(len(ingredient_list))]))
164
- for i_ing, ing_str in enumerate(ing_strs):
165
- ingredients, _ = extract_ingredients(ing_str)
166
- for ing in ingredients:
167
- cocktails_per_ing[ing].append(i_ing)
168
- return cocktails_per_ing
169
-
170
- def extract_ingredients(ingredient_str):
171
- # extract list of ingredients and quantities from an formatted ingredient string (reverse of format_ingredients)
172
- ingredient_str = ingredient_str[1: -1]
173
- words = ingredient_str.split(',')
174
- ingredients = []
175
- quantities = []
176
- for i in range(len(words)//2):
177
- ingredients.append(words[2 * i][1:])
178
- quantities.append(float(words[2 * i + 1][:-1]))
179
- return ingredients, quantities
180
-
181
- def format_ingredients(ingredients, quantities):
182
- # format an ingredient string from the lists of ingredients and quantities (reverse of extract_ingredients)
183
- out = '['
184
- for ing, q in zip(ingredients, quantities):
185
- if ing[-1] == ' ':
186
- ingre = ing[:-1]
187
- else:
188
- ingre = ing
189
- out += f'({ingre},{q}),'
190
- out = out[:-1] + ']'
191
- return out
192
-
193
-
194
- def get_ingredient_count(data):
195
- # get count of ingredients in the whole dataset
196
- ingredient_counts = dict(zip(ingredient_list, [0] * len(ingredient_list)))
197
- for i in range(len(data['names'])):
198
- if data['to_keep'][i]:
199
- ingredients, _ = extract_ingredients(data['ingredients_str'][i])
200
- for i in ingredients:
201
- ingredient_counts[i] += 1
202
- return ingredient_counts
203
-
204
- def add_counts_to_ingredient_list(data):
205
- # update the list of ingredients to add their count of occurence in dataset.
206
- ingredient_counts = get_ingredient_count(data)
207
- counts = [ingredient_counts[k] for k in ingredient_list]
208
- ingredient_profiles['counts'] = counts
209
- ingredient_profiles.to_csv(INGREDIENTS_LIST_PATH, index=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/cocktails/utilities/other_scrubbing_utilities.py DELETED
@@ -1,240 +0,0 @@
1
- import numpy as np
2
- import pickle
3
- from src.cocktails.utilities.cocktail_utilities import get_profile, profile_keys
4
- from src.cocktails.utilities.ingredients_utilities import extract_ingredients, ingredient_list, ingredient_profiles
5
- from src.cocktails.utilities.glass_and_volume_utilities import glass_volume, volume_ranges
6
-
7
- one_dash = 1
8
- one_splash = 6
9
- one_tablespoon = 15
10
- one_barspoon = 5
11
- fill_rate = 0.8
12
- quantity_factors ={'ml':1,
13
- 'cl':10,
14
- 'splash':one_splash,
15
- 'splashes':one_splash,
16
- 'dash':one_dash,
17
- 'dashes':one_dash,
18
- 'spoon':one_barspoon,
19
- 'spoons':one_barspoon,
20
- 'tablespoon':one_tablespoon,
21
- 'barspoons':one_barspoon,
22
- 'barspoon':one_barspoon,
23
- 'bar spoons': one_barspoon,
24
- 'bar spoon': one_barspoon,
25
- 'tablespoons':one_tablespoon,
26
- 'teaspoon':5,
27
- 'teaspoons':5,
28
- 'drop':0.05,
29
- 'drops':0.05}
30
- quantitiy_keys = sorted(quantity_factors.keys())
31
- indexes_keys = np.flip(np.argsort([len(k) for k in quantitiy_keys]))
32
- quantity_factors_keys = list(np.array(quantitiy_keys)[indexes_keys])
33
-
34
- keys_to_track = ['names', 'urls', 'glass', 'garnish', 'recipe', 'how_to', 'review', 'taste_rep', 'valid']
35
- keys_to_add = ['category', 'subcategory', 'ingredients_str', 'ingredients', 'quantities', 'to_keep']
36
- keys_to_update = ['glass']
37
- keys_for_csv = ['names', 'category', 'subcategory', 'ingredients_str', 'urls', 'glass', 'garnish', 'how_to', 'review', 'taste_rep'] + profile_keys
38
-
39
- to_replace_q = {' fresh': ''}
40
- to_replace_ing = {'maple syrup': 'honey syrup',
41
- 'agave syrup': 'honey syrup',
42
- 'basil': 'mint'}
43
-
44
- def print_recipe(unit='mL', ingredient_str=None, ingredients=None, quantities=None, name='', cat='', to_print=True):
45
- str_out = ''
46
- if ingredient_str is None:
47
- assert len(ingredients) == len(quantities), 'provide either ingredient_str, or list ingredients and quantities'
48
- else:
49
- assert ingredients is None and quantities is None, 'provide either ingredient_str, or list ingredients and quantities'
50
- ingredients, quantities = extract_ingredients(ingredient_str)
51
-
52
- str_out += f'\nRecipe:'
53
- if name != '' and name is not None: str_out += f' {name}'
54
- if cat != '': str_out += f' ({cat})'
55
- str_out += '\n'
56
- for i in range(len(ingredients)):
57
- # get quantifier
58
- if ingredients[i] == 'egg':
59
- quantities[i] = 1
60
- ingredients[i] = 'egg white'
61
- if unit == 'mL':
62
- quantifier = ' (30 mL)'
63
- elif unit == 'oz':
64
- quantifier = ' (1 fl oz)'
65
- else:
66
- raise ValueError
67
- elif ingredients[i] in ['angostura', 'orange bitters']:
68
- quantities[i] = max(1, int(quantities[i] / 0.6))
69
- quantifier = ' dash'
70
- if quantities[i] > 1: quantifier += 'es'
71
- elif ingredients[i] == 'mint':
72
- if quantities[i] > 1: quantifier = ' leaves'
73
- else: quantifier = ' leaf'
74
- else:
75
- if unit == "oz":
76
- quantities[i] = float(f"{quantities[i] * 0.033814:.3f}") # convert to fl oz
77
- quantifier = ' fl oz'
78
- else:
79
- quantifier = ' mL'
80
- str_out += f' {quantities[i]}{quantifier} - {ingredients[i]}\n'
81
-
82
- if to_print:
83
- print(str_out)
84
- return str_out
85
-
86
-
87
- def test_datapoint(datapoint, category, ingredients, quantities):
88
- # run checks
89
- ingredient_indexes = [ingredient_list.index(ing) for ing in ingredients]
90
- profile = get_profile(category, ingredients, quantities)
91
- volume = profile['end volume']
92
- alcohol = profile['end alcohol']
93
- acid = profile['end acid']
94
- sugar = profile['end sugar']
95
- # check volume
96
- if datapoint['glass'] != None:
97
- if volume > glass_volume[datapoint['glass']] * fill_rate:
98
- # recompute quantities for it to match
99
- ratio = fill_rate * glass_volume[datapoint['glass']] / volume
100
- for i_q in range(len(quantities)):
101
- quantities[i_q] = float(f'{quantities[i_q] * ratio:.2f}')
102
- # check alcohol
103
- assert alcohol < 30, 'too boozy'
104
- assert alcohol < 5, 'not boozy enough'
105
- assert acid < 2, 'too much acid'
106
- assert sugar < 20, 'too much sugar'
107
- assert len(ingredients) > 1, 'only one ingredient'
108
- if len(set(ingredients)) != len(ingredients):
109
- i_doubles = []
110
- s_ing = set()
111
- for i, ing in enumerate(ingredients):
112
- if ing in s_ing:
113
- i_doubles.append(i)
114
- else:
115
- s_ing.add(ing)
116
- ingredient_double_ok = ['mint', 'cointreau', 'lemon juice', 'cuban rum', 'double syrup']
117
- if len(i_doubles) == 1 and ingredients[i_doubles[0]] in ingredient_double_ok:
118
- ing_double = ingredients[i_doubles[0]]
119
- double_q = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] == ing_double])
120
- ingredients.pop(i_doubles[0])
121
- quantities.pop(i_doubles[0])
122
- quantities[ingredients.index(ing_double)] = double_q
123
- else:
124
- assert False, f'double ingredient, not {ingredient_double_ok}'
125
- lemon_lime_q = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] in ['lime juice', 'lemon juice']])
126
- assert lemon_lime_q <= 45, 'too much lemon and lime'
127
- salt_q = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] == 'salt'])
128
- assert salt_q <= 8, 'too much salt'
129
- bitter_q = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] in ['angostura', 'orange bitters']])
130
- assert bitter_q <= 5 * one_dash, 'too much bitter'
131
- absinthe_q = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] == 'absinthe'])
132
- if absinthe_q > 4 * one_dash:
133
- mix_volume = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] != 'mint'])
134
- assert absinthe_q < 0.5 * mix_volume, 'filter absinthe glasses'
135
- if any([w in datapoint['how_to'] or any([w in ing.lower() for ing in datapoint['recipe'][1]]) for w in ['warm', 'boil', 'hot']]) and 'shot' not in datapoint['how_to']:
136
- assert False
137
- water_q = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] == 'water'])
138
- assert water_q < 40
139
- # n_liqueur = np.sum([ingredient_profiles['type'][i].lower() == 'liqueur' for i in ingredient_indexes])
140
- # assert n_liqueur <= 2
141
- n_liqueur_and_vermouth = np.sum([ingredient_profiles['type'][i].lower() in ['liqueur', 'vermouth'] for i in ingredient_indexes])
142
- assert n_liqueur_and_vermouth <= 3
143
- return ingredients, quantities
144
-
145
- def run_battery_checks_difford(datapoint, category, ingredients, quantities):
146
- flag = False
147
- try:
148
- ingredients, quantities = test_datapoint(datapoint, category, ingredients, quantities)
149
- except:
150
- flag = True
151
- print(datapoint["names"])
152
- print(datapoint["urls"])
153
- ingredients, quantities = None, None
154
-
155
- return flag, ingredients, quantities
156
-
157
- def tambouille(q, ingredients_scrubbed, quantities_scrubbed, cat):
158
- # ugly
159
- ing_scrubbed = ingredients_scrubbed[len(quantities_scrubbed)]
160
- if q == '4 cube' and ing_scrubbed == 'pineapple juice':
161
- q = '20 ml'
162
- elif 'top up with' in q:
163
- volume_so_far = np.sum([quantities_scrubbed[i] for i in range(len(quantities_scrubbed)) if ingredients_scrubbed[i] != 'mint'])
164
- volume_mix = np.sum(volume_ranges[cat]) / 2
165
- if (volume_mix - volume_so_far) < 15:
166
- q = '15 ml'#
167
- else:
168
- q = str(int(volume_mix - volume_so_far)) + ' ml'
169
- elif q == '1 pinch' and ing_scrubbed == 'salt':
170
- q = '2 drops'
171
- elif 'cube' in q and ing_scrubbed == 'double syrup':
172
- q = f'{float(q.split(" ")[0]) * 2 * 1.7:.2f} ml' #2g per cube, 1.7 is ratio solid / syrup
173
- elif 'wedge' in q:
174
- if ing_scrubbed == 'orange juice':
175
- vol = 70
176
- elif ing_scrubbed == 'lime juice':
177
- vol = 30
178
- elif ing_scrubbed == 'lemon juice':
179
- vol = 45
180
- elif ing_scrubbed == 'pineapple juice':
181
- vol = 140
182
- factor = float(q.split(' ')[0]) * 0.15 # consider a wedge to be 0.15*the fruit.
183
- q = f'{factor * vol:.2f} ml'
184
- elif 'slice' in q:
185
- if ing_scrubbed == 'orange juice':
186
- vol = 70
187
- elif ing_scrubbed == 'lime juice':
188
- vol = 30
189
- elif ing_scrubbed == 'lemon juice':
190
- vol = 45
191
- elif ing_scrubbed == 'pineapple juice':
192
- vol = 140
193
- f = q.split(' ')[0]
194
- if len(f.split('⁄')) > 1:
195
- frac = f.split('⁄')
196
- factor = float(frac[0]) / float(frac[1])
197
- else:
198
- factor = float(f)
199
- factor *= 0.1 # consider a slice to be 0.1*the fruit.
200
- q = f'{factor * vol:.2f} ml'
201
- elif q == '1 whole' and ing_scrubbed == 'luxardo maraschino':
202
- q = '10 ml'
203
- elif ing_scrubbed == 'egg' and 'ml' not in q:
204
- q = f'{float(q) * 30:.2f} ml' # 30 ml per egg
205
- return q
206
-
207
-
208
- def compute_eucl_dist(a, b):
209
- return np.sqrt(np.sum((a - b)**2))
210
-
211
- def evaluate_with_quadruplets(representations, strategy='all'):
212
- with open(QUADRUPLETS_PATH, 'rb') as f:
213
- data = pickle.load(f)
214
- data = list(data.values())
215
- quadruplets = []
216
- if strategy != 'all':
217
- for d in data:
218
- if d[0] == strategy:
219
- quadruplets.append(d[1:])
220
- elif strategy == 'all':
221
- for d in data:
222
- quadruplets.append(d[1:])
223
- else:
224
- raise ValueError
225
-
226
- scores = []
227
- for q in quadruplets:
228
- close = q[0]
229
- if len(close) == 2:
230
- far = q[1]
231
- distance_close = compute_eucl_dist(representations[close[0]], representations[close[1]])
232
- distances_far = [compute_eucl_dist(representations[far[i][0]], representations[far[i][1]]) for i in range(len(far))]
233
- scores.append(distance_close < np.min(distances_far))
234
- if len(scores) == 0:
235
- score = np.nan
236
- else:
237
- score = np.mean(scores)
238
- return score
239
-
240
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/debugger.py DELETED
@@ -1,180 +0,0 @@
1
- import os.path
2
-
3
- # from src.music.data_collection.is_audio_solo_piano import calculate_piano_solo_prob
4
- from src.music.utils import load_audio
5
- from src.music.config import FPS
6
- import pretty_midi as pm
7
- import numpy as np
8
- from src.music.config import MUSIC_REP_PATH, MUSIC_NN_PATH
9
- from sklearn.neighbors import NearestNeighbors
10
- from src.cocktails.config import FULL_COCKTAIL_REP_PATH, COCKTAIL_NN_PATH, COCKTAILS_CSV_DATA
11
- # from src.cocktails.pipeline.get_affect2affective_cluster import get_affective_cluster_centers
12
- from src.cocktails.utilities.other_scrubbing_utilities import print_recipe
13
- from src.music.utils import get_all_subfiles_with_extension
14
- import os
15
- import pickle
16
- import pandas as pd
17
- import time
18
-
19
- keyword = 'b256_r128_represented'
20
- def load_reps(rep_path, sample_size=None):
21
- if sample_size:
22
- with open(rep_path + f'all_reps_unnormalized_sample{sample_size}.pickle', 'rb') as f:
23
- data = pickle.load(f)
24
- else:
25
- with open(rep_path + f'music_reps_unnormalized.pickle', 'rb') as f:
26
- data = pickle.load(f)
27
- reps = data['reps']
28
- # playlists = [r.split(f'_{keyword}')[0].split('/')[-1] for r in data['paths']]
29
- playlists = [r.split(f'{keyword}')[1].split('/')[1] for r in data['paths']]
30
- n_data, dim_data = reps.shape
31
- return reps, data['paths'], playlists, n_data, dim_data
32
-
33
- class Debugger():
34
- def __init__(self, verbose=True):
35
-
36
- if verbose: print('Setting up debugger.')
37
- if not os.path.exists(MUSIC_NN_PATH):
38
- reps_path = MUSIC_REP_PATH + 'music_reps_unnormalized.pickle'
39
- if not os.path.exists(reps_path):
40
- all_rep_path = get_all_subfiles_with_extension(MUSIC_REP_PATH, max_depth=3, extension='.txt', current_depth=0)
41
- all_data = []
42
- new_all_rep_path = []
43
- for i_r, r in enumerate(all_rep_path):
44
- if 'mean_std' not in r:
45
- all_data.append(np.loadtxt(r))
46
- assert len(all_data[-1]) == 128
47
- new_all_rep_path.append(r)
48
- data = np.array(all_data)
49
- to_save = dict(reps=data,
50
- paths=new_all_rep_path)
51
- with open(reps_path, 'wb') as f:
52
- pickle.dump(to_save, f)
53
-
54
- reps, self.rep_paths, playlists, n_data, self.dim_rep_music = load_reps(MUSIC_REP_PATH)
55
- self.nn_model_music = NearestNeighbors(n_neighbors=6, metric='cosine')
56
- self.nn_model_music.fit(reps)
57
- to_save = dict(nn_model=self.nn_model_music,
58
- rep_paths=self.rep_paths,
59
- dim_rep_music=self.dim_rep_music)
60
- with open(MUSIC_NN_PATH, 'wb') as f:
61
- pickle.dump(to_save, f)
62
- else:
63
- with open(MUSIC_NN_PATH, 'rb') as f:
64
- data = pickle.load(f)
65
- self.nn_model_music = data['nn_model']
66
- self.rep_paths = data['rep_paths']
67
- self.dim_rep_music = data['dim_rep_music']
68
- if verbose: print(f' {len(self.rep_paths)} songs, representation dim: {self.dim_rep_music}')
69
- self.rep_paths = np.array(self.rep_paths)
70
- if not os.path.exists(COCKTAIL_NN_PATH):
71
- cocktail_reps = np.loadtxt(FULL_COCKTAIL_REP_PATH)
72
- # cocktail_reps = (cocktail_reps - cocktail_reps.mean(axis=0)) / cocktail_reps.std(axis=0)
73
- self.nn_model_cocktail = NearestNeighbors(n_neighbors=6)
74
- self.nn_model_cocktail.fit(cocktail_reps)
75
- self.dim_rep_cocktail = cocktail_reps.shape[1]
76
- self.n_cocktails = cocktail_reps.shape[0]
77
- to_save = dict(nn_model=self.nn_model_cocktail,
78
- dim_rep_cocktail=self.dim_rep_cocktail,
79
- n_cocktails=self.n_cocktails)
80
- with open(COCKTAIL_NN_PATH, 'wb') as f:
81
- pickle.dump(to_save, f)
82
- else:
83
- with open(COCKTAIL_NN_PATH, 'rb') as f:
84
- data = pickle.load(f)
85
- self.nn_model_cocktail = data['nn_model']
86
- self.dim_rep_cocktail = data['dim_rep_cocktail']
87
- self.n_cocktails = data['n_cocktails']
88
- if verbose: print(f' {self.n_cocktails} cocktails, representation dim: {self.dim_rep_cocktail}')
89
-
90
- self.cocktail_data = pd.read_csv(COCKTAILS_CSV_DATA)
91
- # self.affective_cluster_centers = get_affective_cluster_centers()
92
- self.keys_to_print = ['mse_reconstruction', 'nearest_cocktail_recipes', 'nearest_cocktail_urls',
93
- 'nn_music_dists', 'nn_music', 'dim_rep', 'nb_notes', 'audio_len', 'piano_solo_prob', 'recipe_score', 'cocktail_rep']
94
- # 'affect', 'affective_cluster_id', 'affective_cluster_center',
95
-
96
-
97
- def get_nearest_songs(self, music_rep):
98
- dists, indexes = self.nn_model_music.kneighbors(music_rep.reshape(1, -1))
99
- indexes = indexes.flatten()[:5]
100
- rep_paths = [r.split('/')[-1] for r in self.rep_paths[indexes[:5]]]
101
- return rep_paths, dists.flatten().tolist()
102
-
103
- def get_nearest_cocktails(self, cocktail_rep):
104
- dists, indexes = self.nn_model_cocktail.kneighbors(cocktail_rep.reshape(1, -1))
105
- indexes = indexes.flatten()
106
- nn_names = np.array(self.cocktail_data['names'])[indexes].tolist()
107
- nn_urls = np.array(self.cocktail_data['urls'])[indexes].tolist()
108
- nn_recipes = [print_recipe(ingredient_str=ing_str, to_print=False) for ing_str in np.array(self.cocktail_data['ingredients_str'])[indexes]]
109
- nn_ing_strs = np.array(self.cocktail_data['ingredients_str'])[indexes].tolist()
110
- return indexes, nn_names, nn_urls, nn_recipes, nn_ing_strs
111
-
112
- def extract_info(self, all_paths, affective_cluster_id, affect, cocktail_rep, music_reconstruction, recipe_score, verbose=False, level=0):
113
- if verbose: print(' ' * level + 'Extracting debug info..')
114
- init_time = time.time()
115
- debug_dict = dict()
116
- debug_dict['all_paths'] = all_paths
117
- debug_dict['recipe_score'] = recipe_score
118
-
119
- if all_paths['audio_path'] != None:
120
- # is it piano?
121
- debug_dict['piano_solo_prob'] = None#float(calculate_piano_solo_prob(all_paths['audio_path'])[0])
122
- # how long is the audio
123
- (audio, _) = load_audio(all_paths['audio_path'], sr=FPS, mono=True)
124
- debug_dict['audio_len'] = int(len(audio) / FPS)
125
- else:
126
- debug_dict['piano_solo_prob'] = None
127
- debug_dict['audio_len'] = None
128
-
129
- # how many notes?
130
- midi = pm.PrettyMIDI(all_paths['processed_path'])
131
- debug_dict['nb_notes'] = len(midi.instruments[0].notes)
132
-
133
- # dimension of music rep
134
- representation = np.loadtxt(all_paths['representation_path'])
135
- debug_dict['dim_rep'] = representation.shape[0]
136
-
137
- # closest songs in dataset
138
- debug_dict['nn_music'], debug_dict['nn_music_dists'] = self.get_nearest_songs(representation)
139
-
140
- # get affective cluster info
141
- # debug_dict['affective_cluster_id'] = affective_cluster_id[0]
142
- # debug_dict['affective_cluster_center'] = self.affective_cluster_centers[affective_cluster_id].flatten().tolist()
143
- # debug_dict['affect'] = affect.flatten().tolist()
144
- indexes, nn_names, nn_urls, nn_recipes, nn_ing_strs = self.get_nearest_cocktails(cocktail_rep)
145
- debug_dict['cocktail_rep'] = cocktail_rep.copy().tolist()
146
- debug_dict['nearest_cocktail_indexes'] = indexes.tolist()
147
- debug_dict['nn_ing_strs'] = nn_ing_strs
148
- debug_dict['nearest_cocktail_names'] = nn_names
149
- debug_dict['nearest_cocktail_urls'] = nn_urls
150
- debug_dict['nearest_cocktail_recipes'] = nn_recipes
151
-
152
- debug_dict['music_reconstruction'] = music_reconstruction.tolist()
153
- debug_dict['mse_reconstruction'] = ((music_reconstruction - representation) ** 2).mean()
154
- self.debug_dict = debug_dict
155
- if verbose: print(' ' * (level + 2) + f'Debug info extracted in {int(time.time() - init_time)} seconds.')
156
-
157
- return self.debug_dict
158
-
159
- def print_debug(self, level=0):
160
- print(' ' * level + '__DEBUGGING INFO__')
161
- for k in self.keys_to_print:
162
- to_print = self.debug_dict[k]
163
- if k == 'nearest_cocktail_recipes':
164
- to_print = self.debug_dict[k].copy()
165
- for i in range(len(to_print)):
166
- to_print[i] = to_print[i].replace('\n', '').replace('\t', '').replace('()', '')
167
- if k == "nn_music":
168
- to_print = self.debug_dict[k].copy()
169
- for i in range(len(to_print)):
170
- to_print[i] = to_print[i].replace('encoded_new_structured_', '').replace('_represented.txt', '')
171
- to_print_str = f'{to_print}'
172
- if isinstance(to_print, float):
173
- to_print_str = f'{to_print:.2f}'
174
- elif isinstance(to_print, list):
175
- if isinstance(to_print[0], float):
176
- to_print_str = '['
177
- for element in to_print:
178
- to_print_str += f'{element:.2f}, '
179
- to_print_str = to_print_str[:-2] + ']'
180
- print(' ' * (level + 2) + f'{k} : ' + to_print_str)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/music/__init__.py DELETED
File without changes
src/music/config.py DELETED
@@ -1,72 +0,0 @@
1
- import numpy as np
2
- import os
3
-
4
- REPO_PATH = '/'.join(os.path.abspath(__file__).split('/')[:-3]) + '/'
5
- AUDIO_PATH = REPO_PATH + 'data/music/audio/'
6
- MIDI_PATH = REPO_PATH + 'data/music/midi/'
7
- MUSIC_PATH = REPO_PATH + 'data/music/'
8
- PROCESSED_PATH = REPO_PATH + 'data/music/processed/'
9
- ENCODED_PATH = REPO_PATH + 'data/music/encoded/'
10
- HANDCODED_REP_PATH = MUSIC_PATH + 'handcoded_reps/'
11
- DATASET_PATH = REPO_PATH + 'data/music/encoded_new_structured/diverse_piano/'
12
- SYNTH_RECORDED_AUDIO_PATH = AUDIO_PATH + 'synth_audio_recorded/'
13
- SYNTH_RECORDED_MIDI_PATH = MIDI_PATH + 'synth_midi_recorded/'
14
- CHECKPOINTS_PATH = REPO_PATH + 'checkpoints/'
15
- EXPERIMENT_PATH = REPO_PATH + 'experiments/'
16
- SEED = 0
17
-
18
- # params for data download
19
- ALL_URL_PATH = REPO_PATH + 'data/music/audio/all_urls.pickle'
20
- ALL_FAILED_URL_PATH = REPO_PATH + 'data/music/audio/all_failed_urls.pickle'
21
- RATE_AUDIO_SAVE = 16000
22
- FROM_URL_PATH = AUDIO_PATH + 'from_url/'
23
-
24
- # params transcription
25
- CHKPT_PATH_TRANSCRIPTION = REPO_PATH + 'checkpoints/piano_transcription/note_F1=0.9677_pedal_F1=0.9186.pth' # transcriptor chkpt path
26
- FPS = 16000
27
- RANDOM_CROP = True # whether to use random crops in case of cropped audio
28
- CROP_LEN = 26 * 60
29
-
30
- # params midi scrubbing and processing
31
- MAX_DEPTH = 5 # max depth when searching in folders for audio files
32
- MAX_GAP_IN_SONG = 10 # in secs
33
- MIN_LEN = 20 # actual min len could go down to MIN_LEN - 2 * (REMOVE_FIRST_AND_LAST / 5)
34
- MAX_LEN = 25 * 60 # maximum audio len for playlist downloads, and maximum audio length for transcription (in sec)
35
- MIN_NB_NOTES = 80 # min nb of notes per minute of recording
36
- REMOVE_FIRST_AND_LAST = 10 # will be divided by 5 if cutting this makes the song fall below min len
37
-
38
- # parameters encoding
39
- NOISE_INJECTED = True
40
- AUGMENTATION = True
41
- NB_AUG = 4 if AUGMENTATION else 0
42
- RANGE_NOTE_ON = 128
43
- RANGE_NOTE_OFF = 128
44
- RANGE_VEL = 32
45
- RANGE_TIME_SHIFT = 100
46
- MAX_EMBEDDING = RANGE_VEL + RANGE_NOTE_OFF + RANGE_TIME_SHIFT + RANGE_NOTE_ON
47
- MAX_TEST_SIZE = 1000
48
- CHECKSUM_PATH = REPO_PATH + 'data/music/midi/checksum.pickle'
49
- CHUNK_SIZE = 512
50
-
51
- ALL_AUGMENTATIONS = []
52
- for p in [-3, -2, -1, 1, 2, 3]:
53
- ALL_AUGMENTATIONS.append((p))
54
- ALL_AUGMENTATIONS = np.array(ALL_AUGMENTATIONS)
55
-
56
- ALL_NOISE = []
57
- for s in [-5, -2.5, 0, 2.5, 5]:
58
- for p in np.arange(-6, 7):
59
- if not ((s == 0) and (p==0)):
60
- ALL_NOISE.append((s, p))
61
- ALL_NOISE = np.array(ALL_NOISE)
62
-
63
- # music transformer params
64
- REP_MODEL_NAME = REPO_PATH + "checkpoints/music_representation/sentence_embedding/smallbert_b256_r128_1/best_model"
65
- MUSIC_REP_PATH = REPO_PATH + "checkpoints/b256_r128_represented/"
66
- MUSIC_NN_PATH = REPO_PATH + "checkpoints/music_representation/b256_r128_represented/nn_model.pickle"
67
-
68
- TRANSLATION_VAE_CHKP_PATH = REPO_PATH + "checkpoints/music2cocktails/music2flavor/b256_r128_classif001_ld40_meanstd_regground2.5_egg_bubbles/"
69
-
70
- # piano solo evaluation
71
- # META_DATA_PIANO_EVAL_PATH = REPO_PATH + 'data/music/audio/is_piano.csv'
72
- # CHKPT_PATH_PIANO_EVAL = REPO_PATH + 'data/checkpoints/piano_detection/piano_solo_model_32k.pth'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/music/pipeline/__init__.py DELETED
File without changes
src/music/pipeline/audio2midi.py DELETED
@@ -1,52 +0,0 @@
1
- import torch
2
- import piano_transcription_inference
3
- import numpy as np
4
- import os
5
- import sys
6
- sys.path.append('../../')
7
- from src.music.utils import get_out_path, load_audio
8
- from src.music.config import CHKPT_PATH_TRANSCRIPTION, FPS, MIN_LEN, CROP_LEN
9
- # import librosa
10
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
11
- TRANSCRIPTOR = piano_transcription_inference.PianoTranscription(device=device,
12
- checkpoint_path=CHKPT_PATH_TRANSCRIPTION)
13
-
14
- def audio2midi(audio_path, midi_path=None, crop=CROP_LEN, random_crop=True, verbose=False, level=0):
15
- if verbose and crop < MIN_LEN + 2:
16
- print('crop is inferior to the minimal length of a tune')
17
- assert '.mp3' == audio_path[-4:]
18
- if midi_path is None:
19
- midi_path, _, _ = get_out_path(in_path=audio_path, in_word='audio', out_word='midi', out_extension='.mid')
20
-
21
- if verbose: print(' ' * level + f'Transcribing {audio_path}.')
22
- if os.path.exists(midi_path):
23
- if verbose: print(' ' * (level + 2) + 'Midi file already exists.')
24
- return midi_path, ''
25
-
26
- error_msg = 'Error in transcription. '
27
- try:
28
- error_msg += 'Maybe in audio loading?'
29
- (audio, _) = load_audio(audio_path,
30
- sr=FPS,
31
- mono=True)
32
- error_msg += ' Nope. Cropping?'
33
- if isinstance(crop, int) and len(audio) > FPS * crop:
34
- rc_str = ' (random crop)' if random_crop else ' (start crop)'
35
- if verbose: print(' ' * (level + 2) + f'Cropping the song to {crop}s before transcription{rc_str}. ')
36
- size_crop = FPS * crop
37
- if random_crop:
38
- index_begining = np.random.randint(len(audio) - size_crop - 1)
39
- else:
40
- index_begining = 0
41
- audio = audio[index_begining: index_begining + size_crop]
42
- error_msg += ' Nope. Transcription?'
43
- TRANSCRIPTOR.transcribe(audio, midi_path)
44
- error_msg += ' Nope.'
45
- extra = f' Saved to {midi_path}' if midi_path else ''
46
- if verbose: print(' ' * (level + 2) + f'Success! {extra}')
47
- return midi_path, ''
48
- except:
49
- if verbose: print(' ' * (level + 2) + 'Transcription failed.')
50
- if os.path.exists(midi_path):
51
- os.remove(midi_path)
52
- return None, error_msg + ' Yes.'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/music/pipeline/audio2piano_solo_prob.py DELETED
@@ -1,47 +0,0 @@
1
- import numpy as np
2
- import librosa
3
- import sys
4
- sys.path.append('../../../data/')
5
- from src.music.utilities.processing_models import piano_detection_model
6
- from src.music.config import CHKPT_PATH_PIANO_EVAL
7
-
8
- PIANO_SOLO_DETECTOR = piano_detection_model.PianoSoloDetector(CHKPT_PATH_PIANO_EVAL)
9
- exclude_playlist_folders = ['synth_audio_recorded', 'from_url']
10
-
11
- def clean_start_and_end_blanks(probs):
12
- if len(probs) > 20:
13
- # clean up to 10s in each direction
14
- n_zeros_start = 0
15
- for i in range(10):
16
- if probs[i] <= 0.001:
17
- n_zeros_start += 1
18
- else:
19
- break
20
- n_zeros_end = 0
21
- for i in range(10):
22
- if probs[-(i + 1)] <= 0.001:
23
- n_zeros_end += 1
24
- else:
25
- break
26
- if n_zeros_end == 0:
27
- return probs[n_zeros_start:]
28
- else:
29
- return probs[n_zeros_start:-n_zeros_end]
30
- else:
31
- return probs
32
-
33
- def calculate_piano_solo_prob(audio_path, verbose=False):
34
- """Calculate the piano solo probability of all downloaded mp3s, and append
35
- the probability to the meta csv file. Code from https://github.com/bytedance/GiantMIDI-Piano
36
- """
37
- try:
38
- error_msg = 'Error in audio loading?'
39
- (audio, _) = librosa.core.load(audio_path, sr=piano_detection_model.SR, mono=True)
40
- error_msg += ' Nope. Error in solo prediction?'
41
- probs = PIANO_SOLO_DETECTOR.predict(audio)
42
- # probs = clean_start_and_end_blanks(probs) # remove blanks at start and end (<=10s each way). If not piano, the rest of the song will be enough to tell.
43
- piano_solo_prob = np.mean(probs)
44
- error_msg += ' Nope. '
45
- return piano_solo_prob, ''
46
- except:
47
- return None, error_msg + 'Yes.'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/music/pipeline/encoded2rep.py DELETED
@@ -1,88 +0,0 @@
1
- from src.music.utilities.representation_learning_utilities.constants import *
2
- from src.music.config import REP_MODEL_NAME
3
- from src.music.utils import get_out_path
4
- import pickle
5
- import numpy as np
6
- # from transformers import AutoModel, AutoTokenizer
7
- from torch import nn
8
- from src.music.representation_learning.sentence_transfo.sentence_transformers import SentenceTransformer
9
-
10
- class Argument(object):
11
- def __init__(self, adict):
12
- self.__dict__.update(adict)
13
-
14
- class RepModel(nn.Module):
15
- def __init__(self, model, model_name):
16
- super().__init__()
17
- if 't5' in model_name:
18
- self.model = model.get_encoder()
19
- else:
20
- self.model = model
21
- self.model.eval()
22
-
23
- def forward(self, inputs):
24
- with torch.no_grad():
25
- out = self.model(inputs, output_hidden_states=True)
26
- embeddings = out.hidden_states[-1]
27
- return torch.mean(embeddings[0], dim=0)
28
-
29
- # def get_trained_music_LM(model_name):
30
- # tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True)
31
- # model = RepModel(AutoModel.from_pretrained(model_name, use_auth_token=True), model_name)
32
- #
33
- # return model, tokenizer
34
-
35
- def get_trained_sentence_embedder(model_name):
36
- model = SentenceTransformer(model_name)
37
- return model
38
-
39
- MODEL = get_trained_sentence_embedder(REP_MODEL_NAME)
40
-
41
- def encoded2rep(encoded_path, rep_path=None, return_rep=False, verbose=False, level=0):
42
- if not rep_path:
43
- rep_path, _, _ = get_out_path(in_path=encoded_path, in_word='encoded', out_word='represented', out_extension='.txt')
44
-
45
- error_msg = 'Error in music transformer mapping.'
46
- if verbose: print(' ' * level + 'Mapping to final music representations')
47
- try:
48
- error_msg += ' Error in encoded file loading?'
49
- with open(encoded_path, 'rb') as f:
50
- data = pickle.load(f)
51
- performance = [str(w) for w in data['main'] if w != 1]
52
- assert len(performance) % 5 == 0
53
- if(len(performance) == 0):
54
- error_msg += " Error: No midi messages in primer file"
55
- assert False
56
- error_msg += ' Nope, error in tokenization?'
57
- perf = ' '.join(performance)
58
- # tokenized = torch.IntTensor(TOKENIZER.encode(perf)).unsqueeze(dim=0)
59
- error_msg += ' Nope. Maybe in performance encoding?'
60
- # reps = []
61
- # for i_chunk in range(min(tokenized.shape[1] // 510 - 1, 8)):
62
- # chunk_tokenized = tokenized[:, i_chunk * 510: (i_chunk + 1) * 510 + 2]
63
- # rep = MODEL(chunk_tokenized)
64
- # reps.append(rep.detach().numpy())
65
- # representation = np.mean(reps, axis=0)
66
- p = [int(p) for p in perf.split(' ')]
67
- print('PERF:', np.sum(p), perf)
68
- representation = MODEL.encode(perf)
69
- print('model weights sum: ', torch.sum(torch.Tensor([param.sum() for param in list(MODEL.parameters())])))
70
- print('reprep', representation)
71
- error_msg += ' Nope. Saving performance?'
72
- np.savetxt(rep_path, representation)
73
- error_msg += ' Nope.'
74
- if verbose: print(' ' * (level + 2) + 'Success.')
75
- if return_rep:
76
- return rep_path, representation, ''
77
- else:
78
- return rep_path, ''
79
- except:
80
- if verbose: print(' ' * (level + 2) + f'Failed with error: {error_msg}')
81
- if return_rep:
82
- return None, None, error_msg
83
- else:
84
- return None, error_msg
85
-
86
- if __name__ == "__main__":
87
- representation = encoded2rep("/home/cedric/Documents/pianocktail/data/music/encoded/single_videos_midi_processed_encoded/chris_dawson_all_of_me_.pickle")
88
- stop = 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/music/pipeline/midi2processed.py DELETED
@@ -1,152 +0,0 @@
1
- import time
2
- import os
3
- import sys
4
- sys.path.append('../../')
5
- import pretty_midi as pm
6
- import numpy as np
7
-
8
- from src.music.utils import get_out_path
9
- from src.music.config import MIN_LEN, MIN_NB_NOTES, MAX_GAP_IN_SONG, REMOVE_FIRST_AND_LAST
10
-
11
-
12
- def sort_notes(notes):
13
- starts = np.array([n.start for n in notes])
14
- index_sorted = np.argsort(starts)
15
- return [notes[i] for i in index_sorted].copy()
16
-
17
-
18
- def delete_notes_end_after_start(notes):
19
- indexes_to_keep = [i for i, n in enumerate(notes) if n.start < n.end]
20
- return [notes[i] for i in indexes_to_keep].copy()
21
-
22
- def compute_largest_gap(notes):
23
- gaps = []
24
- latest_note_end_so_far = notes[0].end
25
- for i in range(len(notes) - 1):
26
- note_start = notes[i + 1].start
27
- if latest_note_end_so_far < note_start:
28
- gaps.append(note_start - latest_note_end_so_far)
29
- latest_note_end_so_far = max(latest_note_end_so_far, notes[i+1].end)
30
- if len(gaps) > 0:
31
- largest_gap = np.max(gaps)
32
- else:
33
- largest_gap = 0
34
- return largest_gap
35
-
36
- def analyze_instrument(inst):
37
- # test that piano plays throughout
38
- init = time.time()
39
- notes = inst.notes.copy()
40
- nb_notes = len(notes)
41
- start = notes[0].start
42
- end = inst.get_end_time()
43
- duration = end - start
44
- largest_gap = compute_largest_gap(notes)
45
- return nb_notes, start, end, duration, largest_gap
46
-
47
- def remove_beginning_and_end(midi, end_time):
48
- notes = midi.instruments[0].notes.copy()
49
- new_notes = [n for n in notes if n.start > REMOVE_FIRST_AND_LAST and n.end < end_time - REMOVE_FIRST_AND_LAST]
50
- midi.instruments[0].notes = new_notes
51
- return midi
52
-
53
- def remove_blanks_beginning_and_end(midi):
54
- # remove blanks and the beginning and the end
55
- shift = midi.instruments[0].notes[0].start
56
- for n in midi.instruments[0].notes:
57
- n.start = max(0, n.start - shift)
58
- n.end = max(0, n.end - shift)
59
- for ksc in midi.key_signature_changes:
60
- ksc.time = max(0, ksc.time - shift)
61
- for tsc in midi.time_signature_changes:
62
- tsc.time = max(0, tsc.time - shift)
63
- for pb in midi.instruments[0].pitch_bends:
64
- pb.time = max(0, pb.time - shift)
65
- for cc in midi.instruments[0].control_changes:
66
- cc.time = max(0, cc.time - shift)
67
- return midi
68
-
69
- def is_valid_inst(largest_gap, duration, nb_notes, gap_counts=True):
70
- error_msg = ''
71
- valid = True
72
- if largest_gap > MAX_GAP_IN_SONG and gap_counts:
73
- valid = False
74
- error_msg += f'wide gap ({largest_gap:.2f} secs), '
75
- if duration < (MIN_LEN + 2 * REMOVE_FIRST_AND_LAST):
76
- valid = False
77
- error_msg += f'too short ({duration:.2f} secs), '
78
- if nb_notes < MIN_NB_NOTES * duration / 60: # nb of notes needs to be superior to the minimum number / min * the duration in minute
79
- valid = False
80
- error_msg += f'too few notes ({nb_notes}), '
81
- return valid, error_msg
82
-
83
- def midi2processed(midi_path, processed_path=None, apply_filtering=True, verbose=False, level=0):
84
- assert midi_path.split('.')[-1] in ['mid', 'midi']
85
- if not processed_path:
86
- processed_path, _, _ = get_out_path(in_path=midi_path, in_word='midi', out_word='processed', out_extension='.mid')
87
-
88
- if verbose: print(' ' * level + f'Processing {midi_path}.')
89
-
90
- if os.path.exists(processed_path):
91
- if verbose: print(' ' * (level + 2) + 'Processed midi file already exists.')
92
- return processed_path, ''
93
- error_msg = 'Error in scrubbing. '
94
- # try:
95
- inst_error_msg = ''
96
- # load mid file
97
- error_msg += 'Error in midi loading?'
98
- midi = pm.PrettyMIDI(midi_path)
99
- error_msg += ' Nope. Removing invalid notes?'
100
- midi.remove_invalid_notes() # filter invalid notes
101
- error_msg += ' Nope. Filtering instruments?'
102
- # filter instruments
103
- instruments = midi.instruments.copy()
104
- new_instru = []
105
- instruments_data = []
106
- for i_inst, inst in enumerate(instruments):
107
- if inst.program <= 7 and not inst.is_drum and len(inst.notes) > 5:
108
- # inst is a piano
109
- # check data
110
- inst.notes = sort_notes(inst.notes) # sort notes
111
- inst.notes = delete_notes_end_after_start(inst.notes) # delete invalid notes
112
- nb_notes, start, end, duration, largest_gap = analyze_instrument(inst)
113
- is_valid, err_msg = is_valid_inst(largest_gap=largest_gap, duration=duration, nb_notes=nb_notes, gap_counts='maestro' not in midi_path)
114
- if is_valid or not apply_filtering:
115
- new_instru.append(inst)
116
- instruments_data.append([nb_notes, start, end, duration, largest_gap])
117
- else:
118
- inst_error_msg += 'inst1: ' + err_msg + '\n'
119
- instruments_data = np.array(instruments_data)
120
- error_msg += ' Nope. Taking one instrument?'
121
-
122
- if len(new_instru) == 0:
123
- error_msg = f'No piano instrument. {inst_error_msg}'
124
- assert False
125
- elif len(new_instru) > 1:
126
- # take instrument playing the most notes
127
- instrument = new_instru[np.argmax(instruments_data[:, 0])]
128
- else:
129
- instrument = new_instru[0]
130
- instrument.program = 0 # set the instrument to Grand Piano.
131
- midi.instruments = [instrument] # put instrument in midi file
132
- error_msg += ' Nope. Removing blanks?'
133
- # remove first and last REMOVE_FIRST_AND_LAST seconds (avoid clapping and jingles)
134
- end_time = midi.get_end_time()
135
- if apply_filtering: midi = remove_beginning_and_end(midi, end_time)
136
-
137
- # remove beginning and end
138
- midi = remove_blanks_beginning_and_end(midi)
139
- error_msg += ' Nope. Saving?'
140
-
141
- # save midi file
142
- midi.write(processed_path)
143
- error_msg += ' Nope.'
144
- if verbose:
145
- extra = f' Saved to {processed_path}' if midi_path else ''
146
- print(' ' * (level + 2) + f'Success! {extra}')
147
- return processed_path, ''
148
- #except:
149
- # if verbose: print(' ' * (level + 2) + 'Scrubbing failed.')
150
- # if os.path.exists(processed_path):
151
- # os.remove(processed_path)
152
- # return None, error_msg + ' Yes.'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/music/pipeline/music_pipeline.py DELETED
@@ -1,86 +0,0 @@
1
- from src.music.pipeline.url2audio import url2audio
2
- from src.music.pipeline.audio2midi import audio2midi
3
- from src.music.pipeline.midi2processed import midi2processed
4
- from src.music.pipeline.processed2encoded import processed2encoded
5
- from src.music.pipeline.encoded2rep import encoded2rep
6
- from src.music.config import RANDOM_CROP, NB_AUG, FROM_URL_PATH
7
- # from src.music.pipeline.synth2audio import AudioRecorder
8
- # from src.music.pipeline.processed2handcodedrep import processed2handcodedrep
9
- import time
10
- import hashlib
11
-
12
- VERBOSE = True
13
- AUGMENTATION, NOISE_INJECTED = False, False
14
- CROP = 10# crop 30s before transcription
15
-
16
- # AUDIO_RECORDER = AudioRecorder(place='home')
17
-
18
- def encode_music(url=None,
19
- audio_path=None,
20
- midi_path=None,
21
- processed_path=None,
22
- record=False,
23
- crop=CROP,
24
- random_crop=RANDOM_CROP,
25
- augmentation=AUGMENTATION,
26
- noise_injection=NOISE_INJECTED,
27
- apply_filtering=True,
28
- nb_aug=NB_AUG,
29
- level=0,
30
- verbose=VERBOSE):
31
- if not record: assert url is not None or audio_path is not None or midi_path is not None or processed_path is not None
32
- init_time = time.time()
33
- error = ''
34
- try:
35
- if record:
36
- assert audio_path is None and midi_path is None
37
- if verbose: print(' ' * level + 'Processing music, recorded from mic.')
38
- audio_path = AUDIO_RECORDER.record_one()
39
- error = ''
40
- if processed_path is None:
41
- if midi_path is None:
42
- if audio_path is None:
43
- if verbose and not record: print(' ' * level + 'Processing music, from audio source.')
44
- init_t = time.time()
45
- audio_path, _, error = url2audio(playlist_path=FROM_URL_PATH, video_url=url, verbose=verbose, level=level+2)
46
- if verbose: print(' ' * (level + 4) + f'Audio downloaded in {int(time.time() - init_t)} seconds.')
47
- else:
48
- if verbose and not record: print(' ' * level + 'Processing music, from midi source.')
49
- init_t = time.time()
50
- midi_path, error = audio2midi(audio_path, crop=crop, random_crop=random_crop, verbose=verbose, level=level+2)
51
- if verbose: print(' ' * (level + 4) + f'Audio transcribed to midi in {int(time.time() - init_t)} seconds.')
52
- init_t = time.time()
53
- processed_path, error = midi2processed(midi_path, apply_filtering=apply_filtering, verbose=verbose, level=level+2)
54
- if verbose: print(' ' * (level + 4) + f'Midi preprocessed in {int(time.time() - init_t)} seconds.')
55
- init_t = time.time()
56
- encoded_path, error = processed2encoded(processed_path, augmentation=augmentation, nb_aug=nb_aug, noise_injection=noise_injection, verbose=verbose, level=level+2)
57
- if verbose: print(' ' * (level + 4) + f'Midi encoded in {int(time.time() - init_t)} seconds.')
58
- init_t = time.time()
59
- representation_path, representation, error = encoded2rep(encoded_path, return_rep=True, level=level+2, verbose=verbose)
60
- if verbose: print(' ' * (level + 4) + f'Music representation computed in {int(time.time() - init_t)} seconds.')
61
- init_t = time.time()
62
- handcoded_rep_path, handcoded_rep, error = None, None, ''
63
- # handcoded_rep_path, handcoded_rep, error = processed2handcodedrep(processed_path, return_rep=True, level=level+2, verbose=verbose)
64
- if verbose: print(' ' * (level + 4) + f'Music handcoded representation computed in {int(time.time() - init_t)} seconds.')
65
- # assert handcoded_rep_path is not None and representation_path is not None
66
- all_paths = dict(url=url, audio_path=audio_path, midi_path=midi_path, processed_path=processed_path, encoded_path=encoded_path,
67
- representation_path=representation_path, handcoded_rep_path=handcoded_rep_path)
68
- print('audio hash: ', hashlib.md5(open(audio_path, 'rb').read()).hexdigest())
69
- print('midi hash: ', hashlib.md5(open(midi_path, 'rb').read()).hexdigest())
70
- print('processed hash: ', hashlib.md5(open(processed_path, 'rb').read()).hexdigest())
71
- print('encoded hash: ', hashlib.md5(open(encoded_path, 'rb').read()).hexdigest())
72
- print('rep hash: ', hashlib.md5(open(representation_path, 'rb').read()).hexdigest())
73
- print("rep:", representation[:10])
74
- if verbose: print(' ' * (level + 2) + f'Music processed in {int(time.time() - init_time)} seconds.')
75
- except Exception as err:
76
- print(err, error)
77
- if verbose: print(' ' * (level + 2) + f'Music FAILED to process in {int(time.time() - init_time)} seconds.')
78
- representation = None
79
- handcoded_rep = None
80
- all_paths = dict()
81
-
82
- return representation, handcoded_rep, all_paths, error
83
-
84
- if __name__ == '__main__':
85
- representation = encode_music(url="https://www.youtube.com/watch?v=a2LFVWBmoiw")[0]
86
- # representation = encode_music(record=True)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/music/pipeline/processed2encoded.py DELETED
@@ -1,52 +0,0 @@
1
- import os
2
- import sys
3
- import numpy as np
4
- import pickle
5
- sys.path.append('../../')
6
-
7
- from src.music.utils import get_out_path
8
- from src.music.config import ALL_NOISE, ALL_AUGMENTATIONS, NB_AUG, NOISE_INJECTED
9
- from src.music.utilities.midi_processor import encode_midi_structured, encode_midi_chunks_structured
10
-
11
- nb_noise = ALL_NOISE.shape[0]
12
- nb_aug = ALL_AUGMENTATIONS.shape[0]
13
-
14
- def sample_augmentations(n):
15
- return ALL_AUGMENTATIONS[np.random.choice(np.arange(nb_aug), size=n, replace=False)]
16
-
17
- def sample_noise():
18
- return ALL_NOISE[np.random.choice(np.arange(nb_noise))]
19
-
20
- def processed2encoded(processed_path, encoded_path=None, augmentation=False, nb_aug=None, noise_injection=False, verbose=False, level=0):
21
- assert processed_path.split('.')[-1] in ['mid', 'midi']
22
- if not encoded_path:
23
- encoded_path, _, _ = get_out_path(in_path=processed_path, in_word='processed', out_word='encoded', out_extension='.pickle')
24
-
25
- if verbose: print(' ' * level + f'Encoding {processed_path}')
26
- if os.path.exists(encoded_path):
27
- if verbose: print(' ' * (level + 2) + 'Midi file is already encoded.')
28
- return encoded_path, ''
29
-
30
- if augmentation:
31
- assert isinstance(nb_aug, int)
32
- error_msg = 'Error in encoding. '
33
- try:
34
- error_msg = 'Error in encoding midi?'
35
- nb_noise = 1 if noise_injection else 0
36
- encoded_main, encoded_aug, encoded_noisy = encode_midi_structured(processed_path, nb_aug, nb_noise)
37
-
38
- # make sure augmentations are not out of bounds
39
- error_msg = ' Nope. Error in saving encoding?'
40
- with open(encoded_path, 'wb') as f:
41
- pickle.dump(dict(main=encoded_main, aug=encoded_aug, noisy=encoded_noisy), f)
42
- error_msg = ' Nope.'
43
- if verbose:
44
- extra = f' Saved to {encoded_path}' if encoded_path else ''
45
- print(' ' * (level + 2) + f'Success! {extra}')
46
- return encoded_path, ''
47
- except:
48
- if verbose: print(' ' * (level + 2) + 'Transcription failed.')
49
- if os.path.exists(encoded_path):
50
- os.remove(encoded_path)
51
- return None, error_msg + ' Yes.'
52
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/music/pipeline/processed2handcodedrep.py DELETED
@@ -1,343 +0,0 @@
1
- import numpy as np
2
- from music21 import *
3
- from music21.features import native, jSymbolic, DataSet
4
- import pretty_midi as pm
5
- from src.music.utils import get_out_path
6
- from src.music.utilities.handcoded_rep_utilities.tht import tactus_hypothesis_tracker, tracker_analysis
7
- from src.music.utilities.handcoded_rep_utilities.loudness import get_loudness, compute_total_loudness, amplitude2db, velocity2amplitude, get_db_of_equivalent_loudness_at_440hz, pitch2freq
8
- import json
9
- import os
10
- environment.set('musicxmlPath', '/home/cedric/Desktop/test/')
11
- midi_path = "/home/cedric/Documents/pianocktail/data/music/processed/doug_mckenzie_processed/allthethings_reharmonized_processed.mid"
12
-
13
- FEATURES_DICT_SCORE = dict(
14
- # strongest pulse: measures how fast the melody is
15
- # stronger_pulse=jSymbolic.StrongestRhythmicPulseFeature,
16
- # weights of the two strongest pulse, measures rhythmic consistency: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#combinedstrengthoftwostrongestrhythmicpulsesfeature
17
- pulse_strength_two=jSymbolic.CombinedStrengthOfTwoStrongestRhythmicPulsesFeature,
18
- # weights of the strongest pulse, measures rhythmic consistency: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#combinedstrengthoftwostrongestrhythmicpulsesfeature
19
- pulse_strength = jSymbolic.StrengthOfStrongestRhythmicPulseFeature,
20
- # variability of attacks: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#variabilityoftimebetweenattacksfeature
21
-
22
- )
23
- FEATURES_DICT = dict(
24
- # bass register importance: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#importanceofbassregisterfeature
25
- # bass_register=jSymbolic.ImportanceOfBassRegisterFeature,
26
- # high register importance: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#importanceofbassregisterfeature
27
- # high_register=jSymbolic.ImportanceOfHighRegisterFeature,
28
- # medium register importance: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#importanceofbassregisterfeature
29
- # medium_register=jSymbolic.ImportanceOfMiddleRegisterFeature,
30
- # number of common pitches (at least 9% of all): https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#numberofcommonmelodicintervalsfeature
31
- # common_pitches=jSymbolic.NumberOfCommonPitchesFeature,
32
- # pitch class variety (used at least once): https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#pitchvarietyfeature
33
- # pitch_variety=jSymbolic.PitchVarietyFeature,
34
- # attack_variability = jSymbolic.VariabilityOfTimeBetweenAttacksFeature,
35
- # staccato fraction: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#staccatoincidencefeature
36
- # staccato_score = jSymbolic.StaccatoIncidenceFeature,
37
- # mode analysis: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesNative.html
38
- av_melodic_interval = jSymbolic.AverageMelodicIntervalFeature,
39
- # chromatic motion: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#chromaticmotionfeature
40
- chromatic_motion = jSymbolic.ChromaticMotionFeature,
41
- # direction of motion (fraction of rising intervals: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#directionofmotionfeature
42
- motion_direction = jSymbolic.DirectionOfMotionFeature,
43
- # duration of melodic arcs: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#durationofmelodicarcsfeature
44
- melodic_arcs_duration = jSymbolic.DurationOfMelodicArcsFeature,
45
- # melodic arcs size: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#sizeofmelodicarcsfeature
46
- melodic_arcs_size = jSymbolic.SizeOfMelodicArcsFeature,
47
- # number of common melodic interval (at least 9% of all): https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#numberofcommonmelodicintervalsfeature
48
- # common_melodic_intervals = jSymbolic.NumberOfCommonMelodicIntervalsFeature,
49
- # https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#amountofarpeggiationfeature
50
- # arpeggiato=jSymbolic.AmountOfArpeggiationFeature,
51
- )
52
-
53
-
54
-
55
-
56
-
57
-
58
- def compute_beat_info(onsets):
59
- onsets_in_ms = np.array(onsets) * 1000
60
-
61
- tht = tactus_hypothesis_tracker.default_tht()
62
- trackers = tht(onsets_in_ms)
63
- top_hts = tracker_analysis.top_hypothesis(trackers, len(onsets_in_ms))
64
- beats = tracker_analysis.produce_beats_information(onsets_in_ms, top_hts, adapt_period=250 is not None,
65
- adapt_phase=tht.eval_f, max_delta_bpm=250, avoid_quickturns=None)
66
- tempo = 1 / (np.mean(np.diff(beats)) / 1000) * 60 # in bpm
67
- conf_values = tracker_analysis.tht_tracking_confs(trackers, len(onsets_in_ms))
68
- pulse_clarity = np.mean(np.array(conf_values), axis=0)[1]
69
- return tempo, pulse_clarity
70
-
71
- def dissonance_score(A):
72
- """
73
- Given a piano-roll indicator matrix representation of a musical work (128 pitches x beats),
74
- return the dissonance as a function of beats.
75
- Input:
76
- A - 128 x beats indicator matrix of MIDI pitch number
77
-
78
- """
79
- freq_rats = np.arange(1, 7) # Harmonic series ratios
80
- amps = np.exp(-.5 * freq_rats) # Partial amplitudes
81
- F0 = 8.1757989156 # base frequency for MIDI (note 0)
82
- diss = [] # List for dissonance values
83
- thresh = 1e-3
84
- for beat in A.T:
85
- idx = np.where(beat>thresh)[0]
86
- if len(idx):
87
- freqs, mags = [], [] # lists for frequencies, mags
88
- for i in idx:
89
- freqs.extend(F0*2**(i/12.0)*freq_rats)
90
- mags.extend(amps)
91
- freqs = np.array(freqs)
92
- mags = np.array(mags)
93
- sortIdx = freqs.argsort()
94
- d = compute_dissonance(freqs[sortIdx],mags[sortIdx])
95
- diss.extend([d])
96
- else:
97
- diss.extend([-1]) # Null value
98
- diss = np.array(diss)
99
- return diss[np.where(diss != -1)]
100
-
101
- def compute_dissonance(freqs, amps):
102
- """
103
- From https://notebook.community/soundspotter/consonance/week1_consonance
104
- Compute dissonance between partials with center frequencies in freqs, uses a model of critical bandwidth.
105
- and amplitudes in amps. Based on Sethares "Tuning, Timbre, Spectrum, Scale" (1998) after Plomp and Levelt (1965)
106
-
107
- inputs:
108
- freqs - list of partial frequencies
109
- amps - list of corresponding amplitudes [default, uniformly 1]
110
- """
111
- b1, b2, s1, s2, c1, c2, Dstar = (-3.51, -5.75, 0.0207, 19.96, 5, -5, 0.24)
112
- f = np.array(freqs)
113
- a = np.array(amps)
114
- idx = np.argsort(f)
115
- f = f[idx]
116
- a = a[idx]
117
- N = f.size
118
- D = 0
119
- for i in range(1, N):
120
- Fmin = f[ 0 : N - i ]
121
- S = Dstar / ( s1 * Fmin + s2)
122
- Fdif = f[ i : N ] - f[ 0 : N - i ]
123
- am = a[ i : N ] * a[ 0 : N - i ]
124
- Dnew = am * (c1 * np.exp (b1 * S * Fdif) + c2 * np.exp(b2 * S * Fdif))
125
- D += Dnew.sum()
126
- return D
127
-
128
-
129
-
130
-
131
- def store_new_midi(notes, out_path):
132
- midi = pm.PrettyMIDI()
133
- midi.instruments.append(pm.Instrument(program=0, is_drum=False))
134
- midi.instruments[0].notes = notes
135
- midi.write(out_path)
136
- return midi
137
-
138
-
139
- def processed2handcodedrep(midi_path, handcoded_rep_path=None, crop=30, verbose=False, save=True, return_rep=False, level=0):
140
- try:
141
- if not handcoded_rep_path:
142
- handcoded_rep_path, _, _ = get_out_path(in_path=midi_path, in_word='processed', out_word='handcoded_reps', out_extension='.mid')
143
- features = dict()
144
- if verbose: print(' ' * level + 'Computing handcoded representations')
145
- if os.path.exists(handcoded_rep_path):
146
- with open(handcoded_rep_path.replace('.mid', '.json'), 'r') as f:
147
- features = json.load(f)
148
- rep = np.array([features[k] for k in sorted(features.keys())])
149
- if rep.size == 49:
150
- os.remove(handcoded_rep_path)
151
- else:
152
- if verbose: print(' ' * (level + 2) + 'Already computed.')
153
- if return_rep:
154
- return handcoded_rep_path, np.array([features[k] for k in sorted(features.keys())]), ''
155
- else:
156
- return handcoded_rep_path, ''
157
- midi = pm.PrettyMIDI(midi_path) # load midi with pretty midi
158
- notes = midi.instruments[0].notes # get notes
159
- notes.sort(key=lambda x: (x.start, x.pitch)) # sort notes per start and pitch
160
- onsets, offsets, pitches, durations, velocities = [], [], [], [], []
161
- n_notes_cropped = len(notes)
162
- for i_n, n in enumerate(notes):
163
- onsets.append(n.start)
164
- offsets.append(n.end)
165
- durations.append(n.end-n.start)
166
- pitches.append(n.pitch)
167
- velocities.append(n.velocity)
168
- if crop is not None: # find how many notes to keep
169
- if n.start > crop and n_notes_cropped == len(notes):
170
- n_notes_cropped = i_n
171
- break
172
- notes = notes[:n_notes_cropped]
173
- midi = store_new_midi(notes, handcoded_rep_path)
174
- # pianoroll = midi.get_piano_roll() # extract piano roll representation
175
-
176
- # compute loudness
177
- amplitudes = velocity2amplitude(np.array(velocities))
178
- power_dbs = amplitude2db(amplitudes)
179
- frequencies = pitch2freq(np.array(pitches))
180
- loudness_values = get_loudness(power_dbs, frequencies)
181
- # compute average perceived loudness
182
- # for each power, compute loudness, then compute power such that the loudness at 440 Hz would be equivalent.
183
- # equivalent_powers_dbs = get_db_of_equivalent_loudness_at_440hz(frequencies, power_dbs)
184
- # then get the corresponding amplitudes
185
- # equivalent_amplitudes = 10 ** (equivalent_powers_dbs / 20)
186
- # not use a amplitude model across the sample to compute the instantaneous amplitude, turn it back to dbs, then to perceived loudness with unique freq 440 Hz
187
- # av_total_loudness, std_total_loudness = compute_total_loudness(equivalent_amplitudes, onsets, offsets)
188
-
189
- end_time = np.max(offsets)
190
- start_time = notes[0].start
191
-
192
-
193
- score = converter.parse(handcoded_rep_path)
194
- score.chordify()
195
- notes_without_chords = stream.Stream(score.flatten().getElementsByClass('Note'))
196
-
197
- velocities_wo_chords, pitches_wo_chords, amplitudes_wo_chords, dbs_wo_chords = [], [], [], []
198
- frequencies_wo_chords, loudness_values_wo_chords, onsets_wo_chords, offsets_wo_chords, durations_wo_chords = [], [], [], [], []
199
- for i_n in range(len(notes_without_chords)):
200
- n = notes_without_chords[i_n]
201
- velocities_wo_chords.append(n.volume.velocity)
202
- pitches_wo_chords.append(n.pitch.midi)
203
- onsets_wo_chords.append(n.offset)
204
- offsets_wo_chords.append(onsets_wo_chords[-1] + n.seconds)
205
- durations_wo_chords.append(n.seconds)
206
-
207
- amplitudes_wo_chords = velocity2amplitude(np.array(velocities_wo_chords))
208
- power_dbs_wo_chords = amplitude2db(amplitudes_wo_chords)
209
- frequencies_wo_chords = pitch2freq(np.array(pitches_wo_chords))
210
- loudness_values_wo_chords = get_loudness(power_dbs_wo_chords, frequencies_wo_chords)
211
- # compute average perceived loudness
212
- # for each power, compute loudness, then compute power such that the loudness at 440 Hz would be equivalent.
213
- # equivalent_powers_dbs_wo_chords = get_db_of_equivalent_loudness_at_440hz(frequencies_wo_chords, power_dbs_wo_chords)
214
- # then get the corresponding amplitudes
215
- # equivalent_amplitudes_wo_chords = 10 ** (equivalent_powers_dbs_wo_chords / 20)
216
- # not use a amplitude model across the sample to compute the instantaneous amplitude, turn it back to dbs, then to perceived loudness with unique freq 440 Hz
217
- # av_total_loudness_wo_chords, std_total_loudness_wo_chords = compute_total_loudness(equivalent_amplitudes_wo_chords, onsets_wo_chords, offsets_wo_chords)
218
-
219
- ds = DataSet(classLabel='test')
220
- f = list(FEATURES_DICT.values())
221
- ds.addFeatureExtractors(f)
222
- ds.addData(notes_without_chords)
223
- ds.process()
224
- for k, f in zip(FEATURES_DICT.keys(), ds.getFeaturesAsList()[0][1:-1]):
225
- features[k] = f
226
-
227
- ds = DataSet(classLabel='test')
228
- f = list(FEATURES_DICT_SCORE.values())
229
- ds.addFeatureExtractors(f)
230
- ds.addData(score)
231
- ds.process()
232
- for k, f in zip(FEATURES_DICT_SCORE.keys(), ds.getFeaturesAsList()[0][1:-1]):
233
- features[k] = f
234
-
235
- # # # # #
236
- # Register features
237
- # # # # #
238
-
239
- # features['av_pitch'] = np.mean(pitches)
240
- # features['std_pitch'] = np.std(pitches)
241
- # features['range_pitch'] = np.max(pitches) - np.min(pitches) # aka ambitus
242
-
243
- # # # # #
244
- # Rhythmic features
245
- # # # # #
246
-
247
- # tempo, pulse_clarity = compute_beat_info(onsets[:n_notes_cropped])
248
- # features['pulse_clarity'] = pulse_clarity
249
- # features['tempo'] = tempo
250
- features['tempo_pm'] = midi.estimate_tempo()
251
-
252
- # # # # #
253
- # Temporal features
254
- # # # # #
255
-
256
- features['av_duration'] = np.mean(durations)
257
- # features['std_duration'] = np.std(durations)
258
- features['note_density'] = len(notes) / (end_time - start_time)
259
- # intervals_wo_chords = np.diff(onsets_wo_chords)
260
- # articulations = [max((i-d)/i, 0) for d, i in zip(durations_wo_chords, intervals_wo_chords) if i != 0]
261
- # features['articulation'] = np.mean(articulations)
262
- # features['av_duration_wo_chords'] = np.mean(durations_wo_chords)
263
- # features['std_duration_wo_chords'] = np.std(durations_wo_chords)
264
-
265
- # # # # #
266
- # Dynamics features
267
- # # # # #
268
- features['av_velocity'] = np.mean(velocities)
269
- features['std_velocity'] = np.std(velocities)
270
- features['av_loudness'] = np.mean(loudness_values)
271
- # features['std_loudness'] = np.std(loudness_values)
272
- features['range_loudness'] = np.max(loudness_values) - np.min(loudness_values)
273
- # features['av_integrated_loudness'] = av_total_loudness
274
- # features['std_integrated_loudness'] = std_total_loudness
275
- # features['av_velocity_wo_chords'] = np.mean(velocities_wo_chords)
276
- # features['std_velocity_wo_chords'] = np.std(velocities_wo_chords)
277
- # features['av_loudness_wo_chords'] = np.mean(loudness_values_wo_chords)
278
- # features['std_loudness_wo_chords'] = np.std(loudness_values_wo_chords)
279
- features['range_loudness_wo_chords'] = np.max(loudness_values_wo_chords) - np.min(loudness_values_wo_chords)
280
- # features['av_integrated_loudness'] = av_total_loudness_wo_chords
281
- # features['std_integrated_loudness'] = std_total_loudness_wo_chords
282
- # indices_with_intervals = np.where(intervals_wo_chords > 0.01)
283
- # features['av_loudness_change'] = np.mean(np.abs(np.diff(np.array(loudness_values_wo_chords)[indices_with_intervals]))) # accentuation
284
- # features['av_velocity_change'] = np.mean(np.abs(np.diff(np.array(velocities_wo_chords)[indices_with_intervals]))) # accentuation
285
-
286
- # # # # #
287
- # Harmony features
288
- # # # # #
289
-
290
- # get major_minor score: https://web.mit.edu/music21/doc/moduleReference/moduleAnalysisDiscrete.html
291
- music_analysis = score.analyze('key')
292
- major_score = None
293
- minor_score = None
294
- for a in [music_analysis] + music_analysis.alternateInterpretations:
295
- if 'major' in a.__str__() and a.correlationCoefficient > 0:
296
- major_score = a.correlationCoefficient
297
- elif 'minor' in a.__str__() and a.correlationCoefficient > 0:
298
- minor_score = a.correlationCoefficient
299
- if major_score is not None and minor_score is not None:
300
- break
301
- features['major_minor'] = major_score / (major_score + minor_score)
302
- features['tonal_certainty'] = music_analysis.tonalCertainty()
303
- # features['av_sensory_dissonance'] = np.mean(dissonance_score(pianoroll))
304
- #TODO only works for chords, do something with melodic intervals: like proportion that is not third, fifth or sevenths?
305
-
306
- # # # # #
307
- # Interval features
308
- # # # # #
309
- #https://web.mit.edu/music21/doc/moduleReference/moduleAnalysisPatel.html
310
- # features['melodic_interval_variability'] = analysis.patel.melodicIntervalVariability(notes_without_chords)
311
-
312
- # # # # #
313
- # Suprize features
314
- # # # # #
315
- # https://web.mit.edu/music21/doc/moduleReference/moduleAnalysisMetrical.html
316
- # analysis.metrical.thomassenMelodicAccent(notes_without_chords)
317
- # melodic_accents = [n.melodicAccent for n in notes_without_chords]
318
- # features['melodic_accent'] = np.mean(melodic_accents)
319
-
320
- if save:
321
- for k, v in features.items():
322
- features[k] = float(features[k])
323
- with open(handcoded_rep_path.replace('.mid', '.json'), 'w') as f:
324
- json.dump(features, f)
325
- else:
326
- print(features)
327
- if os.path.exists(handcoded_rep_path):
328
- os.remove(handcoded_rep_path)
329
- if verbose: print(' ' * (level + 2) + 'Success.')
330
- if return_rep:
331
- return handcoded_rep_path, np.array([features[k] for k in sorted(features.keys())]), ''
332
- else:
333
- return handcoded_rep_path, ''
334
- except:
335
- if verbose: print(' ' * (level + 2) + 'Failed.')
336
- if return_rep:
337
- return None, None, 'error'
338
- else:
339
- return None, 'error'
340
-
341
-
342
- if __name__ == '__main__':
343
- processed2handcodedrep(midi_path, '/home/cedric/Desktop/test.mid', save=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/music/pipeline/synth2audio.py DELETED
@@ -1,170 +0,0 @@
1
- import pynput
2
- import sys
3
- sys.path.append('../../')
4
- from src.music.config import SYNTH_RECORDED_AUDIO_PATH, RATE_AUDIO_SAVE
5
- from datetime import datetime
6
- import numpy as np
7
- import os
8
- import wave
9
-
10
- from ctypes import *
11
- from contextlib import contextmanager
12
- import pyaudio
13
-
14
- ERROR_HANDLER_FUNC = CFUNCTYPE(None, c_char_p, c_int, c_char_p, c_int, c_char_p)
15
-
16
- def py_error_handler(filename, line, function, err, fmt):
17
- pass
18
- c_error_handler = ERROR_HANDLER_FUNC(py_error_handler)
19
-
20
- @contextmanager
21
- def noalsaerr():
22
- asound = cdll.LoadLibrary('libasound.so')
23
- asound.snd_lib_error_set_handler(c_error_handler)
24
- yield
25
- asound.snd_lib_error_set_handler(None)
26
-
27
- global KEY_PRESSED
28
- KEY_PRESSED = None
29
-
30
- def on_press(key):
31
- global KEY_PRESSED
32
- try:
33
- KEY_PRESSED = key.name
34
- except:
35
- pass
36
-
37
- def on_release(key):
38
- global KEY_PRESSED
39
- KEY_PRESSED = None
40
-
41
-
42
- def is_pressed(key):
43
- global KEY_PRESSED
44
- return KEY_PRESSED == key
45
-
46
- # keyboard listener
47
- listener = pynput.keyboard.Listener(on_press=on_press, on_release=on_release)
48
- listener.start()
49
-
50
- LEN_RECORDINGS = 40
51
- class AudioRecorder:
52
- def __init__(self, chunk=2**10, rate=44100, place='', len_recording=LEN_RECORDINGS, drop_beginning=0.5):
53
- self.chunk = chunk
54
- self.rate = rate
55
- with noalsaerr():
56
- self.audio = pyaudio.PyAudio()
57
- self.channels = 1
58
- self.format = pyaudio.paInt16
59
- self.stream = self.audio.open(format=self.format,
60
- channels=self.channels,
61
- rate=rate,
62
- input=True,
63
- frames_per_buffer=chunk)
64
- self.stream.stop_stream()
65
- self.drop_beginning_chunks = int(drop_beginning * self.rate / self.chunk)
66
- self.place = place
67
- self.len_recordings = len_recording
68
-
69
- def get_filename(self):
70
- now = datetime.now()
71
- return self.place + '_' + now.strftime("%b_%d_%Y_%Hh%Mm%Ss") + '.mp3'
72
-
73
- def read_last_chunk(self):
74
- return self.stream.read(self.chunk)
75
-
76
- def live_read(self):
77
- if self.stream.is_stopped():
78
- self.stream.start_stream()
79
- i = 0
80
- while not is_pressed('esc'):
81
- data = np.frombuffer(self.stream.read(self.chunk), dtype=np.int16)
82
- peak = np.average(np.abs(data)) * 2
83
- bars = "#"*int(50 * peak / 2 ** 16)
84
- i += 1
85
- print("%04d %05d %s"%(i,peak,bars))
86
- self.stream.stop_stream()
87
-
88
- def record_next_N_seconds(self, n=None, saving_path=None):
89
- if saving_path is None:
90
- saving_path = SYNTH_RECORDED_AUDIO_PATH + self.get_filename()
91
- if n is None:
92
- n = self.len_recordings
93
-
94
- print(f'Recoding the next {n} secs.'
95
- # f'\n\tRecording starts when the first key is pressed;'
96
- f'\n\tPress Enter to end the recording;'
97
- f'\n\tPress BackSpace (<--) to cancel the recording;'
98
- f'\n\tSaving to {saving_path}')
99
- try:
100
- self.stream.start_stream()
101
- backspace_pressed = False
102
- self.recording = []
103
- i_chunk = 0
104
- while not is_pressed('enter') and self.chunk / self.rate * i_chunk < n:
105
- self.recording.append(self.read_last_chunk())
106
- i_chunk += 1
107
- if is_pressed('backspace'):
108
- backspace_pressed = True
109
- print('\n \t--> Recording cancelled! (you pressed BackSpace)')
110
- break
111
- self.stream.stop_stream()
112
-
113
- # save the file
114
- if not backspace_pressed:
115
- self.recording = self.recording[self.drop_beginning_chunks:] # drop first chunks to remove keyboard sound
116
- with wave.open(saving_path[:-4] + '.wav', 'wb') as waveFile:
117
- waveFile.setnchannels(self.channels)
118
- waveFile.setsampwidth(self.audio.get_sample_size(self.format))
119
- waveFile.setframerate(self.rate)
120
- waveFile.writeframes(b''.join(self.recording))
121
- os.system(f'ffmpeg -i "{saving_path[:-4] + ".wav"}" -vn -loglevel panic -y -ac 1 -ar {int(RATE_AUDIO_SAVE)} -b:a 320k "{saving_path}" ')
122
- os.remove(saving_path[:-4] + '.wav')
123
- print(f'\n--> Recording saved, duration: {self.chunk / self.rate * i_chunk:.2f} secs.')
124
- return saving_path
125
- except:
126
- print('\n --> The recording failed.')
127
- return None
128
-
129
- def record_one(self):
130
- ready_msg = False
131
- print('Starting the recording loop!\n\tPress BackSpace to cancel the current recording;\n\tPress Esc to quit the loop (only works while not recording)')
132
- while True:
133
- if not ready_msg:
134
- print('-------\nReady to record!')
135
- print('Press space to start a recording\n')
136
- ready_msg = True
137
-
138
- if is_pressed('space'):
139
- saving_path = self.record_next_N_seconds()
140
- break
141
- return saving_path
142
-
143
- def run(self):
144
- # with pynput.Listener(
145
- # on_press=self.on_press) as listener:
146
- # listener.join()
147
- ready_msg = False
148
- print('Starting the recording loop!\n\tPress BackSpace to cancel the current recording;\n\tPress Esc to quit the loop (only works while not recording)')
149
- while True:
150
- if not ready_msg:
151
- print('-------\nReady to record!')
152
- print('Press space to start a recording\n')
153
- ready_msg = True
154
-
155
- if is_pressed('space'):
156
- self.record_next_N_seconds()
157
- ready_msg = False
158
- if is_pressed('esc'):
159
- print('End of the recording session. See you soon!')
160
- self.close()
161
- break
162
-
163
- def close(self):
164
- self.stream.close()
165
- self.audio.terminate()
166
-
167
- if __name__ == '__main__':
168
- audio_recorder = AudioRecorder(place='home')
169
- audio_recorder.record_one()
170
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/music/pipeline/synth2midi.py DELETED
@@ -1,146 +0,0 @@
1
- import mido
2
- mido.set_backend('mido.backends.pygame')
3
- from mido import Message, MidiFile, MidiTrack
4
- import time
5
- import pynput
6
- import sys
7
- sys.path.append('../../')
8
- from src.music.config import SYNTH_RECORDED_MIDI_PATH
9
- from datetime import datetime
10
-
11
- #TODO: debug this with other cable, keyboard and sound card
12
- global KEY_PRESSED
13
- KEY_PRESSED = None
14
-
15
- def on_press(key):
16
- global KEY_PRESSED
17
- try:
18
- KEY_PRESSED = key.name
19
- except:
20
- pass
21
-
22
- def on_release(key):
23
- global KEY_PRESSED
24
- KEY_PRESSED = None
25
-
26
-
27
- def is_pressed(key):
28
- global KEY_PRESSED
29
- return KEY_PRESSED == key
30
-
31
- # keyboard listener
32
- listener = pynput.keyboard.Listener(on_press=on_press, on_release=on_release)
33
- listener.start()
34
-
35
- LEN_MIDI_RECORDINGS = 30
36
- class MidiRecorder:
37
- def __init__(self, place='', len_midi_recordings=LEN_MIDI_RECORDINGS):
38
- self.place = place
39
- self.len_midi_recordings = len_midi_recordings
40
- self.port = mido.open_input(mido.get_input_names()[0])
41
-
42
- def get_filename(self):
43
- now = datetime.now()
44
- return self.place + '_' + now.strftime("%b_%d_%Y_%Hh%Mm%Ss") + '.mid'
45
-
46
- def read_last_midi_msgs(self):
47
- return list(self.port.iter_pending())
48
-
49
- def live_read(self):
50
- while not is_pressed('esc'):
51
- for msg in self.read_last_midi_msgs():
52
- print(msg)
53
-
54
- def check_if_recording_started(self, msgs, t_init):
55
- started = False
56
- if len(msgs) > 0:
57
- for m in msgs:
58
- if m.type == 'note_on':
59
- started = True
60
- t_init = time.time()
61
- return started, t_init
62
-
63
- def create_empty_midi(self):
64
- mid = MidiFile()
65
- track = MidiTrack()
66
- mid.tracks.append(track)
67
- track.append(Message('program_change', program=0, time=0))
68
- return mid, track
69
-
70
- def record_next_N_seconds(self, n=None, saving_path=None):
71
- if saving_path is None:
72
- saving_path = SYNTH_RECORDED_PATH + self.get_filename()
73
- if n is None:
74
- n = self.len_midi_recordings
75
-
76
- print(f'Recoding the next {n} secs.'
77
- f'\n\tRecording starts when the first key is pressed;'
78
- f'\n\tPress Enter to end the recording;'
79
- f'\n\tPress BackSpace (<--) to cancel the recording;'
80
- f'\n\tSaving to {saving_path}')
81
- try:
82
- mid, track = self.create_empty_midi()
83
- started = False
84
- backspace_pressed = False
85
- t_init = time.time()
86
- while not is_pressed('enter') and (time.time() - t_init) < n:
87
- msgs = self.read_last_midi_msgs()
88
- if not started:
89
- started, t_init = self.check_if_recording_started(msgs, t_init)
90
- if started:
91
- print("\n\t--> First note pressed, it's on!")
92
- for m in msgs:
93
- print(m)
94
- if m.type == 'note_on' and m.velocity == 0:
95
- m_off = Message(type='note_off', velocity=127, note=m.note, channel=m.channel, time=m.time)
96
- track.append(m_off)
97
- track.append(m)
98
- if is_pressed('backspace'):
99
- backspace_pressed = True
100
- print('\n \t--> Recording cancelled! (you pressed BackSpace)')
101
- break
102
- # save the file
103
- if not backspace_pressed and len(mid.tracks[0]) > 0:
104
- mid.save(saving_path)
105
- print(f'\n--> Recording saved, duration: {mid.length:.2f} secs, {len(mid.tracks[0])} events.')
106
- except:
107
- print('\n --> The recording failed.')
108
-
109
-
110
- def run(self):
111
- # with pynput.Listener(
112
- # on_press=self.on_press) as listener:
113
- # listener.join()
114
- ready_msg = False
115
- print('Starting the recording loop!\n\tPress BackSpace to cancel the current recording;\n\tPress Esc to quit the loop (only works while not recording)')
116
- while True:
117
- if not ready_msg:
118
- print('-------\nReady to record!')
119
- print('Press space to start a recording\n')
120
- ready_msg = True
121
-
122
- if is_pressed('space'):
123
- self.record_next_N_seconds()
124
- ready_msg = False
125
- if is_pressed('esc'):
126
- print('End of the recording session. See you soon!')
127
- break
128
-
129
-
130
- midi_recorder = MidiRecorder(place='home')
131
- midi_recorder.live_read()
132
- # midi_recorder.run()
133
-
134
-
135
- # try:
136
- # controls[msg.control] = msg.value
137
- # except:
138
- # notes.append(msg.note)
139
- # port = mido.open_input()
140
- # while True:
141
- # for msg in port.iter_pending():
142
- # print(msg)
143
- #
144
- # print('start pause')
145
- # time.sleep(5)
146
- # print('stop pause')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/music/pipeline/url2audio.py DELETED
@@ -1,119 +0,0 @@
1
- import os
2
- from pytube import YouTube
3
- from src.music.utils import RATE_AUDIO_SAVE, slugify
4
- from src.music.config import MAX_LEN
5
-
6
- # define filtering keyworfds
7
- start_keywords = [' ', '(', ',', ':']
8
- end_keywords = [')', ' ', '.', ',', '!', ':']
9
- def get_all_keywords(k):
10
- all_keywords = []
11
- for s in start_keywords:
12
- for e in end_keywords:
13
- all_keywords.append(s + k + e)
14
- return all_keywords
15
- filtered_keywords = ['duet', 'duo', 'quartet', 'orchestre', 'orchestra',
16
- 'quintet', 'sixtet', 'septet', 'octet', 'backing track', 'accompaniment', 'string',
17
- 'contrebrasse', 'drums', 'guitar'] + get_all_keywords('live') + get_all_keywords('trio')
18
-
19
- # list of playlist for which no filtering should occur on keywords (they were prefiltered already, it's supposed to be only piano)
20
- playlist_and_channel_not_to_filter = ["https://www.youtube.com/c/MySheetMusicTranscriptions",
21
- "https://www.youtube.com/c/PianoNotion",
22
- "https://www.youtube.com/c/PianoNotion",
23
- "https://www.youtube.com/watch?v=3F5glYefwio&list=PLFv3ZQw-ZPxi2DH3Bau7lBC5K6zfPJZxc",
24
- "https://www.youtube.com/user/Mercuziopianist",
25
- "https://www.youtube.com/channel/UCy6NPK6-xeX7MZLaMARa5qg",
26
- "https://www.youtube.com/channel/UCKMRNFV2dWTWIJnymtA9_Iw",
27
- "https://www.youtube.com/c/pianomaedaful",
28
- "https://www.youtube.com/c/FrancescoParrinoMusic",
29
- "https://www.youtube.com/c/itsremco"]
30
- playlist_ok = "https://www.youtube.com/watch?v=sYv_vk6bJtk&list=PLO9E3V4rGLD9-0BEd3t-AvvMcVF1zOJPj"
31
-
32
-
33
- def should_be_filtered(title, length, url, playlist_url, max_length):
34
- to_filter = False
35
- reason = ''
36
- lower_title = title.lower()
37
- if length > max_length:
38
- reason += f'it is too long (>{max_length/60:.1f} min), '
39
- to_filter = True
40
- if any([f in lower_title for f in filtered_keywords]) \
41
- and playlist_url not in playlist_and_channel_not_to_filter \
42
- and 'to live' not in lower_title and 'alive' not in lower_title \
43
- and url not in playlist_ok:
44
- reason += 'it contains a filtered keyword, '
45
- to_filter = True
46
- return to_filter, reason
47
-
48
- def convert_mp4_to_mp3(path, verbose=True):
49
- if verbose: print(f"Converting mp4 to mp3, in {path}\n")
50
- assert '.mp4' == path[-4:]
51
- os.system(f'ffmpeg -i "{path}" -loglevel panic -y -ac 1 -ar {int(RATE_AUDIO_SAVE)} "{path[:-4] + ".mp3"}" ')
52
- os.remove(path)
53
- if verbose: print('\tDone.')
54
-
55
- def pipeline_video(video, playlist_path, filename):
56
- # extract best stream for this video
57
- stream, kbps = extract_best_stream(video.streams)
58
- stream.download(output_path=playlist_path, filename=filename + '.mp4')
59
- # convert to mp3
60
- convert_mp4_to_mp3(playlist_path + filename + '.mp4', verbose=False)
61
- return kbps
62
-
63
- def extract_best_stream(streams):
64
- # extract best audio stream
65
- stream_out = streams.get_audio_only()
66
- kbps = int(stream_out.abr[:-4])
67
- return stream_out, kbps
68
-
69
- def get_title_and_length(video):
70
- title = video.title
71
- filename = slugify(title)
72
- length = video.length
73
- return title, filename, length, video.metadata
74
-
75
-
76
- def url2audio(playlist_path, video_url=None, video=None, playlist_url='', apply_filters=False, verbose=False, level=0):
77
- assert video_url is not None or video is not None, 'needs either video or url'
78
- error_msg = 'Error in loading video?'
79
- try:
80
- if not video:
81
- video = YouTube(video_url)
82
- error_msg += ' Nope. In extracting title and length?'
83
- title, filename, length, video_meta_data = get_title_and_length(video)
84
- if apply_filters:
85
- to_filter, reason = should_be_filtered(title, length, video_url, playlist_url, MAX_LEN)
86
- else:
87
- to_filter = False
88
- if not to_filter:
89
- audio_path = playlist_path + filename + ".mp3"
90
- if verbose: print(' ' * level + f'Downloading {title}, Url: {video_url}')
91
- if not os.path.exists(audio_path):
92
- if length > MAX_LEN and verbose: print(' ' * (level + 2) + f'Long video ({int(length/60)} min), will be cut after {int(MAX_LEN/60)} min.')
93
- error_msg += ' Nope. In pipeline video?'
94
- kbps = None
95
- for _ in range(5):
96
- try:
97
- kbps = pipeline_video(video, playlist_path, filename)
98
- break
99
- except:
100
- pass
101
- assert kbps is not None
102
- error_msg += ' Nope. In dict filling?'
103
- data = dict(title=title, filename=filename, length=length, kbps=kbps, url=video_url, meta=video_meta_data)
104
- error_msg += ' Nope. '
105
- else:
106
- if verbose: print(' ' * (level + 2) + 'Song already downloaded')
107
- data = None
108
- return audio_path, data, ''
109
- else:
110
- return None, None, f'Filtered because {reason}'
111
- except:
112
- if verbose: print(' ' * (level + 2) + f'Download failed with error {error_msg}')
113
- if os.path.exists(audio_path):
114
- os.remove(audio_path)
115
- return None, None, error_msg + ' Yes.'
116
-
117
-
118
-
119
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/music/representation_analysis/__init__.py DELETED
File without changes
src/music/representation_analysis/analyze_rep.py DELETED
@@ -1,146 +0,0 @@
1
- import numpy as np
2
- from sklearn.cluster import KMeans
3
- from sklearn.neighbors import NearestNeighbors
4
- from sklearn.manifold import TSNE
5
- from src.music.utils import get_all_subfiles_with_extension
6
- import matplotlib.pyplot as plt
7
- import pickle
8
- import random
9
- # import umap
10
- import os
11
- from shutil import copy
12
- # install numba =numba==0.51.2
13
- # keyword = '32_represented'
14
- # rep_path = f"/home/cedric/Documents/pianocktail/data/music/{keyword}/"
15
- # plot_path = '/home/cedric/Documents/pianocktail/data/music/representation_analysis/plots/'
16
- # neighbors_path = '/home/cedric/Documents/pianocktail/data/music/representation_analysis/neighbors/'
17
- interpolation_path = '/home/cedric/Documents/pianocktail/data/music/representation_analysis/interpolation/'
18
- keyword = 'b256_r128_represented'
19
- rep_path = f"/home/cedric/Documents/pianocktail/data/music/dataset_exploration/dataset_representation/{keyword}/"
20
- plot_path = '/home/cedric/Documents/pianocktail/data/music/dataset_exploration/dataset_representation/analysis/plots/'
21
- neighbors_path = f'/home/cedric/Documents/pianocktail/data/music/dataset_exploration/dataset_representation/analysis/neighbors_{keyword}/'
22
- os.makedirs(neighbors_path, exist_ok=True)
23
- def extract_all_reps(rep_path):
24
- all_rep_path = get_all_subfiles_with_extension(rep_path, max_depth=3, extension='.txt', current_depth=0)
25
- all_data = []
26
- new_all_rep_path = []
27
- for i_r, r in enumerate(all_rep_path):
28
- if 'mean_std' not in r:
29
- all_data.append(np.loadtxt(r))
30
- assert len(all_data[-1]) == 128
31
- new_all_rep_path.append(r)
32
- data = np.array(all_data)
33
- to_save = dict(reps=data,
34
- paths=new_all_rep_path)
35
- with open(rep_path + 'music_reps_unnormalized.pickle', 'wb') as f:
36
- pickle.dump(to_save, f)
37
- for sample_size in [100, 200, 500, 1000, 2000, 5000]:
38
- if sample_size < len(data):
39
- inds = np.arange(len(data))
40
- np.random.shuffle(inds)
41
- to_save = dict(reps=data[inds[:sample_size]],
42
- paths=np.array(all_rep_path)[inds[:sample_size]])
43
- with open(rep_path + f'all_reps_unnormalized_sample{sample_size}.pickle', 'wb') as f:
44
- pickle.dump(to_save, f)
45
-
46
- def load_reps(rep_path, sample_size=None):
47
- if sample_size:
48
- with open(rep_path + f'all_reps_unnormalized_sample{sample_size}.pickle', 'rb') as f:
49
- data = pickle.load(f)
50
- else:
51
- with open(rep_path + f'music_reps_unnormalized.pickle', 'rb') as f:
52
- data = pickle.load(f)
53
- reps = data['reps']
54
- # playlists = [r.split(f'_{keyword}')[0].split('/')[-1] for r in data['paths']]
55
- playlists = [r.split(f'{keyword}')[1].split('/')[1] for r in data['paths']]
56
- n_data, dim_data = reps.shape
57
- return reps, data['paths'], playlists, n_data, dim_data
58
-
59
-
60
- def plot_tsne(reps, playlist_indexes, playlist_colors):
61
- tsne_reps = TSNE(n_components=2, learning_rate='auto', init='random').fit_transform(reps)
62
- plt.figure()
63
- keys_to_print = ['spot_piano_solo_blues', 'itsremco', 'piano_solo_classical',
64
- 'piano_solo_pop', 'piano_jazz_unspecified','spot_piano_solo_jazz_1', 'piano_solo_jazz_latin']
65
- keys_to_print = playlist_indexes.keys()
66
- for k in sorted(keys_to_print):
67
- if k in playlist_indexes.keys():
68
- # plt.scatter(tsne_reps[playlist_indexes[k], 0], tsne_reps[playlist_indexes[k], 1], s=100, label=k, alpha=0.5)
69
- plt.scatter(tsne_reps[playlist_indexes[k], 0], tsne_reps[playlist_indexes[k], 1], s=100, c=playlist_colors[k], label=k, alpha=0.5)
70
- plt.legend()
71
- plt.savefig(plot_path + f'tsne_{keyword}.png')
72
- fig = plt.gcf()
73
- plt.close(fig)
74
- # umap_reps = umap.UMAP().fit_transform(reps)
75
- # plt.figure()
76
- # for k in sorted(keys_to_print):
77
- # if k in playlist_indexes.keys():
78
- # plt.scatter(umap_reps[playlist_indexes[k], 0], tsne_reps[playlist_indexes[k], 1], s=100, c=playlist_colors[k], label=k, alpha=0.5)
79
- # plt.legend()
80
- # plt.savefig(plot_path + f'umap_{keyword}.png')
81
- # fig = plt.gcf()
82
- # plt.close(fig)
83
- return tsne_reps#, umap_reps
84
-
85
- def get_playlist_indexes(playlists):
86
- playlist_indexes = dict()
87
- for i in range(n_data):
88
- if playlists[i] not in playlist_indexes.keys():
89
- playlist_indexes[playlists[i]] = [i]
90
- else:
91
- playlist_indexes[playlists[i]].append(i)
92
- for k in playlist_indexes.keys():
93
- playlist_indexes[k] = np.array(playlist_indexes[k])
94
- set_playlists = sorted(set(playlists))
95
- playlist_colors = dict(zip(set_playlists, ['#%06X' % random.randint(0, 0xFFFFFF) for _ in range(len(set_playlists))]))
96
- return set_playlists, playlist_indexes, playlist_colors
97
-
98
- def convert_rep_path_midi_path(rep_path):
99
- # playlist = rep_path.split(f'_{keyword}/')[0].split('/')[-1]
100
- playlist = rep_path.split(f'{keyword}')[1].split('/')[1].replace('_represented', '')
101
- midi_path = "/home/cedric/Documents/pianocktail/data/music/dataset_exploration/dataset_representation/processed/" + playlist + '_processed/'
102
- filename = rep_path.split(f'{keyword}')[1].split(f'/')[2].split('_represented.txt')[0] + '_processed.mid'
103
- # filename = rep_path.split(f'_{keyword}/')[-1].split(f'_{keyword}')[0] + '_processed.mid'
104
- midi_path = midi_path + filename
105
- assert os.path.exists(midi_path), midi_path
106
- return midi_path
107
-
108
- def sample_nn(reps, rep_paths, playlists, n_samples=30):
109
- nn_model = NearestNeighbors(n_neighbors=6, metric='cosine')
110
- nn_model.fit(reps)
111
- indexes = np.arange(len(reps))
112
- np.random.shuffle(indexes)
113
- for i, ind in enumerate(indexes[:n_samples]):
114
- out = nn_model.kneighbors(reps[ind].reshape(1, -1))[1][0][1:]
115
- midi_path = convert_rep_path_midi_path(rep_paths[ind])
116
- copy(midi_path, neighbors_path + f'sample_{i}_playlist_{playlists[ind]}_target.mid')
117
- for i_n, neighbor in enumerate(out):
118
- midi_path = convert_rep_path_midi_path(rep_paths[neighbor])
119
- copy(midi_path, neighbors_path + f'sample_{i}_playlist_{playlists[neighbor]}_neighbor_{i_n}.mid')
120
-
121
- def interpolate(reps, rep_paths, path):
122
- files = os.listdir(path)
123
- bounds = [f for f in files if 'interpolation' not in f]
124
- b_reps = [np.loadtxt(path + f) for f in bounds]
125
- nn_model = NearestNeighbors(n_neighbors=6)
126
- nn_model.fit(reps)
127
- reps = [alpha * b_reps[0] + (1 - alpha) * b_reps[1] for alpha in np.linspace(0, 1., 5)]
128
- copy(convert_rep_path_midi_path(path + bounds[1]), path + 'interpolation_0.mid')
129
- copy(convert_rep_path_midi_path(path + bounds[0]), path + 'interpolation_1.mid')
130
- for alpha, rep in zip(np.linspace(0, 1, 5)[1:-1], reps[1: -1]):
131
- dists, indexes = nn_model.kneighbors(rep.reshape(1, -1))
132
- if dists.flatten()[0] == 0:
133
- nn = indexes.flatten()[1]
134
- else:
135
- nn = indexes.flatten()[0]
136
- midi_path = convert_rep_path_midi_path(rep_paths[nn])
137
- copy(midi_path, path + f'interpolation_{alpha}.mid')
138
-
139
- if __name__ == '__main__':
140
- extract_all_reps(rep_path)
141
- reps, rep_paths, playlists, n_data, dim_data = load_reps(rep_path)
142
- set_playlists, playlist_indexes, playlist_colors = get_playlist_indexes(playlists)
143
- # interpolate(reps, rep_paths, interpolation_path + 'trial_1/')
144
- sample_nn(reps, rep_paths, playlists)
145
- tsne_reps, umap_reps = plot_tsne(reps, playlist_indexes, playlist_colors)
146
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/music/representation_learning/__init__.py DELETED
File without changes
src/music/representation_learning/mlm_pretrain/__init__.py DELETED
File without changes
src/music/representation_learning/mlm_pretrain/data_collators.py DELETED
@@ -1,180 +0,0 @@
1
- from typing import Any, Dict, List, Optional, Tuple, Union
2
- from transformers.data.data_collator import DataCollatorForLanguageModeling, PreTrainedTokenizerBase, BatchEncoding, DataCollatorForPermutationLanguageModeling
3
- from dataclasses import dataclass
4
-
5
-
6
- def _torch_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
7
- """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
8
- import numpy as np
9
- import torch
10
-
11
- # Tensorize if necessary.
12
- if isinstance(examples[0], (list, tuple, np.ndarray)):
13
- examples = [torch.tensor(e, dtype=torch.long) for e in examples]
14
-
15
- length_of_first = examples[0].size(0)
16
-
17
- # Check if padding is necessary.
18
-
19
- are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
20
- if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
21
- return torch.stack(examples, dim=0)
22
-
23
- # If yes, check if we have a `pad_token`.
24
- if tokenizer._pad_token is None:
25
- raise ValueError(
26
- "You are attempting to pad samples but the tokenizer you are using"
27
- f" ({tokenizer.__class__.__name__}) does not have a pad token."
28
- )
29
-
30
- # Creating the full tensor and filling it with our data.
31
- max_length = max(x.size(0) for x in examples)
32
- if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
33
- max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
34
- result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
35
- for i, example in enumerate(examples):
36
- if tokenizer.padding_side == "right":
37
- result[i, : example.shape[0]] = example
38
- else:
39
- result[i, -example.shape[0] :] = example
40
- return result
41
-
42
-
43
- @dataclass
44
- class DataCollatorForMusicModeling(DataCollatorForLanguageModeling):
45
- """
46
- Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
47
- are not all of the same length.
48
- Args:
49
- tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
50
- The tokenizer used for encoding the data.
51
- mlm (`bool`, *optional*, defaults to `True`):
52
- Whether or not to use masked language modeling. If set to `False`, the labels are the same as the inputs
53
- with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for non-masked
54
- tokens and the value to predict for the masked token.
55
- mlm_probability (`float`, *optional*, defaults to 0.15):
56
- The probability with which to (randomly) mask tokens in the input, when `mlm` is set to `True`.
57
- pad_to_multiple_of (`int`, *optional*):
58
- If set will pad the sequence to a multiple of the provided value.
59
- return_tensors (`str`):
60
- The type of Tensor to return. Allowable values are "np", "pt" and "tf".
61
- <Tip>
62
- For best performance, this data collator should be used with a dataset having items that are dictionaries or
63
- BatchEncoding, with the `"special_tokens_mask"` key, as returned by a [`PreTrainedTokenizer`] or a
64
- [`PreTrainedTokenizerFast`] with the argument `return_special_tokens_mask=True`.
65
- </Tip>"""
66
-
67
-
68
- def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
69
- # Handle dict or lists with proper padding and conversion to tensor.
70
- if isinstance(examples[0], (dict, BatchEncoding)):
71
- batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of)
72
- else:
73
- batch = {
74
- "input_ids": _torch_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
75
- }
76
-
77
- # If special token mask has been preprocessed, pop it from the dict.
78
- special_tokens_mask = batch.pop("special_tokens_mask", None)
79
- if self.mlm:
80
- batch["input_ids"], batch["labels"] = self.torch_mask_tokens(
81
- batch["input_ids"], special_tokens_mask=special_tokens_mask
82
- )
83
- else:
84
- labels = batch["input_ids"].clone()
85
- if self.tokenizer.pad_token_id is not None:
86
- labels[labels == self.tokenizer.pad_token_id] = -100
87
- batch["labels"] = labels
88
- return batch
89
-
90
- def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
91
- """
92
- Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
93
- """
94
- import torch
95
-
96
- labels = inputs.clone()
97
- # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
98
- notes_shape = (labels.shape[0], labels.shape[1] // 5)
99
- probability_matrix = torch.full(notes_shape, self.mlm_probability)
100
- # if special_tokens_mask is None:
101
- # special_tokens_mask = [
102
- # self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
103
- # ]
104
- # special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
105
- # else:
106
- # special_tokens_mask = special_tokens_mask.bool()
107
-
108
- # probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
109
- masked_notes_indices = torch.bernoulli(probability_matrix).bool()
110
- masked_indices = torch.repeat_interleave(masked_notes_indices, repeats=5, dim=1)
111
- labels[~masked_indices] = -100 # We only compute loss on masked tokens
112
-
113
- # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
114
- indices_notes_replaced = torch.bernoulli(torch.full(notes_shape, 0.8)).bool() & masked_notes_indices
115
- indices_replaced = torch.repeat_interleave(indices_notes_replaced, repeats=5, dim=1)
116
- inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
117
-
118
- # 10% of the time, we replace masked input tokens with random word
119
- indices_notes_random = torch.bernoulli(torch.full(notes_shape, 0.5)).bool() & masked_notes_indices & ~indices_notes_replaced
120
- indices_random = torch.repeat_interleave(indices_notes_random, repeats=5, dim=1)
121
- random_words = torch.randint(3, len(self.tokenizer), labels.shape, dtype=torch.long)
122
- inputs[indices_random] = random_words[indices_random]
123
-
124
- # The rest of the time (10% of the time) we keep the masked input tokens unchanged
125
- return inputs, labels
126
-
127
-
128
-
129
- @dataclass
130
- class DataCollatorForSpanMusicModeling(DataCollatorForLanguageModeling):
131
- """
132
- Data collator used for permutation language modeling.
133
- - collates batches of tensors, honoring their tokenizer's pad_token
134
- - preprocesses batches for permutation language modeling with procedures specific to XLNet
135
- """
136
-
137
-
138
- def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
139
- """
140
- The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
141
- 0. Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
142
- 1. Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
143
- 2. Reserve a context of length `context_length = span_length / plm_probability` to surround span to be
144
- masked
145
- 3. Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length -
146
- span_length]` and mask tokens `start_index:start_index + span_length`
147
- 4. Set `cur_len = cur_len + context_length`. If `cur_len < max_len` (i.e. there are tokens remaining in the
148
- sequence to be processed), repeat from Step 1.
149
- """
150
-
151
- import torch
152
-
153
- labels = inputs.clone()
154
- # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
155
- notes_shape = (labels.shape[0], labels.shape[1] // 5)
156
- masked_notes_indices = torch.full(notes_shape, 0, dtype=torch.bool)
157
-
158
- for i in range(labels.size(0)):
159
- # Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
160
- cur_len = 0
161
- max_len = notes_shape[1]
162
-
163
- while cur_len < max_len:
164
- # Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
165
- span_length = torch.randint(1, 5 + 1, (1,)).item()
166
- # Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked
167
- context_length = int(span_length / self.mlm_probability)
168
- # Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length`
169
- start_index = cur_len + torch.randint(context_length - span_length + 1, (1,)).item()
170
- masked_notes_indices[i, start_index: start_index + span_length] = 1
171
- # Set `cur_len = cur_len + context_length`
172
- cur_len += context_length
173
-
174
- masked_indices = torch.repeat_interleave(masked_notes_indices, repeats=5, dim=1)
175
- labels[~masked_indices] = -100 # We only compute loss on masked tokens
176
-
177
- inputs[masked_indices] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
178
-
179
- return inputs, labels
180
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/music/representation_learning/mlm_pretrain/models/music-bert/config.json DELETED
@@ -1,20 +0,0 @@
1
- {
2
- "attention_probs_dropout_prob": 0.1,
3
- "gradient_checkpointing": false,
4
- "hidden_act": "gelu",
5
- "hidden_dropout_prob": 0.1,
6
- "hidden_size": 768,
7
- "initializer_range": 0.02,
8
- "intermediate_size": 3072,
9
- "layer_norm_eps": 1e-12,
10
- "max_position_embeddings": 512,
11
- "model_type": "bert",
12
- "num_attention_heads": 12,
13
- "num_hidden_layers": 12,
14
- "pad_token_id": 0,
15
- "position_embedding_type": "relative_key_query",
16
- "transformers_version": "4.8.2",
17
- "type_vocab_size": 2,
18
- "use_cache": true,
19
- "vocab_size": 30522
20
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/music/representation_learning/mlm_pretrain/models/music-bert/tokenizer.json DELETED
@@ -1 +0,0 @@
1
- {"version":"1.0","truncation":null,"padding":null,"added_tokens":[],"normalizer":{"type":"Lowercase"},"pre_tokenizer":{"type":"Whitespace"},"post_processor":null,"decoder":null,"model":{"type":"WordLevel","vocab":{"[PAD]":0,"[MASK]":1,"[UNK]":2,"2":3,"3":4,"4":5,"5":6,"6":7,"7":8,"8":9,"9":10,"10":11,"11":12,"12":13,"13":14,"14":15,"15":16,"16":17,"17":18,"18":19,"19":20,"20":21,"21":22,"22":23,"23":24,"24":25,"25":26,"26":27,"27":28,"28":29,"29":30,"30":31,"31":32,"32":33,"33":34,"34":35,"35":36,"36":37,"37":38,"38":39,"39":40,"40":41,"41":42,"42":43,"43":44,"44":45,"45":46,"46":47,"47":48,"48":49,"49":50,"50":51,"51":52,"52":53,"53":54,"54":55,"55":56,"56":57,"57":58,"58":59,"59":60,"60":61,"61":62,"62":63,"63":64,"64":65,"65":66,"66":67,"67":68,"68":69,"69":70,"70":71,"71":72,"72":73,"73":74,"74":75,"75":76,"76":77,"77":78,"78":79,"79":80,"80":81,"81":82,"82":83,"83":84,"84":85,"85":86,"86":87,"87":88,"88":89,"89":90,"90":91,"91":92,"92":93,"93":94,"94":95,"95":96,"96":97,"97":98,"98":99,"99":100,"100":101,"101":102,"102":103,"103":104,"104":105,"105":106,"106":107,"107":108,"108":109,"109":110,"110":111,"111":112,"112":113,"113":114,"114":115,"115":116,"116":117,"117":118,"118":119,"119":120,"120":121,"121":122,"122":123,"123":124,"124":125,"125":126,"126":127,"127":128,"128":129,"129":130,"130":131,"131":132,"132":133,"133":134,"134":135,"135":136,"136":137,"137":138,"138":139,"139":140,"140":141,"141":142,"142":143,"143":144,"144":145,"145":146,"146":147,"147":148,"148":149,"149":150,"150":151,"151":152,"152":153,"153":154,"154":155,"155":156,"156":157,"157":158,"158":159,"159":160,"160":161,"161":162,"162":163,"163":164,"164":165,"165":166,"166":167,"167":168,"168":169,"169":170,"170":171,"171":172,"172":173,"173":174,"174":175,"175":176,"176":177,"177":178,"178":179,"179":180,"180":181,"181":182,"182":183},"unk_token":"[UNK]"}}
 
 
src/music/representation_learning/mlm_pretrain/models/music-spanbert/config.json DELETED
@@ -1,20 +0,0 @@
1
- {
2
- "attention_probs_dropout_prob": 0.1,
3
- "gradient_checkpointing": false,
4
- "hidden_act": "gelu",
5
- "hidden_dropout_prob": 0.1,
6
- "hidden_size": 768,
7
- "initializer_range": 0.02,
8
- "intermediate_size": 3072,
9
- "layer_norm_eps": 1e-12,
10
- "max_position_embeddings": 512,
11
- "model_type": "bert",
12
- "num_attention_heads": 12,
13
- "num_hidden_layers": 12,
14
- "pad_token_id": 0,
15
- "position_embedding_type": "relative_key_query",
16
- "transformers_version": "4.8.2",
17
- "type_vocab_size": 2,
18
- "use_cache": true,
19
- "vocab_size": 30522
20
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/music/representation_learning/mlm_pretrain/models/music-spanbert/tokenizer.json DELETED
@@ -1 +0,0 @@
1
- {"version":"1.0","truncation":null,"padding":null,"added_tokens":[],"normalizer":{"type":"Lowercase"},"pre_tokenizer":{"type":"Whitespace"},"post_processor":null,"decoder":null,"model":{"type":"WordLevel","vocab":{"[PAD]":0,"[MASK]":1,"[UNK]":2,"2":3,"3":4,"4":5,"5":6,"6":7,"7":8,"8":9,"9":10,"10":11,"11":12,"12":13,"13":14,"14":15,"15":16,"16":17,"17":18,"18":19,"19":20,"20":21,"21":22,"22":23,"23":24,"24":25,"25":26,"26":27,"27":28,"28":29,"29":30,"30":31,"31":32,"32":33,"33":34,"34":35,"35":36,"36":37,"37":38,"38":39,"39":40,"40":41,"41":42,"42":43,"43":44,"44":45,"45":46,"46":47,"47":48,"48":49,"49":50,"50":51,"51":52,"52":53,"53":54,"54":55,"55":56,"56":57,"57":58,"58":59,"59":60,"60":61,"61":62,"62":63,"63":64,"64":65,"65":66,"66":67,"67":68,"68":69,"69":70,"70":71,"71":72,"72":73,"73":74,"74":75,"75":76,"76":77,"77":78,"78":79,"79":80,"80":81,"81":82,"82":83,"83":84,"84":85,"85":86,"86":87,"87":88,"88":89,"89":90,"90":91,"91":92,"92":93,"93":94,"94":95,"95":96,"96":97,"97":98,"98":99,"99":100,"100":101,"101":102,"102":103,"103":104,"104":105,"105":106,"106":107,"107":108,"108":109,"109":110,"110":111,"111":112,"112":113,"113":114,"114":115,"115":116,"116":117,"117":118,"118":119,"119":120,"120":121,"121":122,"122":123,"123":124,"124":125,"125":126,"126":127,"127":128,"128":129,"129":130,"130":131,"131":132,"132":133,"133":134,"134":135,"135":136,"136":137,"137":138,"138":139,"139":140,"140":141,"141":142,"142":143,"143":144,"144":145,"145":146,"146":147,"147":148,"148":149,"149":150,"150":151,"151":152,"152":153,"153":154,"154":155,"155":156,"156":157,"157":158,"158":159,"159":160,"160":161,"161":162,"162":163,"163":164,"164":165,"165":166,"166":167,"167":168,"168":169,"169":170,"170":171,"171":172,"172":173,"173":174,"174":175,"175":176,"176":177,"177":178,"178":179,"179":180,"180":181,"181":182,"182":183},"unk_token":"[UNK]"}}
 
 
src/music/representation_learning/mlm_pretrain/models/music-t5-small/config.json DELETED
@@ -1,56 +0,0 @@
1
- {
2
- "architectures": [
3
- "T5WithLMHeadModel"
4
- ],
5
- "d_ff": 2048,
6
- "d_kv": 64,
7
- "d_model": 512,
8
- "decoder_start_token_id": 0,
9
- "dropout_rate": 0.1,
10
- "eos_token_id": 1,
11
- "feed_forward_proj": "relu",
12
- "gradient_checkpointing": false,
13
- "initializer_factor": 1.0,
14
- "is_encoder_decoder": true,
15
- "layer_norm_epsilon": 1e-06,
16
- "model_type": "t5",
17
- "n_positions": 512,
18
- "num_decoder_layers": 6,
19
- "num_heads": 8,
20
- "num_layers": 6,
21
- "output_past": true,
22
- "pad_token_id": 0,
23
- "relative_attention_num_buckets": 32,
24
- "task_specific_params": {
25
- "summarization": {
26
- "early_stopping": true,
27
- "length_penalty": 2.0,
28
- "max_length": 200,
29
- "min_length": 30,
30
- "no_repeat_ngram_size": 3,
31
- "num_beams": 4,
32
- "prefix": "summarize: "
33
- },
34
- "translation_en_to_de": {
35
- "early_stopping": true,
36
- "max_length": 300,
37
- "num_beams": 4,
38
- "prefix": "translate English to German: "
39
- },
40
- "translation_en_to_fr": {
41
- "early_stopping": true,
42
- "max_length": 300,
43
- "num_beams": 4,
44
- "prefix": "translate English to French: "
45
- },
46
- "translation_en_to_ro": {
47
- "early_stopping": true,
48
- "max_length": 300,
49
- "num_beams": 4,
50
- "prefix": "translate English to Romanian: "
51
- }
52
- },
53
- "transformers_version": "4.8.2",
54
- "use_cache": true,
55
- "vocab_size": 32128
56
- }