Add inputs_embeds argument
#7
by
jxm
- opened
- modeling_hf_nomic_bert.py +15 -3
modeling_hf_nomic_bert.py
CHANGED
@@ -977,14 +977,18 @@ class NomicBertEmbeddings(nn.Module):
|
|
977 |
if self.type_vocab_size > 0:
|
978 |
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
979 |
|
980 |
-
def forward(self, input_ids, position_ids=None, token_type_ids=None):
|
981 |
"""
|
982 |
input_ids: (batch, seqlen)
|
983 |
position_ids: (batch, seqlen)
|
984 |
token_type_ids: (batch, seqlen)
|
985 |
"""
|
986 |
batch_size, seqlen = input_ids.shape
|
987 |
-
|
|
|
|
|
|
|
|
|
988 |
|
989 |
if self.type_vocab_size > 0:
|
990 |
if token_type_ids is None:
|
@@ -1680,10 +1684,18 @@ class NomicBertModel(NomicBertPreTrainedModel):
|
|
1680 |
token_type_ids=None,
|
1681 |
return_dict=None,
|
1682 |
matryoshka_dim=None,
|
|
|
1683 |
):
|
|
|
|
|
1684 |
if token_type_ids is None:
|
1685 |
token_type_ids = torch.zeros_like(input_ids)
|
1686 |
-
hidden_states = self.embeddings(
|
|
|
|
|
|
|
|
|
|
|
1687 |
hidden_states = self.emb_ln(hidden_states)
|
1688 |
hidden_states = self.emb_drop(hidden_states)
|
1689 |
|
|
|
977 |
if self.type_vocab_size > 0:
|
978 |
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
979 |
|
980 |
+
def forward(self, input_ids=None, position_ids=None, token_type_ids=None, inputs_embeds=None):
|
981 |
"""
|
982 |
input_ids: (batch, seqlen)
|
983 |
position_ids: (batch, seqlen)
|
984 |
token_type_ids: (batch, seqlen)
|
985 |
"""
|
986 |
batch_size, seqlen = input_ids.shape
|
987 |
+
|
988 |
+
if inputs_embeds is None:
|
989 |
+
embeddings = self.word_embeddings(input_ids)
|
990 |
+
else:
|
991 |
+
embeddings = inputs_embeds
|
992 |
|
993 |
if self.type_vocab_size > 0:
|
994 |
if token_type_ids is None:
|
|
|
1684 |
token_type_ids=None,
|
1685 |
return_dict=None,
|
1686 |
matryoshka_dim=None,
|
1687 |
+
inputs_embeds=None,
|
1688 |
):
|
1689 |
+
if input_ids is not None and inputs_embeds is not None:
|
1690 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
1691 |
if token_type_ids is None:
|
1692 |
token_type_ids = torch.zeros_like(input_ids)
|
1693 |
+
hidden_states = self.embeddings(
|
1694 |
+
input_ids=input_ids,
|
1695 |
+
position_ids=position_ids,
|
1696 |
+
token_type_ids=token_type_ids,
|
1697 |
+
inputs_embeds=inputs_embeds,
|
1698 |
+
)
|
1699 |
hidden_states = self.emb_ln(hidden_states)
|
1700 |
hidden_states = self.emb_drop(hidden_states)
|
1701 |
|