jaygala24 commited on
Commit
e58fe12
1 Parent(s): cbb5f4c

Flash Attention.

Browse files
config.json CHANGED
@@ -41,5 +41,6 @@
41
  "share_decoder_input_output_embed": false,
42
  "torch_dtype": "float32",
43
  "transformers_version": "4.32.1",
44
- "use_cache": true
 
45
  }
 
41
  "share_decoder_input_output_embed": false,
42
  "torch_dtype": "float32",
43
  "transformers_version": "4.32.1",
44
+ "use_cache": true,
45
+ "attn_implementation": "eager"
46
  }
configuration_indictrans.py CHANGED
@@ -118,6 +118,7 @@ class IndicTransConfig(PretrainedConfig):
118
  pad_token_id=1,
119
  bos_token_id=0,
120
  eos_token_id=2,
 
121
  **kwargs,
122
  ):
123
  self.encoder_vocab_size = encoder_vocab_size
@@ -146,7 +147,8 @@ class IndicTransConfig(PretrainedConfig):
146
  self.num_hidden_layers = encoder_layers
147
  self.scale_embedding = scale_embedding
148
  self.share_decoder_input_output_embed = share_decoder_input_output_embed
149
-
 
150
  super().__init__(
151
  pad_token_id=pad_token_id,
152
  bos_token_id=bos_token_id,
 
118
  pad_token_id=1,
119
  bos_token_id=0,
120
  eos_token_id=2,
121
+ attn_implementation="eager",
122
  **kwargs,
123
  ):
124
  self.encoder_vocab_size = encoder_vocab_size
 
147
  self.num_hidden_layers = encoder_layers
148
  self.scale_embedding = scale_embedding
149
  self.share_decoder_input_output_embed = share_decoder_input_output_embed
150
+ self.attn_implementation = attn_implementation
151
+
152
  super().__init__(
153
  pad_token_id=pad_token_id,
154
  bos_token_id=bos_token_id,
modeling_indictrans.py CHANGED
@@ -23,15 +23,28 @@ import torch.nn as nn
23
  from torch.nn import functional as F
24
 
25
  from transformers.activations import ACT2FN
 
 
 
 
 
 
 
 
26
  from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
27
  from transformers.modeling_outputs import (
28
  BaseModelOutput,
29
  BaseModelOutputWithPastAndCrossAttentions,
30
  Seq2SeqLMOutput,
31
- Seq2SeqModelOutput,
32
  )
33
 
34
- from transformers.utils import logging
 
 
 
 
 
35
  from transformers.modeling_utils import PreTrainedModel
36
 
37
  from .configuration_indictrans import IndicTransConfig
@@ -39,10 +52,25 @@ from .configuration_indictrans import IndicTransConfig
39
 
40
  logger = logging.get_logger(__name__)
41
 
42
- _CONFIG_FOR_DOC = "IndicTransConfig"
43
-
44
  INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
48
  def shift_tokens_right(
@@ -63,54 +91,6 @@ def shift_tokens_right(
63
  return shifted_input_ids
64
 
65
 
66
- # Copied from transformers.models.bart.modeling_bart._make_causal_mask
67
- def _make_causal_mask(
68
- input_ids_shape: torch.Size,
69
- dtype: torch.dtype,
70
- device: torch.device,
71
- past_key_values_length: int = 0,
72
- ):
73
- """
74
- Make causal mask used for bi-directional self-attention.
75
- """
76
- bsz, tgt_len = input_ids_shape
77
- mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
78
- mask_cond = torch.arange(mask.size(-1), device=device)
79
- mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
80
- mask = mask.to(dtype)
81
-
82
- if past_key_values_length > 0:
83
- mask = torch.cat(
84
- [
85
- torch.zeros(
86
- tgt_len, past_key_values_length, dtype=dtype, device=device
87
- ),
88
- mask,
89
- ],
90
- dim=-1,
91
- )
92
- return mask[None, None, :, :].expand(
93
- bsz, 1, tgt_len, tgt_len + past_key_values_length
94
- )
95
-
96
-
97
- # Copied from transformers.models.bart.modeling_bart._expand_mask
98
- def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
99
- """
100
- Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
101
- """
102
- bsz, src_len = mask.size()
103
- tgt_len = tgt_len if tgt_len is not None else src_len
104
-
105
- expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
106
-
107
- inverted_mask = 1.0 - expanded_mask
108
-
109
- return inverted_mask.masked_fill(
110
- inverted_mask.to(torch.bool), torch.finfo(dtype).min
111
- )
112
-
113
-
114
  def create_position_ids_from_input_ids(
115
  input_ids, padding_idx, past_key_values_length=0
116
  ):
@@ -247,12 +227,15 @@ class IndicTransAttention(nn.Module):
247
  dropout: float = 0.0,
248
  is_decoder: bool = False,
249
  bias: bool = True,
 
 
250
  ):
251
  super().__init__()
252
  self.embed_dim = embed_dim
253
  self.num_heads = num_heads
254
  self.dropout = dropout
255
  self.head_dim = embed_dim // num_heads
 
256
 
257
  if (self.head_dim * num_heads) != self.embed_dim:
258
  raise ValueError(
@@ -261,6 +244,7 @@ class IndicTransAttention(nn.Module):
261
  )
262
  self.scaling = self.head_dim**-0.5
263
  self.is_decoder = is_decoder
 
264
 
265
  self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
266
  self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
@@ -402,17 +386,345 @@ class IndicTransAttention(nn.Module):
402
  attn_output = self.out_proj(attn_output)
403
 
404
  return attn_output, attn_weights_reshaped, past_key_value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
 
407
  # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->IndicTrans
408
  class IndicTransEncoderLayer(nn.Module):
409
  def __init__(self, config: IndicTransConfig):
410
  super().__init__()
411
  self.embed_dim = config.encoder_embed_dim
412
- self.self_attn = IndicTransAttention(
413
  embed_dim=self.embed_dim,
414
  num_heads=config.encoder_attention_heads,
415
  dropout=config.attention_dropout,
 
416
  )
417
  self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
418
  self.dropout = config.dropout
@@ -490,22 +802,25 @@ class IndicTransDecoderLayer(nn.Module):
490
  super().__init__()
491
  self.embed_dim = config.decoder_embed_dim
492
 
493
- self.self_attn = IndicTransAttention(
494
  embed_dim=self.embed_dim,
495
  num_heads=config.decoder_attention_heads,
496
  dropout=config.attention_dropout,
497
  is_decoder=True,
 
 
498
  )
499
  self.dropout = config.dropout
500
  self.activation_fn = ACT2FN[config.activation_function]
501
  self.activation_dropout = config.activation_dropout
502
 
503
  self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
504
- self.encoder_attn = IndicTransAttention(
505
  self.embed_dim,
506
  config.decoder_attention_heads,
507
  dropout=config.attention_dropout,
508
  is_decoder=True,
 
509
  )
510
  self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
511
  self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
@@ -693,6 +1008,9 @@ class IndicTransEncoder(IndicTransPreTrainedModel):
693
  nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
694
  )
695
 
 
 
 
696
  self.gradient_checkpointing = False
697
  # Initialize weights and apply final processing
698
  self.post_init()
@@ -782,10 +1100,18 @@ class IndicTransEncoder(IndicTransPreTrainedModel):
782
  hidden_states = self.layernorm_embedding(hidden_states)
783
  hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
784
 
785
- # expand attention_mask
786
  if attention_mask is not None:
787
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
788
- attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
 
 
 
 
 
 
 
 
 
789
 
790
  encoder_states = () if output_hidden_states else None
791
  all_attentions = () if output_attentions else None
@@ -909,6 +1235,9 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
909
  nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
910
  )
911
 
 
 
 
912
  self.gradient_checkpointing = False
913
  # Initialize weights and apply final processing
914
  self.post_init()
@@ -1031,29 +1360,43 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
1031
  if inputs_embeds is None:
1032
  inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1033
 
1034
- # create causal mask
1035
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1036
- combined_attention_mask = None
1037
- if input_shape[-1] > 1:
1038
- combined_attention_mask = _make_causal_mask(
 
 
 
 
1039
  input_shape,
1040
- inputs_embeds.dtype,
1041
- device=inputs_embeds.device,
1042
- past_key_values_length=past_key_values_length,
1043
  )
1044
-
1045
- if attention_mask is not None and combined_attention_mask is not None:
1046
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1047
- combined_attention_mask = combined_attention_mask + _expand_mask(
1048
- attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1049
  )
1050
 
1051
  # expand encoder attention mask
1052
  if encoder_hidden_states is not None and encoder_attention_mask is not None:
1053
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1054
- encoder_attention_mask = _expand_mask(
1055
- encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1056
- )
 
 
 
 
 
 
 
 
 
 
 
 
1057
 
1058
  # embed positions
1059
  positions = self.embed_positions(
@@ -1124,7 +1467,7 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
1124
  layer_outputs = torch.utils.checkpoint.checkpoint(
1125
  create_custom_forward(decoder_layer),
1126
  hidden_states,
1127
- combined_attention_mask,
1128
  encoder_hidden_states,
1129
  encoder_attention_mask,
1130
  head_mask[idx] if head_mask is not None else None,
@@ -1136,7 +1479,7 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
1136
  else:
1137
  layer_outputs = decoder_layer(
1138
  hidden_states,
1139
- attention_mask=combined_attention_mask,
1140
  encoder_hidden_states=encoder_hidden_states,
1141
  encoder_attention_mask=encoder_attention_mask,
1142
  layer_head_mask=(
 
23
  from torch.nn import functional as F
24
 
25
  from transformers.activations import ACT2FN
26
+
27
+ from transformers.modeling_attn_mask_utils import (
28
+ _prepare_4d_attention_mask,
29
+ _prepare_4d_attention_mask_for_sdpa,
30
+ _prepare_4d_causal_attention_mask,
31
+ _prepare_4d_causal_attention_mask_for_sdpa,
32
+ )
33
+
34
  from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
35
  from transformers.modeling_outputs import (
36
  BaseModelOutput,
37
  BaseModelOutputWithPastAndCrossAttentions,
38
  Seq2SeqLMOutput,
39
+ Seq2SeqModelOutput
40
  )
41
 
42
+ from transformers.utils import (
43
+ logging,
44
+ is_flash_attn_2_available,
45
+ is_flash_attn_greater_or_equal_2_10,
46
+ )
47
+
48
  from transformers.modeling_utils import PreTrainedModel
49
 
50
  from .configuration_indictrans import IndicTransConfig
 
52
 
53
  logger = logging.get_logger(__name__)
54
 
 
 
55
  INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]
56
 
57
+ if is_flash_attn_2_available():
58
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
59
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
60
+
61
+
62
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
63
+ def _get_unpad_data(attention_mask):
64
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
65
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
66
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
67
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
68
+ return (
69
+ indices,
70
+ cu_seqlens,
71
+ max_seqlen_in_batch,
72
+ )
73
+
74
 
75
  # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
76
  def shift_tokens_right(
 
91
  return shifted_input_ids
92
 
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  def create_position_ids_from_input_ids(
95
  input_ids, padding_idx, past_key_values_length=0
96
  ):
 
227
  dropout: float = 0.0,
228
  is_decoder: bool = False,
229
  bias: bool = True,
230
+ is_causal: bool = False,
231
+ config: Optional[IndicTransConfig] = None,
232
  ):
233
  super().__init__()
234
  self.embed_dim = embed_dim
235
  self.num_heads = num_heads
236
  self.dropout = dropout
237
  self.head_dim = embed_dim // num_heads
238
+ self.config = config
239
 
240
  if (self.head_dim * num_heads) != self.embed_dim:
241
  raise ValueError(
 
244
  )
245
  self.scaling = self.head_dim**-0.5
246
  self.is_decoder = is_decoder
247
+ self.is_causal = is_causal
248
 
249
  self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
250
  self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
 
386
  attn_output = self.out_proj(attn_output)
387
 
388
  return attn_output, attn_weights_reshaped, past_key_value
389
+
390
+
391
+ class IndicTransFlashAttention2(IndicTransAttention):
392
+ """
393
+ IndicTrans flash attention module. This module inherits from `IndicTransAttention` as the weights of the module stays
394
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
395
+ flash attention and deal with padding tokens in case the input contains any of them.
396
+ """
397
+
398
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
399
+ def __init__(self, *args, **kwargs):
400
+ super().__init__(*args, **kwargs)
401
+
402
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
403
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
404
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
405
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
406
+
407
+ def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
408
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
409
+
410
+ def forward(
411
+ self,
412
+ hidden_states: torch.Tensor,
413
+ key_value_states: Optional[torch.Tensor] = None,
414
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
415
+ attention_mask: Optional[torch.Tensor] = None,
416
+ layer_head_mask: Optional[torch.Tensor] = None,
417
+ output_attentions: bool = False,
418
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
419
+ # IndicTransFlashAttention2 attention does not support output_attentions
420
+ if output_attentions:
421
+ raise ValueError("IndicTransFlashAttention2 attention does not support output_attentions")
422
+
423
+ # if key_value_states are provided this layer is used as a cross-attention layer
424
+ # for the decoder
425
+ is_cross_attention = key_value_states is not None
426
+
427
+ bsz, q_len, _ = hidden_states.size()
428
+
429
+ # get query proj
430
+ query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
431
+ # get key, value proj
432
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
433
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
434
+ # the provided `key_value_states` to support prefix tuning
435
+ if (
436
+ is_cross_attention
437
+ and past_key_value is not None
438
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
439
+ ):
440
+ # reuse k,v, cross_attentions
441
+ key_states = past_key_value[0].transpose(1, 2)
442
+ value_states = past_key_value[1].transpose(1, 2)
443
+ elif is_cross_attention:
444
+ # cross_attentions
445
+ key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
446
+ value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
447
+ elif past_key_value is not None:
448
+ # reuse k, v, self_attention
449
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
450
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
451
+ key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
452
+ value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
453
+ else:
454
+ # self_attention
455
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
456
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
457
+
458
+ if self.is_decoder:
459
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
460
+ # Further calls to cross_attention layer can then reuse all cross-attention
461
+ # key/value_states (first "if" case)
462
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
463
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
464
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
465
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
466
+ past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
467
+
468
+ kv_seq_len = key_states.shape[-2]
469
+ if past_key_value is not None:
470
+ kv_seq_len += past_key_value[0].shape[-2]
471
+
472
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
473
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
474
+ # cast them back in the correct dtype just to be sure everything works as expected.
475
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
476
+ # in fp32. (LlamaRMSNorm handles it correctly)
477
+
478
+ input_dtype = query_states.dtype
479
+ if input_dtype == torch.float32:
480
+ if torch.is_autocast_enabled():
481
+ target_dtype = torch.get_autocast_gpu_dtype()
482
+ # Handle the case where the model is quantized
483
+ elif hasattr(self.config, "_pre_quantization_dtype"):
484
+ target_dtype = self.config._pre_quantization_dtype
485
+ else:
486
+ target_dtype = self.q_proj.weight.dtype
487
+
488
+ logger.warning_once(
489
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
490
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
491
+ f" {target_dtype}."
492
+ )
493
+
494
+ query_states = query_states.to(target_dtype)
495
+ key_states = key_states.to(target_dtype)
496
+ value_states = value_states.to(target_dtype)
497
+
498
+ attn_output = self._flash_attention_forward(
499
+ query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout
500
+ )
501
+
502
+ attn_output = attn_output.reshape(bsz, q_len, -1)
503
+ attn_output = self.out_proj(attn_output)
504
+
505
+ if not output_attentions:
506
+ attn_weights = None
507
+
508
+ return attn_output, attn_weights, past_key_value
509
+
510
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
511
+ def _flash_attention_forward(
512
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
513
+ ):
514
+ """
515
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
516
+ first unpad the input, then computes the attention scores and pad the final attention scores.
517
+
518
+ Args:
519
+ query_states (`torch.Tensor`):
520
+ Input query states to be passed to Flash Attention API
521
+ key_states (`torch.Tensor`):
522
+ Input key states to be passed to Flash Attention API
523
+ value_states (`torch.Tensor`):
524
+ Input value states to be passed to Flash Attention API
525
+ attention_mask (`torch.Tensor`):
526
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
527
+ position of padding tokens and 1 for the position of non-padding tokens.
528
+ dropout (`float`):
529
+ Attention dropout
530
+ softmax_scale (`float`, *optional*):
531
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
532
+ """
533
+ if not self._flash_attn_uses_top_left_mask:
534
+ causal = self.is_causal
535
+ else:
536
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
537
+ causal = self.is_causal and query_length != 1
538
+
539
+ # Contains at least one padding token in the sequence
540
+ if attention_mask is not None:
541
+ batch_size = query_states.shape[0]
542
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
543
+ query_states, key_states, value_states, attention_mask, query_length
544
+ )
545
+
546
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
547
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
548
+
549
+ attn_output_unpad = flash_attn_varlen_func(
550
+ query_states,
551
+ key_states,
552
+ value_states,
553
+ cu_seqlens_q=cu_seqlens_q,
554
+ cu_seqlens_k=cu_seqlens_k,
555
+ max_seqlen_q=max_seqlen_in_batch_q,
556
+ max_seqlen_k=max_seqlen_in_batch_k,
557
+ dropout_p=dropout,
558
+ softmax_scale=softmax_scale,
559
+ causal=causal,
560
+ )
561
+
562
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
563
+ else:
564
+ attn_output = flash_attn_func(
565
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
566
+ )
567
+
568
+ return attn_output
569
 
570
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
571
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
572
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
573
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
574
+
575
+ key_layer = index_first_axis(
576
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
577
+ )
578
+ value_layer = index_first_axis(
579
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
580
+ )
581
+ if query_length == kv_seq_len:
582
+ query_layer = index_first_axis(
583
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
584
+ )
585
+ cu_seqlens_q = cu_seqlens_k
586
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
587
+ indices_q = indices_k
588
+ elif query_length == 1:
589
+ max_seqlen_in_batch_q = 1
590
+ cu_seqlens_q = torch.arange(
591
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
592
+ ) # There is a memcpy here, that is very bad.
593
+ indices_q = cu_seqlens_q[:-1]
594
+ query_layer = query_layer.squeeze(1)
595
+ else:
596
+ # The -q_len: slice assumes left padding.
597
+ attention_mask = attention_mask[:, -query_length:]
598
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
599
+
600
+ return (
601
+ query_layer,
602
+ key_layer,
603
+ value_layer,
604
+ indices_q,
605
+ (cu_seqlens_q, cu_seqlens_k),
606
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
607
+ )
608
+
609
+
610
+ class IndicTransSdpaAttention(IndicTransAttention):
611
+ def forward(
612
+ self,
613
+ hidden_states: torch.Tensor,
614
+ key_value_states: Optional[torch.Tensor] = None,
615
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
616
+ attention_mask: Optional[torch.Tensor] = None,
617
+ layer_head_mask: Optional[torch.Tensor] = None,
618
+ output_attentions: bool = False,
619
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
620
+ """Input shape: Batch x Time x Channel"""
621
+ if output_attentions or layer_head_mask is not None:
622
+ # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
623
+ logger.warning_once(
624
+ "IndicTransModel is using IndicTransSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
625
+ ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
626
+ )
627
+ return super().forward(
628
+ hidden_states,
629
+ key_value_states=key_value_states,
630
+ past_key_value=past_key_value,
631
+ attention_mask=attention_mask,
632
+ layer_head_mask=layer_head_mask,
633
+ output_attentions=output_attentions,
634
+ )
635
+
636
+ # if key_value_states are provided this layer is used as a cross-attention layer
637
+ # for the decoder
638
+ is_cross_attention = key_value_states is not None
639
+
640
+ bsz, tgt_len, _ = hidden_states.size()
641
+
642
+ # get query proj
643
+ query_states = self.q_proj(hidden_states)
644
+ # get key, value proj
645
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
646
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
647
+ # the provided `key_value_states` to support prefix tuning
648
+ if (
649
+ is_cross_attention
650
+ and past_key_value is not None
651
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
652
+ ):
653
+ # reuse k,v, cross_attentions
654
+ key_states = past_key_value[0]
655
+ value_states = past_key_value[1]
656
+ elif is_cross_attention:
657
+ # cross_attentions
658
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
659
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
660
+ elif past_key_value is not None:
661
+ # reuse k, v, self_attention
662
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
663
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
664
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
665
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
666
+ else:
667
+ # self_attention
668
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
669
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
670
+
671
+ if self.is_decoder:
672
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
673
+ # Further calls to cross_attention layer can then reuse all cross-attention
674
+ # key/value_states (first "if" case)
675
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
676
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
677
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
678
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
679
+ past_key_value = (key_states, value_states)
680
+
681
+ query_states = self._shape(query_states, tgt_len, bsz)
682
+
683
+ # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
684
+ # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
685
+ attn_output = F.scaled_dot_product_attention(
686
+ query_states,
687
+ key_states,
688
+ value_states,
689
+ attn_mask=attention_mask,
690
+ dropout_p=self.dropout if self.training else 0.0,
691
+ # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
692
+ is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
693
+ )
694
+
695
+ if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
696
+ raise ValueError(
697
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
698
+ f" {attn_output.size()}"
699
+ )
700
+
701
+ attn_output = attn_output.transpose(1, 2)
702
+
703
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
704
+ # partitioned across GPUs when using tensor-parallelism.
705
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
706
+
707
+ attn_output = self.out_proj(attn_output)
708
+
709
+ return attn_output, None, past_key_value
710
+
711
+
712
+ INDICTRANS_ATTENTION_CLASSES = {
713
+ "eager": IndicTransAttention,
714
+ "sdpa": IndicTransSdpaAttention,
715
+ "flash_attention_2": IndicTransFlashAttention2,
716
+ }
717
 
718
  # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->IndicTrans
719
  class IndicTransEncoderLayer(nn.Module):
720
  def __init__(self, config: IndicTransConfig):
721
  super().__init__()
722
  self.embed_dim = config.encoder_embed_dim
723
+ self.self_attn = INDICTRANS_ATTENTION_CLASSES[config._attn_implementation](
724
  embed_dim=self.embed_dim,
725
  num_heads=config.encoder_attention_heads,
726
  dropout=config.attention_dropout,
727
+ config=config,
728
  )
729
  self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
730
  self.dropout = config.dropout
 
802
  super().__init__()
803
  self.embed_dim = config.decoder_embed_dim
804
 
805
+ self.self_attn = INDICTRANS_ATTENTION_CLASSES[config._attn_implementation](
806
  embed_dim=self.embed_dim,
807
  num_heads=config.decoder_attention_heads,
808
  dropout=config.attention_dropout,
809
  is_decoder=True,
810
+ is_causal=True,
811
+ config=config,
812
  )
813
  self.dropout = config.dropout
814
  self.activation_fn = ACT2FN[config.activation_function]
815
  self.activation_dropout = config.activation_dropout
816
 
817
  self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
818
+ self.encoder_attn = INDICTRANS_ATTENTION_CLASSES[config._attn_implementation](
819
  self.embed_dim,
820
  config.decoder_attention_heads,
821
  dropout=config.attention_dropout,
822
  is_decoder=True,
823
+ config=config,
824
  )
825
  self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
826
  self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
 
1008
  nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
1009
  )
1010
 
1011
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1012
+ self._use_sdpa = config._attn_implementation == "sdpa"
1013
+
1014
  self.gradient_checkpointing = False
1015
  # Initialize weights and apply final processing
1016
  self.post_init()
 
1100
  hidden_states = self.layernorm_embedding(hidden_states)
1101
  hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
1102
 
 
1103
  if attention_mask is not None:
1104
+ if self._use_flash_attention_2:
1105
+ attention_mask = attention_mask if 0 in attention_mask else None
1106
+ elif self._use_sdpa and head_mask is None and not output_attentions:
1107
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
1108
+ # the manual implementation that requires a 4D causal mask in all cases.
1109
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1110
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
1111
+ else:
1112
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1113
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
1114
+
1115
 
1116
  encoder_states = () if output_hidden_states else None
1117
  all_attentions = () if output_attentions else None
 
1235
  nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
1236
  )
1237
 
1238
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1239
+ self._use_sdpa = config._attn_implementation == "sdpa"
1240
+
1241
  self.gradient_checkpointing = False
1242
  # Initialize weights and apply final processing
1243
  self.post_init()
 
1360
  if inputs_embeds is None:
1361
  inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1362
 
1363
+
1364
+ if self._use_flash_attention_2:
1365
+ # 2d mask is passed through the layers
1366
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1367
+ elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
1368
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
1369
+ # the manual implementation that requires a 4D causal mask in all cases.
1370
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1371
+ attention_mask,
1372
  input_shape,
1373
+ inputs_embeds,
1374
+ past_key_values_length,
 
1375
  )
1376
+ else:
1377
+ # 4d mask is passed through the layers
1378
+ attention_mask = _prepare_4d_causal_attention_mask(
1379
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
 
1380
  )
1381
 
1382
  # expand encoder attention mask
1383
  if encoder_hidden_states is not None and encoder_attention_mask is not None:
1384
+ if self._use_flash_attention_2:
1385
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
1386
+ elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions:
1387
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
1388
+ # the manual implementation that requires a 4D causal mask in all cases.
1389
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1390
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
1391
+ encoder_attention_mask,
1392
+ inputs_embeds.dtype,
1393
+ tgt_len=input_shape[-1],
1394
+ )
1395
+ else:
1396
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1397
+ encoder_attention_mask = _prepare_4d_attention_mask(
1398
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1399
+ )
1400
 
1401
  # embed positions
1402
  positions = self.embed_positions(
 
1467
  layer_outputs = torch.utils.checkpoint.checkpoint(
1468
  create_custom_forward(decoder_layer),
1469
  hidden_states,
1470
+ attention_mask,
1471
  encoder_hidden_states,
1472
  encoder_attention_mask,
1473
  head_mask[idx] if head_mask is not None else None,
 
1479
  else:
1480
  layer_outputs = decoder_layer(
1481
  hidden_states,
1482
+ attention_mask=attention_mask,
1483
  encoder_hidden_states=encoder_hidden_states,
1484
  encoder_attention_mask=encoder_attention_mask,
1485
  layer_head_mask=(