File size: 2,819 Bytes
3a0062c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image

from Utilities.transforms import test_transforms
# from Utilities.config import S
from Utilities.utils import cells_to_bboxes, non_max_suppression, plot_image


def plot_bboxes(
        input_img,
        model,
        thresh=0.6,
        iou_thresh=0.5,
        anchors=None,
):
    input_img = test_transforms(image=input_img)["image"]
    input_img = input_img.unsqueeze(0)
    model.eval()
    with torch.no_grad():
        out = model(input_img)

        for i in range(3):
            batch_size, A, S, _, _ = out[i].shape
            anchor = anchors[i]
            boxes_scale_i = cells_to_bboxes(out[i], anchor, S=S, is_preds=True)
            bboxes = boxes_scale_i[0]

    nms_boxes = non_max_suppression(
        bboxes,
        iou_threshold=iou_thresh,
        threshold=thresh,
        box_formet="midpoint",
    )
    fig = plot_image(input_img[0].permute(1, 2, 0).detach().cpu(), nms_boxes)
    return fig, input_img

def return_top_objectness_class_preds(model, input_img, gradcam_output_stream):
    out = model(input_img)[gradcam_output_stream]

    # 1. get objectness score
    objectness_scores = out[..., 0]

    # 2. get index of highest objectness score
    max_obj_arg = torch.argmax(objectness_scores)

    max_obj_arg_onehot = torch.zeros(objectness_scores.flatten().shape[0])
    max_obj_arg_onehot[max_obj_arg] = 1

    max_obj_arg_onehot = max_obj_arg_onehot.reshape_as(objectness_scores).int()

    selected_elements = out[max_obj_arg_onehot == 1]
    selected_elements = selected_elements[:, 5:]

    return selected_elements

class TopObjectnessClassPreds(pl.LightningModule):
    def __init__(self, model, gradcam_output_stream):
        super().__init__()
        self.model = model
        self.gradcam_output_stream = gradcam_output_stream

    def forward(self, x):
        return return_top_objectness_class_preds(self.model, x, self.gradcam_output_stream)

def generate_gradcam_output(org_img, model, input_img, gradcam_output_stream: int = 0):
    TopObjectnessClassPredsObj = TopObjectnessClassPreds(model, gradcam_output_stream)
    gradcam_model_layer = [15, 22, 29]
    cam = GradCAM(
        model=TopObjectnessClassPredsObj,
        target_layers=[
            TopObjectnessClassPredsObj.model.layers[
                gradcam_model_layer[gradcam_output_stream]
            ]
        ],
    )
    grayscale_cam = cam(input_tensor=input_img, targets=None)
    grayscale_cam = np.sum(grayscale_cam, axis=-1)
    grayscale_cam = grayscale_cam[0, :]

    visualization = show_cam_on_image(
        org_img / 255,
        grayscale_cam,
        use_rgb=True,
        image_weight=0.5,
    )
    return visualization