oweller2
commited on
Commit
•
6f2cf23
1
Parent(s):
3556a25
fix
Browse files- modeling_flexbert.py +1 -1
modeling_flexbert.py
CHANGED
@@ -1644,7 +1644,7 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
|
|
1644 |
if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None):
|
1645 |
batch_size, seq_len = input_ids.shape[:2]
|
1646 |
if attention_mask is None: # Create causal mask (lower triangular)
|
1647 |
-
attention_mask = torch.tril(torch.ones(batch_size, seq_len), diagonal=0)
|
1648 |
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs(
|
1649 |
input_ids, attention_mask, position_ids, labels
|
1650 |
)
|
|
|
1644 |
if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None):
|
1645 |
batch_size, seq_len = input_ids.shape[:2]
|
1646 |
if attention_mask is None: # Create causal mask (lower triangular)
|
1647 |
+
attention_mask = torch.tril(torch.ones(batch_size, seq_len, device=input_ids.device), diagonal=0)
|
1648 |
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs(
|
1649 |
input_ids, attention_mask, position_ids, labels
|
1650 |
)
|