Spaces:
Running
on
L4
Running
on
L4
FrozenBurning
commited on
Commit
•
81ecb2b
1
Parent(s):
06ea84f
single view to 3D init release
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +4 -0
- README.md +2 -1
- app.py +209 -0
- assets/examples/blue_cat.png +0 -0
- assets/examples/bubble_mart_blue.png +0 -0
- assets/examples/bulldog.png +0 -0
- assets/examples/ceramic.png +0 -0
- assets/examples/chair_watermelon.png +0 -0
- assets/examples/cup_rgba.png +0 -0
- assets/examples/cute_horse.jpg +0 -0
- assets/examples/earphone.jpg +0 -0
- assets/examples/firedragon.png +0 -0
- assets/examples/fox.jpg +0 -0
- assets/examples/fruit_elephant.jpg +0 -0
- assets/examples/hatsune_miku.png +0 -0
- assets/examples/ikun_rgba.png +0 -0
- assets/examples/mailbox.png +0 -0
- assets/examples/mario.png +0 -0
- assets/examples/mei_ling_panda.png +0 -0
- assets/examples/mushroom_teapot.jpg +0 -0
- assets/examples/pikachu.png +0 -0
- assets/examples/potplant_rgba.png +0 -0
- assets/examples/seed_frog.png +0 -0
- assets/examples/shuai_panda_notail.png +0 -0
- assets/examples/yellow_duck.png +0 -0
- configs/inference_dit.yml +97 -0
- dva/__init__.py +5 -0
- dva/attr_dict.py +66 -0
- dva/geom.py +653 -0
- dva/io.py +56 -0
- dva/layers.py +157 -0
- dva/losses.py +239 -0
- dva/mvp/extensions/mvpraymarch/bvh.cu +292 -0
- dva/mvp/extensions/mvpraymarch/cudadispatch.h +104 -0
- dva/mvp/extensions/mvpraymarch/helper_math.h +1453 -0
- dva/mvp/extensions/mvpraymarch/makefile +2 -0
- dva/mvp/extensions/mvpraymarch/mvpraymarch.cpp +405 -0
- dva/mvp/extensions/mvpraymarch/mvpraymarch.py +559 -0
- dva/mvp/extensions/mvpraymarch/mvpraymarch_kernel.cu +208 -0
- dva/mvp/extensions/mvpraymarch/mvpraymarch_subset_kernel.h +218 -0
- dva/mvp/extensions/mvpraymarch/primaccum.h +101 -0
- dva/mvp/extensions/mvpraymarch/primsampler.h +94 -0
- dva/mvp/extensions/mvpraymarch/primtransf.h +182 -0
- dva/mvp/extensions/mvpraymarch/setup.py +30 -0
- dva/mvp/extensions/mvpraymarch/utils.h +847 -0
- dva/mvp/extensions/utils/helper_math.h +1453 -0
- dva/mvp/extensions/utils/makefile +2 -0
- dva/mvp/extensions/utils/setup.py +29 -0
- dva/mvp/extensions/utils/utils.cpp +137 -0
- dva/mvp/extensions/utils/utils.py +211 -0
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
build
|
3 |
+
*.so
|
4 |
+
runs
|
README.md
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
---
|
2 |
-
title: 3DTopia
|
3 |
emoji: 🌖
|
4 |
colorFrom: green
|
5 |
colorTo: pink
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.41.0
|
|
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
|
|
1 |
---
|
2 |
+
title: 3DTopia-XL
|
3 |
emoji: 🌖
|
4 |
colorFrom: green
|
5 |
colorTo: pink
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.41.0
|
8 |
+
python_version: 3.9
|
9 |
app_file: app.py
|
10 |
pinned: false
|
11 |
---
|
app.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import imageio
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
os.system("bash install.sh")
|
6 |
+
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
import tqdm
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import torchvision.transforms.functional as TF
|
13 |
+
import rembg
|
14 |
+
import gradio as gr
|
15 |
+
from dva.io import load_from_config
|
16 |
+
from dva.ray_marcher import RayMarcher
|
17 |
+
from dva.visualize import visualize_primvolume, visualize_video_primvolume
|
18 |
+
from inference import remove_background, resize_foreground, extract_texmesh
|
19 |
+
from models.diffusion import create_diffusion
|
20 |
+
from huggingface_hub import hf_hub_download
|
21 |
+
ckpt_path = hf_hub_download(repo_id="frozenburning/3DTopia-XL", filename="model_sview_dit_fp16.pt")
|
22 |
+
vae_ckpt_path = hf_hub_download(repo_id="frozenburning/3DTopia-XL", filename="model_vae_fp16.pt")
|
23 |
+
|
24 |
+
GRADIO_PRIM_VIDEO_PATH = 'prim.mp4'
|
25 |
+
GRADIO_RGB_VIDEO_PATH = 'rgb.mp4'
|
26 |
+
GRADIO_MAT_VIDEO_PATH = 'mat.mp4'
|
27 |
+
GRADIO_GLB_PATH = 'pbr_mesh.glb'
|
28 |
+
CONFIG_PATH = "./configs/inference_dit.yml"
|
29 |
+
|
30 |
+
config = OmegaConf.load(CONFIG_PATH)
|
31 |
+
config.checkpoint_path = ckpt_path
|
32 |
+
config.model.vae_checkpoint_path = vae_ckpt_path
|
33 |
+
# model
|
34 |
+
model = load_from_config(config.model.generator)
|
35 |
+
state_dict = torch.load(config.checkpoint_path, map_location='cpu')
|
36 |
+
model.load_state_dict(state_dict['ema'])
|
37 |
+
vae = load_from_config(config.model.vae)
|
38 |
+
vae_state_dict = torch.load(config.model.vae_checkpoint_path, map_location='cpu')
|
39 |
+
vae.load_state_dict(vae_state_dict['model_state_dict'])
|
40 |
+
conditioner = load_from_config(config.model.conditioner)
|
41 |
+
|
42 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
43 |
+
vae = vae.to(device)
|
44 |
+
conditioner = conditioner.to(device)
|
45 |
+
model = model.to(device)
|
46 |
+
model.eval()
|
47 |
+
|
48 |
+
amp = True
|
49 |
+
precision_dtype = torch.float16
|
50 |
+
|
51 |
+
rm = RayMarcher(
|
52 |
+
config.image_height,
|
53 |
+
config.image_width,
|
54 |
+
**config.rm,
|
55 |
+
).to(device)
|
56 |
+
|
57 |
+
perchannel_norm = False
|
58 |
+
if "latent_mean" in config.model:
|
59 |
+
latent_mean = torch.Tensor(config.model.latent_mean)[None, None, :].to(device)
|
60 |
+
latent_std = torch.Tensor(config.model.latent_std)[None, None, :].to(device)
|
61 |
+
assert latent_mean.shape[-1] == config.model.generator.in_channels
|
62 |
+
perchannel_norm = True
|
63 |
+
|
64 |
+
config.diffusion.pop("timestep_respacing")
|
65 |
+
config.model.pop("vae")
|
66 |
+
config.model.pop("vae_checkpoint_path")
|
67 |
+
config.model.pop("conditioner")
|
68 |
+
config.model.pop("generator")
|
69 |
+
config.model.pop("latent_nf")
|
70 |
+
config.model.pop("latent_mean")
|
71 |
+
config.model.pop("latent_std")
|
72 |
+
model_primx = load_from_config(config.model)
|
73 |
+
# load rembg
|
74 |
+
rembg_session = rembg.new_session()
|
75 |
+
|
76 |
+
# process function
|
77 |
+
def process(input_image, input_num_steps=25, input_seed=42, input_cfg=6.0):
|
78 |
+
# seed
|
79 |
+
torch.manual_seed(input_seed)
|
80 |
+
|
81 |
+
os.makedirs(config.output_dir, exist_ok=True)
|
82 |
+
output_rgb_video_path = os.path.join(config.output_dir, GRADIO_RGB_VIDEO_PATH)
|
83 |
+
output_prim_video_path = os.path.join(config.output_dir, GRADIO_PRIM_VIDEO_PATH)
|
84 |
+
output_mat_video_path = os.path.join(config.output_dir, GRADIO_MAT_VIDEO_PATH)
|
85 |
+
output_glb_path = os.path.join(config.output_dir, GRADIO_GLB_PATH)
|
86 |
+
|
87 |
+
diffusion = create_diffusion(timestep_respacing=respacing, **config.diffusion)
|
88 |
+
sample_fn = diffusion.ddim_sample_loop_progressive
|
89 |
+
fwd_fn = model.forward_with_cfg
|
90 |
+
|
91 |
+
# text-conditioned
|
92 |
+
if input_image is None:
|
93 |
+
raise NotImplementedError
|
94 |
+
# image-conditioned (may also input text, but no text usually works too)
|
95 |
+
else:
|
96 |
+
input_image = remove_background(input_image, rembg_session)
|
97 |
+
input_image = resize_foreground(input_image, 0.85)
|
98 |
+
raw_image = np.array(input_image)
|
99 |
+
mask = (raw_image[..., -1][..., None] > 0) * 1
|
100 |
+
raw_image = raw_image[..., :3] * mask
|
101 |
+
input_cond = torch.from_numpy(np.array(raw_image)[None, ...]).to(device)
|
102 |
+
|
103 |
+
with torch.no_grad():
|
104 |
+
latent = torch.randn(1, config.model.num_prims, 1, 4, 4, 4)
|
105 |
+
batch = {}
|
106 |
+
inf_bs = 1
|
107 |
+
inf_x = torch.randn(inf_bs, config.model.num_prims, 68).to(device)
|
108 |
+
y = conditioner.encoder(input_cond)
|
109 |
+
model_kwargs = dict(y=y[:inf_bs, ...], precision_dtype=precision_dtype, enable_amp=amp)
|
110 |
+
if input_cfg >= 0:
|
111 |
+
model_kwargs['cfg_scale'] = input_cfg
|
112 |
+
for samples in sample_fn(fwd_fn, inf_x.shape, inf_x, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device):
|
113 |
+
final_samples = samples
|
114 |
+
recon_param = final_samples["sample"].reshape(inf_bs, config.model.num_prims, -1)
|
115 |
+
if perchannel_norm:
|
116 |
+
recon_param = recon_param / config.model.latent_nf * latent_std + latent_mean
|
117 |
+
recon_srt_param = recon_param[:, :, 0:4]
|
118 |
+
recon_feat_param = recon_param[:, :, 4:] # [8, 2048, 64]
|
119 |
+
recon_feat_param_list = []
|
120 |
+
# one-by-one to avoid oom
|
121 |
+
for inf_bidx in range(inf_bs):
|
122 |
+
if not perchannel_norm:
|
123 |
+
decoded = vae.decode(recon_feat_param[inf_bidx, ...].reshape(1*config.model.num_prims, *latent.shape[-4:]) / config.model.latent_nf)
|
124 |
+
else:
|
125 |
+
decoded = vae.decode(recon_feat_param[inf_bidx, ...].reshape(1*config.model.num_prims, *latent.shape[-4:]))
|
126 |
+
recon_feat_param_list.append(decoded.detach())
|
127 |
+
recon_feat_param = torch.concat(recon_feat_param_list, dim=0)
|
128 |
+
# invert normalization
|
129 |
+
if not perchannel_norm:
|
130 |
+
recon_srt_param[:, :, 0:1] = (recon_srt_param[:, :, 0:1] / 10) + 0.05
|
131 |
+
recon_feat_param[:, 0:1, ...] /= 5.
|
132 |
+
recon_feat_param[:, 1:, ...] = (recon_feat_param[:, 1:, ...] + 1) / 2.
|
133 |
+
recon_feat_param = recon_feat_param.reshape(inf_bs, config.model.num_prims, -1)
|
134 |
+
recon_param = torch.concat([recon_srt_param, recon_feat_param], dim=-1)
|
135 |
+
visualize_video_primvolume(config.output_dir, batch, recon_param, 60, rm, device)
|
136 |
+
prim_params = {'srt_param': recon_srt_param[0].detach().cpu(), 'feat_param': recon_feat_param[0].detach().cpu()}
|
137 |
+
torch.save({'model_state_dict': prim_params}, "{}/denoised.pt".format(config.output_dir))
|
138 |
+
|
139 |
+
# exporting GLB mesh
|
140 |
+
denoise_param_path = os.path.join(config.output_dir, 'denoised.pt')
|
141 |
+
primx_ckpt_weight = torch.load(denoise_param_path, map_location='cpu')['model_state_dict']
|
142 |
+
model_primx.load_state_dict(ckpt_weight)
|
143 |
+
model_primx.to(device)
|
144 |
+
model_primx.eval()
|
145 |
+
with torch.no_grad():
|
146 |
+
model_primx.srt_param[:, 1:4] *= 0.85
|
147 |
+
extract_texmesh(config.inference, model_primx, output_glb_path, device)
|
148 |
+
|
149 |
+
return output_rgb_video_path, output_prim_video_path, output_mat_video_path, output_glb_path
|
150 |
+
|
151 |
+
# gradio UI
|
152 |
+
_TITLE = '''3DTopia-XL'''
|
153 |
+
|
154 |
+
_DESCRIPTION = '''
|
155 |
+
<div>
|
156 |
+
<a style="display:inline-block" href="https://frozenburning.github.io/projects/3DTopia-XL/"><img src='https://img.shields.io/badge/public_website-8A2BE2'></a>
|
157 |
+
<a style="display:inline-block; margin-left: .5em" href="https://github.com/3DTopia/3DTopia-XL"><img src='https://img.shields.io/github/stars/3DTopia/3DTopia-XL?style=social'/></a>
|
158 |
+
</div>
|
159 |
+
|
160 |
+
* Now we offer 1) single image conditioned model, we will release 2) multiview images conditioned model and 3) pure text conditioned model in the future!
|
161 |
+
* If you find the output unsatisfying, try using different seeds!
|
162 |
+
'''
|
163 |
+
|
164 |
+
block = gr.Blocks(title=_TITLE).queue()
|
165 |
+
with block:
|
166 |
+
with gr.Row():
|
167 |
+
with gr.Column(scale=1):
|
168 |
+
gr.Markdown('# ' + _TITLE)
|
169 |
+
gr.Markdown(_DESCRIPTION)
|
170 |
+
|
171 |
+
with gr.Row(variant='panel'):
|
172 |
+
with gr.Column(scale=1):
|
173 |
+
# input image
|
174 |
+
input_image = gr.Image(label="image", type='pil')
|
175 |
+
# inference steps
|
176 |
+
input_num_steps = gr.Slider(label="inference steps", minimum=1, maximum=100, step=1, value=25)
|
177 |
+
# random seed
|
178 |
+
input_cfg = gr.Slider(label="CFG scale", minimum=0, maximum=15, step=1, value=6)
|
179 |
+
# random seed
|
180 |
+
input_seed = gr.Slider(label="random seed", minimum=0, maximum=100000, step=1, value=42)
|
181 |
+
# gen button
|
182 |
+
button_gen = gr.Button("Generate")
|
183 |
+
|
184 |
+
with gr.Column(scale=1):
|
185 |
+
with gr.Tab("Video"):
|
186 |
+
# final video results
|
187 |
+
output_rgb_video = gr.Video(label="video")
|
188 |
+
output_prim_video = gr.Video(label="video")
|
189 |
+
output_mat_video = gr.Video(label="video")
|
190 |
+
with gr.Tab("GLB"):
|
191 |
+
# glb file
|
192 |
+
output_glb = gr.File(label="glb")
|
193 |
+
|
194 |
+
button_gen.click(process, inputs=[input_image, input_num_steps, input_seed, input_cfg], outputs=[output_rgb_video, output_prim_video, output_mat_video, output_glb])
|
195 |
+
|
196 |
+
gr.Examples(
|
197 |
+
examples=[
|
198 |
+
"assets/examples/fruit_elephant.jpg",
|
199 |
+
"assets/examples/mei_ling_panda.png",
|
200 |
+
"assets/examples/shuai_panda_notail.png",
|
201 |
+
],
|
202 |
+
inputs=[input_image],
|
203 |
+
outputs=[output_rgb_video, output_prim_video, output_mat_video, output_glb],
|
204 |
+
fn=lambda x: process(input_image=x),
|
205 |
+
cache_examples=False,
|
206 |
+
label='Single Image to 3D PBR Asset'
|
207 |
+
)
|
208 |
+
|
209 |
+
block.launch(server_name="0.0.0.0", share=True)
|
assets/examples/blue_cat.png
ADDED
assets/examples/bubble_mart_blue.png
ADDED
assets/examples/bulldog.png
ADDED
assets/examples/ceramic.png
ADDED
assets/examples/chair_watermelon.png
ADDED
assets/examples/cup_rgba.png
ADDED
assets/examples/cute_horse.jpg
ADDED
assets/examples/earphone.jpg
ADDED
assets/examples/firedragon.png
ADDED
assets/examples/fox.jpg
ADDED
assets/examples/fruit_elephant.jpg
ADDED
assets/examples/hatsune_miku.png
ADDED
assets/examples/ikun_rgba.png
ADDED
assets/examples/mailbox.png
ADDED
assets/examples/mario.png
ADDED
assets/examples/mei_ling_panda.png
ADDED
assets/examples/mushroom_teapot.jpg
ADDED
assets/examples/pikachu.png
ADDED
assets/examples/potplant_rgba.png
ADDED
assets/examples/seed_frog.png
ADDED
assets/examples/shuai_panda_notail.png
ADDED
assets/examples/yellow_duck.png
ADDED
configs/inference_dit.yml
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
debug: False
|
2 |
+
root_data_dir: ./runs
|
3 |
+
checkpoint_path:
|
4 |
+
global_seed: 42
|
5 |
+
|
6 |
+
inference:
|
7 |
+
input_dir:
|
8 |
+
ddim: 25
|
9 |
+
cfg: 6
|
10 |
+
seed: ${global_seed}
|
11 |
+
precision: fp16
|
12 |
+
export_glb: True
|
13 |
+
decimate: 100000
|
14 |
+
mc_resolution: 256
|
15 |
+
batch_size: 4096
|
16 |
+
remesh: False
|
17 |
+
|
18 |
+
image_height: 518
|
19 |
+
image_width: 518
|
20 |
+
|
21 |
+
model:
|
22 |
+
class_name: models.primsdf.PrimSDF
|
23 |
+
num_prims: 2048
|
24 |
+
dim_feat: 6
|
25 |
+
prim_shape: 8
|
26 |
+
init_scale: 0.05 # useless if auto_scale_init == True
|
27 |
+
sdf2alpha_var: 0.005
|
28 |
+
auto_scale_init: True
|
29 |
+
init_sampling: uniform
|
30 |
+
vae:
|
31 |
+
class_name: models.vae3d_dib.VAE
|
32 |
+
in_channels: ${model.dim_feat}
|
33 |
+
latent_channels: 1
|
34 |
+
out_channels: ${model.vae.in_channels}
|
35 |
+
down_channels: [32, 256]
|
36 |
+
mid_attention: True
|
37 |
+
up_channels: [256, 32]
|
38 |
+
layers_per_block: 2
|
39 |
+
gradient_checkpointing: False
|
40 |
+
vae_checkpoint_path:
|
41 |
+
conditioner:
|
42 |
+
class_name: models.conditioner.image.ImageConditioner
|
43 |
+
num_prims: ${model.num_prims}
|
44 |
+
dim_feat: ${model.dim_feat}
|
45 |
+
prim_shape: ${model.prim_shape}
|
46 |
+
sample_view: False
|
47 |
+
encoder_config:
|
48 |
+
class_name: models.conditioner.image_dinov2.Dinov2Wrapper
|
49 |
+
model_name: dinov2_vitb14_reg
|
50 |
+
freeze: True
|
51 |
+
generator:
|
52 |
+
class_name: models.dit_crossattn.DiT
|
53 |
+
seq_length: ${model.num_prims}
|
54 |
+
in_channels: 68 # equals to model.vae.latent_channels * latent_dim^3
|
55 |
+
condition_channels: 768
|
56 |
+
hidden_size: 1152
|
57 |
+
depth: 28
|
58 |
+
num_heads: 16
|
59 |
+
attn_proj_bias: True
|
60 |
+
cond_drop_prob: 0.1
|
61 |
+
gradient_checkpointing: False
|
62 |
+
latent_nf: 1.0
|
63 |
+
latent_mean: [ 0.0442, -0.0029, -0.0425, -0.0043, -0.4086, -0.2906, -0.7002, -0.0852, -0.4446, -0.6896, -0.7344, -0.3524, -0.5488, -0.4313, -1.1715, -0.0875, -0.6131, -0.3924, -0.7335, -0.3749, 0.4658, -0.0236, 0.8362, 0.3388, 0.0188, 0.5988, -0.1853, 1.1579, 0.6240, 0.0758, 0.9641, 0.6586, 0.6260, 0.2384, 0.7798, 0.8297, -0.6543, -0.4441, -1.3887, -0.0393, -0.9008, -0.8616, -1.7434, -0.1328, -0.8119, -0.8225, -1.8533, -0.0444, -1.0510, -0.5158, -1.1907, -0.5265, 0.2832, 0.6037, 0.5981, 0.5461, 0.4366, 0.4144, 0.7219, 0.5722, 0.5937, 0.5598, 0.9414, 0.7419, 0.2102, 0.3388, 0.4501, 0.5166]
|
64 |
+
latent_std: [0.0219, 0.3707, 0.3911, 0.3610, 0.7549, 0.7909, 0.9691, 0.9193, 0.8218, 0.9389, 1.1785, 1.0254, 0.6376, 0.6568, 0.7892, 0.8468, 0.8775, 0.7920, 0.9037, 0.9329, 0.9196, 1.1123, 1.3041, 1.0955, 1.2727, 1.6565, 1.8502, 1.7006, 0.8973, 1.0408, 1.2034, 1.2703, 1.0373, 1.0486, 1.0716, 0.9746, 0.7088, 0.8685, 1.0030, 0.9504, 1.0410, 1.3033, 1.5368, 1.4386, 0.6142, 0.6887, 0.9085, 0.9903, 1.0190, 0.9302, 1.0121, 0.9964, 1.1474, 1.2729, 1.4627, 1.1404, 1.3713, 1.6692, 1.8424, 1.5047, 1.1356, 1.2369, 1.3554, 1.1848, 1.1319, 1.0822, 1.1972, 0.9916]
|
65 |
+
|
66 |
+
diffusion:
|
67 |
+
timestep_respacing:
|
68 |
+
noise_schedule: squaredcos_cap_v2
|
69 |
+
diffusion_steps: 1000
|
70 |
+
parameterization: v
|
71 |
+
|
72 |
+
rm:
|
73 |
+
volradius: 10000.0
|
74 |
+
dt: 1.0
|
75 |
+
|
76 |
+
optimizer:
|
77 |
+
class_name: torch.optim.AdamW
|
78 |
+
lr: 0.0001
|
79 |
+
weight_decay: 0
|
80 |
+
|
81 |
+
scheduler:
|
82 |
+
class_name: dva.scheduler.CosineWarmupScheduler
|
83 |
+
warmup_iters: 3000
|
84 |
+
max_iters: 200000
|
85 |
+
|
86 |
+
train:
|
87 |
+
batch_size: 8
|
88 |
+
n_workers: 4
|
89 |
+
n_epochs: 1000
|
90 |
+
log_every_n_steps: 50
|
91 |
+
summary_every_n_steps: 10000
|
92 |
+
ckpt_every_n_steps: 10000
|
93 |
+
amp: False
|
94 |
+
precision: tf32
|
95 |
+
|
96 |
+
tag: 3dtopia-xl-sview
|
97 |
+
output_dir: ${root_data_dir}/inference/${tag}
|
dva/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
dva/attr_dict.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import json
|
8 |
+
|
9 |
+
|
10 |
+
class AttrDict:
|
11 |
+
def __init__(self, entries):
|
12 |
+
self.add_entries_(entries)
|
13 |
+
|
14 |
+
def keys(self):
|
15 |
+
return self.__dict__.keys()
|
16 |
+
|
17 |
+
def values(self):
|
18 |
+
return self.__dict__.values()
|
19 |
+
|
20 |
+
def __getitem__(self, key):
|
21 |
+
return self.__dict__[key]
|
22 |
+
|
23 |
+
def __setitem__(self, key, value):
|
24 |
+
self.__dict__[key] = value
|
25 |
+
|
26 |
+
def __delitem__(self, key):
|
27 |
+
return self.__dict__.__delitem__(key)
|
28 |
+
|
29 |
+
def __contains__(self, key):
|
30 |
+
return key in self.__dict__
|
31 |
+
|
32 |
+
def __repr__(self):
|
33 |
+
return self.__dict__.__repr__()
|
34 |
+
|
35 |
+
def __getattr__(self, attr):
|
36 |
+
if attr.startswith("__"):
|
37 |
+
return self.__getattribute__(attr)
|
38 |
+
return self.__dict__[attr]
|
39 |
+
|
40 |
+
def items(self):
|
41 |
+
return self.__dict__.items()
|
42 |
+
|
43 |
+
def __iter__(self):
|
44 |
+
return iter(self.items())
|
45 |
+
|
46 |
+
def add_entries_(self, entries, overwrite=True):
|
47 |
+
for key, value in entries.items():
|
48 |
+
if key not in self.__dict__:
|
49 |
+
if isinstance(value, dict):
|
50 |
+
self.__dict__[key] = AttrDict(value)
|
51 |
+
else:
|
52 |
+
self.__dict__[key] = value
|
53 |
+
else:
|
54 |
+
if isinstance(value, dict):
|
55 |
+
self.__dict__[key].add_entries_(entries=value, overwrite=overwrite)
|
56 |
+
elif overwrite or self.__dict__[key] is None:
|
57 |
+
self.__dict__[key] = value
|
58 |
+
|
59 |
+
def serialize(self):
|
60 |
+
return json.dumps(self, default=self.obj_to_dict, indent=4)
|
61 |
+
|
62 |
+
def obj_to_dict(self, obj):
|
63 |
+
return obj.__dict__
|
64 |
+
|
65 |
+
def get(self, key, default=None):
|
66 |
+
return self.__dict__.get(key, default)
|
dva/geom.py
ADDED
@@ -0,0 +1,653 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
import numpy as np
|
3 |
+
import torch as th
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from sklearn.neighbors import KDTree
|
8 |
+
|
9 |
+
import logging
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
# NOTE: we need pytorch3d primarily for UV rasterization things
|
14 |
+
from pytorch3d.renderer.mesh.rasterize_meshes import rasterize_meshes
|
15 |
+
from pytorch3d.structures import Meshes
|
16 |
+
from typing import Union, Optional, Tuple
|
17 |
+
import trimesh
|
18 |
+
from trimesh import Trimesh
|
19 |
+
from trimesh.triangles import points_to_barycentric
|
20 |
+
|
21 |
+
try:
|
22 |
+
# pyre-fixme[21]: Could not find module `igl`.
|
23 |
+
from igl import point_mesh_squared_distance # @manual
|
24 |
+
|
25 |
+
# pyre-fixme[3]: Return type must be annotated.
|
26 |
+
# pyre-fixme[2]: Parameter must be annotated.
|
27 |
+
def closest_point(mesh, points):
|
28 |
+
"""Helper function that mimics trimesh.proximity.closest_point but uses
|
29 |
+
IGL for faster queries."""
|
30 |
+
v = mesh.vertices
|
31 |
+
vi = mesh.faces
|
32 |
+
dist, face_idxs, p = point_mesh_squared_distance(points, v, vi)
|
33 |
+
return p, dist, face_idxs
|
34 |
+
|
35 |
+
except ImportError:
|
36 |
+
from trimesh.proximity import closest_point
|
37 |
+
|
38 |
+
|
39 |
+
def closest_point_barycentrics(v, vi, points):
|
40 |
+
"""Given a 3D mesh and a set of query points, return closest point barycentrics
|
41 |
+
Args:
|
42 |
+
v: np.array (float)
|
43 |
+
[N, 3] mesh vertices
|
44 |
+
|
45 |
+
vi: np.array (int)
|
46 |
+
[N, 3] mesh triangle indices
|
47 |
+
|
48 |
+
points: np.array (float)
|
49 |
+
[M, 3] query points
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
Tuple[approx, barys, interp_idxs, face_idxs]
|
53 |
+
approx: [M, 3] approximated (closest) points on the mesh
|
54 |
+
barys: [M, 3] barycentric weights that produce "approx"
|
55 |
+
interp_idxs: [M, 3] vertex indices for barycentric interpolation
|
56 |
+
face_idxs: [M] face indices for barycentric interpolation. interp_idxs = vi[face_idxs]
|
57 |
+
"""
|
58 |
+
mesh = Trimesh(vertices=v, faces=vi, process=False)
|
59 |
+
p, _, face_idxs = closest_point(mesh, points)
|
60 |
+
p = p.reshape((points.shape[0], 3))
|
61 |
+
face_idxs = face_idxs.reshape((points.shape[0],))
|
62 |
+
barys = points_to_barycentric(mesh.triangles[face_idxs], p)
|
63 |
+
b0, b1, b2 = np.split(barys, 3, axis=1)
|
64 |
+
|
65 |
+
interp_idxs = vi[face_idxs]
|
66 |
+
v0 = v[interp_idxs[:, 0]]
|
67 |
+
v1 = v[interp_idxs[:, 1]]
|
68 |
+
v2 = v[interp_idxs[:, 2]]
|
69 |
+
approx = b0 * v0 + b1 * v1 + b2 * v2
|
70 |
+
return approx, barys, interp_idxs, face_idxs
|
71 |
+
|
72 |
+
def make_uv_face_index(
|
73 |
+
vt: th.Tensor,
|
74 |
+
vti: th.Tensor,
|
75 |
+
uv_shape: Union[Tuple[int, int], int],
|
76 |
+
flip_uv: bool = True,
|
77 |
+
device: Optional[Union[str, th.device]] = None,
|
78 |
+
):
|
79 |
+
"""Compute a UV-space face index map identifying which mesh face contains each
|
80 |
+
texel. For texels with no assigned triangle, the index will be -1."""
|
81 |
+
|
82 |
+
if isinstance(uv_shape, int):
|
83 |
+
uv_shape = (uv_shape, uv_shape)
|
84 |
+
|
85 |
+
uv_max_shape_ind = uv_shape.index(max(uv_shape))
|
86 |
+
uv_min_shape_ind = uv_shape.index(min(uv_shape))
|
87 |
+
uv_ratio = uv_shape[uv_max_shape_ind] / uv_shape[uv_min_shape_ind]
|
88 |
+
|
89 |
+
if device is not None:
|
90 |
+
if isinstance(device, str):
|
91 |
+
dev = th.device(device)
|
92 |
+
else:
|
93 |
+
dev = device
|
94 |
+
assert dev.type == "cuda"
|
95 |
+
else:
|
96 |
+
dev = th.device("cuda")
|
97 |
+
|
98 |
+
vt = 1.0 - vt.clone()
|
99 |
+
|
100 |
+
if flip_uv:
|
101 |
+
vt = vt.clone()
|
102 |
+
vt[:, 1] = 1 - vt[:, 1]
|
103 |
+
vt_pix = 2.0 * vt.to(dev) - 1.0
|
104 |
+
vt_pix = th.cat([vt_pix, th.ones_like(vt_pix[:, 0:1])], dim=1)
|
105 |
+
|
106 |
+
vt_pix[:, uv_min_shape_ind] *= uv_ratio
|
107 |
+
meshes = Meshes(vt_pix[np.newaxis], vti[np.newaxis].to(dev))
|
108 |
+
with th.no_grad():
|
109 |
+
face_index, _, _, _ = rasterize_meshes(
|
110 |
+
meshes, uv_shape, faces_per_pixel=1, z_clip_value=0.0, bin_size=0
|
111 |
+
)
|
112 |
+
face_index = face_index[0, ..., 0]
|
113 |
+
return face_index
|
114 |
+
|
115 |
+
|
116 |
+
def make_uv_vert_index(
|
117 |
+
vt: th.Tensor,
|
118 |
+
vi: th.Tensor,
|
119 |
+
vti: th.Tensor,
|
120 |
+
uv_shape: Union[Tuple[int, int], int],
|
121 |
+
flip_uv: bool = True,
|
122 |
+
):
|
123 |
+
"""Compute a UV-space vertex index map identifying which mesh vertices
|
124 |
+
comprise the triangle containing each texel. For texels with no assigned
|
125 |
+
triangle, all indices will be -1.
|
126 |
+
"""
|
127 |
+
face_index_map = make_uv_face_index(vt, vti, uv_shape, flip_uv)
|
128 |
+
vert_index_map = vi[face_index_map.clamp(min=0)]
|
129 |
+
vert_index_map[face_index_map < 0] = -1
|
130 |
+
return vert_index_map.long()
|
131 |
+
|
132 |
+
|
133 |
+
def bary_coords(points: th.Tensor, triangles: th.Tensor, eps: float = 1.0e-6):
|
134 |
+
"""Computes barycentric coordinates for a set of 2D query points given
|
135 |
+
coordintes for the 3 vertices of the enclosing triangle for each point."""
|
136 |
+
x = points[:, 0] - triangles[2, :, 0]
|
137 |
+
x1 = triangles[0, :, 0] - triangles[2, :, 0]
|
138 |
+
x2 = triangles[1, :, 0] - triangles[2, :, 0]
|
139 |
+
y = points[:, 1] - triangles[2, :, 1]
|
140 |
+
y1 = triangles[0, :, 1] - triangles[2, :, 1]
|
141 |
+
y2 = triangles[1, :, 1] - triangles[2, :, 1]
|
142 |
+
denom = y2 * x1 - y1 * x2
|
143 |
+
n0 = y2 * x - x2 * y
|
144 |
+
n1 = x1 * y - y1 * x
|
145 |
+
|
146 |
+
# Small epsilon to prevent divide-by-zero error.
|
147 |
+
denom = th.where(denom >= 0, denom.clamp(min=eps), denom.clamp(max=-eps))
|
148 |
+
|
149 |
+
bary_0 = n0 / denom
|
150 |
+
bary_1 = n1 / denom
|
151 |
+
bary_2 = 1.0 - bary_0 - bary_1
|
152 |
+
|
153 |
+
return th.stack((bary_0, bary_1, bary_2))
|
154 |
+
|
155 |
+
|
156 |
+
def make_uv_barys(
|
157 |
+
vt: th.Tensor,
|
158 |
+
vti: th.Tensor,
|
159 |
+
uv_shape: Union[Tuple[int, int], int],
|
160 |
+
flip_uv: bool = True,
|
161 |
+
):
|
162 |
+
"""Compute a UV-space barycentric map where each texel contains barycentric
|
163 |
+
coordinates for that texel within its enclosing UV triangle. For texels
|
164 |
+
with no assigned triangle, all 3 barycentric coordinates will be 0.
|
165 |
+
"""
|
166 |
+
if isinstance(uv_shape, int):
|
167 |
+
uv_shape = (uv_shape, uv_shape)
|
168 |
+
|
169 |
+
if flip_uv:
|
170 |
+
# Flip here because texture coordinates in some of our topo files are
|
171 |
+
# stored in OpenGL convention with Y=0 on the bottom of the texture
|
172 |
+
# unlike numpy/torch arrays/tensors.
|
173 |
+
vt = vt.clone()
|
174 |
+
vt[:, 1] = 1 - vt[:, 1]
|
175 |
+
|
176 |
+
face_index_map = make_uv_face_index(vt, vti, uv_shape, flip_uv=False)
|
177 |
+
vti_map = vti.long()[face_index_map.clamp(min=0)]
|
178 |
+
|
179 |
+
uv_max_shape_ind = uv_shape.index(max(uv_shape))
|
180 |
+
uv_min_shape_ind = uv_shape.index(min(uv_shape))
|
181 |
+
uv_ratio = uv_shape[uv_max_shape_ind] / uv_shape[uv_min_shape_ind]
|
182 |
+
vt = vt.clone()
|
183 |
+
vt = vt * 2 - 1
|
184 |
+
vt[:, uv_min_shape_ind] *= uv_ratio
|
185 |
+
uv_tri_uvs = vt[vti_map].permute(2, 0, 1, 3)
|
186 |
+
|
187 |
+
uv_grid = th.meshgrid(
|
188 |
+
th.linspace(0.5, uv_shape[0] - 0.5, uv_shape[0]) / uv_shape[0],
|
189 |
+
th.linspace(0.5, uv_shape[1] - 0.5, uv_shape[1]) / uv_shape[1],
|
190 |
+
)
|
191 |
+
uv_grid = th.stack(uv_grid[::-1], dim=2).to(uv_tri_uvs)
|
192 |
+
uv_grid = uv_grid * 2 - 1
|
193 |
+
uv_grid[..., uv_min_shape_ind] *= uv_ratio
|
194 |
+
|
195 |
+
bary_map = bary_coords(uv_grid.view(-1, 2), uv_tri_uvs.view(3, -1, 2))
|
196 |
+
bary_map = bary_map.permute(1, 0).view(uv_shape[0], uv_shape[1], 3)
|
197 |
+
bary_map[face_index_map < 0] = 0
|
198 |
+
return face_index_map, bary_map
|
199 |
+
|
200 |
+
|
201 |
+
def index_image_impaint(
|
202 |
+
index_image: th.Tensor,
|
203 |
+
bary_image: Optional[th.Tensor] = None,
|
204 |
+
distance_threshold=100.0,
|
205 |
+
):
|
206 |
+
# getting the mask around the indexes?
|
207 |
+
if len(index_image.shape) == 3:
|
208 |
+
valid_index = (index_image != -1).any(dim=-1)
|
209 |
+
elif len(index_image.shape) == 2:
|
210 |
+
valid_index = index_image != -1
|
211 |
+
else:
|
212 |
+
raise ValueError("`index_image` should be a [H,W] or [H,W,C] image")
|
213 |
+
|
214 |
+
invalid_index = ~valid_index
|
215 |
+
|
216 |
+
device = index_image.device
|
217 |
+
|
218 |
+
valid_ij = th.stack(th.where(valid_index), dim=-1)
|
219 |
+
invalid_ij = th.stack(th.where(invalid_index), dim=-1)
|
220 |
+
lookup_valid = KDTree(valid_ij.cpu().numpy())
|
221 |
+
|
222 |
+
dists, idxs = lookup_valid.query(invalid_ij.cpu())
|
223 |
+
|
224 |
+
# TODO: try average?
|
225 |
+
idxs = th.as_tensor(idxs, device=device)[..., 0]
|
226 |
+
dists = th.as_tensor(dists, device=device)[..., 0]
|
227 |
+
|
228 |
+
dist_mask = dists < distance_threshold
|
229 |
+
|
230 |
+
invalid_border = th.zeros_like(invalid_index)
|
231 |
+
invalid_border[invalid_index] = dist_mask
|
232 |
+
|
233 |
+
invalid_src_ij = valid_ij[idxs][dist_mask]
|
234 |
+
invalid_dst_ij = invalid_ij[dist_mask]
|
235 |
+
|
236 |
+
index_image_imp = index_image.clone()
|
237 |
+
|
238 |
+
index_image_imp[invalid_dst_ij[:, 0], invalid_dst_ij[:, 1]] = index_image[
|
239 |
+
invalid_src_ij[:, 0], invalid_src_ij[:, 1]
|
240 |
+
]
|
241 |
+
|
242 |
+
if bary_image is not None:
|
243 |
+
bary_image_imp = bary_image.clone()
|
244 |
+
|
245 |
+
bary_image_imp[invalid_dst_ij[:, 0], invalid_dst_ij[:, 1]] = bary_image[
|
246 |
+
invalid_src_ij[:, 0], invalid_src_ij[:, 1]
|
247 |
+
]
|
248 |
+
|
249 |
+
return index_image_imp, bary_image_imp
|
250 |
+
return index_image_imp
|
251 |
+
|
252 |
+
|
253 |
+
class GeometryModule(nn.Module):
|
254 |
+
def __init__(
|
255 |
+
self,
|
256 |
+
v,
|
257 |
+
vi,
|
258 |
+
vt,
|
259 |
+
vti,
|
260 |
+
uv_size,
|
261 |
+
v2uv: Optional[th.Tensor] = None,
|
262 |
+
flip_uv=False,
|
263 |
+
impaint=False,
|
264 |
+
impaint_threshold=100.0,
|
265 |
+
):
|
266 |
+
super().__init__()
|
267 |
+
|
268 |
+
self.register_buffer("v", th.as_tensor(v))
|
269 |
+
self.register_buffer("vi", th.as_tensor(vi))
|
270 |
+
self.register_buffer("vt", th.as_tensor(vt))
|
271 |
+
self.register_buffer("vti", th.as_tensor(vti))
|
272 |
+
if v2uv is not None:
|
273 |
+
self.register_buffer("v2uv", th.as_tensor(v2uv, dtype=th.int64))
|
274 |
+
|
275 |
+
# TODO: should we just pass topology here?
|
276 |
+
# self.n_verts = v2uv.shape[0]
|
277 |
+
self.n_verts = vi.max() + 1
|
278 |
+
|
279 |
+
self.uv_size = uv_size
|
280 |
+
|
281 |
+
# TODO: can't we just index face_index?
|
282 |
+
index_image = make_uv_vert_index(
|
283 |
+
self.vt, self.vi, self.vti, uv_shape=uv_size, flip_uv=flip_uv
|
284 |
+
).cpu()
|
285 |
+
face_index, bary_image = make_uv_barys(
|
286 |
+
self.vt, self.vti, uv_shape=uv_size, flip_uv=flip_uv
|
287 |
+
)
|
288 |
+
if impaint:
|
289 |
+
if min(uv_size) >= 1024:
|
290 |
+
logger.info(
|
291 |
+
"impainting index image might take a while for sizes >= 1024"
|
292 |
+
)
|
293 |
+
|
294 |
+
index_image, bary_image = index_image_impaint(
|
295 |
+
index_image, bary_image, impaint_threshold
|
296 |
+
)
|
297 |
+
# TODO: we can avoid doing this 2x
|
298 |
+
face_index = index_image_impaint(
|
299 |
+
face_index, distance_threshold=impaint_threshold
|
300 |
+
)
|
301 |
+
|
302 |
+
self.register_buffer("index_image", index_image.cpu())
|
303 |
+
self.register_buffer("bary_image", bary_image.cpu())
|
304 |
+
self.register_buffer("face_index_image", face_index.cpu())
|
305 |
+
|
306 |
+
def render_index_images(self, uv_size, flip_uv=False, impaint=False):
|
307 |
+
index_image = make_uv_vert_index(
|
308 |
+
self.vt, self.vi, self.vti, uv_shape=uv_size, flip_uv=flip_uv
|
309 |
+
)
|
310 |
+
face_image, bary_image = make_uv_barys(
|
311 |
+
self.vt, self.vti, uv_shape=uv_size, flip_uv=flip_uv
|
312 |
+
)
|
313 |
+
|
314 |
+
if impaint:
|
315 |
+
index_image, bary_image = index_image_impaint(
|
316 |
+
index_image,
|
317 |
+
bary_image,
|
318 |
+
)
|
319 |
+
|
320 |
+
return index_image, face_image, bary_image
|
321 |
+
|
322 |
+
def vn(self, verts):
|
323 |
+
return vert_normals(verts, self.vi[np.newaxis].to(th.long))
|
324 |
+
|
325 |
+
def to_uv(self, values):
|
326 |
+
return values_to_uv(values, self.index_image, self.bary_image)
|
327 |
+
|
328 |
+
def from_uv(self, values_uv):
|
329 |
+
# TODO: we need to sample this
|
330 |
+
return sample_uv(values_uv, self.vt, self.v2uv.to(th.long))
|
331 |
+
|
332 |
+
def rand_sample_3d_uv(self, count, uv_img):
|
333 |
+
"""
|
334 |
+
Sample a set of 3D points on the surface of mesh, return corresponding interpolated values in UV space.
|
335 |
+
|
336 |
+
Args:
|
337 |
+
count - num of 3D points to be sampled
|
338 |
+
|
339 |
+
uv_img - the image in uv space to be sampled, e.g., texture
|
340 |
+
"""
|
341 |
+
_mesh = Trimesh(vertices=self.v.detach().cpu().numpy(), faces=self.vi.detach().cpu().numpy(), process=False)
|
342 |
+
points, _ = trimesh.sample.sample_surface(_mesh, count)
|
343 |
+
return self.sample_uv_from_3dpts(points, uv_img)
|
344 |
+
|
345 |
+
def sample_uv_from_3dpts(self, points, uv_img):
|
346 |
+
num_pts = points.shape[0]
|
347 |
+
approx, barys, interp_idxs, face_idxs = closest_point_barycentrics(self.v.detach().cpu().numpy(), self.vi.detach().cpu().numpy(), points)
|
348 |
+
interp_uv_coords = self.vt[interp_idxs, :] # [N, 3, 2]
|
349 |
+
# do bary interp first to get interp_uv_coord in high-reso uv space
|
350 |
+
target_uv_coords = th.sum(interp_uv_coords * th.from_numpy(barys)[..., None], dim=1).float()
|
351 |
+
# then directly sample from uv space
|
352 |
+
sampled_values = sample_uv(values_uv=uv_img.permute(2, 0, 1)[None, ...], uv_coords=target_uv_coords) # [1, count, c]
|
353 |
+
approx_values = sampled_values[0].reshape(num_pts, uv_img.shape[2])
|
354 |
+
return approx_values.numpy(), points
|
355 |
+
|
356 |
+
def vert_sample_uv(self, uv_img):
|
357 |
+
count = self.v.shape[0]
|
358 |
+
points = self.v.detach().cpu().numpy()
|
359 |
+
approx_values, _ = self.sample_uv_from_3dpts(points, uv_img)
|
360 |
+
return approx_values
|
361 |
+
|
362 |
+
|
363 |
+
def sample_uv(
|
364 |
+
values_uv,
|
365 |
+
uv_coords,
|
366 |
+
v2uv: Optional[th.Tensor] = None,
|
367 |
+
mode: str = "bilinear",
|
368 |
+
align_corners: bool = True,
|
369 |
+
flip_uvs: bool = False,
|
370 |
+
):
|
371 |
+
batch_size = values_uv.shape[0]
|
372 |
+
|
373 |
+
if flip_uvs:
|
374 |
+
uv_coords = uv_coords.clone()
|
375 |
+
uv_coords[:, 1] = 1.0 - uv_coords[:, 1]
|
376 |
+
|
377 |
+
# uv_coords_norm is [1, N, 1, 2] afterwards
|
378 |
+
uv_coords_norm = (uv_coords * 2.0 - 1.0)[np.newaxis, :, np.newaxis].expand(
|
379 |
+
batch_size, -1, -1, -1
|
380 |
+
)
|
381 |
+
# uv_shape = values_uv.shape[-2:]
|
382 |
+
# uv_max_shape_ind = uv_shape.index(max(uv_shape))
|
383 |
+
# uv_min_shape_ind = uv_shape.index(min(uv_shape))
|
384 |
+
# uv_ratio = uv_shape[uv_max_shape_ind] / uv_shape[uv_min_shape_ind]
|
385 |
+
# uv_coords_norm[..., uv_min_shape_ind] *= uv_ratio
|
386 |
+
|
387 |
+
values = (
|
388 |
+
F.grid_sample(values_uv, uv_coords_norm, align_corners=align_corners, mode=mode)
|
389 |
+
.squeeze(-1)
|
390 |
+
.permute((0, 2, 1))
|
391 |
+
)
|
392 |
+
|
393 |
+
if v2uv is not None:
|
394 |
+
values_duplicate = values[:, v2uv]
|
395 |
+
values = values_duplicate.mean(2)
|
396 |
+
|
397 |
+
return values
|
398 |
+
|
399 |
+
|
400 |
+
def values_to_uv(values, index_img, bary_img):
|
401 |
+
uv_size = index_img.shape
|
402 |
+
index_mask = th.all(index_img != -1, dim=-1)
|
403 |
+
idxs_flat = index_img[index_mask].to(th.int64)
|
404 |
+
bary_flat = bary_img[index_mask].to(th.float32)
|
405 |
+
# NOTE: here we assume
|
406 |
+
values_flat = th.sum(values[:, idxs_flat].permute(0, 3, 1, 2) * bary_flat, dim=-1)
|
407 |
+
values_uv = th.zeros(
|
408 |
+
values.shape[0],
|
409 |
+
values.shape[-1],
|
410 |
+
uv_size[0],
|
411 |
+
uv_size[1],
|
412 |
+
dtype=values.dtype,
|
413 |
+
device=values.device,
|
414 |
+
)
|
415 |
+
values_uv[:, :, index_mask] = values_flat
|
416 |
+
return values_uv
|
417 |
+
|
418 |
+
|
419 |
+
def face_normals(v, vi, eps: float = 1e-5):
|
420 |
+
pts = v[:, vi]
|
421 |
+
v0 = pts[:, :, 1] - pts[:, :, 0]
|
422 |
+
v1 = pts[:, :, 2] - pts[:, :, 0]
|
423 |
+
n = th.cross(v0, v1, dim=-1)
|
424 |
+
norm = th.norm(n, dim=-1, keepdim=True)
|
425 |
+
norm[norm < eps] = 1
|
426 |
+
n /= norm
|
427 |
+
return n
|
428 |
+
|
429 |
+
|
430 |
+
def vert_normals(v, vi, eps: float = 1.0e-5):
|
431 |
+
fnorms = face_normals(v, vi)
|
432 |
+
fnorms = fnorms[:, :, None].expand(-1, -1, 3, -1).reshape(fnorms.shape[0], -1, 3)
|
433 |
+
vi_flat = vi.view(1, -1).expand(v.shape[0], -1)
|
434 |
+
vnorms = th.zeros_like(v)
|
435 |
+
for j in range(3):
|
436 |
+
vnorms[..., j].scatter_add_(1, vi_flat, fnorms[..., j])
|
437 |
+
norm = th.norm(vnorms, dim=-1, keepdim=True)
|
438 |
+
norm[norm < eps] = 1
|
439 |
+
vnorms /= norm
|
440 |
+
return vnorms
|
441 |
+
|
442 |
+
|
443 |
+
def compute_view_cos(verts, faces, camera_pos):
|
444 |
+
vn = F.normalize(vert_normals(verts, faces), dim=-1)
|
445 |
+
v2c = F.normalize(verts - camera_pos[:, np.newaxis], dim=-1)
|
446 |
+
return th.einsum("bnd,bnd->bn", vn, v2c)
|
447 |
+
|
448 |
+
|
449 |
+
def compute_tbn(geom, vt, vi, vti):
|
450 |
+
"""Computes tangent, bitangent, and normal vectors given a mesh.
|
451 |
+
Args:
|
452 |
+
geom: [N, n_verts, 3] th.Tensor
|
453 |
+
Vertex positions.
|
454 |
+
vt: [n_uv_coords, 2] th.Tensor
|
455 |
+
UV coordinates.
|
456 |
+
vi: [..., 3] th.Tensor
|
457 |
+
Face vertex indices.
|
458 |
+
vti: [..., 3] th.Tensor
|
459 |
+
Face UV indices.
|
460 |
+
Returns:
|
461 |
+
[..., 3] th.Tensors for T, B, N.
|
462 |
+
"""
|
463 |
+
|
464 |
+
v0 = geom[:, vi[..., 0]]
|
465 |
+
v1 = geom[:, vi[..., 1]]
|
466 |
+
v2 = geom[:, vi[..., 2]]
|
467 |
+
vt0 = vt[vti[..., 0]]
|
468 |
+
vt1 = vt[vti[..., 1]]
|
469 |
+
vt2 = vt[vti[..., 2]]
|
470 |
+
|
471 |
+
v01 = v1 - v0
|
472 |
+
v02 = v2 - v0
|
473 |
+
vt01 = vt1 - vt0
|
474 |
+
vt02 = vt2 - vt0
|
475 |
+
f = 1.0 / (
|
476 |
+
vt01[None, ..., 0] * vt02[None, ..., 1]
|
477 |
+
- vt01[None, ..., 1] * vt02[None, ..., 0]
|
478 |
+
)
|
479 |
+
tangent = f[..., None] * th.stack(
|
480 |
+
[
|
481 |
+
v01[..., 0] * vt02[None, ..., 1] - v02[..., 0] * vt01[None, ..., 1],
|
482 |
+
v01[..., 1] * vt02[None, ..., 1] - v02[..., 1] * vt01[None, ..., 1],
|
483 |
+
v01[..., 2] * vt02[None, ..., 1] - v02[..., 2] * vt01[None, ..., 1],
|
484 |
+
],
|
485 |
+
dim=-1,
|
486 |
+
)
|
487 |
+
tangent = F.normalize(tangent, dim=-1)
|
488 |
+
normal = F.normalize(th.cross(v01, v02, dim=3), dim=-1)
|
489 |
+
bitangent = F.normalize(th.cross(tangent, normal, dim=3), dim=-1)
|
490 |
+
|
491 |
+
return tangent, bitangent, normal
|
492 |
+
|
493 |
+
|
494 |
+
def compute_v2uv(n_verts, vi, vti, n_max=4):
|
495 |
+
"""Computes mapping from vertex indices to texture indices.
|
496 |
+
|
497 |
+
Args:
|
498 |
+
vi: [F, 3], triangles
|
499 |
+
vti: [F, 3], texture triangles
|
500 |
+
n_max: int, max number of texture locations
|
501 |
+
|
502 |
+
Returns:
|
503 |
+
[n_verts, n_max], texture indices
|
504 |
+
"""
|
505 |
+
v2uv_dict = {}
|
506 |
+
for i_v, i_uv in zip(vi.reshape(-1), vti.reshape(-1)):
|
507 |
+
v2uv_dict.setdefault(i_v, set()).add(i_uv)
|
508 |
+
assert len(v2uv_dict) == n_verts
|
509 |
+
v2uv = np.zeros((n_verts, n_max), dtype=np.int32)
|
510 |
+
for i in range(n_verts):
|
511 |
+
vals = sorted(list(v2uv_dict[i]))
|
512 |
+
v2uv[i, :] = vals[0]
|
513 |
+
v2uv[i, : len(vals)] = np.array(vals)
|
514 |
+
return v2uv
|
515 |
+
|
516 |
+
|
517 |
+
def compute_neighbours(n_verts, vi, n_max_values=10):
|
518 |
+
"""Computes first-ring neighbours given vertices and faces."""
|
519 |
+
n_vi = vi.shape[0]
|
520 |
+
|
521 |
+
adj = {i: set() for i in range(n_verts)}
|
522 |
+
for i in range(n_vi):
|
523 |
+
for idx in vi[i]:
|
524 |
+
adj[idx] |= set(vi[i]) - set([idx])
|
525 |
+
|
526 |
+
nbs_idxs = np.tile(np.arange(n_verts)[:, np.newaxis], (1, n_max_values))
|
527 |
+
nbs_weights = np.zeros((n_verts, n_max_values), dtype=np.float32)
|
528 |
+
|
529 |
+
for idx in range(n_verts):
|
530 |
+
n_values = min(len(adj[idx]), n_max_values)
|
531 |
+
nbs_idxs[idx, :n_values] = np.array(list(adj[idx]))[:n_values]
|
532 |
+
nbs_weights[idx, :n_values] = -1.0 / n_values
|
533 |
+
|
534 |
+
return nbs_idxs, nbs_weights
|
535 |
+
|
536 |
+
|
537 |
+
def make_postex(v, idxim, barim):
|
538 |
+
return (
|
539 |
+
barim[None, :, :, 0, None] * v[:, idxim[:, :, 0]]
|
540 |
+
+ barim[None, :, :, 1, None] * v[:, idxim[:, :, 1]]
|
541 |
+
+ barim[None, :, :, 2, None] * v[:, idxim[:, :, 2]]
|
542 |
+
).permute(0, 3, 1, 2)
|
543 |
+
|
544 |
+
|
545 |
+
def matrix_to_axisangle(r):
|
546 |
+
th = th.arccos(0.5 * (r[..., 0, 0] + r[..., 1, 1] + r[..., 2, 2] - 1.0))[..., None]
|
547 |
+
vec = (
|
548 |
+
0.5
|
549 |
+
* th.stack(
|
550 |
+
[
|
551 |
+
r[..., 2, 1] - r[..., 1, 2],
|
552 |
+
r[..., 0, 2] - r[..., 2, 0],
|
553 |
+
r[..., 1, 0] - r[..., 0, 1],
|
554 |
+
],
|
555 |
+
dim=-1,
|
556 |
+
)
|
557 |
+
/ th.sin(th)
|
558 |
+
)
|
559 |
+
return th, vec
|
560 |
+
|
561 |
+
|
562 |
+
def axisangle_to_matrix(rvec):
|
563 |
+
theta = th.sqrt(1e-5 + th.sum(rvec**2, dim=-1))
|
564 |
+
rvec = rvec / theta[..., None]
|
565 |
+
costh = th.cos(theta)
|
566 |
+
sinth = th.sin(theta)
|
567 |
+
return th.stack(
|
568 |
+
(
|
569 |
+
th.stack(
|
570 |
+
(
|
571 |
+
rvec[..., 0] ** 2 + (1.0 - rvec[..., 0] ** 2) * costh,
|
572 |
+
rvec[..., 0] * rvec[..., 1] * (1.0 - costh) - rvec[..., 2] * sinth,
|
573 |
+
rvec[..., 0] * rvec[..., 2] * (1.0 - costh) + rvec[..., 1] * sinth,
|
574 |
+
),
|
575 |
+
dim=-1,
|
576 |
+
),
|
577 |
+
th.stack(
|
578 |
+
(
|
579 |
+
rvec[..., 0] * rvec[..., 1] * (1.0 - costh) + rvec[..., 2] * sinth,
|
580 |
+
rvec[..., 1] ** 2 + (1.0 - rvec[..., 1] ** 2) * costh,
|
581 |
+
rvec[..., 1] * rvec[..., 2] * (1.0 - costh) - rvec[..., 0] * sinth,
|
582 |
+
),
|
583 |
+
dim=-1,
|
584 |
+
),
|
585 |
+
th.stack(
|
586 |
+
(
|
587 |
+
rvec[..., 0] * rvec[..., 2] * (1.0 - costh) - rvec[..., 1] * sinth,
|
588 |
+
rvec[..., 1] * rvec[..., 2] * (1.0 - costh) + rvec[..., 0] * sinth,
|
589 |
+
rvec[..., 2] ** 2 + (1.0 - rvec[..., 2] ** 2) * costh,
|
590 |
+
),
|
591 |
+
dim=-1,
|
592 |
+
),
|
593 |
+
),
|
594 |
+
dim=-2,
|
595 |
+
)
|
596 |
+
|
597 |
+
|
598 |
+
def rotation_interp(r0, r1, alpha):
|
599 |
+
r0a = r0.view(-1, 3, 3)
|
600 |
+
r1a = r1.view(-1, 3, 3)
|
601 |
+
r = th.bmm(r0a.permute(0, 2, 1), r1a).view_as(r0)
|
602 |
+
|
603 |
+
th, rvec = matrix_to_axisangle(r)
|
604 |
+
rvec = rvec * (alpha * th)
|
605 |
+
|
606 |
+
r = axisangle_to_matrix(rvec)
|
607 |
+
return th.bmm(r0a, r.view(-1, 3, 3)).view_as(r0)
|
608 |
+
|
609 |
+
|
610 |
+
def convert_camera_parameters(Rt, K):
|
611 |
+
R = Rt[:, :3, :3]
|
612 |
+
t = -R.permute(0, 2, 1).bmm(Rt[:, :3, 3].unsqueeze(2)).squeeze(2)
|
613 |
+
return dict(
|
614 |
+
campos=t,
|
615 |
+
camrot=R,
|
616 |
+
focal=K[:, :2, :2],
|
617 |
+
princpt=K[:, :2, 2],
|
618 |
+
)
|
619 |
+
|
620 |
+
|
621 |
+
def project_points_multi(p, Rt, K, normalize=False, size=None):
|
622 |
+
"""Project a set of 3D points into multiple cameras with a pinhole model.
|
623 |
+
Args:
|
624 |
+
p: [B, N, 3], input 3D points in world coordinates
|
625 |
+
Rt: [B, NC, 3, 4], extrinsics (where NC is the number of cameras to project to)
|
626 |
+
K: [B, NC, 3, 3], intrinsics
|
627 |
+
normalize: bool, whether to normalize coordinates to [-1.0, 1.0]
|
628 |
+
Returns:
|
629 |
+
tuple:
|
630 |
+
- [B, NC, N, 2] - projected points
|
631 |
+
- [B, NC, N] - their
|
632 |
+
"""
|
633 |
+
B, N = p.shape[:2]
|
634 |
+
NC = Rt.shape[1]
|
635 |
+
|
636 |
+
Rt = Rt.reshape(B * NC, 3, 4)
|
637 |
+
K = K.reshape(B * NC, 3, 3)
|
638 |
+
|
639 |
+
# [B, N, 3] -> [B * NC, N, 3]
|
640 |
+
p = p[:, np.newaxis].expand(-1, NC, -1, -1).reshape(B * NC, -1, 3)
|
641 |
+
p_cam = p @ Rt[:, :3, :3].transpose(-2, -1) + Rt[:, :3, 3][:, np.newaxis]
|
642 |
+
p_pix = p_cam @ K.transpose(-2, -1)
|
643 |
+
p_depth = p_pix[:, :, 2:]
|
644 |
+
p_pix = (p_pix[..., :2] / p_depth).reshape(B, NC, N, 2)
|
645 |
+
p_depth = p_depth.reshape(B, NC, N)
|
646 |
+
|
647 |
+
if normalize:
|
648 |
+
assert size is not None
|
649 |
+
h, w = size
|
650 |
+
p_pix = (
|
651 |
+
2.0 * p_pix / th.as_tensor([w, h], dtype=th.float32, device=p.device) - 1.0
|
652 |
+
)
|
653 |
+
return p_pix, p_depth
|
dva/io.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import json
|
8 |
+
import cv2
|
9 |
+
import numpy as np
|
10 |
+
import copy
|
11 |
+
import importlib
|
12 |
+
from typing import Any, Dict
|
13 |
+
|
14 |
+
def load_module(module_name, class_name=None, silent: bool = False):
|
15 |
+
module = importlib.import_module(module_name)
|
16 |
+
return getattr(module, class_name) if class_name else module
|
17 |
+
|
18 |
+
|
19 |
+
def load_class(class_name):
|
20 |
+
return load_module(*class_name.rsplit(".", 1))
|
21 |
+
|
22 |
+
|
23 |
+
def load_from_config(config, **kwargs):
|
24 |
+
"""Instantiate an object given a config and arguments."""
|
25 |
+
assert "class_name" in config and "module_name" not in config
|
26 |
+
config = copy.deepcopy(config)
|
27 |
+
class_name = config.pop("class_name")
|
28 |
+
object_class = load_class(class_name)
|
29 |
+
return object_class(**config, **kwargs)
|
30 |
+
|
31 |
+
|
32 |
+
def load_opencv_calib(extrin_path, intrin_path):
|
33 |
+
cameras = {}
|
34 |
+
|
35 |
+
fse = cv2.FileStorage()
|
36 |
+
fse.open(extrin_path, cv2.FileStorage_READ)
|
37 |
+
|
38 |
+
fsi = cv2.FileStorage()
|
39 |
+
fsi.open(intrin_path, cv2.FileStorage_READ)
|
40 |
+
|
41 |
+
names = [
|
42 |
+
fse.getNode("names").at(c).string() for c in range(fse.getNode("names").size())
|
43 |
+
]
|
44 |
+
|
45 |
+
for camera in names:
|
46 |
+
rot = fse.getNode(f"R_{camera}").mat()
|
47 |
+
R = fse.getNode(f"Rot_{camera}").mat()
|
48 |
+
T = fse.getNode(f"T_{camera}").mat()
|
49 |
+
R_pred = cv2.Rodrigues(rot)[0]
|
50 |
+
assert np.all(np.isclose(R_pred, R))
|
51 |
+
K = fsi.getNode(f"K_{camera}").mat()
|
52 |
+
cameras[camera] = {
|
53 |
+
"Rt": np.concatenate([R, T], axis=1).astype(np.float32),
|
54 |
+
"K": K.astype(np.float32),
|
55 |
+
}
|
56 |
+
return cameras
|
dva/layers.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch as th
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
from dva.mvp.models.utils import Conv2dWN, Conv2dWNUB, ConvTranspose2dWNUB, initmod
|
13 |
+
|
14 |
+
|
15 |
+
class ConvBlock(nn.Module):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
in_channels,
|
19 |
+
out_channels,
|
20 |
+
size,
|
21 |
+
lrelu_slope=0.2,
|
22 |
+
kernel_size=3,
|
23 |
+
padding=1,
|
24 |
+
wnorm_dim=0,
|
25 |
+
):
|
26 |
+
super().__init__()
|
27 |
+
|
28 |
+
self.conv_resize = Conv2dWN(in_channels, out_channels, kernel_size=1)
|
29 |
+
self.conv1 = Conv2dWNUB(
|
30 |
+
in_channels,
|
31 |
+
in_channels,
|
32 |
+
kernel_size=kernel_size,
|
33 |
+
padding=padding,
|
34 |
+
height=size,
|
35 |
+
width=size,
|
36 |
+
)
|
37 |
+
|
38 |
+
self.lrelu1 = nn.LeakyReLU(lrelu_slope)
|
39 |
+
self.conv2 = Conv2dWNUB(
|
40 |
+
in_channels,
|
41 |
+
out_channels,
|
42 |
+
kernel_size=kernel_size,
|
43 |
+
padding=padding,
|
44 |
+
height=size,
|
45 |
+
width=size,
|
46 |
+
)
|
47 |
+
self.lrelu2 = nn.LeakyReLU(lrelu_slope)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
x_skip = self.conv_resize(x)
|
51 |
+
x = self.conv1(x)
|
52 |
+
x = self.lrelu1(x)
|
53 |
+
x = self.conv2(x)
|
54 |
+
x = self.lrelu2(x)
|
55 |
+
return x + x_skip
|
56 |
+
|
57 |
+
|
58 |
+
def tile2d(x, size: int):
|
59 |
+
"""Tile a given set of features into a convolutional map.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
x: float tensor of shape [N, F]
|
63 |
+
size: int or a tuple
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
a feature map [N, F, size[0], size[1]]
|
67 |
+
"""
|
68 |
+
# size = size if isinstance(size, tuple) else (size, size)
|
69 |
+
# NOTE: expecting only int here (!!!)
|
70 |
+
return x[:, :, np.newaxis, np.newaxis].expand(-1, -1, size, size)
|
71 |
+
|
72 |
+
|
73 |
+
def weights_initializer(m, alpha: float = 1.0):
|
74 |
+
return initmod(m, nn.init.calculate_gain("leaky_relu", alpha))
|
75 |
+
|
76 |
+
|
77 |
+
class UNetWB(nn.Module):
|
78 |
+
def __init__(
|
79 |
+
self,
|
80 |
+
in_channels,
|
81 |
+
out_channels,
|
82 |
+
size,
|
83 |
+
n_init_ftrs=8,
|
84 |
+
out_scale=0.1,
|
85 |
+
):
|
86 |
+
# super().__init__(*args, **kwargs)
|
87 |
+
super().__init__()
|
88 |
+
|
89 |
+
self.out_scale = 0.1
|
90 |
+
|
91 |
+
F = n_init_ftrs
|
92 |
+
|
93 |
+
# TODO: allow changing the size?
|
94 |
+
self.size = size
|
95 |
+
|
96 |
+
self.down1 = nn.Sequential(
|
97 |
+
Conv2dWNUB(in_channels, F, self.size // 2, self.size // 2, 4, 2, 1),
|
98 |
+
nn.LeakyReLU(0.2),
|
99 |
+
)
|
100 |
+
self.down2 = nn.Sequential(
|
101 |
+
Conv2dWNUB(F, 2 * F, self.size // 4, self.size // 4, 4, 2, 1),
|
102 |
+
nn.LeakyReLU(0.2),
|
103 |
+
)
|
104 |
+
self.down3 = nn.Sequential(
|
105 |
+
Conv2dWNUB(2 * F, 4 * F, self.size // 8, self.size // 8, 4, 2, 1),
|
106 |
+
nn.LeakyReLU(0.2),
|
107 |
+
)
|
108 |
+
self.down4 = nn.Sequential(
|
109 |
+
Conv2dWNUB(4 * F, 8 * F, self.size // 16, self.size // 16, 4, 2, 1),
|
110 |
+
nn.LeakyReLU(0.2),
|
111 |
+
)
|
112 |
+
self.down5 = nn.Sequential(
|
113 |
+
Conv2dWNUB(8 * F, 16 * F, self.size // 32, self.size // 32, 4, 2, 1),
|
114 |
+
nn.LeakyReLU(0.2),
|
115 |
+
)
|
116 |
+
self.up1 = nn.Sequential(
|
117 |
+
ConvTranspose2dWNUB(
|
118 |
+
16 * F, 8 * F, self.size // 16, self.size // 16, 4, 2, 1
|
119 |
+
),
|
120 |
+
nn.LeakyReLU(0.2),
|
121 |
+
)
|
122 |
+
self.up2 = nn.Sequential(
|
123 |
+
ConvTranspose2dWNUB(8 * F, 4 * F, self.size // 8, self.size // 8, 4, 2, 1),
|
124 |
+
nn.LeakyReLU(0.2),
|
125 |
+
)
|
126 |
+
self.up3 = nn.Sequential(
|
127 |
+
ConvTranspose2dWNUB(4 * F, 2 * F, self.size // 4, self.size // 4, 4, 2, 1),
|
128 |
+
nn.LeakyReLU(0.2),
|
129 |
+
)
|
130 |
+
self.up4 = nn.Sequential(
|
131 |
+
ConvTranspose2dWNUB(2 * F, F, self.size // 2, self.size // 2, 4, 2, 1),
|
132 |
+
nn.LeakyReLU(0.2),
|
133 |
+
)
|
134 |
+
self.up5 = nn.Sequential(
|
135 |
+
ConvTranspose2dWNUB(F, F, self.size, self.size, 4, 2, 1), nn.LeakyReLU(0.2)
|
136 |
+
)
|
137 |
+
self.out = Conv2dWNUB(
|
138 |
+
F + in_channels, out_channels, self.size, self.size, kernel_size=1
|
139 |
+
)
|
140 |
+
self.apply(lambda x: initmod(x, 0.2))
|
141 |
+
initmod(self.out, 1.0)
|
142 |
+
|
143 |
+
def forward(self, x):
|
144 |
+
x1 = x
|
145 |
+
x2 = self.down1(x1)
|
146 |
+
x3 = self.down2(x2)
|
147 |
+
x4 = self.down3(x3)
|
148 |
+
x5 = self.down4(x4)
|
149 |
+
x6 = self.down5(x5)
|
150 |
+
# TODO: switch to concat?
|
151 |
+
x = self.up1(x6) + x5
|
152 |
+
x = self.up2(x) + x4
|
153 |
+
x = self.up3(x) + x3
|
154 |
+
x = self.up4(x) + x2
|
155 |
+
x = self.up5(x)
|
156 |
+
x = th.cat([x, x1], dim=1)
|
157 |
+
return self.out(x) * self.out_scale
|
dva/losses.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch as th
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
import logging
|
12 |
+
|
13 |
+
from .vgg import VGGLossMasked
|
14 |
+
|
15 |
+
logger = logging.getLogger("dva.{__name__}")
|
16 |
+
|
17 |
+
class DCTLoss(nn.Module):
|
18 |
+
def __init__(self, weights):
|
19 |
+
super().__init__()
|
20 |
+
self.weights = weights
|
21 |
+
|
22 |
+
def forward(self, inputs, preds, iteration=None):
|
23 |
+
loss_dict = {"loss_total": 0.0}
|
24 |
+
target = inputs['gt']
|
25 |
+
recon = preds['recon']
|
26 |
+
posterior = preds['posterior']
|
27 |
+
fft_gt = th.view_as_real(th.fft.fft(target.reshape(target.shape[0], -1)))
|
28 |
+
fft_recon = th.view_as_real(th.fft.fft(recon.reshape(recon.shape[0], -1)))
|
29 |
+
loss_recon_dct_l1 = th.mean(th.abs(fft_gt - fft_recon))
|
30 |
+
loss_recon_l1 = th.mean(th.abs(target - recon))
|
31 |
+
loss_kl = posterior.kl().mean()
|
32 |
+
loss_dict.update(loss_recon_l1=loss_recon_l1, loss_recon_dct_l1=loss_recon_dct_l1, loss_kl=loss_kl)
|
33 |
+
loss_total = self.weights.recon * loss_recon_dct_l1 + self.weights.kl * loss_kl
|
34 |
+
|
35 |
+
loss_dict["loss_total"] = loss_total
|
36 |
+
return loss_total, loss_dict
|
37 |
+
|
38 |
+
class VAESepL2Loss(nn.Module):
|
39 |
+
def __init__(self, weights):
|
40 |
+
super().__init__()
|
41 |
+
self.weights = weights
|
42 |
+
|
43 |
+
def forward(self, inputs, preds, iteration=None):
|
44 |
+
loss_dict = {"loss_total": 0.0}
|
45 |
+
target = inputs['gt']
|
46 |
+
recon = preds['recon']
|
47 |
+
posterior = preds['posterior']
|
48 |
+
recon_diff = (target - recon) ** 2
|
49 |
+
loss_recon_sdf_l1 = th.mean(recon_diff[:, 0:1, ...])
|
50 |
+
loss_recon_rgb_l1 = th.mean(recon_diff[:, 1:4, ...])
|
51 |
+
loss_recon_mat_l1 = th.mean(recon_diff[:, 4:6, ...])
|
52 |
+
loss_kl = posterior.kl().mean()
|
53 |
+
loss_dict.update(loss_sdf_l1=loss_recon_sdf_l1, loss_rgb_l1=loss_recon_rgb_l1, loss_mat_l1=loss_recon_mat_l1, loss_kl=loss_kl)
|
54 |
+
loss_total = self.weights.sdf * loss_recon_sdf_l1 + self.weights.rgb * loss_recon_rgb_l1 + self.weights.mat * loss_recon_mat_l1
|
55 |
+
if "kl" in self.weights:
|
56 |
+
loss_total += self.weights.kl * loss_kl
|
57 |
+
|
58 |
+
loss_dict["loss_total"] = loss_total
|
59 |
+
return loss_total, loss_dict
|
60 |
+
|
61 |
+
class VAESepLoss(nn.Module):
|
62 |
+
def __init__(self, weights):
|
63 |
+
super().__init__()
|
64 |
+
self.weights = weights
|
65 |
+
|
66 |
+
def forward(self, inputs, preds, iteration=None):
|
67 |
+
loss_dict = {"loss_total": 0.0}
|
68 |
+
target = inputs['gt']
|
69 |
+
recon = preds['recon']
|
70 |
+
posterior = preds['posterior']
|
71 |
+
recon_diff = th.abs(target - recon)
|
72 |
+
loss_recon_sdf_l1 = th.mean(recon_diff[:, 0:1, ...])
|
73 |
+
loss_recon_rgb_l1 = th.mean(recon_diff[:, 1:4, ...])
|
74 |
+
loss_recon_mat_l1 = th.mean(recon_diff[:, 4:6, ...])
|
75 |
+
loss_kl = posterior.kl().mean()
|
76 |
+
loss_dict.update(loss_sdf_l1=loss_recon_sdf_l1, loss_rgb_l1=loss_recon_rgb_l1, loss_mat_l1=loss_recon_mat_l1, loss_kl=loss_kl)
|
77 |
+
loss_total = self.weights.sdf * loss_recon_sdf_l1 + self.weights.rgb * loss_recon_rgb_l1 + self.weights.mat * loss_recon_mat_l1
|
78 |
+
if "kl" in self.weights:
|
79 |
+
loss_total += self.weights.kl * loss_kl
|
80 |
+
|
81 |
+
loss_dict["loss_total"] = loss_total
|
82 |
+
return loss_total, loss_dict
|
83 |
+
|
84 |
+
class VAELoss(nn.Module):
|
85 |
+
def __init__(self, weights):
|
86 |
+
super().__init__()
|
87 |
+
self.weights = weights
|
88 |
+
|
89 |
+
def forward(self, inputs, preds, iteration=None):
|
90 |
+
loss_dict = {"loss_total": 0.0}
|
91 |
+
target = inputs['gt']
|
92 |
+
recon = preds['recon']
|
93 |
+
posterior = preds['posterior']
|
94 |
+
loss_recon_l1 = th.mean(th.abs(target - recon))
|
95 |
+
loss_kl = posterior.kl().mean()
|
96 |
+
loss_dict.update(loss_recon_l1=loss_recon_l1, loss_kl=loss_kl)
|
97 |
+
loss_total = self.weights.recon * loss_recon_l1 + self.weights.kl * loss_kl
|
98 |
+
|
99 |
+
loss_dict["loss_total"] = loss_total
|
100 |
+
return loss_total, loss_dict
|
101 |
+
|
102 |
+
class PrimSDFLoss(nn.Module):
|
103 |
+
def __init__(self, weights, shape_opt_steps=2000, tex_opt_steps=6000):
|
104 |
+
super().__init__()
|
105 |
+
self.weights = weights
|
106 |
+
self.shape_opt_steps = shape_opt_steps
|
107 |
+
self.tex_opt_steps = tex_opt_steps
|
108 |
+
|
109 |
+
def forward(self, inputs, preds, iteration=None):
|
110 |
+
loss_dict = {"loss_total": 0.0}
|
111 |
+
|
112 |
+
if iteration < self.shape_opt_steps:
|
113 |
+
target_sdf = inputs['sdf']
|
114 |
+
sdf = preds['sdf']
|
115 |
+
loss_sdf_l1 = th.mean(th.abs(sdf - target_sdf))
|
116 |
+
loss_dict.update(loss_sdf_l1=loss_sdf_l1)
|
117 |
+
loss_total = self.weights.sdf_l1 * loss_sdf_l1
|
118 |
+
|
119 |
+
prim_scale = preds["prim_scale"]
|
120 |
+
# we use 1/scale instead of the original 100/scale as our scale is normalized to [-1, 1] cube
|
121 |
+
if "vol_sum" in self.weights:
|
122 |
+
loss_prim_vol_sum = th.mean(th.sum(th.prod(1 / prim_scale, dim=-1), dim=-1))
|
123 |
+
loss_dict.update(loss_prim_vol_sum=loss_prim_vol_sum)
|
124 |
+
loss_total += self.weights.vol_sum * loss_prim_vol_sum
|
125 |
+
|
126 |
+
if iteration >= self.shape_opt_steps and iteration < self.tex_opt_steps:
|
127 |
+
target_tex = inputs['tex']
|
128 |
+
tex = preds['tex']
|
129 |
+
loss_tex_l1 = th.mean(th.abs(tex - target_tex))
|
130 |
+
loss_dict.update(loss_tex_l1=loss_tex_l1)
|
131 |
+
|
132 |
+
loss_total = (
|
133 |
+
self.weights.rgb_l1 * loss_tex_l1
|
134 |
+
)
|
135 |
+
if "mat_l1" in self.weights:
|
136 |
+
target_mat = inputs['mat']
|
137 |
+
mat = preds['mat']
|
138 |
+
loss_mat_l1 = th.mean(th.abs(mat - target_mat))
|
139 |
+
loss_dict.update(loss_mat_l1=loss_mat_l1)
|
140 |
+
loss_total += self.weights.mat_l1 * loss_mat_l1
|
141 |
+
|
142 |
+
if "grad_l2" in self.weights:
|
143 |
+
loss_grad_l2 = th.mean((preds["grad"] - inputs["grad"]) ** 2)
|
144 |
+
loss_total += self.weights.grad_l2 * loss_grad_l2
|
145 |
+
loss_dict.update(loss_grad_l2=loss_grad_l2)
|
146 |
+
|
147 |
+
loss_dict["loss_total"] = loss_total
|
148 |
+
return loss_total, loss_dict
|
149 |
+
|
150 |
+
|
151 |
+
class TotalMVPLoss(nn.Module):
|
152 |
+
def __init__(self, weights, assets=None):
|
153 |
+
super().__init__()
|
154 |
+
|
155 |
+
self.weights = weights
|
156 |
+
|
157 |
+
if "vgg" in self.weights:
|
158 |
+
self.vgg_loss = VGGLossMasked()
|
159 |
+
|
160 |
+
def forward(self, inputs, preds, iteration=None):
|
161 |
+
|
162 |
+
loss_dict = {"loss_total": 0.0}
|
163 |
+
|
164 |
+
B = inputs["image"].shape
|
165 |
+
|
166 |
+
# rgb
|
167 |
+
target_rgb = inputs["image"].permute(0, 2, 3, 1)
|
168 |
+
# removing the mask
|
169 |
+
target_rgb = target_rgb * inputs["image_mask"][:, 0, :, :, np.newaxis]
|
170 |
+
|
171 |
+
rgb = preds["rgb"]
|
172 |
+
loss_rgb_mse = th.mean(((rgb - target_rgb) / 16.0) ** 2.0)
|
173 |
+
loss_dict.update(loss_rgb_mse=loss_rgb_mse)
|
174 |
+
|
175 |
+
alpha = preds["alpha"]
|
176 |
+
|
177 |
+
# mask loss
|
178 |
+
target_mask = inputs["image_mask"][:, 0].to(th.float32)
|
179 |
+
loss_mask_mae = th.mean((target_mask - alpha).abs())
|
180 |
+
loss_dict.update(loss_mask_mae=loss_mask_mae)
|
181 |
+
|
182 |
+
B = alpha.shape[0]
|
183 |
+
|
184 |
+
# beta prior on opacity
|
185 |
+
loss_alpha_prior = th.mean(
|
186 |
+
th.log(0.1 + alpha.reshape(B, -1))
|
187 |
+
+ th.log(0.1 + 1.0 - alpha.reshape(B, -1))
|
188 |
+
- -2.20727
|
189 |
+
)
|
190 |
+
loss_dict.update(loss_alpha_prior=loss_alpha_prior)
|
191 |
+
|
192 |
+
prim_scale = preds["prim_scale"]
|
193 |
+
loss_prim_vol_sum = th.mean(th.sum(th.prod(100.0 / prim_scale, dim=-1), dim=-1))
|
194 |
+
loss_dict.update(loss_prim_vol_sum=loss_prim_vol_sum)
|
195 |
+
|
196 |
+
loss_total = (
|
197 |
+
self.weights.rgb_mse * loss_rgb_mse
|
198 |
+
+ self.weights.mask_mae * loss_mask_mae
|
199 |
+
+ self.weights.alpha_prior * loss_alpha_prior
|
200 |
+
+ self.weights.prim_vol_sum * loss_prim_vol_sum
|
201 |
+
)
|
202 |
+
|
203 |
+
if "embs_l2" in self.weights:
|
204 |
+
loss_embs_l2 = th.sum(th.norm(preds["embs"], dim=1))
|
205 |
+
loss_total += self.weights.embs_l2 * loss_embs_l2
|
206 |
+
loss_dict.update(loss_embs_l2=loss_embs_l2)
|
207 |
+
|
208 |
+
if "vgg" in self.weights:
|
209 |
+
loss_vgg = self.vgg_loss(
|
210 |
+
rgb.permute(0, 3, 1, 2),
|
211 |
+
target_rgb.permute(0, 3, 1, 2),
|
212 |
+
inputs["image_mask"],
|
213 |
+
)
|
214 |
+
loss_total += self.weights.vgg * loss_vgg
|
215 |
+
loss_dict.update(loss_vgg=loss_vgg)
|
216 |
+
|
217 |
+
if "prim_scale_var" in self.weights:
|
218 |
+
log_prim_scale = th.log(prim_scale)
|
219 |
+
# NOTE: should we detach this?
|
220 |
+
log_prim_scale_mean = th.mean(log_prim_scale, dim=1, keepdim=True)
|
221 |
+
loss_prim_scale_var = th.mean((log_prim_scale - log_prim_scale_mean) ** 2.0)
|
222 |
+
loss_total += self.weights.prim_scale_var * loss_prim_scale_var
|
223 |
+
loss_dict.update(loss_prim_scale_var=loss_prim_scale_var)
|
224 |
+
|
225 |
+
loss_dict["loss_total"] = loss_total
|
226 |
+
|
227 |
+
return loss_total, loss_dict
|
228 |
+
|
229 |
+
|
230 |
+
def process_losses(loss_dict, reduce=True, detach=True):
|
231 |
+
"""Preprocess the dict of losses outputs."""
|
232 |
+
result = {
|
233 |
+
k.replace("loss_", ""): v for k, v in loss_dict.items() if k.startswith("loss_")
|
234 |
+
}
|
235 |
+
if detach:
|
236 |
+
result = {k: v.detach() for k, v in result.items()}
|
237 |
+
if reduce:
|
238 |
+
result = {k: float(v.mean().item()) for k, v in result.items()}
|
239 |
+
return result
|
dva/mvp/extensions/mvpraymarch/bvh.cu
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
// All rights reserved.
|
3 |
+
//
|
4 |
+
// This source code is licensed under the license found in the
|
5 |
+
// LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
#include <cmath>
|
8 |
+
#include <cstdio>
|
9 |
+
#include <functional>
|
10 |
+
#include <map>
|
11 |
+
|
12 |
+
#include "helper_math.h"
|
13 |
+
|
14 |
+
#include "cudadispatch.h"
|
15 |
+
|
16 |
+
#include "primtransf.h"
|
17 |
+
|
18 |
+
// Expands a 10-bit integer into 30 bits
|
19 |
+
// by inserting 2 zeros after each bit.
|
20 |
+
__device__ unsigned int expand_bits(unsigned int v) {
|
21 |
+
v = (v * 0x00010001u) & 0xFF0000FFu;
|
22 |
+
v = (v * 0x00000101u) & 0x0F00F00Fu;
|
23 |
+
v = (v * 0x00000011u) & 0xC30C30C3u;
|
24 |
+
v = (v * 0x00000005u) & 0x49249249u;
|
25 |
+
return v;
|
26 |
+
}
|
27 |
+
|
28 |
+
// Calculates a 30-bit Morton code for the
|
29 |
+
// given 3D point located within the unit cube [0,1].
|
30 |
+
__device__ unsigned int morton3D(float x, float y, float z) {
|
31 |
+
x = fminf(fmaxf(x * 1024.0f, 0.0f), 1023.0f);
|
32 |
+
y = fminf(fmaxf(y * 1024.0f, 0.0f), 1023.0f);
|
33 |
+
z = fminf(fmaxf(z * 1024.0f, 0.0f), 1023.0f);
|
34 |
+
unsigned int xx = expand_bits((unsigned int)x);
|
35 |
+
unsigned int yy = expand_bits((unsigned int)y);
|
36 |
+
unsigned int zz = expand_bits((unsigned int)z);
|
37 |
+
return xx * 4 + yy * 2 + zz;
|
38 |
+
}
|
39 |
+
|
40 |
+
template<typename PrimTransfT>
|
41 |
+
__global__ void compute_morton_kernel(
|
42 |
+
int N, int K,
|
43 |
+
typename PrimTransfT::Data data,
|
44 |
+
int * code
|
45 |
+
) {
|
46 |
+
const int count = N * K;
|
47 |
+
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < count; index += blockDim.x * gridDim.x) {
|
48 |
+
const int k = index % K;
|
49 |
+
const int n = index / K;
|
50 |
+
|
51 |
+
//float4 c = center[n * K + k];
|
52 |
+
float3 c = data.get_center(n, k);
|
53 |
+
code[n * K + k] = morton3D(c.x, c.y, c.z);
|
54 |
+
}
|
55 |
+
}
|
56 |
+
|
57 |
+
__forceinline__ __device__ int delta(int* sortedcodes, int x, int y, int K) {
|
58 |
+
if (x >= 0 && x <= K - 1 && y >= 0 && y <= K - 1) {
|
59 |
+
return sortedcodes[x] == sortedcodes[y] ?
|
60 |
+
32 + __clz(x ^ y) :
|
61 |
+
__clz(sortedcodes[x] ^ sortedcodes[y]);
|
62 |
+
}
|
63 |
+
return -1;
|
64 |
+
}
|
65 |
+
|
66 |
+
__forceinline__ __device__ int sign(int x) {
|
67 |
+
return (int)(x > 0) - (int)(x < 0);
|
68 |
+
}
|
69 |
+
|
70 |
+
__device__ int find_split(
|
71 |
+
int* sortedcodes,
|
72 |
+
int first,
|
73 |
+
int last,
|
74 |
+
int K) {
|
75 |
+
float commonPrefix = delta(sortedcodes, first, last, K);
|
76 |
+
int split = first;
|
77 |
+
int step = last - first;
|
78 |
+
|
79 |
+
do {
|
80 |
+
step = (step + 1) >> 1; // exponential decrease
|
81 |
+
int newSplit = split + step; // proposed new position
|
82 |
+
|
83 |
+
if (newSplit < last) {
|
84 |
+
int splitPrefix = delta(sortedcodes, first, newSplit, K);
|
85 |
+
if (splitPrefix > commonPrefix) {
|
86 |
+
split = newSplit; // accept proposal
|
87 |
+
}
|
88 |
+
}
|
89 |
+
} while (step > 1);
|
90 |
+
|
91 |
+
return split;
|
92 |
+
}
|
93 |
+
|
94 |
+
__device__ int2 determine_range(int* sortedcodes, int K, int idx) {
|
95 |
+
int d = sign(delta(sortedcodes, idx, idx + 1, K) - delta(sortedcodes, idx, idx - 1, K));
|
96 |
+
int dmin = delta(sortedcodes, idx, idx - d, K);
|
97 |
+
int lmax = 2;
|
98 |
+
while (delta(sortedcodes, idx, idx + lmax * d, K) > dmin) {
|
99 |
+
lmax = lmax * 2;
|
100 |
+
}
|
101 |
+
|
102 |
+
int l = 0;
|
103 |
+
for (int t = lmax / 2; t >= 1; t /= 2) {
|
104 |
+
if (delta(sortedcodes, idx, idx + (l + t)*d, K) > dmin) {
|
105 |
+
l += t;
|
106 |
+
}
|
107 |
+
}
|
108 |
+
|
109 |
+
int j = idx + l*d;
|
110 |
+
int2 range;
|
111 |
+
range.x = min(idx, j);
|
112 |
+
range.y = max(idx, j);
|
113 |
+
|
114 |
+
return range;
|
115 |
+
}
|
116 |
+
|
117 |
+
__global__ void build_tree_kernel(
|
118 |
+
int N, int K,
|
119 |
+
int * sortedcodes,
|
120 |
+
int2 * nodechildren,
|
121 |
+
int * nodeparent) {
|
122 |
+
const int count = N * (K + K - 1);
|
123 |
+
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < count; index += blockDim.x * gridDim.x) {
|
124 |
+
const int k = index % (K + K - 1);
|
125 |
+
const int n = index / (K + K - 1);
|
126 |
+
|
127 |
+
if (k >= K - 1) {
|
128 |
+
// leaf
|
129 |
+
nodechildren[n * (K + K - 1) + k] = make_int2(-(k - (K - 1)) - 1, -(k - (K - 1)) - 2);
|
130 |
+
} else {
|
131 |
+
// internal node
|
132 |
+
|
133 |
+
// find out which range of objects the node corresponds to
|
134 |
+
int2 range = determine_range(sortedcodes + n * K, K, k);
|
135 |
+
int first = range.x;
|
136 |
+
int last = range.y;
|
137 |
+
|
138 |
+
// determine where to split the range
|
139 |
+
int split = find_split(sortedcodes + n * K, first, last, K);
|
140 |
+
|
141 |
+
// select childA
|
142 |
+
int childa = split == first ? (K - 1) + split : split;
|
143 |
+
|
144 |
+
// select childB
|
145 |
+
int childb = split + 1 == last ? (K - 1) + split + 1 : split + 1;
|
146 |
+
|
147 |
+
// record parent-child relationships
|
148 |
+
nodechildren[n * (K + K - 1) + k] = make_int2(childa, childb);
|
149 |
+
nodeparent[n * (K + K - 1) + childa] = k;
|
150 |
+
nodeparent[n * (K + K - 1) + childb] = k;
|
151 |
+
}
|
152 |
+
}
|
153 |
+
}
|
154 |
+
|
155 |
+
template<typename PrimTransfT>
|
156 |
+
__global__ void compute_aabb_kernel(
|
157 |
+
int N, int K,
|
158 |
+
typename PrimTransfT::Data data,
|
159 |
+
int * sortedobjid,
|
160 |
+
int2 * nodechildren,
|
161 |
+
int * nodeparent,
|
162 |
+
float3 * nodeaabb,
|
163 |
+
int * atom) {
|
164 |
+
const int count = N * K;
|
165 |
+
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < count; index += blockDim.x * gridDim.x) {
|
166 |
+
const int k = index % K;
|
167 |
+
const int n = index / K;
|
168 |
+
|
169 |
+
// compute BBOX for leaf
|
170 |
+
int kk = sortedobjid[n * K + k];
|
171 |
+
|
172 |
+
float3 pmin;
|
173 |
+
float3 pmax;
|
174 |
+
data.compute_aabb(n, kk, pmin, pmax);
|
175 |
+
|
176 |
+
nodeaabb[n * (K + K - 1) * 2 + ((K - 1) + k) * 2 + 0] = pmin;
|
177 |
+
nodeaabb[n * (K + K - 1) * 2 + ((K - 1) + k) * 2 + 1] = pmax;
|
178 |
+
|
179 |
+
int node = nodeparent[n * (K + K - 1) + ((K - 1) + k)];
|
180 |
+
|
181 |
+
while (node != -1 && atomicCAS(&atom[n * (K - 1) + node], 0, 1) == 1) {
|
182 |
+
int2 children = nodechildren[n * (K + K - 1) + node];
|
183 |
+
float3 laabbmin = nodeaabb[n * (K + K - 1) * 2 + children.x * 2 + 0];
|
184 |
+
float3 laabbmax = nodeaabb[n * (K + K - 1) * 2 + children.x * 2 + 1];
|
185 |
+
float3 raabbmin = nodeaabb[n * (K + K - 1) * 2 + children.y * 2 + 0];
|
186 |
+
float3 raabbmax = nodeaabb[n * (K + K - 1) * 2 + children.y * 2 + 1];
|
187 |
+
|
188 |
+
float3 aabbmin = fminf(laabbmin, raabbmin);
|
189 |
+
float3 aabbmax = fmaxf(laabbmax, raabbmax);
|
190 |
+
|
191 |
+
nodeaabb[n * (K + K - 1) * 2 + node * 2 + 0] = aabbmin;
|
192 |
+
nodeaabb[n * (K + K - 1) * 2 + node * 2 + 1] = aabbmax;
|
193 |
+
|
194 |
+
node = nodeparent[n * (K + K - 1) + node];
|
195 |
+
|
196 |
+
__threadfence();
|
197 |
+
}
|
198 |
+
}
|
199 |
+
}
|
200 |
+
|
201 |
+
void compute_morton_cuda(
|
202 |
+
int N, int K,
|
203 |
+
float * primpos,
|
204 |
+
int * code,
|
205 |
+
int algorithm,
|
206 |
+
cudaStream_t stream) {
|
207 |
+
int count = N * K;
|
208 |
+
int blocksize = 512;
|
209 |
+
int gridsize = (count + blocksize - 1) / blocksize;
|
210 |
+
|
211 |
+
std::shared_ptr<PrimTransfDataBase> primtransf_data;
|
212 |
+
primtransf_data = std::make_shared<PrimTransfSRT::Data>(PrimTransfSRT::Data{
|
213 |
+
PrimTransfDataBase{},
|
214 |
+
K, (float3*)primpos, nullptr,
|
215 |
+
K * 3, nullptr, nullptr,
|
216 |
+
K, nullptr, nullptr});
|
217 |
+
|
218 |
+
std::map<int, std::function<void(dim3, dim3, cudaStream_t, int, int, std::shared_ptr<PrimTransfDataBase>, int*)>> dispatcher = {
|
219 |
+
{ 0, make_cudacall(compute_morton_kernel<PrimTransfSRT>) }
|
220 |
+
};
|
221 |
+
|
222 |
+
auto iter = dispatcher.find(min(0, algorithm));
|
223 |
+
if (iter != dispatcher.end()) {
|
224 |
+
(iter->second)(
|
225 |
+
dim3(gridsize), dim3(blocksize), stream,
|
226 |
+
N, K,
|
227 |
+
primtransf_data,
|
228 |
+
code);
|
229 |
+
}
|
230 |
+
}
|
231 |
+
|
232 |
+
void build_tree_cuda(
|
233 |
+
int N, int K,
|
234 |
+
int * sortedcode,
|
235 |
+
int * nodechildren,
|
236 |
+
int * nodeparent,
|
237 |
+
cudaStream_t stream) {
|
238 |
+
int count = N * (K + K - 1);
|
239 |
+
int nthreads = 512;
|
240 |
+
int nblocks = (count + nthreads - 1) / nthreads;
|
241 |
+
build_tree_kernel<<<nblocks, nthreads, 0, stream>>>(
|
242 |
+
N, K,
|
243 |
+
sortedcode,
|
244 |
+
reinterpret_cast<int2 *>(nodechildren),
|
245 |
+
nodeparent);
|
246 |
+
}
|
247 |
+
|
248 |
+
void compute_aabb_cuda(
|
249 |
+
int N, int K,
|
250 |
+
float * primpos,
|
251 |
+
float * primrot,
|
252 |
+
float * primscale,
|
253 |
+
int * sortedobjid,
|
254 |
+
int * nodechildren,
|
255 |
+
int * nodeparent,
|
256 |
+
float * nodeaabb,
|
257 |
+
int algorithm,
|
258 |
+
cudaStream_t stream) {
|
259 |
+
int * atom;
|
260 |
+
cudaMalloc(&atom, N * (K - 1) * 4);
|
261 |
+
cudaMemset(atom, 0, N * (K - 1) * 4);
|
262 |
+
|
263 |
+
int count = N * K;
|
264 |
+
int blocksize = 512;
|
265 |
+
int gridsize = (count + blocksize - 1) / blocksize;
|
266 |
+
|
267 |
+
std::shared_ptr<PrimTransfDataBase> primtransf_data;
|
268 |
+
primtransf_data = std::make_shared<PrimTransfSRT::Data>(PrimTransfSRT::Data{
|
269 |
+
PrimTransfDataBase{},
|
270 |
+
K, (float3*)primpos, nullptr,
|
271 |
+
K * 3, (float3*)primrot, nullptr,
|
272 |
+
K, (float3*)primscale, nullptr});
|
273 |
+
|
274 |
+
std::map<int, std::function<void(dim3, dim3, cudaStream_t, int, int, std::shared_ptr<PrimTransfDataBase>, int*, int2*, int*, float3*, int*)>> dispatcher = {
|
275 |
+
{ 0, make_cudacall(compute_aabb_kernel<PrimTransfSRT>) }
|
276 |
+
};
|
277 |
+
|
278 |
+
auto iter = dispatcher.find(min(0, algorithm));
|
279 |
+
if (iter != dispatcher.end()) {
|
280 |
+
(iter->second)(
|
281 |
+
dim3(gridsize), dim3(blocksize), stream,
|
282 |
+
N, K,
|
283 |
+
primtransf_data,
|
284 |
+
sortedobjid,
|
285 |
+
reinterpret_cast<int2 *>(nodechildren),
|
286 |
+
nodeparent,
|
287 |
+
reinterpret_cast<float3 *>(nodeaabb),
|
288 |
+
atom);
|
289 |
+
}
|
290 |
+
|
291 |
+
cudaFree(atom);
|
292 |
+
}
|
dva/mvp/extensions/mvpraymarch/cudadispatch.h
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
// All rights reserved.
|
3 |
+
//
|
4 |
+
// This source code is licensed under the license found in the
|
5 |
+
// LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
#ifndef cudadispatch_h_
|
8 |
+
#define cudadispatch_h_
|
9 |
+
|
10 |
+
#include <functional>
|
11 |
+
#include <memory>
|
12 |
+
#include <type_traits>
|
13 |
+
|
14 |
+
template<typename T, typename = void>
|
15 |
+
struct get_base {
|
16 |
+
typedef T type;
|
17 |
+
};
|
18 |
+
|
19 |
+
template<typename T>
|
20 |
+
struct get_base<T, typename std::enable_if<std::is_base_of<typename T::base, T>::value>::type> {
|
21 |
+
typedef std::shared_ptr<typename T::base> type;
|
22 |
+
};
|
23 |
+
|
24 |
+
template<typename T> struct is_shared_ptr : std::false_type {};
|
25 |
+
template<typename T> struct is_shared_ptr<std::shared_ptr<T>> : std::true_type {};
|
26 |
+
|
27 |
+
template<typename OutT, typename T>
|
28 |
+
auto convert_shptr_impl2(std::shared_ptr<T> t) {
|
29 |
+
return *static_cast<OutT*>(t.get());
|
30 |
+
}
|
31 |
+
|
32 |
+
template<typename OutT, typename T>
|
33 |
+
auto convert_shptr_impl(T&& t, std::false_type) {
|
34 |
+
return convert_shptr_impl2<OutT>(t);
|
35 |
+
}
|
36 |
+
|
37 |
+
template<typename OutT, typename T>
|
38 |
+
auto convert_shptr_impl(T&& t, std::true_type) {
|
39 |
+
return std::forward<T>(t);
|
40 |
+
}
|
41 |
+
|
42 |
+
template<typename OutT, typename T>
|
43 |
+
auto convert_shptr(T&& t) {
|
44 |
+
return convert_shptr_impl<OutT>(std::forward<T>(t), std::is_same<OutT, T>{});
|
45 |
+
}
|
46 |
+
|
47 |
+
template<typename... ArgsIn>
|
48 |
+
struct cudacall {
|
49 |
+
struct functbase {
|
50 |
+
virtual ~functbase() {}
|
51 |
+
virtual void call(dim3, dim3, cudaStream_t, ArgsIn...) const = 0;
|
52 |
+
};
|
53 |
+
|
54 |
+
template<typename... ArgsOut>
|
55 |
+
struct funct : public functbase {
|
56 |
+
std::function<void(ArgsOut...)> fn;
|
57 |
+
funct(void(*fn_)(ArgsOut...)) : fn(fn_) { }
|
58 |
+
void call(dim3 gridsize, dim3 blocksize, cudaStream_t stream, ArgsIn... args) const {
|
59 |
+
void (*const*kfunc)(ArgsOut...) = fn.template target<void (*)(ArgsOut...)>();
|
60 |
+
(*kfunc)<<<gridsize, blocksize, 0, stream>>>(
|
61 |
+
std::forward<ArgsOut>(convert_shptr<ArgsOut>(std::forward<ArgsIn>(args)))...);
|
62 |
+
}
|
63 |
+
};
|
64 |
+
|
65 |
+
std::shared_ptr<functbase> fn;
|
66 |
+
|
67 |
+
template<typename... ArgsOut>
|
68 |
+
cudacall(void(*fn_)(ArgsOut...)) : fn(std::make_shared<funct<ArgsOut...>>(fn_)) { }
|
69 |
+
|
70 |
+
template<typename... ArgsTmp>
|
71 |
+
void call(dim3 gridsize, dim3 blocksize, cudaStream_t stream, ArgsTmp&&... args) const {
|
72 |
+
fn->call(gridsize, blocksize, stream, std::forward<ArgsIn>(args)...);
|
73 |
+
}
|
74 |
+
};
|
75 |
+
|
76 |
+
template <typename F, typename T>
|
77 |
+
struct binder {
|
78 |
+
F f; T t;
|
79 |
+
template <typename... Args>
|
80 |
+
auto operator()(Args&&... args) const
|
81 |
+
-> decltype(f(t, std::forward<Args>(args)...)) {
|
82 |
+
return f(t, std::forward<Args>(args)...);
|
83 |
+
}
|
84 |
+
};
|
85 |
+
|
86 |
+
template <typename F, typename T>
|
87 |
+
binder<typename std::decay<F>::type
|
88 |
+
, typename std::decay<T>::type> BindFirst(F&& f, T&& t) {
|
89 |
+
return { std::forward<F>(f), std::forward<T>(t) };
|
90 |
+
}
|
91 |
+
|
92 |
+
template<typename... ArgsOut>
|
93 |
+
auto make_cudacall_(void(*fn)(ArgsOut...)) {
|
94 |
+
return BindFirst(
|
95 |
+
std::mem_fn(&cudacall<typename get_base<ArgsOut>::type...>::template call<typename get_base<ArgsOut>::type...>),
|
96 |
+
cudacall<typename get_base<ArgsOut>::type...>(fn));
|
97 |
+
}
|
98 |
+
|
99 |
+
template<typename... ArgsOut>
|
100 |
+
std::function<void(dim3, dim3, cudaStream_t, typename get_base<ArgsOut>::type...)> make_cudacall(void(*fn)(ArgsOut...)) {
|
101 |
+
return std::function<void(dim3, dim3, cudaStream_t, typename get_base<ArgsOut>::type...)>(make_cudacall_(fn));
|
102 |
+
}
|
103 |
+
|
104 |
+
#endif
|
dva/mvp/extensions/mvpraymarch/helper_math.h
ADDED
@@ -0,0 +1,1453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/**
|
2 |
+
* Copyright 1993-2013 NVIDIA Corporation. All rights reserved.
|
3 |
+
*
|
4 |
+
* Please refer to the NVIDIA end user license agreement (EULA) associated
|
5 |
+
* with this source code for terms and conditions that govern your use of
|
6 |
+
* this software. Any use, reproduction, disclosure, or distribution of
|
7 |
+
* this software and related documentation outside the terms of the EULA
|
8 |
+
* is strictly prohibited.
|
9 |
+
*
|
10 |
+
*/
|
11 |
+
|
12 |
+
/*
|
13 |
+
* This file implements common mathematical operations on vector types
|
14 |
+
* (float3, float4 etc.) since these are not provided as standard by CUDA.
|
15 |
+
*
|
16 |
+
* The syntax is modeled on the Cg standard library.
|
17 |
+
*
|
18 |
+
* This is part of the Helper library includes
|
19 |
+
*
|
20 |
+
* Thanks to Linh Hah for additions and fixes.
|
21 |
+
*/
|
22 |
+
|
23 |
+
#ifndef HELPER_MATH_H
|
24 |
+
#define HELPER_MATH_H
|
25 |
+
|
26 |
+
#include "cuda_runtime.h"
|
27 |
+
|
28 |
+
typedef unsigned int uint;
|
29 |
+
typedef unsigned short ushort;
|
30 |
+
|
31 |
+
#ifndef EXIT_WAIVED
|
32 |
+
#define EXIT_WAIVED 2
|
33 |
+
#endif
|
34 |
+
|
35 |
+
#ifndef __CUDACC__
|
36 |
+
#include <math.h>
|
37 |
+
|
38 |
+
////////////////////////////////////////////////////////////////////////////////
|
39 |
+
// host implementations of CUDA functions
|
40 |
+
////////////////////////////////////////////////////////////////////////////////
|
41 |
+
|
42 |
+
inline float fminf(float a, float b)
|
43 |
+
{
|
44 |
+
return a < b ? a : b;
|
45 |
+
}
|
46 |
+
|
47 |
+
inline float fmaxf(float a, float b)
|
48 |
+
{
|
49 |
+
return a > b ? a : b;
|
50 |
+
}
|
51 |
+
|
52 |
+
inline int max(int a, int b)
|
53 |
+
{
|
54 |
+
return a > b ? a : b;
|
55 |
+
}
|
56 |
+
|
57 |
+
inline int min(int a, int b)
|
58 |
+
{
|
59 |
+
return a < b ? a : b;
|
60 |
+
}
|
61 |
+
|
62 |
+
inline float rsqrtf(float x)
|
63 |
+
{
|
64 |
+
return 1.0f / sqrtf(x);
|
65 |
+
}
|
66 |
+
#endif
|
67 |
+
|
68 |
+
////////////////////////////////////////////////////////////////////////////////
|
69 |
+
// constructors
|
70 |
+
////////////////////////////////////////////////////////////////////////////////
|
71 |
+
|
72 |
+
inline __host__ __device__ float2 make_float2(float s)
|
73 |
+
{
|
74 |
+
return make_float2(s, s);
|
75 |
+
}
|
76 |
+
inline __host__ __device__ float2 make_float2(float3 a)
|
77 |
+
{
|
78 |
+
return make_float2(a.x, a.y);
|
79 |
+
}
|
80 |
+
inline __host__ __device__ float2 make_float2(int2 a)
|
81 |
+
{
|
82 |
+
return make_float2(float(a.x), float(a.y));
|
83 |
+
}
|
84 |
+
inline __host__ __device__ float2 make_float2(uint2 a)
|
85 |
+
{
|
86 |
+
return make_float2(float(a.x), float(a.y));
|
87 |
+
}
|
88 |
+
|
89 |
+
inline __host__ __device__ int2 make_int2(int s)
|
90 |
+
{
|
91 |
+
return make_int2(s, s);
|
92 |
+
}
|
93 |
+
inline __host__ __device__ int2 make_int2(int3 a)
|
94 |
+
{
|
95 |
+
return make_int2(a.x, a.y);
|
96 |
+
}
|
97 |
+
inline __host__ __device__ int2 make_int2(uint2 a)
|
98 |
+
{
|
99 |
+
return make_int2(int(a.x), int(a.y));
|
100 |
+
}
|
101 |
+
inline __host__ __device__ int2 make_int2(float2 a)
|
102 |
+
{
|
103 |
+
return make_int2(int(a.x), int(a.y));
|
104 |
+
}
|
105 |
+
|
106 |
+
inline __host__ __device__ uint2 make_uint2(uint s)
|
107 |
+
{
|
108 |
+
return make_uint2(s, s);
|
109 |
+
}
|
110 |
+
inline __host__ __device__ uint2 make_uint2(uint3 a)
|
111 |
+
{
|
112 |
+
return make_uint2(a.x, a.y);
|
113 |
+
}
|
114 |
+
inline __host__ __device__ uint2 make_uint2(int2 a)
|
115 |
+
{
|
116 |
+
return make_uint2(uint(a.x), uint(a.y));
|
117 |
+
}
|
118 |
+
|
119 |
+
inline __host__ __device__ float3 make_float3(float s)
|
120 |
+
{
|
121 |
+
return make_float3(s, s, s);
|
122 |
+
}
|
123 |
+
inline __host__ __device__ float3 make_float3(float2 a)
|
124 |
+
{
|
125 |
+
return make_float3(a.x, a.y, 0.0f);
|
126 |
+
}
|
127 |
+
inline __host__ __device__ float3 make_float3(float2 a, float s)
|
128 |
+
{
|
129 |
+
return make_float3(a.x, a.y, s);
|
130 |
+
}
|
131 |
+
inline __host__ __device__ float3 make_float3(float4 a)
|
132 |
+
{
|
133 |
+
return make_float3(a.x, a.y, a.z);
|
134 |
+
}
|
135 |
+
inline __host__ __device__ float3 make_float3(int3 a)
|
136 |
+
{
|
137 |
+
return make_float3(float(a.x), float(a.y), float(a.z));
|
138 |
+
}
|
139 |
+
inline __host__ __device__ float3 make_float3(uint3 a)
|
140 |
+
{
|
141 |
+
return make_float3(float(a.x), float(a.y), float(a.z));
|
142 |
+
}
|
143 |
+
|
144 |
+
inline __host__ __device__ int3 make_int3(int s)
|
145 |
+
{
|
146 |
+
return make_int3(s, s, s);
|
147 |
+
}
|
148 |
+
inline __host__ __device__ int3 make_int3(int2 a)
|
149 |
+
{
|
150 |
+
return make_int3(a.x, a.y, 0);
|
151 |
+
}
|
152 |
+
inline __host__ __device__ int3 make_int3(int2 a, int s)
|
153 |
+
{
|
154 |
+
return make_int3(a.x, a.y, s);
|
155 |
+
}
|
156 |
+
inline __host__ __device__ int3 make_int3(uint3 a)
|
157 |
+
{
|
158 |
+
return make_int3(int(a.x), int(a.y), int(a.z));
|
159 |
+
}
|
160 |
+
inline __host__ __device__ int3 make_int3(float3 a)
|
161 |
+
{
|
162 |
+
return make_int3(int(a.x), int(a.y), int(a.z));
|
163 |
+
}
|
164 |
+
|
165 |
+
inline __host__ __device__ uint3 make_uint3(uint s)
|
166 |
+
{
|
167 |
+
return make_uint3(s, s, s);
|
168 |
+
}
|
169 |
+
inline __host__ __device__ uint3 make_uint3(uint2 a)
|
170 |
+
{
|
171 |
+
return make_uint3(a.x, a.y, 0);
|
172 |
+
}
|
173 |
+
inline __host__ __device__ uint3 make_uint3(uint2 a, uint s)
|
174 |
+
{
|
175 |
+
return make_uint3(a.x, a.y, s);
|
176 |
+
}
|
177 |
+
inline __host__ __device__ uint3 make_uint3(uint4 a)
|
178 |
+
{
|
179 |
+
return make_uint3(a.x, a.y, a.z);
|
180 |
+
}
|
181 |
+
inline __host__ __device__ uint3 make_uint3(int3 a)
|
182 |
+
{
|
183 |
+
return make_uint3(uint(a.x), uint(a.y), uint(a.z));
|
184 |
+
}
|
185 |
+
|
186 |
+
inline __host__ __device__ float4 make_float4(float s)
|
187 |
+
{
|
188 |
+
return make_float4(s, s, s, s);
|
189 |
+
}
|
190 |
+
inline __host__ __device__ float4 make_float4(float3 a)
|
191 |
+
{
|
192 |
+
return make_float4(a.x, a.y, a.z, 0.0f);
|
193 |
+
}
|
194 |
+
inline __host__ __device__ float4 make_float4(float3 a, float w)
|
195 |
+
{
|
196 |
+
return make_float4(a.x, a.y, a.z, w);
|
197 |
+
}
|
198 |
+
inline __host__ __device__ float4 make_float4(int4 a)
|
199 |
+
{
|
200 |
+
return make_float4(float(a.x), float(a.y), float(a.z), float(a.w));
|
201 |
+
}
|
202 |
+
inline __host__ __device__ float4 make_float4(uint4 a)
|
203 |
+
{
|
204 |
+
return make_float4(float(a.x), float(a.y), float(a.z), float(a.w));
|
205 |
+
}
|
206 |
+
|
207 |
+
inline __host__ __device__ int4 make_int4(int s)
|
208 |
+
{
|
209 |
+
return make_int4(s, s, s, s);
|
210 |
+
}
|
211 |
+
inline __host__ __device__ int4 make_int4(int3 a)
|
212 |
+
{
|
213 |
+
return make_int4(a.x, a.y, a.z, 0);
|
214 |
+
}
|
215 |
+
inline __host__ __device__ int4 make_int4(int3 a, int w)
|
216 |
+
{
|
217 |
+
return make_int4(a.x, a.y, a.z, w);
|
218 |
+
}
|
219 |
+
inline __host__ __device__ int4 make_int4(uint4 a)
|
220 |
+
{
|
221 |
+
return make_int4(int(a.x), int(a.y), int(a.z), int(a.w));
|
222 |
+
}
|
223 |
+
inline __host__ __device__ int4 make_int4(float4 a)
|
224 |
+
{
|
225 |
+
return make_int4(int(a.x), int(a.y), int(a.z), int(a.w));
|
226 |
+
}
|
227 |
+
|
228 |
+
|
229 |
+
inline __host__ __device__ uint4 make_uint4(uint s)
|
230 |
+
{
|
231 |
+
return make_uint4(s, s, s, s);
|
232 |
+
}
|
233 |
+
inline __host__ __device__ uint4 make_uint4(uint3 a)
|
234 |
+
{
|
235 |
+
return make_uint4(a.x, a.y, a.z, 0);
|
236 |
+
}
|
237 |
+
inline __host__ __device__ uint4 make_uint4(uint3 a, uint w)
|
238 |
+
{
|
239 |
+
return make_uint4(a.x, a.y, a.z, w);
|
240 |
+
}
|
241 |
+
inline __host__ __device__ uint4 make_uint4(int4 a)
|
242 |
+
{
|
243 |
+
return make_uint4(uint(a.x), uint(a.y), uint(a.z), uint(a.w));
|
244 |
+
}
|
245 |
+
|
246 |
+
////////////////////////////////////////////////////////////////////////////////
|
247 |
+
// negate
|
248 |
+
////////////////////////////////////////////////////////////////////////////////
|
249 |
+
|
250 |
+
inline __host__ __device__ float2 operator-(float2 &a)
|
251 |
+
{
|
252 |
+
return make_float2(-a.x, -a.y);
|
253 |
+
}
|
254 |
+
inline __host__ __device__ int2 operator-(int2 &a)
|
255 |
+
{
|
256 |
+
return make_int2(-a.x, -a.y);
|
257 |
+
}
|
258 |
+
inline __host__ __device__ float3 operator-(float3 &a)
|
259 |
+
{
|
260 |
+
return make_float3(-a.x, -a.y, -a.z);
|
261 |
+
}
|
262 |
+
inline __host__ __device__ int3 operator-(int3 &a)
|
263 |
+
{
|
264 |
+
return make_int3(-a.x, -a.y, -a.z);
|
265 |
+
}
|
266 |
+
inline __host__ __device__ float4 operator-(float4 &a)
|
267 |
+
{
|
268 |
+
return make_float4(-a.x, -a.y, -a.z, -a.w);
|
269 |
+
}
|
270 |
+
inline __host__ __device__ int4 operator-(int4 &a)
|
271 |
+
{
|
272 |
+
return make_int4(-a.x, -a.y, -a.z, -a.w);
|
273 |
+
}
|
274 |
+
|
275 |
+
////////////////////////////////////////////////////////////////////////////////
|
276 |
+
// addition
|
277 |
+
////////////////////////////////////////////////////////////////////////////////
|
278 |
+
|
279 |
+
inline __host__ __device__ float2 operator+(float2 a, float2 b)
|
280 |
+
{
|
281 |
+
return make_float2(a.x + b.x, a.y + b.y);
|
282 |
+
}
|
283 |
+
inline __host__ __device__ void operator+=(float2 &a, float2 b)
|
284 |
+
{
|
285 |
+
a.x += b.x;
|
286 |
+
a.y += b.y;
|
287 |
+
}
|
288 |
+
inline __host__ __device__ float2 operator+(float2 a, float b)
|
289 |
+
{
|
290 |
+
return make_float2(a.x + b, a.y + b);
|
291 |
+
}
|
292 |
+
inline __host__ __device__ float2 operator+(float b, float2 a)
|
293 |
+
{
|
294 |
+
return make_float2(a.x + b, a.y + b);
|
295 |
+
}
|
296 |
+
inline __host__ __device__ void operator+=(float2 &a, float b)
|
297 |
+
{
|
298 |
+
a.x += b;
|
299 |
+
a.y += b;
|
300 |
+
}
|
301 |
+
|
302 |
+
inline __host__ __device__ int2 operator+(int2 a, int2 b)
|
303 |
+
{
|
304 |
+
return make_int2(a.x + b.x, a.y + b.y);
|
305 |
+
}
|
306 |
+
inline __host__ __device__ void operator+=(int2 &a, int2 b)
|
307 |
+
{
|
308 |
+
a.x += b.x;
|
309 |
+
a.y += b.y;
|
310 |
+
}
|
311 |
+
inline __host__ __device__ int2 operator+(int2 a, int b)
|
312 |
+
{
|
313 |
+
return make_int2(a.x + b, a.y + b);
|
314 |
+
}
|
315 |
+
inline __host__ __device__ int2 operator+(int b, int2 a)
|
316 |
+
{
|
317 |
+
return make_int2(a.x + b, a.y + b);
|
318 |
+
}
|
319 |
+
inline __host__ __device__ void operator+=(int2 &a, int b)
|
320 |
+
{
|
321 |
+
a.x += b;
|
322 |
+
a.y += b;
|
323 |
+
}
|
324 |
+
|
325 |
+
inline __host__ __device__ uint2 operator+(uint2 a, uint2 b)
|
326 |
+
{
|
327 |
+
return make_uint2(a.x + b.x, a.y + b.y);
|
328 |
+
}
|
329 |
+
inline __host__ __device__ void operator+=(uint2 &a, uint2 b)
|
330 |
+
{
|
331 |
+
a.x += b.x;
|
332 |
+
a.y += b.y;
|
333 |
+
}
|
334 |
+
inline __host__ __device__ uint2 operator+(uint2 a, uint b)
|
335 |
+
{
|
336 |
+
return make_uint2(a.x + b, a.y + b);
|
337 |
+
}
|
338 |
+
inline __host__ __device__ uint2 operator+(uint b, uint2 a)
|
339 |
+
{
|
340 |
+
return make_uint2(a.x + b, a.y + b);
|
341 |
+
}
|
342 |
+
inline __host__ __device__ void operator+=(uint2 &a, uint b)
|
343 |
+
{
|
344 |
+
a.x += b;
|
345 |
+
a.y += b;
|
346 |
+
}
|
347 |
+
|
348 |
+
|
349 |
+
inline __host__ __device__ float3 operator+(float3 a, float3 b)
|
350 |
+
{
|
351 |
+
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
|
352 |
+
}
|
353 |
+
inline __host__ __device__ void operator+=(float3 &a, float3 b)
|
354 |
+
{
|
355 |
+
a.x += b.x;
|
356 |
+
a.y += b.y;
|
357 |
+
a.z += b.z;
|
358 |
+
}
|
359 |
+
inline __host__ __device__ float3 operator+(float3 a, float b)
|
360 |
+
{
|
361 |
+
return make_float3(a.x + b, a.y + b, a.z + b);
|
362 |
+
}
|
363 |
+
inline __host__ __device__ void operator+=(float3 &a, float b)
|
364 |
+
{
|
365 |
+
a.x += b;
|
366 |
+
a.y += b;
|
367 |
+
a.z += b;
|
368 |
+
}
|
369 |
+
|
370 |
+
inline __host__ __device__ int3 operator+(int3 a, int3 b)
|
371 |
+
{
|
372 |
+
return make_int3(a.x + b.x, a.y + b.y, a.z + b.z);
|
373 |
+
}
|
374 |
+
inline __host__ __device__ void operator+=(int3 &a, int3 b)
|
375 |
+
{
|
376 |
+
a.x += b.x;
|
377 |
+
a.y += b.y;
|
378 |
+
a.z += b.z;
|
379 |
+
}
|
380 |
+
inline __host__ __device__ int3 operator+(int3 a, int b)
|
381 |
+
{
|
382 |
+
return make_int3(a.x + b, a.y + b, a.z + b);
|
383 |
+
}
|
384 |
+
inline __host__ __device__ void operator+=(int3 &a, int b)
|
385 |
+
{
|
386 |
+
a.x += b;
|
387 |
+
a.y += b;
|
388 |
+
a.z += b;
|
389 |
+
}
|
390 |
+
|
391 |
+
inline __host__ __device__ uint3 operator+(uint3 a, uint3 b)
|
392 |
+
{
|
393 |
+
return make_uint3(a.x + b.x, a.y + b.y, a.z + b.z);
|
394 |
+
}
|
395 |
+
inline __host__ __device__ void operator+=(uint3 &a, uint3 b)
|
396 |
+
{
|
397 |
+
a.x += b.x;
|
398 |
+
a.y += b.y;
|
399 |
+
a.z += b.z;
|
400 |
+
}
|
401 |
+
inline __host__ __device__ uint3 operator+(uint3 a, uint b)
|
402 |
+
{
|
403 |
+
return make_uint3(a.x + b, a.y + b, a.z + b);
|
404 |
+
}
|
405 |
+
inline __host__ __device__ void operator+=(uint3 &a, uint b)
|
406 |
+
{
|
407 |
+
a.x += b;
|
408 |
+
a.y += b;
|
409 |
+
a.z += b;
|
410 |
+
}
|
411 |
+
|
412 |
+
inline __host__ __device__ int3 operator+(int b, int3 a)
|
413 |
+
{
|
414 |
+
return make_int3(a.x + b, a.y + b, a.z + b);
|
415 |
+
}
|
416 |
+
inline __host__ __device__ uint3 operator+(uint b, uint3 a)
|
417 |
+
{
|
418 |
+
return make_uint3(a.x + b, a.y + b, a.z + b);
|
419 |
+
}
|
420 |
+
inline __host__ __device__ float3 operator+(float b, float3 a)
|
421 |
+
{
|
422 |
+
return make_float3(a.x + b, a.y + b, a.z + b);
|
423 |
+
}
|
424 |
+
|
425 |
+
inline __host__ __device__ float4 operator+(float4 a, float4 b)
|
426 |
+
{
|
427 |
+
return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
|
428 |
+
}
|
429 |
+
inline __host__ __device__ void operator+=(float4 &a, float4 b)
|
430 |
+
{
|
431 |
+
a.x += b.x;
|
432 |
+
a.y += b.y;
|
433 |
+
a.z += b.z;
|
434 |
+
a.w += b.w;
|
435 |
+
}
|
436 |
+
inline __host__ __device__ float4 operator+(float4 a, float b)
|
437 |
+
{
|
438 |
+
return make_float4(a.x + b, a.y + b, a.z + b, a.w + b);
|
439 |
+
}
|
440 |
+
inline __host__ __device__ float4 operator+(float b, float4 a)
|
441 |
+
{
|
442 |
+
return make_float4(a.x + b, a.y + b, a.z + b, a.w + b);
|
443 |
+
}
|
444 |
+
inline __host__ __device__ void operator+=(float4 &a, float b)
|
445 |
+
{
|
446 |
+
a.x += b;
|
447 |
+
a.y += b;
|
448 |
+
a.z += b;
|
449 |
+
a.w += b;
|
450 |
+
}
|
451 |
+
|
452 |
+
inline __host__ __device__ int4 operator+(int4 a, int4 b)
|
453 |
+
{
|
454 |
+
return make_int4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
|
455 |
+
}
|
456 |
+
inline __host__ __device__ void operator+=(int4 &a, int4 b)
|
457 |
+
{
|
458 |
+
a.x += b.x;
|
459 |
+
a.y += b.y;
|
460 |
+
a.z += b.z;
|
461 |
+
a.w += b.w;
|
462 |
+
}
|
463 |
+
inline __host__ __device__ int4 operator+(int4 a, int b)
|
464 |
+
{
|
465 |
+
return make_int4(a.x + b, a.y + b, a.z + b, a.w + b);
|
466 |
+
}
|
467 |
+
inline __host__ __device__ int4 operator+(int b, int4 a)
|
468 |
+
{
|
469 |
+
return make_int4(a.x + b, a.y + b, a.z + b, a.w + b);
|
470 |
+
}
|
471 |
+
inline __host__ __device__ void operator+=(int4 &a, int b)
|
472 |
+
{
|
473 |
+
a.x += b;
|
474 |
+
a.y += b;
|
475 |
+
a.z += b;
|
476 |
+
a.w += b;
|
477 |
+
}
|
478 |
+
|
479 |
+
inline __host__ __device__ uint4 operator+(uint4 a, uint4 b)
|
480 |
+
{
|
481 |
+
return make_uint4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
|
482 |
+
}
|
483 |
+
inline __host__ __device__ void operator+=(uint4 &a, uint4 b)
|
484 |
+
{
|
485 |
+
a.x += b.x;
|
486 |
+
a.y += b.y;
|
487 |
+
a.z += b.z;
|
488 |
+
a.w += b.w;
|
489 |
+
}
|
490 |
+
inline __host__ __device__ uint4 operator+(uint4 a, uint b)
|
491 |
+
{
|
492 |
+
return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b);
|
493 |
+
}
|
494 |
+
inline __host__ __device__ uint4 operator+(uint b, uint4 a)
|
495 |
+
{
|
496 |
+
return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b);
|
497 |
+
}
|
498 |
+
inline __host__ __device__ void operator+=(uint4 &a, uint b)
|
499 |
+
{
|
500 |
+
a.x += b;
|
501 |
+
a.y += b;
|
502 |
+
a.z += b;
|
503 |
+
a.w += b;
|
504 |
+
}
|
505 |
+
|
506 |
+
////////////////////////////////////////////////////////////////////////////////
|
507 |
+
// subtract
|
508 |
+
////////////////////////////////////////////////////////////////////////////////
|
509 |
+
|
510 |
+
inline __host__ __device__ float2 operator-(float2 a, float2 b)
|
511 |
+
{
|
512 |
+
return make_float2(a.x - b.x, a.y - b.y);
|
513 |
+
}
|
514 |
+
inline __host__ __device__ void operator-=(float2 &a, float2 b)
|
515 |
+
{
|
516 |
+
a.x -= b.x;
|
517 |
+
a.y -= b.y;
|
518 |
+
}
|
519 |
+
inline __host__ __device__ float2 operator-(float2 a, float b)
|
520 |
+
{
|
521 |
+
return make_float2(a.x - b, a.y - b);
|
522 |
+
}
|
523 |
+
inline __host__ __device__ float2 operator-(float b, float2 a)
|
524 |
+
{
|
525 |
+
return make_float2(b - a.x, b - a.y);
|
526 |
+
}
|
527 |
+
inline __host__ __device__ void operator-=(float2 &a, float b)
|
528 |
+
{
|
529 |
+
a.x -= b;
|
530 |
+
a.y -= b;
|
531 |
+
}
|
532 |
+
|
533 |
+
inline __host__ __device__ int2 operator-(int2 a, int2 b)
|
534 |
+
{
|
535 |
+
return make_int2(a.x - b.x, a.y - b.y);
|
536 |
+
}
|
537 |
+
inline __host__ __device__ void operator-=(int2 &a, int2 b)
|
538 |
+
{
|
539 |
+
a.x -= b.x;
|
540 |
+
a.y -= b.y;
|
541 |
+
}
|
542 |
+
inline __host__ __device__ int2 operator-(int2 a, int b)
|
543 |
+
{
|
544 |
+
return make_int2(a.x - b, a.y - b);
|
545 |
+
}
|
546 |
+
inline __host__ __device__ int2 operator-(int b, int2 a)
|
547 |
+
{
|
548 |
+
return make_int2(b - a.x, b - a.y);
|
549 |
+
}
|
550 |
+
inline __host__ __device__ void operator-=(int2 &a, int b)
|
551 |
+
{
|
552 |
+
a.x -= b;
|
553 |
+
a.y -= b;
|
554 |
+
}
|
555 |
+
|
556 |
+
inline __host__ __device__ uint2 operator-(uint2 a, uint2 b)
|
557 |
+
{
|
558 |
+
return make_uint2(a.x - b.x, a.y - b.y);
|
559 |
+
}
|
560 |
+
inline __host__ __device__ void operator-=(uint2 &a, uint2 b)
|
561 |
+
{
|
562 |
+
a.x -= b.x;
|
563 |
+
a.y -= b.y;
|
564 |
+
}
|
565 |
+
inline __host__ __device__ uint2 operator-(uint2 a, uint b)
|
566 |
+
{
|
567 |
+
return make_uint2(a.x - b, a.y - b);
|
568 |
+
}
|
569 |
+
inline __host__ __device__ uint2 operator-(uint b, uint2 a)
|
570 |
+
{
|
571 |
+
return make_uint2(b - a.x, b - a.y);
|
572 |
+
}
|
573 |
+
inline __host__ __device__ void operator-=(uint2 &a, uint b)
|
574 |
+
{
|
575 |
+
a.x -= b;
|
576 |
+
a.y -= b;
|
577 |
+
}
|
578 |
+
|
579 |
+
inline __host__ __device__ float3 operator-(float3 a, float3 b)
|
580 |
+
{
|
581 |
+
return make_float3(a.x - b.x, a.y - b.y, a.z - b.z);
|
582 |
+
}
|
583 |
+
inline __host__ __device__ void operator-=(float3 &a, float3 b)
|
584 |
+
{
|
585 |
+
a.x -= b.x;
|
586 |
+
a.y -= b.y;
|
587 |
+
a.z -= b.z;
|
588 |
+
}
|
589 |
+
inline __host__ __device__ float3 operator-(float3 a, float b)
|
590 |
+
{
|
591 |
+
return make_float3(a.x - b, a.y - b, a.z - b);
|
592 |
+
}
|
593 |
+
inline __host__ __device__ float3 operator-(float b, float3 a)
|
594 |
+
{
|
595 |
+
return make_float3(b - a.x, b - a.y, b - a.z);
|
596 |
+
}
|
597 |
+
inline __host__ __device__ void operator-=(float3 &a, float b)
|
598 |
+
{
|
599 |
+
a.x -= b;
|
600 |
+
a.y -= b;
|
601 |
+
a.z -= b;
|
602 |
+
}
|
603 |
+
|
604 |
+
inline __host__ __device__ int3 operator-(int3 a, int3 b)
|
605 |
+
{
|
606 |
+
return make_int3(a.x - b.x, a.y - b.y, a.z - b.z);
|
607 |
+
}
|
608 |
+
inline __host__ __device__ void operator-=(int3 &a, int3 b)
|
609 |
+
{
|
610 |
+
a.x -= b.x;
|
611 |
+
a.y -= b.y;
|
612 |
+
a.z -= b.z;
|
613 |
+
}
|
614 |
+
inline __host__ __device__ int3 operator-(int3 a, int b)
|
615 |
+
{
|
616 |
+
return make_int3(a.x - b, a.y - b, a.z - b);
|
617 |
+
}
|
618 |
+
inline __host__ __device__ int3 operator-(int b, int3 a)
|
619 |
+
{
|
620 |
+
return make_int3(b - a.x, b - a.y, b - a.z);
|
621 |
+
}
|
622 |
+
inline __host__ __device__ void operator-=(int3 &a, int b)
|
623 |
+
{
|
624 |
+
a.x -= b;
|
625 |
+
a.y -= b;
|
626 |
+
a.z -= b;
|
627 |
+
}
|
628 |
+
|
629 |
+
inline __host__ __device__ uint3 operator-(uint3 a, uint3 b)
|
630 |
+
{
|
631 |
+
return make_uint3(a.x - b.x, a.y - b.y, a.z - b.z);
|
632 |
+
}
|
633 |
+
inline __host__ __device__ void operator-=(uint3 &a, uint3 b)
|
634 |
+
{
|
635 |
+
a.x -= b.x;
|
636 |
+
a.y -= b.y;
|
637 |
+
a.z -= b.z;
|
638 |
+
}
|
639 |
+
inline __host__ __device__ uint3 operator-(uint3 a, uint b)
|
640 |
+
{
|
641 |
+
return make_uint3(a.x - b, a.y - b, a.z - b);
|
642 |
+
}
|
643 |
+
inline __host__ __device__ uint3 operator-(uint b, uint3 a)
|
644 |
+
{
|
645 |
+
return make_uint3(b - a.x, b - a.y, b - a.z);
|
646 |
+
}
|
647 |
+
inline __host__ __device__ void operator-=(uint3 &a, uint b)
|
648 |
+
{
|
649 |
+
a.x -= b;
|
650 |
+
a.y -= b;
|
651 |
+
a.z -= b;
|
652 |
+
}
|
653 |
+
|
654 |
+
inline __host__ __device__ float4 operator-(float4 a, float4 b)
|
655 |
+
{
|
656 |
+
return make_float4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
|
657 |
+
}
|
658 |
+
inline __host__ __device__ void operator-=(float4 &a, float4 b)
|
659 |
+
{
|
660 |
+
a.x -= b.x;
|
661 |
+
a.y -= b.y;
|
662 |
+
a.z -= b.z;
|
663 |
+
a.w -= b.w;
|
664 |
+
}
|
665 |
+
inline __host__ __device__ float4 operator-(float4 a, float b)
|
666 |
+
{
|
667 |
+
return make_float4(a.x - b, a.y - b, a.z - b, a.w - b);
|
668 |
+
}
|
669 |
+
inline __host__ __device__ void operator-=(float4 &a, float b)
|
670 |
+
{
|
671 |
+
a.x -= b;
|
672 |
+
a.y -= b;
|
673 |
+
a.z -= b;
|
674 |
+
a.w -= b;
|
675 |
+
}
|
676 |
+
|
677 |
+
inline __host__ __device__ int4 operator-(int4 a, int4 b)
|
678 |
+
{
|
679 |
+
return make_int4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
|
680 |
+
}
|
681 |
+
inline __host__ __device__ void operator-=(int4 &a, int4 b)
|
682 |
+
{
|
683 |
+
a.x -= b.x;
|
684 |
+
a.y -= b.y;
|
685 |
+
a.z -= b.z;
|
686 |
+
a.w -= b.w;
|
687 |
+
}
|
688 |
+
inline __host__ __device__ int4 operator-(int4 a, int b)
|
689 |
+
{
|
690 |
+
return make_int4(a.x - b, a.y - b, a.z - b, a.w - b);
|
691 |
+
}
|
692 |
+
inline __host__ __device__ int4 operator-(int b, int4 a)
|
693 |
+
{
|
694 |
+
return make_int4(b - a.x, b - a.y, b - a.z, b - a.w);
|
695 |
+
}
|
696 |
+
inline __host__ __device__ void operator-=(int4 &a, int b)
|
697 |
+
{
|
698 |
+
a.x -= b;
|
699 |
+
a.y -= b;
|
700 |
+
a.z -= b;
|
701 |
+
a.w -= b;
|
702 |
+
}
|
703 |
+
|
704 |
+
inline __host__ __device__ uint4 operator-(uint4 a, uint4 b)
|
705 |
+
{
|
706 |
+
return make_uint4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
|
707 |
+
}
|
708 |
+
inline __host__ __device__ void operator-=(uint4 &a, uint4 b)
|
709 |
+
{
|
710 |
+
a.x -= b.x;
|
711 |
+
a.y -= b.y;
|
712 |
+
a.z -= b.z;
|
713 |
+
a.w -= b.w;
|
714 |
+
}
|
715 |
+
inline __host__ __device__ uint4 operator-(uint4 a, uint b)
|
716 |
+
{
|
717 |
+
return make_uint4(a.x - b, a.y - b, a.z - b, a.w - b);
|
718 |
+
}
|
719 |
+
inline __host__ __device__ uint4 operator-(uint b, uint4 a)
|
720 |
+
{
|
721 |
+
return make_uint4(b - a.x, b - a.y, b - a.z, b - a.w);
|
722 |
+
}
|
723 |
+
inline __host__ __device__ void operator-=(uint4 &a, uint b)
|
724 |
+
{
|
725 |
+
a.x -= b;
|
726 |
+
a.y -= b;
|
727 |
+
a.z -= b;
|
728 |
+
a.w -= b;
|
729 |
+
}
|
730 |
+
|
731 |
+
////////////////////////////////////////////////////////////////////////////////
|
732 |
+
// multiply
|
733 |
+
////////////////////////////////////////////////////////////////////////////////
|
734 |
+
|
735 |
+
inline __host__ __device__ float2 operator*(float2 a, float2 b)
|
736 |
+
{
|
737 |
+
return make_float2(a.x * b.x, a.y * b.y);
|
738 |
+
}
|
739 |
+
inline __host__ __device__ void operator*=(float2 &a, float2 b)
|
740 |
+
{
|
741 |
+
a.x *= b.x;
|
742 |
+
a.y *= b.y;
|
743 |
+
}
|
744 |
+
inline __host__ __device__ float2 operator*(float2 a, float b)
|
745 |
+
{
|
746 |
+
return make_float2(a.x * b, a.y * b);
|
747 |
+
}
|
748 |
+
inline __host__ __device__ float2 operator*(float b, float2 a)
|
749 |
+
{
|
750 |
+
return make_float2(b * a.x, b * a.y);
|
751 |
+
}
|
752 |
+
inline __host__ __device__ void operator*=(float2 &a, float b)
|
753 |
+
{
|
754 |
+
a.x *= b;
|
755 |
+
a.y *= b;
|
756 |
+
}
|
757 |
+
|
758 |
+
inline __host__ __device__ int2 operator*(int2 a, int2 b)
|
759 |
+
{
|
760 |
+
return make_int2(a.x * b.x, a.y * b.y);
|
761 |
+
}
|
762 |
+
inline __host__ __device__ void operator*=(int2 &a, int2 b)
|
763 |
+
{
|
764 |
+
a.x *= b.x;
|
765 |
+
a.y *= b.y;
|
766 |
+
}
|
767 |
+
inline __host__ __device__ int2 operator*(int2 a, int b)
|
768 |
+
{
|
769 |
+
return make_int2(a.x * b, a.y * b);
|
770 |
+
}
|
771 |
+
inline __host__ __device__ int2 operator*(int b, int2 a)
|
772 |
+
{
|
773 |
+
return make_int2(b * a.x, b * a.y);
|
774 |
+
}
|
775 |
+
inline __host__ __device__ void operator*=(int2 &a, int b)
|
776 |
+
{
|
777 |
+
a.x *= b;
|
778 |
+
a.y *= b;
|
779 |
+
}
|
780 |
+
|
781 |
+
inline __host__ __device__ uint2 operator*(uint2 a, uint2 b)
|
782 |
+
{
|
783 |
+
return make_uint2(a.x * b.x, a.y * b.y);
|
784 |
+
}
|
785 |
+
inline __host__ __device__ void operator*=(uint2 &a, uint2 b)
|
786 |
+
{
|
787 |
+
a.x *= b.x;
|
788 |
+
a.y *= b.y;
|
789 |
+
}
|
790 |
+
inline __host__ __device__ uint2 operator*(uint2 a, uint b)
|
791 |
+
{
|
792 |
+
return make_uint2(a.x * b, a.y * b);
|
793 |
+
}
|
794 |
+
inline __host__ __device__ uint2 operator*(uint b, uint2 a)
|
795 |
+
{
|
796 |
+
return make_uint2(b * a.x, b * a.y);
|
797 |
+
}
|
798 |
+
inline __host__ __device__ void operator*=(uint2 &a, uint b)
|
799 |
+
{
|
800 |
+
a.x *= b;
|
801 |
+
a.y *= b;
|
802 |
+
}
|
803 |
+
|
804 |
+
inline __host__ __device__ float3 operator*(float3 a, float3 b)
|
805 |
+
{
|
806 |
+
return make_float3(a.x * b.x, a.y * b.y, a.z * b.z);
|
807 |
+
}
|
808 |
+
inline __host__ __device__ void operator*=(float3 &a, float3 b)
|
809 |
+
{
|
810 |
+
a.x *= b.x;
|
811 |
+
a.y *= b.y;
|
812 |
+
a.z *= b.z;
|
813 |
+
}
|
814 |
+
inline __host__ __device__ float3 operator*(float3 a, float b)
|
815 |
+
{
|
816 |
+
return make_float3(a.x * b, a.y * b, a.z * b);
|
817 |
+
}
|
818 |
+
inline __host__ __device__ float3 operator*(float b, float3 a)
|
819 |
+
{
|
820 |
+
return make_float3(b * a.x, b * a.y, b * a.z);
|
821 |
+
}
|
822 |
+
inline __host__ __device__ void operator*=(float3 &a, float b)
|
823 |
+
{
|
824 |
+
a.x *= b;
|
825 |
+
a.y *= b;
|
826 |
+
a.z *= b;
|
827 |
+
}
|
828 |
+
|
829 |
+
inline __host__ __device__ int3 operator*(int3 a, int3 b)
|
830 |
+
{
|
831 |
+
return make_int3(a.x * b.x, a.y * b.y, a.z * b.z);
|
832 |
+
}
|
833 |
+
inline __host__ __device__ void operator*=(int3 &a, int3 b)
|
834 |
+
{
|
835 |
+
a.x *= b.x;
|
836 |
+
a.y *= b.y;
|
837 |
+
a.z *= b.z;
|
838 |
+
}
|
839 |
+
inline __host__ __device__ int3 operator*(int3 a, int b)
|
840 |
+
{
|
841 |
+
return make_int3(a.x * b, a.y * b, a.z * b);
|
842 |
+
}
|
843 |
+
inline __host__ __device__ int3 operator*(int b, int3 a)
|
844 |
+
{
|
845 |
+
return make_int3(b * a.x, b * a.y, b * a.z);
|
846 |
+
}
|
847 |
+
inline __host__ __device__ void operator*=(int3 &a, int b)
|
848 |
+
{
|
849 |
+
a.x *= b;
|
850 |
+
a.y *= b;
|
851 |
+
a.z *= b;
|
852 |
+
}
|
853 |
+
|
854 |
+
inline __host__ __device__ uint3 operator*(uint3 a, uint3 b)
|
855 |
+
{
|
856 |
+
return make_uint3(a.x * b.x, a.y * b.y, a.z * b.z);
|
857 |
+
}
|
858 |
+
inline __host__ __device__ void operator*=(uint3 &a, uint3 b)
|
859 |
+
{
|
860 |
+
a.x *= b.x;
|
861 |
+
a.y *= b.y;
|
862 |
+
a.z *= b.z;
|
863 |
+
}
|
864 |
+
inline __host__ __device__ uint3 operator*(uint3 a, uint b)
|
865 |
+
{
|
866 |
+
return make_uint3(a.x * b, a.y * b, a.z * b);
|
867 |
+
}
|
868 |
+
inline __host__ __device__ uint3 operator*(uint b, uint3 a)
|
869 |
+
{
|
870 |
+
return make_uint3(b * a.x, b * a.y, b * a.z);
|
871 |
+
}
|
872 |
+
inline __host__ __device__ void operator*=(uint3 &a, uint b)
|
873 |
+
{
|
874 |
+
a.x *= b;
|
875 |
+
a.y *= b;
|
876 |
+
a.z *= b;
|
877 |
+
}
|
878 |
+
|
879 |
+
inline __host__ __device__ float4 operator*(float4 a, float4 b)
|
880 |
+
{
|
881 |
+
return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
|
882 |
+
}
|
883 |
+
inline __host__ __device__ void operator*=(float4 &a, float4 b)
|
884 |
+
{
|
885 |
+
a.x *= b.x;
|
886 |
+
a.y *= b.y;
|
887 |
+
a.z *= b.z;
|
888 |
+
a.w *= b.w;
|
889 |
+
}
|
890 |
+
inline __host__ __device__ float4 operator*(float4 a, float b)
|
891 |
+
{
|
892 |
+
return make_float4(a.x * b, a.y * b, a.z * b, a.w * b);
|
893 |
+
}
|
894 |
+
inline __host__ __device__ float4 operator*(float b, float4 a)
|
895 |
+
{
|
896 |
+
return make_float4(b * a.x, b * a.y, b * a.z, b * a.w);
|
897 |
+
}
|
898 |
+
inline __host__ __device__ void operator*=(float4 &a, float b)
|
899 |
+
{
|
900 |
+
a.x *= b;
|
901 |
+
a.y *= b;
|
902 |
+
a.z *= b;
|
903 |
+
a.w *= b;
|
904 |
+
}
|
905 |
+
|
906 |
+
inline __host__ __device__ int4 operator*(int4 a, int4 b)
|
907 |
+
{
|
908 |
+
return make_int4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
|
909 |
+
}
|
910 |
+
inline __host__ __device__ void operator*=(int4 &a, int4 b)
|
911 |
+
{
|
912 |
+
a.x *= b.x;
|
913 |
+
a.y *= b.y;
|
914 |
+
a.z *= b.z;
|
915 |
+
a.w *= b.w;
|
916 |
+
}
|
917 |
+
inline __host__ __device__ int4 operator*(int4 a, int b)
|
918 |
+
{
|
919 |
+
return make_int4(a.x * b, a.y * b, a.z * b, a.w * b);
|
920 |
+
}
|
921 |
+
inline __host__ __device__ int4 operator*(int b, int4 a)
|
922 |
+
{
|
923 |
+
return make_int4(b * a.x, b * a.y, b * a.z, b * a.w);
|
924 |
+
}
|
925 |
+
inline __host__ __device__ void operator*=(int4 &a, int b)
|
926 |
+
{
|
927 |
+
a.x *= b;
|
928 |
+
a.y *= b;
|
929 |
+
a.z *= b;
|
930 |
+
a.w *= b;
|
931 |
+
}
|
932 |
+
|
933 |
+
inline __host__ __device__ uint4 operator*(uint4 a, uint4 b)
|
934 |
+
{
|
935 |
+
return make_uint4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
|
936 |
+
}
|
937 |
+
inline __host__ __device__ void operator*=(uint4 &a, uint4 b)
|
938 |
+
{
|
939 |
+
a.x *= b.x;
|
940 |
+
a.y *= b.y;
|
941 |
+
a.z *= b.z;
|
942 |
+
a.w *= b.w;
|
943 |
+
}
|
944 |
+
inline __host__ __device__ uint4 operator*(uint4 a, uint b)
|
945 |
+
{
|
946 |
+
return make_uint4(a.x * b, a.y * b, a.z * b, a.w * b);
|
947 |
+
}
|
948 |
+
inline __host__ __device__ uint4 operator*(uint b, uint4 a)
|
949 |
+
{
|
950 |
+
return make_uint4(b * a.x, b * a.y, b * a.z, b * a.w);
|
951 |
+
}
|
952 |
+
inline __host__ __device__ void operator*=(uint4 &a, uint b)
|
953 |
+
{
|
954 |
+
a.x *= b;
|
955 |
+
a.y *= b;
|
956 |
+
a.z *= b;
|
957 |
+
a.w *= b;
|
958 |
+
}
|
959 |
+
|
960 |
+
////////////////////////////////////////////////////////////////////////////////
|
961 |
+
// divide
|
962 |
+
////////////////////////////////////////////////////////////////////////////////
|
963 |
+
|
964 |
+
inline __host__ __device__ float2 operator/(float2 a, float2 b)
|
965 |
+
{
|
966 |
+
return make_float2(a.x / b.x, a.y / b.y);
|
967 |
+
}
|
968 |
+
inline __host__ __device__ void operator/=(float2 &a, float2 b)
|
969 |
+
{
|
970 |
+
a.x /= b.x;
|
971 |
+
a.y /= b.y;
|
972 |
+
}
|
973 |
+
inline __host__ __device__ float2 operator/(float2 a, float b)
|
974 |
+
{
|
975 |
+
return make_float2(a.x / b, a.y / b);
|
976 |
+
}
|
977 |
+
inline __host__ __device__ void operator/=(float2 &a, float b)
|
978 |
+
{
|
979 |
+
a.x /= b;
|
980 |
+
a.y /= b;
|
981 |
+
}
|
982 |
+
inline __host__ __device__ float2 operator/(float b, float2 a)
|
983 |
+
{
|
984 |
+
return make_float2(b / a.x, b / a.y);
|
985 |
+
}
|
986 |
+
|
987 |
+
inline __host__ __device__ float3 operator/(float3 a, float3 b)
|
988 |
+
{
|
989 |
+
return make_float3(a.x / b.x, a.y / b.y, a.z / b.z);
|
990 |
+
}
|
991 |
+
inline __host__ __device__ void operator/=(float3 &a, float3 b)
|
992 |
+
{
|
993 |
+
a.x /= b.x;
|
994 |
+
a.y /= b.y;
|
995 |
+
a.z /= b.z;
|
996 |
+
}
|
997 |
+
inline __host__ __device__ float3 operator/(float3 a, float b)
|
998 |
+
{
|
999 |
+
return make_float3(a.x / b, a.y / b, a.z / b);
|
1000 |
+
}
|
1001 |
+
inline __host__ __device__ void operator/=(float3 &a, float b)
|
1002 |
+
{
|
1003 |
+
a.x /= b;
|
1004 |
+
a.y /= b;
|
1005 |
+
a.z /= b;
|
1006 |
+
}
|
1007 |
+
inline __host__ __device__ float3 operator/(float b, float3 a)
|
1008 |
+
{
|
1009 |
+
return make_float3(b / a.x, b / a.y, b / a.z);
|
1010 |
+
}
|
1011 |
+
|
1012 |
+
inline __host__ __device__ float4 operator/(float4 a, float4 b)
|
1013 |
+
{
|
1014 |
+
return make_float4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w);
|
1015 |
+
}
|
1016 |
+
inline __host__ __device__ void operator/=(float4 &a, float4 b)
|
1017 |
+
{
|
1018 |
+
a.x /= b.x;
|
1019 |
+
a.y /= b.y;
|
1020 |
+
a.z /= b.z;
|
1021 |
+
a.w /= b.w;
|
1022 |
+
}
|
1023 |
+
inline __host__ __device__ float4 operator/(float4 a, float b)
|
1024 |
+
{
|
1025 |
+
return make_float4(a.x / b, a.y / b, a.z / b, a.w / b);
|
1026 |
+
}
|
1027 |
+
inline __host__ __device__ void operator/=(float4 &a, float b)
|
1028 |
+
{
|
1029 |
+
a.x /= b;
|
1030 |
+
a.y /= b;
|
1031 |
+
a.z /= b;
|
1032 |
+
a.w /= b;
|
1033 |
+
}
|
1034 |
+
inline __host__ __device__ float4 operator/(float b, float4 a)
|
1035 |
+
{
|
1036 |
+
return make_float4(b / a.x, b / a.y, b / a.z, b / a.w);
|
1037 |
+
}
|
1038 |
+
|
1039 |
+
////////////////////////////////////////////////////////////////////////////////
|
1040 |
+
// min
|
1041 |
+
////////////////////////////////////////////////////////////////////////////////
|
1042 |
+
|
1043 |
+
inline __host__ __device__ float2 fminf(float2 a, float2 b)
|
1044 |
+
{
|
1045 |
+
return make_float2(fminf(a.x,b.x), fminf(a.y,b.y));
|
1046 |
+
}
|
1047 |
+
inline __host__ __device__ float3 fminf(float3 a, float3 b)
|
1048 |
+
{
|
1049 |
+
return make_float3(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z));
|
1050 |
+
}
|
1051 |
+
inline __host__ __device__ float4 fminf(float4 a, float4 b)
|
1052 |
+
{
|
1053 |
+
return make_float4(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z), fminf(a.w,b.w));
|
1054 |
+
}
|
1055 |
+
|
1056 |
+
inline __host__ __device__ int2 min(int2 a, int2 b)
|
1057 |
+
{
|
1058 |
+
return make_int2(min(a.x,b.x), min(a.y,b.y));
|
1059 |
+
}
|
1060 |
+
inline __host__ __device__ int3 min(int3 a, int3 b)
|
1061 |
+
{
|
1062 |
+
return make_int3(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z));
|
1063 |
+
}
|
1064 |
+
inline __host__ __device__ int4 min(int4 a, int4 b)
|
1065 |
+
{
|
1066 |
+
return make_int4(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z), min(a.w,b.w));
|
1067 |
+
}
|
1068 |
+
|
1069 |
+
inline __host__ __device__ uint2 min(uint2 a, uint2 b)
|
1070 |
+
{
|
1071 |
+
return make_uint2(min(a.x,b.x), min(a.y,b.y));
|
1072 |
+
}
|
1073 |
+
inline __host__ __device__ uint3 min(uint3 a, uint3 b)
|
1074 |
+
{
|
1075 |
+
return make_uint3(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z));
|
1076 |
+
}
|
1077 |
+
inline __host__ __device__ uint4 min(uint4 a, uint4 b)
|
1078 |
+
{
|
1079 |
+
return make_uint4(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z), min(a.w,b.w));
|
1080 |
+
}
|
1081 |
+
|
1082 |
+
////////////////////////////////////////////////////////////////////////////////
|
1083 |
+
// max
|
1084 |
+
////////////////////////////////////////////////////////////////////////////////
|
1085 |
+
|
1086 |
+
inline __host__ __device__ float2 fmaxf(float2 a, float2 b)
|
1087 |
+
{
|
1088 |
+
return make_float2(fmaxf(a.x,b.x), fmaxf(a.y,b.y));
|
1089 |
+
}
|
1090 |
+
inline __host__ __device__ float3 fmaxf(float3 a, float3 b)
|
1091 |
+
{
|
1092 |
+
return make_float3(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z));
|
1093 |
+
}
|
1094 |
+
inline __host__ __device__ float4 fmaxf(float4 a, float4 b)
|
1095 |
+
{
|
1096 |
+
return make_float4(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z), fmaxf(a.w,b.w));
|
1097 |
+
}
|
1098 |
+
|
1099 |
+
inline __host__ __device__ int2 max(int2 a, int2 b)
|
1100 |
+
{
|
1101 |
+
return make_int2(max(a.x,b.x), max(a.y,b.y));
|
1102 |
+
}
|
1103 |
+
inline __host__ __device__ int3 max(int3 a, int3 b)
|
1104 |
+
{
|
1105 |
+
return make_int3(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z));
|
1106 |
+
}
|
1107 |
+
inline __host__ __device__ int4 max(int4 a, int4 b)
|
1108 |
+
{
|
1109 |
+
return make_int4(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z), max(a.w,b.w));
|
1110 |
+
}
|
1111 |
+
|
1112 |
+
inline __host__ __device__ uint2 max(uint2 a, uint2 b)
|
1113 |
+
{
|
1114 |
+
return make_uint2(max(a.x,b.x), max(a.y,b.y));
|
1115 |
+
}
|
1116 |
+
inline __host__ __device__ uint3 max(uint3 a, uint3 b)
|
1117 |
+
{
|
1118 |
+
return make_uint3(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z));
|
1119 |
+
}
|
1120 |
+
inline __host__ __device__ uint4 max(uint4 a, uint4 b)
|
1121 |
+
{
|
1122 |
+
return make_uint4(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z), max(a.w,b.w));
|
1123 |
+
}
|
1124 |
+
|
1125 |
+
////////////////////////////////////////////////////////////////////////////////
|
1126 |
+
// lerp
|
1127 |
+
// - linear interpolation between a and b, based on value t in [0, 1] range
|
1128 |
+
////////////////////////////////////////////////////////////////////////////////
|
1129 |
+
|
1130 |
+
inline __device__ __host__ float lerp(float a, float b, float t)
|
1131 |
+
{
|
1132 |
+
return a + t*(b-a);
|
1133 |
+
}
|
1134 |
+
inline __device__ __host__ float2 lerp(float2 a, float2 b, float t)
|
1135 |
+
{
|
1136 |
+
return a + t*(b-a);
|
1137 |
+
}
|
1138 |
+
inline __device__ __host__ float3 lerp(float3 a, float3 b, float t)
|
1139 |
+
{
|
1140 |
+
return a + t*(b-a);
|
1141 |
+
}
|
1142 |
+
inline __device__ __host__ float4 lerp(float4 a, float4 b, float t)
|
1143 |
+
{
|
1144 |
+
return a + t*(b-a);
|
1145 |
+
}
|
1146 |
+
|
1147 |
+
////////////////////////////////////////////////////////////////////////////////
|
1148 |
+
// clamp
|
1149 |
+
// - clamp the value v to be in the range [a, b]
|
1150 |
+
////////////////////////////////////////////////////////////////////////////////
|
1151 |
+
|
1152 |
+
inline __device__ __host__ float clamp(float f, float a, float b)
|
1153 |
+
{
|
1154 |
+
return fmaxf(a, fminf(f, b));
|
1155 |
+
}
|
1156 |
+
inline __device__ __host__ int clamp(int f, int a, int b)
|
1157 |
+
{
|
1158 |
+
return max(a, min(f, b));
|
1159 |
+
}
|
1160 |
+
inline __device__ __host__ uint clamp(uint f, uint a, uint b)
|
1161 |
+
{
|
1162 |
+
return max(a, min(f, b));
|
1163 |
+
}
|
1164 |
+
|
1165 |
+
inline __device__ __host__ float2 clamp(float2 v, float a, float b)
|
1166 |
+
{
|
1167 |
+
return make_float2(clamp(v.x, a, b), clamp(v.y, a, b));
|
1168 |
+
}
|
1169 |
+
inline __device__ __host__ float2 clamp(float2 v, float2 a, float2 b)
|
1170 |
+
{
|
1171 |
+
return make_float2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
|
1172 |
+
}
|
1173 |
+
inline __device__ __host__ float3 clamp(float3 v, float a, float b)
|
1174 |
+
{
|
1175 |
+
return make_float3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
|
1176 |
+
}
|
1177 |
+
inline __device__ __host__ float3 clamp(float3 v, float3 a, float3 b)
|
1178 |
+
{
|
1179 |
+
return make_float3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
|
1180 |
+
}
|
1181 |
+
inline __device__ __host__ float4 clamp(float4 v, float a, float b)
|
1182 |
+
{
|
1183 |
+
return make_float4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
|
1184 |
+
}
|
1185 |
+
inline __device__ __host__ float4 clamp(float4 v, float4 a, float4 b)
|
1186 |
+
{
|
1187 |
+
return make_float4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
|
1188 |
+
}
|
1189 |
+
|
1190 |
+
inline __device__ __host__ int2 clamp(int2 v, int a, int b)
|
1191 |
+
{
|
1192 |
+
return make_int2(clamp(v.x, a, b), clamp(v.y, a, b));
|
1193 |
+
}
|
1194 |
+
inline __device__ __host__ int2 clamp(int2 v, int2 a, int2 b)
|
1195 |
+
{
|
1196 |
+
return make_int2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
|
1197 |
+
}
|
1198 |
+
inline __device__ __host__ int3 clamp(int3 v, int a, int b)
|
1199 |
+
{
|
1200 |
+
return make_int3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
|
1201 |
+
}
|
1202 |
+
inline __device__ __host__ int3 clamp(int3 v, int3 a, int3 b)
|
1203 |
+
{
|
1204 |
+
return make_int3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
|
1205 |
+
}
|
1206 |
+
inline __device__ __host__ int4 clamp(int4 v, int a, int b)
|
1207 |
+
{
|
1208 |
+
return make_int4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
|
1209 |
+
}
|
1210 |
+
inline __device__ __host__ int4 clamp(int4 v, int4 a, int4 b)
|
1211 |
+
{
|
1212 |
+
return make_int4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
|
1213 |
+
}
|
1214 |
+
|
1215 |
+
inline __device__ __host__ uint2 clamp(uint2 v, uint a, uint b)
|
1216 |
+
{
|
1217 |
+
return make_uint2(clamp(v.x, a, b), clamp(v.y, a, b));
|
1218 |
+
}
|
1219 |
+
inline __device__ __host__ uint2 clamp(uint2 v, uint2 a, uint2 b)
|
1220 |
+
{
|
1221 |
+
return make_uint2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
|
1222 |
+
}
|
1223 |
+
inline __device__ __host__ uint3 clamp(uint3 v, uint a, uint b)
|
1224 |
+
{
|
1225 |
+
return make_uint3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
|
1226 |
+
}
|
1227 |
+
inline __device__ __host__ uint3 clamp(uint3 v, uint3 a, uint3 b)
|
1228 |
+
{
|
1229 |
+
return make_uint3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
|
1230 |
+
}
|
1231 |
+
inline __device__ __host__ uint4 clamp(uint4 v, uint a, uint b)
|
1232 |
+
{
|
1233 |
+
return make_uint4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
|
1234 |
+
}
|
1235 |
+
inline __device__ __host__ uint4 clamp(uint4 v, uint4 a, uint4 b)
|
1236 |
+
{
|
1237 |
+
return make_uint4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
|
1238 |
+
}
|
1239 |
+
|
1240 |
+
////////////////////////////////////////////////////////////////////////////////
|
1241 |
+
// dot product
|
1242 |
+
////////////////////////////////////////////////////////////////////////////////
|
1243 |
+
|
1244 |
+
inline __host__ __device__ float dot(float2 a, float2 b)
|
1245 |
+
{
|
1246 |
+
return a.x * b.x + a.y * b.y;
|
1247 |
+
}
|
1248 |
+
inline __host__ __device__ float dot(float3 a, float3 b)
|
1249 |
+
{
|
1250 |
+
return a.x * b.x + a.y * b.y + a.z * b.z;
|
1251 |
+
}
|
1252 |
+
inline __host__ __device__ float dot(float4 a, float4 b)
|
1253 |
+
{
|
1254 |
+
return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
|
1255 |
+
}
|
1256 |
+
|
1257 |
+
inline __host__ __device__ int dot(int2 a, int2 b)
|
1258 |
+
{
|
1259 |
+
return a.x * b.x + a.y * b.y;
|
1260 |
+
}
|
1261 |
+
inline __host__ __device__ int dot(int3 a, int3 b)
|
1262 |
+
{
|
1263 |
+
return a.x * b.x + a.y * b.y + a.z * b.z;
|
1264 |
+
}
|
1265 |
+
inline __host__ __device__ int dot(int4 a, int4 b)
|
1266 |
+
{
|
1267 |
+
return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
|
1268 |
+
}
|
1269 |
+
|
1270 |
+
inline __host__ __device__ uint dot(uint2 a, uint2 b)
|
1271 |
+
{
|
1272 |
+
return a.x * b.x + a.y * b.y;
|
1273 |
+
}
|
1274 |
+
inline __host__ __device__ uint dot(uint3 a, uint3 b)
|
1275 |
+
{
|
1276 |
+
return a.x * b.x + a.y * b.y + a.z * b.z;
|
1277 |
+
}
|
1278 |
+
inline __host__ __device__ uint dot(uint4 a, uint4 b)
|
1279 |
+
{
|
1280 |
+
return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
|
1281 |
+
}
|
1282 |
+
|
1283 |
+
////////////////////////////////////////////////////////////////////////////////
|
1284 |
+
// length
|
1285 |
+
////////////////////////////////////////////////////////////////////////////////
|
1286 |
+
|
1287 |
+
inline __host__ __device__ float length(float2 v)
|
1288 |
+
{
|
1289 |
+
return sqrtf(dot(v, v));
|
1290 |
+
}
|
1291 |
+
inline __host__ __device__ float length(float3 v)
|
1292 |
+
{
|
1293 |
+
return sqrtf(dot(v, v));
|
1294 |
+
}
|
1295 |
+
inline __host__ __device__ float length(float4 v)
|
1296 |
+
{
|
1297 |
+
return sqrtf(dot(v, v));
|
1298 |
+
}
|
1299 |
+
|
1300 |
+
////////////////////////////////////////////////////////////////////////////////
|
1301 |
+
// normalize
|
1302 |
+
////////////////////////////////////////////////////////////////////////////////
|
1303 |
+
|
1304 |
+
inline __host__ __device__ float2 normalize(float2 v)
|
1305 |
+
{
|
1306 |
+
float invLen = rsqrtf(dot(v, v));
|
1307 |
+
return v * invLen;
|
1308 |
+
}
|
1309 |
+
inline __host__ __device__ float3 normalize(float3 v)
|
1310 |
+
{
|
1311 |
+
float invLen = rsqrtf(dot(v, v));
|
1312 |
+
return v * invLen;
|
1313 |
+
}
|
1314 |
+
inline __host__ __device__ float4 normalize(float4 v)
|
1315 |
+
{
|
1316 |
+
float invLen = rsqrtf(dot(v, v));
|
1317 |
+
return v * invLen;
|
1318 |
+
}
|
1319 |
+
|
1320 |
+
////////////////////////////////////////////////////////////////////////////////
|
1321 |
+
// floor
|
1322 |
+
////////////////////////////////////////////////////////////////////////////////
|
1323 |
+
|
1324 |
+
inline __host__ __device__ float2 floorf(float2 v)
|
1325 |
+
{
|
1326 |
+
return make_float2(floorf(v.x), floorf(v.y));
|
1327 |
+
}
|
1328 |
+
inline __host__ __device__ float3 floorf(float3 v)
|
1329 |
+
{
|
1330 |
+
return make_float3(floorf(v.x), floorf(v.y), floorf(v.z));
|
1331 |
+
}
|
1332 |
+
inline __host__ __device__ float4 floorf(float4 v)
|
1333 |
+
{
|
1334 |
+
return make_float4(floorf(v.x), floorf(v.y), floorf(v.z), floorf(v.w));
|
1335 |
+
}
|
1336 |
+
|
1337 |
+
////////////////////////////////////////////////////////////////////////////////
|
1338 |
+
// frac - returns the fractional portion of a scalar or each vector component
|
1339 |
+
////////////////////////////////////////////////////////////////////////////////
|
1340 |
+
|
1341 |
+
inline __host__ __device__ float fracf(float v)
|
1342 |
+
{
|
1343 |
+
return v - floorf(v);
|
1344 |
+
}
|
1345 |
+
inline __host__ __device__ float2 fracf(float2 v)
|
1346 |
+
{
|
1347 |
+
return make_float2(fracf(v.x), fracf(v.y));
|
1348 |
+
}
|
1349 |
+
inline __host__ __device__ float3 fracf(float3 v)
|
1350 |
+
{
|
1351 |
+
return make_float3(fracf(v.x), fracf(v.y), fracf(v.z));
|
1352 |
+
}
|
1353 |
+
inline __host__ __device__ float4 fracf(float4 v)
|
1354 |
+
{
|
1355 |
+
return make_float4(fracf(v.x), fracf(v.y), fracf(v.z), fracf(v.w));
|
1356 |
+
}
|
1357 |
+
|
1358 |
+
////////////////////////////////////////////////////////////////////////////////
|
1359 |
+
// fmod
|
1360 |
+
////////////////////////////////////////////////////////////////////////////////
|
1361 |
+
|
1362 |
+
inline __host__ __device__ float2 fmodf(float2 a, float2 b)
|
1363 |
+
{
|
1364 |
+
return make_float2(fmodf(a.x, b.x), fmodf(a.y, b.y));
|
1365 |
+
}
|
1366 |
+
inline __host__ __device__ float3 fmodf(float3 a, float3 b)
|
1367 |
+
{
|
1368 |
+
return make_float3(fmodf(a.x, b.x), fmodf(a.y, b.y), fmodf(a.z, b.z));
|
1369 |
+
}
|
1370 |
+
inline __host__ __device__ float4 fmodf(float4 a, float4 b)
|
1371 |
+
{
|
1372 |
+
return make_float4(fmodf(a.x, b.x), fmodf(a.y, b.y), fmodf(a.z, b.z), fmodf(a.w, b.w));
|
1373 |
+
}
|
1374 |
+
|
1375 |
+
////////////////////////////////////////////////////////////////////////////////
|
1376 |
+
// absolute value
|
1377 |
+
////////////////////////////////////////////////////////////////////////////////
|
1378 |
+
|
1379 |
+
inline __host__ __device__ float2 fabs(float2 v)
|
1380 |
+
{
|
1381 |
+
return make_float2(fabs(v.x), fabs(v.y));
|
1382 |
+
}
|
1383 |
+
inline __host__ __device__ float3 fabs(float3 v)
|
1384 |
+
{
|
1385 |
+
return make_float3(fabs(v.x), fabs(v.y), fabs(v.z));
|
1386 |
+
}
|
1387 |
+
inline __host__ __device__ float4 fabs(float4 v)
|
1388 |
+
{
|
1389 |
+
return make_float4(fabs(v.x), fabs(v.y), fabs(v.z), fabs(v.w));
|
1390 |
+
}
|
1391 |
+
|
1392 |
+
inline __host__ __device__ int2 abs(int2 v)
|
1393 |
+
{
|
1394 |
+
return make_int2(abs(v.x), abs(v.y));
|
1395 |
+
}
|
1396 |
+
inline __host__ __device__ int3 abs(int3 v)
|
1397 |
+
{
|
1398 |
+
return make_int3(abs(v.x), abs(v.y), abs(v.z));
|
1399 |
+
}
|
1400 |
+
inline __host__ __device__ int4 abs(int4 v)
|
1401 |
+
{
|
1402 |
+
return make_int4(abs(v.x), abs(v.y), abs(v.z), abs(v.w));
|
1403 |
+
}
|
1404 |
+
|
1405 |
+
////////////////////////////////////////////////////////////////////////////////
|
1406 |
+
// reflect
|
1407 |
+
// - returns reflection of incident ray I around surface normal N
|
1408 |
+
// - N should be normalized, reflected vector's length is equal to length of I
|
1409 |
+
////////////////////////////////////////////////////////////////////////////////
|
1410 |
+
|
1411 |
+
inline __host__ __device__ float3 reflect(float3 i, float3 n)
|
1412 |
+
{
|
1413 |
+
return i - 2.0f * n * dot(n,i);
|
1414 |
+
}
|
1415 |
+
|
1416 |
+
////////////////////////////////////////////////////////////////////////////////
|
1417 |
+
// cross product
|
1418 |
+
////////////////////////////////////////////////////////////////////////////////
|
1419 |
+
|
1420 |
+
inline __host__ __device__ float3 cross(float3 a, float3 b)
|
1421 |
+
{
|
1422 |
+
return make_float3(a.y*b.z - a.z*b.y, a.z*b.x - a.x*b.z, a.x*b.y - a.y*b.x);
|
1423 |
+
}
|
1424 |
+
|
1425 |
+
////////////////////////////////////////////////////////////////////////////////
|
1426 |
+
// smoothstep
|
1427 |
+
// - returns 0 if x < a
|
1428 |
+
// - returns 1 if x > b
|
1429 |
+
// - otherwise returns smooth interpolation between 0 and 1 based on x
|
1430 |
+
////////////////////////////////////////////////////////////////////////////////
|
1431 |
+
|
1432 |
+
inline __device__ __host__ float smoothstep(float a, float b, float x)
|
1433 |
+
{
|
1434 |
+
float y = clamp((x - a) / (b - a), 0.0f, 1.0f);
|
1435 |
+
return (y*y*(3.0f - (2.0f*y)));
|
1436 |
+
}
|
1437 |
+
inline __device__ __host__ float2 smoothstep(float2 a, float2 b, float2 x)
|
1438 |
+
{
|
1439 |
+
float2 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
|
1440 |
+
return (y*y*(make_float2(3.0f) - (make_float2(2.0f)*y)));
|
1441 |
+
}
|
1442 |
+
inline __device__ __host__ float3 smoothstep(float3 a, float3 b, float3 x)
|
1443 |
+
{
|
1444 |
+
float3 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
|
1445 |
+
return (y*y*(make_float3(3.0f) - (make_float3(2.0f)*y)));
|
1446 |
+
}
|
1447 |
+
inline __device__ __host__ float4 smoothstep(float4 a, float4 b, float4 x)
|
1448 |
+
{
|
1449 |
+
float4 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
|
1450 |
+
return (y*y*(make_float4(3.0f) - (make_float4(2.0f)*y)));
|
1451 |
+
}
|
1452 |
+
|
1453 |
+
#endif
|
dva/mvp/extensions/mvpraymarch/makefile
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
all:
|
2 |
+
python setup.py build_ext --inplace
|
dva/mvp/extensions/mvpraymarch/mvpraymarch.cpp
ADDED
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
// All rights reserved.
|
3 |
+
//
|
4 |
+
// This source code is licensed under the license found in the
|
5 |
+
// LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
#include <torch/extension.h>
|
8 |
+
#include <c10/cuda/CUDAStream.h>
|
9 |
+
|
10 |
+
#include <vector>
|
11 |
+
|
12 |
+
void compute_morton_cuda(
|
13 |
+
int N, int K,
|
14 |
+
float * primpos,
|
15 |
+
int * code,
|
16 |
+
int algorithm,
|
17 |
+
cudaStream_t stream);
|
18 |
+
|
19 |
+
void build_tree_cuda(
|
20 |
+
int N, int K,
|
21 |
+
int * sortedcode,
|
22 |
+
int * nodechildren,
|
23 |
+
int * nodeparent,
|
24 |
+
cudaStream_t stream);
|
25 |
+
|
26 |
+
void compute_aabb_cuda(
|
27 |
+
int N, int K,
|
28 |
+
float * primpos,
|
29 |
+
float * primrot,
|
30 |
+
float * primscale,
|
31 |
+
int * sortedobjid,
|
32 |
+
int * nodechildren,
|
33 |
+
int * nodeparent,
|
34 |
+
float * nodeaabb,
|
35 |
+
int algorithm,
|
36 |
+
cudaStream_t stream);
|
37 |
+
|
38 |
+
void raymarch_forward_cuda(
|
39 |
+
int N, int H, int W, int K,
|
40 |
+
float * rayposim,
|
41 |
+
float * raydirim,
|
42 |
+
float stepsize,
|
43 |
+
float * tminmaxim,
|
44 |
+
|
45 |
+
int * sortedobjid,
|
46 |
+
int * nodechildren,
|
47 |
+
float * nodeaabb,
|
48 |
+
|
49 |
+
float * primpos,
|
50 |
+
float * primrot,
|
51 |
+
float * primscale,
|
52 |
+
|
53 |
+
int TD, int TH, int TW,
|
54 |
+
float * tplate,
|
55 |
+
int WD, int WH, int WW,
|
56 |
+
float * warp,
|
57 |
+
|
58 |
+
float * rayrgbaim,
|
59 |
+
float * raysatim,
|
60 |
+
int * raytermim,
|
61 |
+
|
62 |
+
int algorithm, bool sortboxes, int maxhitboxes, bool synchitboxes,
|
63 |
+
bool chlast, float fadescale, float fadeexp, int accum, float termthresh,
|
64 |
+
int griddim, int blocksizex, int blocksizey,
|
65 |
+
cudaStream_t stream);
|
66 |
+
|
67 |
+
void raymarch_backward_cuda(
|
68 |
+
int N, int H, int W, int K,
|
69 |
+
float * rayposim,
|
70 |
+
float * raydirim,
|
71 |
+
float stepsize,
|
72 |
+
float * tminmaxim,
|
73 |
+
|
74 |
+
int * sortedobjid,
|
75 |
+
int * nodechildren,
|
76 |
+
float * nodeaabb,
|
77 |
+
|
78 |
+
float * primpos,
|
79 |
+
float * grad_primpos,
|
80 |
+
float * primrot,
|
81 |
+
float * grad_primrot,
|
82 |
+
float * primscale,
|
83 |
+
float * grad_primscale,
|
84 |
+
|
85 |
+
int TD, int TH, int TW,
|
86 |
+
float * tplate,
|
87 |
+
float * grad_tplate,
|
88 |
+
int WD, int WH, int WW,
|
89 |
+
float * warp,
|
90 |
+
float * grad_warp,
|
91 |
+
|
92 |
+
float * rayrgbaim,
|
93 |
+
float * grad_rayrgba,
|
94 |
+
float * raysatim,
|
95 |
+
int * raytermim,
|
96 |
+
|
97 |
+
int algorithm, bool sortboxes, int maxhitboxes, bool synchitboxes,
|
98 |
+
bool chlast, float fadescale, float fadeexp, int accum, float termthresh,
|
99 |
+
int griddim, int blocksizex, int blocksizey,
|
100 |
+
cudaStream_t stream);
|
101 |
+
|
102 |
+
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
|
103 |
+
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
104 |
+
#define CHECK_INPUT(x) CHECK_CUDA((x)); CHECK_CONTIGUOUS((x))
|
105 |
+
|
106 |
+
std::vector<torch::Tensor> compute_morton(
|
107 |
+
torch::Tensor primpos,
|
108 |
+
torch::Tensor code,
|
109 |
+
int algorithm) {
|
110 |
+
CHECK_INPUT(primpos);
|
111 |
+
CHECK_INPUT(code);
|
112 |
+
|
113 |
+
int N = primpos.size(0);
|
114 |
+
int K = primpos.size(1);
|
115 |
+
|
116 |
+
compute_morton_cuda(
|
117 |
+
N, K,
|
118 |
+
reinterpret_cast<float *>(primpos.data_ptr()),
|
119 |
+
reinterpret_cast<int *>(code.data_ptr()),
|
120 |
+
algorithm,
|
121 |
+
0);
|
122 |
+
|
123 |
+
return {};
|
124 |
+
}
|
125 |
+
|
126 |
+
std::vector<torch::Tensor> build_tree(
|
127 |
+
torch::Tensor sortedcode,
|
128 |
+
torch::Tensor nodechildren,
|
129 |
+
torch::Tensor nodeparent) {
|
130 |
+
CHECK_INPUT(sortedcode);
|
131 |
+
CHECK_INPUT(nodechildren);
|
132 |
+
CHECK_INPUT(nodeparent);
|
133 |
+
|
134 |
+
int N = sortedcode.size(0);
|
135 |
+
int K = sortedcode.size(1);
|
136 |
+
|
137 |
+
build_tree_cuda(N, K,
|
138 |
+
reinterpret_cast<int *>(sortedcode.data_ptr()),
|
139 |
+
reinterpret_cast<int *>(nodechildren.data_ptr()),
|
140 |
+
reinterpret_cast<int *>(nodeparent.data_ptr()),
|
141 |
+
0);
|
142 |
+
|
143 |
+
return {};
|
144 |
+
}
|
145 |
+
|
146 |
+
std::vector<torch::Tensor> compute_aabb(
|
147 |
+
torch::Tensor primpos,
|
148 |
+
torch::optional<torch::Tensor> primrot,
|
149 |
+
torch::optional<torch::Tensor> primscale,
|
150 |
+
torch::Tensor sortedobjid,
|
151 |
+
torch::Tensor nodechildren,
|
152 |
+
torch::Tensor nodeparent,
|
153 |
+
torch::Tensor nodeaabb,
|
154 |
+
int algorithm) {
|
155 |
+
CHECK_INPUT(sortedobjid);
|
156 |
+
CHECK_INPUT(primpos);
|
157 |
+
if (primrot) { CHECK_INPUT(*primrot); }
|
158 |
+
if (primscale) { CHECK_INPUT(*primscale); }
|
159 |
+
CHECK_INPUT(nodechildren);
|
160 |
+
CHECK_INPUT(nodeparent);
|
161 |
+
CHECK_INPUT(nodeaabb);
|
162 |
+
|
163 |
+
int N = primpos.size(0);
|
164 |
+
int K = primpos.size(1);
|
165 |
+
|
166 |
+
compute_aabb_cuda(N, K,
|
167 |
+
reinterpret_cast<float *>(primpos.data_ptr()),
|
168 |
+
primrot ? reinterpret_cast<float *>(primrot->data_ptr()) : nullptr,
|
169 |
+
primscale ? reinterpret_cast<float *>(primscale->data_ptr()) : nullptr,
|
170 |
+
reinterpret_cast<int *>(sortedobjid.data_ptr()),
|
171 |
+
reinterpret_cast<int *>(nodechildren.data_ptr()),
|
172 |
+
reinterpret_cast<int *>(nodeparent.data_ptr()),
|
173 |
+
reinterpret_cast<float *>(nodeaabb.data_ptr()),
|
174 |
+
algorithm,
|
175 |
+
0);
|
176 |
+
|
177 |
+
return {};
|
178 |
+
}
|
179 |
+
|
180 |
+
std::vector<torch::Tensor> raymarch_forward(
|
181 |
+
torch::Tensor rayposim,
|
182 |
+
torch::Tensor raydirim,
|
183 |
+
float stepsize,
|
184 |
+
torch::Tensor tminmaxim,
|
185 |
+
|
186 |
+
torch::optional<torch::Tensor> sortedobjid,
|
187 |
+
torch::optional<torch::Tensor> nodechildren,
|
188 |
+
torch::optional<torch::Tensor> nodeaabb,
|
189 |
+
|
190 |
+
torch::Tensor primpos,
|
191 |
+
torch::optional<torch::Tensor> primrot,
|
192 |
+
torch::optional<torch::Tensor> primscale,
|
193 |
+
|
194 |
+
torch::Tensor tplate,
|
195 |
+
torch::optional<torch::Tensor> warp,
|
196 |
+
|
197 |
+
torch::Tensor rayrgbaim,
|
198 |
+
torch::optional<torch::Tensor> raysatim,
|
199 |
+
torch::optional<torch::Tensor> raytermim,
|
200 |
+
|
201 |
+
int algorithm=0,
|
202 |
+
bool sortboxes=true,
|
203 |
+
int maxhitboxes=512,
|
204 |
+
bool synchitboxes=false,
|
205 |
+
bool chlast=false,
|
206 |
+
float fadescale=8.f,
|
207 |
+
float fadeexp=8.f,
|
208 |
+
int accum=0,
|
209 |
+
float termthresh=0.f,
|
210 |
+
int griddim=3,
|
211 |
+
int blocksizex=8,
|
212 |
+
int blocksizey=16) {
|
213 |
+
CHECK_INPUT(rayposim);
|
214 |
+
CHECK_INPUT(raydirim);
|
215 |
+
CHECK_INPUT(tminmaxim);
|
216 |
+
if (sortedobjid) { CHECK_INPUT(*sortedobjid); }
|
217 |
+
if (nodechildren) { CHECK_INPUT(*nodechildren); }
|
218 |
+
if (nodeaabb) { CHECK_INPUT(*nodeaabb); }
|
219 |
+
CHECK_INPUT(tplate);
|
220 |
+
if (warp) { CHECK_INPUT(*warp); }
|
221 |
+
CHECK_INPUT(primpos);
|
222 |
+
if (primrot) { CHECK_INPUT(*primrot); }
|
223 |
+
if (primscale) { CHECK_INPUT(*primscale); }
|
224 |
+
CHECK_INPUT(rayrgbaim);
|
225 |
+
if (raysatim) { CHECK_INPUT(*raysatim); }
|
226 |
+
if (raytermim) { CHECK_INPUT(*raytermim); }
|
227 |
+
|
228 |
+
int N = rayposim.size(0);
|
229 |
+
int H = rayposim.size(1);
|
230 |
+
int W = rayposim.size(2);
|
231 |
+
int K = primpos.size(1);
|
232 |
+
|
233 |
+
int TD, TH, TW;
|
234 |
+
if (chlast) {
|
235 |
+
TD = tplate.size(2); TH = tplate.size(3); TW = tplate.size(4);
|
236 |
+
} else {
|
237 |
+
TD = tplate.size(3); TH = tplate.size(4); TW = tplate.size(5);
|
238 |
+
}
|
239 |
+
|
240 |
+
int WD = 0, WH = 0, WW = 0;
|
241 |
+
if (warp) {
|
242 |
+
if (chlast) {
|
243 |
+
WD = warp->size(2); WH = warp->size(3); WW = warp->size(4);
|
244 |
+
} else {
|
245 |
+
WD = warp->size(3); WH = warp->size(4); WW = warp->size(5);
|
246 |
+
}
|
247 |
+
}
|
248 |
+
|
249 |
+
raymarch_forward_cuda(N, H, W, K,
|
250 |
+
reinterpret_cast<float *>(rayposim.data_ptr()),
|
251 |
+
reinterpret_cast<float *>(raydirim.data_ptr()),
|
252 |
+
stepsize,
|
253 |
+
reinterpret_cast<float *>(tminmaxim.data_ptr()),
|
254 |
+
sortedobjid ? reinterpret_cast<int *>(sortedobjid->data_ptr()) : nullptr,
|
255 |
+
nodechildren ? reinterpret_cast<int *>(nodechildren->data_ptr()) : nullptr,
|
256 |
+
nodeaabb ? reinterpret_cast<float *>(nodeaabb->data_ptr()) : nullptr,
|
257 |
+
|
258 |
+
// prim transforms
|
259 |
+
reinterpret_cast<float *>(primpos.data_ptr()),
|
260 |
+
primrot ? reinterpret_cast<float *>(primrot->data_ptr()) : nullptr,
|
261 |
+
primscale ? reinterpret_cast<float *>(primscale->data_ptr()) : nullptr,
|
262 |
+
|
263 |
+
// prim sampler
|
264 |
+
TD, TH, TW,
|
265 |
+
reinterpret_cast<float *>(tplate.data_ptr()),
|
266 |
+
WD, WH, WW,
|
267 |
+
warp ? reinterpret_cast<float *>(warp->data_ptr()) : nullptr,
|
268 |
+
|
269 |
+
// prim accumulator
|
270 |
+
reinterpret_cast<float *>(rayrgbaim.data_ptr()),
|
271 |
+
raysatim ? reinterpret_cast<float *>(raysatim->data_ptr()) : nullptr,
|
272 |
+
raytermim ? reinterpret_cast<int *>(raytermim->data_ptr()) : nullptr,
|
273 |
+
|
274 |
+
// options
|
275 |
+
algorithm, sortboxes, maxhitboxes, synchitboxes, chlast, fadescale, fadeexp, accum, termthresh,
|
276 |
+
griddim, blocksizex, blocksizey,
|
277 |
+
0);
|
278 |
+
|
279 |
+
return {};
|
280 |
+
}
|
281 |
+
|
282 |
+
std::vector<torch::Tensor> raymarch_backward(
|
283 |
+
torch::Tensor rayposim,
|
284 |
+
torch::Tensor raydirim,
|
285 |
+
float stepsize,
|
286 |
+
torch::Tensor tminmaxim,
|
287 |
+
|
288 |
+
torch::optional<torch::Tensor> sortedobjid,
|
289 |
+
torch::optional<torch::Tensor> nodechildren,
|
290 |
+
torch::optional<torch::Tensor> nodeaabb,
|
291 |
+
|
292 |
+
torch::Tensor primpos,
|
293 |
+
torch::Tensor grad_primpos,
|
294 |
+
torch::optional<torch::Tensor> primrot,
|
295 |
+
torch::optional<torch::Tensor> grad_primrot,
|
296 |
+
torch::optional<torch::Tensor> primscale,
|
297 |
+
torch::optional<torch::Tensor> grad_primscale,
|
298 |
+
|
299 |
+
torch::Tensor tplate,
|
300 |
+
torch::Tensor grad_tplate,
|
301 |
+
torch::optional<torch::Tensor> warp,
|
302 |
+
torch::optional<torch::Tensor> grad_warp,
|
303 |
+
|
304 |
+
torch::Tensor rayrgbaim,
|
305 |
+
torch::Tensor grad_rayrgba,
|
306 |
+
torch::optional<torch::Tensor> raysatim,
|
307 |
+
torch::optional<torch::Tensor> raytermim,
|
308 |
+
|
309 |
+
int algorithm=0,
|
310 |
+
bool sortboxes=true,
|
311 |
+
int maxhitboxes=512,
|
312 |
+
bool synchitboxes=false,
|
313 |
+
bool chlast=false,
|
314 |
+
float fadescale=8.f,
|
315 |
+
float fadeexp=8.f,
|
316 |
+
int accum=0,
|
317 |
+
float termthresh=0.f,
|
318 |
+
int griddim=3,
|
319 |
+
int blocksizex=8,
|
320 |
+
int blocksizey=16) {
|
321 |
+
CHECK_INPUT(rayposim);
|
322 |
+
CHECK_INPUT(raydirim);
|
323 |
+
CHECK_INPUT(tminmaxim);
|
324 |
+
if (sortedobjid) { CHECK_INPUT(*sortedobjid); }
|
325 |
+
if (nodechildren) { CHECK_INPUT(*nodechildren); }
|
326 |
+
if (nodeaabb) { CHECK_INPUT(*nodeaabb); }
|
327 |
+
CHECK_INPUT(tplate);
|
328 |
+
if (warp) { CHECK_INPUT(*warp); }
|
329 |
+
CHECK_INPUT(primpos);
|
330 |
+
if (primrot) { CHECK_INPUT(*primrot); }
|
331 |
+
if (primscale) { CHECK_INPUT(*primscale); }
|
332 |
+
CHECK_INPUT(rayrgbaim);
|
333 |
+
if (raysatim) { CHECK_INPUT(*raysatim); }
|
334 |
+
if (raytermim) { CHECK_INPUT(*raytermim); }
|
335 |
+
CHECK_INPUT(grad_rayrgba);
|
336 |
+
CHECK_INPUT(grad_tplate);
|
337 |
+
if (grad_warp) { CHECK_INPUT(*grad_warp); }
|
338 |
+
CHECK_INPUT(grad_primpos);
|
339 |
+
if (grad_primrot) { CHECK_INPUT(*grad_primrot); }
|
340 |
+
if (grad_primscale) { CHECK_INPUT(*grad_primscale); }
|
341 |
+
|
342 |
+
int N = rayposim.size(0);
|
343 |
+
int H = rayposim.size(1);
|
344 |
+
int W = rayposim.size(2);
|
345 |
+
int K = primpos.size(1);
|
346 |
+
|
347 |
+
int TD, TH, TW;
|
348 |
+
if (chlast) {
|
349 |
+
TD = tplate.size(2); TH = tplate.size(3); TW = tplate.size(4);
|
350 |
+
} else {
|
351 |
+
TD = tplate.size(3); TH = tplate.size(4); TW = tplate.size(5);
|
352 |
+
}
|
353 |
+
|
354 |
+
int WD = 0, WH = 0, WW = 0;
|
355 |
+
if (warp) {
|
356 |
+
if (chlast) {
|
357 |
+
WD = warp->size(2); WH = warp->size(3); WW = warp->size(4);
|
358 |
+
} else {
|
359 |
+
WD = warp->size(3); WH = warp->size(4); WW = warp->size(5);
|
360 |
+
}
|
361 |
+
}
|
362 |
+
|
363 |
+
raymarch_backward_cuda(N, H, W, K,
|
364 |
+
reinterpret_cast<float *>(rayposim.data_ptr()),
|
365 |
+
reinterpret_cast<float *>(raydirim.data_ptr()),
|
366 |
+
stepsize,
|
367 |
+
reinterpret_cast<float *>(tminmaxim.data_ptr()),
|
368 |
+
sortedobjid ? reinterpret_cast<int *>(sortedobjid->data_ptr()) : nullptr,
|
369 |
+
nodechildren ? reinterpret_cast<int *>(nodechildren->data_ptr()) : nullptr,
|
370 |
+
nodeaabb ? reinterpret_cast<float *>(nodeaabb->data_ptr()) : nullptr,
|
371 |
+
|
372 |
+
reinterpret_cast<float *>(primpos.data_ptr()),
|
373 |
+
reinterpret_cast<float *>(grad_primpos.data_ptr()),
|
374 |
+
primrot ? reinterpret_cast<float *>(primrot->data_ptr()) : nullptr,
|
375 |
+
grad_primrot ? reinterpret_cast<float *>(grad_primrot->data_ptr()) : nullptr,
|
376 |
+
primscale ? reinterpret_cast<float *>(primscale->data_ptr()) : nullptr,
|
377 |
+
grad_primscale ? reinterpret_cast<float *>(grad_primscale->data_ptr()) : nullptr,
|
378 |
+
|
379 |
+
TD, TH, TW,
|
380 |
+
reinterpret_cast<float *>(tplate.data_ptr()),
|
381 |
+
reinterpret_cast<float *>(grad_tplate.data_ptr()),
|
382 |
+
WD, WH, WW,
|
383 |
+
warp ? reinterpret_cast<float *>(warp->data_ptr()) : nullptr,
|
384 |
+
grad_warp ? reinterpret_cast<float *>(grad_warp->data_ptr()) : nullptr,
|
385 |
+
|
386 |
+
reinterpret_cast<float *>(rayrgbaim.data_ptr()),
|
387 |
+
reinterpret_cast<float *>(grad_rayrgba.data_ptr()),
|
388 |
+
raysatim ? reinterpret_cast<float *>(raysatim->data_ptr()) : nullptr,
|
389 |
+
raytermim ? reinterpret_cast<int *>(raytermim->data_ptr()) : nullptr,
|
390 |
+
|
391 |
+
algorithm, sortboxes, maxhitboxes, synchitboxes, chlast, fadescale, fadeexp, accum, termthresh,
|
392 |
+
griddim, blocksizex, blocksizey,
|
393 |
+
0);
|
394 |
+
|
395 |
+
return {};
|
396 |
+
}
|
397 |
+
|
398 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
399 |
+
m.def("compute_morton", &compute_morton, "compute morton codes (CUDA)");
|
400 |
+
m.def("build_tree", &build_tree, "build BVH tree (CUDA)");
|
401 |
+
m.def("compute_aabb", &compute_aabb, "compute AABB sizes (CUDA)");
|
402 |
+
|
403 |
+
m.def("raymarch_forward", &raymarch_forward, "raymarch forward (CUDA)");
|
404 |
+
m.def("raymarch_backward", &raymarch_backward, "raymarch backward (CUDA)");
|
405 |
+
}
|
dva/mvp/extensions/mvpraymarch/mvpraymarch.py
ADDED
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import time
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
from torch.autograd import Function
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
try:
|
16 |
+
from . import mvpraymarchlib
|
17 |
+
except:
|
18 |
+
import mvpraymarchlib
|
19 |
+
|
20 |
+
def build_accel(primtransfin, algo, fixedorder=False):
|
21 |
+
"""build bvh structure given primitive centers and sizes
|
22 |
+
|
23 |
+
Parameters:
|
24 |
+
----------
|
25 |
+
primtransfin : tuple[tensor, tensor, tensor]
|
26 |
+
primitive transform tensors
|
27 |
+
algo : int
|
28 |
+
raymarching algorithm
|
29 |
+
fixedorder : optional[str]
|
30 |
+
True means the bvh builder will not reorder primitives and will
|
31 |
+
use a trivial tree structure. Likely to be slow for arbitrary
|
32 |
+
configurations of primitives.
|
33 |
+
|
34 |
+
"""
|
35 |
+
primpos, primrot, primscale = primtransfin
|
36 |
+
|
37 |
+
N = primpos.size(0)
|
38 |
+
K = primpos.size(1)
|
39 |
+
|
40 |
+
dev = primpos.device
|
41 |
+
|
42 |
+
# compute and sort morton codes
|
43 |
+
if fixedorder:
|
44 |
+
sortedobjid = (torch.arange(N*K, dtype=torch.int32, device=dev) % K).view(N, K)
|
45 |
+
else:
|
46 |
+
cmax = primpos.max(dim=1, keepdim=True)[0]
|
47 |
+
cmin = primpos.min(dim=1, keepdim=True)[0]
|
48 |
+
|
49 |
+
centers_norm = (primpos - cmin) / (cmax - cmin).clamp(min=1e-8)
|
50 |
+
|
51 |
+
mortoncode = torch.empty((N, K), dtype=torch.int32, device=dev)
|
52 |
+
mvpraymarchlib.compute_morton(centers_norm, mortoncode, algo)
|
53 |
+
sortedcode, sortedobjid_long = torch.sort(mortoncode, dim=-1)
|
54 |
+
sortedobjid = sortedobjid_long.int()
|
55 |
+
|
56 |
+
if fixedorder:
|
57 |
+
nodechildren = torch.cat([
|
58 |
+
torch.arange(1, (K - 1) * 2 + 1, dtype=torch.int32, device=dev),
|
59 |
+
torch.div(torch.arange(-2, -(K * 2 + 1) - 1, -1, dtype=torch.int32, device=dev), 2, rounding_mode="floor")],
|
60 |
+
dim=0).view(1, K + K - 1, 2).repeat(N, 1, 1)
|
61 |
+
nodeparent = (
|
62 |
+
torch.div(torch.arange(-1, K * 2 - 2, dtype=torch.int32, device=dev), 2, rounding_mode="floor")
|
63 |
+
.view(1, -1).repeat(N, 1))
|
64 |
+
else:
|
65 |
+
nodechildren = torch.empty((N, K + K - 1, 2), dtype=torch.int32, device=dev)
|
66 |
+
nodeparent = torch.full((N, K + K - 1), -1, dtype=torch.int32, device=dev)
|
67 |
+
mvpraymarchlib.build_tree(sortedcode, nodechildren, nodeparent)
|
68 |
+
|
69 |
+
nodeaabb = torch.empty((N, K + K - 1, 2, 3), dtype=torch.float32, device=dev)
|
70 |
+
mvpraymarchlib.compute_aabb(*primtransfin, sortedobjid, nodechildren, nodeparent, nodeaabb, algo)
|
71 |
+
|
72 |
+
return sortedobjid, nodechildren, nodeaabb
|
73 |
+
|
74 |
+
class MVPRaymarch(Function):
|
75 |
+
"""Custom Function for raymarching Mixture of Volumetric Primitives."""
|
76 |
+
@staticmethod
|
77 |
+
def forward(self, raypos, raydir, stepsize, tminmax,
|
78 |
+
primpos, primrot, primscale,
|
79 |
+
template, warp,
|
80 |
+
rayterm, gradmode, options):
|
81 |
+
algo = options["algo"]
|
82 |
+
usebvh = options["usebvh"]
|
83 |
+
sortprims = options["sortprims"]
|
84 |
+
randomorder = options["randomorder"]
|
85 |
+
maxhitboxes = options["maxhitboxes"]
|
86 |
+
synchitboxes = options["synchitboxes"]
|
87 |
+
chlast = options["chlast"]
|
88 |
+
fadescale = options["fadescale"]
|
89 |
+
fadeexp = options["fadeexp"]
|
90 |
+
accum = options["accum"]
|
91 |
+
termthresh = options["termthresh"]
|
92 |
+
griddim = options["griddim"]
|
93 |
+
if isinstance(options["blocksize"], tuple):
|
94 |
+
blocksizex, blocksizey = options["blocksize"]
|
95 |
+
else:
|
96 |
+
blocksizex = options["blocksize"]
|
97 |
+
blocksizey = 1
|
98 |
+
|
99 |
+
assert raypos.is_contiguous() and raypos.size(3) == 3
|
100 |
+
assert raydir.is_contiguous() and raydir.size(3) == 3
|
101 |
+
assert tminmax.is_contiguous() and tminmax.size(3) == 2
|
102 |
+
|
103 |
+
assert primpos is None or primpos.is_contiguous() and primpos.size(2) == 3
|
104 |
+
assert primrot is None or primrot.is_contiguous() and primrot.size(2) == 3
|
105 |
+
assert primscale is None or primscale.is_contiguous() and primscale.size(2) == 3
|
106 |
+
|
107 |
+
if chlast:
|
108 |
+
assert template.is_contiguous() and len(template.size()) == 6 and template.size(-1) == 4
|
109 |
+
assert warp is None or (warp.is_contiguous() and warp.size(-1) == 3)
|
110 |
+
else:
|
111 |
+
assert template.is_contiguous() and len(template.size()) == 6 and template.size(2) == 4
|
112 |
+
assert warp is None or (warp.is_contiguous() and warp.size(2) == 3)
|
113 |
+
|
114 |
+
primtransfin = (primpos, primrot, primscale)
|
115 |
+
|
116 |
+
# Build bvh
|
117 |
+
if usebvh is not False:
|
118 |
+
# compute radius of primitives
|
119 |
+
sortedobjid, nodechildren, nodeaabb = build_accel(primtransfin,
|
120 |
+
algo, fixedorder=usebvh=="fixedorder")
|
121 |
+
assert sortedobjid.is_contiguous()
|
122 |
+
assert nodechildren.is_contiguous()
|
123 |
+
assert nodeaabb.is_contiguous()
|
124 |
+
|
125 |
+
if randomorder:
|
126 |
+
sortedobjid = sortedobjid[torch.randperm(len(sortedobjid))]
|
127 |
+
else:
|
128 |
+
_, sortedobjid, nodechildren, nodeaabb = None, None, None, None
|
129 |
+
|
130 |
+
# march through boxes
|
131 |
+
N, H, W = raypos.size(0), raypos.size(1), raypos.size(2)
|
132 |
+
rayrgba = torch.empty((N, H, W, 4), device=raypos.device)
|
133 |
+
if gradmode:
|
134 |
+
raysat = torch.full((N, H, W, 3), -1, dtype=torch.float32, device=raypos.device)
|
135 |
+
rayterm = None
|
136 |
+
else:
|
137 |
+
raysat = None
|
138 |
+
rayterm = None
|
139 |
+
|
140 |
+
mvpraymarchlib.raymarch_forward(
|
141 |
+
raypos, raydir, stepsize, tminmax,
|
142 |
+
sortedobjid, nodechildren, nodeaabb,
|
143 |
+
*primtransfin,
|
144 |
+
template, warp,
|
145 |
+
rayrgba, raysat, rayterm,
|
146 |
+
algo, sortprims, maxhitboxes, synchitboxes, chlast,
|
147 |
+
fadescale, fadeexp,
|
148 |
+
accum, termthresh,
|
149 |
+
griddim, blocksizex, blocksizey)
|
150 |
+
|
151 |
+
self.save_for_backward(
|
152 |
+
raypos, raydir, tminmax,
|
153 |
+
sortedobjid, nodechildren, nodeaabb,
|
154 |
+
primpos, primrot, primscale,
|
155 |
+
template, warp,
|
156 |
+
rayrgba, raysat, rayterm)
|
157 |
+
self.options = options
|
158 |
+
self.stepsize = stepsize
|
159 |
+
|
160 |
+
return rayrgba
|
161 |
+
|
162 |
+
@staticmethod
|
163 |
+
def backward(self, grad_rayrgba):
|
164 |
+
(raypos, raydir, tminmax,
|
165 |
+
sortedobjid, nodechildren, nodeaabb,
|
166 |
+
primpos, primrot, primscale,
|
167 |
+
template, warp,
|
168 |
+
rayrgba, raysat, rayterm) = self.saved_tensors
|
169 |
+
algo = self.options["algo"]
|
170 |
+
usebvh = self.options["usebvh"]
|
171 |
+
sortprims = self.options["sortprims"]
|
172 |
+
maxhitboxes = self.options["maxhitboxes"]
|
173 |
+
synchitboxes = self.options["synchitboxes"]
|
174 |
+
chlast = self.options["chlast"]
|
175 |
+
fadescale = self.options["fadescale"]
|
176 |
+
fadeexp = self.options["fadeexp"]
|
177 |
+
accum = self.options["accum"]
|
178 |
+
termthresh = self.options["termthresh"]
|
179 |
+
griddim = self.options["griddim"]
|
180 |
+
if isinstance(self.options["bwdblocksize"], tuple):
|
181 |
+
blocksizex, blocksizey = self.options["bwdblocksize"]
|
182 |
+
else:
|
183 |
+
blocksizex = self.options["bwdblocksize"]
|
184 |
+
blocksizey = 1
|
185 |
+
|
186 |
+
stepsize = self.stepsize
|
187 |
+
|
188 |
+
grad_primpos = torch.zeros_like(primpos)
|
189 |
+
grad_primrot = torch.zeros_like(primrot)
|
190 |
+
grad_primscale = torch.zeros_like(primscale)
|
191 |
+
primtransfin = (primpos, grad_primpos, primrot, grad_primrot, primscale, grad_primscale)
|
192 |
+
|
193 |
+
grad_template = torch.zeros_like(template)
|
194 |
+
grad_warp = torch.zeros_like(warp) if warp is not None else None
|
195 |
+
|
196 |
+
mvpraymarchlib.raymarch_backward(raypos, raydir, stepsize, tminmax,
|
197 |
+
sortedobjid, nodechildren, nodeaabb,
|
198 |
+
|
199 |
+
*primtransfin,
|
200 |
+
|
201 |
+
template, grad_template, warp, grad_warp,
|
202 |
+
|
203 |
+
rayrgba, grad_rayrgba.contiguous(), raysat, rayterm,
|
204 |
+
|
205 |
+
algo, sortprims, maxhitboxes, synchitboxes, chlast,
|
206 |
+
fadescale, fadeexp,
|
207 |
+
accum, termthresh,
|
208 |
+
griddim, blocksizex, blocksizey)
|
209 |
+
|
210 |
+
return (None, None, None, None,
|
211 |
+
grad_primpos, grad_primrot, grad_primscale,
|
212 |
+
grad_template, grad_warp,
|
213 |
+
None, None, None)
|
214 |
+
|
215 |
+
def mvpraymarch(raypos, raydir, stepsize, tminmax,
|
216 |
+
primtransf,
|
217 |
+
template, warp,
|
218 |
+
rayterm=None,
|
219 |
+
algo=0, usebvh="fixedorder",
|
220 |
+
sortprims=False, randomorder=False,
|
221 |
+
maxhitboxes=512, synchitboxes=True,
|
222 |
+
chlast=True, fadescale=8., fadeexp=8.,
|
223 |
+
accum=0, termthresh=0.,
|
224 |
+
griddim=3, blocksize=(8, 16), bwdblocksize=(8, 16)):
|
225 |
+
"""Main entry point for raymarching MVP.
|
226 |
+
|
227 |
+
Parameters:
|
228 |
+
----------
|
229 |
+
raypos: N x H x W x 3 tensor of ray origins
|
230 |
+
raydir: N x H x W x 3 tensor of ray directions
|
231 |
+
stepsize: raymarching step size
|
232 |
+
tminmax: N x H x W x 2 tensor of raymarching min/max bounds
|
233 |
+
template: N x K x 4 x TD x TH x TW tensor of K RGBA primitives
|
234 |
+
warp: N x K x 3 x TD x TH x TW tensor of K warp fields (optional)
|
235 |
+
primpos: N x K x 3 tensor of primitive centers
|
236 |
+
primrot: N x K x 3 x 3 tensor of primitive orientations
|
237 |
+
primscale: N x K x 3 tensor of primitive inverse dimension lengths
|
238 |
+
algo: algorithm for raymarching (valid values: 0, 1). algo=0 is the fastest.
|
239 |
+
Currently algo=0 has a limit of 512 primitives per ray, so problems can
|
240 |
+
occur if there are many more boxes. all sortprims=True options have
|
241 |
+
this limitation, but you can use (algo=1, sortprims=False,
|
242 |
+
usebvh="fixedorder") which works correctly and has no primitive number
|
243 |
+
limitation (but is slightly slower).
|
244 |
+
usebvh: True to use bvh, "fixedorder" for a simple BVH, False for no bvh
|
245 |
+
sortprims: True to sort overlapping primitives at a sample point. Must
|
246 |
+
be True for gradients to match the PyTorch gradients. Seems unstable
|
247 |
+
if False but also not a big performance bottleneck.
|
248 |
+
chlast: whether template is provided as channels last or not. True tends
|
249 |
+
to be faster.
|
250 |
+
fadescale: Opacity is faded at the borders of the primitives by the equation
|
251 |
+
exp(-fadescale * x ** fadeexp) where x is the normalized coordinates of
|
252 |
+
the primitive.
|
253 |
+
fadeexp: Opacity is faded at the borders of the primitives by the equation
|
254 |
+
exp(-fadescale * x ** fadeexp) where x is the normalized coordinates of
|
255 |
+
the primitive.
|
256 |
+
griddim: CUDA grid dimensionality.
|
257 |
+
blocksize: blocksize of CUDA kernels. Should be 2-element tuple if
|
258 |
+
griddim>1, or integer if griddim==1."""
|
259 |
+
if isinstance(primtransf, tuple):
|
260 |
+
primpos, primrot, primscale = primtransf
|
261 |
+
else:
|
262 |
+
primpos, primrot, primscale = (
|
263 |
+
primtransf[:, :, 0, :].contiguous(),
|
264 |
+
primtransf[:, :, 1:4, :].contiguous(),
|
265 |
+
primtransf[:, :, 4, :].contiguous())
|
266 |
+
primtransfin = (primpos, primrot, primscale)
|
267 |
+
|
268 |
+
out = MVPRaymarch.apply(raypos, raydir, stepsize, tminmax,
|
269 |
+
*primtransfin,
|
270 |
+
template, warp,
|
271 |
+
rayterm, torch.is_grad_enabled(),
|
272 |
+
{"algo": algo, "usebvh": usebvh, "sortprims": sortprims, "randomorder": randomorder,
|
273 |
+
"maxhitboxes": maxhitboxes, "synchitboxes": synchitboxes,
|
274 |
+
"chlast": chlast, "fadescale": fadescale, "fadeexp": fadeexp,
|
275 |
+
"accum": accum, "termthresh": termthresh,
|
276 |
+
"griddim": griddim, "blocksize": blocksize, "bwdblocksize": bwdblocksize})
|
277 |
+
return out
|
278 |
+
|
279 |
+
class Rodrigues(nn.Module):
|
280 |
+
def __init__(self):
|
281 |
+
super(Rodrigues, self).__init__()
|
282 |
+
|
283 |
+
def forward(self, rvec):
|
284 |
+
theta = torch.sqrt(1e-5 + torch.sum(rvec ** 2, dim=1))
|
285 |
+
rvec = rvec / theta[:, None]
|
286 |
+
costh = torch.cos(theta)
|
287 |
+
sinth = torch.sin(theta)
|
288 |
+
return torch.stack((
|
289 |
+
rvec[:, 0] ** 2 + (1. - rvec[:, 0] ** 2) * costh,
|
290 |
+
rvec[:, 0] * rvec[:, 1] * (1. - costh) - rvec[:, 2] * sinth,
|
291 |
+
rvec[:, 0] * rvec[:, 2] * (1. - costh) + rvec[:, 1] * sinth,
|
292 |
+
|
293 |
+
rvec[:, 0] * rvec[:, 1] * (1. - costh) + rvec[:, 2] * sinth,
|
294 |
+
rvec[:, 1] ** 2 + (1. - rvec[:, 1] ** 2) * costh,
|
295 |
+
rvec[:, 1] * rvec[:, 2] * (1. - costh) - rvec[:, 0] * sinth,
|
296 |
+
|
297 |
+
rvec[:, 0] * rvec[:, 2] * (1. - costh) - rvec[:, 1] * sinth,
|
298 |
+
rvec[:, 1] * rvec[:, 2] * (1. - costh) + rvec[:, 0] * sinth,
|
299 |
+
rvec[:, 2] ** 2 + (1. - rvec[:, 2] ** 2) * costh), dim=1).view(-1, 3, 3)
|
300 |
+
|
301 |
+
def gradcheck(usebvh=True, sortprims=True, maxhitboxes=512, synchitboxes=False,
|
302 |
+
dowarp=False, chlast=False, fadescale=8., fadeexp=8.,
|
303 |
+
accum=0, termthresh=0., algo=0, griddim=2, blocksize=(8, 16), bwdblocksize=(8, 16)):
|
304 |
+
N = 2
|
305 |
+
H = 65
|
306 |
+
W = 65
|
307 |
+
k3 = 4
|
308 |
+
K = k3*k3*k3
|
309 |
+
|
310 |
+
M = 32
|
311 |
+
|
312 |
+
print("=================================================================")
|
313 |
+
print("usebvh={}, sortprims={}, maxhb={}, synchb={}, dowarp={}, chlast={}, "
|
314 |
+
"fadescale={}, fadeexp={}, accum={}, termthresh={}, algo={}, griddim={}, "
|
315 |
+
"blocksize={}, bwdblocksize={}".format(
|
316 |
+
usebvh, sortprims, maxhitboxes, synchitboxes, dowarp, chlast,
|
317 |
+
fadescale, fadeexp, accum, termthresh, algo, griddim, blocksize,
|
318 |
+
bwdblocksize))
|
319 |
+
|
320 |
+
# generate random inputs
|
321 |
+
torch.manual_seed(1112)
|
322 |
+
|
323 |
+
coherent_rays = True
|
324 |
+
if not coherent_rays:
|
325 |
+
_raypos = torch.randn(N, H, W, 3).to("cuda")
|
326 |
+
_raydir = torch.randn(N, H, W, 3).to("cuda")
|
327 |
+
_raydir /= torch.sqrt(torch.sum(_raydir ** 2, dim=-1, keepdim=True))
|
328 |
+
else:
|
329 |
+
focal = torch.tensor([[W*4.0, W*4.0] for n in range(N)])
|
330 |
+
princpt = torch.tensor([[W*0.5, H*0.5] for n in range(N)])
|
331 |
+
pixely, pixelx = torch.meshgrid(torch.arange(H).float(), torch.arange(W).float())
|
332 |
+
pixelcoords = torch.stack([pixelx, pixely], dim=-1)[None, :, :, :].repeat(N, 1, 1, 1)
|
333 |
+
|
334 |
+
raydir = (pixelcoords - princpt[:, None, None, :]) / focal[:, None, None, :]
|
335 |
+
raydir = torch.cat([raydir, torch.ones_like(raydir[:, :, :, 0:1])], dim=-1)
|
336 |
+
raydir = raydir / torch.sqrt(torch.sum(raydir ** 2, dim=-1, keepdim=True))
|
337 |
+
|
338 |
+
_raypos = torch.tensor([-0.0, 0.0, -4.])[None, None, None, :].repeat(N, H, W, 1).to("cuda")
|
339 |
+
_raydir = raydir.to("cuda")
|
340 |
+
_raydir /= torch.sqrt(torch.sum(_raydir ** 2, dim=-1, keepdim=True))
|
341 |
+
|
342 |
+
max_len = 6.0
|
343 |
+
_stepsize = max_len / 15.386928
|
344 |
+
_tminmax = max_len*torch.arange(2, dtype=torch.float32)[None, None, None, :].repeat(N, H, W, 1).to("cuda") + \
|
345 |
+
torch.rand(N, H, W, 2, device="cuda") * 1.
|
346 |
+
|
347 |
+
_template = torch.randn(N, K, 4, M, M, M, requires_grad=True)
|
348 |
+
_template.data[:, :, -1, :, :, :] -= 3.5
|
349 |
+
_template = _template.contiguous().detach().clone()
|
350 |
+
_template.requires_grad = True
|
351 |
+
gridxyz = torch.stack(torch.meshgrid(
|
352 |
+
torch.linspace(-1., 1., M//2),
|
353 |
+
torch.linspace(-1., 1., M//2),
|
354 |
+
torch.linspace(-1., 1., M//2))[::-1], dim=0).contiguous()
|
355 |
+
_warp = (torch.randn(N, K, 3, M//2, M//2, M//2) * 0.01 + gridxyz[None, None, :, :, :, :]).contiguous().detach().clone()
|
356 |
+
_warp.requires_grad = True
|
357 |
+
_primpos = torch.randn(N, K, 3, requires_grad=True)
|
358 |
+
_primpos = torch.randn(N, K, 3, requires_grad=True)
|
359 |
+
|
360 |
+
coherent_centers = True
|
361 |
+
if coherent_centers:
|
362 |
+
ns = k3
|
363 |
+
#assert ns*ns*ns==K
|
364 |
+
grid3d = torch.stack(torch.meshgrid(
|
365 |
+
torch.linspace(-1., 1., ns),
|
366 |
+
torch.linspace(-1., 1., ns),
|
367 |
+
torch.linspace(-1., 1., K//(ns*ns)))[::-1], dim=0)[None]
|
368 |
+
_primpos = ((
|
369 |
+
grid3d.permute((0, 2, 3, 4, 1)).reshape(1, K, 3).expand(N, -1, -1) +
|
370 |
+
0.1 * torch.randn(N, K, 3, requires_grad=True)
|
371 |
+
)).contiguous().detach().clone()
|
372 |
+
_primpos.requires_grad = True
|
373 |
+
scale_ws = 1.
|
374 |
+
_primrot = torch.randn(N, K, 3)
|
375 |
+
rodrigues = Rodrigues()
|
376 |
+
_primrot = rodrigues(_primrot.view(-1, 3)).view(N, K, 3, 3).contiguous().detach().clone()
|
377 |
+
_primrot.requires_grad = True
|
378 |
+
|
379 |
+
_primscale = torch.randn(N, K, 3, requires_grad=True)
|
380 |
+
_primscale.data *= 0.0
|
381 |
+
|
382 |
+
if dowarp:
|
383 |
+
params = [_template, _warp, _primscale, _primrot, _primpos]
|
384 |
+
paramnames = ["template", "warp", "primscale", "primrot", "primpos"]
|
385 |
+
else:
|
386 |
+
params = [_template, _primscale, _primrot, _primpos]
|
387 |
+
paramnames = ["template", "primscale", "primrot", "primpos"]
|
388 |
+
|
389 |
+
termthreshorig = termthresh
|
390 |
+
|
391 |
+
########################### run pytorch version ###########################
|
392 |
+
|
393 |
+
raypos = _raypos
|
394 |
+
raydir = _raydir
|
395 |
+
stepsize = _stepsize
|
396 |
+
tminmax = _tminmax
|
397 |
+
|
398 |
+
#template = F.softplus(_template.to("cuda") * 1.5)
|
399 |
+
template = F.softplus(_template.to("cuda") * 1.5) if algo != 2 else _template.to("cuda") * 1.5
|
400 |
+
warp = _warp.to("cuda")
|
401 |
+
primpos = _primpos.to("cuda") * 0.3
|
402 |
+
primrot = _primrot.to("cuda")
|
403 |
+
primscale = scale_ws * torch.exp(0.1 * _primscale.to("cuda"))
|
404 |
+
|
405 |
+
# python raymarching implementation
|
406 |
+
rayrgba = torch.zeros((N, H, W, 4)).to("cuda")
|
407 |
+
raypos = raypos + raydir * tminmax[:, :, :, 0, None]
|
408 |
+
t = tminmax[:, :, :, 0]
|
409 |
+
|
410 |
+
step = 0
|
411 |
+
t0 = t.detach().clone()
|
412 |
+
raypos0 = raypos.detach().clone()
|
413 |
+
|
414 |
+
torch.cuda.synchronize()
|
415 |
+
time0 = time.time()
|
416 |
+
|
417 |
+
while (t < tminmax[:, :, :, 1]).any():
|
418 |
+
valid2 = torch.ones_like(rayrgba[:, :, :, 3:4])
|
419 |
+
|
420 |
+
for k in range(K):
|
421 |
+
y0 = torch.bmm(
|
422 |
+
(raypos - primpos[:, k, None, None, :]).view(raypos.size(0), -1, raypos.size(3)),
|
423 |
+
primrot[:, k, :, :]).view_as(raypos) * primscale[:, k, None, None, :]
|
424 |
+
|
425 |
+
fade = torch.exp(-fadescale * torch.sum(torch.abs(y0) ** fadeexp, dim=-1, keepdim=True))
|
426 |
+
|
427 |
+
if dowarp:
|
428 |
+
y1 = F.grid_sample(
|
429 |
+
warp[:, k, :, :, :, :],
|
430 |
+
y0[:, None, :, :, :], align_corners=True)[:, :, 0, :, :].permute(0, 2, 3, 1)
|
431 |
+
else:
|
432 |
+
y1 = y0
|
433 |
+
|
434 |
+
sample = F.grid_sample(
|
435 |
+
template[:, k, :, :, :, :],
|
436 |
+
y1[:, None, :, :, :], align_corners=True)[:, :, 0, :, :].permute(0, 2, 3, 1)
|
437 |
+
|
438 |
+
valid1 = (
|
439 |
+
torch.prod(y0[:, :, :, :] >= -1., dim=-1, keepdim=True) *
|
440 |
+
torch.prod(y0[:, :, :, :] <= 1., dim=-1, keepdim=True))
|
441 |
+
|
442 |
+
valid = ((t >= tminmax[:, :, :, 0]) & (t < tminmax[:, :, :, 1])).float()[:, :, :, None]
|
443 |
+
|
444 |
+
alpha0 = sample[:, :, :, 3:4]
|
445 |
+
|
446 |
+
rgb = sample[:, :, :, 0:3] * valid * valid1
|
447 |
+
alpha = alpha0 * fade * stepsize * valid * valid1
|
448 |
+
|
449 |
+
if accum == 0:
|
450 |
+
newalpha = rayrgba[:, :, :, 3:4] + alpha
|
451 |
+
contrib = (newalpha.clamp(max=1.0) - rayrgba[:, :, :, 3:4]) * valid * valid1
|
452 |
+
rayrgba = rayrgba + contrib * torch.cat([rgb, torch.ones_like(alpha)], dim=-1)
|
453 |
+
else:
|
454 |
+
raise
|
455 |
+
|
456 |
+
step += 1
|
457 |
+
t = t0 + stepsize * step
|
458 |
+
raypos = raypos0 + raydir * stepsize * step
|
459 |
+
|
460 |
+
print(rayrgba[..., -1].min().item(), rayrgba[..., -1].max().item())
|
461 |
+
|
462 |
+
sample0 = rayrgba
|
463 |
+
|
464 |
+
torch.cuda.synchronize()
|
465 |
+
time1 = time.time()
|
466 |
+
|
467 |
+
sample0.backward(torch.ones_like(sample0))
|
468 |
+
|
469 |
+
torch.cuda.synchronize()
|
470 |
+
time2 = time.time()
|
471 |
+
|
472 |
+
print("{:<10} {:>10} {:>10} {:>10}".format("", "fwd", "bwd", "total"))
|
473 |
+
print("{:<10} {:10.5} {:10.5} {:10.5}".format("pytime", time1 - time0, time2 - time1, time2 - time0))
|
474 |
+
|
475 |
+
grads0 = [p.grad.detach().clone() for p in params]
|
476 |
+
|
477 |
+
for p in params:
|
478 |
+
p.grad.detach_()
|
479 |
+
p.grad.zero_()
|
480 |
+
|
481 |
+
############################## run cuda version ###########################
|
482 |
+
|
483 |
+
raypos = _raypos
|
484 |
+
raydir = _raydir
|
485 |
+
stepsize = _stepsize
|
486 |
+
tminmax = _tminmax
|
487 |
+
|
488 |
+
template = F.softplus(_template.to("cuda") * 1.5) if algo != 2 else _template.to("cuda") * 1.5
|
489 |
+
warp = _warp.to("cuda")
|
490 |
+
if chlast:
|
491 |
+
template = template.permute(0, 1, 3, 4, 5, 2).contiguous()
|
492 |
+
warp = warp.permute(0, 1, 3, 4, 5, 2).contiguous()
|
493 |
+
primpos = _primpos.to("cuda") * 0.3
|
494 |
+
primrot = _primrot.to("cuda")
|
495 |
+
primscale = scale_ws * torch.exp(0.1 * _primscale.to("cuda"))
|
496 |
+
|
497 |
+
niter = 1
|
498 |
+
|
499 |
+
tf, tb = 0., 0.
|
500 |
+
for i in range(niter):
|
501 |
+
for p in params:
|
502 |
+
try:
|
503 |
+
p.grad.detach_()
|
504 |
+
p.grad.zero_()
|
505 |
+
except:
|
506 |
+
pass
|
507 |
+
t0 = time.time()
|
508 |
+
torch.cuda.synchronize()
|
509 |
+
sample1 = mvpraymarch(raypos, raydir, stepsize, tminmax,
|
510 |
+
(primpos, primrot, primscale),
|
511 |
+
template, warp if dowarp else None,
|
512 |
+
algo=algo, usebvh=usebvh, sortprims=sortprims,
|
513 |
+
maxhitboxes=maxhitboxes, synchitboxes=synchitboxes,
|
514 |
+
chlast=chlast, fadescale=fadescale, fadeexp=fadeexp,
|
515 |
+
accum=accum, termthresh=termthreshorig,
|
516 |
+
griddim=griddim, blocksize=blocksize, bwdblocksize=bwdblocksize)
|
517 |
+
t1 = time.time()
|
518 |
+
torch.cuda.synchronize()
|
519 |
+
sample1.backward(torch.ones_like(sample1), retain_graph=True)
|
520 |
+
torch.cuda.synchronize()
|
521 |
+
t2 = time.time()
|
522 |
+
tf += t1 - t0
|
523 |
+
tb += t2 - t1
|
524 |
+
|
525 |
+
print("{:<10} {:10.5} {:10.5} {:10.5}".format("time", tf / niter, tb / niter, (tf + tb) / niter))
|
526 |
+
grads1 = [p.grad.detach().clone() for p in params]
|
527 |
+
|
528 |
+
############# compare results #############
|
529 |
+
|
530 |
+
print("-----------------------------------------------------------------")
|
531 |
+
print("{:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10}".format("", "maxabsdiff", "dp", "||py||", "||cuda||", "index", "py", "cuda"))
|
532 |
+
ind = torch.argmax(torch.abs(sample0 - sample1))
|
533 |
+
print("{:<10} {:>10.5} {:>10.5} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format(
|
534 |
+
"fwd",
|
535 |
+
torch.max(torch.abs(sample0 - sample1)).item(),
|
536 |
+
(torch.sum(sample0 * sample1) / torch.sqrt(torch.sum(sample0 * sample0) * torch.sum(sample1 * sample1))).item(),
|
537 |
+
torch.sqrt(torch.sum(sample0 * sample0)).item(),
|
538 |
+
torch.sqrt(torch.sum(sample1 * sample1)).item(),
|
539 |
+
ind.item(),
|
540 |
+
sample0.view(-1)[ind].item(),
|
541 |
+
sample1.view(-1)[ind].item()))
|
542 |
+
|
543 |
+
for p, g0, g1 in zip(paramnames, grads0, grads1):
|
544 |
+
ind = torch.argmax(torch.abs(g0 - g1))
|
545 |
+
print("{:<10} {:>10.5} {:>10.5} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format(
|
546 |
+
p,
|
547 |
+
torch.max(torch.abs(g0 - g1)).item(),
|
548 |
+
(torch.sum(g0 * g1) / torch.sqrt(torch.sum(g0 * g0) * torch.sum(g1 * g1))).item(),
|
549 |
+
torch.sqrt(torch.sum(g0 * g0)).item(),
|
550 |
+
torch.sqrt(torch.sum(g1 * g1)).item(),
|
551 |
+
ind.item(),
|
552 |
+
g0.view(-1)[ind].item(),
|
553 |
+
g1.view(-1)[ind].item()))
|
554 |
+
|
555 |
+
if __name__ == "__main__":
|
556 |
+
gradcheck(usebvh="fixedorder", sortprims=False, maxhitboxes=512, synchitboxes=True,
|
557 |
+
dowarp=False, chlast=True, fadescale=6.5, fadeexp=7.5, accum=0, algo=0, griddim=3)
|
558 |
+
gradcheck(usebvh="fixedorder", sortprims=False, maxhitboxes=512, synchitboxes=True,
|
559 |
+
dowarp=True, chlast=True, fadescale=6.5, fadeexp=7.5, accum=0, algo=1, griddim=3)
|
dva/mvp/extensions/mvpraymarch/mvpraymarch_kernel.cu
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
// All rights reserved.
|
3 |
+
//
|
4 |
+
// This source code is licensed under the license found in the
|
5 |
+
// LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
#include <chrono>
|
8 |
+
#include <functional>
|
9 |
+
#include <iostream>
|
10 |
+
#include <map>
|
11 |
+
#include <memory>
|
12 |
+
#include <tuple>
|
13 |
+
#include <vector>
|
14 |
+
|
15 |
+
#include "helper_math.h"
|
16 |
+
|
17 |
+
#include "cudadispatch.h"
|
18 |
+
|
19 |
+
#include "utils.h"
|
20 |
+
|
21 |
+
#include "primtransf.h"
|
22 |
+
#include "primsampler.h"
|
23 |
+
#include "primaccum.h"
|
24 |
+
|
25 |
+
#include "mvpraymarch_subset_kernel.h"
|
26 |
+
|
27 |
+
typedef std::shared_ptr<PrimTransfDataBase> PrimTransfDataBase_ptr;
|
28 |
+
typedef std::shared_ptr<PrimSamplerDataBase> PrimSamplerDataBase_ptr;
|
29 |
+
typedef std::shared_ptr<PrimAccumDataBase> PrimAccumDataBase_ptr;
|
30 |
+
typedef std::function<void(dim3, dim3, cudaStream_t, int, int, int, int,
|
31 |
+
float3*, float3*, float, float2*, int*, int2*, float3*,
|
32 |
+
PrimTransfDataBase_ptr, PrimSamplerDataBase_ptr,
|
33 |
+
PrimAccumDataBase_ptr)> mapfn_t;
|
34 |
+
typedef RaySubsetFixedBVH<false, 512, true, PrimTransfSRT> raysubset_t;
|
35 |
+
|
36 |
+
void raymarch_forward_cuda(
|
37 |
+
int N, int H, int W, int K,
|
38 |
+
float * rayposim,
|
39 |
+
float * raydirim,
|
40 |
+
float stepsize,
|
41 |
+
float * tminmaxim,
|
42 |
+
|
43 |
+
int * sortedobjid,
|
44 |
+
int * nodechildren,
|
45 |
+
float * nodeaabb,
|
46 |
+
float * primpos,
|
47 |
+
float * primrot,
|
48 |
+
float * primscale,
|
49 |
+
|
50 |
+
int TD, int TH, int TW,
|
51 |
+
float * tplate,
|
52 |
+
int WD, int WH, int WW,
|
53 |
+
float * warp,
|
54 |
+
|
55 |
+
float * rayrgbaim,
|
56 |
+
float * raysatim,
|
57 |
+
int * raytermim,
|
58 |
+
|
59 |
+
int algorithm,
|
60 |
+
bool sortboxes,
|
61 |
+
int maxhitboxes,
|
62 |
+
bool synchitboxes,
|
63 |
+
bool chlast,
|
64 |
+
float fadescale,
|
65 |
+
float fadeexp,
|
66 |
+
int accum,
|
67 |
+
float termthresh,
|
68 |
+
int griddim, int blocksizex, int blocksizey,
|
69 |
+
cudaStream_t stream) {
|
70 |
+
dim3 blocksize(blocksizex, blocksizey);
|
71 |
+
dim3 gridsize;
|
72 |
+
gridsize = dim3(
|
73 |
+
(W + blocksize.x - 1) / blocksize.x,
|
74 |
+
(H + blocksize.y - 1) / blocksize.y,
|
75 |
+
N);
|
76 |
+
|
77 |
+
std::shared_ptr<PrimTransfDataBase> primtransf_data;
|
78 |
+
primtransf_data = std::make_shared<PrimTransfSRT::Data>(PrimTransfSRT::Data{
|
79 |
+
PrimTransfDataBase{},
|
80 |
+
K, (float3*)primpos, nullptr,
|
81 |
+
K * 3, (float3*)primrot, nullptr,
|
82 |
+
K, (float3*)primscale, nullptr});
|
83 |
+
std::shared_ptr<PrimSamplerDataBase> primsampler_data;
|
84 |
+
if (algorithm == 1) {
|
85 |
+
primsampler_data = std::make_shared<PrimSamplerTW<true>::Data>(PrimSamplerTW<true>::Data{
|
86 |
+
PrimSamplerDataBase{},
|
87 |
+
fadescale, fadeexp,
|
88 |
+
K * TD * TH * TW * 4, TD, TH, TW, tplate, nullptr,
|
89 |
+
K * WD * WH * WW * 3, WD, WH, WW, warp, nullptr});
|
90 |
+
} else {
|
91 |
+
primsampler_data = std::make_shared<PrimSamplerTW<false>::Data>(PrimSamplerTW<false>::Data{
|
92 |
+
PrimSamplerDataBase{},
|
93 |
+
fadescale, fadeexp,
|
94 |
+
K * TD * TH * TW * 4, TD, TH, TW, tplate, nullptr,
|
95 |
+
0, 0, 0, 0, nullptr, nullptr});
|
96 |
+
}
|
97 |
+
std::shared_ptr<PrimAccumDataBase> primaccum_data = std::make_shared<PrimAccumAdditive::Data>(PrimAccumAdditive::Data{
|
98 |
+
PrimAccumDataBase{},
|
99 |
+
termthresh, H * W, W, 1, (float4*)rayrgbaim, nullptr, (float3*)raysatim});
|
100 |
+
|
101 |
+
std::map<int, mapfn_t> dispatcher = {
|
102 |
+
{0, make_cudacall(raymarch_subset_forward_kernel<512, 4, raysubset_t, PrimTransfSRT, PrimSamplerTW<false>, PrimAccumAdditive>)},
|
103 |
+
{1, make_cudacall(raymarch_subset_forward_kernel<512, 4, raysubset_t, PrimTransfSRT, PrimSamplerTW<true>, PrimAccumAdditive>)}};
|
104 |
+
|
105 |
+
auto iter = dispatcher.find(algorithm);
|
106 |
+
if (iter != dispatcher.end()) {
|
107 |
+
(iter->second)(
|
108 |
+
gridsize, blocksize, stream,
|
109 |
+
N, H, W, K,
|
110 |
+
reinterpret_cast<float3 *>(rayposim),
|
111 |
+
reinterpret_cast<float3 *>(raydirim),
|
112 |
+
stepsize,
|
113 |
+
reinterpret_cast<float2 *>(tminmaxim),
|
114 |
+
reinterpret_cast<int *>(sortedobjid),
|
115 |
+
reinterpret_cast<int2 *>(nodechildren),
|
116 |
+
reinterpret_cast<float3 *>(nodeaabb),
|
117 |
+
primtransf_data,
|
118 |
+
primsampler_data,
|
119 |
+
primaccum_data);
|
120 |
+
}
|
121 |
+
}
|
122 |
+
|
123 |
+
void raymarch_backward_cuda(
|
124 |
+
int N, int H, int W, int K,
|
125 |
+
float * rayposim,
|
126 |
+
float * raydirim,
|
127 |
+
float stepsize,
|
128 |
+
float * tminmaxim,
|
129 |
+
int * sortedobjid,
|
130 |
+
int * nodechildren,
|
131 |
+
float * nodeaabb,
|
132 |
+
|
133 |
+
float * primpos,
|
134 |
+
float * grad_primpos,
|
135 |
+
float * primrot,
|
136 |
+
float * grad_primrot,
|
137 |
+
float * primscale,
|
138 |
+
float * grad_primscale,
|
139 |
+
|
140 |
+
int TD, int TH, int TW,
|
141 |
+
float * tplate,
|
142 |
+
float * grad_tplate,
|
143 |
+
int WD, int WH, int WW,
|
144 |
+
float * warp,
|
145 |
+
float * grad_warp,
|
146 |
+
|
147 |
+
float * rayrgbaim,
|
148 |
+
float * grad_rayrgba,
|
149 |
+
float * raysatim,
|
150 |
+
int * raytermim,
|
151 |
+
|
152 |
+
int algorithm, bool sortboxes, int maxhitboxes, bool synchitboxes,
|
153 |
+
bool chlast, float fadescale, float fadeexp, int accum, float termthresh,
|
154 |
+
int griddim, int blocksizex, int blocksizey,
|
155 |
+
|
156 |
+
cudaStream_t stream) {
|
157 |
+
dim3 blocksize(blocksizex, blocksizey);
|
158 |
+
dim3 gridsize;
|
159 |
+
gridsize = dim3(
|
160 |
+
(W + blocksize.x - 1) / blocksize.x,
|
161 |
+
(H + blocksize.y - 1) / blocksize.y,
|
162 |
+
N);
|
163 |
+
|
164 |
+
std::shared_ptr<PrimTransfDataBase> primtransf_data;
|
165 |
+
primtransf_data = std::make_shared<PrimTransfSRT::Data>(PrimTransfSRT::Data{
|
166 |
+
PrimTransfDataBase{},
|
167 |
+
K, (float3*)primpos, (float3*)grad_primpos,
|
168 |
+
K * 3, (float3*)primrot, (float3*)grad_primrot,
|
169 |
+
K, (float3*)primscale, (float3*)grad_primscale});
|
170 |
+
std::shared_ptr<PrimSamplerDataBase> primsampler_data;
|
171 |
+
if (algorithm == 1) {
|
172 |
+
primsampler_data = std::make_shared<PrimSamplerTW<true>::Data>(PrimSamplerTW<true>::Data{
|
173 |
+
PrimSamplerDataBase{},
|
174 |
+
fadescale, fadeexp,
|
175 |
+
K * TD * TH * TW * 4, TD, TH, TW, tplate, grad_tplate,
|
176 |
+
K * WD * WH * WW * 3, WD, WH, WW, warp, grad_warp});
|
177 |
+
} else {
|
178 |
+
primsampler_data = std::make_shared<PrimSamplerTW<false>::Data>(PrimSamplerTW<false>::Data{
|
179 |
+
PrimSamplerDataBase{},
|
180 |
+
fadescale, fadeexp,
|
181 |
+
K * TD * TH * TW * 4, TD, TH, TW, tplate, grad_tplate,
|
182 |
+
0, 0, 0, 0, nullptr, nullptr});
|
183 |
+
}
|
184 |
+
std::shared_ptr<PrimAccumDataBase> primaccum_data = std::make_shared<PrimAccumAdditive::Data>(PrimAccumAdditive::Data{
|
185 |
+
PrimAccumDataBase{},
|
186 |
+
termthresh, H * W, W, 1, (float4*)rayrgbaim, (float4*)grad_rayrgba, (float3*)raysatim});
|
187 |
+
|
188 |
+
std::map<int, mapfn_t> dispatcher = {
|
189 |
+
{0, make_cudacall(raymarch_subset_backward_kernel<true, 512, 4, raysubset_t, PrimTransfSRT, PrimSamplerTW<false>, PrimAccumAdditive>)},
|
190 |
+
{1, make_cudacall(raymarch_subset_backward_kernel<true, 512, 4, raysubset_t, PrimTransfSRT, PrimSamplerTW<true>, PrimAccumAdditive>)}};
|
191 |
+
|
192 |
+
auto iter = dispatcher.find(algorithm);
|
193 |
+
if (iter != dispatcher.end()) {
|
194 |
+
(iter->second)(
|
195 |
+
gridsize, blocksize, stream,
|
196 |
+
N, H, W, K,
|
197 |
+
reinterpret_cast<float3 *>(rayposim),
|
198 |
+
reinterpret_cast<float3 *>(raydirim),
|
199 |
+
stepsize,
|
200 |
+
reinterpret_cast<float2 *>(tminmaxim),
|
201 |
+
reinterpret_cast<int *>(sortedobjid),
|
202 |
+
reinterpret_cast<int2 *>(nodechildren),
|
203 |
+
reinterpret_cast<float3 *>(nodeaabb),
|
204 |
+
primtransf_data,
|
205 |
+
primsampler_data,
|
206 |
+
primaccum_data);
|
207 |
+
}
|
208 |
+
}
|
dva/mvp/extensions/mvpraymarch/mvpraymarch_subset_kernel.h
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
// All rights reserved.
|
3 |
+
//
|
4 |
+
// This source code is licensed under the license found in the
|
5 |
+
// LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
template<
|
8 |
+
int maxhitboxes,
|
9 |
+
int nwarps,
|
10 |
+
class RaySubsetT=RaySubsetFixedBVH<false, 512, true, PrimTransfSRT>,
|
11 |
+
class PrimTransfT=PrimTransfSRT,
|
12 |
+
class PrimSamplerT=PrimSamplerTW<false>,
|
13 |
+
class PrimAccumT=PrimAccumAdditive>
|
14 |
+
__global__ void raymarch_subset_forward_kernel(
|
15 |
+
int N, int H, int W, int K,
|
16 |
+
float3 * rayposim,
|
17 |
+
float3 * raydirim,
|
18 |
+
float stepsize,
|
19 |
+
float2 * tminmaxim,
|
20 |
+
int * sortedobjid,
|
21 |
+
int2 * nodechildren,
|
22 |
+
float3 * nodeaabb,
|
23 |
+
typename PrimTransfT::Data primtransf_data,
|
24 |
+
typename PrimSamplerT::Data primsampler_data,
|
25 |
+
typename PrimAccumT::Data primaccum_data
|
26 |
+
) {
|
27 |
+
int w = blockIdx.x * blockDim.x + threadIdx.x;
|
28 |
+
int h = blockIdx.y * blockDim.y + threadIdx.y;
|
29 |
+
int n = blockIdx.z;
|
30 |
+
bool validthread = (w < W) && (h < H) && (n<N);
|
31 |
+
|
32 |
+
assert(nwarps == 0 || blockDim.x * blockDim.y / 32 <= nwarps);
|
33 |
+
const int warpid = __shfl_sync(0xffffffff, (threadIdx.y * blockDim.x + threadIdx.x) / 32, 0);
|
34 |
+
assert(__match_any_sync(0xffffffff, (threadIdx.y * blockDim.x + threadIdx.x) / 32) == 0xffffffff);
|
35 |
+
|
36 |
+
// warpmask contains the valid threads in the warp
|
37 |
+
unsigned warpmask = 0xffffffff;
|
38 |
+
n = min(N - 1, n);
|
39 |
+
h = min(H - 1, h);
|
40 |
+
w = min(W - 1, w);
|
41 |
+
|
42 |
+
sortedobjid += n * K;
|
43 |
+
nodechildren += n * (K + K - 1);
|
44 |
+
nodeaabb += n * (K + K - 1) * 2;
|
45 |
+
|
46 |
+
primtransf_data.n_stride(n);
|
47 |
+
primsampler_data.n_stride(n);
|
48 |
+
primaccum_data.n_stride(n, h, w);
|
49 |
+
|
50 |
+
float3 raypos = rayposim[n * H * W + h * W + w];
|
51 |
+
float3 raydir = raydirim[n * H * W + h * W + w];
|
52 |
+
float2 tminmax = tminmaxim[n * H * W + h * W + w];
|
53 |
+
|
54 |
+
int hitboxes[nwarps > 0 ? 1 : maxhitboxes];
|
55 |
+
__shared__ int hitboxes_sh[nwarps > 0 ? maxhitboxes * nwarps : 1];
|
56 |
+
int * hitboxes_ptr = nwarps > 0 ? hitboxes_sh + maxhitboxes * warpid : hitboxes;
|
57 |
+
int nhitboxes = 0;
|
58 |
+
|
59 |
+
// find raytminmax
|
60 |
+
float2 rtminmax = make_float2(std::numeric_limits<float>::infinity(), -std::numeric_limits<float>::infinity());
|
61 |
+
RaySubsetT::forward(warpmask, K, raypos, raydir, tminmax, rtminmax,
|
62 |
+
sortedobjid, nodechildren, nodeaabb,
|
63 |
+
primtransf_data, hitboxes_ptr, nhitboxes);
|
64 |
+
rtminmax.x = max(rtminmax.x, tminmax.x);
|
65 |
+
rtminmax.y = min(rtminmax.y, tminmax.y);
|
66 |
+
__syncwarp(warpmask);
|
67 |
+
|
68 |
+
float t = tminmax.x;
|
69 |
+
raypos = raypos + raydir * tminmax.x;
|
70 |
+
|
71 |
+
int incs = floor((rtminmax.x - t) / stepsize);
|
72 |
+
t += incs * stepsize;
|
73 |
+
raypos += raydir * incs * stepsize;
|
74 |
+
|
75 |
+
PrimAccumT pa;
|
76 |
+
|
77 |
+
while (!__all_sync(warpmask, t > rtminmax.y + 1e-5f || pa.is_done())) {
|
78 |
+
for (int ks = 0; ks < nhitboxes; ++ks) {
|
79 |
+
int k = hitboxes_ptr[ks];
|
80 |
+
|
81 |
+
// compute primitive-relative coordinate
|
82 |
+
PrimTransfT pt;
|
83 |
+
float3 samplepos = pt.forward(primtransf_data, k, raypos);
|
84 |
+
|
85 |
+
if (pt.valid(samplepos) && !pa.is_done() && t < rtminmax.y + 1e-5f) {
|
86 |
+
// sample
|
87 |
+
PrimSamplerT ps;
|
88 |
+
float4 sample = ps.forward(primsampler_data, k, samplepos);
|
89 |
+
|
90 |
+
// accumulate
|
91 |
+
pa.forward_prim(primaccum_data, sample, stepsize);
|
92 |
+
}
|
93 |
+
}
|
94 |
+
|
95 |
+
// update position
|
96 |
+
t += stepsize;
|
97 |
+
raypos += raydir * stepsize;
|
98 |
+
}
|
99 |
+
|
100 |
+
pa.write(primaccum_data);
|
101 |
+
}
|
102 |
+
|
103 |
+
template <
|
104 |
+
bool forwarddir,
|
105 |
+
int maxhitboxes,
|
106 |
+
int nwarps,
|
107 |
+
class RaySubsetT=RaySubsetFixedBVH<false, 512, true, PrimTransfSRT>,
|
108 |
+
class PrimTransfT=PrimTransfSRT,
|
109 |
+
class PrimSamplerT=PrimSamplerTW<false>,
|
110 |
+
class PrimAccumT=PrimAccumAdditive>
|
111 |
+
__global__ void raymarch_subset_backward_kernel(
|
112 |
+
int N, int H, int W, int K,
|
113 |
+
float3 * rayposim,
|
114 |
+
float3 * raydirim,
|
115 |
+
float stepsize,
|
116 |
+
float2 * tminmaxim,
|
117 |
+
int * sortedobjid,
|
118 |
+
int2 * nodechildren,
|
119 |
+
float3 * nodeaabb,
|
120 |
+
typename PrimTransfT::Data primtransf_data,
|
121 |
+
typename PrimSamplerT::Data primsampler_data,
|
122 |
+
typename PrimAccumT::Data primaccum_data
|
123 |
+
) {
|
124 |
+
int w = blockIdx.x * blockDim.x + threadIdx.x;
|
125 |
+
int h = blockIdx.y * blockDim.y + threadIdx.y;
|
126 |
+
int n = blockIdx.z;
|
127 |
+
bool validthread = (w < W) && (h < H) && (n<N);
|
128 |
+
|
129 |
+
assert(nwarps == 0 || blockDim.x * blockDim.y / 32 <= nwarps);
|
130 |
+
const int warpid = __shfl_sync(0xffffffff, (threadIdx.y * blockDim.x + threadIdx.x) / 32, 0);
|
131 |
+
assert(__match_any_sync(0xffffffff, (threadIdx.y * blockDim.x + threadIdx.x) / 32) == 0xffffffff);
|
132 |
+
|
133 |
+
// warpmask contains the valid threads in the warp
|
134 |
+
unsigned warpmask = 0xffffffff;
|
135 |
+
n = min(N - 1, n);
|
136 |
+
h = min(H - 1, h);
|
137 |
+
w = min(W - 1, w);
|
138 |
+
|
139 |
+
sortedobjid += n * K;
|
140 |
+
nodechildren += n * (K + K - 1);
|
141 |
+
nodeaabb += n * (K + K - 1) * 2;
|
142 |
+
|
143 |
+
primtransf_data.n_stride(n);
|
144 |
+
primsampler_data.n_stride(n);
|
145 |
+
primaccum_data.n_stride(n, h, w);
|
146 |
+
|
147 |
+
float3 raypos = rayposim[n * H * W + h * W + w];
|
148 |
+
float3 raydir = raydirim[n * H * W + h * W + w];
|
149 |
+
float2 tminmax = tminmaxim[n * H * W + h * W + w];
|
150 |
+
|
151 |
+
PrimAccumT pa;
|
152 |
+
pa.read(primaccum_data);
|
153 |
+
|
154 |
+
int hitboxes[nwarps > 0 ? 1 : maxhitboxes];
|
155 |
+
__shared__ int hitboxes_sh[nwarps > 0 ? maxhitboxes * nwarps : 1];
|
156 |
+
int * hitboxes_ptr = nwarps > 0 ? hitboxes_sh + maxhitboxes * warpid : hitboxes;
|
157 |
+
int nhitboxes = 0;
|
158 |
+
|
159 |
+
// find raytminmax
|
160 |
+
float2 rtminmax = make_float2(std::numeric_limits<float>::infinity(), -std::numeric_limits<float>::infinity());
|
161 |
+
RaySubsetT::forward(warpmask, K, raypos, raydir, tminmax, rtminmax,
|
162 |
+
sortedobjid, nodechildren, nodeaabb,
|
163 |
+
primtransf_data, hitboxes_ptr, nhitboxes);
|
164 |
+
rtminmax.x = max(rtminmax.x, tminmax.x);
|
165 |
+
rtminmax.y = min(rtminmax.y, tminmax.y);
|
166 |
+
__syncwarp(warpmask);
|
167 |
+
|
168 |
+
// set up raymarching position
|
169 |
+
float t = tminmax.x;
|
170 |
+
raypos = raypos + raydir * tminmax.x;
|
171 |
+
|
172 |
+
int incs = floor((rtminmax.x - t) / stepsize);
|
173 |
+
t += incs * stepsize;
|
174 |
+
raypos += raydir * incs * stepsize;
|
175 |
+
|
176 |
+
if (!forwarddir) {
|
177 |
+
int nsteps = pa.get_nsteps();
|
178 |
+
t += nsteps * stepsize;
|
179 |
+
raypos += raydir * nsteps * stepsize;
|
180 |
+
}
|
181 |
+
|
182 |
+
while (__any_sync(warpmask, (
|
183 |
+
(forwarddir && t < rtminmax.y + 1e-5f ||
|
184 |
+
!forwarddir && t > rtminmax.x - 1e-5f) &&
|
185 |
+
!pa.is_done()))) {
|
186 |
+
for (int ks = 0; ks < nhitboxes; ++ks) {
|
187 |
+
int k = hitboxes_ptr[forwarddir ? ks : nhitboxes - ks - 1];
|
188 |
+
|
189 |
+
PrimTransfT pt;
|
190 |
+
float3 samplepos = pt.forward(primtransf_data, k, raypos);
|
191 |
+
|
192 |
+
bool evalprim = pt.valid(samplepos) && !pa.is_done() && t < rtminmax.y + 1e-5f;
|
193 |
+
|
194 |
+
float3 dL_samplepos = make_float3(0.f);
|
195 |
+
if (evalprim) {
|
196 |
+
PrimSamplerT ps;
|
197 |
+
float4 sample = ps.forward(primsampler_data, k, samplepos);
|
198 |
+
|
199 |
+
float4 dL_sample = pa.forwardbackward_prim(primaccum_data, sample, stepsize);
|
200 |
+
|
201 |
+
dL_samplepos = ps.backward(primsampler_data, k, samplepos, sample, dL_sample, validthread);
|
202 |
+
}
|
203 |
+
|
204 |
+
if (__any_sync(warpmask, evalprim)) {
|
205 |
+
pt.backward(primtransf_data, k, samplepos, dL_samplepos, validthread && evalprim);
|
206 |
+
}
|
207 |
+
}
|
208 |
+
|
209 |
+
if (forwarddir) {
|
210 |
+
t += stepsize;
|
211 |
+
raypos += raydir * stepsize;
|
212 |
+
} else {
|
213 |
+
t -= stepsize;
|
214 |
+
raypos -= raydir * stepsize;
|
215 |
+
}
|
216 |
+
}
|
217 |
+
}
|
218 |
+
|
dva/mvp/extensions/mvpraymarch/primaccum.h
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
// All rights reserved.
|
3 |
+
//
|
4 |
+
// This source code is licensed under the license found in the
|
5 |
+
// LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
#ifndef MVPRAYMARCHER_PRIMACCUM_H_
|
8 |
+
#define MVPRAYMARCHER_PRIMACCUM_H_
|
9 |
+
|
10 |
+
struct PrimAccumDataBase {
|
11 |
+
typedef PrimAccumDataBase base;
|
12 |
+
};
|
13 |
+
|
14 |
+
struct PrimAccumAdditive {
|
15 |
+
struct Data : public PrimAccumDataBase {
|
16 |
+
float termthresh;
|
17 |
+
|
18 |
+
int nstride, hstride, wstride;
|
19 |
+
float4 * rayrgbaim;
|
20 |
+
float4 * grad_rayrgbaim;
|
21 |
+
float3 * raysatim;
|
22 |
+
|
23 |
+
__forceinline__ __device__ void n_stride(int n, int h, int w) {
|
24 |
+
rayrgbaim += n * nstride + h * hstride + w * wstride;
|
25 |
+
grad_rayrgbaim += n * nstride + h * hstride + w * wstride;
|
26 |
+
if (raysatim) {
|
27 |
+
raysatim += n * nstride + h * hstride + w * wstride;
|
28 |
+
}
|
29 |
+
}
|
30 |
+
};
|
31 |
+
|
32 |
+
float4 rayrgba;
|
33 |
+
float3 raysat;
|
34 |
+
bool sat;
|
35 |
+
float4 dL_rayrgba;
|
36 |
+
|
37 |
+
__forceinline__ __device__ PrimAccumAdditive() :
|
38 |
+
rayrgba(make_float4(0.f)),
|
39 |
+
raysat(make_float3(-1.f)),
|
40 |
+
sat(false) {
|
41 |
+
}
|
42 |
+
|
43 |
+
__forceinline__ __device__ bool is_done() const {
|
44 |
+
return sat;
|
45 |
+
}
|
46 |
+
|
47 |
+
__forceinline__ __device__ int get_nsteps() const {
|
48 |
+
return 0;
|
49 |
+
}
|
50 |
+
|
51 |
+
__forceinline__ __device__ void write(const Data & data) {
|
52 |
+
*data.rayrgbaim = rayrgba;
|
53 |
+
if (data.raysatim) {
|
54 |
+
*data.raysatim = raysat;
|
55 |
+
}
|
56 |
+
}
|
57 |
+
|
58 |
+
__forceinline__ __device__ void read(const Data & data) {
|
59 |
+
dL_rayrgba = *data.grad_rayrgbaim;
|
60 |
+
raysat = *data.raysatim;
|
61 |
+
}
|
62 |
+
|
63 |
+
__forceinline__ __device__ void forward_prim(const Data & data, float4 sample, float stepsize) {
|
64 |
+
// accumulate
|
65 |
+
float3 rgb = make_float3(sample);
|
66 |
+
float alpha = sample.w;
|
67 |
+
float newalpha = rayrgba.w + alpha * stepsize;
|
68 |
+
float contrib = fminf(newalpha, 1.f) - rayrgba.w;
|
69 |
+
|
70 |
+
rayrgba += make_float4(rgb, 1.f) * contrib;
|
71 |
+
|
72 |
+
if (newalpha >= 1.f) {
|
73 |
+
// save saturation point
|
74 |
+
if (!sat) {
|
75 |
+
raysat = rgb;
|
76 |
+
}
|
77 |
+
sat = true;
|
78 |
+
}
|
79 |
+
}
|
80 |
+
|
81 |
+
__forceinline__ __device__ float4 forwardbackward_prim(const Data & data, float4 sample, float stepsize) {
|
82 |
+
float3 rgb = make_float3(sample);
|
83 |
+
float4 rgb1 = make_float4(rgb, 1.f);
|
84 |
+
sample.w *= stepsize;
|
85 |
+
|
86 |
+
bool thissat = rayrgba.w + sample.w >= 1.f;
|
87 |
+
sat = sat || thissat;
|
88 |
+
|
89 |
+
float weight = sat ? (1.f - rayrgba.w) : sample.w;
|
90 |
+
|
91 |
+
float3 dL_rgb = weight * make_float3(dL_rayrgba);
|
92 |
+
float dL_alpha = sat ? 0.f :
|
93 |
+
stepsize * dot(rgb1 - (raysat.x > -1.f ? make_float4(raysat, 1.f) : make_float4(0.f)), dL_rayrgba);
|
94 |
+
|
95 |
+
rayrgba += make_float4(rgb, 1.f) * weight;
|
96 |
+
|
97 |
+
return make_float4(dL_rgb, dL_alpha);
|
98 |
+
}
|
99 |
+
};
|
100 |
+
|
101 |
+
#endif
|
dva/mvp/extensions/mvpraymarch/primsampler.h
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
// All rights reserved.
|
3 |
+
//
|
4 |
+
// This source code is licensed under the license found in the
|
5 |
+
// LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
#ifndef MVPRAYMARCHER_PRIMSAMPLER_H_
|
8 |
+
#define MVPRAYMARCHER_PRIMSAMPLER_H_
|
9 |
+
|
10 |
+
struct PrimSamplerDataBase {
|
11 |
+
typedef PrimSamplerDataBase base;
|
12 |
+
};
|
13 |
+
|
14 |
+
template<
|
15 |
+
bool dowarp,
|
16 |
+
template<typename> class GridSamplerT=GridSamplerChlast>
|
17 |
+
struct PrimSamplerTW {
|
18 |
+
struct Data : public PrimSamplerDataBase {
|
19 |
+
float fadescale, fadeexp;
|
20 |
+
|
21 |
+
int tplate_nstride;
|
22 |
+
int TD, TH, TW;
|
23 |
+
float * tplate;
|
24 |
+
float * grad_tplate;
|
25 |
+
|
26 |
+
int warp_nstride;
|
27 |
+
int WD, WH, WW;
|
28 |
+
float * warp;
|
29 |
+
float * grad_warp;
|
30 |
+
|
31 |
+
__forceinline__ __device__ void n_stride(int n) {
|
32 |
+
tplate += n * tplate_nstride;
|
33 |
+
grad_tplate += n * tplate_nstride;
|
34 |
+
warp += n * warp_nstride;
|
35 |
+
grad_warp += n * warp_nstride;
|
36 |
+
}
|
37 |
+
};
|
38 |
+
|
39 |
+
float fade;
|
40 |
+
float * tplate_ptr;
|
41 |
+
float * warp_ptr;
|
42 |
+
float3 yy1;
|
43 |
+
|
44 |
+
__forceinline__ __device__ float4 forward(
|
45 |
+
const Data & data,
|
46 |
+
int k,
|
47 |
+
float3 y0) {
|
48 |
+
fade = __expf(-data.fadescale * (
|
49 |
+
__powf(abs(y0.x), data.fadeexp) +
|
50 |
+
__powf(abs(y0.y), data.fadeexp) +
|
51 |
+
__powf(abs(y0.z), data.fadeexp)));
|
52 |
+
|
53 |
+
if (dowarp) {
|
54 |
+
warp_ptr = data.warp + (k * 3 * data.WD * data.WH * data.WW);
|
55 |
+
yy1 = GridSamplerT<float3>::forward(3, data.WD, data.WH, data.WW, warp_ptr, y0, false);
|
56 |
+
} else {
|
57 |
+
yy1 = y0;
|
58 |
+
}
|
59 |
+
|
60 |
+
tplate_ptr = data.tplate + (k * 4 * data.TD * data.TH * data.TW);
|
61 |
+
float4 sample = GridSamplerT<float4>::forward(4, data.TD, data.TH, data.TW, tplate_ptr, yy1, false);
|
62 |
+
|
63 |
+
sample.w *= fade;
|
64 |
+
|
65 |
+
return sample;
|
66 |
+
}
|
67 |
+
|
68 |
+
__forceinline__ __device__ float3 backward(const Data & data, int k, float3 y0,
|
69 |
+
float4 sample, float4 dL_sample, bool validthread) {
|
70 |
+
float3 dfade_y0 = -(data.fadescale * data.fadeexp) * make_float3(
|
71 |
+
__powf(abs(y0.x), data.fadeexp - 1.f) * (y0.x > 0.f ? 1.f : -1.f),
|
72 |
+
__powf(abs(y0.y), data.fadeexp - 1.f) * (y0.y > 0.f ? 1.f : -1.f),
|
73 |
+
__powf(abs(y0.z), data.fadeexp - 1.f) * (y0.z > 0.f ? 1.f : -1.f));
|
74 |
+
float3 dL_y0 = dfade_y0 * sample.w * dL_sample.w;
|
75 |
+
|
76 |
+
dL_sample.w *= fade;
|
77 |
+
|
78 |
+
float * grad_tplate_ptr = data.grad_tplate + (k * 4 * data.TD * data.TH * data.TW);
|
79 |
+
float3 dL_y1 = GridSamplerT<float4>::backward(4, data.TD, data.TH, data.TW,
|
80 |
+
tplate_ptr, grad_tplate_ptr, yy1, validthread ? dL_sample : make_float4(0.f), false);
|
81 |
+
|
82 |
+
if (dowarp) {
|
83 |
+
float * grad_warp_ptr = data.grad_warp + (k * 3 * data.WD * data.WH * data.WW);
|
84 |
+
dL_y0 += GridSamplerT<float3>::backward(3, data.WD, data.WH, data.WW,
|
85 |
+
warp_ptr, grad_warp_ptr, y0, validthread ? dL_y1 : make_float3(0.f), false);
|
86 |
+
} else {
|
87 |
+
dL_y0 += dL_y1;
|
88 |
+
}
|
89 |
+
|
90 |
+
return dL_y0;
|
91 |
+
}
|
92 |
+
};
|
93 |
+
|
94 |
+
#endif
|
dva/mvp/extensions/mvpraymarch/primtransf.h
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
// All rights reserved.
|
3 |
+
//
|
4 |
+
// This source code is licensed under the license found in the
|
5 |
+
// LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
#ifndef MVPRAYMARCHER_PRIMTRANSF_H_
|
8 |
+
#define MVPRAYMARCHER_PRIMTRANSF_H_
|
9 |
+
|
10 |
+
#include "utils.h"
|
11 |
+
|
12 |
+
__forceinline__ __device__ void compute_aabb_srt(
|
13 |
+
float3 pt, float3 pr0, float3 pr1, float3 pr2, float3 ps,
|
14 |
+
float3 & pmin, float3 & pmax) {
|
15 |
+
float3 p;
|
16 |
+
p = make_float3(-1.f, -1.f, -1.f) / ps;
|
17 |
+
p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt;
|
18 |
+
|
19 |
+
pmin = p;
|
20 |
+
pmax = p;
|
21 |
+
|
22 |
+
p = make_float3(1.f, -1.f, -1.f) / ps;
|
23 |
+
p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt;
|
24 |
+
|
25 |
+
pmin = fminf(pmin, p);
|
26 |
+
pmax = fmaxf(pmax, p);
|
27 |
+
|
28 |
+
p = make_float3(-1.f, 1.f, -1.f) / ps;
|
29 |
+
p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt;
|
30 |
+
|
31 |
+
pmin = fminf(pmin, p);
|
32 |
+
pmax = fmaxf(pmax, p);
|
33 |
+
|
34 |
+
p = make_float3(1.f, 1.f, -1.f) / ps;
|
35 |
+
p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt;
|
36 |
+
|
37 |
+
pmin = fminf(pmin, p);
|
38 |
+
pmax = fmaxf(pmax, p);
|
39 |
+
|
40 |
+
p = make_float3(-1.f, -1.f, 1.f) / ps;
|
41 |
+
p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt;
|
42 |
+
|
43 |
+
pmin = fminf(pmin, p);
|
44 |
+
pmax = fmaxf(pmax, p);
|
45 |
+
|
46 |
+
p = make_float3(1.f, -1.f, 1.f) / ps;
|
47 |
+
p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt;
|
48 |
+
|
49 |
+
pmin = fminf(pmin, p);
|
50 |
+
pmax = fmaxf(pmax, p);
|
51 |
+
|
52 |
+
p = make_float3(-1.f, 1.f, 1.f) / ps;
|
53 |
+
p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt;
|
54 |
+
|
55 |
+
pmin = fminf(pmin, p);
|
56 |
+
pmax = fmaxf(pmax, p);
|
57 |
+
|
58 |
+
p = make_float3(1.f, 1.f, 1.f) / ps;
|
59 |
+
p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt;
|
60 |
+
|
61 |
+
pmin = fminf(pmin, p);
|
62 |
+
pmax = fmaxf(pmax, p);
|
63 |
+
}
|
64 |
+
|
65 |
+
struct PrimTransfDataBase {
|
66 |
+
typedef PrimTransfDataBase base;
|
67 |
+
};
|
68 |
+
|
69 |
+
struct PrimTransfSRT {
|
70 |
+
struct Data : public PrimTransfDataBase {
|
71 |
+
int primpos_nstride;
|
72 |
+
float3 * primpos;
|
73 |
+
float3 * grad_primpos;
|
74 |
+
int primrot_nstride;
|
75 |
+
float3 * primrot;
|
76 |
+
float3 * grad_primrot;
|
77 |
+
int primscale_nstride;
|
78 |
+
float3 * primscale;
|
79 |
+
float3 * grad_primscale;
|
80 |
+
|
81 |
+
__forceinline__ __device__ void n_stride(int n) {
|
82 |
+
primpos += n * primpos_nstride;
|
83 |
+
grad_primpos += n * primpos_nstride;
|
84 |
+
primrot += n * primrot_nstride;
|
85 |
+
grad_primrot += n * primrot_nstride;
|
86 |
+
primscale += n * primscale_nstride;
|
87 |
+
grad_primscale += n * primscale_nstride;
|
88 |
+
}
|
89 |
+
|
90 |
+
__forceinline__ __device__ float3 get_center(int n, int k) {
|
91 |
+
return primpos[n * primpos_nstride + k];
|
92 |
+
}
|
93 |
+
|
94 |
+
__forceinline__ __device__ void compute_aabb(int n, int k, float3 & pmin, float3 & pmax) {
|
95 |
+
float3 pt = primpos[n * primpos_nstride + k];
|
96 |
+
float3 pr0 = primrot[n * primrot_nstride + k * 3 + 0];
|
97 |
+
float3 pr1 = primrot[n * primrot_nstride + k * 3 + 1];
|
98 |
+
float3 pr2 = primrot[n * primrot_nstride + k * 3 + 2];
|
99 |
+
float3 ps = primscale[n * primscale_nstride + k];
|
100 |
+
|
101 |
+
compute_aabb_srt(pt, pr0, pr1, pr2, ps, pmin, pmax);
|
102 |
+
}
|
103 |
+
};
|
104 |
+
|
105 |
+
float3 xmt;
|
106 |
+
float3 pr0;
|
107 |
+
float3 pr1;
|
108 |
+
float3 pr2;
|
109 |
+
float3 rxmt;
|
110 |
+
float3 ps;
|
111 |
+
|
112 |
+
static __forceinline__ __device__ bool valid(float3 pos) {
|
113 |
+
return (
|
114 |
+
pos.x > -1.f && pos.x < 1.f &&
|
115 |
+
pos.y > -1.f && pos.y < 1.f &&
|
116 |
+
pos.z > -1.f && pos.z < 1.f);
|
117 |
+
}
|
118 |
+
|
119 |
+
__forceinline__ __device__ float3 forward(
|
120 |
+
const Data & data,
|
121 |
+
int k,
|
122 |
+
float3 x) {
|
123 |
+
float3 pt = data.primpos[k];
|
124 |
+
pr0 = data.primrot[(k) * 3 + 0];
|
125 |
+
pr1 = data.primrot[(k) * 3 + 1];
|
126 |
+
pr2 = data.primrot[(k) * 3 + 2];
|
127 |
+
ps = data.primscale[k];
|
128 |
+
xmt = x - pt;
|
129 |
+
rxmt = pr0 * xmt.x + pr1 * xmt.y + pr2 * xmt.z;
|
130 |
+
float3 y0 = rxmt * ps;
|
131 |
+
return y0;
|
132 |
+
}
|
133 |
+
|
134 |
+
static __forceinline__ __device__ void forward2(
|
135 |
+
const Data & data,
|
136 |
+
int k,
|
137 |
+
float3 r, float3 d, float3 & rout, float3 & dout) {
|
138 |
+
float3 pt = data.primpos[k];
|
139 |
+
float3 pr0 = data.primrot[k * 3 + 0];
|
140 |
+
float3 pr1 = data.primrot[k * 3 + 1];
|
141 |
+
float3 pr2 = data.primrot[k * 3 + 2];
|
142 |
+
float3 ps = data.primscale[k];
|
143 |
+
float3 xmt = r - pt;
|
144 |
+
float3 dmt = d;
|
145 |
+
float3 rxmt = pr0 * xmt.x;
|
146 |
+
float3 rdmt = pr0 * dmt.x;
|
147 |
+
rxmt += pr1 * xmt.y;
|
148 |
+
rdmt += pr1 * dmt.y;
|
149 |
+
rxmt += pr2 * xmt.z;
|
150 |
+
rdmt += pr2 * dmt.z;
|
151 |
+
rout = rxmt * ps;
|
152 |
+
dout = rdmt * ps;
|
153 |
+
}
|
154 |
+
|
155 |
+
__forceinline__ __device__ void backward(const Data & data, int k, float3 x, float3 dL_y0, bool validthread) {
|
156 |
+
fastAtomicAdd((float*)data.grad_primscale + k * 3 + 0, validthread ? rxmt.x * dL_y0.x : 0.f);
|
157 |
+
fastAtomicAdd((float*)data.grad_primscale + k * 3 + 1, validthread ? rxmt.y * dL_y0.y : 0.f);
|
158 |
+
fastAtomicAdd((float*)data.grad_primscale + k * 3 + 2, validthread ? rxmt.z * dL_y0.z : 0.f);
|
159 |
+
|
160 |
+
dL_y0 *= ps;
|
161 |
+
float3 gpr0 = xmt.x * dL_y0;
|
162 |
+
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 0) * 3 + 0, validthread ? gpr0.x : 0.f);
|
163 |
+
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 0) * 3 + 1, validthread ? gpr0.y : 0.f);
|
164 |
+
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 0) * 3 + 2, validthread ? gpr0.z : 0.f);
|
165 |
+
|
166 |
+
float3 gpr1 = xmt.y * dL_y0;
|
167 |
+
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 1) * 3 + 0, validthread ? gpr1.x : 0.f);
|
168 |
+
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 1) * 3 + 1, validthread ? gpr1.y : 0.f);
|
169 |
+
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 1) * 3 + 2, validthread ? gpr1.z : 0.f);
|
170 |
+
|
171 |
+
float3 gpr2 = xmt.z * dL_y0;
|
172 |
+
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 2) * 3 + 0, validthread ? gpr2.x : 0.f);
|
173 |
+
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 2) * 3 + 1, validthread ? gpr2.y : 0.f);
|
174 |
+
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 2) * 3 + 2, validthread ? gpr2.z : 0.f);
|
175 |
+
|
176 |
+
fastAtomicAdd((float*)data.grad_primpos + k * 3 + 0, validthread ? -dot(pr0, dL_y0) : 0.f);
|
177 |
+
fastAtomicAdd((float*)data.grad_primpos + k * 3 + 1, validthread ? -dot(pr1, dL_y0) : 0.f);
|
178 |
+
fastAtomicAdd((float*)data.grad_primpos + k * 3 + 2, validthread ? -dot(pr2, dL_y0) : 0.f);
|
179 |
+
}
|
180 |
+
};
|
181 |
+
|
182 |
+
#endif
|
dva/mvp/extensions/mvpraymarch/setup.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from setuptools import setup
|
8 |
+
|
9 |
+
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
|
10 |
+
|
11 |
+
if __name__ == "__main__":
|
12 |
+
import torch
|
13 |
+
setup(
|
14 |
+
name="mvpraymarch",
|
15 |
+
ext_modules=[
|
16 |
+
CUDAExtension(
|
17 |
+
"mvpraymarchlib",
|
18 |
+
sources=["mvpraymarch.cpp", "mvpraymarch_kernel.cu", "bvh.cu"],
|
19 |
+
extra_compile_args={
|
20 |
+
"nvcc": [
|
21 |
+
"-use_fast_math",
|
22 |
+
"-arch=sm_70",
|
23 |
+
"-std=c++17",
|
24 |
+
"-lineinfo",
|
25 |
+
]
|
26 |
+
}
|
27 |
+
)
|
28 |
+
],
|
29 |
+
cmdclass={"build_ext": BuildExtension}
|
30 |
+
)
|
dva/mvp/extensions/mvpraymarch/utils.h
ADDED
@@ -0,0 +1,847 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
// All rights reserved.
|
3 |
+
//
|
4 |
+
// This source code is licensed under the license found in the
|
5 |
+
// LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
#ifndef MVPRAYMARCHER_UTILS_H_
|
8 |
+
#define MVPRAYMARCHER_UTILS_H_
|
9 |
+
|
10 |
+
#include <cassert>
|
11 |
+
#include <cmath>
|
12 |
+
|
13 |
+
#include <limits>
|
14 |
+
|
15 |
+
#include "helper_math.h"
|
16 |
+
|
17 |
+
static __forceinline__ __device__ float clock_diff(long long int end, long long int start) {
|
18 |
+
long long int max_clock = std::numeric_limits<long long int>::max();
|
19 |
+
return (end<start? (end + float(max_clock-start)) : float(end-start));
|
20 |
+
}
|
21 |
+
|
22 |
+
static __forceinline__ __device__
|
23 |
+
bool allgt(float3 a, float3 b) {
|
24 |
+
return a.x >= b.x && a.y >= b.y && a.z >= b.z;
|
25 |
+
}
|
26 |
+
|
27 |
+
static __forceinline__ __device__
|
28 |
+
bool alllt(float3 a, float3 b) {
|
29 |
+
return a.x <= b.x && a.y <= b.y && a.z <= b.z;
|
30 |
+
}
|
31 |
+
|
32 |
+
static __forceinline__ __device__
|
33 |
+
float4 softplus(float4 x) {
|
34 |
+
return make_float4(
|
35 |
+
x.x > 20.f ? x.x : logf(1.f + expf(x.x)),
|
36 |
+
x.y > 20.f ? x.y : logf(1.f + expf(x.y)),
|
37 |
+
x.z > 20.f ? x.z : logf(1.f + expf(x.z)),
|
38 |
+
x.w > 20.f ? x.w : logf(1.f + expf(x.w)));
|
39 |
+
}
|
40 |
+
|
41 |
+
static __forceinline__ __device__
|
42 |
+
float softplus(float x) {
|
43 |
+
// that's a neat trick
|
44 |
+
return __logf(1.f + __expf(-abs(x))) + max(x, 0.f);
|
45 |
+
}
|
46 |
+
static __forceinline__ __device__
|
47 |
+
float softplus_grad(float x) {
|
48 |
+
// that's a neat trick
|
49 |
+
float expnabsx = __expf(-abs(x));
|
50 |
+
return (0.5f - expnabsx / (1.f + expnabsx)) * copysign(1.f, x) + 0.5f;
|
51 |
+
}
|
52 |
+
|
53 |
+
|
54 |
+
static __forceinline__ __device__
|
55 |
+
float4 sigmoid(float4 x) {
|
56 |
+
return make_float4(
|
57 |
+
1.f / (1.f + expf(-x.x)),
|
58 |
+
1.f / (1.f + expf(-x.y)),
|
59 |
+
1.f / (1.f + expf(-x.z)),
|
60 |
+
1.f / (1.f + expf(-x.w)));
|
61 |
+
}
|
62 |
+
|
63 |
+
// perform reduction on warp, then call atomicAdd for only one lane
|
64 |
+
static __forceinline__ __device__ void fastAtomicAdd(float * ptr, float val) {
|
65 |
+
for (int offset = 16; offset > 0; offset /= 2) {
|
66 |
+
val += __shfl_down_sync(0xffffffff, val, offset);
|
67 |
+
}
|
68 |
+
|
69 |
+
const int laneid = (threadIdx.y * blockDim.x + threadIdx.x) % 32;
|
70 |
+
if (laneid == 0) {
|
71 |
+
atomicAdd(ptr, val);
|
72 |
+
}
|
73 |
+
}
|
74 |
+
|
75 |
+
|
76 |
+
static __forceinline__ __device__
|
77 |
+
bool within_bounds_3d(int d, int h, int w, int D, int H, int W) {
|
78 |
+
return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W;
|
79 |
+
}
|
80 |
+
|
81 |
+
static __forceinline__ __device__
|
82 |
+
void safe_add_3d(float *data, int d, int h, int w,
|
83 |
+
int sD, int sH, int sW, int D, int H, int W,
|
84 |
+
float delta) {
|
85 |
+
if (within_bounds_3d(d, h, w, D, H, W)) {
|
86 |
+
atomicAdd(data + d * sD + h * sH + w * sW, delta);
|
87 |
+
}
|
88 |
+
}
|
89 |
+
|
90 |
+
static __forceinline__ __device__
|
91 |
+
void safe_add_3d(float3 *data, int d, int h, int w,
|
92 |
+
int sD, int sH, int sW, int D, int H, int W,
|
93 |
+
float3 delta) {
|
94 |
+
if (within_bounds_3d(d, h, w, D, H, W)) {
|
95 |
+
atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 3 + 0, delta.x);
|
96 |
+
atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 3 + 1, delta.y);
|
97 |
+
atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 3 + 2, delta.z);
|
98 |
+
}
|
99 |
+
}
|
100 |
+
|
101 |
+
static __forceinline__ __device__
|
102 |
+
void safe_add_3d(float4 *data, int d, int h, int w,
|
103 |
+
int sD, int sH, int sW, int D, int H, int W,
|
104 |
+
float4 delta) {
|
105 |
+
if (within_bounds_3d(d, h, w, D, H, W)) {
|
106 |
+
atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 4 + 0, delta.x);
|
107 |
+
atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 4 + 1, delta.y);
|
108 |
+
atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 4 + 2, delta.z);
|
109 |
+
atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 4 + 3, delta.w);
|
110 |
+
}
|
111 |
+
}
|
112 |
+
|
113 |
+
static __forceinline__ __device__
|
114 |
+
float clip_coordinates(float in, int clip_limit) {
|
115 |
+
return ::min(static_cast<float>(clip_limit - 1), ::max(in, 0.f));
|
116 |
+
}
|
117 |
+
|
118 |
+
template <typename scalar_t>
|
119 |
+
static __forceinline__ __device__
|
120 |
+
float clip_coordinates_set_grad(float in, int clip_limit, scalar_t *grad_in) {
|
121 |
+
if (in < 0.f) {
|
122 |
+
*grad_in = static_cast<scalar_t>(0);
|
123 |
+
return 0.f;
|
124 |
+
} else {
|
125 |
+
float max = static_cast<float>(clip_limit - 1);
|
126 |
+
if (in > max) {
|
127 |
+
*grad_in = static_cast<scalar_t>(0);
|
128 |
+
return max;
|
129 |
+
} else {
|
130 |
+
*grad_in = static_cast<scalar_t>(1);
|
131 |
+
return in;
|
132 |
+
}
|
133 |
+
}
|
134 |
+
}
|
135 |
+
|
136 |
+
template<typename out_t>
|
137 |
+
static __device__ out_t grid_sample_forward(int C, int inp_D, int inp_H,
|
138 |
+
int inp_W, float* vals, float3 pos, bool border) {
|
139 |
+
int inp_sW = 1, inp_sH = inp_W, inp_sD = inp_W * inp_H, inp_sC = inp_W * inp_H * inp_D;
|
140 |
+
int out_sC = 1;
|
141 |
+
|
142 |
+
// normalize ix, iy, iz from [-1, 1] to [0, inp_W-1] & [0, inp_H-1] & [0, inp_D-1]
|
143 |
+
float ix = max(-10.f, min(10.f, ((pos.x + 1.f) * 0.5f))) * (inp_W - 1);
|
144 |
+
float iy = max(-10.f, min(10.f, ((pos.y + 1.f) * 0.5f))) * (inp_H - 1);
|
145 |
+
float iz = max(-10.f, min(10.f, ((pos.z + 1.f) * 0.5f))) * (inp_D - 1);
|
146 |
+
|
147 |
+
if (border) {
|
148 |
+
// clip coordinates to image borders
|
149 |
+
ix = clip_coordinates(ix, inp_W);
|
150 |
+
iy = clip_coordinates(iy, inp_H);
|
151 |
+
iz = clip_coordinates(iz, inp_D);
|
152 |
+
}
|
153 |
+
|
154 |
+
// get corner pixel values from (x, y, z)
|
155 |
+
// for 4d, we used north-east-south-west
|
156 |
+
// for 5d, we add top-bottom
|
157 |
+
int ix_tnw = static_cast<int>(::floor(ix));
|
158 |
+
int iy_tnw = static_cast<int>(::floor(iy));
|
159 |
+
int iz_tnw = static_cast<int>(::floor(iz));
|
160 |
+
|
161 |
+
int ix_tne = ix_tnw + 1;
|
162 |
+
int iy_tne = iy_tnw;
|
163 |
+
int iz_tne = iz_tnw;
|
164 |
+
|
165 |
+
int ix_tsw = ix_tnw;
|
166 |
+
int iy_tsw = iy_tnw + 1;
|
167 |
+
int iz_tsw = iz_tnw;
|
168 |
+
|
169 |
+
int ix_tse = ix_tnw + 1;
|
170 |
+
int iy_tse = iy_tnw + 1;
|
171 |
+
int iz_tse = iz_tnw;
|
172 |
+
|
173 |
+
int ix_bnw = ix_tnw;
|
174 |
+
int iy_bnw = iy_tnw;
|
175 |
+
int iz_bnw = iz_tnw + 1;
|
176 |
+
|
177 |
+
int ix_bne = ix_tnw + 1;
|
178 |
+
int iy_bne = iy_tnw;
|
179 |
+
int iz_bne = iz_tnw + 1;
|
180 |
+
|
181 |
+
int ix_bsw = ix_tnw;
|
182 |
+
int iy_bsw = iy_tnw + 1;
|
183 |
+
int iz_bsw = iz_tnw + 1;
|
184 |
+
|
185 |
+
int ix_bse = ix_tnw + 1;
|
186 |
+
int iy_bse = iy_tnw + 1;
|
187 |
+
int iz_bse = iz_tnw + 1;
|
188 |
+
|
189 |
+
// get surfaces to each neighbor:
|
190 |
+
float tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
|
191 |
+
float tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
|
192 |
+
float tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
|
193 |
+
float tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
|
194 |
+
float bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
|
195 |
+
float bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
|
196 |
+
float bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
|
197 |
+
float bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
|
198 |
+
|
199 |
+
out_t result;
|
200 |
+
//auto inp_ptr_NC = input.data + n * inp_sN;
|
201 |
+
//auto out_ptr_NCDHW = output.data + n * out_sN + d * out_sD + h * out_sH + w * out_sW;
|
202 |
+
float * inp_ptr_NC = vals;
|
203 |
+
float * out_ptr_NCDHW = &result.x;
|
204 |
+
for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) {
|
205 |
+
// (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * tne
|
206 |
+
// + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * tse
|
207 |
+
// + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * bne
|
208 |
+
// + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * bse
|
209 |
+
*out_ptr_NCDHW = static_cast<float>(0);
|
210 |
+
if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
|
211 |
+
*out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw;
|
212 |
+
}
|
213 |
+
if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
|
214 |
+
*out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne;
|
215 |
+
}
|
216 |
+
if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
|
217 |
+
*out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw;
|
218 |
+
}
|
219 |
+
if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
|
220 |
+
*out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse;
|
221 |
+
}
|
222 |
+
if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
|
223 |
+
*out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw;
|
224 |
+
}
|
225 |
+
if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
|
226 |
+
*out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne;
|
227 |
+
}
|
228 |
+
if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
|
229 |
+
*out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw;
|
230 |
+
}
|
231 |
+
if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
|
232 |
+
*out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse;
|
233 |
+
}
|
234 |
+
}
|
235 |
+
return result;
|
236 |
+
}
|
237 |
+
|
238 |
+
template<typename out_t>
|
239 |
+
static __device__ float3 grid_sample_backward(int C, int inp_D, int inp_H,
|
240 |
+
int inp_W, float* vals, float* grad_vals, float3 pos, out_t grad_out,
|
241 |
+
bool border) {
|
242 |
+
int inp_sW = 1, inp_sH = inp_W, inp_sD = inp_W * inp_H, inp_sC = inp_W * inp_H * inp_D;
|
243 |
+
int gInp_sW = 1, gInp_sH = inp_W, gInp_sD = inp_W * inp_H, gInp_sC = inp_W * inp_H * inp_D;
|
244 |
+
int gOut_sC = 1;
|
245 |
+
|
246 |
+
// normalize ix, iy, iz from [-1, 1] to [0, inp_W-1] & [0, inp_H-1] & [0, inp_D-1]
|
247 |
+
float ix = max(-10.f, min(10.f, ((pos.x + 1.f) * 0.5f))) * (inp_W - 1);
|
248 |
+
float iy = max(-10.f, min(10.f, ((pos.y + 1.f) * 0.5f))) * (inp_H - 1);
|
249 |
+
float iz = max(-10.f, min(10.f, ((pos.z + 1.f) * 0.5f))) * (inp_D - 1);
|
250 |
+
|
251 |
+
float gix_mult = (inp_W - 1.f) / 2;
|
252 |
+
float giy_mult = (inp_H - 1.f) / 2;
|
253 |
+
float giz_mult = (inp_D - 1.f) / 2;
|
254 |
+
|
255 |
+
if (border) {
|
256 |
+
// clip coordinates to image borders
|
257 |
+
ix = clip_coordinates_set_grad(ix, inp_W, &gix_mult);
|
258 |
+
iy = clip_coordinates_set_grad(iy, inp_H, &giy_mult);
|
259 |
+
iz = clip_coordinates_set_grad(iz, inp_D, &giz_mult);
|
260 |
+
}
|
261 |
+
|
262 |
+
// get corner pixel values from (x, y, z)
|
263 |
+
// for 4d, we used north-east-south-west
|
264 |
+
// for 5d, we add top-bottom
|
265 |
+
int ix_tnw = static_cast<int>(::floor(ix));
|
266 |
+
int iy_tnw = static_cast<int>(::floor(iy));
|
267 |
+
int iz_tnw = static_cast<int>(::floor(iz));
|
268 |
+
|
269 |
+
int ix_tne = ix_tnw + 1;
|
270 |
+
int iy_tne = iy_tnw;
|
271 |
+
int iz_tne = iz_tnw;
|
272 |
+
|
273 |
+
int ix_tsw = ix_tnw;
|
274 |
+
int iy_tsw = iy_tnw + 1;
|
275 |
+
int iz_tsw = iz_tnw;
|
276 |
+
|
277 |
+
int ix_tse = ix_tnw + 1;
|
278 |
+
int iy_tse = iy_tnw + 1;
|
279 |
+
int iz_tse = iz_tnw;
|
280 |
+
|
281 |
+
int ix_bnw = ix_tnw;
|
282 |
+
int iy_bnw = iy_tnw;
|
283 |
+
int iz_bnw = iz_tnw + 1;
|
284 |
+
|
285 |
+
int ix_bne = ix_tnw + 1;
|
286 |
+
int iy_bne = iy_tnw;
|
287 |
+
int iz_bne = iz_tnw + 1;
|
288 |
+
|
289 |
+
int ix_bsw = ix_tnw;
|
290 |
+
int iy_bsw = iy_tnw + 1;
|
291 |
+
int iz_bsw = iz_tnw + 1;
|
292 |
+
|
293 |
+
int ix_bse = ix_tnw + 1;
|
294 |
+
int iy_bse = iy_tnw + 1;
|
295 |
+
int iz_bse = iz_tnw + 1;
|
296 |
+
|
297 |
+
// get surfaces to each neighbor:
|
298 |
+
float tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
|
299 |
+
float tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
|
300 |
+
float tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
|
301 |
+
float tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
|
302 |
+
float bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
|
303 |
+
float bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
|
304 |
+
float bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
|
305 |
+
float bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
|
306 |
+
|
307 |
+
float gix = static_cast<float>(0), giy = static_cast<float>(0), giz = static_cast<float>(0);
|
308 |
+
//float *gOut_ptr_NCDHW = grad_output.data + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW;
|
309 |
+
//float *gInp_ptr_NC = grad_input.data + n * gInp_sN;
|
310 |
+
//float *inp_ptr_NC = input.data + n * inp_sN;
|
311 |
+
float *gOut_ptr_NCDHW = &grad_out.x;
|
312 |
+
float *gInp_ptr_NC = grad_vals;
|
313 |
+
float *inp_ptr_NC = vals;
|
314 |
+
// calculate bilinear weighted pixel value and set output pixel
|
315 |
+
for (int c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, gInp_ptr_NC += gInp_sC, inp_ptr_NC += inp_sC) {
|
316 |
+
float gOut = *gOut_ptr_NCDHW;
|
317 |
+
|
318 |
+
// calculate and set grad_input
|
319 |
+
safe_add_3d(gInp_ptr_NC, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw * gOut);
|
320 |
+
safe_add_3d(gInp_ptr_NC, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne * gOut);
|
321 |
+
safe_add_3d(gInp_ptr_NC, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw * gOut);
|
322 |
+
safe_add_3d(gInp_ptr_NC, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse * gOut);
|
323 |
+
safe_add_3d(gInp_ptr_NC, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw * gOut);
|
324 |
+
safe_add_3d(gInp_ptr_NC, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne * gOut);
|
325 |
+
safe_add_3d(gInp_ptr_NC, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw * gOut);
|
326 |
+
safe_add_3d(gInp_ptr_NC, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse * gOut);
|
327 |
+
|
328 |
+
// calculate grad_grid
|
329 |
+
if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
|
330 |
+
float tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW];
|
331 |
+
gix -= tnw_val * (iy_bse - iy) * (iz_bse - iz) * gOut;
|
332 |
+
giy -= tnw_val * (ix_bse - ix) * (iz_bse - iz) * gOut;
|
333 |
+
giz -= tnw_val * (ix_bse - ix) * (iy_bse - iy) * gOut;
|
334 |
+
}
|
335 |
+
if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
|
336 |
+
float tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW];
|
337 |
+
gix += tne_val * (iy_bsw - iy) * (iz_bsw - iz) * gOut;
|
338 |
+
giy -= tne_val * (ix - ix_bsw) * (iz_bsw - iz) * gOut;
|
339 |
+
giz -= tne_val * (ix - ix_bsw) * (iy_bsw - iy) * gOut;
|
340 |
+
}
|
341 |
+
if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
|
342 |
+
float tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW];
|
343 |
+
gix -= tsw_val * (iy - iy_bne) * (iz_bne - iz) * gOut;
|
344 |
+
giy += tsw_val * (ix_bne - ix) * (iz_bne - iz) * gOut;
|
345 |
+
giz -= tsw_val * (ix_bne - ix) * (iy - iy_bne) * gOut;
|
346 |
+
}
|
347 |
+
if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
|
348 |
+
float tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW];
|
349 |
+
gix += tse_val * (iy - iy_bnw) * (iz_bnw - iz) * gOut;
|
350 |
+
giy += tse_val * (ix - ix_bnw) * (iz_bnw - iz) * gOut;
|
351 |
+
giz -= tse_val * (ix - ix_bnw) * (iy - iy_bnw) * gOut;
|
352 |
+
}
|
353 |
+
if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
|
354 |
+
float bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW];
|
355 |
+
gix -= bnw_val * (iy_tse - iy) * (iz - iz_tse) * gOut;
|
356 |
+
giy -= bnw_val * (ix_tse - ix) * (iz - iz_tse) * gOut;
|
357 |
+
giz += bnw_val * (ix_tse - ix) * (iy_tse - iy) * gOut;
|
358 |
+
}
|
359 |
+
if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
|
360 |
+
float bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW];
|
361 |
+
gix += bne_val * (iy_tsw - iy) * (iz - iz_tsw) * gOut;
|
362 |
+
giy -= bne_val * (ix - ix_tsw) * (iz - iz_tsw) * gOut;
|
363 |
+
giz += bne_val * (ix - ix_tsw) * (iy_tsw - iy) * gOut;
|
364 |
+
}
|
365 |
+
if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
|
366 |
+
float bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW];
|
367 |
+
gix -= bsw_val * (iy - iy_tne) * (iz - iz_tne) * gOut;
|
368 |
+
giy += bsw_val * (ix_tne - ix) * (iz - iz_tne) * gOut;
|
369 |
+
giz += bsw_val * (ix_tne - ix) * (iy - iy_tne) * gOut;
|
370 |
+
}
|
371 |
+
if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
|
372 |
+
float bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW];
|
373 |
+
gix += bse_val * (iy - iy_tnw) * (iz - iz_tnw) * gOut;
|
374 |
+
giy += bse_val * (ix - ix_tnw) * (iz - iz_tnw) * gOut;
|
375 |
+
giz += bse_val * (ix - ix_tnw) * (iy - iy_tnw) * gOut;
|
376 |
+
}
|
377 |
+
}
|
378 |
+
|
379 |
+
return make_float3(gix_mult * gix, giy_mult * giy, giz_mult * giz);
|
380 |
+
}
|
381 |
+
|
382 |
+
// this dummy struct necessary because c++ is dumb
|
383 |
+
template<typename out_t>
|
384 |
+
struct GridSampler {
|
385 |
+
static __forceinline__ __device__ out_t forward(int C, int inp_D, int inp_H, int inp_W,
|
386 |
+
float* vals, float3 pos, bool border) {
|
387 |
+
return grid_sample_forward<out_t>(C, inp_D, inp_H, inp_W, vals, pos, border);
|
388 |
+
}
|
389 |
+
|
390 |
+
static __forceinline__ __device__ float3 backward(int C, int inp_D, int inp_H, int inp_W,
|
391 |
+
float* vals, float* grad_vals, float3 pos, out_t grad_out, bool border) {
|
392 |
+
return grid_sample_backward<out_t>(C, inp_D, inp_H, inp_W, vals, grad_vals, pos, grad_out, border);
|
393 |
+
}
|
394 |
+
};
|
395 |
+
|
396 |
+
//template <typename T>
|
397 |
+
//__device__ void cswap ( T& a, T& b ) {
|
398 |
+
// T c(a); a=b; b=c;
|
399 |
+
//}
|
400 |
+
|
401 |
+
static __forceinline__ __device__
|
402 |
+
int within_bounds_3d_ind(int d, int h, int w, int D, int H, int W) {
|
403 |
+
return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W ? ((d * H) + h) * W + w : -1;
|
404 |
+
}
|
405 |
+
|
406 |
+
template<class out_t>
|
407 |
+
static __device__ out_t grid_sample_chlast_forward(int, int inp_D, int inp_H,
|
408 |
+
int inp_W, float * vals, float3 pos, bool border) {
|
409 |
+
int inp_sW = 1, inp_sH = inp_W, inp_sD = inp_W * inp_H;
|
410 |
+
|
411 |
+
// normalize ix, iy, iz from [-1, 1] to [0, inp_W-1] & [0, inp_H-1] & [0, inp_D-1]
|
412 |
+
float ix = max(-100.f, min(100.f, ((pos.x + 1.f) / 2))) * (inp_W - 1);
|
413 |
+
float iy = max(-100.f, min(100.f, ((pos.y + 1.f) / 2))) * (inp_H - 1);
|
414 |
+
float iz = max(-100.f, min(100.f, ((pos.z + 1.f) / 2))) * (inp_D - 1);
|
415 |
+
|
416 |
+
if (border) {
|
417 |
+
// clip coordinates to image borders
|
418 |
+
ix = clip_coordinates(ix, inp_W);
|
419 |
+
iy = clip_coordinates(iy, inp_H);
|
420 |
+
iz = clip_coordinates(iz, inp_D);
|
421 |
+
}
|
422 |
+
|
423 |
+
// get corner pixel values from (x, y, z)
|
424 |
+
// for 4d, we used north-east-south-west
|
425 |
+
// for 5d, we add top-bottom
|
426 |
+
int ix_tnw = static_cast<int>(::floor(ix));
|
427 |
+
int iy_tnw = static_cast<int>(::floor(iy));
|
428 |
+
int iz_tnw = static_cast<int>(::floor(iz));
|
429 |
+
|
430 |
+
int ix_tne = ix_tnw + 1;
|
431 |
+
int iy_tne = iy_tnw;
|
432 |
+
int iz_tne = iz_tnw;
|
433 |
+
|
434 |
+
int ix_tsw = ix_tnw;
|
435 |
+
int iy_tsw = iy_tnw + 1;
|
436 |
+
int iz_tsw = iz_tnw;
|
437 |
+
|
438 |
+
int ix_tse = ix_tnw + 1;
|
439 |
+
int iy_tse = iy_tnw + 1;
|
440 |
+
int iz_tse = iz_tnw;
|
441 |
+
|
442 |
+
int ix_bnw = ix_tnw;
|
443 |
+
int iy_bnw = iy_tnw;
|
444 |
+
int iz_bnw = iz_tnw + 1;
|
445 |
+
|
446 |
+
int ix_bne = ix_tnw + 1;
|
447 |
+
int iy_bne = iy_tnw;
|
448 |
+
int iz_bne = iz_tnw + 1;
|
449 |
+
|
450 |
+
int ix_bsw = ix_tnw;
|
451 |
+
int iy_bsw = iy_tnw + 1;
|
452 |
+
int iz_bsw = iz_tnw + 1;
|
453 |
+
|
454 |
+
int ix_bse = ix_tnw + 1;
|
455 |
+
int iy_bse = iy_tnw + 1;
|
456 |
+
int iz_bse = iz_tnw + 1;
|
457 |
+
|
458 |
+
// get surfaces to each neighbor:
|
459 |
+
float tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
|
460 |
+
float tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
|
461 |
+
float tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
|
462 |
+
float tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
|
463 |
+
float bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
|
464 |
+
float bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
|
465 |
+
float bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
|
466 |
+
float bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
|
467 |
+
|
468 |
+
out_t result;
|
469 |
+
memset(&result, 0, sizeof(out_t));
|
470 |
+
out_t * inp_ptr_NC = (out_t*)vals;
|
471 |
+
out_t * out_ptr_NCDHW = &result;
|
472 |
+
{
|
473 |
+
if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
|
474 |
+
*out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw;
|
475 |
+
}
|
476 |
+
if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
|
477 |
+
*out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne;
|
478 |
+
}
|
479 |
+
if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
|
480 |
+
*out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw;
|
481 |
+
}
|
482 |
+
if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
|
483 |
+
*out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse;
|
484 |
+
}
|
485 |
+
if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
|
486 |
+
*out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw;
|
487 |
+
}
|
488 |
+
if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
|
489 |
+
*out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne;
|
490 |
+
}
|
491 |
+
if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
|
492 |
+
*out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw;
|
493 |
+
}
|
494 |
+
if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
|
495 |
+
*out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse;
|
496 |
+
}
|
497 |
+
}
|
498 |
+
|
499 |
+
return result;
|
500 |
+
}
|
501 |
+
|
502 |
+
template<typename out_t>
|
503 |
+
static __device__ float3 grid_sample_chlast_backward(int, int inp_D, int inp_H,
|
504 |
+
int inp_W, float* vals, float* grad_vals, float3 pos, out_t grad_out,
|
505 |
+
bool border) {
|
506 |
+
int inp_sW = 1, inp_sH = inp_W, inp_sD = inp_W * inp_H;
|
507 |
+
int gInp_sW = 1, gInp_sH = inp_W, gInp_sD = inp_W * inp_H;
|
508 |
+
|
509 |
+
// normalize ix, iy, iz from [-1, 1] to [0, inp_W-1] & [0, inp_H-1] & [0, inp_D-1]
|
510 |
+
float ix = max(-100.f, min(100.f, ((pos.x + 1.f) / 2))) * (inp_W - 1);
|
511 |
+
float iy = max(-100.f, min(100.f, ((pos.y + 1.f) / 2))) * (inp_H - 1);
|
512 |
+
float iz = max(-100.f, min(100.f, ((pos.z + 1.f) / 2))) * (inp_D - 1);
|
513 |
+
|
514 |
+
float gix_mult = (inp_W - 1.f) / 2;
|
515 |
+
float giy_mult = (inp_H - 1.f) / 2;
|
516 |
+
float giz_mult = (inp_D - 1.f) / 2;
|
517 |
+
|
518 |
+
if (border) {
|
519 |
+
// clip coordinates to image borders
|
520 |
+
ix = clip_coordinates_set_grad(ix, inp_W, &gix_mult);
|
521 |
+
iy = clip_coordinates_set_grad(iy, inp_H, &giy_mult);
|
522 |
+
iz = clip_coordinates_set_grad(iz, inp_D, &giz_mult);
|
523 |
+
}
|
524 |
+
|
525 |
+
// get corner pixel values from (x, y, z)
|
526 |
+
// for 4d, we used north-east-south-west
|
527 |
+
// for 5d, we add top-bottom
|
528 |
+
int ix_tnw = static_cast<int>(::floor(ix));
|
529 |
+
int iy_tnw = static_cast<int>(::floor(iy));
|
530 |
+
int iz_tnw = static_cast<int>(::floor(iz));
|
531 |
+
|
532 |
+
int ix_tne = ix_tnw + 1;
|
533 |
+
int iy_tne = iy_tnw;
|
534 |
+
int iz_tne = iz_tnw;
|
535 |
+
|
536 |
+
int ix_tsw = ix_tnw;
|
537 |
+
int iy_tsw = iy_tnw + 1;
|
538 |
+
int iz_tsw = iz_tnw;
|
539 |
+
|
540 |
+
int ix_tse = ix_tnw + 1;
|
541 |
+
int iy_tse = iy_tnw + 1;
|
542 |
+
int iz_tse = iz_tnw;
|
543 |
+
|
544 |
+
int ix_bnw = ix_tnw;
|
545 |
+
int iy_bnw = iy_tnw;
|
546 |
+
int iz_bnw = iz_tnw + 1;
|
547 |
+
|
548 |
+
int ix_bne = ix_tnw + 1;
|
549 |
+
int iy_bne = iy_tnw;
|
550 |
+
int iz_bne = iz_tnw + 1;
|
551 |
+
|
552 |
+
int ix_bsw = ix_tnw;
|
553 |
+
int iy_bsw = iy_tnw + 1;
|
554 |
+
int iz_bsw = iz_tnw + 1;
|
555 |
+
|
556 |
+
int ix_bse = ix_tnw + 1;
|
557 |
+
int iy_bse = iy_tnw + 1;
|
558 |
+
int iz_bse = iz_tnw + 1;
|
559 |
+
|
560 |
+
// get surfaces to each neighbor:
|
561 |
+
float tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
|
562 |
+
float tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
|
563 |
+
float tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
|
564 |
+
float tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
|
565 |
+
float bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
|
566 |
+
float bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
|
567 |
+
float bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
|
568 |
+
float bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
|
569 |
+
|
570 |
+
float gix = static_cast<float>(0), giy = static_cast<float>(0), giz = static_cast<float>(0);
|
571 |
+
out_t *gOut_ptr_NCDHW = &grad_out;
|
572 |
+
out_t *gInp_ptr_NC = (out_t*)grad_vals;
|
573 |
+
out_t *inp_ptr_NC = (out_t*)vals;
|
574 |
+
|
575 |
+
// calculate bilinear weighted pixel value and set output pixel
|
576 |
+
{
|
577 |
+
out_t gOut = *gOut_ptr_NCDHW;
|
578 |
+
|
579 |
+
// calculate and set grad_input
|
580 |
+
safe_add_3d(gInp_ptr_NC, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw * gOut);
|
581 |
+
safe_add_3d(gInp_ptr_NC, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne * gOut);
|
582 |
+
safe_add_3d(gInp_ptr_NC, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw * gOut);
|
583 |
+
safe_add_3d(gInp_ptr_NC, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse * gOut);
|
584 |
+
safe_add_3d(gInp_ptr_NC, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw * gOut);
|
585 |
+
safe_add_3d(gInp_ptr_NC, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne * gOut);
|
586 |
+
safe_add_3d(gInp_ptr_NC, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw * gOut);
|
587 |
+
safe_add_3d(gInp_ptr_NC, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse * gOut);
|
588 |
+
|
589 |
+
// calculate grad_grid
|
590 |
+
if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
|
591 |
+
out_t tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW];
|
592 |
+
gix -= (iy_bse - iy) * (iz_bse - iz) * dot(tnw_val, gOut);
|
593 |
+
giy -= (ix_bse - ix) * (iz_bse - iz) * dot(tnw_val, gOut);
|
594 |
+
giz -= (ix_bse - ix) * (iy_bse - iy) * dot(tnw_val, gOut);
|
595 |
+
}
|
596 |
+
if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
|
597 |
+
out_t tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW];
|
598 |
+
gix += (iy_bsw - iy) * (iz_bsw - iz) * dot(tne_val, gOut);
|
599 |
+
giy -= (ix - ix_bsw) * (iz_bsw - iz) * dot(tne_val, gOut);
|
600 |
+
giz -= (ix - ix_bsw) * (iy_bsw - iy) * dot(tne_val, gOut);
|
601 |
+
}
|
602 |
+
if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
|
603 |
+
out_t tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW];
|
604 |
+
gix -= (iy - iy_bne) * (iz_bne - iz) * dot(tsw_val, gOut);
|
605 |
+
giy += (ix_bne - ix) * (iz_bne - iz) * dot(tsw_val, gOut);
|
606 |
+
giz -= (ix_bne - ix) * (iy - iy_bne) * dot(tsw_val, gOut);
|
607 |
+
}
|
608 |
+
if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
|
609 |
+
out_t tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW];
|
610 |
+
gix += (iy - iy_bnw) * (iz_bnw - iz) * dot(tse_val, gOut);
|
611 |
+
giy += (ix - ix_bnw) * (iz_bnw - iz) * dot(tse_val, gOut);
|
612 |
+
giz -= (ix - ix_bnw) * (iy - iy_bnw) * dot(tse_val, gOut);
|
613 |
+
}
|
614 |
+
if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
|
615 |
+
out_t bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW];
|
616 |
+
gix -= (iy_tse - iy) * (iz - iz_tse) * dot(bnw_val, gOut);
|
617 |
+
giy -= (ix_tse - ix) * (iz - iz_tse) * dot(bnw_val, gOut);
|
618 |
+
giz += (ix_tse - ix) * (iy_tse - iy) * dot(bnw_val, gOut);
|
619 |
+
}
|
620 |
+
if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
|
621 |
+
out_t bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW];
|
622 |
+
gix += (iy_tsw - iy) * (iz - iz_tsw) * dot(bne_val, gOut);
|
623 |
+
giy -= (ix - ix_tsw) * (iz - iz_tsw) * dot(bne_val, gOut);
|
624 |
+
giz += (ix - ix_tsw) * (iy_tsw - iy) * dot(bne_val, gOut);
|
625 |
+
}
|
626 |
+
if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
|
627 |
+
out_t bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW];
|
628 |
+
gix -= (iy - iy_tne) * (iz - iz_tne) * dot(bsw_val, gOut);
|
629 |
+
giy += (ix_tne - ix) * (iz - iz_tne) * dot(bsw_val, gOut);
|
630 |
+
giz += (ix_tne - ix) * (iy - iy_tne) * dot(bsw_val, gOut);
|
631 |
+
}
|
632 |
+
if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
|
633 |
+
out_t bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW];
|
634 |
+
gix += (iy - iy_tnw) * (iz - iz_tnw) * dot(bse_val, gOut);
|
635 |
+
giy += (ix - ix_tnw) * (iz - iz_tnw) * dot(bse_val, gOut);
|
636 |
+
giz += (ix - ix_tnw) * (iy - iy_tnw) * dot(bse_val, gOut);
|
637 |
+
}
|
638 |
+
}
|
639 |
+
|
640 |
+
return make_float3(gix_mult * gix, giy_mult * giy, giz_mult * giz);
|
641 |
+
}
|
642 |
+
|
643 |
+
template<typename out_t>
|
644 |
+
struct GridSamplerChlast {
|
645 |
+
static __forceinline__ __device__ out_t forward(int C, int inp_D, int inp_H, int inp_W,
|
646 |
+
float* vals, float3 pos, bool border) {
|
647 |
+
return grid_sample_chlast_forward<out_t>(C, inp_D, inp_H, inp_W, vals, pos, border);
|
648 |
+
}
|
649 |
+
|
650 |
+
static __forceinline__ __device__ float3 backward(int C, int inp_D, int inp_H, int inp_W,
|
651 |
+
float* vals, float* grad_vals, float3 pos, out_t grad_out, bool border) {
|
652 |
+
return grid_sample_chlast_backward<out_t>(C, inp_D, inp_H, inp_W, vals, grad_vals, pos, grad_out, border);
|
653 |
+
}
|
654 |
+
};
|
655 |
+
|
656 |
+
|
657 |
+
inline __host__ __device__ float min_component(float3 a) {
|
658 |
+
return fminf(fminf(a.x,a.y),a.z);
|
659 |
+
}
|
660 |
+
|
661 |
+
inline __host__ __device__ float max_component(float3 a) {
|
662 |
+
return fmaxf(fmaxf(a.x,a.y),a.z);
|
663 |
+
}
|
664 |
+
|
665 |
+
inline __host__ __device__ float3 abs(float3 a) {
|
666 |
+
return make_float3(abs(a.x), abs(a.y), abs(a.z));
|
667 |
+
}
|
668 |
+
|
669 |
+
__forceinline__ __device__ bool ray_aabb_hit(float3 p0, float3 p1, float3 raypos, float3 raydir) {
|
670 |
+
float3 t0 = (p0 - raypos) / raydir;
|
671 |
+
float3 t1 = (p1 - raypos) / raydir;
|
672 |
+
float3 tmin = fminf(t0,t1), tmax = fmaxf(t0,t1);
|
673 |
+
|
674 |
+
return max_component(tmin) <= min_component(tmax);
|
675 |
+
}
|
676 |
+
|
677 |
+
__forceinline__ __device__ bool ray_aabb_hit_ird(float3 p0, float3 p1, float3 raypos, float3 ird) {
|
678 |
+
float3 t0 = (p0 - raypos) * ird;
|
679 |
+
float3 t1 = (p1 - raypos) * ird;
|
680 |
+
float3 tmin = fminf(t0,t1), tmax = fmaxf(t0,t1);
|
681 |
+
|
682 |
+
return max_component(tmin) <= min_component(tmax);
|
683 |
+
|
684 |
+
}
|
685 |
+
__forceinline__ __device__ void ray_aabb_hit_ird_tminmax(float3 p0, float3 p1,
|
686 |
+
float3 raypos, float3 ird, float &otmin, float &otmax) {
|
687 |
+
float3 t0 = (p0 - raypos) * ird;
|
688 |
+
float3 t1 = (p1 - raypos) * ird;
|
689 |
+
float3 tmin = fminf(t0,t1), tmax = fmaxf(t0,t1);
|
690 |
+
tmin = fminf(t0,t1);
|
691 |
+
tmax = fmaxf(t0,t1);
|
692 |
+
otmin = max_component(tmin);
|
693 |
+
otmax = min_component(tmax);
|
694 |
+
}
|
695 |
+
|
696 |
+
inline __device__ bool aabb_intersect(float3 p0, float3 p1, float3 r0, float3 rd, float &tmin, float &tmax) {
|
697 |
+
float tymin, tymax, tzmin, tzmax;
|
698 |
+
const float3 bounds[2] = {p0, p1};
|
699 |
+
float3 ird = 1.0f/rd;
|
700 |
+
int sx = (ird.x<0) ? 1 : 0;
|
701 |
+
int sy = (ird.y<0) ? 1 : 0;
|
702 |
+
int sz = (ird.z<0) ? 1 : 0;
|
703 |
+
tmin = (bounds[sx].x - r0.x) * ird.x;
|
704 |
+
tmax = (bounds[1-sx].x - r0.x) * ird.x;
|
705 |
+
tymin = (bounds[sy].y - r0.y) * ird.y;
|
706 |
+
tymax = (bounds[1-sy].y - r0.y) * ird.y;
|
707 |
+
|
708 |
+
if ((tmin > tymax) || (tymin > tmax))
|
709 |
+
return false;
|
710 |
+
if (tymin > tmin)
|
711 |
+
tmin = tymin;
|
712 |
+
if (tymax < tmax)
|
713 |
+
tmax = tymax;
|
714 |
+
|
715 |
+
tzmin = (bounds[sz].z - r0.z) * ird.z;
|
716 |
+
tzmax = (bounds[1-sz].z - r0.z) * ird.z;
|
717 |
+
|
718 |
+
if ((tmin > tzmax) || (tzmin > tmax))
|
719 |
+
return false;
|
720 |
+
if (tzmin > tmin)
|
721 |
+
tmin = tzmin;
|
722 |
+
if (tzmax < tmax)
|
723 |
+
tmax = tzmax;
|
724 |
+
|
725 |
+
return true;
|
726 |
+
}
|
727 |
+
|
728 |
+
template<bool sortboxes, int maxhitboxes, bool sync, class PrimTransfT>
|
729 |
+
static __forceinline__ __device__ void ray_subset_fixedbvh(
|
730 |
+
unsigned warpmask,
|
731 |
+
int K,
|
732 |
+
float3 raypos,
|
733 |
+
float3 raydir,
|
734 |
+
float2 tminmax,
|
735 |
+
float2 &rtminmax,
|
736 |
+
int * sortedobjid,
|
737 |
+
int2 * nodechildren,
|
738 |
+
float3 * nodeaabb,
|
739 |
+
const typename PrimTransfT::Data & primtransf_data,
|
740 |
+
int *hitboxes,
|
741 |
+
int & num) {
|
742 |
+
float3 iraydir = 1.0f/raydir;
|
743 |
+
int stack[64];
|
744 |
+
int* stack_ptr = stack;
|
745 |
+
*stack_ptr++ = -1;
|
746 |
+
int node = 0;
|
747 |
+
do {
|
748 |
+
// check if we're in a leaf
|
749 |
+
if (node >= (K - 1)) {
|
750 |
+
{
|
751 |
+
int k = node - (K - 1);
|
752 |
+
|
753 |
+
float3 r0, rd;
|
754 |
+
PrimTransfT::forward2(primtransf_data, k, raypos, raydir, r0, rd);
|
755 |
+
|
756 |
+
float3 ird = 1.0f/rd;
|
757 |
+
float3 t0 = (-1.f - r0) * ird;
|
758 |
+
float3 t1 = (1.f - r0) * ird;
|
759 |
+
float3 tmin = fminf(t0,t1), tmax = fmaxf(t0,t1);
|
760 |
+
|
761 |
+
float trmin = max_component(tmin);
|
762 |
+
float trmax = min_component(tmax);
|
763 |
+
|
764 |
+
bool intersection = trmin <= trmax;
|
765 |
+
|
766 |
+
if (intersection) {
|
767 |
+
// hit
|
768 |
+
rtminmax.x = fminf(rtminmax.x, trmin);
|
769 |
+
rtminmax.y = fmaxf(rtminmax.y, trmax);
|
770 |
+
}
|
771 |
+
|
772 |
+
if (sync) {
|
773 |
+
intersection = __any_sync(warpmask, intersection);
|
774 |
+
}
|
775 |
+
|
776 |
+
if (intersection) {
|
777 |
+
if (sortboxes) {
|
778 |
+
if (num < maxhitboxes) {
|
779 |
+
int j = num - 1;
|
780 |
+
while (j >= 0 && hitboxes[j] > k) {
|
781 |
+
hitboxes[j + 1] = hitboxes[j];
|
782 |
+
j = j - 1;
|
783 |
+
}
|
784 |
+
hitboxes[j + 1] = k;
|
785 |
+
num++;
|
786 |
+
}
|
787 |
+
} else {
|
788 |
+
if (num < maxhitboxes) {
|
789 |
+
hitboxes[num++] = k;
|
790 |
+
}
|
791 |
+
}
|
792 |
+
}
|
793 |
+
}
|
794 |
+
|
795 |
+
node = *--stack_ptr;
|
796 |
+
} else {
|
797 |
+
int2 children = make_int2(node * 2 + 1, node * 2 + 2);
|
798 |
+
|
799 |
+
// check if we're in each child's bbox
|
800 |
+
float3 * nodeaabb_ptr = nodeaabb + children.x * 2;
|
801 |
+
bool traverse_l = ray_aabb_hit_ird(nodeaabb_ptr[0], nodeaabb_ptr[1], raypos, iraydir);
|
802 |
+
bool traverse_r = ray_aabb_hit_ird(nodeaabb_ptr[2], nodeaabb_ptr[3], raypos, iraydir);
|
803 |
+
|
804 |
+
if (sync) {
|
805 |
+
traverse_l = __any_sync(warpmask, traverse_l);
|
806 |
+
traverse_r = __any_sync(warpmask, traverse_r);
|
807 |
+
}
|
808 |
+
|
809 |
+
// update stack
|
810 |
+
if (!traverse_l && !traverse_r) {
|
811 |
+
node = *--stack_ptr;
|
812 |
+
} else {
|
813 |
+
node = traverse_l ? children.x : children.y;
|
814 |
+
if (traverse_l && traverse_r) {
|
815 |
+
*stack_ptr++ = children.y;
|
816 |
+
}
|
817 |
+
}
|
818 |
+
|
819 |
+
if (sync) {
|
820 |
+
__syncwarp(warpmask);
|
821 |
+
}
|
822 |
+
}
|
823 |
+
} while (node != -1);
|
824 |
+
}
|
825 |
+
|
826 |
+
template<bool sortboxes, int maxhitboxes, bool sync, class PrimTransfT>
|
827 |
+
struct RaySubsetFixedBVH {
|
828 |
+
static __forceinline__ __device__ void forward(
|
829 |
+
unsigned warpmask,
|
830 |
+
int K,
|
831 |
+
float3 raypos,
|
832 |
+
float3 raydir,
|
833 |
+
float2 tminmax,
|
834 |
+
float2 &rtminmax,
|
835 |
+
int * sortedobjid,
|
836 |
+
int2 * nodechildren,
|
837 |
+
float3 * nodeaabb,
|
838 |
+
const typename PrimTransfT::Data & primtransf_data,
|
839 |
+
int *hitboxes,
|
840 |
+
int & num) {
|
841 |
+
ray_subset_fixedbvh<sortboxes, maxhitboxes, sync, PrimTransfT>(
|
842 |
+
warpmask, K, raypos, raydir, tminmax, rtminmax,
|
843 |
+
sortedobjid, nodechildren, nodeaabb, primtransf_data, hitboxes, num);
|
844 |
+
}
|
845 |
+
};
|
846 |
+
|
847 |
+
#endif
|
dva/mvp/extensions/utils/helper_math.h
ADDED
@@ -0,0 +1,1453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/**
|
2 |
+
* Copyright 1993-2013 NVIDIA Corporation. All rights reserved.
|
3 |
+
*
|
4 |
+
* Please refer to the NVIDIA end user license agreement (EULA) associated
|
5 |
+
* with this source code for terms and conditions that govern your use of
|
6 |
+
* this software. Any use, reproduction, disclosure, or distribution of
|
7 |
+
* this software and related documentation outside the terms of the EULA
|
8 |
+
* is strictly prohibited.
|
9 |
+
*
|
10 |
+
*/
|
11 |
+
|
12 |
+
/*
|
13 |
+
* This file implements common mathematical operations on vector types
|
14 |
+
* (float3, float4 etc.) since these are not provided as standard by CUDA.
|
15 |
+
*
|
16 |
+
* The syntax is modeled on the Cg standard library.
|
17 |
+
*
|
18 |
+
* This is part of the Helper library includes
|
19 |
+
*
|
20 |
+
* Thanks to Linh Hah for additions and fixes.
|
21 |
+
*/
|
22 |
+
|
23 |
+
#ifndef HELPER_MATH_H
|
24 |
+
#define HELPER_MATH_H
|
25 |
+
|
26 |
+
#include "cuda_runtime.h"
|
27 |
+
|
28 |
+
typedef unsigned int uint;
|
29 |
+
typedef unsigned short ushort;
|
30 |
+
|
31 |
+
#ifndef EXIT_WAIVED
|
32 |
+
#define EXIT_WAIVED 2
|
33 |
+
#endif
|
34 |
+
|
35 |
+
#ifndef __CUDACC__
|
36 |
+
#include <math.h>
|
37 |
+
|
38 |
+
////////////////////////////////////////////////////////////////////////////////
|
39 |
+
// host implementations of CUDA functions
|
40 |
+
////////////////////////////////////////////////////////////////////////////////
|
41 |
+
|
42 |
+
inline float fminf(float a, float b)
|
43 |
+
{
|
44 |
+
return a < b ? a : b;
|
45 |
+
}
|
46 |
+
|
47 |
+
inline float fmaxf(float a, float b)
|
48 |
+
{
|
49 |
+
return a > b ? a : b;
|
50 |
+
}
|
51 |
+
|
52 |
+
inline int max(int a, int b)
|
53 |
+
{
|
54 |
+
return a > b ? a : b;
|
55 |
+
}
|
56 |
+
|
57 |
+
inline int min(int a, int b)
|
58 |
+
{
|
59 |
+
return a < b ? a : b;
|
60 |
+
}
|
61 |
+
|
62 |
+
inline float rsqrtf(float x)
|
63 |
+
{
|
64 |
+
return 1.0f / sqrtf(x);
|
65 |
+
}
|
66 |
+
#endif
|
67 |
+
|
68 |
+
////////////////////////////////////////////////////////////////////////////////
|
69 |
+
// constructors
|
70 |
+
////////////////////////////////////////////////////////////////////////////////
|
71 |
+
|
72 |
+
inline __host__ __device__ float2 make_float2(float s)
|
73 |
+
{
|
74 |
+
return make_float2(s, s);
|
75 |
+
}
|
76 |
+
inline __host__ __device__ float2 make_float2(float3 a)
|
77 |
+
{
|
78 |
+
return make_float2(a.x, a.y);
|
79 |
+
}
|
80 |
+
inline __host__ __device__ float2 make_float2(int2 a)
|
81 |
+
{
|
82 |
+
return make_float2(float(a.x), float(a.y));
|
83 |
+
}
|
84 |
+
inline __host__ __device__ float2 make_float2(uint2 a)
|
85 |
+
{
|
86 |
+
return make_float2(float(a.x), float(a.y));
|
87 |
+
}
|
88 |
+
|
89 |
+
inline __host__ __device__ int2 make_int2(int s)
|
90 |
+
{
|
91 |
+
return make_int2(s, s);
|
92 |
+
}
|
93 |
+
inline __host__ __device__ int2 make_int2(int3 a)
|
94 |
+
{
|
95 |
+
return make_int2(a.x, a.y);
|
96 |
+
}
|
97 |
+
inline __host__ __device__ int2 make_int2(uint2 a)
|
98 |
+
{
|
99 |
+
return make_int2(int(a.x), int(a.y));
|
100 |
+
}
|
101 |
+
inline __host__ __device__ int2 make_int2(float2 a)
|
102 |
+
{
|
103 |
+
return make_int2(int(a.x), int(a.y));
|
104 |
+
}
|
105 |
+
|
106 |
+
inline __host__ __device__ uint2 make_uint2(uint s)
|
107 |
+
{
|
108 |
+
return make_uint2(s, s);
|
109 |
+
}
|
110 |
+
inline __host__ __device__ uint2 make_uint2(uint3 a)
|
111 |
+
{
|
112 |
+
return make_uint2(a.x, a.y);
|
113 |
+
}
|
114 |
+
inline __host__ __device__ uint2 make_uint2(int2 a)
|
115 |
+
{
|
116 |
+
return make_uint2(uint(a.x), uint(a.y));
|
117 |
+
}
|
118 |
+
|
119 |
+
inline __host__ __device__ float3 make_float3(float s)
|
120 |
+
{
|
121 |
+
return make_float3(s, s, s);
|
122 |
+
}
|
123 |
+
inline __host__ __device__ float3 make_float3(float2 a)
|
124 |
+
{
|
125 |
+
return make_float3(a.x, a.y, 0.0f);
|
126 |
+
}
|
127 |
+
inline __host__ __device__ float3 make_float3(float2 a, float s)
|
128 |
+
{
|
129 |
+
return make_float3(a.x, a.y, s);
|
130 |
+
}
|
131 |
+
inline __host__ __device__ float3 make_float3(float4 a)
|
132 |
+
{
|
133 |
+
return make_float3(a.x, a.y, a.z);
|
134 |
+
}
|
135 |
+
inline __host__ __device__ float3 make_float3(int3 a)
|
136 |
+
{
|
137 |
+
return make_float3(float(a.x), float(a.y), float(a.z));
|
138 |
+
}
|
139 |
+
inline __host__ __device__ float3 make_float3(uint3 a)
|
140 |
+
{
|
141 |
+
return make_float3(float(a.x), float(a.y), float(a.z));
|
142 |
+
}
|
143 |
+
|
144 |
+
inline __host__ __device__ int3 make_int3(int s)
|
145 |
+
{
|
146 |
+
return make_int3(s, s, s);
|
147 |
+
}
|
148 |
+
inline __host__ __device__ int3 make_int3(int2 a)
|
149 |
+
{
|
150 |
+
return make_int3(a.x, a.y, 0);
|
151 |
+
}
|
152 |
+
inline __host__ __device__ int3 make_int3(int2 a, int s)
|
153 |
+
{
|
154 |
+
return make_int3(a.x, a.y, s);
|
155 |
+
}
|
156 |
+
inline __host__ __device__ int3 make_int3(uint3 a)
|
157 |
+
{
|
158 |
+
return make_int3(int(a.x), int(a.y), int(a.z));
|
159 |
+
}
|
160 |
+
inline __host__ __device__ int3 make_int3(float3 a)
|
161 |
+
{
|
162 |
+
return make_int3(int(a.x), int(a.y), int(a.z));
|
163 |
+
}
|
164 |
+
|
165 |
+
inline __host__ __device__ uint3 make_uint3(uint s)
|
166 |
+
{
|
167 |
+
return make_uint3(s, s, s);
|
168 |
+
}
|
169 |
+
inline __host__ __device__ uint3 make_uint3(uint2 a)
|
170 |
+
{
|
171 |
+
return make_uint3(a.x, a.y, 0);
|
172 |
+
}
|
173 |
+
inline __host__ __device__ uint3 make_uint3(uint2 a, uint s)
|
174 |
+
{
|
175 |
+
return make_uint3(a.x, a.y, s);
|
176 |
+
}
|
177 |
+
inline __host__ __device__ uint3 make_uint3(uint4 a)
|
178 |
+
{
|
179 |
+
return make_uint3(a.x, a.y, a.z);
|
180 |
+
}
|
181 |
+
inline __host__ __device__ uint3 make_uint3(int3 a)
|
182 |
+
{
|
183 |
+
return make_uint3(uint(a.x), uint(a.y), uint(a.z));
|
184 |
+
}
|
185 |
+
|
186 |
+
inline __host__ __device__ float4 make_float4(float s)
|
187 |
+
{
|
188 |
+
return make_float4(s, s, s, s);
|
189 |
+
}
|
190 |
+
inline __host__ __device__ float4 make_float4(float3 a)
|
191 |
+
{
|
192 |
+
return make_float4(a.x, a.y, a.z, 0.0f);
|
193 |
+
}
|
194 |
+
inline __host__ __device__ float4 make_float4(float3 a, float w)
|
195 |
+
{
|
196 |
+
return make_float4(a.x, a.y, a.z, w);
|
197 |
+
}
|
198 |
+
inline __host__ __device__ float4 make_float4(int4 a)
|
199 |
+
{
|
200 |
+
return make_float4(float(a.x), float(a.y), float(a.z), float(a.w));
|
201 |
+
}
|
202 |
+
inline __host__ __device__ float4 make_float4(uint4 a)
|
203 |
+
{
|
204 |
+
return make_float4(float(a.x), float(a.y), float(a.z), float(a.w));
|
205 |
+
}
|
206 |
+
|
207 |
+
inline __host__ __device__ int4 make_int4(int s)
|
208 |
+
{
|
209 |
+
return make_int4(s, s, s, s);
|
210 |
+
}
|
211 |
+
inline __host__ __device__ int4 make_int4(int3 a)
|
212 |
+
{
|
213 |
+
return make_int4(a.x, a.y, a.z, 0);
|
214 |
+
}
|
215 |
+
inline __host__ __device__ int4 make_int4(int3 a, int w)
|
216 |
+
{
|
217 |
+
return make_int4(a.x, a.y, a.z, w);
|
218 |
+
}
|
219 |
+
inline __host__ __device__ int4 make_int4(uint4 a)
|
220 |
+
{
|
221 |
+
return make_int4(int(a.x), int(a.y), int(a.z), int(a.w));
|
222 |
+
}
|
223 |
+
inline __host__ __device__ int4 make_int4(float4 a)
|
224 |
+
{
|
225 |
+
return make_int4(int(a.x), int(a.y), int(a.z), int(a.w));
|
226 |
+
}
|
227 |
+
|
228 |
+
|
229 |
+
inline __host__ __device__ uint4 make_uint4(uint s)
|
230 |
+
{
|
231 |
+
return make_uint4(s, s, s, s);
|
232 |
+
}
|
233 |
+
inline __host__ __device__ uint4 make_uint4(uint3 a)
|
234 |
+
{
|
235 |
+
return make_uint4(a.x, a.y, a.z, 0);
|
236 |
+
}
|
237 |
+
inline __host__ __device__ uint4 make_uint4(uint3 a, uint w)
|
238 |
+
{
|
239 |
+
return make_uint4(a.x, a.y, a.z, w);
|
240 |
+
}
|
241 |
+
inline __host__ __device__ uint4 make_uint4(int4 a)
|
242 |
+
{
|
243 |
+
return make_uint4(uint(a.x), uint(a.y), uint(a.z), uint(a.w));
|
244 |
+
}
|
245 |
+
|
246 |
+
////////////////////////////////////////////////////////////////////////////////
|
247 |
+
// negate
|
248 |
+
////////////////////////////////////////////////////////////////////////////////
|
249 |
+
|
250 |
+
inline __host__ __device__ float2 operator-(float2 &a)
|
251 |
+
{
|
252 |
+
return make_float2(-a.x, -a.y);
|
253 |
+
}
|
254 |
+
inline __host__ __device__ int2 operator-(int2 &a)
|
255 |
+
{
|
256 |
+
return make_int2(-a.x, -a.y);
|
257 |
+
}
|
258 |
+
inline __host__ __device__ float3 operator-(float3 &a)
|
259 |
+
{
|
260 |
+
return make_float3(-a.x, -a.y, -a.z);
|
261 |
+
}
|
262 |
+
inline __host__ __device__ int3 operator-(int3 &a)
|
263 |
+
{
|
264 |
+
return make_int3(-a.x, -a.y, -a.z);
|
265 |
+
}
|
266 |
+
inline __host__ __device__ float4 operator-(float4 &a)
|
267 |
+
{
|
268 |
+
return make_float4(-a.x, -a.y, -a.z, -a.w);
|
269 |
+
}
|
270 |
+
inline __host__ __device__ int4 operator-(int4 &a)
|
271 |
+
{
|
272 |
+
return make_int4(-a.x, -a.y, -a.z, -a.w);
|
273 |
+
}
|
274 |
+
|
275 |
+
////////////////////////////////////////////////////////////////////////////////
|
276 |
+
// addition
|
277 |
+
////////////////////////////////////////////////////////////////////////////////
|
278 |
+
|
279 |
+
inline __host__ __device__ float2 operator+(float2 a, float2 b)
|
280 |
+
{
|
281 |
+
return make_float2(a.x + b.x, a.y + b.y);
|
282 |
+
}
|
283 |
+
inline __host__ __device__ void operator+=(float2 &a, float2 b)
|
284 |
+
{
|
285 |
+
a.x += b.x;
|
286 |
+
a.y += b.y;
|
287 |
+
}
|
288 |
+
inline __host__ __device__ float2 operator+(float2 a, float b)
|
289 |
+
{
|
290 |
+
return make_float2(a.x + b, a.y + b);
|
291 |
+
}
|
292 |
+
inline __host__ __device__ float2 operator+(float b, float2 a)
|
293 |
+
{
|
294 |
+
return make_float2(a.x + b, a.y + b);
|
295 |
+
}
|
296 |
+
inline __host__ __device__ void operator+=(float2 &a, float b)
|
297 |
+
{
|
298 |
+
a.x += b;
|
299 |
+
a.y += b;
|
300 |
+
}
|
301 |
+
|
302 |
+
inline __host__ __device__ int2 operator+(int2 a, int2 b)
|
303 |
+
{
|
304 |
+
return make_int2(a.x + b.x, a.y + b.y);
|
305 |
+
}
|
306 |
+
inline __host__ __device__ void operator+=(int2 &a, int2 b)
|
307 |
+
{
|
308 |
+
a.x += b.x;
|
309 |
+
a.y += b.y;
|
310 |
+
}
|
311 |
+
inline __host__ __device__ int2 operator+(int2 a, int b)
|
312 |
+
{
|
313 |
+
return make_int2(a.x + b, a.y + b);
|
314 |
+
}
|
315 |
+
inline __host__ __device__ int2 operator+(int b, int2 a)
|
316 |
+
{
|
317 |
+
return make_int2(a.x + b, a.y + b);
|
318 |
+
}
|
319 |
+
inline __host__ __device__ void operator+=(int2 &a, int b)
|
320 |
+
{
|
321 |
+
a.x += b;
|
322 |
+
a.y += b;
|
323 |
+
}
|
324 |
+
|
325 |
+
inline __host__ __device__ uint2 operator+(uint2 a, uint2 b)
|
326 |
+
{
|
327 |
+
return make_uint2(a.x + b.x, a.y + b.y);
|
328 |
+
}
|
329 |
+
inline __host__ __device__ void operator+=(uint2 &a, uint2 b)
|
330 |
+
{
|
331 |
+
a.x += b.x;
|
332 |
+
a.y += b.y;
|
333 |
+
}
|
334 |
+
inline __host__ __device__ uint2 operator+(uint2 a, uint b)
|
335 |
+
{
|
336 |
+
return make_uint2(a.x + b, a.y + b);
|
337 |
+
}
|
338 |
+
inline __host__ __device__ uint2 operator+(uint b, uint2 a)
|
339 |
+
{
|
340 |
+
return make_uint2(a.x + b, a.y + b);
|
341 |
+
}
|
342 |
+
inline __host__ __device__ void operator+=(uint2 &a, uint b)
|
343 |
+
{
|
344 |
+
a.x += b;
|
345 |
+
a.y += b;
|
346 |
+
}
|
347 |
+
|
348 |
+
|
349 |
+
inline __host__ __device__ float3 operator+(float3 a, float3 b)
|
350 |
+
{
|
351 |
+
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
|
352 |
+
}
|
353 |
+
inline __host__ __device__ void operator+=(float3 &a, float3 b)
|
354 |
+
{
|
355 |
+
a.x += b.x;
|
356 |
+
a.y += b.y;
|
357 |
+
a.z += b.z;
|
358 |
+
}
|
359 |
+
inline __host__ __device__ float3 operator+(float3 a, float b)
|
360 |
+
{
|
361 |
+
return make_float3(a.x + b, a.y + b, a.z + b);
|
362 |
+
}
|
363 |
+
inline __host__ __device__ void operator+=(float3 &a, float b)
|
364 |
+
{
|
365 |
+
a.x += b;
|
366 |
+
a.y += b;
|
367 |
+
a.z += b;
|
368 |
+
}
|
369 |
+
|
370 |
+
inline __host__ __device__ int3 operator+(int3 a, int3 b)
|
371 |
+
{
|
372 |
+
return make_int3(a.x + b.x, a.y + b.y, a.z + b.z);
|
373 |
+
}
|
374 |
+
inline __host__ __device__ void operator+=(int3 &a, int3 b)
|
375 |
+
{
|
376 |
+
a.x += b.x;
|
377 |
+
a.y += b.y;
|
378 |
+
a.z += b.z;
|
379 |
+
}
|
380 |
+
inline __host__ __device__ int3 operator+(int3 a, int b)
|
381 |
+
{
|
382 |
+
return make_int3(a.x + b, a.y + b, a.z + b);
|
383 |
+
}
|
384 |
+
inline __host__ __device__ void operator+=(int3 &a, int b)
|
385 |
+
{
|
386 |
+
a.x += b;
|
387 |
+
a.y += b;
|
388 |
+
a.z += b;
|
389 |
+
}
|
390 |
+
|
391 |
+
inline __host__ __device__ uint3 operator+(uint3 a, uint3 b)
|
392 |
+
{
|
393 |
+
return make_uint3(a.x + b.x, a.y + b.y, a.z + b.z);
|
394 |
+
}
|
395 |
+
inline __host__ __device__ void operator+=(uint3 &a, uint3 b)
|
396 |
+
{
|
397 |
+
a.x += b.x;
|
398 |
+
a.y += b.y;
|
399 |
+
a.z += b.z;
|
400 |
+
}
|
401 |
+
inline __host__ __device__ uint3 operator+(uint3 a, uint b)
|
402 |
+
{
|
403 |
+
return make_uint3(a.x + b, a.y + b, a.z + b);
|
404 |
+
}
|
405 |
+
inline __host__ __device__ void operator+=(uint3 &a, uint b)
|
406 |
+
{
|
407 |
+
a.x += b;
|
408 |
+
a.y += b;
|
409 |
+
a.z += b;
|
410 |
+
}
|
411 |
+
|
412 |
+
inline __host__ __device__ int3 operator+(int b, int3 a)
|
413 |
+
{
|
414 |
+
return make_int3(a.x + b, a.y + b, a.z + b);
|
415 |
+
}
|
416 |
+
inline __host__ __device__ uint3 operator+(uint b, uint3 a)
|
417 |
+
{
|
418 |
+
return make_uint3(a.x + b, a.y + b, a.z + b);
|
419 |
+
}
|
420 |
+
inline __host__ __device__ float3 operator+(float b, float3 a)
|
421 |
+
{
|
422 |
+
return make_float3(a.x + b, a.y + b, a.z + b);
|
423 |
+
}
|
424 |
+
|
425 |
+
inline __host__ __device__ float4 operator+(float4 a, float4 b)
|
426 |
+
{
|
427 |
+
return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
|
428 |
+
}
|
429 |
+
inline __host__ __device__ void operator+=(float4 &a, float4 b)
|
430 |
+
{
|
431 |
+
a.x += b.x;
|
432 |
+
a.y += b.y;
|
433 |
+
a.z += b.z;
|
434 |
+
a.w += b.w;
|
435 |
+
}
|
436 |
+
inline __host__ __device__ float4 operator+(float4 a, float b)
|
437 |
+
{
|
438 |
+
return make_float4(a.x + b, a.y + b, a.z + b, a.w + b);
|
439 |
+
}
|
440 |
+
inline __host__ __device__ float4 operator+(float b, float4 a)
|
441 |
+
{
|
442 |
+
return make_float4(a.x + b, a.y + b, a.z + b, a.w + b);
|
443 |
+
}
|
444 |
+
inline __host__ __device__ void operator+=(float4 &a, float b)
|
445 |
+
{
|
446 |
+
a.x += b;
|
447 |
+
a.y += b;
|
448 |
+
a.z += b;
|
449 |
+
a.w += b;
|
450 |
+
}
|
451 |
+
|
452 |
+
inline __host__ __device__ int4 operator+(int4 a, int4 b)
|
453 |
+
{
|
454 |
+
return make_int4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
|
455 |
+
}
|
456 |
+
inline __host__ __device__ void operator+=(int4 &a, int4 b)
|
457 |
+
{
|
458 |
+
a.x += b.x;
|
459 |
+
a.y += b.y;
|
460 |
+
a.z += b.z;
|
461 |
+
a.w += b.w;
|
462 |
+
}
|
463 |
+
inline __host__ __device__ int4 operator+(int4 a, int b)
|
464 |
+
{
|
465 |
+
return make_int4(a.x + b, a.y + b, a.z + b, a.w + b);
|
466 |
+
}
|
467 |
+
inline __host__ __device__ int4 operator+(int b, int4 a)
|
468 |
+
{
|
469 |
+
return make_int4(a.x + b, a.y + b, a.z + b, a.w + b);
|
470 |
+
}
|
471 |
+
inline __host__ __device__ void operator+=(int4 &a, int b)
|
472 |
+
{
|
473 |
+
a.x += b;
|
474 |
+
a.y += b;
|
475 |
+
a.z += b;
|
476 |
+
a.w += b;
|
477 |
+
}
|
478 |
+
|
479 |
+
inline __host__ __device__ uint4 operator+(uint4 a, uint4 b)
|
480 |
+
{
|
481 |
+
return make_uint4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
|
482 |
+
}
|
483 |
+
inline __host__ __device__ void operator+=(uint4 &a, uint4 b)
|
484 |
+
{
|
485 |
+
a.x += b.x;
|
486 |
+
a.y += b.y;
|
487 |
+
a.z += b.z;
|
488 |
+
a.w += b.w;
|
489 |
+
}
|
490 |
+
inline __host__ __device__ uint4 operator+(uint4 a, uint b)
|
491 |
+
{
|
492 |
+
return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b);
|
493 |
+
}
|
494 |
+
inline __host__ __device__ uint4 operator+(uint b, uint4 a)
|
495 |
+
{
|
496 |
+
return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b);
|
497 |
+
}
|
498 |
+
inline __host__ __device__ void operator+=(uint4 &a, uint b)
|
499 |
+
{
|
500 |
+
a.x += b;
|
501 |
+
a.y += b;
|
502 |
+
a.z += b;
|
503 |
+
a.w += b;
|
504 |
+
}
|
505 |
+
|
506 |
+
////////////////////////////////////////////////////////////////////////////////
|
507 |
+
// subtract
|
508 |
+
////////////////////////////////////////////////////////////////////////////////
|
509 |
+
|
510 |
+
inline __host__ __device__ float2 operator-(float2 a, float2 b)
|
511 |
+
{
|
512 |
+
return make_float2(a.x - b.x, a.y - b.y);
|
513 |
+
}
|
514 |
+
inline __host__ __device__ void operator-=(float2 &a, float2 b)
|
515 |
+
{
|
516 |
+
a.x -= b.x;
|
517 |
+
a.y -= b.y;
|
518 |
+
}
|
519 |
+
inline __host__ __device__ float2 operator-(float2 a, float b)
|
520 |
+
{
|
521 |
+
return make_float2(a.x - b, a.y - b);
|
522 |
+
}
|
523 |
+
inline __host__ __device__ float2 operator-(float b, float2 a)
|
524 |
+
{
|
525 |
+
return make_float2(b - a.x, b - a.y);
|
526 |
+
}
|
527 |
+
inline __host__ __device__ void operator-=(float2 &a, float b)
|
528 |
+
{
|
529 |
+
a.x -= b;
|
530 |
+
a.y -= b;
|
531 |
+
}
|
532 |
+
|
533 |
+
inline __host__ __device__ int2 operator-(int2 a, int2 b)
|
534 |
+
{
|
535 |
+
return make_int2(a.x - b.x, a.y - b.y);
|
536 |
+
}
|
537 |
+
inline __host__ __device__ void operator-=(int2 &a, int2 b)
|
538 |
+
{
|
539 |
+
a.x -= b.x;
|
540 |
+
a.y -= b.y;
|
541 |
+
}
|
542 |
+
inline __host__ __device__ int2 operator-(int2 a, int b)
|
543 |
+
{
|
544 |
+
return make_int2(a.x - b, a.y - b);
|
545 |
+
}
|
546 |
+
inline __host__ __device__ int2 operator-(int b, int2 a)
|
547 |
+
{
|
548 |
+
return make_int2(b - a.x, b - a.y);
|
549 |
+
}
|
550 |
+
inline __host__ __device__ void operator-=(int2 &a, int b)
|
551 |
+
{
|
552 |
+
a.x -= b;
|
553 |
+
a.y -= b;
|
554 |
+
}
|
555 |
+
|
556 |
+
inline __host__ __device__ uint2 operator-(uint2 a, uint2 b)
|
557 |
+
{
|
558 |
+
return make_uint2(a.x - b.x, a.y - b.y);
|
559 |
+
}
|
560 |
+
inline __host__ __device__ void operator-=(uint2 &a, uint2 b)
|
561 |
+
{
|
562 |
+
a.x -= b.x;
|
563 |
+
a.y -= b.y;
|
564 |
+
}
|
565 |
+
inline __host__ __device__ uint2 operator-(uint2 a, uint b)
|
566 |
+
{
|
567 |
+
return make_uint2(a.x - b, a.y - b);
|
568 |
+
}
|
569 |
+
inline __host__ __device__ uint2 operator-(uint b, uint2 a)
|
570 |
+
{
|
571 |
+
return make_uint2(b - a.x, b - a.y);
|
572 |
+
}
|
573 |
+
inline __host__ __device__ void operator-=(uint2 &a, uint b)
|
574 |
+
{
|
575 |
+
a.x -= b;
|
576 |
+
a.y -= b;
|
577 |
+
}
|
578 |
+
|
579 |
+
inline __host__ __device__ float3 operator-(float3 a, float3 b)
|
580 |
+
{
|
581 |
+
return make_float3(a.x - b.x, a.y - b.y, a.z - b.z);
|
582 |
+
}
|
583 |
+
inline __host__ __device__ void operator-=(float3 &a, float3 b)
|
584 |
+
{
|
585 |
+
a.x -= b.x;
|
586 |
+
a.y -= b.y;
|
587 |
+
a.z -= b.z;
|
588 |
+
}
|
589 |
+
inline __host__ __device__ float3 operator-(float3 a, float b)
|
590 |
+
{
|
591 |
+
return make_float3(a.x - b, a.y - b, a.z - b);
|
592 |
+
}
|
593 |
+
inline __host__ __device__ float3 operator-(float b, float3 a)
|
594 |
+
{
|
595 |
+
return make_float3(b - a.x, b - a.y, b - a.z);
|
596 |
+
}
|
597 |
+
inline __host__ __device__ void operator-=(float3 &a, float b)
|
598 |
+
{
|
599 |
+
a.x -= b;
|
600 |
+
a.y -= b;
|
601 |
+
a.z -= b;
|
602 |
+
}
|
603 |
+
|
604 |
+
inline __host__ __device__ int3 operator-(int3 a, int3 b)
|
605 |
+
{
|
606 |
+
return make_int3(a.x - b.x, a.y - b.y, a.z - b.z);
|
607 |
+
}
|
608 |
+
inline __host__ __device__ void operator-=(int3 &a, int3 b)
|
609 |
+
{
|
610 |
+
a.x -= b.x;
|
611 |
+
a.y -= b.y;
|
612 |
+
a.z -= b.z;
|
613 |
+
}
|
614 |
+
inline __host__ __device__ int3 operator-(int3 a, int b)
|
615 |
+
{
|
616 |
+
return make_int3(a.x - b, a.y - b, a.z - b);
|
617 |
+
}
|
618 |
+
inline __host__ __device__ int3 operator-(int b, int3 a)
|
619 |
+
{
|
620 |
+
return make_int3(b - a.x, b - a.y, b - a.z);
|
621 |
+
}
|
622 |
+
inline __host__ __device__ void operator-=(int3 &a, int b)
|
623 |
+
{
|
624 |
+
a.x -= b;
|
625 |
+
a.y -= b;
|
626 |
+
a.z -= b;
|
627 |
+
}
|
628 |
+
|
629 |
+
inline __host__ __device__ uint3 operator-(uint3 a, uint3 b)
|
630 |
+
{
|
631 |
+
return make_uint3(a.x - b.x, a.y - b.y, a.z - b.z);
|
632 |
+
}
|
633 |
+
inline __host__ __device__ void operator-=(uint3 &a, uint3 b)
|
634 |
+
{
|
635 |
+
a.x -= b.x;
|
636 |
+
a.y -= b.y;
|
637 |
+
a.z -= b.z;
|
638 |
+
}
|
639 |
+
inline __host__ __device__ uint3 operator-(uint3 a, uint b)
|
640 |
+
{
|
641 |
+
return make_uint3(a.x - b, a.y - b, a.z - b);
|
642 |
+
}
|
643 |
+
inline __host__ __device__ uint3 operator-(uint b, uint3 a)
|
644 |
+
{
|
645 |
+
return make_uint3(b - a.x, b - a.y, b - a.z);
|
646 |
+
}
|
647 |
+
inline __host__ __device__ void operator-=(uint3 &a, uint b)
|
648 |
+
{
|
649 |
+
a.x -= b;
|
650 |
+
a.y -= b;
|
651 |
+
a.z -= b;
|
652 |
+
}
|
653 |
+
|
654 |
+
inline __host__ __device__ float4 operator-(float4 a, float4 b)
|
655 |
+
{
|
656 |
+
return make_float4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
|
657 |
+
}
|
658 |
+
inline __host__ __device__ void operator-=(float4 &a, float4 b)
|
659 |
+
{
|
660 |
+
a.x -= b.x;
|
661 |
+
a.y -= b.y;
|
662 |
+
a.z -= b.z;
|
663 |
+
a.w -= b.w;
|
664 |
+
}
|
665 |
+
inline __host__ __device__ float4 operator-(float4 a, float b)
|
666 |
+
{
|
667 |
+
return make_float4(a.x - b, a.y - b, a.z - b, a.w - b);
|
668 |
+
}
|
669 |
+
inline __host__ __device__ void operator-=(float4 &a, float b)
|
670 |
+
{
|
671 |
+
a.x -= b;
|
672 |
+
a.y -= b;
|
673 |
+
a.z -= b;
|
674 |
+
a.w -= b;
|
675 |
+
}
|
676 |
+
|
677 |
+
inline __host__ __device__ int4 operator-(int4 a, int4 b)
|
678 |
+
{
|
679 |
+
return make_int4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
|
680 |
+
}
|
681 |
+
inline __host__ __device__ void operator-=(int4 &a, int4 b)
|
682 |
+
{
|
683 |
+
a.x -= b.x;
|
684 |
+
a.y -= b.y;
|
685 |
+
a.z -= b.z;
|
686 |
+
a.w -= b.w;
|
687 |
+
}
|
688 |
+
inline __host__ __device__ int4 operator-(int4 a, int b)
|
689 |
+
{
|
690 |
+
return make_int4(a.x - b, a.y - b, a.z - b, a.w - b);
|
691 |
+
}
|
692 |
+
inline __host__ __device__ int4 operator-(int b, int4 a)
|
693 |
+
{
|
694 |
+
return make_int4(b - a.x, b - a.y, b - a.z, b - a.w);
|
695 |
+
}
|
696 |
+
inline __host__ __device__ void operator-=(int4 &a, int b)
|
697 |
+
{
|
698 |
+
a.x -= b;
|
699 |
+
a.y -= b;
|
700 |
+
a.z -= b;
|
701 |
+
a.w -= b;
|
702 |
+
}
|
703 |
+
|
704 |
+
inline __host__ __device__ uint4 operator-(uint4 a, uint4 b)
|
705 |
+
{
|
706 |
+
return make_uint4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
|
707 |
+
}
|
708 |
+
inline __host__ __device__ void operator-=(uint4 &a, uint4 b)
|
709 |
+
{
|
710 |
+
a.x -= b.x;
|
711 |
+
a.y -= b.y;
|
712 |
+
a.z -= b.z;
|
713 |
+
a.w -= b.w;
|
714 |
+
}
|
715 |
+
inline __host__ __device__ uint4 operator-(uint4 a, uint b)
|
716 |
+
{
|
717 |
+
return make_uint4(a.x - b, a.y - b, a.z - b, a.w - b);
|
718 |
+
}
|
719 |
+
inline __host__ __device__ uint4 operator-(uint b, uint4 a)
|
720 |
+
{
|
721 |
+
return make_uint4(b - a.x, b - a.y, b - a.z, b - a.w);
|
722 |
+
}
|
723 |
+
inline __host__ __device__ void operator-=(uint4 &a, uint b)
|
724 |
+
{
|
725 |
+
a.x -= b;
|
726 |
+
a.y -= b;
|
727 |
+
a.z -= b;
|
728 |
+
a.w -= b;
|
729 |
+
}
|
730 |
+
|
731 |
+
////////////////////////////////////////////////////////////////////////////////
|
732 |
+
// multiply
|
733 |
+
////////////////////////////////////////////////////////////////////////////////
|
734 |
+
|
735 |
+
inline __host__ __device__ float2 operator*(float2 a, float2 b)
|
736 |
+
{
|
737 |
+
return make_float2(a.x * b.x, a.y * b.y);
|
738 |
+
}
|
739 |
+
inline __host__ __device__ void operator*=(float2 &a, float2 b)
|
740 |
+
{
|
741 |
+
a.x *= b.x;
|
742 |
+
a.y *= b.y;
|
743 |
+
}
|
744 |
+
inline __host__ __device__ float2 operator*(float2 a, float b)
|
745 |
+
{
|
746 |
+
return make_float2(a.x * b, a.y * b);
|
747 |
+
}
|
748 |
+
inline __host__ __device__ float2 operator*(float b, float2 a)
|
749 |
+
{
|
750 |
+
return make_float2(b * a.x, b * a.y);
|
751 |
+
}
|
752 |
+
inline __host__ __device__ void operator*=(float2 &a, float b)
|
753 |
+
{
|
754 |
+
a.x *= b;
|
755 |
+
a.y *= b;
|
756 |
+
}
|
757 |
+
|
758 |
+
inline __host__ __device__ int2 operator*(int2 a, int2 b)
|
759 |
+
{
|
760 |
+
return make_int2(a.x * b.x, a.y * b.y);
|
761 |
+
}
|
762 |
+
inline __host__ __device__ void operator*=(int2 &a, int2 b)
|
763 |
+
{
|
764 |
+
a.x *= b.x;
|
765 |
+
a.y *= b.y;
|
766 |
+
}
|
767 |
+
inline __host__ __device__ int2 operator*(int2 a, int b)
|
768 |
+
{
|
769 |
+
return make_int2(a.x * b, a.y * b);
|
770 |
+
}
|
771 |
+
inline __host__ __device__ int2 operator*(int b, int2 a)
|
772 |
+
{
|
773 |
+
return make_int2(b * a.x, b * a.y);
|
774 |
+
}
|
775 |
+
inline __host__ __device__ void operator*=(int2 &a, int b)
|
776 |
+
{
|
777 |
+
a.x *= b;
|
778 |
+
a.y *= b;
|
779 |
+
}
|
780 |
+
|
781 |
+
inline __host__ __device__ uint2 operator*(uint2 a, uint2 b)
|
782 |
+
{
|
783 |
+
return make_uint2(a.x * b.x, a.y * b.y);
|
784 |
+
}
|
785 |
+
inline __host__ __device__ void operator*=(uint2 &a, uint2 b)
|
786 |
+
{
|
787 |
+
a.x *= b.x;
|
788 |
+
a.y *= b.y;
|
789 |
+
}
|
790 |
+
inline __host__ __device__ uint2 operator*(uint2 a, uint b)
|
791 |
+
{
|
792 |
+
return make_uint2(a.x * b, a.y * b);
|
793 |
+
}
|
794 |
+
inline __host__ __device__ uint2 operator*(uint b, uint2 a)
|
795 |
+
{
|
796 |
+
return make_uint2(b * a.x, b * a.y);
|
797 |
+
}
|
798 |
+
inline __host__ __device__ void operator*=(uint2 &a, uint b)
|
799 |
+
{
|
800 |
+
a.x *= b;
|
801 |
+
a.y *= b;
|
802 |
+
}
|
803 |
+
|
804 |
+
inline __host__ __device__ float3 operator*(float3 a, float3 b)
|
805 |
+
{
|
806 |
+
return make_float3(a.x * b.x, a.y * b.y, a.z * b.z);
|
807 |
+
}
|
808 |
+
inline __host__ __device__ void operator*=(float3 &a, float3 b)
|
809 |
+
{
|
810 |
+
a.x *= b.x;
|
811 |
+
a.y *= b.y;
|
812 |
+
a.z *= b.z;
|
813 |
+
}
|
814 |
+
inline __host__ __device__ float3 operator*(float3 a, float b)
|
815 |
+
{
|
816 |
+
return make_float3(a.x * b, a.y * b, a.z * b);
|
817 |
+
}
|
818 |
+
inline __host__ __device__ float3 operator*(float b, float3 a)
|
819 |
+
{
|
820 |
+
return make_float3(b * a.x, b * a.y, b * a.z);
|
821 |
+
}
|
822 |
+
inline __host__ __device__ void operator*=(float3 &a, float b)
|
823 |
+
{
|
824 |
+
a.x *= b;
|
825 |
+
a.y *= b;
|
826 |
+
a.z *= b;
|
827 |
+
}
|
828 |
+
|
829 |
+
inline __host__ __device__ int3 operator*(int3 a, int3 b)
|
830 |
+
{
|
831 |
+
return make_int3(a.x * b.x, a.y * b.y, a.z * b.z);
|
832 |
+
}
|
833 |
+
inline __host__ __device__ void operator*=(int3 &a, int3 b)
|
834 |
+
{
|
835 |
+
a.x *= b.x;
|
836 |
+
a.y *= b.y;
|
837 |
+
a.z *= b.z;
|
838 |
+
}
|
839 |
+
inline __host__ __device__ int3 operator*(int3 a, int b)
|
840 |
+
{
|
841 |
+
return make_int3(a.x * b, a.y * b, a.z * b);
|
842 |
+
}
|
843 |
+
inline __host__ __device__ int3 operator*(int b, int3 a)
|
844 |
+
{
|
845 |
+
return make_int3(b * a.x, b * a.y, b * a.z);
|
846 |
+
}
|
847 |
+
inline __host__ __device__ void operator*=(int3 &a, int b)
|
848 |
+
{
|
849 |
+
a.x *= b;
|
850 |
+
a.y *= b;
|
851 |
+
a.z *= b;
|
852 |
+
}
|
853 |
+
|
854 |
+
inline __host__ __device__ uint3 operator*(uint3 a, uint3 b)
|
855 |
+
{
|
856 |
+
return make_uint3(a.x * b.x, a.y * b.y, a.z * b.z);
|
857 |
+
}
|
858 |
+
inline __host__ __device__ void operator*=(uint3 &a, uint3 b)
|
859 |
+
{
|
860 |
+
a.x *= b.x;
|
861 |
+
a.y *= b.y;
|
862 |
+
a.z *= b.z;
|
863 |
+
}
|
864 |
+
inline __host__ __device__ uint3 operator*(uint3 a, uint b)
|
865 |
+
{
|
866 |
+
return make_uint3(a.x * b, a.y * b, a.z * b);
|
867 |
+
}
|
868 |
+
inline __host__ __device__ uint3 operator*(uint b, uint3 a)
|
869 |
+
{
|
870 |
+
return make_uint3(b * a.x, b * a.y, b * a.z);
|
871 |
+
}
|
872 |
+
inline __host__ __device__ void operator*=(uint3 &a, uint b)
|
873 |
+
{
|
874 |
+
a.x *= b;
|
875 |
+
a.y *= b;
|
876 |
+
a.z *= b;
|
877 |
+
}
|
878 |
+
|
879 |
+
inline __host__ __device__ float4 operator*(float4 a, float4 b)
|
880 |
+
{
|
881 |
+
return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
|
882 |
+
}
|
883 |
+
inline __host__ __device__ void operator*=(float4 &a, float4 b)
|
884 |
+
{
|
885 |
+
a.x *= b.x;
|
886 |
+
a.y *= b.y;
|
887 |
+
a.z *= b.z;
|
888 |
+
a.w *= b.w;
|
889 |
+
}
|
890 |
+
inline __host__ __device__ float4 operator*(float4 a, float b)
|
891 |
+
{
|
892 |
+
return make_float4(a.x * b, a.y * b, a.z * b, a.w * b);
|
893 |
+
}
|
894 |
+
inline __host__ __device__ float4 operator*(float b, float4 a)
|
895 |
+
{
|
896 |
+
return make_float4(b * a.x, b * a.y, b * a.z, b * a.w);
|
897 |
+
}
|
898 |
+
inline __host__ __device__ void operator*=(float4 &a, float b)
|
899 |
+
{
|
900 |
+
a.x *= b;
|
901 |
+
a.y *= b;
|
902 |
+
a.z *= b;
|
903 |
+
a.w *= b;
|
904 |
+
}
|
905 |
+
|
906 |
+
inline __host__ __device__ int4 operator*(int4 a, int4 b)
|
907 |
+
{
|
908 |
+
return make_int4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
|
909 |
+
}
|
910 |
+
inline __host__ __device__ void operator*=(int4 &a, int4 b)
|
911 |
+
{
|
912 |
+
a.x *= b.x;
|
913 |
+
a.y *= b.y;
|
914 |
+
a.z *= b.z;
|
915 |
+
a.w *= b.w;
|
916 |
+
}
|
917 |
+
inline __host__ __device__ int4 operator*(int4 a, int b)
|
918 |
+
{
|
919 |
+
return make_int4(a.x * b, a.y * b, a.z * b, a.w * b);
|
920 |
+
}
|
921 |
+
inline __host__ __device__ int4 operator*(int b, int4 a)
|
922 |
+
{
|
923 |
+
return make_int4(b * a.x, b * a.y, b * a.z, b * a.w);
|
924 |
+
}
|
925 |
+
inline __host__ __device__ void operator*=(int4 &a, int b)
|
926 |
+
{
|
927 |
+
a.x *= b;
|
928 |
+
a.y *= b;
|
929 |
+
a.z *= b;
|
930 |
+
a.w *= b;
|
931 |
+
}
|
932 |
+
|
933 |
+
inline __host__ __device__ uint4 operator*(uint4 a, uint4 b)
|
934 |
+
{
|
935 |
+
return make_uint4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
|
936 |
+
}
|
937 |
+
inline __host__ __device__ void operator*=(uint4 &a, uint4 b)
|
938 |
+
{
|
939 |
+
a.x *= b.x;
|
940 |
+
a.y *= b.y;
|
941 |
+
a.z *= b.z;
|
942 |
+
a.w *= b.w;
|
943 |
+
}
|
944 |
+
inline __host__ __device__ uint4 operator*(uint4 a, uint b)
|
945 |
+
{
|
946 |
+
return make_uint4(a.x * b, a.y * b, a.z * b, a.w * b);
|
947 |
+
}
|
948 |
+
inline __host__ __device__ uint4 operator*(uint b, uint4 a)
|
949 |
+
{
|
950 |
+
return make_uint4(b * a.x, b * a.y, b * a.z, b * a.w);
|
951 |
+
}
|
952 |
+
inline __host__ __device__ void operator*=(uint4 &a, uint b)
|
953 |
+
{
|
954 |
+
a.x *= b;
|
955 |
+
a.y *= b;
|
956 |
+
a.z *= b;
|
957 |
+
a.w *= b;
|
958 |
+
}
|
959 |
+
|
960 |
+
////////////////////////////////////////////////////////////////////////////////
|
961 |
+
// divide
|
962 |
+
////////////////////////////////////////////////////////////////////////////////
|
963 |
+
|
964 |
+
inline __host__ __device__ float2 operator/(float2 a, float2 b)
|
965 |
+
{
|
966 |
+
return make_float2(a.x / b.x, a.y / b.y);
|
967 |
+
}
|
968 |
+
inline __host__ __device__ void operator/=(float2 &a, float2 b)
|
969 |
+
{
|
970 |
+
a.x /= b.x;
|
971 |
+
a.y /= b.y;
|
972 |
+
}
|
973 |
+
inline __host__ __device__ float2 operator/(float2 a, float b)
|
974 |
+
{
|
975 |
+
return make_float2(a.x / b, a.y / b);
|
976 |
+
}
|
977 |
+
inline __host__ __device__ void operator/=(float2 &a, float b)
|
978 |
+
{
|
979 |
+
a.x /= b;
|
980 |
+
a.y /= b;
|
981 |
+
}
|
982 |
+
inline __host__ __device__ float2 operator/(float b, float2 a)
|
983 |
+
{
|
984 |
+
return make_float2(b / a.x, b / a.y);
|
985 |
+
}
|
986 |
+
|
987 |
+
inline __host__ __device__ float3 operator/(float3 a, float3 b)
|
988 |
+
{
|
989 |
+
return make_float3(a.x / b.x, a.y / b.y, a.z / b.z);
|
990 |
+
}
|
991 |
+
inline __host__ __device__ void operator/=(float3 &a, float3 b)
|
992 |
+
{
|
993 |
+
a.x /= b.x;
|
994 |
+
a.y /= b.y;
|
995 |
+
a.z /= b.z;
|
996 |
+
}
|
997 |
+
inline __host__ __device__ float3 operator/(float3 a, float b)
|
998 |
+
{
|
999 |
+
return make_float3(a.x / b, a.y / b, a.z / b);
|
1000 |
+
}
|
1001 |
+
inline __host__ __device__ void operator/=(float3 &a, float b)
|
1002 |
+
{
|
1003 |
+
a.x /= b;
|
1004 |
+
a.y /= b;
|
1005 |
+
a.z /= b;
|
1006 |
+
}
|
1007 |
+
inline __host__ __device__ float3 operator/(float b, float3 a)
|
1008 |
+
{
|
1009 |
+
return make_float3(b / a.x, b / a.y, b / a.z);
|
1010 |
+
}
|
1011 |
+
|
1012 |
+
inline __host__ __device__ float4 operator/(float4 a, float4 b)
|
1013 |
+
{
|
1014 |
+
return make_float4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w);
|
1015 |
+
}
|
1016 |
+
inline __host__ __device__ void operator/=(float4 &a, float4 b)
|
1017 |
+
{
|
1018 |
+
a.x /= b.x;
|
1019 |
+
a.y /= b.y;
|
1020 |
+
a.z /= b.z;
|
1021 |
+
a.w /= b.w;
|
1022 |
+
}
|
1023 |
+
inline __host__ __device__ float4 operator/(float4 a, float b)
|
1024 |
+
{
|
1025 |
+
return make_float4(a.x / b, a.y / b, a.z / b, a.w / b);
|
1026 |
+
}
|
1027 |
+
inline __host__ __device__ void operator/=(float4 &a, float b)
|
1028 |
+
{
|
1029 |
+
a.x /= b;
|
1030 |
+
a.y /= b;
|
1031 |
+
a.z /= b;
|
1032 |
+
a.w /= b;
|
1033 |
+
}
|
1034 |
+
inline __host__ __device__ float4 operator/(float b, float4 a)
|
1035 |
+
{
|
1036 |
+
return make_float4(b / a.x, b / a.y, b / a.z, b / a.w);
|
1037 |
+
}
|
1038 |
+
|
1039 |
+
////////////////////////////////////////////////////////////////////////////////
|
1040 |
+
// min
|
1041 |
+
////////////////////////////////////////////////////////////////////////////////
|
1042 |
+
|
1043 |
+
inline __host__ __device__ float2 fminf(float2 a, float2 b)
|
1044 |
+
{
|
1045 |
+
return make_float2(fminf(a.x,b.x), fminf(a.y,b.y));
|
1046 |
+
}
|
1047 |
+
inline __host__ __device__ float3 fminf(float3 a, float3 b)
|
1048 |
+
{
|
1049 |
+
return make_float3(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z));
|
1050 |
+
}
|
1051 |
+
inline __host__ __device__ float4 fminf(float4 a, float4 b)
|
1052 |
+
{
|
1053 |
+
return make_float4(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z), fminf(a.w,b.w));
|
1054 |
+
}
|
1055 |
+
|
1056 |
+
inline __host__ __device__ int2 min(int2 a, int2 b)
|
1057 |
+
{
|
1058 |
+
return make_int2(min(a.x,b.x), min(a.y,b.y));
|
1059 |
+
}
|
1060 |
+
inline __host__ __device__ int3 min(int3 a, int3 b)
|
1061 |
+
{
|
1062 |
+
return make_int3(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z));
|
1063 |
+
}
|
1064 |
+
inline __host__ __device__ int4 min(int4 a, int4 b)
|
1065 |
+
{
|
1066 |
+
return make_int4(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z), min(a.w,b.w));
|
1067 |
+
}
|
1068 |
+
|
1069 |
+
inline __host__ __device__ uint2 min(uint2 a, uint2 b)
|
1070 |
+
{
|
1071 |
+
return make_uint2(min(a.x,b.x), min(a.y,b.y));
|
1072 |
+
}
|
1073 |
+
inline __host__ __device__ uint3 min(uint3 a, uint3 b)
|
1074 |
+
{
|
1075 |
+
return make_uint3(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z));
|
1076 |
+
}
|
1077 |
+
inline __host__ __device__ uint4 min(uint4 a, uint4 b)
|
1078 |
+
{
|
1079 |
+
return make_uint4(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z), min(a.w,b.w));
|
1080 |
+
}
|
1081 |
+
|
1082 |
+
////////////////////////////////////////////////////////////////////////////////
|
1083 |
+
// max
|
1084 |
+
////////////////////////////////////////////////////////////////////////////////
|
1085 |
+
|
1086 |
+
inline __host__ __device__ float2 fmaxf(float2 a, float2 b)
|
1087 |
+
{
|
1088 |
+
return make_float2(fmaxf(a.x,b.x), fmaxf(a.y,b.y));
|
1089 |
+
}
|
1090 |
+
inline __host__ __device__ float3 fmaxf(float3 a, float3 b)
|
1091 |
+
{
|
1092 |
+
return make_float3(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z));
|
1093 |
+
}
|
1094 |
+
inline __host__ __device__ float4 fmaxf(float4 a, float4 b)
|
1095 |
+
{
|
1096 |
+
return make_float4(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z), fmaxf(a.w,b.w));
|
1097 |
+
}
|
1098 |
+
|
1099 |
+
inline __host__ __device__ int2 max(int2 a, int2 b)
|
1100 |
+
{
|
1101 |
+
return make_int2(max(a.x,b.x), max(a.y,b.y));
|
1102 |
+
}
|
1103 |
+
inline __host__ __device__ int3 max(int3 a, int3 b)
|
1104 |
+
{
|
1105 |
+
return make_int3(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z));
|
1106 |
+
}
|
1107 |
+
inline __host__ __device__ int4 max(int4 a, int4 b)
|
1108 |
+
{
|
1109 |
+
return make_int4(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z), max(a.w,b.w));
|
1110 |
+
}
|
1111 |
+
|
1112 |
+
inline __host__ __device__ uint2 max(uint2 a, uint2 b)
|
1113 |
+
{
|
1114 |
+
return make_uint2(max(a.x,b.x), max(a.y,b.y));
|
1115 |
+
}
|
1116 |
+
inline __host__ __device__ uint3 max(uint3 a, uint3 b)
|
1117 |
+
{
|
1118 |
+
return make_uint3(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z));
|
1119 |
+
}
|
1120 |
+
inline __host__ __device__ uint4 max(uint4 a, uint4 b)
|
1121 |
+
{
|
1122 |
+
return make_uint4(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z), max(a.w,b.w));
|
1123 |
+
}
|
1124 |
+
|
1125 |
+
////////////////////////////////////////////////////////////////////////////////
|
1126 |
+
// lerp
|
1127 |
+
// - linear interpolation between a and b, based on value t in [0, 1] range
|
1128 |
+
////////////////////////////////////////////////////////////////////////////////
|
1129 |
+
|
1130 |
+
inline __device__ __host__ float lerp(float a, float b, float t)
|
1131 |
+
{
|
1132 |
+
return a + t*(b-a);
|
1133 |
+
}
|
1134 |
+
inline __device__ __host__ float2 lerp(float2 a, float2 b, float t)
|
1135 |
+
{
|
1136 |
+
return a + t*(b-a);
|
1137 |
+
}
|
1138 |
+
inline __device__ __host__ float3 lerp(float3 a, float3 b, float t)
|
1139 |
+
{
|
1140 |
+
return a + t*(b-a);
|
1141 |
+
}
|
1142 |
+
inline __device__ __host__ float4 lerp(float4 a, float4 b, float t)
|
1143 |
+
{
|
1144 |
+
return a + t*(b-a);
|
1145 |
+
}
|
1146 |
+
|
1147 |
+
////////////////////////////////////////////////////////////////////////////////
|
1148 |
+
// clamp
|
1149 |
+
// - clamp the value v to be in the range [a, b]
|
1150 |
+
////////////////////////////////////////////////////////////////////////////////
|
1151 |
+
|
1152 |
+
inline __device__ __host__ float clamp(float f, float a, float b)
|
1153 |
+
{
|
1154 |
+
return fmaxf(a, fminf(f, b));
|
1155 |
+
}
|
1156 |
+
inline __device__ __host__ int clamp(int f, int a, int b)
|
1157 |
+
{
|
1158 |
+
return max(a, min(f, b));
|
1159 |
+
}
|
1160 |
+
inline __device__ __host__ uint clamp(uint f, uint a, uint b)
|
1161 |
+
{
|
1162 |
+
return max(a, min(f, b));
|
1163 |
+
}
|
1164 |
+
|
1165 |
+
inline __device__ __host__ float2 clamp(float2 v, float a, float b)
|
1166 |
+
{
|
1167 |
+
return make_float2(clamp(v.x, a, b), clamp(v.y, a, b));
|
1168 |
+
}
|
1169 |
+
inline __device__ __host__ float2 clamp(float2 v, float2 a, float2 b)
|
1170 |
+
{
|
1171 |
+
return make_float2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
|
1172 |
+
}
|
1173 |
+
inline __device__ __host__ float3 clamp(float3 v, float a, float b)
|
1174 |
+
{
|
1175 |
+
return make_float3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
|
1176 |
+
}
|
1177 |
+
inline __device__ __host__ float3 clamp(float3 v, float3 a, float3 b)
|
1178 |
+
{
|
1179 |
+
return make_float3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
|
1180 |
+
}
|
1181 |
+
inline __device__ __host__ float4 clamp(float4 v, float a, float b)
|
1182 |
+
{
|
1183 |
+
return make_float4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
|
1184 |
+
}
|
1185 |
+
inline __device__ __host__ float4 clamp(float4 v, float4 a, float4 b)
|
1186 |
+
{
|
1187 |
+
return make_float4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
|
1188 |
+
}
|
1189 |
+
|
1190 |
+
inline __device__ __host__ int2 clamp(int2 v, int a, int b)
|
1191 |
+
{
|
1192 |
+
return make_int2(clamp(v.x, a, b), clamp(v.y, a, b));
|
1193 |
+
}
|
1194 |
+
inline __device__ __host__ int2 clamp(int2 v, int2 a, int2 b)
|
1195 |
+
{
|
1196 |
+
return make_int2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
|
1197 |
+
}
|
1198 |
+
inline __device__ __host__ int3 clamp(int3 v, int a, int b)
|
1199 |
+
{
|
1200 |
+
return make_int3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
|
1201 |
+
}
|
1202 |
+
inline __device__ __host__ int3 clamp(int3 v, int3 a, int3 b)
|
1203 |
+
{
|
1204 |
+
return make_int3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
|
1205 |
+
}
|
1206 |
+
inline __device__ __host__ int4 clamp(int4 v, int a, int b)
|
1207 |
+
{
|
1208 |
+
return make_int4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
|
1209 |
+
}
|
1210 |
+
inline __device__ __host__ int4 clamp(int4 v, int4 a, int4 b)
|
1211 |
+
{
|
1212 |
+
return make_int4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
|
1213 |
+
}
|
1214 |
+
|
1215 |
+
inline __device__ __host__ uint2 clamp(uint2 v, uint a, uint b)
|
1216 |
+
{
|
1217 |
+
return make_uint2(clamp(v.x, a, b), clamp(v.y, a, b));
|
1218 |
+
}
|
1219 |
+
inline __device__ __host__ uint2 clamp(uint2 v, uint2 a, uint2 b)
|
1220 |
+
{
|
1221 |
+
return make_uint2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
|
1222 |
+
}
|
1223 |
+
inline __device__ __host__ uint3 clamp(uint3 v, uint a, uint b)
|
1224 |
+
{
|
1225 |
+
return make_uint3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
|
1226 |
+
}
|
1227 |
+
inline __device__ __host__ uint3 clamp(uint3 v, uint3 a, uint3 b)
|
1228 |
+
{
|
1229 |
+
return make_uint3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
|
1230 |
+
}
|
1231 |
+
inline __device__ __host__ uint4 clamp(uint4 v, uint a, uint b)
|
1232 |
+
{
|
1233 |
+
return make_uint4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
|
1234 |
+
}
|
1235 |
+
inline __device__ __host__ uint4 clamp(uint4 v, uint4 a, uint4 b)
|
1236 |
+
{
|
1237 |
+
return make_uint4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
|
1238 |
+
}
|
1239 |
+
|
1240 |
+
////////////////////////////////////////////////////////////////////////////////
|
1241 |
+
// dot product
|
1242 |
+
////////////////////////////////////////////////////////////////////////////////
|
1243 |
+
|
1244 |
+
inline __host__ __device__ float dot(float2 a, float2 b)
|
1245 |
+
{
|
1246 |
+
return a.x * b.x + a.y * b.y;
|
1247 |
+
}
|
1248 |
+
inline __host__ __device__ float dot(float3 a, float3 b)
|
1249 |
+
{
|
1250 |
+
return a.x * b.x + a.y * b.y + a.z * b.z;
|
1251 |
+
}
|
1252 |
+
inline __host__ __device__ float dot(float4 a, float4 b)
|
1253 |
+
{
|
1254 |
+
return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
|
1255 |
+
}
|
1256 |
+
|
1257 |
+
inline __host__ __device__ int dot(int2 a, int2 b)
|
1258 |
+
{
|
1259 |
+
return a.x * b.x + a.y * b.y;
|
1260 |
+
}
|
1261 |
+
inline __host__ __device__ int dot(int3 a, int3 b)
|
1262 |
+
{
|
1263 |
+
return a.x * b.x + a.y * b.y + a.z * b.z;
|
1264 |
+
}
|
1265 |
+
inline __host__ __device__ int dot(int4 a, int4 b)
|
1266 |
+
{
|
1267 |
+
return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
|
1268 |
+
}
|
1269 |
+
|
1270 |
+
inline __host__ __device__ uint dot(uint2 a, uint2 b)
|
1271 |
+
{
|
1272 |
+
return a.x * b.x + a.y * b.y;
|
1273 |
+
}
|
1274 |
+
inline __host__ __device__ uint dot(uint3 a, uint3 b)
|
1275 |
+
{
|
1276 |
+
return a.x * b.x + a.y * b.y + a.z * b.z;
|
1277 |
+
}
|
1278 |
+
inline __host__ __device__ uint dot(uint4 a, uint4 b)
|
1279 |
+
{
|
1280 |
+
return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
|
1281 |
+
}
|
1282 |
+
|
1283 |
+
////////////////////////////////////////////////////////////////////////////////
|
1284 |
+
// length
|
1285 |
+
////////////////////////////////////////////////////////////////////////////////
|
1286 |
+
|
1287 |
+
inline __host__ __device__ float length(float2 v)
|
1288 |
+
{
|
1289 |
+
return sqrtf(dot(v, v));
|
1290 |
+
}
|
1291 |
+
inline __host__ __device__ float length(float3 v)
|
1292 |
+
{
|
1293 |
+
return sqrtf(dot(v, v));
|
1294 |
+
}
|
1295 |
+
inline __host__ __device__ float length(float4 v)
|
1296 |
+
{
|
1297 |
+
return sqrtf(dot(v, v));
|
1298 |
+
}
|
1299 |
+
|
1300 |
+
////////////////////////////////////////////////////////////////////////////////
|
1301 |
+
// normalize
|
1302 |
+
////////////////////////////////////////////////////////////////////////////////
|
1303 |
+
|
1304 |
+
inline __host__ __device__ float2 normalize(float2 v)
|
1305 |
+
{
|
1306 |
+
float invLen = rsqrtf(dot(v, v));
|
1307 |
+
return v * invLen;
|
1308 |
+
}
|
1309 |
+
inline __host__ __device__ float3 normalize(float3 v)
|
1310 |
+
{
|
1311 |
+
float invLen = rsqrtf(dot(v, v));
|
1312 |
+
return v * invLen;
|
1313 |
+
}
|
1314 |
+
inline __host__ __device__ float4 normalize(float4 v)
|
1315 |
+
{
|
1316 |
+
float invLen = rsqrtf(dot(v, v));
|
1317 |
+
return v * invLen;
|
1318 |
+
}
|
1319 |
+
|
1320 |
+
////////////////////////////////////////////////////////////////////////////////
|
1321 |
+
// floor
|
1322 |
+
////////////////////////////////////////////////////////////////////////////////
|
1323 |
+
|
1324 |
+
inline __host__ __device__ float2 floorf(float2 v)
|
1325 |
+
{
|
1326 |
+
return make_float2(floorf(v.x), floorf(v.y));
|
1327 |
+
}
|
1328 |
+
inline __host__ __device__ float3 floorf(float3 v)
|
1329 |
+
{
|
1330 |
+
return make_float3(floorf(v.x), floorf(v.y), floorf(v.z));
|
1331 |
+
}
|
1332 |
+
inline __host__ __device__ float4 floorf(float4 v)
|
1333 |
+
{
|
1334 |
+
return make_float4(floorf(v.x), floorf(v.y), floorf(v.z), floorf(v.w));
|
1335 |
+
}
|
1336 |
+
|
1337 |
+
////////////////////////////////////////////////////////////////////////////////
|
1338 |
+
// frac - returns the fractional portion of a scalar or each vector component
|
1339 |
+
////////////////////////////////////////////////////////////////////////////////
|
1340 |
+
|
1341 |
+
inline __host__ __device__ float fracf(float v)
|
1342 |
+
{
|
1343 |
+
return v - floorf(v);
|
1344 |
+
}
|
1345 |
+
inline __host__ __device__ float2 fracf(float2 v)
|
1346 |
+
{
|
1347 |
+
return make_float2(fracf(v.x), fracf(v.y));
|
1348 |
+
}
|
1349 |
+
inline __host__ __device__ float3 fracf(float3 v)
|
1350 |
+
{
|
1351 |
+
return make_float3(fracf(v.x), fracf(v.y), fracf(v.z));
|
1352 |
+
}
|
1353 |
+
inline __host__ __device__ float4 fracf(float4 v)
|
1354 |
+
{
|
1355 |
+
return make_float4(fracf(v.x), fracf(v.y), fracf(v.z), fracf(v.w));
|
1356 |
+
}
|
1357 |
+
|
1358 |
+
////////////////////////////////////////////////////////////////////////////////
|
1359 |
+
// fmod
|
1360 |
+
////////////////////////////////////////////////////////////////////////////////
|
1361 |
+
|
1362 |
+
inline __host__ __device__ float2 fmodf(float2 a, float2 b)
|
1363 |
+
{
|
1364 |
+
return make_float2(fmodf(a.x, b.x), fmodf(a.y, b.y));
|
1365 |
+
}
|
1366 |
+
inline __host__ __device__ float3 fmodf(float3 a, float3 b)
|
1367 |
+
{
|
1368 |
+
return make_float3(fmodf(a.x, b.x), fmodf(a.y, b.y), fmodf(a.z, b.z));
|
1369 |
+
}
|
1370 |
+
inline __host__ __device__ float4 fmodf(float4 a, float4 b)
|
1371 |
+
{
|
1372 |
+
return make_float4(fmodf(a.x, b.x), fmodf(a.y, b.y), fmodf(a.z, b.z), fmodf(a.w, b.w));
|
1373 |
+
}
|
1374 |
+
|
1375 |
+
////////////////////////////////////////////////////////////////////////////////
|
1376 |
+
// absolute value
|
1377 |
+
////////////////////////////////////////////////////////////////////////////////
|
1378 |
+
|
1379 |
+
inline __host__ __device__ float2 fabs(float2 v)
|
1380 |
+
{
|
1381 |
+
return make_float2(fabs(v.x), fabs(v.y));
|
1382 |
+
}
|
1383 |
+
inline __host__ __device__ float3 fabs(float3 v)
|
1384 |
+
{
|
1385 |
+
return make_float3(fabs(v.x), fabs(v.y), fabs(v.z));
|
1386 |
+
}
|
1387 |
+
inline __host__ __device__ float4 fabs(float4 v)
|
1388 |
+
{
|
1389 |
+
return make_float4(fabs(v.x), fabs(v.y), fabs(v.z), fabs(v.w));
|
1390 |
+
}
|
1391 |
+
|
1392 |
+
inline __host__ __device__ int2 abs(int2 v)
|
1393 |
+
{
|
1394 |
+
return make_int2(abs(v.x), abs(v.y));
|
1395 |
+
}
|
1396 |
+
inline __host__ __device__ int3 abs(int3 v)
|
1397 |
+
{
|
1398 |
+
return make_int3(abs(v.x), abs(v.y), abs(v.z));
|
1399 |
+
}
|
1400 |
+
inline __host__ __device__ int4 abs(int4 v)
|
1401 |
+
{
|
1402 |
+
return make_int4(abs(v.x), abs(v.y), abs(v.z), abs(v.w));
|
1403 |
+
}
|
1404 |
+
|
1405 |
+
////////////////////////////////////////////////////////////////////////////////
|
1406 |
+
// reflect
|
1407 |
+
// - returns reflection of incident ray I around surface normal N
|
1408 |
+
// - N should be normalized, reflected vector's length is equal to length of I
|
1409 |
+
////////////////////////////////////////////////////////////////////////////////
|
1410 |
+
|
1411 |
+
inline __host__ __device__ float3 reflect(float3 i, float3 n)
|
1412 |
+
{
|
1413 |
+
return i - 2.0f * n * dot(n,i);
|
1414 |
+
}
|
1415 |
+
|
1416 |
+
////////////////////////////////////////////////////////////////////////////////
|
1417 |
+
// cross product
|
1418 |
+
////////////////////////////////////////////////////////////////////////////////
|
1419 |
+
|
1420 |
+
inline __host__ __device__ float3 cross(float3 a, float3 b)
|
1421 |
+
{
|
1422 |
+
return make_float3(a.y*b.z - a.z*b.y, a.z*b.x - a.x*b.z, a.x*b.y - a.y*b.x);
|
1423 |
+
}
|
1424 |
+
|
1425 |
+
////////////////////////////////////////////////////////////////////////////////
|
1426 |
+
// smoothstep
|
1427 |
+
// - returns 0 if x < a
|
1428 |
+
// - returns 1 if x > b
|
1429 |
+
// - otherwise returns smooth interpolation between 0 and 1 based on x
|
1430 |
+
////////////////////////////////////////////////////////////////////////////////
|
1431 |
+
|
1432 |
+
inline __device__ __host__ float smoothstep(float a, float b, float x)
|
1433 |
+
{
|
1434 |
+
float y = clamp((x - a) / (b - a), 0.0f, 1.0f);
|
1435 |
+
return (y*y*(3.0f - (2.0f*y)));
|
1436 |
+
}
|
1437 |
+
inline __device__ __host__ float2 smoothstep(float2 a, float2 b, float2 x)
|
1438 |
+
{
|
1439 |
+
float2 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
|
1440 |
+
return (y*y*(make_float2(3.0f) - (make_float2(2.0f)*y)));
|
1441 |
+
}
|
1442 |
+
inline __device__ __host__ float3 smoothstep(float3 a, float3 b, float3 x)
|
1443 |
+
{
|
1444 |
+
float3 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
|
1445 |
+
return (y*y*(make_float3(3.0f) - (make_float3(2.0f)*y)));
|
1446 |
+
}
|
1447 |
+
inline __device__ __host__ float4 smoothstep(float4 a, float4 b, float4 x)
|
1448 |
+
{
|
1449 |
+
float4 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
|
1450 |
+
return (y*y*(make_float4(3.0f) - (make_float4(2.0f)*y)));
|
1451 |
+
}
|
1452 |
+
|
1453 |
+
#endif
|
dva/mvp/extensions/utils/makefile
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
all:
|
2 |
+
python setup.py build_ext --inplace
|
dva/mvp/extensions/utils/setup.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from setuptools import setup
|
8 |
+
|
9 |
+
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
|
10 |
+
|
11 |
+
if __name__ == "__main__":
|
12 |
+
import torch
|
13 |
+
setup(
|
14 |
+
name="utils",
|
15 |
+
ext_modules=[
|
16 |
+
CUDAExtension(
|
17 |
+
"utilslib",
|
18 |
+
sources=["utils.cpp", "utils_kernel.cu"],
|
19 |
+
extra_compile_args={
|
20 |
+
"nvcc": [
|
21 |
+
"-arch=sm_70",
|
22 |
+
"-std=c++14",
|
23 |
+
"-lineinfo",
|
24 |
+
]
|
25 |
+
}
|
26 |
+
)
|
27 |
+
],
|
28 |
+
cmdclass={"build_ext": BuildExtension}
|
29 |
+
)
|
dva/mvp/extensions/utils/utils.cpp
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
// All rights reserved.
|
3 |
+
//
|
4 |
+
// This source code is licensed under the license found in the
|
5 |
+
// LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
#include <torch/extension.h>
|
8 |
+
#include <c10/cuda/CUDAStream.h>
|
9 |
+
|
10 |
+
#include <vector>
|
11 |
+
|
12 |
+
void compute_raydirs_forward_cuda(
|
13 |
+
int N, int H, int W,
|
14 |
+
float * viewposim,
|
15 |
+
float * viewrotim,
|
16 |
+
float * focalim,
|
17 |
+
float * princptim,
|
18 |
+
float * pixelcoordsim,
|
19 |
+
float volradius,
|
20 |
+
float * raypos,
|
21 |
+
float * raydir,
|
22 |
+
float * tminmax,
|
23 |
+
cudaStream_t stream);
|
24 |
+
|
25 |
+
void compute_raydirs_backward_cuda(
|
26 |
+
int N, int H, int W,
|
27 |
+
float * viewposim,
|
28 |
+
float * viewrotim,
|
29 |
+
float * focalim,
|
30 |
+
float * princptim,
|
31 |
+
float * pixelcoordsim,
|
32 |
+
float volradius,
|
33 |
+
float * raypos,
|
34 |
+
float * raydir,
|
35 |
+
float * tminmax,
|
36 |
+
float * grad_viewposim,
|
37 |
+
float * grad_viewrotim,
|
38 |
+
float * grad_focalim,
|
39 |
+
float * grad_princptim,
|
40 |
+
cudaStream_t stream);
|
41 |
+
|
42 |
+
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
|
43 |
+
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
44 |
+
#define CHECK_INPUT(x) CHECK_CUDA((x)); CHECK_CONTIGUOUS((x))
|
45 |
+
|
46 |
+
std::vector<torch::Tensor> compute_raydirs_forward(
|
47 |
+
torch::Tensor viewposim,
|
48 |
+
torch::Tensor viewrotim,
|
49 |
+
torch::Tensor focalim,
|
50 |
+
torch::Tensor princptim,
|
51 |
+
torch::optional<torch::Tensor> pixelcoordsim,
|
52 |
+
int W, int H,
|
53 |
+
float volradius,
|
54 |
+
torch::Tensor rayposim,
|
55 |
+
torch::Tensor raydirim,
|
56 |
+
torch::Tensor tminmaxim) {
|
57 |
+
CHECK_INPUT(viewposim);
|
58 |
+
CHECK_INPUT(viewrotim);
|
59 |
+
CHECK_INPUT(focalim);
|
60 |
+
CHECK_INPUT(princptim);
|
61 |
+
if (pixelcoordsim) { CHECK_INPUT(*pixelcoordsim); }
|
62 |
+
CHECK_INPUT(rayposim);
|
63 |
+
CHECK_INPUT(raydirim);
|
64 |
+
CHECK_INPUT(tminmaxim);
|
65 |
+
|
66 |
+
int N = viewposim.size(0);
|
67 |
+
assert(!pixelcoordsim || (pixelcoordsim.size(1) == H && pixelcoordsim.size(2) == W));
|
68 |
+
|
69 |
+
compute_raydirs_forward_cuda(N, H, W,
|
70 |
+
reinterpret_cast<float *>(viewposim.data_ptr()),
|
71 |
+
reinterpret_cast<float *>(viewrotim.data_ptr()),
|
72 |
+
reinterpret_cast<float *>(focalim.data_ptr()),
|
73 |
+
reinterpret_cast<float *>(princptim.data_ptr()),
|
74 |
+
pixelcoordsim ? reinterpret_cast<float *>(pixelcoordsim->data_ptr()) : nullptr,
|
75 |
+
volradius,
|
76 |
+
reinterpret_cast<float *>(rayposim.data_ptr()),
|
77 |
+
reinterpret_cast<float *>(raydirim.data_ptr()),
|
78 |
+
reinterpret_cast<float *>(tminmaxim.data_ptr()),
|
79 |
+
0);
|
80 |
+
|
81 |
+
return {};
|
82 |
+
}
|
83 |
+
|
84 |
+
std::vector<torch::Tensor> compute_raydirs_backward(
|
85 |
+
torch::Tensor viewposim,
|
86 |
+
torch::Tensor viewrotim,
|
87 |
+
torch::Tensor focalim,
|
88 |
+
torch::Tensor princptim,
|
89 |
+
torch::optional<torch::Tensor> pixelcoordsim,
|
90 |
+
int W, int H,
|
91 |
+
float volradius,
|
92 |
+
torch::Tensor rayposim,
|
93 |
+
torch::Tensor raydirim,
|
94 |
+
torch::Tensor tminmaxim,
|
95 |
+
torch::Tensor grad_viewpos,
|
96 |
+
torch::Tensor grad_viewrot,
|
97 |
+
torch::Tensor grad_focal,
|
98 |
+
torch::Tensor grad_princpt) {
|
99 |
+
CHECK_INPUT(viewposim);
|
100 |
+
CHECK_INPUT(viewrotim);
|
101 |
+
CHECK_INPUT(focalim);
|
102 |
+
CHECK_INPUT(princptim);
|
103 |
+
if (pixelcoordsim) { CHECK_INPUT(*pixelcoordsim); }
|
104 |
+
CHECK_INPUT(rayposim);
|
105 |
+
CHECK_INPUT(raydirim);
|
106 |
+
CHECK_INPUT(tminmaxim);
|
107 |
+
CHECK_INPUT(grad_viewpos);
|
108 |
+
CHECK_INPUT(grad_viewrot);
|
109 |
+
CHECK_INPUT(grad_focal);
|
110 |
+
CHECK_INPUT(grad_princpt);
|
111 |
+
|
112 |
+
int N = viewposim.size(0);
|
113 |
+
assert(!pixelcoordsim || (pixelcoordsim.size(1) == H && pixelcoordsim.size(2) == W));
|
114 |
+
|
115 |
+
compute_raydirs_backward_cuda(N, H, W,
|
116 |
+
reinterpret_cast<float *>(viewposim.data_ptr()),
|
117 |
+
reinterpret_cast<float *>(viewrotim.data_ptr()),
|
118 |
+
reinterpret_cast<float *>(focalim.data_ptr()),
|
119 |
+
reinterpret_cast<float *>(princptim.data_ptr()),
|
120 |
+
pixelcoordsim ? reinterpret_cast<float *>(pixelcoordsim->data_ptr()) : nullptr,
|
121 |
+
volradius,
|
122 |
+
reinterpret_cast<float *>(rayposim.data_ptr()),
|
123 |
+
reinterpret_cast<float *>(raydirim.data_ptr()),
|
124 |
+
reinterpret_cast<float *>(tminmaxim.data_ptr()),
|
125 |
+
reinterpret_cast<float *>(grad_viewpos.data_ptr()),
|
126 |
+
reinterpret_cast<float *>(grad_viewrot.data_ptr()),
|
127 |
+
reinterpret_cast<float *>(grad_focal.data_ptr()),
|
128 |
+
reinterpret_cast<float *>(grad_princpt.data_ptr()),
|
129 |
+
0);
|
130 |
+
|
131 |
+
return {};
|
132 |
+
}
|
133 |
+
|
134 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
135 |
+
m.def("compute_raydirs_forward", &compute_raydirs_forward, "raydirs forward (CUDA)");
|
136 |
+
m.def("compute_raydirs_backward", &compute_raydirs_backward, "raydirs backward (CUDA)");
|
137 |
+
}
|
dva/mvp/extensions/utils/utils.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import time
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
from torch.autograd import Function
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
try:
|
16 |
+
from . import utilslib
|
17 |
+
except:
|
18 |
+
import utilslib
|
19 |
+
|
20 |
+
class ComputeRaydirs(Function):
|
21 |
+
@staticmethod
|
22 |
+
def forward(self, viewpos, viewrot, focal, princpt, pixelcoords, volradius):
|
23 |
+
for tensor in [viewpos, viewrot, focal, princpt, pixelcoords]:
|
24 |
+
assert tensor.is_contiguous()
|
25 |
+
|
26 |
+
N = viewpos.size(0)
|
27 |
+
if isinstance(pixelcoords, tuple):
|
28 |
+
W, H = pixelcoords
|
29 |
+
pixelcoords = None
|
30 |
+
else:
|
31 |
+
H = pixelcoords.size(1)
|
32 |
+
W = pixelcoords.size(2)
|
33 |
+
|
34 |
+
raypos = torch.empty((N, H, W, 3), device=viewpos.device)
|
35 |
+
raydirs = torch.empty((N, H, W, 3), device=viewpos.device)
|
36 |
+
tminmax = torch.empty((N, H, W, 2), device=viewpos.device)
|
37 |
+
utilslib.compute_raydirs_forward(viewpos, viewrot, focal, princpt,
|
38 |
+
pixelcoords, W, H, volradius, raypos, raydirs, tminmax)
|
39 |
+
|
40 |
+
return raypos, raydirs, tminmax
|
41 |
+
|
42 |
+
@staticmethod
|
43 |
+
def backward(self, grad_raydirs, grad_tminmax):
|
44 |
+
return None, None, None, None, None, None
|
45 |
+
|
46 |
+
def compute_raydirs(viewpos, viewrot, focal, princpt, pixelcoords, volradius):
|
47 |
+
raypos, raydirs, tminmax = ComputeRaydirs.apply(viewpos, viewrot, focal, princpt, pixelcoords, volradius)
|
48 |
+
return raypos, raydirs, tminmax
|
49 |
+
|
50 |
+
class Rodrigues(nn.Module):
|
51 |
+
def __init__(self):
|
52 |
+
super(Rodrigues, self).__init__()
|
53 |
+
|
54 |
+
def forward(self, rvec):
|
55 |
+
theta = torch.sqrt(1e-5 + torch.sum(rvec ** 2, dim=1))
|
56 |
+
rvec = rvec / theta[:, None]
|
57 |
+
costh = torch.cos(theta)
|
58 |
+
sinth = torch.sin(theta)
|
59 |
+
return torch.stack((
|
60 |
+
rvec[:, 0] ** 2 + (1. - rvec[:, 0] ** 2) * costh,
|
61 |
+
rvec[:, 0] * rvec[:, 1] * (1. - costh) - rvec[:, 2] * sinth,
|
62 |
+
rvec[:, 0] * rvec[:, 2] * (1. - costh) + rvec[:, 1] * sinth,
|
63 |
+
|
64 |
+
rvec[:, 0] * rvec[:, 1] * (1. - costh) + rvec[:, 2] * sinth,
|
65 |
+
rvec[:, 1] ** 2 + (1. - rvec[:, 1] ** 2) * costh,
|
66 |
+
rvec[:, 1] * rvec[:, 2] * (1. - costh) - rvec[:, 0] * sinth,
|
67 |
+
|
68 |
+
rvec[:, 0] * rvec[:, 2] * (1. - costh) - rvec[:, 1] * sinth,
|
69 |
+
rvec[:, 1] * rvec[:, 2] * (1. - costh) + rvec[:, 0] * sinth,
|
70 |
+
rvec[:, 2] ** 2 + (1. - rvec[:, 2] ** 2) * costh), dim=1).view(-1, 3, 3)
|
71 |
+
|
72 |
+
def gradcheck():
|
73 |
+
N = 2
|
74 |
+
H = 64
|
75 |
+
W = 64
|
76 |
+
k3 = 4
|
77 |
+
K = k3*k3*k3
|
78 |
+
|
79 |
+
M = 32
|
80 |
+
volradius = 1.
|
81 |
+
|
82 |
+
# generate random inputs
|
83 |
+
torch.manual_seed(1113)
|
84 |
+
|
85 |
+
rodrigues = Rodrigues()
|
86 |
+
|
87 |
+
_viewpos = torch.tensor([[-0.0, 0.0, -4.] for n in range(N)], device="cuda") + torch.randn(N, 3, device="cuda") * 0.1
|
88 |
+
viewrvec = torch.randn(N, 3, device="cuda") * 0.01
|
89 |
+
_viewrot = rodrigues(viewrvec)
|
90 |
+
|
91 |
+
_focal = torch.tensor([[W*4.0, W*4.0] for n in range(N)], device="cuda")
|
92 |
+
_princpt = torch.tensor([[W*0.5, H*0.5] for n in range(N)], device="cuda")
|
93 |
+
pixely, pixelx = torch.meshgrid(torch.arange(H, device="cuda").float(), torch.arange(W, device="cuda").float())
|
94 |
+
_pixelcoords = torch.stack([pixelx, pixely], dim=-1)[None, :, :, :].repeat(N, 1, 1, 1)
|
95 |
+
|
96 |
+
_viewpos = _viewpos.contiguous().detach().clone()
|
97 |
+
_viewpos.requires_grad = True
|
98 |
+
_viewrot = _viewrot.contiguous().detach().clone()
|
99 |
+
_viewrot.requires_grad = True
|
100 |
+
_focal = _focal.contiguous().detach().clone()
|
101 |
+
_focal.requires_grad = True
|
102 |
+
_princpt = _princpt.contiguous().detach().clone()
|
103 |
+
_princpt.requires_grad = True
|
104 |
+
_pixelcoords = _pixelcoords.contiguous().detach().clone()
|
105 |
+
_pixelcoords.requires_grad = True
|
106 |
+
|
107 |
+
max_len = 6.0
|
108 |
+
_stepsize = max_len / 15.5
|
109 |
+
|
110 |
+
params = [_viewpos, _viewrot, _focal, _princpt]
|
111 |
+
paramnames = ["viewpos", "viewrot", "focal", "princpt"]
|
112 |
+
|
113 |
+
########################### run pytorch version ###########################
|
114 |
+
|
115 |
+
viewpos = _viewpos
|
116 |
+
viewrot = _viewrot
|
117 |
+
focal = _focal
|
118 |
+
princpt = _princpt
|
119 |
+
pixelcoords = _pixelcoords
|
120 |
+
|
121 |
+
raypos = viewpos[:, None, None, :].repeat(1, H, W, 1)
|
122 |
+
|
123 |
+
raydir = (pixelcoords - princpt[:, None, None, :]) / focal[:, None, None, :]
|
124 |
+
raydir = torch.cat([raydir, torch.ones_like(raydir[:, :, :, 0:1])], dim=-1)
|
125 |
+
raydir = torch.sum(viewrot[:, None, None, :, :] * raydir[:, :, :, :, None], dim=-2)
|
126 |
+
raydir = raydir / torch.sqrt(torch.sum(raydir ** 2, dim=-1, keepdim=True))
|
127 |
+
|
128 |
+
t1 = (-1. - viewpos[:, None, None, :]) / raydir
|
129 |
+
t2 = ( 1. - viewpos[:, None, None, :]) / raydir
|
130 |
+
tmin = torch.max(torch.min(t1[..., 0], t2[..., 0]),
|
131 |
+
torch.max(torch.min(t1[..., 1], t2[..., 1]),
|
132 |
+
torch.min(t1[..., 2], t2[..., 2]))).clamp(min=0.)
|
133 |
+
tmax = torch.min(torch.max(t1[..., 0], t2[..., 0]),
|
134 |
+
torch.min(torch.max(t1[..., 1], t2[..., 1]),
|
135 |
+
torch.max(t1[..., 2], t2[..., 2])))
|
136 |
+
|
137 |
+
tminmax = torch.stack([tmin, tmax], dim=-1)
|
138 |
+
|
139 |
+
sample0 = raydir
|
140 |
+
|
141 |
+
torch.cuda.synchronize()
|
142 |
+
time1 = time.time()
|
143 |
+
|
144 |
+
sample0.backward(torch.ones_like(sample0))
|
145 |
+
|
146 |
+
torch.cuda.synchronize()
|
147 |
+
time2 = time.time()
|
148 |
+
|
149 |
+
grads0 = [p.grad.detach().clone() if p.grad is not None else None for p in params]
|
150 |
+
|
151 |
+
for p in params:
|
152 |
+
if p.grad is not None:
|
153 |
+
p.grad.detach_()
|
154 |
+
p.grad.zero_()
|
155 |
+
|
156 |
+
############################## run cuda version ###########################
|
157 |
+
|
158 |
+
viewpos = _viewpos
|
159 |
+
viewrot = _viewrot
|
160 |
+
focal = _focal
|
161 |
+
princpt = _princpt
|
162 |
+
pixelcoords = _pixelcoords
|
163 |
+
|
164 |
+
niter = 1
|
165 |
+
|
166 |
+
for p in params:
|
167 |
+
if p.grad is not None:
|
168 |
+
p.grad.detach_()
|
169 |
+
p.grad.zero_()
|
170 |
+
t0 = time.time()
|
171 |
+
torch.cuda.synchronize()
|
172 |
+
|
173 |
+
sample1 = compute_raydirs(viewpos, viewrot, focal, princpt, pixelcoords, volradius)[1]
|
174 |
+
|
175 |
+
t1 = time.time()
|
176 |
+
torch.cuda.synchronize()
|
177 |
+
|
178 |
+
print("-----------------------------------------------------------------")
|
179 |
+
print("{:>10} {:>10} {:>10} {:>10} {:>10} {:>10}".format("", "maxabsdiff", "dp", "index", "py", "cuda"))
|
180 |
+
ind = torch.argmax(torch.abs(sample0 - sample1))
|
181 |
+
print("{:<10} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format(
|
182 |
+
"fwd",
|
183 |
+
torch.max(torch.abs(sample0 - sample1)).item(),
|
184 |
+
(torch.sum(sample0 * sample1) / torch.sqrt(torch.sum(sample0 * sample0) * torch.sum(sample1 * sample1))).item(),
|
185 |
+
ind.item(),
|
186 |
+
sample0.view(-1)[ind].item(),
|
187 |
+
sample1.view(-1)[ind].item()))
|
188 |
+
|
189 |
+
sample1.backward(torch.ones_like(sample1), retain_graph=True)
|
190 |
+
|
191 |
+
torch.cuda.synchronize()
|
192 |
+
t2 = time.time()
|
193 |
+
|
194 |
+
|
195 |
+
print("{:<10} {:10.5} {:10.5} {:10.5}".format("time", tf / niter, tb / niter, (tf + tb) / niter))
|
196 |
+
grads1 = [p.grad.detach().clone() if p.grad is not None else None for p in params]
|
197 |
+
|
198 |
+
############# compare results #############
|
199 |
+
|
200 |
+
for p, g0, g1 in zip(paramnames, grads0, grads1):
|
201 |
+
ind = torch.argmax(torch.abs(g0 - g1))
|
202 |
+
print("{:<10} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format(
|
203 |
+
p,
|
204 |
+
torch.max(torch.abs(g0 - g1)).item(),
|
205 |
+
(torch.sum(g0 * g1) / torch.sqrt(torch.sum(g0 * g0) * torch.sum(g1 * g1))).item(),
|
206 |
+
ind.item(),
|
207 |
+
g0.view(-1)[ind].item(),
|
208 |
+
g1.view(-1)[ind].item()))
|
209 |
+
|
210 |
+
if __name__ == "__main__":
|
211 |
+
gradcheck()
|