|
""" |
|
Microsoft Table Transformer Extension |
|
By Neils: |
|
https://docs.llamaindex.ai/en/stable/examples/multi_modal/multi_modal_pdf_tables.html#experiment-3-let-s-use-microsoft-table-transformer-to-crop-tables-from-the-images-and-see-if-it-gives-the-correct-answer |
|
""" |
|
import matplotlib.pyplot as plt |
|
import matplotlib.patches as patches |
|
from matplotlib.patches import Patch |
|
import io |
|
from PIL import Image, ImageDraw |
|
import numpy as np |
|
import csv |
|
import pandas as pd |
|
|
|
from torchvision import transforms |
|
|
|
from transformers import AutoModelForObjectDetection |
|
import torch |
|
import openai |
|
import os |
|
import fitz |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
class MaxResize(object): |
|
def __init__(self, max_size=800): |
|
self.max_size = max_size |
|
|
|
def __call__(self, image): |
|
width, height = image.size |
|
current_max_size = max(width, height) |
|
scale = self.max_size / current_max_size |
|
resized_image = image.resize( |
|
(int(round(scale * width)), int(round(scale * height))) |
|
) |
|
|
|
return resized_image |
|
|
|
|
|
detection_transform = transforms.Compose( |
|
[ |
|
MaxResize(800), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
|
] |
|
) |
|
|
|
structure_transform = transforms.Compose( |
|
[ |
|
MaxResize(1000), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
|
] |
|
) |
|
|
|
|
|
|
|
model = AutoModelForObjectDetection.from_pretrained( |
|
"microsoft/table-transformer-detection", revision="no_timm" |
|
).to(device) |
|
|
|
|
|
|
|
structure_model = AutoModelForObjectDetection.from_pretrained( |
|
"microsoft/table-transformer-structure-recognition-v1.1-all" |
|
).to(device) |
|
|
|
|
|
|
|
def box_cxcywh_to_xyxy(x): |
|
x_c, y_c, w, h = x.unbind(-1) |
|
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] |
|
return torch.stack(b, dim=1) |
|
|
|
|
|
def rescale_bboxes(out_bbox, size): |
|
width, height = size |
|
boxes = box_cxcywh_to_xyxy(out_bbox) |
|
boxes = boxes * torch.tensor( |
|
[width, height, width, height], dtype=torch.float32 |
|
) |
|
return boxes |
|
|
|
|
|
def outputs_to_objects(outputs, img_size, id2label): |
|
m = outputs.logits.softmax(-1).max(-1) |
|
pred_labels = list(m.indices.detach().cpu().numpy())[0] |
|
pred_scores = list(m.values.detach().cpu().numpy())[0] |
|
pred_bboxes = outputs["pred_boxes"].detach().cpu()[0] |
|
pred_bboxes = [ |
|
elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size) |
|
] |
|
|
|
objects = [] |
|
for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes): |
|
class_label = id2label[int(label)] |
|
if not class_label == "no object": |
|
objects.append( |
|
{ |
|
"label": class_label, |
|
"score": float(score), |
|
"bbox": [float(elem) for elem in bbox], |
|
} |
|
) |
|
|
|
return objects |
|
|
|
|
|
def detect_and_crop_save_table( |
|
file_path, cropped_table_directory="./table_images/" |
|
): |
|
image = Image.open(file_path) |
|
|
|
filename, _ = os.path.splitext(file_path.split("/")[-1]) |
|
|
|
if not os.path.exists(cropped_table_directory): |
|
os.makedirs(cropped_table_directory) |
|
|
|
|
|
|
|
pixel_values = detection_transform(image).unsqueeze(0).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(pixel_values) |
|
|
|
|
|
id2label = model.config.id2label |
|
id2label[len(model.config.id2label)] = "no object" |
|
detected_tables = outputs_to_objects(outputs, image.size, id2label) |
|
|
|
print(f"number of tables detected {len(detected_tables)}") |
|
|
|
for idx in range(len(detected_tables)): |
|
|
|
cropped_table = image.crop(detected_tables[idx]["bbox"]) |
|
cropped_table.save(f"./{cropped_table_directory}/{filename}_{idx}.png") |
|
|
|
|
|
def plot_images(image_paths): |
|
images_shown = 0 |
|
plt.figure(figsize=(16, 9)) |
|
for img_path in image_paths: |
|
if os.path.isfile(img_path): |
|
image = Image.open(img_path) |
|
|
|
plt.subplot(2, 3, images_shown + 1) |
|
plt.imshow(image) |
|
plt.xticks([]) |
|
plt.yticks([]) |
|
|
|
images_shown += 1 |
|
if images_shown >= 9: |
|
break |