Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import cv2 | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
from .utils import image_to_np_ndarray | |
from PIL import Image | |
try: | |
import clip # for linear_assignment | |
except (ImportError, AssertionError, AttributeError): | |
from ultralytics.yolo.utils.checks import check_requirements | |
check_requirements('git+https://github.com/openai/CLIP.git') # required before installing lap from source | |
import clip | |
class FastSAMPrompt: | |
def __init__(self, image, results, device='cuda'): | |
if isinstance(image, str) or isinstance(image, Image.Image): | |
image = image_to_np_ndarray(image) | |
self.device = device | |
self.results = results | |
self.img = image | |
def _segment_image(self, image, bbox): | |
if isinstance(image, Image.Image): | |
image_array = np.array(image) | |
else: | |
image_array = image | |
segmented_image_array = np.zeros_like(image_array) | |
x1, y1, x2, y2 = bbox | |
segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2] | |
segmented_image = Image.fromarray(segmented_image_array) | |
black_image = Image.new('RGB', image.size, (255, 255, 255)) | |
# transparency_mask = np.zeros_like((), dtype=np.uint8) | |
transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8) | |
transparency_mask[y1:y2, x1:x2] = 255 | |
transparency_mask_image = Image.fromarray(transparency_mask, mode='L') | |
black_image.paste(segmented_image, mask=transparency_mask_image) | |
return black_image | |
def _format_results(self, result, filter=0): | |
annotations = [] | |
n = len(result.masks.data) | |
for i in range(n): | |
annotation = {} | |
mask = result.masks.data[i] == 1.0 | |
if torch.sum(mask) < filter: | |
continue | |
annotation['id'] = i | |
annotation['segmentation'] = mask.cpu().numpy() | |
annotation['bbox'] = result.boxes.data[i] | |
annotation['score'] = result.boxes.conf[i] | |
annotation['area'] = annotation['segmentation'].sum() | |
annotations.append(annotation) | |
return annotations | |
def filter_masks(annotations): # filte the overlap mask | |
annotations.sort(key=lambda x: x['area'], reverse=True) | |
to_remove = set() | |
for i in range(0, len(annotations)): | |
a = annotations[i] | |
for j in range(i + 1, len(annotations)): | |
b = annotations[j] | |
if i != j and j not in to_remove: | |
# check if | |
if b['area'] < a['area']: | |
if (a['segmentation'] & b['segmentation']).sum() / b['segmentation'].sum() > 0.8: | |
to_remove.add(j) | |
return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove | |
def _get_bbox_from_mask(self, mask): | |
mask = mask.astype(np.uint8) | |
contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
x1, y1, w, h = cv2.boundingRect(contours[0]) | |
x2, y2 = x1 + w, y1 + h | |
if len(contours) > 1: | |
for b in contours: | |
x_t, y_t, w_t, h_t = cv2.boundingRect(b) | |
# Merge multiple bounding boxes into one. | |
x1 = min(x1, x_t) | |
y1 = min(y1, y_t) | |
x2 = max(x2, x_t + w_t) | |
y2 = max(y2, y_t + h_t) | |
h = y2 - y1 | |
w = x2 - x1 | |
return [x1, y1, x2, y2] | |
def plot_to_result(self, | |
annotations, | |
bboxes=None, | |
points=None, | |
point_label=None, | |
mask_random_color=True, | |
better_quality=True, | |
retina=False, | |
withContours=True) -> np.ndarray: | |
if isinstance(annotations[0], dict): | |
annotations = [annotation['segmentation'] for annotation in annotations] | |
image = self.img | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
original_h = image.shape[0] | |
original_w = image.shape[1] | |
if sys.platform == "darwin": | |
plt.switch_backend("TkAgg") | |
plt.figure(figsize=(original_w / 100, original_h / 100)) | |
# Add subplot with no margin. | |
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) | |
plt.margins(0, 0) | |
plt.gca().xaxis.set_major_locator(plt.NullLocator()) | |
plt.gca().yaxis.set_major_locator(plt.NullLocator()) | |
plt.imshow(image) | |
if better_quality: | |
if isinstance(annotations[0], torch.Tensor): | |
annotations = np.array(annotations.cpu()) | |
for i, mask in enumerate(annotations): | |
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)) | |
annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)) | |
if self.device == 'cpu': | |
annotations = np.array(annotations) | |
self.fast_show_mask( | |
annotations, | |
plt.gca(), | |
random_color=mask_random_color, | |
bboxes=bboxes, | |
points=points, | |
pointlabel=point_label, | |
retinamask=retina, | |
target_height=original_h, | |
target_width=original_w, | |
) | |
else: | |
if isinstance(annotations[0], np.ndarray): | |
annotations = torch.from_numpy(annotations) | |
self.fast_show_mask_gpu( | |
annotations, | |
plt.gca(), | |
random_color=mask_random_color, | |
bboxes=bboxes, | |
points=points, | |
pointlabel=point_label, | |
retinamask=retina, | |
target_height=original_h, | |
target_width=original_w, | |
) | |
if isinstance(annotations, torch.Tensor): | |
annotations = annotations.cpu().numpy() | |
if withContours: | |
contour_all = [] | |
temp = np.zeros((original_h, original_w, 1)) | |
for i, mask in enumerate(annotations): | |
if type(mask) == dict: | |
mask = mask['segmentation'] | |
annotation = mask.astype(np.uint8) | |
if not retina: | |
annotation = cv2.resize( | |
annotation, | |
(original_w, original_h), | |
interpolation=cv2.INTER_NEAREST, | |
) | |
contours, hierarchy = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) | |
for contour in contours: | |
contour_all.append(contour) | |
cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2) | |
color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8]) | |
contour_mask = temp / 255 * color.reshape(1, 1, -1) | |
plt.imshow(contour_mask) | |
plt.axis('off') | |
fig = plt.gcf() | |
plt.draw() | |
try: | |
buf = fig.canvas.tostring_rgb() | |
except AttributeError: | |
fig.canvas.draw() | |
buf = fig.canvas.tostring_rgb() | |
cols, rows = fig.canvas.get_width_height() | |
img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 3) | |
result = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR) | |
plt.close() | |
return result | |
# Remark for refactoring: IMO a function should do one thing only, storing the image and plotting should be seperated and do not necessarily need to be class functions but standalone utility functions that the user can chain in his scripts to have more fine-grained control. | |
def plot(self, | |
annotations, | |
output_path, | |
bboxes=None, | |
points=None, | |
point_label=None, | |
mask_random_color=True, | |
better_quality=True, | |
retina=False, | |
withContours=True): | |
if len(annotations) == 0: | |
return None | |
result = self.plot_to_result( | |
annotations, | |
bboxes, | |
points, | |
point_label, | |
mask_random_color, | |
better_quality, | |
retina, | |
withContours, | |
) | |
path = os.path.dirname(os.path.abspath(output_path)) | |
if not os.path.exists(path): | |
os.makedirs(path) | |
result = result[:, :, ::-1] | |
cv2.imwrite(output_path, result) | |
# CPU post process | |
def fast_show_mask( | |
self, | |
annotation, | |
ax, | |
random_color=False, | |
bboxes=None, | |
points=None, | |
pointlabel=None, | |
retinamask=True, | |
target_height=960, | |
target_width=960, | |
): | |
msak_sum = annotation.shape[0] | |
height = annotation.shape[1] | |
weight = annotation.shape[2] | |
#Sort annotations based on area. | |
areas = np.sum(annotation, axis=(1, 2)) | |
sorted_indices = np.argsort(areas) | |
annotation = annotation[sorted_indices] | |
index = (annotation != 0).argmax(axis=0) | |
if random_color: | |
color = np.random.random((msak_sum, 1, 1, 3)) | |
else: | |
color = np.ones((msak_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255]) | |
transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6 | |
visual = np.concatenate([color, transparency], axis=-1) | |
mask_image = np.expand_dims(annotation, -1) * visual | |
show = np.zeros((height, weight, 4)) | |
h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij') | |
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None)) | |
# Use vectorized indexing to update the values of 'show'. | |
show[h_indices, w_indices, :] = mask_image[indices] | |
if bboxes is not None: | |
for bbox in bboxes: | |
x1, y1, x2, y2 = bbox | |
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1)) | |
# draw point | |
if points is not None: | |
plt.scatter( | |
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1], | |
[point[1] for i, point in enumerate(points) if pointlabel[i] == 1], | |
s=20, | |
c='y', | |
) | |
plt.scatter( | |
[point[0] for i, point in enumerate(points) if pointlabel[i] == 0], | |
[point[1] for i, point in enumerate(points) if pointlabel[i] == 0], | |
s=20, | |
c='m', | |
) | |
if not retinamask: | |
show = cv2.resize(show, (target_width, target_height), interpolation=cv2.INTER_NEAREST) | |
ax.imshow(show) | |
def fast_show_mask_gpu( | |
self, | |
annotation, | |
ax, | |
random_color=False, | |
bboxes=None, | |
points=None, | |
pointlabel=None, | |
retinamask=True, | |
target_height=960, | |
target_width=960, | |
): | |
msak_sum = annotation.shape[0] | |
height = annotation.shape[1] | |
weight = annotation.shape[2] | |
areas = torch.sum(annotation, dim=(1, 2)) | |
sorted_indices = torch.argsort(areas, descending=False) | |
annotation = annotation[sorted_indices] | |
# Find the index of the first non-zero value at each position. | |
index = (annotation != 0).to(torch.long).argmax(dim=0) | |
if random_color: | |
color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device) | |
else: | |
color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor([ | |
30 / 255, 144 / 255, 255 / 255]).to(annotation.device) | |
transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6 | |
visual = torch.cat([color, transparency], dim=-1) | |
mask_image = torch.unsqueeze(annotation, -1) * visual | |
# Select data according to the index. The index indicates which batch's data to choose at each position, converting the mask_image into a single batch form. | |
show = torch.zeros((height, weight, 4)).to(annotation.device) | |
try: | |
h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight), indexing='ij') | |
except: | |
h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight)) | |
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None)) | |
# Use vectorized indexing to update the values of 'show'. | |
show[h_indices, w_indices, :] = mask_image[indices] | |
show_cpu = show.cpu().numpy() | |
if bboxes is not None: | |
for bbox in bboxes: | |
x1, y1, x2, y2 = bbox | |
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1)) | |
# draw point | |
if points is not None: | |
plt.scatter( | |
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1], | |
[point[1] for i, point in enumerate(points) if pointlabel[i] == 1], | |
s=20, | |
c='y', | |
) | |
plt.scatter( | |
[point[0] for i, point in enumerate(points) if pointlabel[i] == 0], | |
[point[1] for i, point in enumerate(points) if pointlabel[i] == 0], | |
s=20, | |
c='m', | |
) | |
if not retinamask: | |
show_cpu = cv2.resize(show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST) | |
ax.imshow(show_cpu) | |
# clip | |
def retrieve(self, model, preprocess, elements, search_text: str, device) -> int: | |
preprocessed_images = [preprocess(image).to(device) for image in elements] | |
tokenized_text = clip.tokenize([search_text]).to(device) | |
stacked_images = torch.stack(preprocessed_images) | |
image_features = model.encode_image(stacked_images) | |
text_features = model.encode_text(tokenized_text) | |
image_features /= image_features.norm(dim=-1, keepdim=True) | |
text_features /= text_features.norm(dim=-1, keepdim=True) | |
probs = 100.0 * image_features @ text_features.T | |
return probs[:, 0].softmax(dim=0) | |
def _crop_image(self, format_results): | |
image = Image.fromarray(cv2.cvtColor(self.img, cv2.COLOR_BGR2RGB)) | |
ori_w, ori_h = image.size | |
annotations = format_results | |
mask_h, mask_w = annotations[0]['segmentation'].shape | |
if ori_w != mask_w or ori_h != mask_h: | |
image = image.resize((mask_w, mask_h)) | |
cropped_boxes = [] | |
cropped_images = [] | |
not_crop = [] | |
filter_id = [] | |
# annotations, _ = filter_masks(annotations) | |
# filter_id = list(_) | |
for _, mask in enumerate(annotations): | |
if np.sum(mask['segmentation']) <= 100: | |
filter_id.append(_) | |
continue | |
bbox = self._get_bbox_from_mask(mask['segmentation']) # mask ็ bbox | |
cropped_boxes.append(self._segment_image(image, bbox)) | |
# cropped_boxes.append(segment_image(image,mask["segmentation"])) | |
cropped_images.append(bbox) # Save the bounding box of the cropped image. | |
return cropped_boxes, cropped_images, not_crop, filter_id, annotations | |
def box_prompt(self, bbox=None, bboxes=None): | |
if self.results == None: | |
return [] | |
assert bbox or bboxes | |
if bboxes is None: | |
bboxes = [bbox] | |
max_iou_index = [] | |
for bbox in bboxes: | |
assert (bbox[2] != 0 and bbox[3] != 0) | |
masks = self.results[0].masks.data | |
target_height = self.img.shape[0] | |
target_width = self.img.shape[1] | |
h = masks.shape[1] | |
w = masks.shape[2] | |
if h != target_height or w != target_width: | |
bbox = [ | |
int(bbox[0] * w / target_width), | |
int(bbox[1] * h / target_height), | |
int(bbox[2] * w / target_width), | |
int(bbox[3] * h / target_height), ] | |
bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0 | |
bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0 | |
bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w | |
bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h | |
# IoUs = torch.zeros(len(masks), dtype=torch.float32) | |
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0]) | |
masks_area = torch.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], dim=(1, 2)) | |
orig_masks_area = torch.sum(masks, dim=(1, 2)) | |
union = bbox_area + orig_masks_area - masks_area | |
IoUs = masks_area / union | |
max_iou_index.append(int(torch.argmax(IoUs))) | |
max_iou_index = list(set(max_iou_index)) | |
return np.array(masks[max_iou_index].cpu().numpy()) | |
def point_prompt(self, points, pointlabel): # numpy | |
if self.results == None: | |
return [] | |
masks = self._format_results(self.results[0], 0) | |
target_height = self.img.shape[0] | |
target_width = self.img.shape[1] | |
h = masks[0]['segmentation'].shape[0] | |
w = masks[0]['segmentation'].shape[1] | |
if h != target_height or w != target_width: | |
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points] | |
onemask = np.zeros((h, w)) | |
masks = sorted(masks, key=lambda x: x['area'], reverse=True) | |
for i, annotation in enumerate(masks): | |
if type(annotation) == dict: | |
mask = annotation['segmentation'] | |
else: | |
mask = annotation | |
for i, point in enumerate(points): | |
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1: | |
onemask[mask] = 1 | |
if mask[point[1], point[0]] == 1 and pointlabel[i] == 0: | |
onemask[mask] = 0 | |
onemask = onemask >= 1 | |
return np.array([onemask]) | |
def text_prompt(self, text): | |
if self.results == None: | |
return [] | |
format_results = self._format_results(self.results[0], 0) | |
cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results) | |
clip_model, preprocess = clip.load('ViT-B/32', device=self.device) | |
scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device) | |
max_idx = scores.argsort() | |
max_idx = max_idx[-1] | |
max_idx += sum(np.array(filter_id) <= int(max_idx)) | |
return np.array([annotations[max_idx]['segmentation']]) | |
def everything_prompt(self): | |
if self.results == None: | |
return [] | |
return self.results[0].masks.data | |