import os import torch; torch.manual_seed(0) import torch.nn as nn import torch.nn.functional as F import torch.utils import torch.distributions import numpy as np import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200 from vae_model import get_gml_vae_models from utils import get_dataloaders, compute_swd_loss import matplotlib.pyplot as plt from src.music.config import MUSIC_REP_PATH from src.cocktails.config import FULL_COCKTAIL_REP_PATH import json import argparse device = 'cuda' if torch.cuda.is_available() else 'cpu' if torch.cuda.is_available(): print('Using GPUs') else: print('Using CPUs') music_rep_path = "/home/cedric/Documents/pianocktail/data/music/represented_small/" music_rep_path = MUSIC_REP_PATH + "music_reps_normalized_meanstd.pickle" # music_rep_path = "/home/cedric/Documents/pianocktail/data/music/32_represented/reps.pickle" LOSS = nn.CrossEntropyLoss() def run_epoch(epoch, model, data, params, opt, train): if epoch == params['n_epochs_music_pretrain']: print(f'Switching to bs: {params["batch_size"]}') for k in data.keys(): prefix = 'train' if train else 'test' data[k].batch_sampler.update_epoch_size_and_batch(params[prefix + '_epoch_size'], params['batch_size']) if train: model.train() else: model.eval() keys_to_track = params['keys_to_track'] losses = dict(zip(keys_to_track, [[] for _ in range(len(keys_to_track))])) step = 0 cf_matrices_music = [] cf_matrices_cocktail = [] for i_batch, data_music, data_cocktail, data_music_lab, data_cocktail_lab, data_reg_grounding \ in zip(range(len(data['music'])), data['music'], data['cocktail'], data['music_labeled'], data['cocktail_labeled'], data['reg_grounding']): x_music, _ = data_music x_cocktail, _, contains_egg, contains_bubbles = data_cocktail x_music_lab, labels_music = data_music_lab x_cocktail_lab, labels_cocktail = data_cocktail_lab x_reg_music, x_reg_cocktail = data_reg_grounding step += x_music.shape[0] if train: opt.zero_grad() # weight more examples that have bubbles or egg in the mse computation bubbles_egg_weights = torch.ones([contains_bubbles.shape[0]]) bubbles_egg_weights[contains_bubbles] += 1 bubbles_egg_weights[contains_egg] += 3 # vae x_hat_cocktail, z_cocktail, mu_cocktail, log_var_cocktail = model(x_cocktail, modality_in='cocktail', modality_out='cocktail') mse_loss_cocktail = torch.sum(((x_cocktail - x_hat_cocktail)**2).mean(axis=1) * bubbles_egg_weights) / bubbles_egg_weights.sum() if contains_bubbles.sum() > 0: bubble_mse = float(((x_cocktail - x_hat_cocktail)**2)[contains_bubbles, -3].mean()) else: bubble_mse = np.nan if contains_egg.sum() > 0: egg_mse = float(((x_cocktail - x_hat_cocktail)**2)[contains_egg, -1].mean()) else: egg_mse = np.nan kld_loss_cocktail = torch.mean(-0.5 * torch.sum(1 + log_var_cocktail - mu_cocktail ** 2 - log_var_cocktail.exp(), dim=1)) x_hat_music, z_music, mu_music, log_var_music = model(x_music, modality_in='music', modality_out='music') mse_loss_music = ((x_music - x_hat_music)**2).mean() kld_loss_music = torch.mean(-0.5 * torch.sum(1 + log_var_music - mu_music ** 2 - log_var_music.exp(), dim=1)) music_vae_loss = mse_loss_music + params['beta_vae'] * kld_loss_music cocktail_vae_loss = mse_loss_cocktail + params['beta_vae'] * kld_loss_cocktail vae_loss = cocktail_vae_loss + params['beta_music'] * music_vae_loss # music_vae_loss = mse_loss_music + params['beta_vae'] * kld_loss_music brb_kld_loss_cocktail, brb_kld_loss_music, brb_mse_loss_music, brb_mse_loss_cocktail, brb_mse_latent_loss, brb_music_vae_loss, brb_vae_loss = [0] * 7 if params['use_brb_vae']: # vae back to back out = model.forward_b2b(x_cocktail, modality_in_out='cocktail', modality_intermediate='music') x_hat_cocktail, x_intermediate_music, mu_cocktail, log_var_cocktail, z_cocktail, mu_music, log_var_music, z_music = out brb_mse_loss_cocktail = ((x_cocktail - x_hat_cocktail) ** 2).mean() brb_mse_latent_loss_1 = ((z_music - z_cocktail) ** 2).mean() brb_kld_loss_cocktail_1 = torch.mean(-0.5 * torch.sum(1 + log_var_cocktail - mu_cocktail ** 2 - log_var_cocktail.exp(), dim=1)) brb_kld_loss_music_1 = torch.mean(-0.5 * torch.sum(1 + log_var_music - mu_music ** 2 - log_var_music.exp(), dim=1)) # brb_cocktail_in_loss = mse_loss_cocktail + mse_latents_1 + params['beta_vae'] * (kld_loss_cocktail + kld_loss_music) out = model.forward_b2b(x_music, modality_in_out='music', modality_intermediate='cocktail') x_hat_music, x_intermediate_cocktail, mu_music, log_var_music, z_music, mu_cocktail, log_var_cocktail, z_cocktail = out brb_mse_loss_music = ((x_music - x_hat_music) ** 2).mean() brb_mse_latent_loss_2 = ((z_music - z_cocktail) ** 2).mean() brb_kld_loss_cocktail_2 = torch.mean(-0.5 * torch.sum(1 + log_var_cocktail - mu_cocktail ** 2 - log_var_cocktail.exp(), dim=1)) brb_kld_loss_music_2 = torch.mean(-0.5 * torch.sum(1 + log_var_music - mu_music ** 2 - log_var_music.exp(), dim=1)) # brb_music_in_loss = mse_loss_music + mse_latents_2 + params['beta_vae'] * (kld_loss_cocktail + kld_loss_music) brb_mse_latent_loss = (brb_mse_latent_loss_1 + brb_mse_latent_loss_2) / 2 brb_kld_loss_music = (brb_kld_loss_music_1 + brb_kld_loss_music_2) / 2 brb_kld_loss_cocktail = (brb_kld_loss_cocktail_1 + brb_kld_loss_cocktail_2) / 2 brb_vae_loss = brb_mse_latent_loss + brb_mse_loss_cocktail + brb_mse_loss_music + params['beta_vae'] * (brb_kld_loss_music + brb_kld_loss_cocktail) brb_music_vae_loss = brb_mse_loss_music + params['beta_vae'] * brb_kld_loss_music + brb_mse_latent_loss # swd if params['beta_swd'] > 0: swd_loss = compute_swd_loss(z_music, z_cocktail, params['latent_dim']) else: swd_loss = 0 # classif losses if params['beta_classif'] > 0: pred_music = model.classify(x_music_lab, modality_in='music') classif_loss_music = LOSS(pred_music, labels_music) accuracy_music = torch.mean((torch.argmax(pred_music, dim=1) == labels_music).float()) cf_matrices_music.append(get_cf_matrix(pred_music, labels_music)) pred_cocktail = model.classify(x_cocktail_lab, modality_in='cocktail') classif_loss_cocktail = LOSS(pred_cocktail, labels_cocktail) accuracy_cocktail = torch.mean((torch.argmax(pred_cocktail, dim=1) == labels_cocktail).float()) cf_matrices_cocktail.append(get_cf_matrix(pred_cocktail, labels_cocktail)) else: classif_loss_cocktail, classif_loss_music = 0, 0 accuracy_music, accuracy_cocktail = 0, 0 cf_matrices_cocktail.append(np.zeros((2, 2))) cf_matrices_music.append(np.zeros((2, 2))) if params['beta_reg_grounding'] > 0: x_hat_cocktail, _, _, _ = model(x_reg_music, modality_in='music', modality_out='cocktail', freeze_decoder=True) mse_reg_grounding = ((x_reg_cocktail - x_hat_cocktail) ** 2).mean() else: mse_reg_grounding = 0 if params['use_brb_vae']: global_minus_classif = params['beta_vae_loss'] * (vae_loss + brb_music_vae_loss) + params['beta_swd'] * swd_loss global_loss = params['beta_vae_loss'] * (vae_loss + brb_music_vae_loss) + params['beta_swd'] * swd_loss + \ params['beta_classif'] * (classif_loss_cocktail + params['beta_music_classif'] * classif_loss_music) else: global_minus_classif = params['beta_vae_loss'] * vae_loss + params['beta_swd'] * swd_loss global_loss = params['beta_vae_loss'] * vae_loss + params['beta_swd'] * swd_loss + params['beta_classif'] * (classif_loss_cocktail + classif_loss_music) + \ params['beta_reg_grounding'] * mse_reg_grounding # global_loss = params['beta_vae_loss'] * cocktail_vae_loss + params['beta_classif'] * (classif_loss_cocktail + classif_loss_music) + \ # params['beta_reg_grounding'] * mse_reg_grounding losses['brb_vae_loss'].append(float(brb_vae_loss)) losses['brb_mse_latent_loss'].append(float(brb_mse_latent_loss)) losses['brb_kld_loss_cocktail'].append(float(brb_kld_loss_cocktail)) losses['brb_kld_loss_music'].append(float(brb_kld_loss_music)) losses['brb_mse_loss_music'].append(float(brb_mse_loss_music)) losses['brb_mse_loss_cocktail'].append(float(brb_mse_loss_cocktail)) losses['swd_losses'].append(float(swd_loss)) losses['vae_losses'].append(float(vae_loss)) losses['kld_losses_music'].append(float(kld_loss_music)) losses['kld_losses_cocktail'].append(float(kld_loss_cocktail)) losses['mse_losses_music'].append(float(mse_loss_music)) losses['mse_losses_cocktail'].append(float(mse_loss_cocktail)) losses['global_losses'].append(float(global_loss)) losses['classif_losses_music'].append(float(classif_loss_music)) losses['classif_losses_cocktail'].append(float(classif_loss_cocktail)) losses['classif_acc_cocktail'].append(float(accuracy_cocktail)) losses['classif_acc_music'].append(float(accuracy_music)) losses['beta_reg_grounding'].append(float(mse_reg_grounding)) losses['bubble_mse'].append(bubble_mse) losses['egg_mse'].append(egg_mse) if train: # if epoch < params['n_epochs_music_pretrain']: # music_vae_loss.backward() # elif epoch >= params['n_epochs_music_pretrain'] and epoch < (params['n_epochs_music_pretrain'] + params['n_epochs_train']): # global_minus_classif.backward() # elif epoch >= (params['n_epochs_music_pretrain'] + params['n_epochs_train']): global_loss.backward() opt.step() if params['log_every'] != 0: if step != 0 and step % params['log_every'] == 0: print(f'\tBatch #{i_batch}') for k in params['keys_to_print']: if k != 'steps': print(f'\t {k}: Train: {np.nanmean(losses[k][-params["log_every"]:]):.3f}') # print(f'\t {k}: Train: {torch.mean(torch.cat(losses[k][-params["log_every"]:])):.3f}') return losses, [np.mean(cf_matrices_music, axis=0), np.mean(cf_matrices_cocktail, axis=0)] def get_cf_matrix(pred, labels): bs, dim = pred.shape labels = labels.detach().numpy() pred_labels = np.argmax(pred.detach().numpy(), axis=1) confusion_matrix = np.zeros((dim, dim)) for i in range(bs): confusion_matrix[labels[i], pred_labels[i]] += 1 for i in range(dim): if np.sum(confusion_matrix[i]) != 0: confusion_matrix[i] /= np.sum(confusion_matrix[i]) return confusion_matrix def train(model, dataloaders, params): keys_to_track = params['keys_to_track'] opt = torch.optim.AdamW(list(model.parameters()), lr=params['lr']) if params['decay_step'] > 0: scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=params['decay_step'], gamma=0.5) all_train_losses = dict(zip(keys_to_track, [[] for _ in range(len(keys_to_track))])) all_eval_losses = dict(zip(keys_to_track, [[] for _ in range(len(keys_to_track))])) best_eval_loss = np.inf data_train = dict() data_test = dict() for k in dataloaders.keys(): if '_train' in k: data_train[k[:-6]] = dataloaders[k] elif '_test' in k: data_test[k[:-5]] = dataloaders[k] else: raise ValueError # run first eval eval_losses, _ = run_epoch(0, model, data_test, params, opt, train=False) for k in params['keys_to_track']: if k == 'steps': all_train_losses[k].append(0) all_eval_losses[k].append(0) else: all_train_losses[k].append(np.nan) all_eval_losses[k].append(np.mean(eval_losses[k])) # all_train_losses[k].append(torch.Tensor([np.nan])) # all_eval_losses[k].append(torch.atleast_1d(torch.mean(torch.cat(eval_losses[k])))) print(f'Initial evaluation') for k in params['keys_to_print']: to_print = all_eval_losses[k][-1] if k != 'steps' else all_eval_losses[k][-1] # to_print = all_eval_losses[k][-1][0] if k != 'steps' else all_eval_losses[k][-1] print(f' {k}: Eval: {to_print:.3f}') step = 0 for epoch in range(params['epochs']): print(f'\n------------\nEpoch #{epoch}') # run training epoch train_losses, train_cf_matrices = run_epoch(epoch, model, data_train, params, opt, train=True) # run eval epoch eval_losses, eval_cf_matrices = run_epoch(epoch, model, data_test, params, opt, train=False) if epoch < params['n_epochs_music_pretrain']: epoch_size = params['pretrain_train_epoch_size'] else: epoch_size = params['train_epoch_size'] step += epoch_size for k in params['keys_to_track']: if k == 'steps': all_train_losses[k].append(epoch) all_eval_losses[k].append(epoch) else: all_train_losses[k].append(np.nanmean(train_losses[k])) all_eval_losses[k].append(np.nanmean(eval_losses[k])) # all_train_losses[k].append(torch.atleast_1d(torch.mean(torch.cat(train_losses[k])))) # all_eval_losses[k].append(torch.atleast_1d(torch.mean(torch.cat(eval_losses[k])))) if params['decay_step']: scheduler.step() # logging print(f'----\n\tEval epoch #{epoch}') for k in params['keys_to_print']: to_print_eval = all_eval_losses[k][-1] if k != 'steps' else all_eval_losses[k][-1] to_print_train = all_train_losses[k][-1] if k != 'steps' else all_train_losses[k][-1] # to_print_eval = all_eval_losses[k][-1][0] if k != 'steps' else all_eval_losses[k][-1] # to_print_train = all_train_losses[k][-1][0] if k != 'steps' else all_train_losses[k][-1] print(f'\t {k}: Eval: {to_print_eval:.3f} / Train: {to_print_train:.3f}') if epoch % params['plot_every'] == 0: plot_all_losses(all_train_losses.copy(), all_eval_losses.copy(), train_cf_matrices, eval_cf_matrices, params) # saving models save_losses(all_train_losses, all_eval_losses, params['save_path'] + 'results.txt') if params['save_every'] != 0: if epoch % params['save_every'] == 0: print('Saving model.') save_model(model, path=params['save_path'], name=f'epoch_{epoch}') if all_eval_losses['global_losses'][-1] < best_eval_loss: best_eval_loss = all_eval_losses['global_losses'][-1] print(f'New best eval loss: {best_eval_loss:.3f}, saving model.') # print(f'New best eval loss: {best_eval_loss[0]:.3f}, saving model.') save_model(model, path=params['save_path'], name='best_eval') print('Saving last model.') save_model(model, path=params['save_path'], name=f'last') return model, all_train_losses, all_eval_losses, train_cf_matrices, eval_cf_matrices def save_losses(train_losses, eval_losses, path): results = [] keys = sorted(train_losses.keys()) for k in keys: if k != 'steps': results.append(train_losses[k])#list(torch.cat(train_losses[k]).detach().cpu().numpy())) else: results.append(train_losses[k]) for k in keys: if k != 'steps': results.append(eval_losses[k])#list(torch.cat(eval_losses[k]).detach().cpu().numpy())) else: results.append(eval_losses[k]) np.savetxt(path, np.array(results)) def save_model(model, path, name): torch.save(model.state_dict(), path + f'checkpoints_{name}.save') def run_training(params): params = compute_expe_name_and_save_path(params) dataloaders, n_labels, stats = get_dataloaders(cocktail_rep_path=params['cocktail_rep_path'], music_rep_path=params['music_rep_path'], batch_size=params['pretrain_batch_size'], train_epoch_size=params['pretrain_train_epoch_size'], test_epoch_size=params['pretrain_test_epoch_size']) params['nb_classes'] = n_labels params['stats'] = stats params['classif_classes'] = dataloaders['music_labeled_train'].dataset.classes vae_gml_model = get_gml_vae_models(layer_type=params['layer_type'], input_dim_music=dataloaders['music_train'].dataset.dim_music, input_dim_cocktail=dataloaders['cocktail_train'].dataset.dim_cocktail, hidden_dim=params['hidden_dim'], n_hidden=params['n_hidden'], latent_dim=params['latent_dim'], nb_classes=params['nb_classes'], dropout=params['dropout']) params['dim_music'] = dataloaders['music_train'].dataset.dim_music params['dim_cocktail'] = dataloaders['cocktail_train'].dataset.dim_cocktail with open(params['save_path'] + 'params.json', 'w') as f: json.dump(params, f) models, train_losses, eval_losses, train_cf_matrices, eval_cf_matrices = train(vae_gml_model, dataloaders, params) plot_all_losses(train_losses.copy(), eval_losses.copy(), train_cf_matrices, eval_cf_matrices, params) return models, train_losses, eval_losses def plot_all_losses(train_losses, eval_losses, train_cf_matrices, eval_cf_matrices, params): plot_losses(train_losses, train_cf_matrices, 'train', params) plot_losses(eval_losses, eval_cf_matrices, 'eval', params) def plot_losses(losses, cf_matrices, split, params): save_path = params['save_path'] + 'plots/' os.makedirs(save_path, exist_ok=True) steps = losses['steps'] for k in losses.keys(): # if k != 'steps': # losses[k] = losses[k]#torch.cat(losses[k]).detach().cpu().numpy() # else: losses[k] = np.array(losses[k]) losses['sum_loss_classif'] = losses['classif_losses_music'] + losses['classif_losses_cocktail'] losses['av_acc_classif'] = (losses['classif_acc_cocktail'] + losses['classif_acc_music'])/2 losses['sum_mse_vae'] = losses['mse_losses_cocktail'] + losses['mse_losses_music'] losses['sum_kld_vae'] = losses['kld_losses_cocktail'] + losses['kld_losses_music'] plt.figure() for k in ['global_losses', 'vae_losses', 'swd_losses', 'sum_mse_vae', 'sum_kld_vae']: factor = 10 if k == 'swd_losses' else 1 plt.plot(steps, losses[k] * factor, label=k) plt.title(split) plt.legend() plt.ylim([0, 2.5]) plt.savefig(save_path + f'plot_high_level_losses_{split}.png') plt.close(plt.gcf()) plt.figure() for k in ['classif_acc_cocktail', 'classif_acc_music']: plt.plot(steps, losses[k], label=k) plt.title(split) plt.ylim([0, 1]) plt.legend() plt.savefig(save_path + f'plot_classif_accuracies_{split}.png') plt.close(plt.gcf()) plt.figure() for k in ['mse_losses_cocktail', 'mse_losses_music', 'kld_losses_cocktail', 'kld_losses_music', 'swd_losses', 'classif_losses_cocktail', 'classif_losses_music', 'beta_reg_grounding', 'bubble_mse', 'egg_mse']: factor = 10 if k == 'swd_losses' else 1 plt.plot(steps, losses[k] * factor, label=k) plt.title(split) plt.ylim([0, 2.5]) plt.legend() plt.savefig(save_path + f'plot_detailed_losses_{split}.png') plt.close(plt.gcf()) for i_k, k in enumerate(['music', 'cocktail']): plt.figure() plt.imshow(cf_matrices[i_k], vmin=0, vmax=1) labx = plt.xticks(range(len(params['classif_classes'])), params['classif_classes'], rotation=45) laby = plt.yticks(range(len(params['classif_classes'])), params['classif_classes']) labxx = plt.xlabel('predicted') labyy = plt.ylabel('true') plt.title(split + ' ' + k) plt.colorbar() plt.savefig(save_path + f'cf_matrix_{split}_{k}.png', artists=(labx, laby, labxx, labyy)) plt.close(plt.gcf()) if params['use_brb_vae']: plt.figure() for k in ['brb_vae_loss', 'brb_kld_loss_cocktail', 'brb_kld_loss_music', 'brb_mse_loss_music', 'brb_mse_loss_cocktail', 'mse_losses_music', 'brb_mse_latent_loss']: factor = 10 if k == 'swd_losses' else 1 plt.plot(steps, losses[k] * factor, label=k) plt.title(split) plt.ylim([0, 2.5]) plt.legend() plt.savefig(save_path + f'plot_detailed_brb_losses_{split}.png') plt.close(plt.gcf()) def parse_args(): parser = argparse.ArgumentParser(description="") parser.add_argument("--save_path", type=str, default="/home/cedric/Documents/pianocktail/experiments/music/representation_learning/saved_models/latent_translation/") parser.add_argument("--trial_id", type=str, default="b256_r128_classif001_ld40_meanstd") parser.add_argument("--hidden_dim", type=int, default=256) #128 parser.add_argument("--n_hidden", type=int, default=1) parser.add_argument("--latent_dim", type=int, default=40) #40 parser.add_argument("--n_epochs_music_pretrain", type=int, default=0) parser.add_argument("--n_epochs_train", type=int, default=200) parser.add_argument("--n_epochs_classif_finetune", type=int, default=0) parser.add_argument("--beta_vae_loss", type=float, default=1.) parser.add_argument("--beta_vae", type=float, default=1.2) # keep this low~1 to allow music classification... parser.add_argument("--beta_swd", type=float, default=1) parser.add_argument("--beta_reg_grounding", type=float, default=2.5) parser.add_argument("--beta_classif", type=float, default=0.01)#0.01) #TODO: try 0.1, default 0.01 parser.add_argument("--beta_music", type=float, default=100) # higher loss on the music that needs more to converge parser.add_argument("--beta_music_classif", type=float, default=300) # try300# higher loss on the music that needs more to converge parser.add_argument("--pretrain_batch_size", type=int, default=128) parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--lr", type=float, default=0.001) parser.add_argument("--decay_step", type=int, default=0) parser.add_argument("--cocktail_rep_path", type=str, default=FULL_COCKTAIL_REP_PATH) parser.add_argument("--music_rep_path", type=str, default=music_rep_path) parser.add_argument("--use_brb_vae", type=bool, default=False) parser.add_argument("--layer_type", type=str, default='gml') parser.add_argument("--dropout", type=float, default=0.2) # best parameters # parser = argparse.ArgumentParser(description="") # parser.add_argument("--save_path", type=str, default="/home/cedric/Documents/pianocktail/experiments/music/representation_learning/saved_models/latent_translation/") # parser.add_argument("--trial_id", type=str, default="b256_r128_classif001_ld40_meanstd") # parser.add_argument("--hidden_dim", type=int, default=256) #128 # parser.add_argument("--n_hidden", type=int, default=1) # parser.add_argument("--latent_dim", type=int, default=40) #40 # parser.add_argument("--n_epochs_music_pretrain", type=int, default=0) # parser.add_argument("--n_epochs_train", type=int, default=200) # parser.add_argument("--n_epochs_classif_finetune", type=int, default=0) # parser.add_argument("--beta_vae_loss", type=float, default=1.) # parser.add_argument("--beta_vae", type=float, default=1) # keep this low~1 to allow music classification... # parser.add_argument("--beta_swd", type=float, default=1) # parser.add_argument("--beta_reg_grounding", type=float, default=2.5) # parser.add_argument("--beta_classif", type=float, default=0.01)#0.01) #TODO: try 0.1, default 0.01 # parser.add_argument("--beta_music", type=float, default=100) # higher loss on the music that needs more to converge # parser.add_argument("--beta_music_classif", type=float, default=300) # try300# higher loss on the music that needs more to converge # parser.add_argument("--pretrain_batch_size", type=int, default=128) # parser.add_argument("--batch_size", type=int, default=32) # parser.add_argument("--lr", type=float, default=0.001) # parser.add_argument("--decay_step", type=int, default=0) # parser.add_argument("--cocktail_rep_path", type=str, default=FULL_COCKTAIL_REP_PATH) # parser.add_argument("--music_rep_path", type=str, default=music_rep_path) # parser.add_argument("--use_brb_vae", type=bool, default=False) # parser.add_argument("--layer_type", type=str, default='gml') # parser.add_argument("--dropout", type=float, default=0.2) args = parser.parse_args() return args def compute_expe_name_and_save_path(params): save_path = params['save_path'] + params["trial_id"] if params["use_brb_vae"]: save_path += '_usebrb' save_path += f'_lr{params["lr"]}' save_path += f'_bs{params["batch_size"]}' save_path += f'_bmusic{params["beta_music"]}' save_path += f'_bswd{params["beta_swd"]}' save_path += f'_bclassif{params["beta_classif"]}' save_path += f'_bvae{params["beta_vae_loss"]}' save_path += f'_bvaekld{params["beta_vae"]}' save_path += f'_lat{params["latent_dim"]}' save_path += f'_hd{params["n_hidden"]}x{params["hidden_dim"]}' save_path += f'_drop{params["dropout"]}' save_path += f'_decay{params["decay_step"]}' save_path += f'_layertype{params["layer_type"]}' number_added = False counter = 1 while os.path.exists(save_path): if number_added: save_path = '_'.join(save_path.split('_')[:-1]) + f'_{counter}' counter += 1 else: save_path += f'_{counter}' params["save_path"] = save_path + '/' os.makedirs(save_path) print(f'logging to {save_path}') return params if __name__ == '__main__': keys_to_track = ['steps', 'global_losses', 'vae_losses', 'mse_losses_cocktail', 'mse_losses_music', 'kld_losses_cocktail', 'kld_losses_music', 'swd_losses', 'classif_losses_cocktail', 'classif_losses_music', 'classif_acc_cocktail', 'classif_acc_music', 'brb_kld_loss_cocktail', 'brb_kld_loss_music', 'brb_mse_loss_music', 'brb_mse_loss_cocktail', 'brb_mse_latent_loss', 'brb_vae_loss', 'beta_reg_grounding', 'bubble_mse', 'egg_mse'] keys_to_print = ['steps', 'global_losses', 'vae_losses', 'mse_losses_cocktail', 'mse_losses_music', 'kld_losses_cocktail', 'kld_losses_music', 'swd_losses', 'classif_losses_cocktail', 'classif_losses_music', 'classif_acc_cocktail', 'classif_acc_music', 'beta_reg_grounding'] #TODO: first phase vae pretraining for music # then in second phase: vae cocktail and music, brb vaes args = parse_args() params = dict(nb_classes=None, save_every=0, #epochs log_every=0, #32*500, plot_every=10, # in epochs keys_to_track=keys_to_track, keys_to_print=keys_to_print,) params.update(vars(args)) params['train_epoch_size'] = params['batch_size'] * 100 params['test_epoch_size'] = params['batch_size'] * 10 params['pretrain_train_epoch_size'] = params['pretrain_batch_size'] * 100 params['pretrain_test_epoch_size'] = params['pretrain_batch_size'] * 10 params['epochs'] = params['n_epochs_music_pretrain'] + params['n_epochs_train'] + params['n_epochs_classif_finetune'] run_training(params)