Spaces:
Running
Running
import os | |
import os.path as osp | |
import cv2 | |
import torch | |
import numpy as np | |
from isegm.inference.clicker import Clicker | |
from isegm.inference import utils | |
def inference(image, gt_mask, predictor, threshold=0.5, min_clicks=1, max_clicks=20): | |
clicker = Clicker(gt_mask=gt_mask) | |
pred_mask = np.zeros_like(gt_mask) | |
ious_list = [] | |
probs_list = [] | |
masks_list = [] | |
with torch.no_grad(): | |
predictor.set_input_image(image) | |
for click_indx in range(max_clicks): | |
clicker.make_next_click(pred_mask) | |
pred_probs = predictor.get_prediction(clicker) | |
pred_mask = pred_probs > threshold | |
iou = utils.get_iou(gt_mask, pred_mask) | |
ious_list.append(iou) | |
probs_list.append(pred_probs.copy()) | |
masks_list.append(pred_mask.copy()) | |
return clicker.clicks_list, np.array(ious_list, dtype=np.float32), probs_list, masks_list | |
def visualization(sample, predictor, mask=True, score=False, contour=True, click=True, threshold=0.5, min_clicks=1, max_clicks=20, out_dir=None): | |
mask = False if score else mask | |
if out_dir is not None: | |
out_dir = osp.join(out_dir, str(sample.sample_id)) | |
os.makedirs(out_dir, exist_ok=True) | |
clicks, ious, probs, masks = inference(sample.image, sample.gt_mask(sample.objects_ids[0]), | |
predictor, threshold=threshold, | |
min_clicks=min_clicks, max_clicks=max_clicks) | |
outputs = [] | |
show = cv2.cvtColor(sample.image.copy(), cv2.COLOR_RGB2BGR) | |
gt_mask = sample.gt_mask(sample.objects_ids[0]).astype(np.bool8) | |
if mask: | |
show[~gt_mask] = (show[~gt_mask] * 0.4).astype(np.uint8) | |
if contour: | |
contours, _ = cv2.findContours(gt_mask.astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE) | |
show = cv2.drawContours(show, contours, -1, (106, 211, 253), 2) | |
if out_dir is not None: | |
cv2.imwrite(osp.join(out_dir, f'gt_mask.jpg'), show) | |
outputs.append(show) | |
for i in range(len(clicks)): | |
show = cv2.cvtColor(sample.image.copy(), cv2.COLOR_RGB2BGR) | |
if score: | |
score_map = cv2.applyColorMap((probs[i] * 255).astype(np.uint8), cv2.COLORMAP_JET) | |
show = cv2.addWeighted(show, 0.5, score_map, 0.5, 0) | |
if mask: | |
show[~masks[i]] = (show[~masks[i]] * 0.4).astype(np.uint8) | |
if contour: | |
contours, _ = cv2.findContours(masks[i].astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE) | |
min_area_threshold = 10 | |
for cur_contour in contours: | |
area = cv2.contourArea(cur_contour) | |
if area > min_area_threshold: | |
show = cv2.drawContours(show, [cur_contour], -1, (106, 211, 253), 2) | |
# show = cv2.drawContours(show, contours, -1, (106, 211, 253), 4) | |
if click: | |
for j in range(i + 1): | |
color = (80, 208, 146) if clicks[j].is_positive else (0, 0, 192) | |
coords = (clicks[j].coords[1], clicks[j].coords[0]) | |
show = cv2.circle(show, coords, 7, (0, 0, 0), -1) | |
show = cv2.circle(show, coords, 5, color, -1) | |
outputs.append(show) | |
if out_dir is not None: | |
cv2.imwrite(osp.join(out_dir, f'{i+1}_{ious[i]:.2f}.jpg'), show) | |
return outputs, ious | |