wanzin commited on
Commit
32ddfbe
1 Parent(s): c98e0b1

Update modeling_llama.py

Browse files
Files changed (1) hide show
  1. modeling_llama.py +406 -461
modeling_llama.py CHANGED
@@ -17,108 +17,123 @@
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
  # See the License for the specific language governing permissions and
19
  # limitations under the License.
20
- """PyTorch LLaMA model."""
21
-
22
  import math
23
- import warnings
24
  from typing import List, Optional, Tuple, Union
25
 
26
  import torch
27
  import torch.nn.functional as F
28
  import torch.utils.checkpoint
29
  from torch import nn
30
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
 
 
32
  from transformers.activations import ACT2FN
33
  from transformers.cache_utils import Cache, DynamicCache, StaticCache
 
34
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
 
35
  from transformers.modeling_outputs import (
36
  BaseModelOutputWithPast,
37
  CausalLMOutputWithPast,
38
  QuestionAnsweringModelOutput,
39
  SequenceClassifierOutputWithPast,
 
40
  )
 
41
  from transformers.modeling_utils import PreTrainedModel
 
42
  from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
43
  from transformers.utils import (
 
 
44
  add_start_docstrings,
45
  add_start_docstrings_to_model_forward,
46
- is_flash_attn_2_available,
47
  is_flash_attn_greater_or_equal_2_10,
48
  logging,
49
  replace_return_docstrings,
50
  )
51
  from .configuration_llama import LlamaConfig
52
- from transformers.models.llama.modeling_llama import LlamaRMSNorm
53
-
54
-
55
- if is_flash_attn_2_available():
56
- from flash_attn import flash_attn_func, flash_attn_varlen_func
57
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
58
-
59
 
60
 
61
  logger = logging.get_logger(__name__)
62
 
 
63
  _CONFIG_FOR_DOC = "LlamaConfig"
64
 
65
-
66
- def _get_unpad_data(attention_mask):
67
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
68
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
69
- max_seqlen_in_batch = seqlens_in_batch.max().item()
70
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
71
- return (
72
- indices,
73
- cu_seqlens,
74
- max_seqlen_in_batch,
75
- )
76
-
77
  ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
78
 
79
 
80
  class LlamaRotaryEmbedding(nn.Module):
81
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
 
 
 
 
 
 
 
 
 
82
  super().__init__()
83
- self.scaling_factor = scaling_factor
84
- self.dim = dim
85
- self.max_position_embeddings = max_position_embeddings
86
- self.base = base
87
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  self.register_buffer("inv_freq", inv_freq, persistent=False)
89
- # For BC we register cos and sin cached
90
- self.max_seq_len_cached = max_position_embeddings
91
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
92
- t = t / self.scaling_factor
93
- freqs = torch.outer(t, self.inv_freq)
94
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
95
- emb = torch.cat((freqs, freqs), dim=-1)
96
- self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
97
- self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)
98
-
99
- @property
100
- def sin_cached(self):
101
- logger.warning_once(
102
- "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
103
- "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
104
- )
105
- return self._sin_cached
106
 
107
- @property
108
- def cos_cached(self):
109
- logger.warning_once(
110
- "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
111
- "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
112
- )
113
- return self._cos_cached
 
 
 
 
 
 
 
 
 
 
114
 
115
  @torch.no_grad()
116
  def forward(self, x, position_ids):
117
- # x: [bs, num_attention_heads, seq_len, head_size]
 
 
 
118
  inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
119
  position_ids_expanded = position_ids[:, None, :].float()
120
- # Force float32 since bfloat16 loses precision on long contexts
121
- # See https://github.com/huggingface/transformers/pull/29285
122
  device_type = x.device.type
123
  device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
124
  with torch.autocast(device_type=device_type, enabled=False):
@@ -126,36 +141,37 @@ class LlamaRotaryEmbedding(nn.Module):
126
  emb = torch.cat((freqs, freqs), dim=-1)
127
  cos = emb.cos()
128
  sin = emb.sin()
 
 
 
 
 
129
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
130
 
131
 
132
  class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
133
  """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
134
 
135
- def forward(self, x, position_ids):
136
- # difference to the original RoPE: a scaling factor is aplied to the position ids
137
- position_ids = position_ids.float() / self.scaling_factor
138
- cos, sin = super().forward(x, position_ids)
139
- return cos, sin
 
 
140
 
141
 
142
  class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
143
  """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
144
 
145
- def forward(self, x, position_ids):
146
- # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
147
- seq_len = torch.max(position_ids) + 1
148
- if seq_len > self.max_position_embeddings:
149
- base = self.base * (
150
- (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
151
- ) ** (self.dim / (self.dim - 2))
152
- inv_freq = 1.0 / (
153
- base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
154
- )
155
- self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
156
-
157
- cos, sin = super().forward(x, position_ids)
158
- return cos, sin
159
 
160
 
161
  def rotate_half(x):
@@ -198,9 +214,9 @@ class LlamaMLP(nn.Module):
198
  self.config = config
199
  self.hidden_size = config.hidden_size
200
  self.intermediate_size = config.intermediate_size
201
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
202
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
203
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
204
  self.act_fn = ACT2FN[config.hidden_act]
205
 
206
  def forward(self, x):
@@ -255,51 +271,20 @@ class LlamaAttention(nn.Module):
255
  self.attention_dropout = config.attention_dropout
256
  self.hidden_size = config.hidden_size
257
  self.num_heads = config.num_attention_heads
258
- self.head_dim = self.hidden_size // self.num_heads
259
  self.num_key_value_heads = config.num_key_value_heads
260
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
261
  self.max_position_embeddings = config.max_position_embeddings
262
  self.rope_theta = config.rope_theta
263
  self.is_causal = True
264
 
265
- if (self.head_dim * self.num_heads) != self.hidden_size:
266
- raise ValueError(
267
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
268
- f" and `num_heads`: {self.num_heads})."
269
- )
270
-
271
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
272
  self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
273
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
274
- self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
275
- self._init_rope()
276
-
277
- def _init_rope(self):
278
- if self.config.rope_scaling is None:
279
- self.rotary_emb = LlamaRotaryEmbedding(
280
- self.head_dim,
281
- max_position_embeddings=self.max_position_embeddings,
282
- base=self.rope_theta,
283
- )
284
- else:
285
- scaling_type = self.config.rope_scaling["type"]
286
- scaling_factor = self.config.rope_scaling["factor"]
287
- if scaling_type == "linear":
288
- self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
289
- self.head_dim,
290
- max_position_embeddings=self.max_position_embeddings,
291
- scaling_factor=scaling_factor,
292
- base=self.rope_theta,
293
- )
294
- elif scaling_type == "dynamic":
295
- self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
296
- self.head_dim,
297
- max_position_embeddings=self.max_position_embeddings,
298
- scaling_factor=scaling_factor,
299
- base=self.rope_theta,
300
- )
301
- else:
302
- raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
303
 
304
  def forward(
305
  self,
@@ -310,6 +295,7 @@ class LlamaAttention(nn.Module):
310
  output_attentions: bool = False,
311
  use_cache: bool = False,
312
  cache_position: Optional[torch.LongTensor] = None,
 
313
  **kwargs,
314
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
315
  bsz, q_len, _ = hidden_states.size()
@@ -340,8 +326,16 @@ class LlamaAttention(nn.Module):
340
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
341
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
342
 
343
- past_key_value = getattr(self, "past_key_value", past_key_value)
344
- cos, sin = self.rotary_emb(value_states, position_ids)
 
 
 
 
 
 
 
 
345
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
346
 
347
  if past_key_value is not None:
@@ -351,7 +345,6 @@ class LlamaAttention(nn.Module):
351
 
352
  key_states = repeat_kv(key_states, self.num_key_value_groups)
353
  value_states = repeat_kv(value_states, self.num_key_value_groups)
354
-
355
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
356
 
357
  if attention_mask is not None: # no matter the length, we just slice it
@@ -371,7 +364,7 @@ class LlamaAttention(nn.Module):
371
 
372
  attn_output = attn_output.transpose(1, 2).contiguous()
373
 
374
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
375
 
376
  if self.config.pretraining_tp > 1:
377
  attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
@@ -410,8 +403,15 @@ class LlamaFlashAttention2(LlamaAttention):
410
  output_attentions: bool = False,
411
  use_cache: bool = False,
412
  cache_position: Optional[torch.LongTensor] = None,
413
- **kwargs,
 
414
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 
 
 
 
 
 
415
  output_attentions = False
416
 
417
  bsz, q_len, _ = hidden_states.size()
@@ -427,11 +427,18 @@ class LlamaFlashAttention2(LlamaAttention):
427
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
428
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
429
 
430
- cos, sin = self.rotary_emb(value_states, position_ids)
 
 
 
 
 
 
 
 
 
431
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
432
 
433
- past_key_value = getattr(self, "past_key_value", past_key_value)
434
-
435
  if past_key_value is not None:
436
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
437
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
@@ -471,11 +478,21 @@ class LlamaFlashAttention2(LlamaAttention):
471
  key_states = key_states.to(target_dtype)
472
  value_states = value_states.to(target_dtype)
473
 
474
- attn_output = self._flash_attention_forward(
475
- query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
 
 
 
 
 
 
 
 
 
 
476
  )
477
 
478
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
479
  attn_output = self.o_proj(attn_output)
480
 
481
  if not output_attentions:
@@ -483,103 +500,6 @@ class LlamaFlashAttention2(LlamaAttention):
483
 
484
  return attn_output, attn_weights, past_key_value
485
 
486
- def _flash_attention_forward(
487
- self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
488
- ):
489
- """
490
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
491
- first unpad the input, then computes the attention scores and pad the final attention scores.
492
-
493
- Args:
494
- query_states (`torch.Tensor`):
495
- Input query states to be passed to Flash Attention API
496
- key_states (`torch.Tensor`):
497
- Input key states to be passed to Flash Attention API
498
- value_states (`torch.Tensor`):
499
- Input value states to be passed to Flash Attention API
500
- attention_mask (`torch.Tensor`):
501
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
502
- position of padding tokens and 1 for the position of non-padding tokens.
503
- dropout (`float`):
504
- Attention dropout
505
- softmax_scale (`float`, *optional*):
506
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
507
- """
508
- if not self._flash_attn_uses_top_left_mask:
509
- causal = self.is_causal
510
- else:
511
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
512
- causal = self.is_causal and query_length != 1
513
-
514
- # Contains at least one padding token in the sequence
515
- if attention_mask is not None:
516
- batch_size = query_states.shape[0]
517
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
518
- query_states, key_states, value_states, attention_mask, query_length
519
- )
520
-
521
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
522
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
523
-
524
- attn_output_unpad = flash_attn_varlen_func(
525
- query_states,
526
- key_states,
527
- value_states,
528
- cu_seqlens_q=cu_seqlens_q,
529
- cu_seqlens_k=cu_seqlens_k,
530
- max_seqlen_q=max_seqlen_in_batch_q,
531
- max_seqlen_k=max_seqlen_in_batch_k,
532
- dropout_p=dropout,
533
- softmax_scale=softmax_scale,
534
- causal=causal,
535
- )
536
-
537
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
538
- else:
539
- attn_output = flash_attn_func(
540
- query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
541
- )
542
-
543
- return attn_output
544
-
545
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
546
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
547
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
548
-
549
- key_layer = index_first_axis(
550
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
551
- )
552
- value_layer = index_first_axis(
553
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
554
- )
555
- if query_length == kv_seq_len:
556
- query_layer = index_first_axis(
557
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
558
- )
559
- cu_seqlens_q = cu_seqlens_k
560
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
561
- indices_q = indices_k
562
- elif query_length == 1:
563
- max_seqlen_in_batch_q = 1
564
- cu_seqlens_q = torch.arange(
565
- batch_size + 1, dtype=torch.int32, device=query_layer.device
566
- ) # There is a memcpy here, that is very bad.
567
- indices_q = cu_seqlens_q[:-1]
568
- query_layer = query_layer.squeeze(1)
569
- else:
570
- # The -q_len: slice assumes left padding.
571
- attention_mask = attention_mask[:, -query_length:]
572
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
573
-
574
- return (
575
- query_layer,
576
- key_layer,
577
- value_layer,
578
- indices_q,
579
- (cu_seqlens_q, cu_seqlens_k),
580
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
581
- )
582
-
583
 
584
  class LlamaSdpaAttention(LlamaAttention):
585
  """
@@ -598,6 +518,8 @@ class LlamaSdpaAttention(LlamaAttention):
598
  output_attentions: bool = False,
599
  use_cache: bool = False,
600
  cache_position: Optional[torch.LongTensor] = None,
 
 
601
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
602
  if output_attentions:
603
  # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
@@ -613,6 +535,7 @@ class LlamaSdpaAttention(LlamaAttention):
613
  output_attentions=output_attentions,
614
  use_cache=use_cache,
615
  cache_position=cache_position,
 
616
  )
617
 
618
  bsz, q_len, _ = hidden_states.size()
@@ -625,12 +548,18 @@ class LlamaSdpaAttention(LlamaAttention):
625
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
626
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
627
 
628
- cos, sin = self.rotary_emb(value_states, position_ids)
 
 
 
 
 
 
 
 
 
629
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
630
 
631
- # In case static cache is used, it is an instance attribute.
632
- past_key_value = getattr(self, "past_key_value", past_key_value)
633
-
634
  if past_key_value is not None:
635
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
636
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
@@ -650,19 +579,21 @@ class LlamaSdpaAttention(LlamaAttention):
650
  key_states = key_states.contiguous()
651
  value_states = value_states.contiguous()
652
 
653
- # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather
654
- # relying on the `is_causal` argument.
 
 
655
  attn_output = torch.nn.functional.scaled_dot_product_attention(
656
  query_states,
657
  key_states,
658
  value_states,
659
  attn_mask=causal_mask,
660
  dropout_p=self.attention_dropout if self.training else 0.0,
661
- is_causal=causal_mask is None and q_len > 1,
662
  )
663
 
664
  attn_output = attn_output.transpose(1, 2).contiguous()
665
- attn_output = attn_output.view(bsz, q_len, self.hidden_size)
666
 
667
  attn_output = self.o_proj(attn_output)
668
 
@@ -692,10 +623,11 @@ class LlamaDecoderLayer(nn.Module):
692
  hidden_states: torch.Tensor,
693
  attention_mask: Optional[torch.Tensor] = None,
694
  position_ids: Optional[torch.LongTensor] = None,
695
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
696
  output_attentions: Optional[bool] = False,
697
  use_cache: Optional[bool] = False,
698
  cache_position: Optional[torch.LongTensor] = None,
 
699
  **kwargs,
700
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
701
  """
@@ -711,12 +643,15 @@ class LlamaDecoderLayer(nn.Module):
711
  If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
712
  (see `past_key_values`).
713
  past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
 
 
 
 
 
 
 
 
714
  """
715
- if "padding_mask" in kwargs:
716
- warnings.warn(
717
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
718
- )
719
-
720
  residual = hidden_states
721
 
722
  hidden_states = self.input_layernorm(hidden_states)
@@ -730,6 +665,7 @@ class LlamaDecoderLayer(nn.Module):
730
  output_attentions=output_attentions,
731
  use_cache=use_cache,
732
  cache_position=cache_position,
 
733
  **kwargs,
734
  )
735
  hidden_states = residual + hidden_states
@@ -781,6 +717,8 @@ class LlamaPreTrainedModel(PreTrainedModel):
781
  _supports_flash_attn_2 = True
782
  _supports_sdpa = True
783
  _supports_cache_class = True
 
 
784
 
785
  def _init_weights(self, module):
786
  std = self.config.initializer_range
@@ -793,27 +731,6 @@ class LlamaPreTrainedModel(PreTrainedModel):
793
  if module.padding_idx is not None:
794
  module.weight.data[module.padding_idx].zero_()
795
 
796
- def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
797
- if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
798
- raise ValueError(
799
- "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
800
- "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
801
- )
802
-
803
- for layer in self.model.layers:
804
- device = layer.input_layernorm.weight.device
805
- if hasattr(self.config, "_pre_quantization_dtype"):
806
- dtype = self.config._pre_quantization_dtype
807
- else:
808
- dtype = layer.self_attn.o_proj.weight.dtype
809
- layer.self_attn.past_key_value = cache_cls(
810
- self.config, max_batch_size, max_cache_len, device=device, dtype=dtype
811
- )
812
-
813
- def _reset_cache(self):
814
- for layer in self.model.layers:
815
- layer.self_attn.past_key_value = None
816
-
817
 
818
  LLAMA_INPUTS_DOCSTRING = r"""
819
  Args:
@@ -856,7 +773,8 @@ LLAMA_INPUTS_DOCSTRING = r"""
856
  returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
857
 
858
  Two formats are allowed:
859
- - a [`~cache_utils.Cache`] instance;
 
860
  - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
861
  shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
862
  cache format.
@@ -911,6 +829,7 @@ class LlamaModel(LlamaPreTrainedModel):
911
  [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
912
  )
913
  self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
914
  self.gradient_checkpointing = False
915
 
916
  # Initialize weights and apply final processing
@@ -928,13 +847,14 @@ class LlamaModel(LlamaPreTrainedModel):
928
  input_ids: torch.LongTensor = None,
929
  attention_mask: Optional[torch.Tensor] = None,
930
  position_ids: Optional[torch.LongTensor] = None,
931
- past_key_values: Optional[List[torch.FloatTensor]] = None,
932
  inputs_embeds: Optional[torch.FloatTensor] = None,
933
  use_cache: Optional[bool] = None,
934
  output_attentions: Optional[bool] = None,
935
  output_hidden_states: Optional[bool] = None,
936
  return_dict: Optional[bool] = None,
937
  cache_position: Optional[torch.LongTensor] = None,
 
938
  ) -> Union[Tuple, BaseModelOutputWithPast]:
939
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
940
  output_hidden_states = (
@@ -944,9 +864,7 @@ class LlamaModel(LlamaPreTrainedModel):
944
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
945
 
946
  if (input_ids is None) ^ (inputs_embeds is not None):
947
- raise ValueError(
948
- "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
949
- )
950
 
951
  if self.gradient_checkpointing and self.training and use_cache:
952
  logger.warning_once(
@@ -957,27 +875,36 @@ class LlamaModel(LlamaPreTrainedModel):
957
  if inputs_embeds is None:
958
  inputs_embeds = self.embed_tokens(input_ids)
959
 
960
- past_seen_tokens = 0
961
- if use_cache: # kept for BC (cache positions)
962
- if past_key_values is not None and not isinstance(past_key_values, StaticCache):
 
 
 
 
963
  past_key_values = DynamicCache.from_legacy_cache(past_key_values)
964
- past_seen_tokens = past_key_values.get_seq_length()
 
 
 
 
965
 
966
  if cache_position is None:
967
- if isinstance(past_key_values, StaticCache):
968
- raise ValueError("cache_position is a required argument when using StaticCache.")
969
  cache_position = torch.arange(
970
  past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
971
  )
972
-
973
  if position_ids is None:
974
  position_ids = cache_position.unsqueeze(0)
975
 
976
- causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)
977
-
978
- # embed positions
979
  hidden_states = inputs_embeds
980
 
 
 
 
981
  # decoder layers
982
  all_hidden_states = () if output_hidden_states else None
983
  all_self_attns = () if output_attentions else None
@@ -997,6 +924,7 @@ class LlamaModel(LlamaPreTrainedModel):
997
  output_attentions,
998
  use_cache,
999
  cache_position,
 
1000
  )
1001
  else:
1002
  layer_outputs = decoder_layer(
@@ -1007,6 +935,8 @@ class LlamaModel(LlamaPreTrainedModel):
1007
  output_attentions=output_attentions,
1008
  use_cache=use_cache,
1009
  cache_position=cache_position,
 
 
1010
  )
1011
 
1012
  hidden_states = layer_outputs[0]
@@ -1023,11 +953,10 @@ class LlamaModel(LlamaPreTrainedModel):
1023
  if output_hidden_states:
1024
  all_hidden_states += (hidden_states,)
1025
 
1026
- next_cache = None
1027
- if use_cache:
1028
- next_cache = (
1029
- next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
1030
- )
1031
  if not return_dict:
1032
  return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1033
  return BaseModelOutputWithPast(
@@ -1042,76 +971,127 @@ class LlamaModel(LlamaPreTrainedModel):
1042
  attention_mask: torch.Tensor,
1043
  input_tensor: torch.Tensor,
1044
  cache_position: torch.Tensor,
1045
- past_seen_tokens: int,
 
1046
  ):
1047
- # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
1048
- # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1049
- # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1050
- # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1051
-
1052
  if self.config._attn_implementation == "flash_attention_2":
1053
  if attention_mask is not None and 0.0 in attention_mask:
1054
  return attention_mask
1055
  return None
1056
 
1057
- if self.config._attn_implementation == "sdpa":
1058
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
1059
- # in order to dispatch on Flash Attention 2.
 
 
 
 
 
1060
  if AttentionMaskConverter._ignore_causal_mask_sdpa(
1061
- attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens
 
 
 
1062
  ):
1063
  return None
1064
 
1065
  dtype, device = input_tensor.dtype, input_tensor.device
1066
- min_dtype = torch.finfo(dtype).min
1067
  sequence_length = input_tensor.shape[1]
1068
- if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache
1069
- target_length = self.config.max_position_embeddings
1070
- else: # dynamic cache
1071
  target_length = (
1072
  attention_mask.shape[-1]
1073
  if isinstance(attention_mask, torch.Tensor)
1074
  else past_seen_tokens + sequence_length + 1
1075
  )
1076
 
1077
- causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
1078
- if sequence_length != 1:
1079
- causal_mask = torch.triu(causal_mask, diagonal=1)
1080
- causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1081
- causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
1082
- if attention_mask is not None:
1083
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1084
- if attention_mask.dim() == 2:
1085
- mask_length = attention_mask.shape[-1]
1086
- padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
1087
- causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
1088
- elif attention_mask.dim() == 4:
1089
- # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
1090
- # cache. In that case, the 4D attention mask attends to the newest tokens only.
1091
- if attention_mask.shape[-2] < cache_position[0] + sequence_length:
1092
- offset = cache_position[0]
1093
- else:
1094
- offset = 0
1095
- mask_shape = attention_mask.shape
1096
- mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
1097
- causal_mask[
1098
- : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
1099
- ] = mask_slice
1100
 
1101
  if (
1102
  self.config._attn_implementation == "sdpa"
1103
  and attention_mask is not None
1104
  and attention_mask.device.type == "cuda"
 
1105
  ):
1106
  # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1107
  # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1108
  # Details: https://github.com/pytorch/pytorch/issues/110213
 
1109
  causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1110
 
1111
  return causal_mask
1112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1113
 
1114
- class LlamaForCausalLM(LlamaPreTrainedModel):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1115
  _tied_weights_keys = ["lm_head.weight"]
1116
 
1117
  def __init__(self, config):
@@ -1148,7 +1128,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
1148
  input_ids: torch.LongTensor = None,
1149
  attention_mask: Optional[torch.Tensor] = None,
1150
  position_ids: Optional[torch.LongTensor] = None,
1151
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1152
  inputs_embeds: Optional[torch.FloatTensor] = None,
1153
  labels: Optional[torch.LongTensor] = None,
1154
  use_cache: Optional[bool] = None,
@@ -1156,6 +1136,8 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
1156
  output_hidden_states: Optional[bool] = None,
1157
  return_dict: Optional[bool] = None,
1158
  cache_position: Optional[torch.LongTensor] = None,
 
 
1159
  ) -> Union[Tuple, CausalLMOutputWithPast]:
1160
  r"""
1161
  Args:
@@ -1164,6 +1146,11 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
1164
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1165
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1166
 
 
 
 
 
 
1167
  Returns:
1168
 
1169
  Example:
@@ -1200,6 +1187,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
1200
  output_hidden_states=output_hidden_states,
1201
  return_dict=return_dict,
1202
  cache_position=cache_position,
 
1203
  )
1204
 
1205
  hidden_states = outputs[0]
@@ -1208,21 +1196,12 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
1208
  logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1209
  logits = torch.cat(logits, dim=-1)
1210
  else:
1211
- logits = self.lm_head(hidden_states)
1212
- logits = logits.float()
1213
 
1214
  loss = None
1215
  if labels is not None:
1216
- # Shift so that tokens < n predict n
1217
- shift_logits = logits[..., :-1, :].contiguous()
1218
- shift_labels = labels[..., 1:].contiguous()
1219
- # Flatten the tokens
1220
- loss_fct = CrossEntropyLoss()
1221
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
1222
- shift_labels = shift_labels.view(-1)
1223
- # Enable model parallelism
1224
- shift_labels = shift_labels.to(shift_logits.device)
1225
- loss = loss_fct(shift_logits, shift_labels)
1226
 
1227
  if not return_dict:
1228
  output = (logits,) + outputs[1:]
@@ -1236,97 +1215,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
1236
  attentions=outputs.attentions,
1237
  )
1238
 
1239
- def prepare_inputs_for_generation(
1240
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
1241
- ):
1242
- # With static cache, the `past_key_values` is None
1243
- # TODO joao: standardize interface for the different Cache classes and remove of this if
1244
- has_static_cache = False
1245
- if past_key_values is None:
1246
- past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None)
1247
- has_static_cache = past_key_values is not None
1248
-
1249
- past_length = 0
1250
- if past_key_values is not None:
1251
- if isinstance(past_key_values, Cache):
1252
- past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
1253
- max_cache_length = (
1254
- torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
1255
- if past_key_values.get_max_length() is not None
1256
- else None
1257
- )
1258
- cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
1259
- # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
1260
- else:
1261
- cache_length = past_length = past_key_values[0][0].shape[2]
1262
- max_cache_length = None
1263
-
1264
- # Keep only the unprocessed tokens:
1265
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1266
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1267
- # input)
1268
- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1269
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1270
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1271
- # input_ids based on the past_length.
1272
- elif past_length < input_ids.shape[1]:
1273
- input_ids = input_ids[:, past_length:]
1274
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1275
-
1276
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1277
- if (
1278
- max_cache_length is not None
1279
- and attention_mask is not None
1280
- and cache_length + input_ids.shape[1] > max_cache_length
1281
- ):
1282
- attention_mask = attention_mask[:, -max_cache_length:]
1283
-
1284
- position_ids = kwargs.get("position_ids", None)
1285
- if attention_mask is not None and position_ids is None:
1286
- # create position_ids on the fly for batch generation
1287
- position_ids = attention_mask.long().cumsum(-1) - 1
1288
- position_ids.masked_fill_(attention_mask == 0, 1)
1289
- if past_key_values:
1290
- position_ids = position_ids[:, -input_ids.shape[1] :]
1291
-
1292
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1293
- if inputs_embeds is not None and past_key_values is None:
1294
- model_inputs = {"inputs_embeds": inputs_embeds}
1295
- else:
1296
- # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
1297
- # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
1298
- # TODO: use `next_tokens` directly instead.
1299
- model_inputs = {"input_ids": input_ids.contiguous()}
1300
-
1301
- input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
1302
- if cache_position is None:
1303
- cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
1304
- else:
1305
- cache_position = cache_position[-input_length:]
1306
-
1307
- if has_static_cache:
1308
- past_key_values = None
1309
-
1310
- model_inputs.update(
1311
- {
1312
- "position_ids": position_ids,
1313
- "cache_position": cache_position,
1314
- "past_key_values": past_key_values,
1315
- "use_cache": kwargs.get("use_cache"),
1316
- "attention_mask": attention_mask,
1317
- }
1318
- )
1319
- return model_inputs
1320
-
1321
- @staticmethod
1322
- def _reorder_cache(past_key_values, beam_idx):
1323
- reordered_past = ()
1324
- for layer_past in past_key_values:
1325
- reordered_past += (
1326
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1327
- )
1328
- return reordered_past
1329
-
1330
 
1331
  @add_start_docstrings(
1332
  """
@@ -1362,10 +1250,10 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
1362
  @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1363
  def forward(
1364
  self,
1365
- input_ids: torch.LongTensor = None,
1366
  attention_mask: Optional[torch.Tensor] = None,
1367
  position_ids: Optional[torch.LongTensor] = None,
1368
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1369
  inputs_embeds: Optional[torch.FloatTensor] = None,
1370
  labels: Optional[torch.LongTensor] = None,
1371
  use_cache: Optional[bool] = None,
@@ -1417,27 +1305,8 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
1417
 
1418
  loss = None
1419
  if labels is not None:
1420
- labels = labels.to(logits.device)
1421
- if self.config.problem_type is None:
1422
- if self.num_labels == 1:
1423
- self.config.problem_type = "regression"
1424
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1425
- self.config.problem_type = "single_label_classification"
1426
- else:
1427
- self.config.problem_type = "multi_label_classification"
1428
-
1429
- if self.config.problem_type == "regression":
1430
- loss_fct = MSELoss()
1431
- if self.num_labels == 1:
1432
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1433
- else:
1434
- loss = loss_fct(pooled_logits, labels)
1435
- elif self.config.problem_type == "single_label_classification":
1436
- loss_fct = CrossEntropyLoss()
1437
- loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1438
- elif self.config.problem_type == "multi_label_classification":
1439
- loss_fct = BCEWithLogitsLoss()
1440
- loss = loss_fct(pooled_logits, labels)
1441
  if not return_dict:
1442
  output = (pooled_logits,) + transformer_outputs[1:]
1443
  return ((loss,) + output) if loss is not None else output
@@ -1482,13 +1351,14 @@ class LlamaForQuestionAnswering(LlamaPreTrainedModel):
1482
  input_ids: Optional[torch.LongTensor] = None,
1483
  attention_mask: Optional[torch.FloatTensor] = None,
1484
  position_ids: Optional[torch.LongTensor] = None,
1485
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1486
  inputs_embeds: Optional[torch.FloatTensor] = None,
1487
  start_positions: Optional[torch.LongTensor] = None,
1488
  end_positions: Optional[torch.LongTensor] = None,
1489
  output_attentions: Optional[bool] = None,
1490
  output_hidden_states: Optional[bool] = None,
1491
  return_dict: Optional[bool] = None,
 
1492
  ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1493
  r"""
1494
  start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1520,31 +1390,106 @@ class LlamaForQuestionAnswering(LlamaPreTrainedModel):
1520
  start_logits = start_logits.squeeze(-1).contiguous()
1521
  end_logits = end_logits.squeeze(-1).contiguous()
1522
 
1523
- total_loss = None
1524
  if start_positions is not None and end_positions is not None:
1525
- # If we are on multi-GPU, split add a dimension
1526
- if len(start_positions.size()) > 1:
1527
- start_positions = start_positions.squeeze(-1).to(start_logits.device)
1528
- if len(end_positions.size()) > 1:
1529
- end_positions = end_positions.squeeze(-1).to(end_logits.device)
1530
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
1531
- ignored_index = start_logits.size(1)
1532
- start_positions = start_positions.clamp(0, ignored_index)
1533
- end_positions = end_positions.clamp(0, ignored_index)
1534
-
1535
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1536
- start_loss = loss_fct(start_logits, start_positions)
1537
- end_loss = loss_fct(end_logits, end_positions)
1538
- total_loss = (start_loss + end_loss) / 2
1539
 
1540
  if not return_dict:
1541
  output = (start_logits, end_logits) + outputs[2:]
1542
- return ((total_loss,) + output) if total_loss is not None else output
1543
 
1544
  return QuestionAnsweringModelOutput(
1545
- loss=total_loss,
1546
  start_logits=start_logits,
1547
  end_logits=end_logits,
1548
  hidden_states=outputs.hidden_states,
1549
  attentions=outputs.attentions,
1550
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
  # See the License for the specific language governing permissions and
19
  # limitations under the License.
 
 
20
  import math
 
21
  from typing import List, Optional, Tuple, Union
22
 
23
  import torch
24
  import torch.nn.functional as F
25
  import torch.utils.checkpoint
26
  from torch import nn
 
27
 
28
+ from transformers.models.llama.modeling_llama import LlamaRMSNorm
29
  from transformers.activations import ACT2FN
30
  from transformers.cache_utils import Cache, DynamicCache, StaticCache
31
+ from transformers.generation import GenerationMixin
32
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
33
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward
34
  from transformers.modeling_outputs import (
35
  BaseModelOutputWithPast,
36
  CausalLMOutputWithPast,
37
  QuestionAnsweringModelOutput,
38
  SequenceClassifierOutputWithPast,
39
+ TokenClassifierOutput,
40
  )
41
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
42
  from transformers.modeling_utils import PreTrainedModel
43
+ from transformers.processing_utils import Unpack
44
  from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
45
  from transformers.utils import (
46
+ LossKwargs,
47
+ add_code_sample_docstrings,
48
  add_start_docstrings,
49
  add_start_docstrings_to_model_forward,
 
50
  is_flash_attn_greater_or_equal_2_10,
51
  logging,
52
  replace_return_docstrings,
53
  )
54
  from .configuration_llama import LlamaConfig
 
 
 
 
 
 
 
55
 
56
 
57
  logger = logging.get_logger(__name__)
58
 
59
+ _CHECKPOINT_FOR_DOC = "meta-llama/Llama-2-7b-hf"
60
  _CONFIG_FOR_DOC = "LlamaConfig"
61
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
63
 
64
 
65
  class LlamaRotaryEmbedding(nn.Module):
66
+ def __init__(
67
+ self,
68
+ dim=None,
69
+ max_position_embeddings=2048,
70
+ base=10000,
71
+ device=None,
72
+ scaling_factor=1.0,
73
+ rope_type="default",
74
+ config: Optional[LlamaConfig] = None,
75
+ ):
76
  super().__init__()
77
+ # TODO (joao): remove the `if` below, only used for BC
78
+ self.rope_kwargs = {}
79
+ if config is None:
80
+ logger.warning_once(
81
+ "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the "
82
+ "`config` argument. All other arguments will be removed in v4.46"
83
+ )
84
+ self.rope_kwargs = {
85
+ "rope_type": rope_type,
86
+ "factor": scaling_factor,
87
+ "dim": dim,
88
+ "base": base,
89
+ "max_position_embeddings": max_position_embeddings,
90
+ }
91
+ self.rope_type = rope_type
92
+ self.max_seq_len_cached = max_position_embeddings
93
+ self.original_max_seq_len = max_position_embeddings
94
+ else:
95
+ # BC: "rope_type" was originally "type"
96
+ if config.rope_scaling is not None:
97
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
98
+ else:
99
+ self.rope_type = "default"
100
+ self.max_seq_len_cached = config.max_position_embeddings
101
+ self.original_max_seq_len = config.max_position_embeddings
102
+
103
+ self.config = config
104
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
105
+
106
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
107
  self.register_buffer("inv_freq", inv_freq, persistent=False)
108
+ self.original_inv_freq = self.inv_freq
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
+ def _dynamic_frequency_update(self, position_ids, device):
111
+ """
112
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
113
+ 1 - growing beyond the cached sequence length (allow scaling)
114
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
115
+ """
116
+ seq_len = torch.max(position_ids) + 1
117
+ if seq_len > self.max_seq_len_cached: # growth
118
+ inv_freq, self.attention_scaling = self.rope_init_fn(
119
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
120
+ )
121
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
122
+ self.max_seq_len_cached = seq_len
123
+
124
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
125
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
126
+ self.max_seq_len_cached = self.original_max_seq_len
127
 
128
  @torch.no_grad()
129
  def forward(self, x, position_ids):
130
+ if "dynamic" in self.rope_type:
131
+ self._dynamic_frequency_update(position_ids, device=x.device)
132
+
133
+ # Core RoPE block
134
  inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
135
  position_ids_expanded = position_ids[:, None, :].float()
136
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
 
137
  device_type = x.device.type
138
  device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
139
  with torch.autocast(device_type=device_type, enabled=False):
 
141
  emb = torch.cat((freqs, freqs), dim=-1)
142
  cos = emb.cos()
143
  sin = emb.sin()
144
+
145
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
146
+ cos = cos * self.attention_scaling
147
+ sin = sin * self.attention_scaling
148
+
149
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
150
 
151
 
152
  class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
153
  """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
154
 
155
+ def __init__(self, *args, **kwargs):
156
+ logger.warning_once(
157
+ "`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
158
+ "`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
159
+ )
160
+ kwargs["rope_type"] = "linear"
161
+ super().__init__(*args, **kwargs)
162
 
163
 
164
  class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
165
  """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
166
 
167
+ def __init__(self, *args, **kwargs):
168
+ logger.warning_once(
169
+ "`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
170
+ "`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
171
+ "__init__)."
172
+ )
173
+ kwargs["rope_type"] = "dynamic"
174
+ super().__init__(*args, **kwargs)
 
 
 
 
 
 
175
 
176
 
177
  def rotate_half(x):
 
214
  self.config = config
215
  self.hidden_size = config.hidden_size
216
  self.intermediate_size = config.intermediate_size
217
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
218
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
219
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
220
  self.act_fn = ACT2FN[config.hidden_act]
221
 
222
  def forward(self, x):
 
271
  self.attention_dropout = config.attention_dropout
272
  self.hidden_size = config.hidden_size
273
  self.num_heads = config.num_attention_heads
274
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
275
  self.num_key_value_heads = config.num_key_value_heads
276
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
277
  self.max_position_embeddings = config.max_position_embeddings
278
  self.rope_theta = config.rope_theta
279
  self.is_causal = True
280
 
 
 
 
 
 
 
281
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
282
  self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
283
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
284
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
285
+
286
+ # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
287
+ self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
  def forward(
290
  self,
 
295
  output_attentions: bool = False,
296
  use_cache: bool = False,
297
  cache_position: Optional[torch.LongTensor] = None,
298
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
299
  **kwargs,
300
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
301
  bsz, q_len, _ = hidden_states.size()
 
326
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
327
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
328
 
329
+ if position_embeddings is None:
330
+ logger.warning_once(
331
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
332
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
333
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
334
+ "removed and `position_embeddings` will be mandatory."
335
+ )
336
+ cos, sin = self.rotary_emb(value_states, position_ids)
337
+ else:
338
+ cos, sin = position_embeddings
339
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
340
 
341
  if past_key_value is not None:
 
345
 
346
  key_states = repeat_kv(key_states, self.num_key_value_groups)
347
  value_states = repeat_kv(value_states, self.num_key_value_groups)
 
348
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
349
 
350
  if attention_mask is not None: # no matter the length, we just slice it
 
364
 
365
  attn_output = attn_output.transpose(1, 2).contiguous()
366
 
367
+ attn_output = attn_output.reshape(bsz, q_len, -1)
368
 
369
  if self.config.pretraining_tp > 1:
370
  attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
 
403
  output_attentions: bool = False,
404
  use_cache: bool = False,
405
  cache_position: Optional[torch.LongTensor] = None,
406
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
407
+ **kwargs: Unpack[FlashAttentionKwargs],
408
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
409
+ if isinstance(past_key_value, StaticCache):
410
+ raise ValueError(
411
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
412
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
413
+ )
414
+
415
  output_attentions = False
416
 
417
  bsz, q_len, _ = hidden_states.size()
 
427
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
428
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
429
 
430
+ if position_embeddings is None:
431
+ logger.warning_once(
432
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
433
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
434
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
435
+ "removed and `position_embeddings` will be mandatory."
436
+ )
437
+ cos, sin = self.rotary_emb(value_states, position_ids)
438
+ else:
439
+ cos, sin = position_embeddings
440
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
441
 
 
 
442
  if past_key_value is not None:
443
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
444
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
 
478
  key_states = key_states.to(target_dtype)
479
  value_states = value_states.to(target_dtype)
480
 
481
+ attn_output = _flash_attention_forward(
482
+ query_states,
483
+ key_states,
484
+ value_states,
485
+ attention_mask,
486
+ q_len,
487
+ position_ids=position_ids,
488
+ dropout=dropout_rate,
489
+ sliding_window=getattr(self, "sliding_window", None),
490
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
491
+ is_causal=self.is_causal,
492
+ **kwargs,
493
  )
494
 
495
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
496
  attn_output = self.o_proj(attn_output)
497
 
498
  if not output_attentions:
 
500
 
501
  return attn_output, attn_weights, past_key_value
502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
 
504
  class LlamaSdpaAttention(LlamaAttention):
505
  """
 
518
  output_attentions: bool = False,
519
  use_cache: bool = False,
520
  cache_position: Optional[torch.LongTensor] = None,
521
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
522
+ **kwargs,
523
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
524
  if output_attentions:
525
  # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
 
535
  output_attentions=output_attentions,
536
  use_cache=use_cache,
537
  cache_position=cache_position,
538
+ position_embeddings=position_embeddings,
539
  )
540
 
541
  bsz, q_len, _ = hidden_states.size()
 
548
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
549
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
550
 
551
+ if position_embeddings is None:
552
+ logger.warning_once(
553
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
554
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
555
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
556
+ "removed and `position_embeddings` will be mandatory."
557
+ )
558
+ cos, sin = self.rotary_emb(value_states, position_ids)
559
+ else:
560
+ cos, sin = position_embeddings
561
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
562
 
 
 
 
563
  if past_key_value is not None:
564
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
565
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
 
579
  key_states = key_states.contiguous()
580
  value_states = value_states.contiguous()
581
 
582
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
583
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
584
+ is_causal = True if causal_mask is None and q_len > 1 else False
585
+
586
  attn_output = torch.nn.functional.scaled_dot_product_attention(
587
  query_states,
588
  key_states,
589
  value_states,
590
  attn_mask=causal_mask,
591
  dropout_p=self.attention_dropout if self.training else 0.0,
592
+ is_causal=is_causal,
593
  )
594
 
595
  attn_output = attn_output.transpose(1, 2).contiguous()
596
+ attn_output = attn_output.view(bsz, q_len, -1)
597
 
598
  attn_output = self.o_proj(attn_output)
599
 
 
623
  hidden_states: torch.Tensor,
624
  attention_mask: Optional[torch.Tensor] = None,
625
  position_ids: Optional[torch.LongTensor] = None,
626
+ past_key_value: Optional[Cache] = None,
627
  output_attentions: Optional[bool] = False,
628
  use_cache: Optional[bool] = False,
629
  cache_position: Optional[torch.LongTensor] = None,
630
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
631
  **kwargs,
632
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
633
  """
 
643
  If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
644
  (see `past_key_values`).
645
  past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
646
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
647
+ Indices depicting the position of the input sequence tokens in the sequence
648
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
649
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
650
+ with `head_dim` being the embedding dimension of each attention head.
651
+ kwargs (`dict`, *optional*):
652
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
653
+ into the model
654
  """
 
 
 
 
 
655
  residual = hidden_states
656
 
657
  hidden_states = self.input_layernorm(hidden_states)
 
665
  output_attentions=output_attentions,
666
  use_cache=use_cache,
667
  cache_position=cache_position,
668
+ position_embeddings=position_embeddings,
669
  **kwargs,
670
  )
671
  hidden_states = residual + hidden_states
 
717
  _supports_flash_attn_2 = True
718
  _supports_sdpa = True
719
  _supports_cache_class = True
720
+ _supports_quantized_cache = True
721
+ _supports_static_cache = True
722
 
723
  def _init_weights(self, module):
724
  std = self.config.initializer_range
 
731
  if module.padding_idx is not None:
732
  module.weight.data[module.padding_idx].zero_()
733
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
734
 
735
  LLAMA_INPUTS_DOCSTRING = r"""
736
  Args:
 
773
  returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
774
 
775
  Two formats are allowed:
776
+ - a [`~cache_utils.Cache`] instance, see our
777
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
778
  - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
779
  shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
780
  cache format.
 
829
  [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
830
  )
831
  self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
832
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
833
  self.gradient_checkpointing = False
834
 
835
  # Initialize weights and apply final processing
 
847
  input_ids: torch.LongTensor = None,
848
  attention_mask: Optional[torch.Tensor] = None,
849
  position_ids: Optional[torch.LongTensor] = None,
850
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
851
  inputs_embeds: Optional[torch.FloatTensor] = None,
852
  use_cache: Optional[bool] = None,
853
  output_attentions: Optional[bool] = None,
854
  output_hidden_states: Optional[bool] = None,
855
  return_dict: Optional[bool] = None,
856
  cache_position: Optional[torch.LongTensor] = None,
857
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
858
  ) -> Union[Tuple, BaseModelOutputWithPast]:
859
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
860
  output_hidden_states = (
 
864
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
865
 
866
  if (input_ids is None) ^ (inputs_embeds is not None):
867
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
 
 
868
 
869
  if self.gradient_checkpointing and self.training and use_cache:
870
  logger.warning_once(
 
875
  if inputs_embeds is None:
876
  inputs_embeds = self.embed_tokens(input_ids)
877
 
878
+ # kept for BC (non `Cache` `past_key_values` inputs)
879
+ return_legacy_cache = False
880
+ if use_cache and not isinstance(past_key_values, Cache):
881
+ return_legacy_cache = True
882
+ if past_key_values is None:
883
+ past_key_values = DynamicCache()
884
+ else:
885
  past_key_values = DynamicCache.from_legacy_cache(past_key_values)
886
+ logger.warning_once(
887
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
888
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
889
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
890
+ )
891
 
892
  if cache_position is None:
893
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
 
894
  cache_position = torch.arange(
895
  past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
896
  )
 
897
  if position_ids is None:
898
  position_ids = cache_position.unsqueeze(0)
899
 
900
+ causal_mask = self._update_causal_mask(
901
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
902
+ )
903
  hidden_states = inputs_embeds
904
 
905
+ # create position embeddings to be shared across the decoder layers
906
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
907
+
908
  # decoder layers
909
  all_hidden_states = () if output_hidden_states else None
910
  all_self_attns = () if output_attentions else None
 
924
  output_attentions,
925
  use_cache,
926
  cache_position,
927
+ position_embeddings,
928
  )
929
  else:
930
  layer_outputs = decoder_layer(
 
935
  output_attentions=output_attentions,
936
  use_cache=use_cache,
937
  cache_position=cache_position,
938
+ position_embeddings=position_embeddings,
939
+ **flash_attn_kwargs,
940
  )
941
 
942
  hidden_states = layer_outputs[0]
 
953
  if output_hidden_states:
954
  all_hidden_states += (hidden_states,)
955
 
956
+ next_cache = next_decoder_cache if use_cache else None
957
+ if return_legacy_cache:
958
+ next_cache = next_cache.to_legacy_cache()
959
+
 
960
  if not return_dict:
961
  return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
962
  return BaseModelOutputWithPast(
 
971
  attention_mask: torch.Tensor,
972
  input_tensor: torch.Tensor,
973
  cache_position: torch.Tensor,
974
+ past_key_values: Cache,
975
+ output_attentions: bool,
976
  ):
 
 
 
 
 
977
  if self.config._attn_implementation == "flash_attention_2":
978
  if attention_mask is not None and 0.0 in attention_mask:
979
  return attention_mask
980
  return None
981
 
982
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
983
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
984
+ # to infer the attention mask.
985
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
986
+ using_static_cache = isinstance(past_key_values, StaticCache)
987
+
988
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
989
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
990
  if AttentionMaskConverter._ignore_causal_mask_sdpa(
991
+ attention_mask,
992
+ inputs_embeds=input_tensor,
993
+ past_key_values_length=past_seen_tokens,
994
+ is_training=self.training,
995
  ):
996
  return None
997
 
998
  dtype, device = input_tensor.dtype, input_tensor.device
 
999
  sequence_length = input_tensor.shape[1]
1000
+ if using_static_cache:
1001
+ target_length = past_key_values.get_max_cache_shape()
1002
+ else:
1003
  target_length = (
1004
  attention_mask.shape[-1]
1005
  if isinstance(attention_mask, torch.Tensor)
1006
  else past_seen_tokens + sequence_length + 1
1007
  )
1008
 
1009
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1010
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
1011
+ attention_mask,
1012
+ sequence_length=sequence_length,
1013
+ target_length=target_length,
1014
+ dtype=dtype,
1015
+ device=device,
1016
+ cache_position=cache_position,
1017
+ batch_size=input_tensor.shape[0],
1018
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
1019
 
1020
  if (
1021
  self.config._attn_implementation == "sdpa"
1022
  and attention_mask is not None
1023
  and attention_mask.device.type == "cuda"
1024
+ and not output_attentions
1025
  ):
1026
  # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1027
  # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1028
  # Details: https://github.com/pytorch/pytorch/issues/110213
1029
+ min_dtype = torch.finfo(dtype).min
1030
  causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1031
 
1032
  return causal_mask
1033
 
1034
+ @staticmethod
1035
+ def _prepare_4d_causal_attention_mask_with_cache_position(
1036
+ attention_mask: torch.Tensor,
1037
+ sequence_length: int,
1038
+ target_length: int,
1039
+ dtype: torch.dtype,
1040
+ device: torch.device,
1041
+ cache_position: torch.Tensor,
1042
+ batch_size: int,
1043
+ **kwargs,
1044
+ ):
1045
+ """
1046
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
1047
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
1048
 
1049
+ Args:
1050
+ attention_mask (`torch.Tensor`):
1051
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
1052
+ `(batch_size, 1, query_length, key_value_length)`.
1053
+ sequence_length (`int`):
1054
+ The sequence length being processed.
1055
+ target_length (`int`):
1056
+ The target length: when generating with static cache, the mask should be as long as the static cache,
1057
+ to account for the 0 padding, the part of the cache that is not filled yet.
1058
+ dtype (`torch.dtype`):
1059
+ The dtype to use for the 4D attention mask.
1060
+ device (`torch.device`):
1061
+ The device to plcae the 4D attention mask on.
1062
+ cache_position (`torch.Tensor`):
1063
+ Indices depicting the position of the input sequence tokens in the sequence.
1064
+ batch_size (`torch.Tensor`):
1065
+ Batch size.
1066
+ """
1067
+ if attention_mask is not None and attention_mask.dim() == 4:
1068
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
1069
+ causal_mask = attention_mask
1070
+ else:
1071
+ min_dtype = torch.finfo(dtype).min
1072
+ causal_mask = torch.full(
1073
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
1074
+ )
1075
+ if sequence_length != 1:
1076
+ causal_mask = torch.triu(causal_mask, diagonal=1)
1077
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1078
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
1079
+ if attention_mask is not None:
1080
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1081
+ mask_length = attention_mask.shape[-1]
1082
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
1083
+ padding_mask = padding_mask == 0
1084
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
1085
+ padding_mask, min_dtype
1086
+ )
1087
+
1088
+ return causal_mask
1089
+
1090
+
1091
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
1092
+
1093
+
1094
+ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
1095
  _tied_weights_keys = ["lm_head.weight"]
1096
 
1097
  def __init__(self, config):
 
1128
  input_ids: torch.LongTensor = None,
1129
  attention_mask: Optional[torch.Tensor] = None,
1130
  position_ids: Optional[torch.LongTensor] = None,
1131
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1132
  inputs_embeds: Optional[torch.FloatTensor] = None,
1133
  labels: Optional[torch.LongTensor] = None,
1134
  use_cache: Optional[bool] = None,
 
1136
  output_hidden_states: Optional[bool] = None,
1137
  return_dict: Optional[bool] = None,
1138
  cache_position: Optional[torch.LongTensor] = None,
1139
+ num_logits_to_keep: int = 0,
1140
+ **kwargs: Unpack[KwargsForCausalLM],
1141
  ) -> Union[Tuple, CausalLMOutputWithPast]:
1142
  r"""
1143
  Args:
 
1146
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1147
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1148
 
1149
+ num_logits_to_keep (`int`, *optional*):
1150
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
1151
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
1152
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
1153
+
1154
  Returns:
1155
 
1156
  Example:
 
1187
  output_hidden_states=output_hidden_states,
1188
  return_dict=return_dict,
1189
  cache_position=cache_position,
1190
+ **kwargs,
1191
  )
1192
 
1193
  hidden_states = outputs[0]
 
1196
  logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1197
  logits = torch.cat(logits, dim=-1)
1198
  else:
1199
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1200
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
1201
 
1202
  loss = None
1203
  if labels is not None:
1204
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
 
 
 
 
 
 
 
 
 
1205
 
1206
  if not return_dict:
1207
  output = (logits,) + outputs[1:]
 
1215
  attentions=outputs.attentions,
1216
  )
1217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1218
 
1219
  @add_start_docstrings(
1220
  """
 
1250
  @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1251
  def forward(
1252
  self,
1253
+ input_ids: Optional[torch.LongTensor] = None,
1254
  attention_mask: Optional[torch.Tensor] = None,
1255
  position_ids: Optional[torch.LongTensor] = None,
1256
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1257
  inputs_embeds: Optional[torch.FloatTensor] = None,
1258
  labels: Optional[torch.LongTensor] = None,
1259
  use_cache: Optional[bool] = None,
 
1305
 
1306
  loss = None
1307
  if labels is not None:
1308
+ loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
1309
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1310
  if not return_dict:
1311
  output = (pooled_logits,) + transformer_outputs[1:]
1312
  return ((loss,) + output) if loss is not None else output
 
1351
  input_ids: Optional[torch.LongTensor] = None,
1352
  attention_mask: Optional[torch.FloatTensor] = None,
1353
  position_ids: Optional[torch.LongTensor] = None,
1354
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1355
  inputs_embeds: Optional[torch.FloatTensor] = None,
1356
  start_positions: Optional[torch.LongTensor] = None,
1357
  end_positions: Optional[torch.LongTensor] = None,
1358
  output_attentions: Optional[bool] = None,
1359
  output_hidden_states: Optional[bool] = None,
1360
  return_dict: Optional[bool] = None,
1361
+ **kwargs,
1362
  ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1363
  r"""
1364
  start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 
1390
  start_logits = start_logits.squeeze(-1).contiguous()
1391
  end_logits = end_logits.squeeze(-1).contiguous()
1392
 
1393
+ loss = None
1394
  if start_positions is not None and end_positions is not None:
1395
+ loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
1396
 
1397
  if not return_dict:
1398
  output = (start_logits, end_logits) + outputs[2:]
1399
+ return ((loss,) + output) if loss is not None else output
1400
 
1401
  return QuestionAnsweringModelOutput(
1402
+ loss=loss,
1403
  start_logits=start_logits,
1404
  end_logits=end_logits,
1405
  hidden_states=outputs.hidden_states,
1406
  attentions=outputs.attentions,
1407
  )
1408
+
1409
+
1410
+ @add_start_docstrings(
1411
+ """
1412
+ The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states
1413
+ output) e.g. for Named-Entity-Recognition (NER) tasks.
1414
+ """,
1415
+ LLAMA_START_DOCSTRING,
1416
+ )
1417
+ class LlamaForTokenClassification(LlamaPreTrainedModel):
1418
+ def __init__(self, config):
1419
+ super().__init__(config)
1420
+ self.num_labels = config.num_labels
1421
+ self.model = LlamaModel(config)
1422
+ if getattr(config, "classifier_dropout", None) is not None:
1423
+ classifier_dropout = config.classifier_dropout
1424
+ elif getattr(config, "hidden_dropout", None) is not None:
1425
+ classifier_dropout = config.hidden_dropout
1426
+ else:
1427
+ classifier_dropout = 0.1
1428
+ self.dropout = nn.Dropout(classifier_dropout)
1429
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
1430
+
1431
+ # Initialize weights and apply final processing
1432
+ self.post_init()
1433
+
1434
+ def get_input_embeddings(self):
1435
+ return self.model.embed_tokens
1436
+
1437
+ def set_input_embeddings(self, value):
1438
+ self.model.embed_tokens = value
1439
+
1440
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1441
+ @add_code_sample_docstrings(
1442
+ checkpoint=_CHECKPOINT_FOR_DOC,
1443
+ output_type=TokenClassifierOutput,
1444
+ config_class=_CONFIG_FOR_DOC,
1445
+ )
1446
+ def forward(
1447
+ self,
1448
+ input_ids: Optional[torch.LongTensor] = None,
1449
+ attention_mask: Optional[torch.Tensor] = None,
1450
+ position_ids: Optional[torch.LongTensor] = None,
1451
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1452
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1453
+ labels: Optional[torch.LongTensor] = None,
1454
+ use_cache: Optional[bool] = None,
1455
+ output_attentions: Optional[bool] = None,
1456
+ output_hidden_states: Optional[bool] = None,
1457
+ return_dict: Optional[bool] = None,
1458
+ ) -> Union[Tuple, TokenClassifierOutput]:
1459
+ r"""
1460
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1461
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1462
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1463
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1464
+ """
1465
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1466
+
1467
+ outputs = self.model(
1468
+ input_ids,
1469
+ attention_mask=attention_mask,
1470
+ position_ids=position_ids,
1471
+ past_key_values=past_key_values,
1472
+ inputs_embeds=inputs_embeds,
1473
+ use_cache=use_cache,
1474
+ output_attentions=output_attentions,
1475
+ output_hidden_states=output_hidden_states,
1476
+ return_dict=return_dict,
1477
+ )
1478
+ sequence_output = outputs[0]
1479
+ sequence_output = self.dropout(sequence_output)
1480
+ logits = self.score(sequence_output)
1481
+
1482
+ loss = None
1483
+ if labels is not None:
1484
+ loss = self.loss_function(logits, labels, self.config)
1485
+
1486
+ if not return_dict:
1487
+ output = (logits,) + outputs[2:]
1488
+ return ((loss,) + output) if loss is not None else output
1489
+
1490
+ return TokenClassifierOutput(
1491
+ loss=loss,
1492
+ logits=logits,
1493
+ hidden_states=outputs.hidden_states,
1494
+ attentions=outputs.attentions,
1495
+ )