Jackmin801
commited on
Commit
•
f4624e0
1
Parent(s):
43f3955
feat: no splt modules for device auto map
Browse files- modeling_bert.py +1 -0
modeling_bert.py
CHANGED
@@ -956,6 +956,7 @@ class JinaBertPreTrainedModel(PreTrainedModel):
|
|
956 |
load_tf_weights = load_tf_weights_in_bert
|
957 |
base_model_prefix = "bert"
|
958 |
supports_gradient_checkpointing = True
|
|
|
959 |
|
960 |
def _init_weights(self, module):
|
961 |
"""Initialize the weights"""
|
|
|
956 |
load_tf_weights = load_tf_weights_in_bert
|
957 |
base_model_prefix = "bert"
|
958 |
supports_gradient_checkpointing = True
|
959 |
+
_no_split_modules = ["JinaBertLayer"]
|
960 |
|
961 |
def _init_weights(self, module):
|
962 |
"""Initialize the weights"""
|