bachpc commited on
Commit
c59b7a8
1 Parent(s): 3479c98

Update postprocess.py

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. postprocess.py +25 -23
app.py CHANGED
@@ -141,7 +141,7 @@ def crop_image(pil_img, detection_result):
141
  fontScale = lw / 3
142
  thickness = max(lw - 1, 1)
143
  w_label, h_label = cv2.getTextSize(label, 0, fontScale=fontScale, thickness=thickness)[0]
144
- cv2.rectangle(image, (x1, y1), (x1 + w_label, y1 - h_label - 3), (255, 0, 0), -1, cv2.LINE_AA)
145
  cv2.putText(image, label, (x1, y1 - 2), cv2.FONT_HERSHEY_SIMPLEX, fontScale, (255, 255, 255), thickness=thickness, lineType=cv2.LINE_AA)
146
 
147
  return crop_images, cv_to_PIL(image)
 
141
  fontScale = lw / 3
142
  thickness = max(lw - 1, 1)
143
  w_label, h_label = cv2.getTextSize(label, 0, fontScale=fontScale, thickness=thickness)[0]
144
+ cv2.rectangle(image, (x1, y1), (x1 + w_label, y1 - h_label - 3), (0, 0, 255), -1, cv2.LINE_AA)
145
  cv2.putText(image, label, (x1, y1 - 2), cv2.FONT_HERSHEY_SIMPLEX, fontScale, (255, 255, 255), thickness=thickness, lineType=cv2.LINE_AA)
146
 
147
  return crop_images, cv_to_PIL(image)
postprocess.py CHANGED
@@ -16,7 +16,7 @@ def apply_threshold(objects, threshold):
16
  def apply_class_thresholds(bboxes, labels, scores, class_names, class_thresholds):
17
  """
18
  Filter out bounding boxes whose confidence is below the confidence threshold for
19
- its associated class label.
20
  """
21
  # Apply class-specific thresholds
22
  indices_above_threshold = [idx for idx, (score, label) in enumerate(zip(scores, labels))
@@ -37,11 +37,11 @@ def iou(bbox1, bbox2):
37
  """
38
  intersection = Rect(bbox1).intersect(bbox2)
39
  union = Rect(bbox1).include_rect(bbox2)
40
-
41
  union_area = union.get_area()
42
  if union_area > 0:
43
  return intersection.get_area() / union.get_area()
44
-
45
  return 0
46
 
47
 
@@ -50,11 +50,11 @@ def iob(bbox1, bbox2):
50
  Compute the intersection area over box area, for bbox1.
51
  """
52
  intersection = Rect(bbox1).intersect(bbox2)
53
-
54
  bbox1_area = Rect(bbox1).get_area()
55
  if bbox1_area > 0:
56
  return intersection.get_area() / bbox1_area
57
-
58
  return 0
59
 
60
 
@@ -123,7 +123,7 @@ def objects_to_table_structures(table_object, objects_in_table, tokens_in_table,
123
  row_rect = Rect()
124
  for obj in rows:
125
  row_rect.include_rect(obj['bbox'])
126
- column_rect = Rect()
127
  for obj in columns:
128
  column_rect.include_rect(obj['bbox'])
129
  table_object['row_column_bbox'] = [column_rect[0], row_rect[1], column_rect[2], row_rect[3]]
@@ -189,7 +189,7 @@ def nms_by_containment(container_objects, package_objects, overlap_threshold=0.5
189
  suppression = [False for obj in container_objects]
190
 
191
  packages_by_container, _, _ = slot_into_containers(container_objects, package_objects, overlap_threshold=overlap_threshold,
192
- unique_assignment=True, forced_assignment=False)
193
 
194
  for object2_num in range(1, num_objects):
195
  object2_packages = set(packages_by_container[object2_num])
@@ -198,7 +198,9 @@ def nms_by_containment(container_objects, package_objects, overlap_threshold=0.5
198
  for object1_num in range(object2_num):
199
  if not suppression[object1_num]:
200
  object1_packages = set(packages_by_container[object1_num])
201
- if len(object2_packages.intersection(object1_packages)) > 0:
 
 
202
  suppression[object2_num] = True
203
 
204
  final_objects = [obj for idx, obj in enumerate(container_objects) if not suppression[idx]]
@@ -222,7 +224,7 @@ def slot_into_containers(container_objects, package_objects, overlap_threshold=0
222
  for package_num, package in enumerate(package_objects):
223
  match_scores = []
224
  package_rect = Rect(package['bbox'])
225
- package_area = package_rect.get_area()
226
  for container_num, container in enumerate(container_objects):
227
  container_rect = Rect(container['bbox'])
228
  intersect_area = container_rect.intersect(package['bbox']).get_area()
@@ -244,7 +246,7 @@ def slot_into_containers(container_objects, package_objects, overlap_threshold=0
244
  package_assignments[package_num].append(match_score['container_num'])
245
  else:
246
  break
247
-
248
  return container_assignments, package_assignments, best_match_scores
249
 
250
 
@@ -268,8 +270,8 @@ def remove_objects_without_content(page_spans, objects):
268
  object_text, _ = extract_text_inside_bbox(page_spans, obj['bbox'])
269
  if len(object_text.strip()) == 0:
270
  objects.remove(obj)
271
-
272
-
273
  def extract_text_inside_bbox(spans, bbox):
274
  """
275
  Extract the text inside a bounding box.
@@ -314,7 +316,7 @@ def extract_text_from_spans(spans, join_with_space=True, remove_integer_superscr
314
  else:
315
  join_char = ""
316
  spans_copy = spans[:]
317
-
318
  if remove_integer_superscripts:
319
  for span in spans:
320
  if not 'flags' in span:
@@ -328,11 +330,11 @@ def extract_text_from_spans(spans, join_with_space=True, remove_integer_superscr
328
 
329
  if len(spans_copy) == 0:
330
  return ""
331
-
332
  spans_copy.sort(key=lambda span: span['span_num'])
333
  spans_copy.sort(key=lambda span: span['line_num'])
334
  spans_copy.sort(key=lambda span: span['block_num'])
335
-
336
  # Force the span at the end of every line within a block to have exactly one space
337
  # unless the line ends with a space or ends with a non-space followed by a hyphen
338
  line_texts = []
@@ -351,7 +353,7 @@ def extract_text_from_spans(spans, join_with_space=True, remove_integer_superscr
351
  line_span_texts.append(span2['text'])
352
  line_text = join_char.join(line_span_texts)
353
  line_texts.append(line_text)
354
-
355
  return join_char.join(line_texts).strip()
356
 
357
 
@@ -443,7 +445,7 @@ def refine_table_structures(table_bbox, table_structures, page_spans, class_thre
443
  def nms(objects, match_criteria="object2_overlap", match_threshold=0.05, keep_higher=True):
444
  """
445
  A customizable version of non-maxima suppression (NMS).
446
-
447
  Default behavior: If a lower-confidence object overlaps more than 5% of its area
448
  with a higher-confidence object, remove the lower-confidence object.
449
 
@@ -493,7 +495,7 @@ def align_headers(headers, rows):
493
  For now, we are not supporting tables with multiple headers, so we need to
494
  eliminate anything besides the top-most header.
495
  """
496
-
497
  aligned_headers = []
498
 
499
  for row in rows:
@@ -672,7 +674,7 @@ def header_supercell_tree(supercells):
672
  """
673
  header_supercells = [supercell for supercell in supercells if 'header' in supercell and supercell['header']]
674
  header_supercells = sort_objects_by_score(header_supercells)
675
-
676
  for header_supercell in header_supercells[:]:
677
  ancestors_by_row = defaultdict(int)
678
  min_row = min(header_supercell['row_numbers'])
@@ -687,8 +689,8 @@ def header_supercell_tree(supercells):
687
  if not ancestors_by_row[row] == 1:
688
  supercells.remove(header_supercell)
689
  break
690
-
691
-
692
  def table_structure_to_cells(table_structures, table_spans, table_bbox):
693
  """
694
  Assuming the row, column, supercell, and header bounding boxes have
@@ -787,10 +789,10 @@ def table_structure_to_cells(table_structures, table_spans, table_bbox):
787
  for cell, cell_span_nums in zip(cells, span_nums_by_cell):
788
  cell_spans = [table_spans[num] for num in cell_span_nums]
789
  # TODO: Refine how text is extracted; should be character-based, not span-based;
790
- # but need to associate
791
  # cell['cell_text'] = extract_text_from_spans(cell_spans, remove_integer_superscripts=False) # TODO
792
  cell['spans'] = cell_spans
793
-
794
  # Adjust the row, column, and cell bounding boxes to reflect the extracted text
795
  num_rows = len(rows)
796
  rows = sort_objects_top_to_bottom(rows)
 
16
  def apply_class_thresholds(bboxes, labels, scores, class_names, class_thresholds):
17
  """
18
  Filter out bounding boxes whose confidence is below the confidence threshold for
19
+ its associated class label.
20
  """
21
  # Apply class-specific thresholds
22
  indices_above_threshold = [idx for idx, (score, label) in enumerate(zip(scores, labels))
 
37
  """
38
  intersection = Rect(bbox1).intersect(bbox2)
39
  union = Rect(bbox1).include_rect(bbox2)
40
+
41
  union_area = union.get_area()
42
  if union_area > 0:
43
  return intersection.get_area() / union.get_area()
44
+
45
  return 0
46
 
47
 
 
50
  Compute the intersection area over box area, for bbox1.
51
  """
52
  intersection = Rect(bbox1).intersect(bbox2)
53
+
54
  bbox1_area = Rect(bbox1).get_area()
55
  if bbox1_area > 0:
56
  return intersection.get_area() / bbox1_area
57
+
58
  return 0
59
 
60
 
 
123
  row_rect = Rect()
124
  for obj in rows:
125
  row_rect.include_rect(obj['bbox'])
126
+ column_rect = Rect()
127
  for obj in columns:
128
  column_rect.include_rect(obj['bbox'])
129
  table_object['row_column_bbox'] = [column_rect[0], row_rect[1], column_rect[2], row_rect[3]]
 
189
  suppression = [False for obj in container_objects]
190
 
191
  packages_by_container, _, _ = slot_into_containers(container_objects, package_objects, overlap_threshold=overlap_threshold,
192
+ unique_assignment=False, forced_assignment=False)
193
 
194
  for object2_num in range(1, num_objects):
195
  object2_packages = set(packages_by_container[object2_num])
 
198
  for object1_num in range(object2_num):
199
  if not suppression[object1_num]:
200
  object1_packages = set(packages_by_container[object1_num])
201
+ if len(object2_packages.intersection(object1_packages)) > 0 \
202
+ and (iob(container_objects[object2_num]['bbox'], container_objects[object1_num]['bbox']) > 0.5 \
203
+ or iob(container_objects[object1_num]['bbox'], container_objects[object2_num]['bbox']) > 0.5):
204
  suppression[object2_num] = True
205
 
206
  final_objects = [obj for idx, obj in enumerate(container_objects) if not suppression[idx]]
 
224
  for package_num, package in enumerate(package_objects):
225
  match_scores = []
226
  package_rect = Rect(package['bbox'])
227
+ package_area = package_rect.get_area()
228
  for container_num, container in enumerate(container_objects):
229
  container_rect = Rect(container['bbox'])
230
  intersect_area = container_rect.intersect(package['bbox']).get_area()
 
246
  package_assignments[package_num].append(match_score['container_num'])
247
  else:
248
  break
249
+
250
  return container_assignments, package_assignments, best_match_scores
251
 
252
 
 
270
  object_text, _ = extract_text_inside_bbox(page_spans, obj['bbox'])
271
  if len(object_text.strip()) == 0:
272
  objects.remove(obj)
273
+
274
+
275
  def extract_text_inside_bbox(spans, bbox):
276
  """
277
  Extract the text inside a bounding box.
 
316
  else:
317
  join_char = ""
318
  spans_copy = spans[:]
319
+
320
  if remove_integer_superscripts:
321
  for span in spans:
322
  if not 'flags' in span:
 
330
 
331
  if len(spans_copy) == 0:
332
  return ""
333
+
334
  spans_copy.sort(key=lambda span: span['span_num'])
335
  spans_copy.sort(key=lambda span: span['line_num'])
336
  spans_copy.sort(key=lambda span: span['block_num'])
337
+
338
  # Force the span at the end of every line within a block to have exactly one space
339
  # unless the line ends with a space or ends with a non-space followed by a hyphen
340
  line_texts = []
 
353
  line_span_texts.append(span2['text'])
354
  line_text = join_char.join(line_span_texts)
355
  line_texts.append(line_text)
356
+
357
  return join_char.join(line_texts).strip()
358
 
359
 
 
445
  def nms(objects, match_criteria="object2_overlap", match_threshold=0.05, keep_higher=True):
446
  """
447
  A customizable version of non-maxima suppression (NMS).
448
+
449
  Default behavior: If a lower-confidence object overlaps more than 5% of its area
450
  with a higher-confidence object, remove the lower-confidence object.
451
 
 
495
  For now, we are not supporting tables with multiple headers, so we need to
496
  eliminate anything besides the top-most header.
497
  """
498
+
499
  aligned_headers = []
500
 
501
  for row in rows:
 
674
  """
675
  header_supercells = [supercell for supercell in supercells if 'header' in supercell and supercell['header']]
676
  header_supercells = sort_objects_by_score(header_supercells)
677
+
678
  for header_supercell in header_supercells[:]:
679
  ancestors_by_row = defaultdict(int)
680
  min_row = min(header_supercell['row_numbers'])
 
689
  if not ancestors_by_row[row] == 1:
690
  supercells.remove(header_supercell)
691
  break
692
+
693
+
694
  def table_structure_to_cells(table_structures, table_spans, table_bbox):
695
  """
696
  Assuming the row, column, supercell, and header bounding boxes have
 
789
  for cell, cell_span_nums in zip(cells, span_nums_by_cell):
790
  cell_spans = [table_spans[num] for num in cell_span_nums]
791
  # TODO: Refine how text is extracted; should be character-based, not span-based;
792
+ # but need to associate
793
  # cell['cell_text'] = extract_text_from_spans(cell_spans, remove_integer_superscripts=False) # TODO
794
  cell['spans'] = cell_spans
795
+
796
  # Adjust the row, column, and cell bounding boxes to reflect the extracted text
797
  num_rows = len(rows)
798
  rows = sort_objects_top_to_bottom(rows)