jiaweir commited on
Commit
5b9bbe2
β€’
1 Parent(s): fc94f83
Files changed (4) hide show
  1. app.py +184 -19
  2. configs/4d_demo.yaml +1 -1
  3. lgm/infer_demo.py +197 -0
  4. main_4d_demo.py +616 -0
app.py CHANGED
@@ -7,6 +7,26 @@ import numpy
7
  import hashlib
8
  import shlex
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  import spaces
11
 
12
 
@@ -27,45 +47,179 @@ function refresh() {
27
  }
28
  """
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  # check if there is a picture uploaded or selected
31
  def check_img_input(control_image):
32
  if control_image is None:
33
  raise gr.Error("Please select or upload an input image")
34
 
35
  # check if there is a picture uploaded or selected
36
- def check_video_input(image_block: Image.Image):
37
  img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
38
  if not os.path.exists(os.path.join('tmp_data', f'{img_hash}_rgba_generated.mp4')):
39
  raise gr.Error("Please generate a video first")
 
 
 
40
 
41
 
42
  @spaces.GPU()
43
- def optimize_stage_1(image_block: Image.Image, preprocess_chk: bool, seed_slider: int):
44
  if not os.path.exists('tmp_data'):
45
  os.makedirs('tmp_data')
46
  img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
47
- if preprocess_chk:
48
- # save image to a designated path
49
- image_block.save(os.path.join('tmp_data', f'{img_hash}.png'))
 
50
 
51
- # preprocess image
52
- print(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}')
53
- subprocess.run(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}', shell=True)
54
- else:
55
- image_block.save(os.path.join('tmp_data', f'{img_hash}_rgba.png'))
 
56
 
57
  # stage 1
58
- subprocess.run(f'export MKL_THREADING_LAYER=GNU;export MKL_SERVICE_FORCE_INTEL=1;python scripts/gen_vid.py --path tmp_data/{img_hash}_rgba.png --seed {seed_slider} --bg white', shell=True)
59
- subprocess.run(f'python lgm/infer.py big --resume {ckpt_path} --test_path tmp_data/{img_hash}_rgba.png', shell=True)
60
  # return [os.path.join('logs', 'tmp_rgba_model.ply')]
61
  return os.path.join('tmp_data', f'{img_hash}_rgba_generated.mp4')
62
 
63
- @spaces.GPU(duration=200)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def optimize_stage_2(image_block: Image.Image, seed_slider: int):
65
  img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
66
 
67
  # stage 2
68
- subprocess.run(f'python main_4d.py --config {os.path.join("configs", "4d_demo.yaml")} input={os.path.join("tmp_data", f"{img_hash}_rgba.png")}', shell=True)
 
69
  # os.rename(os.path.join('logs', f'{img_hash}_rgba_frames'), os.path.join('logs', f'{img_hash}_{seed_slider:03d}_rgba_frames'))
70
  image_dir = os.path.join('logs', f'{img_hash}_rgba_frames')
71
  # return 'vis_data/tmp_rgba.mp4', [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith('.ply')]
@@ -83,7 +237,7 @@ if __name__ == "__main__":
83
  </div>
84
  We present DreamGausssion4D, an efficient 4D generation framework that builds on Gaussian Splatting.
85
  '''
86
- _IMG_USER_GUIDE = "Please upload an image in the block above (or choose an example above), select a random seed, and click **Generate Video**. After having the video generated, please click **Generate 4D**."
87
 
88
  # load images in 'data' folder as examples
89
  example_folder = os.path.join(os.path.dirname(__file__), 'data')
@@ -104,7 +258,8 @@ if __name__ == "__main__":
104
  image_block = gr.Image(type='pil', image_mode='RGBA', height=290, label='Input image')
105
 
106
  # elevation_slider = gr.Slider(-90, 90, value=0, step=1, label='Estimated elevation angle')
107
- seed_slider = gr.Slider(0, 100000, value=0, step=1, label='Random Seed')
 
108
  gr.Markdown(
109
  "random seed for video generation.")
110
 
@@ -120,20 +275,30 @@ if __name__ == "__main__":
120
  examples_per_page=40
121
  )
122
  img_run_btn = gr.Button("Generate Video")
 
123
  fourd_run_btn = gr.Button("Generate 4D")
124
  img_guide_text = gr.Markdown(_IMG_USER_GUIDE, visible=True)
125
 
126
  with gr.Column(scale=5):
127
- obj3d = gr.Video(label="video",height=290)
 
128
  obj4d = Model4DGS(label="4D Model", height=500, fps=14)
129
 
130
- img_run_btn.click(check_img_input, inputs=[image_block], queue=False).success(optimize_stage_1,
 
131
  inputs=[image_block,
132
  preprocess_chk,
133
  seed_slider],
 
 
 
 
 
 
 
134
  outputs=[
135
  obj3d])
136
- fourd_run_btn.click(check_video_input, inputs=[image_block], queue=False).success(optimize_stage_2, inputs=[image_block, seed_slider], outputs=[obj4d])
137
 
138
  # demo.queue().launch(share=True)
139
  demo.queue(max_size=10) # <-- Sets up a queue with default parameters
 
7
  import hashlib
8
  import shlex
9
 
10
+ import rembg
11
+ import glob
12
+ import cv2
13
+ import numpy as np
14
+ from diffusers import StableVideoDiffusionPipeline
15
+ from scripts.gen_vid import *
16
+
17
+ import sys
18
+ sys.path.append('lgm')
19
+ from safetensors.torch import load_file
20
+ from kiui.cam import orbit_camera
21
+ from core.options import config_defaults, Options
22
+ from core.models import LGM
23
+ from mvdream.pipeline_mvdream import MVDreamPipeline
24
+ from infer_demo import process as process_lgm
25
+
26
+ from main_4d_demo import process as process_dg4d
27
+
28
+
29
+
30
  import spaces
31
 
32
 
 
47
  }
48
  """
49
 
50
+
51
+ device = torch.device('cuda')
52
+ # device = torch.device('cpu')
53
+
54
+ session = rembg.new_session(model_name='u2net')
55
+
56
+ pipe = StableVideoDiffusionPipeline.from_pretrained(
57
+ "stabilityai/stable-video-diffusion-img2vid", torch_dtype=torch.float16, variant="fp16"
58
+ )
59
+ pipe.to(device)
60
+
61
+ opt = config_defaults['big']
62
+ opt.resume = ckpt_path
63
+ # model
64
+ model = LGM(opt)
65
+
66
+ # resume pretrained checkpoint
67
+ if opt.resume is not None:
68
+ if opt.resume.endswith('safetensors'):
69
+ ckpt = load_file(opt.resume, device='cpu')
70
+ else:
71
+ ckpt = torch.load(opt.resume, map_location='cpu')
72
+ model.load_state_dict(ckpt, strict=False)
73
+ print(f'[INFO] Loaded checkpoint from {opt.resume}')
74
+ else:
75
+ print(f'[WARN] model randomly initialized, are you sure?')
76
+
77
+ # device
78
+ model = model.half().to(device)
79
+ model.eval()
80
+ rays_embeddings = model.prepare_default_rays(device)
81
+
82
+ # load image dream
83
+ pipe_mvdream = MVDreamPipeline.from_pretrained(
84
+ "ashawkey/imagedream-ipmv-diffusers", # remote weights
85
+ torch_dtype=torch.float16,
86
+ trust_remote_code=True,
87
+ # local_files_only=True,
88
+ )
89
+ pipe_mvdream = pipe_mvdream.to(device)
90
+
91
+ from guidance.zero123_utils import Zero123
92
+ guidance_zero123 = Zero123(device, model_key='ashawkey/stable-zero123-diffusers')
93
+
94
+ def preprocess(path, recenter=True, size=256, border_ratio=0.2):
95
+ files = [path]
96
+ out_dir = os.path.dirname(path)
97
+
98
+ for file in files:
99
+
100
+ out_base = os.path.basename(file).split('.')[0]
101
+ out_rgba = os.path.join(out_dir, out_base + '_rgba.png')
102
+
103
+ # load image
104
+ print(f'[INFO] loading image {file}...')
105
+ image = cv2.imread(file, cv2.IMREAD_UNCHANGED)
106
+
107
+ # carve background
108
+ print(f'[INFO] background removal...')
109
+ carved_image = rembg.remove(image, session=session) # [H, W, 4]
110
+ mask = carved_image[..., -1] > 0
111
+
112
+ # recenter
113
+ if recenter:
114
+ print(f'[INFO] recenter...')
115
+ final_rgba = np.zeros((size, size, 4), dtype=np.uint8)
116
+
117
+ coords = np.nonzero(mask)
118
+ x_min, x_max = coords[0].min(), coords[0].max()
119
+ y_min, y_max = coords[1].min(), coords[1].max()
120
+ h = x_max - x_min
121
+ w = y_max - y_min
122
+ desired_size = int(size * (1 - border_ratio))
123
+ scale = desired_size / max(h, w)
124
+ h2 = int(h * scale)
125
+ w2 = int(w * scale)
126
+ x2_min = (size - h2) // 2
127
+ x2_max = x2_min + h2
128
+ y2_min = (size - w2) // 2
129
+ y2_max = y2_min + w2
130
+ final_rgba[x2_min:x2_max, y2_min:y2_max] = cv2.resize(carved_image[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA)
131
+
132
+ else:
133
+ final_rgba = carved_image
134
+
135
+ # write image
136
+ cv2.imwrite(out_rgba, final_rgba)
137
+
138
+ def gen_vid(input_path, seed, bg='white'):
139
+ name = input_path.split('/')[-1].split('.')[0]
140
+ input_dir = os.path.dirname(input_path)
141
+ height, width = 512, 512
142
+
143
+ image = load_image(input_path, width, height, bg)
144
+
145
+ generator = torch.manual_seed(seed)
146
+ # frames = pipe(image, height, width, decode_chunk_size=2, generator=generator).frames[0]
147
+ frames = pipe(image, height, width, generator=generator).frames[0]
148
+
149
+ imageio.mimwrite(f"{input_dir}/{name}_generated.mp4", frames, fps=7)
150
+ os.makedirs(f"{input_dir}/{name}_frames", exist_ok=True)
151
+ for idx, img in enumerate(frames):
152
+ img.save(f"{input_dir}/{name}_frames/{idx:03}.png")
153
+
154
  # check if there is a picture uploaded or selected
155
  def check_img_input(control_image):
156
  if control_image is None:
157
  raise gr.Error("Please select or upload an input image")
158
 
159
  # check if there is a picture uploaded or selected
160
+ def check_video_3d_input(image_block: Image.Image):
161
  img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
162
  if not os.path.exists(os.path.join('tmp_data', f'{img_hash}_rgba_generated.mp4')):
163
  raise gr.Error("Please generate a video first")
164
+ if not os.path.exists(os.path.join('vis_data', f'{img_hash}_rgba_static.mp4')):
165
+ raise gr.Error("Please generate a 3D first")
166
+
167
 
168
 
169
  @spaces.GPU()
170
+ def optimize_stage_0(image_block: Image.Image, preprocess_chk: bool, seed_slider: int):
171
  if not os.path.exists('tmp_data'):
172
  os.makedirs('tmp_data')
173
  img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
174
+ if not os.path.exists(os.path.join('tmp_data', f'{img_hash}_rgba.png')):
175
+ if preprocess_chk:
176
+ # save image to a designated path
177
+ image_block.save(os.path.join('tmp_data', f'{img_hash}.png'))
178
 
179
+ # preprocess image
180
+ # print(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}')
181
+ # subprocess.run(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}', shell=True)
182
+ preprocess(os.path.join("tmp_data", f"{img_hash}.png"))
183
+ else:
184
+ image_block.save(os.path.join('tmp_data', f'{img_hash}_rgba.png'))
185
 
186
  # stage 1
187
+ # subprocess.run(f'export MKL_THREADING_LAYER=GNU;export MKL_SERVICE_FORCE_INTEL=1;python scripts/gen_vid.py --path tmp_data/{img_hash}_rgba.png --seed {seed_slider} --bg white', shell=True)
188
+ gen_vid(f'tmp_data/{img_hash}_rgba.png', seed_slider)
189
  # return [os.path.join('logs', 'tmp_rgba_model.ply')]
190
  return os.path.join('tmp_data', f'{img_hash}_rgba_generated.mp4')
191
 
192
+
193
+ @spaces.GPU()
194
+ def optimize_stage_1(image_block: Image.Image, preprocess_chk: bool, seed_slider: int):
195
+ if not os.path.exists('tmp_data'):
196
+ os.makedirs('tmp_data')
197
+ img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
198
+ if not os.path.exists(os.path.join('tmp_data', f'{img_hash}_rgba.png')):
199
+ if preprocess_chk:
200
+ # save image to a designated path
201
+ image_block.save(os.path.join('tmp_data', f'{img_hash}.png'))
202
+
203
+ # preprocess image
204
+ # print(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}')
205
+ # subprocess.run(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}', shell=True)
206
+ preprocess(os.path.join("tmp_data", f"{img_hash}.png"))
207
+ else:
208
+ image_block.save(os.path.join('tmp_data', f'{img_hash}_rgba.png'))
209
+
210
+ # stage 1
211
+ # subprocess.run(f'python lgm/infer.py big --resume {ckpt_path} --test_path tmp_data/{img_hash}_rgba.png', shell=True)
212
+ process_lgm(opt, f'tmp_data/{img_hash}_rgba.png', pipe_mvdream, model, rays_embeddings)
213
+ # return [os.path.join('logs', 'tmp_rgba_model.ply')]
214
+ return os.path.join('vis_data', f'{img_hash}_rgba_static.mp4')
215
+
216
+ @spaces.GPU(duration=120)
217
  def optimize_stage_2(image_block: Image.Image, seed_slider: int):
218
  img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
219
 
220
  # stage 2
221
+ # subprocess.run(f'python main_4d.py --config {os.path.join("configs", "4d_demo.yaml")} input={os.path.join("tmp_data", f"{img_hash}_rgba.png")}', shell=True)
222
+ process_dg4d(os.path.join("configs", "4d_demo.yaml"), os.path.join("tmp_data", f"{img_hash}_rgba.png"), guidance_zero123)
223
  # os.rename(os.path.join('logs', f'{img_hash}_rgba_frames'), os.path.join('logs', f'{img_hash}_{seed_slider:03d}_rgba_frames'))
224
  image_dir = os.path.join('logs', f'{img_hash}_rgba_frames')
225
  # return 'vis_data/tmp_rgba.mp4', [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith('.ply')]
 
237
  </div>
238
  We present DreamGausssion4D, an efficient 4D generation framework that builds on Gaussian Splatting.
239
  '''
240
+ _IMG_USER_GUIDE = "Please upload an image in the block above (or choose an example above), click **Generate Video** and **Generate 3D**. Finally, click **Generate 4D**."
241
 
242
  # load images in 'data' folder as examples
243
  example_folder = os.path.join(os.path.dirname(__file__), 'data')
 
258
  image_block = gr.Image(type='pil', image_mode='RGBA', height=290, label='Input image')
259
 
260
  # elevation_slider = gr.Slider(-90, 90, value=0, step=1, label='Estimated elevation angle')
261
+ seed_slider = gr.Slider(0, 100000, value=0, step=1, label='Random Seed (Video)')
262
+ seed_slider2 = gr.Slider(0, 100000, value=0, step=1, label='Random Seed (3D)')
263
  gr.Markdown(
264
  "random seed for video generation.")
265
 
 
275
  examples_per_page=40
276
  )
277
  img_run_btn = gr.Button("Generate Video")
278
+ threed_run_btn = gr.Button("Generate 3D")
279
  fourd_run_btn = gr.Button("Generate 4D")
280
  img_guide_text = gr.Markdown(_IMG_USER_GUIDE, visible=True)
281
 
282
  with gr.Column(scale=5):
283
+ dirving_video = gr.Video(label="video",height=290)
284
+ obj3d = gr.Video(label="3D Model",height=290)
285
  obj4d = Model4DGS(label="4D Model", height=500, fps=14)
286
 
287
+
288
+ img_run_btn.click(check_img_input, inputs=[image_block], queue=False).success(optimize_stage_0,
289
  inputs=[image_block,
290
  preprocess_chk,
291
  seed_slider],
292
+ outputs=[
293
+ dirving_video])
294
+
295
+ threed_run_btn.click(check_img_input, inputs=[image_block], queue=False).success(optimize_stage_1,
296
+ inputs=[image_block,
297
+ preprocess_chk,
298
+ seed_slider2],
299
  outputs=[
300
  obj3d])
301
+ fourd_run_btn.click(check_video_3d_input, inputs=[image_block], queue=False).success(optimize_stage_2, inputs=[image_block, seed_slider], outputs=[obj4d])
302
 
303
  # demo.queue().launch(share=True)
304
  demo.queue(max_size=10) # <-- Sets up a queue with default parameters
configs/4d_demo.yaml CHANGED
@@ -30,7 +30,7 @@ lambda_svd: 0
30
  # training batch size per iter
31
  batch_size: 7
32
  # training iterations for stage 1
33
- iters: 500
34
  # training iterations for stage 2
35
  iters_refine: 50
36
  # training camera radius
 
30
  # training batch size per iter
31
  batch_size: 7
32
  # training iterations for stage 1
33
+ iters: 300
34
  # training iterations for stage 2
35
  iters_refine: 50
36
  # training camera radius
lgm/infer_demo.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import tyro
4
+ import glob
5
+ import imageio
6
+ import numpy as np
7
+ import tqdm
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torchvision.transforms.functional as TF
12
+ from safetensors.torch import load_file
13
+
14
+ import kiui
15
+ from kiui.op import recenter
16
+ from kiui.cam import orbit_camera
17
+
18
+ from core.options import AllConfigs, Options
19
+ from core.models import LGM
20
+ from mvdream.pipeline_mvdream import MVDreamPipeline
21
+ import cv2
22
+
23
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
24
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
25
+
26
+ # opt = tyro.cli(AllConfigs)
27
+
28
+ # # model
29
+ # model = LGM(opt)
30
+
31
+ # # resume pretrained checkpoint
32
+ # if opt.resume is not None:
33
+ # if opt.resume.endswith('safetensors'):
34
+ # ckpt = load_file(opt.resume, device='cpu')
35
+ # else:
36
+ # ckpt = torch.load(opt.resume, map_location='cpu')
37
+ # model.load_state_dict(ckpt, strict=False)
38
+ # print(f'[INFO] Loaded checkpoint from {opt.resume}')
39
+ # else:
40
+ # print(f'[WARN] model randomly initialized, are you sure?')
41
+
42
+ # # device
43
+ # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
44
+ # model = model.half().to(device)
45
+ # model.eval()
46
+
47
+
48
+
49
+ # process function
50
+ def process(opt: Options, path, pipe, model, rays_embeddings):
51
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
52
+ tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
53
+ proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=device)
54
+ proj_matrix[0, 0] = 1 / tan_half_fov
55
+ proj_matrix[1, 1] = 1 / tan_half_fov
56
+ proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
57
+ proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
58
+ proj_matrix[2, 3] = 1
59
+
60
+
61
+ name = os.path.splitext(os.path.basename(path))[0]
62
+ print(f'[INFO] Processing {path} --> {name}')
63
+ os.makedirs('vis_data', exist_ok=True)
64
+ os.makedirs('logs', exist_ok=True)
65
+
66
+ image = kiui.read_image(path, mode='uint8')
67
+
68
+ # generate mv
69
+ image = image.astype(np.float32) / 255.0
70
+
71
+ # rgba to rgb white bg
72
+ if image.shape[-1] == 4:
73
+ image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])
74
+
75
+ mv_image = pipe('', image, guidance_scale=5.0, num_inference_steps=30, elevation=0)
76
+ mv_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0) # [4, 256, 256, 3], float32
77
+
78
+ # generate gaussians
79
+ input_image = torch.from_numpy(mv_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256]
80
+ input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False)
81
+ input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
82
+
83
+ input_image = torch.cat([input_image, rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W]
84
+
85
+ with torch.inference_mode():
86
+ ############## align azimuth #####################
87
+ with torch.autocast(device_type='cuda', dtype=torch.float16):
88
+ # generate gaussians
89
+ gaussians = model.forward_gaussians(input_image)
90
+
91
+ best_azi = 0
92
+ best_diff = 1e8
93
+ for v, azi in enumerate(np.arange(-180, 180, 1)):
94
+ cam_poses = torch.from_numpy(orbit_camera(0, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
95
+
96
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
97
+
98
+ # cameras needed by gaussian rasterizer
99
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
100
+ cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
101
+ cam_pos = - cam_poses[:, :3, 3] # [V, 3]
102
+
103
+ # scale = min(azi / 360, 1)
104
+ scale = 1
105
+
106
+
107
+ result = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=scale)
108
+ rendered_image = result['image']
109
+
110
+ rendered_image = rendered_image.squeeze(1).permute(0,2,3,1).squeeze(0).contiguous().float().cpu().numpy()
111
+ rendered_image = cv2.resize(rendered_image, (image.shape[0], image.shape[1]), interpolation=cv2.INTER_AREA)
112
+
113
+ diff = np.mean((rendered_image- image) ** 2)
114
+
115
+ if diff < best_diff:
116
+ best_diff = diff
117
+ best_azi = azi
118
+ print("Best aligned azimuth: ", best_azi)
119
+
120
+ mv_image = []
121
+ for v, azi in enumerate([0, 90, 180, 270]):
122
+ cam_poses = torch.from_numpy(orbit_camera(0, azi + best_azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
123
+
124
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
125
+
126
+ # cameras needed by gaussian rasterizer
127
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
128
+ cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
129
+ cam_pos = - cam_poses[:, :3, 3] # [V, 3]
130
+
131
+ # scale = min(azi / 360, 1)
132
+ scale = 1
133
+
134
+
135
+ result = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=scale)
136
+ rendered_image = result['image']
137
+ rendered_image = rendered_image.squeeze(1)
138
+ rendered_image = F.interpolate(rendered_image, (256, 256))
139
+ rendered_image = rendered_image.permute(0,2,3,1).contiguous().float().cpu().numpy()
140
+ mv_image.append(rendered_image)
141
+ mv_image = np.concatenate(mv_image, axis=0)
142
+
143
+ input_image = torch.from_numpy(mv_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256]
144
+ input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False)
145
+ input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
146
+
147
+ input_image = torch.cat([input_image, rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W]
148
+
149
+ ################################
150
+
151
+ with torch.autocast(device_type='cuda', dtype=torch.float16):
152
+ # generate gaussians
153
+ gaussians = model.forward_gaussians(input_image)
154
+
155
+ # save gaussians
156
+ model.gs.save_ply(gaussians, os.path.join('logs', name + '_model.ply'))
157
+
158
+ # render 360 video
159
+ images = []
160
+ elevation = 0
161
+
162
+ if opt.fancy_video:
163
+
164
+ azimuth = np.arange(0, 720, 4, dtype=np.int32)
165
+ for azi in tqdm.tqdm(azimuth):
166
+
167
+ cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
168
+
169
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
170
+
171
+ # cameras needed by gaussian rasterizer
172
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
173
+ cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
174
+ cam_pos = - cam_poses[:, :3, 3] # [V, 3]
175
+
176
+ scale = min(azi / 360, 1)
177
+
178
+ image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=scale)['image']
179
+ images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
180
+ else:
181
+ azimuth = np.arange(0, 360, 2, dtype=np.int32)
182
+ for azi in tqdm.tqdm(azimuth):
183
+
184
+ cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
185
+
186
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
187
+
188
+ # cameras needed by gaussian rasterizer
189
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
190
+ cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
191
+ cam_pos = - cam_poses[:, :3, 3] # [V, 3]
192
+
193
+ image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=1)['image']
194
+ images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
195
+
196
+ images = np.concatenate(images, axis=0)
197
+ imageio.mimwrite(os.path.join('vis_data', name + '_static.mp4'), images, fps=30)
main_4d_demo.py ADDED
@@ -0,0 +1,616 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import tqdm
5
+ import numpy as np
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ import rembg
11
+
12
+ from cam_utils import orbit_camera, OrbitCamera
13
+ from gs_renderer_4d import Renderer, MiniCam
14
+
15
+ from grid_put import mipmap_linear_grid_put_2d
16
+ import imageio
17
+
18
+ import copy
19
+ from omegaconf import OmegaConf
20
+
21
+ class GUI:
22
+ def __init__(self, opt, guidance_zero123):
23
+ self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
24
+ self.gui = opt.gui # enable gui
25
+ self.W = opt.W
26
+ self.H = opt.H
27
+ self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)
28
+
29
+ self.mode = "image"
30
+ # self.seed = "random"
31
+ self.seed = 888
32
+
33
+ self.buffer_image = np.ones((self.W, self.H, 3), dtype=np.float32)
34
+ self.need_update = True # update buffer_image
35
+
36
+ # models
37
+ self.device = torch.device("cuda")
38
+ self.bg_remover = None
39
+
40
+ self.guidance_sd = None
41
+ self.guidance_zero123 = guidance_zero123
42
+ self.guidance_svd = None
43
+
44
+
45
+ self.enable_sd = False
46
+ self.enable_zero123 = False
47
+ self.enable_svd = False
48
+
49
+
50
+ # renderer
51
+ self.renderer = Renderer(self.opt, sh_degree=self.opt.sh_degree)
52
+ self.gaussain_scale_factor = 1
53
+
54
+ # input image
55
+ self.input_img = None
56
+ self.input_mask = None
57
+ self.input_img_torch = None
58
+ self.input_mask_torch = None
59
+ self.overlay_input_img = False
60
+ self.overlay_input_img_ratio = 0.5
61
+
62
+ self.input_img_list = None
63
+ self.input_mask_list = None
64
+ self.input_img_torch_list = None
65
+ self.input_mask_torch_list = None
66
+
67
+ # input text
68
+ self.prompt = ""
69
+ self.negative_prompt = ""
70
+
71
+ # training stuff
72
+ self.training = False
73
+ self.optimizer = None
74
+ self.step = 0
75
+ self.train_steps = 1 # steps per rendering loop
76
+
77
+ # load input data from cmdline
78
+ if self.opt.input is not None: # True
79
+ self.load_input(self.opt.input) # load imgs, if has bg, then rm bg; or just load imgs
80
+
81
+ # override prompt from cmdline
82
+ if self.opt.prompt is not None: # None
83
+ self.prompt = self.opt.prompt
84
+
85
+ # override if provide a checkpoint
86
+ if self.opt.load is not None: # not None
87
+ self.renderer.initialize(self.opt.load)
88
+ # self.renderer.gaussians.load_model(opt.outdir, opt.save_path)
89
+ else:
90
+ # initialize gaussians to a blob
91
+ self.renderer.initialize(num_pts=self.opt.num_pts)
92
+
93
+ self.seed_everything()
94
+
95
+ def seed_everything(self):
96
+ try:
97
+ seed = int(self.seed)
98
+ except:
99
+ seed = np.random.randint(0, 1000000)
100
+
101
+ print(f'Seed: {seed:d}')
102
+ os.environ["PYTHONHASHSEED"] = str(seed)
103
+ np.random.seed(seed)
104
+ torch.manual_seed(seed)
105
+ torch.cuda.manual_seed(seed)
106
+ torch.backends.cudnn.deterministic = True
107
+ torch.backends.cudnn.benchmark = True
108
+
109
+ self.last_seed = seed
110
+
111
+ def prepare_train(self):
112
+
113
+ self.step = 0
114
+
115
+ # setup training
116
+ self.renderer.gaussians.training_setup(self.opt)
117
+
118
+ # # do not do progressive sh-level
119
+ self.renderer.gaussians.active_sh_degree = self.renderer.gaussians.max_sh_degree
120
+ self.optimizer = self.renderer.gaussians.optimizer
121
+
122
+ # default camera
123
+ if self.opt.mvdream or self.opt.imagedream:
124
+ # the second view is the front view for mvdream/imagedream.
125
+ pose = orbit_camera(self.opt.elevation, 90, self.opt.radius)
126
+ else:
127
+ pose = orbit_camera(self.opt.elevation, 0, self.opt.radius)
128
+ self.fixed_cam = MiniCam(
129
+ pose,
130
+ self.opt.ref_size,
131
+ self.opt.ref_size,
132
+ self.cam.fovy,
133
+ self.cam.fovx,
134
+ self.cam.near,
135
+ self.cam.far,
136
+ )
137
+
138
+ self.enable_sd = self.opt.lambda_sd > 0
139
+ self.enable_zero123 = self.opt.lambda_zero123 > 0
140
+ self.enable_svd = self.opt.lambda_svd > 0 and self.input_img is not None
141
+
142
+ # lazy load guidance model
143
+ if self.guidance_sd is None and self.enable_sd:
144
+ if self.opt.mvdream:
145
+ print(f"[INFO] loading MVDream...")
146
+ from guidance.mvdream_utils import MVDream
147
+ self.guidance_sd = MVDream(self.device)
148
+ print(f"[INFO] loaded MVDream!")
149
+ elif self.opt.imagedream:
150
+ print(f"[INFO] loading ImageDream...")
151
+ from guidance.imagedream_utils import ImageDream
152
+ self.guidance_sd = ImageDream(self.device)
153
+ print(f"[INFO] loaded ImageDream!")
154
+ else:
155
+ print(f"[INFO] loading SD...")
156
+ from guidance.sd_utils import StableDiffusion
157
+ self.guidance_sd = StableDiffusion(self.device)
158
+ print(f"[INFO] loaded SD!")
159
+
160
+ if self.guidance_zero123 is None and self.enable_zero123:
161
+ print(f"[INFO] loading zero123...")
162
+ from guidance.zero123_utils import Zero123
163
+ if self.opt.stable_zero123:
164
+ self.guidance_zero123 = Zero123(self.device, model_key='ashawkey/stable-zero123-diffusers')
165
+ else:
166
+ self.guidance_zero123 = Zero123(self.device, model_key='ashawkey/zero123-xl-diffusers')
167
+ print(f"[INFO] loaded zero123!")
168
+
169
+ if self.guidance_svd is None and self.enable_svd: # False
170
+ print(f"[INFO] loading SVD...")
171
+ from guidance.svd_utils import StableVideoDiffusion
172
+ self.guidance_svd = StableVideoDiffusion(self.device)
173
+ print(f"[INFO] loaded SVD!")
174
+
175
+ # input image
176
+ if self.input_img is not None:
177
+ self.input_img_torch = torch.from_numpy(self.input_img).permute(2, 0, 1).unsqueeze(0).to(self.device)
178
+ self.input_img_torch = F.interpolate(self.input_img_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False)
179
+
180
+ self.input_mask_torch = torch.from_numpy(self.input_mask).permute(2, 0, 1).unsqueeze(0).to(self.device)
181
+ self.input_mask_torch = F.interpolate(self.input_mask_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False)
182
+
183
+ if self.input_img_list is not None:
184
+ self.input_img_torch_list = [torch.from_numpy(input_img).permute(2, 0, 1).unsqueeze(0).to(self.device) for input_img in self.input_img_list]
185
+ self.input_img_torch_list = [F.interpolate(input_img_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False) for input_img_torch in self.input_img_torch_list]
186
+
187
+ self.input_mask_torch_list = [torch.from_numpy(input_mask).permute(2, 0, 1).unsqueeze(0).to(self.device) for input_mask in self.input_mask_list]
188
+ self.input_mask_torch_list = [F.interpolate(input_mask_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False) for input_mask_torch in self.input_mask_torch_list]
189
+ # prepare embeddings
190
+ with torch.no_grad():
191
+
192
+ if self.enable_sd:
193
+ if self.opt.imagedream:
194
+ img_pos_list, img_neg_list, ip_pos_list, ip_neg_list, emb_pos_list, emb_neg_list = [], [], [], [], [], []
195
+ for _ in range(self.opt.n_views):
196
+ for input_img_torch in self.input_img_torch_list:
197
+ img_pos, img_neg, ip_pos, ip_neg, emb_pos, emb_neg = self.guidance_sd.get_image_text_embeds(input_img_torch, [self.prompt], [self.negative_prompt])
198
+ img_pos_list.append(img_pos)
199
+ img_neg_list.append(img_neg)
200
+ ip_pos_list.append(ip_pos)
201
+ ip_neg_list.append(ip_neg)
202
+ emb_pos_list.append(emb_pos)
203
+ emb_neg_list.append(emb_neg)
204
+ self.guidance_sd.image_embeddings['pos'] = torch.cat(img_pos_list, 0)
205
+ self.guidance_sd.image_embeddings['neg'] = torch.cat(img_pos_list, 0)
206
+ self.guidance_sd.image_embeddings['ip_img'] = torch.cat(ip_pos_list, 0)
207
+ self.guidance_sd.image_embeddings['neg_ip_img'] = torch.cat(ip_neg_list, 0)
208
+ self.guidance_sd.embeddings['pos'] = torch.cat(emb_pos_list, 0)
209
+ self.guidance_sd.embeddings['neg'] = torch.cat(emb_neg_list, 0)
210
+ else:
211
+ self.guidance_sd.get_text_embeds([self.prompt], [self.negative_prompt])
212
+
213
+ if self.enable_zero123:
214
+ c_list, v_list = [], []
215
+ for _ in range(self.opt.n_views):
216
+ for input_img_torch in self.input_img_torch_list:
217
+ c, v = self.guidance_zero123.get_img_embeds(input_img_torch)
218
+ c_list.append(c)
219
+ v_list.append(v)
220
+ self.guidance_zero123.embeddings = [torch.cat(c_list, 0), torch.cat(v_list, 0)]
221
+
222
+ if self.enable_svd:
223
+ self.guidance_svd.get_img_embeds(self.input_img)
224
+
225
+ def train_step(self):
226
+ starter = torch.cuda.Event(enable_timing=True)
227
+ ender = torch.cuda.Event(enable_timing=True)
228
+ starter.record()
229
+
230
+ for _ in range(self.train_steps): # 1
231
+
232
+ self.step += 1 # self.step starts from 0
233
+ step_ratio = min(1, self.step / self.opt.iters) # 1, step / 500
234
+
235
+ # update lr
236
+ self.renderer.gaussians.update_learning_rate(self.step)
237
+
238
+ loss = 0
239
+
240
+ self.renderer.prepare_render()
241
+
242
+ ### known view
243
+ if not self.opt.imagedream:
244
+ for b_idx in range(self.opt.batch_size):
245
+ cur_cam = copy.deepcopy(self.fixed_cam)
246
+ cur_cam.time = b_idx
247
+ out = self.renderer.render(cur_cam)
248
+
249
+ # rgb loss
250
+ image = out["image"].unsqueeze(0) # [1, 3, H, W] in [0, 1]
251
+ loss = loss + 10000 * step_ratio * F.mse_loss(image, self.input_img_torch_list[b_idx]) / self.opt.batch_size
252
+
253
+ # mask loss
254
+ mask = out["alpha"].unsqueeze(0) # [1, 1, H, W] in [0, 1]
255
+ loss = loss + 1000 * step_ratio * F.mse_loss(mask, self.input_mask_torch_list[b_idx]) / self.opt.batch_size
256
+
257
+ ### novel view (manual batch)
258
+ render_resolution = 128 if step_ratio < 0.3 else (256 if step_ratio < 0.6 else 512)
259
+ # render_resolution = 512
260
+ images = []
261
+ poses = []
262
+ vers, hors, radii = [], [], []
263
+ # avoid too large elevation (> 80 or < -80), and make sure it always cover [-30, 30]
264
+ min_ver = max(min(self.opt.min_ver, self.opt.min_ver - self.opt.elevation), -80 - self.opt.elevation)
265
+ max_ver = min(max(self.opt.max_ver, self.opt.max_ver - self.opt.elevation), 80 - self.opt.elevation)
266
+
267
+ for _ in range(self.opt.n_views):
268
+ for b_idx in range(self.opt.batch_size):
269
+
270
+ # render random view
271
+ ver = np.random.randint(min_ver, max_ver)
272
+ hor = np.random.randint(-180, 180)
273
+ radius = 0
274
+
275
+ vers.append(ver)
276
+ hors.append(hor)
277
+ radii.append(radius)
278
+
279
+ pose = orbit_camera(self.opt.elevation + ver, hor, self.opt.radius + radius)
280
+ poses.append(pose)
281
+
282
+ cur_cam = MiniCam(pose, render_resolution, render_resolution, self.cam.fovy, self.cam.fovx, self.cam.near, self.cam.far, time=b_idx)
283
+
284
+ bg_color = torch.tensor([1, 1, 1] if np.random.rand() > self.opt.invert_bg_prob else [0, 0, 0], dtype=torch.float32, device="cuda")
285
+ out = self.renderer.render(cur_cam, bg_color=bg_color)
286
+
287
+ image = out["image"].unsqueeze(0) # [1, 3, H, W] in [0, 1]
288
+ images.append(image)
289
+
290
+ # enable mvdream training
291
+ if self.opt.mvdream or self.opt.imagedream: # False
292
+ for view_i in range(1, 4):
293
+ pose_i = orbit_camera(self.opt.elevation + ver, hor + 90 * view_i, self.opt.radius + radius)
294
+ poses.append(pose_i)
295
+
296
+ cur_cam_i = MiniCam(pose_i, render_resolution, render_resolution, self.cam.fovy, self.cam.fovx, self.cam.near, self.cam.far)
297
+
298
+ # bg_color = torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32, device="cuda")
299
+ out_i = self.renderer.render(cur_cam_i, bg_color=bg_color)
300
+
301
+ image = out_i["image"].unsqueeze(0) # [1, 3, H, W] in [0, 1]
302
+ images.append(image)
303
+
304
+
305
+
306
+ images = torch.cat(images, dim=0)
307
+ poses = torch.from_numpy(np.stack(poses, axis=0)).to(self.device)
308
+
309
+ # guidance loss
310
+ if self.enable_sd:
311
+ if self.opt.mvdream or self.opt.imagedream:
312
+ loss = loss + self.opt.lambda_sd * self.guidance_sd.train_step(images, poses, step_ratio)
313
+ else:
314
+ loss = loss + self.opt.lambda_sd * self.guidance_sd.train_step(images, step_ratio)
315
+
316
+ if self.enable_zero123:
317
+ loss = loss + self.opt.lambda_zero123 * self.guidance_zero123.train_step(images, vers, hors, radii, step_ratio) / (self.opt.batch_size * self.opt.n_views)
318
+
319
+ if self.enable_svd:
320
+ loss = loss + self.opt.lambda_svd * self.guidance_svd.train_step(images, step_ratio)
321
+
322
+ # optimize step
323
+ loss.backward()
324
+ self.optimizer.step()
325
+ self.optimizer.zero_grad()
326
+
327
+ # densify and prune
328
+ if self.step >= self.opt.density_start_iter and self.step <= self.opt.density_end_iter:
329
+ viewspace_point_tensor, visibility_filter, radii = out["viewspace_points"], out["visibility_filter"], out["radii"]
330
+ self.renderer.gaussians.max_radii2D[visibility_filter] = torch.max(self.renderer.gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
331
+ self.renderer.gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
332
+
333
+ if self.step % self.opt.densification_interval == 0:
334
+ # size_threshold = 20 if self.step > self.opt.opacity_reset_interval else None
335
+ self.renderer.gaussians.densify_and_prune(self.opt.densify_grad_threshold, min_opacity=0.01, extent=0.5, max_screen_size=1)
336
+
337
+ if self.step % self.opt.opacity_reset_interval == 0:
338
+ self.renderer.gaussians.reset_opacity()
339
+
340
+ ender.record()
341
+ torch.cuda.synchronize()
342
+ t = starter.elapsed_time(ender)
343
+
344
+ self.need_update = True
345
+
346
+
347
+ def load_input(self, file):
348
+ if self.opt.data_mode == 'c4d':
349
+ file_list = [os.path.join(file, f'{x * self.opt.downsample_rate}.png') for x in range(self.opt.batch_size)]
350
+ elif self.opt.data_mode == 'svd':
351
+ # file_list = [file.replace('.png', f'_frames/{x* self.opt.downsample_rate:03d}_rgba.png') for x in range(self.opt.batch_size)]
352
+ # file_list = [x if os.path.exists(x) else (x.replace('_rgba.png', '.png')) for x in file_list]
353
+ file_list = [file.replace('.png', f'_frames/{x* self.opt.downsample_rate:03d}.png') for x in range(self.opt.batch_size)]
354
+ else:
355
+ raise NotImplementedError
356
+ self.input_img_list, self.input_mask_list = [], []
357
+ for file in file_list:
358
+ # load image
359
+ print(f'[INFO] load image from {file}...')
360
+ img = cv2.imread(file, cv2.IMREAD_UNCHANGED)
361
+ if img.shape[-1] == 3:
362
+ if self.bg_remover is None:
363
+ self.bg_remover = rembg.new_session()
364
+ img = rembg.remove(img, session=self.bg_remover)
365
+ # cv2.imwrite(file.replace('.png', '_rgba.png'), img)
366
+ img = cv2.resize(img, (self.W, self.H), interpolation=cv2.INTER_AREA)
367
+ img = img.astype(np.float32) / 255.0
368
+ input_mask = img[..., 3:]
369
+ # white bg
370
+ input_img = img[..., :3] * input_mask + (1 - input_mask)
371
+ # bgr to rgb
372
+ input_img = input_img[..., ::-1].copy()
373
+ self.input_img_list.append(input_img)
374
+ self.input_mask_list.append(input_mask)
375
+
376
+ @torch.no_grad()
377
+ def save_model(self, mode='geo', texture_size=1024, interp=1):
378
+ os.makedirs(self.opt.outdir, exist_ok=True)
379
+ if mode == 'geo':
380
+ path = f'logs/{opt.save_path}_mesh_{t:03d}.ply'
381
+ mesh = self.renderer.gaussians.extract_mesh_t(path, self.opt.density_thresh, t=t)
382
+ mesh.write_ply(path)
383
+
384
+ elif mode == 'geo+tex':
385
+ from mesh import Mesh, safe_normalize
386
+ os.makedirs(os.path.join(self.opt.outdir, self.opt.save_path+'_meshes'), exist_ok=True)
387
+ for t in range(self.opt.batch_size):
388
+ path = os.path.join(self.opt.outdir, self.opt.save_path+'_meshes', f'{t:03d}.obj')
389
+ mesh = self.renderer.gaussians.extract_mesh_t(path, self.opt.density_thresh, t=t)
390
+
391
+ # perform texture extraction
392
+ print(f"[INFO] unwrap uv...")
393
+ h = w = texture_size
394
+ mesh.auto_uv()
395
+ mesh.auto_normal()
396
+
397
+ albedo = torch.zeros((h, w, 3), device=self.device, dtype=torch.float32)
398
+ cnt = torch.zeros((h, w, 1), device=self.device, dtype=torch.float32)
399
+
400
+ vers = [0] * 8 + [-45] * 8 + [45] * 8 + [-89.9, 89.9]
401
+ hors = [0, 45, -45, 90, -90, 135, -135, 180] * 3 + [0, 0]
402
+
403
+ render_resolution = 512
404
+
405
+ import nvdiffrast.torch as dr
406
+
407
+ if not self.opt.force_cuda_rast and (not self.opt.gui or os.name == 'nt'):
408
+ glctx = dr.RasterizeGLContext()
409
+ else:
410
+ glctx = dr.RasterizeCudaContext()
411
+
412
+ for ver, hor in zip(vers, hors):
413
+ # render image
414
+ pose = orbit_camera(ver, hor, self.cam.radius)
415
+
416
+ cur_cam = MiniCam(
417
+ pose,
418
+ render_resolution,
419
+ render_resolution,
420
+ self.cam.fovy,
421
+ self.cam.fovx,
422
+ self.cam.near,
423
+ self.cam.far,
424
+ time=t
425
+ )
426
+
427
+ cur_out = self.renderer.render(cur_cam)
428
+
429
+ rgbs = cur_out["image"].unsqueeze(0) # [1, 3, H, W] in [0, 1]
430
+
431
+ # get coordinate in texture image
432
+ pose = torch.from_numpy(pose.astype(np.float32)).to(self.device)
433
+ proj = torch.from_numpy(self.cam.perspective.astype(np.float32)).to(self.device)
434
+
435
+ v_cam = torch.matmul(F.pad(mesh.v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0)
436
+ v_clip = v_cam @ proj.T
437
+ rast, rast_db = dr.rasterize(glctx, v_clip, mesh.f, (render_resolution, render_resolution))
438
+
439
+ depth, _ = dr.interpolate(-v_cam[..., [2]], rast, mesh.f) # [1, H, W, 1]
440
+ depth = depth.squeeze(0) # [H, W, 1]
441
+
442
+ alpha = (rast[0, ..., 3:] > 0).float()
443
+
444
+ uvs, _ = dr.interpolate(mesh.vt.unsqueeze(0), rast, mesh.ft) # [1, 512, 512, 2] in [0, 1]
445
+
446
+ # use normal to produce a back-project mask
447
+ normal, _ = dr.interpolate(mesh.vn.unsqueeze(0).contiguous(), rast, mesh.fn)
448
+ normal = safe_normalize(normal[0])
449
+
450
+ # rotated normal (where [0, 0, 1] always faces camera)
451
+ rot_normal = normal @ pose[:3, :3]
452
+ viewcos = rot_normal[..., [2]]
453
+
454
+ mask = (alpha > 0) & (viewcos > 0.5) # [H, W, 1]
455
+ mask = mask.view(-1)
456
+
457
+ uvs = uvs.view(-1, 2).clamp(0, 1)[mask]
458
+ rgbs = rgbs.view(3, -1).permute(1, 0)[mask].contiguous()
459
+
460
+ # update texture image
461
+ cur_albedo, cur_cnt = mipmap_linear_grid_put_2d(
462
+ h, w,
463
+ uvs[..., [1, 0]] * 2 - 1,
464
+ rgbs,
465
+ min_resolution=256,
466
+ return_count=True,
467
+ )
468
+
469
+ mask = cnt.squeeze(-1) < 0.1
470
+ albedo[mask] += cur_albedo[mask]
471
+ cnt[mask] += cur_cnt[mask]
472
+
473
+ mask = cnt.squeeze(-1) > 0
474
+ albedo[mask] = albedo[mask] / cnt[mask].repeat(1, 3)
475
+
476
+ mask = mask.view(h, w)
477
+
478
+ albedo = albedo.detach().cpu().numpy()
479
+ mask = mask.detach().cpu().numpy()
480
+
481
+ # dilate texture
482
+ from sklearn.neighbors import NearestNeighbors
483
+ from scipy.ndimage import binary_dilation, binary_erosion
484
+
485
+ inpaint_region = binary_dilation(mask, iterations=32)
486
+ inpaint_region[mask] = 0
487
+
488
+ search_region = mask.copy()
489
+ not_search_region = binary_erosion(search_region, iterations=3)
490
+ search_region[not_search_region] = 0
491
+
492
+ search_coords = np.stack(np.nonzero(search_region), axis=-1)
493
+ inpaint_coords = np.stack(np.nonzero(inpaint_region), axis=-1)
494
+
495
+ knn = NearestNeighbors(n_neighbors=1, algorithm="kd_tree").fit(
496
+ search_coords
497
+ )
498
+ _, indices = knn.kneighbors(inpaint_coords)
499
+
500
+ albedo[tuple(inpaint_coords.T)] = albedo[tuple(search_coords[indices[:, 0]].T)]
501
+
502
+ mesh.albedo = torch.from_numpy(albedo).to(self.device)
503
+ mesh.write(path)
504
+
505
+
506
+ elif mode == 'frames':
507
+ os.makedirs(os.path.join(self.opt.outdir, self.opt.save_path+'_frames'), exist_ok=True)
508
+ for t in range(self.opt.batch_size * interp):
509
+ tt = t / interp
510
+ path = os.path.join(self.opt.outdir, self.opt.save_path+'_frames', f'{t:03d}.ply')
511
+ self.renderer.gaussians.save_frame_ply(path, tt)
512
+ else:
513
+ path = os.path.join(self.opt.outdir, self.opt.save_path + '_4d_model.ply')
514
+ self.renderer.gaussians.save_ply(path)
515
+ self.renderer.gaussians.save_deformation(self.opt.outdir, self.opt.save_path)
516
+
517
+ print(f"[INFO] save model to {path}.")
518
+
519
+ # no gui mode
520
+ def train(self, iters=500, ui=False):
521
+ if self.gui:
522
+ from visualizer.visergui import ViserViewer
523
+ self.viser_gui = ViserViewer(device="cuda", viewer_port=8080)
524
+ if iters > 0:
525
+ self.prepare_train()
526
+ if self.gui:
527
+ self.viser_gui.set_renderer(self.renderer, self.fixed_cam)
528
+
529
+ for i in tqdm.trange(iters):
530
+ self.train_step()
531
+ if self.gui:
532
+ self.viser_gui.update()
533
+ if self.opt.mesh_format == 'frames':
534
+ self.save_model(mode='frames', interp=4)
535
+ elif self.opt.mesh_format == 'obj':
536
+ self.save_model(mode='geo+tex')
537
+
538
+ if self.opt.save_model:
539
+ self.save_model(mode='model')
540
+
541
+ # render eval
542
+ image_list =[]
543
+ nframes = self.opt.batch_size * 7 + 15 * 7
544
+ hor = 180
545
+ delta_hor = 45 / 15
546
+ delta_time = 1
547
+ for i in range(8):
548
+ time = 0
549
+ for j in range(self.opt.batch_size + 15):
550
+ pose = orbit_camera(self.opt.elevation, hor-180, self.opt.radius)
551
+ cur_cam = MiniCam(
552
+ pose,
553
+ 512,
554
+ 512,
555
+ self.cam.fovy,
556
+ self.cam.fovx,
557
+ self.cam.near,
558
+ self.cam.far,
559
+ time=time
560
+ )
561
+ with torch.no_grad():
562
+ outputs = self.renderer.render(cur_cam)
563
+
564
+ out = outputs["image"].cpu().detach().numpy().astype(np.float32)
565
+ out = np.transpose(out, (1, 2, 0))
566
+ out = np.uint8(out*255)
567
+ image_list.append(out)
568
+
569
+ time = (time + delta_time) % self.opt.batch_size
570
+ if j >= self.opt.batch_size:
571
+ hor = (hor+delta_hor) % 360
572
+
573
+
574
+ imageio.mimwrite(f'vis_data/{opt.save_path}.mp4', image_list, fps=7)
575
+
576
+ if self.gui:
577
+ while True:
578
+ self.viser_gui.update()
579
+
580
+ def process(config, input_path, guidance):
581
+ # override default config from cli
582
+ opt = OmegaConf.load(config)
583
+ opt.input = input_path
584
+ opt.save_path = os.path.splitext(os.path.basename(opt.input))[0] if opt.save_path == '' else opt.save_path
585
+
586
+
587
+ # auto find mesh from stage 1
588
+ opt.load = os.path.join(opt.outdir, opt.save_path + '_model.ply')
589
+
590
+ gui = GUI(opt, guidance)
591
+
592
+ gui.train(opt.iters)
593
+
594
+
595
+ if __name__ == "__main__":
596
+ import argparse
597
+ from omegaconf import OmegaConf
598
+
599
+ parser = argparse.ArgumentParser()
600
+ parser.add_argument("--config", required=True, help="path to the yaml config file")
601
+ args, extras = parser.parse_known_args()
602
+
603
+ # override default config from cli
604
+ opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras))
605
+ opt.save_path = os.path.splitext(os.path.basename(opt.input))[0] if opt.save_path == '' else opt.save_path
606
+
607
+
608
+ # auto find mesh from stage 1
609
+ opt.load = os.path.join(opt.outdir, opt.save_path + '_model.ply')
610
+
611
+ gui = GUI(opt)
612
+
613
+ gui.train(opt.iters)
614
+
615
+
616
+ # python main_4d.py --config configs/4d_low.yaml input=data/CONSISTENT4D_DATA/in-the-wild/blooming_rose