nguyenp99's picture
Upload 17 files
45099b6 verified
raw
history blame
1.23 kB
import torch
def iou(table_box, stamp_box):
# table_box = [x1, y1, x2, y2]
# stamp_box = [x1, y1, x2, y2]
x1 = max(table_box[0], stamp_box[0])
y1 = max(table_box[1], stamp_box[1])
x2 = min(table_box[2], stamp_box[2])
y2 = min(table_box[3], stamp_box[3])
intersection = max(0, x2 - x1) * max(0, y2 - y1)
union = (table_box[2] - table_box[0]) * (table_box[3] - table_box[1]) + (stamp_box[2] - stamp_box[0]) * (stamp_box[3] - stamp_box[1]) - intersection
return intersection / union
def remove_potiential_table_fp(stamp_detector, image, table_preds, iou_threshold=0.6):
stamps = stamp_detector([image])[0]
remove_idc = []
for stamp in stamps:
for i, table in enumerate(table_preds):
if iou(table, stamp) >= iou_threshold:
remove_idc.append(i)
return remove_idc
def torch_delete_by_idc(tensor, indices):
mask = torch.ones(tensor.shape[0], dtype=torch.bool)
if len(indices) == 0:
return tensor
else:
mask[indices] = False
return tensor[mask, :]
def remove_box_by_idc(boxes, indices):
item = getattr(boxes, 'data')
setattr(boxes, 'data', torch_delete_by_idc(item, indices))
return boxes