bachpc's picture
Modify threshold
6353415
raw
history blame
17.1 kB
import streamlit as st
import PIL
import cv2
import numpy as np
import pandas as pd
import torch
import io
# import sys
# import json
from collections import OrderedDict, defaultdict
import xml.etree.ElementTree as ET
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from paddleocr import PaddleOCR
import pytesseract
from pytesseract import Output
import postprocess
ocr_instance = PaddleOCR(use_angle_cls=False, lang='en', use_gpu=True)
detection_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/detection_wts.pt', force_reload=True)
structure_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/structure_wts.pt', force_reload=True)
imgsz = 640
detection_class_names = ['table', 'table rotated']
structure_class_names = [
'table', 'table column', 'table row', 'table column header',
'table projected row header', 'table spanning cell', 'no object'
]
structure_class_map = {k: v for v, k in enumerate(structure_class_names)}
structure_class_thresholds = {
"table": 0.42,
"table column": 0.56,
"table row": 0.5,
"table column header": 0.38,
"table projected row header": 0.27,
"table spanning cell": 0.4,
"no object": 10
}
def PIL_to_cv(pil_img):
return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
def cv_to_PIL(cv_img):
return PIL.Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
def table_detection(pil_img):
image = PIL_to_cv(pil_img)
pred = detection_model(image, size=imgsz)
pred = pred.xywhn[0]
result = pred.cpu().numpy()
return result
def table_structure(pil_img):
image = PIL_to_cv(pil_img)
pred = structure_model(image, size=imgsz)
pred = pred.xywhn[0]
result = pred.cpu().numpy()
return result
def crop_image(pil_img, detection_result, padding=30):
crop_images = []
image = PIL_to_cv(pil_img)
width = image.shape[1]
height = image.shape[0]
# print(width, height)
for i, result in enumerate(detection_result):
class_id = int(result[5])
score = float(result[4])
min_x = result[0]
min_y = result[1]
w = result[2]
h = result[3]
x1 = max(0, int((min_x - w / 2) * width) - padding)
y1 = max(0, int((min_y - h / 2) * height) - padding)
x2 = min(width, int((min_x + w / 2) * width) + padding)
y2 = min(height, int((min_y + h / 2) * height) + padding)
# print(x1, y1, x2, y2)
crop_image = image[y1:y2, x1:x2, :]
crop_image = cv_to_PIL(crop_image)
if class_id == 1: # table rotated
crop_image = crop_image.rotate(270, expand=True)
crop_images.append(crop_image)
cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 0, 255))
return crop_images, cv_to_PIL(image)
def ocr(pil_img):
image = PIL_to_cv(pil_img)
result = ocr_instance.ocr(image)
ocr_res = []
for ps, (text, score) in result[0]:
x1 = min(p[0] for p in ps)
y1 = min(p[1] for p in ps)
x2 = max(p[0] for p in ps)
y2 = max(p[1] for p in ps)
word_info = {
'bbox': [x1, y1, x2, y2],
'text': text
}
ocr_res.append(word_info)
return ocr_res
def convert_stucture(page_tokens, pil_img, structure_result):
image = PIL_to_cv(pil_img)
width = image.shape[1]
height = image.shape[0]
# print(width, height)
bboxes = []
scores = []
labels = []
for i, result in enumerate(structure_result):
class_id = int(result[5])
score = float(result[4])
min_x = result[0]
min_y = result[1]
w = result[2]
h = result[3]
x1 = int((min_x - w / 2) * width)
y1 = int((min_y - h / 2) * height)
x2 = int((min_x + w / 2) * width)
y2 = int((min_y + h / 2) * height)
# print(x1, y1, x2, y2)
bboxes.append([x1, y1, x2, y2])
scores.append(score)
labels.append(class_id)
table_objects = []
for bbox, score, label in zip(bboxes, scores, labels):
table_objects.append({'bbox': bbox, 'score': score, 'label': label})
# print('table_objects:', table_objects)
table = {'objects': table_objects, 'page_num': 0}
table_class_objects = [obj for obj in table_objects if obj['label'] == structure_class_map['table']]
if len(table_class_objects) > 1:
table_class_objects = sorted(table_class_objects, key=lambda x: x['score'], reverse=True)
try:
table_bbox = list(table_class_objects[0]['bbox'])
except:
table_bbox = (0, 0, 1000, 1000)
# print('table_class_objects:', table_class_objects)
# print('table_bbox:', table_bbox)
tokens_in_table = [token for token in page_tokens if postprocess.iob(token['bbox'], table_bbox) >= 0.5]
# print('tokens_in_table:', tokens_in_table)
table_structures, cells, confidence_score = postprocess.objects_to_cells(table, table_objects, tokens_in_table, structure_class_names, structure_class_thresholds)
return table_structures, cells, confidence_score
def visualize_ocr(pil_img, ocr_result):
image = PIL_to_cv(pil_img)
for i, res in enumerate(ocr_result):
bbox = res['bbox']
x1 = int(bbox[0])
y1 = int(bbox[1])
x2 = int(bbox[2])
y2 = int(bbox[3])
cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 255, 0))
cv2.putText(image, res['text'], (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.25, color=(255, 0, 0))
return cv_to_PIL(image)
def get_bbox_decorations(data_type, label):
if label == 0:
if data_type == 'detection':
return 'brown', 0.05, 3, '//'
else:
return 'brown', 0, 3, None
elif label == 1:
return 'red', 0.15, 2, None
elif label == 2:
return 'blue', 0.15, 2, None
elif label == 3:
return 'magenta', 0.2, 3, '//'
elif label == 4:
return 'cyan', 0.2, 4, '//'
elif label == 5:
return 'green', 0.2, 4, '\\\\'
return 'gray', 0, 0, None
def visualize_structure(pil_img, structure_result):
image = PIL_to_cv(pil_img)
width = image.shape[1]
height = image.shape[0]
# print(width, height)
fig, ax = plt.subplots(1)
ax.imshow(pil_img, interpolation='lanczos')
for i, result in enumerate(structure_result):
class_id = int(result[5])
score = float(result[4])
min_x = result[0]
min_y = result[1]
w = result[2]
h = result[3]
x1 = int((min_x - w / 2) * width)
y1 = int((min_y - h / 2) * height)
x2 = int((min_x + w / 2) * width)
y2 = int((min_y + h / 2) * height)
# print(x1, y1, x2, y2)
bbox = [x1, y1, x2, y2]
if score >= structure_class_thresholds[structure_class_names[class_id]]:
#cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 255, 0))
#cv2.putText(image, str(i)+'-'+str(class_id), (x1-10, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255))
color, alpha, linewidth, hatch = get_bbox_decorations('recognition', class_id)
# Fill
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1],
linewidth=linewidth, alpha=alpha,
edgecolor='none',facecolor=color,
linestyle=None)
ax.add_patch(rect)
# Hatch
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1],
linewidth=1, alpha=0.4,
edgecolor=color,facecolor='none',
linestyle='--',hatch=hatch)
ax.add_patch(rect)
# Edge
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1],
linewidth=linewidth,
edgecolor=color,facecolor='none',
linestyle="--")
ax.add_patch(rect)
plt.axis('off')
img_buf = io.BytesIO()
plt.savefig(img_buf, bbox_inches='tight', dpi=1000)
return PIL.Image.open(img_buf)
def visualize_cells(pil_img, cells):
fig, ax = plt.subplots(1)
ax.imshow(pil_img, interpolation='lanczos')
for i, cell in enumerate(cells):
bbox = cell['bbox']
if cell['header']:
alpha = 0.3
else:
alpha = 0.125
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=1,
edgecolor='none',facecolor="magenta", alpha=alpha)
ax.add_patch(rect)
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=1,
edgecolor="magenta",facecolor='none',linestyle="--",
alpha=0.08, hatch='///')
ax.add_patch(rect)
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=1,
edgecolor="magenta",facecolor='none',linestyle="--")
ax.add_patch(rect)
plt.axis('off')
img_buf = io.BytesIO()
plt.savefig(img_buf, bbox_inches='tight', dpi=1000)
return PIL.Image.open(img_buf)
def pytess(cell_pil_img):
return ' '.join(pytesseract.image_to_data(cell_pil_img, output_type=Output.DICT, config='-c tessedit_char_blacklist=œ˜â€œï¬â™Ã©œ¢!|”?«“¥ --tessdata-dir tessdata --oem 3 --psm 6')['text']).strip()
def resize(pil_img, size=1800):
length_x, width_y = pil_img.size
factor = max(1, size / length_x)
size = int(factor * length_x), int(factor * width_y)
pil_img = pil_img.resize(size, PIL.Image.ANTIALIAS)
return pil_img, factor
def image_smoothening(img):
ret1, th1 = cv2.threshold(img, 180, 255, cv2.THRESH_BINARY)
ret2, th2 = cv2.threshold(th1, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
blur = cv2.GaussianBlur(th2, (1, 1), 0)
ret3, th3 = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
return th3
def remove_noise_and_smooth(pil_img):
img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY)
filtered = cv2.adaptiveThreshold(img.astype(np.uint8), 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 41, 3)
kernel = np.ones((1, 1), np.uint8)
opening = cv2.morphologyEx(filtered, cv2.MORPH_OPEN, kernel)
closing = cv2.morphologyEx(opening, cv2.MORPH_CLOSE, kernel)
img = image_smoothening(img)
or_image = cv2.bitwise_or(img, closing)
pil_img = PIL.Image.fromarray(or_image)
return pil_img
# def extract_text_from_cells(pil_img, cells):
# pil_img, factor = resize(pil_img)
# #pil_img = remove_noise_and_smooth(pil_img)
# #display(pil_img)
# for cell in cells:
# bbox = [x * factor for x in cell['bbox']]
# cell_pil_img = pil_img.crop(bbox)
# #cell_pil_img = remove_noise_and_smooth(cell_pil_img)
# #cell_pil_img = tess_prep(cell_pil_img)
# cell['cell text'] = pytess(cell_pil_img)
# return cells
def extract_text_from_cells(cells, sep=' '):
for cell in cells:
spans = cell['spans']
text = ''
for span in spans:
if 'text' in span:
text += span['text'] + sep
cell['cell_text'] = text
return cells
def cells_to_csv(cells):
if len(cells) > 0:
num_columns = max([max(cell['column_nums']) for cell in cells]) + 1
num_rows = max([max(cell['row_nums']) for cell in cells]) + 1
else:
return
header_cells = [cell for cell in cells if cell['header']]
if len(header_cells) > 0:
max_header_row = max([max(cell['row_nums']) for cell in header_cells])
else:
max_header_row = -1
table_array = np.empty([num_rows, num_columns], dtype='object')
if len(cells) > 0:
for cell in cells:
for row_num in cell['row_nums']:
for column_num in cell['column_nums']:
table_array[row_num, column_num] = cell['cell_text']
header = table_array[:max_header_row+1,:]
flattened_header = []
for col in header.transpose():
flattened_header.append(' | '.join(OrderedDict.fromkeys(col)))
df = pd.DataFrame(table_array[max_header_row+1:,:], index=None, columns=flattened_header)
return df, df.to_csv(index=None)
def cells_to_html(cells):
cells = sorted(cells, key=lambda k: min(k['column_nums']))
cells = sorted(cells, key=lambda k: min(k['row_nums']))
table = ET.Element('table')
current_row = -1
for cell in cells:
this_row = min(cell['row_nums'])
attrib = {}
colspan = len(cell['column_nums'])
if colspan > 1:
attrib['colspan'] = str(colspan)
rowspan = len(cell['row_nums'])
if rowspan > 1:
attrib['rowspan'] = str(rowspan)
if this_row > current_row:
current_row = this_row
if cell['header']:
cell_tag = 'th'
row = ET.SubElement(table, 'tr')
else:
cell_tag = 'td'
row = ET.SubElement(table, 'tr')
tcell = ET.SubElement(row, cell_tag, attrib=attrib)
tcell.text = cell['cell_text']
return str(ET.tostring(table, encoding='unicode', short_empty_elements=False))
# def cells_to_html(cells):
# for cell in cells:
# cell['column_nums'].sort()
# cell['row_nums'].sort()
# n_cols = max(cell['column_nums'][-1] for cell in cells) + 1
# n_rows = max(cell['row_nums'][-1] for cell in cells) + 1
# html_code = ''
# for r in range(n_rows):
# r_cells = [cell for cell in cells if cell['row_nums'][0] == r]
# r_cells.sort(key=lambda x: x['column_nums'][0])
# r_html = ''
# for cell in r_cells:
# rowspan = cell['row_nums'][-1] - cell['row_nums'][0] + 1
# colspan = cell['column_nums'][-1] - cell['column_nums'][0] + 1
# r_html += f'<td rowspan='{rowspan}' colspan='{colspan}'>{escape(cell['text'])}</td>'
# html_code += f'<tr>{r_html}</tr>'
# html_code = '''<html>
# <head>
# <meta charset='UTF-8'>
# <style>
# table, th, td {
# border: 1px solid black;
# font-size: 10px;
# }
# </style>
# </head>
# <body>
# <table frame='hsides' rules='groups' width='100%%'>
# %s
# </table>
# </body>
# </html>''' % html_code
# soup = bs(html_code)
# html_code = soup.prettify()
# return html_code
def main():
st.set_page_config(layout='wide')
st.title('Table Extraction Demo')
st.write('\n')
filename = st.file_uploader('Upload image', type=['png', 'jpeg', 'jpg'])
if st.button('Analyze image'):
if filename is None:
st.write('Please upload an image')
else:
tabs = st.tabs(
['Table Detection', 'Table Structure Recognition']
)
print(filename)
pil_img = PIL.Image.open(filename)
detection_result = table_detection(pil_img)
crop_images, vis_det_img = crop_image(pil_img, detection_result)
with tabs[0]:
st.image(vis_det_img)
with tabs[1]:
str_cols = st.columns((len(crop_images), ) * 5)
str_cols[0].subheader('Table image')
str_cols[1].subheader('OCR result')
str_cols[2].subheader('Structure result')
str_cols[3].subheader('Cells result')
str_cols[4].subheader('CSV result')
for i, img in enumerate(crop_images):
ocr_result = ocr(img)
structure_result = table_structure(img)
table_structures, cells, confidence_score = convert_stucture(ocr_result, img, structure_result)
cells = extract_text_from_cells(cells)
html_result = cells_to_html(cells)
df, csv_result = cells_to_csv(cells)
#print(df)
vis_ocr_img = visualize_ocr(img, ocr_result)
vis_str_img = visualize_structure(img, structure_result)
vis_cells_img = visualize_cells(img, cells)
str_cols[0].image(img)
str_cols[1].image(vis_ocr_img)
str_cols[2].image(vis_str_img)
str_cols[3].image(vis_cells_img)
#str_cols[4].dataframe(df)
str_cols[4].download_button('Download table', csv_result, f'table-{i}.csv', 'text/csv', key=f'download-csv-{i}')
st.write('\n')
st.markdown(html_result, unsafe_allow_html=True)
if __name__ == '__main__':
main()