Maikou commited on
Commit
9c3a994
1 Parent(s): e83e789

all files first commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. checkpoints/aligned_shape_latents/shapevae-256.ckpt +3 -0
  2. checkpoints/clip/clip-vit-large-patch14 +1 -0
  3. checkpoints/image_cond_diffuser_asl/image-ASLDM-256.ckpt +3 -0
  4. checkpoints/text_cond_diffuser_asl/text-ASLDM-256.ckpt +3 -0
  5. configs/aligned_shape_latents/shapevae-256.yaml +46 -0
  6. configs/deploy/clip_aslp_3df+3dc+abo+gso+toy+t10k+obj+sp+pk=256_01_4096_8_ckpt_250000_udt=110M_finetune_500000_deploy.yaml +181 -0
  7. configs/deploy/clip_sp+pk_aslperceiver=256_01_4096_8_udt=03.yaml +180 -0
  8. configs/image_cond_diffuser_asl/image-ASLDM-256.yaml +97 -0
  9. configs/text_cond_diffuser_asl/text-ASLDM-256.yaml +98 -0
  10. example_data/image/car.jpg +0 -0
  11. example_data/surface/surface.npz +3 -0
  12. gradio_app.py +372 -0
  13. gradio_cached_dir/example/img_example/airplane.jpg +0 -0
  14. gradio_cached_dir/example/img_example/alita.jpg +0 -0
  15. gradio_cached_dir/example/img_example/bag.jpg +0 -0
  16. gradio_cached_dir/example/img_example/bench.jpg +0 -0
  17. gradio_cached_dir/example/img_example/building.jpg +0 -0
  18. gradio_cached_dir/example/img_example/burger.jpg +0 -0
  19. gradio_cached_dir/example/img_example/car.jpg +0 -0
  20. gradio_cached_dir/example/img_example/loopy.jpg +0 -0
  21. gradio_cached_dir/example/img_example/mario.jpg +0 -0
  22. gradio_cached_dir/example/img_example/ship.jpg +0 -0
  23. inference.py +181 -0
  24. michelangelo/__init__.py +1 -0
  25. michelangelo/__pycache__/__init__.cpython-39.pyc +0 -0
  26. michelangelo/data/__init__.py +1 -0
  27. michelangelo/data/__pycache__/__init__.cpython-39.pyc +0 -0
  28. michelangelo/data/__pycache__/asl_webdataset.cpython-39.pyc +0 -0
  29. michelangelo/data/__pycache__/tokenizer.cpython-39.pyc +0 -0
  30. michelangelo/data/__pycache__/transforms.cpython-39.pyc +0 -0
  31. michelangelo/data/__pycache__/utils.cpython-39.pyc +0 -0
  32. michelangelo/data/templates.json +69 -0
  33. michelangelo/data/transforms.py +407 -0
  34. michelangelo/data/utils.py +59 -0
  35. michelangelo/graphics/__init__.py +1 -0
  36. michelangelo/graphics/__pycache__/__init__.cpython-39.pyc +0 -0
  37. michelangelo/graphics/primitives/__init__.py +9 -0
  38. michelangelo/graphics/primitives/__pycache__/__init__.cpython-39.pyc +0 -0
  39. michelangelo/graphics/primitives/__pycache__/extract_texture_map.cpython-39.pyc +0 -0
  40. michelangelo/graphics/primitives/__pycache__/mesh.cpython-39.pyc +0 -0
  41. michelangelo/graphics/primitives/__pycache__/volume.cpython-39.pyc +0 -0
  42. michelangelo/graphics/primitives/mesh.py +114 -0
  43. michelangelo/graphics/primitives/volume.py +21 -0
  44. michelangelo/models/__init__.py +1 -0
  45. michelangelo/models/__pycache__/__init__.cpython-39.pyc +0 -0
  46. michelangelo/models/asl_diffusion/__init__.py +1 -0
  47. michelangelo/models/asl_diffusion/__pycache__/__init__.cpython-39.pyc +0 -0
  48. michelangelo/models/asl_diffusion/__pycache__/asl_udt.cpython-39.pyc +0 -0
  49. michelangelo/models/asl_diffusion/__pycache__/clip_asl_diffuser_pl_module.cpython-39.pyc +0 -0
  50. michelangelo/models/asl_diffusion/__pycache__/inference_utils.cpython-39.pyc +0 -0
checkpoints/aligned_shape_latents/shapevae-256.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0391b81c36240e8f766fedf4265df599884193a5ef65354525074b9a00887454
3
+ size 3934164973
checkpoints/clip/clip-vit-large-patch14 ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 8d052a0f05efbaefbc9e8786ba291cfdf93e5bff
checkpoints/image_cond_diffuser_asl/image-ASLDM-256.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:83eda8e4f81034dee7674b3ce1ff03a4900181f0f0d7bc461e1a8692fb379b0f
3
+ size 1999253985
checkpoints/text_cond_diffuser_asl/text-ASLDM-256.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af546b1f877a41d71f63c3a11394779e77c954002c50dc8e75359338224f615b
3
+ size 4076140813
configs/aligned_shape_latents/shapevae-256.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule
3
+ params:
4
+ shape_module_cfg:
5
+ target: michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver
6
+ params:
7
+ num_latents: 256
8
+ embed_dim: 64
9
+ point_feats: 3 # normal
10
+ num_freqs: 8
11
+ include_pi: false
12
+ heads: 12
13
+ width: 768
14
+ num_encoder_layers: 8
15
+ num_decoder_layers: 16
16
+ use_ln_post: true
17
+ init_scale: 0.25
18
+ qkv_bias: false
19
+ use_checkpoint: true
20
+ aligned_module_cfg:
21
+ target: michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule
22
+ params:
23
+ clip_model_version: "./checkpoints/clip/clip-vit-large-patch14"
24
+
25
+ loss_cfg:
26
+ target: michelangelo.models.tsal.loss.ContrastKLNearFar
27
+ params:
28
+ contrast_weight: 0.1
29
+ near_weight: 0.1
30
+ kl_weight: 0.001
31
+
32
+ optimizer_cfg:
33
+ optimizer:
34
+ target: torch.optim.AdamW
35
+ params:
36
+ betas: [0.9, 0.99]
37
+ eps: 1.e-6
38
+ weight_decay: 1.e-2
39
+
40
+ scheduler:
41
+ target: michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler
42
+ params:
43
+ warm_up_steps: 5000
44
+ f_start: 1.e-6
45
+ f_min: 1.e-3
46
+ f_max: 1.0
configs/deploy/clip_aslp_3df+3dc+abo+gso+toy+t10k+obj+sp+pk=256_01_4096_8_ckpt_250000_udt=110M_finetune_500000_deploy.yaml ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "0630_clip_aslp_3df+3dc+abo+gso+toy+t10k+obj+sp+pk=256_01_4096_8_ckpt_250000_udt=110M_finetune_500000"
2
+ #wandb:
3
+ # project: "image_diffuser"
4
+ # offline: false
5
+
6
+
7
+ training:
8
+ steps: 500000
9
+ use_amp: true
10
+ ckpt_path: ""
11
+ base_lr: 1.e-4
12
+ gradient_clip_val: 5.0
13
+ gradient_clip_algorithm: "norm"
14
+ every_n_train_steps: 5000
15
+ val_check_interval: 1024
16
+ limit_val_batches: 16
17
+
18
+ dataset:
19
+ target: michelangelo.data.asl_webdataset.MultiAlignedShapeLatentModule
20
+ params:
21
+ batch_size: 38
22
+ num_workers: 4
23
+ val_num_workers: 4
24
+ buffer_size: 256
25
+ return_normal: true
26
+ random_crop: false
27
+ surface_sampling: true
28
+ pc_size: &pc_size 4096
29
+ image_size: 384
30
+ mean: &mean [0.5, 0.5, 0.5]
31
+ std: &std [0.5, 0.5, 0.5]
32
+ cond_stage_key: "image"
33
+
34
+ meta_info:
35
+ 3D-FUTURE:
36
+ render_folder: "/root/workspace/cq_workspace/datasets/3D-FUTURE/renders"
37
+ tar_folder: "/root/workspace/datasets/make_tars/3D-FUTURE"
38
+
39
+ ABO:
40
+ render_folder: "/root/workspace/cq_workspace/datasets/ABO/renders"
41
+ tar_folder: "/root/workspace/datasets/make_tars/ABO"
42
+
43
+ GSO:
44
+ render_folder: "/root/workspace/cq_workspace/datasets/GSO/renders"
45
+ tar_folder: "/root/workspace/datasets/make_tars/GSO"
46
+
47
+ TOYS4K:
48
+ render_folder: "/root/workspace/cq_workspace/datasets/TOYS4K/TOYS4K/renders"
49
+ tar_folder: "/root/workspace/datasets/make_tars/TOYS4K"
50
+
51
+ 3DCaricShop:
52
+ render_folder: "/root/workspace/cq_workspace/datasets/3DCaricShop/renders"
53
+ tar_folder: "/root/workspace/datasets/make_tars/3DCaricShop"
54
+
55
+ Thingi10K:
56
+ render_folder: "/root/workspace/cq_workspace/datasets/Thingi10K/renders"
57
+ tar_folder: "/root/workspace/datasets/make_tars/Thingi10K"
58
+
59
+ shapenet:
60
+ render_folder: "/root/workspace/cq_workspace/datasets/shapenet/renders"
61
+ tar_folder: "/root/workspace/datasets/make_tars/shapenet"
62
+
63
+ pokemon:
64
+ render_folder: "/root/workspace/cq_workspace/datasets/pokemon/renders"
65
+ tar_folder: "/root/workspace/datasets/make_tars/pokemon"
66
+
67
+ objaverse:
68
+ render_folder: "/root/workspace/cq_workspace/datasets/objaverse/renders"
69
+ tar_folder: "/root/workspace/datasets/make_tars/objaverse"
70
+
71
+ model:
72
+ target: michelangelo.models.asl_diffusion.clip_asl_diffuser_pl_module.ClipASLDiffuser
73
+ params:
74
+ first_stage_config:
75
+ target: michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule
76
+ params:
77
+ shape_module_cfg:
78
+ target: michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver
79
+ params:
80
+ num_latents: &num_latents 256
81
+ embed_dim: &embed_dim 64
82
+ point_feats: 3 # normal
83
+ num_freqs: 8
84
+ include_pi: false
85
+ heads: 12
86
+ width: 768
87
+ num_encoder_layers: 8
88
+ num_decoder_layers: 16
89
+ use_ln_post: true
90
+ init_scale: 0.25
91
+ qkv_bias: false
92
+ use_checkpoint: false
93
+ aligned_module_cfg:
94
+ target: michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule
95
+ params:
96
+ clip_model_version: "/mnt/shadow_cv_training/stevenxxliu/checkpoints/clip/clip-vit-large-patch14"
97
+ # clip_model_version: "/root/workspace/checkpoints/clip/clip-vit-large-patch14"
98
+
99
+ loss_cfg:
100
+ target: torch.nn.Identity
101
+
102
+ cond_stage_config:
103
+ target: michelangelo.models.conditional_encoders.encoder_factory.FrozenCLIPImageGridEmbedder
104
+ params:
105
+ version: "/mnt/shadow_cv_training/stevenxxliu/checkpoints/clip/clip-vit-large-patch14"
106
+ # version: "/root/workspace/checkpoints/clip/clip-vit-large-patch14"
107
+ zero_embedding_radio: 0.1
108
+
109
+ first_stage_key: "surface"
110
+ cond_stage_key: "image"
111
+ scale_by_std: false
112
+
113
+ denoiser_cfg:
114
+ target: michelangelo.models.asl_diffusion.asl_udt.ConditionalASLUDTDenoiser
115
+ params:
116
+ input_channels: *embed_dim
117
+ output_channels: *embed_dim
118
+ n_ctx: *num_latents
119
+ width: 768
120
+ layers: 6 # 2 * 6 + 1 = 13
121
+ heads: 12
122
+ context_dim: 1024
123
+ init_scale: 1.0
124
+ skip_ln: true
125
+ use_checkpoint: true
126
+
127
+ scheduler_cfg:
128
+ guidance_scale: 7.5
129
+ num_inference_steps: 50
130
+ eta: 0.0
131
+
132
+ noise:
133
+ target: diffusers.schedulers.DDPMScheduler
134
+ params:
135
+ num_train_timesteps: 1000
136
+ beta_start: 0.00085
137
+ beta_end: 0.012
138
+ beta_schedule: "scaled_linear"
139
+ variance_type: "fixed_small"
140
+ clip_sample: false
141
+ denoise:
142
+ target: diffusers.schedulers.DDIMScheduler
143
+ params:
144
+ num_train_timesteps: 1000
145
+ beta_start: 0.00085
146
+ beta_end: 0.012
147
+ beta_schedule: "scaled_linear"
148
+ clip_sample: false # clip sample to -1~1
149
+ set_alpha_to_one: false
150
+ steps_offset: 1
151
+
152
+ optimizer_cfg:
153
+ optimizer:
154
+ target: torch.optim.AdamW
155
+ params:
156
+ betas: [0.9, 0.99]
157
+ eps: 1.e-6
158
+ weight_decay: 1.e-2
159
+
160
+ scheduler:
161
+ target: michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler
162
+ params:
163
+ warm_up_steps: 5000
164
+ f_start: 1.e-6
165
+ f_min: 1.e-3
166
+ f_max: 1.0
167
+
168
+ loss_cfg:
169
+ loss_type: "mse"
170
+
171
+ logger:
172
+ target: michelangelo.utils.trainings.mesh_log_callback.ImageConditionalASLDiffuserLogger
173
+ params:
174
+ step_frequency: 2000
175
+ num_samples: 4
176
+ sample_times: 4
177
+ mean: *mean
178
+ std: *std
179
+ bounds: [-1.1, -1.1, -1.1, 1.1, 1.1, 1.1]
180
+ octree_depth: 7
181
+ num_chunks: 10000
configs/deploy/clip_sp+pk_aslperceiver=256_01_4096_8_udt=03.yaml ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "0428_clip_subsp+pk_sal_perceiver=256_01_4096_8_udt=03"
2
+ #wandb:
3
+ # project: "image_diffuser"
4
+ # offline: false
5
+
6
+ training:
7
+ steps: 500000
8
+ use_amp: true
9
+ ckpt_path: ""
10
+ base_lr: 1.e-4
11
+ gradient_clip_val: 5.0
12
+ gradient_clip_algorithm: "norm"
13
+ every_n_train_steps: 5000
14
+ val_check_interval: 1024
15
+ limit_val_batches: 16
16
+
17
+ # dataset
18
+ dataset:
19
+ target: michelangelo.data.asl_torch_dataset.MultiAlignedShapeImageTextModule
20
+ params:
21
+ batch_size: 38
22
+ num_workers: 4
23
+ val_num_workers: 4
24
+ buffer_size: 256
25
+ return_normal: true
26
+ random_crop: false
27
+ surface_sampling: true
28
+ pc_size: &pc_size 4096
29
+ image_size: 384
30
+ mean: &mean [0.5, 0.5, 0.5]
31
+ std: &std [0.5, 0.5, 0.5]
32
+
33
+ cond_stage_key: "text"
34
+
35
+ meta_info:
36
+ 3D-FUTURE:
37
+ render_folder: "/root/workspace/cq_workspace/datasets/3D-FUTURE/renders"
38
+ tar_folder: "/root/workspace/datasets/make_tars/3D-FUTURE"
39
+
40
+ ABO:
41
+ render_folder: "/root/workspace/cq_workspace/datasets/ABO/renders"
42
+ tar_folder: "/root/workspace/datasets/make_tars/ABO"
43
+
44
+ GSO:
45
+ render_folder: "/root/workspace/cq_workspace/datasets/GSO/renders"
46
+ tar_folder: "/root/workspace/datasets/make_tars/GSO"
47
+
48
+ TOYS4K:
49
+ render_folder: "/root/workspace/cq_workspace/datasets/TOYS4K/TOYS4K/renders"
50
+ tar_folder: "/root/workspace/datasets/make_tars/TOYS4K"
51
+
52
+ 3DCaricShop:
53
+ render_folder: "/root/workspace/cq_workspace/datasets/3DCaricShop/renders"
54
+ tar_folder: "/root/workspace/datasets/make_tars/3DCaricShop"
55
+
56
+ Thingi10K:
57
+ render_folder: "/root/workspace/cq_workspace/datasets/Thingi10K/renders"
58
+ tar_folder: "/root/workspace/datasets/make_tars/Thingi10K"
59
+
60
+ shapenet:
61
+ render_folder: "/root/workspace/cq_workspace/datasets/shapenet/renders"
62
+ tar_folder: "/root/workspace/datasets/make_tars/shapenet"
63
+
64
+ pokemon:
65
+ render_folder: "/root/workspace/cq_workspace/datasets/pokemon/renders"
66
+ tar_folder: "/root/workspace/datasets/make_tars/pokemon"
67
+
68
+ objaverse:
69
+ render_folder: "/root/workspace/cq_workspace/datasets/objaverse/renders"
70
+ tar_folder: "/root/workspace/datasets/make_tars/objaverse"
71
+
72
+ model:
73
+ target: michelangelo.models.asl_diffusion.clip_asl_diffuser_pl_module.ClipASLDiffuser
74
+ params:
75
+ first_stage_config:
76
+ target: michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule
77
+ params:
78
+ # ckpt_path: "/root/workspace/cq_workspace/michelangelo/experiments/aligned_shape_latents/clip_aslperceiver_sp+pk_01_01/ckpt/ckpt-step=00230000.ckpt"
79
+ shape_module_cfg:
80
+ target: michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver
81
+ params:
82
+ num_latents: &num_latents 256
83
+ embed_dim: &embed_dim 64
84
+ point_feats: 3 # normal
85
+ num_freqs: 8
86
+ include_pi: false
87
+ heads: 12
88
+ width: 768
89
+ num_encoder_layers: 8
90
+ num_decoder_layers: 16
91
+ use_ln_post: true
92
+ init_scale: 0.25
93
+ qkv_bias: false
94
+ use_checkpoint: true
95
+ aligned_module_cfg:
96
+ target: michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule
97
+ params:
98
+ clip_model_version: "/mnt/shadow_cv_training/stevenxxliu/checkpoints/clip/clip-vit-large-patch14"
99
+
100
+ loss_cfg:
101
+ target: torch.nn.Identity
102
+
103
+ cond_stage_config:
104
+ target: michelangelo.models.conditional_encoders.encoder_factory.FrozenAlignedCLIPTextEmbedder
105
+ params:
106
+ version: "/mnt/shadow_cv_training/stevenxxliu/checkpoints/clip/clip-vit-large-patch14"
107
+ zero_embedding_radio: 0.1
108
+ max_length: 77
109
+
110
+ first_stage_key: "surface"
111
+ cond_stage_key: "text"
112
+ scale_by_std: false
113
+
114
+ denoiser_cfg:
115
+ target: michelangelo.models.asl_diffusion.asl_udt.ConditionalASLUDTDenoiser
116
+ params:
117
+ input_channels: *embed_dim
118
+ output_channels: *embed_dim
119
+ n_ctx: *num_latents
120
+ width: 768
121
+ layers: 8 # 2 * 6 + 1 = 13
122
+ heads: 12
123
+ context_dim: 768
124
+ init_scale: 1.0
125
+ skip_ln: true
126
+ use_checkpoint: true
127
+
128
+ scheduler_cfg:
129
+ guidance_scale: 7.5
130
+ num_inference_steps: 50
131
+ eta: 0.0
132
+
133
+ noise:
134
+ target: diffusers.schedulers.DDPMScheduler
135
+ params:
136
+ num_train_timesteps: 1000
137
+ beta_start: 0.00085
138
+ beta_end: 0.012
139
+ beta_schedule: "scaled_linear"
140
+ variance_type: "fixed_small"
141
+ clip_sample: false
142
+ denoise:
143
+ target: diffusers.schedulers.DDIMScheduler
144
+ params:
145
+ num_train_timesteps: 1000
146
+ beta_start: 0.00085
147
+ beta_end: 0.012
148
+ beta_schedule: "scaled_linear"
149
+ clip_sample: false # clip sample to -1~1
150
+ set_alpha_to_one: false
151
+ steps_offset: 1
152
+
153
+ optimizer_cfg:
154
+ optimizer:
155
+ target: torch.optim.AdamW
156
+ params:
157
+ betas: [0.9, 0.99]
158
+ eps: 1.e-6
159
+ weight_decay: 1.e-2
160
+
161
+ scheduler:
162
+ target: michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler
163
+ params:
164
+ warm_up_steps: 5000
165
+ f_start: 1.e-6
166
+ f_min: 1.e-3
167
+ f_max: 1.0
168
+
169
+ loss_cfg:
170
+ loss_type: "mse"
171
+
172
+ logger:
173
+ target: michelangelo.utils.trainings.mesh_log_callback.TextConditionalASLDiffuserLogger
174
+ params:
175
+ step_frequency: 1000
176
+ num_samples: 4
177
+ sample_times: 4
178
+ bounds: [-1.1, -1.1, -1.1, 1.1, 1.1, 1.1]
179
+ octree_depth: 7
180
+ num_chunks: 10000
configs/image_cond_diffuser_asl/image-ASLDM-256.yaml ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: michelangelo.models.asl_diffusion.clip_asl_diffuser_pl_module.ClipASLDiffuser
3
+ params:
4
+ first_stage_config:
5
+ target: michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule
6
+ params:
7
+ shape_module_cfg:
8
+ target: michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver
9
+ params:
10
+ num_latents: &num_latents 256
11
+ embed_dim: &embed_dim 64
12
+ point_feats: 3 # normal
13
+ num_freqs: 8
14
+ include_pi: false
15
+ heads: 12
16
+ width: 768
17
+ num_encoder_layers: 8
18
+ num_decoder_layers: 16
19
+ use_ln_post: true
20
+ init_scale: 0.25
21
+ qkv_bias: false
22
+ use_checkpoint: false
23
+ aligned_module_cfg:
24
+ target: michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule
25
+ params:
26
+ clip_model_version: "./checkpoints/clip/clip-vit-large-patch14"
27
+
28
+ loss_cfg:
29
+ target: torch.nn.Identity
30
+
31
+ cond_stage_config:
32
+ target: michelangelo.models.conditional_encoders.encoder_factory.FrozenCLIPImageGridEmbedder
33
+ params:
34
+ version: "./checkpoints/clip/clip-vit-large-patch14"
35
+ zero_embedding_radio: 0.1
36
+
37
+ first_stage_key: "surface"
38
+ cond_stage_key: "image"
39
+ scale_by_std: false
40
+
41
+ denoiser_cfg:
42
+ target: michelangelo.models.asl_diffusion.asl_udt.ConditionalASLUDTDenoiser
43
+ params:
44
+ input_channels: *embed_dim
45
+ output_channels: *embed_dim
46
+ n_ctx: *num_latents
47
+ width: 768
48
+ layers: 6 # 2 * 6 + 1 = 13
49
+ heads: 12
50
+ context_dim: 1024
51
+ init_scale: 1.0
52
+ skip_ln: true
53
+ use_checkpoint: true
54
+
55
+ scheduler_cfg:
56
+ guidance_scale: 7.5
57
+ num_inference_steps: 50
58
+ eta: 0.0
59
+
60
+ noise:
61
+ target: diffusers.schedulers.DDPMScheduler
62
+ params:
63
+ num_train_timesteps: 1000
64
+ beta_start: 0.00085
65
+ beta_end: 0.012
66
+ beta_schedule: "scaled_linear"
67
+ variance_type: "fixed_small"
68
+ clip_sample: false
69
+ denoise:
70
+ target: diffusers.schedulers.DDIMScheduler
71
+ params:
72
+ num_train_timesteps: 1000
73
+ beta_start: 0.00085
74
+ beta_end: 0.012
75
+ beta_schedule: "scaled_linear"
76
+ clip_sample: false # clip sample to -1~1
77
+ set_alpha_to_one: false
78
+ steps_offset: 1
79
+
80
+ optimizer_cfg:
81
+ optimizer:
82
+ target: torch.optim.AdamW
83
+ params:
84
+ betas: [0.9, 0.99]
85
+ eps: 1.e-6
86
+ weight_decay: 1.e-2
87
+
88
+ scheduler:
89
+ target: michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler
90
+ params:
91
+ warm_up_steps: 5000
92
+ f_start: 1.e-6
93
+ f_min: 1.e-3
94
+ f_max: 1.0
95
+
96
+ loss_cfg:
97
+ loss_type: "mse"
configs/text_cond_diffuser_asl/text-ASLDM-256.yaml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: michelangelo.models.asl_diffusion.clip_asl_diffuser_pl_module.ClipASLDiffuser
3
+ params:
4
+ first_stage_config:
5
+ target: michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule
6
+ params:
7
+ shape_module_cfg:
8
+ target: michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver
9
+ params:
10
+ num_latents: &num_latents 256
11
+ embed_dim: &embed_dim 64
12
+ point_feats: 3 # normal
13
+ num_freqs: 8
14
+ include_pi: false
15
+ heads: 12
16
+ width: 768
17
+ num_encoder_layers: 8
18
+ num_decoder_layers: 16
19
+ use_ln_post: true
20
+ init_scale: 0.25
21
+ qkv_bias: false
22
+ use_checkpoint: true
23
+ aligned_module_cfg:
24
+ target: michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule
25
+ params:
26
+ clip_model_version: "./checkpoints/clip/clip-vit-large-patch14"
27
+
28
+ loss_cfg:
29
+ target: torch.nn.Identity
30
+
31
+ cond_stage_config:
32
+ target: michelangelo.models.conditional_encoders.encoder_factory.FrozenAlignedCLIPTextEmbedder
33
+ params:
34
+ version: "./checkpoints/clip/clip-vit-large-patch14"
35
+ zero_embedding_radio: 0.1
36
+ max_length: 77
37
+
38
+ first_stage_key: "surface"
39
+ cond_stage_key: "text"
40
+ scale_by_std: false
41
+
42
+ denoiser_cfg:
43
+ target: michelangelo.models.asl_diffusion.asl_udt.ConditionalASLUDTDenoiser
44
+ params:
45
+ input_channels: *embed_dim
46
+ output_channels: *embed_dim
47
+ n_ctx: *num_latents
48
+ width: 768
49
+ layers: 8 # 2 * 6 + 1 = 13
50
+ heads: 12
51
+ context_dim: 768
52
+ init_scale: 1.0
53
+ skip_ln: true
54
+ use_checkpoint: true
55
+
56
+ scheduler_cfg:
57
+ guidance_scale: 7.5
58
+ num_inference_steps: 50
59
+ eta: 0.0
60
+
61
+ noise:
62
+ target: diffusers.schedulers.DDPMScheduler
63
+ params:
64
+ num_train_timesteps: 1000
65
+ beta_start: 0.00085
66
+ beta_end: 0.012
67
+ beta_schedule: "scaled_linear"
68
+ variance_type: "fixed_small"
69
+ clip_sample: false
70
+ denoise:
71
+ target: diffusers.schedulers.DDIMScheduler
72
+ params:
73
+ num_train_timesteps: 1000
74
+ beta_start: 0.00085
75
+ beta_end: 0.012
76
+ beta_schedule: "scaled_linear"
77
+ clip_sample: false # clip sample to -1~1
78
+ set_alpha_to_one: false
79
+ steps_offset: 1
80
+
81
+ optimizer_cfg:
82
+ optimizer:
83
+ target: torch.optim.AdamW
84
+ params:
85
+ betas: [0.9, 0.99]
86
+ eps: 1.e-6
87
+ weight_decay: 1.e-2
88
+
89
+ scheduler:
90
+ target: michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler
91
+ params:
92
+ warm_up_steps: 5000
93
+ f_start: 1.e-6
94
+ f_min: 1.e-3
95
+ f_max: 1.0
96
+
97
+ loss_cfg:
98
+ loss_type: "mse"
example_data/image/car.jpg ADDED
example_data/surface/surface.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0893e44d82ada683baa656a718beaf6ec19fc28b6816b451f56645530d5bb962
3
+ size 1201024
gradio_app.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import time
4
+ from collections import OrderedDict
5
+ from PIL import Image
6
+ import torch
7
+ import trimesh
8
+ from typing import Optional, List
9
+ from einops import repeat, rearrange
10
+ import numpy as np
11
+ from michelangelo.models.tsal.tsal_base import Latent2MeshOutput
12
+ from michelangelo.utils.misc import get_config_from_file, instantiate_from_config
13
+ from michelangelo.utils.visualizers.pythreejs_viewer import PyThreeJSViewer
14
+ from michelangelo.utils.visualizers import html_util
15
+
16
+ import gradio as gr
17
+
18
+
19
+ gradio_cached_dir = "./gradio_cached_dir"
20
+ os.makedirs(gradio_cached_dir, exist_ok=True)
21
+
22
+ save_mesh = False
23
+
24
+ state = ""
25
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26
+
27
+ box_v = 1.1
28
+ viewer = PyThreeJSViewer(settings={}, render_mode="WEBSITE")
29
+
30
+ image_model_config_dict = OrderedDict({
31
+ "ASLDM-256-obj": {
32
+ "config": "./configs/image_cond_diffuser_asl/image-ASLDM-256.yaml",
33
+ "ckpt_path": "./checkpoints/image_cond_diffuser_asl/image-ASLDM-256.ckpt",
34
+ },
35
+ })
36
+
37
+ text_model_config_dict = OrderedDict({
38
+ "ASLDM-256": {
39
+ "config": "./configs/text_cond_diffuser_asl/text-ASLDM-256.yaml",
40
+ "ckpt_path": "./checkpoints/text_cond_diffuser_asl/text-ASLDM-256.ckpt",
41
+ },
42
+ })
43
+
44
+
45
+ class InferenceModel(object):
46
+ model = None
47
+ name = ""
48
+
49
+
50
+ text2mesh_model = InferenceModel()
51
+ image2mesh_model = InferenceModel()
52
+
53
+
54
+ def set_state(s):
55
+ global state
56
+ state = s
57
+ print(s)
58
+
59
+
60
+ def output_to_html_frame(mesh_outputs: List[Latent2MeshOutput], bbox_size: float,
61
+ image: Optional[np.ndarray] = None,
62
+ html_frame: bool = False):
63
+ global viewer
64
+
65
+ for i in range(len(mesh_outputs)):
66
+ mesh = mesh_outputs[i]
67
+ if mesh is None:
68
+ continue
69
+
70
+ mesh_v = mesh.mesh_v.copy()
71
+ mesh_v[:, 0] += i * np.max(bbox_size)
72
+ mesh_v[:, 2] += np.max(bbox_size)
73
+ viewer.add_mesh(mesh_v, mesh.mesh_f)
74
+
75
+ mesh_tag = viewer.to_html(html_frame=False)
76
+
77
+ if image is not None:
78
+ image_tag = html_util.to_image_embed_tag(image)
79
+ frame = f"""
80
+ <table border = "1">
81
+ <tr>
82
+ <td>{image_tag}</td>
83
+ <td>{mesh_tag}</td>
84
+ </tr>
85
+ </table>
86
+ """
87
+ else:
88
+ frame = mesh_tag
89
+
90
+ if html_frame:
91
+ frame = html_util.to_html_frame(frame)
92
+
93
+ viewer.reset()
94
+
95
+ return frame
96
+
97
+
98
+ def load_model(model_name: str, model_config_dict: dict, inference_model: InferenceModel):
99
+ global device
100
+
101
+ if inference_model.name == model_name:
102
+ model = inference_model.model
103
+ else:
104
+ assert model_name in model_config_dict
105
+
106
+ if inference_model.model is not None:
107
+ del inference_model.model
108
+
109
+ config_ckpt_path = model_config_dict[model_name]
110
+
111
+ model_config = get_config_from_file(config_ckpt_path["config"])
112
+ if hasattr(model_config, "model"):
113
+ model_config = model_config.model
114
+
115
+ model = instantiate_from_config(model_config, ckpt_path=config_ckpt_path["ckpt_path"])
116
+ model = model.to(device)
117
+ model = model.eval()
118
+
119
+ inference_model.model = model
120
+ inference_model.name = model_name
121
+
122
+ return model
123
+
124
+
125
+ def prepare_img(image: np.ndarray):
126
+ image_pt = torch.tensor(image).float()
127
+ image_pt = image_pt / 255 * 2 - 1
128
+ image_pt = rearrange(image_pt, "h w c -> c h w")
129
+
130
+ return image_pt
131
+
132
+ def prepare_model_viewer(fp):
133
+ content = f"""
134
+ <head>
135
+ <script
136
+ type="module" src="https://ajax.googleapis.com/ajax/libs/model-viewer/3.1.1/model-viewer.min.js">
137
+ </script>
138
+ </head>
139
+ <body>
140
+ <model-viewer
141
+ style="height: 150px; width: 150px;"
142
+ rotation-per-second="10deg"
143
+ id="t1"
144
+ src="file/gradio_cached_dir/{fp}"
145
+ environment-image="neutral"
146
+ camera-target="0m 0m 0m"
147
+ orientation="0deg 90deg 170deg"
148
+ shadow-intensity="1"
149
+ ar:true
150
+ auto-rotate
151
+ camera-controls>
152
+ </model-viewer>
153
+ </body>
154
+ """
155
+ return content
156
+
157
+ def prepare_html_frame(content):
158
+ frame = f"""
159
+ <html>
160
+ <body>
161
+ {content}
162
+ </body>
163
+ </html>
164
+ """
165
+ return frame
166
+
167
+ def prepare_html_body(content):
168
+ frame = f"""
169
+ <body>
170
+ {content}
171
+ </body>
172
+ """
173
+ return frame
174
+
175
+ def post_process_mesh_outputs(mesh_outputs):
176
+ # html_frame = output_to_html_frame(mesh_outputs, 2 * box_v, image=None, html_frame=True)
177
+ html_content = output_to_html_frame(mesh_outputs, 2 * box_v, image=None, html_frame=False)
178
+ html_frame = prepare_html_frame(html_content)
179
+
180
+ # filename = f"{time.time()}.html"
181
+ filename = f"text-256-{time.time()}.html"
182
+ html_filepath = os.path.join(gradio_cached_dir, filename)
183
+ with open(html_filepath, "w") as writer:
184
+ writer.write(html_frame)
185
+
186
+ '''
187
+ Bug: The iframe tag does not work in Gradio.
188
+ The chrome returns "No resource with given URL found"
189
+ Solutions:
190
+ https://github.com/gradio-app/gradio/issues/884
191
+ Due to the security bitches, the server can only find files parallel to the gradio_app.py.
192
+ The path has format "file/TARGET_FILE_PATH"
193
+ '''
194
+
195
+ iframe_tag = f'<iframe src="file/gradio_cached_dir/{filename}" width="600%" height="400" frameborder="0"></iframe>'
196
+
197
+ filelist = []
198
+ filenames = []
199
+ for i, mesh in enumerate(mesh_outputs):
200
+ mesh.mesh_f = mesh.mesh_f[:, ::-1]
201
+ mesh_output = trimesh.Trimesh(mesh.mesh_v, mesh.mesh_f)
202
+
203
+ name = str(i) + "_out_mesh.obj"
204
+ filepath = gradio_cached_dir + "/" + name
205
+ mesh_output.export(filepath, include_normals=True)
206
+ filelist.append(filepath)
207
+ filenames.append(name)
208
+
209
+ filelist.append(html_filepath)
210
+ return iframe_tag, filelist
211
+
212
+ def image2mesh(image: np.ndarray,
213
+ model_name: str = "subsp+pk_asl_perceiver=01_01_udt=03",
214
+ num_samples: int = 4,
215
+ guidance_scale: int = 7.5,
216
+ octree_depth: int = 7):
217
+ global device, gradio_cached_dir, image_model_config_dict, box_v
218
+
219
+ # load model
220
+ model = load_model(model_name, image_model_config_dict, image2mesh_model)
221
+
222
+ # prepare image inputs
223
+ image_pt = prepare_img(image)
224
+ image_pt = repeat(image_pt, "c h w -> b c h w", b=num_samples)
225
+
226
+ sample_inputs = {
227
+ "image": image_pt
228
+ }
229
+ mesh_outputs = model.sample(
230
+ sample_inputs,
231
+ sample_times=1,
232
+ guidance_scale=guidance_scale,
233
+ return_intermediates=False,
234
+ bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v],
235
+ octree_depth=octree_depth,
236
+ )[0]
237
+
238
+ iframe_tag, filelist = post_process_mesh_outputs(mesh_outputs)
239
+
240
+ return iframe_tag, gr.update(value=filelist, visible=True)
241
+
242
+
243
+ def text2mesh(text: str,
244
+ model_name: str = "subsp+pk_asl_perceiver=01_01_udt=03",
245
+ num_samples: int = 4,
246
+ guidance_scale: int = 7.5,
247
+ octree_depth: int = 7):
248
+ global device, gradio_cached_dir, text_model_config_dict, text2mesh_model, box_v
249
+
250
+ # load model
251
+ model = load_model(model_name, text_model_config_dict, text2mesh_model)
252
+
253
+ # prepare text inputs
254
+ sample_inputs = {
255
+ "text": [text] * num_samples
256
+ }
257
+ mesh_outputs = model.sample(
258
+ sample_inputs,
259
+ sample_times=1,
260
+ guidance_scale=guidance_scale,
261
+ return_intermediates=False,
262
+ bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v],
263
+ octree_depth=octree_depth,
264
+ )[0]
265
+
266
+ iframe_tag, filelist = post_process_mesh_outputs(mesh_outputs)
267
+
268
+ return iframe_tag, gr.update(value=filelist, visible=True)
269
+
270
+ example_dir = './gradio_cached_dir/example/img_example'
271
+
272
+ first_page_items = [
273
+ 'alita.jpg',
274
+ 'burger.jpg'
275
+ 'loopy.jpg'
276
+ 'building.jpg',
277
+ 'mario.jpg',
278
+ 'car.jpg',
279
+ 'airplane.jpg',
280
+ 'bag.jpg',
281
+ 'bench.jpg',
282
+ 'ship.jpg'
283
+ ]
284
+ raw_example_items = [
285
+ # (os.path.join(example_dir, x), x)
286
+ os.path.join(example_dir, x)
287
+ for x in os.listdir(example_dir)
288
+ if x.endswith(('.jpg', '.png'))
289
+ ]
290
+ example_items = [x for x in raw_example_items if os.path.basename(x) in first_page_items] + [x for x in raw_example_items if os.path.basename(x) not in first_page_items]
291
+
292
+ example_text = [
293
+ ["A 3D model of a car; Audi A6."],
294
+ ["A 3D model of police car; Highway Patrol Charger"]
295
+ ],
296
+
297
+ def set_cache(data: gr.SelectData):
298
+ img_name = os.path.basename(example_items[data.index])
299
+ return os.path.join(example_dir, img_name), os.path.join(img_name)
300
+
301
+ def disable_cache():
302
+ return ""
303
+
304
+ with gr.Blocks() as app:
305
+ gr.Markdown("# Michelangelo")
306
+ gr.Markdown("## [Github](https://github.com/NeuralCarver/Michelangelo) | [Arxiv](https://arxiv.org/abs/2306.17115) | [Project Page](https://neuralcarver.github.io/michelangelo/)")
307
+ gr.Markdown("Michelangelo is a conditional 3D shape generation system that trains based on the shape-image-text aligned latent representation.")
308
+ gr.Markdown("### Hint:")
309
+ gr.Markdown("1. We provide two APIs: Image-conditioned generation and Text-conditioned generation")
310
+ gr.Markdown("2. Note that the Image-conditioned model is trained on multiple 3D datasets like ShapeNet and Objaverse")
311
+ gr.Markdown("3. We provide some examples for you to try. You can also upload images or text as input.")
312
+ gr.Markdown("4. Welcome to share your amazing results with us, and thanks for your interest in our work!")
313
+
314
+ with gr.Row():
315
+ with gr.Column():
316
+
317
+ with gr.Tab("Image to 3D"):
318
+ img = gr.Image(label="Image")
319
+ gr.Markdown("For the best results, we suggest that the images uploaded meet the following three criteria: 1. The object is positioned at the center of the image, 2. The image size is square, and 3. The background is relatively clean.")
320
+ btn_generate_img2obj = gr.Button(value="Generate")
321
+
322
+ with gr.Accordion("Advanced settings", open=False):
323
+ image_dropdown_models = gr.Dropdown(label="Model", value="ASLDM-256-obj",choices=list(image_model_config_dict.keys()))
324
+ num_samples = gr.Slider(label="samples", value=4, minimum=1, maximum=8, step=1)
325
+ guidance_scale = gr.Slider(label="Guidance scale", value=7.5, minimum=3.0, maximum=10.0, step=0.1)
326
+ octree_depth = gr.Slider(label="Octree Depth (for 3D model)", value=7, minimum=4, maximum=8, step=1)
327
+
328
+
329
+ cache_dir = gr.Textbox(value="", visible=False)
330
+ examples = gr.Gallery(label='Examples', value=example_items, elem_id="gallery", allow_preview=False, columns=[4], object_fit="contain")
331
+
332
+ with gr.Tab("Text to 3D"):
333
+ prompt = gr.Textbox(label="Prompt", placeholder="A 3D model of motorcar; Porche Cayenne Turbo.")
334
+ gr.Markdown("For the best results, we suggest that the prompt follows 'A 3D model of CATEGORY; DESCRIPTION'. For example, A 3D model of motorcar; Porche Cayenne Turbo.")
335
+ btn_generate_txt2obj = gr.Button(value="Generate")
336
+
337
+ with gr.Accordion("Advanced settings", open=False):
338
+ text_dropdown_models = gr.Dropdown(label="Model", value="ASLDM-256",choices=list(text_model_config_dict.keys()))
339
+ num_samples = gr.Slider(label="samples", value=4, minimum=1, maximum=8, step=1)
340
+ guidance_scale = gr.Slider(label="Guidance scale", value=7.5, minimum=3.0, maximum=10.0, step=0.1)
341
+ octree_depth = gr.Slider(label="Octree Depth (for 3D model)", value=7, minimum=4, maximum=8, step=1)
342
+
343
+ gr.Markdown("#### Examples:")
344
+ gr.Markdown("1. A 3D model of a coupe; Audi A6.")
345
+ gr.Markdown("2. A 3D model of a motorcar; Hummer H2 SUT.")
346
+ gr.Markdown("3. A 3D model of an airplane; Airbus.")
347
+ gr.Markdown("4. A 3D model of a fighter aircraft; Attack Fighter.")
348
+ gr.Markdown("5. A 3D model of a chair; Simple Wooden Chair.")
349
+ gr.Markdown("6. A 3D model of a laptop computer; Dell Laptop.")
350
+ gr.Markdown("7. A 3D model of a lamp; ceiling light.")
351
+ gr.Markdown("8. A 3D model of a rifle; AK47.")
352
+ gr.Markdown("9. A 3D model of a knife; Sword.")
353
+ gr.Markdown("10. A 3D model of a vase; Plant in pot.")
354
+
355
+ with gr.Column():
356
+ model_3d = gr.HTML()
357
+ file_out = gr.File(label="Files", visible=False)
358
+
359
+ outputs = [model_3d, file_out]
360
+
361
+ img.upload(disable_cache, outputs=cache_dir)
362
+ examples.select(set_cache, outputs=[img, cache_dir])
363
+ print(f'line:404: {cache_dir}')
364
+ btn_generate_img2obj.click(image2mesh, inputs=[img, image_dropdown_models, num_samples,
365
+ guidance_scale, octree_depth],
366
+ outputs=outputs, api_name="generate_img2obj")
367
+
368
+ btn_generate_txt2obj.click(text2mesh, inputs=[prompt, text_dropdown_models, num_samples,
369
+ guidance_scale, octree_depth],
370
+ outputs=outputs, api_name="generate_txt2obj")
371
+
372
+ app.launch(server_name="0.0.0.0", server_port=8008, share=False)
gradio_cached_dir/example/img_example/airplane.jpg ADDED
gradio_cached_dir/example/img_example/alita.jpg ADDED
gradio_cached_dir/example/img_example/bag.jpg ADDED
gradio_cached_dir/example/img_example/bench.jpg ADDED
gradio_cached_dir/example/img_example/building.jpg ADDED
gradio_cached_dir/example/img_example/burger.jpg ADDED
gradio_cached_dir/example/img_example/car.jpg ADDED
gradio_cached_dir/example/img_example/loopy.jpg ADDED
gradio_cached_dir/example/img_example/mario.jpg ADDED
gradio_cached_dir/example/img_example/ship.jpg ADDED
inference.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import time
4
+ from collections import OrderedDict
5
+ from typing import Optional, List
6
+ import argparse
7
+ from functools import partial
8
+
9
+ from einops import repeat, rearrange
10
+ import numpy as np
11
+ from PIL import Image
12
+ import trimesh
13
+ import cv2
14
+
15
+ import torch
16
+ import pytorch_lightning as pl
17
+
18
+ from michelangelo.models.tsal.tsal_base import Latent2MeshOutput
19
+ from michelangelo.models.tsal.inference_utils import extract_geometry
20
+ from michelangelo.utils.misc import get_config_from_file, instantiate_from_config
21
+ from michelangelo.utils.visualizers.pythreejs_viewer import PyThreeJSViewer
22
+ from michelangelo.utils.visualizers import html_util
23
+
24
+ def load_model(args):
25
+
26
+ model_config = get_config_from_file(args.config_path)
27
+ if hasattr(model_config, "model"):
28
+ model_config = model_config.model
29
+
30
+ model = instantiate_from_config(model_config, ckpt_path=args.ckpt_path)
31
+ model = model.cuda()
32
+ model = model.eval()
33
+
34
+ return model
35
+
36
+ def load_surface(fp):
37
+
38
+ with np.load(args.pointcloud_path) as input_pc:
39
+ surface = input_pc['points']
40
+ normal = input_pc['normals']
41
+
42
+ rng = np.random.default_rng()
43
+ ind = rng.choice(surface.shape[0], 4096, replace=False)
44
+ surface = torch.FloatTensor(surface[ind])
45
+ normal = torch.FloatTensor(normal[ind])
46
+
47
+ surface = torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda()
48
+
49
+ return surface
50
+
51
+ def prepare_image(args, number_samples=2):
52
+
53
+ image = cv2.imread(f"{args.image_path}")
54
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
55
+
56
+ image_pt = torch.tensor(image).float()
57
+ image_pt = image_pt / 255 * 2 - 1
58
+ image_pt = rearrange(image_pt, "h w c -> c h w")
59
+
60
+ image_pt = repeat(image_pt, "c h w -> b c h w", b=number_samples)
61
+
62
+ return image_pt
63
+
64
+ def save_output(args, mesh_outputs):
65
+
66
+ os.makedirs(args.output_dir, exist_ok=True)
67
+ for i, mesh in enumerate(mesh_outputs):
68
+ mesh.mesh_f = mesh.mesh_f[:, ::-1]
69
+ mesh_output = trimesh.Trimesh(mesh.mesh_v, mesh.mesh_f)
70
+
71
+ name = str(i) + "_out_mesh.obj"
72
+ mesh_output.export(os.path.join(args.output_dir, name), include_normals=True)
73
+
74
+ print(f'-----------------------------------------------------------------------------')
75
+ print(f'>>> Finished and mesh saved in {args.output_dir}')
76
+ print(f'-----------------------------------------------------------------------------')
77
+
78
+ return 0
79
+
80
+ def reconstruction(args, model, bounds=(-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), octree_depth=7, num_chunks=10000):
81
+
82
+ surface = load_surface(args.pointcloud_path)
83
+
84
+ # encoding
85
+ shape_embed, shape_latents = model.model.encode_shape_embed(surface, return_latents=True)
86
+ shape_zq, posterior = model.model.shape_model.encode_kl_embed(shape_latents)
87
+
88
+ # decoding
89
+ latents = model.model.shape_model.decode(shape_zq)
90
+ geometric_func = partial(model.model.shape_model.query_geometry, latents=latents)
91
+
92
+ # reconstruction
93
+ mesh_v_f, has_surface = extract_geometry(
94
+ geometric_func=geometric_func,
95
+ device=surface.device,
96
+ batch_size=surface.shape[0],
97
+ bounds=bounds,
98
+ octree_depth=octree_depth,
99
+ num_chunks=num_chunks,
100
+ )
101
+ recon_mesh = trimesh.Trimesh(mesh_v_f[0][0], mesh_v_f[0][1])
102
+
103
+ # save
104
+ os.makedirs(args.output_dir, exist_ok=True)
105
+ recon_mesh.export(os.path.join(args.output_dir, 'reconstruction.obj'))
106
+
107
+ print(f'-----------------------------------------------------------------------------')
108
+ print(f'>>> Finished and mesh saved in {os.path.join(args.output_dir, "reconstruction.obj")}')
109
+ print(f'-----------------------------------------------------------------------------')
110
+
111
+ return 0
112
+
113
+ def image2mesh(args, model, guidance_scale=7.5, box_v=1.1, octree_depth=7):
114
+
115
+ sample_inputs = {
116
+ "image": prepare_image(args)
117
+ }
118
+
119
+ mesh_outputs = model.sample(
120
+ sample_inputs,
121
+ sample_times=1,
122
+ guidance_scale=guidance_scale,
123
+ return_intermediates=False,
124
+ bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v],
125
+ octree_depth=octree_depth,
126
+ )[0]
127
+
128
+ save_output(args, mesh_outputs)
129
+
130
+ return 0
131
+
132
+ def text2mesh(args, model, num_samples=2, guidance_scale=7.5, box_v=1.1, octree_depth=7):
133
+
134
+ sample_inputs = {
135
+ "text": [args.text] * num_samples
136
+ }
137
+ mesh_outputs = model.sample(
138
+ sample_inputs,
139
+ sample_times=1,
140
+ guidance_scale=guidance_scale,
141
+ return_intermediates=False,
142
+ bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v],
143
+ octree_depth=octree_depth,
144
+ )[0]
145
+
146
+ save_output(args, mesh_outputs)
147
+
148
+ return 0
149
+
150
+ task_dick = {
151
+ 'reconstruction': reconstruction,
152
+ 'image2mesh': image2mesh,
153
+ 'text2mesh': text2mesh,
154
+ }
155
+
156
+ if __name__ == "__main__":
157
+ '''
158
+ 1. Reconstruct point cloud
159
+ 2. Image-conditioned generation
160
+ 3. Text-conditioned generation
161
+ '''
162
+ parser = argparse.ArgumentParser()
163
+ parser.add_argument("--task", type=str, choices=['reconstruction', 'image2mesh', 'text2mesh'], required=True)
164
+ parser.add_argument("--config_path", type=str, required=True)
165
+ parser.add_argument("--ckpt_path", type=str, required=True)
166
+ parser.add_argument("--pointcloud_path", type=str, default='./example_data/surface.npz', help='Path to the input point cloud')
167
+ parser.add_argument("--image_path", type=str, help='Path to the input image')
168
+ parser.add_argument("--text", type=str, help='Input text within a format: A 3D model of motorcar; Porsche 911.')
169
+ parser.add_argument("--output_dir", type=str, default='./output')
170
+ parser.add_argument("-s", "--seed", type=int, default=0)
171
+ args = parser.parse_args()
172
+
173
+ pl.seed_everything(args.seed)
174
+
175
+ print(f'-----------------------------------------------------------------------------')
176
+ print(f'>>> Running {args.task}')
177
+ args.output_dir = os.path.join(args.output_dir, args.task)
178
+ print(f'>>> Output directory: {args.output_dir}')
179
+ print(f'-----------------------------------------------------------------------------')
180
+
181
+ task_dick[args.task](args, load_model(args))
michelangelo/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
michelangelo/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (176 Bytes). View file
 
michelangelo/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
michelangelo/data/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (181 Bytes). View file
 
michelangelo/data/__pycache__/asl_webdataset.cpython-39.pyc ADDED
Binary file (9.43 kB). View file
 
michelangelo/data/__pycache__/tokenizer.cpython-39.pyc ADDED
Binary file (6.48 kB). View file
 
michelangelo/data/__pycache__/transforms.cpython-39.pyc ADDED
Binary file (11.4 kB). View file
 
michelangelo/data/__pycache__/utils.cpython-39.pyc ADDED
Binary file (1.13 kB). View file
 
michelangelo/data/templates.json ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "shape": [
3
+ "a point cloud model of {}.",
4
+ "There is a {} in the scene.",
5
+ "There is the {} in the scene.",
6
+ "a photo of a {} in the scene.",
7
+ "a photo of the {} in the scene.",
8
+ "a photo of one {} in the scene.",
9
+ "itap of a {}.",
10
+ "itap of my {}.",
11
+ "itap of the {}.",
12
+ "a photo of a {}.",
13
+ "a photo of my {}.",
14
+ "a photo of the {}.",
15
+ "a photo of one {}.",
16
+ "a photo of many {}.",
17
+ "a good photo of a {}.",
18
+ "a good photo of the {}.",
19
+ "a bad photo of a {}.",
20
+ "a bad photo of the {}.",
21
+ "a photo of a nice {}.",
22
+ "a photo of the nice {}.",
23
+ "a photo of a cool {}.",
24
+ "a photo of the cool {}.",
25
+ "a photo of a weird {}.",
26
+ "a photo of the weird {}.",
27
+ "a photo of a small {}.",
28
+ "a photo of the small {}.",
29
+ "a photo of a large {}.",
30
+ "a photo of the large {}.",
31
+ "a photo of a clean {}.",
32
+ "a photo of the clean {}.",
33
+ "a photo of a dirty {}.",
34
+ "a photo of the dirty {}.",
35
+ "a bright photo of a {}.",
36
+ "a bright photo of the {}.",
37
+ "a dark photo of a {}.",
38
+ "a dark photo of the {}.",
39
+ "a photo of a hard to see {}.",
40
+ "a photo of the hard to see {}.",
41
+ "a low resolution photo of a {}.",
42
+ "a low resolution photo of the {}.",
43
+ "a cropped photo of a {}.",
44
+ "a cropped photo of the {}.",
45
+ "a close-up photo of a {}.",
46
+ "a close-up photo of the {}.",
47
+ "a jpeg corrupted photo of a {}.",
48
+ "a jpeg corrupted photo of the {}.",
49
+ "a blurry photo of a {}.",
50
+ "a blurry photo of the {}.",
51
+ "a pixelated photo of a {}.",
52
+ "a pixelated photo of the {}.",
53
+ "a black and white photo of the {}.",
54
+ "a black and white photo of a {}",
55
+ "a plastic {}.",
56
+ "the plastic {}.",
57
+ "a toy {}.",
58
+ "the toy {}.",
59
+ "a plushie {}.",
60
+ "the plushie {}.",
61
+ "a cartoon {}.",
62
+ "the cartoon {}.",
63
+ "an embroidered {}.",
64
+ "the embroidered {}.",
65
+ "a painting of the {}.",
66
+ "a painting of a {}."
67
+ ]
68
+
69
+ }
michelangelo/data/transforms.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import time
4
+ import numpy as np
5
+ import warnings
6
+ import random
7
+ from omegaconf.listconfig import ListConfig
8
+ from webdataset import pipelinefilter
9
+ import torch
10
+ import torchvision.transforms.functional as TVF
11
+ from torchvision.transforms import InterpolationMode
12
+ from torchvision.transforms.transforms import _interpolation_modes_from_int
13
+ from typing import Sequence
14
+
15
+ from michelangelo.utils import instantiate_from_config
16
+
17
+
18
+ def _uid_buffer_pick(buf_dict, rng):
19
+ uid_keys = list(buf_dict.keys())
20
+ selected_uid = rng.choice(uid_keys)
21
+ buf = buf_dict[selected_uid]
22
+
23
+ k = rng.randint(0, len(buf) - 1)
24
+ sample = buf[k]
25
+ buf[k] = buf[-1]
26
+ buf.pop()
27
+
28
+ if len(buf) == 0:
29
+ del buf_dict[selected_uid]
30
+
31
+ return sample
32
+
33
+
34
+ def _add_to_buf_dict(buf_dict, sample):
35
+ key = sample["__key__"]
36
+ uid, uid_sample_id = key.split("_")
37
+ if uid not in buf_dict:
38
+ buf_dict[uid] = []
39
+ buf_dict[uid].append(sample)
40
+
41
+ return buf_dict
42
+
43
+
44
+ def _uid_shuffle(data, bufsize=1000, initial=100, rng=None, handler=None):
45
+ """Shuffle the data in the stream.
46
+
47
+ This uses a buffer of size `bufsize`. Shuffling at
48
+ startup is less random; this is traded off against
49
+ yielding samples quickly.
50
+
51
+ data: iterator
52
+ bufsize: buffer size for shuffling
53
+ returns: iterator
54
+ rng: either random module or random.Random instance
55
+
56
+ """
57
+ if rng is None:
58
+ rng = random.Random(int((os.getpid() + time.time()) * 1e9))
59
+ initial = min(initial, bufsize)
60
+ buf_dict = dict()
61
+ current_samples = 0
62
+ for sample in data:
63
+ _add_to_buf_dict(buf_dict, sample)
64
+ current_samples += 1
65
+
66
+ if current_samples < bufsize:
67
+ try:
68
+ _add_to_buf_dict(buf_dict, next(data)) # skipcq: PYL-R1708
69
+ current_samples += 1
70
+ except StopIteration:
71
+ pass
72
+
73
+ if current_samples >= initial:
74
+ current_samples -= 1
75
+ yield _uid_buffer_pick(buf_dict, rng)
76
+
77
+ while current_samples > 0:
78
+ current_samples -= 1
79
+ yield _uid_buffer_pick(buf_dict, rng)
80
+
81
+
82
+ uid_shuffle = pipelinefilter(_uid_shuffle)
83
+
84
+
85
+ class RandomSample(object):
86
+ def __init__(self,
87
+ num_volume_samples: int = 1024,
88
+ num_near_samples: int = 1024):
89
+
90
+ super().__init__()
91
+
92
+ self.num_volume_samples = num_volume_samples
93
+ self.num_near_samples = num_near_samples
94
+
95
+ def __call__(self, sample):
96
+ rng = np.random.default_rng()
97
+
98
+ # 1. sample surface input
99
+ total_surface = sample["surface"]
100
+ ind = rng.choice(total_surface.shape[0], replace=False)
101
+ surface = total_surface[ind]
102
+
103
+ # 2. sample volume/near geometric points
104
+ vol_points = sample["vol_points"]
105
+ vol_label = sample["vol_label"]
106
+ near_points = sample["near_points"]
107
+ near_label = sample["near_label"]
108
+
109
+ ind = rng.choice(vol_points.shape[0], self.num_volume_samples, replace=False)
110
+ vol_points = vol_points[ind]
111
+ vol_label = vol_label[ind]
112
+ vol_points_labels = np.concatenate([vol_points, vol_label[:, np.newaxis]], axis=1)
113
+
114
+ ind = rng.choice(near_points.shape[0], self.num_near_samples, replace=False)
115
+ near_points = near_points[ind]
116
+ near_label = near_label[ind]
117
+ near_points_labels = np.concatenate([near_points, near_label[:, np.newaxis]], axis=1)
118
+
119
+ # concat sampled volume and near points
120
+ geo_points = np.concatenate([vol_points_labels, near_points_labels], axis=0)
121
+
122
+ sample = {
123
+ "surface": surface,
124
+ "geo_points": geo_points
125
+ }
126
+
127
+ return sample
128
+
129
+
130
+ class SplitRandomSample(object):
131
+ def __init__(self,
132
+ use_surface_sample: bool = False,
133
+ num_surface_samples: int = 4096,
134
+ num_volume_samples: int = 1024,
135
+ num_near_samples: int = 1024):
136
+
137
+ super().__init__()
138
+
139
+ self.use_surface_sample = use_surface_sample
140
+ self.num_surface_samples = num_surface_samples
141
+ self.num_volume_samples = num_volume_samples
142
+ self.num_near_samples = num_near_samples
143
+
144
+ def __call__(self, sample):
145
+
146
+ rng = np.random.default_rng()
147
+
148
+ # 1. sample surface input
149
+ surface = sample["surface"]
150
+
151
+ if self.use_surface_sample:
152
+ replace = surface.shape[0] < self.num_surface_samples
153
+ ind = rng.choice(surface.shape[0], self.num_surface_samples, replace=replace)
154
+ surface = surface[ind]
155
+
156
+ # 2. sample volume/near geometric points
157
+ vol_points = sample["vol_points"]
158
+ vol_label = sample["vol_label"]
159
+ near_points = sample["near_points"]
160
+ near_label = sample["near_label"]
161
+
162
+ ind = rng.choice(vol_points.shape[0], self.num_volume_samples, replace=False)
163
+ vol_points = vol_points[ind]
164
+ vol_label = vol_label[ind]
165
+ vol_points_labels = np.concatenate([vol_points, vol_label[:, np.newaxis]], axis=1)
166
+
167
+ ind = rng.choice(near_points.shape[0], self.num_near_samples, replace=False)
168
+ near_points = near_points[ind]
169
+ near_label = near_label[ind]
170
+ near_points_labels = np.concatenate([near_points, near_label[:, np.newaxis]], axis=1)
171
+
172
+ # concat sampled volume and near points
173
+ geo_points = np.concatenate([vol_points_labels, near_points_labels], axis=0)
174
+
175
+ sample = {
176
+ "surface": surface,
177
+ "geo_points": geo_points
178
+ }
179
+
180
+ return sample
181
+
182
+
183
+ class FeatureSelection(object):
184
+
185
+ VALID_SURFACE_FEATURE_DIMS = {
186
+ "none": [0, 1, 2], # xyz
187
+ "watertight_normal": [0, 1, 2, 3, 4, 5], # xyz, normal
188
+ "normal": [0, 1, 2, 6, 7, 8]
189
+ }
190
+
191
+ def __init__(self, surface_feature_type: str):
192
+
193
+ self.surface_feature_type = surface_feature_type
194
+ self.surface_dims = self.VALID_SURFACE_FEATURE_DIMS[surface_feature_type]
195
+
196
+ def __call__(self, sample):
197
+ sample["surface"] = sample["surface"][:, self.surface_dims]
198
+ return sample
199
+
200
+
201
+ class AxisScaleTransform(object):
202
+ def __init__(self, interval=(0.75, 1.25), jitter=True, jitter_scale=0.005):
203
+ assert isinstance(interval, (tuple, list, ListConfig))
204
+ self.interval = interval
205
+ self.min_val = interval[0]
206
+ self.max_val = interval[1]
207
+ self.inter_size = interval[1] - interval[0]
208
+ self.jitter = jitter
209
+ self.jitter_scale = jitter_scale
210
+
211
+ def __call__(self, sample):
212
+
213
+ surface = sample["surface"][..., 0:3]
214
+ geo_points = sample["geo_points"][..., 0:3]
215
+
216
+ scaling = torch.rand(1, 3) * self.inter_size + self.min_val
217
+ # print(scaling)
218
+ surface = surface * scaling
219
+ geo_points = geo_points * scaling
220
+
221
+ scale = (1 / torch.abs(surface).max().item()) * 0.999999
222
+ surface *= scale
223
+ geo_points *= scale
224
+
225
+ if self.jitter:
226
+ surface += self.jitter_scale * torch.randn_like(surface)
227
+ surface.clamp_(min=-1.015, max=1.015)
228
+
229
+ sample["surface"][..., 0:3] = surface
230
+ sample["geo_points"][..., 0:3] = geo_points
231
+
232
+ return sample
233
+
234
+
235
+ class ToTensor(object):
236
+
237
+ def __init__(self, tensor_keys=("surface", "geo_points", "tex_points")):
238
+ self.tensor_keys = tensor_keys
239
+
240
+ def __call__(self, sample):
241
+ for key in self.tensor_keys:
242
+ if key not in sample:
243
+ continue
244
+
245
+ sample[key] = torch.tensor(sample[key], dtype=torch.float32)
246
+
247
+ return sample
248
+
249
+
250
+ class AxisScale(object):
251
+ def __init__(self, interval=(0.75, 1.25), jitter=True, jitter_scale=0.005):
252
+ assert isinstance(interval, (tuple, list, ListConfig))
253
+ self.interval = interval
254
+ self.jitter = jitter
255
+ self.jitter_scale = jitter_scale
256
+
257
+ def __call__(self, surface, *args):
258
+ scaling = torch.rand(1, 3) * 0.5 + 0.75
259
+ # print(scaling)
260
+ surface = surface * scaling
261
+ scale = (1 / torch.abs(surface).max().item()) * 0.999999
262
+ surface *= scale
263
+
264
+ args_outputs = []
265
+ for _arg in args:
266
+ _arg = _arg * scaling * scale
267
+ args_outputs.append(_arg)
268
+
269
+ if self.jitter:
270
+ surface += self.jitter_scale * torch.randn_like(surface)
271
+ surface.clamp_(min=-1, max=1)
272
+
273
+ if len(args) == 0:
274
+ return surface
275
+ else:
276
+ return surface, *args_outputs
277
+
278
+
279
+ class RandomResize(torch.nn.Module):
280
+ """Apply randomly Resize with a given probability."""
281
+
282
+ def __init__(
283
+ self,
284
+ size,
285
+ resize_radio=(0.5, 1),
286
+ allow_resize_interpolations=(InterpolationMode.BICUBIC, InterpolationMode.BILINEAR, InterpolationMode.BILINEAR),
287
+ interpolation=InterpolationMode.BICUBIC,
288
+ max_size=None,
289
+ antialias=None,
290
+ ):
291
+ super().__init__()
292
+ if not isinstance(size, (int, Sequence)):
293
+ raise TypeError(f"Size should be int or sequence. Got {type(size)}")
294
+ if isinstance(size, Sequence) and len(size) not in (1, 2):
295
+ raise ValueError("If size is a sequence, it should have 1 or 2 values")
296
+
297
+ self.size = size
298
+ self.max_size = max_size
299
+ # Backward compatibility with integer value
300
+ if isinstance(interpolation, int):
301
+ warnings.warn(
302
+ "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
303
+ "Please use InterpolationMode enum."
304
+ )
305
+ interpolation = _interpolation_modes_from_int(interpolation)
306
+
307
+ self.interpolation = interpolation
308
+ self.antialias = antialias
309
+
310
+ self.resize_radio = resize_radio
311
+ self.allow_resize_interpolations = allow_resize_interpolations
312
+
313
+ def random_resize_params(self):
314
+ radio = torch.rand(1) * (self.resize_radio[1] - self.resize_radio[0]) + self.resize_radio[0]
315
+
316
+ if isinstance(self.size, int):
317
+ size = int(self.size * radio)
318
+ elif isinstance(self.size, Sequence):
319
+ size = list(self.size)
320
+ size = (int(size[0] * radio), int(size[1] * radio))
321
+ else:
322
+ raise RuntimeError()
323
+
324
+ interpolation = self.allow_resize_interpolations[
325
+ torch.randint(low=0, high=len(self.allow_resize_interpolations), size=(1,))
326
+ ]
327
+ return size, interpolation
328
+
329
+ def forward(self, img):
330
+ size, interpolation = self.random_resize_params()
331
+ img = TVF.resize(img, size, interpolation, self.max_size, self.antialias)
332
+ img = TVF.resize(img, self.size, self.interpolation, self.max_size, self.antialias)
333
+ return img
334
+
335
+ def __repr__(self) -> str:
336
+ detail = f"(size={self.size}, interpolation={self.interpolation.value},"
337
+ detail += f"max_size={self.max_size}, antialias={self.antialias}), resize_radio={self.resize_radio}"
338
+ return f"{self.__class__.__name__}{detail}"
339
+
340
+
341
+ class Compose(object):
342
+ """Composes several transforms together. This transform does not support torchscript.
343
+ Please, see the note below.
344
+
345
+ Args:
346
+ transforms (list of ``Transform`` objects): list of transforms to compose.
347
+
348
+ Example:
349
+ >>> transforms.Compose([
350
+ >>> transforms.CenterCrop(10),
351
+ >>> transforms.ToTensor(),
352
+ >>> ])
353
+
354
+ .. note::
355
+ In order to script the transformations, please use ``torch.nn.Sequential`` as below.
356
+
357
+ >>> transforms = torch.nn.Sequential(
358
+ >>> transforms.CenterCrop(10),
359
+ >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
360
+ >>> )
361
+ >>> scripted_transforms = torch.jit.script(transforms)
362
+
363
+ Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
364
+ `lambda` functions or ``PIL.Image``.
365
+
366
+ """
367
+
368
+ def __init__(self, transforms):
369
+ self.transforms = transforms
370
+
371
+ def __call__(self, *args):
372
+ for t in self.transforms:
373
+ args = t(*args)
374
+ return args
375
+
376
+ def __repr__(self):
377
+ format_string = self.__class__.__name__ + '('
378
+ for t in self.transforms:
379
+ format_string += '\n'
380
+ format_string += ' {0}'.format(t)
381
+ format_string += '\n)'
382
+ return format_string
383
+
384
+
385
+ def identity(*args, **kwargs):
386
+ if len(args) == 1:
387
+ return args[0]
388
+ else:
389
+ return args
390
+
391
+
392
+ def build_transforms(cfg):
393
+
394
+ if cfg is None:
395
+ return identity
396
+
397
+ transforms = []
398
+
399
+ for transform_name, cfg_instance in cfg.items():
400
+ transform_instance = instantiate_from_config(cfg_instance)
401
+ transforms.append(transform_instance)
402
+ print(f"Build transform: {transform_instance}")
403
+
404
+ transforms = Compose(transforms)
405
+
406
+ return transforms
407
+
michelangelo/data/utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ import numpy as np
5
+
6
+
7
+ def worker_init_fn(_):
8
+ worker_info = torch.utils.data.get_worker_info()
9
+ worker_id = worker_info.id
10
+
11
+ # dataset = worker_info.dataset
12
+ # split_size = dataset.num_records // worker_info.num_workers
13
+ # # reset num_records to the true number to retain reliable length information
14
+ # dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
15
+ # current_id = np.random.choice(len(np.random.get_state()[1]), 1)
16
+ # return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
17
+
18
+ return np.random.seed(np.random.get_state()[1][0] + worker_id)
19
+
20
+
21
+ def collation_fn(samples, combine_tensors=True, combine_scalars=True):
22
+ """
23
+
24
+ Args:
25
+ samples (list[dict]):
26
+ combine_tensors:
27
+ combine_scalars:
28
+
29
+ Returns:
30
+
31
+ """
32
+
33
+ result = {}
34
+
35
+ keys = samples[0].keys()
36
+
37
+ for key in keys:
38
+ result[key] = []
39
+
40
+ for sample in samples:
41
+ for key in keys:
42
+ val = sample[key]
43
+ result[key].append(val)
44
+
45
+ for key in keys:
46
+ val_list = result[key]
47
+ if isinstance(val_list[0], (int, float)):
48
+ if combine_scalars:
49
+ result[key] = np.array(result[key])
50
+
51
+ elif isinstance(val_list[0], torch.Tensor):
52
+ if combine_tensors:
53
+ result[key] = torch.stack(val_list)
54
+
55
+ elif isinstance(val_list[0], np.ndarray):
56
+ if combine_tensors:
57
+ result[key] = np.stack(val_list)
58
+
59
+ return result
michelangelo/graphics/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
michelangelo/graphics/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (185 Bytes). View file
 
michelangelo/graphics/primitives/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .volume import generate_dense_grid_points
4
+
5
+ from .mesh import (
6
+ MeshOutput,
7
+ save_obj,
8
+ savemeshtes2
9
+ )
michelangelo/graphics/primitives/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (334 Bytes). View file
 
michelangelo/graphics/primitives/__pycache__/extract_texture_map.cpython-39.pyc ADDED
Binary file (2.46 kB). View file
 
michelangelo/graphics/primitives/__pycache__/mesh.cpython-39.pyc ADDED
Binary file (2.93 kB). View file
 
michelangelo/graphics/primitives/__pycache__/volume.cpython-39.pyc ADDED
Binary file (860 Bytes). View file
 
michelangelo/graphics/primitives/mesh.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import os
4
+ import cv2
5
+ import numpy as np
6
+ import PIL.Image
7
+ from typing import Optional
8
+
9
+ import trimesh
10
+
11
+
12
+ def save_obj(pointnp_px3, facenp_fx3, fname):
13
+ fid = open(fname, "w")
14
+ write_str = ""
15
+ for pidx, p in enumerate(pointnp_px3):
16
+ pp = p
17
+ write_str += "v %f %f %f\n" % (pp[0], pp[1], pp[2])
18
+
19
+ for i, f in enumerate(facenp_fx3):
20
+ f1 = f + 1
21
+ write_str += "f %d %d %d\n" % (f1[0], f1[1], f1[2])
22
+ fid.write(write_str)
23
+ fid.close()
24
+ return
25
+
26
+
27
+ def savemeshtes2(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, tex_map, fname):
28
+ fol, na = os.path.split(fname)
29
+ na, _ = os.path.splitext(na)
30
+
31
+ matname = "%s/%s.mtl" % (fol, na)
32
+ fid = open(matname, "w")
33
+ fid.write("newmtl material_0\n")
34
+ fid.write("Kd 1 1 1\n")
35
+ fid.write("Ka 0 0 0\n")
36
+ fid.write("Ks 0.4 0.4 0.4\n")
37
+ fid.write("Ns 10\n")
38
+ fid.write("illum 2\n")
39
+ fid.write("map_Kd %s.png\n" % na)
40
+ fid.close()
41
+ ####
42
+
43
+ fid = open(fname, "w")
44
+ fid.write("mtllib %s.mtl\n" % na)
45
+
46
+ for pidx, p in enumerate(pointnp_px3):
47
+ pp = p
48
+ fid.write("v %f %f %f\n" % (pp[0], pp[1], pp[2]))
49
+
50
+ for pidx, p in enumerate(tcoords_px2):
51
+ pp = p
52
+ fid.write("vt %f %f\n" % (pp[0], pp[1]))
53
+
54
+ fid.write("usemtl material_0\n")
55
+ for i, f in enumerate(facenp_fx3):
56
+ f1 = f + 1
57
+ f2 = facetex_fx3[i] + 1
58
+ fid.write("f %d/%d %d/%d %d/%d\n" % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2]))
59
+ fid.close()
60
+
61
+ PIL.Image.fromarray(np.ascontiguousarray(tex_map), "RGB").save(
62
+ os.path.join(fol, "%s.png" % na))
63
+
64
+ return
65
+
66
+
67
+ class MeshOutput(object):
68
+
69
+ def __init__(self,
70
+ mesh_v: np.ndarray,
71
+ mesh_f: np.ndarray,
72
+ vertex_colors: Optional[np.ndarray] = None,
73
+ uvs: Optional[np.ndarray] = None,
74
+ mesh_tex_idx: Optional[np.ndarray] = None,
75
+ tex_map: Optional[np.ndarray] = None):
76
+
77
+ self.mesh_v = mesh_v
78
+ self.mesh_f = mesh_f
79
+ self.vertex_colors = vertex_colors
80
+ self.uvs = uvs
81
+ self.mesh_tex_idx = mesh_tex_idx
82
+ self.tex_map = tex_map
83
+
84
+ def contain_uv_texture(self):
85
+ return (self.uvs is not None) and (self.mesh_tex_idx is not None) and (self.tex_map is not None)
86
+
87
+ def contain_vertex_colors(self):
88
+ return self.vertex_colors is not None
89
+
90
+ def export(self, fname):
91
+
92
+ if self.contain_uv_texture():
93
+ savemeshtes2(
94
+ self.mesh_v,
95
+ self.uvs,
96
+ self.mesh_f,
97
+ self.mesh_tex_idx,
98
+ self.tex_map,
99
+ fname
100
+ )
101
+
102
+ elif self.contain_vertex_colors():
103
+ mesh_obj = trimesh.Trimesh(vertices=self.mesh_v, faces=self.mesh_f, vertex_colors=self.vertex_colors)
104
+ mesh_obj.export(fname)
105
+
106
+ else:
107
+ save_obj(
108
+ self.mesh_v,
109
+ self.mesh_f,
110
+ fname
111
+ )
112
+
113
+
114
+
michelangelo/graphics/primitives/volume.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import numpy as np
4
+
5
+
6
+ def generate_dense_grid_points(bbox_min: np.ndarray,
7
+ bbox_max: np.ndarray,
8
+ octree_depth: int,
9
+ indexing: str = "ij"):
10
+ length = bbox_max - bbox_min
11
+ num_cells = np.exp2(octree_depth)
12
+ x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
13
+ y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
14
+ z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
15
+ [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
16
+ xyz = np.stack((xs, ys, zs), axis=-1)
17
+ xyz = xyz.reshape(-1, 3)
18
+ grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
19
+
20
+ return xyz, grid_size, length
21
+
michelangelo/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
michelangelo/models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (183 Bytes). View file
 
michelangelo/models/asl_diffusion/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
michelangelo/models/asl_diffusion/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (197 Bytes). View file
 
michelangelo/models/asl_diffusion/__pycache__/asl_udt.cpython-39.pyc ADDED
Binary file (2.64 kB). View file
 
michelangelo/models/asl_diffusion/__pycache__/clip_asl_diffuser_pl_module.cpython-39.pyc ADDED
Binary file (9.87 kB). View file
 
michelangelo/models/asl_diffusion/__pycache__/inference_utils.cpython-39.pyc ADDED
Binary file (1.75 kB). View file