liminghong
commited on
Commit
•
f784b6e
1
Parent(s):
bfad706
Fix Bugs
Browse files- bert_layers.py +2 -2
bert_layers.py
CHANGED
@@ -599,9 +599,9 @@ class BertModel(BertPreTrainedModel):
|
|
599 |
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
600 |
|
601 |
if attention_mask is None:
|
602 |
-
attention_mask = torch.ones(input_shape, device=device)
|
603 |
if token_type_ids is None:
|
604 |
-
token_type_ids = torch.zeros(input_shape, device=device)
|
605 |
|
606 |
embedding_output = self.embeddings(
|
607 |
input_ids,
|
|
|
599 |
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
600 |
|
601 |
if attention_mask is None:
|
602 |
+
attention_mask = torch.ones(input_shape, dtype=torch.long, device=device)
|
603 |
if token_type_ids is None:
|
604 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
605 |
|
606 |
embedding_output = self.embeddings(
|
607 |
input_ids,
|