from dataclasses import dataclass, field from typing import Optional import pandas as pd import os import torch from transformers import VisionEncoderDecoderModel, TrOCRProcessor, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator, EarlyStoppingCallback from peft import LoraConfig, get_peft_model from data import AphaPenDataset import evaluate from sklearn.model_selection import train_test_split from src.calibrator import EncoderDecoderCalibrator from src.loss import MarginLoss, KLRegularization from src.similarity import CERSimilarity from datetime import datetime import torch.nn.functional as F os.environ["WANDB_PROJECT"] = "Alphapen-TrOCR" # # Step 1: Load the dataset train_df_path = "/mnt/data1/Datasets/AlphaPen/" + "training_data.csv" test_df_path = "/mnt/data1/Datasets/AlphaPen/" + "testing_data.csv" #train_df = pd.read_csv(train_df_path) #train_df.dropna(inplace=True) train_df = pd.read_csv(test_df_path)[:4000] train_df.dropna(inplace=True) test_df = pd.read_csv(test_df_path)[4000:] test_df.dropna(inplace=True) # we reset the indices to start from zero train_df.reset_index(drop=True, inplace=True) test_df.reset_index(drop=True, inplace=True) root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/final_cropped_rotated_" model_name = "microsoft/trocr-large-handwritten" processor = TrOCRProcessor.from_pretrained(model_name) train_dataset = AphaPenDataset(root_dir=root_dir, df=train_df, processor=processor) eval_dataset = AphaPenDataset(root_dir=root_dir, df=test_df, processor=processor) # Step 2: Load the model model = VisionEncoderDecoderModel.from_pretrained(model_name) # set special tokens used for creating the decoder_input_ids from the labels model.config.decoder_start_token_id = processor.tokenizer.cls_token_id model.config.pad_token_id = processor.tokenizer.pad_token_id # make sure vocab size is set correctly model.config.vocab_size = model.config.decoder.vocab_size # for peft model.vocab_size = model.config.decoder.vocab_size # set beam search parameters model.config.eos_token_id = processor.tokenizer.sep_token_id model.config.max_length = 64 model.config.early_stopping = True model.config.no_repeat_ngram_size = 3 model.config.length_penalty = 2.0 model.config.num_beams = 4 # LoRa lora_config = LoraConfig( r=1, lora_alpha=8, lora_dropout=0.1, target_modules=[ 'query', 'key', 'value', 'intermediate.dense', 'output.dense', #'wte', #'wpe', #'c_attn', #'c_proj', #'q_attn', #'c_fc' ], ) model = get_peft_model(model, lora_config) tokenizer = processor.tokenizer # sim = CERSimilarity(tokenizer) # loss = MarginLoss(sim, beta=0.1, num_samples=60) # reg = KLRegularization(model) # calibrator = EncoderDecoderCalibrator(model, loss, reg, 15, 15) # # Step 3: Define the training arguments training_args = Seq2SeqTrainingArguments( predict_with_generate=True, evaluation_strategy="steps", per_device_train_batch_size=8, per_device_eval_batch_size=8, bf16=True, bf16_full_eval=True, output_dir="./", logging_steps=100, save_steps=20000, eval_steps=500, # report_to="wandb", optim="adamw_torch_fused", lr_scheduler_type="cosine", gradient_accumulation_steps=2, learning_rate=1.0e-4, max_steps=10000, run_name=f"trocr-LoRA-{datetime.now().strftime('%Y-%m-%d-%H-%M-%s')}", ) # Step 4: Define a metric cer_metric = evaluate.load("cer") def compute_cer(pred, target): return cer_metric.compute(predictions=[pred], references=[target])['cer'] def generate_candidates(model, pixel_values, num_candidates=10): return model.generate( pixel_values, num_return_sequences=num_candidates, num_beams=num_candidates, output_scores=True, return_dict_in_generate=True ) def rank_loss(positive_scores, negative_scores): return F.relu(1 - positive_scores + negative_scores).mean() def margin_loss(positive_scores, negative_scores, margin=0.1): return F.relu(margin - positive_scores + negative_scores).mean() def calibration_loss(model, pixel_values, ground_truth, processor, loss_type='margin'): candidates = generate_candidates(model, pixel_values) candidate_sequences = processor.batch_decode(candidates.sequences, skip_special_tokens=True) ground_truth = processor.decode(ground_truth, skip_special_tokens=True) similarities = [1 - compute_cer(cand, ground_truth) for cand in candidate_sequences] positive_pairs = [] negative_pairs = [] for i in range(len(similarities)): for j in range(i + 1, len(similarities)): if similarities[i] > similarities[j]: positive_pairs.append((i, j)) else: negative_pairs.append((i, j)) if not positive_pairs or not negative_pairs: return torch.tensor(0.0, device=pixel_values.device) positive_scores = candidates.sequences_scores[torch.tensor(positive_pairs)[:, 0]] negative_scores = candidates.sequences_scores[torch.tensor(negative_pairs)[:, 1]] if loss_type == 'rank': return rank_loss(positive_scores, negative_scores) elif loss_type == 'margin': return margin_loss(positive_scores, negative_scores) else: raise ValueError("Invalid loss type. Choose 'rank' or 'margin'.") class CalibratedTrainer(Seq2SeqTrainer): def __init__(self, *args, **kwargs): self.processor = kwargs.pop('processor', None) self.calibration_weight = kwargs.pop('calibration_weight', 0.1) self.calibration_loss_type = kwargs.pop('calibration_loss_type', 'margin') super().__init__(*args, **kwargs) def compute_loss(self, model, inputs, return_outputs=False): labels = inputs.pop("labels") pixel_values = inputs['pixel_values'] outputs = model.generate(**inputs, return_dict_in_generate=True, output_logits=True) logits = outputs.logits print(logits) # Original cross-entropy loss ce_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100) # Calibration loss cal_loss = calibration_loss(model, pixel_values, labels, self.processor, self.calibration_loss_type) # Combine losses total_loss = ce_loss + self.calibration_weight * cal_loss return (total_loss, outputs) if return_outputs else total_loss def compute_metrics(pred): # accuracy_metric = evaluate.load("precision") cer_metric = evaluate.load("cer") labels_ids = pred.label_ids pred_ids = pred.predictions pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True) labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id label_str = processor.batch_decode(labels_ids, skip_special_tokens=True) cer = cer_metric.compute(predictions=pred_str, references=label_str) # accuracy = accuracy_metric.compute(predictions=pred_ids.tolist(), references=labels_ids.tolist()) return {"cer": cer} # # Step 5: Define the Trainer # Step 5: Define the Trainer trainer = CalibratedTrainer( model=model, tokenizer=processor.feature_extractor, args=training_args, # compute_metrics=compute_metrics, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=default_data_collator, processor=processor, calibration_weight=0.1, calibration_loss_type='margin' # or 'rank' ) trainer.train()