Spaces:
Running
on
Zero
Running
on
Zero
kxhit
commited on
Commit
•
23aae87
1
Parent(s):
ad4ee48
clean
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +0 -43
- 3drecon/configs/neus_36.yaml +0 -26
- 3drecon/raymarching/__init__.py +0 -1
- 3drecon/raymarching/backend.py +0 -40
- 3drecon/raymarching/raymarching.py +0 -373
- 3drecon/raymarching/setup.py +0 -62
- 3drecon/raymarching/src/bindings.cpp +0 -19
- 3drecon/raymarching/src/raymarching.cu +0 -914
- 3drecon/raymarching/src/raymarching.h +0 -18
- 3drecon/renderer/agg_net.py +0 -83
- 3drecon/renderer/cost_reg_net.py +0 -95
- 3drecon/renderer/dummy_dataset.py +0 -40
- 3drecon/renderer/feature_net.py +0 -42
- 3drecon/renderer/neus_networks.py +0 -503
- 3drecon/renderer/ngp_renderer.py +0 -721
- 3drecon/renderer/renderer.py +0 -640
- 3drecon/run_NeuS.py +0 -32
- 3drecon/train_renderer.py +0 -188
- 3drecon/util.py +0 -54
- 4DoF/CN_encoder.py +0 -36
- 4DoF/dataset.py +0 -228
- 4DoF/diffusers/__init__.py +0 -281
- 4DoF/diffusers/commands/__init__.py +0 -27
- 4DoF/diffusers/commands/diffusers_cli.py +0 -41
- 4DoF/diffusers/commands/env.py +0 -84
- 4DoF/diffusers/configuration_utils.py +0 -664
- 4DoF/diffusers/dependency_versions_check.py +0 -47
- 4DoF/diffusers/dependency_versions_table.py +0 -44
- 4DoF/diffusers/experimental/__init__.py +0 -1
- 4DoF/diffusers/experimental/rl/__init__.py +0 -1
- 4DoF/diffusers/experimental/rl/value_guided_sampling.py +0 -152
- 4DoF/diffusers/image_processor.py +0 -366
- 4DoF/diffusers/loaders.py +0 -1492
- 4DoF/diffusers/models/__init__.py +0 -35
- 4DoF/diffusers/models/activations.py +0 -12
- 4DoF/diffusers/models/attention.py +0 -392
- 4DoF/diffusers/models/attention_flax.py +0 -446
- 4DoF/diffusers/models/attention_processor.py +0 -1714
- 4DoF/diffusers/models/autoencoder_kl.py +0 -411
- 4DoF/diffusers/models/controlnet.py +0 -705
- 4DoF/diffusers/models/controlnet_flax.py +0 -394
- 4DoF/diffusers/models/cross_attention.py +0 -94
- 4DoF/diffusers/models/dual_transformer_2d.py +0 -151
- 4DoF/diffusers/models/embeddings.py +0 -546
- 4DoF/diffusers/models/embeddings_flax.py +0 -95
- 4DoF/diffusers/models/modeling_flax_pytorch_utils.py +0 -118
- 4DoF/diffusers/models/modeling_flax_utils.py +0 -534
- 4DoF/diffusers/models/modeling_pytorch_flax_utils.py +0 -161
- 4DoF/diffusers/models/modeling_utils.py +0 -980
- 4DoF/diffusers/models/prior_transformer.py +0 -364
.gitattributes
DELETED
@@ -1,43 +0,0 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
-
logs/user_object/eschernet/output.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
-
logs/user_object/scene.glb filter=lfs diff=lfs merge=lfs -text
|
38 |
-
3drecon/ours_GSO_T1/NeuS/grandmother/mesh.ply filter=lfs diff=lfs merge=lfs -text
|
39 |
-
3drecon/ours_GSO_T1/NeuS/lion/mesh.ply filter=lfs diff=lfs merge=lfs -text
|
40 |
-
gradio_demo/examples/bike/003.jpg filter=lfs diff=lfs merge=lfs -text
|
41 |
-
gradio_demo/examples/bike/027.jpg filter=lfs diff=lfs merge=lfs -text
|
42 |
-
gradio_demo/examples/bike/bike_0.jpg filter=lfs diff=lfs merge=lfs -text
|
43 |
-
gradio_demo/examples/bike/bike_2.jpg filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3drecon/configs/neus_36.yaml
DELETED
@@ -1,26 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
base_lr: 5.0e-4
|
3 |
-
target: renderer.renderer.RendererTrainer
|
4 |
-
params:
|
5 |
-
total_steps: 2000
|
6 |
-
warm_up_steps: 100
|
7 |
-
train_batch_num: 2560
|
8 |
-
train_batch_fg_num: 512
|
9 |
-
test_batch_num: 4096
|
10 |
-
use_mask: true
|
11 |
-
lambda_rgb_loss: 0.5
|
12 |
-
lambda_mask_loss: 1.0
|
13 |
-
lambda_eikonal_loss: 0.1
|
14 |
-
use_warm_up: true
|
15 |
-
|
16 |
-
data:
|
17 |
-
target: renderer.dummy_dataset.DummyDataset
|
18 |
-
params: {}
|
19 |
-
|
20 |
-
callbacks:
|
21 |
-
save_interval: 500
|
22 |
-
|
23 |
-
trainer:
|
24 |
-
val_check_interval: 500
|
25 |
-
max_steps: 2000
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3drecon/raymarching/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from .raymarching import *
|
|
|
|
3drecon/raymarching/backend.py
DELETED
@@ -1,40 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from torch.utils.cpp_extension import load
|
3 |
-
|
4 |
-
_src_path = os.path.dirname(os.path.abspath(__file__))
|
5 |
-
|
6 |
-
nvcc_flags = [
|
7 |
-
'-O3', '-std=c++14',
|
8 |
-
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
9 |
-
]
|
10 |
-
|
11 |
-
if os.name == "posix":
|
12 |
-
c_flags = ['-O3', '-std=c++14']
|
13 |
-
elif os.name == "nt":
|
14 |
-
c_flags = ['/O2', '/std:c++17']
|
15 |
-
|
16 |
-
# find cl.exe
|
17 |
-
def find_cl_path():
|
18 |
-
import glob
|
19 |
-
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
20 |
-
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
|
21 |
-
if paths:
|
22 |
-
return paths[0]
|
23 |
-
|
24 |
-
# If cl.exe is not on path, try to find it.
|
25 |
-
if os.system("where cl.exe >nul 2>nul") != 0:
|
26 |
-
cl_path = find_cl_path()
|
27 |
-
if cl_path is None:
|
28 |
-
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
29 |
-
os.environ["PATH"] += ";" + cl_path
|
30 |
-
|
31 |
-
_backend = load(name='_raymarching',
|
32 |
-
extra_cflags=c_flags,
|
33 |
-
extra_cuda_cflags=nvcc_flags,
|
34 |
-
sources=[os.path.join(_src_path, 'src', f) for f in [
|
35 |
-
'raymarching.cu',
|
36 |
-
'bindings.cpp',
|
37 |
-
]],
|
38 |
-
)
|
39 |
-
|
40 |
-
__all__ = ['_backend']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3drecon/raymarching/raymarching.py
DELETED
@@ -1,373 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import time
|
3 |
-
|
4 |
-
import torch
|
5 |
-
import torch.nn as nn
|
6 |
-
from torch.autograd import Function
|
7 |
-
from torch.cuda.amp import custom_bwd, custom_fwd
|
8 |
-
|
9 |
-
try:
|
10 |
-
import _raymarching as _backend
|
11 |
-
except ImportError:
|
12 |
-
from .backend import _backend
|
13 |
-
|
14 |
-
|
15 |
-
# ----------------------------------------
|
16 |
-
# utils
|
17 |
-
# ----------------------------------------
|
18 |
-
|
19 |
-
class _near_far_from_aabb(Function):
|
20 |
-
@staticmethod
|
21 |
-
@custom_fwd(cast_inputs=torch.float32)
|
22 |
-
def forward(ctx, rays_o, rays_d, aabb, min_near=0.2):
|
23 |
-
''' near_far_from_aabb, CUDA implementation
|
24 |
-
Calculate rays' intersection time (near and far) with aabb
|
25 |
-
Args:
|
26 |
-
rays_o: float, [N, 3]
|
27 |
-
rays_d: float, [N, 3]
|
28 |
-
aabb: float, [6], (xmin, ymin, zmin, xmax, ymax, zmax)
|
29 |
-
min_near: float, scalar
|
30 |
-
Returns:
|
31 |
-
nears: float, [N]
|
32 |
-
fars: float, [N]
|
33 |
-
'''
|
34 |
-
if not rays_o.is_cuda: rays_o = rays_o.cuda()
|
35 |
-
if not rays_d.is_cuda: rays_d = rays_d.cuda()
|
36 |
-
|
37 |
-
rays_o = rays_o.contiguous().view(-1, 3)
|
38 |
-
rays_d = rays_d.contiguous().view(-1, 3)
|
39 |
-
|
40 |
-
N = rays_o.shape[0] # num rays
|
41 |
-
|
42 |
-
nears = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)
|
43 |
-
fars = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)
|
44 |
-
|
45 |
-
_backend.near_far_from_aabb(rays_o, rays_d, aabb, N, min_near, nears, fars)
|
46 |
-
|
47 |
-
return nears, fars
|
48 |
-
|
49 |
-
near_far_from_aabb = _near_far_from_aabb.apply
|
50 |
-
|
51 |
-
|
52 |
-
class _sph_from_ray(Function):
|
53 |
-
@staticmethod
|
54 |
-
@custom_fwd(cast_inputs=torch.float32)
|
55 |
-
def forward(ctx, rays_o, rays_d, radius):
|
56 |
-
''' sph_from_ray, CUDA implementation
|
57 |
-
get spherical coordinate on the background sphere from rays.
|
58 |
-
Assume rays_o are inside the Sphere(radius).
|
59 |
-
Args:
|
60 |
-
rays_o: [N, 3]
|
61 |
-
rays_d: [N, 3]
|
62 |
-
radius: scalar, float
|
63 |
-
Return:
|
64 |
-
coords: [N, 2], in [-1, 1], theta and phi on a sphere. (further-surface)
|
65 |
-
'''
|
66 |
-
if not rays_o.is_cuda: rays_o = rays_o.cuda()
|
67 |
-
if not rays_d.is_cuda: rays_d = rays_d.cuda()
|
68 |
-
|
69 |
-
rays_o = rays_o.contiguous().view(-1, 3)
|
70 |
-
rays_d = rays_d.contiguous().view(-1, 3)
|
71 |
-
|
72 |
-
N = rays_o.shape[0] # num rays
|
73 |
-
|
74 |
-
coords = torch.empty(N, 2, dtype=rays_o.dtype, device=rays_o.device)
|
75 |
-
|
76 |
-
_backend.sph_from_ray(rays_o, rays_d, radius, N, coords)
|
77 |
-
|
78 |
-
return coords
|
79 |
-
|
80 |
-
sph_from_ray = _sph_from_ray.apply
|
81 |
-
|
82 |
-
|
83 |
-
class _morton3D(Function):
|
84 |
-
@staticmethod
|
85 |
-
def forward(ctx, coords):
|
86 |
-
''' morton3D, CUDA implementation
|
87 |
-
Args:
|
88 |
-
coords: [N, 3], int32, in [0, 128) (for some reason there is no uint32 tensor in torch...)
|
89 |
-
TODO: check if the coord range is valid! (current 128 is safe)
|
90 |
-
Returns:
|
91 |
-
indices: [N], int32, in [0, 128^3)
|
92 |
-
|
93 |
-
'''
|
94 |
-
if not coords.is_cuda: coords = coords.cuda()
|
95 |
-
|
96 |
-
N = coords.shape[0]
|
97 |
-
|
98 |
-
indices = torch.empty(N, dtype=torch.int32, device=coords.device)
|
99 |
-
|
100 |
-
_backend.morton3D(coords.int(), N, indices)
|
101 |
-
|
102 |
-
return indices
|
103 |
-
|
104 |
-
morton3D = _morton3D.apply
|
105 |
-
|
106 |
-
class _morton3D_invert(Function):
|
107 |
-
@staticmethod
|
108 |
-
def forward(ctx, indices):
|
109 |
-
''' morton3D_invert, CUDA implementation
|
110 |
-
Args:
|
111 |
-
indices: [N], int32, in [0, 128^3)
|
112 |
-
Returns:
|
113 |
-
coords: [N, 3], int32, in [0, 128)
|
114 |
-
|
115 |
-
'''
|
116 |
-
if not indices.is_cuda: indices = indices.cuda()
|
117 |
-
|
118 |
-
N = indices.shape[0]
|
119 |
-
|
120 |
-
coords = torch.empty(N, 3, dtype=torch.int32, device=indices.device)
|
121 |
-
|
122 |
-
_backend.morton3D_invert(indices.int(), N, coords)
|
123 |
-
|
124 |
-
return coords
|
125 |
-
|
126 |
-
morton3D_invert = _morton3D_invert.apply
|
127 |
-
|
128 |
-
|
129 |
-
class _packbits(Function):
|
130 |
-
@staticmethod
|
131 |
-
@custom_fwd(cast_inputs=torch.float32)
|
132 |
-
def forward(ctx, grid, thresh, bitfield=None):
|
133 |
-
''' packbits, CUDA implementation
|
134 |
-
Pack up the density grid into a bit field to accelerate ray marching.
|
135 |
-
Args:
|
136 |
-
grid: float, [C, H * H * H], assume H % 2 == 0
|
137 |
-
thresh: float, threshold
|
138 |
-
Returns:
|
139 |
-
bitfield: uint8, [C, H * H * H / 8]
|
140 |
-
'''
|
141 |
-
if not grid.is_cuda: grid = grid.cuda()
|
142 |
-
grid = grid.contiguous()
|
143 |
-
|
144 |
-
C = grid.shape[0]
|
145 |
-
H3 = grid.shape[1]
|
146 |
-
N = C * H3 // 8
|
147 |
-
|
148 |
-
if bitfield is None:
|
149 |
-
bitfield = torch.empty(N, dtype=torch.uint8, device=grid.device)
|
150 |
-
|
151 |
-
_backend.packbits(grid, N, thresh, bitfield)
|
152 |
-
|
153 |
-
return bitfield
|
154 |
-
|
155 |
-
packbits = _packbits.apply
|
156 |
-
|
157 |
-
# ----------------------------------------
|
158 |
-
# train functions
|
159 |
-
# ----------------------------------------
|
160 |
-
|
161 |
-
class _march_rays_train(Function):
|
162 |
-
@staticmethod
|
163 |
-
@custom_fwd(cast_inputs=torch.float32)
|
164 |
-
def forward(ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, step_counter=None, mean_count=-1, perturb=False, align=-1, force_all_rays=False, dt_gamma=0, max_steps=1024):
|
165 |
-
''' march rays to generate points (forward only)
|
166 |
-
Args:
|
167 |
-
rays_o/d: float, [N, 3]
|
168 |
-
bound: float, scalar
|
169 |
-
density_bitfield: uint8: [CHHH // 8]
|
170 |
-
C: int
|
171 |
-
H: int
|
172 |
-
nears/fars: float, [N]
|
173 |
-
step_counter: int32, (2), used to count the actual number of generated points.
|
174 |
-
mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.)
|
175 |
-
perturb: bool
|
176 |
-
align: int, pad output so its size is dividable by align, set to -1 to disable.
|
177 |
-
force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays.
|
178 |
-
dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)
|
179 |
-
max_steps: int, max number of sampled points along each ray, also affect min_stepsize.
|
180 |
-
Returns:
|
181 |
-
xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray)
|
182 |
-
dirs: float, [M, 3], all generated points' view dirs.
|
183 |
-
deltas: float, [M, 2], all generated points' deltas. (first for RGB, second for Depth)
|
184 |
-
rays: int32, [N, 3], all rays' (index, point_offset, point_count), e.g., xyzs[rays[i, 1]:rays[i, 2]] --> points belonging to rays[i, 0]
|
185 |
-
'''
|
186 |
-
|
187 |
-
if not rays_o.is_cuda: rays_o = rays_o.cuda()
|
188 |
-
if not rays_d.is_cuda: rays_d = rays_d.cuda()
|
189 |
-
if not density_bitfield.is_cuda: density_bitfield = density_bitfield.cuda()
|
190 |
-
|
191 |
-
rays_o = rays_o.contiguous().view(-1, 3)
|
192 |
-
rays_d = rays_d.contiguous().view(-1, 3)
|
193 |
-
density_bitfield = density_bitfield.contiguous()
|
194 |
-
|
195 |
-
N = rays_o.shape[0] # num rays
|
196 |
-
M = N * max_steps # init max points number in total
|
197 |
-
|
198 |
-
# running average based on previous epoch (mimic `measured_batch_size_before_compaction` in instant-ngp)
|
199 |
-
# It estimate the max points number to enable faster training, but will lead to random ignored rays if underestimated.
|
200 |
-
if not force_all_rays and mean_count > 0:
|
201 |
-
if align > 0:
|
202 |
-
mean_count += align - mean_count % align
|
203 |
-
M = mean_count
|
204 |
-
|
205 |
-
xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
|
206 |
-
dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
|
207 |
-
deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device)
|
208 |
-
rays = torch.empty(N, 3, dtype=torch.int32, device=rays_o.device) # id, offset, num_steps
|
209 |
-
|
210 |
-
if step_counter is None:
|
211 |
-
step_counter = torch.zeros(2, dtype=torch.int32, device=rays_o.device) # point counter, ray counter
|
212 |
-
|
213 |
-
if perturb:
|
214 |
-
noises = torch.rand(N, dtype=rays_o.dtype, device=rays_o.device)
|
215 |
-
else:
|
216 |
-
noises = torch.zeros(N, dtype=rays_o.dtype, device=rays_o.device)
|
217 |
-
|
218 |
-
_backend.march_rays_train(rays_o, rays_d, density_bitfield, bound, dt_gamma, max_steps, N, C, H, M, nears, fars, xyzs, dirs, deltas, rays, step_counter, noises) # m is the actually used points number
|
219 |
-
|
220 |
-
#print(step_counter, M)
|
221 |
-
|
222 |
-
# only used at the first (few) epochs.
|
223 |
-
if force_all_rays or mean_count <= 0:
|
224 |
-
m = step_counter[0].item() # D2H copy
|
225 |
-
if align > 0:
|
226 |
-
m += align - m % align
|
227 |
-
xyzs = xyzs[:m]
|
228 |
-
dirs = dirs[:m]
|
229 |
-
deltas = deltas[:m]
|
230 |
-
|
231 |
-
torch.cuda.empty_cache()
|
232 |
-
|
233 |
-
return xyzs, dirs, deltas, rays
|
234 |
-
|
235 |
-
march_rays_train = _march_rays_train.apply
|
236 |
-
|
237 |
-
|
238 |
-
class _composite_rays_train(Function):
|
239 |
-
@staticmethod
|
240 |
-
@custom_fwd(cast_inputs=torch.float32)
|
241 |
-
def forward(ctx, sigmas, rgbs, deltas, rays, T_thresh=1e-4):
|
242 |
-
''' composite rays' rgbs, according to the ray marching formula.
|
243 |
-
Args:
|
244 |
-
rgbs: float, [M, 3]
|
245 |
-
sigmas: float, [M,]
|
246 |
-
deltas: float, [M, 2]
|
247 |
-
rays: int32, [N, 3]
|
248 |
-
Returns:
|
249 |
-
weights_sum: float, [N,], the alpha channel
|
250 |
-
depth: float, [N, ], the Depth
|
251 |
-
image: float, [N, 3], the RGB channel (after multiplying alpha!)
|
252 |
-
'''
|
253 |
-
|
254 |
-
sigmas = sigmas.contiguous()
|
255 |
-
rgbs = rgbs.contiguous()
|
256 |
-
|
257 |
-
M = sigmas.shape[0]
|
258 |
-
N = rays.shape[0]
|
259 |
-
|
260 |
-
weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
|
261 |
-
depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
|
262 |
-
image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device)
|
263 |
-
|
264 |
-
_backend.composite_rays_train_forward(sigmas, rgbs, deltas, rays, M, N, T_thresh, weights_sum, depth, image)
|
265 |
-
|
266 |
-
ctx.save_for_backward(sigmas, rgbs, deltas, rays, weights_sum, depth, image)
|
267 |
-
ctx.dims = [M, N, T_thresh]
|
268 |
-
|
269 |
-
return weights_sum, depth, image
|
270 |
-
|
271 |
-
@staticmethod
|
272 |
-
@custom_bwd
|
273 |
-
def backward(ctx, grad_weights_sum, grad_depth, grad_image):
|
274 |
-
|
275 |
-
# NOTE: grad_depth is not used now! It won't be propagated to sigmas.
|
276 |
-
|
277 |
-
grad_weights_sum = grad_weights_sum.contiguous()
|
278 |
-
grad_image = grad_image.contiguous()
|
279 |
-
|
280 |
-
sigmas, rgbs, deltas, rays, weights_sum, depth, image = ctx.saved_tensors
|
281 |
-
M, N, T_thresh = ctx.dims
|
282 |
-
|
283 |
-
grad_sigmas = torch.zeros_like(sigmas)
|
284 |
-
grad_rgbs = torch.zeros_like(rgbs)
|
285 |
-
|
286 |
-
_backend.composite_rays_train_backward(grad_weights_sum, grad_image, sigmas, rgbs, deltas, rays, weights_sum, image, M, N, T_thresh, grad_sigmas, grad_rgbs)
|
287 |
-
|
288 |
-
return grad_sigmas, grad_rgbs, None, None, None
|
289 |
-
|
290 |
-
|
291 |
-
composite_rays_train = _composite_rays_train.apply
|
292 |
-
|
293 |
-
# ----------------------------------------
|
294 |
-
# infer functions
|
295 |
-
# ----------------------------------------
|
296 |
-
|
297 |
-
class _march_rays(Function):
|
298 |
-
@staticmethod
|
299 |
-
@custom_fwd(cast_inputs=torch.float32)
|
300 |
-
def forward(ctx, n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, density_bitfield, C, H, near, far, align=-1, perturb=False, dt_gamma=0, max_steps=1024):
|
301 |
-
''' march rays to generate points (forward only, for inference)
|
302 |
-
Args:
|
303 |
-
n_alive: int, number of alive rays
|
304 |
-
n_step: int, how many steps we march
|
305 |
-
rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive)
|
306 |
-
rays_t: float, [N], the alive rays' time, we only use the first n_alive.
|
307 |
-
rays_o/d: float, [N, 3]
|
308 |
-
bound: float, scalar
|
309 |
-
density_bitfield: uint8: [CHHH // 8]
|
310 |
-
C: int
|
311 |
-
H: int
|
312 |
-
nears/fars: float, [N]
|
313 |
-
align: int, pad output so its size is dividable by align, set to -1 to disable.
|
314 |
-
perturb: bool/int, int > 0 is used as the random seed.
|
315 |
-
dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)
|
316 |
-
max_steps: int, max number of sampled points along each ray, also affect min_stepsize.
|
317 |
-
Returns:
|
318 |
-
xyzs: float, [n_alive * n_step, 3], all generated points' coords
|
319 |
-
dirs: float, [n_alive * n_step, 3], all generated points' view dirs.
|
320 |
-
deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth).
|
321 |
-
'''
|
322 |
-
|
323 |
-
if not rays_o.is_cuda: rays_o = rays_o.cuda()
|
324 |
-
if not rays_d.is_cuda: rays_d = rays_d.cuda()
|
325 |
-
|
326 |
-
rays_o = rays_o.contiguous().view(-1, 3)
|
327 |
-
rays_d = rays_d.contiguous().view(-1, 3)
|
328 |
-
|
329 |
-
M = n_alive * n_step
|
330 |
-
|
331 |
-
if align > 0:
|
332 |
-
M += align - (M % align)
|
333 |
-
|
334 |
-
xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
|
335 |
-
dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
|
336 |
-
deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) # 2 vals, one for rgb, one for depth
|
337 |
-
|
338 |
-
if perturb:
|
339 |
-
# torch.manual_seed(perturb) # test_gui uses spp index as seed
|
340 |
-
noises = torch.rand(n_alive, dtype=rays_o.dtype, device=rays_o.device)
|
341 |
-
else:
|
342 |
-
noises = torch.zeros(n_alive, dtype=rays_o.dtype, device=rays_o.device)
|
343 |
-
|
344 |
-
_backend.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, dt_gamma, max_steps, C, H, density_bitfield, near, far, xyzs, dirs, deltas, noises)
|
345 |
-
|
346 |
-
return xyzs, dirs, deltas
|
347 |
-
|
348 |
-
march_rays = _march_rays.apply
|
349 |
-
|
350 |
-
|
351 |
-
class _composite_rays(Function):
|
352 |
-
@staticmethod
|
353 |
-
@custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
|
354 |
-
def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh=1e-2):
|
355 |
-
''' composite rays' rgbs, according to the ray marching formula. (for inference)
|
356 |
-
Args:
|
357 |
-
n_alive: int, number of alive rays
|
358 |
-
n_step: int, how many steps we march
|
359 |
-
rays_alive: int, [n_alive], the alive rays' IDs in N (N >= n_alive)
|
360 |
-
rays_t: float, [N], the alive rays' time
|
361 |
-
sigmas: float, [n_alive * n_step,]
|
362 |
-
rgbs: float, [n_alive * n_step, 3]
|
363 |
-
deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth).
|
364 |
-
In-place Outputs:
|
365 |
-
weights_sum: float, [N,], the alpha channel
|
366 |
-
depth: float, [N,], the depth value
|
367 |
-
image: float, [N, 3], the RGB channel (after multiplying alpha!)
|
368 |
-
'''
|
369 |
-
_backend.composite_rays(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image)
|
370 |
-
return tuple()
|
371 |
-
|
372 |
-
|
373 |
-
composite_rays = _composite_rays.apply
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3drecon/raymarching/setup.py
DELETED
@@ -1,62 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from setuptools import setup
|
3 |
-
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
4 |
-
|
5 |
-
_src_path = os.path.dirname(os.path.abspath(__file__))
|
6 |
-
|
7 |
-
nvcc_flags = [
|
8 |
-
'-O3', '-std=c++14',
|
9 |
-
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
10 |
-
]
|
11 |
-
|
12 |
-
if os.name == "posix":
|
13 |
-
c_flags = ['-O3', '-std=c++14']
|
14 |
-
elif os.name == "nt":
|
15 |
-
c_flags = ['/O2', '/std:c++17']
|
16 |
-
|
17 |
-
# find cl.exe
|
18 |
-
def find_cl_path():
|
19 |
-
import glob
|
20 |
-
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
21 |
-
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
|
22 |
-
if paths:
|
23 |
-
return paths[0]
|
24 |
-
|
25 |
-
# If cl.exe is not on path, try to find it.
|
26 |
-
if os.system("where cl.exe >nul 2>nul") != 0:
|
27 |
-
cl_path = find_cl_path()
|
28 |
-
if cl_path is None:
|
29 |
-
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
30 |
-
os.environ["PATH"] += ";" + cl_path
|
31 |
-
|
32 |
-
'''
|
33 |
-
Usage:
|
34 |
-
|
35 |
-
python setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory)
|
36 |
-
|
37 |
-
python setup.py install # build extensions and install (copy) to PATH.
|
38 |
-
pip install . # ditto but better (e.g., dependency & metadata handling)
|
39 |
-
|
40 |
-
python setup.py develop # build extensions and install (symbolic) to PATH.
|
41 |
-
pip install -e . # ditto but better (e.g., dependency & metadata handling)
|
42 |
-
|
43 |
-
'''
|
44 |
-
setup(
|
45 |
-
name='raymarching', # package name, import this to use python API
|
46 |
-
ext_modules=[
|
47 |
-
CUDAExtension(
|
48 |
-
name='_raymarching', # extension name, import this to use CUDA API
|
49 |
-
sources=[os.path.join(_src_path, 'src', f) for f in [
|
50 |
-
'raymarching.cu',
|
51 |
-
'bindings.cpp',
|
52 |
-
]],
|
53 |
-
extra_compile_args={
|
54 |
-
'cxx': c_flags,
|
55 |
-
'nvcc': nvcc_flags,
|
56 |
-
}
|
57 |
-
),
|
58 |
-
],
|
59 |
-
cmdclass={
|
60 |
-
'build_ext': BuildExtension,
|
61 |
-
}
|
62 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3drecon/raymarching/src/bindings.cpp
DELETED
@@ -1,19 +0,0 @@
|
|
1 |
-
#include <torch/extension.h>
|
2 |
-
|
3 |
-
#include "raymarching.h"
|
4 |
-
|
5 |
-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
6 |
-
// utils
|
7 |
-
m.def("packbits", &packbits, "packbits (CUDA)");
|
8 |
-
m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)");
|
9 |
-
m.def("sph_from_ray", &sph_from_ray, "sph_from_ray (CUDA)");
|
10 |
-
m.def("morton3D", &morton3D, "morton3D (CUDA)");
|
11 |
-
m.def("morton3D_invert", &morton3D_invert, "morton3D_invert (CUDA)");
|
12 |
-
// train
|
13 |
-
m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)");
|
14 |
-
m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)");
|
15 |
-
m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)");
|
16 |
-
// infer
|
17 |
-
m.def("march_rays", &march_rays, "march rays (CUDA)");
|
18 |
-
m.def("composite_rays", &composite_rays, "composite rays (CUDA)");
|
19 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3drecon/raymarching/src/raymarching.cu
DELETED
@@ -1,914 +0,0 @@
|
|
1 |
-
#include <cuda.h>
|
2 |
-
#include <cuda_fp16.h>
|
3 |
-
#include <cuda_runtime.h>
|
4 |
-
|
5 |
-
#include <ATen/cuda/CUDAContext.h>
|
6 |
-
#include <torch/torch.h>
|
7 |
-
|
8 |
-
#include <cstdio>
|
9 |
-
#include <stdint.h>
|
10 |
-
#include <stdexcept>
|
11 |
-
#include <limits>
|
12 |
-
|
13 |
-
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
14 |
-
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
|
15 |
-
#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
|
16 |
-
#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
|
17 |
-
|
18 |
-
|
19 |
-
inline constexpr __device__ float SQRT3() { return 1.7320508075688772f; }
|
20 |
-
inline constexpr __device__ float RSQRT3() { return 0.5773502691896258f; }
|
21 |
-
inline constexpr __device__ float PI() { return 3.141592653589793f; }
|
22 |
-
inline constexpr __device__ float RPI() { return 0.3183098861837907f; }
|
23 |
-
|
24 |
-
|
25 |
-
template <typename T>
|
26 |
-
inline __host__ __device__ T div_round_up(T val, T divisor) {
|
27 |
-
return (val + divisor - 1) / divisor;
|
28 |
-
}
|
29 |
-
|
30 |
-
inline __host__ __device__ float signf(const float x) {
|
31 |
-
return copysignf(1.0, x);
|
32 |
-
}
|
33 |
-
|
34 |
-
inline __host__ __device__ float clamp(const float x, const float min, const float max) {
|
35 |
-
return fminf(max, fmaxf(min, x));
|
36 |
-
}
|
37 |
-
|
38 |
-
inline __host__ __device__ void swapf(float& a, float& b) {
|
39 |
-
float c = a; a = b; b = c;
|
40 |
-
}
|
41 |
-
|
42 |
-
inline __device__ int mip_from_pos(const float x, const float y, const float z, const float max_cascade) {
|
43 |
-
const float mx = fmaxf(fabsf(x), fmaxf(fabs(y), fabs(z)));
|
44 |
-
int exponent;
|
45 |
-
frexpf(mx, &exponent); // [0, 0.5) --> -1, [0.5, 1) --> 0, [1, 2) --> 1, [2, 4) --> 2, ...
|
46 |
-
return fminf(max_cascade - 1, fmaxf(0, exponent));
|
47 |
-
}
|
48 |
-
|
49 |
-
inline __device__ int mip_from_dt(const float dt, const float H, const float max_cascade) {
|
50 |
-
const float mx = dt * H * 0.5;
|
51 |
-
int exponent;
|
52 |
-
frexpf(mx, &exponent);
|
53 |
-
return fminf(max_cascade - 1, fmaxf(0, exponent));
|
54 |
-
}
|
55 |
-
|
56 |
-
inline __host__ __device__ uint32_t __expand_bits(uint32_t v)
|
57 |
-
{
|
58 |
-
v = (v * 0x00010001u) & 0xFF0000FFu;
|
59 |
-
v = (v * 0x00000101u) & 0x0F00F00Fu;
|
60 |
-
v = (v * 0x00000011u) & 0xC30C30C3u;
|
61 |
-
v = (v * 0x00000005u) & 0x49249249u;
|
62 |
-
return v;
|
63 |
-
}
|
64 |
-
|
65 |
-
inline __host__ __device__ uint32_t __morton3D(uint32_t x, uint32_t y, uint32_t z)
|
66 |
-
{
|
67 |
-
uint32_t xx = __expand_bits(x);
|
68 |
-
uint32_t yy = __expand_bits(y);
|
69 |
-
uint32_t zz = __expand_bits(z);
|
70 |
-
return xx | (yy << 1) | (zz << 2);
|
71 |
-
}
|
72 |
-
|
73 |
-
inline __host__ __device__ uint32_t __morton3D_invert(uint32_t x)
|
74 |
-
{
|
75 |
-
x = x & 0x49249249;
|
76 |
-
x = (x | (x >> 2)) & 0xc30c30c3;
|
77 |
-
x = (x | (x >> 4)) & 0x0f00f00f;
|
78 |
-
x = (x | (x >> 8)) & 0xff0000ff;
|
79 |
-
x = (x | (x >> 16)) & 0x0000ffff;
|
80 |
-
return x;
|
81 |
-
}
|
82 |
-
|
83 |
-
|
84 |
-
////////////////////////////////////////////////////
|
85 |
-
///////////// utils /////////////
|
86 |
-
////////////////////////////////////////////////////
|
87 |
-
|
88 |
-
// rays_o/d: [N, 3]
|
89 |
-
// nears/fars: [N]
|
90 |
-
// scalar_t should always be float in use.
|
91 |
-
template <typename scalar_t>
|
92 |
-
__global__ void kernel_near_far_from_aabb(
|
93 |
-
const scalar_t * __restrict__ rays_o,
|
94 |
-
const scalar_t * __restrict__ rays_d,
|
95 |
-
const scalar_t * __restrict__ aabb,
|
96 |
-
const uint32_t N,
|
97 |
-
const float min_near,
|
98 |
-
scalar_t * nears, scalar_t * fars
|
99 |
-
) {
|
100 |
-
// parallel per ray
|
101 |
-
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
102 |
-
if (n >= N) return;
|
103 |
-
|
104 |
-
// locate
|
105 |
-
rays_o += n * 3;
|
106 |
-
rays_d += n * 3;
|
107 |
-
|
108 |
-
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
|
109 |
-
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
|
110 |
-
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
|
111 |
-
|
112 |
-
// get near far (assume cube scene)
|
113 |
-
float near = (aabb[0] - ox) * rdx;
|
114 |
-
float far = (aabb[3] - ox) * rdx;
|
115 |
-
if (near > far) swapf(near, far);
|
116 |
-
|
117 |
-
float near_y = (aabb[1] - oy) * rdy;
|
118 |
-
float far_y = (aabb[4] - oy) * rdy;
|
119 |
-
if (near_y > far_y) swapf(near_y, far_y);
|
120 |
-
|
121 |
-
if (near > far_y || near_y > far) {
|
122 |
-
nears[n] = fars[n] = std::numeric_limits<scalar_t>::max();
|
123 |
-
return;
|
124 |
-
}
|
125 |
-
|
126 |
-
if (near_y > near) near = near_y;
|
127 |
-
if (far_y < far) far = far_y;
|
128 |
-
|
129 |
-
float near_z = (aabb[2] - oz) * rdz;
|
130 |
-
float far_z = (aabb[5] - oz) * rdz;
|
131 |
-
if (near_z > far_z) swapf(near_z, far_z);
|
132 |
-
|
133 |
-
if (near > far_z || near_z > far) {
|
134 |
-
nears[n] = fars[n] = std::numeric_limits<scalar_t>::max();
|
135 |
-
return;
|
136 |
-
}
|
137 |
-
|
138 |
-
if (near_z > near) near = near_z;
|
139 |
-
if (far_z < far) far = far_z;
|
140 |
-
|
141 |
-
if (near < min_near) near = min_near;
|
142 |
-
|
143 |
-
nears[n] = near;
|
144 |
-
fars[n] = far;
|
145 |
-
}
|
146 |
-
|
147 |
-
|
148 |
-
void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars) {
|
149 |
-
|
150 |
-
static constexpr uint32_t N_THREAD = 128;
|
151 |
-
|
152 |
-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
153 |
-
rays_o.scalar_type(), "near_far_from_aabb", ([&] {
|
154 |
-
kernel_near_far_from_aabb<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), aabb.data_ptr<scalar_t>(), N, min_near, nears.data_ptr<scalar_t>(), fars.data_ptr<scalar_t>());
|
155 |
-
}));
|
156 |
-
}
|
157 |
-
|
158 |
-
|
159 |
-
// rays_o/d: [N, 3]
|
160 |
-
// radius: float
|
161 |
-
// coords: [N, 2]
|
162 |
-
template <typename scalar_t>
|
163 |
-
__global__ void kernel_sph_from_ray(
|
164 |
-
const scalar_t * __restrict__ rays_o,
|
165 |
-
const scalar_t * __restrict__ rays_d,
|
166 |
-
const float radius,
|
167 |
-
const uint32_t N,
|
168 |
-
scalar_t * coords
|
169 |
-
) {
|
170 |
-
// parallel per ray
|
171 |
-
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
172 |
-
if (n >= N) return;
|
173 |
-
|
174 |
-
// locate
|
175 |
-
rays_o += n * 3;
|
176 |
-
rays_d += n * 3;
|
177 |
-
coords += n * 2;
|
178 |
-
|
179 |
-
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
|
180 |
-
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
|
181 |
-
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
|
182 |
-
|
183 |
-
// solve t from || o + td || = radius
|
184 |
-
const float A = dx * dx + dy * dy + dz * dz;
|
185 |
-
const float B = ox * dx + oy * dy + oz * dz; // in fact B / 2
|
186 |
-
const float C = ox * ox + oy * oy + oz * oz - radius * radius;
|
187 |
-
|
188 |
-
const float t = (- B + sqrtf(B * B - A * C)) / A; // always use the larger solution (positive)
|
189 |
-
|
190 |
-
// solve theta, phi (assume y is the up axis)
|
191 |
-
const float x = ox + t * dx, y = oy + t * dy, z = oz + t * dz;
|
192 |
-
const float theta = atan2(sqrtf(x * x + z * z), y); // [0, PI)
|
193 |
-
const float phi = atan2(z, x); // [-PI, PI)
|
194 |
-
|
195 |
-
// normalize to [-1, 1]
|
196 |
-
coords[0] = 2 * theta * RPI() - 1;
|
197 |
-
coords[1] = phi * RPI();
|
198 |
-
}
|
199 |
-
|
200 |
-
|
201 |
-
void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords) {
|
202 |
-
|
203 |
-
static constexpr uint32_t N_THREAD = 128;
|
204 |
-
|
205 |
-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
206 |
-
rays_o.scalar_type(), "sph_from_ray", ([&] {
|
207 |
-
kernel_sph_from_ray<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), radius, N, coords.data_ptr<scalar_t>());
|
208 |
-
}));
|
209 |
-
}
|
210 |
-
|
211 |
-
|
212 |
-
// coords: int32, [N, 3]
|
213 |
-
// indices: int32, [N]
|
214 |
-
__global__ void kernel_morton3D(
|
215 |
-
const int * __restrict__ coords,
|
216 |
-
const uint32_t N,
|
217 |
-
int * indices
|
218 |
-
) {
|
219 |
-
// parallel
|
220 |
-
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
221 |
-
if (n >= N) return;
|
222 |
-
|
223 |
-
// locate
|
224 |
-
coords += n * 3;
|
225 |
-
indices[n] = __morton3D(coords[0], coords[1], coords[2]);
|
226 |
-
}
|
227 |
-
|
228 |
-
|
229 |
-
void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices) {
|
230 |
-
static constexpr uint32_t N_THREAD = 128;
|
231 |
-
kernel_morton3D<<<div_round_up(N, N_THREAD), N_THREAD>>>(coords.data_ptr<int>(), N, indices.data_ptr<int>());
|
232 |
-
}
|
233 |
-
|
234 |
-
|
235 |
-
// indices: int32, [N]
|
236 |
-
// coords: int32, [N, 3]
|
237 |
-
__global__ void kernel_morton3D_invert(
|
238 |
-
const int * __restrict__ indices,
|
239 |
-
const uint32_t N,
|
240 |
-
int * coords
|
241 |
-
) {
|
242 |
-
// parallel
|
243 |
-
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
244 |
-
if (n >= N) return;
|
245 |
-
|
246 |
-
// locate
|
247 |
-
coords += n * 3;
|
248 |
-
|
249 |
-
const int ind = indices[n];
|
250 |
-
|
251 |
-
coords[0] = __morton3D_invert(ind >> 0);
|
252 |
-
coords[1] = __morton3D_invert(ind >> 1);
|
253 |
-
coords[2] = __morton3D_invert(ind >> 2);
|
254 |
-
}
|
255 |
-
|
256 |
-
|
257 |
-
void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords) {
|
258 |
-
static constexpr uint32_t N_THREAD = 128;
|
259 |
-
kernel_morton3D_invert<<<div_round_up(N, N_THREAD), N_THREAD>>>(indices.data_ptr<int>(), N, coords.data_ptr<int>());
|
260 |
-
}
|
261 |
-
|
262 |
-
|
263 |
-
// grid: float, [C, H, H, H]
|
264 |
-
// N: int, C * H * H * H / 8
|
265 |
-
// density_thresh: float
|
266 |
-
// bitfield: uint8, [N]
|
267 |
-
template <typename scalar_t>
|
268 |
-
__global__ void kernel_packbits(
|
269 |
-
const scalar_t * __restrict__ grid,
|
270 |
-
const uint32_t N,
|
271 |
-
const float density_thresh,
|
272 |
-
uint8_t * bitfield
|
273 |
-
) {
|
274 |
-
// parallel per byte
|
275 |
-
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
276 |
-
if (n >= N) return;
|
277 |
-
|
278 |
-
// locate
|
279 |
-
grid += n * 8;
|
280 |
-
|
281 |
-
uint8_t bits = 0;
|
282 |
-
|
283 |
-
#pragma unroll
|
284 |
-
for (uint8_t i = 0; i < 8; i++) {
|
285 |
-
bits |= (grid[i] > density_thresh) ? ((uint8_t)1 << i) : 0;
|
286 |
-
}
|
287 |
-
|
288 |
-
bitfield[n] = bits;
|
289 |
-
}
|
290 |
-
|
291 |
-
|
292 |
-
void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield) {
|
293 |
-
|
294 |
-
static constexpr uint32_t N_THREAD = 128;
|
295 |
-
|
296 |
-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
297 |
-
grid.scalar_type(), "packbits", ([&] {
|
298 |
-
kernel_packbits<<<div_round_up(N, N_THREAD), N_THREAD>>>(grid.data_ptr<scalar_t>(), N, density_thresh, bitfield.data_ptr<uint8_t>());
|
299 |
-
}));
|
300 |
-
}
|
301 |
-
|
302 |
-
////////////////////////////////////////////////////
|
303 |
-
///////////// training /////////////
|
304 |
-
////////////////////////////////////////////////////
|
305 |
-
|
306 |
-
// rays_o/d: [N, 3]
|
307 |
-
// grid: [CHHH / 8]
|
308 |
-
// xyzs, dirs, deltas: [M, 3], [M, 3], [M, 2]
|
309 |
-
// dirs: [M, 3]
|
310 |
-
// rays: [N, 3], idx, offset, num_steps
|
311 |
-
template <typename scalar_t>
|
312 |
-
__global__ void kernel_march_rays_train(
|
313 |
-
const scalar_t * __restrict__ rays_o,
|
314 |
-
const scalar_t * __restrict__ rays_d,
|
315 |
-
const uint8_t * __restrict__ grid,
|
316 |
-
const float bound,
|
317 |
-
const float dt_gamma, const uint32_t max_steps,
|
318 |
-
const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M,
|
319 |
-
const scalar_t* __restrict__ nears,
|
320 |
-
const scalar_t* __restrict__ fars,
|
321 |
-
scalar_t * xyzs, scalar_t * dirs, scalar_t * deltas,
|
322 |
-
int * rays,
|
323 |
-
int * counter,
|
324 |
-
const scalar_t* __restrict__ noises
|
325 |
-
) {
|
326 |
-
// parallel per ray
|
327 |
-
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
328 |
-
if (n >= N) return;
|
329 |
-
|
330 |
-
// locate
|
331 |
-
rays_o += n * 3;
|
332 |
-
rays_d += n * 3;
|
333 |
-
|
334 |
-
// ray marching
|
335 |
-
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
|
336 |
-
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
|
337 |
-
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
|
338 |
-
const float rH = 1 / (float)H;
|
339 |
-
const float H3 = H * H * H;
|
340 |
-
|
341 |
-
const float near = nears[n];
|
342 |
-
const float far = fars[n];
|
343 |
-
const float noise = noises[n];
|
344 |
-
|
345 |
-
const float dt_min = 2 * SQRT3() / max_steps;
|
346 |
-
const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H;
|
347 |
-
|
348 |
-
float t0 = near;
|
349 |
-
|
350 |
-
// perturb
|
351 |
-
t0 += clamp(t0 * dt_gamma, dt_min, dt_max) * noise;
|
352 |
-
|
353 |
-
// first pass: estimation of num_steps
|
354 |
-
float t = t0;
|
355 |
-
uint32_t num_steps = 0;
|
356 |
-
|
357 |
-
//if (t < far) printf("valid ray %d t=%f near=%f far=%f \n", n, t, near, far);
|
358 |
-
|
359 |
-
while (t < far && num_steps < max_steps) {
|
360 |
-
// current point
|
361 |
-
const float x = clamp(ox + t * dx, -bound, bound);
|
362 |
-
const float y = clamp(oy + t * dy, -bound, bound);
|
363 |
-
const float z = clamp(oz + t * dz, -bound, bound);
|
364 |
-
|
365 |
-
const float dt = clamp(t * dt_gamma, dt_min, dt_max);
|
366 |
-
|
367 |
-
// get mip level
|
368 |
-
const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]
|
369 |
-
|
370 |
-
const float mip_bound = fminf(scalbnf(1.0f, level), bound);
|
371 |
-
const float mip_rbound = 1 / mip_bound;
|
372 |
-
|
373 |
-
// convert to nearest grid position
|
374 |
-
const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
375 |
-
const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
376 |
-
const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
377 |
-
|
378 |
-
const uint32_t index = level * H3 + __morton3D(nx, ny, nz);
|
379 |
-
const bool occ = grid[index / 8] & (1 << (index % 8));
|
380 |
-
|
381 |
-
// if occpuied, advance a small step, and write to output
|
382 |
-
//if (n == 0) printf("t=%f density=%f vs thresh=%f step=%d\n", t, density, density_thresh, num_steps);
|
383 |
-
|
384 |
-
if (occ) {
|
385 |
-
num_steps++;
|
386 |
-
t += dt;
|
387 |
-
// else, skip a large step (basically skip a voxel grid)
|
388 |
-
} else {
|
389 |
-
// calc distance to next voxel
|
390 |
-
const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx;
|
391 |
-
const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy;
|
392 |
-
const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz;
|
393 |
-
|
394 |
-
const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
|
395 |
-
// step until next voxel
|
396 |
-
do {
|
397 |
-
t += clamp(t * dt_gamma, dt_min, dt_max);
|
398 |
-
} while (t < tt);
|
399 |
-
}
|
400 |
-
}
|
401 |
-
|
402 |
-
//printf("[n=%d] num_steps=%d, near=%f, far=%f, dt=%f, max_steps=%f\n", n, num_steps, near, far, dt_min, (far - near) / dt_min);
|
403 |
-
|
404 |
-
// second pass: really locate and write points & dirs
|
405 |
-
uint32_t point_index = atomicAdd(counter, num_steps);
|
406 |
-
uint32_t ray_index = atomicAdd(counter + 1, 1);
|
407 |
-
|
408 |
-
//printf("[n=%d] num_steps=%d, point_index=%d, ray_index=%d\n", n, num_steps, point_index, ray_index);
|
409 |
-
|
410 |
-
// write rays
|
411 |
-
rays[ray_index * 3] = n;
|
412 |
-
rays[ray_index * 3 + 1] = point_index;
|
413 |
-
rays[ray_index * 3 + 2] = num_steps;
|
414 |
-
|
415 |
-
if (num_steps == 0) return;
|
416 |
-
if (point_index + num_steps > M) return;
|
417 |
-
|
418 |
-
xyzs += point_index * 3;
|
419 |
-
dirs += point_index * 3;
|
420 |
-
deltas += point_index * 2;
|
421 |
-
|
422 |
-
t = t0;
|
423 |
-
uint32_t step = 0;
|
424 |
-
|
425 |
-
float last_t = t;
|
426 |
-
|
427 |
-
while (t < far && step < num_steps) {
|
428 |
-
// current point
|
429 |
-
const float x = clamp(ox + t * dx, -bound, bound);
|
430 |
-
const float y = clamp(oy + t * dy, -bound, bound);
|
431 |
-
const float z = clamp(oz + t * dz, -bound, bound);
|
432 |
-
|
433 |
-
const float dt = clamp(t * dt_gamma, dt_min, dt_max);
|
434 |
-
|
435 |
-
// get mip level
|
436 |
-
const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]
|
437 |
-
|
438 |
-
const float mip_bound = fminf(scalbnf(1.0f, level), bound);
|
439 |
-
const float mip_rbound = 1 / mip_bound;
|
440 |
-
|
441 |
-
// convert to nearest grid position
|
442 |
-
const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
443 |
-
const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
444 |
-
const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
445 |
-
|
446 |
-
// query grid
|
447 |
-
const uint32_t index = level * H3 + __morton3D(nx, ny, nz);
|
448 |
-
const bool occ = grid[index / 8] & (1 << (index % 8));
|
449 |
-
|
450 |
-
// if occpuied, advance a small step, and write to output
|
451 |
-
if (occ) {
|
452 |
-
// write step
|
453 |
-
xyzs[0] = x;
|
454 |
-
xyzs[1] = y;
|
455 |
-
xyzs[2] = z;
|
456 |
-
dirs[0] = dx;
|
457 |
-
dirs[1] = dy;
|
458 |
-
dirs[2] = dz;
|
459 |
-
t += dt;
|
460 |
-
deltas[0] = dt;
|
461 |
-
deltas[1] = t - last_t; // used to calc depth
|
462 |
-
last_t = t;
|
463 |
-
xyzs += 3;
|
464 |
-
dirs += 3;
|
465 |
-
deltas += 2;
|
466 |
-
step++;
|
467 |
-
// else, skip a large step (basically skip a voxel grid)
|
468 |
-
} else {
|
469 |
-
// calc distance to next voxel
|
470 |
-
const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx;
|
471 |
-
const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy;
|
472 |
-
const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz;
|
473 |
-
const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
|
474 |
-
// step until next voxel
|
475 |
-
do {
|
476 |
-
t += clamp(t * dt_gamma, dt_min, dt_max);
|
477 |
-
} while (t < tt);
|
478 |
-
}
|
479 |
-
}
|
480 |
-
}
|
481 |
-
|
482 |
-
void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises) {
|
483 |
-
|
484 |
-
static constexpr uint32_t N_THREAD = 128;
|
485 |
-
|
486 |
-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
487 |
-
rays_o.scalar_type(), "march_rays_train", ([&] {
|
488 |
-
kernel_march_rays_train<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), grid.data_ptr<uint8_t>(), bound, dt_gamma, max_steps, N, C, H, M, nears.data_ptr<scalar_t>(), fars.data_ptr<scalar_t>(), xyzs.data_ptr<scalar_t>(), dirs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), counter.data_ptr<int>(), noises.data_ptr<scalar_t>());
|
489 |
-
}));
|
490 |
-
}
|
491 |
-
|
492 |
-
|
493 |
-
// sigmas: [M]
|
494 |
-
// rgbs: [M, 3]
|
495 |
-
// deltas: [M, 2]
|
496 |
-
// rays: [N, 3], idx, offset, num_steps
|
497 |
-
// weights_sum: [N], final pixel alpha
|
498 |
-
// depth: [N,]
|
499 |
-
// image: [N, 3]
|
500 |
-
template <typename scalar_t>
|
501 |
-
__global__ void kernel_composite_rays_train_forward(
|
502 |
-
const scalar_t * __restrict__ sigmas,
|
503 |
-
const scalar_t * __restrict__ rgbs,
|
504 |
-
const scalar_t * __restrict__ deltas,
|
505 |
-
const int * __restrict__ rays,
|
506 |
-
const uint32_t M, const uint32_t N, const float T_thresh,
|
507 |
-
scalar_t * weights_sum,
|
508 |
-
scalar_t * depth,
|
509 |
-
scalar_t * image
|
510 |
-
) {
|
511 |
-
// parallel per ray
|
512 |
-
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
513 |
-
if (n >= N) return;
|
514 |
-
|
515 |
-
// locate
|
516 |
-
uint32_t index = rays[n * 3];
|
517 |
-
uint32_t offset = rays[n * 3 + 1];
|
518 |
-
uint32_t num_steps = rays[n * 3 + 2];
|
519 |
-
|
520 |
-
// empty ray, or ray that exceed max step count.
|
521 |
-
if (num_steps == 0 || offset + num_steps > M) {
|
522 |
-
weights_sum[index] = 0;
|
523 |
-
depth[index] = 0;
|
524 |
-
image[index * 3] = 0;
|
525 |
-
image[index * 3 + 1] = 0;
|
526 |
-
image[index * 3 + 2] = 0;
|
527 |
-
return;
|
528 |
-
}
|
529 |
-
|
530 |
-
sigmas += offset;
|
531 |
-
rgbs += offset * 3;
|
532 |
-
deltas += offset * 2;
|
533 |
-
|
534 |
-
// accumulate
|
535 |
-
uint32_t step = 0;
|
536 |
-
|
537 |
-
scalar_t T = 1.0f;
|
538 |
-
scalar_t r = 0, g = 0, b = 0, ws = 0, t = 0, d = 0;
|
539 |
-
|
540 |
-
while (step < num_steps) {
|
541 |
-
|
542 |
-
const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]);
|
543 |
-
const scalar_t weight = alpha * T;
|
544 |
-
|
545 |
-
r += weight * rgbs[0];
|
546 |
-
g += weight * rgbs[1];
|
547 |
-
b += weight * rgbs[2];
|
548 |
-
|
549 |
-
t += deltas[1]; // real delta
|
550 |
-
d += weight * t;
|
551 |
-
|
552 |
-
ws += weight;
|
553 |
-
|
554 |
-
T *= 1.0f - alpha;
|
555 |
-
|
556 |
-
// minimal remained transmittence
|
557 |
-
if (T < T_thresh) break;
|
558 |
-
|
559 |
-
//printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d);
|
560 |
-
|
561 |
-
// locate
|
562 |
-
sigmas++;
|
563 |
-
rgbs += 3;
|
564 |
-
deltas += 2;
|
565 |
-
|
566 |
-
step++;
|
567 |
-
}
|
568 |
-
|
569 |
-
//printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d);
|
570 |
-
|
571 |
-
// write
|
572 |
-
weights_sum[index] = ws; // weights_sum
|
573 |
-
depth[index] = d;
|
574 |
-
image[index * 3] = r;
|
575 |
-
image[index * 3 + 1] = g;
|
576 |
-
image[index * 3 + 2] = b;
|
577 |
-
}
|
578 |
-
|
579 |
-
|
580 |
-
void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor depth, at::Tensor image) {
|
581 |
-
|
582 |
-
static constexpr uint32_t N_THREAD = 128;
|
583 |
-
|
584 |
-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
585 |
-
sigmas.scalar_type(), "composite_rays_train_forward", ([&] {
|
586 |
-
kernel_composite_rays_train_forward<<<div_round_up(N, N_THREAD), N_THREAD>>>(sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), M, N, T_thresh, weights_sum.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>());
|
587 |
-
}));
|
588 |
-
}
|
589 |
-
|
590 |
-
|
591 |
-
// grad_weights_sum: [N,]
|
592 |
-
// grad: [N, 3]
|
593 |
-
// sigmas: [M]
|
594 |
-
// rgbs: [M, 3]
|
595 |
-
// deltas: [M, 2]
|
596 |
-
// rays: [N, 3], idx, offset, num_steps
|
597 |
-
// weights_sum: [N,], weights_sum here
|
598 |
-
// image: [N, 3]
|
599 |
-
// grad_sigmas: [M]
|
600 |
-
// grad_rgbs: [M, 3]
|
601 |
-
template <typename scalar_t>
|
602 |
-
__global__ void kernel_composite_rays_train_backward(
|
603 |
-
const scalar_t * __restrict__ grad_weights_sum,
|
604 |
-
const scalar_t * __restrict__ grad_image,
|
605 |
-
const scalar_t * __restrict__ sigmas,
|
606 |
-
const scalar_t * __restrict__ rgbs,
|
607 |
-
const scalar_t * __restrict__ deltas,
|
608 |
-
const int * __restrict__ rays,
|
609 |
-
const scalar_t * __restrict__ weights_sum,
|
610 |
-
const scalar_t * __restrict__ image,
|
611 |
-
const uint32_t M, const uint32_t N, const float T_thresh,
|
612 |
-
scalar_t * grad_sigmas,
|
613 |
-
scalar_t * grad_rgbs
|
614 |
-
) {
|
615 |
-
// parallel per ray
|
616 |
-
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
617 |
-
if (n >= N) return;
|
618 |
-
|
619 |
-
// locate
|
620 |
-
uint32_t index = rays[n * 3];
|
621 |
-
uint32_t offset = rays[n * 3 + 1];
|
622 |
-
uint32_t num_steps = rays[n * 3 + 2];
|
623 |
-
|
624 |
-
if (num_steps == 0 || offset + num_steps > M) return;
|
625 |
-
|
626 |
-
grad_weights_sum += index;
|
627 |
-
grad_image += index * 3;
|
628 |
-
weights_sum += index;
|
629 |
-
image += index * 3;
|
630 |
-
sigmas += offset;
|
631 |
-
rgbs += offset * 3;
|
632 |
-
deltas += offset * 2;
|
633 |
-
grad_sigmas += offset;
|
634 |
-
grad_rgbs += offset * 3;
|
635 |
-
|
636 |
-
// accumulate
|
637 |
-
uint32_t step = 0;
|
638 |
-
|
639 |
-
scalar_t T = 1.0f;
|
640 |
-
const scalar_t r_final = image[0], g_final = image[1], b_final = image[2], ws_final = weights_sum[0];
|
641 |
-
scalar_t r = 0, g = 0, b = 0, ws = 0;
|
642 |
-
|
643 |
-
while (step < num_steps) {
|
644 |
-
|
645 |
-
const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]);
|
646 |
-
const scalar_t weight = alpha * T;
|
647 |
-
|
648 |
-
r += weight * rgbs[0];
|
649 |
-
g += weight * rgbs[1];
|
650 |
-
b += weight * rgbs[2];
|
651 |
-
ws += weight;
|
652 |
-
|
653 |
-
T *= 1.0f - alpha;
|
654 |
-
|
655 |
-
// check https://note.kiui.moe/others/nerf_gradient/ for the gradient calculation.
|
656 |
-
// write grad_rgbs
|
657 |
-
grad_rgbs[0] = grad_image[0] * weight;
|
658 |
-
grad_rgbs[1] = grad_image[1] * weight;
|
659 |
-
grad_rgbs[2] = grad_image[2] * weight;
|
660 |
-
|
661 |
-
// write grad_sigmas
|
662 |
-
grad_sigmas[0] = deltas[0] * (
|
663 |
-
grad_image[0] * (T * rgbs[0] - (r_final - r)) +
|
664 |
-
grad_image[1] * (T * rgbs[1] - (g_final - g)) +
|
665 |
-
grad_image[2] * (T * rgbs[2] - (b_final - b)) +
|
666 |
-
grad_weights_sum[0] * (1 - ws_final)
|
667 |
-
);
|
668 |
-
|
669 |
-
//printf("[n=%d] num_steps=%d, T=%f, grad_sigmas=%f, r_final=%f, r=%f\n", n, step, T, grad_sigmas[0], r_final, r);
|
670 |
-
// minimal remained transmittence
|
671 |
-
if (T < T_thresh) break;
|
672 |
-
|
673 |
-
// locate
|
674 |
-
sigmas++;
|
675 |
-
rgbs += 3;
|
676 |
-
deltas += 2;
|
677 |
-
grad_sigmas++;
|
678 |
-
grad_rgbs += 3;
|
679 |
-
|
680 |
-
step++;
|
681 |
-
}
|
682 |
-
}
|
683 |
-
|
684 |
-
|
685 |
-
void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs) {
|
686 |
-
|
687 |
-
static constexpr uint32_t N_THREAD = 128;
|
688 |
-
|
689 |
-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
690 |
-
grad_image.scalar_type(), "composite_rays_train_backward", ([&] {
|
691 |
-
kernel_composite_rays_train_backward<<<div_round_up(N, N_THREAD), N_THREAD>>>(grad_weights_sum.data_ptr<scalar_t>(), grad_image.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), weights_sum.data_ptr<scalar_t>(), image.data_ptr<scalar_t>(), M, N, T_thresh, grad_sigmas.data_ptr<scalar_t>(), grad_rgbs.data_ptr<scalar_t>());
|
692 |
-
}));
|
693 |
-
}
|
694 |
-
|
695 |
-
|
696 |
-
////////////////////////////////////////////////////
|
697 |
-
///////////// infernce /////////////
|
698 |
-
////////////////////////////////////////////////////
|
699 |
-
|
700 |
-
template <typename scalar_t>
|
701 |
-
__global__ void kernel_march_rays(
|
702 |
-
const uint32_t n_alive,
|
703 |
-
const uint32_t n_step,
|
704 |
-
const int* __restrict__ rays_alive,
|
705 |
-
const scalar_t* __restrict__ rays_t,
|
706 |
-
const scalar_t* __restrict__ rays_o,
|
707 |
-
const scalar_t* __restrict__ rays_d,
|
708 |
-
const float bound,
|
709 |
-
const float dt_gamma, const uint32_t max_steps,
|
710 |
-
const uint32_t C, const uint32_t H,
|
711 |
-
const uint8_t * __restrict__ grid,
|
712 |
-
const scalar_t* __restrict__ nears,
|
713 |
-
const scalar_t* __restrict__ fars,
|
714 |
-
scalar_t* xyzs, scalar_t* dirs, scalar_t* deltas,
|
715 |
-
const scalar_t* __restrict__ noises
|
716 |
-
) {
|
717 |
-
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
718 |
-
if (n >= n_alive) return;
|
719 |
-
|
720 |
-
const int index = rays_alive[n]; // ray id
|
721 |
-
const float noise = noises[n];
|
722 |
-
|
723 |
-
// locate
|
724 |
-
rays_o += index * 3;
|
725 |
-
rays_d += index * 3;
|
726 |
-
xyzs += n * n_step * 3;
|
727 |
-
dirs += n * n_step * 3;
|
728 |
-
deltas += n * n_step * 2;
|
729 |
-
|
730 |
-
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
|
731 |
-
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
|
732 |
-
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
|
733 |
-
const float rH = 1 / (float)H;
|
734 |
-
const float H3 = H * H * H;
|
735 |
-
|
736 |
-
float t = rays_t[index]; // current ray's t
|
737 |
-
const float near = nears[index], far = fars[index];
|
738 |
-
|
739 |
-
const float dt_min = 2 * SQRT3() / max_steps;
|
740 |
-
const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H;
|
741 |
-
|
742 |
-
// march for n_step steps, record points
|
743 |
-
uint32_t step = 0;
|
744 |
-
|
745 |
-
// introduce some randomness
|
746 |
-
t += clamp(t * dt_gamma, dt_min, dt_max) * noise;
|
747 |
-
|
748 |
-
float last_t = t;
|
749 |
-
|
750 |
-
while (t < far && step < n_step) {
|
751 |
-
// current point
|
752 |
-
const float x = clamp(ox + t * dx, -bound, bound);
|
753 |
-
const float y = clamp(oy + t * dy, -bound, bound);
|
754 |
-
const float z = clamp(oz + t * dz, -bound, bound);
|
755 |
-
|
756 |
-
const float dt = clamp(t * dt_gamma, dt_min, dt_max);
|
757 |
-
|
758 |
-
// get mip level
|
759 |
-
const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]
|
760 |
-
|
761 |
-
const float mip_bound = fminf(scalbnf(1, level), bound);
|
762 |
-
const float mip_rbound = 1 / mip_bound;
|
763 |
-
|
764 |
-
// convert to nearest grid position
|
765 |
-
const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
766 |
-
const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
767 |
-
const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
768 |
-
|
769 |
-
const uint32_t index = level * H3 + __morton3D(nx, ny, nz);
|
770 |
-
const bool occ = grid[index / 8] & (1 << (index % 8));
|
771 |
-
|
772 |
-
// if occpuied, advance a small step, and write to output
|
773 |
-
if (occ) {
|
774 |
-
// write step
|
775 |
-
xyzs[0] = x;
|
776 |
-
xyzs[1] = y;
|
777 |
-
xyzs[2] = z;
|
778 |
-
dirs[0] = dx;
|
779 |
-
dirs[1] = dy;
|
780 |
-
dirs[2] = dz;
|
781 |
-
// calc dt
|
782 |
-
t += dt;
|
783 |
-
deltas[0] = dt;
|
784 |
-
deltas[1] = t - last_t; // used to calc depth
|
785 |
-
last_t = t;
|
786 |
-
// step
|
787 |
-
xyzs += 3;
|
788 |
-
dirs += 3;
|
789 |
-
deltas += 2;
|
790 |
-
step++;
|
791 |
-
|
792 |
-
// else, skip a large step (basically skip a voxel grid)
|
793 |
-
} else {
|
794 |
-
// calc distance to next voxel
|
795 |
-
const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx;
|
796 |
-
const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy;
|
797 |
-
const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz;
|
798 |
-
const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
|
799 |
-
// step until next voxel
|
800 |
-
do {
|
801 |
-
t += clamp(t * dt_gamma, dt_min, dt_max);
|
802 |
-
} while (t < tt);
|
803 |
-
}
|
804 |
-
}
|
805 |
-
}
|
806 |
-
|
807 |
-
|
808 |
-
void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor near, const at::Tensor far, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises) {
|
809 |
-
static constexpr uint32_t N_THREAD = 128;
|
810 |
-
|
811 |
-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
812 |
-
rays_o.scalar_type(), "march_rays", ([&] {
|
813 |
-
kernel_march_rays<<<div_round_up(n_alive, N_THREAD), N_THREAD>>>(n_alive, n_step, rays_alive.data_ptr<int>(), rays_t.data_ptr<scalar_t>(), rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), bound, dt_gamma, max_steps, C, H, grid.data_ptr<uint8_t>(), near.data_ptr<scalar_t>(), far.data_ptr<scalar_t>(), xyzs.data_ptr<scalar_t>(), dirs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), noises.data_ptr<scalar_t>());
|
814 |
-
}));
|
815 |
-
}
|
816 |
-
|
817 |
-
|
818 |
-
template <typename scalar_t>
|
819 |
-
__global__ void kernel_composite_rays(
|
820 |
-
const uint32_t n_alive,
|
821 |
-
const uint32_t n_step,
|
822 |
-
const float T_thresh,
|
823 |
-
int* rays_alive,
|
824 |
-
scalar_t* rays_t,
|
825 |
-
const scalar_t* __restrict__ sigmas,
|
826 |
-
const scalar_t* __restrict__ rgbs,
|
827 |
-
const scalar_t* __restrict__ deltas,
|
828 |
-
scalar_t* weights_sum, scalar_t* depth, scalar_t* image
|
829 |
-
) {
|
830 |
-
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
831 |
-
if (n >= n_alive) return;
|
832 |
-
|
833 |
-
const int index = rays_alive[n]; // ray id
|
834 |
-
|
835 |
-
// locate
|
836 |
-
sigmas += n * n_step;
|
837 |
-
rgbs += n * n_step * 3;
|
838 |
-
deltas += n * n_step * 2;
|
839 |
-
|
840 |
-
rays_t += index;
|
841 |
-
weights_sum += index;
|
842 |
-
depth += index;
|
843 |
-
image += index * 3;
|
844 |
-
|
845 |
-
scalar_t t = rays_t[0]; // current ray's t
|
846 |
-
|
847 |
-
scalar_t weight_sum = weights_sum[0];
|
848 |
-
scalar_t d = depth[0];
|
849 |
-
scalar_t r = image[0];
|
850 |
-
scalar_t g = image[1];
|
851 |
-
scalar_t b = image[2];
|
852 |
-
|
853 |
-
// accumulate
|
854 |
-
uint32_t step = 0;
|
855 |
-
while (step < n_step) {
|
856 |
-
|
857 |
-
// ray is terminated if delta == 0
|
858 |
-
if (deltas[0] == 0) break;
|
859 |
-
|
860 |
-
const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]);
|
861 |
-
|
862 |
-
/*
|
863 |
-
T_0 = 1; T_i = \prod_{j=0}^{i-1} (1 - alpha_j)
|
864 |
-
w_i = alpha_i * T_i
|
865 |
-
-->
|
866 |
-
T_i = 1 - \sum_{j=0}^{i-1} w_j
|
867 |
-
*/
|
868 |
-
const scalar_t T = 1 - weight_sum;
|
869 |
-
const scalar_t weight = alpha * T;
|
870 |
-
weight_sum += weight;
|
871 |
-
|
872 |
-
t += deltas[1]; // real delta
|
873 |
-
d += weight * t;
|
874 |
-
r += weight * rgbs[0];
|
875 |
-
g += weight * rgbs[1];
|
876 |
-
b += weight * rgbs[2];
|
877 |
-
|
878 |
-
//printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d);
|
879 |
-
|
880 |
-
// ray is terminated if T is too small
|
881 |
-
// use a larger bound to further accelerate inference
|
882 |
-
if (T < T_thresh) break;
|
883 |
-
|
884 |
-
// locate
|
885 |
-
sigmas++;
|
886 |
-
rgbs += 3;
|
887 |
-
deltas += 2;
|
888 |
-
step++;
|
889 |
-
}
|
890 |
-
|
891 |
-
//printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d);
|
892 |
-
|
893 |
-
// rays_alive = -1 means ray is terminated early.
|
894 |
-
if (step < n_step) {
|
895 |
-
rays_alive[n] = -1;
|
896 |
-
} else {
|
897 |
-
rays_t[0] = t;
|
898 |
-
}
|
899 |
-
|
900 |
-
weights_sum[0] = weight_sum; // this is the thing I needed!
|
901 |
-
depth[0] = d;
|
902 |
-
image[0] = r;
|
903 |
-
image[1] = g;
|
904 |
-
image[2] = b;
|
905 |
-
}
|
906 |
-
|
907 |
-
|
908 |
-
void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights, at::Tensor depth, at::Tensor image) {
|
909 |
-
static constexpr uint32_t N_THREAD = 128;
|
910 |
-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
911 |
-
image.scalar_type(), "composite_rays", ([&] {
|
912 |
-
kernel_composite_rays<<<div_round_up(n_alive, N_THREAD), N_THREAD>>>(n_alive, n_step, T_thresh, rays_alive.data_ptr<int>(), rays_t.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), weights.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>());
|
913 |
-
}));
|
914 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3drecon/raymarching/src/raymarching.h
DELETED
@@ -1,18 +0,0 @@
|
|
1 |
-
#pragma once
|
2 |
-
|
3 |
-
#include <stdint.h>
|
4 |
-
#include <torch/torch.h>
|
5 |
-
|
6 |
-
|
7 |
-
void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars);
|
8 |
-
void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords);
|
9 |
-
void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices);
|
10 |
-
void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords);
|
11 |
-
void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield);
|
12 |
-
|
13 |
-
void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises);
|
14 |
-
void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);
|
15 |
-
void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs);
|
16 |
-
|
17 |
-
void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises);
|
18 |
-
void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3drecon/renderer/agg_net.py
DELETED
@@ -1,83 +0,0 @@
|
|
1 |
-
import torch.nn.functional as F
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch
|
4 |
-
|
5 |
-
def weights_init(m):
|
6 |
-
if isinstance(m, nn.Linear):
|
7 |
-
nn.init.kaiming_normal_(m.weight.data)
|
8 |
-
if m.bias is not None:
|
9 |
-
nn.init.zeros_(m.bias.data)
|
10 |
-
|
11 |
-
class NeRF(nn.Module):
|
12 |
-
def __init__(self, vol_n=8+8, feat_ch=8+16+32+3, hid_n=64):
|
13 |
-
super(NeRF, self).__init__()
|
14 |
-
self.hid_n = hid_n
|
15 |
-
self.agg = Agg(feat_ch)
|
16 |
-
self.lr0 = nn.Sequential(nn.Linear(vol_n+16, hid_n), nn.ReLU())
|
17 |
-
self.sigma = nn.Sequential(nn.Linear(hid_n, 1), nn.Softplus())
|
18 |
-
self.color = nn.Sequential(
|
19 |
-
nn.Linear(16+vol_n+feat_ch+hid_n+4, hid_n), # agg_feats+vox_feat+img_feat+lr0_feats+dir
|
20 |
-
nn.ReLU(),
|
21 |
-
nn.Linear(hid_n, 1)
|
22 |
-
)
|
23 |
-
self.lr0.apply(weights_init)
|
24 |
-
self.sigma.apply(weights_init)
|
25 |
-
self.color.apply(weights_init)
|
26 |
-
|
27 |
-
def forward(self, vox_feat, img_feat_rgb_dir, source_img_mask):
|
28 |
-
# assert torch.sum(torch.sum(source_img_mask,1)<2)==0
|
29 |
-
b, d, n, _ = img_feat_rgb_dir.shape # b,d,n,f=8+16+32+3+4
|
30 |
-
agg_feat = self.agg(img_feat_rgb_dir, source_img_mask) # b,d,f=16
|
31 |
-
x = self.lr0(torch.cat((vox_feat, agg_feat), dim=-1)) # b,d,f=64
|
32 |
-
sigma = self.sigma(x) # b,d,1
|
33 |
-
|
34 |
-
x = torch.cat((x, vox_feat, agg_feat), dim=-1) # b,d,f=16+16+64
|
35 |
-
x = x.view(b, d, 1, x.shape[-1]).repeat(1, 1, n, 1)
|
36 |
-
x = torch.cat((x, img_feat_rgb_dir), dim=-1)
|
37 |
-
logits = self.color(x)
|
38 |
-
source_img_mask_ = source_img_mask.reshape(b, 1, n, 1).repeat(1, logits.shape[1], 1, 1) == 0
|
39 |
-
logits[source_img_mask_] = -1e7
|
40 |
-
color_weight = F.softmax(logits, dim=-2)
|
41 |
-
color = torch.sum((img_feat_rgb_dir[..., -7:-4] * color_weight), dim=-2)
|
42 |
-
return color, sigma
|
43 |
-
|
44 |
-
class Agg(nn.Module):
|
45 |
-
def __init__(self, feat_ch):
|
46 |
-
super(Agg, self).__init__()
|
47 |
-
self.feat_ch = feat_ch
|
48 |
-
self.view_fc = nn.Sequential(nn.Linear(4, feat_ch), nn.ReLU())
|
49 |
-
self.view_fc.apply(weights_init)
|
50 |
-
self.global_fc = nn.Sequential(nn.Linear(feat_ch*3, 32), nn.ReLU())
|
51 |
-
|
52 |
-
self.agg_w_fc = nn.Linear(32, 1)
|
53 |
-
self.fc = nn.Linear(32, 16)
|
54 |
-
self.global_fc.apply(weights_init)
|
55 |
-
self.agg_w_fc.apply(weights_init)
|
56 |
-
self.fc.apply(weights_init)
|
57 |
-
|
58 |
-
def masked_mean_var(self, img_feat_rgb, source_img_mask):
|
59 |
-
# img_feat_rgb: b,d,n,f source_img_mask: b,n
|
60 |
-
b, n = source_img_mask.shape
|
61 |
-
source_img_mask = source_img_mask.view(b, 1, n, 1)
|
62 |
-
mean = torch.sum(source_img_mask * img_feat_rgb, dim=-2)/ (torch.sum(source_img_mask, dim=-2) + 1e-5)
|
63 |
-
var = torch.sum((img_feat_rgb - mean.unsqueeze(-2)) ** 2 * source_img_mask, dim=-2) / (torch.sum(source_img_mask, dim=-2) + 1e-5)
|
64 |
-
return mean, var
|
65 |
-
|
66 |
-
def forward(self, img_feat_rgb_dir, source_img_mask):
|
67 |
-
# img_feat_rgb_dir b,d,n,f
|
68 |
-
b, d, n, _ = img_feat_rgb_dir.shape
|
69 |
-
view_feat = self.view_fc(img_feat_rgb_dir[..., -4:]) # b,d,n,f-4
|
70 |
-
img_feat_rgb = img_feat_rgb_dir[..., :-4] + view_feat
|
71 |
-
|
72 |
-
mean_feat, var_feat = self.masked_mean_var(img_feat_rgb, source_img_mask)
|
73 |
-
var_feat = var_feat.view(b, -1, 1, self.feat_ch).repeat(1, 1, n, 1)
|
74 |
-
avg_feat = mean_feat.view(b, -1, 1, self.feat_ch).repeat(1, 1, n, 1)
|
75 |
-
|
76 |
-
feat = torch.cat([img_feat_rgb, var_feat, avg_feat], dim=-1) # b,d,n,f
|
77 |
-
global_feat = self.global_fc(feat) # b,d,n,f
|
78 |
-
logits = self.agg_w_fc(global_feat) # b,d,n,1
|
79 |
-
source_img_mask_ = source_img_mask.reshape(b, 1, n, 1).repeat(1, logits.shape[1], 1, 1) == 0
|
80 |
-
logits[source_img_mask_] = -1e7
|
81 |
-
agg_w = F.softmax(logits, dim=-2)
|
82 |
-
im_feat = (global_feat * agg_w).sum(dim=-2)
|
83 |
-
return self.fc(im_feat)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3drecon/renderer/cost_reg_net.py
DELETED
@@ -1,95 +0,0 @@
|
|
1 |
-
import torch.nn as nn
|
2 |
-
|
3 |
-
class ConvBnReLU3D(nn.Module):
|
4 |
-
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1, norm_act=nn.BatchNorm3d):
|
5 |
-
super(ConvBnReLU3D, self).__init__()
|
6 |
-
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)
|
7 |
-
self.bn = norm_act(out_channels)
|
8 |
-
self.relu = nn.ReLU(inplace=True)
|
9 |
-
|
10 |
-
def forward(self, x):
|
11 |
-
return self.relu(self.bn(self.conv(x)))
|
12 |
-
|
13 |
-
class CostRegNet(nn.Module):
|
14 |
-
def __init__(self, in_channels, norm_act=nn.BatchNorm3d):
|
15 |
-
super(CostRegNet, self).__init__()
|
16 |
-
self.conv0 = ConvBnReLU3D(in_channels, 8, norm_act=norm_act)
|
17 |
-
|
18 |
-
self.conv1 = ConvBnReLU3D(8, 16, stride=2, norm_act=norm_act)
|
19 |
-
self.conv2 = ConvBnReLU3D(16, 16, norm_act=norm_act)
|
20 |
-
|
21 |
-
self.conv3 = ConvBnReLU3D(16, 32, stride=2, norm_act=norm_act)
|
22 |
-
self.conv4 = ConvBnReLU3D(32, 32, norm_act=norm_act)
|
23 |
-
|
24 |
-
self.conv5 = ConvBnReLU3D(32, 64, stride=2, norm_act=norm_act)
|
25 |
-
self.conv6 = ConvBnReLU3D(64, 64, norm_act=norm_act)
|
26 |
-
|
27 |
-
self.conv7 = nn.Sequential(
|
28 |
-
nn.ConvTranspose3d(64, 32, 3, padding=1, output_padding=1, stride=2, bias=False),
|
29 |
-
norm_act(32)
|
30 |
-
)
|
31 |
-
|
32 |
-
self.conv9 = nn.Sequential(
|
33 |
-
nn.ConvTranspose3d(32, 16, 3, padding=1, output_padding=1, stride=2, bias=False),
|
34 |
-
norm_act(16)
|
35 |
-
)
|
36 |
-
|
37 |
-
self.conv11 = nn.Sequential(
|
38 |
-
nn.ConvTranspose3d(16, 8, 3, padding=1, output_padding=1,stride=2, bias=False),
|
39 |
-
norm_act(8)
|
40 |
-
)
|
41 |
-
self.depth_conv = nn.Sequential(nn.Conv3d(8, 1, 3, padding=1, bias=False))
|
42 |
-
self.feat_conv = nn.Sequential(nn.Conv3d(8, 8, 3, padding=1, bias=False))
|
43 |
-
|
44 |
-
def forward(self, x):
|
45 |
-
conv0 = self.conv0(x)
|
46 |
-
conv2 = self.conv2(self.conv1(conv0))
|
47 |
-
conv4 = self.conv4(self.conv3(conv2))
|
48 |
-
x = self.conv6(self.conv5(conv4))
|
49 |
-
x = conv4 + self.conv7(x)
|
50 |
-
del conv4
|
51 |
-
x = conv2 + self.conv9(x)
|
52 |
-
del conv2
|
53 |
-
x = conv0 + self.conv11(x)
|
54 |
-
del conv0
|
55 |
-
feat = self.feat_conv(x)
|
56 |
-
depth = self.depth_conv(x)
|
57 |
-
return feat, depth
|
58 |
-
|
59 |
-
|
60 |
-
class MinCostRegNet(nn.Module):
|
61 |
-
def __init__(self, in_channels, norm_act=nn.BatchNorm3d):
|
62 |
-
super(MinCostRegNet, self).__init__()
|
63 |
-
self.conv0 = ConvBnReLU3D(in_channels, 8, norm_act=norm_act)
|
64 |
-
|
65 |
-
self.conv1 = ConvBnReLU3D(8, 16, stride=2, norm_act=norm_act)
|
66 |
-
self.conv2 = ConvBnReLU3D(16, 16, norm_act=norm_act)
|
67 |
-
|
68 |
-
self.conv3 = ConvBnReLU3D(16, 32, stride=2, norm_act=norm_act)
|
69 |
-
self.conv4 = ConvBnReLU3D(32, 32, norm_act=norm_act)
|
70 |
-
|
71 |
-
self.conv9 = nn.Sequential(
|
72 |
-
nn.ConvTranspose3d(32, 16, 3, padding=1, output_padding=1,
|
73 |
-
stride=2, bias=False),
|
74 |
-
norm_act(16))
|
75 |
-
|
76 |
-
self.conv11 = nn.Sequential(
|
77 |
-
nn.ConvTranspose3d(16, 8, 3, padding=1, output_padding=1,
|
78 |
-
stride=2, bias=False),
|
79 |
-
norm_act(8))
|
80 |
-
|
81 |
-
self.depth_conv = nn.Sequential(nn.Conv3d(8, 1, 3, padding=1, bias=False))
|
82 |
-
self.feat_conv = nn.Sequential(nn.Conv3d(8, 8, 3, padding=1, bias=False))
|
83 |
-
|
84 |
-
def forward(self, x):
|
85 |
-
conv0 = self.conv0(x)
|
86 |
-
conv2 = self.conv2(self.conv1(conv0))
|
87 |
-
conv4 = self.conv4(self.conv3(conv2))
|
88 |
-
x = conv4
|
89 |
-
x = conv2 + self.conv9(x)
|
90 |
-
del conv2
|
91 |
-
x = conv0 + self.conv11(x)
|
92 |
-
del conv0
|
93 |
-
feat = self.feat_conv(x)
|
94 |
-
depth = self.depth_conv(x)
|
95 |
-
return feat, depth
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3drecon/renderer/dummy_dataset.py
DELETED
@@ -1,40 +0,0 @@
|
|
1 |
-
import pytorch_lightning as pl
|
2 |
-
from torch.utils.data import Dataset
|
3 |
-
import webdataset as wds
|
4 |
-
from torch.utils.data.distributed import DistributedSampler
|
5 |
-
class DummyDataset(pl.LightningDataModule):
|
6 |
-
def __init__(self,seed):
|
7 |
-
super().__init__()
|
8 |
-
|
9 |
-
def setup(self, stage):
|
10 |
-
if stage in ['fit']:
|
11 |
-
self.train_dataset = DummyData(True)
|
12 |
-
self.val_dataset = DummyData(False)
|
13 |
-
else:
|
14 |
-
raise NotImplementedError
|
15 |
-
|
16 |
-
def train_dataloader(self):
|
17 |
-
return wds.WebLoader(self.train_dataset, batch_size=1, num_workers=0, shuffle=False)
|
18 |
-
|
19 |
-
def val_dataloader(self):
|
20 |
-
return wds.WebLoader(self.val_dataset, batch_size=1, num_workers=0, shuffle=False)
|
21 |
-
|
22 |
-
def test_dataloader(self):
|
23 |
-
return wds.WebLoader(DummyData(False))
|
24 |
-
|
25 |
-
class DummyData(Dataset):
|
26 |
-
def __init__(self,is_train):
|
27 |
-
self.is_train=is_train
|
28 |
-
|
29 |
-
def __len__(self):
|
30 |
-
if self.is_train:
|
31 |
-
return 99999999
|
32 |
-
else:
|
33 |
-
return 1
|
34 |
-
|
35 |
-
def __getitem__(self, index):
|
36 |
-
return {}
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3drecon/renderer/feature_net.py
DELETED
@@ -1,42 +0,0 @@
|
|
1 |
-
import torch.nn as nn
|
2 |
-
import torch.nn.functional as F
|
3 |
-
|
4 |
-
class ConvBnReLU(nn.Module):
|
5 |
-
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1, norm_act=nn.BatchNorm2d):
|
6 |
-
super(ConvBnReLU, self).__init__()
|
7 |
-
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)
|
8 |
-
self.bn = norm_act(out_channels)
|
9 |
-
self.relu = nn.ReLU(inplace=True)
|
10 |
-
|
11 |
-
def forward(self, x):
|
12 |
-
return self.relu(self.bn(self.conv(x)))
|
13 |
-
|
14 |
-
class FeatureNet(nn.Module):
|
15 |
-
def __init__(self, norm_act=nn.BatchNorm2d):
|
16 |
-
super(FeatureNet, self).__init__()
|
17 |
-
self.conv0 = nn.Sequential(ConvBnReLU(3, 8, 3, 1, 1, norm_act=norm_act), ConvBnReLU(8, 8, 3, 1, 1, norm_act=norm_act))
|
18 |
-
self.conv1 = nn.Sequential(ConvBnReLU(8, 16, 5, 2, 2, norm_act=norm_act), ConvBnReLU(16, 16, 3, 1, 1, norm_act=norm_act))
|
19 |
-
self.conv2 = nn.Sequential(ConvBnReLU(16, 32, 5, 2, 2, norm_act=norm_act), ConvBnReLU(32, 32, 3, 1, 1, norm_act=norm_act))
|
20 |
-
|
21 |
-
self.toplayer = nn.Conv2d(32, 32, 1)
|
22 |
-
self.lat1 = nn.Conv2d(16, 32, 1)
|
23 |
-
self.lat0 = nn.Conv2d(8, 32, 1)
|
24 |
-
|
25 |
-
self.smooth1 = nn.Conv2d(32, 16, 3, padding=1)
|
26 |
-
self.smooth0 = nn.Conv2d(32, 8, 3, padding=1)
|
27 |
-
|
28 |
-
def _upsample_add(self, x, y):
|
29 |
-
return F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) + y
|
30 |
-
|
31 |
-
def forward(self, x):
|
32 |
-
conv0 = self.conv0(x)
|
33 |
-
conv1 = self.conv1(conv0)
|
34 |
-
conv2 = self.conv2(conv1)
|
35 |
-
feat2 = self.toplayer(conv2)
|
36 |
-
feat1 = self._upsample_add(feat2, self.lat1(conv1))
|
37 |
-
feat0 = self._upsample_add(feat1, self.lat0(conv0))
|
38 |
-
feat1 = self.smooth1(feat1)
|
39 |
-
feat0 = self.smooth0(feat0)
|
40 |
-
return feat2, feat1, feat0
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3drecon/renderer/neus_networks.py
DELETED
@@ -1,503 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
|
3 |
-
import numpy as np
|
4 |
-
import torch
|
5 |
-
import torch.nn as nn
|
6 |
-
import torch.nn.functional as F
|
7 |
-
import tinycudann as tcnn
|
8 |
-
|
9 |
-
# Positional encoding embedding. Code was taken from https://github.com/bmild/nerf.
|
10 |
-
class Embedder:
|
11 |
-
def __init__(self, **kwargs):
|
12 |
-
self.kwargs = kwargs
|
13 |
-
self.create_embedding_fn()
|
14 |
-
|
15 |
-
def create_embedding_fn(self):
|
16 |
-
embed_fns = []
|
17 |
-
d = self.kwargs['input_dims']
|
18 |
-
out_dim = 0
|
19 |
-
if self.kwargs['include_input']:
|
20 |
-
embed_fns.append(lambda x: x)
|
21 |
-
out_dim += d
|
22 |
-
|
23 |
-
max_freq = self.kwargs['max_freq_log2']
|
24 |
-
N_freqs = self.kwargs['num_freqs']
|
25 |
-
|
26 |
-
if self.kwargs['log_sampling']:
|
27 |
-
freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
|
28 |
-
else:
|
29 |
-
freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, N_freqs)
|
30 |
-
|
31 |
-
for freq in freq_bands:
|
32 |
-
for p_fn in self.kwargs['periodic_fns']:
|
33 |
-
embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
|
34 |
-
out_dim += d
|
35 |
-
|
36 |
-
self.embed_fns = embed_fns
|
37 |
-
self.out_dim = out_dim
|
38 |
-
|
39 |
-
def embed(self, inputs):
|
40 |
-
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
|
41 |
-
|
42 |
-
|
43 |
-
def get_embedder(multires, input_dims=3):
|
44 |
-
embed_kwargs = {
|
45 |
-
'include_input': True,
|
46 |
-
'input_dims': input_dims,
|
47 |
-
'max_freq_log2': multires - 1,
|
48 |
-
'num_freqs': multires,
|
49 |
-
'log_sampling': True,
|
50 |
-
'periodic_fns': [torch.sin, torch.cos],
|
51 |
-
}
|
52 |
-
|
53 |
-
embedder_obj = Embedder(**embed_kwargs)
|
54 |
-
|
55 |
-
def embed(x, eo=embedder_obj): return eo.embed(x)
|
56 |
-
|
57 |
-
return embed, embedder_obj.out_dim
|
58 |
-
|
59 |
-
|
60 |
-
class SDFNetwork(nn.Module):
|
61 |
-
def __init__(self, d_in, d_out, d_hidden, n_layers, skip_in=(4,), multires=0, bias=0.5,
|
62 |
-
scale=1, geometric_init=True, weight_norm=True, inside_outside=False):
|
63 |
-
super(SDFNetwork, self).__init__()
|
64 |
-
|
65 |
-
dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out]
|
66 |
-
|
67 |
-
self.embed_fn_fine = None
|
68 |
-
|
69 |
-
if multires > 0:
|
70 |
-
embed_fn, input_ch = get_embedder(multires, input_dims=d_in)
|
71 |
-
self.embed_fn_fine = embed_fn
|
72 |
-
dims[0] = input_ch
|
73 |
-
|
74 |
-
self.num_layers = len(dims)
|
75 |
-
self.skip_in = skip_in
|
76 |
-
self.scale = scale
|
77 |
-
|
78 |
-
for l in range(0, self.num_layers - 1):
|
79 |
-
if l + 1 in self.skip_in:
|
80 |
-
out_dim = dims[l + 1] - dims[0]
|
81 |
-
else:
|
82 |
-
out_dim = dims[l + 1]
|
83 |
-
|
84 |
-
lin = nn.Linear(dims[l], out_dim)
|
85 |
-
|
86 |
-
if geometric_init:
|
87 |
-
if l == self.num_layers - 2:
|
88 |
-
if not inside_outside:
|
89 |
-
torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
|
90 |
-
torch.nn.init.constant_(lin.bias, -bias)
|
91 |
-
else:
|
92 |
-
torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
|
93 |
-
torch.nn.init.constant_(lin.bias, bias)
|
94 |
-
elif multires > 0 and l == 0:
|
95 |
-
torch.nn.init.constant_(lin.bias, 0.0)
|
96 |
-
torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
|
97 |
-
torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
|
98 |
-
elif multires > 0 and l in self.skip_in:
|
99 |
-
torch.nn.init.constant_(lin.bias, 0.0)
|
100 |
-
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
|
101 |
-
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
|
102 |
-
else:
|
103 |
-
torch.nn.init.constant_(lin.bias, 0.0)
|
104 |
-
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
|
105 |
-
|
106 |
-
if weight_norm:
|
107 |
-
lin = nn.utils.weight_norm(lin)
|
108 |
-
|
109 |
-
setattr(self, "lin" + str(l), lin)
|
110 |
-
|
111 |
-
self.activation = nn.Softplus(beta=100)
|
112 |
-
|
113 |
-
def forward(self, inputs):
|
114 |
-
inputs = inputs * self.scale
|
115 |
-
if self.embed_fn_fine is not None:
|
116 |
-
inputs = self.embed_fn_fine(inputs)
|
117 |
-
|
118 |
-
x = inputs
|
119 |
-
for l in range(0, self.num_layers - 1):
|
120 |
-
lin = getattr(self, "lin" + str(l))
|
121 |
-
|
122 |
-
if l in self.skip_in:
|
123 |
-
x = torch.cat([x, inputs], -1) / np.sqrt(2)
|
124 |
-
|
125 |
-
x = lin(x)
|
126 |
-
|
127 |
-
if l < self.num_layers - 2:
|
128 |
-
x = self.activation(x)
|
129 |
-
|
130 |
-
return x
|
131 |
-
|
132 |
-
def sdf(self, x):
|
133 |
-
return self.forward(x)[..., :1]
|
134 |
-
|
135 |
-
def sdf_hidden_appearance(self, x):
|
136 |
-
return self.forward(x)
|
137 |
-
|
138 |
-
def gradient(self, x):
|
139 |
-
x.requires_grad_(True)
|
140 |
-
with torch.enable_grad():
|
141 |
-
y = self.sdf(x)
|
142 |
-
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
|
143 |
-
gradients = torch.autograd.grad(
|
144 |
-
outputs=y,
|
145 |
-
inputs=x,
|
146 |
-
grad_outputs=d_output,
|
147 |
-
create_graph=True,
|
148 |
-
retain_graph=True,
|
149 |
-
only_inputs=True)[0]
|
150 |
-
return gradients
|
151 |
-
|
152 |
-
def sdf_normal(self, x):
|
153 |
-
x.requires_grad_(True)
|
154 |
-
with torch.enable_grad():
|
155 |
-
y = self.sdf(x)
|
156 |
-
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
|
157 |
-
gradients = torch.autograd.grad(
|
158 |
-
outputs=y,
|
159 |
-
inputs=x,
|
160 |
-
grad_outputs=d_output,
|
161 |
-
create_graph=True,
|
162 |
-
retain_graph=True,
|
163 |
-
only_inputs=True)[0]
|
164 |
-
return y[..., :1].detach(), gradients.detach()
|
165 |
-
|
166 |
-
class SDFNetworkWithFeature(nn.Module):
|
167 |
-
def __init__(self, cube, dp_in, df_in, d_out, d_hidden, n_layers, skip_in=(4,), multires=0, bias=0.5,
|
168 |
-
scale=1, geometric_init=True, weight_norm=True, inside_outside=False, cube_length=0.5):
|
169 |
-
super().__init__()
|
170 |
-
|
171 |
-
self.register_buffer("cube", cube)
|
172 |
-
self.cube_length = cube_length
|
173 |
-
dims = [dp_in+df_in] + [d_hidden for _ in range(n_layers)] + [d_out]
|
174 |
-
|
175 |
-
self.embed_fn_fine = None
|
176 |
-
|
177 |
-
if multires > 0:
|
178 |
-
embed_fn, input_ch = get_embedder(multires, input_dims=dp_in)
|
179 |
-
self.embed_fn_fine = embed_fn
|
180 |
-
dims[0] = input_ch + df_in
|
181 |
-
|
182 |
-
self.num_layers = len(dims)
|
183 |
-
self.skip_in = skip_in
|
184 |
-
self.scale = scale
|
185 |
-
|
186 |
-
for l in range(0, self.num_layers - 1):
|
187 |
-
if l + 1 in self.skip_in:
|
188 |
-
out_dim = dims[l + 1] - dims[0]
|
189 |
-
else:
|
190 |
-
out_dim = dims[l + 1]
|
191 |
-
|
192 |
-
lin = nn.Linear(dims[l], out_dim)
|
193 |
-
|
194 |
-
if geometric_init:
|
195 |
-
if l == self.num_layers - 2:
|
196 |
-
if not inside_outside:
|
197 |
-
torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
|
198 |
-
torch.nn.init.constant_(lin.bias, -bias)
|
199 |
-
else:
|
200 |
-
torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
|
201 |
-
torch.nn.init.constant_(lin.bias, bias)
|
202 |
-
elif multires > 0 and l == 0:
|
203 |
-
torch.nn.init.constant_(lin.bias, 0.0)
|
204 |
-
torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
|
205 |
-
torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
|
206 |
-
elif multires > 0 and l in self.skip_in:
|
207 |
-
torch.nn.init.constant_(lin.bias, 0.0)
|
208 |
-
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
|
209 |
-
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
|
210 |
-
else:
|
211 |
-
torch.nn.init.constant_(lin.bias, 0.0)
|
212 |
-
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
|
213 |
-
|
214 |
-
if weight_norm:
|
215 |
-
lin = nn.utils.weight_norm(lin)
|
216 |
-
|
217 |
-
setattr(self, "lin" + str(l), lin)
|
218 |
-
|
219 |
-
self.activation = nn.Softplus(beta=100)
|
220 |
-
|
221 |
-
def forward(self, points):
|
222 |
-
points = points * self.scale
|
223 |
-
|
224 |
-
# note: point*2 because the cube is [-0.5,0.5]
|
225 |
-
with torch.no_grad():
|
226 |
-
feats = F.grid_sample(self.cube, points.view(1,-1,1,1,3)/self.cube_length, mode='bilinear', align_corners=True, padding_mode='zeros').detach()
|
227 |
-
feats = feats.view(self.cube.shape[1], -1).permute(1,0).view(*points.shape[:-1], -1)
|
228 |
-
if self.embed_fn_fine is not None:
|
229 |
-
points = self.embed_fn_fine(points)
|
230 |
-
|
231 |
-
x = torch.cat([points, feats], -1)
|
232 |
-
for l in range(0, self.num_layers - 1):
|
233 |
-
lin = getattr(self, "lin" + str(l))
|
234 |
-
|
235 |
-
if l in self.skip_in:
|
236 |
-
x = torch.cat([x, points, feats], -1) / np.sqrt(2)
|
237 |
-
|
238 |
-
x = lin(x)
|
239 |
-
|
240 |
-
if l < self.num_layers - 2:
|
241 |
-
x = self.activation(x)
|
242 |
-
|
243 |
-
# concat feats
|
244 |
-
x = torch.cat([x, feats], -1)
|
245 |
-
return x
|
246 |
-
|
247 |
-
def sdf(self, x):
|
248 |
-
return self.forward(x)[..., :1]
|
249 |
-
|
250 |
-
def sdf_hidden_appearance(self, x):
|
251 |
-
return self.forward(x)
|
252 |
-
|
253 |
-
def gradient(self, x):
|
254 |
-
x.requires_grad_(True)
|
255 |
-
with torch.enable_grad():
|
256 |
-
y = self.sdf(x)
|
257 |
-
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
|
258 |
-
gradients = torch.autograd.grad(
|
259 |
-
outputs=y,
|
260 |
-
inputs=x,
|
261 |
-
grad_outputs=d_output,
|
262 |
-
create_graph=True,
|
263 |
-
retain_graph=True,
|
264 |
-
only_inputs=True)[0]
|
265 |
-
return gradients
|
266 |
-
|
267 |
-
def sdf_normal(self, x):
|
268 |
-
x.requires_grad_(True)
|
269 |
-
with torch.enable_grad():
|
270 |
-
y = self.sdf(x)
|
271 |
-
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
|
272 |
-
gradients = torch.autograd.grad(
|
273 |
-
outputs=y,
|
274 |
-
inputs=x,
|
275 |
-
grad_outputs=d_output,
|
276 |
-
create_graph=True,
|
277 |
-
retain_graph=True,
|
278 |
-
only_inputs=True)[0]
|
279 |
-
return y[..., :1].detach(), gradients.detach()
|
280 |
-
|
281 |
-
|
282 |
-
class VanillaMLP(nn.Module):
|
283 |
-
def __init__(self, dim_in, dim_out, n_neurons, n_hidden_layers):
|
284 |
-
super().__init__()
|
285 |
-
self.n_neurons, self.n_hidden_layers = n_neurons, n_hidden_layers
|
286 |
-
self.sphere_init, self.weight_norm = True, True
|
287 |
-
self.sphere_init_radius = 0.5
|
288 |
-
self.layers = [self.make_linear(dim_in, self.n_neurons, is_first=True, is_last=False), self.make_activation()]
|
289 |
-
for i in range(self.n_hidden_layers - 1):
|
290 |
-
self.layers += [self.make_linear(self.n_neurons, self.n_neurons, is_first=False, is_last=False), self.make_activation()]
|
291 |
-
self.layers += [self.make_linear(self.n_neurons, dim_out, is_first=False, is_last=True)]
|
292 |
-
self.layers = nn.Sequential(*self.layers)
|
293 |
-
|
294 |
-
@torch.cuda.amp.autocast(False)
|
295 |
-
def forward(self, x):
|
296 |
-
x = self.layers(x.float())
|
297 |
-
return x
|
298 |
-
|
299 |
-
def make_linear(self, dim_in, dim_out, is_first, is_last):
|
300 |
-
layer = nn.Linear(dim_in, dim_out, bias=True) # network without bias will degrade quality
|
301 |
-
if self.sphere_init:
|
302 |
-
if is_last:
|
303 |
-
torch.nn.init.constant_(layer.bias, -self.sphere_init_radius)
|
304 |
-
torch.nn.init.normal_(layer.weight, mean=math.sqrt(math.pi) / math.sqrt(dim_in), std=0.0001)
|
305 |
-
elif is_first:
|
306 |
-
torch.nn.init.constant_(layer.bias, 0.0)
|
307 |
-
torch.nn.init.constant_(layer.weight[:, 3:], 0.0)
|
308 |
-
torch.nn.init.normal_(layer.weight[:, :3], 0.0, math.sqrt(2) / math.sqrt(dim_out))
|
309 |
-
else:
|
310 |
-
torch.nn.init.constant_(layer.bias, 0.0)
|
311 |
-
torch.nn.init.normal_(layer.weight, 0.0, math.sqrt(2) / math.sqrt(dim_out))
|
312 |
-
else:
|
313 |
-
torch.nn.init.constant_(layer.bias, 0.0)
|
314 |
-
torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity='relu')
|
315 |
-
|
316 |
-
if self.weight_norm:
|
317 |
-
layer = nn.utils.weight_norm(layer)
|
318 |
-
return layer
|
319 |
-
|
320 |
-
def make_activation(self):
|
321 |
-
if self.sphere_init:
|
322 |
-
return nn.Softplus(beta=100)
|
323 |
-
else:
|
324 |
-
return nn.ReLU(inplace=True)
|
325 |
-
|
326 |
-
|
327 |
-
class SDFHashGridNetwork(nn.Module):
|
328 |
-
def __init__(self, bound=0.5, feats_dim=13):
|
329 |
-
super().__init__()
|
330 |
-
self.bound = bound
|
331 |
-
# max_resolution = 32
|
332 |
-
# base_resolution = 16
|
333 |
-
# n_levels = 4
|
334 |
-
# log2_hashmap_size = 16
|
335 |
-
# n_features_per_level = 8
|
336 |
-
max_resolution = 2048
|
337 |
-
base_resolution = 16
|
338 |
-
n_levels = 16
|
339 |
-
log2_hashmap_size = 19
|
340 |
-
n_features_per_level = 2
|
341 |
-
|
342 |
-
# max_res = base_res * t^(k-1)
|
343 |
-
per_level_scale = (max_resolution / base_resolution)** (1 / (n_levels - 1))
|
344 |
-
|
345 |
-
self.encoder = tcnn.Encoding(
|
346 |
-
n_input_dims=3,
|
347 |
-
encoding_config={
|
348 |
-
"otype": "HashGrid",
|
349 |
-
"n_levels": n_levels,
|
350 |
-
"n_features_per_level": n_features_per_level,
|
351 |
-
"log2_hashmap_size": log2_hashmap_size,
|
352 |
-
"base_resolution": base_resolution,
|
353 |
-
"per_level_scale": per_level_scale,
|
354 |
-
},
|
355 |
-
)
|
356 |
-
self.sdf_mlp = VanillaMLP(n_levels*n_features_per_level+3,feats_dim,64,1)
|
357 |
-
|
358 |
-
def forward(self, x):
|
359 |
-
shape = x.shape[:-1]
|
360 |
-
x = x.reshape(-1, 3)
|
361 |
-
x_ = (x + self.bound) / (2 * self.bound)
|
362 |
-
feats = self.encoder(x_)
|
363 |
-
feats = torch.cat([x, feats], 1)
|
364 |
-
|
365 |
-
feats = self.sdf_mlp(feats)
|
366 |
-
feats = feats.reshape(*shape,-1)
|
367 |
-
return feats
|
368 |
-
|
369 |
-
def sdf(self, x):
|
370 |
-
return self(x)[...,:1]
|
371 |
-
|
372 |
-
def gradient(self, x):
|
373 |
-
x.requires_grad_(True)
|
374 |
-
with torch.enable_grad():
|
375 |
-
y = self.sdf(x)
|
376 |
-
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
|
377 |
-
gradients = torch.autograd.grad(
|
378 |
-
outputs=y,
|
379 |
-
inputs=x,
|
380 |
-
grad_outputs=d_output,
|
381 |
-
create_graph=True,
|
382 |
-
retain_graph=True,
|
383 |
-
only_inputs=True)[0]
|
384 |
-
return gradients
|
385 |
-
|
386 |
-
def sdf_normal(self, x):
|
387 |
-
x.requires_grad_(True)
|
388 |
-
with torch.enable_grad():
|
389 |
-
y = self.sdf(x)
|
390 |
-
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
|
391 |
-
gradients = torch.autograd.grad(
|
392 |
-
outputs=y,
|
393 |
-
inputs=x,
|
394 |
-
grad_outputs=d_output,
|
395 |
-
create_graph=True,
|
396 |
-
retain_graph=True,
|
397 |
-
only_inputs=True)[0]
|
398 |
-
return y[..., :1].detach(), gradients.detach()
|
399 |
-
|
400 |
-
class RenderingFFNetwork(nn.Module):
|
401 |
-
def __init__(self, in_feats_dim=12):
|
402 |
-
super().__init__()
|
403 |
-
self.dir_encoder = tcnn.Encoding(
|
404 |
-
n_input_dims=3,
|
405 |
-
encoding_config={
|
406 |
-
"otype": "SphericalHarmonics",
|
407 |
-
"degree": 4,
|
408 |
-
},
|
409 |
-
)
|
410 |
-
self.color_mlp = tcnn.Network(
|
411 |
-
n_input_dims = in_feats_dim + 3 + self.dir_encoder.n_output_dims,
|
412 |
-
n_output_dims = 3,
|
413 |
-
network_config={
|
414 |
-
"otype": "FullyFusedMLP",
|
415 |
-
"activation": "ReLU",
|
416 |
-
"output_activation": "none",
|
417 |
-
"n_neurons": 64,
|
418 |
-
"n_hidden_layers": 2,
|
419 |
-
},
|
420 |
-
)
|
421 |
-
|
422 |
-
def forward(self, points, normals, view_dirs, feature_vectors):
|
423 |
-
normals = F.normalize(normals, dim=-1)
|
424 |
-
view_dirs = F.normalize(view_dirs, dim=-1)
|
425 |
-
reflective = torch.sum(view_dirs * normals, -1, keepdim=True) * normals * 2 - view_dirs
|
426 |
-
|
427 |
-
x = torch.cat([feature_vectors, normals, self.dir_encoder(reflective)], -1)
|
428 |
-
colors = self.color_mlp(x).float()
|
429 |
-
colors = F.sigmoid(colors)
|
430 |
-
return colors
|
431 |
-
|
432 |
-
# This implementation is borrowed from IDR: https://github.com/lioryariv/idr
|
433 |
-
class RenderingNetwork(nn.Module):
|
434 |
-
def __init__(self, d_feature, d_in, d_out, d_hidden,
|
435 |
-
n_layers, weight_norm=True, multires_view=0, squeeze_out=True, use_view_dir=True):
|
436 |
-
super().__init__()
|
437 |
-
|
438 |
-
self.squeeze_out = squeeze_out
|
439 |
-
self.rgb_act=F.sigmoid
|
440 |
-
self.use_view_dir=use_view_dir
|
441 |
-
|
442 |
-
dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out]
|
443 |
-
|
444 |
-
self.embedview_fn = None
|
445 |
-
if multires_view > 0:
|
446 |
-
embedview_fn, input_ch = get_embedder(multires_view)
|
447 |
-
self.embedview_fn = embedview_fn
|
448 |
-
dims[0] += (input_ch - 3)
|
449 |
-
|
450 |
-
self.num_layers = len(dims)
|
451 |
-
|
452 |
-
for l in range(0, self.num_layers - 1):
|
453 |
-
out_dim = dims[l + 1]
|
454 |
-
lin = nn.Linear(dims[l], out_dim)
|
455 |
-
|
456 |
-
if weight_norm:
|
457 |
-
lin = nn.utils.weight_norm(lin)
|
458 |
-
|
459 |
-
setattr(self, "lin" + str(l), lin)
|
460 |
-
|
461 |
-
self.relu = nn.ReLU()
|
462 |
-
|
463 |
-
def forward(self, points, normals, view_dirs, feature_vectors):
|
464 |
-
if self.use_view_dir:
|
465 |
-
view_dirs = F.normalize(view_dirs, dim=-1)
|
466 |
-
normals = F.normalize(normals, dim=-1)
|
467 |
-
reflective = torch.sum(view_dirs*normals, -1, keepdim=True) * normals * 2 - view_dirs
|
468 |
-
if self.embedview_fn is not None: reflective = self.embedview_fn(reflective)
|
469 |
-
rendering_input = torch.cat([points, reflective, normals, feature_vectors], dim=-1)
|
470 |
-
else:
|
471 |
-
rendering_input = torch.cat([points, normals, feature_vectors], dim=-1)
|
472 |
-
|
473 |
-
x = rendering_input
|
474 |
-
|
475 |
-
for l in range(0, self.num_layers - 1):
|
476 |
-
lin = getattr(self, "lin" + str(l))
|
477 |
-
|
478 |
-
x = lin(x)
|
479 |
-
|
480 |
-
if l < self.num_layers - 2:
|
481 |
-
x = self.relu(x)
|
482 |
-
|
483 |
-
if self.squeeze_out:
|
484 |
-
x = self.rgb_act(x)
|
485 |
-
return x
|
486 |
-
|
487 |
-
|
488 |
-
class SingleVarianceNetwork(nn.Module):
|
489 |
-
def __init__(self, init_val, activation='exp'):
|
490 |
-
super(SingleVarianceNetwork, self).__init__()
|
491 |
-
self.act = activation
|
492 |
-
self.register_parameter('variance', nn.Parameter(torch.tensor(init_val)))
|
493 |
-
|
494 |
-
def forward(self, x):
|
495 |
-
device = x.device
|
496 |
-
if self.act=='exp':
|
497 |
-
return torch.ones([*x.shape[:-1], 1], dtype=torch.float32, device=device) * torch.exp(self.variance * 10.0)
|
498 |
-
else:
|
499 |
-
raise NotImplementedError
|
500 |
-
|
501 |
-
def warp(self, x, inv_s):
|
502 |
-
device = x.device
|
503 |
-
return torch.ones([*x.shape[:-1], 1], dtype=torch.float32, device=device) * inv_s
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3drecon/renderer/ngp_renderer.py
DELETED
@@ -1,721 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
import trimesh
|
3 |
-
import numpy as np
|
4 |
-
|
5 |
-
import torch
|
6 |
-
import torch.nn as nn
|
7 |
-
import torch.nn.functional as F
|
8 |
-
from packaging import version as pver
|
9 |
-
|
10 |
-
import tinycudann as tcnn
|
11 |
-
from torch.autograd import Function
|
12 |
-
|
13 |
-
from torch.cuda.amp import custom_bwd, custom_fwd
|
14 |
-
|
15 |
-
import raymarching
|
16 |
-
|
17 |
-
def custom_meshgrid(*args):
|
18 |
-
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
|
19 |
-
if pver.parse(torch.__version__) < pver.parse('1.10'):
|
20 |
-
return torch.meshgrid(*args)
|
21 |
-
else:
|
22 |
-
return torch.meshgrid(*args, indexing='ij')
|
23 |
-
|
24 |
-
def sample_pdf(bins, weights, n_samples, det=False):
|
25 |
-
# This implementation is from NeRF
|
26 |
-
# bins: [B, T], old_z_vals
|
27 |
-
# weights: [B, T - 1], bin weights.
|
28 |
-
# return: [B, n_samples], new_z_vals
|
29 |
-
|
30 |
-
# Get pdf
|
31 |
-
weights = weights + 1e-5 # prevent nans
|
32 |
-
pdf = weights / torch.sum(weights, -1, keepdim=True)
|
33 |
-
cdf = torch.cumsum(pdf, -1)
|
34 |
-
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
|
35 |
-
# Take uniform samples
|
36 |
-
if det:
|
37 |
-
u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(weights.device)
|
38 |
-
u = u.expand(list(cdf.shape[:-1]) + [n_samples])
|
39 |
-
else:
|
40 |
-
u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device)
|
41 |
-
|
42 |
-
# Invert CDF
|
43 |
-
u = u.contiguous()
|
44 |
-
inds = torch.searchsorted(cdf, u, right=True)
|
45 |
-
below = torch.max(torch.zeros_like(inds - 1), inds - 1)
|
46 |
-
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
|
47 |
-
inds_g = torch.stack([below, above], -1) # (B, n_samples, 2)
|
48 |
-
|
49 |
-
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
|
50 |
-
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
|
51 |
-
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
|
52 |
-
|
53 |
-
denom = (cdf_g[..., 1] - cdf_g[..., 0])
|
54 |
-
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
|
55 |
-
t = (u - cdf_g[..., 0]) / denom
|
56 |
-
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
|
57 |
-
|
58 |
-
return samples
|
59 |
-
|
60 |
-
|
61 |
-
def plot_pointcloud(pc, color=None):
|
62 |
-
# pc: [N, 3]
|
63 |
-
# color: [N, 3/4]
|
64 |
-
print('[visualize points]', pc.shape, pc.dtype, pc.min(0), pc.max(0))
|
65 |
-
pc = trimesh.PointCloud(pc, color)
|
66 |
-
# axis
|
67 |
-
axes = trimesh.creation.axis(axis_length=4)
|
68 |
-
# sphere
|
69 |
-
sphere = trimesh.creation.icosphere(radius=1)
|
70 |
-
trimesh.Scene([pc, axes, sphere]).show()
|
71 |
-
|
72 |
-
|
73 |
-
class NGPRenderer(nn.Module):
|
74 |
-
def __init__(self,
|
75 |
-
bound=1,
|
76 |
-
cuda_ray=True,
|
77 |
-
density_scale=1, # scale up deltas (or sigmas), to make the density grid more sharp. larger value than 1 usually improves performance.
|
78 |
-
min_near=0.2,
|
79 |
-
density_thresh=0.01,
|
80 |
-
bg_radius=-1,
|
81 |
-
):
|
82 |
-
super().__init__()
|
83 |
-
|
84 |
-
self.bound = bound
|
85 |
-
self.cascade = 1
|
86 |
-
self.grid_size = 128
|
87 |
-
self.density_scale = density_scale
|
88 |
-
self.min_near = min_near
|
89 |
-
self.density_thresh = density_thresh
|
90 |
-
self.bg_radius = bg_radius # radius of the background sphere.
|
91 |
-
|
92 |
-
# prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax)
|
93 |
-
# NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing.
|
94 |
-
aabb_train = torch.FloatTensor([-bound, -bound, -bound, bound, bound, bound])
|
95 |
-
aabb_infer = aabb_train.clone()
|
96 |
-
self.register_buffer('aabb_train', aabb_train)
|
97 |
-
self.register_buffer('aabb_infer', aabb_infer)
|
98 |
-
|
99 |
-
# extra state for cuda raymarching
|
100 |
-
self.cuda_ray = cuda_ray
|
101 |
-
if cuda_ray:
|
102 |
-
# density grid
|
103 |
-
density_grid = torch.zeros([self.cascade, self.grid_size ** 3]) # [CAS, H * H * H]
|
104 |
-
density_bitfield = torch.zeros(self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [CAS * H * H * H // 8]
|
105 |
-
self.register_buffer('density_grid', density_grid)
|
106 |
-
self.register_buffer('density_bitfield', density_bitfield)
|
107 |
-
self.mean_density = 0
|
108 |
-
self.iter_density = 0
|
109 |
-
# step counter
|
110 |
-
step_counter = torch.zeros(16, 2, dtype=torch.int32) # 16 is hardcoded for averaging...
|
111 |
-
self.register_buffer('step_counter', step_counter)
|
112 |
-
self.mean_count = 0
|
113 |
-
self.local_step = 0
|
114 |
-
|
115 |
-
def forward(self, x, d):
|
116 |
-
raise NotImplementedError()
|
117 |
-
|
118 |
-
# separated density and color query (can accelerate non-cuda-ray mode.)
|
119 |
-
def density(self, x):
|
120 |
-
raise NotImplementedError()
|
121 |
-
|
122 |
-
def color(self, x, d, mask=None, **kwargs):
|
123 |
-
raise NotImplementedError()
|
124 |
-
|
125 |
-
def reset_extra_state(self):
|
126 |
-
if not self.cuda_ray:
|
127 |
-
return
|
128 |
-
# density grid
|
129 |
-
self.density_grid.zero_()
|
130 |
-
self.mean_density = 0
|
131 |
-
self.iter_density = 0
|
132 |
-
# step counter
|
133 |
-
self.step_counter.zero_()
|
134 |
-
self.mean_count = 0
|
135 |
-
self.local_step = 0
|
136 |
-
|
137 |
-
def run(self, rays_o, rays_d, num_steps=128, upsample_steps=128, bg_color=None, perturb=False, **kwargs):
|
138 |
-
# rays_o, rays_d: [B, N, 3], assumes B == 1
|
139 |
-
# bg_color: [3] in range [0, 1]
|
140 |
-
# return: image: [B, N, 3], depth: [B, N]
|
141 |
-
|
142 |
-
prefix = rays_o.shape[:-1]
|
143 |
-
rays_o = rays_o.contiguous().view(-1, 3)
|
144 |
-
rays_d = rays_d.contiguous().view(-1, 3)
|
145 |
-
|
146 |
-
N = rays_o.shape[0] # N = B * N, in fact
|
147 |
-
device = rays_o.device
|
148 |
-
|
149 |
-
# choose aabb
|
150 |
-
aabb = self.aabb_train if self.training else self.aabb_infer
|
151 |
-
|
152 |
-
# sample steps
|
153 |
-
nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, aabb, self.min_near)
|
154 |
-
nears.unsqueeze_(-1)
|
155 |
-
fars.unsqueeze_(-1)
|
156 |
-
|
157 |
-
#print(f'nears = {nears.min().item()} ~ {nears.max().item()}, fars = {fars.min().item()} ~ {fars.max().item()}')
|
158 |
-
|
159 |
-
z_vals = torch.linspace(0.0, 1.0, num_steps, device=device).unsqueeze(0) # [1, T]
|
160 |
-
z_vals = z_vals.expand((N, num_steps)) # [N, T]
|
161 |
-
z_vals = nears + (fars - nears) * z_vals # [N, T], in [nears, fars]
|
162 |
-
|
163 |
-
# perturb z_vals
|
164 |
-
sample_dist = (fars - nears) / num_steps
|
165 |
-
if perturb:
|
166 |
-
z_vals = z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist
|
167 |
-
#z_vals = z_vals.clamp(nears, fars) # avoid out of bounds xyzs.
|
168 |
-
|
169 |
-
# generate xyzs
|
170 |
-
xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(-1) # [N, 1, 3] * [N, T, 1] -> [N, T, 3]
|
171 |
-
xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:]) # a manual clip.
|
172 |
-
|
173 |
-
#plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())
|
174 |
-
|
175 |
-
# query SDF and RGB
|
176 |
-
density_outputs = self.density(xyzs.reshape(-1, 3))
|
177 |
-
|
178 |
-
#sigmas = density_outputs['sigma'].view(N, num_steps) # [N, T]
|
179 |
-
for k, v in density_outputs.items():
|
180 |
-
density_outputs[k] = v.view(N, num_steps, -1)
|
181 |
-
|
182 |
-
# upsample z_vals (nerf-like)
|
183 |
-
if upsample_steps > 0:
|
184 |
-
with torch.no_grad():
|
185 |
-
|
186 |
-
deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T-1]
|
187 |
-
deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1)
|
188 |
-
|
189 |
-
alphas = 1 - torch.exp(-deltas * self.density_scale * density_outputs['sigma'].squeeze(-1)) # [N, T]
|
190 |
-
alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+1]
|
191 |
-
weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T]
|
192 |
-
|
193 |
-
# sample new z_vals
|
194 |
-
z_vals_mid = (z_vals[..., :-1] + 0.5 * deltas[..., :-1]) # [N, T-1]
|
195 |
-
new_z_vals = sample_pdf(z_vals_mid, weights[:, 1:-1], upsample_steps, det=not self.training).detach() # [N, t]
|
196 |
-
|
197 |
-
new_xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * new_z_vals.unsqueeze(-1) # [N, 1, 3] * [N, t, 1] -> [N, t, 3]
|
198 |
-
new_xyzs = torch.min(torch.max(new_xyzs, aabb[:3]), aabb[3:]) # a manual clip.
|
199 |
-
|
200 |
-
# only forward new points to save computation
|
201 |
-
new_density_outputs = self.density(new_xyzs.reshape(-1, 3))
|
202 |
-
#new_sigmas = new_density_outputs['sigma'].view(N, upsample_steps) # [N, t]
|
203 |
-
for k, v in new_density_outputs.items():
|
204 |
-
new_density_outputs[k] = v.view(N, upsample_steps, -1)
|
205 |
-
|
206 |
-
# re-order
|
207 |
-
z_vals = torch.cat([z_vals, new_z_vals], dim=1) # [N, T+t]
|
208 |
-
z_vals, z_index = torch.sort(z_vals, dim=1)
|
209 |
-
|
210 |
-
xyzs = torch.cat([xyzs, new_xyzs], dim=1) # [N, T+t, 3]
|
211 |
-
xyzs = torch.gather(xyzs, dim=1, index=z_index.unsqueeze(-1).expand_as(xyzs))
|
212 |
-
|
213 |
-
for k in density_outputs:
|
214 |
-
tmp_output = torch.cat([density_outputs[k], new_density_outputs[k]], dim=1)
|
215 |
-
density_outputs[k] = torch.gather(tmp_output, dim=1, index=z_index.unsqueeze(-1).expand_as(tmp_output))
|
216 |
-
|
217 |
-
deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T+t-1]
|
218 |
-
deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1)
|
219 |
-
alphas = 1 - torch.exp(-deltas * self.density_scale * density_outputs['sigma'].squeeze(-1)) # [N, T+t]
|
220 |
-
alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+t+1]
|
221 |
-
weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T+t]
|
222 |
-
|
223 |
-
dirs = rays_d.view(-1, 1, 3).expand_as(xyzs)
|
224 |
-
for k, v in density_outputs.items():
|
225 |
-
density_outputs[k] = v.view(-1, v.shape[-1])
|
226 |
-
|
227 |
-
mask = weights > 1e-4 # hard coded
|
228 |
-
rgbs = self.color(xyzs.reshape(-1, 3), dirs.reshape(-1, 3), mask=mask.reshape(-1), **density_outputs)
|
229 |
-
rgbs = rgbs.view(N, -1, 3) # [N, T+t, 3]
|
230 |
-
|
231 |
-
#print(xyzs.shape, 'valid_rgb:', mask.sum().item())
|
232 |
-
|
233 |
-
# calculate weight_sum (mask)
|
234 |
-
weights_sum = weights.sum(dim=-1) # [N]
|
235 |
-
|
236 |
-
# calculate depth
|
237 |
-
ori_z_vals = ((z_vals - nears) / (fars - nears)).clamp(0, 1)
|
238 |
-
depth = torch.sum(weights * ori_z_vals, dim=-1)
|
239 |
-
|
240 |
-
# calculate color
|
241 |
-
image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2) # [N, 3], in [0, 1]
|
242 |
-
|
243 |
-
# mix background color
|
244 |
-
if self.bg_radius > 0:
|
245 |
-
# use the bg model to calculate bg_color
|
246 |
-
sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1]
|
247 |
-
bg_color = self.background(sph, rays_d.reshape(-1, 3)) # [N, 3]
|
248 |
-
elif bg_color is None:
|
249 |
-
bg_color = 1
|
250 |
-
|
251 |
-
image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
|
252 |
-
|
253 |
-
image = image.view(*prefix, 3)
|
254 |
-
depth = depth.view(*prefix)
|
255 |
-
|
256 |
-
# tmp: reg loss in mip-nerf 360
|
257 |
-
# z_vals_shifted = torch.cat([z_vals[..., 1:], sample_dist * torch.ones_like(z_vals[..., :1])], dim=-1)
|
258 |
-
# mid_zs = (z_vals + z_vals_shifted) / 2 # [N, T]
|
259 |
-
# loss_dist = (torch.abs(mid_zs.unsqueeze(1) - mid_zs.unsqueeze(2)) * (weights.unsqueeze(1) * weights.unsqueeze(2))).sum() + 1/3 * ((z_vals_shifted - z_vals_shifted) * (weights ** 2)).sum()
|
260 |
-
|
261 |
-
return {
|
262 |
-
'depth': depth,
|
263 |
-
'image': image,
|
264 |
-
'weights_sum': weights_sum,
|
265 |
-
}
|
266 |
-
|
267 |
-
|
268 |
-
def run_cuda(self, rays_o, rays_d, dt_gamma=0, bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, **kwargs):
|
269 |
-
# rays_o, rays_d: [B, N, 3], assumes B == 1
|
270 |
-
# return: image: [B, N, 3], depth: [B, N]
|
271 |
-
|
272 |
-
prefix = rays_o.shape[:-1]
|
273 |
-
rays_o = rays_o.contiguous().view(-1, 3)
|
274 |
-
rays_d = rays_d.contiguous().view(-1, 3)
|
275 |
-
|
276 |
-
N = rays_o.shape[0] # N = B * N, in fact
|
277 |
-
device = rays_o.device
|
278 |
-
|
279 |
-
# pre-calculate near far
|
280 |
-
nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer, self.min_near)
|
281 |
-
|
282 |
-
# mix background color
|
283 |
-
if self.bg_radius > 0:
|
284 |
-
# use the bg model to calculate bg_color
|
285 |
-
sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1]
|
286 |
-
bg_color = self.background(sph, rays_d) # [N, 3]
|
287 |
-
elif bg_color is None:
|
288 |
-
bg_color = 1
|
289 |
-
|
290 |
-
results = {}
|
291 |
-
|
292 |
-
if self.training:
|
293 |
-
# setup counter
|
294 |
-
counter = self.step_counter[self.local_step % 16]
|
295 |
-
counter.zero_() # set to 0
|
296 |
-
self.local_step += 1
|
297 |
-
|
298 |
-
xyzs, dirs, deltas, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, max_steps)
|
299 |
-
|
300 |
-
#plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())
|
301 |
-
|
302 |
-
sigmas, rgbs = self(xyzs, dirs)
|
303 |
-
sigmas = self.density_scale * sigmas
|
304 |
-
|
305 |
-
weights_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, deltas, rays, T_thresh)
|
306 |
-
image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
|
307 |
-
depth = torch.clamp(depth - nears, min=0) / (fars - nears)
|
308 |
-
image = image.view(*prefix, 3)
|
309 |
-
depth = depth.view(*prefix)
|
310 |
-
|
311 |
-
else:
|
312 |
-
|
313 |
-
# allocate outputs
|
314 |
-
# if use autocast, must init as half so it won't be autocasted and lose reference.
|
315 |
-
#dtype = torch.half if torch.is_autocast_enabled() else torch.float32
|
316 |
-
# output should always be float32! only network inference uses half.
|
317 |
-
dtype = torch.float32
|
318 |
-
|
319 |
-
weights_sum = torch.zeros(N, dtype=dtype, device=device)
|
320 |
-
depth = torch.zeros(N, dtype=dtype, device=device)
|
321 |
-
image = torch.zeros(N, 3, dtype=dtype, device=device)
|
322 |
-
|
323 |
-
n_alive = N
|
324 |
-
rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N]
|
325 |
-
rays_t = nears.clone() # [N]
|
326 |
-
|
327 |
-
step = 0
|
328 |
-
|
329 |
-
while step < max_steps:
|
330 |
-
|
331 |
-
# count alive rays
|
332 |
-
n_alive = rays_alive.shape[0]
|
333 |
-
|
334 |
-
# exit loop
|
335 |
-
if n_alive <= 0:
|
336 |
-
break
|
337 |
-
|
338 |
-
# decide compact_steps
|
339 |
-
n_step = max(min(N // n_alive, 8), 1)
|
340 |
-
|
341 |
-
xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, 128, perturb if step == 0 else False, dt_gamma, max_steps)
|
342 |
-
|
343 |
-
sigmas, rgbs = self(xyzs, dirs)
|
344 |
-
# density_outputs = self.density(xyzs) # [M,], use a dict since it may include extra things, like geo_feat for rgb.
|
345 |
-
# sigmas = density_outputs['sigma']
|
346 |
-
# rgbs = self.color(xyzs, dirs, **density_outputs)
|
347 |
-
sigmas = self.density_scale * sigmas
|
348 |
-
|
349 |
-
raymarching.composite_rays(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh)
|
350 |
-
|
351 |
-
rays_alive = rays_alive[rays_alive >= 0]
|
352 |
-
|
353 |
-
#print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}')
|
354 |
-
|
355 |
-
step += n_step
|
356 |
-
|
357 |
-
image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
|
358 |
-
depth = torch.clamp(depth - nears, min=0) / (fars - nears)
|
359 |
-
image = image.view(*prefix, 3)
|
360 |
-
depth = depth.view(*prefix)
|
361 |
-
|
362 |
-
results['weights_sum'] = weights_sum
|
363 |
-
results['depth'] = depth
|
364 |
-
results['image'] = image
|
365 |
-
|
366 |
-
return results
|
367 |
-
|
368 |
-
@torch.no_grad()
|
369 |
-
def mark_untrained_grid(self, poses, intrinsic, S=64):
|
370 |
-
# poses: [B, 4, 4]
|
371 |
-
# intrinsic: [3, 3]
|
372 |
-
|
373 |
-
if not self.cuda_ray:
|
374 |
-
return
|
375 |
-
|
376 |
-
if isinstance(poses, np.ndarray):
|
377 |
-
poses = torch.from_numpy(poses)
|
378 |
-
|
379 |
-
B = poses.shape[0]
|
380 |
-
|
381 |
-
fx, fy, cx, cy = intrinsic
|
382 |
-
|
383 |
-
X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
|
384 |
-
Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
|
385 |
-
Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
|
386 |
-
|
387 |
-
count = torch.zeros_like(self.density_grid)
|
388 |
-
poses = poses.to(count.device)
|
389 |
-
|
390 |
-
# 5-level loop, forgive me...
|
391 |
-
|
392 |
-
for xs in X:
|
393 |
-
for ys in Y:
|
394 |
-
for zs in Z:
|
395 |
-
|
396 |
-
# construct points
|
397 |
-
xx, yy, zz = custom_meshgrid(xs, ys, zs)
|
398 |
-
coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)
|
399 |
-
indices = raymarching.morton3D(coords).long() # [N]
|
400 |
-
world_xyzs = (2 * coords.float() / (self.grid_size - 1) - 1).unsqueeze(0) # [1, N, 3] in [-1, 1]
|
401 |
-
|
402 |
-
# cascading
|
403 |
-
for cas in range(self.cascade):
|
404 |
-
bound = min(2 ** cas, self.bound)
|
405 |
-
half_grid_size = bound / self.grid_size
|
406 |
-
# scale to current cascade's resolution
|
407 |
-
cas_world_xyzs = world_xyzs * (bound - half_grid_size)
|
408 |
-
|
409 |
-
# split batch to avoid OOM
|
410 |
-
head = 0
|
411 |
-
while head < B:
|
412 |
-
tail = min(head + S, B)
|
413 |
-
|
414 |
-
# world2cam transform (poses is c2w, so we need to transpose it. Another transpose is needed for batched matmul, so the final form is without transpose.)
|
415 |
-
cam_xyzs = cas_world_xyzs - poses[head:tail, :3, 3].unsqueeze(1)
|
416 |
-
cam_xyzs = cam_xyzs @ poses[head:tail, :3, :3] # [S, N, 3]
|
417 |
-
|
418 |
-
# query if point is covered by any camera
|
419 |
-
mask_z = cam_xyzs[:, :, 2] > 0 # [S, N]
|
420 |
-
mask_x = torch.abs(cam_xyzs[:, :, 0]) < cx / fx * cam_xyzs[:, :, 2] + half_grid_size * 2
|
421 |
-
mask_y = torch.abs(cam_xyzs[:, :, 1]) < cy / fy * cam_xyzs[:, :, 2] + half_grid_size * 2
|
422 |
-
mask = (mask_z & mask_x & mask_y).sum(0).reshape(-1) # [N]
|
423 |
-
|
424 |
-
# update count
|
425 |
-
count[cas, indices] += mask
|
426 |
-
head += S
|
427 |
-
|
428 |
-
# mark untrained grid as -1
|
429 |
-
self.density_grid[count == 0] = -1
|
430 |
-
|
431 |
-
print(f'[mark untrained grid] {(count == 0).sum()} from {self.grid_size ** 3 * self.cascade}')
|
432 |
-
|
433 |
-
@torch.no_grad()
|
434 |
-
def update_extra_state(self, decay=0.95, S=128):
|
435 |
-
# call before each epoch to update extra states.
|
436 |
-
|
437 |
-
if not self.cuda_ray:
|
438 |
-
return
|
439 |
-
|
440 |
-
### update density grid
|
441 |
-
tmp_grid = - torch.ones_like(self.density_grid)
|
442 |
-
|
443 |
-
# full update.
|
444 |
-
if self.iter_density < 16:
|
445 |
-
#if True:
|
446 |
-
X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
|
447 |
-
Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
|
448 |
-
Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
|
449 |
-
|
450 |
-
for xs in X:
|
451 |
-
for ys in Y:
|
452 |
-
for zs in Z:
|
453 |
-
|
454 |
-
# construct points
|
455 |
-
xx, yy, zz = custom_meshgrid(xs, ys, zs)
|
456 |
-
coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)
|
457 |
-
indices = raymarching.morton3D(coords).long() # [N]
|
458 |
-
xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1]
|
459 |
-
|
460 |
-
# cascading
|
461 |
-
for cas in range(self.cascade):
|
462 |
-
bound = min(2 ** cas, self.bound)
|
463 |
-
half_grid_size = bound / self.grid_size
|
464 |
-
# scale to current cascade's resolution
|
465 |
-
cas_xyzs = xyzs * (bound - half_grid_size)
|
466 |
-
# add noise in [-hgs, hgs]
|
467 |
-
cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size
|
468 |
-
# query density
|
469 |
-
sigmas = self.density(cas_xyzs)['sigma'].reshape(-1).detach()
|
470 |
-
sigmas *= self.density_scale
|
471 |
-
# assign
|
472 |
-
tmp_grid[cas, indices] = sigmas
|
473 |
-
|
474 |
-
# partial update (half the computation)
|
475 |
-
# TODO: why no need of maxpool ?
|
476 |
-
else:
|
477 |
-
N = self.grid_size ** 3 // 4 # H * H * H / 4
|
478 |
-
for cas in range(self.cascade):
|
479 |
-
# random sample some positions
|
480 |
-
coords = torch.randint(0, self.grid_size, (N, 3), device=self.density_bitfield.device) # [N, 3], in [0, 128)
|
481 |
-
indices = raymarching.morton3D(coords).long() # [N]
|
482 |
-
# random sample occupied positions
|
483 |
-
occ_indices = torch.nonzero(self.density_grid[cas] > 0).squeeze(-1) # [Nz]
|
484 |
-
rand_mask = torch.randint(0, occ_indices.shape[0], [N], dtype=torch.long, device=self.density_bitfield.device)
|
485 |
-
occ_indices = occ_indices[rand_mask] # [Nz] --> [N], allow for duplication
|
486 |
-
occ_coords = raymarching.morton3D_invert(occ_indices) # [N, 3]
|
487 |
-
# concat
|
488 |
-
indices = torch.cat([indices, occ_indices], dim=0)
|
489 |
-
coords = torch.cat([coords, occ_coords], dim=0)
|
490 |
-
# same below
|
491 |
-
xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1]
|
492 |
-
bound = min(2 ** cas, self.bound)
|
493 |
-
half_grid_size = bound / self.grid_size
|
494 |
-
# scale to current cascade's resolution
|
495 |
-
cas_xyzs = xyzs * (bound - half_grid_size)
|
496 |
-
# add noise in [-hgs, hgs]
|
497 |
-
cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size
|
498 |
-
# query density
|
499 |
-
sigmas = self.density(cas_xyzs)['sigma'].reshape(-1).detach()
|
500 |
-
sigmas *= self.density_scale
|
501 |
-
# assign
|
502 |
-
tmp_grid[cas, indices] = sigmas
|
503 |
-
|
504 |
-
## max-pool on tmp_grid for less aggressive culling [No significant improvement...]
|
505 |
-
# invalid_mask = tmp_grid < 0
|
506 |
-
# tmp_grid = F.max_pool3d(tmp_grid.view(self.cascade, 1, self.grid_size, self.grid_size, self.grid_size), kernel_size=3, stride=1, padding=1).view(self.cascade, -1)
|
507 |
-
# tmp_grid[invalid_mask] = -1
|
508 |
-
|
509 |
-
# ema update
|
510 |
-
valid_mask = (self.density_grid >= 0) & (tmp_grid >= 0)
|
511 |
-
self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask])
|
512 |
-
self.mean_density = torch.mean(self.density_grid.clamp(min=0)).item() # -1 regions are viewed as 0 density.
|
513 |
-
#self.mean_density = torch.mean(self.density_grid[self.density_grid > 0]).item() # do not count -1 regions
|
514 |
-
self.iter_density += 1
|
515 |
-
|
516 |
-
# convert to bitfield
|
517 |
-
density_thresh = min(self.mean_density, self.density_thresh)
|
518 |
-
self.density_bitfield = raymarching.packbits(self.density_grid, density_thresh, self.density_bitfield)
|
519 |
-
|
520 |
-
### update step counter
|
521 |
-
total_step = min(16, self.local_step)
|
522 |
-
if total_step > 0:
|
523 |
-
self.mean_count = int(self.step_counter[:total_step, 0].sum().item() / total_step)
|
524 |
-
self.local_step = 0
|
525 |
-
|
526 |
-
#print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > 0.01).sum() / (128**3 * self.cascade):.3f} | [step counter] mean={self.mean_count}')
|
527 |
-
|
528 |
-
|
529 |
-
def render(self, rays_o, rays_d, staged=False, max_ray_batch=4096, **kwargs):
|
530 |
-
# rays_o, rays_d: [B, N, 3], assumes B == 1
|
531 |
-
# return: pred_rgb: [B, N, 3]
|
532 |
-
|
533 |
-
if self.cuda_ray:
|
534 |
-
_run = self.run_cuda
|
535 |
-
else:
|
536 |
-
_run = self.run
|
537 |
-
|
538 |
-
results = _run(rays_o, rays_d, **kwargs)
|
539 |
-
return results
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
class _trunc_exp(Function):
|
544 |
-
@staticmethod
|
545 |
-
@custom_fwd(cast_inputs=torch.float32) # cast to float32
|
546 |
-
def forward(ctx, x):
|
547 |
-
ctx.save_for_backward(x)
|
548 |
-
return torch.exp(x)
|
549 |
-
|
550 |
-
@staticmethod
|
551 |
-
@custom_bwd
|
552 |
-
def backward(ctx, g):
|
553 |
-
x = ctx.saved_tensors[0]
|
554 |
-
return g * torch.exp(x.clamp(-15, 15))
|
555 |
-
|
556 |
-
trunc_exp = _trunc_exp.apply
|
557 |
-
|
558 |
-
class NGPNetwork(NGPRenderer):
|
559 |
-
def __init__(self,
|
560 |
-
num_layers=2,
|
561 |
-
hidden_dim=64,
|
562 |
-
geo_feat_dim=15,
|
563 |
-
num_layers_color=3,
|
564 |
-
hidden_dim_color=64,
|
565 |
-
bound=0.5,
|
566 |
-
max_resolution=128,
|
567 |
-
base_resolution=16,
|
568 |
-
n_levels=16,
|
569 |
-
**kwargs
|
570 |
-
):
|
571 |
-
super().__init__(bound, **kwargs)
|
572 |
-
|
573 |
-
# sigma network
|
574 |
-
self.num_layers = num_layers
|
575 |
-
self.hidden_dim = hidden_dim
|
576 |
-
self.geo_feat_dim = geo_feat_dim
|
577 |
-
self.bound = bound
|
578 |
-
|
579 |
-
log2_hashmap_size = 19
|
580 |
-
n_features_per_level = 2
|
581 |
-
|
582 |
-
|
583 |
-
per_level_scale = np.exp2(np.log2(max_resolution / base_resolution) / (n_levels - 1))
|
584 |
-
|
585 |
-
self.encoder = tcnn.Encoding(
|
586 |
-
n_input_dims=3,
|
587 |
-
encoding_config={
|
588 |
-
"otype": "HashGrid",
|
589 |
-
"n_levels": n_levels,
|
590 |
-
"n_features_per_level": n_features_per_level,
|
591 |
-
"log2_hashmap_size": log2_hashmap_size,
|
592 |
-
"base_resolution": base_resolution,
|
593 |
-
"per_level_scale": per_level_scale,
|
594 |
-
},
|
595 |
-
)
|
596 |
-
|
597 |
-
self.sigma_net = tcnn.Network(
|
598 |
-
n_input_dims = n_levels * 2,
|
599 |
-
n_output_dims=1 + self.geo_feat_dim,
|
600 |
-
network_config={
|
601 |
-
"otype": "FullyFusedMLP",
|
602 |
-
"activation": "ReLU",
|
603 |
-
"output_activation": "None",
|
604 |
-
"n_neurons": hidden_dim,
|
605 |
-
"n_hidden_layers": num_layers - 1,
|
606 |
-
},
|
607 |
-
)
|
608 |
-
|
609 |
-
# color network
|
610 |
-
self.num_layers_color = num_layers_color
|
611 |
-
self.hidden_dim_color = hidden_dim_color
|
612 |
-
|
613 |
-
self.encoder_dir = tcnn.Encoding(
|
614 |
-
n_input_dims=3,
|
615 |
-
encoding_config={
|
616 |
-
"otype": "SphericalHarmonics",
|
617 |
-
"degree": 4,
|
618 |
-
},
|
619 |
-
)
|
620 |
-
|
621 |
-
self.in_dim_color = self.encoder_dir.n_output_dims + self.geo_feat_dim
|
622 |
-
|
623 |
-
self.color_net = tcnn.Network(
|
624 |
-
n_input_dims = self.in_dim_color,
|
625 |
-
n_output_dims=3,
|
626 |
-
network_config={
|
627 |
-
"otype": "FullyFusedMLP",
|
628 |
-
"activation": "ReLU",
|
629 |
-
"output_activation": "None",
|
630 |
-
"n_neurons": hidden_dim_color,
|
631 |
-
"n_hidden_layers": num_layers_color - 1,
|
632 |
-
},
|
633 |
-
)
|
634 |
-
self.density_scale, self.density_std = 10.0, 0.25
|
635 |
-
|
636 |
-
def forward(self, x, d):
|
637 |
-
# x: [N, 3], in [-bound, bound]
|
638 |
-
# d: [N, 3], nomalized in [-1, 1]
|
639 |
-
|
640 |
-
|
641 |
-
# sigma
|
642 |
-
x_raw = x
|
643 |
-
x = (x + self.bound) / (2 * self.bound) # to [0, 1]
|
644 |
-
x = self.encoder(x)
|
645 |
-
h = self.sigma_net(x)
|
646 |
-
|
647 |
-
# sigma = F.relu(h[..., 0])
|
648 |
-
density = h[..., 0]
|
649 |
-
# add density bias
|
650 |
-
dist = torch.norm(x_raw, dim=-1)
|
651 |
-
density_bias = (1 - dist / self.density_std) * self.density_scale
|
652 |
-
density = density_bias + density
|
653 |
-
sigma = F.softplus(density)
|
654 |
-
geo_feat = h[..., 1:]
|
655 |
-
|
656 |
-
# color
|
657 |
-
d = (d + 1) / 2 # tcnn SH encoding requires inputs to be in [0, 1]
|
658 |
-
d = self.encoder_dir(d)
|
659 |
-
|
660 |
-
# p = torch.zeros_like(geo_feat[..., :1]) # manual input padding
|
661 |
-
h = torch.cat([d, geo_feat], dim=-1)
|
662 |
-
h = self.color_net(h)
|
663 |
-
|
664 |
-
# sigmoid activation for rgb
|
665 |
-
color = torch.sigmoid(h)
|
666 |
-
|
667 |
-
return sigma, color
|
668 |
-
|
669 |
-
def density(self, x):
|
670 |
-
# x: [N, 3], in [-bound, bound]
|
671 |
-
x_raw = x
|
672 |
-
x = (x + self.bound) / (2 * self.bound) # to [0, 1]
|
673 |
-
x = self.encoder(x)
|
674 |
-
h = self.sigma_net(x)
|
675 |
-
|
676 |
-
# sigma = F.relu(h[..., 0])
|
677 |
-
density = h[..., 0]
|
678 |
-
# add density bias
|
679 |
-
dist = torch.norm(x_raw, dim=-1)
|
680 |
-
density_bias = (1 - dist / self.density_std) * self.density_scale
|
681 |
-
density = density_bias + density
|
682 |
-
sigma = F.softplus(density)
|
683 |
-
geo_feat = h[..., 1:]
|
684 |
-
|
685 |
-
return {
|
686 |
-
'sigma': sigma,
|
687 |
-
'geo_feat': geo_feat,
|
688 |
-
}
|
689 |
-
|
690 |
-
# allow masked inference
|
691 |
-
def color(self, x, d, mask=None, geo_feat=None, **kwargs):
|
692 |
-
# x: [N, 3] in [-bound, bound]
|
693 |
-
# mask: [N,], bool, indicates where we actually needs to compute rgb.
|
694 |
-
|
695 |
-
x = (x + self.bound) / (2 * self.bound) # to [0, 1]
|
696 |
-
|
697 |
-
if mask is not None:
|
698 |
-
rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3]
|
699 |
-
# in case of empty mask
|
700 |
-
if not mask.any():
|
701 |
-
return rgbs
|
702 |
-
x = x[mask]
|
703 |
-
d = d[mask]
|
704 |
-
geo_feat = geo_feat[mask]
|
705 |
-
|
706 |
-
# color
|
707 |
-
d = (d + 1) / 2 # tcnn SH encoding requires inputs to be in [0, 1]
|
708 |
-
d = self.encoder_dir(d)
|
709 |
-
|
710 |
-
h = torch.cat([d, geo_feat], dim=-1)
|
711 |
-
h = self.color_net(h)
|
712 |
-
|
713 |
-
# sigmoid activation for rgb
|
714 |
-
h = torch.sigmoid(h)
|
715 |
-
|
716 |
-
if mask is not None:
|
717 |
-
rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32
|
718 |
-
else:
|
719 |
-
rgbs = h
|
720 |
-
|
721 |
-
return rgbs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3drecon/renderer/renderer.py
DELETED
@@ -1,640 +0,0 @@
|
|
1 |
-
import abc
|
2 |
-
import os
|
3 |
-
from pathlib import Path
|
4 |
-
|
5 |
-
import cv2
|
6 |
-
import numpy as np
|
7 |
-
import pytorch_lightning as pl
|
8 |
-
import torch
|
9 |
-
import torch.nn as nn
|
10 |
-
import torch.nn.functional as F
|
11 |
-
from omegaconf import OmegaConf
|
12 |
-
|
13 |
-
from skimage.io import imread, imsave
|
14 |
-
from PIL import Image
|
15 |
-
from torch.optim.lr_scheduler import LambdaLR
|
16 |
-
|
17 |
-
from renderer.neus_networks import SDFNetwork, RenderingNetwork, SingleVarianceNetwork, SDFHashGridNetwork, RenderingFFNetwork
|
18 |
-
from renderer.ngp_renderer import NGPNetwork
|
19 |
-
from util import instantiate_from_config, read_pickle, concat_images_list
|
20 |
-
|
21 |
-
DEFAULT_RADIUS = np.sqrt(3)/2
|
22 |
-
DEFAULT_SIDE_LENGTH = 0.6
|
23 |
-
|
24 |
-
def sample_pdf(bins, weights, n_samples, det=True):
|
25 |
-
device = bins.device
|
26 |
-
dtype = bins.dtype
|
27 |
-
# This implementation is from NeRF
|
28 |
-
# Get pdf
|
29 |
-
weights = weights + 1e-5 # prevent nans
|
30 |
-
pdf = weights / torch.sum(weights, -1, keepdim=True)
|
31 |
-
cdf = torch.cumsum(pdf, -1)
|
32 |
-
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
|
33 |
-
# Take uniform samples
|
34 |
-
if det:
|
35 |
-
u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples, dtype=dtype, device=device)
|
36 |
-
u = u.expand(list(cdf.shape[:-1]) + [n_samples])
|
37 |
-
else:
|
38 |
-
u = torch.rand(list(cdf.shape[:-1]) + [n_samples], dtype=dtype, device=device)
|
39 |
-
|
40 |
-
# Invert CDF
|
41 |
-
u = u.contiguous()
|
42 |
-
inds = torch.searchsorted(cdf, u, right=True)
|
43 |
-
below = torch.max(torch.zeros_like(inds - 1), inds - 1)
|
44 |
-
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
|
45 |
-
inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
|
46 |
-
|
47 |
-
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
|
48 |
-
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
|
49 |
-
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
|
50 |
-
|
51 |
-
denom = (cdf_g[..., 1] - cdf_g[..., 0])
|
52 |
-
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
|
53 |
-
t = (u - cdf_g[..., 0]) / denom
|
54 |
-
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
|
55 |
-
|
56 |
-
return samples
|
57 |
-
|
58 |
-
def near_far_from_sphere(rays_o, rays_d, radius=DEFAULT_RADIUS):
|
59 |
-
a = torch.sum(rays_d ** 2, dim=-1, keepdim=True)
|
60 |
-
b = torch.sum(rays_o * rays_d, dim=-1, keepdim=True)
|
61 |
-
mid = -b / a
|
62 |
-
near = mid - radius
|
63 |
-
far = mid + radius
|
64 |
-
return near, far
|
65 |
-
|
66 |
-
class BackgroundRemoval:
|
67 |
-
def __init__(self, device='cuda'):
|
68 |
-
from carvekit.api.high import HiInterface
|
69 |
-
self.interface = HiInterface(
|
70 |
-
object_type="object", # Can be "object" or "hairs-like".
|
71 |
-
batch_size_seg=5,
|
72 |
-
batch_size_matting=1,
|
73 |
-
device=device,
|
74 |
-
seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
|
75 |
-
matting_mask_size=2048,
|
76 |
-
trimap_prob_threshold=231,
|
77 |
-
trimap_dilation=30,
|
78 |
-
trimap_erosion_iters=5,
|
79 |
-
fp16=True,
|
80 |
-
)
|
81 |
-
|
82 |
-
@torch.no_grad()
|
83 |
-
def __call__(self, image):
|
84 |
-
# image: [H, W, 3] array in [0, 255].
|
85 |
-
image = Image.fromarray(image)
|
86 |
-
image = self.interface([image])[0]
|
87 |
-
image = np.array(image)
|
88 |
-
return image
|
89 |
-
|
90 |
-
|
91 |
-
class BaseRenderer(nn.Module):
|
92 |
-
def __init__(self, train_batch_num, test_batch_num):
|
93 |
-
super().__init__()
|
94 |
-
self.train_batch_num = train_batch_num
|
95 |
-
self.test_batch_num = test_batch_num
|
96 |
-
|
97 |
-
@abc.abstractmethod
|
98 |
-
def render_impl(self, ray_batch, is_train, step):
|
99 |
-
pass
|
100 |
-
|
101 |
-
@abc.abstractmethod
|
102 |
-
def render_with_loss(self, ray_batch, is_train, step):
|
103 |
-
pass
|
104 |
-
|
105 |
-
def render(self, ray_batch, is_train, step):
|
106 |
-
batch_num = self.train_batch_num if is_train else self.test_batch_num
|
107 |
-
ray_num = ray_batch['rays_o'].shape[0]
|
108 |
-
outputs = {}
|
109 |
-
for ri in range(0, ray_num, batch_num):
|
110 |
-
cur_ray_batch = {}
|
111 |
-
for k, v in ray_batch.items():
|
112 |
-
cur_ray_batch[k] = v[ri:ri + batch_num]
|
113 |
-
cur_outputs = self.render_impl(cur_ray_batch, is_train, step)
|
114 |
-
for k, v in cur_outputs.items():
|
115 |
-
if k not in outputs: outputs[k] = []
|
116 |
-
outputs[k].append(v)
|
117 |
-
|
118 |
-
for k, v in outputs.items():
|
119 |
-
outputs[k] = torch.cat(v, 0)
|
120 |
-
return outputs
|
121 |
-
|
122 |
-
|
123 |
-
class NeuSRenderer(BaseRenderer):
|
124 |
-
def __init__(self, train_batch_num, test_batch_num, lambda_eikonal_loss=0.1, use_mask=True,
|
125 |
-
lambda_rgb_loss=1.0, lambda_mask_loss=0.0, rgb_loss='soft_l1', coarse_sn=64, fine_sn=64):
|
126 |
-
super().__init__(train_batch_num, test_batch_num)
|
127 |
-
self.n_samples = coarse_sn
|
128 |
-
self.n_importance = fine_sn
|
129 |
-
self.up_sample_steps = 4
|
130 |
-
self.anneal_end = 200
|
131 |
-
self.use_mask = use_mask
|
132 |
-
self.lambda_eikonal_loss = lambda_eikonal_loss
|
133 |
-
self.lambda_rgb_loss = lambda_rgb_loss
|
134 |
-
self.lambda_mask_loss = lambda_mask_loss
|
135 |
-
self.rgb_loss = rgb_loss
|
136 |
-
|
137 |
-
self.sdf_network = SDFNetwork(d_out=257, d_in=3, d_hidden=256, n_layers=8, skip_in=[4], multires=6, bias=0.5, scale=1.0, geometric_init=True, weight_norm=True)
|
138 |
-
self.color_network = RenderingNetwork(d_feature=256, d_in=9, d_out=3, d_hidden=256, n_layers=4, weight_norm=True, multires_view=4, squeeze_out=True)
|
139 |
-
self.default_dtype = torch.float32
|
140 |
-
self.deviation_network = SingleVarianceNetwork(0.3)
|
141 |
-
|
142 |
-
@torch.no_grad()
|
143 |
-
def get_vertex_colors(self, vertices):
|
144 |
-
"""
|
145 |
-
@param vertices: n,3
|
146 |
-
@return:
|
147 |
-
"""
|
148 |
-
V = vertices.shape[0]
|
149 |
-
bn = 20480
|
150 |
-
verts_colors = []
|
151 |
-
with torch.no_grad():
|
152 |
-
for vi in range(0, V, bn):
|
153 |
-
verts = torch.from_numpy(vertices[vi:vi+bn].astype(np.float32)).cuda()
|
154 |
-
feats = self.sdf_network(verts)[..., 1:]
|
155 |
-
gradients = self.sdf_network.gradient(verts) # ...,3
|
156 |
-
gradients = F.normalize(gradients, dim=-1)
|
157 |
-
colors = self.color_network(verts, gradients, gradients, feats)
|
158 |
-
colors = torch.clamp(colors,min=0,max=1).cpu().numpy()
|
159 |
-
verts_colors.append(colors)
|
160 |
-
|
161 |
-
verts_colors = (np.concatenate(verts_colors, 0)*255).astype(np.uint8)
|
162 |
-
return verts_colors
|
163 |
-
|
164 |
-
def upsample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_s):
|
165 |
-
"""
|
166 |
-
Up sampling give a fixed inv_s
|
167 |
-
"""
|
168 |
-
device = rays_o.device
|
169 |
-
batch_size, n_samples = z_vals.shape
|
170 |
-
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] # n_rays, n_samples, 3
|
171 |
-
inner_mask = self.get_inner_mask(pts)
|
172 |
-
# radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=False)
|
173 |
-
inside_sphere = inner_mask[:, :-1] | inner_mask[:, 1:]
|
174 |
-
sdf = sdf.reshape(batch_size, n_samples)
|
175 |
-
prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:]
|
176 |
-
prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:]
|
177 |
-
mid_sdf = (prev_sdf + next_sdf) * 0.5
|
178 |
-
cos_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5)
|
179 |
-
|
180 |
-
prev_cos_val = torch.cat([torch.zeros([batch_size, 1], dtype=self.default_dtype, device=device), cos_val[:, :-1]], dim=-1)
|
181 |
-
cos_val = torch.stack([prev_cos_val, cos_val], dim=-1)
|
182 |
-
cos_val, _ = torch.min(cos_val, dim=-1, keepdim=False)
|
183 |
-
cos_val = cos_val.clip(-1e3, 0.0) * inside_sphere
|
184 |
-
|
185 |
-
dist = (next_z_vals - prev_z_vals)
|
186 |
-
prev_esti_sdf = mid_sdf - cos_val * dist * 0.5
|
187 |
-
next_esti_sdf = mid_sdf + cos_val * dist * 0.5
|
188 |
-
prev_cdf = torch.sigmoid(prev_esti_sdf * inv_s)
|
189 |
-
next_cdf = torch.sigmoid(next_esti_sdf * inv_s)
|
190 |
-
alpha = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5)
|
191 |
-
weights = alpha * torch.cumprod(
|
192 |
-
torch.cat([torch.ones([batch_size, 1], dtype=self.default_dtype, device=device), 1. - alpha + 1e-7], -1), -1)[:, :-1]
|
193 |
-
|
194 |
-
z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach()
|
195 |
-
return z_samples
|
196 |
-
|
197 |
-
def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, last=False):
|
198 |
-
batch_size, n_samples = z_vals.shape
|
199 |
-
_, n_importance = new_z_vals.shape
|
200 |
-
pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None]
|
201 |
-
z_vals = torch.cat([z_vals, new_z_vals], dim=-1)
|
202 |
-
z_vals, index = torch.sort(z_vals, dim=-1)
|
203 |
-
|
204 |
-
if not last:
|
205 |
-
device = pts.device
|
206 |
-
new_sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance)
|
207 |
-
sdf = torch.cat([sdf, new_sdf], dim=-1)
|
208 |
-
xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1).to(device)
|
209 |
-
index = index.reshape(-1)
|
210 |
-
sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance)
|
211 |
-
|
212 |
-
return z_vals, sdf
|
213 |
-
|
214 |
-
def sample_depth(self, rays_o, rays_d, near, far, perturb):
|
215 |
-
n_samples = self.n_samples
|
216 |
-
n_importance = self.n_importance
|
217 |
-
up_sample_steps = self.up_sample_steps
|
218 |
-
device = rays_o.device
|
219 |
-
|
220 |
-
# sample points
|
221 |
-
batch_size = len(rays_o)
|
222 |
-
z_vals = torch.linspace(0.0, 1.0, n_samples, dtype=self.default_dtype, device=device) # sn
|
223 |
-
z_vals = near + (far - near) * z_vals[None, :] # rn,sn
|
224 |
-
|
225 |
-
if perturb > 0:
|
226 |
-
t_rand = (torch.rand([batch_size, 1]).to(device) - 0.5)
|
227 |
-
z_vals = z_vals + t_rand * 2.0 / n_samples
|
228 |
-
|
229 |
-
# Up sample
|
230 |
-
with torch.no_grad():
|
231 |
-
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None]
|
232 |
-
sdf = self.sdf_network.sdf(pts).reshape(batch_size, n_samples)
|
233 |
-
|
234 |
-
for i in range(up_sample_steps):
|
235 |
-
rn, sn = z_vals.shape
|
236 |
-
inv_s = torch.ones(rn, sn - 1, dtype=self.default_dtype, device=device) * 64 * 2 ** i
|
237 |
-
new_z_vals = self.upsample(rays_o, rays_d, z_vals, sdf, n_importance // up_sample_steps, inv_s)
|
238 |
-
z_vals, sdf = self.cat_z_vals(rays_o, rays_d, z_vals, new_z_vals, sdf, last=(i + 1 == up_sample_steps))
|
239 |
-
|
240 |
-
return z_vals
|
241 |
-
|
242 |
-
def compute_sdf_alpha(self, points, dists, dirs, cos_anneal_ratio, step):
|
243 |
-
# points [...,3] dists [...] dirs[...,3]
|
244 |
-
sdf_nn_output = self.sdf_network(points)
|
245 |
-
sdf = sdf_nn_output[..., 0]
|
246 |
-
feature_vector = sdf_nn_output[..., 1:]
|
247 |
-
|
248 |
-
gradients = self.sdf_network.gradient(points) # ...,3
|
249 |
-
inv_s = self.deviation_network(points).clip(1e-6, 1e6) # ...,1
|
250 |
-
inv_s = inv_s[..., 0]
|
251 |
-
|
252 |
-
true_cos = (dirs * gradients).sum(-1) # [...]
|
253 |
-
iter_cos = -(F.relu(-true_cos * 0.5 + 0.5) * (1.0 - cos_anneal_ratio) +
|
254 |
-
F.relu(-true_cos) * cos_anneal_ratio) # always non-positive
|
255 |
-
|
256 |
-
# Estimate signed distances at section points
|
257 |
-
estimated_next_sdf = sdf + iter_cos * dists * 0.5
|
258 |
-
estimated_prev_sdf = sdf - iter_cos * dists * 0.5
|
259 |
-
|
260 |
-
prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s)
|
261 |
-
next_cdf = torch.sigmoid(estimated_next_sdf * inv_s)
|
262 |
-
|
263 |
-
p = prev_cdf - next_cdf
|
264 |
-
c = prev_cdf
|
265 |
-
|
266 |
-
alpha = ((p + 1e-5) / (c + 1e-5)).clip(0.0, 1.0) # [...]
|
267 |
-
return alpha, gradients, feature_vector, inv_s, sdf
|
268 |
-
|
269 |
-
def get_anneal_val(self, step):
|
270 |
-
if self.anneal_end < 0:
|
271 |
-
return 1.0
|
272 |
-
else:
|
273 |
-
return np.min([1.0, step / self.anneal_end])
|
274 |
-
|
275 |
-
def get_inner_mask(self, points):
|
276 |
-
return torch.sum(torch.abs(points)<=DEFAULT_SIDE_LENGTH,-1)==3
|
277 |
-
|
278 |
-
def render_impl(self, ray_batch, is_train, step):
|
279 |
-
near, far = near_far_from_sphere(ray_batch['rays_o'], ray_batch['rays_d'])
|
280 |
-
rays_o, rays_d = ray_batch['rays_o'], ray_batch['rays_d']
|
281 |
-
z_vals = self.sample_depth(rays_o, rays_d, near, far, is_train)
|
282 |
-
|
283 |
-
batch_size, n_samples = z_vals.shape
|
284 |
-
|
285 |
-
# section length in original space
|
286 |
-
dists = z_vals[..., 1:] - z_vals[..., :-1] # rn,sn-1
|
287 |
-
dists = torch.cat([dists, dists[..., -1:]], -1) # rn,sn
|
288 |
-
mid_z_vals = z_vals + dists * 0.5
|
289 |
-
|
290 |
-
points = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * mid_z_vals.unsqueeze(-1) # rn, sn, 3
|
291 |
-
inner_mask = self.get_inner_mask(points)
|
292 |
-
|
293 |
-
dirs = rays_d.unsqueeze(-2).expand(batch_size, n_samples, 3)
|
294 |
-
dirs = F.normalize(dirs, dim=-1)
|
295 |
-
device = rays_o.device
|
296 |
-
alpha, sampled_color, gradient_error, normal = torch.zeros(batch_size, n_samples, dtype=self.default_dtype, device=device), \
|
297 |
-
torch.zeros(batch_size, n_samples, 3, dtype=self.default_dtype, device=device), \
|
298 |
-
torch.zeros([batch_size, n_samples], dtype=self.default_dtype, device=device), \
|
299 |
-
torch.zeros([batch_size, n_samples, 3], dtype=self.default_dtype, device=device)
|
300 |
-
if torch.sum(inner_mask) > 0:
|
301 |
-
cos_anneal_ratio = self.get_anneal_val(step) if is_train else 1.0
|
302 |
-
alpha[inner_mask], gradients, feature_vector, inv_s, sdf = self.compute_sdf_alpha(points[inner_mask], dists[inner_mask], dirs[inner_mask], cos_anneal_ratio, step)
|
303 |
-
sampled_color[inner_mask] = self.color_network(points[inner_mask], gradients, -dirs[inner_mask], feature_vector)
|
304 |
-
# Eikonal loss
|
305 |
-
gradient_error[inner_mask] = (torch.linalg.norm(gradients, ord=2, dim=-1) - 1.0) ** 2 # rn,sn
|
306 |
-
normal[inner_mask] = F.normalize(gradients, dim=-1)
|
307 |
-
|
308 |
-
weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1], dtype=self.default_dtype, device=device), 1. - alpha + 1e-7], -1), -1)[..., :-1] # rn,sn
|
309 |
-
mask = torch.sum(weights,dim=1).unsqueeze(-1) # rn,1
|
310 |
-
color = (sampled_color * weights[..., None]).sum(dim=1) + (1 - mask) # add white background
|
311 |
-
normal = (normal * weights[..., None]).sum(dim=1)
|
312 |
-
|
313 |
-
outputs = {
|
314 |
-
'rgb': color, # rn,3
|
315 |
-
'gradient_error': gradient_error, # rn,sn
|
316 |
-
'inner_mask': inner_mask, # rn,sn
|
317 |
-
'normal': normal, # rn,3
|
318 |
-
'mask': mask, # rn,1
|
319 |
-
}
|
320 |
-
return outputs
|
321 |
-
|
322 |
-
def render_with_loss(self, ray_batch, is_train, step):
|
323 |
-
render_outputs = self.render(ray_batch, is_train, step)
|
324 |
-
|
325 |
-
rgb_gt = ray_batch['rgb']
|
326 |
-
rgb_pr = render_outputs['rgb']
|
327 |
-
if self.rgb_loss == 'soft_l1':
|
328 |
-
epsilon = 0.001
|
329 |
-
rgb_loss = torch.sqrt(torch.sum((rgb_gt - rgb_pr) ** 2, dim=-1) + epsilon)
|
330 |
-
elif self.rgb_loss =='mse':
|
331 |
-
rgb_loss = F.mse_loss(rgb_pr, rgb_gt, reduction='none')
|
332 |
-
else:
|
333 |
-
raise NotImplementedError
|
334 |
-
rgb_loss = torch.mean(rgb_loss)
|
335 |
-
|
336 |
-
eikonal_loss = torch.sum(render_outputs['gradient_error'] * render_outputs['inner_mask']) / torch.sum(render_outputs['inner_mask'] + 1e-5)
|
337 |
-
loss = rgb_loss * self.lambda_rgb_loss + eikonal_loss * self.lambda_eikonal_loss
|
338 |
-
loss_batch = {
|
339 |
-
'eikonal': eikonal_loss,
|
340 |
-
'rendering': rgb_loss,
|
341 |
-
# 'mask': mask_loss,
|
342 |
-
}
|
343 |
-
if self.lambda_mask_loss>0 and self.use_mask:
|
344 |
-
mask_loss = F.mse_loss(render_outputs['mask'], ray_batch['mask'], reduction='none').mean()
|
345 |
-
loss += mask_loss * self.lambda_mask_loss
|
346 |
-
loss_batch['mask'] = mask_loss
|
347 |
-
return loss, loss_batch
|
348 |
-
|
349 |
-
|
350 |
-
class NeRFRenderer(BaseRenderer):
|
351 |
-
def __init__(self, train_batch_num, test_batch_num, bound=0.5, use_mask=False, lambda_rgb_loss=1.0, lambda_mask_loss=0.0):
|
352 |
-
super().__init__(train_batch_num, test_batch_num)
|
353 |
-
self.train_batch_num = train_batch_num
|
354 |
-
self.test_batch_num = test_batch_num
|
355 |
-
self.use_mask = use_mask
|
356 |
-
self.field = NGPNetwork(bound=bound)
|
357 |
-
|
358 |
-
self.update_interval = 16
|
359 |
-
self.fp16 = True
|
360 |
-
self.lambda_rgb_loss = lambda_rgb_loss
|
361 |
-
self.lambda_mask_loss = lambda_mask_loss
|
362 |
-
|
363 |
-
def render_impl(self, ray_batch, is_train, step):
|
364 |
-
rays_o, rays_d = ray_batch['rays_o'], ray_batch['rays_d']
|
365 |
-
with torch.cuda.amp.autocast(enabled=self.fp16):
|
366 |
-
if step % self.update_interval==0:
|
367 |
-
self.field.update_extra_state()
|
368 |
-
|
369 |
-
outputs = self.field.render(rays_o, rays_d,)
|
370 |
-
|
371 |
-
renderings={
|
372 |
-
'rgb': outputs['image'],
|
373 |
-
'depth': outputs['depth'],
|
374 |
-
'mask': outputs['weights_sum'].unsqueeze(-1),
|
375 |
-
}
|
376 |
-
return renderings
|
377 |
-
|
378 |
-
def render_with_loss(self, ray_batch, is_train, step):
|
379 |
-
render_outputs = self.render(ray_batch, is_train, step)
|
380 |
-
|
381 |
-
rgb_gt = ray_batch['rgb']
|
382 |
-
rgb_pr = render_outputs['rgb']
|
383 |
-
epsilon = 0.001
|
384 |
-
rgb_loss = torch.sqrt(torch.sum((rgb_gt - rgb_pr) ** 2, dim=-1) + epsilon)
|
385 |
-
rgb_loss = torch.mean(rgb_loss)
|
386 |
-
loss = rgb_loss * self.lambda_rgb_loss
|
387 |
-
loss_batch = {'rendering': rgb_loss}
|
388 |
-
|
389 |
-
if self.use_mask:
|
390 |
-
mask_loss = F.mse_loss(render_outputs['mask'], ray_batch['mask'], reduction='none')
|
391 |
-
mask_loss = torch.mean(mask_loss)
|
392 |
-
loss = loss + mask_loss * self.lambda_mask_loss
|
393 |
-
loss_batch['mask'] = mask_loss
|
394 |
-
return loss, loss_batch
|
395 |
-
|
396 |
-
def cartesian_to_spherical(xyz):
|
397 |
-
ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
|
398 |
-
xy = xyz[:, 0] ** 2 + xyz[:, 1] ** 2
|
399 |
-
z = np.sqrt(xy + xyz[:, 2] ** 2)
|
400 |
-
theta = np.arctan2(np.sqrt(xy), xyz[:, 2]) # for elevation angle defined from Z-axis down
|
401 |
-
# ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up
|
402 |
-
azimuth = np.arctan2(xyz[:, 1], xyz[:, 0])
|
403 |
-
return np.array([theta, azimuth, z])
|
404 |
-
|
405 |
-
def get_pose(target_RT):
|
406 |
-
R, T = target_RT[:3, :3], target_RT[:, -1]
|
407 |
-
T_target = -R.T @ T
|
408 |
-
theta_target, azimuth_target, z_target = cartesian_to_spherical(T_target[None, :])
|
409 |
-
return theta_target, azimuth_target, z_target
|
410 |
-
|
411 |
-
|
412 |
-
class RendererTrainer(pl.LightningModule):
|
413 |
-
def __init__(self, image_path, data_path, total_steps, warm_up_steps, log_dir, train_batch_fg_num=0,
|
414 |
-
use_cube_feats=False, cube_ckpt=None, cube_cfg=None, cube_bound=0.5,
|
415 |
-
train_batch_num=4096, test_batch_num=8192, use_warm_up=True, use_mask=True,
|
416 |
-
lambda_rgb_loss=1.0, lambda_mask_loss=0.0, renderer='neus',
|
417 |
-
# used in neus
|
418 |
-
lambda_eikonal_loss=0.1,
|
419 |
-
coarse_sn=64, fine_sn=64):
|
420 |
-
super().__init__()
|
421 |
-
self.num_images = 36 # todo ours 36, syncdreamer 16
|
422 |
-
self.image_size = 256
|
423 |
-
self.log_dir = log_dir
|
424 |
-
(Path(log_dir)/'images').mkdir(exist_ok=True, parents=True)
|
425 |
-
self.train_batch_num = train_batch_num
|
426 |
-
self.train_batch_fg_num = train_batch_fg_num
|
427 |
-
self.test_batch_num = test_batch_num
|
428 |
-
self.image_path = image_path
|
429 |
-
self.data_path = data_path
|
430 |
-
self.total_steps = total_steps
|
431 |
-
self.warm_up_steps = warm_up_steps
|
432 |
-
self.use_mask = use_mask
|
433 |
-
self.lambda_eikonal_loss = lambda_eikonal_loss
|
434 |
-
self.lambda_rgb_loss = lambda_rgb_loss
|
435 |
-
self.lambda_mask_loss = lambda_mask_loss
|
436 |
-
self.use_warm_up = use_warm_up
|
437 |
-
|
438 |
-
self.use_cube_feats, self.cube_cfg, self.cube_ckpt = use_cube_feats, cube_cfg, cube_ckpt
|
439 |
-
|
440 |
-
self._init_dataset()
|
441 |
-
if renderer=='neus':
|
442 |
-
self.renderer = NeuSRenderer(train_batch_num, test_batch_num,
|
443 |
-
lambda_rgb_loss=lambda_rgb_loss,
|
444 |
-
lambda_eikonal_loss=lambda_eikonal_loss,
|
445 |
-
lambda_mask_loss=lambda_mask_loss,
|
446 |
-
coarse_sn=coarse_sn, fine_sn=fine_sn)
|
447 |
-
elif renderer=='ngp':
|
448 |
-
self.renderer = NeRFRenderer(train_batch_num, test_batch_num, bound=cube_bound, use_mask=use_mask, lambda_mask_loss=lambda_mask_loss, lambda_rgb_loss=lambda_rgb_loss,)
|
449 |
-
else:
|
450 |
-
raise NotImplementedError
|
451 |
-
self.validation_index = 0
|
452 |
-
|
453 |
-
def _construct_ray_batch(self, images_info):
|
454 |
-
image_num = images_info['images'].shape[0]
|
455 |
-
_, h, w, _ = images_info['images'].shape
|
456 |
-
coords = torch.stack(torch.meshgrid(torch.arange(h), torch.arange(w)), -1)[:, :, (1, 0)] # h,w,2
|
457 |
-
coords = coords.float()[None, :, :, :].repeat(image_num, 1, 1, 1) # imn,h,w,2
|
458 |
-
coords = coords.reshape(image_num, h * w, 2)
|
459 |
-
coords = torch.cat([coords, torch.ones(image_num, h * w, 1, dtype=torch.float32)], 2) # imn,h*w,3
|
460 |
-
|
461 |
-
# imn,h*w,3 @ imn,3,3 => imn,h*w,3
|
462 |
-
rays_d = coords @ torch.inverse(images_info['Ks']).permute(0, 2, 1)
|
463 |
-
poses = images_info['poses'] # imn,3,4
|
464 |
-
R, t = poses[:, :, :3], poses[:, :, 3:]
|
465 |
-
rays_d = rays_d @ R
|
466 |
-
rays_d = F.normalize(rays_d, dim=-1)
|
467 |
-
rays_o = -R.permute(0,2,1) @ t # imn,3,3 @ imn,3,1
|
468 |
-
rays_o = rays_o.permute(0, 2, 1).repeat(1, h*w, 1) # imn,h*w,3
|
469 |
-
|
470 |
-
ray_batch = {
|
471 |
-
'rgb': images_info['images'].reshape(image_num*h*w,3),
|
472 |
-
'mask': images_info['masks'].reshape(image_num*h*w,1),
|
473 |
-
'rays_o': rays_o.reshape(image_num*h*w,3).float(),
|
474 |
-
'rays_d': rays_d.reshape(image_num*h*w,3).float(),
|
475 |
-
}
|
476 |
-
return ray_batch
|
477 |
-
|
478 |
-
@staticmethod
|
479 |
-
def load_model(cfg, ckpt):
|
480 |
-
config = OmegaConf.load(cfg)
|
481 |
-
model = instantiate_from_config(config.model)
|
482 |
-
print(f'loading model from {ckpt} ...')
|
483 |
-
ckpt = torch.load(ckpt)
|
484 |
-
model.load_state_dict(ckpt['state_dict'])
|
485 |
-
model = model.cuda().eval()
|
486 |
-
return model
|
487 |
-
|
488 |
-
def _init_dataset(self):
|
489 |
-
mask_predictor = BackgroundRemoval()
|
490 |
-
# syncdreamer fixed 16 views
|
491 |
-
# self.K, self.azs, self.els, self.dists, self.poses = read_pickle(f'meta_info/camera-{self.num_images}.pkl')
|
492 |
-
# for ours+NeuS, we pre fix 36 views
|
493 |
-
self.K = np.array([[280.,0.,128.],[0.,280.,128.],[0.,0.,1.]], dtype=np.float32)
|
494 |
-
data_dir = os.path.join(self.data_path, "mario/render_sync_36_single/model/") # fixed 36 views
|
495 |
-
# get all files .npy
|
496 |
-
self.azs = []
|
497 |
-
self.els = []
|
498 |
-
self.dists = []
|
499 |
-
self.poses = []
|
500 |
-
for index in range(self.num_images):
|
501 |
-
pose = np.load(os.path.join(data_dir, "%03d.npy"%index))[:3, :] # in blender
|
502 |
-
self.poses.append(pose)
|
503 |
-
theta, azimuth, radius = get_pose(pose)
|
504 |
-
self.azs.append(azimuth)
|
505 |
-
self.els.append(theta)
|
506 |
-
self.dists.append(radius)
|
507 |
-
# stack to numpy along axis 0
|
508 |
-
self.azs = np.stack(self.azs, axis=0) # [25,]
|
509 |
-
self.els = np.stack(self.els, axis=0) # [25,]
|
510 |
-
self.dists = np.stack(self.dists, axis=0) # [25,]
|
511 |
-
self.poses = np.stack(self.poses, axis=0) # [25, 3, 4]
|
512 |
-
|
513 |
-
self.images_info = {'images': [] ,'masks': [], 'Ks': [], 'poses':[]}
|
514 |
-
|
515 |
-
img = imread(self.image_path)
|
516 |
-
|
517 |
-
for index in range(self.num_images):
|
518 |
-
rgb = np.copy(img[:,index*self.image_size:(index+1)*self.image_size,:])
|
519 |
-
# predict mask
|
520 |
-
if self.use_mask:
|
521 |
-
imsave(f'{self.log_dir}/input-{index}.png', rgb)
|
522 |
-
masked_image = mask_predictor(rgb)
|
523 |
-
imsave(f'{self.log_dir}/masked-{index}.png', masked_image)
|
524 |
-
mask = masked_image[:,:,3].astype(np.float32)/255
|
525 |
-
else:
|
526 |
-
h, w, _ = rgb.shape
|
527 |
-
mask = np.zeros([h,w], np.float32)
|
528 |
-
|
529 |
-
rgb = rgb.astype(np.float32)/255
|
530 |
-
K, pose = np.copy(self.K), self.poses[index]
|
531 |
-
self.images_info['images'].append(torch.from_numpy(rgb.astype(np.float32))) # h,w,3
|
532 |
-
self.images_info['masks'].append(torch.from_numpy(mask.astype(np.float32))) # h,w
|
533 |
-
self.images_info['Ks'].append(torch.from_numpy(K.astype(np.float32)))
|
534 |
-
self.images_info['poses'].append(torch.from_numpy(pose.astype(np.float32)))
|
535 |
-
|
536 |
-
for k, v in self.images_info.items(): self.images_info[k] = torch.stack(v, 0) # stack all values
|
537 |
-
|
538 |
-
self.train_batch = self._construct_ray_batch(self.images_info)
|
539 |
-
self.train_batch_pseudo_fg = {}
|
540 |
-
pseudo_fg_mask = torch.sum(self.train_batch['rgb']>0.99,1)!=3
|
541 |
-
for k, v in self.train_batch.items():
|
542 |
-
self.train_batch_pseudo_fg[k] = v[pseudo_fg_mask]
|
543 |
-
self.train_ray_fg_num = int(torch.sum(pseudo_fg_mask).cpu().numpy())
|
544 |
-
self.train_ray_num = self.num_images * self.image_size ** 2
|
545 |
-
self._shuffle_train_batch()
|
546 |
-
self._shuffle_train_fg_batch()
|
547 |
-
|
548 |
-
def _shuffle_train_batch(self):
|
549 |
-
self.train_batch_i = 0
|
550 |
-
shuffle_idxs = torch.randperm(self.train_ray_num, device='cpu') # shuffle
|
551 |
-
for k, v in self.train_batch.items():
|
552 |
-
self.train_batch[k] = v[shuffle_idxs]
|
553 |
-
|
554 |
-
def _shuffle_train_fg_batch(self):
|
555 |
-
self.train_batch_fg_i = 0
|
556 |
-
shuffle_idxs = torch.randperm(self.train_ray_fg_num, device='cpu') # shuffle
|
557 |
-
for k, v in self.train_batch_pseudo_fg.items():
|
558 |
-
self.train_batch_pseudo_fg[k] = v[shuffle_idxs]
|
559 |
-
|
560 |
-
|
561 |
-
def training_step(self, batch, batch_idx):
|
562 |
-
train_ray_batch = {k: v[self.train_batch_i:self.train_batch_i + self.train_batch_num].cuda() for k, v in self.train_batch.items()}
|
563 |
-
self.train_batch_i += self.train_batch_num
|
564 |
-
if self.train_batch_i + self.train_batch_num >= self.train_ray_num: self._shuffle_train_batch()
|
565 |
-
|
566 |
-
if self.train_batch_fg_num>0:
|
567 |
-
train_ray_batch_fg = {k: v[self.train_batch_fg_i:self.train_batch_fg_i+self.train_batch_fg_num].cuda() for k, v in self.train_batch_pseudo_fg.items()}
|
568 |
-
self.train_batch_fg_i += self.train_batch_fg_num
|
569 |
-
if self.train_batch_fg_i + self.train_batch_fg_num >= self.train_ray_fg_num: self._shuffle_train_fg_batch()
|
570 |
-
for k, v in train_ray_batch_fg.items():
|
571 |
-
train_ray_batch[k] = torch.cat([train_ray_batch[k], v], 0)
|
572 |
-
|
573 |
-
loss, loss_batch = self.renderer.render_with_loss(train_ray_batch, is_train=True, step=self.global_step)
|
574 |
-
self.log_dict(loss_batch, prog_bar=True, logger=True, on_step=True, on_epoch=False, rank_zero_only=True)
|
575 |
-
|
576 |
-
self.log('step', self.global_step, prog_bar=True, on_step=True, on_epoch=False, logger=False, rank_zero_only=True)
|
577 |
-
lr = self.optimizers().param_groups[0]['lr']
|
578 |
-
self.log('lr', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False, rank_zero_only=True)
|
579 |
-
return loss
|
580 |
-
|
581 |
-
def _slice_images_info(self, index):
|
582 |
-
return {k:v[index:index+1] for k, v in self.images_info.items()}
|
583 |
-
|
584 |
-
@torch.no_grad()
|
585 |
-
def validation_step(self, batch, batch_idx):
|
586 |
-
with torch.no_grad():
|
587 |
-
if self.global_rank==0:
|
588 |
-
# we output an rendering image
|
589 |
-
images_info = self._slice_images_info(self.validation_index)
|
590 |
-
self.validation_index += 1
|
591 |
-
self.validation_index %= self.num_images
|
592 |
-
|
593 |
-
test_ray_batch = self._construct_ray_batch(images_info)
|
594 |
-
test_ray_batch = {k: v.cuda() for k,v in test_ray_batch.items()}
|
595 |
-
test_ray_batch['near'], test_ray_batch['far'] = near_far_from_sphere(test_ray_batch['rays_o'], test_ray_batch['rays_d'])
|
596 |
-
render_outputs = self.renderer.render(test_ray_batch, False, self.global_step)
|
597 |
-
|
598 |
-
process = lambda x: (x.cpu().numpy() * 255).astype(np.uint8)
|
599 |
-
h, w = self.image_size, self.image_size
|
600 |
-
rgb = torch.clamp(render_outputs['rgb'].reshape(h, w, 3), max=1.0, min=0.0)
|
601 |
-
mask = torch.clamp(render_outputs['mask'].reshape(h, w, 1), max=1.0, min=0.0)
|
602 |
-
mask_ = torch.repeat_interleave(mask, 3, dim=-1)
|
603 |
-
output_image = concat_images_list(process(rgb), process(mask_))
|
604 |
-
if 'normal' in render_outputs:
|
605 |
-
normal = torch.clamp((render_outputs['normal'].reshape(h, w, 3) + 1) / 2, max=1.0, min=0.0)
|
606 |
-
normal = normal * mask # we only show foregound normal
|
607 |
-
output_image = concat_images_list(output_image, process(normal))
|
608 |
-
|
609 |
-
# save images
|
610 |
-
imsave(f'{self.log_dir}/images/{self.global_step}.jpg', output_image)
|
611 |
-
|
612 |
-
def configure_optimizers(self):
|
613 |
-
lr = self.learning_rate
|
614 |
-
opt = torch.optim.AdamW([{"params": self.renderer.parameters(), "lr": lr},], lr=lr)
|
615 |
-
|
616 |
-
def schedule_fn(step):
|
617 |
-
total_step = self.total_steps
|
618 |
-
warm_up_step = self.warm_up_steps
|
619 |
-
warm_up_init = 0.02
|
620 |
-
warm_up_end = 1.0
|
621 |
-
final_lr = 0.02
|
622 |
-
interval = 1000
|
623 |
-
times = total_step // interval
|
624 |
-
ratio = np.power(final_lr, 1/times)
|
625 |
-
if step<warm_up_step:
|
626 |
-
learning_rate = (step / warm_up_step) * (warm_up_end - warm_up_init) + warm_up_init
|
627 |
-
else:
|
628 |
-
learning_rate = ratio ** (step // interval) * warm_up_end
|
629 |
-
return learning_rate
|
630 |
-
|
631 |
-
if self.use_warm_up:
|
632 |
-
scheduler = [{
|
633 |
-
'scheduler': LambdaLR(opt, lr_lambda=schedule_fn),
|
634 |
-
'interval': 'step',
|
635 |
-
'frequency': 1
|
636 |
-
}]
|
637 |
-
else:
|
638 |
-
scheduler = []
|
639 |
-
return [opt], scheduler
|
640 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3drecon/run_NeuS.py
DELETED
@@ -1,32 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import numpy as np
|
3 |
-
from tqdm import tqdm
|
4 |
-
|
5 |
-
# ours + NeuS
|
6 |
-
DATA_DIR = "/home/xin/data/EscherNet/Data/GSO30" # GSO
|
7 |
-
exp_dir = "/home/xin/6DoF/GSO3D/"
|
8 |
-
|
9 |
-
config = "configs/neus_36.yaml"
|
10 |
-
exps = [1]
|
11 |
-
# exps = [1, 2, 3, 5, 10]
|
12 |
-
|
13 |
-
for exp in exps:
|
14 |
-
OUTPUT_DIR = os.path.join(exp_dir, f"logs_GSO_T{exp}M36_99k")
|
15 |
-
output_NeuS = f"ours_GSO_T{exp}"
|
16 |
-
os.makedirs(output_NeuS, exist_ok=True)
|
17 |
-
obj_names = os.listdir(DATA_DIR)
|
18 |
-
for obj_name in tqdm(obj_names):
|
19 |
-
if os.path.exists(os.path.join(output_NeuS, "NeuS", obj_name, "mesh.ply")):
|
20 |
-
print("NeuS already trained for: ", obj_name)
|
21 |
-
continue
|
22 |
-
# remove the folder for new training
|
23 |
-
os.system(f"rm -rf {output_NeuS}/NeuS/{obj_name}")
|
24 |
-
print("Training NeuS for: ", obj_name)
|
25 |
-
input_img = os.path.join(OUTPUT_DIR, obj_name, "0.png")
|
26 |
-
# input_img = os.path.join(OUTPUT_DIR, obj_name, "gt.png") # ground truth image
|
27 |
-
cmd = f"python train_renderer.py -i {input_img} \
|
28 |
-
-d {DATA_DIR} \
|
29 |
-
-n {obj_name} \
|
30 |
-
-b {config} \
|
31 |
-
-l {output_NeuS}/NeuS"
|
32 |
-
os.system(cmd)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3drecon/train_renderer.py
DELETED
@@ -1,188 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
|
3 |
-
import imageio
|
4 |
-
import numpy as np
|
5 |
-
import torch
|
6 |
-
import torch.nn.functional as F
|
7 |
-
from pathlib import Path
|
8 |
-
|
9 |
-
import trimesh
|
10 |
-
from omegaconf import OmegaConf
|
11 |
-
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, Callback
|
12 |
-
from pytorch_lightning.loggers import TensorBoardLogger
|
13 |
-
from pytorch_lightning import Trainer
|
14 |
-
from skimage.io import imsave
|
15 |
-
from tqdm import tqdm
|
16 |
-
|
17 |
-
import mcubes
|
18 |
-
|
19 |
-
from renderer.renderer import NeuSRenderer, DEFAULT_SIDE_LENGTH
|
20 |
-
from util import instantiate_from_config, read_pickle
|
21 |
-
|
22 |
-
class ResumeCallBacks(Callback):
|
23 |
-
def __init__(self):
|
24 |
-
pass
|
25 |
-
|
26 |
-
def on_train_start(self, trainer, pl_module):
|
27 |
-
pl_module.optimizers().param_groups = pl_module.optimizers()._optimizer.param_groups
|
28 |
-
|
29 |
-
def render_images(model, output,):
|
30 |
-
# render from model
|
31 |
-
n = 180
|
32 |
-
azimuths = (np.arange(n) / n * np.pi * 2).astype(np.float32)
|
33 |
-
elevations = np.deg2rad(np.asarray([30] * n).astype(np.float32))
|
34 |
-
K, _, _, _, poses = read_pickle(f'meta_info/camera-16.pkl')
|
35 |
-
output_points
|
36 |
-
h, w = 256, 256
|
37 |
-
default_size = 256
|
38 |
-
K = np.diag([w/default_size,h/default_size,1.0]) @ K
|
39 |
-
imgs = []
|
40 |
-
for ni in tqdm(range(n)):
|
41 |
-
# R = euler2mat(azimuths[ni], elevations[ni], 0, 'szyx')
|
42 |
-
# R = np.asarray([[0,-1,0],[0,0,-1],[1,0,0]]) @ R
|
43 |
-
e, a = elevations[ni], azimuths[ni]
|
44 |
-
row1 = np.asarray([np.sin(e)*np.cos(a),np.sin(e)*np.sin(a),-np.cos(e)])
|
45 |
-
row0 = np.asarray([-np.sin(a),np.cos(a), 0])
|
46 |
-
row2 = np.cross(row0, row1)
|
47 |
-
R = np.stack([row0,row1,row2],0)
|
48 |
-
t = np.asarray([0,0,1.5])
|
49 |
-
pose = np.concatenate([R,t[:,None]],1)
|
50 |
-
pose_ = torch.from_numpy(pose.astype(np.float32)).unsqueeze(0)
|
51 |
-
K_ = torch.from_numpy(K.astype(np.float32)).unsqueeze(0) # [1,3,3]
|
52 |
-
|
53 |
-
coords = torch.stack(torch.meshgrid(torch.arange(h), torch.arange(w)), -1)[:, :, (1, 0)] # h,w,2
|
54 |
-
coords = coords.float()[None, :, :, :].repeat(1, 1, 1, 1) # imn,h,w,2
|
55 |
-
coords = coords.reshape(1, h * w, 2)
|
56 |
-
coords = torch.cat([coords, torch.ones(1, h * w, 1, dtype=torch.float32)], 2) # imn,h*w,3
|
57 |
-
|
58 |
-
# imn,h*w,3 @ imn,3,3 => imn,h*w,3
|
59 |
-
rays_d = coords @ torch.inverse(K_).permute(0, 2, 1)
|
60 |
-
R, t = pose_[:, :, :3], pose_[:, :, 3:]
|
61 |
-
rays_d = rays_d @ R
|
62 |
-
rays_d = F.normalize(rays_d, dim=-1)
|
63 |
-
rays_o = -R.permute(0, 2, 1) @ t # imn,3,3 @ imn,3,1
|
64 |
-
rays_o = rays_o.permute(0, 2, 1).repeat(1, h * w, 1) # imn,h*w,3
|
65 |
-
|
66 |
-
ray_batch = {
|
67 |
-
'rays_o': rays_o.reshape(-1,3).cuda(),
|
68 |
-
'rays_d': rays_d.reshape(-1,3).cuda(),
|
69 |
-
}
|
70 |
-
with torch.no_grad():
|
71 |
-
image = model.renderer.render(ray_batch,False,5000)['rgb'].reshape(h,w,3)
|
72 |
-
image = (image.cpu().numpy() * 255).astype(np.uint8)
|
73 |
-
imgs.append(image)
|
74 |
-
|
75 |
-
imageio.mimsave(f'{output}/rendering.mp4', imgs, fps=30)
|
76 |
-
|
77 |
-
def extract_fields(bound_min, bound_max, resolution, query_func, batch_size=64, outside_val=1.0):
|
78 |
-
N = batch_size
|
79 |
-
X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N)
|
80 |
-
Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N)
|
81 |
-
Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N)
|
82 |
-
|
83 |
-
u = np.zeros([resolution, resolution, resolution], dtype=np.float32)
|
84 |
-
with torch.no_grad():
|
85 |
-
for xi, xs in enumerate(X):
|
86 |
-
for yi, ys in enumerate(Y):
|
87 |
-
for zi, zs in enumerate(Z):
|
88 |
-
xx, yy, zz = torch.meshgrid(xs, ys, zs)
|
89 |
-
pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1).cuda()
|
90 |
-
val = query_func(pts).detach()
|
91 |
-
outside_mask = torch.norm(pts,dim=-1)>=1.0
|
92 |
-
val[outside_mask]=outside_val
|
93 |
-
val = val.reshape(len(xs), len(ys), len(zs)).cpu().numpy()
|
94 |
-
u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val
|
95 |
-
return u
|
96 |
-
|
97 |
-
def extract_geometry(bound_min, bound_max, resolution, threshold, query_func, color_func, outside_val=1.0):
|
98 |
-
u = extract_fields(bound_min, bound_max, resolution, query_func, outside_val=outside_val)
|
99 |
-
vertices, triangles = mcubes.marching_cubes(u, threshold)
|
100 |
-
b_max_np = bound_max.detach().cpu().numpy()
|
101 |
-
b_min_np = bound_min.detach().cpu().numpy()
|
102 |
-
|
103 |
-
vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
|
104 |
-
vertex_colors = color_func(vertices)
|
105 |
-
return vertices, triangles, vertex_colors
|
106 |
-
|
107 |
-
def extract_mesh(model, output, resolution=512):
|
108 |
-
if not isinstance(model.renderer, NeuSRenderer): return
|
109 |
-
bbox_min = -torch.ones(3)*DEFAULT_SIDE_LENGTH
|
110 |
-
bbox_max = torch.ones(3)*DEFAULT_SIDE_LENGTH
|
111 |
-
with torch.no_grad():
|
112 |
-
vertices, triangles, vertex_colors = extract_geometry(bbox_min, bbox_max, resolution, 0, lambda x: model.renderer.sdf_network.sdf(x), lambda x: model.renderer.get_vertex_colors(x))
|
113 |
-
|
114 |
-
# output geometry
|
115 |
-
mesh = trimesh.Trimesh(vertices, triangles, vertex_colors=vertex_colors)
|
116 |
-
mesh.export(str(f'{output}/mesh.ply'))
|
117 |
-
|
118 |
-
def main():
|
119 |
-
parser = argparse.ArgumentParser()
|
120 |
-
parser.add_argument('-i', '--image_path', type=str, required=True)
|
121 |
-
parser.add_argument('-n', '--name', type=str, required=True)
|
122 |
-
parser.add_argument('-b', '--base', type=str, default='configs/neus.yaml')
|
123 |
-
parser.add_argument('-d', '--data_path', type=str, default='/data/GSO/')
|
124 |
-
parser.add_argument('-l', '--log', type=str, default='output/renderer')
|
125 |
-
parser.add_argument('-s', '--seed', type=int, default=6033)
|
126 |
-
parser.add_argument('-g', '--gpus', type=str, default='0,')
|
127 |
-
parser.add_argument('-r', '--resume', action='store_true', default=False, dest='resume')
|
128 |
-
parser.add_argument('--fp16', action='store_true', default=False, dest='fp16')
|
129 |
-
opt = parser.parse_args()
|
130 |
-
# seed_everything(opt.seed)
|
131 |
-
|
132 |
-
# configs
|
133 |
-
cfg = OmegaConf.load(opt.base)
|
134 |
-
name = opt.name
|
135 |
-
log_dir, ckpt_dir = Path(opt.log) / name, Path(opt.log) / name / 'ckpt'
|
136 |
-
cfg.model.params['image_path'] = opt.image_path
|
137 |
-
cfg.model.params['log_dir'] = log_dir
|
138 |
-
cfg.model.params['data_path'] = opt.data_path
|
139 |
-
|
140 |
-
# setup
|
141 |
-
log_dir.mkdir(exist_ok=True, parents=True)
|
142 |
-
ckpt_dir.mkdir(exist_ok=True, parents=True)
|
143 |
-
trainer_config = cfg.trainer
|
144 |
-
callback_config = cfg.callbacks
|
145 |
-
model_config = cfg.model
|
146 |
-
data_config = cfg.data
|
147 |
-
|
148 |
-
data_config.params.seed = opt.seed
|
149 |
-
data = instantiate_from_config(data_config)
|
150 |
-
data.prepare_data()
|
151 |
-
data.setup('fit')
|
152 |
-
|
153 |
-
model = instantiate_from_config(model_config,)
|
154 |
-
model.cpu()
|
155 |
-
model.learning_rate = model_config.base_lr
|
156 |
-
|
157 |
-
# logger
|
158 |
-
logger = TensorBoardLogger(save_dir=log_dir, name='tensorboard_logs')
|
159 |
-
callbacks=[]
|
160 |
-
callbacks.append(LearningRateMonitor(logging_interval='step'))
|
161 |
-
callbacks.append(ModelCheckpoint(dirpath=ckpt_dir, filename="{epoch:06}", verbose=True, save_last=True, every_n_train_steps=callback_config.save_interval))
|
162 |
-
|
163 |
-
# trainer
|
164 |
-
trainer_config.update({
|
165 |
-
"accelerator": "cuda", "check_val_every_n_epoch": None,
|
166 |
-
"benchmark": True, "num_sanity_val_steps": 0,
|
167 |
-
"devices": 1, "gpus": opt.gpus,
|
168 |
-
})
|
169 |
-
if opt.fp16:
|
170 |
-
trainer_config['precision']=16
|
171 |
-
|
172 |
-
if opt.resume:
|
173 |
-
callbacks.append(ResumeCallBacks())
|
174 |
-
trainer_config['resume_from_checkpoint'] = str(ckpt_dir / 'last.ckpt')
|
175 |
-
else:
|
176 |
-
if (ckpt_dir / 'last.ckpt').exists():
|
177 |
-
raise RuntimeError(f"checkpoint {ckpt_dir / 'last.ckpt'} existing ...")
|
178 |
-
trainer = Trainer.from_argparse_args(args=argparse.Namespace(), **trainer_config, logger=logger, callbacks=callbacks)
|
179 |
-
|
180 |
-
trainer.fit(model, data)
|
181 |
-
|
182 |
-
model = model.cuda().eval()
|
183 |
-
|
184 |
-
# render_images(model, log_dir)
|
185 |
-
extract_mesh(model, log_dir)
|
186 |
-
|
187 |
-
if __name__=="__main__":
|
188 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3drecon/util.py
DELETED
@@ -1,54 +0,0 @@
|
|
1 |
-
import importlib
|
2 |
-
import pickle
|
3 |
-
import numpy as np
|
4 |
-
import cv2
|
5 |
-
|
6 |
-
def instantiate_from_config(config):
|
7 |
-
if not "target" in config:
|
8 |
-
if config == '__is_first_stage__':
|
9 |
-
return None
|
10 |
-
elif config == "__is_unconditional__":
|
11 |
-
return None
|
12 |
-
raise KeyError("Expected key `target` to instantiate.")
|
13 |
-
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
14 |
-
|
15 |
-
|
16 |
-
def get_obj_from_str(string, reload=False):
|
17 |
-
module, cls = string.rsplit(".", 1)
|
18 |
-
if reload:
|
19 |
-
module_imp = importlib.import_module(module)
|
20 |
-
importlib.reload(module_imp)
|
21 |
-
return getattr(importlib.import_module(module, package=None), cls)
|
22 |
-
|
23 |
-
def read_pickle(pkl_path):
|
24 |
-
with open(pkl_path, 'rb') as f:
|
25 |
-
return pickle.load(f)
|
26 |
-
|
27 |
-
def output_points(fn,pts,colors=None):
|
28 |
-
with open(fn, 'w') as f:
|
29 |
-
for pi, pt in enumerate(pts):
|
30 |
-
f.write(f'{pt[0]:.6f} {pt[1]:.6f} {pt[2]:.6f} ')
|
31 |
-
if colors is not None:
|
32 |
-
f.write(f'{int(colors[pi,0])} {int(colors[pi,1])} {int(colors[pi,2])}')
|
33 |
-
f.write('\n')
|
34 |
-
|
35 |
-
def concat_images(img0,img1,vert=False):
|
36 |
-
if not vert:
|
37 |
-
h0,h1=img0.shape[0],img1.shape[0],
|
38 |
-
if h0<h1: img0=cv2.copyMakeBorder(img0,0,h1-h0,0,0,borderType=cv2.BORDER_CONSTANT,value=0)
|
39 |
-
if h1<h0: img1=cv2.copyMakeBorder(img1,0,h0-h1,0,0,borderType=cv2.BORDER_CONSTANT,value=0)
|
40 |
-
img = np.concatenate([img0, img1], axis=1)
|
41 |
-
else:
|
42 |
-
w0,w1=img0.shape[1],img1.shape[1]
|
43 |
-
if w0<w1: img0=cv2.copyMakeBorder(img0,0,0,0,w1-w0,borderType=cv2.BORDER_CONSTANT,value=0)
|
44 |
-
if w1<w0: img1=cv2.copyMakeBorder(img1,0,0,0,w0-w1,borderType=cv2.BORDER_CONSTANT,value=0)
|
45 |
-
img = np.concatenate([img0, img1], axis=0)
|
46 |
-
|
47 |
-
return img
|
48 |
-
|
49 |
-
def concat_images_list(*args,vert=False):
|
50 |
-
if len(args)==1: return args[0]
|
51 |
-
img_out=args[0]
|
52 |
-
for img in args[1:]:
|
53 |
-
img_out=concat_images(img_out,img,vert)
|
54 |
-
return img_out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4DoF/CN_encoder.py
DELETED
@@ -1,36 +0,0 @@
|
|
1 |
-
from transformers import ConvNextV2Model
|
2 |
-
import torch
|
3 |
-
from typing import Optional
|
4 |
-
import einops
|
5 |
-
|
6 |
-
class CN_encoder(ConvNextV2Model):
|
7 |
-
def __init__(self, config):
|
8 |
-
super().__init__(config)
|
9 |
-
|
10 |
-
def forward(
|
11 |
-
self,
|
12 |
-
pixel_values: torch.FloatTensor = None,
|
13 |
-
output_hidden_states: Optional[bool] = None,
|
14 |
-
return_dict: Optional[bool] = None,
|
15 |
-
):
|
16 |
-
output_hidden_states = (
|
17 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
18 |
-
)
|
19 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
20 |
-
|
21 |
-
if pixel_values is None:
|
22 |
-
raise ValueError("You have to specify pixel_values")
|
23 |
-
|
24 |
-
embedding_output = self.embeddings(pixel_values)
|
25 |
-
|
26 |
-
encoder_outputs = self.encoder(
|
27 |
-
embedding_output,
|
28 |
-
output_hidden_states=output_hidden_states,
|
29 |
-
return_dict=return_dict,
|
30 |
-
)
|
31 |
-
|
32 |
-
last_hidden_state = encoder_outputs[0]
|
33 |
-
image_embeddings = einops.rearrange(last_hidden_state, 'b c h w -> b (h w) c')
|
34 |
-
image_embeddings = self.layernorm(image_embeddings)
|
35 |
-
|
36 |
-
return image_embeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4DoF/dataset.py
DELETED
@@ -1,228 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import math
|
3 |
-
from pathlib import Path
|
4 |
-
import torch
|
5 |
-
import torchvision
|
6 |
-
from torch.utils.data import Dataset, DataLoader
|
7 |
-
from torchvision import transforms
|
8 |
-
from PIL import Image
|
9 |
-
import numpy as np
|
10 |
-
import webdataset as wds
|
11 |
-
from torch.utils.data.distributed import DistributedSampler
|
12 |
-
import matplotlib.pyplot as plt
|
13 |
-
import sys
|
14 |
-
|
15 |
-
class ObjaverseDataLoader():
|
16 |
-
def __init__(self, root_dir, batch_size, total_view=12, num_workers=4):
|
17 |
-
self.root_dir = root_dir
|
18 |
-
self.batch_size = batch_size
|
19 |
-
self.num_workers = num_workers
|
20 |
-
self.total_view = total_view
|
21 |
-
|
22 |
-
image_transforms = [torchvision.transforms.Resize((256, 256)),
|
23 |
-
transforms.ToTensor(),
|
24 |
-
transforms.Normalize([0.5], [0.5])]
|
25 |
-
self.image_transforms = torchvision.transforms.Compose(image_transforms)
|
26 |
-
|
27 |
-
def train_dataloader(self):
|
28 |
-
dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=False,
|
29 |
-
image_transforms=self.image_transforms)
|
30 |
-
# sampler = DistributedSampler(dataset)
|
31 |
-
return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
|
32 |
-
# sampler=sampler)
|
33 |
-
|
34 |
-
def val_dataloader(self):
|
35 |
-
dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=True,
|
36 |
-
image_transforms=self.image_transforms)
|
37 |
-
sampler = DistributedSampler(dataset)
|
38 |
-
return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
|
39 |
-
|
40 |
-
def cartesian_to_spherical(xyz):
|
41 |
-
ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
|
42 |
-
xy = xyz[:, 0] ** 2 + xyz[:, 1] ** 2
|
43 |
-
z = np.sqrt(xy + xyz[:, 2] ** 2)
|
44 |
-
theta = np.arctan2(np.sqrt(xy), xyz[:, 2]) # for elevation angle defined from Z-axis down
|
45 |
-
# ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up
|
46 |
-
azimuth = np.arctan2(xyz[:, 1], xyz[:, 0])
|
47 |
-
return np.array([theta, azimuth, z])
|
48 |
-
|
49 |
-
def get_pose(target_RT):
|
50 |
-
target_RT = target_RT[:3, :]
|
51 |
-
R, T = target_RT[:3, :3], target_RT[:, -1]
|
52 |
-
T_target = -R.T @ T
|
53 |
-
theta_target, azimuth_target, z_target = cartesian_to_spherical(T_target[None, :])
|
54 |
-
# assert if z_target is out of range
|
55 |
-
if z_target.item() < 1.5 or z_target.item() > 2.2:
|
56 |
-
# print('z_target out of range 1.5-2.2', z_target.item())
|
57 |
-
z_target = np.clip(z_target.item(), 1.5, 2.2)
|
58 |
-
# with log scale for radius
|
59 |
-
target_T = torch.tensor([theta_target.item(), azimuth_target.item(), (np.log(z_target.item()) - np.log(1.5))/(np.log(2.2)-np.log(1.5)) * torch.pi, torch.tensor(0)])
|
60 |
-
assert torch.all(target_T <= torch.pi) and torch.all(target_T >= -torch.pi)
|
61 |
-
return target_T.numpy()
|
62 |
-
|
63 |
-
class ObjaverseData(Dataset):
|
64 |
-
def __init__(self,
|
65 |
-
root_dir='.objaverse/hf-objaverse-v1/views',
|
66 |
-
image_transforms=None,
|
67 |
-
total_view=12,
|
68 |
-
validation=False,
|
69 |
-
T_in=1,
|
70 |
-
T_out=1,
|
71 |
-
fix_sample=False,
|
72 |
-
) -> None:
|
73 |
-
"""Create a dataset from a folder of images.
|
74 |
-
If you pass in a root directory it will be searched for images
|
75 |
-
ending in ext (ext can be a list)
|
76 |
-
"""
|
77 |
-
self.root_dir = Path(root_dir)
|
78 |
-
self.total_view = total_view
|
79 |
-
self.T_in = T_in
|
80 |
-
self.T_out = T_out
|
81 |
-
self.fix_sample = fix_sample
|
82 |
-
|
83 |
-
self.paths = []
|
84 |
-
# # include all folders
|
85 |
-
# for folder in os.listdir(self.root_dir):
|
86 |
-
# if os.path.isdir(os.path.join(self.root_dir, folder)):
|
87 |
-
# self.paths.append(folder)
|
88 |
-
# load ids from .npy so we have exactly the same ids/order
|
89 |
-
self.paths = np.load("../scripts/obj_ids.npy")
|
90 |
-
# # only use 100K objects for ablation study
|
91 |
-
# self.paths = self.paths[:100000]
|
92 |
-
total_objects = len(self.paths)
|
93 |
-
assert total_objects == 790152, 'total objects %d' % total_objects
|
94 |
-
if validation:
|
95 |
-
self.paths = self.paths[math.floor(total_objects / 100. * 99.):] # used last 1% as validation
|
96 |
-
else:
|
97 |
-
self.paths = self.paths[:math.floor(total_objects / 100. * 99.)] # used first 99% as training
|
98 |
-
print('============= length of dataset %d =============' % len(self.paths))
|
99 |
-
self.tform = image_transforms
|
100 |
-
|
101 |
-
downscale = 512 / 256.
|
102 |
-
self.fx = 560. / downscale
|
103 |
-
self.fy = 560. / downscale
|
104 |
-
self.intrinsic = torch.tensor([[self.fx, 0, 128., 0, self.fy, 128., 0, 0, 1.]], dtype=torch.float64).view(3, 3)
|
105 |
-
|
106 |
-
def __len__(self):
|
107 |
-
return len(self.paths)
|
108 |
-
|
109 |
-
def cartesian_to_spherical(self, xyz):
|
110 |
-
ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
|
111 |
-
xy = xyz[:, 0] ** 2 + xyz[:, 1] ** 2
|
112 |
-
z = np.sqrt(xy + xyz[:, 2] ** 2)
|
113 |
-
theta = np.arctan2(np.sqrt(xy), xyz[:, 2]) # for elevation angle defined from Z-axis down
|
114 |
-
# ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up
|
115 |
-
azimuth = np.arctan2(xyz[:, 1], xyz[:, 0])
|
116 |
-
return np.array([theta, azimuth, z])
|
117 |
-
|
118 |
-
def get_T(self, target_RT, cond_RT):
|
119 |
-
R, T = target_RT[:3, :3], target_RT[:, -1]
|
120 |
-
T_target = -R.T @ T
|
121 |
-
|
122 |
-
R, T = cond_RT[:3, :3], cond_RT[:, -1]
|
123 |
-
T_cond = -R.T @ T
|
124 |
-
|
125 |
-
theta_cond, azimuth_cond, z_cond = self.cartesian_to_spherical(T_cond[None, :])
|
126 |
-
theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :])
|
127 |
-
|
128 |
-
d_theta = theta_target - theta_cond
|
129 |
-
d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi)
|
130 |
-
d_z = z_target - z_cond
|
131 |
-
|
132 |
-
d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()])
|
133 |
-
return d_T
|
134 |
-
|
135 |
-
def get_pose(self, target_RT):
|
136 |
-
R, T = target_RT[:3, :3], target_RT[:, -1]
|
137 |
-
T_target = -R.T @ T
|
138 |
-
theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :])
|
139 |
-
# assert if z_target is out of range
|
140 |
-
if z_target.item() < 1.5 or z_target.item() > 2.2:
|
141 |
-
# print('z_target out of range 1.5-2.2', z_target.item())
|
142 |
-
z_target = np.clip(z_target.item(), 1.5, 2.2)
|
143 |
-
# with log scale for radius
|
144 |
-
target_T = torch.tensor([theta_target.item(), azimuth_target.item(), (np.log(z_target.item()) - np.log(1.5))/(np.log(2.2)-np.log(1.5)) * torch.pi, torch.tensor(0)])
|
145 |
-
assert torch.all(target_T <= torch.pi) and torch.all(target_T >= -torch.pi)
|
146 |
-
return target_T
|
147 |
-
|
148 |
-
def load_im(self, path, color):
|
149 |
-
'''
|
150 |
-
replace background pixel with random color in rendering
|
151 |
-
'''
|
152 |
-
try:
|
153 |
-
img = plt.imread(path)
|
154 |
-
except:
|
155 |
-
print(path)
|
156 |
-
sys.exit()
|
157 |
-
img[img[:, :, -1] == 0.] = color
|
158 |
-
img = Image.fromarray(np.uint8(img[:, :, :3] * 255.))
|
159 |
-
return img
|
160 |
-
|
161 |
-
def __getitem__(self, index):
|
162 |
-
data = {}
|
163 |
-
total_view = 12
|
164 |
-
|
165 |
-
if self.fix_sample:
|
166 |
-
if self.T_out > 1:
|
167 |
-
indexes = range(total_view)
|
168 |
-
index_targets = list(indexes[:2]) + list(indexes[-(self.T_out-2):])
|
169 |
-
index_inputs = indexes[1:self.T_in+1] # one overlap identity
|
170 |
-
else:
|
171 |
-
indexes = range(total_view)
|
172 |
-
index_targets = indexes[:self.T_out]
|
173 |
-
index_inputs = indexes[self.T_out-1:self.T_in+self.T_out-1] # one overlap identity
|
174 |
-
else:
|
175 |
-
assert self.T_in + self.T_out <= total_view
|
176 |
-
# training with replace, including identity
|
177 |
-
indexes = np.random.choice(range(total_view), self.T_in+self.T_out, replace=True)
|
178 |
-
index_inputs = indexes[:self.T_in]
|
179 |
-
index_targets = indexes[self.T_in:]
|
180 |
-
filename = os.path.join(self.root_dir, self.paths[index])
|
181 |
-
|
182 |
-
color = [1., 1., 1., 1.]
|
183 |
-
|
184 |
-
try:
|
185 |
-
input_ims = []
|
186 |
-
target_ims = []
|
187 |
-
target_Ts = []
|
188 |
-
cond_Ts = []
|
189 |
-
for i, index_input in enumerate(index_inputs):
|
190 |
-
input_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_input), color))
|
191 |
-
input_ims.append(input_im)
|
192 |
-
input_RT = np.load(os.path.join(filename, '%03d.npy' % index_input))
|
193 |
-
cond_Ts.append(self.get_pose(input_RT))
|
194 |
-
for i, index_target in enumerate(index_targets):
|
195 |
-
target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color))
|
196 |
-
target_ims.append(target_im)
|
197 |
-
target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target))
|
198 |
-
target_Ts.append(self.get_pose(target_RT))
|
199 |
-
except:
|
200 |
-
print('error loading data ', filename)
|
201 |
-
filename = os.path.join(self.root_dir, '0a01f314e2864711aa7e33bace4bd8c8') # this one we know is valid
|
202 |
-
input_ims = []
|
203 |
-
target_ims = []
|
204 |
-
target_Ts = []
|
205 |
-
cond_Ts = []
|
206 |
-
# very hacky solution, sorry about this
|
207 |
-
for i, index_input in enumerate(index_inputs):
|
208 |
-
input_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_input), color))
|
209 |
-
input_ims.append(input_im)
|
210 |
-
input_RT = np.load(os.path.join(filename, '%03d.npy' % index_input))
|
211 |
-
cond_Ts.append(self.get_pose(input_RT))
|
212 |
-
for i, index_target in enumerate(index_targets):
|
213 |
-
target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color))
|
214 |
-
target_ims.append(target_im)
|
215 |
-
target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target))
|
216 |
-
target_Ts.append(self.get_pose(target_RT))
|
217 |
-
|
218 |
-
# stack to batch
|
219 |
-
data['image_input'] = torch.stack(input_ims, dim=0)
|
220 |
-
data['image_target'] = torch.stack(target_ims, dim=0)
|
221 |
-
data['pose_out'] = torch.stack(target_Ts, dim=0)
|
222 |
-
data['pose_in'] = torch.stack(cond_Ts, dim=0)
|
223 |
-
|
224 |
-
return data
|
225 |
-
|
226 |
-
def process_im(self, im):
|
227 |
-
im = im.convert("RGB")
|
228 |
-
return self.tform(im)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4DoF/diffusers/__init__.py
DELETED
@@ -1,281 +0,0 @@
|
|
1 |
-
__version__ = "0.18.2"
|
2 |
-
|
3 |
-
from .configuration_utils import ConfigMixin
|
4 |
-
from .utils import (
|
5 |
-
OptionalDependencyNotAvailable,
|
6 |
-
is_flax_available,
|
7 |
-
is_inflect_available,
|
8 |
-
is_invisible_watermark_available,
|
9 |
-
is_k_diffusion_available,
|
10 |
-
is_k_diffusion_version,
|
11 |
-
is_librosa_available,
|
12 |
-
is_note_seq_available,
|
13 |
-
is_onnx_available,
|
14 |
-
is_scipy_available,
|
15 |
-
is_torch_available,
|
16 |
-
is_torchsde_available,
|
17 |
-
is_transformers_available,
|
18 |
-
is_transformers_version,
|
19 |
-
is_unidecode_available,
|
20 |
-
logging,
|
21 |
-
)
|
22 |
-
|
23 |
-
|
24 |
-
try:
|
25 |
-
if not is_onnx_available():
|
26 |
-
raise OptionalDependencyNotAvailable()
|
27 |
-
except OptionalDependencyNotAvailable:
|
28 |
-
from .utils.dummy_onnx_objects import * # noqa F403
|
29 |
-
else:
|
30 |
-
from .pipelines import OnnxRuntimeModel
|
31 |
-
|
32 |
-
try:
|
33 |
-
if not is_torch_available():
|
34 |
-
raise OptionalDependencyNotAvailable()
|
35 |
-
except OptionalDependencyNotAvailable:
|
36 |
-
from .utils.dummy_pt_objects import * # noqa F403
|
37 |
-
else:
|
38 |
-
from .models import (
|
39 |
-
AutoencoderKL,
|
40 |
-
ControlNetModel,
|
41 |
-
ModelMixin,
|
42 |
-
PriorTransformer,
|
43 |
-
T5FilmDecoder,
|
44 |
-
Transformer2DModel,
|
45 |
-
UNet1DModel,
|
46 |
-
UNet2DConditionModel,
|
47 |
-
UNet2DModel,
|
48 |
-
UNet3DConditionModel,
|
49 |
-
VQModel,
|
50 |
-
)
|
51 |
-
from .optimization import (
|
52 |
-
get_constant_schedule,
|
53 |
-
get_constant_schedule_with_warmup,
|
54 |
-
get_cosine_schedule_with_warmup,
|
55 |
-
get_cosine_with_hard_restarts_schedule_with_warmup,
|
56 |
-
get_linear_schedule_with_warmup,
|
57 |
-
get_polynomial_decay_schedule_with_warmup,
|
58 |
-
get_scheduler,
|
59 |
-
)
|
60 |
-
from .pipelines import (
|
61 |
-
AudioPipelineOutput,
|
62 |
-
ConsistencyModelPipeline,
|
63 |
-
DanceDiffusionPipeline,
|
64 |
-
DDIMPipeline,
|
65 |
-
DDPMPipeline,
|
66 |
-
DiffusionPipeline,
|
67 |
-
DiTPipeline,
|
68 |
-
ImagePipelineOutput,
|
69 |
-
KarrasVePipeline,
|
70 |
-
LDMPipeline,
|
71 |
-
LDMSuperResolutionPipeline,
|
72 |
-
PNDMPipeline,
|
73 |
-
RePaintPipeline,
|
74 |
-
ScoreSdeVePipeline,
|
75 |
-
)
|
76 |
-
from .schedulers import (
|
77 |
-
CMStochasticIterativeScheduler,
|
78 |
-
DDIMInverseScheduler,
|
79 |
-
DDIMParallelScheduler,
|
80 |
-
DDIMScheduler,
|
81 |
-
DDPMParallelScheduler,
|
82 |
-
DDPMScheduler,
|
83 |
-
DEISMultistepScheduler,
|
84 |
-
DPMSolverMultistepInverseScheduler,
|
85 |
-
DPMSolverMultistepScheduler,
|
86 |
-
DPMSolverSinglestepScheduler,
|
87 |
-
EulerAncestralDiscreteScheduler,
|
88 |
-
EulerDiscreteScheduler,
|
89 |
-
HeunDiscreteScheduler,
|
90 |
-
IPNDMScheduler,
|
91 |
-
KarrasVeScheduler,
|
92 |
-
KDPM2AncestralDiscreteScheduler,
|
93 |
-
KDPM2DiscreteScheduler,
|
94 |
-
PNDMScheduler,
|
95 |
-
RePaintScheduler,
|
96 |
-
SchedulerMixin,
|
97 |
-
ScoreSdeVeScheduler,
|
98 |
-
UnCLIPScheduler,
|
99 |
-
UniPCMultistepScheduler,
|
100 |
-
VQDiffusionScheduler,
|
101 |
-
)
|
102 |
-
from .training_utils import EMAModel
|
103 |
-
|
104 |
-
try:
|
105 |
-
if not (is_torch_available() and is_scipy_available()):
|
106 |
-
raise OptionalDependencyNotAvailable()
|
107 |
-
except OptionalDependencyNotAvailable:
|
108 |
-
from .utils.dummy_torch_and_scipy_objects import * # noqa F403
|
109 |
-
else:
|
110 |
-
from .schedulers import LMSDiscreteScheduler
|
111 |
-
|
112 |
-
try:
|
113 |
-
if not (is_torch_available() and is_torchsde_available()):
|
114 |
-
raise OptionalDependencyNotAvailable()
|
115 |
-
except OptionalDependencyNotAvailable:
|
116 |
-
from .utils.dummy_torch_and_torchsde_objects import * # noqa F403
|
117 |
-
else:
|
118 |
-
from .schedulers import DPMSolverSDEScheduler
|
119 |
-
|
120 |
-
try:
|
121 |
-
if not (is_torch_available() and is_transformers_available()):
|
122 |
-
raise OptionalDependencyNotAvailable()
|
123 |
-
except OptionalDependencyNotAvailable:
|
124 |
-
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
|
125 |
-
else:
|
126 |
-
from .pipelines import (
|
127 |
-
AltDiffusionImg2ImgPipeline,
|
128 |
-
AltDiffusionPipeline,
|
129 |
-
AudioLDMPipeline,
|
130 |
-
CycleDiffusionPipeline,
|
131 |
-
IFImg2ImgPipeline,
|
132 |
-
IFImg2ImgSuperResolutionPipeline,
|
133 |
-
IFInpaintingPipeline,
|
134 |
-
IFInpaintingSuperResolutionPipeline,
|
135 |
-
IFPipeline,
|
136 |
-
IFSuperResolutionPipeline,
|
137 |
-
ImageTextPipelineOutput,
|
138 |
-
KandinskyImg2ImgPipeline,
|
139 |
-
KandinskyInpaintPipeline,
|
140 |
-
KandinskyPipeline,
|
141 |
-
KandinskyPriorPipeline,
|
142 |
-
KandinskyV22ControlnetImg2ImgPipeline,
|
143 |
-
KandinskyV22ControlnetPipeline,
|
144 |
-
KandinskyV22Img2ImgPipeline,
|
145 |
-
KandinskyV22InpaintPipeline,
|
146 |
-
KandinskyV22Pipeline,
|
147 |
-
KandinskyV22PriorEmb2EmbPipeline,
|
148 |
-
KandinskyV22PriorPipeline,
|
149 |
-
LDMTextToImagePipeline,
|
150 |
-
PaintByExamplePipeline,
|
151 |
-
SemanticStableDiffusionPipeline,
|
152 |
-
ShapEImg2ImgPipeline,
|
153 |
-
ShapEPipeline,
|
154 |
-
StableDiffusionAttendAndExcitePipeline,
|
155 |
-
StableDiffusionControlNetImg2ImgPipeline,
|
156 |
-
StableDiffusionControlNetInpaintPipeline,
|
157 |
-
StableDiffusionControlNetPipeline,
|
158 |
-
StableDiffusionDepth2ImgPipeline,
|
159 |
-
StableDiffusionDiffEditPipeline,
|
160 |
-
StableDiffusionImageVariationPipeline,
|
161 |
-
StableDiffusionImg2ImgPipeline,
|
162 |
-
StableDiffusionInpaintPipeline,
|
163 |
-
StableDiffusionInpaintPipelineLegacy,
|
164 |
-
StableDiffusionInstructPix2PixPipeline,
|
165 |
-
StableDiffusionLatentUpscalePipeline,
|
166 |
-
StableDiffusionLDM3DPipeline,
|
167 |
-
StableDiffusionModelEditingPipeline,
|
168 |
-
StableDiffusionPanoramaPipeline,
|
169 |
-
StableDiffusionParadigmsPipeline,
|
170 |
-
StableDiffusionPipeline,
|
171 |
-
StableDiffusionPipelineSafe,
|
172 |
-
StableDiffusionPix2PixZeroPipeline,
|
173 |
-
StableDiffusionSAGPipeline,
|
174 |
-
StableDiffusionUpscalePipeline,
|
175 |
-
StableUnCLIPImg2ImgPipeline,
|
176 |
-
StableUnCLIPPipeline,
|
177 |
-
TextToVideoSDPipeline,
|
178 |
-
TextToVideoZeroPipeline,
|
179 |
-
UnCLIPImageVariationPipeline,
|
180 |
-
UnCLIPPipeline,
|
181 |
-
UniDiffuserModel,
|
182 |
-
UniDiffuserPipeline,
|
183 |
-
UniDiffuserTextDecoder,
|
184 |
-
VersatileDiffusionDualGuidedPipeline,
|
185 |
-
VersatileDiffusionImageVariationPipeline,
|
186 |
-
VersatileDiffusionPipeline,
|
187 |
-
VersatileDiffusionTextToImagePipeline,
|
188 |
-
VideoToVideoSDPipeline,
|
189 |
-
VQDiffusionPipeline,
|
190 |
-
)
|
191 |
-
|
192 |
-
try:
|
193 |
-
if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()):
|
194 |
-
raise OptionalDependencyNotAvailable()
|
195 |
-
except OptionalDependencyNotAvailable:
|
196 |
-
from .utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
|
197 |
-
else:
|
198 |
-
from .pipelines import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline
|
199 |
-
|
200 |
-
try:
|
201 |
-
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
|
202 |
-
raise OptionalDependencyNotAvailable()
|
203 |
-
except OptionalDependencyNotAvailable:
|
204 |
-
from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
|
205 |
-
else:
|
206 |
-
from .pipelines import StableDiffusionKDiffusionPipeline
|
207 |
-
|
208 |
-
try:
|
209 |
-
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
210 |
-
raise OptionalDependencyNotAvailable()
|
211 |
-
except OptionalDependencyNotAvailable:
|
212 |
-
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
|
213 |
-
else:
|
214 |
-
from .pipelines import (
|
215 |
-
OnnxStableDiffusionImg2ImgPipeline,
|
216 |
-
OnnxStableDiffusionInpaintPipeline,
|
217 |
-
OnnxStableDiffusionInpaintPipelineLegacy,
|
218 |
-
OnnxStableDiffusionPipeline,
|
219 |
-
OnnxStableDiffusionUpscalePipeline,
|
220 |
-
StableDiffusionOnnxPipeline,
|
221 |
-
)
|
222 |
-
|
223 |
-
try:
|
224 |
-
if not (is_torch_available() and is_librosa_available()):
|
225 |
-
raise OptionalDependencyNotAvailable()
|
226 |
-
except OptionalDependencyNotAvailable:
|
227 |
-
from .utils.dummy_torch_and_librosa_objects import * # noqa F403
|
228 |
-
else:
|
229 |
-
from .pipelines import AudioDiffusionPipeline, Mel
|
230 |
-
|
231 |
-
try:
|
232 |
-
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
|
233 |
-
raise OptionalDependencyNotAvailable()
|
234 |
-
except OptionalDependencyNotAvailable:
|
235 |
-
from .utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
|
236 |
-
else:
|
237 |
-
from .pipelines import SpectrogramDiffusionPipeline
|
238 |
-
|
239 |
-
try:
|
240 |
-
if not is_flax_available():
|
241 |
-
raise OptionalDependencyNotAvailable()
|
242 |
-
except OptionalDependencyNotAvailable:
|
243 |
-
from .utils.dummy_flax_objects import * # noqa F403
|
244 |
-
else:
|
245 |
-
from .models.controlnet_flax import FlaxControlNetModel
|
246 |
-
from .models.modeling_flax_utils import FlaxModelMixin
|
247 |
-
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
|
248 |
-
from .models.vae_flax import FlaxAutoencoderKL
|
249 |
-
from .pipelines import FlaxDiffusionPipeline
|
250 |
-
from .schedulers import (
|
251 |
-
FlaxDDIMScheduler,
|
252 |
-
FlaxDDPMScheduler,
|
253 |
-
FlaxDPMSolverMultistepScheduler,
|
254 |
-
FlaxKarrasVeScheduler,
|
255 |
-
FlaxLMSDiscreteScheduler,
|
256 |
-
FlaxPNDMScheduler,
|
257 |
-
FlaxSchedulerMixin,
|
258 |
-
FlaxScoreSdeVeScheduler,
|
259 |
-
)
|
260 |
-
|
261 |
-
|
262 |
-
try:
|
263 |
-
if not (is_flax_available() and is_transformers_available()):
|
264 |
-
raise OptionalDependencyNotAvailable()
|
265 |
-
except OptionalDependencyNotAvailable:
|
266 |
-
from .utils.dummy_flax_and_transformers_objects import * # noqa F403
|
267 |
-
else:
|
268 |
-
from .pipelines import (
|
269 |
-
FlaxStableDiffusionControlNetPipeline,
|
270 |
-
FlaxStableDiffusionImg2ImgPipeline,
|
271 |
-
FlaxStableDiffusionInpaintPipeline,
|
272 |
-
FlaxStableDiffusionPipeline,
|
273 |
-
)
|
274 |
-
|
275 |
-
try:
|
276 |
-
if not (is_note_seq_available()):
|
277 |
-
raise OptionalDependencyNotAvailable()
|
278 |
-
except OptionalDependencyNotAvailable:
|
279 |
-
from .utils.dummy_note_seq_objects import * # noqa F403
|
280 |
-
else:
|
281 |
-
from .pipelines import MidiProcessor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4DoF/diffusers/commands/__init__.py
DELETED
@@ -1,27 +0,0 @@
|
|
1 |
-
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
from abc import ABC, abstractmethod
|
16 |
-
from argparse import ArgumentParser
|
17 |
-
|
18 |
-
|
19 |
-
class BaseDiffusersCLICommand(ABC):
|
20 |
-
@staticmethod
|
21 |
-
@abstractmethod
|
22 |
-
def register_subcommand(parser: ArgumentParser):
|
23 |
-
raise NotImplementedError()
|
24 |
-
|
25 |
-
@abstractmethod
|
26 |
-
def run(self):
|
27 |
-
raise NotImplementedError()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4DoF/diffusers/commands/diffusers_cli.py
DELETED
@@ -1,41 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python
|
2 |
-
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
|
16 |
-
from argparse import ArgumentParser
|
17 |
-
|
18 |
-
from .env import EnvironmentCommand
|
19 |
-
|
20 |
-
|
21 |
-
def main():
|
22 |
-
parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli <command> [<args>]")
|
23 |
-
commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")
|
24 |
-
|
25 |
-
# Register commands
|
26 |
-
EnvironmentCommand.register_subcommand(commands_parser)
|
27 |
-
|
28 |
-
# Let's go
|
29 |
-
args = parser.parse_args()
|
30 |
-
|
31 |
-
if not hasattr(args, "func"):
|
32 |
-
parser.print_help()
|
33 |
-
exit(1)
|
34 |
-
|
35 |
-
# Run
|
36 |
-
service = args.func(args)
|
37 |
-
service.run()
|
38 |
-
|
39 |
-
|
40 |
-
if __name__ == "__main__":
|
41 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4DoF/diffusers/commands/env.py
DELETED
@@ -1,84 +0,0 @@
|
|
1 |
-
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
import platform
|
16 |
-
from argparse import ArgumentParser
|
17 |
-
|
18 |
-
import huggingface_hub
|
19 |
-
|
20 |
-
from .. import __version__ as version
|
21 |
-
from ..utils import is_accelerate_available, is_torch_available, is_transformers_available, is_xformers_available
|
22 |
-
from . import BaseDiffusersCLICommand
|
23 |
-
|
24 |
-
|
25 |
-
def info_command_factory(_):
|
26 |
-
return EnvironmentCommand()
|
27 |
-
|
28 |
-
|
29 |
-
class EnvironmentCommand(BaseDiffusersCLICommand):
|
30 |
-
@staticmethod
|
31 |
-
def register_subcommand(parser: ArgumentParser):
|
32 |
-
download_parser = parser.add_parser("env")
|
33 |
-
download_parser.set_defaults(func=info_command_factory)
|
34 |
-
|
35 |
-
def run(self):
|
36 |
-
hub_version = huggingface_hub.__version__
|
37 |
-
|
38 |
-
pt_version = "not installed"
|
39 |
-
pt_cuda_available = "NA"
|
40 |
-
if is_torch_available():
|
41 |
-
import torch
|
42 |
-
|
43 |
-
pt_version = torch.__version__
|
44 |
-
pt_cuda_available = torch.cuda.is_available()
|
45 |
-
|
46 |
-
transformers_version = "not installed"
|
47 |
-
if is_transformers_available():
|
48 |
-
import transformers
|
49 |
-
|
50 |
-
transformers_version = transformers.__version__
|
51 |
-
|
52 |
-
accelerate_version = "not installed"
|
53 |
-
if is_accelerate_available():
|
54 |
-
import accelerate
|
55 |
-
|
56 |
-
accelerate_version = accelerate.__version__
|
57 |
-
|
58 |
-
xformers_version = "not installed"
|
59 |
-
if is_xformers_available():
|
60 |
-
import xformers
|
61 |
-
|
62 |
-
xformers_version = xformers.__version__
|
63 |
-
|
64 |
-
info = {
|
65 |
-
"`diffusers` version": version,
|
66 |
-
"Platform": platform.platform(),
|
67 |
-
"Python version": platform.python_version(),
|
68 |
-
"PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
|
69 |
-
"Huggingface_hub version": hub_version,
|
70 |
-
"Transformers version": transformers_version,
|
71 |
-
"Accelerate version": accelerate_version,
|
72 |
-
"xFormers version": xformers_version,
|
73 |
-
"Using GPU in script?": "<fill in>",
|
74 |
-
"Using distributed or parallel set-up in script?": "<fill in>",
|
75 |
-
}
|
76 |
-
|
77 |
-
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
|
78 |
-
print(self.format_dict(info))
|
79 |
-
|
80 |
-
return info
|
81 |
-
|
82 |
-
@staticmethod
|
83 |
-
def format_dict(d):
|
84 |
-
return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4DoF/diffusers/configuration_utils.py
DELETED
@@ -1,664 +0,0 @@
|
|
1 |
-
# coding=utf-8
|
2 |
-
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
-
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
4 |
-
#
|
5 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
-
# you may not use this file except in compliance with the License.
|
7 |
-
# You may obtain a copy of the License at
|
8 |
-
#
|
9 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
-
#
|
11 |
-
# Unless required by applicable law or agreed to in writing, software
|
12 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
-
# See the License for the specific language governing permissions and
|
15 |
-
# limitations under the License.
|
16 |
-
""" ConfigMixin base class and utilities."""
|
17 |
-
import dataclasses
|
18 |
-
import functools
|
19 |
-
import importlib
|
20 |
-
import inspect
|
21 |
-
import json
|
22 |
-
import os
|
23 |
-
import re
|
24 |
-
from collections import OrderedDict
|
25 |
-
from pathlib import PosixPath
|
26 |
-
from typing import Any, Dict, Tuple, Union
|
27 |
-
|
28 |
-
import numpy as np
|
29 |
-
from huggingface_hub import hf_hub_download
|
30 |
-
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
31 |
-
from requests import HTTPError
|
32 |
-
|
33 |
-
from . import __version__
|
34 |
-
from .utils import (
|
35 |
-
DIFFUSERS_CACHE,
|
36 |
-
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
37 |
-
DummyObject,
|
38 |
-
deprecate,
|
39 |
-
extract_commit_hash,
|
40 |
-
http_user_agent,
|
41 |
-
logging,
|
42 |
-
)
|
43 |
-
|
44 |
-
|
45 |
-
logger = logging.get_logger(__name__)
|
46 |
-
|
47 |
-
_re_configuration_file = re.compile(r"config\.(.*)\.json")
|
48 |
-
|
49 |
-
|
50 |
-
class FrozenDict(OrderedDict):
|
51 |
-
def __init__(self, *args, **kwargs):
|
52 |
-
super().__init__(*args, **kwargs)
|
53 |
-
|
54 |
-
for key, value in self.items():
|
55 |
-
setattr(self, key, value)
|
56 |
-
|
57 |
-
self.__frozen = True
|
58 |
-
|
59 |
-
def __delitem__(self, *args, **kwargs):
|
60 |
-
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
|
61 |
-
|
62 |
-
def setdefault(self, *args, **kwargs):
|
63 |
-
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
|
64 |
-
|
65 |
-
def pop(self, *args, **kwargs):
|
66 |
-
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
|
67 |
-
|
68 |
-
def update(self, *args, **kwargs):
|
69 |
-
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
|
70 |
-
|
71 |
-
def __setattr__(self, name, value):
|
72 |
-
if hasattr(self, "__frozen") and self.__frozen:
|
73 |
-
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
|
74 |
-
super().__setattr__(name, value)
|
75 |
-
|
76 |
-
def __setitem__(self, name, value):
|
77 |
-
if hasattr(self, "__frozen") and self.__frozen:
|
78 |
-
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
|
79 |
-
super().__setitem__(name, value)
|
80 |
-
|
81 |
-
|
82 |
-
class ConfigMixin:
|
83 |
-
r"""
|
84 |
-
Base class for all configuration classes. All configuration parameters are stored under `self.config`. Also
|
85 |
-
provides the [`~ConfigMixin.from_config`] and [`~ConfigMixin.save_config`] methods for loading, downloading, and
|
86 |
-
saving classes that inherit from [`ConfigMixin`].
|
87 |
-
|
88 |
-
Class attributes:
|
89 |
-
- **config_name** (`str`) -- A filename under which the config should stored when calling
|
90 |
-
[`~ConfigMixin.save_config`] (should be overridden by parent class).
|
91 |
-
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
|
92 |
-
overridden by subclass).
|
93 |
-
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
|
94 |
-
- **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function
|
95 |
-
should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
|
96 |
-
subclass).
|
97 |
-
"""
|
98 |
-
config_name = None
|
99 |
-
ignore_for_config = []
|
100 |
-
has_compatibles = False
|
101 |
-
|
102 |
-
_deprecated_kwargs = []
|
103 |
-
|
104 |
-
def register_to_config(self, **kwargs):
|
105 |
-
if self.config_name is None:
|
106 |
-
raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
|
107 |
-
# Special case for `kwargs` used in deprecation warning added to schedulers
|
108 |
-
# TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
|
109 |
-
# or solve in a more general way.
|
110 |
-
kwargs.pop("kwargs", None)
|
111 |
-
|
112 |
-
if not hasattr(self, "_internal_dict"):
|
113 |
-
internal_dict = kwargs
|
114 |
-
else:
|
115 |
-
previous_dict = dict(self._internal_dict)
|
116 |
-
internal_dict = {**self._internal_dict, **kwargs}
|
117 |
-
logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
|
118 |
-
|
119 |
-
self._internal_dict = FrozenDict(internal_dict)
|
120 |
-
|
121 |
-
def __getattr__(self, name: str) -> Any:
|
122 |
-
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
|
123 |
-
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129
|
124 |
-
|
125 |
-
Tihs funtion is mostly copied from PyTorch's __getattr__ overwrite:
|
126 |
-
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
127 |
-
"""
|
128 |
-
|
129 |
-
is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
|
130 |
-
is_attribute = name in self.__dict__
|
131 |
-
|
132 |
-
if is_in_config and not is_attribute:
|
133 |
-
deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'."
|
134 |
-
deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
|
135 |
-
return self._internal_dict[name]
|
136 |
-
|
137 |
-
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
138 |
-
|
139 |
-
def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
140 |
-
"""
|
141 |
-
Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the
|
142 |
-
[`~ConfigMixin.from_config`] class method.
|
143 |
-
|
144 |
-
Args:
|
145 |
-
save_directory (`str` or `os.PathLike`):
|
146 |
-
Directory where the configuration JSON file is saved (will be created if it does not exist).
|
147 |
-
"""
|
148 |
-
if os.path.isfile(save_directory):
|
149 |
-
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
150 |
-
|
151 |
-
os.makedirs(save_directory, exist_ok=True)
|
152 |
-
|
153 |
-
# If we save using the predefined names, we can load using `from_config`
|
154 |
-
output_config_file = os.path.join(save_directory, self.config_name)
|
155 |
-
|
156 |
-
self.to_json_file(output_config_file)
|
157 |
-
logger.info(f"Configuration saved in {output_config_file}")
|
158 |
-
|
159 |
-
@classmethod
|
160 |
-
def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
|
161 |
-
r"""
|
162 |
-
Instantiate a Python class from a config dictionary.
|
163 |
-
|
164 |
-
Parameters:
|
165 |
-
config (`Dict[str, Any]`):
|
166 |
-
A config dictionary from which the Python class is instantiated. Make sure to only load configuration
|
167 |
-
files of compatible classes.
|
168 |
-
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
169 |
-
Whether kwargs that are not consumed by the Python class should be returned or not.
|
170 |
-
kwargs (remaining dictionary of keyword arguments, *optional*):
|
171 |
-
Can be used to update the configuration object (after it is loaded) and initiate the Python class.
|
172 |
-
`**kwargs` are passed directly to the underlying scheduler/model's `__init__` method and eventually
|
173 |
-
overwrite the same named arguments in `config`.
|
174 |
-
|
175 |
-
Returns:
|
176 |
-
[`ModelMixin`] or [`SchedulerMixin`]:
|
177 |
-
A model or scheduler object instantiated from a config dictionary.
|
178 |
-
|
179 |
-
Examples:
|
180 |
-
|
181 |
-
```python
|
182 |
-
>>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
|
183 |
-
|
184 |
-
>>> # Download scheduler from huggingface.co and cache.
|
185 |
-
>>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
|
186 |
-
|
187 |
-
>>> # Instantiate DDIM scheduler class with same config as DDPM
|
188 |
-
>>> scheduler = DDIMScheduler.from_config(scheduler.config)
|
189 |
-
|
190 |
-
>>> # Instantiate PNDM scheduler class with same config as DDPM
|
191 |
-
>>> scheduler = PNDMScheduler.from_config(scheduler.config)
|
192 |
-
```
|
193 |
-
"""
|
194 |
-
# <===== TO BE REMOVED WITH DEPRECATION
|
195 |
-
# TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
|
196 |
-
if "pretrained_model_name_or_path" in kwargs:
|
197 |
-
config = kwargs.pop("pretrained_model_name_or_path")
|
198 |
-
|
199 |
-
if config is None:
|
200 |
-
raise ValueError("Please make sure to provide a config as the first positional argument.")
|
201 |
-
# ======>
|
202 |
-
|
203 |
-
if not isinstance(config, dict):
|
204 |
-
deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
|
205 |
-
if "Scheduler" in cls.__name__:
|
206 |
-
deprecation_message += (
|
207 |
-
f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
|
208 |
-
" Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
|
209 |
-
" be removed in v1.0.0."
|
210 |
-
)
|
211 |
-
elif "Model" in cls.__name__:
|
212 |
-
deprecation_message += (
|
213 |
-
f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
|
214 |
-
f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
|
215 |
-
" instead. This functionality will be removed in v1.0.0."
|
216 |
-
)
|
217 |
-
deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
|
218 |
-
config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
|
219 |
-
|
220 |
-
init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
|
221 |
-
|
222 |
-
# Allow dtype to be specified on initialization
|
223 |
-
if "dtype" in unused_kwargs:
|
224 |
-
init_dict["dtype"] = unused_kwargs.pop("dtype")
|
225 |
-
|
226 |
-
# add possible deprecated kwargs
|
227 |
-
for deprecated_kwarg in cls._deprecated_kwargs:
|
228 |
-
if deprecated_kwarg in unused_kwargs:
|
229 |
-
init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
|
230 |
-
|
231 |
-
# Return model and optionally state and/or unused_kwargs
|
232 |
-
model = cls(**init_dict)
|
233 |
-
|
234 |
-
# make sure to also save config parameters that might be used for compatible classes
|
235 |
-
model.register_to_config(**hidden_dict)
|
236 |
-
|
237 |
-
# add hidden kwargs of compatible classes to unused_kwargs
|
238 |
-
unused_kwargs = {**unused_kwargs, **hidden_dict}
|
239 |
-
|
240 |
-
if return_unused_kwargs:
|
241 |
-
return (model, unused_kwargs)
|
242 |
-
else:
|
243 |
-
return model
|
244 |
-
|
245 |
-
@classmethod
|
246 |
-
def get_config_dict(cls, *args, **kwargs):
|
247 |
-
deprecation_message = (
|
248 |
-
f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
|
249 |
-
" removed in version v1.0.0"
|
250 |
-
)
|
251 |
-
deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
|
252 |
-
return cls.load_config(*args, **kwargs)
|
253 |
-
|
254 |
-
@classmethod
|
255 |
-
def load_config(
|
256 |
-
cls,
|
257 |
-
pretrained_model_name_or_path: Union[str, os.PathLike],
|
258 |
-
return_unused_kwargs=False,
|
259 |
-
return_commit_hash=False,
|
260 |
-
**kwargs,
|
261 |
-
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
262 |
-
r"""
|
263 |
-
Load a model or scheduler configuration.
|
264 |
-
|
265 |
-
Parameters:
|
266 |
-
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
267 |
-
Can be either:
|
268 |
-
|
269 |
-
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
270 |
-
the Hub.
|
271 |
-
- A path to a *directory* (for example `./my_model_directory`) containing model weights saved with
|
272 |
-
[`~ConfigMixin.save_config`].
|
273 |
-
|
274 |
-
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
275 |
-
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
276 |
-
is not used.
|
277 |
-
force_download (`bool`, *optional*, defaults to `False`):
|
278 |
-
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
279 |
-
cached versions if they exist.
|
280 |
-
resume_download (`bool`, *optional*, defaults to `False`):
|
281 |
-
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
282 |
-
incompletely downloaded files are deleted.
|
283 |
-
proxies (`Dict[str, str]`, *optional*):
|
284 |
-
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
285 |
-
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
286 |
-
output_loading_info(`bool`, *optional*, defaults to `False`):
|
287 |
-
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
288 |
-
local_files_only (`bool`, *optional*, defaults to `False`):
|
289 |
-
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
290 |
-
won't be downloaded from the Hub.
|
291 |
-
use_auth_token (`str` or *bool*, *optional*):
|
292 |
-
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
293 |
-
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
294 |
-
revision (`str`, *optional*, defaults to `"main"`):
|
295 |
-
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
296 |
-
allowed by Git.
|
297 |
-
subfolder (`str`, *optional*, defaults to `""`):
|
298 |
-
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
299 |
-
return_unused_kwargs (`bool`, *optional*, defaults to `False):
|
300 |
-
Whether unused keyword arguments of the config are returned.
|
301 |
-
return_commit_hash (`bool`, *optional*, defaults to `False):
|
302 |
-
Whether the `commit_hash` of the loaded configuration are returned.
|
303 |
-
|
304 |
-
Returns:
|
305 |
-
`dict`:
|
306 |
-
A dictionary of all the parameters stored in a JSON configuration file.
|
307 |
-
|
308 |
-
"""
|
309 |
-
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
310 |
-
force_download = kwargs.pop("force_download", False)
|
311 |
-
resume_download = kwargs.pop("resume_download", False)
|
312 |
-
proxies = kwargs.pop("proxies", None)
|
313 |
-
use_auth_token = kwargs.pop("use_auth_token", None)
|
314 |
-
local_files_only = kwargs.pop("local_files_only", False)
|
315 |
-
revision = kwargs.pop("revision", None)
|
316 |
-
_ = kwargs.pop("mirror", None)
|
317 |
-
subfolder = kwargs.pop("subfolder", None)
|
318 |
-
user_agent = kwargs.pop("user_agent", {})
|
319 |
-
|
320 |
-
user_agent = {**user_agent, "file_type": "config"}
|
321 |
-
user_agent = http_user_agent(user_agent)
|
322 |
-
|
323 |
-
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
324 |
-
|
325 |
-
if cls.config_name is None:
|
326 |
-
raise ValueError(
|
327 |
-
"`self.config_name` is not defined. Note that one should not load a config from "
|
328 |
-
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
|
329 |
-
)
|
330 |
-
|
331 |
-
if os.path.isfile(pretrained_model_name_or_path):
|
332 |
-
config_file = pretrained_model_name_or_path
|
333 |
-
elif os.path.isdir(pretrained_model_name_or_path):
|
334 |
-
if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
|
335 |
-
# Load from a PyTorch checkpoint
|
336 |
-
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
|
337 |
-
elif subfolder is not None and os.path.isfile(
|
338 |
-
os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
|
339 |
-
):
|
340 |
-
config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
|
341 |
-
else:
|
342 |
-
raise EnvironmentError(
|
343 |
-
f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
|
344 |
-
)
|
345 |
-
else:
|
346 |
-
try:
|
347 |
-
# Load from URL or cache if already cached
|
348 |
-
config_file = hf_hub_download(
|
349 |
-
pretrained_model_name_or_path,
|
350 |
-
filename=cls.config_name,
|
351 |
-
cache_dir=cache_dir,
|
352 |
-
force_download=force_download,
|
353 |
-
proxies=proxies,
|
354 |
-
resume_download=resume_download,
|
355 |
-
local_files_only=local_files_only,
|
356 |
-
use_auth_token=use_auth_token,
|
357 |
-
user_agent=user_agent,
|
358 |
-
subfolder=subfolder,
|
359 |
-
revision=revision,
|
360 |
-
)
|
361 |
-
except RepositoryNotFoundError:
|
362 |
-
raise EnvironmentError(
|
363 |
-
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
|
364 |
-
" listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
|
365 |
-
" token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
|
366 |
-
" login`."
|
367 |
-
)
|
368 |
-
except RevisionNotFoundError:
|
369 |
-
raise EnvironmentError(
|
370 |
-
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
|
371 |
-
" this model name. Check the model page at"
|
372 |
-
f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
373 |
-
)
|
374 |
-
except EntryNotFoundError:
|
375 |
-
raise EnvironmentError(
|
376 |
-
f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
|
377 |
-
)
|
378 |
-
except HTTPError as err:
|
379 |
-
raise EnvironmentError(
|
380 |
-
"There was a specific connection error when trying to load"
|
381 |
-
f" {pretrained_model_name_or_path}:\n{err}"
|
382 |
-
)
|
383 |
-
except ValueError:
|
384 |
-
raise EnvironmentError(
|
385 |
-
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
386 |
-
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
387 |
-
f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
|
388 |
-
" run the library in offline mode at"
|
389 |
-
" 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
390 |
-
)
|
391 |
-
except EnvironmentError:
|
392 |
-
raise EnvironmentError(
|
393 |
-
f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
394 |
-
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
395 |
-
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
396 |
-
f"containing a {cls.config_name} file"
|
397 |
-
)
|
398 |
-
|
399 |
-
try:
|
400 |
-
# Load config dict
|
401 |
-
config_dict = cls._dict_from_json_file(config_file)
|
402 |
-
|
403 |
-
commit_hash = extract_commit_hash(config_file)
|
404 |
-
except (json.JSONDecodeError, UnicodeDecodeError):
|
405 |
-
raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
|
406 |
-
|
407 |
-
if not (return_unused_kwargs or return_commit_hash):
|
408 |
-
return config_dict
|
409 |
-
|
410 |
-
outputs = (config_dict,)
|
411 |
-
|
412 |
-
if return_unused_kwargs:
|
413 |
-
outputs += (kwargs,)
|
414 |
-
|
415 |
-
if return_commit_hash:
|
416 |
-
outputs += (commit_hash,)
|
417 |
-
|
418 |
-
return outputs
|
419 |
-
|
420 |
-
@staticmethod
|
421 |
-
def _get_init_keys(cls):
|
422 |
-
return set(dict(inspect.signature(cls.__init__).parameters).keys())
|
423 |
-
|
424 |
-
@classmethod
|
425 |
-
def extract_init_dict(cls, config_dict, **kwargs):
|
426 |
-
# Skip keys that were not present in the original config, so default __init__ values were used
|
427 |
-
used_defaults = config_dict.get("_use_default_values", [])
|
428 |
-
config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"}
|
429 |
-
|
430 |
-
# 0. Copy origin config dict
|
431 |
-
original_dict = dict(config_dict.items())
|
432 |
-
|
433 |
-
# 1. Retrieve expected config attributes from __init__ signature
|
434 |
-
expected_keys = cls._get_init_keys(cls)
|
435 |
-
expected_keys.remove("self")
|
436 |
-
# remove general kwargs if present in dict
|
437 |
-
if "kwargs" in expected_keys:
|
438 |
-
expected_keys.remove("kwargs")
|
439 |
-
# remove flax internal keys
|
440 |
-
if hasattr(cls, "_flax_internal_args"):
|
441 |
-
for arg in cls._flax_internal_args:
|
442 |
-
expected_keys.remove(arg)
|
443 |
-
|
444 |
-
# 2. Remove attributes that cannot be expected from expected config attributes
|
445 |
-
# remove keys to be ignored
|
446 |
-
if len(cls.ignore_for_config) > 0:
|
447 |
-
expected_keys = expected_keys - set(cls.ignore_for_config)
|
448 |
-
|
449 |
-
# load diffusers library to import compatible and original scheduler
|
450 |
-
diffusers_library = importlib.import_module(__name__.split(".")[0])
|
451 |
-
|
452 |
-
if cls.has_compatibles:
|
453 |
-
compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
|
454 |
-
else:
|
455 |
-
compatible_classes = []
|
456 |
-
|
457 |
-
expected_keys_comp_cls = set()
|
458 |
-
for c in compatible_classes:
|
459 |
-
expected_keys_c = cls._get_init_keys(c)
|
460 |
-
expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
|
461 |
-
expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
|
462 |
-
config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
|
463 |
-
|
464 |
-
# remove attributes from orig class that cannot be expected
|
465 |
-
orig_cls_name = config_dict.pop("_class_name", cls.__name__)
|
466 |
-
if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name):
|
467 |
-
orig_cls = getattr(diffusers_library, orig_cls_name)
|
468 |
-
unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
|
469 |
-
config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
|
470 |
-
|
471 |
-
# remove private attributes
|
472 |
-
config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
|
473 |
-
|
474 |
-
# 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
|
475 |
-
init_dict = {}
|
476 |
-
for key in expected_keys:
|
477 |
-
# if config param is passed to kwarg and is present in config dict
|
478 |
-
# it should overwrite existing config dict key
|
479 |
-
if key in kwargs and key in config_dict:
|
480 |
-
config_dict[key] = kwargs.pop(key)
|
481 |
-
|
482 |
-
if key in kwargs:
|
483 |
-
# overwrite key
|
484 |
-
init_dict[key] = kwargs.pop(key)
|
485 |
-
elif key in config_dict:
|
486 |
-
# use value from config dict
|
487 |
-
init_dict[key] = config_dict.pop(key)
|
488 |
-
|
489 |
-
# 4. Give nice warning if unexpected values have been passed
|
490 |
-
if len(config_dict) > 0:
|
491 |
-
logger.warning(
|
492 |
-
f"The config attributes {config_dict} were passed to {cls.__name__}, "
|
493 |
-
"but are not expected and will be ignored. Please verify your "
|
494 |
-
f"{cls.config_name} configuration file."
|
495 |
-
)
|
496 |
-
|
497 |
-
# 5. Give nice info if config attributes are initiliazed to default because they have not been passed
|
498 |
-
passed_keys = set(init_dict.keys())
|
499 |
-
if len(expected_keys - passed_keys) > 0:
|
500 |
-
logger.info(
|
501 |
-
f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
|
502 |
-
)
|
503 |
-
|
504 |
-
# 6. Define unused keyword arguments
|
505 |
-
unused_kwargs = {**config_dict, **kwargs}
|
506 |
-
|
507 |
-
# 7. Define "hidden" config parameters that were saved for compatible classes
|
508 |
-
hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
|
509 |
-
|
510 |
-
return init_dict, unused_kwargs, hidden_config_dict
|
511 |
-
|
512 |
-
@classmethod
|
513 |
-
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
|
514 |
-
with open(json_file, "r", encoding="utf-8") as reader:
|
515 |
-
text = reader.read()
|
516 |
-
return json.loads(text)
|
517 |
-
|
518 |
-
def __repr__(self):
|
519 |
-
return f"{self.__class__.__name__} {self.to_json_string()}"
|
520 |
-
|
521 |
-
@property
|
522 |
-
def config(self) -> Dict[str, Any]:
|
523 |
-
"""
|
524 |
-
Returns the config of the class as a frozen dictionary
|
525 |
-
|
526 |
-
Returns:
|
527 |
-
`Dict[str, Any]`: Config of the class.
|
528 |
-
"""
|
529 |
-
return self._internal_dict
|
530 |
-
|
531 |
-
def to_json_string(self) -> str:
|
532 |
-
"""
|
533 |
-
Serializes the configuration instance to a JSON string.
|
534 |
-
|
535 |
-
Returns:
|
536 |
-
`str`:
|
537 |
-
String containing all the attributes that make up the configuration instance in JSON format.
|
538 |
-
"""
|
539 |
-
config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
|
540 |
-
config_dict["_class_name"] = self.__class__.__name__
|
541 |
-
config_dict["_diffusers_version"] = __version__
|
542 |
-
|
543 |
-
def to_json_saveable(value):
|
544 |
-
if isinstance(value, np.ndarray):
|
545 |
-
value = value.tolist()
|
546 |
-
elif isinstance(value, PosixPath):
|
547 |
-
value = str(value)
|
548 |
-
return value
|
549 |
-
|
550 |
-
config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
|
551 |
-
# Don't save "_ignore_files" or "_use_default_values"
|
552 |
-
config_dict.pop("_ignore_files", None)
|
553 |
-
config_dict.pop("_use_default_values", None)
|
554 |
-
|
555 |
-
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
556 |
-
|
557 |
-
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
558 |
-
"""
|
559 |
-
Save the configuration instance's parameters to a JSON file.
|
560 |
-
|
561 |
-
Args:
|
562 |
-
json_file_path (`str` or `os.PathLike`):
|
563 |
-
Path to the JSON file to save a configuration instance's parameters.
|
564 |
-
"""
|
565 |
-
with open(json_file_path, "w", encoding="utf-8") as writer:
|
566 |
-
writer.write(self.to_json_string())
|
567 |
-
|
568 |
-
|
569 |
-
def register_to_config(init):
|
570 |
-
r"""
|
571 |
-
Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
|
572 |
-
automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
|
573 |
-
shouldn't be registered in the config, use the `ignore_for_config` class variable
|
574 |
-
|
575 |
-
Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
|
576 |
-
"""
|
577 |
-
|
578 |
-
@functools.wraps(init)
|
579 |
-
def inner_init(self, *args, **kwargs):
|
580 |
-
# Ignore private kwargs in the init.
|
581 |
-
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
|
582 |
-
config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
|
583 |
-
if not isinstance(self, ConfigMixin):
|
584 |
-
raise RuntimeError(
|
585 |
-
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
|
586 |
-
"not inherit from `ConfigMixin`."
|
587 |
-
)
|
588 |
-
|
589 |
-
ignore = getattr(self, "ignore_for_config", [])
|
590 |
-
# Get positional arguments aligned with kwargs
|
591 |
-
new_kwargs = {}
|
592 |
-
signature = inspect.signature(init)
|
593 |
-
parameters = {
|
594 |
-
name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
|
595 |
-
}
|
596 |
-
for arg, name in zip(args, parameters.keys()):
|
597 |
-
new_kwargs[name] = arg
|
598 |
-
|
599 |
-
# Then add all kwargs
|
600 |
-
new_kwargs.update(
|
601 |
-
{
|
602 |
-
k: init_kwargs.get(k, default)
|
603 |
-
for k, default in parameters.items()
|
604 |
-
if k not in ignore and k not in new_kwargs
|
605 |
-
}
|
606 |
-
)
|
607 |
-
|
608 |
-
# Take note of the parameters that were not present in the loaded config
|
609 |
-
if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
|
610 |
-
new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
|
611 |
-
|
612 |
-
new_kwargs = {**config_init_kwargs, **new_kwargs}
|
613 |
-
getattr(self, "register_to_config")(**new_kwargs)
|
614 |
-
init(self, *args, **init_kwargs)
|
615 |
-
|
616 |
-
return inner_init
|
617 |
-
|
618 |
-
|
619 |
-
def flax_register_to_config(cls):
|
620 |
-
original_init = cls.__init__
|
621 |
-
|
622 |
-
@functools.wraps(original_init)
|
623 |
-
def init(self, *args, **kwargs):
|
624 |
-
if not isinstance(self, ConfigMixin):
|
625 |
-
raise RuntimeError(
|
626 |
-
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
|
627 |
-
"not inherit from `ConfigMixin`."
|
628 |
-
)
|
629 |
-
|
630 |
-
# Ignore private kwargs in the init. Retrieve all passed attributes
|
631 |
-
init_kwargs = dict(kwargs.items())
|
632 |
-
|
633 |
-
# Retrieve default values
|
634 |
-
fields = dataclasses.fields(self)
|
635 |
-
default_kwargs = {}
|
636 |
-
for field in fields:
|
637 |
-
# ignore flax specific attributes
|
638 |
-
if field.name in self._flax_internal_args:
|
639 |
-
continue
|
640 |
-
if type(field.default) == dataclasses._MISSING_TYPE:
|
641 |
-
default_kwargs[field.name] = None
|
642 |
-
else:
|
643 |
-
default_kwargs[field.name] = getattr(self, field.name)
|
644 |
-
|
645 |
-
# Make sure init_kwargs override default kwargs
|
646 |
-
new_kwargs = {**default_kwargs, **init_kwargs}
|
647 |
-
# dtype should be part of `init_kwargs`, but not `new_kwargs`
|
648 |
-
if "dtype" in new_kwargs:
|
649 |
-
new_kwargs.pop("dtype")
|
650 |
-
|
651 |
-
# Get positional arguments aligned with kwargs
|
652 |
-
for i, arg in enumerate(args):
|
653 |
-
name = fields[i].name
|
654 |
-
new_kwargs[name] = arg
|
655 |
-
|
656 |
-
# Take note of the parameters that were not present in the loaded config
|
657 |
-
if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
|
658 |
-
new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
|
659 |
-
|
660 |
-
getattr(self, "register_to_config")(**new_kwargs)
|
661 |
-
original_init(self, *args, **kwargs)
|
662 |
-
|
663 |
-
cls.__init__ = init
|
664 |
-
return cls
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4DoF/diffusers/dependency_versions_check.py
DELETED
@@ -1,47 +0,0 @@
|
|
1 |
-
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
import sys
|
15 |
-
|
16 |
-
from .dependency_versions_table import deps
|
17 |
-
from .utils.versions import require_version, require_version_core
|
18 |
-
|
19 |
-
|
20 |
-
# define which module versions we always want to check at run time
|
21 |
-
# (usually the ones defined in `install_requires` in setup.py)
|
22 |
-
#
|
23 |
-
# order specific notes:
|
24 |
-
# - tqdm must be checked before tokenizers
|
25 |
-
|
26 |
-
pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split()
|
27 |
-
if sys.version_info < (3, 7):
|
28 |
-
pkgs_to_check_at_runtime.append("dataclasses")
|
29 |
-
if sys.version_info < (3, 8):
|
30 |
-
pkgs_to_check_at_runtime.append("importlib_metadata")
|
31 |
-
|
32 |
-
for pkg in pkgs_to_check_at_runtime:
|
33 |
-
if pkg in deps:
|
34 |
-
if pkg == "tokenizers":
|
35 |
-
# must be loaded here, or else tqdm check may fail
|
36 |
-
from .utils import is_tokenizers_available
|
37 |
-
|
38 |
-
if not is_tokenizers_available():
|
39 |
-
continue # not required, check version only if installed
|
40 |
-
|
41 |
-
require_version_core(deps[pkg])
|
42 |
-
else:
|
43 |
-
raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
|
44 |
-
|
45 |
-
|
46 |
-
def dep_version_check(pkg, hint=None):
|
47 |
-
require_version(deps[pkg], hint)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4DoF/diffusers/dependency_versions_table.py
DELETED
@@ -1,44 +0,0 @@
|
|
1 |
-
# THIS FILE HAS BEEN AUTOGENERATED. To update:
|
2 |
-
# 1. modify the `_deps` dict in setup.py
|
3 |
-
# 2. run `make deps_table_update``
|
4 |
-
deps = {
|
5 |
-
"Pillow": "Pillow",
|
6 |
-
"accelerate": "accelerate>=0.11.0",
|
7 |
-
"compel": "compel==0.1.8",
|
8 |
-
"black": "black~=23.1",
|
9 |
-
"datasets": "datasets",
|
10 |
-
"filelock": "filelock",
|
11 |
-
"flax": "flax>=0.4.1",
|
12 |
-
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
13 |
-
"huggingface-hub": "huggingface-hub>=0.13.2",
|
14 |
-
"requests-mock": "requests-mock==1.10.0",
|
15 |
-
"importlib_metadata": "importlib_metadata",
|
16 |
-
"invisible-watermark": "invisible-watermark",
|
17 |
-
"isort": "isort>=5.5.4",
|
18 |
-
"jax": "jax>=0.2.8,!=0.3.2",
|
19 |
-
"jaxlib": "jaxlib>=0.1.65",
|
20 |
-
"Jinja2": "Jinja2",
|
21 |
-
"k-diffusion": "k-diffusion>=0.0.12",
|
22 |
-
"torchsde": "torchsde",
|
23 |
-
"note_seq": "note_seq",
|
24 |
-
"librosa": "librosa",
|
25 |
-
"numpy": "numpy",
|
26 |
-
"omegaconf": "omegaconf",
|
27 |
-
"parameterized": "parameterized",
|
28 |
-
"protobuf": "protobuf>=3.20.3,<4",
|
29 |
-
"pytest": "pytest",
|
30 |
-
"pytest-timeout": "pytest-timeout",
|
31 |
-
"pytest-xdist": "pytest-xdist",
|
32 |
-
"ruff": "ruff>=0.0.241",
|
33 |
-
"safetensors": "safetensors",
|
34 |
-
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
|
35 |
-
"scipy": "scipy",
|
36 |
-
"onnx": "onnx",
|
37 |
-
"regex": "regex!=2019.12.17",
|
38 |
-
"requests": "requests",
|
39 |
-
"tensorboard": "tensorboard",
|
40 |
-
"torch": "torch>=1.4",
|
41 |
-
"torchvision": "torchvision",
|
42 |
-
"transformers": "transformers>=4.25.1",
|
43 |
-
"urllib3": "urllib3<=2.0.0",
|
44 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4DoF/diffusers/experimental/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from .rl import ValueGuidedRLPipeline
|
|
|
|
4DoF/diffusers/experimental/rl/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from .value_guided_sampling import ValueGuidedRLPipeline
|
|
|
|
4DoF/diffusers/experimental/rl/value_guided_sampling.py
DELETED
@@ -1,152 +0,0 @@
|
|
1 |
-
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
import numpy as np
|
16 |
-
import torch
|
17 |
-
import tqdm
|
18 |
-
|
19 |
-
from ...models.unet_1d import UNet1DModel
|
20 |
-
from ...pipelines import DiffusionPipeline
|
21 |
-
from ...utils import randn_tensor
|
22 |
-
from ...utils.dummy_pt_objects import DDPMScheduler
|
23 |
-
|
24 |
-
|
25 |
-
class ValueGuidedRLPipeline(DiffusionPipeline):
|
26 |
-
r"""
|
27 |
-
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
28 |
-
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
29 |
-
Pipeline for sampling actions from a diffusion model trained to predict sequences of states.
|
30 |
-
|
31 |
-
Original implementation inspired by this repository: https://github.com/jannerm/diffuser.
|
32 |
-
|
33 |
-
Parameters:
|
34 |
-
value_function ([`UNet1DModel`]): A specialized UNet for fine-tuning trajectories base on reward.
|
35 |
-
unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded trajectories.
|
36 |
-
scheduler ([`SchedulerMixin`]):
|
37 |
-
A scheduler to be used in combination with `unet` to denoise the encoded trajectories. Default for this
|
38 |
-
application is [`DDPMScheduler`].
|
39 |
-
env: An environment following the OpenAI gym API to act in. For now only Hopper has pretrained models.
|
40 |
-
"""
|
41 |
-
|
42 |
-
def __init__(
|
43 |
-
self,
|
44 |
-
value_function: UNet1DModel,
|
45 |
-
unet: UNet1DModel,
|
46 |
-
scheduler: DDPMScheduler,
|
47 |
-
env,
|
48 |
-
):
|
49 |
-
super().__init__()
|
50 |
-
self.value_function = value_function
|
51 |
-
self.unet = unet
|
52 |
-
self.scheduler = scheduler
|
53 |
-
self.env = env
|
54 |
-
self.data = env.get_dataset()
|
55 |
-
self.means = {}
|
56 |
-
for key in self.data.keys():
|
57 |
-
try:
|
58 |
-
self.means[key] = self.data[key].mean()
|
59 |
-
except: # noqa: E722
|
60 |
-
pass
|
61 |
-
self.stds = {}
|
62 |
-
for key in self.data.keys():
|
63 |
-
try:
|
64 |
-
self.stds[key] = self.data[key].std()
|
65 |
-
except: # noqa: E722
|
66 |
-
pass
|
67 |
-
self.state_dim = env.observation_space.shape[0]
|
68 |
-
self.action_dim = env.action_space.shape[0]
|
69 |
-
|
70 |
-
def normalize(self, x_in, key):
|
71 |
-
return (x_in - self.means[key]) / self.stds[key]
|
72 |
-
|
73 |
-
def de_normalize(self, x_in, key):
|
74 |
-
return x_in * self.stds[key] + self.means[key]
|
75 |
-
|
76 |
-
def to_torch(self, x_in):
|
77 |
-
if type(x_in) is dict:
|
78 |
-
return {k: self.to_torch(v) for k, v in x_in.items()}
|
79 |
-
elif torch.is_tensor(x_in):
|
80 |
-
return x_in.to(self.unet.device)
|
81 |
-
return torch.tensor(x_in, device=self.unet.device)
|
82 |
-
|
83 |
-
def reset_x0(self, x_in, cond, act_dim):
|
84 |
-
for key, val in cond.items():
|
85 |
-
x_in[:, key, act_dim:] = val.clone()
|
86 |
-
return x_in
|
87 |
-
|
88 |
-
def run_diffusion(self, x, conditions, n_guide_steps, scale):
|
89 |
-
batch_size = x.shape[0]
|
90 |
-
y = None
|
91 |
-
for i in tqdm.tqdm(self.scheduler.timesteps):
|
92 |
-
# create batch of timesteps to pass into model
|
93 |
-
timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
|
94 |
-
for _ in range(n_guide_steps):
|
95 |
-
with torch.enable_grad():
|
96 |
-
x.requires_grad_()
|
97 |
-
|
98 |
-
# permute to match dimension for pre-trained models
|
99 |
-
y = self.value_function(x.permute(0, 2, 1), timesteps).sample
|
100 |
-
grad = torch.autograd.grad([y.sum()], [x])[0]
|
101 |
-
|
102 |
-
posterior_variance = self.scheduler._get_variance(i)
|
103 |
-
model_std = torch.exp(0.5 * posterior_variance)
|
104 |
-
grad = model_std * grad
|
105 |
-
|
106 |
-
grad[timesteps < 2] = 0
|
107 |
-
x = x.detach()
|
108 |
-
x = x + scale * grad
|
109 |
-
x = self.reset_x0(x, conditions, self.action_dim)
|
110 |
-
|
111 |
-
prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
|
112 |
-
|
113 |
-
# TODO: verify deprecation of this kwarg
|
114 |
-
x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
|
115 |
-
|
116 |
-
# apply conditions to the trajectory (set the initial state)
|
117 |
-
x = self.reset_x0(x, conditions, self.action_dim)
|
118 |
-
x = self.to_torch(x)
|
119 |
-
return x, y
|
120 |
-
|
121 |
-
def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
|
122 |
-
# normalize the observations and create batch dimension
|
123 |
-
obs = self.normalize(obs, "observations")
|
124 |
-
obs = obs[None].repeat(batch_size, axis=0)
|
125 |
-
|
126 |
-
conditions = {0: self.to_torch(obs)}
|
127 |
-
shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
|
128 |
-
|
129 |
-
# generate initial noise and apply our conditions (to make the trajectories start at current state)
|
130 |
-
x1 = randn_tensor(shape, device=self.unet.device)
|
131 |
-
x = self.reset_x0(x1, conditions, self.action_dim)
|
132 |
-
x = self.to_torch(x)
|
133 |
-
|
134 |
-
# run the diffusion process
|
135 |
-
x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
|
136 |
-
|
137 |
-
# sort output trajectories by value
|
138 |
-
sorted_idx = y.argsort(0, descending=True).squeeze()
|
139 |
-
sorted_values = x[sorted_idx]
|
140 |
-
actions = sorted_values[:, :, : self.action_dim]
|
141 |
-
actions = actions.detach().cpu().numpy()
|
142 |
-
denorm_actions = self.de_normalize(actions, key="actions")
|
143 |
-
|
144 |
-
# select the action with the highest value
|
145 |
-
if y is not None:
|
146 |
-
selected_index = 0
|
147 |
-
else:
|
148 |
-
# if we didn't run value guiding, select a random action
|
149 |
-
selected_index = np.random.randint(0, batch_size)
|
150 |
-
|
151 |
-
denorm_actions = denorm_actions[selected_index, 0]
|
152 |
-
return denorm_actions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4DoF/diffusers/image_processor.py
DELETED
@@ -1,366 +0,0 @@
|
|
1 |
-
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
import warnings
|
16 |
-
from typing import List, Optional, Union
|
17 |
-
|
18 |
-
import numpy as np
|
19 |
-
import PIL
|
20 |
-
import torch
|
21 |
-
from PIL import Image
|
22 |
-
|
23 |
-
from .configuration_utils import ConfigMixin, register_to_config
|
24 |
-
from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
|
25 |
-
|
26 |
-
|
27 |
-
class VaeImageProcessor(ConfigMixin):
|
28 |
-
"""
|
29 |
-
Image processor for VAE.
|
30 |
-
|
31 |
-
Args:
|
32 |
-
do_resize (`bool`, *optional*, defaults to `True`):
|
33 |
-
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
|
34 |
-
`height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
|
35 |
-
vae_scale_factor (`int`, *optional*, defaults to `8`):
|
36 |
-
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
|
37 |
-
resample (`str`, *optional*, defaults to `lanczos`):
|
38 |
-
Resampling filter to use when resizing the image.
|
39 |
-
do_normalize (`bool`, *optional*, defaults to `True`):
|
40 |
-
Whether to normalize the image to [-1,1].
|
41 |
-
do_convert_rgb (`bool`, *optional*, defaults to be `False`):
|
42 |
-
Whether to convert the images to RGB format.
|
43 |
-
"""
|
44 |
-
|
45 |
-
config_name = CONFIG_NAME
|
46 |
-
|
47 |
-
@register_to_config
|
48 |
-
def __init__(
|
49 |
-
self,
|
50 |
-
do_resize: bool = True,
|
51 |
-
vae_scale_factor: int = 8,
|
52 |
-
resample: str = "lanczos",
|
53 |
-
do_normalize: bool = True,
|
54 |
-
do_convert_rgb: bool = False,
|
55 |
-
):
|
56 |
-
super().__init__()
|
57 |
-
|
58 |
-
@staticmethod
|
59 |
-
def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image:
|
60 |
-
"""
|
61 |
-
Convert a numpy image or a batch of images to a PIL image.
|
62 |
-
"""
|
63 |
-
if images.ndim == 3:
|
64 |
-
images = images[None, ...]
|
65 |
-
images = (images * 255).round().astype("uint8")
|
66 |
-
if images.shape[-1] == 1:
|
67 |
-
# special case for grayscale (single channel) images
|
68 |
-
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
69 |
-
else:
|
70 |
-
pil_images = [Image.fromarray(image) for image in images]
|
71 |
-
|
72 |
-
return pil_images
|
73 |
-
|
74 |
-
@staticmethod
|
75 |
-
def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
|
76 |
-
"""
|
77 |
-
Convert a PIL image or a list of PIL images to NumPy arrays.
|
78 |
-
"""
|
79 |
-
if not isinstance(images, list):
|
80 |
-
images = [images]
|
81 |
-
images = [np.array(image).astype(np.float32) / 255.0 for image in images]
|
82 |
-
images = np.stack(images, axis=0)
|
83 |
-
|
84 |
-
return images
|
85 |
-
|
86 |
-
@staticmethod
|
87 |
-
def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
|
88 |
-
"""
|
89 |
-
Convert a NumPy image to a PyTorch tensor.
|
90 |
-
"""
|
91 |
-
if images.ndim == 3:
|
92 |
-
images = images[..., None]
|
93 |
-
|
94 |
-
images = torch.from_numpy(images.transpose(0, 3, 1, 2))
|
95 |
-
return images
|
96 |
-
|
97 |
-
@staticmethod
|
98 |
-
def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
|
99 |
-
"""
|
100 |
-
Convert a PyTorch tensor to a NumPy image.
|
101 |
-
"""
|
102 |
-
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
|
103 |
-
return images
|
104 |
-
|
105 |
-
@staticmethod
|
106 |
-
def normalize(images):
|
107 |
-
"""
|
108 |
-
Normalize an image array to [-1,1].
|
109 |
-
"""
|
110 |
-
return 2.0 * images - 1.0
|
111 |
-
|
112 |
-
@staticmethod
|
113 |
-
def denormalize(images):
|
114 |
-
"""
|
115 |
-
Denormalize an image array to [0,1].
|
116 |
-
"""
|
117 |
-
return (images / 2 + 0.5).clamp(0, 1)
|
118 |
-
|
119 |
-
@staticmethod
|
120 |
-
def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
|
121 |
-
"""
|
122 |
-
Converts an image to RGB format.
|
123 |
-
"""
|
124 |
-
image = image.convert("RGB")
|
125 |
-
return image
|
126 |
-
|
127 |
-
def resize(
|
128 |
-
self,
|
129 |
-
image: PIL.Image.Image,
|
130 |
-
height: Optional[int] = None,
|
131 |
-
width: Optional[int] = None,
|
132 |
-
) -> PIL.Image.Image:
|
133 |
-
"""
|
134 |
-
Resize a PIL image. Both height and width are downscaled to the next integer multiple of `vae_scale_factor`.
|
135 |
-
"""
|
136 |
-
if height is None:
|
137 |
-
height = image.height
|
138 |
-
if width is None:
|
139 |
-
width = image.width
|
140 |
-
|
141 |
-
width, height = (
|
142 |
-
x - x % self.config.vae_scale_factor for x in (width, height)
|
143 |
-
) # resize to integer multiple of vae_scale_factor
|
144 |
-
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
|
145 |
-
return image
|
146 |
-
|
147 |
-
def preprocess(
|
148 |
-
self,
|
149 |
-
image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
|
150 |
-
height: Optional[int] = None,
|
151 |
-
width: Optional[int] = None,
|
152 |
-
) -> torch.Tensor:
|
153 |
-
"""
|
154 |
-
Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
|
155 |
-
"""
|
156 |
-
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
|
157 |
-
if isinstance(image, supported_formats):
|
158 |
-
image = [image]
|
159 |
-
elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)):
|
160 |
-
raise ValueError(
|
161 |
-
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}"
|
162 |
-
)
|
163 |
-
|
164 |
-
if isinstance(image[0], PIL.Image.Image):
|
165 |
-
if self.config.do_convert_rgb:
|
166 |
-
image = [self.convert_to_rgb(i) for i in image]
|
167 |
-
if self.config.do_resize:
|
168 |
-
image = [self.resize(i, height, width) for i in image]
|
169 |
-
image = self.pil_to_numpy(image) # to np
|
170 |
-
image = self.numpy_to_pt(image) # to pt
|
171 |
-
|
172 |
-
elif isinstance(image[0], np.ndarray):
|
173 |
-
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
|
174 |
-
image = self.numpy_to_pt(image)
|
175 |
-
_, _, height, width = image.shape
|
176 |
-
if self.config.do_resize and (
|
177 |
-
height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
|
178 |
-
):
|
179 |
-
raise ValueError(
|
180 |
-
f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.config.vae_scale_factor}"
|
181 |
-
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
|
182 |
-
)
|
183 |
-
|
184 |
-
elif isinstance(image[0], torch.Tensor):
|
185 |
-
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
|
186 |
-
_, channel, height, width = image.shape
|
187 |
-
|
188 |
-
# don't need any preprocess if the image is latents
|
189 |
-
if channel == 4:
|
190 |
-
return image
|
191 |
-
|
192 |
-
if self.config.do_resize and (
|
193 |
-
height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
|
194 |
-
):
|
195 |
-
raise ValueError(
|
196 |
-
f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.config.vae_scale_factor}"
|
197 |
-
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
|
198 |
-
)
|
199 |
-
|
200 |
-
# expected range [0,1], normalize to [-1,1]
|
201 |
-
do_normalize = self.config.do_normalize
|
202 |
-
if image.min() < 0:
|
203 |
-
warnings.warn(
|
204 |
-
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
|
205 |
-
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
|
206 |
-
FutureWarning,
|
207 |
-
)
|
208 |
-
do_normalize = False
|
209 |
-
|
210 |
-
if do_normalize:
|
211 |
-
image = self.normalize(image)
|
212 |
-
|
213 |
-
return image
|
214 |
-
|
215 |
-
def postprocess(
|
216 |
-
self,
|
217 |
-
image: torch.FloatTensor,
|
218 |
-
output_type: str = "pil",
|
219 |
-
do_denormalize: Optional[List[bool]] = None,
|
220 |
-
):
|
221 |
-
if not isinstance(image, torch.Tensor):
|
222 |
-
raise ValueError(
|
223 |
-
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
|
224 |
-
)
|
225 |
-
if output_type not in ["latent", "pt", "np", "pil"]:
|
226 |
-
deprecation_message = (
|
227 |
-
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
|
228 |
-
"`pil`, `np`, `pt`, `latent`"
|
229 |
-
)
|
230 |
-
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
|
231 |
-
output_type = "np"
|
232 |
-
|
233 |
-
if output_type == "latent":
|
234 |
-
return image
|
235 |
-
|
236 |
-
if do_denormalize is None:
|
237 |
-
do_denormalize = [self.config.do_normalize] * image.shape[0]
|
238 |
-
|
239 |
-
image = torch.stack(
|
240 |
-
[self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
|
241 |
-
)
|
242 |
-
|
243 |
-
if output_type == "pt":
|
244 |
-
return image
|
245 |
-
|
246 |
-
image = self.pt_to_numpy(image)
|
247 |
-
|
248 |
-
if output_type == "np":
|
249 |
-
return image
|
250 |
-
|
251 |
-
if output_type == "pil":
|
252 |
-
return self.numpy_to_pil(image)
|
253 |
-
|
254 |
-
|
255 |
-
class VaeImageProcessorLDM3D(VaeImageProcessor):
|
256 |
-
"""
|
257 |
-
Image processor for VAE LDM3D.
|
258 |
-
|
259 |
-
Args:
|
260 |
-
do_resize (`bool`, *optional*, defaults to `True`):
|
261 |
-
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
|
262 |
-
vae_scale_factor (`int`, *optional*, defaults to `8`):
|
263 |
-
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
|
264 |
-
resample (`str`, *optional*, defaults to `lanczos`):
|
265 |
-
Resampling filter to use when resizing the image.
|
266 |
-
do_normalize (`bool`, *optional*, defaults to `True`):
|
267 |
-
Whether to normalize the image to [-1,1].
|
268 |
-
"""
|
269 |
-
|
270 |
-
config_name = CONFIG_NAME
|
271 |
-
|
272 |
-
@register_to_config
|
273 |
-
def __init__(
|
274 |
-
self,
|
275 |
-
do_resize: bool = True,
|
276 |
-
vae_scale_factor: int = 8,
|
277 |
-
resample: str = "lanczos",
|
278 |
-
do_normalize: bool = True,
|
279 |
-
):
|
280 |
-
super().__init__()
|
281 |
-
|
282 |
-
@staticmethod
|
283 |
-
def numpy_to_pil(images):
|
284 |
-
"""
|
285 |
-
Convert a NumPy image or a batch of images to a PIL image.
|
286 |
-
"""
|
287 |
-
if images.ndim == 3:
|
288 |
-
images = images[None, ...]
|
289 |
-
images = (images * 255).round().astype("uint8")
|
290 |
-
if images.shape[-1] == 1:
|
291 |
-
# special case for grayscale (single channel) images
|
292 |
-
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
293 |
-
else:
|
294 |
-
pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
|
295 |
-
|
296 |
-
return pil_images
|
297 |
-
|
298 |
-
@staticmethod
|
299 |
-
def rgblike_to_depthmap(image):
|
300 |
-
"""
|
301 |
-
Args:
|
302 |
-
image: RGB-like depth image
|
303 |
-
|
304 |
-
Returns: depth map
|
305 |
-
|
306 |
-
"""
|
307 |
-
return image[:, :, 1] * 2**8 + image[:, :, 2]
|
308 |
-
|
309 |
-
def numpy_to_depth(self, images):
|
310 |
-
"""
|
311 |
-
Convert a NumPy depth image or a batch of images to a PIL image.
|
312 |
-
"""
|
313 |
-
if images.ndim == 3:
|
314 |
-
images = images[None, ...]
|
315 |
-
images_depth = images[:, :, :, 3:]
|
316 |
-
if images.shape[-1] == 6:
|
317 |
-
images_depth = (images_depth * 255).round().astype("uint8")
|
318 |
-
pil_images = [
|
319 |
-
Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth
|
320 |
-
]
|
321 |
-
elif images.shape[-1] == 4:
|
322 |
-
images_depth = (images_depth * 65535.0).astype(np.uint16)
|
323 |
-
pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth]
|
324 |
-
else:
|
325 |
-
raise Exception("Not supported")
|
326 |
-
|
327 |
-
return pil_images
|
328 |
-
|
329 |
-
def postprocess(
|
330 |
-
self,
|
331 |
-
image: torch.FloatTensor,
|
332 |
-
output_type: str = "pil",
|
333 |
-
do_denormalize: Optional[List[bool]] = None,
|
334 |
-
):
|
335 |
-
if not isinstance(image, torch.Tensor):
|
336 |
-
raise ValueError(
|
337 |
-
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
|
338 |
-
)
|
339 |
-
if output_type not in ["latent", "pt", "np", "pil"]:
|
340 |
-
deprecation_message = (
|
341 |
-
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
|
342 |
-
"`pil`, `np`, `pt`, `latent`"
|
343 |
-
)
|
344 |
-
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
|
345 |
-
output_type = "np"
|
346 |
-
|
347 |
-
if do_denormalize is None:
|
348 |
-
do_denormalize = [self.config.do_normalize] * image.shape[0]
|
349 |
-
|
350 |
-
image = torch.stack(
|
351 |
-
[self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
|
352 |
-
)
|
353 |
-
|
354 |
-
image = self.pt_to_numpy(image)
|
355 |
-
|
356 |
-
if output_type == "np":
|
357 |
-
if image.shape[-1] == 6:
|
358 |
-
image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)
|
359 |
-
else:
|
360 |
-
image_depth = image[:, :, :, 3:]
|
361 |
-
return image[:, :, :, :3], image_depth
|
362 |
-
|
363 |
-
if output_type == "pil":
|
364 |
-
return self.numpy_to_pil(image), self.numpy_to_depth(image)
|
365 |
-
else:
|
366 |
-
raise Exception(f"This type {output_type} is not supported")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4DoF/diffusers/loaders.py
DELETED
@@ -1,1492 +0,0 @@
|
|
1 |
-
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
import os
|
15 |
-
import warnings
|
16 |
-
from collections import defaultdict
|
17 |
-
from pathlib import Path
|
18 |
-
from typing import Callable, Dict, List, Optional, Union
|
19 |
-
|
20 |
-
import torch
|
21 |
-
import torch.nn.functional as F
|
22 |
-
from huggingface_hub import hf_hub_download
|
23 |
-
|
24 |
-
from .models.attention_processor import (
|
25 |
-
AttnAddedKVProcessor,
|
26 |
-
AttnAddedKVProcessor2_0,
|
27 |
-
CustomDiffusionAttnProcessor,
|
28 |
-
CustomDiffusionXFormersAttnProcessor,
|
29 |
-
LoRAAttnAddedKVProcessor,
|
30 |
-
LoRAAttnProcessor,
|
31 |
-
LoRAAttnProcessor2_0,
|
32 |
-
LoRAXFormersAttnProcessor,
|
33 |
-
SlicedAttnAddedKVProcessor,
|
34 |
-
XFormersAttnProcessor,
|
35 |
-
)
|
36 |
-
from .utils import (
|
37 |
-
DIFFUSERS_CACHE,
|
38 |
-
HF_HUB_OFFLINE,
|
39 |
-
TEXT_ENCODER_ATTN_MODULE,
|
40 |
-
_get_model_file,
|
41 |
-
deprecate,
|
42 |
-
is_safetensors_available,
|
43 |
-
is_transformers_available,
|
44 |
-
logging,
|
45 |
-
)
|
46 |
-
|
47 |
-
|
48 |
-
if is_safetensors_available():
|
49 |
-
import safetensors
|
50 |
-
|
51 |
-
if is_transformers_available():
|
52 |
-
from transformers import PreTrainedModel, PreTrainedTokenizer
|
53 |
-
|
54 |
-
|
55 |
-
logger = logging.get_logger(__name__)
|
56 |
-
|
57 |
-
TEXT_ENCODER_NAME = "text_encoder"
|
58 |
-
UNET_NAME = "unet"
|
59 |
-
|
60 |
-
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
|
61 |
-
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
62 |
-
|
63 |
-
TEXT_INVERSION_NAME = "learned_embeds.bin"
|
64 |
-
TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
|
65 |
-
|
66 |
-
CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
|
67 |
-
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
|
68 |
-
|
69 |
-
|
70 |
-
class AttnProcsLayers(torch.nn.Module):
|
71 |
-
def __init__(self, state_dict: Dict[str, torch.Tensor]):
|
72 |
-
super().__init__()
|
73 |
-
self.layers = torch.nn.ModuleList(state_dict.values())
|
74 |
-
self.mapping = dict(enumerate(state_dict.keys()))
|
75 |
-
self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
|
76 |
-
|
77 |
-
# .processor for unet, .self_attn for text encoder
|
78 |
-
self.split_keys = [".processor", ".self_attn"]
|
79 |
-
|
80 |
-
# we add a hook to state_dict() and load_state_dict() so that the
|
81 |
-
# naming fits with `unet.attn_processors`
|
82 |
-
def map_to(module, state_dict, *args, **kwargs):
|
83 |
-
new_state_dict = {}
|
84 |
-
for key, value in state_dict.items():
|
85 |
-
num = int(key.split(".")[1]) # 0 is always "layers"
|
86 |
-
new_key = key.replace(f"layers.{num}", module.mapping[num])
|
87 |
-
new_state_dict[new_key] = value
|
88 |
-
|
89 |
-
return new_state_dict
|
90 |
-
|
91 |
-
def remap_key(key, state_dict):
|
92 |
-
for k in self.split_keys:
|
93 |
-
if k in key:
|
94 |
-
return key.split(k)[0] + k
|
95 |
-
|
96 |
-
raise ValueError(
|
97 |
-
f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}."
|
98 |
-
)
|
99 |
-
|
100 |
-
def map_from(module, state_dict, *args, **kwargs):
|
101 |
-
all_keys = list(state_dict.keys())
|
102 |
-
for key in all_keys:
|
103 |
-
replace_key = remap_key(key, state_dict)
|
104 |
-
new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
|
105 |
-
state_dict[new_key] = state_dict[key]
|
106 |
-
del state_dict[key]
|
107 |
-
|
108 |
-
self._register_state_dict_hook(map_to)
|
109 |
-
self._register_load_state_dict_pre_hook(map_from, with_module=True)
|
110 |
-
|
111 |
-
|
112 |
-
class UNet2DConditionLoadersMixin:
|
113 |
-
text_encoder_name = TEXT_ENCODER_NAME
|
114 |
-
unet_name = UNET_NAME
|
115 |
-
|
116 |
-
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
117 |
-
r"""
|
118 |
-
Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
|
119 |
-
defined in
|
120 |
-
[`cross_attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py)
|
121 |
-
and be a `torch.nn.Module` class.
|
122 |
-
|
123 |
-
Parameters:
|
124 |
-
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
125 |
-
Can be either:
|
126 |
-
|
127 |
-
- A string, the model id (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
128 |
-
the Hub.
|
129 |
-
- A path to a directory (for example `./my_model_directory`) containing the model weights saved
|
130 |
-
with [`ModelMixin.save_pretrained`].
|
131 |
-
- A [torch state
|
132 |
-
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
133 |
-
|
134 |
-
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
135 |
-
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
136 |
-
is not used.
|
137 |
-
force_download (`bool`, *optional*, defaults to `False`):
|
138 |
-
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
139 |
-
cached versions if they exist.
|
140 |
-
resume_download (`bool`, *optional*, defaults to `False`):
|
141 |
-
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
142 |
-
incompletely downloaded files are deleted.
|
143 |
-
proxies (`Dict[str, str]`, *optional*):
|
144 |
-
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
145 |
-
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
146 |
-
local_files_only (`bool`, *optional*, defaults to `False`):
|
147 |
-
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
148 |
-
won't be downloaded from the Hub.
|
149 |
-
use_auth_token (`str` or *bool*, *optional*):
|
150 |
-
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
151 |
-
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
152 |
-
revision (`str`, *optional*, defaults to `"main"`):
|
153 |
-
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
154 |
-
allowed by Git.
|
155 |
-
subfolder (`str`, *optional*, defaults to `""`):
|
156 |
-
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
157 |
-
mirror (`str`, *optional*):
|
158 |
-
Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
|
159 |
-
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
160 |
-
information.
|
161 |
-
|
162 |
-
"""
|
163 |
-
|
164 |
-
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
165 |
-
force_download = kwargs.pop("force_download", False)
|
166 |
-
resume_download = kwargs.pop("resume_download", False)
|
167 |
-
proxies = kwargs.pop("proxies", None)
|
168 |
-
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
169 |
-
use_auth_token = kwargs.pop("use_auth_token", None)
|
170 |
-
revision = kwargs.pop("revision", None)
|
171 |
-
subfolder = kwargs.pop("subfolder", None)
|
172 |
-
weight_name = kwargs.pop("weight_name", None)
|
173 |
-
use_safetensors = kwargs.pop("use_safetensors", None)
|
174 |
-
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
175 |
-
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
176 |
-
network_alpha = kwargs.pop("network_alpha", None)
|
177 |
-
|
178 |
-
if use_safetensors and not is_safetensors_available():
|
179 |
-
raise ValueError(
|
180 |
-
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
|
181 |
-
)
|
182 |
-
|
183 |
-
allow_pickle = False
|
184 |
-
if use_safetensors is None:
|
185 |
-
use_safetensors = is_safetensors_available()
|
186 |
-
allow_pickle = True
|
187 |
-
|
188 |
-
user_agent = {
|
189 |
-
"file_type": "attn_procs_weights",
|
190 |
-
"framework": "pytorch",
|
191 |
-
}
|
192 |
-
|
193 |
-
model_file = None
|
194 |
-
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
195 |
-
# Let's first try to load .safetensors weights
|
196 |
-
if (use_safetensors and weight_name is None) or (
|
197 |
-
weight_name is not None and weight_name.endswith(".safetensors")
|
198 |
-
):
|
199 |
-
try:
|
200 |
-
model_file = _get_model_file(
|
201 |
-
pretrained_model_name_or_path_or_dict,
|
202 |
-
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
|
203 |
-
cache_dir=cache_dir,
|
204 |
-
force_download=force_download,
|
205 |
-
resume_download=resume_download,
|
206 |
-
proxies=proxies,
|
207 |
-
local_files_only=local_files_only,
|
208 |
-
use_auth_token=use_auth_token,
|
209 |
-
revision=revision,
|
210 |
-
subfolder=subfolder,
|
211 |
-
user_agent=user_agent,
|
212 |
-
)
|
213 |
-
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
214 |
-
except IOError as e:
|
215 |
-
if not allow_pickle:
|
216 |
-
raise e
|
217 |
-
# try loading non-safetensors weights
|
218 |
-
pass
|
219 |
-
if model_file is None:
|
220 |
-
model_file = _get_model_file(
|
221 |
-
pretrained_model_name_or_path_or_dict,
|
222 |
-
weights_name=weight_name or LORA_WEIGHT_NAME,
|
223 |
-
cache_dir=cache_dir,
|
224 |
-
force_download=force_download,
|
225 |
-
resume_download=resume_download,
|
226 |
-
proxies=proxies,
|
227 |
-
local_files_only=local_files_only,
|
228 |
-
use_auth_token=use_auth_token,
|
229 |
-
revision=revision,
|
230 |
-
subfolder=subfolder,
|
231 |
-
user_agent=user_agent,
|
232 |
-
)
|
233 |
-
state_dict = torch.load(model_file, map_location="cpu")
|
234 |
-
else:
|
235 |
-
state_dict = pretrained_model_name_or_path_or_dict
|
236 |
-
|
237 |
-
# fill attn processors
|
238 |
-
attn_processors = {}
|
239 |
-
|
240 |
-
is_lora = all("lora" in k for k in state_dict.keys())
|
241 |
-
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
|
242 |
-
|
243 |
-
if is_lora:
|
244 |
-
is_new_lora_format = all(
|
245 |
-
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
|
246 |
-
)
|
247 |
-
if is_new_lora_format:
|
248 |
-
# Strip the `"unet"` prefix.
|
249 |
-
is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
|
250 |
-
if is_text_encoder_present:
|
251 |
-
warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
|
252 |
-
warnings.warn(warn_message)
|
253 |
-
unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
|
254 |
-
state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
|
255 |
-
|
256 |
-
lora_grouped_dict = defaultdict(dict)
|
257 |
-
for key, value in state_dict.items():
|
258 |
-
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
|
259 |
-
lora_grouped_dict[attn_processor_key][sub_key] = value
|
260 |
-
|
261 |
-
for key, value_dict in lora_grouped_dict.items():
|
262 |
-
rank = value_dict["to_k_lora.down.weight"].shape[0]
|
263 |
-
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
|
264 |
-
|
265 |
-
attn_processor = self
|
266 |
-
for sub_key in key.split("."):
|
267 |
-
attn_processor = getattr(attn_processor, sub_key)
|
268 |
-
|
269 |
-
if isinstance(
|
270 |
-
attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)
|
271 |
-
):
|
272 |
-
cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1]
|
273 |
-
attn_processor_class = LoRAAttnAddedKVProcessor
|
274 |
-
else:
|
275 |
-
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
|
276 |
-
if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)):
|
277 |
-
attn_processor_class = LoRAXFormersAttnProcessor
|
278 |
-
else:
|
279 |
-
attn_processor_class = (
|
280 |
-
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
281 |
-
)
|
282 |
-
|
283 |
-
attn_processors[key] = attn_processor_class(
|
284 |
-
hidden_size=hidden_size,
|
285 |
-
cross_attention_dim=cross_attention_dim,
|
286 |
-
rank=rank,
|
287 |
-
network_alpha=network_alpha,
|
288 |
-
)
|
289 |
-
attn_processors[key].load_state_dict(value_dict)
|
290 |
-
elif is_custom_diffusion:
|
291 |
-
custom_diffusion_grouped_dict = defaultdict(dict)
|
292 |
-
for key, value in state_dict.items():
|
293 |
-
if len(value) == 0:
|
294 |
-
custom_diffusion_grouped_dict[key] = {}
|
295 |
-
else:
|
296 |
-
if "to_out" in key:
|
297 |
-
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
|
298 |
-
else:
|
299 |
-
attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
|
300 |
-
custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
|
301 |
-
|
302 |
-
for key, value_dict in custom_diffusion_grouped_dict.items():
|
303 |
-
if len(value_dict) == 0:
|
304 |
-
attn_processors[key] = CustomDiffusionAttnProcessor(
|
305 |
-
train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
|
306 |
-
)
|
307 |
-
else:
|
308 |
-
cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
|
309 |
-
hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
|
310 |
-
train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
|
311 |
-
attn_processors[key] = CustomDiffusionAttnProcessor(
|
312 |
-
train_kv=True,
|
313 |
-
train_q_out=train_q_out,
|
314 |
-
hidden_size=hidden_size,
|
315 |
-
cross_attention_dim=cross_attention_dim,
|
316 |
-
)
|
317 |
-
attn_processors[key].load_state_dict(value_dict)
|
318 |
-
else:
|
319 |
-
raise ValueError(
|
320 |
-
f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
|
321 |
-
)
|
322 |
-
|
323 |
-
# set correct dtype & device
|
324 |
-
attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
|
325 |
-
|
326 |
-
# set layers
|
327 |
-
self.set_attn_processor(attn_processors)
|
328 |
-
|
329 |
-
def save_attn_procs(
|
330 |
-
self,
|
331 |
-
save_directory: Union[str, os.PathLike],
|
332 |
-
is_main_process: bool = True,
|
333 |
-
weight_name: str = None,
|
334 |
-
save_function: Callable = None,
|
335 |
-
safe_serialization: bool = False,
|
336 |
-
**kwargs,
|
337 |
-
):
|
338 |
-
r"""
|
339 |
-
Save an attention processor to a directory so that it can be reloaded using the
|
340 |
-
[`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method.
|
341 |
-
|
342 |
-
Arguments:
|
343 |
-
save_directory (`str` or `os.PathLike`):
|
344 |
-
Directory to save an attention processor to. Will be created if it doesn't exist.
|
345 |
-
is_main_process (`bool`, *optional*, defaults to `True`):
|
346 |
-
Whether the process calling this is the main process or not. Useful during distributed training and you
|
347 |
-
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
348 |
-
process to avoid race conditions.
|
349 |
-
save_function (`Callable`):
|
350 |
-
The function to use to save the state dictionary. Useful during distributed training when you need to
|
351 |
-
replace `torch.save` with another method. Can be configured with the environment variable
|
352 |
-
`DIFFUSERS_SAVE_MODE`.
|
353 |
-
|
354 |
-
"""
|
355 |
-
weight_name = weight_name or deprecate(
|
356 |
-
"weights_name",
|
357 |
-
"0.20.0",
|
358 |
-
"`weights_name` is deprecated, please use `weight_name` instead.",
|
359 |
-
take_from=kwargs,
|
360 |
-
)
|
361 |
-
if os.path.isfile(save_directory):
|
362 |
-
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
363 |
-
return
|
364 |
-
|
365 |
-
if save_function is None:
|
366 |
-
if safe_serialization:
|
367 |
-
|
368 |
-
def save_function(weights, filename):
|
369 |
-
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
|
370 |
-
|
371 |
-
else:
|
372 |
-
save_function = torch.save
|
373 |
-
|
374 |
-
os.makedirs(save_directory, exist_ok=True)
|
375 |
-
|
376 |
-
is_custom_diffusion = any(
|
377 |
-
isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
|
378 |
-
for (_, x) in self.attn_processors.items()
|
379 |
-
)
|
380 |
-
if is_custom_diffusion:
|
381 |
-
model_to_save = AttnProcsLayers(
|
382 |
-
{
|
383 |
-
y: x
|
384 |
-
for (y, x) in self.attn_processors.items()
|
385 |
-
if isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
|
386 |
-
}
|
387 |
-
)
|
388 |
-
state_dict = model_to_save.state_dict()
|
389 |
-
for name, attn in self.attn_processors.items():
|
390 |
-
if len(attn.state_dict()) == 0:
|
391 |
-
state_dict[name] = {}
|
392 |
-
else:
|
393 |
-
model_to_save = AttnProcsLayers(self.attn_processors)
|
394 |
-
state_dict = model_to_save.state_dict()
|
395 |
-
|
396 |
-
if weight_name is None:
|
397 |
-
if safe_serialization:
|
398 |
-
weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE
|
399 |
-
else:
|
400 |
-
weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME
|
401 |
-
|
402 |
-
# Save the model
|
403 |
-
save_function(state_dict, os.path.join(save_directory, weight_name))
|
404 |
-
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
|
405 |
-
|
406 |
-
|
407 |
-
class TextualInversionLoaderMixin:
|
408 |
-
r"""
|
409 |
-
Load textual inversion tokens and embeddings to the tokenizer and text encoder.
|
410 |
-
"""
|
411 |
-
|
412 |
-
def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"):
|
413 |
-
r"""
|
414 |
-
Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to
|
415 |
-
be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
|
416 |
-
inversion token or if the textual inversion token is a single vector, the input prompt is returned.
|
417 |
-
|
418 |
-
Parameters:
|
419 |
-
prompt (`str` or list of `str`):
|
420 |
-
The prompt or prompts to guide the image generation.
|
421 |
-
tokenizer (`PreTrainedTokenizer`):
|
422 |
-
The tokenizer responsible for encoding the prompt into input tokens.
|
423 |
-
|
424 |
-
Returns:
|
425 |
-
`str` or list of `str`: The converted prompt
|
426 |
-
"""
|
427 |
-
if not isinstance(prompt, List):
|
428 |
-
prompts = [prompt]
|
429 |
-
else:
|
430 |
-
prompts = prompt
|
431 |
-
|
432 |
-
prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts]
|
433 |
-
|
434 |
-
if not isinstance(prompt, List):
|
435 |
-
return prompts[0]
|
436 |
-
|
437 |
-
return prompts
|
438 |
-
|
439 |
-
def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"):
|
440 |
-
r"""
|
441 |
-
Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
|
442 |
-
to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
|
443 |
-
is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
|
444 |
-
inversion token or a textual inversion token that is a single vector, the input prompt is simply returned.
|
445 |
-
|
446 |
-
Parameters:
|
447 |
-
prompt (`str`):
|
448 |
-
The prompt to guide the image generation.
|
449 |
-
tokenizer (`PreTrainedTokenizer`):
|
450 |
-
The tokenizer responsible for encoding the prompt into input tokens.
|
451 |
-
|
452 |
-
Returns:
|
453 |
-
`str`: The converted prompt
|
454 |
-
"""
|
455 |
-
tokens = tokenizer.tokenize(prompt)
|
456 |
-
unique_tokens = set(tokens)
|
457 |
-
for token in unique_tokens:
|
458 |
-
if token in tokenizer.added_tokens_encoder:
|
459 |
-
replacement = token
|
460 |
-
i = 1
|
461 |
-
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
|
462 |
-
replacement += f" {token}_{i}"
|
463 |
-
i += 1
|
464 |
-
|
465 |
-
prompt = prompt.replace(token, replacement)
|
466 |
-
|
467 |
-
return prompt
|
468 |
-
|
469 |
-
def load_textual_inversion(
|
470 |
-
self,
|
471 |
-
pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
|
472 |
-
token: Optional[Union[str, List[str]]] = None,
|
473 |
-
**kwargs,
|
474 |
-
):
|
475 |
-
r"""
|
476 |
-
Load textual inversion embeddings into the text encoder of [`StableDiffusionPipeline`] (both 🤗 Diffusers and
|
477 |
-
Automatic1111 formats are supported).
|
478 |
-
|
479 |
-
Parameters:
|
480 |
-
pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`):
|
481 |
-
Can be either one of the following or a list of them:
|
482 |
-
|
483 |
-
- A string, the *model id* (for example `sd-concepts-library/low-poly-hd-logos-icons`) of a
|
484 |
-
pretrained model hosted on the Hub.
|
485 |
-
- A path to a *directory* (for example `./my_text_inversion_directory/`) containing the textual
|
486 |
-
inversion weights.
|
487 |
-
- A path to a *file* (for example `./my_text_inversions.pt`) containing textual inversion weights.
|
488 |
-
- A [torch state
|
489 |
-
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
490 |
-
|
491 |
-
token (`str` or `List[str]`, *optional*):
|
492 |
-
Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a
|
493 |
-
list, then `token` must also be a list of equal length.
|
494 |
-
weight_name (`str`, *optional*):
|
495 |
-
Name of a custom weight file. This should be used when:
|
496 |
-
|
497 |
-
- The saved textual inversion file is in 🤗 Diffusers format, but was saved under a specific weight
|
498 |
-
name such as `text_inv.bin`.
|
499 |
-
- The saved textual inversion file is in the Automatic1111 format.
|
500 |
-
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
501 |
-
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
502 |
-
is not used.
|
503 |
-
force_download (`bool`, *optional*, defaults to `False`):
|
504 |
-
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
505 |
-
cached versions if they exist.
|
506 |
-
resume_download (`bool`, *optional*, defaults to `False`):
|
507 |
-
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
508 |
-
incompletely downloaded files are deleted.
|
509 |
-
proxies (`Dict[str, str]`, *optional*):
|
510 |
-
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
511 |
-
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
512 |
-
local_files_only (`bool`, *optional*, defaults to `False`):
|
513 |
-
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
514 |
-
won't be downloaded from the Hub.
|
515 |
-
use_auth_token (`str` or *bool*, *optional*):
|
516 |
-
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
517 |
-
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
518 |
-
revision (`str`, *optional*, defaults to `"main"`):
|
519 |
-
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
520 |
-
allowed by Git.
|
521 |
-
subfolder (`str`, *optional*, defaults to `""`):
|
522 |
-
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
523 |
-
mirror (`str`, *optional*):
|
524 |
-
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
|
525 |
-
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
526 |
-
information.
|
527 |
-
|
528 |
-
Example:
|
529 |
-
|
530 |
-
To load a textual inversion embedding vector in 🤗 Diffusers format:
|
531 |
-
|
532 |
-
```py
|
533 |
-
from diffusers import StableDiffusionPipeline
|
534 |
-
import torch
|
535 |
-
|
536 |
-
model_id = "runwayml/stable-diffusion-v1-5"
|
537 |
-
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
|
538 |
-
|
539 |
-
pipe.load_textual_inversion("sd-concepts-library/cat-toy")
|
540 |
-
|
541 |
-
prompt = "A <cat-toy> backpack"
|
542 |
-
|
543 |
-
image = pipe(prompt, num_inference_steps=50).images[0]
|
544 |
-
image.save("cat-backpack.png")
|
545 |
-
```
|
546 |
-
|
547 |
-
To load a textual inversion embedding vector in Automatic1111 format, make sure to download the vector first
|
548 |
-
(for example from [civitAI](https://civitai.com/models/3036?modelVersionId=9857)) and then load the vector
|
549 |
-
locally:
|
550 |
-
|
551 |
-
```py
|
552 |
-
from diffusers import StableDiffusionPipeline
|
553 |
-
import torch
|
554 |
-
|
555 |
-
model_id = "runwayml/stable-diffusion-v1-5"
|
556 |
-
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
|
557 |
-
|
558 |
-
pipe.load_textual_inversion("./charturnerv2.pt", token="charturnerv2")
|
559 |
-
|
560 |
-
prompt = "charturnerv2, multiple views of the same character in the same outfit, a character turnaround of a woman wearing a black jacket and red shirt, best quality, intricate details."
|
561 |
-
|
562 |
-
image = pipe(prompt, num_inference_steps=50).images[0]
|
563 |
-
image.save("character.png")
|
564 |
-
```
|
565 |
-
|
566 |
-
"""
|
567 |
-
if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer):
|
568 |
-
raise ValueError(
|
569 |
-
f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling"
|
570 |
-
f" `{self.load_textual_inversion.__name__}`"
|
571 |
-
)
|
572 |
-
|
573 |
-
if not hasattr(self, "text_encoder") or not isinstance(self.text_encoder, PreTrainedModel):
|
574 |
-
raise ValueError(
|
575 |
-
f"{self.__class__.__name__} requires `self.text_encoder` of type `PreTrainedModel` for calling"
|
576 |
-
f" `{self.load_textual_inversion.__name__}`"
|
577 |
-
)
|
578 |
-
|
579 |
-
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
580 |
-
force_download = kwargs.pop("force_download", False)
|
581 |
-
resume_download = kwargs.pop("resume_download", False)
|
582 |
-
proxies = kwargs.pop("proxies", None)
|
583 |
-
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
584 |
-
use_auth_token = kwargs.pop("use_auth_token", None)
|
585 |
-
revision = kwargs.pop("revision", None)
|
586 |
-
subfolder = kwargs.pop("subfolder", None)
|
587 |
-
weight_name = kwargs.pop("weight_name", None)
|
588 |
-
use_safetensors = kwargs.pop("use_safetensors", None)
|
589 |
-
|
590 |
-
if use_safetensors and not is_safetensors_available():
|
591 |
-
raise ValueError(
|
592 |
-
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
|
593 |
-
)
|
594 |
-
|
595 |
-
allow_pickle = False
|
596 |
-
if use_safetensors is None:
|
597 |
-
use_safetensors = is_safetensors_available()
|
598 |
-
allow_pickle = True
|
599 |
-
|
600 |
-
user_agent = {
|
601 |
-
"file_type": "text_inversion",
|
602 |
-
"framework": "pytorch",
|
603 |
-
}
|
604 |
-
|
605 |
-
if not isinstance(pretrained_model_name_or_path, list):
|
606 |
-
pretrained_model_name_or_paths = [pretrained_model_name_or_path]
|
607 |
-
else:
|
608 |
-
pretrained_model_name_or_paths = pretrained_model_name_or_path
|
609 |
-
|
610 |
-
if isinstance(token, str):
|
611 |
-
tokens = [token]
|
612 |
-
elif token is None:
|
613 |
-
tokens = [None] * len(pretrained_model_name_or_paths)
|
614 |
-
else:
|
615 |
-
tokens = token
|
616 |
-
|
617 |
-
if len(pretrained_model_name_or_paths) != len(tokens):
|
618 |
-
raise ValueError(
|
619 |
-
f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)}"
|
620 |
-
f"Make sure both lists have the same length."
|
621 |
-
)
|
622 |
-
|
623 |
-
valid_tokens = [t for t in tokens if t is not None]
|
624 |
-
if len(set(valid_tokens)) < len(valid_tokens):
|
625 |
-
raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}")
|
626 |
-
|
627 |
-
token_ids_and_embeddings = []
|
628 |
-
|
629 |
-
for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens):
|
630 |
-
if not isinstance(pretrained_model_name_or_path, dict):
|
631 |
-
# 1. Load textual inversion file
|
632 |
-
model_file = None
|
633 |
-
# Let's first try to load .safetensors weights
|
634 |
-
if (use_safetensors and weight_name is None) or (
|
635 |
-
weight_name is not None and weight_name.endswith(".safetensors")
|
636 |
-
):
|
637 |
-
try:
|
638 |
-
model_file = _get_model_file(
|
639 |
-
pretrained_model_name_or_path,
|
640 |
-
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
|
641 |
-
cache_dir=cache_dir,
|
642 |
-
force_download=force_download,
|
643 |
-
resume_download=resume_download,
|
644 |
-
proxies=proxies,
|
645 |
-
local_files_only=local_files_only,
|
646 |
-
use_auth_token=use_auth_token,
|
647 |
-
revision=revision,
|
648 |
-
subfolder=subfolder,
|
649 |
-
user_agent=user_agent,
|
650 |
-
)
|
651 |
-
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
652 |
-
except Exception as e:
|
653 |
-
if not allow_pickle:
|
654 |
-
raise e
|
655 |
-
|
656 |
-
model_file = None
|
657 |
-
|
658 |
-
if model_file is None:
|
659 |
-
model_file = _get_model_file(
|
660 |
-
pretrained_model_name_or_path,
|
661 |
-
weights_name=weight_name or TEXT_INVERSION_NAME,
|
662 |
-
cache_dir=cache_dir,
|
663 |
-
force_download=force_download,
|
664 |
-
resume_download=resume_download,
|
665 |
-
proxies=proxies,
|
666 |
-
local_files_only=local_files_only,
|
667 |
-
use_auth_token=use_auth_token,
|
668 |
-
revision=revision,
|
669 |
-
subfolder=subfolder,
|
670 |
-
user_agent=user_agent,
|
671 |
-
)
|
672 |
-
state_dict = torch.load(model_file, map_location="cpu")
|
673 |
-
else:
|
674 |
-
state_dict = pretrained_model_name_or_path
|
675 |
-
|
676 |
-
# 2. Load token and embedding correcly from file
|
677 |
-
loaded_token = None
|
678 |
-
if isinstance(state_dict, torch.Tensor):
|
679 |
-
if token is None:
|
680 |
-
raise ValueError(
|
681 |
-
"You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
|
682 |
-
)
|
683 |
-
embedding = state_dict
|
684 |
-
elif len(state_dict) == 1:
|
685 |
-
# diffusers
|
686 |
-
loaded_token, embedding = next(iter(state_dict.items()))
|
687 |
-
elif "string_to_param" in state_dict:
|
688 |
-
# A1111
|
689 |
-
loaded_token = state_dict["name"]
|
690 |
-
embedding = state_dict["string_to_param"]["*"]
|
691 |
-
|
692 |
-
if token is not None and loaded_token != token:
|
693 |
-
logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
|
694 |
-
else:
|
695 |
-
token = loaded_token
|
696 |
-
|
697 |
-
embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device)
|
698 |
-
|
699 |
-
# 3. Make sure we don't mess up the tokenizer or text encoder
|
700 |
-
vocab = self.tokenizer.get_vocab()
|
701 |
-
if token in vocab:
|
702 |
-
raise ValueError(
|
703 |
-
f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
|
704 |
-
)
|
705 |
-
elif f"{token}_1" in vocab:
|
706 |
-
multi_vector_tokens = [token]
|
707 |
-
i = 1
|
708 |
-
while f"{token}_{i}" in self.tokenizer.added_tokens_encoder:
|
709 |
-
multi_vector_tokens.append(f"{token}_{i}")
|
710 |
-
i += 1
|
711 |
-
|
712 |
-
raise ValueError(
|
713 |
-
f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
|
714 |
-
)
|
715 |
-
|
716 |
-
is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
|
717 |
-
|
718 |
-
if is_multi_vector:
|
719 |
-
tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
|
720 |
-
embeddings = [e for e in embedding] # noqa: C416
|
721 |
-
else:
|
722 |
-
tokens = [token]
|
723 |
-
embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding]
|
724 |
-
|
725 |
-
# add tokens and get ids
|
726 |
-
self.tokenizer.add_tokens(tokens)
|
727 |
-
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
728 |
-
token_ids_and_embeddings += zip(token_ids, embeddings)
|
729 |
-
|
730 |
-
logger.info(f"Loaded textual inversion embedding for {token}.")
|
731 |
-
|
732 |
-
# resize token embeddings and set all new embeddings
|
733 |
-
self.text_encoder.resize_token_embeddings(len(self.tokenizer))
|
734 |
-
for token_id, embedding in token_ids_and_embeddings:
|
735 |
-
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
|
736 |
-
|
737 |
-
|
738 |
-
class LoraLoaderMixin:
|
739 |
-
r"""
|
740 |
-
Load LoRA layers into [`UNet2DConditionModel`] and
|
741 |
-
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
|
742 |
-
"""
|
743 |
-
text_encoder_name = TEXT_ENCODER_NAME
|
744 |
-
unet_name = UNET_NAME
|
745 |
-
|
746 |
-
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
747 |
-
r"""
|
748 |
-
Load pretrained LoRA attention processor layers into [`UNet2DConditionModel`] and
|
749 |
-
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
|
750 |
-
|
751 |
-
Parameters:
|
752 |
-
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
753 |
-
Can be either:
|
754 |
-
|
755 |
-
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
756 |
-
the Hub.
|
757 |
-
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
758 |
-
with [`ModelMixin.save_pretrained`].
|
759 |
-
- A [torch state
|
760 |
-
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
761 |
-
|
762 |
-
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
763 |
-
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
764 |
-
is not used.
|
765 |
-
force_download (`bool`, *optional*, defaults to `False`):
|
766 |
-
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
767 |
-
cached versions if they exist.
|
768 |
-
resume_download (`bool`, *optional*, defaults to `False`):
|
769 |
-
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
770 |
-
incompletely downloaded files are deleted.
|
771 |
-
proxies (`Dict[str, str]`, *optional*):
|
772 |
-
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
773 |
-
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
774 |
-
local_files_only (`bool`, *optional*, defaults to `False`):
|
775 |
-
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
776 |
-
won't be downloaded from the Hub.
|
777 |
-
use_auth_token (`str` or *bool*, *optional*):
|
778 |
-
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
779 |
-
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
780 |
-
revision (`str`, *optional*, defaults to `"main"`):
|
781 |
-
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
782 |
-
allowed by Git.
|
783 |
-
subfolder (`str`, *optional*, defaults to `""`):
|
784 |
-
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
785 |
-
mirror (`str`, *optional*):
|
786 |
-
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
|
787 |
-
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
788 |
-
information.
|
789 |
-
|
790 |
-
"""
|
791 |
-
# Load the main state dict first which has the LoRA layers for either of
|
792 |
-
# UNet and text encoder or both.
|
793 |
-
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
794 |
-
force_download = kwargs.pop("force_download", False)
|
795 |
-
resume_download = kwargs.pop("resume_download", False)
|
796 |
-
proxies = kwargs.pop("proxies", None)
|
797 |
-
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
798 |
-
use_auth_token = kwargs.pop("use_auth_token", None)
|
799 |
-
revision = kwargs.pop("revision", None)
|
800 |
-
subfolder = kwargs.pop("subfolder", None)
|
801 |
-
weight_name = kwargs.pop("weight_name", None)
|
802 |
-
use_safetensors = kwargs.pop("use_safetensors", None)
|
803 |
-
|
804 |
-
# set lora scale to a reasonable default
|
805 |
-
self._lora_scale = 1.0
|
806 |
-
|
807 |
-
if use_safetensors and not is_safetensors_available():
|
808 |
-
raise ValueError(
|
809 |
-
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
|
810 |
-
)
|
811 |
-
|
812 |
-
allow_pickle = False
|
813 |
-
if use_safetensors is None:
|
814 |
-
use_safetensors = is_safetensors_available()
|
815 |
-
allow_pickle = True
|
816 |
-
|
817 |
-
user_agent = {
|
818 |
-
"file_type": "attn_procs_weights",
|
819 |
-
"framework": "pytorch",
|
820 |
-
}
|
821 |
-
|
822 |
-
model_file = None
|
823 |
-
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
824 |
-
# Let's first try to load .safetensors weights
|
825 |
-
if (use_safetensors and weight_name is None) or (
|
826 |
-
weight_name is not None and weight_name.endswith(".safetensors")
|
827 |
-
):
|
828 |
-
try:
|
829 |
-
model_file = _get_model_file(
|
830 |
-
pretrained_model_name_or_path_or_dict,
|
831 |
-
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
|
832 |
-
cache_dir=cache_dir,
|
833 |
-
force_download=force_download,
|
834 |
-
resume_download=resume_download,
|
835 |
-
proxies=proxies,
|
836 |
-
local_files_only=local_files_only,
|
837 |
-
use_auth_token=use_auth_token,
|
838 |
-
revision=revision,
|
839 |
-
subfolder=subfolder,
|
840 |
-
user_agent=user_agent,
|
841 |
-
)
|
842 |
-
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
843 |
-
except IOError as e:
|
844 |
-
if not allow_pickle:
|
845 |
-
raise e
|
846 |
-
# try loading non-safetensors weights
|
847 |
-
pass
|
848 |
-
if model_file is None:
|
849 |
-
model_file = _get_model_file(
|
850 |
-
pretrained_model_name_or_path_or_dict,
|
851 |
-
weights_name=weight_name or LORA_WEIGHT_NAME,
|
852 |
-
cache_dir=cache_dir,
|
853 |
-
force_download=force_download,
|
854 |
-
resume_download=resume_download,
|
855 |
-
proxies=proxies,
|
856 |
-
local_files_only=local_files_only,
|
857 |
-
use_auth_token=use_auth_token,
|
858 |
-
revision=revision,
|
859 |
-
subfolder=subfolder,
|
860 |
-
user_agent=user_agent,
|
861 |
-
)
|
862 |
-
state_dict = torch.load(model_file, map_location="cpu")
|
863 |
-
else:
|
864 |
-
state_dict = pretrained_model_name_or_path_or_dict
|
865 |
-
|
866 |
-
# Convert kohya-ss Style LoRA attn procs to diffusers attn procs
|
867 |
-
network_alpha = None
|
868 |
-
if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()):
|
869 |
-
state_dict, network_alpha = self._convert_kohya_lora_to_diffusers(state_dict)
|
870 |
-
|
871 |
-
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
872 |
-
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
873 |
-
# their prefixes.
|
874 |
-
keys = list(state_dict.keys())
|
875 |
-
if all(key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in keys):
|
876 |
-
# Load the layers corresponding to UNet.
|
877 |
-
unet_keys = [k for k in keys if k.startswith(self.unet_name)]
|
878 |
-
logger.info(f"Loading {self.unet_name}.")
|
879 |
-
unet_lora_state_dict = {
|
880 |
-
k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys
|
881 |
-
}
|
882 |
-
self.unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha)
|
883 |
-
|
884 |
-
# Load the layers corresponding to text encoder and make necessary adjustments.
|
885 |
-
text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)]
|
886 |
-
text_encoder_lora_state_dict = {
|
887 |
-
k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
|
888 |
-
}
|
889 |
-
if len(text_encoder_lora_state_dict) > 0:
|
890 |
-
logger.info(f"Loading {self.text_encoder_name}.")
|
891 |
-
attn_procs_text_encoder = self._load_text_encoder_attn_procs(
|
892 |
-
text_encoder_lora_state_dict, network_alpha=network_alpha
|
893 |
-
)
|
894 |
-
self._modify_text_encoder(attn_procs_text_encoder)
|
895 |
-
|
896 |
-
# save lora attn procs of text encoder so that it can be easily retrieved
|
897 |
-
self._text_encoder_lora_attn_procs = attn_procs_text_encoder
|
898 |
-
|
899 |
-
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
|
900 |
-
# contain the module names of the `unet` as its keys WITHOUT any prefix.
|
901 |
-
elif not all(
|
902 |
-
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
|
903 |
-
):
|
904 |
-
self.unet.load_attn_procs(state_dict)
|
905 |
-
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
|
906 |
-
warnings.warn(warn_message)
|
907 |
-
|
908 |
-
@property
|
909 |
-
def lora_scale(self) -> float:
|
910 |
-
# property function that returns the lora scale which can be set at run time by the pipeline.
|
911 |
-
# if _lora_scale has not been set, return 1
|
912 |
-
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
|
913 |
-
|
914 |
-
@property
|
915 |
-
def text_encoder_lora_attn_procs(self):
|
916 |
-
if hasattr(self, "_text_encoder_lora_attn_procs"):
|
917 |
-
return self._text_encoder_lora_attn_procs
|
918 |
-
return
|
919 |
-
|
920 |
-
def _remove_text_encoder_monkey_patch(self):
|
921 |
-
# Loop over the CLIPAttention module of text_encoder
|
922 |
-
for name, attn_module in self.text_encoder.named_modules():
|
923 |
-
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
|
924 |
-
# Loop over the LoRA layers
|
925 |
-
for _, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items():
|
926 |
-
# Retrieve the q/k/v/out projection of CLIPAttention
|
927 |
-
module = attn_module.get_submodule(text_encoder_attr)
|
928 |
-
if hasattr(module, "old_forward"):
|
929 |
-
# restore original `forward` to remove monkey-patch
|
930 |
-
module.forward = module.old_forward
|
931 |
-
delattr(module, "old_forward")
|
932 |
-
|
933 |
-
def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
|
934 |
-
r"""
|
935 |
-
Monkey-patches the forward passes of attention modules of the text encoder.
|
936 |
-
|
937 |
-
Parameters:
|
938 |
-
attn_processors: Dict[str, `LoRAAttnProcessor`]:
|
939 |
-
A dictionary mapping the module names and their corresponding [`~LoRAAttnProcessor`].
|
940 |
-
"""
|
941 |
-
|
942 |
-
# First, remove any monkey-patch that might have been applied before
|
943 |
-
self._remove_text_encoder_monkey_patch()
|
944 |
-
|
945 |
-
# Loop over the CLIPAttention module of text_encoder
|
946 |
-
for name, attn_module in self.text_encoder.named_modules():
|
947 |
-
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
|
948 |
-
# Loop over the LoRA layers
|
949 |
-
for attn_proc_attr, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items():
|
950 |
-
# Retrieve the q/k/v/out projection of CLIPAttention and its corresponding LoRA layer.
|
951 |
-
module = attn_module.get_submodule(text_encoder_attr)
|
952 |
-
lora_layer = attn_processors[name].get_submodule(attn_proc_attr)
|
953 |
-
|
954 |
-
# save old_forward to module that can be used to remove monkey-patch
|
955 |
-
old_forward = module.old_forward = module.forward
|
956 |
-
|
957 |
-
# create a new scope that locks in the old_forward, lora_layer value for each new_forward function
|
958 |
-
# for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
|
959 |
-
def make_new_forward(old_forward, lora_layer):
|
960 |
-
def new_forward(x):
|
961 |
-
result = old_forward(x) + self.lora_scale * lora_layer(x)
|
962 |
-
return result
|
963 |
-
|
964 |
-
return new_forward
|
965 |
-
|
966 |
-
# Monkey-patch.
|
967 |
-
module.forward = make_new_forward(old_forward, lora_layer)
|
968 |
-
|
969 |
-
@property
|
970 |
-
def _lora_attn_processor_attr_to_text_encoder_attr(self):
|
971 |
-
return {
|
972 |
-
"to_q_lora": "q_proj",
|
973 |
-
"to_k_lora": "k_proj",
|
974 |
-
"to_v_lora": "v_proj",
|
975 |
-
"to_out_lora": "out_proj",
|
976 |
-
}
|
977 |
-
|
978 |
-
def _load_text_encoder_attn_procs(
|
979 |
-
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs
|
980 |
-
):
|
981 |
-
r"""
|
982 |
-
Load pretrained attention processor layers for
|
983 |
-
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
|
984 |
-
|
985 |
-
<Tip warning={true}>
|
986 |
-
|
987 |
-
This function is experimental and might change in the future.
|
988 |
-
|
989 |
-
</Tip>
|
990 |
-
|
991 |
-
Parameters:
|
992 |
-
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
993 |
-
Can be either:
|
994 |
-
|
995 |
-
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
996 |
-
Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
|
997 |
-
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
|
998 |
-
`./my_model_directory/`.
|
999 |
-
- A [torch state
|
1000 |
-
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
1001 |
-
|
1002 |
-
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
1003 |
-
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
1004 |
-
standard cache should not be used.
|
1005 |
-
force_download (`bool`, *optional*, defaults to `False`):
|
1006 |
-
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
1007 |
-
cached versions if they exist.
|
1008 |
-
resume_download (`bool`, *optional*, defaults to `False`):
|
1009 |
-
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
1010 |
-
file exists.
|
1011 |
-
proxies (`Dict[str, str]`, *optional*):
|
1012 |
-
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
1013 |
-
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
1014 |
-
local_files_only (`bool`, *optional*, defaults to `False`):
|
1015 |
-
Whether or not to only look at local files (i.e., do not try to download the model).
|
1016 |
-
use_auth_token (`str` or *bool*, *optional*):
|
1017 |
-
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
1018 |
-
when running `diffusers-cli login` (stored in `~/.huggingface`).
|
1019 |
-
revision (`str`, *optional*, defaults to `"main"`):
|
1020 |
-
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
1021 |
-
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
1022 |
-
identifier allowed by git.
|
1023 |
-
subfolder (`str`, *optional*, defaults to `""`):
|
1024 |
-
In case the relevant files are located inside a subfolder of the model repo (either remote in
|
1025 |
-
huggingface.co or downloaded locally), you can specify the folder name here.
|
1026 |
-
mirror (`str`, *optional*):
|
1027 |
-
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
1028 |
-
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
1029 |
-
Please refer to the mirror site for more information.
|
1030 |
-
|
1031 |
-
Returns:
|
1032 |
-
`Dict[name, LoRAAttnProcessor]`: Mapping between the module names and their corresponding
|
1033 |
-
[`LoRAAttnProcessor`].
|
1034 |
-
|
1035 |
-
<Tip>
|
1036 |
-
|
1037 |
-
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
|
1038 |
-
models](https://huggingface.co/docs/hub/models-gated#gated-models).
|
1039 |
-
|
1040 |
-
</Tip>
|
1041 |
-
"""
|
1042 |
-
|
1043 |
-
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
1044 |
-
force_download = kwargs.pop("force_download", False)
|
1045 |
-
resume_download = kwargs.pop("resume_download", False)
|
1046 |
-
proxies = kwargs.pop("proxies", None)
|
1047 |
-
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
1048 |
-
use_auth_token = kwargs.pop("use_auth_token", None)
|
1049 |
-
revision = kwargs.pop("revision", None)
|
1050 |
-
subfolder = kwargs.pop("subfolder", None)
|
1051 |
-
weight_name = kwargs.pop("weight_name", None)
|
1052 |
-
use_safetensors = kwargs.pop("use_safetensors", None)
|
1053 |
-
network_alpha = kwargs.pop("network_alpha", None)
|
1054 |
-
|
1055 |
-
if use_safetensors and not is_safetensors_available():
|
1056 |
-
raise ValueError(
|
1057 |
-
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
|
1058 |
-
)
|
1059 |
-
|
1060 |
-
allow_pickle = False
|
1061 |
-
if use_safetensors is None:
|
1062 |
-
use_safetensors = is_safetensors_available()
|
1063 |
-
allow_pickle = True
|
1064 |
-
|
1065 |
-
user_agent = {
|
1066 |
-
"file_type": "attn_procs_weights",
|
1067 |
-
"framework": "pytorch",
|
1068 |
-
}
|
1069 |
-
|
1070 |
-
model_file = None
|
1071 |
-
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
1072 |
-
# Let's first try to load .safetensors weights
|
1073 |
-
if (use_safetensors and weight_name is None) or (
|
1074 |
-
weight_name is not None and weight_name.endswith(".safetensors")
|
1075 |
-
):
|
1076 |
-
try:
|
1077 |
-
model_file = _get_model_file(
|
1078 |
-
pretrained_model_name_or_path_or_dict,
|
1079 |
-
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
|
1080 |
-
cache_dir=cache_dir,
|
1081 |
-
force_download=force_download,
|
1082 |
-
resume_download=resume_download,
|
1083 |
-
proxies=proxies,
|
1084 |
-
local_files_only=local_files_only,
|
1085 |
-
use_auth_token=use_auth_token,
|
1086 |
-
revision=revision,
|
1087 |
-
subfolder=subfolder,
|
1088 |
-
user_agent=user_agent,
|
1089 |
-
)
|
1090 |
-
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
1091 |
-
except IOError as e:
|
1092 |
-
if not allow_pickle:
|
1093 |
-
raise e
|
1094 |
-
# try loading non-safetensors weights
|
1095 |
-
pass
|
1096 |
-
if model_file is None:
|
1097 |
-
model_file = _get_model_file(
|
1098 |
-
pretrained_model_name_or_path_or_dict,
|
1099 |
-
weights_name=weight_name or LORA_WEIGHT_NAME,
|
1100 |
-
cache_dir=cache_dir,
|
1101 |
-
force_download=force_download,
|
1102 |
-
resume_download=resume_download,
|
1103 |
-
proxies=proxies,
|
1104 |
-
local_files_only=local_files_only,
|
1105 |
-
use_auth_token=use_auth_token,
|
1106 |
-
revision=revision,
|
1107 |
-
subfolder=subfolder,
|
1108 |
-
user_agent=user_agent,
|
1109 |
-
)
|
1110 |
-
state_dict = torch.load(model_file, map_location="cpu")
|
1111 |
-
else:
|
1112 |
-
state_dict = pretrained_model_name_or_path_or_dict
|
1113 |
-
|
1114 |
-
# fill attn processors
|
1115 |
-
attn_processors = {}
|
1116 |
-
|
1117 |
-
is_lora = all("lora" in k for k in state_dict.keys())
|
1118 |
-
|
1119 |
-
if is_lora:
|
1120 |
-
lora_grouped_dict = defaultdict(dict)
|
1121 |
-
for key, value in state_dict.items():
|
1122 |
-
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
|
1123 |
-
lora_grouped_dict[attn_processor_key][sub_key] = value
|
1124 |
-
|
1125 |
-
for key, value_dict in lora_grouped_dict.items():
|
1126 |
-
rank = value_dict["to_k_lora.down.weight"].shape[0]
|
1127 |
-
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
|
1128 |
-
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
|
1129 |
-
|
1130 |
-
attn_processor_class = (
|
1131 |
-
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
1132 |
-
)
|
1133 |
-
attn_processors[key] = attn_processor_class(
|
1134 |
-
hidden_size=hidden_size,
|
1135 |
-
cross_attention_dim=cross_attention_dim,
|
1136 |
-
rank=rank,
|
1137 |
-
network_alpha=network_alpha,
|
1138 |
-
)
|
1139 |
-
attn_processors[key].load_state_dict(value_dict)
|
1140 |
-
|
1141 |
-
else:
|
1142 |
-
raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.")
|
1143 |
-
|
1144 |
-
# set correct dtype & device
|
1145 |
-
attn_processors = {
|
1146 |
-
k: v.to(device=self.device, dtype=self.text_encoder.dtype) for k, v in attn_processors.items()
|
1147 |
-
}
|
1148 |
-
return attn_processors
|
1149 |
-
|
1150 |
-
@classmethod
|
1151 |
-
def save_lora_weights(
|
1152 |
-
self,
|
1153 |
-
save_directory: Union[str, os.PathLike],
|
1154 |
-
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
1155 |
-
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
|
1156 |
-
is_main_process: bool = True,
|
1157 |
-
weight_name: str = None,
|
1158 |
-
save_function: Callable = None,
|
1159 |
-
safe_serialization: bool = False,
|
1160 |
-
):
|
1161 |
-
r"""
|
1162 |
-
Save the LoRA parameters corresponding to the UNet and text encoder.
|
1163 |
-
|
1164 |
-
Arguments:
|
1165 |
-
save_directory (`str` or `os.PathLike`):
|
1166 |
-
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
1167 |
-
unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
1168 |
-
State dict of the LoRA layers corresponding to the UNet.
|
1169 |
-
text_encoder_lora_layers (`Dict[str, torch.nn.Module] or `Dict[str, torch.Tensor]`):
|
1170 |
-
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
|
1171 |
-
encoder LoRA state dict because it comes 🤗 Transformers.
|
1172 |
-
is_main_process (`bool`, *optional*, defaults to `True`):
|
1173 |
-
Whether the process calling this is the main process or not. Useful during distributed training and you
|
1174 |
-
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
1175 |
-
process to avoid race conditions.
|
1176 |
-
save_function (`Callable`):
|
1177 |
-
The function to use to save the state dictionary. Useful during distributed training when you need to
|
1178 |
-
replace `torch.save` with another method. Can be configured with the environment variable
|
1179 |
-
`DIFFUSERS_SAVE_MODE`.
|
1180 |
-
"""
|
1181 |
-
if os.path.isfile(save_directory):
|
1182 |
-
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
1183 |
-
return
|
1184 |
-
|
1185 |
-
if save_function is None:
|
1186 |
-
if safe_serialization:
|
1187 |
-
|
1188 |
-
def save_function(weights, filename):
|
1189 |
-
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
|
1190 |
-
|
1191 |
-
else:
|
1192 |
-
save_function = torch.save
|
1193 |
-
|
1194 |
-
os.makedirs(save_directory, exist_ok=True)
|
1195 |
-
|
1196 |
-
# Create a flat dictionary.
|
1197 |
-
state_dict = {}
|
1198 |
-
if unet_lora_layers is not None:
|
1199 |
-
weights = (
|
1200 |
-
unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers
|
1201 |
-
)
|
1202 |
-
|
1203 |
-
unet_lora_state_dict = {f"{self.unet_name}.{module_name}": param for module_name, param in weights.items()}
|
1204 |
-
state_dict.update(unet_lora_state_dict)
|
1205 |
-
|
1206 |
-
if text_encoder_lora_layers is not None:
|
1207 |
-
weights = (
|
1208 |
-
text_encoder_lora_layers.state_dict()
|
1209 |
-
if isinstance(text_encoder_lora_layers, torch.nn.Module)
|
1210 |
-
else text_encoder_lora_layers
|
1211 |
-
)
|
1212 |
-
|
1213 |
-
text_encoder_lora_state_dict = {
|
1214 |
-
f"{self.text_encoder_name}.{module_name}": param for module_name, param in weights.items()
|
1215 |
-
}
|
1216 |
-
state_dict.update(text_encoder_lora_state_dict)
|
1217 |
-
|
1218 |
-
# Save the model
|
1219 |
-
if weight_name is None:
|
1220 |
-
if safe_serialization:
|
1221 |
-
weight_name = LORA_WEIGHT_NAME_SAFE
|
1222 |
-
else:
|
1223 |
-
weight_name = LORA_WEIGHT_NAME
|
1224 |
-
|
1225 |
-
save_function(state_dict, os.path.join(save_directory, weight_name))
|
1226 |
-
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
|
1227 |
-
|
1228 |
-
def _convert_kohya_lora_to_diffusers(self, state_dict):
|
1229 |
-
unet_state_dict = {}
|
1230 |
-
te_state_dict = {}
|
1231 |
-
network_alpha = None
|
1232 |
-
|
1233 |
-
for key, value in state_dict.items():
|
1234 |
-
if "lora_down" in key:
|
1235 |
-
lora_name = key.split(".")[0]
|
1236 |
-
lora_name_up = lora_name + ".lora_up.weight"
|
1237 |
-
lora_name_alpha = lora_name + ".alpha"
|
1238 |
-
if lora_name_alpha in state_dict:
|
1239 |
-
alpha = state_dict[lora_name_alpha].item()
|
1240 |
-
if network_alpha is None:
|
1241 |
-
network_alpha = alpha
|
1242 |
-
elif network_alpha != alpha:
|
1243 |
-
raise ValueError("Network alpha is not consistent")
|
1244 |
-
|
1245 |
-
if lora_name.startswith("lora_unet_"):
|
1246 |
-
diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
|
1247 |
-
diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
|
1248 |
-
diffusers_name = diffusers_name.replace("mid.block", "mid_block")
|
1249 |
-
diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
|
1250 |
-
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
|
1251 |
-
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
|
1252 |
-
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
|
1253 |
-
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
|
1254 |
-
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
|
1255 |
-
if "transformer_blocks" in diffusers_name:
|
1256 |
-
if "attn1" in diffusers_name or "attn2" in diffusers_name:
|
1257 |
-
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
|
1258 |
-
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
|
1259 |
-
unet_state_dict[diffusers_name] = value
|
1260 |
-
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
|
1261 |
-
elif lora_name.startswith("lora_te_"):
|
1262 |
-
diffusers_name = key.replace("lora_te_", "").replace("_", ".")
|
1263 |
-
diffusers_name = diffusers_name.replace("text.model", "text_model")
|
1264 |
-
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
|
1265 |
-
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
|
1266 |
-
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
|
1267 |
-
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
|
1268 |
-
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
|
1269 |
-
if "self_attn" in diffusers_name:
|
1270 |
-
te_state_dict[diffusers_name] = value
|
1271 |
-
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
|
1272 |
-
|
1273 |
-
unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()}
|
1274 |
-
te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()}
|
1275 |
-
new_state_dict = {**unet_state_dict, **te_state_dict}
|
1276 |
-
return new_state_dict, network_alpha
|
1277 |
-
|
1278 |
-
|
1279 |
-
class FromSingleFileMixin:
|
1280 |
-
"""
|
1281 |
-
Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`].
|
1282 |
-
"""
|
1283 |
-
|
1284 |
-
@classmethod
|
1285 |
-
def from_ckpt(cls, *args, **kwargs):
|
1286 |
-
deprecation_message = "The function `from_ckpt` is deprecated in favor of `from_single_file` and will be removed in diffusers v.0.21. Please make sure to use `StableDiffusionPipeline.from_single_file(...)` instead."
|
1287 |
-
deprecate("from_ckpt", "0.21.0", deprecation_message, standard_warn=False)
|
1288 |
-
return cls.from_single_file(*args, **kwargs)
|
1289 |
-
|
1290 |
-
@classmethod
|
1291 |
-
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
1292 |
-
r"""
|
1293 |
-
Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` format. The pipeline
|
1294 |
-
is set in evaluation mode (`model.eval()`) by default.
|
1295 |
-
|
1296 |
-
Parameters:
|
1297 |
-
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
1298 |
-
Can be either:
|
1299 |
-
- A link to the `.ckpt` file (for example
|
1300 |
-
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
1301 |
-
- A path to a *file* containing all pipeline weights.
|
1302 |
-
torch_dtype (`str` or `torch.dtype`, *optional*):
|
1303 |
-
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
1304 |
-
dtype is automatically derived from the model's weights.
|
1305 |
-
force_download (`bool`, *optional*, defaults to `False`):
|
1306 |
-
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
1307 |
-
cached versions if they exist.
|
1308 |
-
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
1309 |
-
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
1310 |
-
is not used.
|
1311 |
-
resume_download (`bool`, *optional*, defaults to `False`):
|
1312 |
-
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
1313 |
-
incompletely downloaded files are deleted.
|
1314 |
-
proxies (`Dict[str, str]`, *optional*):
|
1315 |
-
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
1316 |
-
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
1317 |
-
local_files_only (`bool`, *optional*, defaults to `False`):
|
1318 |
-
Whether to only load local model weights and configuration files or not. If set to True, the model
|
1319 |
-
won't be downloaded from the Hub.
|
1320 |
-
use_auth_token (`str` or *bool*, *optional*):
|
1321 |
-
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
1322 |
-
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
1323 |
-
revision (`str`, *optional*, defaults to `"main"`):
|
1324 |
-
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
1325 |
-
allowed by Git.
|
1326 |
-
use_safetensors (`bool`, *optional*, defaults to `None`):
|
1327 |
-
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
1328 |
-
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
1329 |
-
weights. If set to `False`, safetensors weights are not loaded.
|
1330 |
-
extract_ema (`bool`, *optional*, defaults to `False`):
|
1331 |
-
Whether to extract the EMA weights or not. Pass `True` to extract the EMA weights which usually yield
|
1332 |
-
higher quality images for inference. Non-EMA weights are usually better to continue finetuning.
|
1333 |
-
upcast_attention (`bool`, *optional*, defaults to `None`):
|
1334 |
-
Whether the attention computation should always be upcasted.
|
1335 |
-
image_size (`int`, *optional*, defaults to 512):
|
1336 |
-
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
|
1337 |
-
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
|
1338 |
-
prediction_type (`str`, *optional*):
|
1339 |
-
The prediction type the model was trained on. Use `'epsilon'` for all Stable Diffusion v1 models and
|
1340 |
-
the Stable Diffusion v2 base model. Use `'v_prediction'` for Stable Diffusion v2.
|
1341 |
-
num_in_channels (`int`, *optional*, defaults to `None`):
|
1342 |
-
The number of input channels. If `None`, it will be automatically inferred.
|
1343 |
-
scheduler_type (`str`, *optional*, defaults to `"pndm"`):
|
1344 |
-
Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
|
1345 |
-
"ddim"]`.
|
1346 |
-
load_safety_checker (`bool`, *optional*, defaults to `True`):
|
1347 |
-
Whether to load the safety checker or not.
|
1348 |
-
text_encoder (`CLIPTextModel`, *optional*, defaults to `None`):
|
1349 |
-
An instance of
|
1350 |
-
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) to use,
|
1351 |
-
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)
|
1352 |
-
variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if
|
1353 |
-
needed.
|
1354 |
-
tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`):
|
1355 |
-
An instance of
|
1356 |
-
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer)
|
1357 |
-
to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by
|
1358 |
-
itself, if needed.
|
1359 |
-
kwargs (remaining dictionary of keyword arguments, *optional*):
|
1360 |
-
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
1361 |
-
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
1362 |
-
method. See example below for more information.
|
1363 |
-
|
1364 |
-
Examples:
|
1365 |
-
|
1366 |
-
```py
|
1367 |
-
>>> from diffusers import StableDiffusionPipeline
|
1368 |
-
|
1369 |
-
>>> # Download pipeline from huggingface.co and cache.
|
1370 |
-
>>> pipeline = StableDiffusionPipeline.from_single_file(
|
1371 |
-
... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
|
1372 |
-
... )
|
1373 |
-
|
1374 |
-
>>> # Download pipeline from local file
|
1375 |
-
>>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
|
1376 |
-
>>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly")
|
1377 |
-
|
1378 |
-
>>> # Enable float16 and move to GPU
|
1379 |
-
>>> pipeline = StableDiffusionPipeline.from_single_file(
|
1380 |
-
... "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
|
1381 |
-
... torch_dtype=torch.float16,
|
1382 |
-
... )
|
1383 |
-
>>> pipeline.to("cuda")
|
1384 |
-
```
|
1385 |
-
"""
|
1386 |
-
# import here to avoid circular dependency
|
1387 |
-
from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
|
1388 |
-
|
1389 |
-
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
1390 |
-
resume_download = kwargs.pop("resume_download", False)
|
1391 |
-
force_download = kwargs.pop("force_download", False)
|
1392 |
-
proxies = kwargs.pop("proxies", None)
|
1393 |
-
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
1394 |
-
use_auth_token = kwargs.pop("use_auth_token", None)
|
1395 |
-
revision = kwargs.pop("revision", None)
|
1396 |
-
extract_ema = kwargs.pop("extract_ema", False)
|
1397 |
-
image_size = kwargs.pop("image_size", None)
|
1398 |
-
scheduler_type = kwargs.pop("scheduler_type", "pndm")
|
1399 |
-
num_in_channels = kwargs.pop("num_in_channels", None)
|
1400 |
-
upcast_attention = kwargs.pop("upcast_attention", None)
|
1401 |
-
load_safety_checker = kwargs.pop("load_safety_checker", True)
|
1402 |
-
prediction_type = kwargs.pop("prediction_type", None)
|
1403 |
-
text_encoder = kwargs.pop("text_encoder", None)
|
1404 |
-
tokenizer = kwargs.pop("tokenizer", None)
|
1405 |
-
|
1406 |
-
torch_dtype = kwargs.pop("torch_dtype", None)
|
1407 |
-
|
1408 |
-
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
|
1409 |
-
|
1410 |
-
pipeline_name = cls.__name__
|
1411 |
-
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
1412 |
-
from_safetensors = file_extension == "safetensors"
|
1413 |
-
|
1414 |
-
if from_safetensors and use_safetensors is False:
|
1415 |
-
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
|
1416 |
-
|
1417 |
-
# TODO: For now we only support stable diffusion
|
1418 |
-
stable_unclip = None
|
1419 |
-
model_type = None
|
1420 |
-
controlnet = False
|
1421 |
-
|
1422 |
-
if pipeline_name == "StableDiffusionControlNetPipeline":
|
1423 |
-
# Model type will be inferred from the checkpoint.
|
1424 |
-
controlnet = True
|
1425 |
-
elif "StableDiffusion" in pipeline_name:
|
1426 |
-
# Model type will be inferred from the checkpoint.
|
1427 |
-
pass
|
1428 |
-
elif pipeline_name == "StableUnCLIPPipeline":
|
1429 |
-
model_type = "FrozenOpenCLIPEmbedder"
|
1430 |
-
stable_unclip = "txt2img"
|
1431 |
-
elif pipeline_name == "StableUnCLIPImg2ImgPipeline":
|
1432 |
-
model_type = "FrozenOpenCLIPEmbedder"
|
1433 |
-
stable_unclip = "img2img"
|
1434 |
-
elif pipeline_name == "PaintByExamplePipeline":
|
1435 |
-
model_type = "PaintByExample"
|
1436 |
-
elif pipeline_name == "LDMTextToImagePipeline":
|
1437 |
-
model_type = "LDMTextToImage"
|
1438 |
-
else:
|
1439 |
-
raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
|
1440 |
-
|
1441 |
-
# remove huggingface url
|
1442 |
-
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
|
1443 |
-
if pretrained_model_link_or_path.startswith(prefix):
|
1444 |
-
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
|
1445 |
-
|
1446 |
-
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
|
1447 |
-
ckpt_path = Path(pretrained_model_link_or_path)
|
1448 |
-
if not ckpt_path.is_file():
|
1449 |
-
# get repo_id and (potentially nested) file path of ckpt in repo
|
1450 |
-
repo_id = "/".join(ckpt_path.parts[:2])
|
1451 |
-
file_path = "/".join(ckpt_path.parts[2:])
|
1452 |
-
|
1453 |
-
if file_path.startswith("blob/"):
|
1454 |
-
file_path = file_path[len("blob/") :]
|
1455 |
-
|
1456 |
-
if file_path.startswith("main/"):
|
1457 |
-
file_path = file_path[len("main/") :]
|
1458 |
-
|
1459 |
-
pretrained_model_link_or_path = hf_hub_download(
|
1460 |
-
repo_id,
|
1461 |
-
filename=file_path,
|
1462 |
-
cache_dir=cache_dir,
|
1463 |
-
resume_download=resume_download,
|
1464 |
-
proxies=proxies,
|
1465 |
-
local_files_only=local_files_only,
|
1466 |
-
use_auth_token=use_auth_token,
|
1467 |
-
revision=revision,
|
1468 |
-
force_download=force_download,
|
1469 |
-
)
|
1470 |
-
|
1471 |
-
pipe = download_from_original_stable_diffusion_ckpt(
|
1472 |
-
pretrained_model_link_or_path,
|
1473 |
-
pipeline_class=cls,
|
1474 |
-
model_type=model_type,
|
1475 |
-
stable_unclip=stable_unclip,
|
1476 |
-
controlnet=controlnet,
|
1477 |
-
from_safetensors=from_safetensors,
|
1478 |
-
extract_ema=extract_ema,
|
1479 |
-
image_size=image_size,
|
1480 |
-
scheduler_type=scheduler_type,
|
1481 |
-
num_in_channels=num_in_channels,
|
1482 |
-
upcast_attention=upcast_attention,
|
1483 |
-
load_safety_checker=load_safety_checker,
|
1484 |
-
prediction_type=prediction_type,
|
1485 |
-
text_encoder=text_encoder,
|
1486 |
-
tokenizer=tokenizer,
|
1487 |
-
)
|
1488 |
-
|
1489 |
-
if torch_dtype is not None:
|
1490 |
-
pipe.to(torch_dtype=torch_dtype)
|
1491 |
-
|
1492 |
-
return pipe
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4DoF/diffusers/models/__init__.py
DELETED
@@ -1,35 +0,0 @@
|
|
1 |
-
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
from ..utils import is_flax_available, is_torch_available
|
16 |
-
|
17 |
-
|
18 |
-
if is_torch_available():
|
19 |
-
from .autoencoder_kl import AutoencoderKL
|
20 |
-
from .controlnet import ControlNetModel
|
21 |
-
from .dual_transformer_2d import DualTransformer2DModel
|
22 |
-
from .modeling_utils import ModelMixin
|
23 |
-
from .prior_transformer import PriorTransformer
|
24 |
-
from .t5_film_transformer import T5FilmDecoder
|
25 |
-
from .transformer_2d import Transformer2DModel
|
26 |
-
from .unet_1d import UNet1DModel
|
27 |
-
from .unet_2d import UNet2DModel
|
28 |
-
from .unet_2d_condition import UNet2DConditionModel
|
29 |
-
from .unet_3d_condition import UNet3DConditionModel
|
30 |
-
from .vq_model import VQModel
|
31 |
-
|
32 |
-
if is_flax_available():
|
33 |
-
from .controlnet_flax import FlaxControlNetModel
|
34 |
-
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
|
35 |
-
from .vae_flax import FlaxAutoencoderKL
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4DoF/diffusers/models/activations.py
DELETED
@@ -1,12 +0,0 @@
|
|
1 |
-
from torch import nn
|
2 |
-
|
3 |
-
|
4 |
-
def get_activation(act_fn):
|
5 |
-
if act_fn in ["swish", "silu"]:
|
6 |
-
return nn.SiLU()
|
7 |
-
elif act_fn == "mish":
|
8 |
-
return nn.Mish()
|
9 |
-
elif act_fn == "gelu":
|
10 |
-
return nn.GELU()
|
11 |
-
else:
|
12 |
-
raise ValueError(f"Unsupported activation function: {act_fn}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4DoF/diffusers/models/attention.py
DELETED
@@ -1,392 +0,0 @@
|
|
1 |
-
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
from typing import Any, Dict, Optional
|
15 |
-
|
16 |
-
import torch
|
17 |
-
import torch.nn.functional as F
|
18 |
-
from torch import nn
|
19 |
-
|
20 |
-
from ..utils import maybe_allow_in_graph
|
21 |
-
from .activations import get_activation
|
22 |
-
from .attention_processor import Attention
|
23 |
-
from .embeddings import CombinedTimestepLabelEmbeddings
|
24 |
-
|
25 |
-
|
26 |
-
@maybe_allow_in_graph
|
27 |
-
class BasicTransformerBlock(nn.Module):
|
28 |
-
r"""
|
29 |
-
A basic Transformer block.
|
30 |
-
|
31 |
-
Parameters:
|
32 |
-
dim (`int`): The number of channels in the input and output.
|
33 |
-
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
34 |
-
attention_head_dim (`int`): The number of channels in each head.
|
35 |
-
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
36 |
-
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
37 |
-
only_cross_attention (`bool`, *optional*):
|
38 |
-
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
39 |
-
double_self_attention (`bool`, *optional*):
|
40 |
-
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
41 |
-
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
42 |
-
num_embeds_ada_norm (:
|
43 |
-
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
44 |
-
attention_bias (:
|
45 |
-
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
46 |
-
"""
|
47 |
-
|
48 |
-
def __init__(
|
49 |
-
self,
|
50 |
-
dim: int,
|
51 |
-
num_attention_heads: int,
|
52 |
-
attention_head_dim: int,
|
53 |
-
dropout=0.0,
|
54 |
-
cross_attention_dim: Optional[int] = None,
|
55 |
-
activation_fn: str = "geglu",
|
56 |
-
num_embeds_ada_norm: Optional[int] = None,
|
57 |
-
attention_bias: bool = False,
|
58 |
-
only_cross_attention: bool = False,
|
59 |
-
double_self_attention: bool = False,
|
60 |
-
upcast_attention: bool = False,
|
61 |
-
norm_elementwise_affine: bool = True,
|
62 |
-
norm_type: str = "layer_norm",
|
63 |
-
final_dropout: bool = False,
|
64 |
-
):
|
65 |
-
super().__init__()
|
66 |
-
self.only_cross_attention = only_cross_attention
|
67 |
-
|
68 |
-
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
69 |
-
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
70 |
-
|
71 |
-
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
72 |
-
raise ValueError(
|
73 |
-
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
74 |
-
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
75 |
-
)
|
76 |
-
|
77 |
-
# Define 3 blocks. Each block has its own normalization layer.
|
78 |
-
# 1. Self-Attn
|
79 |
-
if self.use_ada_layer_norm:
|
80 |
-
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
81 |
-
elif self.use_ada_layer_norm_zero:
|
82 |
-
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
83 |
-
else:
|
84 |
-
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
85 |
-
self.attn1 = Attention(
|
86 |
-
query_dim=dim,
|
87 |
-
heads=num_attention_heads,
|
88 |
-
dim_head=attention_head_dim,
|
89 |
-
dropout=dropout,
|
90 |
-
bias=attention_bias,
|
91 |
-
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
92 |
-
upcast_attention=upcast_attention,
|
93 |
-
)
|
94 |
-
|
95 |
-
# 2. Cross-Attn
|
96 |
-
if cross_attention_dim is not None or double_self_attention:
|
97 |
-
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
98 |
-
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
99 |
-
# the second cross attention block.
|
100 |
-
self.norm2 = (
|
101 |
-
AdaLayerNorm(dim, num_embeds_ada_norm)
|
102 |
-
if self.use_ada_layer_norm
|
103 |
-
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
104 |
-
)
|
105 |
-
self.attn2 = Attention(
|
106 |
-
query_dim=dim,
|
107 |
-
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
108 |
-
heads=num_attention_heads,
|
109 |
-
dim_head=attention_head_dim,
|
110 |
-
dropout=dropout,
|
111 |
-
bias=attention_bias,
|
112 |
-
upcast_attention=upcast_attention,
|
113 |
-
) # is self-attn if encoder_hidden_states is none
|
114 |
-
else:
|
115 |
-
self.norm2 = None
|
116 |
-
self.attn2 = None
|
117 |
-
|
118 |
-
# 3. Feed-forward
|
119 |
-
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
120 |
-
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
121 |
-
|
122 |
-
# let chunk size default to None
|
123 |
-
self._chunk_size = None
|
124 |
-
self._chunk_dim = 0
|
125 |
-
|
126 |
-
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
|
127 |
-
# Sets chunk feed-forward
|
128 |
-
self._chunk_size = chunk_size
|
129 |
-
self._chunk_dim = dim
|
130 |
-
|
131 |
-
def forward(
|
132 |
-
self,
|
133 |
-
hidden_states: torch.FloatTensor,
|
134 |
-
attention_mask: Optional[torch.FloatTensor] = None,
|
135 |
-
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
136 |
-
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
137 |
-
timestep: Optional[torch.LongTensor] = None,
|
138 |
-
posemb: Optional = None,
|
139 |
-
cross_attention_kwargs: Dict[str, Any] = None,
|
140 |
-
class_labels: Optional[torch.LongTensor] = None,
|
141 |
-
):
|
142 |
-
# Notice that normalization is always applied before the real computation in the following blocks.
|
143 |
-
# 1. Self-Attention
|
144 |
-
if self.use_ada_layer_norm:
|
145 |
-
norm_hidden_states = self.norm1(hidden_states, timestep)
|
146 |
-
elif self.use_ada_layer_norm_zero:
|
147 |
-
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
148 |
-
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
149 |
-
)
|
150 |
-
else:
|
151 |
-
norm_hidden_states = self.norm1(hidden_states)
|
152 |
-
|
153 |
-
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
154 |
-
|
155 |
-
attn_output = self.attn1(
|
156 |
-
norm_hidden_states,
|
157 |
-
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
158 |
-
attention_mask=attention_mask,
|
159 |
-
posemb=posemb, # todo in self attn, posemb shoule be [pose_in, pose_in]?
|
160 |
-
**cross_attention_kwargs,
|
161 |
-
)
|
162 |
-
if self.use_ada_layer_norm_zero:
|
163 |
-
attn_output = gate_msa.unsqueeze(1) * attn_output
|
164 |
-
hidden_states = attn_output + hidden_states
|
165 |
-
|
166 |
-
# 2. Cross-Attention
|
167 |
-
if self.attn2 is not None:
|
168 |
-
norm_hidden_states = (
|
169 |
-
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
170 |
-
)
|
171 |
-
|
172 |
-
attn_output = self.attn2(
|
173 |
-
norm_hidden_states,
|
174 |
-
encoder_hidden_states=encoder_hidden_states,
|
175 |
-
attention_mask=encoder_attention_mask,
|
176 |
-
posemb=posemb,
|
177 |
-
**cross_attention_kwargs,
|
178 |
-
)
|
179 |
-
hidden_states = attn_output + hidden_states
|
180 |
-
|
181 |
-
# 3. Feed-forward
|
182 |
-
norm_hidden_states = self.norm3(hidden_states)
|
183 |
-
|
184 |
-
if self.use_ada_layer_norm_zero:
|
185 |
-
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
186 |
-
|
187 |
-
if self._chunk_size is not None:
|
188 |
-
# "feed_forward_chunk_size" can be used to save memory
|
189 |
-
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
190 |
-
raise ValueError(
|
191 |
-
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
192 |
-
)
|
193 |
-
|
194 |
-
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
195 |
-
ff_output = torch.cat(
|
196 |
-
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
|
197 |
-
dim=self._chunk_dim,
|
198 |
-
)
|
199 |
-
else:
|
200 |
-
ff_output = self.ff(norm_hidden_states)
|
201 |
-
|
202 |
-
if self.use_ada_layer_norm_zero:
|
203 |
-
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
204 |
-
|
205 |
-
hidden_states = ff_output + hidden_states
|
206 |
-
|
207 |
-
return hidden_states
|
208 |
-
|
209 |
-
|
210 |
-
class FeedForward(nn.Module):
|
211 |
-
r"""
|
212 |
-
A feed-forward layer.
|
213 |
-
|
214 |
-
Parameters:
|
215 |
-
dim (`int`): The number of channels in the input.
|
216 |
-
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
217 |
-
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
218 |
-
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
219 |
-
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
220 |
-
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
221 |
-
"""
|
222 |
-
|
223 |
-
def __init__(
|
224 |
-
self,
|
225 |
-
dim: int,
|
226 |
-
dim_out: Optional[int] = None,
|
227 |
-
mult: int = 4,
|
228 |
-
dropout: float = 0.0,
|
229 |
-
activation_fn: str = "geglu",
|
230 |
-
final_dropout: bool = False,
|
231 |
-
):
|
232 |
-
super().__init__()
|
233 |
-
inner_dim = int(dim * mult)
|
234 |
-
dim_out = dim_out if dim_out is not None else dim
|
235 |
-
|
236 |
-
if activation_fn == "gelu":
|
237 |
-
act_fn = GELU(dim, inner_dim)
|
238 |
-
if activation_fn == "gelu-approximate":
|
239 |
-
act_fn = GELU(dim, inner_dim, approximate="tanh")
|
240 |
-
elif activation_fn == "geglu":
|
241 |
-
act_fn = GEGLU(dim, inner_dim)
|
242 |
-
elif activation_fn == "geglu-approximate":
|
243 |
-
act_fn = ApproximateGELU(dim, inner_dim)
|
244 |
-
|
245 |
-
self.net = nn.ModuleList([])
|
246 |
-
# project in
|
247 |
-
self.net.append(act_fn)
|
248 |
-
# project dropout
|
249 |
-
self.net.append(nn.Dropout(dropout))
|
250 |
-
# project out
|
251 |
-
self.net.append(nn.Linear(inner_dim, dim_out))
|
252 |
-
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
253 |
-
if final_dropout:
|
254 |
-
self.net.append(nn.Dropout(dropout))
|
255 |
-
|
256 |
-
def forward(self, hidden_states):
|
257 |
-
for module in self.net:
|
258 |
-
hidden_states = module(hidden_states)
|
259 |
-
return hidden_states
|
260 |
-
|
261 |
-
|
262 |
-
class GELU(nn.Module):
|
263 |
-
r"""
|
264 |
-
GELU activation function with tanh approximation support with `approximate="tanh"`.
|
265 |
-
"""
|
266 |
-
|
267 |
-
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
|
268 |
-
super().__init__()
|
269 |
-
self.proj = nn.Linear(dim_in, dim_out)
|
270 |
-
self.approximate = approximate
|
271 |
-
|
272 |
-
def gelu(self, gate):
|
273 |
-
if gate.device.type != "mps":
|
274 |
-
return F.gelu(gate, approximate=self.approximate)
|
275 |
-
# mps: gelu is not implemented for float16
|
276 |
-
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
|
277 |
-
|
278 |
-
def forward(self, hidden_states):
|
279 |
-
hidden_states = self.proj(hidden_states)
|
280 |
-
hidden_states = self.gelu(hidden_states)
|
281 |
-
return hidden_states
|
282 |
-
|
283 |
-
|
284 |
-
class GEGLU(nn.Module):
|
285 |
-
r"""
|
286 |
-
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
287 |
-
|
288 |
-
Parameters:
|
289 |
-
dim_in (`int`): The number of channels in the input.
|
290 |
-
dim_out (`int`): The number of channels in the output.
|
291 |
-
"""
|
292 |
-
|
293 |
-
def __init__(self, dim_in: int, dim_out: int):
|
294 |
-
super().__init__()
|
295 |
-
self.proj = nn.Linear(dim_in, dim_out * 2)
|
296 |
-
|
297 |
-
def gelu(self, gate):
|
298 |
-
if gate.device.type != "mps":
|
299 |
-
return F.gelu(gate)
|
300 |
-
# mps: gelu is not implemented for float16
|
301 |
-
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
302 |
-
|
303 |
-
def forward(self, hidden_states):
|
304 |
-
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
305 |
-
return hidden_states * self.gelu(gate)
|
306 |
-
|
307 |
-
|
308 |
-
class ApproximateGELU(nn.Module):
|
309 |
-
"""
|
310 |
-
The approximate form of Gaussian Error Linear Unit (GELU)
|
311 |
-
|
312 |
-
For more details, see section 2: https://arxiv.org/abs/1606.08415
|
313 |
-
"""
|
314 |
-
|
315 |
-
def __init__(self, dim_in: int, dim_out: int):
|
316 |
-
super().__init__()
|
317 |
-
self.proj = nn.Linear(dim_in, dim_out)
|
318 |
-
|
319 |
-
def forward(self, x):
|
320 |
-
x = self.proj(x)
|
321 |
-
return x * torch.sigmoid(1.702 * x)
|
322 |
-
|
323 |
-
|
324 |
-
class AdaLayerNorm(nn.Module):
|
325 |
-
"""
|
326 |
-
Norm layer modified to incorporate timestep embeddings.
|
327 |
-
"""
|
328 |
-
|
329 |
-
def __init__(self, embedding_dim, num_embeddings):
|
330 |
-
super().__init__()
|
331 |
-
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
332 |
-
self.silu = nn.SiLU()
|
333 |
-
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
|
334 |
-
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
|
335 |
-
|
336 |
-
def forward(self, x, timestep):
|
337 |
-
emb = self.linear(self.silu(self.emb(timestep)))
|
338 |
-
scale, shift = torch.chunk(emb, 2)
|
339 |
-
x = self.norm(x) * (1 + scale) + shift
|
340 |
-
return x
|
341 |
-
|
342 |
-
|
343 |
-
class AdaLayerNormZero(nn.Module):
|
344 |
-
"""
|
345 |
-
Norm layer adaptive layer norm zero (adaLN-Zero).
|
346 |
-
"""
|
347 |
-
|
348 |
-
def __init__(self, embedding_dim, num_embeddings):
|
349 |
-
super().__init__()
|
350 |
-
|
351 |
-
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
|
352 |
-
|
353 |
-
self.silu = nn.SiLU()
|
354 |
-
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
355 |
-
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
356 |
-
|
357 |
-
def forward(self, x, timestep, class_labels, hidden_dtype=None):
|
358 |
-
emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
|
359 |
-
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
|
360 |
-
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
361 |
-
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
362 |
-
|
363 |
-
|
364 |
-
class AdaGroupNorm(nn.Module):
|
365 |
-
"""
|
366 |
-
GroupNorm layer modified to incorporate timestep embeddings.
|
367 |
-
"""
|
368 |
-
|
369 |
-
def __init__(
|
370 |
-
self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
|
371 |
-
):
|
372 |
-
super().__init__()
|
373 |
-
self.num_groups = num_groups
|
374 |
-
self.eps = eps
|
375 |
-
|
376 |
-
if act_fn is None:
|
377 |
-
self.act = None
|
378 |
-
else:
|
379 |
-
self.act = get_activation(act_fn)
|
380 |
-
|
381 |
-
self.linear = nn.Linear(embedding_dim, out_dim * 2)
|
382 |
-
|
383 |
-
def forward(self, x, emb):
|
384 |
-
if self.act:
|
385 |
-
emb = self.act(emb)
|
386 |
-
emb = self.linear(emb)
|
387 |
-
emb = emb[:, :, None, None]
|
388 |
-
scale, shift = emb.chunk(2, dim=1)
|
389 |
-
|
390 |
-
x = F.group_norm(x, self.num_groups, eps=self.eps)
|
391 |
-
x = x * (1 + scale) + shift
|
392 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4DoF/diffusers/models/attention_flax.py
DELETED
@@ -1,446 +0,0 @@
|
|
1 |
-
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
import functools
|
16 |
-
import math
|
17 |
-
|
18 |
-
import flax.linen as nn
|
19 |
-
import jax
|
20 |
-
import jax.numpy as jnp
|
21 |
-
|
22 |
-
|
23 |
-
def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
|
24 |
-
"""Multi-head dot product attention with a limited number of queries."""
|
25 |
-
num_kv, num_heads, k_features = key.shape[-3:]
|
26 |
-
v_features = value.shape[-1]
|
27 |
-
key_chunk_size = min(key_chunk_size, num_kv)
|
28 |
-
query = query / jnp.sqrt(k_features)
|
29 |
-
|
30 |
-
@functools.partial(jax.checkpoint, prevent_cse=False)
|
31 |
-
def summarize_chunk(query, key, value):
|
32 |
-
attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision)
|
33 |
-
|
34 |
-
max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
|
35 |
-
max_score = jax.lax.stop_gradient(max_score)
|
36 |
-
exp_weights = jnp.exp(attn_weights - max_score)
|
37 |
-
|
38 |
-
exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision)
|
39 |
-
max_score = jnp.einsum("...qhk->...qh", max_score)
|
40 |
-
|
41 |
-
return (exp_values, exp_weights.sum(axis=-1), max_score)
|
42 |
-
|
43 |
-
def chunk_scanner(chunk_idx):
|
44 |
-
# julienne key array
|
45 |
-
key_chunk = jax.lax.dynamic_slice(
|
46 |
-
operand=key,
|
47 |
-
start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d]
|
48 |
-
slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d]
|
49 |
-
)
|
50 |
-
|
51 |
-
# julienne value array
|
52 |
-
value_chunk = jax.lax.dynamic_slice(
|
53 |
-
operand=value,
|
54 |
-
start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d]
|
55 |
-
slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d]
|
56 |
-
)
|
57 |
-
|
58 |
-
return summarize_chunk(query, key_chunk, value_chunk)
|
59 |
-
|
60 |
-
chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))
|
61 |
-
|
62 |
-
global_max = jnp.max(chunk_max, axis=0, keepdims=True)
|
63 |
-
max_diffs = jnp.exp(chunk_max - global_max)
|
64 |
-
|
65 |
-
chunk_values *= jnp.expand_dims(max_diffs, axis=-1)
|
66 |
-
chunk_weights *= max_diffs
|
67 |
-
|
68 |
-
all_values = chunk_values.sum(axis=0)
|
69 |
-
all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)
|
70 |
-
|
71 |
-
return all_values / all_weights
|
72 |
-
|
73 |
-
|
74 |
-
def jax_memory_efficient_attention(
|
75 |
-
query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096
|
76 |
-
):
|
77 |
-
r"""
|
78 |
-
Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2
|
79 |
-
https://github.com/AminRezaei0x443/memory-efficient-attention
|
80 |
-
|
81 |
-
Args:
|
82 |
-
query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head)
|
83 |
-
key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head)
|
84 |
-
value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head)
|
85 |
-
precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`):
|
86 |
-
numerical precision for computation
|
87 |
-
query_chunk_size (`int`, *optional*, defaults to 1024):
|
88 |
-
chunk size to divide query array value must divide query_length equally without remainder
|
89 |
-
key_chunk_size (`int`, *optional*, defaults to 4096):
|
90 |
-
chunk size to divide key and value array value must divide key_value_length equally without remainder
|
91 |
-
|
92 |
-
Returns:
|
93 |
-
(`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head)
|
94 |
-
"""
|
95 |
-
num_q, num_heads, q_features = query.shape[-3:]
|
96 |
-
|
97 |
-
def chunk_scanner(chunk_idx, _):
|
98 |
-
# julienne query array
|
99 |
-
query_chunk = jax.lax.dynamic_slice(
|
100 |
-
operand=query,
|
101 |
-
start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d]
|
102 |
-
slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d]
|
103 |
-
)
|
104 |
-
|
105 |
-
return (
|
106 |
-
chunk_idx + query_chunk_size, # unused ignore it
|
107 |
-
_query_chunk_attention(
|
108 |
-
query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size
|
109 |
-
),
|
110 |
-
)
|
111 |
-
|
112 |
-
_, res = jax.lax.scan(
|
113 |
-
f=chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size) # start counter # stop counter
|
114 |
-
)
|
115 |
-
|
116 |
-
return jnp.concatenate(res, axis=-3) # fuse the chunked result back
|
117 |
-
|
118 |
-
|
119 |
-
class FlaxAttention(nn.Module):
|
120 |
-
r"""
|
121 |
-
A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
|
122 |
-
|
123 |
-
Parameters:
|
124 |
-
query_dim (:obj:`int`):
|
125 |
-
Input hidden states dimension
|
126 |
-
heads (:obj:`int`, *optional*, defaults to 8):
|
127 |
-
Number of heads
|
128 |
-
dim_head (:obj:`int`, *optional*, defaults to 64):
|
129 |
-
Hidden states dimension inside each head
|
130 |
-
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
131 |
-
Dropout rate
|
132 |
-
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
133 |
-
enable memory efficient attention https://arxiv.org/abs/2112.05682
|
134 |
-
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
135 |
-
Parameters `dtype`
|
136 |
-
|
137 |
-
"""
|
138 |
-
query_dim: int
|
139 |
-
heads: int = 8
|
140 |
-
dim_head: int = 64
|
141 |
-
dropout: float = 0.0
|
142 |
-
use_memory_efficient_attention: bool = False
|
143 |
-
dtype: jnp.dtype = jnp.float32
|
144 |
-
|
145 |
-
def setup(self):
|
146 |
-
inner_dim = self.dim_head * self.heads
|
147 |
-
self.scale = self.dim_head**-0.5
|
148 |
-
|
149 |
-
# Weights were exported with old names {to_q, to_k, to_v, to_out}
|
150 |
-
self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q")
|
151 |
-
self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
|
152 |
-
self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
|
153 |
-
|
154 |
-
self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
|
155 |
-
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
156 |
-
|
157 |
-
def reshape_heads_to_batch_dim(self, tensor):
|
158 |
-
batch_size, seq_len, dim = tensor.shape
|
159 |
-
head_size = self.heads
|
160 |
-
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
161 |
-
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
|
162 |
-
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
|
163 |
-
return tensor
|
164 |
-
|
165 |
-
def reshape_batch_dim_to_heads(self, tensor):
|
166 |
-
batch_size, seq_len, dim = tensor.shape
|
167 |
-
head_size = self.heads
|
168 |
-
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
169 |
-
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
|
170 |
-
tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
|
171 |
-
return tensor
|
172 |
-
|
173 |
-
def __call__(self, hidden_states, context=None, deterministic=True):
|
174 |
-
context = hidden_states if context is None else context
|
175 |
-
|
176 |
-
query_proj = self.query(hidden_states)
|
177 |
-
key_proj = self.key(context)
|
178 |
-
value_proj = self.value(context)
|
179 |
-
|
180 |
-
query_states = self.reshape_heads_to_batch_dim(query_proj)
|
181 |
-
key_states = self.reshape_heads_to_batch_dim(key_proj)
|
182 |
-
value_states = self.reshape_heads_to_batch_dim(value_proj)
|
183 |
-
|
184 |
-
if self.use_memory_efficient_attention:
|
185 |
-
query_states = query_states.transpose(1, 0, 2)
|
186 |
-
key_states = key_states.transpose(1, 0, 2)
|
187 |
-
value_states = value_states.transpose(1, 0, 2)
|
188 |
-
|
189 |
-
# this if statement create a chunk size for each layer of the unet
|
190 |
-
# the chunk size is equal to the query_length dimension of the deepest layer of the unet
|
191 |
-
|
192 |
-
flatten_latent_dim = query_states.shape[-3]
|
193 |
-
if flatten_latent_dim % 64 == 0:
|
194 |
-
query_chunk_size = int(flatten_latent_dim / 64)
|
195 |
-
elif flatten_latent_dim % 16 == 0:
|
196 |
-
query_chunk_size = int(flatten_latent_dim / 16)
|
197 |
-
elif flatten_latent_dim % 4 == 0:
|
198 |
-
query_chunk_size = int(flatten_latent_dim / 4)
|
199 |
-
else:
|
200 |
-
query_chunk_size = int(flatten_latent_dim)
|
201 |
-
|
202 |
-
hidden_states = jax_memory_efficient_attention(
|
203 |
-
query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4
|
204 |
-
)
|
205 |
-
|
206 |
-
hidden_states = hidden_states.transpose(1, 0, 2)
|
207 |
-
else:
|
208 |
-
# compute attentions
|
209 |
-
attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
|
210 |
-
attention_scores = attention_scores * self.scale
|
211 |
-
attention_probs = nn.softmax(attention_scores, axis=2)
|
212 |
-
|
213 |
-
# attend to values
|
214 |
-
hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
|
215 |
-
|
216 |
-
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
217 |
-
hidden_states = self.proj_attn(hidden_states)
|
218 |
-
return self.dropout_layer(hidden_states, deterministic=deterministic)
|
219 |
-
|
220 |
-
|
221 |
-
class FlaxBasicTransformerBlock(nn.Module):
|
222 |
-
r"""
|
223 |
-
A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in:
|
224 |
-
https://arxiv.org/abs/1706.03762
|
225 |
-
|
226 |
-
|
227 |
-
Parameters:
|
228 |
-
dim (:obj:`int`):
|
229 |
-
Inner hidden states dimension
|
230 |
-
n_heads (:obj:`int`):
|
231 |
-
Number of heads
|
232 |
-
d_head (:obj:`int`):
|
233 |
-
Hidden states dimension inside each head
|
234 |
-
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
235 |
-
Dropout rate
|
236 |
-
only_cross_attention (`bool`, defaults to `False`):
|
237 |
-
Whether to only apply cross attention.
|
238 |
-
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
239 |
-
Parameters `dtype`
|
240 |
-
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
241 |
-
enable memory efficient attention https://arxiv.org/abs/2112.05682
|
242 |
-
"""
|
243 |
-
dim: int
|
244 |
-
n_heads: int
|
245 |
-
d_head: int
|
246 |
-
dropout: float = 0.0
|
247 |
-
only_cross_attention: bool = False
|
248 |
-
dtype: jnp.dtype = jnp.float32
|
249 |
-
use_memory_efficient_attention: bool = False
|
250 |
-
|
251 |
-
def setup(self):
|
252 |
-
# self attention (or cross_attention if only_cross_attention is True)
|
253 |
-
self.attn1 = FlaxAttention(
|
254 |
-
self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype
|
255 |
-
)
|
256 |
-
# cross attention
|
257 |
-
self.attn2 = FlaxAttention(
|
258 |
-
self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype
|
259 |
-
)
|
260 |
-
self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
|
261 |
-
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
|
262 |
-
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
|
263 |
-
self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
|
264 |
-
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
265 |
-
|
266 |
-
def __call__(self, hidden_states, context, deterministic=True):
|
267 |
-
# self attention
|
268 |
-
residual = hidden_states
|
269 |
-
if self.only_cross_attention:
|
270 |
-
hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic)
|
271 |
-
else:
|
272 |
-
hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
|
273 |
-
hidden_states = hidden_states + residual
|
274 |
-
|
275 |
-
# cross attention
|
276 |
-
residual = hidden_states
|
277 |
-
hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic)
|
278 |
-
hidden_states = hidden_states + residual
|
279 |
-
|
280 |
-
# feed forward
|
281 |
-
residual = hidden_states
|
282 |
-
hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
|
283 |
-
hidden_states = hidden_states + residual
|
284 |
-
|
285 |
-
return self.dropout_layer(hidden_states, deterministic=deterministic)
|
286 |
-
|
287 |
-
|
288 |
-
class FlaxTransformer2DModel(nn.Module):
|
289 |
-
r"""
|
290 |
-
A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:
|
291 |
-
https://arxiv.org/pdf/1506.02025.pdf
|
292 |
-
|
293 |
-
|
294 |
-
Parameters:
|
295 |
-
in_channels (:obj:`int`):
|
296 |
-
Input number of channels
|
297 |
-
n_heads (:obj:`int`):
|
298 |
-
Number of heads
|
299 |
-
d_head (:obj:`int`):
|
300 |
-
Hidden states dimension inside each head
|
301 |
-
depth (:obj:`int`, *optional*, defaults to 1):
|
302 |
-
Number of transformers block
|
303 |
-
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
304 |
-
Dropout rate
|
305 |
-
use_linear_projection (`bool`, defaults to `False`): tbd
|
306 |
-
only_cross_attention (`bool`, defaults to `False`): tbd
|
307 |
-
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
308 |
-
Parameters `dtype`
|
309 |
-
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
310 |
-
enable memory efficient attention https://arxiv.org/abs/2112.05682
|
311 |
-
"""
|
312 |
-
in_channels: int
|
313 |
-
n_heads: int
|
314 |
-
d_head: int
|
315 |
-
depth: int = 1
|
316 |
-
dropout: float = 0.0
|
317 |
-
use_linear_projection: bool = False
|
318 |
-
only_cross_attention: bool = False
|
319 |
-
dtype: jnp.dtype = jnp.float32
|
320 |
-
use_memory_efficient_attention: bool = False
|
321 |
-
|
322 |
-
def setup(self):
|
323 |
-
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
|
324 |
-
|
325 |
-
inner_dim = self.n_heads * self.d_head
|
326 |
-
if self.use_linear_projection:
|
327 |
-
self.proj_in = nn.Dense(inner_dim, dtype=self.dtype)
|
328 |
-
else:
|
329 |
-
self.proj_in = nn.Conv(
|
330 |
-
inner_dim,
|
331 |
-
kernel_size=(1, 1),
|
332 |
-
strides=(1, 1),
|
333 |
-
padding="VALID",
|
334 |
-
dtype=self.dtype,
|
335 |
-
)
|
336 |
-
|
337 |
-
self.transformer_blocks = [
|
338 |
-
FlaxBasicTransformerBlock(
|
339 |
-
inner_dim,
|
340 |
-
self.n_heads,
|
341 |
-
self.d_head,
|
342 |
-
dropout=self.dropout,
|
343 |
-
only_cross_attention=self.only_cross_attention,
|
344 |
-
dtype=self.dtype,
|
345 |
-
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
346 |
-
)
|
347 |
-
for _ in range(self.depth)
|
348 |
-
]
|
349 |
-
|
350 |
-
if self.use_linear_projection:
|
351 |
-
self.proj_out = nn.Dense(inner_dim, dtype=self.dtype)
|
352 |
-
else:
|
353 |
-
self.proj_out = nn.Conv(
|
354 |
-
inner_dim,
|
355 |
-
kernel_size=(1, 1),
|
356 |
-
strides=(1, 1),
|
357 |
-
padding="VALID",
|
358 |
-
dtype=self.dtype,
|
359 |
-
)
|
360 |
-
|
361 |
-
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
362 |
-
|
363 |
-
def __call__(self, hidden_states, context, deterministic=True):
|
364 |
-
batch, height, width, channels = hidden_states.shape
|
365 |
-
residual = hidden_states
|
366 |
-
hidden_states = self.norm(hidden_states)
|
367 |
-
if self.use_linear_projection:
|
368 |
-
hidden_states = hidden_states.reshape(batch, height * width, channels)
|
369 |
-
hidden_states = self.proj_in(hidden_states)
|
370 |
-
else:
|
371 |
-
hidden_states = self.proj_in(hidden_states)
|
372 |
-
hidden_states = hidden_states.reshape(batch, height * width, channels)
|
373 |
-
|
374 |
-
for transformer_block in self.transformer_blocks:
|
375 |
-
hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)
|
376 |
-
|
377 |
-
if self.use_linear_projection:
|
378 |
-
hidden_states = self.proj_out(hidden_states)
|
379 |
-
hidden_states = hidden_states.reshape(batch, height, width, channels)
|
380 |
-
else:
|
381 |
-
hidden_states = hidden_states.reshape(batch, height, width, channels)
|
382 |
-
hidden_states = self.proj_out(hidden_states)
|
383 |
-
|
384 |
-
hidden_states = hidden_states + residual
|
385 |
-
return self.dropout_layer(hidden_states, deterministic=deterministic)
|
386 |
-
|
387 |
-
|
388 |
-
class FlaxFeedForward(nn.Module):
|
389 |
-
r"""
|
390 |
-
Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's
|
391 |
-
[`FeedForward`] class, with the following simplifications:
|
392 |
-
- The activation function is currently hardcoded to a gated linear unit from:
|
393 |
-
https://arxiv.org/abs/2002.05202
|
394 |
-
- `dim_out` is equal to `dim`.
|
395 |
-
- The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`].
|
396 |
-
|
397 |
-
Parameters:
|
398 |
-
dim (:obj:`int`):
|
399 |
-
Inner hidden states dimension
|
400 |
-
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
401 |
-
Dropout rate
|
402 |
-
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
403 |
-
Parameters `dtype`
|
404 |
-
"""
|
405 |
-
dim: int
|
406 |
-
dropout: float = 0.0
|
407 |
-
dtype: jnp.dtype = jnp.float32
|
408 |
-
|
409 |
-
def setup(self):
|
410 |
-
# The second linear layer needs to be called
|
411 |
-
# net_2 for now to match the index of the Sequential layer
|
412 |
-
self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
|
413 |
-
self.net_2 = nn.Dense(self.dim, dtype=self.dtype)
|
414 |
-
|
415 |
-
def __call__(self, hidden_states, deterministic=True):
|
416 |
-
hidden_states = self.net_0(hidden_states, deterministic=deterministic)
|
417 |
-
hidden_states = self.net_2(hidden_states)
|
418 |
-
return hidden_states
|
419 |
-
|
420 |
-
|
421 |
-
class FlaxGEGLU(nn.Module):
|
422 |
-
r"""
|
423 |
-
Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
|
424 |
-
https://arxiv.org/abs/2002.05202.
|
425 |
-
|
426 |
-
Parameters:
|
427 |
-
dim (:obj:`int`):
|
428 |
-
Input hidden states dimension
|
429 |
-
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
430 |
-
Dropout rate
|
431 |
-
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
432 |
-
Parameters `dtype`
|
433 |
-
"""
|
434 |
-
dim: int
|
435 |
-
dropout: float = 0.0
|
436 |
-
dtype: jnp.dtype = jnp.float32
|
437 |
-
|
438 |
-
def setup(self):
|
439 |
-
inner_dim = self.dim * 4
|
440 |
-
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
|
441 |
-
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
442 |
-
|
443 |
-
def __call__(self, hidden_states, deterministic=True):
|
444 |
-
hidden_states = self.proj(hidden_states)
|
445 |
-
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
|
446 |
-
return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4DoF/diffusers/models/attention_processor.py
DELETED
@@ -1,1714 +0,0 @@
|
|
1 |
-
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
from typing import Callable, Optional, Union
|
15 |
-
|
16 |
-
import torch
|
17 |
-
import torch.nn.functional as F
|
18 |
-
from torch import nn
|
19 |
-
|
20 |
-
from ..utils import deprecate, logging, maybe_allow_in_graph
|
21 |
-
from ..utils.import_utils import is_xformers_available
|
22 |
-
|
23 |
-
|
24 |
-
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
25 |
-
|
26 |
-
|
27 |
-
if is_xformers_available():
|
28 |
-
import xformers
|
29 |
-
import xformers.ops
|
30 |
-
else:
|
31 |
-
xformers = None
|
32 |
-
|
33 |
-
|
34 |
-
# 4DoF CaPE
|
35 |
-
import einops
|
36 |
-
def rotate_every_two(x):
|
37 |
-
x = einops.rearrange(x, '... (d j) -> ... d j', j=2)
|
38 |
-
x1, x2 = x.unbind(dim=-1)
|
39 |
-
x = torch.stack((-x2, x1), dim=-1)
|
40 |
-
return einops.rearrange(x, '... d j -> ... (d j)')
|
41 |
-
|
42 |
-
def cape(x, p):
|
43 |
-
d, l, n = x.shape[-1], p.shape[-2], p.shape[-1]
|
44 |
-
assert d % (2 * n) == 0
|
45 |
-
m = einops.repeat(p, 'b l n -> b l (n k)', k=d // n)
|
46 |
-
return m
|
47 |
-
|
48 |
-
def cape_embed(p1, p2, qq, kk):
|
49 |
-
"""
|
50 |
-
Embed camera position encoding into attention map
|
51 |
-
Args:
|
52 |
-
p1: query pose b, l_q, pose_dim
|
53 |
-
p2: key pose b, l_k, pose_dim
|
54 |
-
qq: query feature map b, l_q, feature_dim
|
55 |
-
kk: key feature map b, l_k, feature_dim
|
56 |
-
|
57 |
-
Returns: cape embedded attention map b, l_q, l_k
|
58 |
-
|
59 |
-
"""
|
60 |
-
assert p1.shape[-1] == p2.shape[-1]
|
61 |
-
assert qq.shape[-1] == kk.shape[-1]
|
62 |
-
assert p1.shape[0] == p2.shape[0] == qq.shape[0] == kk.shape[0]
|
63 |
-
assert p1.shape[1] == qq.shape[1]
|
64 |
-
assert p2.shape[1] == kk.shape[1]
|
65 |
-
|
66 |
-
m1 = cape(qq, p1)
|
67 |
-
m2 = cape(kk, p2)
|
68 |
-
|
69 |
-
q = (qq * m1.cos()) + (rotate_every_two(qq) * m1.sin())
|
70 |
-
k = (kk * m2.cos()) + (rotate_every_two(kk) * m2.sin())
|
71 |
-
|
72 |
-
return q, k
|
73 |
-
|
74 |
-
@maybe_allow_in_graph
|
75 |
-
class Attention(nn.Module):
|
76 |
-
r"""
|
77 |
-
A cross attention layer.
|
78 |
-
|
79 |
-
Parameters:
|
80 |
-
query_dim (`int`): The number of channels in the query.
|
81 |
-
cross_attention_dim (`int`, *optional*):
|
82 |
-
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
83 |
-
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
84 |
-
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
|
85 |
-
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
86 |
-
bias (`bool`, *optional*, defaults to False):
|
87 |
-
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
88 |
-
"""
|
89 |
-
|
90 |
-
def __init__(
|
91 |
-
self,
|
92 |
-
query_dim: int,
|
93 |
-
cross_attention_dim: Optional[int] = None,
|
94 |
-
heads: int = 8,
|
95 |
-
dim_head: int = 64,
|
96 |
-
dropout: float = 0.0,
|
97 |
-
bias=False,
|
98 |
-
upcast_attention: bool = False,
|
99 |
-
upcast_softmax: bool = False,
|
100 |
-
cross_attention_norm: Optional[str] = None,
|
101 |
-
cross_attention_norm_num_groups: int = 32,
|
102 |
-
added_kv_proj_dim: Optional[int] = None,
|
103 |
-
norm_num_groups: Optional[int] = None,
|
104 |
-
spatial_norm_dim: Optional[int] = None,
|
105 |
-
out_bias: bool = True,
|
106 |
-
scale_qk: bool = True,
|
107 |
-
only_cross_attention: bool = False,
|
108 |
-
eps: float = 1e-5,
|
109 |
-
rescale_output_factor: float = 1.0,
|
110 |
-
residual_connection: bool = False,
|
111 |
-
_from_deprecated_attn_block=False,
|
112 |
-
processor: Optional["AttnProcessor"] = None,
|
113 |
-
):
|
114 |
-
super().__init__()
|
115 |
-
inner_dim = dim_head * heads
|
116 |
-
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
117 |
-
self.upcast_attention = upcast_attention
|
118 |
-
self.upcast_softmax = upcast_softmax
|
119 |
-
self.rescale_output_factor = rescale_output_factor
|
120 |
-
self.residual_connection = residual_connection
|
121 |
-
self.dropout = dropout
|
122 |
-
|
123 |
-
# we make use of this private variable to know whether this class is loaded
|
124 |
-
# with an deprecated state dict so that we can convert it on the fly
|
125 |
-
self._from_deprecated_attn_block = _from_deprecated_attn_block
|
126 |
-
|
127 |
-
self.scale_qk = scale_qk
|
128 |
-
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
129 |
-
|
130 |
-
self.heads = heads
|
131 |
-
# for slice_size > 0 the attention score computation
|
132 |
-
# is split across the batch axis to save memory
|
133 |
-
# You can set slice_size with `set_attention_slice`
|
134 |
-
self.sliceable_head_dim = heads
|
135 |
-
|
136 |
-
self.added_kv_proj_dim = added_kv_proj_dim
|
137 |
-
self.only_cross_attention = only_cross_attention
|
138 |
-
|
139 |
-
if self.added_kv_proj_dim is None and self.only_cross_attention:
|
140 |
-
raise ValueError(
|
141 |
-
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
|
142 |
-
)
|
143 |
-
|
144 |
-
if norm_num_groups is not None:
|
145 |
-
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
|
146 |
-
else:
|
147 |
-
self.group_norm = None
|
148 |
-
|
149 |
-
if spatial_norm_dim is not None:
|
150 |
-
self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
|
151 |
-
else:
|
152 |
-
self.spatial_norm = None
|
153 |
-
|
154 |
-
if cross_attention_norm is None:
|
155 |
-
self.norm_cross = None
|
156 |
-
elif cross_attention_norm == "layer_norm":
|
157 |
-
self.norm_cross = nn.LayerNorm(cross_attention_dim)
|
158 |
-
elif cross_attention_norm == "group_norm":
|
159 |
-
if self.added_kv_proj_dim is not None:
|
160 |
-
# The given `encoder_hidden_states` are initially of shape
|
161 |
-
# (batch_size, seq_len, added_kv_proj_dim) before being projected
|
162 |
-
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
|
163 |
-
# before the projection, so we need to use `added_kv_proj_dim` as
|
164 |
-
# the number of channels for the group norm.
|
165 |
-
norm_cross_num_channels = added_kv_proj_dim
|
166 |
-
else:
|
167 |
-
norm_cross_num_channels = cross_attention_dim
|
168 |
-
|
169 |
-
self.norm_cross = nn.GroupNorm(
|
170 |
-
num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
|
171 |
-
)
|
172 |
-
else:
|
173 |
-
raise ValueError(
|
174 |
-
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
|
175 |
-
)
|
176 |
-
|
177 |
-
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
178 |
-
|
179 |
-
if not self.only_cross_attention:
|
180 |
-
# only relevant for the `AddedKVProcessor` classes
|
181 |
-
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
182 |
-
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
183 |
-
else:
|
184 |
-
self.to_k = None
|
185 |
-
self.to_v = None
|
186 |
-
|
187 |
-
if self.added_kv_proj_dim is not None:
|
188 |
-
self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim)
|
189 |
-
self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim)
|
190 |
-
|
191 |
-
self.to_out = nn.ModuleList([])
|
192 |
-
self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))
|
193 |
-
self.to_out.append(nn.Dropout(dropout))
|
194 |
-
|
195 |
-
# set attention processor
|
196 |
-
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
197 |
-
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
198 |
-
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
199 |
-
if processor is None:
|
200 |
-
processor = (
|
201 |
-
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
|
202 |
-
)
|
203 |
-
self.set_processor(processor)
|
204 |
-
|
205 |
-
def set_use_memory_efficient_attention_xformers(
|
206 |
-
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
207 |
-
):
|
208 |
-
is_lora = hasattr(self, "processor") and isinstance(
|
209 |
-
self.processor,
|
210 |
-
(LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor),
|
211 |
-
)
|
212 |
-
is_custom_diffusion = hasattr(self, "processor") and isinstance(
|
213 |
-
self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
|
214 |
-
)
|
215 |
-
is_added_kv_processor = hasattr(self, "processor") and isinstance(
|
216 |
-
self.processor,
|
217 |
-
(
|
218 |
-
AttnAddedKVProcessor,
|
219 |
-
AttnAddedKVProcessor2_0,
|
220 |
-
SlicedAttnAddedKVProcessor,
|
221 |
-
XFormersAttnAddedKVProcessor,
|
222 |
-
LoRAAttnAddedKVProcessor,
|
223 |
-
),
|
224 |
-
)
|
225 |
-
|
226 |
-
if use_memory_efficient_attention_xformers:
|
227 |
-
if is_added_kv_processor and (is_lora or is_custom_diffusion):
|
228 |
-
raise NotImplementedError(
|
229 |
-
f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}"
|
230 |
-
)
|
231 |
-
if not is_xformers_available():
|
232 |
-
raise ModuleNotFoundError(
|
233 |
-
(
|
234 |
-
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
235 |
-
" xformers"
|
236 |
-
),
|
237 |
-
name="xformers",
|
238 |
-
)
|
239 |
-
elif not torch.cuda.is_available():
|
240 |
-
raise ValueError(
|
241 |
-
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
|
242 |
-
" only available for GPU "
|
243 |
-
)
|
244 |
-
else:
|
245 |
-
try:
|
246 |
-
# Make sure we can run the memory efficient attention
|
247 |
-
_ = xformers.ops.memory_efficient_attention(
|
248 |
-
torch.randn((1, 2, 40), device="cuda"),
|
249 |
-
torch.randn((1, 2, 40), device="cuda"),
|
250 |
-
torch.randn((1, 2, 40), device="cuda"),
|
251 |
-
)
|
252 |
-
except Exception as e:
|
253 |
-
raise e
|
254 |
-
|
255 |
-
if is_lora:
|
256 |
-
# TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
|
257 |
-
# variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
|
258 |
-
processor = LoRAXFormersAttnProcessor(
|
259 |
-
hidden_size=self.processor.hidden_size,
|
260 |
-
cross_attention_dim=self.processor.cross_attention_dim,
|
261 |
-
rank=self.processor.rank,
|
262 |
-
attention_op=attention_op,
|
263 |
-
)
|
264 |
-
processor.load_state_dict(self.processor.state_dict())
|
265 |
-
processor.to(self.processor.to_q_lora.up.weight.device)
|
266 |
-
elif is_custom_diffusion:
|
267 |
-
processor = CustomDiffusionXFormersAttnProcessor(
|
268 |
-
train_kv=self.processor.train_kv,
|
269 |
-
train_q_out=self.processor.train_q_out,
|
270 |
-
hidden_size=self.processor.hidden_size,
|
271 |
-
cross_attention_dim=self.processor.cross_attention_dim,
|
272 |
-
attention_op=attention_op,
|
273 |
-
)
|
274 |
-
processor.load_state_dict(self.processor.state_dict())
|
275 |
-
if hasattr(self.processor, "to_k_custom_diffusion"):
|
276 |
-
processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
277 |
-
elif is_added_kv_processor:
|
278 |
-
# TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
|
279 |
-
# which uses this type of cross attention ONLY because the attention mask of format
|
280 |
-
# [0, ..., -10.000, ..., 0, ...,] is not supported
|
281 |
-
# throw warning
|
282 |
-
logger.info(
|
283 |
-
"Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
|
284 |
-
)
|
285 |
-
processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
|
286 |
-
else:
|
287 |
-
processor = XFormersAttnProcessor(attention_op=attention_op)
|
288 |
-
else:
|
289 |
-
if is_lora:
|
290 |
-
attn_processor_class = (
|
291 |
-
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
292 |
-
)
|
293 |
-
processor = attn_processor_class(
|
294 |
-
hidden_size=self.processor.hidden_size,
|
295 |
-
cross_attention_dim=self.processor.cross_attention_dim,
|
296 |
-
rank=self.processor.rank,
|
297 |
-
)
|
298 |
-
processor.load_state_dict(self.processor.state_dict())
|
299 |
-
processor.to(self.processor.to_q_lora.up.weight.device)
|
300 |
-
elif is_custom_diffusion:
|
301 |
-
processor = CustomDiffusionAttnProcessor(
|
302 |
-
train_kv=self.processor.train_kv,
|
303 |
-
train_q_out=self.processor.train_q_out,
|
304 |
-
hidden_size=self.processor.hidden_size,
|
305 |
-
cross_attention_dim=self.processor.cross_attention_dim,
|
306 |
-
)
|
307 |
-
processor.load_state_dict(self.processor.state_dict())
|
308 |
-
if hasattr(self.processor, "to_k_custom_diffusion"):
|
309 |
-
processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
310 |
-
else:
|
311 |
-
# set attention processor
|
312 |
-
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
313 |
-
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
314 |
-
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
315 |
-
processor = (
|
316 |
-
AttnProcessor2_0()
|
317 |
-
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
|
318 |
-
else AttnProcessor()
|
319 |
-
)
|
320 |
-
|
321 |
-
self.set_processor(processor)
|
322 |
-
|
323 |
-
def set_attention_slice(self, slice_size):
|
324 |
-
if slice_size is not None and slice_size > self.sliceable_head_dim:
|
325 |
-
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
|
326 |
-
|
327 |
-
if slice_size is not None and self.added_kv_proj_dim is not None:
|
328 |
-
processor = SlicedAttnAddedKVProcessor(slice_size)
|
329 |
-
elif slice_size is not None:
|
330 |
-
processor = SlicedAttnProcessor(slice_size)
|
331 |
-
elif self.added_kv_proj_dim is not None:
|
332 |
-
processor = AttnAddedKVProcessor()
|
333 |
-
else:
|
334 |
-
# set attention processor
|
335 |
-
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
336 |
-
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
337 |
-
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
338 |
-
processor = (
|
339 |
-
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
|
340 |
-
)
|
341 |
-
|
342 |
-
self.set_processor(processor)
|
343 |
-
|
344 |
-
def set_processor(self, processor: "AttnProcessor"):
|
345 |
-
# if current processor is in `self._modules` and if passed `processor` is not, we need to
|
346 |
-
# pop `processor` from `self._modules`
|
347 |
-
if (
|
348 |
-
hasattr(self, "processor")
|
349 |
-
and isinstance(self.processor, torch.nn.Module)
|
350 |
-
and not isinstance(processor, torch.nn.Module)
|
351 |
-
):
|
352 |
-
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
353 |
-
self._modules.pop("processor")
|
354 |
-
|
355 |
-
self.processor = processor
|
356 |
-
|
357 |
-
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
|
358 |
-
# The `Attention` class can call different attention processors / attention functions
|
359 |
-
# here we simply pass along all tensors to the selected processor class
|
360 |
-
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
361 |
-
return self.processor(
|
362 |
-
self,
|
363 |
-
hidden_states,
|
364 |
-
encoder_hidden_states=encoder_hidden_states,
|
365 |
-
attention_mask=attention_mask,
|
366 |
-
**cross_attention_kwargs,
|
367 |
-
)
|
368 |
-
|
369 |
-
def batch_to_head_dim(self, tensor):
|
370 |
-
head_size = self.heads
|
371 |
-
batch_size, seq_len, dim = tensor.shape
|
372 |
-
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
373 |
-
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
374 |
-
return tensor
|
375 |
-
|
376 |
-
def head_to_batch_dim(self, tensor, out_dim=3):
|
377 |
-
head_size = self.heads
|
378 |
-
batch_size, seq_len, dim = tensor.shape
|
379 |
-
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
380 |
-
tensor = tensor.permute(0, 2, 1, 3)
|
381 |
-
|
382 |
-
if out_dim == 3:
|
383 |
-
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
|
384 |
-
|
385 |
-
return tensor
|
386 |
-
|
387 |
-
def get_attention_scores(self, query, key, attention_mask=None):
|
388 |
-
dtype = query.dtype
|
389 |
-
if self.upcast_attention:
|
390 |
-
query = query.float()
|
391 |
-
key = key.float()
|
392 |
-
|
393 |
-
if attention_mask is None:
|
394 |
-
baddbmm_input = torch.empty(
|
395 |
-
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
|
396 |
-
)
|
397 |
-
beta = 0
|
398 |
-
else:
|
399 |
-
baddbmm_input = attention_mask
|
400 |
-
beta = 1
|
401 |
-
|
402 |
-
attention_scores = torch.baddbmm(
|
403 |
-
baddbmm_input,
|
404 |
-
query,
|
405 |
-
key.transpose(-1, -2),
|
406 |
-
beta=beta,
|
407 |
-
alpha=self.scale,
|
408 |
-
)
|
409 |
-
del baddbmm_input
|
410 |
-
|
411 |
-
if self.upcast_softmax:
|
412 |
-
attention_scores = attention_scores.float()
|
413 |
-
|
414 |
-
attention_probs = attention_scores.softmax(dim=-1)
|
415 |
-
del attention_scores
|
416 |
-
|
417 |
-
attention_probs = attention_probs.to(dtype)
|
418 |
-
|
419 |
-
return attention_probs
|
420 |
-
|
421 |
-
def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3):
|
422 |
-
if batch_size is None:
|
423 |
-
deprecate(
|
424 |
-
"batch_size=None",
|
425 |
-
"0.0.15",
|
426 |
-
(
|
427 |
-
"Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect"
|
428 |
-
" attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to"
|
429 |
-
" `prepare_attention_mask` when preparing the attention_mask."
|
430 |
-
),
|
431 |
-
)
|
432 |
-
batch_size = 1
|
433 |
-
|
434 |
-
head_size = self.heads
|
435 |
-
if attention_mask is None:
|
436 |
-
return attention_mask
|
437 |
-
|
438 |
-
current_length: int = attention_mask.shape[-1]
|
439 |
-
if current_length != target_length:
|
440 |
-
if attention_mask.device.type == "mps":
|
441 |
-
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
442 |
-
# Instead, we can manually construct the padding tensor.
|
443 |
-
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
|
444 |
-
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
|
445 |
-
attention_mask = torch.cat([attention_mask, padding], dim=2)
|
446 |
-
else:
|
447 |
-
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
|
448 |
-
# we want to instead pad by (0, remaining_length), where remaining_length is:
|
449 |
-
# remaining_length: int = target_length - current_length
|
450 |
-
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
|
451 |
-
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
452 |
-
|
453 |
-
if out_dim == 3:
|
454 |
-
if attention_mask.shape[0] < batch_size * head_size:
|
455 |
-
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
|
456 |
-
elif out_dim == 4:
|
457 |
-
attention_mask = attention_mask.unsqueeze(1)
|
458 |
-
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
|
459 |
-
|
460 |
-
return attention_mask
|
461 |
-
|
462 |
-
def norm_encoder_hidden_states(self, encoder_hidden_states):
|
463 |
-
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
|
464 |
-
|
465 |
-
if isinstance(self.norm_cross, nn.LayerNorm):
|
466 |
-
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
467 |
-
elif isinstance(self.norm_cross, nn.GroupNorm):
|
468 |
-
# Group norm norms along the channels dimension and expects
|
469 |
-
# input to be in the shape of (N, C, *). In this case, we want
|
470 |
-
# to norm along the hidden dimension, so we need to move
|
471 |
-
# (batch_size, sequence_length, hidden_size) ->
|
472 |
-
# (batch_size, hidden_size, sequence_length)
|
473 |
-
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
474 |
-
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
475 |
-
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
476 |
-
else:
|
477 |
-
assert False
|
478 |
-
|
479 |
-
return encoder_hidden_states
|
480 |
-
|
481 |
-
|
482 |
-
class AttnProcessor:
|
483 |
-
r"""
|
484 |
-
Default processor for performing attention-related computations.
|
485 |
-
"""
|
486 |
-
|
487 |
-
def __call__(
|
488 |
-
self,
|
489 |
-
attn: Attention,
|
490 |
-
hidden_states,
|
491 |
-
encoder_hidden_states=None,
|
492 |
-
attention_mask=None,
|
493 |
-
temb=None,
|
494 |
-
):
|
495 |
-
residual = hidden_states
|
496 |
-
|
497 |
-
if attn.spatial_norm is not None:
|
498 |
-
hidden_states = attn.spatial_norm(hidden_states, temb)
|
499 |
-
|
500 |
-
input_ndim = hidden_states.ndim
|
501 |
-
|
502 |
-
if input_ndim == 4:
|
503 |
-
batch_size, channel, height, width = hidden_states.shape
|
504 |
-
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
505 |
-
|
506 |
-
batch_size, sequence_length, _ = (
|
507 |
-
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
508 |
-
)
|
509 |
-
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
510 |
-
|
511 |
-
if attn.group_norm is not None:
|
512 |
-
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
513 |
-
|
514 |
-
query = attn.to_q(hidden_states)
|
515 |
-
|
516 |
-
if encoder_hidden_states is None:
|
517 |
-
encoder_hidden_states = hidden_states
|
518 |
-
elif attn.norm_cross:
|
519 |
-
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
520 |
-
|
521 |
-
key = attn.to_k(encoder_hidden_states)
|
522 |
-
value = attn.to_v(encoder_hidden_states)
|
523 |
-
|
524 |
-
query = attn.head_to_batch_dim(query)
|
525 |
-
key = attn.head_to_batch_dim(key)
|
526 |
-
value = attn.head_to_batch_dim(value)
|
527 |
-
|
528 |
-
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
529 |
-
hidden_states = torch.bmm(attention_probs, value)
|
530 |
-
hidden_states = attn.batch_to_head_dim(hidden_states)
|
531 |
-
|
532 |
-
# linear proj
|
533 |
-
hidden_states = attn.to_out[0](hidden_states)
|
534 |
-
# dropout
|
535 |
-
hidden_states = attn.to_out[1](hidden_states)
|
536 |
-
|
537 |
-
if input_ndim == 4:
|
538 |
-
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
539 |
-
|
540 |
-
if attn.residual_connection:
|
541 |
-
hidden_states = hidden_states + residual
|
542 |
-
|
543 |
-
hidden_states = hidden_states / attn.rescale_output_factor
|
544 |
-
|
545 |
-
return hidden_states
|
546 |
-
|
547 |
-
|
548 |
-
class LoRALinearLayer(nn.Module):
|
549 |
-
def __init__(self, in_features, out_features, rank=4, network_alpha=None):
|
550 |
-
super().__init__()
|
551 |
-
|
552 |
-
if rank > min(in_features, out_features):
|
553 |
-
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
|
554 |
-
|
555 |
-
self.down = nn.Linear(in_features, rank, bias=False)
|
556 |
-
self.up = nn.Linear(rank, out_features, bias=False)
|
557 |
-
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
558 |
-
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
559 |
-
self.network_alpha = network_alpha
|
560 |
-
self.rank = rank
|
561 |
-
|
562 |
-
nn.init.normal_(self.down.weight, std=1 / rank)
|
563 |
-
nn.init.zeros_(self.up.weight)
|
564 |
-
|
565 |
-
def forward(self, hidden_states):
|
566 |
-
orig_dtype = hidden_states.dtype
|
567 |
-
dtype = self.down.weight.dtype
|
568 |
-
|
569 |
-
down_hidden_states = self.down(hidden_states.to(dtype))
|
570 |
-
up_hidden_states = self.up(down_hidden_states)
|
571 |
-
|
572 |
-
if self.network_alpha is not None:
|
573 |
-
up_hidden_states *= self.network_alpha / self.rank
|
574 |
-
|
575 |
-
return up_hidden_states.to(orig_dtype)
|
576 |
-
|
577 |
-
|
578 |
-
class LoRAAttnProcessor(nn.Module):
|
579 |
-
r"""
|
580 |
-
Processor for implementing the LoRA attention mechanism.
|
581 |
-
|
582 |
-
Args:
|
583 |
-
hidden_size (`int`, *optional*):
|
584 |
-
The hidden size of the attention layer.
|
585 |
-
cross_attention_dim (`int`, *optional*):
|
586 |
-
The number of channels in the `encoder_hidden_states`.
|
587 |
-
rank (`int`, defaults to 4):
|
588 |
-
The dimension of the LoRA update matrices.
|
589 |
-
network_alpha (`int`, *optional*):
|
590 |
-
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
591 |
-
"""
|
592 |
-
|
593 |
-
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
|
594 |
-
super().__init__()
|
595 |
-
|
596 |
-
self.hidden_size = hidden_size
|
597 |
-
self.cross_attention_dim = cross_attention_dim
|
598 |
-
self.rank = rank
|
599 |
-
|
600 |
-
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
601 |
-
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
602 |
-
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
603 |
-
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
604 |
-
|
605 |
-
def __call__(
|
606 |
-
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
|
607 |
-
):
|
608 |
-
residual = hidden_states
|
609 |
-
|
610 |
-
if attn.spatial_norm is not None:
|
611 |
-
hidden_states = attn.spatial_norm(hidden_states, temb)
|
612 |
-
|
613 |
-
input_ndim = hidden_states.ndim
|
614 |
-
|
615 |
-
if input_ndim == 4:
|
616 |
-
batch_size, channel, height, width = hidden_states.shape
|
617 |
-
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
618 |
-
|
619 |
-
batch_size, sequence_length, _ = (
|
620 |
-
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
621 |
-
)
|
622 |
-
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
623 |
-
|
624 |
-
if attn.group_norm is not None:
|
625 |
-
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
626 |
-
|
627 |
-
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
628 |
-
query = attn.head_to_batch_dim(query)
|
629 |
-
|
630 |
-
if encoder_hidden_states is None:
|
631 |
-
encoder_hidden_states = hidden_states
|
632 |
-
elif attn.norm_cross:
|
633 |
-
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
634 |
-
|
635 |
-
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
|
636 |
-
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
|
637 |
-
|
638 |
-
key = attn.head_to_batch_dim(key)
|
639 |
-
value = attn.head_to_batch_dim(value)
|
640 |
-
|
641 |
-
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
642 |
-
hidden_states = torch.bmm(attention_probs, value)
|
643 |
-
hidden_states = attn.batch_to_head_dim(hidden_states)
|
644 |
-
|
645 |
-
# linear proj
|
646 |
-
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
647 |
-
# dropout
|
648 |
-
hidden_states = attn.to_out[1](hidden_states)
|
649 |
-
|
650 |
-
if input_ndim == 4:
|
651 |
-
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
652 |
-
|
653 |
-
if attn.residual_connection:
|
654 |
-
hidden_states = hidden_states + residual
|
655 |
-
|
656 |
-
hidden_states = hidden_states / attn.rescale_output_factor
|
657 |
-
|
658 |
-
return hidden_states
|
659 |
-
|
660 |
-
|
661 |
-
class CustomDiffusionAttnProcessor(nn.Module):
|
662 |
-
r"""
|
663 |
-
Processor for implementing attention for the Custom Diffusion method.
|
664 |
-
|
665 |
-
Args:
|
666 |
-
train_kv (`bool`, defaults to `True`):
|
667 |
-
Whether to newly train the key and value matrices corresponding to the text features.
|
668 |
-
train_q_out (`bool`, defaults to `True`):
|
669 |
-
Whether to newly train query matrices corresponding to the latent image features.
|
670 |
-
hidden_size (`int`, *optional*, defaults to `None`):
|
671 |
-
The hidden size of the attention layer.
|
672 |
-
cross_attention_dim (`int`, *optional*, defaults to `None`):
|
673 |
-
The number of channels in the `encoder_hidden_states`.
|
674 |
-
out_bias (`bool`, defaults to `True`):
|
675 |
-
Whether to include the bias parameter in `train_q_out`.
|
676 |
-
dropout (`float`, *optional*, defaults to 0.0):
|
677 |
-
The dropout probability to use.
|
678 |
-
"""
|
679 |
-
|
680 |
-
def __init__(
|
681 |
-
self,
|
682 |
-
train_kv=True,
|
683 |
-
train_q_out=True,
|
684 |
-
hidden_size=None,
|
685 |
-
cross_attention_dim=None,
|
686 |
-
out_bias=True,
|
687 |
-
dropout=0.0,
|
688 |
-
):
|
689 |
-
super().__init__()
|
690 |
-
self.train_kv = train_kv
|
691 |
-
self.train_q_out = train_q_out
|
692 |
-
|
693 |
-
self.hidden_size = hidden_size
|
694 |
-
self.cross_attention_dim = cross_attention_dim
|
695 |
-
|
696 |
-
# `_custom_diffusion` id for easy serialization and loading.
|
697 |
-
if self.train_kv:
|
698 |
-
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
699 |
-
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
700 |
-
if self.train_q_out:
|
701 |
-
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
|
702 |
-
self.to_out_custom_diffusion = nn.ModuleList([])
|
703 |
-
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
|
704 |
-
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
|
705 |
-
|
706 |
-
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
707 |
-
batch_size, sequence_length, _ = hidden_states.shape
|
708 |
-
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
709 |
-
if self.train_q_out:
|
710 |
-
query = self.to_q_custom_diffusion(hidden_states)
|
711 |
-
else:
|
712 |
-
query = attn.to_q(hidden_states)
|
713 |
-
|
714 |
-
if encoder_hidden_states is None:
|
715 |
-
crossattn = False
|
716 |
-
encoder_hidden_states = hidden_states
|
717 |
-
else:
|
718 |
-
crossattn = True
|
719 |
-
if attn.norm_cross:
|
720 |
-
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
721 |
-
|
722 |
-
if self.train_kv:
|
723 |
-
key = self.to_k_custom_diffusion(encoder_hidden_states)
|
724 |
-
value = self.to_v_custom_diffusion(encoder_hidden_states)
|
725 |
-
else:
|
726 |
-
key = attn.to_k(encoder_hidden_states)
|
727 |
-
value = attn.to_v(encoder_hidden_states)
|
728 |
-
|
729 |
-
if crossattn:
|
730 |
-
detach = torch.ones_like(key)
|
731 |
-
detach[:, :1, :] = detach[:, :1, :] * 0.0
|
732 |
-
key = detach * key + (1 - detach) * key.detach()
|
733 |
-
value = detach * value + (1 - detach) * value.detach()
|
734 |
-
|
735 |
-
query = attn.head_to_batch_dim(query)
|
736 |
-
key = attn.head_to_batch_dim(key)
|
737 |
-
value = attn.head_to_batch_dim(value)
|
738 |
-
|
739 |
-
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
740 |
-
hidden_states = torch.bmm(attention_probs, value)
|
741 |
-
hidden_states = attn.batch_to_head_dim(hidden_states)
|
742 |
-
|
743 |
-
if self.train_q_out:
|
744 |
-
# linear proj
|
745 |
-
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
|
746 |
-
# dropout
|
747 |
-
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
|
748 |
-
else:
|
749 |
-
# linear proj
|
750 |
-
hidden_states = attn.to_out[0](hidden_states)
|
751 |
-
# dropout
|
752 |
-
hidden_states = attn.to_out[1](hidden_states)
|
753 |
-
|
754 |
-
return hidden_states
|
755 |
-
|
756 |
-
|
757 |
-
class AttnAddedKVProcessor:
|
758 |
-
r"""
|
759 |
-
Processor for performing attention-related computations with extra learnable key and value matrices for the text
|
760 |
-
encoder.
|
761 |
-
"""
|
762 |
-
|
763 |
-
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
764 |
-
residual = hidden_states
|
765 |
-
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
766 |
-
batch_size, sequence_length, _ = hidden_states.shape
|
767 |
-
|
768 |
-
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
769 |
-
|
770 |
-
if encoder_hidden_states is None:
|
771 |
-
encoder_hidden_states = hidden_states
|
772 |
-
elif attn.norm_cross:
|
773 |
-
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
774 |
-
|
775 |
-
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
776 |
-
|
777 |
-
query = attn.to_q(hidden_states)
|
778 |
-
query = attn.head_to_batch_dim(query)
|
779 |
-
|
780 |
-
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
781 |
-
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
782 |
-
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
783 |
-
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
784 |
-
|
785 |
-
if not attn.only_cross_attention:
|
786 |
-
key = attn.to_k(hidden_states)
|
787 |
-
value = attn.to_v(hidden_states)
|
788 |
-
key = attn.head_to_batch_dim(key)
|
789 |
-
value = attn.head_to_batch_dim(value)
|
790 |
-
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
791 |
-
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
792 |
-
else:
|
793 |
-
key = encoder_hidden_states_key_proj
|
794 |
-
value = encoder_hidden_states_value_proj
|
795 |
-
|
796 |
-
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
797 |
-
hidden_states = torch.bmm(attention_probs, value)
|
798 |
-
hidden_states = attn.batch_to_head_dim(hidden_states)
|
799 |
-
|
800 |
-
# linear proj
|
801 |
-
hidden_states = attn.to_out[0](hidden_states)
|
802 |
-
# dropout
|
803 |
-
hidden_states = attn.to_out[1](hidden_states)
|
804 |
-
|
805 |
-
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
806 |
-
hidden_states = hidden_states + residual
|
807 |
-
|
808 |
-
return hidden_states
|
809 |
-
|
810 |
-
|
811 |
-
class AttnAddedKVProcessor2_0:
|
812 |
-
r"""
|
813 |
-
Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
|
814 |
-
learnable key and value matrices for the text encoder.
|
815 |
-
"""
|
816 |
-
|
817 |
-
def __init__(self):
|
818 |
-
if not hasattr(F, "scaled_dot_product_attention"):
|
819 |
-
raise ImportError(
|
820 |
-
"AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
821 |
-
)
|
822 |
-
|
823 |
-
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
824 |
-
residual = hidden_states
|
825 |
-
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
826 |
-
batch_size, sequence_length, _ = hidden_states.shape
|
827 |
-
|
828 |
-
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
|
829 |
-
|
830 |
-
if encoder_hidden_states is None:
|
831 |
-
encoder_hidden_states = hidden_states
|
832 |
-
elif attn.norm_cross:
|
833 |
-
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
834 |
-
|
835 |
-
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
836 |
-
|
837 |
-
query = attn.to_q(hidden_states)
|
838 |
-
query = attn.head_to_batch_dim(query, out_dim=4)
|
839 |
-
|
840 |
-
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
841 |
-
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
842 |
-
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
|
843 |
-
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
|
844 |
-
|
845 |
-
if not attn.only_cross_attention:
|
846 |
-
key = attn.to_k(hidden_states)
|
847 |
-
value = attn.to_v(hidden_states)
|
848 |
-
key = attn.head_to_batch_dim(key, out_dim=4)
|
849 |
-
value = attn.head_to_batch_dim(value, out_dim=4)
|
850 |
-
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
851 |
-
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
852 |
-
else:
|
853 |
-
key = encoder_hidden_states_key_proj
|
854 |
-
value = encoder_hidden_states_value_proj
|
855 |
-
|
856 |
-
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
857 |
-
# TODO: add support for attn.scale when we move to Torch 2.1
|
858 |
-
hidden_states = F.scaled_dot_product_attention(
|
859 |
-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
860 |
-
)
|
861 |
-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
|
862 |
-
|
863 |
-
# linear proj
|
864 |
-
hidden_states = attn.to_out[0](hidden_states)
|
865 |
-
# dropout
|
866 |
-
hidden_states = attn.to_out[1](hidden_states)
|
867 |
-
|
868 |
-
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
869 |
-
hidden_states = hidden_states + residual
|
870 |
-
|
871 |
-
return hidden_states
|
872 |
-
|
873 |
-
|
874 |
-
class LoRAAttnAddedKVProcessor(nn.Module):
|
875 |
-
r"""
|
876 |
-
Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text
|
877 |
-
encoder.
|
878 |
-
|
879 |
-
Args:
|
880 |
-
hidden_size (`int`, *optional*):
|
881 |
-
The hidden size of the attention layer.
|
882 |
-
cross_attention_dim (`int`, *optional*, defaults to `None`):
|
883 |
-
The number of channels in the `encoder_hidden_states`.
|
884 |
-
rank (`int`, defaults to 4):
|
885 |
-
The dimension of the LoRA update matrices.
|
886 |
-
|
887 |
-
"""
|
888 |
-
|
889 |
-
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
|
890 |
-
super().__init__()
|
891 |
-
|
892 |
-
self.hidden_size = hidden_size
|
893 |
-
self.cross_attention_dim = cross_attention_dim
|
894 |
-
self.rank = rank
|
895 |
-
|
896 |
-
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
897 |
-
self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
898 |
-
self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
899 |
-
self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
900 |
-
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
901 |
-
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
902 |
-
|
903 |
-
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
|
904 |
-
residual = hidden_states
|
905 |
-
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
906 |
-
batch_size, sequence_length, _ = hidden_states.shape
|
907 |
-
|
908 |
-
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
909 |
-
|
910 |
-
if encoder_hidden_states is None:
|
911 |
-
encoder_hidden_states = hidden_states
|
912 |
-
elif attn.norm_cross:
|
913 |
-
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
914 |
-
|
915 |
-
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
916 |
-
|
917 |
-
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
918 |
-
query = attn.head_to_batch_dim(query)
|
919 |
-
|
920 |
-
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + scale * self.add_k_proj_lora(
|
921 |
-
encoder_hidden_states
|
922 |
-
)
|
923 |
-
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + scale * self.add_v_proj_lora(
|
924 |
-
encoder_hidden_states
|
925 |
-
)
|
926 |
-
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
927 |
-
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
928 |
-
|
929 |
-
if not attn.only_cross_attention:
|
930 |
-
key = attn.to_k(hidden_states) + scale * self.to_k_lora(hidden_states)
|
931 |
-
value = attn.to_v(hidden_states) + scale * self.to_v_lora(hidden_states)
|
932 |
-
key = attn.head_to_batch_dim(key)
|
933 |
-
value = attn.head_to_batch_dim(value)
|
934 |
-
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
935 |
-
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
936 |
-
else:
|
937 |
-
key = encoder_hidden_states_key_proj
|
938 |
-
value = encoder_hidden_states_value_proj
|
939 |
-
|
940 |
-
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
941 |
-
hidden_states = torch.bmm(attention_probs, value)
|
942 |
-
hidden_states = attn.batch_to_head_dim(hidden_states)
|
943 |
-
|
944 |
-
# linear proj
|
945 |
-
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
946 |
-
# dropout
|
947 |
-
hidden_states = attn.to_out[1](hidden_states)
|
948 |
-
|
949 |
-
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
950 |
-
hidden_states = hidden_states + residual
|
951 |
-
|
952 |
-
return hidden_states
|
953 |
-
|
954 |
-
|
955 |
-
class XFormersAttnAddedKVProcessor:
|
956 |
-
r"""
|
957 |
-
Processor for implementing memory efficient attention using xFormers.
|
958 |
-
|
959 |
-
Args:
|
960 |
-
attention_op (`Callable`, *optional*, defaults to `None`):
|
961 |
-
The base
|
962 |
-
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
963 |
-
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
964 |
-
operator.
|
965 |
-
"""
|
966 |
-
|
967 |
-
def __init__(self, attention_op: Optional[Callable] = None):
|
968 |
-
self.attention_op = attention_op
|
969 |
-
|
970 |
-
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
971 |
-
residual = hidden_states
|
972 |
-
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
973 |
-
batch_size, sequence_length, _ = hidden_states.shape
|
974 |
-
|
975 |
-
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
976 |
-
|
977 |
-
if encoder_hidden_states is None:
|
978 |
-
encoder_hidden_states = hidden_states
|
979 |
-
elif attn.norm_cross:
|
980 |
-
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
981 |
-
|
982 |
-
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
983 |
-
|
984 |
-
query = attn.to_q(hidden_states)
|
985 |
-
query = attn.head_to_batch_dim(query)
|
986 |
-
|
987 |
-
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
988 |
-
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
989 |
-
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
990 |
-
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
991 |
-
|
992 |
-
if not attn.only_cross_attention:
|
993 |
-
key = attn.to_k(hidden_states)
|
994 |
-
value = attn.to_v(hidden_states)
|
995 |
-
key = attn.head_to_batch_dim(key)
|
996 |
-
value = attn.head_to_batch_dim(value)
|
997 |
-
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
998 |
-
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
999 |
-
else:
|
1000 |
-
key = encoder_hidden_states_key_proj
|
1001 |
-
value = encoder_hidden_states_value_proj
|
1002 |
-
|
1003 |
-
hidden_states = xformers.ops.memory_efficient_attention(
|
1004 |
-
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
1005 |
-
)
|
1006 |
-
hidden_states = hidden_states.to(query.dtype)
|
1007 |
-
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1008 |
-
|
1009 |
-
# linear proj
|
1010 |
-
hidden_states = attn.to_out[0](hidden_states)
|
1011 |
-
# dropout
|
1012 |
-
hidden_states = attn.to_out[1](hidden_states)
|
1013 |
-
|
1014 |
-
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
1015 |
-
hidden_states = hidden_states + residual
|
1016 |
-
|
1017 |
-
return hidden_states
|
1018 |
-
|
1019 |
-
|
1020 |
-
class XFormersAttnProcessor:
|
1021 |
-
r"""
|
1022 |
-
Processor for implementing memory efficient attention using xFormers.
|
1023 |
-
|
1024 |
-
Args:
|
1025 |
-
attention_op (`Callable`, *optional*, defaults to `None`):
|
1026 |
-
The base
|
1027 |
-
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
1028 |
-
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
1029 |
-
operator.
|
1030 |
-
"""
|
1031 |
-
|
1032 |
-
def __init__(self, attention_op: Optional[Callable] = None):
|
1033 |
-
self.attention_op = attention_op
|
1034 |
-
|
1035 |
-
def __call__(
|
1036 |
-
self,
|
1037 |
-
attn: Attention,
|
1038 |
-
hidden_states: torch.FloatTensor,
|
1039 |
-
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
1040 |
-
attention_mask: Optional[torch.FloatTensor] = None,
|
1041 |
-
temb: Optional[torch.FloatTensor] = None,
|
1042 |
-
posemb: Optional = None,
|
1043 |
-
):
|
1044 |
-
residual = hidden_states
|
1045 |
-
|
1046 |
-
if attn.spatial_norm is not None:
|
1047 |
-
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1048 |
-
|
1049 |
-
input_ndim = hidden_states.ndim
|
1050 |
-
|
1051 |
-
if input_ndim == 4:
|
1052 |
-
batch_size, channel, height, width = hidden_states.shape
|
1053 |
-
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1054 |
-
|
1055 |
-
if posemb is not None:
|
1056 |
-
# turn 2d attention into multiview attention
|
1057 |
-
self_attn = encoder_hidden_states is None # check if self attn or cross attn
|
1058 |
-
p_out, p_in = posemb
|
1059 |
-
t_out, t_in = p_out.shape[1], p_in.shape[1] # t size
|
1060 |
-
hidden_states = einops.rearrange(hidden_states, '(b t_out) l d -> b (t_out l) d', t_out=t_out)
|
1061 |
-
|
1062 |
-
batch_size, key_tokens, _ = (
|
1063 |
-
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1064 |
-
)
|
1065 |
-
|
1066 |
-
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
|
1067 |
-
if attention_mask is not None:
|
1068 |
-
# expand our mask's singleton query_tokens dimension:
|
1069 |
-
# [batch*heads, 1, key_tokens] ->
|
1070 |
-
# [batch*heads, query_tokens, key_tokens]
|
1071 |
-
# so that it can be added as a bias onto the attention scores that xformers computes:
|
1072 |
-
# [batch*heads, query_tokens, key_tokens]
|
1073 |
-
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
|
1074 |
-
_, query_tokens, _ = hidden_states.shape
|
1075 |
-
attention_mask = attention_mask.expand(-1, query_tokens, -1)
|
1076 |
-
|
1077 |
-
if attn.group_norm is not None:
|
1078 |
-
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1079 |
-
|
1080 |
-
query = attn.to_q(hidden_states)
|
1081 |
-
if encoder_hidden_states is None:
|
1082 |
-
encoder_hidden_states = hidden_states
|
1083 |
-
elif attn.norm_cross:
|
1084 |
-
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1085 |
-
|
1086 |
-
key = attn.to_k(encoder_hidden_states)
|
1087 |
-
value = attn.to_v(encoder_hidden_states)
|
1088 |
-
|
1089 |
-
|
1090 |
-
# apply 4DoF CaPE, todo now only for xformer processor
|
1091 |
-
if posemb is not None:
|
1092 |
-
p_out = einops.repeat(p_out, 'b t_out d -> b (t_out l) d', l=query.shape[1]//t_out) # query shape
|
1093 |
-
if self_attn:
|
1094 |
-
p_in = p_out
|
1095 |
-
else:
|
1096 |
-
p_in = einops.repeat(p_in, 'b t_in d -> b (t_in l) d', l=key.shape[1] // t_in) # key shape
|
1097 |
-
query, key = cape_embed(p_out, p_in, query, key)
|
1098 |
-
|
1099 |
-
query = attn.head_to_batch_dim(query).contiguous()
|
1100 |
-
key = attn.head_to_batch_dim(key).contiguous()
|
1101 |
-
value = attn.head_to_batch_dim(value).contiguous()
|
1102 |
-
|
1103 |
-
# self-ttn (bm) l c x (bm) l c -> (bm) l c
|
1104 |
-
# cross-ttn (bm) l c x b (nl) c -> (bm) l c
|
1105 |
-
# reuse 2d attention for multiview attention
|
1106 |
-
# self-ttn b (ml) c x b (ml) c -> b (ml) c
|
1107 |
-
# cross-ttn b (ml) c x b (nl) c -> b (ml) c
|
1108 |
-
hidden_states = xformers.ops.memory_efficient_attention( # query: (bm) l c -> b (ml) c; key: b (nl) c
|
1109 |
-
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
1110 |
-
)
|
1111 |
-
hidden_states = hidden_states.to(query.dtype)
|
1112 |
-
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1113 |
-
|
1114 |
-
# linear proj
|
1115 |
-
hidden_states = attn.to_out[0](hidden_states)
|
1116 |
-
# dropout
|
1117 |
-
hidden_states = attn.to_out[1](hidden_states)
|
1118 |
-
|
1119 |
-
if posemb is not None:
|
1120 |
-
# reshape back
|
1121 |
-
hidden_states = einops.rearrange(hidden_states, 'b (t_out l) d -> (b t_out) l d', t_out=t_out)
|
1122 |
-
|
1123 |
-
if input_ndim == 4:
|
1124 |
-
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1125 |
-
|
1126 |
-
if attn.residual_connection:
|
1127 |
-
hidden_states = hidden_states + residual
|
1128 |
-
|
1129 |
-
hidden_states = hidden_states / attn.rescale_output_factor
|
1130 |
-
|
1131 |
-
|
1132 |
-
return hidden_states
|
1133 |
-
|
1134 |
-
|
1135 |
-
class AttnProcessor2_0:
|
1136 |
-
r"""
|
1137 |
-
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
1138 |
-
"""
|
1139 |
-
|
1140 |
-
def __init__(self):
|
1141 |
-
if not hasattr(F, "scaled_dot_product_attention"):
|
1142 |
-
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
1143 |
-
|
1144 |
-
def __call__(
|
1145 |
-
self,
|
1146 |
-
attn: Attention,
|
1147 |
-
hidden_states,
|
1148 |
-
encoder_hidden_states=None,
|
1149 |
-
attention_mask=None,
|
1150 |
-
temb=None,
|
1151 |
-
):
|
1152 |
-
residual = hidden_states
|
1153 |
-
|
1154 |
-
if attn.spatial_norm is not None:
|
1155 |
-
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1156 |
-
|
1157 |
-
input_ndim = hidden_states.ndim
|
1158 |
-
|
1159 |
-
if input_ndim == 4:
|
1160 |
-
batch_size, channel, height, width = hidden_states.shape
|
1161 |
-
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1162 |
-
|
1163 |
-
batch_size, sequence_length, _ = (
|
1164 |
-
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1165 |
-
)
|
1166 |
-
inner_dim = hidden_states.shape[-1]
|
1167 |
-
|
1168 |
-
if attention_mask is not None:
|
1169 |
-
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1170 |
-
# scaled_dot_product_attention expects attention_mask shape to be
|
1171 |
-
# (batch, heads, source_length, target_length)
|
1172 |
-
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
1173 |
-
|
1174 |
-
if attn.group_norm is not None:
|
1175 |
-
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1176 |
-
|
1177 |
-
query = attn.to_q(hidden_states)
|
1178 |
-
|
1179 |
-
if encoder_hidden_states is None:
|
1180 |
-
encoder_hidden_states = hidden_states
|
1181 |
-
elif attn.norm_cross:
|
1182 |
-
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1183 |
-
|
1184 |
-
key = attn.to_k(encoder_hidden_states)
|
1185 |
-
value = attn.to_v(encoder_hidden_states)
|
1186 |
-
|
1187 |
-
head_dim = inner_dim // attn.heads
|
1188 |
-
|
1189 |
-
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1190 |
-
|
1191 |
-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1192 |
-
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1193 |
-
|
1194 |
-
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
1195 |
-
# TODO: add support for attn.scale when we move to Torch 2.1
|
1196 |
-
hidden_states = F.scaled_dot_product_attention(
|
1197 |
-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
1198 |
-
)
|
1199 |
-
|
1200 |
-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1201 |
-
hidden_states = hidden_states.to(query.dtype)
|
1202 |
-
|
1203 |
-
# linear proj
|
1204 |
-
hidden_states = attn.to_out[0](hidden_states)
|
1205 |
-
# dropout
|
1206 |
-
hidden_states = attn.to_out[1](hidden_states)
|
1207 |
-
|
1208 |
-
if input_ndim == 4:
|
1209 |
-
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1210 |
-
|
1211 |
-
if attn.residual_connection:
|
1212 |
-
hidden_states = hidden_states + residual
|
1213 |
-
|
1214 |
-
hidden_states = hidden_states / attn.rescale_output_factor
|
1215 |
-
|
1216 |
-
return hidden_states
|
1217 |
-
|
1218 |
-
|
1219 |
-
class LoRAXFormersAttnProcessor(nn.Module):
|
1220 |
-
r"""
|
1221 |
-
Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers.
|
1222 |
-
|
1223 |
-
Args:
|
1224 |
-
hidden_size (`int`, *optional*):
|
1225 |
-
The hidden size of the attention layer.
|
1226 |
-
cross_attention_dim (`int`, *optional*):
|
1227 |
-
The number of channels in the `encoder_hidden_states`.
|
1228 |
-
rank (`int`, defaults to 4):
|
1229 |
-
The dimension of the LoRA update matrices.
|
1230 |
-
attention_op (`Callable`, *optional*, defaults to `None`):
|
1231 |
-
The base
|
1232 |
-
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
1233 |
-
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
1234 |
-
operator.
|
1235 |
-
network_alpha (`int`, *optional*):
|
1236 |
-
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
1237 |
-
|
1238 |
-
"""
|
1239 |
-
|
1240 |
-
def __init__(
|
1241 |
-
self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None
|
1242 |
-
):
|
1243 |
-
super().__init__()
|
1244 |
-
|
1245 |
-
self.hidden_size = hidden_size
|
1246 |
-
self.cross_attention_dim = cross_attention_dim
|
1247 |
-
self.rank = rank
|
1248 |
-
self.attention_op = attention_op
|
1249 |
-
|
1250 |
-
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
1251 |
-
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
1252 |
-
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
1253 |
-
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
1254 |
-
|
1255 |
-
def __call__(
|
1256 |
-
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
|
1257 |
-
):
|
1258 |
-
residual = hidden_states
|
1259 |
-
|
1260 |
-
if attn.spatial_norm is not None:
|
1261 |
-
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1262 |
-
|
1263 |
-
input_ndim = hidden_states.ndim
|
1264 |
-
|
1265 |
-
if input_ndim == 4:
|
1266 |
-
batch_size, channel, height, width = hidden_states.shape
|
1267 |
-
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1268 |
-
|
1269 |
-
batch_size, sequence_length, _ = (
|
1270 |
-
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1271 |
-
)
|
1272 |
-
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1273 |
-
|
1274 |
-
if attn.group_norm is not None:
|
1275 |
-
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1276 |
-
|
1277 |
-
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
1278 |
-
query = attn.head_to_batch_dim(query).contiguous()
|
1279 |
-
|
1280 |
-
if encoder_hidden_states is None:
|
1281 |
-
encoder_hidden_states = hidden_states
|
1282 |
-
elif attn.norm_cross:
|
1283 |
-
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1284 |
-
|
1285 |
-
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
|
1286 |
-
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
|
1287 |
-
|
1288 |
-
key = attn.head_to_batch_dim(key).contiguous()
|
1289 |
-
value = attn.head_to_batch_dim(value).contiguous()
|
1290 |
-
|
1291 |
-
hidden_states = xformers.ops.memory_efficient_attention(
|
1292 |
-
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
1293 |
-
)
|
1294 |
-
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1295 |
-
|
1296 |
-
# linear proj
|
1297 |
-
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
1298 |
-
# dropout
|
1299 |
-
hidden_states = attn.to_out[1](hidden_states)
|
1300 |
-
|
1301 |
-
if input_ndim == 4:
|
1302 |
-
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1303 |
-
|
1304 |
-
if attn.residual_connection:
|
1305 |
-
hidden_states = hidden_states + residual
|
1306 |
-
|
1307 |
-
hidden_states = hidden_states / attn.rescale_output_factor
|
1308 |
-
|
1309 |
-
return hidden_states
|
1310 |
-
|
1311 |
-
|
1312 |
-
class LoRAAttnProcessor2_0(nn.Module):
|
1313 |
-
r"""
|
1314 |
-
Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product
|
1315 |
-
attention.
|
1316 |
-
|
1317 |
-
Args:
|
1318 |
-
hidden_size (`int`):
|
1319 |
-
The hidden size of the attention layer.
|
1320 |
-
cross_attention_dim (`int`, *optional*):
|
1321 |
-
The number of channels in the `encoder_hidden_states`.
|
1322 |
-
rank (`int`, defaults to 4):
|
1323 |
-
The dimension of the LoRA update matrices.
|
1324 |
-
network_alpha (`int`, *optional*):
|
1325 |
-
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
1326 |
-
"""
|
1327 |
-
|
1328 |
-
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
|
1329 |
-
super().__init__()
|
1330 |
-
if not hasattr(F, "scaled_dot_product_attention"):
|
1331 |
-
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
1332 |
-
|
1333 |
-
self.hidden_size = hidden_size
|
1334 |
-
self.cross_attention_dim = cross_attention_dim
|
1335 |
-
self.rank = rank
|
1336 |
-
|
1337 |
-
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
1338 |
-
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
1339 |
-
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
1340 |
-
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
1341 |
-
|
1342 |
-
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
|
1343 |
-
residual = hidden_states
|
1344 |
-
|
1345 |
-
input_ndim = hidden_states.ndim
|
1346 |
-
|
1347 |
-
if input_ndim == 4:
|
1348 |
-
batch_size, channel, height, width = hidden_states.shape
|
1349 |
-
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1350 |
-
|
1351 |
-
batch_size, sequence_length, _ = (
|
1352 |
-
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1353 |
-
)
|
1354 |
-
inner_dim = hidden_states.shape[-1]
|
1355 |
-
|
1356 |
-
if attention_mask is not None:
|
1357 |
-
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1358 |
-
# scaled_dot_product_attention expects attention_mask shape to be
|
1359 |
-
# (batch, heads, source_length, target_length)
|
1360 |
-
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
1361 |
-
|
1362 |
-
if attn.group_norm is not None:
|
1363 |
-
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1364 |
-
|
1365 |
-
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
1366 |
-
|
1367 |
-
if encoder_hidden_states is None:
|
1368 |
-
encoder_hidden_states = hidden_states
|
1369 |
-
elif attn.norm_cross:
|
1370 |
-
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1371 |
-
|
1372 |
-
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
|
1373 |
-
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
|
1374 |
-
|
1375 |
-
head_dim = inner_dim // attn.heads
|
1376 |
-
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1377 |
-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1378 |
-
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1379 |
-
|
1380 |
-
# TODO: add support for attn.scale when we move to Torch 2.1
|
1381 |
-
hidden_states = F.scaled_dot_product_attention(
|
1382 |
-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
1383 |
-
)
|
1384 |
-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1385 |
-
hidden_states = hidden_states.to(query.dtype)
|
1386 |
-
|
1387 |
-
# linear proj
|
1388 |
-
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
1389 |
-
# dropout
|
1390 |
-
hidden_states = attn.to_out[1](hidden_states)
|
1391 |
-
|
1392 |
-
if input_ndim == 4:
|
1393 |
-
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1394 |
-
|
1395 |
-
if attn.residual_connection:
|
1396 |
-
hidden_states = hidden_states + residual
|
1397 |
-
|
1398 |
-
hidden_states = hidden_states / attn.rescale_output_factor
|
1399 |
-
|
1400 |
-
return hidden_states
|
1401 |
-
|
1402 |
-
|
1403 |
-
class CustomDiffusionXFormersAttnProcessor(nn.Module):
|
1404 |
-
r"""
|
1405 |
-
Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
|
1406 |
-
|
1407 |
-
Args:
|
1408 |
-
train_kv (`bool`, defaults to `True`):
|
1409 |
-
Whether to newly train the key and value matrices corresponding to the text features.
|
1410 |
-
train_q_out (`bool`, defaults to `True`):
|
1411 |
-
Whether to newly train query matrices corresponding to the latent image features.
|
1412 |
-
hidden_size (`int`, *optional*, defaults to `None`):
|
1413 |
-
The hidden size of the attention layer.
|
1414 |
-
cross_attention_dim (`int`, *optional*, defaults to `None`):
|
1415 |
-
The number of channels in the `encoder_hidden_states`.
|
1416 |
-
out_bias (`bool`, defaults to `True`):
|
1417 |
-
Whether to include the bias parameter in `train_q_out`.
|
1418 |
-
dropout (`float`, *optional*, defaults to 0.0):
|
1419 |
-
The dropout probability to use.
|
1420 |
-
attention_op (`Callable`, *optional*, defaults to `None`):
|
1421 |
-
The base
|
1422 |
-
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
|
1423 |
-
as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
|
1424 |
-
"""
|
1425 |
-
|
1426 |
-
def __init__(
|
1427 |
-
self,
|
1428 |
-
train_kv=True,
|
1429 |
-
train_q_out=False,
|
1430 |
-
hidden_size=None,
|
1431 |
-
cross_attention_dim=None,
|
1432 |
-
out_bias=True,
|
1433 |
-
dropout=0.0,
|
1434 |
-
attention_op: Optional[Callable] = None,
|
1435 |
-
):
|
1436 |
-
super().__init__()
|
1437 |
-
self.train_kv = train_kv
|
1438 |
-
self.train_q_out = train_q_out
|
1439 |
-
|
1440 |
-
self.hidden_size = hidden_size
|
1441 |
-
self.cross_attention_dim = cross_attention_dim
|
1442 |
-
self.attention_op = attention_op
|
1443 |
-
|
1444 |
-
# `_custom_diffusion` id for easy serialization and loading.
|
1445 |
-
if self.train_kv:
|
1446 |
-
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
1447 |
-
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
1448 |
-
if self.train_q_out:
|
1449 |
-
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
|
1450 |
-
self.to_out_custom_diffusion = nn.ModuleList([])
|
1451 |
-
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
|
1452 |
-
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
|
1453 |
-
|
1454 |
-
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
1455 |
-
batch_size, sequence_length, _ = (
|
1456 |
-
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1457 |
-
)
|
1458 |
-
|
1459 |
-
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1460 |
-
|
1461 |
-
if self.train_q_out:
|
1462 |
-
query = self.to_q_custom_diffusion(hidden_states)
|
1463 |
-
else:
|
1464 |
-
query = attn.to_q(hidden_states)
|
1465 |
-
|
1466 |
-
if encoder_hidden_states is None:
|
1467 |
-
crossattn = False
|
1468 |
-
encoder_hidden_states = hidden_states
|
1469 |
-
else:
|
1470 |
-
crossattn = True
|
1471 |
-
if attn.norm_cross:
|
1472 |
-
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1473 |
-
|
1474 |
-
if self.train_kv:
|
1475 |
-
key = self.to_k_custom_diffusion(encoder_hidden_states)
|
1476 |
-
value = self.to_v_custom_diffusion(encoder_hidden_states)
|
1477 |
-
else:
|
1478 |
-
key = attn.to_k(encoder_hidden_states)
|
1479 |
-
value = attn.to_v(encoder_hidden_states)
|
1480 |
-
|
1481 |
-
if crossattn:
|
1482 |
-
detach = torch.ones_like(key)
|
1483 |
-
detach[:, :1, :] = detach[:, :1, :] * 0.0
|
1484 |
-
key = detach * key + (1 - detach) * key.detach()
|
1485 |
-
value = detach * value + (1 - detach) * value.detach()
|
1486 |
-
|
1487 |
-
query = attn.head_to_batch_dim(query).contiguous()
|
1488 |
-
key = attn.head_to_batch_dim(key).contiguous()
|
1489 |
-
value = attn.head_to_batch_dim(value).contiguous()
|
1490 |
-
|
1491 |
-
hidden_states = xformers.ops.memory_efficient_attention(
|
1492 |
-
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
1493 |
-
)
|
1494 |
-
hidden_states = hidden_states.to(query.dtype)
|
1495 |
-
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1496 |
-
|
1497 |
-
if self.train_q_out:
|
1498 |
-
# linear proj
|
1499 |
-
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
|
1500 |
-
# dropout
|
1501 |
-
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
|
1502 |
-
else:
|
1503 |
-
# linear proj
|
1504 |
-
hidden_states = attn.to_out[0](hidden_states)
|
1505 |
-
# dropout
|
1506 |
-
hidden_states = attn.to_out[1](hidden_states)
|
1507 |
-
return hidden_states
|
1508 |
-
|
1509 |
-
|
1510 |
-
class SlicedAttnProcessor:
|
1511 |
-
r"""
|
1512 |
-
Processor for implementing sliced attention.
|
1513 |
-
|
1514 |
-
Args:
|
1515 |
-
slice_size (`int`, *optional*):
|
1516 |
-
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
|
1517 |
-
`attention_head_dim` must be a multiple of the `slice_size`.
|
1518 |
-
"""
|
1519 |
-
|
1520 |
-
def __init__(self, slice_size):
|
1521 |
-
self.slice_size = slice_size
|
1522 |
-
|
1523 |
-
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
1524 |
-
residual = hidden_states
|
1525 |
-
|
1526 |
-
input_ndim = hidden_states.ndim
|
1527 |
-
|
1528 |
-
if input_ndim == 4:
|
1529 |
-
batch_size, channel, height, width = hidden_states.shape
|
1530 |
-
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1531 |
-
|
1532 |
-
batch_size, sequence_length, _ = (
|
1533 |
-
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1534 |
-
)
|
1535 |
-
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1536 |
-
|
1537 |
-
if attn.group_norm is not None:
|
1538 |
-
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1539 |
-
|
1540 |
-
query = attn.to_q(hidden_states)
|
1541 |
-
dim = query.shape[-1]
|
1542 |
-
query = attn.head_to_batch_dim(query)
|
1543 |
-
|
1544 |
-
if encoder_hidden_states is None:
|
1545 |
-
encoder_hidden_states = hidden_states
|
1546 |
-
elif attn.norm_cross:
|
1547 |
-
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1548 |
-
|
1549 |
-
key = attn.to_k(encoder_hidden_states)
|
1550 |
-
value = attn.to_v(encoder_hidden_states)
|
1551 |
-
key = attn.head_to_batch_dim(key)
|
1552 |
-
value = attn.head_to_batch_dim(value)
|
1553 |
-
|
1554 |
-
batch_size_attention, query_tokens, _ = query.shape
|
1555 |
-
hidden_states = torch.zeros(
|
1556 |
-
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
1557 |
-
)
|
1558 |
-
|
1559 |
-
for i in range(batch_size_attention // self.slice_size):
|
1560 |
-
start_idx = i * self.slice_size
|
1561 |
-
end_idx = (i + 1) * self.slice_size
|
1562 |
-
|
1563 |
-
query_slice = query[start_idx:end_idx]
|
1564 |
-
key_slice = key[start_idx:end_idx]
|
1565 |
-
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
1566 |
-
|
1567 |
-
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
1568 |
-
|
1569 |
-
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
1570 |
-
|
1571 |
-
hidden_states[start_idx:end_idx] = attn_slice
|
1572 |
-
|
1573 |
-
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1574 |
-
|
1575 |
-
# linear proj
|
1576 |
-
hidden_states = attn.to_out[0](hidden_states)
|
1577 |
-
# dropout
|
1578 |
-
hidden_states = attn.to_out[1](hidden_states)
|
1579 |
-
|
1580 |
-
if input_ndim == 4:
|
1581 |
-
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1582 |
-
|
1583 |
-
if attn.residual_connection:
|
1584 |
-
hidden_states = hidden_states + residual
|
1585 |
-
|
1586 |
-
hidden_states = hidden_states / attn.rescale_output_factor
|
1587 |
-
|
1588 |
-
return hidden_states
|
1589 |
-
|
1590 |
-
|
1591 |
-
class SlicedAttnAddedKVProcessor:
|
1592 |
-
r"""
|
1593 |
-
Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder.
|
1594 |
-
|
1595 |
-
Args:
|
1596 |
-
slice_size (`int`, *optional*):
|
1597 |
-
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
|
1598 |
-
`attention_head_dim` must be a multiple of the `slice_size`.
|
1599 |
-
"""
|
1600 |
-
|
1601 |
-
def __init__(self, slice_size):
|
1602 |
-
self.slice_size = slice_size
|
1603 |
-
|
1604 |
-
def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
|
1605 |
-
residual = hidden_states
|
1606 |
-
|
1607 |
-
if attn.spatial_norm is not None:
|
1608 |
-
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1609 |
-
|
1610 |
-
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
1611 |
-
|
1612 |
-
batch_size, sequence_length, _ = hidden_states.shape
|
1613 |
-
|
1614 |
-
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1615 |
-
|
1616 |
-
if encoder_hidden_states is None:
|
1617 |
-
encoder_hidden_states = hidden_states
|
1618 |
-
elif attn.norm_cross:
|
1619 |
-
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1620 |
-
|
1621 |
-
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1622 |
-
|
1623 |
-
query = attn.to_q(hidden_states)
|
1624 |
-
dim = query.shape[-1]
|
1625 |
-
query = attn.head_to_batch_dim(query)
|
1626 |
-
|
1627 |
-
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
1628 |
-
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
1629 |
-
|
1630 |
-
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
1631 |
-
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
1632 |
-
|
1633 |
-
if not attn.only_cross_attention:
|
1634 |
-
key = attn.to_k(hidden_states)
|
1635 |
-
value = attn.to_v(hidden_states)
|
1636 |
-
key = attn.head_to_batch_dim(key)
|
1637 |
-
value = attn.head_to_batch_dim(value)
|
1638 |
-
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
1639 |
-
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
1640 |
-
else:
|
1641 |
-
key = encoder_hidden_states_key_proj
|
1642 |
-
value = encoder_hidden_states_value_proj
|
1643 |
-
|
1644 |
-
batch_size_attention, query_tokens, _ = query.shape
|
1645 |
-
hidden_states = torch.zeros(
|
1646 |
-
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
1647 |
-
)
|
1648 |
-
|
1649 |
-
for i in range(batch_size_attention // self.slice_size):
|
1650 |
-
start_idx = i * self.slice_size
|
1651 |
-
end_idx = (i + 1) * self.slice_size
|
1652 |
-
|
1653 |
-
query_slice = query[start_idx:end_idx]
|
1654 |
-
key_slice = key[start_idx:end_idx]
|
1655 |
-
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
1656 |
-
|
1657 |
-
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
1658 |
-
|
1659 |
-
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
1660 |
-
|
1661 |
-
hidden_states[start_idx:end_idx] = attn_slice
|
1662 |
-
|
1663 |
-
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1664 |
-
|
1665 |
-
# linear proj
|
1666 |
-
hidden_states = attn.to_out[0](hidden_states)
|
1667 |
-
# dropout
|
1668 |
-
hidden_states = attn.to_out[1](hidden_states)
|
1669 |
-
|
1670 |
-
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
1671 |
-
hidden_states = hidden_states + residual
|
1672 |
-
|
1673 |
-
return hidden_states
|
1674 |
-
|
1675 |
-
|
1676 |
-
AttentionProcessor = Union[
|
1677 |
-
AttnProcessor,
|
1678 |
-
AttnProcessor2_0,
|
1679 |
-
XFormersAttnProcessor,
|
1680 |
-
SlicedAttnProcessor,
|
1681 |
-
AttnAddedKVProcessor,
|
1682 |
-
SlicedAttnAddedKVProcessor,
|
1683 |
-
AttnAddedKVProcessor2_0,
|
1684 |
-
XFormersAttnAddedKVProcessor,
|
1685 |
-
LoRAAttnProcessor,
|
1686 |
-
LoRAXFormersAttnProcessor,
|
1687 |
-
LoRAAttnProcessor2_0,
|
1688 |
-
LoRAAttnAddedKVProcessor,
|
1689 |
-
CustomDiffusionAttnProcessor,
|
1690 |
-
CustomDiffusionXFormersAttnProcessor,
|
1691 |
-
]
|
1692 |
-
|
1693 |
-
|
1694 |
-
class SpatialNorm(nn.Module):
|
1695 |
-
"""
|
1696 |
-
Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002
|
1697 |
-
"""
|
1698 |
-
|
1699 |
-
def __init__(
|
1700 |
-
self,
|
1701 |
-
f_channels,
|
1702 |
-
zq_channels,
|
1703 |
-
):
|
1704 |
-
super().__init__()
|
1705 |
-
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
|
1706 |
-
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
1707 |
-
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
1708 |
-
|
1709 |
-
def forward(self, f, zq):
|
1710 |
-
f_size = f.shape[-2:]
|
1711 |
-
zq = F.interpolate(zq, size=f_size, mode="nearest")
|
1712 |
-
norm_f = self.norm_layer(f)
|
1713 |
-
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
1714 |
-
return new_f
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4DoF/diffusers/models/autoencoder_kl.py
DELETED
@@ -1,411 +0,0 @@
|
|
1 |
-
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
from dataclasses import dataclass
|
15 |
-
from typing import Dict, Optional, Tuple, Union
|
16 |
-
|
17 |
-
import torch
|
18 |
-
import torch.nn as nn
|
19 |
-
|
20 |
-
from ..configuration_utils import ConfigMixin, register_to_config
|
21 |
-
from ..utils import BaseOutput, apply_forward_hook
|
22 |
-
from .attention_processor import AttentionProcessor, AttnProcessor
|
23 |
-
from .modeling_utils import ModelMixin
|
24 |
-
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
|
25 |
-
|
26 |
-
|
27 |
-
@dataclass
|
28 |
-
class AutoencoderKLOutput(BaseOutput):
|
29 |
-
"""
|
30 |
-
Output of AutoencoderKL encoding method.
|
31 |
-
|
32 |
-
Args:
|
33 |
-
latent_dist (`DiagonalGaussianDistribution`):
|
34 |
-
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
|
35 |
-
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
|
36 |
-
"""
|
37 |
-
|
38 |
-
latent_dist: "DiagonalGaussianDistribution"
|
39 |
-
|
40 |
-
|
41 |
-
class AutoencoderKL(ModelMixin, ConfigMixin):
|
42 |
-
r"""
|
43 |
-
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
|
44 |
-
|
45 |
-
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
46 |
-
for all models (such as downloading or saving).
|
47 |
-
|
48 |
-
Parameters:
|
49 |
-
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
50 |
-
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
51 |
-
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
52 |
-
Tuple of downsample block types.
|
53 |
-
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
54 |
-
Tuple of upsample block types.
|
55 |
-
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
56 |
-
Tuple of block output channels.
|
57 |
-
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
58 |
-
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
|
59 |
-
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
60 |
-
scaling_factor (`float`, *optional*, defaults to 0.18215):
|
61 |
-
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
62 |
-
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
63 |
-
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
64 |
-
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
65 |
-
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
66 |
-
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
67 |
-
"""
|
68 |
-
|
69 |
-
_supports_gradient_checkpointing = True
|
70 |
-
|
71 |
-
@register_to_config
|
72 |
-
def __init__(
|
73 |
-
self,
|
74 |
-
in_channels: int = 3,
|
75 |
-
out_channels: int = 3,
|
76 |
-
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
77 |
-
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
|
78 |
-
block_out_channels: Tuple[int] = (64,),
|
79 |
-
layers_per_block: int = 1,
|
80 |
-
act_fn: str = "silu",
|
81 |
-
latent_channels: int = 4,
|
82 |
-
norm_num_groups: int = 32,
|
83 |
-
sample_size: int = 32,
|
84 |
-
scaling_factor: float = 0.18215,
|
85 |
-
):
|
86 |
-
super().__init__()
|
87 |
-
|
88 |
-
# pass init params to Encoder
|
89 |
-
self.encoder = Encoder(
|
90 |
-
in_channels=in_channels,
|
91 |
-
out_channels=latent_channels,
|
92 |
-
down_block_types=down_block_types,
|
93 |
-
block_out_channels=block_out_channels,
|
94 |
-
layers_per_block=layers_per_block,
|
95 |
-
act_fn=act_fn,
|
96 |
-
norm_num_groups=norm_num_groups,
|
97 |
-
double_z=True,
|
98 |
-
)
|
99 |
-
|
100 |
-
# pass init params to Decoder
|
101 |
-
self.decoder = Decoder(
|
102 |
-
in_channels=latent_channels,
|
103 |
-
out_channels=out_channels,
|
104 |
-
up_block_types=up_block_types,
|
105 |
-
block_out_channels=block_out_channels,
|
106 |
-
layers_per_block=layers_per_block,
|
107 |
-
norm_num_groups=norm_num_groups,
|
108 |
-
act_fn=act_fn,
|
109 |
-
)
|
110 |
-
|
111 |
-
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
112 |
-
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
|
113 |
-
|
114 |
-
self.use_slicing = False
|
115 |
-
self.use_tiling = False
|
116 |
-
|
117 |
-
# only relevant if vae tiling is enabled
|
118 |
-
self.tile_sample_min_size = self.config.sample_size
|
119 |
-
sample_size = (
|
120 |
-
self.config.sample_size[0]
|
121 |
-
if isinstance(self.config.sample_size, (list, tuple))
|
122 |
-
else self.config.sample_size
|
123 |
-
)
|
124 |
-
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
|
125 |
-
self.tile_overlap_factor = 0.25
|
126 |
-
|
127 |
-
def _set_gradient_checkpointing(self, module, value=False):
|
128 |
-
if isinstance(module, (Encoder, Decoder)):
|
129 |
-
module.gradient_checkpointing = value
|
130 |
-
|
131 |
-
def enable_tiling(self, use_tiling: bool = True):
|
132 |
-
r"""
|
133 |
-
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
134 |
-
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
135 |
-
processing larger images.
|
136 |
-
"""
|
137 |
-
self.use_tiling = use_tiling
|
138 |
-
|
139 |
-
def disable_tiling(self):
|
140 |
-
r"""
|
141 |
-
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
142 |
-
decoding in one step.
|
143 |
-
"""
|
144 |
-
self.enable_tiling(False)
|
145 |
-
|
146 |
-
def enable_slicing(self):
|
147 |
-
r"""
|
148 |
-
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
149 |
-
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
150 |
-
"""
|
151 |
-
self.use_slicing = True
|
152 |
-
|
153 |
-
def disable_slicing(self):
|
154 |
-
r"""
|
155 |
-
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
156 |
-
decoding in one step.
|
157 |
-
"""
|
158 |
-
self.use_slicing = False
|
159 |
-
|
160 |
-
@property
|
161 |
-
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
162 |
-
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
163 |
-
r"""
|
164 |
-
Returns:
|
165 |
-
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
166 |
-
indexed by its weight name.
|
167 |
-
"""
|
168 |
-
# set recursively
|
169 |
-
processors = {}
|
170 |
-
|
171 |
-
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
172 |
-
if hasattr(module, "set_processor"):
|
173 |
-
processors[f"{name}.processor"] = module.processor
|
174 |
-
|
175 |
-
for sub_name, child in module.named_children():
|
176 |
-
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
177 |
-
|
178 |
-
return processors
|
179 |
-
|
180 |
-
for name, module in self.named_children():
|
181 |
-
fn_recursive_add_processors(name, module, processors)
|
182 |
-
|
183 |
-
return processors
|
184 |
-
|
185 |
-
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
186 |
-
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
187 |
-
r"""
|
188 |
-
Sets the attention processor to use to compute attention.
|
189 |
-
|
190 |
-
Parameters:
|
191 |
-
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
192 |
-
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
193 |
-
for **all** `Attention` layers.
|
194 |
-
|
195 |
-
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
196 |
-
processor. This is strongly recommended when setting trainable attention processors.
|
197 |
-
|
198 |
-
"""
|
199 |
-
count = len(self.attn_processors.keys())
|
200 |
-
|
201 |
-
if isinstance(processor, dict) and len(processor) != count:
|
202 |
-
raise ValueError(
|
203 |
-
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
204 |
-
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
205 |
-
)
|
206 |
-
|
207 |
-
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
208 |
-
if hasattr(module, "set_processor"):
|
209 |
-
if not isinstance(processor, dict):
|
210 |
-
module.set_processor(processor)
|
211 |
-
else:
|
212 |
-
module.set_processor(processor.pop(f"{name}.processor"))
|
213 |
-
|
214 |
-
for sub_name, child in module.named_children():
|
215 |
-
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
216 |
-
|
217 |
-
for name, module in self.named_children():
|
218 |
-
fn_recursive_attn_processor(name, module, processor)
|
219 |
-
|
220 |
-
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
221 |
-
def set_default_attn_processor(self):
|
222 |
-
"""
|
223 |
-
Disables custom attention processors and sets the default attention implementation.
|
224 |
-
"""
|
225 |
-
self.set_attn_processor(AttnProcessor())
|
226 |
-
|
227 |
-
@apply_forward_hook
|
228 |
-
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
229 |
-
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
230 |
-
return self.tiled_encode(x, return_dict=return_dict)
|
231 |
-
|
232 |
-
if self.use_slicing and x.shape[0] > 1:
|
233 |
-
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
234 |
-
h = torch.cat(encoded_slices)
|
235 |
-
else:
|
236 |
-
h = self.encoder(x)
|
237 |
-
|
238 |
-
moments = self.quant_conv(h)
|
239 |
-
posterior = DiagonalGaussianDistribution(moments)
|
240 |
-
|
241 |
-
if not return_dict:
|
242 |
-
return (posterior,)
|
243 |
-
|
244 |
-
return AutoencoderKLOutput(latent_dist=posterior)
|
245 |
-
|
246 |
-
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
247 |
-
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
|
248 |
-
return self.tiled_decode(z, return_dict=return_dict)
|
249 |
-
|
250 |
-
z = self.post_quant_conv(z)
|
251 |
-
dec = self.decoder(z)
|
252 |
-
|
253 |
-
if not return_dict:
|
254 |
-
return (dec,)
|
255 |
-
|
256 |
-
return DecoderOutput(sample=dec)
|
257 |
-
|
258 |
-
@apply_forward_hook
|
259 |
-
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
260 |
-
if self.use_slicing and z.shape[0] > 1:
|
261 |
-
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
262 |
-
decoded = torch.cat(decoded_slices)
|
263 |
-
else:
|
264 |
-
decoded = self._decode(z).sample
|
265 |
-
|
266 |
-
if not return_dict:
|
267 |
-
return (decoded,)
|
268 |
-
|
269 |
-
return DecoderOutput(sample=decoded)
|
270 |
-
|
271 |
-
def blend_v(self, a, b, blend_extent):
|
272 |
-
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
273 |
-
for y in range(blend_extent):
|
274 |
-
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
|
275 |
-
return b
|
276 |
-
|
277 |
-
def blend_h(self, a, b, blend_extent):
|
278 |
-
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
279 |
-
for x in range(blend_extent):
|
280 |
-
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
281 |
-
return b
|
282 |
-
|
283 |
-
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
284 |
-
r"""Encode a batch of images using a tiled encoder.
|
285 |
-
|
286 |
-
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
287 |
-
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
288 |
-
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
289 |
-
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
290 |
-
output, but they should be much less noticeable.
|
291 |
-
|
292 |
-
Args:
|
293 |
-
x (`torch.FloatTensor`): Input batch of images.
|
294 |
-
return_dict (`bool`, *optional*, defaults to `True`):
|
295 |
-
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
296 |
-
|
297 |
-
Returns:
|
298 |
-
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
|
299 |
-
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
|
300 |
-
`tuple` is returned.
|
301 |
-
"""
|
302 |
-
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
303 |
-
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
304 |
-
row_limit = self.tile_latent_min_size - blend_extent
|
305 |
-
|
306 |
-
# Split the image into 512x512 tiles and encode them separately.
|
307 |
-
rows = []
|
308 |
-
for i in range(0, x.shape[2], overlap_size):
|
309 |
-
row = []
|
310 |
-
for j in range(0, x.shape[3], overlap_size):
|
311 |
-
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
312 |
-
tile = self.encoder(tile)
|
313 |
-
tile = self.quant_conv(tile)
|
314 |
-
row.append(tile)
|
315 |
-
rows.append(row)
|
316 |
-
result_rows = []
|
317 |
-
for i, row in enumerate(rows):
|
318 |
-
result_row = []
|
319 |
-
for j, tile in enumerate(row):
|
320 |
-
# blend the above tile and the left tile
|
321 |
-
# to the current tile and add the current tile to the result row
|
322 |
-
if i > 0:
|
323 |
-
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
324 |
-
if j > 0:
|
325 |
-
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
326 |
-
result_row.append(tile[:, :, :row_limit, :row_limit])
|
327 |
-
result_rows.append(torch.cat(result_row, dim=3))
|
328 |
-
|
329 |
-
moments = torch.cat(result_rows, dim=2)
|
330 |
-
posterior = DiagonalGaussianDistribution(moments)
|
331 |
-
|
332 |
-
if not return_dict:
|
333 |
-
return (posterior,)
|
334 |
-
|
335 |
-
return AutoencoderKLOutput(latent_dist=posterior)
|
336 |
-
|
337 |
-
def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
338 |
-
r"""
|
339 |
-
Decode a batch of images using a tiled decoder.
|
340 |
-
|
341 |
-
Args:
|
342 |
-
z (`torch.FloatTensor`): Input batch of latent vectors.
|
343 |
-
return_dict (`bool`, *optional*, defaults to `True`):
|
344 |
-
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
345 |
-
|
346 |
-
Returns:
|
347 |
-
[`~models.vae.DecoderOutput`] or `tuple`:
|
348 |
-
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
349 |
-
returned.
|
350 |
-
"""
|
351 |
-
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
352 |
-
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
353 |
-
row_limit = self.tile_sample_min_size - blend_extent
|
354 |
-
|
355 |
-
# Split z into overlapping 64x64 tiles and decode them separately.
|
356 |
-
# The tiles have an overlap to avoid seams between tiles.
|
357 |
-
rows = []
|
358 |
-
for i in range(0, z.shape[2], overlap_size):
|
359 |
-
row = []
|
360 |
-
for j in range(0, z.shape[3], overlap_size):
|
361 |
-
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
|
362 |
-
tile = self.post_quant_conv(tile)
|
363 |
-
decoded = self.decoder(tile)
|
364 |
-
row.append(decoded)
|
365 |
-
rows.append(row)
|
366 |
-
result_rows = []
|
367 |
-
for i, row in enumerate(rows):
|
368 |
-
result_row = []
|
369 |
-
for j, tile in enumerate(row):
|
370 |
-
# blend the above tile and the left tile
|
371 |
-
# to the current tile and add the current tile to the result row
|
372 |
-
if i > 0:
|
373 |
-
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
374 |
-
if j > 0:
|
375 |
-
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
376 |
-
result_row.append(tile[:, :, :row_limit, :row_limit])
|
377 |
-
result_rows.append(torch.cat(result_row, dim=3))
|
378 |
-
|
379 |
-
dec = torch.cat(result_rows, dim=2)
|
380 |
-
if not return_dict:
|
381 |
-
return (dec,)
|
382 |
-
|
383 |
-
return DecoderOutput(sample=dec)
|
384 |
-
|
385 |
-
def forward(
|
386 |
-
self,
|
387 |
-
sample: torch.FloatTensor,
|
388 |
-
sample_posterior: bool = False,
|
389 |
-
return_dict: bool = True,
|
390 |
-
generator: Optional[torch.Generator] = None,
|
391 |
-
) -> Union[DecoderOutput, torch.FloatTensor]:
|
392 |
-
r"""
|
393 |
-
Args:
|
394 |
-
sample (`torch.FloatTensor`): Input sample.
|
395 |
-
sample_posterior (`bool`, *optional*, defaults to `False`):
|
396 |
-
Whether to sample from the posterior.
|
397 |
-
return_dict (`bool`, *optional*, defaults to `True`):
|
398 |
-
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
399 |
-
"""
|
400 |
-
x = sample
|
401 |
-
posterior = self.encode(x).latent_dist
|
402 |
-
if sample_posterior:
|
403 |
-
z = posterior.sample(generator=generator)
|
404 |
-
else:
|
405 |
-
z = posterior.mode()
|
406 |
-
dec = self.decode(z).sample
|
407 |
-
|
408 |
-
if not return_dict:
|
409 |
-
return (dec,)
|
410 |
-
|
411 |
-
return DecoderOutput(sample=dec)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4DoF/diffusers/models/controlnet.py
DELETED
@@ -1,705 +0,0 @@
|
|
1 |
-
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
from dataclasses import dataclass
|
15 |
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
16 |
-
|
17 |
-
import torch
|
18 |
-
from torch import nn
|
19 |
-
from torch.nn import functional as F
|
20 |
-
|
21 |
-
from ..configuration_utils import ConfigMixin, register_to_config
|
22 |
-
from ..utils import BaseOutput, logging
|
23 |
-
from .attention_processor import AttentionProcessor, AttnProcessor
|
24 |
-
from .embeddings import TimestepEmbedding, Timesteps
|
25 |
-
from .modeling_utils import ModelMixin
|
26 |
-
from .unet_2d_blocks import (
|
27 |
-
CrossAttnDownBlock2D,
|
28 |
-
DownBlock2D,
|
29 |
-
UNetMidBlock2DCrossAttn,
|
30 |
-
get_down_block,
|
31 |
-
)
|
32 |
-
from .unet_2d_condition import UNet2DConditionModel
|
33 |
-
|
34 |
-
|
35 |
-
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
36 |
-
|
37 |
-
|
38 |
-
@dataclass
|
39 |
-
class ControlNetOutput(BaseOutput):
|
40 |
-
"""
|
41 |
-
The output of [`ControlNetModel`].
|
42 |
-
|
43 |
-
Args:
|
44 |
-
down_block_res_samples (`tuple[torch.Tensor]`):
|
45 |
-
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
|
46 |
-
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
47 |
-
used to condition the original UNet's downsampling activations.
|
48 |
-
mid_down_block_re_sample (`torch.Tensor`):
|
49 |
-
The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
|
50 |
-
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
|
51 |
-
Output can be used to condition the original UNet's middle block activation.
|
52 |
-
"""
|
53 |
-
|
54 |
-
down_block_res_samples: Tuple[torch.Tensor]
|
55 |
-
mid_block_res_sample: torch.Tensor
|
56 |
-
|
57 |
-
|
58 |
-
class ControlNetConditioningEmbedding(nn.Module):
|
59 |
-
"""
|
60 |
-
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
|
61 |
-
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
|
62 |
-
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
|
63 |
-
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
|
64 |
-
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
|
65 |
-
model) to encode image-space conditions ... into feature maps ..."
|
66 |
-
"""
|
67 |
-
|
68 |
-
def __init__(
|
69 |
-
self,
|
70 |
-
conditioning_embedding_channels: int,
|
71 |
-
conditioning_channels: int = 3,
|
72 |
-
block_out_channels: Tuple[int] = (16, 32, 96, 256),
|
73 |
-
):
|
74 |
-
super().__init__()
|
75 |
-
|
76 |
-
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
|
77 |
-
|
78 |
-
self.blocks = nn.ModuleList([])
|
79 |
-
|
80 |
-
for i in range(len(block_out_channels) - 1):
|
81 |
-
channel_in = block_out_channels[i]
|
82 |
-
channel_out = block_out_channels[i + 1]
|
83 |
-
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
|
84 |
-
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
|
85 |
-
|
86 |
-
self.conv_out = zero_module(
|
87 |
-
nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
|
88 |
-
)
|
89 |
-
|
90 |
-
def forward(self, conditioning):
|
91 |
-
embedding = self.conv_in(conditioning)
|
92 |
-
embedding = F.silu(embedding)
|
93 |
-
|
94 |
-
for block in self.blocks:
|
95 |
-
embedding = block(embedding)
|
96 |
-
embedding = F.silu(embedding)
|
97 |
-
|
98 |
-
embedding = self.conv_out(embedding)
|
99 |
-
|
100 |
-
return embedding
|
101 |
-
|
102 |
-
|
103 |
-
class ControlNetModel(ModelMixin, ConfigMixin):
|
104 |
-
"""
|
105 |
-
A ControlNet model.
|
106 |
-
|
107 |
-
Args:
|
108 |
-
in_channels (`int`, defaults to 4):
|
109 |
-
The number of channels in the input sample.
|
110 |
-
flip_sin_to_cos (`bool`, defaults to `True`):
|
111 |
-
Whether to flip the sin to cos in the time embedding.
|
112 |
-
freq_shift (`int`, defaults to 0):
|
113 |
-
The frequency shift to apply to the time embedding.
|
114 |
-
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
115 |
-
The tuple of downsample blocks to use.
|
116 |
-
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
117 |
-
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
118 |
-
The tuple of output channels for each block.
|
119 |
-
layers_per_block (`int`, defaults to 2):
|
120 |
-
The number of layers per block.
|
121 |
-
downsample_padding (`int`, defaults to 1):
|
122 |
-
The padding to use for the downsampling convolution.
|
123 |
-
mid_block_scale_factor (`float`, defaults to 1):
|
124 |
-
The scale factor to use for the mid block.
|
125 |
-
act_fn (`str`, defaults to "silu"):
|
126 |
-
The activation function to use.
|
127 |
-
norm_num_groups (`int`, *optional*, defaults to 32):
|
128 |
-
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
129 |
-
in post-processing.
|
130 |
-
norm_eps (`float`, defaults to 1e-5):
|
131 |
-
The epsilon to use for the normalization.
|
132 |
-
cross_attention_dim (`int`, defaults to 1280):
|
133 |
-
The dimension of the cross attention features.
|
134 |
-
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
135 |
-
The dimension of the attention heads.
|
136 |
-
use_linear_projection (`bool`, defaults to `False`):
|
137 |
-
class_embed_type (`str`, *optional*, defaults to `None`):
|
138 |
-
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
139 |
-
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
140 |
-
num_class_embeds (`int`, *optional*, defaults to 0):
|
141 |
-
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
142 |
-
class conditioning with `class_embed_type` equal to `None`.
|
143 |
-
upcast_attention (`bool`, defaults to `False`):
|
144 |
-
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
145 |
-
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
146 |
-
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
|
147 |
-
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
|
148 |
-
`class_embed_type="projection"`.
|
149 |
-
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
|
150 |
-
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
151 |
-
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
|
152 |
-
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
153 |
-
global_pool_conditions (`bool`, defaults to `False`):
|
154 |
-
"""
|
155 |
-
|
156 |
-
_supports_gradient_checkpointing = True
|
157 |
-
|
158 |
-
@register_to_config
|
159 |
-
def __init__(
|
160 |
-
self,
|
161 |
-
in_channels: int = 4,
|
162 |
-
conditioning_channels: int = 3,
|
163 |
-
flip_sin_to_cos: bool = True,
|
164 |
-
freq_shift: int = 0,
|
165 |
-
down_block_types: Tuple[str] = (
|
166 |
-
"CrossAttnDownBlock2D",
|
167 |
-
"CrossAttnDownBlock2D",
|
168 |
-
"CrossAttnDownBlock2D",
|
169 |
-
"DownBlock2D",
|
170 |
-
),
|
171 |
-
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
172 |
-
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
173 |
-
layers_per_block: int = 2,
|
174 |
-
downsample_padding: int = 1,
|
175 |
-
mid_block_scale_factor: float = 1,
|
176 |
-
act_fn: str = "silu",
|
177 |
-
norm_num_groups: Optional[int] = 32,
|
178 |
-
norm_eps: float = 1e-5,
|
179 |
-
cross_attention_dim: int = 1280,
|
180 |
-
attention_head_dim: Union[int, Tuple[int]] = 8,
|
181 |
-
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
182 |
-
use_linear_projection: bool = False,
|
183 |
-
class_embed_type: Optional[str] = None,
|
184 |
-
num_class_embeds: Optional[int] = None,
|
185 |
-
upcast_attention: bool = False,
|
186 |
-
resnet_time_scale_shift: str = "default",
|
187 |
-
projection_class_embeddings_input_dim: Optional[int] = None,
|
188 |
-
controlnet_conditioning_channel_order: str = "rgb",
|
189 |
-
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
190 |
-
global_pool_conditions: bool = False,
|
191 |
-
):
|
192 |
-
super().__init__()
|
193 |
-
|
194 |
-
# If `num_attention_heads` is not defined (which is the case for most models)
|
195 |
-
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
196 |
-
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
197 |
-
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
198 |
-
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
199 |
-
# which is why we correct for the naming here.
|
200 |
-
num_attention_heads = num_attention_heads or attention_head_dim
|
201 |
-
|
202 |
-
# Check inputs
|
203 |
-
if len(block_out_channels) != len(down_block_types):
|
204 |
-
raise ValueError(
|
205 |
-
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
206 |
-
)
|
207 |
-
|
208 |
-
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
209 |
-
raise ValueError(
|
210 |
-
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
211 |
-
)
|
212 |
-
|
213 |
-
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
214 |
-
raise ValueError(
|
215 |
-
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
216 |
-
)
|
217 |
-
|
218 |
-
# input
|
219 |
-
conv_in_kernel = 3
|
220 |
-
conv_in_padding = (conv_in_kernel - 1) // 2
|
221 |
-
self.conv_in = nn.Conv2d(
|
222 |
-
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
223 |
-
)
|
224 |
-
|
225 |
-
# time
|
226 |
-
time_embed_dim = block_out_channels[0] * 4
|
227 |
-
|
228 |
-
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
229 |
-
timestep_input_dim = block_out_channels[0]
|
230 |
-
|
231 |
-
self.time_embedding = TimestepEmbedding(
|
232 |
-
timestep_input_dim,
|
233 |
-
time_embed_dim,
|
234 |
-
act_fn=act_fn,
|
235 |
-
)
|
236 |
-
|
237 |
-
# class embedding
|
238 |
-
if class_embed_type is None and num_class_embeds is not None:
|
239 |
-
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
240 |
-
elif class_embed_type == "timestep":
|
241 |
-
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
242 |
-
elif class_embed_type == "identity":
|
243 |
-
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
244 |
-
elif class_embed_type == "projection":
|
245 |
-
if projection_class_embeddings_input_dim is None:
|
246 |
-
raise ValueError(
|
247 |
-
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
248 |
-
)
|
249 |
-
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
250 |
-
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
251 |
-
# 2. it projects from an arbitrary input dimension.
|
252 |
-
#
|
253 |
-
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
254 |
-
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
255 |
-
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
256 |
-
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
257 |
-
else:
|
258 |
-
self.class_embedding = None
|
259 |
-
|
260 |
-
# control net conditioning embedding
|
261 |
-
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
|
262 |
-
conditioning_embedding_channels=block_out_channels[0],
|
263 |
-
block_out_channels=conditioning_embedding_out_channels,
|
264 |
-
conditioning_channels=conditioning_channels,
|
265 |
-
)
|
266 |
-
|
267 |
-
self.down_blocks = nn.ModuleList([])
|
268 |
-
self.controlnet_down_blocks = nn.ModuleList([])
|
269 |
-
|
270 |
-
if isinstance(only_cross_attention, bool):
|
271 |
-
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
272 |
-
|
273 |
-
if isinstance(attention_head_dim, int):
|
274 |
-
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
275 |
-
|
276 |
-
if isinstance(num_attention_heads, int):
|
277 |
-
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
278 |
-
|
279 |
-
# down
|
280 |
-
output_channel = block_out_channels[0]
|
281 |
-
|
282 |
-
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
283 |
-
controlnet_block = zero_module(controlnet_block)
|
284 |
-
self.controlnet_down_blocks.append(controlnet_block)
|
285 |
-
|
286 |
-
for i, down_block_type in enumerate(down_block_types):
|
287 |
-
input_channel = output_channel
|
288 |
-
output_channel = block_out_channels[i]
|
289 |
-
is_final_block = i == len(block_out_channels) - 1
|
290 |
-
|
291 |
-
down_block = get_down_block(
|
292 |
-
down_block_type,
|
293 |
-
num_layers=layers_per_block,
|
294 |
-
in_channels=input_channel,
|
295 |
-
out_channels=output_channel,
|
296 |
-
temb_channels=time_embed_dim,
|
297 |
-
add_downsample=not is_final_block,
|
298 |
-
resnet_eps=norm_eps,
|
299 |
-
resnet_act_fn=act_fn,
|
300 |
-
resnet_groups=norm_num_groups,
|
301 |
-
cross_attention_dim=cross_attention_dim,
|
302 |
-
num_attention_heads=num_attention_heads[i],
|
303 |
-
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
304 |
-
downsample_padding=downsample_padding,
|
305 |
-
use_linear_projection=use_linear_projection,
|
306 |
-
only_cross_attention=only_cross_attention[i],
|
307 |
-
upcast_attention=upcast_attention,
|
308 |
-
resnet_time_scale_shift=resnet_time_scale_shift,
|
309 |
-
)
|
310 |
-
self.down_blocks.append(down_block)
|
311 |
-
|
312 |
-
for _ in range(layers_per_block):
|
313 |
-
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
314 |
-
controlnet_block = zero_module(controlnet_block)
|
315 |
-
self.controlnet_down_blocks.append(controlnet_block)
|
316 |
-
|
317 |
-
if not is_final_block:
|
318 |
-
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
319 |
-
controlnet_block = zero_module(controlnet_block)
|
320 |
-
self.controlnet_down_blocks.append(controlnet_block)
|
321 |
-
|
322 |
-
# mid
|
323 |
-
mid_block_channel = block_out_channels[-1]
|
324 |
-
|
325 |
-
controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
|
326 |
-
controlnet_block = zero_module(controlnet_block)
|
327 |
-
self.controlnet_mid_block = controlnet_block
|
328 |
-
|
329 |
-
self.mid_block = UNetMidBlock2DCrossAttn(
|
330 |
-
in_channels=mid_block_channel,
|
331 |
-
temb_channels=time_embed_dim,
|
332 |
-
resnet_eps=norm_eps,
|
333 |
-
resnet_act_fn=act_fn,
|
334 |
-
output_scale_factor=mid_block_scale_factor,
|
335 |
-
resnet_time_scale_shift=resnet_time_scale_shift,
|
336 |
-
cross_attention_dim=cross_attention_dim,
|
337 |
-
num_attention_heads=num_attention_heads[-1],
|
338 |
-
resnet_groups=norm_num_groups,
|
339 |
-
use_linear_projection=use_linear_projection,
|
340 |
-
upcast_attention=upcast_attention,
|
341 |
-
)
|
342 |
-
|
343 |
-
@classmethod
|
344 |
-
def from_unet(
|
345 |
-
cls,
|
346 |
-
unet: UNet2DConditionModel,
|
347 |
-
controlnet_conditioning_channel_order: str = "rgb",
|
348 |
-
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
349 |
-
load_weights_from_unet: bool = True,
|
350 |
-
):
|
351 |
-
r"""
|
352 |
-
Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
|
353 |
-
|
354 |
-
Parameters:
|
355 |
-
unet (`UNet2DConditionModel`):
|
356 |
-
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
|
357 |
-
where applicable.
|
358 |
-
"""
|
359 |
-
controlnet = cls(
|
360 |
-
in_channels=unet.config.in_channels,
|
361 |
-
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
362 |
-
freq_shift=unet.config.freq_shift,
|
363 |
-
down_block_types=unet.config.down_block_types,
|
364 |
-
only_cross_attention=unet.config.only_cross_attention,
|
365 |
-
block_out_channels=unet.config.block_out_channels,
|
366 |
-
layers_per_block=unet.config.layers_per_block,
|
367 |
-
downsample_padding=unet.config.downsample_padding,
|
368 |
-
mid_block_scale_factor=unet.config.mid_block_scale_factor,
|
369 |
-
act_fn=unet.config.act_fn,
|
370 |
-
norm_num_groups=unet.config.norm_num_groups,
|
371 |
-
norm_eps=unet.config.norm_eps,
|
372 |
-
cross_attention_dim=unet.config.cross_attention_dim,
|
373 |
-
attention_head_dim=unet.config.attention_head_dim,
|
374 |
-
num_attention_heads=unet.config.num_attention_heads,
|
375 |
-
use_linear_projection=unet.config.use_linear_projection,
|
376 |
-
class_embed_type=unet.config.class_embed_type,
|
377 |
-
num_class_embeds=unet.config.num_class_embeds,
|
378 |
-
upcast_attention=unet.config.upcast_attention,
|
379 |
-
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
|
380 |
-
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
|
381 |
-
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
|
382 |
-
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
383 |
-
)
|
384 |
-
|
385 |
-
if load_weights_from_unet:
|
386 |
-
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
|
387 |
-
controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
|
388 |
-
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
389 |
-
|
390 |
-
if controlnet.class_embedding:
|
391 |
-
controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
|
392 |
-
|
393 |
-
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
|
394 |
-
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
|
395 |
-
|
396 |
-
return controlnet
|
397 |
-
|
398 |
-
@property
|
399 |
-
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
400 |
-
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
401 |
-
r"""
|
402 |
-
Returns:
|
403 |
-
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
404 |
-
indexed by its weight name.
|
405 |
-
"""
|
406 |
-
# set recursively
|
407 |
-
processors = {}
|
408 |
-
|
409 |
-
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
410 |
-
if hasattr(module, "set_processor"):
|
411 |
-
processors[f"{name}.processor"] = module.processor
|
412 |
-
|
413 |
-
for sub_name, child in module.named_children():
|
414 |
-
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
415 |
-
|
416 |
-
return processors
|
417 |
-
|
418 |
-
for name, module in self.named_children():
|
419 |
-
fn_recursive_add_processors(name, module, processors)
|
420 |
-
|
421 |
-
return processors
|
422 |
-
|
423 |
-
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
424 |
-
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
425 |
-
r"""
|
426 |
-
Sets the attention processor to use to compute attention.
|
427 |
-
|
428 |
-
Parameters:
|
429 |
-
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
430 |
-
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
431 |
-
for **all** `Attention` layers.
|
432 |
-
|
433 |
-
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
434 |
-
processor. This is strongly recommended when setting trainable attention processors.
|
435 |
-
|
436 |
-
"""
|
437 |
-
count = len(self.attn_processors.keys())
|
438 |
-
|
439 |
-
if isinstance(processor, dict) and len(processor) != count:
|
440 |
-
raise ValueError(
|
441 |
-
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
442 |
-
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
443 |
-
)
|
444 |
-
|
445 |
-
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
446 |
-
if hasattr(module, "set_processor"):
|
447 |
-
if not isinstance(processor, dict):
|
448 |
-
module.set_processor(processor)
|
449 |
-
else:
|
450 |
-
module.set_processor(processor.pop(f"{name}.processor"))
|
451 |
-
|
452 |
-
for sub_name, child in module.named_children():
|
453 |
-
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
454 |
-
|
455 |
-
for name, module in self.named_children():
|
456 |
-
fn_recursive_attn_processor(name, module, processor)
|
457 |
-
|
458 |
-
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
459 |
-
def set_default_attn_processor(self):
|
460 |
-
"""
|
461 |
-
Disables custom attention processors and sets the default attention implementation.
|
462 |
-
"""
|
463 |
-
self.set_attn_processor(AttnProcessor())
|
464 |
-
|
465 |
-
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
466 |
-
def set_attention_slice(self, slice_size):
|
467 |
-
r"""
|
468 |
-
Enable sliced attention computation.
|
469 |
-
|
470 |
-
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
471 |
-
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
472 |
-
|
473 |
-
Args:
|
474 |
-
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
475 |
-
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
476 |
-
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
477 |
-
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
478 |
-
must be a multiple of `slice_size`.
|
479 |
-
"""
|
480 |
-
sliceable_head_dims = []
|
481 |
-
|
482 |
-
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
483 |
-
if hasattr(module, "set_attention_slice"):
|
484 |
-
sliceable_head_dims.append(module.sliceable_head_dim)
|
485 |
-
|
486 |
-
for child in module.children():
|
487 |
-
fn_recursive_retrieve_sliceable_dims(child)
|
488 |
-
|
489 |
-
# retrieve number of attention layers
|
490 |
-
for module in self.children():
|
491 |
-
fn_recursive_retrieve_sliceable_dims(module)
|
492 |
-
|
493 |
-
num_sliceable_layers = len(sliceable_head_dims)
|
494 |
-
|
495 |
-
if slice_size == "auto":
|
496 |
-
# half the attention head size is usually a good trade-off between
|
497 |
-
# speed and memory
|
498 |
-
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
499 |
-
elif slice_size == "max":
|
500 |
-
# make smallest slice possible
|
501 |
-
slice_size = num_sliceable_layers * [1]
|
502 |
-
|
503 |
-
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
504 |
-
|
505 |
-
if len(slice_size) != len(sliceable_head_dims):
|
506 |
-
raise ValueError(
|
507 |
-
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
508 |
-
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
509 |
-
)
|
510 |
-
|
511 |
-
for i in range(len(slice_size)):
|
512 |
-
size = slice_size[i]
|
513 |
-
dim = sliceable_head_dims[i]
|
514 |
-
if size is not None and size > dim:
|
515 |
-
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
516 |
-
|
517 |
-
# Recursively walk through all the children.
|
518 |
-
# Any children which exposes the set_attention_slice method
|
519 |
-
# gets the message
|
520 |
-
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
521 |
-
if hasattr(module, "set_attention_slice"):
|
522 |
-
module.set_attention_slice(slice_size.pop())
|
523 |
-
|
524 |
-
for child in module.children():
|
525 |
-
fn_recursive_set_attention_slice(child, slice_size)
|
526 |
-
|
527 |
-
reversed_slice_size = list(reversed(slice_size))
|
528 |
-
for module in self.children():
|
529 |
-
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
530 |
-
|
531 |
-
def _set_gradient_checkpointing(self, module, value=False):
|
532 |
-
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
533 |
-
module.gradient_checkpointing = value
|
534 |
-
|
535 |
-
def forward(
|
536 |
-
self,
|
537 |
-
sample: torch.FloatTensor,
|
538 |
-
timestep: Union[torch.Tensor, float, int],
|
539 |
-
encoder_hidden_states: torch.Tensor,
|
540 |
-
controlnet_cond: torch.FloatTensor,
|
541 |
-
conditioning_scale: float = 1.0,
|
542 |
-
class_labels: Optional[torch.Tensor] = None,
|
543 |
-
timestep_cond: Optional[torch.Tensor] = None,
|
544 |
-
attention_mask: Optional[torch.Tensor] = None,
|
545 |
-
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
546 |
-
guess_mode: bool = False,
|
547 |
-
return_dict: bool = True,
|
548 |
-
) -> Union[ControlNetOutput, Tuple]:
|
549 |
-
"""
|
550 |
-
The [`ControlNetModel`] forward method.
|
551 |
-
|
552 |
-
Args:
|
553 |
-
sample (`torch.FloatTensor`):
|
554 |
-
The noisy input tensor.
|
555 |
-
timestep (`Union[torch.Tensor, float, int]`):
|
556 |
-
The number of timesteps to denoise an input.
|
557 |
-
encoder_hidden_states (`torch.Tensor`):
|
558 |
-
The encoder hidden states.
|
559 |
-
controlnet_cond (`torch.FloatTensor`):
|
560 |
-
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
561 |
-
conditioning_scale (`float`, defaults to `1.0`):
|
562 |
-
The scale factor for ControlNet outputs.
|
563 |
-
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
564 |
-
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
565 |
-
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
566 |
-
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
567 |
-
cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`):
|
568 |
-
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
569 |
-
guess_mode (`bool`, defaults to `False`):
|
570 |
-
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
|
571 |
-
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
572 |
-
return_dict (`bool`, defaults to `True`):
|
573 |
-
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
|
574 |
-
|
575 |
-
Returns:
|
576 |
-
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
|
577 |
-
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
|
578 |
-
returned where the first element is the sample tensor.
|
579 |
-
"""
|
580 |
-
# check channel order
|
581 |
-
channel_order = self.config.controlnet_conditioning_channel_order
|
582 |
-
|
583 |
-
if channel_order == "rgb":
|
584 |
-
# in rgb order by default
|
585 |
-
...
|
586 |
-
elif channel_order == "bgr":
|
587 |
-
controlnet_cond = torch.flip(controlnet_cond, dims=[1])
|
588 |
-
else:
|
589 |
-
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
|
590 |
-
|
591 |
-
# prepare attention_mask
|
592 |
-
if attention_mask is not None:
|
593 |
-
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
594 |
-
attention_mask = attention_mask.unsqueeze(1)
|
595 |
-
|
596 |
-
# 1. time
|
597 |
-
timesteps = timestep
|
598 |
-
if not torch.is_tensor(timesteps):
|
599 |
-
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
600 |
-
# This would be a good case for the `match` statement (Python 3.10+)
|
601 |
-
is_mps = sample.device.type == "mps"
|
602 |
-
if isinstance(timestep, float):
|
603 |
-
dtype = torch.float32 if is_mps else torch.float64
|
604 |
-
else:
|
605 |
-
dtype = torch.int32 if is_mps else torch.int64
|
606 |
-
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
607 |
-
elif len(timesteps.shape) == 0:
|
608 |
-
timesteps = timesteps[None].to(sample.device)
|
609 |
-
|
610 |
-
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
611 |
-
timesteps = timesteps.expand(sample.shape[0])
|
612 |
-
|
613 |
-
t_emb = self.time_proj(timesteps)
|
614 |
-
|
615 |
-
# timesteps does not contain any weights and will always return f32 tensors
|
616 |
-
# but time_embedding might actually be running in fp16. so we need to cast here.
|
617 |
-
# there might be better ways to encapsulate this.
|
618 |
-
t_emb = t_emb.to(dtype=sample.dtype)
|
619 |
-
|
620 |
-
emb = self.time_embedding(t_emb, timestep_cond)
|
621 |
-
|
622 |
-
if self.class_embedding is not None:
|
623 |
-
if class_labels is None:
|
624 |
-
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
625 |
-
|
626 |
-
if self.config.class_embed_type == "timestep":
|
627 |
-
class_labels = self.time_proj(class_labels)
|
628 |
-
|
629 |
-
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
630 |
-
emb = emb + class_emb
|
631 |
-
|
632 |
-
# 2. pre-process
|
633 |
-
sample = self.conv_in(sample)
|
634 |
-
|
635 |
-
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
636 |
-
|
637 |
-
sample = sample + controlnet_cond
|
638 |
-
|
639 |
-
# 3. down
|
640 |
-
down_block_res_samples = (sample,)
|
641 |
-
for downsample_block in self.down_blocks:
|
642 |
-
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
643 |
-
sample, res_samples = downsample_block(
|
644 |
-
hidden_states=sample,
|
645 |
-
temb=emb,
|
646 |
-
encoder_hidden_states=encoder_hidden_states,
|
647 |
-
attention_mask=attention_mask,
|
648 |
-
cross_attention_kwargs=cross_attention_kwargs,
|
649 |
-
)
|
650 |
-
else:
|
651 |
-
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
652 |
-
|
653 |
-
down_block_res_samples += res_samples
|
654 |
-
|
655 |
-
# 4. mid
|
656 |
-
if self.mid_block is not None:
|
657 |
-
sample = self.mid_block(
|
658 |
-
sample,
|
659 |
-
emb,
|
660 |
-
encoder_hidden_states=encoder_hidden_states,
|
661 |
-
attention_mask=attention_mask,
|
662 |
-
cross_attention_kwargs=cross_attention_kwargs,
|
663 |
-
)
|
664 |
-
|
665 |
-
# 5. Control net blocks
|
666 |
-
|
667 |
-
controlnet_down_block_res_samples = ()
|
668 |
-
|
669 |
-
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
670 |
-
down_block_res_sample = controlnet_block(down_block_res_sample)
|
671 |
-
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
|
672 |
-
|
673 |
-
down_block_res_samples = controlnet_down_block_res_samples
|
674 |
-
|
675 |
-
mid_block_res_sample = self.controlnet_mid_block(sample)
|
676 |
-
|
677 |
-
# 6. scaling
|
678 |
-
if guess_mode and not self.config.global_pool_conditions:
|
679 |
-
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
680 |
-
|
681 |
-
scales = scales * conditioning_scale
|
682 |
-
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
683 |
-
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
684 |
-
else:
|
685 |
-
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
686 |
-
mid_block_res_sample = mid_block_res_sample * conditioning_scale
|
687 |
-
|
688 |
-
if self.config.global_pool_conditions:
|
689 |
-
down_block_res_samples = [
|
690 |
-
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
|
691 |
-
]
|
692 |
-
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
|
693 |
-
|
694 |
-
if not return_dict:
|
695 |
-
return (down_block_res_samples, mid_block_res_sample)
|
696 |
-
|
697 |
-
return ControlNetOutput(
|
698 |
-
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
699 |
-
)
|
700 |
-
|
701 |
-
|
702 |
-
def zero_module(module):
|
703 |
-
for p in module.parameters():
|
704 |
-
nn.init.zeros_(p)
|
705 |
-
return module
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4DoF/diffusers/models/controlnet_flax.py
DELETED
@@ -1,394 +0,0 @@
|
|
1 |
-
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
from typing import Optional, Tuple, Union
|
15 |
-
|
16 |
-
import flax
|
17 |
-
import flax.linen as nn
|
18 |
-
import jax
|
19 |
-
import jax.numpy as jnp
|
20 |
-
from flax.core.frozen_dict import FrozenDict
|
21 |
-
|
22 |
-
from ..configuration_utils import ConfigMixin, flax_register_to_config
|
23 |
-
from ..utils import BaseOutput
|
24 |
-
from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
|
25 |
-
from .modeling_flax_utils import FlaxModelMixin
|
26 |
-
from .unet_2d_blocks_flax import (
|
27 |
-
FlaxCrossAttnDownBlock2D,
|
28 |
-
FlaxDownBlock2D,
|
29 |
-
FlaxUNetMidBlock2DCrossAttn,
|
30 |
-
)
|
31 |
-
|
32 |
-
|
33 |
-
@flax.struct.dataclass
|
34 |
-
class FlaxControlNetOutput(BaseOutput):
|
35 |
-
"""
|
36 |
-
The output of [`FlaxControlNetModel`].
|
37 |
-
|
38 |
-
Args:
|
39 |
-
down_block_res_samples (`jnp.ndarray`):
|
40 |
-
mid_block_res_sample (`jnp.ndarray`):
|
41 |
-
"""
|
42 |
-
|
43 |
-
down_block_res_samples: jnp.ndarray
|
44 |
-
mid_block_res_sample: jnp.ndarray
|
45 |
-
|
46 |
-
|
47 |
-
class FlaxControlNetConditioningEmbedding(nn.Module):
|
48 |
-
conditioning_embedding_channels: int
|
49 |
-
block_out_channels: Tuple[int] = (16, 32, 96, 256)
|
50 |
-
dtype: jnp.dtype = jnp.float32
|
51 |
-
|
52 |
-
def setup(self):
|
53 |
-
self.conv_in = nn.Conv(
|
54 |
-
self.block_out_channels[0],
|
55 |
-
kernel_size=(3, 3),
|
56 |
-
padding=((1, 1), (1, 1)),
|
57 |
-
dtype=self.dtype,
|
58 |
-
)
|
59 |
-
|
60 |
-
blocks = []
|
61 |
-
for i in range(len(self.block_out_channels) - 1):
|
62 |
-
channel_in = self.block_out_channels[i]
|
63 |
-
channel_out = self.block_out_channels[i + 1]
|
64 |
-
conv1 = nn.Conv(
|
65 |
-
channel_in,
|
66 |
-
kernel_size=(3, 3),
|
67 |
-
padding=((1, 1), (1, 1)),
|
68 |
-
dtype=self.dtype,
|
69 |
-
)
|
70 |
-
blocks.append(conv1)
|
71 |
-
conv2 = nn.Conv(
|
72 |
-
channel_out,
|
73 |
-
kernel_size=(3, 3),
|
74 |
-
strides=(2, 2),
|
75 |
-
padding=((1, 1), (1, 1)),
|
76 |
-
dtype=self.dtype,
|
77 |
-
)
|
78 |
-
blocks.append(conv2)
|
79 |
-
self.blocks = blocks
|
80 |
-
|
81 |
-
self.conv_out = nn.Conv(
|
82 |
-
self.conditioning_embedding_channels,
|
83 |
-
kernel_size=(3, 3),
|
84 |
-
padding=((1, 1), (1, 1)),
|
85 |
-
kernel_init=nn.initializers.zeros_init(),
|
86 |
-
bias_init=nn.initializers.zeros_init(),
|
87 |
-
dtype=self.dtype,
|
88 |
-
)
|
89 |
-
|
90 |
-
def __call__(self, conditioning):
|
91 |
-
embedding = self.conv_in(conditioning)
|
92 |
-
embedding = nn.silu(embedding)
|
93 |
-
|
94 |
-
for block in self.blocks:
|
95 |
-
embedding = block(embedding)
|
96 |
-
embedding = nn.silu(embedding)
|
97 |
-
|
98 |
-
embedding = self.conv_out(embedding)
|
99 |
-
|
100 |
-
return embedding
|
101 |
-
|
102 |
-
|
103 |
-
@flax_register_to_config
|
104 |
-
class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
105 |
-
r"""
|
106 |
-
A ControlNet model.
|
107 |
-
|
108 |
-
This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it’s generic methods
|
109 |
-
implemented for all models (such as downloading or saving).
|
110 |
-
|
111 |
-
This model is also a Flax Linen [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
|
112 |
-
subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its
|
113 |
-
general usage and behavior.
|
114 |
-
|
115 |
-
Inherent JAX features such as the following are supported:
|
116 |
-
|
117 |
-
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
118 |
-
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
119 |
-
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
120 |
-
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
121 |
-
|
122 |
-
Parameters:
|
123 |
-
sample_size (`int`, *optional*):
|
124 |
-
The size of the input sample.
|
125 |
-
in_channels (`int`, *optional*, defaults to 4):
|
126 |
-
The number of channels in the input sample.
|
127 |
-
down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
|
128 |
-
The tuple of downsample blocks to use.
|
129 |
-
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
130 |
-
The tuple of output channels for each block.
|
131 |
-
layers_per_block (`int`, *optional*, defaults to 2):
|
132 |
-
The number of layers per block.
|
133 |
-
attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
|
134 |
-
The dimension of the attention heads.
|
135 |
-
num_attention_heads (`int` or `Tuple[int]`, *optional*):
|
136 |
-
The number of attention heads.
|
137 |
-
cross_attention_dim (`int`, *optional*, defaults to 768):
|
138 |
-
The dimension of the cross attention features.
|
139 |
-
dropout (`float`, *optional*, defaults to 0):
|
140 |
-
Dropout probability for down, up and bottleneck blocks.
|
141 |
-
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
|
142 |
-
Whether to flip the sin to cos in the time embedding.
|
143 |
-
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
144 |
-
controlnet_conditioning_channel_order (`str`, *optional*, defaults to `rgb`):
|
145 |
-
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
146 |
-
conditioning_embedding_out_channels (`tuple`, *optional*, defaults to `(16, 32, 96, 256)`):
|
147 |
-
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
148 |
-
"""
|
149 |
-
sample_size: int = 32
|
150 |
-
in_channels: int = 4
|
151 |
-
down_block_types: Tuple[str] = (
|
152 |
-
"CrossAttnDownBlock2D",
|
153 |
-
"CrossAttnDownBlock2D",
|
154 |
-
"CrossAttnDownBlock2D",
|
155 |
-
"DownBlock2D",
|
156 |
-
)
|
157 |
-
only_cross_attention: Union[bool, Tuple[bool]] = False
|
158 |
-
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
|
159 |
-
layers_per_block: int = 2
|
160 |
-
attention_head_dim: Union[int, Tuple[int]] = 8
|
161 |
-
num_attention_heads: Optional[Union[int, Tuple[int]]] = None
|
162 |
-
cross_attention_dim: int = 1280
|
163 |
-
dropout: float = 0.0
|
164 |
-
use_linear_projection: bool = False
|
165 |
-
dtype: jnp.dtype = jnp.float32
|
166 |
-
flip_sin_to_cos: bool = True
|
167 |
-
freq_shift: int = 0
|
168 |
-
controlnet_conditioning_channel_order: str = "rgb"
|
169 |
-
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256)
|
170 |
-
|
171 |
-
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
|
172 |
-
# init input tensors
|
173 |
-
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
|
174 |
-
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
|
175 |
-
timesteps = jnp.ones((1,), dtype=jnp.int32)
|
176 |
-
encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
|
177 |
-
controlnet_cond_shape = (1, 3, self.sample_size * 8, self.sample_size * 8)
|
178 |
-
controlnet_cond = jnp.zeros(controlnet_cond_shape, dtype=jnp.float32)
|
179 |
-
|
180 |
-
params_rng, dropout_rng = jax.random.split(rng)
|
181 |
-
rngs = {"params": params_rng, "dropout": dropout_rng}
|
182 |
-
|
183 |
-
return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"]
|
184 |
-
|
185 |
-
def setup(self):
|
186 |
-
block_out_channels = self.block_out_channels
|
187 |
-
time_embed_dim = block_out_channels[0] * 4
|
188 |
-
|
189 |
-
# If `num_attention_heads` is not defined (which is the case for most models)
|
190 |
-
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
191 |
-
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
192 |
-
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
193 |
-
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
194 |
-
# which is why we correct for the naming here.
|
195 |
-
num_attention_heads = self.num_attention_heads or self.attention_head_dim
|
196 |
-
|
197 |
-
# input
|
198 |
-
self.conv_in = nn.Conv(
|
199 |
-
block_out_channels[0],
|
200 |
-
kernel_size=(3, 3),
|
201 |
-
strides=(1, 1),
|
202 |
-
padding=((1, 1), (1, 1)),
|
203 |
-
dtype=self.dtype,
|
204 |
-
)
|
205 |
-
|
206 |
-
# time
|
207 |
-
self.time_proj = FlaxTimesteps(
|
208 |
-
block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift
|
209 |
-
)
|
210 |
-
self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
|
211 |
-
|
212 |
-
self.controlnet_cond_embedding = FlaxControlNetConditioningEmbedding(
|
213 |
-
conditioning_embedding_channels=block_out_channels[0],
|
214 |
-
block_out_channels=self.conditioning_embedding_out_channels,
|
215 |
-
)
|
216 |
-
|
217 |
-
only_cross_attention = self.only_cross_attention
|
218 |
-
if isinstance(only_cross_attention, bool):
|
219 |
-
only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
|
220 |
-
|
221 |
-
if isinstance(num_attention_heads, int):
|
222 |
-
num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
|
223 |
-
|
224 |
-
# down
|
225 |
-
down_blocks = []
|
226 |
-
controlnet_down_blocks = []
|
227 |
-
|
228 |
-
output_channel = block_out_channels[0]
|
229 |
-
|
230 |
-
controlnet_block = nn.Conv(
|
231 |
-
output_channel,
|
232 |
-
kernel_size=(1, 1),
|
233 |
-
padding="VALID",
|
234 |
-
kernel_init=nn.initializers.zeros_init(),
|
235 |
-
bias_init=nn.initializers.zeros_init(),
|
236 |
-
dtype=self.dtype,
|
237 |
-
)
|
238 |
-
controlnet_down_blocks.append(controlnet_block)
|
239 |
-
|
240 |
-
for i, down_block_type in enumerate(self.down_block_types):
|
241 |
-
input_channel = output_channel
|
242 |
-
output_channel = block_out_channels[i]
|
243 |
-
is_final_block = i == len(block_out_channels) - 1
|
244 |
-
|
245 |
-
if down_block_type == "CrossAttnDownBlock2D":
|
246 |
-
down_block = FlaxCrossAttnDownBlock2D(
|
247 |
-
in_channels=input_channel,
|
248 |
-
out_channels=output_channel,
|
249 |
-
dropout=self.dropout,
|
250 |
-
num_layers=self.layers_per_block,
|
251 |
-
num_attention_heads=num_attention_heads[i],
|
252 |
-
add_downsample=not is_final_block,
|
253 |
-
use_linear_projection=self.use_linear_projection,
|
254 |
-
only_cross_attention=only_cross_attention[i],
|
255 |
-
dtype=self.dtype,
|
256 |
-
)
|
257 |
-
else:
|
258 |
-
down_block = FlaxDownBlock2D(
|
259 |
-
in_channels=input_channel,
|
260 |
-
out_channels=output_channel,
|
261 |
-
dropout=self.dropout,
|
262 |
-
num_layers=self.layers_per_block,
|
263 |
-
add_downsample=not is_final_block,
|
264 |
-
dtype=self.dtype,
|
265 |
-
)
|
266 |
-
|
267 |
-
down_blocks.append(down_block)
|
268 |
-
|
269 |
-
for _ in range(self.layers_per_block):
|
270 |
-
controlnet_block = nn.Conv(
|
271 |
-
output_channel,
|
272 |
-
kernel_size=(1, 1),
|
273 |
-
padding="VALID",
|
274 |
-
kernel_init=nn.initializers.zeros_init(),
|
275 |
-
bias_init=nn.initializers.zeros_init(),
|
276 |
-
dtype=self.dtype,
|
277 |
-
)
|
278 |
-
controlnet_down_blocks.append(controlnet_block)
|
279 |
-
|
280 |
-
if not is_final_block:
|
281 |
-
controlnet_block = nn.Conv(
|
282 |
-
output_channel,
|
283 |
-
kernel_size=(1, 1),
|
284 |
-
padding="VALID",
|
285 |
-
kernel_init=nn.initializers.zeros_init(),
|
286 |
-
bias_init=nn.initializers.zeros_init(),
|
287 |
-
dtype=self.dtype,
|
288 |
-
)
|
289 |
-
controlnet_down_blocks.append(controlnet_block)
|
290 |
-
|
291 |
-
self.down_blocks = down_blocks
|
292 |
-
self.controlnet_down_blocks = controlnet_down_blocks
|
293 |
-
|
294 |
-
# mid
|
295 |
-
mid_block_channel = block_out_channels[-1]
|
296 |
-
self.mid_block = FlaxUNetMidBlock2DCrossAttn(
|
297 |
-
in_channels=mid_block_channel,
|
298 |
-
dropout=self.dropout,
|
299 |
-
num_attention_heads=num_attention_heads[-1],
|
300 |
-
use_linear_projection=self.use_linear_projection,
|
301 |
-
dtype=self.dtype,
|
302 |
-
)
|
303 |
-
|
304 |
-
self.controlnet_mid_block = nn.Conv(
|
305 |
-
mid_block_channel,
|
306 |
-
kernel_size=(1, 1),
|
307 |
-
padding="VALID",
|
308 |
-
kernel_init=nn.initializers.zeros_init(),
|
309 |
-
bias_init=nn.initializers.zeros_init(),
|
310 |
-
dtype=self.dtype,
|
311 |
-
)
|
312 |
-
|
313 |
-
def __call__(
|
314 |
-
self,
|
315 |
-
sample,
|
316 |
-
timesteps,
|
317 |
-
encoder_hidden_states,
|
318 |
-
controlnet_cond,
|
319 |
-
conditioning_scale: float = 1.0,
|
320 |
-
return_dict: bool = True,
|
321 |
-
train: bool = False,
|
322 |
-
) -> Union[FlaxControlNetOutput, Tuple]:
|
323 |
-
r"""
|
324 |
-
Args:
|
325 |
-
sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
|
326 |
-
timestep (`jnp.ndarray` or `float` or `int`): timesteps
|
327 |
-
encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
|
328 |
-
controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor
|
329 |
-
conditioning_scale: (`float`) the scale factor for controlnet outputs
|
330 |
-
return_dict (`bool`, *optional*, defaults to `True`):
|
331 |
-
Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
|
332 |
-
plain tuple.
|
333 |
-
train (`bool`, *optional*, defaults to `False`):
|
334 |
-
Use deterministic functions and disable dropout when not training.
|
335 |
-
|
336 |
-
Returns:
|
337 |
-
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
|
338 |
-
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
|
339 |
-
When returning a tuple, the first element is the sample tensor.
|
340 |
-
"""
|
341 |
-
channel_order = self.controlnet_conditioning_channel_order
|
342 |
-
if channel_order == "bgr":
|
343 |
-
controlnet_cond = jnp.flip(controlnet_cond, axis=1)
|
344 |
-
|
345 |
-
# 1. time
|
346 |
-
if not isinstance(timesteps, jnp.ndarray):
|
347 |
-
timesteps = jnp.array([timesteps], dtype=jnp.int32)
|
348 |
-
elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
|
349 |
-
timesteps = timesteps.astype(dtype=jnp.float32)
|
350 |
-
timesteps = jnp.expand_dims(timesteps, 0)
|
351 |
-
|
352 |
-
t_emb = self.time_proj(timesteps)
|
353 |
-
t_emb = self.time_embedding(t_emb)
|
354 |
-
|
355 |
-
# 2. pre-process
|
356 |
-
sample = jnp.transpose(sample, (0, 2, 3, 1))
|
357 |
-
sample = self.conv_in(sample)
|
358 |
-
|
359 |
-
controlnet_cond = jnp.transpose(controlnet_cond, (0, 2, 3, 1))
|
360 |
-
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
361 |
-
sample += controlnet_cond
|
362 |
-
|
363 |
-
# 3. down
|
364 |
-
down_block_res_samples = (sample,)
|
365 |
-
for down_block in self.down_blocks:
|
366 |
-
if isinstance(down_block, FlaxCrossAttnDownBlock2D):
|
367 |
-
sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
|
368 |
-
else:
|
369 |
-
sample, res_samples = down_block(sample, t_emb, deterministic=not train)
|
370 |
-
down_block_res_samples += res_samples
|
371 |
-
|
372 |
-
# 4. mid
|
373 |
-
sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
|
374 |
-
|
375 |
-
# 5. contronet blocks
|
376 |
-
controlnet_down_block_res_samples = ()
|
377 |
-
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
378 |
-
down_block_res_sample = controlnet_block(down_block_res_sample)
|
379 |
-
controlnet_down_block_res_samples += (down_block_res_sample,)
|
380 |
-
|
381 |
-
down_block_res_samples = controlnet_down_block_res_samples
|
382 |
-
|
383 |
-
mid_block_res_sample = self.controlnet_mid_block(sample)
|
384 |
-
|
385 |
-
# 6. scaling
|
386 |
-
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
387 |
-
mid_block_res_sample *= conditioning_scale
|
388 |
-
|
389 |
-
if not return_dict:
|
390 |
-
return (down_block_res_samples, mid_block_res_sample)
|
391 |
-
|
392 |
-
return FlaxControlNetOutput(
|
393 |
-
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
394 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4DoF/diffusers/models/cross_attention.py
DELETED
@@ -1,94 +0,0 @@
|
|
1 |
-
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
from ..utils import deprecate
|
15 |
-
from .attention_processor import ( # noqa: F401
|
16 |
-
Attention,
|
17 |
-
AttentionProcessor,
|
18 |
-
AttnAddedKVProcessor,
|
19 |
-
AttnProcessor2_0,
|
20 |
-
LoRAAttnProcessor,
|
21 |
-
LoRALinearLayer,
|
22 |
-
LoRAXFormersAttnProcessor,
|
23 |
-
SlicedAttnAddedKVProcessor,
|
24 |
-
SlicedAttnProcessor,
|
25 |
-
XFormersAttnProcessor,
|
26 |
-
)
|
27 |
-
from .attention_processor import AttnProcessor as AttnProcessorRename # noqa: F401
|
28 |
-
|
29 |
-
|
30 |
-
deprecate(
|
31 |
-
"cross_attention",
|
32 |
-
"0.20.0",
|
33 |
-
"Importing from cross_attention is deprecated. Please import from diffusers.models.attention_processor instead.",
|
34 |
-
standard_warn=False,
|
35 |
-
)
|
36 |
-
|
37 |
-
|
38 |
-
AttnProcessor = AttentionProcessor
|
39 |
-
|
40 |
-
|
41 |
-
class CrossAttention(Attention):
|
42 |
-
def __init__(self, *args, **kwargs):
|
43 |
-
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
|
44 |
-
deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
|
45 |
-
super().__init__(*args, **kwargs)
|
46 |
-
|
47 |
-
|
48 |
-
class CrossAttnProcessor(AttnProcessorRename):
|
49 |
-
def __init__(self, *args, **kwargs):
|
50 |
-
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
|
51 |
-
deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
|
52 |
-
super().__init__(*args, **kwargs)
|
53 |
-
|
54 |
-
|
55 |
-
class LoRACrossAttnProcessor(LoRAAttnProcessor):
|
56 |
-
def __init__(self, *args, **kwargs):
|
57 |
-
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
|
58 |
-
deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
|
59 |
-
super().__init__(*args, **kwargs)
|
60 |
-
|
61 |
-
|
62 |
-
class CrossAttnAddedKVProcessor(AttnAddedKVProcessor):
|
63 |
-
def __init__(self, *args, **kwargs):
|
64 |
-
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
|
65 |
-
deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
|
66 |
-
super().__init__(*args, **kwargs)
|
67 |
-
|
68 |
-
|
69 |
-
class XFormersCrossAttnProcessor(XFormersAttnProcessor):
|
70 |
-
def __init__(self, *args, **kwargs):
|
71 |
-
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
|
72 |
-
deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
|
73 |
-
super().__init__(*args, **kwargs)
|
74 |
-
|
75 |
-
|
76 |
-
class LoRAXFormersCrossAttnProcessor(LoRAXFormersAttnProcessor):
|
77 |
-
def __init__(self, *args, **kwargs):
|
78 |
-
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
|
79 |
-
deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
|
80 |
-
super().__init__(*args, **kwargs)
|
81 |
-
|
82 |
-
|
83 |
-
class SlicedCrossAttnProcessor(SlicedAttnProcessor):
|
84 |
-
def __init__(self, *args, **kwargs):
|
85 |
-
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
|
86 |
-
deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
|
87 |
-
super().__init__(*args, **kwargs)
|
88 |
-
|
89 |
-
|
90 |
-
class SlicedCrossAttnAddedKVProcessor(SlicedAttnAddedKVProcessor):
|
91 |
-
def __init__(self, *args, **kwargs):
|
92 |
-
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
|
93 |
-
deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
|
94 |
-
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4DoF/diffusers/models/dual_transformer_2d.py
DELETED
@@ -1,151 +0,0 @@
|
|
1 |
-
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
from typing import Optional
|
15 |
-
|
16 |
-
from torch import nn
|
17 |
-
|
18 |
-
from .transformer_2d import Transformer2DModel, Transformer2DModelOutput
|
19 |
-
|
20 |
-
|
21 |
-
class DualTransformer2DModel(nn.Module):
|
22 |
-
"""
|
23 |
-
Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
|
24 |
-
|
25 |
-
Parameters:
|
26 |
-
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
27 |
-
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
28 |
-
in_channels (`int`, *optional*):
|
29 |
-
Pass if the input is continuous. The number of channels in the input and output.
|
30 |
-
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
31 |
-
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
|
32 |
-
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
|
33 |
-
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
34 |
-
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
35 |
-
`ImagePositionalEmbeddings`.
|
36 |
-
num_vector_embeds (`int`, *optional*):
|
37 |
-
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
|
38 |
-
Includes the class for the masked latent pixel.
|
39 |
-
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
40 |
-
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
|
41 |
-
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
|
42 |
-
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
|
43 |
-
up to but not more than steps than `num_embeds_ada_norm`.
|
44 |
-
attention_bias (`bool`, *optional*):
|
45 |
-
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
46 |
-
"""
|
47 |
-
|
48 |
-
def __init__(
|
49 |
-
self,
|
50 |
-
num_attention_heads: int = 16,
|
51 |
-
attention_head_dim: int = 88,
|
52 |
-
in_channels: Optional[int] = None,
|
53 |
-
num_layers: int = 1,
|
54 |
-
dropout: float = 0.0,
|
55 |
-
norm_num_groups: int = 32,
|
56 |
-
cross_attention_dim: Optional[int] = None,
|
57 |
-
attention_bias: bool = False,
|
58 |
-
sample_size: Optional[int] = None,
|
59 |
-
num_vector_embeds: Optional[int] = None,
|
60 |
-
activation_fn: str = "geglu",
|
61 |
-
num_embeds_ada_norm: Optional[int] = None,
|
62 |
-
):
|
63 |
-
super().__init__()
|
64 |
-
self.transformers = nn.ModuleList(
|
65 |
-
[
|
66 |
-
Transformer2DModel(
|
67 |
-
num_attention_heads=num_attention_heads,
|
68 |
-
attention_head_dim=attention_head_dim,
|
69 |
-
in_channels=in_channels,
|
70 |
-
num_layers=num_layers,
|
71 |
-
dropout=dropout,
|
72 |
-
norm_num_groups=norm_num_groups,
|
73 |
-
cross_attention_dim=cross_attention_dim,
|
74 |
-
attention_bias=attention_bias,
|
75 |
-
sample_size=sample_size,
|
76 |
-
num_vector_embeds=num_vector_embeds,
|
77 |
-
activation_fn=activation_fn,
|
78 |
-
num_embeds_ada_norm=num_embeds_ada_norm,
|
79 |
-
)
|
80 |
-
for _ in range(2)
|
81 |
-
]
|
82 |
-
)
|
83 |
-
|
84 |
-
# Variables that can be set by a pipeline:
|
85 |
-
|
86 |
-
# The ratio of transformer1 to transformer2's output states to be combined during inference
|
87 |
-
self.mix_ratio = 0.5
|
88 |
-
|
89 |
-
# The shape of `encoder_hidden_states` is expected to be
|
90 |
-
# `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
|
91 |
-
self.condition_lengths = [77, 257]
|
92 |
-
|
93 |
-
# Which transformer to use to encode which condition.
|
94 |
-
# E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
|
95 |
-
self.transformer_index_for_condition = [1, 0]
|
96 |
-
|
97 |
-
def forward(
|
98 |
-
self,
|
99 |
-
hidden_states,
|
100 |
-
encoder_hidden_states,
|
101 |
-
timestep=None,
|
102 |
-
attention_mask=None,
|
103 |
-
cross_attention_kwargs=None,
|
104 |
-
return_dict: bool = True,
|
105 |
-
):
|
106 |
-
"""
|
107 |
-
Args:
|
108 |
-
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
109 |
-
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
110 |
-
hidden_states
|
111 |
-
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
112 |
-
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
113 |
-
self-attention.
|
114 |
-
timestep ( `torch.long`, *optional*):
|
115 |
-
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
116 |
-
attention_mask (`torch.FloatTensor`, *optional*):
|
117 |
-
Optional attention mask to be applied in Attention
|
118 |
-
return_dict (`bool`, *optional*, defaults to `True`):
|
119 |
-
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
120 |
-
|
121 |
-
Returns:
|
122 |
-
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
|
123 |
-
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
124 |
-
returning a tuple, the first element is the sample tensor.
|
125 |
-
"""
|
126 |
-
input_states = hidden_states
|
127 |
-
|
128 |
-
encoded_states = []
|
129 |
-
tokens_start = 0
|
130 |
-
# attention_mask is not used yet
|
131 |
-
for i in range(2):
|
132 |
-
# for each of the two transformers, pass the corresponding condition tokens
|
133 |
-
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
|
134 |
-
transformer_index = self.transformer_index_for_condition[i]
|
135 |
-
encoded_state = self.transformers[transformer_index](
|
136 |
-
input_states,
|
137 |
-
encoder_hidden_states=condition_state,
|
138 |
-
timestep=timestep,
|
139 |
-
cross_attention_kwargs=cross_attention_kwargs,
|
140 |
-
return_dict=False,
|
141 |
-
)[0]
|
142 |
-
encoded_states.append(encoded_state - input_states)
|
143 |
-
tokens_start += self.condition_lengths[i]
|
144 |
-
|
145 |
-
output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
|
146 |
-
output_states = output_states + input_states
|
147 |
-
|
148 |
-
if not return_dict:
|
149 |
-
return (output_states,)
|
150 |
-
|
151 |
-
return Transformer2DModelOutput(sample=output_states)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4DoF/diffusers/models/embeddings.py
DELETED
@@ -1,546 +0,0 @@
|
|
1 |
-
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
import math
|
15 |
-
from typing import Optional
|
16 |
-
|
17 |
-
import numpy as np
|
18 |
-
import torch
|
19 |
-
from torch import nn
|
20 |
-
|
21 |
-
from .activations import get_activation
|
22 |
-
|
23 |
-
|
24 |
-
def get_timestep_embedding(
|
25 |
-
timesteps: torch.Tensor,
|
26 |
-
embedding_dim: int,
|
27 |
-
flip_sin_to_cos: bool = False,
|
28 |
-
downscale_freq_shift: float = 1,
|
29 |
-
scale: float = 1,
|
30 |
-
max_period: int = 10000,
|
31 |
-
):
|
32 |
-
"""
|
33 |
-
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
34 |
-
|
35 |
-
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
36 |
-
These may be fractional.
|
37 |
-
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
|
38 |
-
embeddings. :return: an [N x dim] Tensor of positional embeddings.
|
39 |
-
"""
|
40 |
-
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
41 |
-
|
42 |
-
half_dim = embedding_dim // 2
|
43 |
-
exponent = -math.log(max_period) * torch.arange(
|
44 |
-
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
45 |
-
)
|
46 |
-
exponent = exponent / (half_dim - downscale_freq_shift)
|
47 |
-
|
48 |
-
emb = torch.exp(exponent)
|
49 |
-
emb = timesteps[:, None].float() * emb[None, :]
|
50 |
-
|
51 |
-
# scale embeddings
|
52 |
-
emb = scale * emb
|
53 |
-
|
54 |
-
# concat sine and cosine embeddings
|
55 |
-
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
56 |
-
|
57 |
-
# flip sine and cosine embeddings
|
58 |
-
if flip_sin_to_cos:
|
59 |
-
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
60 |
-
|
61 |
-
# zero pad
|
62 |
-
if embedding_dim % 2 == 1:
|
63 |
-
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
64 |
-
return emb
|
65 |
-
|
66 |
-
|
67 |
-
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
68 |
-
"""
|
69 |
-
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
|
70 |
-
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
71 |
-
"""
|
72 |
-
grid_h = np.arange(grid_size, dtype=np.float32)
|
73 |
-
grid_w = np.arange(grid_size, dtype=np.float32)
|
74 |
-
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
75 |
-
grid = np.stack(grid, axis=0)
|
76 |
-
|
77 |
-
grid = grid.reshape([2, 1, grid_size, grid_size])
|
78 |
-
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
79 |
-
if cls_token and extra_tokens > 0:
|
80 |
-
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
81 |
-
return pos_embed
|
82 |
-
|
83 |
-
|
84 |
-
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
85 |
-
if embed_dim % 2 != 0:
|
86 |
-
raise ValueError("embed_dim must be divisible by 2")
|
87 |
-
|
88 |
-
# use half of dimensions to encode grid_h
|
89 |
-
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
90 |
-
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
91 |
-
|
92 |
-
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
93 |
-
return emb
|
94 |
-
|
95 |
-
|
96 |
-
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
97 |
-
"""
|
98 |
-
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
|
99 |
-
"""
|
100 |
-
if embed_dim % 2 != 0:
|
101 |
-
raise ValueError("embed_dim must be divisible by 2")
|
102 |
-
|
103 |
-
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
104 |
-
omega /= embed_dim / 2.0
|
105 |
-
omega = 1.0 / 10000**omega # (D/2,)
|
106 |
-
|
107 |
-
pos = pos.reshape(-1) # (M,)
|
108 |
-
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
109 |
-
|
110 |
-
emb_sin = np.sin(out) # (M, D/2)
|
111 |
-
emb_cos = np.cos(out) # (M, D/2)
|
112 |
-
|
113 |
-
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
114 |
-
return emb
|
115 |
-
|
116 |
-
|
117 |
-
class PatchEmbed(nn.Module):
|
118 |
-
"""2D Image to Patch Embedding"""
|
119 |
-
|
120 |
-
def __init__(
|
121 |
-
self,
|
122 |
-
height=224,
|
123 |
-
width=224,
|
124 |
-
patch_size=16,
|
125 |
-
in_channels=3,
|
126 |
-
embed_dim=768,
|
127 |
-
layer_norm=False,
|
128 |
-
flatten=True,
|
129 |
-
bias=True,
|
130 |
-
):
|
131 |
-
super().__init__()
|
132 |
-
|
133 |
-
num_patches = (height // patch_size) * (width // patch_size)
|
134 |
-
self.flatten = flatten
|
135 |
-
self.layer_norm = layer_norm
|
136 |
-
|
137 |
-
self.proj = nn.Conv2d(
|
138 |
-
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
139 |
-
)
|
140 |
-
if layer_norm:
|
141 |
-
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
|
142 |
-
else:
|
143 |
-
self.norm = None
|
144 |
-
|
145 |
-
pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
|
146 |
-
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
|
147 |
-
|
148 |
-
def forward(self, latent):
|
149 |
-
latent = self.proj(latent)
|
150 |
-
if self.flatten:
|
151 |
-
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
152 |
-
if self.layer_norm:
|
153 |
-
latent = self.norm(latent)
|
154 |
-
return latent + self.pos_embed
|
155 |
-
|
156 |
-
|
157 |
-
class TimestepEmbedding(nn.Module):
|
158 |
-
def __init__(
|
159 |
-
self,
|
160 |
-
in_channels: int,
|
161 |
-
time_embed_dim: int,
|
162 |
-
act_fn: str = "silu",
|
163 |
-
out_dim: int = None,
|
164 |
-
post_act_fn: Optional[str] = None,
|
165 |
-
cond_proj_dim=None,
|
166 |
-
):
|
167 |
-
super().__init__()
|
168 |
-
|
169 |
-
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
170 |
-
|
171 |
-
if cond_proj_dim is not None:
|
172 |
-
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
173 |
-
else:
|
174 |
-
self.cond_proj = None
|
175 |
-
|
176 |
-
self.act = get_activation(act_fn)
|
177 |
-
|
178 |
-
if out_dim is not None:
|
179 |
-
time_embed_dim_out = out_dim
|
180 |
-
else:
|
181 |
-
time_embed_dim_out = time_embed_dim
|
182 |
-
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
|
183 |
-
|
184 |
-
if post_act_fn is None:
|
185 |
-
self.post_act = None
|
186 |
-
else:
|
187 |
-
self.post_act = get_activation(post_act_fn)
|
188 |
-
|
189 |
-
def forward(self, sample, condition=None):
|
190 |
-
if condition is not None:
|
191 |
-
sample = sample + self.cond_proj(condition)
|
192 |
-
sample = self.linear_1(sample)
|
193 |
-
|
194 |
-
if self.act is not None:
|
195 |
-
sample = self.act(sample)
|
196 |
-
|
197 |
-
sample = self.linear_2(sample)
|
198 |
-
|
199 |
-
if self.post_act is not None:
|
200 |
-
sample = self.post_act(sample)
|
201 |
-
return sample
|
202 |
-
|
203 |
-
|
204 |
-
class Timesteps(nn.Module):
|
205 |
-
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
|
206 |
-
super().__init__()
|
207 |
-
self.num_channels = num_channels
|
208 |
-
self.flip_sin_to_cos = flip_sin_to_cos
|
209 |
-
self.downscale_freq_shift = downscale_freq_shift
|
210 |
-
|
211 |
-
def forward(self, timesteps):
|
212 |
-
t_emb = get_timestep_embedding(
|
213 |
-
timesteps,
|
214 |
-
self.num_channels,
|
215 |
-
flip_sin_to_cos=self.flip_sin_to_cos,
|
216 |
-
downscale_freq_shift=self.downscale_freq_shift,
|
217 |
-
)
|
218 |
-
return t_emb
|
219 |
-
|
220 |
-
|
221 |
-
class GaussianFourierProjection(nn.Module):
|
222 |
-
"""Gaussian Fourier embeddings for noise levels."""
|
223 |
-
|
224 |
-
def __init__(
|
225 |
-
self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
|
226 |
-
):
|
227 |
-
super().__init__()
|
228 |
-
self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
|
229 |
-
self.log = log
|
230 |
-
self.flip_sin_to_cos = flip_sin_to_cos
|
231 |
-
|
232 |
-
if set_W_to_weight:
|
233 |
-
# to delete later
|
234 |
-
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
|
235 |
-
|
236 |
-
self.weight = self.W
|
237 |
-
|
238 |
-
def forward(self, x):
|
239 |
-
if self.log:
|
240 |
-
x = torch.log(x)
|
241 |
-
|
242 |
-
x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
|
243 |
-
|
244 |
-
if self.flip_sin_to_cos:
|
245 |
-
out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
|
246 |
-
else:
|
247 |
-
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
|
248 |
-
return out
|
249 |
-
|
250 |
-
|
251 |
-
class ImagePositionalEmbeddings(nn.Module):
|
252 |
-
"""
|
253 |
-
Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
|
254 |
-
height and width of the latent space.
|
255 |
-
|
256 |
-
For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
|
257 |
-
|
258 |
-
For VQ-diffusion:
|
259 |
-
|
260 |
-
Output vector embeddings are used as input for the transformer.
|
261 |
-
|
262 |
-
Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
|
263 |
-
|
264 |
-
Args:
|
265 |
-
num_embed (`int`):
|
266 |
-
Number of embeddings for the latent pixels embeddings.
|
267 |
-
height (`int`):
|
268 |
-
Height of the latent image i.e. the number of height embeddings.
|
269 |
-
width (`int`):
|
270 |
-
Width of the latent image i.e. the number of width embeddings.
|
271 |
-
embed_dim (`int`):
|
272 |
-
Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
|
273 |
-
"""
|
274 |
-
|
275 |
-
def __init__(
|
276 |
-
self,
|
277 |
-
num_embed: int,
|
278 |
-
height: int,
|
279 |
-
width: int,
|
280 |
-
embed_dim: int,
|
281 |
-
):
|
282 |
-
super().__init__()
|
283 |
-
|
284 |
-
self.height = height
|
285 |
-
self.width = width
|
286 |
-
self.num_embed = num_embed
|
287 |
-
self.embed_dim = embed_dim
|
288 |
-
|
289 |
-
self.emb = nn.Embedding(self.num_embed, embed_dim)
|
290 |
-
self.height_emb = nn.Embedding(self.height, embed_dim)
|
291 |
-
self.width_emb = nn.Embedding(self.width, embed_dim)
|
292 |
-
|
293 |
-
def forward(self, index):
|
294 |
-
emb = self.emb(index)
|
295 |
-
|
296 |
-
height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
|
297 |
-
|
298 |
-
# 1 x H x D -> 1 x H x 1 x D
|
299 |
-
height_emb = height_emb.unsqueeze(2)
|
300 |
-
|
301 |
-
width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
|
302 |
-
|
303 |
-
# 1 x W x D -> 1 x 1 x W x D
|
304 |
-
width_emb = width_emb.unsqueeze(1)
|
305 |
-
|
306 |
-
pos_emb = height_emb + width_emb
|
307 |
-
|
308 |
-
# 1 x H x W x D -> 1 x L xD
|
309 |
-
pos_emb = pos_emb.view(1, self.height * self.width, -1)
|
310 |
-
|
311 |
-
emb = emb + pos_emb[:, : emb.shape[1], :]
|
312 |
-
|
313 |
-
return emb
|
314 |
-
|
315 |
-
|
316 |
-
class LabelEmbedding(nn.Module):
|
317 |
-
"""
|
318 |
-
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
319 |
-
|
320 |
-
Args:
|
321 |
-
num_classes (`int`): The number of classes.
|
322 |
-
hidden_size (`int`): The size of the vector embeddings.
|
323 |
-
dropout_prob (`float`): The probability of dropping a label.
|
324 |
-
"""
|
325 |
-
|
326 |
-
def __init__(self, num_classes, hidden_size, dropout_prob):
|
327 |
-
super().__init__()
|
328 |
-
use_cfg_embedding = dropout_prob > 0
|
329 |
-
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
330 |
-
self.num_classes = num_classes
|
331 |
-
self.dropout_prob = dropout_prob
|
332 |
-
|
333 |
-
def token_drop(self, labels, force_drop_ids=None):
|
334 |
-
"""
|
335 |
-
Drops labels to enable classifier-free guidance.
|
336 |
-
"""
|
337 |
-
if force_drop_ids is None:
|
338 |
-
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
339 |
-
else:
|
340 |
-
drop_ids = torch.tensor(force_drop_ids == 1)
|
341 |
-
labels = torch.where(drop_ids, self.num_classes, labels)
|
342 |
-
return labels
|
343 |
-
|
344 |
-
def forward(self, labels: torch.LongTensor, force_drop_ids=None):
|
345 |
-
use_dropout = self.dropout_prob > 0
|
346 |
-
if (self.training and use_dropout) or (force_drop_ids is not None):
|
347 |
-
labels = self.token_drop(labels, force_drop_ids)
|
348 |
-
embeddings = self.embedding_table(labels)
|
349 |
-
return embeddings
|
350 |
-
|
351 |
-
|
352 |
-
class TextImageProjection(nn.Module):
|
353 |
-
def __init__(
|
354 |
-
self,
|
355 |
-
text_embed_dim: int = 1024,
|
356 |
-
image_embed_dim: int = 768,
|
357 |
-
cross_attention_dim: int = 768,
|
358 |
-
num_image_text_embeds: int = 10,
|
359 |
-
):
|
360 |
-
super().__init__()
|
361 |
-
|
362 |
-
self.num_image_text_embeds = num_image_text_embeds
|
363 |
-
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
|
364 |
-
self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim)
|
365 |
-
|
366 |
-
def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
|
367 |
-
batch_size = text_embeds.shape[0]
|
368 |
-
|
369 |
-
# image
|
370 |
-
image_text_embeds = self.image_embeds(image_embeds)
|
371 |
-
image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
|
372 |
-
|
373 |
-
# text
|
374 |
-
text_embeds = self.text_proj(text_embeds)
|
375 |
-
|
376 |
-
return torch.cat([image_text_embeds, text_embeds], dim=1)
|
377 |
-
|
378 |
-
|
379 |
-
class ImageProjection(nn.Module):
|
380 |
-
def __init__(
|
381 |
-
self,
|
382 |
-
image_embed_dim: int = 768,
|
383 |
-
cross_attention_dim: int = 768,
|
384 |
-
num_image_text_embeds: int = 32,
|
385 |
-
):
|
386 |
-
super().__init__()
|
387 |
-
|
388 |
-
self.num_image_text_embeds = num_image_text_embeds
|
389 |
-
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
|
390 |
-
self.norm = nn.LayerNorm(cross_attention_dim)
|
391 |
-
|
392 |
-
def forward(self, image_embeds: torch.FloatTensor):
|
393 |
-
batch_size = image_embeds.shape[0]
|
394 |
-
|
395 |
-
# image
|
396 |
-
image_embeds = self.image_embeds(image_embeds)
|
397 |
-
image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
|
398 |
-
image_embeds = self.norm(image_embeds)
|
399 |
-
return image_embeds
|
400 |
-
|
401 |
-
|
402 |
-
class CombinedTimestepLabelEmbeddings(nn.Module):
|
403 |
-
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
|
404 |
-
super().__init__()
|
405 |
-
|
406 |
-
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
407 |
-
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
408 |
-
self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
|
409 |
-
|
410 |
-
def forward(self, timestep, class_labels, hidden_dtype=None):
|
411 |
-
timesteps_proj = self.time_proj(timestep)
|
412 |
-
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
413 |
-
|
414 |
-
class_labels = self.class_embedder(class_labels) # (N, D)
|
415 |
-
|
416 |
-
conditioning = timesteps_emb + class_labels # (N, D)
|
417 |
-
|
418 |
-
return conditioning
|
419 |
-
|
420 |
-
|
421 |
-
class TextTimeEmbedding(nn.Module):
|
422 |
-
def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
|
423 |
-
super().__init__()
|
424 |
-
self.norm1 = nn.LayerNorm(encoder_dim)
|
425 |
-
self.pool = AttentionPooling(num_heads, encoder_dim)
|
426 |
-
self.proj = nn.Linear(encoder_dim, time_embed_dim)
|
427 |
-
self.norm2 = nn.LayerNorm(time_embed_dim)
|
428 |
-
|
429 |
-
def forward(self, hidden_states):
|
430 |
-
hidden_states = self.norm1(hidden_states)
|
431 |
-
hidden_states = self.pool(hidden_states)
|
432 |
-
hidden_states = self.proj(hidden_states)
|
433 |
-
hidden_states = self.norm2(hidden_states)
|
434 |
-
return hidden_states
|
435 |
-
|
436 |
-
|
437 |
-
class TextImageTimeEmbedding(nn.Module):
|
438 |
-
def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536):
|
439 |
-
super().__init__()
|
440 |
-
self.text_proj = nn.Linear(text_embed_dim, time_embed_dim)
|
441 |
-
self.text_norm = nn.LayerNorm(time_embed_dim)
|
442 |
-
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
|
443 |
-
|
444 |
-
def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
|
445 |
-
# text
|
446 |
-
time_text_embeds = self.text_proj(text_embeds)
|
447 |
-
time_text_embeds = self.text_norm(time_text_embeds)
|
448 |
-
|
449 |
-
# image
|
450 |
-
time_image_embeds = self.image_proj(image_embeds)
|
451 |
-
|
452 |
-
return time_image_embeds + time_text_embeds
|
453 |
-
|
454 |
-
|
455 |
-
class ImageTimeEmbedding(nn.Module):
|
456 |
-
def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
|
457 |
-
super().__init__()
|
458 |
-
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
|
459 |
-
self.image_norm = nn.LayerNorm(time_embed_dim)
|
460 |
-
|
461 |
-
def forward(self, image_embeds: torch.FloatTensor):
|
462 |
-
# image
|
463 |
-
time_image_embeds = self.image_proj(image_embeds)
|
464 |
-
time_image_embeds = self.image_norm(time_image_embeds)
|
465 |
-
return time_image_embeds
|
466 |
-
|
467 |
-
|
468 |
-
class ImageHintTimeEmbedding(nn.Module):
|
469 |
-
def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
|
470 |
-
super().__init__()
|
471 |
-
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
|
472 |
-
self.image_norm = nn.LayerNorm(time_embed_dim)
|
473 |
-
self.input_hint_block = nn.Sequential(
|
474 |
-
nn.Conv2d(3, 16, 3, padding=1),
|
475 |
-
nn.SiLU(),
|
476 |
-
nn.Conv2d(16, 16, 3, padding=1),
|
477 |
-
nn.SiLU(),
|
478 |
-
nn.Conv2d(16, 32, 3, padding=1, stride=2),
|
479 |
-
nn.SiLU(),
|
480 |
-
nn.Conv2d(32, 32, 3, padding=1),
|
481 |
-
nn.SiLU(),
|
482 |
-
nn.Conv2d(32, 96, 3, padding=1, stride=2),
|
483 |
-
nn.SiLU(),
|
484 |
-
nn.Conv2d(96, 96, 3, padding=1),
|
485 |
-
nn.SiLU(),
|
486 |
-
nn.Conv2d(96, 256, 3, padding=1, stride=2),
|
487 |
-
nn.SiLU(),
|
488 |
-
nn.Conv2d(256, 4, 3, padding=1),
|
489 |
-
)
|
490 |
-
|
491 |
-
def forward(self, image_embeds: torch.FloatTensor, hint: torch.FloatTensor):
|
492 |
-
# image
|
493 |
-
time_image_embeds = self.image_proj(image_embeds)
|
494 |
-
time_image_embeds = self.image_norm(time_image_embeds)
|
495 |
-
hint = self.input_hint_block(hint)
|
496 |
-
return time_image_embeds, hint
|
497 |
-
|
498 |
-
|
499 |
-
class AttentionPooling(nn.Module):
|
500 |
-
# Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
|
501 |
-
|
502 |
-
def __init__(self, num_heads, embed_dim, dtype=None):
|
503 |
-
super().__init__()
|
504 |
-
self.dtype = dtype
|
505 |
-
self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5)
|
506 |
-
self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
|
507 |
-
self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
|
508 |
-
self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
|
509 |
-
self.num_heads = num_heads
|
510 |
-
self.dim_per_head = embed_dim // self.num_heads
|
511 |
-
|
512 |
-
def forward(self, x):
|
513 |
-
bs, length, width = x.size()
|
514 |
-
|
515 |
-
def shape(x):
|
516 |
-
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
517 |
-
x = x.view(bs, -1, self.num_heads, self.dim_per_head)
|
518 |
-
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
519 |
-
x = x.transpose(1, 2)
|
520 |
-
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
521 |
-
x = x.reshape(bs * self.num_heads, -1, self.dim_per_head)
|
522 |
-
# (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length)
|
523 |
-
x = x.transpose(1, 2)
|
524 |
-
return x
|
525 |
-
|
526 |
-
class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype)
|
527 |
-
x = torch.cat([class_token, x], dim=1) # (bs, length+1, width)
|
528 |
-
|
529 |
-
# (bs*n_heads, class_token_length, dim_per_head)
|
530 |
-
q = shape(self.q_proj(class_token))
|
531 |
-
# (bs*n_heads, length+class_token_length, dim_per_head)
|
532 |
-
k = shape(self.k_proj(x))
|
533 |
-
v = shape(self.v_proj(x))
|
534 |
-
|
535 |
-
# (bs*n_heads, class_token_length, length+class_token_length):
|
536 |
-
scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))
|
537 |
-
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
|
538 |
-
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
539 |
-
|
540 |
-
# (bs*n_heads, dim_per_head, class_token_length)
|
541 |
-
a = torch.einsum("bts,bcs->bct", weight, v)
|
542 |
-
|
543 |
-
# (bs, length+1, width)
|
544 |
-
a = a.reshape(bs, -1, 1).transpose(1, 2)
|
545 |
-
|
546 |
-
return a[:, 0, :] # cls_token
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4DoF/diffusers/models/embeddings_flax.py
DELETED
@@ -1,95 +0,0 @@
|
|
1 |
-
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
import math
|
15 |
-
|
16 |
-
import flax.linen as nn
|
17 |
-
import jax.numpy as jnp
|
18 |
-
|
19 |
-
|
20 |
-
def get_sinusoidal_embeddings(
|
21 |
-
timesteps: jnp.ndarray,
|
22 |
-
embedding_dim: int,
|
23 |
-
freq_shift: float = 1,
|
24 |
-
min_timescale: float = 1,
|
25 |
-
max_timescale: float = 1.0e4,
|
26 |
-
flip_sin_to_cos: bool = False,
|
27 |
-
scale: float = 1.0,
|
28 |
-
) -> jnp.ndarray:
|
29 |
-
"""Returns the positional encoding (same as Tensor2Tensor).
|
30 |
-
|
31 |
-
Args:
|
32 |
-
timesteps: a 1-D Tensor of N indices, one per batch element.
|
33 |
-
These may be fractional.
|
34 |
-
embedding_dim: The number of output channels.
|
35 |
-
min_timescale: The smallest time unit (should probably be 0.0).
|
36 |
-
max_timescale: The largest time unit.
|
37 |
-
Returns:
|
38 |
-
a Tensor of timing signals [N, num_channels]
|
39 |
-
"""
|
40 |
-
assert timesteps.ndim == 1, "Timesteps should be a 1d-array"
|
41 |
-
assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even"
|
42 |
-
num_timescales = float(embedding_dim // 2)
|
43 |
-
log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift)
|
44 |
-
inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment)
|
45 |
-
emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0)
|
46 |
-
|
47 |
-
# scale embeddings
|
48 |
-
scaled_time = scale * emb
|
49 |
-
|
50 |
-
if flip_sin_to_cos:
|
51 |
-
signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1)
|
52 |
-
else:
|
53 |
-
signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1)
|
54 |
-
signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])
|
55 |
-
return signal
|
56 |
-
|
57 |
-
|
58 |
-
class FlaxTimestepEmbedding(nn.Module):
|
59 |
-
r"""
|
60 |
-
Time step Embedding Module. Learns embeddings for input time steps.
|
61 |
-
|
62 |
-
Args:
|
63 |
-
time_embed_dim (`int`, *optional*, defaults to `32`):
|
64 |
-
Time step embedding dimension
|
65 |
-
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
66 |
-
Parameters `dtype`
|
67 |
-
"""
|
68 |
-
time_embed_dim: int = 32
|
69 |
-
dtype: jnp.dtype = jnp.float32
|
70 |
-
|
71 |
-
@nn.compact
|
72 |
-
def __call__(self, temb):
|
73 |
-
temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb)
|
74 |
-
temb = nn.silu(temb)
|
75 |
-
temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb)
|
76 |
-
return temb
|
77 |
-
|
78 |
-
|
79 |
-
class FlaxTimesteps(nn.Module):
|
80 |
-
r"""
|
81 |
-
Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239
|
82 |
-
|
83 |
-
Args:
|
84 |
-
dim (`int`, *optional*, defaults to `32`):
|
85 |
-
Time step embedding dimension
|
86 |
-
"""
|
87 |
-
dim: int = 32
|
88 |
-
flip_sin_to_cos: bool = False
|
89 |
-
freq_shift: float = 1
|
90 |
-
|
91 |
-
@nn.compact
|
92 |
-
def __call__(self, timesteps):
|
93 |
-
return get_sinusoidal_embeddings(
|
94 |
-
timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift
|
95 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4DoF/diffusers/models/modeling_flax_pytorch_utils.py
DELETED
@@ -1,118 +0,0 @@
|
|
1 |
-
# coding=utf-8
|
2 |
-
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
""" PyTorch - Flax general utilities."""
|
16 |
-
import re
|
17 |
-
|
18 |
-
import jax.numpy as jnp
|
19 |
-
from flax.traverse_util import flatten_dict, unflatten_dict
|
20 |
-
from jax.random import PRNGKey
|
21 |
-
|
22 |
-
from ..utils import logging
|
23 |
-
|
24 |
-
|
25 |
-
logger = logging.get_logger(__name__)
|
26 |
-
|
27 |
-
|
28 |
-
def rename_key(key):
|
29 |
-
regex = r"\w+[.]\d+"
|
30 |
-
pats = re.findall(regex, key)
|
31 |
-
for pat in pats:
|
32 |
-
key = key.replace(pat, "_".join(pat.split(".")))
|
33 |
-
return key
|
34 |
-
|
35 |
-
|
36 |
-
#####################
|
37 |
-
# PyTorch => Flax #
|
38 |
-
#####################
|
39 |
-
|
40 |
-
|
41 |
-
# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
|
42 |
-
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
|
43 |
-
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
|
44 |
-
"""Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
|
45 |
-
|
46 |
-
# conv norm or layer norm
|
47 |
-
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
|
48 |
-
if (
|
49 |
-
any("norm" in str_ for str_ in pt_tuple_key)
|
50 |
-
and (pt_tuple_key[-1] == "bias")
|
51 |
-
and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict)
|
52 |
-
and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
|
53 |
-
):
|
54 |
-
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
|
55 |
-
return renamed_pt_tuple_key, pt_tensor
|
56 |
-
elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
|
57 |
-
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
|
58 |
-
return renamed_pt_tuple_key, pt_tensor
|
59 |
-
|
60 |
-
# embedding
|
61 |
-
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
|
62 |
-
pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
|
63 |
-
return renamed_pt_tuple_key, pt_tensor
|
64 |
-
|
65 |
-
# conv layer
|
66 |
-
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
|
67 |
-
if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4:
|
68 |
-
pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
|
69 |
-
return renamed_pt_tuple_key, pt_tensor
|
70 |
-
|
71 |
-
# linear layer
|
72 |
-
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
|
73 |
-
if pt_tuple_key[-1] == "weight":
|
74 |
-
pt_tensor = pt_tensor.T
|
75 |
-
return renamed_pt_tuple_key, pt_tensor
|
76 |
-
|
77 |
-
# old PyTorch layer norm weight
|
78 |
-
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
|
79 |
-
if pt_tuple_key[-1] == "gamma":
|
80 |
-
return renamed_pt_tuple_key, pt_tensor
|
81 |
-
|
82 |
-
# old PyTorch layer norm bias
|
83 |
-
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
|
84 |
-
if pt_tuple_key[-1] == "beta":
|
85 |
-
return renamed_pt_tuple_key, pt_tensor
|
86 |
-
|
87 |
-
return pt_tuple_key, pt_tensor
|
88 |
-
|
89 |
-
|
90 |
-
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42):
|
91 |
-
# Step 1: Convert pytorch tensor to numpy
|
92 |
-
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
|
93 |
-
|
94 |
-
# Step 2: Since the model is stateless, get random Flax params
|
95 |
-
random_flax_params = flax_model.init_weights(PRNGKey(init_key))
|
96 |
-
|
97 |
-
random_flax_state_dict = flatten_dict(random_flax_params)
|
98 |
-
flax_state_dict = {}
|
99 |
-
|
100 |
-
# Need to change some parameters name to match Flax names
|
101 |
-
for pt_key, pt_tensor in pt_state_dict.items():
|
102 |
-
renamed_pt_key = rename_key(pt_key)
|
103 |
-
pt_tuple_key = tuple(renamed_pt_key.split("."))
|
104 |
-
|
105 |
-
# Correctly rename weight parameters
|
106 |
-
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict)
|
107 |
-
|
108 |
-
if flax_key in random_flax_state_dict:
|
109 |
-
if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
|
110 |
-
raise ValueError(
|
111 |
-
f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
|
112 |
-
f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
|
113 |
-
)
|
114 |
-
|
115 |
-
# also add unexpected weight so that warning is thrown
|
116 |
-
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
|
117 |
-
|
118 |
-
return unflatten_dict(flax_state_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4DoF/diffusers/models/modeling_flax_utils.py
DELETED
@@ -1,534 +0,0 @@
|
|
1 |
-
# coding=utf-8
|
2 |
-
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
|
16 |
-
import os
|
17 |
-
from pickle import UnpicklingError
|
18 |
-
from typing import Any, Dict, Union
|
19 |
-
|
20 |
-
import jax
|
21 |
-
import jax.numpy as jnp
|
22 |
-
import msgpack.exceptions
|
23 |
-
from flax.core.frozen_dict import FrozenDict, unfreeze
|
24 |
-
from flax.serialization import from_bytes, to_bytes
|
25 |
-
from flax.traverse_util import flatten_dict, unflatten_dict
|
26 |
-
from huggingface_hub import hf_hub_download
|
27 |
-
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
28 |
-
from requests import HTTPError
|
29 |
-
|
30 |
-
from .. import __version__, is_torch_available
|
31 |
-
from ..utils import (
|
32 |
-
CONFIG_NAME,
|
33 |
-
DIFFUSERS_CACHE,
|
34 |
-
FLAX_WEIGHTS_NAME,
|
35 |
-
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
36 |
-
WEIGHTS_NAME,
|
37 |
-
logging,
|
38 |
-
)
|
39 |
-
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
|
40 |
-
|
41 |
-
|
42 |
-
logger = logging.get_logger(__name__)
|
43 |
-
|
44 |
-
|
45 |
-
class FlaxModelMixin:
|
46 |
-
r"""
|
47 |
-
Base class for all Flax models.
|
48 |
-
|
49 |
-
[`FlaxModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
|
50 |
-
saving models.
|
51 |
-
|
52 |
-
- **config_name** ([`str`]) -- Filename to save a model to when calling [`~FlaxModelMixin.save_pretrained`].
|
53 |
-
"""
|
54 |
-
config_name = CONFIG_NAME
|
55 |
-
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
56 |
-
_flax_internal_args = ["name", "parent", "dtype"]
|
57 |
-
|
58 |
-
@classmethod
|
59 |
-
def _from_config(cls, config, **kwargs):
|
60 |
-
"""
|
61 |
-
All context managers that the model should be initialized under go here.
|
62 |
-
"""
|
63 |
-
return cls(config, **kwargs)
|
64 |
-
|
65 |
-
def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
|
66 |
-
"""
|
67 |
-
Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
|
68 |
-
"""
|
69 |
-
|
70 |
-
# taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
|
71 |
-
def conditional_cast(param):
|
72 |
-
if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
|
73 |
-
param = param.astype(dtype)
|
74 |
-
return param
|
75 |
-
|
76 |
-
if mask is None:
|
77 |
-
return jax.tree_map(conditional_cast, params)
|
78 |
-
|
79 |
-
flat_params = flatten_dict(params)
|
80 |
-
flat_mask, _ = jax.tree_flatten(mask)
|
81 |
-
|
82 |
-
for masked, key in zip(flat_mask, flat_params.keys()):
|
83 |
-
if masked:
|
84 |
-
param = flat_params[key]
|
85 |
-
flat_params[key] = conditional_cast(param)
|
86 |
-
|
87 |
-
return unflatten_dict(flat_params)
|
88 |
-
|
89 |
-
def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
|
90 |
-
r"""
|
91 |
-
Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
|
92 |
-
the `params` in place.
|
93 |
-
|
94 |
-
This method can be used on a TPU to explicitly convert the model parameters to bfloat16 precision to do full
|
95 |
-
half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.
|
96 |
-
|
97 |
-
Arguments:
|
98 |
-
params (`Union[Dict, FrozenDict]`):
|
99 |
-
A `PyTree` of model parameters.
|
100 |
-
mask (`Union[Dict, FrozenDict]`):
|
101 |
-
A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
|
102 |
-
for params you want to cast, and `False` for those you want to skip.
|
103 |
-
|
104 |
-
Examples:
|
105 |
-
|
106 |
-
```python
|
107 |
-
>>> from diffusers import FlaxUNet2DConditionModel
|
108 |
-
|
109 |
-
>>> # load model
|
110 |
-
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
|
111 |
-
>>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
|
112 |
-
>>> params = model.to_bf16(params)
|
113 |
-
>>> # If you don't want to cast certain parameters (for example layer norm bias and scale)
|
114 |
-
>>> # then pass the mask as follows
|
115 |
-
>>> from flax import traverse_util
|
116 |
-
|
117 |
-
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
|
118 |
-
>>> flat_params = traverse_util.flatten_dict(params)
|
119 |
-
>>> mask = {
|
120 |
-
... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
|
121 |
-
... for path in flat_params
|
122 |
-
... }
|
123 |
-
>>> mask = traverse_util.unflatten_dict(mask)
|
124 |
-
>>> params = model.to_bf16(params, mask)
|
125 |
-
```"""
|
126 |
-
return self._cast_floating_to(params, jnp.bfloat16, mask)
|
127 |
-
|
128 |
-
def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
|
129 |
-
r"""
|
130 |
-
Cast the floating-point `params` to `jax.numpy.float32`. This method can be used to explicitly convert the
|
131 |
-
model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place.
|
132 |
-
|
133 |
-
Arguments:
|
134 |
-
params (`Union[Dict, FrozenDict]`):
|
135 |
-
A `PyTree` of model parameters.
|
136 |
-
mask (`Union[Dict, FrozenDict]`):
|
137 |
-
A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
|
138 |
-
for params you want to cast, and `False` for those you want to skip.
|
139 |
-
|
140 |
-
Examples:
|
141 |
-
|
142 |
-
```python
|
143 |
-
>>> from diffusers import FlaxUNet2DConditionModel
|
144 |
-
|
145 |
-
>>> # Download model and configuration from huggingface.co
|
146 |
-
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
|
147 |
-
>>> # By default, the model params will be in fp32, to illustrate the use of this method,
|
148 |
-
>>> # we'll first cast to fp16 and back to fp32
|
149 |
-
>>> params = model.to_f16(params)
|
150 |
-
>>> # now cast back to fp32
|
151 |
-
>>> params = model.to_fp32(params)
|
152 |
-
```"""
|
153 |
-
return self._cast_floating_to(params, jnp.float32, mask)
|
154 |
-
|
155 |
-
def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
|
156 |
-
r"""
|
157 |
-
Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the
|
158 |
-
`params` in place.
|
159 |
-
|
160 |
-
This method can be used on a GPU to explicitly convert the model parameters to float16 precision to do full
|
161 |
-
half-precision training or to save weights in float16 for inference in order to save memory and improve speed.
|
162 |
-
|
163 |
-
Arguments:
|
164 |
-
params (`Union[Dict, FrozenDict]`):
|
165 |
-
A `PyTree` of model parameters.
|
166 |
-
mask (`Union[Dict, FrozenDict]`):
|
167 |
-
A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
|
168 |
-
for params you want to cast, and `False` for those you want to skip.
|
169 |
-
|
170 |
-
Examples:
|
171 |
-
|
172 |
-
```python
|
173 |
-
>>> from diffusers import FlaxUNet2DConditionModel
|
174 |
-
|
175 |
-
>>> # load model
|
176 |
-
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
|
177 |
-
>>> # By default, the model params will be in fp32, to cast these to float16
|
178 |
-
>>> params = model.to_fp16(params)
|
179 |
-
>>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
|
180 |
-
>>> # then pass the mask as follows
|
181 |
-
>>> from flax import traverse_util
|
182 |
-
|
183 |
-
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
|
184 |
-
>>> flat_params = traverse_util.flatten_dict(params)
|
185 |
-
>>> mask = {
|
186 |
-
... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
|
187 |
-
... for path in flat_params
|
188 |
-
... }
|
189 |
-
>>> mask = traverse_util.unflatten_dict(mask)
|
190 |
-
>>> params = model.to_fp16(params, mask)
|
191 |
-
```"""
|
192 |
-
return self._cast_floating_to(params, jnp.float16, mask)
|
193 |
-
|
194 |
-
def init_weights(self, rng: jax.random.KeyArray) -> Dict:
|
195 |
-
raise NotImplementedError(f"init_weights method has to be implemented for {self}")
|
196 |
-
|
197 |
-
@classmethod
|
198 |
-
def from_pretrained(
|
199 |
-
cls,
|
200 |
-
pretrained_model_name_or_path: Union[str, os.PathLike],
|
201 |
-
dtype: jnp.dtype = jnp.float32,
|
202 |
-
*model_args,
|
203 |
-
**kwargs,
|
204 |
-
):
|
205 |
-
r"""
|
206 |
-
Instantiate a pretrained Flax model from a pretrained model configuration.
|
207 |
-
|
208 |
-
Parameters:
|
209 |
-
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
210 |
-
Can be either:
|
211 |
-
|
212 |
-
- A string, the *model id* (for example `runwayml/stable-diffusion-v1-5`) of a pretrained model
|
213 |
-
hosted on the Hub.
|
214 |
-
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
215 |
-
using [`~FlaxModelMixin.save_pretrained`].
|
216 |
-
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
|
217 |
-
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
|
218 |
-
`jax.numpy.bfloat16` (on TPUs).
|
219 |
-
|
220 |
-
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
221 |
-
specified, all the computation will be performed with the given `dtype`.
|
222 |
-
|
223 |
-
<Tip>
|
224 |
-
|
225 |
-
This only specifies the dtype of the *computation* and does not influence the dtype of model
|
226 |
-
parameters.
|
227 |
-
|
228 |
-
If you wish to change the dtype of the model parameters, see [`~FlaxModelMixin.to_fp16`] and
|
229 |
-
[`~FlaxModelMixin.to_bf16`].
|
230 |
-
|
231 |
-
</Tip>
|
232 |
-
|
233 |
-
model_args (sequence of positional arguments, *optional*):
|
234 |
-
All remaining positional arguments are passed to the underlying model's `__init__` method.
|
235 |
-
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
236 |
-
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
237 |
-
is not used.
|
238 |
-
force_download (`bool`, *optional*, defaults to `False`):
|
239 |
-
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
240 |
-
cached versions if they exist.
|
241 |
-
resume_download (`bool`, *optional*, defaults to `False`):
|
242 |
-
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
243 |
-
incompletely downloaded files are deleted.
|
244 |
-
proxies (`Dict[str, str]`, *optional*):
|
245 |
-
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
246 |
-
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
247 |
-
local_files_only(`bool`, *optional*, defaults to `False`):
|
248 |
-
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
249 |
-
won't be downloaded from the Hub.
|
250 |
-
revision (`str`, *optional*, defaults to `"main"`):
|
251 |
-
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
252 |
-
allowed by Git.
|
253 |
-
from_pt (`bool`, *optional*, defaults to `False`):
|
254 |
-
Load the model weights from a PyTorch checkpoint save file.
|
255 |
-
kwargs (remaining dictionary of keyword arguments, *optional*):
|
256 |
-
Can be used to update the configuration object (after it is loaded) and initiate the model (for
|
257 |
-
example, `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
|
258 |
-
automatically loaded:
|
259 |
-
|
260 |
-
- If a configuration is provided with `config`, `kwargs` are directly passed to the underlying
|
261 |
-
model's `__init__` method (we assume all relevant updates to the configuration have already been
|
262 |
-
done).
|
263 |
-
- If a configuration is not provided, `kwargs` are first passed to the configuration class
|
264 |
-
initialization function [`~ConfigMixin.from_config`]. Each key of the `kwargs` that corresponds
|
265 |
-
to a configuration attribute is used to override said attribute with the supplied `kwargs` value.
|
266 |
-
Remaining keys that do not correspond to any configuration attribute are passed to the underlying
|
267 |
-
model's `__init__` function.
|
268 |
-
|
269 |
-
Examples:
|
270 |
-
|
271 |
-
```python
|
272 |
-
>>> from diffusers import FlaxUNet2DConditionModel
|
273 |
-
|
274 |
-
>>> # Download model and configuration from huggingface.co and cache.
|
275 |
-
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
|
276 |
-
>>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
|
277 |
-
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/")
|
278 |
-
```
|
279 |
-
|
280 |
-
If you get the error message below, you need to finetune the weights for your downstream task:
|
281 |
-
|
282 |
-
```bash
|
283 |
-
Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
|
284 |
-
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
|
285 |
-
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
286 |
-
```
|
287 |
-
"""
|
288 |
-
config = kwargs.pop("config", None)
|
289 |
-
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
290 |
-
force_download = kwargs.pop("force_download", False)
|
291 |
-
from_pt = kwargs.pop("from_pt", False)
|
292 |
-
resume_download = kwargs.pop("resume_download", False)
|
293 |
-
proxies = kwargs.pop("proxies", None)
|
294 |
-
local_files_only = kwargs.pop("local_files_only", False)
|
295 |
-
use_auth_token = kwargs.pop("use_auth_token", None)
|
296 |
-
revision = kwargs.pop("revision", None)
|
297 |
-
subfolder = kwargs.pop("subfolder", None)
|
298 |
-
|
299 |
-
user_agent = {
|
300 |
-
"diffusers": __version__,
|
301 |
-
"file_type": "model",
|
302 |
-
"framework": "flax",
|
303 |
-
}
|
304 |
-
|
305 |
-
# Load config if we don't provide a configuration
|
306 |
-
config_path = config if config is not None else pretrained_model_name_or_path
|
307 |
-
model, model_kwargs = cls.from_config(
|
308 |
-
config_path,
|
309 |
-
cache_dir=cache_dir,
|
310 |
-
return_unused_kwargs=True,
|
311 |
-
force_download=force_download,
|
312 |
-
resume_download=resume_download,
|
313 |
-
proxies=proxies,
|
314 |
-
local_files_only=local_files_only,
|
315 |
-
use_auth_token=use_auth_token,
|
316 |
-
revision=revision,
|
317 |
-
subfolder=subfolder,
|
318 |
-
# model args
|
319 |
-
dtype=dtype,
|
320 |
-
**kwargs,
|
321 |
-
)
|
322 |
-
|
323 |
-
# Load model
|
324 |
-
pretrained_path_with_subfolder = (
|
325 |
-
pretrained_model_name_or_path
|
326 |
-
if subfolder is None
|
327 |
-
else os.path.join(pretrained_model_name_or_path, subfolder)
|
328 |
-
)
|
329 |
-
if os.path.isdir(pretrained_path_with_subfolder):
|
330 |
-
if from_pt:
|
331 |
-
if not os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
|
332 |
-
raise EnvironmentError(
|
333 |
-
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_path_with_subfolder} "
|
334 |
-
)
|
335 |
-
model_file = os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)
|
336 |
-
elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)):
|
337 |
-
# Load from a Flax checkpoint
|
338 |
-
model_file = os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)
|
339 |
-
# Check if pytorch weights exist instead
|
340 |
-
elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
|
341 |
-
raise EnvironmentError(
|
342 |
-
f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model"
|
343 |
-
" using `from_pt=True`."
|
344 |
-
)
|
345 |
-
else:
|
346 |
-
raise EnvironmentError(
|
347 |
-
f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
|
348 |
-
f"{pretrained_path_with_subfolder}."
|
349 |
-
)
|
350 |
-
else:
|
351 |
-
try:
|
352 |
-
model_file = hf_hub_download(
|
353 |
-
pretrained_model_name_or_path,
|
354 |
-
filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME,
|
355 |
-
cache_dir=cache_dir,
|
356 |
-
force_download=force_download,
|
357 |
-
proxies=proxies,
|
358 |
-
resume_download=resume_download,
|
359 |
-
local_files_only=local_files_only,
|
360 |
-
use_auth_token=use_auth_token,
|
361 |
-
user_agent=user_agent,
|
362 |
-
subfolder=subfolder,
|
363 |
-
revision=revision,
|
364 |
-
)
|
365 |
-
|
366 |
-
except RepositoryNotFoundError:
|
367 |
-
raise EnvironmentError(
|
368 |
-
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
369 |
-
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
370 |
-
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
371 |
-
"login`."
|
372 |
-
)
|
373 |
-
except RevisionNotFoundError:
|
374 |
-
raise EnvironmentError(
|
375 |
-
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
376 |
-
"this model name. Check the model page at "
|
377 |
-
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
378 |
-
)
|
379 |
-
except EntryNotFoundError:
|
380 |
-
raise EnvironmentError(
|
381 |
-
f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME}."
|
382 |
-
)
|
383 |
-
except HTTPError as err:
|
384 |
-
raise EnvironmentError(
|
385 |
-
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
|
386 |
-
f"{err}"
|
387 |
-
)
|
388 |
-
except ValueError:
|
389 |
-
raise EnvironmentError(
|
390 |
-
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
391 |
-
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
392 |
-
f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your"
|
393 |
-
" internet connection or see how to run the library in offline mode at"
|
394 |
-
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
395 |
-
)
|
396 |
-
except EnvironmentError:
|
397 |
-
raise EnvironmentError(
|
398 |
-
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
399 |
-
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
400 |
-
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
401 |
-
f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
|
402 |
-
)
|
403 |
-
|
404 |
-
if from_pt:
|
405 |
-
if is_torch_available():
|
406 |
-
from .modeling_utils import load_state_dict
|
407 |
-
else:
|
408 |
-
raise EnvironmentError(
|
409 |
-
"Can't load the model in PyTorch format because PyTorch is not installed. "
|
410 |
-
"Please, install PyTorch or use native Flax weights."
|
411 |
-
)
|
412 |
-
|
413 |
-
# Step 1: Get the pytorch file
|
414 |
-
pytorch_model_file = load_state_dict(model_file)
|
415 |
-
|
416 |
-
# Step 2: Convert the weights
|
417 |
-
state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model)
|
418 |
-
else:
|
419 |
-
try:
|
420 |
-
with open(model_file, "rb") as state_f:
|
421 |
-
state = from_bytes(cls, state_f.read())
|
422 |
-
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
|
423 |
-
try:
|
424 |
-
with open(model_file) as f:
|
425 |
-
if f.read().startswith("version"):
|
426 |
-
raise OSError(
|
427 |
-
"You seem to have cloned a repository without having git-lfs installed. Please"
|
428 |
-
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
|
429 |
-
" folder you cloned."
|
430 |
-
)
|
431 |
-
else:
|
432 |
-
raise ValueError from e
|
433 |
-
except (UnicodeDecodeError, ValueError):
|
434 |
-
raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
|
435 |
-
# make sure all arrays are stored as jnp.ndarray
|
436 |
-
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
|
437 |
-
# https://github.com/google/flax/issues/1261
|
438 |
-
state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)
|
439 |
-
|
440 |
-
# flatten dicts
|
441 |
-
state = flatten_dict(state)
|
442 |
-
|
443 |
-
params_shape_tree = jax.eval_shape(model.init_weights, rng=jax.random.PRNGKey(0))
|
444 |
-
required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
|
445 |
-
|
446 |
-
shape_state = flatten_dict(unfreeze(params_shape_tree))
|
447 |
-
|
448 |
-
missing_keys = required_params - set(state.keys())
|
449 |
-
unexpected_keys = set(state.keys()) - required_params
|
450 |
-
|
451 |
-
if missing_keys:
|
452 |
-
logger.warning(
|
453 |
-
f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
|
454 |
-
"Make sure to call model.init_weights to initialize the missing weights."
|
455 |
-
)
|
456 |
-
cls._missing_keys = missing_keys
|
457 |
-
|
458 |
-
for key in state.keys():
|
459 |
-
if key in shape_state and state[key].shape != shape_state[key].shape:
|
460 |
-
raise ValueError(
|
461 |
-
f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
|
462 |
-
f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. "
|
463 |
-
)
|
464 |
-
|
465 |
-
# remove unexpected keys to not be saved again
|
466 |
-
for unexpected_key in unexpected_keys:
|
467 |
-
del state[unexpected_key]
|
468 |
-
|
469 |
-
if len(unexpected_keys) > 0:
|
470 |
-
logger.warning(
|
471 |
-
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
472 |
-
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
|
473 |
-
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
|
474 |
-
" with another architecture."
|
475 |
-
)
|
476 |
-
else:
|
477 |
-
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
478 |
-
|
479 |
-
if len(missing_keys) > 0:
|
480 |
-
logger.warning(
|
481 |
-
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
482 |
-
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
|
483 |
-
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
484 |
-
)
|
485 |
-
else:
|
486 |
-
logger.info(
|
487 |
-
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
|
488 |
-
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
|
489 |
-
f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
|
490 |
-
" training."
|
491 |
-
)
|
492 |
-
|
493 |
-
return model, unflatten_dict(state)
|
494 |
-
|
495 |
-
def save_pretrained(
|
496 |
-
self,
|
497 |
-
save_directory: Union[str, os.PathLike],
|
498 |
-
params: Union[Dict, FrozenDict],
|
499 |
-
is_main_process: bool = True,
|
500 |
-
):
|
501 |
-
"""
|
502 |
-
Save a model and its configuration file to a directory so that it can be reloaded using the
|
503 |
-
[`~FlaxModelMixin.from_pretrained`] class method.
|
504 |
-
|
505 |
-
Arguments:
|
506 |
-
save_directory (`str` or `os.PathLike`):
|
507 |
-
Directory to save a model and its configuration file to. Will be created if it doesn't exist.
|
508 |
-
params (`Union[Dict, FrozenDict]`):
|
509 |
-
A `PyTree` of model parameters.
|
510 |
-
is_main_process (`bool`, *optional*, defaults to `True`):
|
511 |
-
Whether the process calling this is the main process or not. Useful during distributed training and you
|
512 |
-
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
513 |
-
process to avoid race conditions.
|
514 |
-
"""
|
515 |
-
if os.path.isfile(save_directory):
|
516 |
-
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
517 |
-
return
|
518 |
-
|
519 |
-
os.makedirs(save_directory, exist_ok=True)
|
520 |
-
|
521 |
-
model_to_save = self
|
522 |
-
|
523 |
-
# Attach architecture to the config
|
524 |
-
# Save the config
|
525 |
-
if is_main_process:
|
526 |
-
model_to_save.save_config(save_directory)
|
527 |
-
|
528 |
-
# save model
|
529 |
-
output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
|
530 |
-
with open(output_model_file, "wb") as f:
|
531 |
-
model_bytes = to_bytes(params)
|
532 |
-
f.write(model_bytes)
|
533 |
-
|
534 |
-
logger.info(f"Model weights saved in {output_model_file}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4DoF/diffusers/models/modeling_pytorch_flax_utils.py
DELETED
@@ -1,161 +0,0 @@
|
|
1 |
-
# coding=utf-8
|
2 |
-
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
""" PyTorch - Flax general utilities."""
|
16 |
-
|
17 |
-
from pickle import UnpicklingError
|
18 |
-
|
19 |
-
import jax
|
20 |
-
import jax.numpy as jnp
|
21 |
-
import numpy as np
|
22 |
-
from flax.serialization import from_bytes
|
23 |
-
from flax.traverse_util import flatten_dict
|
24 |
-
|
25 |
-
from ..utils import logging
|
26 |
-
|
27 |
-
|
28 |
-
logger = logging.get_logger(__name__)
|
29 |
-
|
30 |
-
|
31 |
-
#####################
|
32 |
-
# Flax => PyTorch #
|
33 |
-
#####################
|
34 |
-
|
35 |
-
|
36 |
-
# from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py#L224-L352
|
37 |
-
def load_flax_checkpoint_in_pytorch_model(pt_model, model_file):
|
38 |
-
try:
|
39 |
-
with open(model_file, "rb") as flax_state_f:
|
40 |
-
flax_state = from_bytes(None, flax_state_f.read())
|
41 |
-
except UnpicklingError as e:
|
42 |
-
try:
|
43 |
-
with open(model_file) as f:
|
44 |
-
if f.read().startswith("version"):
|
45 |
-
raise OSError(
|
46 |
-
"You seem to have cloned a repository without having git-lfs installed. Please"
|
47 |
-
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
|
48 |
-
" folder you cloned."
|
49 |
-
)
|
50 |
-
else:
|
51 |
-
raise ValueError from e
|
52 |
-
except (UnicodeDecodeError, ValueError):
|
53 |
-
raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
|
54 |
-
|
55 |
-
return load_flax_weights_in_pytorch_model(pt_model, flax_state)
|
56 |
-
|
57 |
-
|
58 |
-
def load_flax_weights_in_pytorch_model(pt_model, flax_state):
|
59 |
-
"""Load flax checkpoints in a PyTorch model"""
|
60 |
-
|
61 |
-
try:
|
62 |
-
import torch # noqa: F401
|
63 |
-
except ImportError:
|
64 |
-
logger.error(
|
65 |
-
"Loading Flax weights in PyTorch requires both PyTorch and Flax to be installed. Please see"
|
66 |
-
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
|
67 |
-
" instructions."
|
68 |
-
)
|
69 |
-
raise
|
70 |
-
|
71 |
-
# check if we have bf16 weights
|
72 |
-
is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values()
|
73 |
-
if any(is_type_bf16):
|
74 |
-
# convert all weights to fp32 if they are bf16 since torch.from_numpy can-not handle bf16
|
75 |
-
|
76 |
-
# and bf16 is not fully supported in PT yet.
|
77 |
-
logger.warning(
|
78 |
-
"Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` "
|
79 |
-
"before loading those in PyTorch model."
|
80 |
-
)
|
81 |
-
flax_state = jax.tree_util.tree_map(
|
82 |
-
lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state
|
83 |
-
)
|
84 |
-
|
85 |
-
pt_model.base_model_prefix = ""
|
86 |
-
|
87 |
-
flax_state_dict = flatten_dict(flax_state, sep=".")
|
88 |
-
pt_model_dict = pt_model.state_dict()
|
89 |
-
|
90 |
-
# keep track of unexpected & missing keys
|
91 |
-
unexpected_keys = []
|
92 |
-
missing_keys = set(pt_model_dict.keys())
|
93 |
-
|
94 |
-
for flax_key_tuple, flax_tensor in flax_state_dict.items():
|
95 |
-
flax_key_tuple_array = flax_key_tuple.split(".")
|
96 |
-
|
97 |
-
if flax_key_tuple_array[-1] == "kernel" and flax_tensor.ndim == 4:
|
98 |
-
flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
|
99 |
-
flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1))
|
100 |
-
elif flax_key_tuple_array[-1] == "kernel":
|
101 |
-
flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
|
102 |
-
flax_tensor = flax_tensor.T
|
103 |
-
elif flax_key_tuple_array[-1] == "scale":
|
104 |
-
flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
|
105 |
-
|
106 |
-
if "time_embedding" not in flax_key_tuple_array:
|
107 |
-
for i, flax_key_tuple_string in enumerate(flax_key_tuple_array):
|
108 |
-
flax_key_tuple_array[i] = (
|
109 |
-
flax_key_tuple_string.replace("_0", ".0")
|
110 |
-
.replace("_1", ".1")
|
111 |
-
.replace("_2", ".2")
|
112 |
-
.replace("_3", ".3")
|
113 |
-
.replace("_4", ".4")
|
114 |
-
.replace("_5", ".5")
|
115 |
-
.replace("_6", ".6")
|
116 |
-
.replace("_7", ".7")
|
117 |
-
.replace("_8", ".8")
|
118 |
-
.replace("_9", ".9")
|
119 |
-
)
|
120 |
-
|
121 |
-
flax_key = ".".join(flax_key_tuple_array)
|
122 |
-
|
123 |
-
if flax_key in pt_model_dict:
|
124 |
-
if flax_tensor.shape != pt_model_dict[flax_key].shape:
|
125 |
-
raise ValueError(
|
126 |
-
f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected "
|
127 |
-
f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}."
|
128 |
-
)
|
129 |
-
else:
|
130 |
-
# add weight to pytorch dict
|
131 |
-
flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor
|
132 |
-
pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
|
133 |
-
# remove from missing keys
|
134 |
-
missing_keys.remove(flax_key)
|
135 |
-
else:
|
136 |
-
# weight is not expected by PyTorch model
|
137 |
-
unexpected_keys.append(flax_key)
|
138 |
-
|
139 |
-
pt_model.load_state_dict(pt_model_dict)
|
140 |
-
|
141 |
-
# re-transform missing_keys to list
|
142 |
-
missing_keys = list(missing_keys)
|
143 |
-
|
144 |
-
if len(unexpected_keys) > 0:
|
145 |
-
logger.warning(
|
146 |
-
"Some weights of the Flax model were not used when initializing the PyTorch model"
|
147 |
-
f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
|
148 |
-
f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture"
|
149 |
-
" (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This"
|
150 |
-
f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect"
|
151 |
-
" to be exactly identical (e.g. initializing a BertForSequenceClassification model from a"
|
152 |
-
" FlaxBertForSequenceClassification model)."
|
153 |
-
)
|
154 |
-
if len(missing_keys) > 0:
|
155 |
-
logger.warning(
|
156 |
-
f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly"
|
157 |
-
f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to"
|
158 |
-
" use it for predictions and inference."
|
159 |
-
)
|
160 |
-
|
161 |
-
return pt_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4DoF/diffusers/models/modeling_utils.py
DELETED
@@ -1,980 +0,0 @@
|
|
1 |
-
# coding=utf-8
|
2 |
-
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
-
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
4 |
-
#
|
5 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
-
# you may not use this file except in compliance with the License.
|
7 |
-
# You may obtain a copy of the License at
|
8 |
-
#
|
9 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
-
#
|
11 |
-
# Unless required by applicable law or agreed to in writing, software
|
12 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
-
# See the License for the specific language governing permissions and
|
15 |
-
# limitations under the License.
|
16 |
-
|
17 |
-
import inspect
|
18 |
-
import itertools
|
19 |
-
import os
|
20 |
-
import re
|
21 |
-
from functools import partial
|
22 |
-
from typing import Any, Callable, List, Optional, Tuple, Union
|
23 |
-
|
24 |
-
import torch
|
25 |
-
from torch import Tensor, device, nn
|
26 |
-
|
27 |
-
from .. import __version__
|
28 |
-
from ..utils import (
|
29 |
-
CONFIG_NAME,
|
30 |
-
DIFFUSERS_CACHE,
|
31 |
-
FLAX_WEIGHTS_NAME,
|
32 |
-
HF_HUB_OFFLINE,
|
33 |
-
SAFETENSORS_WEIGHTS_NAME,
|
34 |
-
WEIGHTS_NAME,
|
35 |
-
_add_variant,
|
36 |
-
_get_model_file,
|
37 |
-
deprecate,
|
38 |
-
is_accelerate_available,
|
39 |
-
is_safetensors_available,
|
40 |
-
is_torch_version,
|
41 |
-
logging,
|
42 |
-
)
|
43 |
-
|
44 |
-
|
45 |
-
logger = logging.get_logger(__name__)
|
46 |
-
|
47 |
-
|
48 |
-
if is_torch_version(">=", "1.9.0"):
|
49 |
-
_LOW_CPU_MEM_USAGE_DEFAULT = True
|
50 |
-
else:
|
51 |
-
_LOW_CPU_MEM_USAGE_DEFAULT = False
|
52 |
-
|
53 |
-
|
54 |
-
if is_accelerate_available():
|
55 |
-
import accelerate
|
56 |
-
from accelerate.utils import set_module_tensor_to_device
|
57 |
-
from accelerate.utils.versions import is_torch_version
|
58 |
-
|
59 |
-
if is_safetensors_available():
|
60 |
-
import safetensors
|
61 |
-
|
62 |
-
|
63 |
-
def get_parameter_device(parameter: torch.nn.Module):
|
64 |
-
try:
|
65 |
-
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
|
66 |
-
return next(parameters_and_buffers).device
|
67 |
-
except StopIteration:
|
68 |
-
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
69 |
-
|
70 |
-
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
71 |
-
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
72 |
-
return tuples
|
73 |
-
|
74 |
-
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
75 |
-
first_tuple = next(gen)
|
76 |
-
return first_tuple[1].device
|
77 |
-
|
78 |
-
|
79 |
-
def get_parameter_dtype(parameter: torch.nn.Module):
|
80 |
-
try:
|
81 |
-
params = tuple(parameter.parameters())
|
82 |
-
if len(params) > 0:
|
83 |
-
return params[0].dtype
|
84 |
-
|
85 |
-
buffers = tuple(parameter.buffers())
|
86 |
-
if len(buffers) > 0:
|
87 |
-
return buffers[0].dtype
|
88 |
-
|
89 |
-
except StopIteration:
|
90 |
-
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
91 |
-
|
92 |
-
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
93 |
-
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
94 |
-
return tuples
|
95 |
-
|
96 |
-
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
97 |
-
first_tuple = next(gen)
|
98 |
-
return first_tuple[1].dtype
|
99 |
-
|
100 |
-
|
101 |
-
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
|
102 |
-
"""
|
103 |
-
Reads a checkpoint file, returning properly formatted errors if they arise.
|
104 |
-
"""
|
105 |
-
try:
|
106 |
-
if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant):
|
107 |
-
return torch.load(checkpoint_file, map_location="cpu")
|
108 |
-
else:
|
109 |
-
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
110 |
-
except Exception as e:
|
111 |
-
try:
|
112 |
-
with open(checkpoint_file) as f:
|
113 |
-
if f.read().startswith("version"):
|
114 |
-
raise OSError(
|
115 |
-
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
116 |
-
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
117 |
-
"you cloned."
|
118 |
-
)
|
119 |
-
else:
|
120 |
-
raise ValueError(
|
121 |
-
f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
|
122 |
-
"model. Make sure you have saved the model properly."
|
123 |
-
) from e
|
124 |
-
except (UnicodeDecodeError, ValueError):
|
125 |
-
raise OSError(
|
126 |
-
f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
|
127 |
-
f"at '{checkpoint_file}'. "
|
128 |
-
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
|
129 |
-
)
|
130 |
-
|
131 |
-
|
132 |
-
def _load_state_dict_into_model(model_to_load, state_dict):
|
133 |
-
# Convert old format to new format if needed from a PyTorch state_dict
|
134 |
-
# copy state_dict so _load_from_state_dict can modify it
|
135 |
-
state_dict = state_dict.copy()
|
136 |
-
error_msgs = []
|
137 |
-
|
138 |
-
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
139 |
-
# so we need to apply the function recursively.
|
140 |
-
def load(module: torch.nn.Module, prefix=""):
|
141 |
-
args = (state_dict, prefix, {}, True, [], [], error_msgs)
|
142 |
-
module._load_from_state_dict(*args)
|
143 |
-
|
144 |
-
for name, child in module._modules.items():
|
145 |
-
if child is not None:
|
146 |
-
load(child, prefix + name + ".")
|
147 |
-
|
148 |
-
load(model_to_load)
|
149 |
-
|
150 |
-
return error_msgs
|
151 |
-
|
152 |
-
|
153 |
-
class ModelMixin(torch.nn.Module):
|
154 |
-
r"""
|
155 |
-
Base class for all models.
|
156 |
-
|
157 |
-
[`ModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
|
158 |
-
saving models.
|
159 |
-
|
160 |
-
- **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`].
|
161 |
-
"""
|
162 |
-
config_name = CONFIG_NAME
|
163 |
-
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
164 |
-
_supports_gradient_checkpointing = False
|
165 |
-
_keys_to_ignore_on_load_unexpected = None
|
166 |
-
|
167 |
-
def __init__(self):
|
168 |
-
super().__init__()
|
169 |
-
|
170 |
-
def __getattr__(self, name: str) -> Any:
|
171 |
-
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
|
172 |
-
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
|
173 |
-
__getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__':
|
174 |
-
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
175 |
-
"""
|
176 |
-
|
177 |
-
is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
|
178 |
-
is_attribute = name in self.__dict__
|
179 |
-
|
180 |
-
if is_in_config and not is_attribute:
|
181 |
-
deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
|
182 |
-
deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
|
183 |
-
return self._internal_dict[name]
|
184 |
-
|
185 |
-
# call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
186 |
-
return super().__getattr__(name)
|
187 |
-
|
188 |
-
@property
|
189 |
-
def is_gradient_checkpointing(self) -> bool:
|
190 |
-
"""
|
191 |
-
Whether gradient checkpointing is activated for this model or not.
|
192 |
-
"""
|
193 |
-
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
|
194 |
-
|
195 |
-
def enable_gradient_checkpointing(self):
|
196 |
-
"""
|
197 |
-
Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
|
198 |
-
*checkpoint activations* in other frameworks).
|
199 |
-
"""
|
200 |
-
if not self._supports_gradient_checkpointing:
|
201 |
-
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
202 |
-
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
203 |
-
|
204 |
-
def disable_gradient_checkpointing(self):
|
205 |
-
"""
|
206 |
-
Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
|
207 |
-
*checkpoint activations* in other frameworks).
|
208 |
-
"""
|
209 |
-
if self._supports_gradient_checkpointing:
|
210 |
-
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
211 |
-
|
212 |
-
def set_use_memory_efficient_attention_xformers(
|
213 |
-
self, valid: bool, attention_op: Optional[Callable] = None
|
214 |
-
) -> None:
|
215 |
-
# Recursively walk through all the children.
|
216 |
-
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
217 |
-
# gets the message
|
218 |
-
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
219 |
-
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
220 |
-
module.set_use_memory_efficient_attention_xformers(valid, attention_op)
|
221 |
-
|
222 |
-
for child in module.children():
|
223 |
-
fn_recursive_set_mem_eff(child)
|
224 |
-
|
225 |
-
for module in self.children():
|
226 |
-
if isinstance(module, torch.nn.Module):
|
227 |
-
fn_recursive_set_mem_eff(module)
|
228 |
-
|
229 |
-
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
|
230 |
-
r"""
|
231 |
-
Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
|
232 |
-
|
233 |
-
When this option is enabled, you should observe lower GPU memory usage and a potential speed up during
|
234 |
-
inference. Speed up during training is not guaranteed.
|
235 |
-
|
236 |
-
<Tip warning={true}>
|
237 |
-
|
238 |
-
⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
|
239 |
-
precedent.
|
240 |
-
|
241 |
-
</Tip>
|
242 |
-
|
243 |
-
Parameters:
|
244 |
-
attention_op (`Callable`, *optional*):
|
245 |
-
Override the default `None` operator for use as `op` argument to the
|
246 |
-
[`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
|
247 |
-
function of xFormers.
|
248 |
-
|
249 |
-
Examples:
|
250 |
-
|
251 |
-
```py
|
252 |
-
>>> import torch
|
253 |
-
>>> from diffusers import UNet2DConditionModel
|
254 |
-
>>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
|
255 |
-
|
256 |
-
>>> model = UNet2DConditionModel.from_pretrained(
|
257 |
-
... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
|
258 |
-
... )
|
259 |
-
>>> model = model.to("cuda")
|
260 |
-
>>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
|
261 |
-
```
|
262 |
-
"""
|
263 |
-
self.set_use_memory_efficient_attention_xformers(True, attention_op)
|
264 |
-
|
265 |
-
def disable_xformers_memory_efficient_attention(self):
|
266 |
-
r"""
|
267 |
-
Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
|
268 |
-
"""
|
269 |
-
self.set_use_memory_efficient_attention_xformers(False)
|
270 |
-
|
271 |
-
def save_pretrained(
|
272 |
-
self,
|
273 |
-
save_directory: Union[str, os.PathLike],
|
274 |
-
is_main_process: bool = True,
|
275 |
-
save_function: Callable = None,
|
276 |
-
safe_serialization: bool = False,
|
277 |
-
variant: Optional[str] = None,
|
278 |
-
):
|
279 |
-
"""
|
280 |
-
Save a model and its configuration file to a directory so that it can be reloaded using the
|
281 |
-
[`~models.ModelMixin.from_pretrained`] class method.
|
282 |
-
|
283 |
-
Arguments:
|
284 |
-
save_directory (`str` or `os.PathLike`):
|
285 |
-
Directory to save a model and its configuration file to. Will be created if it doesn't exist.
|
286 |
-
is_main_process (`bool`, *optional*, defaults to `True`):
|
287 |
-
Whether the process calling this is the main process or not. Useful during distributed training and you
|
288 |
-
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
289 |
-
process to avoid race conditions.
|
290 |
-
save_function (`Callable`):
|
291 |
-
The function to use to save the state dictionary. Useful during distributed training when you need to
|
292 |
-
replace `torch.save` with another method. Can be configured with the environment variable
|
293 |
-
`DIFFUSERS_SAVE_MODE`.
|
294 |
-
safe_serialization (`bool`, *optional*, defaults to `False`):
|
295 |
-
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
296 |
-
variant (`str`, *optional*):
|
297 |
-
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
|
298 |
-
"""
|
299 |
-
if safe_serialization and not is_safetensors_available():
|
300 |
-
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
|
301 |
-
|
302 |
-
if os.path.isfile(save_directory):
|
303 |
-
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
304 |
-
return
|
305 |
-
|
306 |
-
os.makedirs(save_directory, exist_ok=True)
|
307 |
-
|
308 |
-
model_to_save = self
|
309 |
-
|
310 |
-
# Attach architecture to the config
|
311 |
-
# Save the config
|
312 |
-
if is_main_process:
|
313 |
-
model_to_save.save_config(save_directory)
|
314 |
-
|
315 |
-
# Save the model
|
316 |
-
state_dict = model_to_save.state_dict()
|
317 |
-
|
318 |
-
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
319 |
-
weights_name = _add_variant(weights_name, variant)
|
320 |
-
|
321 |
-
# Save the model
|
322 |
-
if safe_serialization:
|
323 |
-
safetensors.torch.save_file(
|
324 |
-
state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"}
|
325 |
-
)
|
326 |
-
else:
|
327 |
-
torch.save(state_dict, os.path.join(save_directory, weights_name))
|
328 |
-
|
329 |
-
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
|
330 |
-
|
331 |
-
@classmethod
|
332 |
-
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
333 |
-
r"""
|
334 |
-
Instantiate a pretrained PyTorch model from a pretrained model configuration.
|
335 |
-
|
336 |
-
The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
|
337 |
-
train the model, set it back in training mode with `model.train()`.
|
338 |
-
|
339 |
-
Parameters:
|
340 |
-
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
341 |
-
Can be either:
|
342 |
-
|
343 |
-
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
344 |
-
the Hub.
|
345 |
-
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
346 |
-
with [`~ModelMixin.save_pretrained`].
|
347 |
-
|
348 |
-
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
349 |
-
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
350 |
-
is not used.
|
351 |
-
torch_dtype (`str` or `torch.dtype`, *optional*):
|
352 |
-
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
353 |
-
dtype is automatically derived from the model's weights.
|
354 |
-
force_download (`bool`, *optional*, defaults to `False`):
|
355 |
-
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
356 |
-
cached versions if they exist.
|
357 |
-
resume_download (`bool`, *optional*, defaults to `False`):
|
358 |
-
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
359 |
-
incompletely downloaded files are deleted.
|
360 |
-
proxies (`Dict[str, str]`, *optional*):
|
361 |
-
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
362 |
-
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
363 |
-
output_loading_info (`bool`, *optional*, defaults to `False`):
|
364 |
-
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
365 |
-
local_files_only(`bool`, *optional*, defaults to `False`):
|
366 |
-
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
367 |
-
won't be downloaded from the Hub.
|
368 |
-
use_auth_token (`str` or *bool*, *optional*):
|
369 |
-
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
370 |
-
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
371 |
-
revision (`str`, *optional*, defaults to `"main"`):
|
372 |
-
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
373 |
-
allowed by Git.
|
374 |
-
from_flax (`bool`, *optional*, defaults to `False`):
|
375 |
-
Load the model weights from a Flax checkpoint save file.
|
376 |
-
subfolder (`str`, *optional*, defaults to `""`):
|
377 |
-
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
378 |
-
mirror (`str`, *optional*):
|
379 |
-
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
|
380 |
-
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
381 |
-
information.
|
382 |
-
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
383 |
-
A map that specifies where each submodule should go. It doesn't need to be defined for each
|
384 |
-
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
|
385 |
-
same device.
|
386 |
-
|
387 |
-
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
|
388 |
-
more information about each option see [designing a device
|
389 |
-
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
390 |
-
max_memory (`Dict`, *optional*):
|
391 |
-
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
|
392 |
-
each GPU and the available CPU RAM if unset.
|
393 |
-
offload_folder (`str` or `os.PathLike`, *optional*):
|
394 |
-
The path to offload weights if `device_map` contains the value `"disk"`.
|
395 |
-
offload_state_dict (`bool`, *optional*):
|
396 |
-
If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
|
397 |
-
the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
|
398 |
-
when there is some disk offload.
|
399 |
-
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
400 |
-
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
401 |
-
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
402 |
-
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
403 |
-
argument to `True` will raise an error.
|
404 |
-
variant (`str`, *optional*):
|
405 |
-
Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
|
406 |
-
loading `from_flax`.
|
407 |
-
use_safetensors (`bool`, *optional*, defaults to `None`):
|
408 |
-
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
|
409 |
-
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
|
410 |
-
weights. If set to `False`, `safetensors` weights are not loaded.
|
411 |
-
|
412 |
-
<Tip>
|
413 |
-
|
414 |
-
To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
|
415 |
-
`huggingface-cli login`. You can also activate the special
|
416 |
-
["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
|
417 |
-
firewalled environment.
|
418 |
-
|
419 |
-
</Tip>
|
420 |
-
|
421 |
-
Example:
|
422 |
-
|
423 |
-
```py
|
424 |
-
from diffusers import UNet2DConditionModel
|
425 |
-
|
426 |
-
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
|
427 |
-
```
|
428 |
-
|
429 |
-
If you get the error message below, you need to finetune the weights for your downstream task:
|
430 |
-
|
431 |
-
```bash
|
432 |
-
Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
|
433 |
-
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
|
434 |
-
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
435 |
-
```
|
436 |
-
"""
|
437 |
-
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
438 |
-
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
439 |
-
force_download = kwargs.pop("force_download", False)
|
440 |
-
from_flax = kwargs.pop("from_flax", False)
|
441 |
-
resume_download = kwargs.pop("resume_download", False)
|
442 |
-
proxies = kwargs.pop("proxies", None)
|
443 |
-
output_loading_info = kwargs.pop("output_loading_info", False)
|
444 |
-
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
445 |
-
use_auth_token = kwargs.pop("use_auth_token", None)
|
446 |
-
revision = kwargs.pop("revision", None)
|
447 |
-
torch_dtype = kwargs.pop("torch_dtype", None)
|
448 |
-
subfolder = kwargs.pop("subfolder", None)
|
449 |
-
device_map = kwargs.pop("device_map", None)
|
450 |
-
max_memory = kwargs.pop("max_memory", None)
|
451 |
-
offload_folder = kwargs.pop("offload_folder", None)
|
452 |
-
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
453 |
-
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
454 |
-
variant = kwargs.pop("variant", None)
|
455 |
-
use_safetensors = kwargs.pop("use_safetensors", None)
|
456 |
-
|
457 |
-
if use_safetensors and not is_safetensors_available():
|
458 |
-
raise ValueError(
|
459 |
-
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
|
460 |
-
)
|
461 |
-
|
462 |
-
allow_pickle = False
|
463 |
-
if use_safetensors is None:
|
464 |
-
use_safetensors = is_safetensors_available()
|
465 |
-
allow_pickle = True
|
466 |
-
|
467 |
-
if low_cpu_mem_usage and not is_accelerate_available():
|
468 |
-
low_cpu_mem_usage = False
|
469 |
-
logger.warning(
|
470 |
-
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
471 |
-
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
472 |
-
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
473 |
-
" install accelerate\n```\n."
|
474 |
-
)
|
475 |
-
|
476 |
-
if device_map is not None and not is_accelerate_available():
|
477 |
-
raise NotImplementedError(
|
478 |
-
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
|
479 |
-
" `device_map=None`. You can install accelerate with `pip install accelerate`."
|
480 |
-
)
|
481 |
-
|
482 |
-
# Check if we can handle device_map and dispatching the weights
|
483 |
-
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
484 |
-
raise NotImplementedError(
|
485 |
-
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
486 |
-
" `device_map=None`."
|
487 |
-
)
|
488 |
-
|
489 |
-
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
490 |
-
raise NotImplementedError(
|
491 |
-
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
492 |
-
" `low_cpu_mem_usage=False`."
|
493 |
-
)
|
494 |
-
|
495 |
-
if low_cpu_mem_usage is False and device_map is not None:
|
496 |
-
raise ValueError(
|
497 |
-
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
|
498 |
-
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
499 |
-
)
|
500 |
-
|
501 |
-
# Load config if we don't provide a configuration
|
502 |
-
config_path = pretrained_model_name_or_path
|
503 |
-
|
504 |
-
user_agent = {
|
505 |
-
"diffusers": __version__,
|
506 |
-
"file_type": "model",
|
507 |
-
"framework": "pytorch",
|
508 |
-
}
|
509 |
-
|
510 |
-
# load config
|
511 |
-
config, unused_kwargs, commit_hash = cls.load_config(
|
512 |
-
config_path,
|
513 |
-
cache_dir=cache_dir,
|
514 |
-
return_unused_kwargs=True,
|
515 |
-
return_commit_hash=True,
|
516 |
-
force_download=force_download,
|
517 |
-
resume_download=resume_download,
|
518 |
-
proxies=proxies,
|
519 |
-
local_files_only=local_files_only,
|
520 |
-
use_auth_token=use_auth_token,
|
521 |
-
revision=revision,
|
522 |
-
subfolder=subfolder,
|
523 |
-
device_map=device_map,
|
524 |
-
max_memory=max_memory,
|
525 |
-
offload_folder=offload_folder,
|
526 |
-
offload_state_dict=offload_state_dict,
|
527 |
-
user_agent=user_agent,
|
528 |
-
**kwargs,
|
529 |
-
)
|
530 |
-
|
531 |
-
# load model
|
532 |
-
model_file = None
|
533 |
-
if from_flax:
|
534 |
-
model_file = _get_model_file(
|
535 |
-
pretrained_model_name_or_path,
|
536 |
-
weights_name=FLAX_WEIGHTS_NAME,
|
537 |
-
cache_dir=cache_dir,
|
538 |
-
force_download=force_download,
|
539 |
-
resume_download=resume_download,
|
540 |
-
proxies=proxies,
|
541 |
-
local_files_only=local_files_only,
|
542 |
-
use_auth_token=use_auth_token,
|
543 |
-
revision=revision,
|
544 |
-
subfolder=subfolder,
|
545 |
-
user_agent=user_agent,
|
546 |
-
commit_hash=commit_hash,
|
547 |
-
)
|
548 |
-
model = cls.from_config(config, **unused_kwargs)
|
549 |
-
|
550 |
-
# Convert the weights
|
551 |
-
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
|
552 |
-
|
553 |
-
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
|
554 |
-
else:
|
555 |
-
if use_safetensors:
|
556 |
-
try:
|
557 |
-
model_file = _get_model_file(
|
558 |
-
pretrained_model_name_or_path,
|
559 |
-
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
|
560 |
-
cache_dir=cache_dir,
|
561 |
-
force_download=force_download,
|
562 |
-
resume_download=resume_download,
|
563 |
-
proxies=proxies,
|
564 |
-
local_files_only=local_files_only,
|
565 |
-
use_auth_token=use_auth_token,
|
566 |
-
revision=revision,
|
567 |
-
subfolder=subfolder,
|
568 |
-
user_agent=user_agent,
|
569 |
-
commit_hash=commit_hash,
|
570 |
-
)
|
571 |
-
except IOError as e:
|
572 |
-
if not allow_pickle:
|
573 |
-
raise e
|
574 |
-
pass
|
575 |
-
if model_file is None:
|
576 |
-
model_file = _get_model_file(
|
577 |
-
pretrained_model_name_or_path,
|
578 |
-
weights_name=_add_variant(WEIGHTS_NAME, variant),
|
579 |
-
cache_dir=cache_dir,
|
580 |
-
force_download=force_download,
|
581 |
-
resume_download=resume_download,
|
582 |
-
proxies=proxies,
|
583 |
-
local_files_only=local_files_only,
|
584 |
-
use_auth_token=use_auth_token,
|
585 |
-
revision=revision,
|
586 |
-
subfolder=subfolder,
|
587 |
-
user_agent=user_agent,
|
588 |
-
commit_hash=commit_hash,
|
589 |
-
)
|
590 |
-
|
591 |
-
if low_cpu_mem_usage:
|
592 |
-
# Instantiate model with empty weights
|
593 |
-
with accelerate.init_empty_weights():
|
594 |
-
model = cls.from_config(config, **unused_kwargs)
|
595 |
-
|
596 |
-
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
597 |
-
if device_map is None:
|
598 |
-
param_device = "cpu"
|
599 |
-
state_dict = load_state_dict(model_file, variant=variant)
|
600 |
-
model._convert_deprecated_attention_blocks(state_dict)
|
601 |
-
# move the params from meta device to cpu
|
602 |
-
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
603 |
-
if len(missing_keys) > 0:
|
604 |
-
raise ValueError(
|
605 |
-
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
|
606 |
-
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
607 |
-
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
608 |
-
" those weights or else make sure your checkpoint file is correct."
|
609 |
-
)
|
610 |
-
unexpected_keys = []
|
611 |
-
|
612 |
-
empty_state_dict = model.state_dict()
|
613 |
-
for param_name, param in state_dict.items():
|
614 |
-
accepts_dtype = "dtype" in set(
|
615 |
-
inspect.signature(set_module_tensor_to_device).parameters.keys()
|
616 |
-
)
|
617 |
-
|
618 |
-
if param_name not in empty_state_dict:
|
619 |
-
unexpected_keys.append(param_name)
|
620 |
-
continue
|
621 |
-
|
622 |
-
if empty_state_dict[param_name].shape != param.shape:
|
623 |
-
raise ValueError(
|
624 |
-
f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
625 |
-
)
|
626 |
-
|
627 |
-
if accepts_dtype:
|
628 |
-
set_module_tensor_to_device(
|
629 |
-
model, param_name, param_device, value=param, dtype=torch_dtype
|
630 |
-
)
|
631 |
-
else:
|
632 |
-
set_module_tensor_to_device(model, param_name, param_device, value=param)
|
633 |
-
|
634 |
-
if cls._keys_to_ignore_on_load_unexpected is not None:
|
635 |
-
for pat in cls._keys_to_ignore_on_load_unexpected:
|
636 |
-
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
637 |
-
|
638 |
-
if len(unexpected_keys) > 0:
|
639 |
-
logger.warn(
|
640 |
-
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
641 |
-
)
|
642 |
-
|
643 |
-
else: # else let accelerate handle loading and dispatching.
|
644 |
-
# Load weights and dispatch according to the device_map
|
645 |
-
# by default the device_map is None and the weights are loaded on the CPU
|
646 |
-
try:
|
647 |
-
accelerate.load_checkpoint_and_dispatch(
|
648 |
-
model,
|
649 |
-
model_file,
|
650 |
-
device_map,
|
651 |
-
max_memory=max_memory,
|
652 |
-
offload_folder=offload_folder,
|
653 |
-
offload_state_dict=offload_state_dict,
|
654 |
-
dtype=torch_dtype,
|
655 |
-
)
|
656 |
-
except AttributeError as e:
|
657 |
-
# When using accelerate loading, we do not have the ability to load the state
|
658 |
-
# dict and rename the weight names manually. Additionally, accelerate skips
|
659 |
-
# torch loading conventions and directly writes into `module.{_buffers, _parameters}`
|
660 |
-
# (which look like they should be private variables?), so we can't use the standard hooks
|
661 |
-
# to rename parameters on load. We need to mimic the original weight names so the correct
|
662 |
-
# attributes are available. After we have loaded the weights, we convert the deprecated
|
663 |
-
# names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
|
664 |
-
# the weights so we don't have to do this again.
|
665 |
-
|
666 |
-
if "'Attention' object has no attribute" in str(e):
|
667 |
-
logger.warn(
|
668 |
-
f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
|
669 |
-
" was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
|
670 |
-
" names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
|
671 |
-
" so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
|
672 |
-
" please also re-upload it or open a PR on the original repository."
|
673 |
-
)
|
674 |
-
model._temp_convert_self_to_deprecated_attention_blocks()
|
675 |
-
accelerate.load_checkpoint_and_dispatch(
|
676 |
-
model,
|
677 |
-
model_file,
|
678 |
-
device_map,
|
679 |
-
max_memory=max_memory,
|
680 |
-
offload_folder=offload_folder,
|
681 |
-
offload_state_dict=offload_state_dict,
|
682 |
-
dtype=torch_dtype,
|
683 |
-
)
|
684 |
-
model._undo_temp_convert_self_to_deprecated_attention_blocks()
|
685 |
-
else:
|
686 |
-
raise e
|
687 |
-
|
688 |
-
loading_info = {
|
689 |
-
"missing_keys": [],
|
690 |
-
"unexpected_keys": [],
|
691 |
-
"mismatched_keys": [],
|
692 |
-
"error_msgs": [],
|
693 |
-
}
|
694 |
-
else:
|
695 |
-
model = cls.from_config(config, **unused_kwargs)
|
696 |
-
|
697 |
-
state_dict = load_state_dict(model_file, variant=variant)
|
698 |
-
model._convert_deprecated_attention_blocks(state_dict)
|
699 |
-
|
700 |
-
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
701 |
-
model,
|
702 |
-
state_dict,
|
703 |
-
model_file,
|
704 |
-
pretrained_model_name_or_path,
|
705 |
-
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
706 |
-
)
|
707 |
-
|
708 |
-
loading_info = {
|
709 |
-
"missing_keys": missing_keys,
|
710 |
-
"unexpected_keys": unexpected_keys,
|
711 |
-
"mismatched_keys": mismatched_keys,
|
712 |
-
"error_msgs": error_msgs,
|
713 |
-
}
|
714 |
-
|
715 |
-
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
716 |
-
raise ValueError(
|
717 |
-
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
|
718 |
-
)
|
719 |
-
elif torch_dtype is not None:
|
720 |
-
model = model.to(torch_dtype)
|
721 |
-
|
722 |
-
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
723 |
-
|
724 |
-
# Set model in evaluation mode to deactivate DropOut modules by default
|
725 |
-
model.eval()
|
726 |
-
if output_loading_info:
|
727 |
-
return model, loading_info
|
728 |
-
|
729 |
-
return model
|
730 |
-
|
731 |
-
@classmethod
|
732 |
-
def _load_pretrained_model(
|
733 |
-
cls,
|
734 |
-
model,
|
735 |
-
state_dict,
|
736 |
-
resolved_archive_file,
|
737 |
-
pretrained_model_name_or_path,
|
738 |
-
ignore_mismatched_sizes=False,
|
739 |
-
):
|
740 |
-
# Retrieve missing & unexpected_keys
|
741 |
-
model_state_dict = model.state_dict()
|
742 |
-
loaded_keys = list(state_dict.keys())
|
743 |
-
|
744 |
-
expected_keys = list(model_state_dict.keys())
|
745 |
-
|
746 |
-
original_loaded_keys = loaded_keys
|
747 |
-
|
748 |
-
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
749 |
-
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
750 |
-
|
751 |
-
# Make sure we are able to load base models as well as derived models (with heads)
|
752 |
-
model_to_load = model
|
753 |
-
|
754 |
-
def _find_mismatched_keys(
|
755 |
-
state_dict,
|
756 |
-
model_state_dict,
|
757 |
-
loaded_keys,
|
758 |
-
ignore_mismatched_sizes,
|
759 |
-
):
|
760 |
-
mismatched_keys = []
|
761 |
-
if ignore_mismatched_sizes:
|
762 |
-
for checkpoint_key in loaded_keys:
|
763 |
-
model_key = checkpoint_key
|
764 |
-
|
765 |
-
if (
|
766 |
-
model_key in model_state_dict
|
767 |
-
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
768 |
-
):
|
769 |
-
mismatched_keys.append(
|
770 |
-
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
771 |
-
)
|
772 |
-
del state_dict[checkpoint_key]
|
773 |
-
return mismatched_keys
|
774 |
-
|
775 |
-
if state_dict is not None:
|
776 |
-
# Whole checkpoint
|
777 |
-
mismatched_keys = _find_mismatched_keys(
|
778 |
-
state_dict,
|
779 |
-
model_state_dict,
|
780 |
-
original_loaded_keys,
|
781 |
-
ignore_mismatched_sizes,
|
782 |
-
)
|
783 |
-
error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
|
784 |
-
|
785 |
-
if len(error_msgs) > 0:
|
786 |
-
error_msg = "\n\t".join(error_msgs)
|
787 |
-
if "size mismatch" in error_msg:
|
788 |
-
error_msg += (
|
789 |
-
"\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
|
790 |
-
)
|
791 |
-
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
|
792 |
-
|
793 |
-
if len(unexpected_keys) > 0:
|
794 |
-
logger.warning(
|
795 |
-
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
796 |
-
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
|
797 |
-
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
|
798 |
-
" or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
|
799 |
-
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
|
800 |
-
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
|
801 |
-
" identical (initializing a BertForSequenceClassification model from a"
|
802 |
-
" BertForSequenceClassification model)."
|
803 |
-
)
|
804 |
-
else:
|
805 |
-
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
806 |
-
if len(missing_keys) > 0:
|
807 |
-
logger.warning(
|
808 |
-
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
809 |
-
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
|
810 |
-
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
811 |
-
)
|
812 |
-
elif len(mismatched_keys) == 0:
|
813 |
-
logger.info(
|
814 |
-
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
|
815 |
-
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
|
816 |
-
f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
|
817 |
-
" without further training."
|
818 |
-
)
|
819 |
-
if len(mismatched_keys) > 0:
|
820 |
-
mismatched_warning = "\n".join(
|
821 |
-
[
|
822 |
-
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
823 |
-
for key, shape1, shape2 in mismatched_keys
|
824 |
-
]
|
825 |
-
)
|
826 |
-
logger.warning(
|
827 |
-
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
828 |
-
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
|
829 |
-
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
|
830 |
-
" able to use it for predictions and inference."
|
831 |
-
)
|
832 |
-
|
833 |
-
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
834 |
-
|
835 |
-
@property
|
836 |
-
def device(self) -> device:
|
837 |
-
"""
|
838 |
-
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
|
839 |
-
device).
|
840 |
-
"""
|
841 |
-
return get_parameter_device(self)
|
842 |
-
|
843 |
-
@property
|
844 |
-
def dtype(self) -> torch.dtype:
|
845 |
-
"""
|
846 |
-
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
847 |
-
"""
|
848 |
-
return get_parameter_dtype(self)
|
849 |
-
|
850 |
-
def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
|
851 |
-
"""
|
852 |
-
Get number of (trainable or non-embedding) parameters in the module.
|
853 |
-
|
854 |
-
Args:
|
855 |
-
only_trainable (`bool`, *optional*, defaults to `False`):
|
856 |
-
Whether or not to return only the number of trainable parameters.
|
857 |
-
exclude_embeddings (`bool`, *optional*, defaults to `False`):
|
858 |
-
Whether or not to return only the number of non-embedding parameters.
|
859 |
-
|
860 |
-
Returns:
|
861 |
-
`int`: The number of parameters.
|
862 |
-
|
863 |
-
Example:
|
864 |
-
|
865 |
-
```py
|
866 |
-
from diffusers import UNet2DConditionModel
|
867 |
-
|
868 |
-
model_id = "runwayml/stable-diffusion-v1-5"
|
869 |
-
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
|
870 |
-
unet.num_parameters(only_trainable=True)
|
871 |
-
859520964
|
872 |
-
```
|
873 |
-
"""
|
874 |
-
|
875 |
-
if exclude_embeddings:
|
876 |
-
embedding_param_names = [
|
877 |
-
f"{name}.weight"
|
878 |
-
for name, module_type in self.named_modules()
|
879 |
-
if isinstance(module_type, torch.nn.Embedding)
|
880 |
-
]
|
881 |
-
non_embedding_parameters = [
|
882 |
-
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
|
883 |
-
]
|
884 |
-
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
|
885 |
-
else:
|
886 |
-
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
|
887 |
-
|
888 |
-
def _convert_deprecated_attention_blocks(self, state_dict):
|
889 |
-
deprecated_attention_block_paths = []
|
890 |
-
|
891 |
-
def recursive_find_attn_block(name, module):
|
892 |
-
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
893 |
-
deprecated_attention_block_paths.append(name)
|
894 |
-
|
895 |
-
for sub_name, sub_module in module.named_children():
|
896 |
-
sub_name = sub_name if name == "" else f"{name}.{sub_name}"
|
897 |
-
recursive_find_attn_block(sub_name, sub_module)
|
898 |
-
|
899 |
-
recursive_find_attn_block("", self)
|
900 |
-
|
901 |
-
# NOTE: we have to check if the deprecated parameters are in the state dict
|
902 |
-
# because it is possible we are loading from a state dict that was already
|
903 |
-
# converted
|
904 |
-
|
905 |
-
for path in deprecated_attention_block_paths:
|
906 |
-
# group_norm path stays the same
|
907 |
-
|
908 |
-
# query -> to_q
|
909 |
-
if f"{path}.query.weight" in state_dict:
|
910 |
-
state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
|
911 |
-
if f"{path}.query.bias" in state_dict:
|
912 |
-
state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
|
913 |
-
|
914 |
-
# key -> to_k
|
915 |
-
if f"{path}.key.weight" in state_dict:
|
916 |
-
state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
|
917 |
-
if f"{path}.key.bias" in state_dict:
|
918 |
-
state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
|
919 |
-
|
920 |
-
# value -> to_v
|
921 |
-
if f"{path}.value.weight" in state_dict:
|
922 |
-
state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
|
923 |
-
if f"{path}.value.bias" in state_dict:
|
924 |
-
state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
|
925 |
-
|
926 |
-
# proj_attn -> to_out.0
|
927 |
-
if f"{path}.proj_attn.weight" in state_dict:
|
928 |
-
state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
|
929 |
-
if f"{path}.proj_attn.bias" in state_dict:
|
930 |
-
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
|
931 |
-
|
932 |
-
def _temp_convert_self_to_deprecated_attention_blocks(self):
|
933 |
-
deprecated_attention_block_modules = []
|
934 |
-
|
935 |
-
def recursive_find_attn_block(module):
|
936 |
-
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
937 |
-
deprecated_attention_block_modules.append(module)
|
938 |
-
|
939 |
-
for sub_module in module.children():
|
940 |
-
recursive_find_attn_block(sub_module)
|
941 |
-
|
942 |
-
recursive_find_attn_block(self)
|
943 |
-
|
944 |
-
for module in deprecated_attention_block_modules:
|
945 |
-
module.query = module.to_q
|
946 |
-
module.key = module.to_k
|
947 |
-
module.value = module.to_v
|
948 |
-
module.proj_attn = module.to_out[0]
|
949 |
-
|
950 |
-
# We don't _have_ to delete the old attributes, but it's helpful to ensure
|
951 |
-
# that _all_ the weights are loaded into the new attributes and we're not
|
952 |
-
# making an incorrect assumption that this model should be converted when
|
953 |
-
# it really shouldn't be.
|
954 |
-
del module.to_q
|
955 |
-
del module.to_k
|
956 |
-
del module.to_v
|
957 |
-
del module.to_out
|
958 |
-
|
959 |
-
def _undo_temp_convert_self_to_deprecated_attention_blocks(self):
|
960 |
-
deprecated_attention_block_modules = []
|
961 |
-
|
962 |
-
def recursive_find_attn_block(module):
|
963 |
-
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
964 |
-
deprecated_attention_block_modules.append(module)
|
965 |
-
|
966 |
-
for sub_module in module.children():
|
967 |
-
recursive_find_attn_block(sub_module)
|
968 |
-
|
969 |
-
recursive_find_attn_block(self)
|
970 |
-
|
971 |
-
for module in deprecated_attention_block_modules:
|
972 |
-
module.to_q = module.query
|
973 |
-
module.to_k = module.key
|
974 |
-
module.to_v = module.value
|
975 |
-
module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)])
|
976 |
-
|
977 |
-
del module.query
|
978 |
-
del module.key
|
979 |
-
del module.value
|
980 |
-
del module.proj_attn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4DoF/diffusers/models/prior_transformer.py
DELETED
@@ -1,364 +0,0 @@
|
|
1 |
-
from dataclasses import dataclass
|
2 |
-
from typing import Dict, Optional, Union
|
3 |
-
|
4 |
-
import torch
|
5 |
-
import torch.nn.functional as F
|
6 |
-
from torch import nn
|
7 |
-
|
8 |
-
from ..configuration_utils import ConfigMixin, register_to_config
|
9 |
-
from ..utils import BaseOutput
|
10 |
-
from .attention import BasicTransformerBlock
|
11 |
-
from .attention_processor import AttentionProcessor, AttnProcessor
|
12 |
-
from .embeddings import TimestepEmbedding, Timesteps
|
13 |
-
from .modeling_utils import ModelMixin
|
14 |
-
|
15 |
-
|
16 |
-
@dataclass
|
17 |
-
class PriorTransformerOutput(BaseOutput):
|
18 |
-
"""
|
19 |
-
The output of [`PriorTransformer`].
|
20 |
-
|
21 |
-
Args:
|
22 |
-
predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
23 |
-
The predicted CLIP image embedding conditioned on the CLIP text embedding input.
|
24 |
-
"""
|
25 |
-
|
26 |
-
predicted_image_embedding: torch.FloatTensor
|
27 |
-
|
28 |
-
|
29 |
-
class PriorTransformer(ModelMixin, ConfigMixin):
|
30 |
-
"""
|
31 |
-
A Prior Transformer model.
|
32 |
-
|
33 |
-
Parameters:
|
34 |
-
num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
|
35 |
-
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
36 |
-
num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
|
37 |
-
embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states`
|
38 |
-
num_embeddings (`int`, *optional*, defaults to 77):
|
39 |
-
The number of embeddings of the model input `hidden_states`
|
40 |
-
additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
|
41 |
-
projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
|
42 |
-
additional_embeddings`.
|
43 |
-
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
44 |
-
time_embed_act_fn (`str`, *optional*, defaults to 'silu'):
|
45 |
-
The activation function to use to create timestep embeddings.
|
46 |
-
norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before
|
47 |
-
passing to Transformer blocks. Set it to `None` if normalization is not needed.
|
48 |
-
embedding_proj_norm_type (`str`, *optional*, defaults to None):
|
49 |
-
The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not
|
50 |
-
needed.
|
51 |
-
encoder_hid_proj_type (`str`, *optional*, defaults to `linear`):
|
52 |
-
The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if
|
53 |
-
`encoder_hidden_states` is `None`.
|
54 |
-
added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
|
55 |
-
Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
|
56 |
-
product between the text embedding and image embedding as proposed in the unclip paper
|
57 |
-
https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended.
|
58 |
-
time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
|
59 |
-
If None, will be set to `num_attention_heads * attention_head_dim`
|
60 |
-
embedding_proj_dim (`int`, *optional*, default to None):
|
61 |
-
The dimension of `proj_embedding`. If None, will be set to `embedding_dim`.
|
62 |
-
clip_embed_dim (`int`, *optional*, default to None):
|
63 |
-
The dimension of the output. If None, will be set to `embedding_dim`.
|
64 |
-
"""
|
65 |
-
|
66 |
-
@register_to_config
|
67 |
-
def __init__(
|
68 |
-
self,
|
69 |
-
num_attention_heads: int = 32,
|
70 |
-
attention_head_dim: int = 64,
|
71 |
-
num_layers: int = 20,
|
72 |
-
embedding_dim: int = 768,
|
73 |
-
num_embeddings=77,
|
74 |
-
additional_embeddings=4,
|
75 |
-
dropout: float = 0.0,
|
76 |
-
time_embed_act_fn: str = "silu",
|
77 |
-
norm_in_type: Optional[str] = None, # layer
|
78 |
-
embedding_proj_norm_type: Optional[str] = None, # layer
|
79 |
-
encoder_hid_proj_type: Optional[str] = "linear", # linear
|
80 |
-
added_emb_type: Optional[str] = "prd", # prd
|
81 |
-
time_embed_dim: Optional[int] = None,
|
82 |
-
embedding_proj_dim: Optional[int] = None,
|
83 |
-
clip_embed_dim: Optional[int] = None,
|
84 |
-
):
|
85 |
-
super().__init__()
|
86 |
-
self.num_attention_heads = num_attention_heads
|
87 |
-
self.attention_head_dim = attention_head_dim
|
88 |
-
inner_dim = num_attention_heads * attention_head_dim
|
89 |
-
self.additional_embeddings = additional_embeddings
|
90 |
-
|
91 |
-
time_embed_dim = time_embed_dim or inner_dim
|
92 |
-
embedding_proj_dim = embedding_proj_dim or embedding_dim
|
93 |
-
clip_embed_dim = clip_embed_dim or embedding_dim
|
94 |
-
|
95 |
-
self.time_proj = Timesteps(inner_dim, True, 0)
|
96 |
-
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn)
|
97 |
-
|
98 |
-
self.proj_in = nn.Linear(embedding_dim, inner_dim)
|
99 |
-
|
100 |
-
if embedding_proj_norm_type is None:
|
101 |
-
self.embedding_proj_norm = None
|
102 |
-
elif embedding_proj_norm_type == "layer":
|
103 |
-
self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim)
|
104 |
-
else:
|
105 |
-
raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}")
|
106 |
-
|
107 |
-
self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim)
|
108 |
-
|
109 |
-
if encoder_hid_proj_type is None:
|
110 |
-
self.encoder_hidden_states_proj = None
|
111 |
-
elif encoder_hid_proj_type == "linear":
|
112 |
-
self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
|
113 |
-
else:
|
114 |
-
raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}")
|
115 |
-
|
116 |
-
self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
|
117 |
-
|
118 |
-
if added_emb_type == "prd":
|
119 |
-
self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
|
120 |
-
elif added_emb_type is None:
|
121 |
-
self.prd_embedding = None
|
122 |
-
else:
|
123 |
-
raise ValueError(
|
124 |
-
f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`."
|
125 |
-
)
|
126 |
-
|
127 |
-
self.transformer_blocks = nn.ModuleList(
|
128 |
-
[
|
129 |
-
BasicTransformerBlock(
|
130 |
-
inner_dim,
|
131 |
-
num_attention_heads,
|
132 |
-
attention_head_dim,
|
133 |
-
dropout=dropout,
|
134 |
-
activation_fn="gelu",
|
135 |
-
attention_bias=True,
|
136 |
-
)
|
137 |
-
for d in range(num_layers)
|
138 |
-
]
|
139 |
-
)
|
140 |
-
|
141 |
-
if norm_in_type == "layer":
|
142 |
-
self.norm_in = nn.LayerNorm(inner_dim)
|
143 |
-
elif norm_in_type is None:
|
144 |
-
self.norm_in = None
|
145 |
-
else:
|
146 |
-
raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.")
|
147 |
-
|
148 |
-
self.norm_out = nn.LayerNorm(inner_dim)
|
149 |
-
|
150 |
-
self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim)
|
151 |
-
|
152 |
-
causal_attention_mask = torch.full(
|
153 |
-
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
|
154 |
-
)
|
155 |
-
causal_attention_mask.triu_(1)
|
156 |
-
causal_attention_mask = causal_attention_mask[None, ...]
|
157 |
-
self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
|
158 |
-
|
159 |
-
self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim))
|
160 |
-
self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
|
161 |
-
|
162 |
-
@property
|
163 |
-
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
164 |
-
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
165 |
-
r"""
|
166 |
-
Returns:
|
167 |
-
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
168 |
-
indexed by its weight name.
|
169 |
-
"""
|
170 |
-
# set recursively
|
171 |
-
processors = {}
|
172 |
-
|
173 |
-
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
174 |
-
if hasattr(module, "set_processor"):
|
175 |
-
processors[f"{name}.processor"] = module.processor
|
176 |
-
|
177 |
-
for sub_name, child in module.named_children():
|
178 |
-
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
179 |
-
|
180 |
-
return processors
|
181 |
-
|
182 |
-
for name, module in self.named_children():
|
183 |
-
fn_recursive_add_processors(name, module, processors)
|
184 |
-
|
185 |
-
return processors
|
186 |
-
|
187 |
-
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
188 |
-
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
189 |
-
r"""
|
190 |
-
Sets the attention processor to use to compute attention.
|
191 |
-
|
192 |
-
Parameters:
|
193 |
-
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
194 |
-
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
195 |
-
for **all** `Attention` layers.
|
196 |
-
|
197 |
-
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
198 |
-
processor. This is strongly recommended when setting trainable attention processors.
|
199 |
-
|
200 |
-
"""
|
201 |
-
count = len(self.attn_processors.keys())
|
202 |
-
|
203 |
-
if isinstance(processor, dict) and len(processor) != count:
|
204 |
-
raise ValueError(
|
205 |
-
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
206 |
-
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
207 |
-
)
|
208 |
-
|
209 |
-
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
210 |
-
if hasattr(module, "set_processor"):
|
211 |
-
if not isinstance(processor, dict):
|
212 |
-
module.set_processor(processor)
|
213 |
-
else:
|
214 |
-
module.set_processor(processor.pop(f"{name}.processor"))
|
215 |
-
|
216 |
-
for sub_name, child in module.named_children():
|
217 |
-
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
218 |
-
|
219 |
-
for name, module in self.named_children():
|
220 |
-
fn_recursive_attn_processor(name, module, processor)
|
221 |
-
|
222 |
-
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
223 |
-
def set_default_attn_processor(self):
|
224 |
-
"""
|
225 |
-
Disables custom attention processors and sets the default attention implementation.
|
226 |
-
"""
|
227 |
-
self.set_attn_processor(AttnProcessor())
|
228 |
-
|
229 |
-
def forward(
|
230 |
-
self,
|
231 |
-
hidden_states,
|
232 |
-
timestep: Union[torch.Tensor, float, int],
|
233 |
-
proj_embedding: torch.FloatTensor,
|
234 |
-
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
235 |
-
attention_mask: Optional[torch.BoolTensor] = None,
|
236 |
-
return_dict: bool = True,
|
237 |
-
):
|
238 |
-
"""
|
239 |
-
The [`PriorTransformer`] forward method.
|
240 |
-
|
241 |
-
Args:
|
242 |
-
hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
243 |
-
The currently predicted image embeddings.
|
244 |
-
timestep (`torch.LongTensor`):
|
245 |
-
Current denoising step.
|
246 |
-
proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
247 |
-
Projected embedding vector the denoising process is conditioned on.
|
248 |
-
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
|
249 |
-
Hidden states of the text embeddings the denoising process is conditioned on.
|
250 |
-
attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
|
251 |
-
Text mask for the text embeddings.
|
252 |
-
return_dict (`bool`, *optional*, defaults to `True`):
|
253 |
-
Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain
|
254 |
-
tuple.
|
255 |
-
|
256 |
-
Returns:
|
257 |
-
[`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
|
258 |
-
If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a
|
259 |
-
tuple is returned where the first element is the sample tensor.
|
260 |
-
"""
|
261 |
-
batch_size = hidden_states.shape[0]
|
262 |
-
|
263 |
-
timesteps = timestep
|
264 |
-
if not torch.is_tensor(timesteps):
|
265 |
-
timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device)
|
266 |
-
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
267 |
-
timesteps = timesteps[None].to(hidden_states.device)
|
268 |
-
|
269 |
-
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
270 |
-
timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device)
|
271 |
-
|
272 |
-
timesteps_projected = self.time_proj(timesteps)
|
273 |
-
|
274 |
-
# timesteps does not contain any weights and will always return f32 tensors
|
275 |
-
# but time_embedding might be fp16, so we need to cast here.
|
276 |
-
timesteps_projected = timesteps_projected.to(dtype=self.dtype)
|
277 |
-
time_embeddings = self.time_embedding(timesteps_projected)
|
278 |
-
|
279 |
-
if self.embedding_proj_norm is not None:
|
280 |
-
proj_embedding = self.embedding_proj_norm(proj_embedding)
|
281 |
-
|
282 |
-
proj_embeddings = self.embedding_proj(proj_embedding)
|
283 |
-
if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
|
284 |
-
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
|
285 |
-
elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None:
|
286 |
-
raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set")
|
287 |
-
|
288 |
-
hidden_states = self.proj_in(hidden_states)
|
289 |
-
|
290 |
-
positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
|
291 |
-
|
292 |
-
additional_embeds = []
|
293 |
-
additional_embeddings_len = 0
|
294 |
-
|
295 |
-
if encoder_hidden_states is not None:
|
296 |
-
additional_embeds.append(encoder_hidden_states)
|
297 |
-
additional_embeddings_len += encoder_hidden_states.shape[1]
|
298 |
-
|
299 |
-
if len(proj_embeddings.shape) == 2:
|
300 |
-
proj_embeddings = proj_embeddings[:, None, :]
|
301 |
-
|
302 |
-
if len(hidden_states.shape) == 2:
|
303 |
-
hidden_states = hidden_states[:, None, :]
|
304 |
-
|
305 |
-
additional_embeds = additional_embeds + [
|
306 |
-
proj_embeddings,
|
307 |
-
time_embeddings[:, None, :],
|
308 |
-
hidden_states,
|
309 |
-
]
|
310 |
-
|
311 |
-
if self.prd_embedding is not None:
|
312 |
-
prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
|
313 |
-
additional_embeds.append(prd_embedding)
|
314 |
-
|
315 |
-
hidden_states = torch.cat(
|
316 |
-
additional_embeds,
|
317 |
-
dim=1,
|
318 |
-
)
|
319 |
-
|
320 |
-
# Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens
|
321 |
-
additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1
|
322 |
-
if positional_embeddings.shape[1] < hidden_states.shape[1]:
|
323 |
-
positional_embeddings = F.pad(
|
324 |
-
positional_embeddings,
|
325 |
-
(
|
326 |
-
0,
|
327 |
-
0,
|
328 |
-
additional_embeddings_len,
|
329 |
-
self.prd_embedding.shape[1] if self.prd_embedding is not None else 0,
|
330 |
-
),
|
331 |
-
value=0.0,
|
332 |
-
)
|
333 |
-
|
334 |
-
hidden_states = hidden_states + positional_embeddings
|
335 |
-
|
336 |
-
if attention_mask is not None:
|
337 |
-
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
338 |
-
attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
|
339 |
-
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
|
340 |
-
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
|
341 |
-
|
342 |
-
if self.norm_in is not None:
|
343 |
-
hidden_states = self.norm_in(hidden_states)
|
344 |
-
|
345 |
-
for block in self.transformer_blocks:
|
346 |
-
hidden_states = block(hidden_states, attention_mask=attention_mask)
|
347 |
-
|
348 |
-
hidden_states = self.norm_out(hidden_states)
|
349 |
-
|
350 |
-
if self.prd_embedding is not None:
|
351 |
-
hidden_states = hidden_states[:, -1]
|
352 |
-
else:
|
353 |
-
hidden_states = hidden_states[:, additional_embeddings_len:]
|
354 |
-
|
355 |
-
predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
|
356 |
-
|
357 |
-
if not return_dict:
|
358 |
-
return (predicted_image_embedding,)
|
359 |
-
|
360 |
-
return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
|
361 |
-
|
362 |
-
def post_process_latents(self, prior_latents):
|
363 |
-
prior_latents = (prior_latents * self.clip_std) + self.clip_mean
|
364 |
-
return prior_latents
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|