Spaces:
Runtime error
Runtime error
jiaweir
commited on
Commit
β’
5b9bbe2
1
Parent(s):
fc94f83
optimize
Browse files- app.py +184 -19
- configs/4d_demo.yaml +1 -1
- lgm/infer_demo.py +197 -0
- main_4d_demo.py +616 -0
app.py
CHANGED
@@ -7,6 +7,26 @@ import numpy
|
|
7 |
import hashlib
|
8 |
import shlex
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
import spaces
|
11 |
|
12 |
|
@@ -27,45 +47,179 @@ function refresh() {
|
|
27 |
}
|
28 |
"""
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
# check if there is a picture uploaded or selected
|
31 |
def check_img_input(control_image):
|
32 |
if control_image is None:
|
33 |
raise gr.Error("Please select or upload an input image")
|
34 |
|
35 |
# check if there is a picture uploaded or selected
|
36 |
-
def
|
37 |
img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
|
38 |
if not os.path.exists(os.path.join('tmp_data', f'{img_hash}_rgba_generated.mp4')):
|
39 |
raise gr.Error("Please generate a video first")
|
|
|
|
|
|
|
40 |
|
41 |
|
42 |
@spaces.GPU()
|
43 |
-
def
|
44 |
if not os.path.exists('tmp_data'):
|
45 |
os.makedirs('tmp_data')
|
46 |
img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
|
47 |
-
if
|
48 |
-
|
49 |
-
|
|
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
56 |
|
57 |
# stage 1
|
58 |
-
subprocess.run(f'export MKL_THREADING_LAYER=GNU;export MKL_SERVICE_FORCE_INTEL=1;python scripts/gen_vid.py --path tmp_data/{img_hash}_rgba.png --seed {seed_slider} --bg white', shell=True)
|
59 |
-
|
60 |
# return [os.path.join('logs', 'tmp_rgba_model.ply')]
|
61 |
return os.path.join('tmp_data', f'{img_hash}_rgba_generated.mp4')
|
62 |
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
def optimize_stage_2(image_block: Image.Image, seed_slider: int):
|
65 |
img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
|
66 |
|
67 |
# stage 2
|
68 |
-
subprocess.run(f'python main_4d.py --config {os.path.join("configs", "4d_demo.yaml")} input={os.path.join("tmp_data", f"{img_hash}_rgba.png")}', shell=True)
|
|
|
69 |
# os.rename(os.path.join('logs', f'{img_hash}_rgba_frames'), os.path.join('logs', f'{img_hash}_{seed_slider:03d}_rgba_frames'))
|
70 |
image_dir = os.path.join('logs', f'{img_hash}_rgba_frames')
|
71 |
# return 'vis_data/tmp_rgba.mp4', [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith('.ply')]
|
@@ -83,7 +237,7 @@ if __name__ == "__main__":
|
|
83 |
</div>
|
84 |
We present DreamGausssion4D, an efficient 4D generation framework that builds on Gaussian Splatting.
|
85 |
'''
|
86 |
-
_IMG_USER_GUIDE = "Please upload an image in the block above (or choose an example above),
|
87 |
|
88 |
# load images in 'data' folder as examples
|
89 |
example_folder = os.path.join(os.path.dirname(__file__), 'data')
|
@@ -104,7 +258,8 @@ if __name__ == "__main__":
|
|
104 |
image_block = gr.Image(type='pil', image_mode='RGBA', height=290, label='Input image')
|
105 |
|
106 |
# elevation_slider = gr.Slider(-90, 90, value=0, step=1, label='Estimated elevation angle')
|
107 |
-
seed_slider = gr.Slider(0, 100000, value=0, step=1, label='Random Seed')
|
|
|
108 |
gr.Markdown(
|
109 |
"random seed for video generation.")
|
110 |
|
@@ -120,20 +275,30 @@ if __name__ == "__main__":
|
|
120 |
examples_per_page=40
|
121 |
)
|
122 |
img_run_btn = gr.Button("Generate Video")
|
|
|
123 |
fourd_run_btn = gr.Button("Generate 4D")
|
124 |
img_guide_text = gr.Markdown(_IMG_USER_GUIDE, visible=True)
|
125 |
|
126 |
with gr.Column(scale=5):
|
127 |
-
|
|
|
128 |
obj4d = Model4DGS(label="4D Model", height=500, fps=14)
|
129 |
|
130 |
-
|
|
|
131 |
inputs=[image_block,
|
132 |
preprocess_chk,
|
133 |
seed_slider],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
outputs=[
|
135 |
obj3d])
|
136 |
-
fourd_run_btn.click(
|
137 |
|
138 |
# demo.queue().launch(share=True)
|
139 |
demo.queue(max_size=10) # <-- Sets up a queue with default parameters
|
|
|
7 |
import hashlib
|
8 |
import shlex
|
9 |
|
10 |
+
import rembg
|
11 |
+
import glob
|
12 |
+
import cv2
|
13 |
+
import numpy as np
|
14 |
+
from diffusers import StableVideoDiffusionPipeline
|
15 |
+
from scripts.gen_vid import *
|
16 |
+
|
17 |
+
import sys
|
18 |
+
sys.path.append('lgm')
|
19 |
+
from safetensors.torch import load_file
|
20 |
+
from kiui.cam import orbit_camera
|
21 |
+
from core.options import config_defaults, Options
|
22 |
+
from core.models import LGM
|
23 |
+
from mvdream.pipeline_mvdream import MVDreamPipeline
|
24 |
+
from infer_demo import process as process_lgm
|
25 |
+
|
26 |
+
from main_4d_demo import process as process_dg4d
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
import spaces
|
31 |
|
32 |
|
|
|
47 |
}
|
48 |
"""
|
49 |
|
50 |
+
|
51 |
+
device = torch.device('cuda')
|
52 |
+
# device = torch.device('cpu')
|
53 |
+
|
54 |
+
session = rembg.new_session(model_name='u2net')
|
55 |
+
|
56 |
+
pipe = StableVideoDiffusionPipeline.from_pretrained(
|
57 |
+
"stabilityai/stable-video-diffusion-img2vid", torch_dtype=torch.float16, variant="fp16"
|
58 |
+
)
|
59 |
+
pipe.to(device)
|
60 |
+
|
61 |
+
opt = config_defaults['big']
|
62 |
+
opt.resume = ckpt_path
|
63 |
+
# model
|
64 |
+
model = LGM(opt)
|
65 |
+
|
66 |
+
# resume pretrained checkpoint
|
67 |
+
if opt.resume is not None:
|
68 |
+
if opt.resume.endswith('safetensors'):
|
69 |
+
ckpt = load_file(opt.resume, device='cpu')
|
70 |
+
else:
|
71 |
+
ckpt = torch.load(opt.resume, map_location='cpu')
|
72 |
+
model.load_state_dict(ckpt, strict=False)
|
73 |
+
print(f'[INFO] Loaded checkpoint from {opt.resume}')
|
74 |
+
else:
|
75 |
+
print(f'[WARN] model randomly initialized, are you sure?')
|
76 |
+
|
77 |
+
# device
|
78 |
+
model = model.half().to(device)
|
79 |
+
model.eval()
|
80 |
+
rays_embeddings = model.prepare_default_rays(device)
|
81 |
+
|
82 |
+
# load image dream
|
83 |
+
pipe_mvdream = MVDreamPipeline.from_pretrained(
|
84 |
+
"ashawkey/imagedream-ipmv-diffusers", # remote weights
|
85 |
+
torch_dtype=torch.float16,
|
86 |
+
trust_remote_code=True,
|
87 |
+
# local_files_only=True,
|
88 |
+
)
|
89 |
+
pipe_mvdream = pipe_mvdream.to(device)
|
90 |
+
|
91 |
+
from guidance.zero123_utils import Zero123
|
92 |
+
guidance_zero123 = Zero123(device, model_key='ashawkey/stable-zero123-diffusers')
|
93 |
+
|
94 |
+
def preprocess(path, recenter=True, size=256, border_ratio=0.2):
|
95 |
+
files = [path]
|
96 |
+
out_dir = os.path.dirname(path)
|
97 |
+
|
98 |
+
for file in files:
|
99 |
+
|
100 |
+
out_base = os.path.basename(file).split('.')[0]
|
101 |
+
out_rgba = os.path.join(out_dir, out_base + '_rgba.png')
|
102 |
+
|
103 |
+
# load image
|
104 |
+
print(f'[INFO] loading image {file}...')
|
105 |
+
image = cv2.imread(file, cv2.IMREAD_UNCHANGED)
|
106 |
+
|
107 |
+
# carve background
|
108 |
+
print(f'[INFO] background removal...')
|
109 |
+
carved_image = rembg.remove(image, session=session) # [H, W, 4]
|
110 |
+
mask = carved_image[..., -1] > 0
|
111 |
+
|
112 |
+
# recenter
|
113 |
+
if recenter:
|
114 |
+
print(f'[INFO] recenter...')
|
115 |
+
final_rgba = np.zeros((size, size, 4), dtype=np.uint8)
|
116 |
+
|
117 |
+
coords = np.nonzero(mask)
|
118 |
+
x_min, x_max = coords[0].min(), coords[0].max()
|
119 |
+
y_min, y_max = coords[1].min(), coords[1].max()
|
120 |
+
h = x_max - x_min
|
121 |
+
w = y_max - y_min
|
122 |
+
desired_size = int(size * (1 - border_ratio))
|
123 |
+
scale = desired_size / max(h, w)
|
124 |
+
h2 = int(h * scale)
|
125 |
+
w2 = int(w * scale)
|
126 |
+
x2_min = (size - h2) // 2
|
127 |
+
x2_max = x2_min + h2
|
128 |
+
y2_min = (size - w2) // 2
|
129 |
+
y2_max = y2_min + w2
|
130 |
+
final_rgba[x2_min:x2_max, y2_min:y2_max] = cv2.resize(carved_image[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA)
|
131 |
+
|
132 |
+
else:
|
133 |
+
final_rgba = carved_image
|
134 |
+
|
135 |
+
# write image
|
136 |
+
cv2.imwrite(out_rgba, final_rgba)
|
137 |
+
|
138 |
+
def gen_vid(input_path, seed, bg='white'):
|
139 |
+
name = input_path.split('/')[-1].split('.')[0]
|
140 |
+
input_dir = os.path.dirname(input_path)
|
141 |
+
height, width = 512, 512
|
142 |
+
|
143 |
+
image = load_image(input_path, width, height, bg)
|
144 |
+
|
145 |
+
generator = torch.manual_seed(seed)
|
146 |
+
# frames = pipe(image, height, width, decode_chunk_size=2, generator=generator).frames[0]
|
147 |
+
frames = pipe(image, height, width, generator=generator).frames[0]
|
148 |
+
|
149 |
+
imageio.mimwrite(f"{input_dir}/{name}_generated.mp4", frames, fps=7)
|
150 |
+
os.makedirs(f"{input_dir}/{name}_frames", exist_ok=True)
|
151 |
+
for idx, img in enumerate(frames):
|
152 |
+
img.save(f"{input_dir}/{name}_frames/{idx:03}.png")
|
153 |
+
|
154 |
# check if there is a picture uploaded or selected
|
155 |
def check_img_input(control_image):
|
156 |
if control_image is None:
|
157 |
raise gr.Error("Please select or upload an input image")
|
158 |
|
159 |
# check if there is a picture uploaded or selected
|
160 |
+
def check_video_3d_input(image_block: Image.Image):
|
161 |
img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
|
162 |
if not os.path.exists(os.path.join('tmp_data', f'{img_hash}_rgba_generated.mp4')):
|
163 |
raise gr.Error("Please generate a video first")
|
164 |
+
if not os.path.exists(os.path.join('vis_data', f'{img_hash}_rgba_static.mp4')):
|
165 |
+
raise gr.Error("Please generate a 3D first")
|
166 |
+
|
167 |
|
168 |
|
169 |
@spaces.GPU()
|
170 |
+
def optimize_stage_0(image_block: Image.Image, preprocess_chk: bool, seed_slider: int):
|
171 |
if not os.path.exists('tmp_data'):
|
172 |
os.makedirs('tmp_data')
|
173 |
img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
|
174 |
+
if not os.path.exists(os.path.join('tmp_data', f'{img_hash}_rgba.png')):
|
175 |
+
if preprocess_chk:
|
176 |
+
# save image to a designated path
|
177 |
+
image_block.save(os.path.join('tmp_data', f'{img_hash}.png'))
|
178 |
|
179 |
+
# preprocess image
|
180 |
+
# print(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}')
|
181 |
+
# subprocess.run(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}', shell=True)
|
182 |
+
preprocess(os.path.join("tmp_data", f"{img_hash}.png"))
|
183 |
+
else:
|
184 |
+
image_block.save(os.path.join('tmp_data', f'{img_hash}_rgba.png'))
|
185 |
|
186 |
# stage 1
|
187 |
+
# subprocess.run(f'export MKL_THREADING_LAYER=GNU;export MKL_SERVICE_FORCE_INTEL=1;python scripts/gen_vid.py --path tmp_data/{img_hash}_rgba.png --seed {seed_slider} --bg white', shell=True)
|
188 |
+
gen_vid(f'tmp_data/{img_hash}_rgba.png', seed_slider)
|
189 |
# return [os.path.join('logs', 'tmp_rgba_model.ply')]
|
190 |
return os.path.join('tmp_data', f'{img_hash}_rgba_generated.mp4')
|
191 |
|
192 |
+
|
193 |
+
@spaces.GPU()
|
194 |
+
def optimize_stage_1(image_block: Image.Image, preprocess_chk: bool, seed_slider: int):
|
195 |
+
if not os.path.exists('tmp_data'):
|
196 |
+
os.makedirs('tmp_data')
|
197 |
+
img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
|
198 |
+
if not os.path.exists(os.path.join('tmp_data', f'{img_hash}_rgba.png')):
|
199 |
+
if preprocess_chk:
|
200 |
+
# save image to a designated path
|
201 |
+
image_block.save(os.path.join('tmp_data', f'{img_hash}.png'))
|
202 |
+
|
203 |
+
# preprocess image
|
204 |
+
# print(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}')
|
205 |
+
# subprocess.run(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}', shell=True)
|
206 |
+
preprocess(os.path.join("tmp_data", f"{img_hash}.png"))
|
207 |
+
else:
|
208 |
+
image_block.save(os.path.join('tmp_data', f'{img_hash}_rgba.png'))
|
209 |
+
|
210 |
+
# stage 1
|
211 |
+
# subprocess.run(f'python lgm/infer.py big --resume {ckpt_path} --test_path tmp_data/{img_hash}_rgba.png', shell=True)
|
212 |
+
process_lgm(opt, f'tmp_data/{img_hash}_rgba.png', pipe_mvdream, model, rays_embeddings)
|
213 |
+
# return [os.path.join('logs', 'tmp_rgba_model.ply')]
|
214 |
+
return os.path.join('vis_data', f'{img_hash}_rgba_static.mp4')
|
215 |
+
|
216 |
+
@spaces.GPU(duration=120)
|
217 |
def optimize_stage_2(image_block: Image.Image, seed_slider: int):
|
218 |
img_hash = hashlib.sha256(image_block.tobytes()).hexdigest()
|
219 |
|
220 |
# stage 2
|
221 |
+
# subprocess.run(f'python main_4d.py --config {os.path.join("configs", "4d_demo.yaml")} input={os.path.join("tmp_data", f"{img_hash}_rgba.png")}', shell=True)
|
222 |
+
process_dg4d(os.path.join("configs", "4d_demo.yaml"), os.path.join("tmp_data", f"{img_hash}_rgba.png"), guidance_zero123)
|
223 |
# os.rename(os.path.join('logs', f'{img_hash}_rgba_frames'), os.path.join('logs', f'{img_hash}_{seed_slider:03d}_rgba_frames'))
|
224 |
image_dir = os.path.join('logs', f'{img_hash}_rgba_frames')
|
225 |
# return 'vis_data/tmp_rgba.mp4', [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith('.ply')]
|
|
|
237 |
</div>
|
238 |
We present DreamGausssion4D, an efficient 4D generation framework that builds on Gaussian Splatting.
|
239 |
'''
|
240 |
+
_IMG_USER_GUIDE = "Please upload an image in the block above (or choose an example above), click **Generate Video** and **Generate 3D**. Finally, click **Generate 4D**."
|
241 |
|
242 |
# load images in 'data' folder as examples
|
243 |
example_folder = os.path.join(os.path.dirname(__file__), 'data')
|
|
|
258 |
image_block = gr.Image(type='pil', image_mode='RGBA', height=290, label='Input image')
|
259 |
|
260 |
# elevation_slider = gr.Slider(-90, 90, value=0, step=1, label='Estimated elevation angle')
|
261 |
+
seed_slider = gr.Slider(0, 100000, value=0, step=1, label='Random Seed (Video)')
|
262 |
+
seed_slider2 = gr.Slider(0, 100000, value=0, step=1, label='Random Seed (3D)')
|
263 |
gr.Markdown(
|
264 |
"random seed for video generation.")
|
265 |
|
|
|
275 |
examples_per_page=40
|
276 |
)
|
277 |
img_run_btn = gr.Button("Generate Video")
|
278 |
+
threed_run_btn = gr.Button("Generate 3D")
|
279 |
fourd_run_btn = gr.Button("Generate 4D")
|
280 |
img_guide_text = gr.Markdown(_IMG_USER_GUIDE, visible=True)
|
281 |
|
282 |
with gr.Column(scale=5):
|
283 |
+
dirving_video = gr.Video(label="video",height=290)
|
284 |
+
obj3d = gr.Video(label="3D Model",height=290)
|
285 |
obj4d = Model4DGS(label="4D Model", height=500, fps=14)
|
286 |
|
287 |
+
|
288 |
+
img_run_btn.click(check_img_input, inputs=[image_block], queue=False).success(optimize_stage_0,
|
289 |
inputs=[image_block,
|
290 |
preprocess_chk,
|
291 |
seed_slider],
|
292 |
+
outputs=[
|
293 |
+
dirving_video])
|
294 |
+
|
295 |
+
threed_run_btn.click(check_img_input, inputs=[image_block], queue=False).success(optimize_stage_1,
|
296 |
+
inputs=[image_block,
|
297 |
+
preprocess_chk,
|
298 |
+
seed_slider2],
|
299 |
outputs=[
|
300 |
obj3d])
|
301 |
+
fourd_run_btn.click(check_video_3d_input, inputs=[image_block], queue=False).success(optimize_stage_2, inputs=[image_block, seed_slider], outputs=[obj4d])
|
302 |
|
303 |
# demo.queue().launch(share=True)
|
304 |
demo.queue(max_size=10) # <-- Sets up a queue with default parameters
|
configs/4d_demo.yaml
CHANGED
@@ -30,7 +30,7 @@ lambda_svd: 0
|
|
30 |
# training batch size per iter
|
31 |
batch_size: 7
|
32 |
# training iterations for stage 1
|
33 |
-
iters:
|
34 |
# training iterations for stage 2
|
35 |
iters_refine: 50
|
36 |
# training camera radius
|
|
|
30 |
# training batch size per iter
|
31 |
batch_size: 7
|
32 |
# training iterations for stage 1
|
33 |
+
iters: 300
|
34 |
# training iterations for stage 2
|
35 |
iters_refine: 50
|
36 |
# training camera radius
|
lgm/infer_demo.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import tyro
|
4 |
+
import glob
|
5 |
+
import imageio
|
6 |
+
import numpy as np
|
7 |
+
import tqdm
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import torchvision.transforms.functional as TF
|
12 |
+
from safetensors.torch import load_file
|
13 |
+
|
14 |
+
import kiui
|
15 |
+
from kiui.op import recenter
|
16 |
+
from kiui.cam import orbit_camera
|
17 |
+
|
18 |
+
from core.options import AllConfigs, Options
|
19 |
+
from core.models import LGM
|
20 |
+
from mvdream.pipeline_mvdream import MVDreamPipeline
|
21 |
+
import cv2
|
22 |
+
|
23 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
24 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
25 |
+
|
26 |
+
# opt = tyro.cli(AllConfigs)
|
27 |
+
|
28 |
+
# # model
|
29 |
+
# model = LGM(opt)
|
30 |
+
|
31 |
+
# # resume pretrained checkpoint
|
32 |
+
# if opt.resume is not None:
|
33 |
+
# if opt.resume.endswith('safetensors'):
|
34 |
+
# ckpt = load_file(opt.resume, device='cpu')
|
35 |
+
# else:
|
36 |
+
# ckpt = torch.load(opt.resume, map_location='cpu')
|
37 |
+
# model.load_state_dict(ckpt, strict=False)
|
38 |
+
# print(f'[INFO] Loaded checkpoint from {opt.resume}')
|
39 |
+
# else:
|
40 |
+
# print(f'[WARN] model randomly initialized, are you sure?')
|
41 |
+
|
42 |
+
# # device
|
43 |
+
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
44 |
+
# model = model.half().to(device)
|
45 |
+
# model.eval()
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
# process function
|
50 |
+
def process(opt: Options, path, pipe, model, rays_embeddings):
|
51 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
52 |
+
tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
|
53 |
+
proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=device)
|
54 |
+
proj_matrix[0, 0] = 1 / tan_half_fov
|
55 |
+
proj_matrix[1, 1] = 1 / tan_half_fov
|
56 |
+
proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
|
57 |
+
proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
|
58 |
+
proj_matrix[2, 3] = 1
|
59 |
+
|
60 |
+
|
61 |
+
name = os.path.splitext(os.path.basename(path))[0]
|
62 |
+
print(f'[INFO] Processing {path} --> {name}')
|
63 |
+
os.makedirs('vis_data', exist_ok=True)
|
64 |
+
os.makedirs('logs', exist_ok=True)
|
65 |
+
|
66 |
+
image = kiui.read_image(path, mode='uint8')
|
67 |
+
|
68 |
+
# generate mv
|
69 |
+
image = image.astype(np.float32) / 255.0
|
70 |
+
|
71 |
+
# rgba to rgb white bg
|
72 |
+
if image.shape[-1] == 4:
|
73 |
+
image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])
|
74 |
+
|
75 |
+
mv_image = pipe('', image, guidance_scale=5.0, num_inference_steps=30, elevation=0)
|
76 |
+
mv_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0) # [4, 256, 256, 3], float32
|
77 |
+
|
78 |
+
# generate gaussians
|
79 |
+
input_image = torch.from_numpy(mv_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256]
|
80 |
+
input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False)
|
81 |
+
input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
|
82 |
+
|
83 |
+
input_image = torch.cat([input_image, rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W]
|
84 |
+
|
85 |
+
with torch.inference_mode():
|
86 |
+
############## align azimuth #####################
|
87 |
+
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
88 |
+
# generate gaussians
|
89 |
+
gaussians = model.forward_gaussians(input_image)
|
90 |
+
|
91 |
+
best_azi = 0
|
92 |
+
best_diff = 1e8
|
93 |
+
for v, azi in enumerate(np.arange(-180, 180, 1)):
|
94 |
+
cam_poses = torch.from_numpy(orbit_camera(0, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
|
95 |
+
|
96 |
+
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
|
97 |
+
|
98 |
+
# cameras needed by gaussian rasterizer
|
99 |
+
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
|
100 |
+
cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
|
101 |
+
cam_pos = - cam_poses[:, :3, 3] # [V, 3]
|
102 |
+
|
103 |
+
# scale = min(azi / 360, 1)
|
104 |
+
scale = 1
|
105 |
+
|
106 |
+
|
107 |
+
result = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=scale)
|
108 |
+
rendered_image = result['image']
|
109 |
+
|
110 |
+
rendered_image = rendered_image.squeeze(1).permute(0,2,3,1).squeeze(0).contiguous().float().cpu().numpy()
|
111 |
+
rendered_image = cv2.resize(rendered_image, (image.shape[0], image.shape[1]), interpolation=cv2.INTER_AREA)
|
112 |
+
|
113 |
+
diff = np.mean((rendered_image- image) ** 2)
|
114 |
+
|
115 |
+
if diff < best_diff:
|
116 |
+
best_diff = diff
|
117 |
+
best_azi = azi
|
118 |
+
print("Best aligned azimuth: ", best_azi)
|
119 |
+
|
120 |
+
mv_image = []
|
121 |
+
for v, azi in enumerate([0, 90, 180, 270]):
|
122 |
+
cam_poses = torch.from_numpy(orbit_camera(0, azi + best_azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
|
123 |
+
|
124 |
+
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
|
125 |
+
|
126 |
+
# cameras needed by gaussian rasterizer
|
127 |
+
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
|
128 |
+
cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
|
129 |
+
cam_pos = - cam_poses[:, :3, 3] # [V, 3]
|
130 |
+
|
131 |
+
# scale = min(azi / 360, 1)
|
132 |
+
scale = 1
|
133 |
+
|
134 |
+
|
135 |
+
result = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=scale)
|
136 |
+
rendered_image = result['image']
|
137 |
+
rendered_image = rendered_image.squeeze(1)
|
138 |
+
rendered_image = F.interpolate(rendered_image, (256, 256))
|
139 |
+
rendered_image = rendered_image.permute(0,2,3,1).contiguous().float().cpu().numpy()
|
140 |
+
mv_image.append(rendered_image)
|
141 |
+
mv_image = np.concatenate(mv_image, axis=0)
|
142 |
+
|
143 |
+
input_image = torch.from_numpy(mv_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256]
|
144 |
+
input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False)
|
145 |
+
input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
|
146 |
+
|
147 |
+
input_image = torch.cat([input_image, rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W]
|
148 |
+
|
149 |
+
################################
|
150 |
+
|
151 |
+
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
152 |
+
# generate gaussians
|
153 |
+
gaussians = model.forward_gaussians(input_image)
|
154 |
+
|
155 |
+
# save gaussians
|
156 |
+
model.gs.save_ply(gaussians, os.path.join('logs', name + '_model.ply'))
|
157 |
+
|
158 |
+
# render 360 video
|
159 |
+
images = []
|
160 |
+
elevation = 0
|
161 |
+
|
162 |
+
if opt.fancy_video:
|
163 |
+
|
164 |
+
azimuth = np.arange(0, 720, 4, dtype=np.int32)
|
165 |
+
for azi in tqdm.tqdm(azimuth):
|
166 |
+
|
167 |
+
cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
|
168 |
+
|
169 |
+
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
|
170 |
+
|
171 |
+
# cameras needed by gaussian rasterizer
|
172 |
+
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
|
173 |
+
cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
|
174 |
+
cam_pos = - cam_poses[:, :3, 3] # [V, 3]
|
175 |
+
|
176 |
+
scale = min(azi / 360, 1)
|
177 |
+
|
178 |
+
image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=scale)['image']
|
179 |
+
images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
|
180 |
+
else:
|
181 |
+
azimuth = np.arange(0, 360, 2, dtype=np.int32)
|
182 |
+
for azi in tqdm.tqdm(azimuth):
|
183 |
+
|
184 |
+
cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
|
185 |
+
|
186 |
+
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
|
187 |
+
|
188 |
+
# cameras needed by gaussian rasterizer
|
189 |
+
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
|
190 |
+
cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
|
191 |
+
cam_pos = - cam_poses[:, :3, 3] # [V, 3]
|
192 |
+
|
193 |
+
image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=1)['image']
|
194 |
+
images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
|
195 |
+
|
196 |
+
images = np.concatenate(images, axis=0)
|
197 |
+
imageio.mimwrite(os.path.join('vis_data', name + '_static.mp4'), images, fps=30)
|
main_4d_demo.py
ADDED
@@ -0,0 +1,616 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import time
|
4 |
+
import tqdm
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
import rembg
|
11 |
+
|
12 |
+
from cam_utils import orbit_camera, OrbitCamera
|
13 |
+
from gs_renderer_4d import Renderer, MiniCam
|
14 |
+
|
15 |
+
from grid_put import mipmap_linear_grid_put_2d
|
16 |
+
import imageio
|
17 |
+
|
18 |
+
import copy
|
19 |
+
from omegaconf import OmegaConf
|
20 |
+
|
21 |
+
class GUI:
|
22 |
+
def __init__(self, opt, guidance_zero123):
|
23 |
+
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
|
24 |
+
self.gui = opt.gui # enable gui
|
25 |
+
self.W = opt.W
|
26 |
+
self.H = opt.H
|
27 |
+
self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)
|
28 |
+
|
29 |
+
self.mode = "image"
|
30 |
+
# self.seed = "random"
|
31 |
+
self.seed = 888
|
32 |
+
|
33 |
+
self.buffer_image = np.ones((self.W, self.H, 3), dtype=np.float32)
|
34 |
+
self.need_update = True # update buffer_image
|
35 |
+
|
36 |
+
# models
|
37 |
+
self.device = torch.device("cuda")
|
38 |
+
self.bg_remover = None
|
39 |
+
|
40 |
+
self.guidance_sd = None
|
41 |
+
self.guidance_zero123 = guidance_zero123
|
42 |
+
self.guidance_svd = None
|
43 |
+
|
44 |
+
|
45 |
+
self.enable_sd = False
|
46 |
+
self.enable_zero123 = False
|
47 |
+
self.enable_svd = False
|
48 |
+
|
49 |
+
|
50 |
+
# renderer
|
51 |
+
self.renderer = Renderer(self.opt, sh_degree=self.opt.sh_degree)
|
52 |
+
self.gaussain_scale_factor = 1
|
53 |
+
|
54 |
+
# input image
|
55 |
+
self.input_img = None
|
56 |
+
self.input_mask = None
|
57 |
+
self.input_img_torch = None
|
58 |
+
self.input_mask_torch = None
|
59 |
+
self.overlay_input_img = False
|
60 |
+
self.overlay_input_img_ratio = 0.5
|
61 |
+
|
62 |
+
self.input_img_list = None
|
63 |
+
self.input_mask_list = None
|
64 |
+
self.input_img_torch_list = None
|
65 |
+
self.input_mask_torch_list = None
|
66 |
+
|
67 |
+
# input text
|
68 |
+
self.prompt = ""
|
69 |
+
self.negative_prompt = ""
|
70 |
+
|
71 |
+
# training stuff
|
72 |
+
self.training = False
|
73 |
+
self.optimizer = None
|
74 |
+
self.step = 0
|
75 |
+
self.train_steps = 1 # steps per rendering loop
|
76 |
+
|
77 |
+
# load input data from cmdline
|
78 |
+
if self.opt.input is not None: # True
|
79 |
+
self.load_input(self.opt.input) # load imgs, if has bg, then rm bg; or just load imgs
|
80 |
+
|
81 |
+
# override prompt from cmdline
|
82 |
+
if self.opt.prompt is not None: # None
|
83 |
+
self.prompt = self.opt.prompt
|
84 |
+
|
85 |
+
# override if provide a checkpoint
|
86 |
+
if self.opt.load is not None: # not None
|
87 |
+
self.renderer.initialize(self.opt.load)
|
88 |
+
# self.renderer.gaussians.load_model(opt.outdir, opt.save_path)
|
89 |
+
else:
|
90 |
+
# initialize gaussians to a blob
|
91 |
+
self.renderer.initialize(num_pts=self.opt.num_pts)
|
92 |
+
|
93 |
+
self.seed_everything()
|
94 |
+
|
95 |
+
def seed_everything(self):
|
96 |
+
try:
|
97 |
+
seed = int(self.seed)
|
98 |
+
except:
|
99 |
+
seed = np.random.randint(0, 1000000)
|
100 |
+
|
101 |
+
print(f'Seed: {seed:d}')
|
102 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
103 |
+
np.random.seed(seed)
|
104 |
+
torch.manual_seed(seed)
|
105 |
+
torch.cuda.manual_seed(seed)
|
106 |
+
torch.backends.cudnn.deterministic = True
|
107 |
+
torch.backends.cudnn.benchmark = True
|
108 |
+
|
109 |
+
self.last_seed = seed
|
110 |
+
|
111 |
+
def prepare_train(self):
|
112 |
+
|
113 |
+
self.step = 0
|
114 |
+
|
115 |
+
# setup training
|
116 |
+
self.renderer.gaussians.training_setup(self.opt)
|
117 |
+
|
118 |
+
# # do not do progressive sh-level
|
119 |
+
self.renderer.gaussians.active_sh_degree = self.renderer.gaussians.max_sh_degree
|
120 |
+
self.optimizer = self.renderer.gaussians.optimizer
|
121 |
+
|
122 |
+
# default camera
|
123 |
+
if self.opt.mvdream or self.opt.imagedream:
|
124 |
+
# the second view is the front view for mvdream/imagedream.
|
125 |
+
pose = orbit_camera(self.opt.elevation, 90, self.opt.radius)
|
126 |
+
else:
|
127 |
+
pose = orbit_camera(self.opt.elevation, 0, self.opt.radius)
|
128 |
+
self.fixed_cam = MiniCam(
|
129 |
+
pose,
|
130 |
+
self.opt.ref_size,
|
131 |
+
self.opt.ref_size,
|
132 |
+
self.cam.fovy,
|
133 |
+
self.cam.fovx,
|
134 |
+
self.cam.near,
|
135 |
+
self.cam.far,
|
136 |
+
)
|
137 |
+
|
138 |
+
self.enable_sd = self.opt.lambda_sd > 0
|
139 |
+
self.enable_zero123 = self.opt.lambda_zero123 > 0
|
140 |
+
self.enable_svd = self.opt.lambda_svd > 0 and self.input_img is not None
|
141 |
+
|
142 |
+
# lazy load guidance model
|
143 |
+
if self.guidance_sd is None and self.enable_sd:
|
144 |
+
if self.opt.mvdream:
|
145 |
+
print(f"[INFO] loading MVDream...")
|
146 |
+
from guidance.mvdream_utils import MVDream
|
147 |
+
self.guidance_sd = MVDream(self.device)
|
148 |
+
print(f"[INFO] loaded MVDream!")
|
149 |
+
elif self.opt.imagedream:
|
150 |
+
print(f"[INFO] loading ImageDream...")
|
151 |
+
from guidance.imagedream_utils import ImageDream
|
152 |
+
self.guidance_sd = ImageDream(self.device)
|
153 |
+
print(f"[INFO] loaded ImageDream!")
|
154 |
+
else:
|
155 |
+
print(f"[INFO] loading SD...")
|
156 |
+
from guidance.sd_utils import StableDiffusion
|
157 |
+
self.guidance_sd = StableDiffusion(self.device)
|
158 |
+
print(f"[INFO] loaded SD!")
|
159 |
+
|
160 |
+
if self.guidance_zero123 is None and self.enable_zero123:
|
161 |
+
print(f"[INFO] loading zero123...")
|
162 |
+
from guidance.zero123_utils import Zero123
|
163 |
+
if self.opt.stable_zero123:
|
164 |
+
self.guidance_zero123 = Zero123(self.device, model_key='ashawkey/stable-zero123-diffusers')
|
165 |
+
else:
|
166 |
+
self.guidance_zero123 = Zero123(self.device, model_key='ashawkey/zero123-xl-diffusers')
|
167 |
+
print(f"[INFO] loaded zero123!")
|
168 |
+
|
169 |
+
if self.guidance_svd is None and self.enable_svd: # False
|
170 |
+
print(f"[INFO] loading SVD...")
|
171 |
+
from guidance.svd_utils import StableVideoDiffusion
|
172 |
+
self.guidance_svd = StableVideoDiffusion(self.device)
|
173 |
+
print(f"[INFO] loaded SVD!")
|
174 |
+
|
175 |
+
# input image
|
176 |
+
if self.input_img is not None:
|
177 |
+
self.input_img_torch = torch.from_numpy(self.input_img).permute(2, 0, 1).unsqueeze(0).to(self.device)
|
178 |
+
self.input_img_torch = F.interpolate(self.input_img_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False)
|
179 |
+
|
180 |
+
self.input_mask_torch = torch.from_numpy(self.input_mask).permute(2, 0, 1).unsqueeze(0).to(self.device)
|
181 |
+
self.input_mask_torch = F.interpolate(self.input_mask_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False)
|
182 |
+
|
183 |
+
if self.input_img_list is not None:
|
184 |
+
self.input_img_torch_list = [torch.from_numpy(input_img).permute(2, 0, 1).unsqueeze(0).to(self.device) for input_img in self.input_img_list]
|
185 |
+
self.input_img_torch_list = [F.interpolate(input_img_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False) for input_img_torch in self.input_img_torch_list]
|
186 |
+
|
187 |
+
self.input_mask_torch_list = [torch.from_numpy(input_mask).permute(2, 0, 1).unsqueeze(0).to(self.device) for input_mask in self.input_mask_list]
|
188 |
+
self.input_mask_torch_list = [F.interpolate(input_mask_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False) for input_mask_torch in self.input_mask_torch_list]
|
189 |
+
# prepare embeddings
|
190 |
+
with torch.no_grad():
|
191 |
+
|
192 |
+
if self.enable_sd:
|
193 |
+
if self.opt.imagedream:
|
194 |
+
img_pos_list, img_neg_list, ip_pos_list, ip_neg_list, emb_pos_list, emb_neg_list = [], [], [], [], [], []
|
195 |
+
for _ in range(self.opt.n_views):
|
196 |
+
for input_img_torch in self.input_img_torch_list:
|
197 |
+
img_pos, img_neg, ip_pos, ip_neg, emb_pos, emb_neg = self.guidance_sd.get_image_text_embeds(input_img_torch, [self.prompt], [self.negative_prompt])
|
198 |
+
img_pos_list.append(img_pos)
|
199 |
+
img_neg_list.append(img_neg)
|
200 |
+
ip_pos_list.append(ip_pos)
|
201 |
+
ip_neg_list.append(ip_neg)
|
202 |
+
emb_pos_list.append(emb_pos)
|
203 |
+
emb_neg_list.append(emb_neg)
|
204 |
+
self.guidance_sd.image_embeddings['pos'] = torch.cat(img_pos_list, 0)
|
205 |
+
self.guidance_sd.image_embeddings['neg'] = torch.cat(img_pos_list, 0)
|
206 |
+
self.guidance_sd.image_embeddings['ip_img'] = torch.cat(ip_pos_list, 0)
|
207 |
+
self.guidance_sd.image_embeddings['neg_ip_img'] = torch.cat(ip_neg_list, 0)
|
208 |
+
self.guidance_sd.embeddings['pos'] = torch.cat(emb_pos_list, 0)
|
209 |
+
self.guidance_sd.embeddings['neg'] = torch.cat(emb_neg_list, 0)
|
210 |
+
else:
|
211 |
+
self.guidance_sd.get_text_embeds([self.prompt], [self.negative_prompt])
|
212 |
+
|
213 |
+
if self.enable_zero123:
|
214 |
+
c_list, v_list = [], []
|
215 |
+
for _ in range(self.opt.n_views):
|
216 |
+
for input_img_torch in self.input_img_torch_list:
|
217 |
+
c, v = self.guidance_zero123.get_img_embeds(input_img_torch)
|
218 |
+
c_list.append(c)
|
219 |
+
v_list.append(v)
|
220 |
+
self.guidance_zero123.embeddings = [torch.cat(c_list, 0), torch.cat(v_list, 0)]
|
221 |
+
|
222 |
+
if self.enable_svd:
|
223 |
+
self.guidance_svd.get_img_embeds(self.input_img)
|
224 |
+
|
225 |
+
def train_step(self):
|
226 |
+
starter = torch.cuda.Event(enable_timing=True)
|
227 |
+
ender = torch.cuda.Event(enable_timing=True)
|
228 |
+
starter.record()
|
229 |
+
|
230 |
+
for _ in range(self.train_steps): # 1
|
231 |
+
|
232 |
+
self.step += 1 # self.step starts from 0
|
233 |
+
step_ratio = min(1, self.step / self.opt.iters) # 1, step / 500
|
234 |
+
|
235 |
+
# update lr
|
236 |
+
self.renderer.gaussians.update_learning_rate(self.step)
|
237 |
+
|
238 |
+
loss = 0
|
239 |
+
|
240 |
+
self.renderer.prepare_render()
|
241 |
+
|
242 |
+
### known view
|
243 |
+
if not self.opt.imagedream:
|
244 |
+
for b_idx in range(self.opt.batch_size):
|
245 |
+
cur_cam = copy.deepcopy(self.fixed_cam)
|
246 |
+
cur_cam.time = b_idx
|
247 |
+
out = self.renderer.render(cur_cam)
|
248 |
+
|
249 |
+
# rgb loss
|
250 |
+
image = out["image"].unsqueeze(0) # [1, 3, H, W] in [0, 1]
|
251 |
+
loss = loss + 10000 * step_ratio * F.mse_loss(image, self.input_img_torch_list[b_idx]) / self.opt.batch_size
|
252 |
+
|
253 |
+
# mask loss
|
254 |
+
mask = out["alpha"].unsqueeze(0) # [1, 1, H, W] in [0, 1]
|
255 |
+
loss = loss + 1000 * step_ratio * F.mse_loss(mask, self.input_mask_torch_list[b_idx]) / self.opt.batch_size
|
256 |
+
|
257 |
+
### novel view (manual batch)
|
258 |
+
render_resolution = 128 if step_ratio < 0.3 else (256 if step_ratio < 0.6 else 512)
|
259 |
+
# render_resolution = 512
|
260 |
+
images = []
|
261 |
+
poses = []
|
262 |
+
vers, hors, radii = [], [], []
|
263 |
+
# avoid too large elevation (> 80 or < -80), and make sure it always cover [-30, 30]
|
264 |
+
min_ver = max(min(self.opt.min_ver, self.opt.min_ver - self.opt.elevation), -80 - self.opt.elevation)
|
265 |
+
max_ver = min(max(self.opt.max_ver, self.opt.max_ver - self.opt.elevation), 80 - self.opt.elevation)
|
266 |
+
|
267 |
+
for _ in range(self.opt.n_views):
|
268 |
+
for b_idx in range(self.opt.batch_size):
|
269 |
+
|
270 |
+
# render random view
|
271 |
+
ver = np.random.randint(min_ver, max_ver)
|
272 |
+
hor = np.random.randint(-180, 180)
|
273 |
+
radius = 0
|
274 |
+
|
275 |
+
vers.append(ver)
|
276 |
+
hors.append(hor)
|
277 |
+
radii.append(radius)
|
278 |
+
|
279 |
+
pose = orbit_camera(self.opt.elevation + ver, hor, self.opt.radius + radius)
|
280 |
+
poses.append(pose)
|
281 |
+
|
282 |
+
cur_cam = MiniCam(pose, render_resolution, render_resolution, self.cam.fovy, self.cam.fovx, self.cam.near, self.cam.far, time=b_idx)
|
283 |
+
|
284 |
+
bg_color = torch.tensor([1, 1, 1] if np.random.rand() > self.opt.invert_bg_prob else [0, 0, 0], dtype=torch.float32, device="cuda")
|
285 |
+
out = self.renderer.render(cur_cam, bg_color=bg_color)
|
286 |
+
|
287 |
+
image = out["image"].unsqueeze(0) # [1, 3, H, W] in [0, 1]
|
288 |
+
images.append(image)
|
289 |
+
|
290 |
+
# enable mvdream training
|
291 |
+
if self.opt.mvdream or self.opt.imagedream: # False
|
292 |
+
for view_i in range(1, 4):
|
293 |
+
pose_i = orbit_camera(self.opt.elevation + ver, hor + 90 * view_i, self.opt.radius + radius)
|
294 |
+
poses.append(pose_i)
|
295 |
+
|
296 |
+
cur_cam_i = MiniCam(pose_i, render_resolution, render_resolution, self.cam.fovy, self.cam.fovx, self.cam.near, self.cam.far)
|
297 |
+
|
298 |
+
# bg_color = torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32, device="cuda")
|
299 |
+
out_i = self.renderer.render(cur_cam_i, bg_color=bg_color)
|
300 |
+
|
301 |
+
image = out_i["image"].unsqueeze(0) # [1, 3, H, W] in [0, 1]
|
302 |
+
images.append(image)
|
303 |
+
|
304 |
+
|
305 |
+
|
306 |
+
images = torch.cat(images, dim=0)
|
307 |
+
poses = torch.from_numpy(np.stack(poses, axis=0)).to(self.device)
|
308 |
+
|
309 |
+
# guidance loss
|
310 |
+
if self.enable_sd:
|
311 |
+
if self.opt.mvdream or self.opt.imagedream:
|
312 |
+
loss = loss + self.opt.lambda_sd * self.guidance_sd.train_step(images, poses, step_ratio)
|
313 |
+
else:
|
314 |
+
loss = loss + self.opt.lambda_sd * self.guidance_sd.train_step(images, step_ratio)
|
315 |
+
|
316 |
+
if self.enable_zero123:
|
317 |
+
loss = loss + self.opt.lambda_zero123 * self.guidance_zero123.train_step(images, vers, hors, radii, step_ratio) / (self.opt.batch_size * self.opt.n_views)
|
318 |
+
|
319 |
+
if self.enable_svd:
|
320 |
+
loss = loss + self.opt.lambda_svd * self.guidance_svd.train_step(images, step_ratio)
|
321 |
+
|
322 |
+
# optimize step
|
323 |
+
loss.backward()
|
324 |
+
self.optimizer.step()
|
325 |
+
self.optimizer.zero_grad()
|
326 |
+
|
327 |
+
# densify and prune
|
328 |
+
if self.step >= self.opt.density_start_iter and self.step <= self.opt.density_end_iter:
|
329 |
+
viewspace_point_tensor, visibility_filter, radii = out["viewspace_points"], out["visibility_filter"], out["radii"]
|
330 |
+
self.renderer.gaussians.max_radii2D[visibility_filter] = torch.max(self.renderer.gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
|
331 |
+
self.renderer.gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
|
332 |
+
|
333 |
+
if self.step % self.opt.densification_interval == 0:
|
334 |
+
# size_threshold = 20 if self.step > self.opt.opacity_reset_interval else None
|
335 |
+
self.renderer.gaussians.densify_and_prune(self.opt.densify_grad_threshold, min_opacity=0.01, extent=0.5, max_screen_size=1)
|
336 |
+
|
337 |
+
if self.step % self.opt.opacity_reset_interval == 0:
|
338 |
+
self.renderer.gaussians.reset_opacity()
|
339 |
+
|
340 |
+
ender.record()
|
341 |
+
torch.cuda.synchronize()
|
342 |
+
t = starter.elapsed_time(ender)
|
343 |
+
|
344 |
+
self.need_update = True
|
345 |
+
|
346 |
+
|
347 |
+
def load_input(self, file):
|
348 |
+
if self.opt.data_mode == 'c4d':
|
349 |
+
file_list = [os.path.join(file, f'{x * self.opt.downsample_rate}.png') for x in range(self.opt.batch_size)]
|
350 |
+
elif self.opt.data_mode == 'svd':
|
351 |
+
# file_list = [file.replace('.png', f'_frames/{x* self.opt.downsample_rate:03d}_rgba.png') for x in range(self.opt.batch_size)]
|
352 |
+
# file_list = [x if os.path.exists(x) else (x.replace('_rgba.png', '.png')) for x in file_list]
|
353 |
+
file_list = [file.replace('.png', f'_frames/{x* self.opt.downsample_rate:03d}.png') for x in range(self.opt.batch_size)]
|
354 |
+
else:
|
355 |
+
raise NotImplementedError
|
356 |
+
self.input_img_list, self.input_mask_list = [], []
|
357 |
+
for file in file_list:
|
358 |
+
# load image
|
359 |
+
print(f'[INFO] load image from {file}...')
|
360 |
+
img = cv2.imread(file, cv2.IMREAD_UNCHANGED)
|
361 |
+
if img.shape[-1] == 3:
|
362 |
+
if self.bg_remover is None:
|
363 |
+
self.bg_remover = rembg.new_session()
|
364 |
+
img = rembg.remove(img, session=self.bg_remover)
|
365 |
+
# cv2.imwrite(file.replace('.png', '_rgba.png'), img)
|
366 |
+
img = cv2.resize(img, (self.W, self.H), interpolation=cv2.INTER_AREA)
|
367 |
+
img = img.astype(np.float32) / 255.0
|
368 |
+
input_mask = img[..., 3:]
|
369 |
+
# white bg
|
370 |
+
input_img = img[..., :3] * input_mask + (1 - input_mask)
|
371 |
+
# bgr to rgb
|
372 |
+
input_img = input_img[..., ::-1].copy()
|
373 |
+
self.input_img_list.append(input_img)
|
374 |
+
self.input_mask_list.append(input_mask)
|
375 |
+
|
376 |
+
@torch.no_grad()
|
377 |
+
def save_model(self, mode='geo', texture_size=1024, interp=1):
|
378 |
+
os.makedirs(self.opt.outdir, exist_ok=True)
|
379 |
+
if mode == 'geo':
|
380 |
+
path = f'logs/{opt.save_path}_mesh_{t:03d}.ply'
|
381 |
+
mesh = self.renderer.gaussians.extract_mesh_t(path, self.opt.density_thresh, t=t)
|
382 |
+
mesh.write_ply(path)
|
383 |
+
|
384 |
+
elif mode == 'geo+tex':
|
385 |
+
from mesh import Mesh, safe_normalize
|
386 |
+
os.makedirs(os.path.join(self.opt.outdir, self.opt.save_path+'_meshes'), exist_ok=True)
|
387 |
+
for t in range(self.opt.batch_size):
|
388 |
+
path = os.path.join(self.opt.outdir, self.opt.save_path+'_meshes', f'{t:03d}.obj')
|
389 |
+
mesh = self.renderer.gaussians.extract_mesh_t(path, self.opt.density_thresh, t=t)
|
390 |
+
|
391 |
+
# perform texture extraction
|
392 |
+
print(f"[INFO] unwrap uv...")
|
393 |
+
h = w = texture_size
|
394 |
+
mesh.auto_uv()
|
395 |
+
mesh.auto_normal()
|
396 |
+
|
397 |
+
albedo = torch.zeros((h, w, 3), device=self.device, dtype=torch.float32)
|
398 |
+
cnt = torch.zeros((h, w, 1), device=self.device, dtype=torch.float32)
|
399 |
+
|
400 |
+
vers = [0] * 8 + [-45] * 8 + [45] * 8 + [-89.9, 89.9]
|
401 |
+
hors = [0, 45, -45, 90, -90, 135, -135, 180] * 3 + [0, 0]
|
402 |
+
|
403 |
+
render_resolution = 512
|
404 |
+
|
405 |
+
import nvdiffrast.torch as dr
|
406 |
+
|
407 |
+
if not self.opt.force_cuda_rast and (not self.opt.gui or os.name == 'nt'):
|
408 |
+
glctx = dr.RasterizeGLContext()
|
409 |
+
else:
|
410 |
+
glctx = dr.RasterizeCudaContext()
|
411 |
+
|
412 |
+
for ver, hor in zip(vers, hors):
|
413 |
+
# render image
|
414 |
+
pose = orbit_camera(ver, hor, self.cam.radius)
|
415 |
+
|
416 |
+
cur_cam = MiniCam(
|
417 |
+
pose,
|
418 |
+
render_resolution,
|
419 |
+
render_resolution,
|
420 |
+
self.cam.fovy,
|
421 |
+
self.cam.fovx,
|
422 |
+
self.cam.near,
|
423 |
+
self.cam.far,
|
424 |
+
time=t
|
425 |
+
)
|
426 |
+
|
427 |
+
cur_out = self.renderer.render(cur_cam)
|
428 |
+
|
429 |
+
rgbs = cur_out["image"].unsqueeze(0) # [1, 3, H, W] in [0, 1]
|
430 |
+
|
431 |
+
# get coordinate in texture image
|
432 |
+
pose = torch.from_numpy(pose.astype(np.float32)).to(self.device)
|
433 |
+
proj = torch.from_numpy(self.cam.perspective.astype(np.float32)).to(self.device)
|
434 |
+
|
435 |
+
v_cam = torch.matmul(F.pad(mesh.v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0)
|
436 |
+
v_clip = v_cam @ proj.T
|
437 |
+
rast, rast_db = dr.rasterize(glctx, v_clip, mesh.f, (render_resolution, render_resolution))
|
438 |
+
|
439 |
+
depth, _ = dr.interpolate(-v_cam[..., [2]], rast, mesh.f) # [1, H, W, 1]
|
440 |
+
depth = depth.squeeze(0) # [H, W, 1]
|
441 |
+
|
442 |
+
alpha = (rast[0, ..., 3:] > 0).float()
|
443 |
+
|
444 |
+
uvs, _ = dr.interpolate(mesh.vt.unsqueeze(0), rast, mesh.ft) # [1, 512, 512, 2] in [0, 1]
|
445 |
+
|
446 |
+
# use normal to produce a back-project mask
|
447 |
+
normal, _ = dr.interpolate(mesh.vn.unsqueeze(0).contiguous(), rast, mesh.fn)
|
448 |
+
normal = safe_normalize(normal[0])
|
449 |
+
|
450 |
+
# rotated normal (where [0, 0, 1] always faces camera)
|
451 |
+
rot_normal = normal @ pose[:3, :3]
|
452 |
+
viewcos = rot_normal[..., [2]]
|
453 |
+
|
454 |
+
mask = (alpha > 0) & (viewcos > 0.5) # [H, W, 1]
|
455 |
+
mask = mask.view(-1)
|
456 |
+
|
457 |
+
uvs = uvs.view(-1, 2).clamp(0, 1)[mask]
|
458 |
+
rgbs = rgbs.view(3, -1).permute(1, 0)[mask].contiguous()
|
459 |
+
|
460 |
+
# update texture image
|
461 |
+
cur_albedo, cur_cnt = mipmap_linear_grid_put_2d(
|
462 |
+
h, w,
|
463 |
+
uvs[..., [1, 0]] * 2 - 1,
|
464 |
+
rgbs,
|
465 |
+
min_resolution=256,
|
466 |
+
return_count=True,
|
467 |
+
)
|
468 |
+
|
469 |
+
mask = cnt.squeeze(-1) < 0.1
|
470 |
+
albedo[mask] += cur_albedo[mask]
|
471 |
+
cnt[mask] += cur_cnt[mask]
|
472 |
+
|
473 |
+
mask = cnt.squeeze(-1) > 0
|
474 |
+
albedo[mask] = albedo[mask] / cnt[mask].repeat(1, 3)
|
475 |
+
|
476 |
+
mask = mask.view(h, w)
|
477 |
+
|
478 |
+
albedo = albedo.detach().cpu().numpy()
|
479 |
+
mask = mask.detach().cpu().numpy()
|
480 |
+
|
481 |
+
# dilate texture
|
482 |
+
from sklearn.neighbors import NearestNeighbors
|
483 |
+
from scipy.ndimage import binary_dilation, binary_erosion
|
484 |
+
|
485 |
+
inpaint_region = binary_dilation(mask, iterations=32)
|
486 |
+
inpaint_region[mask] = 0
|
487 |
+
|
488 |
+
search_region = mask.copy()
|
489 |
+
not_search_region = binary_erosion(search_region, iterations=3)
|
490 |
+
search_region[not_search_region] = 0
|
491 |
+
|
492 |
+
search_coords = np.stack(np.nonzero(search_region), axis=-1)
|
493 |
+
inpaint_coords = np.stack(np.nonzero(inpaint_region), axis=-1)
|
494 |
+
|
495 |
+
knn = NearestNeighbors(n_neighbors=1, algorithm="kd_tree").fit(
|
496 |
+
search_coords
|
497 |
+
)
|
498 |
+
_, indices = knn.kneighbors(inpaint_coords)
|
499 |
+
|
500 |
+
albedo[tuple(inpaint_coords.T)] = albedo[tuple(search_coords[indices[:, 0]].T)]
|
501 |
+
|
502 |
+
mesh.albedo = torch.from_numpy(albedo).to(self.device)
|
503 |
+
mesh.write(path)
|
504 |
+
|
505 |
+
|
506 |
+
elif mode == 'frames':
|
507 |
+
os.makedirs(os.path.join(self.opt.outdir, self.opt.save_path+'_frames'), exist_ok=True)
|
508 |
+
for t in range(self.opt.batch_size * interp):
|
509 |
+
tt = t / interp
|
510 |
+
path = os.path.join(self.opt.outdir, self.opt.save_path+'_frames', f'{t:03d}.ply')
|
511 |
+
self.renderer.gaussians.save_frame_ply(path, tt)
|
512 |
+
else:
|
513 |
+
path = os.path.join(self.opt.outdir, self.opt.save_path + '_4d_model.ply')
|
514 |
+
self.renderer.gaussians.save_ply(path)
|
515 |
+
self.renderer.gaussians.save_deformation(self.opt.outdir, self.opt.save_path)
|
516 |
+
|
517 |
+
print(f"[INFO] save model to {path}.")
|
518 |
+
|
519 |
+
# no gui mode
|
520 |
+
def train(self, iters=500, ui=False):
|
521 |
+
if self.gui:
|
522 |
+
from visualizer.visergui import ViserViewer
|
523 |
+
self.viser_gui = ViserViewer(device="cuda", viewer_port=8080)
|
524 |
+
if iters > 0:
|
525 |
+
self.prepare_train()
|
526 |
+
if self.gui:
|
527 |
+
self.viser_gui.set_renderer(self.renderer, self.fixed_cam)
|
528 |
+
|
529 |
+
for i in tqdm.trange(iters):
|
530 |
+
self.train_step()
|
531 |
+
if self.gui:
|
532 |
+
self.viser_gui.update()
|
533 |
+
if self.opt.mesh_format == 'frames':
|
534 |
+
self.save_model(mode='frames', interp=4)
|
535 |
+
elif self.opt.mesh_format == 'obj':
|
536 |
+
self.save_model(mode='geo+tex')
|
537 |
+
|
538 |
+
if self.opt.save_model:
|
539 |
+
self.save_model(mode='model')
|
540 |
+
|
541 |
+
# render eval
|
542 |
+
image_list =[]
|
543 |
+
nframes = self.opt.batch_size * 7 + 15 * 7
|
544 |
+
hor = 180
|
545 |
+
delta_hor = 45 / 15
|
546 |
+
delta_time = 1
|
547 |
+
for i in range(8):
|
548 |
+
time = 0
|
549 |
+
for j in range(self.opt.batch_size + 15):
|
550 |
+
pose = orbit_camera(self.opt.elevation, hor-180, self.opt.radius)
|
551 |
+
cur_cam = MiniCam(
|
552 |
+
pose,
|
553 |
+
512,
|
554 |
+
512,
|
555 |
+
self.cam.fovy,
|
556 |
+
self.cam.fovx,
|
557 |
+
self.cam.near,
|
558 |
+
self.cam.far,
|
559 |
+
time=time
|
560 |
+
)
|
561 |
+
with torch.no_grad():
|
562 |
+
outputs = self.renderer.render(cur_cam)
|
563 |
+
|
564 |
+
out = outputs["image"].cpu().detach().numpy().astype(np.float32)
|
565 |
+
out = np.transpose(out, (1, 2, 0))
|
566 |
+
out = np.uint8(out*255)
|
567 |
+
image_list.append(out)
|
568 |
+
|
569 |
+
time = (time + delta_time) % self.opt.batch_size
|
570 |
+
if j >= self.opt.batch_size:
|
571 |
+
hor = (hor+delta_hor) % 360
|
572 |
+
|
573 |
+
|
574 |
+
imageio.mimwrite(f'vis_data/{opt.save_path}.mp4', image_list, fps=7)
|
575 |
+
|
576 |
+
if self.gui:
|
577 |
+
while True:
|
578 |
+
self.viser_gui.update()
|
579 |
+
|
580 |
+
def process(config, input_path, guidance):
|
581 |
+
# override default config from cli
|
582 |
+
opt = OmegaConf.load(config)
|
583 |
+
opt.input = input_path
|
584 |
+
opt.save_path = os.path.splitext(os.path.basename(opt.input))[0] if opt.save_path == '' else opt.save_path
|
585 |
+
|
586 |
+
|
587 |
+
# auto find mesh from stage 1
|
588 |
+
opt.load = os.path.join(opt.outdir, opt.save_path + '_model.ply')
|
589 |
+
|
590 |
+
gui = GUI(opt, guidance)
|
591 |
+
|
592 |
+
gui.train(opt.iters)
|
593 |
+
|
594 |
+
|
595 |
+
if __name__ == "__main__":
|
596 |
+
import argparse
|
597 |
+
from omegaconf import OmegaConf
|
598 |
+
|
599 |
+
parser = argparse.ArgumentParser()
|
600 |
+
parser.add_argument("--config", required=True, help="path to the yaml config file")
|
601 |
+
args, extras = parser.parse_known_args()
|
602 |
+
|
603 |
+
# override default config from cli
|
604 |
+
opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras))
|
605 |
+
opt.save_path = os.path.splitext(os.path.basename(opt.input))[0] if opt.save_path == '' else opt.save_path
|
606 |
+
|
607 |
+
|
608 |
+
# auto find mesh from stage 1
|
609 |
+
opt.load = os.path.join(opt.outdir, opt.save_path + '_model.ply')
|
610 |
+
|
611 |
+
gui = GUI(opt)
|
612 |
+
|
613 |
+
gui.train(opt.iters)
|
614 |
+
|
615 |
+
|
616 |
+
# python main_4d.py --config configs/4d_low.yaml input=data/CONSISTENT4D_DATA/in-the-wild/blooming_rose
|