bluestarburst
commited on
Commit
•
9a6a590
1
Parent(s):
2057037
Upload folder using huggingface_hub
Browse files- animatediff/models/motion_module.py +6 -1
- train.py +12 -0
animatediff/models/motion_module.py
CHANGED
@@ -308,9 +308,14 @@ class VersatileAttention(CrossAttention):
|
|
308 |
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
309 |
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
310 |
|
|
|
|
|
|
|
|
|
311 |
# attention, what we cannot get enough of
|
312 |
if self._use_memory_efficient_attention_xformers:
|
313 |
-
|
|
|
314 |
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
315 |
hidden_states = hidden_states.to(query.dtype)
|
316 |
else:
|
|
|
308 |
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
309 |
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
310 |
|
311 |
+
if not hasattr(self, '_use_memory_efficient_attention_xformers'):
|
312 |
+
self._use_memory_efficient_attention_xformers = True
|
313 |
+
|
314 |
+
|
315 |
# attention, what we cannot get enough of
|
316 |
if self._use_memory_efficient_attention_xformers:
|
317 |
+
self.set_use_memory_efficient_attention_xformers(True)
|
318 |
+
# hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
319 |
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
320 |
hidden_states = hidden_states.to(query.dtype)
|
321 |
else:
|
train.py
CHANGED
@@ -177,6 +177,7 @@ def main(
|
|
177 |
for name, module in unet.named_modules():
|
178 |
if "motion_modules" in name and (train_whole_module or name.endswith(tuple(trainable_modules))):
|
179 |
for params in module.parameters():
|
|
|
180 |
params.requires_grad = True
|
181 |
|
182 |
if enable_xformers_memory_efficient_attention:
|
@@ -370,10 +371,21 @@ def main(
|
|
370 |
avg_loss = accelerator.gather(loss.repeat(train_batch_size)).mean()
|
371 |
train_loss += avg_loss.item() / gradient_accumulation_steps
|
372 |
|
|
|
|
|
|
|
|
|
|
|
373 |
# Backpropagate
|
374 |
accelerator.backward(loss)
|
375 |
if accelerator.sync_gradients:
|
376 |
accelerator.clip_grad_norm_(unet.parameters(), max_grad_norm)
|
|
|
|
|
|
|
|
|
|
|
|
|
377 |
optimizer.step()
|
378 |
lr_scheduler.step()
|
379 |
optimizer.zero_grad()
|
|
|
177 |
for name, module in unet.named_modules():
|
178 |
if "motion_modules" in name and (train_whole_module or name.endswith(tuple(trainable_modules))):
|
179 |
for params in module.parameters():
|
180 |
+
print("trainable", name)
|
181 |
params.requires_grad = True
|
182 |
|
183 |
if enable_xformers_memory_efficient_attention:
|
|
|
371 |
avg_loss = accelerator.gather(loss.repeat(train_batch_size)).mean()
|
372 |
train_loss += avg_loss.item() / gradient_accumulation_steps
|
373 |
|
374 |
+
for name, module in unet.named_modules():
|
375 |
+
if "motion_modules" in name and (train_whole_module or name.endswith(tuple(trainable_modules))):
|
376 |
+
for params in module.parameters():
|
377 |
+
params.requires_grad = True
|
378 |
+
|
379 |
# Backpropagate
|
380 |
accelerator.backward(loss)
|
381 |
if accelerator.sync_gradients:
|
382 |
accelerator.clip_grad_norm_(unet.parameters(), max_grad_norm)
|
383 |
+
|
384 |
+
# for param in unet.parameters():
|
385 |
+
# print(param.grad)
|
386 |
+
|
387 |
+
|
388 |
+
|
389 |
optimizer.step()
|
390 |
lr_scheduler.step()
|
391 |
optimizer.zero_grad()
|