zhaoyian01's picture
Add application file
6d1366a
raw
history blame
7.38 kB
import torch
import numpy as np
from typing import List
from isegm.inference.clicker import Click
from isegm.utils.misc import get_bbox_iou, get_bbox_from_mask, expand_bbox, clamp_bbox
from .base import BaseTransform
class ZoomIn(BaseTransform):
def __init__(self,
target_size=400,
skip_clicks=1,
expansion_ratio=1.4,
min_crop_size=200,
recompute_thresh_iou=0.5,
prob_thresh=0.50):
super().__init__()
self.target_size = target_size
self.min_crop_size = min_crop_size
self.skip_clicks = skip_clicks
self.expansion_ratio = expansion_ratio
self.recompute_thresh_iou = recompute_thresh_iou
self.prob_thresh = prob_thresh
self._input_image_shape = None
self._prev_probs = None
self._object_roi = None
self._roi_image = None
def transform(self, image_nd, clicks_lists: List[List[Click]]):
transformed_image = []
transformed_clicks_lists = []
for bindx in range(len(clicks_lists)):
new_image_nd, new_clicks_lists = self._transform(image_nd[bindx].unsqueeze(0), [clicks_lists[bindx]])
transformed_image.append(new_image_nd)
transformed_clicks_lists.append(new_clicks_lists[0])
return torch.cat(transformed_image, dim=0), transformed_clicks_lists
def _transform(self, image_nd, clicks_lists: List[List[Click]]):
assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
self.image_changed = False
clicks_list = clicks_lists[0]
if len(clicks_list) <= self.skip_clicks:
return image_nd, clicks_lists
self._input_image_shape = image_nd.shape
current_object_roi = None
if self._prev_probs is not None:
current_pred_mask = (self._prev_probs > self.prob_thresh)[0, 0]
if current_pred_mask.sum() > 0:
current_object_roi = get_object_roi(current_pred_mask, clicks_list,
self.expansion_ratio, self.min_crop_size)
if current_object_roi is None:
if self.skip_clicks >= 0:
return image_nd, clicks_lists
else:
current_object_roi = 0, image_nd.shape[2] - 1, 0, image_nd.shape[3] - 1
update_object_roi = False
if self._object_roi is None:
update_object_roi = True
elif not check_object_roi(self._object_roi, clicks_list):
update_object_roi = True
elif get_bbox_iou(current_object_roi, self._object_roi) < self.recompute_thresh_iou:
update_object_roi = True
if update_object_roi:
self._object_roi = current_object_roi
self.image_changed = True
self._roi_image = get_roi_image_nd(image_nd, self._object_roi, self.target_size)
tclicks_lists = [self._transform_clicks(clicks_list)]
return self._roi_image.to(image_nd.device), tclicks_lists
def inv_transform(self, prob_map):
new_prob_maps = []
for bindx in range(prob_map.shape[0]):
new_prob_map = self._inv_transform(prob_map[bindx].unsqueeze(0))
new_prob_maps.append(new_prob_map)
return torch.cat(new_prob_maps, dim=0)
def _inv_transform(self, prob_map):
if self._object_roi is None:
self._prev_probs = prob_map.cpu().numpy()
return prob_map
assert prob_map.shape[0] == 1
rmin, rmax, cmin, cmax = self._object_roi
prob_map = torch.nn.functional.interpolate(prob_map, size=(rmax - rmin + 1, cmax - cmin + 1),
mode='bilinear', align_corners=True)
if self._prev_probs is not None:
new_prob_map = torch.zeros(*self._prev_probs.shape, device=prob_map.device, dtype=prob_map.dtype)
new_prob_map[:, :, rmin:rmax + 1, cmin:cmax + 1] = prob_map
else:
new_prob_map = prob_map
self._prev_probs = new_prob_map.cpu().numpy()
return new_prob_map
def check_possible_recalculation(self):
if self._prev_probs is None or self._object_roi is not None or self.skip_clicks > 0:
return False
pred_mask = (self._prev_probs > self.prob_thresh)[0, 0]
if pred_mask.sum() > 0:
possible_object_roi = get_object_roi(pred_mask, [],
self.expansion_ratio, self.min_crop_size)
image_roi = (0, self._input_image_shape[2] - 1, 0, self._input_image_shape[3] - 1)
if get_bbox_iou(possible_object_roi, image_roi) < 0.50:
return True
return False
def get_state(self):
roi_image = self._roi_image.cpu() if self._roi_image is not None else None
return self._input_image_shape, self._object_roi, self._prev_probs, roi_image, self.image_changed
def set_state(self, state):
self._input_image_shape, self._object_roi, self._prev_probs, self._roi_image, self.image_changed = state
def reset(self):
self._input_image_shape = None
self._object_roi = None
self._prev_probs = None
self._roi_image = None
self.image_changed = False
def _transform_clicks(self, clicks_list):
if self._object_roi is None:
return clicks_list
rmin, rmax, cmin, cmax = self._object_roi
crop_height, crop_width = self._roi_image.shape[2:]
transformed_clicks = []
for click in clicks_list:
new_r = crop_height * (click.coords[0] - rmin) / (rmax - rmin + 1)
new_c = crop_width * (click.coords[1] - cmin) / (cmax - cmin + 1)
transformed_clicks.append(click.copy(coords=(new_r, new_c)))
return transformed_clicks
def get_object_roi(pred_mask, clicks_list, expansion_ratio, min_crop_size):
pred_mask = pred_mask.copy()
for click in clicks_list:
if click.is_positive:
pred_mask[int(click.coords[0]), int(click.coords[1])] = 1
bbox = get_bbox_from_mask(pred_mask)
bbox = expand_bbox(bbox, expansion_ratio, min_crop_size)
h, w = pred_mask.shape[0], pred_mask.shape[1]
bbox = clamp_bbox(bbox, 0, h - 1, 0, w - 1)
return bbox
def get_roi_image_nd(image_nd, object_roi, target_size):
rmin, rmax, cmin, cmax = object_roi
height = rmax - rmin + 1
width = cmax - cmin + 1
if isinstance(target_size, tuple):
new_height, new_width = target_size
else:
scale = target_size / max(height, width)
new_height = int(round(height * scale))
new_width = int(round(width * scale))
with torch.no_grad():
roi_image_nd = image_nd[:, :, rmin:rmax + 1, cmin:cmax + 1]
roi_image_nd = torch.nn.functional.interpolate(roi_image_nd, size=(new_height, new_width),
mode='bilinear', align_corners=True)
return roi_image_nd
def check_object_roi(object_roi, clicks_list):
for click in clicks_list:
if click.is_positive:
if click.coords[0] < object_roi[0] or click.coords[0] >= object_roi[1]:
return False
if click.coords[1] < object_roi[2] or click.coords[1] >= object_roi[3]:
return False
return True