|
"""
|
|
Hilbert Order
|
|
Modified from https://github.com/PrincetonLIPS/numpy-hilbert-curve
|
|
|
|
Author: Xiaoyang Wu ([email protected]), Kaixin Xu
|
|
Please cite our work if the code is helpful to you.
|
|
"""
|
|
|
|
import torch
|
|
|
|
|
|
def right_shift(binary, k=1, axis=-1):
|
|
"""Right shift an array of binary values.
|
|
|
|
Parameters:
|
|
-----------
|
|
binary: An ndarray of binary values.
|
|
|
|
k: The number of bits to shift. Default 1.
|
|
|
|
axis: The axis along which to shift. Default -1.
|
|
|
|
Returns:
|
|
--------
|
|
Returns an ndarray with zero prepended and the ends truncated, along
|
|
whatever axis was specified."""
|
|
|
|
|
|
if binary.shape[axis] <= k:
|
|
return torch.zeros_like(binary)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
slicing = [slice(None)] * len(binary.shape)
|
|
slicing[axis] = slice(None, -k)
|
|
shifted = torch.nn.functional.pad(
|
|
binary[tuple(slicing)], (k, 0), mode="constant", value=0
|
|
)
|
|
|
|
return shifted
|
|
|
|
|
|
def binary2gray(binary, axis=-1):
|
|
"""Convert an array of binary values into Gray codes.
|
|
|
|
This uses the classic X ^ (X >> 1) trick to compute the Gray code.
|
|
|
|
Parameters:
|
|
-----------
|
|
binary: An ndarray of binary values.
|
|
|
|
axis: The axis along which to compute the gray code. Default=-1.
|
|
|
|
Returns:
|
|
--------
|
|
Returns an ndarray of Gray codes.
|
|
"""
|
|
shifted = right_shift(binary, axis=axis)
|
|
|
|
|
|
gray = torch.logical_xor(binary, shifted)
|
|
|
|
return gray
|
|
|
|
|
|
def gray2binary(gray, axis=-1):
|
|
"""Convert an array of Gray codes back into binary values.
|
|
|
|
Parameters:
|
|
-----------
|
|
gray: An ndarray of gray codes.
|
|
|
|
axis: The axis along which to perform Gray decoding. Default=-1.
|
|
|
|
Returns:
|
|
--------
|
|
Returns an ndarray of binary values.
|
|
"""
|
|
|
|
|
|
shift = 2 ** (torch.Tensor([gray.shape[axis]]).log2().ceil().int() - 1)
|
|
while shift > 0:
|
|
gray = torch.logical_xor(gray, right_shift(gray, shift))
|
|
shift = torch.div(shift, 2, rounding_mode="floor")
|
|
return gray
|
|
|
|
|
|
def encode(locs, num_dims, num_bits):
|
|
"""Decode an array of locations in a hypercube into a Hilbert integer.
|
|
|
|
This is a vectorized-ish version of the Hilbert curve implementation by John
|
|
Skilling as described in:
|
|
|
|
Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference
|
|
Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics.
|
|
|
|
Params:
|
|
-------
|
|
locs - An ndarray of locations in a hypercube of num_dims dimensions, in
|
|
which each dimension runs from 0 to 2**num_bits-1. The shape can
|
|
be arbitrary, as long as the last dimension of the same has size
|
|
num_dims.
|
|
|
|
num_dims - The dimensionality of the hypercube. Integer.
|
|
|
|
num_bits - The number of bits for each dimension. Integer.
|
|
|
|
Returns:
|
|
--------
|
|
The output is an ndarray of uint64 integers with the same shape as the
|
|
input, excluding the last dimension, which needs to be num_dims.
|
|
"""
|
|
|
|
|
|
orig_shape = locs.shape
|
|
bitpack_mask = 1 << torch.arange(0, 8).to(locs.device)
|
|
bitpack_mask_rev = bitpack_mask.flip(-1)
|
|
|
|
if orig_shape[-1] != num_dims:
|
|
raise ValueError(
|
|
"""
|
|
The shape of locs was surprising in that the last dimension was of size
|
|
%d, but num_dims=%d. These need to be equal.
|
|
"""
|
|
% (orig_shape[-1], num_dims)
|
|
)
|
|
|
|
if num_dims * num_bits > 63:
|
|
raise ValueError(
|
|
"""
|
|
num_dims=%d and num_bits=%d for %d bits total, which can't be encoded
|
|
into a int64. Are you sure you need that many points on your Hilbert
|
|
curve?
|
|
"""
|
|
% (num_dims, num_bits, num_dims * num_bits)
|
|
)
|
|
|
|
|
|
|
|
locs_uint8 = locs.long().view(torch.uint8).reshape((-1, num_dims, 8)).flip(-1)
|
|
|
|
|
|
gray = (
|
|
locs_uint8.unsqueeze(-1)
|
|
.bitwise_and(bitpack_mask_rev)
|
|
.ne(0)
|
|
.byte()
|
|
.flatten(-2, -1)[..., -num_bits:]
|
|
)
|
|
|
|
|
|
|
|
for bit in range(0, num_bits):
|
|
|
|
for dim in range(0, num_dims):
|
|
|
|
mask = gray[:, dim, bit]
|
|
|
|
|
|
gray[:, 0, bit + 1 :] = torch.logical_xor(
|
|
gray[:, 0, bit + 1 :], mask[:, None]
|
|
)
|
|
|
|
|
|
to_flip = torch.logical_and(
|
|
torch.logical_not(mask[:, None]).repeat(1, gray.shape[2] - bit - 1),
|
|
torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]),
|
|
)
|
|
gray[:, dim, bit + 1 :] = torch.logical_xor(
|
|
gray[:, dim, bit + 1 :], to_flip
|
|
)
|
|
gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip)
|
|
|
|
|
|
gray = gray.swapaxes(1, 2).reshape((-1, num_bits * num_dims))
|
|
|
|
|
|
hh_bin = gray2binary(gray)
|
|
|
|
|
|
extra_dims = 64 - num_bits * num_dims
|
|
padded = torch.nn.functional.pad(hh_bin, (extra_dims, 0), "constant", 0)
|
|
|
|
|
|
hh_uint8 = (
|
|
(padded.flip(-1).reshape((-1, 8, 8)) * bitpack_mask)
|
|
.sum(2)
|
|
.squeeze()
|
|
.type(torch.uint8)
|
|
)
|
|
|
|
|
|
hh_uint64 = hh_uint8.view(torch.int64).squeeze()
|
|
|
|
return hh_uint64
|
|
|
|
|
|
def decode(hilberts, num_dims, num_bits):
|
|
"""Decode an array of Hilbert integers into locations in a hypercube.
|
|
|
|
This is a vectorized-ish version of the Hilbert curve implementation by John
|
|
Skilling as described in:
|
|
|
|
Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference
|
|
Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics.
|
|
|
|
Params:
|
|
-------
|
|
hilberts - An ndarray of Hilbert integers. Must be an integer dtype and
|
|
cannot have fewer bits than num_dims * num_bits.
|
|
|
|
num_dims - The dimensionality of the hypercube. Integer.
|
|
|
|
num_bits - The number of bits for each dimension. Integer.
|
|
|
|
Returns:
|
|
--------
|
|
The output is an ndarray of unsigned integers with the same shape as hilberts
|
|
but with an additional dimension of size num_dims.
|
|
"""
|
|
|
|
if num_dims * num_bits > 64:
|
|
raise ValueError(
|
|
"""
|
|
num_dims=%d and num_bits=%d for %d bits total, which can't be encoded
|
|
into a uint64. Are you sure you need that many points on your Hilbert
|
|
curve?
|
|
"""
|
|
% (num_dims, num_bits)
|
|
)
|
|
|
|
|
|
hilberts = torch.atleast_1d(hilberts)
|
|
|
|
|
|
orig_shape = hilberts.shape
|
|
bitpack_mask = 2 ** torch.arange(0, 8).to(hilberts.device)
|
|
bitpack_mask_rev = bitpack_mask.flip(-1)
|
|
|
|
|
|
|
|
hh_uint8 = (
|
|
hilberts.ravel().type(torch.int64).view(torch.uint8).reshape((-1, 8)).flip(-1)
|
|
)
|
|
|
|
|
|
|
|
hh_bits = (
|
|
hh_uint8.unsqueeze(-1)
|
|
.bitwise_and(bitpack_mask_rev)
|
|
.ne(0)
|
|
.byte()
|
|
.flatten(-2, -1)[:, -num_dims * num_bits :]
|
|
)
|
|
|
|
|
|
gray = binary2gray(hh_bits)
|
|
|
|
|
|
|
|
gray = gray.reshape((-1, num_bits, num_dims)).swapaxes(1, 2)
|
|
|
|
|
|
for bit in range(num_bits - 1, -1, -1):
|
|
|
|
for dim in range(num_dims - 1, -1, -1):
|
|
|
|
mask = gray[:, dim, bit]
|
|
|
|
|
|
gray[:, 0, bit + 1 :] = torch.logical_xor(
|
|
gray[:, 0, bit + 1 :], mask[:, None]
|
|
)
|
|
|
|
|
|
to_flip = torch.logical_and(
|
|
torch.logical_not(mask[:, None]),
|
|
torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]),
|
|
)
|
|
gray[:, dim, bit + 1 :] = torch.logical_xor(
|
|
gray[:, dim, bit + 1 :], to_flip
|
|
)
|
|
gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip)
|
|
|
|
|
|
extra_dims = 64 - num_bits
|
|
padded = torch.nn.functional.pad(gray, (extra_dims, 0), "constant", 0)
|
|
|
|
|
|
locs_chopped = padded.flip(-1).reshape((-1, num_dims, 8, 8))
|
|
|
|
|
|
|
|
locs_uint8 = (locs_chopped * bitpack_mask).sum(3).squeeze().type(torch.uint8)
|
|
|
|
|
|
flat_locs = locs_uint8.view(torch.int64)
|
|
|
|
|
|
return flat_locs.reshape((*orig_shape, num_dims))
|
|
|