oweller2 commited on
Commit
6f2cf23
1 Parent(s): 3556a25
Files changed (1) hide show
  1. modeling_flexbert.py +1 -1
modeling_flexbert.py CHANGED
@@ -1644,7 +1644,7 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
1644
  if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None):
1645
  batch_size, seq_len = input_ids.shape[:2]
1646
  if attention_mask is None: # Create causal mask (lower triangular)
1647
- attention_mask = torch.tril(torch.ones(batch_size, seq_len), diagonal=0)
1648
  input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs(
1649
  input_ids, attention_mask, position_ids, labels
1650
  )
 
1644
  if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None):
1645
  batch_size, seq_len = input_ids.shape[:2]
1646
  if attention_mask is None: # Create causal mask (lower triangular)
1647
+ attention_mask = torch.tril(torch.ones(batch_size, seq_len, device=input_ids.device), diagonal=0)
1648
  input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs(
1649
  input_ids, attention_mask, position_ids, labels
1650
  )