from time import time import numpy as np import torch import cv2 from isegm.inference import utils from isegm.inference.clicker import Click, Clicker try: get_ipython() from tqdm import tqdm_notebook as tqdm except NameError: from tqdm import tqdm def evaluate_dataset(dataset, predictor, sam_type=None, oracle=False, gra_oracle=False, **kwargs): all_ious = [] start_time = time() all_gras = {} for index in tqdm(range(len(dataset)), leave=False): sample = dataset.get_sample(index) for object_id in sample.objects_ids: if gra_oracle: sample_ious, gra_idx = evaluate_sample_oracle(sample.image, sample.gt_mask(object_id), predictor, sample_id=index, sam_type=sam_type, oracle=oracle, **kwargs) all_gras[gra_idx] = all_gras.get(gra_idx, 0) + 1 else: _, sample_ious, _ = evaluate_sample(sample.image, sample.gt_mask(object_id), predictor, sample_id=index, sam_type=sam_type, oracle=oracle, **kwargs) all_ious.append(sample_ious) end_time = time() elapsed_time = end_time - start_time if len(all_gras) > 0: print(all_gras) return all_ious, elapsed_time def evaluate_sample(image, gt_mask, predictor, max_iou_thr, pred_thr=0.49, min_clicks=1, max_clicks=20, sample_id=None, sam_type=False, oracle=False, callback=None): clicker = Clicker(gt_mask=gt_mask) pred_mask = np.zeros_like(gt_mask) ious_list = [] with torch.no_grad(): predictor.set_input_image(image) if sam_type == 'SAM': for click_indx in range(max_clicks): clicker.make_next_click(pred_mask) point_coords, point_labels = get_sam_input(clicker) if oracle: ious = [] pred_masks = [] pred_probs, _, _ = predictor.predict(point_coords, point_labels, multimask_output=True, return_logits=True) for idx in range(pred_probs.shape[0]): pred_masks.append(pred_probs[idx] > predictor.model.mask_threshold) ious.append(utils.get_iou(gt_mask, pred_masks[-1])) tgt_idx = np.argmax(np.array(ious)) iou = ious[tgt_idx] pred_mask = pred_masks[tgt_idx] else: pred_probs, _, _ = predictor.predict(point_coords, point_labels, multimask_output=False, return_logits=True) pred_probs = pred_probs[0] pred_mask = pred_probs > predictor.model.mask_threshold iou = utils.get_iou(gt_mask, pred_mask) if callback is not None: callback(image, gt_mask, pred_probs, sample_id, click_indx, clicker.clicks_list) ious_list.append(iou) if iou >= max_iou_thr and click_indx + 1 >= min_clicks: break return clicker.clicks_list, np.array(ious_list, dtype=np.float32), pred_probs else: for click_indx in range(max_clicks): clicker.make_next_click(pred_mask) pred_probs = predictor.get_prediction(clicker) pred_mask = pred_probs > pred_thr iou = utils.get_iou(gt_mask, pred_mask) if callback is not None: callback(image, gt_mask, pred_probs, sample_id, click_indx, clicker.clicks_list) ious_list.append(iou) if iou >= max_iou_thr and click_indx + 1 >= min_clicks: break return clicker.clicks_list, np.array(ious_list, dtype=np.float32), pred_probs def evaluate_sample_oracle(image, gt_mask, predictor, max_iou_thr, pred_thr=0.49, min_clicks=1, max_clicks=20, sample_id=None, sam_type=False, oracle=False, callback=None): clicker = Clicker(gt_mask=gt_mask) ious_lists = [] click_indxs = [] with torch.no_grad(): predictor.set_input_image(image) min_num = 100 for gra in range(1, 11): cur_gra = round(gra * 0.1, 1) ious_list = [] clicker.reset_clicks() pred_mask = np.zeros_like(gt_mask) if sam_type == 'SAM_GraCo': for click_indx in range(max_clicks): clicker.make_next_click(pred_mask) point_coords, point_labels = get_sam_input(clicker) if oracle: ious = [] pred_masks = [] pred_probs, _, _ = predictor.predict(point_coords, point_labels, gra=cur_gra, multimask_output=True, return_logits=True) for idx in range(pred_probs.shape[0]): pred_masks.append(pred_probs[idx] > predictor.model.mask_threshold) ious.append(utils.get_iou(gt_mask, pred_masks[-1])) tgt_idx = np.argmax(np.array(ious)) iou = ious[tgt_idx] pred_mask = pred_masks[tgt_idx] else: pred_probs, _, _ = predictor.predict(point_coords, point_labels, gra=cur_gra, multimask_output=False, return_logits=True) pred_probs = pred_probs[0] pred_mask = pred_probs > predictor.model.mask_threshold iou = utils.get_iou(gt_mask, pred_mask) if callback is not None: callback(image, gt_mask, pred_probs, sample_id, click_indx, clicker.clicks_list) ious_list.append(iou) if iou >= max_iou_thr and click_indx + 1 >= min_clicks: min_num = min(min_num, click_indx + 1) break if min_num <= max_clicks and click_indx + 1 > min_num: break else: predictor.prev_prediction = torch.zeros_like(predictor.original_image[:, :1, :, :]) for click_indx in range(max_clicks): clicker.make_next_click(pred_mask) pred_probs = predictor.get_prediction(clicker, gra=cur_gra) pred_mask = pred_probs > pred_thr iou = utils.get_iou(gt_mask, pred_mask) if callback is not None: callback(image, gt_mask, pred_probs, sample_id, click_indx, clicker.clicks_list) ious_list.append(iou) if iou >= max_iou_thr and click_indx + 1 >= min_clicks: min_num = min(min_num, click_indx + 1) break if min_num <= max_clicks and click_indx + 1 > min_num: break ious_lists.append(np.array(ious_list, dtype=np.float32)) click_indxs.append(click_indx) click_indxs = np.array(click_indxs) tgt_idxs = np.squeeze(np.argwhere(click_indxs == np.min(click_indxs)), axis=1) selected_ious = [ious_lists[i] for i in tgt_idxs] max_index = np.argmax([ious[0] for ious in selected_ious]) ious = selected_ious[max_index] tgt_idx = tgt_idxs[max_index] return ious, tgt_idx def get_sam_input(clicker, reverse=True): clicks_list = clicker.get_clicks() points_nd = get_points_nd([clicks_list]) point_length = len(points_nd[0]) // 2 point_coords = [] point_labels = [] for i, point in enumerate(points_nd[0]): if point[0] == -1: continue if i < point_length: point_labels.append(1) else: point_labels.append(0) if reverse: point_coords.append([point[1], point[0]]) # for SAM return np.array(point_coords), np.array(point_labels) def get_points_nd(clicks_lists): total_clicks = [] num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists] num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)] num_max_points = max(num_pos_clicks + num_neg_clicks) num_max_points = max(1, num_max_points) for clicks_list in clicks_lists: pos_clicks = [click.coords_and_indx for click in clicks_list if click.is_positive] pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1, -1)] neg_clicks = [click.coords_and_indx for click in clicks_list if not click.is_positive] neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1, -1)] total_clicks.append(pos_clicks + neg_clicks) return total_clicks