|
|
|
|
|
|
|
from PIL import Image |
|
import requests |
|
from transformers import AutoModelForCausalLM |
|
from transformers import AutoProcessor |
|
from transformers import BitsAndBytesConfig |
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel |
|
import torch |
|
import pandas as pd |
|
from torchmetrics.text import CharErrorRate |
|
from peft import PeftModel, PeftConfig |
|
|
|
|
|
model_id = "microsoft/Phi-3-vision-128k-instruct" |
|
peft_model_id = "hadrakey/alphapen_phi3" |
|
peft_model_id_new = "hadrakey/alphapen_new_large" |
|
|
|
|
|
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_finetune = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_large") |
|
|
|
|
|
|
|
|
|
|
|
model_base = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten") |
|
|
|
|
|
processor_ocr = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") |
|
|
|
|
|
nf4_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
device_map="cuda", |
|
trust_remote_code=True, |
|
torch_dtype="auto", |
|
quantization_config=nf4_config, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
messages = [{"role": "user", "content": """<|image_1|>\nThis image contains handwritten French characters forming a complete or partial word. The image is blurred, which makes recognition challenging. Please analyze the image to the best of your ability and provide your best guess of the French word or partial word shown, even if you're not certain. Follow these guidelines: |
|
|
|
1. Examine the overall shape and any discernible character features. |
|
2. Consider common French letter combinations and word patterns. |
|
3. If you can only identify some characters, provide those as a partial word. |
|
4. Make an educated guess based on what you can see, even if it's just a few letters. |
|
5. If you can see any characters at all, avoid responding with "indiscernible." |
|
|
|
Your response should be only the predicted French word or partial word, using lowercase letters unless capital letters are clearly visible. If you can see any characters or shapes at all, provide the OCR from the image. |
|
"""}] |
|
|
|
|
|
|
|
|
|
|
|
url = "https://images.unsplash.com/photo-1528834342297-fdefb9a5a92b?ixlib=rb-4.0.3&q=85&fm=jpg&crop=entropy&cs=srgb&dl=roonz-nl-vjDbHCjHlEY-unsplash.jpg&w=640" |
|
|
|
|
|
df_path = "/mnt/data1/Datasets/AlphaPen/" + "testing_data.csv" |
|
data = pd.read_csv(df_path) |
|
data.dropna(inplace=True) |
|
data.reset_index(inplace=True) |
|
sample = data.iloc[:5000,:] |
|
root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/" |
|
|
|
prompt = processor.tokenizer.apply_chat_template( |
|
messages, tokenize=False, add_generation_prompt=True |
|
) |
|
cer_metric = CharErrorRate() |
|
phi_output=[] |
|
phi_finetune_output=[] |
|
inf_baseline = [] |
|
inf_finetune = [] |
|
inf_finetune_new = [] |
|
|
|
cer_phi = [] |
|
cer_phi_finetune = [] |
|
cer_trocr_fine_new = [] |
|
cer_trocr_fine = [] |
|
cer_trocr_base = [] |
|
for idx in range(len(sample)): |
|
|
|
|
|
image = Image.open(root_dir + "final_cropped_rotated_" + data.filename[idx]).convert("RGB") |
|
|
|
|
|
inputs = processor(prompt, [image], return_tensors="pt").to("cuda:0") |
|
|
|
|
|
generate_ids = model.generate( |
|
**inputs, |
|
eos_token_id=processor.tokenizer.eos_token_id, |
|
max_new_tokens=500, |
|
do_sample=False, |
|
) |
|
|
|
|
|
generate_ids = generate_ids[:, inputs["input_ids"].shape[1] :] |
|
|
|
|
|
response = processor.batch_decode( |
|
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False |
|
)[0] |
|
phi_output.append(response) |
|
cer_phi.append(cer_metric(response.lower(), data.text[idx].lower()).detach().numpy()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pixel_values = processor_ocr(image, return_tensors="pt").pixel_values |
|
generated_ids_base = model_base.generate(pixel_values) |
|
generated_ids_fine = model_finetune.generate(pixel_values) |
|
|
|
generated_text_base = processor_ocr.batch_decode(generated_ids_base, skip_special_tokens=True)[0] |
|
generated_text_fine= processor_ocr.batch_decode(generated_ids_fine, skip_special_tokens=True)[0] |
|
|
|
|
|
inf_baseline.append(generated_text_base) |
|
inf_finetune.append(generated_text_fine) |
|
|
|
|
|
|
|
cer_trocr_fine.append(cer_metric(generated_text_fine.lower(), data.text[idx].lower()).detach().numpy()) |
|
cer_trocr_base.append(cer_metric(generated_text_base.lower(), data.text[idx].lower()).detach().numpy()) |
|
|
|
|
|
|
|
sample["phi3"]=phi_output |
|
|
|
sample["Baseline"]=inf_baseline |
|
sample["Finetune"]=inf_finetune |
|
|
|
sample["cer_phi"]=cer_phi |
|
|
|
sample["cer_trocr_base"]=cer_trocr_base |
|
sample["cer_trocr_fine"]=cer_trocr_fine |
|
|
|
sample.to_csv("/mnt/data1/Datasets/AlphaPen/" + "sample_data.csv") |