chandler88 commited on
Commit
cc9262b
1 Parent(s): 99a1409

del gradient_checkpointing_enable()

Browse files

del gradient_checkpointing_enable()
because it did nothing and when training with transformers, it did **NOT** enable gradient_checkpointing.

Files changed (1) hide show
  1. modeling_chatglm.py +0 -4
modeling_chatglm.py CHANGED
@@ -797,10 +797,6 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
797
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
798
  return position_ids
799
 
800
- def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
801
- if not self.supports_gradient_checkpointing:
802
- raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
803
-
804
 
805
  class Embedding(torch.nn.Module):
806
  """Language model embeddings."""
 
797
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
798
  return position_ids
799
 
 
 
 
 
800
 
801
  class Embedding(torch.nn.Module):
802
  """Language model embeddings."""