LingConv / options.py
mohdelgaar's picture
Update layout and samples
674b430
raw
history blame
7.9 kB
import os, json
import argparse
import numpy as np
from datetime import datetime
from const import lftkplus_names
from copy import deepcopy
def parse_args(ckpt=None):
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', default='/data/mohamed/data')
parser.add_argument('--data', default='ling_conversion')
parser.add_argument('--data_sources')
parser.add_argument('--data_type', default='text')
parser.add_argument('--aim_repo', default='/data/mohamed/')
parser.add_argument('--ckpt_dir', default='/data/mohamed/checkpoints')
parser.add_argument('--kld_annealing', default='cyclic')
parser.add_argument('--lingpred_annealing', default='mono')
parser.add_argument('--ling_embed_type', default = 'one-layer')
parser.add_argument('--combine_weight', default=1, type=float)
parser.add_argument('--alpha_kld', default=1, type=float)
parser.add_argument('--alpha_lingpred', default=1, type=float)
parser.add_argument('--alpha_sem', default=1, type=float)
parser.add_argument('--max_grad_norm', default=10, type=float)
parser.add_argument('--sem_loss_tao', default=0.5, type=float)
parser.add_argument('--sem_loss_eps', default=1, type=float)
parser.add_argument('--ckpt')
parser.add_argument('--disc_ckpt')
parser.add_argument('--sem_ckpt')
parser.add_argument('--lng_ids')
parser.add_argument('--lng_ids_idx', type=int)
parser.add_argument('--lng_ids_path', default='/data/mohamed/indices')
parser.add_argument('--preds_dir', default='/data/mohamed/preds')
parser.add_argument('--model_name', default="google/flan-t5-base")
parser.add_argument('--disc_type', default="t5")
parser.add_argument('--aim_exp', default='ling-conversion')
parser.add_argument('--sem_loss_type', default='dedicated')
parser.add_argument('--combine_method', default='none')
parser.add_argument('--train_log', type=int, default=200)
parser.add_argument('--val_log', type=int, default=2000)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--eval_batch_size', type=int, default=32)
parser.add_argument('--max_eval_samples', type=int, default=1000)
parser.add_argument('--test_batch_size', type=int, default=1)
parser.add_argument('--hidden_dim', type=int, default=500)
parser.add_argument('--latent_dim', type=int, default=150)
parser.add_argument('--lng_dim', type=int, default=40)
parser.add_argument('--disc_lng_dim', type=int)
parser.add_argument('--use_lora', action='store_true')
parser.add_argument('--lora_r', type=int, default=64)
parser.add_argument('--gpu', type=str, default='0')
parser.add_argument('--epochs', type=int, default=10)
parser.add_argument('--grad_accumulation', type=int, default=1)
parser.add_argument('--n_ica', type=int, default=10)
parser.add_argument('--max_length', type=int, default=200)
parser.add_argument('--total_steps', type=int)
parser.add_argument('--kld_const', type=float, default=1)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--kl_weight', type=float, default=1e-1)
parser.add_argument('--weight_decay', type=float, default=1e-2)
parser.add_argument('--ling_dropout', type=float, default=0.1)
parser.add_argument('--predict_fn', default = 'logs/test.txt')
parser.add_argument('--save_predict', action='store_true')
parser.add_argument('--use_ica', action='store_true')
parser.add_argument('--pretrain_gen', action='store_true')
parser.add_argument('--pretrain_sem', action='store_true')
parser.add_argument('--pretrain_disc', action='store_true')
parser.add_argument('--linggen_type', default='none')
parser.add_argument('--linggen_input', default='s+l')
parser.add_argument('--aug_same', action='store_true')
parser.add_argument('--ling_vae', action='store_true')
parser.add_argument('--process_lingpred', action='store_true')
parser.add_argument('--fudge_lambda', type=float, default=1.0)
parser.add_argument('--use_lingpred', action='store_true')
parser.add_argument('--ling2_only', action='store_true')
parser.add_argument('--cycle_loss', action='store_true')
parser.add_argument('--disc_loss', action='store_true')
parser.add_argument('--sem_loss', action='store_true')
parser.add_argument('--sim_loss', action='store_true')
parser.add_argument('--optuna', action='store_true')
parser.add_argument('--debug', action='store_true')
parser.add_argument('--demo', action='store_true')
parser.add_argument('--fudge', action='store_true')
parser.add_argument('--fb_log', default='feedback_logs/default.txt')
parser.add_argument('--eval_only', action='store_true')
parser.add_argument('--predict_with_feedback', action='store_true')
parser.add_argument('--feedback_param', default = 's')
parser.add_argument('--eval_ling', action='store_true')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--major_arg', default = 0, type=int)
parser.add_argument('--quantize_lng', action='store_true')
parser.add_argument('--quant_nbins', type=int, default=20)
parser.add_argument('--src_lng', default = 'ling')
parser.add_argument('--to_restore', nargs='+', default=[])
# args = parser.parse_args()
args, unknown = parser.parse_known_args()
args.name = f'{datetime.now().strftime("%m%d_%H-%M-%S")}-{args.data}-{args.combine_method}'
major_arg = args.major_arg
to_restore = [
] + args.to_restore
to_restore = {k: args.__dict__[k] for k in to_restore}
if not args.disc_loss or args.disc_ckpt:
args.disc_steps = 0
if args.data_sources is not None:
args.data_sources = args.data_sources.split(',')
if ckpt is not None:
args.ckpt = ckpt
args_list = [args]
if args.ckpt:
if ',' in args.ckpt:
ckpts = args.ckpt.split(',')
args_list = [deepcopy(args) for _ in range(len(ckpts))]
for i in range(len(ckpts)):
args_path = ckpts[i].replace('_best', '').replace('.pt', '.json')
with open(args_path) as f:
args_list[i].__dict__.update(json.load(f))
args_list[i].__dict__.update(to_restore)
args_list[i].ckpt = ckpts[i]
else:
args_path = args.ckpt.replace('_best', '').replace('.pt', '.json')
ckpt = args.ckpt
with open(args_path) as f:
args.__dict__.update(json.load(f))
args.__dict__.update(to_restore)
args.ckpt = ckpt
lng_names = lftkplus_names
for i in range(len(args_list)):
if args_list[i].lng_ids or args_list[i].lng_ids_idx:
if args_list[i].lng_ids_idx:
lng_ids = np.load(os.path.join(args_list[i].lng_ids_path, f'{args_list[i].lng_ids_idx}.npy'))
elif args_list[i].lng_ids[0].isnumeric():
lng_ids = [int(x) for x in args_list[i].lng_ids.split(',')]
elif ',' in args_list[i].lng_ids:
lng_ids = [lng_names.index(x) for x in args_list[i].lng_ids.split(',')]
else:
lng_ids = np.load(args_list[i].lng_ids)
args_list[i].lng_dim = len(lng_ids)
args_list[i].lng_ids = lng_ids.tolist()
# lng_names = [lng_names[i] for i in lng_ids]
elif args_list[i].use_ica:
args_list[i].lng_dim = args_list[i].n_ica
if args_list[i].disc_lng_dim is None:
args_list[i].disc_lng_dim = args_list[i].lng_dim
if not args.ckpt and not args.eval_only:
args_path = os.path.join(args.ckpt_dir, '%s.json'%args.name)
with open(args_path, 'w') as f:
s = json.dumps(args.__dict__)
f.write(s)
return args_list[major_arg], args_list, lng_names