AoiKazama commited on
Commit
ae1a53c
1 Parent(s): 13ef324

Upload modeling_attn_mask_utils.py

Browse files
Files changed (1) hide show
  1. modeling_attn_mask_utils.py +20 -19
modeling_attn_mask_utils.py CHANGED
@@ -240,6 +240,7 @@ class AttentionMaskConverter:
240
  inputs_embeds: torch.Tensor,
241
  past_key_values_length: int,
242
  sliding_window: Optional[int] = None,
 
243
  ) -> bool:
244
  """
245
  Detects whether the optional user-specified attention_mask & the automatically created causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
@@ -249,7 +250,7 @@ class AttentionMaskConverter:
249
  allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
250
  """
251
 
252
- batch_size, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
253
  key_value_length = query_length + past_key_values_length
254
 
255
  is_tracing = (
@@ -263,23 +264,19 @@ class AttentionMaskConverter:
263
  if attention_mask is None:
264
  # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input shape, thus SDPA's `is_causal` argument is rightfully updated (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using `torch.export` or
265
  # or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
266
- # Thus, we currently can NOT set `ignore_causal_mask = True` here. We would need a `torch._dynamo.is_exporting()` flag.
267
  #
268
  # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` (`TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor`).
269
  if (
270
- not is_tracing
271
  and (query_length == 1 or key_value_length == query_length)
272
  and (sliding_window is None or key_value_length < sliding_window)
273
  ):
274
  ignore_causal_mask = True
275
  elif sliding_window is None or key_value_length < sliding_window:
276
  if len(attention_mask.shape) == 4:
277
- expected_shape = (batch_size, 1, query_length, key_value_length)
278
- if tuple(attention_mask.shape) != expected_shape:
279
- raise ValueError(
280
- f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
281
- )
282
- elif not is_tracing and torch.all(attention_mask == 1):
283
  if query_length == 1 or key_value_length == query_length:
284
  # For query_length == 1, causal attention and bi-directional attention are the same.
285
  ignore_causal_mask = True
@@ -386,12 +383,18 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
386
  input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
387
  )
388
  else:
389
- expanded_4d_mask = attn_mask_converter.to_4d(
390
- attention_mask,
391
- input_shape[-1],
392
- dtype=inputs_embeds.dtype,
393
- key_value_length=key_value_length,
394
- )
 
 
 
 
 
 
395
 
396
  # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
397
  # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
@@ -445,10 +448,8 @@ def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype,
445
  or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
446
  )
447
 
448
- if torch.all(mask == 1):
449
- if is_tracing:
450
- pass
451
- elif tgt_len == 1:
452
  # For query_length == 1, causal attention and bi-directional attention are the same.
453
  return None
454
  elif key_value_length == tgt_len:
 
240
  inputs_embeds: torch.Tensor,
241
  past_key_values_length: int,
242
  sliding_window: Optional[int] = None,
243
+ is_training: bool = False,
244
  ) -> bool:
245
  """
246
  Detects whether the optional user-specified attention_mask & the automatically created causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
 
250
  allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
251
  """
252
 
253
+ _, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
254
  key_value_length = query_length + past_key_values_length
255
 
256
  is_tracing = (
 
264
  if attention_mask is None:
265
  # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input shape, thus SDPA's `is_causal` argument is rightfully updated (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using `torch.export` or
266
  # or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
267
+ # Thus, we only set `ignore_causal_mask = True` if the model is set to training.
268
  #
269
  # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` (`TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor`).
270
  if (
271
+ (is_training or not is_tracing)
272
  and (query_length == 1 or key_value_length == query_length)
273
  and (sliding_window is None or key_value_length < sliding_window)
274
  ):
275
  ignore_causal_mask = True
276
  elif sliding_window is None or key_value_length < sliding_window:
277
  if len(attention_mask.shape) == 4:
278
+ return False
279
+ elif (is_training or not is_tracing) and torch.all(attention_mask == 1):
 
 
 
 
280
  if query_length == 1 or key_value_length == query_length:
281
  # For query_length == 1, causal attention and bi-directional attention are the same.
282
  ignore_causal_mask = True
 
383
  input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
384
  )
385
  else:
386
+ if attention_mask.dim() == 4:
387
+ # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
388
+ if attention_mask.max() != 0:
389
+ raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
390
+ expanded_4d_mask = attention_mask
391
+ else:
392
+ expanded_4d_mask = attn_mask_converter.to_4d(
393
+ attention_mask,
394
+ input_shape[-1],
395
+ dtype=inputs_embeds.dtype,
396
+ key_value_length=key_value_length,
397
+ )
398
 
399
  # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
400
  # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
 
448
  or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
449
  )
450
 
451
+ if not is_tracing and torch.all(mask == 1):
452
+ if tgt_len == 1:
 
 
453
  # For query_length == 1, causal attention and bi-directional attention are the same.
454
  return None
455
  elif key_value_length == tgt_len: