chandler88
commited on
Commit
•
cc9262b
1
Parent(s):
99a1409
del gradient_checkpointing_enable()
Browse filesdel gradient_checkpointing_enable()
because it did nothing and when training with transformers, it did **NOT** enable gradient_checkpointing.
- modeling_chatglm.py +0 -4
modeling_chatglm.py
CHANGED
@@ -797,10 +797,6 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
|
797 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
798 |
return position_ids
|
799 |
|
800 |
-
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
|
801 |
-
if not self.supports_gradient_checkpointing:
|
802 |
-
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
803 |
-
|
804 |
|
805 |
class Embedding(torch.nn.Module):
|
806 |
"""Language model embeddings."""
|
|
|
797 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
798 |
return position_ids
|
799 |
|
|
|
|
|
|
|
|
|
800 |
|
801 |
class Embedding(torch.nn.Module):
|
802 |
"""Language model embeddings."""
|