import torch; torch.manual_seed(0) import torch.utils from torch.utils.data import DataLoader import torch.distributions import torch.nn as nn import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200 from src.cocktails.representation_learning.dataset import MyDataset, get_representation_from_ingredient, get_max_n_ingredients import json import pandas as pd import numpy as np import os from src.cocktails.representation_learning.multihead_model import get_multihead_model from src.cocktails.config import COCKTAILS_CSV_DATA, FULL_COCKTAIL_REP_PATH, EXPERIMENT_PATH from src.cocktails.utilities.cocktail_utilities import get_bunch_of_rep_keys from src.cocktails.utilities.ingredients_utilities import ingredient_profiles from resource import getrusage from resource import RUSAGE_SELF import gc gc.collect(2) device = 'cuda' if torch.cuda.is_available() else 'cpu' def get_params(): data = pd.read_csv(COCKTAILS_CSV_DATA) max_ingredients, ingredient_set, liquor_set, liqueur_set = get_max_n_ingredients(data) num_ingredients = len(ingredient_set) rep_keys = get_bunch_of_rep_keys()['custom'] ing_keys = [k.split(' ')[1] for k in rep_keys] ing_keys.remove('volume') nb_ing_categories = len(set(ingredient_profiles['type'])) category_encodings = dict(zip(sorted(set(ingredient_profiles['type'])), np.eye(nb_ing_categories))) params = dict(trial_id='test', save_path=EXPERIMENT_PATH + "/multihead_model/", nb_epochs=500, print_every=50, plot_every=50, batch_size=128, lr=0.001, dropout=0., nb_epoch_switch_beta=600, latent_dim=10, beta_vae=0.2, ing_keys=ing_keys, nb_ingredients=len(ingredient_set), hidden_dims_ingredients=[128], hidden_dims_cocktail=[64], hidden_dims_decoder=[32], agg='mean', activation='relu', auxiliaries_dict=dict(categories=dict(weight=5, type='classif', final_activ=None, dim_output=len(set(data['subcategory']))), #0.5 glasses=dict(weight=0.5, type='classif', final_activ=None, dim_output=len(set(data['glass']))), #0.1 prep_type=dict(weight=0.1, type='classif', final_activ=None, dim_output=len(set(data['category']))),#1 cocktail_reps=dict(weight=1, type='regression', final_activ=None, dim_output=13),#1 volume=dict(weight=1, type='regression', final_activ='relu', dim_output=1),#1 taste_reps=dict(weight=1, type='regression', final_activ='relu', dim_output=2),#1 ingredients_presence=dict(weight=0, type='multiclassif', final_activ=None, dim_output=num_ingredients),#10 ingredients_quantities=dict(weight=0, type='regression', final_activ=None, dim_output=num_ingredients)), category_encodings=category_encodings ) water_rep, indexes_to_normalize = get_representation_from_ingredient(ingredients=['water'], quantities=[1], max_q_per_ing=dict(zip(ingredient_set, [1] * num_ingredients)), index=0, params=params) dim_rep_ingredient = water_rep.size params['indexes_ing_to_normalize'] = indexes_to_normalize params['deepset_latent_dim'] = dim_rep_ingredient * max_ingredients params['dim_rep_ingredient'] = dim_rep_ingredient params['input_dim'] = params['nb_ingredients'] params = compute_expe_name_and_save_path(params) del params['category_encodings'] # to dump with open(params['save_path'] + 'params.json', 'w') as f: json.dump(params, f) params = complete_params(params) return params def complete_params(params): data = pd.read_csv(COCKTAILS_CSV_DATA) cocktail_reps = np.loadtxt(FULL_COCKTAIL_REP_PATH) nb_ing_categories = len(set(ingredient_profiles['type'])) category_encodings = dict(zip(sorted(set(ingredient_profiles['type'])), np.eye(nb_ing_categories))) params['cocktail_reps'] = cocktail_reps params['raw_data'] = data params['category_encodings'] = category_encodings return params def compute_losses_and_accuracies(loss_functions, auxiliaries, auxiliaries_str, outputs, data): losses = dict() accuracies = dict() other_metrics = dict() for i_k, k in enumerate(auxiliaries_str): # get ground truth # compute loss if k == 'volume': outputs[i_k] = outputs[i_k].flatten() ground_truth = auxiliaries[k] if ground_truth.dtype == torch.float64: losses[k] = loss_functions[k](outputs[i_k], ground_truth.float()).float() elif ground_truth.dtype == torch.int64: if str(loss_functions[k]) != "BCEWithLogitsLoss()": losses[k] = loss_functions[k](outputs[i_k].float(), ground_truth.long()).float() else: losses[k] = loss_functions[k](outputs[i_k].float(), ground_truth.float()).float() else: losses[k] = loss_functions[k](outputs[i_k], ground_truth).float() # compute accuracies if str(loss_functions[k]) == 'CrossEntropyLoss()': bs, n_options = outputs[i_k].shape predicted = outputs[i_k].argmax(dim=1).detach().numpy() true = ground_truth.int().detach().numpy() confusion_matrix = np.zeros([n_options, n_options]) for i in range(bs): confusion_matrix[true[i], predicted[i]] += 1 acc = confusion_matrix.diagonal().sum() / bs for i in range(n_options): if confusion_matrix[i].sum() != 0: confusion_matrix[i] /= confusion_matrix[i].sum() other_metrics[k + '_confusion'] = confusion_matrix accuracies[k] = np.mean(outputs[i_k].argmax(dim=1).detach().numpy() == ground_truth.int().detach().numpy()) assert (acc - accuracies[k]) < 1e-5 elif str(loss_functions[k]) == 'BCEWithLogitsLoss()': assert k == 'ingredients_presence' outputs_rescaled = outputs[i_k].detach().numpy() * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities predicted_presence = (outputs_rescaled > 0).astype(bool) presence = ground_truth.detach().numpy().astype(bool) other_metrics[k + '_false_positive'] = np.mean(np.logical_and(predicted_presence.astype(bool), ~presence.astype(bool))) other_metrics[k + '_false_negative'] = np.mean(np.logical_and(~predicted_presence.astype(bool), presence.astype(bool))) accuracies[k] = np.mean(predicted_presence == presence) # accuracy for multi class labeling elif str(loss_functions[k]) == 'MSELoss()': accuracies[k] = np.nan else: raise ValueError return losses, accuracies, other_metrics def compute_metric_output(aux_other_metrics, data, ingredient_quantities, x_hat): ing_q = ingredient_quantities.detach().numpy()# * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities ing_presence = (ing_q > 0) x_hat = x_hat.detach().numpy() # x_hat = x_hat.detach().numpy() * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities abs_diff = np.abs(ing_q - x_hat) * data.dataset.max_ing_quantities # abs_diff = np.abs(ing_q - x_hat) ing_q_abs_loss_when_present, ing_q_abs_loss_when_absent = [], [] for i in range(ingredient_quantities.shape[0]): ing_q_abs_loss_when_present.append(np.mean(abs_diff[i, np.where(ing_presence[i])])) ing_q_abs_loss_when_absent.append(np.mean(abs_diff[i, np.where(~ing_presence[i])])) aux_other_metrics['ing_q_abs_loss_when_present'] = np.mean(ing_q_abs_loss_when_present) aux_other_metrics['ing_q_abs_loss_when_absent'] = np.mean(ing_q_abs_loss_when_absent) return aux_other_metrics def run_epoch(opt, train, model, data, loss_functions, weights, params): if train: model.train() else: model.eval() # prepare logging of losses losses = dict(kld_loss=[], mse_loss=[], vae_loss=[], volume_loss=[], global_loss=[]) accuracies = dict() other_metrics = dict() for aux in params['auxiliaries_dict'].keys(): losses[aux] = [] accuracies[aux] = [] if train: opt.zero_grad() for d in data: nb_ingredients = d[0] batch_size = nb_ingredients.shape[0] x_ingredients = d[1].float() ingredient_quantities = d[2] cocktail_reps = d[3] auxiliaries = d[4] for k in auxiliaries.keys(): if auxiliaries[k].dtype == torch.float64: auxiliaries[k] = auxiliaries[k].float() taste_valid = d[-1] z, outputs, auxiliaries_str = model.forward(ingredient_quantities.float()) # get auxiliary losses and accuracies aux_losses, aux_accuracies, aux_other_metrics = compute_losses_and_accuracies(loss_functions, auxiliaries, auxiliaries_str, outputs, data) # compute vae loss aux_other_metrics = compute_metric_output(aux_other_metrics, data, ingredient_quantities, outputs[auxiliaries_str.index('ingredients_quantities')]) indexes_taste_valid = np.argwhere(taste_valid.detach().numpy()).flatten() if indexes_taste_valid.size > 0: outputs_taste = model.get_auxiliary(z[indexes_taste_valid], aux_str='taste_reps') gt = auxiliaries['taste_reps'][indexes_taste_valid] factor_loss = indexes_taste_valid.size / (0.3 * batch_size)# factor on the loss: if same ratio as actual dataset factor = 1 if there is less data, then the factor decreases, more data, it increases aux_losses['taste_reps'] = (loss_functions['taste_reps'](outputs_taste, gt) * factor_loss).float() else: aux_losses['taste_reps'] = torch.FloatTensor([0]).reshape([]) aux_accuracies['taste_reps'] = 0 # aggregate losses global_loss = torch.sum(torch.cat([torch.atleast_1d(aux_losses[k] * weights[k]) for k in params['auxiliaries_dict'].keys()])) # for k in params['auxiliaries_dict'].keys(): # global_loss += aux_losses[k] * weights[k] if train: global_loss.backward() opt.step() opt.zero_grad() # logging losses['global_loss'].append(float(global_loss)) for k in params['auxiliaries_dict'].keys(): losses[k].append(float(aux_losses[k])) accuracies[k].append(float(aux_accuracies[k])) for k in aux_other_metrics.keys(): if k not in other_metrics.keys(): other_metrics[k] = [aux_other_metrics[k]] else: other_metrics[k].append(aux_other_metrics[k]) for k in losses.keys(): losses[k] = np.mean(losses[k]) for k in accuracies.keys(): accuracies[k] = np.mean(accuracies[k]) for k in other_metrics.keys(): other_metrics[k] = np.mean(other_metrics[k], axis=0) return model, losses, accuracies, other_metrics def prepare_data_and_loss(params): train_data = MyDataset(split='train', params=params) test_data = MyDataset(split='test', params=params) train_data_loader = DataLoader(train_data, batch_size=params['batch_size'], shuffle=True) test_data_loader = DataLoader(test_data, batch_size=params['batch_size'], shuffle=True) loss_functions = dict() weights = dict() for k in sorted(params['auxiliaries_dict'].keys()): if params['auxiliaries_dict'][k]['type'] == 'classif': if k == 'glasses': classif_weights = train_data.glasses_weights elif k == 'prep_type': classif_weights = train_data.prep_types_weights elif k == 'categories': classif_weights = train_data.categories_weights else: raise ValueError loss_functions[k] = nn.CrossEntropyLoss(torch.FloatTensor(classif_weights)) elif params['auxiliaries_dict'][k]['type'] == 'multiclassif': loss_functions[k] = nn.BCEWithLogitsLoss() elif params['auxiliaries_dict'][k]['type'] == 'regression': loss_functions[k] = nn.MSELoss() else: raise ValueError weights[k] = params['auxiliaries_dict'][k]['weight'] return loss_functions, train_data_loader, test_data_loader, weights def print_losses(train, losses, accuracies, other_metrics): keyword = 'Train' if train else 'Eval' print(f'\t{keyword} logs:') keys = ['global_loss', 'vae_loss', 'mse_loss', 'kld_loss', 'volume_loss'] for k in keys: print(f'\t\t{k} - Loss: {losses[k]:.2f}') for k in sorted(accuracies.keys()): print(f'\t\t{k} (aux) - Loss: {losses[k]:.2f}, Acc: {accuracies[k]:.2f}') for k in sorted(other_metrics.keys()): if 'confusion' not in k: print(f'\t\t{k} - {other_metrics[k]:.2f}') def run_experiment(params, verbose=True): loss_functions, train_data_loader, test_data_loader, weights = prepare_data_and_loss(params) model_params = [params[k] for k in ["input_dim", "activation", "hidden_dims_cocktail", "latent_dim", "dropout", "auxiliaries_dict", "hidden_dims_decoder"]] model = get_multihead_model(*model_params) opt = torch.optim.AdamW(model.parameters(), lr=params['lr']) all_train_losses = [] all_eval_losses = [] all_train_accuracies = [] all_eval_accuracies = [] all_eval_other_metrics = [] all_train_other_metrics = [] best_loss = np.inf model, eval_losses, eval_accuracies, eval_other_metrics = run_epoch(opt=opt, train=False, model=model, data=test_data_loader, loss_functions=loss_functions, weights=weights, params=params) all_eval_losses.append(eval_losses) all_eval_accuracies.append(eval_accuracies) all_eval_other_metrics.append(eval_other_metrics) if verbose: print(f'\n--------\nEpoch #0') if verbose: print_losses(train=False, accuracies=eval_accuracies, losses=eval_losses, other_metrics=eval_other_metrics) for epoch in range(params['nb_epochs']): if verbose and (epoch + 1) % params['print_every'] == 0: print(f'\n--------\nEpoch #{epoch+1}') model, train_losses, train_accuracies, train_other_metrics = run_epoch(opt=opt, train=True, model=model, data=train_data_loader, loss_functions=loss_functions, weights=weights, params=params) if verbose and (epoch + 1) % params['print_every'] == 0: print_losses(train=True, accuracies=train_accuracies, losses=train_losses, other_metrics=train_other_metrics) model, eval_losses, eval_accuracies, eval_other_metrics = run_epoch(opt=opt, train=False, model=model, data=test_data_loader, loss_functions=loss_functions, weights=weights, params=params) if verbose and (epoch + 1) % params['print_every'] == 0: print_losses(train=False, accuracies=eval_accuracies, losses=eval_losses, other_metrics=eval_other_metrics) if eval_losses['global_loss'] < best_loss: best_loss = eval_losses['global_loss'] if verbose: print(f'Saving new best model with loss {best_loss:.2f}') torch.save(model.state_dict(), params['save_path'] + f'checkpoint_best.save') # log all_train_losses.append(train_losses) all_train_accuracies.append(train_accuracies) all_eval_losses.append(eval_losses) all_eval_accuracies.append(eval_accuracies) all_eval_other_metrics.append(eval_other_metrics) all_train_other_metrics.append(train_other_metrics) # if epoch == params['nb_epoch_switch_beta']: # params['beta_vae'] = 2.5 # params['auxiliaries_dict']['prep_type']['weight'] /= 10 # params['auxiliaries_dict']['glasses']['weight'] /= 10 if (epoch + 1) % params['plot_every'] == 0: plot_results(all_train_losses, all_train_accuracies, all_train_other_metrics, all_eval_losses, all_eval_accuracies, all_eval_other_metrics, params['plot_path'], weights) return model def plot_results(all_train_losses, all_train_accuracies, all_train_other_metrics, all_eval_losses, all_eval_accuracies, all_eval_other_metrics, plot_path, weights): steps = np.arange(len(all_eval_accuracies)) loss_keys = sorted(all_train_losses[0].keys()) acc_keys = sorted(all_train_accuracies[0].keys()) metrics_keys = sorted(all_train_other_metrics[0].keys()) plt.figure() plt.title('Train losses') for k in loss_keys: factor = 1 if k == 'mse_loss' else 1 if k not in weights.keys(): plt.plot(steps[1:], [train_loss[k] * factor for train_loss in all_train_losses], label=k) else: if weights[k] != 0: plt.plot(steps[1:], [train_loss[k] * factor for train_loss in all_train_losses], label=k) plt.legend() plt.ylim([0, 4]) plt.savefig(plot_path + 'train_losses.png', dpi=200) fig = plt.gcf() plt.close(fig) plt.figure() plt.title('Train accuracies') for k in acc_keys: if weights[k] != 0: plt.plot(steps[1:], [train_acc[k] for train_acc in all_train_accuracies], label=k) plt.legend() plt.ylim([0, 1]) plt.savefig(plot_path + 'train_acc.png', dpi=200) fig = plt.gcf() plt.close(fig) plt.figure() plt.title('Train other metrics') for k in metrics_keys: if 'confusion' not in k and 'presence' in k: plt.plot(steps[1:], [train_metric[k] for train_metric in all_train_other_metrics], label=k) plt.legend() plt.ylim([0, 1]) plt.savefig(plot_path + 'train_ing_presence_errors.png', dpi=200) fig = plt.gcf() plt.close(fig) plt.figure() plt.title('Train other metrics') for k in metrics_keys: if 'confusion' not in k and 'presence' not in k: plt.plot(steps[1:], [train_metric[k] for train_metric in all_train_other_metrics], label=k) plt.legend() plt.ylim([0, 15]) plt.savefig(plot_path + 'train_ing_q_error.png', dpi=200) fig = plt.gcf() plt.close(fig) plt.figure() plt.title('Eval losses') for k in loss_keys: factor = 1 if k == 'mse_loss' else 1 if k not in weights.keys(): plt.plot(steps, [eval_loss[k] * factor for eval_loss in all_eval_losses], label=k) else: if weights[k] != 0: plt.plot(steps, [eval_loss[k] * factor for eval_loss in all_eval_losses], label=k) plt.legend() plt.ylim([0, 4]) plt.savefig(plot_path + 'eval_losses.png', dpi=200) fig = plt.gcf() plt.close(fig) plt.figure() plt.title('Eval accuracies') for k in acc_keys: if weights[k] != 0: plt.plot(steps, [eval_acc[k] for eval_acc in all_eval_accuracies], label=k) plt.legend() plt.ylim([0, 1]) plt.savefig(plot_path + 'eval_acc.png', dpi=200) fig = plt.gcf() plt.close(fig) plt.figure() plt.title('Eval other metrics') for k in metrics_keys: if 'confusion' not in k and 'presence' in k: plt.plot(steps, [eval_metric[k] for eval_metric in all_eval_other_metrics], label=k) plt.legend() plt.ylim([0, 1]) plt.savefig(plot_path + 'eval_ing_presence_errors.png', dpi=200) fig = plt.gcf() plt.close(fig) plt.figure() plt.title('Eval other metrics') for k in metrics_keys: if 'confusion' not in k and 'presence' not in k: plt.plot(steps, [eval_metric[k] for eval_metric in all_eval_other_metrics], label=k) plt.legend() plt.ylim([0, 15]) plt.savefig(plot_path + 'eval_ing_q_error.png', dpi=200) fig = plt.gcf() plt.close(fig) for k in metrics_keys: if 'confusion' in k: plt.figure() plt.title(k) plt.ylabel('True') plt.xlabel('Predicted') plt.imshow(all_eval_other_metrics[-1][k], vmin=0, vmax=1) plt.colorbar() plt.savefig(plot_path + f'eval_{k}.png', dpi=200) fig = plt.gcf() plt.close(fig) for k in metrics_keys: if 'confusion' in k: plt.figure() plt.title(k) plt.ylabel('True') plt.xlabel('Predicted') plt.imshow(all_train_other_metrics[-1][k], vmin=0, vmax=1) plt.colorbar() plt.savefig(plot_path + f'train_{k}.png', dpi=200) fig = plt.gcf() plt.close(fig) plt.close('all') def get_model(model_path): with open(model_path + 'params.json', 'r') as f: params = json.load(f) params['save_path'] = model_path model_chkpt = model_path + "checkpoint_best.save" model_params = [params[k] for k in ["input_dim", "activation", "hidden_dims_cocktail", "latent_dim", "dropout", "auxiliaries_dict", "hidden_dims_decoder"]] model = get_multihead_model(*model_params) model.load_state_dict(torch.load(model_chkpt)) model.eval() max_ing_quantities = np.loadtxt(model_path + 'max_ing_quantities.txt') def predict(ing_qs, aux_str): ing_qs /= max_ing_quantities input_model = torch.FloatTensor(ing_qs).reshape(1, -1) _, outputs, auxiliaries_str = model.forward(input_model, ) if isinstance(aux_str, str): return outputs[auxiliaries_str.index(aux_str)].detach().numpy() elif isinstance(aux_str, list): return [outputs[auxiliaries_str.index(aux)].detach().numpy() for aux in aux_str] else: raise ValueError return predict, params def compute_expe_name_and_save_path(params): weights_str = '[' for aux in params['auxiliaries_dict'].keys(): weights_str += f'{params["auxiliaries_dict"][aux]["weight"]}, ' weights_str = weights_str[:-2] + ']' save_path = params['save_path'] + params["trial_id"] save_path += f'_lr{params["lr"]}' save_path += f'_betavae{params["beta_vae"]}' save_path += f'_bs{params["batch_size"]}' save_path += f'_latentdim{params["latent_dim"]}' save_path += f'_hding{params["hidden_dims_ingredients"]}' save_path += f'_hdcocktail{params["hidden_dims_cocktail"]}' save_path += f'_hddecoder{params["hidden_dims_decoder"]}' save_path += f'_agg{params["agg"]}' save_path += f'_activ{params["activation"]}' save_path += f'_w{weights_str}' counter = 0 while os.path.exists(save_path + f"_{counter}"): counter += 1 save_path = save_path + f"_{counter}" + '/' params["save_path"] = save_path os.makedirs(save_path) os.makedirs(save_path + 'plots/') params['plot_path'] = save_path + 'plots/' print(f'logging to {save_path}') return params if __name__ == '__main__': params = get_params() run_experiment(params)