dreamgaussian4d / scene /deformation.py
jiaweir
init
21c4e64
import functools
import math
import os
import time
from tkinter import W
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.cpp_extension import load
import torch.nn.init as init
from scene.hexplane import HexPlaneField
class Linear_Res(nn.Module):
def __init__(self, W):
super(Linear_Res, self).__init__()
self.main_stream = nn.Linear(W, W)
def forward(self, x):
x = F.relu(x)
return x + self.main_stream(x)
class Head_Res_Net(nn.Module):
def __init__(self, W, H):
super(Head_Res_Net, self).__init__()
self.W = W
self.H = H
self.feature_out = [Linear_Res(self.W)]
self.feature_out.append(nn.Linear(W, self.H))
self.feature_out = nn.Sequential(*self.feature_out)
def initialize_weights(self,):
for m in self.feature_out.modules():
if isinstance(m, nn.Linear):
init.constant_(m.weight, 0)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, x):
return self.feature_out(x)
class Deformation(nn.Module):
def __init__(self, D=8, W=256, input_ch=27, input_ch_time=9, skips=[], args=None, use_res=False):
super(Deformation, self).__init__()
self.D = D
self.W = W
self.input_ch = input_ch
self.input_ch_time = input_ch_time
self.skips = skips
self.no_grid = args.no_grid
self.grid = HexPlaneField(args.bounds, args.kplanes_config, args.multires)
self.use_res = use_res
if not self.use_res:
self.pos_deform, self.scales_deform, self.rotations_deform, self.opacity_deform = self.create_net()
else:
self.pos_deform, self.scales_deform, self.rotations_deform, self.opacity_deform = self.create_res_net()
self.args = args
def create_net(self):
mlp_out_dim = 0
if self.no_grid:
self.feature_out = [nn.Linear(4,self.W)]
else:
self.feature_out = [nn.Linear(mlp_out_dim + self.grid.feat_dim ,self.W)]
for i in range(self.D-1):
self.feature_out.append(nn.ReLU())
self.feature_out.append(nn.Linear(self.W,self.W))
self.feature_out = nn.Sequential(*self.feature_out)
output_dim = self.W
return \
nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 3)),\
nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 3)),\
nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 4)), \
nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 1))
def create_res_net(self,):
mlp_out_dim = 0
if self.no_grid:
self.feature_out = [nn.Linear(4,self.W)]
else:
self.feature_out = [nn.Linear(mlp_out_dim + self.grid.feat_dim ,self.W)]
for i in range(self.D-1):
self.feature_out.append(nn.ReLU())
self.feature_out.append(nn.Linear(self.W,self.W))
self.feature_out = nn.Sequential(*self.feature_out)
output_dim = self.W
return \
Head_Res_Net(self.W, 3), \
Head_Res_Net(self.W, 3), \
Head_Res_Net(self.W, 4), \
Head_Res_Net(self.W, 1)
def query_time(self, rays_pts_emb, scales_emb, rotations_emb, time_emb):
if self.args.no_mlp:
assert not self.no_grid
grid_feature = self.grid(rays_pts_emb[:,:3], time_emb[:,:1])
h = grid_feature
elif not self.use_res:
if self.no_grid:
h = torch.cat([rays_pts_emb[:,:3],time_emb[:,:1]],-1)
else:
grid_feature = self.grid(rays_pts_emb[:,:3], time_emb[:,:1])
h = grid_feature
h = self.feature_out(h)
else:
if self.no_grid:
h = torch.cat([rays_pts_emb[:,:3],time_emb[:,:1]],-1)
h = self.feature_out(h)
else:
grid_feature = self.grid(rays_pts_emb[:,:3], time_emb[:,:1])
h = self.feature_out(grid_feature)
return h
def forward(self, rays_pts_emb, scales_emb=None, rotations_emb=None, opacity = None, time_emb=None):
if time_emb is None:
return self.forward_static(rays_pts_emb[:,:3])
else:
return self.forward_dynamic(rays_pts_emb, scales_emb, rotations_emb, opacity, time_emb)
def forward_static(self, rays_pts_emb):
grid_feature = self.grid(rays_pts_emb[:,:3])
dx = self.static_mlp(grid_feature)
return rays_pts_emb[:, :3] + dx
def forward_dynamic(self,rays_pts_emb, scales_emb, rotations_emb, opacity_emb, time_emb):
hidden = self.query_time(rays_pts_emb, scales_emb, rotations_emb, time_emb).float()
if self.args.no_mlp:
return hidden[:, :3], hidden[:, 3:6], hidden[:, 6:10], hidden[:, 10:11]
dx = self.pos_deform(hidden)
pts = dx
if self.args.no_ds:
scales = scales_emb[:,:3]
else:
ds = self.scales_deform(hidden)
scales = ds
if self.args.no_dr:
rotations = rotations_emb[:,:4]
else:
dr = self.rotations_deform(hidden)
rotations = dr
if self.args.no_do:
opacity = opacity_emb[:,:1]
else:
do = self.opacity_deform(hidden)
opacity = do
return pts, scales, rotations, opacity
def get_mlp_parameters(self):
parameter_list = []
for name, param in self.named_parameters():
if "grid" not in name:
parameter_list.append(param)
return parameter_list
def get_grid_parameters(self):
return list(self.grid.parameters() )
class deform_network(nn.Module):
def __init__(self, args) :
super(deform_network, self).__init__()
net_width = args.net_width
timebase_pe = args.timebase_pe
defor_depth= args.defor_depth
posbase_pe= args.posebase_pe
scale_rotation_pe = args.scale_rotation_pe
opacity_pe = args.opacity_pe
timenet_width = args.timenet_width
timenet_output = args.timenet_output
times_ch = 2*timebase_pe+1
self.timenet = nn.Sequential(
nn.Linear(times_ch, timenet_width), nn.ReLU(),
nn.Linear(timenet_width, timenet_output))
self.use_res = args.use_res
if self.use_res:
print("Using zero-init and residual")
self.deformation_net = Deformation(W=net_width, D=defor_depth, input_ch=(4+3)+((4+3)*scale_rotation_pe)*2, input_ch_time=timenet_output, args=args, use_res=self.use_res)
self.register_buffer('time_poc', torch.FloatTensor([(2**i) for i in range(timebase_pe)]))
self.register_buffer('pos_poc', torch.FloatTensor([(2**i) for i in range(posbase_pe)]))
self.register_buffer('rotation_scaling_poc', torch.FloatTensor([(2**i) for i in range(scale_rotation_pe)]))
self.register_buffer('opacity_poc', torch.FloatTensor([(2**i) for i in range(opacity_pe)]))
self.apply(initialize_weights)
if self.use_res:
self.deformation_net.pos_deform.initialize_weights()
self.deformation_net.scales_deform.initialize_weights()
self.deformation_net.rotations_deform.initialize_weights()
self.deformation_net.opacity_deform.initialize_weights()
def forward(self, point, scales=None, rotations=None, opacity=None, times_sel=None):
if times_sel is not None:
return self.forward_dynamic(point, scales, rotations, opacity, times_sel)
else:
return self.forward_static(point)
def forward_static(self, points):
points = self.deformation_net(points)
return points
def forward_dynamic(self, point, scales=None, rotations=None, opacity=None, times_sel=None):
means3D, scales, rotations, opacity = self.deformation_net( point,
scales,
rotations,
opacity,
times_sel)
return means3D, scales, rotations, opacity
def get_mlp_parameters(self):
return self.deformation_net.get_mlp_parameters() + list(self.timenet.parameters())
def get_grid_parameters(self):
return self.deformation_net.get_grid_parameters()
def initialize_weights(m):
if isinstance(m, nn.Linear):
init.xavier_uniform_(m.weight,gain=1)
if m.bias is not None:
init.xavier_uniform_(m.weight,gain=1)
def initialize_zeros_weights(m):
if isinstance(m, nn.Linear):
init.constant_(m.weight, 0)
if m.bias is not None:
init.constant_(m.bias, 0)