liminghong
commited on
Commit
•
bfad706
1
Parent(s):
b498c31
Bert model allows inputs_embeds input
Browse files- bert_layers.py +19 -5
bert_layers.py
CHANGED
@@ -579,21 +579,35 @@ class BertModel(BertPreTrainedModel):
|
|
579 |
|
580 |
def forward(
|
581 |
self,
|
582 |
-
input_ids: torch.Tensor,
|
583 |
token_type_ids: Optional[torch.Tensor] = None,
|
584 |
attention_mask: Optional[torch.Tensor] = None,
|
585 |
position_ids: Optional[torch.Tensor] = None,
|
|
|
586 |
output_all_encoded_layers: Optional[bool] = False,
|
587 |
masked_tokens_mask: Optional[torch.Tensor] = None,
|
588 |
**kwargs
|
589 |
) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
590 |
if attention_mask is None:
|
591 |
-
attention_mask = torch.
|
592 |
if token_type_ids is None:
|
593 |
-
token_type_ids = torch.
|
594 |
|
595 |
-
embedding_output = self.embeddings(
|
596 |
-
|
|
|
|
|
|
|
597 |
|
598 |
subset_mask = []
|
599 |
first_col_mask = []
|
|
|
579 |
|
580 |
def forward(
|
581 |
self,
|
582 |
+
input_ids: Optional[torch.Tensor] = None,
|
583 |
token_type_ids: Optional[torch.Tensor] = None,
|
584 |
attention_mask: Optional[torch.Tensor] = None,
|
585 |
position_ids: Optional[torch.Tensor] = None,
|
586 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
587 |
output_all_encoded_layers: Optional[bool] = False,
|
588 |
masked_tokens_mask: Optional[torch.Tensor] = None,
|
589 |
**kwargs
|
590 |
) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]:
|
591 |
+
if input_ids is not None and inputs_embeds is not None:
|
592 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
593 |
+
elif input_ids is not None:
|
594 |
+
input_shape = input_ids.size()
|
595 |
+
elif inputs_embeds is not None:
|
596 |
+
input_shape = inputs_embeds.size()[:-1]
|
597 |
+
else:
|
598 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
599 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
600 |
+
|
601 |
if attention_mask is None:
|
602 |
+
attention_mask = torch.ones(input_shape, device=device)
|
603 |
if token_type_ids is None:
|
604 |
+
token_type_ids = torch.zeros(input_shape, device=device)
|
605 |
|
606 |
+
embedding_output = self.embeddings(
|
607 |
+
input_ids,
|
608 |
+
token_type_ids,
|
609 |
+
position_ids
|
610 |
+
) if inputs_embeds is None else inputs_embeds
|
611 |
|
612 |
subset_mask = []
|
613 |
first_col_mask = []
|