File size: 2,809 Bytes
d945eeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os

import slangtorch
import torch
import torch.nn as nn
from jaxtyping import Bool, Float
from torch import Tensor


class TextureBaker(nn.Module):
    def __init__(self):
        super().__init__()
        self.baker = slangtorch.loadModule(
            os.path.join(os.path.dirname(__file__), "texture_baker.slang")
        )

    def rasterize(
        self,
        uv: Float[Tensor, "Nv 2"],
        face_indices: Float[Tensor, "Nf 3"],
        bake_resolution: int,
    ) -> Float[Tensor, "bake_resolution bake_resolution 4"]:
        if not face_indices.is_cuda or not uv.is_cuda:
            raise ValueError("All input tensors must be on cuda")

        face_indices = face_indices.to(torch.int32)
        uv = uv.to(torch.float32)

        rast_result = torch.empty(
            bake_resolution, bake_resolution, 4, device=uv.device, dtype=torch.float32
        )

        block_size = 16
        grid_size = bake_resolution // block_size
        self.baker.bake_uv(uv=uv, indices=face_indices, output=rast_result).launchRaw(
            blockSize=(block_size, block_size, 1), gridSize=(grid_size, grid_size, 1)
        )

        return rast_result

    def get_mask(
        self, rast: Float[Tensor, "bake_resolution bake_resolution 4"]
    ) -> Bool[Tensor, "bake_resolution bake_resolution"]:
        return rast[..., -1] >= 0

    def interpolate(
        self,
        attr: Float[Tensor, "Nv 3"],
        rast: Float[Tensor, "bake_resolution bake_resolution 4"],
        face_indices: Float[Tensor, "Nf 3"],
        uv: Float[Tensor, "Nv 2"],
    ) -> Float[Tensor, "bake_resolution bake_resolution 3"]:
        # Make sure all input tensors are on torch
        if not attr.is_cuda or not face_indices.is_cuda or not rast.is_cuda:
            raise ValueError("All input tensors must be on cuda")

        attr = attr.to(torch.float32)
        face_indices = face_indices.to(torch.int32)
        uv = uv.to(torch.float32)

        pos_bake = torch.zeros(
            rast.shape[0],
            rast.shape[1],
            3,
            device=attr.device,
            dtype=attr.dtype,
        )

        block_size = 16
        grid_size = rast.shape[0] // block_size
        self.baker.interpolate(
            attr=attr, indices=face_indices, rast=rast, output=pos_bake
        ).launchRaw(
            blockSize=(block_size, block_size, 1), gridSize=(grid_size, grid_size, 1)
        )

        return pos_bake

    def forward(
        self,
        attr: Float[Tensor, "Nv 3"],
        uv: Float[Tensor, "Nv 2"],
        face_indices: Float[Tensor, "Nf 3"],
        bake_resolution: int,
    ) -> Float[Tensor, "bake_resolution bake_resolution 3"]:
        rast = self.rasterize(uv, face_indices, bake_resolution)
        return self.interpolate(attr, rast, face_indices, uv)