|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
from typing import Optional, Union
|
|
|
|
|
|
class KeyLUT:
|
|
def __init__(self):
|
|
r256 = torch.arange(256, dtype=torch.int64)
|
|
r512 = torch.arange(512, dtype=torch.int64)
|
|
zero = torch.zeros(256, dtype=torch.int64)
|
|
device = torch.device("cpu")
|
|
|
|
self._encode = {
|
|
device: (
|
|
self.xyz2key(r256, zero, zero, 8),
|
|
self.xyz2key(zero, r256, zero, 8),
|
|
self.xyz2key(zero, zero, r256, 8),
|
|
)
|
|
}
|
|
self._decode = {device: self.key2xyz(r512, 9)}
|
|
|
|
def encode_lut(self, device=torch.device("cpu")):
|
|
if device not in self._encode:
|
|
cpu = torch.device("cpu")
|
|
self._encode[device] = tuple(e.to(device) for e in self._encode[cpu])
|
|
return self._encode[device]
|
|
|
|
def decode_lut(self, device=torch.device("cpu")):
|
|
if device not in self._decode:
|
|
cpu = torch.device("cpu")
|
|
self._decode[device] = tuple(e.to(device) for e in self._decode[cpu])
|
|
return self._decode[device]
|
|
|
|
def xyz2key(self, x, y, z, depth):
|
|
key = torch.zeros_like(x)
|
|
for i in range(depth):
|
|
mask = 1 << i
|
|
key = (
|
|
key
|
|
| ((x & mask) << (2 * i + 2))
|
|
| ((y & mask) << (2 * i + 1))
|
|
| ((z & mask) << (2 * i + 0))
|
|
)
|
|
return key
|
|
|
|
def key2xyz(self, key, depth):
|
|
x = torch.zeros_like(key)
|
|
y = torch.zeros_like(key)
|
|
z = torch.zeros_like(key)
|
|
for i in range(depth):
|
|
x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2))
|
|
y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1))
|
|
z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0))
|
|
return x, y, z
|
|
|
|
|
|
_key_lut = KeyLUT()
|
|
|
|
|
|
def xyz2key(
|
|
x: torch.Tensor,
|
|
y: torch.Tensor,
|
|
z: torch.Tensor,
|
|
b: Optional[Union[torch.Tensor, int]] = None,
|
|
depth: int = 16,
|
|
):
|
|
r"""Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys
|
|
based on pre-computed look up tables. The speed of this function is much
|
|
faster than the method based on for-loop.
|
|
|
|
Args:
|
|
x (torch.Tensor): The x coordinate.
|
|
y (torch.Tensor): The y coordinate.
|
|
z (torch.Tensor): The z coordinate.
|
|
b (torch.Tensor or int): The batch index of the coordinates, and should be
|
|
smaller than 32768. If :attr:`b` is :obj:`torch.Tensor`, the size of
|
|
:attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`.
|
|
depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
|
|
"""
|
|
|
|
EX, EY, EZ = _key_lut.encode_lut(x.device)
|
|
x, y, z = x.long(), y.long(), z.long()
|
|
|
|
mask = 255 if depth > 8 else (1 << depth) - 1
|
|
key = EX[x & mask] | EY[y & mask] | EZ[z & mask]
|
|
if depth > 8:
|
|
mask = (1 << (depth - 8)) - 1
|
|
key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask]
|
|
key = key16 << 24 | key
|
|
|
|
if b is not None:
|
|
b = b.long()
|
|
key = b << 48 | key
|
|
|
|
return key
|
|
|
|
|
|
def key2xyz(key: torch.Tensor, depth: int = 16):
|
|
r"""Decodes the shuffled key to :attr:`x`, :attr:`y`, :attr:`z` coordinates
|
|
and the batch index based on pre-computed look up tables.
|
|
|
|
Args:
|
|
key (torch.Tensor): The shuffled key.
|
|
depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
|
|
"""
|
|
|
|
DX, DY, DZ = _key_lut.decode_lut(key.device)
|
|
x, y, z = torch.zeros_like(key), torch.zeros_like(key), torch.zeros_like(key)
|
|
|
|
b = key >> 48
|
|
key = key & ((1 << 48) - 1)
|
|
|
|
n = (depth + 2) // 3
|
|
for i in range(n):
|
|
k = key >> (i * 9) & 511
|
|
x = x | (DX[k] << (i * 3))
|
|
y = y | (DY[k] << (i * 3))
|
|
z = z | (DZ[k] << (i * 3))
|
|
|
|
return x, y, z, b
|
|
|