oweller2
commited on
Commit
•
1f59624
1
Parent(s):
855df5e
update model
Browse files- modeling_flexbert.py +4 -2
- pytorch_model.bin +1 -1
modeling_flexbert.py
CHANGED
@@ -1507,7 +1507,8 @@ class FlexBertForMultipleChoice(FlexBertPreTrainedModel):
|
|
1507 |
|
1508 |
|
1509 |
class FlexBertForCasualLM(FlexBertPreTrainedModel):
|
1510 |
-
config_class = FlexBertConfig
|
|
|
1511 |
"""Bert Model transformer with a LM head.
|
1512 |
|
1513 |
This head is just a standard LM head module. Used for causal language modeling tasks.
|
@@ -1701,8 +1702,8 @@ class FlexBertForCasualLM(FlexBertPreTrainedModel):
|
|
1701 |
shift_labels.view(-1)
|
1702 |
)
|
1703 |
|
1704 |
-
assert False
|
1705 |
if self.pad_logits:
|
|
|
1706 |
new_logits = self.pad_inputs(logits, indices, batch_size, seq_len)[0]
|
1707 |
if len(new_logits.shape) == 2:
|
1708 |
new_logits = new_logits.unsqueeze(0)
|
@@ -1713,6 +1714,7 @@ class FlexBertForCasualLM(FlexBertPreTrainedModel):
|
|
1713 |
attentions=None,
|
1714 |
)
|
1715 |
else:
|
|
|
1716 |
if len(logits.shape) == 2:
|
1717 |
logits = logits.unsqueeze(0)
|
1718 |
return CausalLMOutput(
|
|
|
1507 |
|
1508 |
|
1509 |
class FlexBertForCasualLM(FlexBertPreTrainedModel):
|
1510 |
+
config_class = FlexBertConfig
|
1511 |
+
|
1512 |
"""Bert Model transformer with a LM head.
|
1513 |
|
1514 |
This head is just a standard LM head module. Used for causal language modeling tasks.
|
|
|
1702 |
shift_labels.view(-1)
|
1703 |
)
|
1704 |
|
|
|
1705 |
if self.pad_logits:
|
1706 |
+
print(f"Padding logits: {logits.shape}")
|
1707 |
new_logits = self.pad_inputs(logits, indices, batch_size, seq_len)[0]
|
1708 |
if len(new_logits.shape) == 2:
|
1709 |
new_logits = new_logits.unsqueeze(0)
|
|
|
1714 |
attentions=None,
|
1715 |
)
|
1716 |
else:
|
1717 |
+
print(f"Non-padding logits: {logits.shape}")
|
1718 |
if len(logits.shape) == 2:
|
1719 |
logits = logits.unsqueeze(0)
|
1720 |
return CausalLMOutput(
|
pytorch_model.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 598685038
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:67a284058719cc19a4ec63c15f1f3cd8b57e5d4e59494b1704b7ae33ab717fb9
|
3 |
size 598685038
|