import os import slangtorch import torch import torch.nn as nn from jaxtyping import Bool, Float from torch import Tensor class TextureBaker(nn.Module): def __init__(self): super().__init__() self.baker = slangtorch.loadModule( os.path.join(os.path.dirname(__file__), "texture_baker.slang") ) def rasterize( self, uv: Float[Tensor, "Nv 2"], face_indices: Float[Tensor, "Nf 3"], bake_resolution: int, ) -> Float[Tensor, "bake_resolution bake_resolution 4"]: if not face_indices.is_cuda or not uv.is_cuda: raise ValueError("All input tensors must be on cuda") face_indices = face_indices.to(torch.int32) uv = uv.to(torch.float32) rast_result = torch.empty( bake_resolution, bake_resolution, 4, device=uv.device, dtype=torch.float32 ) block_size = 16 grid_size = bake_resolution // block_size self.baker.bake_uv(uv=uv, indices=face_indices, output=rast_result).launchRaw( blockSize=(block_size, block_size, 1), gridSize=(grid_size, grid_size, 1) ) return rast_result def get_mask( self, rast: Float[Tensor, "bake_resolution bake_resolution 4"] ) -> Bool[Tensor, "bake_resolution bake_resolution"]: return rast[..., -1] >= 0 def interpolate( self, attr: Float[Tensor, "Nv 3"], rast: Float[Tensor, "bake_resolution bake_resolution 4"], face_indices: Float[Tensor, "Nf 3"], uv: Float[Tensor, "Nv 2"], ) -> Float[Tensor, "bake_resolution bake_resolution 3"]: # Make sure all input tensors are on torch if not attr.is_cuda or not face_indices.is_cuda or not rast.is_cuda: raise ValueError("All input tensors must be on cuda") attr = attr.to(torch.float32) face_indices = face_indices.to(torch.int32) uv = uv.to(torch.float32) pos_bake = torch.zeros( rast.shape[0], rast.shape[1], 3, device=attr.device, dtype=attr.dtype, ) block_size = 16 grid_size = rast.shape[0] // block_size self.baker.interpolate( attr=attr, indices=face_indices, rast=rast, output=pos_bake ).launchRaw( blockSize=(block_size, block_size, 1), gridSize=(grid_size, grid_size, 1) ) return pos_bake def forward( self, attr: Float[Tensor, "Nv 3"], uv: Float[Tensor, "Nv 2"], face_indices: Float[Tensor, "Nf 3"], bake_resolution: int, ) -> Float[Tensor, "bake_resolution bake_resolution 3"]: rast = self.rasterize(uv, face_indices, bake_resolution) return self.interpolate(attr, rast, face_indices, uv)