Cédric Colas
initial commit
e775f6d
raw
history blame
28.3 kB
import os
import torch; torch.manual_seed(0)
import torch.nn as nn
import torch.nn.functional as F
import torch.utils
import torch.distributions
import numpy as np
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
from vae_model import get_gml_vae_models
from utils import get_dataloaders, compute_swd_loss
import matplotlib.pyplot as plt
from src.music.config import MUSIC_REP_PATH
from src.cocktails.config import FULL_COCKTAIL_REP_PATH
import json
import argparse
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if torch.cuda.is_available():
print('Using GPUs')
else:
print('Using CPUs')
music_rep_path = "/home/cedric/Documents/pianocktail/data/music/represented_small/"
music_rep_path = MUSIC_REP_PATH + "music_reps_normalized_meanstd.pickle"
# music_rep_path = "/home/cedric/Documents/pianocktail/data/music/32_represented/reps.pickle"
LOSS = nn.CrossEntropyLoss()
def run_epoch(epoch, model, data, params, opt, train):
if epoch == params['n_epochs_music_pretrain']:
print(f'Switching to bs: {params["batch_size"]}')
for k in data.keys():
prefix = 'train' if train else 'test'
data[k].batch_sampler.update_epoch_size_and_batch(params[prefix + '_epoch_size'], params['batch_size'])
if train:
model.train()
else:
model.eval()
keys_to_track = params['keys_to_track']
losses = dict(zip(keys_to_track, [[] for _ in range(len(keys_to_track))]))
step = 0
cf_matrices_music = []
cf_matrices_cocktail = []
for i_batch, data_music, data_cocktail, data_music_lab, data_cocktail_lab, data_reg_grounding \
in zip(range(len(data['music'])), data['music'], data['cocktail'], data['music_labeled'], data['cocktail_labeled'], data['reg_grounding']):
x_music, _ = data_music
x_cocktail, _, contains_egg, contains_bubbles = data_cocktail
x_music_lab, labels_music = data_music_lab
x_cocktail_lab, labels_cocktail = data_cocktail_lab
x_reg_music, x_reg_cocktail = data_reg_grounding
step += x_music.shape[0]
if train: opt.zero_grad()
# weight more examples that have bubbles or egg in the mse computation
bubbles_egg_weights = torch.ones([contains_bubbles.shape[0]])
bubbles_egg_weights[contains_bubbles] += 1
bubbles_egg_weights[contains_egg] += 3
# vae
x_hat_cocktail, z_cocktail, mu_cocktail, log_var_cocktail = model(x_cocktail, modality_in='cocktail', modality_out='cocktail')
mse_loss_cocktail = torch.sum(((x_cocktail - x_hat_cocktail)**2).mean(axis=1) * bubbles_egg_weights) / bubbles_egg_weights.sum()
if contains_bubbles.sum() > 0:
bubble_mse = float(((x_cocktail - x_hat_cocktail)**2)[contains_bubbles, -3].mean())
else:
bubble_mse = np.nan
if contains_egg.sum() > 0:
egg_mse = float(((x_cocktail - x_hat_cocktail)**2)[contains_egg, -1].mean())
else:
egg_mse = np.nan
kld_loss_cocktail = torch.mean(-0.5 * torch.sum(1 + log_var_cocktail - mu_cocktail ** 2 - log_var_cocktail.exp(), dim=1))
x_hat_music, z_music, mu_music, log_var_music = model(x_music, modality_in='music', modality_out='music')
mse_loss_music = ((x_music - x_hat_music)**2).mean()
kld_loss_music = torch.mean(-0.5 * torch.sum(1 + log_var_music - mu_music ** 2 - log_var_music.exp(), dim=1))
music_vae_loss = mse_loss_music + params['beta_vae'] * kld_loss_music
cocktail_vae_loss = mse_loss_cocktail + params['beta_vae'] * kld_loss_cocktail
vae_loss = cocktail_vae_loss + params['beta_music'] * music_vae_loss
# music_vae_loss = mse_loss_music + params['beta_vae'] * kld_loss_music
brb_kld_loss_cocktail, brb_kld_loss_music, brb_mse_loss_music, brb_mse_loss_cocktail, brb_mse_latent_loss, brb_music_vae_loss, brb_vae_loss = [0] * 7
if params['use_brb_vae']:
# vae back to back
out = model.forward_b2b(x_cocktail, modality_in_out='cocktail', modality_intermediate='music')
x_hat_cocktail, x_intermediate_music, mu_cocktail, log_var_cocktail, z_cocktail, mu_music, log_var_music, z_music = out
brb_mse_loss_cocktail = ((x_cocktail - x_hat_cocktail) ** 2).mean()
brb_mse_latent_loss_1 = ((z_music - z_cocktail) ** 2).mean()
brb_kld_loss_cocktail_1 = torch.mean(-0.5 * torch.sum(1 + log_var_cocktail - mu_cocktail ** 2 - log_var_cocktail.exp(), dim=1))
brb_kld_loss_music_1 = torch.mean(-0.5 * torch.sum(1 + log_var_music - mu_music ** 2 - log_var_music.exp(), dim=1))
# brb_cocktail_in_loss = mse_loss_cocktail + mse_latents_1 + params['beta_vae'] * (kld_loss_cocktail + kld_loss_music)
out = model.forward_b2b(x_music, modality_in_out='music', modality_intermediate='cocktail')
x_hat_music, x_intermediate_cocktail, mu_music, log_var_music, z_music, mu_cocktail, log_var_cocktail, z_cocktail = out
brb_mse_loss_music = ((x_music - x_hat_music) ** 2).mean()
brb_mse_latent_loss_2 = ((z_music - z_cocktail) ** 2).mean()
brb_kld_loss_cocktail_2 = torch.mean(-0.5 * torch.sum(1 + log_var_cocktail - mu_cocktail ** 2 - log_var_cocktail.exp(), dim=1))
brb_kld_loss_music_2 = torch.mean(-0.5 * torch.sum(1 + log_var_music - mu_music ** 2 - log_var_music.exp(), dim=1))
# brb_music_in_loss = mse_loss_music + mse_latents_2 + params['beta_vae'] * (kld_loss_cocktail + kld_loss_music)
brb_mse_latent_loss = (brb_mse_latent_loss_1 + brb_mse_latent_loss_2) / 2
brb_kld_loss_music = (brb_kld_loss_music_1 + brb_kld_loss_music_2) / 2
brb_kld_loss_cocktail = (brb_kld_loss_cocktail_1 + brb_kld_loss_cocktail_2) / 2
brb_vae_loss = brb_mse_latent_loss + brb_mse_loss_cocktail + brb_mse_loss_music + params['beta_vae'] * (brb_kld_loss_music + brb_kld_loss_cocktail)
brb_music_vae_loss = brb_mse_loss_music + params['beta_vae'] * brb_kld_loss_music + brb_mse_latent_loss
# swd
if params['beta_swd'] > 0:
swd_loss = compute_swd_loss(z_music, z_cocktail, params['latent_dim'])
else:
swd_loss = 0
# classif losses
if params['beta_classif'] > 0:
pred_music = model.classify(x_music_lab, modality_in='music')
classif_loss_music = LOSS(pred_music, labels_music)
accuracy_music = torch.mean((torch.argmax(pred_music, dim=1) == labels_music).float())
cf_matrices_music.append(get_cf_matrix(pred_music, labels_music))
pred_cocktail = model.classify(x_cocktail_lab, modality_in='cocktail')
classif_loss_cocktail = LOSS(pred_cocktail, labels_cocktail)
accuracy_cocktail = torch.mean((torch.argmax(pred_cocktail, dim=1) == labels_cocktail).float())
cf_matrices_cocktail.append(get_cf_matrix(pred_cocktail, labels_cocktail))
else:
classif_loss_cocktail, classif_loss_music = 0, 0
accuracy_music, accuracy_cocktail = 0, 0
cf_matrices_cocktail.append(np.zeros((2, 2)))
cf_matrices_music.append(np.zeros((2, 2)))
if params['beta_reg_grounding'] > 0:
x_hat_cocktail, _, _, _ = model(x_reg_music, modality_in='music', modality_out='cocktail', freeze_decoder=True)
mse_reg_grounding = ((x_reg_cocktail - x_hat_cocktail) ** 2).mean()
else:
mse_reg_grounding = 0
if params['use_brb_vae']:
global_minus_classif = params['beta_vae_loss'] * (vae_loss + brb_music_vae_loss) + params['beta_swd'] * swd_loss
global_loss = params['beta_vae_loss'] * (vae_loss + brb_music_vae_loss) + params['beta_swd'] * swd_loss + \
params['beta_classif'] * (classif_loss_cocktail + params['beta_music_classif'] * classif_loss_music)
else:
global_minus_classif = params['beta_vae_loss'] * vae_loss + params['beta_swd'] * swd_loss
global_loss = params['beta_vae_loss'] * vae_loss + params['beta_swd'] * swd_loss + params['beta_classif'] * (classif_loss_cocktail + classif_loss_music) + \
params['beta_reg_grounding'] * mse_reg_grounding
# global_loss = params['beta_vae_loss'] * cocktail_vae_loss + params['beta_classif'] * (classif_loss_cocktail + classif_loss_music) + \
# params['beta_reg_grounding'] * mse_reg_grounding
losses['brb_vae_loss'].append(float(brb_vae_loss))
losses['brb_mse_latent_loss'].append(float(brb_mse_latent_loss))
losses['brb_kld_loss_cocktail'].append(float(brb_kld_loss_cocktail))
losses['brb_kld_loss_music'].append(float(brb_kld_loss_music))
losses['brb_mse_loss_music'].append(float(brb_mse_loss_music))
losses['brb_mse_loss_cocktail'].append(float(brb_mse_loss_cocktail))
losses['swd_losses'].append(float(swd_loss))
losses['vae_losses'].append(float(vae_loss))
losses['kld_losses_music'].append(float(kld_loss_music))
losses['kld_losses_cocktail'].append(float(kld_loss_cocktail))
losses['mse_losses_music'].append(float(mse_loss_music))
losses['mse_losses_cocktail'].append(float(mse_loss_cocktail))
losses['global_losses'].append(float(global_loss))
losses['classif_losses_music'].append(float(classif_loss_music))
losses['classif_losses_cocktail'].append(float(classif_loss_cocktail))
losses['classif_acc_cocktail'].append(float(accuracy_cocktail))
losses['classif_acc_music'].append(float(accuracy_music))
losses['beta_reg_grounding'].append(float(mse_reg_grounding))
losses['bubble_mse'].append(bubble_mse)
losses['egg_mse'].append(egg_mse)
if train:
# if epoch < params['n_epochs_music_pretrain']:
# music_vae_loss.backward()
# elif epoch >= params['n_epochs_music_pretrain'] and epoch < (params['n_epochs_music_pretrain'] + params['n_epochs_train']):
# global_minus_classif.backward()
# elif epoch >= (params['n_epochs_music_pretrain'] + params['n_epochs_train']):
global_loss.backward()
opt.step()
if params['log_every'] != 0:
if step != 0 and step % params['log_every'] == 0:
print(f'\tBatch #{i_batch}')
for k in params['keys_to_print']:
if k != 'steps':
print(f'\t {k}: Train: {np.nanmean(losses[k][-params["log_every"]:]):.3f}')
# print(f'\t {k}: Train: {torch.mean(torch.cat(losses[k][-params["log_every"]:])):.3f}')
return losses, [np.mean(cf_matrices_music, axis=0), np.mean(cf_matrices_cocktail, axis=0)]
def get_cf_matrix(pred, labels):
bs, dim = pred.shape
labels = labels.detach().numpy()
pred_labels = np.argmax(pred.detach().numpy(), axis=1)
confusion_matrix = np.zeros((dim, dim))
for i in range(bs):
confusion_matrix[labels[i], pred_labels[i]] += 1
for i in range(dim):
if np.sum(confusion_matrix[i]) != 0:
confusion_matrix[i] /= np.sum(confusion_matrix[i])
return confusion_matrix
def train(model, dataloaders, params):
keys_to_track = params['keys_to_track']
opt = torch.optim.AdamW(list(model.parameters()), lr=params['lr'])
if params['decay_step'] > 0: scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=params['decay_step'], gamma=0.5)
all_train_losses = dict(zip(keys_to_track, [[] for _ in range(len(keys_to_track))]))
all_eval_losses = dict(zip(keys_to_track, [[] for _ in range(len(keys_to_track))]))
best_eval_loss = np.inf
data_train = dict()
data_test = dict()
for k in dataloaders.keys():
if '_train' in k:
data_train[k[:-6]] = dataloaders[k]
elif '_test' in k:
data_test[k[:-5]] = dataloaders[k]
else:
raise ValueError
# run first eval
eval_losses, _ = run_epoch(0, model, data_test, params, opt, train=False)
for k in params['keys_to_track']:
if k == 'steps':
all_train_losses[k].append(0)
all_eval_losses[k].append(0)
else:
all_train_losses[k].append(np.nan)
all_eval_losses[k].append(np.mean(eval_losses[k]))
# all_train_losses[k].append(torch.Tensor([np.nan]))
# all_eval_losses[k].append(torch.atleast_1d(torch.mean(torch.cat(eval_losses[k]))))
print(f'Initial evaluation')
for k in params['keys_to_print']:
to_print = all_eval_losses[k][-1] if k != 'steps' else all_eval_losses[k][-1]
# to_print = all_eval_losses[k][-1][0] if k != 'steps' else all_eval_losses[k][-1]
print(f' {k}: Eval: {to_print:.3f}')
step = 0
for epoch in range(params['epochs']):
print(f'\n------------\nEpoch #{epoch}')
# run training epoch
train_losses, train_cf_matrices = run_epoch(epoch, model, data_train, params, opt, train=True)
# run eval epoch
eval_losses, eval_cf_matrices = run_epoch(epoch, model, data_test, params, opt, train=False)
if epoch < params['n_epochs_music_pretrain']:
epoch_size = params['pretrain_train_epoch_size']
else:
epoch_size = params['train_epoch_size']
step += epoch_size
for k in params['keys_to_track']:
if k == 'steps':
all_train_losses[k].append(epoch)
all_eval_losses[k].append(epoch)
else:
all_train_losses[k].append(np.nanmean(train_losses[k]))
all_eval_losses[k].append(np.nanmean(eval_losses[k]))
# all_train_losses[k].append(torch.atleast_1d(torch.mean(torch.cat(train_losses[k]))))
# all_eval_losses[k].append(torch.atleast_1d(torch.mean(torch.cat(eval_losses[k]))))
if params['decay_step']: scheduler.step()
# logging
print(f'----\n\tEval epoch #{epoch}')
for k in params['keys_to_print']:
to_print_eval = all_eval_losses[k][-1] if k != 'steps' else all_eval_losses[k][-1]
to_print_train = all_train_losses[k][-1] if k != 'steps' else all_train_losses[k][-1]
# to_print_eval = all_eval_losses[k][-1][0] if k != 'steps' else all_eval_losses[k][-1]
# to_print_train = all_train_losses[k][-1][0] if k != 'steps' else all_train_losses[k][-1]
print(f'\t {k}: Eval: {to_print_eval:.3f} / Train: {to_print_train:.3f}')
if epoch % params['plot_every'] == 0:
plot_all_losses(all_train_losses.copy(), all_eval_losses.copy(), train_cf_matrices, eval_cf_matrices, params)
# saving models
save_losses(all_train_losses, all_eval_losses, params['save_path'] + 'results.txt')
if params['save_every'] != 0:
if epoch % params['save_every'] == 0:
print('Saving model.')
save_model(model, path=params['save_path'], name=f'epoch_{epoch}')
if all_eval_losses['global_losses'][-1] < best_eval_loss:
best_eval_loss = all_eval_losses['global_losses'][-1]
print(f'New best eval loss: {best_eval_loss:.3f}, saving model.')
# print(f'New best eval loss: {best_eval_loss[0]:.3f}, saving model.')
save_model(model, path=params['save_path'], name='best_eval')
print('Saving last model.')
save_model(model, path=params['save_path'], name=f'last')
return model, all_train_losses, all_eval_losses, train_cf_matrices, eval_cf_matrices
def save_losses(train_losses, eval_losses, path):
results = []
keys = sorted(train_losses.keys())
for k in keys:
if k != 'steps':
results.append(train_losses[k])#list(torch.cat(train_losses[k]).detach().cpu().numpy()))
else:
results.append(train_losses[k])
for k in keys:
if k != 'steps':
results.append(eval_losses[k])#list(torch.cat(eval_losses[k]).detach().cpu().numpy()))
else:
results.append(eval_losses[k])
np.savetxt(path, np.array(results))
def save_model(model, path, name):
torch.save(model.state_dict(), path + f'checkpoints_{name}.save')
def run_training(params):
params = compute_expe_name_and_save_path(params)
dataloaders, n_labels, stats = get_dataloaders(cocktail_rep_path=params['cocktail_rep_path'],
music_rep_path=params['music_rep_path'],
batch_size=params['pretrain_batch_size'],
train_epoch_size=params['pretrain_train_epoch_size'],
test_epoch_size=params['pretrain_test_epoch_size'])
params['nb_classes'] = n_labels
params['stats'] = stats
params['classif_classes'] = dataloaders['music_labeled_train'].dataset.classes
vae_gml_model = get_gml_vae_models(layer_type=params['layer_type'],
input_dim_music=dataloaders['music_train'].dataset.dim_music,
input_dim_cocktail=dataloaders['cocktail_train'].dataset.dim_cocktail,
hidden_dim=params['hidden_dim'],
n_hidden=params['n_hidden'],
latent_dim=params['latent_dim'],
nb_classes=params['nb_classes'],
dropout=params['dropout'])
params['dim_music'] = dataloaders['music_train'].dataset.dim_music
params['dim_cocktail'] = dataloaders['cocktail_train'].dataset.dim_cocktail
with open(params['save_path'] + 'params.json', 'w') as f:
json.dump(params, f)
models, train_losses, eval_losses, train_cf_matrices, eval_cf_matrices = train(vae_gml_model, dataloaders, params)
plot_all_losses(train_losses.copy(), eval_losses.copy(), train_cf_matrices, eval_cf_matrices, params)
return models, train_losses, eval_losses
def plot_all_losses(train_losses, eval_losses, train_cf_matrices, eval_cf_matrices, params):
plot_losses(train_losses, train_cf_matrices, 'train', params)
plot_losses(eval_losses, eval_cf_matrices, 'eval', params)
def plot_losses(losses, cf_matrices, split, params):
save_path = params['save_path'] + 'plots/'
os.makedirs(save_path, exist_ok=True)
steps = losses['steps']
for k in losses.keys():
# if k != 'steps':
# losses[k] = losses[k]#torch.cat(losses[k]).detach().cpu().numpy()
# else:
losses[k] = np.array(losses[k])
losses['sum_loss_classif'] = losses['classif_losses_music'] + losses['classif_losses_cocktail']
losses['av_acc_classif'] = (losses['classif_acc_cocktail'] + losses['classif_acc_music'])/2
losses['sum_mse_vae'] = losses['mse_losses_cocktail'] + losses['mse_losses_music']
losses['sum_kld_vae'] = losses['kld_losses_cocktail'] + losses['kld_losses_music']
plt.figure()
for k in ['global_losses', 'vae_losses', 'swd_losses', 'sum_mse_vae', 'sum_kld_vae']:
factor = 10 if k == 'swd_losses' else 1
plt.plot(steps, losses[k] * factor, label=k)
plt.title(split)
plt.legend()
plt.ylim([0, 2.5])
plt.savefig(save_path + f'plot_high_level_losses_{split}.png')
plt.close(plt.gcf())
plt.figure()
for k in ['classif_acc_cocktail', 'classif_acc_music']:
plt.plot(steps, losses[k], label=k)
plt.title(split)
plt.ylim([0, 1])
plt.legend()
plt.savefig(save_path + f'plot_classif_accuracies_{split}.png')
plt.close(plt.gcf())
plt.figure()
for k in ['mse_losses_cocktail', 'mse_losses_music', 'kld_losses_cocktail',
'kld_losses_music', 'swd_losses', 'classif_losses_cocktail', 'classif_losses_music', 'beta_reg_grounding',
'bubble_mse', 'egg_mse']:
factor = 10 if k == 'swd_losses' else 1
plt.plot(steps, losses[k] * factor, label=k)
plt.title(split)
plt.ylim([0, 2.5])
plt.legend()
plt.savefig(save_path + f'plot_detailed_losses_{split}.png')
plt.close(plt.gcf())
for i_k, k in enumerate(['music', 'cocktail']):
plt.figure()
plt.imshow(cf_matrices[i_k], vmin=0, vmax=1)
labx = plt.xticks(range(len(params['classif_classes'])), params['classif_classes'], rotation=45)
laby = plt.yticks(range(len(params['classif_classes'])), params['classif_classes'])
labxx = plt.xlabel('predicted')
labyy = plt.ylabel('true')
plt.title(split + ' ' + k)
plt.colorbar()
plt.savefig(save_path + f'cf_matrix_{split}_{k}.png', artists=(labx, laby, labxx, labyy))
plt.close(plt.gcf())
if params['use_brb_vae']:
plt.figure()
for k in ['brb_vae_loss', 'brb_kld_loss_cocktail', 'brb_kld_loss_music', 'brb_mse_loss_music', 'brb_mse_loss_cocktail', 'mse_losses_music', 'brb_mse_latent_loss']:
factor = 10 if k == 'swd_losses' else 1
plt.plot(steps, losses[k] * factor, label=k)
plt.title(split)
plt.ylim([0, 2.5])
plt.legend()
plt.savefig(save_path + f'plot_detailed_brb_losses_{split}.png')
plt.close(plt.gcf())
def parse_args():
parser = argparse.ArgumentParser(description="")
parser.add_argument("--save_path", type=str, default="/home/cedric/Documents/pianocktail/experiments/music/representation_learning/saved_models/latent_translation/")
parser.add_argument("--trial_id", type=str, default="b256_r128_classif001_ld40_meanstd")
parser.add_argument("--hidden_dim", type=int, default=256) #128
parser.add_argument("--n_hidden", type=int, default=1)
parser.add_argument("--latent_dim", type=int, default=40) #40
parser.add_argument("--n_epochs_music_pretrain", type=int, default=0)
parser.add_argument("--n_epochs_train", type=int, default=200)
parser.add_argument("--n_epochs_classif_finetune", type=int, default=0)
parser.add_argument("--beta_vae_loss", type=float, default=1.)
parser.add_argument("--beta_vae", type=float, default=1.2) # keep this low~1 to allow music classification...
parser.add_argument("--beta_swd", type=float, default=1)
parser.add_argument("--beta_reg_grounding", type=float, default=2.5)
parser.add_argument("--beta_classif", type=float, default=0.01)#0.01) #TODO: try 0.1, default 0.01
parser.add_argument("--beta_music", type=float, default=100) # higher loss on the music that needs more to converge
parser.add_argument("--beta_music_classif", type=float, default=300) # try300# higher loss on the music that needs more to converge
parser.add_argument("--pretrain_batch_size", type=int, default=128)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--decay_step", type=int, default=0)
parser.add_argument("--cocktail_rep_path", type=str, default=FULL_COCKTAIL_REP_PATH)
parser.add_argument("--music_rep_path", type=str, default=music_rep_path)
parser.add_argument("--use_brb_vae", type=bool, default=False)
parser.add_argument("--layer_type", type=str, default='gml')
parser.add_argument("--dropout", type=float, default=0.2)
# best parameters
# parser = argparse.ArgumentParser(description="")
# parser.add_argument("--save_path", type=str, default="/home/cedric/Documents/pianocktail/experiments/music/representation_learning/saved_models/latent_translation/")
# parser.add_argument("--trial_id", type=str, default="b256_r128_classif001_ld40_meanstd")
# parser.add_argument("--hidden_dim", type=int, default=256) #128
# parser.add_argument("--n_hidden", type=int, default=1)
# parser.add_argument("--latent_dim", type=int, default=40) #40
# parser.add_argument("--n_epochs_music_pretrain", type=int, default=0)
# parser.add_argument("--n_epochs_train", type=int, default=200)
# parser.add_argument("--n_epochs_classif_finetune", type=int, default=0)
# parser.add_argument("--beta_vae_loss", type=float, default=1.)
# parser.add_argument("--beta_vae", type=float, default=1) # keep this low~1 to allow music classification...
# parser.add_argument("--beta_swd", type=float, default=1)
# parser.add_argument("--beta_reg_grounding", type=float, default=2.5)
# parser.add_argument("--beta_classif", type=float, default=0.01)#0.01) #TODO: try 0.1, default 0.01
# parser.add_argument("--beta_music", type=float, default=100) # higher loss on the music that needs more to converge
# parser.add_argument("--beta_music_classif", type=float, default=300) # try300# higher loss on the music that needs more to converge
# parser.add_argument("--pretrain_batch_size", type=int, default=128)
# parser.add_argument("--batch_size", type=int, default=32)
# parser.add_argument("--lr", type=float, default=0.001)
# parser.add_argument("--decay_step", type=int, default=0)
# parser.add_argument("--cocktail_rep_path", type=str, default=FULL_COCKTAIL_REP_PATH)
# parser.add_argument("--music_rep_path", type=str, default=music_rep_path)
# parser.add_argument("--use_brb_vae", type=bool, default=False)
# parser.add_argument("--layer_type", type=str, default='gml')
# parser.add_argument("--dropout", type=float, default=0.2)
args = parser.parse_args()
return args
def compute_expe_name_and_save_path(params):
save_path = params['save_path'] + params["trial_id"]
if params["use_brb_vae"]:
save_path += '_usebrb'
save_path += f'_lr{params["lr"]}'
save_path += f'_bs{params["batch_size"]}'
save_path += f'_bmusic{params["beta_music"]}'
save_path += f'_bswd{params["beta_swd"]}'
save_path += f'_bclassif{params["beta_classif"]}'
save_path += f'_bvae{params["beta_vae_loss"]}'
save_path += f'_bvaekld{params["beta_vae"]}'
save_path += f'_lat{params["latent_dim"]}'
save_path += f'_hd{params["n_hidden"]}x{params["hidden_dim"]}'
save_path += f'_drop{params["dropout"]}'
save_path += f'_decay{params["decay_step"]}'
save_path += f'_layertype{params["layer_type"]}'
number_added = False
counter = 1
while os.path.exists(save_path):
if number_added:
save_path = '_'.join(save_path.split('_')[:-1]) + f'_{counter}'
counter += 1
else:
save_path += f'_{counter}'
params["save_path"] = save_path + '/'
os.makedirs(save_path)
print(f'logging to {save_path}')
return params
if __name__ == '__main__':
keys_to_track = ['steps', 'global_losses', 'vae_losses', 'mse_losses_cocktail', 'mse_losses_music', 'kld_losses_cocktail',
'kld_losses_music', 'swd_losses', 'classif_losses_cocktail', 'classif_losses_music', 'classif_acc_cocktail', 'classif_acc_music',
'brb_kld_loss_cocktail', 'brb_kld_loss_music', 'brb_mse_loss_music', 'brb_mse_loss_cocktail', 'brb_mse_latent_loss', 'brb_vae_loss', 'beta_reg_grounding',
'bubble_mse', 'egg_mse']
keys_to_print = ['steps', 'global_losses', 'vae_losses', 'mse_losses_cocktail', 'mse_losses_music', 'kld_losses_cocktail',
'kld_losses_music', 'swd_losses', 'classif_losses_cocktail', 'classif_losses_music', 'classif_acc_cocktail', 'classif_acc_music', 'beta_reg_grounding']
#TODO: first phase vae pretraining for music
# then in second phase: vae cocktail and music, brb vaes
args = parse_args()
params = dict(nb_classes=None,
save_every=0, #epochs
log_every=0, #32*500,
plot_every=10, # in epochs
keys_to_track=keys_to_track,
keys_to_print=keys_to_print,)
params.update(vars(args))
params['train_epoch_size'] = params['batch_size'] * 100
params['test_epoch_size'] = params['batch_size'] * 10
params['pretrain_train_epoch_size'] = params['pretrain_batch_size'] * 100
params['pretrain_test_epoch_size'] = params['pretrain_batch_size'] * 10
params['epochs'] = params['n_epochs_music_pretrain'] + params['n_epochs_train'] + params['n_epochs_classif_finetune']
run_training(params)