|
|
|
from dataclasses import dataclass, field |
|
from typing import Optional |
|
import pandas as pd |
|
import os |
|
|
|
import torch |
|
from transformers import VisionEncoderDecoderModel, TrOCRProcessor, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator, EarlyStoppingCallback |
|
from peft import LoraConfig, get_peft_model |
|
|
|
from data import AphaPenDataset |
|
import evaluate |
|
from sklearn.model_selection import train_test_split |
|
|
|
from src.calibrator import EncoderDecoderCalibrator |
|
from src.loss import MarginLoss, KLRegularization |
|
from src.similarity import CERSimilarity |
|
from datetime import datetime |
|
import torch.nn.functional as F |
|
|
|
os.environ["WANDB_PROJECT"] = "Alphapen-TrOCR" |
|
|
|
|
|
train_df_path = "/mnt/data1/Datasets/AlphaPen/" + "training_data.csv" |
|
test_df_path = "/mnt/data1/Datasets/AlphaPen/" + "testing_data.csv" |
|
|
|
|
|
|
|
|
|
train_df = pd.read_csv(test_df_path)[:4000] |
|
train_df.dropna(inplace=True) |
|
|
|
test_df = pd.read_csv(test_df_path)[4000:] |
|
test_df.dropna(inplace=True) |
|
|
|
|
|
train_df.reset_index(drop=True, inplace=True) |
|
test_df.reset_index(drop=True, inplace=True) |
|
root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/final_cropped_rotated_" |
|
|
|
model_name = "microsoft/trocr-large-handwritten" |
|
|
|
processor = TrOCRProcessor.from_pretrained(model_name) |
|
train_dataset = AphaPenDataset(root_dir=root_dir, df=train_df, processor=processor) |
|
eval_dataset = AphaPenDataset(root_dir=root_dir, df=test_df, processor=processor) |
|
|
|
|
|
model = VisionEncoderDecoderModel.from_pretrained(model_name) |
|
|
|
|
|
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id |
|
model.config.pad_token_id = processor.tokenizer.pad_token_id |
|
|
|
model.config.vocab_size = model.config.decoder.vocab_size |
|
|
|
model.vocab_size = model.config.decoder.vocab_size |
|
|
|
|
|
model.config.eos_token_id = processor.tokenizer.sep_token_id |
|
model.config.max_length = 64 |
|
model.config.early_stopping = True |
|
model.config.no_repeat_ngram_size = 3 |
|
model.config.length_penalty = 2.0 |
|
model.config.num_beams = 4 |
|
|
|
|
|
lora_config = LoraConfig( |
|
r=1, |
|
lora_alpha=8, |
|
lora_dropout=0.1, |
|
target_modules=[ |
|
'query', |
|
'key', |
|
'value', |
|
'intermediate.dense', |
|
'output.dense', |
|
|
|
|
|
|
|
|
|
|
|
|
|
], |
|
) |
|
model = get_peft_model(model, lora_config) |
|
|
|
tokenizer = processor.tokenizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
training_args = Seq2SeqTrainingArguments( |
|
predict_with_generate=True, |
|
evaluation_strategy="steps", |
|
per_device_train_batch_size=8, |
|
per_device_eval_batch_size=8, |
|
bf16=True, |
|
bf16_full_eval=True, |
|
output_dir="./", |
|
logging_steps=100, |
|
save_steps=20000, |
|
eval_steps=500, |
|
|
|
optim="adamw_torch_fused", |
|
lr_scheduler_type="cosine", |
|
gradient_accumulation_steps=2, |
|
learning_rate=1.0e-4, |
|
max_steps=10000, |
|
run_name=f"trocr-LoRA-{datetime.now().strftime('%Y-%m-%d-%H-%M-%s')}", |
|
) |
|
|
|
|
|
|
|
cer_metric = evaluate.load("cer") |
|
|
|
def compute_cer(pred, target): |
|
return cer_metric.compute(predictions=[pred], references=[target])['cer'] |
|
|
|
def generate_candidates(model, pixel_values, num_candidates=10): |
|
return model.generate( |
|
pixel_values, |
|
num_return_sequences=num_candidates, |
|
num_beams=num_candidates, |
|
output_scores=True, |
|
return_dict_in_generate=True |
|
) |
|
|
|
def rank_loss(positive_scores, negative_scores): |
|
return F.relu(1 - positive_scores + negative_scores).mean() |
|
|
|
def margin_loss(positive_scores, negative_scores, margin=0.1): |
|
return F.relu(margin - positive_scores + negative_scores).mean() |
|
|
|
def calibration_loss(model, pixel_values, ground_truth, processor, loss_type='margin'): |
|
candidates = generate_candidates(model, pixel_values) |
|
candidate_sequences = processor.batch_decode(candidates.sequences, skip_special_tokens=True) |
|
|
|
ground_truth = processor.decode(ground_truth, skip_special_tokens=True) |
|
|
|
similarities = [1 - compute_cer(cand, ground_truth) for cand in candidate_sequences] |
|
|
|
positive_pairs = [] |
|
negative_pairs = [] |
|
|
|
for i in range(len(similarities)): |
|
for j in range(i + 1, len(similarities)): |
|
if similarities[i] > similarities[j]: |
|
positive_pairs.append((i, j)) |
|
else: |
|
negative_pairs.append((i, j)) |
|
|
|
if not positive_pairs or not negative_pairs: |
|
return torch.tensor(0.0, device=pixel_values.device) |
|
|
|
positive_scores = candidates.sequences_scores[torch.tensor(positive_pairs)[:, 0]] |
|
negative_scores = candidates.sequences_scores[torch.tensor(negative_pairs)[:, 1]] |
|
|
|
if loss_type == 'rank': |
|
return rank_loss(positive_scores, negative_scores) |
|
elif loss_type == 'margin': |
|
return margin_loss(positive_scores, negative_scores) |
|
else: |
|
raise ValueError("Invalid loss type. Choose 'rank' or 'margin'.") |
|
|
|
class CalibratedTrainer(Seq2SeqTrainer): |
|
def __init__(self, *args, **kwargs): |
|
self.processor = kwargs.pop('processor', None) |
|
self.calibration_weight = kwargs.pop('calibration_weight', 0.1) |
|
self.calibration_loss_type = kwargs.pop('calibration_loss_type', 'margin') |
|
super().__init__(*args, **kwargs) |
|
|
|
def compute_loss(self, model, inputs, return_outputs=False): |
|
labels = inputs.pop("labels") |
|
pixel_values = inputs['pixel_values'] |
|
|
|
outputs = model.generate(**inputs, return_dict_in_generate=True, output_logits=True) |
|
|
|
logits = outputs.logits |
|
print(logits) |
|
|
|
|
|
ce_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100) |
|
|
|
|
|
cal_loss = calibration_loss(model, pixel_values, labels, self.processor, self.calibration_loss_type) |
|
|
|
|
|
total_loss = ce_loss + self.calibration_weight * cal_loss |
|
|
|
return (total_loss, outputs) if return_outputs else total_loss |
|
|
|
|
|
|
|
def compute_metrics(pred): |
|
|
|
cer_metric = evaluate.load("cer") |
|
|
|
labels_ids = pred.label_ids |
|
pred_ids = pred.predictions |
|
|
|
pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True) |
|
labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id |
|
label_str = processor.batch_decode(labels_ids, skip_special_tokens=True) |
|
|
|
cer = cer_metric.compute(predictions=pred_str, references=label_str) |
|
|
|
|
|
return {"cer": cer} |
|
|
|
|
|
|
|
trainer = CalibratedTrainer( |
|
model=model, |
|
tokenizer=processor.feature_extractor, |
|
args=training_args, |
|
|
|
train_dataset=train_dataset, |
|
eval_dataset=eval_dataset, |
|
data_collator=default_data_collator, |
|
processor=processor, |
|
calibration_weight=0.1, |
|
calibration_loss_type='margin' |
|
) |
|
|
|
|
|
trainer.train() |