Spaces:
Build error
Build error
import streamlit as st | |
import PIL | |
import cv2 | |
import numpy as np | |
import pandas as pd | |
import torch | |
import io | |
# import sys | |
# import json | |
from collections import OrderedDict, defaultdict | |
import xml.etree.ElementTree as ET | |
import matplotlib.pyplot as plt | |
import matplotlib.patches as patches | |
from paddleocr import PaddleOCR | |
import pytesseract | |
from pytesseract import Output | |
import postprocess | |
ocr_instance = PaddleOCR(use_angle_cls=False, lang='en', use_gpu=True) | |
detection_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/detection_wts.pt', force_reload=True) | |
structure_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/structure_wts.pt', force_reload=True) | |
imgsz = 640 | |
detection_class_names = ['table', 'table rotated'] | |
structure_class_names = [ | |
'table', 'table column', 'table row', 'table column header', | |
'table projected row header', 'table spanning cell', 'no object' | |
] | |
structure_class_map = {k: v for v, k in enumerate(structure_class_names)} | |
structure_class_thresholds = { | |
"table": 0.42, | |
"table column": 0.56, | |
"table row": 0.5, | |
"table column header": 0.38, | |
"table projected row header": 0.27, | |
"table spanning cell": 0.4, | |
"no object": 10 | |
} | |
def PIL_to_cv(pil_img): | |
return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) | |
def cv_to_PIL(cv_img): | |
return PIL.Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)) | |
def table_detection(pil_img): | |
image = PIL_to_cv(pil_img) | |
pred = detection_model(image, size=imgsz) | |
pred = pred.xywhn[0] | |
result = pred.cpu().numpy() | |
return result | |
def table_structure(pil_img): | |
image = PIL_to_cv(pil_img) | |
pred = structure_model(image, size=imgsz) | |
pred = pred.xywhn[0] | |
result = pred.cpu().numpy() | |
return result | |
def crop_image(pil_img, detection_result, padding=30): | |
crop_images = [] | |
image = PIL_to_cv(pil_img) | |
width = image.shape[1] | |
height = image.shape[0] | |
# print(width, height) | |
for i, result in enumerate(detection_result): | |
class_id = int(result[5]) | |
score = float(result[4]) | |
min_x = result[0] | |
min_y = result[1] | |
w = result[2] | |
h = result[3] | |
x1 = max(0, int((min_x - w / 2) * width) - padding) | |
y1 = max(0, int((min_y - h / 2) * height) - padding) | |
x2 = min(width, int((min_x + w / 2) * width) + padding) | |
y2 = min(height, int((min_y + h / 2) * height) + padding) | |
# print(x1, y1, x2, y2) | |
crop_image = image[y1:y2, x1:x2, :] | |
crop_image = cv_to_PIL(crop_image) | |
if class_id == 1: # table rotated | |
crop_image = crop_image.rotate(270, expand=True) | |
crop_images.append(crop_image) | |
cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 0, 255)) | |
return crop_images, cv_to_PIL(image) | |
def ocr(pil_img): | |
image = PIL_to_cv(pil_img) | |
result = ocr_instance.ocr(image) | |
ocr_res = [] | |
for ps, (text, score) in result[0]: | |
x1 = min(p[0] for p in ps) | |
y1 = min(p[1] for p in ps) | |
x2 = max(p[0] for p in ps) | |
y2 = max(p[1] for p in ps) | |
word_info = { | |
'bbox': [x1, y1, x2, y2], | |
'text': text | |
} | |
ocr_res.append(word_info) | |
return ocr_res | |
def convert_stucture(page_tokens, pil_img, structure_result): | |
image = PIL_to_cv(pil_img) | |
width = image.shape[1] | |
height = image.shape[0] | |
# print(width, height) | |
bboxes = [] | |
scores = [] | |
labels = [] | |
for i, result in enumerate(structure_result): | |
class_id = int(result[5]) | |
score = float(result[4]) | |
min_x = result[0] | |
min_y = result[1] | |
w = result[2] | |
h = result[3] | |
x1 = int((min_x - w / 2) * width) | |
y1 = int((min_y - h / 2) * height) | |
x2 = int((min_x + w / 2) * width) | |
y2 = int((min_y + h / 2) * height) | |
# print(x1, y1, x2, y2) | |
bboxes.append([x1, y1, x2, y2]) | |
scores.append(score) | |
labels.append(class_id) | |
table_objects = [] | |
for bbox, score, label in zip(bboxes, scores, labels): | |
table_objects.append({'bbox': bbox, 'score': score, 'label': label}) | |
# print('table_objects:', table_objects) | |
table = {'objects': table_objects, 'page_num': 0} | |
table_class_objects = [obj for obj in table_objects if obj['label'] == structure_class_map['table']] | |
if len(table_class_objects) > 1: | |
table_class_objects = sorted(table_class_objects, key=lambda x: x['score'], reverse=True) | |
try: | |
table_bbox = list(table_class_objects[0]['bbox']) | |
except: | |
table_bbox = (0, 0, 1000, 1000) | |
# print('table_class_objects:', table_class_objects) | |
# print('table_bbox:', table_bbox) | |
tokens_in_table = [token for token in page_tokens if postprocess.iob(token['bbox'], table_bbox) >= 0.5] | |
# print('tokens_in_table:', tokens_in_table) | |
table_structures, cells, confidence_score = postprocess.objects_to_cells(table, table_objects, tokens_in_table, structure_class_names, structure_class_thresholds) | |
return table_structures, cells, confidence_score | |
def visualize_ocr(pil_img, ocr_result): | |
image = PIL_to_cv(pil_img) | |
for i, res in enumerate(ocr_result): | |
bbox = res['bbox'] | |
x1 = int(bbox[0]) | |
y1 = int(bbox[1]) | |
x2 = int(bbox[2]) | |
y2 = int(bbox[3]) | |
cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 255, 0)) | |
cv2.putText(image, res['text'], (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.25, color=(255, 0, 0)) | |
return cv_to_PIL(image) | |
def get_bbox_decorations(data_type, label): | |
if label == 0: | |
if data_type == 'detection': | |
return 'brown', 0.05, 3, '//' | |
else: | |
return 'brown', 0, 3, None | |
elif label == 1: | |
return 'red', 0.15, 2, None | |
elif label == 2: | |
return 'blue', 0.15, 2, None | |
elif label == 3: | |
return 'magenta', 0.2, 3, '//' | |
elif label == 4: | |
return 'cyan', 0.2, 4, '//' | |
elif label == 5: | |
return 'green', 0.2, 4, '\\\\' | |
return 'gray', 0, 0, None | |
def visualize_structure(pil_img, structure_result): | |
image = PIL_to_cv(pil_img) | |
width = image.shape[1] | |
height = image.shape[0] | |
# print(width, height) | |
fig, ax = plt.subplots(1) | |
ax.imshow(pil_img, interpolation='lanczos') | |
for i, result in enumerate(structure_result): | |
class_id = int(result[5]) | |
score = float(result[4]) | |
min_x = result[0] | |
min_y = result[1] | |
w = result[2] | |
h = result[3] | |
x1 = int((min_x - w / 2) * width) | |
y1 = int((min_y - h / 2) * height) | |
x2 = int((min_x + w / 2) * width) | |
y2 = int((min_y + h / 2) * height) | |
# print(x1, y1, x2, y2) | |
bbox = [x1, y1, x2, y2] | |
if score >= structure_class_thresholds[structure_class_names[class_id]]: | |
#cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 255, 0)) | |
#cv2.putText(image, str(i)+'-'+str(class_id), (x1-10, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255)) | |
color, alpha, linewidth, hatch = get_bbox_decorations('recognition', class_id) | |
# Fill | |
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], | |
linewidth=linewidth, alpha=alpha, | |
edgecolor='none',facecolor=color, | |
linestyle=None) | |
ax.add_patch(rect) | |
# Hatch | |
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], | |
linewidth=1, alpha=0.4, | |
edgecolor=color,facecolor='none', | |
linestyle='--',hatch=hatch) | |
ax.add_patch(rect) | |
# Edge | |
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], | |
linewidth=linewidth, | |
edgecolor=color,facecolor='none', | |
linestyle="--") | |
ax.add_patch(rect) | |
plt.axis('off') | |
img_buf = io.BytesIO() | |
plt.savefig(img_buf, bbox_inches='tight', dpi=1000) | |
return PIL.Image.open(img_buf) | |
def visualize_cells(pil_img, cells): | |
fig, ax = plt.subplots(1) | |
ax.imshow(pil_img, interpolation='lanczos') | |
for i, cell in enumerate(cells): | |
bbox = cell['bbox'] | |
if cell['header']: | |
alpha = 0.3 | |
else: | |
alpha = 0.125 | |
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=1, | |
edgecolor='none',facecolor="magenta", alpha=alpha) | |
ax.add_patch(rect) | |
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=1, | |
edgecolor="magenta",facecolor='none',linestyle="--", | |
alpha=0.08, hatch='///') | |
ax.add_patch(rect) | |
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=1, | |
edgecolor="magenta",facecolor='none',linestyle="--") | |
ax.add_patch(rect) | |
plt.axis('off') | |
img_buf = io.BytesIO() | |
plt.savefig(img_buf, bbox_inches='tight', dpi=1000) | |
return PIL.Image.open(img_buf) | |
def pytess(cell_pil_img): | |
return ' '.join(pytesseract.image_to_data(cell_pil_img, output_type=Output.DICT, config='-c tessedit_char_blacklist=œ˜â€œï¬â™Ã©œ¢!|”?«“¥ --tessdata-dir tessdata --oem 3 --psm 6')['text']).strip() | |
def resize(pil_img, size=1800): | |
length_x, width_y = pil_img.size | |
factor = max(1, size / length_x) | |
size = int(factor * length_x), int(factor * width_y) | |
pil_img = pil_img.resize(size, PIL.Image.ANTIALIAS) | |
return pil_img, factor | |
def image_smoothening(img): | |
ret1, th1 = cv2.threshold(img, 180, 255, cv2.THRESH_BINARY) | |
ret2, th2 = cv2.threshold(th1, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
blur = cv2.GaussianBlur(th2, (1, 1), 0) | |
ret3, th3 = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
return th3 | |
def remove_noise_and_smooth(pil_img): | |
img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY) | |
filtered = cv2.adaptiveThreshold(img.astype(np.uint8), 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 41, 3) | |
kernel = np.ones((1, 1), np.uint8) | |
opening = cv2.morphologyEx(filtered, cv2.MORPH_OPEN, kernel) | |
closing = cv2.morphologyEx(opening, cv2.MORPH_CLOSE, kernel) | |
img = image_smoothening(img) | |
or_image = cv2.bitwise_or(img, closing) | |
pil_img = PIL.Image.fromarray(or_image) | |
return pil_img | |
# def extract_text_from_cells(pil_img, cells): | |
# pil_img, factor = resize(pil_img) | |
# #pil_img = remove_noise_and_smooth(pil_img) | |
# #display(pil_img) | |
# for cell in cells: | |
# bbox = [x * factor for x in cell['bbox']] | |
# cell_pil_img = pil_img.crop(bbox) | |
# #cell_pil_img = remove_noise_and_smooth(cell_pil_img) | |
# #cell_pil_img = tess_prep(cell_pil_img) | |
# cell['cell text'] = pytess(cell_pil_img) | |
# return cells | |
def extract_text_from_cells(cells, sep=' '): | |
for cell in cells: | |
spans = cell['spans'] | |
text = '' | |
for span in spans: | |
if 'text' in span: | |
text += span['text'] + sep | |
cell['cell_text'] = text | |
return cells | |
def cells_to_csv(cells): | |
if len(cells) > 0: | |
num_columns = max([max(cell['column_nums']) for cell in cells]) + 1 | |
num_rows = max([max(cell['row_nums']) for cell in cells]) + 1 | |
else: | |
return | |
header_cells = [cell for cell in cells if cell['header']] | |
if len(header_cells) > 0: | |
max_header_row = max([max(cell['row_nums']) for cell in header_cells]) | |
else: | |
max_header_row = -1 | |
table_array = np.empty([num_rows, num_columns], dtype='object') | |
if len(cells) > 0: | |
for cell in cells: | |
for row_num in cell['row_nums']: | |
for column_num in cell['column_nums']: | |
table_array[row_num, column_num] = cell['cell_text'] | |
header = table_array[:max_header_row+1,:] | |
flattened_header = [] | |
for col in header.transpose(): | |
flattened_header.append(' | '.join(OrderedDict.fromkeys(col))) | |
df = pd.DataFrame(table_array[max_header_row+1:,:], index=None, columns=flattened_header) | |
return df, df.to_csv(index=None) | |
def cells_to_html(cells): | |
cells = sorted(cells, key=lambda k: min(k['column_nums'])) | |
cells = sorted(cells, key=lambda k: min(k['row_nums'])) | |
table = ET.Element('table') | |
current_row = -1 | |
for cell in cells: | |
this_row = min(cell['row_nums']) | |
attrib = {} | |
colspan = len(cell['column_nums']) | |
if colspan > 1: | |
attrib['colspan'] = str(colspan) | |
rowspan = len(cell['row_nums']) | |
if rowspan > 1: | |
attrib['rowspan'] = str(rowspan) | |
if this_row > current_row: | |
current_row = this_row | |
if cell['header']: | |
cell_tag = 'th' | |
row = ET.SubElement(table, 'tr') | |
else: | |
cell_tag = 'td' | |
row = ET.SubElement(table, 'tr') | |
tcell = ET.SubElement(row, cell_tag, attrib=attrib) | |
tcell.text = cell['cell_text'] | |
return str(ET.tostring(table, encoding='unicode', short_empty_elements=False)) | |
# def cells_to_html(cells): | |
# for cell in cells: | |
# cell['column_nums'].sort() | |
# cell['row_nums'].sort() | |
# n_cols = max(cell['column_nums'][-1] for cell in cells) + 1 | |
# n_rows = max(cell['row_nums'][-1] for cell in cells) + 1 | |
# html_code = '' | |
# for r in range(n_rows): | |
# r_cells = [cell for cell in cells if cell['row_nums'][0] == r] | |
# r_cells.sort(key=lambda x: x['column_nums'][0]) | |
# r_html = '' | |
# for cell in r_cells: | |
# rowspan = cell['row_nums'][-1] - cell['row_nums'][0] + 1 | |
# colspan = cell['column_nums'][-1] - cell['column_nums'][0] + 1 | |
# r_html += f'<td rowspan='{rowspan}' colspan='{colspan}'>{escape(cell['text'])}</td>' | |
# html_code += f'<tr>{r_html}</tr>' | |
# html_code = '''<html> | |
# <head> | |
# <meta charset='UTF-8'> | |
# <style> | |
# table, th, td { | |
# border: 1px solid black; | |
# font-size: 10px; | |
# } | |
# </style> | |
# </head> | |
# <body> | |
# <table frame='hsides' rules='groups' width='100%%'> | |
# %s | |
# </table> | |
# </body> | |
# </html>''' % html_code | |
# soup = bs(html_code) | |
# html_code = soup.prettify() | |
# return html_code | |
def main(): | |
st.set_page_config(layout='wide') | |
st.title('Table Extraction Demo') | |
st.write('\n') | |
filename = st.file_uploader('Upload image', type=['png', 'jpeg', 'jpg']) | |
if st.button('Analyze image'): | |
if filename is None: | |
st.write('Please upload an image') | |
else: | |
tabs = st.tabs( | |
['Table Detection', 'Table Structure Recognition'] | |
) | |
print(filename) | |
pil_img = PIL.Image.open(filename) | |
detection_result = table_detection(pil_img) | |
crop_images, vis_det_img = crop_image(pil_img, detection_result) | |
with tabs[0]: | |
st.image(vis_det_img) | |
with tabs[1]: | |
str_cols = st.columns((len(crop_images), ) * 5) | |
str_cols[0].subheader('Table image') | |
str_cols[1].subheader('OCR result') | |
str_cols[2].subheader('Structure result') | |
str_cols[3].subheader('Cells result') | |
str_cols[4].subheader('CSV result') | |
for i, img in enumerate(crop_images): | |
ocr_result = ocr(img) | |
structure_result = table_structure(img) | |
table_structures, cells, confidence_score = convert_stucture(ocr_result, img, structure_result) | |
cells = extract_text_from_cells(cells) | |
html_result = cells_to_html(cells) | |
df, csv_result = cells_to_csv(cells) | |
#print(df) | |
vis_ocr_img = visualize_ocr(img, ocr_result) | |
vis_str_img = visualize_structure(img, structure_result) | |
vis_cells_img = visualize_cells(img, cells) | |
str_cols[0].image(img) | |
str_cols[1].image(vis_ocr_img) | |
str_cols[2].image(vis_str_img) | |
str_cols[3].image(vis_cells_img) | |
#str_cols[4].dataframe(df) | |
str_cols[4].download_button('Download table', csv_result, f'table-{i}.csv', 'text/csv', key=f'download-csv-{i}') | |
st.write('\n') | |
st.markdown(html_result, unsafe_allow_html=True) | |
if __name__ == '__main__': | |
main() | |