GraCo / isegm /utils /visualization.py
zhaoyian01's picture
Add application file
6d1366a
raw
history blame
3.37 kB
import os
import os.path as osp
import cv2
import torch
import numpy as np
from isegm.inference.clicker import Clicker
from isegm.inference import utils
def inference(image, gt_mask, predictor, threshold=0.5, min_clicks=1, max_clicks=20):
clicker = Clicker(gt_mask=gt_mask)
pred_mask = np.zeros_like(gt_mask)
ious_list = []
probs_list = []
masks_list = []
with torch.no_grad():
predictor.set_input_image(image)
for click_indx in range(max_clicks):
clicker.make_next_click(pred_mask)
pred_probs = predictor.get_prediction(clicker)
pred_mask = pred_probs > threshold
iou = utils.get_iou(gt_mask, pred_mask)
ious_list.append(iou)
probs_list.append(pred_probs.copy())
masks_list.append(pred_mask.copy())
return clicker.clicks_list, np.array(ious_list, dtype=np.float32), probs_list, masks_list
def visualization(sample, predictor, mask=True, score=False, contour=True, click=True, threshold=0.5, min_clicks=1, max_clicks=20, out_dir=None):
mask = False if score else mask
if out_dir is not None:
out_dir = osp.join(out_dir, str(sample.sample_id))
os.makedirs(out_dir, exist_ok=True)
clicks, ious, probs, masks = inference(sample.image, sample.gt_mask(sample.objects_ids[0]),
predictor, threshold=threshold,
min_clicks=min_clicks, max_clicks=max_clicks)
outputs = []
show = cv2.cvtColor(sample.image.copy(), cv2.COLOR_RGB2BGR)
gt_mask = sample.gt_mask(sample.objects_ids[0]).astype(np.bool8)
if mask:
show[~gt_mask] = (show[~gt_mask] * 0.4).astype(np.uint8)
if contour:
contours, _ = cv2.findContours(gt_mask.astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)
show = cv2.drawContours(show, contours, -1, (106, 211, 253), 2)
if out_dir is not None:
cv2.imwrite(osp.join(out_dir, f'gt_mask.jpg'), show)
outputs.append(show)
for i in range(len(clicks)):
show = cv2.cvtColor(sample.image.copy(), cv2.COLOR_RGB2BGR)
if score:
score_map = cv2.applyColorMap((probs[i] * 255).astype(np.uint8), cv2.COLORMAP_JET)
show = cv2.addWeighted(show, 0.5, score_map, 0.5, 0)
if mask:
show[~masks[i]] = (show[~masks[i]] * 0.4).astype(np.uint8)
if contour:
contours, _ = cv2.findContours(masks[i].astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)
min_area_threshold = 10
for cur_contour in contours:
area = cv2.contourArea(cur_contour)
if area > min_area_threshold:
show = cv2.drawContours(show, [cur_contour], -1, (106, 211, 253), 2)
# show = cv2.drawContours(show, contours, -1, (106, 211, 253), 4)
if click:
for j in range(i + 1):
color = (80, 208, 146) if clicks[j].is_positive else (0, 0, 192)
coords = (clicks[j].coords[1], clicks[j].coords[0])
show = cv2.circle(show, coords, 7, (0, 0, 0), -1)
show = cv2.circle(show, coords, 5, color, -1)
outputs.append(show)
if out_dir is not None:
cv2.imwrite(osp.join(out_dir, f'{i+1}_{ious[i]:.2f}.jpg'), show)
return outputs, ious