|
from typing import Optional |
|
|
|
import cv2 |
|
import numpy as np |
|
from PIL import Image |
|
|
|
|
|
def export_mask( |
|
masks, |
|
autogenerated: Optional[bool] = False, |
|
random_color: Optional[bool] = True, |
|
smoothen_contours: Optional[bool] = True, |
|
) -> Image: |
|
if not autogenerated: |
|
num_masks, _, h, w = masks.shape |
|
num_masks = len(masks) |
|
|
|
|
|
masks = masks.squeeze(axis=1) |
|
|
|
|
|
combined_mask = np.zeros((h, w), dtype=np.uint8) |
|
|
|
for i in range(num_masks): |
|
mask = masks[i] |
|
mask = mask.astype(np.uint8) |
|
combined_mask[mask > 0] = i + 1 |
|
|
|
|
|
if random_color: |
|
colors = np.random.rand(num_masks, 3) |
|
else: |
|
colors = np.array( |
|
[[30 / 255, 144 / 255, 255 / 255]] * num_masks |
|
) |
|
|
|
|
|
color_image = np.zeros((h, w, 3), dtype=np.uint8) |
|
|
|
for i in range(1, num_masks + 1): |
|
mask_color = colors[i - 1] * 255 |
|
color_image[combined_mask == i] = mask_color |
|
|
|
|
|
pil_image = Image.fromarray(color_image) |
|
|
|
|
|
if smoothen_contours: |
|
contours_image = np.zeros((h, w, 4), dtype=np.float32) |
|
|
|
for i in range(1, num_masks + 1): |
|
mask = (combined_mask == i).astype(np.uint8) |
|
contours_image = smoothen(mask, contours_image) |
|
|
|
|
|
contours_image = (contours_image[:, :, :3] * 255).astype(np.uint8) |
|
contours_pil_image = Image.fromarray(contours_image) |
|
pil_image = Image.blend(pil_image, contours_pil_image, alpha=0.6) |
|
|
|
return pil_image |
|
else: |
|
sorted_anns = sorted(masks, key=(lambda x: x["area"]), reverse=True) |
|
img_shape = sorted_anns[0]["segmentation"].shape |
|
img = np.ones((img_shape[0], img_shape[1], 4)) |
|
img[:, :, 3] = 0 |
|
|
|
for ann in sorted_anns: |
|
m = ann["segmentation"] |
|
color_mask = np.concatenate([np.random.random(3), [0.5]]) |
|
img[m] = color_mask |
|
|
|
if smoothen_contours: |
|
img = smoothen(m, img) |
|
|
|
img = (img * 255).astype(np.uint8) |
|
pil_image = Image.fromarray(img) |
|
|
|
return pil_image |
|
|
|
|
|
def smoothen(mask: np.ndarray, image: np.ndarray) -> np.ndarray: |
|
contours, _ = cv2.findContours( |
|
mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE |
|
) |
|
contours = [ |
|
cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours |
|
] |
|
image = cv2.drawContours(image, contours, -1, (0, 0, 1, 0.4), thickness=1) |
|
return image |
|
|