lukasmoeller commited on
Commit
06c1397
1 Parent(s): 9c0f1cf

Remove flash attn support

Browse files
Files changed (1) hide show
  1. attention.py +4 -71
attention.py CHANGED
@@ -58,59 +58,6 @@ def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
58
  if not tensor.is_cuda:
59
  raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
60
 
61
- def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
62
- try:
63
- from flash_attn import bert_padding, flash_attn_interface
64
- except:
65
- raise RuntimeError('Please install flash-attn==1.0.3.post0')
66
- check_valid_inputs(query, key, value)
67
- if attn_bias is not None:
68
- raise NotImplementedError(f'attn_bias not implemented for flash attn.')
69
- (batch_size, seqlen) = query.shape[:2]
70
- if key_padding_mask is None:
71
- key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
72
- query_padding_mask = key_padding_mask[:, -query.size(1):]
73
- (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(query, query_padding_mask)
74
- query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
75
- (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(key, key_padding_mask)
76
- key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=1 if multiquery else n_heads)
77
- (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask)
78
- value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=1 if multiquery else n_heads)
79
- if multiquery:
80
- key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))
81
- value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, value_unpad.size(-1))
82
- dropout_p = dropout_p if training else 0.0
83
- reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
84
- output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
85
- output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
86
- return (output, None)
87
-
88
- def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
89
- try:
90
- from flash_attn import flash_attn_triton
91
- except:
92
- raise RuntimeError('Please install flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202')
93
- check_valid_inputs(query, key, value)
94
- if dropout_p:
95
- raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
96
- if needs_weights:
97
- raise NotImplementedError(f'attn_impl: triton cannot return attn weights.')
98
- if key_padding_mask is not None:
99
- warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unnecessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
100
- (b_size, s_k) = key_padding_mask.shape[:2]
101
- if attn_bias is None:
102
- attn_bias = query.new_zeros(b_size, 1, 1, s_k)
103
- attn_bias = attn_bias.masked_fill(~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min)
104
- query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
105
- key = rearrange(key, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads)
106
- value = rearrange(value, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads)
107
- if multiquery:
108
- key = key.expand(*key.shape[:2], n_heads, key.size(-1))
109
- value = value.expand(*value.shape[:2], n_heads, value.size(-1))
110
- reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
111
- attn_output = flash_attn_triton.flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
112
- output = attn_output.view(*attn_output.shape[:2], -1)
113
- return (output, None)
114
 
115
  class MultiheadAttention(nn.Module):
116
  """Multi-head self attention.
@@ -137,12 +84,7 @@ class MultiheadAttention(nn.Module):
137
  layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
138
  self.q_ln = layernorm_class(self.d_model, device=device)
139
  self.k_ln = layernorm_class(self.d_model, device=device)
140
- if self.attn_impl == 'flash':
141
- self.attn_fn = flash_attn_fn
142
- elif self.attn_impl == 'triton':
143
- self.attn_fn = triton_flash_attn_fn
144
- warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
145
- elif self.attn_impl == 'torch':
146
  self.attn_fn = scaled_multihead_dot_product_attention
147
  if torch.cuda.is_available():
148
  warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
@@ -197,12 +139,7 @@ class MultiQueryAttention(nn.Module):
197
  layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
198
  self.q_ln = layernorm_class(d_model, device=device)
199
  self.k_ln = layernorm_class(self.head_dim, device=device)
200
- if self.attn_impl == 'flash':
201
- self.attn_fn = flash_attn_fn
202
- elif self.attn_impl == 'triton':
203
- self.attn_fn = triton_flash_attn_fn
204
- warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
205
- elif self.attn_impl == 'torch':
206
  self.attn_fn = scaled_multihead_dot_product_attention
207
  if torch.cuda.is_available():
208
  warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
@@ -232,9 +169,7 @@ class MultiQueryAttention(nn.Module):
232
  return (self.out_proj(context), attn_weights, past_key_value)
233
 
234
  def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
235
- if attn_impl == 'flash':
236
- return None
237
- elif attn_impl in ['torch', 'triton']:
238
  if alibi:
239
  if (prefix_lm or not causal) or use_sequence_id:
240
  return (1, n_heads, seq_len, seq_len)
@@ -246,9 +181,7 @@ def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_s
246
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
247
 
248
  def build_attn_bias(attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8):
249
- if attn_impl == 'flash':
250
- return None
251
- elif attn_impl in ['torch', 'triton']:
252
  if alibi:
253
  (device, dtype) = (attn_bias.device, attn_bias.dtype)
254
  attn_bias = attn_bias.add(build_alibi_bias(n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, device=device, dtype=dtype))
 
58
  if not tensor.is_cuda:
59
  raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  class MultiheadAttention(nn.Module):
63
  """Multi-head self attention.
 
84
  layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
85
  self.q_ln = layernorm_class(self.d_model, device=device)
86
  self.k_ln = layernorm_class(self.d_model, device=device)
87
+ if self.attn_impl == 'torch':
 
 
 
 
 
88
  self.attn_fn = scaled_multihead_dot_product_attention
89
  if torch.cuda.is_available():
90
  warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
 
139
  layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
140
  self.q_ln = layernorm_class(d_model, device=device)
141
  self.k_ln = layernorm_class(self.head_dim, device=device)
142
+ if self.attn_impl == 'torch':
 
 
 
 
 
143
  self.attn_fn = scaled_multihead_dot_product_attention
144
  if torch.cuda.is_available():
145
  warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
 
169
  return (self.out_proj(context), attn_weights, past_key_value)
170
 
171
  def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
172
+ if attn_impl in ['torch', 'triton']:
 
 
173
  if alibi:
174
  if (prefix_lm or not causal) or use_sequence_id:
175
  return (1, n_heads, seq_len, seq_len)
 
181
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
182
 
183
  def build_attn_bias(attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8):
184
+ if attn_impl in ['torch', 'triton']:
 
 
185
  if alibi:
186
  (device, dtype) = (attn_bias.device, attn_bias.dtype)
187
  attn_bias = attn_bias.add(build_alibi_bias(n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, device=device, dtype=dtype))