lukasmoeller
commited on
Commit
•
06c1397
1
Parent(s):
9c0f1cf
Remove flash attn support
Browse files- attention.py +4 -71
attention.py
CHANGED
@@ -58,59 +58,6 @@ def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
|
|
58 |
if not tensor.is_cuda:
|
59 |
raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
|
60 |
|
61 |
-
def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
|
62 |
-
try:
|
63 |
-
from flash_attn import bert_padding, flash_attn_interface
|
64 |
-
except:
|
65 |
-
raise RuntimeError('Please install flash-attn==1.0.3.post0')
|
66 |
-
check_valid_inputs(query, key, value)
|
67 |
-
if attn_bias is not None:
|
68 |
-
raise NotImplementedError(f'attn_bias not implemented for flash attn.')
|
69 |
-
(batch_size, seqlen) = query.shape[:2]
|
70 |
-
if key_padding_mask is None:
|
71 |
-
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
|
72 |
-
query_padding_mask = key_padding_mask[:, -query.size(1):]
|
73 |
-
(query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(query, query_padding_mask)
|
74 |
-
query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
|
75 |
-
(key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(key, key_padding_mask)
|
76 |
-
key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=1 if multiquery else n_heads)
|
77 |
-
(value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask)
|
78 |
-
value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=1 if multiquery else n_heads)
|
79 |
-
if multiquery:
|
80 |
-
key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))
|
81 |
-
value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, value_unpad.size(-1))
|
82 |
-
dropout_p = dropout_p if training else 0.0
|
83 |
-
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
84 |
-
output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
|
85 |
-
output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
|
86 |
-
return (output, None)
|
87 |
-
|
88 |
-
def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
|
89 |
-
try:
|
90 |
-
from flash_attn import flash_attn_triton
|
91 |
-
except:
|
92 |
-
raise RuntimeError('Please install flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202')
|
93 |
-
check_valid_inputs(query, key, value)
|
94 |
-
if dropout_p:
|
95 |
-
raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
|
96 |
-
if needs_weights:
|
97 |
-
raise NotImplementedError(f'attn_impl: triton cannot return attn weights.')
|
98 |
-
if key_padding_mask is not None:
|
99 |
-
warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unnecessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
|
100 |
-
(b_size, s_k) = key_padding_mask.shape[:2]
|
101 |
-
if attn_bias is None:
|
102 |
-
attn_bias = query.new_zeros(b_size, 1, 1, s_k)
|
103 |
-
attn_bias = attn_bias.masked_fill(~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min)
|
104 |
-
query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
|
105 |
-
key = rearrange(key, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads)
|
106 |
-
value = rearrange(value, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads)
|
107 |
-
if multiquery:
|
108 |
-
key = key.expand(*key.shape[:2], n_heads, key.size(-1))
|
109 |
-
value = value.expand(*value.shape[:2], n_heads, value.size(-1))
|
110 |
-
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
111 |
-
attn_output = flash_attn_triton.flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
|
112 |
-
output = attn_output.view(*attn_output.shape[:2], -1)
|
113 |
-
return (output, None)
|
114 |
|
115 |
class MultiheadAttention(nn.Module):
|
116 |
"""Multi-head self attention.
|
@@ -137,12 +84,7 @@ class MultiheadAttention(nn.Module):
|
|
137 |
layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
|
138 |
self.q_ln = layernorm_class(self.d_model, device=device)
|
139 |
self.k_ln = layernorm_class(self.d_model, device=device)
|
140 |
-
if self.attn_impl == '
|
141 |
-
self.attn_fn = flash_attn_fn
|
142 |
-
elif self.attn_impl == 'triton':
|
143 |
-
self.attn_fn = triton_flash_attn_fn
|
144 |
-
warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
|
145 |
-
elif self.attn_impl == 'torch':
|
146 |
self.attn_fn = scaled_multihead_dot_product_attention
|
147 |
if torch.cuda.is_available():
|
148 |
warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
|
@@ -197,12 +139,7 @@ class MultiQueryAttention(nn.Module):
|
|
197 |
layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
|
198 |
self.q_ln = layernorm_class(d_model, device=device)
|
199 |
self.k_ln = layernorm_class(self.head_dim, device=device)
|
200 |
-
if self.attn_impl == '
|
201 |
-
self.attn_fn = flash_attn_fn
|
202 |
-
elif self.attn_impl == 'triton':
|
203 |
-
self.attn_fn = triton_flash_attn_fn
|
204 |
-
warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
|
205 |
-
elif self.attn_impl == 'torch':
|
206 |
self.attn_fn = scaled_multihead_dot_product_attention
|
207 |
if torch.cuda.is_available():
|
208 |
warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
|
@@ -232,9 +169,7 @@ class MultiQueryAttention(nn.Module):
|
|
232 |
return (self.out_proj(context), attn_weights, past_key_value)
|
233 |
|
234 |
def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
|
235 |
-
if
|
236 |
-
return None
|
237 |
-
elif attn_impl in ['torch', 'triton']:
|
238 |
if alibi:
|
239 |
if (prefix_lm or not causal) or use_sequence_id:
|
240 |
return (1, n_heads, seq_len, seq_len)
|
@@ -246,9 +181,7 @@ def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_s
|
|
246 |
raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
|
247 |
|
248 |
def build_attn_bias(attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8):
|
249 |
-
if attn_impl
|
250 |
-
return None
|
251 |
-
elif attn_impl in ['torch', 'triton']:
|
252 |
if alibi:
|
253 |
(device, dtype) = (attn_bias.device, attn_bias.dtype)
|
254 |
attn_bias = attn_bias.add(build_alibi_bias(n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, device=device, dtype=dtype))
|
|
|
58 |
if not tensor.is_cuda:
|
59 |
raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
class MultiheadAttention(nn.Module):
|
63 |
"""Multi-head self attention.
|
|
|
84 |
layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
|
85 |
self.q_ln = layernorm_class(self.d_model, device=device)
|
86 |
self.k_ln = layernorm_class(self.d_model, device=device)
|
87 |
+
if self.attn_impl == 'torch':
|
|
|
|
|
|
|
|
|
|
|
88 |
self.attn_fn = scaled_multihead_dot_product_attention
|
89 |
if torch.cuda.is_available():
|
90 |
warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
|
|
|
139 |
layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
|
140 |
self.q_ln = layernorm_class(d_model, device=device)
|
141 |
self.k_ln = layernorm_class(self.head_dim, device=device)
|
142 |
+
if self.attn_impl == 'torch':
|
|
|
|
|
|
|
|
|
|
|
143 |
self.attn_fn = scaled_multihead_dot_product_attention
|
144 |
if torch.cuda.is_available():
|
145 |
warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
|
|
|
169 |
return (self.out_proj(context), attn_weights, past_key_value)
|
170 |
|
171 |
def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
|
172 |
+
if attn_impl in ['torch', 'triton']:
|
|
|
|
|
173 |
if alibi:
|
174 |
if (prefix_lm or not causal) or use_sequence_id:
|
175 |
return (1, n_heads, seq_len, seq_len)
|
|
|
181 |
raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
|
182 |
|
183 |
def build_attn_bias(attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8):
|
184 |
+
if attn_impl in ['torch', 'triton']:
|
|
|
|
|
185 |
if alibi:
|
186 |
(device, dtype) = (attn_bias.device, attn_bias.dtype)
|
187 |
attn_bias = attn_bias.add(build_alibi_bias(n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, device=device, dtype=dtype))
|