Spaces:
Runtime error
Runtime error
import functools | |
import torch | |
from diffusers.models.attention import BasicTransformerBlock | |
from diffusers.utils.import_utils import is_xformers_available | |
from .lora import LoraInjectedLinear | |
if is_xformers_available(): | |
import xformers | |
import xformers.ops | |
else: | |
xformers = None | |
def test_xformers_backwards(size): | |
def _grad(size): | |
q = torch.randn((1, 4, size), device="cuda") | |
k = torch.randn((1, 4, size), device="cuda") | |
v = torch.randn((1, 4, size), device="cuda") | |
q = q.detach().requires_grad_() | |
k = k.detach().requires_grad_() | |
v = v.detach().requires_grad_() | |
out = xformers.ops.memory_efficient_attention(q, k, v) | |
loss = out.sum(2).mean(0).sum() | |
return torch.autograd.grad(loss, v) | |
try: | |
_grad(size) | |
print(size, "pass") | |
return True | |
except Exception as e: | |
print(size, "fail") | |
return False | |
def set_use_memory_efficient_attention_xformers( | |
module: torch.nn.Module, valid: bool | |
) -> None: | |
def fn_test_dim_head(module: torch.nn.Module): | |
if isinstance(module, BasicTransformerBlock): | |
# dim_head isn't stored anywhere, so back-calculate | |
source = module.attn1.to_v | |
if isinstance(source, LoraInjectedLinear): | |
source = source.linear | |
dim_head = source.out_features // module.attn1.heads | |
result = test_xformers_backwards(dim_head) | |
# If dim_head > dim_head_max, turn xformers off | |
if not result: | |
module.set_use_memory_efficient_attention_xformers(False) | |
for child in module.children(): | |
fn_test_dim_head(child) | |
if not is_xformers_available() and valid: | |
print("XFormers is not available. Skipping.") | |
return | |
module.set_use_memory_efficient_attention_xformers(valid) | |
if valid: | |
fn_test_dim_head(module) | |