Spaces:
Build error
Build error
File size: 2,843 Bytes
94ada0b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
// Copyright (c) Facebook, Inc. and its affiliates.All Rights Reserved
// Please refer to original code: https://github.com/NVlabs/instant-ngp
// and the pytorch wrapper from https://github.com/ashawkey/torch-ngp
#include <stdint.h>
#include <torch/torch.h>
#include <torch/extension.h>
#include "hash_sample.h"
#include "utils.h"
void hash_encode_forward(at::Tensor inputs, at::Tensor embeddings, at::Tensor offsets, at::Tensor outputs, const float beta, const uint32_t B, const uint32_t N, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t H, const bool calc_grad_inputs, at::Tensor dy_dx, const uint32_t mode) {
CHECK_CUDA(inputs);
CHECK_CUDA(embeddings);
CHECK_CUDA(offsets);
CHECK_CUDA(outputs);
CHECK_CUDA(dy_dx);
CHECK_CONTIGUOUS(inputs);
CHECK_CONTIGUOUS(embeddings);
CHECK_CONTIGUOUS(offsets);
CHECK_CONTIGUOUS(outputs);
CHECK_CONTIGUOUS(dy_dx);
CHECK_IS_FLOAT(inputs);
CHECK_IS_FLOAT(embeddings);
CHECK_IS_INT(offsets);
CHECK_IS_FLOAT(outputs);
CHECK_IS_FLOAT(dy_dx);
hash_encode_forward_cuda(inputs.data_ptr<float>(), embeddings.data_ptr<float>(), offsets.data_ptr<int>(), outputs.data_ptr<float>(), beta, B, N, D, C, L, H, calc_grad_inputs, dy_dx.data_ptr<float>(), mode);
}
void hash_encode_backward(at::Tensor grad, at::Tensor inputs, at::Tensor embeddings, at::Tensor offsets, at::Tensor grad_embeddings, const float beta, const uint32_t B, const uint32_t N, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t H, const bool calc_grad_inputs, at::Tensor dy_dx, at::Tensor grad_inputs, const uint32_t mode) {
CHECK_CUDA(grad);
CHECK_CUDA(inputs);
CHECK_CUDA(embeddings);
CHECK_CUDA(offsets);
CHECK_CUDA(grad_embeddings);
CHECK_CUDA(dy_dx);
CHECK_CUDA(grad_inputs);
CHECK_CONTIGUOUS(grad);
CHECK_CONTIGUOUS(inputs);
CHECK_CONTIGUOUS(embeddings);
CHECK_CONTIGUOUS(offsets);
CHECK_CONTIGUOUS(grad_embeddings);
CHECK_CONTIGUOUS(dy_dx);
CHECK_CONTIGUOUS(grad_inputs);
CHECK_IS_FLOAT(grad);
CHECK_IS_FLOAT(inputs);
CHECK_IS_FLOAT(embeddings);
CHECK_IS_INT(offsets);
CHECK_IS_FLOAT(grad_embeddings);
CHECK_IS_FLOAT(dy_dx);
CHECK_IS_FLOAT(grad_inputs);
hash_encode_backward_cuda(grad.data_ptr<float>(), inputs.data_ptr<float>(), embeddings.data_ptr<float>(), offsets.data_ptr<int>(), grad_embeddings.data_ptr<float>(), beta, B, N, D, C, L, H, calc_grad_inputs, dy_dx.data_ptr<float>(), grad_inputs.data_ptr<float>(), mode);
}
//------------------------------------------------------------------------
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("hash_encode_forward", &hash_encode_forward, "hash encode forward (CUDA)");
m.def("hash_encode_backward", &hash_encode_backward, "hash encode backward (CUDA)");
}
|