update for transformers >= 4.29.1
Browse files- modeling_lsg_bart.py +23 -3
modeling_lsg_bart.py
CHANGED
@@ -643,6 +643,11 @@ class LSGBartEncoderLayer(BartEncoderLayer):
|
|
643 |
class LSGBartPretrainedModel(BartPretrainedModel):
|
644 |
|
645 |
config_class = LSGBartConfig
|
|
|
|
|
|
|
|
|
|
|
646 |
|
647 |
def _set_gradient_checkpointing(self, module, value=False):
|
648 |
|
@@ -836,8 +841,13 @@ class LSGBartEncoder(LSGBartPretrainedModel, BartEncoder):
|
|
836 |
if output_hidden_states:
|
837 |
encoder_states = encoder_states + (hidden_states,)
|
838 |
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
839 |
-
|
840 |
-
if self.training
|
|
|
|
|
|
|
|
|
|
|
841 |
layer_outputs = (None, None)
|
842 |
else:
|
843 |
if self.gradient_checkpointing and self.training:
|
@@ -879,6 +889,8 @@ class LSGBartEncoder(LSGBartPretrainedModel, BartEncoder):
|
|
879 |
|
880 |
class LSGBartModel(LSGBartPretrainedModel, BartModel):
|
881 |
|
|
|
|
|
882 |
def __init__(self, config):
|
883 |
|
884 |
LSGBartPretrainedModel.__init__(self, config)
|
@@ -984,7 +996,8 @@ class LSGBartModel(LSGBartPretrainedModel, BartModel):
|
|
984 |
class LSGBartForConditionalGeneration(LSGBartPretrainedModel, BartForConditionalGeneration):
|
985 |
|
986 |
base_model_prefix = "model"
|
987 |
-
|
|
|
988 |
|
989 |
def __init__(self, config):
|
990 |
|
@@ -999,6 +1012,8 @@ class LSGBartForConditionalGeneration(LSGBartPretrainedModel, BartForConditional
|
|
999 |
|
1000 |
class LSGBartForSequenceClassification(LSGBartPretrainedModel, BartForSequenceClassification):
|
1001 |
|
|
|
|
|
1002 |
def __init__(self, config: LSGBartConfig, **kwargs):
|
1003 |
|
1004 |
LSGBartPretrainedModel.__init__(self, config, **kwargs)
|
@@ -1015,6 +1030,8 @@ class LSGBartForSequenceClassification(LSGBartPretrainedModel, BartForSequenceCl
|
|
1015 |
|
1016 |
class LSGBartForQuestionAnswering(LSGBartPretrainedModel, BartForQuestionAnswering):
|
1017 |
|
|
|
|
|
1018 |
def __init__(self, config: LSGBartConfig):
|
1019 |
|
1020 |
LSGBartPretrainedModel.__init__(self, config)
|
@@ -1030,6 +1047,9 @@ class LSGBartForQuestionAnswering(LSGBartPretrainedModel, BartForQuestionAnsweri
|
|
1030 |
|
1031 |
class LSGBartForCausalLM(LSGBartPretrainedModel, BartForCausalLM):
|
1032 |
|
|
|
|
|
|
|
1033 |
def __init__(self, config: LSGBartConfig):
|
1034 |
|
1035 |
LSGBartPretrainedModel.__init__(self, config)
|
|
|
643 |
class LSGBartPretrainedModel(BartPretrainedModel):
|
644 |
|
645 |
config_class = LSGBartConfig
|
646 |
+
base_model_prefix = "model"
|
647 |
+
supports_gradient_checkpointing = True
|
648 |
+
_keys_to_ignore_on_load_unexpected = ["encoder.version", "decoder.version"]
|
649 |
+
_no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"]
|
650 |
+
_skip_keys_device_placement = "past_key_values"
|
651 |
|
652 |
def _set_gradient_checkpointing(self, module, value=False):
|
653 |
|
|
|
841 |
if output_hidden_states:
|
842 |
encoder_states = encoder_states + (hidden_states,)
|
843 |
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
844 |
+
to_drop = False
|
845 |
+
if self.training:
|
846 |
+
dropout_probability = torch.rand([])
|
847 |
+
if dropout_probability < self.layerdrop: # skip the layer
|
848 |
+
to_drop = True
|
849 |
+
|
850 |
+
if to_drop:
|
851 |
layer_outputs = (None, None)
|
852 |
else:
|
853 |
if self.gradient_checkpointing and self.training:
|
|
|
889 |
|
890 |
class LSGBartModel(LSGBartPretrainedModel, BartModel):
|
891 |
|
892 |
+
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
893 |
+
|
894 |
def __init__(self, config):
|
895 |
|
896 |
LSGBartPretrainedModel.__init__(self, config)
|
|
|
996 |
class LSGBartForConditionalGeneration(LSGBartPretrainedModel, BartForConditionalGeneration):
|
997 |
|
998 |
base_model_prefix = "model"
|
999 |
+
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
|
1000 |
+
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
|
1001 |
|
1002 |
def __init__(self, config):
|
1003 |
|
|
|
1012 |
|
1013 |
class LSGBartForSequenceClassification(LSGBartPretrainedModel, BartForSequenceClassification):
|
1014 |
|
1015 |
+
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
1016 |
+
|
1017 |
def __init__(self, config: LSGBartConfig, **kwargs):
|
1018 |
|
1019 |
LSGBartPretrainedModel.__init__(self, config, **kwargs)
|
|
|
1030 |
|
1031 |
class LSGBartForQuestionAnswering(LSGBartPretrainedModel, BartForQuestionAnswering):
|
1032 |
|
1033 |
+
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
1034 |
+
|
1035 |
def __init__(self, config: LSGBartConfig):
|
1036 |
|
1037 |
LSGBartPretrainedModel.__init__(self, config)
|
|
|
1047 |
|
1048 |
class LSGBartForCausalLM(LSGBartPretrainedModel, BartForCausalLM):
|
1049 |
|
1050 |
+
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
|
1051 |
+
_tied_weights_keys = ["lm_head.weight"]
|
1052 |
+
|
1053 |
def __init__(self, config: LSGBartConfig):
|
1054 |
|
1055 |
LSGBartPretrainedModel.__init__(self, config)
|