Spaces:
Runtime error
Runtime error
dmitriitochilkin
commited on
Commit
•
ff49a48
1
Parent(s):
a2337f4
add dependencies
Browse files- requirements.txt +8 -0
- tsr/__pycache__/system.cpython-310.pyc +0 -0
- tsr/__pycache__/utils.cpython-310.pyc +0 -0
- tsr/models/__pycache__/camera.cpython-310.pyc +0 -0
- tsr/models/__pycache__/isosurface.cpython-310.pyc +0 -0
- tsr/models/__pycache__/nerf_renderer.cpython-310.pyc +0 -0
- tsr/models/__pycache__/network_utils.cpython-310.pyc +0 -0
- tsr/models/isosurface.py +48 -0
- tsr/models/nerf_renderer.py +180 -0
- tsr/models/network_utils.py +124 -0
- tsr/models/tokenizers/__pycache__/dino.cpython-310.pyc +0 -0
- tsr/models/tokenizers/__pycache__/image.cpython-310.pyc +0 -0
- tsr/models/tokenizers/__pycache__/triplane.cpython-310.pyc +0 -0
- tsr/models/tokenizers/image.py +67 -0
- tsr/models/tokenizers/triplane.py +45 -0
- tsr/models/transformer/__pycache__/attention.cpython-310.pyc +0 -0
- tsr/models/transformer/__pycache__/basic_transformer_block.cpython-310.pyc +0 -0
- tsr/models/transformer/__pycache__/transformer_1d.cpython-310.pyc +0 -0
- tsr/models/transformer/attention.py +628 -0
- tsr/models/transformer/basic_transformer_block.py +314 -0
- tsr/models/transformer/transformer_1d.py +216 -0
- tsr/system.py +203 -0
- tsr/utils.py +492 -0
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
omegaconf==2.3.0
|
2 |
+
Pillow==10.1.0
|
3 |
+
einops==0.7.0
|
4 |
+
git+https://github.com/tatsy/torchmcubes.git
|
5 |
+
transformers==4.35.0
|
6 |
+
trimesh==4.0.5
|
7 |
+
rembg
|
8 |
+
huggingface-hub
|
tsr/__pycache__/system.cpython-310.pyc
ADDED
Binary file (5.41 kB). View file
|
|
tsr/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (13.8 kB). View file
|
|
tsr/models/__pycache__/camera.cpython-310.pyc
ADDED
Binary file (1.48 kB). View file
|
|
tsr/models/__pycache__/isosurface.cpython-310.pyc
ADDED
Binary file (2.04 kB). View file
|
|
tsr/models/__pycache__/nerf_renderer.cpython-310.pyc
ADDED
Binary file (5.29 kB). View file
|
|
tsr/models/__pycache__/network_utils.cpython-310.pyc
ADDED
Binary file (3.42 kB). View file
|
|
tsr/models/isosurface.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Optional, Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torchmcubes import marching_cubes
|
7 |
+
|
8 |
+
|
9 |
+
class IsosurfaceHelper(nn.Module):
|
10 |
+
points_range: Tuple[float, float] = (0, 1)
|
11 |
+
|
12 |
+
@property
|
13 |
+
def grid_vertices(self) -> torch.FloatTensor:
|
14 |
+
raise NotImplementedError
|
15 |
+
|
16 |
+
|
17 |
+
class MarchingCubeHelper(IsosurfaceHelper):
|
18 |
+
def __init__(self, resolution: int) -> None:
|
19 |
+
super().__init__()
|
20 |
+
self.resolution = resolution
|
21 |
+
self.mc_func: Callable = marching_cubes
|
22 |
+
self._grid_vertices: Optional[torch.FloatTensor] = None
|
23 |
+
|
24 |
+
@property
|
25 |
+
def grid_vertices(self) -> torch.FloatTensor:
|
26 |
+
if self._grid_vertices is None:
|
27 |
+
# keep the vertices on CPU so that we can support very large resolution
|
28 |
+
x, y, z = (
|
29 |
+
torch.linspace(*self.points_range, self.resolution),
|
30 |
+
torch.linspace(*self.points_range, self.resolution),
|
31 |
+
torch.linspace(*self.points_range, self.resolution),
|
32 |
+
)
|
33 |
+
x, y, z = torch.meshgrid(x, y, z, indexing="ij")
|
34 |
+
verts = torch.cat(
|
35 |
+
[x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1
|
36 |
+
).reshape(-1, 3)
|
37 |
+
self._grid_vertices = verts
|
38 |
+
return self._grid_vertices
|
39 |
+
|
40 |
+
def forward(
|
41 |
+
self,
|
42 |
+
level: torch.FloatTensor,
|
43 |
+
) -> Tuple[torch.FloatTensor, torch.LongTensor]:
|
44 |
+
level = -level.view(self.resolution, self.resolution, self.resolution)
|
45 |
+
v_pos, t_pos_idx = self.mc_func(level.detach(), 0.0)
|
46 |
+
v_pos = v_pos[..., [2, 1, 0]]
|
47 |
+
v_pos = v_pos / (self.resolution - 1.0)
|
48 |
+
return v_pos, t_pos_idx
|
tsr/models/nerf_renderer.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import Dict
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from einops import rearrange, reduce
|
7 |
+
|
8 |
+
from ..utils import (
|
9 |
+
BaseModule,
|
10 |
+
chunk_batch,
|
11 |
+
get_activation,
|
12 |
+
rays_intersect_bbox,
|
13 |
+
scale_tensor,
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
class TriplaneNeRFRenderer(BaseModule):
|
18 |
+
@dataclass
|
19 |
+
class Config(BaseModule.Config):
|
20 |
+
radius: float
|
21 |
+
|
22 |
+
feature_reduction: str = "concat"
|
23 |
+
density_activation: str = "trunc_exp"
|
24 |
+
density_bias: float = -1.0
|
25 |
+
color_activation: str = "sigmoid"
|
26 |
+
num_samples_per_ray: int = 128
|
27 |
+
randomized: bool = False
|
28 |
+
|
29 |
+
cfg: Config
|
30 |
+
|
31 |
+
def configure(self) -> None:
|
32 |
+
assert self.cfg.feature_reduction in ["concat", "mean"]
|
33 |
+
self.chunk_size = 0
|
34 |
+
|
35 |
+
def set_chunk_size(self, chunk_size: int):
|
36 |
+
assert (
|
37 |
+
chunk_size >= 0
|
38 |
+
), "chunk_size must be a non-negative integer (0 for no chunking)."
|
39 |
+
self.chunk_size = chunk_size
|
40 |
+
|
41 |
+
def query_triplane(
|
42 |
+
self,
|
43 |
+
decoder: torch.nn.Module,
|
44 |
+
positions: torch.Tensor,
|
45 |
+
triplane: torch.Tensor,
|
46 |
+
) -> Dict[str, torch.Tensor]:
|
47 |
+
input_shape = positions.shape[:-1]
|
48 |
+
positions = positions.view(-1, 3)
|
49 |
+
|
50 |
+
# positions in (-radius, radius)
|
51 |
+
# normalized to (-1, 1) for grid sample
|
52 |
+
positions = scale_tensor(
|
53 |
+
positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
|
54 |
+
)
|
55 |
+
|
56 |
+
def _query_chunk(x):
|
57 |
+
indices2D: torch.Tensor = torch.stack(
|
58 |
+
(x[..., [0, 1]], x[..., [0, 2]], x[..., [1, 2]]),
|
59 |
+
dim=-3,
|
60 |
+
)
|
61 |
+
out: torch.Tensor = F.grid_sample(
|
62 |
+
rearrange(triplane, "Np Cp Hp Wp -> Np Cp Hp Wp", Np=3),
|
63 |
+
rearrange(indices2D, "Np N Nd -> Np () N Nd", Np=3),
|
64 |
+
align_corners=False,
|
65 |
+
mode="bilinear",
|
66 |
+
)
|
67 |
+
if self.cfg.feature_reduction == "concat":
|
68 |
+
out = rearrange(out, "Np Cp () N -> N (Np Cp)", Np=3)
|
69 |
+
elif self.cfg.feature_reduction == "mean":
|
70 |
+
out = reduce(out, "Np Cp () N -> N Cp", Np=3, reduction="mean")
|
71 |
+
else:
|
72 |
+
raise NotImplementedError
|
73 |
+
|
74 |
+
net_out: Dict[str, torch.Tensor] = decoder(out)
|
75 |
+
return net_out
|
76 |
+
|
77 |
+
if self.chunk_size > 0:
|
78 |
+
net_out = chunk_batch(_query_chunk, self.chunk_size, positions)
|
79 |
+
else:
|
80 |
+
net_out = _query_chunk(positions)
|
81 |
+
|
82 |
+
net_out["density_act"] = get_activation(self.cfg.density_activation)(
|
83 |
+
net_out["density"] + self.cfg.density_bias
|
84 |
+
)
|
85 |
+
net_out["color"] = get_activation(self.cfg.color_activation)(
|
86 |
+
net_out["features"]
|
87 |
+
)
|
88 |
+
|
89 |
+
net_out = {k: v.view(*input_shape, -1) for k, v in net_out.items()}
|
90 |
+
|
91 |
+
return net_out
|
92 |
+
|
93 |
+
def _forward(
|
94 |
+
self,
|
95 |
+
decoder: torch.nn.Module,
|
96 |
+
triplane: torch.Tensor,
|
97 |
+
rays_o: torch.Tensor,
|
98 |
+
rays_d: torch.Tensor,
|
99 |
+
**kwargs,
|
100 |
+
):
|
101 |
+
rays_shape = rays_o.shape[:-1]
|
102 |
+
rays_o = rays_o.view(-1, 3)
|
103 |
+
rays_d = rays_d.view(-1, 3)
|
104 |
+
n_rays = rays_o.shape[0]
|
105 |
+
|
106 |
+
t_near, t_far, rays_valid = rays_intersect_bbox(rays_o, rays_d, self.cfg.radius)
|
107 |
+
t_near, t_far = t_near[rays_valid], t_far[rays_valid]
|
108 |
+
|
109 |
+
t_vals = torch.linspace(
|
110 |
+
0, 1, self.cfg.num_samples_per_ray + 1, device=triplane.device
|
111 |
+
)
|
112 |
+
t_mid = (t_vals[:-1] + t_vals[1:]) / 2.0
|
113 |
+
z_vals = t_near * (1 - t_mid[None]) + t_far * t_mid[None] # (N_rays, N_samples)
|
114 |
+
|
115 |
+
xyz = (
|
116 |
+
rays_o[:, None, :] + z_vals[..., None] * rays_d[..., None, :]
|
117 |
+
) # (N_rays, N_sample, 3)
|
118 |
+
|
119 |
+
mlp_out = self.query_triplane(
|
120 |
+
decoder=decoder,
|
121 |
+
positions=xyz,
|
122 |
+
triplane=triplane,
|
123 |
+
)
|
124 |
+
|
125 |
+
eps = 1e-10
|
126 |
+
# deltas = z_vals[:, 1:] - z_vals[:, :-1] # (N_rays, N_samples)
|
127 |
+
deltas = t_vals[1:] - t_vals[:-1] # (N_rays, N_samples)
|
128 |
+
alpha = 1 - torch.exp(
|
129 |
+
-deltas * mlp_out["density_act"][..., 0]
|
130 |
+
) # (N_rays, N_samples)
|
131 |
+
accum_prod = torch.cat(
|
132 |
+
[
|
133 |
+
torch.ones_like(alpha[:, :1]),
|
134 |
+
torch.cumprod(1 - alpha[:, :-1] + eps, dim=-1),
|
135 |
+
],
|
136 |
+
dim=-1,
|
137 |
+
)
|
138 |
+
weights = alpha * accum_prod # (N_rays, N_samples)
|
139 |
+
comp_rgb_ = (weights[..., None] * mlp_out["color"]).sum(dim=-2) # (N_rays, 3)
|
140 |
+
opacity_ = weights.sum(dim=-1) # (N_rays)
|
141 |
+
|
142 |
+
comp_rgb = torch.zeros(
|
143 |
+
n_rays, 3, dtype=comp_rgb_.dtype, device=comp_rgb_.device
|
144 |
+
)
|
145 |
+
opacity = torch.zeros(n_rays, dtype=opacity_.dtype, device=opacity_.device)
|
146 |
+
comp_rgb[rays_valid] = comp_rgb_
|
147 |
+
opacity[rays_valid] = opacity_
|
148 |
+
|
149 |
+
comp_rgb += 1 - opacity[..., None]
|
150 |
+
comp_rgb = comp_rgb.view(*rays_shape, 3)
|
151 |
+
|
152 |
+
return comp_rgb
|
153 |
+
|
154 |
+
def forward(
|
155 |
+
self,
|
156 |
+
decoder: torch.nn.Module,
|
157 |
+
triplane: torch.Tensor,
|
158 |
+
rays_o: torch.Tensor,
|
159 |
+
rays_d: torch.Tensor,
|
160 |
+
) -> Dict[str, torch.Tensor]:
|
161 |
+
if triplane.ndim == 4:
|
162 |
+
comp_rgb = self._forward(decoder, triplane, rays_o, rays_d)
|
163 |
+
else:
|
164 |
+
comp_rgb = torch.stack(
|
165 |
+
[
|
166 |
+
self._forward(decoder, triplane[i], rays_o[i], rays_d[i])
|
167 |
+
for i in range(triplane.shape[0])
|
168 |
+
],
|
169 |
+
dim=0,
|
170 |
+
)
|
171 |
+
|
172 |
+
return comp_rgb
|
173 |
+
|
174 |
+
def train(self, mode=True):
|
175 |
+
self.randomized = mode and self.cfg.randomized
|
176 |
+
return super().train(mode=mode)
|
177 |
+
|
178 |
+
def eval(self):
|
179 |
+
self.randomized = False
|
180 |
+
return super().eval()
|
tsr/models/network_utils.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from einops import rearrange
|
7 |
+
|
8 |
+
from ..utils import BaseModule
|
9 |
+
|
10 |
+
|
11 |
+
class TriplaneUpsampleNetwork(BaseModule):
|
12 |
+
@dataclass
|
13 |
+
class Config(BaseModule.Config):
|
14 |
+
in_channels: int
|
15 |
+
out_channels: int
|
16 |
+
|
17 |
+
cfg: Config
|
18 |
+
|
19 |
+
def configure(self) -> None:
|
20 |
+
self.upsample = nn.ConvTranspose2d(
|
21 |
+
self.cfg.in_channels, self.cfg.out_channels, kernel_size=2, stride=2
|
22 |
+
)
|
23 |
+
|
24 |
+
def forward(self, triplanes: torch.Tensor) -> torch.Tensor:
|
25 |
+
triplanes_up = rearrange(
|
26 |
+
self.upsample(
|
27 |
+
rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3)
|
28 |
+
),
|
29 |
+
"(B Np) Co Hp Wp -> B Np Co Hp Wp",
|
30 |
+
Np=3,
|
31 |
+
)
|
32 |
+
return triplanes_up
|
33 |
+
|
34 |
+
|
35 |
+
class NeRFMLP(BaseModule):
|
36 |
+
@dataclass
|
37 |
+
class Config(BaseModule.Config):
|
38 |
+
in_channels: int
|
39 |
+
n_neurons: int
|
40 |
+
n_hidden_layers: int
|
41 |
+
activation: str = "relu"
|
42 |
+
bias: bool = True
|
43 |
+
weight_init: Optional[str] = "kaiming_uniform"
|
44 |
+
bias_init: Optional[str] = None
|
45 |
+
|
46 |
+
cfg: Config
|
47 |
+
|
48 |
+
def configure(self) -> None:
|
49 |
+
layers = [
|
50 |
+
self.make_linear(
|
51 |
+
self.cfg.in_channels,
|
52 |
+
self.cfg.n_neurons,
|
53 |
+
bias=self.cfg.bias,
|
54 |
+
weight_init=self.cfg.weight_init,
|
55 |
+
bias_init=self.cfg.bias_init,
|
56 |
+
),
|
57 |
+
self.make_activation(self.cfg.activation),
|
58 |
+
]
|
59 |
+
for i in range(self.cfg.n_hidden_layers - 1):
|
60 |
+
layers += [
|
61 |
+
self.make_linear(
|
62 |
+
self.cfg.n_neurons,
|
63 |
+
self.cfg.n_neurons,
|
64 |
+
bias=self.cfg.bias,
|
65 |
+
weight_init=self.cfg.weight_init,
|
66 |
+
bias_init=self.cfg.bias_init,
|
67 |
+
),
|
68 |
+
self.make_activation(self.cfg.activation),
|
69 |
+
]
|
70 |
+
layers += [
|
71 |
+
self.make_linear(
|
72 |
+
self.cfg.n_neurons,
|
73 |
+
4, # density 1 + features 3
|
74 |
+
bias=self.cfg.bias,
|
75 |
+
weight_init=self.cfg.weight_init,
|
76 |
+
bias_init=self.cfg.bias_init,
|
77 |
+
)
|
78 |
+
]
|
79 |
+
self.layers = nn.Sequential(*layers)
|
80 |
+
|
81 |
+
def make_linear(
|
82 |
+
self,
|
83 |
+
dim_in,
|
84 |
+
dim_out,
|
85 |
+
bias=True,
|
86 |
+
weight_init=None,
|
87 |
+
bias_init=None,
|
88 |
+
):
|
89 |
+
layer = nn.Linear(dim_in, dim_out, bias=bias)
|
90 |
+
|
91 |
+
if weight_init is None:
|
92 |
+
pass
|
93 |
+
elif weight_init == "kaiming_uniform":
|
94 |
+
torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity="relu")
|
95 |
+
else:
|
96 |
+
raise NotImplementedError
|
97 |
+
|
98 |
+
if bias:
|
99 |
+
if bias_init is None:
|
100 |
+
pass
|
101 |
+
elif bias_init == "zero":
|
102 |
+
torch.nn.init.zeros_(layer.bias)
|
103 |
+
else:
|
104 |
+
raise NotImplementedError
|
105 |
+
|
106 |
+
return layer
|
107 |
+
|
108 |
+
def make_activation(self, activation):
|
109 |
+
if activation == "relu":
|
110 |
+
return nn.ReLU(inplace=True)
|
111 |
+
elif activation == "silu":
|
112 |
+
return nn.SiLU(inplace=True)
|
113 |
+
else:
|
114 |
+
raise NotImplementedError
|
115 |
+
|
116 |
+
def forward(self, x):
|
117 |
+
inp_shape = x.shape[:-1]
|
118 |
+
x = x.reshape(-1, x.shape[-1])
|
119 |
+
|
120 |
+
features = self.layers(x)
|
121 |
+
features = features.reshape(*inp_shape, -1)
|
122 |
+
out = {"density": features[..., 0:1], "features": features[..., 1:4]}
|
123 |
+
|
124 |
+
return out
|
tsr/models/tokenizers/__pycache__/dino.cpython-310.pyc
ADDED
Binary file (18.6 kB). View file
|
|
tsr/models/tokenizers/__pycache__/image.cpython-310.pyc
ADDED
Binary file (2.42 kB). View file
|
|
tsr/models/tokenizers/__pycache__/triplane.cpython-310.pyc
ADDED
Binary file (1.76 kB). View file
|
|
tsr/models/tokenizers/image.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from einops import rearrange
|
7 |
+
from huggingface_hub import hf_hub_download
|
8 |
+
from transformers.models.vit.modeling_vit import ViTModel
|
9 |
+
|
10 |
+
from ...utils import BaseModule
|
11 |
+
|
12 |
+
|
13 |
+
class DINOSingleImageTokenizer(BaseModule):
|
14 |
+
@dataclass
|
15 |
+
class Config(BaseModule.Config):
|
16 |
+
pretrained_model_name_or_path: str = "facebook/dino-vitb16"
|
17 |
+
enable_gradient_checkpointing: bool = False
|
18 |
+
|
19 |
+
cfg: Config
|
20 |
+
|
21 |
+
def configure(self) -> None:
|
22 |
+
self.model: ViTModel = ViTModel(
|
23 |
+
ViTModel.config_class.from_pretrained(
|
24 |
+
hf_hub_download(
|
25 |
+
repo_id=self.cfg.pretrained_model_name_or_path,
|
26 |
+
filename="config.json",
|
27 |
+
)
|
28 |
+
)
|
29 |
+
)
|
30 |
+
|
31 |
+
if self.cfg.enable_gradient_checkpointing:
|
32 |
+
self.model.encoder.gradient_checkpointing = True
|
33 |
+
|
34 |
+
self.register_buffer(
|
35 |
+
"image_mean",
|
36 |
+
torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1),
|
37 |
+
persistent=False,
|
38 |
+
)
|
39 |
+
self.register_buffer(
|
40 |
+
"image_std",
|
41 |
+
torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1),
|
42 |
+
persistent=False,
|
43 |
+
)
|
44 |
+
|
45 |
+
def forward(self, images: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
46 |
+
packed = False
|
47 |
+
if images.ndim == 4:
|
48 |
+
packed = True
|
49 |
+
images = images.unsqueeze(1)
|
50 |
+
|
51 |
+
batch_size, n_input_views = images.shape[:2]
|
52 |
+
images = (images - self.image_mean) / self.image_std
|
53 |
+
out = self.model(
|
54 |
+
rearrange(images, "B N C H W -> (B N) C H W"), interpolate_pos_encoding=True
|
55 |
+
)
|
56 |
+
local_features, global_features = out.last_hidden_state, out.pooler_output
|
57 |
+
local_features = local_features.permute(0, 2, 1)
|
58 |
+
local_features = rearrange(
|
59 |
+
local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
|
60 |
+
)
|
61 |
+
if packed:
|
62 |
+
local_features = local_features.squeeze(1)
|
63 |
+
|
64 |
+
return local_features
|
65 |
+
|
66 |
+
def detokenize(self, *args, **kwargs):
|
67 |
+
raise NotImplementedError
|
tsr/models/tokenizers/triplane.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from dataclasses import dataclass
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from einops import rearrange, repeat
|
7 |
+
|
8 |
+
from ...utils import BaseModule
|
9 |
+
|
10 |
+
|
11 |
+
class Triplane1DTokenizer(BaseModule):
|
12 |
+
@dataclass
|
13 |
+
class Config(BaseModule.Config):
|
14 |
+
plane_size: int
|
15 |
+
num_channels: int
|
16 |
+
|
17 |
+
cfg: Config
|
18 |
+
|
19 |
+
def configure(self) -> None:
|
20 |
+
self.embeddings = nn.Parameter(
|
21 |
+
torch.randn(
|
22 |
+
(3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size),
|
23 |
+
dtype=torch.float32,
|
24 |
+
)
|
25 |
+
* 1
|
26 |
+
/ math.sqrt(self.cfg.num_channels)
|
27 |
+
)
|
28 |
+
|
29 |
+
def forward(self, batch_size: int) -> torch.Tensor:
|
30 |
+
return rearrange(
|
31 |
+
repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size),
|
32 |
+
"B Np Ct Hp Wp -> B Ct (Np Hp Wp)",
|
33 |
+
)
|
34 |
+
|
35 |
+
def detokenize(self, tokens: torch.Tensor) -> torch.Tensor:
|
36 |
+
batch_size, Ct, Nt = tokens.shape
|
37 |
+
assert Nt == self.cfg.plane_size**2 * 3
|
38 |
+
assert Ct == self.cfg.num_channels
|
39 |
+
return rearrange(
|
40 |
+
tokens,
|
41 |
+
"B Ct (Np Hp Wp) -> B Np Ct Hp Wp",
|
42 |
+
Np=3,
|
43 |
+
Hp=self.cfg.plane_size,
|
44 |
+
Wp=self.cfg.plane_size,
|
45 |
+
)
|
tsr/models/transformer/__pycache__/attention.cpython-310.pyc
ADDED
Binary file (15.3 kB). View file
|
|
tsr/models/transformer/__pycache__/basic_transformer_block.cpython-310.pyc
ADDED
Binary file (9.96 kB). View file
|
|
tsr/models/transformer/__pycache__/transformer_1d.cpython-310.pyc
ADDED
Binary file (7.47 kB). View file
|
|
tsr/models/transformer/attention.py
ADDED
@@ -0,0 +1,628 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Optional
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
|
21 |
+
class Attention(nn.Module):
|
22 |
+
r"""
|
23 |
+
A cross attention layer.
|
24 |
+
|
25 |
+
Parameters:
|
26 |
+
query_dim (`int`):
|
27 |
+
The number of channels in the query.
|
28 |
+
cross_attention_dim (`int`, *optional*):
|
29 |
+
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
30 |
+
heads (`int`, *optional*, defaults to 8):
|
31 |
+
The number of heads to use for multi-head attention.
|
32 |
+
dim_head (`int`, *optional*, defaults to 64):
|
33 |
+
The number of channels in each head.
|
34 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
35 |
+
The dropout probability to use.
|
36 |
+
bias (`bool`, *optional*, defaults to False):
|
37 |
+
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
38 |
+
upcast_attention (`bool`, *optional*, defaults to False):
|
39 |
+
Set to `True` to upcast the attention computation to `float32`.
|
40 |
+
upcast_softmax (`bool`, *optional*, defaults to False):
|
41 |
+
Set to `True` to upcast the softmax computation to `float32`.
|
42 |
+
cross_attention_norm (`str`, *optional*, defaults to `None`):
|
43 |
+
The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
|
44 |
+
cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
|
45 |
+
The number of groups to use for the group norm in the cross attention.
|
46 |
+
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
|
47 |
+
The number of channels to use for the added key and value projections. If `None`, no projection is used.
|
48 |
+
norm_num_groups (`int`, *optional*, defaults to `None`):
|
49 |
+
The number of groups to use for the group norm in the attention.
|
50 |
+
spatial_norm_dim (`int`, *optional*, defaults to `None`):
|
51 |
+
The number of channels to use for the spatial normalization.
|
52 |
+
out_bias (`bool`, *optional*, defaults to `True`):
|
53 |
+
Set to `True` to use a bias in the output linear layer.
|
54 |
+
scale_qk (`bool`, *optional*, defaults to `True`):
|
55 |
+
Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
|
56 |
+
only_cross_attention (`bool`, *optional*, defaults to `False`):
|
57 |
+
Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
|
58 |
+
`added_kv_proj_dim` is not `None`.
|
59 |
+
eps (`float`, *optional*, defaults to 1e-5):
|
60 |
+
An additional value added to the denominator in group normalization that is used for numerical stability.
|
61 |
+
rescale_output_factor (`float`, *optional*, defaults to 1.0):
|
62 |
+
A factor to rescale the output by dividing it with this value.
|
63 |
+
residual_connection (`bool`, *optional*, defaults to `False`):
|
64 |
+
Set to `True` to add the residual connection to the output.
|
65 |
+
_from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
|
66 |
+
Set to `True` if the attention block is loaded from a deprecated state dict.
|
67 |
+
processor (`AttnProcessor`, *optional*, defaults to `None`):
|
68 |
+
The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
|
69 |
+
`AttnProcessor` otherwise.
|
70 |
+
"""
|
71 |
+
|
72 |
+
def __init__(
|
73 |
+
self,
|
74 |
+
query_dim: int,
|
75 |
+
cross_attention_dim: Optional[int] = None,
|
76 |
+
heads: int = 8,
|
77 |
+
dim_head: int = 64,
|
78 |
+
dropout: float = 0.0,
|
79 |
+
bias: bool = False,
|
80 |
+
upcast_attention: bool = False,
|
81 |
+
upcast_softmax: bool = False,
|
82 |
+
cross_attention_norm: Optional[str] = None,
|
83 |
+
cross_attention_norm_num_groups: int = 32,
|
84 |
+
added_kv_proj_dim: Optional[int] = None,
|
85 |
+
norm_num_groups: Optional[int] = None,
|
86 |
+
out_bias: bool = True,
|
87 |
+
scale_qk: bool = True,
|
88 |
+
only_cross_attention: bool = False,
|
89 |
+
eps: float = 1e-5,
|
90 |
+
rescale_output_factor: float = 1.0,
|
91 |
+
residual_connection: bool = False,
|
92 |
+
_from_deprecated_attn_block: bool = False,
|
93 |
+
processor: Optional["AttnProcessor"] = None,
|
94 |
+
out_dim: int = None,
|
95 |
+
):
|
96 |
+
super().__init__()
|
97 |
+
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
98 |
+
self.query_dim = query_dim
|
99 |
+
self.cross_attention_dim = (
|
100 |
+
cross_attention_dim if cross_attention_dim is not None else query_dim
|
101 |
+
)
|
102 |
+
self.upcast_attention = upcast_attention
|
103 |
+
self.upcast_softmax = upcast_softmax
|
104 |
+
self.rescale_output_factor = rescale_output_factor
|
105 |
+
self.residual_connection = residual_connection
|
106 |
+
self.dropout = dropout
|
107 |
+
self.fused_projections = False
|
108 |
+
self.out_dim = out_dim if out_dim is not None else query_dim
|
109 |
+
|
110 |
+
# we make use of this private variable to know whether this class is loaded
|
111 |
+
# with an deprecated state dict so that we can convert it on the fly
|
112 |
+
self._from_deprecated_attn_block = _from_deprecated_attn_block
|
113 |
+
|
114 |
+
self.scale_qk = scale_qk
|
115 |
+
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
116 |
+
|
117 |
+
self.heads = out_dim // dim_head if out_dim is not None else heads
|
118 |
+
# for slice_size > 0 the attention score computation
|
119 |
+
# is split across the batch axis to save memory
|
120 |
+
# You can set slice_size with `set_attention_slice`
|
121 |
+
self.sliceable_head_dim = heads
|
122 |
+
|
123 |
+
self.added_kv_proj_dim = added_kv_proj_dim
|
124 |
+
self.only_cross_attention = only_cross_attention
|
125 |
+
|
126 |
+
if self.added_kv_proj_dim is None and self.only_cross_attention:
|
127 |
+
raise ValueError(
|
128 |
+
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
|
129 |
+
)
|
130 |
+
|
131 |
+
if norm_num_groups is not None:
|
132 |
+
self.group_norm = nn.GroupNorm(
|
133 |
+
num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True
|
134 |
+
)
|
135 |
+
else:
|
136 |
+
self.group_norm = None
|
137 |
+
|
138 |
+
self.spatial_norm = None
|
139 |
+
|
140 |
+
if cross_attention_norm is None:
|
141 |
+
self.norm_cross = None
|
142 |
+
elif cross_attention_norm == "layer_norm":
|
143 |
+
self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
|
144 |
+
elif cross_attention_norm == "group_norm":
|
145 |
+
if self.added_kv_proj_dim is not None:
|
146 |
+
# The given `encoder_hidden_states` are initially of shape
|
147 |
+
# (batch_size, seq_len, added_kv_proj_dim) before being projected
|
148 |
+
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
|
149 |
+
# before the projection, so we need to use `added_kv_proj_dim` as
|
150 |
+
# the number of channels for the group norm.
|
151 |
+
norm_cross_num_channels = added_kv_proj_dim
|
152 |
+
else:
|
153 |
+
norm_cross_num_channels = self.cross_attention_dim
|
154 |
+
|
155 |
+
self.norm_cross = nn.GroupNorm(
|
156 |
+
num_channels=norm_cross_num_channels,
|
157 |
+
num_groups=cross_attention_norm_num_groups,
|
158 |
+
eps=1e-5,
|
159 |
+
affine=True,
|
160 |
+
)
|
161 |
+
else:
|
162 |
+
raise ValueError(
|
163 |
+
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
|
164 |
+
)
|
165 |
+
|
166 |
+
linear_cls = nn.Linear
|
167 |
+
|
168 |
+
self.linear_cls = linear_cls
|
169 |
+
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
|
170 |
+
|
171 |
+
if not self.only_cross_attention:
|
172 |
+
# only relevant for the `AddedKVProcessor` classes
|
173 |
+
self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
|
174 |
+
self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
|
175 |
+
else:
|
176 |
+
self.to_k = None
|
177 |
+
self.to_v = None
|
178 |
+
|
179 |
+
if self.added_kv_proj_dim is not None:
|
180 |
+
self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
|
181 |
+
self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
|
182 |
+
|
183 |
+
self.to_out = nn.ModuleList([])
|
184 |
+
self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
|
185 |
+
self.to_out.append(nn.Dropout(dropout))
|
186 |
+
|
187 |
+
# set attention processor
|
188 |
+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
189 |
+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
190 |
+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
191 |
+
if processor is None:
|
192 |
+
processor = (
|
193 |
+
AttnProcessor2_0()
|
194 |
+
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
|
195 |
+
else AttnProcessor()
|
196 |
+
)
|
197 |
+
self.set_processor(processor)
|
198 |
+
|
199 |
+
def set_processor(self, processor: "AttnProcessor") -> None:
|
200 |
+
self.processor = processor
|
201 |
+
|
202 |
+
def forward(
|
203 |
+
self,
|
204 |
+
hidden_states: torch.FloatTensor,
|
205 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
206 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
207 |
+
**cross_attention_kwargs,
|
208 |
+
) -> torch.Tensor:
|
209 |
+
r"""
|
210 |
+
The forward method of the `Attention` class.
|
211 |
+
|
212 |
+
Args:
|
213 |
+
hidden_states (`torch.Tensor`):
|
214 |
+
The hidden states of the query.
|
215 |
+
encoder_hidden_states (`torch.Tensor`, *optional*):
|
216 |
+
The hidden states of the encoder.
|
217 |
+
attention_mask (`torch.Tensor`, *optional*):
|
218 |
+
The attention mask to use. If `None`, no mask is applied.
|
219 |
+
**cross_attention_kwargs:
|
220 |
+
Additional keyword arguments to pass along to the cross attention.
|
221 |
+
|
222 |
+
Returns:
|
223 |
+
`torch.Tensor`: The output of the attention layer.
|
224 |
+
"""
|
225 |
+
# The `Attention` class can call different attention processors / attention functions
|
226 |
+
# here we simply pass along all tensors to the selected processor class
|
227 |
+
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
228 |
+
return self.processor(
|
229 |
+
self,
|
230 |
+
hidden_states,
|
231 |
+
encoder_hidden_states=encoder_hidden_states,
|
232 |
+
attention_mask=attention_mask,
|
233 |
+
**cross_attention_kwargs,
|
234 |
+
)
|
235 |
+
|
236 |
+
def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
|
237 |
+
r"""
|
238 |
+
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
|
239 |
+
is the number of heads initialized while constructing the `Attention` class.
|
240 |
+
|
241 |
+
Args:
|
242 |
+
tensor (`torch.Tensor`): The tensor to reshape.
|
243 |
+
|
244 |
+
Returns:
|
245 |
+
`torch.Tensor`: The reshaped tensor.
|
246 |
+
"""
|
247 |
+
head_size = self.heads
|
248 |
+
batch_size, seq_len, dim = tensor.shape
|
249 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
250 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(
|
251 |
+
batch_size // head_size, seq_len, dim * head_size
|
252 |
+
)
|
253 |
+
return tensor
|
254 |
+
|
255 |
+
def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
|
256 |
+
r"""
|
257 |
+
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
|
258 |
+
the number of heads initialized while constructing the `Attention` class.
|
259 |
+
|
260 |
+
Args:
|
261 |
+
tensor (`torch.Tensor`): The tensor to reshape.
|
262 |
+
out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
|
263 |
+
reshaped to `[batch_size * heads, seq_len, dim // heads]`.
|
264 |
+
|
265 |
+
Returns:
|
266 |
+
`torch.Tensor`: The reshaped tensor.
|
267 |
+
"""
|
268 |
+
head_size = self.heads
|
269 |
+
batch_size, seq_len, dim = tensor.shape
|
270 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
271 |
+
tensor = tensor.permute(0, 2, 1, 3)
|
272 |
+
|
273 |
+
if out_dim == 3:
|
274 |
+
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
|
275 |
+
|
276 |
+
return tensor
|
277 |
+
|
278 |
+
def get_attention_scores(
|
279 |
+
self,
|
280 |
+
query: torch.Tensor,
|
281 |
+
key: torch.Tensor,
|
282 |
+
attention_mask: torch.Tensor = None,
|
283 |
+
) -> torch.Tensor:
|
284 |
+
r"""
|
285 |
+
Compute the attention scores.
|
286 |
+
|
287 |
+
Args:
|
288 |
+
query (`torch.Tensor`): The query tensor.
|
289 |
+
key (`torch.Tensor`): The key tensor.
|
290 |
+
attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
|
291 |
+
|
292 |
+
Returns:
|
293 |
+
`torch.Tensor`: The attention probabilities/scores.
|
294 |
+
"""
|
295 |
+
dtype = query.dtype
|
296 |
+
if self.upcast_attention:
|
297 |
+
query = query.float()
|
298 |
+
key = key.float()
|
299 |
+
|
300 |
+
if attention_mask is None:
|
301 |
+
baddbmm_input = torch.empty(
|
302 |
+
query.shape[0],
|
303 |
+
query.shape[1],
|
304 |
+
key.shape[1],
|
305 |
+
dtype=query.dtype,
|
306 |
+
device=query.device,
|
307 |
+
)
|
308 |
+
beta = 0
|
309 |
+
else:
|
310 |
+
baddbmm_input = attention_mask
|
311 |
+
beta = 1
|
312 |
+
|
313 |
+
attention_scores = torch.baddbmm(
|
314 |
+
baddbmm_input,
|
315 |
+
query,
|
316 |
+
key.transpose(-1, -2),
|
317 |
+
beta=beta,
|
318 |
+
alpha=self.scale,
|
319 |
+
)
|
320 |
+
del baddbmm_input
|
321 |
+
|
322 |
+
if self.upcast_softmax:
|
323 |
+
attention_scores = attention_scores.float()
|
324 |
+
|
325 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
326 |
+
del attention_scores
|
327 |
+
|
328 |
+
attention_probs = attention_probs.to(dtype)
|
329 |
+
|
330 |
+
return attention_probs
|
331 |
+
|
332 |
+
def prepare_attention_mask(
|
333 |
+
self,
|
334 |
+
attention_mask: torch.Tensor,
|
335 |
+
target_length: int,
|
336 |
+
batch_size: int,
|
337 |
+
out_dim: int = 3,
|
338 |
+
) -> torch.Tensor:
|
339 |
+
r"""
|
340 |
+
Prepare the attention mask for the attention computation.
|
341 |
+
|
342 |
+
Args:
|
343 |
+
attention_mask (`torch.Tensor`):
|
344 |
+
The attention mask to prepare.
|
345 |
+
target_length (`int`):
|
346 |
+
The target length of the attention mask. This is the length of the attention mask after padding.
|
347 |
+
batch_size (`int`):
|
348 |
+
The batch size, which is used to repeat the attention mask.
|
349 |
+
out_dim (`int`, *optional*, defaults to `3`):
|
350 |
+
The output dimension of the attention mask. Can be either `3` or `4`.
|
351 |
+
|
352 |
+
Returns:
|
353 |
+
`torch.Tensor`: The prepared attention mask.
|
354 |
+
"""
|
355 |
+
head_size = self.heads
|
356 |
+
if attention_mask is None:
|
357 |
+
return attention_mask
|
358 |
+
|
359 |
+
current_length: int = attention_mask.shape[-1]
|
360 |
+
if current_length != target_length:
|
361 |
+
if attention_mask.device.type == "mps":
|
362 |
+
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
363 |
+
# Instead, we can manually construct the padding tensor.
|
364 |
+
padding_shape = (
|
365 |
+
attention_mask.shape[0],
|
366 |
+
attention_mask.shape[1],
|
367 |
+
target_length,
|
368 |
+
)
|
369 |
+
padding = torch.zeros(
|
370 |
+
padding_shape,
|
371 |
+
dtype=attention_mask.dtype,
|
372 |
+
device=attention_mask.device,
|
373 |
+
)
|
374 |
+
attention_mask = torch.cat([attention_mask, padding], dim=2)
|
375 |
+
else:
|
376 |
+
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
|
377 |
+
# we want to instead pad by (0, remaining_length), where remaining_length is:
|
378 |
+
# remaining_length: int = target_length - current_length
|
379 |
+
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
|
380 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
381 |
+
|
382 |
+
if out_dim == 3:
|
383 |
+
if attention_mask.shape[0] < batch_size * head_size:
|
384 |
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
|
385 |
+
elif out_dim == 4:
|
386 |
+
attention_mask = attention_mask.unsqueeze(1)
|
387 |
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
|
388 |
+
|
389 |
+
return attention_mask
|
390 |
+
|
391 |
+
def norm_encoder_hidden_states(
|
392 |
+
self, encoder_hidden_states: torch.Tensor
|
393 |
+
) -> torch.Tensor:
|
394 |
+
r"""
|
395 |
+
Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
|
396 |
+
`Attention` class.
|
397 |
+
|
398 |
+
Args:
|
399 |
+
encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
|
400 |
+
|
401 |
+
Returns:
|
402 |
+
`torch.Tensor`: The normalized encoder hidden states.
|
403 |
+
"""
|
404 |
+
assert (
|
405 |
+
self.norm_cross is not None
|
406 |
+
), "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
|
407 |
+
|
408 |
+
if isinstance(self.norm_cross, nn.LayerNorm):
|
409 |
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
410 |
+
elif isinstance(self.norm_cross, nn.GroupNorm):
|
411 |
+
# Group norm norms along the channels dimension and expects
|
412 |
+
# input to be in the shape of (N, C, *). In this case, we want
|
413 |
+
# to norm along the hidden dimension, so we need to move
|
414 |
+
# (batch_size, sequence_length, hidden_size) ->
|
415 |
+
# (batch_size, hidden_size, sequence_length)
|
416 |
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
417 |
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
418 |
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
419 |
+
else:
|
420 |
+
assert False
|
421 |
+
|
422 |
+
return encoder_hidden_states
|
423 |
+
|
424 |
+
@torch.no_grad()
|
425 |
+
def fuse_projections(self, fuse=True):
|
426 |
+
is_cross_attention = self.cross_attention_dim != self.query_dim
|
427 |
+
device = self.to_q.weight.data.device
|
428 |
+
dtype = self.to_q.weight.data.dtype
|
429 |
+
|
430 |
+
if not is_cross_attention:
|
431 |
+
# fetch weight matrices.
|
432 |
+
concatenated_weights = torch.cat(
|
433 |
+
[self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]
|
434 |
+
)
|
435 |
+
in_features = concatenated_weights.shape[1]
|
436 |
+
out_features = concatenated_weights.shape[0]
|
437 |
+
|
438 |
+
# create a new single projection layer and copy over the weights.
|
439 |
+
self.to_qkv = self.linear_cls(
|
440 |
+
in_features, out_features, bias=False, device=device, dtype=dtype
|
441 |
+
)
|
442 |
+
self.to_qkv.weight.copy_(concatenated_weights)
|
443 |
+
|
444 |
+
else:
|
445 |
+
concatenated_weights = torch.cat(
|
446 |
+
[self.to_k.weight.data, self.to_v.weight.data]
|
447 |
+
)
|
448 |
+
in_features = concatenated_weights.shape[1]
|
449 |
+
out_features = concatenated_weights.shape[0]
|
450 |
+
|
451 |
+
self.to_kv = self.linear_cls(
|
452 |
+
in_features, out_features, bias=False, device=device, dtype=dtype
|
453 |
+
)
|
454 |
+
self.to_kv.weight.copy_(concatenated_weights)
|
455 |
+
|
456 |
+
self.fused_projections = fuse
|
457 |
+
|
458 |
+
|
459 |
+
class AttnProcessor:
|
460 |
+
r"""
|
461 |
+
Default processor for performing attention-related computations.
|
462 |
+
"""
|
463 |
+
|
464 |
+
def __call__(
|
465 |
+
self,
|
466 |
+
attn: Attention,
|
467 |
+
hidden_states: torch.FloatTensor,
|
468 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
469 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
470 |
+
) -> torch.Tensor:
|
471 |
+
residual = hidden_states
|
472 |
+
|
473 |
+
input_ndim = hidden_states.ndim
|
474 |
+
|
475 |
+
if input_ndim == 4:
|
476 |
+
batch_size, channel, height, width = hidden_states.shape
|
477 |
+
hidden_states = hidden_states.view(
|
478 |
+
batch_size, channel, height * width
|
479 |
+
).transpose(1, 2)
|
480 |
+
|
481 |
+
batch_size, sequence_length, _ = (
|
482 |
+
hidden_states.shape
|
483 |
+
if encoder_hidden_states is None
|
484 |
+
else encoder_hidden_states.shape
|
485 |
+
)
|
486 |
+
attention_mask = attn.prepare_attention_mask(
|
487 |
+
attention_mask, sequence_length, batch_size
|
488 |
+
)
|
489 |
+
|
490 |
+
if attn.group_norm is not None:
|
491 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
492 |
+
1, 2
|
493 |
+
)
|
494 |
+
|
495 |
+
query = attn.to_q(hidden_states)
|
496 |
+
|
497 |
+
if encoder_hidden_states is None:
|
498 |
+
encoder_hidden_states = hidden_states
|
499 |
+
elif attn.norm_cross:
|
500 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
501 |
+
encoder_hidden_states
|
502 |
+
)
|
503 |
+
|
504 |
+
key = attn.to_k(encoder_hidden_states)
|
505 |
+
value = attn.to_v(encoder_hidden_states)
|
506 |
+
|
507 |
+
query = attn.head_to_batch_dim(query)
|
508 |
+
key = attn.head_to_batch_dim(key)
|
509 |
+
value = attn.head_to_batch_dim(value)
|
510 |
+
|
511 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
512 |
+
hidden_states = torch.bmm(attention_probs, value)
|
513 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
514 |
+
|
515 |
+
# linear proj
|
516 |
+
hidden_states = attn.to_out[0](hidden_states)
|
517 |
+
# dropout
|
518 |
+
hidden_states = attn.to_out[1](hidden_states)
|
519 |
+
|
520 |
+
if input_ndim == 4:
|
521 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
522 |
+
batch_size, channel, height, width
|
523 |
+
)
|
524 |
+
|
525 |
+
if attn.residual_connection:
|
526 |
+
hidden_states = hidden_states + residual
|
527 |
+
|
528 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
529 |
+
|
530 |
+
return hidden_states
|
531 |
+
|
532 |
+
|
533 |
+
class AttnProcessor2_0:
|
534 |
+
r"""
|
535 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
536 |
+
"""
|
537 |
+
|
538 |
+
def __init__(self):
|
539 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
540 |
+
raise ImportError(
|
541 |
+
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
542 |
+
)
|
543 |
+
|
544 |
+
def __call__(
|
545 |
+
self,
|
546 |
+
attn: Attention,
|
547 |
+
hidden_states: torch.FloatTensor,
|
548 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
549 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
550 |
+
) -> torch.FloatTensor:
|
551 |
+
residual = hidden_states
|
552 |
+
|
553 |
+
input_ndim = hidden_states.ndim
|
554 |
+
|
555 |
+
if input_ndim == 4:
|
556 |
+
batch_size, channel, height, width = hidden_states.shape
|
557 |
+
hidden_states = hidden_states.view(
|
558 |
+
batch_size, channel, height * width
|
559 |
+
).transpose(1, 2)
|
560 |
+
|
561 |
+
batch_size, sequence_length, _ = (
|
562 |
+
hidden_states.shape
|
563 |
+
if encoder_hidden_states is None
|
564 |
+
else encoder_hidden_states.shape
|
565 |
+
)
|
566 |
+
|
567 |
+
if attention_mask is not None:
|
568 |
+
attention_mask = attn.prepare_attention_mask(
|
569 |
+
attention_mask, sequence_length, batch_size
|
570 |
+
)
|
571 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
572 |
+
# (batch, heads, source_length, target_length)
|
573 |
+
attention_mask = attention_mask.view(
|
574 |
+
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
575 |
+
)
|
576 |
+
|
577 |
+
if attn.group_norm is not None:
|
578 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
579 |
+
1, 2
|
580 |
+
)
|
581 |
+
|
582 |
+
query = attn.to_q(hidden_states)
|
583 |
+
|
584 |
+
if encoder_hidden_states is None:
|
585 |
+
encoder_hidden_states = hidden_states
|
586 |
+
elif attn.norm_cross:
|
587 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
588 |
+
encoder_hidden_states
|
589 |
+
)
|
590 |
+
|
591 |
+
key = attn.to_k(encoder_hidden_states)
|
592 |
+
value = attn.to_v(encoder_hidden_states)
|
593 |
+
|
594 |
+
inner_dim = key.shape[-1]
|
595 |
+
head_dim = inner_dim // attn.heads
|
596 |
+
|
597 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
598 |
+
|
599 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
600 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
601 |
+
|
602 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
603 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
604 |
+
hidden_states = F.scaled_dot_product_attention(
|
605 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
606 |
+
)
|
607 |
+
|
608 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
609 |
+
batch_size, -1, attn.heads * head_dim
|
610 |
+
)
|
611 |
+
hidden_states = hidden_states.to(query.dtype)
|
612 |
+
|
613 |
+
# linear proj
|
614 |
+
hidden_states = attn.to_out[0](hidden_states)
|
615 |
+
# dropout
|
616 |
+
hidden_states = attn.to_out[1](hidden_states)
|
617 |
+
|
618 |
+
if input_ndim == 4:
|
619 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
620 |
+
batch_size, channel, height, width
|
621 |
+
)
|
622 |
+
|
623 |
+
if attn.residual_connection:
|
624 |
+
hidden_states = hidden_states + residual
|
625 |
+
|
626 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
627 |
+
|
628 |
+
return hidden_states
|
tsr/models/transformer/basic_transformer_block.py
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from typing import Any, Dict, Optional
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from torch import nn
|
20 |
+
|
21 |
+
from .attention import Attention
|
22 |
+
|
23 |
+
|
24 |
+
class BasicTransformerBlock(nn.Module):
|
25 |
+
r"""
|
26 |
+
A basic Transformer block.
|
27 |
+
|
28 |
+
Parameters:
|
29 |
+
dim (`int`): The number of channels in the input and output.
|
30 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
31 |
+
attention_head_dim (`int`): The number of channels in each head.
|
32 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
33 |
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
34 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
35 |
+
num_embeds_ada_norm (:
|
36 |
+
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
37 |
+
attention_bias (:
|
38 |
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
39 |
+
only_cross_attention (`bool`, *optional*):
|
40 |
+
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
41 |
+
double_self_attention (`bool`, *optional*):
|
42 |
+
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
43 |
+
upcast_attention (`bool`, *optional*):
|
44 |
+
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
45 |
+
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
46 |
+
Whether to use learnable elementwise affine parameters for normalization.
|
47 |
+
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
48 |
+
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
49 |
+
final_dropout (`bool` *optional*, defaults to False):
|
50 |
+
Whether to apply a final dropout after the last feed-forward layer.
|
51 |
+
attention_type (`str`, *optional*, defaults to `"default"`):
|
52 |
+
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
dim: int,
|
58 |
+
num_attention_heads: int,
|
59 |
+
attention_head_dim: int,
|
60 |
+
dropout=0.0,
|
61 |
+
cross_attention_dim: Optional[int] = None,
|
62 |
+
activation_fn: str = "geglu",
|
63 |
+
attention_bias: bool = False,
|
64 |
+
only_cross_attention: bool = False,
|
65 |
+
double_self_attention: bool = False,
|
66 |
+
upcast_attention: bool = False,
|
67 |
+
norm_elementwise_affine: bool = True,
|
68 |
+
norm_type: str = "layer_norm",
|
69 |
+
final_dropout: bool = False,
|
70 |
+
):
|
71 |
+
super().__init__()
|
72 |
+
self.only_cross_attention = only_cross_attention
|
73 |
+
|
74 |
+
assert norm_type == "layer_norm"
|
75 |
+
|
76 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
77 |
+
# 1. Self-Attn
|
78 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
79 |
+
self.attn1 = Attention(
|
80 |
+
query_dim=dim,
|
81 |
+
heads=num_attention_heads,
|
82 |
+
dim_head=attention_head_dim,
|
83 |
+
dropout=dropout,
|
84 |
+
bias=attention_bias,
|
85 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
86 |
+
upcast_attention=upcast_attention,
|
87 |
+
)
|
88 |
+
|
89 |
+
# 2. Cross-Attn
|
90 |
+
if cross_attention_dim is not None or double_self_attention:
|
91 |
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
92 |
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
93 |
+
# the second cross attention block.
|
94 |
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
95 |
+
|
96 |
+
self.attn2 = Attention(
|
97 |
+
query_dim=dim,
|
98 |
+
cross_attention_dim=cross_attention_dim
|
99 |
+
if not double_self_attention
|
100 |
+
else None,
|
101 |
+
heads=num_attention_heads,
|
102 |
+
dim_head=attention_head_dim,
|
103 |
+
dropout=dropout,
|
104 |
+
bias=attention_bias,
|
105 |
+
upcast_attention=upcast_attention,
|
106 |
+
) # is self-attn if encoder_hidden_states is none
|
107 |
+
else:
|
108 |
+
self.norm2 = None
|
109 |
+
self.attn2 = None
|
110 |
+
|
111 |
+
# 3. Feed-forward
|
112 |
+
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
113 |
+
self.ff = FeedForward(
|
114 |
+
dim,
|
115 |
+
dropout=dropout,
|
116 |
+
activation_fn=activation_fn,
|
117 |
+
final_dropout=final_dropout,
|
118 |
+
)
|
119 |
+
|
120 |
+
# let chunk size default to None
|
121 |
+
self._chunk_size = None
|
122 |
+
self._chunk_dim = 0
|
123 |
+
|
124 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
|
125 |
+
# Sets chunk feed-forward
|
126 |
+
self._chunk_size = chunk_size
|
127 |
+
self._chunk_dim = dim
|
128 |
+
|
129 |
+
def forward(
|
130 |
+
self,
|
131 |
+
hidden_states: torch.FloatTensor,
|
132 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
133 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
134 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
135 |
+
) -> torch.FloatTensor:
|
136 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
137 |
+
# 0. Self-Attention
|
138 |
+
norm_hidden_states = self.norm1(hidden_states)
|
139 |
+
|
140 |
+
attn_output = self.attn1(
|
141 |
+
norm_hidden_states,
|
142 |
+
encoder_hidden_states=encoder_hidden_states
|
143 |
+
if self.only_cross_attention
|
144 |
+
else None,
|
145 |
+
attention_mask=attention_mask,
|
146 |
+
)
|
147 |
+
|
148 |
+
hidden_states = attn_output + hidden_states
|
149 |
+
|
150 |
+
# 3. Cross-Attention
|
151 |
+
if self.attn2 is not None:
|
152 |
+
norm_hidden_states = self.norm2(hidden_states)
|
153 |
+
|
154 |
+
attn_output = self.attn2(
|
155 |
+
norm_hidden_states,
|
156 |
+
encoder_hidden_states=encoder_hidden_states,
|
157 |
+
attention_mask=encoder_attention_mask,
|
158 |
+
)
|
159 |
+
hidden_states = attn_output + hidden_states
|
160 |
+
|
161 |
+
# 4. Feed-forward
|
162 |
+
norm_hidden_states = self.norm3(hidden_states)
|
163 |
+
|
164 |
+
if self._chunk_size is not None:
|
165 |
+
# "feed_forward_chunk_size" can be used to save memory
|
166 |
+
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
167 |
+
raise ValueError(
|
168 |
+
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
169 |
+
)
|
170 |
+
|
171 |
+
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
172 |
+
ff_output = torch.cat(
|
173 |
+
[
|
174 |
+
self.ff(hid_slice)
|
175 |
+
for hid_slice in norm_hidden_states.chunk(
|
176 |
+
num_chunks, dim=self._chunk_dim
|
177 |
+
)
|
178 |
+
],
|
179 |
+
dim=self._chunk_dim,
|
180 |
+
)
|
181 |
+
else:
|
182 |
+
ff_output = self.ff(norm_hidden_states)
|
183 |
+
|
184 |
+
hidden_states = ff_output + hidden_states
|
185 |
+
|
186 |
+
return hidden_states
|
187 |
+
|
188 |
+
|
189 |
+
class FeedForward(nn.Module):
|
190 |
+
r"""
|
191 |
+
A feed-forward layer.
|
192 |
+
|
193 |
+
Parameters:
|
194 |
+
dim (`int`): The number of channels in the input.
|
195 |
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
196 |
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
197 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
198 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
199 |
+
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
200 |
+
"""
|
201 |
+
|
202 |
+
def __init__(
|
203 |
+
self,
|
204 |
+
dim: int,
|
205 |
+
dim_out: Optional[int] = None,
|
206 |
+
mult: int = 4,
|
207 |
+
dropout: float = 0.0,
|
208 |
+
activation_fn: str = "geglu",
|
209 |
+
final_dropout: bool = False,
|
210 |
+
):
|
211 |
+
super().__init__()
|
212 |
+
inner_dim = int(dim * mult)
|
213 |
+
dim_out = dim_out if dim_out is not None else dim
|
214 |
+
linear_cls = nn.Linear
|
215 |
+
|
216 |
+
if activation_fn == "gelu":
|
217 |
+
act_fn = GELU(dim, inner_dim)
|
218 |
+
if activation_fn == "gelu-approximate":
|
219 |
+
act_fn = GELU(dim, inner_dim, approximate="tanh")
|
220 |
+
elif activation_fn == "geglu":
|
221 |
+
act_fn = GEGLU(dim, inner_dim)
|
222 |
+
elif activation_fn == "geglu-approximate":
|
223 |
+
act_fn = ApproximateGELU(dim, inner_dim)
|
224 |
+
|
225 |
+
self.net = nn.ModuleList([])
|
226 |
+
# project in
|
227 |
+
self.net.append(act_fn)
|
228 |
+
# project dropout
|
229 |
+
self.net.append(nn.Dropout(dropout))
|
230 |
+
# project out
|
231 |
+
self.net.append(linear_cls(inner_dim, dim_out))
|
232 |
+
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
233 |
+
if final_dropout:
|
234 |
+
self.net.append(nn.Dropout(dropout))
|
235 |
+
|
236 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
237 |
+
for module in self.net:
|
238 |
+
hidden_states = module(hidden_states)
|
239 |
+
return hidden_states
|
240 |
+
|
241 |
+
|
242 |
+
class GELU(nn.Module):
|
243 |
+
r"""
|
244 |
+
GELU activation function with tanh approximation support with `approximate="tanh"`.
|
245 |
+
|
246 |
+
Parameters:
|
247 |
+
dim_in (`int`): The number of channels in the input.
|
248 |
+
dim_out (`int`): The number of channels in the output.
|
249 |
+
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
|
250 |
+
"""
|
251 |
+
|
252 |
+
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
|
253 |
+
super().__init__()
|
254 |
+
self.proj = nn.Linear(dim_in, dim_out)
|
255 |
+
self.approximate = approximate
|
256 |
+
|
257 |
+
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
258 |
+
if gate.device.type != "mps":
|
259 |
+
return F.gelu(gate, approximate=self.approximate)
|
260 |
+
# mps: gelu is not implemented for float16
|
261 |
+
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(
|
262 |
+
dtype=gate.dtype
|
263 |
+
)
|
264 |
+
|
265 |
+
def forward(self, hidden_states):
|
266 |
+
hidden_states = self.proj(hidden_states)
|
267 |
+
hidden_states = self.gelu(hidden_states)
|
268 |
+
return hidden_states
|
269 |
+
|
270 |
+
|
271 |
+
class GEGLU(nn.Module):
|
272 |
+
r"""
|
273 |
+
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
274 |
+
|
275 |
+
Parameters:
|
276 |
+
dim_in (`int`): The number of channels in the input.
|
277 |
+
dim_out (`int`): The number of channels in the output.
|
278 |
+
"""
|
279 |
+
|
280 |
+
def __init__(self, dim_in: int, dim_out: int):
|
281 |
+
super().__init__()
|
282 |
+
linear_cls = nn.Linear
|
283 |
+
|
284 |
+
self.proj = linear_cls(dim_in, dim_out * 2)
|
285 |
+
|
286 |
+
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
287 |
+
if gate.device.type != "mps":
|
288 |
+
return F.gelu(gate)
|
289 |
+
# mps: gelu is not implemented for float16
|
290 |
+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
291 |
+
|
292 |
+
def forward(self, hidden_states, scale: float = 1.0):
|
293 |
+
args = ()
|
294 |
+
hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
|
295 |
+
return hidden_states * self.gelu(gate)
|
296 |
+
|
297 |
+
|
298 |
+
class ApproximateGELU(nn.Module):
|
299 |
+
r"""
|
300 |
+
The approximate form of Gaussian Error Linear Unit (GELU). For more details, see section 2:
|
301 |
+
https://arxiv.org/abs/1606.08415.
|
302 |
+
|
303 |
+
Parameters:
|
304 |
+
dim_in (`int`): The number of channels in the input.
|
305 |
+
dim_out (`int`): The number of channels in the output.
|
306 |
+
"""
|
307 |
+
|
308 |
+
def __init__(self, dim_in: int, dim_out: int):
|
309 |
+
super().__init__()
|
310 |
+
self.proj = nn.Linear(dim_in, dim_out)
|
311 |
+
|
312 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
313 |
+
x = self.proj(x)
|
314 |
+
return x * torch.sigmoid(1.702 * x)
|
tsr/models/transformer/transformer_1d.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import Any, Dict, Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
from ...utils import BaseModule
|
9 |
+
from .basic_transformer_block import BasicTransformerBlock
|
10 |
+
|
11 |
+
|
12 |
+
class Transformer1D(BaseModule):
|
13 |
+
"""
|
14 |
+
A 1D Transformer model for sequence data.
|
15 |
+
|
16 |
+
Parameters:
|
17 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
18 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
19 |
+
in_channels (`int`, *optional*):
|
20 |
+
The number of channels in the input and output (specify if the input is **continuous**).
|
21 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
22 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
23 |
+
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
24 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
|
25 |
+
num_embeds_ada_norm ( `int`, *optional*):
|
26 |
+
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
|
27 |
+
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
|
28 |
+
added to the hidden states.
|
29 |
+
|
30 |
+
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
|
31 |
+
attention_bias (`bool`, *optional*):
|
32 |
+
Configure if the `TransformerBlocks` attention should contain a bias parameter.
|
33 |
+
"""
|
34 |
+
|
35 |
+
@dataclass
|
36 |
+
class Config(BaseModule.Config):
|
37 |
+
num_attention_heads: int = 16
|
38 |
+
attention_head_dim: int = 88
|
39 |
+
in_channels: Optional[int] = None
|
40 |
+
out_channels: Optional[int] = None
|
41 |
+
num_layers: int = 1
|
42 |
+
dropout: float = 0.0
|
43 |
+
norm_num_groups: int = 32
|
44 |
+
cross_attention_dim: Optional[int] = None
|
45 |
+
attention_bias: bool = False
|
46 |
+
activation_fn: str = "geglu"
|
47 |
+
only_cross_attention: bool = False
|
48 |
+
double_self_attention: bool = False
|
49 |
+
upcast_attention: bool = False
|
50 |
+
norm_type: str = "layer_norm"
|
51 |
+
norm_elementwise_affine: bool = True
|
52 |
+
gradient_checkpointing: bool = False
|
53 |
+
|
54 |
+
cfg: Config
|
55 |
+
|
56 |
+
def configure(self) -> None:
|
57 |
+
self.num_attention_heads = self.cfg.num_attention_heads
|
58 |
+
self.attention_head_dim = self.cfg.attention_head_dim
|
59 |
+
inner_dim = self.num_attention_heads * self.attention_head_dim
|
60 |
+
|
61 |
+
linear_cls = nn.Linear
|
62 |
+
|
63 |
+
# 2. Define input layers
|
64 |
+
self.in_channels = self.cfg.in_channels
|
65 |
+
|
66 |
+
self.norm = torch.nn.GroupNorm(
|
67 |
+
num_groups=self.cfg.norm_num_groups,
|
68 |
+
num_channels=self.cfg.in_channels,
|
69 |
+
eps=1e-6,
|
70 |
+
affine=True,
|
71 |
+
)
|
72 |
+
self.proj_in = linear_cls(self.cfg.in_channels, inner_dim)
|
73 |
+
|
74 |
+
# 3. Define transformers blocks
|
75 |
+
self.transformer_blocks = nn.ModuleList(
|
76 |
+
[
|
77 |
+
BasicTransformerBlock(
|
78 |
+
inner_dim,
|
79 |
+
self.num_attention_heads,
|
80 |
+
self.attention_head_dim,
|
81 |
+
dropout=self.cfg.dropout,
|
82 |
+
cross_attention_dim=self.cfg.cross_attention_dim,
|
83 |
+
activation_fn=self.cfg.activation_fn,
|
84 |
+
attention_bias=self.cfg.attention_bias,
|
85 |
+
only_cross_attention=self.cfg.only_cross_attention,
|
86 |
+
double_self_attention=self.cfg.double_self_attention,
|
87 |
+
upcast_attention=self.cfg.upcast_attention,
|
88 |
+
norm_type=self.cfg.norm_type,
|
89 |
+
norm_elementwise_affine=self.cfg.norm_elementwise_affine,
|
90 |
+
)
|
91 |
+
for d in range(self.cfg.num_layers)
|
92 |
+
]
|
93 |
+
)
|
94 |
+
|
95 |
+
# 4. Define output layers
|
96 |
+
self.out_channels = (
|
97 |
+
self.cfg.in_channels
|
98 |
+
if self.cfg.out_channels is None
|
99 |
+
else self.cfg.out_channels
|
100 |
+
)
|
101 |
+
|
102 |
+
self.proj_out = linear_cls(inner_dim, self.cfg.in_channels)
|
103 |
+
|
104 |
+
self.gradient_checkpointing = self.cfg.gradient_checkpointing
|
105 |
+
|
106 |
+
def forward(
|
107 |
+
self,
|
108 |
+
hidden_states: torch.Tensor,
|
109 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
110 |
+
attention_mask: Optional[torch.Tensor] = None,
|
111 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
112 |
+
):
|
113 |
+
"""
|
114 |
+
The [`Transformer1DModel`] forward method.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
118 |
+
Input `hidden_states`.
|
119 |
+
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
120 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
121 |
+
self-attention.
|
122 |
+
timestep ( `torch.LongTensor`, *optional*):
|
123 |
+
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
124 |
+
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
125 |
+
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
126 |
+
`AdaLayerZeroNorm`.
|
127 |
+
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
|
128 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
129 |
+
`self.processor` in
|
130 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
131 |
+
attention_mask ( `torch.Tensor`, *optional*):
|
132 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
133 |
+
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
134 |
+
negative values to the attention scores corresponding to "discard" tokens.
|
135 |
+
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
136 |
+
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
137 |
+
|
138 |
+
* Mask `(batch, sequence_length)` True = keep, False = discard.
|
139 |
+
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
|
140 |
+
|
141 |
+
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
142 |
+
above. This bias will be added to the cross-attention scores.
|
143 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
144 |
+
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
145 |
+
tuple.
|
146 |
+
|
147 |
+
Returns:
|
148 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
149 |
+
`tuple` where the first element is the sample tensor.
|
150 |
+
"""
|
151 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
152 |
+
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
153 |
+
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
154 |
+
# expects mask of shape:
|
155 |
+
# [batch, key_tokens]
|
156 |
+
# adds singleton query_tokens dimension:
|
157 |
+
# [batch, 1, key_tokens]
|
158 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
159 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
160 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
161 |
+
if attention_mask is not None and attention_mask.ndim == 2:
|
162 |
+
# assume that mask is expressed as:
|
163 |
+
# (1 = keep, 0 = discard)
|
164 |
+
# convert mask into a bias that can be added to attention scores:
|
165 |
+
# (keep = +0, discard = -10000.0)
|
166 |
+
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
167 |
+
attention_mask = attention_mask.unsqueeze(1)
|
168 |
+
|
169 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
170 |
+
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
171 |
+
encoder_attention_mask = (
|
172 |
+
1 - encoder_attention_mask.to(hidden_states.dtype)
|
173 |
+
) * -10000.0
|
174 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
175 |
+
|
176 |
+
# 1. Input
|
177 |
+
batch, _, seq_len = hidden_states.shape
|
178 |
+
residual = hidden_states
|
179 |
+
|
180 |
+
hidden_states = self.norm(hidden_states)
|
181 |
+
inner_dim = hidden_states.shape[1]
|
182 |
+
hidden_states = hidden_states.permute(0, 2, 1).reshape(
|
183 |
+
batch, seq_len, inner_dim
|
184 |
+
)
|
185 |
+
hidden_states = self.proj_in(hidden_states)
|
186 |
+
|
187 |
+
# 2. Blocks
|
188 |
+
for block in self.transformer_blocks:
|
189 |
+
if self.training and self.gradient_checkpointing:
|
190 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
191 |
+
block,
|
192 |
+
hidden_states,
|
193 |
+
attention_mask,
|
194 |
+
encoder_hidden_states,
|
195 |
+
encoder_attention_mask,
|
196 |
+
use_reentrant=False,
|
197 |
+
)
|
198 |
+
else:
|
199 |
+
hidden_states = block(
|
200 |
+
hidden_states,
|
201 |
+
attention_mask=attention_mask,
|
202 |
+
encoder_hidden_states=encoder_hidden_states,
|
203 |
+
encoder_attention_mask=encoder_attention_mask,
|
204 |
+
)
|
205 |
+
|
206 |
+
# 3. Output
|
207 |
+
hidden_states = self.proj_out(hidden_states)
|
208 |
+
hidden_states = (
|
209 |
+
hidden_states.reshape(batch, seq_len, inner_dim)
|
210 |
+
.permute(0, 2, 1)
|
211 |
+
.contiguous()
|
212 |
+
)
|
213 |
+
|
214 |
+
output = hidden_states + residual
|
215 |
+
|
216 |
+
return output
|
tsr/system.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
from typing import List, Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import PIL.Image
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import trimesh
|
11 |
+
from einops import rearrange
|
12 |
+
from huggingface_hub import hf_hub_download
|
13 |
+
from omegaconf import OmegaConf
|
14 |
+
from PIL import Image
|
15 |
+
|
16 |
+
from .models.isosurface import MarchingCubeHelper
|
17 |
+
from .utils import (
|
18 |
+
BaseModule,
|
19 |
+
ImagePreprocessor,
|
20 |
+
find_class,
|
21 |
+
get_spherical_cameras,
|
22 |
+
scale_tensor,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
class TSR(BaseModule):
|
27 |
+
@dataclass
|
28 |
+
class Config(BaseModule.Config):
|
29 |
+
cond_image_size: int
|
30 |
+
|
31 |
+
image_tokenizer_cls: str
|
32 |
+
image_tokenizer: dict
|
33 |
+
|
34 |
+
tokenizer_cls: str
|
35 |
+
tokenizer: dict
|
36 |
+
|
37 |
+
backbone_cls: str
|
38 |
+
backbone: dict
|
39 |
+
|
40 |
+
post_processor_cls: str
|
41 |
+
post_processor: dict
|
42 |
+
|
43 |
+
decoder_cls: str
|
44 |
+
decoder: dict
|
45 |
+
|
46 |
+
renderer_cls: str
|
47 |
+
renderer: dict
|
48 |
+
|
49 |
+
cfg: Config
|
50 |
+
|
51 |
+
@classmethod
|
52 |
+
def from_pretrained(
|
53 |
+
cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str
|
54 |
+
):
|
55 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
56 |
+
config_path = os.path.join(pretrained_model_name_or_path, config_name)
|
57 |
+
weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
|
58 |
+
else:
|
59 |
+
config_path = hf_hub_download(
|
60 |
+
repo_id=pretrained_model_name_or_path, filename=config_name
|
61 |
+
)
|
62 |
+
weight_path = hf_hub_download(
|
63 |
+
repo_id=pretrained_model_name_or_path, filename=weight_name
|
64 |
+
)
|
65 |
+
|
66 |
+
cfg = OmegaConf.load(config_path)
|
67 |
+
OmegaConf.resolve(cfg)
|
68 |
+
model = cls(cfg)
|
69 |
+
ckpt = torch.load(weight_path, map_location="cpu")
|
70 |
+
model.load_state_dict(ckpt)
|
71 |
+
return model
|
72 |
+
|
73 |
+
def configure(self):
|
74 |
+
self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)(
|
75 |
+
self.cfg.image_tokenizer
|
76 |
+
)
|
77 |
+
self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer)
|
78 |
+
self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone)
|
79 |
+
self.post_processor = find_class(self.cfg.post_processor_cls)(
|
80 |
+
self.cfg.post_processor
|
81 |
+
)
|
82 |
+
self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder)
|
83 |
+
self.renderer = find_class(self.cfg.renderer_cls)(self.cfg.renderer)
|
84 |
+
self.image_processor = ImagePreprocessor()
|
85 |
+
self.isosurface_helper = None
|
86 |
+
|
87 |
+
def forward(
|
88 |
+
self,
|
89 |
+
image: Union[
|
90 |
+
PIL.Image.Image,
|
91 |
+
np.ndarray,
|
92 |
+
torch.FloatTensor,
|
93 |
+
List[PIL.Image.Image],
|
94 |
+
List[np.ndarray],
|
95 |
+
List[torch.FloatTensor],
|
96 |
+
],
|
97 |
+
device: str,
|
98 |
+
) -> torch.FloatTensor:
|
99 |
+
rgb_cond = self.image_processor(image, self.cfg.cond_image_size)[:, None].to(
|
100 |
+
device
|
101 |
+
)
|
102 |
+
batch_size = rgb_cond.shape[0]
|
103 |
+
|
104 |
+
input_image_tokens: torch.Tensor = self.image_tokenizer(
|
105 |
+
rearrange(rgb_cond, "B Nv H W C -> B Nv C H W", Nv=1),
|
106 |
+
)
|
107 |
+
|
108 |
+
input_image_tokens = rearrange(
|
109 |
+
input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=1
|
110 |
+
)
|
111 |
+
|
112 |
+
tokens: torch.Tensor = self.tokenizer(batch_size)
|
113 |
+
|
114 |
+
tokens = self.backbone(
|
115 |
+
tokens,
|
116 |
+
encoder_hidden_states=input_image_tokens,
|
117 |
+
)
|
118 |
+
|
119 |
+
scene_codes = self.post_processor(self.tokenizer.detokenize(tokens))
|
120 |
+
return scene_codes
|
121 |
+
|
122 |
+
def render(
|
123 |
+
self,
|
124 |
+
scene_codes,
|
125 |
+
n_views: int,
|
126 |
+
elevation_deg: float = 0.0,
|
127 |
+
camera_distance: float = 1.9,
|
128 |
+
fovy_deg: float = 40.0,
|
129 |
+
height: int = 256,
|
130 |
+
width: int = 256,
|
131 |
+
return_type: str = "pil",
|
132 |
+
):
|
133 |
+
rays_o, rays_d = get_spherical_cameras(
|
134 |
+
n_views, elevation_deg, camera_distance, fovy_deg, height, width
|
135 |
+
)
|
136 |
+
rays_o, rays_d = rays_o.to(scene_codes.device), rays_d.to(scene_codes.device)
|
137 |
+
|
138 |
+
def process_output(image: torch.FloatTensor):
|
139 |
+
if return_type == "pt":
|
140 |
+
return image
|
141 |
+
elif return_type == "np":
|
142 |
+
return image.detach().cpu().numpy()
|
143 |
+
elif return_type == "pil":
|
144 |
+
return Image.fromarray(
|
145 |
+
(image.detach().cpu().numpy() * 255.0).astype(np.uint8)
|
146 |
+
)
|
147 |
+
else:
|
148 |
+
raise NotImplementedError
|
149 |
+
|
150 |
+
images = []
|
151 |
+
for scene_code in scene_codes:
|
152 |
+
images_ = []
|
153 |
+
for i in range(n_views):
|
154 |
+
with torch.no_grad():
|
155 |
+
image = self.renderer(
|
156 |
+
self.decoder, scene_code, rays_o[i], rays_d[i]
|
157 |
+
)
|
158 |
+
images_.append(process_output(image))
|
159 |
+
images.append(images_)
|
160 |
+
|
161 |
+
return images
|
162 |
+
|
163 |
+
def set_marching_cubes_resolution(self, resolution: int):
|
164 |
+
if (
|
165 |
+
self.isosurface_helper is not None
|
166 |
+
and self.isosurface_helper.resolution == resolution
|
167 |
+
):
|
168 |
+
return
|
169 |
+
self.isosurface_helper = MarchingCubeHelper(resolution)
|
170 |
+
|
171 |
+
def extract_mesh(self, scene_codes, resolution: int = 256, threshold: float = 20.0):
|
172 |
+
self.set_marching_cubes_resolution(resolution)
|
173 |
+
meshes = []
|
174 |
+
for scene_code in scene_codes:
|
175 |
+
with torch.no_grad():
|
176 |
+
density = self.renderer.query_triplane(
|
177 |
+
self.decoder,
|
178 |
+
scale_tensor(
|
179 |
+
self.isosurface_helper.grid_vertices.to(scene_codes.device),
|
180 |
+
self.isosurface_helper.points_range,
|
181 |
+
(-self.renderer.cfg.radius, self.renderer.cfg.radius),
|
182 |
+
),
|
183 |
+
scene_code,
|
184 |
+
)["density_act"]
|
185 |
+
v_pos, t_pos_idx = self.isosurface_helper(-(density - threshold))
|
186 |
+
v_pos = scale_tensor(
|
187 |
+
v_pos,
|
188 |
+
self.isosurface_helper.points_range,
|
189 |
+
(-self.renderer.cfg.radius, self.renderer.cfg.radius),
|
190 |
+
)
|
191 |
+
with torch.no_grad():
|
192 |
+
color = self.renderer.query_triplane(
|
193 |
+
self.decoder,
|
194 |
+
v_pos,
|
195 |
+
scene_code,
|
196 |
+
)["color"]
|
197 |
+
mesh = trimesh.Trimesh(
|
198 |
+
vertices=v_pos.cpu().numpy(),
|
199 |
+
faces=t_pos_idx.cpu().numpy(),
|
200 |
+
vertex_colors=color.cpu().numpy(),
|
201 |
+
)
|
202 |
+
meshes.append(mesh)
|
203 |
+
return meshes
|
tsr/utils.py
ADDED
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import math
|
3 |
+
from collections import defaultdict
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
6 |
+
|
7 |
+
import imageio
|
8 |
+
import numpy as np
|
9 |
+
import PIL.Image
|
10 |
+
import rembg
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from omegaconf import DictConfig, OmegaConf
|
15 |
+
from PIL import Image
|
16 |
+
|
17 |
+
|
18 |
+
def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
|
19 |
+
scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg)
|
20 |
+
return scfg
|
21 |
+
|
22 |
+
|
23 |
+
def find_class(cls_string):
|
24 |
+
module_string = ".".join(cls_string.split(".")[:-1])
|
25 |
+
cls_name = cls_string.split(".")[-1]
|
26 |
+
module = importlib.import_module(module_string, package=None)
|
27 |
+
cls = getattr(module, cls_name)
|
28 |
+
return cls
|
29 |
+
|
30 |
+
|
31 |
+
def get_intrinsic_from_fov(fov, H, W, bs=-1):
|
32 |
+
focal_length = 0.5 * H / np.tan(0.5 * fov)
|
33 |
+
intrinsic = np.identity(3, dtype=np.float32)
|
34 |
+
intrinsic[0, 0] = focal_length
|
35 |
+
intrinsic[1, 1] = focal_length
|
36 |
+
intrinsic[0, 2] = W / 2.0
|
37 |
+
intrinsic[1, 2] = H / 2.0
|
38 |
+
|
39 |
+
if bs > 0:
|
40 |
+
intrinsic = intrinsic[None].repeat(bs, axis=0)
|
41 |
+
|
42 |
+
return torch.from_numpy(intrinsic)
|
43 |
+
|
44 |
+
|
45 |
+
class BaseModule(nn.Module):
|
46 |
+
@dataclass
|
47 |
+
class Config:
|
48 |
+
pass
|
49 |
+
|
50 |
+
cfg: Config # add this to every subclass of BaseModule to enable static type checking
|
51 |
+
|
52 |
+
def __init__(
|
53 |
+
self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
|
54 |
+
) -> None:
|
55 |
+
super().__init__()
|
56 |
+
self.cfg = parse_structured(self.Config, cfg)
|
57 |
+
self.configure(*args, **kwargs)
|
58 |
+
|
59 |
+
def configure(self, *args, **kwargs) -> None:
|
60 |
+
raise NotImplementedError
|
61 |
+
|
62 |
+
|
63 |
+
class ImagePreprocessor:
|
64 |
+
def convert_and_resize(
|
65 |
+
self,
|
66 |
+
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
67 |
+
size: int,
|
68 |
+
):
|
69 |
+
if isinstance(image, PIL.Image.Image):
|
70 |
+
image = torch.from_numpy(np.array(image).astype(np.float32) / 255.0)
|
71 |
+
elif isinstance(image, np.ndarray):
|
72 |
+
if image.dtype == np.uint8:
|
73 |
+
image = torch.from_numpy(image.astype(np.float32) / 255.0)
|
74 |
+
else:
|
75 |
+
image = torch.from_numpy(image)
|
76 |
+
elif isinstance(image, torch.Tensor):
|
77 |
+
pass
|
78 |
+
|
79 |
+
batched = image.ndim == 4
|
80 |
+
|
81 |
+
if not batched:
|
82 |
+
image = image[None, ...]
|
83 |
+
image = F.interpolate(
|
84 |
+
image.permute(0, 3, 1, 2),
|
85 |
+
(size, size),
|
86 |
+
mode="bilinear",
|
87 |
+
align_corners=False,
|
88 |
+
antialias=True,
|
89 |
+
).permute(0, 2, 3, 1)
|
90 |
+
if not batched:
|
91 |
+
image = image[0]
|
92 |
+
return image
|
93 |
+
|
94 |
+
def __call__(
|
95 |
+
self,
|
96 |
+
image: Union[
|
97 |
+
PIL.Image.Image,
|
98 |
+
np.ndarray,
|
99 |
+
torch.FloatTensor,
|
100 |
+
List[PIL.Image.Image],
|
101 |
+
List[np.ndarray],
|
102 |
+
List[torch.FloatTensor],
|
103 |
+
],
|
104 |
+
size: int,
|
105 |
+
) -> Any:
|
106 |
+
if isinstance(image, (np.ndarray, torch.FloatTensor)) and image.ndim == 4:
|
107 |
+
image = self.convert_and_resize(image, size)
|
108 |
+
else:
|
109 |
+
if not isinstance(image, list):
|
110 |
+
image = [image]
|
111 |
+
image = [self.convert_and_resize(im, size) for im in image]
|
112 |
+
image = torch.stack(image, dim=0)
|
113 |
+
return image
|
114 |
+
|
115 |
+
|
116 |
+
def rays_intersect_bbox(
|
117 |
+
rays_o: torch.Tensor,
|
118 |
+
rays_d: torch.Tensor,
|
119 |
+
radius: float,
|
120 |
+
near: float = 0.0,
|
121 |
+
valid_thresh: float = 0.01,
|
122 |
+
):
|
123 |
+
input_shape = rays_o.shape[:-1]
|
124 |
+
rays_o, rays_d = rays_o.view(-1, 3), rays_d.view(-1, 3)
|
125 |
+
rays_d_valid = torch.where(
|
126 |
+
rays_d.abs() < 1e-6, torch.full_like(rays_d, 1e-6), rays_d
|
127 |
+
)
|
128 |
+
if type(radius) in [int, float]:
|
129 |
+
radius = torch.FloatTensor(
|
130 |
+
[[-radius, radius], [-radius, radius], [-radius, radius]]
|
131 |
+
).to(rays_o.device)
|
132 |
+
radius = (
|
133 |
+
1.0 - 1.0e-3
|
134 |
+
) * radius # tighten the radius to make sure the intersection point lies in the bounding box
|
135 |
+
interx0 = (radius[..., 1] - rays_o) / rays_d_valid
|
136 |
+
interx1 = (radius[..., 0] - rays_o) / rays_d_valid
|
137 |
+
t_near = torch.minimum(interx0, interx1).amax(dim=-1).clamp_min(near)
|
138 |
+
t_far = torch.maximum(interx0, interx1).amin(dim=-1)
|
139 |
+
|
140 |
+
# check wheter a ray intersects the bbox or not
|
141 |
+
rays_valid = t_far - t_near > valid_thresh
|
142 |
+
|
143 |
+
t_near[torch.where(~rays_valid)] = 0.0
|
144 |
+
t_far[torch.where(~rays_valid)] = 0.0
|
145 |
+
|
146 |
+
t_near = t_near.view(*input_shape, 1)
|
147 |
+
t_far = t_far.view(*input_shape, 1)
|
148 |
+
rays_valid = rays_valid.view(*input_shape)
|
149 |
+
|
150 |
+
return t_near, t_far, rays_valid
|
151 |
+
|
152 |
+
|
153 |
+
def chunk_batch(func: Callable, chunk_size: int, *args, **kwargs) -> Any:
|
154 |
+
if chunk_size <= 0:
|
155 |
+
return func(*args, **kwargs)
|
156 |
+
B = None
|
157 |
+
for arg in list(args) + list(kwargs.values()):
|
158 |
+
if isinstance(arg, torch.Tensor):
|
159 |
+
B = arg.shape[0]
|
160 |
+
break
|
161 |
+
assert (
|
162 |
+
B is not None
|
163 |
+
), "No tensor found in args or kwargs, cannot determine batch size."
|
164 |
+
out = defaultdict(list)
|
165 |
+
out_type = None
|
166 |
+
# max(1, B) to support B == 0
|
167 |
+
for i in range(0, max(1, B), chunk_size):
|
168 |
+
out_chunk = func(
|
169 |
+
*[
|
170 |
+
arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
|
171 |
+
for arg in args
|
172 |
+
],
|
173 |
+
**{
|
174 |
+
k: arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
|
175 |
+
for k, arg in kwargs.items()
|
176 |
+
},
|
177 |
+
)
|
178 |
+
if out_chunk is None:
|
179 |
+
continue
|
180 |
+
out_type = type(out_chunk)
|
181 |
+
if isinstance(out_chunk, torch.Tensor):
|
182 |
+
out_chunk = {0: out_chunk}
|
183 |
+
elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list):
|
184 |
+
chunk_length = len(out_chunk)
|
185 |
+
out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)}
|
186 |
+
elif isinstance(out_chunk, dict):
|
187 |
+
pass
|
188 |
+
else:
|
189 |
+
print(
|
190 |
+
f"Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}."
|
191 |
+
)
|
192 |
+
exit(1)
|
193 |
+
for k, v in out_chunk.items():
|
194 |
+
v = v if torch.is_grad_enabled() else v.detach()
|
195 |
+
out[k].append(v)
|
196 |
+
|
197 |
+
if out_type is None:
|
198 |
+
return None
|
199 |
+
|
200 |
+
out_merged: Dict[Any, Optional[torch.Tensor]] = {}
|
201 |
+
for k, v in out.items():
|
202 |
+
if all([vv is None for vv in v]):
|
203 |
+
# allow None in return value
|
204 |
+
out_merged[k] = None
|
205 |
+
elif all([isinstance(vv, torch.Tensor) for vv in v]):
|
206 |
+
out_merged[k] = torch.cat(v, dim=0)
|
207 |
+
else:
|
208 |
+
raise TypeError(
|
209 |
+
f"Unsupported types in return value of func: {[type(vv) for vv in v if not isinstance(vv, torch.Tensor)]}"
|
210 |
+
)
|
211 |
+
|
212 |
+
if out_type is torch.Tensor:
|
213 |
+
return out_merged[0]
|
214 |
+
elif out_type in [tuple, list]:
|
215 |
+
return out_type([out_merged[i] for i in range(chunk_length)])
|
216 |
+
elif out_type is dict:
|
217 |
+
return out_merged
|
218 |
+
|
219 |
+
|
220 |
+
ValidScale = Union[Tuple[float, float], torch.FloatTensor]
|
221 |
+
|
222 |
+
|
223 |
+
def scale_tensor(dat: torch.FloatTensor, inp_scale: ValidScale, tgt_scale: ValidScale):
|
224 |
+
if inp_scale is None:
|
225 |
+
inp_scale = (0, 1)
|
226 |
+
if tgt_scale is None:
|
227 |
+
tgt_scale = (0, 1)
|
228 |
+
if isinstance(tgt_scale, torch.FloatTensor):
|
229 |
+
assert dat.shape[-1] == tgt_scale.shape[-1]
|
230 |
+
dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
|
231 |
+
dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
|
232 |
+
return dat
|
233 |
+
|
234 |
+
|
235 |
+
def get_activation(name) -> Callable:
|
236 |
+
if name is None:
|
237 |
+
return lambda x: x
|
238 |
+
name = name.lower()
|
239 |
+
if name == "none":
|
240 |
+
return lambda x: x
|
241 |
+
elif name == "exp":
|
242 |
+
return lambda x: torch.exp(x)
|
243 |
+
elif name == "sigmoid":
|
244 |
+
return lambda x: torch.sigmoid(x)
|
245 |
+
elif name == "tanh":
|
246 |
+
return lambda x: torch.tanh(x)
|
247 |
+
elif name == "softplus":
|
248 |
+
return lambda x: F.softplus(x)
|
249 |
+
else:
|
250 |
+
try:
|
251 |
+
return getattr(F, name)
|
252 |
+
except AttributeError:
|
253 |
+
raise ValueError(f"Unknown activation function: {name}")
|
254 |
+
|
255 |
+
|
256 |
+
def get_ray_directions(
|
257 |
+
H: int,
|
258 |
+
W: int,
|
259 |
+
focal: Union[float, Tuple[float, float]],
|
260 |
+
principal: Optional[Tuple[float, float]] = None,
|
261 |
+
use_pixel_centers: bool = True,
|
262 |
+
normalize: bool = True,
|
263 |
+
) -> torch.FloatTensor:
|
264 |
+
"""
|
265 |
+
Get ray directions for all pixels in camera coordinate.
|
266 |
+
Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
|
267 |
+
ray-tracing-generating-camera-rays/standard-coordinate-systems
|
268 |
+
|
269 |
+
Inputs:
|
270 |
+
H, W, focal, principal, use_pixel_centers: image height, width, focal length, principal point and whether use pixel centers
|
271 |
+
Outputs:
|
272 |
+
directions: (H, W, 3), the direction of the rays in camera coordinate
|
273 |
+
"""
|
274 |
+
pixel_center = 0.5 if use_pixel_centers else 0
|
275 |
+
|
276 |
+
if isinstance(focal, float):
|
277 |
+
fx, fy = focal, focal
|
278 |
+
cx, cy = W / 2, H / 2
|
279 |
+
else:
|
280 |
+
fx, fy = focal
|
281 |
+
assert principal is not None
|
282 |
+
cx, cy = principal
|
283 |
+
|
284 |
+
i, j = torch.meshgrid(
|
285 |
+
torch.arange(W, dtype=torch.float32) + pixel_center,
|
286 |
+
torch.arange(H, dtype=torch.float32) + pixel_center,
|
287 |
+
indexing="xy",
|
288 |
+
)
|
289 |
+
|
290 |
+
directions = torch.stack([(i - cx) / fx, -(j - cy) / fy, -torch.ones_like(i)], -1)
|
291 |
+
|
292 |
+
if normalize:
|
293 |
+
directions = F.normalize(directions, dim=-1)
|
294 |
+
|
295 |
+
return directions
|
296 |
+
|
297 |
+
|
298 |
+
def get_rays(
|
299 |
+
directions,
|
300 |
+
c2w,
|
301 |
+
keepdim=False,
|
302 |
+
noise_scale=0.0,
|
303 |
+
normalize=False,
|
304 |
+
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
305 |
+
# Rotate ray directions from camera coordinate to the world coordinate
|
306 |
+
assert directions.shape[-1] == 3
|
307 |
+
|
308 |
+
if directions.ndim == 2: # (N_rays, 3)
|
309 |
+
if c2w.ndim == 2: # (4, 4)
|
310 |
+
c2w = c2w[None, :, :]
|
311 |
+
assert c2w.ndim == 3 # (N_rays, 4, 4) or (1, 4, 4)
|
312 |
+
rays_d = (directions[:, None, :] * c2w[:, :3, :3]).sum(-1) # (N_rays, 3)
|
313 |
+
rays_o = c2w[:, :3, 3].expand(rays_d.shape)
|
314 |
+
elif directions.ndim == 3: # (H, W, 3)
|
315 |
+
assert c2w.ndim in [2, 3]
|
316 |
+
if c2w.ndim == 2: # (4, 4)
|
317 |
+
rays_d = (directions[:, :, None, :] * c2w[None, None, :3, :3]).sum(
|
318 |
+
-1
|
319 |
+
) # (H, W, 3)
|
320 |
+
rays_o = c2w[None, None, :3, 3].expand(rays_d.shape)
|
321 |
+
elif c2w.ndim == 3: # (B, 4, 4)
|
322 |
+
rays_d = (directions[None, :, :, None, :] * c2w[:, None, None, :3, :3]).sum(
|
323 |
+
-1
|
324 |
+
) # (B, H, W, 3)
|
325 |
+
rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
|
326 |
+
elif directions.ndim == 4: # (B, H, W, 3)
|
327 |
+
assert c2w.ndim == 3 # (B, 4, 4)
|
328 |
+
rays_d = (directions[:, :, :, None, :] * c2w[:, None, None, :3, :3]).sum(
|
329 |
+
-1
|
330 |
+
) # (B, H, W, 3)
|
331 |
+
rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
|
332 |
+
|
333 |
+
# add camera noise to avoid grid-like artifect
|
334 |
+
# https://github.com/ashawkey/stable-dreamfusion/blob/49c3d4fa01d68a4f027755acf94e1ff6020458cc/nerf/utils.py#L373
|
335 |
+
if noise_scale > 0:
|
336 |
+
rays_o = rays_o + torch.randn(3, device=rays_o.device) * noise_scale
|
337 |
+
rays_d = rays_d + torch.randn(3, device=rays_d.device) * noise_scale
|
338 |
+
|
339 |
+
if normalize:
|
340 |
+
rays_d = F.normalize(rays_d, dim=-1)
|
341 |
+
if not keepdim:
|
342 |
+
rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)
|
343 |
+
|
344 |
+
return rays_o, rays_d
|
345 |
+
|
346 |
+
|
347 |
+
def get_spherical_cameras(
|
348 |
+
n_views: int,
|
349 |
+
elevation_deg: float,
|
350 |
+
camera_distance: float,
|
351 |
+
fovy_deg: float,
|
352 |
+
height: int,
|
353 |
+
width: int,
|
354 |
+
):
|
355 |
+
azimuth_deg = torch.linspace(0, 360.0, n_views + 1)[:n_views]
|
356 |
+
elevation_deg = torch.full_like(azimuth_deg, elevation_deg)
|
357 |
+
camera_distances = torch.full_like(elevation_deg, camera_distance)
|
358 |
+
|
359 |
+
elevation = elevation_deg * math.pi / 180
|
360 |
+
azimuth = azimuth_deg * math.pi / 180
|
361 |
+
|
362 |
+
# convert spherical coordinates to cartesian coordinates
|
363 |
+
# right hand coordinate system, x back, y right, z up
|
364 |
+
# elevation in (-90, 90), azimuth from +x to +y in (-180, 180)
|
365 |
+
camera_positions = torch.stack(
|
366 |
+
[
|
367 |
+
camera_distances * torch.cos(elevation) * torch.cos(azimuth),
|
368 |
+
camera_distances * torch.cos(elevation) * torch.sin(azimuth),
|
369 |
+
camera_distances * torch.sin(elevation),
|
370 |
+
],
|
371 |
+
dim=-1,
|
372 |
+
)
|
373 |
+
|
374 |
+
# default scene center at origin
|
375 |
+
center = torch.zeros_like(camera_positions)
|
376 |
+
# default camera up direction as +z
|
377 |
+
up = torch.as_tensor([0, 0, 1], dtype=torch.float32)[None, :].repeat(n_views, 1)
|
378 |
+
|
379 |
+
fovy = torch.full_like(elevation_deg, fovy_deg) * math.pi / 180
|
380 |
+
|
381 |
+
lookat = F.normalize(center - camera_positions, dim=-1)
|
382 |
+
right = F.normalize(torch.cross(lookat, up), dim=-1)
|
383 |
+
up = F.normalize(torch.cross(right, lookat), dim=-1)
|
384 |
+
c2w3x4 = torch.cat(
|
385 |
+
[torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]],
|
386 |
+
dim=-1,
|
387 |
+
)
|
388 |
+
c2w = torch.cat([c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1)
|
389 |
+
c2w[:, 3, 3] = 1.0
|
390 |
+
|
391 |
+
# get directions by dividing directions_unit_focal by focal length
|
392 |
+
focal_length = 0.5 * height / torch.tan(0.5 * fovy)
|
393 |
+
directions_unit_focal = get_ray_directions(
|
394 |
+
H=height,
|
395 |
+
W=width,
|
396 |
+
focal=1.0,
|
397 |
+
)
|
398 |
+
directions = directions_unit_focal[None, :, :, :].repeat(n_views, 1, 1, 1)
|
399 |
+
directions[:, :, :, :2] = (
|
400 |
+
directions[:, :, :, :2] / focal_length[:, None, None, None]
|
401 |
+
)
|
402 |
+
# must use normalize=True to normalize directions here
|
403 |
+
rays_o, rays_d = get_rays(directions, c2w, keepdim=True, normalize=True)
|
404 |
+
|
405 |
+
return rays_o, rays_d
|
406 |
+
|
407 |
+
|
408 |
+
def remove_background(
|
409 |
+
image: PIL.Image.Image,
|
410 |
+
rembg_session: Any = None,
|
411 |
+
force: bool = False,
|
412 |
+
**rembg_kwargs,
|
413 |
+
) -> PIL.Image.Image:
|
414 |
+
do_remove = True
|
415 |
+
if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
|
416 |
+
do_remove = False
|
417 |
+
do_remove = do_remove or force
|
418 |
+
if do_remove:
|
419 |
+
image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
|
420 |
+
return image
|
421 |
+
|
422 |
+
|
423 |
+
def resize_foreground(
|
424 |
+
image: PIL.Image.Image,
|
425 |
+
ratio: float,
|
426 |
+
) -> PIL.Image.Image:
|
427 |
+
image = np.array(image)
|
428 |
+
assert image.shape[-1] == 4
|
429 |
+
alpha = np.where(image[..., 3] > 0)
|
430 |
+
y1, y2, x1, x2 = (
|
431 |
+
alpha[0].min(),
|
432 |
+
alpha[0].max(),
|
433 |
+
alpha[1].min(),
|
434 |
+
alpha[1].max(),
|
435 |
+
)
|
436 |
+
# crop the foreground
|
437 |
+
fg = image[y1:y2, x1:x2]
|
438 |
+
# pad to square
|
439 |
+
size = max(fg.shape[0], fg.shape[1])
|
440 |
+
ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
|
441 |
+
ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
|
442 |
+
new_image = np.pad(
|
443 |
+
fg,
|
444 |
+
((ph0, ph1), (pw0, pw1), (0, 0)),
|
445 |
+
mode="constant",
|
446 |
+
constant_values=((0, 0), (0, 0), (0, 0)),
|
447 |
+
)
|
448 |
+
|
449 |
+
# compute padding according to the ratio
|
450 |
+
new_size = int(new_image.shape[0] / ratio)
|
451 |
+
# pad to size, double side
|
452 |
+
ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
|
453 |
+
ph1, pw1 = new_size - size - ph0, new_size - size - pw0
|
454 |
+
new_image = np.pad(
|
455 |
+
new_image,
|
456 |
+
((ph0, ph1), (pw0, pw1), (0, 0)),
|
457 |
+
mode="constant",
|
458 |
+
constant_values=((0, 0), (0, 0), (0, 0)),
|
459 |
+
)
|
460 |
+
new_image = PIL.Image.fromarray(new_image)
|
461 |
+
return new_image
|
462 |
+
|
463 |
+
|
464 |
+
def save_video(
|
465 |
+
frames: List[PIL.Image.Image],
|
466 |
+
output_path: str,
|
467 |
+
fps: int = 30,
|
468 |
+
):
|
469 |
+
# use imageio to save video
|
470 |
+
frames = [np.array(frame) for frame in frames]
|
471 |
+
writer = imageio.get_writer(output_path, fps=fps)
|
472 |
+
for frame in frames:
|
473 |
+
writer.append_data(frame)
|
474 |
+
writer.close()
|
475 |
+
|
476 |
+
|
477 |
+
_dir2vec = {
|
478 |
+
"+x": np.array([1, 0, 0]),
|
479 |
+
"+y": np.array([0, 1, 0]),
|
480 |
+
"+z": np.array([0, 0, 1]),
|
481 |
+
"-x": np.array([-1, 0, 0]),
|
482 |
+
"-y": np.array([0, -1, 0]),
|
483 |
+
"-z": np.array([0, 0, -1]),
|
484 |
+
}
|
485 |
+
|
486 |
+
|
487 |
+
def to_gradio_3d_orientation(vertices):
|
488 |
+
z_, x_ = _dir2vec["+y"], _dir2vec["-z"]
|
489 |
+
y_ = np.cross(z_, x_)
|
490 |
+
std2mesh = np.stack([x_, y_, z_], axis=0).T
|
491 |
+
vertices = np.dot(std2mesh, vertices.T).T
|
492 |
+
return vertices
|