from transformers import TrOCRProcessor, VisionEncoderDecoderModel import pandas as pd from PIL import Image from torchmetrics.text import CharErrorRate # Finetuned model model_finetune_1 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_1") model_finetune_2 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_15000") model_finetune_3 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_30000") model_finetune_4 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_45000") model_finetune_5 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_60000") model_finetune_6 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_70000") #Baseline model_base = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten") processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") # Checked label df_path = "/mnt/data1/Datasets/AlphaPen/" + "testing_data.csv" data = pd.read_csv(df_path) data.dropna(inplace=True) data.reset_index(inplace=True) sample = data.iloc[:50,:] root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/" inf_baseline = [] inf_finetune_1 = [] inf_finetune_2 = [] inf_finetune_3 = [] inf_finetune_4 = [] inf_finetune_5 = [] inf_finetune_6 = [] cer_fine_1 = [] cer_fine_2 = [] cer_fine_3 = [] cer_fine_4 = [] cer_fine_5 = [] cer_fine_6 = [] cer_base = [] cer_metric = CharErrorRate() for idx in range(len(sample)): image = Image.open(root_dir + "final_cropped_rotated_" + sample.filename[idx]).convert("RGB") pixel_values = processor(image, return_tensors="pt").pixel_values generated_ids_base = model_base.generate(pixel_values) generated_ids_fine_1 = model_finetune_1.generate(pixel_values) generated_ids_fine_2= model_finetune_2.generate(pixel_values) generated_ids_fine_3 = model_finetune_3.generate(pixel_values) generated_ids_fine_4 = model_finetune_4.generate(pixel_values) generated_ids_fine_5 = model_finetune_5.generate(pixel_values) generated_ids_fine_6 = model_finetune_6.generate(pixel_values) generated_text_base = processor.batch_decode(generated_ids_base, skip_special_tokens=True)[0] generated_text_fine_1= processor.batch_decode(generated_ids_fine_1, skip_special_tokens=True)[0] generated_text_fine_2= processor.batch_decode(generated_ids_fine_2, skip_special_tokens=True)[0] generated_text_fine_3= processor.batch_decode(generated_ids_fine_3, skip_special_tokens=True)[0] generated_text_fine_4= processor.batch_decode(generated_ids_fine_4, skip_special_tokens=True)[0] generated_text_fine_5= processor.batch_decode(generated_ids_fine_5, skip_special_tokens=True)[0] generated_text_fine_6= processor.batch_decode(generated_ids_fine_6, skip_special_tokens=True)[0] cer_fine_1.append(cer_metric(generated_text_fine_1.lower(), sample.text[idx].lower()).detach().numpy()) cer_fine_2.append(cer_metric(generated_text_fine_2.lower(), sample.text[idx].lower()).detach().numpy()) cer_fine_3.append(cer_metric(generated_text_fine_3.lower(), sample.text[idx].lower()).detach().numpy()) cer_fine_4.append(cer_metric(generated_text_fine_4.lower(), sample.text[idx].lower()).detach().numpy()) cer_fine_5.append(cer_metric(generated_text_fine_5.lower(), sample.text[idx].lower()).detach().numpy()) cer_fine_6.append(cer_metric(generated_text_fine_6.lower(), sample.text[idx].lower()).detach().numpy()) cer_base.append(cer_metric(generated_text_base.lower(), sample.text[idx].lower()).detach().numpy()) inf_baseline.append(generated_text_base) inf_finetune_1.append(generated_text_fine_1) inf_finetune_2.append(generated_text_fine_2) inf_finetune_3.append(generated_text_fine_3) inf_finetune_4.append(generated_text_fine_4) inf_finetune_5.append(generated_text_fine_5) inf_finetune_6.append(generated_text_fine_6) sample["Baseline"]=inf_baseline sample["Finetune_1"]=inf_finetune_1 sample["Finetune_2"]=inf_finetune_2 sample["Finetune_3"]=inf_finetune_3 sample["Finetune_4"]=inf_finetune_4 sample["Finetune_5"]=inf_finetune_5 sample["Finetune_6"]=inf_finetune_6 sample["cer_1"]=cer_fine_1 sample["cer_2"]=cer_fine_2 sample["cer_3"]=cer_fine_3 sample["cer_4"]=cer_fine_4 sample["cer_5"]=cer_fine_5 sample["cer_6"]=cer_fine_6 sample["cer_base"]=cer_base sample.to_csv("/mnt/data1/Datasets/AlphaPen/" + "inference_results.csv")