Spaces:
Running
Running
File size: 3,366 Bytes
6d1366a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import os
import os.path as osp
import cv2
import torch
import numpy as np
from isegm.inference.clicker import Clicker
from isegm.inference import utils
def inference(image, gt_mask, predictor, threshold=0.5, min_clicks=1, max_clicks=20):
clicker = Clicker(gt_mask=gt_mask)
pred_mask = np.zeros_like(gt_mask)
ious_list = []
probs_list = []
masks_list = []
with torch.no_grad():
predictor.set_input_image(image)
for click_indx in range(max_clicks):
clicker.make_next_click(pred_mask)
pred_probs = predictor.get_prediction(clicker)
pred_mask = pred_probs > threshold
iou = utils.get_iou(gt_mask, pred_mask)
ious_list.append(iou)
probs_list.append(pred_probs.copy())
masks_list.append(pred_mask.copy())
return clicker.clicks_list, np.array(ious_list, dtype=np.float32), probs_list, masks_list
def visualization(sample, predictor, mask=True, score=False, contour=True, click=True, threshold=0.5, min_clicks=1, max_clicks=20, out_dir=None):
mask = False if score else mask
if out_dir is not None:
out_dir = osp.join(out_dir, str(sample.sample_id))
os.makedirs(out_dir, exist_ok=True)
clicks, ious, probs, masks = inference(sample.image, sample.gt_mask(sample.objects_ids[0]),
predictor, threshold=threshold,
min_clicks=min_clicks, max_clicks=max_clicks)
outputs = []
show = cv2.cvtColor(sample.image.copy(), cv2.COLOR_RGB2BGR)
gt_mask = sample.gt_mask(sample.objects_ids[0]).astype(np.bool8)
if mask:
show[~gt_mask] = (show[~gt_mask] * 0.4).astype(np.uint8)
if contour:
contours, _ = cv2.findContours(gt_mask.astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)
show = cv2.drawContours(show, contours, -1, (106, 211, 253), 2)
if out_dir is not None:
cv2.imwrite(osp.join(out_dir, f'gt_mask.jpg'), show)
outputs.append(show)
for i in range(len(clicks)):
show = cv2.cvtColor(sample.image.copy(), cv2.COLOR_RGB2BGR)
if score:
score_map = cv2.applyColorMap((probs[i] * 255).astype(np.uint8), cv2.COLORMAP_JET)
show = cv2.addWeighted(show, 0.5, score_map, 0.5, 0)
if mask:
show[~masks[i]] = (show[~masks[i]] * 0.4).astype(np.uint8)
if contour:
contours, _ = cv2.findContours(masks[i].astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)
min_area_threshold = 10
for cur_contour in contours:
area = cv2.contourArea(cur_contour)
if area > min_area_threshold:
show = cv2.drawContours(show, [cur_contour], -1, (106, 211, 253), 2)
# show = cv2.drawContours(show, contours, -1, (106, 211, 253), 4)
if click:
for j in range(i + 1):
color = (80, 208, 146) if clicks[j].is_positive else (0, 0, 192)
coords = (clicks[j].coords[1], clicks[j].coords[0])
show = cv2.circle(show, coords, 7, (0, 0, 0), -1)
show = cv2.circle(show, coords, 5, color, -1)
outputs.append(show)
if out_dir is not None:
cv2.imwrite(osp.join(out_dir, f'{i+1}_{ious[i]:.2f}.jpg'), show)
return outputs, ious
|