|
import torch |
|
from torch.utils.data import Dataset |
|
from PIL import Image |
|
import json |
|
from transformers import TrOCRProcessor |
|
import pandas as pd |
|
from sklearn.model_selection import train_test_split |
|
import glob |
|
import torchvision.transforms as transforms |
|
import numpy as np |
|
|
|
def prepare_data_frame(root_dir): |
|
with open(root_dir) as f: |
|
d = json.load(f) |
|
filename = [d[i]["word_id"]+ ".png" for i in range(len(d))] |
|
text = [d[i]["text"] for i in range(len(d))] |
|
data = {'filename': filename, 'text': text} |
|
df = pd.DataFrame(data=data) |
|
return df |
|
|
|
|
|
class AphaPenDataset(Dataset): |
|
def __init__(self, root_dir, df, processor, transform=None, max_target_length=128): |
|
self.root_dir = root_dir |
|
self.df= df |
|
|
|
self.processor = processor |
|
self.max_target_length = max_target_length |
|
self.transform = transform |
|
|
|
def __len__(self): |
|
return len(self.df) |
|
|
|
def __getitem__(self, idx): |
|
|
|
file_name = self.df.filename[idx] |
|
text = self.df.text[idx] |
|
|
|
image = Image.open(self.root_dir + file_name).convert("RGB") |
|
if self.transform is not None: |
|
image = self.transform(image) |
|
img=transforms.ToPILImage()(image) |
|
img.save("/mnt/data1/Datasets/AlphaPen/transformed_images/" + file_name) |
|
pixel_values = self.processor(image, return_tensors="pt").pixel_values |
|
|
|
labels = self.processor.tokenizer(text, |
|
padding="max_length", |
|
max_length=self.max_target_length).input_ids |
|
|
|
labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels] |
|
|
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} |
|
return encoding |
|
|
|
def prepare_data(self): |
|
with open(self.path_json) as f: |
|
d = json.load(f) |
|
filename = [d[i]["image_id"]+ ".png" for i in range(len(d))] |
|
text = [d[i]["text"] for i in range(len(d))] |
|
return filename, text |
|
|
|
|
|
class AlphaPenPhi3Dataset(Dataset): |
|
def __init__(self, root_dir, dataframe, tokenizer, max_length, image_size): |
|
self.dataframe = dataframe |
|
self.tokenizer = tokenizer |
|
self.tokenizer.padding_side = 'left' |
|
self.max_length = max_length |
|
self.root_dir = root_dir |
|
self.transform = transforms.Compose([ |
|
transforms.Resize((image_size, image_size)), |
|
transforms.ToTensor() |
|
]) |
|
|
|
def __len__(self): |
|
return len(self.dataframe) |
|
|
|
|
|
def __getitem__(self, idx): |
|
row = self.dataframe.iloc[idx] |
|
text = f"<|user|>\n<|image_1|>What is shown in this image?<|end|><|assistant|>\n {row['text']} <|end|>" |
|
image_path = self.root_dir + row['filename'] |
|
|
|
|
|
encodings = self.tokenizer(text, truncation=True, padding='max_length', max_length=self.max_length) |
|
|
|
try: |
|
|
|
image = Image.open(image_path).convert("RGB") |
|
image = self.image_transform_function(image) |
|
except (FileNotFoundError, IOError): |
|
|
|
return None |
|
|
|
labels = self.tokenizer(row['text'], |
|
padding="max_length", |
|
max_length=self.max_length).input_ids |
|
|
|
labels = [label if label != self.tokenizer.pad_token_id else -100 for label in labels] |
|
encodings['pixel_values'] = image |
|
encodings['labels'] = labels |
|
|
|
return {key: torch.tensor(val) for key, val in encodings.items()} |
|
|
|
|
|
def image_transform_function(self, image): |
|
image = self.transform(image) |
|
return image |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
json_path = "/mnt/data1/Datasets/OCR/Alphapen/label_check/" |
|
json_path_b2 = "/mnt/data1/Datasets/OCR/Alphapen/DataBatch2/label_check/" |
|
root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/final_cropped_rotated_" |
|
root_dir_b2 = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/final_cropped_rotated_" |
|
json_files = glob.glob(json_path + "*.json") |
|
json_files_b2 = glob.glob(json_path_b2 + "*.json") |
|
root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/final_cropped_rotated_" |
|
df_list_b1 = [prepare_data_frame(file) for file in json_files] |
|
df_list_b2 = [prepare_data_frame(file) for file in json_files_b2] |
|
|
|
df_b1 = pd.concat(df_list_b1) |
|
df_b2 = pd.concat(df_list_b2) |
|
|
|
df_b1.to_csv("/mnt/data1/Datasets/AlphaPen/" + "testing_data_b1.csv") |
|
df_b2.to_csv("/mnt/data1/Datasets/AlphaPen/" + "testing_data_b2.csv") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|