import streamlit as st import PIL import cv2 import numpy as np import pandas as pd import torch import io # import sys # import json from collections import OrderedDict, defaultdict import xml.etree.ElementTree as ET import matplotlib.pyplot as plt import matplotlib.patches as patches from paddleocr import PaddleOCR import pytesseract from pytesseract import Output import postprocess ocr_instance = PaddleOCR(use_angle_cls=False, lang='en', use_gpu=True) detection_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/detection_wts.pt', force_reload=True) structure_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/structure_wts.pt', force_reload=True) imgsz = 640 detection_class_names = ['table', 'table rotated'] structure_class_names = [ 'table', 'table column', 'table row', 'table column header', 'table projected row header', 'table spanning cell', 'no object' ] structure_class_map = {k: v for v, k in enumerate(structure_class_names)} structure_class_thresholds = { "table": 0.42, "table column": 0.56, "table row": 0.5, "table column header": 0.38, "table projected row header": 0.27, "table spanning cell": 0.4, "no object": 10 } def PIL_to_cv(pil_img): return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) def cv_to_PIL(cv_img): return PIL.Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)) def table_detection(pil_img): image = PIL_to_cv(pil_img) pred = detection_model(image, size=imgsz) pred = pred.xywhn[0] result = pred.cpu().numpy() return result def table_structure(pil_img): image = PIL_to_cv(pil_img) pred = structure_model(image, size=imgsz) pred = pred.xywhn[0] result = pred.cpu().numpy() return result def crop_image(pil_img, detection_result, padding=30): crop_images = [] image = PIL_to_cv(pil_img) width = image.shape[1] height = image.shape[0] # print(width, height) for i, result in enumerate(detection_result): class_id = int(result[5]) score = float(result[4]) min_x = result[0] min_y = result[1] w = result[2] h = result[3] x1 = max(0, int((min_x - w / 2) * width) - padding) y1 = max(0, int((min_y - h / 2) * height) - padding) x2 = min(width, int((min_x + w / 2) * width) + padding) y2 = min(height, int((min_y + h / 2) * height) + padding) # print(x1, y1, x2, y2) crop_image = image[y1:y2, x1:x2, :] crop_image = cv_to_PIL(crop_image) if class_id == 1: # table rotated crop_image = crop_image.rotate(270, expand=True) crop_images.append(crop_image) cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 0, 255)) return crop_images, cv_to_PIL(image) def ocr(pil_img): image = PIL_to_cv(pil_img) result = ocr_instance.ocr(image) ocr_res = [] for ps, (text, score) in result[0]: x1 = min(p[0] for p in ps) y1 = min(p[1] for p in ps) x2 = max(p[0] for p in ps) y2 = max(p[1] for p in ps) word_info = { 'bbox': [x1, y1, x2, y2], 'text': text } ocr_res.append(word_info) return ocr_res def convert_stucture(page_tokens, pil_img, structure_result): image = PIL_to_cv(pil_img) width = image.shape[1] height = image.shape[0] # print(width, height) bboxes = [] scores = [] labels = [] for i, result in enumerate(structure_result): class_id = int(result[5]) score = float(result[4]) min_x = result[0] min_y = result[1] w = result[2] h = result[3] x1 = int((min_x - w / 2) * width) y1 = int((min_y - h / 2) * height) x2 = int((min_x + w / 2) * width) y2 = int((min_y + h / 2) * height) # print(x1, y1, x2, y2) bboxes.append([x1, y1, x2, y2]) scores.append(score) labels.append(class_id) table_objects = [] for bbox, score, label in zip(bboxes, scores, labels): table_objects.append({'bbox': bbox, 'score': score, 'label': label}) # print('table_objects:', table_objects) table = {'objects': table_objects, 'page_num': 0} table_class_objects = [obj for obj in table_objects if obj['label'] == structure_class_map['table']] if len(table_class_objects) > 1: table_class_objects = sorted(table_class_objects, key=lambda x: x['score'], reverse=True) try: table_bbox = list(table_class_objects[0]['bbox']) except: table_bbox = (0, 0, 1000, 1000) # print('table_class_objects:', table_class_objects) # print('table_bbox:', table_bbox) tokens_in_table = [token for token in page_tokens if postprocess.iob(token['bbox'], table_bbox) >= 0.5] # print('tokens_in_table:', tokens_in_table) table_structures, cells, confidence_score = postprocess.objects_to_cells(table, table_objects, tokens_in_table, structure_class_names, structure_class_thresholds) return table_structures, cells, confidence_score def visualize_ocr(pil_img, ocr_result): image = PIL_to_cv(pil_img) for i, res in enumerate(ocr_result): bbox = res['bbox'] x1 = int(bbox[0]) y1 = int(bbox[1]) x2 = int(bbox[2]) y2 = int(bbox[3]) cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 255, 0)) cv2.putText(image, res['text'], (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.25, color=(255, 0, 0)) return cv_to_PIL(image) def get_bbox_decorations(data_type, label): if label == 0: if data_type == 'detection': return 'brown', 0.05, 3, '//' else: return 'brown', 0, 3, None elif label == 1: return 'red', 0.15, 2, None elif label == 2: return 'blue', 0.15, 2, None elif label == 3: return 'magenta', 0.2, 3, '//' elif label == 4: return 'cyan', 0.2, 4, '//' elif label == 5: return 'green', 0.2, 4, '\\\\' return 'gray', 0, 0, None def visualize_structure(pil_img, structure_result): image = PIL_to_cv(pil_img) width = image.shape[1] height = image.shape[0] # print(width, height) fig, ax = plt.subplots(1) ax.imshow(pil_img, interpolation='lanczos') for i, result in enumerate(structure_result): class_id = int(result[5]) score = float(result[4]) min_x = result[0] min_y = result[1] w = result[2] h = result[3] x1 = int((min_x - w / 2) * width) y1 = int((min_y - h / 2) * height) x2 = int((min_x + w / 2) * width) y2 = int((min_y + h / 2) * height) # print(x1, y1, x2, y2) bbox = [x1, y1, x2, y2] if score >= structure_class_thresholds[structure_class_names[class_id]]: #cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 255, 0)) #cv2.putText(image, str(i)+'-'+str(class_id), (x1-10, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255)) color, alpha, linewidth, hatch = get_bbox_decorations('recognition', class_id) # Fill rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth, alpha=alpha, edgecolor='none',facecolor=color, linestyle=None) ax.add_patch(rect) # Hatch rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=1, alpha=0.4, edgecolor=color,facecolor='none', linestyle='--',hatch=hatch) ax.add_patch(rect) # Edge rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth, edgecolor=color,facecolor='none', linestyle="--") ax.add_patch(rect) plt.axis('off') img_buf = io.BytesIO() plt.savefig(img_buf, bbox_inches='tight', dpi=1000) return PIL.Image.open(img_buf) def visualize_cells(pil_img, cells): fig, ax = plt.subplots(1) ax.imshow(pil_img, interpolation='lanczos') for i, cell in enumerate(cells): bbox = cell['bbox'] if cell['header']: alpha = 0.3 else: alpha = 0.125 rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=1, edgecolor='none',facecolor="magenta", alpha=alpha) ax.add_patch(rect) rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=1, edgecolor="magenta",facecolor='none',linestyle="--", alpha=0.08, hatch='///') ax.add_patch(rect) rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=1, edgecolor="magenta",facecolor='none',linestyle="--") ax.add_patch(rect) plt.axis('off') img_buf = io.BytesIO() plt.savefig(img_buf, bbox_inches='tight', dpi=1000) return PIL.Image.open(img_buf) def pytess(cell_pil_img): return ' '.join(pytesseract.image_to_data(cell_pil_img, output_type=Output.DICT, config='-c tessedit_char_blacklist=œ˜â€œï¬â™Ã©œ¢!|”?«“¥ --tessdata-dir tessdata --oem 3 --psm 6')['text']).strip() def resize(pil_img, size=1800): length_x, width_y = pil_img.size factor = max(1, size / length_x) size = int(factor * length_x), int(factor * width_y) pil_img = pil_img.resize(size, PIL.Image.ANTIALIAS) return pil_img, factor def image_smoothening(img): ret1, th1 = cv2.threshold(img, 180, 255, cv2.THRESH_BINARY) ret2, th2 = cv2.threshold(th1, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) blur = cv2.GaussianBlur(th2, (1, 1), 0) ret3, th3 = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) return th3 def remove_noise_and_smooth(pil_img): img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY) filtered = cv2.adaptiveThreshold(img.astype(np.uint8), 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 41, 3) kernel = np.ones((1, 1), np.uint8) opening = cv2.morphologyEx(filtered, cv2.MORPH_OPEN, kernel) closing = cv2.morphologyEx(opening, cv2.MORPH_CLOSE, kernel) img = image_smoothening(img) or_image = cv2.bitwise_or(img, closing) pil_img = PIL.Image.fromarray(or_image) return pil_img # def extract_text_from_cells(pil_img, cells): # pil_img, factor = resize(pil_img) # #pil_img = remove_noise_and_smooth(pil_img) # #display(pil_img) # for cell in cells: # bbox = [x * factor for x in cell['bbox']] # cell_pil_img = pil_img.crop(bbox) # #cell_pil_img = remove_noise_and_smooth(cell_pil_img) # #cell_pil_img = tess_prep(cell_pil_img) # cell['cell text'] = pytess(cell_pil_img) # return cells def extract_text_from_cells(cells, sep=' '): for cell in cells: spans = cell['spans'] text = '' for span in spans: if 'text' in span: text += span['text'] + sep cell['cell_text'] = text return cells def cells_to_csv(cells): if len(cells) > 0: num_columns = max([max(cell['column_nums']) for cell in cells]) + 1 num_rows = max([max(cell['row_nums']) for cell in cells]) + 1 else: return header_cells = [cell for cell in cells if cell['header']] if len(header_cells) > 0: max_header_row = max([max(cell['row_nums']) for cell in header_cells]) else: max_header_row = -1 table_array = np.empty([num_rows, num_columns], dtype='object') if len(cells) > 0: for cell in cells: for row_num in cell['row_nums']: for column_num in cell['column_nums']: table_array[row_num, column_num] = cell['cell_text'] header = table_array[:max_header_row+1,:] flattened_header = [] for col in header.transpose(): flattened_header.append(' | '.join(OrderedDict.fromkeys(col))) df = pd.DataFrame(table_array[max_header_row+1:,:], index=None, columns=flattened_header) return df, df.to_csv(index=None) def cells_to_html(cells): cells = sorted(cells, key=lambda k: min(k['column_nums'])) cells = sorted(cells, key=lambda k: min(k['row_nums'])) table = ET.Element('table') current_row = -1 for cell in cells: this_row = min(cell['row_nums']) attrib = {} colspan = len(cell['column_nums']) if colspan > 1: attrib['colspan'] = str(colspan) rowspan = len(cell['row_nums']) if rowspan > 1: attrib['rowspan'] = str(rowspan) if this_row > current_row: current_row = this_row if cell['header']: cell_tag = 'th' row = ET.SubElement(table, 'tr') else: cell_tag = 'td' row = ET.SubElement(table, 'tr') tcell = ET.SubElement(row, cell_tag, attrib=attrib) tcell.text = cell['cell_text'] return str(ET.tostring(table, encoding='unicode', short_empty_elements=False)) # def cells_to_html(cells): # for cell in cells: # cell['column_nums'].sort() # cell['row_nums'].sort() # n_cols = max(cell['column_nums'][-1] for cell in cells) + 1 # n_rows = max(cell['row_nums'][-1] for cell in cells) + 1 # html_code = '' # for r in range(n_rows): # r_cells = [cell for cell in cells if cell['row_nums'][0] == r] # r_cells.sort(key=lambda x: x['column_nums'][0]) # r_html = '' # for cell in r_cells: # rowspan = cell['row_nums'][-1] - cell['row_nums'][0] + 1 # colspan = cell['column_nums'][-1] - cell['column_nums'][0] + 1 # r_html += f'