Arnold / microsofttt.py
Libroru's picture
Upload 15 files
612b7f5 verified
raw
history blame contribute delete
No virus
4.61 kB
"""
Microsoft Table Transformer Extension
By Neils:
https://docs.llamaindex.ai/en/stable/examples/multi_modal/multi_modal_pdf_tables.html#experiment-3-let-s-use-microsoft-table-transformer-to-crop-tables-from-the-images-and-see-if-it-gives-the-correct-answer
"""
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import Patch
import io
from PIL import Image, ImageDraw
import numpy as np
import csv
import pandas as pd
from torchvision import transforms
from transformers import AutoModelForObjectDetection
import torch
import openai
import os
import fitz
device = "cuda" if torch.cuda.is_available() else "cpu"
class MaxResize(object):
def __init__(self, max_size=800):
self.max_size = max_size
def __call__(self, image):
width, height = image.size
current_max_size = max(width, height)
scale = self.max_size / current_max_size
resized_image = image.resize(
(int(round(scale * width)), int(round(scale * height)))
)
return resized_image
detection_transform = transforms.Compose(
[
MaxResize(800),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
structure_transform = transforms.Compose(
[
MaxResize(1000),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
# load table detection model
# processor = TableTransformerImageProcessor(max_size=800)
model = AutoModelForObjectDetection.from_pretrained(
"microsoft/table-transformer-detection", revision="no_timm"
).to(device)
# load table structure recognition model
# structure_processor = TableTransformerImageProcessor(max_size=1000)
structure_model = AutoModelForObjectDetection.from_pretrained(
"microsoft/table-transformer-structure-recognition-v1.1-all"
).to(device)
# for output bounding box post-processing
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=1)
def rescale_bboxes(out_bbox, size):
width, height = size
boxes = box_cxcywh_to_xyxy(out_bbox)
boxes = boxes * torch.tensor(
[width, height, width, height], dtype=torch.float32
)
return boxes
def outputs_to_objects(outputs, img_size, id2label):
m = outputs.logits.softmax(-1).max(-1)
pred_labels = list(m.indices.detach().cpu().numpy())[0]
pred_scores = list(m.values.detach().cpu().numpy())[0]
pred_bboxes = outputs["pred_boxes"].detach().cpu()[0]
pred_bboxes = [
elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)
]
objects = []
for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
class_label = id2label[int(label)]
if not class_label == "no object":
objects.append(
{
"label": class_label,
"score": float(score),
"bbox": [float(elem) for elem in bbox],
}
)
return objects
def detect_and_crop_save_table(
file_path, cropped_table_directory="./table_images/"
):
image = Image.open(file_path)
filename, _ = os.path.splitext(file_path.split("/")[-1])
if not os.path.exists(cropped_table_directory):
os.makedirs(cropped_table_directory)
# prepare image for the model
# pixel_values = processor(image, return_tensors="pt").pixel_values
pixel_values = detection_transform(image).unsqueeze(0).to(device)
# forward pass
with torch.no_grad():
outputs = model(pixel_values)
# postprocess to get detected tables
id2label = model.config.id2label
id2label[len(model.config.id2label)] = "no object"
detected_tables = outputs_to_objects(outputs, image.size, id2label)
print(f"number of tables detected {len(detected_tables)}")
for idx in range(len(detected_tables)):
# # crop detected table out of image
cropped_table = image.crop(detected_tables[idx]["bbox"])
cropped_table.save(f"./{cropped_table_directory}/{filename}_{idx}.png")
def plot_images(image_paths):
images_shown = 0
plt.figure(figsize=(16, 9))
for img_path in image_paths:
if os.path.isfile(img_path):
image = Image.open(img_path)
plt.subplot(2, 3, images_shown + 1)
plt.imshow(image)
plt.xticks([])
plt.yticks([])
images_shown += 1
if images_shown >= 9:
break