|
import torch |
|
|
|
|
|
def iou(table_box, stamp_box): |
|
|
|
|
|
x1 = max(table_box[0], stamp_box[0]) |
|
y1 = max(table_box[1], stamp_box[1]) |
|
x2 = min(table_box[2], stamp_box[2]) |
|
y2 = min(table_box[3], stamp_box[3]) |
|
|
|
intersection = max(0, x2 - x1) * max(0, y2 - y1) |
|
union = (table_box[2] - table_box[0]) * (table_box[3] - table_box[1]) + (stamp_box[2] - stamp_box[0]) * (stamp_box[3] - stamp_box[1]) - intersection |
|
return intersection / union |
|
|
|
|
|
def remove_potiential_table_fp(stamp_detector, image, table_preds, iou_threshold=0.6): |
|
stamps = stamp_detector([image])[0] |
|
remove_idc = [] |
|
for stamp in stamps: |
|
for i, table in enumerate(table_preds): |
|
if iou(table, stamp) >= iou_threshold: |
|
remove_idc.append(i) |
|
return remove_idc |
|
|
|
|
|
def torch_delete_by_idc(tensor, indices): |
|
mask = torch.ones(tensor.shape[0], dtype=torch.bool) |
|
if len(indices) == 0: |
|
return tensor |
|
else: |
|
mask[indices] = False |
|
return tensor[mask, :] |
|
|
|
|
|
def remove_box_by_idc(boxes, indices): |
|
|
|
item = getattr(boxes, 'data') |
|
setattr(boxes, 'data', torch_delete_by_idc(item, indices)) |
|
|
|
return boxes |
|
|