Spaces:
Runtime error
Runtime error
Delete scene/.ipynb_checkpoints
Browse files
scene/.ipynb_checkpoints/__init__-checkpoint.py
DELETED
@@ -1,41 +0,0 @@
|
|
1 |
-
###
|
2 |
-
# Copyright (C) 2023, Computer Vision Lab, Seoul National University, https://cv.snu.ac.kr
|
3 |
-
# For permission requests, please contact [email protected], [email protected], [email protected], [email protected].
|
4 |
-
# All rights reserved.
|
5 |
-
###
|
6 |
-
import os
|
7 |
-
import random
|
8 |
-
|
9 |
-
from arguments import GSParams
|
10 |
-
from utils.system import searchForMaxIteration
|
11 |
-
from scene.dataset_readers import readDataInfo
|
12 |
-
from scene.gaussian_model import GaussianModel
|
13 |
-
|
14 |
-
|
15 |
-
class Scene:
|
16 |
-
gaussians: GaussianModel
|
17 |
-
|
18 |
-
def __init__(self, traindata, gaussians: GaussianModel, opt: GSParams):
|
19 |
-
self.traindata = traindata
|
20 |
-
self.gaussians = gaussians
|
21 |
-
|
22 |
-
info = readDataInfo(traindata, opt.white_background)
|
23 |
-
random.shuffle(info.train_cameras) # Multi-res consistent random shuffling
|
24 |
-
self.cameras_extent = info.nerf_normalization["radius"]
|
25 |
-
|
26 |
-
print("Loading Training Cameras")
|
27 |
-
self.train_cameras = info.train_cameras
|
28 |
-
print("Loading Preset Cameras")
|
29 |
-
self.preset_cameras = {}
|
30 |
-
for campath in info.preset_cameras.keys():
|
31 |
-
self.preset_cameras[campath] = info.preset_cameras[campath]
|
32 |
-
|
33 |
-
self.gaussians.create_from_pcd(info.point_cloud, self.cameras_extent)
|
34 |
-
self.gaussians.training_setup(opt)
|
35 |
-
|
36 |
-
def getTrainCameras(self):
|
37 |
-
return self.train_cameras
|
38 |
-
|
39 |
-
def getPresetCameras(self, preset):
|
40 |
-
assert preset in self.preset_cameras
|
41 |
-
return self.preset_cameras[preset]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scene/.ipynb_checkpoints/cameras-checkpoint.py
DELETED
@@ -1,76 +0,0 @@
|
|
1 |
-
#
|
2 |
-
# Copyright (C) 2023, Inria
|
3 |
-
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
-
# All rights reserved.
|
5 |
-
#
|
6 |
-
# This software is free for non-commercial, research and evaluation use
|
7 |
-
# under the terms of the LICENSE.md file.
|
8 |
-
#
|
9 |
-
# For inquiries contact [email protected]
|
10 |
-
#
|
11 |
-
import numpy as np
|
12 |
-
|
13 |
-
import torch
|
14 |
-
from torch import nn
|
15 |
-
|
16 |
-
from utils.graphics import getWorld2View2, getProjectionMatrix
|
17 |
-
from utils.loss import image2canny
|
18 |
-
|
19 |
-
|
20 |
-
class Camera(nn.Module):
|
21 |
-
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
|
22 |
-
image_name, uid,
|
23 |
-
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
|
24 |
-
):
|
25 |
-
super(Camera, self).__init__()
|
26 |
-
|
27 |
-
self.uid = uid
|
28 |
-
self.colmap_id = colmap_id
|
29 |
-
self.R = R
|
30 |
-
self.T = T
|
31 |
-
self.FoVx = FoVx
|
32 |
-
self.FoVy = FoVy
|
33 |
-
self.image_name = image_name
|
34 |
-
|
35 |
-
try:
|
36 |
-
self.data_device = torch.device(data_device)
|
37 |
-
except Exception as e:
|
38 |
-
print(e)
|
39 |
-
print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
|
40 |
-
self.data_device = torch.device("cuda")
|
41 |
-
|
42 |
-
self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
|
43 |
-
self.canny_mask = image2canny(self.original_image.permute(1,2,0), 50, 150, isEdge1=False).detach().to(self.data_device)
|
44 |
-
self.image_width = self.original_image.shape[2]
|
45 |
-
self.image_height = self.original_image.shape[1]
|
46 |
-
|
47 |
-
if gt_alpha_mask is not None:
|
48 |
-
self.original_image *= gt_alpha_mask.to(self.data_device)
|
49 |
-
else:
|
50 |
-
self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
|
51 |
-
|
52 |
-
self.zfar = 100.0
|
53 |
-
self.znear = 0.01
|
54 |
-
|
55 |
-
self.trans = trans
|
56 |
-
self.scale = scale
|
57 |
-
|
58 |
-
self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()
|
59 |
-
self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda()
|
60 |
-
self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
|
61 |
-
self.camera_center = self.world_view_transform.inverse()[3, :3]
|
62 |
-
|
63 |
-
|
64 |
-
class MiniCam:
|
65 |
-
def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
|
66 |
-
self.image_width = width
|
67 |
-
self.image_height = height
|
68 |
-
self.FoVy = fovy
|
69 |
-
self.FoVx = fovx
|
70 |
-
self.znear = znear
|
71 |
-
self.zfar = zfar
|
72 |
-
self.world_view_transform = world_view_transform
|
73 |
-
self.full_proj_transform = full_proj_transform
|
74 |
-
view_inv = torch.inverse(self.world_view_transform)
|
75 |
-
self.camera_center = view_inv[3][:3]
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scene/.ipynb_checkpoints/colmap_loader-checkpoint.py
DELETED
@@ -1,301 +0,0 @@
|
|
1 |
-
#
|
2 |
-
# Copyright (C) 2023, Inria
|
3 |
-
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
-
# All rights reserved.
|
5 |
-
#
|
6 |
-
# This software is free for non-commercial, research and evaluation use
|
7 |
-
# under the terms of the LICENSE.md file.
|
8 |
-
#
|
9 |
-
# For inquiries contact [email protected]
|
10 |
-
#
|
11 |
-
import numpy as np
|
12 |
-
import collections
|
13 |
-
import struct
|
14 |
-
|
15 |
-
|
16 |
-
CameraModel = collections.namedtuple(
|
17 |
-
"CameraModel", ["model_id", "model_name", "num_params"])
|
18 |
-
Camera = collections.namedtuple(
|
19 |
-
"Camera", ["id", "model", "width", "height", "params"])
|
20 |
-
BaseImage = collections.namedtuple(
|
21 |
-
"Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
|
22 |
-
Point3D = collections.namedtuple(
|
23 |
-
"Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
|
24 |
-
CAMERA_MODELS = {
|
25 |
-
CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
|
26 |
-
CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
|
27 |
-
CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
|
28 |
-
CameraModel(model_id=3, model_name="RADIAL", num_params=5),
|
29 |
-
CameraModel(model_id=4, model_name="OPENCV", num_params=8),
|
30 |
-
CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
|
31 |
-
CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
|
32 |
-
CameraModel(model_id=7, model_name="FOV", num_params=5),
|
33 |
-
CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
|
34 |
-
CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
|
35 |
-
CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
|
36 |
-
}
|
37 |
-
CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
|
38 |
-
for camera_model in CAMERA_MODELS])
|
39 |
-
CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)
|
40 |
-
for camera_model in CAMERA_MODELS])
|
41 |
-
|
42 |
-
|
43 |
-
def qvec2rotmat(qvec):
|
44 |
-
return np.array([
|
45 |
-
[1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
|
46 |
-
2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
|
47 |
-
2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
|
48 |
-
[2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
|
49 |
-
1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
|
50 |
-
2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
|
51 |
-
[2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
|
52 |
-
2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
|
53 |
-
1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
|
54 |
-
|
55 |
-
|
56 |
-
def rotmat2qvec(R):
|
57 |
-
Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
|
58 |
-
K = np.array([
|
59 |
-
[Rxx - Ryy - Rzz, 0, 0, 0],
|
60 |
-
[Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
|
61 |
-
[Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
|
62 |
-
[Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
|
63 |
-
eigvals, eigvecs = np.linalg.eigh(K)
|
64 |
-
qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
|
65 |
-
if qvec[0] < 0:
|
66 |
-
qvec *= -1
|
67 |
-
return qvec
|
68 |
-
|
69 |
-
|
70 |
-
class Image(BaseImage):
|
71 |
-
def qvec2rotmat(self):
|
72 |
-
return qvec2rotmat(self.qvec)
|
73 |
-
|
74 |
-
|
75 |
-
def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
|
76 |
-
"""Read and unpack the next bytes from a binary file.
|
77 |
-
:param fid:
|
78 |
-
:param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
|
79 |
-
:param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
|
80 |
-
:param endian_character: Any of {@, =, <, >, !}
|
81 |
-
:return: Tuple of read and unpacked values.
|
82 |
-
"""
|
83 |
-
data = fid.read(num_bytes)
|
84 |
-
return struct.unpack(endian_character + format_char_sequence, data)
|
85 |
-
|
86 |
-
|
87 |
-
def read_points3D_text(path):
|
88 |
-
"""
|
89 |
-
see: src/base/reconstruction.cc
|
90 |
-
void Reconstruction::ReadPoints3DText(const std::string& path)
|
91 |
-
void Reconstruction::WritePoints3DText(const std::string& path)
|
92 |
-
"""
|
93 |
-
xyzs = None
|
94 |
-
rgbs = None
|
95 |
-
errors = None
|
96 |
-
num_points = 0
|
97 |
-
with open(path, "r") as fid:
|
98 |
-
while True:
|
99 |
-
line = fid.readline()
|
100 |
-
if not line:
|
101 |
-
break
|
102 |
-
line = line.strip()
|
103 |
-
if len(line) > 0 and line[0] != "#":
|
104 |
-
num_points += 1
|
105 |
-
|
106 |
-
|
107 |
-
xyzs = np.empty((num_points, 3))
|
108 |
-
rgbs = np.empty((num_points, 3))
|
109 |
-
errors = np.empty((num_points, 1))
|
110 |
-
count = 0
|
111 |
-
with open(path, "r") as fid:
|
112 |
-
while True:
|
113 |
-
line = fid.readline()
|
114 |
-
if not line:
|
115 |
-
break
|
116 |
-
line = line.strip()
|
117 |
-
if len(line) > 0 and line[0] != "#":
|
118 |
-
elems = line.split()
|
119 |
-
xyz = np.array(tuple(map(float, elems[1:4])))
|
120 |
-
rgb = np.array(tuple(map(int, elems[4:7])))
|
121 |
-
error = np.array(float(elems[7]))
|
122 |
-
xyzs[count] = xyz
|
123 |
-
rgbs[count] = rgb
|
124 |
-
errors[count] = error
|
125 |
-
count += 1
|
126 |
-
|
127 |
-
return xyzs, rgbs, errors
|
128 |
-
|
129 |
-
|
130 |
-
def read_points3D_binary(path_to_model_file):
|
131 |
-
"""
|
132 |
-
see: src/base/reconstruction.cc
|
133 |
-
void Reconstruction::ReadPoints3DBinary(const std::string& path)
|
134 |
-
void Reconstruction::WritePoints3DBinary(const std::string& path)
|
135 |
-
"""
|
136 |
-
|
137 |
-
|
138 |
-
with open(path_to_model_file, "rb") as fid:
|
139 |
-
num_points = read_next_bytes(fid, 8, "Q")[0]
|
140 |
-
|
141 |
-
xyzs = np.empty((num_points, 3))
|
142 |
-
rgbs = np.empty((num_points, 3))
|
143 |
-
errors = np.empty((num_points, 1))
|
144 |
-
|
145 |
-
for p_id in range(num_points):
|
146 |
-
binary_point_line_properties = read_next_bytes(
|
147 |
-
fid, num_bytes=43, format_char_sequence="QdddBBBd")
|
148 |
-
xyz = np.array(binary_point_line_properties[1:4])
|
149 |
-
rgb = np.array(binary_point_line_properties[4:7])
|
150 |
-
error = np.array(binary_point_line_properties[7])
|
151 |
-
track_length = read_next_bytes(
|
152 |
-
fid, num_bytes=8, format_char_sequence="Q")[0]
|
153 |
-
track_elems = read_next_bytes(
|
154 |
-
fid, num_bytes=8*track_length,
|
155 |
-
format_char_sequence="ii"*track_length)
|
156 |
-
xyzs[p_id] = xyz
|
157 |
-
rgbs[p_id] = rgb
|
158 |
-
errors[p_id] = error
|
159 |
-
return xyzs, rgbs, errors
|
160 |
-
|
161 |
-
|
162 |
-
def read_intrinsics_text(path):
|
163 |
-
"""
|
164 |
-
Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
|
165 |
-
"""
|
166 |
-
cameras = {}
|
167 |
-
with open(path, "r") as fid:
|
168 |
-
while True:
|
169 |
-
line = fid.readline()
|
170 |
-
if not line:
|
171 |
-
break
|
172 |
-
line = line.strip()
|
173 |
-
if len(line) > 0 and line[0] != "#":
|
174 |
-
elems = line.split()
|
175 |
-
camera_id = int(elems[0])
|
176 |
-
model = elems[1]
|
177 |
-
assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE"
|
178 |
-
width = int(elems[2])
|
179 |
-
height = int(elems[3])
|
180 |
-
params = np.array(tuple(map(float, elems[4:])))
|
181 |
-
cameras[camera_id] = Camera(id=camera_id, model=model,
|
182 |
-
width=width, height=height,
|
183 |
-
params=params)
|
184 |
-
return cameras
|
185 |
-
|
186 |
-
|
187 |
-
def read_extrinsics_binary(path_to_model_file):
|
188 |
-
"""
|
189 |
-
see: src/base/reconstruction.cc
|
190 |
-
void Reconstruction::ReadImagesBinary(const std::string& path)
|
191 |
-
void Reconstruction::WriteImagesBinary(const std::string& path)
|
192 |
-
"""
|
193 |
-
images = {}
|
194 |
-
with open(path_to_model_file, "rb") as fid:
|
195 |
-
num_reg_images = read_next_bytes(fid, 8, "Q")[0]
|
196 |
-
for _ in range(num_reg_images):
|
197 |
-
binary_image_properties = read_next_bytes(
|
198 |
-
fid, num_bytes=64, format_char_sequence="idddddddi")
|
199 |
-
image_id = binary_image_properties[0]
|
200 |
-
qvec = np.array(binary_image_properties[1:5])
|
201 |
-
tvec = np.array(binary_image_properties[5:8])
|
202 |
-
camera_id = binary_image_properties[8]
|
203 |
-
image_name = ""
|
204 |
-
current_char = read_next_bytes(fid, 1, "c")[0]
|
205 |
-
while current_char != b"\x00": # look for the ASCII 0 entry
|
206 |
-
image_name += current_char.decode("utf-8")
|
207 |
-
current_char = read_next_bytes(fid, 1, "c")[0]
|
208 |
-
num_points2D = read_next_bytes(fid, num_bytes=8,
|
209 |
-
format_char_sequence="Q")[0]
|
210 |
-
x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
|
211 |
-
format_char_sequence="ddq"*num_points2D)
|
212 |
-
xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
|
213 |
-
tuple(map(float, x_y_id_s[1::3]))])
|
214 |
-
point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
|
215 |
-
images[image_id] = Image(
|
216 |
-
id=image_id, qvec=qvec, tvec=tvec,
|
217 |
-
camera_id=camera_id, name=image_name,
|
218 |
-
xys=xys, point3D_ids=point3D_ids)
|
219 |
-
return images
|
220 |
-
|
221 |
-
|
222 |
-
def read_intrinsics_binary(path_to_model_file):
|
223 |
-
"""
|
224 |
-
see: src/base/reconstruction.cc
|
225 |
-
void Reconstruction::WriteCamerasBinary(const std::string& path)
|
226 |
-
void Reconstruction::ReadCamerasBinary(const std::string& path)
|
227 |
-
"""
|
228 |
-
cameras = {}
|
229 |
-
with open(path_to_model_file, "rb") as fid:
|
230 |
-
num_cameras = read_next_bytes(fid, 8, "Q")[0]
|
231 |
-
for _ in range(num_cameras):
|
232 |
-
camera_properties = read_next_bytes(
|
233 |
-
fid, num_bytes=24, format_char_sequence="iiQQ")
|
234 |
-
camera_id = camera_properties[0]
|
235 |
-
model_id = camera_properties[1]
|
236 |
-
model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
|
237 |
-
width = camera_properties[2]
|
238 |
-
height = camera_properties[3]
|
239 |
-
num_params = CAMERA_MODEL_IDS[model_id].num_params
|
240 |
-
params = read_next_bytes(fid, num_bytes=8*num_params,
|
241 |
-
format_char_sequence="d"*num_params)
|
242 |
-
cameras[camera_id] = Camera(id=camera_id,
|
243 |
-
model=model_name,
|
244 |
-
width=width,
|
245 |
-
height=height,
|
246 |
-
params=np.array(params))
|
247 |
-
assert len(cameras) == num_cameras
|
248 |
-
return cameras
|
249 |
-
|
250 |
-
|
251 |
-
def read_extrinsics_text(path):
|
252 |
-
"""
|
253 |
-
Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
|
254 |
-
"""
|
255 |
-
images = {}
|
256 |
-
with open(path, "r") as fid:
|
257 |
-
while True:
|
258 |
-
line = fid.readline()
|
259 |
-
if not line:
|
260 |
-
break
|
261 |
-
line = line.strip()
|
262 |
-
if len(line) > 0 and line[0] != "#":
|
263 |
-
elems = line.split()
|
264 |
-
image_id = int(elems[0])
|
265 |
-
qvec = np.array(tuple(map(float, elems[1:5])))
|
266 |
-
tvec = np.array(tuple(map(float, elems[5:8])))
|
267 |
-
camera_id = int(elems[8])
|
268 |
-
image_name = elems[9]
|
269 |
-
elems = fid.readline().split()
|
270 |
-
xys = np.column_stack([tuple(map(float, elems[0::3])),
|
271 |
-
tuple(map(float, elems[1::3]))])
|
272 |
-
point3D_ids = np.array(tuple(map(int, elems[2::3])))
|
273 |
-
images[image_id] = Image(
|
274 |
-
id=image_id, qvec=qvec, tvec=tvec,
|
275 |
-
camera_id=camera_id, name=image_name,
|
276 |
-
xys=xys, point3D_ids=point3D_ids)
|
277 |
-
return images
|
278 |
-
|
279 |
-
|
280 |
-
def read_colmap_bin_array(path):
|
281 |
-
"""
|
282 |
-
Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py
|
283 |
-
|
284 |
-
:param path: path to the colmap binary file.
|
285 |
-
:return: nd array with the floating point values in the value
|
286 |
-
"""
|
287 |
-
with open(path, "rb") as fid:
|
288 |
-
width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1,
|
289 |
-
usecols=(0, 1, 2), dtype=int)
|
290 |
-
fid.seek(0)
|
291 |
-
num_delimiter = 0
|
292 |
-
byte = fid.read(1)
|
293 |
-
while True:
|
294 |
-
if byte == b"&":
|
295 |
-
num_delimiter += 1
|
296 |
-
if num_delimiter >= 3:
|
297 |
-
break
|
298 |
-
byte = fid.read(1)
|
299 |
-
array = np.fromfile(fid, np.float32)
|
300 |
-
array = array.reshape((width, height, channels), order="F")
|
301 |
-
return np.transpose(array, (1, 0, 2)).squeeze()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scene/.ipynb_checkpoints/dataset_readers-checkpoint.py
DELETED
@@ -1,434 +0,0 @@
|
|
1 |
-
#
|
2 |
-
# Copyright (C) 2023, Inria
|
3 |
-
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
-
# All rights reserved.
|
5 |
-
#
|
6 |
-
# This software is free for non-commercial, research and evaluation use
|
7 |
-
# under the terms of the LICENSE.md file.
|
8 |
-
#
|
9 |
-
# For inquiries contact [email protected]
|
10 |
-
#
|
11 |
-
import os
|
12 |
-
import sys
|
13 |
-
import json
|
14 |
-
from typing import NamedTuple
|
15 |
-
from pathlib import Path
|
16 |
-
|
17 |
-
import imageio
|
18 |
-
import torch
|
19 |
-
import numpy as np
|
20 |
-
from PIL import Image
|
21 |
-
from plyfile import PlyData, PlyElement
|
22 |
-
|
23 |
-
from scene.gaussian_model import BasicPointCloud
|
24 |
-
from scene.cameras import MiniCam, Camera
|
25 |
-
from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \
|
26 |
-
read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text
|
27 |
-
from utils.graphics import getWorld2View2, focal2fov, fov2focal
|
28 |
-
from utils.graphics import getProjectionMatrix
|
29 |
-
from utils.trajectory import get_camerapaths
|
30 |
-
from utils.sh import SH2RGB
|
31 |
-
|
32 |
-
|
33 |
-
class CameraInfo(NamedTuple):
|
34 |
-
uid: int
|
35 |
-
R: np.array
|
36 |
-
T: np.array
|
37 |
-
FovY: np.array
|
38 |
-
FovX: np.array
|
39 |
-
image: np.array
|
40 |
-
image_path: str
|
41 |
-
image_name: str
|
42 |
-
width: int
|
43 |
-
height: int
|
44 |
-
|
45 |
-
|
46 |
-
class SceneInfo(NamedTuple):
|
47 |
-
point_cloud: BasicPointCloud
|
48 |
-
train_cameras: list
|
49 |
-
test_cameras: list
|
50 |
-
preset_cameras: list
|
51 |
-
nerf_normalization: dict
|
52 |
-
ply_path: str
|
53 |
-
|
54 |
-
|
55 |
-
def getNerfppNorm(cam_info):
|
56 |
-
def get_center_and_diag(cam_centers):
|
57 |
-
cam_centers = np.hstack(cam_centers)
|
58 |
-
avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
|
59 |
-
center = avg_cam_center
|
60 |
-
dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
|
61 |
-
diagonal = np.max(dist)
|
62 |
-
return center.flatten(), diagonal
|
63 |
-
|
64 |
-
cam_centers = []
|
65 |
-
|
66 |
-
for cam in cam_info:
|
67 |
-
W2C = getWorld2View2(cam.R, cam.T)
|
68 |
-
C2W = np.linalg.inv(W2C)
|
69 |
-
cam_centers.append(C2W[:3, 3:4])
|
70 |
-
|
71 |
-
center, diagonal = get_center_and_diag(cam_centers)
|
72 |
-
radius = diagonal * 1.1
|
73 |
-
|
74 |
-
translate = -center
|
75 |
-
|
76 |
-
return {"translate": translate, "radius": radius}
|
77 |
-
|
78 |
-
|
79 |
-
def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
|
80 |
-
cam_infos = []
|
81 |
-
for idx, key in enumerate(cam_extrinsics):
|
82 |
-
sys.stdout.write('\r')
|
83 |
-
# the exact output you're looking for:
|
84 |
-
sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics)))
|
85 |
-
sys.stdout.flush()
|
86 |
-
|
87 |
-
extr = cam_extrinsics[key]
|
88 |
-
intr = cam_intrinsics[extr.camera_id]
|
89 |
-
height = intr.height
|
90 |
-
width = intr.width
|
91 |
-
|
92 |
-
uid = intr.id
|
93 |
-
R = np.transpose(qvec2rotmat(extr.qvec))
|
94 |
-
T = np.array(extr.tvec)
|
95 |
-
|
96 |
-
if intr.model=="SIMPLE_PINHOLE":
|
97 |
-
focal_length_x = intr.params[0]
|
98 |
-
FovY = focal2fov(focal_length_x, height)
|
99 |
-
FovX = focal2fov(focal_length_x, width)
|
100 |
-
elif intr.model=="PINHOLE":
|
101 |
-
focal_length_x = intr.params[0]
|
102 |
-
focal_length_y = intr.params[1]
|
103 |
-
FovY = focal2fov(focal_length_y, height)
|
104 |
-
FovX = focal2fov(focal_length_x, width)
|
105 |
-
else:
|
106 |
-
assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"
|
107 |
-
|
108 |
-
image_path = os.path.join(images_folder, os.path.basename(extr.name))
|
109 |
-
image_name = os.path.basename(image_path).split(".")[0]
|
110 |
-
image = Image.open(image_path)
|
111 |
-
|
112 |
-
cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
|
113 |
-
image_path=image_path, image_name=image_name, width=width, height=height)
|
114 |
-
cam_infos.append(cam_info)
|
115 |
-
sys.stdout.write('\n')
|
116 |
-
return cam_infos
|
117 |
-
|
118 |
-
|
119 |
-
def fetchPly(path):
|
120 |
-
plydata = PlyData.read(path)
|
121 |
-
vertices = plydata['vertex']
|
122 |
-
idx = np.random.choice(len(vertices['x']),size=(min(len(vertices['x']), 100_000),),replace=False)
|
123 |
-
positions = np.vstack([vertices['x'][idx], vertices['y'][idx], vertices['z'][idx]]).T if 'x' in vertices else None
|
124 |
-
colors = np.vstack([vertices['red'][idx], vertices['green'][idx], vertices['blue'][idx]]).T / 255.0 if 'red' in vertices else None
|
125 |
-
normals = np.vstack([vertices['nx'][idx], vertices['ny'][idx], vertices['nz'][idx]]).T if 'nx' in vertices else None
|
126 |
-
return BasicPointCloud(points=positions, colors=colors, normals=normals)
|
127 |
-
|
128 |
-
|
129 |
-
def storePly(path, xyz, rgb):
|
130 |
-
# Define the dtype for the structured array
|
131 |
-
dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
|
132 |
-
('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
|
133 |
-
('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
|
134 |
-
|
135 |
-
normals = np.zeros_like(xyz)
|
136 |
-
|
137 |
-
elements = np.empty(xyz.shape[0], dtype=dtype)
|
138 |
-
attributes = np.concatenate((xyz, normals, rgb), axis=1)
|
139 |
-
elements[:] = list(map(tuple, attributes))
|
140 |
-
|
141 |
-
# Create the PlyData object and write to file
|
142 |
-
vertex_element = PlyElement.describe(elements, 'vertex')
|
143 |
-
ply_data = PlyData([vertex_element])
|
144 |
-
ply_data.write(path)
|
145 |
-
|
146 |
-
|
147 |
-
def readColmapSceneInfo(path, images, eval, preset=None, llffhold=8):
|
148 |
-
try:
|
149 |
-
cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin")
|
150 |
-
cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin")
|
151 |
-
cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file)
|
152 |
-
cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)
|
153 |
-
except:
|
154 |
-
cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt")
|
155 |
-
cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt")
|
156 |
-
cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file)
|
157 |
-
cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)
|
158 |
-
|
159 |
-
reading_dir = "images" if images == None else images
|
160 |
-
cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir))
|
161 |
-
cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name)
|
162 |
-
|
163 |
-
if eval:
|
164 |
-
# train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0]
|
165 |
-
# test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0]
|
166 |
-
train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % 5 == 2 or idx % 5 == 0]
|
167 |
-
test_cam_infos = [c for idx, c in enumerate(cam_infos) if not (idx % 5 == 2 or idx % 5 == 0)]
|
168 |
-
else:
|
169 |
-
train_cam_infos = cam_infos
|
170 |
-
test_cam_infos = []
|
171 |
-
|
172 |
-
nerf_normalization = getNerfppNorm(train_cam_infos)
|
173 |
-
|
174 |
-
ply_path = os.path.join(path, "sparse/0/points3D.ply")
|
175 |
-
bin_path = os.path.join(path, "sparse/0/points3D.bin")
|
176 |
-
txt_path = os.path.join(path, "sparse/0/points3D.txt")
|
177 |
-
if not os.path.exists(ply_path):
|
178 |
-
print("Converting point3d.bin to .ply, will happen only the first time you open the scene.")
|
179 |
-
try:
|
180 |
-
xyz, rgb, _ = read_points3D_binary(bin_path)
|
181 |
-
except:
|
182 |
-
xyz, rgb, _ = read_points3D_text(txt_path)
|
183 |
-
storePly(ply_path, xyz, rgb)
|
184 |
-
try:
|
185 |
-
pcd = fetchPly(ply_path)
|
186 |
-
except:
|
187 |
-
pcd = None
|
188 |
-
|
189 |
-
if preset:
|
190 |
-
preset_cam_infos = readCamerasFromPreset('/home/chung/workspace/gaussian-splatting/poses_supplementary', f"{preset}.json")
|
191 |
-
else:
|
192 |
-
preset_cam_infos = None
|
193 |
-
|
194 |
-
scene_info = SceneInfo(point_cloud=pcd,
|
195 |
-
train_cameras=train_cam_infos,
|
196 |
-
test_cameras=test_cam_infos,
|
197 |
-
preset_cameras=preset_cam_infos,
|
198 |
-
nerf_normalization=nerf_normalization,
|
199 |
-
ply_path=ply_path)
|
200 |
-
return scene_info
|
201 |
-
|
202 |
-
|
203 |
-
def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"):
|
204 |
-
cam_infos = []
|
205 |
-
|
206 |
-
with open(os.path.join(path, transformsfile)) as json_file:
|
207 |
-
contents = json.load(json_file)
|
208 |
-
fovx = contents["camera_angle_x"]
|
209 |
-
|
210 |
-
frames = contents["frames"]
|
211 |
-
for idx, frame in enumerate(frames):
|
212 |
-
cam_name = os.path.join(path, frame["file_path"] + extension)
|
213 |
-
|
214 |
-
# NeRF 'transform_matrix' is a camera-to-world transform
|
215 |
-
c2w = np.array(frame["transform_matrix"])
|
216 |
-
# change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
|
217 |
-
c2w[:3, 1:3] *= -1
|
218 |
-
|
219 |
-
# get the world-to-camera transform and set R, T
|
220 |
-
w2c = np.linalg.inv(c2w)
|
221 |
-
R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code
|
222 |
-
T = w2c[:3, 3]
|
223 |
-
|
224 |
-
image_path = os.path.join(path, cam_name)
|
225 |
-
image_name = Path(cam_name).stem
|
226 |
-
image = Image.open(image_path)
|
227 |
-
|
228 |
-
# if os.path.exists(os.path.join(path, frame["file_path"].replace("/train/", "/depths_train/")+'.npy')):
|
229 |
-
# depth = np.load(os.path.join(path, frame["file_path"].replace("/train/", "/depths_train/")+'.npy'))
|
230 |
-
# if os.path.exists(os.path.join(path, frame["file_path"].replace("/train/", "/masks_train/")+'.png')):
|
231 |
-
# mask = imageio.v3.imread(os.path.join(path, frame["file_path"].replace("/train/", "/masks_train/")+'.png'))[:,:,0]/255.
|
232 |
-
# else:
|
233 |
-
# mask = np.ones_like(depth)
|
234 |
-
# final_depth = depth*mask
|
235 |
-
# else:
|
236 |
-
# final_depth = None
|
237 |
-
|
238 |
-
im_data = np.array(image.convert("RGBA"))
|
239 |
-
|
240 |
-
bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0])
|
241 |
-
|
242 |
-
norm_data = im_data / 255.0
|
243 |
-
arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
|
244 |
-
image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
|
245 |
-
|
246 |
-
fovy = focal2fov(fov2focal(fovx, image.size[1]), image.size[0])
|
247 |
-
FovY = fovy
|
248 |
-
FovX = fovx
|
249 |
-
|
250 |
-
cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
|
251 |
-
image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1]))
|
252 |
-
|
253 |
-
return cam_infos
|
254 |
-
|
255 |
-
|
256 |
-
def readCamerasFromPreset(path, transformsfile):
|
257 |
-
cam_infos = []
|
258 |
-
|
259 |
-
with open(os.path.join(path, transformsfile)) as json_file:
|
260 |
-
contents = json.load(json_file)
|
261 |
-
FOV = contents["camera_angle_x"]*1.2
|
262 |
-
|
263 |
-
frames = contents["frames"]
|
264 |
-
for idx, frame in enumerate(frames):
|
265 |
-
# NeRF 'transform_matrix' is a camera-to-world transform
|
266 |
-
c2w = np.array(frame["transform_matrix"])
|
267 |
-
# change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
|
268 |
-
c2w[:3, 1:3] *= -1
|
269 |
-
|
270 |
-
# get the world-to-camera transform and set R, T
|
271 |
-
w2c = np.linalg.inv(np.concatenate((c2w, np.array([0,0,0,1]).reshape(1,4)), axis=0))
|
272 |
-
R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code
|
273 |
-
T = w2c[:3, 3]
|
274 |
-
# R = c2w[:3,:3]
|
275 |
-
# T = - np.transpose(R).dot(c2w[:3,3])
|
276 |
-
|
277 |
-
image = Image.fromarray(np.zeros((512,512)), "RGB")
|
278 |
-
FovY = focal2fov(fov2focal(FOV, 512), image.size[0])
|
279 |
-
FovX = focal2fov(fov2focal(FOV, 512), image.size[1])
|
280 |
-
# FovX, FovY = contents["camera_angle_x"], contents["camera_angle_x"]
|
281 |
-
|
282 |
-
cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
|
283 |
-
image_path='None', image_name='None', width=image.size[1], height=image.size[0]))
|
284 |
-
|
285 |
-
return cam_infos
|
286 |
-
|
287 |
-
|
288 |
-
def readNerfSyntheticInfo(path, white_background, eval, preset=None, extension=".png"):
|
289 |
-
print("Reading Training Transforms")
|
290 |
-
train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension)
|
291 |
-
print("Reading Test Transforms")
|
292 |
-
test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension)
|
293 |
-
|
294 |
-
if preset:
|
295 |
-
preset_cam_infos = readCamerasFromPreset('/home/chung/workspace/gaussian-splatting/poses_supplementary', f"{preset}.json")
|
296 |
-
else:
|
297 |
-
preset_cam_infos = None
|
298 |
-
|
299 |
-
if not eval:
|
300 |
-
train_cam_infos.extend(test_cam_infos)
|
301 |
-
test_cam_infos = []
|
302 |
-
|
303 |
-
nerf_normalization = getNerfppNorm(train_cam_infos)
|
304 |
-
|
305 |
-
ply_path = os.path.join(path, "points3d.ply")
|
306 |
-
if not os.path.exists(ply_path):
|
307 |
-
# Since this data set has no colmap data, we start with random points
|
308 |
-
num_pts = 100_000
|
309 |
-
print(f"Generating random point cloud ({num_pts})...")
|
310 |
-
|
311 |
-
# We create random points inside the bounds of the synthetic Blender scenes
|
312 |
-
xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
|
313 |
-
shs = np.random.random((num_pts, 3)) / 255.0
|
314 |
-
pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)))
|
315 |
-
|
316 |
-
storePly(ply_path, xyz, SH2RGB(shs) * 255)
|
317 |
-
|
318 |
-
try:
|
319 |
-
pcd = fetchPly(ply_path)
|
320 |
-
except:
|
321 |
-
pcd = None
|
322 |
-
|
323 |
-
scene_info = SceneInfo(point_cloud=pcd,
|
324 |
-
train_cameras=train_cam_infos,
|
325 |
-
test_cameras=test_cam_infos,
|
326 |
-
preset_cameras=preset_cam_infos,
|
327 |
-
nerf_normalization=nerf_normalization,
|
328 |
-
ply_path=ply_path)
|
329 |
-
return scene_info
|
330 |
-
|
331 |
-
|
332 |
-
def loadCamerasFromData(traindata, white_background):
|
333 |
-
cameras = []
|
334 |
-
|
335 |
-
fovx = traindata["camera_angle_x"]
|
336 |
-
frames = traindata["frames"]
|
337 |
-
for idx, frame in enumerate(frames):
|
338 |
-
# NeRF 'transform_matrix' is a camera-to-world transform
|
339 |
-
c2w = np.array(frame["transform_matrix"])
|
340 |
-
# change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
|
341 |
-
c2w[:3, 1:3] *= -1
|
342 |
-
|
343 |
-
# get the world-to-camera transform and set R, T
|
344 |
-
w2c = np.linalg.inv(c2w)
|
345 |
-
R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code
|
346 |
-
T = w2c[:3, 3]
|
347 |
-
|
348 |
-
image = frame["image"] if "image" in frame else None
|
349 |
-
im_data = np.array(image.convert("RGBA"))
|
350 |
-
|
351 |
-
bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0])
|
352 |
-
|
353 |
-
norm_data = im_data / 255.0
|
354 |
-
arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
|
355 |
-
image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
|
356 |
-
loaded_mask = np.ones_like(norm_data[:, :, 3:4])
|
357 |
-
|
358 |
-
fovy = focal2fov(fov2focal(fovx, image.size[1]), image.size[0])
|
359 |
-
FovY = fovy
|
360 |
-
FovX = fovx
|
361 |
-
|
362 |
-
image = torch.Tensor(arr).permute(2,0,1)
|
363 |
-
loaded_mask = None #torch.Tensor(loaded_mask).permute(2,0,1)
|
364 |
-
|
365 |
-
### torch로 바꿔야함
|
366 |
-
cameras.append(Camera(colmap_id=idx, R=R, T=T, FoVx=FovX, FoVy=FovY, image=image,
|
367 |
-
gt_alpha_mask=loaded_mask, image_name='', uid=idx, data_device='cuda'))
|
368 |
-
|
369 |
-
return cameras
|
370 |
-
|
371 |
-
|
372 |
-
def loadCameraPreset(traindata, presetdata):
|
373 |
-
cam_infos = {}
|
374 |
-
## camera setting (for H, W and focal)
|
375 |
-
fovx = traindata["camera_angle_x"] * 1.2
|
376 |
-
W, H = traindata["frames"][0]["image"].size
|
377 |
-
# W, H = traindata["W"], traindata["H"]
|
378 |
-
|
379 |
-
for camkey in presetdata:
|
380 |
-
cam_infos[camkey] = []
|
381 |
-
for idx, frame in enumerate(presetdata[camkey]["frames"]):
|
382 |
-
# NeRF 'transform_matrix' is a camera-to-world transform
|
383 |
-
c2w = np.array(frame["transform_matrix"])
|
384 |
-
# change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
|
385 |
-
c2w[:3, 1:3] *= -1
|
386 |
-
|
387 |
-
# get the world-to-camera transform and set R, T
|
388 |
-
w2c = np.linalg.inv(c2w)
|
389 |
-
R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code
|
390 |
-
T = w2c[:3, 3]
|
391 |
-
|
392 |
-
fovy = focal2fov(fov2focal(fovx, W), H)
|
393 |
-
FovY = fovy
|
394 |
-
FovX = fovx
|
395 |
-
|
396 |
-
znear, zfar = 0.01, 100
|
397 |
-
world_view_transform = torch.tensor(getWorld2View2(R, T, np.array([0.0, 0.0, 0.0]), 1.0)).transpose(0, 1).cuda()
|
398 |
-
projection_matrix = getProjectionMatrix(znear=znear, zfar=zfar, fovX=FovX, fovY=FovY).transpose(0,1).cuda()
|
399 |
-
full_proj_transform = (world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0))).squeeze(0)
|
400 |
-
|
401 |
-
cam_infos[camkey].append(MiniCam(width=W, height=H, fovy=FovY, fovx=FovX, znear=znear, zfar=zfar,
|
402 |
-
world_view_transform=world_view_transform, full_proj_transform=full_proj_transform))
|
403 |
-
|
404 |
-
return cam_infos
|
405 |
-
|
406 |
-
|
407 |
-
def readDataInfo(traindata, white_background):
|
408 |
-
print("Reading Training Transforms")
|
409 |
-
|
410 |
-
train_cameras = loadCamerasFromData(traindata, white_background)
|
411 |
-
preset_minicams = loadCameraPreset(traindata, presetdata=get_camerapaths())
|
412 |
-
|
413 |
-
# if not eval:
|
414 |
-
# train_cam_infos.extend(test_cam_infos)
|
415 |
-
# test_cam_infos = []
|
416 |
-
|
417 |
-
nerf_normalization = getNerfppNorm(train_cameras)
|
418 |
-
|
419 |
-
pcd = BasicPointCloud(points=traindata['pcd_points'].T, colors=traindata['pcd_colors'], normals=None)
|
420 |
-
|
421 |
-
|
422 |
-
scene_info = SceneInfo(point_cloud=pcd,
|
423 |
-
train_cameras=train_cameras,
|
424 |
-
test_cameras=[],
|
425 |
-
preset_cameras=preset_minicams,
|
426 |
-
nerf_normalization=nerf_normalization,
|
427 |
-
ply_path='')
|
428 |
-
return scene_info
|
429 |
-
|
430 |
-
|
431 |
-
sceneLoadTypeCallbacks = {
|
432 |
-
"Colmap": readColmapSceneInfo,
|
433 |
-
"Blender" : readNerfSyntheticInfo
|
434 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scene/.ipynb_checkpoints/gaussian_model-checkpoint.py
DELETED
@@ -1,407 +0,0 @@
|
|
1 |
-
#
|
2 |
-
# Copyright (C) 2023, Inria
|
3 |
-
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
-
# All rights reserved.
|
5 |
-
#
|
6 |
-
# This software is free for non-commercial, research and evaluation use
|
7 |
-
# under the terms of the LICENSE.md file.
|
8 |
-
#
|
9 |
-
# For inquiries contact [email protected]
|
10 |
-
#
|
11 |
-
import os
|
12 |
-
|
13 |
-
import numpy as np
|
14 |
-
from plyfile import PlyData, PlyElement
|
15 |
-
|
16 |
-
import torch
|
17 |
-
from torch import nn
|
18 |
-
|
19 |
-
from simple_knn._C import distCUDA2
|
20 |
-
from utils.general import inverse_sigmoid, get_expon_lr_func, build_rotation
|
21 |
-
from utils.system import mkdir_p
|
22 |
-
from utils.sh import RGB2SH
|
23 |
-
from utils.graphics import BasicPointCloud
|
24 |
-
from utils.general import strip_symmetric, build_scaling_rotation
|
25 |
-
|
26 |
-
|
27 |
-
class GaussianModel:
|
28 |
-
def setup_functions(self):
|
29 |
-
def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
|
30 |
-
L = build_scaling_rotation(scaling_modifier * scaling, rotation)
|
31 |
-
actual_covariance = L @ L.transpose(1, 2)
|
32 |
-
symm = strip_symmetric(actual_covariance)
|
33 |
-
return symm
|
34 |
-
|
35 |
-
self.scaling_activation = torch.exp
|
36 |
-
self.scaling_inverse_activation = torch.log
|
37 |
-
|
38 |
-
self.covariance_activation = build_covariance_from_scaling_rotation
|
39 |
-
|
40 |
-
self.opacity_activation = torch.sigmoid
|
41 |
-
self.inverse_opacity_activation = inverse_sigmoid
|
42 |
-
|
43 |
-
self.rotation_activation = torch.nn.functional.normalize
|
44 |
-
|
45 |
-
|
46 |
-
def __init__(self, sh_degree : int):
|
47 |
-
self.active_sh_degree = 0
|
48 |
-
self.max_sh_degree = sh_degree
|
49 |
-
self._xyz = torch.empty(0)
|
50 |
-
self._features_dc = torch.empty(0)
|
51 |
-
self._features_rest = torch.empty(0)
|
52 |
-
self._scaling = torch.empty(0)
|
53 |
-
self._rotation = torch.empty(0)
|
54 |
-
self._opacity = torch.empty(0)
|
55 |
-
self.max_radii2D = torch.empty(0)
|
56 |
-
self.xyz_gradient_accum = torch.empty(0)
|
57 |
-
self.denom = torch.empty(0)
|
58 |
-
self.optimizer = None
|
59 |
-
self.percent_dense = 0
|
60 |
-
self.spatial_lr_scale = 0
|
61 |
-
self.setup_functions()
|
62 |
-
|
63 |
-
def capture(self):
|
64 |
-
return (
|
65 |
-
self.active_sh_degree,
|
66 |
-
self._xyz,
|
67 |
-
self._features_dc,
|
68 |
-
self._features_rest,
|
69 |
-
self._scaling,
|
70 |
-
self._rotation,
|
71 |
-
self._opacity,
|
72 |
-
self.max_radii2D,
|
73 |
-
self.xyz_gradient_accum,
|
74 |
-
self.denom,
|
75 |
-
self.optimizer.state_dict(),
|
76 |
-
self.spatial_lr_scale,
|
77 |
-
)
|
78 |
-
|
79 |
-
def restore(self, model_args, training_args):
|
80 |
-
(self.active_sh_degree,
|
81 |
-
self._xyz,
|
82 |
-
self._features_dc,
|
83 |
-
self._features_rest,
|
84 |
-
self._scaling,
|
85 |
-
self._rotation,
|
86 |
-
self._opacity,
|
87 |
-
self.max_radii2D,
|
88 |
-
xyz_gradient_accum,
|
89 |
-
denom,
|
90 |
-
opt_dict,
|
91 |
-
self.spatial_lr_scale) = model_args
|
92 |
-
self.training_setup(training_args)
|
93 |
-
self.xyz_gradient_accum = xyz_gradient_accum
|
94 |
-
self.denom = denom
|
95 |
-
self.optimizer.load_state_dict(opt_dict)
|
96 |
-
|
97 |
-
@property
|
98 |
-
def get_scaling(self):
|
99 |
-
return self.scaling_activation(self._scaling)
|
100 |
-
|
101 |
-
@property
|
102 |
-
def get_rotation(self):
|
103 |
-
return self.rotation_activation(self._rotation)
|
104 |
-
|
105 |
-
@property
|
106 |
-
def get_xyz(self):
|
107 |
-
return self._xyz
|
108 |
-
|
109 |
-
@property
|
110 |
-
def get_features(self):
|
111 |
-
features_dc = self._features_dc
|
112 |
-
features_rest = self._features_rest
|
113 |
-
return torch.cat((features_dc, features_rest), dim=1)
|
114 |
-
|
115 |
-
@property
|
116 |
-
def get_opacity(self):
|
117 |
-
return self.opacity_activation(self._opacity)
|
118 |
-
|
119 |
-
def get_covariance(self, scaling_modifier = 1):
|
120 |
-
return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)
|
121 |
-
|
122 |
-
def oneupSHdegree(self):
|
123 |
-
if self.active_sh_degree < self.max_sh_degree:
|
124 |
-
self.active_sh_degree += 1
|
125 |
-
|
126 |
-
def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float):
|
127 |
-
self.spatial_lr_scale = spatial_lr_scale
|
128 |
-
fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
|
129 |
-
fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
|
130 |
-
features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda()
|
131 |
-
features[:, :3, 0 ] = fused_color
|
132 |
-
features[:, 3:, 1:] = 0.0
|
133 |
-
|
134 |
-
print("Number of points at initialisation : ", fused_point_cloud.shape[0])
|
135 |
-
|
136 |
-
dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)
|
137 |
-
scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)
|
138 |
-
rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
|
139 |
-
rots[:, 0] = 1
|
140 |
-
|
141 |
-
opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
|
142 |
-
|
143 |
-
self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
|
144 |
-
self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
|
145 |
-
self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))
|
146 |
-
self._scaling = nn.Parameter(scales.requires_grad_(True))
|
147 |
-
self._rotation = nn.Parameter(rots.requires_grad_(True))
|
148 |
-
self._opacity = nn.Parameter(opacities.requires_grad_(True))
|
149 |
-
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
|
150 |
-
|
151 |
-
def training_setup(self, training_args):
|
152 |
-
self.percent_dense = training_args.percent_dense
|
153 |
-
self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
|
154 |
-
self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
|
155 |
-
|
156 |
-
l = [
|
157 |
-
{'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"},
|
158 |
-
{'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"},
|
159 |
-
{'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"},
|
160 |
-
{'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"},
|
161 |
-
{'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"},
|
162 |
-
{'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"}
|
163 |
-
]
|
164 |
-
|
165 |
-
self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
|
166 |
-
self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale,
|
167 |
-
lr_final=training_args.position_lr_final*self.spatial_lr_scale,
|
168 |
-
lr_delay_mult=training_args.position_lr_delay_mult,
|
169 |
-
max_steps=training_args.position_lr_max_steps)
|
170 |
-
|
171 |
-
def update_learning_rate(self, iteration):
|
172 |
-
''' Learning rate scheduling per step '''
|
173 |
-
for param_group in self.optimizer.param_groups:
|
174 |
-
if param_group["name"] == "xyz":
|
175 |
-
lr = self.xyz_scheduler_args(iteration)
|
176 |
-
param_group['lr'] = lr
|
177 |
-
return lr
|
178 |
-
|
179 |
-
def construct_list_of_attributes(self):
|
180 |
-
l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
|
181 |
-
# All channels except the 3 DC
|
182 |
-
for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
|
183 |
-
l.append('f_dc_{}'.format(i))
|
184 |
-
for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]):
|
185 |
-
l.append('f_rest_{}'.format(i))
|
186 |
-
l.append('opacity')
|
187 |
-
for i in range(self._scaling.shape[1]):
|
188 |
-
l.append('scale_{}'.format(i))
|
189 |
-
for i in range(self._rotation.shape[1]):
|
190 |
-
l.append('rot_{}'.format(i))
|
191 |
-
return l
|
192 |
-
|
193 |
-
def save_ply(self, filepath):
|
194 |
-
xyz = self._xyz.detach().cpu().numpy()
|
195 |
-
normals = np.zeros_like(xyz)
|
196 |
-
f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
|
197 |
-
f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
|
198 |
-
opacities = self._opacity.detach().cpu().numpy()
|
199 |
-
scale = self._scaling.detach().cpu().numpy()
|
200 |
-
rotation = self._rotation.detach().cpu().numpy()
|
201 |
-
|
202 |
-
dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
|
203 |
-
|
204 |
-
elements = np.empty(xyz.shape[0], dtype=dtype_full)
|
205 |
-
attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
|
206 |
-
elements[:] = list(map(tuple, attributes))
|
207 |
-
el = PlyElement.describe(elements, 'vertex')
|
208 |
-
PlyData([el]).write(filepath)
|
209 |
-
|
210 |
-
def reset_opacity(self):
|
211 |
-
opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01))
|
212 |
-
optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
|
213 |
-
self._opacity = optimizable_tensors["opacity"]
|
214 |
-
|
215 |
-
def load_ply(self, path):
|
216 |
-
plydata = PlyData.read(path)
|
217 |
-
|
218 |
-
xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
|
219 |
-
np.asarray(plydata.elements[0]["y"]),
|
220 |
-
np.asarray(plydata.elements[0]["z"])), axis=1)
|
221 |
-
opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
|
222 |
-
|
223 |
-
features_dc = np.zeros((xyz.shape[0], 3, 1))
|
224 |
-
features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
|
225 |
-
features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
|
226 |
-
features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
|
227 |
-
|
228 |
-
extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
|
229 |
-
extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1]))
|
230 |
-
assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3
|
231 |
-
features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
|
232 |
-
for idx, attr_name in enumerate(extra_f_names):
|
233 |
-
features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
234 |
-
# Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
|
235 |
-
features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
|
236 |
-
|
237 |
-
scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
|
238 |
-
scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1]))
|
239 |
-
scales = np.zeros((xyz.shape[0], len(scale_names)))
|
240 |
-
for idx, attr_name in enumerate(scale_names):
|
241 |
-
scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
242 |
-
|
243 |
-
rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
|
244 |
-
rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1]))
|
245 |
-
rots = np.zeros((xyz.shape[0], len(rot_names)))
|
246 |
-
for idx, attr_name in enumerate(rot_names):
|
247 |
-
rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
248 |
-
|
249 |
-
self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
|
250 |
-
self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
|
251 |
-
self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
|
252 |
-
self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True))
|
253 |
-
self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True))
|
254 |
-
self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True))
|
255 |
-
|
256 |
-
self.active_sh_degree = self.max_sh_degree
|
257 |
-
|
258 |
-
def replace_tensor_to_optimizer(self, tensor, name):
|
259 |
-
optimizable_tensors = {}
|
260 |
-
for group in self.optimizer.param_groups:
|
261 |
-
if group["name"] == name:
|
262 |
-
stored_state = self.optimizer.state.get(group['params'][0], None)
|
263 |
-
stored_state["exp_avg"] = torch.zeros_like(tensor)
|
264 |
-
stored_state["exp_avg_sq"] = torch.zeros_like(tensor)
|
265 |
-
|
266 |
-
del self.optimizer.state[group['params'][0]]
|
267 |
-
group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
|
268 |
-
self.optimizer.state[group['params'][0]] = stored_state
|
269 |
-
|
270 |
-
optimizable_tensors[group["name"]] = group["params"][0]
|
271 |
-
return optimizable_tensors
|
272 |
-
|
273 |
-
def _prune_optimizer(self, mask):
|
274 |
-
optimizable_tensors = {}
|
275 |
-
for group in self.optimizer.param_groups:
|
276 |
-
stored_state = self.optimizer.state.get(group['params'][0], None)
|
277 |
-
if stored_state is not None:
|
278 |
-
stored_state["exp_avg"] = stored_state["exp_avg"][mask]
|
279 |
-
stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
|
280 |
-
|
281 |
-
del self.optimizer.state[group['params'][0]]
|
282 |
-
group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
|
283 |
-
self.optimizer.state[group['params'][0]] = stored_state
|
284 |
-
|
285 |
-
optimizable_tensors[group["name"]] = group["params"][0]
|
286 |
-
else:
|
287 |
-
group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
|
288 |
-
optimizable_tensors[group["name"]] = group["params"][0]
|
289 |
-
return optimizable_tensors
|
290 |
-
|
291 |
-
def prune_points(self, mask):
|
292 |
-
valid_points_mask = ~mask
|
293 |
-
optimizable_tensors = self._prune_optimizer(valid_points_mask)
|
294 |
-
|
295 |
-
self._xyz = optimizable_tensors["xyz"]
|
296 |
-
self._features_dc = optimizable_tensors["f_dc"]
|
297 |
-
self._features_rest = optimizable_tensors["f_rest"]
|
298 |
-
self._opacity = optimizable_tensors["opacity"]
|
299 |
-
self._scaling = optimizable_tensors["scaling"]
|
300 |
-
self._rotation = optimizable_tensors["rotation"]
|
301 |
-
|
302 |
-
self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]
|
303 |
-
|
304 |
-
self.denom = self.denom[valid_points_mask]
|
305 |
-
self.max_radii2D = self.max_radii2D[valid_points_mask]
|
306 |
-
|
307 |
-
def cat_tensors_to_optimizer(self, tensors_dict):
|
308 |
-
optimizable_tensors = {}
|
309 |
-
for group in self.optimizer.param_groups:
|
310 |
-
assert len(group["params"]) == 1
|
311 |
-
extension_tensor = tensors_dict[group["name"]]
|
312 |
-
stored_state = self.optimizer.state.get(group['params'][0], None)
|
313 |
-
if stored_state is not None:
|
314 |
-
|
315 |
-
stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0)
|
316 |
-
stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0)
|
317 |
-
|
318 |
-
del self.optimizer.state[group['params'][0]]
|
319 |
-
group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
|
320 |
-
self.optimizer.state[group['params'][0]] = stored_state
|
321 |
-
|
322 |
-
optimizable_tensors[group["name"]] = group["params"][0]
|
323 |
-
else:
|
324 |
-
group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
|
325 |
-
optimizable_tensors[group["name"]] = group["params"][0]
|
326 |
-
|
327 |
-
return optimizable_tensors
|
328 |
-
|
329 |
-
def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation):
|
330 |
-
d = {"xyz": new_xyz,
|
331 |
-
"f_dc": new_features_dc,
|
332 |
-
"f_rest": new_features_rest,
|
333 |
-
"opacity": new_opacities,
|
334 |
-
"scaling" : new_scaling,
|
335 |
-
"rotation" : new_rotation}
|
336 |
-
|
337 |
-
optimizable_tensors = self.cat_tensors_to_optimizer(d)
|
338 |
-
self._xyz = optimizable_tensors["xyz"]
|
339 |
-
self._features_dc = optimizable_tensors["f_dc"]
|
340 |
-
self._features_rest = optimizable_tensors["f_rest"]
|
341 |
-
self._opacity = optimizable_tensors["opacity"]
|
342 |
-
self._scaling = optimizable_tensors["scaling"]
|
343 |
-
self._rotation = optimizable_tensors["rotation"]
|
344 |
-
|
345 |
-
self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
|
346 |
-
self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
|
347 |
-
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
|
348 |
-
|
349 |
-
def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
|
350 |
-
n_init_points = self.get_xyz.shape[0]
|
351 |
-
# Extract points that satisfy the gradient condition
|
352 |
-
padded_grad = torch.zeros((n_init_points), device="cuda")
|
353 |
-
padded_grad[:grads.shape[0]] = grads.squeeze()
|
354 |
-
selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
|
355 |
-
selected_pts_mask = torch.logical_and(selected_pts_mask,
|
356 |
-
torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent)
|
357 |
-
|
358 |
-
stds = self.get_scaling[selected_pts_mask].repeat(N,1)
|
359 |
-
means =torch.zeros((stds.size(0), 3),device="cuda")
|
360 |
-
samples = torch.normal(mean=means, std=stds)
|
361 |
-
rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1)
|
362 |
-
new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1)
|
363 |
-
new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N))
|
364 |
-
new_rotation = self._rotation[selected_pts_mask].repeat(N,1)
|
365 |
-
new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1)
|
366 |
-
new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1)
|
367 |
-
new_opacity = self._opacity[selected_pts_mask].repeat(N,1)
|
368 |
-
|
369 |
-
self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation)
|
370 |
-
|
371 |
-
prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool)))
|
372 |
-
self.prune_points(prune_filter)
|
373 |
-
|
374 |
-
def densify_and_clone(self, grads, grad_threshold, scene_extent):
|
375 |
-
# Extract points that satisfy the gradient condition
|
376 |
-
selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False)
|
377 |
-
selected_pts_mask = torch.logical_and(selected_pts_mask,
|
378 |
-
torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent)
|
379 |
-
|
380 |
-
new_xyz = self._xyz[selected_pts_mask]
|
381 |
-
new_features_dc = self._features_dc[selected_pts_mask]
|
382 |
-
new_features_rest = self._features_rest[selected_pts_mask]
|
383 |
-
new_opacities = self._opacity[selected_pts_mask]
|
384 |
-
new_scaling = self._scaling[selected_pts_mask]
|
385 |
-
new_rotation = self._rotation[selected_pts_mask]
|
386 |
-
|
387 |
-
self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation)
|
388 |
-
|
389 |
-
def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size):
|
390 |
-
grads = self.xyz_gradient_accum / self.denom
|
391 |
-
grads[grads.isnan()] = 0.0
|
392 |
-
|
393 |
-
self.densify_and_clone(grads, max_grad, extent)
|
394 |
-
self.densify_and_split(grads, max_grad, extent)
|
395 |
-
|
396 |
-
prune_mask = (self.get_opacity < min_opacity).squeeze()
|
397 |
-
if max_screen_size:
|
398 |
-
big_points_vs = self.max_radii2D > max_screen_size
|
399 |
-
big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
|
400 |
-
prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws)
|
401 |
-
self.prune_points(prune_mask)
|
402 |
-
|
403 |
-
torch.cuda.empty_cache()
|
404 |
-
|
405 |
-
def add_densification_stats(self, viewspace_point_tensor, update_filter):
|
406 |
-
self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True)
|
407 |
-
self.denom[update_filter] += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|