Spaces:
Runtime error
Runtime error
import torch; torch.manual_seed(0) | |
import torch.utils | |
from torch.utils.data import DataLoader | |
import torch.distributions | |
import torch.nn as nn | |
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200 | |
from src.cocktails.representation_learning.dataset import MyDataset, get_representation_from_ingredient, get_max_n_ingredients | |
import json | |
import pandas as pd | |
import numpy as np | |
import os | |
from src.cocktails.representation_learning.multihead_model import get_multihead_model | |
from src.cocktails.config import COCKTAILS_CSV_DATA, FULL_COCKTAIL_REP_PATH, EXPERIMENT_PATH | |
from src.cocktails.utilities.cocktail_utilities import get_bunch_of_rep_keys | |
from src.cocktails.utilities.ingredients_utilities import ingredient_profiles | |
from resource import getrusage | |
from resource import RUSAGE_SELF | |
import gc | |
gc.collect(2) | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
def get_params(): | |
data = pd.read_csv(COCKTAILS_CSV_DATA) | |
max_ingredients, ingredient_set, liquor_set, liqueur_set = get_max_n_ingredients(data) | |
num_ingredients = len(ingredient_set) | |
rep_keys = get_bunch_of_rep_keys()['custom'] | |
ing_keys = [k.split(' ')[1] for k in rep_keys] | |
ing_keys.remove('volume') | |
nb_ing_categories = len(set(ingredient_profiles['type'])) | |
category_encodings = dict(zip(sorted(set(ingredient_profiles['type'])), np.eye(nb_ing_categories))) | |
params = dict(trial_id='test', | |
save_path=EXPERIMENT_PATH + "/multihead_model/", | |
nb_epochs=500, | |
print_every=50, | |
plot_every=50, | |
batch_size=128, | |
lr=0.001, | |
dropout=0., | |
nb_epoch_switch_beta=600, | |
latent_dim=10, | |
beta_vae=0.2, | |
ing_keys=ing_keys, | |
nb_ingredients=len(ingredient_set), | |
hidden_dims_ingredients=[128], | |
hidden_dims_cocktail=[64], | |
hidden_dims_decoder=[32], | |
agg='mean', | |
activation='relu', | |
auxiliaries_dict=dict(categories=dict(weight=5, type='classif', final_activ=None, dim_output=len(set(data['subcategory']))), #0.5 | |
glasses=dict(weight=0.5, type='classif', final_activ=None, dim_output=len(set(data['glass']))), #0.1 | |
prep_type=dict(weight=0.1, type='classif', final_activ=None, dim_output=len(set(data['category']))),#1 | |
cocktail_reps=dict(weight=1, type='regression', final_activ=None, dim_output=13),#1 | |
volume=dict(weight=1, type='regression', final_activ='relu', dim_output=1),#1 | |
taste_reps=dict(weight=1, type='regression', final_activ='relu', dim_output=2),#1 | |
ingredients_presence=dict(weight=0, type='multiclassif', final_activ=None, dim_output=num_ingredients),#10 | |
ingredients_quantities=dict(weight=0, type='regression', final_activ=None, dim_output=num_ingredients)), | |
category_encodings=category_encodings | |
) | |
water_rep, indexes_to_normalize = get_representation_from_ingredient(ingredients=['water'], quantities=[1], | |
max_q_per_ing=dict(zip(ingredient_set, [1] * num_ingredients)), index=0, | |
params=params) | |
dim_rep_ingredient = water_rep.size | |
params['indexes_ing_to_normalize'] = indexes_to_normalize | |
params['deepset_latent_dim'] = dim_rep_ingredient * max_ingredients | |
params['dim_rep_ingredient'] = dim_rep_ingredient | |
params['input_dim'] = params['nb_ingredients'] | |
params = compute_expe_name_and_save_path(params) | |
del params['category_encodings'] # to dump | |
with open(params['save_path'] + 'params.json', 'w') as f: | |
json.dump(params, f) | |
params = complete_params(params) | |
return params | |
def complete_params(params): | |
data = pd.read_csv(COCKTAILS_CSV_DATA) | |
cocktail_reps = np.loadtxt(FULL_COCKTAIL_REP_PATH) | |
nb_ing_categories = len(set(ingredient_profiles['type'])) | |
category_encodings = dict(zip(sorted(set(ingredient_profiles['type'])), np.eye(nb_ing_categories))) | |
params['cocktail_reps'] = cocktail_reps | |
params['raw_data'] = data | |
params['category_encodings'] = category_encodings | |
return params | |
def compute_losses_and_accuracies(loss_functions, auxiliaries, auxiliaries_str, outputs, data): | |
losses = dict() | |
accuracies = dict() | |
other_metrics = dict() | |
for i_k, k in enumerate(auxiliaries_str): | |
# get ground truth | |
# compute loss | |
if k == 'volume': | |
outputs[i_k] = outputs[i_k].flatten() | |
ground_truth = auxiliaries[k] | |
if ground_truth.dtype == torch.float64: | |
losses[k] = loss_functions[k](outputs[i_k], ground_truth.float()).float() | |
elif ground_truth.dtype == torch.int64: | |
if str(loss_functions[k]) != "BCEWithLogitsLoss()": | |
losses[k] = loss_functions[k](outputs[i_k].float(), ground_truth.long()).float() | |
else: | |
losses[k] = loss_functions[k](outputs[i_k].float(), ground_truth.float()).float() | |
else: | |
losses[k] = loss_functions[k](outputs[i_k], ground_truth).float() | |
# compute accuracies | |
if str(loss_functions[k]) == 'CrossEntropyLoss()': | |
bs, n_options = outputs[i_k].shape | |
predicted = outputs[i_k].argmax(dim=1).detach().numpy() | |
true = ground_truth.int().detach().numpy() | |
confusion_matrix = np.zeros([n_options, n_options]) | |
for i in range(bs): | |
confusion_matrix[true[i], predicted[i]] += 1 | |
acc = confusion_matrix.diagonal().sum() / bs | |
for i in range(n_options): | |
if confusion_matrix[i].sum() != 0: | |
confusion_matrix[i] /= confusion_matrix[i].sum() | |
other_metrics[k + '_confusion'] = confusion_matrix | |
accuracies[k] = np.mean(outputs[i_k].argmax(dim=1).detach().numpy() == ground_truth.int().detach().numpy()) | |
assert (acc - accuracies[k]) < 1e-5 | |
elif str(loss_functions[k]) == 'BCEWithLogitsLoss()': | |
assert k == 'ingredients_presence' | |
outputs_rescaled = outputs[i_k].detach().numpy() * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities | |
predicted_presence = (outputs_rescaled > 0).astype(bool) | |
presence = ground_truth.detach().numpy().astype(bool) | |
other_metrics[k + '_false_positive'] = np.mean(np.logical_and(predicted_presence.astype(bool), ~presence.astype(bool))) | |
other_metrics[k + '_false_negative'] = np.mean(np.logical_and(~predicted_presence.astype(bool), presence.astype(bool))) | |
accuracies[k] = np.mean(predicted_presence == presence) # accuracy for multi class labeling | |
elif str(loss_functions[k]) == 'MSELoss()': | |
accuracies[k] = np.nan | |
else: | |
raise ValueError | |
return losses, accuracies, other_metrics | |
def compute_metric_output(aux_other_metrics, data, ingredient_quantities, x_hat): | |
ing_q = ingredient_quantities.detach().numpy()# * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities | |
ing_presence = (ing_q > 0) | |
x_hat = x_hat.detach().numpy() | |
# x_hat = x_hat.detach().numpy() * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities | |
abs_diff = np.abs(ing_q - x_hat) * data.dataset.max_ing_quantities | |
# abs_diff = np.abs(ing_q - x_hat) | |
ing_q_abs_loss_when_present, ing_q_abs_loss_when_absent = [], [] | |
for i in range(ingredient_quantities.shape[0]): | |
ing_q_abs_loss_when_present.append(np.mean(abs_diff[i, np.where(ing_presence[i])])) | |
ing_q_abs_loss_when_absent.append(np.mean(abs_diff[i, np.where(~ing_presence[i])])) | |
aux_other_metrics['ing_q_abs_loss_when_present'] = np.mean(ing_q_abs_loss_when_present) | |
aux_other_metrics['ing_q_abs_loss_when_absent'] = np.mean(ing_q_abs_loss_when_absent) | |
return aux_other_metrics | |
def run_epoch(opt, train, model, data, loss_functions, weights, params): | |
if train: | |
model.train() | |
else: | |
model.eval() | |
# prepare logging of losses | |
losses = dict(kld_loss=[], | |
mse_loss=[], | |
vae_loss=[], | |
volume_loss=[], | |
global_loss=[]) | |
accuracies = dict() | |
other_metrics = dict() | |
for aux in params['auxiliaries_dict'].keys(): | |
losses[aux] = [] | |
accuracies[aux] = [] | |
if train: opt.zero_grad() | |
for d in data: | |
nb_ingredients = d[0] | |
batch_size = nb_ingredients.shape[0] | |
x_ingredients = d[1].float() | |
ingredient_quantities = d[2] | |
cocktail_reps = d[3] | |
auxiliaries = d[4] | |
for k in auxiliaries.keys(): | |
if auxiliaries[k].dtype == torch.float64: auxiliaries[k] = auxiliaries[k].float() | |
taste_valid = d[-1] | |
z, outputs, auxiliaries_str = model.forward(ingredient_quantities.float()) | |
# get auxiliary losses and accuracies | |
aux_losses, aux_accuracies, aux_other_metrics = compute_losses_and_accuracies(loss_functions, auxiliaries, auxiliaries_str, outputs, data) | |
# compute vae loss | |
aux_other_metrics = compute_metric_output(aux_other_metrics, data, ingredient_quantities, outputs[auxiliaries_str.index('ingredients_quantities')]) | |
indexes_taste_valid = np.argwhere(taste_valid.detach().numpy()).flatten() | |
if indexes_taste_valid.size > 0: | |
outputs_taste = model.get_auxiliary(z[indexes_taste_valid], aux_str='taste_reps') | |
gt = auxiliaries['taste_reps'][indexes_taste_valid] | |
factor_loss = indexes_taste_valid.size / (0.3 * batch_size)# factor on the loss: if same ratio as actual dataset factor = 1 if there is less data, then the factor decreases, more data, it increases | |
aux_losses['taste_reps'] = (loss_functions['taste_reps'](outputs_taste, gt) * factor_loss).float() | |
else: | |
aux_losses['taste_reps'] = torch.FloatTensor([0]).reshape([]) | |
aux_accuracies['taste_reps'] = 0 | |
# aggregate losses | |
global_loss = torch.sum(torch.cat([torch.atleast_1d(aux_losses[k] * weights[k]) for k in params['auxiliaries_dict'].keys()])) | |
# for k in params['auxiliaries_dict'].keys(): | |
# global_loss += aux_losses[k] * weights[k] | |
if train: | |
global_loss.backward() | |
opt.step() | |
opt.zero_grad() | |
# logging | |
losses['global_loss'].append(float(global_loss)) | |
for k in params['auxiliaries_dict'].keys(): | |
losses[k].append(float(aux_losses[k])) | |
accuracies[k].append(float(aux_accuracies[k])) | |
for k in aux_other_metrics.keys(): | |
if k not in other_metrics.keys(): | |
other_metrics[k] = [aux_other_metrics[k]] | |
else: | |
other_metrics[k].append(aux_other_metrics[k]) | |
for k in losses.keys(): | |
losses[k] = np.mean(losses[k]) | |
for k in accuracies.keys(): | |
accuracies[k] = np.mean(accuracies[k]) | |
for k in other_metrics.keys(): | |
other_metrics[k] = np.mean(other_metrics[k], axis=0) | |
return model, losses, accuracies, other_metrics | |
def prepare_data_and_loss(params): | |
train_data = MyDataset(split='train', params=params) | |
test_data = MyDataset(split='test', params=params) | |
train_data_loader = DataLoader(train_data, batch_size=params['batch_size'], shuffle=True) | |
test_data_loader = DataLoader(test_data, batch_size=params['batch_size'], shuffle=True) | |
loss_functions = dict() | |
weights = dict() | |
for k in sorted(params['auxiliaries_dict'].keys()): | |
if params['auxiliaries_dict'][k]['type'] == 'classif': | |
if k == 'glasses': | |
classif_weights = train_data.glasses_weights | |
elif k == 'prep_type': | |
classif_weights = train_data.prep_types_weights | |
elif k == 'categories': | |
classif_weights = train_data.categories_weights | |
else: | |
raise ValueError | |
loss_functions[k] = nn.CrossEntropyLoss(torch.FloatTensor(classif_weights)) | |
elif params['auxiliaries_dict'][k]['type'] == 'multiclassif': | |
loss_functions[k] = nn.BCEWithLogitsLoss() | |
elif params['auxiliaries_dict'][k]['type'] == 'regression': | |
loss_functions[k] = nn.MSELoss() | |
else: | |
raise ValueError | |
weights[k] = params['auxiliaries_dict'][k]['weight'] | |
return loss_functions, train_data_loader, test_data_loader, weights | |
def print_losses(train, losses, accuracies, other_metrics): | |
keyword = 'Train' if train else 'Eval' | |
print(f'\t{keyword} logs:') | |
keys = ['global_loss', 'vae_loss', 'mse_loss', 'kld_loss', 'volume_loss'] | |
for k in keys: | |
print(f'\t\t{k} - Loss: {losses[k]:.2f}') | |
for k in sorted(accuracies.keys()): | |
print(f'\t\t{k} (aux) - Loss: {losses[k]:.2f}, Acc: {accuracies[k]:.2f}') | |
for k in sorted(other_metrics.keys()): | |
if 'confusion' not in k: | |
print(f'\t\t{k} - {other_metrics[k]:.2f}') | |
def run_experiment(params, verbose=True): | |
loss_functions, train_data_loader, test_data_loader, weights = prepare_data_and_loss(params) | |
model_params = [params[k] for k in ["input_dim", "activation", "hidden_dims_cocktail", "latent_dim", "dropout", "auxiliaries_dict", "hidden_dims_decoder"]] | |
model = get_multihead_model(*model_params) | |
opt = torch.optim.AdamW(model.parameters(), lr=params['lr']) | |
all_train_losses = [] | |
all_eval_losses = [] | |
all_train_accuracies = [] | |
all_eval_accuracies = [] | |
all_eval_other_metrics = [] | |
all_train_other_metrics = [] | |
best_loss = np.inf | |
model, eval_losses, eval_accuracies, eval_other_metrics = run_epoch(opt=opt, train=False, model=model, data=test_data_loader, loss_functions=loss_functions, | |
weights=weights, params=params) | |
all_eval_losses.append(eval_losses) | |
all_eval_accuracies.append(eval_accuracies) | |
all_eval_other_metrics.append(eval_other_metrics) | |
if verbose: print(f'\n--------\nEpoch #0') | |
if verbose: print_losses(train=False, accuracies=eval_accuracies, losses=eval_losses, other_metrics=eval_other_metrics) | |
for epoch in range(params['nb_epochs']): | |
if verbose and (epoch + 1) % params['print_every'] == 0: print(f'\n--------\nEpoch #{epoch+1}') | |
model, train_losses, train_accuracies, train_other_metrics = run_epoch(opt=opt, train=True, model=model, data=train_data_loader, loss_functions=loss_functions, | |
weights=weights, params=params) | |
if verbose and (epoch + 1) % params['print_every'] == 0: print_losses(train=True, accuracies=train_accuracies, losses=train_losses, other_metrics=train_other_metrics) | |
model, eval_losses, eval_accuracies, eval_other_metrics = run_epoch(opt=opt, train=False, model=model, data=test_data_loader, loss_functions=loss_functions, | |
weights=weights, params=params) | |
if verbose and (epoch + 1) % params['print_every'] == 0: print_losses(train=False, accuracies=eval_accuracies, losses=eval_losses, other_metrics=eval_other_metrics) | |
if eval_losses['global_loss'] < best_loss: | |
best_loss = eval_losses['global_loss'] | |
if verbose: print(f'Saving new best model with loss {best_loss:.2f}') | |
torch.save(model.state_dict(), params['save_path'] + f'checkpoint_best.save') | |
# log | |
all_train_losses.append(train_losses) | |
all_train_accuracies.append(train_accuracies) | |
all_eval_losses.append(eval_losses) | |
all_eval_accuracies.append(eval_accuracies) | |
all_eval_other_metrics.append(eval_other_metrics) | |
all_train_other_metrics.append(train_other_metrics) | |
# if epoch == params['nb_epoch_switch_beta']: | |
# params['beta_vae'] = 2.5 | |
# params['auxiliaries_dict']['prep_type']['weight'] /= 10 | |
# params['auxiliaries_dict']['glasses']['weight'] /= 10 | |
if (epoch + 1) % params['plot_every'] == 0: | |
plot_results(all_train_losses, all_train_accuracies, all_train_other_metrics, | |
all_eval_losses, all_eval_accuracies, all_eval_other_metrics, params['plot_path'], weights) | |
return model | |
def plot_results(all_train_losses, all_train_accuracies, all_train_other_metrics, | |
all_eval_losses, all_eval_accuracies, all_eval_other_metrics, plot_path, weights): | |
steps = np.arange(len(all_eval_accuracies)) | |
loss_keys = sorted(all_train_losses[0].keys()) | |
acc_keys = sorted(all_train_accuracies[0].keys()) | |
metrics_keys = sorted(all_train_other_metrics[0].keys()) | |
plt.figure() | |
plt.title('Train losses') | |
for k in loss_keys: | |
factor = 1 if k == 'mse_loss' else 1 | |
if k not in weights.keys(): | |
plt.plot(steps[1:], [train_loss[k] * factor for train_loss in all_train_losses], label=k) | |
else: | |
if weights[k] != 0: | |
plt.plot(steps[1:], [train_loss[k] * factor for train_loss in all_train_losses], label=k) | |
plt.legend() | |
plt.ylim([0, 4]) | |
plt.savefig(plot_path + 'train_losses.png', dpi=200) | |
fig = plt.gcf() | |
plt.close(fig) | |
plt.figure() | |
plt.title('Train accuracies') | |
for k in acc_keys: | |
if weights[k] != 0: | |
plt.plot(steps[1:], [train_acc[k] for train_acc in all_train_accuracies], label=k) | |
plt.legend() | |
plt.ylim([0, 1]) | |
plt.savefig(plot_path + 'train_acc.png', dpi=200) | |
fig = plt.gcf() | |
plt.close(fig) | |
plt.figure() | |
plt.title('Train other metrics') | |
for k in metrics_keys: | |
if 'confusion' not in k and 'presence' in k: | |
plt.plot(steps[1:], [train_metric[k] for train_metric in all_train_other_metrics], label=k) | |
plt.legend() | |
plt.ylim([0, 1]) | |
plt.savefig(plot_path + 'train_ing_presence_errors.png', dpi=200) | |
fig = plt.gcf() | |
plt.close(fig) | |
plt.figure() | |
plt.title('Train other metrics') | |
for k in metrics_keys: | |
if 'confusion' not in k and 'presence' not in k: | |
plt.plot(steps[1:], [train_metric[k] for train_metric in all_train_other_metrics], label=k) | |
plt.legend() | |
plt.ylim([0, 15]) | |
plt.savefig(plot_path + 'train_ing_q_error.png', dpi=200) | |
fig = plt.gcf() | |
plt.close(fig) | |
plt.figure() | |
plt.title('Eval losses') | |
for k in loss_keys: | |
factor = 1 if k == 'mse_loss' else 1 | |
if k not in weights.keys(): | |
plt.plot(steps, [eval_loss[k] * factor for eval_loss in all_eval_losses], label=k) | |
else: | |
if weights[k] != 0: | |
plt.plot(steps, [eval_loss[k] * factor for eval_loss in all_eval_losses], label=k) | |
plt.legend() | |
plt.ylim([0, 4]) | |
plt.savefig(plot_path + 'eval_losses.png', dpi=200) | |
fig = plt.gcf() | |
plt.close(fig) | |
plt.figure() | |
plt.title('Eval accuracies') | |
for k in acc_keys: | |
if weights[k] != 0: | |
plt.plot(steps, [eval_acc[k] for eval_acc in all_eval_accuracies], label=k) | |
plt.legend() | |
plt.ylim([0, 1]) | |
plt.savefig(plot_path + 'eval_acc.png', dpi=200) | |
fig = plt.gcf() | |
plt.close(fig) | |
plt.figure() | |
plt.title('Eval other metrics') | |
for k in metrics_keys: | |
if 'confusion' not in k and 'presence' in k: | |
plt.plot(steps, [eval_metric[k] for eval_metric in all_eval_other_metrics], label=k) | |
plt.legend() | |
plt.ylim([0, 1]) | |
plt.savefig(plot_path + 'eval_ing_presence_errors.png', dpi=200) | |
fig = plt.gcf() | |
plt.close(fig) | |
plt.figure() | |
plt.title('Eval other metrics') | |
for k in metrics_keys: | |
if 'confusion' not in k and 'presence' not in k: | |
plt.plot(steps, [eval_metric[k] for eval_metric in all_eval_other_metrics], label=k) | |
plt.legend() | |
plt.ylim([0, 15]) | |
plt.savefig(plot_path + 'eval_ing_q_error.png', dpi=200) | |
fig = plt.gcf() | |
plt.close(fig) | |
for k in metrics_keys: | |
if 'confusion' in k: | |
plt.figure() | |
plt.title(k) | |
plt.ylabel('True') | |
plt.xlabel('Predicted') | |
plt.imshow(all_eval_other_metrics[-1][k], vmin=0, vmax=1) | |
plt.colorbar() | |
plt.savefig(plot_path + f'eval_{k}.png', dpi=200) | |
fig = plt.gcf() | |
plt.close(fig) | |
for k in metrics_keys: | |
if 'confusion' in k: | |
plt.figure() | |
plt.title(k) | |
plt.ylabel('True') | |
plt.xlabel('Predicted') | |
plt.imshow(all_train_other_metrics[-1][k], vmin=0, vmax=1) | |
plt.colorbar() | |
plt.savefig(plot_path + f'train_{k}.png', dpi=200) | |
fig = plt.gcf() | |
plt.close(fig) | |
plt.close('all') | |
def get_model(model_path): | |
with open(model_path + 'params.json', 'r') as f: | |
params = json.load(f) | |
params['save_path'] = model_path | |
model_chkpt = model_path + "checkpoint_best.save" | |
model_params = [params[k] for k in ["input_dim", "activation", "hidden_dims_cocktail", "latent_dim", "dropout", "auxiliaries_dict", "hidden_dims_decoder"]] | |
model = get_multihead_model(*model_params) | |
model.load_state_dict(torch.load(model_chkpt)) | |
model.eval() | |
max_ing_quantities = np.loadtxt(model_path + 'max_ing_quantities.txt') | |
def predict(ing_qs, aux_str): | |
ing_qs /= max_ing_quantities | |
input_model = torch.FloatTensor(ing_qs).reshape(1, -1) | |
_, outputs, auxiliaries_str = model.forward(input_model, ) | |
if isinstance(aux_str, str): | |
return outputs[auxiliaries_str.index(aux_str)].detach().numpy() | |
elif isinstance(aux_str, list): | |
return [outputs[auxiliaries_str.index(aux)].detach().numpy() for aux in aux_str] | |
else: | |
raise ValueError | |
return predict, params | |
def compute_expe_name_and_save_path(params): | |
weights_str = '[' | |
for aux in params['auxiliaries_dict'].keys(): | |
weights_str += f'{params["auxiliaries_dict"][aux]["weight"]}, ' | |
weights_str = weights_str[:-2] + ']' | |
save_path = params['save_path'] + params["trial_id"] | |
save_path += f'_lr{params["lr"]}' | |
save_path += f'_betavae{params["beta_vae"]}' | |
save_path += f'_bs{params["batch_size"]}' | |
save_path += f'_latentdim{params["latent_dim"]}' | |
save_path += f'_hding{params["hidden_dims_ingredients"]}' | |
save_path += f'_hdcocktail{params["hidden_dims_cocktail"]}' | |
save_path += f'_hddecoder{params["hidden_dims_decoder"]}' | |
save_path += f'_agg{params["agg"]}' | |
save_path += f'_activ{params["activation"]}' | |
save_path += f'_w{weights_str}' | |
counter = 0 | |
while os.path.exists(save_path + f"_{counter}"): | |
counter += 1 | |
save_path = save_path + f"_{counter}" + '/' | |
params["save_path"] = save_path | |
os.makedirs(save_path) | |
os.makedirs(save_path + 'plots/') | |
params['plot_path'] = save_path + 'plots/' | |
print(f'logging to {save_path}') | |
return params | |
if __name__ == '__main__': | |
params = get_params() | |
run_experiment(params) | |