MotionInversion / utils /gpu_utils.py
ziyangmai's picture
page demo
113884e
raw
history blame
2.28 kB
import torch
import torch.nn.functional as F
from diffusers.models.attention_processor import AttnProcessor2_0
from diffusers.models.attention import BasicTransformerBlock
from diffusers.utils.import_utils import is_xformers_available
from transformers.models.clip.modeling_clip import CLIPEncoder
GRADIENT_CHECKPOINTING = True
TEXT_ENCODER_GRADIENT_CHECKPOINTING = True
ENABLE_XFORMERS_MEMORY_EFFICIENT_ATTENTION = True
ENABLE_TORCH_2_ATTN = True
def is_attn(name):
return ('attn1' or 'attn2' == name.split('.')[-1])
def unet_and_text_g_c(unet, text_encoder, unet_enable=GRADIENT_CHECKPOINTING, text_enable=TEXT_ENCODER_GRADIENT_CHECKPOINTING):
unet._set_gradient_checkpointing(value=unet_enable)
text_encoder._set_gradient_checkpointing(CLIPEncoder)
def set_processors(attentions):
for attn in attentions: attn.set_processor(AttnProcessor2_0())
def set_torch_2_attn(unet):
optim_count = 0
for name, module in unet.named_modules():
if is_attn(name):
if isinstance(module, torch.nn.ModuleList):
for m in module:
if isinstance(m, BasicTransformerBlock):
set_processors([m.attn1, m.attn2])
optim_count += 1
if optim_count > 0:
print(f"{optim_count} Attention layers using Scaled Dot Product Attention.")
def handle_memory_attention(
unet,
enable_xformers_memory_efficient_attention=ENABLE_XFORMERS_MEMORY_EFFICIENT_ATTENTION,
enable_torch_2_attn=ENABLE_TORCH_2_ATTN
):
try:
is_torch_2 = hasattr(F, 'scaled_dot_product_attention')
enable_torch_2 = is_torch_2 and enable_torch_2_attn
if enable_xformers_memory_efficient_attention and not enable_torch_2:
if is_xformers_available():
from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
unet.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
if enable_torch_2:
set_torch_2_attn(unet)
except:
print("Could not enable memory efficient attention for xformers or Torch 2.0.")