damerajee commited on
Commit
1cc8a08
1 Parent(s): 5084b18

Update modeling_Llamoe.py

Browse files
Files changed (1) hide show
  1. modeling_Llamoe.py +491 -341
modeling_Llamoe.py CHANGED
@@ -62,9 +62,11 @@ def load_balancing_loss_func(
62
  ) -> float:
63
  r"""
64
  Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
 
65
  See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
66
  function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
67
  experts is too unbalanced.
 
68
  Args:
69
  gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
70
  Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
@@ -74,6 +76,7 @@ def load_balancing_loss_func(
74
  shape [batch_size X sequence_length] if not None.
75
  num_experts (`int`, *optional*):
76
  Number of experts
 
77
  Returns:
78
  The auxiliary loss.
79
  """
@@ -130,15 +133,12 @@ def load_balancing_loss_func(
130
  return overall_loss * num_experts
131
 
132
 
133
-
134
- def approx_gelu(x):
135
- return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * x**3)))
136
-
137
  def _get_unpad_data(attention_mask):
138
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
139
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
140
  max_seqlen_in_batch = seqlens_in_batch.max().item()
141
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
142
  return (
143
  indices,
144
  cu_seqlens,
@@ -146,53 +146,60 @@ def _get_unpad_data(attention_mask):
146
  )
147
 
148
 
149
-
150
  class LlamoeRMSNorm(nn.Module):
151
- def __init__(self, dim: int, eps: float = 1e-6):
 
 
 
152
  super().__init__()
153
- self.eps = eps
154
- self.weight = nn.Parameter(torch.zeros(dim))
155
-
156
- def _norm(self, x):
157
- x_float = x.float()
158
- normed_x = x_float * torch.rsqrt(x_float.pow(2).mean(-1, keepdim=True) + self.eps)
159
- return normed_x
160
 
161
- def forward(self, x):
162
- normed_x = self._norm(x)
163
- # Downcast the result to the original dtype at the end
164
- normed_x = normed_x.type_as(x)
165
- return normed_x * (self.weight + 1)
 
166
 
167
- ALL_LAYERNORM_LAYERS.append(LlamoeRMSNorm)
168
 
 
169
  class LlamoeRotaryEmbedding(nn.Module):
170
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
171
  super().__init__()
 
172
  self.dim = dim
173
  self.max_position_embeddings = max_position_embeddings
174
  self.base = base
175
- self._set_cos_sin_cache(seq_len=max_position_embeddings, device=device, dtype=torch.get_default_dtype())
 
 
 
 
 
 
176
 
177
  def _set_cos_sin_cache(self, seq_len, device, dtype):
178
  self.max_seq_len_cached = seq_len
179
- freq_exponents = (2.0 / self.dim) * (torch.arange(self.dim // 2, dtype=torch.float32, device="cpu").float())
180
- timescale = self.base ** freq_exponents
181
- positions = torch.arange(self.max_seq_len_cached, device="cpu", dtype=torch.float32).float()
182
- radians_new = positions[..., None] / timescale[None, None, :]
183
- radians_new = radians_new.squeeze(0)
184
- emb = torch.cat((radians_new, radians_new), dim=-1)
185
- cos = emb.cos().to(device=device, dtype=dtype, non_blocking=True)
186
- sin = emb.sin().to(device=device, dtype=dtype, non_blocking=True)
187
- self.register_buffer("cos_cached", cos, persistent=False)
188
- self.register_buffer("sin_cached", sin, persistent=False)
189
-
190
- def forward(self, x, position_ids=None, seq_len=None):
191
- if seq_len is None:
192
- seq_len = x.size(2)
193
  if seq_len > self.max_seq_len_cached:
194
  self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
195
- return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
 
 
 
 
 
196
 
197
  # Copied from transformers.models.llama.modeling_llama.rotate_half
198
  def rotate_half(x):
@@ -202,15 +209,35 @@ def rotate_half(x):
202
  return torch.cat((-x2, x1), dim=-1)
203
 
204
 
205
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
206
- seq_len, dim = q.shape[-2], q.shape[-1]
207
- cos = cos[:seq_len].view(1, 1, seq_len, dim)
208
- sin = sin[:seq_len].view(1, 1, seq_len, dim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  q_embed = (q * cos) + (rotate_half(q) * sin)
210
  k_embed = (k * cos) + (rotate_half(k) * sin)
211
  return q_embed, k_embed
212
 
213
-
214
  # Copied from transformers.models.llama.modeling_llama.repeat_kv
215
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
216
  """
@@ -223,11 +250,15 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
223
  hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
224
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
225
 
 
 
226
  class LlamoeAttention(nn.Module):
227
- """Multi-headed attention from 'Attention Is All You Need' paper"""
 
 
 
228
 
229
- # Ignore copy
230
- def __init__(self, config: LlamoeConfig, layer_idx: Optional[int] = None):
231
  super().__init__()
232
  self.config = config
233
  self.layer_idx = layer_idx
@@ -238,32 +269,35 @@ class LlamoeAttention(nn.Module):
238
  "when creating this class."
239
  )
240
 
241
- self.attention_dropout = config.attention_dropout
242
  self.hidden_size = config.hidden_size
243
  self.num_heads = config.num_attention_heads
244
- self.head_dim = config.head_dim
245
  self.num_key_value_heads = config.num_key_value_heads
246
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
247
  self.max_position_embeddings = config.max_position_embeddings
248
  self.rope_theta = config.rope_theta
249
  self.is_causal = True
 
250
 
251
- if self.hidden_size % self.num_heads != 0:
252
  raise ValueError(
253
  f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
254
  f" and `num_heads`: {self.num_heads})."
255
  )
 
 
 
 
256
 
257
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
258
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
259
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
260
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
261
- self.rotary_emb = LlamoeRotaryEmbedding(
262
  self.head_dim,
263
  max_position_embeddings=self.max_position_embeddings,
264
  base=self.rope_theta,
265
  )
266
 
 
 
 
267
  def forward(
268
  self,
269
  hidden_states: torch.Tensor,
@@ -272,9 +306,12 @@ class LlamoeAttention(nn.Module):
272
  past_key_value: Optional[Cache] = None,
273
  output_attentions: bool = False,
274
  use_cache: bool = False,
275
- cache_position: Optional[torch.LongTensor] = None,
276
  **kwargs,
277
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 
 
 
 
278
  bsz, q_len, _ = hidden_states.size()
279
 
280
  query_states = self.q_proj(hidden_states)
@@ -285,26 +322,41 @@ class LlamoeAttention(nn.Module):
285
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
286
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
287
 
288
- past_key_value = getattr(self, "past_key_value", past_key_value)
289
- cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
290
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
 
 
 
 
 
 
 
 
291
 
292
  if past_key_value is not None:
293
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
294
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
295
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
296
 
 
297
  key_states = repeat_kv(key_states, self.num_key_value_groups)
298
  value_states = repeat_kv(value_states, self.num_key_value_groups)
299
 
300
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
301
 
302
- if attention_mask is not None: # no matter the length, we just slice it
303
- if cache_position is not None:
304
- causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
305
- else:
306
- causal_mask = attention_mask
307
- attn_weights = attn_weights + causal_mask
 
 
 
 
 
 
 
308
 
309
  # upcast attention to fp32
310
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
@@ -318,8 +370,8 @@ class LlamoeAttention(nn.Module):
318
  )
319
 
320
  attn_output = attn_output.transpose(1, 2).contiguous()
 
321
 
322
- attn_output = attn_output.view(bsz, q_len, -1)
323
  attn_output = self.o_proj(attn_output)
324
 
325
  if not output_attentions:
@@ -328,14 +380,15 @@ class LlamoeAttention(nn.Module):
328
  return attn_output, attn_weights, past_key_value
329
 
330
 
331
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Gemmoe
332
  class LlamoeFlashAttention2(LlamoeAttention):
333
  """
334
- Llamoe flash attention module. This module inherits from `LlamoeAttention` as the weights of the module stays
335
  untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
336
  flash attention and deal with padding tokens in case the input contains any of them.
337
  """
338
 
 
339
  def __init__(self, *args, **kwargs):
340
  super().__init__(*args, **kwargs)
341
 
@@ -344,57 +397,98 @@ class LlamoeFlashAttention2(LlamoeAttention):
344
  # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
345
  self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
346
 
347
- # Ignore copy
348
  def forward(
349
  self,
350
  hidden_states: torch.Tensor,
351
- attention_mask: Optional[torch.LongTensor] = None,
352
  position_ids: Optional[torch.LongTensor] = None,
353
  past_key_value: Optional[Cache] = None,
354
  output_attentions: bool = False,
355
  use_cache: bool = False,
356
- cache_position: Optional[torch.LongTensor] = None,
357
  **kwargs,
358
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
359
- output_attentions = False
 
 
 
360
 
 
 
361
  bsz, q_len, _ = hidden_states.size()
362
 
363
  query_states = self.q_proj(hidden_states)
364
  key_states = self.k_proj(hidden_states)
365
  value_states = self.v_proj(hidden_states)
366
 
367
- # Flash attention requires the input to have the shape
368
- # batch_size x seq_length x head_dim x hidden_dim
369
- # therefore we just need to keep the original shape
370
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
371
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
372
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
373
 
374
- cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
375
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
 
 
 
 
 
 
 
376
 
377
- past_key_value = getattr(self, "past_key_value", past_key_value)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
 
379
  if past_key_value is not None:
380
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
381
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
382
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
 
 
 
 
383
 
384
- # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
385
- # to be able to avoid many of these transpose/reshape/view.
386
- query_states = query_states.transpose(1, 2)
387
- key_states = key_states.transpose(1, 2)
388
- value_states = value_states.transpose(1, 2)
 
 
 
 
 
 
 
 
 
 
389
 
390
- dropout_rate = self.attention_dropout if self.training else 0.0
 
 
 
 
 
 
391
 
392
  # In PEFT, usually we cast the layer norms in float32 for training stability reasons
393
  # therefore the input hidden states gets silently casted in float32. Hence, we need
394
- # cast them back in the correct dtype just to be sure everything works as expected.
395
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
396
- # in fp32. (GemmoeRMSNorm handles it correctly)
397
-
398
  input_dtype = query_states.dtype
399
  if input_dtype == torch.float32:
400
  if torch.is_autocast_enabled():
@@ -415,11 +509,22 @@ class LlamoeFlashAttention2(LlamoeAttention):
415
  key_states = key_states.to(target_dtype)
416
  value_states = value_states.to(target_dtype)
417
 
 
 
 
 
 
418
  attn_output = self._flash_attention_forward(
419
- query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
 
 
 
 
 
 
420
  )
421
 
422
- attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
423
  attn_output = self.o_proj(attn_output)
424
 
425
  if not output_attentions:
@@ -428,11 +533,20 @@ class LlamoeFlashAttention2(LlamoeAttention):
428
  return attn_output, attn_weights, past_key_value
429
 
430
  def _flash_attention_forward(
431
- self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
 
 
 
 
 
 
 
 
432
  ):
433
  """
434
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
435
  first unpad the input, then computes the attention scores and pad the final attention scores.
 
436
  Args:
437
  query_states (`torch.Tensor`):
438
  Input query states to be passed to Flash Attention API
@@ -447,11 +561,13 @@ class LlamoeFlashAttention2(LlamoeAttention):
447
  Attention dropout
448
  softmax_scale (`float`, *optional*):
449
  The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
 
 
450
  """
451
  if not self._flash_attn_uses_top_left_mask:
452
  causal = self.is_causal
453
  else:
454
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in GemmoeFlashAttention2 __init__.
455
  causal = self.is_causal and query_length != 1
456
 
457
  # Contains at least one padding token in the sequence
@@ -464,40 +580,75 @@ class LlamoeFlashAttention2(LlamoeAttention):
464
  cu_seqlens_q, cu_seqlens_k = cu_seq_lens
465
  max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
466
 
467
- attn_output_unpad = flash_attn_varlen_func(
468
- query_states,
469
- key_states,
470
- value_states,
471
- cu_seqlens_q=cu_seqlens_q,
472
- cu_seqlens_k=cu_seqlens_k,
473
- max_seqlen_q=max_seqlen_in_batch_q,
474
- max_seqlen_k=max_seqlen_in_batch_k,
475
- dropout_p=dropout,
476
- softmax_scale=softmax_scale,
477
- causal=causal,
478
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
479
 
480
  attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
481
  else:
482
- attn_output = flash_attn_func(
483
- query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
484
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
 
486
  return attn_output
487
 
488
  def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
 
 
 
 
 
 
 
 
489
  indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
490
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
491
 
492
- key_layer = index_first_axis(
493
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
494
- )
495
- value_layer = index_first_axis(
496
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
497
- )
498
  if query_length == kv_seq_len:
499
  query_layer = index_first_axis(
500
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
501
  )
502
  cu_seqlens_q = cu_seqlens_k
503
  max_seqlen_in_batch_q = max_seqlen_in_batch_k
@@ -524,15 +675,15 @@ class LlamoeFlashAttention2(LlamoeAttention):
524
  )
525
 
526
 
527
- # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Gemmoe
528
  class LlamoeSdpaAttention(LlamoeAttention):
529
  """
530
- Llamoe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
531
- `LlamoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
532
  SDPA API.
533
  """
534
 
535
- # Ignore copy
536
  def forward(
537
  self,
538
  hidden_states: torch.Tensor,
@@ -541,12 +692,11 @@ class LlamoeSdpaAttention(LlamoeAttention):
541
  past_key_value: Optional[Cache] = None,
542
  output_attentions: bool = False,
543
  use_cache: bool = False,
544
- cache_position: Optional[torch.LongTensor] = None,
545
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
546
  if output_attentions:
547
  # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
548
  logger.warning_once(
549
- "LlamoeModel is using LlamoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
550
  'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
551
  )
552
  return super().forward(
@@ -556,41 +706,41 @@ class LlamoeSdpaAttention(LlamoeAttention):
556
  past_key_value=past_key_value,
557
  output_attentions=output_attentions,
558
  use_cache=use_cache,
559
- cache_position=cache_position,
560
  )
561
 
562
  bsz, q_len, _ = hidden_states.size()
563
 
564
-
565
  query_states = self.q_proj(hidden_states)
566
  key_states = self.k_proj(hidden_states)
567
  value_states = self.v_proj(hidden_states)
568
-
569
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
570
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
571
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
572
 
573
- cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
574
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
575
-
576
- past_key_value = getattr(self, "past_key_value", past_key_value)
577
-
578
  if past_key_value is not None:
579
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
580
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
 
 
 
 
 
581
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
582
 
583
  key_states = repeat_kv(key_states, self.num_key_value_groups)
584
  value_states = repeat_kv(value_states, self.num_key_value_groups)
585
 
586
-
587
- causal_mask = attention_mask
588
- if attention_mask is not None and cache_position is not None:
589
- causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
 
590
 
591
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
592
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
593
- if query_states.device.type == "cuda" and causal_mask is not None:
594
  query_states = query_states.contiguous()
595
  key_states = key_states.contiguous()
596
  value_states = value_states.contiguous()
@@ -599,88 +749,129 @@ class LlamoeSdpaAttention(LlamoeAttention):
599
  query_states,
600
  key_states,
601
  value_states,
602
- attn_mask=causal_mask,
603
  dropout_p=self.attention_dropout if self.training else 0.0,
 
 
604
  )
605
 
606
  attn_output = attn_output.transpose(1, 2).contiguous()
607
- attn_output = attn_output.view(bsz, q_len, -1)
608
 
609
  attn_output = self.o_proj(attn_output)
610
 
611
  return attn_output, None, past_key_value
612
 
613
 
614
- LLAMOE_ATTENTION_CLASSES = {
615
  "eager": LlamoeAttention,
616
  "flash_attention_2": LlamoeFlashAttention2,
617
  "sdpa": LlamoeSdpaAttention,
618
  }
619
 
620
- class LlamoeBlockSparseTop2MLP(nn.Module):
 
621
  def __init__(self, config: LlamoeConfig):
622
  super().__init__()
623
  self.ffn_dim = config.intermediate_size
624
  self.hidden_dim = config.hidden_size
625
-
626
  self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
627
  self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
628
  self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
629
 
630
- self.act_fn = approx_gelu
631
 
632
  def forward(self, hidden_states):
633
  current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
634
  current_hidden_states = self.w2(current_hidden_states)
635
- return current_hidden_states.to(hidden_states.dtype)
 
 
 
 
 
 
 
 
636
 
637
 
638
  class LlamoeSparseMoeBlock(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
639
  def __init__(self, config):
640
  super().__init__()
641
  self.hidden_dim = config.hidden_size
642
  self.ffn_dim = config.intermediate_size
643
  self.num_experts = config.num_local_experts
644
- self.top_k = 2
645
 
646
  # gating
647
  self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
648
 
649
  self.experts = nn.ModuleList([LlamoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
650
 
651
- def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
 
652
  batch_size, sequence_length, hidden_dim = hidden_states.shape
653
  hidden_states = hidden_states.view(-1, hidden_dim)
654
-
655
  # router_logits: (batch * sequence_length, n_experts)
656
  router_logits = self.gate(hidden_states)
657
- routing_weights = F.softmax(router_logits, dim=1)
658
- topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
659
- topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
660
 
661
- hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
662
 
663
- y = torch.empty_like(hidden_states)
 
 
664
 
665
- flat_topk_idx = topk_idx.view(-1)
666
- for i in range(self.num_experts):
667
- expert = self.experts[i]
668
- expert_output = expert(hidden_states[flat_topk_idx == i])
669
- y[flat_topk_idx == i] = expert_output
670
 
671
- y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
 
 
 
 
672
 
673
- final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
674
- return final_hidden_states.to(hidden_states.dtype), router_logits.to(hidden_states.dtype)
675
 
676
-
677
- # Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->GEMMOE,Llama->Gemmoe
678
  class LlamoeDecoderLayer(nn.Module):
679
  def __init__(self, config: LlamoeConfig, layer_idx: int):
680
  super().__init__()
681
  self.hidden_size = config.hidden_size
682
 
683
- self.self_attn = LLAMOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
684
 
685
  self.block_sparse_moe = LlamoeSparseMoeBlock(config)
686
  self.input_layernorm = LlamoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -757,11 +948,13 @@ Llamoe_START_DOCSTRING = r"""
757
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
758
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
759
  etc.)
 
760
  This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
761
  Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
762
  and behavior.
 
763
  Parameters:
764
- config ([`LlamoeConfig`]):
765
  Model configuration class with all the parameters of the model. Initializing with a config file does not
766
  load the weights associated with the model, only the configuration. Check out the
767
  [`~PreTrainedModel.from_pretrained`] method to load the model weights.
@@ -769,17 +962,16 @@ Llamoe_START_DOCSTRING = r"""
769
 
770
 
771
  @add_start_docstrings(
772
- "The bare Llamoe Model outputting raw hidden-states without any specific head on top.",
773
  Llamoe_START_DOCSTRING,
774
  )
775
-
776
  class LlamoePreTrainedModel(PreTrainedModel):
777
  config_class = LlamoeConfig
778
  base_model_prefix = "model"
779
  supports_gradient_checkpointing = True
780
- _keep_in_fp32_modules = ["inv_freq", "rotary_emb", "cos_cached", "sin_cached"]
781
  _no_split_modules = ["LlamoeDecoderLayer"]
782
- _skip_keys_device_placement = ["past_key_values", "causal_mask"]
783
  _supports_flash_attn_2 = True
784
  _supports_sdpa = True
785
  _supports_cache_class = True
@@ -795,68 +987,53 @@ class LlamoePreTrainedModel(PreTrainedModel):
795
  if module.padding_idx is not None:
796
  module.weight.data[module.padding_idx].zero_()
797
 
798
- def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
799
- if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
800
- raise ValueError(
801
- "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
802
- "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
803
- )
804
-
805
- if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device:
806
- causal_mask = torch.full((max_cache_len, max_cache_len), fill_value=1, device=self.device)
807
- self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
808
-
809
- for layer in self.model.layers:
810
- weights = layer.self_attn.o_proj.weight
811
- layer.self_attn.past_key_value = cache_cls(
812
- self.config, max_batch_size, max_cache_len, device=weights.device, dtype=weights.dtype
813
- )
814
-
815
- def _reset_cache(self):
816
- for layer in self.model.layers:
817
- layer.self_attn.past_key_value = None
818
 
819
-
820
- LLAMOE_INPUTS_DOCSTRING = r"""
821
  Args:
822
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
823
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
824
  it.
 
825
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
826
  [`PreTrainedTokenizer.__call__`] for details.
 
827
  [What are input IDs?](../glossary#input-ids)
828
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
829
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
 
830
  - 1 for tokens that are **not masked**,
831
  - 0 for tokens that are **masked**.
 
832
  [What are attention masks?](../glossary#attention-mask)
 
833
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
834
  [`PreTrainedTokenizer.__call__`] for details.
835
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
 
836
  `past_key_values`).
 
837
  If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
838
  and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
839
  information on the default strategy.
 
840
  - 1 indicates the head is **not masked**,
841
  - 0 indicates the head is **masked**.
842
  position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
843
  Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
844
  config.n_positions - 1]`.
 
845
  [What are position IDs?](../glossary#position-ids)
846
- past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
847
- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
848
- blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
849
- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
850
- Two formats are allowed:
851
- - a [`~cache_utils.Cache`] instance;
852
- - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
853
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
854
- cache format.
855
- The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
856
- legacy cache format will be returned.
857
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
858
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
859
- of shape `(batch_size, sequence_length)`.
860
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
861
  Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
862
  is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
@@ -870,28 +1047,28 @@ LLAMOE_INPUTS_DOCSTRING = r"""
870
  output_hidden_states (`bool`, *optional*):
871
  Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
872
  more detail.
 
 
 
873
  return_dict (`bool`, *optional*):
874
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
875
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
876
- Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
877
- this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
878
- the complete sequence length.
879
  """
880
 
881
 
882
  @add_start_docstrings(
883
- "The bare Llamoe Model outputting raw hidden-states without any specific head on top.",
884
  Llamoe_START_DOCSTRING,
885
  )
886
-
887
- class LlamoeModel(LlamoePreTrainedModel):
888
  """
889
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamoeDecoderLayer`]
 
890
  Args:
891
- config: LlamoeConfig
892
  """
893
 
894
- def __init__(self, config: LlamoeConfig):
895
  super().__init__(config)
896
  self.padding_idx = config.pad_token_id
897
  self.vocab_size = config.vocab_size
@@ -900,15 +1077,10 @@ class LlamoeModel(LlamoePreTrainedModel):
900
  self.layers = nn.ModuleList(
901
  [LlamoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
902
  )
 
903
  self.norm = LlamoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
904
- self.gradient_checkpointing = False
905
 
906
- # Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
907
- # NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
908
- causal_mask = torch.full(
909
- (config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool
910
- )
911
- self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
912
  # Initialize weights and apply final processing
913
  self.post_init()
914
 
@@ -918,7 +1090,8 @@ class LlamoeModel(LlamoePreTrainedModel):
918
  def set_input_embeddings(self, value):
919
  self.embed_tokens = value
920
 
921
- @add_start_docstrings_to_model_forward(LLAMOE_INPUTS_DOCSTRING)
 
922
  def forward(
923
  self,
924
  input_ids: torch.LongTensor = None,
@@ -931,89 +1104,118 @@ class LlamoeModel(LlamoePreTrainedModel):
931
  output_hidden_states: Optional[bool] = None,
932
  output_router_logits: Optional[bool] = None,
933
  return_dict: Optional[bool] = None,
934
- cache_position: Optional[torch.LongTensor] = None,
935
  ) -> Union[Tuple, MoeModelOutputWithPast]:
936
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
937
  output_hidden_states = (
938
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
939
  )
940
  use_cache = use_cache if use_cache is not None else self.config.use_cache
 
941
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
942
 
943
- if (input_ids is None) ^ (inputs_embeds is not None):
944
- raise ValueError(
945
- "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
946
- )
 
 
 
 
 
947
 
948
- if self.gradient_checkpointing and self.training and use_cache:
949
- logger.warning_once(
950
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
951
- )
952
- use_cache = False
953
 
954
- if inputs_embeds is None:
955
- inputs_embeds = self.embed_tokens(input_ids)
 
 
 
 
956
 
957
- # Scale embeddings
958
- # Fix for precision issue when casting to bfloat16
959
- hidden_size_sqrt = math.sqrt(self.config.hidden_size)
960
- if inputs_embeds.dtype == torch.bfloat16:
961
- pass
962
-
963
- hidden_states = inputs_embeds * hidden_size_sqrt
964
-
965
- past_seen_tokens = 0
966
- if use_cache: # kept for BC (cache positions)
967
- if not isinstance(past_key_values, StaticCache):
968
  past_key_values = DynamicCache.from_legacy_cache(past_key_values)
969
- past_seen_tokens = past_key_values.get_seq_length()
970
 
971
- if cache_position is None:
972
- cache_position = torch.arange(
973
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
 
974
  )
 
 
 
975
 
976
- if position_ids is None:
977
- position_ids = cache_position.unsqueeze(0)
978
 
979
- causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
 
 
 
 
 
 
 
980
 
981
- # embed positions
982
- hidden_states = inputs_embeds
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
983
 
984
- # normalized
985
- hidden_states = hidden_states * (self.config.hidden_size**0.5)
986
 
987
  # decoder layers
988
  all_hidden_states = () if output_hidden_states else None
989
  all_self_attns = () if output_attentions else None
 
990
  next_decoder_cache = None
991
 
992
  for decoder_layer in self.layers:
993
  if output_hidden_states:
994
  all_hidden_states += (hidden_states,)
 
 
995
  layer_outputs = self._gradient_checkpointing_func(
996
  decoder_layer.__call__,
997
  hidden_states,
998
- causal_mask,
999
  position_ids,
1000
  past_key_values,
1001
  output_attentions,
1002
  output_router_logits,
1003
- use_cache.item() if isinstance(use_cache, torch.Tensor) else use_cache,
1004
- cache_position,
1005
- output_router_logits,
1006
  )
1007
  else:
1008
  layer_outputs = decoder_layer(
1009
  hidden_states,
1010
- attention_mask=causal_mask,
1011
  position_ids=position_ids,
1012
  past_key_value=past_key_values,
1013
  output_attentions=output_attentions,
1014
  output_router_logits=output_router_logits,
1015
- use_cache=use_cache.item() if isinstance(use_cache, torch.Tensor) else use_cache,
1016
- cache_position=cache_position,
1017
  )
1018
 
1019
  hidden_states = layer_outputs[0]
@@ -1024,6 +1226,9 @@ class LlamoeModel(LlamoePreTrainedModel):
1024
  if output_attentions:
1025
  all_self_attns += (layer_outputs[1],)
1026
 
 
 
 
1027
  hidden_states = self.norm(hidden_states)
1028
 
1029
  # add hidden states from the last decoder layer
@@ -1032,74 +1237,29 @@ class LlamoeModel(LlamoePreTrainedModel):
1032
 
1033
  next_cache = None
1034
  if use_cache:
1035
- next_cache = (
1036
- next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
1037
- )
1038
  if not return_dict:
1039
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
 
 
 
 
1040
  return MoeModelOutputWithPast(
1041
  last_hidden_state=hidden_states,
1042
  past_key_values=next_cache,
1043
  hidden_states=all_hidden_states,
1044
  attentions=all_self_attns,
 
1045
  )
1046
 
1047
- def _update_causal_mask(self, attention_mask, input_tensor):
1048
- if self.config._attn_implementation == "flash_attention_2":
1049
- if attention_mask is not None and 0.0 in attention_mask:
1050
- return attention_mask
1051
- return None
1052
-
1053
- batch_size, seq_length = input_tensor.shape[:2]
1054
- dtype = input_tensor.dtype
1055
- device = input_tensor.device
1056
-
1057
- # support going beyond cached `max_position_embedding`
1058
- if seq_length > self.causal_mask.shape[-1]:
1059
- causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
1060
- self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
1061
-
1062
- # We use the current dtype to avoid any overflows
1063
- min_dtype = torch.finfo(dtype).min
1064
-
1065
- causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype
1066
- causal_mask = causal_mask.expand(batch_size, 1, -1, -1)
1067
- if attention_mask is not None:
1068
- causal_mask = causal_mask.clone()
1069
- if attention_mask.dim() == 2:
1070
- mask_length = attention_mask.shape[-1]
1071
- padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
1072
- causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
1073
- elif attention_mask.dim() == 4:
1074
- mask_shape = attention_mask.shape
1075
- mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
1076
- causal_mask[: mask_shape[0], : mask_shape[1], : mask_shape[2], : mask_shape[3]] = mask_slice
1077
-
1078
- if (
1079
- self.config._attn_implementation == "sdpa"
1080
- and attention_mask is not None
1081
- and attention_mask.device.type == "cuda"
1082
- ):
1083
- # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
1084
- is_tracing = (
1085
- torch.jit.is_tracing()
1086
- or isinstance(input_tensor, torch.fx.Proxy)
1087
- or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
1088
- )
1089
- if not is_tracing and torch.any(attention_mask != 1):
1090
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1091
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1092
- # Details: https://github.com/pytorch/pytorch/issues/110213
1093
- causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1094
-
1095
- return causal_mask
1096
 
1097
  class LlamoeForCausalLM(LlamoePreTrainedModel):
1098
  _tied_weights_keys = ["lm_head.weight"]
1099
 
1100
  def __init__(self, config):
1101
  super().__init__(config)
1102
- self.model = LlamoeModel(config)
1103
  self.vocab_size = config.vocab_size
1104
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1105
  self.router_aux_loss_coef = config.router_aux_loss_coef
@@ -1126,7 +1286,7 @@ class LlamoeForCausalLM(LlamoePreTrainedModel):
1126
  def get_decoder(self):
1127
  return self.model
1128
 
1129
- @add_start_docstrings_to_model_forward(LLAMOE_INPUTS_DOCSTRING)
1130
  @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1131
  # Ignore copy
1132
  def forward(
@@ -1149,14 +1309,20 @@ class LlamoeForCausalLM(LlamoePreTrainedModel):
1149
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1150
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1151
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
 
1152
  Returns:
 
1153
  Example:
 
1154
  ```python
1155
- >>> from transformers import AutoTokenizer, LlamoeForCausalLM
1156
- >>> model = LlamoeForCausalLM.from_pretrained("mistralai/Llamoe-8x7B-v0.1")
1157
- >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Llamoe-8x7B-v0.1")
 
 
1158
  >>> prompt = "Hey, are you conscious? Can you talk to me?"
1159
  >>> inputs = tokenizer(prompt, return_tensors="pt")
 
1160
  >>> # Generate
1161
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1162
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
@@ -1190,12 +1356,6 @@ class LlamoeForCausalLM(LlamoePreTrainedModel):
1190
  hidden_states = outputs[0]
1191
  logits = self.lm_head(hidden_states)
1192
  logits = logits.float()
1193
-
1194
- if self.training:
1195
- for expert in self.model.layers[-1].block_sparse_moe.experts:
1196
- for param in expert.parameters():
1197
- if param.requires_grad and param.grad is None:
1198
- param.grad = torch.zeros_like(param)
1199
 
1200
  loss = None
1201
  if labels is not None:
@@ -1299,14 +1459,4 @@ class LlamoeForCausalLM(LlamoePreTrainedModel):
1299
  "output_router_logits": output_router_logits,
1300
  }
1301
  )
1302
- return model_inputs
1303
-
1304
- @staticmethod
1305
- def _reorder_cache(past_key_values, beam_idx):
1306
- reordered_past = ()
1307
- for layer_past in past_key_values:
1308
- reordered_past += (
1309
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1310
- )
1311
- return reordered_past
1312
-
 
62
  ) -> float:
63
  r"""
64
  Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
65
+
66
  See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
67
  function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
68
  experts is too unbalanced.
69
+
70
  Args:
71
  gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
72
  Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
 
76
  shape [batch_size X sequence_length] if not None.
77
  num_experts (`int`, *optional*):
78
  Number of experts
79
+
80
  Returns:
81
  The auxiliary loss.
82
  """
 
133
  return overall_loss * num_experts
134
 
135
 
136
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
 
 
 
137
  def _get_unpad_data(attention_mask):
138
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
139
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
140
  max_seqlen_in_batch = seqlens_in_batch.max().item()
141
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
142
  return (
143
  indices,
144
  cu_seqlens,
 
146
  )
147
 
148
 
149
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mixtral
150
  class LlamoeRMSNorm(nn.Module):
151
+ def __init__(self, hidden_size, eps=1e-6):
152
+ """
153
+ LlamoeRMSNorm is equivalent to T5LayerNorm
154
+ """
155
  super().__init__()
156
+ self.weight = nn.Parameter(torch.ones(hidden_size))
157
+ self.variance_epsilon = eps
 
 
 
 
 
158
 
159
+ def forward(self, hidden_states):
160
+ input_dtype = hidden_states.dtype
161
+ hidden_states = hidden_states.to(torch.float32)
162
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
163
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
164
+ return self.weight * hidden_states.to(input_dtype)
165
 
 
166
 
167
+ # Copied from transformers.models.mistral.modeling_mistral.LlamoeRotaryEmbedding with Mistral->Mixtral
168
  class LlamoeRotaryEmbedding(nn.Module):
169
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
170
  super().__init__()
171
+
172
  self.dim = dim
173
  self.max_position_embeddings = max_position_embeddings
174
  self.base = base
175
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
176
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
177
+
178
+ # Build here to make `torch.jit.trace` work.
179
+ self._set_cos_sin_cache(
180
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
181
+ )
182
 
183
  def _set_cos_sin_cache(self, seq_len, device, dtype):
184
  self.max_seq_len_cached = seq_len
185
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
186
+
187
+ freqs = torch.outer(t, self.inv_freq)
188
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
189
+ emb = torch.cat((freqs, freqs), dim=-1)
190
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
191
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
192
+
193
+ def forward(self, x, seq_len=None):
194
+ # x: [bs, num_attention_heads, seq_len, head_size]
 
 
 
 
195
  if seq_len > self.max_seq_len_cached:
196
  self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
197
+
198
+ return (
199
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
200
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
201
+ )
202
+
203
 
204
  # Copied from transformers.models.llama.modeling_llama.rotate_half
205
  def rotate_half(x):
 
209
  return torch.cat((-x2, x1), dim=-1)
210
 
211
 
212
+ # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
213
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
214
+ """Applies Rotary Position Embedding to the query and key tensors.
215
+
216
+ Args:
217
+ q (`torch.Tensor`): The query tensor.
218
+ k (`torch.Tensor`): The key tensor.
219
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
220
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
221
+ position_ids (`torch.Tensor`):
222
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
223
+ used to pass offsetted position ids when working with a KV-cache.
224
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
225
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
226
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
227
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
228
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
229
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
230
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
231
+ Returns:
232
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
233
+ """
234
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
235
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
236
  q_embed = (q * cos) + (rotate_half(q) * sin)
237
  k_embed = (k * cos) + (rotate_half(k) * sin)
238
  return q_embed, k_embed
239
 
240
+
241
  # Copied from transformers.models.llama.modeling_llama.repeat_kv
242
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
243
  """
 
250
  hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
251
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
252
 
253
+
254
+ # Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral
255
  class LlamoeAttention(nn.Module):
256
+ """
257
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
258
+ and "Generating Long Sequences with Sparse Transformers".
259
+ """
260
 
261
+ def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None):
 
262
  super().__init__()
263
  self.config = config
264
  self.layer_idx = layer_idx
 
269
  "when creating this class."
270
  )
271
 
 
272
  self.hidden_size = config.hidden_size
273
  self.num_heads = config.num_attention_heads
274
+ self.head_dim = self.hidden_size // self.num_heads
275
  self.num_key_value_heads = config.num_key_value_heads
276
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
277
  self.max_position_embeddings = config.max_position_embeddings
278
  self.rope_theta = config.rope_theta
279
  self.is_causal = True
280
+ self.attention_dropout = config.attention_dropout
281
 
282
+ if (self.head_dim * self.num_heads) != self.hidden_size:
283
  raise ValueError(
284
  f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
285
  f" and `num_heads`: {self.num_heads})."
286
  )
287
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
288
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
289
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
290
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
291
 
292
+ self.rotary_emb = MixtralRotaryEmbedding(
 
 
 
 
293
  self.head_dim,
294
  max_position_embeddings=self.max_position_embeddings,
295
  base=self.rope_theta,
296
  )
297
 
298
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
299
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
300
+
301
  def forward(
302
  self,
303
  hidden_states: torch.Tensor,
 
306
  past_key_value: Optional[Cache] = None,
307
  output_attentions: bool = False,
308
  use_cache: bool = False,
 
309
  **kwargs,
310
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
311
+ if "padding_mask" in kwargs:
312
+ warnings.warn(
313
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
314
+ )
315
  bsz, q_len, _ = hidden_states.size()
316
 
317
  query_states = self.q_proj(hidden_states)
 
322
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
323
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
324
 
325
+ kv_seq_len = key_states.shape[-2]
326
+ if past_key_value is not None:
327
+ if self.layer_idx is None:
328
+ raise ValueError(
329
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
330
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
331
+ "with a layer index."
332
+ )
333
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
334
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
335
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
336
 
337
  if past_key_value is not None:
338
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
 
339
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
340
 
341
+ # repeat k/v heads if n_kv_heads < n_heads
342
  key_states = repeat_kv(key_states, self.num_key_value_groups)
343
  value_states = repeat_kv(value_states, self.num_key_value_groups)
344
 
345
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
346
 
347
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
348
+ raise ValueError(
349
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
350
+ f" {attn_weights.size()}"
351
+ )
352
+
353
+ if attention_mask is not None:
354
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
355
+ raise ValueError(
356
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
357
+ )
358
+
359
+ attn_weights = attn_weights + attention_mask
360
 
361
  # upcast attention to fp32
362
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
 
370
  )
371
 
372
  attn_output = attn_output.transpose(1, 2).contiguous()
373
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
374
 
 
375
  attn_output = self.o_proj(attn_output)
376
 
377
  if not output_attentions:
 
380
  return attn_output, attn_weights, past_key_value
381
 
382
 
383
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral
384
  class LlamoeFlashAttention2(LlamoeAttention):
385
  """
386
+ Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays
387
  untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
388
  flash attention and deal with padding tokens in case the input contains any of them.
389
  """
390
 
391
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
392
  def __init__(self, *args, **kwargs):
393
  super().__init__(*args, **kwargs)
394
 
 
397
  # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
398
  self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
399
 
 
400
  def forward(
401
  self,
402
  hidden_states: torch.Tensor,
403
+ attention_mask: Optional[torch.Tensor] = None,
404
  position_ids: Optional[torch.LongTensor] = None,
405
  past_key_value: Optional[Cache] = None,
406
  output_attentions: bool = False,
407
  use_cache: bool = False,
 
408
  **kwargs,
409
+ ):
410
+ if "padding_mask" in kwargs:
411
+ warnings.warn(
412
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
413
+ )
414
 
415
+ # overwrite attention_mask with padding_mask
416
+ attention_mask = kwargs.pop("padding_mask")
417
  bsz, q_len, _ = hidden_states.size()
418
 
419
  query_states = self.q_proj(hidden_states)
420
  key_states = self.k_proj(hidden_states)
421
  value_states = self.v_proj(hidden_states)
422
 
 
 
 
423
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
424
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
425
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
426
 
427
+ kv_seq_len = key_states.shape[-2]
428
+ if past_key_value is not None:
429
+ if self.layer_idx is None:
430
+ raise ValueError(
431
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
432
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
433
+ "with a layer index."
434
+ )
435
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
436
 
437
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
438
+ rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
439
+ cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
440
+
441
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
442
+
443
+ use_sliding_windows = (
444
+ _flash_supports_window_size
445
+ and getattr(self.config, "sliding_window", None) is not None
446
+ and kv_seq_len > self.config.sliding_window
447
+ )
448
+
449
+ if not _flash_supports_window_size:
450
+ logger.warning_once(
451
+ "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
452
+ " make sure to upgrade flash-attn library."
453
+ )
454
 
455
  if past_key_value is not None:
456
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
457
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
458
+ if (
459
+ getattr(self.config, "sliding_window", None) is not None
460
+ and kv_seq_len > self.config.sliding_window
461
+ and cache_has_contents
462
+ ):
463
+ slicing_tokens = 1 - self.config.sliding_window
464
 
465
+ past_key = past_key_value[self.layer_idx][0]
466
+ past_value = past_key_value[self.layer_idx][1]
467
+
468
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
469
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
470
+
471
+ if past_key.shape[-2] != self.config.sliding_window - 1:
472
+ raise ValueError(
473
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
474
+ f" {past_key.shape}"
475
+ )
476
+
477
+ if attention_mask is not None:
478
+ attention_mask = attention_mask[:, slicing_tokens:]
479
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
480
 
481
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
482
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
483
+
484
+ # repeat k/v heads if n_kv_heads < n_heads
485
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
486
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
487
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
488
 
489
  # In PEFT, usually we cast the layer norms in float32 for training stability reasons
490
  # therefore the input hidden states gets silently casted in float32. Hence, we need
491
+ # cast them back in float16 just to be sure everything works as expected.
 
 
 
492
  input_dtype = query_states.dtype
493
  if input_dtype == torch.float32:
494
  if torch.is_autocast_enabled():
 
509
  key_states = key_states.to(target_dtype)
510
  value_states = value_states.to(target_dtype)
511
 
512
+ # Reashape to the expected shape for Flash Attention
513
+ query_states = query_states.transpose(1, 2)
514
+ key_states = key_states.transpose(1, 2)
515
+ value_states = value_states.transpose(1, 2)
516
+
517
  attn_output = self._flash_attention_forward(
518
+ query_states,
519
+ key_states,
520
+ value_states,
521
+ attention_mask,
522
+ q_len,
523
+ dropout=dropout_rate,
524
+ use_sliding_windows=use_sliding_windows,
525
  )
526
 
527
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
528
  attn_output = self.o_proj(attn_output)
529
 
530
  if not output_attentions:
 
533
  return attn_output, attn_weights, past_key_value
534
 
535
  def _flash_attention_forward(
536
+ self,
537
+ query_states,
538
+ key_states,
539
+ value_states,
540
+ attention_mask,
541
+ query_length,
542
+ dropout=0.0,
543
+ softmax_scale=None,
544
+ use_sliding_windows=False,
545
  ):
546
  """
547
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
548
  first unpad the input, then computes the attention scores and pad the final attention scores.
549
+
550
  Args:
551
  query_states (`torch.Tensor`):
552
  Input query states to be passed to Flash Attention API
 
561
  Attention dropout
562
  softmax_scale (`float`, *optional*):
563
  The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
564
+ use_sliding_windows (`bool`, *optional*):
565
+ Whether to activate sliding window attention.
566
  """
567
  if not self._flash_attn_uses_top_left_mask:
568
  causal = self.is_causal
569
  else:
570
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
571
  causal = self.is_causal and query_length != 1
572
 
573
  # Contains at least one padding token in the sequence
 
580
  cu_seqlens_q, cu_seqlens_k = cu_seq_lens
581
  max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
582
 
583
+ if not use_sliding_windows:
584
+ attn_output_unpad = flash_attn_varlen_func(
585
+ query_states,
586
+ key_states,
587
+ value_states,
588
+ cu_seqlens_q=cu_seqlens_q,
589
+ cu_seqlens_k=cu_seqlens_k,
590
+ max_seqlen_q=max_seqlen_in_batch_q,
591
+ max_seqlen_k=max_seqlen_in_batch_k,
592
+ dropout_p=dropout,
593
+ softmax_scale=softmax_scale,
594
+ causal=causal,
595
+ )
596
+ else:
597
+ attn_output_unpad = flash_attn_varlen_func(
598
+ query_states,
599
+ key_states,
600
+ value_states,
601
+ cu_seqlens_q=cu_seqlens_q,
602
+ cu_seqlens_k=cu_seqlens_k,
603
+ max_seqlen_q=max_seqlen_in_batch_q,
604
+ max_seqlen_k=max_seqlen_in_batch_k,
605
+ dropout_p=dropout,
606
+ softmax_scale=softmax_scale,
607
+ causal=causal,
608
+ window_size=(self.config.sliding_window, self.config.sliding_window),
609
+ )
610
 
611
  attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
612
  else:
613
+ if not use_sliding_windows:
614
+ attn_output = flash_attn_func(
615
+ query_states,
616
+ key_states,
617
+ value_states,
618
+ dropout,
619
+ softmax_scale=softmax_scale,
620
+ causal=causal,
621
+ )
622
+ else:
623
+ attn_output = flash_attn_func(
624
+ query_states,
625
+ key_states,
626
+ value_states,
627
+ dropout,
628
+ softmax_scale=softmax_scale,
629
+ causal=causal,
630
+ window_size=(self.config.sliding_window, self.config.sliding_window),
631
+ )
632
 
633
  return attn_output
634
 
635
  def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
636
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
637
+
638
+ # On the first iteration we need to properly re-create the padding mask
639
+ # by slicing it on the proper place
640
+ if kv_seq_len != attention_mask.shape[-1]:
641
+ attention_mask_num_tokens = attention_mask.shape[-1]
642
+ attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
643
+
644
  indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
 
645
 
646
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
647
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
648
+
 
 
 
649
  if query_length == kv_seq_len:
650
  query_layer = index_first_axis(
651
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
652
  )
653
  cu_seqlens_q = cu_seqlens_k
654
  max_seqlen_in_batch_q = max_seqlen_in_batch_k
 
675
  )
676
 
677
 
678
+ # Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral
679
  class LlamoeSdpaAttention(LlamoeAttention):
680
  """
681
+ Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
682
+ `MixtralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
683
  SDPA API.
684
  """
685
 
686
+ # Adapted from MixtralAttention.forward
687
  def forward(
688
  self,
689
  hidden_states: torch.Tensor,
 
692
  past_key_value: Optional[Cache] = None,
693
  output_attentions: bool = False,
694
  use_cache: bool = False,
 
695
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
696
  if output_attentions:
697
  # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
698
  logger.warning_once(
699
+ "MixtralModel is using MixtralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
700
  'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
701
  )
702
  return super().forward(
 
706
  past_key_value=past_key_value,
707
  output_attentions=output_attentions,
708
  use_cache=use_cache,
 
709
  )
710
 
711
  bsz, q_len, _ = hidden_states.size()
712
 
 
713
  query_states = self.q_proj(hidden_states)
714
  key_states = self.k_proj(hidden_states)
715
  value_states = self.v_proj(hidden_states)
716
+
717
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
718
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
719
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
720
 
721
+ kv_seq_len = key_states.shape[-2]
 
 
 
 
722
  if past_key_value is not None:
723
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
724
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
725
+
726
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
727
+
728
+ if past_key_value is not None:
729
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
730
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
731
 
732
  key_states = repeat_kv(key_states, self.num_key_value_groups)
733
  value_states = repeat_kv(value_states, self.num_key_value_groups)
734
 
735
+ if attention_mask is not None:
736
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
737
+ raise ValueError(
738
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
739
+ )
740
 
741
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
742
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
743
+ if query_states.device.type == "cuda" and attention_mask is not None:
744
  query_states = query_states.contiguous()
745
  key_states = key_states.contiguous()
746
  value_states = value_states.contiguous()
 
749
  query_states,
750
  key_states,
751
  value_states,
752
+ attn_mask=attention_mask,
753
  dropout_p=self.attention_dropout if self.training else 0.0,
754
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
755
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
756
  )
757
 
758
  attn_output = attn_output.transpose(1, 2).contiguous()
759
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
760
 
761
  attn_output = self.o_proj(attn_output)
762
 
763
  return attn_output, None, past_key_value
764
 
765
 
766
+ Llamoe_ATTENTION_CLASSES = {
767
  "eager": LlamoeAttention,
768
  "flash_attention_2": LlamoeFlashAttention2,
769
  "sdpa": LlamoeSdpaAttention,
770
  }
771
 
772
+
773
+ class MixtralBlockSparseTop2MLP(nn.Module):
774
  def __init__(self, config: LlamoeConfig):
775
  super().__init__()
776
  self.ffn_dim = config.intermediate_size
777
  self.hidden_dim = config.hidden_size
778
+
779
  self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
780
  self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
781
  self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
782
 
783
+ self.act_fn = ACT2FN[config.hidden_act]
784
 
785
  def forward(self, hidden_states):
786
  current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
787
  current_hidden_states = self.w2(current_hidden_states)
788
+ return current_hidden_states
789
+
790
+
791
+ class LlamoeBLockSparseTop2MLP(LlamoeBlockSparseTop2MLP):
792
+ def __init__(self, *args, **kwargs):
793
+ logger.warning_once(
794
+ "LlamoeBLockSparseTop2MLP is deprecated by MixtralBlockSparseTop2MLP and will be removed in v4.40."
795
+ )
796
+ super().__init__(*args, **kwargs)
797
 
798
 
799
  class LlamoeSparseMoeBlock(nn.Module):
800
+ """
801
+ This implementation is
802
+ strictly equivalent to standard MoE with full capacity (no
803
+ dropped tokens). It's faster since it formulates MoE operations
804
+ in terms of block-sparse operations to accomodate imbalanced
805
+ assignments of tokens to experts, whereas standard MoE either
806
+ (1) drop tokens at the cost of reduced performance or (2) set
807
+ capacity factor to number of experts and thus waste computation
808
+ and memory on padding.
809
+ """
810
+
811
  def __init__(self, config):
812
  super().__init__()
813
  self.hidden_dim = config.hidden_size
814
  self.ffn_dim = config.intermediate_size
815
  self.num_experts = config.num_local_experts
816
+ self.top_k = config.num_experts_per_tok
817
 
818
  # gating
819
  self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
820
 
821
  self.experts = nn.ModuleList([LlamoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
822
 
823
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
824
+ """ """
825
  batch_size, sequence_length, hidden_dim = hidden_states.shape
826
  hidden_states = hidden_states.view(-1, hidden_dim)
 
827
  # router_logits: (batch * sequence_length, n_experts)
828
  router_logits = self.gate(hidden_states)
 
 
 
829
 
830
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
831
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
832
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
833
+ # we cast back to the input dtype
834
+ routing_weights = routing_weights.to(hidden_states.dtype)
835
+
836
+ final_hidden_states = torch.zeros(
837
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
838
+ )
839
+
840
+ # One hot encode the selected experts to create an expert mask
841
+ # this will be used to easily index which expert is going to be sollicitated
842
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
843
+
844
+ # Loop over all available experts in the model and perform the computation on each expert
845
+ for expert_idx in range(self.num_experts):
846
+ expert_layer = self.experts[expert_idx]
847
+ idx, top_x = torch.where(expert_mask[expert_idx])
848
+
849
+ if top_x.shape[0] == 0:
850
+ continue
851
 
852
+ # in torch it is faster to index using lists than torch tensors
853
+ top_x_list = top_x.tolist()
854
+ idx_list = idx.tolist()
855
 
856
+ # Index the correct hidden states and compute the expert hidden state for
857
+ # the current expert. We need to make sure to multiply the output hidden
858
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
859
+ current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
860
+ current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
861
 
862
+ # However `index_add_` only support torch tensors for indexing so we'll use
863
+ # the `top_x` tensor here.
864
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
865
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
866
+ return final_hidden_states, router_logits
867
 
 
 
868
 
 
 
869
  class LlamoeDecoderLayer(nn.Module):
870
  def __init__(self, config: LlamoeConfig, layer_idx: int):
871
  super().__init__()
872
  self.hidden_size = config.hidden_size
873
 
874
+ self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
875
 
876
  self.block_sparse_moe = LlamoeSparseMoeBlock(config)
877
  self.input_layernorm = LlamoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
948
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
949
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
950
  etc.)
951
+
952
  This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
953
  Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
954
  and behavior.
955
+
956
  Parameters:
957
+ config ([`MixtralConfig`]):
958
  Model configuration class with all the parameters of the model. Initializing with a config file does not
959
  load the weights associated with the model, only the configuration. Check out the
960
  [`~PreTrainedModel.from_pretrained`] method to load the model weights.
 
962
 
963
 
964
  @add_start_docstrings(
965
+ "The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
966
  Llamoe_START_DOCSTRING,
967
  )
968
+ # Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Mixtral
969
  class LlamoePreTrainedModel(PreTrainedModel):
970
  config_class = LlamoeConfig
971
  base_model_prefix = "model"
972
  supports_gradient_checkpointing = True
 
973
  _no_split_modules = ["LlamoeDecoderLayer"]
974
+ _skip_keys_device_placement = "past_key_values"
975
  _supports_flash_attn_2 = True
976
  _supports_sdpa = True
977
  _supports_cache_class = True
 
987
  if module.padding_idx is not None:
988
  module.weight.data[module.padding_idx].zero_()
989
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
990
 
991
+ Llamoe_INPUTS_DOCSTRING = r"""
 
992
  Args:
993
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
994
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
995
  it.
996
+
997
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
998
  [`PreTrainedTokenizer.__call__`] for details.
999
+
1000
  [What are input IDs?](../glossary#input-ids)
1001
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1002
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1003
+
1004
  - 1 for tokens that are **not masked**,
1005
  - 0 for tokens that are **masked**.
1006
+
1007
  [What are attention masks?](../glossary#attention-mask)
1008
+
1009
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1010
  [`PreTrainedTokenizer.__call__`] for details.
1011
+
1012
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
1013
  `past_key_values`).
1014
+
1015
  If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1016
  and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1017
  information on the default strategy.
1018
+
1019
  - 1 indicates the head is **not masked**,
1020
  - 0 indicates the head is **masked**.
1021
  position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1022
  Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1023
  config.n_positions - 1]`.
1024
+
1025
  [What are position IDs?](../glossary#position-ids)
1026
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1027
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
1028
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
1029
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
1030
+
1031
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1032
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
1033
+
1034
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1035
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1036
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
 
 
 
1037
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1038
  Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1039
  is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
 
1047
  output_hidden_states (`bool`, *optional*):
1048
  Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1049
  more detail.
1050
+ output_router_logits (`bool`, *optional*):
1051
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
1052
+ should not be returned during inference.
1053
  return_dict (`bool`, *optional*):
1054
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
 
 
 
 
1055
  """
1056
 
1057
 
1058
  @add_start_docstrings(
1059
+ "The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
1060
  Llamoe_START_DOCSTRING,
1061
  )
1062
+ # Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral
1063
+ class MixtralModel(LlamoePreTrainedModel):
1064
  """
1065
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`]
1066
+
1067
  Args:
1068
+ config: MixtralConfig
1069
  """
1070
 
1071
+ def __init__(self, config: MixtralConfig):
1072
  super().__init__(config)
1073
  self.padding_idx = config.pad_token_id
1074
  self.vocab_size = config.vocab_size
 
1077
  self.layers = nn.ModuleList(
1078
  [LlamoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1079
  )
1080
+ self._attn_implementation = config._attn_implementation
1081
  self.norm = LlamoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
1082
 
1083
+ self.gradient_checkpointing = False
 
 
 
 
 
1084
  # Initialize weights and apply final processing
1085
  self.post_init()
1086
 
 
1090
  def set_input_embeddings(self, value):
1091
  self.embed_tokens = value
1092
 
1093
+ # Ignore copy
1094
+ @add_start_docstrings_to_model_forward(Llamoe_INPUTS_DOCSTRING)
1095
  def forward(
1096
  self,
1097
  input_ids: torch.LongTensor = None,
 
1104
  output_hidden_states: Optional[bool] = None,
1105
  output_router_logits: Optional[bool] = None,
1106
  return_dict: Optional[bool] = None,
 
1107
  ) -> Union[Tuple, MoeModelOutputWithPast]:
1108
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1109
+ output_router_logits = (
1110
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
1111
+ )
1112
  output_hidden_states = (
1113
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1114
  )
1115
  use_cache = use_cache if use_cache is not None else self.config.use_cache
1116
+
1117
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1118
 
1119
+ # retrieve input_ids and inputs_embeds
1120
+ if input_ids is not None and inputs_embeds is not None:
1121
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
1122
+ elif input_ids is not None:
1123
+ batch_size, seq_length = input_ids.shape
1124
+ elif inputs_embeds is not None:
1125
+ batch_size, seq_length, _ = inputs_embeds.shape
1126
+ else:
1127
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1128
 
1129
+ past_key_values_length = 0
 
 
 
 
1130
 
1131
+ if self.gradient_checkpointing and self.training:
1132
+ if use_cache:
1133
+ logger.warning_once(
1134
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1135
+ )
1136
+ use_cache = False
1137
 
1138
+ if use_cache:
1139
+ use_legacy_cache = not isinstance(past_key_values, Cache)
1140
+ if use_legacy_cache:
 
 
 
 
 
 
 
 
1141
  past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1142
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
1143
 
1144
+ if position_ids is None:
1145
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1146
+ position_ids = torch.arange(
1147
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1148
  )
1149
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1150
+ else:
1151
+ position_ids = position_ids.view(-1, seq_length).long()
1152
 
1153
+ if inputs_embeds is None:
1154
+ inputs_embeds = self.embed_tokens(input_ids)
1155
 
1156
+ if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
1157
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
1158
+ if is_padding_right:
1159
+ raise ValueError(
1160
+ "You are attempting to perform batched generation with padding_side='right'"
1161
+ " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
1162
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1163
+ )
1164
 
1165
+ if self._attn_implementation == "flash_attention_2":
1166
+ # 2d mask is passed through the layers
1167
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1168
+ elif self._attn_implementation == "sdpa" and not output_attentions:
1169
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
1170
+ # the manual implementation that requires a 4D causal mask in all cases.
1171
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1172
+ attention_mask,
1173
+ (batch_size, seq_length),
1174
+ inputs_embeds,
1175
+ past_key_values_length,
1176
+ )
1177
+ else:
1178
+ # 4d mask is passed through the layers
1179
+ attention_mask = _prepare_4d_causal_attention_mask(
1180
+ attention_mask,
1181
+ (batch_size, seq_length),
1182
+ inputs_embeds,
1183
+ past_key_values_length,
1184
+ sliding_window=self.config.sliding_window,
1185
+ )
1186
 
1187
+ hidden_states = inputs_embeds
 
1188
 
1189
  # decoder layers
1190
  all_hidden_states = () if output_hidden_states else None
1191
  all_self_attns = () if output_attentions else None
1192
+ all_router_logits = () if output_router_logits else None
1193
  next_decoder_cache = None
1194
 
1195
  for decoder_layer in self.layers:
1196
  if output_hidden_states:
1197
  all_hidden_states += (hidden_states,)
1198
+
1199
+ if self.gradient_checkpointing and self.training:
1200
  layer_outputs = self._gradient_checkpointing_func(
1201
  decoder_layer.__call__,
1202
  hidden_states,
1203
+ attention_mask,
1204
  position_ids,
1205
  past_key_values,
1206
  output_attentions,
1207
  output_router_logits,
1208
+ use_cache,
 
 
1209
  )
1210
  else:
1211
  layer_outputs = decoder_layer(
1212
  hidden_states,
1213
+ attention_mask=attention_mask,
1214
  position_ids=position_ids,
1215
  past_key_value=past_key_values,
1216
  output_attentions=output_attentions,
1217
  output_router_logits=output_router_logits,
1218
+ use_cache=use_cache,
 
1219
  )
1220
 
1221
  hidden_states = layer_outputs[0]
 
1226
  if output_attentions:
1227
  all_self_attns += (layer_outputs[1],)
1228
 
1229
+ if output_router_logits:
1230
+ all_router_logits += (layer_outputs[-1],)
1231
+
1232
  hidden_states = self.norm(hidden_states)
1233
 
1234
  # add hidden states from the last decoder layer
 
1237
 
1238
  next_cache = None
1239
  if use_cache:
1240
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1241
+
 
1242
  if not return_dict:
1243
+ return tuple(
1244
+ v
1245
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
1246
+ if v is not None
1247
+ )
1248
  return MoeModelOutputWithPast(
1249
  last_hidden_state=hidden_states,
1250
  past_key_values=next_cache,
1251
  hidden_states=all_hidden_states,
1252
  attentions=all_self_attns,
1253
+ router_logits=all_router_logits,
1254
  )
1255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1256
 
1257
  class LlamoeForCausalLM(LlamoePreTrainedModel):
1258
  _tied_weights_keys = ["lm_head.weight"]
1259
 
1260
  def __init__(self, config):
1261
  super().__init__(config)
1262
+ self.model = MixtralModel(config)
1263
  self.vocab_size = config.vocab_size
1264
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1265
  self.router_aux_loss_coef = config.router_aux_loss_coef
 
1286
  def get_decoder(self):
1287
  return self.model
1288
 
1289
+ @add_start_docstrings_to_model_forward(Llamoe_INPUTS_DOCSTRING)
1290
  @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1291
  # Ignore copy
1292
  def forward(
 
1309
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1310
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1311
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1312
+
1313
  Returns:
1314
+
1315
  Example:
1316
+
1317
  ```python
1318
+ >>> from transformers import AutoTokenizer, MixtralForCausalLM
1319
+
1320
+ >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
1321
+ >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
1322
+
1323
  >>> prompt = "Hey, are you conscious? Can you talk to me?"
1324
  >>> inputs = tokenizer(prompt, return_tensors="pt")
1325
+
1326
  >>> # Generate
1327
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1328
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
 
1356
  hidden_states = outputs[0]
1357
  logits = self.lm_head(hidden_states)
1358
  logits = logits.float()
 
 
 
 
 
 
1359
 
1360
  loss = None
1361
  if labels is not None:
 
1459
  "output_router_logits": output_router_logits,
1460
  }
1461
  )
1462
+ return model_inputs