Alesteba commited on
Commit
c5dcb67
1 Parent(s): b6e035a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -174
app.py CHANGED
@@ -17,180 +17,6 @@ H = 25
17
  W = 25
18
  focal = 0.6911112070083618
19
 
20
- def encode_position(x):
21
- """Encodes the position into its corresponding Fourier feature.
22
- Args:
23
- x: The input coordinate.
24
- Returns:
25
- Fourier features tensors of the position.
26
- """
27
- positions = [x]
28
- for i in range(POS_ENCODE_DIMS):
29
- for fn in [tf.sin, tf.cos]:
30
- positions.append(fn(2.0 ** i * x))
31
- return tf.concat(positions, axis=-1)
32
-
33
-
34
- def get_rays(height, width, focal, pose):
35
- """Computes origin point and direction vector of rays.
36
- Args:
37
- height: Height of the image.
38
- width: Width of the image.
39
- focal: The focal length between the images and the camera.
40
- pose: The pose matrix of the camera.
41
- Returns:
42
- Tuple of origin point and direction vector for rays.
43
- """
44
- # Build a meshgrid for the rays.
45
- i, j = tf.meshgrid(
46
- tf.range(width, dtype=tf.float32),
47
- tf.range(height, dtype=tf.float32),
48
- indexing="xy",
49
- )
50
-
51
- # Normalize the x axis coordinates.
52
- transformed_i = (i - width * 0.5) / focal
53
-
54
- # Normalize the y axis coordinates.
55
- transformed_j = (j - height * 0.5) / focal
56
-
57
- # Create the direction unit vectors.
58
- directions = tf.stack([transformed_i, -transformed_j, -tf.ones_like(i)], axis=-1)
59
-
60
- # Get the camera matrix.
61
- camera_matrix = pose[:3, :3]
62
- height_width_focal = pose[:3, -1]
63
-
64
- # Get origins and directions for the rays.
65
- transformed_dirs = directions[..., None, :]
66
- camera_dirs = transformed_dirs * camera_matrix
67
- ray_directions = tf.reduce_sum(camera_dirs, axis=-1)
68
- ray_origins = tf.broadcast_to(height_width_focal, tf.shape(ray_directions))
69
-
70
- # Return the origins and directions.
71
- return (ray_origins, ray_directions)
72
-
73
-
74
- def render_flat_rays(ray_origins, ray_directions, near, far, num_samples, rand=False):
75
-
76
- """Renders the rays and flattens it.
77
-
78
- Args:
79
- ray_origins: The origin points for rays.
80
- ray_directions: The direction unit vectors for the rays.
81
- near: The near bound of the volumetric scene.
82
- far: The far bound of the volumetric scene.
83
- num_samples: Number of sample points in a ray.
84
- rand: Choice for randomising the sampling strategy.
85
-
86
- Returns:
87
- Tuple of flattened rays and sample points on each rays.
88
- """
89
-
90
- # Compute 3D query points.
91
- # Equation: r(t) = o+td -> Building the "t" here.
92
-
93
- t_vals = tf.linspace(near, far, num_samples)
94
-
95
- if rand:
96
-
97
- # Inject uniform noise into sample space to make the sampling
98
- # continuous.
99
- shape = list(ray_origins.shape[:-1]) + [num_samples]
100
- noise = tf.random.uniform(shape=shape) * (far - near) / num_samples
101
- t_vals = t_vals + noise
102
-
103
- # Equation: r(t) = o + td -> Building the "r" here.
104
-
105
- rays = ray_origins[..., None, :] + (
106
- ray_directions[..., None, :] * t_vals[..., None]
107
- )
108
- rays_flat = tf.reshape(rays, [-1, 3])
109
- rays_flat = encode_position(rays_flat)
110
- return (rays_flat, t_vals)
111
-
112
-
113
- def map_fn(pose):
114
-
115
- """Maps individual pose to flattened rays and sample points.
116
-
117
- Args:
118
- pose: The pose matrix of the camera.
119
-
120
- Returns:
121
- Tuple of flattened rays and sample points corresponding to the
122
- camera pose.
123
- """
124
-
125
- (ray_origins, ray_directions) = get_rays(height=H, width=W, focal=focal, pose=pose)
126
- (rays_flat, t_vals) = render_flat_rays(
127
- ray_origins=ray_origins,
128
- ray_directions=ray_directions,
129
- near=2.0,
130
- far=6.0,
131
- num_samples=NUM_SAMPLES,
132
- rand=True,
133
- )
134
-
135
- return (rays_flat, t_vals)
136
-
137
-
138
- def render_rgb_depth(model, rays_flat, t_vals, rand=True, train=True):
139
-
140
- """Generates the RGB image and depth map from model prediction.
141
-
142
- Args:
143
- model: The MLP model that is trained to predict the rgb and
144
- volume density of the volumetric scene.
145
- rays_flat: The flattened rays that serve as the input to
146
- the NeRF model.
147
- t_vals: The sample points for the rays.
148
- rand: Choice to randomise the sampling strategy.
149
- train: Whether the model is in the training or testing phase.
150
-
151
- Returns:
152
- Tuple of rgb image and depth map.
153
- """
154
-
155
- # Get the predictions from the nerf model and reshape it.
156
- if train:
157
- predictions = model(rays_flat)
158
- else:
159
- predictions = model.predict(rays_flat)
160
- predictions = tf.reshape(predictions, shape=(BATCH_SIZE, H, W, NUM_SAMPLES, 4))
161
-
162
- # Slice the predictions into rgb and sigma.
163
- rgb = tf.sigmoid(predictions[..., :-1])
164
- sigma_a = tf.nn.relu(predictions[..., -1])
165
-
166
- # Get the distance of adjacent intervals.
167
- delta = t_vals[..., 1:] - t_vals[..., :-1]
168
- # delta shape = (num_samples)
169
-
170
- if rand:
171
- delta = tf.concat(
172
- [delta, tf.broadcast_to([1e10], shape=(BATCH_SIZE, H, W, 1))], axis=-1
173
- )
174
- alpha = 1.0 - tf.exp(-sigma_a * delta)
175
- else:
176
- delta = tf.concat(
177
- [delta, tf.broadcast_to([1e10], shape=(BATCH_SIZE, 1))], axis=-1
178
- )
179
- alpha = 1.0 - tf.exp(-sigma_a * delta[:, None, None, :])
180
-
181
- # Get transmittance.
182
- exp_term = 1.0 - alpha
183
- epsilon = 1e-10
184
- transmittance = tf.math.cumprod(exp_term + epsilon, axis=-1, exclusive=True)
185
- weights = alpha * transmittance
186
- rgb = tf.reduce_sum(weights[..., None] * rgb, axis=-2)
187
-
188
- if rand:
189
- depth_map = tf.reduce_sum(weights * t_vals, axis=-1)
190
- else:
191
- depth_map = tf.reduce_sum(weights * t_vals[:, None, None], axis=-1)
192
- return (rgb, depth_map)
193
-
194
  def show_rendered_image(r,theta,phi):
195
  # Get the camera to world matrix.
196
  c2w = pose_spherical(theta, phi, r)
 
17
  W = 25
18
  focal = 0.6911112070083618
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def show_rendered_image(r,theta,phi):
21
  # Get the camera to world matrix.
22
  c2w = pose_spherical(theta, phi, r)