Spaces:
Runtime error
Runtime error
import os | |
import torch; torch.manual_seed(0) | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.utils | |
import torch.distributions | |
import numpy as np | |
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200 | |
from vae_model import get_gml_vae_models | |
from utils import get_dataloaders, compute_swd_loss | |
import matplotlib.pyplot as plt | |
from src.music.config import MUSIC_REP_PATH | |
from src.cocktails.config import FULL_COCKTAIL_REP_PATH | |
import json | |
import argparse | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
if torch.cuda.is_available(): | |
print('Using GPUs') | |
else: | |
print('Using CPUs') | |
music_rep_path = "/home/cedric/Documents/pianocktail/data/music/represented_small/" | |
music_rep_path = MUSIC_REP_PATH + "music_reps_normalized_meanstd.pickle" | |
# music_rep_path = "/home/cedric/Documents/pianocktail/data/music/32_represented/reps.pickle" | |
LOSS = nn.CrossEntropyLoss() | |
def run_epoch(epoch, model, data, params, opt, train): | |
if epoch == params['n_epochs_music_pretrain']: | |
print(f'Switching to bs: {params["batch_size"]}') | |
for k in data.keys(): | |
prefix = 'train' if train else 'test' | |
data[k].batch_sampler.update_epoch_size_and_batch(params[prefix + '_epoch_size'], params['batch_size']) | |
if train: | |
model.train() | |
else: | |
model.eval() | |
keys_to_track = params['keys_to_track'] | |
losses = dict(zip(keys_to_track, [[] for _ in range(len(keys_to_track))])) | |
step = 0 | |
cf_matrices_music = [] | |
cf_matrices_cocktail = [] | |
for i_batch, data_music, data_cocktail, data_music_lab, data_cocktail_lab, data_reg_grounding \ | |
in zip(range(len(data['music'])), data['music'], data['cocktail'], data['music_labeled'], data['cocktail_labeled'], data['reg_grounding']): | |
x_music, _ = data_music | |
x_cocktail, _, contains_egg, contains_bubbles = data_cocktail | |
x_music_lab, labels_music = data_music_lab | |
x_cocktail_lab, labels_cocktail = data_cocktail_lab | |
x_reg_music, x_reg_cocktail = data_reg_grounding | |
step += x_music.shape[0] | |
if train: opt.zero_grad() | |
# weight more examples that have bubbles or egg in the mse computation | |
bubbles_egg_weights = torch.ones([contains_bubbles.shape[0]]) | |
bubbles_egg_weights[contains_bubbles] += 1 | |
bubbles_egg_weights[contains_egg] += 3 | |
# vae | |
x_hat_cocktail, z_cocktail, mu_cocktail, log_var_cocktail = model(x_cocktail, modality_in='cocktail', modality_out='cocktail') | |
mse_loss_cocktail = torch.sum(((x_cocktail - x_hat_cocktail)**2).mean(axis=1) * bubbles_egg_weights) / bubbles_egg_weights.sum() | |
if contains_bubbles.sum() > 0: | |
bubble_mse = float(((x_cocktail - x_hat_cocktail)**2)[contains_bubbles, -3].mean()) | |
else: | |
bubble_mse = np.nan | |
if contains_egg.sum() > 0: | |
egg_mse = float(((x_cocktail - x_hat_cocktail)**2)[contains_egg, -1].mean()) | |
else: | |
egg_mse = np.nan | |
kld_loss_cocktail = torch.mean(-0.5 * torch.sum(1 + log_var_cocktail - mu_cocktail ** 2 - log_var_cocktail.exp(), dim=1)) | |
x_hat_music, z_music, mu_music, log_var_music = model(x_music, modality_in='music', modality_out='music') | |
mse_loss_music = ((x_music - x_hat_music)**2).mean() | |
kld_loss_music = torch.mean(-0.5 * torch.sum(1 + log_var_music - mu_music ** 2 - log_var_music.exp(), dim=1)) | |
music_vae_loss = mse_loss_music + params['beta_vae'] * kld_loss_music | |
cocktail_vae_loss = mse_loss_cocktail + params['beta_vae'] * kld_loss_cocktail | |
vae_loss = cocktail_vae_loss + params['beta_music'] * music_vae_loss | |
# music_vae_loss = mse_loss_music + params['beta_vae'] * kld_loss_music | |
brb_kld_loss_cocktail, brb_kld_loss_music, brb_mse_loss_music, brb_mse_loss_cocktail, brb_mse_latent_loss, brb_music_vae_loss, brb_vae_loss = [0] * 7 | |
if params['use_brb_vae']: | |
# vae back to back | |
out = model.forward_b2b(x_cocktail, modality_in_out='cocktail', modality_intermediate='music') | |
x_hat_cocktail, x_intermediate_music, mu_cocktail, log_var_cocktail, z_cocktail, mu_music, log_var_music, z_music = out | |
brb_mse_loss_cocktail = ((x_cocktail - x_hat_cocktail) ** 2).mean() | |
brb_mse_latent_loss_1 = ((z_music - z_cocktail) ** 2).mean() | |
brb_kld_loss_cocktail_1 = torch.mean(-0.5 * torch.sum(1 + log_var_cocktail - mu_cocktail ** 2 - log_var_cocktail.exp(), dim=1)) | |
brb_kld_loss_music_1 = torch.mean(-0.5 * torch.sum(1 + log_var_music - mu_music ** 2 - log_var_music.exp(), dim=1)) | |
# brb_cocktail_in_loss = mse_loss_cocktail + mse_latents_1 + params['beta_vae'] * (kld_loss_cocktail + kld_loss_music) | |
out = model.forward_b2b(x_music, modality_in_out='music', modality_intermediate='cocktail') | |
x_hat_music, x_intermediate_cocktail, mu_music, log_var_music, z_music, mu_cocktail, log_var_cocktail, z_cocktail = out | |
brb_mse_loss_music = ((x_music - x_hat_music) ** 2).mean() | |
brb_mse_latent_loss_2 = ((z_music - z_cocktail) ** 2).mean() | |
brb_kld_loss_cocktail_2 = torch.mean(-0.5 * torch.sum(1 + log_var_cocktail - mu_cocktail ** 2 - log_var_cocktail.exp(), dim=1)) | |
brb_kld_loss_music_2 = torch.mean(-0.5 * torch.sum(1 + log_var_music - mu_music ** 2 - log_var_music.exp(), dim=1)) | |
# brb_music_in_loss = mse_loss_music + mse_latents_2 + params['beta_vae'] * (kld_loss_cocktail + kld_loss_music) | |
brb_mse_latent_loss = (brb_mse_latent_loss_1 + brb_mse_latent_loss_2) / 2 | |
brb_kld_loss_music = (brb_kld_loss_music_1 + brb_kld_loss_music_2) / 2 | |
brb_kld_loss_cocktail = (brb_kld_loss_cocktail_1 + brb_kld_loss_cocktail_2) / 2 | |
brb_vae_loss = brb_mse_latent_loss + brb_mse_loss_cocktail + brb_mse_loss_music + params['beta_vae'] * (brb_kld_loss_music + brb_kld_loss_cocktail) | |
brb_music_vae_loss = brb_mse_loss_music + params['beta_vae'] * brb_kld_loss_music + brb_mse_latent_loss | |
# swd | |
if params['beta_swd'] > 0: | |
swd_loss = compute_swd_loss(z_music, z_cocktail, params['latent_dim']) | |
else: | |
swd_loss = 0 | |
# classif losses | |
if params['beta_classif'] > 0: | |
pred_music = model.classify(x_music_lab, modality_in='music') | |
classif_loss_music = LOSS(pred_music, labels_music) | |
accuracy_music = torch.mean((torch.argmax(pred_music, dim=1) == labels_music).float()) | |
cf_matrices_music.append(get_cf_matrix(pred_music, labels_music)) | |
pred_cocktail = model.classify(x_cocktail_lab, modality_in='cocktail') | |
classif_loss_cocktail = LOSS(pred_cocktail, labels_cocktail) | |
accuracy_cocktail = torch.mean((torch.argmax(pred_cocktail, dim=1) == labels_cocktail).float()) | |
cf_matrices_cocktail.append(get_cf_matrix(pred_cocktail, labels_cocktail)) | |
else: | |
classif_loss_cocktail, classif_loss_music = 0, 0 | |
accuracy_music, accuracy_cocktail = 0, 0 | |
cf_matrices_cocktail.append(np.zeros((2, 2))) | |
cf_matrices_music.append(np.zeros((2, 2))) | |
if params['beta_reg_grounding'] > 0: | |
x_hat_cocktail, _, _, _ = model(x_reg_music, modality_in='music', modality_out='cocktail', freeze_decoder=True) | |
mse_reg_grounding = ((x_reg_cocktail - x_hat_cocktail) ** 2).mean() | |
else: | |
mse_reg_grounding = 0 | |
if params['use_brb_vae']: | |
global_minus_classif = params['beta_vae_loss'] * (vae_loss + brb_music_vae_loss) + params['beta_swd'] * swd_loss | |
global_loss = params['beta_vae_loss'] * (vae_loss + brb_music_vae_loss) + params['beta_swd'] * swd_loss + \ | |
params['beta_classif'] * (classif_loss_cocktail + params['beta_music_classif'] * classif_loss_music) | |
else: | |
global_minus_classif = params['beta_vae_loss'] * vae_loss + params['beta_swd'] * swd_loss | |
global_loss = params['beta_vae_loss'] * vae_loss + params['beta_swd'] * swd_loss + params['beta_classif'] * (classif_loss_cocktail + classif_loss_music) + \ | |
params['beta_reg_grounding'] * mse_reg_grounding | |
# global_loss = params['beta_vae_loss'] * cocktail_vae_loss + params['beta_classif'] * (classif_loss_cocktail + classif_loss_music) + \ | |
# params['beta_reg_grounding'] * mse_reg_grounding | |
losses['brb_vae_loss'].append(float(brb_vae_loss)) | |
losses['brb_mse_latent_loss'].append(float(brb_mse_latent_loss)) | |
losses['brb_kld_loss_cocktail'].append(float(brb_kld_loss_cocktail)) | |
losses['brb_kld_loss_music'].append(float(brb_kld_loss_music)) | |
losses['brb_mse_loss_music'].append(float(brb_mse_loss_music)) | |
losses['brb_mse_loss_cocktail'].append(float(brb_mse_loss_cocktail)) | |
losses['swd_losses'].append(float(swd_loss)) | |
losses['vae_losses'].append(float(vae_loss)) | |
losses['kld_losses_music'].append(float(kld_loss_music)) | |
losses['kld_losses_cocktail'].append(float(kld_loss_cocktail)) | |
losses['mse_losses_music'].append(float(mse_loss_music)) | |
losses['mse_losses_cocktail'].append(float(mse_loss_cocktail)) | |
losses['global_losses'].append(float(global_loss)) | |
losses['classif_losses_music'].append(float(classif_loss_music)) | |
losses['classif_losses_cocktail'].append(float(classif_loss_cocktail)) | |
losses['classif_acc_cocktail'].append(float(accuracy_cocktail)) | |
losses['classif_acc_music'].append(float(accuracy_music)) | |
losses['beta_reg_grounding'].append(float(mse_reg_grounding)) | |
losses['bubble_mse'].append(bubble_mse) | |
losses['egg_mse'].append(egg_mse) | |
if train: | |
# if epoch < params['n_epochs_music_pretrain']: | |
# music_vae_loss.backward() | |
# elif epoch >= params['n_epochs_music_pretrain'] and epoch < (params['n_epochs_music_pretrain'] + params['n_epochs_train']): | |
# global_minus_classif.backward() | |
# elif epoch >= (params['n_epochs_music_pretrain'] + params['n_epochs_train']): | |
global_loss.backward() | |
opt.step() | |
if params['log_every'] != 0: | |
if step != 0 and step % params['log_every'] == 0: | |
print(f'\tBatch #{i_batch}') | |
for k in params['keys_to_print']: | |
if k != 'steps': | |
print(f'\t {k}: Train: {np.nanmean(losses[k][-params["log_every"]:]):.3f}') | |
# print(f'\t {k}: Train: {torch.mean(torch.cat(losses[k][-params["log_every"]:])):.3f}') | |
return losses, [np.mean(cf_matrices_music, axis=0), np.mean(cf_matrices_cocktail, axis=0)] | |
def get_cf_matrix(pred, labels): | |
bs, dim = pred.shape | |
labels = labels.detach().numpy() | |
pred_labels = np.argmax(pred.detach().numpy(), axis=1) | |
confusion_matrix = np.zeros((dim, dim)) | |
for i in range(bs): | |
confusion_matrix[labels[i], pred_labels[i]] += 1 | |
for i in range(dim): | |
if np.sum(confusion_matrix[i]) != 0: | |
confusion_matrix[i] /= np.sum(confusion_matrix[i]) | |
return confusion_matrix | |
def train(model, dataloaders, params): | |
keys_to_track = params['keys_to_track'] | |
opt = torch.optim.AdamW(list(model.parameters()), lr=params['lr']) | |
if params['decay_step'] > 0: scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=params['decay_step'], gamma=0.5) | |
all_train_losses = dict(zip(keys_to_track, [[] for _ in range(len(keys_to_track))])) | |
all_eval_losses = dict(zip(keys_to_track, [[] for _ in range(len(keys_to_track))])) | |
best_eval_loss = np.inf | |
data_train = dict() | |
data_test = dict() | |
for k in dataloaders.keys(): | |
if '_train' in k: | |
data_train[k[:-6]] = dataloaders[k] | |
elif '_test' in k: | |
data_test[k[:-5]] = dataloaders[k] | |
else: | |
raise ValueError | |
# run first eval | |
eval_losses, _ = run_epoch(0, model, data_test, params, opt, train=False) | |
for k in params['keys_to_track']: | |
if k == 'steps': | |
all_train_losses[k].append(0) | |
all_eval_losses[k].append(0) | |
else: | |
all_train_losses[k].append(np.nan) | |
all_eval_losses[k].append(np.mean(eval_losses[k])) | |
# all_train_losses[k].append(torch.Tensor([np.nan])) | |
# all_eval_losses[k].append(torch.atleast_1d(torch.mean(torch.cat(eval_losses[k])))) | |
print(f'Initial evaluation') | |
for k in params['keys_to_print']: | |
to_print = all_eval_losses[k][-1] if k != 'steps' else all_eval_losses[k][-1] | |
# to_print = all_eval_losses[k][-1][0] if k != 'steps' else all_eval_losses[k][-1] | |
print(f' {k}: Eval: {to_print:.3f}') | |
step = 0 | |
for epoch in range(params['epochs']): | |
print(f'\n------------\nEpoch #{epoch}') | |
# run training epoch | |
train_losses, train_cf_matrices = run_epoch(epoch, model, data_train, params, opt, train=True) | |
# run eval epoch | |
eval_losses, eval_cf_matrices = run_epoch(epoch, model, data_test, params, opt, train=False) | |
if epoch < params['n_epochs_music_pretrain']: | |
epoch_size = params['pretrain_train_epoch_size'] | |
else: | |
epoch_size = params['train_epoch_size'] | |
step += epoch_size | |
for k in params['keys_to_track']: | |
if k == 'steps': | |
all_train_losses[k].append(epoch) | |
all_eval_losses[k].append(epoch) | |
else: | |
all_train_losses[k].append(np.nanmean(train_losses[k])) | |
all_eval_losses[k].append(np.nanmean(eval_losses[k])) | |
# all_train_losses[k].append(torch.atleast_1d(torch.mean(torch.cat(train_losses[k])))) | |
# all_eval_losses[k].append(torch.atleast_1d(torch.mean(torch.cat(eval_losses[k])))) | |
if params['decay_step']: scheduler.step() | |
# logging | |
print(f'----\n\tEval epoch #{epoch}') | |
for k in params['keys_to_print']: | |
to_print_eval = all_eval_losses[k][-1] if k != 'steps' else all_eval_losses[k][-1] | |
to_print_train = all_train_losses[k][-1] if k != 'steps' else all_train_losses[k][-1] | |
# to_print_eval = all_eval_losses[k][-1][0] if k != 'steps' else all_eval_losses[k][-1] | |
# to_print_train = all_train_losses[k][-1][0] if k != 'steps' else all_train_losses[k][-1] | |
print(f'\t {k}: Eval: {to_print_eval:.3f} / Train: {to_print_train:.3f}') | |
if epoch % params['plot_every'] == 0: | |
plot_all_losses(all_train_losses.copy(), all_eval_losses.copy(), train_cf_matrices, eval_cf_matrices, params) | |
# saving models | |
save_losses(all_train_losses, all_eval_losses, params['save_path'] + 'results.txt') | |
if params['save_every'] != 0: | |
if epoch % params['save_every'] == 0: | |
print('Saving model.') | |
save_model(model, path=params['save_path'], name=f'epoch_{epoch}') | |
if all_eval_losses['global_losses'][-1] < best_eval_loss: | |
best_eval_loss = all_eval_losses['global_losses'][-1] | |
print(f'New best eval loss: {best_eval_loss:.3f}, saving model.') | |
# print(f'New best eval loss: {best_eval_loss[0]:.3f}, saving model.') | |
save_model(model, path=params['save_path'], name='best_eval') | |
print('Saving last model.') | |
save_model(model, path=params['save_path'], name=f'last') | |
return model, all_train_losses, all_eval_losses, train_cf_matrices, eval_cf_matrices | |
def save_losses(train_losses, eval_losses, path): | |
results = [] | |
keys = sorted(train_losses.keys()) | |
for k in keys: | |
if k != 'steps': | |
results.append(train_losses[k])#list(torch.cat(train_losses[k]).detach().cpu().numpy())) | |
else: | |
results.append(train_losses[k]) | |
for k in keys: | |
if k != 'steps': | |
results.append(eval_losses[k])#list(torch.cat(eval_losses[k]).detach().cpu().numpy())) | |
else: | |
results.append(eval_losses[k]) | |
np.savetxt(path, np.array(results)) | |
def save_model(model, path, name): | |
torch.save(model.state_dict(), path + f'checkpoints_{name}.save') | |
def run_training(params): | |
params = compute_expe_name_and_save_path(params) | |
dataloaders, n_labels, stats = get_dataloaders(cocktail_rep_path=params['cocktail_rep_path'], | |
music_rep_path=params['music_rep_path'], | |
batch_size=params['pretrain_batch_size'], | |
train_epoch_size=params['pretrain_train_epoch_size'], | |
test_epoch_size=params['pretrain_test_epoch_size']) | |
params['nb_classes'] = n_labels | |
params['stats'] = stats | |
params['classif_classes'] = dataloaders['music_labeled_train'].dataset.classes | |
vae_gml_model = get_gml_vae_models(layer_type=params['layer_type'], | |
input_dim_music=dataloaders['music_train'].dataset.dim_music, | |
input_dim_cocktail=dataloaders['cocktail_train'].dataset.dim_cocktail, | |
hidden_dim=params['hidden_dim'], | |
n_hidden=params['n_hidden'], | |
latent_dim=params['latent_dim'], | |
nb_classes=params['nb_classes'], | |
dropout=params['dropout']) | |
params['dim_music'] = dataloaders['music_train'].dataset.dim_music | |
params['dim_cocktail'] = dataloaders['cocktail_train'].dataset.dim_cocktail | |
with open(params['save_path'] + 'params.json', 'w') as f: | |
json.dump(params, f) | |
models, train_losses, eval_losses, train_cf_matrices, eval_cf_matrices = train(vae_gml_model, dataloaders, params) | |
plot_all_losses(train_losses.copy(), eval_losses.copy(), train_cf_matrices, eval_cf_matrices, params) | |
return models, train_losses, eval_losses | |
def plot_all_losses(train_losses, eval_losses, train_cf_matrices, eval_cf_matrices, params): | |
plot_losses(train_losses, train_cf_matrices, 'train', params) | |
plot_losses(eval_losses, eval_cf_matrices, 'eval', params) | |
def plot_losses(losses, cf_matrices, split, params): | |
save_path = params['save_path'] + 'plots/' | |
os.makedirs(save_path, exist_ok=True) | |
steps = losses['steps'] | |
for k in losses.keys(): | |
# if k != 'steps': | |
# losses[k] = losses[k]#torch.cat(losses[k]).detach().cpu().numpy() | |
# else: | |
losses[k] = np.array(losses[k]) | |
losses['sum_loss_classif'] = losses['classif_losses_music'] + losses['classif_losses_cocktail'] | |
losses['av_acc_classif'] = (losses['classif_acc_cocktail'] + losses['classif_acc_music'])/2 | |
losses['sum_mse_vae'] = losses['mse_losses_cocktail'] + losses['mse_losses_music'] | |
losses['sum_kld_vae'] = losses['kld_losses_cocktail'] + losses['kld_losses_music'] | |
plt.figure() | |
for k in ['global_losses', 'vae_losses', 'swd_losses', 'sum_mse_vae', 'sum_kld_vae']: | |
factor = 10 if k == 'swd_losses' else 1 | |
plt.plot(steps, losses[k] * factor, label=k) | |
plt.title(split) | |
plt.legend() | |
plt.ylim([0, 2.5]) | |
plt.savefig(save_path + f'plot_high_level_losses_{split}.png') | |
plt.close(plt.gcf()) | |
plt.figure() | |
for k in ['classif_acc_cocktail', 'classif_acc_music']: | |
plt.plot(steps, losses[k], label=k) | |
plt.title(split) | |
plt.ylim([0, 1]) | |
plt.legend() | |
plt.savefig(save_path + f'plot_classif_accuracies_{split}.png') | |
plt.close(plt.gcf()) | |
plt.figure() | |
for k in ['mse_losses_cocktail', 'mse_losses_music', 'kld_losses_cocktail', | |
'kld_losses_music', 'swd_losses', 'classif_losses_cocktail', 'classif_losses_music', 'beta_reg_grounding', | |
'bubble_mse', 'egg_mse']: | |
factor = 10 if k == 'swd_losses' else 1 | |
plt.plot(steps, losses[k] * factor, label=k) | |
plt.title(split) | |
plt.ylim([0, 2.5]) | |
plt.legend() | |
plt.savefig(save_path + f'plot_detailed_losses_{split}.png') | |
plt.close(plt.gcf()) | |
for i_k, k in enumerate(['music', 'cocktail']): | |
plt.figure() | |
plt.imshow(cf_matrices[i_k], vmin=0, vmax=1) | |
labx = plt.xticks(range(len(params['classif_classes'])), params['classif_classes'], rotation=45) | |
laby = plt.yticks(range(len(params['classif_classes'])), params['classif_classes']) | |
labxx = plt.xlabel('predicted') | |
labyy = plt.ylabel('true') | |
plt.title(split + ' ' + k) | |
plt.colorbar() | |
plt.savefig(save_path + f'cf_matrix_{split}_{k}.png', artists=(labx, laby, labxx, labyy)) | |
plt.close(plt.gcf()) | |
if params['use_brb_vae']: | |
plt.figure() | |
for k in ['brb_vae_loss', 'brb_kld_loss_cocktail', 'brb_kld_loss_music', 'brb_mse_loss_music', 'brb_mse_loss_cocktail', 'mse_losses_music', 'brb_mse_latent_loss']: | |
factor = 10 if k == 'swd_losses' else 1 | |
plt.plot(steps, losses[k] * factor, label=k) | |
plt.title(split) | |
plt.ylim([0, 2.5]) | |
plt.legend() | |
plt.savefig(save_path + f'plot_detailed_brb_losses_{split}.png') | |
plt.close(plt.gcf()) | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="") | |
parser.add_argument("--save_path", type=str, default="/home/cedric/Documents/pianocktail/experiments/music/representation_learning/saved_models/latent_translation/") | |
parser.add_argument("--trial_id", type=str, default="b256_r128_classif001_ld40_meanstd") | |
parser.add_argument("--hidden_dim", type=int, default=256) #128 | |
parser.add_argument("--n_hidden", type=int, default=1) | |
parser.add_argument("--latent_dim", type=int, default=40) #40 | |
parser.add_argument("--n_epochs_music_pretrain", type=int, default=0) | |
parser.add_argument("--n_epochs_train", type=int, default=200) | |
parser.add_argument("--n_epochs_classif_finetune", type=int, default=0) | |
parser.add_argument("--beta_vae_loss", type=float, default=1.) | |
parser.add_argument("--beta_vae", type=float, default=1.2) # keep this low~1 to allow music classification... | |
parser.add_argument("--beta_swd", type=float, default=1) | |
parser.add_argument("--beta_reg_grounding", type=float, default=2.5) | |
parser.add_argument("--beta_classif", type=float, default=0.01)#0.01) #TODO: try 0.1, default 0.01 | |
parser.add_argument("--beta_music", type=float, default=100) # higher loss on the music that needs more to converge | |
parser.add_argument("--beta_music_classif", type=float, default=300) # try300# higher loss on the music that needs more to converge | |
parser.add_argument("--pretrain_batch_size", type=int, default=128) | |
parser.add_argument("--batch_size", type=int, default=32) | |
parser.add_argument("--lr", type=float, default=0.001) | |
parser.add_argument("--decay_step", type=int, default=0) | |
parser.add_argument("--cocktail_rep_path", type=str, default=FULL_COCKTAIL_REP_PATH) | |
parser.add_argument("--music_rep_path", type=str, default=music_rep_path) | |
parser.add_argument("--use_brb_vae", type=bool, default=False) | |
parser.add_argument("--layer_type", type=str, default='gml') | |
parser.add_argument("--dropout", type=float, default=0.2) | |
# best parameters | |
# parser = argparse.ArgumentParser(description="") | |
# parser.add_argument("--save_path", type=str, default="/home/cedric/Documents/pianocktail/experiments/music/representation_learning/saved_models/latent_translation/") | |
# parser.add_argument("--trial_id", type=str, default="b256_r128_classif001_ld40_meanstd") | |
# parser.add_argument("--hidden_dim", type=int, default=256) #128 | |
# parser.add_argument("--n_hidden", type=int, default=1) | |
# parser.add_argument("--latent_dim", type=int, default=40) #40 | |
# parser.add_argument("--n_epochs_music_pretrain", type=int, default=0) | |
# parser.add_argument("--n_epochs_train", type=int, default=200) | |
# parser.add_argument("--n_epochs_classif_finetune", type=int, default=0) | |
# parser.add_argument("--beta_vae_loss", type=float, default=1.) | |
# parser.add_argument("--beta_vae", type=float, default=1) # keep this low~1 to allow music classification... | |
# parser.add_argument("--beta_swd", type=float, default=1) | |
# parser.add_argument("--beta_reg_grounding", type=float, default=2.5) | |
# parser.add_argument("--beta_classif", type=float, default=0.01)#0.01) #TODO: try 0.1, default 0.01 | |
# parser.add_argument("--beta_music", type=float, default=100) # higher loss on the music that needs more to converge | |
# parser.add_argument("--beta_music_classif", type=float, default=300) # try300# higher loss on the music that needs more to converge | |
# parser.add_argument("--pretrain_batch_size", type=int, default=128) | |
# parser.add_argument("--batch_size", type=int, default=32) | |
# parser.add_argument("--lr", type=float, default=0.001) | |
# parser.add_argument("--decay_step", type=int, default=0) | |
# parser.add_argument("--cocktail_rep_path", type=str, default=FULL_COCKTAIL_REP_PATH) | |
# parser.add_argument("--music_rep_path", type=str, default=music_rep_path) | |
# parser.add_argument("--use_brb_vae", type=bool, default=False) | |
# parser.add_argument("--layer_type", type=str, default='gml') | |
# parser.add_argument("--dropout", type=float, default=0.2) | |
args = parser.parse_args() | |
return args | |
def compute_expe_name_and_save_path(params): | |
save_path = params['save_path'] + params["trial_id"] | |
if params["use_brb_vae"]: | |
save_path += '_usebrb' | |
save_path += f'_lr{params["lr"]}' | |
save_path += f'_bs{params["batch_size"]}' | |
save_path += f'_bmusic{params["beta_music"]}' | |
save_path += f'_bswd{params["beta_swd"]}' | |
save_path += f'_bclassif{params["beta_classif"]}' | |
save_path += f'_bvae{params["beta_vae_loss"]}' | |
save_path += f'_bvaekld{params["beta_vae"]}' | |
save_path += f'_lat{params["latent_dim"]}' | |
save_path += f'_hd{params["n_hidden"]}x{params["hidden_dim"]}' | |
save_path += f'_drop{params["dropout"]}' | |
save_path += f'_decay{params["decay_step"]}' | |
save_path += f'_layertype{params["layer_type"]}' | |
number_added = False | |
counter = 1 | |
while os.path.exists(save_path): | |
if number_added: | |
save_path = '_'.join(save_path.split('_')[:-1]) + f'_{counter}' | |
counter += 1 | |
else: | |
save_path += f'_{counter}' | |
params["save_path"] = save_path + '/' | |
os.makedirs(save_path) | |
print(f'logging to {save_path}') | |
return params | |
if __name__ == '__main__': | |
keys_to_track = ['steps', 'global_losses', 'vae_losses', 'mse_losses_cocktail', 'mse_losses_music', 'kld_losses_cocktail', | |
'kld_losses_music', 'swd_losses', 'classif_losses_cocktail', 'classif_losses_music', 'classif_acc_cocktail', 'classif_acc_music', | |
'brb_kld_loss_cocktail', 'brb_kld_loss_music', 'brb_mse_loss_music', 'brb_mse_loss_cocktail', 'brb_mse_latent_loss', 'brb_vae_loss', 'beta_reg_grounding', | |
'bubble_mse', 'egg_mse'] | |
keys_to_print = ['steps', 'global_losses', 'vae_losses', 'mse_losses_cocktail', 'mse_losses_music', 'kld_losses_cocktail', | |
'kld_losses_music', 'swd_losses', 'classif_losses_cocktail', 'classif_losses_music', 'classif_acc_cocktail', 'classif_acc_music', 'beta_reg_grounding'] | |
#TODO: first phase vae pretraining for music | |
# then in second phase: vae cocktail and music, brb vaes | |
args = parse_args() | |
params = dict(nb_classes=None, | |
save_every=0, #epochs | |
log_every=0, #32*500, | |
plot_every=10, # in epochs | |
keys_to_track=keys_to_track, | |
keys_to_print=keys_to_print,) | |
params.update(vars(args)) | |
params['train_epoch_size'] = params['batch_size'] * 100 | |
params['test_epoch_size'] = params['batch_size'] * 10 | |
params['pretrain_train_epoch_size'] = params['pretrain_batch_size'] * 100 | |
params['pretrain_test_epoch_size'] = params['pretrain_batch_size'] * 10 | |
params['epochs'] = params['n_epochs_music_pretrain'] + params['n_epochs_train'] + params['n_epochs_classif_finetune'] | |
run_training(params) | |