File size: 4,165 Bytes
d0f5a61
 
 
 
4317393
 
 
0b822c2
 
 
 
 
4317393
f0585ee
0b822c2
 
 
 
 
 
1a933f0
0b822c2
1a933f0
0b822c2
1a933f0
 
 
 
 
 
 
0b822c2
 
4317393
487ba13
153ddd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0f5a61
 
f0585ee
4317393
 
0b822c2
 
 
 
1a933f0
 
 
 
 
153ddd7
19a9827
1a933f0
0b822c2
 
 
 
 
 
487ba13
153ddd7
 
 
 
 
 
 
0b822c2
153ddd7
bfd905c
 
 
 
 
 
13fc488
bfd905c
d0f5a61
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
import spaces
import torch

from PIL import Image
import requests
from transformers import DetrImageProcessor
from transformers import DetrForObjectDetection
from random import choice
import matplotlib.pyplot as plt
import io


processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")


COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
          [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]

def get_output_figure(pil_img, scores, labels, boxes):
    plt.figure(figsize=(16, 10))
    plt.imshow(pil_img)
    ax = plt.gca()
    colors = COLORS * 100
    for score, label, (xmin, ymin, xmax, ymax), c in zip (scores.tolist(), labels.tolist(), boxes.tolist(), colors):
        ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3))
        text = f'{model.config.id2label[label]}: {score:0.2f}'
        ax.text(xmin, ymin, text, fontsize=15, 
                bbox=dict(facecolor='yellow', alpha=0.5))
    plt.axis('off')

    return plt.gcf()

def get_output_attn_figure(image, encoding, results, outputs):
    # keep only predictions of queries with +0.9 condifence (excluding no-object class)
    probas = outputs.logits.softmax(-1)[0, :, :-1]
    keep = probas.max(-1).values > 0.9

    bboxes_scaled = results['boxes']
    # use lists to store the outputs vis up-values
    conv_features = []

    hooks = [
        model.model.backbone.conv_encoder.register_forward_hook(
            lambda self, input, output: conv_features.append(output)
        )
    ]

    # propagate through the model
    outputs = model(**encoding, output_attentions=True)

    for hook in hooks:
        hook.remove()

    # don't need the list anymore
    conv_features = conv_features[0]
    # get cross-attentions weights of last decoder layer - which is of shape (batch_size, num_heads, num_queries, width*height)
    dec_attn_weights = outputs.cross_attentions[-1]
    #average them over the 8 heads and detach from graph
    dec_attn_weights = torch.mean(dec_attn_weights, dim=1).detach()

    # get the feature map shape
    h, w = conv_features[-1][0].shape[-2:]

    fig, axs = plt.subplots(ncols=len(bboxes_scaled), nrows=2, figsize=(22, 7))
    colors = COLORS * 100
    for idx, ax_i, box in zip(keep.nonzero(), axs.T, bboxes_scaled):
        xmin, ymin, xmax, ymax = box.detach().numpy()
        ax = ax_i[0]
        ax.imshow(dec_attn_weights[0, idx].view(h, w))
        ax.axis('off')
        ax.set_title(f'query id: {idx.item()}')
        ax = ax_i[1]
        ax.imshow(image)
        ax.add_patch(plt.Rectangle((xmin, ymin), xmax-xmin, ymax - ymin, fill=False,
                        color='blue', linewidth=3))
        ax.axis('off')
        ax.set_title(model.config.id2label[probas[idx].argmax().item()])
    fig.tight_layout()
    return plt.gcf()


@spaces.GPU
def detect(image):
    encoding = processor(image, return_tensors='pt')
    print(encoding.keys())
    
    with torch.no_grad():
        outputs = model(**encoding)

    #print(outputs)
    width, height = image.size
    postprocessed_outputs = processor.post_process_object_detection(outputs, target_sizes=[(height, width)], threshold=0.9)
    results = postprocessed_outputs[0]

    #print(results)

    output_figure = get_output_figure(image, results['scores'], results['labels'], results['boxes'])

    buf = io.BytesIO()
    output_figure.savefig(buf, bbox_inches='tight')
    buf.seek(0)
    output_pil_img = Image.open(buf)

    output_figure_attn = get_output_attn_figure(image, encoding, results, outputs)

    buf = io.BytesIO()
    output_figure_attn.savefig(buf, bbox_inches='tight')
    buf.seek(0)
    output_pil_img_attn = Image.open(buf)

    #print(output_pil_img)

    return output_pil_img, output_pil_img_attn

demo = gr.Interface(
    fn=detect, 
    inputs=gr.Image(label="Input image", type="pil"), 
    outputs=[
        gr.Image(label="Output image predictions", type="pil"),
        gr.Image(label="Output attention weights", type="pil")
        ])
demo.launch()