`README.md` script error
Hi, first thanks for putting out this model!
I'd like to use it for my research but am having some trouble getting the model running in the base example script in the README.md
. I included the error below. Notably, I'm having trouble:
- changing the
attn_bias
to the proper type - changing the model to
torch.float16
without running into other errors
I am using the following packages:
pytorch 2.4.0 py3.10_cuda12.4_cudnn9.1.0_0 pytorch
pytorch-cuda 12.4 hc786d27_6 pytorch
xformers 0.0.27.post2 pypi_0 pypi
Please let me know if I am missing any other documentation anywhere as well!
Original error:
NotImplementedError: No operator found for `memory_efficient_attention_forward` with inputs:
query : shape=(4, 75, 12, 64) (torch.float32)
key : shape=(4, 75, 12, 64) (torch.float32)
value : shape=(4, 75, 12, 64) (torch.float32)
attn_bias : <class 'torch.Tensor'>
p : 0
`[email protected]` is not supported because:
device=cpu (supported: {'cuda'})
dtype=torch.float32 (supported: {torch.bfloat16, torch.float16})
attn_bias type is <class 'torch.Tensor'>
`cutlassF-pt` is not supported because:
device=cpu (supported: {'cuda'})
attn_bias.stride(-2) % 4 != 0 (attn_bias.stride() = (75, 0, 0, 1))
HINT: To use an `attn_bias` with a sequence length that is not a multiple of 8, you need to ensure memory is aligned by slicing a bigger tensor. Example: use `attn_bias = torch.zeros([1, 1, 5, 8])[:,:,:,:5]` instead of `torch.zeros([1, 1, 5, 5])`
`smallkF` is not supported because:
max(query.shape[-1] != value.shape[-1]) > 32
device=cpu (supported: {'cuda'})
unsupported embed per head: 64
Hi
@reaganjlee
, thanks for pointing out this issue! Since this model uses xops.memory_efficient_attention
, it requires all tensors on gpus, and input sequence be padded to a multiple of 8. I have modified the example code in readme accordingly. Looking forward to your reply and thanks for your interest in this work!
@dwzhu I've verified the new changes works properly on my end. Thank you for the help and putting this model out there!