|
from PIL import Image |
|
import requests |
|
from transformers import AutoModelForCausalLM |
|
from transformers import AutoProcessor |
|
from transformers import BitsAndBytesConfig |
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, AutoModelForVision2Seq |
|
import torch |
|
import pandas as pd |
|
from torchmetrics.text import CharErrorRate |
|
from peft import PeftModel, PeftConfig |
|
from torchmetrics.text import CharErrorRate |
|
from datasets import Dataset, DatasetDict, Image |
|
|
|
TRAIN_SAMPLES = 1000 |
|
TEST_SAMPLES = 200 |
|
TEST_SIZE = 0.166 |
|
DEVICE = "cuda:0" |
|
peft_model_id = "hadrakey/alphapen_idefics2_finetune_v1" |
|
|
|
config = PeftConfig.from_pretrained(peft_model_id) |
|
processor = AutoProcessor.from_pretrained(config.base_model_name_or_path, trust_remote_code=True) |
|
base_model = AutoModelForVision2Seq.from_pretrained(config.base_model_name_or_path, device_map="auto", trust_remote_code=True, torch_dtype="auto") |
|
model = PeftModel.from_pretrained(base_model, peft_model_id) |
|
model = model.to(DEVICE) |
|
|
|
|
|
df_path = "/mnt/data1/Datasets/AlphaPen/" + "testing_data.csv" |
|
df = pd.read_csv(df_path) |
|
df.dropna(inplace=True) |
|
sample = df.iloc[:5000,:] |
|
sample.reset_index(inplace=True) |
|
sample["id"] = range(sample.shape[0]) |
|
sample["query"] = "What is shown in this image?" |
|
|
|
root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/final_cropped_rotated_" |
|
image_paths = [root_dir + img for img in sample.filename] |
|
|
|
ids = sample['id'].tolist() |
|
queries = sample['query'].tolist() |
|
answers = sample['text'].tolist() |
|
|
|
|
|
dataset_dict = { |
|
'id': ids, |
|
'image': image_paths, |
|
'query': queries, |
|
'answers': answers |
|
} |
|
|
|
|
|
dataset = Dataset.from_dict(dataset_dict) |
|
|
|
|
|
dataset = dataset.cast_column("image", Image()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cer_metric = CharErrorRate() |
|
cer_idefics = [] |
|
idefics_output = [] |
|
|
|
for idx in range(len(dataset)): |
|
|
|
test_example = dataset[idx] |
|
|
|
image = test_example["image"] |
|
query = test_example["query"] |
|
|
|
|
|
messages = [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "text", "text": "Answer briefly."}, |
|
{"type": "image"}, |
|
{"type": "text", "text": query} |
|
] |
|
} |
|
] |
|
|
|
|
|
text = processor.apply_chat_template(messages, add_generation_prompt=True) |
|
inputs = processor(text=[text.strip()], images=[image], return_tensors="pt", padding=True) |
|
inputs = {k: v.to(DEVICE) for k, v in inputs.items()} |
|
generated_ids = model.generate(**inputs, max_new_tokens=64) |
|
generated_texts = processor.batch_decode(generated_ids[:, inputs["input_ids"].size(1):], skip_special_tokens=True) |
|
idefics_output.append(generated_texts[0]) |
|
cer_idefics.append(cer_metric(generated_texts[0].lower(), test_example["answers"].lower()).detach().numpy()) |
|
|
|
|
|
sample["idefics"] = idefics_output |
|
sample["cer"] = cer_idefics |
|
sample.to_csv("/mnt/data1/Datasets/AlphaPen/" + "sample_idefics_v1.csv") |