Spaces:
Sleeping
Sleeping
import os, json | |
import argparse | |
import numpy as np | |
from datetime import datetime | |
from const import lftkplus_names | |
from copy import deepcopy | |
def parse_args(ckpt=None): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--data_dir', default='/data/mohamed/data') | |
parser.add_argument('--data', default='ling_conversion') | |
parser.add_argument('--data_sources') | |
parser.add_argument('--data_type', default='text') | |
parser.add_argument('--aim_repo', default='/data/mohamed/') | |
parser.add_argument('--ckpt_dir', default='/data/mohamed/checkpoints') | |
parser.add_argument('--kld_annealing', default='cyclic') | |
parser.add_argument('--lingpred_annealing', default='mono') | |
parser.add_argument('--ling_embed_type', default = 'one-layer') | |
parser.add_argument('--combine_weight', default=1, type=float) | |
parser.add_argument('--alpha_kld', default=1, type=float) | |
parser.add_argument('--alpha_lingpred', default=1, type=float) | |
parser.add_argument('--alpha_sem', default=1, type=float) | |
parser.add_argument('--max_grad_norm', default=10, type=float) | |
parser.add_argument('--sem_loss_tao', default=0.5, type=float) | |
parser.add_argument('--sem_loss_eps', default=1, type=float) | |
parser.add_argument('--ckpt') | |
parser.add_argument('--disc_ckpt') | |
parser.add_argument('--sem_ckpt') | |
parser.add_argument('--lng_ids') | |
parser.add_argument('--lng_ids_idx', type=int) | |
parser.add_argument('--lng_ids_path', default='/data/mohamed/indices') | |
parser.add_argument('--preds_dir', default='/data/mohamed/preds') | |
parser.add_argument('--model_name', default="google/flan-t5-base") | |
parser.add_argument('--disc_type', default="t5") | |
parser.add_argument('--aim_exp', default='ling-conversion') | |
parser.add_argument('--sem_loss_type', default='dedicated') | |
parser.add_argument('--combine_method', default='none') | |
parser.add_argument('--train_log', type=int, default=200) | |
parser.add_argument('--val_log', type=int, default=2000) | |
parser.add_argument('--batch_size', type=int, default=64) | |
parser.add_argument('--eval_batch_size', type=int, default=32) | |
parser.add_argument('--max_eval_samples', type=int, default=1000) | |
parser.add_argument('--test_batch_size', type=int, default=1) | |
parser.add_argument('--hidden_dim', type=int, default=500) | |
parser.add_argument('--latent_dim', type=int, default=150) | |
parser.add_argument('--lng_dim', type=int, default=40) | |
parser.add_argument('--disc_lng_dim', type=int) | |
parser.add_argument('--use_lora', action='store_true') | |
parser.add_argument('--lora_r', type=int, default=64) | |
parser.add_argument('--gpu', type=str, default='0') | |
parser.add_argument('--epochs', type=int, default=10) | |
parser.add_argument('--grad_accumulation', type=int, default=1) | |
parser.add_argument('--n_ica', type=int, default=10) | |
parser.add_argument('--max_length', type=int, default=200) | |
parser.add_argument('--total_steps', type=int) | |
parser.add_argument('--kld_const', type=float, default=1) | |
parser.add_argument('--lr', type=float, default=1e-4) | |
parser.add_argument('--kl_weight', type=float, default=1e-1) | |
parser.add_argument('--weight_decay', type=float, default=1e-2) | |
parser.add_argument('--ling_dropout', type=float, default=0.1) | |
parser.add_argument('--predict_fn', default = 'logs/test.txt') | |
parser.add_argument('--save_predict', action='store_true') | |
parser.add_argument('--use_ica', action='store_true') | |
parser.add_argument('--pretrain_gen', action='store_true') | |
parser.add_argument('--pretrain_sem', action='store_true') | |
parser.add_argument('--pretrain_disc', action='store_true') | |
parser.add_argument('--linggen_type', default='none') | |
parser.add_argument('--linggen_input', default='s+l') | |
parser.add_argument('--aug_same', action='store_true') | |
parser.add_argument('--ling_vae', action='store_true') | |
parser.add_argument('--process_lingpred', action='store_true') | |
parser.add_argument('--fudge_lambda', type=float, default=1.0) | |
parser.add_argument('--use_lingpred', action='store_true') | |
parser.add_argument('--ling2_only', action='store_true') | |
parser.add_argument('--cycle_loss', action='store_true') | |
parser.add_argument('--disc_loss', action='store_true') | |
parser.add_argument('--sem_loss', action='store_true') | |
parser.add_argument('--sim_loss', action='store_true') | |
parser.add_argument('--optuna', action='store_true') | |
parser.add_argument('--debug', action='store_true') | |
parser.add_argument('--demo', action='store_true') | |
parser.add_argument('--fudge', action='store_true') | |
parser.add_argument('--fb_log', default='feedback_logs/default.txt') | |
parser.add_argument('--eval_only', action='store_true') | |
parser.add_argument('--predict_with_feedback', action='store_true') | |
parser.add_argument('--feedback_param', default = 's') | |
parser.add_argument('--eval_ling', action='store_true') | |
parser.add_argument('--seed', type=int, default=0) | |
parser.add_argument('--major_arg', default = 0, type=int) | |
parser.add_argument('--quantize_lng', action='store_true') | |
parser.add_argument('--quant_nbins', type=int, default=20) | |
parser.add_argument('--src_lng', default = 'ling') | |
parser.add_argument('--to_restore', nargs='+', default=[]) | |
# args = parser.parse_args() | |
args, unknown = parser.parse_known_args() | |
args.name = f'{datetime.now().strftime("%m%d_%H-%M-%S")}-{args.data}-{args.combine_method}' | |
major_arg = args.major_arg | |
to_restore = [ | |
] + args.to_restore | |
to_restore = {k: args.__dict__[k] for k in to_restore} | |
if not args.disc_loss or args.disc_ckpt: | |
args.disc_steps = 0 | |
if args.data_sources is not None: | |
args.data_sources = args.data_sources.split(',') | |
if ckpt is not None: | |
args.ckpt = ckpt | |
args_list = [args] | |
if args.ckpt: | |
if ',' in args.ckpt: | |
ckpts = args.ckpt.split(',') | |
args_list = [deepcopy(args) for _ in range(len(ckpts))] | |
for i in range(len(ckpts)): | |
args_path = ckpts[i].replace('_best', '').replace('.pt', '.json') | |
with open(args_path) as f: | |
args_list[i].__dict__.update(json.load(f)) | |
args_list[i].__dict__.update(to_restore) | |
args_list[i].ckpt = ckpts[i] | |
else: | |
args_path = args.ckpt.replace('_best', '').replace('.pt', '.json') | |
ckpt = args.ckpt | |
with open(args_path) as f: | |
args.__dict__.update(json.load(f)) | |
args.__dict__.update(to_restore) | |
args.ckpt = ckpt | |
lng_names = lftkplus_names | |
for i in range(len(args_list)): | |
if args_list[i].lng_ids or args_list[i].lng_ids_idx: | |
if args_list[i].lng_ids_idx: | |
lng_ids = np.load(os.path.join(args_list[i].lng_ids_path, f'{args_list[i].lng_ids_idx}.npy')) | |
elif args_list[i].lng_ids[0].isnumeric(): | |
lng_ids = [int(x) for x in args_list[i].lng_ids.split(',')] | |
elif ',' in args_list[i].lng_ids: | |
lng_ids = [lng_names.index(x) for x in args_list[i].lng_ids.split(',')] | |
else: | |
lng_ids = np.load(args_list[i].lng_ids) | |
args_list[i].lng_dim = len(lng_ids) | |
args_list[i].lng_ids = lng_ids.tolist() | |
# lng_names = [lng_names[i] for i in lng_ids] | |
elif args_list[i].use_ica: | |
args_list[i].lng_dim = args_list[i].n_ica | |
if args_list[i].disc_lng_dim is None: | |
args_list[i].disc_lng_dim = args_list[i].lng_dim | |
if not args.ckpt and not args.eval_only: | |
args_path = os.path.join(args.ckpt_dir, '%s.json'%args.name) | |
with open(args_path, 'w') as f: | |
s = json.dumps(args.__dict__) | |
f.write(s) | |
return args_list[major_arg], args_list, lng_names | |