|
|
|
from dataclasses import dataclass, field |
|
from typing import Optional |
|
import pandas as pd |
|
|
|
import torch |
|
from accelerate import Accelerator |
|
from datasets import load_dataset, Dataset, load_metric |
|
from peft import LoraConfig |
|
from tqdm import tqdm |
|
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, HfArgumentParser, TrainingArguments, VisionEncoderDecoderModel, TrOCRProcessor, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator, EarlyStoppingCallback |
|
|
|
|
|
from trl import SFTTrainer, is_xpu_available |
|
from data import AphaPenDataset |
|
import evaluate |
|
from sklearn.model_selection import train_test_split |
|
|
|
import torchvision.transforms as transforms |
|
|
|
|
|
tqdm.pandas() |
|
|
|
|
|
|
|
@dataclass |
|
class ScriptArguments: |
|
""" |
|
The name of the OCR model we wish to fine with Seq2SeqTrainer |
|
""" |
|
|
|
model_name: Optional[str] = field(default="microsoft/trocr-base-handwritten", metadata={"help": "the model name"}) |
|
dataset_name: Optional[str] = field( |
|
default="Anthropic/hh-rlhf", metadata={"help": "the dataset name"} |
|
) |
|
log_with: Optional[str] = field(default="none", metadata={"help": "use 'wandb' to log with wandb"}) |
|
learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"}) |
|
batch_size: Optional[int] = field(default=8, metadata={"help": "the batch size"}) |
|
seq_length: Optional[int] = field(default=512, metadata={"help": "Input sequence length"}) |
|
gradient_accumulation_steps: Optional[int] = field( |
|
default=16, metadata={"help": "the number of gradient accumulation steps"} |
|
) |
|
load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"}) |
|
load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"}) |
|
use_peft: Optional[bool] = field(default=False, metadata={"help": "Wether to use PEFT or not to train adapters"}) |
|
trust_remote_code: Optional[bool] = field(default=False, metadata={"help": "Enable `trust_remote_code`"}) |
|
output_dir: Optional[str] = field(default="output", metadata={"help": "the output directory"}) |
|
peft_lora_r: Optional[int] = field(default=64, metadata={"help": "the r parameter of the LoRA adapters"}) |
|
peft_lora_alpha: Optional[int] = field(default=16, metadata={"help": "the alpha parameter of the LoRA adapters"}) |
|
logging_steps: Optional[int] = field(default=1, metadata={"help": "the number of logging steps"}) |
|
use_auth_token: Optional[bool] = field(default=True, metadata={"help": "Use HF auth token to access the model"}) |
|
num_train_epochs: Optional[int] = field(default=3, metadata={"help": "the number of training epochs"}) |
|
max_steps: Optional[int] = field(default=-1, metadata={"help": "the number of training steps"}) |
|
max_length: Optional[int] = field(default=10, metadata={"help": "the maximum length"}) |
|
no_repeat_ngram_size: Optional[int] = field(default=3, metadata={"help": "the number of repeat"}) |
|
length_penalty: Optional[float] = field(default=2.0, metadata={"help": "the length of penalty"}) |
|
num_beams: Optional[int] = field(default=3, metadata={"help": "the number of beam search"}) |
|
early_stopping: Optional[bool] = field(default=True, metadata={"help": "Early stopping"}) |
|
save_steps: Optional[int] = field( |
|
default=1000, metadata={"help": "Number of updates steps before two checkpoint saves"} |
|
) |
|
save_total_limit: Optional[int] = field(default=10, metadata={"help": "Limits total number of checkpoints."}) |
|
push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the model to HF Hub"}) |
|
gradient_checkpointing: Optional[bool] = field( |
|
default=False, metadata={"help": "Whether to use gradient checkpointing or no"} |
|
) |
|
gradient_checkpointing_kwargs: Optional[dict] = field( |
|
default=None, |
|
metadata={ |
|
"help": "key word arguments to be passed along `torch.utils.checkpoint.checkpoint` method - e.g. `use_reentrant=False`" |
|
}, |
|
) |
|
hub_model_id: Optional[str] = field(default=None, metadata={"help": "The name of the model on HF Hub"}) |
|
|
|
parser = HfArgumentParser(ScriptArguments) |
|
script_args = parser.parse_args_into_dataclasses()[0] |
|
|
|
|
|
df_path = "/mnt/data1/Datasets/AlphaPen/" + "training_data.csv" |
|
df = pd.read_csv(df_path) |
|
df.dropna(inplace=True) |
|
train_df, test_df = train_test_split(df, test_size=0.15, random_state=0) |
|
|
|
train_df.reset_index(drop=True, inplace=True) |
|
test_df.reset_index(drop=True, inplace=True) |
|
root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/final_cropped_rotated_" |
|
processor = TrOCRProcessor.from_pretrained(script_args.model_name) |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.PILToTensor(), |
|
transforms.RandomRotation(degrees=(0, 180)) |
|
]) |
|
|
|
|
|
|
|
train_dataset = AphaPenDataset(root_dir=root_dir, df=train_df, processor=processor, transform=transform) |
|
eval_dataset = AphaPenDataset(root_dir=root_dir, df=test_df, processor=processor) |
|
|
|
|
|
if script_args.load_in_8bit and script_args.load_in_4bit: |
|
raise ValueError("You can't load the model in 8 bits and 4 bits at the same time") |
|
elif script_args.load_in_8bit or script_args.load_in_4bit: |
|
quantization_config = BitsAndBytesConfig( |
|
load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit |
|
) |
|
|
|
device_map = ( |
|
{"": f"xpu:{Accelerator().local_process_index}"} |
|
if is_xpu_available() |
|
else {"": Accelerator().local_process_index} |
|
) |
|
torch_dtype = torch.bfloat16 |
|
else: |
|
device_map = None |
|
quantization_config = None |
|
torch_dtype = None |
|
|
|
model = VisionEncoderDecoderModel.from_pretrained( |
|
script_args.model_name, |
|
quantization_config=quantization_config, |
|
device_map=device_map, |
|
trust_remote_code=script_args.trust_remote_code, |
|
torch_dtype=torch_dtype, |
|
token=script_args.use_auth_token, |
|
) |
|
|
|
|
|
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id |
|
model.config.pad_token_id = processor.tokenizer.pad_token_id |
|
|
|
model.config.vocab_size = model.config.decoder.vocab_size |
|
|
|
|
|
model.config.eos_token_id = processor.tokenizer.sep_token_id |
|
model.config.max_length = script_args.max_length |
|
model.config.early_stopping = script_args.early_stopping |
|
model.config.no_repeat_ngram_size = script_args.no_repeat_ngram_size |
|
model.config.length_penalty = script_args.length_penalty |
|
model.config.num_beams = script_args.num_beams |
|
|
|
|
|
|
|
|
|
|
|
training_args = Seq2SeqTrainingArguments( |
|
predict_with_generate=True, |
|
evaluation_strategy="steps", |
|
|
|
|
|
fp16=True, |
|
output_dir=script_args.output_dir, |
|
logging_steps=script_args.logging_steps, |
|
save_steps=script_args.save_steps, |
|
eval_steps=100, |
|
save_total_limit=script_args.save_total_limit, |
|
load_best_model_at_end = True, |
|
report_to=script_args.log_with, |
|
num_train_epochs=script_args.num_train_epochs, |
|
push_to_hub=script_args.push_to_hub, |
|
hub_model_id=script_args.hub_model_id, |
|
gradient_checkpointing=script_args.gradient_checkpointing, |
|
auto_find_batch_size=True, |
|
metric_for_best_model="eval/cer" |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
def compute_metrics(pred): |
|
|
|
cer_metric = evaluate.load("cer") |
|
|
|
labels_ids = pred.label_ids |
|
pred_ids = pred.predictions |
|
|
|
pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True) |
|
labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id |
|
label_str = processor.batch_decode(labels_ids, skip_special_tokens=True) |
|
|
|
cer = cer_metric.compute(predictions=pred_str, references=label_str) |
|
|
|
|
|
return {"cer": cer} |
|
|
|
early_stop = EarlyStoppingCallback(10, .001) |
|
|
|
|
|
if script_args.use_peft: |
|
peft_config = LoraConfig( |
|
r=script_args.peft_lora_r, |
|
lora_alpha=script_args.peft_lora_alpha, |
|
bias="none", |
|
task_type="CAUSAL_LM", |
|
target_modules="all-linear" |
|
) |
|
else: |
|
peft_config = None |
|
|
|
trainer = SFTTrainer( |
|
model=model, |
|
tokenizer=processor.feature_extractor, |
|
args=training_args, |
|
compute_metrics=compute_metrics, |
|
train_dataset=train_dataset, |
|
eval_dataset=eval_dataset, |
|
data_collator=default_data_collator, |
|
peft_config=peft_config, |
|
callbacks=[EarlyStoppingCallback(early_stopping_patience=10)] |
|
) |
|
|
|
trainer.train() |
|
|
|
|
|
|
|
|