Update modeling_hf_nomic_bert.py
Browse files- modeling_hf_nomic_bert.py +145 -149
modeling_hf_nomic_bert.py
CHANGED
@@ -3,39 +3,34 @@
|
|
3 |
# https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
|
4 |
# https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
|
5 |
|
|
|
|
|
6 |
# Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
|
7 |
import os
|
8 |
-
import
|
|
|
9 |
from functools import partial
|
10 |
-
from typing import
|
11 |
|
12 |
import torch
|
13 |
import torch.nn as nn
|
14 |
import torch.nn.functional as F
|
15 |
from einops import rearrange, repeat
|
|
|
16 |
from transformers import GPT2Config, PreTrainedModel
|
17 |
from transformers.models.bert.modeling_bert import (
|
18 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
19 |
MaskedLMOutput,
|
20 |
-
SequenceClassifierOutput
|
21 |
-
)
|
22 |
-
|
23 |
-
import re
|
24 |
-
from collections import OrderedDict
|
25 |
-
from safetensors.torch import load_file as safe_load_file
|
26 |
-
from transformers.utils import (
|
27 |
-
SAFE_WEIGHTS_INDEX_NAME,
|
28 |
-
SAFE_WEIGHTS_NAME,
|
29 |
-
WEIGHTS_INDEX_NAME,
|
30 |
-
WEIGHTS_NAME,
|
31 |
)
|
|
|
32 |
from transformers.utils.hub import cached_file, get_checkpoint_shard_files
|
33 |
|
34 |
-
|
35 |
from .configuration_hf_nomic_bert import NomicBertConfig
|
36 |
|
37 |
logger = logging.getLogger(__name__)
|
38 |
|
|
|
39 |
# adapted from flash attention, added safe serialization option for hf models
|
40 |
def state_dict_from_pretrained(model_name, safe_serialization=False, device=None, dtype=None):
|
41 |
# If not fp32, then we don't want to load directly to the GPU
|
@@ -50,18 +45,12 @@ def state_dict_from_pretrained(model_name, safe_serialization=False, device=None
|
|
50 |
safe_weights_index_path = os.path.join(model_name, SAFE_WEIGHTS_INDEX_NAME)
|
51 |
|
52 |
if os.path.isfile(weights_path):
|
53 |
-
resolved_archive_file = cached_file(
|
54 |
-
model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False
|
55 |
-
)
|
56 |
elif os.path.isfile(weights_index_path):
|
57 |
-
resolved_archive_file = cached_file(
|
58 |
-
model_name, WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False
|
59 |
-
)
|
60 |
is_sharded = True
|
61 |
elif os.path.isfile(safe_weights_path):
|
62 |
-
resolved_archive_file = cached_file(
|
63 |
-
model_name, SAFE_WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False
|
64 |
-
)
|
65 |
load_safe = True
|
66 |
elif os.path.isfile(safe_weights_index_path):
|
67 |
resolved_archive_file = cached_file(
|
@@ -74,8 +63,7 @@ def state_dict_from_pretrained(model_name, safe_serialization=False, device=None
|
|
74 |
resolved_archive_file = cached_file(model_name, weight_name, _raise_exceptions_for_missing_entries=False)
|
75 |
if resolved_archive_file is None:
|
76 |
weight_index = WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME
|
77 |
-
resolved_archive_file = cached_file(model_name, weight_index,
|
78 |
-
_raise_exceptions_for_missing_entries=False)
|
79 |
if resolved_archive_file is not None:
|
80 |
is_sharded = True
|
81 |
|
@@ -92,9 +80,7 @@ def state_dict_from_pretrained(model_name, safe_serialization=False, device=None
|
|
92 |
if is_sharded:
|
93 |
# resolved_archive_file becomes a list of files that point to the different
|
94 |
# checkpoint shards in this case.
|
95 |
-
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
|
96 |
-
model_name, resolved_archive_file
|
97 |
-
)
|
98 |
state_dict = {}
|
99 |
for sharded_file in resolved_archive_file:
|
100 |
state_dict.update(loader(sharded_file))
|
@@ -106,7 +92,7 @@ def state_dict_from_pretrained(model_name, safe_serialization=False, device=None
|
|
106 |
state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
|
107 |
return state_dict
|
108 |
|
109 |
-
|
110 |
def filter_shapes(state_dict, model):
|
111 |
"""
|
112 |
Filters the state dict to match the current model shape.
|
@@ -118,11 +104,18 @@ def filter_shapes(state_dict, model):
|
|
118 |
filtered_state_dict[key] = value
|
119 |
return filtered_state_dict
|
120 |
|
121 |
-
|
122 |
-
def remap_bert_state_dict(
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
"""
|
124 |
Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
|
125 |
"""
|
|
|
126 |
def add_bert_prefix(key):
|
127 |
# prepend bert. to the key
|
128 |
if key.startswith("bert.") or key.startswith("cls."):
|
@@ -130,7 +123,7 @@ def remap_bert_state_dict(state_dict, config, remove_bert=False, remove_cls_weig
|
|
130 |
return f"bert.{key}"
|
131 |
|
132 |
state_dict = OrderedDict((add_bert_prefix(k), v) for k, v in state_dict.items())
|
133 |
-
|
134 |
# LayerNorm
|
135 |
def key_mapping_ln_gamma_beta(key):
|
136 |
key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
|
@@ -195,9 +188,7 @@ def remap_bert_state_dict(state_dict, config, remove_bert=False, remove_cls_weig
|
|
195 |
bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias")
|
196 |
bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias")
|
197 |
if not (last_layer_subset and d == config.num_hidden_layers - 1):
|
198 |
-
state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.weight"] = torch.cat(
|
199 |
-
[Wq, Wk, Wv], dim=0
|
200 |
-
)
|
201 |
state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
|
202 |
else:
|
203 |
state_dict[f"bert.encoder.layers.{d}.attn.Wq.weight"] = Wq
|
@@ -217,7 +208,6 @@ def remap_bert_state_dict(state_dict, config, remove_bert=False, remove_cls_weig
|
|
217 |
def key_mapping_decoder_bias(key):
|
218 |
return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
|
219 |
|
220 |
-
|
221 |
# remove nsp weights, we don't use
|
222 |
state_dict.pop("cls.seq_relationship.weight", None)
|
223 |
state_dict.pop("cls.seq_relationship.bias", None)
|
@@ -226,12 +216,14 @@ def remap_bert_state_dict(state_dict, config, remove_bert=False, remove_cls_weig
|
|
226 |
state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
|
227 |
|
228 |
if remove_cls_weights:
|
229 |
-
cls_weights = [
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
|
|
|
|
235 |
for weight in cls_weights:
|
236 |
state_dict.pop(weight, None)
|
237 |
|
@@ -257,20 +249,21 @@ def remap_bert_state_dict(state_dict, config, remove_bert=False, remove_cls_weig
|
|
257 |
)
|
258 |
|
259 |
if add_pooling_layer is False:
|
260 |
-
pooler_weights = [
|
261 |
-
|
262 |
-
|
|
|
263 |
for key in pooler_weights:
|
264 |
state_dict.pop(key, None)
|
265 |
|
266 |
if remove_bert:
|
|
|
267 |
def remove_bert_prefix(key):
|
268 |
key = re.sub(r"^bert.", "", key)
|
269 |
return key
|
270 |
|
271 |
state_dict = OrderedDict((remove_bert_prefix(k), v) for k, v in state_dict.items())
|
272 |
|
273 |
-
|
274 |
return state_dict
|
275 |
|
276 |
|
@@ -278,6 +271,7 @@ class NomicBertPreTrainedModel(PreTrainedModel):
|
|
278 |
"""An abstract class to handle weights initialization and
|
279 |
a simple interface for dowloading and loading pretrained models.
|
280 |
"""
|
|
|
281 |
config_class = NomicBertConfig
|
282 |
base_model_prefix = "model"
|
283 |
supports_gradient_checkpointing = True
|
@@ -317,14 +311,13 @@ class NomicBertPreTrainedModel(PreTrainedModel):
|
|
317 |
if config is None:
|
318 |
config = cls.config_class.from_pretrained(model_name)
|
319 |
remove_cls = cls != NomicBertForPreTraining
|
320 |
-
remove_bert_prefix = cls != NomicBertForPreTraining
|
321 |
ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
|
322 |
num_labels = kwargs.pop("num_labels", None)
|
323 |
rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
|
324 |
if rotary_scaling_factor:
|
325 |
config.rotary_scaling_factor = rotary_scaling_factor
|
326 |
-
|
327 |
-
config.rotary_scaling_factor = None
|
328 |
if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
|
329 |
config.n_positions = 2048
|
330 |
if num_labels:
|
@@ -341,26 +334,34 @@ class NomicBertPreTrainedModel(PreTrainedModel):
|
|
341 |
# Assuming we know what we're doing when loading from disk
|
342 |
# Prob a bad assumption but i'm tired and want to train this asap
|
343 |
if os.path.exists(model_name):
|
344 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
345 |
if ignore_mismatched_shapes:
|
346 |
state_dict = filter_shapes(state_dict, model)
|
347 |
load_return = model.load_state_dict(state_dict, strict=False)
|
348 |
else:
|
349 |
# TODO: can probably check config class and see if we need to remap from a bert model
|
350 |
-
state_dict = state_dict_from_pretrained(
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
|
|
|
|
|
|
357 |
if ignore_mismatched_shapes:
|
358 |
state_dict = filter_shapes(state_dict, model)
|
359 |
|
360 |
-
load_return = model.load_state_dict(
|
361 |
-
state_dict,
|
362 |
-
strict=True
|
363 |
-
)
|
364 |
logger.warning(load_return)
|
365 |
return model
|
366 |
|
@@ -380,25 +381,21 @@ def _init_weights(module, initializer_range=0.02):
|
|
380 |
if module.padding_idx is not None:
|
381 |
nn.init.zeros_(module.weight[module.padding_idx])
|
382 |
|
383 |
-
|
384 |
class NomicBertEmbeddings(nn.Module):
|
385 |
-
def __init__(
|
386 |
-
self,
|
387 |
-
config
|
388 |
-
):
|
389 |
"""
|
390 |
If max_position_embeddings <= 0, there's no position embeddings
|
391 |
If type_vocab_size <= 0, there's no token type embeddings
|
392 |
"""
|
393 |
super().__init__()
|
394 |
-
self.word_embeddings = nn.Embedding(
|
395 |
-
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
|
396 |
-
)
|
397 |
self.max_position_embeddings = config.max_position_embeddings if config.rotary_emb_fraction <= 0 else 0
|
398 |
self.type_vocab_size = config.type_vocab_size
|
399 |
if self.max_position_embeddings > 0 and config.rotary_emb_fraction <= 0:
|
400 |
self.position_embeddings = nn.Embedding(
|
401 |
-
config.max_position_embeddings,
|
|
|
402 |
)
|
403 |
if self.type_vocab_size > 0:
|
404 |
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
@@ -425,6 +422,7 @@ class NomicBertEmbeddings(nn.Module):
|
|
425 |
embeddings = embeddings + position_embeddings
|
426 |
return embeddings
|
427 |
|
|
|
428 |
class NomicBertMLP(nn.Module):
|
429 |
def __init__(
|
430 |
self,
|
@@ -442,11 +440,7 @@ class NomicBertMLP(nn.Module):
|
|
442 |
hidden_features = hidden_features if hidden_features is not None else in_features * 4
|
443 |
self.return_residual = return_residual
|
444 |
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1)
|
445 |
-
approximate =
|
446 |
-
"tanh"
|
447 |
-
if activation in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
|
448 |
-
else "none"
|
449 |
-
)
|
450 |
self.activation = nn.GELU(approximate=approximate) if activation == "gelu" else activation
|
451 |
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2)
|
452 |
|
@@ -456,7 +450,7 @@ class NomicBertMLP(nn.Module):
|
|
456 |
y = self.fc2(y)
|
457 |
return y if not self.return_residual else (y, x)
|
458 |
|
459 |
-
|
460 |
class NomciBertGatedMLP(nn.Module):
|
461 |
def __init__(
|
462 |
self,
|
@@ -474,9 +468,7 @@ class NomciBertGatedMLP(nn.Module):
|
|
474 |
):
|
475 |
super().__init__()
|
476 |
out_features = out_features if out_features is not None else in_features
|
477 |
-
hidden_features = (
|
478 |
-
hidden_features if hidden_features is not None else int(8 * in_features / 3)
|
479 |
-
)
|
480 |
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
|
481 |
self.return_residual = return_residual
|
482 |
|
@@ -513,8 +505,8 @@ def apply_rotary_emb(x, cos, sin, offset=0, interleaved=False):
|
|
513 |
ro_dim = cos.shape[-1] * 2
|
514 |
assert ro_dim <= x.shape[-1]
|
515 |
cos, sin = (
|
516 |
-
cos[offset: offset + x.shape[1]],
|
517 |
-
sin[offset: offset + x.shape[1]],
|
518 |
)
|
519 |
cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
520 |
sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
@@ -571,10 +563,7 @@ class NomicBertRotaryEmbedding(nn.Module):
|
|
571 |
self._sin_k_cached = None
|
572 |
|
573 |
def _compute_inv_freq(self, device=None):
|
574 |
-
return 1.0 / (
|
575 |
-
self.base
|
576 |
-
** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
|
577 |
-
)
|
578 |
|
579 |
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
580 |
# Reset the tables if the sequence length has changed,
|
@@ -646,14 +635,10 @@ class NomicBertDynamicNTKRotaryEmbedding(NomicBertRotaryEmbedding):
|
|
646 |
self.rotary_scaling_factor = rotary_scaling_factor
|
647 |
self.max_position_embeddings = max_position_embeddings
|
648 |
|
649 |
-
|
650 |
def _compute_inv_freq(self, base=None, device=None):
|
651 |
if base is None:
|
652 |
base = self.base
|
653 |
-
return 1.0 / (
|
654 |
-
base
|
655 |
-
** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
|
656 |
-
)
|
657 |
|
658 |
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
659 |
# Reset the tables if the sequence length has changed,
|
@@ -704,8 +689,7 @@ class NomicBertDynamicNTKRotaryEmbedding(NomicBertRotaryEmbedding):
|
|
704 |
self._sin_cached = torch.sin(freqs).to(dtype)
|
705 |
else:
|
706 |
power = (
|
707 |
-
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
|
708 |
-
- seqlen // 2
|
709 |
) / self.scale_base
|
710 |
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
711 |
# We want the multiplication by scale to happen in fp32
|
@@ -714,6 +698,7 @@ class NomicBertDynamicNTKRotaryEmbedding(NomicBertRotaryEmbedding):
|
|
714 |
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
715 |
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
716 |
|
|
|
717 |
class NomicBertAttention(nn.Module):
|
718 |
"""Multi-head self-attention and cross-attention"""
|
719 |
|
@@ -754,8 +739,8 @@ class NomicBertAttention(nn.Module):
|
|
754 |
scale_base=config.rotary_emb_scale_base,
|
755 |
interleaved=config.rotary_emb_interleaved,
|
756 |
rotary_scaling_factor=config.rotary_scaling_factor,
|
757 |
-
max_position_embeddings=config.
|
758 |
-
)
|
759 |
else:
|
760 |
self.rotary_emb = NomicBertRotaryEmbedding(
|
761 |
dim=self.rotary_emb_dim,
|
@@ -826,7 +811,7 @@ class NomicBertAttention(nn.Module):
|
|
826 |
attn_output = self.out_proj(attn_output)
|
827 |
|
828 |
return attn_output
|
829 |
-
|
830 |
|
831 |
class NomicBertBlock(nn.Module):
|
832 |
def __init__(
|
@@ -836,17 +821,31 @@ class NomicBertBlock(nn.Module):
|
|
836 |
super().__init__()
|
837 |
self.prenorm = config.prenorm
|
838 |
self.fused_dropout_add_ln = config.fused_dropout_add_ln
|
839 |
-
|
840 |
-
self.attn = NomicBertAttention(config)
|
841 |
activation = (
|
842 |
-
|
843 |
-
|
844 |
-
|
845 |
)
|
846 |
if config.activation_function in ["glu", "swiglu", "geglu"]:
|
847 |
-
self.mlp = NomciBertGatedMLP(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
848 |
else:
|
849 |
-
self.mlp = NomicBertMLP(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
850 |
|
851 |
self.dropout1 = nn.Dropout(config.resid_pdrop)
|
852 |
self.norm1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
@@ -880,7 +879,13 @@ class NomicBertBlock(nn.Module):
|
|
880 |
dropped = self.dropout1(hidden_states)
|
881 |
residual = (dropped + residual) if residual is not None else dropped
|
882 |
hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
|
883 |
-
hidden_states = self.attn(
|
|
|
|
|
|
|
|
|
|
|
|
|
884 |
|
885 |
dropped = self.dropout2(hidden_states)
|
886 |
residual = (dropped + residual) if residual is not None else dropped
|
@@ -890,36 +895,29 @@ class NomicBertBlock(nn.Module):
|
|
890 |
return hidden_states, None, residual
|
891 |
else:
|
892 |
assert residual is None
|
893 |
-
attn_outputs = self.attn(
|
894 |
-
|
895 |
-
|
896 |
-
|
897 |
-
|
898 |
-
|
899 |
-
(self.dropout1(attn_outputs) + hidden_states).to(
|
900 |
-
dtype=self.norm1.weight.dtype
|
901 |
-
)
|
902 |
)
|
|
|
903 |
mlp_out = self.mlp(hidden_states)
|
904 |
|
905 |
-
hidden_states = self.norm2(
|
906 |
-
(self.dropout2(mlp_out) + hidden_states).to(
|
907 |
-
dtype=self.norm2.weight.dtype
|
908 |
-
)
|
909 |
-
)
|
910 |
return hidden_states, None, None
|
911 |
|
912 |
|
913 |
class NomicBertEncoder(nn.Module):
|
914 |
def __init__(self, config: GPT2Config):
|
915 |
super().__init__()
|
916 |
-
self.layers = nn.ModuleList(
|
917 |
-
[NomicBertBlock(config) for _ in range(config.n_layer)]
|
918 |
-
)
|
919 |
self.gradient_checkpointing = False
|
920 |
self.config = config
|
921 |
|
922 |
-
def forward(
|
|
|
923 |
hidden_states: torch.LongTensor = None,
|
924 |
attention_mask: Optional[torch.Tensor] = None,
|
925 |
position_ids: Optional[torch.LongTensor] = None,
|
@@ -929,8 +927,8 @@ class NomicBertEncoder(nn.Module):
|
|
929 |
output_attentions: Optional[bool] = None,
|
930 |
output_hidden_states: Optional[bool] = None,
|
931 |
return_dict: Optional[bool] = None,
|
932 |
-
is_padded_inputs: Optional[bool] = True,
|
933 |
-
|
934 |
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
935 |
This means that we only compute the last layer output for these tokens.
|
936 |
subset_mask: (batch, seqlen), dtype=torch.bool
|
@@ -938,7 +936,6 @@ class NomicBertEncoder(nn.Module):
|
|
938 |
hidden_states2 = None
|
939 |
residual = None
|
940 |
|
941 |
-
|
942 |
for _, layer in enumerate(self.layers):
|
943 |
if self.gradient_checkpointing and self.training:
|
944 |
|
@@ -998,11 +995,7 @@ class NomicBertPredictionHeadTransform(nn.Module):
|
|
998 |
def __init__(self, config):
|
999 |
super().__init__()
|
1000 |
self.dense = nn.Linear(config.n_embd, config.n_embd, bias=config.mlp_fc1_bias)
|
1001 |
-
approximate =
|
1002 |
-
"tanh"
|
1003 |
-
if config.activation_function in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
|
1004 |
-
else "none"
|
1005 |
-
)
|
1006 |
if config.activation_function == "swiglu":
|
1007 |
self.transform_act_fn = F.silu
|
1008 |
else:
|
@@ -1047,15 +1040,19 @@ class NomicBertModel(NomicBertPreTrainedModel):
|
|
1047 |
super().__init__(config)
|
1048 |
self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
1049 |
if config.vocab_size % self.pad_vocab_size_multiple != 0:
|
1050 |
-
config.vocab_size += self.pad_vocab_size_multiple - (
|
1051 |
-
|
1052 |
-
|
1053 |
-
|
1054 |
-
|
1055 |
-
|
1056 |
-
|
1057 |
-
|
1058 |
-
|
|
|
|
|
|
|
|
|
1059 |
self.emb_drop = nn.Dropout(config.resid_pdrop)
|
1060 |
self.emb_ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
1061 |
self.encoder = NomicBertEncoder(config)
|
@@ -1069,22 +1066,23 @@ class NomicBertModel(NomicBertPreTrainedModel):
|
|
1069 |
position_ids=None,
|
1070 |
token_type_ids=None,
|
1071 |
attention_mask=None,
|
|
|
|
|
1072 |
):
|
1073 |
if token_type_ids is None:
|
1074 |
token_type_ids = torch.zeros_like(input_ids)
|
1075 |
-
hidden_states = self.embeddings(
|
1076 |
-
input_ids, position_ids=position_ids, token_type_ids=token_type_ids
|
1077 |
-
)
|
1078 |
hidden_states = self.emb_ln(hidden_states)
|
1079 |
hidden_states = self.emb_drop(hidden_states)
|
1080 |
|
1081 |
attention_mask = self.get_extended_attention_mask(attention_mask, input_ids.shape)
|
1082 |
-
sequence_output = self.encoder(
|
1083 |
-
hidden_states, attention_mask=attention_mask
|
1084 |
-
)
|
1085 |
|
1086 |
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
1087 |
|
|
|
|
|
|
|
1088 |
return BaseModelOutputWithPoolingAndCrossAttentions(
|
1089 |
last_hidden_state=sequence_output,
|
1090 |
pooler_output=pooled_output,
|
@@ -1151,10 +1149,10 @@ class NomicBertForPreTraining(NomicBertPreTrainedModel):
|
|
1151 |
loss=total_loss,
|
1152 |
logits=prediction_scores,
|
1153 |
hidden_states=outputs.hidden_states,
|
1154 |
-
attentions=None,
|
1155 |
)
|
1156 |
|
1157 |
-
|
1158 |
class NomicBertForSequenceClassification(NomicBertPreTrainedModel):
|
1159 |
def __init__(self, config):
|
1160 |
super().__init__(config)
|
@@ -1162,9 +1160,7 @@ class NomicBertForSequenceClassification(NomicBertPreTrainedModel):
|
|
1162 |
self.config = config
|
1163 |
|
1164 |
self.bert = NomicBertModel(config)
|
1165 |
-
classifier_dropout = (
|
1166 |
-
getattr(config, "classifier_dropout", config.embd_pdrop)
|
1167 |
-
)
|
1168 |
self.dropout = nn.Dropout(classifier_dropout)
|
1169 |
self.classifier = nn.Linear(config.n_embd, config.num_labels)
|
1170 |
|
|
|
3 |
# https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
|
4 |
# https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
|
5 |
|
6 |
+
import logging
|
7 |
+
|
8 |
# Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
|
9 |
import os
|
10 |
+
import re
|
11 |
+
from collections import OrderedDict
|
12 |
from functools import partial
|
13 |
+
from typing import List, Optional, Tuple, Union
|
14 |
|
15 |
import torch
|
16 |
import torch.nn as nn
|
17 |
import torch.nn.functional as F
|
18 |
from einops import rearrange, repeat
|
19 |
+
from safetensors.torch import load_file as safe_load_file
|
20 |
from transformers import GPT2Config, PreTrainedModel
|
21 |
from transformers.models.bert.modeling_bert import (
|
22 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
23 |
MaskedLMOutput,
|
24 |
+
SequenceClassifierOutput,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
)
|
26 |
+
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
|
27 |
from transformers.utils.hub import cached_file, get_checkpoint_shard_files
|
28 |
|
|
|
29 |
from .configuration_hf_nomic_bert import NomicBertConfig
|
30 |
|
31 |
logger = logging.getLogger(__name__)
|
32 |
|
33 |
+
|
34 |
# adapted from flash attention, added safe serialization option for hf models
|
35 |
def state_dict_from_pretrained(model_name, safe_serialization=False, device=None, dtype=None):
|
36 |
# If not fp32, then we don't want to load directly to the GPU
|
|
|
45 |
safe_weights_index_path = os.path.join(model_name, SAFE_WEIGHTS_INDEX_NAME)
|
46 |
|
47 |
if os.path.isfile(weights_path):
|
48 |
+
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
|
|
|
|
|
49 |
elif os.path.isfile(weights_index_path):
|
50 |
+
resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False)
|
|
|
|
|
51 |
is_sharded = True
|
52 |
elif os.path.isfile(safe_weights_path):
|
53 |
+
resolved_archive_file = cached_file(model_name, SAFE_WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
|
|
|
|
|
54 |
load_safe = True
|
55 |
elif os.path.isfile(safe_weights_index_path):
|
56 |
resolved_archive_file = cached_file(
|
|
|
63 |
resolved_archive_file = cached_file(model_name, weight_name, _raise_exceptions_for_missing_entries=False)
|
64 |
if resolved_archive_file is None:
|
65 |
weight_index = WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME
|
66 |
+
resolved_archive_file = cached_file(model_name, weight_index, _raise_exceptions_for_missing_entries=False)
|
|
|
67 |
if resolved_archive_file is not None:
|
68 |
is_sharded = True
|
69 |
|
|
|
80 |
if is_sharded:
|
81 |
# resolved_archive_file becomes a list of files that point to the different
|
82 |
# checkpoint shards in this case.
|
83 |
+
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(model_name, resolved_archive_file)
|
|
|
|
|
84 |
state_dict = {}
|
85 |
for sharded_file in resolved_archive_file:
|
86 |
state_dict.update(loader(sharded_file))
|
|
|
92 |
state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
|
93 |
return state_dict
|
94 |
|
95 |
+
|
96 |
def filter_shapes(state_dict, model):
|
97 |
"""
|
98 |
Filters the state dict to match the current model shape.
|
|
|
104 |
filtered_state_dict[key] = value
|
105 |
return filtered_state_dict
|
106 |
|
107 |
+
|
108 |
+
def remap_bert_state_dict(
|
109 |
+
state_dict,
|
110 |
+
config,
|
111 |
+
remove_bert=False,
|
112 |
+
remove_cls_weights=False,
|
113 |
+
add_pooling_layer=False,
|
114 |
+
):
|
115 |
"""
|
116 |
Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
|
117 |
"""
|
118 |
+
|
119 |
def add_bert_prefix(key):
|
120 |
# prepend bert. to the key
|
121 |
if key.startswith("bert.") or key.startswith("cls."):
|
|
|
123 |
return f"bert.{key}"
|
124 |
|
125 |
state_dict = OrderedDict((add_bert_prefix(k), v) for k, v in state_dict.items())
|
126 |
+
|
127 |
# LayerNorm
|
128 |
def key_mapping_ln_gamma_beta(key):
|
129 |
key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
|
|
|
188 |
bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias")
|
189 |
bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias")
|
190 |
if not (last_layer_subset and d == config.num_hidden_layers - 1):
|
191 |
+
state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
|
|
|
|
|
192 |
state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
|
193 |
else:
|
194 |
state_dict[f"bert.encoder.layers.{d}.attn.Wq.weight"] = Wq
|
|
|
208 |
def key_mapping_decoder_bias(key):
|
209 |
return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
|
210 |
|
|
|
211 |
# remove nsp weights, we don't use
|
212 |
state_dict.pop("cls.seq_relationship.weight", None)
|
213 |
state_dict.pop("cls.seq_relationship.bias", None)
|
|
|
216 |
state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
|
217 |
|
218 |
if remove_cls_weights:
|
219 |
+
cls_weights = [
|
220 |
+
"cls.predictions.decoder.bias",
|
221 |
+
"cls.predictions.transform.dense.weight",
|
222 |
+
"cls.predictions.transform.dense.bias",
|
223 |
+
"cls.predictions.transform.layer_norm.weight",
|
224 |
+
"cls.predictions.transform.layer_norm.bias",
|
225 |
+
"cls.predictions.decoder.weight",
|
226 |
+
]
|
227 |
for weight in cls_weights:
|
228 |
state_dict.pop(weight, None)
|
229 |
|
|
|
249 |
)
|
250 |
|
251 |
if add_pooling_layer is False:
|
252 |
+
pooler_weights = [
|
253 |
+
"bert.pooler.dense.weight",
|
254 |
+
"bert.pooler.dense.bias",
|
255 |
+
]
|
256 |
for key in pooler_weights:
|
257 |
state_dict.pop(key, None)
|
258 |
|
259 |
if remove_bert:
|
260 |
+
|
261 |
def remove_bert_prefix(key):
|
262 |
key = re.sub(r"^bert.", "", key)
|
263 |
return key
|
264 |
|
265 |
state_dict = OrderedDict((remove_bert_prefix(k), v) for k, v in state_dict.items())
|
266 |
|
|
|
267 |
return state_dict
|
268 |
|
269 |
|
|
|
271 |
"""An abstract class to handle weights initialization and
|
272 |
a simple interface for dowloading and loading pretrained models.
|
273 |
"""
|
274 |
+
|
275 |
config_class = NomicBertConfig
|
276 |
base_model_prefix = "model"
|
277 |
supports_gradient_checkpointing = True
|
|
|
311 |
if config is None:
|
312 |
config = cls.config_class.from_pretrained(model_name)
|
313 |
remove_cls = cls != NomicBertForPreTraining
|
314 |
+
remove_bert_prefix = cls != NomicBertForPreTraining and cls != NomicBertForSequenceClassification
|
315 |
ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
|
316 |
num_labels = kwargs.pop("num_labels", None)
|
317 |
rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
|
318 |
if rotary_scaling_factor:
|
319 |
config.rotary_scaling_factor = rotary_scaling_factor
|
320 |
+
|
|
|
321 |
if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
|
322 |
config.n_positions = 2048
|
323 |
if num_labels:
|
|
|
334 |
# Assuming we know what we're doing when loading from disk
|
335 |
# Prob a bad assumption but i'm tired and want to train this asap
|
336 |
if os.path.exists(model_name):
|
337 |
+
model_path = f"{model_name}/pytorch_model.bin"
|
338 |
+
if os.path.exists(model_path):
|
339 |
+
state_dict = torch.load(f"{model_name}/pytorch_model.bin")
|
340 |
+
else:
|
341 |
+
model_path = f"{model_name}/model.safetensors"
|
342 |
+
if not os.path.exists(model_path):
|
343 |
+
raise ValueError(f"Model path {model_path} not found")
|
344 |
+
state_dict = safe_load_file(model_path)
|
345 |
+
|
346 |
if ignore_mismatched_shapes:
|
347 |
state_dict = filter_shapes(state_dict, model)
|
348 |
load_return = model.load_state_dict(state_dict, strict=False)
|
349 |
else:
|
350 |
# TODO: can probably check config class and see if we need to remap from a bert model
|
351 |
+
state_dict = state_dict_from_pretrained(
|
352 |
+
model_name, safe_serialization=kwargs.get("safe_serialization", False)
|
353 |
+
)
|
354 |
+
state_dict = remap_bert_state_dict(
|
355 |
+
state_dict,
|
356 |
+
config,
|
357 |
+
remove_bert=remove_bert_prefix,
|
358 |
+
remove_cls_weights=remove_cls,
|
359 |
+
add_pooling_layer=getattr(config, "add_pooling_layer", False),
|
360 |
+
)
|
361 |
if ignore_mismatched_shapes:
|
362 |
state_dict = filter_shapes(state_dict, model)
|
363 |
|
364 |
+
load_return = model.load_state_dict(state_dict, strict=True)
|
|
|
|
|
|
|
365 |
logger.warning(load_return)
|
366 |
return model
|
367 |
|
|
|
381 |
if module.padding_idx is not None:
|
382 |
nn.init.zeros_(module.weight[module.padding_idx])
|
383 |
|
384 |
+
|
385 |
class NomicBertEmbeddings(nn.Module):
|
386 |
+
def __init__(self, config):
|
|
|
|
|
|
|
387 |
"""
|
388 |
If max_position_embeddings <= 0, there's no position embeddings
|
389 |
If type_vocab_size <= 0, there's no token type embeddings
|
390 |
"""
|
391 |
super().__init__()
|
392 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
|
|
|
|
393 |
self.max_position_embeddings = config.max_position_embeddings if config.rotary_emb_fraction <= 0 else 0
|
394 |
self.type_vocab_size = config.type_vocab_size
|
395 |
if self.max_position_embeddings > 0 and config.rotary_emb_fraction <= 0:
|
396 |
self.position_embeddings = nn.Embedding(
|
397 |
+
config.max_position_embeddings,
|
398 |
+
config.hidden_size,
|
399 |
)
|
400 |
if self.type_vocab_size > 0:
|
401 |
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
|
|
422 |
embeddings = embeddings + position_embeddings
|
423 |
return embeddings
|
424 |
|
425 |
+
|
426 |
class NomicBertMLP(nn.Module):
|
427 |
def __init__(
|
428 |
self,
|
|
|
440 |
hidden_features = hidden_features if hidden_features is not None else in_features * 4
|
441 |
self.return_residual = return_residual
|
442 |
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1)
|
443 |
+
approximate = "tanh" if activation in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] else "none"
|
|
|
|
|
|
|
|
|
444 |
self.activation = nn.GELU(approximate=approximate) if activation == "gelu" else activation
|
445 |
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2)
|
446 |
|
|
|
450 |
y = self.fc2(y)
|
451 |
return y if not self.return_residual else (y, x)
|
452 |
|
453 |
+
|
454 |
class NomciBertGatedMLP(nn.Module):
|
455 |
def __init__(
|
456 |
self,
|
|
|
468 |
):
|
469 |
super().__init__()
|
470 |
out_features = out_features if out_features is not None else in_features
|
471 |
+
hidden_features = hidden_features if hidden_features is not None else int(8 * in_features / 3)
|
|
|
|
|
472 |
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
|
473 |
self.return_residual = return_residual
|
474 |
|
|
|
505 |
ro_dim = cos.shape[-1] * 2
|
506 |
assert ro_dim <= x.shape[-1]
|
507 |
cos, sin = (
|
508 |
+
cos[offset : offset + x.shape[1]],
|
509 |
+
sin[offset : offset + x.shape[1]],
|
510 |
)
|
511 |
cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
512 |
sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
|
|
563 |
self._sin_k_cached = None
|
564 |
|
565 |
def _compute_inv_freq(self, device=None):
|
566 |
+
return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
|
|
|
|
|
|
|
567 |
|
568 |
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
569 |
# Reset the tables if the sequence length has changed,
|
|
|
635 |
self.rotary_scaling_factor = rotary_scaling_factor
|
636 |
self.max_position_embeddings = max_position_embeddings
|
637 |
|
|
|
638 |
def _compute_inv_freq(self, base=None, device=None):
|
639 |
if base is None:
|
640 |
base = self.base
|
641 |
+
return 1.0 / (base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
|
|
|
|
|
|
|
642 |
|
643 |
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
644 |
# Reset the tables if the sequence length has changed,
|
|
|
689 |
self._sin_cached = torch.sin(freqs).to(dtype)
|
690 |
else:
|
691 |
power = (
|
692 |
+
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
|
|
|
693 |
) / self.scale_base
|
694 |
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
695 |
# We want the multiplication by scale to happen in fp32
|
|
|
698 |
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
699 |
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
700 |
|
701 |
+
|
702 |
class NomicBertAttention(nn.Module):
|
703 |
"""Multi-head self-attention and cross-attention"""
|
704 |
|
|
|
739 |
scale_base=config.rotary_emb_scale_base,
|
740 |
interleaved=config.rotary_emb_interleaved,
|
741 |
rotary_scaling_factor=config.rotary_scaling_factor,
|
742 |
+
max_position_embeddings=config.max_trained_positions,
|
743 |
+
)
|
744 |
else:
|
745 |
self.rotary_emb = NomicBertRotaryEmbedding(
|
746 |
dim=self.rotary_emb_dim,
|
|
|
811 |
attn_output = self.out_proj(attn_output)
|
812 |
|
813 |
return attn_output
|
814 |
+
|
815 |
|
816 |
class NomicBertBlock(nn.Module):
|
817 |
def __init__(
|
|
|
821 |
super().__init__()
|
822 |
self.prenorm = config.prenorm
|
823 |
self.fused_dropout_add_ln = config.fused_dropout_add_ln
|
824 |
+
|
825 |
+
self.attn = NomicBertAttention(config)
|
826 |
activation = (
|
827 |
+
F.sigmoid
|
828 |
+
if config.activation_function == "glu"
|
829 |
+
else (F.silu if config.activation_function == "swiglu" else F.gelu)
|
830 |
)
|
831 |
if config.activation_function in ["glu", "swiglu", "geglu"]:
|
832 |
+
self.mlp = NomciBertGatedMLP(
|
833 |
+
config.n_embd,
|
834 |
+
hidden_features=config.n_inner,
|
835 |
+
bias1=config.mlp_fc1_bias,
|
836 |
+
bias2=config.mlp_fc2_bias,
|
837 |
+
activation=activation,
|
838 |
+
fused_bias_fc=config.fused_bias_fc,
|
839 |
+
)
|
840 |
else:
|
841 |
+
self.mlp = NomicBertMLP(
|
842 |
+
config.n_embd,
|
843 |
+
hidden_features=config.n_inner,
|
844 |
+
bias1=config.mlp_fc1_bias,
|
845 |
+
bias2=config.mlp_fc2_bias,
|
846 |
+
activation=activation,
|
847 |
+
fused_bias_fc=config.fused_bias_fc,
|
848 |
+
)
|
849 |
|
850 |
self.dropout1 = nn.Dropout(config.resid_pdrop)
|
851 |
self.norm1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
|
|
879 |
dropped = self.dropout1(hidden_states)
|
880 |
residual = (dropped + residual) if residual is not None else dropped
|
881 |
hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
|
882 |
+
hidden_states = self.attn(
|
883 |
+
hidden_states,
|
884 |
+
attention_mask=attention_mask,
|
885 |
+
is_padded_inputs=is_padded_inputs,
|
886 |
+
cu_seqlens=cu_seqlens,
|
887 |
+
max_seq_len=max_seq_len,
|
888 |
+
)
|
889 |
|
890 |
dropped = self.dropout2(hidden_states)
|
891 |
residual = (dropped + residual) if residual is not None else dropped
|
|
|
895 |
return hidden_states, None, residual
|
896 |
else:
|
897 |
assert residual is None
|
898 |
+
attn_outputs = self.attn(
|
899 |
+
hidden_states,
|
900 |
+
attention_mask=attention_mask,
|
901 |
+
is_padded_inputs=is_padded_inputs,
|
902 |
+
cu_seqlens=cu_seqlens,
|
903 |
+
max_seq_len=max_seq_len,
|
|
|
|
|
|
|
904 |
)
|
905 |
+
hidden_states = self.norm1((self.dropout1(attn_outputs) + hidden_states).to(dtype=self.norm1.weight.dtype))
|
906 |
mlp_out = self.mlp(hidden_states)
|
907 |
|
908 |
+
hidden_states = self.norm2((self.dropout2(mlp_out) + hidden_states).to(dtype=self.norm2.weight.dtype))
|
|
|
|
|
|
|
|
|
909 |
return hidden_states, None, None
|
910 |
|
911 |
|
912 |
class NomicBertEncoder(nn.Module):
|
913 |
def __init__(self, config: GPT2Config):
|
914 |
super().__init__()
|
915 |
+
self.layers = nn.ModuleList([NomicBertBlock(config) for _ in range(config.n_layer)])
|
|
|
|
|
916 |
self.gradient_checkpointing = False
|
917 |
self.config = config
|
918 |
|
919 |
+
def forward(
|
920 |
+
self,
|
921 |
hidden_states: torch.LongTensor = None,
|
922 |
attention_mask: Optional[torch.Tensor] = None,
|
923 |
position_ids: Optional[torch.LongTensor] = None,
|
|
|
927 |
output_attentions: Optional[bool] = None,
|
928 |
output_hidden_states: Optional[bool] = None,
|
929 |
return_dict: Optional[bool] = None,
|
930 |
+
is_padded_inputs: Optional[bool] = True,
|
931 |
+
):
|
932 |
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
933 |
This means that we only compute the last layer output for these tokens.
|
934 |
subset_mask: (batch, seqlen), dtype=torch.bool
|
|
|
936 |
hidden_states2 = None
|
937 |
residual = None
|
938 |
|
|
|
939 |
for _, layer in enumerate(self.layers):
|
940 |
if self.gradient_checkpointing and self.training:
|
941 |
|
|
|
995 |
def __init__(self, config):
|
996 |
super().__init__()
|
997 |
self.dense = nn.Linear(config.n_embd, config.n_embd, bias=config.mlp_fc1_bias)
|
998 |
+
approximate = "tanh" if config.activation_function in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] else "none"
|
|
|
|
|
|
|
|
|
999 |
if config.activation_function == "swiglu":
|
1000 |
self.transform_act_fn = F.silu
|
1001 |
else:
|
|
|
1040 |
super().__init__(config)
|
1041 |
self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
1042 |
if config.vocab_size % self.pad_vocab_size_multiple != 0:
|
1043 |
+
config.vocab_size += self.pad_vocab_size_multiple - (config.vocab_size % self.pad_vocab_size_multiple)
|
1044 |
+
|
1045 |
+
assert config.activation_function in [
|
1046 |
+
"gelu",
|
1047 |
+
"gelu_new",
|
1048 |
+
"gelu_fast",
|
1049 |
+
"gelu_pytorch_tanh",
|
1050 |
+
"swiglu",
|
1051 |
+
"geglu",
|
1052 |
+
"glu",
|
1053 |
+
]
|
1054 |
+
|
1055 |
+
self.embeddings = NomicBertEmbeddings(config)
|
1056 |
self.emb_drop = nn.Dropout(config.resid_pdrop)
|
1057 |
self.emb_ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
1058 |
self.encoder = NomicBertEncoder(config)
|
|
|
1066 |
position_ids=None,
|
1067 |
token_type_ids=None,
|
1068 |
attention_mask=None,
|
1069 |
+
return_dict=None,
|
1070 |
+
matryoshka_dim=None,
|
1071 |
):
|
1072 |
if token_type_ids is None:
|
1073 |
token_type_ids = torch.zeros_like(input_ids)
|
1074 |
+
hidden_states = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
|
|
|
|
|
1075 |
hidden_states = self.emb_ln(hidden_states)
|
1076 |
hidden_states = self.emb_drop(hidden_states)
|
1077 |
|
1078 |
attention_mask = self.get_extended_attention_mask(attention_mask, input_ids.shape)
|
1079 |
+
sequence_output = self.encoder(hidden_states, attention_mask=attention_mask, return_dict=return_dict)
|
|
|
|
|
1080 |
|
1081 |
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
1082 |
|
1083 |
+
if matryoshka_dim:
|
1084 |
+
sequence_output = sequence_output[:, :matryoshka_dim]
|
1085 |
+
|
1086 |
return BaseModelOutputWithPoolingAndCrossAttentions(
|
1087 |
last_hidden_state=sequence_output,
|
1088 |
pooler_output=pooled_output,
|
|
|
1149 |
loss=total_loss,
|
1150 |
logits=prediction_scores,
|
1151 |
hidden_states=outputs.hidden_states,
|
1152 |
+
attentions=None,
|
1153 |
)
|
1154 |
|
1155 |
+
|
1156 |
class NomicBertForSequenceClassification(NomicBertPreTrainedModel):
|
1157 |
def __init__(self, config):
|
1158 |
super().__init__(config)
|
|
|
1160 |
self.config = config
|
1161 |
|
1162 |
self.bert = NomicBertModel(config)
|
1163 |
+
classifier_dropout = getattr(config, "classifier_dropout", config.embd_pdrop)
|
|
|
|
|
1164 |
self.dropout = nn.Dropout(classifier_dropout)
|
1165 |
self.classifier = nn.Linear(config.n_embd, config.num_labels)
|
1166 |
|