|
|
|
from dataclasses import dataclass, field |
|
from typing import Optional |
|
import pandas as pd |
|
import os |
|
|
|
import torch |
|
from transformers import VisionEncoderDecoderModel, TrOCRProcessor, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator, EarlyStoppingCallback, TrainerCallback, TrainingArguments, TrainerState, TrainerControl |
|
from peft import LoraConfig, get_peft_model |
|
from transformers import VisionEncoderDecoderConfig |
|
|
|
from data import AphaPenDataset |
|
import evaluate |
|
from sklearn.model_selection import train_test_split |
|
|
|
from src.calibrator import EncoderDecoderCalibrator |
|
from src.loss import MarginLoss, KLRegularization |
|
from src.similarity import CERSimilarity |
|
from datetime import datetime |
|
from torch.utils.data import ConcatDataset |
|
import wandb |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
samp_list = [1, 15000, 30000, 45000, 60000, 70000] |
|
|
|
|
|
model_name = "microsoft/trocr-large-handwritten" |
|
|
|
df_path = "/mnt/data1/Datasets/AlphaPen/" + "training_data.csv" |
|
df = pd.read_csv(df_path) |
|
df.dropna(inplace=True) |
|
train_df, test_df = train_test_split(df, test_size=0.02, random_state=0) |
|
|
|
train_df.reset_index(drop=True, inplace=True) |
|
test_df.reset_index(drop=True, inplace=True) |
|
|
|
df_path_b2= "/mnt/data1/Datasets/AlphaPen/" + "training_b2.csv" |
|
df_b2 = pd.read_csv(df_path_b2) |
|
df_b2.dropna(inplace=True) |
|
train_df_b2, test_df_b2 = train_test_split(df_b2, test_size=0.01, random_state=0) |
|
|
|
train_df_b2.reset_index(drop=True, inplace=True) |
|
test_df_b2.reset_index(drop=True, inplace=True) |
|
|
|
root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/final_cropped_rotated_" |
|
root_dir_b2 = "/mnt/data1/Datasets/OCR/Alphapen/DataBatch2/clean_data/cropped_data/cropped_" |
|
processor = TrOCRProcessor.from_pretrained(model_name) |
|
|
|
train_dataset_b1 = AphaPenDataset(root_dir=root_dir, df=train_df.iloc[:100,:], processor=processor) |
|
eval_dataset_b1 = AphaPenDataset(root_dir=root_dir, df=test_df.iloc[:100,:], processor=processor) |
|
|
|
eval_dataset_b2 = AphaPenDataset(root_dir=root_dir_b2, df=test_df_b2.iloc[:100,:], processor=processor) |
|
|
|
|
|
eval_dataset = ConcatDataset([eval_dataset_b1, eval_dataset_b2]) |
|
|
|
|
|
|
|
|
|
|
|
model = VisionEncoderDecoderModel.from_pretrained(model_name) |
|
|
|
|
|
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id |
|
model.config.pad_token_id = processor.tokenizer.pad_token_id |
|
|
|
|
|
|
|
|
|
|
|
|
|
model.config.eos_token_id = processor.tokenizer.sep_token_id |
|
model.config.max_length = 64 |
|
model.config.early_stopping = True |
|
model.config.no_repeat_ngram_size = 3 |
|
model.config.length_penalty = 2.0 |
|
model.config.num_beams = 4 |
|
|
|
|
|
|
|
lora_config = LoraConfig( |
|
r=1, |
|
lora_alpha=8, |
|
lora_dropout=0.1, |
|
target_modules=[ |
|
'query', |
|
'key', |
|
'value', |
|
'intermediate.dense', |
|
'output.dense', |
|
|
|
|
|
|
|
|
|
|
|
|
|
], |
|
|
|
) |
|
model = get_peft_model(model, lora_config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for samp in samp_list: |
|
os.environ["WANDB_PROJECT"] = "Alphapen-TrOCR" |
|
train_dataset_b2 = AphaPenDataset(root_dir=root_dir_b2, df=train_df_b2.iloc[:samp,:], processor=processor) |
|
|
|
train_dataset = ConcatDataset([train_dataset_b1, train_dataset_b2]) |
|
|
|
|
|
|
|
training_args = Seq2SeqTrainingArguments( |
|
predict_with_generate=True, |
|
evaluation_strategy="steps", |
|
per_device_train_batch_size=8, |
|
per_device_eval_batch_size=8, |
|
bf16=True, |
|
bf16_full_eval=True, |
|
output_dir="./", |
|
logging_steps=100, |
|
save_steps=1000, |
|
eval_steps=500, |
|
report_to="wandb", |
|
optim="adamw_torch_fused", |
|
lr_scheduler_type="cosine", |
|
gradient_accumulation_steps=2, |
|
learning_rate=1.0e-4, |
|
max_steps=15000, |
|
|
|
run_name="trocr-LoRA-large_" + str(samp), |
|
push_to_hub=True, |
|
hub_model_id="hadrakey/alphapen_trocr_large_" + str(samp), |
|
) |
|
|
|
|
|
|
|
def compute_metrics(pred): |
|
|
|
cer_metric = evaluate.load("cer") |
|
|
|
labels_ids = pred.label_ids |
|
pred_ids = pred.predictions |
|
|
|
pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True) |
|
labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id |
|
label_str = processor.batch_decode(labels_ids, skip_special_tokens=True) |
|
pred_str = [word.lower() for word in pred_str] |
|
label_str = [word.lower() for word in label_str] |
|
cer = cer_metric.compute(predictions=pred_str, references=label_str) |
|
|
|
|
|
return {"cer": cer} |
|
|
|
|
|
trainer = Seq2SeqTrainer( |
|
model=model, |
|
tokenizer=processor.feature_extractor, |
|
args=training_args, |
|
compute_metrics=compute_metrics, |
|
train_dataset=train_dataset, |
|
eval_dataset=eval_dataset, |
|
data_collator=default_data_collator, |
|
|
|
) |
|
|
|
trainer.train() |
|
wandb.finish() |