|
|
|
|
|
|
|
from PIL import Image |
|
import requests |
|
from transformers import AutoModelForCausalLM |
|
from transformers import AutoProcessor |
|
from transformers import BitsAndBytesConfig |
|
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator |
|
import torch |
|
import pandas as pd |
|
from torchmetrics.text import CharErrorRate |
|
from peft import LoraConfig, get_peft_model |
|
from data import AlphaPenPhi3Dataset |
|
from sklearn.model_selection import train_test_split |
|
from datetime import datetime |
|
import os |
|
import evaluate |
|
|
|
os.environ["WANDB_PROJECT"]="Alphapen" |
|
|
|
|
|
model_id = "microsoft/Phi-3-vision-128k-instruct" |
|
|
|
|
|
df_path = "/mnt/data1/Datasets/AlphaPen/" + "training_data.csv" |
|
df = pd.read_csv(df_path) |
|
df.dropna(inplace=True) |
|
train_df, test_df = train_test_split(df, test_size=0.15, random_state=0) |
|
|
|
train_df.reset_index(drop=True, inplace=True) |
|
test_df.reset_index(drop=True, inplace=True) |
|
root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/final_cropped_rotated_" |
|
|
|
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) |
|
tokenizer = processor.tokenizer |
|
|
|
train_dataset = AlphaPenPhi3Dataset(root_dir=root_dir, dataframe=train_df, tokenizer=tokenizer, max_length=128, image_size=128) |
|
eval_dataset = AlphaPenPhi3Dataset(root_dir=root_dir, dataframe=test_df.iloc[:10,], tokenizer=tokenizer, max_length=128, image_size=128) |
|
|
|
print(train_dataset[0]) |
|
nf4_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
device_map="auto", |
|
trust_remote_code=True, |
|
torch_dtype="auto", |
|
quantization_config=nf4_config, |
|
) |
|
|
|
|
|
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id |
|
model.config.pad_token_id = processor.tokenizer.pad_token_id |
|
|
|
|
|
|
|
|
|
|
|
|
|
model.config.eos_token_id = processor.tokenizer.sep_token_id |
|
model.config.max_new_tokens= 128 |
|
model.config.early_stopping = True |
|
model.config.no_repeat_ngram_size = 3 |
|
model.config.length_penalty = 2.0 |
|
model.config.num_beams = 4 |
|
|
|
|
|
|
|
lora_config = LoraConfig( |
|
r=64, |
|
lora_alpha=16, |
|
lora_dropout=0.1, |
|
|
|
target_modules=[ |
|
"q_proj", |
|
"k_proj", |
|
"v_proj", |
|
"o_proj", |
|
|
|
|
|
|
|
], |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = get_peft_model(model, lora_config) |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = model.to(device) |
|
|
|
|
|
|
|
|
|
training_args = Seq2SeqTrainingArguments( |
|
predict_with_generate=True, |
|
evaluation_strategy="steps", |
|
per_device_train_batch_size=8, |
|
per_device_eval_batch_size=8, |
|
bf16=True, |
|
bf16_full_eval=True, |
|
output_dir="./", |
|
logging_steps=100, |
|
save_steps=1000, |
|
eval_steps=100, |
|
report_to="wandb", |
|
run_name=f"phi3-vision-LoRA-{datetime.now().strftime('%Y-%m-%d-%H-%M-%s')}", |
|
optim="adamw_torch_fused", |
|
lr_scheduler_type="cosine", |
|
gradient_accumulation_steps=2, |
|
learning_rate=1.0e-4, |
|
max_steps=10000, |
|
push_to_hub=True, |
|
hub_model_id="hadrakey/alphapen_phi3", |
|
) |
|
|
|
def compute_metrics(pred): |
|
|
|
cer_metric = evaluate.load("cer") |
|
|
|
labels_ids = pred.label_ids |
|
pred_ids = pred.predictions |
|
print(labels_ids.shape, pred_ids.shape) |
|
max_length = max(pred_ids.shape[1], labels_ids.shape[1]) |
|
|
|
pred_str = processor.batch_decode(pred_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False) |
|
print(pred_str) |
|
|
|
labels_ids[labels_ids == -100] = tokenizer.pad_token_id |
|
label_str = processor.batch_decode(labels_ids, skip_special_tokens=True) |
|
print(label_str) |
|
cer = cer_metric.compute(predictions=pred_str, references=label_str) |
|
|
|
|
|
return {"cer": cer} |
|
|
|
|
|
|
|
trainer = Seq2SeqTrainer( |
|
model=model, |
|
tokenizer=tokenizer, |
|
args=training_args, |
|
compute_metrics=compute_metrics, |
|
train_dataset=train_dataset, |
|
eval_dataset=eval_dataset, |
|
data_collator=default_data_collator |
|
) |
|
|
|
trainer.train() |