Alesteba commited on
Commit
b369cdc
1 Parent(s): c5dcb67

Update rendering.py

Browse files
Files changed (1) hide show
  1. rendering.py +160 -0
rendering.py CHANGED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tensorflow as tf
3
+ import numpy as np
4
+
5
+
6
+ def encode_position(x):
7
+ """Encodes the position into its corresponding Fourier feature.
8
+ Args:
9
+ x: The input coordinate.
10
+ Returns:
11
+ Fourier features tensors of the position.
12
+ """
13
+ positions = [x]
14
+ for i in range(POS_ENCODE_DIMS):
15
+ for fn in [tf.sin, tf.cos]:
16
+ positions.append(fn(2.0 ** i * x))
17
+ return tf.concat(positions, axis=-1)
18
+
19
+
20
+ def get_rays(height, width, focal, pose):
21
+ """Computes origin point and direction vector of rays.
22
+ Args:
23
+ height: Height of the image.
24
+ width: Width of the image.
25
+ focal: The focal length between the images and the camera.
26
+ pose: The pose matrix of the camera.
27
+ Returns:
28
+ Tuple of origin point and direction vector for rays.
29
+ """
30
+ # Build a meshgrid for the rays.
31
+ i, j = tf.meshgrid(
32
+ tf.range(width, dtype=tf.float32),
33
+ tf.range(height, dtype=tf.float32),
34
+ indexing="xy",
35
+ )
36
+
37
+ # Normalize the x axis coordinates.
38
+ transformed_i = (i - width * 0.5) / focal
39
+
40
+ # Normalize the y axis coordinates.
41
+ transformed_j = (j - height * 0.5) / focal
42
+
43
+ # Create the direction unit vectors.
44
+ directions = tf.stack([transformed_i, -transformed_j, -tf.ones_like(i)], axis=-1)
45
+
46
+ # Get the camera matrix.
47
+ camera_matrix = pose[:3, :3]
48
+ height_width_focal = pose[:3, -1]
49
+
50
+ # Get origins and directions for the rays.
51
+ transformed_dirs = directions[..., None, :]
52
+ camera_dirs = transformed_dirs * camera_matrix
53
+ ray_directions = tf.reduce_sum(camera_dirs, axis=-1)
54
+ ray_origins = tf.broadcast_to(height_width_focal, tf.shape(ray_directions))
55
+
56
+ # Return the origins and directions.
57
+ return (ray_origins, ray_directions)
58
+
59
+
60
+ def render_flat_rays(ray_origins, ray_directions, near, far, num_samples, rand=False):
61
+ """Renders the rays and flattens it.
62
+ Args:
63
+ ray_origins: The origin points for rays.
64
+ ray_directions: The direction unit vectors for the rays.
65
+ near: The near bound of the volumetric scene.
66
+ far: The far bound of the volumetric scene.
67
+ num_samples: Number of sample points in a ray.
68
+ rand: Choice for randomising the sampling strategy.
69
+ Returns:
70
+ Tuple of flattened rays and sample points on each rays.
71
+ """
72
+ # Compute 3D query points.
73
+ # Equation: r(t) = o+td -> Building the "t" here.
74
+ t_vals = tf.linspace(near, far, num_samples)
75
+ if rand:
76
+ # Inject uniform noise into sample space to make the sampling
77
+ # continuous.
78
+ shape = list(ray_origins.shape[:-1]) + [num_samples]
79
+ noise = tf.random.uniform(shape=shape) * (far - near) / num_samples
80
+ t_vals = t_vals + noise
81
+
82
+ # Equation: r(t) = o + td -> Building the "r" here.
83
+ rays = ray_origins[..., None, :] + (
84
+ ray_directions[..., None, :] * t_vals[..., None]
85
+ )
86
+ rays_flat = tf.reshape(rays, [-1, 3])
87
+ rays_flat = encode_position(rays_flat)
88
+ return (rays_flat, t_vals)
89
+
90
+
91
+ def map_fn(pose):
92
+ """Maps individual pose to flattened rays and sample points.
93
+ Args:
94
+ pose: The pose matrix of the camera.
95
+ Returns:
96
+ Tuple of flattened rays and sample points corresponding to the
97
+ camera pose.
98
+ """
99
+ (ray_origins, ray_directions) = get_rays(height=H, width=W, focal=focal, pose=pose)
100
+ (rays_flat, t_vals) = render_flat_rays(
101
+ ray_origins=ray_origins,
102
+ ray_directions=ray_directions,
103
+ near=2.0,
104
+ far=6.0,
105
+ num_samples=NUM_SAMPLES,
106
+ rand=True,
107
+ )
108
+ return (rays_flat, t_vals)
109
+
110
+
111
+ def render_rgb_depth(model, rays_flat, t_vals, rand=True, train=True):
112
+ """Generates the RGB image and depth map from model prediction.
113
+ Args:
114
+ model: The MLP model that is trained to predict the rgb and
115
+ volume density of the volumetric scene.
116
+ rays_flat: The flattened rays that serve as the input to
117
+ the NeRF model.
118
+ t_vals: The sample points for the rays.
119
+ rand: Choice to randomise the sampling strategy.
120
+ train: Whether the model is in the training or testing phase.
121
+ Returns:
122
+ Tuple of rgb image and depth map.
123
+ """
124
+ # Get the predictions from the nerf model and reshape it.
125
+ if train:
126
+ predictions = model(rays_flat)
127
+ else:
128
+ predictions = model.predict(rays_flat)
129
+ predictions = tf.reshape(predictions, shape=(BATCH_SIZE, H, W, NUM_SAMPLES, 4))
130
+
131
+ # Slice the predictions into rgb and sigma.
132
+ rgb = tf.sigmoid(predictions[..., :-1])
133
+ sigma_a = tf.nn.relu(predictions[..., -1])
134
+
135
+ # Get the distance of adjacent intervals.
136
+ delta = t_vals[..., 1:] - t_vals[..., :-1]
137
+ # delta shape = (num_samples)
138
+ if rand:
139
+ delta = tf.concat(
140
+ [delta, tf.broadcast_to([1e10], shape=(BATCH_SIZE, H, W, 1))], axis=-1
141
+ )
142
+ alpha = 1.0 - tf.exp(-sigma_a * delta)
143
+ else:
144
+ delta = tf.concat(
145
+ [delta, tf.broadcast_to([1e10], shape=(BATCH_SIZE, 1))], axis=-1
146
+ )
147
+ alpha = 1.0 - tf.exp(-sigma_a * delta[:, None, None, :])
148
+
149
+ # Get transmittance.
150
+ exp_term = 1.0 - alpha
151
+ epsilon = 1e-10
152
+ transmittance = tf.math.cumprod(exp_term + epsilon, axis=-1, exclusive=True)
153
+ weights = alpha * transmittance
154
+ rgb = tf.reduce_sum(weights[..., None] * rgb, axis=-2)
155
+
156
+ if rand:
157
+ depth_map = tf.reduce_sum(weights * t_vals, axis=-1)
158
+ else:
159
+ depth_map = tf.reduce_sum(weights * t_vals[:, None, None], axis=-1)
160
+ return (rgb, depth_map)