Spaces:
Runtime error
Runtime error
yinshengming
commited on
Commit
•
ab85cf9
1
Parent(s):
2fea44e
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +7 -0
- DragNUWA_net.py +297 -0
- app.py +386 -0
- assets/DragNUWA1.0/Figure1.gif +3 -0
- assets/DragNUWA1.0/Figure2.gif +3 -0
- assets/DragNUWA1.0/Figure3.gif +3 -0
- assets/DragNUWA1.5/Figure1.gif +3 -0
- assets/DragNUWA1.5/Figure2.gif +3 -0
- assets/DragNUWA1.5/Figure3.gif +3 -0
- assets/DragNUWA1.5/Figure4.gif +3 -0
- dragnuwa.md +81 -0
- dragnuwa/__init__.py +0 -0
- dragnuwa/__pycache__/__init__.cpython-38.pyc +0 -0
- dragnuwa/__pycache__/lora.cpython-38.pyc +0 -0
- dragnuwa/lora.py +412 -0
- dragnuwa/svd/__init__.py +0 -0
- dragnuwa/svd/__pycache__/__init__.cpython-38.pyc +0 -0
- dragnuwa/svd/__pycache__/util.cpython-38.pyc +0 -0
- dragnuwa/svd/models/__init__.py +0 -0
- dragnuwa/svd/models/__pycache__/__init__.cpython-38.pyc +0 -0
- dragnuwa/svd/models/__pycache__/autoencoder.cpython-38.pyc +0 -0
- dragnuwa/svd/models/autoencoder.py +615 -0
- dragnuwa/svd/modules/__init__.py +6 -0
- dragnuwa/svd/modules/__pycache__/__init__.cpython-38.pyc +0 -0
- dragnuwa/svd/modules/__pycache__/attention.cpython-38.pyc +0 -0
- dragnuwa/svd/modules/__pycache__/ema.cpython-38.pyc +0 -0
- dragnuwa/svd/modules/__pycache__/video_attention.cpython-38.pyc +0 -0
- dragnuwa/svd/modules/attention.py +759 -0
- dragnuwa/svd/modules/autoencoding/__init__.py +0 -0
- dragnuwa/svd/modules/autoencoding/__pycache__/__init__.cpython-38.pyc +0 -0
- dragnuwa/svd/modules/autoencoding/__pycache__/temporal_ae.cpython-38.pyc +0 -0
- dragnuwa/svd/modules/autoencoding/losses/__init__.py +7 -0
- dragnuwa/svd/modules/autoencoding/losses/discriminator_loss.py +306 -0
- dragnuwa/svd/modules/autoencoding/losses/lpips.py +73 -0
- dragnuwa/svd/modules/autoencoding/lpips/__init__.py +0 -0
- dragnuwa/svd/modules/autoencoding/lpips/loss/.gitignore +1 -0
- dragnuwa/svd/modules/autoencoding/lpips/loss/LICENSE +23 -0
- dragnuwa/svd/modules/autoencoding/lpips/loss/__init__.py +0 -0
- dragnuwa/svd/modules/autoencoding/lpips/loss/lpips.py +147 -0
- dragnuwa/svd/modules/autoencoding/lpips/model/LICENSE +58 -0
- dragnuwa/svd/modules/autoencoding/lpips/model/__init__.py +0 -0
- dragnuwa/svd/modules/autoencoding/lpips/model/model.py +88 -0
- dragnuwa/svd/modules/autoencoding/lpips/util.py +128 -0
- dragnuwa/svd/modules/autoencoding/lpips/vqperceptual.py +17 -0
- dragnuwa/svd/modules/autoencoding/regularizers/__init__.py +31 -0
- dragnuwa/svd/modules/autoencoding/regularizers/__pycache__/__init__.cpython-38.pyc +0 -0
- dragnuwa/svd/modules/autoencoding/regularizers/__pycache__/base.cpython-38.pyc +0 -0
- dragnuwa/svd/modules/autoencoding/regularizers/base.py +40 -0
- dragnuwa/svd/modules/autoencoding/regularizers/quantize.py +487 -0
- dragnuwa/svd/modules/autoencoding/temporal_ae.py +349 -0
.gitattributes
CHANGED
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/DragNUWA1.0/Figure1.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/DragNUWA1.0/Figure2.gif filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/DragNUWA1.0/Figure3.gif filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/DragNUWA1.5/Figure1.gif filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/DragNUWA1.5/Figure2.gif filter=lfs diff=lfs merge=lfs -text
|
41 |
+
assets/DragNUWA1.5/Figure3.gif filter=lfs diff=lfs merge=lfs -text
|
42 |
+
assets/DragNUWA1.5/Figure4.gif filter=lfs diff=lfs merge=lfs -text
|
DragNUWA_net.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils import *
|
2 |
+
|
3 |
+
#### SVD
|
4 |
+
from dragnuwa.svd.modules.diffusionmodules.video_model_flow import VideoUNet_flow, VideoResBlock_Embed
|
5 |
+
from dragnuwa.svd.modules.diffusionmodules.denoiser import Denoiser
|
6 |
+
from dragnuwa.svd.modules.diffusionmodules.denoiser_scaling import VScalingWithEDMcNoise
|
7 |
+
from dragnuwa.svd.modules.encoders.modules import *
|
8 |
+
from dragnuwa.svd.models.autoencoder import AutoencodingEngine
|
9 |
+
from dragnuwa.svd.modules.diffusionmodules.wrappers import OpenAIWrapper
|
10 |
+
from dragnuwa.svd.modules.diffusionmodules.sampling import EulerEDMSampler
|
11 |
+
|
12 |
+
from dragnuwa.lora import inject_trainable_lora, inject_trainable_lora_extended, extract_lora_ups_down, _find_modules
|
13 |
+
|
14 |
+
def get_gaussian_kernel(kernel_size, sigma, channels):
|
15 |
+
print('parameters of gaussian kernel: kernel_size: {}, sigma: {}, channels: {}'.format(kernel_size, sigma, channels))
|
16 |
+
x_coord = torch.arange(kernel_size)
|
17 |
+
x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size)
|
18 |
+
y_grid = x_grid.t()
|
19 |
+
xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()
|
20 |
+
mean = (kernel_size - 1)/2.
|
21 |
+
variance = sigma**2.
|
22 |
+
|
23 |
+
gaussian_kernel = torch.exp(
|
24 |
+
-torch.sum((xy_grid - mean)**2., dim=-1) /\
|
25 |
+
(2*variance)
|
26 |
+
)
|
27 |
+
|
28 |
+
gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)
|
29 |
+
gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1)
|
30 |
+
|
31 |
+
gaussian_filter = nn.Conv2d(in_channels=channels, out_channels=channels,kernel_size=kernel_size, groups=channels, bias=False, padding=kernel_size//2)
|
32 |
+
|
33 |
+
gaussian_filter.weight.data = gaussian_kernel
|
34 |
+
gaussian_filter.weight.requires_grad = False
|
35 |
+
|
36 |
+
return gaussian_filter
|
37 |
+
|
38 |
+
def inject_lora(use_lora, model, replace_modules, is_extended=False, dropout=0.0, r=16):
|
39 |
+
injector = (
|
40 |
+
inject_trainable_lora if not is_extended
|
41 |
+
else
|
42 |
+
inject_trainable_lora_extended
|
43 |
+
)
|
44 |
+
|
45 |
+
params = None
|
46 |
+
negation = None
|
47 |
+
|
48 |
+
if use_lora:
|
49 |
+
REPLACE_MODULES = replace_modules
|
50 |
+
injector_args = {
|
51 |
+
"model": model,
|
52 |
+
"target_replace_module": REPLACE_MODULES,
|
53 |
+
"r": r
|
54 |
+
}
|
55 |
+
if not is_extended: injector_args['dropout_p'] = dropout
|
56 |
+
|
57 |
+
params, negation = injector(**injector_args)
|
58 |
+
for _up, _down in extract_lora_ups_down(
|
59 |
+
model,
|
60 |
+
target_replace_module=REPLACE_MODULES):
|
61 |
+
|
62 |
+
if all(x is not None for x in [_up, _down]):
|
63 |
+
print(f"Lora successfully injected into {model.__class__.__name__}.")
|
64 |
+
|
65 |
+
break
|
66 |
+
|
67 |
+
return params, negation
|
68 |
+
|
69 |
+
class Args:
|
70 |
+
### basic
|
71 |
+
fps = 4
|
72 |
+
height = 320
|
73 |
+
width = 576
|
74 |
+
|
75 |
+
### lora
|
76 |
+
unet_lora_rank = 32
|
77 |
+
|
78 |
+
### gaussian filter parameters
|
79 |
+
kernel_size = 199
|
80 |
+
sigma = 20
|
81 |
+
|
82 |
+
# model
|
83 |
+
denoiser_config = {
|
84 |
+
'scaling_config':{
|
85 |
+
'target': 'dragnuwa.svd.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise',
|
86 |
+
}
|
87 |
+
}
|
88 |
+
|
89 |
+
network_config = {
|
90 |
+
'adm_in_channels': 768, 'num_classes': 'sequential', 'use_checkpoint': True, 'in_channels': 8, 'out_channels': 4, 'model_channels': 320, 'attention_resolutions': [4, 2, 1], 'num_res_blocks': 2, 'channel_mult': [1, 2, 4, 4], 'num_head_channels': 64, 'use_linear_in_transformer': True, 'transformer_depth': 1, 'context_dim': 1024, 'spatial_transformer_attn_type': 'softmax-xformers', 'extra_ff_mix_layer': True, 'use_spatial_context': True, 'merge_strategy': 'learned_with_images', 'video_kernel_size': [3, 1, 1], 'flow_dim_scale': 1,
|
91 |
+
}
|
92 |
+
|
93 |
+
conditioner_emb_models = [
|
94 |
+
{'is_trainable': False,
|
95 |
+
'input_key': 'cond_frames_without_noise', # crossattn
|
96 |
+
'ucg_rate': 0.1,
|
97 |
+
'target': 'dragnuwa.svd.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder',
|
98 |
+
'params':{
|
99 |
+
'n_cond_frames': 1,
|
100 |
+
'n_copies': 1,
|
101 |
+
'open_clip_embedding_config': {
|
102 |
+
'target': 'dragnuwa.svd.modules.encoders.modules.FrozenOpenCLIPImageEmbedder',
|
103 |
+
'params': {
|
104 |
+
'freeze':True,
|
105 |
+
}
|
106 |
+
}
|
107 |
+
}
|
108 |
+
},
|
109 |
+
{'input_key': 'fps_id', # vector
|
110 |
+
'is_trainable': False,
|
111 |
+
'ucg_rate': 0.1,
|
112 |
+
'target': 'dragnuwa.svd.modules.encoders.modules.ConcatTimestepEmbedderND',
|
113 |
+
'params': {
|
114 |
+
'outdim': 256,
|
115 |
+
}
|
116 |
+
},
|
117 |
+
{'input_key': 'motion_bucket_id', # vector
|
118 |
+
'ucg_rate': 0.1,
|
119 |
+
'is_trainable': False,
|
120 |
+
'target': 'dragnuwa.svd.modules.encoders.modules.ConcatTimestepEmbedderND',
|
121 |
+
'params': {
|
122 |
+
'outdim': 256,
|
123 |
+
}
|
124 |
+
},
|
125 |
+
{'input_key': 'cond_frames', # concat
|
126 |
+
'is_trainable': False,
|
127 |
+
'ucg_rate': 0.1,
|
128 |
+
'target': 'dragnuwa.svd.modules.encoders.modules.VideoPredictionEmbedderWithEncoder',
|
129 |
+
'params': {
|
130 |
+
'en_and_decode_n_samples_a_time': 1,
|
131 |
+
'disable_encoder_autocast': True,
|
132 |
+
'n_cond_frames': 1,
|
133 |
+
'n_copies': 1,
|
134 |
+
'is_ae': True,
|
135 |
+
'encoder_config': {
|
136 |
+
'target': 'dragnuwa.svd.models.autoencoder.AutoencoderKLModeOnly',
|
137 |
+
'params': {
|
138 |
+
'embed_dim': 4,
|
139 |
+
'monitor': 'val/rec_loss',
|
140 |
+
'ddconfig': {
|
141 |
+
'attn_type': 'vanilla-xformers',
|
142 |
+
'double_z': True,
|
143 |
+
'z_channels': 4,
|
144 |
+
'resolution': 256,
|
145 |
+
'in_channels': 3,
|
146 |
+
'out_ch': 3,
|
147 |
+
'ch': 128,
|
148 |
+
'ch_mult': [1, 2, 4, 4],
|
149 |
+
'num_res_blocks': 2,
|
150 |
+
'attn_resolutions': [],
|
151 |
+
'dropout': 0.0,
|
152 |
+
},
|
153 |
+
'lossconfig': {
|
154 |
+
'target': 'torch.nn.Identity',
|
155 |
+
}
|
156 |
+
}
|
157 |
+
}
|
158 |
+
}
|
159 |
+
},
|
160 |
+
{'input_key': 'cond_aug', # vector
|
161 |
+
'ucg_rate': 0.1,
|
162 |
+
'is_trainable': False,
|
163 |
+
'target': 'dragnuwa.svd.modules.encoders.modules.ConcatTimestepEmbedderND',
|
164 |
+
'params': {
|
165 |
+
'outdim': 256,
|
166 |
+
}
|
167 |
+
}
|
168 |
+
]
|
169 |
+
|
170 |
+
first_stage_config = {
|
171 |
+
'loss_config': {'target': 'torch.nn.Identity'},
|
172 |
+
'regularizer_config': {'target': 'dragnuwa.svd.modules.autoencoding.regularizers.DiagonalGaussianRegularizer'},
|
173 |
+
'encoder_config':{'target': 'dragnuwa.svd.modules.diffusionmodules.model.Encoder',
|
174 |
+
'params': { 'attn_type':'vanilla',
|
175 |
+
'double_z': True,
|
176 |
+
'z_channels': 4,
|
177 |
+
'resolution': 256,
|
178 |
+
'in_channels': 3,
|
179 |
+
'out_ch': 3,
|
180 |
+
'ch': 128,
|
181 |
+
'ch_mult': [1, 2, 4, 4],
|
182 |
+
'num_res_blocks': 2,
|
183 |
+
'attn_resolutions': [],
|
184 |
+
'dropout': 0.0,
|
185 |
+
}
|
186 |
+
},
|
187 |
+
'decoder_config':{'target': 'dragnuwa.svd.modules.autoencoding.temporal_ae.VideoDecoder',
|
188 |
+
'params': {'attn_type': 'vanilla',
|
189 |
+
'double_z': True,
|
190 |
+
'z_channels': 4,
|
191 |
+
'resolution': 256,
|
192 |
+
'in_channels': 3,
|
193 |
+
'out_ch': 3,
|
194 |
+
'ch': 128,
|
195 |
+
'ch_mult': [1, 2, 4, 4],
|
196 |
+
'num_res_blocks': 2,
|
197 |
+
'attn_resolutions': [],
|
198 |
+
'dropout': 0.0,
|
199 |
+
'video_kernel_size': [3, 1, 1],
|
200 |
+
}
|
201 |
+
},
|
202 |
+
}
|
203 |
+
|
204 |
+
sampler_config = {
|
205 |
+
'discretization_config': {'target': 'dragnuwa.svd.modules.diffusionmodules.discretizer.EDMDiscretization',
|
206 |
+
'params': {'sigma_max': 700.0,},
|
207 |
+
},
|
208 |
+
'guider_config': {'target': 'dragnuwa.svd.modules.diffusionmodules.guiders.LinearPredictionGuider',
|
209 |
+
'params': {'max_scale':2.5,
|
210 |
+
'min_scale':1.0,
|
211 |
+
'num_frames':14},
|
212 |
+
},
|
213 |
+
'num_steps': 25,
|
214 |
+
}
|
215 |
+
|
216 |
+
scale_factor = 0.18215
|
217 |
+
num_frames = 14
|
218 |
+
|
219 |
+
### others
|
220 |
+
seed = 42
|
221 |
+
os.environ["PL_GLOBAL_SEED"] = str(seed)
|
222 |
+
random.seed(seed)
|
223 |
+
np.random.seed(seed)
|
224 |
+
torch.manual_seed(seed)
|
225 |
+
torch.cuda.manual_seed_all(seed)
|
226 |
+
|
227 |
+
|
228 |
+
args = Args()
|
229 |
+
|
230 |
+
def quick_freeze(model):
|
231 |
+
for name, param in model.named_parameters():
|
232 |
+
param.requires_grad = False
|
233 |
+
return model
|
234 |
+
|
235 |
+
class Net(nn.Module):
|
236 |
+
def __init__(self, args):
|
237 |
+
super(Net, self).__init__()
|
238 |
+
self.args = args
|
239 |
+
self.device = 'cpu'
|
240 |
+
### unet
|
241 |
+
model = VideoUNet_flow(**args.network_config)
|
242 |
+
self.model = OpenAIWrapper(model)
|
243 |
+
|
244 |
+
### denoiser and sampler
|
245 |
+
self.denoiser = Denoiser(**args.denoiser_config)
|
246 |
+
self.sampler = EulerEDMSampler(**args.sampler_config)
|
247 |
+
|
248 |
+
### conditioner
|
249 |
+
self.conditioner = GeneralConditioner(args.conditioner_emb_models)
|
250 |
+
|
251 |
+
### first stage model
|
252 |
+
self.first_stage_model = AutoencodingEngine(**args.first_stage_config).eval()
|
253 |
+
|
254 |
+
self.scale_factor = args.scale_factor
|
255 |
+
self.en_and_decode_n_samples_a_time = 1 # decode 1 frame each time to save GPU memory
|
256 |
+
self.num_frames = args.num_frames
|
257 |
+
self.guassian_filter = quick_freeze(get_gaussian_kernel(kernel_size=args.kernel_size, sigma=args.sigma, channels=2))
|
258 |
+
|
259 |
+
unet_lora_params, unet_negation = inject_lora(
|
260 |
+
True, self, ['OpenAIWrapper'], is_extended=False, r=args.unet_lora_rank
|
261 |
+
)
|
262 |
+
|
263 |
+
def to(self, *args, **kwargs):
|
264 |
+
model_converted = super().to(*args, **kwargs)
|
265 |
+
self.device = next(self.parameters()).device
|
266 |
+
self.sampler.device = self.device
|
267 |
+
for embedder in self.conditioner.embedders:
|
268 |
+
if hasattr(embedder, "device"):
|
269 |
+
embedder.device = self.device
|
270 |
+
return model_converted
|
271 |
+
|
272 |
+
def train(self, *args):
|
273 |
+
super().train(*args)
|
274 |
+
self.conditioner.eval()
|
275 |
+
self.first_stage_model.eval()
|
276 |
+
|
277 |
+
def apply_gaussian_filter_on_drag(self, drag):
|
278 |
+
b, l, h, w, c = drag.shape
|
279 |
+
drag = rearrange(drag, 'b l h w c -> (b l) c h w')
|
280 |
+
drag = self.guassian_filter(drag)
|
281 |
+
drag = rearrange(drag, '(b l) c h w -> b l h w c', b=b)
|
282 |
+
return drag
|
283 |
+
|
284 |
+
@torch.no_grad()
|
285 |
+
def decode_first_stage(self, z):
|
286 |
+
z = 1.0 / self.scale_factor * z
|
287 |
+
n_samples = self.en_and_decode_n_samples_a_time # 1
|
288 |
+
n_rounds = math.ceil(z.shape[0] / n_samples)
|
289 |
+
all_out = []
|
290 |
+
for n in range(n_rounds):
|
291 |
+
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
|
292 |
+
out = self.first_stage_model.decode(
|
293 |
+
z[n * n_samples : (n + 1) * n_samples], **kwargs
|
294 |
+
)
|
295 |
+
all_out.append(out)
|
296 |
+
out = torch.cat(all_out, dim=0)
|
297 |
+
return out
|
app.py
ADDED
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
from PIL import Image, ImageFilter
|
5 |
+
import uuid
|
6 |
+
from scipy.interpolate import interp1d, PchipInterpolator
|
7 |
+
import torchvision
|
8 |
+
from utils import *
|
9 |
+
|
10 |
+
output_dir = "outputs"
|
11 |
+
ensure_dirname(output_dir)
|
12 |
+
|
13 |
+
def interpolate_trajectory(points, n_points):
|
14 |
+
x = [point[0] for point in points]
|
15 |
+
y = [point[1] for point in points]
|
16 |
+
|
17 |
+
t = np.linspace(0, 1, len(points))
|
18 |
+
|
19 |
+
# fx = interp1d(t, x, kind='cubic')
|
20 |
+
# fy = interp1d(t, y, kind='cubic')
|
21 |
+
fx = PchipInterpolator(t, x)
|
22 |
+
fy = PchipInterpolator(t, y)
|
23 |
+
|
24 |
+
new_t = np.linspace(0, 1, n_points)
|
25 |
+
|
26 |
+
new_x = fx(new_t)
|
27 |
+
new_y = fy(new_t)
|
28 |
+
new_points = list(zip(new_x, new_y))
|
29 |
+
|
30 |
+
return new_points
|
31 |
+
|
32 |
+
def visualize_drag_v2(background_image_path, splited_tracks, width, height):
|
33 |
+
trajectory_maps = []
|
34 |
+
|
35 |
+
background_image = Image.open(background_image_path).convert('RGBA')
|
36 |
+
background_image = background_image.resize((width, height))
|
37 |
+
w, h = background_image.size
|
38 |
+
transparent_background = np.array(background_image)
|
39 |
+
transparent_background[:, :, -1] = 128
|
40 |
+
transparent_background = Image.fromarray(transparent_background)
|
41 |
+
|
42 |
+
# Create a transparent layer with the same size as the background image
|
43 |
+
transparent_layer = np.zeros((h, w, 4))
|
44 |
+
for splited_track in splited_tracks:
|
45 |
+
if len(splited_track) > 1:
|
46 |
+
splited_track = interpolate_trajectory(splited_track, 16)
|
47 |
+
splited_track = splited_track[:16]
|
48 |
+
for i in range(len(splited_track)-1):
|
49 |
+
start_point = (int(splited_track[i][0]), int(splited_track[i][1]))
|
50 |
+
end_point = (int(splited_track[i+1][0]), int(splited_track[i+1][1]))
|
51 |
+
vx = end_point[0] - start_point[0]
|
52 |
+
vy = end_point[1] - start_point[1]
|
53 |
+
arrow_length = np.sqrt(vx**2 + vy**2)
|
54 |
+
if i == len(splited_track)-2:
|
55 |
+
cv2.arrowedLine(transparent_layer, start_point, end_point, (255, 0, 0, 192), 2, tipLength=8 / arrow_length)
|
56 |
+
else:
|
57 |
+
cv2.line(transparent_layer, start_point, end_point, (255, 0, 0, 192), 2)
|
58 |
+
else:
|
59 |
+
cv2.circle(transparent_layer, (int(splited_track[0][0]), int(splited_track[0][1])), 5, (255, 0, 0, 192), -1)
|
60 |
+
|
61 |
+
transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
|
62 |
+
trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
|
63 |
+
trajectory_maps.append(trajectory_map)
|
64 |
+
return trajectory_maps, transparent_layer
|
65 |
+
|
66 |
+
class Drag:
|
67 |
+
def __init__(self, device, model_path, cfg_path, height, width, model_length):
|
68 |
+
self.device = device
|
69 |
+
cf = import_filename(cfg_path)
|
70 |
+
Net, args = cf.Net, cf.args
|
71 |
+
drag_nuwa_net = Net(args)
|
72 |
+
state_dict = file2data(model_path, map_location='cpu')
|
73 |
+
adaptively_load_state_dict(drag_nuwa_net, state_dict)
|
74 |
+
drag_nuwa_net.eval()
|
75 |
+
drag_nuwa_net.to(device)
|
76 |
+
# drag_nuwa_net.half()
|
77 |
+
self.drag_nuwa_net = drag_nuwa_net
|
78 |
+
self.height = height
|
79 |
+
self.width = width
|
80 |
+
_, model_step, _ = split_filename(model_path)
|
81 |
+
self.ouput_prefix = f'{model_step}_{width}X{height}'
|
82 |
+
self.model_length = model_length
|
83 |
+
|
84 |
+
@torch.no_grad()
|
85 |
+
def forward_sample(self, input_drag, input_first_frame, motion_bucket_id, outputs=dict()):
|
86 |
+
device = self.device
|
87 |
+
|
88 |
+
b, l, h, w, c = input_drag.size()
|
89 |
+
drag = self.drag_nuwa_net.apply_gaussian_filter_on_drag(input_drag)
|
90 |
+
drag = torch.cat([torch.zeros_like(drag[:, 0]).unsqueeze(1), drag], dim=1) # pad the first frame with zero flow
|
91 |
+
drag = rearrange(drag, 'b l h w c -> b l c h w')
|
92 |
+
|
93 |
+
input_conditioner = dict()
|
94 |
+
input_conditioner['cond_frames_without_noise'] = input_first_frame
|
95 |
+
input_conditioner['cond_frames'] = (input_first_frame + 0.02 * torch.randn_like(input_first_frame))
|
96 |
+
input_conditioner['motion_bucket_id'] = torch.tensor([motion_bucket_id]).to(drag.device).repeat(b * (l+1))
|
97 |
+
input_conditioner['fps_id'] = torch.tensor([self.drag_nuwa_net.args.fps]).to(drag.device).repeat(b * (l+1))
|
98 |
+
input_conditioner['cond_aug'] = torch.tensor([0.02]).to(drag.device).repeat(b * (l+1))
|
99 |
+
|
100 |
+
input_conditioner_uc = {}
|
101 |
+
for key in input_conditioner.keys():
|
102 |
+
if key not in input_conditioner_uc and isinstance(input_conditioner[key], torch.Tensor):
|
103 |
+
input_conditioner_uc[key] = input_conditioner[key].clone()
|
104 |
+
|
105 |
+
c, uc = self.drag_nuwa_net.conditioner.get_unconditional_conditioning(
|
106 |
+
input_conditioner,
|
107 |
+
batch_uc=input_conditioner_uc,
|
108 |
+
force_uc_zero_embeddings=[
|
109 |
+
"cond_frames",
|
110 |
+
"cond_frames_without_noise",
|
111 |
+
],
|
112 |
+
)
|
113 |
+
|
114 |
+
for k in ["crossattn", "concat"]:
|
115 |
+
uc[k] = repeat(uc[k], "b ... -> b t ...", t=self.drag_nuwa_net.num_frames)
|
116 |
+
uc[k] = rearrange(uc[k], "b t ... -> (b t) ...")
|
117 |
+
c[k] = repeat(c[k], "b ... -> b t ...", t=self.drag_nuwa_net.num_frames)
|
118 |
+
c[k] = rearrange(c[k], "b t ... -> (b t) ...")
|
119 |
+
|
120 |
+
H, W = input_conditioner['cond_frames_without_noise'].shape[2:]
|
121 |
+
shape = (self.drag_nuwa_net.num_frames, 4, H // 8, W // 8)
|
122 |
+
randn = torch.randn(shape).to(self.device)
|
123 |
+
|
124 |
+
additional_model_inputs = {}
|
125 |
+
additional_model_inputs["image_only_indicator"] = torch.zeros(
|
126 |
+
2, self.drag_nuwa_net.num_frames
|
127 |
+
).to(self.device)
|
128 |
+
additional_model_inputs["num_video_frames"] = self.drag_nuwa_net.num_frames
|
129 |
+
additional_model_inputs["flow"] = drag.repeat(2, 1, 1, 1, 1) # c and uc
|
130 |
+
|
131 |
+
def denoiser(input, sigma, c):
|
132 |
+
return self.drag_nuwa_net.denoiser(self.drag_nuwa_net.model, input, sigma, c, **additional_model_inputs)
|
133 |
+
|
134 |
+
samples_z = self.drag_nuwa_net.sampler(denoiser, randn, cond=c, uc=uc)
|
135 |
+
samples = self.drag_nuwa_net.decode_first_stage(samples_z)
|
136 |
+
|
137 |
+
outputs['logits_imgs'] = rearrange(samples, '(b l) c h w -> b l c h w', b=b)
|
138 |
+
return outputs
|
139 |
+
|
140 |
+
def run(self, first_frame_path, tracking_points, inference_batch_size, motion_bucket_id):
|
141 |
+
original_width, original_height=576, 320
|
142 |
+
|
143 |
+
input_all_points = tracking_points.constructor_args['value']
|
144 |
+
resized_all_points = [tuple([tuple([int(e1[0]*self.width/original_width), int(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points]
|
145 |
+
|
146 |
+
input_drag = torch.zeros(self.model_length - 1, self.height, self.width, 2)
|
147 |
+
for splited_track in resized_all_points:
|
148 |
+
if len(splited_track) == 1: # stationary point
|
149 |
+
displacement_point = tuple([splited_track[0][0] + 1, splited_track[0][1] + 1])
|
150 |
+
splited_track = tuple([splited_track[0], displacement_point])
|
151 |
+
# interpolate the track
|
152 |
+
splited_track = interpolate_trajectory(splited_track, self.model_length)
|
153 |
+
splited_track = splited_track[:self.model_length]
|
154 |
+
if len(splited_track) < self.model_length:
|
155 |
+
splited_track = splited_track + [splited_track[-1]] * (self.model_length -len(splited_track))
|
156 |
+
for i in range(self.model_length - 1):
|
157 |
+
start_point = splited_track[i]
|
158 |
+
end_point = splited_track[i+1]
|
159 |
+
input_drag[i][int(start_point[1])][int(start_point[0])][0] = end_point[0] - start_point[0]
|
160 |
+
input_drag[i][int(start_point[1])][int(start_point[0])][1] = end_point[1] - start_point[1]
|
161 |
+
|
162 |
+
dir, base, ext = split_filename(first_frame_path)
|
163 |
+
id = base.split('_')[-1]
|
164 |
+
|
165 |
+
image_pil = image2pil(first_frame_path)
|
166 |
+
image_pil = image_pil.resize((self.width, self.height), Image.BILINEAR).convert('RGB')
|
167 |
+
|
168 |
+
visualized_drag, _ = visualize_drag_v2(first_frame_path, resized_all_points, self.width, self.height)
|
169 |
+
|
170 |
+
first_frames_transform = transforms.Compose([
|
171 |
+
lambda x: Image.fromarray(x),
|
172 |
+
transforms.ToTensor(),
|
173 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
174 |
+
])
|
175 |
+
|
176 |
+
outputs = None
|
177 |
+
ouput_video_list = []
|
178 |
+
num_inference = 1
|
179 |
+
for i in tqdm(range(num_inference)):
|
180 |
+
if not outputs:
|
181 |
+
first_frames = image2arr(first_frame_path)
|
182 |
+
first_frames = repeat(first_frames_transform(first_frames), 'c h w -> b c h w', b=inference_batch_size).to(self.device)
|
183 |
+
else:
|
184 |
+
first_frames = outputs['logits_imgs'][:, -1]
|
185 |
+
|
186 |
+
outputs = self.forward_sample(
|
187 |
+
repeat(input_drag[i*(self.model_length - 1):(i+1)*(self.model_length - 1)], 'l h w c -> b l h w c', b=inference_batch_size).to(self.device),
|
188 |
+
first_frames,
|
189 |
+
motion_bucket_id)
|
190 |
+
ouput_video_list.append(outputs['logits_imgs'])
|
191 |
+
|
192 |
+
for i in range(inference_batch_size):
|
193 |
+
ouput_tensor = [ouput_video_list[0][i]]
|
194 |
+
for j in range(num_inference - 1):
|
195 |
+
ouput_tensor.append(ouput_video_list[j+1][i][1:])
|
196 |
+
ouput_tensor = torch.cat(ouput_tensor, dim=0)
|
197 |
+
outputs_path = os.path.join(output_dir, f'output_{i}_{id}.gif')
|
198 |
+
data2file([transforms.ToPILImage('RGB')(utils.make_grid(e.to(torch.float32).cpu(), normalize=True, range=(-1, 1))) for e in ouput_tensor], outputs_path,
|
199 |
+
printable=False, duration=1 / 6, override=True)
|
200 |
+
|
201 |
+
return visualized_drag[0], outputs_path
|
202 |
+
|
203 |
+
with gr.Blocks() as demo:
|
204 |
+
gr.Markdown("""<h1 align="center">DragNUWA 1.5</h1><br>""")
|
205 |
+
|
206 |
+
gr.Markdown("""Official Gradio Demo for <a href='https://arxiv.org/abs/2308.08089'><b>DragNUWA: Fine-grained Control in Video Generation by Integrating Text, Image, and Trajectory</b></a>.<br>
|
207 |
+
🔥DragNUWA enables users to manipulate backgrounds or objects within images directly, and the model seamlessly translates these actions into **camera movements** or **object motions**, generating the corresponding video.<br>
|
208 |
+
🔥DragNUWA 1.5 enables Stable Video Diffusion to animate an image according to specific path.<br>""")
|
209 |
+
|
210 |
+
gr.Markdown("""## Usage: <br>
|
211 |
+
1. Upload an image via the "Upload Image" button.<br>
|
212 |
+
2. Draw some drags.<br>
|
213 |
+
2.1. Click "Add Drag" when you want to add a control path.<br>
|
214 |
+
2.2. You can click several points which forms a path.<br>
|
215 |
+
2.3. Click "Delete last drag" to delete the whole lastest path.<br>
|
216 |
+
2.4. Click "Delete last step" to delete the lastest clicked control point.<br>
|
217 |
+
3. Animate the image according the path with a click on "Run" button. <br>""")
|
218 |
+
|
219 |
+
DragNUWA_net = Drag("cuda:0", 'models/drag_nuwa_svd.pth', 'DragNUWA_net.py', 320, 576, 14)
|
220 |
+
first_frame_path = gr.State()
|
221 |
+
tracking_points = gr.State([])
|
222 |
+
|
223 |
+
def reset_states(first_frame_path, tracking_points):
|
224 |
+
first_frame_path = gr.State()
|
225 |
+
tracking_points = gr.State([])
|
226 |
+
return first_frame_path, tracking_points
|
227 |
+
|
228 |
+
def preprocess_image(image):
|
229 |
+
image_pil = image2pil(image.name)
|
230 |
+
raw_w, raw_h = image_pil.size
|
231 |
+
resize_ratio = max(576/raw_w, 320/raw_h)
|
232 |
+
image_pil = image_pil.resize((int(raw_w * resize_ratio), int(raw_h * resize_ratio)), Image.BILINEAR)
|
233 |
+
image_pil = transforms.CenterCrop((320, 576))(image_pil.convert('RGB'))
|
234 |
+
|
235 |
+
first_frame_path = os.path.join(output_dir, f"first_frame_{str(uuid.uuid4())[:4]}.png")
|
236 |
+
image_pil.save(first_frame_path)
|
237 |
+
|
238 |
+
return first_frame_path, first_frame_path, gr.State([])
|
239 |
+
|
240 |
+
def add_drag(tracking_points):
|
241 |
+
tracking_points.constructor_args['value'].append([])
|
242 |
+
return tracking_points
|
243 |
+
|
244 |
+
def delete_last_drag(tracking_points, first_frame_path):
|
245 |
+
tracking_points.constructor_args['value'].pop()
|
246 |
+
transparent_background = Image.open(first_frame_path).convert('RGBA')
|
247 |
+
w, h = transparent_background.size
|
248 |
+
transparent_layer = np.zeros((h, w, 4))
|
249 |
+
for track in tracking_points.constructor_args['value']:
|
250 |
+
if len(track) > 1:
|
251 |
+
for i in range(len(track)-1):
|
252 |
+
start_point = track[i]
|
253 |
+
end_point = track[i+1]
|
254 |
+
vx = end_point[0] - start_point[0]
|
255 |
+
vy = end_point[1] - start_point[1]
|
256 |
+
arrow_length = np.sqrt(vx**2 + vy**2)
|
257 |
+
if i == len(track)-2:
|
258 |
+
cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length)
|
259 |
+
else:
|
260 |
+
cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,)
|
261 |
+
else:
|
262 |
+
cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
|
263 |
+
|
264 |
+
transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
|
265 |
+
trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
|
266 |
+
return tracking_points, trajectory_map
|
267 |
+
|
268 |
+
def delete_last_step(tracking_points, first_frame_path):
|
269 |
+
tracking_points.constructor_args['value'][-1].pop()
|
270 |
+
transparent_background = Image.open(first_frame_path).convert('RGBA')
|
271 |
+
w, h = transparent_background.size
|
272 |
+
transparent_layer = np.zeros((h, w, 4))
|
273 |
+
for track in tracking_points.constructor_args['value']:
|
274 |
+
if len(track) > 1:
|
275 |
+
for i in range(len(track)-1):
|
276 |
+
start_point = track[i]
|
277 |
+
end_point = track[i+1]
|
278 |
+
vx = end_point[0] - start_point[0]
|
279 |
+
vy = end_point[1] - start_point[1]
|
280 |
+
arrow_length = np.sqrt(vx**2 + vy**2)
|
281 |
+
if i == len(track)-2:
|
282 |
+
cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length)
|
283 |
+
else:
|
284 |
+
cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,)
|
285 |
+
else:
|
286 |
+
cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
|
287 |
+
|
288 |
+
transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
|
289 |
+
trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
|
290 |
+
return tracking_points, trajectory_map
|
291 |
+
|
292 |
+
def add_tracking_points(tracking_points, first_frame_path, evt: gr.SelectData): # SelectData is a subclass of EventData
|
293 |
+
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
|
294 |
+
tracking_points.constructor_args['value'][-1].append(evt.index)
|
295 |
+
|
296 |
+
transparent_background = Image.open(first_frame_path).convert('RGBA')
|
297 |
+
w, h = transparent_background.size
|
298 |
+
transparent_layer = np.zeros((h, w, 4))
|
299 |
+
for track in tracking_points.constructor_args['value']:
|
300 |
+
if len(track) > 1:
|
301 |
+
for i in range(len(track)-1):
|
302 |
+
start_point = track[i]
|
303 |
+
end_point = track[i+1]
|
304 |
+
vx = end_point[0] - start_point[0]
|
305 |
+
vy = end_point[1] - start_point[1]
|
306 |
+
arrow_length = np.sqrt(vx**2 + vy**2)
|
307 |
+
if i == len(track)-2:
|
308 |
+
cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length)
|
309 |
+
else:
|
310 |
+
cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,)
|
311 |
+
else:
|
312 |
+
cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
|
313 |
+
|
314 |
+
transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
|
315 |
+
trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
|
316 |
+
return tracking_points, trajectory_map
|
317 |
+
|
318 |
+
with gr.Row():
|
319 |
+
with gr.Column(scale=1):
|
320 |
+
image_upload_button = gr.UploadButton(label="Upload Image",file_types=["image"])
|
321 |
+
add_drag_button = gr.Button(value="Add Drag")
|
322 |
+
reset_button = gr.Button(value="Reset")
|
323 |
+
run_button = gr.Button(value="Run")
|
324 |
+
delete_last_drag_button = gr.Button(value="Delete last drag")
|
325 |
+
delete_last_step_button = gr.Button(value="Delete last step")
|
326 |
+
|
327 |
+
with gr.Column(scale=7):
|
328 |
+
with gr.Row():
|
329 |
+
with gr.Column(scale=6):
|
330 |
+
input_image = gr.Image(label=None,
|
331 |
+
interactive=True,
|
332 |
+
height=320,
|
333 |
+
width=576,)
|
334 |
+
with gr.Column(scale=6):
|
335 |
+
output_image = gr.Image(label=None,
|
336 |
+
height=320,
|
337 |
+
width=576,)
|
338 |
+
|
339 |
+
with gr.Row():
|
340 |
+
with gr.Column(scale=1):
|
341 |
+
inference_batch_size = gr.Slider(label='Inference Batch Size',
|
342 |
+
minimum=1,
|
343 |
+
maximum=1,
|
344 |
+
step=1,
|
345 |
+
value=1)
|
346 |
+
|
347 |
+
motion_bucket_id = gr.Slider(label='Motion Bucket',
|
348 |
+
minimum=1,
|
349 |
+
maximum=100,
|
350 |
+
step=1,
|
351 |
+
value=4)
|
352 |
+
|
353 |
+
with gr.Column(scale=5):
|
354 |
+
output_video = gr.Image(label="Output Video",
|
355 |
+
height=320,
|
356 |
+
width=576,)
|
357 |
+
|
358 |
+
with gr.Row():
|
359 |
+
gr.Markdown("""
|
360 |
+
## Citation
|
361 |
+
```bibtex
|
362 |
+
@article{yin2023dragnuwa,
|
363 |
+
title={Dragnuwa: Fine-grained control in video generation by integrating text, image, and trajectory},
|
364 |
+
author={Yin, Shengming and Wu, Chenfei and Liang, Jian and Shi, Jie and Li, Houqiang and Ming, Gong and Duan, Nan},
|
365 |
+
journal={arXiv preprint arXiv:2308.08089},
|
366 |
+
year={2023}
|
367 |
+
}
|
368 |
+
```
|
369 |
+
""")
|
370 |
+
|
371 |
+
|
372 |
+
image_upload_button.upload(preprocess_image, image_upload_button, [input_image, first_frame_path, tracking_points])
|
373 |
+
|
374 |
+
add_drag_button.click(add_drag, tracking_points, tracking_points)
|
375 |
+
|
376 |
+
delete_last_drag_button.click(delete_last_drag, [tracking_points, first_frame_path], [tracking_points, input_image])
|
377 |
+
|
378 |
+
delete_last_step_button.click(delete_last_step, [tracking_points, first_frame_path], [tracking_points, input_image])
|
379 |
+
|
380 |
+
reset_button.click(reset_states, [first_frame_path, tracking_points], [first_frame_path, tracking_points])
|
381 |
+
|
382 |
+
input_image.select(add_tracking_points, [tracking_points, first_frame_path], [tracking_points, input_image])
|
383 |
+
|
384 |
+
run_button.click(DragNUWA_net.run, [first_frame_path, tracking_points, inference_batch_size, motion_bucket_id], [output_image, output_video])
|
385 |
+
|
386 |
+
demo.launch(server_name="0.0.0.0", debug=True)
|
assets/DragNUWA1.0/Figure1.gif
ADDED
Git LFS Details
|
assets/DragNUWA1.0/Figure2.gif
ADDED
Git LFS Details
|
assets/DragNUWA1.0/Figure3.gif
ADDED
Git LFS Details
|
assets/DragNUWA1.5/Figure1.gif
ADDED
Git LFS Details
|
assets/DragNUWA1.5/Figure2.gif
ADDED
Git LFS Details
|
assets/DragNUWA1.5/Figure3.gif
ADDED
Git LFS Details
|
assets/DragNUWA1.5/Figure4.gif
ADDED
Git LFS Details
|
dragnuwa.md
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DragNUWA
|
2 |
+
|
3 |
+
**DragNUWA** enables users to manipulate backgrounds or objects within images directly, and the model seamlessly translates these actions into **camera movements** or **object motions**, generating the corresponding video.
|
4 |
+
|
5 |
+
See our paper: [DragNUWA: Fine-grained Control in Video Generation by Integrating Text, Image, and Trajectory](https://arxiv.org/abs/2308.08089)
|
6 |
+
|
7 |
+
<a src="https://img.shields.io/badge/%F0%9F%A4%97-Open%20in%20Spaces-blue" href="TOBEDONE">
|
8 |
+
<img src="https://img.shields.io/badge/%F0%9F%A4%97-Open%20in%20Spaces-blue" alt="Open in Spaces">
|
9 |
+
</a>
|
10 |
+
<a src="https://colab.research.google.com/assets/colab-badge.svg" href="TOBEDONE">
|
11 |
+
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab">
|
12 |
+
</a>
|
13 |
+
|
14 |
+
### DragNUWA 1.5 (Updated on Jan 8, 2024)
|
15 |
+
|
16 |
+
**DragNUWA 1.5** enables Stable Video Diffusion to animate an image according to specific path.
|
17 |
+
|
18 |
+
<p align="center">
|
19 |
+
<img src="assets/DragNUWA1.5/Figure1.gif" width="90%">
|
20 |
+
</p>
|
21 |
+
|
22 |
+
<p align="center">
|
23 |
+
<img src="assets/DragNUWA1.5/Figure2.gif" width="90%">
|
24 |
+
</p>
|
25 |
+
<p align="center">
|
26 |
+
<img src="assets/DragNUWA1.5/Figure3.gif" width="90%">
|
27 |
+
</p>
|
28 |
+
<p align="center">
|
29 |
+
<img src="assets/DragNUWA1.5/Figure4.gif" width="90%">
|
30 |
+
</p>
|
31 |
+
|
32 |
+
### DragNUWA 1.0 (Original Paper)
|
33 |
+
[**DragNUWA 1.0**](https://arxiv.org/abs/2308.08089) utilizes text, images, and trajectory as three essential control factors to facilitate highly controllable video generation from semantic, spatial, and temporal aspects.
|
34 |
+
|
35 |
+
<p align="center">
|
36 |
+
<img src="assets/DragNUWA1.0/Figure1.gif" width="90%">
|
37 |
+
</p>
|
38 |
+
<p align="center">
|
39 |
+
<img src="assets/DragNUWA1.0/Figure2.gif" width="100%">
|
40 |
+
</p>
|
41 |
+
<p align="center">
|
42 |
+
<img src="assets/DragNUWA1.0/Figure3.gif" width="100%">
|
43 |
+
</p>
|
44 |
+
|
45 |
+
## Getting Start
|
46 |
+
|
47 |
+
### Setting Environment
|
48 |
+
```Shell
|
49 |
+
git clone -b svd https://github.com/ProjectNUWA/DragNUWA.git
|
50 |
+
cd DragNUWA
|
51 |
+
|
52 |
+
conda create -n DragNUWA python=3.8
|
53 |
+
conda activate DragNUWA
|
54 |
+
pip install -r environment.txt
|
55 |
+
```
|
56 |
+
|
57 |
+
### Download Pretrained Weights
|
58 |
+
Download the [Pretrained Weights](https://drive.google.com/file/d/1Z4JOley0SJCb35kFF4PCc6N6P1ftfX4i/view) to `models/` directory or directly run `bash models/Download.sh`.
|
59 |
+
|
60 |
+
### Drag and Animate !
|
61 |
+
```Shell
|
62 |
+
python DragNUWA_demo.py
|
63 |
+
```
|
64 |
+
It will launch a gradio demo, and you can drag an image and animate it!
|
65 |
+
|
66 |
+
### Acknowledgement
|
67 |
+
We appreciate the open source of the following projects:
|
68 |
+
[Stable Video Diffusion](https://github.com/Stability-AI/generative-models)  
|
69 |
+
[Hugging Face](https://github.com/huggingface)  
|
70 |
+
[UniMatch](https://github.com/autonomousvision/unimatch) 
|
71 |
+
|
72 |
+
### Citation
|
73 |
+
```bibtex
|
74 |
+
@article{yin2023dragnuwa,
|
75 |
+
title={Dragnuwa: Fine-grained control in video generation by integrating text, image, and trajectory},
|
76 |
+
author={Yin, Shengming and Wu, Chenfei and Liang, Jian and Shi, Jie and Li, Houqiang and Ming, Gong and Duan, Nan},
|
77 |
+
journal={arXiv preprint arXiv:2308.08089},
|
78 |
+
year={2023}
|
79 |
+
}
|
80 |
+
```
|
81 |
+
|
dragnuwa/__init__.py
ADDED
File without changes
|
dragnuwa/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (131 Bytes). View file
|
|
dragnuwa/__pycache__/lora.cpython-38.pyc
ADDED
Binary file (9.18 kB). View file
|
|
dragnuwa/lora.py
ADDED
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils import *
|
2 |
+
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
3 |
+
|
4 |
+
|
5 |
+
class LoraInjectedLinear(nn.Module):
|
6 |
+
def __init__(
|
7 |
+
self, in_features, out_features, bias=False, r=4, dropout_p=0.1, scale=1.0
|
8 |
+
):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
if r > min(in_features, out_features):
|
12 |
+
#raise ValueError(
|
13 |
+
# f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
|
14 |
+
#)
|
15 |
+
print(f"LoRA rank {r} is too large. setting to: {min(in_features, out_features)}")
|
16 |
+
r = min(in_features, out_features)
|
17 |
+
|
18 |
+
self.r = r
|
19 |
+
self.linear = nn.Linear(in_features, out_features, bias)
|
20 |
+
self.lora_down = nn.Linear(in_features, r, bias=False)
|
21 |
+
self.dropout = nn.Dropout(dropout_p)
|
22 |
+
self.lora_up = nn.Linear(r, out_features, bias=False)
|
23 |
+
self.scale = scale
|
24 |
+
self.selector = nn.Identity()
|
25 |
+
|
26 |
+
nn.init.normal_(self.lora_down.weight, std=1 / r)
|
27 |
+
nn.init.zeros_(self.lora_up.weight)
|
28 |
+
|
29 |
+
def forward(self, input):
|
30 |
+
return (
|
31 |
+
self.linear(input)
|
32 |
+
+ self.dropout(self.lora_up(self.selector(self.lora_down(input))))
|
33 |
+
* self.scale
|
34 |
+
)
|
35 |
+
|
36 |
+
def realize_as_lora(self):
|
37 |
+
return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
|
38 |
+
|
39 |
+
def set_selector_from_diag(self, diag: torch.Tensor):
|
40 |
+
# diag is a 1D tensor of size (r,)
|
41 |
+
assert diag.shape == (self.r,)
|
42 |
+
self.selector = nn.Linear(self.r, self.r, bias=False)
|
43 |
+
self.selector.weight.data = torch.diag(diag)
|
44 |
+
self.selector.weight.data = self.selector.weight.data.to(
|
45 |
+
self.lora_up.weight.device
|
46 |
+
).to(self.lora_up.weight.dtype)
|
47 |
+
|
48 |
+
class LoraInjectedConv2d(nn.Module):
|
49 |
+
def __init__(
|
50 |
+
self,
|
51 |
+
in_channels: int,
|
52 |
+
out_channels: int,
|
53 |
+
kernel_size,
|
54 |
+
stride=1,
|
55 |
+
padding=0,
|
56 |
+
dilation=1,
|
57 |
+
groups: int = 1,
|
58 |
+
bias: bool = True,
|
59 |
+
r: int = 4,
|
60 |
+
dropout_p: float = 0.1,
|
61 |
+
scale: float = 1.0,
|
62 |
+
):
|
63 |
+
super().__init__()
|
64 |
+
if r > min(in_channels, out_channels):
|
65 |
+
print(f"LoRA rank {r} is too large. setting to: {min(in_channels, out_channels)}")
|
66 |
+
r = min(in_channels, out_channels)
|
67 |
+
|
68 |
+
self.r = r
|
69 |
+
self.conv = nn.Conv2d(
|
70 |
+
in_channels=in_channels,
|
71 |
+
out_channels=out_channels,
|
72 |
+
kernel_size=kernel_size,
|
73 |
+
stride=stride,
|
74 |
+
padding=padding,
|
75 |
+
dilation=dilation,
|
76 |
+
groups=groups,
|
77 |
+
bias=bias,
|
78 |
+
)
|
79 |
+
|
80 |
+
self.lora_down = nn.Conv2d(
|
81 |
+
in_channels=in_channels,
|
82 |
+
out_channels=r,
|
83 |
+
kernel_size=kernel_size,
|
84 |
+
stride=stride,
|
85 |
+
padding=padding,
|
86 |
+
dilation=dilation,
|
87 |
+
groups=groups,
|
88 |
+
bias=False,
|
89 |
+
)
|
90 |
+
self.dropout = nn.Dropout(dropout_p)
|
91 |
+
self.lora_up = nn.Conv2d(
|
92 |
+
in_channels=r,
|
93 |
+
out_channels=out_channels,
|
94 |
+
kernel_size=1,
|
95 |
+
stride=1,
|
96 |
+
padding=0,
|
97 |
+
bias=False,
|
98 |
+
)
|
99 |
+
self.selector = nn.Identity()
|
100 |
+
self.scale = scale
|
101 |
+
|
102 |
+
nn.init.normal_(self.lora_down.weight, std=1 / r)
|
103 |
+
nn.init.zeros_(self.lora_up.weight)
|
104 |
+
|
105 |
+
def forward(self, input):
|
106 |
+
return (
|
107 |
+
self.conv(input)
|
108 |
+
+ self.dropout(self.lora_up(self.selector(self.lora_down(input))))
|
109 |
+
* self.scale
|
110 |
+
)
|
111 |
+
|
112 |
+
def realize_as_lora(self):
|
113 |
+
return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
|
114 |
+
|
115 |
+
def set_selector_from_diag(self, diag: torch.Tensor):
|
116 |
+
# diag is a 1D tensor of size (r,)
|
117 |
+
assert diag.shape == (self.r,)
|
118 |
+
self.selector = nn.Conv2d(
|
119 |
+
in_channels=self.r,
|
120 |
+
out_channels=self.r,
|
121 |
+
kernel_size=1,
|
122 |
+
stride=1,
|
123 |
+
padding=0,
|
124 |
+
bias=False,
|
125 |
+
)
|
126 |
+
self.selector.weight.data = torch.diag(diag)
|
127 |
+
|
128 |
+
# same device + dtype as lora_up
|
129 |
+
self.selector.weight.data = self.selector.weight.data.to(
|
130 |
+
self.lora_up.weight.device
|
131 |
+
).to(self.lora_up.weight.dtype)
|
132 |
+
|
133 |
+
class LoraInjectedConv3d(nn.Module):
|
134 |
+
def __init__(
|
135 |
+
self,
|
136 |
+
in_channels: int,
|
137 |
+
out_channels: int,
|
138 |
+
kernel_size: (3, 1, 1),
|
139 |
+
padding: (1, 0, 0),
|
140 |
+
bias: bool = False,
|
141 |
+
r: int = 4,
|
142 |
+
dropout_p: float = 0,
|
143 |
+
scale: float = 1.0,
|
144 |
+
):
|
145 |
+
super().__init__()
|
146 |
+
if r > min(in_channels, out_channels):
|
147 |
+
print(f"LoRA rank {r} is too large. setting to: {min(in_channels, out_channels)}")
|
148 |
+
r = min(in_channels, out_channels)
|
149 |
+
|
150 |
+
self.r = r
|
151 |
+
self.kernel_size = kernel_size
|
152 |
+
self.padding = padding
|
153 |
+
self.conv = nn.Conv3d(
|
154 |
+
in_channels=in_channels,
|
155 |
+
out_channels=out_channels,
|
156 |
+
kernel_size=kernel_size,
|
157 |
+
padding=padding,
|
158 |
+
)
|
159 |
+
|
160 |
+
self.lora_down = nn.Conv3d(
|
161 |
+
in_channels=in_channels,
|
162 |
+
out_channels=r,
|
163 |
+
kernel_size=kernel_size,
|
164 |
+
bias=False,
|
165 |
+
padding=padding
|
166 |
+
)
|
167 |
+
self.dropout = nn.Dropout(dropout_p)
|
168 |
+
self.lora_up = nn.Conv3d(
|
169 |
+
in_channels=r,
|
170 |
+
out_channels=out_channels,
|
171 |
+
kernel_size=1,
|
172 |
+
stride=1,
|
173 |
+
padding=0,
|
174 |
+
bias=False,
|
175 |
+
)
|
176 |
+
self.selector = nn.Identity()
|
177 |
+
self.scale = scale
|
178 |
+
|
179 |
+
nn.init.normal_(self.lora_down.weight, std=1 / r)
|
180 |
+
nn.init.zeros_(self.lora_up.weight)
|
181 |
+
|
182 |
+
def forward(self, input):
|
183 |
+
return (
|
184 |
+
self.conv(input)
|
185 |
+
+ self.dropout(self.lora_up(self.selector(self.lora_down(input))))
|
186 |
+
* self.scale
|
187 |
+
)
|
188 |
+
|
189 |
+
def realize_as_lora(self):
|
190 |
+
return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
|
191 |
+
|
192 |
+
def set_selector_from_diag(self, diag: torch.Tensor):
|
193 |
+
# diag is a 1D tensor of size (r,)
|
194 |
+
assert diag.shape == (self.r,)
|
195 |
+
self.selector = nn.Conv3d(
|
196 |
+
in_channels=self.r,
|
197 |
+
out_channels=self.r,
|
198 |
+
kernel_size=1,
|
199 |
+
stride=1,
|
200 |
+
padding=0,
|
201 |
+
bias=False,
|
202 |
+
)
|
203 |
+
self.selector.weight.data = torch.diag(diag)
|
204 |
+
|
205 |
+
# same device + dtype as lora_up
|
206 |
+
self.selector.weight.data = self.selector.weight.data.to(
|
207 |
+
self.lora_up.weight.device
|
208 |
+
).to(self.lora_up.weight.dtype)
|
209 |
+
|
210 |
+
def _find_modules(
|
211 |
+
model,
|
212 |
+
ancestor_class: Optional[Set[str]] = None,
|
213 |
+
search_class: List[Type[nn.Module]] = [nn.Linear],
|
214 |
+
exclude_children_of: Optional[List[Type[nn.Module]]] = [
|
215 |
+
LoraInjectedLinear,
|
216 |
+
LoraInjectedConv2d,
|
217 |
+
LoraInjectedConv3d
|
218 |
+
],
|
219 |
+
):
|
220 |
+
"""
|
221 |
+
Find all modules of a certain class (or union of classes) that are direct or
|
222 |
+
indirect descendants of other modules of a certain class (or union of classes).
|
223 |
+
|
224 |
+
Returns all matching modules, along with the parent of those moduless and the
|
225 |
+
names they are referenced by.
|
226 |
+
"""
|
227 |
+
|
228 |
+
# Get the targets we should replace all linears under
|
229 |
+
if ancestor_class is not None:
|
230 |
+
ancestors = (
|
231 |
+
module
|
232 |
+
for module in model.modules()
|
233 |
+
if module.__class__.__name__ in ancestor_class
|
234 |
+
)
|
235 |
+
else:
|
236 |
+
# this, incase you want to naively iterate over all modules.
|
237 |
+
ancestors = [module for module in model.modules()]
|
238 |
+
|
239 |
+
# For each target find every linear_class module that isn't a child of a LoraInjectedLinear
|
240 |
+
for ancestor in ancestors:
|
241 |
+
for fullname, module in ancestor.named_modules():
|
242 |
+
if any([isinstance(module, _class) for _class in search_class]):
|
243 |
+
# Find the direct parent if this is a descendant, not a child, of target
|
244 |
+
*path, name = fullname.split(".")
|
245 |
+
parent = ancestor
|
246 |
+
while path:
|
247 |
+
parent = parent.get_submodule(path.pop(0))
|
248 |
+
# Skip this linear if it's a child of a LoraInjectedLinear
|
249 |
+
if exclude_children_of and any(
|
250 |
+
[isinstance(parent, _class) for _class in exclude_children_of]
|
251 |
+
):
|
252 |
+
continue
|
253 |
+
# Otherwise, yield it
|
254 |
+
yield parent, name, module
|
255 |
+
|
256 |
+
|
257 |
+
def inject_trainable_lora(
|
258 |
+
model: nn.Module,
|
259 |
+
target_replace_module,
|
260 |
+
r: int = 4,
|
261 |
+
loras=None, # path to lora .pt
|
262 |
+
verbose: bool = False,
|
263 |
+
dropout_p: float = 0.0,
|
264 |
+
scale: float = 1.0,
|
265 |
+
):
|
266 |
+
"""
|
267 |
+
inject lora into model, and returns lora parameter groups.
|
268 |
+
"""
|
269 |
+
|
270 |
+
require_grad_params = []
|
271 |
+
names = []
|
272 |
+
|
273 |
+
if loras != None:
|
274 |
+
loras = torch.load(loras)
|
275 |
+
|
276 |
+
for _module, name, _child_module in _find_modules(
|
277 |
+
model, target_replace_module, search_class=[nn.Linear]
|
278 |
+
):
|
279 |
+
weight = _child_module.weight
|
280 |
+
bias = _child_module.bias
|
281 |
+
if verbose:
|
282 |
+
print("LoRA Injection : injecting lora into ", name)
|
283 |
+
print("LoRA Injection : weight shape", weight.shape)
|
284 |
+
_tmp = LoraInjectedLinear(
|
285 |
+
_child_module.in_features,
|
286 |
+
_child_module.out_features,
|
287 |
+
_child_module.bias is not None,
|
288 |
+
r=r,
|
289 |
+
dropout_p=dropout_p,
|
290 |
+
scale=scale,
|
291 |
+
)
|
292 |
+
_tmp.linear.weight = weight
|
293 |
+
if bias is not None:
|
294 |
+
_tmp.linear.bias = bias
|
295 |
+
|
296 |
+
# switch the module
|
297 |
+
_tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
|
298 |
+
_module._modules[name] = _tmp
|
299 |
+
|
300 |
+
require_grad_params.append(_module._modules[name].lora_up.parameters())
|
301 |
+
require_grad_params.append(_module._modules[name].lora_down.parameters())
|
302 |
+
|
303 |
+
if loras != None:
|
304 |
+
_module._modules[name].lora_up.weight = loras.pop(0)
|
305 |
+
_module._modules[name].lora_down.weight = loras.pop(0)
|
306 |
+
|
307 |
+
_module._modules[name].lora_up.weight.requires_grad = True
|
308 |
+
_module._modules[name].lora_down.weight.requires_grad = True
|
309 |
+
names.append(name)
|
310 |
+
|
311 |
+
return require_grad_params, names
|
312 |
+
|
313 |
+
|
314 |
+
def inject_trainable_lora_extended(
|
315 |
+
model: nn.Module,
|
316 |
+
target_replace_module,
|
317 |
+
r: int = 4,
|
318 |
+
loras=None, # path to lora .pt
|
319 |
+
):
|
320 |
+
"""
|
321 |
+
inject lora into model, and returns lora parameter groups.
|
322 |
+
"""
|
323 |
+
|
324 |
+
require_grad_params = []
|
325 |
+
names = []
|
326 |
+
|
327 |
+
if loras != None:
|
328 |
+
loras = torch.load(loras)
|
329 |
+
|
330 |
+
for _module, name, _child_module in _find_modules(
|
331 |
+
model, target_replace_module, search_class=[nn.Linear, nn.Conv2d, nn.Conv3d]
|
332 |
+
):
|
333 |
+
if _child_module.__class__ == nn.Linear:
|
334 |
+
weight = _child_module.weight
|
335 |
+
bias = _child_module.bias
|
336 |
+
_tmp = LoraInjectedLinear(
|
337 |
+
_child_module.in_features,
|
338 |
+
_child_module.out_features,
|
339 |
+
_child_module.bias is not None,
|
340 |
+
r=r,
|
341 |
+
)
|
342 |
+
_tmp.linear.weight = weight
|
343 |
+
if bias is not None:
|
344 |
+
_tmp.linear.bias = bias
|
345 |
+
elif _child_module.__class__ == nn.Conv2d:
|
346 |
+
weight = _child_module.weight
|
347 |
+
bias = _child_module.bias
|
348 |
+
_tmp = LoraInjectedConv2d(
|
349 |
+
_child_module.in_channels,
|
350 |
+
_child_module.out_channels,
|
351 |
+
_child_module.kernel_size,
|
352 |
+
_child_module.stride,
|
353 |
+
_child_module.padding,
|
354 |
+
_child_module.dilation,
|
355 |
+
_child_module.groups,
|
356 |
+
_child_module.bias is not None,
|
357 |
+
r=r,
|
358 |
+
)
|
359 |
+
|
360 |
+
_tmp.conv.weight = weight
|
361 |
+
if bias is not None:
|
362 |
+
_tmp.conv.bias = bias
|
363 |
+
|
364 |
+
elif _child_module.__class__ == nn.Conv3d:
|
365 |
+
weight = _child_module.weight
|
366 |
+
bias = _child_module.bias
|
367 |
+
_tmp = LoraInjectedConv3d(
|
368 |
+
_child_module.in_channels,
|
369 |
+
_child_module.out_channels,
|
370 |
+
bias=_child_module.bias is not None,
|
371 |
+
kernel_size=_child_module.kernel_size,
|
372 |
+
padding=_child_module.padding,
|
373 |
+
r=r,
|
374 |
+
)
|
375 |
+
|
376 |
+
_tmp.conv.weight = weight
|
377 |
+
if bias is not None:
|
378 |
+
_tmp.conv.bias = bias
|
379 |
+
# switch the module
|
380 |
+
_tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
|
381 |
+
if bias is not None:
|
382 |
+
_tmp.to(_child_module.bias.device).to(_child_module.bias.dtype)
|
383 |
+
|
384 |
+
_module._modules[name] = _tmp
|
385 |
+
require_grad_params.append(_module._modules[name].lora_up.parameters())
|
386 |
+
require_grad_params.append(_module._modules[name].lora_down.parameters())
|
387 |
+
|
388 |
+
if loras != None:
|
389 |
+
_module._modules[name].lora_up.weight = loras.pop(0)
|
390 |
+
_module._modules[name].lora_down.weight = loras.pop(0)
|
391 |
+
|
392 |
+
_module._modules[name].lora_up.weight.requires_grad = True
|
393 |
+
_module._modules[name].lora_down.weight.requires_grad = True
|
394 |
+
names.append(name)
|
395 |
+
|
396 |
+
return require_grad_params, names
|
397 |
+
|
398 |
+
def extract_lora_ups_down(model, target_replace_module):
|
399 |
+
|
400 |
+
loras = []
|
401 |
+
|
402 |
+
for _m, _n, _child_module in _find_modules(
|
403 |
+
model,
|
404 |
+
target_replace_module,
|
405 |
+
search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d],
|
406 |
+
):
|
407 |
+
loras.append((_child_module.lora_up, _child_module.lora_down))
|
408 |
+
|
409 |
+
if len(loras) == 0:
|
410 |
+
raise ValueError("No lora injected.")
|
411 |
+
|
412 |
+
return loras
|
dragnuwa/svd/__init__.py
ADDED
File without changes
|
dragnuwa/svd/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (135 Bytes). View file
|
|
dragnuwa/svd/__pycache__/util.cpython-38.pyc
ADDED
Binary file (9.38 kB). View file
|
|
dragnuwa/svd/models/__init__.py
ADDED
File without changes
|
dragnuwa/svd/models/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (142 Bytes). View file
|
|
dragnuwa/svd/models/__pycache__/autoencoder.cpython-38.pyc
ADDED
Binary file (19.2 kB). View file
|
|
dragnuwa/svd/models/autoencoder.py
ADDED
@@ -0,0 +1,615 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
import re
|
4 |
+
from abc import abstractmethod
|
5 |
+
from contextlib import contextmanager
|
6 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
7 |
+
|
8 |
+
import pytorch_lightning as pl
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from einops import rearrange
|
12 |
+
from packaging import version
|
13 |
+
|
14 |
+
from ..modules.autoencoding.regularizers import AbstractRegularizer
|
15 |
+
from ..modules.ema import LitEma
|
16 |
+
from ..util import (default, get_nested_attribute, get_obj_from_str,
|
17 |
+
instantiate_from_config)
|
18 |
+
|
19 |
+
logpy = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
class AbstractAutoencoder(pl.LightningModule):
|
23 |
+
"""
|
24 |
+
This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
|
25 |
+
unCLIP models, etc. Hence, it is fairly general, and specific features
|
26 |
+
(e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
ema_decay: Union[None, float] = None,
|
32 |
+
monitor: Union[None, str] = None,
|
33 |
+
input_key: str = "jpg",
|
34 |
+
):
|
35 |
+
super().__init__()
|
36 |
+
|
37 |
+
self.input_key = input_key
|
38 |
+
self.use_ema = ema_decay is not None
|
39 |
+
if monitor is not None:
|
40 |
+
self.monitor = monitor
|
41 |
+
|
42 |
+
if self.use_ema:
|
43 |
+
self.model_ema = LitEma(self, decay=ema_decay)
|
44 |
+
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
45 |
+
|
46 |
+
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
47 |
+
self.automatic_optimization = False
|
48 |
+
|
49 |
+
def apply_ckpt(self, ckpt: Union[None, str, dict]):
|
50 |
+
if ckpt is None:
|
51 |
+
return
|
52 |
+
if isinstance(ckpt, str):
|
53 |
+
ckpt = {
|
54 |
+
"target": "dragnuwa.svd.modules.checkpoint.CheckpointEngine",
|
55 |
+
"params": {"ckpt_path": ckpt},
|
56 |
+
}
|
57 |
+
engine = instantiate_from_config(ckpt)
|
58 |
+
engine(self)
|
59 |
+
|
60 |
+
@abstractmethod
|
61 |
+
def get_input(self, batch) -> Any:
|
62 |
+
raise NotImplementedError()
|
63 |
+
|
64 |
+
def on_train_batch_end(self, *args, **kwargs):
|
65 |
+
# for EMA computation
|
66 |
+
if self.use_ema:
|
67 |
+
self.model_ema(self)
|
68 |
+
|
69 |
+
@contextmanager
|
70 |
+
def ema_scope(self, context=None):
|
71 |
+
if self.use_ema:
|
72 |
+
self.model_ema.store(self.parameters())
|
73 |
+
self.model_ema.copy_to(self)
|
74 |
+
if context is not None:
|
75 |
+
logpy.info(f"{context}: Switched to EMA weights")
|
76 |
+
try:
|
77 |
+
yield None
|
78 |
+
finally:
|
79 |
+
if self.use_ema:
|
80 |
+
self.model_ema.restore(self.parameters())
|
81 |
+
if context is not None:
|
82 |
+
logpy.info(f"{context}: Restored training weights")
|
83 |
+
|
84 |
+
@abstractmethod
|
85 |
+
def encode(self, *args, **kwargs) -> torch.Tensor:
|
86 |
+
raise NotImplementedError("encode()-method of abstract base class called")
|
87 |
+
|
88 |
+
@abstractmethod
|
89 |
+
def decode(self, *args, **kwargs) -> torch.Tensor:
|
90 |
+
raise NotImplementedError("decode()-method of abstract base class called")
|
91 |
+
|
92 |
+
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
93 |
+
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
94 |
+
return get_obj_from_str(cfg["target"])(
|
95 |
+
params, lr=lr, **cfg.get("params", dict())
|
96 |
+
)
|
97 |
+
|
98 |
+
def configure_optimizers(self) -> Any:
|
99 |
+
raise NotImplementedError()
|
100 |
+
|
101 |
+
|
102 |
+
class AutoencodingEngine(AbstractAutoencoder):
|
103 |
+
"""
|
104 |
+
Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
|
105 |
+
(we also restore them explicitly as special cases for legacy reasons).
|
106 |
+
Regularizations such as KL or VQ are moved to the regularizer class.
|
107 |
+
"""
|
108 |
+
|
109 |
+
def __init__(
|
110 |
+
self,
|
111 |
+
*args,
|
112 |
+
encoder_config: Dict,
|
113 |
+
decoder_config: Dict,
|
114 |
+
loss_config: Dict,
|
115 |
+
regularizer_config: Dict,
|
116 |
+
optimizer_config: Union[Dict, None] = None,
|
117 |
+
lr_g_factor: float = 1.0,
|
118 |
+
trainable_ae_params: Optional[List[List[str]]] = None,
|
119 |
+
ae_optimizer_args: Optional[List[dict]] = None,
|
120 |
+
trainable_disc_params: Optional[List[List[str]]] = None,
|
121 |
+
disc_optimizer_args: Optional[List[dict]] = None,
|
122 |
+
disc_start_iter: int = 0,
|
123 |
+
diff_boost_factor: float = 3.0,
|
124 |
+
ckpt_engine: Union[None, str, dict] = None,
|
125 |
+
ckpt_path: Optional[str] = None,
|
126 |
+
additional_decode_keys: Optional[List[str]] = None,
|
127 |
+
**kwargs,
|
128 |
+
):
|
129 |
+
super().__init__(*args, **kwargs)
|
130 |
+
self.automatic_optimization = False # pytorch lightning
|
131 |
+
|
132 |
+
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
|
133 |
+
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
|
134 |
+
self.loss: torch.nn.Module = instantiate_from_config(loss_config)
|
135 |
+
self.regularization: AbstractRegularizer = instantiate_from_config(
|
136 |
+
regularizer_config
|
137 |
+
)
|
138 |
+
self.optimizer_config = default(
|
139 |
+
optimizer_config, {"target": "torch.optim.Adam"}
|
140 |
+
)
|
141 |
+
self.diff_boost_factor = diff_boost_factor
|
142 |
+
self.disc_start_iter = disc_start_iter
|
143 |
+
self.lr_g_factor = lr_g_factor
|
144 |
+
self.trainable_ae_params = trainable_ae_params
|
145 |
+
if self.trainable_ae_params is not None:
|
146 |
+
self.ae_optimizer_args = default(
|
147 |
+
ae_optimizer_args,
|
148 |
+
[{} for _ in range(len(self.trainable_ae_params))],
|
149 |
+
)
|
150 |
+
assert len(self.ae_optimizer_args) == len(self.trainable_ae_params)
|
151 |
+
else:
|
152 |
+
self.ae_optimizer_args = [{}] # makes type consitent
|
153 |
+
|
154 |
+
self.trainable_disc_params = trainable_disc_params
|
155 |
+
if self.trainable_disc_params is not None:
|
156 |
+
self.disc_optimizer_args = default(
|
157 |
+
disc_optimizer_args,
|
158 |
+
[{} for _ in range(len(self.trainable_disc_params))],
|
159 |
+
)
|
160 |
+
assert len(self.disc_optimizer_args) == len(self.trainable_disc_params)
|
161 |
+
else:
|
162 |
+
self.disc_optimizer_args = [{}] # makes type consitent
|
163 |
+
|
164 |
+
if ckpt_path is not None:
|
165 |
+
assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path"
|
166 |
+
logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead")
|
167 |
+
self.apply_ckpt(default(ckpt_path, ckpt_engine))
|
168 |
+
self.additional_decode_keys = set(default(additional_decode_keys, []))
|
169 |
+
|
170 |
+
def get_input(self, batch: Dict) -> torch.Tensor:
|
171 |
+
# assuming unified data format, dataloader returns a dict.
|
172 |
+
# image tensors should be scaled to -1 ... 1 and in channels-first
|
173 |
+
# format (e.g., bchw instead if bhwc)
|
174 |
+
return batch[self.input_key]
|
175 |
+
|
176 |
+
def get_autoencoder_params(self) -> list:
|
177 |
+
params = []
|
178 |
+
if hasattr(self.loss, "get_trainable_autoencoder_parameters"):
|
179 |
+
params += list(self.loss.get_trainable_autoencoder_parameters())
|
180 |
+
if hasattr(self.regularization, "get_trainable_parameters"):
|
181 |
+
params += list(self.regularization.get_trainable_parameters())
|
182 |
+
params = params + list(self.encoder.parameters())
|
183 |
+
params = params + list(self.decoder.parameters())
|
184 |
+
return params
|
185 |
+
|
186 |
+
def get_discriminator_params(self) -> list:
|
187 |
+
if hasattr(self.loss, "get_trainable_parameters"):
|
188 |
+
params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
|
189 |
+
else:
|
190 |
+
params = []
|
191 |
+
return params
|
192 |
+
|
193 |
+
def get_last_layer(self):
|
194 |
+
return self.decoder.get_last_layer()
|
195 |
+
|
196 |
+
def encode(
|
197 |
+
self,
|
198 |
+
x: torch.Tensor,
|
199 |
+
return_reg_log: bool = False,
|
200 |
+
unregularized: bool = False,
|
201 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
202 |
+
z = self.encoder(x)
|
203 |
+
if unregularized:
|
204 |
+
return z, dict()
|
205 |
+
z, reg_log = self.regularization(z)
|
206 |
+
if return_reg_log:
|
207 |
+
return z, reg_log
|
208 |
+
return z
|
209 |
+
|
210 |
+
def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
|
211 |
+
x = self.decoder(z, **kwargs)
|
212 |
+
return x
|
213 |
+
|
214 |
+
def forward(
|
215 |
+
self, x: torch.Tensor, **additional_decode_kwargs
|
216 |
+
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
217 |
+
z, reg_log = self.encode(x, return_reg_log=True)
|
218 |
+
dec = self.decode(z, **additional_decode_kwargs)
|
219 |
+
return z, dec, reg_log
|
220 |
+
|
221 |
+
def inner_training_step(
|
222 |
+
self, batch: dict, batch_idx: int, optimizer_idx: int = 0
|
223 |
+
) -> torch.Tensor:
|
224 |
+
x = self.get_input(batch)
|
225 |
+
additional_decode_kwargs = {
|
226 |
+
key: batch[key] for key in self.additional_decode_keys.intersection(batch)
|
227 |
+
}
|
228 |
+
z, xrec, regularization_log = self(x, **additional_decode_kwargs)
|
229 |
+
if hasattr(self.loss, "forward_keys"):
|
230 |
+
extra_info = {
|
231 |
+
"z": z,
|
232 |
+
"optimizer_idx": optimizer_idx,
|
233 |
+
"global_step": self.global_step,
|
234 |
+
"last_layer": self.get_last_layer(),
|
235 |
+
"split": "train",
|
236 |
+
"regularization_log": regularization_log,
|
237 |
+
"autoencoder": self,
|
238 |
+
}
|
239 |
+
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
|
240 |
+
else:
|
241 |
+
extra_info = dict()
|
242 |
+
|
243 |
+
if optimizer_idx == 0:
|
244 |
+
# autoencode
|
245 |
+
out_loss = self.loss(x, xrec, **extra_info)
|
246 |
+
if isinstance(out_loss, tuple):
|
247 |
+
aeloss, log_dict_ae = out_loss
|
248 |
+
else:
|
249 |
+
# simple loss function
|
250 |
+
aeloss = out_loss
|
251 |
+
log_dict_ae = {"train/loss/rec": aeloss.detach()}
|
252 |
+
|
253 |
+
self.log_dict(
|
254 |
+
log_dict_ae,
|
255 |
+
prog_bar=False,
|
256 |
+
logger=True,
|
257 |
+
on_step=True,
|
258 |
+
on_epoch=True,
|
259 |
+
sync_dist=False,
|
260 |
+
)
|
261 |
+
self.log(
|
262 |
+
"loss",
|
263 |
+
aeloss.mean().detach(),
|
264 |
+
prog_bar=True,
|
265 |
+
logger=False,
|
266 |
+
on_epoch=False,
|
267 |
+
on_step=True,
|
268 |
+
)
|
269 |
+
return aeloss
|
270 |
+
elif optimizer_idx == 1:
|
271 |
+
# discriminator
|
272 |
+
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
|
273 |
+
# -> discriminator always needs to return a tuple
|
274 |
+
self.log_dict(
|
275 |
+
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
|
276 |
+
)
|
277 |
+
return discloss
|
278 |
+
else:
|
279 |
+
raise NotImplementedError(f"Unknown optimizer {optimizer_idx}")
|
280 |
+
|
281 |
+
def training_step(self, batch: dict, batch_idx: int):
|
282 |
+
opts = self.optimizers()
|
283 |
+
if not isinstance(opts, list):
|
284 |
+
# Non-adversarial case
|
285 |
+
opts = [opts]
|
286 |
+
optimizer_idx = batch_idx % len(opts)
|
287 |
+
if self.global_step < self.disc_start_iter:
|
288 |
+
optimizer_idx = 0
|
289 |
+
opt = opts[optimizer_idx]
|
290 |
+
opt.zero_grad()
|
291 |
+
with opt.toggle_model():
|
292 |
+
loss = self.inner_training_step(
|
293 |
+
batch, batch_idx, optimizer_idx=optimizer_idx
|
294 |
+
)
|
295 |
+
self.manual_backward(loss)
|
296 |
+
opt.step()
|
297 |
+
|
298 |
+
def validation_step(self, batch: dict, batch_idx: int) -> Dict:
|
299 |
+
log_dict = self._validation_step(batch, batch_idx)
|
300 |
+
with self.ema_scope():
|
301 |
+
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
|
302 |
+
log_dict.update(log_dict_ema)
|
303 |
+
return log_dict
|
304 |
+
|
305 |
+
def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict:
|
306 |
+
x = self.get_input(batch)
|
307 |
+
|
308 |
+
z, xrec, regularization_log = self(x)
|
309 |
+
if hasattr(self.loss, "forward_keys"):
|
310 |
+
extra_info = {
|
311 |
+
"z": z,
|
312 |
+
"optimizer_idx": 0,
|
313 |
+
"global_step": self.global_step,
|
314 |
+
"last_layer": self.get_last_layer(),
|
315 |
+
"split": "val" + postfix,
|
316 |
+
"regularization_log": regularization_log,
|
317 |
+
"autoencoder": self,
|
318 |
+
}
|
319 |
+
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
|
320 |
+
else:
|
321 |
+
extra_info = dict()
|
322 |
+
out_loss = self.loss(x, xrec, **extra_info)
|
323 |
+
if isinstance(out_loss, tuple):
|
324 |
+
aeloss, log_dict_ae = out_loss
|
325 |
+
else:
|
326 |
+
# simple loss function
|
327 |
+
aeloss = out_loss
|
328 |
+
log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()}
|
329 |
+
full_log_dict = log_dict_ae
|
330 |
+
|
331 |
+
if "optimizer_idx" in extra_info:
|
332 |
+
extra_info["optimizer_idx"] = 1
|
333 |
+
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
|
334 |
+
full_log_dict.update(log_dict_disc)
|
335 |
+
self.log(
|
336 |
+
f"val{postfix}/loss/rec",
|
337 |
+
log_dict_ae[f"val{postfix}/loss/rec"],
|
338 |
+
sync_dist=True,
|
339 |
+
)
|
340 |
+
self.log_dict(full_log_dict, sync_dist=True)
|
341 |
+
return full_log_dict
|
342 |
+
|
343 |
+
def get_param_groups(
|
344 |
+
self, parameter_names: List[List[str]], optimizer_args: List[dict]
|
345 |
+
) -> Tuple[List[Dict[str, Any]], int]:
|
346 |
+
groups = []
|
347 |
+
num_params = 0
|
348 |
+
for names, args in zip(parameter_names, optimizer_args):
|
349 |
+
params = []
|
350 |
+
for pattern_ in names:
|
351 |
+
pattern_params = []
|
352 |
+
pattern = re.compile(pattern_)
|
353 |
+
for p_name, param in self.named_parameters():
|
354 |
+
if re.match(pattern, p_name):
|
355 |
+
pattern_params.append(param)
|
356 |
+
num_params += param.numel()
|
357 |
+
if len(pattern_params) == 0:
|
358 |
+
logpy.warn(f"Did not find parameters for pattern {pattern_}")
|
359 |
+
params.extend(pattern_params)
|
360 |
+
groups.append({"params": params, **args})
|
361 |
+
return groups, num_params
|
362 |
+
|
363 |
+
def configure_optimizers(self) -> List[torch.optim.Optimizer]:
|
364 |
+
if self.trainable_ae_params is None:
|
365 |
+
ae_params = self.get_autoencoder_params()
|
366 |
+
else:
|
367 |
+
ae_params, num_ae_params = self.get_param_groups(
|
368 |
+
self.trainable_ae_params, self.ae_optimizer_args
|
369 |
+
)
|
370 |
+
logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
|
371 |
+
if self.trainable_disc_params is None:
|
372 |
+
disc_params = self.get_discriminator_params()
|
373 |
+
else:
|
374 |
+
disc_params, num_disc_params = self.get_param_groups(
|
375 |
+
self.trainable_disc_params, self.disc_optimizer_args
|
376 |
+
)
|
377 |
+
logpy.info(
|
378 |
+
f"Number of trainable discriminator parameters: {num_disc_params:,}"
|
379 |
+
)
|
380 |
+
opt_ae = self.instantiate_optimizer_from_config(
|
381 |
+
ae_params,
|
382 |
+
default(self.lr_g_factor, 1.0) * self.learning_rate,
|
383 |
+
self.optimizer_config,
|
384 |
+
)
|
385 |
+
opts = [opt_ae]
|
386 |
+
if len(disc_params) > 0:
|
387 |
+
opt_disc = self.instantiate_optimizer_from_config(
|
388 |
+
disc_params, self.learning_rate, self.optimizer_config
|
389 |
+
)
|
390 |
+
opts.append(opt_disc)
|
391 |
+
|
392 |
+
return opts
|
393 |
+
|
394 |
+
@torch.no_grad()
|
395 |
+
def log_images(
|
396 |
+
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
|
397 |
+
) -> dict:
|
398 |
+
log = dict()
|
399 |
+
additional_decode_kwargs = {}
|
400 |
+
x = self.get_input(batch)
|
401 |
+
additional_decode_kwargs.update(
|
402 |
+
{key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
|
403 |
+
)
|
404 |
+
|
405 |
+
_, xrec, _ = self(x, **additional_decode_kwargs)
|
406 |
+
log["inputs"] = x
|
407 |
+
log["reconstructions"] = xrec
|
408 |
+
diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x)
|
409 |
+
diff.clamp_(0, 1.0)
|
410 |
+
log["diff"] = 2.0 * diff - 1.0
|
411 |
+
# diff_boost shows location of small errors, by boosting their
|
412 |
+
# brightness.
|
413 |
+
log["diff_boost"] = (
|
414 |
+
2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
|
415 |
+
)
|
416 |
+
if hasattr(self.loss, "log_images"):
|
417 |
+
log.update(self.loss.log_images(x, xrec))
|
418 |
+
with self.ema_scope():
|
419 |
+
_, xrec_ema, _ = self(x, **additional_decode_kwargs)
|
420 |
+
log["reconstructions_ema"] = xrec_ema
|
421 |
+
diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
|
422 |
+
diff_ema.clamp_(0, 1.0)
|
423 |
+
log["diff_ema"] = 2.0 * diff_ema - 1.0
|
424 |
+
log["diff_boost_ema"] = (
|
425 |
+
2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
|
426 |
+
)
|
427 |
+
if additional_log_kwargs:
|
428 |
+
additional_decode_kwargs.update(additional_log_kwargs)
|
429 |
+
_, xrec_add, _ = self(x, **additional_decode_kwargs)
|
430 |
+
log_str = "reconstructions-" + "-".join(
|
431 |
+
[f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs]
|
432 |
+
)
|
433 |
+
log[log_str] = xrec_add
|
434 |
+
return log
|
435 |
+
|
436 |
+
|
437 |
+
class AutoencodingEngineLegacy(AutoencodingEngine):
|
438 |
+
def __init__(self, embed_dim: int, **kwargs):
|
439 |
+
self.max_batch_size = kwargs.pop("max_batch_size", None)
|
440 |
+
ddconfig = kwargs.pop("ddconfig")
|
441 |
+
ckpt_path = kwargs.pop("ckpt_path", None)
|
442 |
+
ckpt_engine = kwargs.pop("ckpt_engine", None)
|
443 |
+
super().__init__(
|
444 |
+
encoder_config={
|
445 |
+
"target": "dragnuwa.svd.modules.diffusionmodules.model.Encoder",
|
446 |
+
"params": ddconfig,
|
447 |
+
},
|
448 |
+
decoder_config={
|
449 |
+
"target": "dragnuwa.svd.modules.diffusionmodules.model.Decoder",
|
450 |
+
"params": ddconfig,
|
451 |
+
},
|
452 |
+
**kwargs,
|
453 |
+
)
|
454 |
+
self.quant_conv = torch.nn.Conv2d(
|
455 |
+
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
|
456 |
+
(1 + ddconfig["double_z"]) * embed_dim,
|
457 |
+
1,
|
458 |
+
)
|
459 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
460 |
+
self.embed_dim = embed_dim
|
461 |
+
|
462 |
+
self.apply_ckpt(default(ckpt_path, ckpt_engine))
|
463 |
+
|
464 |
+
def get_autoencoder_params(self) -> list:
|
465 |
+
params = super().get_autoencoder_params()
|
466 |
+
return params
|
467 |
+
|
468 |
+
def encode(
|
469 |
+
self, x: torch.Tensor, return_reg_log: bool = False
|
470 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
471 |
+
if self.max_batch_size is None:
|
472 |
+
z = self.encoder(x)
|
473 |
+
z = self.quant_conv(z)
|
474 |
+
else:
|
475 |
+
N = x.shape[0]
|
476 |
+
bs = self.max_batch_size
|
477 |
+
n_batches = int(math.ceil(N / bs))
|
478 |
+
z = list()
|
479 |
+
for i_batch in range(n_batches):
|
480 |
+
z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
|
481 |
+
z_batch = self.quant_conv(z_batch)
|
482 |
+
z.append(z_batch)
|
483 |
+
z = torch.cat(z, 0)
|
484 |
+
|
485 |
+
z, reg_log = self.regularization(z)
|
486 |
+
if return_reg_log:
|
487 |
+
return z, reg_log
|
488 |
+
return z
|
489 |
+
|
490 |
+
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
|
491 |
+
if self.max_batch_size is None:
|
492 |
+
dec = self.post_quant_conv(z)
|
493 |
+
dec = self.decoder(dec, **decoder_kwargs)
|
494 |
+
else:
|
495 |
+
N = z.shape[0]
|
496 |
+
bs = self.max_batch_size
|
497 |
+
n_batches = int(math.ceil(N / bs))
|
498 |
+
dec = list()
|
499 |
+
for i_batch in range(n_batches):
|
500 |
+
dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
|
501 |
+
dec_batch = self.decoder(dec_batch, **decoder_kwargs)
|
502 |
+
dec.append(dec_batch)
|
503 |
+
dec = torch.cat(dec, 0)
|
504 |
+
|
505 |
+
return dec
|
506 |
+
|
507 |
+
|
508 |
+
class AutoencoderKL(AutoencodingEngineLegacy):
|
509 |
+
def __init__(self, **kwargs):
|
510 |
+
if "lossconfig" in kwargs:
|
511 |
+
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
512 |
+
super().__init__(
|
513 |
+
regularizer_config={
|
514 |
+
"target": (
|
515 |
+
"dragnuwa.svd.modules.autoencoding.regularizers"
|
516 |
+
".DiagonalGaussianRegularizer"
|
517 |
+
)
|
518 |
+
},
|
519 |
+
**kwargs,
|
520 |
+
)
|
521 |
+
|
522 |
+
|
523 |
+
class AutoencoderLegacyVQ(AutoencodingEngineLegacy):
|
524 |
+
def __init__(
|
525 |
+
self,
|
526 |
+
embed_dim: int,
|
527 |
+
n_embed: int,
|
528 |
+
sane_index_shape: bool = False,
|
529 |
+
**kwargs,
|
530 |
+
):
|
531 |
+
if "lossconfig" in kwargs:
|
532 |
+
logpy.warn(f"Parameter `lossconfig` is deprecated, use `loss_config`.")
|
533 |
+
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
534 |
+
super().__init__(
|
535 |
+
regularizer_config={
|
536 |
+
"target": (
|
537 |
+
"dragnuwa.svd.modules.autoencoding.regularizers.quantize" ".VectorQuantizer"
|
538 |
+
),
|
539 |
+
"params": {
|
540 |
+
"n_e": n_embed,
|
541 |
+
"e_dim": embed_dim,
|
542 |
+
"sane_index_shape": sane_index_shape,
|
543 |
+
},
|
544 |
+
},
|
545 |
+
**kwargs,
|
546 |
+
)
|
547 |
+
|
548 |
+
|
549 |
+
class IdentityFirstStage(AbstractAutoencoder):
|
550 |
+
def __init__(self, *args, **kwargs):
|
551 |
+
super().__init__(*args, **kwargs)
|
552 |
+
|
553 |
+
def get_input(self, x: Any) -> Any:
|
554 |
+
return x
|
555 |
+
|
556 |
+
def encode(self, x: Any, *args, **kwargs) -> Any:
|
557 |
+
return x
|
558 |
+
|
559 |
+
def decode(self, x: Any, *args, **kwargs) -> Any:
|
560 |
+
return x
|
561 |
+
|
562 |
+
|
563 |
+
class AEIntegerWrapper(nn.Module):
|
564 |
+
def __init__(
|
565 |
+
self,
|
566 |
+
model: nn.Module,
|
567 |
+
shape: Union[None, Tuple[int, int], List[int]] = (16, 16),
|
568 |
+
regularization_key: str = "regularization",
|
569 |
+
encoder_kwargs: Optional[Dict[str, Any]] = None,
|
570 |
+
):
|
571 |
+
super().__init__()
|
572 |
+
self.model = model
|
573 |
+
assert hasattr(model, "encode") and hasattr(
|
574 |
+
model, "decode"
|
575 |
+
), "Need AE interface"
|
576 |
+
self.regularization = get_nested_attribute(model, regularization_key)
|
577 |
+
self.shape = shape
|
578 |
+
self.encoder_kwargs = default(encoder_kwargs, {"return_reg_log": True})
|
579 |
+
|
580 |
+
def encode(self, x) -> torch.Tensor:
|
581 |
+
assert (
|
582 |
+
not self.training
|
583 |
+
), f"{self.__class__.__name__} only supports inference currently"
|
584 |
+
_, log = self.model.encode(x, **self.encoder_kwargs)
|
585 |
+
assert isinstance(log, dict)
|
586 |
+
inds = log["min_encoding_indices"]
|
587 |
+
return rearrange(inds, "b ... -> b (...)")
|
588 |
+
|
589 |
+
def decode(
|
590 |
+
self, inds: torch.Tensor, shape: Union[None, tuple, list] = None
|
591 |
+
) -> torch.Tensor:
|
592 |
+
# expect inds shape (b, s) with s = h*w
|
593 |
+
shape = default(shape, self.shape) # Optional[(h, w)]
|
594 |
+
if shape is not None:
|
595 |
+
assert len(shape) == 2, f"Unhandeled shape {shape}"
|
596 |
+
inds = rearrange(inds, "b (h w) -> b h w", h=shape[0], w=shape[1])
|
597 |
+
h = self.regularization.get_codebook_entry(inds) # (b, h, w, c)
|
598 |
+
h = rearrange(h, "b h w c -> b c h w")
|
599 |
+
return self.model.decode(h)
|
600 |
+
|
601 |
+
|
602 |
+
class AutoencoderKLModeOnly(AutoencodingEngineLegacy):
|
603 |
+
def __init__(self, **kwargs):
|
604 |
+
if "lossconfig" in kwargs:
|
605 |
+
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
606 |
+
super().__init__(
|
607 |
+
regularizer_config={
|
608 |
+
"target": (
|
609 |
+
"dragnuwa.svd.modules.autoencoding.regularizers"
|
610 |
+
".DiagonalGaussianRegularizer"
|
611 |
+
),
|
612 |
+
"params": {"sample": False},
|
613 |
+
},
|
614 |
+
**kwargs,
|
615 |
+
)
|
dragnuwa/svd/modules/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .encoders.modules import GeneralConditioner
|
2 |
+
|
3 |
+
UNCONDITIONAL_CONFIG = {
|
4 |
+
"target": "dragnuwa.svd.modules.GeneralConditioner",
|
5 |
+
"params": {"emb_models": []},
|
6 |
+
}
|
dragnuwa/svd/modules/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (315 Bytes). View file
|
|
dragnuwa/svd/modules/__pycache__/attention.cpython-38.pyc
ADDED
Binary file (18 kB). View file
|
|
dragnuwa/svd/modules/__pycache__/ema.cpython-38.pyc
ADDED
Binary file (3.19 kB). View file
|
|
dragnuwa/svd/modules/__pycache__/video_attention.cpython-38.pyc
ADDED
Binary file (6.23 kB). View file
|
|
dragnuwa/svd/modules/attention.py
ADDED
@@ -0,0 +1,759 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
from inspect import isfunction
|
4 |
+
from typing import Any, Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from einops import rearrange, repeat
|
9 |
+
from packaging import version
|
10 |
+
from torch import nn
|
11 |
+
from torch.utils.checkpoint import checkpoint
|
12 |
+
|
13 |
+
logpy = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
16 |
+
SDP_IS_AVAILABLE = True
|
17 |
+
from torch.backends.cuda import SDPBackend, sdp_kernel
|
18 |
+
|
19 |
+
BACKEND_MAP = {
|
20 |
+
SDPBackend.MATH: {
|
21 |
+
"enable_math": True,
|
22 |
+
"enable_flash": False,
|
23 |
+
"enable_mem_efficient": False,
|
24 |
+
},
|
25 |
+
SDPBackend.FLASH_ATTENTION: {
|
26 |
+
"enable_math": False,
|
27 |
+
"enable_flash": True,
|
28 |
+
"enable_mem_efficient": False,
|
29 |
+
},
|
30 |
+
SDPBackend.EFFICIENT_ATTENTION: {
|
31 |
+
"enable_math": False,
|
32 |
+
"enable_flash": False,
|
33 |
+
"enable_mem_efficient": True,
|
34 |
+
},
|
35 |
+
None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
|
36 |
+
}
|
37 |
+
else:
|
38 |
+
from contextlib import nullcontext
|
39 |
+
|
40 |
+
SDP_IS_AVAILABLE = False
|
41 |
+
sdp_kernel = nullcontext
|
42 |
+
BACKEND_MAP = {}
|
43 |
+
logpy.warn(
|
44 |
+
f"No SDP backend available, likely because you are running in pytorch "
|
45 |
+
f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
|
46 |
+
f"You might want to consider upgrading."
|
47 |
+
)
|
48 |
+
|
49 |
+
try:
|
50 |
+
import xformers
|
51 |
+
import xformers.ops
|
52 |
+
|
53 |
+
XFORMERS_IS_AVAILABLE = True
|
54 |
+
except:
|
55 |
+
XFORMERS_IS_AVAILABLE = False
|
56 |
+
logpy.warn("no module 'xformers'. Processing without...")
|
57 |
+
|
58 |
+
# from .diffusionmodules.util import mixed_checkpoint as checkpoint
|
59 |
+
|
60 |
+
|
61 |
+
def exists(val):
|
62 |
+
return val is not None
|
63 |
+
|
64 |
+
|
65 |
+
def uniq(arr):
|
66 |
+
return {el: True for el in arr}.keys()
|
67 |
+
|
68 |
+
|
69 |
+
def default(val, d):
|
70 |
+
if exists(val):
|
71 |
+
return val
|
72 |
+
return d() if isfunction(d) else d
|
73 |
+
|
74 |
+
|
75 |
+
def max_neg_value(t):
|
76 |
+
return -torch.finfo(t.dtype).max
|
77 |
+
|
78 |
+
|
79 |
+
def init_(tensor):
|
80 |
+
dim = tensor.shape[-1]
|
81 |
+
std = 1 / math.sqrt(dim)
|
82 |
+
tensor.uniform_(-std, std)
|
83 |
+
return tensor
|
84 |
+
|
85 |
+
|
86 |
+
# feedforward
|
87 |
+
class GEGLU(nn.Module):
|
88 |
+
def __init__(self, dim_in, dim_out):
|
89 |
+
super().__init__()
|
90 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
94 |
+
return x * F.gelu(gate)
|
95 |
+
|
96 |
+
|
97 |
+
class FeedForward(nn.Module):
|
98 |
+
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
|
99 |
+
super().__init__()
|
100 |
+
inner_dim = int(dim * mult)
|
101 |
+
dim_out = default(dim_out, dim)
|
102 |
+
project_in = (
|
103 |
+
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
|
104 |
+
if not glu
|
105 |
+
else GEGLU(dim, inner_dim)
|
106 |
+
)
|
107 |
+
|
108 |
+
self.net = nn.Sequential(
|
109 |
+
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
|
110 |
+
)
|
111 |
+
|
112 |
+
def forward(self, x):
|
113 |
+
return self.net(x)
|
114 |
+
|
115 |
+
|
116 |
+
def zero_module(module):
|
117 |
+
"""
|
118 |
+
Zero out the parameters of a module and return it.
|
119 |
+
"""
|
120 |
+
for p in module.parameters():
|
121 |
+
p.detach().zero_()
|
122 |
+
return module
|
123 |
+
|
124 |
+
|
125 |
+
def Normalize(in_channels):
|
126 |
+
return torch.nn.GroupNorm(
|
127 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
128 |
+
)
|
129 |
+
|
130 |
+
|
131 |
+
class LinearAttention(nn.Module):
|
132 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
133 |
+
super().__init__()
|
134 |
+
self.heads = heads
|
135 |
+
hidden_dim = dim_head * heads
|
136 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
137 |
+
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
138 |
+
|
139 |
+
def forward(self, x):
|
140 |
+
b, c, h, w = x.shape
|
141 |
+
qkv = self.to_qkv(x)
|
142 |
+
q, k, v = rearrange(
|
143 |
+
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
|
144 |
+
)
|
145 |
+
k = k.softmax(dim=-1)
|
146 |
+
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
147 |
+
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
148 |
+
out = rearrange(
|
149 |
+
out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
|
150 |
+
)
|
151 |
+
return self.to_out(out)
|
152 |
+
|
153 |
+
|
154 |
+
class SelfAttention(nn.Module):
|
155 |
+
ATTENTION_MODES = ("xformers", "torch", "math")
|
156 |
+
|
157 |
+
def __init__(
|
158 |
+
self,
|
159 |
+
dim: int,
|
160 |
+
num_heads: int = 8,
|
161 |
+
qkv_bias: bool = False,
|
162 |
+
qk_scale: Optional[float] = None,
|
163 |
+
attn_drop: float = 0.0,
|
164 |
+
proj_drop: float = 0.0,
|
165 |
+
attn_mode: str = "xformers",
|
166 |
+
):
|
167 |
+
super().__init__()
|
168 |
+
self.num_heads = num_heads
|
169 |
+
head_dim = dim // num_heads
|
170 |
+
self.scale = qk_scale or head_dim**-0.5
|
171 |
+
|
172 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
173 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
174 |
+
self.proj = nn.Linear(dim, dim)
|
175 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
176 |
+
assert attn_mode in self.ATTENTION_MODES
|
177 |
+
self.attn_mode = attn_mode
|
178 |
+
|
179 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
180 |
+
B, L, C = x.shape
|
181 |
+
|
182 |
+
qkv = self.qkv(x)
|
183 |
+
if self.attn_mode == "torch":
|
184 |
+
qkv = rearrange(
|
185 |
+
qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
|
186 |
+
).float()
|
187 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
|
188 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
189 |
+
x = rearrange(x, "B H L D -> B L (H D)")
|
190 |
+
elif self.attn_mode == "xformers":
|
191 |
+
qkv = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
|
192 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
|
193 |
+
x = xformers.ops.memory_efficient_attention(q, k, v)
|
194 |
+
x = rearrange(x, "B L H D -> B L (H D)", H=self.num_heads)
|
195 |
+
elif self.attn_mode == "math":
|
196 |
+
qkv = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
197 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
|
198 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
199 |
+
attn = attn.softmax(dim=-1)
|
200 |
+
attn = self.attn_drop(attn)
|
201 |
+
x = (attn @ v).transpose(1, 2).reshape(B, L, C)
|
202 |
+
else:
|
203 |
+
raise NotImplemented
|
204 |
+
|
205 |
+
x = self.proj(x)
|
206 |
+
x = self.proj_drop(x)
|
207 |
+
return x
|
208 |
+
|
209 |
+
|
210 |
+
class SpatialSelfAttention(nn.Module):
|
211 |
+
def __init__(self, in_channels):
|
212 |
+
super().__init__()
|
213 |
+
self.in_channels = in_channels
|
214 |
+
|
215 |
+
self.norm = Normalize(in_channels)
|
216 |
+
self.q = torch.nn.Conv2d(
|
217 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
218 |
+
)
|
219 |
+
self.k = torch.nn.Conv2d(
|
220 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
221 |
+
)
|
222 |
+
self.v = torch.nn.Conv2d(
|
223 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
224 |
+
)
|
225 |
+
self.proj_out = torch.nn.Conv2d(
|
226 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
227 |
+
)
|
228 |
+
|
229 |
+
def forward(self, x):
|
230 |
+
h_ = x
|
231 |
+
h_ = self.norm(h_)
|
232 |
+
q = self.q(h_)
|
233 |
+
k = self.k(h_)
|
234 |
+
v = self.v(h_)
|
235 |
+
|
236 |
+
# compute attention
|
237 |
+
b, c, h, w = q.shape
|
238 |
+
q = rearrange(q, "b c h w -> b (h w) c")
|
239 |
+
k = rearrange(k, "b c h w -> b c (h w)")
|
240 |
+
w_ = torch.einsum("bij,bjk->bik", q, k)
|
241 |
+
|
242 |
+
w_ = w_ * (int(c) ** (-0.5))
|
243 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
244 |
+
|
245 |
+
# attend to values
|
246 |
+
v = rearrange(v, "b c h w -> b c (h w)")
|
247 |
+
w_ = rearrange(w_, "b i j -> b j i")
|
248 |
+
h_ = torch.einsum("bij,bjk->bik", v, w_)
|
249 |
+
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
|
250 |
+
h_ = self.proj_out(h_)
|
251 |
+
|
252 |
+
return x + h_
|
253 |
+
|
254 |
+
|
255 |
+
class CrossAttention(nn.Module):
|
256 |
+
def __init__(
|
257 |
+
self,
|
258 |
+
query_dim,
|
259 |
+
context_dim=None,
|
260 |
+
heads=8,
|
261 |
+
dim_head=64,
|
262 |
+
dropout=0.0,
|
263 |
+
backend=None,
|
264 |
+
):
|
265 |
+
super().__init__()
|
266 |
+
inner_dim = dim_head * heads
|
267 |
+
context_dim = default(context_dim, query_dim)
|
268 |
+
|
269 |
+
self.scale = dim_head**-0.5
|
270 |
+
self.heads = heads
|
271 |
+
|
272 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
273 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
274 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
275 |
+
|
276 |
+
self.to_out = nn.Sequential(
|
277 |
+
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
278 |
+
)
|
279 |
+
self.backend = backend
|
280 |
+
|
281 |
+
def forward(
|
282 |
+
self,
|
283 |
+
x,
|
284 |
+
context=None,
|
285 |
+
mask=None,
|
286 |
+
additional_tokens=None,
|
287 |
+
n_times_crossframe_attn_in_self=0,
|
288 |
+
):
|
289 |
+
h = self.heads
|
290 |
+
|
291 |
+
if additional_tokens is not None:
|
292 |
+
# get the number of masked tokens at the beginning of the output sequence
|
293 |
+
n_tokens_to_mask = additional_tokens.shape[1]
|
294 |
+
# add additional token
|
295 |
+
x = torch.cat([additional_tokens, x], dim=1)
|
296 |
+
|
297 |
+
q = self.to_q(x)
|
298 |
+
context = default(context, x)
|
299 |
+
k = self.to_k(context)
|
300 |
+
v = self.to_v(context)
|
301 |
+
|
302 |
+
if n_times_crossframe_attn_in_self:
|
303 |
+
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
|
304 |
+
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
|
305 |
+
n_cp = x.shape[0] // n_times_crossframe_attn_in_self
|
306 |
+
k = repeat(
|
307 |
+
k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
|
308 |
+
)
|
309 |
+
v = repeat(
|
310 |
+
v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
|
311 |
+
)
|
312 |
+
|
313 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
314 |
+
|
315 |
+
## old
|
316 |
+
"""
|
317 |
+
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
318 |
+
del q, k
|
319 |
+
|
320 |
+
if exists(mask):
|
321 |
+
mask = rearrange(mask, 'b ... -> b (...)')
|
322 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
323 |
+
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
324 |
+
sim.masked_fill_(~mask, max_neg_value)
|
325 |
+
|
326 |
+
# attention, what we cannot get enough of
|
327 |
+
sim = sim.softmax(dim=-1)
|
328 |
+
|
329 |
+
out = einsum('b i j, b j d -> b i d', sim, v)
|
330 |
+
"""
|
331 |
+
## new
|
332 |
+
with sdp_kernel(**BACKEND_MAP[self.backend]):
|
333 |
+
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
|
334 |
+
out = F.scaled_dot_product_attention(
|
335 |
+
q, k, v, attn_mask=mask
|
336 |
+
) # scale is dim_head ** -0.5 per default
|
337 |
+
|
338 |
+
del q, k, v
|
339 |
+
out = rearrange(out, "b h n d -> b n (h d)", h=h)
|
340 |
+
|
341 |
+
if additional_tokens is not None:
|
342 |
+
# remove additional token
|
343 |
+
out = out[:, n_tokens_to_mask:]
|
344 |
+
return self.to_out(out)
|
345 |
+
|
346 |
+
|
347 |
+
class MemoryEfficientCrossAttention(nn.Module):
|
348 |
+
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
349 |
+
def __init__(
|
350 |
+
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
|
351 |
+
):
|
352 |
+
super().__init__()
|
353 |
+
logpy.debug(
|
354 |
+
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, "
|
355 |
+
f"context_dim is {context_dim} and using {heads} heads with a "
|
356 |
+
f"dimension of {dim_head}."
|
357 |
+
)
|
358 |
+
inner_dim = dim_head * heads
|
359 |
+
context_dim = default(context_dim, query_dim)
|
360 |
+
|
361 |
+
self.heads = heads
|
362 |
+
self.dim_head = dim_head
|
363 |
+
|
364 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
365 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
366 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
367 |
+
|
368 |
+
self.to_out = nn.Sequential(
|
369 |
+
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
370 |
+
)
|
371 |
+
self.attention_op: Optional[Any] = None
|
372 |
+
|
373 |
+
def forward(
|
374 |
+
self,
|
375 |
+
x,
|
376 |
+
context=None,
|
377 |
+
mask=None,
|
378 |
+
additional_tokens=None,
|
379 |
+
n_times_crossframe_attn_in_self=0,
|
380 |
+
):
|
381 |
+
if additional_tokens is not None:
|
382 |
+
# get the number of masked tokens at the beginning of the output sequence
|
383 |
+
n_tokens_to_mask = additional_tokens.shape[1]
|
384 |
+
# add additional token
|
385 |
+
x = torch.cat([additional_tokens, x], dim=1)
|
386 |
+
q = self.to_q(x)
|
387 |
+
context = default(context, x)
|
388 |
+
k = self.to_k(context)
|
389 |
+
v = self.to_v(context)
|
390 |
+
|
391 |
+
if n_times_crossframe_attn_in_self:
|
392 |
+
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
|
393 |
+
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
|
394 |
+
# n_cp = x.shape[0]//n_times_crossframe_attn_in_self
|
395 |
+
k = repeat(
|
396 |
+
k[::n_times_crossframe_attn_in_self],
|
397 |
+
"b ... -> (b n) ...",
|
398 |
+
n=n_times_crossframe_attn_in_self,
|
399 |
+
)
|
400 |
+
v = repeat(
|
401 |
+
v[::n_times_crossframe_attn_in_self],
|
402 |
+
"b ... -> (b n) ...",
|
403 |
+
n=n_times_crossframe_attn_in_self,
|
404 |
+
)
|
405 |
+
|
406 |
+
b, _, _ = q.shape
|
407 |
+
q, k, v = map(
|
408 |
+
lambda t: t.unsqueeze(3)
|
409 |
+
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
410 |
+
.permute(0, 2, 1, 3)
|
411 |
+
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
412 |
+
.contiguous(),
|
413 |
+
(q, k, v),
|
414 |
+
)
|
415 |
+
|
416 |
+
# actually compute the attention, what we cannot get enough of
|
417 |
+
if version.parse(xformers.__version__) >= version.parse("0.0.21"):
|
418 |
+
# NOTE: workaround for
|
419 |
+
# https://github.com/facebookresearch/xformers/issues/845
|
420 |
+
max_bs = 32768
|
421 |
+
N = q.shape[0]
|
422 |
+
n_batches = math.ceil(N / max_bs)
|
423 |
+
out = list()
|
424 |
+
for i_batch in range(n_batches):
|
425 |
+
batch = slice(i_batch * max_bs, (i_batch + 1) * max_bs)
|
426 |
+
out.append(
|
427 |
+
xformers.ops.memory_efficient_attention(
|
428 |
+
q[batch],
|
429 |
+
k[batch],
|
430 |
+
v[batch],
|
431 |
+
attn_bias=None,
|
432 |
+
op=self.attention_op,
|
433 |
+
)
|
434 |
+
)
|
435 |
+
out = torch.cat(out, 0)
|
436 |
+
else:
|
437 |
+
out = xformers.ops.memory_efficient_attention(
|
438 |
+
q, k, v, attn_bias=None, op=self.attention_op
|
439 |
+
)
|
440 |
+
|
441 |
+
# TODO: Use this directly in the attention operation, as a bias
|
442 |
+
if exists(mask):
|
443 |
+
raise NotImplementedError
|
444 |
+
out = (
|
445 |
+
out.unsqueeze(0)
|
446 |
+
.reshape(b, self.heads, out.shape[1], self.dim_head)
|
447 |
+
.permute(0, 2, 1, 3)
|
448 |
+
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
449 |
+
)
|
450 |
+
if additional_tokens is not None:
|
451 |
+
# remove additional token
|
452 |
+
out = out[:, n_tokens_to_mask:]
|
453 |
+
return self.to_out(out)
|
454 |
+
|
455 |
+
|
456 |
+
class BasicTransformerBlock(nn.Module):
|
457 |
+
ATTENTION_MODES = {
|
458 |
+
"softmax": CrossAttention, # vanilla attention
|
459 |
+
"softmax-xformers": MemoryEfficientCrossAttention, # ampere
|
460 |
+
}
|
461 |
+
|
462 |
+
def __init__(
|
463 |
+
self,
|
464 |
+
dim,
|
465 |
+
n_heads,
|
466 |
+
d_head,
|
467 |
+
dropout=0.0,
|
468 |
+
context_dim=None,
|
469 |
+
gated_ff=True,
|
470 |
+
checkpoint=True,
|
471 |
+
disable_self_attn=False,
|
472 |
+
attn_mode="softmax",
|
473 |
+
sdp_backend=None,
|
474 |
+
):
|
475 |
+
super().__init__()
|
476 |
+
assert attn_mode in self.ATTENTION_MODES
|
477 |
+
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
|
478 |
+
logpy.warn(
|
479 |
+
f"Attention mode '{attn_mode}' is not available. Falling "
|
480 |
+
f"back to native attention. This is not a problem in "
|
481 |
+
f"Pytorch >= 2.0. FYI, you are running with PyTorch "
|
482 |
+
f"version {torch.__version__}."
|
483 |
+
)
|
484 |
+
attn_mode = "softmax"
|
485 |
+
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
|
486 |
+
logpy.warn(
|
487 |
+
"We do not support vanilla attention anymore, as it is too "
|
488 |
+
"expensive. Sorry."
|
489 |
+
)
|
490 |
+
if not XFORMERS_IS_AVAILABLE:
|
491 |
+
assert (
|
492 |
+
False
|
493 |
+
), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
494 |
+
else:
|
495 |
+
logpy.info("Falling back to xformers efficient attention.")
|
496 |
+
attn_mode = "softmax-xformers"
|
497 |
+
attn_cls = self.ATTENTION_MODES[attn_mode]
|
498 |
+
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
499 |
+
assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
|
500 |
+
else:
|
501 |
+
assert sdp_backend is None
|
502 |
+
self.disable_self_attn = disable_self_attn
|
503 |
+
self.attn1 = attn_cls(
|
504 |
+
query_dim=dim,
|
505 |
+
heads=n_heads,
|
506 |
+
dim_head=d_head,
|
507 |
+
dropout=dropout,
|
508 |
+
context_dim=context_dim if self.disable_self_attn else None,
|
509 |
+
backend=sdp_backend,
|
510 |
+
) # is a self-attention if not self.disable_self_attn
|
511 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
512 |
+
self.attn2 = attn_cls(
|
513 |
+
query_dim=dim,
|
514 |
+
context_dim=context_dim,
|
515 |
+
heads=n_heads,
|
516 |
+
dim_head=d_head,
|
517 |
+
dropout=dropout,
|
518 |
+
backend=sdp_backend,
|
519 |
+
) # is self-attn if context is none
|
520 |
+
self.norm1 = nn.LayerNorm(dim)
|
521 |
+
self.norm2 = nn.LayerNorm(dim)
|
522 |
+
self.norm3 = nn.LayerNorm(dim)
|
523 |
+
self.checkpoint = checkpoint
|
524 |
+
if self.checkpoint:
|
525 |
+
logpy.debug(f"{self.__class__.__name__} is using checkpointing")
|
526 |
+
|
527 |
+
def forward(
|
528 |
+
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
|
529 |
+
):
|
530 |
+
kwargs = {"x": x}
|
531 |
+
|
532 |
+
if context is not None:
|
533 |
+
kwargs.update({"context": context})
|
534 |
+
|
535 |
+
if additional_tokens is not None:
|
536 |
+
kwargs.update({"additional_tokens": additional_tokens})
|
537 |
+
|
538 |
+
if n_times_crossframe_attn_in_self:
|
539 |
+
kwargs.update(
|
540 |
+
{"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
|
541 |
+
)
|
542 |
+
|
543 |
+
# return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
|
544 |
+
if self.checkpoint:
|
545 |
+
# inputs = {"x": x, "context": context}
|
546 |
+
return checkpoint(self._forward, x, context)
|
547 |
+
# return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
|
548 |
+
else:
|
549 |
+
return self._forward(**kwargs)
|
550 |
+
|
551 |
+
def _forward(
|
552 |
+
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
|
553 |
+
):
|
554 |
+
x = (
|
555 |
+
self.attn1(
|
556 |
+
self.norm1(x),
|
557 |
+
context=context if self.disable_self_attn else None,
|
558 |
+
additional_tokens=additional_tokens,
|
559 |
+
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
|
560 |
+
if not self.disable_self_attn
|
561 |
+
else 0,
|
562 |
+
)
|
563 |
+
+ x
|
564 |
+
)
|
565 |
+
x = (
|
566 |
+
self.attn2(
|
567 |
+
self.norm2(x), context=context, additional_tokens=additional_tokens
|
568 |
+
)
|
569 |
+
+ x
|
570 |
+
)
|
571 |
+
x = self.ff(self.norm3(x)) + x
|
572 |
+
return x
|
573 |
+
|
574 |
+
|
575 |
+
class BasicTransformerSingleLayerBlock(nn.Module):
|
576 |
+
ATTENTION_MODES = {
|
577 |
+
"softmax": CrossAttention, # vanilla attention
|
578 |
+
"softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version
|
579 |
+
# (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
|
580 |
+
}
|
581 |
+
|
582 |
+
def __init__(
|
583 |
+
self,
|
584 |
+
dim,
|
585 |
+
n_heads,
|
586 |
+
d_head,
|
587 |
+
dropout=0.0,
|
588 |
+
context_dim=None,
|
589 |
+
gated_ff=True,
|
590 |
+
checkpoint=True,
|
591 |
+
attn_mode="softmax",
|
592 |
+
):
|
593 |
+
super().__init__()
|
594 |
+
assert attn_mode in self.ATTENTION_MODES
|
595 |
+
attn_cls = self.ATTENTION_MODES[attn_mode]
|
596 |
+
self.attn1 = attn_cls(
|
597 |
+
query_dim=dim,
|
598 |
+
heads=n_heads,
|
599 |
+
dim_head=d_head,
|
600 |
+
dropout=dropout,
|
601 |
+
context_dim=context_dim,
|
602 |
+
)
|
603 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
604 |
+
self.norm1 = nn.LayerNorm(dim)
|
605 |
+
self.norm2 = nn.LayerNorm(dim)
|
606 |
+
self.checkpoint = checkpoint
|
607 |
+
|
608 |
+
def forward(self, x, context=None):
|
609 |
+
# inputs = {"x": x, "context": context}
|
610 |
+
# return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
|
611 |
+
return checkpoint(self._forward, x, context)
|
612 |
+
|
613 |
+
def _forward(self, x, context=None):
|
614 |
+
x = self.attn1(self.norm1(x), context=context) + x
|
615 |
+
x = self.ff(self.norm2(x)) + x
|
616 |
+
return x
|
617 |
+
|
618 |
+
|
619 |
+
class SpatialTransformer(nn.Module):
|
620 |
+
"""
|
621 |
+
Transformer block for image-like data.
|
622 |
+
First, project the input (aka embedding)
|
623 |
+
and reshape to b, t, d.
|
624 |
+
Then apply standard transformer action.
|
625 |
+
Finally, reshape to image
|
626 |
+
NEW: use_linear for more efficiency instead of the 1x1 convs
|
627 |
+
"""
|
628 |
+
|
629 |
+
def __init__(
|
630 |
+
self,
|
631 |
+
in_channels,
|
632 |
+
n_heads,
|
633 |
+
d_head,
|
634 |
+
depth=1,
|
635 |
+
dropout=0.0,
|
636 |
+
context_dim=None,
|
637 |
+
disable_self_attn=False,
|
638 |
+
use_linear=False,
|
639 |
+
attn_type="softmax",
|
640 |
+
use_checkpoint=True,
|
641 |
+
# sdp_backend=SDPBackend.FLASH_ATTENTION
|
642 |
+
sdp_backend=None,
|
643 |
+
):
|
644 |
+
super().__init__()
|
645 |
+
logpy.debug(
|
646 |
+
f"constructing {self.__class__.__name__} of depth {depth} w/ "
|
647 |
+
f"{in_channels} channels and {n_heads} heads."
|
648 |
+
)
|
649 |
+
|
650 |
+
if exists(context_dim) and not isinstance(context_dim, list):
|
651 |
+
context_dim = [context_dim]
|
652 |
+
if exists(context_dim) and isinstance(context_dim, list):
|
653 |
+
if depth != len(context_dim):
|
654 |
+
logpy.warn(
|
655 |
+
f"{self.__class__.__name__}: Found context dims "
|
656 |
+
f"{context_dim} of depth {len(context_dim)}, which does not "
|
657 |
+
f"match the specified 'depth' of {depth}. Setting context_dim "
|
658 |
+
f"to {depth * [context_dim[0]]} now."
|
659 |
+
)
|
660 |
+
# depth does not match context dims.
|
661 |
+
assert all(
|
662 |
+
map(lambda x: x == context_dim[0], context_dim)
|
663 |
+
), "need homogenous context_dim to match depth automatically"
|
664 |
+
context_dim = depth * [context_dim[0]]
|
665 |
+
elif context_dim is None:
|
666 |
+
context_dim = [None] * depth
|
667 |
+
self.in_channels = in_channels
|
668 |
+
inner_dim = n_heads * d_head
|
669 |
+
self.norm = Normalize(in_channels)
|
670 |
+
if not use_linear:
|
671 |
+
self.proj_in = nn.Conv2d(
|
672 |
+
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
673 |
+
)
|
674 |
+
else:
|
675 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
676 |
+
|
677 |
+
self.transformer_blocks = nn.ModuleList(
|
678 |
+
[
|
679 |
+
BasicTransformerBlock(
|
680 |
+
inner_dim,
|
681 |
+
n_heads,
|
682 |
+
d_head,
|
683 |
+
dropout=dropout,
|
684 |
+
context_dim=context_dim[d],
|
685 |
+
disable_self_attn=disable_self_attn,
|
686 |
+
attn_mode=attn_type,
|
687 |
+
checkpoint=use_checkpoint,
|
688 |
+
sdp_backend=sdp_backend,
|
689 |
+
)
|
690 |
+
for d in range(depth)
|
691 |
+
]
|
692 |
+
)
|
693 |
+
if not use_linear:
|
694 |
+
self.proj_out = zero_module(
|
695 |
+
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
696 |
+
)
|
697 |
+
else:
|
698 |
+
# self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
699 |
+
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
|
700 |
+
self.use_linear = use_linear
|
701 |
+
|
702 |
+
def forward(self, x, context=None):
|
703 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
704 |
+
if not isinstance(context, list):
|
705 |
+
context = [context]
|
706 |
+
b, c, h, w = x.shape
|
707 |
+
x_in = x
|
708 |
+
x = self.norm(x)
|
709 |
+
if not self.use_linear:
|
710 |
+
x = self.proj_in(x)
|
711 |
+
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
|
712 |
+
if self.use_linear:
|
713 |
+
x = self.proj_in(x)
|
714 |
+
for i, block in enumerate(self.transformer_blocks):
|
715 |
+
if i > 0 and len(context) == 1:
|
716 |
+
i = 0 # use same context for each block
|
717 |
+
x = block(x, context=context[i])
|
718 |
+
if self.use_linear:
|
719 |
+
x = self.proj_out(x)
|
720 |
+
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
|
721 |
+
if not self.use_linear:
|
722 |
+
x = self.proj_out(x)
|
723 |
+
return x + x_in
|
724 |
+
|
725 |
+
|
726 |
+
class SimpleTransformer(nn.Module):
|
727 |
+
def __init__(
|
728 |
+
self,
|
729 |
+
dim: int,
|
730 |
+
depth: int,
|
731 |
+
heads: int,
|
732 |
+
dim_head: int,
|
733 |
+
context_dim: Optional[int] = None,
|
734 |
+
dropout: float = 0.0,
|
735 |
+
checkpoint: bool = True,
|
736 |
+
):
|
737 |
+
super().__init__()
|
738 |
+
self.layers = nn.ModuleList([])
|
739 |
+
for _ in range(depth):
|
740 |
+
self.layers.append(
|
741 |
+
BasicTransformerBlock(
|
742 |
+
dim,
|
743 |
+
heads,
|
744 |
+
dim_head,
|
745 |
+
dropout=dropout,
|
746 |
+
context_dim=context_dim,
|
747 |
+
attn_mode="softmax-xformers",
|
748 |
+
checkpoint=checkpoint,
|
749 |
+
)
|
750 |
+
)
|
751 |
+
|
752 |
+
def forward(
|
753 |
+
self,
|
754 |
+
x: torch.Tensor,
|
755 |
+
context: Optional[torch.Tensor] = None,
|
756 |
+
) -> torch.Tensor:
|
757 |
+
for layer in self.layers:
|
758 |
+
x = layer(x, context)
|
759 |
+
return x
|
dragnuwa/svd/modules/autoencoding/__init__.py
ADDED
File without changes
|
dragnuwa/svd/modules/autoencoding/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (156 Bytes). View file
|
|
dragnuwa/svd/modules/autoencoding/__pycache__/temporal_ae.cpython-38.pyc
ADDED
Binary file (9.13 kB). View file
|
|
dragnuwa/svd/modules/autoencoding/losses/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__all__ = [
|
2 |
+
"GeneralLPIPSWithDiscriminator",
|
3 |
+
"LatentLPIPS",
|
4 |
+
]
|
5 |
+
|
6 |
+
from .discriminator_loss import GeneralLPIPSWithDiscriminator
|
7 |
+
from .lpips import LatentLPIPS
|
dragnuwa/svd/modules/autoencoding/losses/discriminator_loss.py
ADDED
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torchvision
|
7 |
+
from einops import rearrange
|
8 |
+
from matplotlib import colormaps
|
9 |
+
from matplotlib import pyplot as plt
|
10 |
+
|
11 |
+
from ....util import default, instantiate_from_config
|
12 |
+
from ..lpips.loss.lpips import LPIPS
|
13 |
+
from ..lpips.model.model import weights_init
|
14 |
+
from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss
|
15 |
+
|
16 |
+
|
17 |
+
class GeneralLPIPSWithDiscriminator(nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
disc_start: int,
|
21 |
+
logvar_init: float = 0.0,
|
22 |
+
disc_num_layers: int = 3,
|
23 |
+
disc_in_channels: int = 3,
|
24 |
+
disc_factor: float = 1.0,
|
25 |
+
disc_weight: float = 1.0,
|
26 |
+
perceptual_weight: float = 1.0,
|
27 |
+
disc_loss: str = "hinge",
|
28 |
+
scale_input_to_tgt_size: bool = False,
|
29 |
+
dims: int = 2,
|
30 |
+
learn_logvar: bool = False,
|
31 |
+
regularization_weights: Union[None, Dict[str, float]] = None,
|
32 |
+
additional_log_keys: Optional[List[str]] = None,
|
33 |
+
discriminator_config: Optional[Dict] = None,
|
34 |
+
):
|
35 |
+
super().__init__()
|
36 |
+
self.dims = dims
|
37 |
+
if self.dims > 2:
|
38 |
+
print(
|
39 |
+
f"running with dims={dims}. This means that for perceptual loss "
|
40 |
+
f"calculation, the LPIPS loss will be applied to each frame "
|
41 |
+
f"independently."
|
42 |
+
)
|
43 |
+
self.scale_input_to_tgt_size = scale_input_to_tgt_size
|
44 |
+
assert disc_loss in ["hinge", "vanilla"]
|
45 |
+
self.perceptual_loss = LPIPS().eval()
|
46 |
+
self.perceptual_weight = perceptual_weight
|
47 |
+
# output log variance
|
48 |
+
self.logvar = nn.Parameter(
|
49 |
+
torch.full((), logvar_init), requires_grad=learn_logvar
|
50 |
+
)
|
51 |
+
self.learn_logvar = learn_logvar
|
52 |
+
|
53 |
+
discriminator_config = default(
|
54 |
+
discriminator_config,
|
55 |
+
{
|
56 |
+
"target": "dragnuwa.svd.modules.autoencoding.lpips.model.model.NLayerDiscriminator",
|
57 |
+
"params": {
|
58 |
+
"input_nc": disc_in_channels,
|
59 |
+
"n_layers": disc_num_layers,
|
60 |
+
"use_actnorm": False,
|
61 |
+
},
|
62 |
+
},
|
63 |
+
)
|
64 |
+
|
65 |
+
self.discriminator = instantiate_from_config(discriminator_config).apply(
|
66 |
+
weights_init
|
67 |
+
)
|
68 |
+
self.discriminator_iter_start = disc_start
|
69 |
+
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
|
70 |
+
self.disc_factor = disc_factor
|
71 |
+
self.discriminator_weight = disc_weight
|
72 |
+
self.regularization_weights = default(regularization_weights, {})
|
73 |
+
|
74 |
+
self.forward_keys = [
|
75 |
+
"optimizer_idx",
|
76 |
+
"global_step",
|
77 |
+
"last_layer",
|
78 |
+
"split",
|
79 |
+
"regularization_log",
|
80 |
+
]
|
81 |
+
|
82 |
+
self.additional_log_keys = set(default(additional_log_keys, []))
|
83 |
+
self.additional_log_keys.update(set(self.regularization_weights.keys()))
|
84 |
+
|
85 |
+
def get_trainable_parameters(self) -> Iterator[nn.Parameter]:
|
86 |
+
return self.discriminator.parameters()
|
87 |
+
|
88 |
+
def get_trainable_autoencoder_parameters(self) -> Iterator[nn.Parameter]:
|
89 |
+
if self.learn_logvar:
|
90 |
+
yield self.logvar
|
91 |
+
yield from ()
|
92 |
+
|
93 |
+
@torch.no_grad()
|
94 |
+
def log_images(
|
95 |
+
self, inputs: torch.Tensor, reconstructions: torch.Tensor
|
96 |
+
) -> Dict[str, torch.Tensor]:
|
97 |
+
# calc logits of real/fake
|
98 |
+
logits_real = self.discriminator(inputs.contiguous().detach())
|
99 |
+
if len(logits_real.shape) < 4:
|
100 |
+
# Non patch-discriminator
|
101 |
+
return dict()
|
102 |
+
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
103 |
+
# -> (b, 1, h, w)
|
104 |
+
|
105 |
+
# parameters for colormapping
|
106 |
+
high = max(logits_fake.abs().max(), logits_real.abs().max()).item()
|
107 |
+
cmap = colormaps["PiYG"] # diverging colormap
|
108 |
+
|
109 |
+
def to_colormap(logits: torch.Tensor) -> torch.Tensor:
|
110 |
+
"""(b, 1, ...) -> (b, 3, ...)"""
|
111 |
+
logits = (logits + high) / (2 * high)
|
112 |
+
logits_np = cmap(logits.cpu().numpy())[..., :3] # truncate alpha channel
|
113 |
+
# -> (b, 1, ..., 3)
|
114 |
+
logits = torch.from_numpy(logits_np).to(logits.device)
|
115 |
+
return rearrange(logits, "b 1 ... c -> b c ...")
|
116 |
+
|
117 |
+
logits_real = torch.nn.functional.interpolate(
|
118 |
+
logits_real,
|
119 |
+
size=inputs.shape[-2:],
|
120 |
+
mode="nearest",
|
121 |
+
antialias=False,
|
122 |
+
)
|
123 |
+
logits_fake = torch.nn.functional.interpolate(
|
124 |
+
logits_fake,
|
125 |
+
size=reconstructions.shape[-2:],
|
126 |
+
mode="nearest",
|
127 |
+
antialias=False,
|
128 |
+
)
|
129 |
+
|
130 |
+
# alpha value of logits for overlay
|
131 |
+
alpha_real = torch.abs(logits_real) / high
|
132 |
+
alpha_fake = torch.abs(logits_fake) / high
|
133 |
+
# -> (b, 1, h, w) in range [0, 0.5]
|
134 |
+
# alpha value of lines don't really matter, since the values are the same
|
135 |
+
# for both images and logits anyway
|
136 |
+
grid_alpha_real = torchvision.utils.make_grid(alpha_real, nrow=4)
|
137 |
+
grid_alpha_fake = torchvision.utils.make_grid(alpha_fake, nrow=4)
|
138 |
+
grid_alpha = 0.8 * torch.cat((grid_alpha_real, grid_alpha_fake), dim=1)
|
139 |
+
# -> (1, h, w)
|
140 |
+
# blend logits and images together
|
141 |
+
|
142 |
+
# prepare logits for plotting
|
143 |
+
logits_real = to_colormap(logits_real)
|
144 |
+
logits_fake = to_colormap(logits_fake)
|
145 |
+
# resize logits
|
146 |
+
# -> (b, 3, h, w)
|
147 |
+
|
148 |
+
# make some grids
|
149 |
+
# add all logits to one plot
|
150 |
+
logits_real = torchvision.utils.make_grid(logits_real, nrow=4)
|
151 |
+
logits_fake = torchvision.utils.make_grid(logits_fake, nrow=4)
|
152 |
+
# I just love how torchvision calls the number of columns `nrow`
|
153 |
+
grid_logits = torch.cat((logits_real, logits_fake), dim=1)
|
154 |
+
# -> (3, h, w)
|
155 |
+
|
156 |
+
grid_images_real = torchvision.utils.make_grid(0.5 * inputs + 0.5, nrow=4)
|
157 |
+
grid_images_fake = torchvision.utils.make_grid(
|
158 |
+
0.5 * reconstructions + 0.5, nrow=4
|
159 |
+
)
|
160 |
+
grid_images = torch.cat((grid_images_real, grid_images_fake), dim=1)
|
161 |
+
# -> (3, h, w) in range [0, 1]
|
162 |
+
|
163 |
+
grid_blend = grid_alpha * grid_logits + (1 - grid_alpha) * grid_images
|
164 |
+
|
165 |
+
# Create labeled colorbar
|
166 |
+
dpi = 100
|
167 |
+
height = 128 / dpi
|
168 |
+
width = grid_logits.shape[2] / dpi
|
169 |
+
fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
|
170 |
+
img = ax.imshow(np.array([[-high, high]]), cmap=cmap)
|
171 |
+
plt.colorbar(
|
172 |
+
img,
|
173 |
+
cax=ax,
|
174 |
+
orientation="horizontal",
|
175 |
+
fraction=0.9,
|
176 |
+
aspect=width / height,
|
177 |
+
pad=0.0,
|
178 |
+
)
|
179 |
+
img.set_visible(False)
|
180 |
+
fig.tight_layout()
|
181 |
+
fig.canvas.draw()
|
182 |
+
# manually convert figure to numpy
|
183 |
+
cbar_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
184 |
+
cbar_np = cbar_np.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
185 |
+
cbar = torch.from_numpy(cbar_np.copy()).to(grid_logits.dtype) / 255.0
|
186 |
+
cbar = rearrange(cbar, "h w c -> c h w").to(grid_logits.device)
|
187 |
+
|
188 |
+
# Add colorbar to plot
|
189 |
+
annotated_grid = torch.cat((grid_logits, cbar), dim=1)
|
190 |
+
blended_grid = torch.cat((grid_blend, cbar), dim=1)
|
191 |
+
return {
|
192 |
+
"vis_logits": 2 * annotated_grid[None, ...] - 1,
|
193 |
+
"vis_logits_blended": 2 * blended_grid[None, ...] - 1,
|
194 |
+
}
|
195 |
+
|
196 |
+
def calculate_adaptive_weight(
|
197 |
+
self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer: torch.Tensor
|
198 |
+
) -> torch.Tensor:
|
199 |
+
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
200 |
+
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
201 |
+
|
202 |
+
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
203 |
+
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
204 |
+
d_weight = d_weight * self.discriminator_weight
|
205 |
+
return d_weight
|
206 |
+
|
207 |
+
def forward(
|
208 |
+
self,
|
209 |
+
inputs: torch.Tensor,
|
210 |
+
reconstructions: torch.Tensor,
|
211 |
+
*, # added because I changed the order here
|
212 |
+
regularization_log: Dict[str, torch.Tensor],
|
213 |
+
optimizer_idx: int,
|
214 |
+
global_step: int,
|
215 |
+
last_layer: torch.Tensor,
|
216 |
+
split: str = "train",
|
217 |
+
weights: Union[None, float, torch.Tensor] = None,
|
218 |
+
) -> Tuple[torch.Tensor, dict]:
|
219 |
+
if self.scale_input_to_tgt_size:
|
220 |
+
inputs = torch.nn.functional.interpolate(
|
221 |
+
inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
|
222 |
+
)
|
223 |
+
|
224 |
+
if self.dims > 2:
|
225 |
+
inputs, reconstructions = map(
|
226 |
+
lambda x: rearrange(x, "b c t h w -> (b t) c h w"),
|
227 |
+
(inputs, reconstructions),
|
228 |
+
)
|
229 |
+
|
230 |
+
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
231 |
+
if self.perceptual_weight > 0:
|
232 |
+
p_loss = self.perceptual_loss(
|
233 |
+
inputs.contiguous(), reconstructions.contiguous()
|
234 |
+
)
|
235 |
+
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
236 |
+
|
237 |
+
nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights)
|
238 |
+
|
239 |
+
# now the GAN part
|
240 |
+
if optimizer_idx == 0:
|
241 |
+
# generator update
|
242 |
+
if global_step >= self.discriminator_iter_start or not self.training:
|
243 |
+
logits_fake = self.discriminator(reconstructions.contiguous())
|
244 |
+
g_loss = -torch.mean(logits_fake)
|
245 |
+
if self.training:
|
246 |
+
d_weight = self.calculate_adaptive_weight(
|
247 |
+
nll_loss, g_loss, last_layer=last_layer
|
248 |
+
)
|
249 |
+
else:
|
250 |
+
d_weight = torch.tensor(1.0)
|
251 |
+
else:
|
252 |
+
d_weight = torch.tensor(0.0)
|
253 |
+
g_loss = torch.tensor(0.0, requires_grad=True)
|
254 |
+
|
255 |
+
loss = weighted_nll_loss + d_weight * self.disc_factor * g_loss
|
256 |
+
log = dict()
|
257 |
+
for k in regularization_log:
|
258 |
+
if k in self.regularization_weights:
|
259 |
+
loss = loss + self.regularization_weights[k] * regularization_log[k]
|
260 |
+
if k in self.additional_log_keys:
|
261 |
+
log[f"{split}/{k}"] = regularization_log[k].detach().float().mean()
|
262 |
+
|
263 |
+
log.update(
|
264 |
+
{
|
265 |
+
f"{split}/loss/total": loss.clone().detach().mean(),
|
266 |
+
f"{split}/loss/nll": nll_loss.detach().mean(),
|
267 |
+
f"{split}/loss/rec": rec_loss.detach().mean(),
|
268 |
+
f"{split}/loss/g": g_loss.detach().mean(),
|
269 |
+
f"{split}/scalars/logvar": self.logvar.detach(),
|
270 |
+
f"{split}/scalars/d_weight": d_weight.detach(),
|
271 |
+
}
|
272 |
+
)
|
273 |
+
|
274 |
+
return loss, log
|
275 |
+
elif optimizer_idx == 1:
|
276 |
+
# second pass for discriminator update
|
277 |
+
logits_real = self.discriminator(inputs.contiguous().detach())
|
278 |
+
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
279 |
+
|
280 |
+
if global_step >= self.discriminator_iter_start or not self.training:
|
281 |
+
d_loss = self.disc_factor * self.disc_loss(logits_real, logits_fake)
|
282 |
+
else:
|
283 |
+
d_loss = torch.tensor(0.0, requires_grad=True)
|
284 |
+
|
285 |
+
log = {
|
286 |
+
f"{split}/loss/disc": d_loss.clone().detach().mean(),
|
287 |
+
f"{split}/logits/real": logits_real.detach().mean(),
|
288 |
+
f"{split}/logits/fake": logits_fake.detach().mean(),
|
289 |
+
}
|
290 |
+
return d_loss, log
|
291 |
+
else:
|
292 |
+
raise NotImplementedError(f"Unknown optimizer_idx {optimizer_idx}")
|
293 |
+
|
294 |
+
def get_nll_loss(
|
295 |
+
self,
|
296 |
+
rec_loss: torch.Tensor,
|
297 |
+
weights: Optional[Union[float, torch.Tensor]] = None,
|
298 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
299 |
+
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
|
300 |
+
weighted_nll_loss = nll_loss
|
301 |
+
if weights is not None:
|
302 |
+
weighted_nll_loss = weights * nll_loss
|
303 |
+
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
|
304 |
+
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
305 |
+
|
306 |
+
return nll_loss, weighted_nll_loss
|
dragnuwa/svd/modules/autoencoding/losses/lpips.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from ....util import default, instantiate_from_config
|
5 |
+
from ..lpips.loss.lpips import LPIPS
|
6 |
+
|
7 |
+
|
8 |
+
class LatentLPIPS(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
decoder_config,
|
12 |
+
perceptual_weight=1.0,
|
13 |
+
latent_weight=1.0,
|
14 |
+
scale_input_to_tgt_size=False,
|
15 |
+
scale_tgt_to_input_size=False,
|
16 |
+
perceptual_weight_on_inputs=0.0,
|
17 |
+
):
|
18 |
+
super().__init__()
|
19 |
+
self.scale_input_to_tgt_size = scale_input_to_tgt_size
|
20 |
+
self.scale_tgt_to_input_size = scale_tgt_to_input_size
|
21 |
+
self.init_decoder(decoder_config)
|
22 |
+
self.perceptual_loss = LPIPS().eval()
|
23 |
+
self.perceptual_weight = perceptual_weight
|
24 |
+
self.latent_weight = latent_weight
|
25 |
+
self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
|
26 |
+
|
27 |
+
def init_decoder(self, config):
|
28 |
+
self.decoder = instantiate_from_config(config)
|
29 |
+
if hasattr(self.decoder, "encoder"):
|
30 |
+
del self.decoder.encoder
|
31 |
+
|
32 |
+
def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
|
33 |
+
log = dict()
|
34 |
+
loss = (latent_inputs - latent_predictions) ** 2
|
35 |
+
log[f"{split}/latent_l2_loss"] = loss.mean().detach()
|
36 |
+
image_reconstructions = None
|
37 |
+
if self.perceptual_weight > 0.0:
|
38 |
+
image_reconstructions = self.decoder.decode(latent_predictions)
|
39 |
+
image_targets = self.decoder.decode(latent_inputs)
|
40 |
+
perceptual_loss = self.perceptual_loss(
|
41 |
+
image_targets.contiguous(), image_reconstructions.contiguous()
|
42 |
+
)
|
43 |
+
loss = (
|
44 |
+
self.latent_weight * loss.mean()
|
45 |
+
+ self.perceptual_weight * perceptual_loss.mean()
|
46 |
+
)
|
47 |
+
log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
|
48 |
+
|
49 |
+
if self.perceptual_weight_on_inputs > 0.0:
|
50 |
+
image_reconstructions = default(
|
51 |
+
image_reconstructions, self.decoder.decode(latent_predictions)
|
52 |
+
)
|
53 |
+
if self.scale_input_to_tgt_size:
|
54 |
+
image_inputs = torch.nn.functional.interpolate(
|
55 |
+
image_inputs,
|
56 |
+
image_reconstructions.shape[2:],
|
57 |
+
mode="bicubic",
|
58 |
+
antialias=True,
|
59 |
+
)
|
60 |
+
elif self.scale_tgt_to_input_size:
|
61 |
+
image_reconstructions = torch.nn.functional.interpolate(
|
62 |
+
image_reconstructions,
|
63 |
+
image_inputs.shape[2:],
|
64 |
+
mode="bicubic",
|
65 |
+
antialias=True,
|
66 |
+
)
|
67 |
+
|
68 |
+
perceptual_loss2 = self.perceptual_loss(
|
69 |
+
image_inputs.contiguous(), image_reconstructions.contiguous()
|
70 |
+
)
|
71 |
+
loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
|
72 |
+
log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
|
73 |
+
return loss, log
|
dragnuwa/svd/modules/autoencoding/lpips/__init__.py
ADDED
File without changes
|
dragnuwa/svd/modules/autoencoding/lpips/loss/.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
vgg.pth
|
dragnuwa/svd/modules/autoencoding/lpips/loss/LICENSE
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
|
2 |
+
All rights reserved.
|
3 |
+
|
4 |
+
Redistribution and use in source and binary forms, with or without
|
5 |
+
modification, are permitted provided that the following conditions are met:
|
6 |
+
|
7 |
+
* Redistributions of source code must retain the above copyright notice, this
|
8 |
+
list of conditions and the following disclaimer.
|
9 |
+
|
10 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
11 |
+
this list of conditions and the following disclaimer in the documentation
|
12 |
+
and/or other materials provided with the distribution.
|
13 |
+
|
14 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
15 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
16 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
17 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
18 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
19 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
20 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
21 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
22 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
23 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
dragnuwa/svd/modules/autoencoding/lpips/loss/__init__.py
ADDED
File without changes
|
dragnuwa/svd/modules/autoencoding/lpips/loss/lpips.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
|
2 |
+
|
3 |
+
from collections import namedtuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torchvision import models
|
8 |
+
|
9 |
+
from ..util import get_ckpt_path
|
10 |
+
|
11 |
+
|
12 |
+
class LPIPS(nn.Module):
|
13 |
+
# Learned perceptual metric
|
14 |
+
def __init__(self, use_dropout=True):
|
15 |
+
super().__init__()
|
16 |
+
self.scaling_layer = ScalingLayer()
|
17 |
+
self.chns = [64, 128, 256, 512, 512] # vg16 features
|
18 |
+
self.net = vgg16(pretrained=True, requires_grad=False)
|
19 |
+
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
20 |
+
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
21 |
+
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
22 |
+
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
23 |
+
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
24 |
+
self.load_from_pretrained()
|
25 |
+
for param in self.parameters():
|
26 |
+
param.requires_grad = False
|
27 |
+
|
28 |
+
def load_from_pretrained(self, name="vgg_lpips"):
|
29 |
+
ckpt = get_ckpt_path(name, "models/svd/modules/autoencoding/lpips/loss")
|
30 |
+
self.load_state_dict(
|
31 |
+
torch.load(ckpt, map_location=torch.device("cpu")), strict=False
|
32 |
+
)
|
33 |
+
print("loaded pretrained LPIPS loss from {}".format(ckpt))
|
34 |
+
|
35 |
+
@classmethod
|
36 |
+
def from_pretrained(cls, name="vgg_lpips"):
|
37 |
+
if name != "vgg_lpips":
|
38 |
+
raise NotImplementedError
|
39 |
+
model = cls()
|
40 |
+
ckpt = get_ckpt_path(name)
|
41 |
+
model.load_state_dict(
|
42 |
+
torch.load(ckpt, map_location=torch.device("cpu")), strict=False
|
43 |
+
)
|
44 |
+
return model
|
45 |
+
|
46 |
+
def forward(self, input, target):
|
47 |
+
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
|
48 |
+
outs0, outs1 = self.net(in0_input), self.net(in1_input)
|
49 |
+
feats0, feats1, diffs = {}, {}, {}
|
50 |
+
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
|
51 |
+
for kk in range(len(self.chns)):
|
52 |
+
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
|
53 |
+
outs1[kk]
|
54 |
+
)
|
55 |
+
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
|
56 |
+
|
57 |
+
res = [
|
58 |
+
spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
|
59 |
+
for kk in range(len(self.chns))
|
60 |
+
]
|
61 |
+
val = res[0]
|
62 |
+
for l in range(1, len(self.chns)):
|
63 |
+
val += res[l]
|
64 |
+
return val
|
65 |
+
|
66 |
+
|
67 |
+
class ScalingLayer(nn.Module):
|
68 |
+
def __init__(self):
|
69 |
+
super(ScalingLayer, self).__init__()
|
70 |
+
self.register_buffer(
|
71 |
+
"shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
|
72 |
+
)
|
73 |
+
self.register_buffer(
|
74 |
+
"scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
|
75 |
+
)
|
76 |
+
|
77 |
+
def forward(self, inp):
|
78 |
+
return (inp - self.shift) / self.scale
|
79 |
+
|
80 |
+
|
81 |
+
class NetLinLayer(nn.Module):
|
82 |
+
"""A single linear layer which does a 1x1 conv"""
|
83 |
+
|
84 |
+
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
85 |
+
super(NetLinLayer, self).__init__()
|
86 |
+
layers = (
|
87 |
+
[
|
88 |
+
nn.Dropout(),
|
89 |
+
]
|
90 |
+
if (use_dropout)
|
91 |
+
else []
|
92 |
+
)
|
93 |
+
layers += [
|
94 |
+
nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
|
95 |
+
]
|
96 |
+
self.model = nn.Sequential(*layers)
|
97 |
+
|
98 |
+
|
99 |
+
class vgg16(torch.nn.Module):
|
100 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
101 |
+
super(vgg16, self).__init__()
|
102 |
+
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
|
103 |
+
self.slice1 = torch.nn.Sequential()
|
104 |
+
self.slice2 = torch.nn.Sequential()
|
105 |
+
self.slice3 = torch.nn.Sequential()
|
106 |
+
self.slice4 = torch.nn.Sequential()
|
107 |
+
self.slice5 = torch.nn.Sequential()
|
108 |
+
self.N_slices = 5
|
109 |
+
for x in range(4):
|
110 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
111 |
+
for x in range(4, 9):
|
112 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
113 |
+
for x in range(9, 16):
|
114 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
115 |
+
for x in range(16, 23):
|
116 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
117 |
+
for x in range(23, 30):
|
118 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
119 |
+
if not requires_grad:
|
120 |
+
for param in self.parameters():
|
121 |
+
param.requires_grad = False
|
122 |
+
|
123 |
+
def forward(self, X):
|
124 |
+
h = self.slice1(X)
|
125 |
+
h_relu1_2 = h
|
126 |
+
h = self.slice2(h)
|
127 |
+
h_relu2_2 = h
|
128 |
+
h = self.slice3(h)
|
129 |
+
h_relu3_3 = h
|
130 |
+
h = self.slice4(h)
|
131 |
+
h_relu4_3 = h
|
132 |
+
h = self.slice5(h)
|
133 |
+
h_relu5_3 = h
|
134 |
+
vgg_outputs = namedtuple(
|
135 |
+
"VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
|
136 |
+
)
|
137 |
+
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
138 |
+
return out
|
139 |
+
|
140 |
+
|
141 |
+
def normalize_tensor(x, eps=1e-10):
|
142 |
+
norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
|
143 |
+
return x / (norm_factor + eps)
|
144 |
+
|
145 |
+
|
146 |
+
def spatial_average(x, keepdim=True):
|
147 |
+
return x.mean([2, 3], keepdim=keepdim)
|
dragnuwa/svd/modules/autoencoding/lpips/model/LICENSE
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
|
2 |
+
All rights reserved.
|
3 |
+
|
4 |
+
Redistribution and use in source and binary forms, with or without
|
5 |
+
modification, are permitted provided that the following conditions are met:
|
6 |
+
|
7 |
+
* Redistributions of source code must retain the above copyright notice, this
|
8 |
+
list of conditions and the following disclaimer.
|
9 |
+
|
10 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
11 |
+
this list of conditions and the following disclaimer in the documentation
|
12 |
+
and/or other materials provided with the distribution.
|
13 |
+
|
14 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
15 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
16 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
17 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
18 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
19 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
20 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
21 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
22 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
23 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
24 |
+
|
25 |
+
|
26 |
+
--------------------------- LICENSE FOR pix2pix --------------------------------
|
27 |
+
BSD License
|
28 |
+
|
29 |
+
For pix2pix software
|
30 |
+
Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu
|
31 |
+
All rights reserved.
|
32 |
+
|
33 |
+
Redistribution and use in source and binary forms, with or without
|
34 |
+
modification, are permitted provided that the following conditions are met:
|
35 |
+
|
36 |
+
* Redistributions of source code must retain the above copyright notice, this
|
37 |
+
list of conditions and the following disclaimer.
|
38 |
+
|
39 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
40 |
+
this list of conditions and the following disclaimer in the documentation
|
41 |
+
and/or other materials provided with the distribution.
|
42 |
+
|
43 |
+
----------------------------- LICENSE FOR DCGAN --------------------------------
|
44 |
+
BSD License
|
45 |
+
|
46 |
+
For dcgan.torch software
|
47 |
+
|
48 |
+
Copyright (c) 2015, Facebook, Inc. All rights reserved.
|
49 |
+
|
50 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
51 |
+
|
52 |
+
Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
53 |
+
|
54 |
+
Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
55 |
+
|
56 |
+
Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
57 |
+
|
58 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
dragnuwa/svd/modules/autoencoding/lpips/model/__init__.py
ADDED
File without changes
|
dragnuwa/svd/modules/autoencoding/lpips/model/model.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from ..util import ActNorm
|
6 |
+
|
7 |
+
|
8 |
+
def weights_init(m):
|
9 |
+
classname = m.__class__.__name__
|
10 |
+
if classname.find("Conv") != -1:
|
11 |
+
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
12 |
+
elif classname.find("BatchNorm") != -1:
|
13 |
+
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
14 |
+
nn.init.constant_(m.bias.data, 0)
|
15 |
+
|
16 |
+
|
17 |
+
class NLayerDiscriminator(nn.Module):
|
18 |
+
"""Defines a PatchGAN discriminator as in Pix2Pix
|
19 |
+
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
|
23 |
+
"""Construct a PatchGAN discriminator
|
24 |
+
Parameters:
|
25 |
+
input_nc (int) -- the number of channels in input images
|
26 |
+
ndf (int) -- the number of filters in the last conv layer
|
27 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
28 |
+
norm_layer -- normalization layer
|
29 |
+
"""
|
30 |
+
super(NLayerDiscriminator, self).__init__()
|
31 |
+
if not use_actnorm:
|
32 |
+
norm_layer = nn.BatchNorm2d
|
33 |
+
else:
|
34 |
+
norm_layer = ActNorm
|
35 |
+
if (
|
36 |
+
type(norm_layer) == functools.partial
|
37 |
+
): # no need to use bias as BatchNorm2d has affine parameters
|
38 |
+
use_bias = norm_layer.func != nn.BatchNorm2d
|
39 |
+
else:
|
40 |
+
use_bias = norm_layer != nn.BatchNorm2d
|
41 |
+
|
42 |
+
kw = 4
|
43 |
+
padw = 1
|
44 |
+
sequence = [
|
45 |
+
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
|
46 |
+
nn.LeakyReLU(0.2, True),
|
47 |
+
]
|
48 |
+
nf_mult = 1
|
49 |
+
nf_mult_prev = 1
|
50 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
51 |
+
nf_mult_prev = nf_mult
|
52 |
+
nf_mult = min(2**n, 8)
|
53 |
+
sequence += [
|
54 |
+
nn.Conv2d(
|
55 |
+
ndf * nf_mult_prev,
|
56 |
+
ndf * nf_mult,
|
57 |
+
kernel_size=kw,
|
58 |
+
stride=2,
|
59 |
+
padding=padw,
|
60 |
+
bias=use_bias,
|
61 |
+
),
|
62 |
+
norm_layer(ndf * nf_mult),
|
63 |
+
nn.LeakyReLU(0.2, True),
|
64 |
+
]
|
65 |
+
|
66 |
+
nf_mult_prev = nf_mult
|
67 |
+
nf_mult = min(2**n_layers, 8)
|
68 |
+
sequence += [
|
69 |
+
nn.Conv2d(
|
70 |
+
ndf * nf_mult_prev,
|
71 |
+
ndf * nf_mult,
|
72 |
+
kernel_size=kw,
|
73 |
+
stride=1,
|
74 |
+
padding=padw,
|
75 |
+
bias=use_bias,
|
76 |
+
),
|
77 |
+
norm_layer(ndf * nf_mult),
|
78 |
+
nn.LeakyReLU(0.2, True),
|
79 |
+
]
|
80 |
+
|
81 |
+
sequence += [
|
82 |
+
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
|
83 |
+
] # output 1 channel prediction map
|
84 |
+
self.main = nn.Sequential(*sequence)
|
85 |
+
|
86 |
+
def forward(self, input):
|
87 |
+
"""Standard forward."""
|
88 |
+
return self.main(input)
|
dragnuwa/svd/modules/autoencoding/lpips/util.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import os
|
3 |
+
|
4 |
+
import requests
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
|
10 |
+
|
11 |
+
CKPT_MAP = {"vgg_lpips": "vgg.pth"}
|
12 |
+
|
13 |
+
MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
|
14 |
+
|
15 |
+
|
16 |
+
def download(url, local_path, chunk_size=1024):
|
17 |
+
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
|
18 |
+
with requests.get(url, stream=True) as r:
|
19 |
+
total_size = int(r.headers.get("content-length", 0))
|
20 |
+
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
|
21 |
+
with open(local_path, "wb") as f:
|
22 |
+
for data in r.iter_content(chunk_size=chunk_size):
|
23 |
+
if data:
|
24 |
+
f.write(data)
|
25 |
+
pbar.update(chunk_size)
|
26 |
+
|
27 |
+
|
28 |
+
def md5_hash(path):
|
29 |
+
with open(path, "rb") as f:
|
30 |
+
content = f.read()
|
31 |
+
return hashlib.md5(content).hexdigest()
|
32 |
+
|
33 |
+
|
34 |
+
def get_ckpt_path(name, root, check=False):
|
35 |
+
assert name in URL_MAP
|
36 |
+
path = os.path.join(root, CKPT_MAP[name])
|
37 |
+
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
|
38 |
+
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
|
39 |
+
download(URL_MAP[name], path)
|
40 |
+
md5 = md5_hash(path)
|
41 |
+
assert md5 == MD5_MAP[name], md5
|
42 |
+
return path
|
43 |
+
|
44 |
+
|
45 |
+
class ActNorm(nn.Module):
|
46 |
+
def __init__(
|
47 |
+
self, num_features, logdet=False, affine=True, allow_reverse_init=False
|
48 |
+
):
|
49 |
+
assert affine
|
50 |
+
super().__init__()
|
51 |
+
self.logdet = logdet
|
52 |
+
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
|
53 |
+
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
|
54 |
+
self.allow_reverse_init = allow_reverse_init
|
55 |
+
|
56 |
+
self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
|
57 |
+
|
58 |
+
def initialize(self, input):
|
59 |
+
with torch.no_grad():
|
60 |
+
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
|
61 |
+
mean = (
|
62 |
+
flatten.mean(1)
|
63 |
+
.unsqueeze(1)
|
64 |
+
.unsqueeze(2)
|
65 |
+
.unsqueeze(3)
|
66 |
+
.permute(1, 0, 2, 3)
|
67 |
+
)
|
68 |
+
std = (
|
69 |
+
flatten.std(1)
|
70 |
+
.unsqueeze(1)
|
71 |
+
.unsqueeze(2)
|
72 |
+
.unsqueeze(3)
|
73 |
+
.permute(1, 0, 2, 3)
|
74 |
+
)
|
75 |
+
|
76 |
+
self.loc.data.copy_(-mean)
|
77 |
+
self.scale.data.copy_(1 / (std + 1e-6))
|
78 |
+
|
79 |
+
def forward(self, input, reverse=False):
|
80 |
+
if reverse:
|
81 |
+
return self.reverse(input)
|
82 |
+
if len(input.shape) == 2:
|
83 |
+
input = input[:, :, None, None]
|
84 |
+
squeeze = True
|
85 |
+
else:
|
86 |
+
squeeze = False
|
87 |
+
|
88 |
+
_, _, height, width = input.shape
|
89 |
+
|
90 |
+
if self.training and self.initialized.item() == 0:
|
91 |
+
self.initialize(input)
|
92 |
+
self.initialized.fill_(1)
|
93 |
+
|
94 |
+
h = self.scale * (input + self.loc)
|
95 |
+
|
96 |
+
if squeeze:
|
97 |
+
h = h.squeeze(-1).squeeze(-1)
|
98 |
+
|
99 |
+
if self.logdet:
|
100 |
+
log_abs = torch.log(torch.abs(self.scale))
|
101 |
+
logdet = height * width * torch.sum(log_abs)
|
102 |
+
logdet = logdet * torch.ones(input.shape[0]).to(input)
|
103 |
+
return h, logdet
|
104 |
+
|
105 |
+
return h
|
106 |
+
|
107 |
+
def reverse(self, output):
|
108 |
+
if self.training and self.initialized.item() == 0:
|
109 |
+
if not self.allow_reverse_init:
|
110 |
+
raise RuntimeError(
|
111 |
+
"Initializing ActNorm in reverse direction is "
|
112 |
+
"disabled by default. Use allow_reverse_init=True to enable."
|
113 |
+
)
|
114 |
+
else:
|
115 |
+
self.initialize(output)
|
116 |
+
self.initialized.fill_(1)
|
117 |
+
|
118 |
+
if len(output.shape) == 2:
|
119 |
+
output = output[:, :, None, None]
|
120 |
+
squeeze = True
|
121 |
+
else:
|
122 |
+
squeeze = False
|
123 |
+
|
124 |
+
h = output / self.scale - self.loc
|
125 |
+
|
126 |
+
if squeeze:
|
127 |
+
h = h.squeeze(-1).squeeze(-1)
|
128 |
+
return h
|
dragnuwa/svd/modules/autoencoding/lpips/vqperceptual.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
|
5 |
+
def hinge_d_loss(logits_real, logits_fake):
|
6 |
+
loss_real = torch.mean(F.relu(1.0 - logits_real))
|
7 |
+
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
|
8 |
+
d_loss = 0.5 * (loss_real + loss_fake)
|
9 |
+
return d_loss
|
10 |
+
|
11 |
+
|
12 |
+
def vanilla_d_loss(logits_real, logits_fake):
|
13 |
+
d_loss = 0.5 * (
|
14 |
+
torch.mean(torch.nn.functional.softplus(-logits_real))
|
15 |
+
+ torch.mean(torch.nn.functional.softplus(logits_fake))
|
16 |
+
)
|
17 |
+
return d_loss
|
dragnuwa/svd/modules/autoencoding/regularizers/__init__.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
from typing import Any, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from ....modules.distributions.distributions import \
|
9 |
+
DiagonalGaussianDistribution
|
10 |
+
from .base import AbstractRegularizer
|
11 |
+
|
12 |
+
|
13 |
+
class DiagonalGaussianRegularizer(AbstractRegularizer):
|
14 |
+
def __init__(self, sample: bool = True):
|
15 |
+
super().__init__()
|
16 |
+
self.sample = sample
|
17 |
+
|
18 |
+
def get_trainable_parameters(self) -> Any:
|
19 |
+
yield from ()
|
20 |
+
|
21 |
+
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
22 |
+
log = dict()
|
23 |
+
posterior = DiagonalGaussianDistribution(z)
|
24 |
+
if self.sample:
|
25 |
+
z = posterior.sample()
|
26 |
+
else:
|
27 |
+
z = posterior.mode()
|
28 |
+
kl_loss = posterior.kl()
|
29 |
+
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
30 |
+
log["kl_loss"] = kl_loss
|
31 |
+
return z, log
|
dragnuwa/svd/modules/autoencoding/regularizers/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (1.49 kB). View file
|
|
dragnuwa/svd/modules/autoencoding/regularizers/__pycache__/base.cpython-38.pyc
ADDED
Binary file (2.03 kB). View file
|
|
dragnuwa/svd/modules/autoencoding/regularizers/base.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
from typing import Any, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
|
9 |
+
class AbstractRegularizer(nn.Module):
|
10 |
+
def __init__(self):
|
11 |
+
super().__init__()
|
12 |
+
|
13 |
+
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
14 |
+
raise NotImplementedError()
|
15 |
+
|
16 |
+
@abstractmethod
|
17 |
+
def get_trainable_parameters(self) -> Any:
|
18 |
+
raise NotImplementedError()
|
19 |
+
|
20 |
+
|
21 |
+
class IdentityRegularizer(AbstractRegularizer):
|
22 |
+
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
23 |
+
return z, dict()
|
24 |
+
|
25 |
+
def get_trainable_parameters(self) -> Any:
|
26 |
+
yield from ()
|
27 |
+
|
28 |
+
|
29 |
+
def measure_perplexity(
|
30 |
+
predicted_indices: torch.Tensor, num_centroids: int
|
31 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
32 |
+
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
|
33 |
+
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
|
34 |
+
encodings = (
|
35 |
+
F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
|
36 |
+
)
|
37 |
+
avg_probs = encodings.mean(0)
|
38 |
+
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
|
39 |
+
cluster_use = torch.sum(avg_probs > 0)
|
40 |
+
return perplexity, cluster_use
|
dragnuwa/svd/modules/autoencoding/regularizers/quantize.py
ADDED
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from abc import abstractmethod
|
3 |
+
from typing import Dict, Iterator, Literal, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from einops import rearrange
|
10 |
+
from torch import einsum
|
11 |
+
|
12 |
+
from .base import AbstractRegularizer, measure_perplexity
|
13 |
+
|
14 |
+
logpy = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
class AbstractQuantizer(AbstractRegularizer):
|
18 |
+
def __init__(self):
|
19 |
+
super().__init__()
|
20 |
+
# Define these in your init
|
21 |
+
# shape (N,)
|
22 |
+
self.used: Optional[torch.Tensor]
|
23 |
+
self.re_embed: int
|
24 |
+
self.unknown_index: Union[Literal["random"], int]
|
25 |
+
|
26 |
+
def remap_to_used(self, inds: torch.Tensor) -> torch.Tensor:
|
27 |
+
assert self.used is not None, "You need to define used indices for remap"
|
28 |
+
ishape = inds.shape
|
29 |
+
assert len(ishape) > 1
|
30 |
+
inds = inds.reshape(ishape[0], -1)
|
31 |
+
used = self.used.to(inds)
|
32 |
+
match = (inds[:, :, None] == used[None, None, ...]).long()
|
33 |
+
new = match.argmax(-1)
|
34 |
+
unknown = match.sum(2) < 1
|
35 |
+
if self.unknown_index == "random":
|
36 |
+
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
|
37 |
+
device=new.device
|
38 |
+
)
|
39 |
+
else:
|
40 |
+
new[unknown] = self.unknown_index
|
41 |
+
return new.reshape(ishape)
|
42 |
+
|
43 |
+
def unmap_to_all(self, inds: torch.Tensor) -> torch.Tensor:
|
44 |
+
assert self.used is not None, "You need to define used indices for remap"
|
45 |
+
ishape = inds.shape
|
46 |
+
assert len(ishape) > 1
|
47 |
+
inds = inds.reshape(ishape[0], -1)
|
48 |
+
used = self.used.to(inds)
|
49 |
+
if self.re_embed > self.used.shape[0]: # extra token
|
50 |
+
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
|
51 |
+
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
52 |
+
return back.reshape(ishape)
|
53 |
+
|
54 |
+
@abstractmethod
|
55 |
+
def get_codebook_entry(
|
56 |
+
self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None
|
57 |
+
) -> torch.Tensor:
|
58 |
+
raise NotImplementedError()
|
59 |
+
|
60 |
+
def get_trainable_parameters(self) -> Iterator[torch.nn.Parameter]:
|
61 |
+
yield from self.parameters()
|
62 |
+
|
63 |
+
|
64 |
+
class GumbelQuantizer(AbstractQuantizer):
|
65 |
+
"""
|
66 |
+
credit to @karpathy:
|
67 |
+
https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
|
68 |
+
Gumbel Softmax trick quantizer
|
69 |
+
Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
|
70 |
+
https://arxiv.org/abs/1611.01144
|
71 |
+
"""
|
72 |
+
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
num_hiddens: int,
|
76 |
+
embedding_dim: int,
|
77 |
+
n_embed: int,
|
78 |
+
straight_through: bool = True,
|
79 |
+
kl_weight: float = 5e-4,
|
80 |
+
temp_init: float = 1.0,
|
81 |
+
remap: Optional[str] = None,
|
82 |
+
unknown_index: str = "random",
|
83 |
+
loss_key: str = "loss/vq",
|
84 |
+
) -> None:
|
85 |
+
super().__init__()
|
86 |
+
|
87 |
+
self.loss_key = loss_key
|
88 |
+
self.embedding_dim = embedding_dim
|
89 |
+
self.n_embed = n_embed
|
90 |
+
|
91 |
+
self.straight_through = straight_through
|
92 |
+
self.temperature = temp_init
|
93 |
+
self.kl_weight = kl_weight
|
94 |
+
|
95 |
+
self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
|
96 |
+
self.embed = nn.Embedding(n_embed, embedding_dim)
|
97 |
+
|
98 |
+
self.remap = remap
|
99 |
+
if self.remap is not None:
|
100 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
101 |
+
self.re_embed = self.used.shape[0]
|
102 |
+
else:
|
103 |
+
self.used = None
|
104 |
+
self.re_embed = n_embed
|
105 |
+
if unknown_index == "extra":
|
106 |
+
self.unknown_index = self.re_embed
|
107 |
+
self.re_embed = self.re_embed + 1
|
108 |
+
else:
|
109 |
+
assert unknown_index == "random" or isinstance(
|
110 |
+
unknown_index, int
|
111 |
+
), "unknown index needs to be 'random', 'extra' or any integer"
|
112 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
113 |
+
if self.remap is not None:
|
114 |
+
logpy.info(
|
115 |
+
f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
|
116 |
+
f"Using {self.unknown_index} for unknown indices."
|
117 |
+
)
|
118 |
+
|
119 |
+
def forward(
|
120 |
+
self, z: torch.Tensor, temp: Optional[float] = None, return_logits: bool = False
|
121 |
+
) -> Tuple[torch.Tensor, Dict]:
|
122 |
+
# force hard = True when we are in eval mode, as we must quantize.
|
123 |
+
# actually, always true seems to work
|
124 |
+
hard = self.straight_through if self.training else True
|
125 |
+
temp = self.temperature if temp is None else temp
|
126 |
+
out_dict = {}
|
127 |
+
logits = self.proj(z)
|
128 |
+
if self.remap is not None:
|
129 |
+
# continue only with used logits
|
130 |
+
full_zeros = torch.zeros_like(logits)
|
131 |
+
logits = logits[:, self.used, ...]
|
132 |
+
|
133 |
+
soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
|
134 |
+
if self.remap is not None:
|
135 |
+
# go back to all entries but unused set to zero
|
136 |
+
full_zeros[:, self.used, ...] = soft_one_hot
|
137 |
+
soft_one_hot = full_zeros
|
138 |
+
z_q = einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
|
139 |
+
|
140 |
+
# + kl divergence to the prior loss
|
141 |
+
qy = F.softmax(logits, dim=1)
|
142 |
+
diff = (
|
143 |
+
self.kl_weight
|
144 |
+
* torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
|
145 |
+
)
|
146 |
+
out_dict[self.loss_key] = diff
|
147 |
+
|
148 |
+
ind = soft_one_hot.argmax(dim=1)
|
149 |
+
out_dict["indices"] = ind
|
150 |
+
if self.remap is not None:
|
151 |
+
ind = self.remap_to_used(ind)
|
152 |
+
|
153 |
+
if return_logits:
|
154 |
+
out_dict["logits"] = logits
|
155 |
+
|
156 |
+
return z_q, out_dict
|
157 |
+
|
158 |
+
def get_codebook_entry(self, indices, shape):
|
159 |
+
# TODO: shape not yet optional
|
160 |
+
b, h, w, c = shape
|
161 |
+
assert b * h * w == indices.shape[0]
|
162 |
+
indices = rearrange(indices, "(b h w) -> b h w", b=b, h=h, w=w)
|
163 |
+
if self.remap is not None:
|
164 |
+
indices = self.unmap_to_all(indices)
|
165 |
+
one_hot = (
|
166 |
+
F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
|
167 |
+
)
|
168 |
+
z_q = einsum("b n h w, n d -> b d h w", one_hot, self.embed.weight)
|
169 |
+
return z_q
|
170 |
+
|
171 |
+
|
172 |
+
class VectorQuantizer(AbstractQuantizer):
|
173 |
+
"""
|
174 |
+
____________________________________________
|
175 |
+
Discretization bottleneck part of the VQ-VAE.
|
176 |
+
Inputs:
|
177 |
+
- n_e : number of embeddings
|
178 |
+
- e_dim : dimension of embedding
|
179 |
+
- beta : commitment cost used in loss term,
|
180 |
+
beta * ||z_e(x)-sg[e]||^2
|
181 |
+
_____________________________________________
|
182 |
+
"""
|
183 |
+
|
184 |
+
def __init__(
|
185 |
+
self,
|
186 |
+
n_e: int,
|
187 |
+
e_dim: int,
|
188 |
+
beta: float = 0.25,
|
189 |
+
remap: Optional[str] = None,
|
190 |
+
unknown_index: str = "random",
|
191 |
+
sane_index_shape: bool = False,
|
192 |
+
log_perplexity: bool = False,
|
193 |
+
embedding_weight_norm: bool = False,
|
194 |
+
loss_key: str = "loss/vq",
|
195 |
+
):
|
196 |
+
super().__init__()
|
197 |
+
self.n_e = n_e
|
198 |
+
self.e_dim = e_dim
|
199 |
+
self.beta = beta
|
200 |
+
self.loss_key = loss_key
|
201 |
+
|
202 |
+
if not embedding_weight_norm:
|
203 |
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
204 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
205 |
+
else:
|
206 |
+
self.embedding = torch.nn.utils.weight_norm(
|
207 |
+
nn.Embedding(self.n_e, self.e_dim), dim=1
|
208 |
+
)
|
209 |
+
|
210 |
+
self.remap = remap
|
211 |
+
if self.remap is not None:
|
212 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
213 |
+
self.re_embed = self.used.shape[0]
|
214 |
+
else:
|
215 |
+
self.used = None
|
216 |
+
self.re_embed = n_e
|
217 |
+
if unknown_index == "extra":
|
218 |
+
self.unknown_index = self.re_embed
|
219 |
+
self.re_embed = self.re_embed + 1
|
220 |
+
else:
|
221 |
+
assert unknown_index == "random" or isinstance(
|
222 |
+
unknown_index, int
|
223 |
+
), "unknown index needs to be 'random', 'extra' or any integer"
|
224 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
225 |
+
if self.remap is not None:
|
226 |
+
logpy.info(
|
227 |
+
f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
228 |
+
f"Using {self.unknown_index} for unknown indices."
|
229 |
+
)
|
230 |
+
|
231 |
+
self.sane_index_shape = sane_index_shape
|
232 |
+
self.log_perplexity = log_perplexity
|
233 |
+
|
234 |
+
def forward(
|
235 |
+
self,
|
236 |
+
z: torch.Tensor,
|
237 |
+
) -> Tuple[torch.Tensor, Dict]:
|
238 |
+
do_reshape = z.ndim == 4
|
239 |
+
if do_reshape:
|
240 |
+
# # reshape z -> (batch, height, width, channel) and flatten
|
241 |
+
z = rearrange(z, "b c h w -> b h w c").contiguous()
|
242 |
+
|
243 |
+
else:
|
244 |
+
assert z.ndim < 4, "No reshaping strategy for inputs > 4 dimensions defined"
|
245 |
+
z = z.contiguous()
|
246 |
+
|
247 |
+
z_flattened = z.view(-1, self.e_dim)
|
248 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
249 |
+
|
250 |
+
d = (
|
251 |
+
torch.sum(z_flattened**2, dim=1, keepdim=True)
|
252 |
+
+ torch.sum(self.embedding.weight**2, dim=1)
|
253 |
+
- 2
|
254 |
+
* torch.einsum(
|
255 |
+
"bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n")
|
256 |
+
)
|
257 |
+
)
|
258 |
+
|
259 |
+
min_encoding_indices = torch.argmin(d, dim=1)
|
260 |
+
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
261 |
+
loss_dict = {}
|
262 |
+
if self.log_perplexity:
|
263 |
+
perplexity, cluster_usage = measure_perplexity(
|
264 |
+
min_encoding_indices.detach(), self.n_e
|
265 |
+
)
|
266 |
+
loss_dict.update({"perplexity": perplexity, "cluster_usage": cluster_usage})
|
267 |
+
|
268 |
+
# compute loss for embedding
|
269 |
+
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean(
|
270 |
+
(z_q - z.detach()) ** 2
|
271 |
+
)
|
272 |
+
loss_dict[self.loss_key] = loss
|
273 |
+
|
274 |
+
# preserve gradients
|
275 |
+
z_q = z + (z_q - z).detach()
|
276 |
+
|
277 |
+
# reshape back to match original input shape
|
278 |
+
if do_reshape:
|
279 |
+
z_q = rearrange(z_q, "b h w c -> b c h w").contiguous()
|
280 |
+
|
281 |
+
if self.remap is not None:
|
282 |
+
min_encoding_indices = min_encoding_indices.reshape(
|
283 |
+
z.shape[0], -1
|
284 |
+
) # add batch axis
|
285 |
+
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
286 |
+
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
|
287 |
+
|
288 |
+
if self.sane_index_shape:
|
289 |
+
if do_reshape:
|
290 |
+
min_encoding_indices = min_encoding_indices.reshape(
|
291 |
+
z_q.shape[0], z_q.shape[2], z_q.shape[3]
|
292 |
+
)
|
293 |
+
else:
|
294 |
+
min_encoding_indices = rearrange(
|
295 |
+
min_encoding_indices, "(b s) 1 -> b s", b=z_q.shape[0]
|
296 |
+
)
|
297 |
+
|
298 |
+
loss_dict["min_encoding_indices"] = min_encoding_indices
|
299 |
+
|
300 |
+
return z_q, loss_dict
|
301 |
+
|
302 |
+
def get_codebook_entry(
|
303 |
+
self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None
|
304 |
+
) -> torch.Tensor:
|
305 |
+
# shape specifying (batch, height, width, channel)
|
306 |
+
if self.remap is not None:
|
307 |
+
assert shape is not None, "Need to give shape for remap"
|
308 |
+
indices = indices.reshape(shape[0], -1) # add batch axis
|
309 |
+
indices = self.unmap_to_all(indices)
|
310 |
+
indices = indices.reshape(-1) # flatten again
|
311 |
+
|
312 |
+
# get quantized latent vectors
|
313 |
+
z_q = self.embedding(indices)
|
314 |
+
|
315 |
+
if shape is not None:
|
316 |
+
z_q = z_q.view(shape)
|
317 |
+
# reshape back to match original input shape
|
318 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
319 |
+
|
320 |
+
return z_q
|
321 |
+
|
322 |
+
|
323 |
+
class EmbeddingEMA(nn.Module):
|
324 |
+
def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
|
325 |
+
super().__init__()
|
326 |
+
self.decay = decay
|
327 |
+
self.eps = eps
|
328 |
+
weight = torch.randn(num_tokens, codebook_dim)
|
329 |
+
self.weight = nn.Parameter(weight, requires_grad=False)
|
330 |
+
self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False)
|
331 |
+
self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
|
332 |
+
self.update = True
|
333 |
+
|
334 |
+
def forward(self, embed_id):
|
335 |
+
return F.embedding(embed_id, self.weight)
|
336 |
+
|
337 |
+
def cluster_size_ema_update(self, new_cluster_size):
|
338 |
+
self.cluster_size.data.mul_(self.decay).add_(
|
339 |
+
new_cluster_size, alpha=1 - self.decay
|
340 |
+
)
|
341 |
+
|
342 |
+
def embed_avg_ema_update(self, new_embed_avg):
|
343 |
+
self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
|
344 |
+
|
345 |
+
def weight_update(self, num_tokens):
|
346 |
+
n = self.cluster_size.sum()
|
347 |
+
smoothed_cluster_size = (
|
348 |
+
(self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
|
349 |
+
)
|
350 |
+
# normalize embedding average with smoothed cluster size
|
351 |
+
embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
|
352 |
+
self.weight.data.copy_(embed_normalized)
|
353 |
+
|
354 |
+
|
355 |
+
class EMAVectorQuantizer(AbstractQuantizer):
|
356 |
+
def __init__(
|
357 |
+
self,
|
358 |
+
n_embed: int,
|
359 |
+
embedding_dim: int,
|
360 |
+
beta: float,
|
361 |
+
decay: float = 0.99,
|
362 |
+
eps: float = 1e-5,
|
363 |
+
remap: Optional[str] = None,
|
364 |
+
unknown_index: str = "random",
|
365 |
+
loss_key: str = "loss/vq",
|
366 |
+
):
|
367 |
+
super().__init__()
|
368 |
+
self.codebook_dim = embedding_dim
|
369 |
+
self.num_tokens = n_embed
|
370 |
+
self.beta = beta
|
371 |
+
self.loss_key = loss_key
|
372 |
+
|
373 |
+
self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
|
374 |
+
|
375 |
+
self.remap = remap
|
376 |
+
if self.remap is not None:
|
377 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
378 |
+
self.re_embed = self.used.shape[0]
|
379 |
+
else:
|
380 |
+
self.used = None
|
381 |
+
self.re_embed = n_embed
|
382 |
+
if unknown_index == "extra":
|
383 |
+
self.unknown_index = self.re_embed
|
384 |
+
self.re_embed = self.re_embed + 1
|
385 |
+
else:
|
386 |
+
assert unknown_index == "random" or isinstance(
|
387 |
+
unknown_index, int
|
388 |
+
), "unknown index needs to be 'random', 'extra' or any integer"
|
389 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
390 |
+
if self.remap is not None:
|
391 |
+
logpy.info(
|
392 |
+
f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
|
393 |
+
f"Using {self.unknown_index} for unknown indices."
|
394 |
+
)
|
395 |
+
|
396 |
+
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
|
397 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
398 |
+
# z, 'b c h w -> b h w c'
|
399 |
+
z = rearrange(z, "b c h w -> b h w c")
|
400 |
+
z_flattened = z.reshape(-1, self.codebook_dim)
|
401 |
+
|
402 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
403 |
+
d = (
|
404 |
+
z_flattened.pow(2).sum(dim=1, keepdim=True)
|
405 |
+
+ self.embedding.weight.pow(2).sum(dim=1)
|
406 |
+
- 2 * torch.einsum("bd,nd->bn", z_flattened, self.embedding.weight)
|
407 |
+
) # 'n d -> d n'
|
408 |
+
|
409 |
+
encoding_indices = torch.argmin(d, dim=1)
|
410 |
+
|
411 |
+
z_q = self.embedding(encoding_indices).view(z.shape)
|
412 |
+
encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
|
413 |
+
avg_probs = torch.mean(encodings, dim=0)
|
414 |
+
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
|
415 |
+
|
416 |
+
if self.training and self.embedding.update:
|
417 |
+
# EMA cluster size
|
418 |
+
encodings_sum = encodings.sum(0)
|
419 |
+
self.embedding.cluster_size_ema_update(encodings_sum)
|
420 |
+
# EMA embedding average
|
421 |
+
embed_sum = encodings.transpose(0, 1) @ z_flattened
|
422 |
+
self.embedding.embed_avg_ema_update(embed_sum)
|
423 |
+
# normalize embed_avg and update weight
|
424 |
+
self.embedding.weight_update(self.num_tokens)
|
425 |
+
|
426 |
+
# compute loss for embedding
|
427 |
+
loss = self.beta * F.mse_loss(z_q.detach(), z)
|
428 |
+
|
429 |
+
# preserve gradients
|
430 |
+
z_q = z + (z_q - z).detach()
|
431 |
+
|
432 |
+
# reshape back to match original input shape
|
433 |
+
# z_q, 'b h w c -> b c h w'
|
434 |
+
z_q = rearrange(z_q, "b h w c -> b c h w")
|
435 |
+
|
436 |
+
out_dict = {
|
437 |
+
self.loss_key: loss,
|
438 |
+
"encodings": encodings,
|
439 |
+
"encoding_indices": encoding_indices,
|
440 |
+
"perplexity": perplexity,
|
441 |
+
}
|
442 |
+
|
443 |
+
return z_q, out_dict
|
444 |
+
|
445 |
+
|
446 |
+
class VectorQuantizerWithInputProjection(VectorQuantizer):
|
447 |
+
def __init__(
|
448 |
+
self,
|
449 |
+
input_dim: int,
|
450 |
+
n_codes: int,
|
451 |
+
codebook_dim: int,
|
452 |
+
beta: float = 1.0,
|
453 |
+
output_dim: Optional[int] = None,
|
454 |
+
**kwargs,
|
455 |
+
):
|
456 |
+
super().__init__(n_codes, codebook_dim, beta, **kwargs)
|
457 |
+
self.proj_in = nn.Linear(input_dim, codebook_dim)
|
458 |
+
self.output_dim = output_dim
|
459 |
+
if output_dim is not None:
|
460 |
+
self.proj_out = nn.Linear(codebook_dim, output_dim)
|
461 |
+
else:
|
462 |
+
self.proj_out = nn.Identity()
|
463 |
+
|
464 |
+
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
|
465 |
+
rearr = False
|
466 |
+
in_shape = z.shape
|
467 |
+
|
468 |
+
if z.ndim > 3:
|
469 |
+
rearr = self.output_dim is not None
|
470 |
+
z = rearrange(z, "b c ... -> b (...) c")
|
471 |
+
z = self.proj_in(z)
|
472 |
+
z_q, loss_dict = super().forward(z)
|
473 |
+
|
474 |
+
z_q = self.proj_out(z_q)
|
475 |
+
if rearr:
|
476 |
+
if len(in_shape) == 4:
|
477 |
+
z_q = rearrange(z_q, "b (h w) c -> b c h w ", w=in_shape[-1])
|
478 |
+
elif len(in_shape) == 5:
|
479 |
+
z_q = rearrange(
|
480 |
+
z_q, "b (t h w) c -> b c t h w ", w=in_shape[-1], h=in_shape[-2]
|
481 |
+
)
|
482 |
+
else:
|
483 |
+
raise NotImplementedError(
|
484 |
+
f"rearranging not available for {len(in_shape)}-dimensional input."
|
485 |
+
)
|
486 |
+
|
487 |
+
return z_q, loss_dict
|
dragnuwa/svd/modules/autoencoding/temporal_ae.py
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Iterable, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from einops import rearrange, repeat
|
5 |
+
|
6 |
+
from dragnuwa.svd.modules.diffusionmodules.model import (
|
7 |
+
XFORMERS_IS_AVAILABLE,
|
8 |
+
AttnBlock,
|
9 |
+
Decoder,
|
10 |
+
MemoryEfficientAttnBlock,
|
11 |
+
ResnetBlock,
|
12 |
+
)
|
13 |
+
from dragnuwa.svd.modules.diffusionmodules.openaimodel import ResBlock, timestep_embedding
|
14 |
+
from dragnuwa.svd.modules.video_attention import VideoTransformerBlock
|
15 |
+
from dragnuwa.svd.util import partialclass
|
16 |
+
|
17 |
+
|
18 |
+
class VideoResBlock(ResnetBlock):
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
out_channels,
|
22 |
+
*args,
|
23 |
+
dropout=0.0,
|
24 |
+
video_kernel_size=3,
|
25 |
+
alpha=0.0,
|
26 |
+
merge_strategy="learned",
|
27 |
+
**kwargs,
|
28 |
+
):
|
29 |
+
super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)
|
30 |
+
if video_kernel_size is None:
|
31 |
+
video_kernel_size = [3, 1, 1]
|
32 |
+
self.time_stack = ResBlock(
|
33 |
+
channels=out_channels,
|
34 |
+
emb_channels=0,
|
35 |
+
dropout=dropout,
|
36 |
+
dims=3,
|
37 |
+
use_scale_shift_norm=False,
|
38 |
+
use_conv=False,
|
39 |
+
up=False,
|
40 |
+
down=False,
|
41 |
+
kernel_size=video_kernel_size,
|
42 |
+
use_checkpoint=False,
|
43 |
+
skip_t_emb=True,
|
44 |
+
)
|
45 |
+
|
46 |
+
self.merge_strategy = merge_strategy
|
47 |
+
if self.merge_strategy == "fixed":
|
48 |
+
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
49 |
+
elif self.merge_strategy == "learned":
|
50 |
+
self.register_parameter(
|
51 |
+
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
|
52 |
+
)
|
53 |
+
else:
|
54 |
+
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
55 |
+
|
56 |
+
def get_alpha(self, bs):
|
57 |
+
if self.merge_strategy == "fixed":
|
58 |
+
return self.mix_factor
|
59 |
+
elif self.merge_strategy == "learned":
|
60 |
+
return torch.sigmoid(self.mix_factor)
|
61 |
+
else:
|
62 |
+
raise NotImplementedError()
|
63 |
+
|
64 |
+
def forward(self, x, temb, skip_video=False, timesteps=None):
|
65 |
+
if timesteps is None:
|
66 |
+
timesteps = self.timesteps
|
67 |
+
|
68 |
+
b, c, h, w = x.shape
|
69 |
+
|
70 |
+
x = super().forward(x, temb)
|
71 |
+
|
72 |
+
if not skip_video:
|
73 |
+
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
74 |
+
|
75 |
+
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
76 |
+
|
77 |
+
x = self.time_stack(x, temb)
|
78 |
+
|
79 |
+
alpha = self.get_alpha(bs=b // timesteps)
|
80 |
+
x = alpha * x + (1.0 - alpha) * x_mix
|
81 |
+
|
82 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
83 |
+
return x
|
84 |
+
|
85 |
+
|
86 |
+
class AE3DConv(torch.nn.Conv2d):
|
87 |
+
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
|
88 |
+
super().__init__(in_channels, out_channels, *args, **kwargs)
|
89 |
+
if isinstance(video_kernel_size, Iterable):
|
90 |
+
padding = [int(k // 2) for k in video_kernel_size]
|
91 |
+
else:
|
92 |
+
padding = int(video_kernel_size // 2)
|
93 |
+
|
94 |
+
self.time_mix_conv = torch.nn.Conv3d(
|
95 |
+
in_channels=out_channels,
|
96 |
+
out_channels=out_channels,
|
97 |
+
kernel_size=video_kernel_size,
|
98 |
+
padding=padding,
|
99 |
+
)
|
100 |
+
|
101 |
+
def forward(self, input, timesteps, skip_video=False):
|
102 |
+
x = super().forward(input)
|
103 |
+
if skip_video:
|
104 |
+
return x
|
105 |
+
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
106 |
+
x = self.time_mix_conv(x)
|
107 |
+
return rearrange(x, "b c t h w -> (b t) c h w")
|
108 |
+
|
109 |
+
|
110 |
+
class VideoBlock(AttnBlock):
|
111 |
+
def __init__(
|
112 |
+
self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
|
113 |
+
):
|
114 |
+
super().__init__(in_channels)
|
115 |
+
# no context, single headed, as in base class
|
116 |
+
self.time_mix_block = VideoTransformerBlock(
|
117 |
+
dim=in_channels,
|
118 |
+
n_heads=1,
|
119 |
+
d_head=in_channels,
|
120 |
+
checkpoint=False,
|
121 |
+
ff_in=True,
|
122 |
+
attn_mode="softmax",
|
123 |
+
)
|
124 |
+
|
125 |
+
time_embed_dim = self.in_channels * 4
|
126 |
+
self.video_time_embed = torch.nn.Sequential(
|
127 |
+
torch.nn.Linear(self.in_channels, time_embed_dim),
|
128 |
+
torch.nn.SiLU(),
|
129 |
+
torch.nn.Linear(time_embed_dim, self.in_channels),
|
130 |
+
)
|
131 |
+
|
132 |
+
self.merge_strategy = merge_strategy
|
133 |
+
if self.merge_strategy == "fixed":
|
134 |
+
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
135 |
+
elif self.merge_strategy == "learned":
|
136 |
+
self.register_parameter(
|
137 |
+
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
|
138 |
+
)
|
139 |
+
else:
|
140 |
+
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
141 |
+
|
142 |
+
def forward(self, x, timesteps, skip_video=False):
|
143 |
+
if skip_video:
|
144 |
+
return super().forward(x)
|
145 |
+
|
146 |
+
x_in = x
|
147 |
+
x = self.attention(x)
|
148 |
+
h, w = x.shape[2:]
|
149 |
+
x = rearrange(x, "b c h w -> b (h w) c")
|
150 |
+
|
151 |
+
x_mix = x
|
152 |
+
num_frames = torch.arange(timesteps, device=x.device)
|
153 |
+
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
|
154 |
+
num_frames = rearrange(num_frames, "b t -> (b t)")
|
155 |
+
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
|
156 |
+
emb = self.video_time_embed(t_emb) # b, n_channels
|
157 |
+
emb = emb[:, None, :]
|
158 |
+
x_mix = x_mix + emb
|
159 |
+
|
160 |
+
alpha = self.get_alpha()
|
161 |
+
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
|
162 |
+
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
|
163 |
+
|
164 |
+
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
165 |
+
x = self.proj_out(x)
|
166 |
+
|
167 |
+
return x_in + x
|
168 |
+
|
169 |
+
def get_alpha(
|
170 |
+
self,
|
171 |
+
):
|
172 |
+
if self.merge_strategy == "fixed":
|
173 |
+
return self.mix_factor
|
174 |
+
elif self.merge_strategy == "learned":
|
175 |
+
return torch.sigmoid(self.mix_factor)
|
176 |
+
else:
|
177 |
+
raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
|
178 |
+
|
179 |
+
|
180 |
+
class MemoryEfficientVideoBlock(MemoryEfficientAttnBlock):
|
181 |
+
def __init__(
|
182 |
+
self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
|
183 |
+
):
|
184 |
+
super().__init__(in_channels)
|
185 |
+
# no context, single headed, as in base class
|
186 |
+
self.time_mix_block = VideoTransformerBlock(
|
187 |
+
dim=in_channels,
|
188 |
+
n_heads=1,
|
189 |
+
d_head=in_channels,
|
190 |
+
checkpoint=False,
|
191 |
+
ff_in=True,
|
192 |
+
attn_mode="softmax-xformers",
|
193 |
+
)
|
194 |
+
|
195 |
+
time_embed_dim = self.in_channels * 4
|
196 |
+
self.video_time_embed = torch.nn.Sequential(
|
197 |
+
torch.nn.Linear(self.in_channels, time_embed_dim),
|
198 |
+
torch.nn.SiLU(),
|
199 |
+
torch.nn.Linear(time_embed_dim, self.in_channels),
|
200 |
+
)
|
201 |
+
|
202 |
+
self.merge_strategy = merge_strategy
|
203 |
+
if self.merge_strategy == "fixed":
|
204 |
+
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
205 |
+
elif self.merge_strategy == "learned":
|
206 |
+
self.register_parameter(
|
207 |
+
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
|
208 |
+
)
|
209 |
+
else:
|
210 |
+
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
211 |
+
|
212 |
+
def forward(self, x, timesteps, skip_time_block=False):
|
213 |
+
if skip_time_block:
|
214 |
+
return super().forward(x)
|
215 |
+
|
216 |
+
x_in = x
|
217 |
+
x = self.attention(x)
|
218 |
+
h, w = x.shape[2:]
|
219 |
+
x = rearrange(x, "b c h w -> b (h w) c")
|
220 |
+
|
221 |
+
x_mix = x
|
222 |
+
num_frames = torch.arange(timesteps, device=x.device)
|
223 |
+
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
|
224 |
+
num_frames = rearrange(num_frames, "b t -> (b t)")
|
225 |
+
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
|
226 |
+
emb = self.video_time_embed(t_emb) # b, n_channels
|
227 |
+
emb = emb[:, None, :]
|
228 |
+
x_mix = x_mix + emb
|
229 |
+
|
230 |
+
alpha = self.get_alpha()
|
231 |
+
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
|
232 |
+
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
|
233 |
+
|
234 |
+
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
235 |
+
x = self.proj_out(x)
|
236 |
+
|
237 |
+
return x_in + x
|
238 |
+
|
239 |
+
def get_alpha(
|
240 |
+
self,
|
241 |
+
):
|
242 |
+
if self.merge_strategy == "fixed":
|
243 |
+
return self.mix_factor
|
244 |
+
elif self.merge_strategy == "learned":
|
245 |
+
return torch.sigmoid(self.mix_factor)
|
246 |
+
else:
|
247 |
+
raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
|
248 |
+
|
249 |
+
|
250 |
+
def make_time_attn(
|
251 |
+
in_channels,
|
252 |
+
attn_type="vanilla",
|
253 |
+
attn_kwargs=None,
|
254 |
+
alpha: float = 0,
|
255 |
+
merge_strategy: str = "learned",
|
256 |
+
):
|
257 |
+
assert attn_type in [
|
258 |
+
"vanilla",
|
259 |
+
"vanilla-xformers",
|
260 |
+
], f"attn_type {attn_type} not supported for spatio-temporal attention"
|
261 |
+
print(
|
262 |
+
f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels"
|
263 |
+
)
|
264 |
+
if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers":
|
265 |
+
print(
|
266 |
+
f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. "
|
267 |
+
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
|
268 |
+
)
|
269 |
+
attn_type = "vanilla"
|
270 |
+
|
271 |
+
if attn_type == "vanilla":
|
272 |
+
assert attn_kwargs is None
|
273 |
+
return partialclass(
|
274 |
+
VideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy
|
275 |
+
)
|
276 |
+
elif attn_type == "vanilla-xformers":
|
277 |
+
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
|
278 |
+
return partialclass(
|
279 |
+
MemoryEfficientVideoBlock,
|
280 |
+
in_channels,
|
281 |
+
alpha=alpha,
|
282 |
+
merge_strategy=merge_strategy,
|
283 |
+
)
|
284 |
+
else:
|
285 |
+
return NotImplementedError()
|
286 |
+
|
287 |
+
|
288 |
+
class Conv2DWrapper(torch.nn.Conv2d):
|
289 |
+
def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
|
290 |
+
return super().forward(input)
|
291 |
+
|
292 |
+
|
293 |
+
class VideoDecoder(Decoder):
|
294 |
+
available_time_modes = ["all", "conv-only", "attn-only"]
|
295 |
+
|
296 |
+
def __init__(
|
297 |
+
self,
|
298 |
+
*args,
|
299 |
+
video_kernel_size: Union[int, list] = 3,
|
300 |
+
alpha: float = 0.0,
|
301 |
+
merge_strategy: str = "learned",
|
302 |
+
time_mode: str = "conv-only",
|
303 |
+
**kwargs,
|
304 |
+
):
|
305 |
+
self.video_kernel_size = video_kernel_size
|
306 |
+
self.alpha = alpha
|
307 |
+
self.merge_strategy = merge_strategy
|
308 |
+
self.time_mode = time_mode
|
309 |
+
assert (
|
310 |
+
self.time_mode in self.available_time_modes
|
311 |
+
), f"time_mode parameter has to be in {self.available_time_modes}"
|
312 |
+
super().__init__(*args, **kwargs)
|
313 |
+
|
314 |
+
def get_last_layer(self, skip_time_mix=False, **kwargs):
|
315 |
+
if self.time_mode == "attn-only":
|
316 |
+
raise NotImplementedError("TODO")
|
317 |
+
else:
|
318 |
+
return (
|
319 |
+
self.conv_out.time_mix_conv.weight
|
320 |
+
if not skip_time_mix
|
321 |
+
else self.conv_out.weight
|
322 |
+
)
|
323 |
+
|
324 |
+
def _make_attn(self) -> Callable:
|
325 |
+
if self.time_mode not in ["conv-only", "only-last-conv"]:
|
326 |
+
return partialclass(
|
327 |
+
make_time_attn,
|
328 |
+
alpha=self.alpha,
|
329 |
+
merge_strategy=self.merge_strategy,
|
330 |
+
)
|
331 |
+
else:
|
332 |
+
return super()._make_attn()
|
333 |
+
|
334 |
+
def _make_conv(self) -> Callable:
|
335 |
+
if self.time_mode != "attn-only":
|
336 |
+
return partialclass(AE3DConv, video_kernel_size=self.video_kernel_size)
|
337 |
+
else:
|
338 |
+
return Conv2DWrapper
|
339 |
+
|
340 |
+
def _make_resblock(self) -> Callable:
|
341 |
+
if self.time_mode not in ["attn-only", "only-last-conv"]:
|
342 |
+
return partialclass(
|
343 |
+
VideoResBlock,
|
344 |
+
video_kernel_size=self.video_kernel_size,
|
345 |
+
alpha=self.alpha,
|
346 |
+
merge_strategy=self.merge_strategy,
|
347 |
+
)
|
348 |
+
else:
|
349 |
+
return super()._make_resblock()
|