Spaces:
Running
Running
File size: 8,080 Bytes
6d1366a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import torch
import torch.nn.functional as F
import numpy as np
from torchvision import transforms
from isegm.inference.transforms import AddHorizontalFlip, SigmoidForPred, LimitLongestSide
class BasePredictor(object):
def __init__(self, model, device, gra=None, sam_type=None,
net_clicks_limit=None,
with_flip=False,
zoom_in=None,
max_size=None,
**kwargs):
self.with_flip = with_flip
self.net_clicks_limit = net_clicks_limit
self.original_image = None
self.device = device
self.gra=gra if gra is not None and gra > 0 else None
self.sam_type = sam_type
self.zoom_in = zoom_in
self.prev_prediction = None
self.model_indx = 0
self.click_models = None
self.net_state_dict = None
if isinstance(model, tuple):
self.net, self.click_models = model
else:
self.net = model
self.to_tensor = transforms.ToTensor()
self.transforms = [zoom_in] if zoom_in is not None else []
if max_size is not None:
self.transforms.append(LimitLongestSide(max_size=max_size))
self.transforms.append(SigmoidForPred())
if with_flip:
self.transforms.append(AddHorizontalFlip())
def set_input_image(self, image):
if not isinstance(image, torch.Tensor):
image_nd = self.to_tensor(image)
else:
image_nd = image
for transform in self.transforms:
transform.reset()
self.original_image = image_nd.to(self.device)
if len(self.original_image.shape) == 3:
self.original_image = self.original_image.unsqueeze(0)
self.prev_prediction = torch.zeros_like(self.original_image[:, :1, :, :])
def get_prediction(self, clicker, prev_mask=None, gra=None):
clicks_list = clicker.get_clicks()
if self.click_models is not None:
model_indx = min(clicker.click_indx_offset + len(clicks_list), len(self.click_models)) - 1
if model_indx != self.model_indx:
self.model_indx = model_indx
self.net = self.click_models[model_indx]
input_image = self.original_image
if prev_mask is None:
prev_mask = self.prev_prediction
if (hasattr(self.net, 'with_prev_mask') and self.net.with_prev_mask) or self.sam_type is not None:
input_image = torch.cat((input_image, prev_mask), dim=1)
image_nd, clicks_lists, is_image_changed = self.apply_transforms(
input_image, [clicks_list]
)
pred_logits = self._get_prediction(image_nd, clicks_lists, is_image_changed, gra=gra)
prediction = F.interpolate(pred_logits, mode='bilinear', align_corners=True,
size=image_nd.size()[2:])
for t in reversed(self.transforms):
prediction = t.inv_transform(prediction)
if self.zoom_in is not None and self.zoom_in.check_possible_recalculation():
return self.get_prediction(clicker)
self.prev_prediction = prediction
return prediction.cpu().numpy()[0, 0]
def _get_prediction(self, image_nd, clicks_lists, is_image_changed, gra=None):
points_nd = self.get_points_nd(clicks_lists)
if gra is None:
gra = self.gra
if self.sam_type == 'SAM':
batched_input = self.get_sam_batched_input(image_nd, points_nd)
batched_output = self.net(batched_input, multimask_output=False, return_logits=True)
return torch.cat([batch['masks'] for batch in batched_output], dim=0)
if gra is not None:
return self.net(image_nd, points_nd, torch.Tensor([gra]).to(self.device))['instances']
else:
return self.net(image_nd, points_nd)['instances']
def _batch_infer(self, batch_image_tensor, batch_clickers, prev_mask=None):
if prev_mask is None:
prev_mask = self.prev_prediction
if hasattr(self.net, 'with_prev_mask') and self.net.with_prev_mask:
input_image = torch.cat((batch_image_tensor, prev_mask), dim=1)
clicks_lists = [clicker.get_clicks() for clicker in batch_clickers]
image_nd, clicks_lists, is_image_changed = self.apply_transforms(
input_image, clicks_lists
)
points_nd = self.get_points_nd(clicks_lists)
pred_logits = self.net(image_nd, points_nd)['instances']
prediction = F.interpolate(pred_logits, mode='bilinear', align_corners=True,
size=image_nd.size()[2:])
for t in reversed(self.transforms):
prediction = t.inv_transform(prediction)
self.prev_prediction = prediction
return prediction.cpu().numpy()[:, 0]
def _get_transform_states(self):
return [x.get_state() for x in self.transforms]
def _set_transform_states(self, states):
assert len(states) == len(self.transforms)
for state, transform in zip(states, self.transforms):
transform.set_state(state)
def apply_transforms(self, image_nd, clicks_lists):
is_image_changed = False
for t in self.transforms:
image_nd, clicks_lists = t.transform(image_nd, clicks_lists)
is_image_changed |= t.image_changed
return image_nd, clicks_lists, is_image_changed
def get_points_nd(self, clicks_lists):
total_clicks = []
num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists]
num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)]
num_max_points = max(num_pos_clicks + num_neg_clicks)
if self.net_clicks_limit is not None:
num_max_points = min(self.net_clicks_limit, num_max_points)
num_max_points = max(1, num_max_points)
for clicks_list in clicks_lists:
clicks_list = clicks_list[:self.net_clicks_limit]
pos_clicks = [click.coords_and_indx for click in clicks_list if click.is_positive]
pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1, -1)]
neg_clicks = [click.coords_and_indx for click in clicks_list if not click.is_positive]
neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1, -1)]
total_clicks.append(pos_clicks + neg_clicks)
return torch.tensor(total_clicks, device=self.device)
def get_sam_batched_input(self, image_nd, points_nd):
batched_output = []
for i in range(image_nd.shape[0]):
image = image_nd[i]
point_length = points_nd[i].shape[0] // 2
point_coords = []
point_labels = []
for i, point in enumerate(points_nd[i]):
point_np = point.cpu().numpy()
if point_np[0] == -1:
continue
if i < point_length:
point_labels.append(1)
else:
point_labels.append(0)
point_coords.append([point_np[1], point_np[0]])
res = {
'image': image[:3, :, :],
'point_coords': torch.as_tensor(np.array(point_coords), dtype=torch.float, device=self.device)[None, :],
'point_labels': torch.as_tensor(np.array(point_labels), dtype=torch.float, device=self.device)[None, :],
'original_size': image.cpu().numpy().shape[1:],
'mask_inputs': image[3, :, :][None, None, :]
}
batched_output.append(res)
return batched_output
def get_states(self):
return {
'transform_states': self._get_transform_states(),
'prev_prediction': self.prev_prediction.clone()
}
def set_states(self, states):
self._set_transform_states(states['transform_states'])
self.prev_prediction = states['prev_prediction']
|