from dataclasses import dataclass, field from typing import Optional import pandas as pd import torch from accelerate import Accelerator from datasets import load_dataset, Dataset, load_metric from peft import LoraConfig from tqdm import tqdm from transformers import AutoModelForCausalLM, BitsAndBytesConfig, HfArgumentParser, TrainingArguments, VisionEncoderDecoderModel, TrOCRProcessor, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator, EarlyStoppingCallback from trl import SFTTrainer, is_xpu_available from data import AphaPenDataset import evaluate from sklearn.model_selection import train_test_split import torchvision.transforms as transforms # from utils import compute_metrics from src.calibrator import EncoderDecoderCalibrator from src.loss import MarginLoss, KLRegularization from src.similarity import CERSimilarity import os tqdm.pandas() os.environ["WANDB_PROJECT"]="Alphapen" # Define and parse arguments. @dataclass class ScriptArguments: """ The name of the OCR model we wish to fine with Seq2SeqTrainer """ model_name: Optional[str] = field(default="microsoft/trocr-base-handwritten", metadata={"help": "the model name"}) dataset_name: Optional[str] = field( default="Anthropic/hh-rlhf", metadata={"help": "the dataset name"} ) log_with: Optional[str] = field(default="none", metadata={"help": "use 'wandb' to log with wandb"}) learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"}) batch_size: Optional[int] = field(default=8, metadata={"help": "the batch size"}) seq_length: Optional[int] = field(default=512, metadata={"help": "Input sequence length"}) gradient_accumulation_steps: Optional[int] = field( default=16, metadata={"help": "the number of gradient accumulation steps"} ) load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"}) load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"}) use_peft: Optional[bool] = field(default=False, metadata={"help": "Wether to use PEFT or not to train adapters"}) trust_remote_code: Optional[bool] = field(default=False, metadata={"help": "Enable `trust_remote_code`"}) output_dir: Optional[str] = field(default="output", metadata={"help": "the output directory"}) peft_lora_r: Optional[int] = field(default=64, metadata={"help": "the r parameter of the LoRA adapters"}) peft_lora_alpha: Optional[int] = field(default=16, metadata={"help": "the alpha parameter of the LoRA adapters"}) logging_steps: Optional[int] = field(default=1, metadata={"help": "the number of logging steps"}) use_auth_token: Optional[bool] = field(default=True, metadata={"help": "Use HF auth token to access the model"}) num_train_epochs: Optional[int] = field(default=3, metadata={"help": "the number of training epochs"}) max_steps: Optional[int] = field(default=-1, metadata={"help": "the number of training steps"}) max_length: Optional[int] = field(default=10, metadata={"help": "the maximum length"}) no_repeat_ngram_size: Optional[int] = field(default=3, metadata={"help": "the number of repeat"}) length_penalty: Optional[float] = field(default=2.0, metadata={"help": "the length of penalty"}) num_beams: Optional[int] = field(default=3, metadata={"help": "the number of beam search"}) early_stopping: Optional[bool] = field(default=True, metadata={"help": "Early stopping"}) save_steps: Optional[int] = field( default=1000, metadata={"help": "Number of updates steps before two checkpoint saves"} ) save_total_limit: Optional[int] = field(default=10, metadata={"help": "Limits total number of checkpoints."}) push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the model to HF Hub"}) gradient_checkpointing: Optional[bool] = field( default=False, metadata={"help": "Whether to use gradient checkpointing or no"} ) gradient_checkpointing_kwargs: Optional[dict] = field( default=None, metadata={ "help": "key word arguments to be passed along `torch.utils.checkpoint.checkpoint` method - e.g. `use_reentrant=False`" }, ) hub_model_id: Optional[str] = field(default=None, metadata={"help": "The name of the model on HF Hub"}) parser = HfArgumentParser(ScriptArguments) script_args = parser.parse_args_into_dataclasses()[0] # # Step 1: Load the dataset df_path = "/mnt/data1/Datasets/AlphaPen/" + "training_data.csv" df = pd.read_csv(df_path) df.dropna(inplace=True) train_df, test_df = train_test_split(df, test_size=0.15, random_state=0) # we reset the indices to start from zero train_df.reset_index(drop=True, inplace=True) test_df.reset_index(drop=True, inplace=True) root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/final_cropped_rotated_" processor = TrOCRProcessor.from_pretrained(script_args.model_name) train_dataset = AphaPenDataset(root_dir=root_dir, df=train_df, processor=processor) eval_dataset = AphaPenDataset(root_dir=root_dir, df=test_df, processor=processor) # Step 2: Load the model if script_args.load_in_8bit and script_args.load_in_4bit: raise ValueError("You can't load the model in 8 bits and 4 bits at the same time") elif script_args.load_in_8bit or script_args.load_in_4bit: quantization_config = BitsAndBytesConfig( load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit ) # Copy the model to each device device_map = ( {"": f"xpu:{Accelerator().local_process_index}"} if is_xpu_available() else {"": Accelerator().local_process_index} ) torch_dtype = torch.bfloat16 else: device_map = None quantization_config = None torch_dtype = None model = VisionEncoderDecoderModel.from_pretrained( script_args.model_name, quantization_config=quantization_config, device_map=device_map, trust_remote_code=script_args.trust_remote_code, torch_dtype=torch_dtype, token=script_args.use_auth_token, ) # set special tokens used for creating the decoder_input_ids from the labels model.config.decoder_start_token_id = processor.tokenizer.cls_token_id model.config.pad_token_id = processor.tokenizer.pad_token_id # make sure vocab size is set correctly model.config.vocab_size = model.config.decoder.vocab_size # set beam search parameters model.config.eos_token_id = processor.tokenizer.sep_token_id model.config.max_length = script_args.max_length model.config.early_stopping = script_args.early_stopping model.config.no_repeat_ngram_size = script_args.no_repeat_ngram_size model.config.length_penalty = script_args.length_penalty model.config.num_beams = script_args.num_beams tokenizer = processor.tokenizer sim = CERSimilarity(tokenizer) loss = MarginLoss(sim, beta=0.1, num_samples=60) reg = KLRegularization(model) calibrator = EncoderDecoderCalibrator(model, loss, reg, 15, 15) # # Step 3: Define the training arguments training_args = Seq2SeqTrainingArguments( predict_with_generate=True, evaluation_strategy="steps", per_device_train_batch_size=script_args.batch_size, per_device_eval_batch_size=script_args.batch_size, fp16=True, output_dir=script_args.output_dir, logging_steps=script_args.logging_steps, save_steps=script_args.save_steps, eval_steps=100, save_total_limit=script_args.save_total_limit, # load_best_model_at_end = True, report_to=script_args.log_with, num_train_epochs=script_args.num_train_epochs, push_to_hub=script_args.push_to_hub, hub_model_id=script_args.hub_model_id, gradient_checkpointing=script_args.gradient_checkpointing, # metric_for_best_model="eval/cer" # TODO: uncomment that on the next release # gradient_checkpointing_kwargs=script_args.gradient_checkpointing_kwargs, ) # Step 4: Define a metric def compute_metrics(pred): # accuracy_metric = evaluate.load("precision") cer_metric = evaluate.load("cer") labels_ids = pred.label_ids pred_ids = pred.predictions pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True) labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id label_str = processor.batch_decode(labels_ids, skip_special_tokens=True) cer = cer_metric.compute(predictions=pred_str, references=label_str) # accuracy = accuracy_metric.compute(predictions=pred_ids.tolist(), references=labels_ids.tolist()) return {"cer": cer} early_stop = EarlyStoppingCallback(10, .001) # # Step 5: Define the Trainer trainer = Seq2SeqTrainer( model=model, tokenizer=processor.feature_extractor, args=training_args, compute_metrics=compute_metrics, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=default_data_collator, # callbacks = [early_stop] ) trainer.train() # # Step 6: Save the model # trainer.save_model(script_args.output_dir)