from typing import Optional, Tuple import numpy as np import torch import torch.nn as nn from jaxtyping import Float, Integer from torch import Tensor from .mesh import Mesh class IsosurfaceHelper(nn.Module): points_range: Tuple[float, float] = (0, 1) @property def grid_vertices(self) -> Float[Tensor, "N 3"]: raise NotImplementedError @property def requires_instance_per_batch(self) -> bool: return False class MarchingTetrahedraHelper(IsosurfaceHelper): def __init__(self, resolution: int, tets_path: str): super().__init__() self.resolution = resolution self.tets_path = tets_path self.triangle_table: Float[Tensor, "..."] self.register_buffer( "triangle_table", torch.as_tensor( [ [-1, -1, -1, -1, -1, -1], [1, 0, 2, -1, -1, -1], [4, 0, 3, -1, -1, -1], [1, 4, 2, 1, 3, 4], [3, 1, 5, -1, -1, -1], [2, 3, 0, 2, 5, 3], [1, 4, 0, 1, 5, 4], [4, 2, 5, -1, -1, -1], [4, 5, 2, -1, -1, -1], [4, 1, 0, 4, 5, 1], [3, 2, 0, 3, 5, 2], [1, 3, 5, -1, -1, -1], [4, 1, 2, 4, 3, 1], [3, 0, 4, -1, -1, -1], [2, 0, 1, -1, -1, -1], [-1, -1, -1, -1, -1, -1], ], dtype=torch.long, ), persistent=False, ) self.num_triangles_table: Integer[Tensor, "..."] self.register_buffer( "num_triangles_table", torch.as_tensor( [0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long ), persistent=False, ) self.base_tet_edges: Integer[Tensor, "..."] self.register_buffer( "base_tet_edges", torch.as_tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long), persistent=False, ) tets = np.load(self.tets_path) self._grid_vertices: Float[Tensor, "..."] self.register_buffer( "_grid_vertices", torch.from_numpy(tets["vertices"]).float(), persistent=False, ) self.indices: Integer[Tensor, "..."] self.register_buffer( "indices", torch.from_numpy(tets["indices"]).long(), persistent=False ) self._all_edges: Optional[Integer[Tensor, "Ne 2"]] = None center_indices, boundary_indices = self.get_center_boundary_index( self._grid_vertices ) self.center_indices: Integer[Tensor, "..."] self.register_buffer("center_indices", center_indices, persistent=False) self.boundary_indices: Integer[Tensor, "..."] self.register_buffer("boundary_indices", boundary_indices, persistent=False) def get_center_boundary_index(self, verts): magn = torch.sum(verts**2, dim=-1) center_idx = torch.argmin(magn) boundary_neg = verts == verts.max() boundary_pos = verts == verts.min() boundary = torch.bitwise_or(boundary_pos, boundary_neg) boundary = torch.sum(boundary.float(), dim=-1) boundary_idx = torch.nonzero(boundary) return center_idx, boundary_idx.squeeze(dim=-1) def normalize_grid_deformation( self, grid_vertex_offsets: Float[Tensor, "Nv 3"] ) -> Float[Tensor, "Nv 3"]: return ( (self.points_range[1] - self.points_range[0]) / self.resolution # half tet size is approximately 1 / self.resolution * torch.tanh(grid_vertex_offsets) ) # FIXME: hard-coded activation @property def grid_vertices(self) -> Float[Tensor, "Nv 3"]: return self._grid_vertices @property def all_edges(self) -> Integer[Tensor, "Ne 2"]: if self._all_edges is None: # compute edges on GPU, or it would be VERY SLOW (basically due to the unique operation) edges = torch.tensor( [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=self.indices.device, ) _all_edges = self.indices[:, edges].reshape(-1, 2) _all_edges_sorted = torch.sort(_all_edges, dim=1)[0] _all_edges = torch.unique(_all_edges_sorted, dim=0) self._all_edges = _all_edges return self._all_edges def sort_edges(self, edges_ex2): with torch.no_grad(): order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long() order = order.unsqueeze(dim=1) a = torch.gather(input=edges_ex2, index=order, dim=1) b = torch.gather(input=edges_ex2, index=1 - order, dim=1) return torch.stack([a, b], -1) def _forward(self, pos_nx3, sdf_n, tet_fx4): with torch.no_grad(): occ_n = sdf_n > 0 occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4) occ_sum = torch.sum(occ_fx4, -1) valid_tets = (occ_sum > 0) & (occ_sum < 4) occ_sum = occ_sum[valid_tets] # find all vertices all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2) all_edges = self.sort_edges(all_edges) unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) unique_edges = unique_edges.long() mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 mapping = ( torch.ones( (unique_edges.shape[0]), dtype=torch.long, device=pos_nx3.device ) * -1 ) mapping[mask_edges] = torch.arange( mask_edges.sum(), dtype=torch.long, device=pos_nx3.device ) idx_map = mapping[idx_map] # map edges to verts interp_v = unique_edges[mask_edges] edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3) edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1) edges_to_interp_sdf[:, -1] *= -1 denominator = edges_to_interp_sdf.sum(1, keepdim=True) edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator verts = (edges_to_interp * edges_to_interp_sdf).sum(1) idx_map = idx_map.reshape(-1, 6) v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=pos_nx3.device)) tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) num_triangles = self.num_triangles_table[tetindex] # Generate triangle indices faces = torch.cat( ( torch.gather( input=idx_map[num_triangles == 1], dim=1, index=self.triangle_table[tetindex[num_triangles == 1]][:, :3], ).reshape(-1, 3), torch.gather( input=idx_map[num_triangles == 2], dim=1, index=self.triangle_table[tetindex[num_triangles == 2]][:, :6], ).reshape(-1, 3), ), dim=0, ) return verts, faces def forward( self, level: Float[Tensor, "N3 1"], deformation: Optional[Float[Tensor, "N3 3"]] = None, ) -> Mesh: if deformation is not None: grid_vertices = self.grid_vertices + self.normalize_grid_deformation( deformation ) else: grid_vertices = self.grid_vertices v_pos, t_pos_idx = self._forward(grid_vertices, level, self.indices) mesh = Mesh( v_pos=v_pos, t_pos_idx=t_pos_idx, # extras grid_vertices=grid_vertices, tet_edges=self.all_edges, grid_level=level, grid_deformation=deformation, ) return mesh