def train( push_to_hub:bool, num_epoch: int, train_batch_size: int, eval_batch_size: int, ): import torch import numpy as np # 1. Dataset from datasets import load_dataset dataset = load_dataset("Adapting/abstract-keyphrases") # 2. Model from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from lrt.clustering.models import KeyBartAdapter tokenizer = AutoTokenizer.from_pretrained("Adapting/KeyBartAdapter") ''' Or you can just use the initial model weights from Huggingface: model = AutoModelForSeq2SeqLM.from_pretrained("Adapting/KeyBartAdapter", revision='9c3ed39c6ed5c7e141363e892d77cf8f589d5999') ''' model = KeyBartAdapter(256) # 3. preprocess dataset dataset = dataset.shuffle() def preprocess_function(examples): inputs = examples['Abstract'] targets = examples['Keywords'] model_inputs = tokenizer(inputs, truncation=True) # Set up the tokenizer for targets with tokenizer.as_target_tokenizer(): labels = tokenizer(targets, truncation=True) model_inputs["labels"] = labels["input_ids"] return model_inputs tokenized_dataset = dataset.map( preprocess_function, batched=True, remove_columns=dataset["train"].column_names, ) # 4. evaluation metrics def compute_metrics(eval_preds): preds = eval_preds.predictions labels = eval_preds.label_ids if isinstance(preds, tuple): preds = preds[0] print(preds.shape) if len(preds.shape) == 3: preds = preds.argmax(axis=-1) decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) # Replace -100 in the labels as we can't decode them. labels = np.where(labels != -100, labels, tokenizer.pad_token_id) decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) # Some simple post-processing decoded_preds = [a.strip().split(';') for a in decoded_preds] decoded_labels = [a.strip().split(';') for a in decoded_labels] precs, recalls, f_scores = [], [], [] num_match, num_pred, num_gold = [], [], [] for pred, label in zip(decoded_preds, decoded_labels): pred_set = set(pred) label_set = set(label) match_set = label_set.intersection(pred_set) p = float(len(match_set)) / float(len(pred_set)) if len(pred_set) > 0 else 0.0 r = float(len(match_set)) / float(len(label_set)) if len(label_set) > 0 else 0.0 f1 = float(2 * (p * r)) / (p + r) if (p + r) > 0 else 0.0 precs.append(p) recalls.append(r) f_scores.append(f1) num_match.append(len(match_set)) num_pred.append(len(pred_set)) num_gold.append(len(label_set)) # print(f'raw_PRED: {raw_pred}') print(f'PRED: num={len(pred_set)} - {pred_set}') print(f'GT: num={len(label_set)} - {label_set}') print(f'p={p}, r={r}, f1={f1}') print('-' * 20) result = { 'precision@M': np.mean(precs) * 100.0, 'recall@M': np.mean(recalls) * 100.0, 'fscore@M': np.mean(f_scores) * 100.0, 'num_match': np.mean(num_match), 'num_pred': np.mean(num_pred), 'num_gold': np.mean(num_gold), } result = {k: round(v, 2) for k, v in result.items()} return result # 5. train from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer data_collator = DataCollatorForSeq2Seq(tokenizer, model=model) model_name = 'KeyBartAdapter' args = Seq2SeqTrainingArguments( model_name, evaluation_strategy="epoch", save_strategy="epoch", learning_rate=2e-5, per_device_train_batch_size=train_batch_size, per_device_eval_batch_size=eval_batch_size, weight_decay=0.01, save_total_limit=3, num_train_epochs=num_epoch, logging_steps=4, load_best_model_at_end=True, metric_for_best_model='fscore@M', predict_with_generate=True, fp16=torch.cuda.is_available(), # speeds up training on modern GPUs. # eval_accumulation_steps=10, ) trainer = Seq2SeqTrainer( model, args, train_dataset=tokenized_dataset["train"], eval_dataset=tokenized_dataset["train"], data_collator=data_collator, tokenizer=tokenizer, compute_metrics=compute_metrics ) trainer.train() # 6. push if push_to_hub: commit_msg = f'{model_name}_{num_epoch}' tokenizer.push_to_hub(commit_message=commit_msg, repo_id=model_name) model.push_to_hub(commit_message=commit_msg, repo_id=model_name) return model, tokenizer if __name__ == '__main__': import sys from pathlib import Path project_root = Path(__file__).parent.parent.parent.absolute() sys.path.append(project_root.__str__()) # code import argparse parser = argparse.ArgumentParser() parser.add_argument("--epoch", help="number of epochs", default=30) parser.add_argument("--train_batch_size", help="training batch size", default=16) parser.add_argument("--eval_batch_size", help="evaluation batch size", default=16) parser.add_argument("--push", help="whether push the model to hub", action='store_true') args = parser.parse_args() print(args) model, tokenizer = train( push_to_hub= bool(args.push), num_epoch= int(args.epoch), train_batch_size= int(args.train_batch_size), eval_batch_size= int(args.eval_batch_size) )