02alexander commited on
Commit
9a947d8
1 Parent(s): 23ae8d0

copy stuff

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +2 -0
  2. LICENSE +201 -0
  3. README.md +6 -4
  4. app.py +427 -0
  5. configs/instant-mesh-base.yaml +22 -0
  6. configs/instant-mesh-large.yaml +22 -0
  7. configs/instant-nerf-base.yaml +21 -0
  8. configs/instant-nerf-large.yaml +21 -0
  9. examples/bird.jpg +0 -0
  10. examples/bubble_mart_blue.png +0 -0
  11. examples/cake.jpg +0 -0
  12. examples/cartoon_dinosaur.png +0 -0
  13. examples/chair_armed.png +0 -0
  14. examples/chair_comfort.jpg +0 -0
  15. examples/chair_wood.jpg +0 -0
  16. examples/chest.jpg +0 -0
  17. examples/cute_horse.jpg +0 -0
  18. examples/cute_tiger.jpg +0 -0
  19. examples/earphone.jpg +0 -0
  20. examples/fox.jpg +0 -0
  21. examples/fruit.jpg +0 -0
  22. examples/fruit_elephant.jpg +0 -0
  23. examples/genshin_building.png +0 -0
  24. examples/genshin_teapot.png +0 -0
  25. examples/hatsune_miku.png +0 -0
  26. examples/house2.jpg +0 -0
  27. examples/mushroom_teapot.jpg +0 -0
  28. examples/pikachu.png +0 -0
  29. examples/plant.jpg +0 -0
  30. examples/robot.jpg +0 -0
  31. examples/sea_turtle.png +0 -0
  32. examples/skating_shoe.jpg +0 -0
  33. examples/sorting_board.png +0 -0
  34. examples/sword.png +0 -0
  35. examples/toy_car.jpg +0 -0
  36. examples/watermelon.png +0 -0
  37. examples/whitedog.png +0 -0
  38. examples/x_teapot.jpg +0 -0
  39. examples/x_toyduck.jpg +0 -0
  40. requirements.txt +27 -0
  41. src/__init__.py +0 -0
  42. src/__pycache__/__init__.cpython-311.pyc +0 -0
  43. src/data/__init__.py +0 -0
  44. src/data/objaverse.py +329 -0
  45. src/model.py +310 -0
  46. src/model_mesh.py +325 -0
  47. src/models/__init__.py +0 -0
  48. src/models/__pycache__/__init__.cpython-311.pyc +0 -0
  49. src/models/__pycache__/lrm_mesh.cpython-311.pyc +0 -0
  50. src/models/decoder/__init__.py +0 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.nix
2
+ venv/
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,12 +1,14 @@
1
  ---
2
  title: InstantMeshRerun
3
- emoji: 💻
4
- colorFrom: blue
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: InstantMeshRerun
3
+ emoji: 📚
4
+ colorFrom: indigo
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 4.26.0
8
  app_file: app.py
9
  pinned: false
10
+ short_description: Create a 3D model from an image in 10 seconds!
11
+ license: apache-2.0
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+
3
+ import os
4
+ import imageio
5
+ import numpy as np
6
+ import torch
7
+ import rembg
8
+ from PIL import Image
9
+ from torchvision.transforms import v2
10
+ from pytorch_lightning import seed_everything
11
+ from omegaconf import OmegaConf
12
+ from einops import rearrange, repeat
13
+ from tqdm import tqdm
14
+ from typing import Any
15
+ from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
16
+ import rerun as rr
17
+ from gradio_rerun import Rerun
18
+
19
+ from src.utils.train_util import instantiate_from_config
20
+ from src.utils.camera_util import (
21
+ FOV_to_intrinsics,
22
+ get_zero123plus_input_cameras,
23
+ get_circular_camera_poses,
24
+ )
25
+ from src.utils.mesh_util import save_obj, save_glb
26
+ from src.utils.infer_util import remove_background, resize_foreground, images_to_video
27
+
28
+ import tempfile
29
+ from functools import partial
30
+
31
+ from huggingface_hub import hf_hub_download
32
+
33
+ import gradio as gr
34
+
35
+
36
+ def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
37
+ """
38
+ Get the rendering camera parameters.
39
+ """
40
+ c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation)
41
+ if is_flexicubes:
42
+ cameras = torch.linalg.inv(c2ws)
43
+ cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
44
+ else:
45
+ extrinsics = c2ws.flatten(-2)
46
+ intrinsics = FOV_to_intrinsics(50.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2)
47
+ cameras = torch.cat([extrinsics, intrinsics], dim=-1)
48
+ cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
49
+ return cameras
50
+
51
+
52
+ def images_to_video(images, output_path, fps=30):
53
+ # images: (N, C, H, W)
54
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
55
+ frames = []
56
+ for i in range(images.shape[0]):
57
+ frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8).clip(0, 255)
58
+ assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
59
+ f"Frame shape mismatch: {frame.shape} vs {images.shape}"
60
+ assert frame.min() >= 0 and frame.max() <= 255, \
61
+ f"Frame value out of range: {frame.min()} ~ {frame.max()}"
62
+ frames.append(frame)
63
+ imageio.mimwrite(output_path, np.stack(frames), fps=fps, codec='h264')
64
+
65
+
66
+ ###############################################################################
67
+ # Configuration.
68
+ ###############################################################################
69
+
70
+ import shutil
71
+
72
+ def find_cuda():
73
+ # Check if CUDA_HOME or CUDA_PATH environment variables are set
74
+ cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
75
+
76
+ if cuda_home and os.path.exists(cuda_home):
77
+ return cuda_home
78
+
79
+ # Search for the nvcc executable in the system's PATH
80
+ nvcc_path = shutil.which('nvcc')
81
+
82
+ if nvcc_path:
83
+ # Remove the 'bin/nvcc' part to get the CUDA installation path
84
+ cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
85
+ return cuda_path
86
+
87
+ return None
88
+
89
+ cuda_path = find_cuda()
90
+
91
+ if cuda_path:
92
+ print(f"CUDA installation found at: {cuda_path}")
93
+ else:
94
+ print("CUDA installation not found")
95
+
96
+ config_path = 'configs/instant-mesh-large.yaml'
97
+ config = OmegaConf.load(config_path)
98
+ config_name = os.path.basename(config_path).replace('.yaml', '')
99
+ model_config = config.model_config
100
+ infer_config = config.infer_config
101
+
102
+ IS_FLEXICUBES = True if config_name.startswith('instant-mesh') else False
103
+
104
+ device = torch.device('cuda')
105
+
106
+ # load diffusion model
107
+ print('Loading diffusion model ...')
108
+ pipeline = DiffusionPipeline.from_pretrained(
109
+ "sudo-ai/zero123plus-v1.2",
110
+ custom_pipeline="zero123plus",
111
+ torch_dtype=torch.float16,
112
+ )
113
+ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
114
+ pipeline.scheduler.config, timestep_spacing='trailing'
115
+ )
116
+
117
+ # load custom white-background UNet
118
+ unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
119
+ state_dict = torch.load(unet_ckpt_path, map_location='cpu')
120
+ pipeline.unet.load_state_dict(state_dict, strict=True)
121
+
122
+ pipeline = pipeline.to(device)
123
+ print(f'type(pipeline)={type(pipeline)}')
124
+
125
+ # load reconstruction model
126
+ print('Loading reconstruction model ...')
127
+ model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model")
128
+ model = instantiate_from_config(model_config)
129
+ state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
130
+ state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k}
131
+ model.load_state_dict(state_dict, strict=True)
132
+
133
+ model = model.to(device)
134
+
135
+ print('Loading Finished!')
136
+
137
+
138
+ def check_input_image(input_image):
139
+ if input_image is None:
140
+ raise gr.Error("No image uploaded!")
141
+
142
+
143
+ def preprocess(input_image, do_remove_background):
144
+
145
+ rembg_session = rembg.new_session() if do_remove_background else None
146
+
147
+ if do_remove_background:
148
+ input_image = remove_background(input_image, rembg_session)
149
+ input_image = resize_foreground(input_image, 0.85)
150
+
151
+ return input_image
152
+
153
+
154
+ def pipeline_callback(pipe: Any, step_index: int, timestep: float, callback_kwargs: dict[str, Any]) -> dict[str, Any]:
155
+ rr.set_time_sequence("iteration", step_index)
156
+ rr.set_time_seconds("timestep", timestep)
157
+ latents = callback_kwargs["latents"]
158
+ image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0] # type: ignore[attr-defined]
159
+ image = pipe.image_processor.postprocess(image, output_type="np").squeeze() # type: ignore[attr-defined]
160
+
161
+ rr.log("output", rr.Image(image))
162
+ rr.log("latents", rr.Tensor(latents.squeeze()))
163
+ return callback_kwargs
164
+
165
+ @spaces.GPU
166
+ def generate_mvs(input_image, sample_steps, sample_seed):
167
+
168
+ seed_everything(sample_seed)
169
+
170
+ return pipeline(
171
+ input_image,
172
+ num_inference_steps=sample_steps,
173
+ callback_on_step_end=pipeline_callback,
174
+ )
175
+
176
+ # sampling
177
+ # z123_image = pipeline(
178
+ # input_image,
179
+ # num_inference_steps=sample_steps
180
+ # ).images[0]
181
+
182
+ # show_image = np.asarray(z123_image, dtype=np.uint8)
183
+ # show_image = torch.from_numpy(show_image) # (960, 640, 3)
184
+ # show_image = rearrange(show_image, '(n h) (m w) c -> (n m) h w c', n=3, m=2)
185
+ # show_image = rearrange(show_image, '(n m) h w c -> (n h) (m w) c', n=2, m=3)
186
+ # show_image = Image.fromarray(show_image.numpy())
187
+
188
+ # return z123_image, show_image
189
+
190
+
191
+ @spaces.GPU
192
+ def make3d(images):
193
+
194
+ global model
195
+ if IS_FLEXICUBES:
196
+ model.init_flexicubes_geometry(device, use_renderer=False)
197
+ model = model.eval()
198
+
199
+ images = np.asarray(images, dtype=np.float32) / 255.0
200
+ images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
201
+ images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
202
+
203
+ input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device)
204
+ render_cameras = get_render_cameras(batch_size=1, radius=2.5, is_flexicubes=IS_FLEXICUBES).to(device)
205
+
206
+ images = images.unsqueeze(0).to(device)
207
+ images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
208
+
209
+ mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
210
+ print(mesh_fpath)
211
+ mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
212
+ mesh_dirname = os.path.dirname(mesh_fpath)
213
+ video_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.mp4")
214
+ mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
215
+
216
+ with torch.no_grad():
217
+ # get triplane
218
+ planes = model.forward_planes(images, input_cameras)
219
+
220
+ # # get video
221
+ # chunk_size = 20 if IS_FLEXICUBES else 1
222
+ # render_size = 384
223
+
224
+ # frames = []
225
+ # for i in tqdm(range(0, render_cameras.shape[1], chunk_size)):
226
+ # if IS_FLEXICUBES:
227
+ # frame = model.forward_geometry(
228
+ # planes,
229
+ # render_cameras[:, i:i+chunk_size],
230
+ # render_size=render_size,
231
+ # )['img']
232
+ # else:
233
+ # frame = model.synthesizer(
234
+ # planes,
235
+ # cameras=render_cameras[:, i:i+chunk_size],
236
+ # render_size=render_size,
237
+ # )['images_rgb']
238
+ # frames.append(frame)
239
+ # frames = torch.cat(frames, dim=1)
240
+
241
+ # images_to_video(
242
+ # frames[0],
243
+ # video_fpath,
244
+ # fps=30,
245
+ # )
246
+
247
+ # print(f"Video saved to {video_fpath}")
248
+
249
+ # get mesh
250
+ mesh_out = model.extract_mesh(
251
+ planes,
252
+ use_texture_map=False,
253
+ **infer_config,
254
+ )
255
+
256
+ vertices, faces, vertex_colors = mesh_out
257
+ vertices = vertices[:, [1, 2, 0]]
258
+
259
+ save_glb(vertices, faces, vertex_colors, mesh_glb_fpath)
260
+ save_obj(vertices, faces, vertex_colors, mesh_fpath)
261
+
262
+ print(f"Mesh saved to {mesh_fpath}")
263
+
264
+ return mesh_fpath, mesh_glb_fpath
265
+
266
+ @rr.thread_local_stream("InstantMesh_visualization")
267
+ def log_to_rr(input_image, do_remove_background, sample_steps, sample_seed):
268
+ preprocessed_image = preprocess(input_image, do_remove_background)
269
+
270
+ stream = rr.binary_stream()
271
+
272
+ rr.log("preprocessed_image", rr.Image(preprocessed_image))
273
+
274
+ yield stream.read()
275
+
276
+ z123_out = generate_mvs(input_image, sample_steps, sample_seed)
277
+ print(z123_out)
278
+ for image in z123_out.images:
279
+ rr.log("z123image", rr.Image(image))
280
+ yield stream.read()
281
+
282
+ pass
283
+
284
+ _HEADER_ = '''
285
+ <h2><b>Official 🤗 Gradio Demo</b></h2><h2><a href='https://github.com/TencentARC/InstantMesh' target='_blank'><b>InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models</b></a></h2>
286
+
287
+ **InstantMesh** is a feed-forward framework for efficient 3D mesh generation from a single image based on the LRM/Instant3D architecture.
288
+
289
+ Code: <a href='https://github.com/TencentARC/InstantMesh' target='_blank'>GitHub</a>. Techenical report: <a href='https://arxiv.org/abs/2404.07191' target='_blank'>ArXiv</a>.
290
+
291
+ ❗️❗️❗️**Important Notes:**
292
+ - Our demo can export a .obj mesh with vertex colors or a .glb mesh now. If you prefer to export a .obj mesh with a **texture map**, please refer to our <a href='https://github.com/TencentARC/InstantMesh?tab=readme-ov-file#running-with-command-line' target='_blank'>Github Repo</a>.
293
+ - The 3D mesh generation results highly depend on the quality of generated multi-view images. Please try a different **seed value** if the result is unsatisfying (Default: 42).
294
+ '''
295
+
296
+ _CITE_ = r"""
297
+ If InstantMesh is helpful, please help to ⭐ the <a href='https://github.com/TencentARC/InstantMesh' target='_blank'>Github Repo</a>. Thanks! [![GitHub Stars](https://img.shields.io/github/stars/TencentARC/InstantMesh?style=social)](https://github.com/TencentARC/InstantMesh)
298
+ ---
299
+ 📝 **Citation**
300
+
301
+ If you find our work useful for your research or applications, please cite using this bibtex:
302
+ ```bibtex
303
+ @article{xu2024instantmesh,
304
+ title={InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models},
305
+ author={Xu, Jiale and Cheng, Weihao and Gao, Yiming and Wang, Xintao and Gao, Shenghua and Shan, Ying},
306
+ journal={arXiv preprint arXiv:2404.07191},
307
+ year={2024}
308
+ }
309
+ ```
310
+
311
+ 📋 **License**
312
+
313
+ Apache-2.0 LICENSE. Please refer to the [LICENSE file](https://huggingface.co/spaces/TencentARC/InstantMesh/blob/main/LICENSE) for details.
314
+
315
+ 📧 **Contact**
316
+
317
+ If you have any questions, feel free to open a discussion or contact us at <b>[email protected]</b>.
318
+ """
319
+
320
+
321
+ with gr.Blocks() as demo:
322
+ gr.Markdown(_HEADER_)
323
+ with gr.Row(variant="panel"):
324
+ with gr.Column():
325
+ with gr.Row():
326
+ input_image = gr.Image(
327
+ label="Input Image",
328
+ image_mode="RGBA",
329
+ sources="upload",
330
+ #width=256,
331
+ #height=256,
332
+ type="pil",
333
+ elem_id="content_image",
334
+ )
335
+ processed_image = gr.Image(
336
+ label="Processed Image",
337
+ image_mode="RGBA",
338
+ #width=256,
339
+ #height=256,
340
+ type="pil",
341
+ interactive=False
342
+ )
343
+ with gr.Row():
344
+ with gr.Group():
345
+ do_remove_background = gr.Checkbox(
346
+ label="Remove Background", value=True
347
+ )
348
+ sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
349
+
350
+ sample_steps = gr.Slider(
351
+ label="Sample Steps",
352
+ minimum=30,
353
+ maximum=75,
354
+ value=75,
355
+ step=5
356
+ )
357
+
358
+ with gr.Row():
359
+ submit = gr.Button("Generate", elem_id="generate", variant="primary")
360
+
361
+ with gr.Row(variant="panel"):
362
+ gr.Examples(
363
+ examples=[
364
+ os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
365
+ ],
366
+ inputs=[input_image],
367
+ label="Examples",
368
+ cache_examples=False,
369
+ examples_per_page=16
370
+ )
371
+
372
+ with gr.Column():
373
+
374
+ viewer = Rerun(streaming=True, height=800)
375
+
376
+ # with gr.Row():
377
+
378
+ # with gr.Column():
379
+ # mv_show_images = gr.Image(
380
+ # label="Generated Multi-views",
381
+ # type="pil",
382
+ # width=379,
383
+ # interactive=False
384
+ # )
385
+
386
+ # with gr.Row():
387
+ # with gr.Tab("OBJ"):
388
+ # output_model_obj = gr.Model3D(
389
+ # label="Output Model (OBJ Format)",
390
+ # interactive=False,
391
+ # )
392
+ # gr.Markdown("Note: Downloaded .obj model will be flipped. Export .glb instead or manually flip it before usage.")
393
+ # with gr.Tab("GLB"):
394
+ # output_model_glb = gr.Model3D(
395
+ # label="Output Model (GLB Format)",
396
+ # interactive=False,
397
+ # )
398
+ # gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
399
+
400
+ with gr.Row():
401
+ gr.Markdown('''Try a different <b>seed value</b> if the result is unsatisfying (Default: 42).''')
402
+
403
+ gr.Markdown(_CITE_)
404
+
405
+ mv_images = gr.State()
406
+
407
+ submit.click(fn=check_input_image, inputs=[input_image]).success(
408
+ fn=log_to_rr,
409
+ inputs=[input_image, do_remove_background, sample_steps, sample_seed],
410
+ outputs=[viewer]
411
+ )
412
+ # submit.click(fn=check_input_image, inputs=[input_image]).success(
413
+ # fn=preprocess,
414
+ # inputs=[input_image, do_remove_background],
415
+ # outputs=[processed_image],
416
+ # ).success(
417
+ # fn=generate_mvs,
418
+ # inputs=[processed_image, sample_steps, sample_seed],
419
+ # outputs=[mv_images, mv_show_images]
420
+
421
+ # ).success(
422
+ # fn=make3d,
423
+ # inputs=[mv_images],
424
+ # outputs=[output_model_obj, output_model_glb]
425
+ # )
426
+
427
+ demo.launch()
configs/instant-mesh-base.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_config:
2
+ target: src.models.lrm_mesh.InstantMesh
3
+ params:
4
+ encoder_feat_dim: 768
5
+ encoder_freeze: false
6
+ encoder_model_name: facebook/dino-vitb16
7
+ transformer_dim: 1024
8
+ transformer_layers: 12
9
+ transformer_heads: 16
10
+ triplane_low_res: 32
11
+ triplane_high_res: 64
12
+ triplane_dim: 40
13
+ rendering_samples_per_ray: 96
14
+ grid_res: 128
15
+ grid_scale: 2.1
16
+
17
+
18
+ infer_config:
19
+ unet_path: ckpts/diffusion_pytorch_model.bin
20
+ model_path: ckpts/instant_mesh_base.ckpt
21
+ texture_resolution: 1024
22
+ render_resolution: 512
configs/instant-mesh-large.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_config:
2
+ target: src.models.lrm_mesh.InstantMesh
3
+ params:
4
+ encoder_feat_dim: 768
5
+ encoder_freeze: false
6
+ encoder_model_name: facebook/dino-vitb16
7
+ transformer_dim: 1024
8
+ transformer_layers: 16
9
+ transformer_heads: 16
10
+ triplane_low_res: 32
11
+ triplane_high_res: 64
12
+ triplane_dim: 80
13
+ rendering_samples_per_ray: 128
14
+ grid_res: 128
15
+ grid_scale: 2.1
16
+
17
+
18
+ infer_config:
19
+ unet_path: ckpts/diffusion_pytorch_model.bin
20
+ model_path: ckpts/instant_mesh_large.ckpt
21
+ texture_resolution: 1024
22
+ render_resolution: 512
configs/instant-nerf-base.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_config:
2
+ target: src.models.lrm.InstantNeRF
3
+ params:
4
+ encoder_feat_dim: 768
5
+ encoder_freeze: false
6
+ encoder_model_name: facebook/dino-vitb16
7
+ transformer_dim: 1024
8
+ transformer_layers: 12
9
+ transformer_heads: 16
10
+ triplane_low_res: 32
11
+ triplane_high_res: 64
12
+ triplane_dim: 40
13
+ rendering_samples_per_ray: 96
14
+
15
+
16
+ infer_config:
17
+ unet_path: ckpts/diffusion_pytorch_model.bin
18
+ model_path: ckpts/instant_nerf_base.ckpt
19
+ mesh_threshold: 10.0
20
+ mesh_resolution: 256
21
+ render_resolution: 384
configs/instant-nerf-large.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_config:
2
+ target: src.models.lrm.InstantNeRF
3
+ params:
4
+ encoder_feat_dim: 768
5
+ encoder_freeze: false
6
+ encoder_model_name: facebook/dino-vitb16
7
+ transformer_dim: 1024
8
+ transformer_layers: 16
9
+ transformer_heads: 16
10
+ triplane_low_res: 32
11
+ triplane_high_res: 64
12
+ triplane_dim: 80
13
+ rendering_samples_per_ray: 128
14
+
15
+
16
+ infer_config:
17
+ unet_path: ckpts/diffusion_pytorch_model.bin
18
+ model_path: ckpts/instant_nerf_large.ckpt
19
+ mesh_threshold: 10.0
20
+ mesh_resolution: 256
21
+ render_resolution: 384
examples/bird.jpg ADDED
examples/bubble_mart_blue.png ADDED
examples/cake.jpg ADDED
examples/cartoon_dinosaur.png ADDED
examples/chair_armed.png ADDED
examples/chair_comfort.jpg ADDED
examples/chair_wood.jpg ADDED
examples/chest.jpg ADDED
examples/cute_horse.jpg ADDED
examples/cute_tiger.jpg ADDED
examples/earphone.jpg ADDED
examples/fox.jpg ADDED
examples/fruit.jpg ADDED
examples/fruit_elephant.jpg ADDED
examples/genshin_building.png ADDED
examples/genshin_teapot.png ADDED
examples/hatsune_miku.png ADDED
examples/house2.jpg ADDED
examples/mushroom_teapot.jpg ADDED
examples/pikachu.png ADDED
examples/plant.jpg ADDED
examples/robot.jpg ADDED
examples/sea_turtle.png ADDED
examples/skating_shoe.jpg ADDED
examples/sorting_board.png ADDED
examples/sword.png ADDED
examples/toy_car.jpg ADDED
examples/watermelon.png ADDED
examples/whitedog.png ADDED
examples/x_teapot.jpg ADDED
examples/x_toyduck.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ spaces
2
+ torch==2.1.0
3
+ torchvision==0.16.0
4
+ torchaudio==2.1.0
5
+ pytorch-lightning==2.1.2
6
+ einops
7
+ omegaconf
8
+ deepspeed
9
+ torchmetrics
10
+ webdataset
11
+ accelerate
12
+ tensorboard
13
+ PyMCubes
14
+ trimesh
15
+ rembg
16
+ transformers
17
+ diffusers==0.28.2
18
+ bitsandbytes
19
+ imageio[ffmpeg]
20
+ xatlas
21
+ plyfile
22
+ xformers==0.0.22.post7
23
+ git+https://github.com/NVlabs/nvdiffrast/
24
+ huggingface-hub
25
+ gradio_client >= 0.12
26
+ rerun-sdk>=0.16.0,<0.17.0
27
+ gradio_rerun
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (157 Bytes). View file
 
src/data/__init__.py ADDED
File without changes
src/data/objaverse.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import math
3
+ import json
4
+ import importlib
5
+ from pathlib import Path
6
+
7
+ import cv2
8
+ import random
9
+ import numpy as np
10
+ from PIL import Image
11
+ import webdataset as wds
12
+ import pytorch_lightning as pl
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from torch.utils.data import Dataset
17
+ from torch.utils.data import DataLoader
18
+ from torch.utils.data.distributed import DistributedSampler
19
+ from torchvision import transforms
20
+
21
+ from src.utils.train_util import instantiate_from_config
22
+ from src.utils.camera_util import (
23
+ FOV_to_intrinsics,
24
+ center_looking_at_camera_pose,
25
+ get_surrounding_views,
26
+ )
27
+
28
+
29
+ class DataModuleFromConfig(pl.LightningDataModule):
30
+ def __init__(
31
+ self,
32
+ batch_size=8,
33
+ num_workers=4,
34
+ train=None,
35
+ validation=None,
36
+ test=None,
37
+ **kwargs,
38
+ ):
39
+ super().__init__()
40
+
41
+ self.batch_size = batch_size
42
+ self.num_workers = num_workers
43
+
44
+ self.dataset_configs = dict()
45
+ if train is not None:
46
+ self.dataset_configs['train'] = train
47
+ if validation is not None:
48
+ self.dataset_configs['validation'] = validation
49
+ if test is not None:
50
+ self.dataset_configs['test'] = test
51
+
52
+ def setup(self, stage):
53
+
54
+ if stage in ['fit']:
55
+ self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
56
+ else:
57
+ raise NotImplementedError
58
+
59
+ def train_dataloader(self):
60
+
61
+ sampler = DistributedSampler(self.datasets['train'])
62
+ return wds.WebLoader(self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler)
63
+
64
+ def val_dataloader(self):
65
+
66
+ sampler = DistributedSampler(self.datasets['validation'])
67
+ return wds.WebLoader(self.datasets['validation'], batch_size=1, num_workers=self.num_workers, shuffle=False, sampler=sampler)
68
+
69
+ def test_dataloader(self):
70
+
71
+ return wds.WebLoader(self.datasets['test'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
72
+
73
+
74
+ class ObjaverseData(Dataset):
75
+ def __init__(self,
76
+ root_dir='objaverse/',
77
+ meta_fname='valid_paths.json',
78
+ input_image_dir='rendering_random_32views',
79
+ target_image_dir='rendering_random_32views',
80
+ input_view_num=6,
81
+ target_view_num=2,
82
+ total_view_n=32,
83
+ fov=50,
84
+ camera_rotation=True,
85
+ validation=False,
86
+ ):
87
+ self.root_dir = Path(root_dir)
88
+ self.input_image_dir = input_image_dir
89
+ self.target_image_dir = target_image_dir
90
+
91
+ self.input_view_num = input_view_num
92
+ self.target_view_num = target_view_num
93
+ self.total_view_n = total_view_n
94
+ self.fov = fov
95
+ self.camera_rotation = camera_rotation
96
+
97
+ with open(os.path.join(root_dir, meta_fname)) as f:
98
+ filtered_dict = json.load(f)
99
+ paths = filtered_dict['good_objs']
100
+ self.paths = paths
101
+
102
+ self.depth_scale = 4.0
103
+
104
+ total_objects = len(self.paths)
105
+ print('============= length of dataset %d =============' % len(self.paths))
106
+
107
+ def __len__(self):
108
+ return len(self.paths)
109
+
110
+ def load_im(self, path, color):
111
+ '''
112
+ replace background pixel with random color in rendering
113
+ '''
114
+ pil_img = Image.open(path)
115
+
116
+ image = np.asarray(pil_img, dtype=np.float32) / 255.
117
+ alpha = image[:, :, 3:]
118
+ image = image[:, :, :3] * alpha + color * (1 - alpha)
119
+
120
+ image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
121
+ alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
122
+ return image, alpha
123
+
124
+ def __getitem__(self, index):
125
+ # load data
126
+ while True:
127
+ input_image_path = os.path.join(self.root_dir, self.input_image_dir, self.paths[index])
128
+ target_image_path = os.path.join(self.root_dir, self.target_image_dir, self.paths[index])
129
+
130
+ indices = np.random.choice(range(self.total_view_n), self.input_view_num + self.target_view_num, replace=False)
131
+ input_indices = indices[:self.input_view_num]
132
+ target_indices = indices[self.input_view_num:]
133
+
134
+ '''background color, default: white'''
135
+ bg_white = [1., 1., 1.]
136
+ bg_black = [0., 0., 0.]
137
+
138
+ image_list = []
139
+ alpha_list = []
140
+ depth_list = []
141
+ normal_list = []
142
+ pose_list = []
143
+
144
+ try:
145
+ input_cameras = np.load(os.path.join(input_image_path, 'cameras.npz'))['cam_poses']
146
+ for idx in input_indices:
147
+ image, alpha = self.load_im(os.path.join(input_image_path, '%03d.png' % idx), bg_white)
148
+ normal, _ = self.load_im(os.path.join(input_image_path, '%03d_normal.png' % idx), bg_black)
149
+ depth = cv2.imread(os.path.join(input_image_path, '%03d_depth.png' % idx), cv2.IMREAD_UNCHANGED) / 255.0 * self.depth_scale
150
+ depth = torch.from_numpy(depth).unsqueeze(0)
151
+ pose = input_cameras[idx]
152
+ pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0)
153
+
154
+ image_list.append(image)
155
+ alpha_list.append(alpha)
156
+ depth_list.append(depth)
157
+ normal_list.append(normal)
158
+ pose_list.append(pose)
159
+
160
+ target_cameras = np.load(os.path.join(target_image_path, 'cameras.npz'))['cam_poses']
161
+ for idx in target_indices:
162
+ image, alpha = self.load_im(os.path.join(target_image_path, '%03d.png' % idx), bg_white)
163
+ normal, _ = self.load_im(os.path.join(target_image_path, '%03d_normal.png' % idx), bg_black)
164
+ depth = cv2.imread(os.path.join(target_image_path, '%03d_depth.png' % idx), cv2.IMREAD_UNCHANGED) / 255.0 * self.depth_scale
165
+ depth = torch.from_numpy(depth).unsqueeze(0)
166
+ pose = target_cameras[idx]
167
+ pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0)
168
+
169
+ image_list.append(image)
170
+ alpha_list.append(alpha)
171
+ depth_list.append(depth)
172
+ normal_list.append(normal)
173
+ pose_list.append(pose)
174
+
175
+ except Exception as e:
176
+ print(e)
177
+ index = np.random.randint(0, len(self.paths))
178
+ continue
179
+
180
+ break
181
+
182
+ images = torch.stack(image_list, dim=0).float() # (6+V, 3, H, W)
183
+ alphas = torch.stack(alpha_list, dim=0).float() # (6+V, 1, H, W)
184
+ depths = torch.stack(depth_list, dim=0).float() # (6+V, 1, H, W)
185
+ normals = torch.stack(normal_list, dim=0).float() # (6+V, 3, H, W)
186
+ w2cs = torch.from_numpy(np.stack(pose_list, axis=0)).float() # (6+V, 4, 4)
187
+ c2ws = torch.linalg.inv(w2cs).float()
188
+
189
+ normals = normals * 2.0 - 1.0
190
+ normals = F.normalize(normals, dim=1)
191
+ normals = (normals + 1.0) / 2.0
192
+ normals = torch.lerp(torch.zeros_like(normals), normals, alphas)
193
+
194
+ # random rotation along z axis
195
+ if self.camera_rotation:
196
+ degree = np.random.uniform(0, math.pi * 2)
197
+ rot = torch.tensor([
198
+ [np.cos(degree), -np.sin(degree), 0, 0],
199
+ [np.sin(degree), np.cos(degree), 0, 0],
200
+ [0, 0, 1, 0],
201
+ [0, 0, 0, 1],
202
+ ]).unsqueeze(0).float()
203
+ c2ws = torch.matmul(rot, c2ws)
204
+
205
+ # rotate normals
206
+ N, _, H, W = normals.shape
207
+ normals = normals * 2.0 - 1.0
208
+ normals = torch.matmul(rot[:, :3, :3], normals.view(N, 3, -1)).view(N, 3, H, W)
209
+ normals = F.normalize(normals, dim=1)
210
+ normals = (normals + 1.0) / 2.0
211
+ normals = torch.lerp(torch.zeros_like(normals), normals, alphas)
212
+
213
+ # random scaling
214
+ if np.random.rand() < 0.5:
215
+ scale = np.random.uniform(0.8, 1.0)
216
+ c2ws[:, :3, 3] *= scale
217
+ depths *= scale
218
+
219
+ # instrinsics of perspective cameras
220
+ K = FOV_to_intrinsics(self.fov)
221
+ Ks = K.unsqueeze(0).repeat(self.input_view_num + self.target_view_num, 1, 1).float()
222
+
223
+ data = {
224
+ 'input_images': images[:self.input_view_num], # (6, 3, H, W)
225
+ 'input_alphas': alphas[:self.input_view_num], # (6, 1, H, W)
226
+ 'input_depths': depths[:self.input_view_num], # (6, 1, H, W)
227
+ 'input_normals': normals[:self.input_view_num], # (6, 3, H, W)
228
+ 'input_c2ws': c2ws_input[:self.input_view_num], # (6, 4, 4)
229
+ 'input_Ks': Ks[:self.input_view_num], # (6, 3, 3)
230
+
231
+ # lrm generator input and supervision
232
+ 'target_images': images[self.input_view_num:], # (V, 3, H, W)
233
+ 'target_alphas': alphas[self.input_view_num:], # (V, 1, H, W)
234
+ 'target_depths': depths[self.input_view_num:], # (V, 1, H, W)
235
+ 'target_normals': normals[self.input_view_num:], # (V, 3, H, W)
236
+ 'target_c2ws': c2ws[self.input_view_num:], # (V, 4, 4)
237
+ 'target_Ks': Ks[self.input_view_num:], # (V, 3, 3)
238
+
239
+ 'depth_available': 1,
240
+ }
241
+ return data
242
+
243
+
244
+ class ValidationData(Dataset):
245
+ def __init__(self,
246
+ root_dir='objaverse/',
247
+ input_view_num=6,
248
+ input_image_size=256,
249
+ fov=50,
250
+ ):
251
+ self.root_dir = Path(root_dir)
252
+ self.input_view_num = input_view_num
253
+ self.input_image_size = input_image_size
254
+ self.fov = fov
255
+
256
+ self.paths = sorted(os.listdir(self.root_dir))
257
+ print('============= length of dataset %d =============' % len(self.paths))
258
+
259
+ cam_distance = 2.5
260
+ azimuths = np.array([30, 90, 150, 210, 270, 330])
261
+ elevations = np.array([30, -20, 30, -20, 30, -20])
262
+ azimuths = np.deg2rad(azimuths)
263
+ elevations = np.deg2rad(elevations)
264
+
265
+ x = cam_distance * np.cos(elevations) * np.cos(azimuths)
266
+ y = cam_distance * np.cos(elevations) * np.sin(azimuths)
267
+ z = cam_distance * np.sin(elevations)
268
+
269
+ cam_locations = np.stack([x, y, z], axis=-1)
270
+ cam_locations = torch.from_numpy(cam_locations).float()
271
+ c2ws = center_looking_at_camera_pose(cam_locations)
272
+ self.c2ws = c2ws.float()
273
+ self.Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(6, 1, 1).float()
274
+
275
+ render_c2ws = get_surrounding_views(M=8, radius=cam_distance)
276
+ render_Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(render_c2ws.shape[0], 1, 1)
277
+ self.render_c2ws = render_c2ws.float()
278
+ self.render_Ks = render_Ks.float()
279
+
280
+ def __len__(self):
281
+ return len(self.paths)
282
+
283
+ def load_im(self, path, color):
284
+ '''
285
+ replace background pixel with random color in rendering
286
+ '''
287
+ pil_img = Image.open(path)
288
+ pil_img = pil_img.resize((self.input_image_size, self.input_image_size), resample=Image.BICUBIC)
289
+
290
+ image = np.asarray(pil_img, dtype=np.float32) / 255.
291
+ if image.shape[-1] == 4:
292
+ alpha = image[:, :, 3:]
293
+ image = image[:, :, :3] * alpha + color * (1 - alpha)
294
+ else:
295
+ alpha = np.ones_like(image[:, :, :1])
296
+
297
+ image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
298
+ alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
299
+ return image, alpha
300
+
301
+ def __getitem__(self, index):
302
+ # load data
303
+ input_image_path = os.path.join(self.root_dir, self.paths[index])
304
+
305
+ '''background color, default: white'''
306
+ # color = np.random.uniform(0.48, 0.52)
307
+ bkg_color = [1.0, 1.0, 1.0]
308
+
309
+ image_list = []
310
+ alpha_list = []
311
+
312
+ for idx in range(self.input_view_num):
313
+ image, alpha = self.load_im(os.path.join(input_image_path, f'{idx:03d}.png'), bkg_color)
314
+ image_list.append(image)
315
+ alpha_list.append(alpha)
316
+
317
+ images = torch.stack(image_list, dim=0).float() # (6+V, 3, H, W)
318
+ alphas = torch.stack(alpha_list, dim=0).float() # (6+V, 1, H, W)
319
+
320
+ data = {
321
+ 'input_images': images, # (6, 3, H, W)
322
+ 'input_alphas': alphas, # (6, 1, H, W)
323
+ 'input_c2ws': self.c2ws, # (6, 4, 4)
324
+ 'input_Ks': self.Ks, # (6, 3, 3)
325
+
326
+ 'render_c2ws': self.render_c2ws,
327
+ 'render_Ks': self.render_Ks,
328
+ }
329
+ return data
src/model.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torchvision.transforms import v2
6
+ from torchvision.utils import make_grid, save_image
7
+ from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
8
+ import pytorch_lightning as pl
9
+ from einops import rearrange, repeat
10
+
11
+ from src.utils.train_util import instantiate_from_config
12
+
13
+
14
+ class MVRecon(pl.LightningModule):
15
+ def __init__(
16
+ self,
17
+ lrm_generator_config,
18
+ lrm_path=None,
19
+ input_size=256,
20
+ render_size=192,
21
+ ):
22
+ super(MVRecon, self).__init__()
23
+
24
+ self.input_size = input_size
25
+ self.render_size = render_size
26
+
27
+ # init modules
28
+ self.lrm_generator = instantiate_from_config(lrm_generator_config)
29
+ if lrm_path is not None:
30
+ lrm_ckpt = torch.load(lrm_path)
31
+ self.lrm_generator.load_state_dict(lrm_ckpt['weights'], strict=False)
32
+
33
+ self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')
34
+
35
+ self.validation_step_outputs = []
36
+
37
+ def on_fit_start(self):
38
+ if self.global_rank == 0:
39
+ os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)
40
+ os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)
41
+
42
+ def prepare_batch_data(self, batch):
43
+ lrm_generator_input = {}
44
+ render_gt = {} # for supervision
45
+
46
+ # input images
47
+ images = batch['input_images']
48
+ images = v2.functional.resize(
49
+ images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
50
+
51
+ lrm_generator_input['images'] = images.to(self.device)
52
+
53
+ # input cameras and render cameras
54
+ input_c2ws = batch['input_c2ws'].flatten(-2)
55
+ input_Ks = batch['input_Ks'].flatten(-2)
56
+ target_c2ws = batch['target_c2ws'].flatten(-2)
57
+ target_Ks = batch['target_Ks'].flatten(-2)
58
+ render_cameras_input = torch.cat([input_c2ws, input_Ks], dim=-1)
59
+ render_cameras_target = torch.cat([target_c2ws, target_Ks], dim=-1)
60
+ render_cameras = torch.cat([render_cameras_input, render_cameras_target], dim=1)
61
+
62
+ input_extrinsics = input_c2ws[:, :, :12]
63
+ input_intrinsics = torch.stack([
64
+ input_Ks[:, :, 0], input_Ks[:, :, 4],
65
+ input_Ks[:, :, 2], input_Ks[:, :, 5],
66
+ ], dim=-1)
67
+ cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
68
+
69
+ # add noise to input cameras
70
+ cameras = cameras + torch.rand_like(cameras) * 0.04 - 0.02
71
+
72
+ lrm_generator_input['cameras'] = cameras.to(self.device)
73
+ lrm_generator_input['render_cameras'] = render_cameras.to(self.device)
74
+
75
+ # target images
76
+ target_images = torch.cat([batch['input_images'], batch['target_images']], dim=1)
77
+ target_depths = torch.cat([batch['input_depths'], batch['target_depths']], dim=1)
78
+ target_alphas = torch.cat([batch['input_alphas'], batch['target_alphas']], dim=1)
79
+
80
+ # random crop
81
+ render_size = np.random.randint(self.render_size, 513)
82
+ target_images = v2.functional.resize(
83
+ target_images, render_size, interpolation=3, antialias=True).clamp(0, 1)
84
+ target_depths = v2.functional.resize(
85
+ target_depths, render_size, interpolation=0, antialias=True)
86
+ target_alphas = v2.functional.resize(
87
+ target_alphas, render_size, interpolation=0, antialias=True)
88
+
89
+ crop_params = v2.RandomCrop.get_params(
90
+ target_images, output_size=(self.render_size, self.render_size))
91
+ target_images = v2.functional.crop(target_images, *crop_params)
92
+ target_depths = v2.functional.crop(target_depths, *crop_params)[:, :, 0:1]
93
+ target_alphas = v2.functional.crop(target_alphas, *crop_params)[:, :, 0:1]
94
+
95
+ lrm_generator_input['render_size'] = render_size
96
+ lrm_generator_input['crop_params'] = crop_params
97
+
98
+ render_gt['target_images'] = target_images.to(self.device)
99
+ render_gt['target_depths'] = target_depths.to(self.device)
100
+ render_gt['target_alphas'] = target_alphas.to(self.device)
101
+
102
+ return lrm_generator_input, render_gt
103
+
104
+ def prepare_validation_batch_data(self, batch):
105
+ lrm_generator_input = {}
106
+
107
+ # input images
108
+ images = batch['input_images']
109
+ images = v2.functional.resize(
110
+ images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
111
+
112
+ lrm_generator_input['images'] = images.to(self.device)
113
+
114
+ input_c2ws = batch['input_c2ws'].flatten(-2)
115
+ input_Ks = batch['input_Ks'].flatten(-2)
116
+
117
+ input_extrinsics = input_c2ws[:, :, :12]
118
+ input_intrinsics = torch.stack([
119
+ input_Ks[:, :, 0], input_Ks[:, :, 4],
120
+ input_Ks[:, :, 2], input_Ks[:, :, 5],
121
+ ], dim=-1)
122
+ cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
123
+
124
+ lrm_generator_input['cameras'] = cameras.to(self.device)
125
+
126
+ render_c2ws = batch['render_c2ws'].flatten(-2)
127
+ render_Ks = batch['render_Ks'].flatten(-2)
128
+ render_cameras = torch.cat([render_c2ws, render_Ks], dim=-1)
129
+
130
+ lrm_generator_input['render_cameras'] = render_cameras.to(self.device)
131
+ lrm_generator_input['render_size'] = 384
132
+ lrm_generator_input['crop_params'] = None
133
+
134
+ return lrm_generator_input
135
+
136
+ def forward_lrm_generator(
137
+ self,
138
+ images,
139
+ cameras,
140
+ render_cameras,
141
+ render_size=192,
142
+ crop_params=None,
143
+ chunk_size=1,
144
+ ):
145
+ planes = torch.utils.checkpoint.checkpoint(
146
+ self.lrm_generator.forward_planes,
147
+ images,
148
+ cameras,
149
+ use_reentrant=False,
150
+ )
151
+ frames = []
152
+ for i in range(0, render_cameras.shape[1], chunk_size):
153
+ frames.append(
154
+ torch.utils.checkpoint.checkpoint(
155
+ self.lrm_generator.synthesizer,
156
+ planes,
157
+ cameras=render_cameras[:, i:i+chunk_size],
158
+ render_size=render_size,
159
+ crop_params=crop_params,
160
+ use_reentrant=False
161
+ )
162
+ )
163
+ frames = {
164
+ k: torch.cat([r[k] for r in frames], dim=1)
165
+ for k in frames[0].keys()
166
+ }
167
+ return frames
168
+
169
+ def forward(self, lrm_generator_input):
170
+ images = lrm_generator_input['images']
171
+ cameras = lrm_generator_input['cameras']
172
+ render_cameras = lrm_generator_input['render_cameras']
173
+ render_size = lrm_generator_input['render_size']
174
+ crop_params = lrm_generator_input['crop_params']
175
+
176
+ out = self.forward_lrm_generator(
177
+ images,
178
+ cameras,
179
+ render_cameras,
180
+ render_size=render_size,
181
+ crop_params=crop_params,
182
+ chunk_size=1,
183
+ )
184
+ render_images = torch.clamp(out['images_rgb'], 0.0, 1.0)
185
+ render_depths = out['images_depth']
186
+ render_alphas = torch.clamp(out['images_weight'], 0.0, 1.0)
187
+
188
+ out = {
189
+ 'render_images': render_images,
190
+ 'render_depths': render_depths,
191
+ 'render_alphas': render_alphas,
192
+ }
193
+ return out
194
+
195
+ def training_step(self, batch, batch_idx):
196
+ lrm_generator_input, render_gt = self.prepare_batch_data(batch)
197
+
198
+ render_out = self.forward(lrm_generator_input)
199
+
200
+ loss, loss_dict = self.compute_loss(render_out, render_gt)
201
+
202
+ self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
203
+
204
+ if self.global_step % 1000 == 0 and self.global_rank == 0:
205
+ B, N, C, H, W = render_gt['target_images'].shape
206
+ N_in = lrm_generator_input['images'].shape[1]
207
+
208
+ input_images = v2.functional.resize(
209
+ lrm_generator_input['images'], (H, W), interpolation=3, antialias=True).clamp(0, 1)
210
+ input_images = torch.cat(
211
+ [input_images, torch.ones(B, N-N_in, C, H, W).to(input_images)], dim=1)
212
+
213
+ input_images = rearrange(
214
+ input_images, 'b n c h w -> b c h (n w)')
215
+ target_images = rearrange(
216
+ render_gt['target_images'], 'b n c h w -> b c h (n w)')
217
+ render_images = rearrange(
218
+ render_out['render_images'], 'b n c h w -> b c h (n w)')
219
+ target_alphas = rearrange(
220
+ repeat(render_gt['target_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
221
+ render_alphas = rearrange(
222
+ repeat(render_out['render_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
223
+ target_depths = rearrange(
224
+ repeat(render_gt['target_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
225
+ render_depths = rearrange(
226
+ repeat(render_out['render_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
227
+ MAX_DEPTH = torch.max(target_depths)
228
+ target_depths = target_depths / MAX_DEPTH * target_alphas
229
+ render_depths = render_depths / MAX_DEPTH
230
+
231
+ grid = torch.cat([
232
+ input_images,
233
+ target_images, render_images,
234
+ target_alphas, render_alphas,
235
+ target_depths, render_depths,
236
+ ], dim=-2)
237
+ grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1))
238
+
239
+ save_image(grid, os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png'))
240
+
241
+ return loss
242
+
243
+ def compute_loss(self, render_out, render_gt):
244
+ # NOTE: the rgb value range of OpenLRM is [0, 1]
245
+ render_images = render_out['render_images']
246
+ target_images = render_gt['target_images'].to(render_images)
247
+ render_images = rearrange(render_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
248
+ target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
249
+
250
+ loss_mse = F.mse_loss(render_images, target_images)
251
+ loss_lpips = 2.0 * self.lpips(render_images, target_images)
252
+
253
+ render_alphas = render_out['render_alphas']
254
+ target_alphas = render_gt['target_alphas']
255
+ loss_mask = F.mse_loss(render_alphas, target_alphas)
256
+
257
+ loss = loss_mse + loss_lpips + loss_mask
258
+
259
+ prefix = 'train'
260
+ loss_dict = {}
261
+ loss_dict.update({f'{prefix}/loss_mse': loss_mse})
262
+ loss_dict.update({f'{prefix}/loss_lpips': loss_lpips})
263
+ loss_dict.update({f'{prefix}/loss_mask': loss_mask})
264
+ loss_dict.update({f'{prefix}/loss': loss})
265
+
266
+ return loss, loss_dict
267
+
268
+ @torch.no_grad()
269
+ def validation_step(self, batch, batch_idx):
270
+ lrm_generator_input = self.prepare_validation_batch_data(batch)
271
+
272
+ render_out = self.forward(lrm_generator_input)
273
+ render_images = render_out['render_images']
274
+ render_images = rearrange(render_images, 'b n c h w -> b c h (n w)')
275
+
276
+ self.validation_step_outputs.append(render_images)
277
+
278
+ def on_validation_epoch_end(self):
279
+ images = torch.cat(self.validation_step_outputs, dim=-1)
280
+
281
+ all_images = self.all_gather(images)
282
+ all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')
283
+
284
+ if self.global_rank == 0:
285
+ image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')
286
+
287
+ grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1))
288
+ save_image(grid, image_path)
289
+ print(f"Saved image to {image_path}")
290
+
291
+ self.validation_step_outputs.clear()
292
+
293
+ def configure_optimizers(self):
294
+ lr = self.learning_rate
295
+
296
+ params = []
297
+
298
+ lrm_params_fast, lrm_params_slow = [], []
299
+ for n, p in self.lrm_generator.named_parameters():
300
+ if 'adaLN_modulation' in n or 'camera_embedder' in n:
301
+ lrm_params_fast.append(p)
302
+ else:
303
+ lrm_params_slow.append(p)
304
+ params.append({"params": lrm_params_fast, "lr": lr, "weight_decay": 0.01 })
305
+ params.append({"params": lrm_params_slow, "lr": lr / 10.0, "weight_decay": 0.01 })
306
+
307
+ optimizer = torch.optim.AdamW(params, lr=lr, betas=(0.90, 0.95))
308
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3000, eta_min=lr/4)
309
+
310
+ return {'optimizer': optimizer, 'lr_scheduler': scheduler}
src/model_mesh.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torchvision.transforms import v2
6
+ from torchvision.utils import make_grid, save_image
7
+ from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
8
+ import pytorch_lightning as pl
9
+ from einops import rearrange, repeat
10
+
11
+ from src.utils.train_util import instantiate_from_config
12
+
13
+
14
+ # Regulrarization loss for FlexiCubes
15
+ def sdf_reg_loss_batch(sdf, all_edges):
16
+ sdf_f1x6x2 = sdf[:, all_edges.reshape(-1)].reshape(sdf.shape[0], -1, 2)
17
+ mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1])
18
+ sdf_f1x6x2 = sdf_f1x6x2[mask]
19
+ sdf_diff = F.binary_cross_entropy_with_logits(
20
+ sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()) + \
21
+ F.binary_cross_entropy_with_logits(
22
+ sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float())
23
+ return sdf_diff
24
+
25
+
26
+ class MVRecon(pl.LightningModule):
27
+ def __init__(
28
+ self,
29
+ lrm_generator_config,
30
+ input_size=256,
31
+ render_size=512,
32
+ init_ckpt=None,
33
+ ):
34
+ super(MVRecon, self).__init__()
35
+
36
+ self.input_size = input_size
37
+ self.render_size = render_size
38
+
39
+ # init modules
40
+ self.lrm_generator = instantiate_from_config(lrm_generator_config)
41
+
42
+ self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')
43
+
44
+ # Load weights from pretrained MVRecon model, and use the mlp
45
+ # weights to initialize the weights of sdf and rgb mlps.
46
+ if init_ckpt is not None:
47
+ sd = torch.load(init_ckpt, map_location='cpu')['state_dict']
48
+ sd = {k: v for k, v in sd.items() if k.startswith('lrm_generator')}
49
+ sd_fc = {}
50
+ for k, v in sd.items():
51
+ if k.startswith('lrm_generator.synthesizer.decoder.net.'):
52
+ if k.startswith('lrm_generator.synthesizer.decoder.net.6.'): # last layer
53
+ # Here we assume the density filed's isosurface threshold is t,
54
+ # we reverse the sign of density filed to initialize SDF field.
55
+ # -(w*x + b - t) = (-w)*x + (t - b)
56
+ if 'weight' in k:
57
+ sd_fc[k.replace('net.', 'net_sdf.')] = -v[0:1]
58
+ else:
59
+ sd_fc[k.replace('net.', 'net_sdf.')] = 3.0 - v[0:1]
60
+ sd_fc[k.replace('net.', 'net_rgb.')] = v[1:4]
61
+ else:
62
+ sd_fc[k.replace('net.', 'net_sdf.')] = v
63
+ sd_fc[k.replace('net.', 'net_rgb.')] = v
64
+ else:
65
+ sd_fc[k] = v
66
+ sd_fc = {k.replace('lrm_generator.', ''): v for k, v in sd_fc.items()}
67
+ # missing `net_deformation` and `net_weight` parameters
68
+ self.lrm_generator.load_state_dict(sd_fc, strict=False)
69
+ print(f'Loaded weights from {init_ckpt}')
70
+
71
+ self.validation_step_outputs = []
72
+
73
+ def on_fit_start(self):
74
+ device = torch.device(f'cuda:{self.global_rank}')
75
+ self.lrm_generator.init_flexicubes_geometry(device)
76
+ if self.global_rank == 0:
77
+ os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)
78
+ os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)
79
+
80
+ def prepare_batch_data(self, batch):
81
+ lrm_generator_input = {}
82
+ render_gt = {}
83
+
84
+ # input images
85
+ images = batch['input_images']
86
+ images = v2.functional.resize(
87
+ images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
88
+
89
+ lrm_generator_input['images'] = images.to(self.device)
90
+
91
+ # input cameras and render cameras
92
+ input_c2ws = batch['input_c2ws']
93
+ input_Ks = batch['input_Ks']
94
+ target_c2ws = batch['target_c2ws']
95
+
96
+ render_c2ws = torch.cat([input_c2ws, target_c2ws], dim=1)
97
+ render_w2cs = torch.linalg.inv(render_c2ws)
98
+
99
+ input_extrinsics = input_c2ws.flatten(-2)
100
+ input_extrinsics = input_extrinsics[:, :, :12]
101
+ input_intrinsics = input_Ks.flatten(-2)
102
+ input_intrinsics = torch.stack([
103
+ input_intrinsics[:, :, 0], input_intrinsics[:, :, 4],
104
+ input_intrinsics[:, :, 2], input_intrinsics[:, :, 5],
105
+ ], dim=-1)
106
+ cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
107
+
108
+ # add noise to input_cameras
109
+ cameras = cameras + torch.rand_like(cameras) * 0.04 - 0.02
110
+
111
+ lrm_generator_input['cameras'] = cameras.to(self.device)
112
+ lrm_generator_input['render_cameras'] = render_w2cs.to(self.device)
113
+
114
+ # target images
115
+ target_images = torch.cat([batch['input_images'], batch['target_images']], dim=1)
116
+ target_depths = torch.cat([batch['input_depths'], batch['target_depths']], dim=1)
117
+ target_alphas = torch.cat([batch['input_alphas'], batch['target_alphas']], dim=1)
118
+ target_normals = torch.cat([batch['input_normals'], batch['target_normals']], dim=1)
119
+
120
+ render_size = self.render_size
121
+ target_images = v2.functional.resize(
122
+ target_images, render_size, interpolation=3, antialias=True).clamp(0, 1)
123
+ target_depths = v2.functional.resize(
124
+ target_depths, render_size, interpolation=0, antialias=True)
125
+ target_alphas = v2.functional.resize(
126
+ target_alphas, render_size, interpolation=0, antialias=True)
127
+ target_normals = v2.functional.resize(
128
+ target_normals, render_size, interpolation=3, antialias=True)
129
+
130
+ lrm_generator_input['render_size'] = render_size
131
+
132
+ render_gt['target_images'] = target_images.to(self.device)
133
+ render_gt['target_depths'] = target_depths.to(self.device)
134
+ render_gt['target_alphas'] = target_alphas.to(self.device)
135
+ render_gt['target_normals'] = target_normals.to(self.device)
136
+
137
+ return lrm_generator_input, render_gt
138
+
139
+ def prepare_validation_batch_data(self, batch):
140
+ lrm_generator_input = {}
141
+
142
+ # input images
143
+ images = batch['input_images']
144
+ images = v2.functional.resize(
145
+ images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
146
+
147
+ lrm_generator_input['images'] = images.to(self.device)
148
+
149
+ # input cameras
150
+ input_c2ws = batch['input_c2ws'].flatten(-2)
151
+ input_Ks = batch['input_Ks'].flatten(-2)
152
+
153
+ input_extrinsics = input_c2ws[:, :, :12]
154
+ input_intrinsics = torch.stack([
155
+ input_Ks[:, :, 0], input_Ks[:, :, 4],
156
+ input_Ks[:, :, 2], input_Ks[:, :, 5],
157
+ ], dim=-1)
158
+ cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
159
+
160
+ lrm_generator_input['cameras'] = cameras.to(self.device)
161
+
162
+ # render cameras
163
+ render_c2ws = batch['render_c2ws']
164
+ render_w2cs = torch.linalg.inv(render_c2ws)
165
+
166
+ lrm_generator_input['render_cameras'] = render_w2cs.to(self.device)
167
+ lrm_generator_input['render_size'] = 384
168
+
169
+ return lrm_generator_input
170
+
171
+ def forward_lrm_generator(self, images, cameras, render_cameras, render_size=512):
172
+ planes = torch.utils.checkpoint.checkpoint(
173
+ self.lrm_generator.forward_planes,
174
+ images,
175
+ cameras,
176
+ use_reentrant=False,
177
+ )
178
+ out = self.lrm_generator.forward_geometry(
179
+ planes,
180
+ render_cameras,
181
+ render_size,
182
+ )
183
+ return out
184
+
185
+ def forward(self, lrm_generator_input):
186
+ images = lrm_generator_input['images']
187
+ cameras = lrm_generator_input['cameras']
188
+ render_cameras = lrm_generator_input['render_cameras']
189
+ render_size = lrm_generator_input['render_size']
190
+
191
+ out = self.forward_lrm_generator(
192
+ images, cameras, render_cameras, render_size=render_size)
193
+
194
+ return out
195
+
196
+ def training_step(self, batch, batch_idx):
197
+ lrm_generator_input, render_gt = self.prepare_batch_data(batch)
198
+
199
+ render_out = self.forward(lrm_generator_input)
200
+
201
+ loss, loss_dict = self.compute_loss(render_out, render_gt)
202
+
203
+ self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
204
+
205
+ if self.global_step % 1000 == 0 and self.global_rank == 0:
206
+ B, N, C, H, W = render_gt['target_images'].shape
207
+ N_in = lrm_generator_input['images'].shape[1]
208
+
209
+ target_images = rearrange(
210
+ render_gt['target_images'], 'b n c h w -> b c h (n w)')
211
+ render_images = rearrange(
212
+ render_out['img'], 'b n c h w -> b c h (n w)')
213
+ target_alphas = rearrange(
214
+ repeat(render_gt['target_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
215
+ render_alphas = rearrange(
216
+ repeat(render_out['mask'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
217
+ target_depths = rearrange(
218
+ repeat(render_gt['target_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
219
+ render_depths = rearrange(
220
+ repeat(render_out['depth'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
221
+ target_normals = rearrange(
222
+ render_gt['target_normals'], 'b n c h w -> b c h (n w)')
223
+ render_normals = rearrange(
224
+ render_out['normal'], 'b n c h w -> b c h (n w)')
225
+ MAX_DEPTH = torch.max(target_depths)
226
+ target_depths = target_depths / MAX_DEPTH * target_alphas
227
+ render_depths = render_depths / MAX_DEPTH
228
+
229
+ grid = torch.cat([
230
+ target_images, render_images,
231
+ target_alphas, render_alphas,
232
+ target_depths, render_depths,
233
+ target_normals, render_normals,
234
+ ], dim=-2)
235
+ grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1))
236
+
237
+ image_path = os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png')
238
+ save_image(grid, image_path)
239
+ print(f"Saved image to {image_path}")
240
+
241
+ return loss
242
+
243
+ def compute_loss(self, render_out, render_gt):
244
+ # NOTE: the rgb value range of OpenLRM is [0, 1]
245
+ render_images = render_out['img']
246
+ target_images = render_gt['target_images'].to(render_images)
247
+ render_images = rearrange(render_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
248
+ target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
249
+ loss_mse = F.mse_loss(render_images, target_images)
250
+ loss_lpips = 2.0 * self.lpips(render_images, target_images)
251
+
252
+ render_alphas = render_out['mask']
253
+ target_alphas = render_gt['target_alphas']
254
+ loss_mask = F.mse_loss(render_alphas, target_alphas)
255
+
256
+ render_depths = render_out['depth']
257
+ target_depths = render_gt['target_depths']
258
+ loss_depth = 0.5 * F.l1_loss(render_depths[target_alphas>0], target_depths[target_alphas>0])
259
+
260
+ render_normals = render_out['normal'] * 2.0 - 1.0
261
+ target_normals = render_gt['target_normals'] * 2.0 - 1.0
262
+ similarity = (render_normals * target_normals).sum(dim=-3).abs()
263
+ normal_mask = target_alphas.squeeze(-3)
264
+ loss_normal = 1 - similarity[normal_mask>0].mean()
265
+ loss_normal = 0.2 * loss_normal
266
+
267
+ # flexicubes regularization loss
268
+ sdf = render_out['sdf']
269
+ sdf_reg_loss = render_out['sdf_reg_loss']
270
+ sdf_reg_loss_entropy = sdf_reg_loss_batch(sdf, self.lrm_generator.geometry.all_edges).mean() * 0.01
271
+ _, flexicubes_surface_reg, flexicubes_weights_reg = sdf_reg_loss
272
+ flexicubes_surface_reg = flexicubes_surface_reg.mean() * 0.5
273
+ flexicubes_weights_reg = flexicubes_weights_reg.mean() * 0.1
274
+
275
+ loss_reg = sdf_reg_loss_entropy + flexicubes_surface_reg + flexicubes_weights_reg
276
+
277
+ loss = loss_mse + loss_lpips + loss_mask + loss_normal + loss_reg
278
+
279
+ prefix = 'train'
280
+ loss_dict = {}
281
+ loss_dict.update({f'{prefix}/loss_mse': loss_mse})
282
+ loss_dict.update({f'{prefix}/loss_lpips': loss_lpips})
283
+ loss_dict.update({f'{prefix}/loss_mask': loss_mask})
284
+ loss_dict.update({f'{prefix}/loss_normal': loss_normal})
285
+ loss_dict.update({f'{prefix}/loss_depth': loss_depth})
286
+ loss_dict.update({f'{prefix}/loss_reg_sdf': sdf_reg_loss_entropy})
287
+ loss_dict.update({f'{prefix}/loss_reg_surface': flexicubes_surface_reg})
288
+ loss_dict.update({f'{prefix}/loss_reg_weights': flexicubes_weights_reg})
289
+ loss_dict.update({f'{prefix}/loss': loss})
290
+
291
+ return loss, loss_dict
292
+
293
+ @torch.no_grad()
294
+ def validation_step(self, batch, batch_idx):
295
+ lrm_generator_input = self.prepare_validation_batch_data(batch)
296
+
297
+ render_out = self.forward(lrm_generator_input)
298
+ render_images = render_out['img']
299
+ render_images = rearrange(render_images, 'b n c h w -> b c h (n w)')
300
+
301
+ self.validation_step_outputs.append(render_images)
302
+
303
+ def on_validation_epoch_end(self):
304
+ images = torch.cat(self.validation_step_outputs, dim=-1)
305
+
306
+ all_images = self.all_gather(images)
307
+ all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')
308
+
309
+ if self.global_rank == 0:
310
+ image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')
311
+
312
+ grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1))
313
+ save_image(grid, image_path)
314
+ print(f"Saved image to {image_path}")
315
+
316
+ self.validation_step_outputs.clear()
317
+
318
+ def configure_optimizers(self):
319
+ lr = self.learning_rate
320
+
321
+ optimizer = torch.optim.AdamW(
322
+ self.lrm_generator.parameters(), lr=lr, betas=(0.90, 0.95), weight_decay=0.01)
323
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 100000, eta_min=0)
324
+
325
+ return {'optimizer': optimizer, 'lr_scheduler': scheduler}
src/models/__init__.py ADDED
File without changes
src/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (164 Bytes). View file
 
src/models/__pycache__/lrm_mesh.cpython-311.pyc ADDED
Binary file (21.8 kB). View file
 
src/models/decoder/__init__.py ADDED
File without changes