|
import gradio as gr |
|
import torch |
|
from matplotlib import pyplot as plt |
|
import numpy as np |
|
from groundingdino.util.inference import load_model, load_image, predict |
|
from segment_anything import SamPredictor, sam_model_registry |
|
from torchvision.ops import box_convert |
|
|
|
model_type = "vit_b" |
|
sam_checkpoint = "weights/sam_vit_b.pth" |
|
config = "groundingdino/config/GroundingDINO_SwinT_OGC.py" |
|
dino_checkpoint = "weights/groundingdino_swint_ogc.pth" |
|
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) |
|
predictor = SamPredictor(sam) |
|
device = "cpu" |
|
model = load_model(config, dino_checkpoint, device) |
|
box_threshold = 0.35 |
|
text_threshold = 0.25 |
|
|
|
|
|
def show_mask(mask, ax, random_color=False): |
|
if random_color: |
|
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) |
|
else: |
|
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) |
|
h, w = mask.shape[-2:] |
|
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
|
ax.imshow(mask_image) |
|
|
|
|
|
def show_box(box, ax, label=None): |
|
x0, y0 = box[0], box[1] |
|
w, h = box[2] - box[0], box[3] - box[1] |
|
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='red', facecolor=(0, 0, 0, 0), lw=2)) |
|
if label is not None: |
|
ax.text(x0, y0, label, fontsize=12, color='white', backgroundcolor='red', ha='left', va='top') |
|
|
|
|
|
def extract_object_with_transparent_background(image, masks): |
|
mask_expanded = np.expand_dims(masks[0], axis=-1) |
|
mask_expanded = np.repeat(mask_expanded, 3, axis=-1) |
|
segment = image * mask_expanded |
|
rgba_segment = np.zeros((segment.shape[0], segment.shape[1], 4), dtype=np.uint8) |
|
rgba_segment[:, :, :3] = segment |
|
rgba_segment[:, :, 3] = masks[0] * 255 |
|
return rgba_segment |
|
|
|
|
|
def extract_remaining_image(image, masks): |
|
inverse_mask = np.logical_not(masks[0]) |
|
inverse_mask_expanded = np.expand_dims(inverse_mask, axis=-1) |
|
inverse_mask_expanded = np.repeat(inverse_mask_expanded, 3, axis=-1) |
|
remaining_image = image * inverse_mask_expanded |
|
return remaining_image |
|
|
|
|
|
def overlay_masks_boxes_on_image(image, masks, boxes, labels, show_masks, show_boxes): |
|
fig, ax = plt.subplots() |
|
ax.imshow(image) |
|
if show_masks: |
|
for mask in masks: |
|
show_mask(mask, ax, random_color=False) |
|
|
|
if show_boxes: |
|
for input_box, label in zip(boxes, labels): |
|
show_box(input_box, ax, label) |
|
|
|
ax.axis('off') |
|
plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0) |
|
plt.margins(0, 0) |
|
|
|
fig.canvas.draw() |
|
output_image = np.array(fig.canvas.buffer_rgba()) |
|
|
|
plt.close(fig) |
|
return output_image |
|
|
|
|
|
def detect_objects(image, prompt, show_masks=True, show_boxes=True, crop_options="No crop"): |
|
image_source, image = load_image(image) |
|
predictor.set_image(image_source) |
|
|
|
boxes, logits, phrases = predict( |
|
model=model, |
|
image=image, |
|
caption=prompt, |
|
box_threshold=box_threshold, |
|
text_threshold=text_threshold, |
|
device=device |
|
) |
|
|
|
h, w, _ = image_source.shape |
|
boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy") * torch.Tensor([w, h, w, h]) |
|
boxes = np.round(boxes.numpy()).astype(int) |
|
|
|
labels = [f"{phrase} {logit:.2f}" for phrase, logit in zip(phrases, logits)] |
|
|
|
masks_list = [] |
|
res_json = {"prompt": prompt, "objects": []} |
|
|
|
output_image_paths = [] |
|
|
|
for i, (input_box, label, phrase, logit) in enumerate(zip(boxes, labels, phrases, logits.tolist())): |
|
x1, y1, x2, y2 = input_box |
|
width = x2 - x1 |
|
height = y2 - y1 |
|
avg_size = (width + height) / 2 |
|
d = avg_size * 0.1 |
|
|
|
center_point = np.array([(x1 + x2) / 2, (y1 + y2) / 2]) |
|
points = [] |
|
points.append([center_point[0], center_point[1] - d]) |
|
points.append([center_point[0], center_point[1] + d]) |
|
points.append([center_point[0] - d, center_point[1]]) |
|
points.append([center_point[0] + d, center_point[1]]) |
|
input_point = np.array(points) |
|
input_label = np.array([1] * len(input_point)) |
|
|
|
masks, scores, logits = predictor.predict( |
|
point_coords=input_point, |
|
point_labels=input_label, |
|
multimask_output=True, |
|
) |
|
mask_input = logits[np.argmax(scores), :, :] |
|
|
|
masks, _, _ = predictor.predict( |
|
point_coords=input_point, |
|
point_labels=input_label, |
|
mask_input=mask_input[None, :, :], |
|
multimask_output=False |
|
) |
|
masks_list.append(masks) |
|
|
|
composite_image = np.zeros_like(image_source) |
|
rgba_segment = extract_object_with_transparent_background(image_source, masks) |
|
composite_image = np.maximum(composite_image, rgba_segment[:, :, :3]) |
|
cropped_image = composite_image[y1:y2, x1:x2, :] |
|
output_image = overlay_masks_boxes_on_image(cropped_image, [], [], [], False, False) |
|
|
|
output_image_path = f'output_image_{i}.jpeg' |
|
plt.imsave(output_image_path, output_image) |
|
|
|
output_image_paths.append(output_image_path) |
|
|
|
|
|
res_json["objects"].append({ |
|
"label": phrase, |
|
"dino_score": logit, |
|
"sam_score": np.max(scores).item(), |
|
"box": input_box.tolist(), |
|
"center": center_point.tolist(), |
|
"avg_size": avg_size |
|
}) |
|
|
|
return [res_json, output_image_paths] |
|
|
|
|
|
app = gr.Interface( |
|
detect_objects, |
|
inputs=[gr.Image(type='filepath', label="Upload Image"), |
|
gr.Textbox( |
|
label="Object to Detect", |
|
placeholder="Enter any text, comma separated if multiple objects needed", |
|
show_label=True, |
|
lines=1, |
|
)], |
|
outputs=[ |
|
gr.JSON(label="Output JSON"), |
|
gr.Gallery(label="Result"), |
|
], |
|
examples=[ |
|
["images/fish.jpg", "fish"], |
|
["images/birds.png", "bird"], |
|
["images/bear.png", "bear"], |
|
["images/penguin.png", "penguin"], |
|
["images/penn.jpg", "sign board"] |
|
], |
|
title="Object Detection, Segmentation and Cropping", |
|
description="This app uses DINO to detect objects in an image and then uses SAM to segment and crop the objects.", |
|
) |
|
|
|
app.launch() |
|
|