Added flash attention
Browse files- configuration_megatron_gpt.py +4 -0
- modeling_megatron_gpt.py +48 -2
configuration_megatron_gpt.py
CHANGED
@@ -81,6 +81,8 @@ class MegatronGPTConfig(PretrainedConfig):
|
|
81 |
Whether to calculate and apply the relative position bias within the attention function.
|
82 |
If this is False, then model.generate will require you to calculate the triangular attention
|
83 |
mask and pass it through in the attention mask.
|
|
|
|
|
84 |
rope_scaling (`Dict`, *optional*):
|
85 |
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling
|
86 |
strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
|
@@ -118,6 +120,7 @@ class MegatronGPTConfig(PretrainedConfig):
|
|
118 |
eos_token_id=2,
|
119 |
tie_word_embeddings=False,
|
120 |
rope_scaling=None,
|
|
|
121 |
**kwargs,
|
122 |
):
|
123 |
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
@@ -141,6 +144,7 @@ class MegatronGPTConfig(PretrainedConfig):
|
|
141 |
self.use_cache = use_cache
|
142 |
self.self_attention_relative_position_bias = self_attention_relative_position_bias
|
143 |
self.tie_word_embeddings = tie_word_embeddings
|
|
|
144 |
self.rope_scaling = rope_scaling
|
145 |
self._rope_scaling_validation()
|
146 |
|
|
|
81 |
Whether to calculate and apply the relative position bias within the attention function.
|
82 |
If this is False, then model.generate will require you to calculate the triangular attention
|
83 |
mask and pass it through in the attention mask.
|
84 |
+
use_flash_attention (`bool`, *optional*, defaults to `False`):
|
85 |
+
When calculating attention, whether to attempt to use flash attention if it's installed, or to always skip and use the regular method.
|
86 |
rope_scaling (`Dict`, *optional*):
|
87 |
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling
|
88 |
strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
|
|
|
120 |
eos_token_id=2,
|
121 |
tie_word_embeddings=False,
|
122 |
rope_scaling=None,
|
123 |
+
use_flash_attention=False,
|
124 |
**kwargs,
|
125 |
):
|
126 |
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
|
|
144 |
self.use_cache = use_cache
|
145 |
self.self_attention_relative_position_bias = self_attention_relative_position_bias
|
146 |
self.tie_word_embeddings = tie_word_embeddings
|
147 |
+
self.use_flash_attention = use_flash_attention
|
148 |
self.rope_scaling = rope_scaling
|
149 |
self._rope_scaling_validation()
|
150 |
|
modeling_megatron_gpt.py
CHANGED
@@ -21,6 +21,7 @@
|
|
21 |
""" PyTorch MegatronGPT model."""
|
22 |
|
23 |
from dataclasses import dataclass
|
|
|
24 |
from typing import Optional, Tuple, Union
|
25 |
|
26 |
import torch
|
@@ -43,8 +44,21 @@ from transformers.modeling_outputs import (
|
|
43 |
)
|
44 |
from transformers.modeling_utils import PreTrainedModel
|
45 |
from transformers.utils import logging
|
|
|
46 |
from .configuration_megatron_gpt import MegatronGPTConfig
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
def get_activation(act):
|
49 |
if act in ["gelu", "geglu", "fast-geglu"]:
|
50 |
act = 'gelu'
|
@@ -111,9 +125,10 @@ class MegatronGPTAttention(nn.Module):
|
|
111 |
self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
|
112 |
self._init_rope()
|
113 |
|
|
|
114 |
self.register_buffer(
|
115 |
"norm_factor",
|
116 |
-
torch.
|
117 |
persistent=False,
|
118 |
)
|
119 |
self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=config.bias)
|
@@ -207,7 +222,10 @@ class MegatronGPTAttention(nn.Module):
|
|
207 |
present = (key, value) if use_cache else None
|
208 |
|
209 |
# Compute attention
|
210 |
-
|
|
|
|
|
|
|
211 |
|
212 |
# Reshape outputs
|
213 |
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)
|
@@ -244,6 +262,34 @@ class MegatronGPTAttention(nn.Module):
|
|
244 |
# -> [bs, seq_len, hidden_size]
|
245 |
return tensor
|
246 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
248 |
# q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
|
249 |
# compute causal mask from causal mask buffer
|
|
|
21 |
""" PyTorch MegatronGPT model."""
|
22 |
|
23 |
from dataclasses import dataclass
|
24 |
+
import math
|
25 |
from typing import Optional, Tuple, Union
|
26 |
|
27 |
import torch
|
|
|
44 |
)
|
45 |
from transformers.modeling_utils import PreTrainedModel
|
46 |
from transformers.utils import logging
|
47 |
+
# try to load using a relative path, but if it fails try loading it directly
|
48 |
from .configuration_megatron_gpt import MegatronGPTConfig
|
49 |
|
50 |
+
try:
|
51 |
+
from flash_attn.bert_padding import unpad_input, pad_input
|
52 |
+
from flash_attn import flash_attn_varlen_func as flash_attn_func
|
53 |
+
HAS_FLASH = True
|
54 |
+
except:
|
55 |
+
try:
|
56 |
+
from flash_attn.flash_attn_interface import flash_attn_unpadded_func as flash_attn_func
|
57 |
+
HAS_FLASH = True
|
58 |
+
except:
|
59 |
+
HAS_FLASH = False
|
60 |
+
|
61 |
+
|
62 |
def get_activation(act):
|
63 |
if act in ["gelu", "geglu", "fast-geglu"]:
|
64 |
act = 'gelu'
|
|
|
125 |
self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
|
126 |
self._init_rope()
|
127 |
|
128 |
+
self.norm_factor_float = math.sqrt(self.head_size if config.normalize_attention_scores else 1.0)
|
129 |
self.register_buffer(
|
130 |
"norm_factor",
|
131 |
+
torch.tensor(self.norm_factor_float, dtype=torch.float32).to(torch.get_default_dtype()),
|
132 |
persistent=False,
|
133 |
)
|
134 |
self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=config.bias)
|
|
|
222 |
present = (key, value) if use_cache else None
|
223 |
|
224 |
# Compute attention
|
225 |
+
if not HAS_FLASH or output_attentions or head_mask is not None or not self.config.use_flash_attention:
|
226 |
+
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
227 |
+
else:
|
228 |
+
attn_output = self._flash_attn(query, key, value, attention_mask)
|
229 |
|
230 |
# Reshape outputs
|
231 |
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)
|
|
|
262 |
# -> [bs, seq_len, hidden_size]
|
263 |
return tensor
|
264 |
|
265 |
+
def _flash_attn(self, query, key, value, attention_mask=None):
|
266 |
+
# q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
|
267 |
+
# compute causal mask from causal mask buffer
|
268 |
+
batch_size, num_attention_heads, query_seq_length, attn_head_size = query.size()
|
269 |
+
|
270 |
+
# transpose_for_scores_flash returns b s h d
|
271 |
+
query_layer = query.transpose(1, 2).half()
|
272 |
+
key_layer = key.transpose(1, 2).half()
|
273 |
+
value_layer = value.transpose(1, 2).half()
|
274 |
+
|
275 |
+
# fix the mask
|
276 |
+
attention_mask = (attention_mask == 0).int().squeeze(1).squeeze(1)
|
277 |
+
query_layer, query_indicies, cu_seqlens_q, max_seqlen_q = unpad_input(query_layer, attention_mask[:, -query_seq_length:])
|
278 |
+
key_layer, _, cu_seqlens_k, max_seqlen_k = unpad_input(key_layer, attention_mask)
|
279 |
+
value_layer, _, _, _ = unpad_input(value_layer, attention_mask)
|
280 |
+
|
281 |
+
# returns [batch * seq, nheads, headdim]
|
282 |
+
context_layer = flash_attn_func(query_layer, key_layer, value_layer,
|
283 |
+
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
284 |
+
dropout_p=self.config.attention_dropout, softmax_scale=1 / self.norm_factor_float, causal=self.self_attention_relative_position_bias if max_seqlen_q > 1 else False)
|
285 |
+
|
286 |
+
# fix the shape to be [bs, num_attention_heads, seq_len, attn_head_size]
|
287 |
+
context_layer = pad_input(context_layer, query_indicies, batch_size, query_seq_length)
|
288 |
+
context_layer = context_layer.view(batch_size, query_seq_length, num_attention_heads, attn_head_size) \
|
289 |
+
.transpose(1, 2)
|
290 |
+
|
291 |
+
return context_layer.to(value.dtype)
|
292 |
+
|
293 |
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
294 |
# q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
|
295 |
# compute causal mask from causal mask buffer
|