|
from dataclasses import dataclass, field |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
import craftsman |
|
from .utils import ( |
|
Mesh, |
|
IsosurfaceHelper, |
|
MarchingCubeCPUHelper, |
|
MarchingTetrahedraHelper, |
|
) |
|
|
|
from craftsman.utils.base import BaseModule |
|
from craftsman.utils.ops import chunk_batch, scale_tensor |
|
from craftsman.utils.typing import * |
|
|
|
class BaseGeometry(BaseModule): |
|
@dataclass |
|
class Config(BaseModule.Config): |
|
pass |
|
|
|
cfg: Config |
|
|
|
@staticmethod |
|
def create_from( |
|
other: "BaseGeometry", cfg: Optional[Union[dict, DictConfig]] = None, **kwargs |
|
) -> "BaseGeometry": |
|
raise TypeError( |
|
f"Cannot create {BaseGeometry.__name__} from {other.__class__.__name__}" |
|
) |
|
|
|
def export(self, *args, **kwargs): |
|
return {} |
|
|
|
|
|
class BaseImplicitGeometry(BaseGeometry): |
|
@dataclass |
|
class Config(BaseGeometry.Config): |
|
radius: float = 1.0 |
|
isosurface: bool = True |
|
isosurface_method: str = "mt" |
|
isosurface_resolution: int = 128 |
|
isosurface_threshold: Union[float, str] = 0.0 |
|
isosurface_chunk: int = 0 |
|
isosurface_coarse_to_fine: bool = True |
|
isosurface_deformable_grid: bool = False |
|
isosurface_remove_outliers: bool = True |
|
isosurface_outlier_n_faces_threshold: Union[int, float] = 0.01 |
|
|
|
cfg: Config |
|
|
|
def configure(self) -> None: |
|
self.bbox: Float[Tensor, "2 3"] |
|
self.register_buffer( |
|
"bbox", |
|
torch.as_tensor( |
|
[ |
|
[-self.cfg.radius, -self.cfg.radius, -self.cfg.radius], |
|
[self.cfg.radius, self.cfg.radius, self.cfg.radius], |
|
], |
|
dtype=torch.float32, |
|
), |
|
) |
|
self.isosurface_helper: Optional[IsosurfaceHelper] = None |
|
self.unbounded: bool = False |
|
|
|
def _initilize_isosurface_helper(self): |
|
if self.cfg.isosurface and self.isosurface_helper is None: |
|
if self.cfg.isosurface_method == "mc-cpu": |
|
self.isosurface_helper = MarchingCubeCPUHelper( |
|
self.cfg.isosurface_resolution |
|
).to(self.device) |
|
elif self.cfg.isosurface_method == "mt": |
|
self.isosurface_helper = MarchingTetrahedraHelper( |
|
self.cfg.isosurface_resolution, |
|
f"load/tets/{self.cfg.isosurface_resolution}_tets.npz", |
|
).to(self.device) |
|
else: |
|
raise AttributeError( |
|
"Unknown isosurface method {self.cfg.isosurface_method}" |
|
) |
|
|
|
def forward( |
|
self, points: Float[Tensor, "*N Di"], output_normal: bool = False |
|
) -> Dict[str, Float[Tensor, "..."]]: |
|
raise NotImplementedError |
|
|
|
def forward_field( |
|
self, points: Float[Tensor, "*N Di"] |
|
) -> Tuple[Float[Tensor, "*N 1"], Optional[Float[Tensor, "*N 3"]]]: |
|
|
|
|
|
raise NotImplementedError |
|
|
|
def forward_level( |
|
self, field: Float[Tensor, "*N 1"], threshold: float |
|
) -> Float[Tensor, "*N 1"]: |
|
|
|
raise NotImplementedError |
|
|
|
def _isosurface(self, bbox: Float[Tensor, "2 3"], fine_stage: bool = False) -> Mesh: |
|
def batch_func(x): |
|
|
|
field, deformation = self.forward_field( |
|
scale_tensor( |
|
x.to(bbox.device), self.isosurface_helper.points_range, bbox |
|
), |
|
) |
|
field = field.to( |
|
x.device |
|
) |
|
if deformation is not None: |
|
deformation = deformation.to(x.device) |
|
return field, deformation |
|
|
|
assert self.isosurface_helper is not None |
|
|
|
field, deformation = chunk_batch( |
|
batch_func, |
|
self.cfg.isosurface_chunk, |
|
self.isosurface_helper.grid_vertices, |
|
) |
|
|
|
threshold: float |
|
|
|
if isinstance(self.cfg.isosurface_threshold, float): |
|
threshold = self.cfg.isosurface_threshold |
|
elif self.cfg.isosurface_threshold == "auto": |
|
eps = 1.0e-5 |
|
threshold = field[field > eps].mean().item() |
|
craftsman.info( |
|
f"Automatically determined isosurface threshold: {threshold}" |
|
) |
|
else: |
|
raise TypeError( |
|
f"Unknown isosurface_threshold {self.cfg.isosurface_threshold}" |
|
) |
|
|
|
level = self.forward_level(field, threshold) |
|
mesh: Mesh = self.isosurface_helper(level, deformation=deformation) |
|
mesh.v_pos = scale_tensor( |
|
mesh.v_pos, self.isosurface_helper.points_range, bbox |
|
) |
|
mesh.add_extra("bbox", bbox) |
|
|
|
if self.cfg.isosurface_remove_outliers: |
|
|
|
|
|
mesh = mesh.remove_outlier(self.cfg.isosurface_outlier_n_faces_threshold) |
|
|
|
return mesh |
|
|
|
def isosurface(self) -> Mesh: |
|
if not self.cfg.isosurface: |
|
raise NotImplementedError( |
|
"Isosurface is not enabled in the current configuration" |
|
) |
|
self._initilize_isosurface_helper() |
|
if self.cfg.isosurface_coarse_to_fine: |
|
craftsman.debug("First run isosurface to get a tight bounding box ...") |
|
with torch.no_grad(): |
|
mesh_coarse = self._isosurface(self.bbox) |
|
vmin, vmax = mesh_coarse.v_pos.amin(dim=0), mesh_coarse.v_pos.amax(dim=0) |
|
vmin_ = (vmin - (vmax - vmin) * 0.1).max(self.bbox[0]) |
|
vmax_ = (vmax + (vmax - vmin) * 0.1).min(self.bbox[1]) |
|
craftsman.debug("Run isosurface again with the tight bounding box ...") |
|
mesh = self._isosurface(torch.stack([vmin_, vmax_], dim=0), fine_stage=True) |
|
else: |
|
mesh = self._isosurface(self.bbox) |
|
return mesh |
|
|
|
|
|
class BaseExplicitGeometry(BaseGeometry): |
|
@dataclass |
|
class Config(BaseGeometry.Config): |
|
radius: float = 1.0 |
|
|
|
cfg: Config |
|
|
|
def configure(self) -> None: |
|
self.bbox: Float[Tensor, "2 3"] |
|
self.register_buffer( |
|
"bbox", |
|
torch.as_tensor( |
|
[ |
|
[-self.cfg.radius, -self.cfg.radius, -self.cfg.radius], |
|
[self.cfg.radius, self.cfg.radius, self.cfg.radius], |
|
], |
|
dtype=torch.float32, |
|
), |
|
) |