Spaces:
Running
on
Zero
Running
on
Zero
import torch.nn.functional as F | |
import torch.nn as nn | |
import torch | |
def weights_init(m): | |
if isinstance(m, nn.Linear): | |
nn.init.kaiming_normal_(m.weight.data) | |
if m.bias is not None: | |
nn.init.zeros_(m.bias.data) | |
class NeRF(nn.Module): | |
def __init__(self, vol_n=8+8, feat_ch=8+16+32+3, hid_n=64): | |
super(NeRF, self).__init__() | |
self.hid_n = hid_n | |
self.agg = Agg(feat_ch) | |
self.lr0 = nn.Sequential(nn.Linear(vol_n+16, hid_n), nn.ReLU()) | |
self.sigma = nn.Sequential(nn.Linear(hid_n, 1), nn.Softplus()) | |
self.color = nn.Sequential( | |
nn.Linear(16+vol_n+feat_ch+hid_n+4, hid_n), # agg_feats+vox_feat+img_feat+lr0_feats+dir | |
nn.ReLU(), | |
nn.Linear(hid_n, 1) | |
) | |
self.lr0.apply(weights_init) | |
self.sigma.apply(weights_init) | |
self.color.apply(weights_init) | |
def forward(self, vox_feat, img_feat_rgb_dir, source_img_mask): | |
# assert torch.sum(torch.sum(source_img_mask,1)<2)==0 | |
b, d, n, _ = img_feat_rgb_dir.shape # b,d,n,f=8+16+32+3+4 | |
agg_feat = self.agg(img_feat_rgb_dir, source_img_mask) # b,d,f=16 | |
x = self.lr0(torch.cat((vox_feat, agg_feat), dim=-1)) # b,d,f=64 | |
sigma = self.sigma(x) # b,d,1 | |
x = torch.cat((x, vox_feat, agg_feat), dim=-1) # b,d,f=16+16+64 | |
x = x.view(b, d, 1, x.shape[-1]).repeat(1, 1, n, 1) | |
x = torch.cat((x, img_feat_rgb_dir), dim=-1) | |
logits = self.color(x) | |
source_img_mask_ = source_img_mask.reshape(b, 1, n, 1).repeat(1, logits.shape[1], 1, 1) == 0 | |
logits[source_img_mask_] = -1e7 | |
color_weight = F.softmax(logits, dim=-2) | |
color = torch.sum((img_feat_rgb_dir[..., -7:-4] * color_weight), dim=-2) | |
return color, sigma | |
class Agg(nn.Module): | |
def __init__(self, feat_ch): | |
super(Agg, self).__init__() | |
self.feat_ch = feat_ch | |
self.view_fc = nn.Sequential(nn.Linear(4, feat_ch), nn.ReLU()) | |
self.view_fc.apply(weights_init) | |
self.global_fc = nn.Sequential(nn.Linear(feat_ch*3, 32), nn.ReLU()) | |
self.agg_w_fc = nn.Linear(32, 1) | |
self.fc = nn.Linear(32, 16) | |
self.global_fc.apply(weights_init) | |
self.agg_w_fc.apply(weights_init) | |
self.fc.apply(weights_init) | |
def masked_mean_var(self, img_feat_rgb, source_img_mask): | |
# img_feat_rgb: b,d,n,f source_img_mask: b,n | |
b, n = source_img_mask.shape | |
source_img_mask = source_img_mask.view(b, 1, n, 1) | |
mean = torch.sum(source_img_mask * img_feat_rgb, dim=-2)/ (torch.sum(source_img_mask, dim=-2) + 1e-5) | |
var = torch.sum((img_feat_rgb - mean.unsqueeze(-2)) ** 2 * source_img_mask, dim=-2) / (torch.sum(source_img_mask, dim=-2) + 1e-5) | |
return mean, var | |
def forward(self, img_feat_rgb_dir, source_img_mask): | |
# img_feat_rgb_dir b,d,n,f | |
b, d, n, _ = img_feat_rgb_dir.shape | |
view_feat = self.view_fc(img_feat_rgb_dir[..., -4:]) # b,d,n,f-4 | |
img_feat_rgb = img_feat_rgb_dir[..., :-4] + view_feat | |
mean_feat, var_feat = self.masked_mean_var(img_feat_rgb, source_img_mask) | |
var_feat = var_feat.view(b, -1, 1, self.feat_ch).repeat(1, 1, n, 1) | |
avg_feat = mean_feat.view(b, -1, 1, self.feat_ch).repeat(1, 1, n, 1) | |
feat = torch.cat([img_feat_rgb, var_feat, avg_feat], dim=-1) # b,d,n,f | |
global_feat = self.global_fc(feat) # b,d,n,f | |
logits = self.agg_w_fc(global_feat) # b,d,n,1 | |
source_img_mask_ = source_img_mask.reshape(b, 1, n, 1).repeat(1, logits.shape[1], 1, 1) == 0 | |
logits[source_img_mask_] = -1e7 | |
agg_w = F.softmax(logits, dim=-2) | |
im_feat = (global_feat * agg_w).sum(dim=-2) | |
return self.fc(im_feat) |