import gradio as gr import torch from matplotlib import pyplot as plt import numpy as np from groundingdino.util.inference import load_model, load_image, predict from segment_anything import SamPredictor, sam_model_registry from torchvision.ops import box_convert model_type = "vit_b" sam_checkpoint = "weights/sam_vit_b.pth" config = "groundingdino/config/GroundingDINO_SwinT_OGC.py" dino_checkpoint = "weights/groundingdino_swint_ogc.pth" sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) predictor = SamPredictor(sam) device = "cpu" model = load_model(config, dino_checkpoint, device) box_threshold = 0.35 text_threshold = 0.25 def show_mask(mask, ax, random_color=False): if random_color: color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) else: color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_image) def show_box(box, ax, label=None): x0, y0 = box[0], box[1] w, h = box[2] - box[0], box[3] - box[1] ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='red', facecolor=(0, 0, 0, 0), lw=2)) if label is not None: ax.text(x0, y0, label, fontsize=12, color='white', backgroundcolor='red', ha='left', va='top') def extract_object_with_transparent_background(image, masks): mask_expanded = np.expand_dims(masks[0], axis=-1) mask_expanded = np.repeat(mask_expanded, 3, axis=-1) segment = image * mask_expanded rgba_segment = np.zeros((segment.shape[0], segment.shape[1], 4), dtype=np.uint8) rgba_segment[:, :, :3] = segment rgba_segment[:, :, 3] = masks[0] * 255 return rgba_segment def extract_remaining_image(image, masks): inverse_mask = np.logical_not(masks[0]) inverse_mask_expanded = np.expand_dims(inverse_mask, axis=-1) inverse_mask_expanded = np.repeat(inverse_mask_expanded, 3, axis=-1) remaining_image = image * inverse_mask_expanded return remaining_image def overlay_masks_boxes_on_image(image, masks, boxes, labels, show_masks, show_boxes): fig, ax = plt.subplots() ax.imshow(image) if show_masks: for mask in masks: show_mask(mask, ax, random_color=False) if show_boxes: for input_box, label in zip(boxes, labels): show_box(input_box, ax, label) ax.axis('off') plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0) plt.margins(0, 0) fig.canvas.draw() output_image = np.array(fig.canvas.buffer_rgba()) plt.close(fig) return output_image def detect_objects(image, prompt, show_masks=True, show_boxes=True, crop_options="No crop"): image_source, image = load_image(image) predictor.set_image(image_source) boxes, logits, phrases = predict( model=model, image=image, caption=prompt, box_threshold=box_threshold, text_threshold=text_threshold, device=device ) h, w, _ = image_source.shape boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy") * torch.Tensor([w, h, w, h]) boxes = np.round(boxes.numpy()).astype(int) labels = [f"{phrase} {logit:.2f}" for phrase, logit in zip(phrases, logits)] masks_list = [] res_json = {"prompt": prompt, "objects": []} output_image_paths = [] for i, (input_box, label, phrase, logit) in enumerate(zip(boxes, labels, phrases, logits.tolist())): x1, y1, x2, y2 = input_box width = x2 - x1 height = y2 - y1 avg_size = (width + height) / 2 d = avg_size * 0.1 center_point = np.array([(x1 + x2) / 2, (y1 + y2) / 2]) points = [] points.append([center_point[0], center_point[1] - d]) points.append([center_point[0], center_point[1] + d]) points.append([center_point[0] - d, center_point[1]]) points.append([center_point[0] + d, center_point[1]]) input_point = np.array(points) input_label = np.array([1] * len(input_point)) masks, scores, logits = predictor.predict( point_coords=input_point, point_labels=input_label, multimask_output=True, ) mask_input = logits[np.argmax(scores), :, :] masks, _, _ = predictor.predict( point_coords=input_point, point_labels=input_label, mask_input=mask_input[None, :, :], multimask_output=False ) masks_list.append(masks) composite_image = np.zeros_like(image_source) rgba_segment = extract_object_with_transparent_background(image_source, masks) composite_image = np.maximum(composite_image, rgba_segment[:, :, :3]) cropped_image = composite_image[y1:y2, x1:x2, :] output_image = overlay_masks_boxes_on_image(cropped_image, [], [], [], False, False) output_image_path = f'output_image_{i}.jpeg' plt.imsave(output_image_path, output_image) output_image_paths.append(output_image_path) # save object information in json res_json["objects"].append({ "label": phrase, "dino_score": logit, "sam_score": np.max(scores).item(), "box": input_box.tolist(), "center": center_point.tolist(), "avg_size": avg_size }) return [res_json, output_image_paths] app = gr.Interface( detect_objects, inputs=[gr.Image(type='filepath', label="Upload Image"), gr.Textbox( label="Object to Detect", placeholder="Enter any text, comma separated if multiple objects needed", show_label=True, lines=1, )], outputs=[ gr.JSON(label="Output JSON"), gr.Gallery(label="Result"), ], examples=[ ["images/fish.jpg", "fish"], ["images/birds.png", "bird"], ["images/bear.png", "bear"], ["images/penguin.png", "penguin"], ["images/penn.jpg", "sign board"] ], title="Object Detection, Segmentation and Cropping", description="This app uses DINO to detect objects in an image and then uses SAM to segment and crop the objects.", ) app.launch()