kyleleey
first commit
98a77e0
raw
history blame
16.7 kB
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
from multiprocessing.spawn import get_preparation_data
import numpy as np
import torch
from ..render import mesh
from ..render import render
from ..networks import MLPWithPositionalEncoding, MLPWithPositionalEncoding_Style
###############################################################################
# Marching tetrahedrons implementation (differentiable), adapted from
# https://github.com/NVIDIAGameWorks/kaolin/blob/master/kaolin/ops/conversions/tetmesh.py
#
# Note this only supports batch size = 1.
###############################################################################
class DMTet:
def __init__(self):
self.triangle_table = torch.tensor([
[-1, -1, -1, -1, -1, -1],
[ 1, 0, 2, -1, -1, -1],
[ 4, 0, 3, -1, -1, -1],
[ 1, 4, 2, 1, 3, 4],
[ 3, 1, 5, -1, -1, -1],
[ 2, 3, 0, 2, 5, 3],
[ 1, 4, 0, 1, 5, 4],
[ 4, 2, 5, -1, -1, -1],
[ 4, 5, 2, -1, -1, -1],
[ 4, 1, 0, 4, 5, 1],
[ 3, 2, 0, 3, 5, 2],
[ 1, 3, 5, -1, -1, -1],
[ 4, 1, 2, 4, 3, 1],
[ 3, 0, 4, -1, -1, -1],
[ 2, 0, 1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1]
], dtype=torch.long, device='cuda')
self.num_triangles_table = torch.tensor([0,1,1,2,1,2,2,1,1,2,2,1,2,1,1,0], dtype=torch.long, device='cuda')
self.base_tet_edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype=torch.long, device='cuda')
###############################################################################
# Utility functions
###############################################################################
def sort_edges(self, edges_ex2):
with torch.no_grad():
order = (edges_ex2[:,0] > edges_ex2[:,1]).long()
order = order.unsqueeze(dim=1)
a = torch.gather(input=edges_ex2, index=order, dim=1)
b = torch.gather(input=edges_ex2, index=1-order, dim=1)
return torch.stack([a, b],-1)
def map_uv(self, faces, face_gidx, max_idx):
N = int(np.ceil(np.sqrt((max_idx+1)//2)))
tex_y, tex_x = torch.meshgrid(
torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"),
torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"),
indexing='ij'
)
pad = 0.9 / N
uvs = torch.stack([
tex_x , tex_y,
tex_x + pad, tex_y,
tex_x + pad, tex_y + pad,
tex_x , tex_y + pad
], dim=-1).view(-1, 2)
def _idx(tet_idx, N):
x = tet_idx % N
y = torch.div(tet_idx, N, rounding_mode='trunc')
return y * N + x
tet_idx = _idx(torch.div(face_gidx, 2, rounding_mode='trunc'), N)
tri_idx = face_gidx % 2
uv_idx = torch.stack((
tet_idx * 4, tet_idx * 4 + tri_idx + 1, tet_idx * 4 + tri_idx + 2
), dim = -1). view(-1, 3)
return uvs, uv_idx
###############################################################################
# Marching tets implementation
###############################################################################
def __call__(self, pos_nx3, sdf_n, tet_fx4):
with torch.no_grad():
occ_n = sdf_n > 0
occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1,4)
occ_sum = torch.sum(occ_fx4, -1)
valid_tets = (occ_sum>0) & (occ_sum<4)
occ_sum = occ_sum[valid_tets]
# find all vertices
all_edges = tet_fx4[valid_tets][:,self.base_tet_edges].reshape(-1,2)
all_edges = self.sort_edges(all_edges)
unique_edges, idx_map = torch.unique(all_edges,dim=0, return_inverse=True)
unique_edges = unique_edges.long()
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1,2).sum(-1) == 1
mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device="cuda") * -1
mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long,device="cuda")
idx_map = mapping[idx_map] # map edges to verts
interp_v = unique_edges[mask_edges]
edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1,2,3)
edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1,2,1)
edges_to_interp_sdf[:,-1] *= -1
denominator = edges_to_interp_sdf.sum(1,keepdim = True)
edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1])/denominator
verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
idx_map = idx_map.reshape(-1,6)
v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device="cuda"))
tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
num_triangles = self.num_triangles_table[tetindex]
# Generate triangle indices
faces = torch.cat((
torch.gather(input=idx_map[num_triangles == 1], dim=1, index=self.triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1,3),
torch.gather(input=idx_map[num_triangles == 2], dim=1, index=self.triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1,3),
), dim=0)
# Get global face index (static, does not depend on topology)
num_tets = tet_fx4.shape[0]
tet_gidx = torch.arange(num_tets, dtype=torch.long, device="cuda")[valid_tets]
face_gidx = torch.cat((
tet_gidx[num_triangles == 1]*2,
torch.stack((tet_gidx[num_triangles == 2]*2, tet_gidx[num_triangles == 2]*2 + 1), dim=-1).view(-1)
), dim=0)
uvs, uv_idx = self.map_uv(faces, face_gidx, num_tets*2)
return verts, faces, uvs, uv_idx
###############################################################################
# Regularizer
###############################################################################
def sdf_bce_reg_loss(sdf, all_edges):
sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1,2)
mask = torch.sign(sdf_f1x6x2[...,0]) != torch.sign(sdf_f1x6x2[...,1])
sdf_f1x6x2 = sdf_f1x6x2[mask]
sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,0], (sdf_f1x6x2[...,1] > 0).float()) + \
torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,1], (sdf_f1x6x2[...,0] > 0).float())
if torch.isnan(sdf_diff).any():
import ipdb; ipdb.set_trace()
return sdf_diff
###############################################################################
# Geometry interface
###############################################################################
class DMTetGeometry(torch.nn.Module):
def __init__(self, grid_res, scale, sdf_mode, num_layers=None, hidden_size=None, embedder_freq=None, embed_concat_pts=True, init_sdf=None, jitter_grid=0., perturb_sdf_iter=10000, sym_prior_shape=False, dim_of_classes=0, condition_choice='concat'):
super(DMTetGeometry, self).__init__()
self.sdf_mode = sdf_mode
self.grid_res = grid_res
self.marching_tets = DMTet()
self.grid_scale = scale
self.init_sdf = init_sdf
self.jitter_grid = jitter_grid
self.perturb_sdf_iter = perturb_sdf_iter
self.sym_prior_shape = sym_prior_shape
self.load_tets(self.grid_res, self.grid_scale)
if sdf_mode == "param":
sdf = torch.rand_like(self.verts[:,0]) - 0.1 # Random init.
self.sdf = torch.nn.Parameter(sdf.clone().detach(), requires_grad=True)
self.register_parameter('sdf', self.sdf)
self.deform = torch.nn.Parameter(torch.zeros_like(self.verts), requires_grad=True)
self.register_parameter('deform', self.deform)
else:
embedder_scaler = 2 * np.pi / self.grid_scale * 0.9 # originally (-0.5*s, 0.5*s) rescale to (-pi, pi) * 0.9
if dim_of_classes == 0 or (dim_of_classes != 0 and condition_choice == 'concat'):
self.mlp = MLPWithPositionalEncoding(
3,
1,
num_layers,
nf=hidden_size,
extra_dim=dim_of_classes,
dropout=0,
activation=None,
n_harmonic_functions=embedder_freq,
omega0=embedder_scaler,
embed_concat_pts=embed_concat_pts)
elif condition_choice == 'film' or condition_choice == 'mod':
self.mlp = MLPWithPositionalEncoding_Style(
3,
1,
num_layers,
nf=hidden_size,
extra_dim=dim_of_classes,
dropout=0,
activation=None,
n_harmonic_functions=embedder_freq,
omega0=embedder_scaler,
embed_concat_pts=embed_concat_pts,
style_choice=condition_choice)
else:
raise NotImplementedError
def load_tets(self, grid_res=None, scale=None):
if grid_res is None:
grid_res = self.grid_res
else:
self.grid_res = grid_res
if scale is None:
scale = self.grid_scale
else:
self.grid_scale = scale
tets = np.load('./data/tets/{}_tets.npz'.format(grid_res))
self.verts = torch.tensor(tets['vertices'], dtype=torch.float32, device='cuda') * scale # verts original scale (-0.5, 0.5)
self.indices = torch.tensor(tets['indices'], dtype=torch.long, device='cuda')
self.generate_edges()
def get_sdf(self, pts=None, perturb_sdf=False, total_iter=0, class_vector=None):
if self.sdf_mode == 'param':
sdf = self.sdf
else:
if pts is None:
pts = self.verts
if self.sym_prior_shape:
xs, ys, zs = pts.unbind(-1)
pts = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x
feat = None
if class_vector is not None:
feat = class_vector.unsqueeze(0).repeat(pts.shape[0], 1)
sdf = self.mlp(pts, feat=feat)
if self.init_sdf is None:
pass
elif type(self.init_sdf) in [float, int]:
sdf = sdf + self.init_sdf
elif self.init_sdf == 'sphere':
init_radius = self.grid_scale * 0.25
init_sdf = init_radius - pts.norm(dim=-1, keepdim=True) # init sdf is a sphere centered at origin
sdf = sdf + init_sdf
elif self.init_sdf == 'ellipsoid':
rxy = self.grid_scale * 0.15
xs, ys, zs = pts.unbind(-1)[:3]
init_sdf = rxy - torch.stack([xs, ys, zs/2], -1).norm(dim=-1, keepdim=True) # init sdf is approximately an ellipsoid centered at origin
sdf = sdf + init_sdf
else:
raise NotImplementedError
if perturb_sdf:
sdf = sdf + torch.randn_like(sdf) * 0.1 * max(0, 1-total_iter/self.perturb_sdf_iter)
return sdf
def get_sdf_gradient(self, class_vector=None):
assert self.sdf_mode == 'mlp', "Only MLP supports gradient computation."
num_samples = 5000
sample_points = (torch.rand(num_samples, 3, device=self.verts.device) - 0.5) * self.grid_scale
mesh_verts = self.mesh_verts.detach() + (torch.rand_like(self.mesh_verts) -0.5) * 0.1 * self.grid_scale
rand_idx = torch.randperm(len(mesh_verts), device=mesh_verts.device)[:5000]
mesh_verts = mesh_verts[rand_idx]
sample_points = torch.cat([sample_points, mesh_verts], 0)
sample_points.requires_grad = True
y = self.get_sdf(pts=sample_points, perturb_sdf=False, class_vector=class_vector)
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
try:
gradients = torch.autograd.grad(
outputs=[y],
inputs=sample_points,
grad_outputs=d_output,
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
except RuntimeError: # For validation, we have disabled gradient calculation.
return torch.zeros_like(sample_points)
return gradients
def get_sdf_reg_loss(self, class_vector=None):
reg_loss = {"sdf_bce_reg_loss": sdf_bce_reg_loss(self.current_sdf, self.all_edges).mean()}
if self.sdf_mode == 'mlp':
reg_loss["sdf_gradient_reg_loss"] = ((self.get_sdf_gradient(class_vector=class_vector).norm(dim=-1) - 1) ** 2).mean()
reg_loss['sdf_inflate_reg_loss'] = -self.current_sdf.mean()
return reg_loss
def generate_edges(self):
with torch.no_grad():
edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype = torch.long, device = "cuda")
all_edges = self.indices[:,edges].reshape(-1,2)
all_edges_sorted = torch.sort(all_edges, dim=1)[0]
self.all_edges = torch.unique(all_edges_sorted, dim=0)
@torch.no_grad()
def getAABB(self):
return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values
def getMesh(self, material=None, perturb_sdf=False, total_iter=0, jitter_grid=True, class_vector=None):
# Run DM tet to get a base mesh
v_deformed = self.verts
# if self.FLAGS.deform_grid:
# v_deformed = self.verts + 2 / (self.grid_res * 2) * torch.tanh(self.deform)
# else:
# v_deformed = self.verts
if jitter_grid and self.jitter_grid > 0:
jitter = (torch.rand(1, device=v_deformed.device)*2-1) * self.jitter_grid * self.grid_scale
v_deformed = v_deformed + jitter
self.current_sdf = self.get_sdf(v_deformed, perturb_sdf=perturb_sdf, total_iter=total_iter, class_vector=class_vector)
verts, faces, uvs, uv_idx = self.marching_tets(v_deformed, self.current_sdf, self.indices)
self.mesh_verts = verts
return mesh.make_mesh(verts[None], faces[None], uvs[None], uv_idx[None], material)
def render(self, glctx, target, lgt, opt_material, bsdf=None):
opt_mesh = self.getMesh(opt_material)
return render.render_mesh(glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'], msaa=True, background=target['background'], bsdf=bsdf)
def tick(self, glctx, target, lgt, opt_material, loss_fn, iteration):
# ==============================================================================================
# Render optimizable object with identical conditions
# ==============================================================================================
buffers = self.render(glctx, target, lgt, opt_material)
# ==============================================================================================
# Compute loss
# ==============================================================================================
t_iter = iteration / 20000
# Image-space loss, split into a coverage component and a color component
color_ref = target['img']
img_loss = torch.nn.functional.mse_loss(buffers['shaded'][..., 3:], color_ref[..., 3:])
img_loss = img_loss + loss_fn(buffers['shaded'][..., 0:3] * color_ref[..., 3:], color_ref[..., 0:3] * color_ref[..., 3:])
# SDF regularizer
# sdf_weight = self.sdf_regularizer - (self.sdf_regularizer - 0.01) * min(1.0, 4.0 * t_iter) # Dropoff to 0.01
reg_loss = sum(self.get_sdf_reg_loss().values)
# Albedo (k_d) smoothnesss regularizer
reg_loss += torch.mean(buffers['kd_grad'][..., :-1] * buffers['kd_grad'][..., -1:]) * 0.03 * min(1.0, iteration / 500)
# Visibility regularizer
reg_loss += torch.mean(buffers['occlusion'][..., :-1] * buffers['occlusion'][..., -1:]) * 0.001 * min(1.0, iteration / 500)
# Light white balance regularizer
reg_loss = reg_loss + lgt.regularizer() * 0.005
return img_loss, reg_loss