Spaces:
Running
Running
File size: 4,723 Bytes
10f4748 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from .model import FastSAM
import numpy as np
from PIL import Image
import clip
from typing import Optional, List, Tuple, Union
class FastSAMDecoder:
def __init__(
self,
model: FastSAM,
device: str='cpu',
conf: float=0.4,
iou: float=0.9,
imgsz: int=1024,
retina_masks: bool=True,
):
self.model = model
self.device = device
self.retina_masks = retina_masks
self.imgsz = imgsz
self.conf = conf
self.iou = iou
self.image = None
self.image_embedding = None
def run_encoder(self, image):
if isinstance(image,str):
image = np.array(Image.open(image))
self.image = image
image_embedding = self.model(
self.image,
device=self.device,
retina_masks=self.retina_masks,
imgsz=self.imgsz,
conf=self.conf,
iou=self.iou
)
return image_embedding[0].numpy()
def run_decoder(
self,
image_embedding,
point_prompt: Optional[np.ndarray]=None,
point_label: Optional[np.ndarray]=None,
box_prompt: Optional[np.ndarray]=None,
text_prompt: Optional[str]=None,
)->np.ndarray:
self.image_embedding = image_embedding
if point_prompt is not None:
ann = self.point_prompt(points=point_prompt, pointlabel=point_label)
return ann
elif box_prompt is not None:
ann = self.box_prompt(bbox=box_prompt)
return ann
elif text_prompt is not None:
ann = self.text_prompt(text=text_prompt)
return ann
else:
return None
def box_prompt(self, bbox):
assert (bbox[2] != 0 and bbox[3] != 0)
masks = self.image_embedding.masks.data
target_height = self.image.shape[0]
target_width = self.image.shape[1]
h = masks.shape[1]
w = masks.shape[2]
if h != target_height or w != target_width:
bbox = [
int(bbox[0] * w / target_width),
int(bbox[1] * h / target_height),
int(bbox[2] * w / target_width),
int(bbox[3] * h / target_height), ]
bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
# IoUs = torch.zeros(len(masks), dtype=torch.float32)
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
masks_area = np.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], axis=(1, 2))
orig_masks_area = np.sum(masks, axis=(1, 2))
union = bbox_area + orig_masks_area - masks_area
IoUs = masks_area / union
max_iou_index = np.argmax(IoUs)
return np.array([masks[max_iou_index].cpu().numpy()])
def point_prompt(self, points, pointlabel): # numpy
masks = self._format_results(self.image_embedding[0], 0)
target_height = self.image.shape[0]
target_width = self.image.shape[1]
h = masks[0]['segmentation'].shape[0]
w = masks[0]['segmentation'].shape[1]
if h != target_height or w != target_width:
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
onemask = np.zeros((h, w))
masks = sorted(masks, key=lambda x: x['area'], reverse=True)
for i, annotation in enumerate(masks):
if type(annotation) == dict:
mask = annotation['segmentation']
else:
mask = annotation
for i, point in enumerate(points):
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
onemask[mask] = 1
if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
onemask[mask] = 0
onemask = onemask >= 1
return np.array([onemask])
def _format_results(self, result, filter=0):
annotations = []
n = len(result.masks.data)
for i in range(n):
annotation = {}
mask = result.masks.data[i] == 1.0
if np.sum(mask) < filter:
continue
annotation['id'] = i
annotation['segmentation'] = mask
annotation['bbox'] = result.boxes.data[i]
annotation['score'] = result.boxes.conf[i]
annotation['area'] = annotation['segmentation'].sum()
annotations.append(annotation)
return annotations
|