from dataclasses import dataclass, field from typing import Optional import pandas as pd import os import torch from transformers import VisionEncoderDecoderModel, TrOCRProcessor, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator, EarlyStoppingCallback, TrainerCallback, TrainingArguments, TrainerState, TrainerControl from peft import LoraConfig, get_peft_model from transformers import VisionEncoderDecoderConfig from data import AphaPenDataset import evaluate from sklearn.model_selection import train_test_split from src.calibrator import EncoderDecoderCalibrator from src.loss import MarginLoss, KLRegularization from src.similarity import CERSimilarity from datetime import datetime from torch.utils.data import ConcatDataset import wandb # @dataclass # class ScriptArguments: # """ # The name of the OCR model we wish to fine with Seq2SeqTrainer # """ # samp_size: Optional[int] = field(default=0, metadata={"help": "the additional sample size"}) # parser = HfArgumentParser(ScriptArguments) # script_args = parser.parse_args_into_dataclasses()[0] samp_list = [1, 15000, 30000, 45000, 60000, 70000] model_name = "microsoft/trocr-large-handwritten" # # Step 1: Load the dataset df_path = "/mnt/data1/Datasets/AlphaPen/" + "training_data.csv" df = pd.read_csv(df_path) df.dropna(inplace=True) train_df, test_df = train_test_split(df, test_size=0.02, random_state=0) # we reset the indices to start from zero train_df.reset_index(drop=True, inplace=True) test_df.reset_index(drop=True, inplace=True) df_path_b2= "/mnt/data1/Datasets/AlphaPen/" + "training_b2.csv" df_b2 = pd.read_csv(df_path_b2) df_b2.dropna(inplace=True) train_df_b2, test_df_b2 = train_test_split(df_b2, test_size=0.01, random_state=0) # we reset the indices to start from zero train_df_b2.reset_index(drop=True, inplace=True) test_df_b2.reset_index(drop=True, inplace=True) root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/final_cropped_rotated_" root_dir_b2 = "/mnt/data1/Datasets/OCR/Alphapen/DataBatch2/clean_data/cropped_data/cropped_" processor = TrOCRProcessor.from_pretrained(model_name) train_dataset_b1 = AphaPenDataset(root_dir=root_dir, df=train_df, processor=processor) eval_dataset_b1 = AphaPenDataset(root_dir=root_dir, df=test_df, processor=processor) eval_dataset_b2 = AphaPenDataset(root_dir=root_dir_b2, df=test_df_b2, processor=processor) # train_dataset = ConcatDataset([train_dataset_b1, train_dataset_b2]) eval_dataset = ConcatDataset([eval_dataset_b1, eval_dataset_b2]) # config = VisionEncoderDecoderConfig.from_pretrained(model_name) # config.decoder.vocab_size = config.decoder.decoder_vocab_size # Step 2: Load the model model = VisionEncoderDecoderModel.from_pretrained(model_name) # set special tokens used for creating the decoder_input_ids from the labels model.config.decoder_start_token_id = processor.tokenizer.cls_token_id model.config.pad_token_id = processor.tokenizer.pad_token_id # make sure vocab size is set correctly # model.config.vocab_size = model.config.decoder.vocab_size # for peft # model.vocab_size = model.config.decoder.vocab_size # set beam search parameters model.config.eos_token_id = processor.tokenizer.sep_token_id model.config.max_length = 64 model.config.early_stopping = True model.config.no_repeat_ngram_size = 3 model.config.length_penalty = 2.0 model.config.num_beams = 4 # print(model.config) # LoRa lora_config = LoraConfig( r=1, lora_alpha=8, lora_dropout=0.1, target_modules=[ 'query', 'key', 'value', 'intermediate.dense', 'output.dense', #'wte', #'wpe', #'c_attn', #'c_proj', #'q_attn', #'c_fc' ], # task_type="SEQ_2_SEQ_LM" ) model = get_peft_model(model, lora_config) # model.add_adapter(lora_config) # print(model.config) # tokenizer = processor.tokenizer # sim = CERSimilarity(tokenizer) # loss = MarginLoss(sim, beta=0.1, num_samples=60) # reg = KLRegularization(model) # calibrator = EncoderDecoderCalibrator(model, loss, reg, 15, 15) # from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR for samp in samp_list: os.environ["WANDB_PROJECT"] = "Alphapen-TrOCR" train_dataset_b2 = AphaPenDataset(root_dir=root_dir_b2, df=train_df_b2.iloc[:samp,:], processor=processor) train_dataset = ConcatDataset([train_dataset_b1, train_dataset_b2]) # # Step 3: Define the training arguments training_args = Seq2SeqTrainingArguments( predict_with_generate=True, evaluation_strategy="steps", per_device_train_batch_size=8, per_device_eval_batch_size=8, bf16=True, bf16_full_eval=True, output_dir="./", logging_steps=100, save_steps=1000, eval_steps=500, report_to="wandb", optim="adamw_torch_fused", lr_scheduler_type="cosine", gradient_accumulation_steps=2, learning_rate=1.0e-4, max_steps=15000, # run_name=f"trocr-LoRA-{datetime.now().strftime('%Y-%m-%d-%H-%M-%s')}", run_name="trocr-LoRA-large_" + str(samp), push_to_hub=True, hub_model_id="hadrakey/alphapen_trocr_large_" + str(samp), ) # Step 4: Define a metric def compute_metrics(pred): # accuracy_metric = evaluate.load("precision") cer_metric = evaluate.load("cer") labels_ids = pred.label_ids pred_ids = pred.predictions pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True) labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id label_str = processor.batch_decode(labels_ids, skip_special_tokens=True) pred_str = [word.lower() for word in pred_str] label_str = [word.lower() for word in label_str] cer = cer_metric.compute(predictions=pred_str, references=label_str) # accuracy = accuracy_metric.compute(predictions=pred_ids.tolist(), references=labels_ids.tolist()) return {"cer": cer} # # Step 5: Define the Trainer trainer = Seq2SeqTrainer( model=model, tokenizer=processor.feature_extractor, args=training_args, compute_metrics=compute_metrics, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=default_data_collator, # callbacks=[SavePeftModelCallback] ) trainer.train() wandb.finish()