import torch def iou(table_box, stamp_box): # table_box = [x1, y1, x2, y2] # stamp_box = [x1, y1, x2, y2] x1 = max(table_box[0], stamp_box[0]) y1 = max(table_box[1], stamp_box[1]) x2 = min(table_box[2], stamp_box[2]) y2 = min(table_box[3], stamp_box[3]) intersection = max(0, x2 - x1) * max(0, y2 - y1) union = (table_box[2] - table_box[0]) * (table_box[3] - table_box[1]) + (stamp_box[2] - stamp_box[0]) * (stamp_box[3] - stamp_box[1]) - intersection return intersection / union def remove_potiential_table_fp(stamp_detector, image, table_preds, iou_threshold=0.6): stamps = stamp_detector([image])[0] remove_idc = [] for stamp in stamps: for i, table in enumerate(table_preds): if iou(table, stamp) >= iou_threshold: remove_idc.append(i) return remove_idc def torch_delete_by_idc(tensor, indices): mask = torch.ones(tensor.shape[0], dtype=torch.bool) if len(indices) == 0: return tensor else: mask[indices] = False return tensor[mask, :] def remove_box_by_idc(boxes, indices): item = getattr(boxes, 'data') setattr(boxes, 'data', torch_delete_by_idc(item, indices)) return boxes