liminghong commited on
Commit
bfad706
1 Parent(s): b498c31

Bert model allows inputs_embeds input

Browse files
Files changed (1) hide show
  1. bert_layers.py +19 -5
bert_layers.py CHANGED
@@ -579,21 +579,35 @@ class BertModel(BertPreTrainedModel):
579
 
580
  def forward(
581
  self,
582
- input_ids: torch.Tensor,
583
  token_type_ids: Optional[torch.Tensor] = None,
584
  attention_mask: Optional[torch.Tensor] = None,
585
  position_ids: Optional[torch.Tensor] = None,
 
586
  output_all_encoded_layers: Optional[bool] = False,
587
  masked_tokens_mask: Optional[torch.Tensor] = None,
588
  **kwargs
589
  ) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]:
 
 
 
 
 
 
 
 
 
 
590
  if attention_mask is None:
591
- attention_mask = torch.ones_like(input_ids)
592
  if token_type_ids is None:
593
- token_type_ids = torch.zeros_like(input_ids)
594
 
595
- embedding_output = self.embeddings(input_ids, token_type_ids,
596
- position_ids)
 
 
 
597
 
598
  subset_mask = []
599
  first_col_mask = []
 
579
 
580
  def forward(
581
  self,
582
+ input_ids: Optional[torch.Tensor] = None,
583
  token_type_ids: Optional[torch.Tensor] = None,
584
  attention_mask: Optional[torch.Tensor] = None,
585
  position_ids: Optional[torch.Tensor] = None,
586
+ inputs_embeds: Optional[torch.Tensor] = None,
587
  output_all_encoded_layers: Optional[bool] = False,
588
  masked_tokens_mask: Optional[torch.Tensor] = None,
589
  **kwargs
590
  ) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]:
591
+ if input_ids is not None and inputs_embeds is not None:
592
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
593
+ elif input_ids is not None:
594
+ input_shape = input_ids.size()
595
+ elif inputs_embeds is not None:
596
+ input_shape = inputs_embeds.size()[:-1]
597
+ else:
598
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
599
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
600
+
601
  if attention_mask is None:
602
+ attention_mask = torch.ones(input_shape, device=device)
603
  if token_type_ids is None:
604
+ token_type_ids = torch.zeros(input_shape, device=device)
605
 
606
+ embedding_output = self.embeddings(
607
+ input_ids,
608
+ token_type_ids,
609
+ position_ids
610
+ ) if inputs_embeds is None else inputs_embeds
611
 
612
  subset_mask = []
613
  first_col_mask = []