|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
The ray marcher takes the raw output of the implicit representation and uses the volume rendering equation to produce composited colors and depths. |
|
Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!) |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from pdb import set_trace as st |
|
|
|
|
|
class MipRayMarcher2(nn.Module): |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
def run_forward(self, colors, densities, depths, rendering_options): |
|
deltas = depths[:, :, 1:] - depths[:, :, :-1] |
|
colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2 |
|
densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2 |
|
depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2 |
|
|
|
if rendering_options['clamp_mode'] == 'softplus': |
|
densities_mid = F.softplus( |
|
densities_mid - |
|
1) |
|
else: |
|
assert False, "MipRayMarcher only supports `clamp_mode`=`softplus`!" |
|
|
|
density_delta = densities_mid * deltas |
|
|
|
alpha = 1 - torch.exp(-density_delta) |
|
|
|
alpha_shifted = torch.cat( |
|
[torch.ones_like(alpha[:, :, :1]), 1 - alpha + 1e-10], -2) |
|
T = torch.cumprod(alpha_shifted, -2) |
|
weights = alpha * T[:, :, :-1] |
|
visibility = T[:, :, |
|
-1] |
|
|
|
|
|
composite_rgb = torch.sum(weights * colors_mid, -2) |
|
weight_total = weights.sum(2) |
|
|
|
composite_depth = torch.sum( |
|
weights * depths_mid, |
|
-2) |
|
|
|
|
|
composite_depth = torch.nan_to_num(composite_depth, float('inf')) |
|
composite_depth = torch.clamp(composite_depth, torch.min(depths), |
|
torch.max(depths)) |
|
|
|
if rendering_options.get('white_back', True): |
|
composite_rgb = composite_rgb + 1 - weight_total |
|
|
|
composite_rgb = composite_rgb * 2 - 1 |
|
|
|
return composite_rgb, composite_depth, visibility, weights |
|
|
|
def forward(self, colors, densities, depths, rendering_options): |
|
composite_rgb, composite_depth, visibility, weights = self.run_forward( |
|
colors, densities, depths, rendering_options) |
|
|
|
return composite_rgb, composite_depth, visibility, weights |
|
|