liminghong commited on
Commit
f784b6e
1 Parent(s): bfad706
Files changed (1) hide show
  1. bert_layers.py +2 -2
bert_layers.py CHANGED
@@ -599,9 +599,9 @@ class BertModel(BertPreTrainedModel):
599
  device = input_ids.device if input_ids is not None else inputs_embeds.device
600
 
601
  if attention_mask is None:
602
- attention_mask = torch.ones(input_shape, device=device)
603
  if token_type_ids is None:
604
- token_type_ids = torch.zeros(input_shape, device=device)
605
 
606
  embedding_output = self.embeddings(
607
  input_ids,
 
599
  device = input_ids.device if input_ids is not None else inputs_embeds.device
600
 
601
  if attention_mask is None:
602
+ attention_mask = torch.ones(input_shape, dtype=torch.long, device=device)
603
  if token_type_ids is None:
604
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
605
 
606
  embedding_output = self.embeddings(
607
  input_ids,