tokenid commited on
Commit
ad06aed
β€’
1 Parent(s): 6af576b
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. .gitignore +1 -0
  2. README.md +1 -1
  3. app.py +349 -0
  4. configs/instant-mesh-base.yaml +22 -0
  5. configs/instant-mesh-large.yaml +22 -0
  6. configs/instant-nerf-base.yaml +21 -0
  7. configs/instant-nerf-large.yaml +21 -0
  8. examples/bird.jpg +0 -0
  9. examples/bubble_mart_blue.png +0 -0
  10. examples/cake.jpg +0 -0
  11. examples/cartoon_dinosaur.png +0 -0
  12. examples/cartoon_girl.jpg +0 -0
  13. examples/chair_comfort.jpg +0 -0
  14. examples/chair_wood.jpg +0 -0
  15. examples/chest.jpg +0 -0
  16. examples/cube.png +0 -0
  17. examples/extinguisher.png +0 -0
  18. examples/fruit_bycycle.jpg +0 -0
  19. examples/fruit_elephant.jpg +0 -0
  20. examples/genshin_building.png +0 -0
  21. examples/house2.jpg +0 -0
  22. examples/kunkun.png +0 -0
  23. examples/mushroom_teapot.jpg +0 -0
  24. examples/pikachu.png +0 -0
  25. examples/pistol.png +0 -0
  26. examples/plant.jpg +0 -0
  27. examples/robot.jpg +0 -0
  28. examples/sea_turtle.png +0 -0
  29. examples/skating_shoe.jpg +0 -0
  30. examples/sorting_board.png +0 -0
  31. examples/sword.png +0 -0
  32. examples/toy_car.jpg +0 -0
  33. examples/toyduck.png +0 -0
  34. examples/watermelon.png +0 -0
  35. examples/whitedog.png +0 -0
  36. examples/x_cube.jpg +0 -0
  37. examples/x_teapot.jpg +0 -0
  38. examples/x_toyduck.jpg +0 -0
  39. requirements.txt +21 -0
  40. src/__init__.py +0 -0
  41. src/data/__init__.py +0 -0
  42. src/data/objaverse.py +329 -0
  43. src/model.py +310 -0
  44. src/model_mesh.py +325 -0
  45. src/models/__init__.py +0 -0
  46. src/models/decoder/__init__.py +0 -0
  47. src/models/decoder/transformer.py +123 -0
  48. src/models/encoder/__init__.py +0 -0
  49. src/models/encoder/dino.py +550 -0
  50. src/models/encoder/dino_wrapper.py +80 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
README.md CHANGED
@@ -7,7 +7,7 @@ sdk: gradio
7
  sdk_version: 4.25.0
8
  app_file: app.py
9
  pinned: false
10
- license: mit
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
7
  sdk_version: 4.25.0
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import imageio
3
+ import numpy as np
4
+ import torch
5
+ import rembg
6
+ from PIL import Image
7
+ from torchvision.transforms import v2
8
+ from pytorch_lightning import seed_everything
9
+ from omegaconf import OmegaConf
10
+ from einops import rearrange, repeat
11
+ from tqdm import tqdm
12
+ from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
13
+
14
+ from src.utils.train_util import instantiate_from_config
15
+ from src.utils.camera_util import (
16
+ FOV_to_intrinsics,
17
+ get_zero123plus_input_cameras,
18
+ get_circular_camera_poses,
19
+ )
20
+ from src.utils.mesh_util import save_obj
21
+ from src.utils.infer_util import remove_background, resize_foreground, images_to_video
22
+
23
+ import tempfile
24
+ from functools import partial
25
+
26
+ from huggingface_hub import hf_hub_download
27
+ import spaces
28
+
29
+
30
+ def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
31
+ """
32
+ Get the rendering camera parameters.
33
+ """
34
+ c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation)
35
+ if is_flexicubes:
36
+ cameras = torch.linalg.inv(c2ws)
37
+ cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
38
+ else:
39
+ extrinsics = c2ws.flatten(-2)
40
+ intrinsics = FOV_to_intrinsics(50.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2)
41
+ cameras = torch.cat([extrinsics, intrinsics], dim=-1)
42
+ cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
43
+ return cameras
44
+
45
+
46
+ def images_to_video(images, output_path, fps=30):
47
+ # images: (N, C, H, W)
48
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
49
+ frames = []
50
+ for i in range(images.shape[0]):
51
+ frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8).clip(0, 255)
52
+ assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
53
+ f"Frame shape mismatch: {frame.shape} vs {images.shape}"
54
+ assert frame.min() >= 0 and frame.max() <= 255, \
55
+ f"Frame value out of range: {frame.min()} ~ {frame.max()}"
56
+ frames.append(frame)
57
+ imageio.mimwrite(output_path, np.stack(frames), fps=fps, codec='h264')
58
+
59
+
60
+ ###############################################################################
61
+ # Configuration.
62
+ ###############################################################################
63
+
64
+ config_path = 'configs/instant-mesh-large-eval.yaml'
65
+ config = OmegaConf.load(config_path)
66
+ config_name = os.path.basename(config_path).replace('.yaml', '')
67
+ model_config = config.model_config
68
+ infer_config = config.infer_config
69
+
70
+ IS_FLEXICUBES = True if config_name.startswith('instant-mesh') else False
71
+
72
+ device = torch.device('cuda')
73
+
74
+ # load diffusion model
75
+ print('Loading diffusion model ...')
76
+ pipeline = DiffusionPipeline.from_pretrained(
77
+ "sudo-ai/zero123plus-v1.2",
78
+ custom_pipeline="zero123plus",
79
+ torch_dtype=torch.float16,
80
+ )
81
+ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
82
+ pipeline.scheduler.config, timestep_spacing='trailing'
83
+ )
84
+
85
+ # load custom white-background UNet
86
+ unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
87
+ state_dict = torch.load(unet_ckpt_path, map_location='cpu')
88
+ pipeline.unet.load_state_dict(state_dict, strict=True)
89
+
90
+ pipeline = pipeline.to(device)
91
+
92
+ # load reconstruction model
93
+ print('Loading reconstruction model ...')
94
+ model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model")
95
+ model = instantiate_from_config(model_config)
96
+ state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
97
+ state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k}
98
+ model.load_state_dict(state_dict, strict=True)
99
+
100
+ model = model.to(device)
101
+ if IS_FLEXICUBES:
102
+ model.init_flexicubes_geometry(device)
103
+ model = model.eval()
104
+
105
+ print('Loading Finished!')
106
+
107
+
108
+ def check_input_image(input_image):
109
+ if input_image is None:
110
+ raise gr.Error("No image uploaded!")
111
+
112
+
113
+ def preprocess(input_image, do_remove_background):
114
+
115
+ rembg_session = rembg.new_session() if do_remove_background else None
116
+
117
+ if do_remove_background:
118
+ input_image = remove_background(input_image, rembg_session)
119
+ input_image = resize_foreground(input_image, 0.85)
120
+
121
+ return input_image
122
+
123
+
124
+ @spaces.GPU
125
+ def generate_mvs(input_image, sample_steps, sample_seed):
126
+
127
+ seed_everything(sample_seed)
128
+
129
+ # sampling
130
+ z123_image = pipeline(
131
+ input_image,
132
+ num_inference_steps=sample_steps
133
+ ).images[0]
134
+
135
+ show_image = np.asarray(z123_image, dtype=np.uint8)
136
+ show_image = torch.from_numpy(show_image) # (960, 640, 3)
137
+ show_image = rearrange(show_image, '(n h) (m w) c -> (m h) (n w) c', n=3, m=2)
138
+ show_image = Image.fromarray(show_image.numpy())
139
+
140
+ return z123_image, show_image
141
+
142
+
143
+ @spaces.GPU
144
+ def make_mesh(mesh_fpath, planes):
145
+
146
+ mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
147
+ mesh_dirname = os.path.dirname(mesh_fpath)
148
+ mesh_vis_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
149
+
150
+ with torch.no_grad():
151
+
152
+ # get mesh
153
+ mesh_out = model.extract_mesh(
154
+ planes,
155
+ use_texture_map=False,
156
+ **infer_config,
157
+ )
158
+
159
+ vertices, faces, vertex_colors = mesh_out
160
+ vertices = vertices[:, [0, 2, 1]]
161
+ vertices[:, -1] *= -1
162
+
163
+ save_obj(vertices, faces, vertex_colors, mesh_fpath)
164
+
165
+ print(f"Mesh saved to {mesh_fpath}")
166
+
167
+ return mesh_fpath
168
+
169
+
170
+ @spaces.GPU
171
+ def make3d(images):
172
+
173
+ images = np.asarray(images, dtype=np.float32) / 255.0
174
+ images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
175
+ images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
176
+
177
+ input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=2.5).to(device)
178
+ render_cameras = get_render_cameras(batch_size=1, radius=2.5, is_flexicubes=IS_FLEXICUBES).to(device)
179
+
180
+ images = images.unsqueeze(0).to(device)
181
+ images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
182
+
183
+ mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
184
+ print(mesh_fpath)
185
+ mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
186
+ mesh_dirname = os.path.dirname(mesh_fpath)
187
+ video_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.mp4")
188
+
189
+ with torch.no_grad():
190
+ # get triplane
191
+ planes = model.forward_planes(images, input_cameras)
192
+
193
+ # get video
194
+ chunk_size = 20 if IS_FLEXICUBES else 1
195
+ render_size = 384
196
+
197
+ frames = []
198
+ for i in tqdm(range(0, render_cameras.shape[1], chunk_size)):
199
+ if IS_FLEXICUBES:
200
+ frame = model.forward_geometry(
201
+ planes,
202
+ render_cameras[:, i:i+chunk_size],
203
+ render_size=render_size,
204
+ )['img']
205
+ else:
206
+ frame = model.synthesizer(
207
+ planes,
208
+ cameras=render_cameras[:, i:i+chunk_size],
209
+ render_size=render_size,
210
+ )['images_rgb']
211
+ frames.append(frame)
212
+ frames = torch.cat(frames, dim=1)
213
+
214
+ images_to_video(
215
+ frames[0],
216
+ video_fpath,
217
+ fps=30,
218
+ )
219
+
220
+ print(f"Video saved to {video_fpath}")
221
+
222
+ mesh_fpath = make_mesh(mesh_fpath, planes)
223
+
224
+ return video_fpath, mesh_fpath
225
+
226
+
227
+ import gradio as gr
228
+
229
+ _HEADER_ = '''
230
+ <h2><b>Official πŸ€— Gradio demo for</b>
231
+ <a href='https://github.com/TencentARC/InstantMesh' target='_blank'>
232
+ <b>InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models</b>
233
+ </a>.
234
+ </h2>
235
+ '''
236
+
237
+ _LINKS_ = '''
238
+ <h3>Code is available at <a href='https://github.com/TencentARC/InstantMesh' target='_blank'>GitHub</a></h3>
239
+ <h3>Report is available at <a href='https://arxiv.org/abs/2404.07191' target='_blank'>ArXiv</a></h3>
240
+ '''
241
+
242
+ _CITE_ = r"""
243
+ ```bibtex
244
+ @article{xu2024instantmesh,
245
+ title={InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models},
246
+ author={Xu, Jiale and Cheng, Weihao and Gao, Yiming and Wang, Xintao and Gao, Shenghua and Shan, Ying},
247
+ journal={arXiv preprint arXiv:2404.07191},
248
+ year={2024}
249
+ }
250
+ ```
251
+ """
252
+
253
+
254
+ with gr.Blocks() as demo:
255
+ gr.Markdown(_HEADER_)
256
+ with gr.Row(variant="panel"):
257
+ with gr.Column():
258
+ with gr.Row():
259
+ input_image = gr.Image(
260
+ label="Input Image",
261
+ image_mode="RGBA",
262
+ sources="upload",
263
+ width=256,
264
+ height=256,
265
+ type="pil",
266
+ elem_id="content_image",
267
+ )
268
+ processed_image = gr.Image(
269
+ label="Processed Image",
270
+ image_mode="RGBA",
271
+ width=256,
272
+ height=256,
273
+ type="pil",
274
+ interactive=False
275
+ )
276
+ with gr.Row():
277
+ with gr.Group():
278
+ do_remove_background = gr.Checkbox(
279
+ label="Remove Background", value=True
280
+ )
281
+ sample_seed = gr.Number(value=42, label="Seed (Try a different value if the result is unsatisfying)", precision=0)
282
+
283
+ sample_steps = gr.Slider(
284
+ label="Sample Steps",
285
+ minimum=30,
286
+ maximum=75,
287
+ value=75,
288
+ step=5
289
+ )
290
+
291
+ with gr.Row():
292
+ submit = gr.Button("Generate", elem_id="generate", variant="primary")
293
+
294
+ with gr.Row(variant="panel"):
295
+ gr.Examples(
296
+ examples=[
297
+ os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
298
+ ],
299
+ inputs=[input_image],
300
+ label="Examples",
301
+ examples_per_page=20
302
+ )
303
+
304
+ with gr.Column():
305
+
306
+ with gr.Row():
307
+
308
+ with gr.Column():
309
+ mv_show_images = gr.Image(
310
+ label="Generated Multi-views",
311
+ type="pil",
312
+ width=379,
313
+ interactive=False
314
+ )
315
+
316
+ with gr.Column():
317
+ output_video = gr.Video(
318
+ label="video", format="mp4",
319
+ width=379,
320
+ autoplay=True,
321
+ interactive=False
322
+ )
323
+
324
+ with gr.Row():
325
+ output_model_obj = gr.Model3D(
326
+ label="Output Model (OBJ Format)",
327
+ width=768,
328
+ interactive=False,
329
+ )
330
+ gr.Markdown(_LINKS_)
331
+ gr.Markdown(_CITE_)
332
+
333
+ mv_images = gr.State()
334
+
335
+ submit.click(fn=check_input_image, inputs=[input_image]).success(
336
+ fn=preprocess,
337
+ inputs=[input_image, do_remove_background],
338
+ outputs=[processed_image],
339
+ ).success(
340
+ fn=generate_mvs,
341
+ inputs=[processed_image, sample_steps, sample_seed],
342
+ outputs=[mv_images, mv_show_images],
343
+ ).success(
344
+ fn=make3d,
345
+ inputs=[mv_images],
346
+ outputs=[output_video, output_model_obj]
347
+ )
348
+
349
+ demo.launch()
configs/instant-mesh-base.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_config:
2
+ target: src.models.lrm_mesh.InstantMesh
3
+ params:
4
+ encoder_feat_dim: 768
5
+ encoder_freeze: false
6
+ encoder_model_name: facebook/dino-vitb16
7
+ transformer_dim: 1024
8
+ transformer_layers: 12
9
+ transformer_heads: 16
10
+ triplane_low_res: 32
11
+ triplane_high_res: 64
12
+ triplane_dim: 40
13
+ rendering_samples_per_ray: 96
14
+ grid_res: 128
15
+ grid_scale: 2.1
16
+
17
+
18
+ infer_config:
19
+ unet_path: ckpts/diffusion_pytorch_model.bin
20
+ model_path: ckpts/instant_mesh_base.ckpt
21
+ texture_resolution: 1024
22
+ render_resolution: 512
configs/instant-mesh-large.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_config:
2
+ target: src.models.lrm_mesh.InstantMesh
3
+ params:
4
+ encoder_feat_dim: 768
5
+ encoder_freeze: false
6
+ encoder_model_name: facebook/dino-vitb16
7
+ transformer_dim: 1024
8
+ transformer_layers: 16
9
+ transformer_heads: 16
10
+ triplane_low_res: 32
11
+ triplane_high_res: 64
12
+ triplane_dim: 80
13
+ rendering_samples_per_ray: 128
14
+ grid_res: 128
15
+ grid_scale: 2.1
16
+
17
+
18
+ infer_config:
19
+ unet_path: ckpts/diffusion_pytorch_model.bin
20
+ model_path: ckpts/instant_mesh_large.ckpt
21
+ texture_resolution: 1024
22
+ render_resolution: 512
configs/instant-nerf-base.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_config:
2
+ target: src.models.lrm.InstantNeRF
3
+ params:
4
+ encoder_feat_dim: 768
5
+ encoder_freeze: false
6
+ encoder_model_name: facebook/dino-vitb16
7
+ transformer_dim: 1024
8
+ transformer_layers: 12
9
+ transformer_heads: 16
10
+ triplane_low_res: 32
11
+ triplane_high_res: 64
12
+ triplane_dim: 40
13
+ rendering_samples_per_ray: 96
14
+
15
+
16
+ infer_config:
17
+ unet_path: ckpts/diffusion_pytorch_model.bin
18
+ model_path: ckpts/instant_nerf_base.ckpt
19
+ mesh_threshold: 10.0
20
+ mesh_resolution: 256
21
+ render_resolution: 384
configs/instant-nerf-large.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_config:
2
+ target: src.models.lrm.InstantNeRF
3
+ params:
4
+ encoder_feat_dim: 768
5
+ encoder_freeze: false
6
+ encoder_model_name: facebook/dino-vitb16
7
+ transformer_dim: 1024
8
+ transformer_layers: 16
9
+ transformer_heads: 16
10
+ triplane_low_res: 32
11
+ triplane_high_res: 64
12
+ triplane_dim: 80
13
+ rendering_samples_per_ray: 128
14
+
15
+
16
+ infer_config:
17
+ unet_path: ckpts/diffusion_pytorch_model.bin
18
+ model_path: ckpts/instant_nerf_large.ckpt
19
+ mesh_threshold: 10.0
20
+ mesh_resolution: 256
21
+ render_resolution: 384
examples/bird.jpg ADDED
examples/bubble_mart_blue.png ADDED
examples/cake.jpg ADDED
examples/cartoon_dinosaur.png ADDED
examples/cartoon_girl.jpg ADDED
examples/chair_comfort.jpg ADDED
examples/chair_wood.jpg ADDED
examples/chest.jpg ADDED
examples/cube.png ADDED
examples/extinguisher.png ADDED
examples/fruit_bycycle.jpg ADDED
examples/fruit_elephant.jpg ADDED
examples/genshin_building.png ADDED
examples/house2.jpg ADDED
examples/kunkun.png ADDED
examples/mushroom_teapot.jpg ADDED
examples/pikachu.png ADDED
examples/pistol.png ADDED
examples/plant.jpg ADDED
examples/robot.jpg ADDED
examples/sea_turtle.png ADDED
examples/skating_shoe.jpg ADDED
examples/sorting_board.png ADDED
examples/sword.png ADDED
examples/toy_car.jpg ADDED
examples/toyduck.png ADDED
examples/watermelon.png ADDED
examples/whitedog.png ADDED
examples/x_cube.jpg ADDED
examples/x_teapot.jpg ADDED
examples/x_toyduck.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pytorch-lightning==2.1.2
2
+ einops
3
+ omegaconf
4
+ deepspeed
5
+ torchmetrics
6
+ webdataset
7
+ accelerate
8
+ tensorboard
9
+ PyMCubes
10
+ trimesh
11
+ rembg
12
+ transformers==4.34.1
13
+ diffusers==0.19.3
14
+ bitsandbytes
15
+ imageio[ffmpeg]
16
+ xatlas
17
+ plyfile
18
+ xformers==0.0.22.post7
19
+ git+https://github.com/NVlabs/nvdiffrast/
20
+ torch-scatter -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
21
+ huggingface-hub
src/__init__.py ADDED
File without changes
src/data/__init__.py ADDED
File without changes
src/data/objaverse.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import math
3
+ import json
4
+ import importlib
5
+ from pathlib import Path
6
+
7
+ import cv2
8
+ import random
9
+ import numpy as np
10
+ from PIL import Image
11
+ import webdataset as wds
12
+ import pytorch_lightning as pl
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from torch.utils.data import Dataset
17
+ from torch.utils.data import DataLoader
18
+ from torch.utils.data.distributed import DistributedSampler
19
+ from torchvision import transforms
20
+
21
+ from src.utils.train_util import instantiate_from_config
22
+ from src.utils.camera_util import (
23
+ FOV_to_intrinsics,
24
+ center_looking_at_camera_pose,
25
+ get_surrounding_views,
26
+ )
27
+
28
+
29
+ class DataModuleFromConfig(pl.LightningDataModule):
30
+ def __init__(
31
+ self,
32
+ batch_size=8,
33
+ num_workers=4,
34
+ train=None,
35
+ validation=None,
36
+ test=None,
37
+ **kwargs,
38
+ ):
39
+ super().__init__()
40
+
41
+ self.batch_size = batch_size
42
+ self.num_workers = num_workers
43
+
44
+ self.dataset_configs = dict()
45
+ if train is not None:
46
+ self.dataset_configs['train'] = train
47
+ if validation is not None:
48
+ self.dataset_configs['validation'] = validation
49
+ if test is not None:
50
+ self.dataset_configs['test'] = test
51
+
52
+ def setup(self, stage):
53
+
54
+ if stage in ['fit']:
55
+ self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
56
+ else:
57
+ raise NotImplementedError
58
+
59
+ def train_dataloader(self):
60
+
61
+ sampler = DistributedSampler(self.datasets['train'])
62
+ return wds.WebLoader(self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler)
63
+
64
+ def val_dataloader(self):
65
+
66
+ sampler = DistributedSampler(self.datasets['validation'])
67
+ return wds.WebLoader(self.datasets['validation'], batch_size=1, num_workers=self.num_workers, shuffle=False, sampler=sampler)
68
+
69
+ def test_dataloader(self):
70
+
71
+ return wds.WebLoader(self.datasets['test'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
72
+
73
+
74
+ class ObjaverseData(Dataset):
75
+ def __init__(self,
76
+ root_dir='objaverse/',
77
+ meta_fname='valid_paths.json',
78
+ input_image_dir='rendering_random_32views',
79
+ target_image_dir='rendering_random_32views',
80
+ input_view_num=6,
81
+ target_view_num=2,
82
+ total_view_n=32,
83
+ fov=50,
84
+ camera_rotation=True,
85
+ validation=False,
86
+ ):
87
+ self.root_dir = Path(root_dir)
88
+ self.input_image_dir = input_image_dir
89
+ self.target_image_dir = target_image_dir
90
+
91
+ self.input_view_num = input_view_num
92
+ self.target_view_num = target_view_num
93
+ self.total_view_n = total_view_n
94
+ self.fov = fov
95
+ self.camera_rotation = camera_rotation
96
+
97
+ with open(os.path.join(root_dir, meta_fname)) as f:
98
+ filtered_dict = json.load(f)
99
+ paths = filtered_dict['good_objs']
100
+ self.paths = paths
101
+
102
+ self.depth_scale = 4.0
103
+
104
+ total_objects = len(self.paths)
105
+ print('============= length of dataset %d =============' % len(self.paths))
106
+
107
+ def __len__(self):
108
+ return len(self.paths)
109
+
110
+ def load_im(self, path, color):
111
+ '''
112
+ replace background pixel with random color in rendering
113
+ '''
114
+ pil_img = Image.open(path)
115
+
116
+ image = np.asarray(pil_img, dtype=np.float32) / 255.
117
+ alpha = image[:, :, 3:]
118
+ image = image[:, :, :3] * alpha + color * (1 - alpha)
119
+
120
+ image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
121
+ alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
122
+ return image, alpha
123
+
124
+ def __getitem__(self, index):
125
+ # load data
126
+ while True:
127
+ input_image_path = os.path.join(self.root_dir, self.input_image_dir, self.paths[index])
128
+ target_image_path = os.path.join(self.root_dir, self.target_image_dir, self.paths[index])
129
+
130
+ indices = np.random.choice(range(self.total_view_n), self.input_view_num + self.target_view_num, replace=False)
131
+ input_indices = indices[:self.input_view_num]
132
+ target_indices = indices[self.input_view_num:]
133
+
134
+ '''background color, default: white'''
135
+ bg_white = [1., 1., 1.]
136
+ bg_black = [0., 0., 0.]
137
+
138
+ image_list = []
139
+ alpha_list = []
140
+ depth_list = []
141
+ normal_list = []
142
+ pose_list = []
143
+
144
+ try:
145
+ input_cameras = np.load(os.path.join(input_image_path, 'cameras.npz'))['cam_poses']
146
+ for idx in input_indices:
147
+ image, alpha = self.load_im(os.path.join(input_image_path, '%03d.png' % idx), bg_white)
148
+ normal, _ = self.load_im(os.path.join(input_image_path, '%03d_normal.png' % idx), bg_black)
149
+ depth = cv2.imread(os.path.join(input_image_path, '%03d_depth.png' % idx), cv2.IMREAD_UNCHANGED) / 255.0 * self.depth_scale
150
+ depth = torch.from_numpy(depth).unsqueeze(0)
151
+ pose = input_cameras[idx]
152
+ pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0)
153
+
154
+ image_list.append(image)
155
+ alpha_list.append(alpha)
156
+ depth_list.append(depth)
157
+ normal_list.append(normal)
158
+ pose_list.append(pose)
159
+
160
+ target_cameras = np.load(os.path.join(target_image_path, 'cameras.npz'))['cam_poses']
161
+ for idx in target_indices:
162
+ image, alpha = self.load_im(os.path.join(target_image_path, '%03d.png' % idx), bg_white)
163
+ normal, _ = self.load_im(os.path.join(target_image_path, '%03d_normal.png' % idx), bg_black)
164
+ depth = cv2.imread(os.path.join(target_image_path, '%03d_depth.png' % idx), cv2.IMREAD_UNCHANGED) / 255.0 * self.depth_scale
165
+ depth = torch.from_numpy(depth).unsqueeze(0)
166
+ pose = target_cameras[idx]
167
+ pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0)
168
+
169
+ image_list.append(image)
170
+ alpha_list.append(alpha)
171
+ depth_list.append(depth)
172
+ normal_list.append(normal)
173
+ pose_list.append(pose)
174
+
175
+ except Exception as e:
176
+ print(e)
177
+ index = np.random.randint(0, len(self.paths))
178
+ continue
179
+
180
+ break
181
+
182
+ images = torch.stack(image_list, dim=0).float() # (6+V, 3, H, W)
183
+ alphas = torch.stack(alpha_list, dim=0).float() # (6+V, 1, H, W)
184
+ depths = torch.stack(depth_list, dim=0).float() # (6+V, 1, H, W)
185
+ normals = torch.stack(normal_list, dim=0).float() # (6+V, 3, H, W)
186
+ w2cs = torch.from_numpy(np.stack(pose_list, axis=0)).float() # (6+V, 4, 4)
187
+ c2ws = torch.linalg.inv(w2cs).float()
188
+
189
+ normals = normals * 2.0 - 1.0
190
+ normals = F.normalize(normals, dim=1)
191
+ normals = (normals + 1.0) / 2.0
192
+ normals = torch.lerp(torch.zeros_like(normals), normals, alphas)
193
+
194
+ # random rotation along z axis
195
+ if self.camera_rotation:
196
+ degree = np.random.uniform(0, math.pi * 2)
197
+ rot = torch.tensor([
198
+ [np.cos(degree), -np.sin(degree), 0, 0],
199
+ [np.sin(degree), np.cos(degree), 0, 0],
200
+ [0, 0, 1, 0],
201
+ [0, 0, 0, 1],
202
+ ]).unsqueeze(0).float()
203
+ c2ws = torch.matmul(rot, c2ws)
204
+
205
+ # rotate normals
206
+ N, _, H, W = normals.shape
207
+ normals = normals * 2.0 - 1.0
208
+ normals = torch.matmul(rot[:, :3, :3], normals.view(N, 3, -1)).view(N, 3, H, W)
209
+ normals = F.normalize(normals, dim=1)
210
+ normals = (normals + 1.0) / 2.0
211
+ normals = torch.lerp(torch.zeros_like(normals), normals, alphas)
212
+
213
+ # random scaling
214
+ if np.random.rand() < 0.5:
215
+ scale = np.random.uniform(0.8, 1.0)
216
+ c2ws[:, :3, 3] *= scale
217
+ depths *= scale
218
+
219
+ # instrinsics of perspective cameras
220
+ K = FOV_to_intrinsics(self.fov)
221
+ Ks = K.unsqueeze(0).repeat(self.input_view_num + self.target_view_num, 1, 1).float()
222
+
223
+ data = {
224
+ 'input_images': images[:self.input_view_num], # (6, 3, H, W)
225
+ 'input_alphas': alphas[:self.input_view_num], # (6, 1, H, W)
226
+ 'input_depths': depths[:self.input_view_num], # (6, 1, H, W)
227
+ 'input_normals': normals[:self.input_view_num], # (6, 3, H, W)
228
+ 'input_c2ws': c2ws_input[:self.input_view_num], # (6, 4, 4)
229
+ 'input_Ks': Ks[:self.input_view_num], # (6, 3, 3)
230
+
231
+ # lrm generator input and supervision
232
+ 'target_images': images[self.input_view_num:], # (V, 3, H, W)
233
+ 'target_alphas': alphas[self.input_view_num:], # (V, 1, H, W)
234
+ 'target_depths': depths[self.input_view_num:], # (V, 1, H, W)
235
+ 'target_normals': normals[self.input_view_num:], # (V, 3, H, W)
236
+ 'target_c2ws': c2ws[self.input_view_num:], # (V, 4, 4)
237
+ 'target_Ks': Ks[self.input_view_num:], # (V, 3, 3)
238
+
239
+ 'depth_available': 1,
240
+ }
241
+ return data
242
+
243
+
244
+ class ValidationData(Dataset):
245
+ def __init__(self,
246
+ root_dir='objaverse/',
247
+ input_view_num=6,
248
+ input_image_size=256,
249
+ fov=50,
250
+ ):
251
+ self.root_dir = Path(root_dir)
252
+ self.input_view_num = input_view_num
253
+ self.input_image_size = input_image_size
254
+ self.fov = fov
255
+
256
+ self.paths = sorted(os.listdir(self.root_dir))
257
+ print('============= length of dataset %d =============' % len(self.paths))
258
+
259
+ cam_distance = 2.5
260
+ azimuths = np.array([30, 90, 150, 210, 270, 330])
261
+ elevations = np.array([30, -20, 30, -20, 30, -20])
262
+ azimuths = np.deg2rad(azimuths)
263
+ elevations = np.deg2rad(elevations)
264
+
265
+ x = cam_distance * np.cos(elevations) * np.cos(azimuths)
266
+ y = cam_distance * np.cos(elevations) * np.sin(azimuths)
267
+ z = cam_distance * np.sin(elevations)
268
+
269
+ cam_locations = np.stack([x, y, z], axis=-1)
270
+ cam_locations = torch.from_numpy(cam_locations).float()
271
+ c2ws = center_looking_at_camera_pose(cam_locations)
272
+ self.c2ws = c2ws.float()
273
+ self.Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(6, 1, 1).float()
274
+
275
+ render_c2ws = get_surrounding_views(M=8, radius=cam_distance)
276
+ render_Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(render_c2ws.shape[0], 1, 1)
277
+ self.render_c2ws = render_c2ws.float()
278
+ self.render_Ks = render_Ks.float()
279
+
280
+ def __len__(self):
281
+ return len(self.paths)
282
+
283
+ def load_im(self, path, color):
284
+ '''
285
+ replace background pixel with random color in rendering
286
+ '''
287
+ pil_img = Image.open(path)
288
+ pil_img = pil_img.resize((self.input_image_size, self.input_image_size), resample=Image.BICUBIC)
289
+
290
+ image = np.asarray(pil_img, dtype=np.float32) / 255.
291
+ if image.shape[-1] == 4:
292
+ alpha = image[:, :, 3:]
293
+ image = image[:, :, :3] * alpha + color * (1 - alpha)
294
+ else:
295
+ alpha = np.ones_like(image[:, :, :1])
296
+
297
+ image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
298
+ alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
299
+ return image, alpha
300
+
301
+ def __getitem__(self, index):
302
+ # load data
303
+ input_image_path = os.path.join(self.root_dir, self.paths[index])
304
+
305
+ '''background color, default: white'''
306
+ # color = np.random.uniform(0.48, 0.52)
307
+ bkg_color = [1.0, 1.0, 1.0]
308
+
309
+ image_list = []
310
+ alpha_list = []
311
+
312
+ for idx in range(self.input_view_num):
313
+ image, alpha = self.load_im(os.path.join(input_image_path, f'{idx:03d}.png'), bkg_color)
314
+ image_list.append(image)
315
+ alpha_list.append(alpha)
316
+
317
+ images = torch.stack(image_list, dim=0).float() # (6+V, 3, H, W)
318
+ alphas = torch.stack(alpha_list, dim=0).float() # (6+V, 1, H, W)
319
+
320
+ data = {
321
+ 'input_images': images, # (6, 3, H, W)
322
+ 'input_alphas': alphas, # (6, 1, H, W)
323
+ 'input_c2ws': self.c2ws, # (6, 4, 4)
324
+ 'input_Ks': self.Ks, # (6, 3, 3)
325
+
326
+ 'render_c2ws': self.render_c2ws,
327
+ 'render_Ks': self.render_Ks,
328
+ }
329
+ return data
src/model.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torchvision.transforms import v2
6
+ from torchvision.utils import make_grid, save_image
7
+ from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
8
+ import pytorch_lightning as pl
9
+ from einops import rearrange, repeat
10
+
11
+ from src.utils.train_util import instantiate_from_config
12
+
13
+
14
+ class MVRecon(pl.LightningModule):
15
+ def __init__(
16
+ self,
17
+ lrm_generator_config,
18
+ lrm_path=None,
19
+ input_size=256,
20
+ render_size=192,
21
+ ):
22
+ super(MVRecon, self).__init__()
23
+
24
+ self.input_size = input_size
25
+ self.render_size = render_size
26
+
27
+ # init modules
28
+ self.lrm_generator = instantiate_from_config(lrm_generator_config)
29
+ if lrm_path is not None:
30
+ lrm_ckpt = torch.load(lrm_path)
31
+ self.lrm_generator.load_state_dict(lrm_ckpt['weights'], strict=False)
32
+
33
+ self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')
34
+
35
+ self.validation_step_outputs = []
36
+
37
+ def on_fit_start(self):
38
+ if self.global_rank == 0:
39
+ os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)
40
+ os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)
41
+
42
+ def prepare_batch_data(self, batch):
43
+ lrm_generator_input = {}
44
+ render_gt = {} # for supervision
45
+
46
+ # input images
47
+ images = batch['input_images']
48
+ images = v2.functional.resize(
49
+ images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
50
+
51
+ lrm_generator_input['images'] = images.to(self.device)
52
+
53
+ # input cameras and render cameras
54
+ input_c2ws = batch['input_c2ws'].flatten(-2)
55
+ input_Ks = batch['input_Ks'].flatten(-2)
56
+ target_c2ws = batch['target_c2ws'].flatten(-2)
57
+ target_Ks = batch['target_Ks'].flatten(-2)
58
+ render_cameras_input = torch.cat([input_c2ws, input_Ks], dim=-1)
59
+ render_cameras_target = torch.cat([target_c2ws, target_Ks], dim=-1)
60
+ render_cameras = torch.cat([render_cameras_input, render_cameras_target], dim=1)
61
+
62
+ input_extrinsics = input_c2ws[:, :, :12]
63
+ input_intrinsics = torch.stack([
64
+ input_Ks[:, :, 0], input_Ks[:, :, 4],
65
+ input_Ks[:, :, 2], input_Ks[:, :, 5],
66
+ ], dim=-1)
67
+ cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
68
+
69
+ # add noise to input cameras
70
+ cameras = cameras + torch.rand_like(cameras) * 0.04 - 0.02
71
+
72
+ lrm_generator_input['cameras'] = cameras.to(self.device)
73
+ lrm_generator_input['render_cameras'] = render_cameras.to(self.device)
74
+
75
+ # target images
76
+ target_images = torch.cat([batch['input_images'], batch['target_images']], dim=1)
77
+ target_depths = torch.cat([batch['input_depths'], batch['target_depths']], dim=1)
78
+ target_alphas = torch.cat([batch['input_alphas'], batch['target_alphas']], dim=1)
79
+
80
+ # random crop
81
+ render_size = np.random.randint(self.render_size, 513)
82
+ target_images = v2.functional.resize(
83
+ target_images, render_size, interpolation=3, antialias=True).clamp(0, 1)
84
+ target_depths = v2.functional.resize(
85
+ target_depths, render_size, interpolation=0, antialias=True)
86
+ target_alphas = v2.functional.resize(
87
+ target_alphas, render_size, interpolation=0, antialias=True)
88
+
89
+ crop_params = v2.RandomCrop.get_params(
90
+ target_images, output_size=(self.render_size, self.render_size))
91
+ target_images = v2.functional.crop(target_images, *crop_params)
92
+ target_depths = v2.functional.crop(target_depths, *crop_params)[:, :, 0:1]
93
+ target_alphas = v2.functional.crop(target_alphas, *crop_params)[:, :, 0:1]
94
+
95
+ lrm_generator_input['render_size'] = render_size
96
+ lrm_generator_input['crop_params'] = crop_params
97
+
98
+ render_gt['target_images'] = target_images.to(self.device)
99
+ render_gt['target_depths'] = target_depths.to(self.device)
100
+ render_gt['target_alphas'] = target_alphas.to(self.device)
101
+
102
+ return lrm_generator_input, render_gt
103
+
104
+ def prepare_validation_batch_data(self, batch):
105
+ lrm_generator_input = {}
106
+
107
+ # input images
108
+ images = batch['input_images']
109
+ images = v2.functional.resize(
110
+ images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
111
+
112
+ lrm_generator_input['images'] = images.to(self.device)
113
+
114
+ input_c2ws = batch['input_c2ws'].flatten(-2)
115
+ input_Ks = batch['input_Ks'].flatten(-2)
116
+
117
+ input_extrinsics = input_c2ws[:, :, :12]
118
+ input_intrinsics = torch.stack([
119
+ input_Ks[:, :, 0], input_Ks[:, :, 4],
120
+ input_Ks[:, :, 2], input_Ks[:, :, 5],
121
+ ], dim=-1)
122
+ cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
123
+
124
+ lrm_generator_input['cameras'] = cameras.to(self.device)
125
+
126
+ render_c2ws = batch['render_c2ws'].flatten(-2)
127
+ render_Ks = batch['render_Ks'].flatten(-2)
128
+ render_cameras = torch.cat([render_c2ws, render_Ks], dim=-1)
129
+
130
+ lrm_generator_input['render_cameras'] = render_cameras.to(self.device)
131
+ lrm_generator_input['render_size'] = 384
132
+ lrm_generator_input['crop_params'] = None
133
+
134
+ return lrm_generator_input
135
+
136
+ def forward_lrm_generator(
137
+ self,
138
+ images,
139
+ cameras,
140
+ render_cameras,
141
+ render_size=192,
142
+ crop_params=None,
143
+ chunk_size=1,
144
+ ):
145
+ planes = torch.utils.checkpoint.checkpoint(
146
+ self.lrm_generator.forward_planes,
147
+ images,
148
+ cameras,
149
+ use_reentrant=False,
150
+ )
151
+ frames = []
152
+ for i in range(0, render_cameras.shape[1], chunk_size):
153
+ frames.append(
154
+ torch.utils.checkpoint.checkpoint(
155
+ self.lrm_generator.synthesizer,
156
+ planes,
157
+ cameras=render_cameras[:, i:i+chunk_size],
158
+ render_size=render_size,
159
+ crop_params=crop_params,
160
+ use_reentrant=False
161
+ )
162
+ )
163
+ frames = {
164
+ k: torch.cat([r[k] for r in frames], dim=1)
165
+ for k in frames[0].keys()
166
+ }
167
+ return frames
168
+
169
+ def forward(self, lrm_generator_input):
170
+ images = lrm_generator_input['images']
171
+ cameras = lrm_generator_input['cameras']
172
+ render_cameras = lrm_generator_input['render_cameras']
173
+ render_size = lrm_generator_input['render_size']
174
+ crop_params = lrm_generator_input['crop_params']
175
+
176
+ out = self.forward_lrm_generator(
177
+ images,
178
+ cameras,
179
+ render_cameras,
180
+ render_size=render_size,
181
+ crop_params=crop_params,
182
+ chunk_size=1,
183
+ )
184
+ render_images = torch.clamp(out['images_rgb'], 0.0, 1.0)
185
+ render_depths = out['images_depth']
186
+ render_alphas = torch.clamp(out['images_weight'], 0.0, 1.0)
187
+
188
+ out = {
189
+ 'render_images': render_images,
190
+ 'render_depths': render_depths,
191
+ 'render_alphas': render_alphas,
192
+ }
193
+ return out
194
+
195
+ def training_step(self, batch, batch_idx):
196
+ lrm_generator_input, render_gt = self.prepare_batch_data(batch)
197
+
198
+ render_out = self.forward(lrm_generator_input)
199
+
200
+ loss, loss_dict = self.compute_loss(render_out, render_gt)
201
+
202
+ self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
203
+
204
+ if self.global_step % 1000 == 0 and self.global_rank == 0:
205
+ B, N, C, H, W = render_gt['target_images'].shape
206
+ N_in = lrm_generator_input['images'].shape[1]
207
+
208
+ input_images = v2.functional.resize(
209
+ lrm_generator_input['images'], (H, W), interpolation=3, antialias=True).clamp(0, 1)
210
+ input_images = torch.cat(
211
+ [input_images, torch.ones(B, N-N_in, C, H, W).to(input_images)], dim=1)
212
+
213
+ input_images = rearrange(
214
+ input_images, 'b n c h w -> b c h (n w)')
215
+ target_images = rearrange(
216
+ render_gt['target_images'], 'b n c h w -> b c h (n w)')
217
+ render_images = rearrange(
218
+ render_out['render_images'], 'b n c h w -> b c h (n w)')
219
+ target_alphas = rearrange(
220
+ repeat(render_gt['target_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
221
+ render_alphas = rearrange(
222
+ repeat(render_out['render_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
223
+ target_depths = rearrange(
224
+ repeat(render_gt['target_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
225
+ render_depths = rearrange(
226
+ repeat(render_out['render_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
227
+ MAX_DEPTH = torch.max(target_depths)
228
+ target_depths = target_depths / MAX_DEPTH * target_alphas
229
+ render_depths = render_depths / MAX_DEPTH
230
+
231
+ grid = torch.cat([
232
+ input_images,
233
+ target_images, render_images,
234
+ target_alphas, render_alphas,
235
+ target_depths, render_depths,
236
+ ], dim=-2)
237
+ grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1))
238
+
239
+ save_image(grid, os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png'))
240
+
241
+ return loss
242
+
243
+ def compute_loss(self, render_out, render_gt):
244
+ # NOTE: the rgb value range of OpenLRM is [0, 1]
245
+ render_images = render_out['render_images']
246
+ target_images = render_gt['target_images'].to(render_images)
247
+ render_images = rearrange(render_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
248
+ target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
249
+
250
+ loss_mse = F.mse_loss(render_images, target_images)
251
+ loss_lpips = 2.0 * self.lpips(render_images, target_images)
252
+
253
+ render_alphas = render_out['render_alphas']
254
+ target_alphas = render_gt['target_alphas']
255
+ loss_mask = F.mse_loss(render_alphas, target_alphas)
256
+
257
+ loss = loss_mse + loss_lpips + loss_mask
258
+
259
+ prefix = 'train'
260
+ loss_dict = {}
261
+ loss_dict.update({f'{prefix}/loss_mse': loss_mse})
262
+ loss_dict.update({f'{prefix}/loss_lpips': loss_lpips})
263
+ loss_dict.update({f'{prefix}/loss_mask': loss_mask})
264
+ loss_dict.update({f'{prefix}/loss': loss})
265
+
266
+ return loss, loss_dict
267
+
268
+ @torch.no_grad()
269
+ def validation_step(self, batch, batch_idx):
270
+ lrm_generator_input = self.prepare_validation_batch_data(batch)
271
+
272
+ render_out = self.forward(lrm_generator_input)
273
+ render_images = render_out['render_images']
274
+ render_images = rearrange(render_images, 'b n c h w -> b c h (n w)')
275
+
276
+ self.validation_step_outputs.append(render_images)
277
+
278
+ def on_validation_epoch_end(self):
279
+ images = torch.cat(self.validation_step_outputs, dim=-1)
280
+
281
+ all_images = self.all_gather(images)
282
+ all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')
283
+
284
+ if self.global_rank == 0:
285
+ image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')
286
+
287
+ grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1))
288
+ save_image(grid, image_path)
289
+ print(f"Saved image to {image_path}")
290
+
291
+ self.validation_step_outputs.clear()
292
+
293
+ def configure_optimizers(self):
294
+ lr = self.learning_rate
295
+
296
+ params = []
297
+
298
+ lrm_params_fast, lrm_params_slow = [], []
299
+ for n, p in self.lrm_generator.named_parameters():
300
+ if 'adaLN_modulation' in n or 'camera_embedder' in n:
301
+ lrm_params_fast.append(p)
302
+ else:
303
+ lrm_params_slow.append(p)
304
+ params.append({"params": lrm_params_fast, "lr": lr, "weight_decay": 0.01 })
305
+ params.append({"params": lrm_params_slow, "lr": lr / 10.0, "weight_decay": 0.01 })
306
+
307
+ optimizer = torch.optim.AdamW(params, lr=lr, betas=(0.90, 0.95))
308
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3000, eta_min=lr/4)
309
+
310
+ return {'optimizer': optimizer, 'lr_scheduler': scheduler}
src/model_mesh.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torchvision.transforms import v2
6
+ from torchvision.utils import make_grid, save_image
7
+ from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
8
+ import pytorch_lightning as pl
9
+ from einops import rearrange, repeat
10
+
11
+ from src.utils.train_util import instantiate_from_config
12
+
13
+
14
+ # Regulrarization loss for FlexiCubes
15
+ def sdf_reg_loss_batch(sdf, all_edges):
16
+ sdf_f1x6x2 = sdf[:, all_edges.reshape(-1)].reshape(sdf.shape[0], -1, 2)
17
+ mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1])
18
+ sdf_f1x6x2 = sdf_f1x6x2[mask]
19
+ sdf_diff = F.binary_cross_entropy_with_logits(
20
+ sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()) + \
21
+ F.binary_cross_entropy_with_logits(
22
+ sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float())
23
+ return sdf_diff
24
+
25
+
26
+ class MVRecon(pl.LightningModule):
27
+ def __init__(
28
+ self,
29
+ lrm_generator_config,
30
+ input_size=256,
31
+ render_size=512,
32
+ init_ckpt=None,
33
+ ):
34
+ super(MVRecon, self).__init__()
35
+
36
+ self.input_size = input_size
37
+ self.render_size = render_size
38
+
39
+ # init modules
40
+ self.lrm_generator = instantiate_from_config(lrm_generator_config)
41
+
42
+ self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')
43
+
44
+ # Load weights from pretrained MVRecon model, and use the mlp
45
+ # weights to initialize the weights of sdf and rgb mlps.
46
+ if init_ckpt is not None:
47
+ sd = torch.load(init_ckpt, map_location='cpu')['state_dict']
48
+ sd = {k: v for k, v in sd.items() if k.startswith('lrm_generator')}
49
+ sd_fc = {}
50
+ for k, v in sd.items():
51
+ if k.startswith('lrm_generator.synthesizer.decoder.net.'):
52
+ if k.startswith('lrm_generator.synthesizer.decoder.net.6.'): # last layer
53
+ # Here we assume the density filed's isosurface threshold is t,
54
+ # we reverse the sign of density filed to initialize SDF field.
55
+ # -(w*x + b - t) = (-w)*x + (t - b)
56
+ if 'weight' in k:
57
+ sd_fc[k.replace('net.', 'net_sdf.')] = -v[0:1]
58
+ else:
59
+ sd_fc[k.replace('net.', 'net_sdf.')] = 3.0 - v[0:1]
60
+ sd_fc[k.replace('net.', 'net_rgb.')] = v[1:4]
61
+ else:
62
+ sd_fc[k.replace('net.', 'net_sdf.')] = v
63
+ sd_fc[k.replace('net.', 'net_rgb.')] = v
64
+ else:
65
+ sd_fc[k] = v
66
+ sd_fc = {k.replace('lrm_generator.', ''): v for k, v in sd_fc.items()}
67
+ # missing `net_deformation` and `net_weight` parameters
68
+ self.lrm_generator.load_state_dict(sd_fc, strict=False)
69
+ print(f'Loaded weights from {init_ckpt}')
70
+
71
+ self.validation_step_outputs = []
72
+
73
+ def on_fit_start(self):
74
+ device = torch.device(f'cuda:{self.global_rank}')
75
+ self.lrm_generator.init_flexicubes_geometry(device)
76
+ if self.global_rank == 0:
77
+ os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)
78
+ os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)
79
+
80
+ def prepare_batch_data(self, batch):
81
+ lrm_generator_input = {}
82
+ render_gt = {}
83
+
84
+ # input images
85
+ images = batch['input_images']
86
+ images = v2.functional.resize(
87
+ images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
88
+
89
+ lrm_generator_input['images'] = images.to(self.device)
90
+
91
+ # input cameras and render cameras
92
+ input_c2ws = batch['input_c2ws']
93
+ input_Ks = batch['input_Ks']
94
+ target_c2ws = batch['target_c2ws']
95
+
96
+ render_c2ws = torch.cat([input_c2ws, target_c2ws], dim=1)
97
+ render_w2cs = torch.linalg.inv(render_c2ws)
98
+
99
+ input_extrinsics = input_c2ws.flatten(-2)
100
+ input_extrinsics = input_extrinsics[:, :, :12]
101
+ input_intrinsics = input_Ks.flatten(-2)
102
+ input_intrinsics = torch.stack([
103
+ input_intrinsics[:, :, 0], input_intrinsics[:, :, 4],
104
+ input_intrinsics[:, :, 2], input_intrinsics[:, :, 5],
105
+ ], dim=-1)
106
+ cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
107
+
108
+ # add noise to input_cameras
109
+ cameras = cameras + torch.rand_like(cameras) * 0.04 - 0.02
110
+
111
+ lrm_generator_input['cameras'] = cameras.to(self.device)
112
+ lrm_generator_input['render_cameras'] = render_w2cs.to(self.device)
113
+
114
+ # target images
115
+ target_images = torch.cat([batch['input_images'], batch['target_images']], dim=1)
116
+ target_depths = torch.cat([batch['input_depths'], batch['target_depths']], dim=1)
117
+ target_alphas = torch.cat([batch['input_alphas'], batch['target_alphas']], dim=1)
118
+ target_normals = torch.cat([batch['input_normals'], batch['target_normals']], dim=1)
119
+
120
+ render_size = self.render_size
121
+ target_images = v2.functional.resize(
122
+ target_images, render_size, interpolation=3, antialias=True).clamp(0, 1)
123
+ target_depths = v2.functional.resize(
124
+ target_depths, render_size, interpolation=0, antialias=True)
125
+ target_alphas = v2.functional.resize(
126
+ target_alphas, render_size, interpolation=0, antialias=True)
127
+ target_normals = v2.functional.resize(
128
+ target_normals, render_size, interpolation=3, antialias=True)
129
+
130
+ lrm_generator_input['render_size'] = render_size
131
+
132
+ render_gt['target_images'] = target_images.to(self.device)
133
+ render_gt['target_depths'] = target_depths.to(self.device)
134
+ render_gt['target_alphas'] = target_alphas.to(self.device)
135
+ render_gt['target_normals'] = target_normals.to(self.device)
136
+
137
+ return lrm_generator_input, render_gt
138
+
139
+ def prepare_validation_batch_data(self, batch):
140
+ lrm_generator_input = {}
141
+
142
+ # input images
143
+ images = batch['input_images']
144
+ images = v2.functional.resize(
145
+ images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
146
+
147
+ lrm_generator_input['images'] = images.to(self.device)
148
+
149
+ # input cameras
150
+ input_c2ws = batch['input_c2ws'].flatten(-2)
151
+ input_Ks = batch['input_Ks'].flatten(-2)
152
+
153
+ input_extrinsics = input_c2ws[:, :, :12]
154
+ input_intrinsics = torch.stack([
155
+ input_Ks[:, :, 0], input_Ks[:, :, 4],
156
+ input_Ks[:, :, 2], input_Ks[:, :, 5],
157
+ ], dim=-1)
158
+ cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
159
+
160
+ lrm_generator_input['cameras'] = cameras.to(self.device)
161
+
162
+ # render cameras
163
+ render_c2ws = batch['render_c2ws']
164
+ render_w2cs = torch.linalg.inv(render_c2ws)
165
+
166
+ lrm_generator_input['render_cameras'] = render_w2cs.to(self.device)
167
+ lrm_generator_input['render_size'] = 384
168
+
169
+ return lrm_generator_input
170
+
171
+ def forward_lrm_generator(self, images, cameras, render_cameras, render_size=512):
172
+ planes = torch.utils.checkpoint.checkpoint(
173
+ self.lrm_generator.forward_planes,
174
+ images,
175
+ cameras,
176
+ use_reentrant=False,
177
+ )
178
+ out = self.lrm_generator.forward_geometry(
179
+ planes,
180
+ render_cameras,
181
+ render_size,
182
+ )
183
+ return out
184
+
185
+ def forward(self, lrm_generator_input):
186
+ images = lrm_generator_input['images']
187
+ cameras = lrm_generator_input['cameras']
188
+ render_cameras = lrm_generator_input['render_cameras']
189
+ render_size = lrm_generator_input['render_size']
190
+
191
+ out = self.forward_lrm_generator(
192
+ images, cameras, render_cameras, render_size=render_size)
193
+
194
+ return out
195
+
196
+ def training_step(self, batch, batch_idx):
197
+ lrm_generator_input, render_gt = self.prepare_batch_data(batch)
198
+
199
+ render_out = self.forward(lrm_generator_input)
200
+
201
+ loss, loss_dict = self.compute_loss(render_out, render_gt)
202
+
203
+ self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
204
+
205
+ if self.global_step % 1000 == 0 and self.global_rank == 0:
206
+ B, N, C, H, W = render_gt['target_images'].shape
207
+ N_in = lrm_generator_input['images'].shape[1]
208
+
209
+ target_images = rearrange(
210
+ render_gt['target_images'], 'b n c h w -> b c h (n w)')
211
+ render_images = rearrange(
212
+ render_out['img'], 'b n c h w -> b c h (n w)')
213
+ target_alphas = rearrange(
214
+ repeat(render_gt['target_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
215
+ render_alphas = rearrange(
216
+ repeat(render_out['mask'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
217
+ target_depths = rearrange(
218
+ repeat(render_gt['target_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
219
+ render_depths = rearrange(
220
+ repeat(render_out['depth'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
221
+ target_normals = rearrange(
222
+ render_gt['target_normals'], 'b n c h w -> b c h (n w)')
223
+ render_normals = rearrange(
224
+ render_out['normal'], 'b n c h w -> b c h (n w)')
225
+ MAX_DEPTH = torch.max(target_depths)
226
+ target_depths = target_depths / MAX_DEPTH * target_alphas
227
+ render_depths = render_depths / MAX_DEPTH
228
+
229
+ grid = torch.cat([
230
+ target_images, render_images,
231
+ target_alphas, render_alphas,
232
+ target_depths, render_depths,
233
+ target_normals, render_normals,
234
+ ], dim=-2)
235
+ grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1))
236
+
237
+ image_path = os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png')
238
+ save_image(grid, image_path)
239
+ print(f"Saved image to {image_path}")
240
+
241
+ return loss
242
+
243
+ def compute_loss(self, render_out, render_gt):
244
+ # NOTE: the rgb value range of OpenLRM is [0, 1]
245
+ render_images = render_out['img']
246
+ target_images = render_gt['target_images'].to(render_images)
247
+ render_images = rearrange(render_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
248
+ target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
249
+ loss_mse = F.mse_loss(render_images, target_images)
250
+ loss_lpips = 2.0 * self.lpips(render_images, target_images)
251
+
252
+ render_alphas = render_out['mask']
253
+ target_alphas = render_gt['target_alphas']
254
+ loss_mask = F.mse_loss(render_alphas, target_alphas)
255
+
256
+ render_depths = render_out['depth']
257
+ target_depths = render_gt['target_depths']
258
+ loss_depth = 0.5 * F.l1_loss(render_depths[target_alphas>0], target_depths[target_alphas>0])
259
+
260
+ render_normals = render_out['normal'] * 2.0 - 1.0
261
+ target_normals = render_gt['target_normals'] * 2.0 - 1.0
262
+ similarity = (render_normals * target_normals).sum(dim=-3).abs()
263
+ normal_mask = target_alphas.squeeze(-3)
264
+ loss_normal = 1 - similarity[normal_mask>0].mean()
265
+ loss_normal = 0.2 * loss_normal
266
+
267
+ # flexicubes regularization loss
268
+ sdf = render_out['sdf']
269
+ sdf_reg_loss = render_out['sdf_reg_loss']
270
+ sdf_reg_loss_entropy = sdf_reg_loss_batch(sdf, self.lrm_generator.geometry.all_edges).mean() * 0.01
271
+ _, flexicubes_surface_reg, flexicubes_weights_reg = sdf_reg_loss
272
+ flexicubes_surface_reg = flexicubes_surface_reg.mean() * 0.5
273
+ flexicubes_weights_reg = flexicubes_weights_reg.mean() * 0.1
274
+
275
+ loss_reg = sdf_reg_loss_entropy + flexicubes_surface_reg + flexicubes_weights_reg
276
+
277
+ loss = loss_mse + loss_lpips + loss_mask + loss_normal + loss_reg
278
+
279
+ prefix = 'train'
280
+ loss_dict = {}
281
+ loss_dict.update({f'{prefix}/loss_mse': loss_mse})
282
+ loss_dict.update({f'{prefix}/loss_lpips': loss_lpips})
283
+ loss_dict.update({f'{prefix}/loss_mask': loss_mask})
284
+ loss_dict.update({f'{prefix}/loss_normal': loss_normal})
285
+ loss_dict.update({f'{prefix}/loss_depth': loss_depth})
286
+ loss_dict.update({f'{prefix}/loss_reg_sdf': sdf_reg_loss_entropy})
287
+ loss_dict.update({f'{prefix}/loss_reg_surface': flexicubes_surface_reg})
288
+ loss_dict.update({f'{prefix}/loss_reg_weights': flexicubes_weights_reg})
289
+ loss_dict.update({f'{prefix}/loss': loss})
290
+
291
+ return loss, loss_dict
292
+
293
+ @torch.no_grad()
294
+ def validation_step(self, batch, batch_idx):
295
+ lrm_generator_input = self.prepare_validation_batch_data(batch)
296
+
297
+ render_out = self.forward(lrm_generator_input)
298
+ render_images = render_out['img']
299
+ render_images = rearrange(render_images, 'b n c h w -> b c h (n w)')
300
+
301
+ self.validation_step_outputs.append(render_images)
302
+
303
+ def on_validation_epoch_end(self):
304
+ images = torch.cat(self.validation_step_outputs, dim=-1)
305
+
306
+ all_images = self.all_gather(images)
307
+ all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')
308
+
309
+ if self.global_rank == 0:
310
+ image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')
311
+
312
+ grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1))
313
+ save_image(grid, image_path)
314
+ print(f"Saved image to {image_path}")
315
+
316
+ self.validation_step_outputs.clear()
317
+
318
+ def configure_optimizers(self):
319
+ lr = self.learning_rate
320
+
321
+ optimizer = torch.optim.AdamW(
322
+ self.lrm_generator.parameters(), lr=lr, betas=(0.90, 0.95), weight_decay=0.01)
323
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 100000, eta_min=0)
324
+
325
+ return {'optimizer': optimizer, 'lr_scheduler': scheduler}
src/models/__init__.py ADDED
File without changes
src/models/decoder/__init__.py ADDED
File without changes
src/models/decoder/transformer.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Zexin He
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+
20
+ class BasicTransformerBlock(nn.Module):
21
+ """
22
+ Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks.
23
+ """
24
+ # use attention from torch.nn.MultiHeadAttention
25
+ # Block contains a cross-attention layer, a self-attention layer, and a MLP
26
+ def __init__(
27
+ self,
28
+ inner_dim: int,
29
+ cond_dim: int,
30
+ num_heads: int,
31
+ eps: float,
32
+ attn_drop: float = 0.,
33
+ attn_bias: bool = False,
34
+ mlp_ratio: float = 4.,
35
+ mlp_drop: float = 0.,
36
+ ):
37
+ super().__init__()
38
+
39
+ self.norm1 = nn.LayerNorm(inner_dim)
40
+ self.cross_attn = nn.MultiheadAttention(
41
+ embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
42
+ dropout=attn_drop, bias=attn_bias, batch_first=True)
43
+ self.norm2 = nn.LayerNorm(inner_dim)
44
+ self.self_attn = nn.MultiheadAttention(
45
+ embed_dim=inner_dim, num_heads=num_heads,
46
+ dropout=attn_drop, bias=attn_bias, batch_first=True)
47
+ self.norm3 = nn.LayerNorm(inner_dim)
48
+ self.mlp = nn.Sequential(
49
+ nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
50
+ nn.GELU(),
51
+ nn.Dropout(mlp_drop),
52
+ nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
53
+ nn.Dropout(mlp_drop),
54
+ )
55
+
56
+ def forward(self, x, cond):
57
+ # x: [N, L, D]
58
+ # cond: [N, L_cond, D_cond]
59
+ x = x + self.cross_attn(self.norm1(x), cond, cond)[0]
60
+ before_sa = self.norm2(x)
61
+ x = x + self.self_attn(before_sa, before_sa, before_sa)[0]
62
+ x = x + self.mlp(self.norm3(x))
63
+ return x
64
+
65
+
66
+ class TriplaneTransformer(nn.Module):
67
+ """
68
+ Transformer with condition that generates a triplane representation.
69
+
70
+ Reference:
71
+ Timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L486
72
+ """
73
+ def __init__(
74
+ self,
75
+ inner_dim: int,
76
+ image_feat_dim: int,
77
+ triplane_low_res: int,
78
+ triplane_high_res: int,
79
+ triplane_dim: int,
80
+ num_layers: int,
81
+ num_heads: int,
82
+ eps: float = 1e-6,
83
+ ):
84
+ super().__init__()
85
+
86
+ # attributes
87
+ self.triplane_low_res = triplane_low_res
88
+ self.triplane_high_res = triplane_high_res
89
+ self.triplane_dim = triplane_dim
90
+
91
+ # modules
92
+ # initialize pos_embed with 1/sqrt(dim) * N(0, 1)
93
+ self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, inner_dim) * (1. / inner_dim) ** 0.5)
94
+ self.layers = nn.ModuleList([
95
+ BasicTransformerBlock(
96
+ inner_dim=inner_dim, cond_dim=image_feat_dim, num_heads=num_heads, eps=eps)
97
+ for _ in range(num_layers)
98
+ ])
99
+ self.norm = nn.LayerNorm(inner_dim, eps=eps)
100
+ self.deconv = nn.ConvTranspose2d(inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0)
101
+
102
+ def forward(self, image_feats):
103
+ # image_feats: [N, L_cond, D_cond]
104
+
105
+ N = image_feats.shape[0]
106
+ H = W = self.triplane_low_res
107
+ L = 3 * H * W
108
+
109
+ x = self.pos_embed.repeat(N, 1, 1) # [N, L, D]
110
+ for layer in self.layers:
111
+ x = layer(x, image_feats)
112
+ x = self.norm(x)
113
+
114
+ # separate each plane and apply deconv
115
+ x = x.view(N, 3, H, W, -1)
116
+ x = torch.einsum('nihwd->indhw', x) # [3, N, D, H, W]
117
+ x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W]
118
+ x = self.deconv(x) # [3*N, D', H', W']
119
+ x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W']
120
+ x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W']
121
+ x = x.contiguous()
122
+
123
+ return x
src/models/encoder/__init__.py ADDED
File without changes
src/models/encoder/dino.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch ViT model."""
16
+
17
+
18
+ import collections.abc
19
+ import math
20
+ from typing import Dict, List, Optional, Set, Tuple, Union
21
+
22
+ import torch
23
+ from torch import nn
24
+
25
+ from transformers.activations import ACT2FN
26
+ from transformers.modeling_outputs import (
27
+ BaseModelOutput,
28
+ BaseModelOutputWithPooling,
29
+ )
30
+ from transformers import PreTrainedModel, ViTConfig
31
+ from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
32
+
33
+
34
+ class ViTEmbeddings(nn.Module):
35
+ """
36
+ Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
37
+ """
38
+
39
+ def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None:
40
+ super().__init__()
41
+
42
+ self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
43
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
44
+ self.patch_embeddings = ViTPatchEmbeddings(config)
45
+ num_patches = self.patch_embeddings.num_patches
46
+ self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
47
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
48
+ self.config = config
49
+
50
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
51
+ """
52
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
53
+ resolution images.
54
+
55
+ Source:
56
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
57
+ """
58
+
59
+ num_patches = embeddings.shape[1] - 1
60
+ num_positions = self.position_embeddings.shape[1] - 1
61
+ if num_patches == num_positions and height == width:
62
+ return self.position_embeddings
63
+ class_pos_embed = self.position_embeddings[:, 0]
64
+ patch_pos_embed = self.position_embeddings[:, 1:]
65
+ dim = embeddings.shape[-1]
66
+ h0 = height // self.config.patch_size
67
+ w0 = width // self.config.patch_size
68
+ # we add a small number to avoid floating point error in the interpolation
69
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
70
+ h0, w0 = h0 + 0.1, w0 + 0.1
71
+ patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
72
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
73
+ patch_pos_embed = nn.functional.interpolate(
74
+ patch_pos_embed,
75
+ scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
76
+ mode="bicubic",
77
+ align_corners=False,
78
+ )
79
+ assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
80
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
81
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
82
+
83
+ def forward(
84
+ self,
85
+ pixel_values: torch.Tensor,
86
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
87
+ interpolate_pos_encoding: bool = False,
88
+ ) -> torch.Tensor:
89
+ batch_size, num_channels, height, width = pixel_values.shape
90
+ embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
91
+
92
+ if bool_masked_pos is not None:
93
+ seq_length = embeddings.shape[1]
94
+ mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
95
+ # replace the masked visual tokens by mask_tokens
96
+ mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
97
+ embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
98
+
99
+ # add the [CLS] token to the embedded patch tokens
100
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
101
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
102
+
103
+ # add positional encoding to each token
104
+ if interpolate_pos_encoding:
105
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
106
+ else:
107
+ embeddings = embeddings + self.position_embeddings
108
+
109
+ embeddings = self.dropout(embeddings)
110
+
111
+ return embeddings
112
+
113
+
114
+ class ViTPatchEmbeddings(nn.Module):
115
+ """
116
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
117
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
118
+ Transformer.
119
+ """
120
+
121
+ def __init__(self, config):
122
+ super().__init__()
123
+ image_size, patch_size = config.image_size, config.patch_size
124
+ num_channels, hidden_size = config.num_channels, config.hidden_size
125
+
126
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
127
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
128
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
129
+ self.image_size = image_size
130
+ self.patch_size = patch_size
131
+ self.num_channels = num_channels
132
+ self.num_patches = num_patches
133
+
134
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
135
+
136
+ def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
137
+ batch_size, num_channels, height, width = pixel_values.shape
138
+ if num_channels != self.num_channels:
139
+ raise ValueError(
140
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
141
+ f" Expected {self.num_channels} but got {num_channels}."
142
+ )
143
+ if not interpolate_pos_encoding:
144
+ if height != self.image_size[0] or width != self.image_size[1]:
145
+ raise ValueError(
146
+ f"Input image size ({height}*{width}) doesn't match model"
147
+ f" ({self.image_size[0]}*{self.image_size[1]})."
148
+ )
149
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
150
+ return embeddings
151
+
152
+
153
+ class ViTSelfAttention(nn.Module):
154
+ def __init__(self, config: ViTConfig) -> None:
155
+ super().__init__()
156
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
157
+ raise ValueError(
158
+ f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
159
+ f"heads {config.num_attention_heads}."
160
+ )
161
+
162
+ self.num_attention_heads = config.num_attention_heads
163
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
164
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
165
+
166
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
167
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
168
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
169
+
170
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
171
+
172
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
173
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
174
+ x = x.view(new_x_shape)
175
+ return x.permute(0, 2, 1, 3)
176
+
177
+ def forward(
178
+ self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
179
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
180
+ mixed_query_layer = self.query(hidden_states)
181
+
182
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
183
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
184
+ query_layer = self.transpose_for_scores(mixed_query_layer)
185
+
186
+ # Take the dot product between "query" and "key" to get the raw attention scores.
187
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
188
+
189
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
190
+
191
+ # Normalize the attention scores to probabilities.
192
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
193
+
194
+ # This is actually dropping out entire tokens to attend to, which might
195
+ # seem a bit unusual, but is taken from the original Transformer paper.
196
+ attention_probs = self.dropout(attention_probs)
197
+
198
+ # Mask heads if we want to
199
+ if head_mask is not None:
200
+ attention_probs = attention_probs * head_mask
201
+
202
+ context_layer = torch.matmul(attention_probs, value_layer)
203
+
204
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
205
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
206
+ context_layer = context_layer.view(new_context_layer_shape)
207
+
208
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
209
+
210
+ return outputs
211
+
212
+
213
+ class ViTSelfOutput(nn.Module):
214
+ """
215
+ The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the
216
+ layernorm applied before each block.
217
+ """
218
+
219
+ def __init__(self, config: ViTConfig) -> None:
220
+ super().__init__()
221
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
222
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
223
+
224
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
225
+ hidden_states = self.dense(hidden_states)
226
+ hidden_states = self.dropout(hidden_states)
227
+
228
+ return hidden_states
229
+
230
+
231
+ class ViTAttention(nn.Module):
232
+ def __init__(self, config: ViTConfig) -> None:
233
+ super().__init__()
234
+ self.attention = ViTSelfAttention(config)
235
+ self.output = ViTSelfOutput(config)
236
+ self.pruned_heads = set()
237
+
238
+ def prune_heads(self, heads: Set[int]) -> None:
239
+ if len(heads) == 0:
240
+ return
241
+ heads, index = find_pruneable_heads_and_indices(
242
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
243
+ )
244
+
245
+ # Prune linear layers
246
+ self.attention.query = prune_linear_layer(self.attention.query, index)
247
+ self.attention.key = prune_linear_layer(self.attention.key, index)
248
+ self.attention.value = prune_linear_layer(self.attention.value, index)
249
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
250
+
251
+ # Update hyper params and store pruned heads
252
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
253
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
254
+ self.pruned_heads = self.pruned_heads.union(heads)
255
+
256
+ def forward(
257
+ self,
258
+ hidden_states: torch.Tensor,
259
+ head_mask: Optional[torch.Tensor] = None,
260
+ output_attentions: bool = False,
261
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
262
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
263
+
264
+ attention_output = self.output(self_outputs[0], hidden_states)
265
+
266
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
267
+ return outputs
268
+
269
+
270
+ class ViTIntermediate(nn.Module):
271
+ def __init__(self, config: ViTConfig) -> None:
272
+ super().__init__()
273
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
274
+ if isinstance(config.hidden_act, str):
275
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
276
+ else:
277
+ self.intermediate_act_fn = config.hidden_act
278
+
279
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
280
+ hidden_states = self.dense(hidden_states)
281
+ hidden_states = self.intermediate_act_fn(hidden_states)
282
+
283
+ return hidden_states
284
+
285
+
286
+ class ViTOutput(nn.Module):
287
+ def __init__(self, config: ViTConfig) -> None:
288
+ super().__init__()
289
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
290
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
291
+
292
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
293
+ hidden_states = self.dense(hidden_states)
294
+ hidden_states = self.dropout(hidden_states)
295
+
296
+ hidden_states = hidden_states + input_tensor
297
+
298
+ return hidden_states
299
+
300
+
301
+ def modulate(x, shift, scale):
302
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
303
+
304
+
305
+ class ViTLayer(nn.Module):
306
+ """This corresponds to the Block class in the timm implementation."""
307
+
308
+ def __init__(self, config: ViTConfig) -> None:
309
+ super().__init__()
310
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
311
+ self.seq_len_dim = 1
312
+ self.attention = ViTAttention(config)
313
+ self.intermediate = ViTIntermediate(config)
314
+ self.output = ViTOutput(config)
315
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
316
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
317
+
318
+ self.adaLN_modulation = nn.Sequential(
319
+ nn.SiLU(),
320
+ nn.Linear(config.hidden_size, 4 * config.hidden_size, bias=True)
321
+ )
322
+ nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
323
+ nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
324
+
325
+ def forward(
326
+ self,
327
+ hidden_states: torch.Tensor,
328
+ adaln_input: torch.Tensor = None,
329
+ head_mask: Optional[torch.Tensor] = None,
330
+ output_attentions: bool = False,
331
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
332
+ shift_msa, scale_msa, shift_mlp, scale_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
333
+
334
+ self_attention_outputs = self.attention(
335
+ modulate(self.layernorm_before(hidden_states), shift_msa, scale_msa), # in ViT, layernorm is applied before self-attention
336
+ head_mask,
337
+ output_attentions=output_attentions,
338
+ )
339
+ attention_output = self_attention_outputs[0]
340
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
341
+
342
+ # first residual connection
343
+ hidden_states = attention_output + hidden_states
344
+
345
+ # in ViT, layernorm is also applied after self-attention
346
+ layer_output = modulate(self.layernorm_after(hidden_states), shift_mlp, scale_mlp)
347
+ layer_output = self.intermediate(layer_output)
348
+
349
+ # second residual connection is done here
350
+ layer_output = self.output(layer_output, hidden_states)
351
+
352
+ outputs = (layer_output,) + outputs
353
+
354
+ return outputs
355
+
356
+
357
+ class ViTEncoder(nn.Module):
358
+ def __init__(self, config: ViTConfig) -> None:
359
+ super().__init__()
360
+ self.config = config
361
+ self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)])
362
+ self.gradient_checkpointing = False
363
+
364
+ def forward(
365
+ self,
366
+ hidden_states: torch.Tensor,
367
+ adaln_input: torch.Tensor = None,
368
+ head_mask: Optional[torch.Tensor] = None,
369
+ output_attentions: bool = False,
370
+ output_hidden_states: bool = False,
371
+ return_dict: bool = True,
372
+ ) -> Union[tuple, BaseModelOutput]:
373
+ all_hidden_states = () if output_hidden_states else None
374
+ all_self_attentions = () if output_attentions else None
375
+
376
+ for i, layer_module in enumerate(self.layer):
377
+ if output_hidden_states:
378
+ all_hidden_states = all_hidden_states + (hidden_states,)
379
+
380
+ layer_head_mask = head_mask[i] if head_mask is not None else None
381
+
382
+ if self.gradient_checkpointing and self.training:
383
+ layer_outputs = self._gradient_checkpointing_func(
384
+ layer_module.__call__,
385
+ hidden_states,
386
+ adaln_input,
387
+ layer_head_mask,
388
+ output_attentions,
389
+ )
390
+ else:
391
+ layer_outputs = layer_module(hidden_states, adaln_input, layer_head_mask, output_attentions)
392
+
393
+ hidden_states = layer_outputs[0]
394
+
395
+ if output_attentions:
396
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
397
+
398
+ if output_hidden_states:
399
+ all_hidden_states = all_hidden_states + (hidden_states,)
400
+
401
+ if not return_dict:
402
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
403
+ return BaseModelOutput(
404
+ last_hidden_state=hidden_states,
405
+ hidden_states=all_hidden_states,
406
+ attentions=all_self_attentions,
407
+ )
408
+
409
+
410
+ class ViTPreTrainedModel(PreTrainedModel):
411
+ """
412
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
413
+ models.
414
+ """
415
+
416
+ config_class = ViTConfig
417
+ base_model_prefix = "vit"
418
+ main_input_name = "pixel_values"
419
+ supports_gradient_checkpointing = True
420
+ _no_split_modules = ["ViTEmbeddings", "ViTLayer"]
421
+
422
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
423
+ """Initialize the weights"""
424
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
425
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
426
+ # `trunc_normal_cpu` not implemented in `half` issues
427
+ module.weight.data = nn.init.trunc_normal_(
428
+ module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
429
+ ).to(module.weight.dtype)
430
+ if module.bias is not None:
431
+ module.bias.data.zero_()
432
+ elif isinstance(module, nn.LayerNorm):
433
+ module.bias.data.zero_()
434
+ module.weight.data.fill_(1.0)
435
+ elif isinstance(module, ViTEmbeddings):
436
+ module.position_embeddings.data = nn.init.trunc_normal_(
437
+ module.position_embeddings.data.to(torch.float32),
438
+ mean=0.0,
439
+ std=self.config.initializer_range,
440
+ ).to(module.position_embeddings.dtype)
441
+
442
+ module.cls_token.data = nn.init.trunc_normal_(
443
+ module.cls_token.data.to(torch.float32),
444
+ mean=0.0,
445
+ std=self.config.initializer_range,
446
+ ).to(module.cls_token.dtype)
447
+
448
+
449
+ class ViTModel(ViTPreTrainedModel):
450
+ def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False):
451
+ super().__init__(config)
452
+ self.config = config
453
+
454
+ self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token)
455
+ self.encoder = ViTEncoder(config)
456
+
457
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
458
+ self.pooler = ViTPooler(config) if add_pooling_layer else None
459
+
460
+ # Initialize weights and apply final processing
461
+ self.post_init()
462
+
463
+ def get_input_embeddings(self) -> ViTPatchEmbeddings:
464
+ return self.embeddings.patch_embeddings
465
+
466
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
467
+ """
468
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
469
+ class PreTrainedModel
470
+ """
471
+ for layer, heads in heads_to_prune.items():
472
+ self.encoder.layer[layer].attention.prune_heads(heads)
473
+
474
+ def forward(
475
+ self,
476
+ pixel_values: Optional[torch.Tensor] = None,
477
+ adaln_input: Optional[torch.Tensor] = None,
478
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
479
+ head_mask: Optional[torch.Tensor] = None,
480
+ output_attentions: Optional[bool] = None,
481
+ output_hidden_states: Optional[bool] = None,
482
+ interpolate_pos_encoding: Optional[bool] = None,
483
+ return_dict: Optional[bool] = None,
484
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
485
+ r"""
486
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
487
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
488
+ """
489
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
490
+ output_hidden_states = (
491
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
492
+ )
493
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
494
+
495
+ if pixel_values is None:
496
+ raise ValueError("You have to specify pixel_values")
497
+
498
+ # Prepare head mask if needed
499
+ # 1.0 in head_mask indicate we keep the head
500
+ # attention_probs has shape bsz x n_heads x N x N
501
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
502
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
503
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
504
+
505
+ # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
506
+ expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
507
+ if pixel_values.dtype != expected_dtype:
508
+ pixel_values = pixel_values.to(expected_dtype)
509
+
510
+ embedding_output = self.embeddings(
511
+ pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
512
+ )
513
+
514
+ encoder_outputs = self.encoder(
515
+ embedding_output,
516
+ adaln_input=adaln_input,
517
+ head_mask=head_mask,
518
+ output_attentions=output_attentions,
519
+ output_hidden_states=output_hidden_states,
520
+ return_dict=return_dict,
521
+ )
522
+ sequence_output = encoder_outputs[0]
523
+ sequence_output = self.layernorm(sequence_output)
524
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
525
+
526
+ if not return_dict:
527
+ head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
528
+ return head_outputs + encoder_outputs[1:]
529
+
530
+ return BaseModelOutputWithPooling(
531
+ last_hidden_state=sequence_output,
532
+ pooler_output=pooled_output,
533
+ hidden_states=encoder_outputs.hidden_states,
534
+ attentions=encoder_outputs.attentions,
535
+ )
536
+
537
+
538
+ class ViTPooler(nn.Module):
539
+ def __init__(self, config: ViTConfig):
540
+ super().__init__()
541
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
542
+ self.activation = nn.Tanh()
543
+
544
+ def forward(self, hidden_states):
545
+ # We "pool" the model by simply taking the hidden state corresponding
546
+ # to the first token.
547
+ first_token_tensor = hidden_states[:, 0]
548
+ pooled_output = self.dense(first_token_tensor)
549
+ pooled_output = self.activation(pooled_output)
550
+ return pooled_output
src/models/encoder/dino_wrapper.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Zexin He
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import torch.nn as nn
17
+ from transformers import ViTImageProcessor
18
+ from einops import rearrange, repeat
19
+ from .dino import ViTModel
20
+
21
+
22
+ class DinoWrapper(nn.Module):
23
+ """
24
+ Dino v1 wrapper using huggingface transformer implementation.
25
+ """
26
+ def __init__(self, model_name: str, freeze: bool = True):
27
+ super().__init__()
28
+ self.model, self.processor = self._build_dino(model_name)
29
+ self.camera_embedder = nn.Sequential(
30
+ nn.Linear(16, self.model.config.hidden_size, bias=True),
31
+ nn.SiLU(),
32
+ nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size, bias=True)
33
+ )
34
+ if freeze:
35
+ self._freeze()
36
+
37
+ def forward(self, image, camera):
38
+ # image: [B, N, C, H, W]
39
+ # camera: [B, N, D]
40
+ # RGB image with [0,1] scale and properly sized
41
+ if image.ndim == 5:
42
+ image = rearrange(image, 'b n c h w -> (b n) c h w')
43
+ dtype = image.dtype
44
+ inputs = self.processor(
45
+ images=image.float(),
46
+ return_tensors="pt",
47
+ do_rescale=False,
48
+ do_resize=False,
49
+ ).to(self.model.device).to(dtype)
50
+ # embed camera
51
+ N = camera.shape[1]
52
+ camera_embeddings = self.camera_embedder(camera)
53
+ camera_embeddings = rearrange(camera_embeddings, 'b n d -> (b n) d')
54
+ embeddings = camera_embeddings
55
+ # This resampling of positional embedding uses bicubic interpolation
56
+ outputs = self.model(**inputs, adaln_input=embeddings, interpolate_pos_encoding=True)
57
+ last_hidden_states = outputs.last_hidden_state
58
+ return last_hidden_states
59
+
60
+ def _freeze(self):
61
+ print(f"======== Freezing DinoWrapper ========")
62
+ self.model.eval()
63
+ for name, param in self.model.named_parameters():
64
+ param.requires_grad = False
65
+
66
+ @staticmethod
67
+ def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5):
68
+ import requests
69
+ try:
70
+ model = ViTModel.from_pretrained(model_name, add_pooling_layer=False)
71
+ processor = ViTImageProcessor.from_pretrained(model_name)
72
+ return model, processor
73
+ except requests.exceptions.ProxyError as err:
74
+ if proxy_error_retries > 0:
75
+ print(f"Huggingface ProxyError: Retrying in {proxy_error_cooldown} seconds...")
76
+ import time
77
+ time.sleep(proxy_error_cooldown)
78
+ return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown)
79
+ else:
80
+ raise err