add kv cache
Browse files- attention.py +57 -16
- blocks.py +3 -2
- config.json +2 -1
- generation_config.json +1 -1
- modeling_mpt.py +23 -5
attention.py
CHANGED
@@ -18,6 +18,7 @@ class PastKeyValue(NamedTuple):
|
|
18 |
class AttnFnOutput(NamedTuple):
|
19 |
attns: torch.Tensor
|
20 |
attn_probs: Optional[torch.Tensor]
|
|
|
21 |
|
22 |
class AttnFn(Protocol):
|
23 |
def __call__(
|
@@ -81,6 +82,7 @@ def scaled_multihead_dot_product_attention(
|
|
81 |
key: torch.Tensor,
|
82 |
value: torch.Tensor,
|
83 |
n_heads: int,
|
|
|
84 |
softmax_scale: Optional[float] = None,
|
85 |
attn_bias: Optional[torch.Tensor] = None,
|
86 |
key_padding_mask: Optional[torch.ByteTensor] = None,
|
@@ -91,23 +93,41 @@ def scaled_multihead_dot_product_attention(
|
|
91 |
multiquery = False,
|
92 |
) -> AttnFnOutput:
|
93 |
q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
|
94 |
-
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
(b, _, s_q, d) = q.shape
|
98 |
s_k = k.size(-1)
|
99 |
if softmax_scale is None:
|
100 |
softmax_scale = 1 / math.sqrt(d)
|
101 |
attn_weight = q.matmul(k) * softmax_scale
|
102 |
if attn_bias is not None:
|
|
|
|
|
|
|
|
|
103 |
if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
|
104 |
raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.')
|
105 |
attn_weight = attn_weight + attn_bias
|
|
|
106 |
if key_padding_mask is not None:
|
107 |
if attn_bias is not None:
|
108 |
warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
|
109 |
attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
|
110 |
-
if is_causal:
|
111 |
s = max(s_q, s_k)
|
112 |
causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
|
113 |
causal_mask = causal_mask.tril()
|
@@ -121,8 +141,8 @@ def scaled_multihead_dot_product_attention(
|
|
121 |
out = attn_weight.matmul(v)
|
122 |
out = rearrange(out, 'b h s d -> b s (h d)')
|
123 |
if needs_weights:
|
124 |
-
return AttnFnOutput(out, attn_weight)
|
125 |
-
return AttnFnOutput(out, None)
|
126 |
|
127 |
def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
|
128 |
for tensor in tensors:
|
@@ -136,6 +156,7 @@ def flash_attn_fn(
|
|
136 |
key: torch.Tensor,
|
137 |
value: torch.Tensor,
|
138 |
n_heads: int,
|
|
|
139 |
softmax_scale: Optional[float] = None,
|
140 |
attn_bias: Optional[torch.Tensor] = None,
|
141 |
key_padding_mask: Optional[torch.ByteTensor] = None,
|
@@ -150,6 +171,18 @@ def flash_attn_fn(
|
|
150 |
except:
|
151 |
raise RuntimeError('Please install flash-attn==1.0.3.post0')
|
152 |
check_valid_inputs(query, key, value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
if attn_bias is not None:
|
154 |
raise NotImplementedError(f'attn_bias not implemented for flash attn.')
|
155 |
(batch_size, seqlen) = query.shape[:2]
|
@@ -169,13 +202,14 @@ def flash_attn_fn(
|
|
169 |
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
170 |
output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
|
171 |
output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
|
172 |
-
return AttnFnOutput(output, None)
|
173 |
|
174 |
def triton_flash_attn_fn(
|
175 |
query: torch.Tensor,
|
176 |
key: torch.Tensor,
|
177 |
value: torch.Tensor,
|
178 |
n_heads: int,
|
|
|
179 |
softmax_scale: Optional[float] = None,
|
180 |
attn_bias: Optional[torch.Tensor] = None,
|
181 |
key_padding_mask: Optional[torch.ByteTensor] = None,
|
@@ -198,6 +232,18 @@ def triton_flash_attn_fn(
|
|
198 |
if not _installed:
|
199 |
raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed.')
|
200 |
check_valid_inputs(query, key, value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
if dropout_p:
|
202 |
raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
|
203 |
if needs_weights:
|
@@ -217,7 +263,7 @@ def triton_flash_attn_fn(
|
|
217 |
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
218 |
attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
|
219 |
output = attn_output.view(*attn_output.shape[:2], -1)
|
220 |
-
return AttnFnOutput(output, None)
|
221 |
|
222 |
class MultiheadAttention(nn.Module, Attn):
|
223 |
"""Multi-head self attention.
|
@@ -278,13 +324,6 @@ class MultiheadAttention(nn.Module, Attn):
|
|
278 |
dtype = query.dtype
|
279 |
query = self.q_ln(query).to(dtype)
|
280 |
key = self.k_ln(key).to(dtype)
|
281 |
-
if past_key_value is not None:
|
282 |
-
if len(past_key_value) != 0:
|
283 |
-
key = torch.cat([past_key_value[0], key], dim=1)
|
284 |
-
value = torch.cat([past_key_value[1], value], dim=1)
|
285 |
-
past_key_value = PastKeyValue(key, value)
|
286 |
-
if attn_bias is not None:
|
287 |
-
attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
|
288 |
if self.training and self.gradient_checkpointing:
|
289 |
ckpt_kwargs: Dict[str, Any] = {'use_reentrant': False} if is_torch_version('>=', '1.11.0') else {}
|
290 |
def create_custom_forward(attn_fn: AttnFn) -> AttnFnCheckpointed:
|
@@ -337,6 +376,7 @@ class MultiheadAttention(nn.Module, Attn):
|
|
337 |
key,
|
338 |
value,
|
339 |
self.n_heads,
|
|
|
340 |
softmax_scale=self.softmax_scale,
|
341 |
attn_bias=attn_bias,
|
342 |
key_padding_mask=key_padding_mask,
|
@@ -345,7 +385,7 @@ class MultiheadAttention(nn.Module, Attn):
|
|
345 |
training=self.training,
|
346 |
needs_weights=needs_weights,
|
347 |
)
|
348 |
-
context, attn_weights = attn_fn_out
|
349 |
return AttnOutput(self.out_proj(context), attn_weights, past_key_value)
|
350 |
|
351 |
class MultiQueryAttention(nn.Module, Attn):
|
@@ -465,6 +505,7 @@ class MultiQueryAttention(nn.Module, Attn):
|
|
465 |
key,
|
466 |
value,
|
467 |
self.n_heads,
|
|
|
468 |
softmax_scale=self.softmax_scale,
|
469 |
attn_bias=attn_bias,
|
470 |
key_padding_mask=key_padding_mask,
|
|
|
18 |
class AttnFnOutput(NamedTuple):
|
19 |
attns: torch.Tensor
|
20 |
attn_probs: Optional[torch.Tensor]
|
21 |
+
past_key_value: Union[PastKeyValue, Tuple, None]
|
22 |
|
23 |
class AttnFn(Protocol):
|
24 |
def __call__(
|
|
|
82 |
key: torch.Tensor,
|
83 |
value: torch.Tensor,
|
84 |
n_heads: int,
|
85 |
+
past_key_value=None,
|
86 |
softmax_scale: Optional[float] = None,
|
87 |
attn_bias: Optional[torch.Tensor] = None,
|
88 |
key_padding_mask: Optional[torch.ByteTensor] = None,
|
|
|
93 |
multiquery = False,
|
94 |
) -> AttnFnOutput:
|
95 |
q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
|
96 |
+
kv_n_heads = 1 if multiquery else n_heads
|
97 |
+
k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)
|
98 |
+
v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)
|
99 |
+
|
100 |
+
if past_key_value is not None:
|
101 |
+
# attn_impl: flash & triton use kernels which expect input shape [b, s, h, d_head].
|
102 |
+
# kv_cache is therefore stored using that shape.
|
103 |
+
# attn_impl: torch stores the kv_cache in the ordering which is most advantageous
|
104 |
+
# for its attn computation ie
|
105 |
+
# keys are stored as tensors with shape [b, h, d_head, s] and
|
106 |
+
# values are stored as tensors with shape [b, h, s, d_head]
|
107 |
+
if len(past_key_value) != 0:
|
108 |
+
k = torch.cat([past_key_value[0], k], dim=3)
|
109 |
+
v = torch.cat([past_key_value[1], v], dim=2)
|
110 |
+
|
111 |
+
past_key_value = (k, v)
|
112 |
(b, _, s_q, d) = q.shape
|
113 |
s_k = k.size(-1)
|
114 |
if softmax_scale is None:
|
115 |
softmax_scale = 1 / math.sqrt(d)
|
116 |
attn_weight = q.matmul(k) * softmax_scale
|
117 |
if attn_bias is not None:
|
118 |
+
# clamp to 0 necessary for torch 2.0 compile()
|
119 |
+
_s_q = max(0, attn_bias.size(2) - s_q)
|
120 |
+
_s_k = max(0, attn_bias.size(3) - s_k)
|
121 |
+
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
122 |
if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
|
123 |
raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.')
|
124 |
attn_weight = attn_weight + attn_bias
|
125 |
+
min_val = torch.finfo(q.dtype).min
|
126 |
if key_padding_mask is not None:
|
127 |
if attn_bias is not None:
|
128 |
warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
|
129 |
attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
|
130 |
+
if is_causal and (not q.size(2) == 1):
|
131 |
s = max(s_q, s_k)
|
132 |
causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
|
133 |
causal_mask = causal_mask.tril()
|
|
|
141 |
out = attn_weight.matmul(v)
|
142 |
out = rearrange(out, 'b h s d -> b s (h d)')
|
143 |
if needs_weights:
|
144 |
+
return AttnFnOutput(out, attn_weight, past_key_value)
|
145 |
+
return AttnFnOutput(out, None, past_key_value)
|
146 |
|
147 |
def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
|
148 |
for tensor in tensors:
|
|
|
156 |
key: torch.Tensor,
|
157 |
value: torch.Tensor,
|
158 |
n_heads: int,
|
159 |
+
past_key_value=None,
|
160 |
softmax_scale: Optional[float] = None,
|
161 |
attn_bias: Optional[torch.Tensor] = None,
|
162 |
key_padding_mask: Optional[torch.ByteTensor] = None,
|
|
|
171 |
except:
|
172 |
raise RuntimeError('Please install flash-attn==1.0.3.post0')
|
173 |
check_valid_inputs(query, key, value)
|
174 |
+
if past_key_value is not None:
|
175 |
+
if len(past_key_value) != 0:
|
176 |
+
key = torch.cat([past_key_value[0], key], dim=1)
|
177 |
+
value = torch.cat([past_key_value[1], value], dim=1)
|
178 |
+
|
179 |
+
past_key_value = (key, value)
|
180 |
+
|
181 |
+
if attn_bias is not None:
|
182 |
+
# clamp to 0 necessary for torch 2.0 compile()
|
183 |
+
_s_q = max(0, attn_bias.size(2) - query.size(1))
|
184 |
+
_s_k = max(0, attn_bias.size(3) - key.size(1))
|
185 |
+
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
186 |
if attn_bias is not None:
|
187 |
raise NotImplementedError(f'attn_bias not implemented for flash attn.')
|
188 |
(batch_size, seqlen) = query.shape[:2]
|
|
|
202 |
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
203 |
output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
|
204 |
output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
|
205 |
+
return AttnFnOutput(output, None, past_key_value)
|
206 |
|
207 |
def triton_flash_attn_fn(
|
208 |
query: torch.Tensor,
|
209 |
key: torch.Tensor,
|
210 |
value: torch.Tensor,
|
211 |
n_heads: int,
|
212 |
+
past_key_value=None,
|
213 |
softmax_scale: Optional[float] = None,
|
214 |
attn_bias: Optional[torch.Tensor] = None,
|
215 |
key_padding_mask: Optional[torch.ByteTensor] = None,
|
|
|
232 |
if not _installed:
|
233 |
raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed.')
|
234 |
check_valid_inputs(query, key, value)
|
235 |
+
if past_key_value is not None:
|
236 |
+
if len(past_key_value) != 0:
|
237 |
+
key = torch.cat([past_key_value[0], key], dim=1)
|
238 |
+
value = torch.cat([past_key_value[1], value], dim=1)
|
239 |
+
|
240 |
+
past_key_value = (key, value)
|
241 |
+
|
242 |
+
if attn_bias is not None:
|
243 |
+
# clamp to 0 necessary for torch 2.0 compile()
|
244 |
+
_s_q = max(0, attn_bias.size(2) - query.size(1))
|
245 |
+
_s_k = max(0, attn_bias.size(3) - key.size(1))
|
246 |
+
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
247 |
if dropout_p:
|
248 |
raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
|
249 |
if needs_weights:
|
|
|
263 |
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
264 |
attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
|
265 |
output = attn_output.view(*attn_output.shape[:2], -1)
|
266 |
+
return AttnFnOutput(output, None, past_key_value)
|
267 |
|
268 |
class MultiheadAttention(nn.Module, Attn):
|
269 |
"""Multi-head self attention.
|
|
|
324 |
dtype = query.dtype
|
325 |
query = self.q_ln(query).to(dtype)
|
326 |
key = self.k_ln(key).to(dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
327 |
if self.training and self.gradient_checkpointing:
|
328 |
ckpt_kwargs: Dict[str, Any] = {'use_reentrant': False} if is_torch_version('>=', '1.11.0') else {}
|
329 |
def create_custom_forward(attn_fn: AttnFn) -> AttnFnCheckpointed:
|
|
|
376 |
key,
|
377 |
value,
|
378 |
self.n_heads,
|
379 |
+
past_key_value=past_key_value,
|
380 |
softmax_scale=self.softmax_scale,
|
381 |
attn_bias=attn_bias,
|
382 |
key_padding_mask=key_padding_mask,
|
|
|
385 |
training=self.training,
|
386 |
needs_weights=needs_weights,
|
387 |
)
|
388 |
+
context, attn_weights, past_key_value = attn_fn_out
|
389 |
return AttnOutput(self.out_proj(context), attn_weights, past_key_value)
|
390 |
|
391 |
class MultiQueryAttention(nn.Module, Attn):
|
|
|
505 |
key,
|
506 |
value,
|
507 |
self.n_heads,
|
508 |
+
past_key_value=past_key_value,
|
509 |
softmax_scale=self.softmax_scale,
|
510 |
attn_bias=attn_bias,
|
511 |
key_padding_mask=key_padding_mask,
|
blocks.py
CHANGED
@@ -7,6 +7,7 @@ from .norm import NORM_CLASS_REGISTRY
|
|
7 |
|
8 |
class MPTBlockOutput(NamedTuple):
|
9 |
hidden_states: torch.Tensor
|
|
|
10 |
past_key_value: Union[PastKeyValue, Tuple, None]
|
11 |
|
12 |
class MPTMLP(nn.Module):
|
@@ -38,9 +39,9 @@ class MPTBlock(nn.Module):
|
|
38 |
|
39 |
def forward(self, x: torch.Tensor, past_key_value: Union[PastKeyValue, Tuple, None] = None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> MPTBlockOutput:
|
40 |
a = self.norm_1(x)
|
41 |
-
(b,
|
42 |
x = x + self.resid_attn_dropout(b)
|
43 |
m = self.norm_2(x)
|
44 |
n = self.ffn(m)
|
45 |
x = x + self.resid_ffn_dropout(n)
|
46 |
-
return MPTBlockOutput(x, past_key_value)
|
|
|
7 |
|
8 |
class MPTBlockOutput(NamedTuple):
|
9 |
hidden_states: torch.Tensor
|
10 |
+
attn_probs: Optional[torch.Tensor]
|
11 |
past_key_value: Union[PastKeyValue, Tuple, None]
|
12 |
|
13 |
class MPTMLP(nn.Module):
|
|
|
39 |
|
40 |
def forward(self, x: torch.Tensor, past_key_value: Union[PastKeyValue, Tuple, None] = None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> MPTBlockOutput:
|
41 |
a = self.norm_1(x)
|
42 |
+
(b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
|
43 |
x = x + self.resid_attn_dropout(b)
|
44 |
m = self.norm_2(x)
|
45 |
n = self.ffn(m)
|
46 |
x = x + self.resid_ffn_dropout(n)
|
47 |
+
return MPTBlockOutput(x, attn_weights, past_key_value)
|
config.json
CHANGED
@@ -21,6 +21,7 @@
|
|
21 |
"d_model": 4096,
|
22 |
"emb_pdrop": 0,
|
23 |
"embedding_fraction": 1.0,
|
|
|
24 |
"expansion_ratio": 4,
|
25 |
"init_config": {
|
26 |
"emb_init_std": null,
|
@@ -46,7 +47,7 @@
|
|
46 |
"tokenizer_name": "EleutherAI/gpt-neox-20b",
|
47 |
"torch_dtype": "bfloat16",
|
48 |
"transformers_version": "4.29.2",
|
49 |
-
"use_cache":
|
50 |
"verbose": 0,
|
51 |
"vocab_size": 50432
|
52 |
}
|
|
|
21 |
"d_model": 4096,
|
22 |
"emb_pdrop": 0,
|
23 |
"embedding_fraction": 1.0,
|
24 |
+
"eos_token_id": 0,
|
25 |
"expansion_ratio": 4,
|
26 |
"init_config": {
|
27 |
"emb_init_std": null,
|
|
|
47 |
"tokenizer_name": "EleutherAI/gpt-neox-20b",
|
48 |
"torch_dtype": "bfloat16",
|
49 |
"transformers_version": "4.29.2",
|
50 |
+
"use_cache": true,
|
51 |
"verbose": 0,
|
52 |
"vocab_size": 50432
|
53 |
}
|
generation_config.json
CHANGED
@@ -2,5 +2,5 @@
|
|
2 |
"_from_model_config": true,
|
3 |
"transformers_version": "4.29.2",
|
4 |
"eos_token_id": 0,
|
5 |
-
"use_cache":
|
6 |
}
|
|
|
2 |
"_from_model_config": true,
|
3 |
"transformers_version": "4.29.2",
|
4 |
"eos_token_id": 0,
|
5 |
+
"use_cache": true
|
6 |
}
|
modeling_mpt.py
CHANGED
@@ -116,7 +116,9 @@ class MPTModel(MPTPreTrainedModel):
|
|
116 |
if attn_bias is None:
|
117 |
attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
|
118 |
else:
|
119 |
-
|
|
|
|
|
120 |
if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
|
121 |
raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.')
|
122 |
min_val = torch.finfo(attn_bias.dtype).min
|
@@ -164,7 +166,10 @@ class MPTModel(MPTPreTrainedModel):
|
|
164 |
if not return_dict:
|
165 |
raise NotImplementedError('return_dict False is not implemented yet for MPT')
|
166 |
if output_attentions:
|
167 |
-
|
|
|
|
|
|
|
168 |
if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
|
169 |
raise NotImplementedError('MPT does not support training with left padding.')
|
170 |
if self.prefix_lm and prefix_mask is None:
|
@@ -184,7 +189,12 @@ class MPTModel(MPTPreTrainedModel):
|
|
184 |
if past_key_values is not None:
|
185 |
if len(past_key_values) != self.config.n_layers:
|
186 |
raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).')
|
|
|
|
|
|
|
187 |
past_position = past_key_values[0][0].size(1)
|
|
|
|
|
188 |
if S + past_position > self.config.max_seq_len:
|
189 |
raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
|
190 |
pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
|
@@ -202,6 +212,7 @@ class MPTModel(MPTPreTrainedModel):
|
|
202 |
if use_cache and past_key_values is None:
|
203 |
past_key_values = [() for _ in range(self.config.n_layers)]
|
204 |
all_hidden_states = () if output_hidden_states else None
|
|
|
205 |
for (b_idx, block) in enumerate(self.blocks):
|
206 |
if output_hidden_states:
|
207 |
assert all_hidden_states is not None
|
@@ -242,12 +253,19 @@ class MPTModel(MPTPreTrainedModel):
|
|
242 |
attention_mask=attention_mask,
|
243 |
is_causal=self.is_causal,
|
244 |
)
|
245 |
-
x, past_key_value = block_out
|
246 |
del block_out
|
247 |
if past_key_values is not None:
|
248 |
past_key_values[b_idx] = past_key_value
|
|
|
|
|
|
|
249 |
x = self.norm_f(x)
|
250 |
-
|
|
|
|
|
|
|
|
|
251 |
|
252 |
def param_init_fn(self, module):
|
253 |
init_fn_name = self.config.init_config['name']
|
@@ -308,7 +326,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|
308 |
labels = torch.roll(labels, shifts=-1)
|
309 |
labels[:, -1] = -100
|
310 |
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
|
311 |
-
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
|
312 |
|
313 |
def param_init_fn(self, module):
|
314 |
init_fn_name = self.config.init_config['name']
|
|
|
116 |
if attn_bias is None:
|
117 |
attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
|
118 |
else:
|
119 |
+
# clamp to 0 necessary for torch 2.0 compile()
|
120 |
+
_s_k = max(0, attn_bias.size(-1) - s_k)
|
121 |
+
attn_bias = attn_bias[:, :, :, _s_k:]
|
122 |
if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
|
123 |
raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.')
|
124 |
min_val = torch.finfo(attn_bias.dtype).min
|
|
|
166 |
if not return_dict:
|
167 |
raise NotImplementedError('return_dict False is not implemented yet for MPT')
|
168 |
if output_attentions:
|
169 |
+
if self.attn_impl != 'torch':
|
170 |
+
raise NotImplementedError(
|
171 |
+
'output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.'
|
172 |
+
)
|
173 |
if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
|
174 |
raise NotImplementedError('MPT does not support training with left padding.')
|
175 |
if self.prefix_lm and prefix_mask is None:
|
|
|
189 |
if past_key_values is not None:
|
190 |
if len(past_key_values) != self.config.n_layers:
|
191 |
raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).')
|
192 |
+
# For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim).
|
193 |
+
# For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq).
|
194 |
+
# Here we shift position embedding using the `seq` dim of the past key
|
195 |
past_position = past_key_values[0][0].size(1)
|
196 |
+
if self.attn_impl == 'torch':
|
197 |
+
past_position = past_key_values[0][0].size(3)
|
198 |
if S + past_position > self.config.max_seq_len:
|
199 |
raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
|
200 |
pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
|
|
|
212 |
if use_cache and past_key_values is None:
|
213 |
past_key_values = [() for _ in range(self.config.n_layers)]
|
214 |
all_hidden_states = () if output_hidden_states else None
|
215 |
+
all_self_attns = () if output_attentions else None
|
216 |
for (b_idx, block) in enumerate(self.blocks):
|
217 |
if output_hidden_states:
|
218 |
assert all_hidden_states is not None
|
|
|
253 |
attention_mask=attention_mask,
|
254 |
is_causal=self.is_causal,
|
255 |
)
|
256 |
+
x, attn_weights, past_key_value = block_out
|
257 |
del block_out
|
258 |
if past_key_values is not None:
|
259 |
past_key_values[b_idx] = past_key_value
|
260 |
+
if output_attentions:
|
261 |
+
assert all_self_attns is not None # pyright
|
262 |
+
all_self_attns = all_self_attns + (attn_weights,)
|
263 |
x = self.norm_f(x)
|
264 |
+
# add hidden states from the last decoder layer
|
265 |
+
if output_hidden_states:
|
266 |
+
assert all_hidden_states is not None # pyright
|
267 |
+
all_hidden_states = all_hidden_states + (x,)
|
268 |
+
return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns)
|
269 |
|
270 |
def param_init_fn(self, module):
|
271 |
init_fn_name = self.config.init_config['name']
|
|
|
326 |
labels = torch.roll(labels, shifts=-1)
|
327 |
labels[:, -1] = -100
|
328 |
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
|
329 |
+
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
|
330 |
|
331 |
def param_init_fn(self, module):
|
332 |
init_fn_name = self.config.init_config['name']
|