import torch import torch.nn.functional as F import numpy as np from torchvision import transforms from isegm.inference.transforms import AddHorizontalFlip, SigmoidForPred, LimitLongestSide class BasePredictor(object): def __init__(self, model, device, gra=None, sam_type=None, net_clicks_limit=None, with_flip=False, zoom_in=None, max_size=None, **kwargs): self.with_flip = with_flip self.net_clicks_limit = net_clicks_limit self.original_image = None self.device = device self.gra=gra if gra is not None and gra > 0 else None self.sam_type = sam_type self.zoom_in = zoom_in self.prev_prediction = None self.model_indx = 0 self.click_models = None self.net_state_dict = None if isinstance(model, tuple): self.net, self.click_models = model else: self.net = model self.to_tensor = transforms.ToTensor() self.transforms = [zoom_in] if zoom_in is not None else [] if max_size is not None: self.transforms.append(LimitLongestSide(max_size=max_size)) self.transforms.append(SigmoidForPred()) if with_flip: self.transforms.append(AddHorizontalFlip()) def set_input_image(self, image): if not isinstance(image, torch.Tensor): image_nd = self.to_tensor(image) else: image_nd = image for transform in self.transforms: transform.reset() self.original_image = image_nd.to(self.device) if len(self.original_image.shape) == 3: self.original_image = self.original_image.unsqueeze(0) self.prev_prediction = torch.zeros_like(self.original_image[:, :1, :, :]) def get_prediction(self, clicker, prev_mask=None, gra=None): clicks_list = clicker.get_clicks() if self.click_models is not None: model_indx = min(clicker.click_indx_offset + len(clicks_list), len(self.click_models)) - 1 if model_indx != self.model_indx: self.model_indx = model_indx self.net = self.click_models[model_indx] input_image = self.original_image if prev_mask is None: prev_mask = self.prev_prediction if (hasattr(self.net, 'with_prev_mask') and self.net.with_prev_mask) or self.sam_type is not None: input_image = torch.cat((input_image, prev_mask), dim=1) image_nd, clicks_lists, is_image_changed = self.apply_transforms( input_image, [clicks_list] ) pred_logits = self._get_prediction(image_nd, clicks_lists, is_image_changed, gra=gra) prediction = F.interpolate(pred_logits, mode='bilinear', align_corners=True, size=image_nd.size()[2:]) for t in reversed(self.transforms): prediction = t.inv_transform(prediction) if self.zoom_in is not None and self.zoom_in.check_possible_recalculation(): return self.get_prediction(clicker) self.prev_prediction = prediction return prediction.cpu().numpy()[0, 0] def _get_prediction(self, image_nd, clicks_lists, is_image_changed, gra=None): points_nd = self.get_points_nd(clicks_lists) if gra is None: gra = self.gra if self.sam_type == 'SAM': batched_input = self.get_sam_batched_input(image_nd, points_nd) batched_output = self.net(batched_input, multimask_output=False, return_logits=True) return torch.cat([batch['masks'] for batch in batched_output], dim=0) if gra is not None: return self.net(image_nd, points_nd, torch.Tensor([gra]).to(self.device))['instances'] else: return self.net(image_nd, points_nd)['instances'] def _batch_infer(self, batch_image_tensor, batch_clickers, prev_mask=None): if prev_mask is None: prev_mask = self.prev_prediction if hasattr(self.net, 'with_prev_mask') and self.net.with_prev_mask: input_image = torch.cat((batch_image_tensor, prev_mask), dim=1) clicks_lists = [clicker.get_clicks() for clicker in batch_clickers] image_nd, clicks_lists, is_image_changed = self.apply_transforms( input_image, clicks_lists ) points_nd = self.get_points_nd(clicks_lists) pred_logits = self.net(image_nd, points_nd)['instances'] prediction = F.interpolate(pred_logits, mode='bilinear', align_corners=True, size=image_nd.size()[2:]) for t in reversed(self.transforms): prediction = t.inv_transform(prediction) self.prev_prediction = prediction return prediction.cpu().numpy()[:, 0] def _get_transform_states(self): return [x.get_state() for x in self.transforms] def _set_transform_states(self, states): assert len(states) == len(self.transforms) for state, transform in zip(states, self.transforms): transform.set_state(state) def apply_transforms(self, image_nd, clicks_lists): is_image_changed = False for t in self.transforms: image_nd, clicks_lists = t.transform(image_nd, clicks_lists) is_image_changed |= t.image_changed return image_nd, clicks_lists, is_image_changed def get_points_nd(self, clicks_lists): total_clicks = [] num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists] num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)] num_max_points = max(num_pos_clicks + num_neg_clicks) if self.net_clicks_limit is not None: num_max_points = min(self.net_clicks_limit, num_max_points) num_max_points = max(1, num_max_points) for clicks_list in clicks_lists: clicks_list = clicks_list[:self.net_clicks_limit] pos_clicks = [click.coords_and_indx for click in clicks_list if click.is_positive] pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1, -1)] neg_clicks = [click.coords_and_indx for click in clicks_list if not click.is_positive] neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1, -1)] total_clicks.append(pos_clicks + neg_clicks) return torch.tensor(total_clicks, device=self.device) def get_sam_batched_input(self, image_nd, points_nd): batched_output = [] for i in range(image_nd.shape[0]): image = image_nd[i] point_length = points_nd[i].shape[0] // 2 point_coords = [] point_labels = [] for i, point in enumerate(points_nd[i]): point_np = point.cpu().numpy() if point_np[0] == -1: continue if i < point_length: point_labels.append(1) else: point_labels.append(0) point_coords.append([point_np[1], point_np[0]]) res = { 'image': image[:3, :, :], 'point_coords': torch.as_tensor(np.array(point_coords), dtype=torch.float, device=self.device)[None, :], 'point_labels': torch.as_tensor(np.array(point_labels), dtype=torch.float, device=self.device)[None, :], 'original_size': image.cpu().numpy().shape[1:], 'mask_inputs': image[3, :, :][None, None, :] } batched_output.append(res) return batched_output def get_states(self): return { 'transform_states': self._get_transform_states(), 'prev_prediction': self.prev_prediction.clone() } def set_states(self, states): self._set_transform_states(states['transform_states']) self.prev_prediction = states['prev_prediction']