damerajee commited on
Commit
8f6227e
1 Parent(s): 36552a5

Update modeling_Llamoe.py

Browse files
Files changed (1) hide show
  1. modeling_Llamoe.py +1175 -768
modeling_Llamoe.py CHANGED
@@ -4,967 +4,1374 @@
4
  # Copyright (c) 2022, Tri Dao, [email protected].
5
  # Licensed under the BSD 3-Clause License.
6
 
7
- from __future__ import annotations
8
-
9
  import math
10
- from dataclasses import dataclass, field
11
- from typing import Any, Dict, Optional, Tuple, Union
12
 
13
  import torch
14
- import torch.nn as nn
15
- from einops import rearrange, repeat
16
- from transformers import PretrainedConfig, PreTrainedModel
17
- from transformers.activations import ACT2FN
18
- from transformers.modeling_outputs import CausalLMOutputWithPast
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  from .configuration_Llamoe import LlamoeConfig
21
 
22
- try:
23
- from flash_attn.bert_padding import pad_input, unpad_input
24
- from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
25
- from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention
26
- from flash_attn.ops.fused_dense import FusedDense
27
- except:
28
- pad_input, unpad_input = None, None
29
- FlashRotaryEmbedding = None
30
- FlashSelfAttention, FlashCrossAttention = None, None
31
- FusedDense = None
32
-
33
-
34
- @dataclass
35
- class InferenceParams:
36
- """Inference parameters passed to model to efficiently calculate
37
- and store context during inference.
38
- Reference:
39
- https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py.
40
- Args:
41
- max_seqlen: Maximum sequence length.
42
- max_batch_size: Maximum batch size.
43
- seqlen_offset: Sequence length offset.
44
- batch_size_offset: Batch size offset.
45
- key_value_memory_dict: Key value memory dictionary.
46
- lengths_per_sample: Lengths per sample.
47
- """
48
-
49
- max_seqlen: int = field(metadata={"help": "Maximum sequence length."})
50
 
51
- max_batch_size: int = field(metadata={"help": "Maximum batch size."})
 
 
52
 
53
- seqlen_offset: int = field(default=0, metadata={"help": "Sequence length offset."})
54
 
55
- batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."})
 
 
 
 
56
 
57
- key_value_memory_dict: Dict[str, Any] = field(
58
- default_factory=dict, metadata={"help": "Key value memory dictionary."}
59
- )
60
 
61
- lengths_per_sample: torch.Tensor = field(default=None, metadata={"help": "Lengths per sample."})
62
 
 
63
 
64
- class Embedding(nn.Module):
65
- """Token embedding with dropout."""
66
 
67
- def __init__(self, config: LlamoeConfig) -> None:
68
- super().__init__()
69
 
70
- self.wte = nn.Embedding(config.vocab_size, config.n_embd)
71
- self.drop = nn.Dropout(config.embd_pdrop)
 
 
 
72
 
73
- def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
74
- input_shape = input_ids.size()
75
- input_ids = input_ids.view(-1, input_shape[-1])
76
 
77
- hidden_states = self.wte(input_ids)
78
- hidden_states = self.drop(hidden_states)
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- return hidden_states
 
 
81
 
 
82
 
83
- def _apply_rotary_emb(
84
- x: torch.FloatTensor,
85
- cos: torch.FloatTensor,
86
- sin: torch.FloatTensor,
87
- ) -> torch.FloatTensor:
88
- _, seqlen, _, _ = x.shape
89
- _, rotary_dim = cos.shape
90
- rotary_dim *= 2
91
 
92
- x_rot = x[:, :, :, :rotary_dim]
93
- x_pass = x[:, :, :, rotary_dim:]
94
 
95
- x1, x2 = x_rot.chunk(2, dim=-1)
96
- c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
97
- x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]]
98
 
99
- x_rot = torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], axis=-1).to(x.dtype)
 
 
 
 
100
 
101
- return torch.cat([x_rot, x_pass], axis=-1)
 
 
 
 
 
 
102
 
 
 
 
 
103
 
104
- def _apply_rotary_emb_kv(
105
- kv: torch.FloatTensor,
106
- cos: torch.FloatTensor,
107
- sin: torch.FloatTensor,
108
- cos_k: Optional[torch.FloatTensor] = None,
109
- sin_k: Optional[torch.FloatTensor] = None,
110
- ) -> torch.FloatTensor:
111
- _, seqlen, _, _, _ = kv.shape
112
- _, rotary_dim = cos.shape
113
- rotary_dim *= 2
114
 
115
- k_rot = kv[:, :, 0, :, :rotary_dim]
116
- k_pass = kv[:, :, 0, :, rotary_dim:]
 
 
117
 
118
- k1, k2 = k_rot.chunk(2, dim=-1)
119
- c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
120
- k1, k2, c, s = [t.to(dtype=torch.float32) for t in [k1, k2, c, s]]
121
 
122
- k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(kv.dtype)
123
 
124
- return torch.cat(
125
- [
126
- torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
127
- kv[:, :, 1:2, :, :],
128
- ],
129
- axis=2,
 
 
 
 
130
  )
131
 
132
 
133
- def _apply_rotary_emb_qkv(
134
- qkv: torch.FloatTensor,
135
- cos: torch.FloatTensor,
136
- sin: torch.FloatTensor,
137
- cos_k: Optional[torch.FloatTensor] = None,
138
- sin_k: Optional[torch.FloatTensor] = None,
139
- ) -> torch.FloatTensor:
140
- _, seqlen, _, _, _ = qkv.shape
141
- _, rotary_dim = cos.shape
142
- rotary_dim *= 2
143
-
144
- q_rot = qkv[:, :, 0, :, :rotary_dim]
145
- q_pass = qkv[:, :, 0, :, rotary_dim:]
146
-
147
- k_rot = qkv[:, :, 1, :, :rotary_dim]
148
- k_pass = qkv[:, :, 1, :, rotary_dim:]
149
-
150
- q1, q2 = q_rot.chunk(2, dim=-1)
151
- k1, k2 = k_rot.chunk(2, dim=-1)
152
- c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
153
- q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]]
154
-
155
- q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
156
- k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
157
-
158
- return torch.cat(
159
- [
160
- torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
161
- torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
162
- qkv[:, :, 2:3, :, :],
163
- ],
164
- axis=2,
165
- )
166
 
 
 
 
 
 
 
167
 
168
- class RotaryEmbedding(nn.Module):
169
- """Rotary positional embedding (RoPE).
170
- Reference:
171
- RoFormer: Enhanced Transformer with Rotary Position Embedding.
172
- https://arxiv.org/pdf/2104.09864.pdf.
173
- """
174
 
175
- def __init__(
176
- self,
177
- dim: int,
178
- base: int = 10000,
179
- scale_base: Optional[float] = None,
180
- pos_idx_in_fp32: bool = True,
181
- max_position_embeddings: int = 2048,
182
- device: Optional[str] = None,
183
- **kwargs,
184
- ) -> None:
185
  super().__init__()
186
 
187
- if scale_base is not None:
188
- raise NotImplementedError
189
-
190
  self.dim = dim
191
- self.base = float(base)
192
- self.scale_base = scale_base
193
- self.pos_idx_in_fp32 = pos_idx_in_fp32
194
  self.max_position_embeddings = max_position_embeddings
195
- self.device = device
196
-
197
- # Generate and save the inverse frequency buffer (non-trainable)
198
- inv_freq = self._compute_inv_freq(device)
199
  self.register_buffer("inv_freq", inv_freq, persistent=False)
200
 
201
- # Generate and save the scale buffer (non-trainable)
202
- scale = (
203
- (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
204
- if scale_base is not None
205
- else None
206
  )
207
- self.register_buffer("scale", scale, persistent=False)
208
-
209
- # Initialize cached attributes since ONNX can't rely on dynamic initialization
210
- self._update_cos_sin_cache(max_position_embeddings, device=device, dtype=torch.float32)
211
-
212
- def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
213
- return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
214
-
215
- def _update_cos_sin_cache(
216
- self,
217
- seqlen: int,
218
- device: Optional[str] = None,
219
- dtype: Optional[torch.dtype] = None,
220
- ) -> None:
221
- self._seq_len_cached = seqlen
222
-
223
- # fp32 is preferred since the output of `torch.arange` can be quite large
224
- # and bf16 would lose a lot of precision
225
- if self.pos_idx_in_fp32:
226
- t = torch.arange(seqlen, device=device, dtype=torch.float32)
227
- if self.inv_freq.dtype != torch.float32:
228
- inv_freq = self._compute_inv_freq(device=device)
229
- else:
230
- inv_freq = self.inv_freq
231
- else:
232
- t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
233
- inv_freq = self.inv_freq
234
-
235
- # `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP
236
- freqs = torch.outer(t, inv_freq)
237
- if self.scale is None:
238
- self._cos_cached = torch.cos(freqs).to(dtype)
239
- self._sin_cached = torch.sin(freqs).to(dtype)
240
- else:
241
- power = (
242
- torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
243
- ) / self.scale_base
244
- scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
245
 
246
- # Force the scale multiplication to happen in fp32
247
- self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
248
- self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
249
- self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
250
- self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
251
-
252
- def forward(
253
- self,
254
- qkv: torch.Tensor,
255
- kv: Optional[torch.Tensor] = None,
256
- seqlen_offset: int = 0,
257
- **kwargs,
258
- ) -> Tuple[torch.Tensor, torch.Tensor]:
259
- if (
260
- self._seq_len_cached < qkv.shape[1] + seqlen_offset
261
- or self._cos_cached.device != qkv.device
262
- or self._cos_cached.dtype != qkv.dtype
263
- or (self.training and self._cos_cached.is_inference())
264
- ):
265
- self._update_cos_sin_cache(qkv.shape[1] + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
266
-
267
- if kv is None:
268
- return _apply_rotary_emb_qkv(
269
- qkv,
270
- self._cos_cached[seqlen_offset:],
271
- self._sin_cached[seqlen_offset:],
272
- )
273
- else:
274
- q = _apply_rotary_emb(
275
- qkv,
276
- self._cos_cached[seqlen_offset:],
277
- self._sin_cached[seqlen_offset:],
278
- )
279
- kv = _apply_rotary_emb_kv(
280
- kv,
281
- self._cos_cached[seqlen_offset:],
282
- self._sin_cached[seqlen_offset:],
283
- )
284
 
285
- return q, kv
 
 
 
 
286
 
 
 
 
 
287
 
288
- class MoE(nn.Module):
289
- def __init__(
290
- self,
291
- config: LlamoeConfig,
292
- ):
293
- super().__init__()
294
- self.mlp = nn.ModuleList([MLP(config) for i in range(config.num_local_experts)])
295
- self.gate = nn.Linear(config.n_embd, config.num_local_experts, bias=False)
296
- self.num_experts_per_tok = config.num_experts_per_tok
297
 
298
- def forward(self, x):
299
- orig_shape = x.shape
300
- x = x.view(-1, x.shape[-1])
301
 
302
- scores = self.gate(x)
303
- expert_weights, expert_indices = torch.topk(scores, self.num_experts_per_tok, dim=-1)
304
- expert_weights = expert_weights.softmax(dim=-1)
305
- flat_expert_indices = expert_indices.view(-1)
 
 
306
 
307
- x = x.repeat_interleave(self.num_experts_per_tok, dim=0)
308
- y = torch.empty_like(x)
309
- for i, expert in enumerate(self.mlp):
310
- y[flat_expert_indices == i] = expert(x[flat_expert_indices == i])
311
- y = (y.view(*expert_weights.shape, -1) * expert_weights.unsqueeze(-1)).sum(dim=1)
312
- return y.view(*orig_shape)
313
 
 
 
 
314
 
315
- class MLP(nn.Module):
316
- """Multi-Layer Perceptron.
317
- Reference:
318
- Attention Is All You Need.
319
- https://arxiv.org/pdf/1706.03762.pdf.
 
 
 
 
 
 
 
 
 
 
 
 
320
  """
 
 
 
 
 
321
 
322
- def __init__(
323
- self,
324
- config: PretrainedConfig,
325
- n_inner: Optional[int] = None,
326
- act_fn: Optional[str] = None,
327
- ) -> None:
328
- super().__init__()
329
-
330
- act_fn = config.activation_function if act_fn is None else act_fn
331
-
332
- n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
333
- n_inner = n_inner if n_inner is not None else 4 * config.n_embd
334
-
335
- self.fc1 = nn.Linear(config.n_embd, n_inner)
336
- self.fc2 = nn.Linear(n_inner, config.n_embd)
337
- self.act = ACT2FN[act_fn]
338
-
339
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
340
- hidden_states = self.fc1(hidden_states)
341
- hidden_states = self.act(hidden_states)
342
- hidden_states = self.fc2(hidden_states)
343
 
 
 
 
 
 
 
 
 
344
  return hidden_states
 
 
345
 
346
 
347
- class SelfAttention(nn.Module):
348
- """Self-attention layer (compatible with PyTorch).
349
- Reference:
350
- https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
351
- """
352
 
353
- def __init__(
354
- self,
355
- causal: bool = True,
356
- softmax_scale: Optional[float] = None,
357
- attention_dropout: float = 0.0,
358
- ) -> None:
359
  super().__init__()
 
 
 
 
 
 
 
 
360
 
361
- self.causal = causal
362
- self.softmax_scale = softmax_scale
363
- self.drop = nn.Dropout(attention_dropout)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
- @torch.autocast("cpu", enabled=False)
366
- @torch.autocast("cuda", enabled=False)
367
  def forward(
368
  self,
369
- qkv: torch.FloatTensor,
370
- causal: bool = None,
371
- key_padding_mask: Optional[torch.BoolTensor] = None,
 
 
 
 
372
  **kwargs,
373
- ) -> torch.FloatTensor:
374
- batch_size, seqlen = qkv.shape[0], qkv.shape[1]
375
- q, k, v = qkv.unbind(dim=2)
 
 
 
 
 
 
 
376
 
377
- q = q.to(torch.float32)
378
- k = k.to(torch.float32)
379
 
380
- causal = self.causal if causal is None else causal
381
- softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
382
 
383
- # Autocast is manually disabled to avoid `torch.einsum` performing the operation
384
- # using float16, which might lead to overflow
385
- scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
386
 
387
- if key_padding_mask is not None:
388
- padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device)
389
- padding_mask.masked_fill_(key_padding_mask, 0.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
 
391
- scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
392
 
393
- if causal:
394
- causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
395
- scores = scores + causal_mask.to(dtype=scores.dtype)
396
 
397
- attention = torch.softmax(scores, dim=-1).to(v.dtype)
398
- attention = self.drop(attention)
 
 
 
 
399
 
400
- output = torch.einsum("bhts,bshd->bthd", attention, v)
 
401
 
402
- return output
403
 
404
 
405
- class CrossAttention(nn.Module):
406
- """Cross-attention layer (compatible with PyTorch).
407
- Reference:
408
- https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
 
409
  """
410
 
411
- def __init__(
412
- self,
413
- causal: bool = True,
414
- softmax_scale: Optional[float] = None,
415
- attention_dropout: float = 0.0,
416
- ) -> None:
417
- super().__init__()
418
 
419
- self.causal = causal
420
- self.softmax_scale = softmax_scale
421
- self.drop = nn.Dropout(attention_dropout)
 
422
 
423
- @torch.autocast("cpu", enabled=False)
424
- @torch.autocast("cuda", enabled=False)
425
  def forward(
426
  self,
427
- q: torch.FloatTensor,
428
- kv: torch.FloatTensor,
429
- causal: bool = None,
430
- key_padding_mask: Optional[torch.BoolTensor] = None,
 
 
 
431
  **kwargs,
432
- ) -> torch.FloatTensor:
433
- batch_size, seqlen_q = q.shape[0], q.shape[1]
434
- seqlen_k = kv.shape[1]
435
-
436
- if kv.shape[3] != q.shape[2]:
437
- kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
438
- k, v = kv.unbind(dim=2)
439
-
440
- q = q.to(torch.float32)
441
- k = k.to(torch.float32)
442
-
443
- causal = self.causal if causal is None else causal
444
- softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
445
-
446
- # Autocast is manually disabled to avoid `torch.einsum` performing the operation
447
- # using float16, which might lead to overflow
448
- scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
449
-
450
- if key_padding_mask is not None:
451
- padding_mask = torch.full(
452
- (batch_size, seqlen_k),
453
- -10000.0,
454
- dtype=scores.dtype,
455
- device=scores.device,
456
- )
457
- padding_mask.masked_fill_(key_padding_mask, 0.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
 
459
- scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
 
 
 
 
460
 
461
- if causal:
462
- rows = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1")
463
- cols = torch.arange(seqlen_k, device=k.device, dtype=torch.long)
464
- causal_mask = cols > rows + seqlen_k - seqlen_q
465
 
466
- scores = scores.masked_fill(causal_mask, -10000.0)
 
 
467
 
468
- attention = torch.softmax(scores, dim=-1).to(v.dtype)
469
- attention = self.drop(attention)
470
 
471
- output = torch.einsum("bhts,bshd->bthd", attention, v)
 
472
 
473
- return output
474
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
 
476
- def _find_mha_dims(
477
- config: PretrainedConfig,
478
- n_head: Optional[int] = None,
479
- n_head_kv: Optional[int] = None,
480
- head_dim: Optional[int] = None,
481
- ) -> Tuple[int, int]:
482
- if n_head is None and head_dim is None:
483
- head_dim = config.n_embd // config.n_head
484
- n_head = config.n_head
485
- elif n_head is None or head_dim is None:
486
- raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
487
 
488
- if n_head_kv is None:
489
- n_head_kv = getattr(config, "n_head_kv", None) or n_head
 
 
 
 
 
 
 
 
 
 
 
 
 
490
 
491
- return n_head, n_head_kv, head_dim
 
 
 
 
492
 
 
493
 
494
- def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int) -> torch.FloatTensor:
495
- num_heads, head_dim = kv.shape[-2:]
 
496
 
497
- if layer_idx not in inference_params.key_value_memory_dict:
498
- inference_params.key_value_memory_dict[layer_idx] = torch.empty(
499
- inference_params.max_batch_size,
500
- inference_params.max_seqlen,
501
- 2,
502
- num_heads,
503
- head_dim,
504
- dtype=kv.dtype,
505
- device=kv.device,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
506
  )
507
 
508
- batch_start = inference_params.batch_size_offset
509
- batch_end = batch_start + kv.shape[0]
510
-
511
- sequence_start = inference_params.seqlen_offset
512
- sequence_end = sequence_start + kv.shape[1]
513
 
514
- # When the current sequence length is equal to or larger than the maximum sequence length,
515
- # we need to concatenate the current `kv` with the cached `kv` to expand its length
516
- if sequence_end >= inference_params.max_seqlen:
517
- inference_params.key_value_memory_dict[layer_idx] = torch.concatenate((inference_params.key_value_memory_dict[layer_idx], kv), dim=1)
 
 
518
 
519
- inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv
520
- kv = inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, :sequence_end, ...]
521
-
522
- return kv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
 
 
524
 
525
- class MHA(nn.Module):
526
- """Multi-head attention layer."""
 
527
 
528
- def __init__(
529
- self,
530
- config: PretrainedConfig,
531
- dtype: Optional[torch.dtype] = None,
532
- device: Optional[str] = None,
533
- rotary_dim: Optional[int] = None,
534
- rotary_base: float = 10000.0,
535
- rotary_scale_base: Optional[float] = None,
536
- n_head: Optional[int] = None,
537
- n_head_kv: Optional[int] = None,
538
- head_dim: Optional[int] = None,
539
- bias: bool = True,
540
- causal: bool = True,
541
- softmax_scale: Optional[float] = None,
542
- layer_idx: Optional[int] = None,
543
- return_residual: bool = False,
544
- checkpointing: bool = False,
545
- ) -> None:
546
- super().__init__()
547
 
548
- # Rotary embedding
549
- self.rotary_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
550
- if self.rotary_dim > 0:
551
- rotary_cls = FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding
552
- if rotary_cls is None:
553
- rotary_cls = RotaryEmbedding
554
-
555
- rotary_kwargs = {}
556
- if rotary_cls is RotaryEmbedding:
557
- rotary_kwargs["max_position_embeddings"] = config.n_positions
558
-
559
- self.rotary_emb = rotary_cls(
560
- self.rotary_dim,
561
- base=rotary_base,
562
- scale_base=rotary_scale_base,
563
- device=device,
564
- **rotary_kwargs,
565
- )
566
 
567
- # MLP
568
- self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(
569
- config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim
570
- )
571
- op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
572
- hidden_size = config.n_embd
573
 
574
- linear_cls = FusedDense if config.fused_dense else nn.Linear
575
- if linear_cls is None:
576
- linear_cls = nn.Linear
 
577
 
578
- self.Wqkv = linear_cls(hidden_size, op_size, bias=bias, device=device, dtype=dtype)
579
- self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
580
 
581
- # Attention
582
- attn_cls = FlashSelfAttention if config.flash_attn else SelfAttention
583
- if attn_cls is None:
584
- attn_cls = SelfAttention
585
 
586
- cross_attn_cls = FlashCrossAttention if config.flash_attn else CrossAttention
587
- if cross_attn_cls is None:
588
- cross_attn_cls = CrossAttention
 
 
 
589
 
590
- self.inner_attn = attn_cls(
591
- causal=causal,
592
- softmax_scale=softmax_scale,
593
- attention_dropout=config.attn_pdrop,
594
- )
595
- self.inner_cross_attn = cross_attn_cls(
596
- causal=causal,
597
- softmax_scale=softmax_scale,
598
- attention_dropout=config.attn_pdrop,
599
  )
600
 
601
- self.flash_attn = config.flash_attn and attn_cls is FlashSelfAttention
602
- self.layer_idx = layer_idx
603
- self.return_residual = return_residual
604
- self.checkpointing = checkpointing
605
-
606
- def _forward_self_attn(
607
- self, x: torch.FloatTensor, key_padding_mask: Optional[torch.BoolTensor]
608
- ) -> torch.FloatTensor:
609
- qkv = self.Wqkv(x)
610
- qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
611
-
612
- if self.rotary_dim > 0:
613
- qkv = self.rotary_emb(qkv)
614
-
615
- if self.flash_attn:
616
- batch_size, seqlen = qkv.shape[0], qkv.shape[1]
617
-
618
- cu_seqlens, max_seqlen = None, None
619
- if key_padding_mask is not None:
620
- # If `key_padding_mask` is supplied, we need to unpad the input and retrieve
621
- # the `cu_seqlens` and `max_seqlen` to be used by `flash-attn`
622
- qkv, indices, cu_seqlens, max_seqlen = unpad_input(qkv, key_padding_mask)
623
-
624
- if self.checkpointing:
625
- attn_output = torch.utils.checkpoint.checkpoint(
626
- self.inner_attn, qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
627
- )
628
- else:
629
- attn_output = self.inner_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen).to(qkv.device)
630
 
631
- # If `key_padding_mask` is supplied, we need to pad the output back to the original shape
632
- return pad_input(attn_output, indices, batch_size, seqlen) if key_padding_mask is not None else attn_output
633
 
634
- if self.checkpointing:
635
- return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, key_padding_mask=key_padding_mask)
636
 
637
- return self.inner_attn(qkv, key_padding_mask=key_padding_mask)
638
 
639
- def _forward_cross_attn(
640
- self,
641
- x: torch.FloatTensor,
642
- past_key_values: Optional[InferenceParams],
643
- key_padding_mask: Optional[torch.BoolTensor],
644
- ) -> torch.FloatTensor:
645
- batch_size = x.shape[0]
646
 
647
- qkv = self.Wqkv(x)
648
 
649
- q = qkv[..., : self.n_head * self.head_dim]
650
- q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
 
 
 
651
 
652
- kv = qkv[..., self.n_head * self.head_dim :]
653
- kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
 
654
 
655
- seqlen_offset = past_key_values.seqlen_offset if past_key_values is not None else 0
656
- causal = None if seqlen_offset == 0 else False
657
- if self.rotary_dim > 0:
658
- q, kv = self.rotary_emb(q, kv=kv, seqlen_offset=seqlen_offset)
659
 
660
- if past_key_values is not None:
661
- kv = _update_kv_cache(kv, past_key_values, self.layer_idx)
 
 
662
 
663
- if self.flash_attn:
664
- batch_size, seqlen_q = q.shape[0], q.shape[1]
665
- seqlen_k = kv.shape[1]
666
 
667
- cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = (
668
- None,
669
- None,
670
- None,
671
- None,
672
- )
673
- if key_padding_mask is not None:
674
- kv, _, cu_seqlens_k, max_seqlen_k = unpad_input(kv, key_padding_mask)
675
-
676
- if seqlen_q == 1:
677
- key_padding_mask = torch.ones(batch_size, 1, device=q.device)
678
- elif seqlen_q != seqlen_k:
679
- key_padding_mask = key_padding_mask[:, -seqlen_q:]
680
-
681
- q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, key_padding_mask)
682
-
683
- if self.checkpointing:
684
- attn_output = torch.utils.checkpoint.checkpoint(
685
- self.inner_cross_attn,
686
- q,
687
- kv,
688
- causal=causal,
689
- cu_seqlens=cu_seqlens_q,
690
- max_seqlen=max_seqlen_q,
691
- cu_seqlens_k=cu_seqlens_k,
692
- max_seqlen_k=max_seqlen_k,
693
- )
694
- else:
695
- attn_output = self.inner_cross_attn(
696
- q,
697
- kv,
698
- causal=causal,
699
- cu_seqlens=cu_seqlens_q,
700
- max_seqlen=max_seqlen_q,
701
- cu_seqlens_k=cu_seqlens_k,
702
- max_seqlen_k=max_seqlen_k,
703
- )
704
 
705
- return (
706
- pad_input(attn_output, indices_q, batch_size, max_seqlen_q)
707
- if key_padding_mask is not None
708
- else attn_output
709
- )
710
 
711
- if self.checkpointing:
712
- return torch.utils.checkpoint.checkpoint(
713
- self.inner_cross_attn,
714
- q,
715
- kv,
716
- key_padding_mask=key_padding_mask,
717
- causal=causal,
718
- )
 
 
 
719
 
720
- return self.inner_cross_attn(q, kv, key_padding_mask=key_padding_mask, causal=causal)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
721
 
722
- def forward(
723
- self,
724
- x: torch.FloatTensor,
725
- past_key_values: Optional[InferenceParams] = None,
726
- attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
727
- **kwargs,
728
- ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
729
- if attention_mask is not None:
730
- attention_mask = attention_mask.bool()
731
- else:
732
- attention_mask = None
733
 
734
- # MHA
735
- if self.n_head == self.n_head_kv:
736
- if past_key_values is None:
737
- # If `past_key_values` are not supplied, we run self-attention
738
- attn_output = self._forward_self_attn(x, attention_mask)
739
- else:
740
- # If `past_key_values` are supplied, it means that we might have cached values and
741
- # could take advantage of cross-attention
742
- attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
743
- # MQA / GQA
744
- else:
745
- # Regardless of `past_key_values` being supplied or not, it always use cross-attention
746
- # because `q` and `kv` lengths might be different
747
- attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
748
 
749
- output = rearrange(attn_output, "... h d -> ... (h d)")
750
- output = self.out_proj(output)
751
 
752
- return output if not self.return_residual else (output, x)
 
 
753
 
 
 
 
 
 
754
 
755
- class ParallelBlock(nn.Module):
756
- """Parallel block.
757
- This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen).
758
- """
 
759
 
760
- def __init__(
761
- self,
762
- config: PretrainedConfig,
763
- block_idx: Optional[int] = None,
764
- ) -> None:
765
  super().__init__()
 
766
 
767
- self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
768
- self.resid_dropout = nn.Dropout(config.resid_pdrop)
769
- self.block_idx = block_idx
770
 
771
- self.mixer = MHA(config, layer_idx=block_idx)
772
- self.moe = MoE(config)
 
773
 
774
  def forward(
775
  self,
776
- hidden_states: torch.FloatTensor,
777
- past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
778
- attention_mask: Optional[torch.BoolTensor] = None,
 
 
 
 
779
  **kwargs,
780
- ) -> torch.FloatTensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
781
  residual = hidden_states
782
- hidden_states = self.ln(hidden_states)
783
 
784
- attn_outputs = self.mixer(
785
- hidden_states,
786
- past_key_values=past_key_values,
 
 
787
  attention_mask=attention_mask,
 
 
 
 
788
  )
789
- if isinstance(attn_outputs, tuple):
790
- attn_outputs = attn_outputs[0]
791
-
792
- attn_outputs = self.resid_dropout(attn_outputs)
793
- feed_forward_hidden_states = self.resid_dropout(self.moe(hidden_states))
794
-
795
- hidden_states = attn_outputs + feed_forward_hidden_states + residual
796
-
797
- return hidden_states
798
-
799
 
800
- class CausalLMHead(nn.Module):
801
- """Causal Language Modeling head.
802
- Reference:
803
- Improving Language Understanding by Generative Pre-Training.
804
- https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
805
- """
806
-
807
- def __init__(self, config: PretrainedConfig) -> None:
808
- super().__init__()
809
-
810
- self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
811
- self.linear = nn.Linear(config.n_embd, config.vocab_size)
812
 
813
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
814
- hidden_states = self.ln(hidden_states)
815
- logits = self.linear(hidden_states).to(torch.float32)
816
 
817
- return logits
 
818
 
 
 
819
 
820
- class CausalLMLoss(nn.Module):
821
- """Causal Language Modeling loss.
822
- Reference:
823
- Improving Language Understanding by Generative Pre-Training.
824
- https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
825
- """
826
 
827
- def __init__(self, shift_labels: bool = True) -> None:
828
- super().__init__()
829
 
830
- self.shift_labels = shift_labels
831
- self.loss_fct = nn.CrossEntropyLoss()
832
 
833
- def forward(self, logits: torch.FloatTensor, labels: torch.LongTensor) -> torch.FloatTensor:
834
- if self.shift_labels:
835
- logits = logits[..., :-1, :].contiguous()
836
- labels = labels[..., 1:].contiguous()
837
 
838
- loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
 
 
839
 
840
- return loss
 
 
 
 
 
841
 
842
 
 
 
 
 
 
843
  class LlamoePreTrainedModel(PreTrainedModel):
844
- """Phi pre-trained model."""
845
-
846
  config_class = LlamoeConfig
847
- base_model_prefix = "transformer"
848
- supports_gradient_checkpointing = False
849
- _no_split_modules = ["ParallelBlock"]
850
-
851
- def __init__(self, *inputs, **kwargs) -> None:
852
- super().__init__(*inputs, **kwargs)
853
-
854
- def _init_weights(self, module: nn.Module) -> None:
855
- if isinstance(module, (nn.Linear,)):
856
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
 
 
857
  if module.bias is not None:
858
  module.bias.data.zero_()
859
  elif isinstance(module, nn.Embedding):
860
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
861
  if module.padding_idx is not None:
862
  module.weight.data[module.padding_idx].zero_()
863
- elif isinstance(module, nn.LayerNorm):
864
- if module.bias is not None:
865
- module.bias.data.zero_()
866
- module.weight.data.fill_(1.0)
867
-
868
- def prepare_inputs_for_generation(
869
- self,
870
- input_ids: torch.LongTensor,
871
- past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
872
- attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
873
- **kwargs,
874
- ) -> Dict[str, Any]:
875
- if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
876
- past_key_values = InferenceParams(
877
- max_seqlen=self.config.n_positions,
878
- max_batch_size=input_ids.shape[0],
879
- seqlen_offset=0,
880
- batch_size_offset=0,
881
- key_value_memory_dict={},
882
- lengths_per_sample=None,
883
- )
884
- else:
885
- # Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
886
- past_key_values.seqlen_offset = input_ids.shape[1] - 1
887
- input_ids = input_ids[:, -1].unsqueeze(-1)
888
-
889
- return {
890
- "input_ids": input_ids,
891
- "past_key_values": past_key_values,
892
- "attention_mask": attention_mask,
893
- }
894
 
895
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
896
  class LlamoeModel(LlamoePreTrainedModel):
897
- """Llamoe model."""
 
898
 
899
- _keys_to_ignore_on_load_missing = [""]
900
- _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
 
901
 
902
- def __init__(self, config: LlamoeConfig) -> None:
903
  super().__init__(config)
 
 
 
 
 
 
 
 
 
904
 
905
- self.embd = Embedding(config)
906
- self.h = nn.ModuleList([ParallelBlock(config, block_idx=i) for i in range(config.n_layer)])
907
  self.gradient_checkpointing = False
 
908
  self.post_init()
909
 
910
- def get_input_embeddings(self) -> nn.Embedding:
911
- return self.embd.wte
912
 
913
- def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
914
- self.embd.wte = new_embeddings
915
 
 
 
916
  def forward(
917
  self,
918
- input_ids: torch.LongTensor,
919
- past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
920
- attention_mask: Optional[torch.BoolTensor] = None,
921
- ) -> torch.FloatTensor:
922
- hidden_states = self.embd(input_ids)
923
-
924
- for layer in self.h:
925
- hidden_states = layer(
926
- hidden_states,
927
- past_key_values=past_key_values,
928
- attention_mask=attention_mask,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
929
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
930
 
931
- return hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
932
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
933
 
934
- class LlamoeForCausalLM(LlamoePreTrainedModel):
935
- """Llamoe for Causal Language Modeling."""
936
 
937
- _keys_to_ignore_on_load_missing = [""]
938
- _keys_to_ignore_on_load_unexpected = [r"transformer\.h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
939
 
940
- def __init__(self, config: PhiConfig) -> None:
941
- super().__init__(config)
 
 
 
 
 
 
 
 
 
942
 
943
- self.transformer = LlamoeModel(config)
944
- self.lm_head = CausalLMHead(config)
945
- self.loss = CausalLMLoss()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
946
 
 
 
 
 
 
 
 
 
 
947
  self.post_init()
948
 
949
- def get_output_embeddings(self) -> nn.Linear:
950
- return self.lm_head.linear
 
 
 
 
 
 
951
 
952
- def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
953
- self.lm_head.linear = new_embeddings
954
 
 
 
 
 
 
 
 
 
 
955
  def forward(
956
  self,
957
- input_ids: torch.LongTensor,
958
- past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
959
- attention_mask: Optional[torch.BoolTensor] = None,
 
 
960
  labels: Optional[torch.LongTensor] = None,
961
- **kwargs,
962
- ) -> CausalLMOutputWithPast:
963
- hidden_states = self.transformer(input_ids, past_key_values=past_key_values, attention_mask=attention_mask)
964
- lm_logits = self.lm_head(hidden_states)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
965
 
966
  loss = None
967
  if labels is not None:
968
- loss = self.loss(lm_logits, labels)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
969
 
970
- return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  # Copyright (c) 2022, Tri Dao, [email protected].
5
  # Licensed under the BSD 3-Clause License.
6
 
7
+ import inspect
 
8
  import math
9
+ import warnings
10
+ from typing import List, Optional, Tuple, Union
11
 
12
  import torch
13
+ import torch.nn.functional as F
14
+ import torch.utils.checkpoint
15
+ from torch import nn
16
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 
17
 
18
+ from transformers.activations import ACT2FN
19
+ from transformers.cache_utils import Cache, DynamicCache
20
+ from transformers.modeling_attn_mask_utils import (
21
+ _prepare_4d_causal_attention_mask,
22
+ _prepare_4d_causal_attention_mask_for_sdpa,
23
+ )
24
+ from transformers.modeling_outputs import (
25
+ MoeCausalLMOutputWithPast,
26
+ MoeModelOutputWithPast,
27
+ SequenceClassifierOutputWithPast,
28
+ )
29
+ from transformers.modeling_utils import PreTrainedModel
30
+ from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13
31
+ from transformers.utils import (
32
+ add_start_docstrings,
33
+ add_start_docstrings_to_model_forward,
34
+ is_flash_attn_2_available,
35
+ is_flash_attn_greater_or_equal_2_10,
36
+ logging,
37
+ replace_return_docstrings,
38
+ )
39
+ from transformers.utils.import_utils import is_torch_fx_available
40
  from .configuration_Llamoe import LlamoeConfig
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ if is_flash_attn_2_available():
44
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
45
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
46
 
47
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
48
 
49
+ # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
50
+ # It means that the function will not be traced through and simply appear as a node in the graph.
51
+ if is_torch_fx_available():
52
+ if not is_torch_greater_or_equal_than_1_13:
53
+ import torch.fx
54
 
55
+ _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
 
 
56
 
 
57
 
58
+ logger = logging.get_logger(__name__)
59
 
60
+ _CONFIG_FOR_DOC = "LlamoeConfig"
 
61
 
 
 
62
 
63
+ def load_balancing_loss_func(
64
+ gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
65
+ ) -> float:
66
+ r"""
67
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
68
 
69
+ See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
70
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
71
+ experts is too unbalanced.
72
 
73
+ Args:
74
+ gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
75
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
76
+ shape [batch_size X sequence_length, num_experts].
77
+ attention_mask (`torch.Tensor`, None):
78
+ The attention_mask used in forward function
79
+ shape [batch_size X sequence_length] if not None.
80
+ num_experts (`int`, *optional*):
81
+ Number of experts
82
+
83
+ Returns:
84
+ The auxiliary loss.
85
+ """
86
+ if gate_logits is None or not isinstance(gate_logits, tuple):
87
+ return 0
88
 
89
+ if isinstance(gate_logits, tuple):
90
+ compute_device = gate_logits[0].device
91
+ concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
92
 
93
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
94
 
95
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
 
 
 
 
 
 
 
96
 
97
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
 
98
 
99
+ if attention_mask is None:
100
+ # Compute the percentage of tokens routed to each experts
101
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
102
 
103
+ # Compute the average probability of routing to these experts
104
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
105
+ else:
106
+ batch_size, sequence_length = attention_mask.shape
107
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
108
 
109
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
110
+ expert_attention_mask = (
111
+ attention_mask[None, :, :, None, None]
112
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
113
+ .reshape(-1, top_k, num_experts)
114
+ .to(compute_device)
115
+ )
116
 
117
+ # Compute the percentage of tokens routed to each experts
118
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
119
+ expert_attention_mask, dim=0
120
+ )
121
 
122
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
123
+ router_per_expert_attention_mask = (
124
+ attention_mask[None, :, :, None]
125
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
126
+ .reshape(-1, num_experts)
127
+ .to(compute_device)
128
+ )
 
 
 
129
 
130
+ # Compute the average probability of routing to these experts
131
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
132
+ router_per_expert_attention_mask, dim=0
133
+ )
134
 
135
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
136
+ return overall_loss * num_experts
 
137
 
 
138
 
139
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
140
+ def _get_unpad_data(attention_mask):
141
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
142
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
143
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
144
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
145
+ return (
146
+ indices,
147
+ cu_seqlens,
148
+ max_seqlen_in_batch,
149
  )
150
 
151
 
152
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mixtral
153
+ class LlamoeRMSNorm(nn.Module):
154
+ def __init__(self, hidden_size, eps=1e-6):
155
+ """
156
+ MixtralRMSNorm is equivalent to T5LayerNorm
157
+ """
158
+ super().__init__()
159
+ self.weight = nn.Parameter(torch.ones(hidden_size))
160
+ self.variance_epsilon = eps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
+ def forward(self, hidden_states):
163
+ input_dtype = hidden_states.dtype
164
+ hidden_states = hidden_states.to(torch.float32)
165
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
166
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
167
+ return self.weight * hidden_states.to(input_dtype)
168
 
 
 
 
 
 
 
169
 
170
+ # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral
171
+ class LlamoeRotaryEmbedding(nn.Module):
172
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
 
 
 
 
 
 
 
173
  super().__init__()
174
 
 
 
 
175
  self.dim = dim
 
 
 
176
  self.max_position_embeddings = max_position_embeddings
177
+ self.base = base
178
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
 
 
179
  self.register_buffer("inv_freq", inv_freq, persistent=False)
180
 
181
+ # Build here to make `torch.jit.trace` work.
182
+ self._set_cos_sin_cache(
183
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
 
 
184
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
187
+ self.max_seq_len_cached = seq_len
188
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
+ freqs = torch.outer(t, self.inv_freq)
191
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
192
+ emb = torch.cat((freqs, freqs), dim=-1)
193
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
194
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
195
 
196
+ def forward(self, x, seq_len=None):
197
+ # x: [bs, num_attention_heads, seq_len, head_size]
198
+ if seq_len > self.max_seq_len_cached:
199
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
200
 
201
+ return (
202
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
203
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
204
+ )
 
 
 
 
 
205
 
 
 
 
206
 
207
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
208
+ def rotate_half(x):
209
+ """Rotates half the hidden dims of the input."""
210
+ x1 = x[..., : x.shape[-1] // 2]
211
+ x2 = x[..., x.shape[-1] // 2 :]
212
+ return torch.cat((-x2, x1), dim=-1)
213
 
 
 
 
 
 
 
214
 
215
+ # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
216
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
217
+ """Applies Rotary Position Embedding to the query and key tensors.
218
 
219
+ Args:
220
+ q (`torch.Tensor`): The query tensor.
221
+ k (`torch.Tensor`): The key tensor.
222
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
223
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
224
+ position_ids (`torch.Tensor`):
225
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
226
+ used to pass offsetted position ids when working with a KV-cache.
227
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
228
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
229
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
230
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
231
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
232
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
233
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
234
+ Returns:
235
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
236
  """
237
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
238
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
239
+ q_embed = (q * cos) + (rotate_half(q) * sin)
240
+ k_embed = (k * cos) + (rotate_half(k) * sin)
241
+ return q_embed, k_embed
242
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
 
244
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
245
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
246
+ """
247
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
248
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
249
+ """
250
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
251
+ if n_rep == 1:
252
  return hidden_states
253
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
254
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
255
 
256
 
257
+ # Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral
258
+ class LlamoeAttention(nn.Module):
259
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
 
 
260
 
261
+ def __init__(self, config: LlamoeConfig, layer_idx: Optional[int] = None):
 
 
 
 
 
262
  super().__init__()
263
+ self.config = config
264
+ self.layer_idx = layer_idx
265
+ if layer_idx is None:
266
+ logger.warning_once(
267
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
268
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
269
+ "when creating this class."
270
+ )
271
 
272
+ self.attention_dropout = config.attention_dropout
273
+ self.hidden_size = config.hidden_size
274
+ self.num_heads = config.num_attention_heads
275
+ self.head_dim = self.hidden_size // self.num_heads
276
+ self.num_key_value_heads = config.num_key_value_heads
277
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
278
+ self.max_position_embeddings = config.max_position_embeddings
279
+ self.rope_theta = config.rope_theta
280
+ self.is_causal = True
281
+
282
+ if (self.head_dim * self.num_heads) != self.hidden_size:
283
+ raise ValueError(
284
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
285
+ f" and `num_heads`: {self.num_heads})."
286
+ )
287
+
288
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
289
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
290
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
291
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
292
+ self._init_rope()
293
+
294
+ def _init_rope(self):
295
+ if self.config.rope_scaling is None:
296
+ self.rotary_emb = LlamoeRotaryEmbedding(
297
+ self.head_dim,
298
+ max_position_embeddings=self.max_position_embeddings,
299
+ base=self.rope_theta,
300
+ )
301
+
302
+ else:
303
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
304
 
 
 
305
  def forward(
306
  self,
307
+ hidden_states: torch.Tensor,
308
+ attention_mask: Optional[torch.Tensor] = None,
309
+ position_ids: Optional[torch.LongTensor] = None,
310
+ past_key_value: Optional[Cache] = None,
311
+ output_attentions: bool = False,
312
+ use_cache: bool = False,
313
+ cache_position: Optional[torch.LongTensor] = None,
314
  **kwargs,
315
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
316
+ bsz, q_len, _ = hidden_states.size()
317
+
318
+ if self.config.pretraining_tp > 1:
319
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
320
+ query_slices = self.q_proj.weight.split(
321
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
322
+ )
323
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
324
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
325
 
326
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
327
+ query_states = torch.cat(query_states, dim=-1)
328
 
329
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
330
+ key_states = torch.cat(key_states, dim=-1)
331
 
332
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
333
+ value_states = torch.cat(value_states, dim=-1)
 
334
 
335
+ else:
336
+ query_states = self.q_proj(hidden_states)
337
+ key_states = self.k_proj(hidden_states)
338
+ value_states = self.v_proj(hidden_states)
339
+
340
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
341
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
342
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
343
+
344
+ past_key_value = getattr(self, "past_key_value", past_key_value)
345
+ cos, sin = self.rotary_emb(value_states, position_ids)
346
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
347
+
348
+ if past_key_value is not None:
349
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
350
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
351
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
352
+
353
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
354
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
355
+
356
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
357
+
358
+ if attention_mask is not None: # no matter the length, we just slice it
359
+ causal_mask = attention_mask
360
+ if cache_position is not None:
361
+ causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
362
+ attn_weights = attn_weights + causal_mask
363
+
364
+ # upcast attention to fp32
365
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
366
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
367
+ attn_output = torch.matmul(attn_weights, value_states)
368
+
369
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
370
+ raise ValueError(
371
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
372
+ f" {attn_output.size()}"
373
+ )
374
 
375
+ attn_output = attn_output.transpose(1, 2).contiguous()
376
 
377
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 
 
378
 
379
+ if self.config.pretraining_tp > 1:
380
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
381
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
382
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
383
+ else:
384
+ attn_output = self.o_proj(attn_output)
385
 
386
+ if not output_attentions:
387
+ attn_weights = None
388
 
389
+ return attn_output, attn_weights, past_key_value
390
 
391
 
392
+ class LlamoeFlashAttention2(LlamoeAttention):
393
+ """
394
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
395
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
396
+ flash attention and deal with padding tokens in case the input contains any of them.
397
  """
398
 
399
+ def __init__(self, *args, **kwargs):
400
+ super().__init__(*args, **kwargs)
 
 
 
 
 
401
 
402
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
403
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
404
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
405
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
406
 
 
 
407
  def forward(
408
  self,
409
+ hidden_states: torch.Tensor,
410
+ attention_mask: Optional[torch.LongTensor] = None,
411
+ position_ids: Optional[torch.LongTensor] = None,
412
+ past_key_value: Optional[Cache] = None,
413
+ output_attentions: bool = False,
414
+ use_cache: bool = False,
415
+ cache_position: Optional[torch.LongTensor] = None,
416
  **kwargs,
417
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
418
+ output_attentions = False
419
+
420
+ bsz, q_len, _ = hidden_states.size()
421
+
422
+ query_states = self.q_proj(hidden_states)
423
+ key_states = self.k_proj(hidden_states)
424
+ value_states = self.v_proj(hidden_states)
425
+
426
+ # Flash attention requires the input to have the shape
427
+ # batch_size x seq_length x head_dim x hidden_dim
428
+ # therefore we just need to keep the original shape
429
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
430
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
431
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
432
+
433
+ cos, sin = self.rotary_emb(value_states, position_ids)
434
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
435
+
436
+ past_key_value = getattr(self, "past_key_value", past_key_value)
437
+
438
+ if past_key_value is not None:
439
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
440
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
441
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
442
+
443
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
444
+ # to be able to avoid many of these transpose/reshape/view.
445
+ query_states = query_states.transpose(1, 2)
446
+ key_states = key_states.transpose(1, 2)
447
+ value_states = value_states.transpose(1, 2)
448
+
449
+ dropout_rate = self.attention_dropout if self.training else 0.0
450
+
451
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
452
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
453
+ # cast them back in the correct dtype just to be sure everything works as expected.
454
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
455
+ # in fp32. (LlamaRMSNorm handles it correctly)
456
+
457
+ input_dtype = query_states.dtype
458
+ if input_dtype == torch.float32:
459
+ if torch.is_autocast_enabled():
460
+ target_dtype = torch.get_autocast_gpu_dtype()
461
+ # Handle the case where the model is quantized
462
+ elif hasattr(self.config, "_pre_quantization_dtype"):
463
+ target_dtype = self.config._pre_quantization_dtype
464
+ else:
465
+ target_dtype = self.q_proj.weight.dtype
466
 
467
+ logger.warning_once(
468
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
469
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
470
+ f" {target_dtype}."
471
+ )
472
 
473
+ query_states = query_states.to(target_dtype)
474
+ key_states = key_states.to(target_dtype)
475
+ value_states = value_states.to(target_dtype)
 
476
 
477
+ attn_output = self._flash_attention_forward(
478
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
479
+ )
480
 
481
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
482
+ attn_output = self.o_proj(attn_output)
483
 
484
+ if not output_attentions:
485
+ attn_weights = None
486
 
487
+ return attn_output, attn_weights, past_key_value
488
 
489
+ def _flash_attention_forward(
490
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
491
+ ):
492
+ """
493
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
494
+ first unpad the input, then computes the attention scores and pad the final attention scores.
495
+
496
+ Args:
497
+ query_states (`torch.Tensor`):
498
+ Input query states to be passed to Flash Attention API
499
+ key_states (`torch.Tensor`):
500
+ Input key states to be passed to Flash Attention API
501
+ value_states (`torch.Tensor`):
502
+ Input value states to be passed to Flash Attention API
503
+ attention_mask (`torch.Tensor`):
504
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
505
+ position of padding tokens and 1 for the position of non-padding tokens.
506
+ dropout (`float`):
507
+ Attention dropout
508
+ softmax_scale (`float`, *optional*):
509
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
510
+ """
511
+ if not self._flash_attn_uses_top_left_mask:
512
+ causal = self.is_causal
513
+ else:
514
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
515
+ causal = self.is_causal and query_length != 1
516
 
517
+ # Contains at least one padding token in the sequence
518
+ if attention_mask is not None:
519
+ batch_size = query_states.shape[0]
520
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
521
+ query_states, key_states, value_states, attention_mask, query_length
522
+ )
 
 
 
 
 
523
 
524
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
525
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
526
+
527
+ attn_output_unpad = flash_attn_varlen_func(
528
+ query_states,
529
+ key_states,
530
+ value_states,
531
+ cu_seqlens_q=cu_seqlens_q,
532
+ cu_seqlens_k=cu_seqlens_k,
533
+ max_seqlen_q=max_seqlen_in_batch_q,
534
+ max_seqlen_k=max_seqlen_in_batch_k,
535
+ dropout_p=dropout,
536
+ softmax_scale=softmax_scale,
537
+ causal=causal,
538
+ )
539
 
540
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
541
+ else:
542
+ attn_output = flash_attn_func(
543
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
544
+ )
545
 
546
+ return attn_output
547
 
548
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
549
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
550
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
551
 
552
+ key_layer = index_first_axis(
553
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
554
+ )
555
+ value_layer = index_first_axis(
556
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
557
+ )
558
+ if query_length == kv_seq_len:
559
+ query_layer = index_first_axis(
560
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
561
+ )
562
+ cu_seqlens_q = cu_seqlens_k
563
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
564
+ indices_q = indices_k
565
+ elif query_length == 1:
566
+ max_seqlen_in_batch_q = 1
567
+ cu_seqlens_q = torch.arange(
568
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
569
+ ) # There is a memcpy here, that is very bad.
570
+ indices_q = cu_seqlens_q[:-1]
571
+ query_layer = query_layer.squeeze(1)
572
+ else:
573
+ # The -q_len: slice assumes left padding.
574
+ attention_mask = attention_mask[:, -query_length:]
575
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
576
+
577
+ return (
578
+ query_layer,
579
+ key_layer,
580
+ value_layer,
581
+ indices_q,
582
+ (cu_seqlens_q, cu_seqlens_k),
583
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
584
  )
585
 
 
 
 
 
 
586
 
587
+ class LlamoeSdpaAttention(LlamoeAttention):
588
+ """
589
+ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
590
+ `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
591
+ SDPA API.
592
+ """
593
 
594
+ # Adapted from LlamaAttention.forward
595
+ def forward(
596
+ self,
597
+ hidden_states: torch.Tensor,
598
+ attention_mask: Optional[torch.Tensor] = None,
599
+ position_ids: Optional[torch.LongTensor] = None,
600
+ past_key_value: Optional[Cache] = None,
601
+ output_attentions: bool = False,
602
+ use_cache: bool = False,
603
+ cache_position: Optional[torch.LongTensor] = None,
604
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
605
+ if output_attentions:
606
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
607
+ logger.warning_once(
608
+ "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
609
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
610
+ )
611
+ return super().forward(
612
+ hidden_states=hidden_states,
613
+ attention_mask=attention_mask,
614
+ position_ids=position_ids,
615
+ past_key_value=past_key_value,
616
+ output_attentions=output_attentions,
617
+ use_cache=use_cache,
618
+ cache_position=cache_position,
619
+ )
620
 
621
+ bsz, q_len, _ = hidden_states.size()
622
 
623
+ query_states = self.q_proj(hidden_states)
624
+ key_states = self.k_proj(hidden_states)
625
+ value_states = self.v_proj(hidden_states)
626
 
627
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
628
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
629
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
630
 
631
+ cos, sin = self.rotary_emb(value_states, position_ids)
632
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
633
 
634
+ # In case static cache is used, it is an instance attribute.
635
+ past_key_value = getattr(self, "past_key_value", past_key_value)
 
 
 
 
636
 
637
+ if past_key_value is not None:
638
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
639
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
640
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
641
 
642
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
643
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
644
 
645
+ causal_mask = attention_mask
646
+ if attention_mask is not None and cache_position is not None:
647
+ causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
 
648
 
649
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
650
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
651
+ if query_states.device.type == "cuda" and causal_mask is not None:
652
+ query_states = query_states.contiguous()
653
+ key_states = key_states.contiguous()
654
+ value_states = value_states.contiguous()
655
 
656
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
657
+ query_states,
658
+ key_states,
659
+ value_states,
660
+ attn_mask=causal_mask,
661
+ dropout_p=self.attention_dropout if self.training else 0.0,
 
 
 
662
  )
663
 
664
+ attn_output = attn_output.transpose(1, 2).contiguous()
665
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
666
 
667
+ attn_output = self.o_proj(attn_output)
 
668
 
669
+ return attn_output, None, past_key_value
 
670
 
 
671
 
672
+ LLAMOE_ATTENTION_CLASSES = {
673
+ "eager": LlamoeAttention,
674
+ "flash_attention_2": LlamoeFlashAttention2,
675
+ "sdpa": LlamoeSdpaAttention,
676
+ }
 
 
677
 
 
678
 
679
+ class LlamoeBlockSparseTop2MLP(nn.Module):
680
+ def __init__(self, config:LlamoeConfig):
681
+ super().__init__()
682
+ self.ffn_dim = config.intermediate_size
683
+ self.hidden_dim = config.hidden_size
684
 
685
+ self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
686
+ self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
687
+ self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
688
 
689
+ self.act_fn = ACT2FN[config.hidden_act]
 
 
 
690
 
691
+ def forward(self, hidden_states):
692
+ current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
693
+ current_hidden_states = self.w2(current_hidden_states)
694
+ return current_hidden_states
695
 
 
 
 
696
 
697
+ class LlamoeBLockSparseTop2MLP(LlamoeBlockSparseTop2MLP):
698
+ def __init__(self, *args, **kwargs):
699
+ logger.warning_once(
700
+ "MixtralBLockSparseTop2MLP is deprecated by MixtralBlockSparseTop2MLP and will be removed in v4.40."
701
+ )
702
+ super().__init__(*args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
703
 
 
 
 
 
 
704
 
705
+ class LlamoeSparseMoeBlock(nn.Module):
706
+ """
707
+ This implementation is
708
+ strictly equivalent to standard MoE with full capacity (no
709
+ dropped tokens). It's faster since it formulates MoE operations
710
+ in terms of block-sparse operations to accomodate imbalanced
711
+ assignments of tokens to experts, whereas standard MoE either
712
+ (1) drop tokens at the cost of reduced performance or (2) set
713
+ capacity factor to number of experts and thus waste computation
714
+ and memory on padding.
715
+ """
716
 
717
+ def __init__(self, config):
718
+ super().__init__()
719
+ self.hidden_dim = config.hidden_size
720
+ self.ffn_dim = config.intermediate_size
721
+ self.num_experts = config.num_local_experts
722
+ self.top_k = config.num_experts_per_tok
723
+
724
+ # gating
725
+ self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
726
+
727
+ self.experts = nn.ModuleList([LlamoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
728
+
729
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
730
+ """ """
731
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
732
+ hidden_states = hidden_states.view(-1, hidden_dim)
733
+ # router_logits: (batch * sequence_length, n_experts)
734
+ router_logits = self.gate(hidden_states)
735
+
736
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
737
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
738
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
739
+ # we cast back to the input dtype
740
+ routing_weights = routing_weights.to(hidden_states.dtype)
741
+
742
+ final_hidden_states = torch.zeros(
743
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
744
+ )
745
 
746
+ # One hot encode the selected experts to create an expert mask
747
+ # this will be used to easily index which expert is going to be sollicitated
748
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
 
 
 
 
 
 
 
 
749
 
750
+ # Loop over all available experts in the model and perform the computation on each expert
751
+ for expert_idx in range(self.num_experts):
752
+ expert_layer = self.experts[expert_idx]
753
+ idx, top_x = torch.where(expert_mask[expert_idx])
 
 
 
 
 
 
 
 
 
 
754
 
755
+ if top_x.shape[0] == 0:
756
+ continue
757
 
758
+ # in torch it is faster to index using lists than torch tensors
759
+ top_x_list = top_x.tolist()
760
+ idx_list = idx.tolist()
761
 
762
+ # Index the correct hidden states and compute the expert hidden state for
763
+ # the current expert. We need to make sure to multiply the output hidden
764
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
765
+ current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
766
+ current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
767
 
768
+ # However `index_add_` only support torch tensors for indexing so we'll use
769
+ # the `top_x` tensor here.
770
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
771
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
772
+ return final_hidden_states, router_logits
773
 
774
+
775
+ class LlamoeDecoderLayer(nn.Module):
776
+ def __init__(self, config: LlamoeConfig, layer_idx: int):
 
 
777
  super().__init__()
778
+ self.hidden_size = config.hidden_size
779
 
780
+ self.self_attn = LLAMOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
 
 
781
 
782
+ self.block_sparse_moe = LlamoeSparseMoeBlock(config)
783
+ self.input_layernorm = LlamoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
784
+ self.post_attention_layernorm = LlamoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
785
 
786
  def forward(
787
  self,
788
+ hidden_states: torch.Tensor,
789
+ attention_mask: Optional[torch.Tensor] = None,
790
+ position_ids: Optional[torch.LongTensor] = None,
791
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
792
+ output_attentions: Optional[bool] = False,
793
+ output_router_logits: Optional[bool] = False,
794
+ use_cache: Optional[bool] = False,
795
  **kwargs,
796
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
797
+ if "padding_mask" in kwargs:
798
+ warnings.warn(
799
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
800
+ )
801
+ """
802
+ Args:
803
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
804
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
805
+ `(batch, sequence_length)` where padding elements are indicated by 0.
806
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
807
+ output_attentions (`bool`, *optional*):
808
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
809
+ returned tensors for more detail.
810
+ output_router_logits (`bool`, *optional*):
811
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
812
+ should not be returned during inference.
813
+ use_cache (`bool`, *optional*):
814
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
815
+ (see `past_key_values`).
816
+ """
817
+
818
  residual = hidden_states
 
819
 
820
+ hidden_states = self.input_layernorm(hidden_states)
821
+
822
+ # Self Attention
823
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
824
+ hidden_states=hidden_states,
825
  attention_mask=attention_mask,
826
+ position_ids=position_ids,
827
+ past_key_value=past_key_value,
828
+ output_attentions=output_attentions,
829
+ use_cache=use_cache,
830
  )
831
+ hidden_states = residual + hidden_states
 
 
 
 
 
 
 
 
 
832
 
833
+ # Fully Connected
834
+ residual = hidden_states
835
+ hidden_states = self.post_attention_layernorm(hidden_states)
836
+ hidden_states, router_logits = self.block_sparse_moe(hidden_states)
837
+ hidden_states = residual + hidden_states
 
 
 
 
 
 
 
838
 
839
+ outputs = (hidden_states,)
 
 
840
 
841
+ if output_attentions:
842
+ outputs += (self_attn_weights,)
843
 
844
+ if use_cache:
845
+ outputs += (present_key_value,)
846
 
847
+ if output_router_logits:
848
+ outputs += (router_logits,)
 
 
 
 
849
 
850
+ return outputs
 
851
 
 
 
852
 
853
+ LLAMOE_START_DOCSTRING = r"""
854
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
855
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
856
+ etc.)
857
 
858
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
859
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
860
+ and behavior.
861
 
862
+ Parameters:
863
+ config ([`MixtralConfig`]):
864
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
865
+ load the weights associated with the model, only the configuration. Check out the
866
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
867
+ """
868
 
869
 
870
+ @add_start_docstrings(
871
+ "The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
872
+ LLAMOE_START_DOCSTRING,
873
+ )
874
+ # Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Mixtral
875
  class LlamoePreTrainedModel(PreTrainedModel):
 
 
876
  config_class = LlamoeConfig
877
+ base_model_prefix = "model"
878
+ supports_gradient_checkpointing = True
879
+ _no_split_modules = ["LlamoeDecoderLayer"]
880
+ _skip_keys_device_placement = "past_key_values"
881
+ _supports_flash_attn_2 = True
882
+ _supports_sdpa = True
883
+ _supports_cache_class = True
884
+
885
+ def _init_weights(self, module):
886
+ std = self.config.initializer_range
887
+ if isinstance(module, nn.Linear):
888
+ module.weight.data.normal_(mean=0.0, std=std)
889
  if module.bias is not None:
890
  module.bias.data.zero_()
891
  elif isinstance(module, nn.Embedding):
892
+ module.weight.data.normal_(mean=0.0, std=std)
893
  if module.padding_idx is not None:
894
  module.weight.data[module.padding_idx].zero_()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
895
 
896
 
897
+ LLAMOE_INPUTS_DOCSTRING = r"""
898
+ Args:
899
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
900
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
901
+ it.
902
+
903
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
904
+ [`PreTrainedTokenizer.__call__`] for details.
905
+
906
+ [What are input IDs?](../glossary#input-ids)
907
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
908
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
909
+
910
+ - 1 for tokens that are **not masked**,
911
+ - 0 for tokens that are **masked**.
912
+
913
+ [What are attention masks?](../glossary#attention-mask)
914
+
915
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
916
+ [`PreTrainedTokenizer.__call__`] for details.
917
+
918
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
919
+ `past_key_values`).
920
+
921
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
922
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
923
+ information on the default strategy.
924
+
925
+ - 1 indicates the head is **not masked**,
926
+ - 0 indicates the head is **masked**.
927
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
928
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
929
+ config.n_positions - 1]`.
930
+
931
+ [What are position IDs?](../glossary#position-ids)
932
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
933
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
934
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
935
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
936
+
937
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
938
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
939
+
940
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
941
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
942
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
943
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
944
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
945
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
946
+ model's internal embedding lookup matrix.
947
+ use_cache (`bool`, *optional*):
948
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
949
+ `past_key_values`).
950
+ output_attentions (`bool`, *optional*):
951
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
952
+ tensors for more detail.
953
+ output_hidden_states (`bool`, *optional*):
954
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
955
+ more detail.
956
+ output_router_logits (`bool`, *optional*):
957
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
958
+ should not be returned during inference.
959
+ return_dict (`bool`, *optional*):
960
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
961
+ """
962
+
963
+
964
+ @add_start_docstrings(
965
+ "The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
966
+ LLAMOE_START_DOCSTRING,
967
+ )
968
+ # Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral
969
  class LlamoeModel(LlamoePreTrainedModel):
970
+ """
971
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`]
972
 
973
+ Args:
974
+ config: MixtralConfig
975
+ """
976
 
977
+ def __init__(self, config: LlamoeConfig):
978
  super().__init__(config)
979
+ self.padding_idx = config.pad_token_id
980
+ self.vocab_size = config.vocab_size
981
+
982
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
983
+ self.layers = nn.ModuleList(
984
+ [LlamoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
985
+ )
986
+ self._attn_implementation = config._attn_implementation
987
+ self.norm = LlamoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
988
 
 
 
989
  self.gradient_checkpointing = False
990
+ # Initialize weights and apply final processing
991
  self.post_init()
992
 
993
+ def get_input_embeddings(self):
994
+ return self.embed_tokens
995
 
996
+ def set_input_embeddings(self, value):
997
+ self.embed_tokens = value
998
 
999
+ # Ignore copy
1000
+ @add_start_docstrings_to_model_forward(LLAMOE_INPUTS_DOCSTRING)
1001
  def forward(
1002
  self,
1003
+ input_ids: torch.LongTensor = None,
1004
+ attention_mask: Optional[torch.Tensor] = None,
1005
+ position_ids: Optional[torch.LongTensor] = None,
1006
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1007
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1008
+ use_cache: Optional[bool] = None,
1009
+ output_attentions: Optional[bool] = None,
1010
+ output_hidden_states: Optional[bool] = None,
1011
+ output_router_logits: Optional[bool] = None,
1012
+ return_dict: Optional[bool] = None,
1013
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
1014
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1015
+ output_router_logits = (
1016
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
1017
+ )
1018
+ output_hidden_states = (
1019
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1020
+ )
1021
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1022
+
1023
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1024
+
1025
+ # retrieve input_ids and inputs_embeds
1026
+ if input_ids is not None and inputs_embeds is not None:
1027
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
1028
+ elif input_ids is not None:
1029
+ batch_size, seq_length = input_ids.shape
1030
+ elif inputs_embeds is not None:
1031
+ batch_size, seq_length, _ = inputs_embeds.shape
1032
+ else:
1033
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1034
+
1035
+ past_key_values_length = 0
1036
+
1037
+ if self.gradient_checkpointing and self.training:
1038
+ if use_cache:
1039
+ logger.warning_once(
1040
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1041
+ )
1042
+ use_cache = False
1043
+
1044
+ if use_cache:
1045
+ use_legacy_cache = not isinstance(past_key_values, Cache)
1046
+ if use_legacy_cache:
1047
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1048
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
1049
+
1050
+ if position_ids is None:
1051
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1052
+ position_ids = torch.arange(
1053
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1054
  )
1055
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1056
+ else:
1057
+ position_ids = position_ids.view(-1, seq_length).long()
1058
+
1059
+ if inputs_embeds is None:
1060
+ inputs_embeds = self.embed_tokens(input_ids)
1061
+
1062
+ if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
1063
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
1064
+ if is_padding_right:
1065
+ raise ValueError(
1066
+ "You are attempting to perform batched generation with padding_side='right'"
1067
+ " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
1068
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1069
+ )
1070
 
1071
+ if self._attn_implementation == "flash_attention_2":
1072
+ # 2d mask is passed through the layers
1073
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1074
+ elif self._attn_implementation == "sdpa" and not output_attentions:
1075
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
1076
+ # the manual implementation that requires a 4D causal mask in all cases.
1077
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1078
+ attention_mask,
1079
+ (batch_size, seq_length),
1080
+ inputs_embeds,
1081
+ past_key_values_length,
1082
+ )
1083
+ else:
1084
+ # 4d mask is passed through the layers
1085
+ attention_mask = _prepare_4d_causal_attention_mask(
1086
+ attention_mask,
1087
+ (batch_size, seq_length),
1088
+ inputs_embeds,
1089
+ past_key_values_length,
1090
+ sliding_window=self.config.sliding_window,
1091
+ )
1092
 
1093
+ hidden_states = inputs_embeds
1094
+
1095
+ # decoder layers
1096
+ all_hidden_states = () if output_hidden_states else None
1097
+ all_self_attns = () if output_attentions else None
1098
+ all_router_logits = () if output_router_logits else None
1099
+ next_decoder_cache = None
1100
+
1101
+ for decoder_layer in self.layers:
1102
+ if output_hidden_states:
1103
+ all_hidden_states += (hidden_states,)
1104
+
1105
+ if self.gradient_checkpointing and self.training:
1106
+ layer_outputs = self._gradient_checkpointing_func(
1107
+ decoder_layer.__call__,
1108
+ hidden_states,
1109
+ attention_mask,
1110
+ position_ids,
1111
+ past_key_values,
1112
+ output_attentions,
1113
+ output_router_logits,
1114
+ use_cache,
1115
+ )
1116
+ else:
1117
+ layer_outputs = decoder_layer(
1118
+ hidden_states,
1119
+ attention_mask=attention_mask,
1120
+ position_ids=position_ids,
1121
+ past_key_value=past_key_values,
1122
+ output_attentions=output_attentions,
1123
+ output_router_logits=output_router_logits,
1124
+ use_cache=use_cache,
1125
+ )
1126
 
1127
+ hidden_states = layer_outputs[0]
 
1128
 
1129
+ if use_cache:
1130
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1131
 
1132
+ if output_attentions:
1133
+ all_self_attns += (layer_outputs[1],)
1134
+
1135
+ if output_router_logits:
1136
+ all_router_logits += (layer_outputs[-1],)
1137
+
1138
+ hidden_states = self.norm(hidden_states)
1139
+
1140
+ # add hidden states from the last decoder layer
1141
+ if output_hidden_states:
1142
+ all_hidden_states += (hidden_states,)
1143
 
1144
+ next_cache = None
1145
+ if use_cache:
1146
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1147
+
1148
+ if not return_dict:
1149
+ return tuple(
1150
+ v
1151
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
1152
+ if v is not None
1153
+ )
1154
+ return MoeModelOutputWithPast(
1155
+ last_hidden_state=hidden_states,
1156
+ past_key_values=next_cache,
1157
+ hidden_states=all_hidden_states,
1158
+ attentions=all_self_attns,
1159
+ router_logits=all_router_logits,
1160
+ )
1161
+
1162
+
1163
+ class LlamoeForCausalLM(LlamoePreTrainedModel):
1164
+ _tied_weights_keys = ["lm_head.weight"]
1165
 
1166
+ def __init__(self, config):
1167
+ super().__init__(config)
1168
+ self.model = LlamoeModel(config)
1169
+ self.vocab_size = config.vocab_size
1170
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1171
+ self.router_aux_loss_coef = config.router_aux_loss_coef
1172
+ self.num_experts = config.num_local_experts
1173
+ self.num_experts_per_tok = config.num_experts_per_tok
1174
+ # Initialize weights and apply final processing
1175
  self.post_init()
1176
 
1177
+ def get_input_embeddings(self):
1178
+ return self.model.embed_tokens
1179
+
1180
+ def set_input_embeddings(self, value):
1181
+ self.model.embed_tokens = value
1182
+
1183
+ def get_output_embeddings(self):
1184
+ return self.lm_head
1185
 
1186
+ def set_output_embeddings(self, new_embeddings):
1187
+ self.lm_head = new_embeddings
1188
 
1189
+ def set_decoder(self, decoder):
1190
+ self.model = decoder
1191
+
1192
+ def get_decoder(self):
1193
+ return self.model
1194
+
1195
+ @add_start_docstrings_to_model_forward(LLAMOE_INPUTS_DOCSTRING)
1196
+ @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1197
+ # Ignore copy
1198
  def forward(
1199
  self,
1200
+ input_ids: torch.LongTensor = None,
1201
+ attention_mask: Optional[torch.Tensor] = None,
1202
+ position_ids: Optional[torch.LongTensor] = None,
1203
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1204
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1205
  labels: Optional[torch.LongTensor] = None,
1206
+ use_cache: Optional[bool] = None,
1207
+ output_attentions: Optional[bool] = None,
1208
+ output_hidden_states: Optional[bool] = None,
1209
+ output_router_logits: Optional[bool] = None,
1210
+ return_dict: Optional[bool] = None,
1211
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
1212
+ r"""
1213
+ Args:
1214
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1215
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1216
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1217
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1218
+
1219
+ Returns:
1220
+
1221
+ Example:
1222
+
1223
+ ```python
1224
+ >>> from transformers import AutoTokenizer, MixtralForCausalLM
1225
+
1226
+ >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
1227
+ >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
1228
+
1229
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1230
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1231
+
1232
+ >>> # Generate
1233
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1234
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1235
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1236
+ ```"""
1237
+
1238
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1239
+ output_router_logits = (
1240
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
1241
+ )
1242
+
1243
+ output_hidden_states = (
1244
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1245
+ )
1246
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1247
+
1248
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1249
+ outputs = self.model(
1250
+ input_ids=input_ids,
1251
+ attention_mask=attention_mask,
1252
+ position_ids=position_ids,
1253
+ past_key_values=past_key_values,
1254
+ inputs_embeds=inputs_embeds,
1255
+ use_cache=use_cache,
1256
+ output_attentions=output_attentions,
1257
+ output_hidden_states=output_hidden_states,
1258
+ output_router_logits=output_router_logits,
1259
+ return_dict=return_dict,
1260
+ )
1261
+
1262
+ hidden_states = outputs[0]
1263
+ logits = self.lm_head(hidden_states)
1264
+ logits = logits.float()
1265
 
1266
  loss = None
1267
  if labels is not None:
1268
+ # Shift so that tokens < n predict n
1269
+ shift_logits = logits[..., :-1, :].contiguous()
1270
+ shift_labels = labels[..., 1:].contiguous()
1271
+ # Flatten the tokens
1272
+ loss_fct = CrossEntropyLoss()
1273
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1274
+ shift_labels = shift_labels.view(-1)
1275
+ # Enable model parallelism
1276
+ shift_labels = shift_labels.to(shift_logits.device)
1277
+ loss = loss_fct(shift_logits, shift_labels)
1278
+
1279
+ aux_loss = None
1280
+ if output_router_logits:
1281
+ aux_loss = load_balancing_loss_func(
1282
+ outputs.router_logits if return_dict else outputs[-1],
1283
+ self.num_experts,
1284
+ self.num_experts_per_tok,
1285
+ attention_mask,
1286
+ )
1287
+ if labels is not None:
1288
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
1289
+
1290
+ if not return_dict:
1291
+ output = (logits,) + outputs[1:]
1292
+ if output_router_logits:
1293
+ output = (aux_loss,) + output
1294
+ return (loss,) + output if loss is not None else output
1295
+
1296
+ return MoeCausalLMOutputWithPast(
1297
+ loss=loss,
1298
+ aux_loss=aux_loss,
1299
+ logits=logits,
1300
+ past_key_values=outputs.past_key_values,
1301
+ hidden_states=outputs.hidden_states,
1302
+ attentions=outputs.attentions,
1303
+ router_logits=outputs.router_logits,
1304
+ )
1305
 
1306
+ def prepare_inputs_for_generation(
1307
+ self,
1308
+ input_ids,
1309
+ past_key_values=None,
1310
+ attention_mask=None,
1311
+ inputs_embeds=None,
1312
+ output_router_logits=False,
1313
+ **kwargs,
1314
+ ):
1315
+ # Omit tokens covered by past_key_values
1316
+ if past_key_values is not None:
1317
+ if isinstance(past_key_values, Cache):
1318
+ cache_length = past_key_values.get_seq_length()
1319
+ past_length = past_key_values.seen_tokens
1320
+ max_cache_length = past_key_values.get_max_length()
1321
+ else:
1322
+ cache_length = past_length = past_key_values[0][0].shape[2]
1323
+ max_cache_length = None
1324
+
1325
+ # Keep only the unprocessed tokens:
1326
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1327
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1328
+ # input)
1329
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1330
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1331
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1332
+ # input_ids based on the past_length.
1333
+ elif past_length < input_ids.shape[1]:
1334
+ input_ids = input_ids[:, past_length:]
1335
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1336
+
1337
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1338
+ if (
1339
+ max_cache_length is not None
1340
+ and attention_mask is not None
1341
+ and cache_length + input_ids.shape[1] > max_cache_length
1342
+ ):
1343
+ attention_mask = attention_mask[:, -max_cache_length:]
1344
+
1345
+ position_ids = kwargs.get("position_ids", None)
1346
+ if attention_mask is not None and position_ids is None:
1347
+ # create position_ids on the fly for batch generation
1348
+ position_ids = attention_mask.long().cumsum(-1) - 1
1349
+ position_ids.masked_fill_(attention_mask == 0, 1)
1350
+ if past_key_values:
1351
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1352
+
1353
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1354
+ if inputs_embeds is not None and past_key_values is None:
1355
+ model_inputs = {"inputs_embeds": inputs_embeds}
1356
+ else:
1357
+ model_inputs = {"input_ids": input_ids}
1358
+
1359
+ model_inputs.update(
1360
+ {
1361
+ "position_ids": position_ids,
1362
+ "past_key_values": past_key_values,
1363
+ "use_cache": kwargs.get("use_cache"),
1364
+ "attention_mask": attention_mask,
1365
+ "output_router_logits": output_router_logits,
1366
+ }
1367
+ )
1368
+ return model_inputs
1369
+
1370
+ @staticmethod
1371
+ def _reorder_cache(past_key_values, beam_idx):
1372
+ reordered_past = ()
1373
+ for layer_past in past_key_values:
1374
+ reordered_past += (
1375
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1376
+ )
1377
+ return reordered_past