Spaces:
Runtime error
Runtime error
File size: 8,335 Bytes
9afcee2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
import torch
import torch.nn
import numpy as np
import pdb
class VNL_Loss(torch.nn.Module):
"""
Virtual Normal Loss Function.
"""
def __init__(self, focal_x, focal_y, input_size,
delta_cos=0.867, delta_diff_x=0.01,
delta_diff_y=0.01, delta_diff_z=0.01,
delta_z=0.0001, sample_ratio=0.15):
super(VNL_Loss, self).__init__()
self.fx = torch.tensor([focal_x], dtype=torch.float32) #.to(cuda0)
self.fy = torch.tensor([focal_y], dtype=torch.float32) #.to(cuda0)
self.input_size = input_size
self.u0 = torch.tensor(input_size[1] // 2, dtype=torch.float32) #.to(cuda0)
self.v0 = torch.tensor(input_size[0] // 2, dtype=torch.float32) #.to(cuda0)
self.init_image_coor()
self.delta_cos = delta_cos
self.delta_diff_x = delta_diff_x
self.delta_diff_y = delta_diff_y
self.delta_diff_z = delta_diff_z
self.delta_z = delta_z
self.sample_ratio = sample_ratio
def init_image_coor(self):
x_row = np.arange(0, self.input_size[1])
x = np.tile(x_row, (self.input_size[0], 1))
x = x[np.newaxis, :, :]
x = x.astype(np.float32)
x = torch.from_numpy(x.copy()) #.to(cuda0)
self.u_u0 = x - self.u0
y_col = np.arange(0, self.input_size[0]) # y_col = np.arange(0, height)
y = np.tile(y_col, (self.input_size[1], 1)).T
y = y[np.newaxis, :, :]
y = y.astype(np.float32)
y = torch.from_numpy(y.copy()) #.to(cuda0)
self.v_v0 = y - self.v0
def transfer_xyz(self, depth):
# print('!!!!!!!!!!!!!!!111111 ', self.u_u0.device, torch.abs(depth).device, self.fx.device)
x = self.u_u0 * torch.abs(depth) / self.fx
y = self.v_v0 * torch.abs(depth) / self.fy
z = depth
pw = torch.cat([x, y, z], 1).permute(0, 2, 3, 1) # [b, h, w, c]
return pw
def select_index(self):
valid_width = self.input_size[1]
valid_height = self.input_size[0]
num = valid_width * valid_height
p1 = np.random.choice(num, int(num * self.sample_ratio), replace=True)
np.random.shuffle(p1)
p2 = np.random.choice(num, int(num * self.sample_ratio), replace=True)
np.random.shuffle(p2)
p3 = np.random.choice(num, int(num * self.sample_ratio), replace=True)
np.random.shuffle(p3)
p1_x = p1 % self.input_size[1]
p1_y = (p1 / self.input_size[1]).astype(np.int)
p2_x = p2 % self.input_size[1]
p2_y = (p2 / self.input_size[1]).astype(np.int)
p3_x = p3 % self.input_size[1]
p3_y = (p3 / self.input_size[1]).astype(np.int)
p123 = {'p1_x': p1_x, 'p1_y': p1_y, 'p2_x': p2_x, 'p2_y': p2_y, 'p3_x': p3_x, 'p3_y': p3_y}
return p123
def form_pw_groups(self, p123, pw):
"""
Form 3D points groups, with 3 points in each grouup.
:param p123: points index
:param pw: 3D points
:return:
"""
p1_x = p123['p1_x']
p1_y = p123['p1_y']
p2_x = p123['p2_x']
p2_y = p123['p2_y']
p3_x = p123['p3_x']
p3_y = p123['p3_y']
pw1 = pw[:, p1_y, p1_x, :]
pw2 = pw[:, p2_y, p2_x, :]
pw3 = pw[:, p3_y, p3_x, :]
# [B, N, 3(x,y,z), 3(p1,p2,p3)]
pw_groups = torch.cat([pw1[:, :, :, np.newaxis], pw2[:, :, :, np.newaxis], pw3[:, :, :, np.newaxis]], 3)
return pw_groups
def filter_mask(self, p123, gt_xyz, delta_cos=0.867,
delta_diff_x=0.005,
delta_diff_y=0.005,
delta_diff_z=0.005):
pw = self.form_pw_groups(p123, gt_xyz)
pw12 = pw[:, :, :, 1] - pw[:, :, :, 0]
pw13 = pw[:, :, :, 2] - pw[:, :, :, 0]
pw23 = pw[:, :, :, 2] - pw[:, :, :, 1]
###ignore linear
pw_diff = torch.cat([pw12[:, :, :, np.newaxis], pw13[:, :, :, np.newaxis], pw23[:, :, :, np.newaxis]],
3) # [b, n, 3, 3]
m_batchsize, groups, coords, index = pw_diff.shape
proj_query = pw_diff.view(m_batchsize * groups, -1, index).permute(0, 2, 1) # (B* X CX(3)) [bn, 3(p123), 3(xyz)]
proj_key = pw_diff.view(m_batchsize * groups, -1, index) # B X (3)*C [bn, 3(xyz), 3(p123)]
q_norm = proj_query.norm(2, dim=2)
nm = torch.bmm(q_norm.view(m_batchsize * groups, index, 1), q_norm.view(m_batchsize * groups, 1, index)) #[]
energy = torch.bmm(proj_query, proj_key) # transpose check [bn, 3(p123), 3(p123)]
norm_energy = energy / (nm + 1e-8)
norm_energy = norm_energy.view(m_batchsize * groups, -1)
mask_cos = torch.sum((norm_energy > delta_cos) + (norm_energy < -delta_cos), 1) > 3 # igonre
mask_cos = mask_cos.view(m_batchsize, groups)
##ignore padding and invilid depth
mask_pad = torch.sum(pw[:, :, 2, :] > self.delta_z, 2) == 3
###ignore near
mask_x = torch.sum(torch.abs(pw_diff[:, :, 0, :]) < delta_diff_x, 2) > 0
mask_y = torch.sum(torch.abs(pw_diff[:, :, 1, :]) < delta_diff_y, 2) > 0
mask_z = torch.sum(torch.abs(pw_diff[:, :, 2, :]) < delta_diff_z, 2) > 0
mask_ignore = (mask_x & mask_y & mask_z) | mask_cos
mask_near = ~mask_ignore
mask = mask_pad & mask_near
return mask, pw
def select_points_groups(self, gt_depth, pred_depth):
pw_gt = self.transfer_xyz(gt_depth)
pw_pred = self.transfer_xyz(pred_depth)
#pdb.set_trace()
B, C, H, W = gt_depth.shape
p123 = self.select_index()
# mask:[b, n], pw_groups_gt: [b, n, 3(x,y,z), 3(p1,p2,p3)]
mask, pw_groups_gt = self.filter_mask(p123, pw_gt,
delta_cos=0.867,
delta_diff_x=0.005,
delta_diff_y=0.005,
delta_diff_z=0.005)
# [b, n, 3, 3]
pw_groups_pred = self.form_pw_groups(p123, pw_pred)
pw_groups_pred[pw_groups_pred[:, :, 2, :] == 0] = 0.0001
mask_broadcast = mask.repeat(1, 9).reshape(B, 3, 3, -1).permute(0, 3, 1, 2)
pw_groups_pred_not_ignore = pw_groups_pred[mask_broadcast].reshape(1, -1, 3, 3)
pw_groups_gt_not_ignore = pw_groups_gt[mask_broadcast].reshape(1, -1, 3, 3)
return pw_groups_gt_not_ignore, pw_groups_pred_not_ignore
def forward(self, gt_depth, pred_depth, select=True):
"""
Virtual normal loss.
:param pred_depth: predicted depth map, [B,W,H,C]
:param data: target label, ground truth depth, [B, W, H, C], padding region [padding_up, padding_down]
:return:
"""
device = gt_depth.device
self.fx = self.fx.to(device)
self.fy = self.fy.to(device)
self.u0 = self.u0.to(device)
self.v0 = self.v0.to(device)
self.u_u0 = self.u_u0.to(device)
self.v_v0 = self.v_v0.to(device)
# print("************ ", self.fx.device, self.u_u0.device)
gt_points, dt_points = self.select_points_groups(gt_depth, pred_depth)
gt_p12 = gt_points[:, :, :, 1] - gt_points[:, :, :, 0]
gt_p13 = gt_points[:, :, :, 2] - gt_points[:, :, :, 0]
dt_p12 = dt_points[:, :, :, 1] - dt_points[:, :, :, 0]
dt_p13 = dt_points[:, :, :, 2] - dt_points[:, :, :, 0]
gt_normal = torch.cross(gt_p12, gt_p13, dim=2)
dt_normal = torch.cross(dt_p12, dt_p13, dim=2)
dt_norm = torch.norm(dt_normal, 2, dim=2, keepdim=True)
gt_norm = torch.norm(gt_normal, 2, dim=2, keepdim=True)
dt_mask = dt_norm == 0.0
gt_mask = gt_norm == 0.0
dt_mask = dt_mask.to(torch.float32)
gt_mask = gt_mask.to(torch.float32)
dt_mask *= 0.01
gt_mask *= 0.01
gt_norm = gt_norm + gt_mask
dt_norm = dt_norm + dt_mask
gt_normal = gt_normal / gt_norm
dt_normal = dt_normal / dt_norm
#pdb.set_trace()
loss = torch.abs(gt_normal - dt_normal)
loss = torch.sum(torch.sum(loss, dim=2), dim=0)
if select:
loss, indices = torch.sort(loss, dim=0, descending=False)
loss = loss[int(loss.size(0) * 0.25):]
loss = torch.mean(loss)
return loss
|