Fill-Mask
Transformers
PyTorch
Safetensors
English
nomic_bert
custom_code

Add inputs_embeds argument

#7
Files changed (1) hide show
  1. modeling_hf_nomic_bert.py +15 -3
modeling_hf_nomic_bert.py CHANGED
@@ -977,14 +977,18 @@ class NomicBertEmbeddings(nn.Module):
977
  if self.type_vocab_size > 0:
978
  self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
979
 
980
- def forward(self, input_ids, position_ids=None, token_type_ids=None):
981
  """
982
  input_ids: (batch, seqlen)
983
  position_ids: (batch, seqlen)
984
  token_type_ids: (batch, seqlen)
985
  """
986
  batch_size, seqlen = input_ids.shape
987
- embeddings = self.word_embeddings(input_ids)
 
 
 
 
988
 
989
  if self.type_vocab_size > 0:
990
  if token_type_ids is None:
@@ -1680,10 +1684,18 @@ class NomicBertModel(NomicBertPreTrainedModel):
1680
  token_type_ids=None,
1681
  return_dict=None,
1682
  matryoshka_dim=None,
 
1683
  ):
 
 
1684
  if token_type_ids is None:
1685
  token_type_ids = torch.zeros_like(input_ids)
1686
- hidden_states = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
 
 
 
 
 
1687
  hidden_states = self.emb_ln(hidden_states)
1688
  hidden_states = self.emb_drop(hidden_states)
1689
 
 
977
  if self.type_vocab_size > 0:
978
  self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
979
 
980
+ def forward(self, input_ids=None, position_ids=None, token_type_ids=None, inputs_embeds=None):
981
  """
982
  input_ids: (batch, seqlen)
983
  position_ids: (batch, seqlen)
984
  token_type_ids: (batch, seqlen)
985
  """
986
  batch_size, seqlen = input_ids.shape
987
+
988
+ if inputs_embeds is None:
989
+ embeddings = self.word_embeddings(input_ids)
990
+ else:
991
+ embeddings = inputs_embeds
992
 
993
  if self.type_vocab_size > 0:
994
  if token_type_ids is None:
 
1684
  token_type_ids=None,
1685
  return_dict=None,
1686
  matryoshka_dim=None,
1687
+ inputs_embeds=None,
1688
  ):
1689
+ if input_ids is not None and inputs_embeds is not None:
1690
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1691
  if token_type_ids is None:
1692
  token_type_ids = torch.zeros_like(input_ids)
1693
+ hidden_states = self.embeddings(
1694
+ input_ids=input_ids,
1695
+ position_ids=position_ids,
1696
+ token_type_ids=token_type_ids,
1697
+ inputs_embeds=inputs_embeds,
1698
+ )
1699
  hidden_states = self.emb_ln(hidden_states)
1700
  hidden_states = self.emb_drop(hidden_states)
1701