bachpc commited on
Commit
b310dda
1 Parent(s): 6353415

Add extract to excel

Browse files
Files changed (2) hide show
  1. app.py +110 -32
  2. requirements.txt +1 -0
app.py CHANGED
@@ -4,24 +4,32 @@ import cv2
4
  import numpy as np
5
  import pandas as pd
6
  import torch
 
7
  import io
8
  # import sys
9
  # import json
10
  from collections import OrderedDict, defaultdict
11
  import xml.etree.ElementTree as ET
 
 
12
  import matplotlib.pyplot as plt
13
  import matplotlib.patches as patches
14
 
15
  from paddleocr import PaddleOCR
16
- import pytesseract
17
- from pytesseract import Output
18
 
19
  import postprocess
20
 
21
 
22
  ocr_instance = PaddleOCR(use_angle_cls=False, lang='en', use_gpu=True)
 
 
23
  detection_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/detection_wts.pt', force_reload=True)
 
 
24
  structure_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/structure_wts.pt', force_reload=True)
 
25
  imgsz = 640
26
 
27
  detection_class_names = ['table', 'table rotated']
@@ -285,36 +293,36 @@ def visualize_cells(pil_img, cells):
285
  return PIL.Image.open(img_buf)
286
 
287
 
288
- def pytess(cell_pil_img):
289
- return ' '.join(pytesseract.image_to_data(cell_pil_img, output_type=Output.DICT, config='-c tessedit_char_blacklist=œ˜â€œï¬â™Ã©œ¢!|”?«“¥ --tessdata-dir tessdata --oem 3 --psm 6')['text']).strip()
290
 
291
 
292
- def resize(pil_img, size=1800):
293
- length_x, width_y = pil_img.size
294
- factor = max(1, size / length_x)
295
- size = int(factor * length_x), int(factor * width_y)
296
- pil_img = pil_img.resize(size, PIL.Image.ANTIALIAS)
297
- return pil_img, factor
298
 
299
 
300
- def image_smoothening(img):
301
- ret1, th1 = cv2.threshold(img, 180, 255, cv2.THRESH_BINARY)
302
- ret2, th2 = cv2.threshold(th1, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
303
- blur = cv2.GaussianBlur(th2, (1, 1), 0)
304
- ret3, th3 = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
305
- return th3
306
 
307
 
308
- def remove_noise_and_smooth(pil_img):
309
- img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY)
310
- filtered = cv2.adaptiveThreshold(img.astype(np.uint8), 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 41, 3)
311
- kernel = np.ones((1, 1), np.uint8)
312
- opening = cv2.morphologyEx(filtered, cv2.MORPH_OPEN, kernel)
313
- closing = cv2.morphologyEx(opening, cv2.MORPH_CLOSE, kernel)
314
- img = image_smoothening(img)
315
- or_image = cv2.bitwise_or(img, closing)
316
- pil_img = PIL.Image.fromarray(or_image)
317
- return pil_img
318
 
319
 
320
  # def extract_text_from_cells(pil_img, cells):
@@ -438,6 +446,53 @@ def cells_to_html(cells):
438
  # return html_code
439
 
440
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
  def main():
442
 
443
  st.set_page_config(layout='wide')
@@ -453,7 +508,7 @@ def main():
453
 
454
  else:
455
  tabs = st.tabs(
456
- ['Table Detection', 'Table Structure Recognition']
457
  )
458
 
459
  print(filename)
@@ -462,24 +517,29 @@ def main():
462
  detection_result = table_detection(pil_img)
463
  crop_images, vis_det_img = crop_image(pil_img, detection_result)
464
 
 
 
465
  with tabs[0]:
 
466
  st.image(vis_det_img)
467
 
468
  with tabs[1]:
469
- str_cols = st.columns((len(crop_images), ) * 5)
 
 
470
  str_cols[0].subheader('Table image')
471
  str_cols[1].subheader('OCR result')
472
  str_cols[2].subheader('Structure result')
473
  str_cols[3].subheader('Cells result')
474
- str_cols[4].subheader('CSV result')
475
 
476
  for i, img in enumerate(crop_images):
477
  ocr_result = ocr(img)
478
  structure_result = table_structure(img)
479
  table_structures, cells, confidence_score = convert_stucture(ocr_result, img, structure_result)
480
  cells = extract_text_from_cells(cells)
 
481
  html_result = cells_to_html(cells)
482
- df, csv_result = cells_to_csv(cells)
483
  #print(df)
484
 
485
  vis_ocr_img = visualize_ocr(img, ocr_result)
@@ -490,12 +550,30 @@ def main():
490
  str_cols[1].image(vis_ocr_img)
491
  str_cols[2].image(vis_str_img)
492
  str_cols[3].image(vis_cells_img)
493
- #str_cols[4].dataframe(df)
494
- str_cols[4].download_button('Download table', csv_result, f'table-{i}.csv', 'text/csv', key=f'download-csv-{i}')
495
 
496
  st.write('\n')
497
  st.markdown(html_result, unsafe_allow_html=True)
498
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499
 
500
  if __name__ == '__main__':
501
  main()
 
4
  import numpy as np
5
  import pandas as pd
6
  import torch
7
+ import os
8
  import io
9
  # import sys
10
  # import json
11
  from collections import OrderedDict, defaultdict
12
  import xml.etree.ElementTree as ET
13
+ from tempfile import TemporaryDirectory
14
+ import xlsxwriter
15
  import matplotlib.pyplot as plt
16
  import matplotlib.patches as patches
17
 
18
  from paddleocr import PaddleOCR
19
+ # import pytesseract
20
+ # from pytesseract import Output
21
 
22
  import postprocess
23
 
24
 
25
  ocr_instance = PaddleOCR(use_angle_cls=False, lang='en', use_gpu=True)
26
+
27
+ torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
28
  detection_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/detection_wts.pt', force_reload=True)
29
+
30
+ torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
31
  structure_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/structure_wts.pt', force_reload=True)
32
+
33
  imgsz = 640
34
 
35
  detection_class_names = ['table', 'table rotated']
 
293
  return PIL.Image.open(img_buf)
294
 
295
 
296
+ # def pytess(cell_pil_img):
297
+ # return ' '.join(pytesseract.image_to_data(cell_pil_img, output_type=Output.DICT, config='-c tessedit_char_blacklist=œ˜â€œï¬â™Ã©œ¢!|”?«“¥ --tessdata-dir tessdata --oem 3 --psm 6')['text']).strip()
298
 
299
 
300
+ # def resize(pil_img, size=1800):
301
+ # length_x, width_y = pil_img.size
302
+ # factor = max(1, size / length_x)
303
+ # size = int(factor * length_x), int(factor * width_y)
304
+ # pil_img = pil_img.resize(size, PIL.Image.ANTIALIAS)
305
+ # return pil_img, factor
306
 
307
 
308
+ # def image_smoothening(img):
309
+ # ret1, th1 = cv2.threshold(img, 180, 255, cv2.THRESH_BINARY)
310
+ # ret2, th2 = cv2.threshold(th1, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
311
+ # blur = cv2.GaussianBlur(th2, (1, 1), 0)
312
+ # ret3, th3 = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
313
+ # return th3
314
 
315
 
316
+ # def remove_noise_and_smooth(pil_img):
317
+ # img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY)
318
+ # filtered = cv2.adaptiveThreshold(img.astype(np.uint8), 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 41, 3)
319
+ # kernel = np.ones((1, 1), np.uint8)
320
+ # opening = cv2.morphologyEx(filtered, cv2.MORPH_OPEN, kernel)
321
+ # closing = cv2.morphologyEx(opening, cv2.MORPH_CLOSE, kernel)
322
+ # img = image_smoothening(img)
323
+ # or_image = cv2.bitwise_or(img, closing)
324
+ # pil_img = PIL.Image.fromarray(or_image)
325
+ # return pil_img
326
 
327
 
328
  # def extract_text_from_cells(pil_img, cells):
 
446
  # return html_code
447
 
448
 
449
+ def cells_to_excel(cells, file_path):
450
+
451
+ def int2xlsx(i):
452
+ if i < 26:
453
+ return chr(i + 65)
454
+ return f'{chr(i // 26 + 64)}{chr(i % 26 + 65)}'
455
+
456
+ cells = sorted(cells, key=lambda k: min(k['column_nums']))
457
+ cells = sorted(cells, key=lambda k: min(k['row_nums']))
458
+
459
+ workbook = xlsxwriter.Workbook(file_path)
460
+
461
+ cell_format = workbook.add_format(
462
+ {
463
+ 'align': 'center',
464
+ 'valign': 'vcenter',
465
+ }
466
+ )
467
+
468
+ worksheet = workbook.add_worksheet(name='Table')
469
+
470
+ table_start_index = 0
471
+
472
+ for cell in cells:
473
+ start_row = min(cell['row_nums'])
474
+ end_row = max(cell['row_nums'])
475
+ start_col = min(cell['column_nums'])
476
+ end_col = max(cell['column_nums'])
477
+ if start_row == end_row and start_col == end_col:
478
+ worksheet.write(
479
+ table_start_index + start_row,
480
+ start_col,
481
+ cell['cell_text'],
482
+ cell_format,
483
+ )
484
+ else:
485
+ if start_col == end_col and start_row == end_row:
486
+ excel_index = f'{int2xlsx(table_start_index + start_col)}{table_start_index + start_row + 1}'
487
+ else:
488
+ excel_index = f'{int2xlsx(table_start_index + start_col)}{table_start_index + start_row + 1}:{int2xlsx(table_start_index + end_col)}{table_start_index + end_row + 1}'
489
+ worksheet.merge_range(
490
+ excel_index, cell['cell_text'], cell_format
491
+ )
492
+
493
+ workbook.close()
494
+
495
+
496
  def main():
497
 
498
  st.set_page_config(layout='wide')
 
508
 
509
  else:
510
  tabs = st.tabs(
511
+ ['Table Detection', 'Table Structure Recognition', 'Extracted Table(s)']
512
  )
513
 
514
  print(filename)
 
517
  detection_result = table_detection(pil_img)
518
  crop_images, vis_det_img = crop_image(pil_img, detection_result)
519
 
520
+ all_cells = []
521
+
522
  with tabs[0]:
523
+ st.header('Table Detection')
524
  st.image(vis_det_img)
525
 
526
  with tabs[1]:
527
+ st.header('Table Structure Recognition')
528
+
529
+ str_cols = st.columns((len(crop_images), ) * 4)
530
  str_cols[0].subheader('Table image')
531
  str_cols[1].subheader('OCR result')
532
  str_cols[2].subheader('Structure result')
533
  str_cols[3].subheader('Cells result')
 
534
 
535
  for i, img in enumerate(crop_images):
536
  ocr_result = ocr(img)
537
  structure_result = table_structure(img)
538
  table_structures, cells, confidence_score = convert_stucture(ocr_result, img, structure_result)
539
  cells = extract_text_from_cells(cells)
540
+ all_cells.append(cells)
541
  html_result = cells_to_html(cells)
542
+ #df, csv_result = cells_to_csv(cells)
543
  #print(df)
544
 
545
  vis_ocr_img = visualize_ocr(img, ocr_result)
 
550
  str_cols[1].image(vis_ocr_img)
551
  str_cols[2].image(vis_str_img)
552
  str_cols[3].image(vis_cells_img)
 
 
553
 
554
  st.write('\n')
555
  st.markdown(html_result, unsafe_allow_html=True)
556
 
557
+ with tabs[2]:
558
+ st.header('Extracted Table(s)')
559
+ for idx, col in enumerate(st.columns(len(all_cells))):
560
+ with col:
561
+ if len(all_cells) > 1:
562
+ st.header(f'Table {idx + 1}')
563
+
564
+ with TemporaryDirectory() as temp_dir_path:
565
+ df = None
566
+ xlsx_path = os.path.join(temp_dir_path, f'debug_{idx}.xlsx')
567
+ cells_to_excel(all_cells[idx], xlsx_path)
568
+ with open(xlsx_path, 'rb') as ref:
569
+ df = pd.read_excel(ref)
570
+ st.dataframe(df)
571
+ st.download_button(
572
+ 'Download Excel File',
573
+ ref,
574
+ file_name=f'output_{idx}.xlsx',
575
+ )
576
+
577
 
578
  if __name__ == '__main__':
579
  main()
requirements.txt CHANGED
@@ -76,3 +76,4 @@ setuptools>=65.5.1 # Snyk vulnerability fix
76
  # Other
77
  pytesseract==0.3.10
78
  # beautifulsoup4==4.11.1
 
 
76
  # Other
77
  pytesseract==0.3.10
78
  # beautifulsoup4==4.11.1
79
+ xlsxwriter