|
|
|
|
|
import argparse |
|
import io |
|
import json |
|
import os |
|
import re |
|
from typing import Dict, List |
|
|
|
from project_settings import project_path |
|
|
|
os.environ["HUGGINGFACE_HUB_CACHE"] = (project_path / "cache/huggingface/hub").as_posix() |
|
|
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from PIL import Image |
|
import requests |
|
import torch |
|
from transformers.models.auto.processing_auto import AutoImageProcessor |
|
from transformers.models.auto.feature_extraction_auto import AutoFeatureExtractor |
|
from transformers.models.auto.modeling_auto import AutoModelForObjectDetection |
|
import validators |
|
|
|
from project_settings import project_path |
|
|
|
|
|
|
|
COLORS = [ |
|
[0.000, 0.447, 0.741], |
|
[0.850, 0.325, 0.098], |
|
[0.929, 0.694, 0.125], |
|
[0.494, 0.184, 0.556], |
|
[0.466, 0.674, 0.188], |
|
[0.301, 0.745, 0.933] |
|
] |
|
|
|
|
|
def get_original_image(url_input): |
|
if validators.url(url_input): |
|
image = Image.open(requests.get(url_input, stream=True).raw) |
|
return image |
|
|
|
|
|
def figure2image(fig): |
|
buf = io.BytesIO() |
|
fig.savefig(buf) |
|
buf.seek(0) |
|
pil_image = Image.open(buf) |
|
base_width = 750 |
|
width_percent = base_width / float(pil_image.size[0]) |
|
height_size = (float(pil_image.size[1]) * float(width_percent)) |
|
height_size = int(height_size) |
|
pil_image = pil_image.resize((base_width, height_size), Image.Resampling.LANCZOS) |
|
return pil_image |
|
|
|
|
|
def non_max_suppression(boxes, scores, threshold): |
|
"""Apply non-maximum suppression at test time to avoid detecting too many |
|
overlapping bounding boxes for a given object. |
|
Args: |
|
boxes: array of [xmin, ymin, xmax, ymax] |
|
scores: array of scores associated with each box. |
|
threshold: IoU threshold |
|
Return: |
|
keep: indices of the boxes to keep |
|
""" |
|
x1 = boxes[:, 0] |
|
y1 = boxes[:, 1] |
|
x2 = boxes[:, 2] |
|
y2 = boxes[:, 3] |
|
|
|
areas = (x2 - x1 + 1) * (y2 - y1 + 1) |
|
order = scores.argsort()[::-1] |
|
|
|
keep = [] |
|
while order.size > 0: |
|
i = order[0] |
|
keep.append(i) |
|
|
|
xx1 = np.maximum(x1[i], x1[order[1:]]) |
|
yy1 = np.maximum(y1[i], y1[order[1:]]) |
|
xx2 = np.minimum(x2[i], x2[order[1:]]) |
|
yy2 = np.minimum(y2[i], y2[order[1:]]) |
|
|
|
w = np.maximum(0.0, xx2 - xx1 + 1) |
|
h = np.maximum(0.0, yy2 - yy1 + 1) |
|
inter = w * h |
|
|
|
ovr = inter / (areas[i] + areas[order[1:]] - inter) |
|
inds = np.where(ovr <= threshold)[0] |
|
order = order[inds + 1] |
|
|
|
return keep |
|
|
|
|
|
def draw_boxes(image, boxes, scores, labels, threshold: float, |
|
idx_to_label: Dict[int, str] = None, labels_to_show: str = None): |
|
if isinstance(labels_to_show, str): |
|
if len(labels_to_show.strip()) == 0: |
|
labels_to_show = None |
|
else: |
|
labels_to_show = labels_to_show.split(",") |
|
labels_to_show = [label.strip().lower() for label in labels_to_show] |
|
labels_to_show = None if len(labels_to_show) == 0 else labels_to_show |
|
|
|
plt.figure(figsize=(50, 50)) |
|
plt.imshow(image) |
|
|
|
if idx_to_label is not None: |
|
labels = [idx_to_label[x] for x in labels] |
|
|
|
axis = plt.gca() |
|
colors = COLORS * len(boxes) |
|
for score, (xmin, ymin, xmax, ymax), label, color in zip(scores, boxes, labels, colors): |
|
if labels_to_show is not None and label.lower() not in labels_to_show: |
|
continue |
|
if score < threshold: |
|
continue |
|
axis.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=color, linewidth=10)) |
|
axis.text(xmin, ymin, f"{label}: {score:0.2f}", fontsize=60, bbox=dict(facecolor="yellow", alpha=0.8)) |
|
plt.axis("off") |
|
|
|
return figure2image(plt.gcf()) |
|
|
|
|
|
def detr_object_detection(url_input: str, |
|
image_input: Image, |
|
pretrained_model_name_or_path: str = "qgyd2021/detr_cppe5_object_detection", |
|
threshold: float = 0.5, |
|
iou_threshold: float = 0.5, |
|
labels_to_show: str = None, |
|
): |
|
|
|
model = AutoModelForObjectDetection.from_pretrained(pretrained_model_name_or_path) |
|
image_processor = AutoImageProcessor.from_pretrained(pretrained_model_name_or_path) |
|
|
|
|
|
if validators.url(url_input): |
|
image = get_original_image(url_input) |
|
elif image_input: |
|
image = image_input |
|
else: |
|
raise AssertionError("at least one `url_input` and `image_input`") |
|
image_size = torch.tensor([tuple(reversed(image.size))]) |
|
|
|
|
|
|
|
inputs = image_processor(images=image, return_tensors="pt") |
|
outputs = model.forward(**inputs) |
|
|
|
processed_outputs = image_processor.post_process_object_detection( |
|
outputs, threshold=threshold, target_sizes=image_size) |
|
|
|
processed_outputs = processed_outputs[0] |
|
|
|
|
|
boxes = processed_outputs["boxes"].detach().numpy() |
|
scores = processed_outputs["scores"].detach().numpy() |
|
labels = processed_outputs["labels"].detach().numpy() |
|
|
|
keep = non_max_suppression(boxes, scores, threshold=iou_threshold) |
|
boxes = boxes[keep] |
|
scores = scores[keep] |
|
labels = labels[keep] |
|
|
|
viz_image: Image = draw_boxes( |
|
image, boxes, scores, labels, |
|
threshold=threshold, |
|
idx_to_label=model.config.id2label, |
|
labels_to_show=labels_to_show |
|
) |
|
return viz_image |
|
|
|
|
|
def main(): |
|
|
|
title = "## Detr Cppe5 Object Detection" |
|
|
|
description = """ |
|
reference: |
|
https://huggingface.co/docs/transformers/tasks/object_detection |
|
|
|
""" |
|
|
|
example_urls = [ |
|
*[ |
|
[ |
|
"https://huggingface.co/datasets/intelli-zen/cppe-5/resolve/main/data/images/{}.png".format(idx), |
|
"intelli-zen/detr_cppe5_object_detection", |
|
0.5, 0.6, None |
|
] for idx in range(1001, 1030) |
|
] |
|
] |
|
|
|
example_images = [ |
|
[ |
|
"data/2lnWoly.jpg", |
|
"intelli-zen/detr_cppe5_object_detection", |
|
0.5, 0.6, None |
|
] |
|
] |
|
|
|
with gr.Blocks() as blocks: |
|
gr.Markdown(value=title) |
|
gr.Markdown(value=description) |
|
|
|
model_name = gr.components.Dropdown( |
|
choices=[ |
|
"intelli-zen/detr_cppe5_object_detection", |
|
], |
|
value="intelli-zen/detr_cppe5_object_detection", |
|
label="model_name", |
|
) |
|
threshold_slider = gr.components.Slider( |
|
minimum=0, maximum=1.0, |
|
step=0.01, value=0.5, |
|
label="Threshold" |
|
) |
|
iou_threshold_slider = gr.components.Slider( |
|
minimum=0, maximum=1.0, |
|
step=0.1, value=0.5, |
|
label="IOU Threshold" |
|
) |
|
classes_to_detect = gr.Textbox(placeholder="e.g. person, truck (split by , comma).", |
|
label="labels to show") |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem("Image URL"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
url_input = gr.Textbox(lines=1, label="Enter valid image URL here..") |
|
original_image = gr.Image() |
|
url_input.change(get_original_image, url_input, original_image) |
|
with gr.Column(): |
|
img_output_from_url = gr.Image() |
|
|
|
url_but = gr.Button("Detect") |
|
|
|
with gr.Row(): |
|
gr.Examples(examples=example_urls, |
|
inputs=[url_input, model_name, threshold_slider, iou_threshold_slider], |
|
examples_per_page=5, |
|
) |
|
|
|
with gr.TabItem("Image Upload"): |
|
with gr.Row(): |
|
img_input = gr.Image(type="pil") |
|
img_output_from_upload = gr.Image() |
|
|
|
img_but = gr.Button("Detect") |
|
|
|
with gr.Row(): |
|
gr.Examples(examples=example_images, |
|
inputs=[img_input, model_name, threshold_slider, iou_threshold_slider], |
|
examples_per_page=5, |
|
) |
|
|
|
inputs = [url_input, img_input, model_name, threshold_slider, iou_threshold_slider, classes_to_detect] |
|
url_but.click(detr_object_detection, inputs=inputs, outputs=[img_output_from_url], queue=True) |
|
img_but.click(detr_object_detection, inputs=inputs, outputs=[img_output_from_upload], queue=True) |
|
|
|
blocks.queue().launch() |
|
return |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|