Spaces:
Build error
Build error
Add extract to excel
Browse files- app.py +110 -32
- requirements.txt +1 -0
app.py
CHANGED
@@ -4,24 +4,32 @@ import cv2
|
|
4 |
import numpy as np
|
5 |
import pandas as pd
|
6 |
import torch
|
|
|
7 |
import io
|
8 |
# import sys
|
9 |
# import json
|
10 |
from collections import OrderedDict, defaultdict
|
11 |
import xml.etree.ElementTree as ET
|
|
|
|
|
12 |
import matplotlib.pyplot as plt
|
13 |
import matplotlib.patches as patches
|
14 |
|
15 |
from paddleocr import PaddleOCR
|
16 |
-
import pytesseract
|
17 |
-
from pytesseract import Output
|
18 |
|
19 |
import postprocess
|
20 |
|
21 |
|
22 |
ocr_instance = PaddleOCR(use_angle_cls=False, lang='en', use_gpu=True)
|
|
|
|
|
23 |
detection_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/detection_wts.pt', force_reload=True)
|
|
|
|
|
24 |
structure_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/structure_wts.pt', force_reload=True)
|
|
|
25 |
imgsz = 640
|
26 |
|
27 |
detection_class_names = ['table', 'table rotated']
|
@@ -285,36 +293,36 @@ def visualize_cells(pil_img, cells):
|
|
285 |
return PIL.Image.open(img_buf)
|
286 |
|
287 |
|
288 |
-
def pytess(cell_pil_img):
|
289 |
-
|
290 |
|
291 |
|
292 |
-
def resize(pil_img, size=1800):
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
|
299 |
|
300 |
-
def image_smoothening(img):
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
|
307 |
|
308 |
-
def remove_noise_and_smooth(pil_img):
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
|
319 |
|
320 |
# def extract_text_from_cells(pil_img, cells):
|
@@ -438,6 +446,53 @@ def cells_to_html(cells):
|
|
438 |
# return html_code
|
439 |
|
440 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
441 |
def main():
|
442 |
|
443 |
st.set_page_config(layout='wide')
|
@@ -453,7 +508,7 @@ def main():
|
|
453 |
|
454 |
else:
|
455 |
tabs = st.tabs(
|
456 |
-
['Table Detection', 'Table Structure Recognition']
|
457 |
)
|
458 |
|
459 |
print(filename)
|
@@ -462,24 +517,29 @@ def main():
|
|
462 |
detection_result = table_detection(pil_img)
|
463 |
crop_images, vis_det_img = crop_image(pil_img, detection_result)
|
464 |
|
|
|
|
|
465 |
with tabs[0]:
|
|
|
466 |
st.image(vis_det_img)
|
467 |
|
468 |
with tabs[1]:
|
469 |
-
|
|
|
|
|
470 |
str_cols[0].subheader('Table image')
|
471 |
str_cols[1].subheader('OCR result')
|
472 |
str_cols[2].subheader('Structure result')
|
473 |
str_cols[3].subheader('Cells result')
|
474 |
-
str_cols[4].subheader('CSV result')
|
475 |
|
476 |
for i, img in enumerate(crop_images):
|
477 |
ocr_result = ocr(img)
|
478 |
structure_result = table_structure(img)
|
479 |
table_structures, cells, confidence_score = convert_stucture(ocr_result, img, structure_result)
|
480 |
cells = extract_text_from_cells(cells)
|
|
|
481 |
html_result = cells_to_html(cells)
|
482 |
-
df, csv_result = cells_to_csv(cells)
|
483 |
#print(df)
|
484 |
|
485 |
vis_ocr_img = visualize_ocr(img, ocr_result)
|
@@ -490,12 +550,30 @@ def main():
|
|
490 |
str_cols[1].image(vis_ocr_img)
|
491 |
str_cols[2].image(vis_str_img)
|
492 |
str_cols[3].image(vis_cells_img)
|
493 |
-
#str_cols[4].dataframe(df)
|
494 |
-
str_cols[4].download_button('Download table', csv_result, f'table-{i}.csv', 'text/csv', key=f'download-csv-{i}')
|
495 |
|
496 |
st.write('\n')
|
497 |
st.markdown(html_result, unsafe_allow_html=True)
|
498 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
499 |
|
500 |
if __name__ == '__main__':
|
501 |
main()
|
|
|
4 |
import numpy as np
|
5 |
import pandas as pd
|
6 |
import torch
|
7 |
+
import os
|
8 |
import io
|
9 |
# import sys
|
10 |
# import json
|
11 |
from collections import OrderedDict, defaultdict
|
12 |
import xml.etree.ElementTree as ET
|
13 |
+
from tempfile import TemporaryDirectory
|
14 |
+
import xlsxwriter
|
15 |
import matplotlib.pyplot as plt
|
16 |
import matplotlib.patches as patches
|
17 |
|
18 |
from paddleocr import PaddleOCR
|
19 |
+
# import pytesseract
|
20 |
+
# from pytesseract import Output
|
21 |
|
22 |
import postprocess
|
23 |
|
24 |
|
25 |
ocr_instance = PaddleOCR(use_angle_cls=False, lang='en', use_gpu=True)
|
26 |
+
|
27 |
+
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
|
28 |
detection_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/detection_wts.pt', force_reload=True)
|
29 |
+
|
30 |
+
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
|
31 |
structure_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/structure_wts.pt', force_reload=True)
|
32 |
+
|
33 |
imgsz = 640
|
34 |
|
35 |
detection_class_names = ['table', 'table rotated']
|
|
|
293 |
return PIL.Image.open(img_buf)
|
294 |
|
295 |
|
296 |
+
# def pytess(cell_pil_img):
|
297 |
+
# return ' '.join(pytesseract.image_to_data(cell_pil_img, output_type=Output.DICT, config='-c tessedit_char_blacklist=œ˜â€œï¬â™Ã©œ¢!|”?«“¥ --tessdata-dir tessdata --oem 3 --psm 6')['text']).strip()
|
298 |
|
299 |
|
300 |
+
# def resize(pil_img, size=1800):
|
301 |
+
# length_x, width_y = pil_img.size
|
302 |
+
# factor = max(1, size / length_x)
|
303 |
+
# size = int(factor * length_x), int(factor * width_y)
|
304 |
+
# pil_img = pil_img.resize(size, PIL.Image.ANTIALIAS)
|
305 |
+
# return pil_img, factor
|
306 |
|
307 |
|
308 |
+
# def image_smoothening(img):
|
309 |
+
# ret1, th1 = cv2.threshold(img, 180, 255, cv2.THRESH_BINARY)
|
310 |
+
# ret2, th2 = cv2.threshold(th1, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
311 |
+
# blur = cv2.GaussianBlur(th2, (1, 1), 0)
|
312 |
+
# ret3, th3 = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
313 |
+
# return th3
|
314 |
|
315 |
|
316 |
+
# def remove_noise_and_smooth(pil_img):
|
317 |
+
# img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY)
|
318 |
+
# filtered = cv2.adaptiveThreshold(img.astype(np.uint8), 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 41, 3)
|
319 |
+
# kernel = np.ones((1, 1), np.uint8)
|
320 |
+
# opening = cv2.morphologyEx(filtered, cv2.MORPH_OPEN, kernel)
|
321 |
+
# closing = cv2.morphologyEx(opening, cv2.MORPH_CLOSE, kernel)
|
322 |
+
# img = image_smoothening(img)
|
323 |
+
# or_image = cv2.bitwise_or(img, closing)
|
324 |
+
# pil_img = PIL.Image.fromarray(or_image)
|
325 |
+
# return pil_img
|
326 |
|
327 |
|
328 |
# def extract_text_from_cells(pil_img, cells):
|
|
|
446 |
# return html_code
|
447 |
|
448 |
|
449 |
+
def cells_to_excel(cells, file_path):
|
450 |
+
|
451 |
+
def int2xlsx(i):
|
452 |
+
if i < 26:
|
453 |
+
return chr(i + 65)
|
454 |
+
return f'{chr(i // 26 + 64)}{chr(i % 26 + 65)}'
|
455 |
+
|
456 |
+
cells = sorted(cells, key=lambda k: min(k['column_nums']))
|
457 |
+
cells = sorted(cells, key=lambda k: min(k['row_nums']))
|
458 |
+
|
459 |
+
workbook = xlsxwriter.Workbook(file_path)
|
460 |
+
|
461 |
+
cell_format = workbook.add_format(
|
462 |
+
{
|
463 |
+
'align': 'center',
|
464 |
+
'valign': 'vcenter',
|
465 |
+
}
|
466 |
+
)
|
467 |
+
|
468 |
+
worksheet = workbook.add_worksheet(name='Table')
|
469 |
+
|
470 |
+
table_start_index = 0
|
471 |
+
|
472 |
+
for cell in cells:
|
473 |
+
start_row = min(cell['row_nums'])
|
474 |
+
end_row = max(cell['row_nums'])
|
475 |
+
start_col = min(cell['column_nums'])
|
476 |
+
end_col = max(cell['column_nums'])
|
477 |
+
if start_row == end_row and start_col == end_col:
|
478 |
+
worksheet.write(
|
479 |
+
table_start_index + start_row,
|
480 |
+
start_col,
|
481 |
+
cell['cell_text'],
|
482 |
+
cell_format,
|
483 |
+
)
|
484 |
+
else:
|
485 |
+
if start_col == end_col and start_row == end_row:
|
486 |
+
excel_index = f'{int2xlsx(table_start_index + start_col)}{table_start_index + start_row + 1}'
|
487 |
+
else:
|
488 |
+
excel_index = f'{int2xlsx(table_start_index + start_col)}{table_start_index + start_row + 1}:{int2xlsx(table_start_index + end_col)}{table_start_index + end_row + 1}'
|
489 |
+
worksheet.merge_range(
|
490 |
+
excel_index, cell['cell_text'], cell_format
|
491 |
+
)
|
492 |
+
|
493 |
+
workbook.close()
|
494 |
+
|
495 |
+
|
496 |
def main():
|
497 |
|
498 |
st.set_page_config(layout='wide')
|
|
|
508 |
|
509 |
else:
|
510 |
tabs = st.tabs(
|
511 |
+
['Table Detection', 'Table Structure Recognition', 'Extracted Table(s)']
|
512 |
)
|
513 |
|
514 |
print(filename)
|
|
|
517 |
detection_result = table_detection(pil_img)
|
518 |
crop_images, vis_det_img = crop_image(pil_img, detection_result)
|
519 |
|
520 |
+
all_cells = []
|
521 |
+
|
522 |
with tabs[0]:
|
523 |
+
st.header('Table Detection')
|
524 |
st.image(vis_det_img)
|
525 |
|
526 |
with tabs[1]:
|
527 |
+
st.header('Table Structure Recognition')
|
528 |
+
|
529 |
+
str_cols = st.columns((len(crop_images), ) * 4)
|
530 |
str_cols[0].subheader('Table image')
|
531 |
str_cols[1].subheader('OCR result')
|
532 |
str_cols[2].subheader('Structure result')
|
533 |
str_cols[3].subheader('Cells result')
|
|
|
534 |
|
535 |
for i, img in enumerate(crop_images):
|
536 |
ocr_result = ocr(img)
|
537 |
structure_result = table_structure(img)
|
538 |
table_structures, cells, confidence_score = convert_stucture(ocr_result, img, structure_result)
|
539 |
cells = extract_text_from_cells(cells)
|
540 |
+
all_cells.append(cells)
|
541 |
html_result = cells_to_html(cells)
|
542 |
+
#df, csv_result = cells_to_csv(cells)
|
543 |
#print(df)
|
544 |
|
545 |
vis_ocr_img = visualize_ocr(img, ocr_result)
|
|
|
550 |
str_cols[1].image(vis_ocr_img)
|
551 |
str_cols[2].image(vis_str_img)
|
552 |
str_cols[3].image(vis_cells_img)
|
|
|
|
|
553 |
|
554 |
st.write('\n')
|
555 |
st.markdown(html_result, unsafe_allow_html=True)
|
556 |
|
557 |
+
with tabs[2]:
|
558 |
+
st.header('Extracted Table(s)')
|
559 |
+
for idx, col in enumerate(st.columns(len(all_cells))):
|
560 |
+
with col:
|
561 |
+
if len(all_cells) > 1:
|
562 |
+
st.header(f'Table {idx + 1}')
|
563 |
+
|
564 |
+
with TemporaryDirectory() as temp_dir_path:
|
565 |
+
df = None
|
566 |
+
xlsx_path = os.path.join(temp_dir_path, f'debug_{idx}.xlsx')
|
567 |
+
cells_to_excel(all_cells[idx], xlsx_path)
|
568 |
+
with open(xlsx_path, 'rb') as ref:
|
569 |
+
df = pd.read_excel(ref)
|
570 |
+
st.dataframe(df)
|
571 |
+
st.download_button(
|
572 |
+
'Download Excel File',
|
573 |
+
ref,
|
574 |
+
file_name=f'output_{idx}.xlsx',
|
575 |
+
)
|
576 |
+
|
577 |
|
578 |
if __name__ == '__main__':
|
579 |
main()
|
requirements.txt
CHANGED
@@ -76,3 +76,4 @@ setuptools>=65.5.1 # Snyk vulnerability fix
|
|
76 |
# Other
|
77 |
pytesseract==0.3.10
|
78 |
# beautifulsoup4==4.11.1
|
|
|
|
76 |
# Other
|
77 |
pytesseract==0.3.10
|
78 |
# beautifulsoup4==4.11.1
|
79 |
+
xlsxwriter
|