enable activation checkpointing

#10
by smangrul - opened
Files changed (1) hide show
  1. modeling_phi.py +1 -2
modeling_phi.py CHANGED
@@ -525,7 +525,6 @@ class MHA(nn.Module):
525
  softmax_scale: Optional[float] = None,
526
  layer_idx: Optional[int] = None,
527
  return_residual: bool = False,
528
- checkpointing: bool = False,
529
  ) -> None:
530
  super().__init__()
531
 
@@ -585,7 +584,7 @@ class MHA(nn.Module):
585
  self.flash_attn = config.flash_attn and attn_cls is FlashSelfAttention
586
  self.layer_idx = layer_idx
587
  self.return_residual = return_residual
588
- self.checkpointing = checkpointing
589
 
590
  def _forward_self_attn(
591
  self, x: torch.FloatTensor, key_padding_mask: Optional[torch.BoolTensor]
 
525
  softmax_scale: Optional[float] = None,
526
  layer_idx: Optional[int] = None,
527
  return_residual: bool = False,
 
528
  ) -> None:
529
  super().__init__()
530
 
 
584
  self.flash_attn = config.flash_attn and attn_cls is FlashSelfAttention
585
  self.layer_idx = layer_idx
586
  self.return_residual = return_residual
587
+ self.checkpointing = getattr(config, "checkpointing", False)
588
 
589
  def _forward_self_attn(
590
  self, x: torch.FloatTensor, key_padding_mask: Optional[torch.BoolTensor]