Spaces:
Running
on
L4
Running
on
L4
FrozenBurning
commited on
Commit
•
cb029d0
1
Parent(s):
e3760d0
Update fast uv unwrap
Browse files- app.py +15 -14
- configs/inference_dit.yml +1 -0
- inference.py +26 -25
- requirements.txt +2 -1
- utils/uv_unwrap.py +685 -0
app.py
CHANGED
@@ -139,23 +139,23 @@ def process(input_cond, input_num_steps, input_seed=42, input_cfg=6.0):
|
|
139 |
recon_param = torch.concat([recon_srt_param, recon_feat_param], dim=-1)
|
140 |
visualize_video_primvolume(config.output_dir, batch, recon_param, 15, rm, device)
|
141 |
prim_params = {'srt_param': recon_srt_param[0].detach().cpu(), 'feat_param': recon_feat_param[0].detach().cpu()}
|
142 |
-
|
143 |
|
144 |
-
|
145 |
-
|
146 |
-
def export_mesh(remesh="No", mc_resolution=256, decimate=100000):
|
147 |
# exporting GLB mesh
|
148 |
output_glb_path = os.path.join(config.output_dir, GRADIO_GLB_PATH)
|
149 |
if remesh == "No":
|
150 |
config.inference.remesh = False
|
151 |
elif remesh == "Yes":
|
152 |
config.inference.remesh = True
|
|
|
|
|
|
|
|
|
153 |
config.inference.decimate = decimate
|
154 |
config.inference.mc_resolution = mc_resolution
|
155 |
config.inference.batch_size = 8192
|
156 |
-
|
157 |
-
primx_ckpt_weight = torch.load(denoise_param_path, map_location='cpu')['model_state_dict']
|
158 |
-
model_primx.load_state_dict(primx_ckpt_weight)
|
159 |
model_primx.to(device)
|
160 |
model_primx.eval()
|
161 |
with torch.no_grad():
|
@@ -179,6 +179,7 @@ _DESCRIPTION = '''
|
|
179 |
block = gr.Blocks(title=_TITLE).queue()
|
180 |
with block:
|
181 |
current_fg_state = gr.State()
|
|
|
182 |
with gr.Row():
|
183 |
with gr.Column(scale=1):
|
184 |
gr.Markdown('# ' + _TITLE)
|
@@ -192,17 +193,17 @@ with block:
|
|
192 |
# background removal
|
193 |
removal_previewer = gr.Image(label="Background Removal Preview", type='pil', interactive=False)
|
194 |
# inference steps
|
195 |
-
input_num_steps = gr.Radio(choices=[25, 50, 100, 200], label="DDIM steps", value=25)
|
196 |
# random seed
|
197 |
input_cfg = gr.Slider(label="CFG scale", minimum=0, maximum=15, step=0.5, value=6, info="Typically CFG in a range of 4-7")
|
198 |
# random seed
|
199 |
input_seed = gr.Slider(label="random seed", minimum=0, maximum=10000, step=1, value=42, info="Try different seed if the result is not satisfying as this is a generative model!")
|
200 |
-
# gen button
|
201 |
-
button_gen = gr.Button(value="Generate", interactive=False)
|
202 |
with gr.Row():
|
203 |
-
input_mc_resolution = gr.Radio(choices=[64, 128, 256], label="MC Resolution", value=128, info="Cube resolution for mesh extraction")
|
204 |
input_remesh = gr.Radio(choices=["No", "Yes"], label="Remesh", value="No", info="Remesh or not?")
|
205 |
-
|
|
|
|
|
206 |
|
207 |
with gr.Column(scale=1):
|
208 |
with gr.Row():
|
@@ -246,8 +247,8 @@ with block:
|
|
246 |
)
|
247 |
|
248 |
input_image.change(background_remove_process, inputs=[input_image], outputs=[button_gen, current_fg_state, removal_previewer])
|
249 |
-
button_gen.click(process, inputs=[current_fg_state, input_num_steps, input_seed, input_cfg], outputs=[output_rgb_video, output_prim_video, output_mat_video, export_glb_btn])
|
250 |
-
|
251 |
|
252 |
gr.Examples(
|
253 |
examples=[
|
|
|
139 |
recon_param = torch.concat([recon_srt_param, recon_feat_param], dim=-1)
|
140 |
visualize_video_primvolume(config.output_dir, batch, recon_param, 15, rm, device)
|
141 |
prim_params = {'srt_param': recon_srt_param[0].detach().cpu(), 'feat_param': recon_feat_param[0].detach().cpu()}
|
142 |
+
return output_rgb_video_path, output_prim_video_path, output_mat_video_path, gr.update(interactive=True), prim_params
|
143 |
|
144 |
+
def export_mesh(prim_params, uv_unwrap="Faster", remesh="No", mc_resolution=256, decimate=100000):
|
|
|
|
|
145 |
# exporting GLB mesh
|
146 |
output_glb_path = os.path.join(config.output_dir, GRADIO_GLB_PATH)
|
147 |
if remesh == "No":
|
148 |
config.inference.remesh = False
|
149 |
elif remesh == "Yes":
|
150 |
config.inference.remesh = True
|
151 |
+
if uv_unwrap == "Faster":
|
152 |
+
config.inference.fast_unwrap = True
|
153 |
+
elif uv_unwrap == "Better":
|
154 |
+
config.inference.fast_unwrap = False
|
155 |
config.inference.decimate = decimate
|
156 |
config.inference.mc_resolution = mc_resolution
|
157 |
config.inference.batch_size = 8192
|
158 |
+
model_primx.load_state_dict(prim_params)
|
|
|
|
|
159 |
model_primx.to(device)
|
160 |
model_primx.eval()
|
161 |
with torch.no_grad():
|
|
|
179 |
block = gr.Blocks(title=_TITLE).queue()
|
180 |
with block:
|
181 |
current_fg_state = gr.State()
|
182 |
+
prim_param_state = gr.State()
|
183 |
with gr.Row():
|
184 |
with gr.Column(scale=1):
|
185 |
gr.Markdown('# ' + _TITLE)
|
|
|
193 |
# background removal
|
194 |
removal_previewer = gr.Image(label="Background Removal Preview", type='pil', interactive=False)
|
195 |
# inference steps
|
196 |
+
input_num_steps = gr.Radio(choices=[25, 50, 100, 200], label="DDIM steps. Larger for robustness but slower.", value=25)
|
197 |
# random seed
|
198 |
input_cfg = gr.Slider(label="CFG scale", minimum=0, maximum=15, step=0.5, value=6, info="Typically CFG in a range of 4-7")
|
199 |
# random seed
|
200 |
input_seed = gr.Slider(label="random seed", minimum=0, maximum=10000, step=1, value=42, info="Try different seed if the result is not satisfying as this is a generative model!")
|
|
|
|
|
201 |
with gr.Row():
|
202 |
+
input_mc_resolution = gr.Radio(choices=[64, 128, 256], label="MC Resolution", value=128, info="Cube resolution for mesh extraction. Larger for better quality but slower.")
|
203 |
input_remesh = gr.Radio(choices=["No", "Yes"], label="Remesh", value="No", info="Remesh or not?")
|
204 |
+
input_unwrap = gr.Radio(choices=["Faster", "Better"], label="UV", value="Faster", info="UV unwrapping algorithm. Trade-off between quality and speed.")
|
205 |
+
# gen button
|
206 |
+
button_gen = gr.Button(value="Generate", interactive=False)
|
207 |
|
208 |
with gr.Column(scale=1):
|
209 |
with gr.Row():
|
|
|
247 |
)
|
248 |
|
249 |
input_image.change(background_remove_process, inputs=[input_image], outputs=[button_gen, current_fg_state, removal_previewer])
|
250 |
+
button_gen.click(process, inputs=[current_fg_state, input_num_steps, input_seed, input_cfg], outputs=[output_rgb_video, output_prim_video, output_mat_video, export_glb_btn, prim_param_state])
|
251 |
+
prim_param_state.change(export_mesh, inputs=[prim_param_state, input_unwrap, input_remesh, input_mc_resolution], outputs=[output_glb, hdr_row])
|
252 |
|
253 |
gr.Examples(
|
254 |
examples=[
|
configs/inference_dit.yml
CHANGED
@@ -10,6 +10,7 @@ inference:
|
|
10 |
seed: ${global_seed}
|
11 |
precision: fp16
|
12 |
export_glb: True
|
|
|
13 |
decimate: 100000
|
14 |
mc_resolution: 256
|
15 |
batch_size: 4096
|
|
|
10 |
seed: ${global_seed}
|
11 |
precision: fp16
|
12 |
export_glb: True
|
13 |
+
fast_unwrap: False
|
14 |
decimate: 100000
|
15 |
mc_resolution: 256
|
16 |
batch_size: 4096
|
inference.py
CHANGED
@@ -25,8 +25,9 @@ from scipy.ndimage import binary_dilation, binary_erosion
|
|
25 |
from sklearn.neighbors import NearestNeighbors
|
26 |
from utils.meshutils import clean_mesh, decimate_mesh
|
27 |
from utils.mesh import Mesh
|
|
|
28 |
logger = logging.getLogger("inference.py")
|
29 |
-
|
30 |
|
31 |
def remove_background(image: PIL.Image.Image,
|
32 |
rembg_session = None,
|
@@ -114,22 +115,31 @@ def extract_texmesh(args, model, output_path, device):
|
|
114 |
w0 = 1024
|
115 |
ssaa = 1
|
116 |
fp16 = True
|
117 |
-
glctx = dr.RasterizeCudaContext()
|
118 |
v_np = vertices.astype(np.float32)
|
119 |
f_np = triangles.astype(np.int64)
|
120 |
v = torch.from_numpy(vertices).float().contiguous().to(device)
|
121 |
-
f = torch.from_numpy(triangles.astype(np.int64)).
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
vt = torch.from_numpy(vt_np.astype(np.float32)).float().contiguous().to(device)
|
135 |
ft = torch.from_numpy(ft_np.astype(np.int64)).int().contiguous().to(device)
|
@@ -143,8 +153,8 @@ def extract_texmesh(args, model, output_path, device):
|
|
143 |
h, w = h0, w0
|
144 |
|
145 |
rast, _ = dr.rasterize(glctx, uv.unsqueeze(0), ft, (h, w)) # [1, h, w, 4]
|
146 |
-
xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, h, w, 3]
|
147 |
-
mask, _ = dr.interpolate(torch.ones_like(v[:, :1]).unsqueeze(0), rast, f) # [1, h, w, 1]
|
148 |
# masked query
|
149 |
xyzs = xyzs.view(-1, 3)
|
150 |
mask = (mask > 0).view(-1)
|
@@ -182,15 +192,6 @@ def extract_texmesh(args, model, output_path, device):
|
|
182 |
knn = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(search_coords)
|
183 |
_, indices = knn.kneighbors(inpaint_coords)
|
184 |
feats[tuple(inpaint_coords.T)] = feats[tuple(search_coords[indices[:, 0]].T)]
|
185 |
-
# do ssaa after the NN search, in numpy
|
186 |
-
feats0 = cv2.cvtColor(feats[..., :3].astype(np.uint8), cv2.COLOR_RGB2BGR) # albedo
|
187 |
-
feats1 = cv2.cvtColor(feats[..., 3:].astype(np.uint8), cv2.COLOR_RGB2BGR) # visibility features
|
188 |
-
if ssaa > 1:
|
189 |
-
feats0 = cv2.resize(feats0, (w0, h0), interpolation=cv2.INTER_LINEAR)
|
190 |
-
feats1 = cv2.resize(feats1, (w0, h0), interpolation=cv2.INTER_LINEAR)
|
191 |
-
|
192 |
-
cv2.imwrite(os.path.join(ins_dir, f'texture.jpg'), feats0)
|
193 |
-
cv2.imwrite(os.path.join(ins_dir, f'roughness_metallic.jpg'), feats1)
|
194 |
|
195 |
target_mesh = Mesh(v=torch.from_numpy(v_np).contiguous(), f=torch.from_numpy(f_np).contiguous(), ft=ft.contiguous(), vt=torch.from_numpy(vt_np).contiguous(), albedo=torch.from_numpy(feats[..., :3]) / 255, metallicRoughness=torch.from_numpy(feats[..., 3:]) / 255)
|
196 |
target_mesh.write(os.path.join(ins_dir, f'pbr_mesh.glb'))
|
|
|
25 |
from sklearn.neighbors import NearestNeighbors
|
26 |
from utils.meshutils import clean_mesh, decimate_mesh
|
27 |
from utils.mesh import Mesh
|
28 |
+
from utils.uv_unwrap import box_projection_uv_unwrap, compute_vertex_normal
|
29 |
logger = logging.getLogger("inference.py")
|
30 |
+
glctx = dr.RasterizeCudaContext()
|
31 |
|
32 |
def remove_background(image: PIL.Image.Image,
|
33 |
rembg_session = None,
|
|
|
115 |
w0 = 1024
|
116 |
ssaa = 1
|
117 |
fp16 = True
|
|
|
118 |
v_np = vertices.astype(np.float32)
|
119 |
f_np = triangles.astype(np.int64)
|
120 |
v = torch.from_numpy(vertices).float().contiguous().to(device)
|
121 |
+
f = torch.from_numpy(triangles.astype(np.int64)).to(torch.int64).contiguous().to(device)
|
122 |
+
if args.fast_unwrap:
|
123 |
+
print(f'[INFO] running box-based fast unwrapping to unwrap UVs for mesh: v={v_np.shape} f={f_np.shape}')
|
124 |
+
v_normal = compute_vertex_normal(v, f)
|
125 |
+
uv, indices = box_projection_uv_unwrap(v, v_normal, f, 0.02)
|
126 |
+
indv_v = v[f].reshape(-1, 3)
|
127 |
+
indv_faces = torch.arange(indv_v.shape[0], device=device, dtype=f.dtype).reshape(-1, 3)
|
128 |
+
uv_flat = uv[indices].reshape((-1, 2))
|
129 |
+
v = indv_v.contiguous()
|
130 |
+
f = indv_faces.contiguous()
|
131 |
+
ft_np = f.cpu().numpy()
|
132 |
+
vt_np = uv_flat.cpu().numpy()
|
133 |
+
else:
|
134 |
+
print(f'[INFO] running xatlas to unwrap UVs for mesh: v={v_np.shape} f={f_np.shape}')
|
135 |
+
# unwrap uv in contracted space
|
136 |
+
atlas = xatlas.Atlas()
|
137 |
+
atlas.add_mesh(v_np, f_np)
|
138 |
+
chart_options = xatlas.ChartOptions()
|
139 |
+
chart_options.max_iterations = 0 # disable merge_chart for faster unwrap...
|
140 |
+
pack_options = xatlas.PackOptions()
|
141 |
+
atlas.generate(chart_options=chart_options, pack_options=pack_options)
|
142 |
+
_, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]
|
143 |
|
144 |
vt = torch.from_numpy(vt_np.astype(np.float32)).float().contiguous().to(device)
|
145 |
ft = torch.from_numpy(ft_np.astype(np.int64)).int().contiguous().to(device)
|
|
|
153 |
h, w = h0, w0
|
154 |
|
155 |
rast, _ = dr.rasterize(glctx, uv.unsqueeze(0), ft, (h, w)) # [1, h, w, 4]
|
156 |
+
xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f.int()) # [1, h, w, 3]
|
157 |
+
mask, _ = dr.interpolate(torch.ones_like(v[:, :1]).unsqueeze(0), rast, f.int()) # [1, h, w, 1]
|
158 |
# masked query
|
159 |
xyzs = xyzs.view(-1, 3)
|
160 |
mask = (mask > 0).view(-1)
|
|
|
192 |
knn = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(search_coords)
|
193 |
_, indices = knn.kneighbors(inpaint_coords)
|
194 |
feats[tuple(inpaint_coords.T)] = feats[tuple(search_coords[indices[:, 0]].T)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
|
196 |
target_mesh = Mesh(v=torch.from_numpy(v_np).contiguous(), f=torch.from_numpy(f_np).contiguous(), ft=ft.contiguous(), vt=torch.from_numpy(vt_np).contiguous(), albedo=torch.from_numpy(feats[..., :3]) / 255, metallicRoughness=torch.from_numpy(feats[..., 3:]) / 255)
|
197 |
target_mesh.write(os.path.join(ins_dir, f'pbr_mesh.glb'))
|
requirements.txt
CHANGED
@@ -21,4 +21,5 @@ diffusers==0.19.3
|
|
21 |
ninja
|
22 |
imageio
|
23 |
imageio-ffmpeg
|
24 |
-
gradio-litmodel3d==0.0.1
|
|
|
|
21 |
ninja
|
22 |
imageio
|
23 |
imageio-ffmpeg
|
24 |
+
gradio-litmodel3d==0.0.1
|
25 |
+
jaxtyping==0.2.31
|
utils/uv_unwrap.py
ADDED
@@ -0,0 +1,685 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from jaxtyping import Bool, Float, Integer, Int, Num
|
7 |
+
from torch import Tensor
|
8 |
+
|
9 |
+
def tri_winding(tri: Float[Tensor, "*B 3 2"]) -> Float[Tensor, "*B 3 3"]:
|
10 |
+
# One pad for determinant
|
11 |
+
tri_sq = F.pad(tri, (0, 1), "constant", 1.0)
|
12 |
+
det_tri = torch.det(tri_sq)
|
13 |
+
tri_rev = torch.cat(
|
14 |
+
(tri_sq[..., 0:1, :], tri_sq[..., 2:3, :], tri_sq[..., 1:2, :]), -2
|
15 |
+
)
|
16 |
+
tri_sq[det_tri < 0] = tri_rev[det_tri < 0]
|
17 |
+
return tri_sq
|
18 |
+
|
19 |
+
def triangle_intersection_2d(
|
20 |
+
t1: Float[Tensor, "*B 3 2"],
|
21 |
+
t2: Float[Tensor, "*B 3 2"],
|
22 |
+
eps=1e-12,
|
23 |
+
) -> Float[Tensor, "*B"]: # noqa: F821
|
24 |
+
"""Returns True if triangles collide, False otherwise"""
|
25 |
+
|
26 |
+
def chk_edge(x: Float[Tensor, "*B 3 3"]) -> Bool[Tensor, "*B"]: # noqa: F821
|
27 |
+
logdetx = torch.logdet(x.double())
|
28 |
+
if eps is None:
|
29 |
+
return ~torch.isfinite(logdetx)
|
30 |
+
return ~(torch.isfinite(logdetx) & (logdetx > math.log(eps)))
|
31 |
+
|
32 |
+
t1s = tri_winding(t1)
|
33 |
+
t2s = tri_winding(t2)
|
34 |
+
|
35 |
+
# Assume the triangles do not collide in the begging
|
36 |
+
ret = torch.zeros(t1.shape[0], dtype=torch.bool, device=t1.device)
|
37 |
+
for i in range(3):
|
38 |
+
edge = torch.roll(t1s, i, dims=1)[:, :2, :]
|
39 |
+
# Check if all points of triangle 2 lay on the external side of edge E.
|
40 |
+
# If this is the case the triangle do not collide
|
41 |
+
upd = (
|
42 |
+
chk_edge(torch.cat((edge, t2s[:, 0:1]), 1))
|
43 |
+
& chk_edge(torch.cat((edge, t2s[:, 1:2]), 1))
|
44 |
+
& chk_edge(torch.cat((edge, t2s[:, 2:3]), 1))
|
45 |
+
)
|
46 |
+
# Here no collision is still True due to inversion
|
47 |
+
ret = ret | upd
|
48 |
+
|
49 |
+
for i in range(3):
|
50 |
+
edge = torch.roll(t2s, i, dims=1)[:, :2, :]
|
51 |
+
|
52 |
+
upd = (
|
53 |
+
chk_edge(torch.cat((edge, t1s[:, 0:1]), 1))
|
54 |
+
& chk_edge(torch.cat((edge, t1s[:, 1:2]), 1))
|
55 |
+
& chk_edge(torch.cat((edge, t1s[:, 2:3]), 1))
|
56 |
+
)
|
57 |
+
# Here no collision is still True due to inversion
|
58 |
+
ret = ret | upd
|
59 |
+
|
60 |
+
return ~ret # Do the inversion
|
61 |
+
|
62 |
+
def dot(x, y, dim=-1):
|
63 |
+
return torch.sum(x * y, dim, keepdim=True)
|
64 |
+
|
65 |
+
def compute_vertex_normal(v_pos, t_pos_idx):
|
66 |
+
i0 = t_pos_idx[:, 0]
|
67 |
+
i1 = t_pos_idx[:, 1]
|
68 |
+
i2 = t_pos_idx[:, 2]
|
69 |
+
v0 = v_pos[i0, :]
|
70 |
+
v1 = v_pos[i1, :]
|
71 |
+
v2 = v_pos[i2, :]
|
72 |
+
face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
|
73 |
+
# Splat face normals to vertices
|
74 |
+
v_nrm = torch.zeros_like(v_pos)
|
75 |
+
v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
|
76 |
+
v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
|
77 |
+
v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
|
78 |
+
# Normalize, replace zero (degenerated) normals with some default value
|
79 |
+
v_nrm = torch.where(
|
80 |
+
dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
|
81 |
+
)
|
82 |
+
v_nrm = F.normalize(v_nrm, dim=1)
|
83 |
+
if torch.is_anomaly_enabled():
|
84 |
+
assert torch.all(torch.isfinite(v_nrm))
|
85 |
+
return v_nrm
|
86 |
+
|
87 |
+
def _box_assign_vertex_to_cube_face(
|
88 |
+
vertex_positions: Float[Tensor, "Nv 3"],
|
89 |
+
vertex_normals: Float[Tensor, "Nv 3"],
|
90 |
+
triangle_idxs: Integer[Tensor, "Nf 3"],
|
91 |
+
bbox: Float[Tensor, "2 3"],
|
92 |
+
) -> Tuple[Float[Tensor, "Nf 3 2"], Integer[Tensor, "Nf 3"]]:
|
93 |
+
# Test to not have a scaled model to fit the space better
|
94 |
+
# bbox_min = bbox[:1].mean(-1, keepdim=True)
|
95 |
+
# bbox_max = bbox[1:].mean(-1, keepdim=True)
|
96 |
+
# v_pos_normalized = (vertex_positions - bbox_min) / (bbox_max - bbox_min)
|
97 |
+
|
98 |
+
# Create a [0, 1] normalized vertex position
|
99 |
+
v_pos_normalized = (vertex_positions - bbox[:1]) / (bbox[1:] - bbox[:1])
|
100 |
+
# And to [-1, 1]
|
101 |
+
v_pos_normalized = 2.0 * v_pos_normalized - 1.0
|
102 |
+
|
103 |
+
# Get all vertex positions for each triangle
|
104 |
+
# Now how do we define to which face the triangle belongs? Mean face pos? Max vertex pos?
|
105 |
+
v0 = v_pos_normalized[triangle_idxs[:, 0]]
|
106 |
+
v1 = v_pos_normalized[triangle_idxs[:, 1]]
|
107 |
+
v2 = v_pos_normalized[triangle_idxs[:, 2]]
|
108 |
+
tri_stack = torch.stack([v0, v1, v2], dim=1)
|
109 |
+
|
110 |
+
vn0 = vertex_normals[triangle_idxs[:, 0]]
|
111 |
+
vn1 = vertex_normals[triangle_idxs[:, 1]]
|
112 |
+
vn2 = vertex_normals[triangle_idxs[:, 2]]
|
113 |
+
tri_stack_nrm = torch.stack([vn0, vn1, vn2], dim=1)
|
114 |
+
|
115 |
+
# Just average the normals per face
|
116 |
+
face_normal = F.normalize(torch.sum(tri_stack_nrm, 1), eps=1e-6, dim=-1)
|
117 |
+
|
118 |
+
# Now decide based on the face normal in which box map we project
|
119 |
+
# abs_x, abs_y, abs_z = tri_stack_nrm.abs().unbind(-1)
|
120 |
+
abs_x, abs_y, abs_z = tri_stack.abs().unbind(-1)
|
121 |
+
|
122 |
+
axis = torch.tensor(
|
123 |
+
[
|
124 |
+
[1, 0, 0], # 0
|
125 |
+
[-1, 0, 0], # 1
|
126 |
+
[0, 1, 0], # 2
|
127 |
+
[0, -1, 0], # 3
|
128 |
+
[0, 0, 1], # 4
|
129 |
+
[0, 0, -1], # 5
|
130 |
+
],
|
131 |
+
device=face_normal.device,
|
132 |
+
dtype=face_normal.dtype,
|
133 |
+
)
|
134 |
+
face_normal_axis = (face_normal[:, None] * axis[None]).sum(-1)
|
135 |
+
index = face_normal_axis.argmax(-1)
|
136 |
+
|
137 |
+
max_axis, uc, vc = (
|
138 |
+
torch.ones_like(abs_x),
|
139 |
+
torch.zeros_like(tri_stack[..., :1]),
|
140 |
+
torch.zeros_like(tri_stack[..., :1]),
|
141 |
+
)
|
142 |
+
mask_pos_x = index == 0
|
143 |
+
max_axis[mask_pos_x] = abs_x[mask_pos_x]
|
144 |
+
uc[mask_pos_x] = tri_stack[mask_pos_x][..., 1:2]
|
145 |
+
vc[mask_pos_x] = -tri_stack[mask_pos_x][..., -1:]
|
146 |
+
|
147 |
+
mask_neg_x = index == 1
|
148 |
+
max_axis[mask_neg_x] = abs_x[mask_neg_x]
|
149 |
+
uc[mask_neg_x] = tri_stack[mask_neg_x][..., 1:2]
|
150 |
+
vc[mask_neg_x] = -tri_stack[mask_neg_x][..., -1:]
|
151 |
+
|
152 |
+
mask_pos_y = index == 2
|
153 |
+
max_axis[mask_pos_y] = abs_y[mask_pos_y]
|
154 |
+
uc[mask_pos_y] = tri_stack[mask_pos_y][..., 0:1]
|
155 |
+
vc[mask_pos_y] = -tri_stack[mask_pos_y][..., -1:]
|
156 |
+
|
157 |
+
mask_neg_y = index == 3
|
158 |
+
max_axis[mask_neg_y] = abs_y[mask_neg_y]
|
159 |
+
uc[mask_neg_y] = tri_stack[mask_neg_y][..., 0:1]
|
160 |
+
vc[mask_neg_y] = -tri_stack[mask_neg_y][..., -1:]
|
161 |
+
|
162 |
+
mask_pos_z = index == 4
|
163 |
+
max_axis[mask_pos_z] = abs_z[mask_pos_z]
|
164 |
+
uc[mask_pos_z] = tri_stack[mask_pos_z][..., 0:1]
|
165 |
+
vc[mask_pos_z] = tri_stack[mask_pos_z][..., 1:2]
|
166 |
+
|
167 |
+
mask_neg_z = index == 5
|
168 |
+
max_axis[mask_neg_z] = abs_z[mask_neg_z]
|
169 |
+
uc[mask_neg_z] = tri_stack[mask_neg_z][..., 0:1]
|
170 |
+
vc[mask_neg_z] = -tri_stack[mask_neg_z][..., 1:2]
|
171 |
+
|
172 |
+
# UC from [-1, 1] to [0, 1]
|
173 |
+
max_dim_div = max_axis.max(dim=0, keepdims=True).values
|
174 |
+
uc = ((uc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1)
|
175 |
+
vc = ((vc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1)
|
176 |
+
|
177 |
+
uv = torch.stack([uc, vc], dim=-1)
|
178 |
+
|
179 |
+
return uv, index
|
180 |
+
|
181 |
+
|
182 |
+
def _assign_faces_uv_to_atlas_index(
|
183 |
+
vertex_positions: Float[Tensor, "Nv 3"],
|
184 |
+
triangle_idxs: Integer[Tensor, "Nf 3"],
|
185 |
+
face_uv: Float[Tensor, "Nf 3 2"],
|
186 |
+
face_index: Integer[Tensor, "Nf 3"],
|
187 |
+
) -> Integer[Tensor, "Nf"]: # noqa: F821
|
188 |
+
triangle_pos = vertex_positions[triangle_idxs]
|
189 |
+
# We need to do perform 3 overlap checks.
|
190 |
+
# The first set is placed in the upper two thirds of the UV atlas.
|
191 |
+
# Conceptually, this is the direct visible surfaces from the each cube side
|
192 |
+
# The second set is placed in the lower thirds and the left half of the UV atlas.
|
193 |
+
# This is the first set of occluded surfaces. They will also be saved in the projected fashion
|
194 |
+
# The third pass finds all non assigned faces. They will be placed in the bottom right half of
|
195 |
+
# the UV atlas in scattered fashion.
|
196 |
+
assign_idx = face_index.clone()
|
197 |
+
for overlap_step in range(3):
|
198 |
+
overlapping_indicator = torch.zeros_like(assign_idx, dtype=torch.bool)
|
199 |
+
for i in range(overlap_step * 6, (overlap_step + 1) * 6):
|
200 |
+
mask = assign_idx == i
|
201 |
+
if not mask.any():
|
202 |
+
continue
|
203 |
+
# Get all elements belonging to the projection face
|
204 |
+
uv_triangle = face_uv[mask]
|
205 |
+
cur_triangle_pos = triangle_pos[mask]
|
206 |
+
# Find the center of the uv coordinates
|
207 |
+
center_uv = uv_triangle.mean(dim=1, keepdim=True)
|
208 |
+
# And also the radius of the triangle
|
209 |
+
uv_triangle_radius = (uv_triangle - center_uv).norm(dim=-1).max(-1).values
|
210 |
+
|
211 |
+
potentially_overlapping_mask = (
|
212 |
+
# Find all close triangles
|
213 |
+
(center_uv[None, ...] - center_uv[:, None]).norm(dim=-1)
|
214 |
+
# Do not select the same element by offseting with an large valued identity matrix
|
215 |
+
+ torch.eye(
|
216 |
+
uv_triangle.shape[0],
|
217 |
+
device=uv_triangle.device,
|
218 |
+
dtype=uv_triangle.dtype,
|
219 |
+
).unsqueeze(-1)
|
220 |
+
* 1000
|
221 |
+
)
|
222 |
+
# Mark all potentially overlapping triangles to reduce the number of triangle intersection tests
|
223 |
+
potentially_overlapping_mask = (
|
224 |
+
potentially_overlapping_mask
|
225 |
+
<= (uv_triangle_radius.view(-1, 1, 1) * 3.0)
|
226 |
+
).squeeze(-1)
|
227 |
+
overlap_coords = torch.stack(torch.where(potentially_overlapping_mask), -1)
|
228 |
+
|
229 |
+
# Only unique triangles (A|B and B|A should be the same)
|
230 |
+
f = torch.min(overlap_coords, dim=-1).values
|
231 |
+
s = torch.max(overlap_coords, dim=-1).values
|
232 |
+
overlap_coords = torch.unique(torch.stack([f, s], dim=1), dim=0)
|
233 |
+
first, second = overlap_coords.unbind(-1)
|
234 |
+
|
235 |
+
# Get the triangles
|
236 |
+
tri_1 = uv_triangle[first]
|
237 |
+
tri_2 = uv_triangle[second]
|
238 |
+
|
239 |
+
# Perform the actual set with the reduced number of potentially overlapping triangles
|
240 |
+
its = triangle_intersection_2d(tri_1, tri_2, eps=1e-6)
|
241 |
+
|
242 |
+
# So we now need to detect which triangles are the occluded ones.
|
243 |
+
# We always assume the first to be the visible one (the others should move)
|
244 |
+
# In the previous step we use a lexigraphical sort to get the unique pairs
|
245 |
+
# In this we use a sort based on the orthographic projection
|
246 |
+
ax = 0 if i < 2 else 1 if i < 4 else 2
|
247 |
+
use_max = i % 2 == 1
|
248 |
+
|
249 |
+
tri1_c = cur_triangle_pos[first].mean(dim=1)
|
250 |
+
tri2_c = cur_triangle_pos[second].mean(dim=1)
|
251 |
+
|
252 |
+
mark_first = (
|
253 |
+
(tri1_c[..., ax] > tri2_c[..., ax])
|
254 |
+
if use_max
|
255 |
+
else (tri1_c[..., ax] < tri2_c[..., ax])
|
256 |
+
)
|
257 |
+
first[mark_first] = second[mark_first]
|
258 |
+
|
259 |
+
# Lastly the same index can be tested multiple times.
|
260 |
+
# If one marks it as overlapping we keep it marked as such.
|
261 |
+
# We do this by testing if it has been marked at least once.
|
262 |
+
unique_idx, rev_idx = torch.unique(first, return_inverse=True)
|
263 |
+
|
264 |
+
add = torch.zeros_like(unique_idx, dtype=torch.float32)
|
265 |
+
add.index_add_(0, rev_idx, its.float())
|
266 |
+
its_mask = add > 0
|
267 |
+
|
268 |
+
# And fill it in the overlapping indicator
|
269 |
+
idx = torch.where(mask)[0][unique_idx]
|
270 |
+
overlapping_indicator[idx] = its_mask
|
271 |
+
|
272 |
+
# Move the index to the overlap regions (shift by 6)
|
273 |
+
assign_idx[overlapping_indicator] += 6
|
274 |
+
|
275 |
+
# We do not care about the correct face placement after the first 2 slices
|
276 |
+
max_idx = 6 * 2
|
277 |
+
return assign_idx.clamp(0, max_idx)
|
278 |
+
|
279 |
+
|
280 |
+
def _find_slice_offset_and_scale(
|
281 |
+
index: Integer[Tensor, "Nf"], # noqa: F821
|
282 |
+
) -> Tuple[
|
283 |
+
Float[Tensor, "Nf"], Float[Tensor, "Nf"], Float[Tensor, "Nf"], Float[Tensor, "Nf"] # noqa: F821
|
284 |
+
]: # noqa: F821
|
285 |
+
# 6 due to the 6 cube faces
|
286 |
+
off = 1 / 3
|
287 |
+
dupl_off = 1 / 6
|
288 |
+
|
289 |
+
# Here, we need to decide how to pack the textures in the case of overlap
|
290 |
+
def x_offset_calc(x, i):
|
291 |
+
offset_calc = i // 6
|
292 |
+
# Initial coordinates - just 3x2 grid
|
293 |
+
if offset_calc == 0:
|
294 |
+
return off * x
|
295 |
+
else:
|
296 |
+
# Smaller 3x2 grid plus eventual shift to right for
|
297 |
+
# second overlap
|
298 |
+
return dupl_off * x + min(offset_calc - 1, 1) * 0.5
|
299 |
+
|
300 |
+
def y_offset_calc(x, i):
|
301 |
+
offset_calc = i // 6
|
302 |
+
# Initial coordinates - just a 3x2 grid
|
303 |
+
if offset_calc == 0:
|
304 |
+
return off * x
|
305 |
+
else:
|
306 |
+
# Smaller coordinates in the lowest row
|
307 |
+
return dupl_off * x + off * 2
|
308 |
+
|
309 |
+
offset_x = torch.zeros_like(index, dtype=torch.float32)
|
310 |
+
offset_y = torch.zeros_like(index, dtype=torch.float32)
|
311 |
+
offset_x_vals = [0, 1, 2, 0, 1, 2]
|
312 |
+
offset_y_vals = [0, 0, 0, 1, 1, 1]
|
313 |
+
for i in range(index.max().item() + 1):
|
314 |
+
mask = index == i
|
315 |
+
if not mask.any():
|
316 |
+
continue
|
317 |
+
offset_x[mask] = x_offset_calc(offset_x_vals[i % 6], i)
|
318 |
+
offset_y[mask] = y_offset_calc(offset_y_vals[i % 6], i)
|
319 |
+
|
320 |
+
div_x = torch.full_like(index, 6 // 2, dtype=torch.float32)
|
321 |
+
# All overlap elements are saved in half scale
|
322 |
+
div_x[index >= 6] = 6
|
323 |
+
div_y = div_x.clone() # Same for y
|
324 |
+
# Except for the random overlaps
|
325 |
+
div_x[index >= 12] = 2
|
326 |
+
# But the random overlaps are saved in a large block in the lower thirds
|
327 |
+
div_y[index >= 12] = 3
|
328 |
+
|
329 |
+
return offset_x, offset_y, div_x, div_y
|
330 |
+
|
331 |
+
|
332 |
+
def rotation_flip_matrix_2d(
|
333 |
+
rad: float, flip_x: bool = False, flip_y: bool = False
|
334 |
+
) -> Float[Tensor, "2 2"]:
|
335 |
+
cos = math.cos(rad)
|
336 |
+
sin = math.sin(rad)
|
337 |
+
rot_mat = torch.tensor([[cos, -sin], [sin, cos]], dtype=torch.float32)
|
338 |
+
flip_mat = torch.tensor(
|
339 |
+
[
|
340 |
+
[-1 if flip_x else 1, 0],
|
341 |
+
[0, -1 if flip_y else 1],
|
342 |
+
],
|
343 |
+
dtype=torch.float32,
|
344 |
+
)
|
345 |
+
|
346 |
+
return flip_mat @ rot_mat
|
347 |
+
|
348 |
+
|
349 |
+
def calculate_tangents(
|
350 |
+
vertex_positions: Float[Tensor, "Nv 3"],
|
351 |
+
vertex_normals: Float[Tensor, "Nv 3"],
|
352 |
+
triangle_idxs: Integer[Tensor, "Nf 3"],
|
353 |
+
face_uv: Float[Tensor, "Nf 3 2"],
|
354 |
+
) -> Float[Tensor, "Nf 3 4"]: # noqa: F821
|
355 |
+
vn_idx = [None] * 3
|
356 |
+
pos = [None] * 3
|
357 |
+
tex = face_uv.unbind(1)
|
358 |
+
for i in range(0, 3):
|
359 |
+
pos[i] = vertex_positions[triangle_idxs[:, i]]
|
360 |
+
# t_nrm_idx is always the same as t_pos_idx
|
361 |
+
vn_idx[i] = triangle_idxs[:, i]
|
362 |
+
|
363 |
+
tangents = torch.zeros_like(vertex_normals)
|
364 |
+
tansum = torch.zeros_like(vertex_normals)
|
365 |
+
|
366 |
+
# Compute tangent space for each triangle
|
367 |
+
duv1 = tex[1] - tex[0]
|
368 |
+
duv2 = tex[2] - tex[0]
|
369 |
+
dpos1 = pos[1] - pos[0]
|
370 |
+
dpos2 = pos[2] - pos[0]
|
371 |
+
|
372 |
+
tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2]
|
373 |
+
|
374 |
+
denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1]
|
375 |
+
|
376 |
+
# Avoid division by zero for degenerated texture coordinates
|
377 |
+
denom_safe = denom.clip(1e-6)
|
378 |
+
tang = tng_nom / denom_safe
|
379 |
+
|
380 |
+
# Update all 3 vertices
|
381 |
+
for i in range(0, 3):
|
382 |
+
idx = vn_idx[i][:, None].repeat(1, 3)
|
383 |
+
tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
|
384 |
+
tansum.scatter_add_(
|
385 |
+
0, idx, torch.ones_like(tang)
|
386 |
+
) # tansum[n_i] = tansum[n_i] + 1
|
387 |
+
# Also normalize it. Here we do not normalize the individual triangles first so larger area
|
388 |
+
# triangles influence the tangent space more
|
389 |
+
tangents = tangents / tansum
|
390 |
+
|
391 |
+
# Normalize and make sure tangent is perpendicular to normal
|
392 |
+
tangents = F.normalize(tangents, dim=1)
|
393 |
+
tangents = F.normalize(tangents - dot(tangents, vertex_normals) * vertex_normals)
|
394 |
+
|
395 |
+
return tangents
|
396 |
+
|
397 |
+
|
398 |
+
def _rotate_uv_slices_consistent_space(
|
399 |
+
vertex_positions: Float[Tensor, "Nv 3"],
|
400 |
+
vertex_normals: Float[Tensor, "Nv 3"],
|
401 |
+
triangle_idxs: Integer[Tensor, "Nf 3"],
|
402 |
+
uv: Float[Tensor, "Nf 3 2"],
|
403 |
+
index: Integer[Tensor, "Nf"], # noqa: F821
|
404 |
+
):
|
405 |
+
tangents = calculate_tangents(vertex_positions, vertex_normals, triangle_idxs, uv)
|
406 |
+
pos_stack = torch.stack(
|
407 |
+
[
|
408 |
+
-vertex_positions[..., 1],
|
409 |
+
vertex_positions[..., 0],
|
410 |
+
torch.zeros_like(vertex_positions[..., 0]),
|
411 |
+
],
|
412 |
+
dim=-1,
|
413 |
+
)
|
414 |
+
expected_tangents = F.normalize(
|
415 |
+
torch.linalg.cross(
|
416 |
+
vertex_normals, torch.linalg.cross(pos_stack, vertex_normals)
|
417 |
+
),
|
418 |
+
-1,
|
419 |
+
)
|
420 |
+
|
421 |
+
actual_tangents = tangents[triangle_idxs]
|
422 |
+
expected_tangents = expected_tangents[triangle_idxs]
|
423 |
+
|
424 |
+
def rotation_matrix_2d(theta):
|
425 |
+
c, s = torch.cos(theta), torch.sin(theta)
|
426 |
+
return torch.tensor([[c, -s], [s, c]])
|
427 |
+
|
428 |
+
# Now find the rotation
|
429 |
+
index_mod = index % 6 # Shouldn't happen. Just for safety
|
430 |
+
for i in range(6):
|
431 |
+
mask = index_mod == i
|
432 |
+
if not mask.any():
|
433 |
+
continue
|
434 |
+
|
435 |
+
actual_mean_tangent = actual_tangents[mask].mean(dim=(0, 1))
|
436 |
+
expected_mean_tangent = expected_tangents[mask].mean(dim=(0, 1))
|
437 |
+
|
438 |
+
dot_product = torch.dot(actual_mean_tangent, expected_mean_tangent)
|
439 |
+
cross_product = (
|
440 |
+
actual_mean_tangent[0] * expected_mean_tangent[1]
|
441 |
+
- actual_mean_tangent[1] * expected_mean_tangent[0]
|
442 |
+
)
|
443 |
+
angle = torch.atan2(cross_product, dot_product)
|
444 |
+
|
445 |
+
rot_matrix = rotation_matrix_2d(angle).to(mask.device)
|
446 |
+
# Center the uv coordinate to be in the range of -1 to 1 and 0 centered
|
447 |
+
uv_cur = uv[mask] * 2 - 1 # Center it first
|
448 |
+
# Rotate it
|
449 |
+
uv[mask] = torch.einsum("ij,nfj->nfi", rot_matrix, uv_cur)
|
450 |
+
|
451 |
+
# Rescale uv[mask] to be within the 0-1 range
|
452 |
+
uv[mask] = (uv[mask] - uv[mask].min()) / (uv[mask].max() - uv[mask].min())
|
453 |
+
|
454 |
+
return uv
|
455 |
+
|
456 |
+
|
457 |
+
def _handle_slice_uvs(
|
458 |
+
uv: Float[Tensor, "Nf 3 2"],
|
459 |
+
index: Integer[Tensor, "Nf"], # noqa: F821
|
460 |
+
island_padding: float,
|
461 |
+
max_index: int = 6 * 2,
|
462 |
+
) -> Float[Tensor, "Nf 3 2"]: # noqa: F821
|
463 |
+
uc, vc = uv.unbind(-1)
|
464 |
+
|
465 |
+
# Get the second slice (The first overlap)
|
466 |
+
index_filter = [index == i for i in range(6, max_index)]
|
467 |
+
|
468 |
+
# Normalize them to always fully fill the atlas patch
|
469 |
+
for i, fi in enumerate(index_filter):
|
470 |
+
if fi.sum() > 0:
|
471 |
+
# Scale the slice but only up to a factor of 2
|
472 |
+
# This keeps the texture resolution with the first slice in line (Half space in UV)
|
473 |
+
uc[fi] = (uc[fi] - uc[fi].min()) / (uc[fi].max() - uc[fi].min()).clip(0.5)
|
474 |
+
vc[fi] = (vc[fi] - vc[fi].min()) / (vc[fi].max() - vc[fi].min()).clip(0.5)
|
475 |
+
|
476 |
+
uc_padded = (uc * (1 - 2 * island_padding) + island_padding).clip(0, 1)
|
477 |
+
vc_padded = (vc * (1 - 2 * island_padding) + island_padding).clip(0, 1)
|
478 |
+
|
479 |
+
return torch.stack([uc_padded, vc_padded], dim=-1)
|
480 |
+
|
481 |
+
|
482 |
+
def _handle_remaining_uvs(
|
483 |
+
uv: Float[Tensor, "Nf 3 2"],
|
484 |
+
index: Integer[Tensor, "Nf"], # noqa: F821
|
485 |
+
island_padding: float,
|
486 |
+
) -> Float[Tensor, "Nf 3 2"]:
|
487 |
+
uc, vc = uv.unbind(-1)
|
488 |
+
# Get all remaining elements
|
489 |
+
remaining_filter = index >= 6 * 2
|
490 |
+
squares_left = remaining_filter.sum()
|
491 |
+
|
492 |
+
if squares_left == 0:
|
493 |
+
return uv
|
494 |
+
|
495 |
+
uc = uc[remaining_filter]
|
496 |
+
vc = vc[remaining_filter]
|
497 |
+
|
498 |
+
# Or remaining triangles are distributed in a rectangle
|
499 |
+
# The rectangle takes 0.5 of the entire uv space in width and 1/3 in height
|
500 |
+
ratio = 0.5 * (1 / 3) # 1.5
|
501 |
+
# sqrt(744/(0.5*(1/3)))
|
502 |
+
|
503 |
+
mult = math.sqrt(squares_left / ratio)
|
504 |
+
num_square_width = int(math.ceil(0.5 * mult))
|
505 |
+
num_square_height = int(math.ceil(squares_left / num_square_width))
|
506 |
+
|
507 |
+
width = 1 / num_square_width
|
508 |
+
height = 1 / num_square_height
|
509 |
+
|
510 |
+
# The idea is again to keep the texture resolution consistent with the first slice
|
511 |
+
# This only occupys half the region in the texture chart but the scaling on the squares
|
512 |
+
# assumes full coverage.
|
513 |
+
clip_val = min(width, height) * 1.5
|
514 |
+
# Now normalize the UVs with taking into account the maximum scaling
|
515 |
+
uc = (uc - uc.min(dim=1, keepdim=True).values) / (
|
516 |
+
uc.amax(dim=1, keepdim=True) - uc.amin(dim=1, keepdim=True)
|
517 |
+
).clip(clip_val)
|
518 |
+
vc = (vc - vc.min(dim=1, keepdim=True).values) / (
|
519 |
+
vc.amax(dim=1, keepdim=True) - vc.amin(dim=1, keepdim=True)
|
520 |
+
).clip(clip_val)
|
521 |
+
# Add a small padding
|
522 |
+
uc = (
|
523 |
+
uc * (1 - island_padding * num_square_width * 0.5)
|
524 |
+
+ island_padding * num_square_width * 0.25
|
525 |
+
).clip(0, 1)
|
526 |
+
vc = (
|
527 |
+
vc * (1 - island_padding * num_square_height * 0.5)
|
528 |
+
+ island_padding * num_square_height * 0.25
|
529 |
+
).clip(0, 1)
|
530 |
+
|
531 |
+
uc = uc * width
|
532 |
+
vc = vc * height
|
533 |
+
|
534 |
+
# And calculate offsets for each element
|
535 |
+
idx = torch.arange(uc.shape[0], device=uc.device, dtype=torch.int32)
|
536 |
+
x_idx = idx % num_square_width
|
537 |
+
y_idx = idx // num_square_width
|
538 |
+
# And move each triangle to its own spot
|
539 |
+
uc = uc + x_idx[:, None] * width
|
540 |
+
vc = vc + y_idx[:, None] * height
|
541 |
+
|
542 |
+
uc = (uc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1)
|
543 |
+
vc = (vc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1)
|
544 |
+
|
545 |
+
uv[remaining_filter] = torch.stack([uc, vc], dim=-1)
|
546 |
+
|
547 |
+
return uv
|
548 |
+
|
549 |
+
|
550 |
+
def _distribute_individual_uvs_in_atlas(
|
551 |
+
face_uv: Float[Tensor, "Nf 3 2"],
|
552 |
+
assigned_faces: Integer[Tensor, "Nf"], # noqa: F821
|
553 |
+
offset_x: Float[Tensor, "Nf"], # noqa: F821
|
554 |
+
offset_y: Float[Tensor, "Nf"], # noqa: F821
|
555 |
+
div_x: Float[Tensor, "Nf"], # noqa: F821
|
556 |
+
div_y: Float[Tensor, "Nf"], # noqa: F821
|
557 |
+
island_padding: float,
|
558 |
+
):
|
559 |
+
# Place the slice first
|
560 |
+
placed_uv = _handle_slice_uvs(face_uv, assigned_faces, island_padding)
|
561 |
+
# Then handle the remaining overlap elements
|
562 |
+
placed_uv = _handle_remaining_uvs(placed_uv, assigned_faces, island_padding)
|
563 |
+
|
564 |
+
uc, vc = placed_uv.unbind(-1)
|
565 |
+
uc = uc / div_x[:, None] + offset_x[:, None]
|
566 |
+
vc = vc / div_y[:, None] + offset_y[:, None]
|
567 |
+
|
568 |
+
uv = torch.stack([uc, vc], dim=-1).view(-1, 2)
|
569 |
+
|
570 |
+
return uv
|
571 |
+
|
572 |
+
|
573 |
+
def _get_unique_face_uv(
|
574 |
+
uv: Float[Tensor, "Nf 3 2"],
|
575 |
+
) -> Tuple[Float[Tensor, "Utex 3"], Integer[Tensor, "Nf"]]: # noqa: F821
|
576 |
+
unique_uv, unique_idx = torch.unique(uv, return_inverse=True, dim=0)
|
577 |
+
# And add the face to uv index mapping
|
578 |
+
vtex_idx = unique_idx.view(-1, 3)
|
579 |
+
|
580 |
+
return unique_uv, vtex_idx
|
581 |
+
|
582 |
+
|
583 |
+
def _align_mesh_with_main_axis(
|
584 |
+
vertex_positions: Float[Tensor, "Nv 3"], vertex_normals: Float[Tensor, "Nv 3"]
|
585 |
+
) -> Tuple[Float[Tensor, "Nv 3"], Float[Tensor, "Nv 3"]]:
|
586 |
+
# Use pca to find the 2 main axis (third is derived by cross product)
|
587 |
+
# Set the random seed so it's repeatable
|
588 |
+
torch.manual_seed(0)
|
589 |
+
_, _, v = torch.pca_lowrank(vertex_positions, q=2)
|
590 |
+
main_axis, seconday_axis = v[:, 0], v[:, 1]
|
591 |
+
|
592 |
+
main_axis: Float[Tensor, "3"] = F.normalize(main_axis, eps=1e-6, dim=-1)
|
593 |
+
# Orthogonalize the second axis
|
594 |
+
seconday_axis: Float[Tensor, "3"] = F.normalize(
|
595 |
+
seconday_axis - dot(seconday_axis, main_axis) * main_axis, eps=1e-6, dim=-1
|
596 |
+
)
|
597 |
+
# Create perpendicular third axis
|
598 |
+
third_axis: Float[Tensor, "3"] = F.normalize(
|
599 |
+
torch.cross(main_axis, seconday_axis), dim=-1, eps=1e-6
|
600 |
+
)
|
601 |
+
|
602 |
+
# Check to which canonical axis each aligns
|
603 |
+
main_axis_max_idx = main_axis.abs().argmax().item()
|
604 |
+
seconday_axis_max_idx = seconday_axis.abs().argmax().item()
|
605 |
+
third_axis_max_idx = third_axis.abs().argmax().item()
|
606 |
+
|
607 |
+
# Now sort the axes based on the argmax so they align with thecanonoical axes
|
608 |
+
# If two axes have the same argmax move one of them
|
609 |
+
all_possible_axis = {0, 1, 2}
|
610 |
+
cur_index = 1
|
611 |
+
while len(set([main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx])) != 3:
|
612 |
+
# Find missing axis
|
613 |
+
missing_axis = all_possible_axis - set(
|
614 |
+
[main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx]
|
615 |
+
)
|
616 |
+
missing_axis = missing_axis.pop()
|
617 |
+
# Just assign it to third axis as it had the smallest contribution to the
|
618 |
+
# overall shape
|
619 |
+
if cur_index == 1:
|
620 |
+
third_axis_max_idx = missing_axis
|
621 |
+
elif cur_index == 2:
|
622 |
+
seconday_axis_max_idx = missing_axis
|
623 |
+
else:
|
624 |
+
raise ValueError("Could not find 3 unique axis")
|
625 |
+
cur_index += 1
|
626 |
+
|
627 |
+
if len({main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx}) != 3:
|
628 |
+
raise ValueError("Could not find 3 unique axis")
|
629 |
+
|
630 |
+
axes = [None] * 3
|
631 |
+
axes[main_axis_max_idx] = main_axis
|
632 |
+
axes[seconday_axis_max_idx] = seconday_axis
|
633 |
+
axes[third_axis_max_idx] = third_axis
|
634 |
+
# Create rotation matrix from the individual axes
|
635 |
+
rot_mat = torch.stack(axes, dim=1).T
|
636 |
+
|
637 |
+
# Now rotate the vertex positions and vertex normals so the mesh aligns with the main axis
|
638 |
+
vertex_positions = torch.einsum("ij,nj->ni", rot_mat, vertex_positions)
|
639 |
+
vertex_normals = torch.einsum("ij,nj->ni", rot_mat, vertex_normals)
|
640 |
+
|
641 |
+
return vertex_positions, vertex_normals
|
642 |
+
|
643 |
+
|
644 |
+
def box_projection_uv_unwrap(
|
645 |
+
vertex_positions: Float[Tensor, "Nv 3"],
|
646 |
+
vertex_normals: Float[Tensor, "Nv 3"],
|
647 |
+
triangle_idxs: Integer[Tensor, "Nf 3"],
|
648 |
+
island_padding: float,
|
649 |
+
) -> Tuple[Float[Tensor, "Utex 3"], Integer[Tensor, "Nf"]]: # noqa: F821
|
650 |
+
# Align the mesh with main axis directions first
|
651 |
+
# vertex_positions, vertex_normals = _align_mesh_with_main_axis(
|
652 |
+
# vertex_positions, vertex_normals
|
653 |
+
# )
|
654 |
+
|
655 |
+
bbox: Float[Tensor, "2 3"] = torch.stack(
|
656 |
+
[vertex_positions.min(dim=0).values, vertex_positions.max(dim=0).values], dim=0
|
657 |
+
)
|
658 |
+
# First decide in which cube face the triangle is placed
|
659 |
+
face_uv, face_index = _box_assign_vertex_to_cube_face(
|
660 |
+
vertex_positions, vertex_normals, triangle_idxs, bbox
|
661 |
+
)
|
662 |
+
|
663 |
+
# Rotate the UV islands in a way that they align with the radial z tangent space
|
664 |
+
face_uv = _rotate_uv_slices_consistent_space(
|
665 |
+
vertex_positions, vertex_normals, triangle_idxs, face_uv, face_index
|
666 |
+
)
|
667 |
+
|
668 |
+
# Then find where where the face is placed in the atlas.
|
669 |
+
# This has to detect potential overlaps
|
670 |
+
assigned_atlas_index = _assign_faces_uv_to_atlas_index(
|
671 |
+
vertex_positions, triangle_idxs, face_uv, face_index
|
672 |
+
)
|
673 |
+
|
674 |
+
# Then figure out the final place in the atlas based on the assignment
|
675 |
+
offset_x, offset_y, div_x, div_y = _find_slice_offset_and_scale(
|
676 |
+
assigned_atlas_index
|
677 |
+
)
|
678 |
+
|
679 |
+
# Next distribute the faces in the uv atlas
|
680 |
+
placed_uv = _distribute_individual_uvs_in_atlas(
|
681 |
+
face_uv, assigned_atlas_index, offset_x, offset_y, div_x, div_y, island_padding
|
682 |
+
)
|
683 |
+
|
684 |
+
# And get the unique per-triangle UV coordinates
|
685 |
+
return _get_unique_face_uv(placed_uv)
|