Spaces:
Runtime error
Runtime error
from pytorch_lightning import Trainer, seed_everything | |
from alignscore.dataloader import DSTDataLoader | |
from alignscore.model import BERTAlignModel | |
from pytorch_lightning.callbacks import ModelCheckpoint | |
from argparse import ArgumentParser | |
import os | |
def train(datasets, args): | |
dm = DSTDataLoader( | |
dataset_config=datasets, | |
model_name=args.model_name, | |
sample_mode='seq', | |
train_batch_size=args.batch_size, | |
eval_batch_size=16, | |
num_workers=args.num_workers, | |
train_eval_split=0.95, | |
need_mlm=args.do_mlm | |
) | |
dm.setup() | |
model = BERTAlignModel(model=args.model_name, using_pretrained=args.use_pretrained_model, | |
adam_epsilon=args.adam_epsilon, | |
learning_rate=args.learning_rate, | |
weight_decay=args.weight_decay, | |
warmup_steps_portion=args.warm_up_proportion | |
) | |
model.need_mlm = args.do_mlm | |
training_dataset_used = '_'.join(datasets.keys()) | |
checkpoint_name = '_'.join(( | |
f"{args.ckpt_comment}{args.model_name.replace('/', '-')}", | |
f"{'scratch_' if not args.use_pretrained_model else ''}{'no_mlm_' if not args.do_mlm else ''}{training_dataset_used}", | |
str(args.max_samples_per_dataset), | |
f"{args.batch_size}x{len(args.devices)}x{args.accumulate_grad_batch}" | |
)) | |
checkpoint_callback = ModelCheckpoint( | |
dirpath=args.ckpt_save_path, | |
filename=checkpoint_name + "_{epoch:02d}_{step}", | |
every_n_train_steps=10000, | |
save_top_k=1 | |
) | |
trainer = Trainer( | |
accelerator='gpu', | |
max_epochs=args.num_epoch, | |
devices=args.devices, | |
strategy="dp", | |
precision=32, | |
callbacks=[checkpoint_callback], | |
accumulate_grad_batches=args.accumulate_grad_batch | |
) | |
trainer.fit(model, datamodule=dm) | |
trainer.save_checkpoint(os.path.join(args.ckpt_save_path, f"{checkpoint_name}_final.ckpt")) | |
print("Training is finished.") | |
if __name__ == "__main__": | |
ALL_TRAINING_DATASETS = { | |
### NLI | |
'mnli': {'task_type': 'nli', 'data_path': 'mnli.json'}, | |
'doc_nli': {'task_type': 'bin_nli', 'data_path': 'doc_nli.json'}, | |
'snli': {'task_type': 'nli', 'data_path': 'snli.json'}, | |
'anli_r1': {'task_type': 'nli', 'data_path': 'anli_r1.json'}, | |
'anli_r2': {'task_type': 'nli', 'data_path': 'anli_r2.json'}, | |
'anli_r3': {'task_type': 'nli', 'data_path': 'anli_r3.json'}, | |
### fact checking | |
'nli_fever': {'task_type': 'fact_checking', 'data_path': 'nli_fever.json'}, | |
'vitaminc': {'task_type': 'fact_checking', 'data_path': 'vitaminc.json'}, | |
### paraphrase | |
'paws': {'task_type': 'paraphrase', 'data_path': 'paws.json'}, | |
'paws_qqp': {'task_type': 'paraphrase', 'data_path': 'paws_qqp.json'}, | |
'paws_unlabeled': {'task_type': 'paraphrase', 'data_path': 'paws_unlabeled.json'}, | |
'qqp': {'task_type': 'paraphrase', 'data_path': 'qqp.json'}, | |
'wiki103': {'task_type': 'paraphrase', 'data_path': 'wiki103.json'}, | |
### QA | |
'squad_v2': {'task_type': 'qa', 'data_path': 'squad_v2_new.json'}, | |
'race': {'task_type': 'qa', 'data_path': 'race.json'}, | |
'adversarial_qa': {'task_type': 'qa', 'data_path': 'adversarial_qa.json'}, | |
'drop': {'task_type': 'qa', 'data_path': 'drop.json'}, | |
'hotpot_qa_distractor': {'task_type': 'qa', 'data_path': 'hotpot_qa_distractor.json'}, | |
'hotpot_qa_fullwiki': {'task_type': 'qa', 'data_path': 'hotpot_qa_fullwiki.json'}, | |
'newsqa': {'task_type': 'qa', 'data_path': 'newsqa.json'}, | |
'quoref': {'task_type': 'qa', 'data_path': 'quoref.json'}, | |
'ropes': {'task_type': 'qa', 'data_path': 'ropes.json'}, | |
'boolq': {'task_type': 'qa', 'data_path': 'boolq.json'}, | |
'eraser_multi_rc': {'task_type': 'qa', 'data_path': 'eraser_multi_rc.json'}, | |
'quail': {'task_type': 'qa', 'data_path': 'quail.json'}, | |
'sciq': {'task_type': 'qa', 'data_path': 'sciq.json'}, | |
'strategy_qa': {'task_type': 'qa', 'data_path': 'strategy_qa.json'}, | |
### Coreference | |
'gap': {'task_type': 'coreference', 'data_path': 'gap.json'}, | |
### Summarization | |
'wikihow': {'task_type': 'summarization', 'data_path': 'wikihow.json'}, | |
### Information Retrieval | |
'msmarco': {'task_type': 'ir', 'data_path': 'msmarco.json'}, | |
### STS | |
'stsb': {'task_type': 'sts', 'data_path': 'stsb.json'}, | |
'sick': {'task_type': 'sts', 'data_path': 'sick.json'}, | |
} | |
parser = ArgumentParser() | |
parser.add_argument('--seed', type=int, default=2022) | |
parser.add_argument('--batch-size', type=int, default=32) | |
parser.add_argument('--accumulate-grad-batch', type=int, default=1) | |
parser.add_argument('--num-epoch', type=int, default=3) | |
parser.add_argument('--num-workers', type=int, default=8) | |
parser.add_argument('--warm-up-proportion', type=float, default=0.06) | |
parser.add_argument('--adam-epsilon', type=float, default=1e-6) | |
parser.add_argument('--weight-decay', type=float, default=0.1) | |
parser.add_argument('--learning-rate', type=float, default=1e-5) | |
parser.add_argument('--val-check-interval', type=float, default=1. / 4) | |
parser.add_argument('--devices', nargs='+', type=int, required=True) | |
parser.add_argument('--model-name', type=str, default="roberta-large") | |
parser.add_argument('--ckpt-save-path', type=str, required=True) | |
parser.add_argument('--ckpt-comment', type=str, default="") | |
parser.add_argument('--trainin-datasets', nargs='+', type=str, default=list(ALL_TRAINING_DATASETS.keys()), choices=list(ALL_TRAINING_DATASETS.keys())) | |
parser.add_argument('--data-path', type=str, required=True) | |
parser.add_argument('--max-samples-per-dataset', type=int, default=500000) | |
parser.add_argument('--do-mlm', type=bool, default=False) | |
parser.add_argument('--use-pretrained-model', type=bool, default=True) | |
args = parser.parse_args() | |
seed_everything(args.seed) | |
datasets = { | |
name: { | |
**ALL_TRAINING_DATASETS[name], | |
"size": args.max_samples_per_dataset, | |
"data_path": os.path.join(args.data_path, ALL_TRAINING_DATASETS[name]['data_path']) | |
} | |
for name in args.trainin_datasets | |
} | |
train(datasets, args) | |