sergiopaniego commited on
Commit
4580854
1 Parent(s): 6c27fbc

First iteration of the Gradio space

Browse files
Files changed (2) hide show
  1. app.py +73 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import torch
4
+
5
+ from PIL import Image
6
+ import requests
7
+ from transformers import DetrImageProcessor
8
+ from transformers import DetrForObjectDetection
9
+ import matplotlib.pyplot as plt
10
+ import io
11
+
12
+
13
+ processor = DetrImageProcessor.from_pretrained("sergiopaniego/fashionpedia-finetuned_albumentations_coco")
14
+ model = DetrForObjectDetection.from_pretrained("sergiopaniego/fashionpedia-finetuned_albumentations_coco")
15
+
16
+
17
+ COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
18
+ [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
19
+
20
+ def get_output_figure(pil_img, scores, labels, boxes):
21
+ plt.figure(figsize=(16, 10))
22
+ plt.imshow(pil_img)
23
+ ax = plt.gca()
24
+ colors = COLORS * 100
25
+ for score, label, (xmin, ymin, xmax, ymax), c in zip(scores.tolist(), labels.tolist(), boxes.tolist(), colors):
26
+ ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3))
27
+ text = f'{model.config.id2label[label]}: {score:0.2f}'
28
+ ax.text(xmin, ymin, text, fontsize=15,
29
+ bbox=dict(facecolor='yellow', alpha=0.5))
30
+ plt.axis('off')
31
+
32
+ return plt.gcf()
33
+
34
+ @spaces.GPU
35
+ def detect(image):
36
+ encoding = processor(image, return_tensors='pt')
37
+ print(encoding.keys())
38
+
39
+ with torch.no_grad():
40
+ outputs = model(**encoding)
41
+
42
+ width, height = image.size
43
+ postprocessed_outputs = processor.post_process_object_detection(outputs, target_sizes=[(height, width)], threshold=0.9)
44
+ results = postprocessed_outputs[0]
45
+
46
+
47
+ output_figure = get_output_figure(image, results['scores'], results['labels'], results['boxes'])
48
+
49
+ buf = io.BytesIO()
50
+ output_figure.savefig(buf, bbox_inches='tight')
51
+ buf.seek(0)
52
+ output_pil_img = Image.open(buf)
53
+
54
+ return output_pil_img
55
+
56
+ with gr.Blocks() as demo:
57
+ gr.Markdown("# Object detection with DETR")
58
+ gr.Markdown(
59
+ """
60
+ This applciation uses DETR (DEtection TRansformers) to detect objects on images.
61
+ You can load an image and see the predictions for the objects detected along with the attention weights.
62
+ """
63
+ )
64
+
65
+ gr.Interface(
66
+ fn=detect,
67
+ inputs=gr.Image(label="Input image", type="pil"),
68
+ outputs=[
69
+ gr.Image(label="Output prediction", type="pil")
70
+ ]
71
+ )#.launch()
72
+
73
+ demo.launch(show_error=True)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers
2
+ timm
3
+ torch