Update modeling_hf_nomic_bert.py
Browse files- modeling_hf_nomic_bert.py +15 -21
modeling_hf_nomic_bert.py
CHANGED
@@ -105,13 +105,7 @@ def filter_shapes(state_dict, model):
|
|
105 |
return filtered_state_dict
|
106 |
|
107 |
|
108 |
-
def remap_bert_state_dict(
|
109 |
-
state_dict,
|
110 |
-
config,
|
111 |
-
remove_bert=False,
|
112 |
-
remove_cls_weights=False,
|
113 |
-
add_pooling_layer=False,
|
114 |
-
):
|
115 |
"""
|
116 |
Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
|
117 |
"""
|
@@ -311,12 +305,13 @@ class NomicBertPreTrainedModel(PreTrainedModel):
|
|
311 |
if config is None:
|
312 |
config = cls.config_class.from_pretrained(model_name)
|
313 |
remove_cls = cls != NomicBertForPreTraining
|
314 |
-
remove_bert_prefix = cls != NomicBertForPreTraining
|
315 |
ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
|
316 |
num_labels = kwargs.pop("num_labels", None)
|
317 |
rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
|
318 |
-
|
319 |
-
|
|
|
320 |
if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
|
321 |
config.n_positions = 2048
|
322 |
if num_labels:
|
@@ -325,7 +320,10 @@ class NomicBertPreTrainedModel(PreTrainedModel):
|
|
325 |
if "add_pooling_layer" in kwargs:
|
326 |
model = cls(config, *inputs, add_pooling_layer=kwargs.pop("add_pooling_layer"))
|
327 |
else:
|
328 |
-
|
|
|
|
|
|
|
329 |
# TODO: fix this
|
330 |
# Assuming we know what we're doing when loading from disk
|
331 |
# Prob a bad assumption but i'm tired and want to train this asap
|
@@ -344,7 +342,7 @@ class NomicBertPreTrainedModel(PreTrainedModel):
|
|
344 |
load_return = model.load_state_dict(state_dict, strict=False)
|
345 |
else:
|
346 |
# TODO: can probably check config class and see if we need to remap from a bert model
|
347 |
-
state_dict = state_dict_from_pretrained(model_name)
|
348 |
state_dict = remap_bert_state_dict(
|
349 |
state_dict,
|
350 |
config,
|
@@ -355,7 +353,7 @@ class NomicBertPreTrainedModel(PreTrainedModel):
|
|
355 |
if ignore_mismatched_shapes:
|
356 |
state_dict = filter_shapes(state_dict, model)
|
357 |
|
358 |
-
load_return = model.load_state_dict(state_dict, strict=
|
359 |
logger.warning(load_return)
|
360 |
return model
|
361 |
|
@@ -726,7 +724,7 @@ class NomicBertAttention(nn.Module):
|
|
726 |
|
727 |
self.rotary_emb_dim = self.head_dim * config.rotary_emb_fraction
|
728 |
if self.rotary_emb_dim > 0:
|
729 |
-
if
|
730 |
self.rotary_emb = NomicBertDynamicNTKRotaryEmbedding(
|
731 |
dim=self.rotary_emb_dim,
|
732 |
base=config.rotary_emb_base,
|
@@ -1057,11 +1055,10 @@ class NomicBertModel(NomicBertPreTrainedModel):
|
|
1057 |
def forward(
|
1058 |
self,
|
1059 |
input_ids,
|
1060 |
-
position_ids=None,
|
1061 |
-
token_type_ids=None,
|
1062 |
attention_mask=None,
|
|
|
|
|
1063 |
return_dict=None,
|
1064 |
-
matryoshka_dim=None,
|
1065 |
):
|
1066 |
if token_type_ids is None:
|
1067 |
token_type_ids = torch.zeros_like(input_ids)
|
@@ -1074,9 +1071,6 @@ class NomicBertModel(NomicBertPreTrainedModel):
|
|
1074 |
|
1075 |
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
1076 |
|
1077 |
-
if matryoshka_dim:
|
1078 |
-
sequence_output = sequence_output[:, :matryoshka_dim]
|
1079 |
-
|
1080 |
return BaseModelOutputWithPoolingAndCrossAttentions(
|
1081 |
last_hidden_state=sequence_output,
|
1082 |
pooler_output=pooled_output,
|
@@ -1224,4 +1218,4 @@ class NomicBertForSequenceClassification(NomicBertPreTrainedModel):
|
|
1224 |
logits=logits,
|
1225 |
hidden_states=outputs.hidden_states,
|
1226 |
attentions=outputs.attentions,
|
1227 |
-
)
|
|
|
105 |
return filtered_state_dict
|
106 |
|
107 |
|
108 |
+
def remap_bert_state_dict(state_dict, config, remove_bert=False, remove_cls_weights=False, add_pooling_layer=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
"""
|
110 |
Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
|
111 |
"""
|
|
|
305 |
if config is None:
|
306 |
config = cls.config_class.from_pretrained(model_name)
|
307 |
remove_cls = cls != NomicBertForPreTraining
|
308 |
+
remove_bert_prefix = cls != NomicBertForPreTraining
|
309 |
ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
|
310 |
num_labels = kwargs.pop("num_labels", None)
|
311 |
rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
|
312 |
+
if rotary_scaling_factor:
|
313 |
+
config.rotary_scaling_factor = rotary_scaling_factor
|
314 |
+
|
315 |
if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
|
316 |
config.n_positions = 2048
|
317 |
if num_labels:
|
|
|
320 |
if "add_pooling_layer" in kwargs:
|
321 |
model = cls(config, *inputs, add_pooling_layer=kwargs.pop("add_pooling_layer"))
|
322 |
else:
|
323 |
+
if cls == NomicBertModel:
|
324 |
+
model = cls(config, *inputs, add_pooling_layer=False)
|
325 |
+
else:
|
326 |
+
model = cls(config, *inputs)
|
327 |
# TODO: fix this
|
328 |
# Assuming we know what we're doing when loading from disk
|
329 |
# Prob a bad assumption but i'm tired and want to train this asap
|
|
|
342 |
load_return = model.load_state_dict(state_dict, strict=False)
|
343 |
else:
|
344 |
# TODO: can probably check config class and see if we need to remap from a bert model
|
345 |
+
state_dict = state_dict_from_pretrained(model_name, safe_serialization=kwargs.get("safe_serialization", False))
|
346 |
state_dict = remap_bert_state_dict(
|
347 |
state_dict,
|
348 |
config,
|
|
|
353 |
if ignore_mismatched_shapes:
|
354 |
state_dict = filter_shapes(state_dict, model)
|
355 |
|
356 |
+
load_return = model.load_state_dict(state_dict, strict=True)
|
357 |
logger.warning(load_return)
|
358 |
return model
|
359 |
|
|
|
724 |
|
725 |
self.rotary_emb_dim = self.head_dim * config.rotary_emb_fraction
|
726 |
if self.rotary_emb_dim > 0:
|
727 |
+
if config.rotary_scaling_factor:
|
728 |
self.rotary_emb = NomicBertDynamicNTKRotaryEmbedding(
|
729 |
dim=self.rotary_emb_dim,
|
730 |
base=config.rotary_emb_base,
|
|
|
1055 |
def forward(
|
1056 |
self,
|
1057 |
input_ids,
|
|
|
|
|
1058 |
attention_mask=None,
|
1059 |
+
token_type_ids=None,
|
1060 |
+
position_ids=None,
|
1061 |
return_dict=None,
|
|
|
1062 |
):
|
1063 |
if token_type_ids is None:
|
1064 |
token_type_ids = torch.zeros_like(input_ids)
|
|
|
1071 |
|
1072 |
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
1073 |
|
|
|
|
|
|
|
1074 |
return BaseModelOutputWithPoolingAndCrossAttentions(
|
1075 |
last_hidden_state=sequence_output,
|
1076 |
pooler_output=pooled_output,
|
|
|
1218 |
logits=logits,
|
1219 |
hidden_states=outputs.hidden_states,
|
1220 |
attentions=outputs.attentions,
|
1221 |
+
)
|