GraCo / isegm /inference /clicker.py
zhaoyian01's picture
Add application file
6d1366a
raw
history blame
3.7 kB
import numpy as np
from copy import deepcopy
import cv2
class Clicker(object):
def __init__(self, gt_mask=None, init_clicks=None, ignore_label=-1, click_indx_offset=0):
self.click_indx_offset = click_indx_offset
if gt_mask is not None:
self.gt_mask = gt_mask == 1
self.not_ignore_mask = gt_mask != ignore_label
else:
self.gt_mask = None
self.reset_clicks()
if init_clicks is not None:
for click in init_clicks:
self.add_click(click)
def make_next_click(self, pred_mask):
assert self.gt_mask is not None
click = self._get_next_click(pred_mask)
self.add_click(click)
def get_clicks(self, clicks_limit=None):
return self.clicks_list[:clicks_limit]
def _get_next_click(self, pred_mask, padding=True):
fn_mask = np.logical_and(np.logical_and(self.gt_mask, np.logical_not(pred_mask)), self.not_ignore_mask)
fp_mask = np.logical_and(np.logical_and(np.logical_not(self.gt_mask), pred_mask), self.not_ignore_mask)
if padding:
fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), 'constant')
fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), 'constant')
fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0)
fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0)
if padding:
fn_mask_dt = fn_mask_dt[1:-1, 1:-1]
fp_mask_dt = fp_mask_dt[1:-1, 1:-1]
fn_mask_dt = fn_mask_dt * self.not_clicked_map
fp_mask_dt = fp_mask_dt * self.not_clicked_map
fn_max_dist = np.max(fn_mask_dt)
fp_max_dist = np.max(fp_mask_dt)
is_positive = fn_max_dist > fp_max_dist
if is_positive:
coords_y, coords_x = np.where(fn_mask_dt == fn_max_dist) # coords is [y, x]
else:
coords_y, coords_x = np.where(fp_mask_dt == fp_max_dist) # coords is [y, x]
return Click(is_positive=is_positive, coords=(coords_y[0], coords_x[0]))
def add_click(self, click):
coords = click.coords
click.indx = self.click_indx_offset + self.num_pos_clicks + self.num_neg_clicks
if click.is_positive:
self.num_pos_clicks += 1
else:
self.num_neg_clicks += 1
self.clicks_list.append(click)
if self.gt_mask is not None:
self.not_clicked_map[coords[0], coords[1]] = False
def _remove_last_click(self):
click = self.clicks_list.pop()
coords = click.coords
if click.is_positive:
self.num_pos_clicks -= 1
else:
self.num_neg_clicks -= 1
if self.gt_mask is not None:
self.not_clicked_map[coords[0], coords[1]] = True
def reset_clicks(self):
if self.gt_mask is not None:
self.not_clicked_map = np.ones_like(self.gt_mask, dtype=bool)
self.num_pos_clicks = 0
self.num_neg_clicks = 0
self.clicks_list = []
def get_state(self):
return deepcopy(self.clicks_list)
def set_state(self, state):
self.reset_clicks()
for click in state:
self.add_click(click)
def __len__(self):
return len(self.clicks_list)
class Click:
def __init__(self, is_positive, coords, indx=None):
self.is_positive = is_positive
self.coords = coords
self.indx = indx
@property
def coords_and_indx(self):
return (*self.coords, self.indx)
def copy(self, **kwargs):
self_copy = deepcopy(self)
for k, v in kwargs.items():
setattr(self_copy, k, v)
return self_copy