File size: 22,946 Bytes
d945eeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
import math
from typing import Tuple

import torch
import torch.nn.functional as F
from jaxtyping import Float, Integer
from torch import Tensor

from sf3d.models.utils import dot, triangle_intersection_2d


def _box_assign_vertex_to_cube_face(
    vertex_positions: Float[Tensor, "Nv 3"],
    vertex_normals: Float[Tensor, "Nv 3"],
    triangle_idxs: Integer[Tensor, "Nf 3"],
    bbox: Float[Tensor, "2 3"],
) -> Tuple[Float[Tensor, "Nf 3 2"], Integer[Tensor, "Nf 3"]]:
    # Test to not have a scaled model to fit the space better
    # bbox_min = bbox[:1].mean(-1, keepdim=True)
    # bbox_max = bbox[1:].mean(-1, keepdim=True)
    # v_pos_normalized = (vertex_positions - bbox_min) / (bbox_max - bbox_min)

    # Create a [0, 1] normalized vertex position
    v_pos_normalized = (vertex_positions - bbox[:1]) / (bbox[1:] - bbox[:1])
    # And to [-1, 1]
    v_pos_normalized = 2.0 * v_pos_normalized - 1.0

    # Get all vertex positions for each triangle
    # Now how do we define to which face the triangle belongs? Mean face pos? Max vertex pos?
    v0 = v_pos_normalized[triangle_idxs[:, 0]]
    v1 = v_pos_normalized[triangle_idxs[:, 1]]
    v2 = v_pos_normalized[triangle_idxs[:, 2]]
    tri_stack = torch.stack([v0, v1, v2], dim=1)

    vn0 = vertex_normals[triangle_idxs[:, 0]]
    vn1 = vertex_normals[triangle_idxs[:, 1]]
    vn2 = vertex_normals[triangle_idxs[:, 2]]
    tri_stack_nrm = torch.stack([vn0, vn1, vn2], dim=1)

    # Just average the normals per face
    face_normal = F.normalize(torch.sum(tri_stack_nrm, 1), eps=1e-6, dim=-1)

    # Now decide based on the face normal in which box map we project
    # abs_x, abs_y, abs_z = tri_stack_nrm.abs().unbind(-1)
    abs_x, abs_y, abs_z = tri_stack.abs().unbind(-1)

    axis = torch.tensor(
        [
            [1, 0, 0],  # 0
            [-1, 0, 0],  # 1
            [0, 1, 0],  # 2
            [0, -1, 0],  # 3
            [0, 0, 1],  # 4
            [0, 0, -1],  # 5
        ],
        device=face_normal.device,
        dtype=face_normal.dtype,
    )
    face_normal_axis = (face_normal[:, None] * axis[None]).sum(-1)
    index = face_normal_axis.argmax(-1)

    max_axis, uc, vc = (
        torch.ones_like(abs_x),
        torch.zeros_like(tri_stack[..., :1]),
        torch.zeros_like(tri_stack[..., :1]),
    )
    mask_pos_x = index == 0
    max_axis[mask_pos_x] = abs_x[mask_pos_x]
    uc[mask_pos_x] = tri_stack[mask_pos_x][..., 1:2]
    vc[mask_pos_x] = -tri_stack[mask_pos_x][..., -1:]

    mask_neg_x = index == 1
    max_axis[mask_neg_x] = abs_x[mask_neg_x]
    uc[mask_neg_x] = tri_stack[mask_neg_x][..., 1:2]
    vc[mask_neg_x] = -tri_stack[mask_neg_x][..., -1:]

    mask_pos_y = index == 2
    max_axis[mask_pos_y] = abs_y[mask_pos_y]
    uc[mask_pos_y] = tri_stack[mask_pos_y][..., 0:1]
    vc[mask_pos_y] = -tri_stack[mask_pos_y][..., -1:]

    mask_neg_y = index == 3
    max_axis[mask_neg_y] = abs_y[mask_neg_y]
    uc[mask_neg_y] = tri_stack[mask_neg_y][..., 0:1]
    vc[mask_neg_y] = -tri_stack[mask_neg_y][..., -1:]

    mask_pos_z = index == 4
    max_axis[mask_pos_z] = abs_z[mask_pos_z]
    uc[mask_pos_z] = tri_stack[mask_pos_z][..., 0:1]
    vc[mask_pos_z] = tri_stack[mask_pos_z][..., 1:2]

    mask_neg_z = index == 5
    max_axis[mask_neg_z] = abs_z[mask_neg_z]
    uc[mask_neg_z] = tri_stack[mask_neg_z][..., 0:1]
    vc[mask_neg_z] = -tri_stack[mask_neg_z][..., 1:2]

    # UC from [-1, 1] to [0, 1]
    max_dim_div = max_axis.max(dim=0, keepdims=True).values
    uc = ((uc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1)
    vc = ((vc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1)

    uv = torch.stack([uc, vc], dim=-1)

    return uv, index


def _assign_faces_uv_to_atlas_index(
    vertex_positions: Float[Tensor, "Nv 3"],
    triangle_idxs: Integer[Tensor, "Nf 3"],
    face_uv: Float[Tensor, "Nf 3 2"],
    face_index: Integer[Tensor, "Nf 3"],
) -> Integer[Tensor, "Nf"]:  # noqa: F821
    triangle_pos = vertex_positions[triangle_idxs]
    # We need to do perform 3 overlap checks.
    # The first set is placed in the upper two thirds of the UV atlas.
    # Conceptually, this is the direct visible surfaces from the each cube side
    # The second set is placed in the lower thirds and the left half of the UV atlas.
    # This is the first set of occluded surfaces. They will also be saved in the projected fashion
    # The third pass finds all non assigned faces. They will be placed in the bottom right half of
    # the UV atlas in scattered fashion.
    assign_idx = face_index.clone()
    for overlap_step in range(3):
        overlapping_indicator = torch.zeros_like(assign_idx, dtype=torch.bool)
        for i in range(overlap_step * 6, (overlap_step + 1) * 6):
            mask = assign_idx == i
            if not mask.any():
                continue
            # Get all elements belonging to the projection face
            uv_triangle = face_uv[mask]
            cur_triangle_pos = triangle_pos[mask]
            # Find the center of the uv coordinates
            center_uv = uv_triangle.mean(dim=1, keepdim=True)
            # And also the radius of the triangle
            uv_triangle_radius = (uv_triangle - center_uv).norm(dim=-1).max(-1).values

            potentially_overlapping_mask = (
                # Find all close triangles
                (center_uv[None, ...] - center_uv[:, None]).norm(dim=-1)
                # Do not select the same element by offseting with an large valued identity matrix
                + torch.eye(
                    uv_triangle.shape[0],
                    device=uv_triangle.device,
                    dtype=uv_triangle.dtype,
                ).unsqueeze(-1)
                * 1000
            )
            # Mark all potentially overlapping triangles to reduce the number of triangle intersection tests
            potentially_overlapping_mask = (
                potentially_overlapping_mask
                <= (uv_triangle_radius.view(-1, 1, 1) * 3.0)
            ).squeeze(-1)
            overlap_coords = torch.stack(torch.where(potentially_overlapping_mask), -1)

            # Only unique triangles (A|B and B|A should be the same)
            f = torch.min(overlap_coords, dim=-1).values
            s = torch.max(overlap_coords, dim=-1).values
            overlap_coords = torch.unique(torch.stack([f, s], dim=1), dim=0)
            first, second = overlap_coords.unbind(-1)

            # Get the triangles
            tri_1 = uv_triangle[first]
            tri_2 = uv_triangle[second]

            # Perform the actual set with the reduced number of potentially overlapping triangles
            its = triangle_intersection_2d(tri_1, tri_2, eps=1e-6)

            # So we now need to detect which triangles are the occluded ones.
            # We always assume the first to be the visible one (the others should move)
            # In the previous step we use a lexigraphical sort to get the unique pairs
            # In this we use a sort based on the orthographic projection
            ax = 0 if i < 2 else 1 if i < 4 else 2
            use_max = i % 2 == 1

            tri1_c = cur_triangle_pos[first].mean(dim=1)
            tri2_c = cur_triangle_pos[second].mean(dim=1)

            mark_first = (
                (tri1_c[..., ax] > tri2_c[..., ax])
                if use_max
                else (tri1_c[..., ax] < tri2_c[..., ax])
            )
            first[mark_first] = second[mark_first]

            # Lastly the same index can be tested multiple times.
            # If one marks it as overlapping we keep it marked as such.
            # We do this by testing if it has been marked at least once.
            unique_idx, rev_idx = torch.unique(first, return_inverse=True)

            add = torch.zeros_like(unique_idx, dtype=torch.float32)
            add.index_add_(0, rev_idx, its.float())
            its_mask = add > 0

            # And fill it in the overlapping indicator
            idx = torch.where(mask)[0][unique_idx]
            overlapping_indicator[idx] = its_mask

        # Move the index to the overlap regions (shift by 6)
        assign_idx[overlapping_indicator] += 6

    # We do not care about the correct face placement after the first 2 slices
    max_idx = 6 * 2
    return assign_idx.clamp(0, max_idx)


def _find_slice_offset_and_scale(
    index: Integer[Tensor, "Nf"],  # noqa: F821
) -> Tuple[
    Float[Tensor, "Nf"], Float[Tensor, "Nf"], Float[Tensor, "Nf"], Float[Tensor, "Nf"]  # noqa: F821
]:  # noqa: F821
    # 6 due to the 6 cube faces
    off = 1 / 3
    dupl_off = 1 / 6

    # Here, we need to decide how to pack the textures in the case of overlap
    def x_offset_calc(x, i):
        offset_calc = i // 6
        # Initial coordinates - just 3x2 grid
        if offset_calc == 0:
            return off * x
        else:
            # Smaller 3x2 grid plus eventual shift to right for
            # second overlap
            return dupl_off * x + min(offset_calc - 1, 1) * 0.5

    def y_offset_calc(x, i):
        offset_calc = i // 6
        # Initial coordinates - just a 3x2 grid
        if offset_calc == 0:
            return off * x
        else:
            # Smaller coordinates in the lowest row
            return dupl_off * x + off * 2

    offset_x = torch.zeros_like(index, dtype=torch.float32)
    offset_y = torch.zeros_like(index, dtype=torch.float32)
    offset_x_vals = [0, 1, 2, 0, 1, 2]
    offset_y_vals = [0, 0, 0, 1, 1, 1]
    for i in range(index.max().item() + 1):
        mask = index == i
        if not mask.any():
            continue
        offset_x[mask] = x_offset_calc(offset_x_vals[i % 6], i)
        offset_y[mask] = y_offset_calc(offset_y_vals[i % 6], i)

    div_x = torch.full_like(index, 6 // 2, dtype=torch.float32)
    # All overlap elements are saved in half scale
    div_x[index >= 6] = 6
    div_y = div_x.clone()  # Same for y
    # Except for the random overlaps
    div_x[index >= 12] = 2
    # But the random overlaps are saved in a large block in the lower thirds
    div_y[index >= 12] = 3

    return offset_x, offset_y, div_x, div_y


def rotation_flip_matrix_2d(
    rad: float, flip_x: bool = False, flip_y: bool = False
) -> Float[Tensor, "2 2"]:
    cos = math.cos(rad)
    sin = math.sin(rad)
    rot_mat = torch.tensor([[cos, -sin], [sin, cos]], dtype=torch.float32)
    flip_mat = torch.tensor(
        [
            [-1 if flip_x else 1, 0],
            [0, -1 if flip_y else 1],
        ],
        dtype=torch.float32,
    )

    return flip_mat @ rot_mat


def calculate_tangents(
    vertex_positions: Float[Tensor, "Nv 3"],
    vertex_normals: Float[Tensor, "Nv 3"],
    triangle_idxs: Integer[Tensor, "Nf 3"],
    face_uv: Float[Tensor, "Nf 3 2"],
) -> Float[Tensor, "Nf 3 4"]:  # noqa: F821
    vn_idx = [None] * 3
    pos = [None] * 3
    tex = face_uv.unbind(1)
    for i in range(0, 3):
        pos[i] = vertex_positions[triangle_idxs[:, i]]
        # t_nrm_idx is always the same as t_pos_idx
        vn_idx[i] = triangle_idxs[:, i]

    tangents = torch.zeros_like(vertex_normals)
    tansum = torch.zeros_like(vertex_normals)

    # Compute tangent space for each triangle
    duv1 = tex[1] - tex[0]
    duv2 = tex[2] - tex[0]
    dpos1 = pos[1] - pos[0]
    dpos2 = pos[2] - pos[0]

    tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2]

    denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1]

    # Avoid division by zero for degenerated texture coordinates
    denom_safe = denom.clip(1e-6)
    tang = tng_nom / denom_safe

    # Update all 3 vertices
    for i in range(0, 3):
        idx = vn_idx[i][:, None].repeat(1, 3)
        tangents.scatter_add_(0, idx, tang)  # tangents[n_i] = tangents[n_i] + tang
        tansum.scatter_add_(
            0, idx, torch.ones_like(tang)
        )  # tansum[n_i] = tansum[n_i] + 1
    # Also normalize it. Here we do not normalize the individual triangles first so larger area
    # triangles influence the tangent space more
    tangents = tangents / tansum

    # Normalize and make sure tangent is perpendicular to normal
    tangents = F.normalize(tangents, dim=1)
    tangents = F.normalize(tangents - dot(tangents, vertex_normals) * vertex_normals)

    return tangents


def _rotate_uv_slices_consistent_space(
    vertex_positions: Float[Tensor, "Nv 3"],
    vertex_normals: Float[Tensor, "Nv 3"],
    triangle_idxs: Integer[Tensor, "Nf 3"],
    uv: Float[Tensor, "Nf 3 2"],
    index: Integer[Tensor, "Nf"],  # noqa: F821
):
    tangents = calculate_tangents(vertex_positions, vertex_normals, triangle_idxs, uv)
    pos_stack = torch.stack(
        [
            -vertex_positions[..., 1],
            vertex_positions[..., 0],
            torch.zeros_like(vertex_positions[..., 0]),
        ],
        dim=-1,
    )
    expected_tangents = F.normalize(
        torch.linalg.cross(
            vertex_normals, torch.linalg.cross(pos_stack, vertex_normals)
        ),
        -1,
    )

    actual_tangents = tangents[triangle_idxs]
    expected_tangents = expected_tangents[triangle_idxs]

    def rotation_matrix_2d(theta):
        c, s = torch.cos(theta), torch.sin(theta)
        return torch.tensor([[c, -s], [s, c]])

    # Now find the rotation
    index_mod = index % 6  # Shouldn't happen. Just for safety
    for i in range(6):
        mask = index_mod == i
        if not mask.any():
            continue

        actual_mean_tangent = actual_tangents[mask].mean(dim=(0, 1))
        expected_mean_tangent = expected_tangents[mask].mean(dim=(0, 1))

        dot_product = torch.dot(actual_mean_tangent, expected_mean_tangent)
        cross_product = (
            actual_mean_tangent[0] * expected_mean_tangent[1]
            - actual_mean_tangent[1] * expected_mean_tangent[0]
        )
        angle = torch.atan2(cross_product, dot_product)

        rot_matrix = rotation_matrix_2d(angle).to(mask.device)
        # Center the uv coordinate to be in the range of -1 to 1 and 0 centered
        uv_cur = uv[mask] * 2 - 1  # Center it first
        # Rotate it
        uv[mask] = torch.einsum("ij,nfj->nfi", rot_matrix, uv_cur)

        # Rescale uv[mask] to be within the 0-1 range
        uv[mask] = (uv[mask] - uv[mask].min()) / (uv[mask].max() - uv[mask].min())

    return uv


def _handle_slice_uvs(
    uv: Float[Tensor, "Nf 3 2"],
    index: Integer[Tensor, "Nf"],  # noqa: F821
    island_padding: float,
    max_index: int = 6 * 2,
) -> Float[Tensor, "Nf 3 2"]:  # noqa: F821
    uc, vc = uv.unbind(-1)

    # Get the second slice (The first overlap)
    index_filter = [index == i for i in range(6, max_index)]

    # Normalize them to always fully fill the atlas patch
    for i, fi in enumerate(index_filter):
        if fi.sum() > 0:
            # Scale the slice but only up to a factor of 2
            # This keeps the texture resolution with the first slice in line (Half space in UV)
            uc[fi] = (uc[fi] - uc[fi].min()) / (uc[fi].max() - uc[fi].min()).clip(0.5)
            vc[fi] = (vc[fi] - vc[fi].min()) / (vc[fi].max() - vc[fi].min()).clip(0.5)

    uc_padded = (uc * (1 - 2 * island_padding) + island_padding).clip(0, 1)
    vc_padded = (vc * (1 - 2 * island_padding) + island_padding).clip(0, 1)

    return torch.stack([uc_padded, vc_padded], dim=-1)


def _handle_remaining_uvs(
    uv: Float[Tensor, "Nf 3 2"],
    index: Integer[Tensor, "Nf"],  # noqa: F821
    island_padding: float,
) -> Float[Tensor, "Nf 3 2"]:
    uc, vc = uv.unbind(-1)
    # Get all remaining elements
    remaining_filter = index >= 6 * 2
    squares_left = remaining_filter.sum()

    if squares_left == 0:
        return uv

    uc = uc[remaining_filter]
    vc = vc[remaining_filter]

    # Or remaining triangles are distributed in a rectangle
    # The rectangle takes 0.5 of the entire uv space in width and 1/3 in height
    ratio = 0.5 * (1 / 3)  # 1.5
    # sqrt(744/(0.5*(1/3)))

    mult = math.sqrt(squares_left / ratio)
    num_square_width = int(math.ceil(0.5 * mult))
    num_square_height = int(math.ceil(squares_left / num_square_width))

    width = 1 / num_square_width
    height = 1 / num_square_height

    # The idea is again to keep the texture resolution consistent with the first slice
    # This only occupys half the region in the texture chart but the scaling on the squares
    # assumes full coverage.
    clip_val = min(width, height) * 1.5
    # Now normalize the UVs with taking into account the maximum scaling
    uc = (uc - uc.min(dim=1, keepdim=True).values) / (
        uc.amax(dim=1, keepdim=True) - uc.amin(dim=1, keepdim=True)
    ).clip(clip_val)
    vc = (vc - vc.min(dim=1, keepdim=True).values) / (
        vc.amax(dim=1, keepdim=True) - vc.amin(dim=1, keepdim=True)
    ).clip(clip_val)
    # Add a small padding
    uc = (
        uc * (1 - island_padding * num_square_width * 0.5)
        + island_padding * num_square_width * 0.25
    ).clip(0, 1)
    vc = (
        vc * (1 - island_padding * num_square_height * 0.5)
        + island_padding * num_square_height * 0.25
    ).clip(0, 1)

    uc = uc * width
    vc = vc * height

    # And calculate offsets for each element
    idx = torch.arange(uc.shape[0], device=uc.device, dtype=torch.int32)
    x_idx = idx % num_square_width
    y_idx = idx // num_square_width
    # And move each triangle to its own spot
    uc = uc + x_idx[:, None] * width
    vc = vc + y_idx[:, None] * height

    uc = (uc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1)
    vc = (vc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1)

    uv[remaining_filter] = torch.stack([uc, vc], dim=-1)

    return uv


def _distribute_individual_uvs_in_atlas(
    face_uv: Float[Tensor, "Nf 3 2"],
    assigned_faces: Integer[Tensor, "Nf"],  # noqa: F821
    offset_x: Float[Tensor, "Nf"],  # noqa: F821
    offset_y: Float[Tensor, "Nf"],  # noqa: F821
    div_x: Float[Tensor, "Nf"],  # noqa: F821
    div_y: Float[Tensor, "Nf"],  # noqa: F821
    island_padding: float,
):
    # Place the slice first
    placed_uv = _handle_slice_uvs(face_uv, assigned_faces, island_padding)
    # Then handle the remaining overlap elements
    placed_uv = _handle_remaining_uvs(placed_uv, assigned_faces, island_padding)

    uc, vc = placed_uv.unbind(-1)
    uc = uc / div_x[:, None] + offset_x[:, None]
    vc = vc / div_y[:, None] + offset_y[:, None]

    uv = torch.stack([uc, vc], dim=-1).view(-1, 2)

    return uv


def _get_unique_face_uv(
    uv: Float[Tensor, "Nf 3 2"],
) -> Tuple[Float[Tensor, "Utex 3"], Integer[Tensor, "Nf"]]:  # noqa: F821
    unique_uv, unique_idx = torch.unique(uv, return_inverse=True, dim=0)
    # And add the face to uv index mapping
    vtex_idx = unique_idx.view(-1, 3)

    return unique_uv, vtex_idx


def _align_mesh_with_main_axis(
    vertex_positions: Float[Tensor, "Nv 3"], vertex_normals: Float[Tensor, "Nv 3"]
) -> Tuple[Float[Tensor, "Nv 3"], Float[Tensor, "Nv 3"]]:
    # Use pca to find the 2 main axis (third is derived by cross product)
    # Set the random seed so it's repeatable
    torch.manual_seed(0)
    _, _, v = torch.pca_lowrank(vertex_positions, q=2)
    main_axis, seconday_axis = v[:, 0], v[:, 1]

    main_axis: Float[Tensor, "3"] = F.normalize(main_axis, eps=1e-6, dim=-1)
    # Orthogonalize the second axis
    seconday_axis: Float[Tensor, "3"] = F.normalize(
        seconday_axis - dot(seconday_axis, main_axis) * main_axis, eps=1e-6, dim=-1
    )
    # Create perpendicular third axis
    third_axis: Float[Tensor, "3"] = F.normalize(
        torch.cross(main_axis, seconday_axis), dim=-1, eps=1e-6
    )

    # Check to which canonical axis each aligns
    main_axis_max_idx = main_axis.abs().argmax().item()
    seconday_axis_max_idx = seconday_axis.abs().argmax().item()
    third_axis_max_idx = third_axis.abs().argmax().item()

    # Now sort the axes based on the argmax so they align with thecanonoical axes
    # If two axes have the same argmax move one of them
    all_possible_axis = {0, 1, 2}
    cur_index = 1
    while len(set([main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx])) != 3:
        # Find missing axis
        missing_axis = all_possible_axis - set(
            [main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx]
        )
        missing_axis = missing_axis.pop()
        # Just assign it to third axis as it had the smallest contribution to the
        # overall shape
        if cur_index == 1:
            third_axis_max_idx = missing_axis
        elif cur_index == 2:
            seconday_axis_max_idx = missing_axis
        else:
            raise ValueError("Could not find 3 unique axis")
        cur_index += 1

    if len({main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx}) != 3:
        raise ValueError("Could not find 3 unique axis")

    axes = [None] * 3
    axes[main_axis_max_idx] = main_axis
    axes[seconday_axis_max_idx] = seconday_axis
    axes[third_axis_max_idx] = third_axis
    # Create rotation matrix from the individual axes
    rot_mat = torch.stack(axes, dim=1).T

    # Now rotate the vertex positions and vertex normals so the mesh aligns with the main axis
    vertex_positions = torch.einsum("ij,nj->ni", rot_mat, vertex_positions)
    vertex_normals = torch.einsum("ij,nj->ni", rot_mat, vertex_normals)

    return vertex_positions, vertex_normals


def box_projection_uv_unwrap(
    vertex_positions: Float[Tensor, "Nv 3"],
    vertex_normals: Float[Tensor, "Nv 3"],
    triangle_idxs: Integer[Tensor, "Nf 3"],
    island_padding: float,
) -> Tuple[Float[Tensor, "Utex 3"], Integer[Tensor, "Nf"]]:  # noqa: F821
    # Align the mesh with main axis directions first
    vertex_positions, vertex_normals = _align_mesh_with_main_axis(
        vertex_positions, vertex_normals
    )

    bbox: Float[Tensor, "2 3"] = torch.stack(
        [vertex_positions.min(dim=0).values, vertex_positions.max(dim=0).values], dim=0
    )
    # First decide in which cube face the triangle is placed
    face_uv, face_index = _box_assign_vertex_to_cube_face(
        vertex_positions, vertex_normals, triangle_idxs, bbox
    )

    # Rotate the UV islands in a way that they align with the radial z tangent space
    face_uv = _rotate_uv_slices_consistent_space(
        vertex_positions, vertex_normals, triangle_idxs, face_uv, face_index
    )

    # Then find where where the face is placed in the atlas.
    # This has to detect potential overlaps
    assigned_atlas_index = _assign_faces_uv_to_atlas_index(
        vertex_positions, triangle_idxs, face_uv, face_index
    )

    # Then figure out the final place in the atlas based on the assignment
    offset_x, offset_y, div_x, div_y = _find_slice_offset_and_scale(
        assigned_atlas_index
    )

    # Next distribute the faces in the uv atlas
    placed_uv = _distribute_individual_uvs_in_atlas(
        face_uv, assigned_atlas_index, offset_x, offset_y, div_x, div_y, island_padding
    )

    # And get the unique per-triangle UV coordinates
    return _get_unique_face_uv(placed_uv)