import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import tinycudann as tcnn # Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. class Embedder: def __init__(self, **kwargs): self.kwargs = kwargs self.create_embedding_fn() def create_embedding_fn(self): embed_fns = [] d = self.kwargs['input_dims'] out_dim = 0 if self.kwargs['include_input']: embed_fns.append(lambda x: x) out_dim += d max_freq = self.kwargs['max_freq_log2'] N_freqs = self.kwargs['num_freqs'] if self.kwargs['log_sampling']: freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs) else: freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, N_freqs) for freq in freq_bands: for p_fn in self.kwargs['periodic_fns']: embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) out_dim += d self.embed_fns = embed_fns self.out_dim = out_dim def embed(self, inputs): return torch.cat([fn(inputs) for fn in self.embed_fns], -1) def get_embedder(multires, input_dims=3): embed_kwargs = { 'include_input': True, 'input_dims': input_dims, 'max_freq_log2': multires - 1, 'num_freqs': multires, 'log_sampling': True, 'periodic_fns': [torch.sin, torch.cos], } embedder_obj = Embedder(**embed_kwargs) def embed(x, eo=embedder_obj): return eo.embed(x) return embed, embedder_obj.out_dim class SDFNetwork(nn.Module): def __init__(self, d_in, d_out, d_hidden, n_layers, skip_in=(4,), multires=0, bias=0.5, scale=1, geometric_init=True, weight_norm=True, inside_outside=False): super(SDFNetwork, self).__init__() dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out] self.embed_fn_fine = None if multires > 0: embed_fn, input_ch = get_embedder(multires, input_dims=d_in) self.embed_fn_fine = embed_fn dims[0] = input_ch self.num_layers = len(dims) self.skip_in = skip_in self.scale = scale for l in range(0, self.num_layers - 1): if l + 1 in self.skip_in: out_dim = dims[l + 1] - dims[0] else: out_dim = dims[l + 1] lin = nn.Linear(dims[l], out_dim) if geometric_init: if l == self.num_layers - 2: if not inside_outside: torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) torch.nn.init.constant_(lin.bias, -bias) else: torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) torch.nn.init.constant_(lin.bias, bias) elif multires > 0 and l == 0: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.constant_(lin.weight[:, 3:], 0.0) torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) elif multires > 0 and l in self.skip_in: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) else: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) if weight_norm: lin = nn.utils.weight_norm(lin) setattr(self, "lin" + str(l), lin) self.activation = nn.Softplus(beta=100) def forward(self, inputs): inputs = inputs * self.scale if self.embed_fn_fine is not None: inputs = self.embed_fn_fine(inputs) x = inputs for l in range(0, self.num_layers - 1): lin = getattr(self, "lin" + str(l)) if l in self.skip_in: x = torch.cat([x, inputs], -1) / np.sqrt(2) x = lin(x) if l < self.num_layers - 2: x = self.activation(x) return x def sdf(self, x): return self.forward(x)[..., :1] def sdf_hidden_appearance(self, x): return self.forward(x) def gradient(self, x): x.requires_grad_(True) with torch.enable_grad(): y = self.sdf(x) d_output = torch.ones_like(y, requires_grad=False, device=y.device) gradients = torch.autograd.grad( outputs=y, inputs=x, grad_outputs=d_output, create_graph=True, retain_graph=True, only_inputs=True)[0] return gradients def sdf_normal(self, x): x.requires_grad_(True) with torch.enable_grad(): y = self.sdf(x) d_output = torch.ones_like(y, requires_grad=False, device=y.device) gradients = torch.autograd.grad( outputs=y, inputs=x, grad_outputs=d_output, create_graph=True, retain_graph=True, only_inputs=True)[0] return y[..., :1].detach(), gradients.detach() class SDFNetworkWithFeature(nn.Module): def __init__(self, cube, dp_in, df_in, d_out, d_hidden, n_layers, skip_in=(4,), multires=0, bias=0.5, scale=1, geometric_init=True, weight_norm=True, inside_outside=False, cube_length=0.5): super().__init__() self.register_buffer("cube", cube) self.cube_length = cube_length dims = [dp_in+df_in] + [d_hidden for _ in range(n_layers)] + [d_out] self.embed_fn_fine = None if multires > 0: embed_fn, input_ch = get_embedder(multires, input_dims=dp_in) self.embed_fn_fine = embed_fn dims[0] = input_ch + df_in self.num_layers = len(dims) self.skip_in = skip_in self.scale = scale for l in range(0, self.num_layers - 1): if l + 1 in self.skip_in: out_dim = dims[l + 1] - dims[0] else: out_dim = dims[l + 1] lin = nn.Linear(dims[l], out_dim) if geometric_init: if l == self.num_layers - 2: if not inside_outside: torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) torch.nn.init.constant_(lin.bias, -bias) else: torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) torch.nn.init.constant_(lin.bias, bias) elif multires > 0 and l == 0: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.constant_(lin.weight[:, 3:], 0.0) torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) elif multires > 0 and l in self.skip_in: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) else: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) if weight_norm: lin = nn.utils.weight_norm(lin) setattr(self, "lin" + str(l), lin) self.activation = nn.Softplus(beta=100) def forward(self, points): points = points * self.scale # note: point*2 because the cube is [-0.5,0.5] with torch.no_grad(): feats = F.grid_sample(self.cube, points.view(1,-1,1,1,3)/self.cube_length, mode='bilinear', align_corners=True, padding_mode='zeros').detach() feats = feats.view(self.cube.shape[1], -1).permute(1,0).view(*points.shape[:-1], -1) if self.embed_fn_fine is not None: points = self.embed_fn_fine(points) x = torch.cat([points, feats], -1) for l in range(0, self.num_layers - 1): lin = getattr(self, "lin" + str(l)) if l in self.skip_in: x = torch.cat([x, points, feats], -1) / np.sqrt(2) x = lin(x) if l < self.num_layers - 2: x = self.activation(x) # concat feats x = torch.cat([x, feats], -1) return x def sdf(self, x): return self.forward(x)[..., :1] def sdf_hidden_appearance(self, x): return self.forward(x) def gradient(self, x): x.requires_grad_(True) with torch.enable_grad(): y = self.sdf(x) d_output = torch.ones_like(y, requires_grad=False, device=y.device) gradients = torch.autograd.grad( outputs=y, inputs=x, grad_outputs=d_output, create_graph=True, retain_graph=True, only_inputs=True)[0] return gradients def sdf_normal(self, x): x.requires_grad_(True) with torch.enable_grad(): y = self.sdf(x) d_output = torch.ones_like(y, requires_grad=False, device=y.device) gradients = torch.autograd.grad( outputs=y, inputs=x, grad_outputs=d_output, create_graph=True, retain_graph=True, only_inputs=True)[0] return y[..., :1].detach(), gradients.detach() class VanillaMLP(nn.Module): def __init__(self, dim_in, dim_out, n_neurons, n_hidden_layers): super().__init__() self.n_neurons, self.n_hidden_layers = n_neurons, n_hidden_layers self.sphere_init, self.weight_norm = True, True self.sphere_init_radius = 0.5 self.layers = [self.make_linear(dim_in, self.n_neurons, is_first=True, is_last=False), self.make_activation()] for i in range(self.n_hidden_layers - 1): self.layers += [self.make_linear(self.n_neurons, self.n_neurons, is_first=False, is_last=False), self.make_activation()] self.layers += [self.make_linear(self.n_neurons, dim_out, is_first=False, is_last=True)] self.layers = nn.Sequential(*self.layers) @torch.cuda.amp.autocast(False) def forward(self, x): x = self.layers(x.float()) return x def make_linear(self, dim_in, dim_out, is_first, is_last): layer = nn.Linear(dim_in, dim_out, bias=True) # network without bias will degrade quality if self.sphere_init: if is_last: torch.nn.init.constant_(layer.bias, -self.sphere_init_radius) torch.nn.init.normal_(layer.weight, mean=math.sqrt(math.pi) / math.sqrt(dim_in), std=0.0001) elif is_first: torch.nn.init.constant_(layer.bias, 0.0) torch.nn.init.constant_(layer.weight[:, 3:], 0.0) torch.nn.init.normal_(layer.weight[:, :3], 0.0, math.sqrt(2) / math.sqrt(dim_out)) else: torch.nn.init.constant_(layer.bias, 0.0) torch.nn.init.normal_(layer.weight, 0.0, math.sqrt(2) / math.sqrt(dim_out)) else: torch.nn.init.constant_(layer.bias, 0.0) torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity='relu') if self.weight_norm: layer = nn.utils.weight_norm(layer) return layer def make_activation(self): if self.sphere_init: return nn.Softplus(beta=100) else: return nn.ReLU(inplace=True) class SDFHashGridNetwork(nn.Module): def __init__(self, bound=0.5, feats_dim=13): super().__init__() self.bound = bound # max_resolution = 32 # base_resolution = 16 # n_levels = 4 # log2_hashmap_size = 16 # n_features_per_level = 8 max_resolution = 2048 base_resolution = 16 n_levels = 16 log2_hashmap_size = 19 n_features_per_level = 2 # max_res = base_res * t^(k-1) per_level_scale = (max_resolution / base_resolution)** (1 / (n_levels - 1)) self.encoder = tcnn.Encoding( n_input_dims=3, encoding_config={ "otype": "HashGrid", "n_levels": n_levels, "n_features_per_level": n_features_per_level, "log2_hashmap_size": log2_hashmap_size, "base_resolution": base_resolution, "per_level_scale": per_level_scale, }, ) self.sdf_mlp = VanillaMLP(n_levels*n_features_per_level+3,feats_dim,64,1) def forward(self, x): shape = x.shape[:-1] x = x.reshape(-1, 3) x_ = (x + self.bound) / (2 * self.bound) feats = self.encoder(x_) feats = torch.cat([x, feats], 1) feats = self.sdf_mlp(feats) feats = feats.reshape(*shape,-1) return feats def sdf(self, x): return self(x)[...,:1] def gradient(self, x): x.requires_grad_(True) with torch.enable_grad(): y = self.sdf(x) d_output = torch.ones_like(y, requires_grad=False, device=y.device) gradients = torch.autograd.grad( outputs=y, inputs=x, grad_outputs=d_output, create_graph=True, retain_graph=True, only_inputs=True)[0] return gradients def sdf_normal(self, x): x.requires_grad_(True) with torch.enable_grad(): y = self.sdf(x) d_output = torch.ones_like(y, requires_grad=False, device=y.device) gradients = torch.autograd.grad( outputs=y, inputs=x, grad_outputs=d_output, create_graph=True, retain_graph=True, only_inputs=True)[0] return y[..., :1].detach(), gradients.detach() class RenderingFFNetwork(nn.Module): def __init__(self, in_feats_dim=12): super().__init__() self.dir_encoder = tcnn.Encoding( n_input_dims=3, encoding_config={ "otype": "SphericalHarmonics", "degree": 4, }, ) self.color_mlp = tcnn.Network( n_input_dims = in_feats_dim + 3 + self.dir_encoder.n_output_dims, n_output_dims = 3, network_config={ "otype": "FullyFusedMLP", "activation": "ReLU", "output_activation": "none", "n_neurons": 64, "n_hidden_layers": 2, }, ) def forward(self, points, normals, view_dirs, feature_vectors): normals = F.normalize(normals, dim=-1) view_dirs = F.normalize(view_dirs, dim=-1) reflective = torch.sum(view_dirs * normals, -1, keepdim=True) * normals * 2 - view_dirs x = torch.cat([feature_vectors, normals, self.dir_encoder(reflective)], -1) colors = self.color_mlp(x).float() colors = F.sigmoid(colors) return colors # This implementation is borrowed from IDR: https://github.com/lioryariv/idr class RenderingNetwork(nn.Module): def __init__(self, d_feature, d_in, d_out, d_hidden, n_layers, weight_norm=True, multires_view=0, squeeze_out=True, use_view_dir=True): super().__init__() self.squeeze_out = squeeze_out self.rgb_act=F.sigmoid self.use_view_dir=use_view_dir dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out] self.embedview_fn = None if multires_view > 0: embedview_fn, input_ch = get_embedder(multires_view) self.embedview_fn = embedview_fn dims[0] += (input_ch - 3) self.num_layers = len(dims) for l in range(0, self.num_layers - 1): out_dim = dims[l + 1] lin = nn.Linear(dims[l], out_dim) if weight_norm: lin = nn.utils.weight_norm(lin) setattr(self, "lin" + str(l), lin) self.relu = nn.ReLU() def forward(self, points, normals, view_dirs, feature_vectors): if self.use_view_dir: view_dirs = F.normalize(view_dirs, dim=-1) normals = F.normalize(normals, dim=-1) reflective = torch.sum(view_dirs*normals, -1, keepdim=True) * normals * 2 - view_dirs if self.embedview_fn is not None: reflective = self.embedview_fn(reflective) rendering_input = torch.cat([points, reflective, normals, feature_vectors], dim=-1) else: rendering_input = torch.cat([points, normals, feature_vectors], dim=-1) x = rendering_input for l in range(0, self.num_layers - 1): lin = getattr(self, "lin" + str(l)) x = lin(x) if l < self.num_layers - 2: x = self.relu(x) if self.squeeze_out: x = self.rgb_act(x) return x class SingleVarianceNetwork(nn.Module): def __init__(self, init_val, activation='exp'): super(SingleVarianceNetwork, self).__init__() self.act = activation self.register_parameter('variance', nn.Parameter(torch.tensor(init_val))) def forward(self, x): device = x.device if self.act=='exp': return torch.ones([*x.shape[:-1], 1], dtype=torch.float32, device=device) * torch.exp(self.variance * 10.0) else: raise NotImplementedError def warp(self, x, inv_s): device = x.device return torch.ones([*x.shape[:-1], 1], dtype=torch.float32, device=device) * inv_s