File size: 7,901 Bytes
674b430
20b7679
674b430
20b7679
674b430
20b7679
674b430
20b7679
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
674b430
20b7679
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import os, json
import argparse
import numpy as np
from datetime import datetime
from const import lftkplus_names
from copy import deepcopy


def parse_args(ckpt=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', default='/data/mohamed/data')
    parser.add_argument('--data', default='ling_conversion')
    parser.add_argument('--data_sources')
    parser.add_argument('--data_type', default='text')
    parser.add_argument('--aim_repo', default='/data/mohamed/')
    parser.add_argument('--ckpt_dir', default='/data/mohamed/checkpoints')
    parser.add_argument('--kld_annealing', default='cyclic')
    parser.add_argument('--lingpred_annealing', default='mono')
    parser.add_argument('--ling_embed_type', default = 'one-layer')
    parser.add_argument('--combine_weight', default=1, type=float)
    parser.add_argument('--alpha_kld', default=1, type=float)
    parser.add_argument('--alpha_lingpred', default=1, type=float)
    parser.add_argument('--alpha_sem', default=1, type=float)
    parser.add_argument('--max_grad_norm', default=10, type=float)
    parser.add_argument('--sem_loss_tao', default=0.5, type=float)
    parser.add_argument('--sem_loss_eps', default=1, type=float)
    parser.add_argument('--ckpt')
    parser.add_argument('--disc_ckpt')
    parser.add_argument('--sem_ckpt')
    parser.add_argument('--lng_ids')
    parser.add_argument('--lng_ids_idx', type=int)
    parser.add_argument('--lng_ids_path', default='/data/mohamed/indices')
    parser.add_argument('--preds_dir', default='/data/mohamed/preds')
    parser.add_argument('--model_name', default="google/flan-t5-base")
    parser.add_argument('--disc_type', default="t5")
    parser.add_argument('--aim_exp', default='ling-conversion')
    parser.add_argument('--sem_loss_type', default='dedicated')
    parser.add_argument('--combine_method', default='none')
    parser.add_argument('--train_log', type=int, default=200)
    parser.add_argument('--val_log', type=int, default=2000)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--eval_batch_size', type=int, default=32)
    parser.add_argument('--max_eval_samples', type=int, default=1000)
    parser.add_argument('--test_batch_size', type=int, default=1)
    parser.add_argument('--hidden_dim', type=int, default=500)
    parser.add_argument('--latent_dim', type=int, default=150)
    parser.add_argument('--lng_dim', type=int, default=40)
    parser.add_argument('--disc_lng_dim', type=int)
    parser.add_argument('--use_lora', action='store_true')
    parser.add_argument('--lora_r', type=int, default=64)
    parser.add_argument('--gpu', type=str, default='0')
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--grad_accumulation', type=int, default=1)
    parser.add_argument('--n_ica', type=int, default=10)
    parser.add_argument('--max_length', type=int, default=200)
    parser.add_argument('--total_steps', type=int)
    parser.add_argument('--kld_const', type=float, default=1)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--kl_weight', type=float, default=1e-1)
    parser.add_argument('--weight_decay', type=float, default=1e-2)
    parser.add_argument('--ling_dropout', type=float, default=0.1)
    parser.add_argument('--predict_fn', default = 'logs/test.txt')
    parser.add_argument('--save_predict', action='store_true')
    parser.add_argument('--use_ica', action='store_true')
    parser.add_argument('--pretrain_gen', action='store_true')
    parser.add_argument('--pretrain_sem', action='store_true')
    parser.add_argument('--pretrain_disc', action='store_true')
    parser.add_argument('--linggen_type', default='none')
    parser.add_argument('--linggen_input', default='s+l')
    parser.add_argument('--aug_same', action='store_true')
    parser.add_argument('--ling_vae', action='store_true')
    parser.add_argument('--process_lingpred', action='store_true')
    parser.add_argument('--fudge_lambda', type=float, default=1.0)
    parser.add_argument('--use_lingpred', action='store_true')
    parser.add_argument('--ling2_only', action='store_true')
    parser.add_argument('--cycle_loss', action='store_true')
    parser.add_argument('--disc_loss', action='store_true')
    parser.add_argument('--sem_loss', action='store_true')
    parser.add_argument('--sim_loss', action='store_true')
    parser.add_argument('--optuna', action='store_true')
    parser.add_argument('--debug', action='store_true')
    parser.add_argument('--demo', action='store_true')
    parser.add_argument('--fudge', action='store_true')
    parser.add_argument('--fb_log', default='feedback_logs/default.txt')
    parser.add_argument('--eval_only', action='store_true')
    parser.add_argument('--predict_with_feedback', action='store_true')
    parser.add_argument('--feedback_param', default = 's')
    parser.add_argument('--eval_ling', action='store_true')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--major_arg', default = 0, type=int)
    parser.add_argument('--quantize_lng', action='store_true')
    parser.add_argument('--quant_nbins', type=int, default=20)
    parser.add_argument('--src_lng', default = 'ling')
    parser.add_argument('--to_restore', nargs='+', default=[])
    # args = parser.parse_args()
    args, unknown = parser.parse_known_args()
    args.name = f'{datetime.now().strftime("%m%d_%H-%M-%S")}-{args.data}-{args.combine_method}'

    major_arg = args.major_arg
    to_restore = [
            ] + args.to_restore
    to_restore = {k: args.__dict__[k] for k in to_restore}

    if not args.disc_loss or args.disc_ckpt:
        args.disc_steps = 0

    if args.data_sources is not None:
        args.data_sources = args.data_sources.split(',')

    if ckpt is not None:
        args.ckpt = ckpt

    args_list = [args]
    if args.ckpt:
        if ',' in args.ckpt:
            ckpts = args.ckpt.split(',')
            args_list = [deepcopy(args) for _ in range(len(ckpts))]
            for i in range(len(ckpts)):
                args_path = ckpts[i].replace('_best', '').replace('.pt', '.json')
                with open(args_path) as f:
                    args_list[i].__dict__.update(json.load(f))
                args_list[i].__dict__.update(to_restore)
                args_list[i].ckpt = ckpts[i]
        else:
            args_path = args.ckpt.replace('_best', '').replace('.pt', '.json')
            ckpt = args.ckpt
            with open(args_path) as f:
                args.__dict__.update(json.load(f))
                args.__dict__.update(to_restore)
                args.ckpt = ckpt

    lng_names = lftkplus_names
    for i in range(len(args_list)):
        if args_list[i].lng_ids or args_list[i].lng_ids_idx:
            if args_list[i].lng_ids_idx:
                lng_ids = np.load(os.path.join(args_list[i].lng_ids_path, f'{args_list[i].lng_ids_idx}.npy'))
            elif args_list[i].lng_ids[0].isnumeric():
                lng_ids = [int(x) for x in args_list[i].lng_ids.split(',')]
            elif ',' in args_list[i].lng_ids:
                lng_ids = [lng_names.index(x) for x in args_list[i].lng_ids.split(',')]
            else:
                lng_ids = np.load(args_list[i].lng_ids)
            args_list[i].lng_dim = len(lng_ids)
            args_list[i].lng_ids = lng_ids.tolist()
            # lng_names = [lng_names[i] for i in lng_ids]
        elif args_list[i].use_ica:
            args_list[i].lng_dim = args_list[i].n_ica
        if args_list[i].disc_lng_dim is None:
            args_list[i].disc_lng_dim = args_list[i].lng_dim

    if not args.ckpt and not args.eval_only:
        args_path = os.path.join(args.ckpt_dir, '%s.json'%args.name)
        with open(args_path, 'w') as f:
            s = json.dumps(args.__dict__)
            f.write(s)

    return args_list[major_arg], args_list, lng_names