Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -38,9 +38,10 @@ def visualize_prediction(pil_img, output_dict, threshold=0.7, id2label=None):
|
|
38 |
boxes = output_dict["boxes"][keep].tolist()
|
39 |
scores = output_dict["scores"][keep].tolist()
|
40 |
labels = output_dict["labels"][keep].tolist()
|
|
|
41 |
if id2label is not None:
|
42 |
labels = [id2label[x] for x in labels]
|
43 |
-
|
44 |
plt.figure(figsize=(16, 10))
|
45 |
plt.imshow(pil_img)
|
46 |
ax = plt.gca()
|
@@ -49,7 +50,7 @@ def visualize_prediction(pil_img, output_dict, threshold=0.7, id2label=None):
|
|
49 |
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=color, linewidth=3))
|
50 |
ax.text(xmin, ymin, f"{label}: {score:0.2f}", fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5))
|
51 |
plt.axis("off")
|
52 |
-
return fig2img(plt.gcf())
|
53 |
|
54 |
def detect_objects(model_name,image_input,threshold):
|
55 |
print(type(image_input))
|
@@ -71,9 +72,9 @@ def detect_objects(model_name,image_input,threshold):
|
|
71 |
processed_outputs = make_prediction(image, feature_extractor, model)
|
72 |
|
73 |
#Visualize prediction
|
74 |
-
viz_img = visualize_prediction(image, processed_outputs, threshold, model.config.id2label)
|
75 |
|
76 |
-
return viz_img
|
77 |
|
78 |
def set_example_image(example: list) -> dict:
|
79 |
return gr.Image.update(value=example[0])
|
@@ -116,11 +117,14 @@ with demo:
|
|
116 |
with gr.Row():
|
117 |
example_images = gr.Dataset(components=[img_input],
|
118 |
samples=[["airport.jpg"],['football-match.jpg']])
|
119 |
-
|
120 |
img_but = gr.Button('Detect')
|
|
|
|
|
|
|
|
|
121 |
|
122 |
-
url_but.click(detect_objects,inputs=[options,url_input,slider_input],outputs=img_output_from_url,queue=True)
|
123 |
-
img_but.click(detect_objects,inputs=[options,img_input,slider_input],outputs=img_output_from_upload,queue=True)
|
124 |
example_images.click(fn=set_example_image,inputs=[example_images],outputs=[img_input])
|
125 |
example_url.click(fn=set_example_url,inputs=[example_url],outputs=[url_input])
|
126 |
|
@@ -128,4 +132,4 @@ with demo:
|
|
128 |
#gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=nickmuchi-object-detection-with-detr-and-yolos)")
|
129 |
|
130 |
|
131 |
-
demo.launch(enable_queue=True)
|
|
|
38 |
boxes = output_dict["boxes"][keep].tolist()
|
39 |
scores = output_dict["scores"][keep].tolist()
|
40 |
labels = output_dict["labels"][keep].tolist()
|
41 |
+
print(labels)
|
42 |
if id2label is not None:
|
43 |
labels = [id2label[x] for x in labels]
|
44 |
+
res = dict(zip(labels, scores))
|
45 |
plt.figure(figsize=(16, 10))
|
46 |
plt.imshow(pil_img)
|
47 |
ax = plt.gca()
|
|
|
50 |
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=color, linewidth=3))
|
51 |
ax.text(xmin, ymin, f"{label}: {score:0.2f}", fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5))
|
52 |
plt.axis("off")
|
53 |
+
return fig2img(plt.gcf()),res
|
54 |
|
55 |
def detect_objects(model_name,image_input,threshold):
|
56 |
print(type(image_input))
|
|
|
72 |
processed_outputs = make_prediction(image, feature_extractor, model)
|
73 |
|
74 |
#Visualize prediction
|
75 |
+
viz_img,labels = visualize_prediction(image, processed_outputs, threshold, model.config.id2label)
|
76 |
|
77 |
+
return viz_img,labels
|
78 |
|
79 |
def set_example_image(example: list) -> dict:
|
80 |
return gr.Image.update(value=example[0])
|
|
|
117 |
with gr.Row():
|
118 |
example_images = gr.Dataset(components=[img_input],
|
119 |
samples=[["airport.jpg"],['football-match.jpg']])
|
|
|
120 |
img_but = gr.Button('Detect')
|
121 |
+
|
122 |
+
with gr.TabItem('Labels'):
|
123 |
+
with gr.Row():
|
124 |
+
label = gr.Label(label = 'Labels')
|
125 |
|
126 |
+
url_but.click(detect_objects,inputs=[options,url_input,slider_input],outputs=[img_output_from_url,label],queue=True)
|
127 |
+
img_but.click(detect_objects,inputs=[options,img_input,slider_input],outputs=[img_output_from_upload,label],queue=True)
|
128 |
example_images.click(fn=set_example_image,inputs=[example_images],outputs=[img_input])
|
129 |
example_url.click(fn=set_example_url,inputs=[example_url],outputs=[url_input])
|
130 |
|
|
|
132 |
#gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=nickmuchi-object-detection-with-detr-and-yolos)")
|
133 |
|
134 |
|
135 |
+
demo.launch(enable_queue=True,show_api=False)
|