|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel |
|
import pandas as pd |
|
from PIL import Image |
|
from torchmetrics.text import CharErrorRate |
|
|
|
|
|
model_finetune_1 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_1") |
|
model_finetune_2 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_15000") |
|
model_finetune_3 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_30000") |
|
model_finetune_4 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_45000") |
|
model_finetune_5 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_60000") |
|
model_finetune_6 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_70000") |
|
|
|
|
|
model_base = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten") |
|
|
|
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") |
|
|
|
|
|
df_path = "/mnt/data1/Datasets/AlphaPen/" + "testing_data.csv" |
|
data = pd.read_csv(df_path) |
|
data.dropna(inplace=True) |
|
data.reset_index(inplace=True) |
|
sample = data.iloc[:50,:] |
|
|
|
root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/" |
|
|
|
inf_baseline = [] |
|
inf_finetune_1 = [] |
|
inf_finetune_2 = [] |
|
inf_finetune_3 = [] |
|
inf_finetune_4 = [] |
|
inf_finetune_5 = [] |
|
inf_finetune_6 = [] |
|
|
|
cer_fine_1 = [] |
|
cer_fine_2 = [] |
|
cer_fine_3 = [] |
|
cer_fine_4 = [] |
|
cer_fine_5 = [] |
|
cer_fine_6 = [] |
|
cer_base = [] |
|
|
|
cer_metric = CharErrorRate() |
|
|
|
for idx in range(len(sample)): |
|
image = Image.open(root_dir + "final_cropped_rotated_" + sample.filename[idx]).convert("RGB") |
|
|
|
pixel_values = processor(image, return_tensors="pt").pixel_values |
|
generated_ids_base = model_base.generate(pixel_values) |
|
generated_ids_fine_1 = model_finetune_1.generate(pixel_values) |
|
generated_ids_fine_2= model_finetune_2.generate(pixel_values) |
|
generated_ids_fine_3 = model_finetune_3.generate(pixel_values) |
|
generated_ids_fine_4 = model_finetune_4.generate(pixel_values) |
|
generated_ids_fine_5 = model_finetune_5.generate(pixel_values) |
|
generated_ids_fine_6 = model_finetune_6.generate(pixel_values) |
|
|
|
generated_text_base = processor.batch_decode(generated_ids_base, skip_special_tokens=True)[0] |
|
generated_text_fine_1= processor.batch_decode(generated_ids_fine_1, skip_special_tokens=True)[0] |
|
generated_text_fine_2= processor.batch_decode(generated_ids_fine_2, skip_special_tokens=True)[0] |
|
generated_text_fine_3= processor.batch_decode(generated_ids_fine_3, skip_special_tokens=True)[0] |
|
generated_text_fine_4= processor.batch_decode(generated_ids_fine_4, skip_special_tokens=True)[0] |
|
generated_text_fine_5= processor.batch_decode(generated_ids_fine_5, skip_special_tokens=True)[0] |
|
generated_text_fine_6= processor.batch_decode(generated_ids_fine_6, skip_special_tokens=True)[0] |
|
|
|
cer_fine_1.append(cer_metric(generated_text_fine_1.lower(), sample.text[idx].lower()).detach().numpy()) |
|
cer_fine_2.append(cer_metric(generated_text_fine_2.lower(), sample.text[idx].lower()).detach().numpy()) |
|
cer_fine_3.append(cer_metric(generated_text_fine_3.lower(), sample.text[idx].lower()).detach().numpy()) |
|
cer_fine_4.append(cer_metric(generated_text_fine_4.lower(), sample.text[idx].lower()).detach().numpy()) |
|
cer_fine_5.append(cer_metric(generated_text_fine_5.lower(), sample.text[idx].lower()).detach().numpy()) |
|
cer_fine_6.append(cer_metric(generated_text_fine_6.lower(), sample.text[idx].lower()).detach().numpy()) |
|
cer_base.append(cer_metric(generated_text_base.lower(), sample.text[idx].lower()).detach().numpy()) |
|
|
|
inf_baseline.append(generated_text_base) |
|
inf_finetune_1.append(generated_text_fine_1) |
|
inf_finetune_2.append(generated_text_fine_2) |
|
inf_finetune_3.append(generated_text_fine_3) |
|
inf_finetune_4.append(generated_text_fine_4) |
|
inf_finetune_5.append(generated_text_fine_5) |
|
inf_finetune_6.append(generated_text_fine_6) |
|
|
|
sample["Baseline"]=inf_baseline |
|
sample["Finetune_1"]=inf_finetune_1 |
|
sample["Finetune_2"]=inf_finetune_2 |
|
sample["Finetune_3"]=inf_finetune_3 |
|
sample["Finetune_4"]=inf_finetune_4 |
|
sample["Finetune_5"]=inf_finetune_5 |
|
sample["Finetune_6"]=inf_finetune_6 |
|
|
|
sample["cer_1"]=cer_fine_1 |
|
sample["cer_2"]=cer_fine_2 |
|
sample["cer_3"]=cer_fine_3 |
|
sample["cer_4"]=cer_fine_4 |
|
sample["cer_5"]=cer_fine_5 |
|
sample["cer_6"]=cer_fine_6 |
|
sample["cer_base"]=cer_base |
|
|
|
sample.to_csv("/mnt/data1/Datasets/AlphaPen/" + "inference_results.csv") |