Upload modeling_attn_mask_utils.py
Browse files- modeling_attn_mask_utils.py +20 -19
modeling_attn_mask_utils.py
CHANGED
@@ -240,6 +240,7 @@ class AttentionMaskConverter:
|
|
240 |
inputs_embeds: torch.Tensor,
|
241 |
past_key_values_length: int,
|
242 |
sliding_window: Optional[int] = None,
|
|
|
243 |
) -> bool:
|
244 |
"""
|
245 |
Detects whether the optional user-specified attention_mask & the automatically created causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
|
@@ -249,7 +250,7 @@ class AttentionMaskConverter:
|
|
249 |
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
|
250 |
"""
|
251 |
|
252 |
-
|
253 |
key_value_length = query_length + past_key_values_length
|
254 |
|
255 |
is_tracing = (
|
@@ -263,23 +264,19 @@ class AttentionMaskConverter:
|
|
263 |
if attention_mask is None:
|
264 |
# TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input shape, thus SDPA's `is_causal` argument is rightfully updated (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using `torch.export` or
|
265 |
# or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
|
266 |
-
# Thus, we
|
267 |
#
|
268 |
# Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` (`TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor`).
|
269 |
if (
|
270 |
-
not is_tracing
|
271 |
and (query_length == 1 or key_value_length == query_length)
|
272 |
and (sliding_window is None or key_value_length < sliding_window)
|
273 |
):
|
274 |
ignore_causal_mask = True
|
275 |
elif sliding_window is None or key_value_length < sliding_window:
|
276 |
if len(attention_mask.shape) == 4:
|
277 |
-
|
278 |
-
|
279 |
-
raise ValueError(
|
280 |
-
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
|
281 |
-
)
|
282 |
-
elif not is_tracing and torch.all(attention_mask == 1):
|
283 |
if query_length == 1 or key_value_length == query_length:
|
284 |
# For query_length == 1, causal attention and bi-directional attention are the same.
|
285 |
ignore_causal_mask = True
|
@@ -386,12 +383,18 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
|
|
386 |
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
387 |
)
|
388 |
else:
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
395 |
|
396 |
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
|
397 |
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
@@ -445,10 +448,8 @@ def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype,
|
|
445 |
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
|
446 |
)
|
447 |
|
448 |
-
if torch.all(mask == 1):
|
449 |
-
if
|
450 |
-
pass
|
451 |
-
elif tgt_len == 1:
|
452 |
# For query_length == 1, causal attention and bi-directional attention are the same.
|
453 |
return None
|
454 |
elif key_value_length == tgt_len:
|
|
|
240 |
inputs_embeds: torch.Tensor,
|
241 |
past_key_values_length: int,
|
242 |
sliding_window: Optional[int] = None,
|
243 |
+
is_training: bool = False,
|
244 |
) -> bool:
|
245 |
"""
|
246 |
Detects whether the optional user-specified attention_mask & the automatically created causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
|
|
|
250 |
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
|
251 |
"""
|
252 |
|
253 |
+
_, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
|
254 |
key_value_length = query_length + past_key_values_length
|
255 |
|
256 |
is_tracing = (
|
|
|
264 |
if attention_mask is None:
|
265 |
# TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input shape, thus SDPA's `is_causal` argument is rightfully updated (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using `torch.export` or
|
266 |
# or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
|
267 |
+
# Thus, we only set `ignore_causal_mask = True` if the model is set to training.
|
268 |
#
|
269 |
# Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` (`TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor`).
|
270 |
if (
|
271 |
+
(is_training or not is_tracing)
|
272 |
and (query_length == 1 or key_value_length == query_length)
|
273 |
and (sliding_window is None or key_value_length < sliding_window)
|
274 |
):
|
275 |
ignore_causal_mask = True
|
276 |
elif sliding_window is None or key_value_length < sliding_window:
|
277 |
if len(attention_mask.shape) == 4:
|
278 |
+
return False
|
279 |
+
elif (is_training or not is_tracing) and torch.all(attention_mask == 1):
|
|
|
|
|
|
|
|
|
280 |
if query_length == 1 or key_value_length == query_length:
|
281 |
# For query_length == 1, causal attention and bi-directional attention are the same.
|
282 |
ignore_causal_mask = True
|
|
|
383 |
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
384 |
)
|
385 |
else:
|
386 |
+
if attention_mask.dim() == 4:
|
387 |
+
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
|
388 |
+
if attention_mask.max() != 0:
|
389 |
+
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
|
390 |
+
expanded_4d_mask = attention_mask
|
391 |
+
else:
|
392 |
+
expanded_4d_mask = attn_mask_converter.to_4d(
|
393 |
+
attention_mask,
|
394 |
+
input_shape[-1],
|
395 |
+
dtype=inputs_embeds.dtype,
|
396 |
+
key_value_length=key_value_length,
|
397 |
+
)
|
398 |
|
399 |
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
|
400 |
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
|
|
448 |
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
|
449 |
)
|
450 |
|
451 |
+
if not is_tracing and torch.all(mask == 1):
|
452 |
+
if tgt_len == 1:
|
|
|
|
|
453 |
# For query_length == 1, causal attention and bi-directional attention are the same.
|
454 |
return None
|
455 |
elif key_value_length == tgt_len:
|