oweller2 commited on
Commit
1f59624
1 Parent(s): 855df5e

update model

Browse files
Files changed (2) hide show
  1. modeling_flexbert.py +4 -2
  2. pytorch_model.bin +1 -1
modeling_flexbert.py CHANGED
@@ -1507,7 +1507,8 @@ class FlexBertForMultipleChoice(FlexBertPreTrainedModel):
1507
 
1508
 
1509
  class FlexBertForCasualLM(FlexBertPreTrainedModel):
1510
- config_class = FlexBertConfig # Add this line
 
1511
  """Bert Model transformer with a LM head.
1512
 
1513
  This head is just a standard LM head module. Used for causal language modeling tasks.
@@ -1701,8 +1702,8 @@ class FlexBertForCasualLM(FlexBertPreTrainedModel):
1701
  shift_labels.view(-1)
1702
  )
1703
 
1704
- assert False
1705
  if self.pad_logits:
 
1706
  new_logits = self.pad_inputs(logits, indices, batch_size, seq_len)[0]
1707
  if len(new_logits.shape) == 2:
1708
  new_logits = new_logits.unsqueeze(0)
@@ -1713,6 +1714,7 @@ class FlexBertForCasualLM(FlexBertPreTrainedModel):
1713
  attentions=None,
1714
  )
1715
  else:
 
1716
  if len(logits.shape) == 2:
1717
  logits = logits.unsqueeze(0)
1718
  return CausalLMOutput(
 
1507
 
1508
 
1509
  class FlexBertForCasualLM(FlexBertPreTrainedModel):
1510
+ config_class = FlexBertConfig
1511
+
1512
  """Bert Model transformer with a LM head.
1513
 
1514
  This head is just a standard LM head module. Used for causal language modeling tasks.
 
1702
  shift_labels.view(-1)
1703
  )
1704
 
 
1705
  if self.pad_logits:
1706
+ print(f"Padding logits: {logits.shape}")
1707
  new_logits = self.pad_inputs(logits, indices, batch_size, seq_len)[0]
1708
  if len(new_logits.shape) == 2:
1709
  new_logits = new_logits.unsqueeze(0)
 
1714
  attentions=None,
1715
  )
1716
  else:
1717
+ print(f"Non-padding logits: {logits.shape}")
1718
  if len(logits.shape) == 2:
1719
  logits = logits.unsqueeze(0)
1720
  return CausalLMOutput(
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:96348b6ba9cd41656884fe36b14263dbb74f76da5632fa0f4b92c0781353f2b0
3
  size 598685038
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:67a284058719cc19a4ec63c15f1f3cd8b57e5d4e59494b1704b7ae33ab717fb9
3
  size 598685038