|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from pdb import set_trace as st |
|
|
|
|
|
def transform_vectors(matrix: torch.Tensor, |
|
vectors4: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Left-multiplies MxM @ NxM. Returns NxM. |
|
""" |
|
res = torch.matmul(vectors4, matrix.T) |
|
return res |
|
|
|
|
|
def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Normalize vector lengths. |
|
""" |
|
return vectors / (torch.norm(vectors, dim=-1, keepdim=True)) |
|
|
|
|
|
def torch_dot(x: torch.Tensor, y: torch.Tensor): |
|
""" |
|
Dot product of two tensors. |
|
""" |
|
return (x * y).sum(-1) |
|
|
|
|
|
def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, |
|
box_side_length): |
|
""" |
|
Author: Petr Kellnhofer |
|
Intersects rays with the [-1, 1] NDC volume. |
|
Returns min and max distance of entry. |
|
Returns -1 for no intersection. |
|
https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection |
|
""" |
|
o_shape = rays_o.shape |
|
rays_o = rays_o.detach().reshape(-1, 3) |
|
rays_d = rays_d.detach().reshape(-1, 3) |
|
|
|
bb_min = [ |
|
-1 * (box_side_length / 2), -1 * (box_side_length / 2), |
|
-1 * (box_side_length / 2) |
|
] |
|
bb_max = [ |
|
1 * (box_side_length / 2), 1 * (box_side_length / 2), |
|
1 * (box_side_length / 2) |
|
] |
|
bounds = torch.tensor([bb_min, bb_max], |
|
dtype=rays_o.dtype, |
|
device=rays_o.device) |
|
is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device) |
|
|
|
|
|
invdir = 1 / rays_d |
|
sign = (invdir < 0).long() |
|
|
|
|
|
tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - |
|
rays_o[..., 0]) * invdir[..., 0] |
|
tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - |
|
rays_o[..., 0]) * invdir[..., 0] |
|
|
|
|
|
tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - |
|
rays_o[..., 1]) * invdir[..., 1] |
|
tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - |
|
rays_o[..., 1]) * invdir[..., 1] |
|
|
|
|
|
is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False |
|
|
|
|
|
tmin = torch.max(tmin, tymin) |
|
tmax = torch.min(tmax, tymax) |
|
|
|
|
|
tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - |
|
rays_o[..., 2]) * invdir[..., 2] |
|
tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - |
|
rays_o[..., 2]) * invdir[..., 2] |
|
|
|
|
|
is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False |
|
|
|
|
|
tmin = torch.max(tmin, tzmin) |
|
tmax = torch.min(tmax, tzmax) |
|
|
|
|
|
tmin[torch.logical_not(is_valid)] = -1 |
|
tmax[torch.logical_not(is_valid)] = -2 |
|
|
|
return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1) |
|
|
|
|
|
def linspace(start: torch.Tensor, stop: torch.Tensor, num: int): |
|
""" |
|
Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive. |
|
Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch. |
|
""" |
|
|
|
steps = torch.arange(num, dtype=torch.float32, |
|
device=start.device) / (num - 1) |
|
|
|
|
|
|
|
|
|
for i in range(start.ndim): |
|
steps = steps.unsqueeze(-1) |
|
|
|
|
|
out = start[None] + steps * (stop - start)[None] |
|
|
|
return out |
|
|