Update attention.py
Browse files- attention.py +58 -2
attention.py
CHANGED
@@ -78,7 +78,7 @@ def check_valid_inputs(*tensors: torch.Tensor, valid_dtypes: Optional[List[torch
|
|
78 |
raise TypeError(f'tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}.')
|
79 |
if not tensor.is_cuda:
|
80 |
raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
|
81 |
-
|
82 |
def flash_attn_fn(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, n_heads: int, kv_n_heads: Optional[int]=None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, softmax_scale: Optional[float]=None, attn_bias: Optional[torch.Tensor]=None, key_padding_mask: Optional[torch.Tensor]=None, is_causal: bool=False, dropout_p: float=0.0, training: bool=False, needs_weights: bool=False, multiquery: bool=False) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
83 |
try:
|
84 |
from flash_attn import bert_padding, flash_attn_interface
|
@@ -123,7 +123,63 @@ def flash_attn_fn(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, n
|
|
123 |
output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
|
124 |
output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
|
125 |
return (output, None, past_key_value)
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
def triton_flash_attn_fn(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, n_heads: int, kv_n_heads: Optional[int]=None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, softmax_scale: Optional[float]=None, attn_bias: Optional[torch.Tensor]=None, key_padding_mask: Optional[torch.Tensor]=None, is_causal: bool=False, dropout_p: float=0.0, training: bool=False, needs_weights: bool=False, multiquery: bool=False) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
128 |
try:
|
129 |
from .flash_attn_triton import flash_attn_func
|
|
|
78 |
raise TypeError(f'tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}.')
|
79 |
if not tensor.is_cuda:
|
80 |
raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
|
81 |
+
'''
|
82 |
def flash_attn_fn(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, n_heads: int, kv_n_heads: Optional[int]=None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, softmax_scale: Optional[float]=None, attn_bias: Optional[torch.Tensor]=None, key_padding_mask: Optional[torch.Tensor]=None, is_causal: bool=False, dropout_p: float=0.0, training: bool=False, needs_weights: bool=False, multiquery: bool=False) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
83 |
try:
|
84 |
from flash_attn import bert_padding, flash_attn_interface
|
|
|
123 |
output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
|
124 |
output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
|
125 |
return (output, None, past_key_value)
|
126 |
+
'''
|
127 |
+
def flash_attn_fn(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, n_heads: int, kv_n_heads: int, past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]]=None, softmax_scale: Optional[float]=None, attn_bias: Optional[torch.Tensor]=None, key_padding_mask: Optional[torch.Tensor]=None, is_causal: bool=False, dropout_p: float=0.0, training: bool=False, needs_weights: bool=False, multiquery: bool=False, should_repeat_kv_for_gqa: Optional[bool]=True, sliding_window_size: int=-1, alibi_slopes: Optional[torch.Tensor]=None, flash_attn_padding_info: Optional[dict[str, torch.Tensor]]=None) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]:
|
128 |
+
if key_padding_mask is not None:
|
129 |
+
raise ValueError('key_padding_mask should be None for flash attn.')
|
130 |
+
del key_padding_mask
|
131 |
+
if flash_attn_padding_info is None:
|
132 |
+
raise ValueError('flash_attn_padding_info is required for flash attn.')
|
133 |
+
try:
|
134 |
+
from flash_attn import bert_padding, flash_attn_interface
|
135 |
+
except:
|
136 |
+
raise RuntimeError('Please install flash-attn==1.0.9 or flash-attn==2.3.6')
|
137 |
+
check_valid_inputs(query, key, value)
|
138 |
+
if past_key_value is not None:
|
139 |
+
if len(past_key_value) != 0:
|
140 |
+
key = torch.cat([past_key_value[0], key], dim=1)
|
141 |
+
value = torch.cat([past_key_value[1], value], dim=1)
|
142 |
+
past_key_value = (key, value)
|
143 |
+
if attn_bias is not None:
|
144 |
+
raise NotImplementedError(f'attn_bias not implemented for flash attn.')
|
145 |
+
(batch_size, seqlen) = query.shape[:2]
|
146 |
+
indices_q = flash_attn_padding_info['indices_q']
|
147 |
+
indices_k = flash_attn_padding_info['indices_k']
|
148 |
+
indices_v = flash_attn_padding_info['indices_v']
|
149 |
+
cu_seqlens_q = flash_attn_padding_info['cu_seqlens_q']
|
150 |
+
cu_seqlens_k = flash_attn_padding_info['cu_seqlens_k']
|
151 |
+
max_seqlen_q = flash_attn_padding_info['max_seqlen_q']
|
152 |
+
max_seqlen_k = flash_attn_padding_info['max_seqlen_k']
|
153 |
+
query_unpad = bert_padding.index_first_axis(rearrange(query, 'b s ... -> (b s) ...'), indices_q)
|
154 |
+
query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
|
155 |
+
key_unpad = bert_padding.index_first_axis(rearrange(key, 'b s ... -> (b s) ...'), indices_k)
|
156 |
+
key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)
|
157 |
+
value_unpad = bert_padding.index_first_axis(rearrange(value, 'b s ... -> (b s) ...'), indices_v)
|
158 |
+
value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)
|
159 |
+
if kv_n_heads < n_heads and (not is_flash_v2_installed()) and (not should_repeat_kv_for_gqa):
|
160 |
+
raise ValueError('For Grouped Query Attention or Multi Query Attention, should_repeat_kv_for_gqa should be set to True if not using Flash Attention v2.')
|
161 |
+
if should_repeat_kv_for_gqa:
|
162 |
+
if kv_n_heads == 1:
|
163 |
+
key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))
|
164 |
+
value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, value_unpad.size(-1))
|
165 |
+
elif kv_n_heads < n_heads:
|
166 |
+
key_unpad = repeat_kv_for_gqa(key_unpad.view(1, key_unpad.size(0), kv_n_heads, -1), n_heads // kv_n_heads).view(key_unpad.size(0), n_heads, -1)
|
167 |
+
value_unpad = repeat_kv_for_gqa(value_unpad.view(1, value_unpad.size(0), kv_n_heads, -1), n_heads // kv_n_heads).view(value_unpad.size(0), n_heads, -1)
|
168 |
+
dropout_p = dropout_p if training else 0.0
|
169 |
+
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
170 |
+
if is_flash_v1_installed():
|
171 |
+
output_unpad = flash_attn_interface.flash_attn_unpadded_func(q=query_unpad, k=key_unpad, v=value_unpad, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, dropout_p=dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
|
172 |
+
elif is_flash_v2_installed():
|
173 |
+
alibi_kwargs = {}
|
174 |
+
if check_alibi_support('flash'):
|
175 |
+
alibi_kwargs = {'alibi_slopes': alibi_slopes}
|
176 |
+
elif alibi_slopes is not None:
|
177 |
+
raise ValueError('alibi_slopes is only supported for flash-attn>=2.4.2')
|
178 |
+
output_unpad = flash_attn_interface.flash_attn_varlen_func(q=query_unpad, k=key_unpad, v=value_unpad, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, dropout_p=dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights, window_size=(sliding_window_size, sliding_window_size), **alibi_kwargs)
|
179 |
+
else:
|
180 |
+
raise RuntimeError('flash-attn==1.0.9 or flash-attn==2.4.2 is required.')
|
181 |
+
output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
|
182 |
+
return (output, None, past_key_value)
|
183 |
def triton_flash_attn_fn(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, n_heads: int, kv_n_heads: Optional[int]=None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, softmax_scale: Optional[float]=None, attn_bias: Optional[torch.Tensor]=None, key_padding_mask: Optional[torch.Tensor]=None, is_causal: bool=False, dropout_p: float=0.0, training: bool=False, needs_weights: bool=False, multiquery: bool=False) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
184 |
try:
|
185 |
from .flash_attn_triton import flash_attn_func
|