kaizen9 commited on
Commit
b58f295
1 Parent(s): 30fcc52

Update attention.py

Browse files
Files changed (1) hide show
  1. attention.py +58 -2
attention.py CHANGED
@@ -78,7 +78,7 @@ def check_valid_inputs(*tensors: torch.Tensor, valid_dtypes: Optional[List[torch
78
  raise TypeError(f'tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}.')
79
  if not tensor.is_cuda:
80
  raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
81
-
82
  def flash_attn_fn(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, n_heads: int, kv_n_heads: Optional[int]=None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, softmax_scale: Optional[float]=None, attn_bias: Optional[torch.Tensor]=None, key_padding_mask: Optional[torch.Tensor]=None, is_causal: bool=False, dropout_p: float=0.0, training: bool=False, needs_weights: bool=False, multiquery: bool=False) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
83
  try:
84
  from flash_attn import bert_padding, flash_attn_interface
@@ -123,7 +123,63 @@ def flash_attn_fn(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, n
123
  output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
124
  output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
125
  return (output, None, past_key_value)
126
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  def triton_flash_attn_fn(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, n_heads: int, kv_n_heads: Optional[int]=None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, softmax_scale: Optional[float]=None, attn_bias: Optional[torch.Tensor]=None, key_padding_mask: Optional[torch.Tensor]=None, is_causal: bool=False, dropout_p: float=0.0, training: bool=False, needs_weights: bool=False, multiquery: bool=False) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
128
  try:
129
  from .flash_attn_triton import flash_attn_func
 
78
  raise TypeError(f'tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}.')
79
  if not tensor.is_cuda:
80
  raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
81
+ '''
82
  def flash_attn_fn(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, n_heads: int, kv_n_heads: Optional[int]=None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, softmax_scale: Optional[float]=None, attn_bias: Optional[torch.Tensor]=None, key_padding_mask: Optional[torch.Tensor]=None, is_causal: bool=False, dropout_p: float=0.0, training: bool=False, needs_weights: bool=False, multiquery: bool=False) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
83
  try:
84
  from flash_attn import bert_padding, flash_attn_interface
 
123
  output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
124
  output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
125
  return (output, None, past_key_value)
126
+ '''
127
+ def flash_attn_fn(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, n_heads: int, kv_n_heads: int, past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]]=None, softmax_scale: Optional[float]=None, attn_bias: Optional[torch.Tensor]=None, key_padding_mask: Optional[torch.Tensor]=None, is_causal: bool=False, dropout_p: float=0.0, training: bool=False, needs_weights: bool=False, multiquery: bool=False, should_repeat_kv_for_gqa: Optional[bool]=True, sliding_window_size: int=-1, alibi_slopes: Optional[torch.Tensor]=None, flash_attn_padding_info: Optional[dict[str, torch.Tensor]]=None) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]:
128
+ if key_padding_mask is not None:
129
+ raise ValueError('key_padding_mask should be None for flash attn.')
130
+ del key_padding_mask
131
+ if flash_attn_padding_info is None:
132
+ raise ValueError('flash_attn_padding_info is required for flash attn.')
133
+ try:
134
+ from flash_attn import bert_padding, flash_attn_interface
135
+ except:
136
+ raise RuntimeError('Please install flash-attn==1.0.9 or flash-attn==2.3.6')
137
+ check_valid_inputs(query, key, value)
138
+ if past_key_value is not None:
139
+ if len(past_key_value) != 0:
140
+ key = torch.cat([past_key_value[0], key], dim=1)
141
+ value = torch.cat([past_key_value[1], value], dim=1)
142
+ past_key_value = (key, value)
143
+ if attn_bias is not None:
144
+ raise NotImplementedError(f'attn_bias not implemented for flash attn.')
145
+ (batch_size, seqlen) = query.shape[:2]
146
+ indices_q = flash_attn_padding_info['indices_q']
147
+ indices_k = flash_attn_padding_info['indices_k']
148
+ indices_v = flash_attn_padding_info['indices_v']
149
+ cu_seqlens_q = flash_attn_padding_info['cu_seqlens_q']
150
+ cu_seqlens_k = flash_attn_padding_info['cu_seqlens_k']
151
+ max_seqlen_q = flash_attn_padding_info['max_seqlen_q']
152
+ max_seqlen_k = flash_attn_padding_info['max_seqlen_k']
153
+ query_unpad = bert_padding.index_first_axis(rearrange(query, 'b s ... -> (b s) ...'), indices_q)
154
+ query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
155
+ key_unpad = bert_padding.index_first_axis(rearrange(key, 'b s ... -> (b s) ...'), indices_k)
156
+ key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)
157
+ value_unpad = bert_padding.index_first_axis(rearrange(value, 'b s ... -> (b s) ...'), indices_v)
158
+ value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)
159
+ if kv_n_heads < n_heads and (not is_flash_v2_installed()) and (not should_repeat_kv_for_gqa):
160
+ raise ValueError('For Grouped Query Attention or Multi Query Attention, should_repeat_kv_for_gqa should be set to True if not using Flash Attention v2.')
161
+ if should_repeat_kv_for_gqa:
162
+ if kv_n_heads == 1:
163
+ key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))
164
+ value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, value_unpad.size(-1))
165
+ elif kv_n_heads < n_heads:
166
+ key_unpad = repeat_kv_for_gqa(key_unpad.view(1, key_unpad.size(0), kv_n_heads, -1), n_heads // kv_n_heads).view(key_unpad.size(0), n_heads, -1)
167
+ value_unpad = repeat_kv_for_gqa(value_unpad.view(1, value_unpad.size(0), kv_n_heads, -1), n_heads // kv_n_heads).view(value_unpad.size(0), n_heads, -1)
168
+ dropout_p = dropout_p if training else 0.0
169
+ reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
170
+ if is_flash_v1_installed():
171
+ output_unpad = flash_attn_interface.flash_attn_unpadded_func(q=query_unpad, k=key_unpad, v=value_unpad, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, dropout_p=dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
172
+ elif is_flash_v2_installed():
173
+ alibi_kwargs = {}
174
+ if check_alibi_support('flash'):
175
+ alibi_kwargs = {'alibi_slopes': alibi_slopes}
176
+ elif alibi_slopes is not None:
177
+ raise ValueError('alibi_slopes is only supported for flash-attn>=2.4.2')
178
+ output_unpad = flash_attn_interface.flash_attn_varlen_func(q=query_unpad, k=key_unpad, v=value_unpad, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, dropout_p=dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights, window_size=(sliding_window_size, sliding_window_size), **alibi_kwargs)
179
+ else:
180
+ raise RuntimeError('flash-attn==1.0.9 or flash-attn==2.4.2 is required.')
181
+ output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
182
+ return (output, None, past_key_value)
183
  def triton_flash_attn_fn(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, n_heads: int, kv_n_heads: Optional[int]=None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, softmax_scale: Optional[float]=None, attn_bias: Optional[torch.Tensor]=None, key_padding_mask: Optional[torch.Tensor]=None, is_causal: bool=False, dropout_p: float=0.0, training: bool=False, needs_weights: bool=False, multiquery: bool=False) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
184
  try:
185
  from .flash_attn_triton import flash_attn_func