feat: add flash attention v2

#9
Files changed (2) hide show
  1. README.md +30 -1
  2. modeling_stablelm_epoch.py +246 -19
README.md CHANGED
@@ -36,7 +36,7 @@ model = AutoModelForCausalLM.from_pretrained(
36
  torch_dtype="auto",
37
  )
38
  model.cuda()
39
- inputs = tokenizer("The weather is always wonderful", return_tensors="pt").to("cuda")
40
  tokens = model.generate(
41
  **inputs,
42
  max_new_tokens=64,
@@ -47,6 +47,35 @@ tokens = model.generate(
47
  print(tokenizer.decode(tokens[0], skip_special_tokens=True))
48
  ```
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  ## Model Details
51
 
52
  * **Developed by**: [Stability AI](https://stability.ai/)
 
36
  torch_dtype="auto",
37
  )
38
  model.cuda()
39
+ inputs = tokenizer("The weather is always wonderful", return_tensors="pt").to(model.device)
40
  tokens = model.generate(
41
  **inputs,
42
  max_new_tokens=64,
 
47
  print(tokenizer.decode(tokens[0], skip_special_tokens=True))
48
  ```
49
 
50
+ ### Run with Flash Attention 2 ⚡️
51
+
52
+ <details>
53
+ <summary> Click to expand </summary>
54
+
55
+ ```python
56
+ from transformers import AutoModelForCausalLM, AutoTokenizer
57
+ tokenizer = AutoTokenizer.from_pretrained("stabilityai/stablelm-3b-4e1t")
58
+ model = AutoModelForCausalLM.from_pretrained(
59
+ "stabilityai/stablelm-3b-4e1t",
60
+ trust_remote_code=True,
61
+ torch_dtype="auto",
62
+ + use_flash_attention_2=True,
63
+ )
64
+ model.cuda()
65
+ inputs = tokenizer("The weather is always wonderful", return_tensors="pt").to(model.device)
66
+ tokens = model.generate(
67
+ **inputs,
68
+ max_new_tokens=64,
69
+ temperature=0.75,
70
+ top_p=0.95,
71
+ do_sample=True,
72
+ )
73
+ print(tokenizer.decode(tokens[0], skip_special_tokens=True))
74
+ ```
75
+
76
+ </details>
77
+
78
+
79
  ## Model Details
80
 
81
  * **Developed by**: [Stability AI](https://stability.ai/)
modeling_stablelm_epoch.py CHANGED
@@ -19,23 +19,46 @@
19
  """ PyTorch StableLM Epoch model. """
20
  from typing import Optional, Tuple, Union
21
  import math
 
22
 
23
  import torch
 
24
  import torch.utils.checkpoint
25
  from torch import nn
26
  from torch.nn import CrossEntropyLoss
 
 
27
  from transformers.modeling_outputs import (
28
  BaseModelOutputWithPast,
29
  CausalLMOutputWithPast,
30
  )
31
  from transformers.modeling_utils import PreTrainedModel
32
- from transformers.utils import logging
 
33
  from .configuration_stablelm_epoch import StableLMEpochConfig
34
 
35
 
 
 
 
 
 
36
  logger = logging.get_logger(__name__)
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
40
  def _make_causal_mask(
41
  input_ids_shape: torch.Size,
@@ -165,6 +188,7 @@ class Attention(nn.Module):
165
  self.num_key_value_heads = config.num_key_value_heads
166
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
167
  self.max_position_embeddings = config.max_position_embeddings
 
168
 
169
  if (self.head_dim * self.num_heads) != self.hidden_size:
170
  raise ValueError(
@@ -269,10 +293,202 @@ class Attention(nn.Module):
269
  return attn_output, attn_weights, past_key_value
270
 
271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  class DecoderLayer(nn.Module):
273
  def __init__(self, config: StableLMEpochConfig):
274
  super().__init__()
275
- self.self_attn = Attention(config)
276
  self.mlp = MLP(config)
277
  self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
278
  self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
@@ -328,6 +544,7 @@ class StableLMEpochPreTrainedModel(PreTrainedModel):
328
  supports_gradient_checkpointing = True
329
  _no_split_modules = ["DecoderLayer"]
330
  _skip_keys_device_placement = "past_key_values"
 
331
 
332
  def _init_weights(self, module: nn.Module):
333
  """Initialize the weights"""
@@ -355,6 +572,7 @@ class StableLMEpochModel(StableLMEpochPreTrainedModel):
355
  self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
356
  self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
357
 
 
358
  self.gradient_checkpointing = False
359
  # Initialize weights and apply final processing
360
  self.post_init()
@@ -428,10 +646,6 @@ class StableLMEpochModel(StableLMEpochPreTrainedModel):
428
  seq_length_with_past = seq_length
429
  past_key_values_length = 0
430
 
431
- if past_key_values is not None:
432
- past_key_values_length = past_key_values[0][0].shape[2]
433
- seq_length_with_past = seq_length_with_past + past_key_values_length
434
-
435
  if position_ids is None:
436
  device = input_ids.device if input_ids is not None else inputs_embeds.device
437
  position_ids = torch.arange(
@@ -447,18 +661,22 @@ class StableLMEpochModel(StableLMEpochPreTrainedModel):
447
  if inputs_embeds is None:
448
  inputs_embeds = self.embed_tokens(input_ids)
449
  # Embed positions
450
- if attention_mask is None:
451
- attention_mask = torch.ones(
452
- (batch_size, seq_length_with_past),
453
- dtype=torch.bool,
454
- device=inputs_embeds.device,
 
 
 
 
 
 
 
 
 
 
455
  )
456
- attention_mask = self._prepare_decoder_attention_mask(
457
- attention_mask,
458
- (batch_size, seq_length),
459
- inputs_embeds,
460
- past_key_values_length,
461
- )
462
 
463
  hidden_states = inputs_embeds
464
 
@@ -643,8 +861,17 @@ class StableLMEpochForCausalLM(StableLMEpochPreTrainedModel):
643
  **kwargs,
644
  ):
645
  # Trim decoder_input_ids if past is used
646
- if past_key_values and past_key_values[0] is not None:
647
- input_ids = input_ids[:, -1:]
 
 
 
 
 
 
 
 
 
648
 
649
  position_ids = kwargs.get("position_ids", None)
650
  if attention_mask is not None and position_ids is None:
 
19
  """ PyTorch StableLM Epoch model. """
20
  from typing import Optional, Tuple, Union
21
  import math
22
+ import warnings
23
 
24
  import torch
25
+ import torch.nn.functional as F
26
  import torch.utils.checkpoint
27
  from torch import nn
28
  from torch.nn import CrossEntropyLoss
29
+
30
+ from transformers.cache_utils import Cache
31
  from transformers.modeling_outputs import (
32
  BaseModelOutputWithPast,
33
  CausalLMOutputWithPast,
34
  )
35
  from transformers.modeling_utils import PreTrainedModel
36
+ from transformers.utils import logging, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10
37
+
38
  from .configuration_stablelm_epoch import StableLMEpochConfig
39
 
40
 
41
+ if is_flash_attn_2_available():
42
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
43
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
44
+
45
+
46
  logger = logging.get_logger(__name__)
47
 
48
 
49
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
50
+ def _get_unpad_data(attention_mask):
51
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
52
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
53
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
54
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
55
+ return (
56
+ indices,
57
+ cu_seqlens,
58
+ max_seqlen_in_batch,
59
+ )
60
+
61
+
62
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
63
  def _make_causal_mask(
64
  input_ids_shape: torch.Size,
 
188
  self.num_key_value_heads = config.num_key_value_heads
189
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
190
  self.max_position_embeddings = config.max_position_embeddings
191
+ self.is_causal = True
192
 
193
  if (self.head_dim * self.num_heads) != self.hidden_size:
194
  raise ValueError(
 
293
  return attn_output, attn_weights, past_key_value
294
 
295
 
296
+ class FlashAttention2(Attention):
297
+ """
298
+ Reference: https://github.com/huggingface/transformers/blob/5d36025ca13d05151b7a0c761e90d429c4644a30/src/transformers/models/llama/modeling_llama.py#L456
299
+ """
300
+
301
+ def __init__(self, *args, **kwargs):
302
+ super().__init__(*args, **kwargs)
303
+
304
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
305
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
306
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
307
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
308
+
309
+ def forward(
310
+ self,
311
+ hidden_states: torch.Tensor,
312
+ attention_mask: Optional[torch.LongTensor] = None,
313
+ position_ids: Optional[torch.LongTensor] = None,
314
+ past_key_value: Optional[Cache] = None,
315
+ output_attentions: bool = False,
316
+ use_cache: bool = False,
317
+ **kwargs,
318
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
319
+ # FlashAttention2 attention does not support output_attentions
320
+ if "padding_mask" in kwargs:
321
+ warnings.warn(
322
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
323
+ )
324
+
325
+ # overwrite attention_mask with padding_mask
326
+ attention_mask = kwargs.pop("padding_mask")
327
+
328
+ output_attentions = False
329
+
330
+ bsz, q_len, _ = hidden_states.size()
331
+
332
+ query_states = self.q_proj(hidden_states)
333
+ key_states = self.k_proj(hidden_states)
334
+ value_states = self.v_proj(hidden_states)
335
+
336
+ # Flash attention requires the input to have the shape
337
+ # batch_size x seq_length x head_dim x hidden_dim
338
+ # therefore we just need to keep the original shape
339
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
340
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
341
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
342
+
343
+ query_rot = query_states[..., : self.rotary_ndims]
344
+ query_pass = query_states[..., self.rotary_ndims :]
345
+ key_rot = key_states[..., : self.rotary_ndims]
346
+ key_pass = key_states[..., self.rotary_ndims :]
347
+
348
+ kv_seq_len = key_states.shape[-2]
349
+ if past_key_value is not None:
350
+ kv_seq_len += past_key_value[0].shape[-2]
351
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
352
+ query_states, key_states = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
353
+
354
+ # [batch_size, num_heads, seq_len, head_dim]
355
+ query_states = torch.cat((query_states, query_pass), dim=-1)
356
+ key_states = torch.cat((key_states, key_pass), dim=-1)
357
+
358
+ if past_key_value is not None:
359
+ # Reuse k, v, self_attention
360
+ key_states = torch.cat((past_key_value[0], key_states), dim=2)
361
+ value_states = torch.cat((past_key_value[1], value_states), dim=2)
362
+
363
+ past_key_value = (key_states, value_states) if use_cache else None
364
+
365
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
366
+ # to be able to avoid many of these transpose/reshape/view.
367
+ query_states = query_states.transpose(1, 2)
368
+ key_states = key_states.transpose(1, 2)
369
+ value_states = value_states.transpose(1, 2)
370
+
371
+ dropout_rate = self.attention_dropout if self.training else 0.0
372
+
373
+ attn_output = self._flash_attention_forward(
374
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
375
+ )
376
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
377
+ attn_output = self.o_proj(attn_output)
378
+
379
+ if not output_attentions:
380
+ attn_weights = None
381
+
382
+ return attn_output, attn_weights, past_key_value
383
+
384
+ def _flash_attention_forward(
385
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
386
+ ):
387
+ """
388
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
389
+ first unpad the input, then computes the attention scores and pad the final attention scores.
390
+
391
+ Args:
392
+ query_states (`torch.Tensor`):
393
+ Input query states to be passed to Flash Attention API
394
+ key_states (`torch.Tensor`):
395
+ Input key states to be passed to Flash Attention API
396
+ value_states (`torch.Tensor`):
397
+ Input value states to be passed to Flash Attention API
398
+ attention_mask (`torch.Tensor`):
399
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
400
+ position of padding tokens and 1 for the position of non-padding tokens.
401
+ dropout (`int`, *optional*):
402
+ Attention dropout
403
+ softmax_scale (`float`, *optional*):
404
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
405
+ """
406
+ if not self._flash_attn_uses_top_left_mask:
407
+ causal = self.is_causal
408
+ else:
409
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in FlashAttention2 __init__.
410
+ causal = self.is_causal and query_length != 1
411
+
412
+ # Contains at least one padding token in the sequence
413
+ if attention_mask is not None:
414
+ batch_size = query_states.shape[0]
415
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
416
+ query_states, key_states, value_states, attention_mask, query_length
417
+ )
418
+
419
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
420
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
421
+
422
+ attn_output_unpad = flash_attn_varlen_func(
423
+ query_states,
424
+ key_states,
425
+ value_states,
426
+ cu_seqlens_q=cu_seqlens_q,
427
+ cu_seqlens_k=cu_seqlens_k,
428
+ max_seqlen_q=max_seqlen_in_batch_q,
429
+ max_seqlen_k=max_seqlen_in_batch_k,
430
+ dropout_p=dropout,
431
+ softmax_scale=softmax_scale,
432
+ causal=causal,
433
+ )
434
+
435
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
436
+ else:
437
+ attn_output = flash_attn_func(
438
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
439
+ )
440
+
441
+ return attn_output
442
+
443
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
444
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
445
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
446
+
447
+ key_layer = index_first_axis(
448
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
449
+ )
450
+ value_layer = index_first_axis(
451
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
452
+ )
453
+ if query_length == kv_seq_len:
454
+ query_layer = index_first_axis(
455
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
456
+ )
457
+ cu_seqlens_q = cu_seqlens_k
458
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
459
+ indices_q = indices_k
460
+ elif query_length == 1:
461
+ max_seqlen_in_batch_q = 1
462
+ cu_seqlens_q = torch.arange(
463
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
464
+ ) # There is a memcpy here, that is very bad.
465
+ indices_q = cu_seqlens_q[:-1]
466
+ query_layer = query_layer.squeeze(1)
467
+ else:
468
+ # The -q_len: slice assumes left padding.
469
+ attention_mask = attention_mask[:, -query_length:]
470
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
471
+
472
+ return (
473
+ query_layer,
474
+ key_layer,
475
+ value_layer,
476
+ indices_q,
477
+ (cu_seqlens_q, cu_seqlens_k),
478
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
479
+ )
480
+
481
+
482
+ ATTENTION_CLASSES = {
483
+ "eager": Attention,
484
+ "flash_attention_2": FlashAttention2,
485
+ }
486
+
487
+
488
  class DecoderLayer(nn.Module):
489
  def __init__(self, config: StableLMEpochConfig):
490
  super().__init__()
491
+ self.self_attn = ATTENTION_CLASSES[config._attn_implementation](config=config)
492
  self.mlp = MLP(config)
493
  self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
494
  self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
 
544
  supports_gradient_checkpointing = True
545
  _no_split_modules = ["DecoderLayer"]
546
  _skip_keys_device_placement = "past_key_values"
547
+ _supports_flash_attn_2 = True
548
 
549
  def _init_weights(self, module: nn.Module):
550
  """Initialize the weights"""
 
572
  self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
573
  self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
574
 
575
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
576
  self.gradient_checkpointing = False
577
  # Initialize weights and apply final processing
578
  self.post_init()
 
646
  seq_length_with_past = seq_length
647
  past_key_values_length = 0
648
 
 
 
 
 
649
  if position_ids is None:
650
  device = input_ids.device if input_ids is not None else inputs_embeds.device
651
  position_ids = torch.arange(
 
661
  if inputs_embeds is None:
662
  inputs_embeds = self.embed_tokens(input_ids)
663
  # Embed positions
664
+ if self._use_flash_attention_2:
665
+ # 2d mask is passed through the layers
666
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
667
+ else:
668
+ if attention_mask is None:
669
+ attention_mask = torch.ones(
670
+ (batch_size, seq_length_with_past),
671
+ dtype=torch.bool,
672
+ device=inputs_embeds.device,
673
+ )
674
+ attention_mask = self._prepare_decoder_attention_mask(
675
+ attention_mask,
676
+ (batch_size, seq_length),
677
+ inputs_embeds,
678
+ past_key_values_length,
679
  )
 
 
 
 
 
 
680
 
681
  hidden_states = inputs_embeds
682
 
 
861
  **kwargs,
862
  ):
863
  # Trim decoder_input_ids if past is used
864
+ if past_key_values is not None:
865
+ past_length = past_key_values[0][0].shape[2]
866
+
867
+ # Some generation methods already pass only the last input ID
868
+ if input_ids.shape[1] > past_length:
869
+ remove_prefix_length = past_length
870
+ else:
871
+ # Default to old behavior: keep only final ID
872
+ remove_prefix_length = input_ids.shape[1] - 1
873
+
874
+ input_ids = input_ids[:, remove_prefix_length:]
875
 
876
  position_ids = kwargs.get("position_ids", None)
877
  if attention_mask is not None and position_ids is None: