Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn | |
import numpy as np | |
import pdb | |
class VNL_Loss(torch.nn.Module): | |
""" | |
Virtual Normal Loss Function. | |
""" | |
def __init__(self, focal_x, focal_y, input_size, | |
delta_cos=0.867, delta_diff_x=0.01, | |
delta_diff_y=0.01, delta_diff_z=0.01, | |
delta_z=0.0001, sample_ratio=0.15): | |
super(VNL_Loss, self).__init__() | |
self.fx = torch.tensor([focal_x], dtype=torch.float32) #.to(cuda0) | |
self.fy = torch.tensor([focal_y], dtype=torch.float32) #.to(cuda0) | |
self.input_size = input_size | |
self.u0 = torch.tensor(input_size[1] // 2, dtype=torch.float32) #.to(cuda0) | |
self.v0 = torch.tensor(input_size[0] // 2, dtype=torch.float32) #.to(cuda0) | |
self.init_image_coor() | |
self.delta_cos = delta_cos | |
self.delta_diff_x = delta_diff_x | |
self.delta_diff_y = delta_diff_y | |
self.delta_diff_z = delta_diff_z | |
self.delta_z = delta_z | |
self.sample_ratio = sample_ratio | |
def init_image_coor(self): | |
x_row = np.arange(0, self.input_size[1]) | |
x = np.tile(x_row, (self.input_size[0], 1)) | |
x = x[np.newaxis, :, :] | |
x = x.astype(np.float32) | |
x = torch.from_numpy(x.copy()) #.to(cuda0) | |
self.u_u0 = x - self.u0 | |
y_col = np.arange(0, self.input_size[0]) # y_col = np.arange(0, height) | |
y = np.tile(y_col, (self.input_size[1], 1)).T | |
y = y[np.newaxis, :, :] | |
y = y.astype(np.float32) | |
y = torch.from_numpy(y.copy()) #.to(cuda0) | |
self.v_v0 = y - self.v0 | |
def transfer_xyz(self, depth): | |
# print('!!!!!!!!!!!!!!!111111 ', self.u_u0.device, torch.abs(depth).device, self.fx.device) | |
x = self.u_u0 * torch.abs(depth) / self.fx | |
y = self.v_v0 * torch.abs(depth) / self.fy | |
z = depth | |
pw = torch.cat([x, y, z], 1).permute(0, 2, 3, 1) # [b, h, w, c] | |
return pw | |
def select_index(self): | |
valid_width = self.input_size[1] | |
valid_height = self.input_size[0] | |
num = valid_width * valid_height | |
p1 = np.random.choice(num, int(num * self.sample_ratio), replace=True) | |
np.random.shuffle(p1) | |
p2 = np.random.choice(num, int(num * self.sample_ratio), replace=True) | |
np.random.shuffle(p2) | |
p3 = np.random.choice(num, int(num * self.sample_ratio), replace=True) | |
np.random.shuffle(p3) | |
p1_x = p1 % self.input_size[1] | |
p1_y = (p1 / self.input_size[1]).astype(np.int) | |
p2_x = p2 % self.input_size[1] | |
p2_y = (p2 / self.input_size[1]).astype(np.int) | |
p3_x = p3 % self.input_size[1] | |
p3_y = (p3 / self.input_size[1]).astype(np.int) | |
p123 = {'p1_x': p1_x, 'p1_y': p1_y, 'p2_x': p2_x, 'p2_y': p2_y, 'p3_x': p3_x, 'p3_y': p3_y} | |
return p123 | |
def form_pw_groups(self, p123, pw): | |
""" | |
Form 3D points groups, with 3 points in each grouup. | |
:param p123: points index | |
:param pw: 3D points | |
:return: | |
""" | |
p1_x = p123['p1_x'] | |
p1_y = p123['p1_y'] | |
p2_x = p123['p2_x'] | |
p2_y = p123['p2_y'] | |
p3_x = p123['p3_x'] | |
p3_y = p123['p3_y'] | |
pw1 = pw[:, p1_y, p1_x, :] | |
pw2 = pw[:, p2_y, p2_x, :] | |
pw3 = pw[:, p3_y, p3_x, :] | |
# [B, N, 3(x,y,z), 3(p1,p2,p3)] | |
pw_groups = torch.cat([pw1[:, :, :, np.newaxis], pw2[:, :, :, np.newaxis], pw3[:, :, :, np.newaxis]], 3) | |
return pw_groups | |
def filter_mask(self, p123, gt_xyz, delta_cos=0.867, | |
delta_diff_x=0.005, | |
delta_diff_y=0.005, | |
delta_diff_z=0.005): | |
pw = self.form_pw_groups(p123, gt_xyz) | |
pw12 = pw[:, :, :, 1] - pw[:, :, :, 0] | |
pw13 = pw[:, :, :, 2] - pw[:, :, :, 0] | |
pw23 = pw[:, :, :, 2] - pw[:, :, :, 1] | |
###ignore linear | |
pw_diff = torch.cat([pw12[:, :, :, np.newaxis], pw13[:, :, :, np.newaxis], pw23[:, :, :, np.newaxis]], | |
3) # [b, n, 3, 3] | |
m_batchsize, groups, coords, index = pw_diff.shape | |
proj_query = pw_diff.view(m_batchsize * groups, -1, index).permute(0, 2, 1) # (B* X CX(3)) [bn, 3(p123), 3(xyz)] | |
proj_key = pw_diff.view(m_batchsize * groups, -1, index) # B X (3)*C [bn, 3(xyz), 3(p123)] | |
q_norm = proj_query.norm(2, dim=2) | |
nm = torch.bmm(q_norm.view(m_batchsize * groups, index, 1), q_norm.view(m_batchsize * groups, 1, index)) #[] | |
energy = torch.bmm(proj_query, proj_key) # transpose check [bn, 3(p123), 3(p123)] | |
norm_energy = energy / (nm + 1e-8) | |
norm_energy = norm_energy.view(m_batchsize * groups, -1) | |
mask_cos = torch.sum((norm_energy > delta_cos) + (norm_energy < -delta_cos), 1) > 3 # igonre | |
mask_cos = mask_cos.view(m_batchsize, groups) | |
##ignore padding and invilid depth | |
mask_pad = torch.sum(pw[:, :, 2, :] > self.delta_z, 2) == 3 | |
###ignore near | |
mask_x = torch.sum(torch.abs(pw_diff[:, :, 0, :]) < delta_diff_x, 2) > 0 | |
mask_y = torch.sum(torch.abs(pw_diff[:, :, 1, :]) < delta_diff_y, 2) > 0 | |
mask_z = torch.sum(torch.abs(pw_diff[:, :, 2, :]) < delta_diff_z, 2) > 0 | |
mask_ignore = (mask_x & mask_y & mask_z) | mask_cos | |
mask_near = ~mask_ignore | |
mask = mask_pad & mask_near | |
return mask, pw | |
def select_points_groups(self, gt_depth, pred_depth): | |
pw_gt = self.transfer_xyz(gt_depth) | |
pw_pred = self.transfer_xyz(pred_depth) | |
#pdb.set_trace() | |
B, C, H, W = gt_depth.shape | |
p123 = self.select_index() | |
# mask:[b, n], pw_groups_gt: [b, n, 3(x,y,z), 3(p1,p2,p3)] | |
mask, pw_groups_gt = self.filter_mask(p123, pw_gt, | |
delta_cos=0.867, | |
delta_diff_x=0.005, | |
delta_diff_y=0.005, | |
delta_diff_z=0.005) | |
# [b, n, 3, 3] | |
pw_groups_pred = self.form_pw_groups(p123, pw_pred) | |
pw_groups_pred[pw_groups_pred[:, :, 2, :] == 0] = 0.0001 | |
mask_broadcast = mask.repeat(1, 9).reshape(B, 3, 3, -1).permute(0, 3, 1, 2) | |
pw_groups_pred_not_ignore = pw_groups_pred[mask_broadcast].reshape(1, -1, 3, 3) | |
pw_groups_gt_not_ignore = pw_groups_gt[mask_broadcast].reshape(1, -1, 3, 3) | |
return pw_groups_gt_not_ignore, pw_groups_pred_not_ignore | |
def forward(self, gt_depth, pred_depth, select=True): | |
""" | |
Virtual normal loss. | |
:param pred_depth: predicted depth map, [B,W,H,C] | |
:param data: target label, ground truth depth, [B, W, H, C], padding region [padding_up, padding_down] | |
:return: | |
""" | |
device = gt_depth.device | |
self.fx = self.fx.to(device) | |
self.fy = self.fy.to(device) | |
self.u0 = self.u0.to(device) | |
self.v0 = self.v0.to(device) | |
self.u_u0 = self.u_u0.to(device) | |
self.v_v0 = self.v_v0.to(device) | |
# print("************ ", self.fx.device, self.u_u0.device) | |
gt_points, dt_points = self.select_points_groups(gt_depth, pred_depth) | |
gt_p12 = gt_points[:, :, :, 1] - gt_points[:, :, :, 0] | |
gt_p13 = gt_points[:, :, :, 2] - gt_points[:, :, :, 0] | |
dt_p12 = dt_points[:, :, :, 1] - dt_points[:, :, :, 0] | |
dt_p13 = dt_points[:, :, :, 2] - dt_points[:, :, :, 0] | |
gt_normal = torch.cross(gt_p12, gt_p13, dim=2) | |
dt_normal = torch.cross(dt_p12, dt_p13, dim=2) | |
dt_norm = torch.norm(dt_normal, 2, dim=2, keepdim=True) | |
gt_norm = torch.norm(gt_normal, 2, dim=2, keepdim=True) | |
dt_mask = dt_norm == 0.0 | |
gt_mask = gt_norm == 0.0 | |
dt_mask = dt_mask.to(torch.float32) | |
gt_mask = gt_mask.to(torch.float32) | |
dt_mask *= 0.01 | |
gt_mask *= 0.01 | |
gt_norm = gt_norm + gt_mask | |
dt_norm = dt_norm + dt_mask | |
gt_normal = gt_normal / gt_norm | |
dt_normal = dt_normal / dt_norm | |
#pdb.set_trace() | |
loss = torch.abs(gt_normal - dt_normal) | |
loss = torch.sum(torch.sum(loss, dim=2), dim=0) | |
if select: | |
loss, indices = torch.sort(loss, dim=0, descending=False) | |
loss = loss[int(loss.size(0) * 0.25):] | |
loss = torch.mean(loss) | |
return loss | |