from pytorch_lightning import Trainer, seed_everything from alignscore.dataloader import DSTDataLoader from alignscore.model import BERTAlignModel from pytorch_lightning.callbacks import ModelCheckpoint from argparse import ArgumentParser import os def train(datasets, args): dm = DSTDataLoader( dataset_config=datasets, model_name=args.model_name, sample_mode='seq', train_batch_size=args.batch_size, eval_batch_size=16, num_workers=args.num_workers, train_eval_split=0.95, need_mlm=args.do_mlm ) dm.setup() model = BERTAlignModel(model=args.model_name, using_pretrained=args.use_pretrained_model, adam_epsilon=args.adam_epsilon, learning_rate=args.learning_rate, weight_decay=args.weight_decay, warmup_steps_portion=args.warm_up_proportion ) model.need_mlm = args.do_mlm training_dataset_used = '_'.join(datasets.keys()) checkpoint_name = '_'.join(( f"{args.ckpt_comment}{args.model_name.replace('/', '-')}", f"{'scratch_' if not args.use_pretrained_model else ''}{'no_mlm_' if not args.do_mlm else ''}{training_dataset_used}", str(args.max_samples_per_dataset), f"{args.batch_size}x{len(args.devices)}x{args.accumulate_grad_batch}" )) checkpoint_callback = ModelCheckpoint( dirpath=args.ckpt_save_path, filename=checkpoint_name + "_{epoch:02d}_{step}", every_n_train_steps=10000, save_top_k=1 ) trainer = Trainer( accelerator='gpu', max_epochs=args.num_epoch, devices=args.devices, strategy="dp", precision=32, callbacks=[checkpoint_callback], accumulate_grad_batches=args.accumulate_grad_batch ) trainer.fit(model, datamodule=dm) trainer.save_checkpoint(os.path.join(args.ckpt_save_path, f"{checkpoint_name}_final.ckpt")) print("Training is finished.") if __name__ == "__main__": ALL_TRAINING_DATASETS = { ### NLI 'mnli': {'task_type': 'nli', 'data_path': 'mnli.json'}, 'doc_nli': {'task_type': 'bin_nli', 'data_path': 'doc_nli.json'}, 'snli': {'task_type': 'nli', 'data_path': 'snli.json'}, 'anli_r1': {'task_type': 'nli', 'data_path': 'anli_r1.json'}, 'anli_r2': {'task_type': 'nli', 'data_path': 'anli_r2.json'}, 'anli_r3': {'task_type': 'nli', 'data_path': 'anli_r3.json'}, ### fact checking 'nli_fever': {'task_type': 'fact_checking', 'data_path': 'nli_fever.json'}, 'vitaminc': {'task_type': 'fact_checking', 'data_path': 'vitaminc.json'}, ### paraphrase 'paws': {'task_type': 'paraphrase', 'data_path': 'paws.json'}, 'paws_qqp': {'task_type': 'paraphrase', 'data_path': 'paws_qqp.json'}, 'paws_unlabeled': {'task_type': 'paraphrase', 'data_path': 'paws_unlabeled.json'}, 'qqp': {'task_type': 'paraphrase', 'data_path': 'qqp.json'}, 'wiki103': {'task_type': 'paraphrase', 'data_path': 'wiki103.json'}, ### QA 'squad_v2': {'task_type': 'qa', 'data_path': 'squad_v2_new.json'}, 'race': {'task_type': 'qa', 'data_path': 'race.json'}, 'adversarial_qa': {'task_type': 'qa', 'data_path': 'adversarial_qa.json'}, 'drop': {'task_type': 'qa', 'data_path': 'drop.json'}, 'hotpot_qa_distractor': {'task_type': 'qa', 'data_path': 'hotpot_qa_distractor.json'}, 'hotpot_qa_fullwiki': {'task_type': 'qa', 'data_path': 'hotpot_qa_fullwiki.json'}, 'newsqa': {'task_type': 'qa', 'data_path': 'newsqa.json'}, 'quoref': {'task_type': 'qa', 'data_path': 'quoref.json'}, 'ropes': {'task_type': 'qa', 'data_path': 'ropes.json'}, 'boolq': {'task_type': 'qa', 'data_path': 'boolq.json'}, 'eraser_multi_rc': {'task_type': 'qa', 'data_path': 'eraser_multi_rc.json'}, 'quail': {'task_type': 'qa', 'data_path': 'quail.json'}, 'sciq': {'task_type': 'qa', 'data_path': 'sciq.json'}, 'strategy_qa': {'task_type': 'qa', 'data_path': 'strategy_qa.json'}, ### Coreference 'gap': {'task_type': 'coreference', 'data_path': 'gap.json'}, ### Summarization 'wikihow': {'task_type': 'summarization', 'data_path': 'wikihow.json'}, ### Information Retrieval 'msmarco': {'task_type': 'ir', 'data_path': 'msmarco.json'}, ### STS 'stsb': {'task_type': 'sts', 'data_path': 'stsb.json'}, 'sick': {'task_type': 'sts', 'data_path': 'sick.json'}, } parser = ArgumentParser() parser.add_argument('--seed', type=int, default=2022) parser.add_argument('--batch-size', type=int, default=32) parser.add_argument('--accumulate-grad-batch', type=int, default=1) parser.add_argument('--num-epoch', type=int, default=3) parser.add_argument('--num-workers', type=int, default=8) parser.add_argument('--warm-up-proportion', type=float, default=0.06) parser.add_argument('--adam-epsilon', type=float, default=1e-6) parser.add_argument('--weight-decay', type=float, default=0.1) parser.add_argument('--learning-rate', type=float, default=1e-5) parser.add_argument('--val-check-interval', type=float, default=1. / 4) parser.add_argument('--devices', nargs='+', type=int, required=True) parser.add_argument('--model-name', type=str, default="roberta-large") parser.add_argument('--ckpt-save-path', type=str, required=True) parser.add_argument('--ckpt-comment', type=str, default="") parser.add_argument('--trainin-datasets', nargs='+', type=str, default=list(ALL_TRAINING_DATASETS.keys()), choices=list(ALL_TRAINING_DATASETS.keys())) parser.add_argument('--data-path', type=str, required=True) parser.add_argument('--max-samples-per-dataset', type=int, default=500000) parser.add_argument('--do-mlm', type=bool, default=False) parser.add_argument('--use-pretrained-model', type=bool, default=True) args = parser.parse_args() seed_everything(args.seed) datasets = { name: { **ALL_TRAINING_DATASETS[name], "size": args.max_samples_per_dataset, "data_path": os.path.join(args.data_path, ALL_TRAINING_DATASETS[name]['data_path']) } for name in args.trainin_datasets } train(datasets, args)