File size: 3,366 Bytes
6d1366a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import os.path as osp
import cv2
import torch
import numpy as np

from isegm.inference.clicker import Clicker
from isegm.inference import utils


def inference(image, gt_mask, predictor, threshold=0.5, min_clicks=1, max_clicks=20):
    clicker = Clicker(gt_mask=gt_mask)
    pred_mask = np.zeros_like(gt_mask)
    ious_list = []
    probs_list = []
    masks_list = []

    with torch.no_grad():
        predictor.set_input_image(image)

        for click_indx in range(max_clicks):
            clicker.make_next_click(pred_mask)
            pred_probs = predictor.get_prediction(clicker)
            pred_mask = pred_probs > threshold

            iou = utils.get_iou(gt_mask, pred_mask)
            ious_list.append(iou)

            probs_list.append(pred_probs.copy())
            masks_list.append(pred_mask.copy())

        return clicker.clicks_list, np.array(ious_list, dtype=np.float32), probs_list, masks_list


def visualization(sample, predictor, mask=True, score=False, contour=True, click=True, threshold=0.5, min_clicks=1, max_clicks=20, out_dir=None):
    mask = False if score else mask
    if out_dir is not None:
        out_dir = osp.join(out_dir, str(sample.sample_id))
        os.makedirs(out_dir, exist_ok=True)

    clicks, ious, probs, masks = inference(sample.image, sample.gt_mask(sample.objects_ids[0]), 
                                    predictor, threshold=threshold,
                                    min_clicks=min_clicks, max_clicks=max_clicks)

    outputs = []

    show = cv2.cvtColor(sample.image.copy(), cv2.COLOR_RGB2BGR)
    gt_mask = sample.gt_mask(sample.objects_ids[0]).astype(np.bool8)
    if mask:
        show[~gt_mask] = (show[~gt_mask] * 0.4).astype(np.uint8)
    if contour:
        contours, _ = cv2.findContours(gt_mask.astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)
        show = cv2.drawContours(show, contours, -1, (106, 211, 253), 2)
    if out_dir is not None:
        cv2.imwrite(osp.join(out_dir, f'gt_mask.jpg'), show)
    outputs.append(show)

    for i in range(len(clicks)):
        show = cv2.cvtColor(sample.image.copy(), cv2.COLOR_RGB2BGR)

        if score:
            score_map = cv2.applyColorMap((probs[i] * 255).astype(np.uint8), cv2.COLORMAP_JET)
            show = cv2.addWeighted(show, 0.5, score_map, 0.5, 0)

        if mask:
            show[~masks[i]] = (show[~masks[i]] * 0.4).astype(np.uint8)

        if contour:
            contours, _ = cv2.findContours(masks[i].astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)
            min_area_threshold = 10
            for cur_contour in contours:
                area = cv2.contourArea(cur_contour)
                if area > min_area_threshold:
                    show = cv2.drawContours(show, [cur_contour], -1, (106, 211, 253), 2)
            # show = cv2.drawContours(show, contours, -1, (106, 211, 253), 4)

        if click:
            for j in range(i + 1):
                color = (80, 208, 146) if clicks[j].is_positive else (0, 0, 192)
                coords = (clicks[j].coords[1], clicks[j].coords[0])
                show = cv2.circle(show, coords, 7, (0, 0, 0), -1)
                show = cv2.circle(show, coords, 5, color, -1)

        outputs.append(show)

        if out_dir is not None:
            cv2.imwrite(osp.join(out_dir, f'{i+1}_{ious[i]:.2f}.jpg'), show)

    return outputs, ious