bachpc's picture
Fix bug and clean
17ae8b6
raw
history blame
14.2 kB
import streamlit as st
import PIL
import cv2
import numpy as np
import pandas as pd
import torch
# import sys
# import json
from collections import OrderedDict, defaultdict
import xml.etree.ElementTree as ET
from paddleocr import PaddleOCR
import pytesseract
from pytesseract import Output
import postprocess
ocr_instance = PaddleOCR(use_angle_cls=False, lang='en', use_gpu=True)
detection_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/detection_wts.pt', force_reload=True)
structure_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/structure_wts.pt', force_reload=True)
imgsz = 640
detection_class_names = ['table', 'table rotated']
structure_class_names = [
'table', 'table column', 'table row', 'table column header',
'table projected row header', 'table spanning cell', 'no object'
]
structure_class_map = {k: v for v, k in enumerate(structure_class_names)}
structure_class_thresholds = {
'table': 0.5,
'table column': 0.5,
'table row': 0.5,
'table column header': 0.25,
'table projected row header': 0.25,
'table spanning cell': 0.25,
'no object': 10
}
def PIL_to_cv(pil_img):
return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
def cv_to_PIL(cv_img):
return PIL.Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
def table_detection(pil_img):
image = PIL_to_cv(pil_img)
pred = detection_model(image, size=imgsz)
pred = pred.xywhn[0]
result = pred.cpu().numpy()
return result
def table_structure(pil_img):
image = PIL_to_cv(pil_img)
pred = structure_model(image, size=imgsz)
pred = pred.xywhn[0]
result = pred.cpu().numpy()
return result
def crop_image(pil_img, detection_result):
crop_images = []
image = PIL_to_cv(pil_img)
width = image.shape[1]
height = image.shape[0]
# print(width, height)
for i, result in enumerate(detection_result):
class_id = int(result[5])
score = float(result[4])
min_x = result[0]
min_y = result[1]
w = result[2]
h = result[3]
x1 = max(0, int((min_x - w / 2 - 0.02) * width))
y1 = max(0, int((min_y - h / 2 - 0.02) * height))
x2 = min(width, int((min_x + w / 2 + 0.02) * width))
y2 = min(height, int((min_y + h / 2 + 0.02) * height))
# print(x1, y1, x2, y2)
crop_image = image[y1:y2, x1:x2, :]
crop_images.append(cv_to_PIL(crop_image))
cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 255, 0))
return crop_images, cv_to_PIL(image)
def ocr(pil_img):
image = PIL_to_cv(pil_img)
result = ocr_instance.ocr(image)
ocr_res = []
for ps, (text, score) in result[0]:
x1 = min(p[0] for p in ps)
y1 = min(p[1] for p in ps)
x2 = max(p[0] for p in ps)
y2 = max(p[1] for p in ps)
word_info = {
'bbox': [x1, y1, x2, y2],
'text': text
}
ocr_res.append(word_info)
return ocr_res
def convert_stucture(page_tokens, pil_img, structure_result):
image = PIL_to_cv(pil_img)
width = image.shape[1]
height = image.shape[0]
# print(width, height)
bboxes = []
scores = []
labels = []
for i, result in enumerate(structure_result):
class_id = int(result[5])
score = float(result[4])
min_x = result[0]
min_y = result[1]
w = result[2]
h = result[3]
x1 = int((min_x - w / 2) * width)
y1 = int((min_y - h / 2) * height)
x2 = int((min_x + w / 2) * width)
y2 = int((min_y + h / 2) * height)
# print(x1, y1, x2, y2)
bboxes.append([x1, y1, x2, y2])
scores.append(score)
labels.append(class_id)
table_objects = []
for bbox, score, label in zip(bboxes, scores, labels):
table_objects.append({'bbox': bbox, 'score': score, 'label': label})
# print('table_objects:', table_objects)
table = {'objects': table_objects, 'page_num': 0}
table_class_objects = [obj for obj in table_objects if obj['label'] == structure_class_map['table']]
if len(table_class_objects) > 1:
table_class_objects = sorted(table_class_objects, key=lambda x: x['score'], reverse=True)
try:
table_bbox = list(table_class_objects[0]['bbox'])
except:
table_bbox = (0, 0, 1000, 1000)
# print('table_class_objects:', table_class_objects)
# print('table_bbox:', table_bbox)
tokens_in_table = [token for token in page_tokens if postprocess.iob(token['bbox'], table_bbox) >= 0.5]
# print('tokens_in_table:', tokens_in_table)
table_structures, cells, confidence_score = postprocess.objects_to_cells(table, table_objects, tokens_in_table, structure_class_names, structure_class_thresholds)
return table_structures, cells, confidence_score
def visualize_ocr(pil_img, ocr_result):
image = PIL_to_cv(pil_img)
for i, res in enumerate(ocr_result):
bbox = res['bbox']
x1 = int(bbox[0])
y1 = int(bbox[1])
x2 = int(bbox[2])
y2 = int(bbox[3])
cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 255, 0))
return cv_to_PIL(image)
def visualize_structure(pil_img, structure_result):
image = PIL_to_cv(pil_img)
width = image.shape[1]
height = image.shape[0]
# print(width, height)
for i, result in enumerate(structure_result):
class_id = int(result[5])
score = float(result[4])
min_x = result[0]
min_y = result[1]
w = result[2]
h = result[3]
x1 = int((min_x - w / 2) * width)
y1 = int((min_y - h / 2) * height)
x2 = int((min_x + w / 2) * width)
y2 = int((min_y + h / 2) * height)
# print(x1, y1, x2, y2)
if score >= structure_class_thresholds[structure_class_names[class_id]]:
cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 0, 255))
#cv2.putText(image, str(i)+'-'+str(class_id), (x1-10, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255))
return cv_to_PIL(image)
def visualize_cells(pil_img, cells):
image = PIL_to_cv(pil_img)
for i, cell in enumerate(cells):
bbox = cell['bbox']
x1 = int(bbox[0])
y1 = int(bbox[1])
x2 = int(bbox[2])
y2 = int(bbox[3])
cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 255, 0))
return cv_to_PIL(image)
def pytess(cell_pil_img):
return ' '.join(pytesseract.image_to_data(cell_pil_img, output_type=Output.DICT, config='-c tessedit_char_blacklist=œ˜â€œï¬â™Ã©œ¢!|”?«“¥ --tessdata-dir tessdata --oem 3 --psm 6')['text']).strip()
def resize(pil_img, size=1800):
length_x, width_y = pil_img.size
factor = max(1, size / length_x)
size = int(factor * length_x), int(factor * width_y)
pil_img = pil_img.resize(size, PIL.Image.ANTIALIAS)
return pil_img, factor
def image_smoothening(img):
ret1, th1 = cv2.threshold(img, 180, 255, cv2.THRESH_BINARY)
ret2, th2 = cv2.threshold(th1, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
blur = cv2.GaussianBlur(th2, (1, 1), 0)
ret3, th3 = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
return th3
def remove_noise_and_smooth(pil_img):
img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY)
filtered = cv2.adaptiveThreshold(img.astype(np.uint8), 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 41, 3)
kernel = np.ones((1, 1), np.uint8)
opening = cv2.morphologyEx(filtered, cv2.MORPH_OPEN, kernel)
closing = cv2.morphologyEx(opening, cv2.MORPH_CLOSE, kernel)
img = image_smoothening(img)
or_image = cv2.bitwise_or(img, closing)
pil_img = PIL.Image.fromarray(or_image)
return pil_img
# def extract_text_from_cells(pil_img, cells):
# pil_img, factor = resize(pil_img)
# #pil_img = remove_noise_and_smooth(pil_img)
# #display(pil_img)
# for cell in cells:
# bbox = [x * factor for x in cell['bbox']]
# cell_pil_img = pil_img.crop(bbox)
# #cell_pil_img = remove_noise_and_smooth(cell_pil_img)
# #cell_pil_img = tess_prep(cell_pil_img)
# cell['cell text'] = pytess(cell_pil_img)
# return cells
def extract_text_from_cells(cells, sep=' '):
for cell in cells:
spans = cell['spans']
text = ''
for span in spans:
if 'text' in span:
text += span['text'] + sep
cell['cell_text'] = text
return cells
def cells_to_csv(cells):
if len(cells) > 0:
num_columns = max([max(cell['column_nums']) for cell in cells]) + 1
num_rows = max([max(cell['row_nums']) for cell in cells]) + 1
else:
return
header_cells = [cell for cell in cells if cell['header']]
if len(header_cells) > 0:
max_header_row = max([max(cell['row_nums']) for cell in header_cells])
else:
max_header_row = -1
table_array = np.empty([num_rows, num_columns], dtype='object')
if len(cells) > 0:
for cell in cells:
for row_num in cell['row_nums']:
for column_num in cell['column_nums']:
table_array[row_num, column_num] = cell['cell_text']
header = table_array[:max_header_row+1,:]
flattened_header = []
for col in header.transpose():
flattened_header.append(' | '.join(OrderedDict.fromkeys(col)))
df = pd.DataFrame(table_array[max_header_row+1:,:], index=None, columns=flattened_header)
return df, df.to_csv(index=None)
def cells_to_html(cells):
cells = sorted(cells, key=lambda k: min(k['column_nums']))
cells = sorted(cells, key=lambda k: min(k['row_nums']))
table = ET.Element('table')
current_row = -1
for cell in cells:
this_row = min(cell['row_nums'])
attrib = {}
colspan = len(cell['column_nums'])
if colspan > 1:
attrib['colspan'] = str(colspan)
rowspan = len(cell['row_nums'])
if rowspan > 1:
attrib['rowspan'] = str(rowspan)
if this_row > current_row:
current_row = this_row
if cell['header']:
cell_tag = 'th'
row = ET.SubElement(table, 'thead')
else:
cell_tag = 'td'
row = ET.SubElement(table, 'tr')
tcell = ET.SubElement(row, cell_tag, attrib=attrib)
tcell.text = cell['cell_text']
return str(ET.tostring(table, encoding='unicode', short_empty_elements=False))
# def cells_to_html(cells):
# for cell in cells:
# cell['column_nums'].sort()
# cell['row_nums'].sort()
# n_cols = max(cell['column_nums'][-1] for cell in cells) + 1
# n_rows = max(cell['row_nums'][-1] for cell in cells) + 1
# html_code = ''
# for r in range(n_rows):
# r_cells = [cell for cell in cells if cell['row_nums'][0] == r]
# r_cells.sort(key=lambda x: x['column_nums'][0])
# r_html = ''
# for cell in r_cells:
# rowspan = cell['row_nums'][-1] - cell['row_nums'][0] + 1
# colspan = cell['column_nums'][-1] - cell['column_nums'][0] + 1
# r_html += f'<td rowspan='{rowspan}' colspan='{colspan}'>{escape(cell['text'])}</td>'
# html_code += f'<tr>{r_html}</tr>'
# html_code = '''<html>
# <head>
# <meta charset='UTF-8'>
# <style>
# table, th, td {
# border: 1px solid black;
# font-size: 10px;
# }
# </style>
# </head>
# <body>
# <table frame='hsides' rules='groups' width='100%%'>
# %s
# </table>
# </body>
# </html>''' % html_code
# soup = bs(html_code)
# html_code = soup.prettify()
# return html_code
def main():
st.set_page_config(layout='wide')
st.title('Table Extraction Demo')
st.write('\n')
cols = st.columns((1, 1))
cols[0].subheader('Input page')
cols[1].subheader('Table(s) detected')
st.sidebar.title('Image upload')
st.set_option('deprecation.showfileUploaderEncoding', False)
filename = st.sidebar.file_uploader('Upload files', type=['png', 'jpeg', 'jpg'])
if st.sidebar.button('Analyze image'):
if filename is None:
st.sidebar.write('Please upload an image')
else:
print(filename)
pil_img = PIL.Image.open(filename)
cols[0].image(pil_img)
detection_result = table_detection(pil_img)
crop_images, vis_det_img = crop_image(pil_img, detection_result)
cols[1].image(vis_det_img)
str_cols = st.columns((len(crop_images), ) * 5)
str_cols[0].subheader('Table image')
str_cols[1].subheader('OCR result')
str_cols[2].subheader('Structure result')
str_cols[3].subheader('Cells result')
str_cols[4].subheader('CSV result')
for i, img in enumerate(crop_images):
ocr_result = ocr(img)
structure_result = table_structure(img)
table_structures, cells, confidence_score = convert_stucture(ocr_result, img, structure_result)
cells = extract_text_from_cells(cells)
html_result = cells_to_html(cells)
df, csv_result = cells_to_csv(cells)
#print(df)
vis_ocr_img = visualize_ocr(img, ocr_result)
vis_str_img = visualize_structure(img, structure_result)
vis_cells_img = visualize_cells(img, cells)
str_cols[0].image(img)
str_cols[1].image(vis_ocr_img)
str_cols[2].image(vis_str_img)
str_cols[3].image(vis_cells_img)
#str_cols[4].dataframe(df)
str_cols[4].download_button('Download table', csv_result, f'table-{i}.csv', 'text/csv', key=f'download-csv-{i}')
st.markdown(html_result, unsafe_allow_html=True)
if __name__ == '__main__':
main()