File size: 2,277 Bytes
113884e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import torch.nn.functional as F
from diffusers.models.attention_processor import AttnProcessor2_0
from diffusers.models.attention import BasicTransformerBlock
from diffusers.utils.import_utils import is_xformers_available
from transformers.models.clip.modeling_clip import CLIPEncoder

GRADIENT_CHECKPOINTING =  True
TEXT_ENCODER_GRADIENT_CHECKPOINTING = True
ENABLE_XFORMERS_MEMORY_EFFICIENT_ATTENTION = True
ENABLE_TORCH_2_ATTN = True

def is_attn(name):
    return ('attn1' or 'attn2' == name.split('.')[-1])

def unet_and_text_g_c(unet, text_encoder, unet_enable=GRADIENT_CHECKPOINTING, text_enable=TEXT_ENCODER_GRADIENT_CHECKPOINTING):
    unet._set_gradient_checkpointing(value=unet_enable)
    text_encoder._set_gradient_checkpointing(CLIPEncoder)

def set_processors(attentions):
    for attn in attentions: attn.set_processor(AttnProcessor2_0())

def set_torch_2_attn(unet):
    optim_count = 0

    for name, module in unet.named_modules():
        if is_attn(name):
            if isinstance(module, torch.nn.ModuleList):
                for m in module:
                    if isinstance(m, BasicTransformerBlock):
                        set_processors([m.attn1, m.attn2])
                        optim_count += 1
    if optim_count > 0:
        print(f"{optim_count} Attention layers using Scaled Dot Product Attention.")

def handle_memory_attention(
        unet,
        enable_xformers_memory_efficient_attention=ENABLE_XFORMERS_MEMORY_EFFICIENT_ATTENTION, 
        enable_torch_2_attn=ENABLE_TORCH_2_ATTN
    ):
    try:
        is_torch_2 = hasattr(F, 'scaled_dot_product_attention')
        enable_torch_2 = is_torch_2 and enable_torch_2_attn

        if enable_xformers_memory_efficient_attention and not enable_torch_2:
            if is_xformers_available():
                from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
                unet.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
            else:
                raise ValueError("xformers is not available. Make sure it is installed correctly")

        if enable_torch_2:
            set_torch_2_attn(unet)

    except:
        print("Could not enable memory efficient attention for xformers or Torch 2.0.")