File size: 1,232 Bytes
45099b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch


def iou(table_box, stamp_box):
    # table_box = [x1, y1, x2, y2]
    # stamp_box = [x1, y1, x2, y2]
    x1 = max(table_box[0], stamp_box[0])
    y1 = max(table_box[1], stamp_box[1])
    x2 = min(table_box[2], stamp_box[2])
    y2 = min(table_box[3], stamp_box[3])

    intersection = max(0, x2 - x1) * max(0, y2 - y1)
    union = (table_box[2] - table_box[0]) * (table_box[3] - table_box[1]) + (stamp_box[2] - stamp_box[0]) * (stamp_box[3] - stamp_box[1]) - intersection
    return intersection / union


def remove_potiential_table_fp(stamp_detector, image, table_preds, iou_threshold=0.6):
    stamps = stamp_detector([image])[0]
    remove_idc = []
    for stamp in stamps:
        for i, table in enumerate(table_preds):
            if iou(table, stamp) >= iou_threshold:
                remove_idc.append(i)
    return remove_idc


def torch_delete_by_idc(tensor, indices):
    mask = torch.ones(tensor.shape[0], dtype=torch.bool)
    if len(indices) == 0:
        return tensor
    else:
        mask[indices] = False
        return tensor[mask, :]


def remove_box_by_idc(boxes, indices):
    
    item = getattr(boxes, 'data')
    setattr(boxes, 'data', torch_delete_by_idc(item, indices))

    return boxes