yinshengming commited on
Commit
ab85cf9
1 Parent(s): 2fea44e
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +7 -0
  2. DragNUWA_net.py +297 -0
  3. app.py +386 -0
  4. assets/DragNUWA1.0/Figure1.gif +3 -0
  5. assets/DragNUWA1.0/Figure2.gif +3 -0
  6. assets/DragNUWA1.0/Figure3.gif +3 -0
  7. assets/DragNUWA1.5/Figure1.gif +3 -0
  8. assets/DragNUWA1.5/Figure2.gif +3 -0
  9. assets/DragNUWA1.5/Figure3.gif +3 -0
  10. assets/DragNUWA1.5/Figure4.gif +3 -0
  11. dragnuwa.md +81 -0
  12. dragnuwa/__init__.py +0 -0
  13. dragnuwa/__pycache__/__init__.cpython-38.pyc +0 -0
  14. dragnuwa/__pycache__/lora.cpython-38.pyc +0 -0
  15. dragnuwa/lora.py +412 -0
  16. dragnuwa/svd/__init__.py +0 -0
  17. dragnuwa/svd/__pycache__/__init__.cpython-38.pyc +0 -0
  18. dragnuwa/svd/__pycache__/util.cpython-38.pyc +0 -0
  19. dragnuwa/svd/models/__init__.py +0 -0
  20. dragnuwa/svd/models/__pycache__/__init__.cpython-38.pyc +0 -0
  21. dragnuwa/svd/models/__pycache__/autoencoder.cpython-38.pyc +0 -0
  22. dragnuwa/svd/models/autoencoder.py +615 -0
  23. dragnuwa/svd/modules/__init__.py +6 -0
  24. dragnuwa/svd/modules/__pycache__/__init__.cpython-38.pyc +0 -0
  25. dragnuwa/svd/modules/__pycache__/attention.cpython-38.pyc +0 -0
  26. dragnuwa/svd/modules/__pycache__/ema.cpython-38.pyc +0 -0
  27. dragnuwa/svd/modules/__pycache__/video_attention.cpython-38.pyc +0 -0
  28. dragnuwa/svd/modules/attention.py +759 -0
  29. dragnuwa/svd/modules/autoencoding/__init__.py +0 -0
  30. dragnuwa/svd/modules/autoencoding/__pycache__/__init__.cpython-38.pyc +0 -0
  31. dragnuwa/svd/modules/autoencoding/__pycache__/temporal_ae.cpython-38.pyc +0 -0
  32. dragnuwa/svd/modules/autoencoding/losses/__init__.py +7 -0
  33. dragnuwa/svd/modules/autoencoding/losses/discriminator_loss.py +306 -0
  34. dragnuwa/svd/modules/autoencoding/losses/lpips.py +73 -0
  35. dragnuwa/svd/modules/autoencoding/lpips/__init__.py +0 -0
  36. dragnuwa/svd/modules/autoencoding/lpips/loss/.gitignore +1 -0
  37. dragnuwa/svd/modules/autoencoding/lpips/loss/LICENSE +23 -0
  38. dragnuwa/svd/modules/autoencoding/lpips/loss/__init__.py +0 -0
  39. dragnuwa/svd/modules/autoencoding/lpips/loss/lpips.py +147 -0
  40. dragnuwa/svd/modules/autoencoding/lpips/model/LICENSE +58 -0
  41. dragnuwa/svd/modules/autoencoding/lpips/model/__init__.py +0 -0
  42. dragnuwa/svd/modules/autoencoding/lpips/model/model.py +88 -0
  43. dragnuwa/svd/modules/autoencoding/lpips/util.py +128 -0
  44. dragnuwa/svd/modules/autoencoding/lpips/vqperceptual.py +17 -0
  45. dragnuwa/svd/modules/autoencoding/regularizers/__init__.py +31 -0
  46. dragnuwa/svd/modules/autoencoding/regularizers/__pycache__/__init__.cpython-38.pyc +0 -0
  47. dragnuwa/svd/modules/autoencoding/regularizers/__pycache__/base.cpython-38.pyc +0 -0
  48. dragnuwa/svd/modules/autoencoding/regularizers/base.py +40 -0
  49. dragnuwa/svd/modules/autoencoding/regularizers/quantize.py +487 -0
  50. dragnuwa/svd/modules/autoencoding/temporal_ae.py +349 -0
.gitattributes CHANGED
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/DragNUWA1.0/Figure1.gif filter=lfs diff=lfs merge=lfs -text
37
+ assets/DragNUWA1.0/Figure2.gif filter=lfs diff=lfs merge=lfs -text
38
+ assets/DragNUWA1.0/Figure3.gif filter=lfs diff=lfs merge=lfs -text
39
+ assets/DragNUWA1.5/Figure1.gif filter=lfs diff=lfs merge=lfs -text
40
+ assets/DragNUWA1.5/Figure2.gif filter=lfs diff=lfs merge=lfs -text
41
+ assets/DragNUWA1.5/Figure3.gif filter=lfs diff=lfs merge=lfs -text
42
+ assets/DragNUWA1.5/Figure4.gif filter=lfs diff=lfs merge=lfs -text
DragNUWA_net.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import *
2
+
3
+ #### SVD
4
+ from dragnuwa.svd.modules.diffusionmodules.video_model_flow import VideoUNet_flow, VideoResBlock_Embed
5
+ from dragnuwa.svd.modules.diffusionmodules.denoiser import Denoiser
6
+ from dragnuwa.svd.modules.diffusionmodules.denoiser_scaling import VScalingWithEDMcNoise
7
+ from dragnuwa.svd.modules.encoders.modules import *
8
+ from dragnuwa.svd.models.autoencoder import AutoencodingEngine
9
+ from dragnuwa.svd.modules.diffusionmodules.wrappers import OpenAIWrapper
10
+ from dragnuwa.svd.modules.diffusionmodules.sampling import EulerEDMSampler
11
+
12
+ from dragnuwa.lora import inject_trainable_lora, inject_trainable_lora_extended, extract_lora_ups_down, _find_modules
13
+
14
+ def get_gaussian_kernel(kernel_size, sigma, channels):
15
+ print('parameters of gaussian kernel: kernel_size: {}, sigma: {}, channels: {}'.format(kernel_size, sigma, channels))
16
+ x_coord = torch.arange(kernel_size)
17
+ x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size)
18
+ y_grid = x_grid.t()
19
+ xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()
20
+ mean = (kernel_size - 1)/2.
21
+ variance = sigma**2.
22
+
23
+ gaussian_kernel = torch.exp(
24
+ -torch.sum((xy_grid - mean)**2., dim=-1) /\
25
+ (2*variance)
26
+ )
27
+
28
+ gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)
29
+ gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1)
30
+
31
+ gaussian_filter = nn.Conv2d(in_channels=channels, out_channels=channels,kernel_size=kernel_size, groups=channels, bias=False, padding=kernel_size//2)
32
+
33
+ gaussian_filter.weight.data = gaussian_kernel
34
+ gaussian_filter.weight.requires_grad = False
35
+
36
+ return gaussian_filter
37
+
38
+ def inject_lora(use_lora, model, replace_modules, is_extended=False, dropout=0.0, r=16):
39
+ injector = (
40
+ inject_trainable_lora if not is_extended
41
+ else
42
+ inject_trainable_lora_extended
43
+ )
44
+
45
+ params = None
46
+ negation = None
47
+
48
+ if use_lora:
49
+ REPLACE_MODULES = replace_modules
50
+ injector_args = {
51
+ "model": model,
52
+ "target_replace_module": REPLACE_MODULES,
53
+ "r": r
54
+ }
55
+ if not is_extended: injector_args['dropout_p'] = dropout
56
+
57
+ params, negation = injector(**injector_args)
58
+ for _up, _down in extract_lora_ups_down(
59
+ model,
60
+ target_replace_module=REPLACE_MODULES):
61
+
62
+ if all(x is not None for x in [_up, _down]):
63
+ print(f"Lora successfully injected into {model.__class__.__name__}.")
64
+
65
+ break
66
+
67
+ return params, negation
68
+
69
+ class Args:
70
+ ### basic
71
+ fps = 4
72
+ height = 320
73
+ width = 576
74
+
75
+ ### lora
76
+ unet_lora_rank = 32
77
+
78
+ ### gaussian filter parameters
79
+ kernel_size = 199
80
+ sigma = 20
81
+
82
+ # model
83
+ denoiser_config = {
84
+ 'scaling_config':{
85
+ 'target': 'dragnuwa.svd.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise',
86
+ }
87
+ }
88
+
89
+ network_config = {
90
+ 'adm_in_channels': 768, 'num_classes': 'sequential', 'use_checkpoint': True, 'in_channels': 8, 'out_channels': 4, 'model_channels': 320, 'attention_resolutions': [4, 2, 1], 'num_res_blocks': 2, 'channel_mult': [1, 2, 4, 4], 'num_head_channels': 64, 'use_linear_in_transformer': True, 'transformer_depth': 1, 'context_dim': 1024, 'spatial_transformer_attn_type': 'softmax-xformers', 'extra_ff_mix_layer': True, 'use_spatial_context': True, 'merge_strategy': 'learned_with_images', 'video_kernel_size': [3, 1, 1], 'flow_dim_scale': 1,
91
+ }
92
+
93
+ conditioner_emb_models = [
94
+ {'is_trainable': False,
95
+ 'input_key': 'cond_frames_without_noise', # crossattn
96
+ 'ucg_rate': 0.1,
97
+ 'target': 'dragnuwa.svd.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder',
98
+ 'params':{
99
+ 'n_cond_frames': 1,
100
+ 'n_copies': 1,
101
+ 'open_clip_embedding_config': {
102
+ 'target': 'dragnuwa.svd.modules.encoders.modules.FrozenOpenCLIPImageEmbedder',
103
+ 'params': {
104
+ 'freeze':True,
105
+ }
106
+ }
107
+ }
108
+ },
109
+ {'input_key': 'fps_id', # vector
110
+ 'is_trainable': False,
111
+ 'ucg_rate': 0.1,
112
+ 'target': 'dragnuwa.svd.modules.encoders.modules.ConcatTimestepEmbedderND',
113
+ 'params': {
114
+ 'outdim': 256,
115
+ }
116
+ },
117
+ {'input_key': 'motion_bucket_id', # vector
118
+ 'ucg_rate': 0.1,
119
+ 'is_trainable': False,
120
+ 'target': 'dragnuwa.svd.modules.encoders.modules.ConcatTimestepEmbedderND',
121
+ 'params': {
122
+ 'outdim': 256,
123
+ }
124
+ },
125
+ {'input_key': 'cond_frames', # concat
126
+ 'is_trainable': False,
127
+ 'ucg_rate': 0.1,
128
+ 'target': 'dragnuwa.svd.modules.encoders.modules.VideoPredictionEmbedderWithEncoder',
129
+ 'params': {
130
+ 'en_and_decode_n_samples_a_time': 1,
131
+ 'disable_encoder_autocast': True,
132
+ 'n_cond_frames': 1,
133
+ 'n_copies': 1,
134
+ 'is_ae': True,
135
+ 'encoder_config': {
136
+ 'target': 'dragnuwa.svd.models.autoencoder.AutoencoderKLModeOnly',
137
+ 'params': {
138
+ 'embed_dim': 4,
139
+ 'monitor': 'val/rec_loss',
140
+ 'ddconfig': {
141
+ 'attn_type': 'vanilla-xformers',
142
+ 'double_z': True,
143
+ 'z_channels': 4,
144
+ 'resolution': 256,
145
+ 'in_channels': 3,
146
+ 'out_ch': 3,
147
+ 'ch': 128,
148
+ 'ch_mult': [1, 2, 4, 4],
149
+ 'num_res_blocks': 2,
150
+ 'attn_resolutions': [],
151
+ 'dropout': 0.0,
152
+ },
153
+ 'lossconfig': {
154
+ 'target': 'torch.nn.Identity',
155
+ }
156
+ }
157
+ }
158
+ }
159
+ },
160
+ {'input_key': 'cond_aug', # vector
161
+ 'ucg_rate': 0.1,
162
+ 'is_trainable': False,
163
+ 'target': 'dragnuwa.svd.modules.encoders.modules.ConcatTimestepEmbedderND',
164
+ 'params': {
165
+ 'outdim': 256,
166
+ }
167
+ }
168
+ ]
169
+
170
+ first_stage_config = {
171
+ 'loss_config': {'target': 'torch.nn.Identity'},
172
+ 'regularizer_config': {'target': 'dragnuwa.svd.modules.autoencoding.regularizers.DiagonalGaussianRegularizer'},
173
+ 'encoder_config':{'target': 'dragnuwa.svd.modules.diffusionmodules.model.Encoder',
174
+ 'params': { 'attn_type':'vanilla',
175
+ 'double_z': True,
176
+ 'z_channels': 4,
177
+ 'resolution': 256,
178
+ 'in_channels': 3,
179
+ 'out_ch': 3,
180
+ 'ch': 128,
181
+ 'ch_mult': [1, 2, 4, 4],
182
+ 'num_res_blocks': 2,
183
+ 'attn_resolutions': [],
184
+ 'dropout': 0.0,
185
+ }
186
+ },
187
+ 'decoder_config':{'target': 'dragnuwa.svd.modules.autoencoding.temporal_ae.VideoDecoder',
188
+ 'params': {'attn_type': 'vanilla',
189
+ 'double_z': True,
190
+ 'z_channels': 4,
191
+ 'resolution': 256,
192
+ 'in_channels': 3,
193
+ 'out_ch': 3,
194
+ 'ch': 128,
195
+ 'ch_mult': [1, 2, 4, 4],
196
+ 'num_res_blocks': 2,
197
+ 'attn_resolutions': [],
198
+ 'dropout': 0.0,
199
+ 'video_kernel_size': [3, 1, 1],
200
+ }
201
+ },
202
+ }
203
+
204
+ sampler_config = {
205
+ 'discretization_config': {'target': 'dragnuwa.svd.modules.diffusionmodules.discretizer.EDMDiscretization',
206
+ 'params': {'sigma_max': 700.0,},
207
+ },
208
+ 'guider_config': {'target': 'dragnuwa.svd.modules.diffusionmodules.guiders.LinearPredictionGuider',
209
+ 'params': {'max_scale':2.5,
210
+ 'min_scale':1.0,
211
+ 'num_frames':14},
212
+ },
213
+ 'num_steps': 25,
214
+ }
215
+
216
+ scale_factor = 0.18215
217
+ num_frames = 14
218
+
219
+ ### others
220
+ seed = 42
221
+ os.environ["PL_GLOBAL_SEED"] = str(seed)
222
+ random.seed(seed)
223
+ np.random.seed(seed)
224
+ torch.manual_seed(seed)
225
+ torch.cuda.manual_seed_all(seed)
226
+
227
+
228
+ args = Args()
229
+
230
+ def quick_freeze(model):
231
+ for name, param in model.named_parameters():
232
+ param.requires_grad = False
233
+ return model
234
+
235
+ class Net(nn.Module):
236
+ def __init__(self, args):
237
+ super(Net, self).__init__()
238
+ self.args = args
239
+ self.device = 'cpu'
240
+ ### unet
241
+ model = VideoUNet_flow(**args.network_config)
242
+ self.model = OpenAIWrapper(model)
243
+
244
+ ### denoiser and sampler
245
+ self.denoiser = Denoiser(**args.denoiser_config)
246
+ self.sampler = EulerEDMSampler(**args.sampler_config)
247
+
248
+ ### conditioner
249
+ self.conditioner = GeneralConditioner(args.conditioner_emb_models)
250
+
251
+ ### first stage model
252
+ self.first_stage_model = AutoencodingEngine(**args.first_stage_config).eval()
253
+
254
+ self.scale_factor = args.scale_factor
255
+ self.en_and_decode_n_samples_a_time = 1 # decode 1 frame each time to save GPU memory
256
+ self.num_frames = args.num_frames
257
+ self.guassian_filter = quick_freeze(get_gaussian_kernel(kernel_size=args.kernel_size, sigma=args.sigma, channels=2))
258
+
259
+ unet_lora_params, unet_negation = inject_lora(
260
+ True, self, ['OpenAIWrapper'], is_extended=False, r=args.unet_lora_rank
261
+ )
262
+
263
+ def to(self, *args, **kwargs):
264
+ model_converted = super().to(*args, **kwargs)
265
+ self.device = next(self.parameters()).device
266
+ self.sampler.device = self.device
267
+ for embedder in self.conditioner.embedders:
268
+ if hasattr(embedder, "device"):
269
+ embedder.device = self.device
270
+ return model_converted
271
+
272
+ def train(self, *args):
273
+ super().train(*args)
274
+ self.conditioner.eval()
275
+ self.first_stage_model.eval()
276
+
277
+ def apply_gaussian_filter_on_drag(self, drag):
278
+ b, l, h, w, c = drag.shape
279
+ drag = rearrange(drag, 'b l h w c -> (b l) c h w')
280
+ drag = self.guassian_filter(drag)
281
+ drag = rearrange(drag, '(b l) c h w -> b l h w c', b=b)
282
+ return drag
283
+
284
+ @torch.no_grad()
285
+ def decode_first_stage(self, z):
286
+ z = 1.0 / self.scale_factor * z
287
+ n_samples = self.en_and_decode_n_samples_a_time # 1
288
+ n_rounds = math.ceil(z.shape[0] / n_samples)
289
+ all_out = []
290
+ for n in range(n_rounds):
291
+ kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
292
+ out = self.first_stage_model.decode(
293
+ z[n * n_samples : (n + 1) * n_samples], **kwargs
294
+ )
295
+ all_out.append(out)
296
+ out = torch.cat(all_out, dim=0)
297
+ return out
app.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import cv2
4
+ from PIL import Image, ImageFilter
5
+ import uuid
6
+ from scipy.interpolate import interp1d, PchipInterpolator
7
+ import torchvision
8
+ from utils import *
9
+
10
+ output_dir = "outputs"
11
+ ensure_dirname(output_dir)
12
+
13
+ def interpolate_trajectory(points, n_points):
14
+ x = [point[0] for point in points]
15
+ y = [point[1] for point in points]
16
+
17
+ t = np.linspace(0, 1, len(points))
18
+
19
+ # fx = interp1d(t, x, kind='cubic')
20
+ # fy = interp1d(t, y, kind='cubic')
21
+ fx = PchipInterpolator(t, x)
22
+ fy = PchipInterpolator(t, y)
23
+
24
+ new_t = np.linspace(0, 1, n_points)
25
+
26
+ new_x = fx(new_t)
27
+ new_y = fy(new_t)
28
+ new_points = list(zip(new_x, new_y))
29
+
30
+ return new_points
31
+
32
+ def visualize_drag_v2(background_image_path, splited_tracks, width, height):
33
+ trajectory_maps = []
34
+
35
+ background_image = Image.open(background_image_path).convert('RGBA')
36
+ background_image = background_image.resize((width, height))
37
+ w, h = background_image.size
38
+ transparent_background = np.array(background_image)
39
+ transparent_background[:, :, -1] = 128
40
+ transparent_background = Image.fromarray(transparent_background)
41
+
42
+ # Create a transparent layer with the same size as the background image
43
+ transparent_layer = np.zeros((h, w, 4))
44
+ for splited_track in splited_tracks:
45
+ if len(splited_track) > 1:
46
+ splited_track = interpolate_trajectory(splited_track, 16)
47
+ splited_track = splited_track[:16]
48
+ for i in range(len(splited_track)-1):
49
+ start_point = (int(splited_track[i][0]), int(splited_track[i][1]))
50
+ end_point = (int(splited_track[i+1][0]), int(splited_track[i+1][1]))
51
+ vx = end_point[0] - start_point[0]
52
+ vy = end_point[1] - start_point[1]
53
+ arrow_length = np.sqrt(vx**2 + vy**2)
54
+ if i == len(splited_track)-2:
55
+ cv2.arrowedLine(transparent_layer, start_point, end_point, (255, 0, 0, 192), 2, tipLength=8 / arrow_length)
56
+ else:
57
+ cv2.line(transparent_layer, start_point, end_point, (255, 0, 0, 192), 2)
58
+ else:
59
+ cv2.circle(transparent_layer, (int(splited_track[0][0]), int(splited_track[0][1])), 5, (255, 0, 0, 192), -1)
60
+
61
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
62
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
63
+ trajectory_maps.append(trajectory_map)
64
+ return trajectory_maps, transparent_layer
65
+
66
+ class Drag:
67
+ def __init__(self, device, model_path, cfg_path, height, width, model_length):
68
+ self.device = device
69
+ cf = import_filename(cfg_path)
70
+ Net, args = cf.Net, cf.args
71
+ drag_nuwa_net = Net(args)
72
+ state_dict = file2data(model_path, map_location='cpu')
73
+ adaptively_load_state_dict(drag_nuwa_net, state_dict)
74
+ drag_nuwa_net.eval()
75
+ drag_nuwa_net.to(device)
76
+ # drag_nuwa_net.half()
77
+ self.drag_nuwa_net = drag_nuwa_net
78
+ self.height = height
79
+ self.width = width
80
+ _, model_step, _ = split_filename(model_path)
81
+ self.ouput_prefix = f'{model_step}_{width}X{height}'
82
+ self.model_length = model_length
83
+
84
+ @torch.no_grad()
85
+ def forward_sample(self, input_drag, input_first_frame, motion_bucket_id, outputs=dict()):
86
+ device = self.device
87
+
88
+ b, l, h, w, c = input_drag.size()
89
+ drag = self.drag_nuwa_net.apply_gaussian_filter_on_drag(input_drag)
90
+ drag = torch.cat([torch.zeros_like(drag[:, 0]).unsqueeze(1), drag], dim=1) # pad the first frame with zero flow
91
+ drag = rearrange(drag, 'b l h w c -> b l c h w')
92
+
93
+ input_conditioner = dict()
94
+ input_conditioner['cond_frames_without_noise'] = input_first_frame
95
+ input_conditioner['cond_frames'] = (input_first_frame + 0.02 * torch.randn_like(input_first_frame))
96
+ input_conditioner['motion_bucket_id'] = torch.tensor([motion_bucket_id]).to(drag.device).repeat(b * (l+1))
97
+ input_conditioner['fps_id'] = torch.tensor([self.drag_nuwa_net.args.fps]).to(drag.device).repeat(b * (l+1))
98
+ input_conditioner['cond_aug'] = torch.tensor([0.02]).to(drag.device).repeat(b * (l+1))
99
+
100
+ input_conditioner_uc = {}
101
+ for key in input_conditioner.keys():
102
+ if key not in input_conditioner_uc and isinstance(input_conditioner[key], torch.Tensor):
103
+ input_conditioner_uc[key] = input_conditioner[key].clone()
104
+
105
+ c, uc = self.drag_nuwa_net.conditioner.get_unconditional_conditioning(
106
+ input_conditioner,
107
+ batch_uc=input_conditioner_uc,
108
+ force_uc_zero_embeddings=[
109
+ "cond_frames",
110
+ "cond_frames_without_noise",
111
+ ],
112
+ )
113
+
114
+ for k in ["crossattn", "concat"]:
115
+ uc[k] = repeat(uc[k], "b ... -> b t ...", t=self.drag_nuwa_net.num_frames)
116
+ uc[k] = rearrange(uc[k], "b t ... -> (b t) ...")
117
+ c[k] = repeat(c[k], "b ... -> b t ...", t=self.drag_nuwa_net.num_frames)
118
+ c[k] = rearrange(c[k], "b t ... -> (b t) ...")
119
+
120
+ H, W = input_conditioner['cond_frames_without_noise'].shape[2:]
121
+ shape = (self.drag_nuwa_net.num_frames, 4, H // 8, W // 8)
122
+ randn = torch.randn(shape).to(self.device)
123
+
124
+ additional_model_inputs = {}
125
+ additional_model_inputs["image_only_indicator"] = torch.zeros(
126
+ 2, self.drag_nuwa_net.num_frames
127
+ ).to(self.device)
128
+ additional_model_inputs["num_video_frames"] = self.drag_nuwa_net.num_frames
129
+ additional_model_inputs["flow"] = drag.repeat(2, 1, 1, 1, 1) # c and uc
130
+
131
+ def denoiser(input, sigma, c):
132
+ return self.drag_nuwa_net.denoiser(self.drag_nuwa_net.model, input, sigma, c, **additional_model_inputs)
133
+
134
+ samples_z = self.drag_nuwa_net.sampler(denoiser, randn, cond=c, uc=uc)
135
+ samples = self.drag_nuwa_net.decode_first_stage(samples_z)
136
+
137
+ outputs['logits_imgs'] = rearrange(samples, '(b l) c h w -> b l c h w', b=b)
138
+ return outputs
139
+
140
+ def run(self, first_frame_path, tracking_points, inference_batch_size, motion_bucket_id):
141
+ original_width, original_height=576, 320
142
+
143
+ input_all_points = tracking_points.constructor_args['value']
144
+ resized_all_points = [tuple([tuple([int(e1[0]*self.width/original_width), int(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points]
145
+
146
+ input_drag = torch.zeros(self.model_length - 1, self.height, self.width, 2)
147
+ for splited_track in resized_all_points:
148
+ if len(splited_track) == 1: # stationary point
149
+ displacement_point = tuple([splited_track[0][0] + 1, splited_track[0][1] + 1])
150
+ splited_track = tuple([splited_track[0], displacement_point])
151
+ # interpolate the track
152
+ splited_track = interpolate_trajectory(splited_track, self.model_length)
153
+ splited_track = splited_track[:self.model_length]
154
+ if len(splited_track) < self.model_length:
155
+ splited_track = splited_track + [splited_track[-1]] * (self.model_length -len(splited_track))
156
+ for i in range(self.model_length - 1):
157
+ start_point = splited_track[i]
158
+ end_point = splited_track[i+1]
159
+ input_drag[i][int(start_point[1])][int(start_point[0])][0] = end_point[0] - start_point[0]
160
+ input_drag[i][int(start_point[1])][int(start_point[0])][1] = end_point[1] - start_point[1]
161
+
162
+ dir, base, ext = split_filename(first_frame_path)
163
+ id = base.split('_')[-1]
164
+
165
+ image_pil = image2pil(first_frame_path)
166
+ image_pil = image_pil.resize((self.width, self.height), Image.BILINEAR).convert('RGB')
167
+
168
+ visualized_drag, _ = visualize_drag_v2(first_frame_path, resized_all_points, self.width, self.height)
169
+
170
+ first_frames_transform = transforms.Compose([
171
+ lambda x: Image.fromarray(x),
172
+ transforms.ToTensor(),
173
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
174
+ ])
175
+
176
+ outputs = None
177
+ ouput_video_list = []
178
+ num_inference = 1
179
+ for i in tqdm(range(num_inference)):
180
+ if not outputs:
181
+ first_frames = image2arr(first_frame_path)
182
+ first_frames = repeat(first_frames_transform(first_frames), 'c h w -> b c h w', b=inference_batch_size).to(self.device)
183
+ else:
184
+ first_frames = outputs['logits_imgs'][:, -1]
185
+
186
+ outputs = self.forward_sample(
187
+ repeat(input_drag[i*(self.model_length - 1):(i+1)*(self.model_length - 1)], 'l h w c -> b l h w c', b=inference_batch_size).to(self.device),
188
+ first_frames,
189
+ motion_bucket_id)
190
+ ouput_video_list.append(outputs['logits_imgs'])
191
+
192
+ for i in range(inference_batch_size):
193
+ ouput_tensor = [ouput_video_list[0][i]]
194
+ for j in range(num_inference - 1):
195
+ ouput_tensor.append(ouput_video_list[j+1][i][1:])
196
+ ouput_tensor = torch.cat(ouput_tensor, dim=0)
197
+ outputs_path = os.path.join(output_dir, f'output_{i}_{id}.gif')
198
+ data2file([transforms.ToPILImage('RGB')(utils.make_grid(e.to(torch.float32).cpu(), normalize=True, range=(-1, 1))) for e in ouput_tensor], outputs_path,
199
+ printable=False, duration=1 / 6, override=True)
200
+
201
+ return visualized_drag[0], outputs_path
202
+
203
+ with gr.Blocks() as demo:
204
+ gr.Markdown("""<h1 align="center">DragNUWA 1.5</h1><br>""")
205
+
206
+ gr.Markdown("""Official Gradio Demo for <a href='https://arxiv.org/abs/2308.08089'><b>DragNUWA: Fine-grained Control in Video Generation by Integrating Text, Image, and Trajectory</b></a>.<br>
207
+ 🔥DragNUWA enables users to manipulate backgrounds or objects within images directly, and the model seamlessly translates these actions into **camera movements** or **object motions**, generating the corresponding video.<br>
208
+ 🔥DragNUWA 1.5 enables Stable Video Diffusion to animate an image according to specific path.<br>""")
209
+
210
+ gr.Markdown("""## Usage: <br>
211
+ 1. Upload an image via the "Upload Image" button.<br>
212
+ 2. Draw some drags.<br>
213
+ 2.1. Click "Add Drag" when you want to add a control path.<br>
214
+ 2.2. You can click several points which forms a path.<br>
215
+ 2.3. Click "Delete last drag" to delete the whole lastest path.<br>
216
+ 2.4. Click "Delete last step" to delete the lastest clicked control point.<br>
217
+ 3. Animate the image according the path with a click on "Run" button. <br>""")
218
+
219
+ DragNUWA_net = Drag("cuda:0", 'models/drag_nuwa_svd.pth', 'DragNUWA_net.py', 320, 576, 14)
220
+ first_frame_path = gr.State()
221
+ tracking_points = gr.State([])
222
+
223
+ def reset_states(first_frame_path, tracking_points):
224
+ first_frame_path = gr.State()
225
+ tracking_points = gr.State([])
226
+ return first_frame_path, tracking_points
227
+
228
+ def preprocess_image(image):
229
+ image_pil = image2pil(image.name)
230
+ raw_w, raw_h = image_pil.size
231
+ resize_ratio = max(576/raw_w, 320/raw_h)
232
+ image_pil = image_pil.resize((int(raw_w * resize_ratio), int(raw_h * resize_ratio)), Image.BILINEAR)
233
+ image_pil = transforms.CenterCrop((320, 576))(image_pil.convert('RGB'))
234
+
235
+ first_frame_path = os.path.join(output_dir, f"first_frame_{str(uuid.uuid4())[:4]}.png")
236
+ image_pil.save(first_frame_path)
237
+
238
+ return first_frame_path, first_frame_path, gr.State([])
239
+
240
+ def add_drag(tracking_points):
241
+ tracking_points.constructor_args['value'].append([])
242
+ return tracking_points
243
+
244
+ def delete_last_drag(tracking_points, first_frame_path):
245
+ tracking_points.constructor_args['value'].pop()
246
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
247
+ w, h = transparent_background.size
248
+ transparent_layer = np.zeros((h, w, 4))
249
+ for track in tracking_points.constructor_args['value']:
250
+ if len(track) > 1:
251
+ for i in range(len(track)-1):
252
+ start_point = track[i]
253
+ end_point = track[i+1]
254
+ vx = end_point[0] - start_point[0]
255
+ vy = end_point[1] - start_point[1]
256
+ arrow_length = np.sqrt(vx**2 + vy**2)
257
+ if i == len(track)-2:
258
+ cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length)
259
+ else:
260
+ cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,)
261
+ else:
262
+ cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
263
+
264
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
265
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
266
+ return tracking_points, trajectory_map
267
+
268
+ def delete_last_step(tracking_points, first_frame_path):
269
+ tracking_points.constructor_args['value'][-1].pop()
270
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
271
+ w, h = transparent_background.size
272
+ transparent_layer = np.zeros((h, w, 4))
273
+ for track in tracking_points.constructor_args['value']:
274
+ if len(track) > 1:
275
+ for i in range(len(track)-1):
276
+ start_point = track[i]
277
+ end_point = track[i+1]
278
+ vx = end_point[0] - start_point[0]
279
+ vy = end_point[1] - start_point[1]
280
+ arrow_length = np.sqrt(vx**2 + vy**2)
281
+ if i == len(track)-2:
282
+ cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length)
283
+ else:
284
+ cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,)
285
+ else:
286
+ cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
287
+
288
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
289
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
290
+ return tracking_points, trajectory_map
291
+
292
+ def add_tracking_points(tracking_points, first_frame_path, evt: gr.SelectData): # SelectData is a subclass of EventData
293
+ print(f"You selected {evt.value} at {evt.index} from {evt.target}")
294
+ tracking_points.constructor_args['value'][-1].append(evt.index)
295
+
296
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
297
+ w, h = transparent_background.size
298
+ transparent_layer = np.zeros((h, w, 4))
299
+ for track in tracking_points.constructor_args['value']:
300
+ if len(track) > 1:
301
+ for i in range(len(track)-1):
302
+ start_point = track[i]
303
+ end_point = track[i+1]
304
+ vx = end_point[0] - start_point[0]
305
+ vy = end_point[1] - start_point[1]
306
+ arrow_length = np.sqrt(vx**2 + vy**2)
307
+ if i == len(track)-2:
308
+ cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length)
309
+ else:
310
+ cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,)
311
+ else:
312
+ cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
313
+
314
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
315
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
316
+ return tracking_points, trajectory_map
317
+
318
+ with gr.Row():
319
+ with gr.Column(scale=1):
320
+ image_upload_button = gr.UploadButton(label="Upload Image",file_types=["image"])
321
+ add_drag_button = gr.Button(value="Add Drag")
322
+ reset_button = gr.Button(value="Reset")
323
+ run_button = gr.Button(value="Run")
324
+ delete_last_drag_button = gr.Button(value="Delete last drag")
325
+ delete_last_step_button = gr.Button(value="Delete last step")
326
+
327
+ with gr.Column(scale=7):
328
+ with gr.Row():
329
+ with gr.Column(scale=6):
330
+ input_image = gr.Image(label=None,
331
+ interactive=True,
332
+ height=320,
333
+ width=576,)
334
+ with gr.Column(scale=6):
335
+ output_image = gr.Image(label=None,
336
+ height=320,
337
+ width=576,)
338
+
339
+ with gr.Row():
340
+ with gr.Column(scale=1):
341
+ inference_batch_size = gr.Slider(label='Inference Batch Size',
342
+ minimum=1,
343
+ maximum=1,
344
+ step=1,
345
+ value=1)
346
+
347
+ motion_bucket_id = gr.Slider(label='Motion Bucket',
348
+ minimum=1,
349
+ maximum=100,
350
+ step=1,
351
+ value=4)
352
+
353
+ with gr.Column(scale=5):
354
+ output_video = gr.Image(label="Output Video",
355
+ height=320,
356
+ width=576,)
357
+
358
+ with gr.Row():
359
+ gr.Markdown("""
360
+ ## Citation
361
+ ```bibtex
362
+ @article{yin2023dragnuwa,
363
+ title={Dragnuwa: Fine-grained control in video generation by integrating text, image, and trajectory},
364
+ author={Yin, Shengming and Wu, Chenfei and Liang, Jian and Shi, Jie and Li, Houqiang and Ming, Gong and Duan, Nan},
365
+ journal={arXiv preprint arXiv:2308.08089},
366
+ year={2023}
367
+ }
368
+ ```
369
+ """)
370
+
371
+
372
+ image_upload_button.upload(preprocess_image, image_upload_button, [input_image, first_frame_path, tracking_points])
373
+
374
+ add_drag_button.click(add_drag, tracking_points, tracking_points)
375
+
376
+ delete_last_drag_button.click(delete_last_drag, [tracking_points, first_frame_path], [tracking_points, input_image])
377
+
378
+ delete_last_step_button.click(delete_last_step, [tracking_points, first_frame_path], [tracking_points, input_image])
379
+
380
+ reset_button.click(reset_states, [first_frame_path, tracking_points], [first_frame_path, tracking_points])
381
+
382
+ input_image.select(add_tracking_points, [tracking_points, first_frame_path], [tracking_points, input_image])
383
+
384
+ run_button.click(DragNUWA_net.run, [first_frame_path, tracking_points, inference_batch_size, motion_bucket_id], [output_image, output_video])
385
+
386
+ demo.launch(server_name="0.0.0.0", debug=True)
assets/DragNUWA1.0/Figure1.gif ADDED

Git LFS Details

  • SHA256: 523d732979aa1ab6f8d80902e0d33351bbfd0a9d7d1cba6ecf7072ecf598b9b2
  • Pointer size: 132 Bytes
  • Size of remote file: 2.27 MB
assets/DragNUWA1.0/Figure2.gif ADDED

Git LFS Details

  • SHA256: 9cbc27131e06230799327bea6af235accb28c1a29cb3f146460751ab37b0910d
  • Pointer size: 132 Bytes
  • Size of remote file: 2.36 MB
assets/DragNUWA1.0/Figure3.gif ADDED

Git LFS Details

  • SHA256: b636d8818e849e6d05226b9c9379e58e78141c9cb8c0d00d64947e7d4005143d
  • Pointer size: 132 Bytes
  • Size of remote file: 2.36 MB
assets/DragNUWA1.5/Figure1.gif ADDED

Git LFS Details

  • SHA256: f688a52aa22e3956861d9eae78f5936bb7206a2d4e3aacf054ef703428aaf63b
  • Pointer size: 132 Bytes
  • Size of remote file: 2.36 MB
assets/DragNUWA1.5/Figure2.gif ADDED

Git LFS Details

  • SHA256: f04c0916ea138c61c013176aef8732bbbaebe4d6c64e77ad53ab15a4b5cc2c42
  • Pointer size: 132 Bytes
  • Size of remote file: 2.36 MB
assets/DragNUWA1.5/Figure3.gif ADDED

Git LFS Details

  • SHA256: b34954df08ed95028a60c161600b5c5869f57f164f6e3890466902e497d2fd9d
  • Pointer size: 132 Bytes
  • Size of remote file: 2.36 MB
assets/DragNUWA1.5/Figure4.gif ADDED

Git LFS Details

  • SHA256: a9bb44da6df81dd5a7521af0e315cbcd408f5e3079988f57d2f068ed879b0a5b
  • Pointer size: 132 Bytes
  • Size of remote file: 2.36 MB
dragnuwa.md ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DragNUWA
2
+
3
+ **DragNUWA** enables users to manipulate backgrounds or objects within images directly, and the model seamlessly translates these actions into **camera movements** or **object motions**, generating the corresponding video.
4
+
5
+ See our paper: [DragNUWA: Fine-grained Control in Video Generation by Integrating Text, Image, and Trajectory](https://arxiv.org/abs/2308.08089)
6
+
7
+ <a src="https://img.shields.io/badge/%F0%9F%A4%97-Open%20in%20Spaces-blue" href="TOBEDONE">
8
+ <img src="https://img.shields.io/badge/%F0%9F%A4%97-Open%20in%20Spaces-blue" alt="Open in Spaces">
9
+ </a>
10
+ <a src="https://colab.research.google.com/assets/colab-badge.svg" href="TOBEDONE">
11
+ <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab">
12
+ </a>
13
+
14
+ ### DragNUWA 1.5 (Updated on Jan 8, 2024)
15
+
16
+ **DragNUWA 1.5** enables Stable Video Diffusion to animate an image according to specific path.
17
+
18
+ <p align="center">
19
+ <img src="assets/DragNUWA1.5/Figure1.gif" width="90%">
20
+ </p>
21
+
22
+ <p align="center">
23
+ <img src="assets/DragNUWA1.5/Figure2.gif" width="90%">
24
+ </p>
25
+ <p align="center">
26
+ <img src="assets/DragNUWA1.5/Figure3.gif" width="90%">
27
+ </p>
28
+ <p align="center">
29
+ <img src="assets/DragNUWA1.5/Figure4.gif" width="90%">
30
+ </p>
31
+
32
+ ### DragNUWA 1.0 (Original Paper)
33
+ [**DragNUWA 1.0**](https://arxiv.org/abs/2308.08089) utilizes text, images, and trajectory as three essential control factors to facilitate highly controllable video generation from semantic, spatial, and temporal aspects.
34
+
35
+ <p align="center">
36
+ <img src="assets/DragNUWA1.0/Figure1.gif" width="90%">
37
+ </p>
38
+ <p align="center">
39
+ <img src="assets/DragNUWA1.0/Figure2.gif" width="100%">
40
+ </p>
41
+ <p align="center">
42
+ <img src="assets/DragNUWA1.0/Figure3.gif" width="100%">
43
+ </p>
44
+
45
+ ## Getting Start
46
+
47
+ ### Setting Environment
48
+ ```Shell
49
+ git clone -b svd https://github.com/ProjectNUWA/DragNUWA.git
50
+ cd DragNUWA
51
+
52
+ conda create -n DragNUWA python=3.8
53
+ conda activate DragNUWA
54
+ pip install -r environment.txt
55
+ ```
56
+
57
+ ### Download Pretrained Weights
58
+ Download the [Pretrained Weights](https://drive.google.com/file/d/1Z4JOley0SJCb35kFF4PCc6N6P1ftfX4i/view) to `models/` directory or directly run `bash models/Download.sh`.
59
+
60
+ ### Drag and Animate !
61
+ ```Shell
62
+ python DragNUWA_demo.py
63
+ ```
64
+ It will launch a gradio demo, and you can drag an image and animate it!
65
+
66
+ ### Acknowledgement
67
+ We appreciate the open source of the following projects:
68
+ [Stable Video Diffusion](https://github.com/Stability-AI/generative-models) &#8194;
69
+ [Hugging Face](https://github.com/huggingface) &#8194;
70
+ [UniMatch](https://github.com/autonomousvision/unimatch)&#8194;
71
+
72
+ ### Citation
73
+ ```bibtex
74
+ @article{yin2023dragnuwa,
75
+ title={Dragnuwa: Fine-grained control in video generation by integrating text, image, and trajectory},
76
+ author={Yin, Shengming and Wu, Chenfei and Liang, Jian and Shi, Jie and Li, Houqiang and Ming, Gong and Duan, Nan},
77
+ journal={arXiv preprint arXiv:2308.08089},
78
+ year={2023}
79
+ }
80
+ ```
81
+
dragnuwa/__init__.py ADDED
File without changes
dragnuwa/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (131 Bytes). View file
 
dragnuwa/__pycache__/lora.cpython-38.pyc ADDED
Binary file (9.18 kB). View file
 
dragnuwa/lora.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import *
2
+ from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
3
+
4
+
5
+ class LoraInjectedLinear(nn.Module):
6
+ def __init__(
7
+ self, in_features, out_features, bias=False, r=4, dropout_p=0.1, scale=1.0
8
+ ):
9
+ super().__init__()
10
+
11
+ if r > min(in_features, out_features):
12
+ #raise ValueError(
13
+ # f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
14
+ #)
15
+ print(f"LoRA rank {r} is too large. setting to: {min(in_features, out_features)}")
16
+ r = min(in_features, out_features)
17
+
18
+ self.r = r
19
+ self.linear = nn.Linear(in_features, out_features, bias)
20
+ self.lora_down = nn.Linear(in_features, r, bias=False)
21
+ self.dropout = nn.Dropout(dropout_p)
22
+ self.lora_up = nn.Linear(r, out_features, bias=False)
23
+ self.scale = scale
24
+ self.selector = nn.Identity()
25
+
26
+ nn.init.normal_(self.lora_down.weight, std=1 / r)
27
+ nn.init.zeros_(self.lora_up.weight)
28
+
29
+ def forward(self, input):
30
+ return (
31
+ self.linear(input)
32
+ + self.dropout(self.lora_up(self.selector(self.lora_down(input))))
33
+ * self.scale
34
+ )
35
+
36
+ def realize_as_lora(self):
37
+ return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
38
+
39
+ def set_selector_from_diag(self, diag: torch.Tensor):
40
+ # diag is a 1D tensor of size (r,)
41
+ assert diag.shape == (self.r,)
42
+ self.selector = nn.Linear(self.r, self.r, bias=False)
43
+ self.selector.weight.data = torch.diag(diag)
44
+ self.selector.weight.data = self.selector.weight.data.to(
45
+ self.lora_up.weight.device
46
+ ).to(self.lora_up.weight.dtype)
47
+
48
+ class LoraInjectedConv2d(nn.Module):
49
+ def __init__(
50
+ self,
51
+ in_channels: int,
52
+ out_channels: int,
53
+ kernel_size,
54
+ stride=1,
55
+ padding=0,
56
+ dilation=1,
57
+ groups: int = 1,
58
+ bias: bool = True,
59
+ r: int = 4,
60
+ dropout_p: float = 0.1,
61
+ scale: float = 1.0,
62
+ ):
63
+ super().__init__()
64
+ if r > min(in_channels, out_channels):
65
+ print(f"LoRA rank {r} is too large. setting to: {min(in_channels, out_channels)}")
66
+ r = min(in_channels, out_channels)
67
+
68
+ self.r = r
69
+ self.conv = nn.Conv2d(
70
+ in_channels=in_channels,
71
+ out_channels=out_channels,
72
+ kernel_size=kernel_size,
73
+ stride=stride,
74
+ padding=padding,
75
+ dilation=dilation,
76
+ groups=groups,
77
+ bias=bias,
78
+ )
79
+
80
+ self.lora_down = nn.Conv2d(
81
+ in_channels=in_channels,
82
+ out_channels=r,
83
+ kernel_size=kernel_size,
84
+ stride=stride,
85
+ padding=padding,
86
+ dilation=dilation,
87
+ groups=groups,
88
+ bias=False,
89
+ )
90
+ self.dropout = nn.Dropout(dropout_p)
91
+ self.lora_up = nn.Conv2d(
92
+ in_channels=r,
93
+ out_channels=out_channels,
94
+ kernel_size=1,
95
+ stride=1,
96
+ padding=0,
97
+ bias=False,
98
+ )
99
+ self.selector = nn.Identity()
100
+ self.scale = scale
101
+
102
+ nn.init.normal_(self.lora_down.weight, std=1 / r)
103
+ nn.init.zeros_(self.lora_up.weight)
104
+
105
+ def forward(self, input):
106
+ return (
107
+ self.conv(input)
108
+ + self.dropout(self.lora_up(self.selector(self.lora_down(input))))
109
+ * self.scale
110
+ )
111
+
112
+ def realize_as_lora(self):
113
+ return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
114
+
115
+ def set_selector_from_diag(self, diag: torch.Tensor):
116
+ # diag is a 1D tensor of size (r,)
117
+ assert diag.shape == (self.r,)
118
+ self.selector = nn.Conv2d(
119
+ in_channels=self.r,
120
+ out_channels=self.r,
121
+ kernel_size=1,
122
+ stride=1,
123
+ padding=0,
124
+ bias=False,
125
+ )
126
+ self.selector.weight.data = torch.diag(diag)
127
+
128
+ # same device + dtype as lora_up
129
+ self.selector.weight.data = self.selector.weight.data.to(
130
+ self.lora_up.weight.device
131
+ ).to(self.lora_up.weight.dtype)
132
+
133
+ class LoraInjectedConv3d(nn.Module):
134
+ def __init__(
135
+ self,
136
+ in_channels: int,
137
+ out_channels: int,
138
+ kernel_size: (3, 1, 1),
139
+ padding: (1, 0, 0),
140
+ bias: bool = False,
141
+ r: int = 4,
142
+ dropout_p: float = 0,
143
+ scale: float = 1.0,
144
+ ):
145
+ super().__init__()
146
+ if r > min(in_channels, out_channels):
147
+ print(f"LoRA rank {r} is too large. setting to: {min(in_channels, out_channels)}")
148
+ r = min(in_channels, out_channels)
149
+
150
+ self.r = r
151
+ self.kernel_size = kernel_size
152
+ self.padding = padding
153
+ self.conv = nn.Conv3d(
154
+ in_channels=in_channels,
155
+ out_channels=out_channels,
156
+ kernel_size=kernel_size,
157
+ padding=padding,
158
+ )
159
+
160
+ self.lora_down = nn.Conv3d(
161
+ in_channels=in_channels,
162
+ out_channels=r,
163
+ kernel_size=kernel_size,
164
+ bias=False,
165
+ padding=padding
166
+ )
167
+ self.dropout = nn.Dropout(dropout_p)
168
+ self.lora_up = nn.Conv3d(
169
+ in_channels=r,
170
+ out_channels=out_channels,
171
+ kernel_size=1,
172
+ stride=1,
173
+ padding=0,
174
+ bias=False,
175
+ )
176
+ self.selector = nn.Identity()
177
+ self.scale = scale
178
+
179
+ nn.init.normal_(self.lora_down.weight, std=1 / r)
180
+ nn.init.zeros_(self.lora_up.weight)
181
+
182
+ def forward(self, input):
183
+ return (
184
+ self.conv(input)
185
+ + self.dropout(self.lora_up(self.selector(self.lora_down(input))))
186
+ * self.scale
187
+ )
188
+
189
+ def realize_as_lora(self):
190
+ return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
191
+
192
+ def set_selector_from_diag(self, diag: torch.Tensor):
193
+ # diag is a 1D tensor of size (r,)
194
+ assert diag.shape == (self.r,)
195
+ self.selector = nn.Conv3d(
196
+ in_channels=self.r,
197
+ out_channels=self.r,
198
+ kernel_size=1,
199
+ stride=1,
200
+ padding=0,
201
+ bias=False,
202
+ )
203
+ self.selector.weight.data = torch.diag(diag)
204
+
205
+ # same device + dtype as lora_up
206
+ self.selector.weight.data = self.selector.weight.data.to(
207
+ self.lora_up.weight.device
208
+ ).to(self.lora_up.weight.dtype)
209
+
210
+ def _find_modules(
211
+ model,
212
+ ancestor_class: Optional[Set[str]] = None,
213
+ search_class: List[Type[nn.Module]] = [nn.Linear],
214
+ exclude_children_of: Optional[List[Type[nn.Module]]] = [
215
+ LoraInjectedLinear,
216
+ LoraInjectedConv2d,
217
+ LoraInjectedConv3d
218
+ ],
219
+ ):
220
+ """
221
+ Find all modules of a certain class (or union of classes) that are direct or
222
+ indirect descendants of other modules of a certain class (or union of classes).
223
+
224
+ Returns all matching modules, along with the parent of those moduless and the
225
+ names they are referenced by.
226
+ """
227
+
228
+ # Get the targets we should replace all linears under
229
+ if ancestor_class is not None:
230
+ ancestors = (
231
+ module
232
+ for module in model.modules()
233
+ if module.__class__.__name__ in ancestor_class
234
+ )
235
+ else:
236
+ # this, incase you want to naively iterate over all modules.
237
+ ancestors = [module for module in model.modules()]
238
+
239
+ # For each target find every linear_class module that isn't a child of a LoraInjectedLinear
240
+ for ancestor in ancestors:
241
+ for fullname, module in ancestor.named_modules():
242
+ if any([isinstance(module, _class) for _class in search_class]):
243
+ # Find the direct parent if this is a descendant, not a child, of target
244
+ *path, name = fullname.split(".")
245
+ parent = ancestor
246
+ while path:
247
+ parent = parent.get_submodule(path.pop(0))
248
+ # Skip this linear if it's a child of a LoraInjectedLinear
249
+ if exclude_children_of and any(
250
+ [isinstance(parent, _class) for _class in exclude_children_of]
251
+ ):
252
+ continue
253
+ # Otherwise, yield it
254
+ yield parent, name, module
255
+
256
+
257
+ def inject_trainable_lora(
258
+ model: nn.Module,
259
+ target_replace_module,
260
+ r: int = 4,
261
+ loras=None, # path to lora .pt
262
+ verbose: bool = False,
263
+ dropout_p: float = 0.0,
264
+ scale: float = 1.0,
265
+ ):
266
+ """
267
+ inject lora into model, and returns lora parameter groups.
268
+ """
269
+
270
+ require_grad_params = []
271
+ names = []
272
+
273
+ if loras != None:
274
+ loras = torch.load(loras)
275
+
276
+ for _module, name, _child_module in _find_modules(
277
+ model, target_replace_module, search_class=[nn.Linear]
278
+ ):
279
+ weight = _child_module.weight
280
+ bias = _child_module.bias
281
+ if verbose:
282
+ print("LoRA Injection : injecting lora into ", name)
283
+ print("LoRA Injection : weight shape", weight.shape)
284
+ _tmp = LoraInjectedLinear(
285
+ _child_module.in_features,
286
+ _child_module.out_features,
287
+ _child_module.bias is not None,
288
+ r=r,
289
+ dropout_p=dropout_p,
290
+ scale=scale,
291
+ )
292
+ _tmp.linear.weight = weight
293
+ if bias is not None:
294
+ _tmp.linear.bias = bias
295
+
296
+ # switch the module
297
+ _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
298
+ _module._modules[name] = _tmp
299
+
300
+ require_grad_params.append(_module._modules[name].lora_up.parameters())
301
+ require_grad_params.append(_module._modules[name].lora_down.parameters())
302
+
303
+ if loras != None:
304
+ _module._modules[name].lora_up.weight = loras.pop(0)
305
+ _module._modules[name].lora_down.weight = loras.pop(0)
306
+
307
+ _module._modules[name].lora_up.weight.requires_grad = True
308
+ _module._modules[name].lora_down.weight.requires_grad = True
309
+ names.append(name)
310
+
311
+ return require_grad_params, names
312
+
313
+
314
+ def inject_trainable_lora_extended(
315
+ model: nn.Module,
316
+ target_replace_module,
317
+ r: int = 4,
318
+ loras=None, # path to lora .pt
319
+ ):
320
+ """
321
+ inject lora into model, and returns lora parameter groups.
322
+ """
323
+
324
+ require_grad_params = []
325
+ names = []
326
+
327
+ if loras != None:
328
+ loras = torch.load(loras)
329
+
330
+ for _module, name, _child_module in _find_modules(
331
+ model, target_replace_module, search_class=[nn.Linear, nn.Conv2d, nn.Conv3d]
332
+ ):
333
+ if _child_module.__class__ == nn.Linear:
334
+ weight = _child_module.weight
335
+ bias = _child_module.bias
336
+ _tmp = LoraInjectedLinear(
337
+ _child_module.in_features,
338
+ _child_module.out_features,
339
+ _child_module.bias is not None,
340
+ r=r,
341
+ )
342
+ _tmp.linear.weight = weight
343
+ if bias is not None:
344
+ _tmp.linear.bias = bias
345
+ elif _child_module.__class__ == nn.Conv2d:
346
+ weight = _child_module.weight
347
+ bias = _child_module.bias
348
+ _tmp = LoraInjectedConv2d(
349
+ _child_module.in_channels,
350
+ _child_module.out_channels,
351
+ _child_module.kernel_size,
352
+ _child_module.stride,
353
+ _child_module.padding,
354
+ _child_module.dilation,
355
+ _child_module.groups,
356
+ _child_module.bias is not None,
357
+ r=r,
358
+ )
359
+
360
+ _tmp.conv.weight = weight
361
+ if bias is not None:
362
+ _tmp.conv.bias = bias
363
+
364
+ elif _child_module.__class__ == nn.Conv3d:
365
+ weight = _child_module.weight
366
+ bias = _child_module.bias
367
+ _tmp = LoraInjectedConv3d(
368
+ _child_module.in_channels,
369
+ _child_module.out_channels,
370
+ bias=_child_module.bias is not None,
371
+ kernel_size=_child_module.kernel_size,
372
+ padding=_child_module.padding,
373
+ r=r,
374
+ )
375
+
376
+ _tmp.conv.weight = weight
377
+ if bias is not None:
378
+ _tmp.conv.bias = bias
379
+ # switch the module
380
+ _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
381
+ if bias is not None:
382
+ _tmp.to(_child_module.bias.device).to(_child_module.bias.dtype)
383
+
384
+ _module._modules[name] = _tmp
385
+ require_grad_params.append(_module._modules[name].lora_up.parameters())
386
+ require_grad_params.append(_module._modules[name].lora_down.parameters())
387
+
388
+ if loras != None:
389
+ _module._modules[name].lora_up.weight = loras.pop(0)
390
+ _module._modules[name].lora_down.weight = loras.pop(0)
391
+
392
+ _module._modules[name].lora_up.weight.requires_grad = True
393
+ _module._modules[name].lora_down.weight.requires_grad = True
394
+ names.append(name)
395
+
396
+ return require_grad_params, names
397
+
398
+ def extract_lora_ups_down(model, target_replace_module):
399
+
400
+ loras = []
401
+
402
+ for _m, _n, _child_module in _find_modules(
403
+ model,
404
+ target_replace_module,
405
+ search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d],
406
+ ):
407
+ loras.append((_child_module.lora_up, _child_module.lora_down))
408
+
409
+ if len(loras) == 0:
410
+ raise ValueError("No lora injected.")
411
+
412
+ return loras
dragnuwa/svd/__init__.py ADDED
File without changes
dragnuwa/svd/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (135 Bytes). View file
 
dragnuwa/svd/__pycache__/util.cpython-38.pyc ADDED
Binary file (9.38 kB). View file
 
dragnuwa/svd/models/__init__.py ADDED
File without changes
dragnuwa/svd/models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (142 Bytes). View file
 
dragnuwa/svd/models/__pycache__/autoencoder.cpython-38.pyc ADDED
Binary file (19.2 kB). View file
 
dragnuwa/svd/models/autoencoder.py ADDED
@@ -0,0 +1,615 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import re
4
+ from abc import abstractmethod
5
+ from contextlib import contextmanager
6
+ from typing import Any, Dict, List, Optional, Tuple, Union
7
+
8
+ import pytorch_lightning as pl
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange
12
+ from packaging import version
13
+
14
+ from ..modules.autoencoding.regularizers import AbstractRegularizer
15
+ from ..modules.ema import LitEma
16
+ from ..util import (default, get_nested_attribute, get_obj_from_str,
17
+ instantiate_from_config)
18
+
19
+ logpy = logging.getLogger(__name__)
20
+
21
+
22
+ class AbstractAutoencoder(pl.LightningModule):
23
+ """
24
+ This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
25
+ unCLIP models, etc. Hence, it is fairly general, and specific features
26
+ (e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ ema_decay: Union[None, float] = None,
32
+ monitor: Union[None, str] = None,
33
+ input_key: str = "jpg",
34
+ ):
35
+ super().__init__()
36
+
37
+ self.input_key = input_key
38
+ self.use_ema = ema_decay is not None
39
+ if monitor is not None:
40
+ self.monitor = monitor
41
+
42
+ if self.use_ema:
43
+ self.model_ema = LitEma(self, decay=ema_decay)
44
+ logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
45
+
46
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
47
+ self.automatic_optimization = False
48
+
49
+ def apply_ckpt(self, ckpt: Union[None, str, dict]):
50
+ if ckpt is None:
51
+ return
52
+ if isinstance(ckpt, str):
53
+ ckpt = {
54
+ "target": "dragnuwa.svd.modules.checkpoint.CheckpointEngine",
55
+ "params": {"ckpt_path": ckpt},
56
+ }
57
+ engine = instantiate_from_config(ckpt)
58
+ engine(self)
59
+
60
+ @abstractmethod
61
+ def get_input(self, batch) -> Any:
62
+ raise NotImplementedError()
63
+
64
+ def on_train_batch_end(self, *args, **kwargs):
65
+ # for EMA computation
66
+ if self.use_ema:
67
+ self.model_ema(self)
68
+
69
+ @contextmanager
70
+ def ema_scope(self, context=None):
71
+ if self.use_ema:
72
+ self.model_ema.store(self.parameters())
73
+ self.model_ema.copy_to(self)
74
+ if context is not None:
75
+ logpy.info(f"{context}: Switched to EMA weights")
76
+ try:
77
+ yield None
78
+ finally:
79
+ if self.use_ema:
80
+ self.model_ema.restore(self.parameters())
81
+ if context is not None:
82
+ logpy.info(f"{context}: Restored training weights")
83
+
84
+ @abstractmethod
85
+ def encode(self, *args, **kwargs) -> torch.Tensor:
86
+ raise NotImplementedError("encode()-method of abstract base class called")
87
+
88
+ @abstractmethod
89
+ def decode(self, *args, **kwargs) -> torch.Tensor:
90
+ raise NotImplementedError("decode()-method of abstract base class called")
91
+
92
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
93
+ logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
94
+ return get_obj_from_str(cfg["target"])(
95
+ params, lr=lr, **cfg.get("params", dict())
96
+ )
97
+
98
+ def configure_optimizers(self) -> Any:
99
+ raise NotImplementedError()
100
+
101
+
102
+ class AutoencodingEngine(AbstractAutoencoder):
103
+ """
104
+ Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
105
+ (we also restore them explicitly as special cases for legacy reasons).
106
+ Regularizations such as KL or VQ are moved to the regularizer class.
107
+ """
108
+
109
+ def __init__(
110
+ self,
111
+ *args,
112
+ encoder_config: Dict,
113
+ decoder_config: Dict,
114
+ loss_config: Dict,
115
+ regularizer_config: Dict,
116
+ optimizer_config: Union[Dict, None] = None,
117
+ lr_g_factor: float = 1.0,
118
+ trainable_ae_params: Optional[List[List[str]]] = None,
119
+ ae_optimizer_args: Optional[List[dict]] = None,
120
+ trainable_disc_params: Optional[List[List[str]]] = None,
121
+ disc_optimizer_args: Optional[List[dict]] = None,
122
+ disc_start_iter: int = 0,
123
+ diff_boost_factor: float = 3.0,
124
+ ckpt_engine: Union[None, str, dict] = None,
125
+ ckpt_path: Optional[str] = None,
126
+ additional_decode_keys: Optional[List[str]] = None,
127
+ **kwargs,
128
+ ):
129
+ super().__init__(*args, **kwargs)
130
+ self.automatic_optimization = False # pytorch lightning
131
+
132
+ self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
133
+ self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
134
+ self.loss: torch.nn.Module = instantiate_from_config(loss_config)
135
+ self.regularization: AbstractRegularizer = instantiate_from_config(
136
+ regularizer_config
137
+ )
138
+ self.optimizer_config = default(
139
+ optimizer_config, {"target": "torch.optim.Adam"}
140
+ )
141
+ self.diff_boost_factor = diff_boost_factor
142
+ self.disc_start_iter = disc_start_iter
143
+ self.lr_g_factor = lr_g_factor
144
+ self.trainable_ae_params = trainable_ae_params
145
+ if self.trainable_ae_params is not None:
146
+ self.ae_optimizer_args = default(
147
+ ae_optimizer_args,
148
+ [{} for _ in range(len(self.trainable_ae_params))],
149
+ )
150
+ assert len(self.ae_optimizer_args) == len(self.trainable_ae_params)
151
+ else:
152
+ self.ae_optimizer_args = [{}] # makes type consitent
153
+
154
+ self.trainable_disc_params = trainable_disc_params
155
+ if self.trainable_disc_params is not None:
156
+ self.disc_optimizer_args = default(
157
+ disc_optimizer_args,
158
+ [{} for _ in range(len(self.trainable_disc_params))],
159
+ )
160
+ assert len(self.disc_optimizer_args) == len(self.trainable_disc_params)
161
+ else:
162
+ self.disc_optimizer_args = [{}] # makes type consitent
163
+
164
+ if ckpt_path is not None:
165
+ assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path"
166
+ logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead")
167
+ self.apply_ckpt(default(ckpt_path, ckpt_engine))
168
+ self.additional_decode_keys = set(default(additional_decode_keys, []))
169
+
170
+ def get_input(self, batch: Dict) -> torch.Tensor:
171
+ # assuming unified data format, dataloader returns a dict.
172
+ # image tensors should be scaled to -1 ... 1 and in channels-first
173
+ # format (e.g., bchw instead if bhwc)
174
+ return batch[self.input_key]
175
+
176
+ def get_autoencoder_params(self) -> list:
177
+ params = []
178
+ if hasattr(self.loss, "get_trainable_autoencoder_parameters"):
179
+ params += list(self.loss.get_trainable_autoencoder_parameters())
180
+ if hasattr(self.regularization, "get_trainable_parameters"):
181
+ params += list(self.regularization.get_trainable_parameters())
182
+ params = params + list(self.encoder.parameters())
183
+ params = params + list(self.decoder.parameters())
184
+ return params
185
+
186
+ def get_discriminator_params(self) -> list:
187
+ if hasattr(self.loss, "get_trainable_parameters"):
188
+ params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
189
+ else:
190
+ params = []
191
+ return params
192
+
193
+ def get_last_layer(self):
194
+ return self.decoder.get_last_layer()
195
+
196
+ def encode(
197
+ self,
198
+ x: torch.Tensor,
199
+ return_reg_log: bool = False,
200
+ unregularized: bool = False,
201
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
202
+ z = self.encoder(x)
203
+ if unregularized:
204
+ return z, dict()
205
+ z, reg_log = self.regularization(z)
206
+ if return_reg_log:
207
+ return z, reg_log
208
+ return z
209
+
210
+ def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
211
+ x = self.decoder(z, **kwargs)
212
+ return x
213
+
214
+ def forward(
215
+ self, x: torch.Tensor, **additional_decode_kwargs
216
+ ) -> Tuple[torch.Tensor, torch.Tensor, dict]:
217
+ z, reg_log = self.encode(x, return_reg_log=True)
218
+ dec = self.decode(z, **additional_decode_kwargs)
219
+ return z, dec, reg_log
220
+
221
+ def inner_training_step(
222
+ self, batch: dict, batch_idx: int, optimizer_idx: int = 0
223
+ ) -> torch.Tensor:
224
+ x = self.get_input(batch)
225
+ additional_decode_kwargs = {
226
+ key: batch[key] for key in self.additional_decode_keys.intersection(batch)
227
+ }
228
+ z, xrec, regularization_log = self(x, **additional_decode_kwargs)
229
+ if hasattr(self.loss, "forward_keys"):
230
+ extra_info = {
231
+ "z": z,
232
+ "optimizer_idx": optimizer_idx,
233
+ "global_step": self.global_step,
234
+ "last_layer": self.get_last_layer(),
235
+ "split": "train",
236
+ "regularization_log": regularization_log,
237
+ "autoencoder": self,
238
+ }
239
+ extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
240
+ else:
241
+ extra_info = dict()
242
+
243
+ if optimizer_idx == 0:
244
+ # autoencode
245
+ out_loss = self.loss(x, xrec, **extra_info)
246
+ if isinstance(out_loss, tuple):
247
+ aeloss, log_dict_ae = out_loss
248
+ else:
249
+ # simple loss function
250
+ aeloss = out_loss
251
+ log_dict_ae = {"train/loss/rec": aeloss.detach()}
252
+
253
+ self.log_dict(
254
+ log_dict_ae,
255
+ prog_bar=False,
256
+ logger=True,
257
+ on_step=True,
258
+ on_epoch=True,
259
+ sync_dist=False,
260
+ )
261
+ self.log(
262
+ "loss",
263
+ aeloss.mean().detach(),
264
+ prog_bar=True,
265
+ logger=False,
266
+ on_epoch=False,
267
+ on_step=True,
268
+ )
269
+ return aeloss
270
+ elif optimizer_idx == 1:
271
+ # discriminator
272
+ discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
273
+ # -> discriminator always needs to return a tuple
274
+ self.log_dict(
275
+ log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
276
+ )
277
+ return discloss
278
+ else:
279
+ raise NotImplementedError(f"Unknown optimizer {optimizer_idx}")
280
+
281
+ def training_step(self, batch: dict, batch_idx: int):
282
+ opts = self.optimizers()
283
+ if not isinstance(opts, list):
284
+ # Non-adversarial case
285
+ opts = [opts]
286
+ optimizer_idx = batch_idx % len(opts)
287
+ if self.global_step < self.disc_start_iter:
288
+ optimizer_idx = 0
289
+ opt = opts[optimizer_idx]
290
+ opt.zero_grad()
291
+ with opt.toggle_model():
292
+ loss = self.inner_training_step(
293
+ batch, batch_idx, optimizer_idx=optimizer_idx
294
+ )
295
+ self.manual_backward(loss)
296
+ opt.step()
297
+
298
+ def validation_step(self, batch: dict, batch_idx: int) -> Dict:
299
+ log_dict = self._validation_step(batch, batch_idx)
300
+ with self.ema_scope():
301
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
302
+ log_dict.update(log_dict_ema)
303
+ return log_dict
304
+
305
+ def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict:
306
+ x = self.get_input(batch)
307
+
308
+ z, xrec, regularization_log = self(x)
309
+ if hasattr(self.loss, "forward_keys"):
310
+ extra_info = {
311
+ "z": z,
312
+ "optimizer_idx": 0,
313
+ "global_step": self.global_step,
314
+ "last_layer": self.get_last_layer(),
315
+ "split": "val" + postfix,
316
+ "regularization_log": regularization_log,
317
+ "autoencoder": self,
318
+ }
319
+ extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
320
+ else:
321
+ extra_info = dict()
322
+ out_loss = self.loss(x, xrec, **extra_info)
323
+ if isinstance(out_loss, tuple):
324
+ aeloss, log_dict_ae = out_loss
325
+ else:
326
+ # simple loss function
327
+ aeloss = out_loss
328
+ log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()}
329
+ full_log_dict = log_dict_ae
330
+
331
+ if "optimizer_idx" in extra_info:
332
+ extra_info["optimizer_idx"] = 1
333
+ discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
334
+ full_log_dict.update(log_dict_disc)
335
+ self.log(
336
+ f"val{postfix}/loss/rec",
337
+ log_dict_ae[f"val{postfix}/loss/rec"],
338
+ sync_dist=True,
339
+ )
340
+ self.log_dict(full_log_dict, sync_dist=True)
341
+ return full_log_dict
342
+
343
+ def get_param_groups(
344
+ self, parameter_names: List[List[str]], optimizer_args: List[dict]
345
+ ) -> Tuple[List[Dict[str, Any]], int]:
346
+ groups = []
347
+ num_params = 0
348
+ for names, args in zip(parameter_names, optimizer_args):
349
+ params = []
350
+ for pattern_ in names:
351
+ pattern_params = []
352
+ pattern = re.compile(pattern_)
353
+ for p_name, param in self.named_parameters():
354
+ if re.match(pattern, p_name):
355
+ pattern_params.append(param)
356
+ num_params += param.numel()
357
+ if len(pattern_params) == 0:
358
+ logpy.warn(f"Did not find parameters for pattern {pattern_}")
359
+ params.extend(pattern_params)
360
+ groups.append({"params": params, **args})
361
+ return groups, num_params
362
+
363
+ def configure_optimizers(self) -> List[torch.optim.Optimizer]:
364
+ if self.trainable_ae_params is None:
365
+ ae_params = self.get_autoencoder_params()
366
+ else:
367
+ ae_params, num_ae_params = self.get_param_groups(
368
+ self.trainable_ae_params, self.ae_optimizer_args
369
+ )
370
+ logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
371
+ if self.trainable_disc_params is None:
372
+ disc_params = self.get_discriminator_params()
373
+ else:
374
+ disc_params, num_disc_params = self.get_param_groups(
375
+ self.trainable_disc_params, self.disc_optimizer_args
376
+ )
377
+ logpy.info(
378
+ f"Number of trainable discriminator parameters: {num_disc_params:,}"
379
+ )
380
+ opt_ae = self.instantiate_optimizer_from_config(
381
+ ae_params,
382
+ default(self.lr_g_factor, 1.0) * self.learning_rate,
383
+ self.optimizer_config,
384
+ )
385
+ opts = [opt_ae]
386
+ if len(disc_params) > 0:
387
+ opt_disc = self.instantiate_optimizer_from_config(
388
+ disc_params, self.learning_rate, self.optimizer_config
389
+ )
390
+ opts.append(opt_disc)
391
+
392
+ return opts
393
+
394
+ @torch.no_grad()
395
+ def log_images(
396
+ self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
397
+ ) -> dict:
398
+ log = dict()
399
+ additional_decode_kwargs = {}
400
+ x = self.get_input(batch)
401
+ additional_decode_kwargs.update(
402
+ {key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
403
+ )
404
+
405
+ _, xrec, _ = self(x, **additional_decode_kwargs)
406
+ log["inputs"] = x
407
+ log["reconstructions"] = xrec
408
+ diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x)
409
+ diff.clamp_(0, 1.0)
410
+ log["diff"] = 2.0 * diff - 1.0
411
+ # diff_boost shows location of small errors, by boosting their
412
+ # brightness.
413
+ log["diff_boost"] = (
414
+ 2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
415
+ )
416
+ if hasattr(self.loss, "log_images"):
417
+ log.update(self.loss.log_images(x, xrec))
418
+ with self.ema_scope():
419
+ _, xrec_ema, _ = self(x, **additional_decode_kwargs)
420
+ log["reconstructions_ema"] = xrec_ema
421
+ diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
422
+ diff_ema.clamp_(0, 1.0)
423
+ log["diff_ema"] = 2.0 * diff_ema - 1.0
424
+ log["diff_boost_ema"] = (
425
+ 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
426
+ )
427
+ if additional_log_kwargs:
428
+ additional_decode_kwargs.update(additional_log_kwargs)
429
+ _, xrec_add, _ = self(x, **additional_decode_kwargs)
430
+ log_str = "reconstructions-" + "-".join(
431
+ [f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs]
432
+ )
433
+ log[log_str] = xrec_add
434
+ return log
435
+
436
+
437
+ class AutoencodingEngineLegacy(AutoencodingEngine):
438
+ def __init__(self, embed_dim: int, **kwargs):
439
+ self.max_batch_size = kwargs.pop("max_batch_size", None)
440
+ ddconfig = kwargs.pop("ddconfig")
441
+ ckpt_path = kwargs.pop("ckpt_path", None)
442
+ ckpt_engine = kwargs.pop("ckpt_engine", None)
443
+ super().__init__(
444
+ encoder_config={
445
+ "target": "dragnuwa.svd.modules.diffusionmodules.model.Encoder",
446
+ "params": ddconfig,
447
+ },
448
+ decoder_config={
449
+ "target": "dragnuwa.svd.modules.diffusionmodules.model.Decoder",
450
+ "params": ddconfig,
451
+ },
452
+ **kwargs,
453
+ )
454
+ self.quant_conv = torch.nn.Conv2d(
455
+ (1 + ddconfig["double_z"]) * ddconfig["z_channels"],
456
+ (1 + ddconfig["double_z"]) * embed_dim,
457
+ 1,
458
+ )
459
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
460
+ self.embed_dim = embed_dim
461
+
462
+ self.apply_ckpt(default(ckpt_path, ckpt_engine))
463
+
464
+ def get_autoencoder_params(self) -> list:
465
+ params = super().get_autoencoder_params()
466
+ return params
467
+
468
+ def encode(
469
+ self, x: torch.Tensor, return_reg_log: bool = False
470
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
471
+ if self.max_batch_size is None:
472
+ z = self.encoder(x)
473
+ z = self.quant_conv(z)
474
+ else:
475
+ N = x.shape[0]
476
+ bs = self.max_batch_size
477
+ n_batches = int(math.ceil(N / bs))
478
+ z = list()
479
+ for i_batch in range(n_batches):
480
+ z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
481
+ z_batch = self.quant_conv(z_batch)
482
+ z.append(z_batch)
483
+ z = torch.cat(z, 0)
484
+
485
+ z, reg_log = self.regularization(z)
486
+ if return_reg_log:
487
+ return z, reg_log
488
+ return z
489
+
490
+ def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
491
+ if self.max_batch_size is None:
492
+ dec = self.post_quant_conv(z)
493
+ dec = self.decoder(dec, **decoder_kwargs)
494
+ else:
495
+ N = z.shape[0]
496
+ bs = self.max_batch_size
497
+ n_batches = int(math.ceil(N / bs))
498
+ dec = list()
499
+ for i_batch in range(n_batches):
500
+ dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
501
+ dec_batch = self.decoder(dec_batch, **decoder_kwargs)
502
+ dec.append(dec_batch)
503
+ dec = torch.cat(dec, 0)
504
+
505
+ return dec
506
+
507
+
508
+ class AutoencoderKL(AutoencodingEngineLegacy):
509
+ def __init__(self, **kwargs):
510
+ if "lossconfig" in kwargs:
511
+ kwargs["loss_config"] = kwargs.pop("lossconfig")
512
+ super().__init__(
513
+ regularizer_config={
514
+ "target": (
515
+ "dragnuwa.svd.modules.autoencoding.regularizers"
516
+ ".DiagonalGaussianRegularizer"
517
+ )
518
+ },
519
+ **kwargs,
520
+ )
521
+
522
+
523
+ class AutoencoderLegacyVQ(AutoencodingEngineLegacy):
524
+ def __init__(
525
+ self,
526
+ embed_dim: int,
527
+ n_embed: int,
528
+ sane_index_shape: bool = False,
529
+ **kwargs,
530
+ ):
531
+ if "lossconfig" in kwargs:
532
+ logpy.warn(f"Parameter `lossconfig` is deprecated, use `loss_config`.")
533
+ kwargs["loss_config"] = kwargs.pop("lossconfig")
534
+ super().__init__(
535
+ regularizer_config={
536
+ "target": (
537
+ "dragnuwa.svd.modules.autoencoding.regularizers.quantize" ".VectorQuantizer"
538
+ ),
539
+ "params": {
540
+ "n_e": n_embed,
541
+ "e_dim": embed_dim,
542
+ "sane_index_shape": sane_index_shape,
543
+ },
544
+ },
545
+ **kwargs,
546
+ )
547
+
548
+
549
+ class IdentityFirstStage(AbstractAutoencoder):
550
+ def __init__(self, *args, **kwargs):
551
+ super().__init__(*args, **kwargs)
552
+
553
+ def get_input(self, x: Any) -> Any:
554
+ return x
555
+
556
+ def encode(self, x: Any, *args, **kwargs) -> Any:
557
+ return x
558
+
559
+ def decode(self, x: Any, *args, **kwargs) -> Any:
560
+ return x
561
+
562
+
563
+ class AEIntegerWrapper(nn.Module):
564
+ def __init__(
565
+ self,
566
+ model: nn.Module,
567
+ shape: Union[None, Tuple[int, int], List[int]] = (16, 16),
568
+ regularization_key: str = "regularization",
569
+ encoder_kwargs: Optional[Dict[str, Any]] = None,
570
+ ):
571
+ super().__init__()
572
+ self.model = model
573
+ assert hasattr(model, "encode") and hasattr(
574
+ model, "decode"
575
+ ), "Need AE interface"
576
+ self.regularization = get_nested_attribute(model, regularization_key)
577
+ self.shape = shape
578
+ self.encoder_kwargs = default(encoder_kwargs, {"return_reg_log": True})
579
+
580
+ def encode(self, x) -> torch.Tensor:
581
+ assert (
582
+ not self.training
583
+ ), f"{self.__class__.__name__} only supports inference currently"
584
+ _, log = self.model.encode(x, **self.encoder_kwargs)
585
+ assert isinstance(log, dict)
586
+ inds = log["min_encoding_indices"]
587
+ return rearrange(inds, "b ... -> b (...)")
588
+
589
+ def decode(
590
+ self, inds: torch.Tensor, shape: Union[None, tuple, list] = None
591
+ ) -> torch.Tensor:
592
+ # expect inds shape (b, s) with s = h*w
593
+ shape = default(shape, self.shape) # Optional[(h, w)]
594
+ if shape is not None:
595
+ assert len(shape) == 2, f"Unhandeled shape {shape}"
596
+ inds = rearrange(inds, "b (h w) -> b h w", h=shape[0], w=shape[1])
597
+ h = self.regularization.get_codebook_entry(inds) # (b, h, w, c)
598
+ h = rearrange(h, "b h w c -> b c h w")
599
+ return self.model.decode(h)
600
+
601
+
602
+ class AutoencoderKLModeOnly(AutoencodingEngineLegacy):
603
+ def __init__(self, **kwargs):
604
+ if "lossconfig" in kwargs:
605
+ kwargs["loss_config"] = kwargs.pop("lossconfig")
606
+ super().__init__(
607
+ regularizer_config={
608
+ "target": (
609
+ "dragnuwa.svd.modules.autoencoding.regularizers"
610
+ ".DiagonalGaussianRegularizer"
611
+ ),
612
+ "params": {"sample": False},
613
+ },
614
+ **kwargs,
615
+ )
dragnuwa/svd/modules/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .encoders.modules import GeneralConditioner
2
+
3
+ UNCONDITIONAL_CONFIG = {
4
+ "target": "dragnuwa.svd.modules.GeneralConditioner",
5
+ "params": {"emb_models": []},
6
+ }
dragnuwa/svd/modules/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (315 Bytes). View file
 
dragnuwa/svd/modules/__pycache__/attention.cpython-38.pyc ADDED
Binary file (18 kB). View file
 
dragnuwa/svd/modules/__pycache__/ema.cpython-38.pyc ADDED
Binary file (3.19 kB). View file
 
dragnuwa/svd/modules/__pycache__/video_attention.cpython-38.pyc ADDED
Binary file (6.23 kB). View file
 
dragnuwa/svd/modules/attention.py ADDED
@@ -0,0 +1,759 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from inspect import isfunction
4
+ from typing import Any, Optional
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange, repeat
9
+ from packaging import version
10
+ from torch import nn
11
+ from torch.utils.checkpoint import checkpoint
12
+
13
+ logpy = logging.getLogger(__name__)
14
+
15
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
16
+ SDP_IS_AVAILABLE = True
17
+ from torch.backends.cuda import SDPBackend, sdp_kernel
18
+
19
+ BACKEND_MAP = {
20
+ SDPBackend.MATH: {
21
+ "enable_math": True,
22
+ "enable_flash": False,
23
+ "enable_mem_efficient": False,
24
+ },
25
+ SDPBackend.FLASH_ATTENTION: {
26
+ "enable_math": False,
27
+ "enable_flash": True,
28
+ "enable_mem_efficient": False,
29
+ },
30
+ SDPBackend.EFFICIENT_ATTENTION: {
31
+ "enable_math": False,
32
+ "enable_flash": False,
33
+ "enable_mem_efficient": True,
34
+ },
35
+ None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
36
+ }
37
+ else:
38
+ from contextlib import nullcontext
39
+
40
+ SDP_IS_AVAILABLE = False
41
+ sdp_kernel = nullcontext
42
+ BACKEND_MAP = {}
43
+ logpy.warn(
44
+ f"No SDP backend available, likely because you are running in pytorch "
45
+ f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
46
+ f"You might want to consider upgrading."
47
+ )
48
+
49
+ try:
50
+ import xformers
51
+ import xformers.ops
52
+
53
+ XFORMERS_IS_AVAILABLE = True
54
+ except:
55
+ XFORMERS_IS_AVAILABLE = False
56
+ logpy.warn("no module 'xformers'. Processing without...")
57
+
58
+ # from .diffusionmodules.util import mixed_checkpoint as checkpoint
59
+
60
+
61
+ def exists(val):
62
+ return val is not None
63
+
64
+
65
+ def uniq(arr):
66
+ return {el: True for el in arr}.keys()
67
+
68
+
69
+ def default(val, d):
70
+ if exists(val):
71
+ return val
72
+ return d() if isfunction(d) else d
73
+
74
+
75
+ def max_neg_value(t):
76
+ return -torch.finfo(t.dtype).max
77
+
78
+
79
+ def init_(tensor):
80
+ dim = tensor.shape[-1]
81
+ std = 1 / math.sqrt(dim)
82
+ tensor.uniform_(-std, std)
83
+ return tensor
84
+
85
+
86
+ # feedforward
87
+ class GEGLU(nn.Module):
88
+ def __init__(self, dim_in, dim_out):
89
+ super().__init__()
90
+ self.proj = nn.Linear(dim_in, dim_out * 2)
91
+
92
+ def forward(self, x):
93
+ x, gate = self.proj(x).chunk(2, dim=-1)
94
+ return x * F.gelu(gate)
95
+
96
+
97
+ class FeedForward(nn.Module):
98
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
99
+ super().__init__()
100
+ inner_dim = int(dim * mult)
101
+ dim_out = default(dim_out, dim)
102
+ project_in = (
103
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
104
+ if not glu
105
+ else GEGLU(dim, inner_dim)
106
+ )
107
+
108
+ self.net = nn.Sequential(
109
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
110
+ )
111
+
112
+ def forward(self, x):
113
+ return self.net(x)
114
+
115
+
116
+ def zero_module(module):
117
+ """
118
+ Zero out the parameters of a module and return it.
119
+ """
120
+ for p in module.parameters():
121
+ p.detach().zero_()
122
+ return module
123
+
124
+
125
+ def Normalize(in_channels):
126
+ return torch.nn.GroupNorm(
127
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
128
+ )
129
+
130
+
131
+ class LinearAttention(nn.Module):
132
+ def __init__(self, dim, heads=4, dim_head=32):
133
+ super().__init__()
134
+ self.heads = heads
135
+ hidden_dim = dim_head * heads
136
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
137
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
138
+
139
+ def forward(self, x):
140
+ b, c, h, w = x.shape
141
+ qkv = self.to_qkv(x)
142
+ q, k, v = rearrange(
143
+ qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
144
+ )
145
+ k = k.softmax(dim=-1)
146
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
147
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
148
+ out = rearrange(
149
+ out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
150
+ )
151
+ return self.to_out(out)
152
+
153
+
154
+ class SelfAttention(nn.Module):
155
+ ATTENTION_MODES = ("xformers", "torch", "math")
156
+
157
+ def __init__(
158
+ self,
159
+ dim: int,
160
+ num_heads: int = 8,
161
+ qkv_bias: bool = False,
162
+ qk_scale: Optional[float] = None,
163
+ attn_drop: float = 0.0,
164
+ proj_drop: float = 0.0,
165
+ attn_mode: str = "xformers",
166
+ ):
167
+ super().__init__()
168
+ self.num_heads = num_heads
169
+ head_dim = dim // num_heads
170
+ self.scale = qk_scale or head_dim**-0.5
171
+
172
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
173
+ self.attn_drop = nn.Dropout(attn_drop)
174
+ self.proj = nn.Linear(dim, dim)
175
+ self.proj_drop = nn.Dropout(proj_drop)
176
+ assert attn_mode in self.ATTENTION_MODES
177
+ self.attn_mode = attn_mode
178
+
179
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
180
+ B, L, C = x.shape
181
+
182
+ qkv = self.qkv(x)
183
+ if self.attn_mode == "torch":
184
+ qkv = rearrange(
185
+ qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
186
+ ).float()
187
+ q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
188
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
189
+ x = rearrange(x, "B H L D -> B L (H D)")
190
+ elif self.attn_mode == "xformers":
191
+ qkv = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
192
+ q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
193
+ x = xformers.ops.memory_efficient_attention(q, k, v)
194
+ x = rearrange(x, "B L H D -> B L (H D)", H=self.num_heads)
195
+ elif self.attn_mode == "math":
196
+ qkv = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
197
+ q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
198
+ attn = (q @ k.transpose(-2, -1)) * self.scale
199
+ attn = attn.softmax(dim=-1)
200
+ attn = self.attn_drop(attn)
201
+ x = (attn @ v).transpose(1, 2).reshape(B, L, C)
202
+ else:
203
+ raise NotImplemented
204
+
205
+ x = self.proj(x)
206
+ x = self.proj_drop(x)
207
+ return x
208
+
209
+
210
+ class SpatialSelfAttention(nn.Module):
211
+ def __init__(self, in_channels):
212
+ super().__init__()
213
+ self.in_channels = in_channels
214
+
215
+ self.norm = Normalize(in_channels)
216
+ self.q = torch.nn.Conv2d(
217
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
218
+ )
219
+ self.k = torch.nn.Conv2d(
220
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
221
+ )
222
+ self.v = torch.nn.Conv2d(
223
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
224
+ )
225
+ self.proj_out = torch.nn.Conv2d(
226
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
227
+ )
228
+
229
+ def forward(self, x):
230
+ h_ = x
231
+ h_ = self.norm(h_)
232
+ q = self.q(h_)
233
+ k = self.k(h_)
234
+ v = self.v(h_)
235
+
236
+ # compute attention
237
+ b, c, h, w = q.shape
238
+ q = rearrange(q, "b c h w -> b (h w) c")
239
+ k = rearrange(k, "b c h w -> b c (h w)")
240
+ w_ = torch.einsum("bij,bjk->bik", q, k)
241
+
242
+ w_ = w_ * (int(c) ** (-0.5))
243
+ w_ = torch.nn.functional.softmax(w_, dim=2)
244
+
245
+ # attend to values
246
+ v = rearrange(v, "b c h w -> b c (h w)")
247
+ w_ = rearrange(w_, "b i j -> b j i")
248
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
249
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
250
+ h_ = self.proj_out(h_)
251
+
252
+ return x + h_
253
+
254
+
255
+ class CrossAttention(nn.Module):
256
+ def __init__(
257
+ self,
258
+ query_dim,
259
+ context_dim=None,
260
+ heads=8,
261
+ dim_head=64,
262
+ dropout=0.0,
263
+ backend=None,
264
+ ):
265
+ super().__init__()
266
+ inner_dim = dim_head * heads
267
+ context_dim = default(context_dim, query_dim)
268
+
269
+ self.scale = dim_head**-0.5
270
+ self.heads = heads
271
+
272
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
273
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
274
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
275
+
276
+ self.to_out = nn.Sequential(
277
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
278
+ )
279
+ self.backend = backend
280
+
281
+ def forward(
282
+ self,
283
+ x,
284
+ context=None,
285
+ mask=None,
286
+ additional_tokens=None,
287
+ n_times_crossframe_attn_in_self=0,
288
+ ):
289
+ h = self.heads
290
+
291
+ if additional_tokens is not None:
292
+ # get the number of masked tokens at the beginning of the output sequence
293
+ n_tokens_to_mask = additional_tokens.shape[1]
294
+ # add additional token
295
+ x = torch.cat([additional_tokens, x], dim=1)
296
+
297
+ q = self.to_q(x)
298
+ context = default(context, x)
299
+ k = self.to_k(context)
300
+ v = self.to_v(context)
301
+
302
+ if n_times_crossframe_attn_in_self:
303
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
304
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
305
+ n_cp = x.shape[0] // n_times_crossframe_attn_in_self
306
+ k = repeat(
307
+ k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
308
+ )
309
+ v = repeat(
310
+ v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
311
+ )
312
+
313
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
314
+
315
+ ## old
316
+ """
317
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
318
+ del q, k
319
+
320
+ if exists(mask):
321
+ mask = rearrange(mask, 'b ... -> b (...)')
322
+ max_neg_value = -torch.finfo(sim.dtype).max
323
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
324
+ sim.masked_fill_(~mask, max_neg_value)
325
+
326
+ # attention, what we cannot get enough of
327
+ sim = sim.softmax(dim=-1)
328
+
329
+ out = einsum('b i j, b j d -> b i d', sim, v)
330
+ """
331
+ ## new
332
+ with sdp_kernel(**BACKEND_MAP[self.backend]):
333
+ # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
334
+ out = F.scaled_dot_product_attention(
335
+ q, k, v, attn_mask=mask
336
+ ) # scale is dim_head ** -0.5 per default
337
+
338
+ del q, k, v
339
+ out = rearrange(out, "b h n d -> b n (h d)", h=h)
340
+
341
+ if additional_tokens is not None:
342
+ # remove additional token
343
+ out = out[:, n_tokens_to_mask:]
344
+ return self.to_out(out)
345
+
346
+
347
+ class MemoryEfficientCrossAttention(nn.Module):
348
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
349
+ def __init__(
350
+ self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
351
+ ):
352
+ super().__init__()
353
+ logpy.debug(
354
+ f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, "
355
+ f"context_dim is {context_dim} and using {heads} heads with a "
356
+ f"dimension of {dim_head}."
357
+ )
358
+ inner_dim = dim_head * heads
359
+ context_dim = default(context_dim, query_dim)
360
+
361
+ self.heads = heads
362
+ self.dim_head = dim_head
363
+
364
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
365
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
366
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
367
+
368
+ self.to_out = nn.Sequential(
369
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
370
+ )
371
+ self.attention_op: Optional[Any] = None
372
+
373
+ def forward(
374
+ self,
375
+ x,
376
+ context=None,
377
+ mask=None,
378
+ additional_tokens=None,
379
+ n_times_crossframe_attn_in_self=0,
380
+ ):
381
+ if additional_tokens is not None:
382
+ # get the number of masked tokens at the beginning of the output sequence
383
+ n_tokens_to_mask = additional_tokens.shape[1]
384
+ # add additional token
385
+ x = torch.cat([additional_tokens, x], dim=1)
386
+ q = self.to_q(x)
387
+ context = default(context, x)
388
+ k = self.to_k(context)
389
+ v = self.to_v(context)
390
+
391
+ if n_times_crossframe_attn_in_self:
392
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
393
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
394
+ # n_cp = x.shape[0]//n_times_crossframe_attn_in_self
395
+ k = repeat(
396
+ k[::n_times_crossframe_attn_in_self],
397
+ "b ... -> (b n) ...",
398
+ n=n_times_crossframe_attn_in_self,
399
+ )
400
+ v = repeat(
401
+ v[::n_times_crossframe_attn_in_self],
402
+ "b ... -> (b n) ...",
403
+ n=n_times_crossframe_attn_in_self,
404
+ )
405
+
406
+ b, _, _ = q.shape
407
+ q, k, v = map(
408
+ lambda t: t.unsqueeze(3)
409
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
410
+ .permute(0, 2, 1, 3)
411
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
412
+ .contiguous(),
413
+ (q, k, v),
414
+ )
415
+
416
+ # actually compute the attention, what we cannot get enough of
417
+ if version.parse(xformers.__version__) >= version.parse("0.0.21"):
418
+ # NOTE: workaround for
419
+ # https://github.com/facebookresearch/xformers/issues/845
420
+ max_bs = 32768
421
+ N = q.shape[0]
422
+ n_batches = math.ceil(N / max_bs)
423
+ out = list()
424
+ for i_batch in range(n_batches):
425
+ batch = slice(i_batch * max_bs, (i_batch + 1) * max_bs)
426
+ out.append(
427
+ xformers.ops.memory_efficient_attention(
428
+ q[batch],
429
+ k[batch],
430
+ v[batch],
431
+ attn_bias=None,
432
+ op=self.attention_op,
433
+ )
434
+ )
435
+ out = torch.cat(out, 0)
436
+ else:
437
+ out = xformers.ops.memory_efficient_attention(
438
+ q, k, v, attn_bias=None, op=self.attention_op
439
+ )
440
+
441
+ # TODO: Use this directly in the attention operation, as a bias
442
+ if exists(mask):
443
+ raise NotImplementedError
444
+ out = (
445
+ out.unsqueeze(0)
446
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
447
+ .permute(0, 2, 1, 3)
448
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
449
+ )
450
+ if additional_tokens is not None:
451
+ # remove additional token
452
+ out = out[:, n_tokens_to_mask:]
453
+ return self.to_out(out)
454
+
455
+
456
+ class BasicTransformerBlock(nn.Module):
457
+ ATTENTION_MODES = {
458
+ "softmax": CrossAttention, # vanilla attention
459
+ "softmax-xformers": MemoryEfficientCrossAttention, # ampere
460
+ }
461
+
462
+ def __init__(
463
+ self,
464
+ dim,
465
+ n_heads,
466
+ d_head,
467
+ dropout=0.0,
468
+ context_dim=None,
469
+ gated_ff=True,
470
+ checkpoint=True,
471
+ disable_self_attn=False,
472
+ attn_mode="softmax",
473
+ sdp_backend=None,
474
+ ):
475
+ super().__init__()
476
+ assert attn_mode in self.ATTENTION_MODES
477
+ if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
478
+ logpy.warn(
479
+ f"Attention mode '{attn_mode}' is not available. Falling "
480
+ f"back to native attention. This is not a problem in "
481
+ f"Pytorch >= 2.0. FYI, you are running with PyTorch "
482
+ f"version {torch.__version__}."
483
+ )
484
+ attn_mode = "softmax"
485
+ elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
486
+ logpy.warn(
487
+ "We do not support vanilla attention anymore, as it is too "
488
+ "expensive. Sorry."
489
+ )
490
+ if not XFORMERS_IS_AVAILABLE:
491
+ assert (
492
+ False
493
+ ), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
494
+ else:
495
+ logpy.info("Falling back to xformers efficient attention.")
496
+ attn_mode = "softmax-xformers"
497
+ attn_cls = self.ATTENTION_MODES[attn_mode]
498
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
499
+ assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
500
+ else:
501
+ assert sdp_backend is None
502
+ self.disable_self_attn = disable_self_attn
503
+ self.attn1 = attn_cls(
504
+ query_dim=dim,
505
+ heads=n_heads,
506
+ dim_head=d_head,
507
+ dropout=dropout,
508
+ context_dim=context_dim if self.disable_self_attn else None,
509
+ backend=sdp_backend,
510
+ ) # is a self-attention if not self.disable_self_attn
511
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
512
+ self.attn2 = attn_cls(
513
+ query_dim=dim,
514
+ context_dim=context_dim,
515
+ heads=n_heads,
516
+ dim_head=d_head,
517
+ dropout=dropout,
518
+ backend=sdp_backend,
519
+ ) # is self-attn if context is none
520
+ self.norm1 = nn.LayerNorm(dim)
521
+ self.norm2 = nn.LayerNorm(dim)
522
+ self.norm3 = nn.LayerNorm(dim)
523
+ self.checkpoint = checkpoint
524
+ if self.checkpoint:
525
+ logpy.debug(f"{self.__class__.__name__} is using checkpointing")
526
+
527
+ def forward(
528
+ self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
529
+ ):
530
+ kwargs = {"x": x}
531
+
532
+ if context is not None:
533
+ kwargs.update({"context": context})
534
+
535
+ if additional_tokens is not None:
536
+ kwargs.update({"additional_tokens": additional_tokens})
537
+
538
+ if n_times_crossframe_attn_in_self:
539
+ kwargs.update(
540
+ {"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
541
+ )
542
+
543
+ # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
544
+ if self.checkpoint:
545
+ # inputs = {"x": x, "context": context}
546
+ return checkpoint(self._forward, x, context)
547
+ # return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
548
+ else:
549
+ return self._forward(**kwargs)
550
+
551
+ def _forward(
552
+ self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
553
+ ):
554
+ x = (
555
+ self.attn1(
556
+ self.norm1(x),
557
+ context=context if self.disable_self_attn else None,
558
+ additional_tokens=additional_tokens,
559
+ n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
560
+ if not self.disable_self_attn
561
+ else 0,
562
+ )
563
+ + x
564
+ )
565
+ x = (
566
+ self.attn2(
567
+ self.norm2(x), context=context, additional_tokens=additional_tokens
568
+ )
569
+ + x
570
+ )
571
+ x = self.ff(self.norm3(x)) + x
572
+ return x
573
+
574
+
575
+ class BasicTransformerSingleLayerBlock(nn.Module):
576
+ ATTENTION_MODES = {
577
+ "softmax": CrossAttention, # vanilla attention
578
+ "softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version
579
+ # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
580
+ }
581
+
582
+ def __init__(
583
+ self,
584
+ dim,
585
+ n_heads,
586
+ d_head,
587
+ dropout=0.0,
588
+ context_dim=None,
589
+ gated_ff=True,
590
+ checkpoint=True,
591
+ attn_mode="softmax",
592
+ ):
593
+ super().__init__()
594
+ assert attn_mode in self.ATTENTION_MODES
595
+ attn_cls = self.ATTENTION_MODES[attn_mode]
596
+ self.attn1 = attn_cls(
597
+ query_dim=dim,
598
+ heads=n_heads,
599
+ dim_head=d_head,
600
+ dropout=dropout,
601
+ context_dim=context_dim,
602
+ )
603
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
604
+ self.norm1 = nn.LayerNorm(dim)
605
+ self.norm2 = nn.LayerNorm(dim)
606
+ self.checkpoint = checkpoint
607
+
608
+ def forward(self, x, context=None):
609
+ # inputs = {"x": x, "context": context}
610
+ # return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
611
+ return checkpoint(self._forward, x, context)
612
+
613
+ def _forward(self, x, context=None):
614
+ x = self.attn1(self.norm1(x), context=context) + x
615
+ x = self.ff(self.norm2(x)) + x
616
+ return x
617
+
618
+
619
+ class SpatialTransformer(nn.Module):
620
+ """
621
+ Transformer block for image-like data.
622
+ First, project the input (aka embedding)
623
+ and reshape to b, t, d.
624
+ Then apply standard transformer action.
625
+ Finally, reshape to image
626
+ NEW: use_linear for more efficiency instead of the 1x1 convs
627
+ """
628
+
629
+ def __init__(
630
+ self,
631
+ in_channels,
632
+ n_heads,
633
+ d_head,
634
+ depth=1,
635
+ dropout=0.0,
636
+ context_dim=None,
637
+ disable_self_attn=False,
638
+ use_linear=False,
639
+ attn_type="softmax",
640
+ use_checkpoint=True,
641
+ # sdp_backend=SDPBackend.FLASH_ATTENTION
642
+ sdp_backend=None,
643
+ ):
644
+ super().__init__()
645
+ logpy.debug(
646
+ f"constructing {self.__class__.__name__} of depth {depth} w/ "
647
+ f"{in_channels} channels and {n_heads} heads."
648
+ )
649
+
650
+ if exists(context_dim) and not isinstance(context_dim, list):
651
+ context_dim = [context_dim]
652
+ if exists(context_dim) and isinstance(context_dim, list):
653
+ if depth != len(context_dim):
654
+ logpy.warn(
655
+ f"{self.__class__.__name__}: Found context dims "
656
+ f"{context_dim} of depth {len(context_dim)}, which does not "
657
+ f"match the specified 'depth' of {depth}. Setting context_dim "
658
+ f"to {depth * [context_dim[0]]} now."
659
+ )
660
+ # depth does not match context dims.
661
+ assert all(
662
+ map(lambda x: x == context_dim[0], context_dim)
663
+ ), "need homogenous context_dim to match depth automatically"
664
+ context_dim = depth * [context_dim[0]]
665
+ elif context_dim is None:
666
+ context_dim = [None] * depth
667
+ self.in_channels = in_channels
668
+ inner_dim = n_heads * d_head
669
+ self.norm = Normalize(in_channels)
670
+ if not use_linear:
671
+ self.proj_in = nn.Conv2d(
672
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
673
+ )
674
+ else:
675
+ self.proj_in = nn.Linear(in_channels, inner_dim)
676
+
677
+ self.transformer_blocks = nn.ModuleList(
678
+ [
679
+ BasicTransformerBlock(
680
+ inner_dim,
681
+ n_heads,
682
+ d_head,
683
+ dropout=dropout,
684
+ context_dim=context_dim[d],
685
+ disable_self_attn=disable_self_attn,
686
+ attn_mode=attn_type,
687
+ checkpoint=use_checkpoint,
688
+ sdp_backend=sdp_backend,
689
+ )
690
+ for d in range(depth)
691
+ ]
692
+ )
693
+ if not use_linear:
694
+ self.proj_out = zero_module(
695
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
696
+ )
697
+ else:
698
+ # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
699
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
700
+ self.use_linear = use_linear
701
+
702
+ def forward(self, x, context=None):
703
+ # note: if no context is given, cross-attention defaults to self-attention
704
+ if not isinstance(context, list):
705
+ context = [context]
706
+ b, c, h, w = x.shape
707
+ x_in = x
708
+ x = self.norm(x)
709
+ if not self.use_linear:
710
+ x = self.proj_in(x)
711
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
712
+ if self.use_linear:
713
+ x = self.proj_in(x)
714
+ for i, block in enumerate(self.transformer_blocks):
715
+ if i > 0 and len(context) == 1:
716
+ i = 0 # use same context for each block
717
+ x = block(x, context=context[i])
718
+ if self.use_linear:
719
+ x = self.proj_out(x)
720
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
721
+ if not self.use_linear:
722
+ x = self.proj_out(x)
723
+ return x + x_in
724
+
725
+
726
+ class SimpleTransformer(nn.Module):
727
+ def __init__(
728
+ self,
729
+ dim: int,
730
+ depth: int,
731
+ heads: int,
732
+ dim_head: int,
733
+ context_dim: Optional[int] = None,
734
+ dropout: float = 0.0,
735
+ checkpoint: bool = True,
736
+ ):
737
+ super().__init__()
738
+ self.layers = nn.ModuleList([])
739
+ for _ in range(depth):
740
+ self.layers.append(
741
+ BasicTransformerBlock(
742
+ dim,
743
+ heads,
744
+ dim_head,
745
+ dropout=dropout,
746
+ context_dim=context_dim,
747
+ attn_mode="softmax-xformers",
748
+ checkpoint=checkpoint,
749
+ )
750
+ )
751
+
752
+ def forward(
753
+ self,
754
+ x: torch.Tensor,
755
+ context: Optional[torch.Tensor] = None,
756
+ ) -> torch.Tensor:
757
+ for layer in self.layers:
758
+ x = layer(x, context)
759
+ return x
dragnuwa/svd/modules/autoencoding/__init__.py ADDED
File without changes
dragnuwa/svd/modules/autoencoding/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (156 Bytes). View file
 
dragnuwa/svd/modules/autoencoding/__pycache__/temporal_ae.cpython-38.pyc ADDED
Binary file (9.13 kB). View file
 
dragnuwa/svd/modules/autoencoding/losses/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ __all__ = [
2
+ "GeneralLPIPSWithDiscriminator",
3
+ "LatentLPIPS",
4
+ ]
5
+
6
+ from .discriminator_loss import GeneralLPIPSWithDiscriminator
7
+ from .lpips import LatentLPIPS
dragnuwa/svd/modules/autoencoding/losses/discriminator_loss.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Iterator, List, Optional, Tuple, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchvision
7
+ from einops import rearrange
8
+ from matplotlib import colormaps
9
+ from matplotlib import pyplot as plt
10
+
11
+ from ....util import default, instantiate_from_config
12
+ from ..lpips.loss.lpips import LPIPS
13
+ from ..lpips.model.model import weights_init
14
+ from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss
15
+
16
+
17
+ class GeneralLPIPSWithDiscriminator(nn.Module):
18
+ def __init__(
19
+ self,
20
+ disc_start: int,
21
+ logvar_init: float = 0.0,
22
+ disc_num_layers: int = 3,
23
+ disc_in_channels: int = 3,
24
+ disc_factor: float = 1.0,
25
+ disc_weight: float = 1.0,
26
+ perceptual_weight: float = 1.0,
27
+ disc_loss: str = "hinge",
28
+ scale_input_to_tgt_size: bool = False,
29
+ dims: int = 2,
30
+ learn_logvar: bool = False,
31
+ regularization_weights: Union[None, Dict[str, float]] = None,
32
+ additional_log_keys: Optional[List[str]] = None,
33
+ discriminator_config: Optional[Dict] = None,
34
+ ):
35
+ super().__init__()
36
+ self.dims = dims
37
+ if self.dims > 2:
38
+ print(
39
+ f"running with dims={dims}. This means that for perceptual loss "
40
+ f"calculation, the LPIPS loss will be applied to each frame "
41
+ f"independently."
42
+ )
43
+ self.scale_input_to_tgt_size = scale_input_to_tgt_size
44
+ assert disc_loss in ["hinge", "vanilla"]
45
+ self.perceptual_loss = LPIPS().eval()
46
+ self.perceptual_weight = perceptual_weight
47
+ # output log variance
48
+ self.logvar = nn.Parameter(
49
+ torch.full((), logvar_init), requires_grad=learn_logvar
50
+ )
51
+ self.learn_logvar = learn_logvar
52
+
53
+ discriminator_config = default(
54
+ discriminator_config,
55
+ {
56
+ "target": "dragnuwa.svd.modules.autoencoding.lpips.model.model.NLayerDiscriminator",
57
+ "params": {
58
+ "input_nc": disc_in_channels,
59
+ "n_layers": disc_num_layers,
60
+ "use_actnorm": False,
61
+ },
62
+ },
63
+ )
64
+
65
+ self.discriminator = instantiate_from_config(discriminator_config).apply(
66
+ weights_init
67
+ )
68
+ self.discriminator_iter_start = disc_start
69
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
70
+ self.disc_factor = disc_factor
71
+ self.discriminator_weight = disc_weight
72
+ self.regularization_weights = default(regularization_weights, {})
73
+
74
+ self.forward_keys = [
75
+ "optimizer_idx",
76
+ "global_step",
77
+ "last_layer",
78
+ "split",
79
+ "regularization_log",
80
+ ]
81
+
82
+ self.additional_log_keys = set(default(additional_log_keys, []))
83
+ self.additional_log_keys.update(set(self.regularization_weights.keys()))
84
+
85
+ def get_trainable_parameters(self) -> Iterator[nn.Parameter]:
86
+ return self.discriminator.parameters()
87
+
88
+ def get_trainable_autoencoder_parameters(self) -> Iterator[nn.Parameter]:
89
+ if self.learn_logvar:
90
+ yield self.logvar
91
+ yield from ()
92
+
93
+ @torch.no_grad()
94
+ def log_images(
95
+ self, inputs: torch.Tensor, reconstructions: torch.Tensor
96
+ ) -> Dict[str, torch.Tensor]:
97
+ # calc logits of real/fake
98
+ logits_real = self.discriminator(inputs.contiguous().detach())
99
+ if len(logits_real.shape) < 4:
100
+ # Non patch-discriminator
101
+ return dict()
102
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
103
+ # -> (b, 1, h, w)
104
+
105
+ # parameters for colormapping
106
+ high = max(logits_fake.abs().max(), logits_real.abs().max()).item()
107
+ cmap = colormaps["PiYG"] # diverging colormap
108
+
109
+ def to_colormap(logits: torch.Tensor) -> torch.Tensor:
110
+ """(b, 1, ...) -> (b, 3, ...)"""
111
+ logits = (logits + high) / (2 * high)
112
+ logits_np = cmap(logits.cpu().numpy())[..., :3] # truncate alpha channel
113
+ # -> (b, 1, ..., 3)
114
+ logits = torch.from_numpy(logits_np).to(logits.device)
115
+ return rearrange(logits, "b 1 ... c -> b c ...")
116
+
117
+ logits_real = torch.nn.functional.interpolate(
118
+ logits_real,
119
+ size=inputs.shape[-2:],
120
+ mode="nearest",
121
+ antialias=False,
122
+ )
123
+ logits_fake = torch.nn.functional.interpolate(
124
+ logits_fake,
125
+ size=reconstructions.shape[-2:],
126
+ mode="nearest",
127
+ antialias=False,
128
+ )
129
+
130
+ # alpha value of logits for overlay
131
+ alpha_real = torch.abs(logits_real) / high
132
+ alpha_fake = torch.abs(logits_fake) / high
133
+ # -> (b, 1, h, w) in range [0, 0.5]
134
+ # alpha value of lines don't really matter, since the values are the same
135
+ # for both images and logits anyway
136
+ grid_alpha_real = torchvision.utils.make_grid(alpha_real, nrow=4)
137
+ grid_alpha_fake = torchvision.utils.make_grid(alpha_fake, nrow=4)
138
+ grid_alpha = 0.8 * torch.cat((grid_alpha_real, grid_alpha_fake), dim=1)
139
+ # -> (1, h, w)
140
+ # blend logits and images together
141
+
142
+ # prepare logits for plotting
143
+ logits_real = to_colormap(logits_real)
144
+ logits_fake = to_colormap(logits_fake)
145
+ # resize logits
146
+ # -> (b, 3, h, w)
147
+
148
+ # make some grids
149
+ # add all logits to one plot
150
+ logits_real = torchvision.utils.make_grid(logits_real, nrow=4)
151
+ logits_fake = torchvision.utils.make_grid(logits_fake, nrow=4)
152
+ # I just love how torchvision calls the number of columns `nrow`
153
+ grid_logits = torch.cat((logits_real, logits_fake), dim=1)
154
+ # -> (3, h, w)
155
+
156
+ grid_images_real = torchvision.utils.make_grid(0.5 * inputs + 0.5, nrow=4)
157
+ grid_images_fake = torchvision.utils.make_grid(
158
+ 0.5 * reconstructions + 0.5, nrow=4
159
+ )
160
+ grid_images = torch.cat((grid_images_real, grid_images_fake), dim=1)
161
+ # -> (3, h, w) in range [0, 1]
162
+
163
+ grid_blend = grid_alpha * grid_logits + (1 - grid_alpha) * grid_images
164
+
165
+ # Create labeled colorbar
166
+ dpi = 100
167
+ height = 128 / dpi
168
+ width = grid_logits.shape[2] / dpi
169
+ fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
170
+ img = ax.imshow(np.array([[-high, high]]), cmap=cmap)
171
+ plt.colorbar(
172
+ img,
173
+ cax=ax,
174
+ orientation="horizontal",
175
+ fraction=0.9,
176
+ aspect=width / height,
177
+ pad=0.0,
178
+ )
179
+ img.set_visible(False)
180
+ fig.tight_layout()
181
+ fig.canvas.draw()
182
+ # manually convert figure to numpy
183
+ cbar_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
184
+ cbar_np = cbar_np.reshape(fig.canvas.get_width_height()[::-1] + (3,))
185
+ cbar = torch.from_numpy(cbar_np.copy()).to(grid_logits.dtype) / 255.0
186
+ cbar = rearrange(cbar, "h w c -> c h w").to(grid_logits.device)
187
+
188
+ # Add colorbar to plot
189
+ annotated_grid = torch.cat((grid_logits, cbar), dim=1)
190
+ blended_grid = torch.cat((grid_blend, cbar), dim=1)
191
+ return {
192
+ "vis_logits": 2 * annotated_grid[None, ...] - 1,
193
+ "vis_logits_blended": 2 * blended_grid[None, ...] - 1,
194
+ }
195
+
196
+ def calculate_adaptive_weight(
197
+ self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer: torch.Tensor
198
+ ) -> torch.Tensor:
199
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
200
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
201
+
202
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
203
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
204
+ d_weight = d_weight * self.discriminator_weight
205
+ return d_weight
206
+
207
+ def forward(
208
+ self,
209
+ inputs: torch.Tensor,
210
+ reconstructions: torch.Tensor,
211
+ *, # added because I changed the order here
212
+ regularization_log: Dict[str, torch.Tensor],
213
+ optimizer_idx: int,
214
+ global_step: int,
215
+ last_layer: torch.Tensor,
216
+ split: str = "train",
217
+ weights: Union[None, float, torch.Tensor] = None,
218
+ ) -> Tuple[torch.Tensor, dict]:
219
+ if self.scale_input_to_tgt_size:
220
+ inputs = torch.nn.functional.interpolate(
221
+ inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
222
+ )
223
+
224
+ if self.dims > 2:
225
+ inputs, reconstructions = map(
226
+ lambda x: rearrange(x, "b c t h w -> (b t) c h w"),
227
+ (inputs, reconstructions),
228
+ )
229
+
230
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
231
+ if self.perceptual_weight > 0:
232
+ p_loss = self.perceptual_loss(
233
+ inputs.contiguous(), reconstructions.contiguous()
234
+ )
235
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
236
+
237
+ nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights)
238
+
239
+ # now the GAN part
240
+ if optimizer_idx == 0:
241
+ # generator update
242
+ if global_step >= self.discriminator_iter_start or not self.training:
243
+ logits_fake = self.discriminator(reconstructions.contiguous())
244
+ g_loss = -torch.mean(logits_fake)
245
+ if self.training:
246
+ d_weight = self.calculate_adaptive_weight(
247
+ nll_loss, g_loss, last_layer=last_layer
248
+ )
249
+ else:
250
+ d_weight = torch.tensor(1.0)
251
+ else:
252
+ d_weight = torch.tensor(0.0)
253
+ g_loss = torch.tensor(0.0, requires_grad=True)
254
+
255
+ loss = weighted_nll_loss + d_weight * self.disc_factor * g_loss
256
+ log = dict()
257
+ for k in regularization_log:
258
+ if k in self.regularization_weights:
259
+ loss = loss + self.regularization_weights[k] * regularization_log[k]
260
+ if k in self.additional_log_keys:
261
+ log[f"{split}/{k}"] = regularization_log[k].detach().float().mean()
262
+
263
+ log.update(
264
+ {
265
+ f"{split}/loss/total": loss.clone().detach().mean(),
266
+ f"{split}/loss/nll": nll_loss.detach().mean(),
267
+ f"{split}/loss/rec": rec_loss.detach().mean(),
268
+ f"{split}/loss/g": g_loss.detach().mean(),
269
+ f"{split}/scalars/logvar": self.logvar.detach(),
270
+ f"{split}/scalars/d_weight": d_weight.detach(),
271
+ }
272
+ )
273
+
274
+ return loss, log
275
+ elif optimizer_idx == 1:
276
+ # second pass for discriminator update
277
+ logits_real = self.discriminator(inputs.contiguous().detach())
278
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
279
+
280
+ if global_step >= self.discriminator_iter_start or not self.training:
281
+ d_loss = self.disc_factor * self.disc_loss(logits_real, logits_fake)
282
+ else:
283
+ d_loss = torch.tensor(0.0, requires_grad=True)
284
+
285
+ log = {
286
+ f"{split}/loss/disc": d_loss.clone().detach().mean(),
287
+ f"{split}/logits/real": logits_real.detach().mean(),
288
+ f"{split}/logits/fake": logits_fake.detach().mean(),
289
+ }
290
+ return d_loss, log
291
+ else:
292
+ raise NotImplementedError(f"Unknown optimizer_idx {optimizer_idx}")
293
+
294
+ def get_nll_loss(
295
+ self,
296
+ rec_loss: torch.Tensor,
297
+ weights: Optional[Union[float, torch.Tensor]] = None,
298
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
299
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
300
+ weighted_nll_loss = nll_loss
301
+ if weights is not None:
302
+ weighted_nll_loss = weights * nll_loss
303
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
304
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
305
+
306
+ return nll_loss, weighted_nll_loss
dragnuwa/svd/modules/autoencoding/losses/lpips.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from ....util import default, instantiate_from_config
5
+ from ..lpips.loss.lpips import LPIPS
6
+
7
+
8
+ class LatentLPIPS(nn.Module):
9
+ def __init__(
10
+ self,
11
+ decoder_config,
12
+ perceptual_weight=1.0,
13
+ latent_weight=1.0,
14
+ scale_input_to_tgt_size=False,
15
+ scale_tgt_to_input_size=False,
16
+ perceptual_weight_on_inputs=0.0,
17
+ ):
18
+ super().__init__()
19
+ self.scale_input_to_tgt_size = scale_input_to_tgt_size
20
+ self.scale_tgt_to_input_size = scale_tgt_to_input_size
21
+ self.init_decoder(decoder_config)
22
+ self.perceptual_loss = LPIPS().eval()
23
+ self.perceptual_weight = perceptual_weight
24
+ self.latent_weight = latent_weight
25
+ self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
26
+
27
+ def init_decoder(self, config):
28
+ self.decoder = instantiate_from_config(config)
29
+ if hasattr(self.decoder, "encoder"):
30
+ del self.decoder.encoder
31
+
32
+ def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
33
+ log = dict()
34
+ loss = (latent_inputs - latent_predictions) ** 2
35
+ log[f"{split}/latent_l2_loss"] = loss.mean().detach()
36
+ image_reconstructions = None
37
+ if self.perceptual_weight > 0.0:
38
+ image_reconstructions = self.decoder.decode(latent_predictions)
39
+ image_targets = self.decoder.decode(latent_inputs)
40
+ perceptual_loss = self.perceptual_loss(
41
+ image_targets.contiguous(), image_reconstructions.contiguous()
42
+ )
43
+ loss = (
44
+ self.latent_weight * loss.mean()
45
+ + self.perceptual_weight * perceptual_loss.mean()
46
+ )
47
+ log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
48
+
49
+ if self.perceptual_weight_on_inputs > 0.0:
50
+ image_reconstructions = default(
51
+ image_reconstructions, self.decoder.decode(latent_predictions)
52
+ )
53
+ if self.scale_input_to_tgt_size:
54
+ image_inputs = torch.nn.functional.interpolate(
55
+ image_inputs,
56
+ image_reconstructions.shape[2:],
57
+ mode="bicubic",
58
+ antialias=True,
59
+ )
60
+ elif self.scale_tgt_to_input_size:
61
+ image_reconstructions = torch.nn.functional.interpolate(
62
+ image_reconstructions,
63
+ image_inputs.shape[2:],
64
+ mode="bicubic",
65
+ antialias=True,
66
+ )
67
+
68
+ perceptual_loss2 = self.perceptual_loss(
69
+ image_inputs.contiguous(), image_reconstructions.contiguous()
70
+ )
71
+ loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
72
+ log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
73
+ return loss, log
dragnuwa/svd/modules/autoencoding/lpips/__init__.py ADDED
File without changes
dragnuwa/svd/modules/autoencoding/lpips/loss/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ vgg.pth
dragnuwa/svd/modules/autoencoding/lpips/loss/LICENSE ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
2
+ All rights reserved.
3
+
4
+ Redistribution and use in source and binary forms, with or without
5
+ modification, are permitted provided that the following conditions are met:
6
+
7
+ * Redistributions of source code must retain the above copyright notice, this
8
+ list of conditions and the following disclaimer.
9
+
10
+ * Redistributions in binary form must reproduce the above copyright notice,
11
+ this list of conditions and the following disclaimer in the documentation
12
+ and/or other materials provided with the distribution.
13
+
14
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
dragnuwa/svd/modules/autoencoding/lpips/loss/__init__.py ADDED
File without changes
dragnuwa/svd/modules/autoencoding/lpips/loss/lpips.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2
+
3
+ from collections import namedtuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchvision import models
8
+
9
+ from ..util import get_ckpt_path
10
+
11
+
12
+ class LPIPS(nn.Module):
13
+ # Learned perceptual metric
14
+ def __init__(self, use_dropout=True):
15
+ super().__init__()
16
+ self.scaling_layer = ScalingLayer()
17
+ self.chns = [64, 128, 256, 512, 512] # vg16 features
18
+ self.net = vgg16(pretrained=True, requires_grad=False)
19
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
20
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
21
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
22
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
23
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
24
+ self.load_from_pretrained()
25
+ for param in self.parameters():
26
+ param.requires_grad = False
27
+
28
+ def load_from_pretrained(self, name="vgg_lpips"):
29
+ ckpt = get_ckpt_path(name, "models/svd/modules/autoencoding/lpips/loss")
30
+ self.load_state_dict(
31
+ torch.load(ckpt, map_location=torch.device("cpu")), strict=False
32
+ )
33
+ print("loaded pretrained LPIPS loss from {}".format(ckpt))
34
+
35
+ @classmethod
36
+ def from_pretrained(cls, name="vgg_lpips"):
37
+ if name != "vgg_lpips":
38
+ raise NotImplementedError
39
+ model = cls()
40
+ ckpt = get_ckpt_path(name)
41
+ model.load_state_dict(
42
+ torch.load(ckpt, map_location=torch.device("cpu")), strict=False
43
+ )
44
+ return model
45
+
46
+ def forward(self, input, target):
47
+ in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
48
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
49
+ feats0, feats1, diffs = {}, {}, {}
50
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
51
+ for kk in range(len(self.chns)):
52
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
53
+ outs1[kk]
54
+ )
55
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
56
+
57
+ res = [
58
+ spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
59
+ for kk in range(len(self.chns))
60
+ ]
61
+ val = res[0]
62
+ for l in range(1, len(self.chns)):
63
+ val += res[l]
64
+ return val
65
+
66
+
67
+ class ScalingLayer(nn.Module):
68
+ def __init__(self):
69
+ super(ScalingLayer, self).__init__()
70
+ self.register_buffer(
71
+ "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
72
+ )
73
+ self.register_buffer(
74
+ "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
75
+ )
76
+
77
+ def forward(self, inp):
78
+ return (inp - self.shift) / self.scale
79
+
80
+
81
+ class NetLinLayer(nn.Module):
82
+ """A single linear layer which does a 1x1 conv"""
83
+
84
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
85
+ super(NetLinLayer, self).__init__()
86
+ layers = (
87
+ [
88
+ nn.Dropout(),
89
+ ]
90
+ if (use_dropout)
91
+ else []
92
+ )
93
+ layers += [
94
+ nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
95
+ ]
96
+ self.model = nn.Sequential(*layers)
97
+
98
+
99
+ class vgg16(torch.nn.Module):
100
+ def __init__(self, requires_grad=False, pretrained=True):
101
+ super(vgg16, self).__init__()
102
+ vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
103
+ self.slice1 = torch.nn.Sequential()
104
+ self.slice2 = torch.nn.Sequential()
105
+ self.slice3 = torch.nn.Sequential()
106
+ self.slice4 = torch.nn.Sequential()
107
+ self.slice5 = torch.nn.Sequential()
108
+ self.N_slices = 5
109
+ for x in range(4):
110
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
111
+ for x in range(4, 9):
112
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
113
+ for x in range(9, 16):
114
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
115
+ for x in range(16, 23):
116
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
117
+ for x in range(23, 30):
118
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
119
+ if not requires_grad:
120
+ for param in self.parameters():
121
+ param.requires_grad = False
122
+
123
+ def forward(self, X):
124
+ h = self.slice1(X)
125
+ h_relu1_2 = h
126
+ h = self.slice2(h)
127
+ h_relu2_2 = h
128
+ h = self.slice3(h)
129
+ h_relu3_3 = h
130
+ h = self.slice4(h)
131
+ h_relu4_3 = h
132
+ h = self.slice5(h)
133
+ h_relu5_3 = h
134
+ vgg_outputs = namedtuple(
135
+ "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
136
+ )
137
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
138
+ return out
139
+
140
+
141
+ def normalize_tensor(x, eps=1e-10):
142
+ norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
143
+ return x / (norm_factor + eps)
144
+
145
+
146
+ def spatial_average(x, keepdim=True):
147
+ return x.mean([2, 3], keepdim=keepdim)
dragnuwa/svd/modules/autoencoding/lpips/model/LICENSE ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
2
+ All rights reserved.
3
+
4
+ Redistribution and use in source and binary forms, with or without
5
+ modification, are permitted provided that the following conditions are met:
6
+
7
+ * Redistributions of source code must retain the above copyright notice, this
8
+ list of conditions and the following disclaimer.
9
+
10
+ * Redistributions in binary form must reproduce the above copyright notice,
11
+ this list of conditions and the following disclaimer in the documentation
12
+ and/or other materials provided with the distribution.
13
+
14
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24
+
25
+
26
+ --------------------------- LICENSE FOR pix2pix --------------------------------
27
+ BSD License
28
+
29
+ For pix2pix software
30
+ Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu
31
+ All rights reserved.
32
+
33
+ Redistribution and use in source and binary forms, with or without
34
+ modification, are permitted provided that the following conditions are met:
35
+
36
+ * Redistributions of source code must retain the above copyright notice, this
37
+ list of conditions and the following disclaimer.
38
+
39
+ * Redistributions in binary form must reproduce the above copyright notice,
40
+ this list of conditions and the following disclaimer in the documentation
41
+ and/or other materials provided with the distribution.
42
+
43
+ ----------------------------- LICENSE FOR DCGAN --------------------------------
44
+ BSD License
45
+
46
+ For dcgan.torch software
47
+
48
+ Copyright (c) 2015, Facebook, Inc. All rights reserved.
49
+
50
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
51
+
52
+ Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
53
+
54
+ Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
55
+
56
+ Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
57
+
58
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
dragnuwa/svd/modules/autoencoding/lpips/model/__init__.py ADDED
File without changes
dragnuwa/svd/modules/autoencoding/lpips/model/model.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import torch.nn as nn
4
+
5
+ from ..util import ActNorm
6
+
7
+
8
+ def weights_init(m):
9
+ classname = m.__class__.__name__
10
+ if classname.find("Conv") != -1:
11
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
12
+ elif classname.find("BatchNorm") != -1:
13
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
14
+ nn.init.constant_(m.bias.data, 0)
15
+
16
+
17
+ class NLayerDiscriminator(nn.Module):
18
+ """Defines a PatchGAN discriminator as in Pix2Pix
19
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
20
+ """
21
+
22
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
23
+ """Construct a PatchGAN discriminator
24
+ Parameters:
25
+ input_nc (int) -- the number of channels in input images
26
+ ndf (int) -- the number of filters in the last conv layer
27
+ n_layers (int) -- the number of conv layers in the discriminator
28
+ norm_layer -- normalization layer
29
+ """
30
+ super(NLayerDiscriminator, self).__init__()
31
+ if not use_actnorm:
32
+ norm_layer = nn.BatchNorm2d
33
+ else:
34
+ norm_layer = ActNorm
35
+ if (
36
+ type(norm_layer) == functools.partial
37
+ ): # no need to use bias as BatchNorm2d has affine parameters
38
+ use_bias = norm_layer.func != nn.BatchNorm2d
39
+ else:
40
+ use_bias = norm_layer != nn.BatchNorm2d
41
+
42
+ kw = 4
43
+ padw = 1
44
+ sequence = [
45
+ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
46
+ nn.LeakyReLU(0.2, True),
47
+ ]
48
+ nf_mult = 1
49
+ nf_mult_prev = 1
50
+ for n in range(1, n_layers): # gradually increase the number of filters
51
+ nf_mult_prev = nf_mult
52
+ nf_mult = min(2**n, 8)
53
+ sequence += [
54
+ nn.Conv2d(
55
+ ndf * nf_mult_prev,
56
+ ndf * nf_mult,
57
+ kernel_size=kw,
58
+ stride=2,
59
+ padding=padw,
60
+ bias=use_bias,
61
+ ),
62
+ norm_layer(ndf * nf_mult),
63
+ nn.LeakyReLU(0.2, True),
64
+ ]
65
+
66
+ nf_mult_prev = nf_mult
67
+ nf_mult = min(2**n_layers, 8)
68
+ sequence += [
69
+ nn.Conv2d(
70
+ ndf * nf_mult_prev,
71
+ ndf * nf_mult,
72
+ kernel_size=kw,
73
+ stride=1,
74
+ padding=padw,
75
+ bias=use_bias,
76
+ ),
77
+ norm_layer(ndf * nf_mult),
78
+ nn.LeakyReLU(0.2, True),
79
+ ]
80
+
81
+ sequence += [
82
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
83
+ ] # output 1 channel prediction map
84
+ self.main = nn.Sequential(*sequence)
85
+
86
+ def forward(self, input):
87
+ """Standard forward."""
88
+ return self.main(input)
dragnuwa/svd/modules/autoencoding/lpips/util.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+
4
+ import requests
5
+ import torch
6
+ import torch.nn as nn
7
+ from tqdm import tqdm
8
+
9
+ URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
10
+
11
+ CKPT_MAP = {"vgg_lpips": "vgg.pth"}
12
+
13
+ MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
14
+
15
+
16
+ def download(url, local_path, chunk_size=1024):
17
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
18
+ with requests.get(url, stream=True) as r:
19
+ total_size = int(r.headers.get("content-length", 0))
20
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
21
+ with open(local_path, "wb") as f:
22
+ for data in r.iter_content(chunk_size=chunk_size):
23
+ if data:
24
+ f.write(data)
25
+ pbar.update(chunk_size)
26
+
27
+
28
+ def md5_hash(path):
29
+ with open(path, "rb") as f:
30
+ content = f.read()
31
+ return hashlib.md5(content).hexdigest()
32
+
33
+
34
+ def get_ckpt_path(name, root, check=False):
35
+ assert name in URL_MAP
36
+ path = os.path.join(root, CKPT_MAP[name])
37
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
38
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
39
+ download(URL_MAP[name], path)
40
+ md5 = md5_hash(path)
41
+ assert md5 == MD5_MAP[name], md5
42
+ return path
43
+
44
+
45
+ class ActNorm(nn.Module):
46
+ def __init__(
47
+ self, num_features, logdet=False, affine=True, allow_reverse_init=False
48
+ ):
49
+ assert affine
50
+ super().__init__()
51
+ self.logdet = logdet
52
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
53
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
54
+ self.allow_reverse_init = allow_reverse_init
55
+
56
+ self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
57
+
58
+ def initialize(self, input):
59
+ with torch.no_grad():
60
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
61
+ mean = (
62
+ flatten.mean(1)
63
+ .unsqueeze(1)
64
+ .unsqueeze(2)
65
+ .unsqueeze(3)
66
+ .permute(1, 0, 2, 3)
67
+ )
68
+ std = (
69
+ flatten.std(1)
70
+ .unsqueeze(1)
71
+ .unsqueeze(2)
72
+ .unsqueeze(3)
73
+ .permute(1, 0, 2, 3)
74
+ )
75
+
76
+ self.loc.data.copy_(-mean)
77
+ self.scale.data.copy_(1 / (std + 1e-6))
78
+
79
+ def forward(self, input, reverse=False):
80
+ if reverse:
81
+ return self.reverse(input)
82
+ if len(input.shape) == 2:
83
+ input = input[:, :, None, None]
84
+ squeeze = True
85
+ else:
86
+ squeeze = False
87
+
88
+ _, _, height, width = input.shape
89
+
90
+ if self.training and self.initialized.item() == 0:
91
+ self.initialize(input)
92
+ self.initialized.fill_(1)
93
+
94
+ h = self.scale * (input + self.loc)
95
+
96
+ if squeeze:
97
+ h = h.squeeze(-1).squeeze(-1)
98
+
99
+ if self.logdet:
100
+ log_abs = torch.log(torch.abs(self.scale))
101
+ logdet = height * width * torch.sum(log_abs)
102
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
103
+ return h, logdet
104
+
105
+ return h
106
+
107
+ def reverse(self, output):
108
+ if self.training and self.initialized.item() == 0:
109
+ if not self.allow_reverse_init:
110
+ raise RuntimeError(
111
+ "Initializing ActNorm in reverse direction is "
112
+ "disabled by default. Use allow_reverse_init=True to enable."
113
+ )
114
+ else:
115
+ self.initialize(output)
116
+ self.initialized.fill_(1)
117
+
118
+ if len(output.shape) == 2:
119
+ output = output[:, :, None, None]
120
+ squeeze = True
121
+ else:
122
+ squeeze = False
123
+
124
+ h = output / self.scale - self.loc
125
+
126
+ if squeeze:
127
+ h = h.squeeze(-1).squeeze(-1)
128
+ return h
dragnuwa/svd/modules/autoencoding/lpips/vqperceptual.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def hinge_d_loss(logits_real, logits_fake):
6
+ loss_real = torch.mean(F.relu(1.0 - logits_real))
7
+ loss_fake = torch.mean(F.relu(1.0 + logits_fake))
8
+ d_loss = 0.5 * (loss_real + loss_fake)
9
+ return d_loss
10
+
11
+
12
+ def vanilla_d_loss(logits_real, logits_fake):
13
+ d_loss = 0.5 * (
14
+ torch.mean(torch.nn.functional.softplus(-logits_real))
15
+ + torch.mean(torch.nn.functional.softplus(logits_fake))
16
+ )
17
+ return d_loss
dragnuwa/svd/modules/autoencoding/regularizers/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from typing import Any, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from ....modules.distributions.distributions import \
9
+ DiagonalGaussianDistribution
10
+ from .base import AbstractRegularizer
11
+
12
+
13
+ class DiagonalGaussianRegularizer(AbstractRegularizer):
14
+ def __init__(self, sample: bool = True):
15
+ super().__init__()
16
+ self.sample = sample
17
+
18
+ def get_trainable_parameters(self) -> Any:
19
+ yield from ()
20
+
21
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
22
+ log = dict()
23
+ posterior = DiagonalGaussianDistribution(z)
24
+ if self.sample:
25
+ z = posterior.sample()
26
+ else:
27
+ z = posterior.mode()
28
+ kl_loss = posterior.kl()
29
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
30
+ log["kl_loss"] = kl_loss
31
+ return z, log
dragnuwa/svd/modules/autoencoding/regularizers/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (1.49 kB). View file
 
dragnuwa/svd/modules/autoencoding/regularizers/__pycache__/base.cpython-38.pyc ADDED
Binary file (2.03 kB). View file
 
dragnuwa/svd/modules/autoencoding/regularizers/base.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from typing import Any, Tuple
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+
9
+ class AbstractRegularizer(nn.Module):
10
+ def __init__(self):
11
+ super().__init__()
12
+
13
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
14
+ raise NotImplementedError()
15
+
16
+ @abstractmethod
17
+ def get_trainable_parameters(self) -> Any:
18
+ raise NotImplementedError()
19
+
20
+
21
+ class IdentityRegularizer(AbstractRegularizer):
22
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
23
+ return z, dict()
24
+
25
+ def get_trainable_parameters(self) -> Any:
26
+ yield from ()
27
+
28
+
29
+ def measure_perplexity(
30
+ predicted_indices: torch.Tensor, num_centroids: int
31
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
32
+ # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
33
+ # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
34
+ encodings = (
35
+ F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
36
+ )
37
+ avg_probs = encodings.mean(0)
38
+ perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
39
+ cluster_use = torch.sum(avg_probs > 0)
40
+ return perplexity, cluster_use
dragnuwa/svd/modules/autoencoding/regularizers/quantize.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from abc import abstractmethod
3
+ from typing import Dict, Iterator, Literal, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange
10
+ from torch import einsum
11
+
12
+ from .base import AbstractRegularizer, measure_perplexity
13
+
14
+ logpy = logging.getLogger(__name__)
15
+
16
+
17
+ class AbstractQuantizer(AbstractRegularizer):
18
+ def __init__(self):
19
+ super().__init__()
20
+ # Define these in your init
21
+ # shape (N,)
22
+ self.used: Optional[torch.Tensor]
23
+ self.re_embed: int
24
+ self.unknown_index: Union[Literal["random"], int]
25
+
26
+ def remap_to_used(self, inds: torch.Tensor) -> torch.Tensor:
27
+ assert self.used is not None, "You need to define used indices for remap"
28
+ ishape = inds.shape
29
+ assert len(ishape) > 1
30
+ inds = inds.reshape(ishape[0], -1)
31
+ used = self.used.to(inds)
32
+ match = (inds[:, :, None] == used[None, None, ...]).long()
33
+ new = match.argmax(-1)
34
+ unknown = match.sum(2) < 1
35
+ if self.unknown_index == "random":
36
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
37
+ device=new.device
38
+ )
39
+ else:
40
+ new[unknown] = self.unknown_index
41
+ return new.reshape(ishape)
42
+
43
+ def unmap_to_all(self, inds: torch.Tensor) -> torch.Tensor:
44
+ assert self.used is not None, "You need to define used indices for remap"
45
+ ishape = inds.shape
46
+ assert len(ishape) > 1
47
+ inds = inds.reshape(ishape[0], -1)
48
+ used = self.used.to(inds)
49
+ if self.re_embed > self.used.shape[0]: # extra token
50
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
51
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
52
+ return back.reshape(ishape)
53
+
54
+ @abstractmethod
55
+ def get_codebook_entry(
56
+ self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None
57
+ ) -> torch.Tensor:
58
+ raise NotImplementedError()
59
+
60
+ def get_trainable_parameters(self) -> Iterator[torch.nn.Parameter]:
61
+ yield from self.parameters()
62
+
63
+
64
+ class GumbelQuantizer(AbstractQuantizer):
65
+ """
66
+ credit to @karpathy:
67
+ https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
68
+ Gumbel Softmax trick quantizer
69
+ Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
70
+ https://arxiv.org/abs/1611.01144
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ num_hiddens: int,
76
+ embedding_dim: int,
77
+ n_embed: int,
78
+ straight_through: bool = True,
79
+ kl_weight: float = 5e-4,
80
+ temp_init: float = 1.0,
81
+ remap: Optional[str] = None,
82
+ unknown_index: str = "random",
83
+ loss_key: str = "loss/vq",
84
+ ) -> None:
85
+ super().__init__()
86
+
87
+ self.loss_key = loss_key
88
+ self.embedding_dim = embedding_dim
89
+ self.n_embed = n_embed
90
+
91
+ self.straight_through = straight_through
92
+ self.temperature = temp_init
93
+ self.kl_weight = kl_weight
94
+
95
+ self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
96
+ self.embed = nn.Embedding(n_embed, embedding_dim)
97
+
98
+ self.remap = remap
99
+ if self.remap is not None:
100
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
101
+ self.re_embed = self.used.shape[0]
102
+ else:
103
+ self.used = None
104
+ self.re_embed = n_embed
105
+ if unknown_index == "extra":
106
+ self.unknown_index = self.re_embed
107
+ self.re_embed = self.re_embed + 1
108
+ else:
109
+ assert unknown_index == "random" or isinstance(
110
+ unknown_index, int
111
+ ), "unknown index needs to be 'random', 'extra' or any integer"
112
+ self.unknown_index = unknown_index # "random" or "extra" or integer
113
+ if self.remap is not None:
114
+ logpy.info(
115
+ f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
116
+ f"Using {self.unknown_index} for unknown indices."
117
+ )
118
+
119
+ def forward(
120
+ self, z: torch.Tensor, temp: Optional[float] = None, return_logits: bool = False
121
+ ) -> Tuple[torch.Tensor, Dict]:
122
+ # force hard = True when we are in eval mode, as we must quantize.
123
+ # actually, always true seems to work
124
+ hard = self.straight_through if self.training else True
125
+ temp = self.temperature if temp is None else temp
126
+ out_dict = {}
127
+ logits = self.proj(z)
128
+ if self.remap is not None:
129
+ # continue only with used logits
130
+ full_zeros = torch.zeros_like(logits)
131
+ logits = logits[:, self.used, ...]
132
+
133
+ soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
134
+ if self.remap is not None:
135
+ # go back to all entries but unused set to zero
136
+ full_zeros[:, self.used, ...] = soft_one_hot
137
+ soft_one_hot = full_zeros
138
+ z_q = einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
139
+
140
+ # + kl divergence to the prior loss
141
+ qy = F.softmax(logits, dim=1)
142
+ diff = (
143
+ self.kl_weight
144
+ * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
145
+ )
146
+ out_dict[self.loss_key] = diff
147
+
148
+ ind = soft_one_hot.argmax(dim=1)
149
+ out_dict["indices"] = ind
150
+ if self.remap is not None:
151
+ ind = self.remap_to_used(ind)
152
+
153
+ if return_logits:
154
+ out_dict["logits"] = logits
155
+
156
+ return z_q, out_dict
157
+
158
+ def get_codebook_entry(self, indices, shape):
159
+ # TODO: shape not yet optional
160
+ b, h, w, c = shape
161
+ assert b * h * w == indices.shape[0]
162
+ indices = rearrange(indices, "(b h w) -> b h w", b=b, h=h, w=w)
163
+ if self.remap is not None:
164
+ indices = self.unmap_to_all(indices)
165
+ one_hot = (
166
+ F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
167
+ )
168
+ z_q = einsum("b n h w, n d -> b d h w", one_hot, self.embed.weight)
169
+ return z_q
170
+
171
+
172
+ class VectorQuantizer(AbstractQuantizer):
173
+ """
174
+ ____________________________________________
175
+ Discretization bottleneck part of the VQ-VAE.
176
+ Inputs:
177
+ - n_e : number of embeddings
178
+ - e_dim : dimension of embedding
179
+ - beta : commitment cost used in loss term,
180
+ beta * ||z_e(x)-sg[e]||^2
181
+ _____________________________________________
182
+ """
183
+
184
+ def __init__(
185
+ self,
186
+ n_e: int,
187
+ e_dim: int,
188
+ beta: float = 0.25,
189
+ remap: Optional[str] = None,
190
+ unknown_index: str = "random",
191
+ sane_index_shape: bool = False,
192
+ log_perplexity: bool = False,
193
+ embedding_weight_norm: bool = False,
194
+ loss_key: str = "loss/vq",
195
+ ):
196
+ super().__init__()
197
+ self.n_e = n_e
198
+ self.e_dim = e_dim
199
+ self.beta = beta
200
+ self.loss_key = loss_key
201
+
202
+ if not embedding_weight_norm:
203
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
204
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
205
+ else:
206
+ self.embedding = torch.nn.utils.weight_norm(
207
+ nn.Embedding(self.n_e, self.e_dim), dim=1
208
+ )
209
+
210
+ self.remap = remap
211
+ if self.remap is not None:
212
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
213
+ self.re_embed = self.used.shape[0]
214
+ else:
215
+ self.used = None
216
+ self.re_embed = n_e
217
+ if unknown_index == "extra":
218
+ self.unknown_index = self.re_embed
219
+ self.re_embed = self.re_embed + 1
220
+ else:
221
+ assert unknown_index == "random" or isinstance(
222
+ unknown_index, int
223
+ ), "unknown index needs to be 'random', 'extra' or any integer"
224
+ self.unknown_index = unknown_index # "random" or "extra" or integer
225
+ if self.remap is not None:
226
+ logpy.info(
227
+ f"Remapping {self.n_e} indices to {self.re_embed} indices. "
228
+ f"Using {self.unknown_index} for unknown indices."
229
+ )
230
+
231
+ self.sane_index_shape = sane_index_shape
232
+ self.log_perplexity = log_perplexity
233
+
234
+ def forward(
235
+ self,
236
+ z: torch.Tensor,
237
+ ) -> Tuple[torch.Tensor, Dict]:
238
+ do_reshape = z.ndim == 4
239
+ if do_reshape:
240
+ # # reshape z -> (batch, height, width, channel) and flatten
241
+ z = rearrange(z, "b c h w -> b h w c").contiguous()
242
+
243
+ else:
244
+ assert z.ndim < 4, "No reshaping strategy for inputs > 4 dimensions defined"
245
+ z = z.contiguous()
246
+
247
+ z_flattened = z.view(-1, self.e_dim)
248
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
249
+
250
+ d = (
251
+ torch.sum(z_flattened**2, dim=1, keepdim=True)
252
+ + torch.sum(self.embedding.weight**2, dim=1)
253
+ - 2
254
+ * torch.einsum(
255
+ "bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n")
256
+ )
257
+ )
258
+
259
+ min_encoding_indices = torch.argmin(d, dim=1)
260
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
261
+ loss_dict = {}
262
+ if self.log_perplexity:
263
+ perplexity, cluster_usage = measure_perplexity(
264
+ min_encoding_indices.detach(), self.n_e
265
+ )
266
+ loss_dict.update({"perplexity": perplexity, "cluster_usage": cluster_usage})
267
+
268
+ # compute loss for embedding
269
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean(
270
+ (z_q - z.detach()) ** 2
271
+ )
272
+ loss_dict[self.loss_key] = loss
273
+
274
+ # preserve gradients
275
+ z_q = z + (z_q - z).detach()
276
+
277
+ # reshape back to match original input shape
278
+ if do_reshape:
279
+ z_q = rearrange(z_q, "b h w c -> b c h w").contiguous()
280
+
281
+ if self.remap is not None:
282
+ min_encoding_indices = min_encoding_indices.reshape(
283
+ z.shape[0], -1
284
+ ) # add batch axis
285
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
286
+ min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
287
+
288
+ if self.sane_index_shape:
289
+ if do_reshape:
290
+ min_encoding_indices = min_encoding_indices.reshape(
291
+ z_q.shape[0], z_q.shape[2], z_q.shape[3]
292
+ )
293
+ else:
294
+ min_encoding_indices = rearrange(
295
+ min_encoding_indices, "(b s) 1 -> b s", b=z_q.shape[0]
296
+ )
297
+
298
+ loss_dict["min_encoding_indices"] = min_encoding_indices
299
+
300
+ return z_q, loss_dict
301
+
302
+ def get_codebook_entry(
303
+ self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None
304
+ ) -> torch.Tensor:
305
+ # shape specifying (batch, height, width, channel)
306
+ if self.remap is not None:
307
+ assert shape is not None, "Need to give shape for remap"
308
+ indices = indices.reshape(shape[0], -1) # add batch axis
309
+ indices = self.unmap_to_all(indices)
310
+ indices = indices.reshape(-1) # flatten again
311
+
312
+ # get quantized latent vectors
313
+ z_q = self.embedding(indices)
314
+
315
+ if shape is not None:
316
+ z_q = z_q.view(shape)
317
+ # reshape back to match original input shape
318
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
319
+
320
+ return z_q
321
+
322
+
323
+ class EmbeddingEMA(nn.Module):
324
+ def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
325
+ super().__init__()
326
+ self.decay = decay
327
+ self.eps = eps
328
+ weight = torch.randn(num_tokens, codebook_dim)
329
+ self.weight = nn.Parameter(weight, requires_grad=False)
330
+ self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False)
331
+ self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
332
+ self.update = True
333
+
334
+ def forward(self, embed_id):
335
+ return F.embedding(embed_id, self.weight)
336
+
337
+ def cluster_size_ema_update(self, new_cluster_size):
338
+ self.cluster_size.data.mul_(self.decay).add_(
339
+ new_cluster_size, alpha=1 - self.decay
340
+ )
341
+
342
+ def embed_avg_ema_update(self, new_embed_avg):
343
+ self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
344
+
345
+ def weight_update(self, num_tokens):
346
+ n = self.cluster_size.sum()
347
+ smoothed_cluster_size = (
348
+ (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
349
+ )
350
+ # normalize embedding average with smoothed cluster size
351
+ embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
352
+ self.weight.data.copy_(embed_normalized)
353
+
354
+
355
+ class EMAVectorQuantizer(AbstractQuantizer):
356
+ def __init__(
357
+ self,
358
+ n_embed: int,
359
+ embedding_dim: int,
360
+ beta: float,
361
+ decay: float = 0.99,
362
+ eps: float = 1e-5,
363
+ remap: Optional[str] = None,
364
+ unknown_index: str = "random",
365
+ loss_key: str = "loss/vq",
366
+ ):
367
+ super().__init__()
368
+ self.codebook_dim = embedding_dim
369
+ self.num_tokens = n_embed
370
+ self.beta = beta
371
+ self.loss_key = loss_key
372
+
373
+ self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
374
+
375
+ self.remap = remap
376
+ if self.remap is not None:
377
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
378
+ self.re_embed = self.used.shape[0]
379
+ else:
380
+ self.used = None
381
+ self.re_embed = n_embed
382
+ if unknown_index == "extra":
383
+ self.unknown_index = self.re_embed
384
+ self.re_embed = self.re_embed + 1
385
+ else:
386
+ assert unknown_index == "random" or isinstance(
387
+ unknown_index, int
388
+ ), "unknown index needs to be 'random', 'extra' or any integer"
389
+ self.unknown_index = unknown_index # "random" or "extra" or integer
390
+ if self.remap is not None:
391
+ logpy.info(
392
+ f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
393
+ f"Using {self.unknown_index} for unknown indices."
394
+ )
395
+
396
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
397
+ # reshape z -> (batch, height, width, channel) and flatten
398
+ # z, 'b c h w -> b h w c'
399
+ z = rearrange(z, "b c h w -> b h w c")
400
+ z_flattened = z.reshape(-1, self.codebook_dim)
401
+
402
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
403
+ d = (
404
+ z_flattened.pow(2).sum(dim=1, keepdim=True)
405
+ + self.embedding.weight.pow(2).sum(dim=1)
406
+ - 2 * torch.einsum("bd,nd->bn", z_flattened, self.embedding.weight)
407
+ ) # 'n d -> d n'
408
+
409
+ encoding_indices = torch.argmin(d, dim=1)
410
+
411
+ z_q = self.embedding(encoding_indices).view(z.shape)
412
+ encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
413
+ avg_probs = torch.mean(encodings, dim=0)
414
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
415
+
416
+ if self.training and self.embedding.update:
417
+ # EMA cluster size
418
+ encodings_sum = encodings.sum(0)
419
+ self.embedding.cluster_size_ema_update(encodings_sum)
420
+ # EMA embedding average
421
+ embed_sum = encodings.transpose(0, 1) @ z_flattened
422
+ self.embedding.embed_avg_ema_update(embed_sum)
423
+ # normalize embed_avg and update weight
424
+ self.embedding.weight_update(self.num_tokens)
425
+
426
+ # compute loss for embedding
427
+ loss = self.beta * F.mse_loss(z_q.detach(), z)
428
+
429
+ # preserve gradients
430
+ z_q = z + (z_q - z).detach()
431
+
432
+ # reshape back to match original input shape
433
+ # z_q, 'b h w c -> b c h w'
434
+ z_q = rearrange(z_q, "b h w c -> b c h w")
435
+
436
+ out_dict = {
437
+ self.loss_key: loss,
438
+ "encodings": encodings,
439
+ "encoding_indices": encoding_indices,
440
+ "perplexity": perplexity,
441
+ }
442
+
443
+ return z_q, out_dict
444
+
445
+
446
+ class VectorQuantizerWithInputProjection(VectorQuantizer):
447
+ def __init__(
448
+ self,
449
+ input_dim: int,
450
+ n_codes: int,
451
+ codebook_dim: int,
452
+ beta: float = 1.0,
453
+ output_dim: Optional[int] = None,
454
+ **kwargs,
455
+ ):
456
+ super().__init__(n_codes, codebook_dim, beta, **kwargs)
457
+ self.proj_in = nn.Linear(input_dim, codebook_dim)
458
+ self.output_dim = output_dim
459
+ if output_dim is not None:
460
+ self.proj_out = nn.Linear(codebook_dim, output_dim)
461
+ else:
462
+ self.proj_out = nn.Identity()
463
+
464
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
465
+ rearr = False
466
+ in_shape = z.shape
467
+
468
+ if z.ndim > 3:
469
+ rearr = self.output_dim is not None
470
+ z = rearrange(z, "b c ... -> b (...) c")
471
+ z = self.proj_in(z)
472
+ z_q, loss_dict = super().forward(z)
473
+
474
+ z_q = self.proj_out(z_q)
475
+ if rearr:
476
+ if len(in_shape) == 4:
477
+ z_q = rearrange(z_q, "b (h w) c -> b c h w ", w=in_shape[-1])
478
+ elif len(in_shape) == 5:
479
+ z_q = rearrange(
480
+ z_q, "b (t h w) c -> b c t h w ", w=in_shape[-1], h=in_shape[-2]
481
+ )
482
+ else:
483
+ raise NotImplementedError(
484
+ f"rearranging not available for {len(in_shape)}-dimensional input."
485
+ )
486
+
487
+ return z_q, loss_dict
dragnuwa/svd/modules/autoencoding/temporal_ae.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Iterable, Union
2
+
3
+ import torch
4
+ from einops import rearrange, repeat
5
+
6
+ from dragnuwa.svd.modules.diffusionmodules.model import (
7
+ XFORMERS_IS_AVAILABLE,
8
+ AttnBlock,
9
+ Decoder,
10
+ MemoryEfficientAttnBlock,
11
+ ResnetBlock,
12
+ )
13
+ from dragnuwa.svd.modules.diffusionmodules.openaimodel import ResBlock, timestep_embedding
14
+ from dragnuwa.svd.modules.video_attention import VideoTransformerBlock
15
+ from dragnuwa.svd.util import partialclass
16
+
17
+
18
+ class VideoResBlock(ResnetBlock):
19
+ def __init__(
20
+ self,
21
+ out_channels,
22
+ *args,
23
+ dropout=0.0,
24
+ video_kernel_size=3,
25
+ alpha=0.0,
26
+ merge_strategy="learned",
27
+ **kwargs,
28
+ ):
29
+ super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)
30
+ if video_kernel_size is None:
31
+ video_kernel_size = [3, 1, 1]
32
+ self.time_stack = ResBlock(
33
+ channels=out_channels,
34
+ emb_channels=0,
35
+ dropout=dropout,
36
+ dims=3,
37
+ use_scale_shift_norm=False,
38
+ use_conv=False,
39
+ up=False,
40
+ down=False,
41
+ kernel_size=video_kernel_size,
42
+ use_checkpoint=False,
43
+ skip_t_emb=True,
44
+ )
45
+
46
+ self.merge_strategy = merge_strategy
47
+ if self.merge_strategy == "fixed":
48
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
49
+ elif self.merge_strategy == "learned":
50
+ self.register_parameter(
51
+ "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
52
+ )
53
+ else:
54
+ raise ValueError(f"unknown merge strategy {self.merge_strategy}")
55
+
56
+ def get_alpha(self, bs):
57
+ if self.merge_strategy == "fixed":
58
+ return self.mix_factor
59
+ elif self.merge_strategy == "learned":
60
+ return torch.sigmoid(self.mix_factor)
61
+ else:
62
+ raise NotImplementedError()
63
+
64
+ def forward(self, x, temb, skip_video=False, timesteps=None):
65
+ if timesteps is None:
66
+ timesteps = self.timesteps
67
+
68
+ b, c, h, w = x.shape
69
+
70
+ x = super().forward(x, temb)
71
+
72
+ if not skip_video:
73
+ x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
74
+
75
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
76
+
77
+ x = self.time_stack(x, temb)
78
+
79
+ alpha = self.get_alpha(bs=b // timesteps)
80
+ x = alpha * x + (1.0 - alpha) * x_mix
81
+
82
+ x = rearrange(x, "b c t h w -> (b t) c h w")
83
+ return x
84
+
85
+
86
+ class AE3DConv(torch.nn.Conv2d):
87
+ def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
88
+ super().__init__(in_channels, out_channels, *args, **kwargs)
89
+ if isinstance(video_kernel_size, Iterable):
90
+ padding = [int(k // 2) for k in video_kernel_size]
91
+ else:
92
+ padding = int(video_kernel_size // 2)
93
+
94
+ self.time_mix_conv = torch.nn.Conv3d(
95
+ in_channels=out_channels,
96
+ out_channels=out_channels,
97
+ kernel_size=video_kernel_size,
98
+ padding=padding,
99
+ )
100
+
101
+ def forward(self, input, timesteps, skip_video=False):
102
+ x = super().forward(input)
103
+ if skip_video:
104
+ return x
105
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
106
+ x = self.time_mix_conv(x)
107
+ return rearrange(x, "b c t h w -> (b t) c h w")
108
+
109
+
110
+ class VideoBlock(AttnBlock):
111
+ def __init__(
112
+ self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
113
+ ):
114
+ super().__init__(in_channels)
115
+ # no context, single headed, as in base class
116
+ self.time_mix_block = VideoTransformerBlock(
117
+ dim=in_channels,
118
+ n_heads=1,
119
+ d_head=in_channels,
120
+ checkpoint=False,
121
+ ff_in=True,
122
+ attn_mode="softmax",
123
+ )
124
+
125
+ time_embed_dim = self.in_channels * 4
126
+ self.video_time_embed = torch.nn.Sequential(
127
+ torch.nn.Linear(self.in_channels, time_embed_dim),
128
+ torch.nn.SiLU(),
129
+ torch.nn.Linear(time_embed_dim, self.in_channels),
130
+ )
131
+
132
+ self.merge_strategy = merge_strategy
133
+ if self.merge_strategy == "fixed":
134
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
135
+ elif self.merge_strategy == "learned":
136
+ self.register_parameter(
137
+ "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
138
+ )
139
+ else:
140
+ raise ValueError(f"unknown merge strategy {self.merge_strategy}")
141
+
142
+ def forward(self, x, timesteps, skip_video=False):
143
+ if skip_video:
144
+ return super().forward(x)
145
+
146
+ x_in = x
147
+ x = self.attention(x)
148
+ h, w = x.shape[2:]
149
+ x = rearrange(x, "b c h w -> b (h w) c")
150
+
151
+ x_mix = x
152
+ num_frames = torch.arange(timesteps, device=x.device)
153
+ num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
154
+ num_frames = rearrange(num_frames, "b t -> (b t)")
155
+ t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
156
+ emb = self.video_time_embed(t_emb) # b, n_channels
157
+ emb = emb[:, None, :]
158
+ x_mix = x_mix + emb
159
+
160
+ alpha = self.get_alpha()
161
+ x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
162
+ x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
163
+
164
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
165
+ x = self.proj_out(x)
166
+
167
+ return x_in + x
168
+
169
+ def get_alpha(
170
+ self,
171
+ ):
172
+ if self.merge_strategy == "fixed":
173
+ return self.mix_factor
174
+ elif self.merge_strategy == "learned":
175
+ return torch.sigmoid(self.mix_factor)
176
+ else:
177
+ raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
178
+
179
+
180
+ class MemoryEfficientVideoBlock(MemoryEfficientAttnBlock):
181
+ def __init__(
182
+ self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
183
+ ):
184
+ super().__init__(in_channels)
185
+ # no context, single headed, as in base class
186
+ self.time_mix_block = VideoTransformerBlock(
187
+ dim=in_channels,
188
+ n_heads=1,
189
+ d_head=in_channels,
190
+ checkpoint=False,
191
+ ff_in=True,
192
+ attn_mode="softmax-xformers",
193
+ )
194
+
195
+ time_embed_dim = self.in_channels * 4
196
+ self.video_time_embed = torch.nn.Sequential(
197
+ torch.nn.Linear(self.in_channels, time_embed_dim),
198
+ torch.nn.SiLU(),
199
+ torch.nn.Linear(time_embed_dim, self.in_channels),
200
+ )
201
+
202
+ self.merge_strategy = merge_strategy
203
+ if self.merge_strategy == "fixed":
204
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
205
+ elif self.merge_strategy == "learned":
206
+ self.register_parameter(
207
+ "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
208
+ )
209
+ else:
210
+ raise ValueError(f"unknown merge strategy {self.merge_strategy}")
211
+
212
+ def forward(self, x, timesteps, skip_time_block=False):
213
+ if skip_time_block:
214
+ return super().forward(x)
215
+
216
+ x_in = x
217
+ x = self.attention(x)
218
+ h, w = x.shape[2:]
219
+ x = rearrange(x, "b c h w -> b (h w) c")
220
+
221
+ x_mix = x
222
+ num_frames = torch.arange(timesteps, device=x.device)
223
+ num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
224
+ num_frames = rearrange(num_frames, "b t -> (b t)")
225
+ t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
226
+ emb = self.video_time_embed(t_emb) # b, n_channels
227
+ emb = emb[:, None, :]
228
+ x_mix = x_mix + emb
229
+
230
+ alpha = self.get_alpha()
231
+ x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
232
+ x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
233
+
234
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
235
+ x = self.proj_out(x)
236
+
237
+ return x_in + x
238
+
239
+ def get_alpha(
240
+ self,
241
+ ):
242
+ if self.merge_strategy == "fixed":
243
+ return self.mix_factor
244
+ elif self.merge_strategy == "learned":
245
+ return torch.sigmoid(self.mix_factor)
246
+ else:
247
+ raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
248
+
249
+
250
+ def make_time_attn(
251
+ in_channels,
252
+ attn_type="vanilla",
253
+ attn_kwargs=None,
254
+ alpha: float = 0,
255
+ merge_strategy: str = "learned",
256
+ ):
257
+ assert attn_type in [
258
+ "vanilla",
259
+ "vanilla-xformers",
260
+ ], f"attn_type {attn_type} not supported for spatio-temporal attention"
261
+ print(
262
+ f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels"
263
+ )
264
+ if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers":
265
+ print(
266
+ f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. "
267
+ f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
268
+ )
269
+ attn_type = "vanilla"
270
+
271
+ if attn_type == "vanilla":
272
+ assert attn_kwargs is None
273
+ return partialclass(
274
+ VideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy
275
+ )
276
+ elif attn_type == "vanilla-xformers":
277
+ print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
278
+ return partialclass(
279
+ MemoryEfficientVideoBlock,
280
+ in_channels,
281
+ alpha=alpha,
282
+ merge_strategy=merge_strategy,
283
+ )
284
+ else:
285
+ return NotImplementedError()
286
+
287
+
288
+ class Conv2DWrapper(torch.nn.Conv2d):
289
+ def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
290
+ return super().forward(input)
291
+
292
+
293
+ class VideoDecoder(Decoder):
294
+ available_time_modes = ["all", "conv-only", "attn-only"]
295
+
296
+ def __init__(
297
+ self,
298
+ *args,
299
+ video_kernel_size: Union[int, list] = 3,
300
+ alpha: float = 0.0,
301
+ merge_strategy: str = "learned",
302
+ time_mode: str = "conv-only",
303
+ **kwargs,
304
+ ):
305
+ self.video_kernel_size = video_kernel_size
306
+ self.alpha = alpha
307
+ self.merge_strategy = merge_strategy
308
+ self.time_mode = time_mode
309
+ assert (
310
+ self.time_mode in self.available_time_modes
311
+ ), f"time_mode parameter has to be in {self.available_time_modes}"
312
+ super().__init__(*args, **kwargs)
313
+
314
+ def get_last_layer(self, skip_time_mix=False, **kwargs):
315
+ if self.time_mode == "attn-only":
316
+ raise NotImplementedError("TODO")
317
+ else:
318
+ return (
319
+ self.conv_out.time_mix_conv.weight
320
+ if not skip_time_mix
321
+ else self.conv_out.weight
322
+ )
323
+
324
+ def _make_attn(self) -> Callable:
325
+ if self.time_mode not in ["conv-only", "only-last-conv"]:
326
+ return partialclass(
327
+ make_time_attn,
328
+ alpha=self.alpha,
329
+ merge_strategy=self.merge_strategy,
330
+ )
331
+ else:
332
+ return super()._make_attn()
333
+
334
+ def _make_conv(self) -> Callable:
335
+ if self.time_mode != "attn-only":
336
+ return partialclass(AE3DConv, video_kernel_size=self.video_kernel_size)
337
+ else:
338
+ return Conv2DWrapper
339
+
340
+ def _make_resblock(self) -> Callable:
341
+ if self.time_mode not in ["attn-only", "only-last-conv"]:
342
+ return partialclass(
343
+ VideoResBlock,
344
+ video_kernel_size=self.video_kernel_size,
345
+ alpha=self.alpha,
346
+ merge_strategy=self.merge_strategy,
347
+ )
348
+ else:
349
+ return super()._make_resblock()