kxhit commited on
Commit
23aae87
1 Parent(s): ad4ee48
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +0 -43
  2. 3drecon/configs/neus_36.yaml +0 -26
  3. 3drecon/raymarching/__init__.py +0 -1
  4. 3drecon/raymarching/backend.py +0 -40
  5. 3drecon/raymarching/raymarching.py +0 -373
  6. 3drecon/raymarching/setup.py +0 -62
  7. 3drecon/raymarching/src/bindings.cpp +0 -19
  8. 3drecon/raymarching/src/raymarching.cu +0 -914
  9. 3drecon/raymarching/src/raymarching.h +0 -18
  10. 3drecon/renderer/agg_net.py +0 -83
  11. 3drecon/renderer/cost_reg_net.py +0 -95
  12. 3drecon/renderer/dummy_dataset.py +0 -40
  13. 3drecon/renderer/feature_net.py +0 -42
  14. 3drecon/renderer/neus_networks.py +0 -503
  15. 3drecon/renderer/ngp_renderer.py +0 -721
  16. 3drecon/renderer/renderer.py +0 -640
  17. 3drecon/run_NeuS.py +0 -32
  18. 3drecon/train_renderer.py +0 -188
  19. 3drecon/util.py +0 -54
  20. 4DoF/CN_encoder.py +0 -36
  21. 4DoF/dataset.py +0 -228
  22. 4DoF/diffusers/__init__.py +0 -281
  23. 4DoF/diffusers/commands/__init__.py +0 -27
  24. 4DoF/diffusers/commands/diffusers_cli.py +0 -41
  25. 4DoF/diffusers/commands/env.py +0 -84
  26. 4DoF/diffusers/configuration_utils.py +0 -664
  27. 4DoF/diffusers/dependency_versions_check.py +0 -47
  28. 4DoF/diffusers/dependency_versions_table.py +0 -44
  29. 4DoF/diffusers/experimental/__init__.py +0 -1
  30. 4DoF/diffusers/experimental/rl/__init__.py +0 -1
  31. 4DoF/diffusers/experimental/rl/value_guided_sampling.py +0 -152
  32. 4DoF/diffusers/image_processor.py +0 -366
  33. 4DoF/diffusers/loaders.py +0 -1492
  34. 4DoF/diffusers/models/__init__.py +0 -35
  35. 4DoF/diffusers/models/activations.py +0 -12
  36. 4DoF/diffusers/models/attention.py +0 -392
  37. 4DoF/diffusers/models/attention_flax.py +0 -446
  38. 4DoF/diffusers/models/attention_processor.py +0 -1714
  39. 4DoF/diffusers/models/autoencoder_kl.py +0 -411
  40. 4DoF/diffusers/models/controlnet.py +0 -705
  41. 4DoF/diffusers/models/controlnet_flax.py +0 -394
  42. 4DoF/diffusers/models/cross_attention.py +0 -94
  43. 4DoF/diffusers/models/dual_transformer_2d.py +0 -151
  44. 4DoF/diffusers/models/embeddings.py +0 -546
  45. 4DoF/diffusers/models/embeddings_flax.py +0 -95
  46. 4DoF/diffusers/models/modeling_flax_pytorch_utils.py +0 -118
  47. 4DoF/diffusers/models/modeling_flax_utils.py +0 -534
  48. 4DoF/diffusers/models/modeling_pytorch_flax_utils.py +0 -161
  49. 4DoF/diffusers/models/modeling_utils.py +0 -980
  50. 4DoF/diffusers/models/prior_transformer.py +0 -364
.gitattributes DELETED
@@ -1,43 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- logs/user_object/eschernet/output.gif filter=lfs diff=lfs merge=lfs -text
37
- logs/user_object/scene.glb filter=lfs diff=lfs merge=lfs -text
38
- 3drecon/ours_GSO_T1/NeuS/grandmother/mesh.ply filter=lfs diff=lfs merge=lfs -text
39
- 3drecon/ours_GSO_T1/NeuS/lion/mesh.ply filter=lfs diff=lfs merge=lfs -text
40
- gradio_demo/examples/bike/003.jpg filter=lfs diff=lfs merge=lfs -text
41
- gradio_demo/examples/bike/027.jpg filter=lfs diff=lfs merge=lfs -text
42
- gradio_demo/examples/bike/bike_0.jpg filter=lfs diff=lfs merge=lfs -text
43
- gradio_demo/examples/bike/bike_2.jpg filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3drecon/configs/neus_36.yaml DELETED
@@ -1,26 +0,0 @@
1
- model:
2
- base_lr: 5.0e-4
3
- target: renderer.renderer.RendererTrainer
4
- params:
5
- total_steps: 2000
6
- warm_up_steps: 100
7
- train_batch_num: 2560
8
- train_batch_fg_num: 512
9
- test_batch_num: 4096
10
- use_mask: true
11
- lambda_rgb_loss: 0.5
12
- lambda_mask_loss: 1.0
13
- lambda_eikonal_loss: 0.1
14
- use_warm_up: true
15
-
16
- data:
17
- target: renderer.dummy_dataset.DummyDataset
18
- params: {}
19
-
20
- callbacks:
21
- save_interval: 500
22
-
23
- trainer:
24
- val_check_interval: 500
25
- max_steps: 2000
26
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3drecon/raymarching/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .raymarching import *
 
 
3drecon/raymarching/backend.py DELETED
@@ -1,40 +0,0 @@
1
- import os
2
- from torch.utils.cpp_extension import load
3
-
4
- _src_path = os.path.dirname(os.path.abspath(__file__))
5
-
6
- nvcc_flags = [
7
- '-O3', '-std=c++14',
8
- '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
9
- ]
10
-
11
- if os.name == "posix":
12
- c_flags = ['-O3', '-std=c++14']
13
- elif os.name == "nt":
14
- c_flags = ['/O2', '/std:c++17']
15
-
16
- # find cl.exe
17
- def find_cl_path():
18
- import glob
19
- for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
20
- paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
21
- if paths:
22
- return paths[0]
23
-
24
- # If cl.exe is not on path, try to find it.
25
- if os.system("where cl.exe >nul 2>nul") != 0:
26
- cl_path = find_cl_path()
27
- if cl_path is None:
28
- raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
29
- os.environ["PATH"] += ";" + cl_path
30
-
31
- _backend = load(name='_raymarching',
32
- extra_cflags=c_flags,
33
- extra_cuda_cflags=nvcc_flags,
34
- sources=[os.path.join(_src_path, 'src', f) for f in [
35
- 'raymarching.cu',
36
- 'bindings.cpp',
37
- ]],
38
- )
39
-
40
- __all__ = ['_backend']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3drecon/raymarching/raymarching.py DELETED
@@ -1,373 +0,0 @@
1
- import numpy as np
2
- import time
3
-
4
- import torch
5
- import torch.nn as nn
6
- from torch.autograd import Function
7
- from torch.cuda.amp import custom_bwd, custom_fwd
8
-
9
- try:
10
- import _raymarching as _backend
11
- except ImportError:
12
- from .backend import _backend
13
-
14
-
15
- # ----------------------------------------
16
- # utils
17
- # ----------------------------------------
18
-
19
- class _near_far_from_aabb(Function):
20
- @staticmethod
21
- @custom_fwd(cast_inputs=torch.float32)
22
- def forward(ctx, rays_o, rays_d, aabb, min_near=0.2):
23
- ''' near_far_from_aabb, CUDA implementation
24
- Calculate rays' intersection time (near and far) with aabb
25
- Args:
26
- rays_o: float, [N, 3]
27
- rays_d: float, [N, 3]
28
- aabb: float, [6], (xmin, ymin, zmin, xmax, ymax, zmax)
29
- min_near: float, scalar
30
- Returns:
31
- nears: float, [N]
32
- fars: float, [N]
33
- '''
34
- if not rays_o.is_cuda: rays_o = rays_o.cuda()
35
- if not rays_d.is_cuda: rays_d = rays_d.cuda()
36
-
37
- rays_o = rays_o.contiguous().view(-1, 3)
38
- rays_d = rays_d.contiguous().view(-1, 3)
39
-
40
- N = rays_o.shape[0] # num rays
41
-
42
- nears = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)
43
- fars = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)
44
-
45
- _backend.near_far_from_aabb(rays_o, rays_d, aabb, N, min_near, nears, fars)
46
-
47
- return nears, fars
48
-
49
- near_far_from_aabb = _near_far_from_aabb.apply
50
-
51
-
52
- class _sph_from_ray(Function):
53
- @staticmethod
54
- @custom_fwd(cast_inputs=torch.float32)
55
- def forward(ctx, rays_o, rays_d, radius):
56
- ''' sph_from_ray, CUDA implementation
57
- get spherical coordinate on the background sphere from rays.
58
- Assume rays_o are inside the Sphere(radius).
59
- Args:
60
- rays_o: [N, 3]
61
- rays_d: [N, 3]
62
- radius: scalar, float
63
- Return:
64
- coords: [N, 2], in [-1, 1], theta and phi on a sphere. (further-surface)
65
- '''
66
- if not rays_o.is_cuda: rays_o = rays_o.cuda()
67
- if not rays_d.is_cuda: rays_d = rays_d.cuda()
68
-
69
- rays_o = rays_o.contiguous().view(-1, 3)
70
- rays_d = rays_d.contiguous().view(-1, 3)
71
-
72
- N = rays_o.shape[0] # num rays
73
-
74
- coords = torch.empty(N, 2, dtype=rays_o.dtype, device=rays_o.device)
75
-
76
- _backend.sph_from_ray(rays_o, rays_d, radius, N, coords)
77
-
78
- return coords
79
-
80
- sph_from_ray = _sph_from_ray.apply
81
-
82
-
83
- class _morton3D(Function):
84
- @staticmethod
85
- def forward(ctx, coords):
86
- ''' morton3D, CUDA implementation
87
- Args:
88
- coords: [N, 3], int32, in [0, 128) (for some reason there is no uint32 tensor in torch...)
89
- TODO: check if the coord range is valid! (current 128 is safe)
90
- Returns:
91
- indices: [N], int32, in [0, 128^3)
92
-
93
- '''
94
- if not coords.is_cuda: coords = coords.cuda()
95
-
96
- N = coords.shape[0]
97
-
98
- indices = torch.empty(N, dtype=torch.int32, device=coords.device)
99
-
100
- _backend.morton3D(coords.int(), N, indices)
101
-
102
- return indices
103
-
104
- morton3D = _morton3D.apply
105
-
106
- class _morton3D_invert(Function):
107
- @staticmethod
108
- def forward(ctx, indices):
109
- ''' morton3D_invert, CUDA implementation
110
- Args:
111
- indices: [N], int32, in [0, 128^3)
112
- Returns:
113
- coords: [N, 3], int32, in [0, 128)
114
-
115
- '''
116
- if not indices.is_cuda: indices = indices.cuda()
117
-
118
- N = indices.shape[0]
119
-
120
- coords = torch.empty(N, 3, dtype=torch.int32, device=indices.device)
121
-
122
- _backend.morton3D_invert(indices.int(), N, coords)
123
-
124
- return coords
125
-
126
- morton3D_invert = _morton3D_invert.apply
127
-
128
-
129
- class _packbits(Function):
130
- @staticmethod
131
- @custom_fwd(cast_inputs=torch.float32)
132
- def forward(ctx, grid, thresh, bitfield=None):
133
- ''' packbits, CUDA implementation
134
- Pack up the density grid into a bit field to accelerate ray marching.
135
- Args:
136
- grid: float, [C, H * H * H], assume H % 2 == 0
137
- thresh: float, threshold
138
- Returns:
139
- bitfield: uint8, [C, H * H * H / 8]
140
- '''
141
- if not grid.is_cuda: grid = grid.cuda()
142
- grid = grid.contiguous()
143
-
144
- C = grid.shape[0]
145
- H3 = grid.shape[1]
146
- N = C * H3 // 8
147
-
148
- if bitfield is None:
149
- bitfield = torch.empty(N, dtype=torch.uint8, device=grid.device)
150
-
151
- _backend.packbits(grid, N, thresh, bitfield)
152
-
153
- return bitfield
154
-
155
- packbits = _packbits.apply
156
-
157
- # ----------------------------------------
158
- # train functions
159
- # ----------------------------------------
160
-
161
- class _march_rays_train(Function):
162
- @staticmethod
163
- @custom_fwd(cast_inputs=torch.float32)
164
- def forward(ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, step_counter=None, mean_count=-1, perturb=False, align=-1, force_all_rays=False, dt_gamma=0, max_steps=1024):
165
- ''' march rays to generate points (forward only)
166
- Args:
167
- rays_o/d: float, [N, 3]
168
- bound: float, scalar
169
- density_bitfield: uint8: [CHHH // 8]
170
- C: int
171
- H: int
172
- nears/fars: float, [N]
173
- step_counter: int32, (2), used to count the actual number of generated points.
174
- mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.)
175
- perturb: bool
176
- align: int, pad output so its size is dividable by align, set to -1 to disable.
177
- force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays.
178
- dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)
179
- max_steps: int, max number of sampled points along each ray, also affect min_stepsize.
180
- Returns:
181
- xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray)
182
- dirs: float, [M, 3], all generated points' view dirs.
183
- deltas: float, [M, 2], all generated points' deltas. (first for RGB, second for Depth)
184
- rays: int32, [N, 3], all rays' (index, point_offset, point_count), e.g., xyzs[rays[i, 1]:rays[i, 2]] --> points belonging to rays[i, 0]
185
- '''
186
-
187
- if not rays_o.is_cuda: rays_o = rays_o.cuda()
188
- if not rays_d.is_cuda: rays_d = rays_d.cuda()
189
- if not density_bitfield.is_cuda: density_bitfield = density_bitfield.cuda()
190
-
191
- rays_o = rays_o.contiguous().view(-1, 3)
192
- rays_d = rays_d.contiguous().view(-1, 3)
193
- density_bitfield = density_bitfield.contiguous()
194
-
195
- N = rays_o.shape[0] # num rays
196
- M = N * max_steps # init max points number in total
197
-
198
- # running average based on previous epoch (mimic `measured_batch_size_before_compaction` in instant-ngp)
199
- # It estimate the max points number to enable faster training, but will lead to random ignored rays if underestimated.
200
- if not force_all_rays and mean_count > 0:
201
- if align > 0:
202
- mean_count += align - mean_count % align
203
- M = mean_count
204
-
205
- xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
206
- dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
207
- deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device)
208
- rays = torch.empty(N, 3, dtype=torch.int32, device=rays_o.device) # id, offset, num_steps
209
-
210
- if step_counter is None:
211
- step_counter = torch.zeros(2, dtype=torch.int32, device=rays_o.device) # point counter, ray counter
212
-
213
- if perturb:
214
- noises = torch.rand(N, dtype=rays_o.dtype, device=rays_o.device)
215
- else:
216
- noises = torch.zeros(N, dtype=rays_o.dtype, device=rays_o.device)
217
-
218
- _backend.march_rays_train(rays_o, rays_d, density_bitfield, bound, dt_gamma, max_steps, N, C, H, M, nears, fars, xyzs, dirs, deltas, rays, step_counter, noises) # m is the actually used points number
219
-
220
- #print(step_counter, M)
221
-
222
- # only used at the first (few) epochs.
223
- if force_all_rays or mean_count <= 0:
224
- m = step_counter[0].item() # D2H copy
225
- if align > 0:
226
- m += align - m % align
227
- xyzs = xyzs[:m]
228
- dirs = dirs[:m]
229
- deltas = deltas[:m]
230
-
231
- torch.cuda.empty_cache()
232
-
233
- return xyzs, dirs, deltas, rays
234
-
235
- march_rays_train = _march_rays_train.apply
236
-
237
-
238
- class _composite_rays_train(Function):
239
- @staticmethod
240
- @custom_fwd(cast_inputs=torch.float32)
241
- def forward(ctx, sigmas, rgbs, deltas, rays, T_thresh=1e-4):
242
- ''' composite rays' rgbs, according to the ray marching formula.
243
- Args:
244
- rgbs: float, [M, 3]
245
- sigmas: float, [M,]
246
- deltas: float, [M, 2]
247
- rays: int32, [N, 3]
248
- Returns:
249
- weights_sum: float, [N,], the alpha channel
250
- depth: float, [N, ], the Depth
251
- image: float, [N, 3], the RGB channel (after multiplying alpha!)
252
- '''
253
-
254
- sigmas = sigmas.contiguous()
255
- rgbs = rgbs.contiguous()
256
-
257
- M = sigmas.shape[0]
258
- N = rays.shape[0]
259
-
260
- weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
261
- depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
262
- image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device)
263
-
264
- _backend.composite_rays_train_forward(sigmas, rgbs, deltas, rays, M, N, T_thresh, weights_sum, depth, image)
265
-
266
- ctx.save_for_backward(sigmas, rgbs, deltas, rays, weights_sum, depth, image)
267
- ctx.dims = [M, N, T_thresh]
268
-
269
- return weights_sum, depth, image
270
-
271
- @staticmethod
272
- @custom_bwd
273
- def backward(ctx, grad_weights_sum, grad_depth, grad_image):
274
-
275
- # NOTE: grad_depth is not used now! It won't be propagated to sigmas.
276
-
277
- grad_weights_sum = grad_weights_sum.contiguous()
278
- grad_image = grad_image.contiguous()
279
-
280
- sigmas, rgbs, deltas, rays, weights_sum, depth, image = ctx.saved_tensors
281
- M, N, T_thresh = ctx.dims
282
-
283
- grad_sigmas = torch.zeros_like(sigmas)
284
- grad_rgbs = torch.zeros_like(rgbs)
285
-
286
- _backend.composite_rays_train_backward(grad_weights_sum, grad_image, sigmas, rgbs, deltas, rays, weights_sum, image, M, N, T_thresh, grad_sigmas, grad_rgbs)
287
-
288
- return grad_sigmas, grad_rgbs, None, None, None
289
-
290
-
291
- composite_rays_train = _composite_rays_train.apply
292
-
293
- # ----------------------------------------
294
- # infer functions
295
- # ----------------------------------------
296
-
297
- class _march_rays(Function):
298
- @staticmethod
299
- @custom_fwd(cast_inputs=torch.float32)
300
- def forward(ctx, n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, density_bitfield, C, H, near, far, align=-1, perturb=False, dt_gamma=0, max_steps=1024):
301
- ''' march rays to generate points (forward only, for inference)
302
- Args:
303
- n_alive: int, number of alive rays
304
- n_step: int, how many steps we march
305
- rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive)
306
- rays_t: float, [N], the alive rays' time, we only use the first n_alive.
307
- rays_o/d: float, [N, 3]
308
- bound: float, scalar
309
- density_bitfield: uint8: [CHHH // 8]
310
- C: int
311
- H: int
312
- nears/fars: float, [N]
313
- align: int, pad output so its size is dividable by align, set to -1 to disable.
314
- perturb: bool/int, int > 0 is used as the random seed.
315
- dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)
316
- max_steps: int, max number of sampled points along each ray, also affect min_stepsize.
317
- Returns:
318
- xyzs: float, [n_alive * n_step, 3], all generated points' coords
319
- dirs: float, [n_alive * n_step, 3], all generated points' view dirs.
320
- deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth).
321
- '''
322
-
323
- if not rays_o.is_cuda: rays_o = rays_o.cuda()
324
- if not rays_d.is_cuda: rays_d = rays_d.cuda()
325
-
326
- rays_o = rays_o.contiguous().view(-1, 3)
327
- rays_d = rays_d.contiguous().view(-1, 3)
328
-
329
- M = n_alive * n_step
330
-
331
- if align > 0:
332
- M += align - (M % align)
333
-
334
- xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
335
- dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
336
- deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) # 2 vals, one for rgb, one for depth
337
-
338
- if perturb:
339
- # torch.manual_seed(perturb) # test_gui uses spp index as seed
340
- noises = torch.rand(n_alive, dtype=rays_o.dtype, device=rays_o.device)
341
- else:
342
- noises = torch.zeros(n_alive, dtype=rays_o.dtype, device=rays_o.device)
343
-
344
- _backend.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, dt_gamma, max_steps, C, H, density_bitfield, near, far, xyzs, dirs, deltas, noises)
345
-
346
- return xyzs, dirs, deltas
347
-
348
- march_rays = _march_rays.apply
349
-
350
-
351
- class _composite_rays(Function):
352
- @staticmethod
353
- @custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
354
- def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh=1e-2):
355
- ''' composite rays' rgbs, according to the ray marching formula. (for inference)
356
- Args:
357
- n_alive: int, number of alive rays
358
- n_step: int, how many steps we march
359
- rays_alive: int, [n_alive], the alive rays' IDs in N (N >= n_alive)
360
- rays_t: float, [N], the alive rays' time
361
- sigmas: float, [n_alive * n_step,]
362
- rgbs: float, [n_alive * n_step, 3]
363
- deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth).
364
- In-place Outputs:
365
- weights_sum: float, [N,], the alpha channel
366
- depth: float, [N,], the depth value
367
- image: float, [N, 3], the RGB channel (after multiplying alpha!)
368
- '''
369
- _backend.composite_rays(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image)
370
- return tuple()
371
-
372
-
373
- composite_rays = _composite_rays.apply
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3drecon/raymarching/setup.py DELETED
@@ -1,62 +0,0 @@
1
- import os
2
- from setuptools import setup
3
- from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4
-
5
- _src_path = os.path.dirname(os.path.abspath(__file__))
6
-
7
- nvcc_flags = [
8
- '-O3', '-std=c++14',
9
- '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
10
- ]
11
-
12
- if os.name == "posix":
13
- c_flags = ['-O3', '-std=c++14']
14
- elif os.name == "nt":
15
- c_flags = ['/O2', '/std:c++17']
16
-
17
- # find cl.exe
18
- def find_cl_path():
19
- import glob
20
- for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
21
- paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
22
- if paths:
23
- return paths[0]
24
-
25
- # If cl.exe is not on path, try to find it.
26
- if os.system("where cl.exe >nul 2>nul") != 0:
27
- cl_path = find_cl_path()
28
- if cl_path is None:
29
- raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
30
- os.environ["PATH"] += ";" + cl_path
31
-
32
- '''
33
- Usage:
34
-
35
- python setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory)
36
-
37
- python setup.py install # build extensions and install (copy) to PATH.
38
- pip install . # ditto but better (e.g., dependency & metadata handling)
39
-
40
- python setup.py develop # build extensions and install (symbolic) to PATH.
41
- pip install -e . # ditto but better (e.g., dependency & metadata handling)
42
-
43
- '''
44
- setup(
45
- name='raymarching', # package name, import this to use python API
46
- ext_modules=[
47
- CUDAExtension(
48
- name='_raymarching', # extension name, import this to use CUDA API
49
- sources=[os.path.join(_src_path, 'src', f) for f in [
50
- 'raymarching.cu',
51
- 'bindings.cpp',
52
- ]],
53
- extra_compile_args={
54
- 'cxx': c_flags,
55
- 'nvcc': nvcc_flags,
56
- }
57
- ),
58
- ],
59
- cmdclass={
60
- 'build_ext': BuildExtension,
61
- }
62
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3drecon/raymarching/src/bindings.cpp DELETED
@@ -1,19 +0,0 @@
1
- #include <torch/extension.h>
2
-
3
- #include "raymarching.h"
4
-
5
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6
- // utils
7
- m.def("packbits", &packbits, "packbits (CUDA)");
8
- m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)");
9
- m.def("sph_from_ray", &sph_from_ray, "sph_from_ray (CUDA)");
10
- m.def("morton3D", &morton3D, "morton3D (CUDA)");
11
- m.def("morton3D_invert", &morton3D_invert, "morton3D_invert (CUDA)");
12
- // train
13
- m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)");
14
- m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)");
15
- m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)");
16
- // infer
17
- m.def("march_rays", &march_rays, "march rays (CUDA)");
18
- m.def("composite_rays", &composite_rays, "composite rays (CUDA)");
19
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3drecon/raymarching/src/raymarching.cu DELETED
@@ -1,914 +0,0 @@
1
- #include <cuda.h>
2
- #include <cuda_fp16.h>
3
- #include <cuda_runtime.h>
4
-
5
- #include <ATen/cuda/CUDAContext.h>
6
- #include <torch/torch.h>
7
-
8
- #include <cstdio>
9
- #include <stdint.h>
10
- #include <stdexcept>
11
- #include <limits>
12
-
13
- #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
14
- #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
15
- #define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
16
- #define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
17
-
18
-
19
- inline constexpr __device__ float SQRT3() { return 1.7320508075688772f; }
20
- inline constexpr __device__ float RSQRT3() { return 0.5773502691896258f; }
21
- inline constexpr __device__ float PI() { return 3.141592653589793f; }
22
- inline constexpr __device__ float RPI() { return 0.3183098861837907f; }
23
-
24
-
25
- template <typename T>
26
- inline __host__ __device__ T div_round_up(T val, T divisor) {
27
- return (val + divisor - 1) / divisor;
28
- }
29
-
30
- inline __host__ __device__ float signf(const float x) {
31
- return copysignf(1.0, x);
32
- }
33
-
34
- inline __host__ __device__ float clamp(const float x, const float min, const float max) {
35
- return fminf(max, fmaxf(min, x));
36
- }
37
-
38
- inline __host__ __device__ void swapf(float& a, float& b) {
39
- float c = a; a = b; b = c;
40
- }
41
-
42
- inline __device__ int mip_from_pos(const float x, const float y, const float z, const float max_cascade) {
43
- const float mx = fmaxf(fabsf(x), fmaxf(fabs(y), fabs(z)));
44
- int exponent;
45
- frexpf(mx, &exponent); // [0, 0.5) --> -1, [0.5, 1) --> 0, [1, 2) --> 1, [2, 4) --> 2, ...
46
- return fminf(max_cascade - 1, fmaxf(0, exponent));
47
- }
48
-
49
- inline __device__ int mip_from_dt(const float dt, const float H, const float max_cascade) {
50
- const float mx = dt * H * 0.5;
51
- int exponent;
52
- frexpf(mx, &exponent);
53
- return fminf(max_cascade - 1, fmaxf(0, exponent));
54
- }
55
-
56
- inline __host__ __device__ uint32_t __expand_bits(uint32_t v)
57
- {
58
- v = (v * 0x00010001u) & 0xFF0000FFu;
59
- v = (v * 0x00000101u) & 0x0F00F00Fu;
60
- v = (v * 0x00000011u) & 0xC30C30C3u;
61
- v = (v * 0x00000005u) & 0x49249249u;
62
- return v;
63
- }
64
-
65
- inline __host__ __device__ uint32_t __morton3D(uint32_t x, uint32_t y, uint32_t z)
66
- {
67
- uint32_t xx = __expand_bits(x);
68
- uint32_t yy = __expand_bits(y);
69
- uint32_t zz = __expand_bits(z);
70
- return xx | (yy << 1) | (zz << 2);
71
- }
72
-
73
- inline __host__ __device__ uint32_t __morton3D_invert(uint32_t x)
74
- {
75
- x = x & 0x49249249;
76
- x = (x | (x >> 2)) & 0xc30c30c3;
77
- x = (x | (x >> 4)) & 0x0f00f00f;
78
- x = (x | (x >> 8)) & 0xff0000ff;
79
- x = (x | (x >> 16)) & 0x0000ffff;
80
- return x;
81
- }
82
-
83
-
84
- ////////////////////////////////////////////////////
85
- ///////////// utils /////////////
86
- ////////////////////////////////////////////////////
87
-
88
- // rays_o/d: [N, 3]
89
- // nears/fars: [N]
90
- // scalar_t should always be float in use.
91
- template <typename scalar_t>
92
- __global__ void kernel_near_far_from_aabb(
93
- const scalar_t * __restrict__ rays_o,
94
- const scalar_t * __restrict__ rays_d,
95
- const scalar_t * __restrict__ aabb,
96
- const uint32_t N,
97
- const float min_near,
98
- scalar_t * nears, scalar_t * fars
99
- ) {
100
- // parallel per ray
101
- const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
102
- if (n >= N) return;
103
-
104
- // locate
105
- rays_o += n * 3;
106
- rays_d += n * 3;
107
-
108
- const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
109
- const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
110
- const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
111
-
112
- // get near far (assume cube scene)
113
- float near = (aabb[0] - ox) * rdx;
114
- float far = (aabb[3] - ox) * rdx;
115
- if (near > far) swapf(near, far);
116
-
117
- float near_y = (aabb[1] - oy) * rdy;
118
- float far_y = (aabb[4] - oy) * rdy;
119
- if (near_y > far_y) swapf(near_y, far_y);
120
-
121
- if (near > far_y || near_y > far) {
122
- nears[n] = fars[n] = std::numeric_limits<scalar_t>::max();
123
- return;
124
- }
125
-
126
- if (near_y > near) near = near_y;
127
- if (far_y < far) far = far_y;
128
-
129
- float near_z = (aabb[2] - oz) * rdz;
130
- float far_z = (aabb[5] - oz) * rdz;
131
- if (near_z > far_z) swapf(near_z, far_z);
132
-
133
- if (near > far_z || near_z > far) {
134
- nears[n] = fars[n] = std::numeric_limits<scalar_t>::max();
135
- return;
136
- }
137
-
138
- if (near_z > near) near = near_z;
139
- if (far_z < far) far = far_z;
140
-
141
- if (near < min_near) near = min_near;
142
-
143
- nears[n] = near;
144
- fars[n] = far;
145
- }
146
-
147
-
148
- void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars) {
149
-
150
- static constexpr uint32_t N_THREAD = 128;
151
-
152
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
153
- rays_o.scalar_type(), "near_far_from_aabb", ([&] {
154
- kernel_near_far_from_aabb<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), aabb.data_ptr<scalar_t>(), N, min_near, nears.data_ptr<scalar_t>(), fars.data_ptr<scalar_t>());
155
- }));
156
- }
157
-
158
-
159
- // rays_o/d: [N, 3]
160
- // radius: float
161
- // coords: [N, 2]
162
- template <typename scalar_t>
163
- __global__ void kernel_sph_from_ray(
164
- const scalar_t * __restrict__ rays_o,
165
- const scalar_t * __restrict__ rays_d,
166
- const float radius,
167
- const uint32_t N,
168
- scalar_t * coords
169
- ) {
170
- // parallel per ray
171
- const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
172
- if (n >= N) return;
173
-
174
- // locate
175
- rays_o += n * 3;
176
- rays_d += n * 3;
177
- coords += n * 2;
178
-
179
- const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
180
- const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
181
- const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
182
-
183
- // solve t from || o + td || = radius
184
- const float A = dx * dx + dy * dy + dz * dz;
185
- const float B = ox * dx + oy * dy + oz * dz; // in fact B / 2
186
- const float C = ox * ox + oy * oy + oz * oz - radius * radius;
187
-
188
- const float t = (- B + sqrtf(B * B - A * C)) / A; // always use the larger solution (positive)
189
-
190
- // solve theta, phi (assume y is the up axis)
191
- const float x = ox + t * dx, y = oy + t * dy, z = oz + t * dz;
192
- const float theta = atan2(sqrtf(x * x + z * z), y); // [0, PI)
193
- const float phi = atan2(z, x); // [-PI, PI)
194
-
195
- // normalize to [-1, 1]
196
- coords[0] = 2 * theta * RPI() - 1;
197
- coords[1] = phi * RPI();
198
- }
199
-
200
-
201
- void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords) {
202
-
203
- static constexpr uint32_t N_THREAD = 128;
204
-
205
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
206
- rays_o.scalar_type(), "sph_from_ray", ([&] {
207
- kernel_sph_from_ray<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), radius, N, coords.data_ptr<scalar_t>());
208
- }));
209
- }
210
-
211
-
212
- // coords: int32, [N, 3]
213
- // indices: int32, [N]
214
- __global__ void kernel_morton3D(
215
- const int * __restrict__ coords,
216
- const uint32_t N,
217
- int * indices
218
- ) {
219
- // parallel
220
- const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
221
- if (n >= N) return;
222
-
223
- // locate
224
- coords += n * 3;
225
- indices[n] = __morton3D(coords[0], coords[1], coords[2]);
226
- }
227
-
228
-
229
- void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices) {
230
- static constexpr uint32_t N_THREAD = 128;
231
- kernel_morton3D<<<div_round_up(N, N_THREAD), N_THREAD>>>(coords.data_ptr<int>(), N, indices.data_ptr<int>());
232
- }
233
-
234
-
235
- // indices: int32, [N]
236
- // coords: int32, [N, 3]
237
- __global__ void kernel_morton3D_invert(
238
- const int * __restrict__ indices,
239
- const uint32_t N,
240
- int * coords
241
- ) {
242
- // parallel
243
- const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
244
- if (n >= N) return;
245
-
246
- // locate
247
- coords += n * 3;
248
-
249
- const int ind = indices[n];
250
-
251
- coords[0] = __morton3D_invert(ind >> 0);
252
- coords[1] = __morton3D_invert(ind >> 1);
253
- coords[2] = __morton3D_invert(ind >> 2);
254
- }
255
-
256
-
257
- void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords) {
258
- static constexpr uint32_t N_THREAD = 128;
259
- kernel_morton3D_invert<<<div_round_up(N, N_THREAD), N_THREAD>>>(indices.data_ptr<int>(), N, coords.data_ptr<int>());
260
- }
261
-
262
-
263
- // grid: float, [C, H, H, H]
264
- // N: int, C * H * H * H / 8
265
- // density_thresh: float
266
- // bitfield: uint8, [N]
267
- template <typename scalar_t>
268
- __global__ void kernel_packbits(
269
- const scalar_t * __restrict__ grid,
270
- const uint32_t N,
271
- const float density_thresh,
272
- uint8_t * bitfield
273
- ) {
274
- // parallel per byte
275
- const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
276
- if (n >= N) return;
277
-
278
- // locate
279
- grid += n * 8;
280
-
281
- uint8_t bits = 0;
282
-
283
- #pragma unroll
284
- for (uint8_t i = 0; i < 8; i++) {
285
- bits |= (grid[i] > density_thresh) ? ((uint8_t)1 << i) : 0;
286
- }
287
-
288
- bitfield[n] = bits;
289
- }
290
-
291
-
292
- void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield) {
293
-
294
- static constexpr uint32_t N_THREAD = 128;
295
-
296
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
297
- grid.scalar_type(), "packbits", ([&] {
298
- kernel_packbits<<<div_round_up(N, N_THREAD), N_THREAD>>>(grid.data_ptr<scalar_t>(), N, density_thresh, bitfield.data_ptr<uint8_t>());
299
- }));
300
- }
301
-
302
- ////////////////////////////////////////////////////
303
- ///////////// training /////////////
304
- ////////////////////////////////////////////////////
305
-
306
- // rays_o/d: [N, 3]
307
- // grid: [CHHH / 8]
308
- // xyzs, dirs, deltas: [M, 3], [M, 3], [M, 2]
309
- // dirs: [M, 3]
310
- // rays: [N, 3], idx, offset, num_steps
311
- template <typename scalar_t>
312
- __global__ void kernel_march_rays_train(
313
- const scalar_t * __restrict__ rays_o,
314
- const scalar_t * __restrict__ rays_d,
315
- const uint8_t * __restrict__ grid,
316
- const float bound,
317
- const float dt_gamma, const uint32_t max_steps,
318
- const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M,
319
- const scalar_t* __restrict__ nears,
320
- const scalar_t* __restrict__ fars,
321
- scalar_t * xyzs, scalar_t * dirs, scalar_t * deltas,
322
- int * rays,
323
- int * counter,
324
- const scalar_t* __restrict__ noises
325
- ) {
326
- // parallel per ray
327
- const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
328
- if (n >= N) return;
329
-
330
- // locate
331
- rays_o += n * 3;
332
- rays_d += n * 3;
333
-
334
- // ray marching
335
- const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
336
- const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
337
- const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
338
- const float rH = 1 / (float)H;
339
- const float H3 = H * H * H;
340
-
341
- const float near = nears[n];
342
- const float far = fars[n];
343
- const float noise = noises[n];
344
-
345
- const float dt_min = 2 * SQRT3() / max_steps;
346
- const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H;
347
-
348
- float t0 = near;
349
-
350
- // perturb
351
- t0 += clamp(t0 * dt_gamma, dt_min, dt_max) * noise;
352
-
353
- // first pass: estimation of num_steps
354
- float t = t0;
355
- uint32_t num_steps = 0;
356
-
357
- //if (t < far) printf("valid ray %d t=%f near=%f far=%f \n", n, t, near, far);
358
-
359
- while (t < far && num_steps < max_steps) {
360
- // current point
361
- const float x = clamp(ox + t * dx, -bound, bound);
362
- const float y = clamp(oy + t * dy, -bound, bound);
363
- const float z = clamp(oz + t * dz, -bound, bound);
364
-
365
- const float dt = clamp(t * dt_gamma, dt_min, dt_max);
366
-
367
- // get mip level
368
- const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]
369
-
370
- const float mip_bound = fminf(scalbnf(1.0f, level), bound);
371
- const float mip_rbound = 1 / mip_bound;
372
-
373
- // convert to nearest grid position
374
- const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
375
- const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
376
- const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
377
-
378
- const uint32_t index = level * H3 + __morton3D(nx, ny, nz);
379
- const bool occ = grid[index / 8] & (1 << (index % 8));
380
-
381
- // if occpuied, advance a small step, and write to output
382
- //if (n == 0) printf("t=%f density=%f vs thresh=%f step=%d\n", t, density, density_thresh, num_steps);
383
-
384
- if (occ) {
385
- num_steps++;
386
- t += dt;
387
- // else, skip a large step (basically skip a voxel grid)
388
- } else {
389
- // calc distance to next voxel
390
- const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx;
391
- const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy;
392
- const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz;
393
-
394
- const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
395
- // step until next voxel
396
- do {
397
- t += clamp(t * dt_gamma, dt_min, dt_max);
398
- } while (t < tt);
399
- }
400
- }
401
-
402
- //printf("[n=%d] num_steps=%d, near=%f, far=%f, dt=%f, max_steps=%f\n", n, num_steps, near, far, dt_min, (far - near) / dt_min);
403
-
404
- // second pass: really locate and write points & dirs
405
- uint32_t point_index = atomicAdd(counter, num_steps);
406
- uint32_t ray_index = atomicAdd(counter + 1, 1);
407
-
408
- //printf("[n=%d] num_steps=%d, point_index=%d, ray_index=%d\n", n, num_steps, point_index, ray_index);
409
-
410
- // write rays
411
- rays[ray_index * 3] = n;
412
- rays[ray_index * 3 + 1] = point_index;
413
- rays[ray_index * 3 + 2] = num_steps;
414
-
415
- if (num_steps == 0) return;
416
- if (point_index + num_steps > M) return;
417
-
418
- xyzs += point_index * 3;
419
- dirs += point_index * 3;
420
- deltas += point_index * 2;
421
-
422
- t = t0;
423
- uint32_t step = 0;
424
-
425
- float last_t = t;
426
-
427
- while (t < far && step < num_steps) {
428
- // current point
429
- const float x = clamp(ox + t * dx, -bound, bound);
430
- const float y = clamp(oy + t * dy, -bound, bound);
431
- const float z = clamp(oz + t * dz, -bound, bound);
432
-
433
- const float dt = clamp(t * dt_gamma, dt_min, dt_max);
434
-
435
- // get mip level
436
- const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]
437
-
438
- const float mip_bound = fminf(scalbnf(1.0f, level), bound);
439
- const float mip_rbound = 1 / mip_bound;
440
-
441
- // convert to nearest grid position
442
- const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
443
- const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
444
- const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
445
-
446
- // query grid
447
- const uint32_t index = level * H3 + __morton3D(nx, ny, nz);
448
- const bool occ = grid[index / 8] & (1 << (index % 8));
449
-
450
- // if occpuied, advance a small step, and write to output
451
- if (occ) {
452
- // write step
453
- xyzs[0] = x;
454
- xyzs[1] = y;
455
- xyzs[2] = z;
456
- dirs[0] = dx;
457
- dirs[1] = dy;
458
- dirs[2] = dz;
459
- t += dt;
460
- deltas[0] = dt;
461
- deltas[1] = t - last_t; // used to calc depth
462
- last_t = t;
463
- xyzs += 3;
464
- dirs += 3;
465
- deltas += 2;
466
- step++;
467
- // else, skip a large step (basically skip a voxel grid)
468
- } else {
469
- // calc distance to next voxel
470
- const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx;
471
- const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy;
472
- const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz;
473
- const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
474
- // step until next voxel
475
- do {
476
- t += clamp(t * dt_gamma, dt_min, dt_max);
477
- } while (t < tt);
478
- }
479
- }
480
- }
481
-
482
- void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises) {
483
-
484
- static constexpr uint32_t N_THREAD = 128;
485
-
486
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
487
- rays_o.scalar_type(), "march_rays_train", ([&] {
488
- kernel_march_rays_train<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), grid.data_ptr<uint8_t>(), bound, dt_gamma, max_steps, N, C, H, M, nears.data_ptr<scalar_t>(), fars.data_ptr<scalar_t>(), xyzs.data_ptr<scalar_t>(), dirs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), counter.data_ptr<int>(), noises.data_ptr<scalar_t>());
489
- }));
490
- }
491
-
492
-
493
- // sigmas: [M]
494
- // rgbs: [M, 3]
495
- // deltas: [M, 2]
496
- // rays: [N, 3], idx, offset, num_steps
497
- // weights_sum: [N], final pixel alpha
498
- // depth: [N,]
499
- // image: [N, 3]
500
- template <typename scalar_t>
501
- __global__ void kernel_composite_rays_train_forward(
502
- const scalar_t * __restrict__ sigmas,
503
- const scalar_t * __restrict__ rgbs,
504
- const scalar_t * __restrict__ deltas,
505
- const int * __restrict__ rays,
506
- const uint32_t M, const uint32_t N, const float T_thresh,
507
- scalar_t * weights_sum,
508
- scalar_t * depth,
509
- scalar_t * image
510
- ) {
511
- // parallel per ray
512
- const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
513
- if (n >= N) return;
514
-
515
- // locate
516
- uint32_t index = rays[n * 3];
517
- uint32_t offset = rays[n * 3 + 1];
518
- uint32_t num_steps = rays[n * 3 + 2];
519
-
520
- // empty ray, or ray that exceed max step count.
521
- if (num_steps == 0 || offset + num_steps > M) {
522
- weights_sum[index] = 0;
523
- depth[index] = 0;
524
- image[index * 3] = 0;
525
- image[index * 3 + 1] = 0;
526
- image[index * 3 + 2] = 0;
527
- return;
528
- }
529
-
530
- sigmas += offset;
531
- rgbs += offset * 3;
532
- deltas += offset * 2;
533
-
534
- // accumulate
535
- uint32_t step = 0;
536
-
537
- scalar_t T = 1.0f;
538
- scalar_t r = 0, g = 0, b = 0, ws = 0, t = 0, d = 0;
539
-
540
- while (step < num_steps) {
541
-
542
- const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]);
543
- const scalar_t weight = alpha * T;
544
-
545
- r += weight * rgbs[0];
546
- g += weight * rgbs[1];
547
- b += weight * rgbs[2];
548
-
549
- t += deltas[1]; // real delta
550
- d += weight * t;
551
-
552
- ws += weight;
553
-
554
- T *= 1.0f - alpha;
555
-
556
- // minimal remained transmittence
557
- if (T < T_thresh) break;
558
-
559
- //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d);
560
-
561
- // locate
562
- sigmas++;
563
- rgbs += 3;
564
- deltas += 2;
565
-
566
- step++;
567
- }
568
-
569
- //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d);
570
-
571
- // write
572
- weights_sum[index] = ws; // weights_sum
573
- depth[index] = d;
574
- image[index * 3] = r;
575
- image[index * 3 + 1] = g;
576
- image[index * 3 + 2] = b;
577
- }
578
-
579
-
580
- void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor depth, at::Tensor image) {
581
-
582
- static constexpr uint32_t N_THREAD = 128;
583
-
584
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
585
- sigmas.scalar_type(), "composite_rays_train_forward", ([&] {
586
- kernel_composite_rays_train_forward<<<div_round_up(N, N_THREAD), N_THREAD>>>(sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), M, N, T_thresh, weights_sum.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>());
587
- }));
588
- }
589
-
590
-
591
- // grad_weights_sum: [N,]
592
- // grad: [N, 3]
593
- // sigmas: [M]
594
- // rgbs: [M, 3]
595
- // deltas: [M, 2]
596
- // rays: [N, 3], idx, offset, num_steps
597
- // weights_sum: [N,], weights_sum here
598
- // image: [N, 3]
599
- // grad_sigmas: [M]
600
- // grad_rgbs: [M, 3]
601
- template <typename scalar_t>
602
- __global__ void kernel_composite_rays_train_backward(
603
- const scalar_t * __restrict__ grad_weights_sum,
604
- const scalar_t * __restrict__ grad_image,
605
- const scalar_t * __restrict__ sigmas,
606
- const scalar_t * __restrict__ rgbs,
607
- const scalar_t * __restrict__ deltas,
608
- const int * __restrict__ rays,
609
- const scalar_t * __restrict__ weights_sum,
610
- const scalar_t * __restrict__ image,
611
- const uint32_t M, const uint32_t N, const float T_thresh,
612
- scalar_t * grad_sigmas,
613
- scalar_t * grad_rgbs
614
- ) {
615
- // parallel per ray
616
- const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
617
- if (n >= N) return;
618
-
619
- // locate
620
- uint32_t index = rays[n * 3];
621
- uint32_t offset = rays[n * 3 + 1];
622
- uint32_t num_steps = rays[n * 3 + 2];
623
-
624
- if (num_steps == 0 || offset + num_steps > M) return;
625
-
626
- grad_weights_sum += index;
627
- grad_image += index * 3;
628
- weights_sum += index;
629
- image += index * 3;
630
- sigmas += offset;
631
- rgbs += offset * 3;
632
- deltas += offset * 2;
633
- grad_sigmas += offset;
634
- grad_rgbs += offset * 3;
635
-
636
- // accumulate
637
- uint32_t step = 0;
638
-
639
- scalar_t T = 1.0f;
640
- const scalar_t r_final = image[0], g_final = image[1], b_final = image[2], ws_final = weights_sum[0];
641
- scalar_t r = 0, g = 0, b = 0, ws = 0;
642
-
643
- while (step < num_steps) {
644
-
645
- const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]);
646
- const scalar_t weight = alpha * T;
647
-
648
- r += weight * rgbs[0];
649
- g += weight * rgbs[1];
650
- b += weight * rgbs[2];
651
- ws += weight;
652
-
653
- T *= 1.0f - alpha;
654
-
655
- // check https://note.kiui.moe/others/nerf_gradient/ for the gradient calculation.
656
- // write grad_rgbs
657
- grad_rgbs[0] = grad_image[0] * weight;
658
- grad_rgbs[1] = grad_image[1] * weight;
659
- grad_rgbs[2] = grad_image[2] * weight;
660
-
661
- // write grad_sigmas
662
- grad_sigmas[0] = deltas[0] * (
663
- grad_image[0] * (T * rgbs[0] - (r_final - r)) +
664
- grad_image[1] * (T * rgbs[1] - (g_final - g)) +
665
- grad_image[2] * (T * rgbs[2] - (b_final - b)) +
666
- grad_weights_sum[0] * (1 - ws_final)
667
- );
668
-
669
- //printf("[n=%d] num_steps=%d, T=%f, grad_sigmas=%f, r_final=%f, r=%f\n", n, step, T, grad_sigmas[0], r_final, r);
670
- // minimal remained transmittence
671
- if (T < T_thresh) break;
672
-
673
- // locate
674
- sigmas++;
675
- rgbs += 3;
676
- deltas += 2;
677
- grad_sigmas++;
678
- grad_rgbs += 3;
679
-
680
- step++;
681
- }
682
- }
683
-
684
-
685
- void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs) {
686
-
687
- static constexpr uint32_t N_THREAD = 128;
688
-
689
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
690
- grad_image.scalar_type(), "composite_rays_train_backward", ([&] {
691
- kernel_composite_rays_train_backward<<<div_round_up(N, N_THREAD), N_THREAD>>>(grad_weights_sum.data_ptr<scalar_t>(), grad_image.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), weights_sum.data_ptr<scalar_t>(), image.data_ptr<scalar_t>(), M, N, T_thresh, grad_sigmas.data_ptr<scalar_t>(), grad_rgbs.data_ptr<scalar_t>());
692
- }));
693
- }
694
-
695
-
696
- ////////////////////////////////////////////////////
697
- ///////////// infernce /////////////
698
- ////////////////////////////////////////////////////
699
-
700
- template <typename scalar_t>
701
- __global__ void kernel_march_rays(
702
- const uint32_t n_alive,
703
- const uint32_t n_step,
704
- const int* __restrict__ rays_alive,
705
- const scalar_t* __restrict__ rays_t,
706
- const scalar_t* __restrict__ rays_o,
707
- const scalar_t* __restrict__ rays_d,
708
- const float bound,
709
- const float dt_gamma, const uint32_t max_steps,
710
- const uint32_t C, const uint32_t H,
711
- const uint8_t * __restrict__ grid,
712
- const scalar_t* __restrict__ nears,
713
- const scalar_t* __restrict__ fars,
714
- scalar_t* xyzs, scalar_t* dirs, scalar_t* deltas,
715
- const scalar_t* __restrict__ noises
716
- ) {
717
- const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
718
- if (n >= n_alive) return;
719
-
720
- const int index = rays_alive[n]; // ray id
721
- const float noise = noises[n];
722
-
723
- // locate
724
- rays_o += index * 3;
725
- rays_d += index * 3;
726
- xyzs += n * n_step * 3;
727
- dirs += n * n_step * 3;
728
- deltas += n * n_step * 2;
729
-
730
- const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
731
- const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
732
- const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
733
- const float rH = 1 / (float)H;
734
- const float H3 = H * H * H;
735
-
736
- float t = rays_t[index]; // current ray's t
737
- const float near = nears[index], far = fars[index];
738
-
739
- const float dt_min = 2 * SQRT3() / max_steps;
740
- const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H;
741
-
742
- // march for n_step steps, record points
743
- uint32_t step = 0;
744
-
745
- // introduce some randomness
746
- t += clamp(t * dt_gamma, dt_min, dt_max) * noise;
747
-
748
- float last_t = t;
749
-
750
- while (t < far && step < n_step) {
751
- // current point
752
- const float x = clamp(ox + t * dx, -bound, bound);
753
- const float y = clamp(oy + t * dy, -bound, bound);
754
- const float z = clamp(oz + t * dz, -bound, bound);
755
-
756
- const float dt = clamp(t * dt_gamma, dt_min, dt_max);
757
-
758
- // get mip level
759
- const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]
760
-
761
- const float mip_bound = fminf(scalbnf(1, level), bound);
762
- const float mip_rbound = 1 / mip_bound;
763
-
764
- // convert to nearest grid position
765
- const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
766
- const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
767
- const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
768
-
769
- const uint32_t index = level * H3 + __morton3D(nx, ny, nz);
770
- const bool occ = grid[index / 8] & (1 << (index % 8));
771
-
772
- // if occpuied, advance a small step, and write to output
773
- if (occ) {
774
- // write step
775
- xyzs[0] = x;
776
- xyzs[1] = y;
777
- xyzs[2] = z;
778
- dirs[0] = dx;
779
- dirs[1] = dy;
780
- dirs[2] = dz;
781
- // calc dt
782
- t += dt;
783
- deltas[0] = dt;
784
- deltas[1] = t - last_t; // used to calc depth
785
- last_t = t;
786
- // step
787
- xyzs += 3;
788
- dirs += 3;
789
- deltas += 2;
790
- step++;
791
-
792
- // else, skip a large step (basically skip a voxel grid)
793
- } else {
794
- // calc distance to next voxel
795
- const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx;
796
- const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy;
797
- const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz;
798
- const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
799
- // step until next voxel
800
- do {
801
- t += clamp(t * dt_gamma, dt_min, dt_max);
802
- } while (t < tt);
803
- }
804
- }
805
- }
806
-
807
-
808
- void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor near, const at::Tensor far, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises) {
809
- static constexpr uint32_t N_THREAD = 128;
810
-
811
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
812
- rays_o.scalar_type(), "march_rays", ([&] {
813
- kernel_march_rays<<<div_round_up(n_alive, N_THREAD), N_THREAD>>>(n_alive, n_step, rays_alive.data_ptr<int>(), rays_t.data_ptr<scalar_t>(), rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), bound, dt_gamma, max_steps, C, H, grid.data_ptr<uint8_t>(), near.data_ptr<scalar_t>(), far.data_ptr<scalar_t>(), xyzs.data_ptr<scalar_t>(), dirs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), noises.data_ptr<scalar_t>());
814
- }));
815
- }
816
-
817
-
818
- template <typename scalar_t>
819
- __global__ void kernel_composite_rays(
820
- const uint32_t n_alive,
821
- const uint32_t n_step,
822
- const float T_thresh,
823
- int* rays_alive,
824
- scalar_t* rays_t,
825
- const scalar_t* __restrict__ sigmas,
826
- const scalar_t* __restrict__ rgbs,
827
- const scalar_t* __restrict__ deltas,
828
- scalar_t* weights_sum, scalar_t* depth, scalar_t* image
829
- ) {
830
- const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
831
- if (n >= n_alive) return;
832
-
833
- const int index = rays_alive[n]; // ray id
834
-
835
- // locate
836
- sigmas += n * n_step;
837
- rgbs += n * n_step * 3;
838
- deltas += n * n_step * 2;
839
-
840
- rays_t += index;
841
- weights_sum += index;
842
- depth += index;
843
- image += index * 3;
844
-
845
- scalar_t t = rays_t[0]; // current ray's t
846
-
847
- scalar_t weight_sum = weights_sum[0];
848
- scalar_t d = depth[0];
849
- scalar_t r = image[0];
850
- scalar_t g = image[1];
851
- scalar_t b = image[2];
852
-
853
- // accumulate
854
- uint32_t step = 0;
855
- while (step < n_step) {
856
-
857
- // ray is terminated if delta == 0
858
- if (deltas[0] == 0) break;
859
-
860
- const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]);
861
-
862
- /*
863
- T_0 = 1; T_i = \prod_{j=0}^{i-1} (1 - alpha_j)
864
- w_i = alpha_i * T_i
865
- -->
866
- T_i = 1 - \sum_{j=0}^{i-1} w_j
867
- */
868
- const scalar_t T = 1 - weight_sum;
869
- const scalar_t weight = alpha * T;
870
- weight_sum += weight;
871
-
872
- t += deltas[1]; // real delta
873
- d += weight * t;
874
- r += weight * rgbs[0];
875
- g += weight * rgbs[1];
876
- b += weight * rgbs[2];
877
-
878
- //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d);
879
-
880
- // ray is terminated if T is too small
881
- // use a larger bound to further accelerate inference
882
- if (T < T_thresh) break;
883
-
884
- // locate
885
- sigmas++;
886
- rgbs += 3;
887
- deltas += 2;
888
- step++;
889
- }
890
-
891
- //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d);
892
-
893
- // rays_alive = -1 means ray is terminated early.
894
- if (step < n_step) {
895
- rays_alive[n] = -1;
896
- } else {
897
- rays_t[0] = t;
898
- }
899
-
900
- weights_sum[0] = weight_sum; // this is the thing I needed!
901
- depth[0] = d;
902
- image[0] = r;
903
- image[1] = g;
904
- image[2] = b;
905
- }
906
-
907
-
908
- void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights, at::Tensor depth, at::Tensor image) {
909
- static constexpr uint32_t N_THREAD = 128;
910
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
911
- image.scalar_type(), "composite_rays", ([&] {
912
- kernel_composite_rays<<<div_round_up(n_alive, N_THREAD), N_THREAD>>>(n_alive, n_step, T_thresh, rays_alive.data_ptr<int>(), rays_t.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), weights.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>());
913
- }));
914
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3drecon/raymarching/src/raymarching.h DELETED
@@ -1,18 +0,0 @@
1
- #pragma once
2
-
3
- #include <stdint.h>
4
- #include <torch/torch.h>
5
-
6
-
7
- void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars);
8
- void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords);
9
- void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices);
10
- void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords);
11
- void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield);
12
-
13
- void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises);
14
- void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);
15
- void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs);
16
-
17
- void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises);
18
- void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3drecon/renderer/agg_net.py DELETED
@@ -1,83 +0,0 @@
1
- import torch.nn.functional as F
2
- import torch.nn as nn
3
- import torch
4
-
5
- def weights_init(m):
6
- if isinstance(m, nn.Linear):
7
- nn.init.kaiming_normal_(m.weight.data)
8
- if m.bias is not None:
9
- nn.init.zeros_(m.bias.data)
10
-
11
- class NeRF(nn.Module):
12
- def __init__(self, vol_n=8+8, feat_ch=8+16+32+3, hid_n=64):
13
- super(NeRF, self).__init__()
14
- self.hid_n = hid_n
15
- self.agg = Agg(feat_ch)
16
- self.lr0 = nn.Sequential(nn.Linear(vol_n+16, hid_n), nn.ReLU())
17
- self.sigma = nn.Sequential(nn.Linear(hid_n, 1), nn.Softplus())
18
- self.color = nn.Sequential(
19
- nn.Linear(16+vol_n+feat_ch+hid_n+4, hid_n), # agg_feats+vox_feat+img_feat+lr0_feats+dir
20
- nn.ReLU(),
21
- nn.Linear(hid_n, 1)
22
- )
23
- self.lr0.apply(weights_init)
24
- self.sigma.apply(weights_init)
25
- self.color.apply(weights_init)
26
-
27
- def forward(self, vox_feat, img_feat_rgb_dir, source_img_mask):
28
- # assert torch.sum(torch.sum(source_img_mask,1)<2)==0
29
- b, d, n, _ = img_feat_rgb_dir.shape # b,d,n,f=8+16+32+3+4
30
- agg_feat = self.agg(img_feat_rgb_dir, source_img_mask) # b,d,f=16
31
- x = self.lr0(torch.cat((vox_feat, agg_feat), dim=-1)) # b,d,f=64
32
- sigma = self.sigma(x) # b,d,1
33
-
34
- x = torch.cat((x, vox_feat, agg_feat), dim=-1) # b,d,f=16+16+64
35
- x = x.view(b, d, 1, x.shape[-1]).repeat(1, 1, n, 1)
36
- x = torch.cat((x, img_feat_rgb_dir), dim=-1)
37
- logits = self.color(x)
38
- source_img_mask_ = source_img_mask.reshape(b, 1, n, 1).repeat(1, logits.shape[1], 1, 1) == 0
39
- logits[source_img_mask_] = -1e7
40
- color_weight = F.softmax(logits, dim=-2)
41
- color = torch.sum((img_feat_rgb_dir[..., -7:-4] * color_weight), dim=-2)
42
- return color, sigma
43
-
44
- class Agg(nn.Module):
45
- def __init__(self, feat_ch):
46
- super(Agg, self).__init__()
47
- self.feat_ch = feat_ch
48
- self.view_fc = nn.Sequential(nn.Linear(4, feat_ch), nn.ReLU())
49
- self.view_fc.apply(weights_init)
50
- self.global_fc = nn.Sequential(nn.Linear(feat_ch*3, 32), nn.ReLU())
51
-
52
- self.agg_w_fc = nn.Linear(32, 1)
53
- self.fc = nn.Linear(32, 16)
54
- self.global_fc.apply(weights_init)
55
- self.agg_w_fc.apply(weights_init)
56
- self.fc.apply(weights_init)
57
-
58
- def masked_mean_var(self, img_feat_rgb, source_img_mask):
59
- # img_feat_rgb: b,d,n,f source_img_mask: b,n
60
- b, n = source_img_mask.shape
61
- source_img_mask = source_img_mask.view(b, 1, n, 1)
62
- mean = torch.sum(source_img_mask * img_feat_rgb, dim=-2)/ (torch.sum(source_img_mask, dim=-2) + 1e-5)
63
- var = torch.sum((img_feat_rgb - mean.unsqueeze(-2)) ** 2 * source_img_mask, dim=-2) / (torch.sum(source_img_mask, dim=-2) + 1e-5)
64
- return mean, var
65
-
66
- def forward(self, img_feat_rgb_dir, source_img_mask):
67
- # img_feat_rgb_dir b,d,n,f
68
- b, d, n, _ = img_feat_rgb_dir.shape
69
- view_feat = self.view_fc(img_feat_rgb_dir[..., -4:]) # b,d,n,f-4
70
- img_feat_rgb = img_feat_rgb_dir[..., :-4] + view_feat
71
-
72
- mean_feat, var_feat = self.masked_mean_var(img_feat_rgb, source_img_mask)
73
- var_feat = var_feat.view(b, -1, 1, self.feat_ch).repeat(1, 1, n, 1)
74
- avg_feat = mean_feat.view(b, -1, 1, self.feat_ch).repeat(1, 1, n, 1)
75
-
76
- feat = torch.cat([img_feat_rgb, var_feat, avg_feat], dim=-1) # b,d,n,f
77
- global_feat = self.global_fc(feat) # b,d,n,f
78
- logits = self.agg_w_fc(global_feat) # b,d,n,1
79
- source_img_mask_ = source_img_mask.reshape(b, 1, n, 1).repeat(1, logits.shape[1], 1, 1) == 0
80
- logits[source_img_mask_] = -1e7
81
- agg_w = F.softmax(logits, dim=-2)
82
- im_feat = (global_feat * agg_w).sum(dim=-2)
83
- return self.fc(im_feat)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3drecon/renderer/cost_reg_net.py DELETED
@@ -1,95 +0,0 @@
1
- import torch.nn as nn
2
-
3
- class ConvBnReLU3D(nn.Module):
4
- def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1, norm_act=nn.BatchNorm3d):
5
- super(ConvBnReLU3D, self).__init__()
6
- self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)
7
- self.bn = norm_act(out_channels)
8
- self.relu = nn.ReLU(inplace=True)
9
-
10
- def forward(self, x):
11
- return self.relu(self.bn(self.conv(x)))
12
-
13
- class CostRegNet(nn.Module):
14
- def __init__(self, in_channels, norm_act=nn.BatchNorm3d):
15
- super(CostRegNet, self).__init__()
16
- self.conv0 = ConvBnReLU3D(in_channels, 8, norm_act=norm_act)
17
-
18
- self.conv1 = ConvBnReLU3D(8, 16, stride=2, norm_act=norm_act)
19
- self.conv2 = ConvBnReLU3D(16, 16, norm_act=norm_act)
20
-
21
- self.conv3 = ConvBnReLU3D(16, 32, stride=2, norm_act=norm_act)
22
- self.conv4 = ConvBnReLU3D(32, 32, norm_act=norm_act)
23
-
24
- self.conv5 = ConvBnReLU3D(32, 64, stride=2, norm_act=norm_act)
25
- self.conv6 = ConvBnReLU3D(64, 64, norm_act=norm_act)
26
-
27
- self.conv7 = nn.Sequential(
28
- nn.ConvTranspose3d(64, 32, 3, padding=1, output_padding=1, stride=2, bias=False),
29
- norm_act(32)
30
- )
31
-
32
- self.conv9 = nn.Sequential(
33
- nn.ConvTranspose3d(32, 16, 3, padding=1, output_padding=1, stride=2, bias=False),
34
- norm_act(16)
35
- )
36
-
37
- self.conv11 = nn.Sequential(
38
- nn.ConvTranspose3d(16, 8, 3, padding=1, output_padding=1,stride=2, bias=False),
39
- norm_act(8)
40
- )
41
- self.depth_conv = nn.Sequential(nn.Conv3d(8, 1, 3, padding=1, bias=False))
42
- self.feat_conv = nn.Sequential(nn.Conv3d(8, 8, 3, padding=1, bias=False))
43
-
44
- def forward(self, x):
45
- conv0 = self.conv0(x)
46
- conv2 = self.conv2(self.conv1(conv0))
47
- conv4 = self.conv4(self.conv3(conv2))
48
- x = self.conv6(self.conv5(conv4))
49
- x = conv4 + self.conv7(x)
50
- del conv4
51
- x = conv2 + self.conv9(x)
52
- del conv2
53
- x = conv0 + self.conv11(x)
54
- del conv0
55
- feat = self.feat_conv(x)
56
- depth = self.depth_conv(x)
57
- return feat, depth
58
-
59
-
60
- class MinCostRegNet(nn.Module):
61
- def __init__(self, in_channels, norm_act=nn.BatchNorm3d):
62
- super(MinCostRegNet, self).__init__()
63
- self.conv0 = ConvBnReLU3D(in_channels, 8, norm_act=norm_act)
64
-
65
- self.conv1 = ConvBnReLU3D(8, 16, stride=2, norm_act=norm_act)
66
- self.conv2 = ConvBnReLU3D(16, 16, norm_act=norm_act)
67
-
68
- self.conv3 = ConvBnReLU3D(16, 32, stride=2, norm_act=norm_act)
69
- self.conv4 = ConvBnReLU3D(32, 32, norm_act=norm_act)
70
-
71
- self.conv9 = nn.Sequential(
72
- nn.ConvTranspose3d(32, 16, 3, padding=1, output_padding=1,
73
- stride=2, bias=False),
74
- norm_act(16))
75
-
76
- self.conv11 = nn.Sequential(
77
- nn.ConvTranspose3d(16, 8, 3, padding=1, output_padding=1,
78
- stride=2, bias=False),
79
- norm_act(8))
80
-
81
- self.depth_conv = nn.Sequential(nn.Conv3d(8, 1, 3, padding=1, bias=False))
82
- self.feat_conv = nn.Sequential(nn.Conv3d(8, 8, 3, padding=1, bias=False))
83
-
84
- def forward(self, x):
85
- conv0 = self.conv0(x)
86
- conv2 = self.conv2(self.conv1(conv0))
87
- conv4 = self.conv4(self.conv3(conv2))
88
- x = conv4
89
- x = conv2 + self.conv9(x)
90
- del conv2
91
- x = conv0 + self.conv11(x)
92
- del conv0
93
- feat = self.feat_conv(x)
94
- depth = self.depth_conv(x)
95
- return feat, depth
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3drecon/renderer/dummy_dataset.py DELETED
@@ -1,40 +0,0 @@
1
- import pytorch_lightning as pl
2
- from torch.utils.data import Dataset
3
- import webdataset as wds
4
- from torch.utils.data.distributed import DistributedSampler
5
- class DummyDataset(pl.LightningDataModule):
6
- def __init__(self,seed):
7
- super().__init__()
8
-
9
- def setup(self, stage):
10
- if stage in ['fit']:
11
- self.train_dataset = DummyData(True)
12
- self.val_dataset = DummyData(False)
13
- else:
14
- raise NotImplementedError
15
-
16
- def train_dataloader(self):
17
- return wds.WebLoader(self.train_dataset, batch_size=1, num_workers=0, shuffle=False)
18
-
19
- def val_dataloader(self):
20
- return wds.WebLoader(self.val_dataset, batch_size=1, num_workers=0, shuffle=False)
21
-
22
- def test_dataloader(self):
23
- return wds.WebLoader(DummyData(False))
24
-
25
- class DummyData(Dataset):
26
- def __init__(self,is_train):
27
- self.is_train=is_train
28
-
29
- def __len__(self):
30
- if self.is_train:
31
- return 99999999
32
- else:
33
- return 1
34
-
35
- def __getitem__(self, index):
36
- return {}
37
-
38
-
39
-
40
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3drecon/renderer/feature_net.py DELETED
@@ -1,42 +0,0 @@
1
- import torch.nn as nn
2
- import torch.nn.functional as F
3
-
4
- class ConvBnReLU(nn.Module):
5
- def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1, norm_act=nn.BatchNorm2d):
6
- super(ConvBnReLU, self).__init__()
7
- self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)
8
- self.bn = norm_act(out_channels)
9
- self.relu = nn.ReLU(inplace=True)
10
-
11
- def forward(self, x):
12
- return self.relu(self.bn(self.conv(x)))
13
-
14
- class FeatureNet(nn.Module):
15
- def __init__(self, norm_act=nn.BatchNorm2d):
16
- super(FeatureNet, self).__init__()
17
- self.conv0 = nn.Sequential(ConvBnReLU(3, 8, 3, 1, 1, norm_act=norm_act), ConvBnReLU(8, 8, 3, 1, 1, norm_act=norm_act))
18
- self.conv1 = nn.Sequential(ConvBnReLU(8, 16, 5, 2, 2, norm_act=norm_act), ConvBnReLU(16, 16, 3, 1, 1, norm_act=norm_act))
19
- self.conv2 = nn.Sequential(ConvBnReLU(16, 32, 5, 2, 2, norm_act=norm_act), ConvBnReLU(32, 32, 3, 1, 1, norm_act=norm_act))
20
-
21
- self.toplayer = nn.Conv2d(32, 32, 1)
22
- self.lat1 = nn.Conv2d(16, 32, 1)
23
- self.lat0 = nn.Conv2d(8, 32, 1)
24
-
25
- self.smooth1 = nn.Conv2d(32, 16, 3, padding=1)
26
- self.smooth0 = nn.Conv2d(32, 8, 3, padding=1)
27
-
28
- def _upsample_add(self, x, y):
29
- return F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) + y
30
-
31
- def forward(self, x):
32
- conv0 = self.conv0(x)
33
- conv1 = self.conv1(conv0)
34
- conv2 = self.conv2(conv1)
35
- feat2 = self.toplayer(conv2)
36
- feat1 = self._upsample_add(feat2, self.lat1(conv1))
37
- feat0 = self._upsample_add(feat1, self.lat0(conv0))
38
- feat1 = self.smooth1(feat1)
39
- feat0 = self.smooth0(feat0)
40
- return feat2, feat1, feat0
41
-
42
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3drecon/renderer/neus_networks.py DELETED
@@ -1,503 +0,0 @@
1
- import math
2
-
3
- import numpy as np
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
- import tinycudann as tcnn
8
-
9
- # Positional encoding embedding. Code was taken from https://github.com/bmild/nerf.
10
- class Embedder:
11
- def __init__(self, **kwargs):
12
- self.kwargs = kwargs
13
- self.create_embedding_fn()
14
-
15
- def create_embedding_fn(self):
16
- embed_fns = []
17
- d = self.kwargs['input_dims']
18
- out_dim = 0
19
- if self.kwargs['include_input']:
20
- embed_fns.append(lambda x: x)
21
- out_dim += d
22
-
23
- max_freq = self.kwargs['max_freq_log2']
24
- N_freqs = self.kwargs['num_freqs']
25
-
26
- if self.kwargs['log_sampling']:
27
- freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
28
- else:
29
- freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, N_freqs)
30
-
31
- for freq in freq_bands:
32
- for p_fn in self.kwargs['periodic_fns']:
33
- embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
34
- out_dim += d
35
-
36
- self.embed_fns = embed_fns
37
- self.out_dim = out_dim
38
-
39
- def embed(self, inputs):
40
- return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
41
-
42
-
43
- def get_embedder(multires, input_dims=3):
44
- embed_kwargs = {
45
- 'include_input': True,
46
- 'input_dims': input_dims,
47
- 'max_freq_log2': multires - 1,
48
- 'num_freqs': multires,
49
- 'log_sampling': True,
50
- 'periodic_fns': [torch.sin, torch.cos],
51
- }
52
-
53
- embedder_obj = Embedder(**embed_kwargs)
54
-
55
- def embed(x, eo=embedder_obj): return eo.embed(x)
56
-
57
- return embed, embedder_obj.out_dim
58
-
59
-
60
- class SDFNetwork(nn.Module):
61
- def __init__(self, d_in, d_out, d_hidden, n_layers, skip_in=(4,), multires=0, bias=0.5,
62
- scale=1, geometric_init=True, weight_norm=True, inside_outside=False):
63
- super(SDFNetwork, self).__init__()
64
-
65
- dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out]
66
-
67
- self.embed_fn_fine = None
68
-
69
- if multires > 0:
70
- embed_fn, input_ch = get_embedder(multires, input_dims=d_in)
71
- self.embed_fn_fine = embed_fn
72
- dims[0] = input_ch
73
-
74
- self.num_layers = len(dims)
75
- self.skip_in = skip_in
76
- self.scale = scale
77
-
78
- for l in range(0, self.num_layers - 1):
79
- if l + 1 in self.skip_in:
80
- out_dim = dims[l + 1] - dims[0]
81
- else:
82
- out_dim = dims[l + 1]
83
-
84
- lin = nn.Linear(dims[l], out_dim)
85
-
86
- if geometric_init:
87
- if l == self.num_layers - 2:
88
- if not inside_outside:
89
- torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
90
- torch.nn.init.constant_(lin.bias, -bias)
91
- else:
92
- torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
93
- torch.nn.init.constant_(lin.bias, bias)
94
- elif multires > 0 and l == 0:
95
- torch.nn.init.constant_(lin.bias, 0.0)
96
- torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
97
- torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
98
- elif multires > 0 and l in self.skip_in:
99
- torch.nn.init.constant_(lin.bias, 0.0)
100
- torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
101
- torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
102
- else:
103
- torch.nn.init.constant_(lin.bias, 0.0)
104
- torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
105
-
106
- if weight_norm:
107
- lin = nn.utils.weight_norm(lin)
108
-
109
- setattr(self, "lin" + str(l), lin)
110
-
111
- self.activation = nn.Softplus(beta=100)
112
-
113
- def forward(self, inputs):
114
- inputs = inputs * self.scale
115
- if self.embed_fn_fine is not None:
116
- inputs = self.embed_fn_fine(inputs)
117
-
118
- x = inputs
119
- for l in range(0, self.num_layers - 1):
120
- lin = getattr(self, "lin" + str(l))
121
-
122
- if l in self.skip_in:
123
- x = torch.cat([x, inputs], -1) / np.sqrt(2)
124
-
125
- x = lin(x)
126
-
127
- if l < self.num_layers - 2:
128
- x = self.activation(x)
129
-
130
- return x
131
-
132
- def sdf(self, x):
133
- return self.forward(x)[..., :1]
134
-
135
- def sdf_hidden_appearance(self, x):
136
- return self.forward(x)
137
-
138
- def gradient(self, x):
139
- x.requires_grad_(True)
140
- with torch.enable_grad():
141
- y = self.sdf(x)
142
- d_output = torch.ones_like(y, requires_grad=False, device=y.device)
143
- gradients = torch.autograd.grad(
144
- outputs=y,
145
- inputs=x,
146
- grad_outputs=d_output,
147
- create_graph=True,
148
- retain_graph=True,
149
- only_inputs=True)[0]
150
- return gradients
151
-
152
- def sdf_normal(self, x):
153
- x.requires_grad_(True)
154
- with torch.enable_grad():
155
- y = self.sdf(x)
156
- d_output = torch.ones_like(y, requires_grad=False, device=y.device)
157
- gradients = torch.autograd.grad(
158
- outputs=y,
159
- inputs=x,
160
- grad_outputs=d_output,
161
- create_graph=True,
162
- retain_graph=True,
163
- only_inputs=True)[0]
164
- return y[..., :1].detach(), gradients.detach()
165
-
166
- class SDFNetworkWithFeature(nn.Module):
167
- def __init__(self, cube, dp_in, df_in, d_out, d_hidden, n_layers, skip_in=(4,), multires=0, bias=0.5,
168
- scale=1, geometric_init=True, weight_norm=True, inside_outside=False, cube_length=0.5):
169
- super().__init__()
170
-
171
- self.register_buffer("cube", cube)
172
- self.cube_length = cube_length
173
- dims = [dp_in+df_in] + [d_hidden for _ in range(n_layers)] + [d_out]
174
-
175
- self.embed_fn_fine = None
176
-
177
- if multires > 0:
178
- embed_fn, input_ch = get_embedder(multires, input_dims=dp_in)
179
- self.embed_fn_fine = embed_fn
180
- dims[0] = input_ch + df_in
181
-
182
- self.num_layers = len(dims)
183
- self.skip_in = skip_in
184
- self.scale = scale
185
-
186
- for l in range(0, self.num_layers - 1):
187
- if l + 1 in self.skip_in:
188
- out_dim = dims[l + 1] - dims[0]
189
- else:
190
- out_dim = dims[l + 1]
191
-
192
- lin = nn.Linear(dims[l], out_dim)
193
-
194
- if geometric_init:
195
- if l == self.num_layers - 2:
196
- if not inside_outside:
197
- torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
198
- torch.nn.init.constant_(lin.bias, -bias)
199
- else:
200
- torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
201
- torch.nn.init.constant_(lin.bias, bias)
202
- elif multires > 0 and l == 0:
203
- torch.nn.init.constant_(lin.bias, 0.0)
204
- torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
205
- torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
206
- elif multires > 0 and l in self.skip_in:
207
- torch.nn.init.constant_(lin.bias, 0.0)
208
- torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
209
- torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
210
- else:
211
- torch.nn.init.constant_(lin.bias, 0.0)
212
- torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
213
-
214
- if weight_norm:
215
- lin = nn.utils.weight_norm(lin)
216
-
217
- setattr(self, "lin" + str(l), lin)
218
-
219
- self.activation = nn.Softplus(beta=100)
220
-
221
- def forward(self, points):
222
- points = points * self.scale
223
-
224
- # note: point*2 because the cube is [-0.5,0.5]
225
- with torch.no_grad():
226
- feats = F.grid_sample(self.cube, points.view(1,-1,1,1,3)/self.cube_length, mode='bilinear', align_corners=True, padding_mode='zeros').detach()
227
- feats = feats.view(self.cube.shape[1], -1).permute(1,0).view(*points.shape[:-1], -1)
228
- if self.embed_fn_fine is not None:
229
- points = self.embed_fn_fine(points)
230
-
231
- x = torch.cat([points, feats], -1)
232
- for l in range(0, self.num_layers - 1):
233
- lin = getattr(self, "lin" + str(l))
234
-
235
- if l in self.skip_in:
236
- x = torch.cat([x, points, feats], -1) / np.sqrt(2)
237
-
238
- x = lin(x)
239
-
240
- if l < self.num_layers - 2:
241
- x = self.activation(x)
242
-
243
- # concat feats
244
- x = torch.cat([x, feats], -1)
245
- return x
246
-
247
- def sdf(self, x):
248
- return self.forward(x)[..., :1]
249
-
250
- def sdf_hidden_appearance(self, x):
251
- return self.forward(x)
252
-
253
- def gradient(self, x):
254
- x.requires_grad_(True)
255
- with torch.enable_grad():
256
- y = self.sdf(x)
257
- d_output = torch.ones_like(y, requires_grad=False, device=y.device)
258
- gradients = torch.autograd.grad(
259
- outputs=y,
260
- inputs=x,
261
- grad_outputs=d_output,
262
- create_graph=True,
263
- retain_graph=True,
264
- only_inputs=True)[0]
265
- return gradients
266
-
267
- def sdf_normal(self, x):
268
- x.requires_grad_(True)
269
- with torch.enable_grad():
270
- y = self.sdf(x)
271
- d_output = torch.ones_like(y, requires_grad=False, device=y.device)
272
- gradients = torch.autograd.grad(
273
- outputs=y,
274
- inputs=x,
275
- grad_outputs=d_output,
276
- create_graph=True,
277
- retain_graph=True,
278
- only_inputs=True)[0]
279
- return y[..., :1].detach(), gradients.detach()
280
-
281
-
282
- class VanillaMLP(nn.Module):
283
- def __init__(self, dim_in, dim_out, n_neurons, n_hidden_layers):
284
- super().__init__()
285
- self.n_neurons, self.n_hidden_layers = n_neurons, n_hidden_layers
286
- self.sphere_init, self.weight_norm = True, True
287
- self.sphere_init_radius = 0.5
288
- self.layers = [self.make_linear(dim_in, self.n_neurons, is_first=True, is_last=False), self.make_activation()]
289
- for i in range(self.n_hidden_layers - 1):
290
- self.layers += [self.make_linear(self.n_neurons, self.n_neurons, is_first=False, is_last=False), self.make_activation()]
291
- self.layers += [self.make_linear(self.n_neurons, dim_out, is_first=False, is_last=True)]
292
- self.layers = nn.Sequential(*self.layers)
293
-
294
- @torch.cuda.amp.autocast(False)
295
- def forward(self, x):
296
- x = self.layers(x.float())
297
- return x
298
-
299
- def make_linear(self, dim_in, dim_out, is_first, is_last):
300
- layer = nn.Linear(dim_in, dim_out, bias=True) # network without bias will degrade quality
301
- if self.sphere_init:
302
- if is_last:
303
- torch.nn.init.constant_(layer.bias, -self.sphere_init_radius)
304
- torch.nn.init.normal_(layer.weight, mean=math.sqrt(math.pi) / math.sqrt(dim_in), std=0.0001)
305
- elif is_first:
306
- torch.nn.init.constant_(layer.bias, 0.0)
307
- torch.nn.init.constant_(layer.weight[:, 3:], 0.0)
308
- torch.nn.init.normal_(layer.weight[:, :3], 0.0, math.sqrt(2) / math.sqrt(dim_out))
309
- else:
310
- torch.nn.init.constant_(layer.bias, 0.0)
311
- torch.nn.init.normal_(layer.weight, 0.0, math.sqrt(2) / math.sqrt(dim_out))
312
- else:
313
- torch.nn.init.constant_(layer.bias, 0.0)
314
- torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity='relu')
315
-
316
- if self.weight_norm:
317
- layer = nn.utils.weight_norm(layer)
318
- return layer
319
-
320
- def make_activation(self):
321
- if self.sphere_init:
322
- return nn.Softplus(beta=100)
323
- else:
324
- return nn.ReLU(inplace=True)
325
-
326
-
327
- class SDFHashGridNetwork(nn.Module):
328
- def __init__(self, bound=0.5, feats_dim=13):
329
- super().__init__()
330
- self.bound = bound
331
- # max_resolution = 32
332
- # base_resolution = 16
333
- # n_levels = 4
334
- # log2_hashmap_size = 16
335
- # n_features_per_level = 8
336
- max_resolution = 2048
337
- base_resolution = 16
338
- n_levels = 16
339
- log2_hashmap_size = 19
340
- n_features_per_level = 2
341
-
342
- # max_res = base_res * t^(k-1)
343
- per_level_scale = (max_resolution / base_resolution)** (1 / (n_levels - 1))
344
-
345
- self.encoder = tcnn.Encoding(
346
- n_input_dims=3,
347
- encoding_config={
348
- "otype": "HashGrid",
349
- "n_levels": n_levels,
350
- "n_features_per_level": n_features_per_level,
351
- "log2_hashmap_size": log2_hashmap_size,
352
- "base_resolution": base_resolution,
353
- "per_level_scale": per_level_scale,
354
- },
355
- )
356
- self.sdf_mlp = VanillaMLP(n_levels*n_features_per_level+3,feats_dim,64,1)
357
-
358
- def forward(self, x):
359
- shape = x.shape[:-1]
360
- x = x.reshape(-1, 3)
361
- x_ = (x + self.bound) / (2 * self.bound)
362
- feats = self.encoder(x_)
363
- feats = torch.cat([x, feats], 1)
364
-
365
- feats = self.sdf_mlp(feats)
366
- feats = feats.reshape(*shape,-1)
367
- return feats
368
-
369
- def sdf(self, x):
370
- return self(x)[...,:1]
371
-
372
- def gradient(self, x):
373
- x.requires_grad_(True)
374
- with torch.enable_grad():
375
- y = self.sdf(x)
376
- d_output = torch.ones_like(y, requires_grad=False, device=y.device)
377
- gradients = torch.autograd.grad(
378
- outputs=y,
379
- inputs=x,
380
- grad_outputs=d_output,
381
- create_graph=True,
382
- retain_graph=True,
383
- only_inputs=True)[0]
384
- return gradients
385
-
386
- def sdf_normal(self, x):
387
- x.requires_grad_(True)
388
- with torch.enable_grad():
389
- y = self.sdf(x)
390
- d_output = torch.ones_like(y, requires_grad=False, device=y.device)
391
- gradients = torch.autograd.grad(
392
- outputs=y,
393
- inputs=x,
394
- grad_outputs=d_output,
395
- create_graph=True,
396
- retain_graph=True,
397
- only_inputs=True)[0]
398
- return y[..., :1].detach(), gradients.detach()
399
-
400
- class RenderingFFNetwork(nn.Module):
401
- def __init__(self, in_feats_dim=12):
402
- super().__init__()
403
- self.dir_encoder = tcnn.Encoding(
404
- n_input_dims=3,
405
- encoding_config={
406
- "otype": "SphericalHarmonics",
407
- "degree": 4,
408
- },
409
- )
410
- self.color_mlp = tcnn.Network(
411
- n_input_dims = in_feats_dim + 3 + self.dir_encoder.n_output_dims,
412
- n_output_dims = 3,
413
- network_config={
414
- "otype": "FullyFusedMLP",
415
- "activation": "ReLU",
416
- "output_activation": "none",
417
- "n_neurons": 64,
418
- "n_hidden_layers": 2,
419
- },
420
- )
421
-
422
- def forward(self, points, normals, view_dirs, feature_vectors):
423
- normals = F.normalize(normals, dim=-1)
424
- view_dirs = F.normalize(view_dirs, dim=-1)
425
- reflective = torch.sum(view_dirs * normals, -1, keepdim=True) * normals * 2 - view_dirs
426
-
427
- x = torch.cat([feature_vectors, normals, self.dir_encoder(reflective)], -1)
428
- colors = self.color_mlp(x).float()
429
- colors = F.sigmoid(colors)
430
- return colors
431
-
432
- # This implementation is borrowed from IDR: https://github.com/lioryariv/idr
433
- class RenderingNetwork(nn.Module):
434
- def __init__(self, d_feature, d_in, d_out, d_hidden,
435
- n_layers, weight_norm=True, multires_view=0, squeeze_out=True, use_view_dir=True):
436
- super().__init__()
437
-
438
- self.squeeze_out = squeeze_out
439
- self.rgb_act=F.sigmoid
440
- self.use_view_dir=use_view_dir
441
-
442
- dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out]
443
-
444
- self.embedview_fn = None
445
- if multires_view > 0:
446
- embedview_fn, input_ch = get_embedder(multires_view)
447
- self.embedview_fn = embedview_fn
448
- dims[0] += (input_ch - 3)
449
-
450
- self.num_layers = len(dims)
451
-
452
- for l in range(0, self.num_layers - 1):
453
- out_dim = dims[l + 1]
454
- lin = nn.Linear(dims[l], out_dim)
455
-
456
- if weight_norm:
457
- lin = nn.utils.weight_norm(lin)
458
-
459
- setattr(self, "lin" + str(l), lin)
460
-
461
- self.relu = nn.ReLU()
462
-
463
- def forward(self, points, normals, view_dirs, feature_vectors):
464
- if self.use_view_dir:
465
- view_dirs = F.normalize(view_dirs, dim=-1)
466
- normals = F.normalize(normals, dim=-1)
467
- reflective = torch.sum(view_dirs*normals, -1, keepdim=True) * normals * 2 - view_dirs
468
- if self.embedview_fn is not None: reflective = self.embedview_fn(reflective)
469
- rendering_input = torch.cat([points, reflective, normals, feature_vectors], dim=-1)
470
- else:
471
- rendering_input = torch.cat([points, normals, feature_vectors], dim=-1)
472
-
473
- x = rendering_input
474
-
475
- for l in range(0, self.num_layers - 1):
476
- lin = getattr(self, "lin" + str(l))
477
-
478
- x = lin(x)
479
-
480
- if l < self.num_layers - 2:
481
- x = self.relu(x)
482
-
483
- if self.squeeze_out:
484
- x = self.rgb_act(x)
485
- return x
486
-
487
-
488
- class SingleVarianceNetwork(nn.Module):
489
- def __init__(self, init_val, activation='exp'):
490
- super(SingleVarianceNetwork, self).__init__()
491
- self.act = activation
492
- self.register_parameter('variance', nn.Parameter(torch.tensor(init_val)))
493
-
494
- def forward(self, x):
495
- device = x.device
496
- if self.act=='exp':
497
- return torch.ones([*x.shape[:-1], 1], dtype=torch.float32, device=device) * torch.exp(self.variance * 10.0)
498
- else:
499
- raise NotImplementedError
500
-
501
- def warp(self, x, inv_s):
502
- device = x.device
503
- return torch.ones([*x.shape[:-1], 1], dtype=torch.float32, device=device) * inv_s
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3drecon/renderer/ngp_renderer.py DELETED
@@ -1,721 +0,0 @@
1
- import math
2
- import trimesh
3
- import numpy as np
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- from packaging import version as pver
9
-
10
- import tinycudann as tcnn
11
- from torch.autograd import Function
12
-
13
- from torch.cuda.amp import custom_bwd, custom_fwd
14
-
15
- import raymarching
16
-
17
- def custom_meshgrid(*args):
18
- # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
19
- if pver.parse(torch.__version__) < pver.parse('1.10'):
20
- return torch.meshgrid(*args)
21
- else:
22
- return torch.meshgrid(*args, indexing='ij')
23
-
24
- def sample_pdf(bins, weights, n_samples, det=False):
25
- # This implementation is from NeRF
26
- # bins: [B, T], old_z_vals
27
- # weights: [B, T - 1], bin weights.
28
- # return: [B, n_samples], new_z_vals
29
-
30
- # Get pdf
31
- weights = weights + 1e-5 # prevent nans
32
- pdf = weights / torch.sum(weights, -1, keepdim=True)
33
- cdf = torch.cumsum(pdf, -1)
34
- cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
35
- # Take uniform samples
36
- if det:
37
- u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(weights.device)
38
- u = u.expand(list(cdf.shape[:-1]) + [n_samples])
39
- else:
40
- u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device)
41
-
42
- # Invert CDF
43
- u = u.contiguous()
44
- inds = torch.searchsorted(cdf, u, right=True)
45
- below = torch.max(torch.zeros_like(inds - 1), inds - 1)
46
- above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
47
- inds_g = torch.stack([below, above], -1) # (B, n_samples, 2)
48
-
49
- matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
50
- cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
51
- bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
52
-
53
- denom = (cdf_g[..., 1] - cdf_g[..., 0])
54
- denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
55
- t = (u - cdf_g[..., 0]) / denom
56
- samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
57
-
58
- return samples
59
-
60
-
61
- def plot_pointcloud(pc, color=None):
62
- # pc: [N, 3]
63
- # color: [N, 3/4]
64
- print('[visualize points]', pc.shape, pc.dtype, pc.min(0), pc.max(0))
65
- pc = trimesh.PointCloud(pc, color)
66
- # axis
67
- axes = trimesh.creation.axis(axis_length=4)
68
- # sphere
69
- sphere = trimesh.creation.icosphere(radius=1)
70
- trimesh.Scene([pc, axes, sphere]).show()
71
-
72
-
73
- class NGPRenderer(nn.Module):
74
- def __init__(self,
75
- bound=1,
76
- cuda_ray=True,
77
- density_scale=1, # scale up deltas (or sigmas), to make the density grid more sharp. larger value than 1 usually improves performance.
78
- min_near=0.2,
79
- density_thresh=0.01,
80
- bg_radius=-1,
81
- ):
82
- super().__init__()
83
-
84
- self.bound = bound
85
- self.cascade = 1
86
- self.grid_size = 128
87
- self.density_scale = density_scale
88
- self.min_near = min_near
89
- self.density_thresh = density_thresh
90
- self.bg_radius = bg_radius # radius of the background sphere.
91
-
92
- # prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax)
93
- # NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing.
94
- aabb_train = torch.FloatTensor([-bound, -bound, -bound, bound, bound, bound])
95
- aabb_infer = aabb_train.clone()
96
- self.register_buffer('aabb_train', aabb_train)
97
- self.register_buffer('aabb_infer', aabb_infer)
98
-
99
- # extra state for cuda raymarching
100
- self.cuda_ray = cuda_ray
101
- if cuda_ray:
102
- # density grid
103
- density_grid = torch.zeros([self.cascade, self.grid_size ** 3]) # [CAS, H * H * H]
104
- density_bitfield = torch.zeros(self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [CAS * H * H * H // 8]
105
- self.register_buffer('density_grid', density_grid)
106
- self.register_buffer('density_bitfield', density_bitfield)
107
- self.mean_density = 0
108
- self.iter_density = 0
109
- # step counter
110
- step_counter = torch.zeros(16, 2, dtype=torch.int32) # 16 is hardcoded for averaging...
111
- self.register_buffer('step_counter', step_counter)
112
- self.mean_count = 0
113
- self.local_step = 0
114
-
115
- def forward(self, x, d):
116
- raise NotImplementedError()
117
-
118
- # separated density and color query (can accelerate non-cuda-ray mode.)
119
- def density(self, x):
120
- raise NotImplementedError()
121
-
122
- def color(self, x, d, mask=None, **kwargs):
123
- raise NotImplementedError()
124
-
125
- def reset_extra_state(self):
126
- if not self.cuda_ray:
127
- return
128
- # density grid
129
- self.density_grid.zero_()
130
- self.mean_density = 0
131
- self.iter_density = 0
132
- # step counter
133
- self.step_counter.zero_()
134
- self.mean_count = 0
135
- self.local_step = 0
136
-
137
- def run(self, rays_o, rays_d, num_steps=128, upsample_steps=128, bg_color=None, perturb=False, **kwargs):
138
- # rays_o, rays_d: [B, N, 3], assumes B == 1
139
- # bg_color: [3] in range [0, 1]
140
- # return: image: [B, N, 3], depth: [B, N]
141
-
142
- prefix = rays_o.shape[:-1]
143
- rays_o = rays_o.contiguous().view(-1, 3)
144
- rays_d = rays_d.contiguous().view(-1, 3)
145
-
146
- N = rays_o.shape[0] # N = B * N, in fact
147
- device = rays_o.device
148
-
149
- # choose aabb
150
- aabb = self.aabb_train if self.training else self.aabb_infer
151
-
152
- # sample steps
153
- nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, aabb, self.min_near)
154
- nears.unsqueeze_(-1)
155
- fars.unsqueeze_(-1)
156
-
157
- #print(f'nears = {nears.min().item()} ~ {nears.max().item()}, fars = {fars.min().item()} ~ {fars.max().item()}')
158
-
159
- z_vals = torch.linspace(0.0, 1.0, num_steps, device=device).unsqueeze(0) # [1, T]
160
- z_vals = z_vals.expand((N, num_steps)) # [N, T]
161
- z_vals = nears + (fars - nears) * z_vals # [N, T], in [nears, fars]
162
-
163
- # perturb z_vals
164
- sample_dist = (fars - nears) / num_steps
165
- if perturb:
166
- z_vals = z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist
167
- #z_vals = z_vals.clamp(nears, fars) # avoid out of bounds xyzs.
168
-
169
- # generate xyzs
170
- xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(-1) # [N, 1, 3] * [N, T, 1] -> [N, T, 3]
171
- xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:]) # a manual clip.
172
-
173
- #plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())
174
-
175
- # query SDF and RGB
176
- density_outputs = self.density(xyzs.reshape(-1, 3))
177
-
178
- #sigmas = density_outputs['sigma'].view(N, num_steps) # [N, T]
179
- for k, v in density_outputs.items():
180
- density_outputs[k] = v.view(N, num_steps, -1)
181
-
182
- # upsample z_vals (nerf-like)
183
- if upsample_steps > 0:
184
- with torch.no_grad():
185
-
186
- deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T-1]
187
- deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1)
188
-
189
- alphas = 1 - torch.exp(-deltas * self.density_scale * density_outputs['sigma'].squeeze(-1)) # [N, T]
190
- alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+1]
191
- weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T]
192
-
193
- # sample new z_vals
194
- z_vals_mid = (z_vals[..., :-1] + 0.5 * deltas[..., :-1]) # [N, T-1]
195
- new_z_vals = sample_pdf(z_vals_mid, weights[:, 1:-1], upsample_steps, det=not self.training).detach() # [N, t]
196
-
197
- new_xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * new_z_vals.unsqueeze(-1) # [N, 1, 3] * [N, t, 1] -> [N, t, 3]
198
- new_xyzs = torch.min(torch.max(new_xyzs, aabb[:3]), aabb[3:]) # a manual clip.
199
-
200
- # only forward new points to save computation
201
- new_density_outputs = self.density(new_xyzs.reshape(-1, 3))
202
- #new_sigmas = new_density_outputs['sigma'].view(N, upsample_steps) # [N, t]
203
- for k, v in new_density_outputs.items():
204
- new_density_outputs[k] = v.view(N, upsample_steps, -1)
205
-
206
- # re-order
207
- z_vals = torch.cat([z_vals, new_z_vals], dim=1) # [N, T+t]
208
- z_vals, z_index = torch.sort(z_vals, dim=1)
209
-
210
- xyzs = torch.cat([xyzs, new_xyzs], dim=1) # [N, T+t, 3]
211
- xyzs = torch.gather(xyzs, dim=1, index=z_index.unsqueeze(-1).expand_as(xyzs))
212
-
213
- for k in density_outputs:
214
- tmp_output = torch.cat([density_outputs[k], new_density_outputs[k]], dim=1)
215
- density_outputs[k] = torch.gather(tmp_output, dim=1, index=z_index.unsqueeze(-1).expand_as(tmp_output))
216
-
217
- deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T+t-1]
218
- deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1)
219
- alphas = 1 - torch.exp(-deltas * self.density_scale * density_outputs['sigma'].squeeze(-1)) # [N, T+t]
220
- alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+t+1]
221
- weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T+t]
222
-
223
- dirs = rays_d.view(-1, 1, 3).expand_as(xyzs)
224
- for k, v in density_outputs.items():
225
- density_outputs[k] = v.view(-1, v.shape[-1])
226
-
227
- mask = weights > 1e-4 # hard coded
228
- rgbs = self.color(xyzs.reshape(-1, 3), dirs.reshape(-1, 3), mask=mask.reshape(-1), **density_outputs)
229
- rgbs = rgbs.view(N, -1, 3) # [N, T+t, 3]
230
-
231
- #print(xyzs.shape, 'valid_rgb:', mask.sum().item())
232
-
233
- # calculate weight_sum (mask)
234
- weights_sum = weights.sum(dim=-1) # [N]
235
-
236
- # calculate depth
237
- ori_z_vals = ((z_vals - nears) / (fars - nears)).clamp(0, 1)
238
- depth = torch.sum(weights * ori_z_vals, dim=-1)
239
-
240
- # calculate color
241
- image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2) # [N, 3], in [0, 1]
242
-
243
- # mix background color
244
- if self.bg_radius > 0:
245
- # use the bg model to calculate bg_color
246
- sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1]
247
- bg_color = self.background(sph, rays_d.reshape(-1, 3)) # [N, 3]
248
- elif bg_color is None:
249
- bg_color = 1
250
-
251
- image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
252
-
253
- image = image.view(*prefix, 3)
254
- depth = depth.view(*prefix)
255
-
256
- # tmp: reg loss in mip-nerf 360
257
- # z_vals_shifted = torch.cat([z_vals[..., 1:], sample_dist * torch.ones_like(z_vals[..., :1])], dim=-1)
258
- # mid_zs = (z_vals + z_vals_shifted) / 2 # [N, T]
259
- # loss_dist = (torch.abs(mid_zs.unsqueeze(1) - mid_zs.unsqueeze(2)) * (weights.unsqueeze(1) * weights.unsqueeze(2))).sum() + 1/3 * ((z_vals_shifted - z_vals_shifted) * (weights ** 2)).sum()
260
-
261
- return {
262
- 'depth': depth,
263
- 'image': image,
264
- 'weights_sum': weights_sum,
265
- }
266
-
267
-
268
- def run_cuda(self, rays_o, rays_d, dt_gamma=0, bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, **kwargs):
269
- # rays_o, rays_d: [B, N, 3], assumes B == 1
270
- # return: image: [B, N, 3], depth: [B, N]
271
-
272
- prefix = rays_o.shape[:-1]
273
- rays_o = rays_o.contiguous().view(-1, 3)
274
- rays_d = rays_d.contiguous().view(-1, 3)
275
-
276
- N = rays_o.shape[0] # N = B * N, in fact
277
- device = rays_o.device
278
-
279
- # pre-calculate near far
280
- nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer, self.min_near)
281
-
282
- # mix background color
283
- if self.bg_radius > 0:
284
- # use the bg model to calculate bg_color
285
- sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1]
286
- bg_color = self.background(sph, rays_d) # [N, 3]
287
- elif bg_color is None:
288
- bg_color = 1
289
-
290
- results = {}
291
-
292
- if self.training:
293
- # setup counter
294
- counter = self.step_counter[self.local_step % 16]
295
- counter.zero_() # set to 0
296
- self.local_step += 1
297
-
298
- xyzs, dirs, deltas, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, max_steps)
299
-
300
- #plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())
301
-
302
- sigmas, rgbs = self(xyzs, dirs)
303
- sigmas = self.density_scale * sigmas
304
-
305
- weights_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, deltas, rays, T_thresh)
306
- image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
307
- depth = torch.clamp(depth - nears, min=0) / (fars - nears)
308
- image = image.view(*prefix, 3)
309
- depth = depth.view(*prefix)
310
-
311
- else:
312
-
313
- # allocate outputs
314
- # if use autocast, must init as half so it won't be autocasted and lose reference.
315
- #dtype = torch.half if torch.is_autocast_enabled() else torch.float32
316
- # output should always be float32! only network inference uses half.
317
- dtype = torch.float32
318
-
319
- weights_sum = torch.zeros(N, dtype=dtype, device=device)
320
- depth = torch.zeros(N, dtype=dtype, device=device)
321
- image = torch.zeros(N, 3, dtype=dtype, device=device)
322
-
323
- n_alive = N
324
- rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N]
325
- rays_t = nears.clone() # [N]
326
-
327
- step = 0
328
-
329
- while step < max_steps:
330
-
331
- # count alive rays
332
- n_alive = rays_alive.shape[0]
333
-
334
- # exit loop
335
- if n_alive <= 0:
336
- break
337
-
338
- # decide compact_steps
339
- n_step = max(min(N // n_alive, 8), 1)
340
-
341
- xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, 128, perturb if step == 0 else False, dt_gamma, max_steps)
342
-
343
- sigmas, rgbs = self(xyzs, dirs)
344
- # density_outputs = self.density(xyzs) # [M,], use a dict since it may include extra things, like geo_feat for rgb.
345
- # sigmas = density_outputs['sigma']
346
- # rgbs = self.color(xyzs, dirs, **density_outputs)
347
- sigmas = self.density_scale * sigmas
348
-
349
- raymarching.composite_rays(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh)
350
-
351
- rays_alive = rays_alive[rays_alive >= 0]
352
-
353
- #print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}')
354
-
355
- step += n_step
356
-
357
- image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
358
- depth = torch.clamp(depth - nears, min=0) / (fars - nears)
359
- image = image.view(*prefix, 3)
360
- depth = depth.view(*prefix)
361
-
362
- results['weights_sum'] = weights_sum
363
- results['depth'] = depth
364
- results['image'] = image
365
-
366
- return results
367
-
368
- @torch.no_grad()
369
- def mark_untrained_grid(self, poses, intrinsic, S=64):
370
- # poses: [B, 4, 4]
371
- # intrinsic: [3, 3]
372
-
373
- if not self.cuda_ray:
374
- return
375
-
376
- if isinstance(poses, np.ndarray):
377
- poses = torch.from_numpy(poses)
378
-
379
- B = poses.shape[0]
380
-
381
- fx, fy, cx, cy = intrinsic
382
-
383
- X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
384
- Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
385
- Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
386
-
387
- count = torch.zeros_like(self.density_grid)
388
- poses = poses.to(count.device)
389
-
390
- # 5-level loop, forgive me...
391
-
392
- for xs in X:
393
- for ys in Y:
394
- for zs in Z:
395
-
396
- # construct points
397
- xx, yy, zz = custom_meshgrid(xs, ys, zs)
398
- coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)
399
- indices = raymarching.morton3D(coords).long() # [N]
400
- world_xyzs = (2 * coords.float() / (self.grid_size - 1) - 1).unsqueeze(0) # [1, N, 3] in [-1, 1]
401
-
402
- # cascading
403
- for cas in range(self.cascade):
404
- bound = min(2 ** cas, self.bound)
405
- half_grid_size = bound / self.grid_size
406
- # scale to current cascade's resolution
407
- cas_world_xyzs = world_xyzs * (bound - half_grid_size)
408
-
409
- # split batch to avoid OOM
410
- head = 0
411
- while head < B:
412
- tail = min(head + S, B)
413
-
414
- # world2cam transform (poses is c2w, so we need to transpose it. Another transpose is needed for batched matmul, so the final form is without transpose.)
415
- cam_xyzs = cas_world_xyzs - poses[head:tail, :3, 3].unsqueeze(1)
416
- cam_xyzs = cam_xyzs @ poses[head:tail, :3, :3] # [S, N, 3]
417
-
418
- # query if point is covered by any camera
419
- mask_z = cam_xyzs[:, :, 2] > 0 # [S, N]
420
- mask_x = torch.abs(cam_xyzs[:, :, 0]) < cx / fx * cam_xyzs[:, :, 2] + half_grid_size * 2
421
- mask_y = torch.abs(cam_xyzs[:, :, 1]) < cy / fy * cam_xyzs[:, :, 2] + half_grid_size * 2
422
- mask = (mask_z & mask_x & mask_y).sum(0).reshape(-1) # [N]
423
-
424
- # update count
425
- count[cas, indices] += mask
426
- head += S
427
-
428
- # mark untrained grid as -1
429
- self.density_grid[count == 0] = -1
430
-
431
- print(f'[mark untrained grid] {(count == 0).sum()} from {self.grid_size ** 3 * self.cascade}')
432
-
433
- @torch.no_grad()
434
- def update_extra_state(self, decay=0.95, S=128):
435
- # call before each epoch to update extra states.
436
-
437
- if not self.cuda_ray:
438
- return
439
-
440
- ### update density grid
441
- tmp_grid = - torch.ones_like(self.density_grid)
442
-
443
- # full update.
444
- if self.iter_density < 16:
445
- #if True:
446
- X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
447
- Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
448
- Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
449
-
450
- for xs in X:
451
- for ys in Y:
452
- for zs in Z:
453
-
454
- # construct points
455
- xx, yy, zz = custom_meshgrid(xs, ys, zs)
456
- coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)
457
- indices = raymarching.morton3D(coords).long() # [N]
458
- xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1]
459
-
460
- # cascading
461
- for cas in range(self.cascade):
462
- bound = min(2 ** cas, self.bound)
463
- half_grid_size = bound / self.grid_size
464
- # scale to current cascade's resolution
465
- cas_xyzs = xyzs * (bound - half_grid_size)
466
- # add noise in [-hgs, hgs]
467
- cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size
468
- # query density
469
- sigmas = self.density(cas_xyzs)['sigma'].reshape(-1).detach()
470
- sigmas *= self.density_scale
471
- # assign
472
- tmp_grid[cas, indices] = sigmas
473
-
474
- # partial update (half the computation)
475
- # TODO: why no need of maxpool ?
476
- else:
477
- N = self.grid_size ** 3 // 4 # H * H * H / 4
478
- for cas in range(self.cascade):
479
- # random sample some positions
480
- coords = torch.randint(0, self.grid_size, (N, 3), device=self.density_bitfield.device) # [N, 3], in [0, 128)
481
- indices = raymarching.morton3D(coords).long() # [N]
482
- # random sample occupied positions
483
- occ_indices = torch.nonzero(self.density_grid[cas] > 0).squeeze(-1) # [Nz]
484
- rand_mask = torch.randint(0, occ_indices.shape[0], [N], dtype=torch.long, device=self.density_bitfield.device)
485
- occ_indices = occ_indices[rand_mask] # [Nz] --> [N], allow for duplication
486
- occ_coords = raymarching.morton3D_invert(occ_indices) # [N, 3]
487
- # concat
488
- indices = torch.cat([indices, occ_indices], dim=0)
489
- coords = torch.cat([coords, occ_coords], dim=0)
490
- # same below
491
- xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1]
492
- bound = min(2 ** cas, self.bound)
493
- half_grid_size = bound / self.grid_size
494
- # scale to current cascade's resolution
495
- cas_xyzs = xyzs * (bound - half_grid_size)
496
- # add noise in [-hgs, hgs]
497
- cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size
498
- # query density
499
- sigmas = self.density(cas_xyzs)['sigma'].reshape(-1).detach()
500
- sigmas *= self.density_scale
501
- # assign
502
- tmp_grid[cas, indices] = sigmas
503
-
504
- ## max-pool on tmp_grid for less aggressive culling [No significant improvement...]
505
- # invalid_mask = tmp_grid < 0
506
- # tmp_grid = F.max_pool3d(tmp_grid.view(self.cascade, 1, self.grid_size, self.grid_size, self.grid_size), kernel_size=3, stride=1, padding=1).view(self.cascade, -1)
507
- # tmp_grid[invalid_mask] = -1
508
-
509
- # ema update
510
- valid_mask = (self.density_grid >= 0) & (tmp_grid >= 0)
511
- self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask])
512
- self.mean_density = torch.mean(self.density_grid.clamp(min=0)).item() # -1 regions are viewed as 0 density.
513
- #self.mean_density = torch.mean(self.density_grid[self.density_grid > 0]).item() # do not count -1 regions
514
- self.iter_density += 1
515
-
516
- # convert to bitfield
517
- density_thresh = min(self.mean_density, self.density_thresh)
518
- self.density_bitfield = raymarching.packbits(self.density_grid, density_thresh, self.density_bitfield)
519
-
520
- ### update step counter
521
- total_step = min(16, self.local_step)
522
- if total_step > 0:
523
- self.mean_count = int(self.step_counter[:total_step, 0].sum().item() / total_step)
524
- self.local_step = 0
525
-
526
- #print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > 0.01).sum() / (128**3 * self.cascade):.3f} | [step counter] mean={self.mean_count}')
527
-
528
-
529
- def render(self, rays_o, rays_d, staged=False, max_ray_batch=4096, **kwargs):
530
- # rays_o, rays_d: [B, N, 3], assumes B == 1
531
- # return: pred_rgb: [B, N, 3]
532
-
533
- if self.cuda_ray:
534
- _run = self.run_cuda
535
- else:
536
- _run = self.run
537
-
538
- results = _run(rays_o, rays_d, **kwargs)
539
- return results
540
-
541
-
542
-
543
- class _trunc_exp(Function):
544
- @staticmethod
545
- @custom_fwd(cast_inputs=torch.float32) # cast to float32
546
- def forward(ctx, x):
547
- ctx.save_for_backward(x)
548
- return torch.exp(x)
549
-
550
- @staticmethod
551
- @custom_bwd
552
- def backward(ctx, g):
553
- x = ctx.saved_tensors[0]
554
- return g * torch.exp(x.clamp(-15, 15))
555
-
556
- trunc_exp = _trunc_exp.apply
557
-
558
- class NGPNetwork(NGPRenderer):
559
- def __init__(self,
560
- num_layers=2,
561
- hidden_dim=64,
562
- geo_feat_dim=15,
563
- num_layers_color=3,
564
- hidden_dim_color=64,
565
- bound=0.5,
566
- max_resolution=128,
567
- base_resolution=16,
568
- n_levels=16,
569
- **kwargs
570
- ):
571
- super().__init__(bound, **kwargs)
572
-
573
- # sigma network
574
- self.num_layers = num_layers
575
- self.hidden_dim = hidden_dim
576
- self.geo_feat_dim = geo_feat_dim
577
- self.bound = bound
578
-
579
- log2_hashmap_size = 19
580
- n_features_per_level = 2
581
-
582
-
583
- per_level_scale = np.exp2(np.log2(max_resolution / base_resolution) / (n_levels - 1))
584
-
585
- self.encoder = tcnn.Encoding(
586
- n_input_dims=3,
587
- encoding_config={
588
- "otype": "HashGrid",
589
- "n_levels": n_levels,
590
- "n_features_per_level": n_features_per_level,
591
- "log2_hashmap_size": log2_hashmap_size,
592
- "base_resolution": base_resolution,
593
- "per_level_scale": per_level_scale,
594
- },
595
- )
596
-
597
- self.sigma_net = tcnn.Network(
598
- n_input_dims = n_levels * 2,
599
- n_output_dims=1 + self.geo_feat_dim,
600
- network_config={
601
- "otype": "FullyFusedMLP",
602
- "activation": "ReLU",
603
- "output_activation": "None",
604
- "n_neurons": hidden_dim,
605
- "n_hidden_layers": num_layers - 1,
606
- },
607
- )
608
-
609
- # color network
610
- self.num_layers_color = num_layers_color
611
- self.hidden_dim_color = hidden_dim_color
612
-
613
- self.encoder_dir = tcnn.Encoding(
614
- n_input_dims=3,
615
- encoding_config={
616
- "otype": "SphericalHarmonics",
617
- "degree": 4,
618
- },
619
- )
620
-
621
- self.in_dim_color = self.encoder_dir.n_output_dims + self.geo_feat_dim
622
-
623
- self.color_net = tcnn.Network(
624
- n_input_dims = self.in_dim_color,
625
- n_output_dims=3,
626
- network_config={
627
- "otype": "FullyFusedMLP",
628
- "activation": "ReLU",
629
- "output_activation": "None",
630
- "n_neurons": hidden_dim_color,
631
- "n_hidden_layers": num_layers_color - 1,
632
- },
633
- )
634
- self.density_scale, self.density_std = 10.0, 0.25
635
-
636
- def forward(self, x, d):
637
- # x: [N, 3], in [-bound, bound]
638
- # d: [N, 3], nomalized in [-1, 1]
639
-
640
-
641
- # sigma
642
- x_raw = x
643
- x = (x + self.bound) / (2 * self.bound) # to [0, 1]
644
- x = self.encoder(x)
645
- h = self.sigma_net(x)
646
-
647
- # sigma = F.relu(h[..., 0])
648
- density = h[..., 0]
649
- # add density bias
650
- dist = torch.norm(x_raw, dim=-1)
651
- density_bias = (1 - dist / self.density_std) * self.density_scale
652
- density = density_bias + density
653
- sigma = F.softplus(density)
654
- geo_feat = h[..., 1:]
655
-
656
- # color
657
- d = (d + 1) / 2 # tcnn SH encoding requires inputs to be in [0, 1]
658
- d = self.encoder_dir(d)
659
-
660
- # p = torch.zeros_like(geo_feat[..., :1]) # manual input padding
661
- h = torch.cat([d, geo_feat], dim=-1)
662
- h = self.color_net(h)
663
-
664
- # sigmoid activation for rgb
665
- color = torch.sigmoid(h)
666
-
667
- return sigma, color
668
-
669
- def density(self, x):
670
- # x: [N, 3], in [-bound, bound]
671
- x_raw = x
672
- x = (x + self.bound) / (2 * self.bound) # to [0, 1]
673
- x = self.encoder(x)
674
- h = self.sigma_net(x)
675
-
676
- # sigma = F.relu(h[..., 0])
677
- density = h[..., 0]
678
- # add density bias
679
- dist = torch.norm(x_raw, dim=-1)
680
- density_bias = (1 - dist / self.density_std) * self.density_scale
681
- density = density_bias + density
682
- sigma = F.softplus(density)
683
- geo_feat = h[..., 1:]
684
-
685
- return {
686
- 'sigma': sigma,
687
- 'geo_feat': geo_feat,
688
- }
689
-
690
- # allow masked inference
691
- def color(self, x, d, mask=None, geo_feat=None, **kwargs):
692
- # x: [N, 3] in [-bound, bound]
693
- # mask: [N,], bool, indicates where we actually needs to compute rgb.
694
-
695
- x = (x + self.bound) / (2 * self.bound) # to [0, 1]
696
-
697
- if mask is not None:
698
- rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3]
699
- # in case of empty mask
700
- if not mask.any():
701
- return rgbs
702
- x = x[mask]
703
- d = d[mask]
704
- geo_feat = geo_feat[mask]
705
-
706
- # color
707
- d = (d + 1) / 2 # tcnn SH encoding requires inputs to be in [0, 1]
708
- d = self.encoder_dir(d)
709
-
710
- h = torch.cat([d, geo_feat], dim=-1)
711
- h = self.color_net(h)
712
-
713
- # sigmoid activation for rgb
714
- h = torch.sigmoid(h)
715
-
716
- if mask is not None:
717
- rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32
718
- else:
719
- rgbs = h
720
-
721
- return rgbs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3drecon/renderer/renderer.py DELETED
@@ -1,640 +0,0 @@
1
- import abc
2
- import os
3
- from pathlib import Path
4
-
5
- import cv2
6
- import numpy as np
7
- import pytorch_lightning as pl
8
- import torch
9
- import torch.nn as nn
10
- import torch.nn.functional as F
11
- from omegaconf import OmegaConf
12
-
13
- from skimage.io import imread, imsave
14
- from PIL import Image
15
- from torch.optim.lr_scheduler import LambdaLR
16
-
17
- from renderer.neus_networks import SDFNetwork, RenderingNetwork, SingleVarianceNetwork, SDFHashGridNetwork, RenderingFFNetwork
18
- from renderer.ngp_renderer import NGPNetwork
19
- from util import instantiate_from_config, read_pickle, concat_images_list
20
-
21
- DEFAULT_RADIUS = np.sqrt(3)/2
22
- DEFAULT_SIDE_LENGTH = 0.6
23
-
24
- def sample_pdf(bins, weights, n_samples, det=True):
25
- device = bins.device
26
- dtype = bins.dtype
27
- # This implementation is from NeRF
28
- # Get pdf
29
- weights = weights + 1e-5 # prevent nans
30
- pdf = weights / torch.sum(weights, -1, keepdim=True)
31
- cdf = torch.cumsum(pdf, -1)
32
- cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
33
- # Take uniform samples
34
- if det:
35
- u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples, dtype=dtype, device=device)
36
- u = u.expand(list(cdf.shape[:-1]) + [n_samples])
37
- else:
38
- u = torch.rand(list(cdf.shape[:-1]) + [n_samples], dtype=dtype, device=device)
39
-
40
- # Invert CDF
41
- u = u.contiguous()
42
- inds = torch.searchsorted(cdf, u, right=True)
43
- below = torch.max(torch.zeros_like(inds - 1), inds - 1)
44
- above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
45
- inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
46
-
47
- matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
48
- cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
49
- bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
50
-
51
- denom = (cdf_g[..., 1] - cdf_g[..., 0])
52
- denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
53
- t = (u - cdf_g[..., 0]) / denom
54
- samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
55
-
56
- return samples
57
-
58
- def near_far_from_sphere(rays_o, rays_d, radius=DEFAULT_RADIUS):
59
- a = torch.sum(rays_d ** 2, dim=-1, keepdim=True)
60
- b = torch.sum(rays_o * rays_d, dim=-1, keepdim=True)
61
- mid = -b / a
62
- near = mid - radius
63
- far = mid + radius
64
- return near, far
65
-
66
- class BackgroundRemoval:
67
- def __init__(self, device='cuda'):
68
- from carvekit.api.high import HiInterface
69
- self.interface = HiInterface(
70
- object_type="object", # Can be "object" or "hairs-like".
71
- batch_size_seg=5,
72
- batch_size_matting=1,
73
- device=device,
74
- seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
75
- matting_mask_size=2048,
76
- trimap_prob_threshold=231,
77
- trimap_dilation=30,
78
- trimap_erosion_iters=5,
79
- fp16=True,
80
- )
81
-
82
- @torch.no_grad()
83
- def __call__(self, image):
84
- # image: [H, W, 3] array in [0, 255].
85
- image = Image.fromarray(image)
86
- image = self.interface([image])[0]
87
- image = np.array(image)
88
- return image
89
-
90
-
91
- class BaseRenderer(nn.Module):
92
- def __init__(self, train_batch_num, test_batch_num):
93
- super().__init__()
94
- self.train_batch_num = train_batch_num
95
- self.test_batch_num = test_batch_num
96
-
97
- @abc.abstractmethod
98
- def render_impl(self, ray_batch, is_train, step):
99
- pass
100
-
101
- @abc.abstractmethod
102
- def render_with_loss(self, ray_batch, is_train, step):
103
- pass
104
-
105
- def render(self, ray_batch, is_train, step):
106
- batch_num = self.train_batch_num if is_train else self.test_batch_num
107
- ray_num = ray_batch['rays_o'].shape[0]
108
- outputs = {}
109
- for ri in range(0, ray_num, batch_num):
110
- cur_ray_batch = {}
111
- for k, v in ray_batch.items():
112
- cur_ray_batch[k] = v[ri:ri + batch_num]
113
- cur_outputs = self.render_impl(cur_ray_batch, is_train, step)
114
- for k, v in cur_outputs.items():
115
- if k not in outputs: outputs[k] = []
116
- outputs[k].append(v)
117
-
118
- for k, v in outputs.items():
119
- outputs[k] = torch.cat(v, 0)
120
- return outputs
121
-
122
-
123
- class NeuSRenderer(BaseRenderer):
124
- def __init__(self, train_batch_num, test_batch_num, lambda_eikonal_loss=0.1, use_mask=True,
125
- lambda_rgb_loss=1.0, lambda_mask_loss=0.0, rgb_loss='soft_l1', coarse_sn=64, fine_sn=64):
126
- super().__init__(train_batch_num, test_batch_num)
127
- self.n_samples = coarse_sn
128
- self.n_importance = fine_sn
129
- self.up_sample_steps = 4
130
- self.anneal_end = 200
131
- self.use_mask = use_mask
132
- self.lambda_eikonal_loss = lambda_eikonal_loss
133
- self.lambda_rgb_loss = lambda_rgb_loss
134
- self.lambda_mask_loss = lambda_mask_loss
135
- self.rgb_loss = rgb_loss
136
-
137
- self.sdf_network = SDFNetwork(d_out=257, d_in=3, d_hidden=256, n_layers=8, skip_in=[4], multires=6, bias=0.5, scale=1.0, geometric_init=True, weight_norm=True)
138
- self.color_network = RenderingNetwork(d_feature=256, d_in=9, d_out=3, d_hidden=256, n_layers=4, weight_norm=True, multires_view=4, squeeze_out=True)
139
- self.default_dtype = torch.float32
140
- self.deviation_network = SingleVarianceNetwork(0.3)
141
-
142
- @torch.no_grad()
143
- def get_vertex_colors(self, vertices):
144
- """
145
- @param vertices: n,3
146
- @return:
147
- """
148
- V = vertices.shape[0]
149
- bn = 20480
150
- verts_colors = []
151
- with torch.no_grad():
152
- for vi in range(0, V, bn):
153
- verts = torch.from_numpy(vertices[vi:vi+bn].astype(np.float32)).cuda()
154
- feats = self.sdf_network(verts)[..., 1:]
155
- gradients = self.sdf_network.gradient(verts) # ...,3
156
- gradients = F.normalize(gradients, dim=-1)
157
- colors = self.color_network(verts, gradients, gradients, feats)
158
- colors = torch.clamp(colors,min=0,max=1).cpu().numpy()
159
- verts_colors.append(colors)
160
-
161
- verts_colors = (np.concatenate(verts_colors, 0)*255).astype(np.uint8)
162
- return verts_colors
163
-
164
- def upsample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_s):
165
- """
166
- Up sampling give a fixed inv_s
167
- """
168
- device = rays_o.device
169
- batch_size, n_samples = z_vals.shape
170
- pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] # n_rays, n_samples, 3
171
- inner_mask = self.get_inner_mask(pts)
172
- # radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=False)
173
- inside_sphere = inner_mask[:, :-1] | inner_mask[:, 1:]
174
- sdf = sdf.reshape(batch_size, n_samples)
175
- prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:]
176
- prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:]
177
- mid_sdf = (prev_sdf + next_sdf) * 0.5
178
- cos_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5)
179
-
180
- prev_cos_val = torch.cat([torch.zeros([batch_size, 1], dtype=self.default_dtype, device=device), cos_val[:, :-1]], dim=-1)
181
- cos_val = torch.stack([prev_cos_val, cos_val], dim=-1)
182
- cos_val, _ = torch.min(cos_val, dim=-1, keepdim=False)
183
- cos_val = cos_val.clip(-1e3, 0.0) * inside_sphere
184
-
185
- dist = (next_z_vals - prev_z_vals)
186
- prev_esti_sdf = mid_sdf - cos_val * dist * 0.5
187
- next_esti_sdf = mid_sdf + cos_val * dist * 0.5
188
- prev_cdf = torch.sigmoid(prev_esti_sdf * inv_s)
189
- next_cdf = torch.sigmoid(next_esti_sdf * inv_s)
190
- alpha = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5)
191
- weights = alpha * torch.cumprod(
192
- torch.cat([torch.ones([batch_size, 1], dtype=self.default_dtype, device=device), 1. - alpha + 1e-7], -1), -1)[:, :-1]
193
-
194
- z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach()
195
- return z_samples
196
-
197
- def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, last=False):
198
- batch_size, n_samples = z_vals.shape
199
- _, n_importance = new_z_vals.shape
200
- pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None]
201
- z_vals = torch.cat([z_vals, new_z_vals], dim=-1)
202
- z_vals, index = torch.sort(z_vals, dim=-1)
203
-
204
- if not last:
205
- device = pts.device
206
- new_sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance)
207
- sdf = torch.cat([sdf, new_sdf], dim=-1)
208
- xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1).to(device)
209
- index = index.reshape(-1)
210
- sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance)
211
-
212
- return z_vals, sdf
213
-
214
- def sample_depth(self, rays_o, rays_d, near, far, perturb):
215
- n_samples = self.n_samples
216
- n_importance = self.n_importance
217
- up_sample_steps = self.up_sample_steps
218
- device = rays_o.device
219
-
220
- # sample points
221
- batch_size = len(rays_o)
222
- z_vals = torch.linspace(0.0, 1.0, n_samples, dtype=self.default_dtype, device=device) # sn
223
- z_vals = near + (far - near) * z_vals[None, :] # rn,sn
224
-
225
- if perturb > 0:
226
- t_rand = (torch.rand([batch_size, 1]).to(device) - 0.5)
227
- z_vals = z_vals + t_rand * 2.0 / n_samples
228
-
229
- # Up sample
230
- with torch.no_grad():
231
- pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None]
232
- sdf = self.sdf_network.sdf(pts).reshape(batch_size, n_samples)
233
-
234
- for i in range(up_sample_steps):
235
- rn, sn = z_vals.shape
236
- inv_s = torch.ones(rn, sn - 1, dtype=self.default_dtype, device=device) * 64 * 2 ** i
237
- new_z_vals = self.upsample(rays_o, rays_d, z_vals, sdf, n_importance // up_sample_steps, inv_s)
238
- z_vals, sdf = self.cat_z_vals(rays_o, rays_d, z_vals, new_z_vals, sdf, last=(i + 1 == up_sample_steps))
239
-
240
- return z_vals
241
-
242
- def compute_sdf_alpha(self, points, dists, dirs, cos_anneal_ratio, step):
243
- # points [...,3] dists [...] dirs[...,3]
244
- sdf_nn_output = self.sdf_network(points)
245
- sdf = sdf_nn_output[..., 0]
246
- feature_vector = sdf_nn_output[..., 1:]
247
-
248
- gradients = self.sdf_network.gradient(points) # ...,3
249
- inv_s = self.deviation_network(points).clip(1e-6, 1e6) # ...,1
250
- inv_s = inv_s[..., 0]
251
-
252
- true_cos = (dirs * gradients).sum(-1) # [...]
253
- iter_cos = -(F.relu(-true_cos * 0.5 + 0.5) * (1.0 - cos_anneal_ratio) +
254
- F.relu(-true_cos) * cos_anneal_ratio) # always non-positive
255
-
256
- # Estimate signed distances at section points
257
- estimated_next_sdf = sdf + iter_cos * dists * 0.5
258
- estimated_prev_sdf = sdf - iter_cos * dists * 0.5
259
-
260
- prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s)
261
- next_cdf = torch.sigmoid(estimated_next_sdf * inv_s)
262
-
263
- p = prev_cdf - next_cdf
264
- c = prev_cdf
265
-
266
- alpha = ((p + 1e-5) / (c + 1e-5)).clip(0.0, 1.0) # [...]
267
- return alpha, gradients, feature_vector, inv_s, sdf
268
-
269
- def get_anneal_val(self, step):
270
- if self.anneal_end < 0:
271
- return 1.0
272
- else:
273
- return np.min([1.0, step / self.anneal_end])
274
-
275
- def get_inner_mask(self, points):
276
- return torch.sum(torch.abs(points)<=DEFAULT_SIDE_LENGTH,-1)==3
277
-
278
- def render_impl(self, ray_batch, is_train, step):
279
- near, far = near_far_from_sphere(ray_batch['rays_o'], ray_batch['rays_d'])
280
- rays_o, rays_d = ray_batch['rays_o'], ray_batch['rays_d']
281
- z_vals = self.sample_depth(rays_o, rays_d, near, far, is_train)
282
-
283
- batch_size, n_samples = z_vals.shape
284
-
285
- # section length in original space
286
- dists = z_vals[..., 1:] - z_vals[..., :-1] # rn,sn-1
287
- dists = torch.cat([dists, dists[..., -1:]], -1) # rn,sn
288
- mid_z_vals = z_vals + dists * 0.5
289
-
290
- points = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * mid_z_vals.unsqueeze(-1) # rn, sn, 3
291
- inner_mask = self.get_inner_mask(points)
292
-
293
- dirs = rays_d.unsqueeze(-2).expand(batch_size, n_samples, 3)
294
- dirs = F.normalize(dirs, dim=-1)
295
- device = rays_o.device
296
- alpha, sampled_color, gradient_error, normal = torch.zeros(batch_size, n_samples, dtype=self.default_dtype, device=device), \
297
- torch.zeros(batch_size, n_samples, 3, dtype=self.default_dtype, device=device), \
298
- torch.zeros([batch_size, n_samples], dtype=self.default_dtype, device=device), \
299
- torch.zeros([batch_size, n_samples, 3], dtype=self.default_dtype, device=device)
300
- if torch.sum(inner_mask) > 0:
301
- cos_anneal_ratio = self.get_anneal_val(step) if is_train else 1.0
302
- alpha[inner_mask], gradients, feature_vector, inv_s, sdf = self.compute_sdf_alpha(points[inner_mask], dists[inner_mask], dirs[inner_mask], cos_anneal_ratio, step)
303
- sampled_color[inner_mask] = self.color_network(points[inner_mask], gradients, -dirs[inner_mask], feature_vector)
304
- # Eikonal loss
305
- gradient_error[inner_mask] = (torch.linalg.norm(gradients, ord=2, dim=-1) - 1.0) ** 2 # rn,sn
306
- normal[inner_mask] = F.normalize(gradients, dim=-1)
307
-
308
- weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1], dtype=self.default_dtype, device=device), 1. - alpha + 1e-7], -1), -1)[..., :-1] # rn,sn
309
- mask = torch.sum(weights,dim=1).unsqueeze(-1) # rn,1
310
- color = (sampled_color * weights[..., None]).sum(dim=1) + (1 - mask) # add white background
311
- normal = (normal * weights[..., None]).sum(dim=1)
312
-
313
- outputs = {
314
- 'rgb': color, # rn,3
315
- 'gradient_error': gradient_error, # rn,sn
316
- 'inner_mask': inner_mask, # rn,sn
317
- 'normal': normal, # rn,3
318
- 'mask': mask, # rn,1
319
- }
320
- return outputs
321
-
322
- def render_with_loss(self, ray_batch, is_train, step):
323
- render_outputs = self.render(ray_batch, is_train, step)
324
-
325
- rgb_gt = ray_batch['rgb']
326
- rgb_pr = render_outputs['rgb']
327
- if self.rgb_loss == 'soft_l1':
328
- epsilon = 0.001
329
- rgb_loss = torch.sqrt(torch.sum((rgb_gt - rgb_pr) ** 2, dim=-1) + epsilon)
330
- elif self.rgb_loss =='mse':
331
- rgb_loss = F.mse_loss(rgb_pr, rgb_gt, reduction='none')
332
- else:
333
- raise NotImplementedError
334
- rgb_loss = torch.mean(rgb_loss)
335
-
336
- eikonal_loss = torch.sum(render_outputs['gradient_error'] * render_outputs['inner_mask']) / torch.sum(render_outputs['inner_mask'] + 1e-5)
337
- loss = rgb_loss * self.lambda_rgb_loss + eikonal_loss * self.lambda_eikonal_loss
338
- loss_batch = {
339
- 'eikonal': eikonal_loss,
340
- 'rendering': rgb_loss,
341
- # 'mask': mask_loss,
342
- }
343
- if self.lambda_mask_loss>0 and self.use_mask:
344
- mask_loss = F.mse_loss(render_outputs['mask'], ray_batch['mask'], reduction='none').mean()
345
- loss += mask_loss * self.lambda_mask_loss
346
- loss_batch['mask'] = mask_loss
347
- return loss, loss_batch
348
-
349
-
350
- class NeRFRenderer(BaseRenderer):
351
- def __init__(self, train_batch_num, test_batch_num, bound=0.5, use_mask=False, lambda_rgb_loss=1.0, lambda_mask_loss=0.0):
352
- super().__init__(train_batch_num, test_batch_num)
353
- self.train_batch_num = train_batch_num
354
- self.test_batch_num = test_batch_num
355
- self.use_mask = use_mask
356
- self.field = NGPNetwork(bound=bound)
357
-
358
- self.update_interval = 16
359
- self.fp16 = True
360
- self.lambda_rgb_loss = lambda_rgb_loss
361
- self.lambda_mask_loss = lambda_mask_loss
362
-
363
- def render_impl(self, ray_batch, is_train, step):
364
- rays_o, rays_d = ray_batch['rays_o'], ray_batch['rays_d']
365
- with torch.cuda.amp.autocast(enabled=self.fp16):
366
- if step % self.update_interval==0:
367
- self.field.update_extra_state()
368
-
369
- outputs = self.field.render(rays_o, rays_d,)
370
-
371
- renderings={
372
- 'rgb': outputs['image'],
373
- 'depth': outputs['depth'],
374
- 'mask': outputs['weights_sum'].unsqueeze(-1),
375
- }
376
- return renderings
377
-
378
- def render_with_loss(self, ray_batch, is_train, step):
379
- render_outputs = self.render(ray_batch, is_train, step)
380
-
381
- rgb_gt = ray_batch['rgb']
382
- rgb_pr = render_outputs['rgb']
383
- epsilon = 0.001
384
- rgb_loss = torch.sqrt(torch.sum((rgb_gt - rgb_pr) ** 2, dim=-1) + epsilon)
385
- rgb_loss = torch.mean(rgb_loss)
386
- loss = rgb_loss * self.lambda_rgb_loss
387
- loss_batch = {'rendering': rgb_loss}
388
-
389
- if self.use_mask:
390
- mask_loss = F.mse_loss(render_outputs['mask'], ray_batch['mask'], reduction='none')
391
- mask_loss = torch.mean(mask_loss)
392
- loss = loss + mask_loss * self.lambda_mask_loss
393
- loss_batch['mask'] = mask_loss
394
- return loss, loss_batch
395
-
396
- def cartesian_to_spherical(xyz):
397
- ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
398
- xy = xyz[:, 0] ** 2 + xyz[:, 1] ** 2
399
- z = np.sqrt(xy + xyz[:, 2] ** 2)
400
- theta = np.arctan2(np.sqrt(xy), xyz[:, 2]) # for elevation angle defined from Z-axis down
401
- # ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up
402
- azimuth = np.arctan2(xyz[:, 1], xyz[:, 0])
403
- return np.array([theta, azimuth, z])
404
-
405
- def get_pose(target_RT):
406
- R, T = target_RT[:3, :3], target_RT[:, -1]
407
- T_target = -R.T @ T
408
- theta_target, azimuth_target, z_target = cartesian_to_spherical(T_target[None, :])
409
- return theta_target, azimuth_target, z_target
410
-
411
-
412
- class RendererTrainer(pl.LightningModule):
413
- def __init__(self, image_path, data_path, total_steps, warm_up_steps, log_dir, train_batch_fg_num=0,
414
- use_cube_feats=False, cube_ckpt=None, cube_cfg=None, cube_bound=0.5,
415
- train_batch_num=4096, test_batch_num=8192, use_warm_up=True, use_mask=True,
416
- lambda_rgb_loss=1.0, lambda_mask_loss=0.0, renderer='neus',
417
- # used in neus
418
- lambda_eikonal_loss=0.1,
419
- coarse_sn=64, fine_sn=64):
420
- super().__init__()
421
- self.num_images = 36 # todo ours 36, syncdreamer 16
422
- self.image_size = 256
423
- self.log_dir = log_dir
424
- (Path(log_dir)/'images').mkdir(exist_ok=True, parents=True)
425
- self.train_batch_num = train_batch_num
426
- self.train_batch_fg_num = train_batch_fg_num
427
- self.test_batch_num = test_batch_num
428
- self.image_path = image_path
429
- self.data_path = data_path
430
- self.total_steps = total_steps
431
- self.warm_up_steps = warm_up_steps
432
- self.use_mask = use_mask
433
- self.lambda_eikonal_loss = lambda_eikonal_loss
434
- self.lambda_rgb_loss = lambda_rgb_loss
435
- self.lambda_mask_loss = lambda_mask_loss
436
- self.use_warm_up = use_warm_up
437
-
438
- self.use_cube_feats, self.cube_cfg, self.cube_ckpt = use_cube_feats, cube_cfg, cube_ckpt
439
-
440
- self._init_dataset()
441
- if renderer=='neus':
442
- self.renderer = NeuSRenderer(train_batch_num, test_batch_num,
443
- lambda_rgb_loss=lambda_rgb_loss,
444
- lambda_eikonal_loss=lambda_eikonal_loss,
445
- lambda_mask_loss=lambda_mask_loss,
446
- coarse_sn=coarse_sn, fine_sn=fine_sn)
447
- elif renderer=='ngp':
448
- self.renderer = NeRFRenderer(train_batch_num, test_batch_num, bound=cube_bound, use_mask=use_mask, lambda_mask_loss=lambda_mask_loss, lambda_rgb_loss=lambda_rgb_loss,)
449
- else:
450
- raise NotImplementedError
451
- self.validation_index = 0
452
-
453
- def _construct_ray_batch(self, images_info):
454
- image_num = images_info['images'].shape[0]
455
- _, h, w, _ = images_info['images'].shape
456
- coords = torch.stack(torch.meshgrid(torch.arange(h), torch.arange(w)), -1)[:, :, (1, 0)] # h,w,2
457
- coords = coords.float()[None, :, :, :].repeat(image_num, 1, 1, 1) # imn,h,w,2
458
- coords = coords.reshape(image_num, h * w, 2)
459
- coords = torch.cat([coords, torch.ones(image_num, h * w, 1, dtype=torch.float32)], 2) # imn,h*w,3
460
-
461
- # imn,h*w,3 @ imn,3,3 => imn,h*w,3
462
- rays_d = coords @ torch.inverse(images_info['Ks']).permute(0, 2, 1)
463
- poses = images_info['poses'] # imn,3,4
464
- R, t = poses[:, :, :3], poses[:, :, 3:]
465
- rays_d = rays_d @ R
466
- rays_d = F.normalize(rays_d, dim=-1)
467
- rays_o = -R.permute(0,2,1) @ t # imn,3,3 @ imn,3,1
468
- rays_o = rays_o.permute(0, 2, 1).repeat(1, h*w, 1) # imn,h*w,3
469
-
470
- ray_batch = {
471
- 'rgb': images_info['images'].reshape(image_num*h*w,3),
472
- 'mask': images_info['masks'].reshape(image_num*h*w,1),
473
- 'rays_o': rays_o.reshape(image_num*h*w,3).float(),
474
- 'rays_d': rays_d.reshape(image_num*h*w,3).float(),
475
- }
476
- return ray_batch
477
-
478
- @staticmethod
479
- def load_model(cfg, ckpt):
480
- config = OmegaConf.load(cfg)
481
- model = instantiate_from_config(config.model)
482
- print(f'loading model from {ckpt} ...')
483
- ckpt = torch.load(ckpt)
484
- model.load_state_dict(ckpt['state_dict'])
485
- model = model.cuda().eval()
486
- return model
487
-
488
- def _init_dataset(self):
489
- mask_predictor = BackgroundRemoval()
490
- # syncdreamer fixed 16 views
491
- # self.K, self.azs, self.els, self.dists, self.poses = read_pickle(f'meta_info/camera-{self.num_images}.pkl')
492
- # for ours+NeuS, we pre fix 36 views
493
- self.K = np.array([[280.,0.,128.],[0.,280.,128.],[0.,0.,1.]], dtype=np.float32)
494
- data_dir = os.path.join(self.data_path, "mario/render_sync_36_single/model/") # fixed 36 views
495
- # get all files .npy
496
- self.azs = []
497
- self.els = []
498
- self.dists = []
499
- self.poses = []
500
- for index in range(self.num_images):
501
- pose = np.load(os.path.join(data_dir, "%03d.npy"%index))[:3, :] # in blender
502
- self.poses.append(pose)
503
- theta, azimuth, radius = get_pose(pose)
504
- self.azs.append(azimuth)
505
- self.els.append(theta)
506
- self.dists.append(radius)
507
- # stack to numpy along axis 0
508
- self.azs = np.stack(self.azs, axis=0) # [25,]
509
- self.els = np.stack(self.els, axis=0) # [25,]
510
- self.dists = np.stack(self.dists, axis=0) # [25,]
511
- self.poses = np.stack(self.poses, axis=0) # [25, 3, 4]
512
-
513
- self.images_info = {'images': [] ,'masks': [], 'Ks': [], 'poses':[]}
514
-
515
- img = imread(self.image_path)
516
-
517
- for index in range(self.num_images):
518
- rgb = np.copy(img[:,index*self.image_size:(index+1)*self.image_size,:])
519
- # predict mask
520
- if self.use_mask:
521
- imsave(f'{self.log_dir}/input-{index}.png', rgb)
522
- masked_image = mask_predictor(rgb)
523
- imsave(f'{self.log_dir}/masked-{index}.png', masked_image)
524
- mask = masked_image[:,:,3].astype(np.float32)/255
525
- else:
526
- h, w, _ = rgb.shape
527
- mask = np.zeros([h,w], np.float32)
528
-
529
- rgb = rgb.astype(np.float32)/255
530
- K, pose = np.copy(self.K), self.poses[index]
531
- self.images_info['images'].append(torch.from_numpy(rgb.astype(np.float32))) # h,w,3
532
- self.images_info['masks'].append(torch.from_numpy(mask.astype(np.float32))) # h,w
533
- self.images_info['Ks'].append(torch.from_numpy(K.astype(np.float32)))
534
- self.images_info['poses'].append(torch.from_numpy(pose.astype(np.float32)))
535
-
536
- for k, v in self.images_info.items(): self.images_info[k] = torch.stack(v, 0) # stack all values
537
-
538
- self.train_batch = self._construct_ray_batch(self.images_info)
539
- self.train_batch_pseudo_fg = {}
540
- pseudo_fg_mask = torch.sum(self.train_batch['rgb']>0.99,1)!=3
541
- for k, v in self.train_batch.items():
542
- self.train_batch_pseudo_fg[k] = v[pseudo_fg_mask]
543
- self.train_ray_fg_num = int(torch.sum(pseudo_fg_mask).cpu().numpy())
544
- self.train_ray_num = self.num_images * self.image_size ** 2
545
- self._shuffle_train_batch()
546
- self._shuffle_train_fg_batch()
547
-
548
- def _shuffle_train_batch(self):
549
- self.train_batch_i = 0
550
- shuffle_idxs = torch.randperm(self.train_ray_num, device='cpu') # shuffle
551
- for k, v in self.train_batch.items():
552
- self.train_batch[k] = v[shuffle_idxs]
553
-
554
- def _shuffle_train_fg_batch(self):
555
- self.train_batch_fg_i = 0
556
- shuffle_idxs = torch.randperm(self.train_ray_fg_num, device='cpu') # shuffle
557
- for k, v in self.train_batch_pseudo_fg.items():
558
- self.train_batch_pseudo_fg[k] = v[shuffle_idxs]
559
-
560
-
561
- def training_step(self, batch, batch_idx):
562
- train_ray_batch = {k: v[self.train_batch_i:self.train_batch_i + self.train_batch_num].cuda() for k, v in self.train_batch.items()}
563
- self.train_batch_i += self.train_batch_num
564
- if self.train_batch_i + self.train_batch_num >= self.train_ray_num: self._shuffle_train_batch()
565
-
566
- if self.train_batch_fg_num>0:
567
- train_ray_batch_fg = {k: v[self.train_batch_fg_i:self.train_batch_fg_i+self.train_batch_fg_num].cuda() for k, v in self.train_batch_pseudo_fg.items()}
568
- self.train_batch_fg_i += self.train_batch_fg_num
569
- if self.train_batch_fg_i + self.train_batch_fg_num >= self.train_ray_fg_num: self._shuffle_train_fg_batch()
570
- for k, v in train_ray_batch_fg.items():
571
- train_ray_batch[k] = torch.cat([train_ray_batch[k], v], 0)
572
-
573
- loss, loss_batch = self.renderer.render_with_loss(train_ray_batch, is_train=True, step=self.global_step)
574
- self.log_dict(loss_batch, prog_bar=True, logger=True, on_step=True, on_epoch=False, rank_zero_only=True)
575
-
576
- self.log('step', self.global_step, prog_bar=True, on_step=True, on_epoch=False, logger=False, rank_zero_only=True)
577
- lr = self.optimizers().param_groups[0]['lr']
578
- self.log('lr', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False, rank_zero_only=True)
579
- return loss
580
-
581
- def _slice_images_info(self, index):
582
- return {k:v[index:index+1] for k, v in self.images_info.items()}
583
-
584
- @torch.no_grad()
585
- def validation_step(self, batch, batch_idx):
586
- with torch.no_grad():
587
- if self.global_rank==0:
588
- # we output an rendering image
589
- images_info = self._slice_images_info(self.validation_index)
590
- self.validation_index += 1
591
- self.validation_index %= self.num_images
592
-
593
- test_ray_batch = self._construct_ray_batch(images_info)
594
- test_ray_batch = {k: v.cuda() for k,v in test_ray_batch.items()}
595
- test_ray_batch['near'], test_ray_batch['far'] = near_far_from_sphere(test_ray_batch['rays_o'], test_ray_batch['rays_d'])
596
- render_outputs = self.renderer.render(test_ray_batch, False, self.global_step)
597
-
598
- process = lambda x: (x.cpu().numpy() * 255).astype(np.uint8)
599
- h, w = self.image_size, self.image_size
600
- rgb = torch.clamp(render_outputs['rgb'].reshape(h, w, 3), max=1.0, min=0.0)
601
- mask = torch.clamp(render_outputs['mask'].reshape(h, w, 1), max=1.0, min=0.0)
602
- mask_ = torch.repeat_interleave(mask, 3, dim=-1)
603
- output_image = concat_images_list(process(rgb), process(mask_))
604
- if 'normal' in render_outputs:
605
- normal = torch.clamp((render_outputs['normal'].reshape(h, w, 3) + 1) / 2, max=1.0, min=0.0)
606
- normal = normal * mask # we only show foregound normal
607
- output_image = concat_images_list(output_image, process(normal))
608
-
609
- # save images
610
- imsave(f'{self.log_dir}/images/{self.global_step}.jpg', output_image)
611
-
612
- def configure_optimizers(self):
613
- lr = self.learning_rate
614
- opt = torch.optim.AdamW([{"params": self.renderer.parameters(), "lr": lr},], lr=lr)
615
-
616
- def schedule_fn(step):
617
- total_step = self.total_steps
618
- warm_up_step = self.warm_up_steps
619
- warm_up_init = 0.02
620
- warm_up_end = 1.0
621
- final_lr = 0.02
622
- interval = 1000
623
- times = total_step // interval
624
- ratio = np.power(final_lr, 1/times)
625
- if step<warm_up_step:
626
- learning_rate = (step / warm_up_step) * (warm_up_end - warm_up_init) + warm_up_init
627
- else:
628
- learning_rate = ratio ** (step // interval) * warm_up_end
629
- return learning_rate
630
-
631
- if self.use_warm_up:
632
- scheduler = [{
633
- 'scheduler': LambdaLR(opt, lr_lambda=schedule_fn),
634
- 'interval': 'step',
635
- 'frequency': 1
636
- }]
637
- else:
638
- scheduler = []
639
- return [opt], scheduler
640
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3drecon/run_NeuS.py DELETED
@@ -1,32 +0,0 @@
1
- import os
2
- import numpy as np
3
- from tqdm import tqdm
4
-
5
- # ours + NeuS
6
- DATA_DIR = "/home/xin/data/EscherNet/Data/GSO30" # GSO
7
- exp_dir = "/home/xin/6DoF/GSO3D/"
8
-
9
- config = "configs/neus_36.yaml"
10
- exps = [1]
11
- # exps = [1, 2, 3, 5, 10]
12
-
13
- for exp in exps:
14
- OUTPUT_DIR = os.path.join(exp_dir, f"logs_GSO_T{exp}M36_99k")
15
- output_NeuS = f"ours_GSO_T{exp}"
16
- os.makedirs(output_NeuS, exist_ok=True)
17
- obj_names = os.listdir(DATA_DIR)
18
- for obj_name in tqdm(obj_names):
19
- if os.path.exists(os.path.join(output_NeuS, "NeuS", obj_name, "mesh.ply")):
20
- print("NeuS already trained for: ", obj_name)
21
- continue
22
- # remove the folder for new training
23
- os.system(f"rm -rf {output_NeuS}/NeuS/{obj_name}")
24
- print("Training NeuS for: ", obj_name)
25
- input_img = os.path.join(OUTPUT_DIR, obj_name, "0.png")
26
- # input_img = os.path.join(OUTPUT_DIR, obj_name, "gt.png") # ground truth image
27
- cmd = f"python train_renderer.py -i {input_img} \
28
- -d {DATA_DIR} \
29
- -n {obj_name} \
30
- -b {config} \
31
- -l {output_NeuS}/NeuS"
32
- os.system(cmd)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3drecon/train_renderer.py DELETED
@@ -1,188 +0,0 @@
1
- import argparse
2
-
3
- import imageio
4
- import numpy as np
5
- import torch
6
- import torch.nn.functional as F
7
- from pathlib import Path
8
-
9
- import trimesh
10
- from omegaconf import OmegaConf
11
- from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, Callback
12
- from pytorch_lightning.loggers import TensorBoardLogger
13
- from pytorch_lightning import Trainer
14
- from skimage.io import imsave
15
- from tqdm import tqdm
16
-
17
- import mcubes
18
-
19
- from renderer.renderer import NeuSRenderer, DEFAULT_SIDE_LENGTH
20
- from util import instantiate_from_config, read_pickle
21
-
22
- class ResumeCallBacks(Callback):
23
- def __init__(self):
24
- pass
25
-
26
- def on_train_start(self, trainer, pl_module):
27
- pl_module.optimizers().param_groups = pl_module.optimizers()._optimizer.param_groups
28
-
29
- def render_images(model, output,):
30
- # render from model
31
- n = 180
32
- azimuths = (np.arange(n) / n * np.pi * 2).astype(np.float32)
33
- elevations = np.deg2rad(np.asarray([30] * n).astype(np.float32))
34
- K, _, _, _, poses = read_pickle(f'meta_info/camera-16.pkl')
35
- output_points
36
- h, w = 256, 256
37
- default_size = 256
38
- K = np.diag([w/default_size,h/default_size,1.0]) @ K
39
- imgs = []
40
- for ni in tqdm(range(n)):
41
- # R = euler2mat(azimuths[ni], elevations[ni], 0, 'szyx')
42
- # R = np.asarray([[0,-1,0],[0,0,-1],[1,0,0]]) @ R
43
- e, a = elevations[ni], azimuths[ni]
44
- row1 = np.asarray([np.sin(e)*np.cos(a),np.sin(e)*np.sin(a),-np.cos(e)])
45
- row0 = np.asarray([-np.sin(a),np.cos(a), 0])
46
- row2 = np.cross(row0, row1)
47
- R = np.stack([row0,row1,row2],0)
48
- t = np.asarray([0,0,1.5])
49
- pose = np.concatenate([R,t[:,None]],1)
50
- pose_ = torch.from_numpy(pose.astype(np.float32)).unsqueeze(0)
51
- K_ = torch.from_numpy(K.astype(np.float32)).unsqueeze(0) # [1,3,3]
52
-
53
- coords = torch.stack(torch.meshgrid(torch.arange(h), torch.arange(w)), -1)[:, :, (1, 0)] # h,w,2
54
- coords = coords.float()[None, :, :, :].repeat(1, 1, 1, 1) # imn,h,w,2
55
- coords = coords.reshape(1, h * w, 2)
56
- coords = torch.cat([coords, torch.ones(1, h * w, 1, dtype=torch.float32)], 2) # imn,h*w,3
57
-
58
- # imn,h*w,3 @ imn,3,3 => imn,h*w,3
59
- rays_d = coords @ torch.inverse(K_).permute(0, 2, 1)
60
- R, t = pose_[:, :, :3], pose_[:, :, 3:]
61
- rays_d = rays_d @ R
62
- rays_d = F.normalize(rays_d, dim=-1)
63
- rays_o = -R.permute(0, 2, 1) @ t # imn,3,3 @ imn,3,1
64
- rays_o = rays_o.permute(0, 2, 1).repeat(1, h * w, 1) # imn,h*w,3
65
-
66
- ray_batch = {
67
- 'rays_o': rays_o.reshape(-1,3).cuda(),
68
- 'rays_d': rays_d.reshape(-1,3).cuda(),
69
- }
70
- with torch.no_grad():
71
- image = model.renderer.render(ray_batch,False,5000)['rgb'].reshape(h,w,3)
72
- image = (image.cpu().numpy() * 255).astype(np.uint8)
73
- imgs.append(image)
74
-
75
- imageio.mimsave(f'{output}/rendering.mp4', imgs, fps=30)
76
-
77
- def extract_fields(bound_min, bound_max, resolution, query_func, batch_size=64, outside_val=1.0):
78
- N = batch_size
79
- X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N)
80
- Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N)
81
- Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N)
82
-
83
- u = np.zeros([resolution, resolution, resolution], dtype=np.float32)
84
- with torch.no_grad():
85
- for xi, xs in enumerate(X):
86
- for yi, ys in enumerate(Y):
87
- for zi, zs in enumerate(Z):
88
- xx, yy, zz = torch.meshgrid(xs, ys, zs)
89
- pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1).cuda()
90
- val = query_func(pts).detach()
91
- outside_mask = torch.norm(pts,dim=-1)>=1.0
92
- val[outside_mask]=outside_val
93
- val = val.reshape(len(xs), len(ys), len(zs)).cpu().numpy()
94
- u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val
95
- return u
96
-
97
- def extract_geometry(bound_min, bound_max, resolution, threshold, query_func, color_func, outside_val=1.0):
98
- u = extract_fields(bound_min, bound_max, resolution, query_func, outside_val=outside_val)
99
- vertices, triangles = mcubes.marching_cubes(u, threshold)
100
- b_max_np = bound_max.detach().cpu().numpy()
101
- b_min_np = bound_min.detach().cpu().numpy()
102
-
103
- vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
104
- vertex_colors = color_func(vertices)
105
- return vertices, triangles, vertex_colors
106
-
107
- def extract_mesh(model, output, resolution=512):
108
- if not isinstance(model.renderer, NeuSRenderer): return
109
- bbox_min = -torch.ones(3)*DEFAULT_SIDE_LENGTH
110
- bbox_max = torch.ones(3)*DEFAULT_SIDE_LENGTH
111
- with torch.no_grad():
112
- vertices, triangles, vertex_colors = extract_geometry(bbox_min, bbox_max, resolution, 0, lambda x: model.renderer.sdf_network.sdf(x), lambda x: model.renderer.get_vertex_colors(x))
113
-
114
- # output geometry
115
- mesh = trimesh.Trimesh(vertices, triangles, vertex_colors=vertex_colors)
116
- mesh.export(str(f'{output}/mesh.ply'))
117
-
118
- def main():
119
- parser = argparse.ArgumentParser()
120
- parser.add_argument('-i', '--image_path', type=str, required=True)
121
- parser.add_argument('-n', '--name', type=str, required=True)
122
- parser.add_argument('-b', '--base', type=str, default='configs/neus.yaml')
123
- parser.add_argument('-d', '--data_path', type=str, default='/data/GSO/')
124
- parser.add_argument('-l', '--log', type=str, default='output/renderer')
125
- parser.add_argument('-s', '--seed', type=int, default=6033)
126
- parser.add_argument('-g', '--gpus', type=str, default='0,')
127
- parser.add_argument('-r', '--resume', action='store_true', default=False, dest='resume')
128
- parser.add_argument('--fp16', action='store_true', default=False, dest='fp16')
129
- opt = parser.parse_args()
130
- # seed_everything(opt.seed)
131
-
132
- # configs
133
- cfg = OmegaConf.load(opt.base)
134
- name = opt.name
135
- log_dir, ckpt_dir = Path(opt.log) / name, Path(opt.log) / name / 'ckpt'
136
- cfg.model.params['image_path'] = opt.image_path
137
- cfg.model.params['log_dir'] = log_dir
138
- cfg.model.params['data_path'] = opt.data_path
139
-
140
- # setup
141
- log_dir.mkdir(exist_ok=True, parents=True)
142
- ckpt_dir.mkdir(exist_ok=True, parents=True)
143
- trainer_config = cfg.trainer
144
- callback_config = cfg.callbacks
145
- model_config = cfg.model
146
- data_config = cfg.data
147
-
148
- data_config.params.seed = opt.seed
149
- data = instantiate_from_config(data_config)
150
- data.prepare_data()
151
- data.setup('fit')
152
-
153
- model = instantiate_from_config(model_config,)
154
- model.cpu()
155
- model.learning_rate = model_config.base_lr
156
-
157
- # logger
158
- logger = TensorBoardLogger(save_dir=log_dir, name='tensorboard_logs')
159
- callbacks=[]
160
- callbacks.append(LearningRateMonitor(logging_interval='step'))
161
- callbacks.append(ModelCheckpoint(dirpath=ckpt_dir, filename="{epoch:06}", verbose=True, save_last=True, every_n_train_steps=callback_config.save_interval))
162
-
163
- # trainer
164
- trainer_config.update({
165
- "accelerator": "cuda", "check_val_every_n_epoch": None,
166
- "benchmark": True, "num_sanity_val_steps": 0,
167
- "devices": 1, "gpus": opt.gpus,
168
- })
169
- if opt.fp16:
170
- trainer_config['precision']=16
171
-
172
- if opt.resume:
173
- callbacks.append(ResumeCallBacks())
174
- trainer_config['resume_from_checkpoint'] = str(ckpt_dir / 'last.ckpt')
175
- else:
176
- if (ckpt_dir / 'last.ckpt').exists():
177
- raise RuntimeError(f"checkpoint {ckpt_dir / 'last.ckpt'} existing ...")
178
- trainer = Trainer.from_argparse_args(args=argparse.Namespace(), **trainer_config, logger=logger, callbacks=callbacks)
179
-
180
- trainer.fit(model, data)
181
-
182
- model = model.cuda().eval()
183
-
184
- # render_images(model, log_dir)
185
- extract_mesh(model, log_dir)
186
-
187
- if __name__=="__main__":
188
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3drecon/util.py DELETED
@@ -1,54 +0,0 @@
1
- import importlib
2
- import pickle
3
- import numpy as np
4
- import cv2
5
-
6
- def instantiate_from_config(config):
7
- if not "target" in config:
8
- if config == '__is_first_stage__':
9
- return None
10
- elif config == "__is_unconditional__":
11
- return None
12
- raise KeyError("Expected key `target` to instantiate.")
13
- return get_obj_from_str(config["target"])(**config.get("params", dict()))
14
-
15
-
16
- def get_obj_from_str(string, reload=False):
17
- module, cls = string.rsplit(".", 1)
18
- if reload:
19
- module_imp = importlib.import_module(module)
20
- importlib.reload(module_imp)
21
- return getattr(importlib.import_module(module, package=None), cls)
22
-
23
- def read_pickle(pkl_path):
24
- with open(pkl_path, 'rb') as f:
25
- return pickle.load(f)
26
-
27
- def output_points(fn,pts,colors=None):
28
- with open(fn, 'w') as f:
29
- for pi, pt in enumerate(pts):
30
- f.write(f'{pt[0]:.6f} {pt[1]:.6f} {pt[2]:.6f} ')
31
- if colors is not None:
32
- f.write(f'{int(colors[pi,0])} {int(colors[pi,1])} {int(colors[pi,2])}')
33
- f.write('\n')
34
-
35
- def concat_images(img0,img1,vert=False):
36
- if not vert:
37
- h0,h1=img0.shape[0],img1.shape[0],
38
- if h0<h1: img0=cv2.copyMakeBorder(img0,0,h1-h0,0,0,borderType=cv2.BORDER_CONSTANT,value=0)
39
- if h1<h0: img1=cv2.copyMakeBorder(img1,0,h0-h1,0,0,borderType=cv2.BORDER_CONSTANT,value=0)
40
- img = np.concatenate([img0, img1], axis=1)
41
- else:
42
- w0,w1=img0.shape[1],img1.shape[1]
43
- if w0<w1: img0=cv2.copyMakeBorder(img0,0,0,0,w1-w0,borderType=cv2.BORDER_CONSTANT,value=0)
44
- if w1<w0: img1=cv2.copyMakeBorder(img1,0,0,0,w0-w1,borderType=cv2.BORDER_CONSTANT,value=0)
45
- img = np.concatenate([img0, img1], axis=0)
46
-
47
- return img
48
-
49
- def concat_images_list(*args,vert=False):
50
- if len(args)==1: return args[0]
51
- img_out=args[0]
52
- for img in args[1:]:
53
- img_out=concat_images(img_out,img,vert)
54
- return img_out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4DoF/CN_encoder.py DELETED
@@ -1,36 +0,0 @@
1
- from transformers import ConvNextV2Model
2
- import torch
3
- from typing import Optional
4
- import einops
5
-
6
- class CN_encoder(ConvNextV2Model):
7
- def __init__(self, config):
8
- super().__init__(config)
9
-
10
- def forward(
11
- self,
12
- pixel_values: torch.FloatTensor = None,
13
- output_hidden_states: Optional[bool] = None,
14
- return_dict: Optional[bool] = None,
15
- ):
16
- output_hidden_states = (
17
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
18
- )
19
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
20
-
21
- if pixel_values is None:
22
- raise ValueError("You have to specify pixel_values")
23
-
24
- embedding_output = self.embeddings(pixel_values)
25
-
26
- encoder_outputs = self.encoder(
27
- embedding_output,
28
- output_hidden_states=output_hidden_states,
29
- return_dict=return_dict,
30
- )
31
-
32
- last_hidden_state = encoder_outputs[0]
33
- image_embeddings = einops.rearrange(last_hidden_state, 'b c h w -> b (h w) c')
34
- image_embeddings = self.layernorm(image_embeddings)
35
-
36
- return image_embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4DoF/dataset.py DELETED
@@ -1,228 +0,0 @@
1
- import os
2
- import math
3
- from pathlib import Path
4
- import torch
5
- import torchvision
6
- from torch.utils.data import Dataset, DataLoader
7
- from torchvision import transforms
8
- from PIL import Image
9
- import numpy as np
10
- import webdataset as wds
11
- from torch.utils.data.distributed import DistributedSampler
12
- import matplotlib.pyplot as plt
13
- import sys
14
-
15
- class ObjaverseDataLoader():
16
- def __init__(self, root_dir, batch_size, total_view=12, num_workers=4):
17
- self.root_dir = root_dir
18
- self.batch_size = batch_size
19
- self.num_workers = num_workers
20
- self.total_view = total_view
21
-
22
- image_transforms = [torchvision.transforms.Resize((256, 256)),
23
- transforms.ToTensor(),
24
- transforms.Normalize([0.5], [0.5])]
25
- self.image_transforms = torchvision.transforms.Compose(image_transforms)
26
-
27
- def train_dataloader(self):
28
- dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=False,
29
- image_transforms=self.image_transforms)
30
- # sampler = DistributedSampler(dataset)
31
- return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
32
- # sampler=sampler)
33
-
34
- def val_dataloader(self):
35
- dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=True,
36
- image_transforms=self.image_transforms)
37
- sampler = DistributedSampler(dataset)
38
- return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
39
-
40
- def cartesian_to_spherical(xyz):
41
- ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
42
- xy = xyz[:, 0] ** 2 + xyz[:, 1] ** 2
43
- z = np.sqrt(xy + xyz[:, 2] ** 2)
44
- theta = np.arctan2(np.sqrt(xy), xyz[:, 2]) # for elevation angle defined from Z-axis down
45
- # ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up
46
- azimuth = np.arctan2(xyz[:, 1], xyz[:, 0])
47
- return np.array([theta, azimuth, z])
48
-
49
- def get_pose(target_RT):
50
- target_RT = target_RT[:3, :]
51
- R, T = target_RT[:3, :3], target_RT[:, -1]
52
- T_target = -R.T @ T
53
- theta_target, azimuth_target, z_target = cartesian_to_spherical(T_target[None, :])
54
- # assert if z_target is out of range
55
- if z_target.item() < 1.5 or z_target.item() > 2.2:
56
- # print('z_target out of range 1.5-2.2', z_target.item())
57
- z_target = np.clip(z_target.item(), 1.5, 2.2)
58
- # with log scale for radius
59
- target_T = torch.tensor([theta_target.item(), azimuth_target.item(), (np.log(z_target.item()) - np.log(1.5))/(np.log(2.2)-np.log(1.5)) * torch.pi, torch.tensor(0)])
60
- assert torch.all(target_T <= torch.pi) and torch.all(target_T >= -torch.pi)
61
- return target_T.numpy()
62
-
63
- class ObjaverseData(Dataset):
64
- def __init__(self,
65
- root_dir='.objaverse/hf-objaverse-v1/views',
66
- image_transforms=None,
67
- total_view=12,
68
- validation=False,
69
- T_in=1,
70
- T_out=1,
71
- fix_sample=False,
72
- ) -> None:
73
- """Create a dataset from a folder of images.
74
- If you pass in a root directory it will be searched for images
75
- ending in ext (ext can be a list)
76
- """
77
- self.root_dir = Path(root_dir)
78
- self.total_view = total_view
79
- self.T_in = T_in
80
- self.T_out = T_out
81
- self.fix_sample = fix_sample
82
-
83
- self.paths = []
84
- # # include all folders
85
- # for folder in os.listdir(self.root_dir):
86
- # if os.path.isdir(os.path.join(self.root_dir, folder)):
87
- # self.paths.append(folder)
88
- # load ids from .npy so we have exactly the same ids/order
89
- self.paths = np.load("../scripts/obj_ids.npy")
90
- # # only use 100K objects for ablation study
91
- # self.paths = self.paths[:100000]
92
- total_objects = len(self.paths)
93
- assert total_objects == 790152, 'total objects %d' % total_objects
94
- if validation:
95
- self.paths = self.paths[math.floor(total_objects / 100. * 99.):] # used last 1% as validation
96
- else:
97
- self.paths = self.paths[:math.floor(total_objects / 100. * 99.)] # used first 99% as training
98
- print('============= length of dataset %d =============' % len(self.paths))
99
- self.tform = image_transforms
100
-
101
- downscale = 512 / 256.
102
- self.fx = 560. / downscale
103
- self.fy = 560. / downscale
104
- self.intrinsic = torch.tensor([[self.fx, 0, 128., 0, self.fy, 128., 0, 0, 1.]], dtype=torch.float64).view(3, 3)
105
-
106
- def __len__(self):
107
- return len(self.paths)
108
-
109
- def cartesian_to_spherical(self, xyz):
110
- ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
111
- xy = xyz[:, 0] ** 2 + xyz[:, 1] ** 2
112
- z = np.sqrt(xy + xyz[:, 2] ** 2)
113
- theta = np.arctan2(np.sqrt(xy), xyz[:, 2]) # for elevation angle defined from Z-axis down
114
- # ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up
115
- azimuth = np.arctan2(xyz[:, 1], xyz[:, 0])
116
- return np.array([theta, azimuth, z])
117
-
118
- def get_T(self, target_RT, cond_RT):
119
- R, T = target_RT[:3, :3], target_RT[:, -1]
120
- T_target = -R.T @ T
121
-
122
- R, T = cond_RT[:3, :3], cond_RT[:, -1]
123
- T_cond = -R.T @ T
124
-
125
- theta_cond, azimuth_cond, z_cond = self.cartesian_to_spherical(T_cond[None, :])
126
- theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :])
127
-
128
- d_theta = theta_target - theta_cond
129
- d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi)
130
- d_z = z_target - z_cond
131
-
132
- d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()])
133
- return d_T
134
-
135
- def get_pose(self, target_RT):
136
- R, T = target_RT[:3, :3], target_RT[:, -1]
137
- T_target = -R.T @ T
138
- theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :])
139
- # assert if z_target is out of range
140
- if z_target.item() < 1.5 or z_target.item() > 2.2:
141
- # print('z_target out of range 1.5-2.2', z_target.item())
142
- z_target = np.clip(z_target.item(), 1.5, 2.2)
143
- # with log scale for radius
144
- target_T = torch.tensor([theta_target.item(), azimuth_target.item(), (np.log(z_target.item()) - np.log(1.5))/(np.log(2.2)-np.log(1.5)) * torch.pi, torch.tensor(0)])
145
- assert torch.all(target_T <= torch.pi) and torch.all(target_T >= -torch.pi)
146
- return target_T
147
-
148
- def load_im(self, path, color):
149
- '''
150
- replace background pixel with random color in rendering
151
- '''
152
- try:
153
- img = plt.imread(path)
154
- except:
155
- print(path)
156
- sys.exit()
157
- img[img[:, :, -1] == 0.] = color
158
- img = Image.fromarray(np.uint8(img[:, :, :3] * 255.))
159
- return img
160
-
161
- def __getitem__(self, index):
162
- data = {}
163
- total_view = 12
164
-
165
- if self.fix_sample:
166
- if self.T_out > 1:
167
- indexes = range(total_view)
168
- index_targets = list(indexes[:2]) + list(indexes[-(self.T_out-2):])
169
- index_inputs = indexes[1:self.T_in+1] # one overlap identity
170
- else:
171
- indexes = range(total_view)
172
- index_targets = indexes[:self.T_out]
173
- index_inputs = indexes[self.T_out-1:self.T_in+self.T_out-1] # one overlap identity
174
- else:
175
- assert self.T_in + self.T_out <= total_view
176
- # training with replace, including identity
177
- indexes = np.random.choice(range(total_view), self.T_in+self.T_out, replace=True)
178
- index_inputs = indexes[:self.T_in]
179
- index_targets = indexes[self.T_in:]
180
- filename = os.path.join(self.root_dir, self.paths[index])
181
-
182
- color = [1., 1., 1., 1.]
183
-
184
- try:
185
- input_ims = []
186
- target_ims = []
187
- target_Ts = []
188
- cond_Ts = []
189
- for i, index_input in enumerate(index_inputs):
190
- input_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_input), color))
191
- input_ims.append(input_im)
192
- input_RT = np.load(os.path.join(filename, '%03d.npy' % index_input))
193
- cond_Ts.append(self.get_pose(input_RT))
194
- for i, index_target in enumerate(index_targets):
195
- target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color))
196
- target_ims.append(target_im)
197
- target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target))
198
- target_Ts.append(self.get_pose(target_RT))
199
- except:
200
- print('error loading data ', filename)
201
- filename = os.path.join(self.root_dir, '0a01f314e2864711aa7e33bace4bd8c8') # this one we know is valid
202
- input_ims = []
203
- target_ims = []
204
- target_Ts = []
205
- cond_Ts = []
206
- # very hacky solution, sorry about this
207
- for i, index_input in enumerate(index_inputs):
208
- input_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_input), color))
209
- input_ims.append(input_im)
210
- input_RT = np.load(os.path.join(filename, '%03d.npy' % index_input))
211
- cond_Ts.append(self.get_pose(input_RT))
212
- for i, index_target in enumerate(index_targets):
213
- target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color))
214
- target_ims.append(target_im)
215
- target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target))
216
- target_Ts.append(self.get_pose(target_RT))
217
-
218
- # stack to batch
219
- data['image_input'] = torch.stack(input_ims, dim=0)
220
- data['image_target'] = torch.stack(target_ims, dim=0)
221
- data['pose_out'] = torch.stack(target_Ts, dim=0)
222
- data['pose_in'] = torch.stack(cond_Ts, dim=0)
223
-
224
- return data
225
-
226
- def process_im(self, im):
227
- im = im.convert("RGB")
228
- return self.tform(im)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4DoF/diffusers/__init__.py DELETED
@@ -1,281 +0,0 @@
1
- __version__ = "0.18.2"
2
-
3
- from .configuration_utils import ConfigMixin
4
- from .utils import (
5
- OptionalDependencyNotAvailable,
6
- is_flax_available,
7
- is_inflect_available,
8
- is_invisible_watermark_available,
9
- is_k_diffusion_available,
10
- is_k_diffusion_version,
11
- is_librosa_available,
12
- is_note_seq_available,
13
- is_onnx_available,
14
- is_scipy_available,
15
- is_torch_available,
16
- is_torchsde_available,
17
- is_transformers_available,
18
- is_transformers_version,
19
- is_unidecode_available,
20
- logging,
21
- )
22
-
23
-
24
- try:
25
- if not is_onnx_available():
26
- raise OptionalDependencyNotAvailable()
27
- except OptionalDependencyNotAvailable:
28
- from .utils.dummy_onnx_objects import * # noqa F403
29
- else:
30
- from .pipelines import OnnxRuntimeModel
31
-
32
- try:
33
- if not is_torch_available():
34
- raise OptionalDependencyNotAvailable()
35
- except OptionalDependencyNotAvailable:
36
- from .utils.dummy_pt_objects import * # noqa F403
37
- else:
38
- from .models import (
39
- AutoencoderKL,
40
- ControlNetModel,
41
- ModelMixin,
42
- PriorTransformer,
43
- T5FilmDecoder,
44
- Transformer2DModel,
45
- UNet1DModel,
46
- UNet2DConditionModel,
47
- UNet2DModel,
48
- UNet3DConditionModel,
49
- VQModel,
50
- )
51
- from .optimization import (
52
- get_constant_schedule,
53
- get_constant_schedule_with_warmup,
54
- get_cosine_schedule_with_warmup,
55
- get_cosine_with_hard_restarts_schedule_with_warmup,
56
- get_linear_schedule_with_warmup,
57
- get_polynomial_decay_schedule_with_warmup,
58
- get_scheduler,
59
- )
60
- from .pipelines import (
61
- AudioPipelineOutput,
62
- ConsistencyModelPipeline,
63
- DanceDiffusionPipeline,
64
- DDIMPipeline,
65
- DDPMPipeline,
66
- DiffusionPipeline,
67
- DiTPipeline,
68
- ImagePipelineOutput,
69
- KarrasVePipeline,
70
- LDMPipeline,
71
- LDMSuperResolutionPipeline,
72
- PNDMPipeline,
73
- RePaintPipeline,
74
- ScoreSdeVePipeline,
75
- )
76
- from .schedulers import (
77
- CMStochasticIterativeScheduler,
78
- DDIMInverseScheduler,
79
- DDIMParallelScheduler,
80
- DDIMScheduler,
81
- DDPMParallelScheduler,
82
- DDPMScheduler,
83
- DEISMultistepScheduler,
84
- DPMSolverMultistepInverseScheduler,
85
- DPMSolverMultistepScheduler,
86
- DPMSolverSinglestepScheduler,
87
- EulerAncestralDiscreteScheduler,
88
- EulerDiscreteScheduler,
89
- HeunDiscreteScheduler,
90
- IPNDMScheduler,
91
- KarrasVeScheduler,
92
- KDPM2AncestralDiscreteScheduler,
93
- KDPM2DiscreteScheduler,
94
- PNDMScheduler,
95
- RePaintScheduler,
96
- SchedulerMixin,
97
- ScoreSdeVeScheduler,
98
- UnCLIPScheduler,
99
- UniPCMultistepScheduler,
100
- VQDiffusionScheduler,
101
- )
102
- from .training_utils import EMAModel
103
-
104
- try:
105
- if not (is_torch_available() and is_scipy_available()):
106
- raise OptionalDependencyNotAvailable()
107
- except OptionalDependencyNotAvailable:
108
- from .utils.dummy_torch_and_scipy_objects import * # noqa F403
109
- else:
110
- from .schedulers import LMSDiscreteScheduler
111
-
112
- try:
113
- if not (is_torch_available() and is_torchsde_available()):
114
- raise OptionalDependencyNotAvailable()
115
- except OptionalDependencyNotAvailable:
116
- from .utils.dummy_torch_and_torchsde_objects import * # noqa F403
117
- else:
118
- from .schedulers import DPMSolverSDEScheduler
119
-
120
- try:
121
- if not (is_torch_available() and is_transformers_available()):
122
- raise OptionalDependencyNotAvailable()
123
- except OptionalDependencyNotAvailable:
124
- from .utils.dummy_torch_and_transformers_objects import * # noqa F403
125
- else:
126
- from .pipelines import (
127
- AltDiffusionImg2ImgPipeline,
128
- AltDiffusionPipeline,
129
- AudioLDMPipeline,
130
- CycleDiffusionPipeline,
131
- IFImg2ImgPipeline,
132
- IFImg2ImgSuperResolutionPipeline,
133
- IFInpaintingPipeline,
134
- IFInpaintingSuperResolutionPipeline,
135
- IFPipeline,
136
- IFSuperResolutionPipeline,
137
- ImageTextPipelineOutput,
138
- KandinskyImg2ImgPipeline,
139
- KandinskyInpaintPipeline,
140
- KandinskyPipeline,
141
- KandinskyPriorPipeline,
142
- KandinskyV22ControlnetImg2ImgPipeline,
143
- KandinskyV22ControlnetPipeline,
144
- KandinskyV22Img2ImgPipeline,
145
- KandinskyV22InpaintPipeline,
146
- KandinskyV22Pipeline,
147
- KandinskyV22PriorEmb2EmbPipeline,
148
- KandinskyV22PriorPipeline,
149
- LDMTextToImagePipeline,
150
- PaintByExamplePipeline,
151
- SemanticStableDiffusionPipeline,
152
- ShapEImg2ImgPipeline,
153
- ShapEPipeline,
154
- StableDiffusionAttendAndExcitePipeline,
155
- StableDiffusionControlNetImg2ImgPipeline,
156
- StableDiffusionControlNetInpaintPipeline,
157
- StableDiffusionControlNetPipeline,
158
- StableDiffusionDepth2ImgPipeline,
159
- StableDiffusionDiffEditPipeline,
160
- StableDiffusionImageVariationPipeline,
161
- StableDiffusionImg2ImgPipeline,
162
- StableDiffusionInpaintPipeline,
163
- StableDiffusionInpaintPipelineLegacy,
164
- StableDiffusionInstructPix2PixPipeline,
165
- StableDiffusionLatentUpscalePipeline,
166
- StableDiffusionLDM3DPipeline,
167
- StableDiffusionModelEditingPipeline,
168
- StableDiffusionPanoramaPipeline,
169
- StableDiffusionParadigmsPipeline,
170
- StableDiffusionPipeline,
171
- StableDiffusionPipelineSafe,
172
- StableDiffusionPix2PixZeroPipeline,
173
- StableDiffusionSAGPipeline,
174
- StableDiffusionUpscalePipeline,
175
- StableUnCLIPImg2ImgPipeline,
176
- StableUnCLIPPipeline,
177
- TextToVideoSDPipeline,
178
- TextToVideoZeroPipeline,
179
- UnCLIPImageVariationPipeline,
180
- UnCLIPPipeline,
181
- UniDiffuserModel,
182
- UniDiffuserPipeline,
183
- UniDiffuserTextDecoder,
184
- VersatileDiffusionDualGuidedPipeline,
185
- VersatileDiffusionImageVariationPipeline,
186
- VersatileDiffusionPipeline,
187
- VersatileDiffusionTextToImagePipeline,
188
- VideoToVideoSDPipeline,
189
- VQDiffusionPipeline,
190
- )
191
-
192
- try:
193
- if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()):
194
- raise OptionalDependencyNotAvailable()
195
- except OptionalDependencyNotAvailable:
196
- from .utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
197
- else:
198
- from .pipelines import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline
199
-
200
- try:
201
- if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
202
- raise OptionalDependencyNotAvailable()
203
- except OptionalDependencyNotAvailable:
204
- from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
205
- else:
206
- from .pipelines import StableDiffusionKDiffusionPipeline
207
-
208
- try:
209
- if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
210
- raise OptionalDependencyNotAvailable()
211
- except OptionalDependencyNotAvailable:
212
- from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
213
- else:
214
- from .pipelines import (
215
- OnnxStableDiffusionImg2ImgPipeline,
216
- OnnxStableDiffusionInpaintPipeline,
217
- OnnxStableDiffusionInpaintPipelineLegacy,
218
- OnnxStableDiffusionPipeline,
219
- OnnxStableDiffusionUpscalePipeline,
220
- StableDiffusionOnnxPipeline,
221
- )
222
-
223
- try:
224
- if not (is_torch_available() and is_librosa_available()):
225
- raise OptionalDependencyNotAvailable()
226
- except OptionalDependencyNotAvailable:
227
- from .utils.dummy_torch_and_librosa_objects import * # noqa F403
228
- else:
229
- from .pipelines import AudioDiffusionPipeline, Mel
230
-
231
- try:
232
- if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
233
- raise OptionalDependencyNotAvailable()
234
- except OptionalDependencyNotAvailable:
235
- from .utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
236
- else:
237
- from .pipelines import SpectrogramDiffusionPipeline
238
-
239
- try:
240
- if not is_flax_available():
241
- raise OptionalDependencyNotAvailable()
242
- except OptionalDependencyNotAvailable:
243
- from .utils.dummy_flax_objects import * # noqa F403
244
- else:
245
- from .models.controlnet_flax import FlaxControlNetModel
246
- from .models.modeling_flax_utils import FlaxModelMixin
247
- from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
248
- from .models.vae_flax import FlaxAutoencoderKL
249
- from .pipelines import FlaxDiffusionPipeline
250
- from .schedulers import (
251
- FlaxDDIMScheduler,
252
- FlaxDDPMScheduler,
253
- FlaxDPMSolverMultistepScheduler,
254
- FlaxKarrasVeScheduler,
255
- FlaxLMSDiscreteScheduler,
256
- FlaxPNDMScheduler,
257
- FlaxSchedulerMixin,
258
- FlaxScoreSdeVeScheduler,
259
- )
260
-
261
-
262
- try:
263
- if not (is_flax_available() and is_transformers_available()):
264
- raise OptionalDependencyNotAvailable()
265
- except OptionalDependencyNotAvailable:
266
- from .utils.dummy_flax_and_transformers_objects import * # noqa F403
267
- else:
268
- from .pipelines import (
269
- FlaxStableDiffusionControlNetPipeline,
270
- FlaxStableDiffusionImg2ImgPipeline,
271
- FlaxStableDiffusionInpaintPipeline,
272
- FlaxStableDiffusionPipeline,
273
- )
274
-
275
- try:
276
- if not (is_note_seq_available()):
277
- raise OptionalDependencyNotAvailable()
278
- except OptionalDependencyNotAvailable:
279
- from .utils.dummy_note_seq_objects import * # noqa F403
280
- else:
281
- from .pipelines import MidiProcessor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4DoF/diffusers/commands/__init__.py DELETED
@@ -1,27 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from abc import ABC, abstractmethod
16
- from argparse import ArgumentParser
17
-
18
-
19
- class BaseDiffusersCLICommand(ABC):
20
- @staticmethod
21
- @abstractmethod
22
- def register_subcommand(parser: ArgumentParser):
23
- raise NotImplementedError()
24
-
25
- @abstractmethod
26
- def run(self):
27
- raise NotImplementedError()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4DoF/diffusers/commands/diffusers_cli.py DELETED
@@ -1,41 +0,0 @@
1
- #!/usr/bin/env python
2
- # Copyright 2023 The HuggingFace Team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from argparse import ArgumentParser
17
-
18
- from .env import EnvironmentCommand
19
-
20
-
21
- def main():
22
- parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli <command> [<args>]")
23
- commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")
24
-
25
- # Register commands
26
- EnvironmentCommand.register_subcommand(commands_parser)
27
-
28
- # Let's go
29
- args = parser.parse_args()
30
-
31
- if not hasattr(args, "func"):
32
- parser.print_help()
33
- exit(1)
34
-
35
- # Run
36
- service = args.func(args)
37
- service.run()
38
-
39
-
40
- if __name__ == "__main__":
41
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4DoF/diffusers/commands/env.py DELETED
@@ -1,84 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import platform
16
- from argparse import ArgumentParser
17
-
18
- import huggingface_hub
19
-
20
- from .. import __version__ as version
21
- from ..utils import is_accelerate_available, is_torch_available, is_transformers_available, is_xformers_available
22
- from . import BaseDiffusersCLICommand
23
-
24
-
25
- def info_command_factory(_):
26
- return EnvironmentCommand()
27
-
28
-
29
- class EnvironmentCommand(BaseDiffusersCLICommand):
30
- @staticmethod
31
- def register_subcommand(parser: ArgumentParser):
32
- download_parser = parser.add_parser("env")
33
- download_parser.set_defaults(func=info_command_factory)
34
-
35
- def run(self):
36
- hub_version = huggingface_hub.__version__
37
-
38
- pt_version = "not installed"
39
- pt_cuda_available = "NA"
40
- if is_torch_available():
41
- import torch
42
-
43
- pt_version = torch.__version__
44
- pt_cuda_available = torch.cuda.is_available()
45
-
46
- transformers_version = "not installed"
47
- if is_transformers_available():
48
- import transformers
49
-
50
- transformers_version = transformers.__version__
51
-
52
- accelerate_version = "not installed"
53
- if is_accelerate_available():
54
- import accelerate
55
-
56
- accelerate_version = accelerate.__version__
57
-
58
- xformers_version = "not installed"
59
- if is_xformers_available():
60
- import xformers
61
-
62
- xformers_version = xformers.__version__
63
-
64
- info = {
65
- "`diffusers` version": version,
66
- "Platform": platform.platform(),
67
- "Python version": platform.python_version(),
68
- "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
69
- "Huggingface_hub version": hub_version,
70
- "Transformers version": transformers_version,
71
- "Accelerate version": accelerate_version,
72
- "xFormers version": xformers_version,
73
- "Using GPU in script?": "<fill in>",
74
- "Using distributed or parallel set-up in script?": "<fill in>",
75
- }
76
-
77
- print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
78
- print(self.format_dict(info))
79
-
80
- return info
81
-
82
- @staticmethod
83
- def format_dict(d):
84
- return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4DoF/diffusers/configuration_utils.py DELETED
@@ -1,664 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2023 The HuggingFace Inc. team.
3
- # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """ ConfigMixin base class and utilities."""
17
- import dataclasses
18
- import functools
19
- import importlib
20
- import inspect
21
- import json
22
- import os
23
- import re
24
- from collections import OrderedDict
25
- from pathlib import PosixPath
26
- from typing import Any, Dict, Tuple, Union
27
-
28
- import numpy as np
29
- from huggingface_hub import hf_hub_download
30
- from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
31
- from requests import HTTPError
32
-
33
- from . import __version__
34
- from .utils import (
35
- DIFFUSERS_CACHE,
36
- HUGGINGFACE_CO_RESOLVE_ENDPOINT,
37
- DummyObject,
38
- deprecate,
39
- extract_commit_hash,
40
- http_user_agent,
41
- logging,
42
- )
43
-
44
-
45
- logger = logging.get_logger(__name__)
46
-
47
- _re_configuration_file = re.compile(r"config\.(.*)\.json")
48
-
49
-
50
- class FrozenDict(OrderedDict):
51
- def __init__(self, *args, **kwargs):
52
- super().__init__(*args, **kwargs)
53
-
54
- for key, value in self.items():
55
- setattr(self, key, value)
56
-
57
- self.__frozen = True
58
-
59
- def __delitem__(self, *args, **kwargs):
60
- raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
61
-
62
- def setdefault(self, *args, **kwargs):
63
- raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
64
-
65
- def pop(self, *args, **kwargs):
66
- raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
67
-
68
- def update(self, *args, **kwargs):
69
- raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
70
-
71
- def __setattr__(self, name, value):
72
- if hasattr(self, "__frozen") and self.__frozen:
73
- raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
74
- super().__setattr__(name, value)
75
-
76
- def __setitem__(self, name, value):
77
- if hasattr(self, "__frozen") and self.__frozen:
78
- raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
79
- super().__setitem__(name, value)
80
-
81
-
82
- class ConfigMixin:
83
- r"""
84
- Base class for all configuration classes. All configuration parameters are stored under `self.config`. Also
85
- provides the [`~ConfigMixin.from_config`] and [`~ConfigMixin.save_config`] methods for loading, downloading, and
86
- saving classes that inherit from [`ConfigMixin`].
87
-
88
- Class attributes:
89
- - **config_name** (`str`) -- A filename under which the config should stored when calling
90
- [`~ConfigMixin.save_config`] (should be overridden by parent class).
91
- - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
92
- overridden by subclass).
93
- - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
94
- - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function
95
- should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
96
- subclass).
97
- """
98
- config_name = None
99
- ignore_for_config = []
100
- has_compatibles = False
101
-
102
- _deprecated_kwargs = []
103
-
104
- def register_to_config(self, **kwargs):
105
- if self.config_name is None:
106
- raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
107
- # Special case for `kwargs` used in deprecation warning added to schedulers
108
- # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
109
- # or solve in a more general way.
110
- kwargs.pop("kwargs", None)
111
-
112
- if not hasattr(self, "_internal_dict"):
113
- internal_dict = kwargs
114
- else:
115
- previous_dict = dict(self._internal_dict)
116
- internal_dict = {**self._internal_dict, **kwargs}
117
- logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
118
-
119
- self._internal_dict = FrozenDict(internal_dict)
120
-
121
- def __getattr__(self, name: str) -> Any:
122
- """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
123
- config attributes directly. See https://github.com/huggingface/diffusers/pull/3129
124
-
125
- Tihs funtion is mostly copied from PyTorch's __getattr__ overwrite:
126
- https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
127
- """
128
-
129
- is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
130
- is_attribute = name in self.__dict__
131
-
132
- if is_in_config and not is_attribute:
133
- deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'."
134
- deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
135
- return self._internal_dict[name]
136
-
137
- raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
138
-
139
- def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
140
- """
141
- Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the
142
- [`~ConfigMixin.from_config`] class method.
143
-
144
- Args:
145
- save_directory (`str` or `os.PathLike`):
146
- Directory where the configuration JSON file is saved (will be created if it does not exist).
147
- """
148
- if os.path.isfile(save_directory):
149
- raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
150
-
151
- os.makedirs(save_directory, exist_ok=True)
152
-
153
- # If we save using the predefined names, we can load using `from_config`
154
- output_config_file = os.path.join(save_directory, self.config_name)
155
-
156
- self.to_json_file(output_config_file)
157
- logger.info(f"Configuration saved in {output_config_file}")
158
-
159
- @classmethod
160
- def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
161
- r"""
162
- Instantiate a Python class from a config dictionary.
163
-
164
- Parameters:
165
- config (`Dict[str, Any]`):
166
- A config dictionary from which the Python class is instantiated. Make sure to only load configuration
167
- files of compatible classes.
168
- return_unused_kwargs (`bool`, *optional*, defaults to `False`):
169
- Whether kwargs that are not consumed by the Python class should be returned or not.
170
- kwargs (remaining dictionary of keyword arguments, *optional*):
171
- Can be used to update the configuration object (after it is loaded) and initiate the Python class.
172
- `**kwargs` are passed directly to the underlying scheduler/model's `__init__` method and eventually
173
- overwrite the same named arguments in `config`.
174
-
175
- Returns:
176
- [`ModelMixin`] or [`SchedulerMixin`]:
177
- A model or scheduler object instantiated from a config dictionary.
178
-
179
- Examples:
180
-
181
- ```python
182
- >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
183
-
184
- >>> # Download scheduler from huggingface.co and cache.
185
- >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
186
-
187
- >>> # Instantiate DDIM scheduler class with same config as DDPM
188
- >>> scheduler = DDIMScheduler.from_config(scheduler.config)
189
-
190
- >>> # Instantiate PNDM scheduler class with same config as DDPM
191
- >>> scheduler = PNDMScheduler.from_config(scheduler.config)
192
- ```
193
- """
194
- # <===== TO BE REMOVED WITH DEPRECATION
195
- # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
196
- if "pretrained_model_name_or_path" in kwargs:
197
- config = kwargs.pop("pretrained_model_name_or_path")
198
-
199
- if config is None:
200
- raise ValueError("Please make sure to provide a config as the first positional argument.")
201
- # ======>
202
-
203
- if not isinstance(config, dict):
204
- deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
205
- if "Scheduler" in cls.__name__:
206
- deprecation_message += (
207
- f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
208
- " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
209
- " be removed in v1.0.0."
210
- )
211
- elif "Model" in cls.__name__:
212
- deprecation_message += (
213
- f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
214
- f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
215
- " instead. This functionality will be removed in v1.0.0."
216
- )
217
- deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
218
- config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
219
-
220
- init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
221
-
222
- # Allow dtype to be specified on initialization
223
- if "dtype" in unused_kwargs:
224
- init_dict["dtype"] = unused_kwargs.pop("dtype")
225
-
226
- # add possible deprecated kwargs
227
- for deprecated_kwarg in cls._deprecated_kwargs:
228
- if deprecated_kwarg in unused_kwargs:
229
- init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
230
-
231
- # Return model and optionally state and/or unused_kwargs
232
- model = cls(**init_dict)
233
-
234
- # make sure to also save config parameters that might be used for compatible classes
235
- model.register_to_config(**hidden_dict)
236
-
237
- # add hidden kwargs of compatible classes to unused_kwargs
238
- unused_kwargs = {**unused_kwargs, **hidden_dict}
239
-
240
- if return_unused_kwargs:
241
- return (model, unused_kwargs)
242
- else:
243
- return model
244
-
245
- @classmethod
246
- def get_config_dict(cls, *args, **kwargs):
247
- deprecation_message = (
248
- f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
249
- " removed in version v1.0.0"
250
- )
251
- deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
252
- return cls.load_config(*args, **kwargs)
253
-
254
- @classmethod
255
- def load_config(
256
- cls,
257
- pretrained_model_name_or_path: Union[str, os.PathLike],
258
- return_unused_kwargs=False,
259
- return_commit_hash=False,
260
- **kwargs,
261
- ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
262
- r"""
263
- Load a model or scheduler configuration.
264
-
265
- Parameters:
266
- pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
267
- Can be either:
268
-
269
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
270
- the Hub.
271
- - A path to a *directory* (for example `./my_model_directory`) containing model weights saved with
272
- [`~ConfigMixin.save_config`].
273
-
274
- cache_dir (`Union[str, os.PathLike]`, *optional*):
275
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
276
- is not used.
277
- force_download (`bool`, *optional*, defaults to `False`):
278
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
279
- cached versions if they exist.
280
- resume_download (`bool`, *optional*, defaults to `False`):
281
- Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
282
- incompletely downloaded files are deleted.
283
- proxies (`Dict[str, str]`, *optional*):
284
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
285
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
286
- output_loading_info(`bool`, *optional*, defaults to `False`):
287
- Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
288
- local_files_only (`bool`, *optional*, defaults to `False`):
289
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
290
- won't be downloaded from the Hub.
291
- use_auth_token (`str` or *bool*, *optional*):
292
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
293
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
294
- revision (`str`, *optional*, defaults to `"main"`):
295
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
296
- allowed by Git.
297
- subfolder (`str`, *optional*, defaults to `""`):
298
- The subfolder location of a model file within a larger model repository on the Hub or locally.
299
- return_unused_kwargs (`bool`, *optional*, defaults to `False):
300
- Whether unused keyword arguments of the config are returned.
301
- return_commit_hash (`bool`, *optional*, defaults to `False):
302
- Whether the `commit_hash` of the loaded configuration are returned.
303
-
304
- Returns:
305
- `dict`:
306
- A dictionary of all the parameters stored in a JSON configuration file.
307
-
308
- """
309
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
310
- force_download = kwargs.pop("force_download", False)
311
- resume_download = kwargs.pop("resume_download", False)
312
- proxies = kwargs.pop("proxies", None)
313
- use_auth_token = kwargs.pop("use_auth_token", None)
314
- local_files_only = kwargs.pop("local_files_only", False)
315
- revision = kwargs.pop("revision", None)
316
- _ = kwargs.pop("mirror", None)
317
- subfolder = kwargs.pop("subfolder", None)
318
- user_agent = kwargs.pop("user_agent", {})
319
-
320
- user_agent = {**user_agent, "file_type": "config"}
321
- user_agent = http_user_agent(user_agent)
322
-
323
- pretrained_model_name_or_path = str(pretrained_model_name_or_path)
324
-
325
- if cls.config_name is None:
326
- raise ValueError(
327
- "`self.config_name` is not defined. Note that one should not load a config from "
328
- "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
329
- )
330
-
331
- if os.path.isfile(pretrained_model_name_or_path):
332
- config_file = pretrained_model_name_or_path
333
- elif os.path.isdir(pretrained_model_name_or_path):
334
- if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
335
- # Load from a PyTorch checkpoint
336
- config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
337
- elif subfolder is not None and os.path.isfile(
338
- os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
339
- ):
340
- config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
341
- else:
342
- raise EnvironmentError(
343
- f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
344
- )
345
- else:
346
- try:
347
- # Load from URL or cache if already cached
348
- config_file = hf_hub_download(
349
- pretrained_model_name_or_path,
350
- filename=cls.config_name,
351
- cache_dir=cache_dir,
352
- force_download=force_download,
353
- proxies=proxies,
354
- resume_download=resume_download,
355
- local_files_only=local_files_only,
356
- use_auth_token=use_auth_token,
357
- user_agent=user_agent,
358
- subfolder=subfolder,
359
- revision=revision,
360
- )
361
- except RepositoryNotFoundError:
362
- raise EnvironmentError(
363
- f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
364
- " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
365
- " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
366
- " login`."
367
- )
368
- except RevisionNotFoundError:
369
- raise EnvironmentError(
370
- f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
371
- " this model name. Check the model page at"
372
- f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
373
- )
374
- except EntryNotFoundError:
375
- raise EnvironmentError(
376
- f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
377
- )
378
- except HTTPError as err:
379
- raise EnvironmentError(
380
- "There was a specific connection error when trying to load"
381
- f" {pretrained_model_name_or_path}:\n{err}"
382
- )
383
- except ValueError:
384
- raise EnvironmentError(
385
- f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
386
- f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
387
- f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
388
- " run the library in offline mode at"
389
- " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
390
- )
391
- except EnvironmentError:
392
- raise EnvironmentError(
393
- f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
394
- "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
395
- f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
396
- f"containing a {cls.config_name} file"
397
- )
398
-
399
- try:
400
- # Load config dict
401
- config_dict = cls._dict_from_json_file(config_file)
402
-
403
- commit_hash = extract_commit_hash(config_file)
404
- except (json.JSONDecodeError, UnicodeDecodeError):
405
- raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
406
-
407
- if not (return_unused_kwargs or return_commit_hash):
408
- return config_dict
409
-
410
- outputs = (config_dict,)
411
-
412
- if return_unused_kwargs:
413
- outputs += (kwargs,)
414
-
415
- if return_commit_hash:
416
- outputs += (commit_hash,)
417
-
418
- return outputs
419
-
420
- @staticmethod
421
- def _get_init_keys(cls):
422
- return set(dict(inspect.signature(cls.__init__).parameters).keys())
423
-
424
- @classmethod
425
- def extract_init_dict(cls, config_dict, **kwargs):
426
- # Skip keys that were not present in the original config, so default __init__ values were used
427
- used_defaults = config_dict.get("_use_default_values", [])
428
- config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"}
429
-
430
- # 0. Copy origin config dict
431
- original_dict = dict(config_dict.items())
432
-
433
- # 1. Retrieve expected config attributes from __init__ signature
434
- expected_keys = cls._get_init_keys(cls)
435
- expected_keys.remove("self")
436
- # remove general kwargs if present in dict
437
- if "kwargs" in expected_keys:
438
- expected_keys.remove("kwargs")
439
- # remove flax internal keys
440
- if hasattr(cls, "_flax_internal_args"):
441
- for arg in cls._flax_internal_args:
442
- expected_keys.remove(arg)
443
-
444
- # 2. Remove attributes that cannot be expected from expected config attributes
445
- # remove keys to be ignored
446
- if len(cls.ignore_for_config) > 0:
447
- expected_keys = expected_keys - set(cls.ignore_for_config)
448
-
449
- # load diffusers library to import compatible and original scheduler
450
- diffusers_library = importlib.import_module(__name__.split(".")[0])
451
-
452
- if cls.has_compatibles:
453
- compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
454
- else:
455
- compatible_classes = []
456
-
457
- expected_keys_comp_cls = set()
458
- for c in compatible_classes:
459
- expected_keys_c = cls._get_init_keys(c)
460
- expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
461
- expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
462
- config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
463
-
464
- # remove attributes from orig class that cannot be expected
465
- orig_cls_name = config_dict.pop("_class_name", cls.__name__)
466
- if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name):
467
- orig_cls = getattr(diffusers_library, orig_cls_name)
468
- unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
469
- config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
470
-
471
- # remove private attributes
472
- config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
473
-
474
- # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
475
- init_dict = {}
476
- for key in expected_keys:
477
- # if config param is passed to kwarg and is present in config dict
478
- # it should overwrite existing config dict key
479
- if key in kwargs and key in config_dict:
480
- config_dict[key] = kwargs.pop(key)
481
-
482
- if key in kwargs:
483
- # overwrite key
484
- init_dict[key] = kwargs.pop(key)
485
- elif key in config_dict:
486
- # use value from config dict
487
- init_dict[key] = config_dict.pop(key)
488
-
489
- # 4. Give nice warning if unexpected values have been passed
490
- if len(config_dict) > 0:
491
- logger.warning(
492
- f"The config attributes {config_dict} were passed to {cls.__name__}, "
493
- "but are not expected and will be ignored. Please verify your "
494
- f"{cls.config_name} configuration file."
495
- )
496
-
497
- # 5. Give nice info if config attributes are initiliazed to default because they have not been passed
498
- passed_keys = set(init_dict.keys())
499
- if len(expected_keys - passed_keys) > 0:
500
- logger.info(
501
- f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
502
- )
503
-
504
- # 6. Define unused keyword arguments
505
- unused_kwargs = {**config_dict, **kwargs}
506
-
507
- # 7. Define "hidden" config parameters that were saved for compatible classes
508
- hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
509
-
510
- return init_dict, unused_kwargs, hidden_config_dict
511
-
512
- @classmethod
513
- def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
514
- with open(json_file, "r", encoding="utf-8") as reader:
515
- text = reader.read()
516
- return json.loads(text)
517
-
518
- def __repr__(self):
519
- return f"{self.__class__.__name__} {self.to_json_string()}"
520
-
521
- @property
522
- def config(self) -> Dict[str, Any]:
523
- """
524
- Returns the config of the class as a frozen dictionary
525
-
526
- Returns:
527
- `Dict[str, Any]`: Config of the class.
528
- """
529
- return self._internal_dict
530
-
531
- def to_json_string(self) -> str:
532
- """
533
- Serializes the configuration instance to a JSON string.
534
-
535
- Returns:
536
- `str`:
537
- String containing all the attributes that make up the configuration instance in JSON format.
538
- """
539
- config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
540
- config_dict["_class_name"] = self.__class__.__name__
541
- config_dict["_diffusers_version"] = __version__
542
-
543
- def to_json_saveable(value):
544
- if isinstance(value, np.ndarray):
545
- value = value.tolist()
546
- elif isinstance(value, PosixPath):
547
- value = str(value)
548
- return value
549
-
550
- config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
551
- # Don't save "_ignore_files" or "_use_default_values"
552
- config_dict.pop("_ignore_files", None)
553
- config_dict.pop("_use_default_values", None)
554
-
555
- return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
556
-
557
- def to_json_file(self, json_file_path: Union[str, os.PathLike]):
558
- """
559
- Save the configuration instance's parameters to a JSON file.
560
-
561
- Args:
562
- json_file_path (`str` or `os.PathLike`):
563
- Path to the JSON file to save a configuration instance's parameters.
564
- """
565
- with open(json_file_path, "w", encoding="utf-8") as writer:
566
- writer.write(self.to_json_string())
567
-
568
-
569
- def register_to_config(init):
570
- r"""
571
- Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
572
- automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
573
- shouldn't be registered in the config, use the `ignore_for_config` class variable
574
-
575
- Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
576
- """
577
-
578
- @functools.wraps(init)
579
- def inner_init(self, *args, **kwargs):
580
- # Ignore private kwargs in the init.
581
- init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
582
- config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
583
- if not isinstance(self, ConfigMixin):
584
- raise RuntimeError(
585
- f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
586
- "not inherit from `ConfigMixin`."
587
- )
588
-
589
- ignore = getattr(self, "ignore_for_config", [])
590
- # Get positional arguments aligned with kwargs
591
- new_kwargs = {}
592
- signature = inspect.signature(init)
593
- parameters = {
594
- name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
595
- }
596
- for arg, name in zip(args, parameters.keys()):
597
- new_kwargs[name] = arg
598
-
599
- # Then add all kwargs
600
- new_kwargs.update(
601
- {
602
- k: init_kwargs.get(k, default)
603
- for k, default in parameters.items()
604
- if k not in ignore and k not in new_kwargs
605
- }
606
- )
607
-
608
- # Take note of the parameters that were not present in the loaded config
609
- if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
610
- new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
611
-
612
- new_kwargs = {**config_init_kwargs, **new_kwargs}
613
- getattr(self, "register_to_config")(**new_kwargs)
614
- init(self, *args, **init_kwargs)
615
-
616
- return inner_init
617
-
618
-
619
- def flax_register_to_config(cls):
620
- original_init = cls.__init__
621
-
622
- @functools.wraps(original_init)
623
- def init(self, *args, **kwargs):
624
- if not isinstance(self, ConfigMixin):
625
- raise RuntimeError(
626
- f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
627
- "not inherit from `ConfigMixin`."
628
- )
629
-
630
- # Ignore private kwargs in the init. Retrieve all passed attributes
631
- init_kwargs = dict(kwargs.items())
632
-
633
- # Retrieve default values
634
- fields = dataclasses.fields(self)
635
- default_kwargs = {}
636
- for field in fields:
637
- # ignore flax specific attributes
638
- if field.name in self._flax_internal_args:
639
- continue
640
- if type(field.default) == dataclasses._MISSING_TYPE:
641
- default_kwargs[field.name] = None
642
- else:
643
- default_kwargs[field.name] = getattr(self, field.name)
644
-
645
- # Make sure init_kwargs override default kwargs
646
- new_kwargs = {**default_kwargs, **init_kwargs}
647
- # dtype should be part of `init_kwargs`, but not `new_kwargs`
648
- if "dtype" in new_kwargs:
649
- new_kwargs.pop("dtype")
650
-
651
- # Get positional arguments aligned with kwargs
652
- for i, arg in enumerate(args):
653
- name = fields[i].name
654
- new_kwargs[name] = arg
655
-
656
- # Take note of the parameters that were not present in the loaded config
657
- if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
658
- new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
659
-
660
- getattr(self, "register_to_config")(**new_kwargs)
661
- original_init(self, *args, **kwargs)
662
-
663
- cls.__init__ = init
664
- return cls
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4DoF/diffusers/dependency_versions_check.py DELETED
@@ -1,47 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import sys
15
-
16
- from .dependency_versions_table import deps
17
- from .utils.versions import require_version, require_version_core
18
-
19
-
20
- # define which module versions we always want to check at run time
21
- # (usually the ones defined in `install_requires` in setup.py)
22
- #
23
- # order specific notes:
24
- # - tqdm must be checked before tokenizers
25
-
26
- pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split()
27
- if sys.version_info < (3, 7):
28
- pkgs_to_check_at_runtime.append("dataclasses")
29
- if sys.version_info < (3, 8):
30
- pkgs_to_check_at_runtime.append("importlib_metadata")
31
-
32
- for pkg in pkgs_to_check_at_runtime:
33
- if pkg in deps:
34
- if pkg == "tokenizers":
35
- # must be loaded here, or else tqdm check may fail
36
- from .utils import is_tokenizers_available
37
-
38
- if not is_tokenizers_available():
39
- continue # not required, check version only if installed
40
-
41
- require_version_core(deps[pkg])
42
- else:
43
- raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
44
-
45
-
46
- def dep_version_check(pkg, hint=None):
47
- require_version(deps[pkg], hint)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4DoF/diffusers/dependency_versions_table.py DELETED
@@ -1,44 +0,0 @@
1
- # THIS FILE HAS BEEN AUTOGENERATED. To update:
2
- # 1. modify the `_deps` dict in setup.py
3
- # 2. run `make deps_table_update``
4
- deps = {
5
- "Pillow": "Pillow",
6
- "accelerate": "accelerate>=0.11.0",
7
- "compel": "compel==0.1.8",
8
- "black": "black~=23.1",
9
- "datasets": "datasets",
10
- "filelock": "filelock",
11
- "flax": "flax>=0.4.1",
12
- "hf-doc-builder": "hf-doc-builder>=0.3.0",
13
- "huggingface-hub": "huggingface-hub>=0.13.2",
14
- "requests-mock": "requests-mock==1.10.0",
15
- "importlib_metadata": "importlib_metadata",
16
- "invisible-watermark": "invisible-watermark",
17
- "isort": "isort>=5.5.4",
18
- "jax": "jax>=0.2.8,!=0.3.2",
19
- "jaxlib": "jaxlib>=0.1.65",
20
- "Jinja2": "Jinja2",
21
- "k-diffusion": "k-diffusion>=0.0.12",
22
- "torchsde": "torchsde",
23
- "note_seq": "note_seq",
24
- "librosa": "librosa",
25
- "numpy": "numpy",
26
- "omegaconf": "omegaconf",
27
- "parameterized": "parameterized",
28
- "protobuf": "protobuf>=3.20.3,<4",
29
- "pytest": "pytest",
30
- "pytest-timeout": "pytest-timeout",
31
- "pytest-xdist": "pytest-xdist",
32
- "ruff": "ruff>=0.0.241",
33
- "safetensors": "safetensors",
34
- "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
35
- "scipy": "scipy",
36
- "onnx": "onnx",
37
- "regex": "regex!=2019.12.17",
38
- "requests": "requests",
39
- "tensorboard": "tensorboard",
40
- "torch": "torch>=1.4",
41
- "torchvision": "torchvision",
42
- "transformers": "transformers>=4.25.1",
43
- "urllib3": "urllib3<=2.0.0",
44
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4DoF/diffusers/experimental/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .rl import ValueGuidedRLPipeline
 
 
4DoF/diffusers/experimental/rl/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .value_guided_sampling import ValueGuidedRLPipeline
 
 
4DoF/diffusers/experimental/rl/value_guided_sampling.py DELETED
@@ -1,152 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import numpy as np
16
- import torch
17
- import tqdm
18
-
19
- from ...models.unet_1d import UNet1DModel
20
- from ...pipelines import DiffusionPipeline
21
- from ...utils import randn_tensor
22
- from ...utils.dummy_pt_objects import DDPMScheduler
23
-
24
-
25
- class ValueGuidedRLPipeline(DiffusionPipeline):
26
- r"""
27
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
28
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
29
- Pipeline for sampling actions from a diffusion model trained to predict sequences of states.
30
-
31
- Original implementation inspired by this repository: https://github.com/jannerm/diffuser.
32
-
33
- Parameters:
34
- value_function ([`UNet1DModel`]): A specialized UNet for fine-tuning trajectories base on reward.
35
- unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded trajectories.
36
- scheduler ([`SchedulerMixin`]):
37
- A scheduler to be used in combination with `unet` to denoise the encoded trajectories. Default for this
38
- application is [`DDPMScheduler`].
39
- env: An environment following the OpenAI gym API to act in. For now only Hopper has pretrained models.
40
- """
41
-
42
- def __init__(
43
- self,
44
- value_function: UNet1DModel,
45
- unet: UNet1DModel,
46
- scheduler: DDPMScheduler,
47
- env,
48
- ):
49
- super().__init__()
50
- self.value_function = value_function
51
- self.unet = unet
52
- self.scheduler = scheduler
53
- self.env = env
54
- self.data = env.get_dataset()
55
- self.means = {}
56
- for key in self.data.keys():
57
- try:
58
- self.means[key] = self.data[key].mean()
59
- except: # noqa: E722
60
- pass
61
- self.stds = {}
62
- for key in self.data.keys():
63
- try:
64
- self.stds[key] = self.data[key].std()
65
- except: # noqa: E722
66
- pass
67
- self.state_dim = env.observation_space.shape[0]
68
- self.action_dim = env.action_space.shape[0]
69
-
70
- def normalize(self, x_in, key):
71
- return (x_in - self.means[key]) / self.stds[key]
72
-
73
- def de_normalize(self, x_in, key):
74
- return x_in * self.stds[key] + self.means[key]
75
-
76
- def to_torch(self, x_in):
77
- if type(x_in) is dict:
78
- return {k: self.to_torch(v) for k, v in x_in.items()}
79
- elif torch.is_tensor(x_in):
80
- return x_in.to(self.unet.device)
81
- return torch.tensor(x_in, device=self.unet.device)
82
-
83
- def reset_x0(self, x_in, cond, act_dim):
84
- for key, val in cond.items():
85
- x_in[:, key, act_dim:] = val.clone()
86
- return x_in
87
-
88
- def run_diffusion(self, x, conditions, n_guide_steps, scale):
89
- batch_size = x.shape[0]
90
- y = None
91
- for i in tqdm.tqdm(self.scheduler.timesteps):
92
- # create batch of timesteps to pass into model
93
- timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
94
- for _ in range(n_guide_steps):
95
- with torch.enable_grad():
96
- x.requires_grad_()
97
-
98
- # permute to match dimension for pre-trained models
99
- y = self.value_function(x.permute(0, 2, 1), timesteps).sample
100
- grad = torch.autograd.grad([y.sum()], [x])[0]
101
-
102
- posterior_variance = self.scheduler._get_variance(i)
103
- model_std = torch.exp(0.5 * posterior_variance)
104
- grad = model_std * grad
105
-
106
- grad[timesteps < 2] = 0
107
- x = x.detach()
108
- x = x + scale * grad
109
- x = self.reset_x0(x, conditions, self.action_dim)
110
-
111
- prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
112
-
113
- # TODO: verify deprecation of this kwarg
114
- x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
115
-
116
- # apply conditions to the trajectory (set the initial state)
117
- x = self.reset_x0(x, conditions, self.action_dim)
118
- x = self.to_torch(x)
119
- return x, y
120
-
121
- def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
122
- # normalize the observations and create batch dimension
123
- obs = self.normalize(obs, "observations")
124
- obs = obs[None].repeat(batch_size, axis=0)
125
-
126
- conditions = {0: self.to_torch(obs)}
127
- shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
128
-
129
- # generate initial noise and apply our conditions (to make the trajectories start at current state)
130
- x1 = randn_tensor(shape, device=self.unet.device)
131
- x = self.reset_x0(x1, conditions, self.action_dim)
132
- x = self.to_torch(x)
133
-
134
- # run the diffusion process
135
- x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
136
-
137
- # sort output trajectories by value
138
- sorted_idx = y.argsort(0, descending=True).squeeze()
139
- sorted_values = x[sorted_idx]
140
- actions = sorted_values[:, :, : self.action_dim]
141
- actions = actions.detach().cpu().numpy()
142
- denorm_actions = self.de_normalize(actions, key="actions")
143
-
144
- # select the action with the highest value
145
- if y is not None:
146
- selected_index = 0
147
- else:
148
- # if we didn't run value guiding, select a random action
149
- selected_index = np.random.randint(0, batch_size)
150
-
151
- denorm_actions = denorm_actions[selected_index, 0]
152
- return denorm_actions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4DoF/diffusers/image_processor.py DELETED
@@ -1,366 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import warnings
16
- from typing import List, Optional, Union
17
-
18
- import numpy as np
19
- import PIL
20
- import torch
21
- from PIL import Image
22
-
23
- from .configuration_utils import ConfigMixin, register_to_config
24
- from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
25
-
26
-
27
- class VaeImageProcessor(ConfigMixin):
28
- """
29
- Image processor for VAE.
30
-
31
- Args:
32
- do_resize (`bool`, *optional*, defaults to `True`):
33
- Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
34
- `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
35
- vae_scale_factor (`int`, *optional*, defaults to `8`):
36
- VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
37
- resample (`str`, *optional*, defaults to `lanczos`):
38
- Resampling filter to use when resizing the image.
39
- do_normalize (`bool`, *optional*, defaults to `True`):
40
- Whether to normalize the image to [-1,1].
41
- do_convert_rgb (`bool`, *optional*, defaults to be `False`):
42
- Whether to convert the images to RGB format.
43
- """
44
-
45
- config_name = CONFIG_NAME
46
-
47
- @register_to_config
48
- def __init__(
49
- self,
50
- do_resize: bool = True,
51
- vae_scale_factor: int = 8,
52
- resample: str = "lanczos",
53
- do_normalize: bool = True,
54
- do_convert_rgb: bool = False,
55
- ):
56
- super().__init__()
57
-
58
- @staticmethod
59
- def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image:
60
- """
61
- Convert a numpy image or a batch of images to a PIL image.
62
- """
63
- if images.ndim == 3:
64
- images = images[None, ...]
65
- images = (images * 255).round().astype("uint8")
66
- if images.shape[-1] == 1:
67
- # special case for grayscale (single channel) images
68
- pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
69
- else:
70
- pil_images = [Image.fromarray(image) for image in images]
71
-
72
- return pil_images
73
-
74
- @staticmethod
75
- def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
76
- """
77
- Convert a PIL image or a list of PIL images to NumPy arrays.
78
- """
79
- if not isinstance(images, list):
80
- images = [images]
81
- images = [np.array(image).astype(np.float32) / 255.0 for image in images]
82
- images = np.stack(images, axis=0)
83
-
84
- return images
85
-
86
- @staticmethod
87
- def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
88
- """
89
- Convert a NumPy image to a PyTorch tensor.
90
- """
91
- if images.ndim == 3:
92
- images = images[..., None]
93
-
94
- images = torch.from_numpy(images.transpose(0, 3, 1, 2))
95
- return images
96
-
97
- @staticmethod
98
- def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
99
- """
100
- Convert a PyTorch tensor to a NumPy image.
101
- """
102
- images = images.cpu().permute(0, 2, 3, 1).float().numpy()
103
- return images
104
-
105
- @staticmethod
106
- def normalize(images):
107
- """
108
- Normalize an image array to [-1,1].
109
- """
110
- return 2.0 * images - 1.0
111
-
112
- @staticmethod
113
- def denormalize(images):
114
- """
115
- Denormalize an image array to [0,1].
116
- """
117
- return (images / 2 + 0.5).clamp(0, 1)
118
-
119
- @staticmethod
120
- def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
121
- """
122
- Converts an image to RGB format.
123
- """
124
- image = image.convert("RGB")
125
- return image
126
-
127
- def resize(
128
- self,
129
- image: PIL.Image.Image,
130
- height: Optional[int] = None,
131
- width: Optional[int] = None,
132
- ) -> PIL.Image.Image:
133
- """
134
- Resize a PIL image. Both height and width are downscaled to the next integer multiple of `vae_scale_factor`.
135
- """
136
- if height is None:
137
- height = image.height
138
- if width is None:
139
- width = image.width
140
-
141
- width, height = (
142
- x - x % self.config.vae_scale_factor for x in (width, height)
143
- ) # resize to integer multiple of vae_scale_factor
144
- image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
145
- return image
146
-
147
- def preprocess(
148
- self,
149
- image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
150
- height: Optional[int] = None,
151
- width: Optional[int] = None,
152
- ) -> torch.Tensor:
153
- """
154
- Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
155
- """
156
- supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
157
- if isinstance(image, supported_formats):
158
- image = [image]
159
- elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)):
160
- raise ValueError(
161
- f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}"
162
- )
163
-
164
- if isinstance(image[0], PIL.Image.Image):
165
- if self.config.do_convert_rgb:
166
- image = [self.convert_to_rgb(i) for i in image]
167
- if self.config.do_resize:
168
- image = [self.resize(i, height, width) for i in image]
169
- image = self.pil_to_numpy(image) # to np
170
- image = self.numpy_to_pt(image) # to pt
171
-
172
- elif isinstance(image[0], np.ndarray):
173
- image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
174
- image = self.numpy_to_pt(image)
175
- _, _, height, width = image.shape
176
- if self.config.do_resize and (
177
- height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
178
- ):
179
- raise ValueError(
180
- f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.config.vae_scale_factor}"
181
- f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
182
- )
183
-
184
- elif isinstance(image[0], torch.Tensor):
185
- image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
186
- _, channel, height, width = image.shape
187
-
188
- # don't need any preprocess if the image is latents
189
- if channel == 4:
190
- return image
191
-
192
- if self.config.do_resize and (
193
- height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
194
- ):
195
- raise ValueError(
196
- f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.config.vae_scale_factor}"
197
- f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
198
- )
199
-
200
- # expected range [0,1], normalize to [-1,1]
201
- do_normalize = self.config.do_normalize
202
- if image.min() < 0:
203
- warnings.warn(
204
- "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
205
- f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
206
- FutureWarning,
207
- )
208
- do_normalize = False
209
-
210
- if do_normalize:
211
- image = self.normalize(image)
212
-
213
- return image
214
-
215
- def postprocess(
216
- self,
217
- image: torch.FloatTensor,
218
- output_type: str = "pil",
219
- do_denormalize: Optional[List[bool]] = None,
220
- ):
221
- if not isinstance(image, torch.Tensor):
222
- raise ValueError(
223
- f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
224
- )
225
- if output_type not in ["latent", "pt", "np", "pil"]:
226
- deprecation_message = (
227
- f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
228
- "`pil`, `np`, `pt`, `latent`"
229
- )
230
- deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
231
- output_type = "np"
232
-
233
- if output_type == "latent":
234
- return image
235
-
236
- if do_denormalize is None:
237
- do_denormalize = [self.config.do_normalize] * image.shape[0]
238
-
239
- image = torch.stack(
240
- [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
241
- )
242
-
243
- if output_type == "pt":
244
- return image
245
-
246
- image = self.pt_to_numpy(image)
247
-
248
- if output_type == "np":
249
- return image
250
-
251
- if output_type == "pil":
252
- return self.numpy_to_pil(image)
253
-
254
-
255
- class VaeImageProcessorLDM3D(VaeImageProcessor):
256
- """
257
- Image processor for VAE LDM3D.
258
-
259
- Args:
260
- do_resize (`bool`, *optional*, defaults to `True`):
261
- Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
262
- vae_scale_factor (`int`, *optional*, defaults to `8`):
263
- VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
264
- resample (`str`, *optional*, defaults to `lanczos`):
265
- Resampling filter to use when resizing the image.
266
- do_normalize (`bool`, *optional*, defaults to `True`):
267
- Whether to normalize the image to [-1,1].
268
- """
269
-
270
- config_name = CONFIG_NAME
271
-
272
- @register_to_config
273
- def __init__(
274
- self,
275
- do_resize: bool = True,
276
- vae_scale_factor: int = 8,
277
- resample: str = "lanczos",
278
- do_normalize: bool = True,
279
- ):
280
- super().__init__()
281
-
282
- @staticmethod
283
- def numpy_to_pil(images):
284
- """
285
- Convert a NumPy image or a batch of images to a PIL image.
286
- """
287
- if images.ndim == 3:
288
- images = images[None, ...]
289
- images = (images * 255).round().astype("uint8")
290
- if images.shape[-1] == 1:
291
- # special case for grayscale (single channel) images
292
- pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
293
- else:
294
- pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
295
-
296
- return pil_images
297
-
298
- @staticmethod
299
- def rgblike_to_depthmap(image):
300
- """
301
- Args:
302
- image: RGB-like depth image
303
-
304
- Returns: depth map
305
-
306
- """
307
- return image[:, :, 1] * 2**8 + image[:, :, 2]
308
-
309
- def numpy_to_depth(self, images):
310
- """
311
- Convert a NumPy depth image or a batch of images to a PIL image.
312
- """
313
- if images.ndim == 3:
314
- images = images[None, ...]
315
- images_depth = images[:, :, :, 3:]
316
- if images.shape[-1] == 6:
317
- images_depth = (images_depth * 255).round().astype("uint8")
318
- pil_images = [
319
- Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth
320
- ]
321
- elif images.shape[-1] == 4:
322
- images_depth = (images_depth * 65535.0).astype(np.uint16)
323
- pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth]
324
- else:
325
- raise Exception("Not supported")
326
-
327
- return pil_images
328
-
329
- def postprocess(
330
- self,
331
- image: torch.FloatTensor,
332
- output_type: str = "pil",
333
- do_denormalize: Optional[List[bool]] = None,
334
- ):
335
- if not isinstance(image, torch.Tensor):
336
- raise ValueError(
337
- f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
338
- )
339
- if output_type not in ["latent", "pt", "np", "pil"]:
340
- deprecation_message = (
341
- f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
342
- "`pil`, `np`, `pt`, `latent`"
343
- )
344
- deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
345
- output_type = "np"
346
-
347
- if do_denormalize is None:
348
- do_denormalize = [self.config.do_normalize] * image.shape[0]
349
-
350
- image = torch.stack(
351
- [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
352
- )
353
-
354
- image = self.pt_to_numpy(image)
355
-
356
- if output_type == "np":
357
- if image.shape[-1] == 6:
358
- image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)
359
- else:
360
- image_depth = image[:, :, :, 3:]
361
- return image[:, :, :, :3], image_depth
362
-
363
- if output_type == "pil":
364
- return self.numpy_to_pil(image), self.numpy_to_depth(image)
365
- else:
366
- raise Exception(f"This type {output_type} is not supported")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4DoF/diffusers/loaders.py DELETED
@@ -1,1492 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import os
15
- import warnings
16
- from collections import defaultdict
17
- from pathlib import Path
18
- from typing import Callable, Dict, List, Optional, Union
19
-
20
- import torch
21
- import torch.nn.functional as F
22
- from huggingface_hub import hf_hub_download
23
-
24
- from .models.attention_processor import (
25
- AttnAddedKVProcessor,
26
- AttnAddedKVProcessor2_0,
27
- CustomDiffusionAttnProcessor,
28
- CustomDiffusionXFormersAttnProcessor,
29
- LoRAAttnAddedKVProcessor,
30
- LoRAAttnProcessor,
31
- LoRAAttnProcessor2_0,
32
- LoRAXFormersAttnProcessor,
33
- SlicedAttnAddedKVProcessor,
34
- XFormersAttnProcessor,
35
- )
36
- from .utils import (
37
- DIFFUSERS_CACHE,
38
- HF_HUB_OFFLINE,
39
- TEXT_ENCODER_ATTN_MODULE,
40
- _get_model_file,
41
- deprecate,
42
- is_safetensors_available,
43
- is_transformers_available,
44
- logging,
45
- )
46
-
47
-
48
- if is_safetensors_available():
49
- import safetensors
50
-
51
- if is_transformers_available():
52
- from transformers import PreTrainedModel, PreTrainedTokenizer
53
-
54
-
55
- logger = logging.get_logger(__name__)
56
-
57
- TEXT_ENCODER_NAME = "text_encoder"
58
- UNET_NAME = "unet"
59
-
60
- LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
61
- LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
62
-
63
- TEXT_INVERSION_NAME = "learned_embeds.bin"
64
- TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
65
-
66
- CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
67
- CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
68
-
69
-
70
- class AttnProcsLayers(torch.nn.Module):
71
- def __init__(self, state_dict: Dict[str, torch.Tensor]):
72
- super().__init__()
73
- self.layers = torch.nn.ModuleList(state_dict.values())
74
- self.mapping = dict(enumerate(state_dict.keys()))
75
- self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
76
-
77
- # .processor for unet, .self_attn for text encoder
78
- self.split_keys = [".processor", ".self_attn"]
79
-
80
- # we add a hook to state_dict() and load_state_dict() so that the
81
- # naming fits with `unet.attn_processors`
82
- def map_to(module, state_dict, *args, **kwargs):
83
- new_state_dict = {}
84
- for key, value in state_dict.items():
85
- num = int(key.split(".")[1]) # 0 is always "layers"
86
- new_key = key.replace(f"layers.{num}", module.mapping[num])
87
- new_state_dict[new_key] = value
88
-
89
- return new_state_dict
90
-
91
- def remap_key(key, state_dict):
92
- for k in self.split_keys:
93
- if k in key:
94
- return key.split(k)[0] + k
95
-
96
- raise ValueError(
97
- f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}."
98
- )
99
-
100
- def map_from(module, state_dict, *args, **kwargs):
101
- all_keys = list(state_dict.keys())
102
- for key in all_keys:
103
- replace_key = remap_key(key, state_dict)
104
- new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
105
- state_dict[new_key] = state_dict[key]
106
- del state_dict[key]
107
-
108
- self._register_state_dict_hook(map_to)
109
- self._register_load_state_dict_pre_hook(map_from, with_module=True)
110
-
111
-
112
- class UNet2DConditionLoadersMixin:
113
- text_encoder_name = TEXT_ENCODER_NAME
114
- unet_name = UNET_NAME
115
-
116
- def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
117
- r"""
118
- Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
119
- defined in
120
- [`cross_attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py)
121
- and be a `torch.nn.Module` class.
122
-
123
- Parameters:
124
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
125
- Can be either:
126
-
127
- - A string, the model id (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
128
- the Hub.
129
- - A path to a directory (for example `./my_model_directory`) containing the model weights saved
130
- with [`ModelMixin.save_pretrained`].
131
- - A [torch state
132
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
133
-
134
- cache_dir (`Union[str, os.PathLike]`, *optional*):
135
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
136
- is not used.
137
- force_download (`bool`, *optional*, defaults to `False`):
138
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
139
- cached versions if they exist.
140
- resume_download (`bool`, *optional*, defaults to `False`):
141
- Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
142
- incompletely downloaded files are deleted.
143
- proxies (`Dict[str, str]`, *optional*):
144
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
145
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
146
- local_files_only (`bool`, *optional*, defaults to `False`):
147
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
148
- won't be downloaded from the Hub.
149
- use_auth_token (`str` or *bool*, *optional*):
150
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
151
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
152
- revision (`str`, *optional*, defaults to `"main"`):
153
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
154
- allowed by Git.
155
- subfolder (`str`, *optional*, defaults to `""`):
156
- The subfolder location of a model file within a larger model repository on the Hub or locally.
157
- mirror (`str`, *optional*):
158
- Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
159
- guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
160
- information.
161
-
162
- """
163
-
164
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
165
- force_download = kwargs.pop("force_download", False)
166
- resume_download = kwargs.pop("resume_download", False)
167
- proxies = kwargs.pop("proxies", None)
168
- local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
169
- use_auth_token = kwargs.pop("use_auth_token", None)
170
- revision = kwargs.pop("revision", None)
171
- subfolder = kwargs.pop("subfolder", None)
172
- weight_name = kwargs.pop("weight_name", None)
173
- use_safetensors = kwargs.pop("use_safetensors", None)
174
- # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
175
- # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
176
- network_alpha = kwargs.pop("network_alpha", None)
177
-
178
- if use_safetensors and not is_safetensors_available():
179
- raise ValueError(
180
- "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
181
- )
182
-
183
- allow_pickle = False
184
- if use_safetensors is None:
185
- use_safetensors = is_safetensors_available()
186
- allow_pickle = True
187
-
188
- user_agent = {
189
- "file_type": "attn_procs_weights",
190
- "framework": "pytorch",
191
- }
192
-
193
- model_file = None
194
- if not isinstance(pretrained_model_name_or_path_or_dict, dict):
195
- # Let's first try to load .safetensors weights
196
- if (use_safetensors and weight_name is None) or (
197
- weight_name is not None and weight_name.endswith(".safetensors")
198
- ):
199
- try:
200
- model_file = _get_model_file(
201
- pretrained_model_name_or_path_or_dict,
202
- weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
203
- cache_dir=cache_dir,
204
- force_download=force_download,
205
- resume_download=resume_download,
206
- proxies=proxies,
207
- local_files_only=local_files_only,
208
- use_auth_token=use_auth_token,
209
- revision=revision,
210
- subfolder=subfolder,
211
- user_agent=user_agent,
212
- )
213
- state_dict = safetensors.torch.load_file(model_file, device="cpu")
214
- except IOError as e:
215
- if not allow_pickle:
216
- raise e
217
- # try loading non-safetensors weights
218
- pass
219
- if model_file is None:
220
- model_file = _get_model_file(
221
- pretrained_model_name_or_path_or_dict,
222
- weights_name=weight_name or LORA_WEIGHT_NAME,
223
- cache_dir=cache_dir,
224
- force_download=force_download,
225
- resume_download=resume_download,
226
- proxies=proxies,
227
- local_files_only=local_files_only,
228
- use_auth_token=use_auth_token,
229
- revision=revision,
230
- subfolder=subfolder,
231
- user_agent=user_agent,
232
- )
233
- state_dict = torch.load(model_file, map_location="cpu")
234
- else:
235
- state_dict = pretrained_model_name_or_path_or_dict
236
-
237
- # fill attn processors
238
- attn_processors = {}
239
-
240
- is_lora = all("lora" in k for k in state_dict.keys())
241
- is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
242
-
243
- if is_lora:
244
- is_new_lora_format = all(
245
- key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
246
- )
247
- if is_new_lora_format:
248
- # Strip the `"unet"` prefix.
249
- is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
250
- if is_text_encoder_present:
251
- warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
252
- warnings.warn(warn_message)
253
- unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
254
- state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
255
-
256
- lora_grouped_dict = defaultdict(dict)
257
- for key, value in state_dict.items():
258
- attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
259
- lora_grouped_dict[attn_processor_key][sub_key] = value
260
-
261
- for key, value_dict in lora_grouped_dict.items():
262
- rank = value_dict["to_k_lora.down.weight"].shape[0]
263
- hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
264
-
265
- attn_processor = self
266
- for sub_key in key.split("."):
267
- attn_processor = getattr(attn_processor, sub_key)
268
-
269
- if isinstance(
270
- attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)
271
- ):
272
- cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1]
273
- attn_processor_class = LoRAAttnAddedKVProcessor
274
- else:
275
- cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
276
- if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)):
277
- attn_processor_class = LoRAXFormersAttnProcessor
278
- else:
279
- attn_processor_class = (
280
- LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
281
- )
282
-
283
- attn_processors[key] = attn_processor_class(
284
- hidden_size=hidden_size,
285
- cross_attention_dim=cross_attention_dim,
286
- rank=rank,
287
- network_alpha=network_alpha,
288
- )
289
- attn_processors[key].load_state_dict(value_dict)
290
- elif is_custom_diffusion:
291
- custom_diffusion_grouped_dict = defaultdict(dict)
292
- for key, value in state_dict.items():
293
- if len(value) == 0:
294
- custom_diffusion_grouped_dict[key] = {}
295
- else:
296
- if "to_out" in key:
297
- attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
298
- else:
299
- attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
300
- custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
301
-
302
- for key, value_dict in custom_diffusion_grouped_dict.items():
303
- if len(value_dict) == 0:
304
- attn_processors[key] = CustomDiffusionAttnProcessor(
305
- train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
306
- )
307
- else:
308
- cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
309
- hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
310
- train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
311
- attn_processors[key] = CustomDiffusionAttnProcessor(
312
- train_kv=True,
313
- train_q_out=train_q_out,
314
- hidden_size=hidden_size,
315
- cross_attention_dim=cross_attention_dim,
316
- )
317
- attn_processors[key].load_state_dict(value_dict)
318
- else:
319
- raise ValueError(
320
- f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
321
- )
322
-
323
- # set correct dtype & device
324
- attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
325
-
326
- # set layers
327
- self.set_attn_processor(attn_processors)
328
-
329
- def save_attn_procs(
330
- self,
331
- save_directory: Union[str, os.PathLike],
332
- is_main_process: bool = True,
333
- weight_name: str = None,
334
- save_function: Callable = None,
335
- safe_serialization: bool = False,
336
- **kwargs,
337
- ):
338
- r"""
339
- Save an attention processor to a directory so that it can be reloaded using the
340
- [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method.
341
-
342
- Arguments:
343
- save_directory (`str` or `os.PathLike`):
344
- Directory to save an attention processor to. Will be created if it doesn't exist.
345
- is_main_process (`bool`, *optional*, defaults to `True`):
346
- Whether the process calling this is the main process or not. Useful during distributed training and you
347
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
348
- process to avoid race conditions.
349
- save_function (`Callable`):
350
- The function to use to save the state dictionary. Useful during distributed training when you need to
351
- replace `torch.save` with another method. Can be configured with the environment variable
352
- `DIFFUSERS_SAVE_MODE`.
353
-
354
- """
355
- weight_name = weight_name or deprecate(
356
- "weights_name",
357
- "0.20.0",
358
- "`weights_name` is deprecated, please use `weight_name` instead.",
359
- take_from=kwargs,
360
- )
361
- if os.path.isfile(save_directory):
362
- logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
363
- return
364
-
365
- if save_function is None:
366
- if safe_serialization:
367
-
368
- def save_function(weights, filename):
369
- return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
370
-
371
- else:
372
- save_function = torch.save
373
-
374
- os.makedirs(save_directory, exist_ok=True)
375
-
376
- is_custom_diffusion = any(
377
- isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
378
- for (_, x) in self.attn_processors.items()
379
- )
380
- if is_custom_diffusion:
381
- model_to_save = AttnProcsLayers(
382
- {
383
- y: x
384
- for (y, x) in self.attn_processors.items()
385
- if isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
386
- }
387
- )
388
- state_dict = model_to_save.state_dict()
389
- for name, attn in self.attn_processors.items():
390
- if len(attn.state_dict()) == 0:
391
- state_dict[name] = {}
392
- else:
393
- model_to_save = AttnProcsLayers(self.attn_processors)
394
- state_dict = model_to_save.state_dict()
395
-
396
- if weight_name is None:
397
- if safe_serialization:
398
- weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE
399
- else:
400
- weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME
401
-
402
- # Save the model
403
- save_function(state_dict, os.path.join(save_directory, weight_name))
404
- logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
405
-
406
-
407
- class TextualInversionLoaderMixin:
408
- r"""
409
- Load textual inversion tokens and embeddings to the tokenizer and text encoder.
410
- """
411
-
412
- def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"):
413
- r"""
414
- Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to
415
- be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
416
- inversion token or if the textual inversion token is a single vector, the input prompt is returned.
417
-
418
- Parameters:
419
- prompt (`str` or list of `str`):
420
- The prompt or prompts to guide the image generation.
421
- tokenizer (`PreTrainedTokenizer`):
422
- The tokenizer responsible for encoding the prompt into input tokens.
423
-
424
- Returns:
425
- `str` or list of `str`: The converted prompt
426
- """
427
- if not isinstance(prompt, List):
428
- prompts = [prompt]
429
- else:
430
- prompts = prompt
431
-
432
- prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts]
433
-
434
- if not isinstance(prompt, List):
435
- return prompts[0]
436
-
437
- return prompts
438
-
439
- def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"):
440
- r"""
441
- Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
442
- to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
443
- is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
444
- inversion token or a textual inversion token that is a single vector, the input prompt is simply returned.
445
-
446
- Parameters:
447
- prompt (`str`):
448
- The prompt to guide the image generation.
449
- tokenizer (`PreTrainedTokenizer`):
450
- The tokenizer responsible for encoding the prompt into input tokens.
451
-
452
- Returns:
453
- `str`: The converted prompt
454
- """
455
- tokens = tokenizer.tokenize(prompt)
456
- unique_tokens = set(tokens)
457
- for token in unique_tokens:
458
- if token in tokenizer.added_tokens_encoder:
459
- replacement = token
460
- i = 1
461
- while f"{token}_{i}" in tokenizer.added_tokens_encoder:
462
- replacement += f" {token}_{i}"
463
- i += 1
464
-
465
- prompt = prompt.replace(token, replacement)
466
-
467
- return prompt
468
-
469
- def load_textual_inversion(
470
- self,
471
- pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
472
- token: Optional[Union[str, List[str]]] = None,
473
- **kwargs,
474
- ):
475
- r"""
476
- Load textual inversion embeddings into the text encoder of [`StableDiffusionPipeline`] (both 🤗 Diffusers and
477
- Automatic1111 formats are supported).
478
-
479
- Parameters:
480
- pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`):
481
- Can be either one of the following or a list of them:
482
-
483
- - A string, the *model id* (for example `sd-concepts-library/low-poly-hd-logos-icons`) of a
484
- pretrained model hosted on the Hub.
485
- - A path to a *directory* (for example `./my_text_inversion_directory/`) containing the textual
486
- inversion weights.
487
- - A path to a *file* (for example `./my_text_inversions.pt`) containing textual inversion weights.
488
- - A [torch state
489
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
490
-
491
- token (`str` or `List[str]`, *optional*):
492
- Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a
493
- list, then `token` must also be a list of equal length.
494
- weight_name (`str`, *optional*):
495
- Name of a custom weight file. This should be used when:
496
-
497
- - The saved textual inversion file is in 🤗 Diffusers format, but was saved under a specific weight
498
- name such as `text_inv.bin`.
499
- - The saved textual inversion file is in the Automatic1111 format.
500
- cache_dir (`Union[str, os.PathLike]`, *optional*):
501
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
502
- is not used.
503
- force_download (`bool`, *optional*, defaults to `False`):
504
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
505
- cached versions if they exist.
506
- resume_download (`bool`, *optional*, defaults to `False`):
507
- Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
508
- incompletely downloaded files are deleted.
509
- proxies (`Dict[str, str]`, *optional*):
510
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
511
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
512
- local_files_only (`bool`, *optional*, defaults to `False`):
513
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
514
- won't be downloaded from the Hub.
515
- use_auth_token (`str` or *bool*, *optional*):
516
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
517
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
518
- revision (`str`, *optional*, defaults to `"main"`):
519
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
520
- allowed by Git.
521
- subfolder (`str`, *optional*, defaults to `""`):
522
- The subfolder location of a model file within a larger model repository on the Hub or locally.
523
- mirror (`str`, *optional*):
524
- Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
525
- guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
526
- information.
527
-
528
- Example:
529
-
530
- To load a textual inversion embedding vector in 🤗 Diffusers format:
531
-
532
- ```py
533
- from diffusers import StableDiffusionPipeline
534
- import torch
535
-
536
- model_id = "runwayml/stable-diffusion-v1-5"
537
- pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
538
-
539
- pipe.load_textual_inversion("sd-concepts-library/cat-toy")
540
-
541
- prompt = "A <cat-toy> backpack"
542
-
543
- image = pipe(prompt, num_inference_steps=50).images[0]
544
- image.save("cat-backpack.png")
545
- ```
546
-
547
- To load a textual inversion embedding vector in Automatic1111 format, make sure to download the vector first
548
- (for example from [civitAI](https://civitai.com/models/3036?modelVersionId=9857)) and then load the vector
549
- locally:
550
-
551
- ```py
552
- from diffusers import StableDiffusionPipeline
553
- import torch
554
-
555
- model_id = "runwayml/stable-diffusion-v1-5"
556
- pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
557
-
558
- pipe.load_textual_inversion("./charturnerv2.pt", token="charturnerv2")
559
-
560
- prompt = "charturnerv2, multiple views of the same character in the same outfit, a character turnaround of a woman wearing a black jacket and red shirt, best quality, intricate details."
561
-
562
- image = pipe(prompt, num_inference_steps=50).images[0]
563
- image.save("character.png")
564
- ```
565
-
566
- """
567
- if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer):
568
- raise ValueError(
569
- f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling"
570
- f" `{self.load_textual_inversion.__name__}`"
571
- )
572
-
573
- if not hasattr(self, "text_encoder") or not isinstance(self.text_encoder, PreTrainedModel):
574
- raise ValueError(
575
- f"{self.__class__.__name__} requires `self.text_encoder` of type `PreTrainedModel` for calling"
576
- f" `{self.load_textual_inversion.__name__}`"
577
- )
578
-
579
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
580
- force_download = kwargs.pop("force_download", False)
581
- resume_download = kwargs.pop("resume_download", False)
582
- proxies = kwargs.pop("proxies", None)
583
- local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
584
- use_auth_token = kwargs.pop("use_auth_token", None)
585
- revision = kwargs.pop("revision", None)
586
- subfolder = kwargs.pop("subfolder", None)
587
- weight_name = kwargs.pop("weight_name", None)
588
- use_safetensors = kwargs.pop("use_safetensors", None)
589
-
590
- if use_safetensors and not is_safetensors_available():
591
- raise ValueError(
592
- "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
593
- )
594
-
595
- allow_pickle = False
596
- if use_safetensors is None:
597
- use_safetensors = is_safetensors_available()
598
- allow_pickle = True
599
-
600
- user_agent = {
601
- "file_type": "text_inversion",
602
- "framework": "pytorch",
603
- }
604
-
605
- if not isinstance(pretrained_model_name_or_path, list):
606
- pretrained_model_name_or_paths = [pretrained_model_name_or_path]
607
- else:
608
- pretrained_model_name_or_paths = pretrained_model_name_or_path
609
-
610
- if isinstance(token, str):
611
- tokens = [token]
612
- elif token is None:
613
- tokens = [None] * len(pretrained_model_name_or_paths)
614
- else:
615
- tokens = token
616
-
617
- if len(pretrained_model_name_or_paths) != len(tokens):
618
- raise ValueError(
619
- f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)}"
620
- f"Make sure both lists have the same length."
621
- )
622
-
623
- valid_tokens = [t for t in tokens if t is not None]
624
- if len(set(valid_tokens)) < len(valid_tokens):
625
- raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}")
626
-
627
- token_ids_and_embeddings = []
628
-
629
- for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens):
630
- if not isinstance(pretrained_model_name_or_path, dict):
631
- # 1. Load textual inversion file
632
- model_file = None
633
- # Let's first try to load .safetensors weights
634
- if (use_safetensors and weight_name is None) or (
635
- weight_name is not None and weight_name.endswith(".safetensors")
636
- ):
637
- try:
638
- model_file = _get_model_file(
639
- pretrained_model_name_or_path,
640
- weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
641
- cache_dir=cache_dir,
642
- force_download=force_download,
643
- resume_download=resume_download,
644
- proxies=proxies,
645
- local_files_only=local_files_only,
646
- use_auth_token=use_auth_token,
647
- revision=revision,
648
- subfolder=subfolder,
649
- user_agent=user_agent,
650
- )
651
- state_dict = safetensors.torch.load_file(model_file, device="cpu")
652
- except Exception as e:
653
- if not allow_pickle:
654
- raise e
655
-
656
- model_file = None
657
-
658
- if model_file is None:
659
- model_file = _get_model_file(
660
- pretrained_model_name_or_path,
661
- weights_name=weight_name or TEXT_INVERSION_NAME,
662
- cache_dir=cache_dir,
663
- force_download=force_download,
664
- resume_download=resume_download,
665
- proxies=proxies,
666
- local_files_only=local_files_only,
667
- use_auth_token=use_auth_token,
668
- revision=revision,
669
- subfolder=subfolder,
670
- user_agent=user_agent,
671
- )
672
- state_dict = torch.load(model_file, map_location="cpu")
673
- else:
674
- state_dict = pretrained_model_name_or_path
675
-
676
- # 2. Load token and embedding correcly from file
677
- loaded_token = None
678
- if isinstance(state_dict, torch.Tensor):
679
- if token is None:
680
- raise ValueError(
681
- "You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
682
- )
683
- embedding = state_dict
684
- elif len(state_dict) == 1:
685
- # diffusers
686
- loaded_token, embedding = next(iter(state_dict.items()))
687
- elif "string_to_param" in state_dict:
688
- # A1111
689
- loaded_token = state_dict["name"]
690
- embedding = state_dict["string_to_param"]["*"]
691
-
692
- if token is not None and loaded_token != token:
693
- logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
694
- else:
695
- token = loaded_token
696
-
697
- embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device)
698
-
699
- # 3. Make sure we don't mess up the tokenizer or text encoder
700
- vocab = self.tokenizer.get_vocab()
701
- if token in vocab:
702
- raise ValueError(
703
- f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
704
- )
705
- elif f"{token}_1" in vocab:
706
- multi_vector_tokens = [token]
707
- i = 1
708
- while f"{token}_{i}" in self.tokenizer.added_tokens_encoder:
709
- multi_vector_tokens.append(f"{token}_{i}")
710
- i += 1
711
-
712
- raise ValueError(
713
- f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
714
- )
715
-
716
- is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
717
-
718
- if is_multi_vector:
719
- tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
720
- embeddings = [e for e in embedding] # noqa: C416
721
- else:
722
- tokens = [token]
723
- embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding]
724
-
725
- # add tokens and get ids
726
- self.tokenizer.add_tokens(tokens)
727
- token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
728
- token_ids_and_embeddings += zip(token_ids, embeddings)
729
-
730
- logger.info(f"Loaded textual inversion embedding for {token}.")
731
-
732
- # resize token embeddings and set all new embeddings
733
- self.text_encoder.resize_token_embeddings(len(self.tokenizer))
734
- for token_id, embedding in token_ids_and_embeddings:
735
- self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
736
-
737
-
738
- class LoraLoaderMixin:
739
- r"""
740
- Load LoRA layers into [`UNet2DConditionModel`] and
741
- [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
742
- """
743
- text_encoder_name = TEXT_ENCODER_NAME
744
- unet_name = UNET_NAME
745
-
746
- def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
747
- r"""
748
- Load pretrained LoRA attention processor layers into [`UNet2DConditionModel`] and
749
- [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
750
-
751
- Parameters:
752
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
753
- Can be either:
754
-
755
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
756
- the Hub.
757
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
758
- with [`ModelMixin.save_pretrained`].
759
- - A [torch state
760
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
761
-
762
- cache_dir (`Union[str, os.PathLike]`, *optional*):
763
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
764
- is not used.
765
- force_download (`bool`, *optional*, defaults to `False`):
766
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
767
- cached versions if they exist.
768
- resume_download (`bool`, *optional*, defaults to `False`):
769
- Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
770
- incompletely downloaded files are deleted.
771
- proxies (`Dict[str, str]`, *optional*):
772
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
773
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
774
- local_files_only (`bool`, *optional*, defaults to `False`):
775
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
776
- won't be downloaded from the Hub.
777
- use_auth_token (`str` or *bool*, *optional*):
778
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
779
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
780
- revision (`str`, *optional*, defaults to `"main"`):
781
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
782
- allowed by Git.
783
- subfolder (`str`, *optional*, defaults to `""`):
784
- The subfolder location of a model file within a larger model repository on the Hub or locally.
785
- mirror (`str`, *optional*):
786
- Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
787
- guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
788
- information.
789
-
790
- """
791
- # Load the main state dict first which has the LoRA layers for either of
792
- # UNet and text encoder or both.
793
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
794
- force_download = kwargs.pop("force_download", False)
795
- resume_download = kwargs.pop("resume_download", False)
796
- proxies = kwargs.pop("proxies", None)
797
- local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
798
- use_auth_token = kwargs.pop("use_auth_token", None)
799
- revision = kwargs.pop("revision", None)
800
- subfolder = kwargs.pop("subfolder", None)
801
- weight_name = kwargs.pop("weight_name", None)
802
- use_safetensors = kwargs.pop("use_safetensors", None)
803
-
804
- # set lora scale to a reasonable default
805
- self._lora_scale = 1.0
806
-
807
- if use_safetensors and not is_safetensors_available():
808
- raise ValueError(
809
- "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
810
- )
811
-
812
- allow_pickle = False
813
- if use_safetensors is None:
814
- use_safetensors = is_safetensors_available()
815
- allow_pickle = True
816
-
817
- user_agent = {
818
- "file_type": "attn_procs_weights",
819
- "framework": "pytorch",
820
- }
821
-
822
- model_file = None
823
- if not isinstance(pretrained_model_name_or_path_or_dict, dict):
824
- # Let's first try to load .safetensors weights
825
- if (use_safetensors and weight_name is None) or (
826
- weight_name is not None and weight_name.endswith(".safetensors")
827
- ):
828
- try:
829
- model_file = _get_model_file(
830
- pretrained_model_name_or_path_or_dict,
831
- weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
832
- cache_dir=cache_dir,
833
- force_download=force_download,
834
- resume_download=resume_download,
835
- proxies=proxies,
836
- local_files_only=local_files_only,
837
- use_auth_token=use_auth_token,
838
- revision=revision,
839
- subfolder=subfolder,
840
- user_agent=user_agent,
841
- )
842
- state_dict = safetensors.torch.load_file(model_file, device="cpu")
843
- except IOError as e:
844
- if not allow_pickle:
845
- raise e
846
- # try loading non-safetensors weights
847
- pass
848
- if model_file is None:
849
- model_file = _get_model_file(
850
- pretrained_model_name_or_path_or_dict,
851
- weights_name=weight_name or LORA_WEIGHT_NAME,
852
- cache_dir=cache_dir,
853
- force_download=force_download,
854
- resume_download=resume_download,
855
- proxies=proxies,
856
- local_files_only=local_files_only,
857
- use_auth_token=use_auth_token,
858
- revision=revision,
859
- subfolder=subfolder,
860
- user_agent=user_agent,
861
- )
862
- state_dict = torch.load(model_file, map_location="cpu")
863
- else:
864
- state_dict = pretrained_model_name_or_path_or_dict
865
-
866
- # Convert kohya-ss Style LoRA attn procs to diffusers attn procs
867
- network_alpha = None
868
- if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()):
869
- state_dict, network_alpha = self._convert_kohya_lora_to_diffusers(state_dict)
870
-
871
- # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
872
- # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
873
- # their prefixes.
874
- keys = list(state_dict.keys())
875
- if all(key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in keys):
876
- # Load the layers corresponding to UNet.
877
- unet_keys = [k for k in keys if k.startswith(self.unet_name)]
878
- logger.info(f"Loading {self.unet_name}.")
879
- unet_lora_state_dict = {
880
- k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys
881
- }
882
- self.unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha)
883
-
884
- # Load the layers corresponding to text encoder and make necessary adjustments.
885
- text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)]
886
- text_encoder_lora_state_dict = {
887
- k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
888
- }
889
- if len(text_encoder_lora_state_dict) > 0:
890
- logger.info(f"Loading {self.text_encoder_name}.")
891
- attn_procs_text_encoder = self._load_text_encoder_attn_procs(
892
- text_encoder_lora_state_dict, network_alpha=network_alpha
893
- )
894
- self._modify_text_encoder(attn_procs_text_encoder)
895
-
896
- # save lora attn procs of text encoder so that it can be easily retrieved
897
- self._text_encoder_lora_attn_procs = attn_procs_text_encoder
898
-
899
- # Otherwise, we're dealing with the old format. This means the `state_dict` should only
900
- # contain the module names of the `unet` as its keys WITHOUT any prefix.
901
- elif not all(
902
- key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
903
- ):
904
- self.unet.load_attn_procs(state_dict)
905
- warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
906
- warnings.warn(warn_message)
907
-
908
- @property
909
- def lora_scale(self) -> float:
910
- # property function that returns the lora scale which can be set at run time by the pipeline.
911
- # if _lora_scale has not been set, return 1
912
- return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
913
-
914
- @property
915
- def text_encoder_lora_attn_procs(self):
916
- if hasattr(self, "_text_encoder_lora_attn_procs"):
917
- return self._text_encoder_lora_attn_procs
918
- return
919
-
920
- def _remove_text_encoder_monkey_patch(self):
921
- # Loop over the CLIPAttention module of text_encoder
922
- for name, attn_module in self.text_encoder.named_modules():
923
- if name.endswith(TEXT_ENCODER_ATTN_MODULE):
924
- # Loop over the LoRA layers
925
- for _, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items():
926
- # Retrieve the q/k/v/out projection of CLIPAttention
927
- module = attn_module.get_submodule(text_encoder_attr)
928
- if hasattr(module, "old_forward"):
929
- # restore original `forward` to remove monkey-patch
930
- module.forward = module.old_forward
931
- delattr(module, "old_forward")
932
-
933
- def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
934
- r"""
935
- Monkey-patches the forward passes of attention modules of the text encoder.
936
-
937
- Parameters:
938
- attn_processors: Dict[str, `LoRAAttnProcessor`]:
939
- A dictionary mapping the module names and their corresponding [`~LoRAAttnProcessor`].
940
- """
941
-
942
- # First, remove any monkey-patch that might have been applied before
943
- self._remove_text_encoder_monkey_patch()
944
-
945
- # Loop over the CLIPAttention module of text_encoder
946
- for name, attn_module in self.text_encoder.named_modules():
947
- if name.endswith(TEXT_ENCODER_ATTN_MODULE):
948
- # Loop over the LoRA layers
949
- for attn_proc_attr, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items():
950
- # Retrieve the q/k/v/out projection of CLIPAttention and its corresponding LoRA layer.
951
- module = attn_module.get_submodule(text_encoder_attr)
952
- lora_layer = attn_processors[name].get_submodule(attn_proc_attr)
953
-
954
- # save old_forward to module that can be used to remove monkey-patch
955
- old_forward = module.old_forward = module.forward
956
-
957
- # create a new scope that locks in the old_forward, lora_layer value for each new_forward function
958
- # for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
959
- def make_new_forward(old_forward, lora_layer):
960
- def new_forward(x):
961
- result = old_forward(x) + self.lora_scale * lora_layer(x)
962
- return result
963
-
964
- return new_forward
965
-
966
- # Monkey-patch.
967
- module.forward = make_new_forward(old_forward, lora_layer)
968
-
969
- @property
970
- def _lora_attn_processor_attr_to_text_encoder_attr(self):
971
- return {
972
- "to_q_lora": "q_proj",
973
- "to_k_lora": "k_proj",
974
- "to_v_lora": "v_proj",
975
- "to_out_lora": "out_proj",
976
- }
977
-
978
- def _load_text_encoder_attn_procs(
979
- self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs
980
- ):
981
- r"""
982
- Load pretrained attention processor layers for
983
- [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
984
-
985
- <Tip warning={true}>
986
-
987
- This function is experimental and might change in the future.
988
-
989
- </Tip>
990
-
991
- Parameters:
992
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
993
- Can be either:
994
-
995
- - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
996
- Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
997
- - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
998
- `./my_model_directory/`.
999
- - A [torch state
1000
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
1001
-
1002
- cache_dir (`Union[str, os.PathLike]`, *optional*):
1003
- Path to a directory in which a downloaded pretrained model configuration should be cached if the
1004
- standard cache should not be used.
1005
- force_download (`bool`, *optional*, defaults to `False`):
1006
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1007
- cached versions if they exist.
1008
- resume_download (`bool`, *optional*, defaults to `False`):
1009
- Whether or not to delete incompletely received files. Will attempt to resume the download if such a
1010
- file exists.
1011
- proxies (`Dict[str, str]`, *optional*):
1012
- A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
1013
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1014
- local_files_only (`bool`, *optional*, defaults to `False`):
1015
- Whether or not to only look at local files (i.e., do not try to download the model).
1016
- use_auth_token (`str` or *bool*, *optional*):
1017
- The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
1018
- when running `diffusers-cli login` (stored in `~/.huggingface`).
1019
- revision (`str`, *optional*, defaults to `"main"`):
1020
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
1021
- git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
1022
- identifier allowed by git.
1023
- subfolder (`str`, *optional*, defaults to `""`):
1024
- In case the relevant files are located inside a subfolder of the model repo (either remote in
1025
- huggingface.co or downloaded locally), you can specify the folder name here.
1026
- mirror (`str`, *optional*):
1027
- Mirror source to accelerate downloads in China. If you are from China and have an accessibility
1028
- problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
1029
- Please refer to the mirror site for more information.
1030
-
1031
- Returns:
1032
- `Dict[name, LoRAAttnProcessor]`: Mapping between the module names and their corresponding
1033
- [`LoRAAttnProcessor`].
1034
-
1035
- <Tip>
1036
-
1037
- It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
1038
- models](https://huggingface.co/docs/hub/models-gated#gated-models).
1039
-
1040
- </Tip>
1041
- """
1042
-
1043
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
1044
- force_download = kwargs.pop("force_download", False)
1045
- resume_download = kwargs.pop("resume_download", False)
1046
- proxies = kwargs.pop("proxies", None)
1047
- local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
1048
- use_auth_token = kwargs.pop("use_auth_token", None)
1049
- revision = kwargs.pop("revision", None)
1050
- subfolder = kwargs.pop("subfolder", None)
1051
- weight_name = kwargs.pop("weight_name", None)
1052
- use_safetensors = kwargs.pop("use_safetensors", None)
1053
- network_alpha = kwargs.pop("network_alpha", None)
1054
-
1055
- if use_safetensors and not is_safetensors_available():
1056
- raise ValueError(
1057
- "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
1058
- )
1059
-
1060
- allow_pickle = False
1061
- if use_safetensors is None:
1062
- use_safetensors = is_safetensors_available()
1063
- allow_pickle = True
1064
-
1065
- user_agent = {
1066
- "file_type": "attn_procs_weights",
1067
- "framework": "pytorch",
1068
- }
1069
-
1070
- model_file = None
1071
- if not isinstance(pretrained_model_name_or_path_or_dict, dict):
1072
- # Let's first try to load .safetensors weights
1073
- if (use_safetensors and weight_name is None) or (
1074
- weight_name is not None and weight_name.endswith(".safetensors")
1075
- ):
1076
- try:
1077
- model_file = _get_model_file(
1078
- pretrained_model_name_or_path_or_dict,
1079
- weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
1080
- cache_dir=cache_dir,
1081
- force_download=force_download,
1082
- resume_download=resume_download,
1083
- proxies=proxies,
1084
- local_files_only=local_files_only,
1085
- use_auth_token=use_auth_token,
1086
- revision=revision,
1087
- subfolder=subfolder,
1088
- user_agent=user_agent,
1089
- )
1090
- state_dict = safetensors.torch.load_file(model_file, device="cpu")
1091
- except IOError as e:
1092
- if not allow_pickle:
1093
- raise e
1094
- # try loading non-safetensors weights
1095
- pass
1096
- if model_file is None:
1097
- model_file = _get_model_file(
1098
- pretrained_model_name_or_path_or_dict,
1099
- weights_name=weight_name or LORA_WEIGHT_NAME,
1100
- cache_dir=cache_dir,
1101
- force_download=force_download,
1102
- resume_download=resume_download,
1103
- proxies=proxies,
1104
- local_files_only=local_files_only,
1105
- use_auth_token=use_auth_token,
1106
- revision=revision,
1107
- subfolder=subfolder,
1108
- user_agent=user_agent,
1109
- )
1110
- state_dict = torch.load(model_file, map_location="cpu")
1111
- else:
1112
- state_dict = pretrained_model_name_or_path_or_dict
1113
-
1114
- # fill attn processors
1115
- attn_processors = {}
1116
-
1117
- is_lora = all("lora" in k for k in state_dict.keys())
1118
-
1119
- if is_lora:
1120
- lora_grouped_dict = defaultdict(dict)
1121
- for key, value in state_dict.items():
1122
- attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
1123
- lora_grouped_dict[attn_processor_key][sub_key] = value
1124
-
1125
- for key, value_dict in lora_grouped_dict.items():
1126
- rank = value_dict["to_k_lora.down.weight"].shape[0]
1127
- cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
1128
- hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
1129
-
1130
- attn_processor_class = (
1131
- LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
1132
- )
1133
- attn_processors[key] = attn_processor_class(
1134
- hidden_size=hidden_size,
1135
- cross_attention_dim=cross_attention_dim,
1136
- rank=rank,
1137
- network_alpha=network_alpha,
1138
- )
1139
- attn_processors[key].load_state_dict(value_dict)
1140
-
1141
- else:
1142
- raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.")
1143
-
1144
- # set correct dtype & device
1145
- attn_processors = {
1146
- k: v.to(device=self.device, dtype=self.text_encoder.dtype) for k, v in attn_processors.items()
1147
- }
1148
- return attn_processors
1149
-
1150
- @classmethod
1151
- def save_lora_weights(
1152
- self,
1153
- save_directory: Union[str, os.PathLike],
1154
- unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1155
- text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
1156
- is_main_process: bool = True,
1157
- weight_name: str = None,
1158
- save_function: Callable = None,
1159
- safe_serialization: bool = False,
1160
- ):
1161
- r"""
1162
- Save the LoRA parameters corresponding to the UNet and text encoder.
1163
-
1164
- Arguments:
1165
- save_directory (`str` or `os.PathLike`):
1166
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
1167
- unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
1168
- State dict of the LoRA layers corresponding to the UNet.
1169
- text_encoder_lora_layers (`Dict[str, torch.nn.Module] or `Dict[str, torch.Tensor]`):
1170
- State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
1171
- encoder LoRA state dict because it comes 🤗 Transformers.
1172
- is_main_process (`bool`, *optional*, defaults to `True`):
1173
- Whether the process calling this is the main process or not. Useful during distributed training and you
1174
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
1175
- process to avoid race conditions.
1176
- save_function (`Callable`):
1177
- The function to use to save the state dictionary. Useful during distributed training when you need to
1178
- replace `torch.save` with another method. Can be configured with the environment variable
1179
- `DIFFUSERS_SAVE_MODE`.
1180
- """
1181
- if os.path.isfile(save_directory):
1182
- logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
1183
- return
1184
-
1185
- if save_function is None:
1186
- if safe_serialization:
1187
-
1188
- def save_function(weights, filename):
1189
- return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
1190
-
1191
- else:
1192
- save_function = torch.save
1193
-
1194
- os.makedirs(save_directory, exist_ok=True)
1195
-
1196
- # Create a flat dictionary.
1197
- state_dict = {}
1198
- if unet_lora_layers is not None:
1199
- weights = (
1200
- unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers
1201
- )
1202
-
1203
- unet_lora_state_dict = {f"{self.unet_name}.{module_name}": param for module_name, param in weights.items()}
1204
- state_dict.update(unet_lora_state_dict)
1205
-
1206
- if text_encoder_lora_layers is not None:
1207
- weights = (
1208
- text_encoder_lora_layers.state_dict()
1209
- if isinstance(text_encoder_lora_layers, torch.nn.Module)
1210
- else text_encoder_lora_layers
1211
- )
1212
-
1213
- text_encoder_lora_state_dict = {
1214
- f"{self.text_encoder_name}.{module_name}": param for module_name, param in weights.items()
1215
- }
1216
- state_dict.update(text_encoder_lora_state_dict)
1217
-
1218
- # Save the model
1219
- if weight_name is None:
1220
- if safe_serialization:
1221
- weight_name = LORA_WEIGHT_NAME_SAFE
1222
- else:
1223
- weight_name = LORA_WEIGHT_NAME
1224
-
1225
- save_function(state_dict, os.path.join(save_directory, weight_name))
1226
- logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
1227
-
1228
- def _convert_kohya_lora_to_diffusers(self, state_dict):
1229
- unet_state_dict = {}
1230
- te_state_dict = {}
1231
- network_alpha = None
1232
-
1233
- for key, value in state_dict.items():
1234
- if "lora_down" in key:
1235
- lora_name = key.split(".")[0]
1236
- lora_name_up = lora_name + ".lora_up.weight"
1237
- lora_name_alpha = lora_name + ".alpha"
1238
- if lora_name_alpha in state_dict:
1239
- alpha = state_dict[lora_name_alpha].item()
1240
- if network_alpha is None:
1241
- network_alpha = alpha
1242
- elif network_alpha != alpha:
1243
- raise ValueError("Network alpha is not consistent")
1244
-
1245
- if lora_name.startswith("lora_unet_"):
1246
- diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
1247
- diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
1248
- diffusers_name = diffusers_name.replace("mid.block", "mid_block")
1249
- diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
1250
- diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
1251
- diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
1252
- diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
1253
- diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
1254
- diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
1255
- if "transformer_blocks" in diffusers_name:
1256
- if "attn1" in diffusers_name or "attn2" in diffusers_name:
1257
- diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
1258
- diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
1259
- unet_state_dict[diffusers_name] = value
1260
- unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
1261
- elif lora_name.startswith("lora_te_"):
1262
- diffusers_name = key.replace("lora_te_", "").replace("_", ".")
1263
- diffusers_name = diffusers_name.replace("text.model", "text_model")
1264
- diffusers_name = diffusers_name.replace("self.attn", "self_attn")
1265
- diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
1266
- diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
1267
- diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
1268
- diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
1269
- if "self_attn" in diffusers_name:
1270
- te_state_dict[diffusers_name] = value
1271
- te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
1272
-
1273
- unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()}
1274
- te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()}
1275
- new_state_dict = {**unet_state_dict, **te_state_dict}
1276
- return new_state_dict, network_alpha
1277
-
1278
-
1279
- class FromSingleFileMixin:
1280
- """
1281
- Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`].
1282
- """
1283
-
1284
- @classmethod
1285
- def from_ckpt(cls, *args, **kwargs):
1286
- deprecation_message = "The function `from_ckpt` is deprecated in favor of `from_single_file` and will be removed in diffusers v.0.21. Please make sure to use `StableDiffusionPipeline.from_single_file(...)` instead."
1287
- deprecate("from_ckpt", "0.21.0", deprecation_message, standard_warn=False)
1288
- return cls.from_single_file(*args, **kwargs)
1289
-
1290
- @classmethod
1291
- def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
1292
- r"""
1293
- Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` format. The pipeline
1294
- is set in evaluation mode (`model.eval()`) by default.
1295
-
1296
- Parameters:
1297
- pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
1298
- Can be either:
1299
- - A link to the `.ckpt` file (for example
1300
- `"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
1301
- - A path to a *file* containing all pipeline weights.
1302
- torch_dtype (`str` or `torch.dtype`, *optional*):
1303
- Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
1304
- dtype is automatically derived from the model's weights.
1305
- force_download (`bool`, *optional*, defaults to `False`):
1306
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1307
- cached versions if they exist.
1308
- cache_dir (`Union[str, os.PathLike]`, *optional*):
1309
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
1310
- is not used.
1311
- resume_download (`bool`, *optional*, defaults to `False`):
1312
- Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
1313
- incompletely downloaded files are deleted.
1314
- proxies (`Dict[str, str]`, *optional*):
1315
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
1316
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1317
- local_files_only (`bool`, *optional*, defaults to `False`):
1318
- Whether to only load local model weights and configuration files or not. If set to True, the model
1319
- won't be downloaded from the Hub.
1320
- use_auth_token (`str` or *bool*, *optional*):
1321
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
1322
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
1323
- revision (`str`, *optional*, defaults to `"main"`):
1324
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
1325
- allowed by Git.
1326
- use_safetensors (`bool`, *optional*, defaults to `None`):
1327
- If set to `None`, the safetensors weights are downloaded if they're available **and** if the
1328
- safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
1329
- weights. If set to `False`, safetensors weights are not loaded.
1330
- extract_ema (`bool`, *optional*, defaults to `False`):
1331
- Whether to extract the EMA weights or not. Pass `True` to extract the EMA weights which usually yield
1332
- higher quality images for inference. Non-EMA weights are usually better to continue finetuning.
1333
- upcast_attention (`bool`, *optional*, defaults to `None`):
1334
- Whether the attention computation should always be upcasted.
1335
- image_size (`int`, *optional*, defaults to 512):
1336
- The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
1337
- Diffusion v2 base model. Use 768 for Stable Diffusion v2.
1338
- prediction_type (`str`, *optional*):
1339
- The prediction type the model was trained on. Use `'epsilon'` for all Stable Diffusion v1 models and
1340
- the Stable Diffusion v2 base model. Use `'v_prediction'` for Stable Diffusion v2.
1341
- num_in_channels (`int`, *optional*, defaults to `None`):
1342
- The number of input channels. If `None`, it will be automatically inferred.
1343
- scheduler_type (`str`, *optional*, defaults to `"pndm"`):
1344
- Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
1345
- "ddim"]`.
1346
- load_safety_checker (`bool`, *optional*, defaults to `True`):
1347
- Whether to load the safety checker or not.
1348
- text_encoder (`CLIPTextModel`, *optional*, defaults to `None`):
1349
- An instance of
1350
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) to use,
1351
- specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)
1352
- variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if
1353
- needed.
1354
- tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`):
1355
- An instance of
1356
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer)
1357
- to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by
1358
- itself, if needed.
1359
- kwargs (remaining dictionary of keyword arguments, *optional*):
1360
- Can be used to overwrite load and saveable variables (for example the pipeline components of the
1361
- specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
1362
- method. See example below for more information.
1363
-
1364
- Examples:
1365
-
1366
- ```py
1367
- >>> from diffusers import StableDiffusionPipeline
1368
-
1369
- >>> # Download pipeline from huggingface.co and cache.
1370
- >>> pipeline = StableDiffusionPipeline.from_single_file(
1371
- ... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
1372
- ... )
1373
-
1374
- >>> # Download pipeline from local file
1375
- >>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
1376
- >>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly")
1377
-
1378
- >>> # Enable float16 and move to GPU
1379
- >>> pipeline = StableDiffusionPipeline.from_single_file(
1380
- ... "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
1381
- ... torch_dtype=torch.float16,
1382
- ... )
1383
- >>> pipeline.to("cuda")
1384
- ```
1385
- """
1386
- # import here to avoid circular dependency
1387
- from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
1388
-
1389
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
1390
- resume_download = kwargs.pop("resume_download", False)
1391
- force_download = kwargs.pop("force_download", False)
1392
- proxies = kwargs.pop("proxies", None)
1393
- local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
1394
- use_auth_token = kwargs.pop("use_auth_token", None)
1395
- revision = kwargs.pop("revision", None)
1396
- extract_ema = kwargs.pop("extract_ema", False)
1397
- image_size = kwargs.pop("image_size", None)
1398
- scheduler_type = kwargs.pop("scheduler_type", "pndm")
1399
- num_in_channels = kwargs.pop("num_in_channels", None)
1400
- upcast_attention = kwargs.pop("upcast_attention", None)
1401
- load_safety_checker = kwargs.pop("load_safety_checker", True)
1402
- prediction_type = kwargs.pop("prediction_type", None)
1403
- text_encoder = kwargs.pop("text_encoder", None)
1404
- tokenizer = kwargs.pop("tokenizer", None)
1405
-
1406
- torch_dtype = kwargs.pop("torch_dtype", None)
1407
-
1408
- use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
1409
-
1410
- pipeline_name = cls.__name__
1411
- file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
1412
- from_safetensors = file_extension == "safetensors"
1413
-
1414
- if from_safetensors and use_safetensors is False:
1415
- raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
1416
-
1417
- # TODO: For now we only support stable diffusion
1418
- stable_unclip = None
1419
- model_type = None
1420
- controlnet = False
1421
-
1422
- if pipeline_name == "StableDiffusionControlNetPipeline":
1423
- # Model type will be inferred from the checkpoint.
1424
- controlnet = True
1425
- elif "StableDiffusion" in pipeline_name:
1426
- # Model type will be inferred from the checkpoint.
1427
- pass
1428
- elif pipeline_name == "StableUnCLIPPipeline":
1429
- model_type = "FrozenOpenCLIPEmbedder"
1430
- stable_unclip = "txt2img"
1431
- elif pipeline_name == "StableUnCLIPImg2ImgPipeline":
1432
- model_type = "FrozenOpenCLIPEmbedder"
1433
- stable_unclip = "img2img"
1434
- elif pipeline_name == "PaintByExamplePipeline":
1435
- model_type = "PaintByExample"
1436
- elif pipeline_name == "LDMTextToImagePipeline":
1437
- model_type = "LDMTextToImage"
1438
- else:
1439
- raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
1440
-
1441
- # remove huggingface url
1442
- for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
1443
- if pretrained_model_link_or_path.startswith(prefix):
1444
- pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
1445
-
1446
- # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
1447
- ckpt_path = Path(pretrained_model_link_or_path)
1448
- if not ckpt_path.is_file():
1449
- # get repo_id and (potentially nested) file path of ckpt in repo
1450
- repo_id = "/".join(ckpt_path.parts[:2])
1451
- file_path = "/".join(ckpt_path.parts[2:])
1452
-
1453
- if file_path.startswith("blob/"):
1454
- file_path = file_path[len("blob/") :]
1455
-
1456
- if file_path.startswith("main/"):
1457
- file_path = file_path[len("main/") :]
1458
-
1459
- pretrained_model_link_or_path = hf_hub_download(
1460
- repo_id,
1461
- filename=file_path,
1462
- cache_dir=cache_dir,
1463
- resume_download=resume_download,
1464
- proxies=proxies,
1465
- local_files_only=local_files_only,
1466
- use_auth_token=use_auth_token,
1467
- revision=revision,
1468
- force_download=force_download,
1469
- )
1470
-
1471
- pipe = download_from_original_stable_diffusion_ckpt(
1472
- pretrained_model_link_or_path,
1473
- pipeline_class=cls,
1474
- model_type=model_type,
1475
- stable_unclip=stable_unclip,
1476
- controlnet=controlnet,
1477
- from_safetensors=from_safetensors,
1478
- extract_ema=extract_ema,
1479
- image_size=image_size,
1480
- scheduler_type=scheduler_type,
1481
- num_in_channels=num_in_channels,
1482
- upcast_attention=upcast_attention,
1483
- load_safety_checker=load_safety_checker,
1484
- prediction_type=prediction_type,
1485
- text_encoder=text_encoder,
1486
- tokenizer=tokenizer,
1487
- )
1488
-
1489
- if torch_dtype is not None:
1490
- pipe.to(torch_dtype=torch_dtype)
1491
-
1492
- return pipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4DoF/diffusers/models/__init__.py DELETED
@@ -1,35 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from ..utils import is_flax_available, is_torch_available
16
-
17
-
18
- if is_torch_available():
19
- from .autoencoder_kl import AutoencoderKL
20
- from .controlnet import ControlNetModel
21
- from .dual_transformer_2d import DualTransformer2DModel
22
- from .modeling_utils import ModelMixin
23
- from .prior_transformer import PriorTransformer
24
- from .t5_film_transformer import T5FilmDecoder
25
- from .transformer_2d import Transformer2DModel
26
- from .unet_1d import UNet1DModel
27
- from .unet_2d import UNet2DModel
28
- from .unet_2d_condition import UNet2DConditionModel
29
- from .unet_3d_condition import UNet3DConditionModel
30
- from .vq_model import VQModel
31
-
32
- if is_flax_available():
33
- from .controlnet_flax import FlaxControlNetModel
34
- from .unet_2d_condition_flax import FlaxUNet2DConditionModel
35
- from .vae_flax import FlaxAutoencoderKL
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4DoF/diffusers/models/activations.py DELETED
@@ -1,12 +0,0 @@
1
- from torch import nn
2
-
3
-
4
- def get_activation(act_fn):
5
- if act_fn in ["swish", "silu"]:
6
- return nn.SiLU()
7
- elif act_fn == "mish":
8
- return nn.Mish()
9
- elif act_fn == "gelu":
10
- return nn.GELU()
11
- else:
12
- raise ValueError(f"Unsupported activation function: {act_fn}")
 
 
 
 
 
 
 
 
 
 
 
 
 
4DoF/diffusers/models/attention.py DELETED
@@ -1,392 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from typing import Any, Dict, Optional
15
-
16
- import torch
17
- import torch.nn.functional as F
18
- from torch import nn
19
-
20
- from ..utils import maybe_allow_in_graph
21
- from .activations import get_activation
22
- from .attention_processor import Attention
23
- from .embeddings import CombinedTimestepLabelEmbeddings
24
-
25
-
26
- @maybe_allow_in_graph
27
- class BasicTransformerBlock(nn.Module):
28
- r"""
29
- A basic Transformer block.
30
-
31
- Parameters:
32
- dim (`int`): The number of channels in the input and output.
33
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
34
- attention_head_dim (`int`): The number of channels in each head.
35
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
36
- cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
37
- only_cross_attention (`bool`, *optional*):
38
- Whether to use only cross-attention layers. In this case two cross attention layers are used.
39
- double_self_attention (`bool`, *optional*):
40
- Whether to use two self-attention layers. In this case no cross attention layers are used.
41
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
42
- num_embeds_ada_norm (:
43
- obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
44
- attention_bias (:
45
- obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
46
- """
47
-
48
- def __init__(
49
- self,
50
- dim: int,
51
- num_attention_heads: int,
52
- attention_head_dim: int,
53
- dropout=0.0,
54
- cross_attention_dim: Optional[int] = None,
55
- activation_fn: str = "geglu",
56
- num_embeds_ada_norm: Optional[int] = None,
57
- attention_bias: bool = False,
58
- only_cross_attention: bool = False,
59
- double_self_attention: bool = False,
60
- upcast_attention: bool = False,
61
- norm_elementwise_affine: bool = True,
62
- norm_type: str = "layer_norm",
63
- final_dropout: bool = False,
64
- ):
65
- super().__init__()
66
- self.only_cross_attention = only_cross_attention
67
-
68
- self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
69
- self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
70
-
71
- if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
72
- raise ValueError(
73
- f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
74
- f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
75
- )
76
-
77
- # Define 3 blocks. Each block has its own normalization layer.
78
- # 1. Self-Attn
79
- if self.use_ada_layer_norm:
80
- self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
81
- elif self.use_ada_layer_norm_zero:
82
- self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
83
- else:
84
- self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
85
- self.attn1 = Attention(
86
- query_dim=dim,
87
- heads=num_attention_heads,
88
- dim_head=attention_head_dim,
89
- dropout=dropout,
90
- bias=attention_bias,
91
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
92
- upcast_attention=upcast_attention,
93
- )
94
-
95
- # 2. Cross-Attn
96
- if cross_attention_dim is not None or double_self_attention:
97
- # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
98
- # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
99
- # the second cross attention block.
100
- self.norm2 = (
101
- AdaLayerNorm(dim, num_embeds_ada_norm)
102
- if self.use_ada_layer_norm
103
- else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
104
- )
105
- self.attn2 = Attention(
106
- query_dim=dim,
107
- cross_attention_dim=cross_attention_dim if not double_self_attention else None,
108
- heads=num_attention_heads,
109
- dim_head=attention_head_dim,
110
- dropout=dropout,
111
- bias=attention_bias,
112
- upcast_attention=upcast_attention,
113
- ) # is self-attn if encoder_hidden_states is none
114
- else:
115
- self.norm2 = None
116
- self.attn2 = None
117
-
118
- # 3. Feed-forward
119
- self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
120
- self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
121
-
122
- # let chunk size default to None
123
- self._chunk_size = None
124
- self._chunk_dim = 0
125
-
126
- def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
127
- # Sets chunk feed-forward
128
- self._chunk_size = chunk_size
129
- self._chunk_dim = dim
130
-
131
- def forward(
132
- self,
133
- hidden_states: torch.FloatTensor,
134
- attention_mask: Optional[torch.FloatTensor] = None,
135
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
136
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
137
- timestep: Optional[torch.LongTensor] = None,
138
- posemb: Optional = None,
139
- cross_attention_kwargs: Dict[str, Any] = None,
140
- class_labels: Optional[torch.LongTensor] = None,
141
- ):
142
- # Notice that normalization is always applied before the real computation in the following blocks.
143
- # 1. Self-Attention
144
- if self.use_ada_layer_norm:
145
- norm_hidden_states = self.norm1(hidden_states, timestep)
146
- elif self.use_ada_layer_norm_zero:
147
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
148
- hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
149
- )
150
- else:
151
- norm_hidden_states = self.norm1(hidden_states)
152
-
153
- cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
154
-
155
- attn_output = self.attn1(
156
- norm_hidden_states,
157
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
158
- attention_mask=attention_mask,
159
- posemb=posemb, # todo in self attn, posemb shoule be [pose_in, pose_in]?
160
- **cross_attention_kwargs,
161
- )
162
- if self.use_ada_layer_norm_zero:
163
- attn_output = gate_msa.unsqueeze(1) * attn_output
164
- hidden_states = attn_output + hidden_states
165
-
166
- # 2. Cross-Attention
167
- if self.attn2 is not None:
168
- norm_hidden_states = (
169
- self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
170
- )
171
-
172
- attn_output = self.attn2(
173
- norm_hidden_states,
174
- encoder_hidden_states=encoder_hidden_states,
175
- attention_mask=encoder_attention_mask,
176
- posemb=posemb,
177
- **cross_attention_kwargs,
178
- )
179
- hidden_states = attn_output + hidden_states
180
-
181
- # 3. Feed-forward
182
- norm_hidden_states = self.norm3(hidden_states)
183
-
184
- if self.use_ada_layer_norm_zero:
185
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
186
-
187
- if self._chunk_size is not None:
188
- # "feed_forward_chunk_size" can be used to save memory
189
- if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
190
- raise ValueError(
191
- f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
192
- )
193
-
194
- num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
195
- ff_output = torch.cat(
196
- [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
197
- dim=self._chunk_dim,
198
- )
199
- else:
200
- ff_output = self.ff(norm_hidden_states)
201
-
202
- if self.use_ada_layer_norm_zero:
203
- ff_output = gate_mlp.unsqueeze(1) * ff_output
204
-
205
- hidden_states = ff_output + hidden_states
206
-
207
- return hidden_states
208
-
209
-
210
- class FeedForward(nn.Module):
211
- r"""
212
- A feed-forward layer.
213
-
214
- Parameters:
215
- dim (`int`): The number of channels in the input.
216
- dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
217
- mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
218
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
219
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
220
- final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
221
- """
222
-
223
- def __init__(
224
- self,
225
- dim: int,
226
- dim_out: Optional[int] = None,
227
- mult: int = 4,
228
- dropout: float = 0.0,
229
- activation_fn: str = "geglu",
230
- final_dropout: bool = False,
231
- ):
232
- super().__init__()
233
- inner_dim = int(dim * mult)
234
- dim_out = dim_out if dim_out is not None else dim
235
-
236
- if activation_fn == "gelu":
237
- act_fn = GELU(dim, inner_dim)
238
- if activation_fn == "gelu-approximate":
239
- act_fn = GELU(dim, inner_dim, approximate="tanh")
240
- elif activation_fn == "geglu":
241
- act_fn = GEGLU(dim, inner_dim)
242
- elif activation_fn == "geglu-approximate":
243
- act_fn = ApproximateGELU(dim, inner_dim)
244
-
245
- self.net = nn.ModuleList([])
246
- # project in
247
- self.net.append(act_fn)
248
- # project dropout
249
- self.net.append(nn.Dropout(dropout))
250
- # project out
251
- self.net.append(nn.Linear(inner_dim, dim_out))
252
- # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
253
- if final_dropout:
254
- self.net.append(nn.Dropout(dropout))
255
-
256
- def forward(self, hidden_states):
257
- for module in self.net:
258
- hidden_states = module(hidden_states)
259
- return hidden_states
260
-
261
-
262
- class GELU(nn.Module):
263
- r"""
264
- GELU activation function with tanh approximation support with `approximate="tanh"`.
265
- """
266
-
267
- def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
268
- super().__init__()
269
- self.proj = nn.Linear(dim_in, dim_out)
270
- self.approximate = approximate
271
-
272
- def gelu(self, gate):
273
- if gate.device.type != "mps":
274
- return F.gelu(gate, approximate=self.approximate)
275
- # mps: gelu is not implemented for float16
276
- return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
277
-
278
- def forward(self, hidden_states):
279
- hidden_states = self.proj(hidden_states)
280
- hidden_states = self.gelu(hidden_states)
281
- return hidden_states
282
-
283
-
284
- class GEGLU(nn.Module):
285
- r"""
286
- A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
287
-
288
- Parameters:
289
- dim_in (`int`): The number of channels in the input.
290
- dim_out (`int`): The number of channels in the output.
291
- """
292
-
293
- def __init__(self, dim_in: int, dim_out: int):
294
- super().__init__()
295
- self.proj = nn.Linear(dim_in, dim_out * 2)
296
-
297
- def gelu(self, gate):
298
- if gate.device.type != "mps":
299
- return F.gelu(gate)
300
- # mps: gelu is not implemented for float16
301
- return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
302
-
303
- def forward(self, hidden_states):
304
- hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
305
- return hidden_states * self.gelu(gate)
306
-
307
-
308
- class ApproximateGELU(nn.Module):
309
- """
310
- The approximate form of Gaussian Error Linear Unit (GELU)
311
-
312
- For more details, see section 2: https://arxiv.org/abs/1606.08415
313
- """
314
-
315
- def __init__(self, dim_in: int, dim_out: int):
316
- super().__init__()
317
- self.proj = nn.Linear(dim_in, dim_out)
318
-
319
- def forward(self, x):
320
- x = self.proj(x)
321
- return x * torch.sigmoid(1.702 * x)
322
-
323
-
324
- class AdaLayerNorm(nn.Module):
325
- """
326
- Norm layer modified to incorporate timestep embeddings.
327
- """
328
-
329
- def __init__(self, embedding_dim, num_embeddings):
330
- super().__init__()
331
- self.emb = nn.Embedding(num_embeddings, embedding_dim)
332
- self.silu = nn.SiLU()
333
- self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
334
- self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
335
-
336
- def forward(self, x, timestep):
337
- emb = self.linear(self.silu(self.emb(timestep)))
338
- scale, shift = torch.chunk(emb, 2)
339
- x = self.norm(x) * (1 + scale) + shift
340
- return x
341
-
342
-
343
- class AdaLayerNormZero(nn.Module):
344
- """
345
- Norm layer adaptive layer norm zero (adaLN-Zero).
346
- """
347
-
348
- def __init__(self, embedding_dim, num_embeddings):
349
- super().__init__()
350
-
351
- self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
352
-
353
- self.silu = nn.SiLU()
354
- self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
355
- self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
356
-
357
- def forward(self, x, timestep, class_labels, hidden_dtype=None):
358
- emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
359
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
360
- x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
361
- return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
362
-
363
-
364
- class AdaGroupNorm(nn.Module):
365
- """
366
- GroupNorm layer modified to incorporate timestep embeddings.
367
- """
368
-
369
- def __init__(
370
- self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
371
- ):
372
- super().__init__()
373
- self.num_groups = num_groups
374
- self.eps = eps
375
-
376
- if act_fn is None:
377
- self.act = None
378
- else:
379
- self.act = get_activation(act_fn)
380
-
381
- self.linear = nn.Linear(embedding_dim, out_dim * 2)
382
-
383
- def forward(self, x, emb):
384
- if self.act:
385
- emb = self.act(emb)
386
- emb = self.linear(emb)
387
- emb = emb[:, :, None, None]
388
- scale, shift = emb.chunk(2, dim=1)
389
-
390
- x = F.group_norm(x, self.num_groups, eps=self.eps)
391
- x = x * (1 + scale) + shift
392
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4DoF/diffusers/models/attention_flax.py DELETED
@@ -1,446 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import functools
16
- import math
17
-
18
- import flax.linen as nn
19
- import jax
20
- import jax.numpy as jnp
21
-
22
-
23
- def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
24
- """Multi-head dot product attention with a limited number of queries."""
25
- num_kv, num_heads, k_features = key.shape[-3:]
26
- v_features = value.shape[-1]
27
- key_chunk_size = min(key_chunk_size, num_kv)
28
- query = query / jnp.sqrt(k_features)
29
-
30
- @functools.partial(jax.checkpoint, prevent_cse=False)
31
- def summarize_chunk(query, key, value):
32
- attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision)
33
-
34
- max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
35
- max_score = jax.lax.stop_gradient(max_score)
36
- exp_weights = jnp.exp(attn_weights - max_score)
37
-
38
- exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision)
39
- max_score = jnp.einsum("...qhk->...qh", max_score)
40
-
41
- return (exp_values, exp_weights.sum(axis=-1), max_score)
42
-
43
- def chunk_scanner(chunk_idx):
44
- # julienne key array
45
- key_chunk = jax.lax.dynamic_slice(
46
- operand=key,
47
- start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d]
48
- slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d]
49
- )
50
-
51
- # julienne value array
52
- value_chunk = jax.lax.dynamic_slice(
53
- operand=value,
54
- start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d]
55
- slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d]
56
- )
57
-
58
- return summarize_chunk(query, key_chunk, value_chunk)
59
-
60
- chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))
61
-
62
- global_max = jnp.max(chunk_max, axis=0, keepdims=True)
63
- max_diffs = jnp.exp(chunk_max - global_max)
64
-
65
- chunk_values *= jnp.expand_dims(max_diffs, axis=-1)
66
- chunk_weights *= max_diffs
67
-
68
- all_values = chunk_values.sum(axis=0)
69
- all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)
70
-
71
- return all_values / all_weights
72
-
73
-
74
- def jax_memory_efficient_attention(
75
- query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096
76
- ):
77
- r"""
78
- Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2
79
- https://github.com/AminRezaei0x443/memory-efficient-attention
80
-
81
- Args:
82
- query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head)
83
- key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head)
84
- value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head)
85
- precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`):
86
- numerical precision for computation
87
- query_chunk_size (`int`, *optional*, defaults to 1024):
88
- chunk size to divide query array value must divide query_length equally without remainder
89
- key_chunk_size (`int`, *optional*, defaults to 4096):
90
- chunk size to divide key and value array value must divide key_value_length equally without remainder
91
-
92
- Returns:
93
- (`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head)
94
- """
95
- num_q, num_heads, q_features = query.shape[-3:]
96
-
97
- def chunk_scanner(chunk_idx, _):
98
- # julienne query array
99
- query_chunk = jax.lax.dynamic_slice(
100
- operand=query,
101
- start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d]
102
- slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d]
103
- )
104
-
105
- return (
106
- chunk_idx + query_chunk_size, # unused ignore it
107
- _query_chunk_attention(
108
- query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size
109
- ),
110
- )
111
-
112
- _, res = jax.lax.scan(
113
- f=chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size) # start counter # stop counter
114
- )
115
-
116
- return jnp.concatenate(res, axis=-3) # fuse the chunked result back
117
-
118
-
119
- class FlaxAttention(nn.Module):
120
- r"""
121
- A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
122
-
123
- Parameters:
124
- query_dim (:obj:`int`):
125
- Input hidden states dimension
126
- heads (:obj:`int`, *optional*, defaults to 8):
127
- Number of heads
128
- dim_head (:obj:`int`, *optional*, defaults to 64):
129
- Hidden states dimension inside each head
130
- dropout (:obj:`float`, *optional*, defaults to 0.0):
131
- Dropout rate
132
- use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
133
- enable memory efficient attention https://arxiv.org/abs/2112.05682
134
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
135
- Parameters `dtype`
136
-
137
- """
138
- query_dim: int
139
- heads: int = 8
140
- dim_head: int = 64
141
- dropout: float = 0.0
142
- use_memory_efficient_attention: bool = False
143
- dtype: jnp.dtype = jnp.float32
144
-
145
- def setup(self):
146
- inner_dim = self.dim_head * self.heads
147
- self.scale = self.dim_head**-0.5
148
-
149
- # Weights were exported with old names {to_q, to_k, to_v, to_out}
150
- self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q")
151
- self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
152
- self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
153
-
154
- self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
155
- self.dropout_layer = nn.Dropout(rate=self.dropout)
156
-
157
- def reshape_heads_to_batch_dim(self, tensor):
158
- batch_size, seq_len, dim = tensor.shape
159
- head_size = self.heads
160
- tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
161
- tensor = jnp.transpose(tensor, (0, 2, 1, 3))
162
- tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
163
- return tensor
164
-
165
- def reshape_batch_dim_to_heads(self, tensor):
166
- batch_size, seq_len, dim = tensor.shape
167
- head_size = self.heads
168
- tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
169
- tensor = jnp.transpose(tensor, (0, 2, 1, 3))
170
- tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
171
- return tensor
172
-
173
- def __call__(self, hidden_states, context=None, deterministic=True):
174
- context = hidden_states if context is None else context
175
-
176
- query_proj = self.query(hidden_states)
177
- key_proj = self.key(context)
178
- value_proj = self.value(context)
179
-
180
- query_states = self.reshape_heads_to_batch_dim(query_proj)
181
- key_states = self.reshape_heads_to_batch_dim(key_proj)
182
- value_states = self.reshape_heads_to_batch_dim(value_proj)
183
-
184
- if self.use_memory_efficient_attention:
185
- query_states = query_states.transpose(1, 0, 2)
186
- key_states = key_states.transpose(1, 0, 2)
187
- value_states = value_states.transpose(1, 0, 2)
188
-
189
- # this if statement create a chunk size for each layer of the unet
190
- # the chunk size is equal to the query_length dimension of the deepest layer of the unet
191
-
192
- flatten_latent_dim = query_states.shape[-3]
193
- if flatten_latent_dim % 64 == 0:
194
- query_chunk_size = int(flatten_latent_dim / 64)
195
- elif flatten_latent_dim % 16 == 0:
196
- query_chunk_size = int(flatten_latent_dim / 16)
197
- elif flatten_latent_dim % 4 == 0:
198
- query_chunk_size = int(flatten_latent_dim / 4)
199
- else:
200
- query_chunk_size = int(flatten_latent_dim)
201
-
202
- hidden_states = jax_memory_efficient_attention(
203
- query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4
204
- )
205
-
206
- hidden_states = hidden_states.transpose(1, 0, 2)
207
- else:
208
- # compute attentions
209
- attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
210
- attention_scores = attention_scores * self.scale
211
- attention_probs = nn.softmax(attention_scores, axis=2)
212
-
213
- # attend to values
214
- hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
215
-
216
- hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
217
- hidden_states = self.proj_attn(hidden_states)
218
- return self.dropout_layer(hidden_states, deterministic=deterministic)
219
-
220
-
221
- class FlaxBasicTransformerBlock(nn.Module):
222
- r"""
223
- A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in:
224
- https://arxiv.org/abs/1706.03762
225
-
226
-
227
- Parameters:
228
- dim (:obj:`int`):
229
- Inner hidden states dimension
230
- n_heads (:obj:`int`):
231
- Number of heads
232
- d_head (:obj:`int`):
233
- Hidden states dimension inside each head
234
- dropout (:obj:`float`, *optional*, defaults to 0.0):
235
- Dropout rate
236
- only_cross_attention (`bool`, defaults to `False`):
237
- Whether to only apply cross attention.
238
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
239
- Parameters `dtype`
240
- use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
241
- enable memory efficient attention https://arxiv.org/abs/2112.05682
242
- """
243
- dim: int
244
- n_heads: int
245
- d_head: int
246
- dropout: float = 0.0
247
- only_cross_attention: bool = False
248
- dtype: jnp.dtype = jnp.float32
249
- use_memory_efficient_attention: bool = False
250
-
251
- def setup(self):
252
- # self attention (or cross_attention if only_cross_attention is True)
253
- self.attn1 = FlaxAttention(
254
- self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype
255
- )
256
- # cross attention
257
- self.attn2 = FlaxAttention(
258
- self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype
259
- )
260
- self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
261
- self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
262
- self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
263
- self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
264
- self.dropout_layer = nn.Dropout(rate=self.dropout)
265
-
266
- def __call__(self, hidden_states, context, deterministic=True):
267
- # self attention
268
- residual = hidden_states
269
- if self.only_cross_attention:
270
- hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic)
271
- else:
272
- hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
273
- hidden_states = hidden_states + residual
274
-
275
- # cross attention
276
- residual = hidden_states
277
- hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic)
278
- hidden_states = hidden_states + residual
279
-
280
- # feed forward
281
- residual = hidden_states
282
- hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
283
- hidden_states = hidden_states + residual
284
-
285
- return self.dropout_layer(hidden_states, deterministic=deterministic)
286
-
287
-
288
- class FlaxTransformer2DModel(nn.Module):
289
- r"""
290
- A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:
291
- https://arxiv.org/pdf/1506.02025.pdf
292
-
293
-
294
- Parameters:
295
- in_channels (:obj:`int`):
296
- Input number of channels
297
- n_heads (:obj:`int`):
298
- Number of heads
299
- d_head (:obj:`int`):
300
- Hidden states dimension inside each head
301
- depth (:obj:`int`, *optional*, defaults to 1):
302
- Number of transformers block
303
- dropout (:obj:`float`, *optional*, defaults to 0.0):
304
- Dropout rate
305
- use_linear_projection (`bool`, defaults to `False`): tbd
306
- only_cross_attention (`bool`, defaults to `False`): tbd
307
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
308
- Parameters `dtype`
309
- use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
310
- enable memory efficient attention https://arxiv.org/abs/2112.05682
311
- """
312
- in_channels: int
313
- n_heads: int
314
- d_head: int
315
- depth: int = 1
316
- dropout: float = 0.0
317
- use_linear_projection: bool = False
318
- only_cross_attention: bool = False
319
- dtype: jnp.dtype = jnp.float32
320
- use_memory_efficient_attention: bool = False
321
-
322
- def setup(self):
323
- self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
324
-
325
- inner_dim = self.n_heads * self.d_head
326
- if self.use_linear_projection:
327
- self.proj_in = nn.Dense(inner_dim, dtype=self.dtype)
328
- else:
329
- self.proj_in = nn.Conv(
330
- inner_dim,
331
- kernel_size=(1, 1),
332
- strides=(1, 1),
333
- padding="VALID",
334
- dtype=self.dtype,
335
- )
336
-
337
- self.transformer_blocks = [
338
- FlaxBasicTransformerBlock(
339
- inner_dim,
340
- self.n_heads,
341
- self.d_head,
342
- dropout=self.dropout,
343
- only_cross_attention=self.only_cross_attention,
344
- dtype=self.dtype,
345
- use_memory_efficient_attention=self.use_memory_efficient_attention,
346
- )
347
- for _ in range(self.depth)
348
- ]
349
-
350
- if self.use_linear_projection:
351
- self.proj_out = nn.Dense(inner_dim, dtype=self.dtype)
352
- else:
353
- self.proj_out = nn.Conv(
354
- inner_dim,
355
- kernel_size=(1, 1),
356
- strides=(1, 1),
357
- padding="VALID",
358
- dtype=self.dtype,
359
- )
360
-
361
- self.dropout_layer = nn.Dropout(rate=self.dropout)
362
-
363
- def __call__(self, hidden_states, context, deterministic=True):
364
- batch, height, width, channels = hidden_states.shape
365
- residual = hidden_states
366
- hidden_states = self.norm(hidden_states)
367
- if self.use_linear_projection:
368
- hidden_states = hidden_states.reshape(batch, height * width, channels)
369
- hidden_states = self.proj_in(hidden_states)
370
- else:
371
- hidden_states = self.proj_in(hidden_states)
372
- hidden_states = hidden_states.reshape(batch, height * width, channels)
373
-
374
- for transformer_block in self.transformer_blocks:
375
- hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)
376
-
377
- if self.use_linear_projection:
378
- hidden_states = self.proj_out(hidden_states)
379
- hidden_states = hidden_states.reshape(batch, height, width, channels)
380
- else:
381
- hidden_states = hidden_states.reshape(batch, height, width, channels)
382
- hidden_states = self.proj_out(hidden_states)
383
-
384
- hidden_states = hidden_states + residual
385
- return self.dropout_layer(hidden_states, deterministic=deterministic)
386
-
387
-
388
- class FlaxFeedForward(nn.Module):
389
- r"""
390
- Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's
391
- [`FeedForward`] class, with the following simplifications:
392
- - The activation function is currently hardcoded to a gated linear unit from:
393
- https://arxiv.org/abs/2002.05202
394
- - `dim_out` is equal to `dim`.
395
- - The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`].
396
-
397
- Parameters:
398
- dim (:obj:`int`):
399
- Inner hidden states dimension
400
- dropout (:obj:`float`, *optional*, defaults to 0.0):
401
- Dropout rate
402
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
403
- Parameters `dtype`
404
- """
405
- dim: int
406
- dropout: float = 0.0
407
- dtype: jnp.dtype = jnp.float32
408
-
409
- def setup(self):
410
- # The second linear layer needs to be called
411
- # net_2 for now to match the index of the Sequential layer
412
- self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
413
- self.net_2 = nn.Dense(self.dim, dtype=self.dtype)
414
-
415
- def __call__(self, hidden_states, deterministic=True):
416
- hidden_states = self.net_0(hidden_states, deterministic=deterministic)
417
- hidden_states = self.net_2(hidden_states)
418
- return hidden_states
419
-
420
-
421
- class FlaxGEGLU(nn.Module):
422
- r"""
423
- Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
424
- https://arxiv.org/abs/2002.05202.
425
-
426
- Parameters:
427
- dim (:obj:`int`):
428
- Input hidden states dimension
429
- dropout (:obj:`float`, *optional*, defaults to 0.0):
430
- Dropout rate
431
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
432
- Parameters `dtype`
433
- """
434
- dim: int
435
- dropout: float = 0.0
436
- dtype: jnp.dtype = jnp.float32
437
-
438
- def setup(self):
439
- inner_dim = self.dim * 4
440
- self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
441
- self.dropout_layer = nn.Dropout(rate=self.dropout)
442
-
443
- def __call__(self, hidden_states, deterministic=True):
444
- hidden_states = self.proj(hidden_states)
445
- hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
446
- return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4DoF/diffusers/models/attention_processor.py DELETED
@@ -1,1714 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from typing import Callable, Optional, Union
15
-
16
- import torch
17
- import torch.nn.functional as F
18
- from torch import nn
19
-
20
- from ..utils import deprecate, logging, maybe_allow_in_graph
21
- from ..utils.import_utils import is_xformers_available
22
-
23
-
24
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25
-
26
-
27
- if is_xformers_available():
28
- import xformers
29
- import xformers.ops
30
- else:
31
- xformers = None
32
-
33
-
34
- # 4DoF CaPE
35
- import einops
36
- def rotate_every_two(x):
37
- x = einops.rearrange(x, '... (d j) -> ... d j', j=2)
38
- x1, x2 = x.unbind(dim=-1)
39
- x = torch.stack((-x2, x1), dim=-1)
40
- return einops.rearrange(x, '... d j -> ... (d j)')
41
-
42
- def cape(x, p):
43
- d, l, n = x.shape[-1], p.shape[-2], p.shape[-1]
44
- assert d % (2 * n) == 0
45
- m = einops.repeat(p, 'b l n -> b l (n k)', k=d // n)
46
- return m
47
-
48
- def cape_embed(p1, p2, qq, kk):
49
- """
50
- Embed camera position encoding into attention map
51
- Args:
52
- p1: query pose b, l_q, pose_dim
53
- p2: key pose b, l_k, pose_dim
54
- qq: query feature map b, l_q, feature_dim
55
- kk: key feature map b, l_k, feature_dim
56
-
57
- Returns: cape embedded attention map b, l_q, l_k
58
-
59
- """
60
- assert p1.shape[-1] == p2.shape[-1]
61
- assert qq.shape[-1] == kk.shape[-1]
62
- assert p1.shape[0] == p2.shape[0] == qq.shape[0] == kk.shape[0]
63
- assert p1.shape[1] == qq.shape[1]
64
- assert p2.shape[1] == kk.shape[1]
65
-
66
- m1 = cape(qq, p1)
67
- m2 = cape(kk, p2)
68
-
69
- q = (qq * m1.cos()) + (rotate_every_two(qq) * m1.sin())
70
- k = (kk * m2.cos()) + (rotate_every_two(kk) * m2.sin())
71
-
72
- return q, k
73
-
74
- @maybe_allow_in_graph
75
- class Attention(nn.Module):
76
- r"""
77
- A cross attention layer.
78
-
79
- Parameters:
80
- query_dim (`int`): The number of channels in the query.
81
- cross_attention_dim (`int`, *optional*):
82
- The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
83
- heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
84
- dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
85
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
86
- bias (`bool`, *optional*, defaults to False):
87
- Set to `True` for the query, key, and value linear layers to contain a bias parameter.
88
- """
89
-
90
- def __init__(
91
- self,
92
- query_dim: int,
93
- cross_attention_dim: Optional[int] = None,
94
- heads: int = 8,
95
- dim_head: int = 64,
96
- dropout: float = 0.0,
97
- bias=False,
98
- upcast_attention: bool = False,
99
- upcast_softmax: bool = False,
100
- cross_attention_norm: Optional[str] = None,
101
- cross_attention_norm_num_groups: int = 32,
102
- added_kv_proj_dim: Optional[int] = None,
103
- norm_num_groups: Optional[int] = None,
104
- spatial_norm_dim: Optional[int] = None,
105
- out_bias: bool = True,
106
- scale_qk: bool = True,
107
- only_cross_attention: bool = False,
108
- eps: float = 1e-5,
109
- rescale_output_factor: float = 1.0,
110
- residual_connection: bool = False,
111
- _from_deprecated_attn_block=False,
112
- processor: Optional["AttnProcessor"] = None,
113
- ):
114
- super().__init__()
115
- inner_dim = dim_head * heads
116
- cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
117
- self.upcast_attention = upcast_attention
118
- self.upcast_softmax = upcast_softmax
119
- self.rescale_output_factor = rescale_output_factor
120
- self.residual_connection = residual_connection
121
- self.dropout = dropout
122
-
123
- # we make use of this private variable to know whether this class is loaded
124
- # with an deprecated state dict so that we can convert it on the fly
125
- self._from_deprecated_attn_block = _from_deprecated_attn_block
126
-
127
- self.scale_qk = scale_qk
128
- self.scale = dim_head**-0.5 if self.scale_qk else 1.0
129
-
130
- self.heads = heads
131
- # for slice_size > 0 the attention score computation
132
- # is split across the batch axis to save memory
133
- # You can set slice_size with `set_attention_slice`
134
- self.sliceable_head_dim = heads
135
-
136
- self.added_kv_proj_dim = added_kv_proj_dim
137
- self.only_cross_attention = only_cross_attention
138
-
139
- if self.added_kv_proj_dim is None and self.only_cross_attention:
140
- raise ValueError(
141
- "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
142
- )
143
-
144
- if norm_num_groups is not None:
145
- self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
146
- else:
147
- self.group_norm = None
148
-
149
- if spatial_norm_dim is not None:
150
- self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
151
- else:
152
- self.spatial_norm = None
153
-
154
- if cross_attention_norm is None:
155
- self.norm_cross = None
156
- elif cross_attention_norm == "layer_norm":
157
- self.norm_cross = nn.LayerNorm(cross_attention_dim)
158
- elif cross_attention_norm == "group_norm":
159
- if self.added_kv_proj_dim is not None:
160
- # The given `encoder_hidden_states` are initially of shape
161
- # (batch_size, seq_len, added_kv_proj_dim) before being projected
162
- # to (batch_size, seq_len, cross_attention_dim). The norm is applied
163
- # before the projection, so we need to use `added_kv_proj_dim` as
164
- # the number of channels for the group norm.
165
- norm_cross_num_channels = added_kv_proj_dim
166
- else:
167
- norm_cross_num_channels = cross_attention_dim
168
-
169
- self.norm_cross = nn.GroupNorm(
170
- num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
171
- )
172
- else:
173
- raise ValueError(
174
- f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
175
- )
176
-
177
- self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
178
-
179
- if not self.only_cross_attention:
180
- # only relevant for the `AddedKVProcessor` classes
181
- self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
182
- self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
183
- else:
184
- self.to_k = None
185
- self.to_v = None
186
-
187
- if self.added_kv_proj_dim is not None:
188
- self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim)
189
- self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim)
190
-
191
- self.to_out = nn.ModuleList([])
192
- self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))
193
- self.to_out.append(nn.Dropout(dropout))
194
-
195
- # set attention processor
196
- # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
197
- # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
198
- # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
199
- if processor is None:
200
- processor = (
201
- AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
202
- )
203
- self.set_processor(processor)
204
-
205
- def set_use_memory_efficient_attention_xformers(
206
- self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
207
- ):
208
- is_lora = hasattr(self, "processor") and isinstance(
209
- self.processor,
210
- (LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor),
211
- )
212
- is_custom_diffusion = hasattr(self, "processor") and isinstance(
213
- self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
214
- )
215
- is_added_kv_processor = hasattr(self, "processor") and isinstance(
216
- self.processor,
217
- (
218
- AttnAddedKVProcessor,
219
- AttnAddedKVProcessor2_0,
220
- SlicedAttnAddedKVProcessor,
221
- XFormersAttnAddedKVProcessor,
222
- LoRAAttnAddedKVProcessor,
223
- ),
224
- )
225
-
226
- if use_memory_efficient_attention_xformers:
227
- if is_added_kv_processor and (is_lora or is_custom_diffusion):
228
- raise NotImplementedError(
229
- f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}"
230
- )
231
- if not is_xformers_available():
232
- raise ModuleNotFoundError(
233
- (
234
- "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
235
- " xformers"
236
- ),
237
- name="xformers",
238
- )
239
- elif not torch.cuda.is_available():
240
- raise ValueError(
241
- "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
242
- " only available for GPU "
243
- )
244
- else:
245
- try:
246
- # Make sure we can run the memory efficient attention
247
- _ = xformers.ops.memory_efficient_attention(
248
- torch.randn((1, 2, 40), device="cuda"),
249
- torch.randn((1, 2, 40), device="cuda"),
250
- torch.randn((1, 2, 40), device="cuda"),
251
- )
252
- except Exception as e:
253
- raise e
254
-
255
- if is_lora:
256
- # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
257
- # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
258
- processor = LoRAXFormersAttnProcessor(
259
- hidden_size=self.processor.hidden_size,
260
- cross_attention_dim=self.processor.cross_attention_dim,
261
- rank=self.processor.rank,
262
- attention_op=attention_op,
263
- )
264
- processor.load_state_dict(self.processor.state_dict())
265
- processor.to(self.processor.to_q_lora.up.weight.device)
266
- elif is_custom_diffusion:
267
- processor = CustomDiffusionXFormersAttnProcessor(
268
- train_kv=self.processor.train_kv,
269
- train_q_out=self.processor.train_q_out,
270
- hidden_size=self.processor.hidden_size,
271
- cross_attention_dim=self.processor.cross_attention_dim,
272
- attention_op=attention_op,
273
- )
274
- processor.load_state_dict(self.processor.state_dict())
275
- if hasattr(self.processor, "to_k_custom_diffusion"):
276
- processor.to(self.processor.to_k_custom_diffusion.weight.device)
277
- elif is_added_kv_processor:
278
- # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
279
- # which uses this type of cross attention ONLY because the attention mask of format
280
- # [0, ..., -10.000, ..., 0, ...,] is not supported
281
- # throw warning
282
- logger.info(
283
- "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
284
- )
285
- processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
286
- else:
287
- processor = XFormersAttnProcessor(attention_op=attention_op)
288
- else:
289
- if is_lora:
290
- attn_processor_class = (
291
- LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
292
- )
293
- processor = attn_processor_class(
294
- hidden_size=self.processor.hidden_size,
295
- cross_attention_dim=self.processor.cross_attention_dim,
296
- rank=self.processor.rank,
297
- )
298
- processor.load_state_dict(self.processor.state_dict())
299
- processor.to(self.processor.to_q_lora.up.weight.device)
300
- elif is_custom_diffusion:
301
- processor = CustomDiffusionAttnProcessor(
302
- train_kv=self.processor.train_kv,
303
- train_q_out=self.processor.train_q_out,
304
- hidden_size=self.processor.hidden_size,
305
- cross_attention_dim=self.processor.cross_attention_dim,
306
- )
307
- processor.load_state_dict(self.processor.state_dict())
308
- if hasattr(self.processor, "to_k_custom_diffusion"):
309
- processor.to(self.processor.to_k_custom_diffusion.weight.device)
310
- else:
311
- # set attention processor
312
- # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
313
- # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
314
- # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
315
- processor = (
316
- AttnProcessor2_0()
317
- if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
318
- else AttnProcessor()
319
- )
320
-
321
- self.set_processor(processor)
322
-
323
- def set_attention_slice(self, slice_size):
324
- if slice_size is not None and slice_size > self.sliceable_head_dim:
325
- raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
326
-
327
- if slice_size is not None and self.added_kv_proj_dim is not None:
328
- processor = SlicedAttnAddedKVProcessor(slice_size)
329
- elif slice_size is not None:
330
- processor = SlicedAttnProcessor(slice_size)
331
- elif self.added_kv_proj_dim is not None:
332
- processor = AttnAddedKVProcessor()
333
- else:
334
- # set attention processor
335
- # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
336
- # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
337
- # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
338
- processor = (
339
- AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
340
- )
341
-
342
- self.set_processor(processor)
343
-
344
- def set_processor(self, processor: "AttnProcessor"):
345
- # if current processor is in `self._modules` and if passed `processor` is not, we need to
346
- # pop `processor` from `self._modules`
347
- if (
348
- hasattr(self, "processor")
349
- and isinstance(self.processor, torch.nn.Module)
350
- and not isinstance(processor, torch.nn.Module)
351
- ):
352
- logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
353
- self._modules.pop("processor")
354
-
355
- self.processor = processor
356
-
357
- def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
358
- # The `Attention` class can call different attention processors / attention functions
359
- # here we simply pass along all tensors to the selected processor class
360
- # For standard processors that are defined here, `**cross_attention_kwargs` is empty
361
- return self.processor(
362
- self,
363
- hidden_states,
364
- encoder_hidden_states=encoder_hidden_states,
365
- attention_mask=attention_mask,
366
- **cross_attention_kwargs,
367
- )
368
-
369
- def batch_to_head_dim(self, tensor):
370
- head_size = self.heads
371
- batch_size, seq_len, dim = tensor.shape
372
- tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
373
- tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
374
- return tensor
375
-
376
- def head_to_batch_dim(self, tensor, out_dim=3):
377
- head_size = self.heads
378
- batch_size, seq_len, dim = tensor.shape
379
- tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
380
- tensor = tensor.permute(0, 2, 1, 3)
381
-
382
- if out_dim == 3:
383
- tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
384
-
385
- return tensor
386
-
387
- def get_attention_scores(self, query, key, attention_mask=None):
388
- dtype = query.dtype
389
- if self.upcast_attention:
390
- query = query.float()
391
- key = key.float()
392
-
393
- if attention_mask is None:
394
- baddbmm_input = torch.empty(
395
- query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
396
- )
397
- beta = 0
398
- else:
399
- baddbmm_input = attention_mask
400
- beta = 1
401
-
402
- attention_scores = torch.baddbmm(
403
- baddbmm_input,
404
- query,
405
- key.transpose(-1, -2),
406
- beta=beta,
407
- alpha=self.scale,
408
- )
409
- del baddbmm_input
410
-
411
- if self.upcast_softmax:
412
- attention_scores = attention_scores.float()
413
-
414
- attention_probs = attention_scores.softmax(dim=-1)
415
- del attention_scores
416
-
417
- attention_probs = attention_probs.to(dtype)
418
-
419
- return attention_probs
420
-
421
- def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3):
422
- if batch_size is None:
423
- deprecate(
424
- "batch_size=None",
425
- "0.0.15",
426
- (
427
- "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect"
428
- " attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to"
429
- " `prepare_attention_mask` when preparing the attention_mask."
430
- ),
431
- )
432
- batch_size = 1
433
-
434
- head_size = self.heads
435
- if attention_mask is None:
436
- return attention_mask
437
-
438
- current_length: int = attention_mask.shape[-1]
439
- if current_length != target_length:
440
- if attention_mask.device.type == "mps":
441
- # HACK: MPS: Does not support padding by greater than dimension of input tensor.
442
- # Instead, we can manually construct the padding tensor.
443
- padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
444
- padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
445
- attention_mask = torch.cat([attention_mask, padding], dim=2)
446
- else:
447
- # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
448
- # we want to instead pad by (0, remaining_length), where remaining_length is:
449
- # remaining_length: int = target_length - current_length
450
- # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
451
- attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
452
-
453
- if out_dim == 3:
454
- if attention_mask.shape[0] < batch_size * head_size:
455
- attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
456
- elif out_dim == 4:
457
- attention_mask = attention_mask.unsqueeze(1)
458
- attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
459
-
460
- return attention_mask
461
-
462
- def norm_encoder_hidden_states(self, encoder_hidden_states):
463
- assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
464
-
465
- if isinstance(self.norm_cross, nn.LayerNorm):
466
- encoder_hidden_states = self.norm_cross(encoder_hidden_states)
467
- elif isinstance(self.norm_cross, nn.GroupNorm):
468
- # Group norm norms along the channels dimension and expects
469
- # input to be in the shape of (N, C, *). In this case, we want
470
- # to norm along the hidden dimension, so we need to move
471
- # (batch_size, sequence_length, hidden_size) ->
472
- # (batch_size, hidden_size, sequence_length)
473
- encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
474
- encoder_hidden_states = self.norm_cross(encoder_hidden_states)
475
- encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
476
- else:
477
- assert False
478
-
479
- return encoder_hidden_states
480
-
481
-
482
- class AttnProcessor:
483
- r"""
484
- Default processor for performing attention-related computations.
485
- """
486
-
487
- def __call__(
488
- self,
489
- attn: Attention,
490
- hidden_states,
491
- encoder_hidden_states=None,
492
- attention_mask=None,
493
- temb=None,
494
- ):
495
- residual = hidden_states
496
-
497
- if attn.spatial_norm is not None:
498
- hidden_states = attn.spatial_norm(hidden_states, temb)
499
-
500
- input_ndim = hidden_states.ndim
501
-
502
- if input_ndim == 4:
503
- batch_size, channel, height, width = hidden_states.shape
504
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
505
-
506
- batch_size, sequence_length, _ = (
507
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
508
- )
509
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
510
-
511
- if attn.group_norm is not None:
512
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
513
-
514
- query = attn.to_q(hidden_states)
515
-
516
- if encoder_hidden_states is None:
517
- encoder_hidden_states = hidden_states
518
- elif attn.norm_cross:
519
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
520
-
521
- key = attn.to_k(encoder_hidden_states)
522
- value = attn.to_v(encoder_hidden_states)
523
-
524
- query = attn.head_to_batch_dim(query)
525
- key = attn.head_to_batch_dim(key)
526
- value = attn.head_to_batch_dim(value)
527
-
528
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
529
- hidden_states = torch.bmm(attention_probs, value)
530
- hidden_states = attn.batch_to_head_dim(hidden_states)
531
-
532
- # linear proj
533
- hidden_states = attn.to_out[0](hidden_states)
534
- # dropout
535
- hidden_states = attn.to_out[1](hidden_states)
536
-
537
- if input_ndim == 4:
538
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
539
-
540
- if attn.residual_connection:
541
- hidden_states = hidden_states + residual
542
-
543
- hidden_states = hidden_states / attn.rescale_output_factor
544
-
545
- return hidden_states
546
-
547
-
548
- class LoRALinearLayer(nn.Module):
549
- def __init__(self, in_features, out_features, rank=4, network_alpha=None):
550
- super().__init__()
551
-
552
- if rank > min(in_features, out_features):
553
- raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
554
-
555
- self.down = nn.Linear(in_features, rank, bias=False)
556
- self.up = nn.Linear(rank, out_features, bias=False)
557
- # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
558
- # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
559
- self.network_alpha = network_alpha
560
- self.rank = rank
561
-
562
- nn.init.normal_(self.down.weight, std=1 / rank)
563
- nn.init.zeros_(self.up.weight)
564
-
565
- def forward(self, hidden_states):
566
- orig_dtype = hidden_states.dtype
567
- dtype = self.down.weight.dtype
568
-
569
- down_hidden_states = self.down(hidden_states.to(dtype))
570
- up_hidden_states = self.up(down_hidden_states)
571
-
572
- if self.network_alpha is not None:
573
- up_hidden_states *= self.network_alpha / self.rank
574
-
575
- return up_hidden_states.to(orig_dtype)
576
-
577
-
578
- class LoRAAttnProcessor(nn.Module):
579
- r"""
580
- Processor for implementing the LoRA attention mechanism.
581
-
582
- Args:
583
- hidden_size (`int`, *optional*):
584
- The hidden size of the attention layer.
585
- cross_attention_dim (`int`, *optional*):
586
- The number of channels in the `encoder_hidden_states`.
587
- rank (`int`, defaults to 4):
588
- The dimension of the LoRA update matrices.
589
- network_alpha (`int`, *optional*):
590
- Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
591
- """
592
-
593
- def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
594
- super().__init__()
595
-
596
- self.hidden_size = hidden_size
597
- self.cross_attention_dim = cross_attention_dim
598
- self.rank = rank
599
-
600
- self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
601
- self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
602
- self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
603
- self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
604
-
605
- def __call__(
606
- self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
607
- ):
608
- residual = hidden_states
609
-
610
- if attn.spatial_norm is not None:
611
- hidden_states = attn.spatial_norm(hidden_states, temb)
612
-
613
- input_ndim = hidden_states.ndim
614
-
615
- if input_ndim == 4:
616
- batch_size, channel, height, width = hidden_states.shape
617
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
618
-
619
- batch_size, sequence_length, _ = (
620
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
621
- )
622
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
623
-
624
- if attn.group_norm is not None:
625
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
626
-
627
- query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
628
- query = attn.head_to_batch_dim(query)
629
-
630
- if encoder_hidden_states is None:
631
- encoder_hidden_states = hidden_states
632
- elif attn.norm_cross:
633
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
634
-
635
- key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
636
- value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
637
-
638
- key = attn.head_to_batch_dim(key)
639
- value = attn.head_to_batch_dim(value)
640
-
641
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
642
- hidden_states = torch.bmm(attention_probs, value)
643
- hidden_states = attn.batch_to_head_dim(hidden_states)
644
-
645
- # linear proj
646
- hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
647
- # dropout
648
- hidden_states = attn.to_out[1](hidden_states)
649
-
650
- if input_ndim == 4:
651
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
652
-
653
- if attn.residual_connection:
654
- hidden_states = hidden_states + residual
655
-
656
- hidden_states = hidden_states / attn.rescale_output_factor
657
-
658
- return hidden_states
659
-
660
-
661
- class CustomDiffusionAttnProcessor(nn.Module):
662
- r"""
663
- Processor for implementing attention for the Custom Diffusion method.
664
-
665
- Args:
666
- train_kv (`bool`, defaults to `True`):
667
- Whether to newly train the key and value matrices corresponding to the text features.
668
- train_q_out (`bool`, defaults to `True`):
669
- Whether to newly train query matrices corresponding to the latent image features.
670
- hidden_size (`int`, *optional*, defaults to `None`):
671
- The hidden size of the attention layer.
672
- cross_attention_dim (`int`, *optional*, defaults to `None`):
673
- The number of channels in the `encoder_hidden_states`.
674
- out_bias (`bool`, defaults to `True`):
675
- Whether to include the bias parameter in `train_q_out`.
676
- dropout (`float`, *optional*, defaults to 0.0):
677
- The dropout probability to use.
678
- """
679
-
680
- def __init__(
681
- self,
682
- train_kv=True,
683
- train_q_out=True,
684
- hidden_size=None,
685
- cross_attention_dim=None,
686
- out_bias=True,
687
- dropout=0.0,
688
- ):
689
- super().__init__()
690
- self.train_kv = train_kv
691
- self.train_q_out = train_q_out
692
-
693
- self.hidden_size = hidden_size
694
- self.cross_attention_dim = cross_attention_dim
695
-
696
- # `_custom_diffusion` id for easy serialization and loading.
697
- if self.train_kv:
698
- self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
699
- self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
700
- if self.train_q_out:
701
- self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
702
- self.to_out_custom_diffusion = nn.ModuleList([])
703
- self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
704
- self.to_out_custom_diffusion.append(nn.Dropout(dropout))
705
-
706
- def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
707
- batch_size, sequence_length, _ = hidden_states.shape
708
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
709
- if self.train_q_out:
710
- query = self.to_q_custom_diffusion(hidden_states)
711
- else:
712
- query = attn.to_q(hidden_states)
713
-
714
- if encoder_hidden_states is None:
715
- crossattn = False
716
- encoder_hidden_states = hidden_states
717
- else:
718
- crossattn = True
719
- if attn.norm_cross:
720
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
721
-
722
- if self.train_kv:
723
- key = self.to_k_custom_diffusion(encoder_hidden_states)
724
- value = self.to_v_custom_diffusion(encoder_hidden_states)
725
- else:
726
- key = attn.to_k(encoder_hidden_states)
727
- value = attn.to_v(encoder_hidden_states)
728
-
729
- if crossattn:
730
- detach = torch.ones_like(key)
731
- detach[:, :1, :] = detach[:, :1, :] * 0.0
732
- key = detach * key + (1 - detach) * key.detach()
733
- value = detach * value + (1 - detach) * value.detach()
734
-
735
- query = attn.head_to_batch_dim(query)
736
- key = attn.head_to_batch_dim(key)
737
- value = attn.head_to_batch_dim(value)
738
-
739
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
740
- hidden_states = torch.bmm(attention_probs, value)
741
- hidden_states = attn.batch_to_head_dim(hidden_states)
742
-
743
- if self.train_q_out:
744
- # linear proj
745
- hidden_states = self.to_out_custom_diffusion[0](hidden_states)
746
- # dropout
747
- hidden_states = self.to_out_custom_diffusion[1](hidden_states)
748
- else:
749
- # linear proj
750
- hidden_states = attn.to_out[0](hidden_states)
751
- # dropout
752
- hidden_states = attn.to_out[1](hidden_states)
753
-
754
- return hidden_states
755
-
756
-
757
- class AttnAddedKVProcessor:
758
- r"""
759
- Processor for performing attention-related computations with extra learnable key and value matrices for the text
760
- encoder.
761
- """
762
-
763
- def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
764
- residual = hidden_states
765
- hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
766
- batch_size, sequence_length, _ = hidden_states.shape
767
-
768
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
769
-
770
- if encoder_hidden_states is None:
771
- encoder_hidden_states = hidden_states
772
- elif attn.norm_cross:
773
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
774
-
775
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
776
-
777
- query = attn.to_q(hidden_states)
778
- query = attn.head_to_batch_dim(query)
779
-
780
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
781
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
782
- encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
783
- encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
784
-
785
- if not attn.only_cross_attention:
786
- key = attn.to_k(hidden_states)
787
- value = attn.to_v(hidden_states)
788
- key = attn.head_to_batch_dim(key)
789
- value = attn.head_to_batch_dim(value)
790
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
791
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
792
- else:
793
- key = encoder_hidden_states_key_proj
794
- value = encoder_hidden_states_value_proj
795
-
796
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
797
- hidden_states = torch.bmm(attention_probs, value)
798
- hidden_states = attn.batch_to_head_dim(hidden_states)
799
-
800
- # linear proj
801
- hidden_states = attn.to_out[0](hidden_states)
802
- # dropout
803
- hidden_states = attn.to_out[1](hidden_states)
804
-
805
- hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
806
- hidden_states = hidden_states + residual
807
-
808
- return hidden_states
809
-
810
-
811
- class AttnAddedKVProcessor2_0:
812
- r"""
813
- Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
814
- learnable key and value matrices for the text encoder.
815
- """
816
-
817
- def __init__(self):
818
- if not hasattr(F, "scaled_dot_product_attention"):
819
- raise ImportError(
820
- "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
821
- )
822
-
823
- def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
824
- residual = hidden_states
825
- hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
826
- batch_size, sequence_length, _ = hidden_states.shape
827
-
828
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
829
-
830
- if encoder_hidden_states is None:
831
- encoder_hidden_states = hidden_states
832
- elif attn.norm_cross:
833
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
834
-
835
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
836
-
837
- query = attn.to_q(hidden_states)
838
- query = attn.head_to_batch_dim(query, out_dim=4)
839
-
840
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
841
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
842
- encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
843
- encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
844
-
845
- if not attn.only_cross_attention:
846
- key = attn.to_k(hidden_states)
847
- value = attn.to_v(hidden_states)
848
- key = attn.head_to_batch_dim(key, out_dim=4)
849
- value = attn.head_to_batch_dim(value, out_dim=4)
850
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
851
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
852
- else:
853
- key = encoder_hidden_states_key_proj
854
- value = encoder_hidden_states_value_proj
855
-
856
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
857
- # TODO: add support for attn.scale when we move to Torch 2.1
858
- hidden_states = F.scaled_dot_product_attention(
859
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
860
- )
861
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
862
-
863
- # linear proj
864
- hidden_states = attn.to_out[0](hidden_states)
865
- # dropout
866
- hidden_states = attn.to_out[1](hidden_states)
867
-
868
- hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
869
- hidden_states = hidden_states + residual
870
-
871
- return hidden_states
872
-
873
-
874
- class LoRAAttnAddedKVProcessor(nn.Module):
875
- r"""
876
- Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text
877
- encoder.
878
-
879
- Args:
880
- hidden_size (`int`, *optional*):
881
- The hidden size of the attention layer.
882
- cross_attention_dim (`int`, *optional*, defaults to `None`):
883
- The number of channels in the `encoder_hidden_states`.
884
- rank (`int`, defaults to 4):
885
- The dimension of the LoRA update matrices.
886
-
887
- """
888
-
889
- def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
890
- super().__init__()
891
-
892
- self.hidden_size = hidden_size
893
- self.cross_attention_dim = cross_attention_dim
894
- self.rank = rank
895
-
896
- self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
897
- self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
898
- self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
899
- self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
900
- self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
901
- self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
902
-
903
- def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
904
- residual = hidden_states
905
- hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
906
- batch_size, sequence_length, _ = hidden_states.shape
907
-
908
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
909
-
910
- if encoder_hidden_states is None:
911
- encoder_hidden_states = hidden_states
912
- elif attn.norm_cross:
913
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
914
-
915
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
916
-
917
- query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
918
- query = attn.head_to_batch_dim(query)
919
-
920
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + scale * self.add_k_proj_lora(
921
- encoder_hidden_states
922
- )
923
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + scale * self.add_v_proj_lora(
924
- encoder_hidden_states
925
- )
926
- encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
927
- encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
928
-
929
- if not attn.only_cross_attention:
930
- key = attn.to_k(hidden_states) + scale * self.to_k_lora(hidden_states)
931
- value = attn.to_v(hidden_states) + scale * self.to_v_lora(hidden_states)
932
- key = attn.head_to_batch_dim(key)
933
- value = attn.head_to_batch_dim(value)
934
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
935
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
936
- else:
937
- key = encoder_hidden_states_key_proj
938
- value = encoder_hidden_states_value_proj
939
-
940
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
941
- hidden_states = torch.bmm(attention_probs, value)
942
- hidden_states = attn.batch_to_head_dim(hidden_states)
943
-
944
- # linear proj
945
- hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
946
- # dropout
947
- hidden_states = attn.to_out[1](hidden_states)
948
-
949
- hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
950
- hidden_states = hidden_states + residual
951
-
952
- return hidden_states
953
-
954
-
955
- class XFormersAttnAddedKVProcessor:
956
- r"""
957
- Processor for implementing memory efficient attention using xFormers.
958
-
959
- Args:
960
- attention_op (`Callable`, *optional*, defaults to `None`):
961
- The base
962
- [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
963
- use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
964
- operator.
965
- """
966
-
967
- def __init__(self, attention_op: Optional[Callable] = None):
968
- self.attention_op = attention_op
969
-
970
- def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
971
- residual = hidden_states
972
- hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
973
- batch_size, sequence_length, _ = hidden_states.shape
974
-
975
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
976
-
977
- if encoder_hidden_states is None:
978
- encoder_hidden_states = hidden_states
979
- elif attn.norm_cross:
980
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
981
-
982
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
983
-
984
- query = attn.to_q(hidden_states)
985
- query = attn.head_to_batch_dim(query)
986
-
987
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
988
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
989
- encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
990
- encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
991
-
992
- if not attn.only_cross_attention:
993
- key = attn.to_k(hidden_states)
994
- value = attn.to_v(hidden_states)
995
- key = attn.head_to_batch_dim(key)
996
- value = attn.head_to_batch_dim(value)
997
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
998
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
999
- else:
1000
- key = encoder_hidden_states_key_proj
1001
- value = encoder_hidden_states_value_proj
1002
-
1003
- hidden_states = xformers.ops.memory_efficient_attention(
1004
- query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1005
- )
1006
- hidden_states = hidden_states.to(query.dtype)
1007
- hidden_states = attn.batch_to_head_dim(hidden_states)
1008
-
1009
- # linear proj
1010
- hidden_states = attn.to_out[0](hidden_states)
1011
- # dropout
1012
- hidden_states = attn.to_out[1](hidden_states)
1013
-
1014
- hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
1015
- hidden_states = hidden_states + residual
1016
-
1017
- return hidden_states
1018
-
1019
-
1020
- class XFormersAttnProcessor:
1021
- r"""
1022
- Processor for implementing memory efficient attention using xFormers.
1023
-
1024
- Args:
1025
- attention_op (`Callable`, *optional*, defaults to `None`):
1026
- The base
1027
- [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1028
- use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1029
- operator.
1030
- """
1031
-
1032
- def __init__(self, attention_op: Optional[Callable] = None):
1033
- self.attention_op = attention_op
1034
-
1035
- def __call__(
1036
- self,
1037
- attn: Attention,
1038
- hidden_states: torch.FloatTensor,
1039
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1040
- attention_mask: Optional[torch.FloatTensor] = None,
1041
- temb: Optional[torch.FloatTensor] = None,
1042
- posemb: Optional = None,
1043
- ):
1044
- residual = hidden_states
1045
-
1046
- if attn.spatial_norm is not None:
1047
- hidden_states = attn.spatial_norm(hidden_states, temb)
1048
-
1049
- input_ndim = hidden_states.ndim
1050
-
1051
- if input_ndim == 4:
1052
- batch_size, channel, height, width = hidden_states.shape
1053
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1054
-
1055
- if posemb is not None:
1056
- # turn 2d attention into multiview attention
1057
- self_attn = encoder_hidden_states is None # check if self attn or cross attn
1058
- p_out, p_in = posemb
1059
- t_out, t_in = p_out.shape[1], p_in.shape[1] # t size
1060
- hidden_states = einops.rearrange(hidden_states, '(b t_out) l d -> b (t_out l) d', t_out=t_out)
1061
-
1062
- batch_size, key_tokens, _ = (
1063
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1064
- )
1065
-
1066
- attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
1067
- if attention_mask is not None:
1068
- # expand our mask's singleton query_tokens dimension:
1069
- # [batch*heads, 1, key_tokens] ->
1070
- # [batch*heads, query_tokens, key_tokens]
1071
- # so that it can be added as a bias onto the attention scores that xformers computes:
1072
- # [batch*heads, query_tokens, key_tokens]
1073
- # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
1074
- _, query_tokens, _ = hidden_states.shape
1075
- attention_mask = attention_mask.expand(-1, query_tokens, -1)
1076
-
1077
- if attn.group_norm is not None:
1078
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1079
-
1080
- query = attn.to_q(hidden_states)
1081
- if encoder_hidden_states is None:
1082
- encoder_hidden_states = hidden_states
1083
- elif attn.norm_cross:
1084
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1085
-
1086
- key = attn.to_k(encoder_hidden_states)
1087
- value = attn.to_v(encoder_hidden_states)
1088
-
1089
-
1090
- # apply 4DoF CaPE, todo now only for xformer processor
1091
- if posemb is not None:
1092
- p_out = einops.repeat(p_out, 'b t_out d -> b (t_out l) d', l=query.shape[1]//t_out) # query shape
1093
- if self_attn:
1094
- p_in = p_out
1095
- else:
1096
- p_in = einops.repeat(p_in, 'b t_in d -> b (t_in l) d', l=key.shape[1] // t_in) # key shape
1097
- query, key = cape_embed(p_out, p_in, query, key)
1098
-
1099
- query = attn.head_to_batch_dim(query).contiguous()
1100
- key = attn.head_to_batch_dim(key).contiguous()
1101
- value = attn.head_to_batch_dim(value).contiguous()
1102
-
1103
- # self-ttn (bm) l c x (bm) l c -> (bm) l c
1104
- # cross-ttn (bm) l c x b (nl) c -> (bm) l c
1105
- # reuse 2d attention for multiview attention
1106
- # self-ttn b (ml) c x b (ml) c -> b (ml) c
1107
- # cross-ttn b (ml) c x b (nl) c -> b (ml) c
1108
- hidden_states = xformers.ops.memory_efficient_attention( # query: (bm) l c -> b (ml) c; key: b (nl) c
1109
- query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1110
- )
1111
- hidden_states = hidden_states.to(query.dtype)
1112
- hidden_states = attn.batch_to_head_dim(hidden_states)
1113
-
1114
- # linear proj
1115
- hidden_states = attn.to_out[0](hidden_states)
1116
- # dropout
1117
- hidden_states = attn.to_out[1](hidden_states)
1118
-
1119
- if posemb is not None:
1120
- # reshape back
1121
- hidden_states = einops.rearrange(hidden_states, 'b (t_out l) d -> (b t_out) l d', t_out=t_out)
1122
-
1123
- if input_ndim == 4:
1124
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1125
-
1126
- if attn.residual_connection:
1127
- hidden_states = hidden_states + residual
1128
-
1129
- hidden_states = hidden_states / attn.rescale_output_factor
1130
-
1131
-
1132
- return hidden_states
1133
-
1134
-
1135
- class AttnProcessor2_0:
1136
- r"""
1137
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
1138
- """
1139
-
1140
- def __init__(self):
1141
- if not hasattr(F, "scaled_dot_product_attention"):
1142
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1143
-
1144
- def __call__(
1145
- self,
1146
- attn: Attention,
1147
- hidden_states,
1148
- encoder_hidden_states=None,
1149
- attention_mask=None,
1150
- temb=None,
1151
- ):
1152
- residual = hidden_states
1153
-
1154
- if attn.spatial_norm is not None:
1155
- hidden_states = attn.spatial_norm(hidden_states, temb)
1156
-
1157
- input_ndim = hidden_states.ndim
1158
-
1159
- if input_ndim == 4:
1160
- batch_size, channel, height, width = hidden_states.shape
1161
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1162
-
1163
- batch_size, sequence_length, _ = (
1164
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1165
- )
1166
- inner_dim = hidden_states.shape[-1]
1167
-
1168
- if attention_mask is not None:
1169
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1170
- # scaled_dot_product_attention expects attention_mask shape to be
1171
- # (batch, heads, source_length, target_length)
1172
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1173
-
1174
- if attn.group_norm is not None:
1175
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1176
-
1177
- query = attn.to_q(hidden_states)
1178
-
1179
- if encoder_hidden_states is None:
1180
- encoder_hidden_states = hidden_states
1181
- elif attn.norm_cross:
1182
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1183
-
1184
- key = attn.to_k(encoder_hidden_states)
1185
- value = attn.to_v(encoder_hidden_states)
1186
-
1187
- head_dim = inner_dim // attn.heads
1188
-
1189
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1190
-
1191
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1192
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1193
-
1194
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
1195
- # TODO: add support for attn.scale when we move to Torch 2.1
1196
- hidden_states = F.scaled_dot_product_attention(
1197
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1198
- )
1199
-
1200
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1201
- hidden_states = hidden_states.to(query.dtype)
1202
-
1203
- # linear proj
1204
- hidden_states = attn.to_out[0](hidden_states)
1205
- # dropout
1206
- hidden_states = attn.to_out[1](hidden_states)
1207
-
1208
- if input_ndim == 4:
1209
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1210
-
1211
- if attn.residual_connection:
1212
- hidden_states = hidden_states + residual
1213
-
1214
- hidden_states = hidden_states / attn.rescale_output_factor
1215
-
1216
- return hidden_states
1217
-
1218
-
1219
- class LoRAXFormersAttnProcessor(nn.Module):
1220
- r"""
1221
- Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers.
1222
-
1223
- Args:
1224
- hidden_size (`int`, *optional*):
1225
- The hidden size of the attention layer.
1226
- cross_attention_dim (`int`, *optional*):
1227
- The number of channels in the `encoder_hidden_states`.
1228
- rank (`int`, defaults to 4):
1229
- The dimension of the LoRA update matrices.
1230
- attention_op (`Callable`, *optional*, defaults to `None`):
1231
- The base
1232
- [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1233
- use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1234
- operator.
1235
- network_alpha (`int`, *optional*):
1236
- Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
1237
-
1238
- """
1239
-
1240
- def __init__(
1241
- self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None
1242
- ):
1243
- super().__init__()
1244
-
1245
- self.hidden_size = hidden_size
1246
- self.cross_attention_dim = cross_attention_dim
1247
- self.rank = rank
1248
- self.attention_op = attention_op
1249
-
1250
- self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1251
- self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1252
- self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1253
- self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1254
-
1255
- def __call__(
1256
- self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
1257
- ):
1258
- residual = hidden_states
1259
-
1260
- if attn.spatial_norm is not None:
1261
- hidden_states = attn.spatial_norm(hidden_states, temb)
1262
-
1263
- input_ndim = hidden_states.ndim
1264
-
1265
- if input_ndim == 4:
1266
- batch_size, channel, height, width = hidden_states.shape
1267
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1268
-
1269
- batch_size, sequence_length, _ = (
1270
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1271
- )
1272
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1273
-
1274
- if attn.group_norm is not None:
1275
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1276
-
1277
- query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
1278
- query = attn.head_to_batch_dim(query).contiguous()
1279
-
1280
- if encoder_hidden_states is None:
1281
- encoder_hidden_states = hidden_states
1282
- elif attn.norm_cross:
1283
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1284
-
1285
- key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
1286
- value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
1287
-
1288
- key = attn.head_to_batch_dim(key).contiguous()
1289
- value = attn.head_to_batch_dim(value).contiguous()
1290
-
1291
- hidden_states = xformers.ops.memory_efficient_attention(
1292
- query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1293
- )
1294
- hidden_states = attn.batch_to_head_dim(hidden_states)
1295
-
1296
- # linear proj
1297
- hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
1298
- # dropout
1299
- hidden_states = attn.to_out[1](hidden_states)
1300
-
1301
- if input_ndim == 4:
1302
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1303
-
1304
- if attn.residual_connection:
1305
- hidden_states = hidden_states + residual
1306
-
1307
- hidden_states = hidden_states / attn.rescale_output_factor
1308
-
1309
- return hidden_states
1310
-
1311
-
1312
- class LoRAAttnProcessor2_0(nn.Module):
1313
- r"""
1314
- Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product
1315
- attention.
1316
-
1317
- Args:
1318
- hidden_size (`int`):
1319
- The hidden size of the attention layer.
1320
- cross_attention_dim (`int`, *optional*):
1321
- The number of channels in the `encoder_hidden_states`.
1322
- rank (`int`, defaults to 4):
1323
- The dimension of the LoRA update matrices.
1324
- network_alpha (`int`, *optional*):
1325
- Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
1326
- """
1327
-
1328
- def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
1329
- super().__init__()
1330
- if not hasattr(F, "scaled_dot_product_attention"):
1331
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1332
-
1333
- self.hidden_size = hidden_size
1334
- self.cross_attention_dim = cross_attention_dim
1335
- self.rank = rank
1336
-
1337
- self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1338
- self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1339
- self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1340
- self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1341
-
1342
- def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
1343
- residual = hidden_states
1344
-
1345
- input_ndim = hidden_states.ndim
1346
-
1347
- if input_ndim == 4:
1348
- batch_size, channel, height, width = hidden_states.shape
1349
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1350
-
1351
- batch_size, sequence_length, _ = (
1352
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1353
- )
1354
- inner_dim = hidden_states.shape[-1]
1355
-
1356
- if attention_mask is not None:
1357
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1358
- # scaled_dot_product_attention expects attention_mask shape to be
1359
- # (batch, heads, source_length, target_length)
1360
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1361
-
1362
- if attn.group_norm is not None:
1363
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1364
-
1365
- query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
1366
-
1367
- if encoder_hidden_states is None:
1368
- encoder_hidden_states = hidden_states
1369
- elif attn.norm_cross:
1370
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1371
-
1372
- key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
1373
- value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
1374
-
1375
- head_dim = inner_dim // attn.heads
1376
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1377
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1378
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1379
-
1380
- # TODO: add support for attn.scale when we move to Torch 2.1
1381
- hidden_states = F.scaled_dot_product_attention(
1382
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1383
- )
1384
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1385
- hidden_states = hidden_states.to(query.dtype)
1386
-
1387
- # linear proj
1388
- hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
1389
- # dropout
1390
- hidden_states = attn.to_out[1](hidden_states)
1391
-
1392
- if input_ndim == 4:
1393
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1394
-
1395
- if attn.residual_connection:
1396
- hidden_states = hidden_states + residual
1397
-
1398
- hidden_states = hidden_states / attn.rescale_output_factor
1399
-
1400
- return hidden_states
1401
-
1402
-
1403
- class CustomDiffusionXFormersAttnProcessor(nn.Module):
1404
- r"""
1405
- Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
1406
-
1407
- Args:
1408
- train_kv (`bool`, defaults to `True`):
1409
- Whether to newly train the key and value matrices corresponding to the text features.
1410
- train_q_out (`bool`, defaults to `True`):
1411
- Whether to newly train query matrices corresponding to the latent image features.
1412
- hidden_size (`int`, *optional*, defaults to `None`):
1413
- The hidden size of the attention layer.
1414
- cross_attention_dim (`int`, *optional*, defaults to `None`):
1415
- The number of channels in the `encoder_hidden_states`.
1416
- out_bias (`bool`, defaults to `True`):
1417
- Whether to include the bias parameter in `train_q_out`.
1418
- dropout (`float`, *optional*, defaults to 0.0):
1419
- The dropout probability to use.
1420
- attention_op (`Callable`, *optional*, defaults to `None`):
1421
- The base
1422
- [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
1423
- as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
1424
- """
1425
-
1426
- def __init__(
1427
- self,
1428
- train_kv=True,
1429
- train_q_out=False,
1430
- hidden_size=None,
1431
- cross_attention_dim=None,
1432
- out_bias=True,
1433
- dropout=0.0,
1434
- attention_op: Optional[Callable] = None,
1435
- ):
1436
- super().__init__()
1437
- self.train_kv = train_kv
1438
- self.train_q_out = train_q_out
1439
-
1440
- self.hidden_size = hidden_size
1441
- self.cross_attention_dim = cross_attention_dim
1442
- self.attention_op = attention_op
1443
-
1444
- # `_custom_diffusion` id for easy serialization and loading.
1445
- if self.train_kv:
1446
- self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1447
- self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1448
- if self.train_q_out:
1449
- self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
1450
- self.to_out_custom_diffusion = nn.ModuleList([])
1451
- self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
1452
- self.to_out_custom_diffusion.append(nn.Dropout(dropout))
1453
-
1454
- def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
1455
- batch_size, sequence_length, _ = (
1456
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1457
- )
1458
-
1459
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1460
-
1461
- if self.train_q_out:
1462
- query = self.to_q_custom_diffusion(hidden_states)
1463
- else:
1464
- query = attn.to_q(hidden_states)
1465
-
1466
- if encoder_hidden_states is None:
1467
- crossattn = False
1468
- encoder_hidden_states = hidden_states
1469
- else:
1470
- crossattn = True
1471
- if attn.norm_cross:
1472
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1473
-
1474
- if self.train_kv:
1475
- key = self.to_k_custom_diffusion(encoder_hidden_states)
1476
- value = self.to_v_custom_diffusion(encoder_hidden_states)
1477
- else:
1478
- key = attn.to_k(encoder_hidden_states)
1479
- value = attn.to_v(encoder_hidden_states)
1480
-
1481
- if crossattn:
1482
- detach = torch.ones_like(key)
1483
- detach[:, :1, :] = detach[:, :1, :] * 0.0
1484
- key = detach * key + (1 - detach) * key.detach()
1485
- value = detach * value + (1 - detach) * value.detach()
1486
-
1487
- query = attn.head_to_batch_dim(query).contiguous()
1488
- key = attn.head_to_batch_dim(key).contiguous()
1489
- value = attn.head_to_batch_dim(value).contiguous()
1490
-
1491
- hidden_states = xformers.ops.memory_efficient_attention(
1492
- query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1493
- )
1494
- hidden_states = hidden_states.to(query.dtype)
1495
- hidden_states = attn.batch_to_head_dim(hidden_states)
1496
-
1497
- if self.train_q_out:
1498
- # linear proj
1499
- hidden_states = self.to_out_custom_diffusion[0](hidden_states)
1500
- # dropout
1501
- hidden_states = self.to_out_custom_diffusion[1](hidden_states)
1502
- else:
1503
- # linear proj
1504
- hidden_states = attn.to_out[0](hidden_states)
1505
- # dropout
1506
- hidden_states = attn.to_out[1](hidden_states)
1507
- return hidden_states
1508
-
1509
-
1510
- class SlicedAttnProcessor:
1511
- r"""
1512
- Processor for implementing sliced attention.
1513
-
1514
- Args:
1515
- slice_size (`int`, *optional*):
1516
- The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
1517
- `attention_head_dim` must be a multiple of the `slice_size`.
1518
- """
1519
-
1520
- def __init__(self, slice_size):
1521
- self.slice_size = slice_size
1522
-
1523
- def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
1524
- residual = hidden_states
1525
-
1526
- input_ndim = hidden_states.ndim
1527
-
1528
- if input_ndim == 4:
1529
- batch_size, channel, height, width = hidden_states.shape
1530
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1531
-
1532
- batch_size, sequence_length, _ = (
1533
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1534
- )
1535
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1536
-
1537
- if attn.group_norm is not None:
1538
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1539
-
1540
- query = attn.to_q(hidden_states)
1541
- dim = query.shape[-1]
1542
- query = attn.head_to_batch_dim(query)
1543
-
1544
- if encoder_hidden_states is None:
1545
- encoder_hidden_states = hidden_states
1546
- elif attn.norm_cross:
1547
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1548
-
1549
- key = attn.to_k(encoder_hidden_states)
1550
- value = attn.to_v(encoder_hidden_states)
1551
- key = attn.head_to_batch_dim(key)
1552
- value = attn.head_to_batch_dim(value)
1553
-
1554
- batch_size_attention, query_tokens, _ = query.shape
1555
- hidden_states = torch.zeros(
1556
- (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
1557
- )
1558
-
1559
- for i in range(batch_size_attention // self.slice_size):
1560
- start_idx = i * self.slice_size
1561
- end_idx = (i + 1) * self.slice_size
1562
-
1563
- query_slice = query[start_idx:end_idx]
1564
- key_slice = key[start_idx:end_idx]
1565
- attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
1566
-
1567
- attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
1568
-
1569
- attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
1570
-
1571
- hidden_states[start_idx:end_idx] = attn_slice
1572
-
1573
- hidden_states = attn.batch_to_head_dim(hidden_states)
1574
-
1575
- # linear proj
1576
- hidden_states = attn.to_out[0](hidden_states)
1577
- # dropout
1578
- hidden_states = attn.to_out[1](hidden_states)
1579
-
1580
- if input_ndim == 4:
1581
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1582
-
1583
- if attn.residual_connection:
1584
- hidden_states = hidden_states + residual
1585
-
1586
- hidden_states = hidden_states / attn.rescale_output_factor
1587
-
1588
- return hidden_states
1589
-
1590
-
1591
- class SlicedAttnAddedKVProcessor:
1592
- r"""
1593
- Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder.
1594
-
1595
- Args:
1596
- slice_size (`int`, *optional*):
1597
- The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
1598
- `attention_head_dim` must be a multiple of the `slice_size`.
1599
- """
1600
-
1601
- def __init__(self, slice_size):
1602
- self.slice_size = slice_size
1603
-
1604
- def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
1605
- residual = hidden_states
1606
-
1607
- if attn.spatial_norm is not None:
1608
- hidden_states = attn.spatial_norm(hidden_states, temb)
1609
-
1610
- hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
1611
-
1612
- batch_size, sequence_length, _ = hidden_states.shape
1613
-
1614
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1615
-
1616
- if encoder_hidden_states is None:
1617
- encoder_hidden_states = hidden_states
1618
- elif attn.norm_cross:
1619
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1620
-
1621
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1622
-
1623
- query = attn.to_q(hidden_states)
1624
- dim = query.shape[-1]
1625
- query = attn.head_to_batch_dim(query)
1626
-
1627
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1628
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1629
-
1630
- encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
1631
- encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
1632
-
1633
- if not attn.only_cross_attention:
1634
- key = attn.to_k(hidden_states)
1635
- value = attn.to_v(hidden_states)
1636
- key = attn.head_to_batch_dim(key)
1637
- value = attn.head_to_batch_dim(value)
1638
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
1639
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
1640
- else:
1641
- key = encoder_hidden_states_key_proj
1642
- value = encoder_hidden_states_value_proj
1643
-
1644
- batch_size_attention, query_tokens, _ = query.shape
1645
- hidden_states = torch.zeros(
1646
- (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
1647
- )
1648
-
1649
- for i in range(batch_size_attention // self.slice_size):
1650
- start_idx = i * self.slice_size
1651
- end_idx = (i + 1) * self.slice_size
1652
-
1653
- query_slice = query[start_idx:end_idx]
1654
- key_slice = key[start_idx:end_idx]
1655
- attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
1656
-
1657
- attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
1658
-
1659
- attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
1660
-
1661
- hidden_states[start_idx:end_idx] = attn_slice
1662
-
1663
- hidden_states = attn.batch_to_head_dim(hidden_states)
1664
-
1665
- # linear proj
1666
- hidden_states = attn.to_out[0](hidden_states)
1667
- # dropout
1668
- hidden_states = attn.to_out[1](hidden_states)
1669
-
1670
- hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
1671
- hidden_states = hidden_states + residual
1672
-
1673
- return hidden_states
1674
-
1675
-
1676
- AttentionProcessor = Union[
1677
- AttnProcessor,
1678
- AttnProcessor2_0,
1679
- XFormersAttnProcessor,
1680
- SlicedAttnProcessor,
1681
- AttnAddedKVProcessor,
1682
- SlicedAttnAddedKVProcessor,
1683
- AttnAddedKVProcessor2_0,
1684
- XFormersAttnAddedKVProcessor,
1685
- LoRAAttnProcessor,
1686
- LoRAXFormersAttnProcessor,
1687
- LoRAAttnProcessor2_0,
1688
- LoRAAttnAddedKVProcessor,
1689
- CustomDiffusionAttnProcessor,
1690
- CustomDiffusionXFormersAttnProcessor,
1691
- ]
1692
-
1693
-
1694
- class SpatialNorm(nn.Module):
1695
- """
1696
- Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002
1697
- """
1698
-
1699
- def __init__(
1700
- self,
1701
- f_channels,
1702
- zq_channels,
1703
- ):
1704
- super().__init__()
1705
- self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
1706
- self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
1707
- self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
1708
-
1709
- def forward(self, f, zq):
1710
- f_size = f.shape[-2:]
1711
- zq = F.interpolate(zq, size=f_size, mode="nearest")
1712
- norm_f = self.norm_layer(f)
1713
- new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
1714
- return new_f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4DoF/diffusers/models/autoencoder_kl.py DELETED
@@ -1,411 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from dataclasses import dataclass
15
- from typing import Dict, Optional, Tuple, Union
16
-
17
- import torch
18
- import torch.nn as nn
19
-
20
- from ..configuration_utils import ConfigMixin, register_to_config
21
- from ..utils import BaseOutput, apply_forward_hook
22
- from .attention_processor import AttentionProcessor, AttnProcessor
23
- from .modeling_utils import ModelMixin
24
- from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
25
-
26
-
27
- @dataclass
28
- class AutoencoderKLOutput(BaseOutput):
29
- """
30
- Output of AutoencoderKL encoding method.
31
-
32
- Args:
33
- latent_dist (`DiagonalGaussianDistribution`):
34
- Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
35
- `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
36
- """
37
-
38
- latent_dist: "DiagonalGaussianDistribution"
39
-
40
-
41
- class AutoencoderKL(ModelMixin, ConfigMixin):
42
- r"""
43
- A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
44
-
45
- This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
46
- for all models (such as downloading or saving).
47
-
48
- Parameters:
49
- in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
50
- out_channels (int, *optional*, defaults to 3): Number of channels in the output.
51
- down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
52
- Tuple of downsample block types.
53
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
54
- Tuple of upsample block types.
55
- block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
56
- Tuple of block output channels.
57
- act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
58
- latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
59
- sample_size (`int`, *optional*, defaults to `32`): Sample input size.
60
- scaling_factor (`float`, *optional*, defaults to 0.18215):
61
- The component-wise standard deviation of the trained latent space computed using the first batch of the
62
- training set. This is used to scale the latent space to have unit variance when training the diffusion
63
- model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
64
- diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
65
- / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
66
- Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
67
- """
68
-
69
- _supports_gradient_checkpointing = True
70
-
71
- @register_to_config
72
- def __init__(
73
- self,
74
- in_channels: int = 3,
75
- out_channels: int = 3,
76
- down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
77
- up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
78
- block_out_channels: Tuple[int] = (64,),
79
- layers_per_block: int = 1,
80
- act_fn: str = "silu",
81
- latent_channels: int = 4,
82
- norm_num_groups: int = 32,
83
- sample_size: int = 32,
84
- scaling_factor: float = 0.18215,
85
- ):
86
- super().__init__()
87
-
88
- # pass init params to Encoder
89
- self.encoder = Encoder(
90
- in_channels=in_channels,
91
- out_channels=latent_channels,
92
- down_block_types=down_block_types,
93
- block_out_channels=block_out_channels,
94
- layers_per_block=layers_per_block,
95
- act_fn=act_fn,
96
- norm_num_groups=norm_num_groups,
97
- double_z=True,
98
- )
99
-
100
- # pass init params to Decoder
101
- self.decoder = Decoder(
102
- in_channels=latent_channels,
103
- out_channels=out_channels,
104
- up_block_types=up_block_types,
105
- block_out_channels=block_out_channels,
106
- layers_per_block=layers_per_block,
107
- norm_num_groups=norm_num_groups,
108
- act_fn=act_fn,
109
- )
110
-
111
- self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
112
- self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
113
-
114
- self.use_slicing = False
115
- self.use_tiling = False
116
-
117
- # only relevant if vae tiling is enabled
118
- self.tile_sample_min_size = self.config.sample_size
119
- sample_size = (
120
- self.config.sample_size[0]
121
- if isinstance(self.config.sample_size, (list, tuple))
122
- else self.config.sample_size
123
- )
124
- self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
125
- self.tile_overlap_factor = 0.25
126
-
127
- def _set_gradient_checkpointing(self, module, value=False):
128
- if isinstance(module, (Encoder, Decoder)):
129
- module.gradient_checkpointing = value
130
-
131
- def enable_tiling(self, use_tiling: bool = True):
132
- r"""
133
- Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
134
- compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
135
- processing larger images.
136
- """
137
- self.use_tiling = use_tiling
138
-
139
- def disable_tiling(self):
140
- r"""
141
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
142
- decoding in one step.
143
- """
144
- self.enable_tiling(False)
145
-
146
- def enable_slicing(self):
147
- r"""
148
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
149
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
150
- """
151
- self.use_slicing = True
152
-
153
- def disable_slicing(self):
154
- r"""
155
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
156
- decoding in one step.
157
- """
158
- self.use_slicing = False
159
-
160
- @property
161
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
162
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
163
- r"""
164
- Returns:
165
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
166
- indexed by its weight name.
167
- """
168
- # set recursively
169
- processors = {}
170
-
171
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
172
- if hasattr(module, "set_processor"):
173
- processors[f"{name}.processor"] = module.processor
174
-
175
- for sub_name, child in module.named_children():
176
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
177
-
178
- return processors
179
-
180
- for name, module in self.named_children():
181
- fn_recursive_add_processors(name, module, processors)
182
-
183
- return processors
184
-
185
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
186
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
187
- r"""
188
- Sets the attention processor to use to compute attention.
189
-
190
- Parameters:
191
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
192
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
193
- for **all** `Attention` layers.
194
-
195
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
196
- processor. This is strongly recommended when setting trainable attention processors.
197
-
198
- """
199
- count = len(self.attn_processors.keys())
200
-
201
- if isinstance(processor, dict) and len(processor) != count:
202
- raise ValueError(
203
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
204
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
205
- )
206
-
207
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
208
- if hasattr(module, "set_processor"):
209
- if not isinstance(processor, dict):
210
- module.set_processor(processor)
211
- else:
212
- module.set_processor(processor.pop(f"{name}.processor"))
213
-
214
- for sub_name, child in module.named_children():
215
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
216
-
217
- for name, module in self.named_children():
218
- fn_recursive_attn_processor(name, module, processor)
219
-
220
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
221
- def set_default_attn_processor(self):
222
- """
223
- Disables custom attention processors and sets the default attention implementation.
224
- """
225
- self.set_attn_processor(AttnProcessor())
226
-
227
- @apply_forward_hook
228
- def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
229
- if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
230
- return self.tiled_encode(x, return_dict=return_dict)
231
-
232
- if self.use_slicing and x.shape[0] > 1:
233
- encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
234
- h = torch.cat(encoded_slices)
235
- else:
236
- h = self.encoder(x)
237
-
238
- moments = self.quant_conv(h)
239
- posterior = DiagonalGaussianDistribution(moments)
240
-
241
- if not return_dict:
242
- return (posterior,)
243
-
244
- return AutoencoderKLOutput(latent_dist=posterior)
245
-
246
- def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
247
- if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
248
- return self.tiled_decode(z, return_dict=return_dict)
249
-
250
- z = self.post_quant_conv(z)
251
- dec = self.decoder(z)
252
-
253
- if not return_dict:
254
- return (dec,)
255
-
256
- return DecoderOutput(sample=dec)
257
-
258
- @apply_forward_hook
259
- def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
260
- if self.use_slicing and z.shape[0] > 1:
261
- decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
262
- decoded = torch.cat(decoded_slices)
263
- else:
264
- decoded = self._decode(z).sample
265
-
266
- if not return_dict:
267
- return (decoded,)
268
-
269
- return DecoderOutput(sample=decoded)
270
-
271
- def blend_v(self, a, b, blend_extent):
272
- blend_extent = min(a.shape[2], b.shape[2], blend_extent)
273
- for y in range(blend_extent):
274
- b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
275
- return b
276
-
277
- def blend_h(self, a, b, blend_extent):
278
- blend_extent = min(a.shape[3], b.shape[3], blend_extent)
279
- for x in range(blend_extent):
280
- b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
281
- return b
282
-
283
- def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
284
- r"""Encode a batch of images using a tiled encoder.
285
-
286
- When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
287
- steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
288
- different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
289
- tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
290
- output, but they should be much less noticeable.
291
-
292
- Args:
293
- x (`torch.FloatTensor`): Input batch of images.
294
- return_dict (`bool`, *optional*, defaults to `True`):
295
- Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
296
-
297
- Returns:
298
- [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
299
- If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
300
- `tuple` is returned.
301
- """
302
- overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
303
- blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
304
- row_limit = self.tile_latent_min_size - blend_extent
305
-
306
- # Split the image into 512x512 tiles and encode them separately.
307
- rows = []
308
- for i in range(0, x.shape[2], overlap_size):
309
- row = []
310
- for j in range(0, x.shape[3], overlap_size):
311
- tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
312
- tile = self.encoder(tile)
313
- tile = self.quant_conv(tile)
314
- row.append(tile)
315
- rows.append(row)
316
- result_rows = []
317
- for i, row in enumerate(rows):
318
- result_row = []
319
- for j, tile in enumerate(row):
320
- # blend the above tile and the left tile
321
- # to the current tile and add the current tile to the result row
322
- if i > 0:
323
- tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
324
- if j > 0:
325
- tile = self.blend_h(row[j - 1], tile, blend_extent)
326
- result_row.append(tile[:, :, :row_limit, :row_limit])
327
- result_rows.append(torch.cat(result_row, dim=3))
328
-
329
- moments = torch.cat(result_rows, dim=2)
330
- posterior = DiagonalGaussianDistribution(moments)
331
-
332
- if not return_dict:
333
- return (posterior,)
334
-
335
- return AutoencoderKLOutput(latent_dist=posterior)
336
-
337
- def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
338
- r"""
339
- Decode a batch of images using a tiled decoder.
340
-
341
- Args:
342
- z (`torch.FloatTensor`): Input batch of latent vectors.
343
- return_dict (`bool`, *optional*, defaults to `True`):
344
- Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
345
-
346
- Returns:
347
- [`~models.vae.DecoderOutput`] or `tuple`:
348
- If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
349
- returned.
350
- """
351
- overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
352
- blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
353
- row_limit = self.tile_sample_min_size - blend_extent
354
-
355
- # Split z into overlapping 64x64 tiles and decode them separately.
356
- # The tiles have an overlap to avoid seams between tiles.
357
- rows = []
358
- for i in range(0, z.shape[2], overlap_size):
359
- row = []
360
- for j in range(0, z.shape[3], overlap_size):
361
- tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
362
- tile = self.post_quant_conv(tile)
363
- decoded = self.decoder(tile)
364
- row.append(decoded)
365
- rows.append(row)
366
- result_rows = []
367
- for i, row in enumerate(rows):
368
- result_row = []
369
- for j, tile in enumerate(row):
370
- # blend the above tile and the left tile
371
- # to the current tile and add the current tile to the result row
372
- if i > 0:
373
- tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
374
- if j > 0:
375
- tile = self.blend_h(row[j - 1], tile, blend_extent)
376
- result_row.append(tile[:, :, :row_limit, :row_limit])
377
- result_rows.append(torch.cat(result_row, dim=3))
378
-
379
- dec = torch.cat(result_rows, dim=2)
380
- if not return_dict:
381
- return (dec,)
382
-
383
- return DecoderOutput(sample=dec)
384
-
385
- def forward(
386
- self,
387
- sample: torch.FloatTensor,
388
- sample_posterior: bool = False,
389
- return_dict: bool = True,
390
- generator: Optional[torch.Generator] = None,
391
- ) -> Union[DecoderOutput, torch.FloatTensor]:
392
- r"""
393
- Args:
394
- sample (`torch.FloatTensor`): Input sample.
395
- sample_posterior (`bool`, *optional*, defaults to `False`):
396
- Whether to sample from the posterior.
397
- return_dict (`bool`, *optional*, defaults to `True`):
398
- Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
399
- """
400
- x = sample
401
- posterior = self.encode(x).latent_dist
402
- if sample_posterior:
403
- z = posterior.sample(generator=generator)
404
- else:
405
- z = posterior.mode()
406
- dec = self.decode(z).sample
407
-
408
- if not return_dict:
409
- return (dec,)
410
-
411
- return DecoderOutput(sample=dec)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4DoF/diffusers/models/controlnet.py DELETED
@@ -1,705 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from dataclasses import dataclass
15
- from typing import Any, Dict, List, Optional, Tuple, Union
16
-
17
- import torch
18
- from torch import nn
19
- from torch.nn import functional as F
20
-
21
- from ..configuration_utils import ConfigMixin, register_to_config
22
- from ..utils import BaseOutput, logging
23
- from .attention_processor import AttentionProcessor, AttnProcessor
24
- from .embeddings import TimestepEmbedding, Timesteps
25
- from .modeling_utils import ModelMixin
26
- from .unet_2d_blocks import (
27
- CrossAttnDownBlock2D,
28
- DownBlock2D,
29
- UNetMidBlock2DCrossAttn,
30
- get_down_block,
31
- )
32
- from .unet_2d_condition import UNet2DConditionModel
33
-
34
-
35
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
-
37
-
38
- @dataclass
39
- class ControlNetOutput(BaseOutput):
40
- """
41
- The output of [`ControlNetModel`].
42
-
43
- Args:
44
- down_block_res_samples (`tuple[torch.Tensor]`):
45
- A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
46
- be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
47
- used to condition the original UNet's downsampling activations.
48
- mid_down_block_re_sample (`torch.Tensor`):
49
- The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
50
- `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
51
- Output can be used to condition the original UNet's middle block activation.
52
- """
53
-
54
- down_block_res_samples: Tuple[torch.Tensor]
55
- mid_block_res_sample: torch.Tensor
56
-
57
-
58
- class ControlNetConditioningEmbedding(nn.Module):
59
- """
60
- Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
61
- [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
62
- training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
63
- convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
64
- (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
65
- model) to encode image-space conditions ... into feature maps ..."
66
- """
67
-
68
- def __init__(
69
- self,
70
- conditioning_embedding_channels: int,
71
- conditioning_channels: int = 3,
72
- block_out_channels: Tuple[int] = (16, 32, 96, 256),
73
- ):
74
- super().__init__()
75
-
76
- self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
77
-
78
- self.blocks = nn.ModuleList([])
79
-
80
- for i in range(len(block_out_channels) - 1):
81
- channel_in = block_out_channels[i]
82
- channel_out = block_out_channels[i + 1]
83
- self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
84
- self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
85
-
86
- self.conv_out = zero_module(
87
- nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
88
- )
89
-
90
- def forward(self, conditioning):
91
- embedding = self.conv_in(conditioning)
92
- embedding = F.silu(embedding)
93
-
94
- for block in self.blocks:
95
- embedding = block(embedding)
96
- embedding = F.silu(embedding)
97
-
98
- embedding = self.conv_out(embedding)
99
-
100
- return embedding
101
-
102
-
103
- class ControlNetModel(ModelMixin, ConfigMixin):
104
- """
105
- A ControlNet model.
106
-
107
- Args:
108
- in_channels (`int`, defaults to 4):
109
- The number of channels in the input sample.
110
- flip_sin_to_cos (`bool`, defaults to `True`):
111
- Whether to flip the sin to cos in the time embedding.
112
- freq_shift (`int`, defaults to 0):
113
- The frequency shift to apply to the time embedding.
114
- down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
115
- The tuple of downsample blocks to use.
116
- only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
117
- block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
118
- The tuple of output channels for each block.
119
- layers_per_block (`int`, defaults to 2):
120
- The number of layers per block.
121
- downsample_padding (`int`, defaults to 1):
122
- The padding to use for the downsampling convolution.
123
- mid_block_scale_factor (`float`, defaults to 1):
124
- The scale factor to use for the mid block.
125
- act_fn (`str`, defaults to "silu"):
126
- The activation function to use.
127
- norm_num_groups (`int`, *optional*, defaults to 32):
128
- The number of groups to use for the normalization. If None, normalization and activation layers is skipped
129
- in post-processing.
130
- norm_eps (`float`, defaults to 1e-5):
131
- The epsilon to use for the normalization.
132
- cross_attention_dim (`int`, defaults to 1280):
133
- The dimension of the cross attention features.
134
- attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
135
- The dimension of the attention heads.
136
- use_linear_projection (`bool`, defaults to `False`):
137
- class_embed_type (`str`, *optional*, defaults to `None`):
138
- The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
139
- `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
140
- num_class_embeds (`int`, *optional*, defaults to 0):
141
- Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
142
- class conditioning with `class_embed_type` equal to `None`.
143
- upcast_attention (`bool`, defaults to `False`):
144
- resnet_time_scale_shift (`str`, defaults to `"default"`):
145
- Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
146
- projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
147
- The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
148
- `class_embed_type="projection"`.
149
- controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
150
- The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
151
- conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
152
- The tuple of output channel for each block in the `conditioning_embedding` layer.
153
- global_pool_conditions (`bool`, defaults to `False`):
154
- """
155
-
156
- _supports_gradient_checkpointing = True
157
-
158
- @register_to_config
159
- def __init__(
160
- self,
161
- in_channels: int = 4,
162
- conditioning_channels: int = 3,
163
- flip_sin_to_cos: bool = True,
164
- freq_shift: int = 0,
165
- down_block_types: Tuple[str] = (
166
- "CrossAttnDownBlock2D",
167
- "CrossAttnDownBlock2D",
168
- "CrossAttnDownBlock2D",
169
- "DownBlock2D",
170
- ),
171
- only_cross_attention: Union[bool, Tuple[bool]] = False,
172
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
173
- layers_per_block: int = 2,
174
- downsample_padding: int = 1,
175
- mid_block_scale_factor: float = 1,
176
- act_fn: str = "silu",
177
- norm_num_groups: Optional[int] = 32,
178
- norm_eps: float = 1e-5,
179
- cross_attention_dim: int = 1280,
180
- attention_head_dim: Union[int, Tuple[int]] = 8,
181
- num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
182
- use_linear_projection: bool = False,
183
- class_embed_type: Optional[str] = None,
184
- num_class_embeds: Optional[int] = None,
185
- upcast_attention: bool = False,
186
- resnet_time_scale_shift: str = "default",
187
- projection_class_embeddings_input_dim: Optional[int] = None,
188
- controlnet_conditioning_channel_order: str = "rgb",
189
- conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
190
- global_pool_conditions: bool = False,
191
- ):
192
- super().__init__()
193
-
194
- # If `num_attention_heads` is not defined (which is the case for most models)
195
- # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
196
- # The reason for this behavior is to correct for incorrectly named variables that were introduced
197
- # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
198
- # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
199
- # which is why we correct for the naming here.
200
- num_attention_heads = num_attention_heads or attention_head_dim
201
-
202
- # Check inputs
203
- if len(block_out_channels) != len(down_block_types):
204
- raise ValueError(
205
- f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
206
- )
207
-
208
- if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
209
- raise ValueError(
210
- f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
211
- )
212
-
213
- if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
214
- raise ValueError(
215
- f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
216
- )
217
-
218
- # input
219
- conv_in_kernel = 3
220
- conv_in_padding = (conv_in_kernel - 1) // 2
221
- self.conv_in = nn.Conv2d(
222
- in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
223
- )
224
-
225
- # time
226
- time_embed_dim = block_out_channels[0] * 4
227
-
228
- self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
229
- timestep_input_dim = block_out_channels[0]
230
-
231
- self.time_embedding = TimestepEmbedding(
232
- timestep_input_dim,
233
- time_embed_dim,
234
- act_fn=act_fn,
235
- )
236
-
237
- # class embedding
238
- if class_embed_type is None and num_class_embeds is not None:
239
- self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
240
- elif class_embed_type == "timestep":
241
- self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
242
- elif class_embed_type == "identity":
243
- self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
244
- elif class_embed_type == "projection":
245
- if projection_class_embeddings_input_dim is None:
246
- raise ValueError(
247
- "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
248
- )
249
- # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
250
- # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
251
- # 2. it projects from an arbitrary input dimension.
252
- #
253
- # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
254
- # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
255
- # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
256
- self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
257
- else:
258
- self.class_embedding = None
259
-
260
- # control net conditioning embedding
261
- self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
262
- conditioning_embedding_channels=block_out_channels[0],
263
- block_out_channels=conditioning_embedding_out_channels,
264
- conditioning_channels=conditioning_channels,
265
- )
266
-
267
- self.down_blocks = nn.ModuleList([])
268
- self.controlnet_down_blocks = nn.ModuleList([])
269
-
270
- if isinstance(only_cross_attention, bool):
271
- only_cross_attention = [only_cross_attention] * len(down_block_types)
272
-
273
- if isinstance(attention_head_dim, int):
274
- attention_head_dim = (attention_head_dim,) * len(down_block_types)
275
-
276
- if isinstance(num_attention_heads, int):
277
- num_attention_heads = (num_attention_heads,) * len(down_block_types)
278
-
279
- # down
280
- output_channel = block_out_channels[0]
281
-
282
- controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
283
- controlnet_block = zero_module(controlnet_block)
284
- self.controlnet_down_blocks.append(controlnet_block)
285
-
286
- for i, down_block_type in enumerate(down_block_types):
287
- input_channel = output_channel
288
- output_channel = block_out_channels[i]
289
- is_final_block = i == len(block_out_channels) - 1
290
-
291
- down_block = get_down_block(
292
- down_block_type,
293
- num_layers=layers_per_block,
294
- in_channels=input_channel,
295
- out_channels=output_channel,
296
- temb_channels=time_embed_dim,
297
- add_downsample=not is_final_block,
298
- resnet_eps=norm_eps,
299
- resnet_act_fn=act_fn,
300
- resnet_groups=norm_num_groups,
301
- cross_attention_dim=cross_attention_dim,
302
- num_attention_heads=num_attention_heads[i],
303
- attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
304
- downsample_padding=downsample_padding,
305
- use_linear_projection=use_linear_projection,
306
- only_cross_attention=only_cross_attention[i],
307
- upcast_attention=upcast_attention,
308
- resnet_time_scale_shift=resnet_time_scale_shift,
309
- )
310
- self.down_blocks.append(down_block)
311
-
312
- for _ in range(layers_per_block):
313
- controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
314
- controlnet_block = zero_module(controlnet_block)
315
- self.controlnet_down_blocks.append(controlnet_block)
316
-
317
- if not is_final_block:
318
- controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
319
- controlnet_block = zero_module(controlnet_block)
320
- self.controlnet_down_blocks.append(controlnet_block)
321
-
322
- # mid
323
- mid_block_channel = block_out_channels[-1]
324
-
325
- controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
326
- controlnet_block = zero_module(controlnet_block)
327
- self.controlnet_mid_block = controlnet_block
328
-
329
- self.mid_block = UNetMidBlock2DCrossAttn(
330
- in_channels=mid_block_channel,
331
- temb_channels=time_embed_dim,
332
- resnet_eps=norm_eps,
333
- resnet_act_fn=act_fn,
334
- output_scale_factor=mid_block_scale_factor,
335
- resnet_time_scale_shift=resnet_time_scale_shift,
336
- cross_attention_dim=cross_attention_dim,
337
- num_attention_heads=num_attention_heads[-1],
338
- resnet_groups=norm_num_groups,
339
- use_linear_projection=use_linear_projection,
340
- upcast_attention=upcast_attention,
341
- )
342
-
343
- @classmethod
344
- def from_unet(
345
- cls,
346
- unet: UNet2DConditionModel,
347
- controlnet_conditioning_channel_order: str = "rgb",
348
- conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
349
- load_weights_from_unet: bool = True,
350
- ):
351
- r"""
352
- Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
353
-
354
- Parameters:
355
- unet (`UNet2DConditionModel`):
356
- The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
357
- where applicable.
358
- """
359
- controlnet = cls(
360
- in_channels=unet.config.in_channels,
361
- flip_sin_to_cos=unet.config.flip_sin_to_cos,
362
- freq_shift=unet.config.freq_shift,
363
- down_block_types=unet.config.down_block_types,
364
- only_cross_attention=unet.config.only_cross_attention,
365
- block_out_channels=unet.config.block_out_channels,
366
- layers_per_block=unet.config.layers_per_block,
367
- downsample_padding=unet.config.downsample_padding,
368
- mid_block_scale_factor=unet.config.mid_block_scale_factor,
369
- act_fn=unet.config.act_fn,
370
- norm_num_groups=unet.config.norm_num_groups,
371
- norm_eps=unet.config.norm_eps,
372
- cross_attention_dim=unet.config.cross_attention_dim,
373
- attention_head_dim=unet.config.attention_head_dim,
374
- num_attention_heads=unet.config.num_attention_heads,
375
- use_linear_projection=unet.config.use_linear_projection,
376
- class_embed_type=unet.config.class_embed_type,
377
- num_class_embeds=unet.config.num_class_embeds,
378
- upcast_attention=unet.config.upcast_attention,
379
- resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
380
- projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
381
- controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
382
- conditioning_embedding_out_channels=conditioning_embedding_out_channels,
383
- )
384
-
385
- if load_weights_from_unet:
386
- controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
387
- controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
388
- controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
389
-
390
- if controlnet.class_embedding:
391
- controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
392
-
393
- controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
394
- controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
395
-
396
- return controlnet
397
-
398
- @property
399
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
400
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
401
- r"""
402
- Returns:
403
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
404
- indexed by its weight name.
405
- """
406
- # set recursively
407
- processors = {}
408
-
409
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
410
- if hasattr(module, "set_processor"):
411
- processors[f"{name}.processor"] = module.processor
412
-
413
- for sub_name, child in module.named_children():
414
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
415
-
416
- return processors
417
-
418
- for name, module in self.named_children():
419
- fn_recursive_add_processors(name, module, processors)
420
-
421
- return processors
422
-
423
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
424
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
425
- r"""
426
- Sets the attention processor to use to compute attention.
427
-
428
- Parameters:
429
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
430
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
431
- for **all** `Attention` layers.
432
-
433
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
434
- processor. This is strongly recommended when setting trainable attention processors.
435
-
436
- """
437
- count = len(self.attn_processors.keys())
438
-
439
- if isinstance(processor, dict) and len(processor) != count:
440
- raise ValueError(
441
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
442
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
443
- )
444
-
445
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
446
- if hasattr(module, "set_processor"):
447
- if not isinstance(processor, dict):
448
- module.set_processor(processor)
449
- else:
450
- module.set_processor(processor.pop(f"{name}.processor"))
451
-
452
- for sub_name, child in module.named_children():
453
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
454
-
455
- for name, module in self.named_children():
456
- fn_recursive_attn_processor(name, module, processor)
457
-
458
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
459
- def set_default_attn_processor(self):
460
- """
461
- Disables custom attention processors and sets the default attention implementation.
462
- """
463
- self.set_attn_processor(AttnProcessor())
464
-
465
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
466
- def set_attention_slice(self, slice_size):
467
- r"""
468
- Enable sliced attention computation.
469
-
470
- When this option is enabled, the attention module splits the input tensor in slices to compute attention in
471
- several steps. This is useful for saving some memory in exchange for a small decrease in speed.
472
-
473
- Args:
474
- slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
475
- When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
476
- `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
477
- provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
478
- must be a multiple of `slice_size`.
479
- """
480
- sliceable_head_dims = []
481
-
482
- def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
483
- if hasattr(module, "set_attention_slice"):
484
- sliceable_head_dims.append(module.sliceable_head_dim)
485
-
486
- for child in module.children():
487
- fn_recursive_retrieve_sliceable_dims(child)
488
-
489
- # retrieve number of attention layers
490
- for module in self.children():
491
- fn_recursive_retrieve_sliceable_dims(module)
492
-
493
- num_sliceable_layers = len(sliceable_head_dims)
494
-
495
- if slice_size == "auto":
496
- # half the attention head size is usually a good trade-off between
497
- # speed and memory
498
- slice_size = [dim // 2 for dim in sliceable_head_dims]
499
- elif slice_size == "max":
500
- # make smallest slice possible
501
- slice_size = num_sliceable_layers * [1]
502
-
503
- slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
504
-
505
- if len(slice_size) != len(sliceable_head_dims):
506
- raise ValueError(
507
- f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
508
- f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
509
- )
510
-
511
- for i in range(len(slice_size)):
512
- size = slice_size[i]
513
- dim = sliceable_head_dims[i]
514
- if size is not None and size > dim:
515
- raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
516
-
517
- # Recursively walk through all the children.
518
- # Any children which exposes the set_attention_slice method
519
- # gets the message
520
- def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
521
- if hasattr(module, "set_attention_slice"):
522
- module.set_attention_slice(slice_size.pop())
523
-
524
- for child in module.children():
525
- fn_recursive_set_attention_slice(child, slice_size)
526
-
527
- reversed_slice_size = list(reversed(slice_size))
528
- for module in self.children():
529
- fn_recursive_set_attention_slice(module, reversed_slice_size)
530
-
531
- def _set_gradient_checkpointing(self, module, value=False):
532
- if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
533
- module.gradient_checkpointing = value
534
-
535
- def forward(
536
- self,
537
- sample: torch.FloatTensor,
538
- timestep: Union[torch.Tensor, float, int],
539
- encoder_hidden_states: torch.Tensor,
540
- controlnet_cond: torch.FloatTensor,
541
- conditioning_scale: float = 1.0,
542
- class_labels: Optional[torch.Tensor] = None,
543
- timestep_cond: Optional[torch.Tensor] = None,
544
- attention_mask: Optional[torch.Tensor] = None,
545
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
546
- guess_mode: bool = False,
547
- return_dict: bool = True,
548
- ) -> Union[ControlNetOutput, Tuple]:
549
- """
550
- The [`ControlNetModel`] forward method.
551
-
552
- Args:
553
- sample (`torch.FloatTensor`):
554
- The noisy input tensor.
555
- timestep (`Union[torch.Tensor, float, int]`):
556
- The number of timesteps to denoise an input.
557
- encoder_hidden_states (`torch.Tensor`):
558
- The encoder hidden states.
559
- controlnet_cond (`torch.FloatTensor`):
560
- The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
561
- conditioning_scale (`float`, defaults to `1.0`):
562
- The scale factor for ControlNet outputs.
563
- class_labels (`torch.Tensor`, *optional*, defaults to `None`):
564
- Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
565
- timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
566
- attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
567
- cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`):
568
- A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
569
- guess_mode (`bool`, defaults to `False`):
570
- In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
571
- you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
572
- return_dict (`bool`, defaults to `True`):
573
- Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
574
-
575
- Returns:
576
- [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
577
- If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
578
- returned where the first element is the sample tensor.
579
- """
580
- # check channel order
581
- channel_order = self.config.controlnet_conditioning_channel_order
582
-
583
- if channel_order == "rgb":
584
- # in rgb order by default
585
- ...
586
- elif channel_order == "bgr":
587
- controlnet_cond = torch.flip(controlnet_cond, dims=[1])
588
- else:
589
- raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
590
-
591
- # prepare attention_mask
592
- if attention_mask is not None:
593
- attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
594
- attention_mask = attention_mask.unsqueeze(1)
595
-
596
- # 1. time
597
- timesteps = timestep
598
- if not torch.is_tensor(timesteps):
599
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
600
- # This would be a good case for the `match` statement (Python 3.10+)
601
- is_mps = sample.device.type == "mps"
602
- if isinstance(timestep, float):
603
- dtype = torch.float32 if is_mps else torch.float64
604
- else:
605
- dtype = torch.int32 if is_mps else torch.int64
606
- timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
607
- elif len(timesteps.shape) == 0:
608
- timesteps = timesteps[None].to(sample.device)
609
-
610
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
611
- timesteps = timesteps.expand(sample.shape[0])
612
-
613
- t_emb = self.time_proj(timesteps)
614
-
615
- # timesteps does not contain any weights and will always return f32 tensors
616
- # but time_embedding might actually be running in fp16. so we need to cast here.
617
- # there might be better ways to encapsulate this.
618
- t_emb = t_emb.to(dtype=sample.dtype)
619
-
620
- emb = self.time_embedding(t_emb, timestep_cond)
621
-
622
- if self.class_embedding is not None:
623
- if class_labels is None:
624
- raise ValueError("class_labels should be provided when num_class_embeds > 0")
625
-
626
- if self.config.class_embed_type == "timestep":
627
- class_labels = self.time_proj(class_labels)
628
-
629
- class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
630
- emb = emb + class_emb
631
-
632
- # 2. pre-process
633
- sample = self.conv_in(sample)
634
-
635
- controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
636
-
637
- sample = sample + controlnet_cond
638
-
639
- # 3. down
640
- down_block_res_samples = (sample,)
641
- for downsample_block in self.down_blocks:
642
- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
643
- sample, res_samples = downsample_block(
644
- hidden_states=sample,
645
- temb=emb,
646
- encoder_hidden_states=encoder_hidden_states,
647
- attention_mask=attention_mask,
648
- cross_attention_kwargs=cross_attention_kwargs,
649
- )
650
- else:
651
- sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
652
-
653
- down_block_res_samples += res_samples
654
-
655
- # 4. mid
656
- if self.mid_block is not None:
657
- sample = self.mid_block(
658
- sample,
659
- emb,
660
- encoder_hidden_states=encoder_hidden_states,
661
- attention_mask=attention_mask,
662
- cross_attention_kwargs=cross_attention_kwargs,
663
- )
664
-
665
- # 5. Control net blocks
666
-
667
- controlnet_down_block_res_samples = ()
668
-
669
- for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
670
- down_block_res_sample = controlnet_block(down_block_res_sample)
671
- controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
672
-
673
- down_block_res_samples = controlnet_down_block_res_samples
674
-
675
- mid_block_res_sample = self.controlnet_mid_block(sample)
676
-
677
- # 6. scaling
678
- if guess_mode and not self.config.global_pool_conditions:
679
- scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
680
-
681
- scales = scales * conditioning_scale
682
- down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
683
- mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
684
- else:
685
- down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
686
- mid_block_res_sample = mid_block_res_sample * conditioning_scale
687
-
688
- if self.config.global_pool_conditions:
689
- down_block_res_samples = [
690
- torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
691
- ]
692
- mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
693
-
694
- if not return_dict:
695
- return (down_block_res_samples, mid_block_res_sample)
696
-
697
- return ControlNetOutput(
698
- down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
699
- )
700
-
701
-
702
- def zero_module(module):
703
- for p in module.parameters():
704
- nn.init.zeros_(p)
705
- return module
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4DoF/diffusers/models/controlnet_flax.py DELETED
@@ -1,394 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from typing import Optional, Tuple, Union
15
-
16
- import flax
17
- import flax.linen as nn
18
- import jax
19
- import jax.numpy as jnp
20
- from flax.core.frozen_dict import FrozenDict
21
-
22
- from ..configuration_utils import ConfigMixin, flax_register_to_config
23
- from ..utils import BaseOutput
24
- from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
25
- from .modeling_flax_utils import FlaxModelMixin
26
- from .unet_2d_blocks_flax import (
27
- FlaxCrossAttnDownBlock2D,
28
- FlaxDownBlock2D,
29
- FlaxUNetMidBlock2DCrossAttn,
30
- )
31
-
32
-
33
- @flax.struct.dataclass
34
- class FlaxControlNetOutput(BaseOutput):
35
- """
36
- The output of [`FlaxControlNetModel`].
37
-
38
- Args:
39
- down_block_res_samples (`jnp.ndarray`):
40
- mid_block_res_sample (`jnp.ndarray`):
41
- """
42
-
43
- down_block_res_samples: jnp.ndarray
44
- mid_block_res_sample: jnp.ndarray
45
-
46
-
47
- class FlaxControlNetConditioningEmbedding(nn.Module):
48
- conditioning_embedding_channels: int
49
- block_out_channels: Tuple[int] = (16, 32, 96, 256)
50
- dtype: jnp.dtype = jnp.float32
51
-
52
- def setup(self):
53
- self.conv_in = nn.Conv(
54
- self.block_out_channels[0],
55
- kernel_size=(3, 3),
56
- padding=((1, 1), (1, 1)),
57
- dtype=self.dtype,
58
- )
59
-
60
- blocks = []
61
- for i in range(len(self.block_out_channels) - 1):
62
- channel_in = self.block_out_channels[i]
63
- channel_out = self.block_out_channels[i + 1]
64
- conv1 = nn.Conv(
65
- channel_in,
66
- kernel_size=(3, 3),
67
- padding=((1, 1), (1, 1)),
68
- dtype=self.dtype,
69
- )
70
- blocks.append(conv1)
71
- conv2 = nn.Conv(
72
- channel_out,
73
- kernel_size=(3, 3),
74
- strides=(2, 2),
75
- padding=((1, 1), (1, 1)),
76
- dtype=self.dtype,
77
- )
78
- blocks.append(conv2)
79
- self.blocks = blocks
80
-
81
- self.conv_out = nn.Conv(
82
- self.conditioning_embedding_channels,
83
- kernel_size=(3, 3),
84
- padding=((1, 1), (1, 1)),
85
- kernel_init=nn.initializers.zeros_init(),
86
- bias_init=nn.initializers.zeros_init(),
87
- dtype=self.dtype,
88
- )
89
-
90
- def __call__(self, conditioning):
91
- embedding = self.conv_in(conditioning)
92
- embedding = nn.silu(embedding)
93
-
94
- for block in self.blocks:
95
- embedding = block(embedding)
96
- embedding = nn.silu(embedding)
97
-
98
- embedding = self.conv_out(embedding)
99
-
100
- return embedding
101
-
102
-
103
- @flax_register_to_config
104
- class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
105
- r"""
106
- A ControlNet model.
107
-
108
- This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it’s generic methods
109
- implemented for all models (such as downloading or saving).
110
-
111
- This model is also a Flax Linen [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
112
- subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its
113
- general usage and behavior.
114
-
115
- Inherent JAX features such as the following are supported:
116
-
117
- - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
118
- - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
119
- - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
120
- - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
121
-
122
- Parameters:
123
- sample_size (`int`, *optional*):
124
- The size of the input sample.
125
- in_channels (`int`, *optional*, defaults to 4):
126
- The number of channels in the input sample.
127
- down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
128
- The tuple of downsample blocks to use.
129
- block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
130
- The tuple of output channels for each block.
131
- layers_per_block (`int`, *optional*, defaults to 2):
132
- The number of layers per block.
133
- attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
134
- The dimension of the attention heads.
135
- num_attention_heads (`int` or `Tuple[int]`, *optional*):
136
- The number of attention heads.
137
- cross_attention_dim (`int`, *optional*, defaults to 768):
138
- The dimension of the cross attention features.
139
- dropout (`float`, *optional*, defaults to 0):
140
- Dropout probability for down, up and bottleneck blocks.
141
- flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
142
- Whether to flip the sin to cos in the time embedding.
143
- freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
144
- controlnet_conditioning_channel_order (`str`, *optional*, defaults to `rgb`):
145
- The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
146
- conditioning_embedding_out_channels (`tuple`, *optional*, defaults to `(16, 32, 96, 256)`):
147
- The tuple of output channel for each block in the `conditioning_embedding` layer.
148
- """
149
- sample_size: int = 32
150
- in_channels: int = 4
151
- down_block_types: Tuple[str] = (
152
- "CrossAttnDownBlock2D",
153
- "CrossAttnDownBlock2D",
154
- "CrossAttnDownBlock2D",
155
- "DownBlock2D",
156
- )
157
- only_cross_attention: Union[bool, Tuple[bool]] = False
158
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
159
- layers_per_block: int = 2
160
- attention_head_dim: Union[int, Tuple[int]] = 8
161
- num_attention_heads: Optional[Union[int, Tuple[int]]] = None
162
- cross_attention_dim: int = 1280
163
- dropout: float = 0.0
164
- use_linear_projection: bool = False
165
- dtype: jnp.dtype = jnp.float32
166
- flip_sin_to_cos: bool = True
167
- freq_shift: int = 0
168
- controlnet_conditioning_channel_order: str = "rgb"
169
- conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256)
170
-
171
- def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
172
- # init input tensors
173
- sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
174
- sample = jnp.zeros(sample_shape, dtype=jnp.float32)
175
- timesteps = jnp.ones((1,), dtype=jnp.int32)
176
- encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
177
- controlnet_cond_shape = (1, 3, self.sample_size * 8, self.sample_size * 8)
178
- controlnet_cond = jnp.zeros(controlnet_cond_shape, dtype=jnp.float32)
179
-
180
- params_rng, dropout_rng = jax.random.split(rng)
181
- rngs = {"params": params_rng, "dropout": dropout_rng}
182
-
183
- return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"]
184
-
185
- def setup(self):
186
- block_out_channels = self.block_out_channels
187
- time_embed_dim = block_out_channels[0] * 4
188
-
189
- # If `num_attention_heads` is not defined (which is the case for most models)
190
- # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
191
- # The reason for this behavior is to correct for incorrectly named variables that were introduced
192
- # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
193
- # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
194
- # which is why we correct for the naming here.
195
- num_attention_heads = self.num_attention_heads or self.attention_head_dim
196
-
197
- # input
198
- self.conv_in = nn.Conv(
199
- block_out_channels[0],
200
- kernel_size=(3, 3),
201
- strides=(1, 1),
202
- padding=((1, 1), (1, 1)),
203
- dtype=self.dtype,
204
- )
205
-
206
- # time
207
- self.time_proj = FlaxTimesteps(
208
- block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift
209
- )
210
- self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
211
-
212
- self.controlnet_cond_embedding = FlaxControlNetConditioningEmbedding(
213
- conditioning_embedding_channels=block_out_channels[0],
214
- block_out_channels=self.conditioning_embedding_out_channels,
215
- )
216
-
217
- only_cross_attention = self.only_cross_attention
218
- if isinstance(only_cross_attention, bool):
219
- only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
220
-
221
- if isinstance(num_attention_heads, int):
222
- num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
223
-
224
- # down
225
- down_blocks = []
226
- controlnet_down_blocks = []
227
-
228
- output_channel = block_out_channels[0]
229
-
230
- controlnet_block = nn.Conv(
231
- output_channel,
232
- kernel_size=(1, 1),
233
- padding="VALID",
234
- kernel_init=nn.initializers.zeros_init(),
235
- bias_init=nn.initializers.zeros_init(),
236
- dtype=self.dtype,
237
- )
238
- controlnet_down_blocks.append(controlnet_block)
239
-
240
- for i, down_block_type in enumerate(self.down_block_types):
241
- input_channel = output_channel
242
- output_channel = block_out_channels[i]
243
- is_final_block = i == len(block_out_channels) - 1
244
-
245
- if down_block_type == "CrossAttnDownBlock2D":
246
- down_block = FlaxCrossAttnDownBlock2D(
247
- in_channels=input_channel,
248
- out_channels=output_channel,
249
- dropout=self.dropout,
250
- num_layers=self.layers_per_block,
251
- num_attention_heads=num_attention_heads[i],
252
- add_downsample=not is_final_block,
253
- use_linear_projection=self.use_linear_projection,
254
- only_cross_attention=only_cross_attention[i],
255
- dtype=self.dtype,
256
- )
257
- else:
258
- down_block = FlaxDownBlock2D(
259
- in_channels=input_channel,
260
- out_channels=output_channel,
261
- dropout=self.dropout,
262
- num_layers=self.layers_per_block,
263
- add_downsample=not is_final_block,
264
- dtype=self.dtype,
265
- )
266
-
267
- down_blocks.append(down_block)
268
-
269
- for _ in range(self.layers_per_block):
270
- controlnet_block = nn.Conv(
271
- output_channel,
272
- kernel_size=(1, 1),
273
- padding="VALID",
274
- kernel_init=nn.initializers.zeros_init(),
275
- bias_init=nn.initializers.zeros_init(),
276
- dtype=self.dtype,
277
- )
278
- controlnet_down_blocks.append(controlnet_block)
279
-
280
- if not is_final_block:
281
- controlnet_block = nn.Conv(
282
- output_channel,
283
- kernel_size=(1, 1),
284
- padding="VALID",
285
- kernel_init=nn.initializers.zeros_init(),
286
- bias_init=nn.initializers.zeros_init(),
287
- dtype=self.dtype,
288
- )
289
- controlnet_down_blocks.append(controlnet_block)
290
-
291
- self.down_blocks = down_blocks
292
- self.controlnet_down_blocks = controlnet_down_blocks
293
-
294
- # mid
295
- mid_block_channel = block_out_channels[-1]
296
- self.mid_block = FlaxUNetMidBlock2DCrossAttn(
297
- in_channels=mid_block_channel,
298
- dropout=self.dropout,
299
- num_attention_heads=num_attention_heads[-1],
300
- use_linear_projection=self.use_linear_projection,
301
- dtype=self.dtype,
302
- )
303
-
304
- self.controlnet_mid_block = nn.Conv(
305
- mid_block_channel,
306
- kernel_size=(1, 1),
307
- padding="VALID",
308
- kernel_init=nn.initializers.zeros_init(),
309
- bias_init=nn.initializers.zeros_init(),
310
- dtype=self.dtype,
311
- )
312
-
313
- def __call__(
314
- self,
315
- sample,
316
- timesteps,
317
- encoder_hidden_states,
318
- controlnet_cond,
319
- conditioning_scale: float = 1.0,
320
- return_dict: bool = True,
321
- train: bool = False,
322
- ) -> Union[FlaxControlNetOutput, Tuple]:
323
- r"""
324
- Args:
325
- sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
326
- timestep (`jnp.ndarray` or `float` or `int`): timesteps
327
- encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
328
- controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor
329
- conditioning_scale: (`float`) the scale factor for controlnet outputs
330
- return_dict (`bool`, *optional*, defaults to `True`):
331
- Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
332
- plain tuple.
333
- train (`bool`, *optional*, defaults to `False`):
334
- Use deterministic functions and disable dropout when not training.
335
-
336
- Returns:
337
- [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
338
- [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
339
- When returning a tuple, the first element is the sample tensor.
340
- """
341
- channel_order = self.controlnet_conditioning_channel_order
342
- if channel_order == "bgr":
343
- controlnet_cond = jnp.flip(controlnet_cond, axis=1)
344
-
345
- # 1. time
346
- if not isinstance(timesteps, jnp.ndarray):
347
- timesteps = jnp.array([timesteps], dtype=jnp.int32)
348
- elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
349
- timesteps = timesteps.astype(dtype=jnp.float32)
350
- timesteps = jnp.expand_dims(timesteps, 0)
351
-
352
- t_emb = self.time_proj(timesteps)
353
- t_emb = self.time_embedding(t_emb)
354
-
355
- # 2. pre-process
356
- sample = jnp.transpose(sample, (0, 2, 3, 1))
357
- sample = self.conv_in(sample)
358
-
359
- controlnet_cond = jnp.transpose(controlnet_cond, (0, 2, 3, 1))
360
- controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
361
- sample += controlnet_cond
362
-
363
- # 3. down
364
- down_block_res_samples = (sample,)
365
- for down_block in self.down_blocks:
366
- if isinstance(down_block, FlaxCrossAttnDownBlock2D):
367
- sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
368
- else:
369
- sample, res_samples = down_block(sample, t_emb, deterministic=not train)
370
- down_block_res_samples += res_samples
371
-
372
- # 4. mid
373
- sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
374
-
375
- # 5. contronet blocks
376
- controlnet_down_block_res_samples = ()
377
- for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
378
- down_block_res_sample = controlnet_block(down_block_res_sample)
379
- controlnet_down_block_res_samples += (down_block_res_sample,)
380
-
381
- down_block_res_samples = controlnet_down_block_res_samples
382
-
383
- mid_block_res_sample = self.controlnet_mid_block(sample)
384
-
385
- # 6. scaling
386
- down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
387
- mid_block_res_sample *= conditioning_scale
388
-
389
- if not return_dict:
390
- return (down_block_res_samples, mid_block_res_sample)
391
-
392
- return FlaxControlNetOutput(
393
- down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
394
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4DoF/diffusers/models/cross_attention.py DELETED
@@ -1,94 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from ..utils import deprecate
15
- from .attention_processor import ( # noqa: F401
16
- Attention,
17
- AttentionProcessor,
18
- AttnAddedKVProcessor,
19
- AttnProcessor2_0,
20
- LoRAAttnProcessor,
21
- LoRALinearLayer,
22
- LoRAXFormersAttnProcessor,
23
- SlicedAttnAddedKVProcessor,
24
- SlicedAttnProcessor,
25
- XFormersAttnProcessor,
26
- )
27
- from .attention_processor import AttnProcessor as AttnProcessorRename # noqa: F401
28
-
29
-
30
- deprecate(
31
- "cross_attention",
32
- "0.20.0",
33
- "Importing from cross_attention is deprecated. Please import from diffusers.models.attention_processor instead.",
34
- standard_warn=False,
35
- )
36
-
37
-
38
- AttnProcessor = AttentionProcessor
39
-
40
-
41
- class CrossAttention(Attention):
42
- def __init__(self, *args, **kwargs):
43
- deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
44
- deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
45
- super().__init__(*args, **kwargs)
46
-
47
-
48
- class CrossAttnProcessor(AttnProcessorRename):
49
- def __init__(self, *args, **kwargs):
50
- deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
51
- deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
52
- super().__init__(*args, **kwargs)
53
-
54
-
55
- class LoRACrossAttnProcessor(LoRAAttnProcessor):
56
- def __init__(self, *args, **kwargs):
57
- deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
58
- deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
59
- super().__init__(*args, **kwargs)
60
-
61
-
62
- class CrossAttnAddedKVProcessor(AttnAddedKVProcessor):
63
- def __init__(self, *args, **kwargs):
64
- deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
65
- deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
66
- super().__init__(*args, **kwargs)
67
-
68
-
69
- class XFormersCrossAttnProcessor(XFormersAttnProcessor):
70
- def __init__(self, *args, **kwargs):
71
- deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
72
- deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
73
- super().__init__(*args, **kwargs)
74
-
75
-
76
- class LoRAXFormersCrossAttnProcessor(LoRAXFormersAttnProcessor):
77
- def __init__(self, *args, **kwargs):
78
- deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
79
- deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
80
- super().__init__(*args, **kwargs)
81
-
82
-
83
- class SlicedCrossAttnProcessor(SlicedAttnProcessor):
84
- def __init__(self, *args, **kwargs):
85
- deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
86
- deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
87
- super().__init__(*args, **kwargs)
88
-
89
-
90
- class SlicedCrossAttnAddedKVProcessor(SlicedAttnAddedKVProcessor):
91
- def __init__(self, *args, **kwargs):
92
- deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
93
- deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
94
- super().__init__(*args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4DoF/diffusers/models/dual_transformer_2d.py DELETED
@@ -1,151 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from typing import Optional
15
-
16
- from torch import nn
17
-
18
- from .transformer_2d import Transformer2DModel, Transformer2DModelOutput
19
-
20
-
21
- class DualTransformer2DModel(nn.Module):
22
- """
23
- Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
24
-
25
- Parameters:
26
- num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
27
- attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
28
- in_channels (`int`, *optional*):
29
- Pass if the input is continuous. The number of channels in the input and output.
30
- num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
31
- dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
32
- cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
33
- sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
34
- Note that this is fixed at training time as it is used for learning a number of position embeddings. See
35
- `ImagePositionalEmbeddings`.
36
- num_vector_embeds (`int`, *optional*):
37
- Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
38
- Includes the class for the masked latent pixel.
39
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
40
- num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
41
- The number of diffusion steps used during training. Note that this is fixed at training time as it is used
42
- to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
43
- up to but not more than steps than `num_embeds_ada_norm`.
44
- attention_bias (`bool`, *optional*):
45
- Configure if the TransformerBlocks' attention should contain a bias parameter.
46
- """
47
-
48
- def __init__(
49
- self,
50
- num_attention_heads: int = 16,
51
- attention_head_dim: int = 88,
52
- in_channels: Optional[int] = None,
53
- num_layers: int = 1,
54
- dropout: float = 0.0,
55
- norm_num_groups: int = 32,
56
- cross_attention_dim: Optional[int] = None,
57
- attention_bias: bool = False,
58
- sample_size: Optional[int] = None,
59
- num_vector_embeds: Optional[int] = None,
60
- activation_fn: str = "geglu",
61
- num_embeds_ada_norm: Optional[int] = None,
62
- ):
63
- super().__init__()
64
- self.transformers = nn.ModuleList(
65
- [
66
- Transformer2DModel(
67
- num_attention_heads=num_attention_heads,
68
- attention_head_dim=attention_head_dim,
69
- in_channels=in_channels,
70
- num_layers=num_layers,
71
- dropout=dropout,
72
- norm_num_groups=norm_num_groups,
73
- cross_attention_dim=cross_attention_dim,
74
- attention_bias=attention_bias,
75
- sample_size=sample_size,
76
- num_vector_embeds=num_vector_embeds,
77
- activation_fn=activation_fn,
78
- num_embeds_ada_norm=num_embeds_ada_norm,
79
- )
80
- for _ in range(2)
81
- ]
82
- )
83
-
84
- # Variables that can be set by a pipeline:
85
-
86
- # The ratio of transformer1 to transformer2's output states to be combined during inference
87
- self.mix_ratio = 0.5
88
-
89
- # The shape of `encoder_hidden_states` is expected to be
90
- # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
91
- self.condition_lengths = [77, 257]
92
-
93
- # Which transformer to use to encode which condition.
94
- # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
95
- self.transformer_index_for_condition = [1, 0]
96
-
97
- def forward(
98
- self,
99
- hidden_states,
100
- encoder_hidden_states,
101
- timestep=None,
102
- attention_mask=None,
103
- cross_attention_kwargs=None,
104
- return_dict: bool = True,
105
- ):
106
- """
107
- Args:
108
- hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
109
- When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
110
- hidden_states
111
- encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
112
- Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
113
- self-attention.
114
- timestep ( `torch.long`, *optional*):
115
- Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
116
- attention_mask (`torch.FloatTensor`, *optional*):
117
- Optional attention mask to be applied in Attention
118
- return_dict (`bool`, *optional*, defaults to `True`):
119
- Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
120
-
121
- Returns:
122
- [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
123
- [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
124
- returning a tuple, the first element is the sample tensor.
125
- """
126
- input_states = hidden_states
127
-
128
- encoded_states = []
129
- tokens_start = 0
130
- # attention_mask is not used yet
131
- for i in range(2):
132
- # for each of the two transformers, pass the corresponding condition tokens
133
- condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
134
- transformer_index = self.transformer_index_for_condition[i]
135
- encoded_state = self.transformers[transformer_index](
136
- input_states,
137
- encoder_hidden_states=condition_state,
138
- timestep=timestep,
139
- cross_attention_kwargs=cross_attention_kwargs,
140
- return_dict=False,
141
- )[0]
142
- encoded_states.append(encoded_state - input_states)
143
- tokens_start += self.condition_lengths[i]
144
-
145
- output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
146
- output_states = output_states + input_states
147
-
148
- if not return_dict:
149
- return (output_states,)
150
-
151
- return Transformer2DModelOutput(sample=output_states)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4DoF/diffusers/models/embeddings.py DELETED
@@ -1,546 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import math
15
- from typing import Optional
16
-
17
- import numpy as np
18
- import torch
19
- from torch import nn
20
-
21
- from .activations import get_activation
22
-
23
-
24
- def get_timestep_embedding(
25
- timesteps: torch.Tensor,
26
- embedding_dim: int,
27
- flip_sin_to_cos: bool = False,
28
- downscale_freq_shift: float = 1,
29
- scale: float = 1,
30
- max_period: int = 10000,
31
- ):
32
- """
33
- This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
34
-
35
- :param timesteps: a 1-D Tensor of N indices, one per batch element.
36
- These may be fractional.
37
- :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
38
- embeddings. :return: an [N x dim] Tensor of positional embeddings.
39
- """
40
- assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
41
-
42
- half_dim = embedding_dim // 2
43
- exponent = -math.log(max_period) * torch.arange(
44
- start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
45
- )
46
- exponent = exponent / (half_dim - downscale_freq_shift)
47
-
48
- emb = torch.exp(exponent)
49
- emb = timesteps[:, None].float() * emb[None, :]
50
-
51
- # scale embeddings
52
- emb = scale * emb
53
-
54
- # concat sine and cosine embeddings
55
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
56
-
57
- # flip sine and cosine embeddings
58
- if flip_sin_to_cos:
59
- emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
60
-
61
- # zero pad
62
- if embedding_dim % 2 == 1:
63
- emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
64
- return emb
65
-
66
-
67
- def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
68
- """
69
- grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
70
- [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
71
- """
72
- grid_h = np.arange(grid_size, dtype=np.float32)
73
- grid_w = np.arange(grid_size, dtype=np.float32)
74
- grid = np.meshgrid(grid_w, grid_h) # here w goes first
75
- grid = np.stack(grid, axis=0)
76
-
77
- grid = grid.reshape([2, 1, grid_size, grid_size])
78
- pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
79
- if cls_token and extra_tokens > 0:
80
- pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
81
- return pos_embed
82
-
83
-
84
- def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
85
- if embed_dim % 2 != 0:
86
- raise ValueError("embed_dim must be divisible by 2")
87
-
88
- # use half of dimensions to encode grid_h
89
- emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
90
- emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
91
-
92
- emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
93
- return emb
94
-
95
-
96
- def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
97
- """
98
- embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
99
- """
100
- if embed_dim % 2 != 0:
101
- raise ValueError("embed_dim must be divisible by 2")
102
-
103
- omega = np.arange(embed_dim // 2, dtype=np.float64)
104
- omega /= embed_dim / 2.0
105
- omega = 1.0 / 10000**omega # (D/2,)
106
-
107
- pos = pos.reshape(-1) # (M,)
108
- out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
109
-
110
- emb_sin = np.sin(out) # (M, D/2)
111
- emb_cos = np.cos(out) # (M, D/2)
112
-
113
- emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
114
- return emb
115
-
116
-
117
- class PatchEmbed(nn.Module):
118
- """2D Image to Patch Embedding"""
119
-
120
- def __init__(
121
- self,
122
- height=224,
123
- width=224,
124
- patch_size=16,
125
- in_channels=3,
126
- embed_dim=768,
127
- layer_norm=False,
128
- flatten=True,
129
- bias=True,
130
- ):
131
- super().__init__()
132
-
133
- num_patches = (height // patch_size) * (width // patch_size)
134
- self.flatten = flatten
135
- self.layer_norm = layer_norm
136
-
137
- self.proj = nn.Conv2d(
138
- in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
139
- )
140
- if layer_norm:
141
- self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
142
- else:
143
- self.norm = None
144
-
145
- pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
146
- self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
147
-
148
- def forward(self, latent):
149
- latent = self.proj(latent)
150
- if self.flatten:
151
- latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
152
- if self.layer_norm:
153
- latent = self.norm(latent)
154
- return latent + self.pos_embed
155
-
156
-
157
- class TimestepEmbedding(nn.Module):
158
- def __init__(
159
- self,
160
- in_channels: int,
161
- time_embed_dim: int,
162
- act_fn: str = "silu",
163
- out_dim: int = None,
164
- post_act_fn: Optional[str] = None,
165
- cond_proj_dim=None,
166
- ):
167
- super().__init__()
168
-
169
- self.linear_1 = nn.Linear(in_channels, time_embed_dim)
170
-
171
- if cond_proj_dim is not None:
172
- self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
173
- else:
174
- self.cond_proj = None
175
-
176
- self.act = get_activation(act_fn)
177
-
178
- if out_dim is not None:
179
- time_embed_dim_out = out_dim
180
- else:
181
- time_embed_dim_out = time_embed_dim
182
- self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
183
-
184
- if post_act_fn is None:
185
- self.post_act = None
186
- else:
187
- self.post_act = get_activation(post_act_fn)
188
-
189
- def forward(self, sample, condition=None):
190
- if condition is not None:
191
- sample = sample + self.cond_proj(condition)
192
- sample = self.linear_1(sample)
193
-
194
- if self.act is not None:
195
- sample = self.act(sample)
196
-
197
- sample = self.linear_2(sample)
198
-
199
- if self.post_act is not None:
200
- sample = self.post_act(sample)
201
- return sample
202
-
203
-
204
- class Timesteps(nn.Module):
205
- def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
206
- super().__init__()
207
- self.num_channels = num_channels
208
- self.flip_sin_to_cos = flip_sin_to_cos
209
- self.downscale_freq_shift = downscale_freq_shift
210
-
211
- def forward(self, timesteps):
212
- t_emb = get_timestep_embedding(
213
- timesteps,
214
- self.num_channels,
215
- flip_sin_to_cos=self.flip_sin_to_cos,
216
- downscale_freq_shift=self.downscale_freq_shift,
217
- )
218
- return t_emb
219
-
220
-
221
- class GaussianFourierProjection(nn.Module):
222
- """Gaussian Fourier embeddings for noise levels."""
223
-
224
- def __init__(
225
- self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
226
- ):
227
- super().__init__()
228
- self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
229
- self.log = log
230
- self.flip_sin_to_cos = flip_sin_to_cos
231
-
232
- if set_W_to_weight:
233
- # to delete later
234
- self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
235
-
236
- self.weight = self.W
237
-
238
- def forward(self, x):
239
- if self.log:
240
- x = torch.log(x)
241
-
242
- x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
243
-
244
- if self.flip_sin_to_cos:
245
- out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
246
- else:
247
- out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
248
- return out
249
-
250
-
251
- class ImagePositionalEmbeddings(nn.Module):
252
- """
253
- Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
254
- height and width of the latent space.
255
-
256
- For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
257
-
258
- For VQ-diffusion:
259
-
260
- Output vector embeddings are used as input for the transformer.
261
-
262
- Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
263
-
264
- Args:
265
- num_embed (`int`):
266
- Number of embeddings for the latent pixels embeddings.
267
- height (`int`):
268
- Height of the latent image i.e. the number of height embeddings.
269
- width (`int`):
270
- Width of the latent image i.e. the number of width embeddings.
271
- embed_dim (`int`):
272
- Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
273
- """
274
-
275
- def __init__(
276
- self,
277
- num_embed: int,
278
- height: int,
279
- width: int,
280
- embed_dim: int,
281
- ):
282
- super().__init__()
283
-
284
- self.height = height
285
- self.width = width
286
- self.num_embed = num_embed
287
- self.embed_dim = embed_dim
288
-
289
- self.emb = nn.Embedding(self.num_embed, embed_dim)
290
- self.height_emb = nn.Embedding(self.height, embed_dim)
291
- self.width_emb = nn.Embedding(self.width, embed_dim)
292
-
293
- def forward(self, index):
294
- emb = self.emb(index)
295
-
296
- height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
297
-
298
- # 1 x H x D -> 1 x H x 1 x D
299
- height_emb = height_emb.unsqueeze(2)
300
-
301
- width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
302
-
303
- # 1 x W x D -> 1 x 1 x W x D
304
- width_emb = width_emb.unsqueeze(1)
305
-
306
- pos_emb = height_emb + width_emb
307
-
308
- # 1 x H x W x D -> 1 x L xD
309
- pos_emb = pos_emb.view(1, self.height * self.width, -1)
310
-
311
- emb = emb + pos_emb[:, : emb.shape[1], :]
312
-
313
- return emb
314
-
315
-
316
- class LabelEmbedding(nn.Module):
317
- """
318
- Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
319
-
320
- Args:
321
- num_classes (`int`): The number of classes.
322
- hidden_size (`int`): The size of the vector embeddings.
323
- dropout_prob (`float`): The probability of dropping a label.
324
- """
325
-
326
- def __init__(self, num_classes, hidden_size, dropout_prob):
327
- super().__init__()
328
- use_cfg_embedding = dropout_prob > 0
329
- self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
330
- self.num_classes = num_classes
331
- self.dropout_prob = dropout_prob
332
-
333
- def token_drop(self, labels, force_drop_ids=None):
334
- """
335
- Drops labels to enable classifier-free guidance.
336
- """
337
- if force_drop_ids is None:
338
- drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
339
- else:
340
- drop_ids = torch.tensor(force_drop_ids == 1)
341
- labels = torch.where(drop_ids, self.num_classes, labels)
342
- return labels
343
-
344
- def forward(self, labels: torch.LongTensor, force_drop_ids=None):
345
- use_dropout = self.dropout_prob > 0
346
- if (self.training and use_dropout) or (force_drop_ids is not None):
347
- labels = self.token_drop(labels, force_drop_ids)
348
- embeddings = self.embedding_table(labels)
349
- return embeddings
350
-
351
-
352
- class TextImageProjection(nn.Module):
353
- def __init__(
354
- self,
355
- text_embed_dim: int = 1024,
356
- image_embed_dim: int = 768,
357
- cross_attention_dim: int = 768,
358
- num_image_text_embeds: int = 10,
359
- ):
360
- super().__init__()
361
-
362
- self.num_image_text_embeds = num_image_text_embeds
363
- self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
364
- self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim)
365
-
366
- def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
367
- batch_size = text_embeds.shape[0]
368
-
369
- # image
370
- image_text_embeds = self.image_embeds(image_embeds)
371
- image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
372
-
373
- # text
374
- text_embeds = self.text_proj(text_embeds)
375
-
376
- return torch.cat([image_text_embeds, text_embeds], dim=1)
377
-
378
-
379
- class ImageProjection(nn.Module):
380
- def __init__(
381
- self,
382
- image_embed_dim: int = 768,
383
- cross_attention_dim: int = 768,
384
- num_image_text_embeds: int = 32,
385
- ):
386
- super().__init__()
387
-
388
- self.num_image_text_embeds = num_image_text_embeds
389
- self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
390
- self.norm = nn.LayerNorm(cross_attention_dim)
391
-
392
- def forward(self, image_embeds: torch.FloatTensor):
393
- batch_size = image_embeds.shape[0]
394
-
395
- # image
396
- image_embeds = self.image_embeds(image_embeds)
397
- image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
398
- image_embeds = self.norm(image_embeds)
399
- return image_embeds
400
-
401
-
402
- class CombinedTimestepLabelEmbeddings(nn.Module):
403
- def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
404
- super().__init__()
405
-
406
- self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
407
- self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
408
- self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
409
-
410
- def forward(self, timestep, class_labels, hidden_dtype=None):
411
- timesteps_proj = self.time_proj(timestep)
412
- timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
413
-
414
- class_labels = self.class_embedder(class_labels) # (N, D)
415
-
416
- conditioning = timesteps_emb + class_labels # (N, D)
417
-
418
- return conditioning
419
-
420
-
421
- class TextTimeEmbedding(nn.Module):
422
- def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
423
- super().__init__()
424
- self.norm1 = nn.LayerNorm(encoder_dim)
425
- self.pool = AttentionPooling(num_heads, encoder_dim)
426
- self.proj = nn.Linear(encoder_dim, time_embed_dim)
427
- self.norm2 = nn.LayerNorm(time_embed_dim)
428
-
429
- def forward(self, hidden_states):
430
- hidden_states = self.norm1(hidden_states)
431
- hidden_states = self.pool(hidden_states)
432
- hidden_states = self.proj(hidden_states)
433
- hidden_states = self.norm2(hidden_states)
434
- return hidden_states
435
-
436
-
437
- class TextImageTimeEmbedding(nn.Module):
438
- def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536):
439
- super().__init__()
440
- self.text_proj = nn.Linear(text_embed_dim, time_embed_dim)
441
- self.text_norm = nn.LayerNorm(time_embed_dim)
442
- self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
443
-
444
- def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
445
- # text
446
- time_text_embeds = self.text_proj(text_embeds)
447
- time_text_embeds = self.text_norm(time_text_embeds)
448
-
449
- # image
450
- time_image_embeds = self.image_proj(image_embeds)
451
-
452
- return time_image_embeds + time_text_embeds
453
-
454
-
455
- class ImageTimeEmbedding(nn.Module):
456
- def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
457
- super().__init__()
458
- self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
459
- self.image_norm = nn.LayerNorm(time_embed_dim)
460
-
461
- def forward(self, image_embeds: torch.FloatTensor):
462
- # image
463
- time_image_embeds = self.image_proj(image_embeds)
464
- time_image_embeds = self.image_norm(time_image_embeds)
465
- return time_image_embeds
466
-
467
-
468
- class ImageHintTimeEmbedding(nn.Module):
469
- def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
470
- super().__init__()
471
- self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
472
- self.image_norm = nn.LayerNorm(time_embed_dim)
473
- self.input_hint_block = nn.Sequential(
474
- nn.Conv2d(3, 16, 3, padding=1),
475
- nn.SiLU(),
476
- nn.Conv2d(16, 16, 3, padding=1),
477
- nn.SiLU(),
478
- nn.Conv2d(16, 32, 3, padding=1, stride=2),
479
- nn.SiLU(),
480
- nn.Conv2d(32, 32, 3, padding=1),
481
- nn.SiLU(),
482
- nn.Conv2d(32, 96, 3, padding=1, stride=2),
483
- nn.SiLU(),
484
- nn.Conv2d(96, 96, 3, padding=1),
485
- nn.SiLU(),
486
- nn.Conv2d(96, 256, 3, padding=1, stride=2),
487
- nn.SiLU(),
488
- nn.Conv2d(256, 4, 3, padding=1),
489
- )
490
-
491
- def forward(self, image_embeds: torch.FloatTensor, hint: torch.FloatTensor):
492
- # image
493
- time_image_embeds = self.image_proj(image_embeds)
494
- time_image_embeds = self.image_norm(time_image_embeds)
495
- hint = self.input_hint_block(hint)
496
- return time_image_embeds, hint
497
-
498
-
499
- class AttentionPooling(nn.Module):
500
- # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
501
-
502
- def __init__(self, num_heads, embed_dim, dtype=None):
503
- super().__init__()
504
- self.dtype = dtype
505
- self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5)
506
- self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
507
- self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
508
- self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
509
- self.num_heads = num_heads
510
- self.dim_per_head = embed_dim // self.num_heads
511
-
512
- def forward(self, x):
513
- bs, length, width = x.size()
514
-
515
- def shape(x):
516
- # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
517
- x = x.view(bs, -1, self.num_heads, self.dim_per_head)
518
- # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
519
- x = x.transpose(1, 2)
520
- # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
521
- x = x.reshape(bs * self.num_heads, -1, self.dim_per_head)
522
- # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length)
523
- x = x.transpose(1, 2)
524
- return x
525
-
526
- class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype)
527
- x = torch.cat([class_token, x], dim=1) # (bs, length+1, width)
528
-
529
- # (bs*n_heads, class_token_length, dim_per_head)
530
- q = shape(self.q_proj(class_token))
531
- # (bs*n_heads, length+class_token_length, dim_per_head)
532
- k = shape(self.k_proj(x))
533
- v = shape(self.v_proj(x))
534
-
535
- # (bs*n_heads, class_token_length, length+class_token_length):
536
- scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))
537
- weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
538
- weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
539
-
540
- # (bs*n_heads, dim_per_head, class_token_length)
541
- a = torch.einsum("bts,bcs->bct", weight, v)
542
-
543
- # (bs, length+1, width)
544
- a = a.reshape(bs, -1, 1).transpose(1, 2)
545
-
546
- return a[:, 0, :] # cls_token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4DoF/diffusers/models/embeddings_flax.py DELETED
@@ -1,95 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import math
15
-
16
- import flax.linen as nn
17
- import jax.numpy as jnp
18
-
19
-
20
- def get_sinusoidal_embeddings(
21
- timesteps: jnp.ndarray,
22
- embedding_dim: int,
23
- freq_shift: float = 1,
24
- min_timescale: float = 1,
25
- max_timescale: float = 1.0e4,
26
- flip_sin_to_cos: bool = False,
27
- scale: float = 1.0,
28
- ) -> jnp.ndarray:
29
- """Returns the positional encoding (same as Tensor2Tensor).
30
-
31
- Args:
32
- timesteps: a 1-D Tensor of N indices, one per batch element.
33
- These may be fractional.
34
- embedding_dim: The number of output channels.
35
- min_timescale: The smallest time unit (should probably be 0.0).
36
- max_timescale: The largest time unit.
37
- Returns:
38
- a Tensor of timing signals [N, num_channels]
39
- """
40
- assert timesteps.ndim == 1, "Timesteps should be a 1d-array"
41
- assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even"
42
- num_timescales = float(embedding_dim // 2)
43
- log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift)
44
- inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment)
45
- emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0)
46
-
47
- # scale embeddings
48
- scaled_time = scale * emb
49
-
50
- if flip_sin_to_cos:
51
- signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1)
52
- else:
53
- signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1)
54
- signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])
55
- return signal
56
-
57
-
58
- class FlaxTimestepEmbedding(nn.Module):
59
- r"""
60
- Time step Embedding Module. Learns embeddings for input time steps.
61
-
62
- Args:
63
- time_embed_dim (`int`, *optional*, defaults to `32`):
64
- Time step embedding dimension
65
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
66
- Parameters `dtype`
67
- """
68
- time_embed_dim: int = 32
69
- dtype: jnp.dtype = jnp.float32
70
-
71
- @nn.compact
72
- def __call__(self, temb):
73
- temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb)
74
- temb = nn.silu(temb)
75
- temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb)
76
- return temb
77
-
78
-
79
- class FlaxTimesteps(nn.Module):
80
- r"""
81
- Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239
82
-
83
- Args:
84
- dim (`int`, *optional*, defaults to `32`):
85
- Time step embedding dimension
86
- """
87
- dim: int = 32
88
- flip_sin_to_cos: bool = False
89
- freq_shift: float = 1
90
-
91
- @nn.compact
92
- def __call__(self, timesteps):
93
- return get_sinusoidal_embeddings(
94
- timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift
95
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4DoF/diffusers/models/modeling_flax_pytorch_utils.py DELETED
@@ -1,118 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2023 The HuggingFace Inc. team.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """ PyTorch - Flax general utilities."""
16
- import re
17
-
18
- import jax.numpy as jnp
19
- from flax.traverse_util import flatten_dict, unflatten_dict
20
- from jax.random import PRNGKey
21
-
22
- from ..utils import logging
23
-
24
-
25
- logger = logging.get_logger(__name__)
26
-
27
-
28
- def rename_key(key):
29
- regex = r"\w+[.]\d+"
30
- pats = re.findall(regex, key)
31
- for pat in pats:
32
- key = key.replace(pat, "_".join(pat.split(".")))
33
- return key
34
-
35
-
36
- #####################
37
- # PyTorch => Flax #
38
- #####################
39
-
40
-
41
- # Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
42
- # and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
43
- def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
44
- """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
45
-
46
- # conv norm or layer norm
47
- renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
48
- if (
49
- any("norm" in str_ for str_ in pt_tuple_key)
50
- and (pt_tuple_key[-1] == "bias")
51
- and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict)
52
- and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
53
- ):
54
- renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
55
- return renamed_pt_tuple_key, pt_tensor
56
- elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
57
- renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
58
- return renamed_pt_tuple_key, pt_tensor
59
-
60
- # embedding
61
- if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
62
- pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
63
- return renamed_pt_tuple_key, pt_tensor
64
-
65
- # conv layer
66
- renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
67
- if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4:
68
- pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
69
- return renamed_pt_tuple_key, pt_tensor
70
-
71
- # linear layer
72
- renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
73
- if pt_tuple_key[-1] == "weight":
74
- pt_tensor = pt_tensor.T
75
- return renamed_pt_tuple_key, pt_tensor
76
-
77
- # old PyTorch layer norm weight
78
- renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
79
- if pt_tuple_key[-1] == "gamma":
80
- return renamed_pt_tuple_key, pt_tensor
81
-
82
- # old PyTorch layer norm bias
83
- renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
84
- if pt_tuple_key[-1] == "beta":
85
- return renamed_pt_tuple_key, pt_tensor
86
-
87
- return pt_tuple_key, pt_tensor
88
-
89
-
90
- def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42):
91
- # Step 1: Convert pytorch tensor to numpy
92
- pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
93
-
94
- # Step 2: Since the model is stateless, get random Flax params
95
- random_flax_params = flax_model.init_weights(PRNGKey(init_key))
96
-
97
- random_flax_state_dict = flatten_dict(random_flax_params)
98
- flax_state_dict = {}
99
-
100
- # Need to change some parameters name to match Flax names
101
- for pt_key, pt_tensor in pt_state_dict.items():
102
- renamed_pt_key = rename_key(pt_key)
103
- pt_tuple_key = tuple(renamed_pt_key.split("."))
104
-
105
- # Correctly rename weight parameters
106
- flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict)
107
-
108
- if flax_key in random_flax_state_dict:
109
- if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
110
- raise ValueError(
111
- f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
112
- f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
113
- )
114
-
115
- # also add unexpected weight so that warning is thrown
116
- flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
117
-
118
- return unflatten_dict(flax_state_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4DoF/diffusers/models/modeling_flax_utils.py DELETED
@@ -1,534 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2023 The HuggingFace Inc. team.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import os
17
- from pickle import UnpicklingError
18
- from typing import Any, Dict, Union
19
-
20
- import jax
21
- import jax.numpy as jnp
22
- import msgpack.exceptions
23
- from flax.core.frozen_dict import FrozenDict, unfreeze
24
- from flax.serialization import from_bytes, to_bytes
25
- from flax.traverse_util import flatten_dict, unflatten_dict
26
- from huggingface_hub import hf_hub_download
27
- from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
28
- from requests import HTTPError
29
-
30
- from .. import __version__, is_torch_available
31
- from ..utils import (
32
- CONFIG_NAME,
33
- DIFFUSERS_CACHE,
34
- FLAX_WEIGHTS_NAME,
35
- HUGGINGFACE_CO_RESOLVE_ENDPOINT,
36
- WEIGHTS_NAME,
37
- logging,
38
- )
39
- from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
40
-
41
-
42
- logger = logging.get_logger(__name__)
43
-
44
-
45
- class FlaxModelMixin:
46
- r"""
47
- Base class for all Flax models.
48
-
49
- [`FlaxModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
50
- saving models.
51
-
52
- - **config_name** ([`str`]) -- Filename to save a model to when calling [`~FlaxModelMixin.save_pretrained`].
53
- """
54
- config_name = CONFIG_NAME
55
- _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
56
- _flax_internal_args = ["name", "parent", "dtype"]
57
-
58
- @classmethod
59
- def _from_config(cls, config, **kwargs):
60
- """
61
- All context managers that the model should be initialized under go here.
62
- """
63
- return cls(config, **kwargs)
64
-
65
- def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
66
- """
67
- Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
68
- """
69
-
70
- # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
71
- def conditional_cast(param):
72
- if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
73
- param = param.astype(dtype)
74
- return param
75
-
76
- if mask is None:
77
- return jax.tree_map(conditional_cast, params)
78
-
79
- flat_params = flatten_dict(params)
80
- flat_mask, _ = jax.tree_flatten(mask)
81
-
82
- for masked, key in zip(flat_mask, flat_params.keys()):
83
- if masked:
84
- param = flat_params[key]
85
- flat_params[key] = conditional_cast(param)
86
-
87
- return unflatten_dict(flat_params)
88
-
89
- def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
90
- r"""
91
- Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
92
- the `params` in place.
93
-
94
- This method can be used on a TPU to explicitly convert the model parameters to bfloat16 precision to do full
95
- half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.
96
-
97
- Arguments:
98
- params (`Union[Dict, FrozenDict]`):
99
- A `PyTree` of model parameters.
100
- mask (`Union[Dict, FrozenDict]`):
101
- A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
102
- for params you want to cast, and `False` for those you want to skip.
103
-
104
- Examples:
105
-
106
- ```python
107
- >>> from diffusers import FlaxUNet2DConditionModel
108
-
109
- >>> # load model
110
- >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
111
- >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
112
- >>> params = model.to_bf16(params)
113
- >>> # If you don't want to cast certain parameters (for example layer norm bias and scale)
114
- >>> # then pass the mask as follows
115
- >>> from flax import traverse_util
116
-
117
- >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
118
- >>> flat_params = traverse_util.flatten_dict(params)
119
- >>> mask = {
120
- ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
121
- ... for path in flat_params
122
- ... }
123
- >>> mask = traverse_util.unflatten_dict(mask)
124
- >>> params = model.to_bf16(params, mask)
125
- ```"""
126
- return self._cast_floating_to(params, jnp.bfloat16, mask)
127
-
128
- def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
129
- r"""
130
- Cast the floating-point `params` to `jax.numpy.float32`. This method can be used to explicitly convert the
131
- model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place.
132
-
133
- Arguments:
134
- params (`Union[Dict, FrozenDict]`):
135
- A `PyTree` of model parameters.
136
- mask (`Union[Dict, FrozenDict]`):
137
- A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
138
- for params you want to cast, and `False` for those you want to skip.
139
-
140
- Examples:
141
-
142
- ```python
143
- >>> from diffusers import FlaxUNet2DConditionModel
144
-
145
- >>> # Download model and configuration from huggingface.co
146
- >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
147
- >>> # By default, the model params will be in fp32, to illustrate the use of this method,
148
- >>> # we'll first cast to fp16 and back to fp32
149
- >>> params = model.to_f16(params)
150
- >>> # now cast back to fp32
151
- >>> params = model.to_fp32(params)
152
- ```"""
153
- return self._cast_floating_to(params, jnp.float32, mask)
154
-
155
- def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
156
- r"""
157
- Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the
158
- `params` in place.
159
-
160
- This method can be used on a GPU to explicitly convert the model parameters to float16 precision to do full
161
- half-precision training or to save weights in float16 for inference in order to save memory and improve speed.
162
-
163
- Arguments:
164
- params (`Union[Dict, FrozenDict]`):
165
- A `PyTree` of model parameters.
166
- mask (`Union[Dict, FrozenDict]`):
167
- A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
168
- for params you want to cast, and `False` for those you want to skip.
169
-
170
- Examples:
171
-
172
- ```python
173
- >>> from diffusers import FlaxUNet2DConditionModel
174
-
175
- >>> # load model
176
- >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
177
- >>> # By default, the model params will be in fp32, to cast these to float16
178
- >>> params = model.to_fp16(params)
179
- >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
180
- >>> # then pass the mask as follows
181
- >>> from flax import traverse_util
182
-
183
- >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
184
- >>> flat_params = traverse_util.flatten_dict(params)
185
- >>> mask = {
186
- ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
187
- ... for path in flat_params
188
- ... }
189
- >>> mask = traverse_util.unflatten_dict(mask)
190
- >>> params = model.to_fp16(params, mask)
191
- ```"""
192
- return self._cast_floating_to(params, jnp.float16, mask)
193
-
194
- def init_weights(self, rng: jax.random.KeyArray) -> Dict:
195
- raise NotImplementedError(f"init_weights method has to be implemented for {self}")
196
-
197
- @classmethod
198
- def from_pretrained(
199
- cls,
200
- pretrained_model_name_or_path: Union[str, os.PathLike],
201
- dtype: jnp.dtype = jnp.float32,
202
- *model_args,
203
- **kwargs,
204
- ):
205
- r"""
206
- Instantiate a pretrained Flax model from a pretrained model configuration.
207
-
208
- Parameters:
209
- pretrained_model_name_or_path (`str` or `os.PathLike`):
210
- Can be either:
211
-
212
- - A string, the *model id* (for example `runwayml/stable-diffusion-v1-5`) of a pretrained model
213
- hosted on the Hub.
214
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
215
- using [`~FlaxModelMixin.save_pretrained`].
216
- dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
217
- The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
218
- `jax.numpy.bfloat16` (on TPUs).
219
-
220
- This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
221
- specified, all the computation will be performed with the given `dtype`.
222
-
223
- <Tip>
224
-
225
- This only specifies the dtype of the *computation* and does not influence the dtype of model
226
- parameters.
227
-
228
- If you wish to change the dtype of the model parameters, see [`~FlaxModelMixin.to_fp16`] and
229
- [`~FlaxModelMixin.to_bf16`].
230
-
231
- </Tip>
232
-
233
- model_args (sequence of positional arguments, *optional*):
234
- All remaining positional arguments are passed to the underlying model's `__init__` method.
235
- cache_dir (`Union[str, os.PathLike]`, *optional*):
236
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
237
- is not used.
238
- force_download (`bool`, *optional*, defaults to `False`):
239
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
240
- cached versions if they exist.
241
- resume_download (`bool`, *optional*, defaults to `False`):
242
- Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
243
- incompletely downloaded files are deleted.
244
- proxies (`Dict[str, str]`, *optional*):
245
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
246
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
247
- local_files_only(`bool`, *optional*, defaults to `False`):
248
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
249
- won't be downloaded from the Hub.
250
- revision (`str`, *optional*, defaults to `"main"`):
251
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
252
- allowed by Git.
253
- from_pt (`bool`, *optional*, defaults to `False`):
254
- Load the model weights from a PyTorch checkpoint save file.
255
- kwargs (remaining dictionary of keyword arguments, *optional*):
256
- Can be used to update the configuration object (after it is loaded) and initiate the model (for
257
- example, `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
258
- automatically loaded:
259
-
260
- - If a configuration is provided with `config`, `kwargs` are directly passed to the underlying
261
- model's `__init__` method (we assume all relevant updates to the configuration have already been
262
- done).
263
- - If a configuration is not provided, `kwargs` are first passed to the configuration class
264
- initialization function [`~ConfigMixin.from_config`]. Each key of the `kwargs` that corresponds
265
- to a configuration attribute is used to override said attribute with the supplied `kwargs` value.
266
- Remaining keys that do not correspond to any configuration attribute are passed to the underlying
267
- model's `__init__` function.
268
-
269
- Examples:
270
-
271
- ```python
272
- >>> from diffusers import FlaxUNet2DConditionModel
273
-
274
- >>> # Download model and configuration from huggingface.co and cache.
275
- >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
276
- >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
277
- >>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/")
278
- ```
279
-
280
- If you get the error message below, you need to finetune the weights for your downstream task:
281
-
282
- ```bash
283
- Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
284
- - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
285
- You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
286
- ```
287
- """
288
- config = kwargs.pop("config", None)
289
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
290
- force_download = kwargs.pop("force_download", False)
291
- from_pt = kwargs.pop("from_pt", False)
292
- resume_download = kwargs.pop("resume_download", False)
293
- proxies = kwargs.pop("proxies", None)
294
- local_files_only = kwargs.pop("local_files_only", False)
295
- use_auth_token = kwargs.pop("use_auth_token", None)
296
- revision = kwargs.pop("revision", None)
297
- subfolder = kwargs.pop("subfolder", None)
298
-
299
- user_agent = {
300
- "diffusers": __version__,
301
- "file_type": "model",
302
- "framework": "flax",
303
- }
304
-
305
- # Load config if we don't provide a configuration
306
- config_path = config if config is not None else pretrained_model_name_or_path
307
- model, model_kwargs = cls.from_config(
308
- config_path,
309
- cache_dir=cache_dir,
310
- return_unused_kwargs=True,
311
- force_download=force_download,
312
- resume_download=resume_download,
313
- proxies=proxies,
314
- local_files_only=local_files_only,
315
- use_auth_token=use_auth_token,
316
- revision=revision,
317
- subfolder=subfolder,
318
- # model args
319
- dtype=dtype,
320
- **kwargs,
321
- )
322
-
323
- # Load model
324
- pretrained_path_with_subfolder = (
325
- pretrained_model_name_or_path
326
- if subfolder is None
327
- else os.path.join(pretrained_model_name_or_path, subfolder)
328
- )
329
- if os.path.isdir(pretrained_path_with_subfolder):
330
- if from_pt:
331
- if not os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
332
- raise EnvironmentError(
333
- f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_path_with_subfolder} "
334
- )
335
- model_file = os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)
336
- elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)):
337
- # Load from a Flax checkpoint
338
- model_file = os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)
339
- # Check if pytorch weights exist instead
340
- elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
341
- raise EnvironmentError(
342
- f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model"
343
- " using `from_pt=True`."
344
- )
345
- else:
346
- raise EnvironmentError(
347
- f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
348
- f"{pretrained_path_with_subfolder}."
349
- )
350
- else:
351
- try:
352
- model_file = hf_hub_download(
353
- pretrained_model_name_or_path,
354
- filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME,
355
- cache_dir=cache_dir,
356
- force_download=force_download,
357
- proxies=proxies,
358
- resume_download=resume_download,
359
- local_files_only=local_files_only,
360
- use_auth_token=use_auth_token,
361
- user_agent=user_agent,
362
- subfolder=subfolder,
363
- revision=revision,
364
- )
365
-
366
- except RepositoryNotFoundError:
367
- raise EnvironmentError(
368
- f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
369
- "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
370
- "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
371
- "login`."
372
- )
373
- except RevisionNotFoundError:
374
- raise EnvironmentError(
375
- f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
376
- "this model name. Check the model page at "
377
- f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
378
- )
379
- except EntryNotFoundError:
380
- raise EnvironmentError(
381
- f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME}."
382
- )
383
- except HTTPError as err:
384
- raise EnvironmentError(
385
- f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
386
- f"{err}"
387
- )
388
- except ValueError:
389
- raise EnvironmentError(
390
- f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
391
- f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
392
- f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your"
393
- " internet connection or see how to run the library in offline mode at"
394
- " 'https://huggingface.co/docs/transformers/installation#offline-mode'."
395
- )
396
- except EnvironmentError:
397
- raise EnvironmentError(
398
- f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
399
- "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
400
- f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
401
- f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
402
- )
403
-
404
- if from_pt:
405
- if is_torch_available():
406
- from .modeling_utils import load_state_dict
407
- else:
408
- raise EnvironmentError(
409
- "Can't load the model in PyTorch format because PyTorch is not installed. "
410
- "Please, install PyTorch or use native Flax weights."
411
- )
412
-
413
- # Step 1: Get the pytorch file
414
- pytorch_model_file = load_state_dict(model_file)
415
-
416
- # Step 2: Convert the weights
417
- state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model)
418
- else:
419
- try:
420
- with open(model_file, "rb") as state_f:
421
- state = from_bytes(cls, state_f.read())
422
- except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
423
- try:
424
- with open(model_file) as f:
425
- if f.read().startswith("version"):
426
- raise OSError(
427
- "You seem to have cloned a repository without having git-lfs installed. Please"
428
- " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
429
- " folder you cloned."
430
- )
431
- else:
432
- raise ValueError from e
433
- except (UnicodeDecodeError, ValueError):
434
- raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
435
- # make sure all arrays are stored as jnp.ndarray
436
- # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
437
- # https://github.com/google/flax/issues/1261
438
- state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)
439
-
440
- # flatten dicts
441
- state = flatten_dict(state)
442
-
443
- params_shape_tree = jax.eval_shape(model.init_weights, rng=jax.random.PRNGKey(0))
444
- required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
445
-
446
- shape_state = flatten_dict(unfreeze(params_shape_tree))
447
-
448
- missing_keys = required_params - set(state.keys())
449
- unexpected_keys = set(state.keys()) - required_params
450
-
451
- if missing_keys:
452
- logger.warning(
453
- f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
454
- "Make sure to call model.init_weights to initialize the missing weights."
455
- )
456
- cls._missing_keys = missing_keys
457
-
458
- for key in state.keys():
459
- if key in shape_state and state[key].shape != shape_state[key].shape:
460
- raise ValueError(
461
- f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
462
- f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. "
463
- )
464
-
465
- # remove unexpected keys to not be saved again
466
- for unexpected_key in unexpected_keys:
467
- del state[unexpected_key]
468
-
469
- if len(unexpected_keys) > 0:
470
- logger.warning(
471
- f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
472
- f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
473
- f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
474
- " with another architecture."
475
- )
476
- else:
477
- logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
478
-
479
- if len(missing_keys) > 0:
480
- logger.warning(
481
- f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
482
- f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
483
- " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
484
- )
485
- else:
486
- logger.info(
487
- f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
488
- f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
489
- f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
490
- " training."
491
- )
492
-
493
- return model, unflatten_dict(state)
494
-
495
- def save_pretrained(
496
- self,
497
- save_directory: Union[str, os.PathLike],
498
- params: Union[Dict, FrozenDict],
499
- is_main_process: bool = True,
500
- ):
501
- """
502
- Save a model and its configuration file to a directory so that it can be reloaded using the
503
- [`~FlaxModelMixin.from_pretrained`] class method.
504
-
505
- Arguments:
506
- save_directory (`str` or `os.PathLike`):
507
- Directory to save a model and its configuration file to. Will be created if it doesn't exist.
508
- params (`Union[Dict, FrozenDict]`):
509
- A `PyTree` of model parameters.
510
- is_main_process (`bool`, *optional*, defaults to `True`):
511
- Whether the process calling this is the main process or not. Useful during distributed training and you
512
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
513
- process to avoid race conditions.
514
- """
515
- if os.path.isfile(save_directory):
516
- logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
517
- return
518
-
519
- os.makedirs(save_directory, exist_ok=True)
520
-
521
- model_to_save = self
522
-
523
- # Attach architecture to the config
524
- # Save the config
525
- if is_main_process:
526
- model_to_save.save_config(save_directory)
527
-
528
- # save model
529
- output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
530
- with open(output_model_file, "wb") as f:
531
- model_bytes = to_bytes(params)
532
- f.write(model_bytes)
533
-
534
- logger.info(f"Model weights saved in {output_model_file}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4DoF/diffusers/models/modeling_pytorch_flax_utils.py DELETED
@@ -1,161 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2023 The HuggingFace Inc. team.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """ PyTorch - Flax general utilities."""
16
-
17
- from pickle import UnpicklingError
18
-
19
- import jax
20
- import jax.numpy as jnp
21
- import numpy as np
22
- from flax.serialization import from_bytes
23
- from flax.traverse_util import flatten_dict
24
-
25
- from ..utils import logging
26
-
27
-
28
- logger = logging.get_logger(__name__)
29
-
30
-
31
- #####################
32
- # Flax => PyTorch #
33
- #####################
34
-
35
-
36
- # from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py#L224-L352
37
- def load_flax_checkpoint_in_pytorch_model(pt_model, model_file):
38
- try:
39
- with open(model_file, "rb") as flax_state_f:
40
- flax_state = from_bytes(None, flax_state_f.read())
41
- except UnpicklingError as e:
42
- try:
43
- with open(model_file) as f:
44
- if f.read().startswith("version"):
45
- raise OSError(
46
- "You seem to have cloned a repository without having git-lfs installed. Please"
47
- " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
48
- " folder you cloned."
49
- )
50
- else:
51
- raise ValueError from e
52
- except (UnicodeDecodeError, ValueError):
53
- raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
54
-
55
- return load_flax_weights_in_pytorch_model(pt_model, flax_state)
56
-
57
-
58
- def load_flax_weights_in_pytorch_model(pt_model, flax_state):
59
- """Load flax checkpoints in a PyTorch model"""
60
-
61
- try:
62
- import torch # noqa: F401
63
- except ImportError:
64
- logger.error(
65
- "Loading Flax weights in PyTorch requires both PyTorch and Flax to be installed. Please see"
66
- " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
67
- " instructions."
68
- )
69
- raise
70
-
71
- # check if we have bf16 weights
72
- is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values()
73
- if any(is_type_bf16):
74
- # convert all weights to fp32 if they are bf16 since torch.from_numpy can-not handle bf16
75
-
76
- # and bf16 is not fully supported in PT yet.
77
- logger.warning(
78
- "Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` "
79
- "before loading those in PyTorch model."
80
- )
81
- flax_state = jax.tree_util.tree_map(
82
- lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state
83
- )
84
-
85
- pt_model.base_model_prefix = ""
86
-
87
- flax_state_dict = flatten_dict(flax_state, sep=".")
88
- pt_model_dict = pt_model.state_dict()
89
-
90
- # keep track of unexpected & missing keys
91
- unexpected_keys = []
92
- missing_keys = set(pt_model_dict.keys())
93
-
94
- for flax_key_tuple, flax_tensor in flax_state_dict.items():
95
- flax_key_tuple_array = flax_key_tuple.split(".")
96
-
97
- if flax_key_tuple_array[-1] == "kernel" and flax_tensor.ndim == 4:
98
- flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
99
- flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1))
100
- elif flax_key_tuple_array[-1] == "kernel":
101
- flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
102
- flax_tensor = flax_tensor.T
103
- elif flax_key_tuple_array[-1] == "scale":
104
- flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
105
-
106
- if "time_embedding" not in flax_key_tuple_array:
107
- for i, flax_key_tuple_string in enumerate(flax_key_tuple_array):
108
- flax_key_tuple_array[i] = (
109
- flax_key_tuple_string.replace("_0", ".0")
110
- .replace("_1", ".1")
111
- .replace("_2", ".2")
112
- .replace("_3", ".3")
113
- .replace("_4", ".4")
114
- .replace("_5", ".5")
115
- .replace("_6", ".6")
116
- .replace("_7", ".7")
117
- .replace("_8", ".8")
118
- .replace("_9", ".9")
119
- )
120
-
121
- flax_key = ".".join(flax_key_tuple_array)
122
-
123
- if flax_key in pt_model_dict:
124
- if flax_tensor.shape != pt_model_dict[flax_key].shape:
125
- raise ValueError(
126
- f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected "
127
- f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}."
128
- )
129
- else:
130
- # add weight to pytorch dict
131
- flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor
132
- pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
133
- # remove from missing keys
134
- missing_keys.remove(flax_key)
135
- else:
136
- # weight is not expected by PyTorch model
137
- unexpected_keys.append(flax_key)
138
-
139
- pt_model.load_state_dict(pt_model_dict)
140
-
141
- # re-transform missing_keys to list
142
- missing_keys = list(missing_keys)
143
-
144
- if len(unexpected_keys) > 0:
145
- logger.warning(
146
- "Some weights of the Flax model were not used when initializing the PyTorch model"
147
- f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
148
- f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture"
149
- " (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This"
150
- f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect"
151
- " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a"
152
- " FlaxBertForSequenceClassification model)."
153
- )
154
- if len(missing_keys) > 0:
155
- logger.warning(
156
- f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly"
157
- f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to"
158
- " use it for predictions and inference."
159
- )
160
-
161
- return pt_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4DoF/diffusers/models/modeling_utils.py DELETED
@@ -1,980 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2023 The HuggingFace Inc. team.
3
- # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
-
17
- import inspect
18
- import itertools
19
- import os
20
- import re
21
- from functools import partial
22
- from typing import Any, Callable, List, Optional, Tuple, Union
23
-
24
- import torch
25
- from torch import Tensor, device, nn
26
-
27
- from .. import __version__
28
- from ..utils import (
29
- CONFIG_NAME,
30
- DIFFUSERS_CACHE,
31
- FLAX_WEIGHTS_NAME,
32
- HF_HUB_OFFLINE,
33
- SAFETENSORS_WEIGHTS_NAME,
34
- WEIGHTS_NAME,
35
- _add_variant,
36
- _get_model_file,
37
- deprecate,
38
- is_accelerate_available,
39
- is_safetensors_available,
40
- is_torch_version,
41
- logging,
42
- )
43
-
44
-
45
- logger = logging.get_logger(__name__)
46
-
47
-
48
- if is_torch_version(">=", "1.9.0"):
49
- _LOW_CPU_MEM_USAGE_DEFAULT = True
50
- else:
51
- _LOW_CPU_MEM_USAGE_DEFAULT = False
52
-
53
-
54
- if is_accelerate_available():
55
- import accelerate
56
- from accelerate.utils import set_module_tensor_to_device
57
- from accelerate.utils.versions import is_torch_version
58
-
59
- if is_safetensors_available():
60
- import safetensors
61
-
62
-
63
- def get_parameter_device(parameter: torch.nn.Module):
64
- try:
65
- parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
66
- return next(parameters_and_buffers).device
67
- except StopIteration:
68
- # For torch.nn.DataParallel compatibility in PyTorch 1.5
69
-
70
- def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
71
- tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
72
- return tuples
73
-
74
- gen = parameter._named_members(get_members_fn=find_tensor_attributes)
75
- first_tuple = next(gen)
76
- return first_tuple[1].device
77
-
78
-
79
- def get_parameter_dtype(parameter: torch.nn.Module):
80
- try:
81
- params = tuple(parameter.parameters())
82
- if len(params) > 0:
83
- return params[0].dtype
84
-
85
- buffers = tuple(parameter.buffers())
86
- if len(buffers) > 0:
87
- return buffers[0].dtype
88
-
89
- except StopIteration:
90
- # For torch.nn.DataParallel compatibility in PyTorch 1.5
91
-
92
- def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
93
- tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
94
- return tuples
95
-
96
- gen = parameter._named_members(get_members_fn=find_tensor_attributes)
97
- first_tuple = next(gen)
98
- return first_tuple[1].dtype
99
-
100
-
101
- def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
102
- """
103
- Reads a checkpoint file, returning properly formatted errors if they arise.
104
- """
105
- try:
106
- if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant):
107
- return torch.load(checkpoint_file, map_location="cpu")
108
- else:
109
- return safetensors.torch.load_file(checkpoint_file, device="cpu")
110
- except Exception as e:
111
- try:
112
- with open(checkpoint_file) as f:
113
- if f.read().startswith("version"):
114
- raise OSError(
115
- "You seem to have cloned a repository without having git-lfs installed. Please install "
116
- "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
117
- "you cloned."
118
- )
119
- else:
120
- raise ValueError(
121
- f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
122
- "model. Make sure you have saved the model properly."
123
- ) from e
124
- except (UnicodeDecodeError, ValueError):
125
- raise OSError(
126
- f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
127
- f"at '{checkpoint_file}'. "
128
- "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
129
- )
130
-
131
-
132
- def _load_state_dict_into_model(model_to_load, state_dict):
133
- # Convert old format to new format if needed from a PyTorch state_dict
134
- # copy state_dict so _load_from_state_dict can modify it
135
- state_dict = state_dict.copy()
136
- error_msgs = []
137
-
138
- # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
139
- # so we need to apply the function recursively.
140
- def load(module: torch.nn.Module, prefix=""):
141
- args = (state_dict, prefix, {}, True, [], [], error_msgs)
142
- module._load_from_state_dict(*args)
143
-
144
- for name, child in module._modules.items():
145
- if child is not None:
146
- load(child, prefix + name + ".")
147
-
148
- load(model_to_load)
149
-
150
- return error_msgs
151
-
152
-
153
- class ModelMixin(torch.nn.Module):
154
- r"""
155
- Base class for all models.
156
-
157
- [`ModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
158
- saving models.
159
-
160
- - **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`].
161
- """
162
- config_name = CONFIG_NAME
163
- _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
164
- _supports_gradient_checkpointing = False
165
- _keys_to_ignore_on_load_unexpected = None
166
-
167
- def __init__(self):
168
- super().__init__()
169
-
170
- def __getattr__(self, name: str) -> Any:
171
- """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
172
- config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
173
- __getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__':
174
- https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
175
- """
176
-
177
- is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
178
- is_attribute = name in self.__dict__
179
-
180
- if is_in_config and not is_attribute:
181
- deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
182
- deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
183
- return self._internal_dict[name]
184
-
185
- # call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
186
- return super().__getattr__(name)
187
-
188
- @property
189
- def is_gradient_checkpointing(self) -> bool:
190
- """
191
- Whether gradient checkpointing is activated for this model or not.
192
- """
193
- return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
194
-
195
- def enable_gradient_checkpointing(self):
196
- """
197
- Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
198
- *checkpoint activations* in other frameworks).
199
- """
200
- if not self._supports_gradient_checkpointing:
201
- raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
202
- self.apply(partial(self._set_gradient_checkpointing, value=True))
203
-
204
- def disable_gradient_checkpointing(self):
205
- """
206
- Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
207
- *checkpoint activations* in other frameworks).
208
- """
209
- if self._supports_gradient_checkpointing:
210
- self.apply(partial(self._set_gradient_checkpointing, value=False))
211
-
212
- def set_use_memory_efficient_attention_xformers(
213
- self, valid: bool, attention_op: Optional[Callable] = None
214
- ) -> None:
215
- # Recursively walk through all the children.
216
- # Any children which exposes the set_use_memory_efficient_attention_xformers method
217
- # gets the message
218
- def fn_recursive_set_mem_eff(module: torch.nn.Module):
219
- if hasattr(module, "set_use_memory_efficient_attention_xformers"):
220
- module.set_use_memory_efficient_attention_xformers(valid, attention_op)
221
-
222
- for child in module.children():
223
- fn_recursive_set_mem_eff(child)
224
-
225
- for module in self.children():
226
- if isinstance(module, torch.nn.Module):
227
- fn_recursive_set_mem_eff(module)
228
-
229
- def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
230
- r"""
231
- Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
232
-
233
- When this option is enabled, you should observe lower GPU memory usage and a potential speed up during
234
- inference. Speed up during training is not guaranteed.
235
-
236
- <Tip warning={true}>
237
-
238
- ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
239
- precedent.
240
-
241
- </Tip>
242
-
243
- Parameters:
244
- attention_op (`Callable`, *optional*):
245
- Override the default `None` operator for use as `op` argument to the
246
- [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
247
- function of xFormers.
248
-
249
- Examples:
250
-
251
- ```py
252
- >>> import torch
253
- >>> from diffusers import UNet2DConditionModel
254
- >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
255
-
256
- >>> model = UNet2DConditionModel.from_pretrained(
257
- ... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
258
- ... )
259
- >>> model = model.to("cuda")
260
- >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
261
- ```
262
- """
263
- self.set_use_memory_efficient_attention_xformers(True, attention_op)
264
-
265
- def disable_xformers_memory_efficient_attention(self):
266
- r"""
267
- Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
268
- """
269
- self.set_use_memory_efficient_attention_xformers(False)
270
-
271
- def save_pretrained(
272
- self,
273
- save_directory: Union[str, os.PathLike],
274
- is_main_process: bool = True,
275
- save_function: Callable = None,
276
- safe_serialization: bool = False,
277
- variant: Optional[str] = None,
278
- ):
279
- """
280
- Save a model and its configuration file to a directory so that it can be reloaded using the
281
- [`~models.ModelMixin.from_pretrained`] class method.
282
-
283
- Arguments:
284
- save_directory (`str` or `os.PathLike`):
285
- Directory to save a model and its configuration file to. Will be created if it doesn't exist.
286
- is_main_process (`bool`, *optional*, defaults to `True`):
287
- Whether the process calling this is the main process or not. Useful during distributed training and you
288
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
289
- process to avoid race conditions.
290
- save_function (`Callable`):
291
- The function to use to save the state dictionary. Useful during distributed training when you need to
292
- replace `torch.save` with another method. Can be configured with the environment variable
293
- `DIFFUSERS_SAVE_MODE`.
294
- safe_serialization (`bool`, *optional*, defaults to `False`):
295
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
296
- variant (`str`, *optional*):
297
- If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
298
- """
299
- if safe_serialization and not is_safetensors_available():
300
- raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
301
-
302
- if os.path.isfile(save_directory):
303
- logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
304
- return
305
-
306
- os.makedirs(save_directory, exist_ok=True)
307
-
308
- model_to_save = self
309
-
310
- # Attach architecture to the config
311
- # Save the config
312
- if is_main_process:
313
- model_to_save.save_config(save_directory)
314
-
315
- # Save the model
316
- state_dict = model_to_save.state_dict()
317
-
318
- weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
319
- weights_name = _add_variant(weights_name, variant)
320
-
321
- # Save the model
322
- if safe_serialization:
323
- safetensors.torch.save_file(
324
- state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"}
325
- )
326
- else:
327
- torch.save(state_dict, os.path.join(save_directory, weights_name))
328
-
329
- logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
330
-
331
- @classmethod
332
- def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
333
- r"""
334
- Instantiate a pretrained PyTorch model from a pretrained model configuration.
335
-
336
- The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
337
- train the model, set it back in training mode with `model.train()`.
338
-
339
- Parameters:
340
- pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
341
- Can be either:
342
-
343
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
344
- the Hub.
345
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
346
- with [`~ModelMixin.save_pretrained`].
347
-
348
- cache_dir (`Union[str, os.PathLike]`, *optional*):
349
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
350
- is not used.
351
- torch_dtype (`str` or `torch.dtype`, *optional*):
352
- Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
353
- dtype is automatically derived from the model's weights.
354
- force_download (`bool`, *optional*, defaults to `False`):
355
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
356
- cached versions if they exist.
357
- resume_download (`bool`, *optional*, defaults to `False`):
358
- Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
359
- incompletely downloaded files are deleted.
360
- proxies (`Dict[str, str]`, *optional*):
361
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
362
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
363
- output_loading_info (`bool`, *optional*, defaults to `False`):
364
- Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
365
- local_files_only(`bool`, *optional*, defaults to `False`):
366
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
367
- won't be downloaded from the Hub.
368
- use_auth_token (`str` or *bool*, *optional*):
369
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
370
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
371
- revision (`str`, *optional*, defaults to `"main"`):
372
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
373
- allowed by Git.
374
- from_flax (`bool`, *optional*, defaults to `False`):
375
- Load the model weights from a Flax checkpoint save file.
376
- subfolder (`str`, *optional*, defaults to `""`):
377
- The subfolder location of a model file within a larger model repository on the Hub or locally.
378
- mirror (`str`, *optional*):
379
- Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
380
- guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
381
- information.
382
- device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
383
- A map that specifies where each submodule should go. It doesn't need to be defined for each
384
- parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
385
- same device.
386
-
387
- Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
388
- more information about each option see [designing a device
389
- map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
390
- max_memory (`Dict`, *optional*):
391
- A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
392
- each GPU and the available CPU RAM if unset.
393
- offload_folder (`str` or `os.PathLike`, *optional*):
394
- The path to offload weights if `device_map` contains the value `"disk"`.
395
- offload_state_dict (`bool`, *optional*):
396
- If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
397
- the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
398
- when there is some disk offload.
399
- low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
400
- Speed up model loading only loading the pretrained weights and not initializing the weights. This also
401
- tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
402
- Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
403
- argument to `True` will raise an error.
404
- variant (`str`, *optional*):
405
- Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
406
- loading `from_flax`.
407
- use_safetensors (`bool`, *optional*, defaults to `None`):
408
- If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
409
- `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
410
- weights. If set to `False`, `safetensors` weights are not loaded.
411
-
412
- <Tip>
413
-
414
- To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
415
- `huggingface-cli login`. You can also activate the special
416
- ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
417
- firewalled environment.
418
-
419
- </Tip>
420
-
421
- Example:
422
-
423
- ```py
424
- from diffusers import UNet2DConditionModel
425
-
426
- unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
427
- ```
428
-
429
- If you get the error message below, you need to finetune the weights for your downstream task:
430
-
431
- ```bash
432
- Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
433
- - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
434
- You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
435
- ```
436
- """
437
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
438
- ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
439
- force_download = kwargs.pop("force_download", False)
440
- from_flax = kwargs.pop("from_flax", False)
441
- resume_download = kwargs.pop("resume_download", False)
442
- proxies = kwargs.pop("proxies", None)
443
- output_loading_info = kwargs.pop("output_loading_info", False)
444
- local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
445
- use_auth_token = kwargs.pop("use_auth_token", None)
446
- revision = kwargs.pop("revision", None)
447
- torch_dtype = kwargs.pop("torch_dtype", None)
448
- subfolder = kwargs.pop("subfolder", None)
449
- device_map = kwargs.pop("device_map", None)
450
- max_memory = kwargs.pop("max_memory", None)
451
- offload_folder = kwargs.pop("offload_folder", None)
452
- offload_state_dict = kwargs.pop("offload_state_dict", False)
453
- low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
454
- variant = kwargs.pop("variant", None)
455
- use_safetensors = kwargs.pop("use_safetensors", None)
456
-
457
- if use_safetensors and not is_safetensors_available():
458
- raise ValueError(
459
- "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
460
- )
461
-
462
- allow_pickle = False
463
- if use_safetensors is None:
464
- use_safetensors = is_safetensors_available()
465
- allow_pickle = True
466
-
467
- if low_cpu_mem_usage and not is_accelerate_available():
468
- low_cpu_mem_usage = False
469
- logger.warning(
470
- "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
471
- " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
472
- " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
473
- " install accelerate\n```\n."
474
- )
475
-
476
- if device_map is not None and not is_accelerate_available():
477
- raise NotImplementedError(
478
- "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
479
- " `device_map=None`. You can install accelerate with `pip install accelerate`."
480
- )
481
-
482
- # Check if we can handle device_map and dispatching the weights
483
- if device_map is not None and not is_torch_version(">=", "1.9.0"):
484
- raise NotImplementedError(
485
- "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
486
- " `device_map=None`."
487
- )
488
-
489
- if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
490
- raise NotImplementedError(
491
- "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
492
- " `low_cpu_mem_usage=False`."
493
- )
494
-
495
- if low_cpu_mem_usage is False and device_map is not None:
496
- raise ValueError(
497
- f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
498
- " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
499
- )
500
-
501
- # Load config if we don't provide a configuration
502
- config_path = pretrained_model_name_or_path
503
-
504
- user_agent = {
505
- "diffusers": __version__,
506
- "file_type": "model",
507
- "framework": "pytorch",
508
- }
509
-
510
- # load config
511
- config, unused_kwargs, commit_hash = cls.load_config(
512
- config_path,
513
- cache_dir=cache_dir,
514
- return_unused_kwargs=True,
515
- return_commit_hash=True,
516
- force_download=force_download,
517
- resume_download=resume_download,
518
- proxies=proxies,
519
- local_files_only=local_files_only,
520
- use_auth_token=use_auth_token,
521
- revision=revision,
522
- subfolder=subfolder,
523
- device_map=device_map,
524
- max_memory=max_memory,
525
- offload_folder=offload_folder,
526
- offload_state_dict=offload_state_dict,
527
- user_agent=user_agent,
528
- **kwargs,
529
- )
530
-
531
- # load model
532
- model_file = None
533
- if from_flax:
534
- model_file = _get_model_file(
535
- pretrained_model_name_or_path,
536
- weights_name=FLAX_WEIGHTS_NAME,
537
- cache_dir=cache_dir,
538
- force_download=force_download,
539
- resume_download=resume_download,
540
- proxies=proxies,
541
- local_files_only=local_files_only,
542
- use_auth_token=use_auth_token,
543
- revision=revision,
544
- subfolder=subfolder,
545
- user_agent=user_agent,
546
- commit_hash=commit_hash,
547
- )
548
- model = cls.from_config(config, **unused_kwargs)
549
-
550
- # Convert the weights
551
- from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
552
-
553
- model = load_flax_checkpoint_in_pytorch_model(model, model_file)
554
- else:
555
- if use_safetensors:
556
- try:
557
- model_file = _get_model_file(
558
- pretrained_model_name_or_path,
559
- weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
560
- cache_dir=cache_dir,
561
- force_download=force_download,
562
- resume_download=resume_download,
563
- proxies=proxies,
564
- local_files_only=local_files_only,
565
- use_auth_token=use_auth_token,
566
- revision=revision,
567
- subfolder=subfolder,
568
- user_agent=user_agent,
569
- commit_hash=commit_hash,
570
- )
571
- except IOError as e:
572
- if not allow_pickle:
573
- raise e
574
- pass
575
- if model_file is None:
576
- model_file = _get_model_file(
577
- pretrained_model_name_or_path,
578
- weights_name=_add_variant(WEIGHTS_NAME, variant),
579
- cache_dir=cache_dir,
580
- force_download=force_download,
581
- resume_download=resume_download,
582
- proxies=proxies,
583
- local_files_only=local_files_only,
584
- use_auth_token=use_auth_token,
585
- revision=revision,
586
- subfolder=subfolder,
587
- user_agent=user_agent,
588
- commit_hash=commit_hash,
589
- )
590
-
591
- if low_cpu_mem_usage:
592
- # Instantiate model with empty weights
593
- with accelerate.init_empty_weights():
594
- model = cls.from_config(config, **unused_kwargs)
595
-
596
- # if device_map is None, load the state dict and move the params from meta device to the cpu
597
- if device_map is None:
598
- param_device = "cpu"
599
- state_dict = load_state_dict(model_file, variant=variant)
600
- model._convert_deprecated_attention_blocks(state_dict)
601
- # move the params from meta device to cpu
602
- missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
603
- if len(missing_keys) > 0:
604
- raise ValueError(
605
- f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
606
- f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
607
- " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
608
- " those weights or else make sure your checkpoint file is correct."
609
- )
610
- unexpected_keys = []
611
-
612
- empty_state_dict = model.state_dict()
613
- for param_name, param in state_dict.items():
614
- accepts_dtype = "dtype" in set(
615
- inspect.signature(set_module_tensor_to_device).parameters.keys()
616
- )
617
-
618
- if param_name not in empty_state_dict:
619
- unexpected_keys.append(param_name)
620
- continue
621
-
622
- if empty_state_dict[param_name].shape != param.shape:
623
- raise ValueError(
624
- f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
625
- )
626
-
627
- if accepts_dtype:
628
- set_module_tensor_to_device(
629
- model, param_name, param_device, value=param, dtype=torch_dtype
630
- )
631
- else:
632
- set_module_tensor_to_device(model, param_name, param_device, value=param)
633
-
634
- if cls._keys_to_ignore_on_load_unexpected is not None:
635
- for pat in cls._keys_to_ignore_on_load_unexpected:
636
- unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
637
-
638
- if len(unexpected_keys) > 0:
639
- logger.warn(
640
- f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
641
- )
642
-
643
- else: # else let accelerate handle loading and dispatching.
644
- # Load weights and dispatch according to the device_map
645
- # by default the device_map is None and the weights are loaded on the CPU
646
- try:
647
- accelerate.load_checkpoint_and_dispatch(
648
- model,
649
- model_file,
650
- device_map,
651
- max_memory=max_memory,
652
- offload_folder=offload_folder,
653
- offload_state_dict=offload_state_dict,
654
- dtype=torch_dtype,
655
- )
656
- except AttributeError as e:
657
- # When using accelerate loading, we do not have the ability to load the state
658
- # dict and rename the weight names manually. Additionally, accelerate skips
659
- # torch loading conventions and directly writes into `module.{_buffers, _parameters}`
660
- # (which look like they should be private variables?), so we can't use the standard hooks
661
- # to rename parameters on load. We need to mimic the original weight names so the correct
662
- # attributes are available. After we have loaded the weights, we convert the deprecated
663
- # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
664
- # the weights so we don't have to do this again.
665
-
666
- if "'Attention' object has no attribute" in str(e):
667
- logger.warn(
668
- f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
669
- " was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
670
- " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
671
- " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
672
- " please also re-upload it or open a PR on the original repository."
673
- )
674
- model._temp_convert_self_to_deprecated_attention_blocks()
675
- accelerate.load_checkpoint_and_dispatch(
676
- model,
677
- model_file,
678
- device_map,
679
- max_memory=max_memory,
680
- offload_folder=offload_folder,
681
- offload_state_dict=offload_state_dict,
682
- dtype=torch_dtype,
683
- )
684
- model._undo_temp_convert_self_to_deprecated_attention_blocks()
685
- else:
686
- raise e
687
-
688
- loading_info = {
689
- "missing_keys": [],
690
- "unexpected_keys": [],
691
- "mismatched_keys": [],
692
- "error_msgs": [],
693
- }
694
- else:
695
- model = cls.from_config(config, **unused_kwargs)
696
-
697
- state_dict = load_state_dict(model_file, variant=variant)
698
- model._convert_deprecated_attention_blocks(state_dict)
699
-
700
- model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
701
- model,
702
- state_dict,
703
- model_file,
704
- pretrained_model_name_or_path,
705
- ignore_mismatched_sizes=ignore_mismatched_sizes,
706
- )
707
-
708
- loading_info = {
709
- "missing_keys": missing_keys,
710
- "unexpected_keys": unexpected_keys,
711
- "mismatched_keys": mismatched_keys,
712
- "error_msgs": error_msgs,
713
- }
714
-
715
- if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
716
- raise ValueError(
717
- f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
718
- )
719
- elif torch_dtype is not None:
720
- model = model.to(torch_dtype)
721
-
722
- model.register_to_config(_name_or_path=pretrained_model_name_or_path)
723
-
724
- # Set model in evaluation mode to deactivate DropOut modules by default
725
- model.eval()
726
- if output_loading_info:
727
- return model, loading_info
728
-
729
- return model
730
-
731
- @classmethod
732
- def _load_pretrained_model(
733
- cls,
734
- model,
735
- state_dict,
736
- resolved_archive_file,
737
- pretrained_model_name_or_path,
738
- ignore_mismatched_sizes=False,
739
- ):
740
- # Retrieve missing & unexpected_keys
741
- model_state_dict = model.state_dict()
742
- loaded_keys = list(state_dict.keys())
743
-
744
- expected_keys = list(model_state_dict.keys())
745
-
746
- original_loaded_keys = loaded_keys
747
-
748
- missing_keys = list(set(expected_keys) - set(loaded_keys))
749
- unexpected_keys = list(set(loaded_keys) - set(expected_keys))
750
-
751
- # Make sure we are able to load base models as well as derived models (with heads)
752
- model_to_load = model
753
-
754
- def _find_mismatched_keys(
755
- state_dict,
756
- model_state_dict,
757
- loaded_keys,
758
- ignore_mismatched_sizes,
759
- ):
760
- mismatched_keys = []
761
- if ignore_mismatched_sizes:
762
- for checkpoint_key in loaded_keys:
763
- model_key = checkpoint_key
764
-
765
- if (
766
- model_key in model_state_dict
767
- and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
768
- ):
769
- mismatched_keys.append(
770
- (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
771
- )
772
- del state_dict[checkpoint_key]
773
- return mismatched_keys
774
-
775
- if state_dict is not None:
776
- # Whole checkpoint
777
- mismatched_keys = _find_mismatched_keys(
778
- state_dict,
779
- model_state_dict,
780
- original_loaded_keys,
781
- ignore_mismatched_sizes,
782
- )
783
- error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
784
-
785
- if len(error_msgs) > 0:
786
- error_msg = "\n\t".join(error_msgs)
787
- if "size mismatch" in error_msg:
788
- error_msg += (
789
- "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
790
- )
791
- raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
792
-
793
- if len(unexpected_keys) > 0:
794
- logger.warning(
795
- f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
796
- f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
797
- f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
798
- " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
799
- " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
800
- f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
801
- " identical (initializing a BertForSequenceClassification model from a"
802
- " BertForSequenceClassification model)."
803
- )
804
- else:
805
- logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
806
- if len(missing_keys) > 0:
807
- logger.warning(
808
- f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
809
- f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
810
- " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
811
- )
812
- elif len(mismatched_keys) == 0:
813
- logger.info(
814
- f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
815
- f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
816
- f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
817
- " without further training."
818
- )
819
- if len(mismatched_keys) > 0:
820
- mismatched_warning = "\n".join(
821
- [
822
- f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
823
- for key, shape1, shape2 in mismatched_keys
824
- ]
825
- )
826
- logger.warning(
827
- f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
828
- f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
829
- f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
830
- " able to use it for predictions and inference."
831
- )
832
-
833
- return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
834
-
835
- @property
836
- def device(self) -> device:
837
- """
838
- `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
839
- device).
840
- """
841
- return get_parameter_device(self)
842
-
843
- @property
844
- def dtype(self) -> torch.dtype:
845
- """
846
- `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
847
- """
848
- return get_parameter_dtype(self)
849
-
850
- def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
851
- """
852
- Get number of (trainable or non-embedding) parameters in the module.
853
-
854
- Args:
855
- only_trainable (`bool`, *optional*, defaults to `False`):
856
- Whether or not to return only the number of trainable parameters.
857
- exclude_embeddings (`bool`, *optional*, defaults to `False`):
858
- Whether or not to return only the number of non-embedding parameters.
859
-
860
- Returns:
861
- `int`: The number of parameters.
862
-
863
- Example:
864
-
865
- ```py
866
- from diffusers import UNet2DConditionModel
867
-
868
- model_id = "runwayml/stable-diffusion-v1-5"
869
- unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
870
- unet.num_parameters(only_trainable=True)
871
- 859520964
872
- ```
873
- """
874
-
875
- if exclude_embeddings:
876
- embedding_param_names = [
877
- f"{name}.weight"
878
- for name, module_type in self.named_modules()
879
- if isinstance(module_type, torch.nn.Embedding)
880
- ]
881
- non_embedding_parameters = [
882
- parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
883
- ]
884
- return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
885
- else:
886
- return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
887
-
888
- def _convert_deprecated_attention_blocks(self, state_dict):
889
- deprecated_attention_block_paths = []
890
-
891
- def recursive_find_attn_block(name, module):
892
- if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
893
- deprecated_attention_block_paths.append(name)
894
-
895
- for sub_name, sub_module in module.named_children():
896
- sub_name = sub_name if name == "" else f"{name}.{sub_name}"
897
- recursive_find_attn_block(sub_name, sub_module)
898
-
899
- recursive_find_attn_block("", self)
900
-
901
- # NOTE: we have to check if the deprecated parameters are in the state dict
902
- # because it is possible we are loading from a state dict that was already
903
- # converted
904
-
905
- for path in deprecated_attention_block_paths:
906
- # group_norm path stays the same
907
-
908
- # query -> to_q
909
- if f"{path}.query.weight" in state_dict:
910
- state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
911
- if f"{path}.query.bias" in state_dict:
912
- state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
913
-
914
- # key -> to_k
915
- if f"{path}.key.weight" in state_dict:
916
- state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
917
- if f"{path}.key.bias" in state_dict:
918
- state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
919
-
920
- # value -> to_v
921
- if f"{path}.value.weight" in state_dict:
922
- state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
923
- if f"{path}.value.bias" in state_dict:
924
- state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
925
-
926
- # proj_attn -> to_out.0
927
- if f"{path}.proj_attn.weight" in state_dict:
928
- state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
929
- if f"{path}.proj_attn.bias" in state_dict:
930
- state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
931
-
932
- def _temp_convert_self_to_deprecated_attention_blocks(self):
933
- deprecated_attention_block_modules = []
934
-
935
- def recursive_find_attn_block(module):
936
- if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
937
- deprecated_attention_block_modules.append(module)
938
-
939
- for sub_module in module.children():
940
- recursive_find_attn_block(sub_module)
941
-
942
- recursive_find_attn_block(self)
943
-
944
- for module in deprecated_attention_block_modules:
945
- module.query = module.to_q
946
- module.key = module.to_k
947
- module.value = module.to_v
948
- module.proj_attn = module.to_out[0]
949
-
950
- # We don't _have_ to delete the old attributes, but it's helpful to ensure
951
- # that _all_ the weights are loaded into the new attributes and we're not
952
- # making an incorrect assumption that this model should be converted when
953
- # it really shouldn't be.
954
- del module.to_q
955
- del module.to_k
956
- del module.to_v
957
- del module.to_out
958
-
959
- def _undo_temp_convert_self_to_deprecated_attention_blocks(self):
960
- deprecated_attention_block_modules = []
961
-
962
- def recursive_find_attn_block(module):
963
- if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
964
- deprecated_attention_block_modules.append(module)
965
-
966
- for sub_module in module.children():
967
- recursive_find_attn_block(sub_module)
968
-
969
- recursive_find_attn_block(self)
970
-
971
- for module in deprecated_attention_block_modules:
972
- module.to_q = module.query
973
- module.to_k = module.key
974
- module.to_v = module.value
975
- module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)])
976
-
977
- del module.query
978
- del module.key
979
- del module.value
980
- del module.proj_attn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4DoF/diffusers/models/prior_transformer.py DELETED
@@ -1,364 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import Dict, Optional, Union
3
-
4
- import torch
5
- import torch.nn.functional as F
6
- from torch import nn
7
-
8
- from ..configuration_utils import ConfigMixin, register_to_config
9
- from ..utils import BaseOutput
10
- from .attention import BasicTransformerBlock
11
- from .attention_processor import AttentionProcessor, AttnProcessor
12
- from .embeddings import TimestepEmbedding, Timesteps
13
- from .modeling_utils import ModelMixin
14
-
15
-
16
- @dataclass
17
- class PriorTransformerOutput(BaseOutput):
18
- """
19
- The output of [`PriorTransformer`].
20
-
21
- Args:
22
- predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
23
- The predicted CLIP image embedding conditioned on the CLIP text embedding input.
24
- """
25
-
26
- predicted_image_embedding: torch.FloatTensor
27
-
28
-
29
- class PriorTransformer(ModelMixin, ConfigMixin):
30
- """
31
- A Prior Transformer model.
32
-
33
- Parameters:
34
- num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
35
- attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
36
- num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
37
- embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states`
38
- num_embeddings (`int`, *optional*, defaults to 77):
39
- The number of embeddings of the model input `hidden_states`
40
- additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
41
- projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
42
- additional_embeddings`.
43
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
44
- time_embed_act_fn (`str`, *optional*, defaults to 'silu'):
45
- The activation function to use to create timestep embeddings.
46
- norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before
47
- passing to Transformer blocks. Set it to `None` if normalization is not needed.
48
- embedding_proj_norm_type (`str`, *optional*, defaults to None):
49
- The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not
50
- needed.
51
- encoder_hid_proj_type (`str`, *optional*, defaults to `linear`):
52
- The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if
53
- `encoder_hidden_states` is `None`.
54
- added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
55
- Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
56
- product between the text embedding and image embedding as proposed in the unclip paper
57
- https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended.
58
- time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
59
- If None, will be set to `num_attention_heads * attention_head_dim`
60
- embedding_proj_dim (`int`, *optional*, default to None):
61
- The dimension of `proj_embedding`. If None, will be set to `embedding_dim`.
62
- clip_embed_dim (`int`, *optional*, default to None):
63
- The dimension of the output. If None, will be set to `embedding_dim`.
64
- """
65
-
66
- @register_to_config
67
- def __init__(
68
- self,
69
- num_attention_heads: int = 32,
70
- attention_head_dim: int = 64,
71
- num_layers: int = 20,
72
- embedding_dim: int = 768,
73
- num_embeddings=77,
74
- additional_embeddings=4,
75
- dropout: float = 0.0,
76
- time_embed_act_fn: str = "silu",
77
- norm_in_type: Optional[str] = None, # layer
78
- embedding_proj_norm_type: Optional[str] = None, # layer
79
- encoder_hid_proj_type: Optional[str] = "linear", # linear
80
- added_emb_type: Optional[str] = "prd", # prd
81
- time_embed_dim: Optional[int] = None,
82
- embedding_proj_dim: Optional[int] = None,
83
- clip_embed_dim: Optional[int] = None,
84
- ):
85
- super().__init__()
86
- self.num_attention_heads = num_attention_heads
87
- self.attention_head_dim = attention_head_dim
88
- inner_dim = num_attention_heads * attention_head_dim
89
- self.additional_embeddings = additional_embeddings
90
-
91
- time_embed_dim = time_embed_dim or inner_dim
92
- embedding_proj_dim = embedding_proj_dim or embedding_dim
93
- clip_embed_dim = clip_embed_dim or embedding_dim
94
-
95
- self.time_proj = Timesteps(inner_dim, True, 0)
96
- self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn)
97
-
98
- self.proj_in = nn.Linear(embedding_dim, inner_dim)
99
-
100
- if embedding_proj_norm_type is None:
101
- self.embedding_proj_norm = None
102
- elif embedding_proj_norm_type == "layer":
103
- self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim)
104
- else:
105
- raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}")
106
-
107
- self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim)
108
-
109
- if encoder_hid_proj_type is None:
110
- self.encoder_hidden_states_proj = None
111
- elif encoder_hid_proj_type == "linear":
112
- self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
113
- else:
114
- raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}")
115
-
116
- self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
117
-
118
- if added_emb_type == "prd":
119
- self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
120
- elif added_emb_type is None:
121
- self.prd_embedding = None
122
- else:
123
- raise ValueError(
124
- f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`."
125
- )
126
-
127
- self.transformer_blocks = nn.ModuleList(
128
- [
129
- BasicTransformerBlock(
130
- inner_dim,
131
- num_attention_heads,
132
- attention_head_dim,
133
- dropout=dropout,
134
- activation_fn="gelu",
135
- attention_bias=True,
136
- )
137
- for d in range(num_layers)
138
- ]
139
- )
140
-
141
- if norm_in_type == "layer":
142
- self.norm_in = nn.LayerNorm(inner_dim)
143
- elif norm_in_type is None:
144
- self.norm_in = None
145
- else:
146
- raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.")
147
-
148
- self.norm_out = nn.LayerNorm(inner_dim)
149
-
150
- self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim)
151
-
152
- causal_attention_mask = torch.full(
153
- [num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
154
- )
155
- causal_attention_mask.triu_(1)
156
- causal_attention_mask = causal_attention_mask[None, ...]
157
- self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
158
-
159
- self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim))
160
- self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
161
-
162
- @property
163
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
164
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
165
- r"""
166
- Returns:
167
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
168
- indexed by its weight name.
169
- """
170
- # set recursively
171
- processors = {}
172
-
173
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
174
- if hasattr(module, "set_processor"):
175
- processors[f"{name}.processor"] = module.processor
176
-
177
- for sub_name, child in module.named_children():
178
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
179
-
180
- return processors
181
-
182
- for name, module in self.named_children():
183
- fn_recursive_add_processors(name, module, processors)
184
-
185
- return processors
186
-
187
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
188
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
189
- r"""
190
- Sets the attention processor to use to compute attention.
191
-
192
- Parameters:
193
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
194
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
195
- for **all** `Attention` layers.
196
-
197
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
198
- processor. This is strongly recommended when setting trainable attention processors.
199
-
200
- """
201
- count = len(self.attn_processors.keys())
202
-
203
- if isinstance(processor, dict) and len(processor) != count:
204
- raise ValueError(
205
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
206
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
207
- )
208
-
209
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
210
- if hasattr(module, "set_processor"):
211
- if not isinstance(processor, dict):
212
- module.set_processor(processor)
213
- else:
214
- module.set_processor(processor.pop(f"{name}.processor"))
215
-
216
- for sub_name, child in module.named_children():
217
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
218
-
219
- for name, module in self.named_children():
220
- fn_recursive_attn_processor(name, module, processor)
221
-
222
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
223
- def set_default_attn_processor(self):
224
- """
225
- Disables custom attention processors and sets the default attention implementation.
226
- """
227
- self.set_attn_processor(AttnProcessor())
228
-
229
- def forward(
230
- self,
231
- hidden_states,
232
- timestep: Union[torch.Tensor, float, int],
233
- proj_embedding: torch.FloatTensor,
234
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
235
- attention_mask: Optional[torch.BoolTensor] = None,
236
- return_dict: bool = True,
237
- ):
238
- """
239
- The [`PriorTransformer`] forward method.
240
-
241
- Args:
242
- hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
243
- The currently predicted image embeddings.
244
- timestep (`torch.LongTensor`):
245
- Current denoising step.
246
- proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
247
- Projected embedding vector the denoising process is conditioned on.
248
- encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
249
- Hidden states of the text embeddings the denoising process is conditioned on.
250
- attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
251
- Text mask for the text embeddings.
252
- return_dict (`bool`, *optional*, defaults to `True`):
253
- Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain
254
- tuple.
255
-
256
- Returns:
257
- [`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
258
- If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a
259
- tuple is returned where the first element is the sample tensor.
260
- """
261
- batch_size = hidden_states.shape[0]
262
-
263
- timesteps = timestep
264
- if not torch.is_tensor(timesteps):
265
- timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device)
266
- elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
267
- timesteps = timesteps[None].to(hidden_states.device)
268
-
269
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
270
- timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device)
271
-
272
- timesteps_projected = self.time_proj(timesteps)
273
-
274
- # timesteps does not contain any weights and will always return f32 tensors
275
- # but time_embedding might be fp16, so we need to cast here.
276
- timesteps_projected = timesteps_projected.to(dtype=self.dtype)
277
- time_embeddings = self.time_embedding(timesteps_projected)
278
-
279
- if self.embedding_proj_norm is not None:
280
- proj_embedding = self.embedding_proj_norm(proj_embedding)
281
-
282
- proj_embeddings = self.embedding_proj(proj_embedding)
283
- if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
284
- encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
285
- elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None:
286
- raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set")
287
-
288
- hidden_states = self.proj_in(hidden_states)
289
-
290
- positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
291
-
292
- additional_embeds = []
293
- additional_embeddings_len = 0
294
-
295
- if encoder_hidden_states is not None:
296
- additional_embeds.append(encoder_hidden_states)
297
- additional_embeddings_len += encoder_hidden_states.shape[1]
298
-
299
- if len(proj_embeddings.shape) == 2:
300
- proj_embeddings = proj_embeddings[:, None, :]
301
-
302
- if len(hidden_states.shape) == 2:
303
- hidden_states = hidden_states[:, None, :]
304
-
305
- additional_embeds = additional_embeds + [
306
- proj_embeddings,
307
- time_embeddings[:, None, :],
308
- hidden_states,
309
- ]
310
-
311
- if self.prd_embedding is not None:
312
- prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
313
- additional_embeds.append(prd_embedding)
314
-
315
- hidden_states = torch.cat(
316
- additional_embeds,
317
- dim=1,
318
- )
319
-
320
- # Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens
321
- additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1
322
- if positional_embeddings.shape[1] < hidden_states.shape[1]:
323
- positional_embeddings = F.pad(
324
- positional_embeddings,
325
- (
326
- 0,
327
- 0,
328
- additional_embeddings_len,
329
- self.prd_embedding.shape[1] if self.prd_embedding is not None else 0,
330
- ),
331
- value=0.0,
332
- )
333
-
334
- hidden_states = hidden_states + positional_embeddings
335
-
336
- if attention_mask is not None:
337
- attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
338
- attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
339
- attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
340
- attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
341
-
342
- if self.norm_in is not None:
343
- hidden_states = self.norm_in(hidden_states)
344
-
345
- for block in self.transformer_blocks:
346
- hidden_states = block(hidden_states, attention_mask=attention_mask)
347
-
348
- hidden_states = self.norm_out(hidden_states)
349
-
350
- if self.prd_embedding is not None:
351
- hidden_states = hidden_states[:, -1]
352
- else:
353
- hidden_states = hidden_states[:, additional_embeddings_len:]
354
-
355
- predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
356
-
357
- if not return_dict:
358
- return (predicted_image_embedding,)
359
-
360
- return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
361
-
362
- def post_process_latents(self, prior_latents):
363
- prior_latents = (prior_latents * self.clip_std) + self.clip_mean
364
- return prior_latents