feat: add flash attention v2
#9
by
jon-tow
- opened
- README.md +30 -1
- modeling_stablelm_epoch.py +246 -19
README.md
CHANGED
@@ -36,7 +36,7 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
36 |
torch_dtype="auto",
|
37 |
)
|
38 |
model.cuda()
|
39 |
-
inputs = tokenizer("The weather is always wonderful", return_tensors="pt").to(
|
40 |
tokens = model.generate(
|
41 |
**inputs,
|
42 |
max_new_tokens=64,
|
@@ -47,6 +47,35 @@ tokens = model.generate(
|
|
47 |
print(tokenizer.decode(tokens[0], skip_special_tokens=True))
|
48 |
```
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
## Model Details
|
51 |
|
52 |
* **Developed by**: [Stability AI](https://stability.ai/)
|
|
|
36 |
torch_dtype="auto",
|
37 |
)
|
38 |
model.cuda()
|
39 |
+
inputs = tokenizer("The weather is always wonderful", return_tensors="pt").to(model.device)
|
40 |
tokens = model.generate(
|
41 |
**inputs,
|
42 |
max_new_tokens=64,
|
|
|
47 |
print(tokenizer.decode(tokens[0], skip_special_tokens=True))
|
48 |
```
|
49 |
|
50 |
+
### Run with Flash Attention 2 ⚡️
|
51 |
+
|
52 |
+
<details>
|
53 |
+
<summary> Click to expand </summary>
|
54 |
+
|
55 |
+
```python
|
56 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
57 |
+
tokenizer = AutoTokenizer.from_pretrained("stabilityai/stablelm-3b-4e1t")
|
58 |
+
model = AutoModelForCausalLM.from_pretrained(
|
59 |
+
"stabilityai/stablelm-3b-4e1t",
|
60 |
+
trust_remote_code=True,
|
61 |
+
torch_dtype="auto",
|
62 |
+
+ use_flash_attention_2=True,
|
63 |
+
)
|
64 |
+
model.cuda()
|
65 |
+
inputs = tokenizer("The weather is always wonderful", return_tensors="pt").to(model.device)
|
66 |
+
tokens = model.generate(
|
67 |
+
**inputs,
|
68 |
+
max_new_tokens=64,
|
69 |
+
temperature=0.75,
|
70 |
+
top_p=0.95,
|
71 |
+
do_sample=True,
|
72 |
+
)
|
73 |
+
print(tokenizer.decode(tokens[0], skip_special_tokens=True))
|
74 |
+
```
|
75 |
+
|
76 |
+
</details>
|
77 |
+
|
78 |
+
|
79 |
## Model Details
|
80 |
|
81 |
* **Developed by**: [Stability AI](https://stability.ai/)
|
modeling_stablelm_epoch.py
CHANGED
@@ -19,23 +19,46 @@
|
|
19 |
""" PyTorch StableLM Epoch model. """
|
20 |
from typing import Optional, Tuple, Union
|
21 |
import math
|
|
|
22 |
|
23 |
import torch
|
|
|
24 |
import torch.utils.checkpoint
|
25 |
from torch import nn
|
26 |
from torch.nn import CrossEntropyLoss
|
|
|
|
|
27 |
from transformers.modeling_outputs import (
|
28 |
BaseModelOutputWithPast,
|
29 |
CausalLMOutputWithPast,
|
30 |
)
|
31 |
from transformers.modeling_utils import PreTrainedModel
|
32 |
-
from transformers.utils import logging
|
|
|
33 |
from .configuration_stablelm_epoch import StableLMEpochConfig
|
34 |
|
35 |
|
|
|
|
|
|
|
|
|
|
|
36 |
logger = logging.get_logger(__name__)
|
37 |
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
40 |
def _make_causal_mask(
|
41 |
input_ids_shape: torch.Size,
|
@@ -165,6 +188,7 @@ class Attention(nn.Module):
|
|
165 |
self.num_key_value_heads = config.num_key_value_heads
|
166 |
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
167 |
self.max_position_embeddings = config.max_position_embeddings
|
|
|
168 |
|
169 |
if (self.head_dim * self.num_heads) != self.hidden_size:
|
170 |
raise ValueError(
|
@@ -269,10 +293,202 @@ class Attention(nn.Module):
|
|
269 |
return attn_output, attn_weights, past_key_value
|
270 |
|
271 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
class DecoderLayer(nn.Module):
|
273 |
def __init__(self, config: StableLMEpochConfig):
|
274 |
super().__init__()
|
275 |
-
self.self_attn =
|
276 |
self.mlp = MLP(config)
|
277 |
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
|
278 |
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
|
@@ -328,6 +544,7 @@ class StableLMEpochPreTrainedModel(PreTrainedModel):
|
|
328 |
supports_gradient_checkpointing = True
|
329 |
_no_split_modules = ["DecoderLayer"]
|
330 |
_skip_keys_device_placement = "past_key_values"
|
|
|
331 |
|
332 |
def _init_weights(self, module: nn.Module):
|
333 |
"""Initialize the weights"""
|
@@ -355,6 +572,7 @@ class StableLMEpochModel(StableLMEpochPreTrainedModel):
|
|
355 |
self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
356 |
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
|
357 |
|
|
|
358 |
self.gradient_checkpointing = False
|
359 |
# Initialize weights and apply final processing
|
360 |
self.post_init()
|
@@ -428,10 +646,6 @@ class StableLMEpochModel(StableLMEpochPreTrainedModel):
|
|
428 |
seq_length_with_past = seq_length
|
429 |
past_key_values_length = 0
|
430 |
|
431 |
-
if past_key_values is not None:
|
432 |
-
past_key_values_length = past_key_values[0][0].shape[2]
|
433 |
-
seq_length_with_past = seq_length_with_past + past_key_values_length
|
434 |
-
|
435 |
if position_ids is None:
|
436 |
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
437 |
position_ids = torch.arange(
|
@@ -447,18 +661,22 @@ class StableLMEpochModel(StableLMEpochPreTrainedModel):
|
|
447 |
if inputs_embeds is None:
|
448 |
inputs_embeds = self.embed_tokens(input_ids)
|
449 |
# Embed positions
|
450 |
-
if
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
455 |
)
|
456 |
-
attention_mask = self._prepare_decoder_attention_mask(
|
457 |
-
attention_mask,
|
458 |
-
(batch_size, seq_length),
|
459 |
-
inputs_embeds,
|
460 |
-
past_key_values_length,
|
461 |
-
)
|
462 |
|
463 |
hidden_states = inputs_embeds
|
464 |
|
@@ -643,8 +861,17 @@ class StableLMEpochForCausalLM(StableLMEpochPreTrainedModel):
|
|
643 |
**kwargs,
|
644 |
):
|
645 |
# Trim decoder_input_ids if past is used
|
646 |
-
if past_key_values
|
647 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
648 |
|
649 |
position_ids = kwargs.get("position_ids", None)
|
650 |
if attention_mask is not None and position_ids is None:
|
|
|
19 |
""" PyTorch StableLM Epoch model. """
|
20 |
from typing import Optional, Tuple, Union
|
21 |
import math
|
22 |
+
import warnings
|
23 |
|
24 |
import torch
|
25 |
+
import torch.nn.functional as F
|
26 |
import torch.utils.checkpoint
|
27 |
from torch import nn
|
28 |
from torch.nn import CrossEntropyLoss
|
29 |
+
|
30 |
+
from transformers.cache_utils import Cache
|
31 |
from transformers.modeling_outputs import (
|
32 |
BaseModelOutputWithPast,
|
33 |
CausalLMOutputWithPast,
|
34 |
)
|
35 |
from transformers.modeling_utils import PreTrainedModel
|
36 |
+
from transformers.utils import logging, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10
|
37 |
+
|
38 |
from .configuration_stablelm_epoch import StableLMEpochConfig
|
39 |
|
40 |
|
41 |
+
if is_flash_attn_2_available():
|
42 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
43 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
44 |
+
|
45 |
+
|
46 |
logger = logging.get_logger(__name__)
|
47 |
|
48 |
|
49 |
+
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
50 |
+
def _get_unpad_data(attention_mask):
|
51 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
52 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
53 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
54 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
55 |
+
return (
|
56 |
+
indices,
|
57 |
+
cu_seqlens,
|
58 |
+
max_seqlen_in_batch,
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
63 |
def _make_causal_mask(
|
64 |
input_ids_shape: torch.Size,
|
|
|
188 |
self.num_key_value_heads = config.num_key_value_heads
|
189 |
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
190 |
self.max_position_embeddings = config.max_position_embeddings
|
191 |
+
self.is_causal = True
|
192 |
|
193 |
if (self.head_dim * self.num_heads) != self.hidden_size:
|
194 |
raise ValueError(
|
|
|
293 |
return attn_output, attn_weights, past_key_value
|
294 |
|
295 |
|
296 |
+
class FlashAttention2(Attention):
|
297 |
+
"""
|
298 |
+
Reference: https://github.com/huggingface/transformers/blob/5d36025ca13d05151b7a0c761e90d429c4644a30/src/transformers/models/llama/modeling_llama.py#L456
|
299 |
+
"""
|
300 |
+
|
301 |
+
def __init__(self, *args, **kwargs):
|
302 |
+
super().__init__(*args, **kwargs)
|
303 |
+
|
304 |
+
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
305 |
+
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
306 |
+
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
307 |
+
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
308 |
+
|
309 |
+
def forward(
|
310 |
+
self,
|
311 |
+
hidden_states: torch.Tensor,
|
312 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
313 |
+
position_ids: Optional[torch.LongTensor] = None,
|
314 |
+
past_key_value: Optional[Cache] = None,
|
315 |
+
output_attentions: bool = False,
|
316 |
+
use_cache: bool = False,
|
317 |
+
**kwargs,
|
318 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
319 |
+
# FlashAttention2 attention does not support output_attentions
|
320 |
+
if "padding_mask" in kwargs:
|
321 |
+
warnings.warn(
|
322 |
+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
323 |
+
)
|
324 |
+
|
325 |
+
# overwrite attention_mask with padding_mask
|
326 |
+
attention_mask = kwargs.pop("padding_mask")
|
327 |
+
|
328 |
+
output_attentions = False
|
329 |
+
|
330 |
+
bsz, q_len, _ = hidden_states.size()
|
331 |
+
|
332 |
+
query_states = self.q_proj(hidden_states)
|
333 |
+
key_states = self.k_proj(hidden_states)
|
334 |
+
value_states = self.v_proj(hidden_states)
|
335 |
+
|
336 |
+
# Flash attention requires the input to have the shape
|
337 |
+
# batch_size x seq_length x head_dim x hidden_dim
|
338 |
+
# therefore we just need to keep the original shape
|
339 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
340 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
341 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
342 |
+
|
343 |
+
query_rot = query_states[..., : self.rotary_ndims]
|
344 |
+
query_pass = query_states[..., self.rotary_ndims :]
|
345 |
+
key_rot = key_states[..., : self.rotary_ndims]
|
346 |
+
key_pass = key_states[..., self.rotary_ndims :]
|
347 |
+
|
348 |
+
kv_seq_len = key_states.shape[-2]
|
349 |
+
if past_key_value is not None:
|
350 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
351 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
352 |
+
query_states, key_states = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
|
353 |
+
|
354 |
+
# [batch_size, num_heads, seq_len, head_dim]
|
355 |
+
query_states = torch.cat((query_states, query_pass), dim=-1)
|
356 |
+
key_states = torch.cat((key_states, key_pass), dim=-1)
|
357 |
+
|
358 |
+
if past_key_value is not None:
|
359 |
+
# Reuse k, v, self_attention
|
360 |
+
key_states = torch.cat((past_key_value[0], key_states), dim=2)
|
361 |
+
value_states = torch.cat((past_key_value[1], value_states), dim=2)
|
362 |
+
|
363 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
364 |
+
|
365 |
+
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
366 |
+
# to be able to avoid many of these transpose/reshape/view.
|
367 |
+
query_states = query_states.transpose(1, 2)
|
368 |
+
key_states = key_states.transpose(1, 2)
|
369 |
+
value_states = value_states.transpose(1, 2)
|
370 |
+
|
371 |
+
dropout_rate = self.attention_dropout if self.training else 0.0
|
372 |
+
|
373 |
+
attn_output = self._flash_attention_forward(
|
374 |
+
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
|
375 |
+
)
|
376 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
377 |
+
attn_output = self.o_proj(attn_output)
|
378 |
+
|
379 |
+
if not output_attentions:
|
380 |
+
attn_weights = None
|
381 |
+
|
382 |
+
return attn_output, attn_weights, past_key_value
|
383 |
+
|
384 |
+
def _flash_attention_forward(
|
385 |
+
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
386 |
+
):
|
387 |
+
"""
|
388 |
+
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
389 |
+
first unpad the input, then computes the attention scores and pad the final attention scores.
|
390 |
+
|
391 |
+
Args:
|
392 |
+
query_states (`torch.Tensor`):
|
393 |
+
Input query states to be passed to Flash Attention API
|
394 |
+
key_states (`torch.Tensor`):
|
395 |
+
Input key states to be passed to Flash Attention API
|
396 |
+
value_states (`torch.Tensor`):
|
397 |
+
Input value states to be passed to Flash Attention API
|
398 |
+
attention_mask (`torch.Tensor`):
|
399 |
+
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
400 |
+
position of padding tokens and 1 for the position of non-padding tokens.
|
401 |
+
dropout (`int`, *optional*):
|
402 |
+
Attention dropout
|
403 |
+
softmax_scale (`float`, *optional*):
|
404 |
+
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
405 |
+
"""
|
406 |
+
if not self._flash_attn_uses_top_left_mask:
|
407 |
+
causal = self.is_causal
|
408 |
+
else:
|
409 |
+
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in FlashAttention2 __init__.
|
410 |
+
causal = self.is_causal and query_length != 1
|
411 |
+
|
412 |
+
# Contains at least one padding token in the sequence
|
413 |
+
if attention_mask is not None:
|
414 |
+
batch_size = query_states.shape[0]
|
415 |
+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
416 |
+
query_states, key_states, value_states, attention_mask, query_length
|
417 |
+
)
|
418 |
+
|
419 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
420 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
421 |
+
|
422 |
+
attn_output_unpad = flash_attn_varlen_func(
|
423 |
+
query_states,
|
424 |
+
key_states,
|
425 |
+
value_states,
|
426 |
+
cu_seqlens_q=cu_seqlens_q,
|
427 |
+
cu_seqlens_k=cu_seqlens_k,
|
428 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
429 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
430 |
+
dropout_p=dropout,
|
431 |
+
softmax_scale=softmax_scale,
|
432 |
+
causal=causal,
|
433 |
+
)
|
434 |
+
|
435 |
+
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
436 |
+
else:
|
437 |
+
attn_output = flash_attn_func(
|
438 |
+
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
|
439 |
+
)
|
440 |
+
|
441 |
+
return attn_output
|
442 |
+
|
443 |
+
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
444 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
445 |
+
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
446 |
+
|
447 |
+
key_layer = index_first_axis(
|
448 |
+
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
449 |
+
)
|
450 |
+
value_layer = index_first_axis(
|
451 |
+
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
452 |
+
)
|
453 |
+
if query_length == kv_seq_len:
|
454 |
+
query_layer = index_first_axis(
|
455 |
+
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
|
456 |
+
)
|
457 |
+
cu_seqlens_q = cu_seqlens_k
|
458 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
459 |
+
indices_q = indices_k
|
460 |
+
elif query_length == 1:
|
461 |
+
max_seqlen_in_batch_q = 1
|
462 |
+
cu_seqlens_q = torch.arange(
|
463 |
+
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
464 |
+
) # There is a memcpy here, that is very bad.
|
465 |
+
indices_q = cu_seqlens_q[:-1]
|
466 |
+
query_layer = query_layer.squeeze(1)
|
467 |
+
else:
|
468 |
+
# The -q_len: slice assumes left padding.
|
469 |
+
attention_mask = attention_mask[:, -query_length:]
|
470 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
471 |
+
|
472 |
+
return (
|
473 |
+
query_layer,
|
474 |
+
key_layer,
|
475 |
+
value_layer,
|
476 |
+
indices_q,
|
477 |
+
(cu_seqlens_q, cu_seqlens_k),
|
478 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
479 |
+
)
|
480 |
+
|
481 |
+
|
482 |
+
ATTENTION_CLASSES = {
|
483 |
+
"eager": Attention,
|
484 |
+
"flash_attention_2": FlashAttention2,
|
485 |
+
}
|
486 |
+
|
487 |
+
|
488 |
class DecoderLayer(nn.Module):
|
489 |
def __init__(self, config: StableLMEpochConfig):
|
490 |
super().__init__()
|
491 |
+
self.self_attn = ATTENTION_CLASSES[config._attn_implementation](config=config)
|
492 |
self.mlp = MLP(config)
|
493 |
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
|
494 |
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
|
|
|
544 |
supports_gradient_checkpointing = True
|
545 |
_no_split_modules = ["DecoderLayer"]
|
546 |
_skip_keys_device_placement = "past_key_values"
|
547 |
+
_supports_flash_attn_2 = True
|
548 |
|
549 |
def _init_weights(self, module: nn.Module):
|
550 |
"""Initialize the weights"""
|
|
|
572 |
self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
573 |
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
|
574 |
|
575 |
+
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
576 |
self.gradient_checkpointing = False
|
577 |
# Initialize weights and apply final processing
|
578 |
self.post_init()
|
|
|
646 |
seq_length_with_past = seq_length
|
647 |
past_key_values_length = 0
|
648 |
|
|
|
|
|
|
|
|
|
649 |
if position_ids is None:
|
650 |
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
651 |
position_ids = torch.arange(
|
|
|
661 |
if inputs_embeds is None:
|
662 |
inputs_embeds = self.embed_tokens(input_ids)
|
663 |
# Embed positions
|
664 |
+
if self._use_flash_attention_2:
|
665 |
+
# 2d mask is passed through the layers
|
666 |
+
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
667 |
+
else:
|
668 |
+
if attention_mask is None:
|
669 |
+
attention_mask = torch.ones(
|
670 |
+
(batch_size, seq_length_with_past),
|
671 |
+
dtype=torch.bool,
|
672 |
+
device=inputs_embeds.device,
|
673 |
+
)
|
674 |
+
attention_mask = self._prepare_decoder_attention_mask(
|
675 |
+
attention_mask,
|
676 |
+
(batch_size, seq_length),
|
677 |
+
inputs_embeds,
|
678 |
+
past_key_values_length,
|
679 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
680 |
|
681 |
hidden_states = inputs_embeds
|
682 |
|
|
|
861 |
**kwargs,
|
862 |
):
|
863 |
# Trim decoder_input_ids if past is used
|
864 |
+
if past_key_values is not None:
|
865 |
+
past_length = past_key_values[0][0].shape[2]
|
866 |
+
|
867 |
+
# Some generation methods already pass only the last input ID
|
868 |
+
if input_ids.shape[1] > past_length:
|
869 |
+
remove_prefix_length = past_length
|
870 |
+
else:
|
871 |
+
# Default to old behavior: keep only final ID
|
872 |
+
remove_prefix_length = input_ids.shape[1] - 1
|
873 |
+
|
874 |
+
input_ids = input_ids[:, remove_prefix_length:]
|
875 |
|
876 |
position_ids = kwargs.get("position_ids", None)
|
877 |
if attention_mask is not None and position_ids is None:
|