import streamlit as st import PIL import cv2 import numpy as np import pandas as pd import torch # import sys # import json from collections import OrderedDict, defaultdict import xml.etree.ElementTree as ET from paddleocr import PaddleOCR import pytesseract from pytesseract import Output import postprocess ocr_instance = PaddleOCR(use_angle_cls=False, lang='en', use_gpu=True) detection_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/detection_wts.pt', force_reload=True) structure_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/structure_wts.pt', force_reload=True) imgsz = 640 detection_class_names = ['table', 'table rotated'] structure_class_names = [ 'table', 'table column', 'table row', 'table column header', 'table projected row header', 'table spanning cell', 'no object' ] structure_class_map = {k: v for v, k in enumerate(structure_class_names)} structure_class_thresholds = { 'table': 0.5, 'table column': 0.5, 'table row': 0.5, 'table column header': 0.25, 'table projected row header': 0.25, 'table spanning cell': 0.25, 'no object': 10 } def PIL_to_cv(pil_img): return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) def cv_to_PIL(cv_img): return PIL.Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)) def table_detection(pil_img): image = PIL_to_cv(pil_img) pred = detection_model(image, size=imgsz) pred = pred.xywhn[0] result = pred.cpu().numpy() return result def table_structure(pil_img): image = PIL_to_cv(pil_img) pred = structure_model(image, size=imgsz) pred = pred.xywhn[0] result = pred.cpu().numpy() return result def crop_image(pil_img, detection_result): crop_images = [] image = PIL_to_cv(pil_img) width = image.shape[1] height = image.shape[0] # print(width, height) for i, result in enumerate(detection_result): class_id = int(result[5]) score = float(result[4]) min_x = result[0] min_y = result[1] w = result[2] h = result[3] x1 = max(0, int((min_x - w / 2 - 0.02) * width)) y1 = max(0, int((min_y - h / 2 - 0.02) * height)) x2 = min(width, int((min_x + w / 2 + 0.02) * width)) y2 = min(height, int((min_y + h / 2 + 0.02) * height)) # print(x1, y1, x2, y2) crop_image = image[y1:y2, x1:x2, :] crop_images.append(cv_to_PIL(crop_image)) cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 255, 0)) return crop_images, cv_to_PIL(image) def ocr(pil_img): image = PIL_to_cv(pil_img) result = ocr_instance.ocr(image) ocr_res = [] for ps, (text, score) in result[0]: x1 = min(p[0] for p in ps) y1 = min(p[1] for p in ps) x2 = max(p[0] for p in ps) y2 = max(p[1] for p in ps) word_info = { 'bbox': [x1, y1, x2, y2], 'text': text } ocr_res.append(word_info) return ocr_res def convert_stucture(page_tokens, pil_img, structure_result): image = PIL_to_cv(pil_img) width = image.shape[1] height = image.shape[0] # print(width, height) bboxes = [] scores = [] labels = [] for i, result in enumerate(structure_result): class_id = int(result[5]) score = float(result[4]) min_x = result[0] min_y = result[1] w = result[2] h = result[3] x1 = int((min_x - w / 2) * width) y1 = int((min_y - h / 2) * height) x2 = int((min_x + w / 2) * width) y2 = int((min_y + h / 2) * height) # print(x1, y1, x2, y2) bboxes.append([x1, y1, x2, y2]) scores.append(score) labels.append(class_id) table_objects = [] for bbox, score, label in zip(bboxes, scores, labels): table_objects.append({'bbox': bbox, 'score': score, 'label': label}) # print('table_objects:', table_objects) table = {'objects': table_objects, 'page_num': 0} table_class_objects = [obj for obj in table_objects if obj['label'] == structure_class_map['table']] if len(table_class_objects) > 1: table_class_objects = sorted(table_class_objects, key=lambda x: x['score'], reverse=True) try: table_bbox = list(table_class_objects[0]['bbox']) except: table_bbox = (0, 0, 1000, 1000) # print('table_class_objects:', table_class_objects) # print('table_bbox:', table_bbox) tokens_in_table = [token for token in page_tokens if postprocess.iob(token['bbox'], table_bbox) >= 0.5] # print('tokens_in_table:', tokens_in_table) table_structures, cells, confidence_score = postprocess.objects_to_cells(table, table_objects, tokens_in_table, structure_class_names, structure_class_thresholds) return table_structures, cells, confidence_score def visualize_ocr(pil_img, ocr_result): image = PIL_to_cv(pil_img) for i, res in enumerate(ocr_result): bbox = res['bbox'] x1 = int(bbox[0]) y1 = int(bbox[1]) x2 = int(bbox[2]) y2 = int(bbox[3]) cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 255, 0)) return cv_to_PIL(image) def visualize_structure(pil_img, structure_result): image = PIL_to_cv(pil_img) width = image.shape[1] height = image.shape[0] # print(width, height) for i, result in enumerate(structure_result): class_id = int(result[5]) score = float(result[4]) min_x = result[0] min_y = result[1] w = result[2] h = result[3] x1 = int((min_x - w / 2) * width) y1 = int((min_y - h / 2) * height) x2 = int((min_x + w / 2) * width) y2 = int((min_y + h / 2) * height) # print(x1, y1, x2, y2) if score >= structure_class_thresholds[structure_class_names[class_id]]: cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 0, 255)) #cv2.putText(image, str(i)+'-'+str(class_id), (x1-10, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255)) return cv_to_PIL(image) def visualize_cells(pil_img, cells): image = PIL_to_cv(pil_img) for i, cell in enumerate(cells): bbox = cell['bbox'] x1 = int(bbox[0]) y1 = int(bbox[1]) x2 = int(bbox[2]) y2 = int(bbox[3]) cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 255, 0)) return cv_to_PIL(image) def pytess(cell_pil_img): return ' '.join(pytesseract.image_to_data(cell_pil_img, output_type=Output.DICT, config='-c tessedit_char_blacklist=œ˜â€œï¬â™Ã©œ¢!|”?«“¥ --tessdata-dir tessdata --oem 3 --psm 6')['text']).strip() def resize(pil_img, size=1800): length_x, width_y = pil_img.size factor = max(1, size / length_x) size = int(factor * length_x), int(factor * width_y) pil_img = pil_img.resize(size, PIL.Image.ANTIALIAS) return pil_img, factor def image_smoothening(img): ret1, th1 = cv2.threshold(img, 180, 255, cv2.THRESH_BINARY) ret2, th2 = cv2.threshold(th1, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) blur = cv2.GaussianBlur(th2, (1, 1), 0) ret3, th3 = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) return th3 def remove_noise_and_smooth(pil_img): img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY) filtered = cv2.adaptiveThreshold(img.astype(np.uint8), 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 41, 3) kernel = np.ones((1, 1), np.uint8) opening = cv2.morphologyEx(filtered, cv2.MORPH_OPEN, kernel) closing = cv2.morphologyEx(opening, cv2.MORPH_CLOSE, kernel) img = image_smoothening(img) or_image = cv2.bitwise_or(img, closing) pil_img = PIL.Image.fromarray(or_image) return pil_img # def extract_text_from_cells(pil_img, cells): # pil_img, factor = resize(pil_img) # #pil_img = remove_noise_and_smooth(pil_img) # #display(pil_img) # for cell in cells: # bbox = [x * factor for x in cell['bbox']] # cell_pil_img = pil_img.crop(bbox) # #cell_pil_img = remove_noise_and_smooth(cell_pil_img) # #cell_pil_img = tess_prep(cell_pil_img) # cell['cell text'] = pytess(cell_pil_img) # return cells def extract_text_from_cells(cells, sep=' '): for cell in cells: spans = cell['spans'] text = '' for span in spans: if 'text' in span: text += span['text'] + sep cell['cell_text'] = text return cells def cells_to_csv(cells): if len(cells) > 0: num_columns = max([max(cell['column_nums']) for cell in cells]) + 1 num_rows = max([max(cell['row_nums']) for cell in cells]) + 1 else: return header_cells = [cell for cell in cells if cell['header']] if len(header_cells) > 0: max_header_row = max([max(cell['row_nums']) for cell in header_cells]) else: max_header_row = -1 table_array = np.empty([num_rows, num_columns], dtype='object') if len(cells) > 0: for cell in cells: for row_num in cell['row_nums']: for column_num in cell['column_nums']: table_array[row_num, column_num] = cell['cell_text'] header = table_array[:max_header_row+1,:] flattened_header = [] for col in header.transpose(): flattened_header.append(' | '.join(OrderedDict.fromkeys(col))) df = pd.DataFrame(table_array[max_header_row+1:,:], index=None, columns=flattened_header) return df, df.to_csv(index=None) def cells_to_html(cells): cells = sorted(cells, key=lambda k: min(k['column_nums'])) cells = sorted(cells, key=lambda k: min(k['row_nums'])) table = ET.Element('table') current_row = -1 for cell in cells: this_row = min(cell['row_nums']) attrib = {} colspan = len(cell['column_nums']) if colspan > 1: attrib['colspan'] = str(colspan) rowspan = len(cell['row_nums']) if rowspan > 1: attrib['rowspan'] = str(rowspan) if this_row > current_row: current_row = this_row if cell['header']: cell_tag = 'th' row = ET.SubElement(table, 'thead') else: cell_tag = 'td' row = ET.SubElement(table, 'tr') tcell = ET.SubElement(row, cell_tag, attrib=attrib) tcell.text = cell['cell_text'] return str(ET.tostring(table, encoding='unicode', short_empty_elements=False)) # def cells_to_html(cells): # for cell in cells: # cell['column_nums'].sort() # cell['row_nums'].sort() # n_cols = max(cell['column_nums'][-1] for cell in cells) + 1 # n_rows = max(cell['row_nums'][-1] for cell in cells) + 1 # html_code = '' # for r in range(n_rows): # r_cells = [cell for cell in cells if cell['row_nums'][0] == r] # r_cells.sort(key=lambda x: x['column_nums'][0]) # r_html = '' # for cell in r_cells: # rowspan = cell['row_nums'][-1] - cell['row_nums'][0] + 1 # colspan = cell['column_nums'][-1] - cell['column_nums'][0] + 1 # r_html += f'{escape(cell['text'])}' # html_code += f'{r_html}' # html_code = ''' # # # # # # # %s #
# # ''' % html_code # soup = bs(html_code) # html_code = soup.prettify() # return html_code def main(): st.set_page_config(layout='wide') st.title('Table Extraction Demo') st.write('\n') cols = st.columns((1, 1)) cols[0].subheader('Input page') cols[1].subheader('Table(s) detected') st.sidebar.title('Image upload') st.set_option('deprecation.showfileUploaderEncoding', False) filename = st.sidebar.file_uploader('Upload files', type=['png', 'jpeg', 'jpg']) if st.sidebar.button('Analyze image'): if filename is None: st.sidebar.write('Please upload an image') else: print(filename) pil_img = PIL.Image.open(filename) cols[0].image(pil_img) detection_result = table_detection(pil_img) crop_images, vis_det_img = crop_image(pil_img, detection_result) cols[1].image(vis_det_img) str_cols = st.columns((len(crop_images), ) * 5) str_cols[0].subheader('Table image') str_cols[1].subheader('OCR result') str_cols[2].subheader('Structure result') str_cols[3].subheader('Cells result') str_cols[4].subheader('CSV result') for i, img in enumerate(crop_images): ocr_result = ocr(img) structure_result = table_structure(img) table_structures, cells, confidence_score = convert_stucture(ocr_result, img, structure_result) cells = extract_text_from_cells(cells) html_result = cells_to_html(cells) df, csv_result = cells_to_csv(cells) #print(df) vis_ocr_img = visualize_ocr(img, ocr_result) vis_str_img = visualize_structure(img, structure_result) vis_cells_img = visualize_cells(img, cells) str_cols[0].image(img) str_cols[1].image(vis_ocr_img) str_cols[2].image(vis_str_img) str_cols[3].image(vis_cells_img) #str_cols[4].dataframe(df) str_cols[4].download_button('Download table', csv_result, f'table-{i}.csv', 'text/csv', key=f'download-csv-{i}') st.markdown(html_result, unsafe_allow_html=True) if __name__ == '__main__': main()