Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is | |
# holder of all proprietary rights on this computer program. | |
# You can only use this computer program if you have closed | |
# a license agreement with MPG or you get the right to use the computer | |
# program from someone who is authorized to grant you that right. | |
# Any use of the computer program without a valid license is prohibited and | |
# liable to prosecution. | |
# | |
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung | |
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
# for Intelligent Systems. All rights reserved. | |
# | |
# Contact: [email protected] | |
import matplotlib.pyplot as plt | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
def plot_mask2D(mask, title="", point_coords=None, figsize=10, point_marker_size=5): | |
''' | |
Simple plotting tool to show intermediate mask predictions and points | |
where PointRend is applied. | |
Args: | |
mask (Tensor): mask prediction of shape HxW | |
title (str): title for the plot | |
point_coords ((Tensor, Tensor)): x and y point coordinates | |
figsize (int): size of the figure to plot | |
point_marker_size (int): marker size for points | |
''' | |
H, W = mask.shape | |
plt.figure(figsize=(figsize, figsize)) | |
if title: | |
title += ", " | |
plt.title("{}resolution {}x{}".format(title, H, W), fontsize=30) | |
plt.ylabel(H, fontsize=30) | |
plt.xlabel(W, fontsize=30) | |
plt.xticks([], []) | |
plt.yticks([], []) | |
plt.imshow(mask.detach(), interpolation="nearest", cmap=plt.get_cmap('gray')) | |
if point_coords is not None: | |
plt.scatter( | |
x=point_coords[0], y=point_coords[1], color="red", s=point_marker_size, clip_on=True | |
) | |
plt.xlim(-0.5, W - 0.5) | |
plt.ylim(H - 0.5, -0.5) | |
plt.show() | |
def plot_mask3D( | |
mask=None, title="", point_coords=None, figsize=1500, point_marker_size=8, interactive=True | |
): | |
''' | |
Simple plotting tool to show intermediate mask predictions and points | |
where PointRend is applied. | |
Args: | |
mask (Tensor): mask prediction of shape DxHxW | |
title (str): title for the plot | |
point_coords ((Tensor, Tensor, Tensor)): x and y and z point coordinates | |
figsize (int): size of the figure to plot | |
point_marker_size (int): marker size for points | |
''' | |
import trimesh | |
import vtkplotter | |
from skimage import measure | |
vp = vtkplotter.Plotter(title=title, size=(figsize, figsize)) | |
vis_list = [] | |
if mask is not None: | |
mask = mask.detach().to("cpu").numpy() | |
mask = mask.transpose(2, 1, 0) | |
# marching cube to find surface | |
verts, faces, normals, values = measure.marching_cubes_lewiner( | |
mask, 0.5, gradient_direction='ascent' | |
) | |
# create a mesh | |
mesh = trimesh.Trimesh(verts, faces) | |
mesh.visual.face_colors = [200, 200, 250, 100] | |
vis_list.append(mesh) | |
if point_coords is not None: | |
point_coords = torch.stack(point_coords, 1).to("cpu").numpy() | |
# import numpy as np | |
# select_x = np.logical_and(point_coords[:, 0] >= 16, point_coords[:, 0] <= 112) | |
# select_y = np.logical_and(point_coords[:, 1] >= 48, point_coords[:, 1] <= 272) | |
# select_z = np.logical_and(point_coords[:, 2] >= 16, point_coords[:, 2] <= 112) | |
# select = np.logical_and(np.logical_and(select_x, select_y), select_z) | |
# point_coords = point_coords[select, :] | |
pc = vtkplotter.Points(point_coords, r=point_marker_size, c='red') | |
vis_list.append(pc) | |
vp.show(*vis_list, bg="white", axes=1, interactive=interactive, azimuth=30, elevation=30) | |
def create_grid3D(min, max, steps): | |
if type(min) is int: | |
min = (min, min, min) # (x, y, z) | |
if type(max) is int: | |
max = (max, max, max) # (x, y) | |
if type(steps) is int: | |
steps = (steps, steps, steps) # (x, y, z) | |
arrangeX = torch.linspace(min[0], max[0], steps[0]).long() | |
arrangeY = torch.linspace(min[1], max[1], steps[1]).long() | |
arrangeZ = torch.linspace(min[2], max[2], steps[2]).long() | |
gridD, girdH, gridW = torch.meshgrid([arrangeZ, arrangeY, arrangeX], indexing='ij') | |
coords = torch.stack([gridW, girdH, gridD]) # [2, steps[0], steps[1], steps[2]] | |
coords = coords.view(3, -1).t() # [N, 3] | |
return coords | |
def create_grid2D(min, max, steps): | |
if type(min) is int: | |
min = (min, min) # (x, y) | |
if type(max) is int: | |
max = (max, max) # (x, y) | |
if type(steps) is int: | |
steps = (steps, steps) # (x, y) | |
arrangeX = torch.linspace(min[0], max[0], steps[0]).long() | |
arrangeY = torch.linspace(min[1], max[1], steps[1]).long() | |
girdH, gridW = torch.meshgrid([arrangeY, arrangeX], indexing='ij') | |
coords = torch.stack([gridW, girdH]) # [2, steps[0], steps[1]] | |
coords = coords.view(2, -1).t() # [N, 2] | |
return coords | |
class SmoothConv2D(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size=3): | |
super().__init__() | |
assert kernel_size % 2 == 1, "kernel_size for smooth_conv must be odd: {3, 5, ...}" | |
self.padding = (kernel_size - 1) // 2 | |
weight = torch.ones((in_channels, out_channels, kernel_size, kernel_size), | |
dtype=torch.float32) / (kernel_size**2) | |
self.register_buffer('weight', weight) | |
def forward(self, input): | |
return F.conv2d(input, self.weight, padding=self.padding) | |
class SmoothConv3D(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size=3): | |
super().__init__() | |
assert kernel_size % 2 == 1, "kernel_size for smooth_conv must be odd: {3, 5, ...}" | |
self.padding = (kernel_size - 1) // 2 | |
weight = torch.ones((in_channels, out_channels, kernel_size, kernel_size, kernel_size), | |
dtype=torch.float32) / (kernel_size**3) | |
self.register_buffer('weight', weight) | |
def forward(self, input): | |
return F.conv3d(input, self.weight, padding=self.padding) | |
def build_smooth_conv3D(in_channels=1, out_channels=1, kernel_size=3, padding=1): | |
smooth_conv = torch.nn.Conv3d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
padding=padding | |
) | |
smooth_conv.weight.data = torch.ones( | |
(in_channels, out_channels, kernel_size, kernel_size, kernel_size), dtype=torch.float32 | |
) / (kernel_size**3) | |
smooth_conv.bias.data = torch.zeros(out_channels) | |
return smooth_conv | |
def build_smooth_conv2D(in_channels=1, out_channels=1, kernel_size=3, padding=1): | |
smooth_conv = torch.nn.Conv2d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
padding=padding | |
) | |
smooth_conv.weight.data = torch.ones((in_channels, out_channels, kernel_size, kernel_size), | |
dtype=torch.float32) / (kernel_size**2) | |
smooth_conv.bias.data = torch.zeros(out_channels) | |
return smooth_conv | |
def get_uncertain_point_coords_on_grid3D(uncertainty_map, num_points, **kwargs): | |
""" | |
Find `num_points` most uncertain points from `uncertainty_map` grid. | |
Args: | |
uncertainty_map (Tensor): A tensor of shape (N, 1, H, W, D) that contains uncertainty | |
values for a set of points on a regular H x W x D grid. | |
num_points (int): The number of points P to select. | |
Returns: | |
point_indices (Tensor): A tensor of shape (N, P) that contains indices from | |
[0, H x W x D) of the most uncertain points. | |
point_coords (Tensor): A tensor of shape (N, P, 3) that contains [0, 1] x [0, 1] normalized | |
coordinates of the most uncertain points from the H x W x D grid. | |
""" | |
R, _, D, H, W = uncertainty_map.shape | |
# h_step = 1.0 / float(H) | |
# w_step = 1.0 / float(W) | |
# d_step = 1.0 / float(D) | |
num_points = min(D * H * W, num_points) | |
point_scores, point_indices = torch.topk( | |
uncertainty_map.view(R, D * H * W), k=num_points, dim=1 | |
) | |
point_coords = torch.zeros(R, num_points, 3, dtype=torch.float, device=uncertainty_map.device) | |
# point_coords[:, :, 0] = h_step / 2.0 + (point_indices // (W * D)).to(torch.float) * h_step | |
# point_coords[:, :, 1] = w_step / 2.0 + (point_indices % (W * D) // D).to(torch.float) * w_step | |
# point_coords[:, :, 2] = d_step / 2.0 + (point_indices % D).to(torch.float) * d_step | |
point_coords[:, :, 0] = (point_indices % W).to(torch.float) # x | |
point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float) # y | |
point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float) # z | |
print(f"resolution {D} x {H} x {W}", point_scores.min(), point_scores.max()) | |
return point_indices, point_coords | |
def get_uncertain_point_coords_on_grid3D_faster(uncertainty_map, num_points, clip_min): | |
""" | |
Find `num_points` most uncertain points from `uncertainty_map` grid. | |
Args: | |
uncertainty_map (Tensor): A tensor of shape (N, 1, H, W, D) that contains uncertainty | |
values for a set of points on a regular H x W x D grid. | |
num_points (int): The number of points P to select. | |
Returns: | |
point_indices (Tensor): A tensor of shape (N, P) that contains indices from | |
[0, H x W x D) of the most uncertain points. | |
point_coords (Tensor): A tensor of shape (N, P, 3) that contains [0, 1] x [0, 1] normalized | |
coordinates of the most uncertain points from the H x W x D grid. | |
""" | |
R, _, D, H, W = uncertainty_map.shape | |
# h_step = 1.0 / float(H) | |
# w_step = 1.0 / float(W) | |
# d_step = 1.0 / float(D) | |
assert R == 1, "batchsize > 1 is not implemented!" | |
uncertainty_map = uncertainty_map.view(D * H * W) | |
indices = (uncertainty_map >= clip_min).nonzero().squeeze(1) | |
num_points = min(num_points, indices.size(0)) | |
point_scores, point_indices = torch.topk(uncertainty_map[indices], k=num_points, dim=0) | |
point_indices = indices[point_indices].unsqueeze(0) | |
point_coords = torch.zeros(R, num_points, 3, dtype=torch.float, device=uncertainty_map.device) | |
# point_coords[:, :, 0] = h_step / 2.0 + (point_indices // (W * D)).to(torch.float) * h_step | |
# point_coords[:, :, 1] = w_step / 2.0 + (point_indices % (W * D) // D).to(torch.float) * w_step | |
# point_coords[:, :, 2] = d_step / 2.0 + (point_indices % D).to(torch.float) * d_step | |
point_coords[:, :, 0] = (point_indices % W).to(torch.float) # x | |
point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float) # y | |
point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float) # z | |
# print (f"resolution {D} x {H} x {W}", point_scores.min(), point_scores.max()) | |
return point_indices, point_coords | |
def get_uncertain_point_coords_on_grid2D(uncertainty_map, num_points, **kwargs): | |
""" | |
Find `num_points` most uncertain points from `uncertainty_map` grid. | |
Args: | |
uncertainty_map (Tensor): A tensor of shape (N, 1, H, W) that contains uncertainty | |
values for a set of points on a regular H x W grid. | |
num_points (int): The number of points P to select. | |
Returns: | |
point_indices (Tensor): A tensor of shape (N, P) that contains indices from | |
[0, H x W) of the most uncertain points. | |
point_coords (Tensor): A tensor of shape (N, P, 2) that contains [0, 1] x [0, 1] normalized | |
coordinates of the most uncertain points from the H x W grid. | |
""" | |
R, _, H, W = uncertainty_map.shape | |
# h_step = 1.0 / float(H) | |
# w_step = 1.0 / float(W) | |
num_points = min(H * W, num_points) | |
point_scores, point_indices = torch.topk(uncertainty_map.view(R, H * W), k=num_points, dim=1) | |
point_coords = torch.zeros(R, num_points, 2, dtype=torch.long, device=uncertainty_map.device) | |
# point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step | |
# point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step | |
point_coords[:, :, 0] = (point_indices % W).to(torch.long) | |
point_coords[:, :, 1] = (point_indices // W).to(torch.long) | |
# print (point_scores.min(), point_scores.max()) | |
return point_indices, point_coords | |
def get_uncertain_point_coords_on_grid2D_faster(uncertainty_map, num_points, clip_min): | |
""" | |
Find `num_points` most uncertain points from `uncertainty_map` grid. | |
Args: | |
uncertainty_map (Tensor): A tensor of shape (N, 1, H, W) that contains uncertainty | |
values for a set of points on a regular H x W grid. | |
num_points (int): The number of points P to select. | |
Returns: | |
point_indices (Tensor): A tensor of shape (N, P) that contains indices from | |
[0, H x W) of the most uncertain points. | |
point_coords (Tensor): A tensor of shape (N, P, 2) that contains [0, 1] x [0, 1] normalized | |
coordinates of the most uncertain points from the H x W grid. | |
""" | |
R, _, H, W = uncertainty_map.shape | |
# h_step = 1.0 / float(H) | |
# w_step = 1.0 / float(W) | |
assert R == 1, "batchsize > 1 is not implemented!" | |
uncertainty_map = uncertainty_map.view(H * W) | |
indices = (uncertainty_map >= clip_min).nonzero().squeeze(1) | |
num_points = min(num_points, indices.size(0)) | |
point_scores, point_indices = torch.topk(uncertainty_map[indices], k=num_points, dim=0) | |
point_indices = indices[point_indices].unsqueeze(0) | |
point_coords = torch.zeros(R, num_points, 2, dtype=torch.long, device=uncertainty_map.device) | |
# point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step | |
# point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step | |
point_coords[:, :, 0] = (point_indices % W).to(torch.long) | |
point_coords[:, :, 1] = (point_indices // W).to(torch.long) | |
# print (point_scores.min(), point_scores.max()) | |
return point_indices, point_coords | |
def calculate_uncertainty(logits, classes=None, balance_value=0.5): | |
""" | |
We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the | |
foreground class in `classes`. | |
Args: | |
logits (Tensor): A tensor of shape (R, C, ...) or (R, 1, ...) for class-specific or | |
class-agnostic, where R is the total number of predicted masks in all images and C is | |
the number of foreground classes. The values are logits. | |
classes (list): A list of length R that contains either predicted of ground truth class | |
for eash predicted mask. | |
Returns: | |
scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with | |
the most uncertain locations having the highest uncertainty score. | |
""" | |
if logits.shape[1] == 1: | |
gt_class_logits = logits | |
else: | |
gt_class_logits = logits[torch.arange(logits.shape[0], device=logits.device), | |
classes].unsqueeze(1) | |
return -torch.abs(gt_class_logits - balance_value) | |