Grounding-DINO-1.5 / gdino /visualize.py
Mountchicken's picture
Upload 28 files
bf9dee2 verified
raw
history blame contribute delete
No virus
4.05 kB
from typing import Dict
import numpy as np
from PIL import Image, ImageDraw, ImageFont, ImageOps
import random
def draw_mask(mask, draw, random_color=True):
"""Draws a mask with a specified color on an image.
Args:
mask (np.array): Binary mask as a NumPy array.
draw (ImageDraw.Draw): ImageDraw object to draw on the image.
random_color (bool): Whether to use a random color for the mask.
"""
if random_color:
color = (
random.randint(0, 255),
random.randint(0, 255),
random.randint(0, 255),
153,
)
else:
color = (30, 144, 255, 153)
nonzero_coords = np.transpose(np.nonzero(mask))
for coord in nonzero_coords:
draw.point(coord[::-1], fill=color)
def visualize(image_pil: Image,
result: Dict,
draw_width: float = 6.0,
return_mask=True,
draw_score=True) -> Image:
"""Plot bounding boxes and labels on an image.
Args:
image_pil (PIL.Image): The input image as a PIL Image object.
result (Dict[str, Union[torch.Tensor, List[torch.Tensor]]]): The target dictionary containing
the bounding boxes and labels. The keys are:
- boxes (List[int]): A list of bounding boxes in shape (N, 4), [x1, y1, x2, y2] format.
- scores (List[float]): A list of scores for each bounding box. shape (N)
- categorys (List[str]): A list of categorys for each object
- masks (List[PIL.Image]): A list of masks in the format of PIL.Image
draw_score (bool): Draw score on the image. Defaults to False.
Returns:
PIL.Image: The input image with plotted bounding boxes, labels, and masks.
"""
# Get the bounding boxes and labels from the target dictionary
boxes = result["boxes"]
scores = result["scores"]
categorys = result["categorys"]
masks = result.get("masks", [])
# Find all unique categories and build a cate2color dictionary
cate2color = {}
unique_categorys = set(categorys)
for cate in unique_categorys:
cate2color[cate] = tuple(np.random.randint(0, 255, size=3).tolist())
# Create a PIL ImageDraw object to draw on the input image
if isinstance(image_pil, np.ndarray):
image_pil = Image.fromarray(image_pil)
draw = ImageDraw.Draw(image_pil)
# Create a new binary mask image with the same size as the input image
mask = Image.new("L", image_pil.size, 0)
# Create a PIL ImageDraw object to draw on the mask image
mask_draw = ImageDraw.Draw(mask)
# Draw boxes, labels, and masks for each box and label in the target dictionary
for box, score, category in zip(boxes, scores, categorys):
# Extract the box coordinates
x0, y0, x1, y1 = box
x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
color = cate2color[category]
# Draw the box outline on the input image
draw.rectangle([x0, y0, x1, y1], outline=color, width=int(draw_width))
# Draw the label and score on the input image
if draw_score:
text = f"{category} {score:.2f}"
else:
text = f"{category}"
font = ImageFont.load_default()
if hasattr(font, "getbbox"):
bbox = draw.textbbox((x0, y0), text, font)
else:
w, h = draw.textsize(text, font)
bbox = (x0, y0, w + x0, y0 + h)
draw.rectangle(bbox, fill=color)
draw.text((x0, y0), text, fill="white")
# Draw the mask on the input image if masks are provided
if len(masks) > 0 and return_mask:
size = image_pil.size
mask_image = Image.new("RGBA", size, color=(0, 0, 0, 0))
mask_draw = ImageDraw.Draw(mask_image)
for mask in masks:
mask = np.array(mask)[:, :, -1]
draw_mask(mask, mask_draw)
image_pil = Image.alpha_composite(image_pil.convert("RGBA"), mask_image).convert("RGB")
return image_pil