import gradio as gr import spaces import torch from PIL import Image import requests from transformers import DetrImageProcessor from transformers import DetrForObjectDetection from random import choice import matplotlib.pyplot as plt import io processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]] def get_output_figure(pil_img, scores, labels, boxes): plt.figure(figsize=(16, 10)) plt.imshow(pil_img) ax = plt.gca() colors = COLORS * 100 for score, label, (xmin, ymin, xmax, ymax), c in zip(scores.tolist(), labels.tolist(), boxes.tolist(), colors): ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3)) text = f'{model.config.id2label[label]}: {score:0.2f}' ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='yellow', alpha=0.5)) plt.axis('off') return plt.gcf() def get_output_attn_figure(image, encoding, results, outputs): # keep only predictions of queries with +0.9 condifence (excluding no-object class) probas = outputs.logits.softmax(-1)[0, :, :-1] keep = probas.max(-1).values > 0.9 bboxes_scaled = results['boxes'] # use lists to store the outputs vis up-values conv_features = [] hooks = [ model.model.backbone.conv_encoder.register_forward_hook( lambda self, input, output: conv_features.append(output) ) ] # propagate through the model outputs = model(**encoding, output_attentions=True) for hook in hooks: hook.remove() # don't need the list anymore conv_features = conv_features[0] # get cross-attentions weights of last decoder layer - which is of shape (batch_size, num_heads, num_queries, width*height) dec_attn_weights = outputs.cross_attentions[-1] #average them over the 8 heads and detach from graph dec_attn_weights = torch.mean(dec_attn_weights, dim=1).detach() # get the feature map shape h, w = conv_features[-1][0].shape[-2:] fig, axs = plt.subplots(ncols=len(bboxes_scaled), nrows=2, figsize=(22, 7)) colors = COLORS * 100 for idx, ax_i, box in zip(keep.nonzero(), axs.T, bboxes_scaled): xmin, ymin, xmax, ymax = box.detach().numpy() ax = ax_i[0] ax.imshow(dec_attn_weights[0, idx].view(h, w)) ax.axis('off') ax.set_title(f'query id: {idx.item()}') ax = ax_i[1] ax.imshow(image) ax.add_patch(plt.Rectangle((xmin, ymin), xmax-xmin, ymax - ymin, fill=False, color='blue', linewidth=3)) ax.axis('off') ax.set_title(model.config.id2label[probas[idx].argmax().item()]) fig.tight_layout() return plt.gcf() @spaces.GPU def detect(image): encoding = processor(image, return_tensors='pt') print(encoding.keys()) with torch.no_grad(): outputs = model(**encoding) width, height = image.size postprocessed_outputs = processor.post_process_object_detection(outputs, target_sizes=[(height, width)], threshold=0.9) results = postprocessed_outputs[0] output_figure = get_output_figure(image, results['scores'], results['labels'], results['boxes']) buf = io.BytesIO() output_figure.savefig(buf, bbox_inches='tight') buf.seek(0) output_pil_img = Image.open(buf) output_figure_attn = get_output_attn_figure(image, encoding, results, outputs) buf = io.BytesIO() output_figure_attn.savefig(buf, bbox_inches='tight') buf.seek(0) output_pil_img_attn = Image.open(buf) return output_pil_img, output_pil_img_attn with gr.Blocks() as demo: gr.Markdown("# Object detection with DETR") gr.Markdown( """ This applciation uses DETR (DEtection TRansformers) to detect objects on images. You can load an image and see the predictions for the objects detected along with the attention weights. """ ) gr.Interface( fn=detect, inputs=gr.Image(label="Input image", type="pil"), outputs=[ gr.Image(label="Output prediction", type="pil"), gr.Image(label="Attention weights", type="pil") ] )#.launch() demo.launch(show_error=True)