ccdv commited on
Commit
c870da8
1 Parent(s): 6e5504c

fix for transformers >= 4.35.2

Browse files
Files changed (2) hide show
  1. README.md +2 -2
  2. modeling_lsg_bart.py +3 -3
README.md CHANGED
@@ -18,7 +18,7 @@ model-index:
18
  <!-- This model card has been generated automatically according to the information the Trainer had access to. You
19
  should probably proofread and complete it, then remove this comment. -->
20
 
21
- **Transformers >= 4.23.1**\
22
  **This model relies on a custom modeling file, you need to add trust_remote_code=True**\
23
  **See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
24
 
@@ -105,7 +105,7 @@ The following hyperparameters were used during generation:
105
 
106
  ### Framework versions
107
 
108
- - Transformers 4.23.1
109
  - Pytorch 1.12.1
110
  - Datasets 2.3.2
111
  - Tokenizers 0.11.6
 
18
  <!-- This model card has been generated automatically according to the information the Trainer had access to. You
19
  should probably proofread and complete it, then remove this comment. -->
20
 
21
+ **Transformers >= 4.35.2**\
22
  **This model relies on a custom modeling file, you need to add trust_remote_code=True**\
23
  **See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
24
 
 
105
 
106
  ### Framework versions
107
 
108
+ - Transformers 4.35.2
109
  - Pytorch 1.12.1
110
  - Datasets 2.3.2
111
  - Tokenizers 0.11.6
modeling_lsg_bart.py CHANGED
@@ -1,7 +1,7 @@
1
  from logging import warn
2
  import torch
3
  from transformers.models.bart.modeling_bart import *
4
- from transformers.models.bart.modeling_bart import _expand_mask
5
  import torch.nn as nn
6
  import sys
7
 
@@ -852,7 +852,7 @@ class LSGBartEncoder(LSGBartPretrainedModel, BartEncoder):
852
  # expand attention_mask
853
  if attention_mask is not None:
854
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
855
- attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
856
 
857
  encoder_states = () if output_hidden_states else None
858
  all_attentions = () if output_attentions else None
@@ -1093,4 +1093,4 @@ try:
1093
  str_to_class(value.split(".")[-1]).register_for_auto_class(key)
1094
  except:
1095
  warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).")
1096
- warn("Update to transformers >= 4.23.1 to fix.")
 
1
  from logging import warn
2
  import torch
3
  from transformers.models.bart.modeling_bart import *
4
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
5
  import torch.nn as nn
6
  import sys
7
 
 
852
  # expand attention_mask
853
  if attention_mask is not None:
854
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
855
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
856
 
857
  encoder_states = () if output_hidden_states else None
858
  all_attentions = () if output_attentions else None
 
1093
  str_to_class(value.split(".")[-1]).register_for_auto_class(key)
1094
  except:
1095
  warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).")
1096
+ warn("Update to transformers >= 4.35.2 to fix.")