|
import PIL |
|
import numpy as np |
|
import copy |
|
import cv2 |
|
import matplotlib.pyplot as plt |
|
from torchvision.transforms.functional import to_pil_image |
|
import torch |
|
from PIL import Image |
|
import matplotlib |
|
matplotlib.use('Agg') |
|
|
|
def show_anns(anns, ax=None): |
|
if len(anns) == 0: |
|
return |
|
if ax is None: |
|
ax = plt.gca() |
|
|
|
sorted_anns = sorted(enumerate(anns), key=(lambda x: x[1]['area']), reverse=True) |
|
|
|
for original_idx, ann in sorted_anns: |
|
m = ann['segmentation'] |
|
if m.shape != (512, 512): |
|
m = cv2.resize(m.astype(float), (512, 512)) |
|
|
|
|
|
color_mask = np.random.random(3) |
|
|
|
|
|
colored_mask = np.zeros((512, 512, 3)) |
|
for i in range(3): |
|
colored_mask[:,:,i] = color_mask[i] |
|
|
|
|
|
ax.imshow(np.dstack([colored_mask, m * 0.35])) |
|
|
|
|
|
contours, _ = cv2.findContours((m * 255).astype(np.uint8), |
|
cv2.RETR_EXTERNAL, |
|
cv2.CHAIN_APPROX_SIMPLE) |
|
|
|
|
|
if contours: |
|
|
|
cnt = max(contours, key=cv2.contourArea) |
|
M = cv2.moments(cnt) |
|
|
|
if M["m00"] != 0: |
|
cx = int(M["m10"] / M["m00"]) |
|
cy = int(M["m01"] / M["m00"]) |
|
|
|
|
|
ax.text(cx, cy, str(original_idx), |
|
color='white', |
|
fontsize=16, |
|
ha='center', |
|
va='center', |
|
fontweight='bold', |
|
bbox=dict(facecolor='black', |
|
alpha=0.5, |
|
edgecolor='none', |
|
pad=1)) |
|
|
|
|
|
def create_image_grid(original_image, images, names, rows, columns): |
|
names = copy.copy(names) |
|
images = copy.copy(images) |
|
|
|
|
|
filtered_images = [] |
|
filtered_names = [] |
|
for img, name in zip(images, names): |
|
if name.strip(): |
|
filtered_images.append(img) |
|
filtered_names.append(name) |
|
|
|
images = filtered_images |
|
names = filtered_names |
|
|
|
|
|
images.insert(0, original_image) |
|
names.insert(0, 'Original') |
|
|
|
fig = plt.figure(figsize=(20, 20)) |
|
|
|
for idx, (img, name) in enumerate(zip(images, names)): |
|
ax = fig.add_subplot(rows, columns, idx + 1) |
|
|
|
if isinstance(img, PIL.Image.Image): |
|
ax.imshow(img) |
|
else: |
|
ax.imshow(img) |
|
|
|
ax.set_title(name, fontsize=12, pad=10) |
|
ax.axis('off') |
|
|
|
plt.tight_layout() |
|
return fig |
|
|