ammarnasr commited on
Commit
41aaa4e
1 Parent(s): 29791f4

Upload model

Browse files
Files changed (5) hide show
  1. README.md +199 -0
  2. config.json +35 -0
  3. configuration_t5mimo.py +152 -0
  4. model.safetensors +3 -0
  5. modeling_t5mimo.py +1751 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "T5MIMOModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_t5mimo.T5MIMOConfig",
7
+ "AutoModel": "modeling_t5mimo.T5MIMOModel"
8
+ },
9
+ "classifier_dropout": 0.0,
10
+ "d_ff": 1024,
11
+ "d_kv": 64,
12
+ "d_model": 256,
13
+ "decoder_start_token_id": 0,
14
+ "dense_act_fn": "relu",
15
+ "dropout_rate": 0.1,
16
+ "eos_token_id": 1,
17
+ "feed_forward_proj": "relu",
18
+ "initializer_factor": 0.05,
19
+ "is_encoder_decoder": true,
20
+ "is_gated_act": false,
21
+ "layer_norm_epsilon": 1e-06,
22
+ "model_type": "t5",
23
+ "num_decoder_layers": 4,
24
+ "num_filters": 64,
25
+ "num_heads": 4,
26
+ "num_layers": 4,
27
+ "num_seqs": 3,
28
+ "pad_token_id": 0,
29
+ "relative_attention_max_distance": 128,
30
+ "relative_attention_num_buckets": 32,
31
+ "torch_dtype": "float32",
32
+ "transformers_version": "4.41.1",
33
+ "use_cache": true,
34
+ "vocab_size": 4096
35
+ }
configuration_t5mimo.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Mapping
2
+ from transformers.configuration_utils import PretrainedConfig
3
+ from transformers.onnx import OnnxSeq2SeqConfigWithPast
4
+ from transformers.utils import logging
5
+
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+
10
+ class T5MIMOConfig(PretrainedConfig):
11
+ r"""
12
+ This is the configuration class to store the configuration of a [`T5Model`] or a [`TFT5Model`]. It is used to
13
+ instantiate a T5 model according to the specified arguments, defining the model architecture. Instantiating a
14
+ configuration with the defaults will yield a similar configuration to that of the T5
15
+ [google-t5/t5-small](https://huggingface.co/google-t5/t5-small) architecture.
16
+
17
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
18
+ documentation from [`PretrainedConfig`] for more information.
19
+
20
+ Arguments:
21
+ vocab_size (`int`, *optional*, defaults to 32128):
22
+ Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the
23
+ `inputs_ids` passed when calling [`T5Model`] or [`TFT5Model`].
24
+ d_model (`int`, *optional*, defaults to 512):
25
+ Size of the encoder layers and the pooler layer.
26
+ d_kv (`int`, *optional*, defaults to 64):
27
+ Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will
28
+ be defined as `num_heads * d_kv`.
29
+ d_ff (`int`, *optional*, defaults to 2048):
30
+ Size of the intermediate feed forward layer in each `T5Block`.
31
+ num_layers (`int`, *optional*, defaults to 6):
32
+ Number of hidden layers in the Transformer encoder.
33
+ num_decoder_layers (`int`, *optional*):
34
+ Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set.
35
+ num_heads (`int`, *optional*, defaults to 8):
36
+ Number of attention heads for each attention layer in the Transformer encoder.
37
+ relative_attention_num_buckets (`int`, *optional*, defaults to 32):
38
+ The number of buckets to use for each attention layer.
39
+ relative_attention_max_distance (`int`, *optional*, defaults to 128):
40
+ The maximum distance of the longer sequences for the bucket separation.
41
+ dropout_rate (`float`, *optional*, defaults to 0.1):
42
+ The ratio for all dropout layers.
43
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
44
+ The dropout ratio for classifier.
45
+ layer_norm_eps (`float`, *optional*, defaults to 1e-6):
46
+ The epsilon used by the layer normalization layers.
47
+ initializer_factor (`float`, *optional*, defaults to 1):
48
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
49
+ testing).
50
+ feed_forward_proj (`string`, *optional*, defaults to `"relu"`):
51
+ Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. T5v1.1 uses the
52
+ `"gated-gelu"` feed forward projection. Original T5 uses `"relu"`.
53
+ use_cache (`bool`, *optional*, defaults to `True`):
54
+ Whether or not the model should return the last key/values attentions (not used by all models).
55
+ """
56
+
57
+ model_type = "t5"
58
+ keys_to_ignore_at_inference = ["past_key_values"]
59
+ attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}
60
+
61
+ def __init__(
62
+ self,
63
+ vocab_size=32128,
64
+ d_model=512,
65
+ d_kv=64,
66
+ d_ff=2048,
67
+ num_layers=6,
68
+ num_decoder_layers=None,
69
+ num_heads=8,
70
+ relative_attention_num_buckets=32,
71
+ relative_attention_max_distance=128,
72
+ dropout_rate=0.1,
73
+ layer_norm_epsilon=1e-6,
74
+ initializer_factor=1.0,
75
+ feed_forward_proj="relu",
76
+ is_encoder_decoder=True,
77
+ use_cache=True,
78
+ pad_token_id=0,
79
+ eos_token_id=1,
80
+ decoder_start_token_id = 0,
81
+ classifier_dropout=0.0,
82
+ num_seqs=3,
83
+ num_filters=64,
84
+ **kwargs,
85
+ ):
86
+ self.vocab_size = vocab_size
87
+ self.d_model = d_model
88
+ self.d_kv = d_kv
89
+ self.d_ff = d_ff
90
+ self.num_layers = num_layers
91
+ self.num_decoder_layers = (
92
+ num_decoder_layers if num_decoder_layers is not None else self.num_layers
93
+ ) # default = symmetry
94
+ self.num_heads = num_heads
95
+ self.relative_attention_num_buckets = relative_attention_num_buckets
96
+ self.relative_attention_max_distance = relative_attention_max_distance
97
+ self.dropout_rate = dropout_rate
98
+ self.classifier_dropout = classifier_dropout
99
+ self.layer_norm_epsilon = layer_norm_epsilon
100
+ self.initializer_factor = initializer_factor
101
+ self.feed_forward_proj = feed_forward_proj
102
+ self.use_cache = use_cache
103
+ self.num_seqs = num_seqs
104
+ self.num_filters = num_filters
105
+
106
+ act_info = self.feed_forward_proj.split("-")
107
+ self.dense_act_fn = act_info[-1]
108
+ self.is_gated_act = act_info[0] == "gated"
109
+
110
+ if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
111
+ raise ValueError(
112
+ f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer. "
113
+ "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
114
+ "'gated-gelu' or 'relu'"
115
+ )
116
+
117
+ # for backwards compatibility
118
+ if feed_forward_proj == "gated-gelu":
119
+ self.dense_act_fn = "gelu_new"
120
+
121
+ super().__init__(
122
+ pad_token_id=pad_token_id,
123
+ eos_token_id=eos_token_id,
124
+ decoder_start_token_id=decoder_start_token_id,
125
+ is_encoder_decoder=is_encoder_decoder,
126
+ **kwargs,
127
+ )
128
+
129
+
130
+ class T5MIMOOnnxConfig(OnnxSeq2SeqConfigWithPast):
131
+ @property
132
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
133
+ common_inputs = {
134
+ "input_ids": {0: "batch", 1: "encoder_sequence"},
135
+ "attention_mask": {0: "batch", 1: "encoder_sequence"},
136
+ }
137
+ if self.use_past:
138
+ common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence"
139
+ common_inputs["decoder_input_ids"] = {0: "batch"}
140
+ common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
141
+ else:
142
+ common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
143
+ common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
144
+
145
+ if self.use_past:
146
+ self.fill_with_past_key_values_(common_inputs, direction="inputs")
147
+
148
+ return common_inputs
149
+
150
+ @property
151
+ def default_onnx_opset(self) -> int:
152
+ return 13
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9d97da94a794f0b0aad1566b9e13205267ea4b3b70ae8c6cd147e6fe6e651cb
3
+ size 33588312
modeling_t5mimo.py ADDED
@@ -0,0 +1,1751 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import warnings
4
+ from typing import Optional, Tuple, Union
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import CrossEntropyLoss
8
+ from transformers.activations import ACT2FN
9
+ from transformers.modeling_outputs import (
10
+ BaseModelOutput,
11
+ BaseModelOutputWithPastAndCrossAttentions,
12
+ Seq2SeqLMOutput,
13
+ Seq2SeqModelOutput,
14
+ )
15
+ from transformers.modeling_utils import PreTrainedModel
16
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer
17
+ from transformers.utils import (
18
+ DUMMY_INPUTS,
19
+ DUMMY_MASK,
20
+ is_torch_fx_proxy,
21
+ logging,
22
+ )
23
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
24
+ from .configuration_t5mimo import T5MIMOConfig
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+
31
+ class T5LayerNorm(nn.Module):
32
+ def __init__(self, hidden_size, eps=1e-6):
33
+ """
34
+ Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
35
+ """
36
+ super().__init__()
37
+ self.weight = nn.Parameter(torch.ones(hidden_size))
38
+ self.variance_epsilon = eps
39
+
40
+ def forward(self, hidden_states):
41
+ # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
42
+ # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
43
+ # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
44
+ # half-precision inputs is done in fp32
45
+
46
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
47
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
48
+
49
+ # convert into half-precision if necessary
50
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
51
+ hidden_states = hidden_states.to(self.weight.dtype)
52
+
53
+ return self.weight * hidden_states
54
+
55
+
56
+ ALL_LAYERNORM_LAYERS.append(T5LayerNorm)
57
+
58
+
59
+ class T5DenseActDense(nn.Module):
60
+ def __init__(self, config: T5MIMOConfig):
61
+ super().__init__()
62
+ self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
63
+ self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
64
+ self.dropout = nn.Dropout(config.dropout_rate)
65
+ self.act = ACT2FN[config.dense_act_fn]
66
+
67
+ def forward(self, hidden_states):
68
+ hidden_states = self.wi(hidden_states)
69
+ hidden_states = self.act(hidden_states)
70
+ hidden_states = self.dropout(hidden_states)
71
+ if (
72
+ isinstance(self.wo.weight, torch.Tensor)
73
+ and hidden_states.dtype != self.wo.weight.dtype
74
+ and self.wo.weight.dtype != torch.int8
75
+ ):
76
+ hidden_states = hidden_states.to(self.wo.weight.dtype)
77
+ hidden_states = self.wo(hidden_states)
78
+ return hidden_states
79
+
80
+
81
+ class T5DenseGatedActDense(nn.Module):
82
+ def __init__(self, config: T5MIMOConfig):
83
+ super().__init__()
84
+ self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
85
+ self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
86
+ self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
87
+ self.dropout = nn.Dropout(config.dropout_rate)
88
+ self.act = ACT2FN[config.dense_act_fn]
89
+
90
+ def forward(self, hidden_states):
91
+ hidden_gelu = self.act(self.wi_0(hidden_states))
92
+ hidden_linear = self.wi_1(hidden_states)
93
+ hidden_states = hidden_gelu * hidden_linear
94
+ hidden_states = self.dropout(hidden_states)
95
+
96
+ # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
97
+ # See https://github.com/huggingface/transformers/issues/20287
98
+ # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
99
+ if (
100
+ isinstance(self.wo.weight, torch.Tensor)
101
+ and hidden_states.dtype != self.wo.weight.dtype
102
+ and self.wo.weight.dtype != torch.int8
103
+ ):
104
+ hidden_states = hidden_states.to(self.wo.weight.dtype)
105
+
106
+ hidden_states = self.wo(hidden_states)
107
+ return hidden_states
108
+
109
+
110
+ class T5LayerFF(nn.Module):
111
+ def __init__(self, config: T5MIMOConfig):
112
+ super().__init__()
113
+ if config.is_gated_act:
114
+ self.DenseReluDense = T5DenseGatedActDense(config)
115
+ else:
116
+ self.DenseReluDense = T5DenseActDense(config)
117
+
118
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
119
+ self.dropout = nn.Dropout(config.dropout_rate)
120
+
121
+ def forward(self, hidden_states):
122
+ forwarded_states = self.layer_norm(hidden_states)
123
+ forwarded_states = self.DenseReluDense(forwarded_states)
124
+ hidden_states = hidden_states + self.dropout(forwarded_states)
125
+ return hidden_states
126
+
127
+
128
+ class T5Attention(nn.Module):
129
+ def __init__(self, config: T5MIMOConfig, has_relative_attention_bias=False):
130
+ super().__init__()
131
+ self.is_decoder = config.is_decoder
132
+ self.has_relative_attention_bias = has_relative_attention_bias
133
+ self.relative_attention_num_buckets = config.relative_attention_num_buckets
134
+ self.relative_attention_max_distance = config.relative_attention_max_distance
135
+ self.d_model = config.d_model
136
+ self.key_value_proj_dim = config.d_kv
137
+ self.n_heads = config.num_heads
138
+ self.dropout = config.dropout_rate
139
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
140
+
141
+ # Mesh TensorFlow initialization to avoid scaling before softmax
142
+ self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
143
+ self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
144
+ self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
145
+ self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
146
+
147
+ if self.has_relative_attention_bias:
148
+ self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
149
+ self.pruned_heads = set()
150
+ self.gradient_checkpointing = False
151
+
152
+ def prune_heads(self, heads):
153
+ if len(heads) == 0:
154
+ return
155
+ heads, index = find_pruneable_heads_and_indices(
156
+ heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
157
+ )
158
+ # Prune linear layers
159
+ self.q = prune_linear_layer(self.q, index)
160
+ self.k = prune_linear_layer(self.k, index)
161
+ self.v = prune_linear_layer(self.v, index)
162
+ self.o = prune_linear_layer(self.o, index, dim=1)
163
+ # Update hyper params
164
+ self.n_heads = self.n_heads - len(heads)
165
+ self.inner_dim = self.key_value_proj_dim * self.n_heads
166
+ self.pruned_heads = self.pruned_heads.union(heads)
167
+
168
+ @staticmethod
169
+ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
170
+ """
171
+ Adapted from Mesh Tensorflow:
172
+ https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
173
+
174
+ Translate relative position to a bucket number for relative attention. The relative position is defined as
175
+ memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
176
+ position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
177
+ small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
178
+ positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
179
+ This should allow for more graceful generalization to longer sequences than the model has been trained on
180
+
181
+ Args:
182
+ relative_position: an int32 Tensor
183
+ bidirectional: a boolean - whether the attention is bidirectional
184
+ num_buckets: an integer
185
+ max_distance: an integer
186
+
187
+ Returns:
188
+ a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
189
+ """
190
+ relative_buckets = 0
191
+ if bidirectional:
192
+ num_buckets //= 2
193
+ relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
194
+ relative_position = torch.abs(relative_position)
195
+ else:
196
+ relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
197
+ # now relative_position is in the range [0, inf)
198
+
199
+ # half of the buckets are for exact increments in positions
200
+ max_exact = num_buckets // 2
201
+ is_small = relative_position < max_exact
202
+
203
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
204
+ relative_position_if_large = max_exact + (
205
+ torch.log(relative_position.float() / max_exact)
206
+ / math.log(max_distance / max_exact)
207
+ * (num_buckets - max_exact)
208
+ ).to(torch.long)
209
+ relative_position_if_large = torch.min(
210
+ relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
211
+ )
212
+
213
+ relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
214
+ return relative_buckets
215
+
216
+ def compute_bias(self, query_length, key_length,multivar_dim=-1, device=None):
217
+ """Compute binned relative position bias"""
218
+ if device is None:
219
+ device = self.relative_attention_bias.weight.device
220
+ context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
221
+ memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
222
+ relative_position = memory_position - context_position # shape (query_length, key_length)
223
+ relative_position_bucket = self._relative_position_bucket(
224
+ relative_position, # shape (query_length, key_length)
225
+ bidirectional=(not self.is_decoder),
226
+ num_buckets=self.relative_attention_num_buckets,
227
+ max_distance=self.relative_attention_max_distance,
228
+ )
229
+ values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
230
+ values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
231
+ if multivar_dim !=-1: # shape (1, multivar_dim, num_heads, query_length, key_length) (copy across)
232
+ values = values.expand(1, multivar_dim, -1, -1, -1)
233
+
234
+ return values
235
+
236
+ def forward(
237
+ self,
238
+ hidden_states,
239
+ mask=None,
240
+ key_value_states=None,
241
+ position_bias=None,
242
+ past_key_value=None,
243
+ layer_head_mask=None,
244
+ query_length=None,
245
+ use_cache=False,
246
+ output_attentions=False,
247
+ ):
248
+ """
249
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
250
+ """
251
+ # Input is (batch_size, seq_length, dim)
252
+ # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
253
+ # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
254
+ if len(hidden_states.shape) == 3:
255
+ batch_size, seq_length = hidden_states.shape[:2]
256
+ else:
257
+ batch_size, seq_length = hidden_states.shape[0],hidden_states.shape[2]
258
+ multivar_dim = hidden_states.shape[1]
259
+ real_seq_length = seq_length
260
+
261
+ if past_key_value is not None:
262
+ if len(past_key_value) != 2:
263
+ raise ValueError(
264
+ f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
265
+ )
266
+ real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
267
+
268
+ if len(hidden_states.shape) == 3:
269
+ key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
270
+ else:
271
+ key_length = real_seq_length if key_value_states is None else key_value_states.shape[2]
272
+
273
+
274
+ def shape(states):
275
+ """projection"""
276
+ # states: torch.Size([3, 16, 512]) -> query_states: torch.Size([3, 8, 16, 64])
277
+ # states: torch.Size([3, 6, 16, 512]) -> query_states: torch.Size([3, 6, 8 , 16, 64])
278
+ if len(states.shape) == 3:
279
+ return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
280
+ else:
281
+ return states.view(batch_size, multivar_dim, -1, self.n_heads, self.key_value_proj_dim).transpose(2, 3)
282
+
283
+
284
+ def unshape(states):
285
+ """reshape"""
286
+ if len(states.shape) == 4:
287
+ return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
288
+ else:
289
+ return states.transpose(2, 3).contiguous().view(batch_size, multivar_dim, -1, self.inner_dim)
290
+
291
+ def project(hidden_states, proj_layer, key_value_states, past_key_value):
292
+ """projects hidden states correctly to key/query states"""
293
+ if key_value_states is None:
294
+ # self-attn
295
+ # (batch_size, n_heads, seq_length, dim_per_head)
296
+ hidden_states = shape(proj_layer(hidden_states))
297
+ elif past_key_value is None:
298
+ # cross-attn
299
+ # (batch_size, n_heads, seq_length, dim_per_head)
300
+ hidden_states = shape(proj_layer(key_value_states))
301
+
302
+ if past_key_value is not None:
303
+ if key_value_states is None:
304
+ # self-attn
305
+ # (batch_size, n_heads, key_length, dim_per_head)
306
+ hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
307
+ elif past_key_value.shape[2] != key_value_states.shape[1]:
308
+ # checking that the `sequence_length` of the `past_key_value` is the same as
309
+ # the provided `key_value_states` to support prefix tuning
310
+ # cross-attn
311
+ # (batch_size, n_heads, seq_length, dim_per_head)
312
+ hidden_states = shape(proj_layer(key_value_states))
313
+ else:
314
+ # cross-attn
315
+ hidden_states = past_key_value
316
+ return hidden_states
317
+
318
+ # get query states
319
+ query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
320
+
321
+
322
+ # get key/value states
323
+ key_states = project(
324
+ hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
325
+ )
326
+ value_states = project(
327
+ hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
328
+ )
329
+
330
+
331
+
332
+ # compute scores
333
+ if len(hidden_states.shape) == 3:
334
+ scores = torch.matmul(
335
+ query_states, key_states.transpose(3, 2)
336
+ ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
337
+ else:
338
+ scores = torch.matmul(
339
+ query_states, key_states.transpose(4, 3)
340
+ )
341
+
342
+
343
+
344
+
345
+
346
+ if position_bias is None:
347
+ if not self.has_relative_attention_bias:
348
+
349
+ if len(hidden_states.shape) == 3:
350
+ position_bias = torch.zeros(
351
+ (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
352
+ )
353
+ else:
354
+ position_bias = torch.zeros(
355
+ (1,multivar_dim, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
356
+ )
357
+ if self.gradient_checkpointing and self.training:
358
+ position_bias.requires_grad = True
359
+ else:
360
+
361
+ if len(hidden_states.shape) == 3:
362
+ position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
363
+ else:
364
+ position_bias = self.compute_bias(real_seq_length, key_length,multivar_dim=multivar_dim, device=scores.device)
365
+
366
+ # if key and values are already calculated
367
+ # we want only the last query position bias
368
+ if past_key_value is not None:
369
+ position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
370
+
371
+ if mask is not None:
372
+ position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
373
+
374
+
375
+
376
+ if self.pruned_heads:
377
+ mask = torch.ones(position_bias.shape[1])
378
+ mask[list(self.pruned_heads)] = 0
379
+ position_bias_masked = position_bias[:, mask.bool()]
380
+ else:
381
+ position_bias_masked = position_bias
382
+
383
+
384
+ scores += position_bias_masked
385
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
386
+ scores
387
+ ) # (batch_size, n_heads, seq_length, key_length)
388
+ attn_weights = nn.functional.dropout(
389
+ attn_weights, p=self.dropout, training=self.training
390
+ ) # (batch_size, n_heads, seq_length, key_length)
391
+
392
+ # Mask heads if we want to
393
+ if layer_head_mask is not None:
394
+ attn_weights = attn_weights * layer_head_mask
395
+
396
+
397
+ if len(hidden_states.shape) == 3:
398
+ attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
399
+ else:
400
+ attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, multivar_dim, seq_length, dim)
401
+ attn_output = self.o(attn_output)
402
+
403
+
404
+ present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
405
+ outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
406
+
407
+
408
+ if output_attentions:
409
+ outputs = outputs + (attn_weights,)
410
+
411
+ return outputs
412
+
413
+
414
+ class T5LayerSelfAttention(nn.Module):
415
+ def __init__(self, config, has_relative_attention_bias=False):
416
+ super().__init__()
417
+ self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
418
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
419
+ self.dropout = nn.Dropout(config.dropout_rate)
420
+
421
+ def forward(
422
+ self,
423
+ hidden_states,
424
+ attention_mask=None,
425
+ position_bias=None,
426
+ layer_head_mask=None,
427
+ past_key_value=None,
428
+ use_cache=False,
429
+ output_attentions=False,
430
+ ):
431
+ normed_hidden_states = self.layer_norm(hidden_states)
432
+ attention_output = self.SelfAttention(
433
+ normed_hidden_states,
434
+ mask=attention_mask,
435
+ position_bias=position_bias,
436
+ layer_head_mask=layer_head_mask,
437
+ past_key_value=past_key_value,
438
+ use_cache=use_cache,
439
+ output_attentions=output_attentions,
440
+ )
441
+
442
+ hidden_states = hidden_states + self.dropout(attention_output[0])
443
+ outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
444
+ return outputs
445
+
446
+
447
+ class T5LayerCrossAttention(nn.Module):
448
+ def __init__(self, config):
449
+ super().__init__()
450
+ self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False)
451
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
452
+ self.dropout = nn.Dropout(config.dropout_rate)
453
+
454
+ def forward(
455
+ self,
456
+ hidden_states,
457
+ key_value_states,
458
+ attention_mask=None,
459
+ position_bias=None,
460
+ layer_head_mask=None,
461
+ past_key_value=None,
462
+ use_cache=False,
463
+ query_length=None,
464
+ output_attentions=False,
465
+ ):
466
+
467
+ normed_hidden_states = self.layer_norm(hidden_states)
468
+ attention_output = self.EncDecAttention(
469
+ normed_hidden_states,
470
+ mask=attention_mask,
471
+ key_value_states=key_value_states,
472
+ position_bias=position_bias,
473
+ layer_head_mask=layer_head_mask,
474
+ past_key_value=past_key_value,
475
+ use_cache=use_cache,
476
+ query_length=query_length,
477
+ output_attentions=output_attentions,
478
+ )
479
+ layer_output = hidden_states + self.dropout(attention_output[0])
480
+ outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
481
+ return outputs
482
+
483
+
484
+ class T5Block(nn.Module):
485
+ def __init__(self, config, has_relative_attention_bias=False):
486
+ super().__init__()
487
+ self.is_decoder = config.is_decoder
488
+ self.layer = nn.ModuleList()
489
+ self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
490
+ if self.is_decoder:
491
+ self.layer.append(T5LayerCrossAttention(config))
492
+
493
+ self.layer.append(T5LayerFF(config))
494
+
495
+ def forward(
496
+ self,
497
+ hidden_states,
498
+ attention_mask=None,
499
+ position_bias=None,
500
+ encoder_hidden_states=None,
501
+ encoder_attention_mask=None,
502
+ encoder_decoder_position_bias=None,
503
+ layer_head_mask=None,
504
+ cross_attn_layer_head_mask=None,
505
+ past_key_value=None,
506
+ use_cache=False,
507
+ output_attentions=False,
508
+ return_dict=True,
509
+ ):
510
+ if past_key_value is not None:
511
+ if not self.is_decoder:
512
+ logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.")
513
+ expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
514
+
515
+ if len(past_key_value) != expected_num_past_key_values:
516
+ raise ValueError(
517
+ f"There should be {expected_num_past_key_values} past states. "
518
+ f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
519
+ f"Got {len(past_key_value)} past key / value states"
520
+ )
521
+
522
+ self_attn_past_key_value = past_key_value[:2]
523
+ cross_attn_past_key_value = past_key_value[2:]
524
+ else:
525
+ self_attn_past_key_value, cross_attn_past_key_value = None, None
526
+
527
+ self_attention_outputs = self.layer[0](
528
+ hidden_states,
529
+ attention_mask=attention_mask,
530
+ position_bias=position_bias,
531
+ layer_head_mask=layer_head_mask,
532
+ past_key_value=self_attn_past_key_value,
533
+ use_cache=use_cache,
534
+ output_attentions=output_attentions,
535
+ )
536
+ hidden_states, present_key_value_state = self_attention_outputs[:2]
537
+ attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
538
+
539
+ # clamp inf values to enable fp16 training
540
+ if hidden_states.dtype == torch.float16:
541
+ clamp_value = torch.where(
542
+ torch.isinf(hidden_states).any(),
543
+ torch.finfo(hidden_states.dtype).max - 1000,
544
+ torch.finfo(hidden_states.dtype).max,
545
+ )
546
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
547
+
548
+ do_cross_attention = self.is_decoder and encoder_hidden_states is not None
549
+ if do_cross_attention:
550
+ # the actual query length is unknown for cross attention
551
+ # if using past key value states. Need to inject it here
552
+ if present_key_value_state is not None:
553
+ query_length = present_key_value_state[0].shape[2]
554
+ else:
555
+ query_length = None
556
+
557
+ cross_attention_outputs = self.layer[1](
558
+ hidden_states,
559
+ key_value_states=encoder_hidden_states,
560
+ attention_mask=encoder_attention_mask,
561
+ position_bias=encoder_decoder_position_bias,
562
+ layer_head_mask=cross_attn_layer_head_mask,
563
+ past_key_value=cross_attn_past_key_value,
564
+ query_length=query_length,
565
+ use_cache=use_cache,
566
+ output_attentions=output_attentions,
567
+ )
568
+ hidden_states = cross_attention_outputs[0]
569
+
570
+ # clamp inf values to enable fp16 training
571
+ if hidden_states.dtype == torch.float16:
572
+ clamp_value = torch.where(
573
+ torch.isinf(hidden_states).any(),
574
+ torch.finfo(hidden_states.dtype).max - 1000,
575
+ torch.finfo(hidden_states.dtype).max,
576
+ )
577
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
578
+
579
+ # Combine self attn and cross attn key value states
580
+ if present_key_value_state is not None:
581
+ present_key_value_state = present_key_value_state + cross_attention_outputs[1]
582
+
583
+ # Keep cross-attention outputs and relative position weights
584
+ attention_outputs = attention_outputs + cross_attention_outputs[2:]
585
+
586
+ # Apply Feed Forward layer
587
+ hidden_states = self.layer[-1](hidden_states)
588
+
589
+ # clamp inf values to enable fp16 training
590
+ if hidden_states.dtype == torch.float16:
591
+ clamp_value = torch.where(
592
+ torch.isinf(hidden_states).any(),
593
+ torch.finfo(hidden_states.dtype).max - 1000,
594
+ torch.finfo(hidden_states.dtype).max,
595
+ )
596
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
597
+
598
+ outputs = (hidden_states,)
599
+
600
+ if use_cache:
601
+ outputs = outputs + (present_key_value_state,) + attention_outputs
602
+ else:
603
+ outputs = outputs + attention_outputs
604
+
605
+ return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
606
+
607
+
608
+ class T5ClassificationHead(nn.Module):
609
+ """Head for sentence-level classification tasks."""
610
+
611
+ def __init__(self, config: T5MIMOConfig):
612
+ super().__init__()
613
+ self.dense = nn.Linear(config.d_model, config.d_model)
614
+ self.dropout = nn.Dropout(p=config.classifier_dropout)
615
+ self.out_proj = nn.Linear(config.d_model, config.num_labels)
616
+
617
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
618
+ hidden_states = self.dropout(hidden_states)
619
+ hidden_states = self.dense(hidden_states)
620
+ hidden_states = torch.tanh(hidden_states)
621
+ hidden_states = self.dropout(hidden_states)
622
+ hidden_states = self.out_proj(hidden_states)
623
+ return hidden_states
624
+
625
+
626
+ class T5PreTrainedModel(PreTrainedModel):
627
+ """
628
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
629
+ models.
630
+ """
631
+
632
+ config_class = T5MIMOConfig
633
+ base_model_prefix = "transformer"
634
+ is_parallelizable = True
635
+ supports_gradient_checkpointing = True
636
+ _no_split_modules = ["T5Block"]
637
+ _keep_in_fp32_modules = ["wo"]
638
+
639
+ @property
640
+ def dummy_inputs(self):
641
+ input_ids = torch.tensor(DUMMY_INPUTS)
642
+ input_mask = torch.tensor(DUMMY_MASK)
643
+ dummy_inputs = {
644
+ "decoder_input_ids": input_ids,
645
+ "input_ids": input_ids,
646
+ "decoder_attention_mask": input_mask,
647
+ }
648
+ return dummy_inputs
649
+
650
+ def _init_weights(self, module):
651
+ """Initialize the weights"""
652
+ factor = self.config.initializer_factor # Used for testing weights initialization
653
+ if isinstance(module, T5LayerNorm):
654
+ module.weight.data.fill_(factor * 1.0)
655
+ elif isinstance(
656
+ module,
657
+ (T5MIMOModel, T5MIMOForConditionalGeneration, T5MIMOEncoderModel),
658
+ ):
659
+ # Mesh TensorFlow embeddings initialization
660
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
661
+ module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
662
+ if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
663
+ module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
664
+ if hasattr(module, "qa_outputs"):
665
+ module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
666
+ module.qa_outputs.bias.data.zero_()
667
+ elif isinstance(module, T5ClassificationHead):
668
+ module.dense.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
669
+ if hasattr(module.dense, "bias") and module.dense.bias is not None:
670
+ module.dense.bias.data.zero_()
671
+ module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
672
+ if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None:
673
+ module.out_proj.bias.data.zero_()
674
+ elif isinstance(module, T5DenseActDense):
675
+ # Mesh TensorFlow FF initialization
676
+ # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
677
+ # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
678
+ module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
679
+ if hasattr(module.wi, "bias") and module.wi.bias is not None:
680
+ module.wi.bias.data.zero_()
681
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
682
+ if hasattr(module.wo, "bias") and module.wo.bias is not None:
683
+ module.wo.bias.data.zero_()
684
+ elif isinstance(module, T5DenseGatedActDense):
685
+ module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
686
+ if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
687
+ module.wi_0.bias.data.zero_()
688
+ module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
689
+ if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
690
+ module.wi_1.bias.data.zero_()
691
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
692
+ if hasattr(module.wo, "bias") and module.wo.bias is not None:
693
+ module.wo.bias.data.zero_()
694
+ elif isinstance(module, T5Attention):
695
+ # Mesh TensorFlow attention initialization to avoid scaling before softmax
696
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
697
+ d_model = self.config.d_model
698
+ key_value_proj_dim = self.config.d_kv
699
+ n_heads = self.config.num_heads
700
+ module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
701
+ module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
702
+ module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
703
+ module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
704
+ if module.has_relative_attention_bias:
705
+ module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
706
+
707
+ def _shift_right(self, input_ids):
708
+ decoder_start_token_id = self.config.decoder_start_token_id
709
+ pad_token_id = self.config.pad_token_id
710
+
711
+ if decoder_start_token_id is None:
712
+ raise ValueError(
713
+ "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. "
714
+ "See T5 docs for more information."
715
+ )
716
+
717
+ # shift inputs to the right
718
+ if is_torch_fx_proxy(input_ids):
719
+ # Item assignment is not supported natively for proxies.
720
+ shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
721
+ shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
722
+ else:
723
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
724
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
725
+ shifted_input_ids[..., 0] = decoder_start_token_id
726
+
727
+ if pad_token_id is None:
728
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
729
+ # replace possible -100 values in labels by `pad_token_id`
730
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
731
+
732
+ return shifted_input_ids
733
+
734
+
735
+ class T5Stack(T5PreTrainedModel):
736
+ def __init__(self, config, embed_tokens=None):
737
+ super().__init__(config)
738
+
739
+ self.embed_tokens = embed_tokens
740
+ self.is_decoder = config.is_decoder
741
+
742
+ self.block = nn.ModuleList(
743
+ [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
744
+ )
745
+ self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
746
+ self.dropout = nn.Dropout(config.dropout_rate)
747
+
748
+ # Initialize weights and apply final processing
749
+ self.post_init()
750
+ # Model parallel
751
+ self.model_parallel = False
752
+ self.device_map = None
753
+ self.gradient_checkpointing = False
754
+
755
+ def parallelize(self, device_map=None):
756
+ warnings.warn(
757
+ "`T5Stack.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model"
758
+ " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
759
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0,"
760
+ " 'block.1': 1, ...}",
761
+ FutureWarning,
762
+ )
763
+ # Check validity of device_map
764
+ self.device_map = (
765
+ get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map
766
+ )
767
+ assert_device_map(self.device_map, len(self.block))
768
+ self.model_parallel = True
769
+ self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
770
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
771
+ # Load onto devices
772
+ for k, v in self.device_map.items():
773
+ for layer in v:
774
+ cuda_device = "cuda:" + str(k)
775
+ self.block[layer] = self.block[layer].to(cuda_device)
776
+
777
+ # Set embed_tokens to first layer
778
+ self.embed_tokens = self.embed_tokens.to(self.first_device)
779
+ # Set final layer norm to last device
780
+ self.final_layer_norm = self.final_layer_norm.to(self.last_device)
781
+
782
+
783
+ def deparallelize(self):
784
+ warnings.warn(
785
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
786
+ FutureWarning,
787
+ )
788
+ self.model_parallel = False
789
+ self.device_map = None
790
+ self.first_device = "cpu"
791
+ self.last_device = "cpu"
792
+ for i in range(len(self.block)):
793
+ self.block[i] = self.block[i].to("cpu")
794
+ self.embed_tokens = self.embed_tokens.to("cpu")
795
+ self.final_layer_norm = self.final_layer_norm.to("cpu")
796
+ torch.cuda.empty_cache()
797
+
798
+ def get_input_embeddings(self):
799
+ return self.embed_tokens
800
+
801
+ def set_input_embeddings(self, new_embeddings):
802
+ self.embed_tokens = new_embeddings
803
+
804
+ def forward(
805
+ self,
806
+ input_ids=None,
807
+ attention_mask=None,
808
+ encoder_hidden_states=None,
809
+ encoder_attention_mask=None,
810
+ inputs_embeds=None,
811
+ head_mask=None,
812
+ cross_attn_head_mask=None,
813
+ past_key_values=None,
814
+ use_cache=None,
815
+ output_attentions=None,
816
+ output_hidden_states=None,
817
+ return_dict=None,
818
+ ):
819
+ # Model parallel
820
+ if self.model_parallel:
821
+ torch.cuda.set_device(self.first_device)
822
+ self.embed_tokens = self.embed_tokens.to(self.first_device)
823
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
824
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
825
+ output_hidden_states = (
826
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
827
+ )
828
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
829
+
830
+ if input_ids is not None and inputs_embeds is not None:
831
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
832
+ raise ValueError(
833
+ f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
834
+ )
835
+ elif input_ids is not None:
836
+ input_shape = input_ids.size()
837
+ # input_ids = input_ids.view(-1, input_shape[-1])
838
+ elif inputs_embeds is not None:
839
+ input_shape = inputs_embeds.size()[:-1]
840
+ else:
841
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
842
+ raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
843
+
844
+ if inputs_embeds is None:
845
+ if self.embed_tokens is None:
846
+ raise ValueError("You have to initialize the model with valid token embeddings")
847
+ inputs_embeds = self.embed_tokens(input_ids)
848
+
849
+ if len(input_shape) == 3:
850
+ batch_size, multivar_seqs ,seq_length = input_shape
851
+ else:
852
+ batch_size, seq_length = input_shape
853
+
854
+ # required mask seq length can be calculated via length of past
855
+ mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
856
+
857
+ if use_cache is True:
858
+ if not self.is_decoder:
859
+ raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
860
+
861
+ # initialize past_key_values with `None` if past does not exist
862
+ if past_key_values is None:
863
+ past_key_values = [None] * len(self.block)
864
+
865
+ if attention_mask is None:
866
+ if len(input_shape) == 2:
867
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
868
+ else:
869
+ attention_mask = torch.ones(batch_size, multivar_seqs, mask_seq_length, device=inputs_embeds.device)
870
+
871
+
872
+
873
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
874
+ # ourselves in which case we just need to make it broadcastable to all heads.
875
+ if len(input_shape) == 2:
876
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
877
+ else:
878
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
879
+ # permute from [batch_size, 1, multivar_seqs, seq_length] to [batch_size, multivar_seqs, 1, seq_length]
880
+ extended_attention_mask = extended_attention_mask.permute(0, 2, 1, 3)
881
+ # Now make it [batch_size, multivar_seqs, 1, 1, seq_length]
882
+ extended_attention_mask = extended_attention_mask.unsqueeze(3)
883
+
884
+ # If a 2D or 3D attention mask is provided for the cross-attention
885
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
886
+ if self.is_decoder and encoder_hidden_states is not None:
887
+
888
+ if len(encoder_hidden_states.size()) == 3 :
889
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
890
+ else:
891
+ encoder_batch_size, multivar_dem, encoder_sequence_length, _ = encoder_hidden_states.size()
892
+
893
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
894
+ if encoder_attention_mask is None:
895
+ encoder_attention_mask = torch.ones(
896
+ encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
897
+ )
898
+ if len(input_shape) == 2:
899
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
900
+ else:
901
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
902
+ multivar_dim = extended_attention_mask.shape[1]
903
+ encoder_extended_attention_mask = encoder_extended_attention_mask.unsqueeze(1)
904
+ encoder_extended_attention_mask = encoder_extended_attention_mask.permute(0, 3, 1, 2, 4)
905
+
906
+ else:
907
+ encoder_extended_attention_mask = None
908
+
909
+
910
+
911
+ if self.gradient_checkpointing and self.training:
912
+ if use_cache:
913
+ logger.warning_once(
914
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
915
+ )
916
+ use_cache = False
917
+
918
+ # Prepare head mask if needed
919
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
920
+ cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
921
+ present_key_value_states = () if use_cache else None
922
+ all_hidden_states = () if output_hidden_states else None
923
+ all_attentions = () if output_attentions else None
924
+ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
925
+ position_bias = None
926
+ encoder_decoder_position_bias = None
927
+
928
+ hidden_states = self.dropout(inputs_embeds)
929
+
930
+ for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
931
+ layer_head_mask = head_mask[i]
932
+ cross_attn_layer_head_mask = cross_attn_head_mask[i]
933
+ # Model parallel
934
+ if self.model_parallel:
935
+ torch.cuda.set_device(hidden_states.device)
936
+ # Ensure that attention_mask is always on the same device as hidden_states
937
+ if attention_mask is not None:
938
+ attention_mask = attention_mask.to(hidden_states.device)
939
+ if position_bias is not None:
940
+ position_bias = position_bias.to(hidden_states.device)
941
+ if encoder_hidden_states is not None:
942
+ encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
943
+ if encoder_extended_attention_mask is not None:
944
+ encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
945
+ if encoder_decoder_position_bias is not None:
946
+ encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
947
+ if layer_head_mask is not None:
948
+ layer_head_mask = layer_head_mask.to(hidden_states.device)
949
+ if cross_attn_layer_head_mask is not None:
950
+ cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
951
+ if output_hidden_states:
952
+ all_hidden_states = all_hidden_states + (hidden_states,)
953
+
954
+ if self.gradient_checkpointing and self.training:
955
+ layer_outputs = self._gradient_checkpointing_func(
956
+ layer_module.forward,
957
+ hidden_states,
958
+ extended_attention_mask,
959
+ position_bias,
960
+ encoder_hidden_states,
961
+ encoder_extended_attention_mask,
962
+ encoder_decoder_position_bias,
963
+ layer_head_mask,
964
+ cross_attn_layer_head_mask,
965
+ None, # past_key_value is always None with gradient checkpointing
966
+ use_cache,
967
+ output_attentions,
968
+ )
969
+ else:
970
+ layer_outputs = layer_module(
971
+ hidden_states,
972
+ attention_mask=extended_attention_mask,
973
+ position_bias=position_bias,
974
+ encoder_hidden_states=encoder_hidden_states,
975
+ encoder_attention_mask=encoder_extended_attention_mask,
976
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
977
+ layer_head_mask=layer_head_mask,
978
+ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
979
+ past_key_value=past_key_value,
980
+ use_cache=use_cache,
981
+ output_attentions=output_attentions,
982
+ )
983
+
984
+ # layer_outputs is a tuple with:
985
+ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
986
+ if use_cache is False:
987
+ layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
988
+
989
+ hidden_states, present_key_value_state = layer_outputs[:2]
990
+
991
+ # We share the position biases between the layers - the first layer store them
992
+ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
993
+ # (cross-attention position bias), (cross-attention weights)
994
+ position_bias = layer_outputs[2]
995
+ if self.is_decoder and encoder_hidden_states is not None:
996
+ encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
997
+ # append next layer key value states
998
+ if use_cache:
999
+ present_key_value_states = present_key_value_states + (present_key_value_state,)
1000
+
1001
+ if output_attentions:
1002
+ all_attentions = all_attentions + (layer_outputs[3],)
1003
+ if self.is_decoder:
1004
+ all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
1005
+
1006
+ # Model Parallel: If it's the last layer for that device, put things on the next device
1007
+ if self.model_parallel:
1008
+ for k, v in self.device_map.items():
1009
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
1010
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
1011
+
1012
+ hidden_states = self.final_layer_norm(hidden_states)
1013
+ hidden_states = self.dropout(hidden_states)
1014
+
1015
+ # Add last layer
1016
+ if output_hidden_states:
1017
+ all_hidden_states = all_hidden_states + (hidden_states,)
1018
+
1019
+ if not return_dict:
1020
+ return tuple(
1021
+ v
1022
+ for v in [
1023
+ hidden_states,
1024
+ present_key_value_states,
1025
+ all_hidden_states,
1026
+ all_attentions,
1027
+ all_cross_attentions,
1028
+ ]
1029
+ if v is not None
1030
+ )
1031
+ return BaseModelOutputWithPastAndCrossAttentions(
1032
+ last_hidden_state=hidden_states,
1033
+ past_key_values=present_key_value_states,
1034
+ hidden_states=all_hidden_states,
1035
+ attentions=all_attentions,
1036
+ cross_attentions=all_cross_attentions,
1037
+ )
1038
+
1039
+
1040
+
1041
+ class T5MIMOModel(T5PreTrainedModel):
1042
+ _keys_to_ignore_on_load_unexpected = [
1043
+ "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
1044
+ ]
1045
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
1046
+
1047
+ def __init__(self, config: T5MIMOConfig):
1048
+ super().__init__(config)
1049
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
1050
+
1051
+ encoder_config = copy.deepcopy(config)
1052
+ encoder_config.is_decoder = False
1053
+ encoder_config.use_cache = False
1054
+ encoder_config.is_encoder_decoder = False
1055
+ self.encoder = T5Stack(encoder_config, self.shared)
1056
+
1057
+ decoder_config = copy.deepcopy(config)
1058
+ decoder_config.is_decoder = True
1059
+ decoder_config.is_encoder_decoder = False
1060
+ decoder_config.num_layers = config.num_decoder_layers
1061
+ self.decoder = T5Stack(decoder_config, self.shared)
1062
+
1063
+ # Initialize weights and apply final processing
1064
+ self.post_init()
1065
+
1066
+ # Model parallel
1067
+ self.model_parallel = False
1068
+ self.device_map = None
1069
+
1070
+
1071
+ def parallelize(self, device_map=None):
1072
+ warnings.warn(
1073
+ "`T5Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model"
1074
+ " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
1075
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'encoder.block.0':"
1076
+ " 0, 'encoder.block.1': 1, ...}",
1077
+ FutureWarning,
1078
+ )
1079
+ self.device_map = (
1080
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1081
+ if device_map is None
1082
+ else device_map
1083
+ )
1084
+ assert_device_map(self.device_map, len(self.encoder.block))
1085
+ self.encoder.parallelize(self.device_map)
1086
+ self.decoder.parallelize(self.device_map)
1087
+ self.model_parallel = True
1088
+
1089
+
1090
+ def deparallelize(self):
1091
+ warnings.warn(
1092
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1093
+ FutureWarning,
1094
+ )
1095
+ self.encoder.deparallelize()
1096
+ self.decoder.deparallelize()
1097
+ self.encoder = self.encoder.to("cpu")
1098
+ self.decoder = self.decoder.to("cpu")
1099
+ self.model_parallel = False
1100
+ self.device_map = None
1101
+ torch.cuda.empty_cache()
1102
+
1103
+ def get_input_embeddings(self):
1104
+ return self.shared
1105
+
1106
+ def set_input_embeddings(self, new_embeddings):
1107
+ self.shared = new_embeddings
1108
+ self.encoder.set_input_embeddings(new_embeddings)
1109
+ self.decoder.set_input_embeddings(new_embeddings)
1110
+
1111
+ def _tie_weights(self):
1112
+ if self.config.tie_word_embeddings:
1113
+ self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
1114
+ self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
1115
+
1116
+ def get_encoder(self):
1117
+ return self.encoder
1118
+
1119
+ def get_decoder(self):
1120
+ return self.decoder
1121
+
1122
+ def _prune_heads(self, heads_to_prune):
1123
+ """
1124
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1125
+ class PreTrainedModel
1126
+ """
1127
+ for layer, heads in heads_to_prune.items():
1128
+ self.encoder.layer[layer].attention.prune_heads(heads)
1129
+
1130
+ def forward(
1131
+ self,
1132
+ input_ids: Optional[torch.LongTensor] = None,
1133
+ attention_mask: Optional[torch.FloatTensor] = None,
1134
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1135
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
1136
+ head_mask: Optional[torch.FloatTensor] = None,
1137
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
1138
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1139
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1140
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1141
+ inputs_embeds: Optional[torch.Tensor] = None,
1142
+ decoder_inputs_embeds: Optional[torch.Tensor] = None,
1143
+ use_cache: Optional[bool] = None,
1144
+ output_attentions: Optional[bool] = None,
1145
+ output_hidden_states: Optional[bool] = None,
1146
+ return_dict: Optional[bool] = None,
1147
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
1148
+ r"""
1149
+ Returns:
1150
+
1151
+ Example:
1152
+
1153
+ ```python
1154
+ >>> from transformers import AutoTokenizer, T5Model
1155
+
1156
+ >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
1157
+ >>> model = T5Model.from_pretrained("google-t5/t5-small")
1158
+
1159
+ >>> input_ids = tokenizer(
1160
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
1161
+ ... ).input_ids # Batch size 1
1162
+ >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
1163
+
1164
+ >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model.
1165
+ >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg.
1166
+ >>> decoder_input_ids = model._shift_right(decoder_input_ids)
1167
+
1168
+ >>> # forward pass
1169
+ >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
1170
+ >>> last_hidden_states = outputs.last_hidden_state
1171
+ ```"""
1172
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1173
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1174
+
1175
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1176
+ if head_mask is not None and decoder_head_mask is None:
1177
+ if self.config.num_layers == self.config.num_decoder_layers:
1178
+ decoder_head_mask = head_mask
1179
+
1180
+ # Encode if needed (training, first prediction pass)
1181
+ if encoder_outputs is None:
1182
+ encoder_outputs = self.encoder(
1183
+ input_ids=input_ids,
1184
+ attention_mask=attention_mask,
1185
+ inputs_embeds=inputs_embeds,
1186
+ head_mask=head_mask,
1187
+ output_attentions=output_attentions,
1188
+ output_hidden_states=output_hidden_states,
1189
+ return_dict=return_dict,
1190
+ )
1191
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1192
+ encoder_outputs = BaseModelOutput(
1193
+ last_hidden_state=encoder_outputs[0],
1194
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1195
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1196
+ )
1197
+
1198
+ hidden_states = encoder_outputs[0]
1199
+
1200
+ # Set device for model parallelism
1201
+ if self.model_parallel:
1202
+ torch.cuda.set_device(self.decoder.first_device)
1203
+ hidden_states = hidden_states.to(self.decoder.first_device)
1204
+ if decoder_input_ids is not None:
1205
+ decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
1206
+ if attention_mask is not None:
1207
+ attention_mask = attention_mask.to(self.decoder.first_device)
1208
+ if decoder_attention_mask is not None:
1209
+ decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
1210
+
1211
+ # Decode
1212
+ decoder_outputs = self.decoder(
1213
+ input_ids=decoder_input_ids,
1214
+ attention_mask=decoder_attention_mask,
1215
+ inputs_embeds=decoder_inputs_embeds,
1216
+ past_key_values=past_key_values,
1217
+ encoder_hidden_states=hidden_states,
1218
+ encoder_attention_mask=attention_mask,
1219
+ head_mask=decoder_head_mask,
1220
+ cross_attn_head_mask=cross_attn_head_mask,
1221
+ use_cache=use_cache,
1222
+ output_attentions=output_attentions,
1223
+ output_hidden_states=output_hidden_states,
1224
+ return_dict=return_dict,
1225
+ )
1226
+
1227
+ if not return_dict:
1228
+ return decoder_outputs + encoder_outputs
1229
+
1230
+ return Seq2SeqModelOutput(
1231
+ last_hidden_state=decoder_outputs.last_hidden_state,
1232
+ past_key_values=decoder_outputs.past_key_values,
1233
+ decoder_hidden_states=decoder_outputs.hidden_states,
1234
+ decoder_attentions=decoder_outputs.attentions,
1235
+ cross_attentions=decoder_outputs.cross_attentions,
1236
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1237
+ encoder_hidden_states=encoder_outputs.hidden_states,
1238
+ encoder_attentions=encoder_outputs.attentions,
1239
+ )
1240
+
1241
+
1242
+
1243
+ class T5MIMOForConditionalGeneration(T5PreTrainedModel):
1244
+ _keys_to_ignore_on_load_unexpected = [
1245
+ "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
1246
+ ]
1247
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
1248
+
1249
+ def __init__(self, config: T5MIMOConfig):
1250
+ super().__init__(config)
1251
+ self.model_dim = config.d_model
1252
+
1253
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
1254
+
1255
+ encoder_config = copy.deepcopy(config)
1256
+ encoder_config.is_decoder = False
1257
+ encoder_config.use_cache = False
1258
+ encoder_config.is_encoder_decoder = False
1259
+ self.encoder = T5Stack(encoder_config, self.shared)
1260
+
1261
+ decoder_config = copy.deepcopy(config)
1262
+ decoder_config.is_decoder = True
1263
+ decoder_config.is_encoder_decoder = False
1264
+ decoder_config.num_layers = config.num_decoder_layers
1265
+ self.decoder = T5Stack(decoder_config, self.shared)
1266
+
1267
+
1268
+ self.conv_block = MultivariateConvBlock(config.num_seqs, config.d_model, num_filters=config.num_filters)
1269
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
1270
+
1271
+ # Initialize weights and apply final processing
1272
+ self.post_init()
1273
+
1274
+ # Model parallel
1275
+ self.model_parallel = False
1276
+ self.device_map = None
1277
+
1278
+
1279
+ def parallelize(self, device_map=None):
1280
+ warnings.warn(
1281
+ "`T5ForConditionalGeneration.parallelize` is deprecated and will be removed in v5 of Transformers, you"
1282
+ " should load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also"
1283
+ " provide your own `device_map` but it needs to be a dictionary module_name to device, so for instance"
1284
+ " {'encoder.block.0': 0, 'encoder.block.1': 1, ...}",
1285
+ FutureWarning,
1286
+ )
1287
+ self.device_map = (
1288
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1289
+ if device_map is None
1290
+ else device_map
1291
+ )
1292
+ assert_device_map(self.device_map, len(self.encoder.block))
1293
+ self.encoder.parallelize(self.device_map)
1294
+ self.decoder.parallelize(self.device_map)
1295
+ self.lm_head = self.lm_head.to(self.decoder.first_device)
1296
+ self.model_parallel = True
1297
+
1298
+
1299
+ def deparallelize(self):
1300
+ warnings.warn(
1301
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1302
+ FutureWarning,
1303
+ )
1304
+ self.encoder.deparallelize()
1305
+ self.decoder.deparallelize()
1306
+ self.encoder = self.encoder.to("cpu")
1307
+ self.decoder = self.decoder.to("cpu")
1308
+ self.lm_head = self.lm_head.to("cpu")
1309
+ self.model_parallel = False
1310
+ self.device_map = None
1311
+ torch.cuda.empty_cache()
1312
+
1313
+ def get_input_embeddings(self):
1314
+ return self.shared
1315
+
1316
+ def set_input_embeddings(self, new_embeddings):
1317
+ self.shared = new_embeddings
1318
+ self.encoder.set_input_embeddings(new_embeddings)
1319
+ self.decoder.set_input_embeddings(new_embeddings)
1320
+
1321
+ def _tie_weights(self):
1322
+ if self.config.tie_word_embeddings:
1323
+ self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
1324
+ self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
1325
+
1326
+ def set_output_embeddings(self, new_embeddings):
1327
+ self.lm_head = new_embeddings
1328
+
1329
+ def get_output_embeddings(self):
1330
+ return self.lm_head
1331
+
1332
+ def get_encoder(self):
1333
+ return self.encoder
1334
+
1335
+ def get_decoder(self):
1336
+ return self.decoder
1337
+
1338
+ def forward(
1339
+ self,
1340
+ input_ids: Optional[torch.LongTensor] = None,
1341
+ attention_mask: Optional[torch.FloatTensor] = None,
1342
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1343
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
1344
+ head_mask: Optional[torch.FloatTensor] = None,
1345
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
1346
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1347
+ encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1348
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1349
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1350
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1351
+ labels: Optional[torch.LongTensor] = None,
1352
+ use_cache: Optional[bool] = None,
1353
+ output_attentions: Optional[bool] = None,
1354
+ output_hidden_states: Optional[bool] = None,
1355
+ return_dict: Optional[bool] = None,
1356
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
1357
+ r"""
1358
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1359
+ Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
1360
+ config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
1361
+ labels in `[0, ..., config.vocab_size]`
1362
+
1363
+ Returns:
1364
+
1365
+ Examples:
1366
+
1367
+ ```python
1368
+ >>> from transformers import AutoTokenizer, T5ForConditionalGeneration
1369
+
1370
+ >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
1371
+ >>> model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
1372
+
1373
+ >>> # training
1374
+ >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
1375
+ >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
1376
+ >>> outputs = model(input_ids=input_ids, labels=labels)
1377
+ >>> loss = outputs.loss
1378
+ >>> logits = outputs.logits
1379
+
1380
+ >>> # inference
1381
+ >>> input_ids = tokenizer(
1382
+ ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
1383
+ ... ).input_ids # Batch size 1
1384
+ >>> outputs = model.generate(input_ids)
1385
+ >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
1386
+ >>> # studies have shown that owning a dog is good for you.
1387
+ ```"""
1388
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1389
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1390
+
1391
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1392
+ if head_mask is not None and decoder_head_mask is None:
1393
+ if self.config.num_layers == self.config.num_decoder_layers:
1394
+ decoder_head_mask = head_mask
1395
+
1396
+ # Encode if needed (training, first prediction pass)
1397
+ if encoder_outputs is None:
1398
+ # Convert encoder inputs in embeddings if needed
1399
+ encoder_outputs = self.encoder(
1400
+ input_ids=input_ids,
1401
+ attention_mask=attention_mask,
1402
+ inputs_embeds=inputs_embeds,
1403
+ head_mask=head_mask,
1404
+ output_attentions=output_attentions,
1405
+ output_hidden_states=output_hidden_states,
1406
+ return_dict=return_dict,
1407
+ )
1408
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1409
+ encoder_outputs = BaseModelOutput(
1410
+ last_hidden_state=encoder_outputs[0],
1411
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1412
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1413
+ )
1414
+
1415
+ hidden_states = encoder_outputs[0]
1416
+
1417
+ if self.model_parallel:
1418
+ torch.cuda.set_device(self.decoder.first_device)
1419
+
1420
+ if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
1421
+ # get decoder inputs from shifting lm labels to the right
1422
+ decoder_input_ids = self._shift_right(labels)
1423
+
1424
+ # Set device for model parallelism
1425
+ if self.model_parallel:
1426
+ torch.cuda.set_device(self.decoder.first_device)
1427
+ hidden_states = hidden_states.to(self.decoder.first_device)
1428
+ if decoder_input_ids is not None:
1429
+ decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
1430
+ if attention_mask is not None:
1431
+ attention_mask = attention_mask.to(self.decoder.first_device)
1432
+ if decoder_attention_mask is not None:
1433
+ decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
1434
+
1435
+ # Decode
1436
+ decoder_outputs = self.decoder(
1437
+ input_ids=decoder_input_ids,
1438
+ attention_mask=decoder_attention_mask,
1439
+ inputs_embeds=decoder_inputs_embeds,
1440
+ past_key_values=past_key_values,
1441
+ encoder_hidden_states=hidden_states,
1442
+ encoder_attention_mask=attention_mask,
1443
+ head_mask=decoder_head_mask,
1444
+ cross_attn_head_mask=cross_attn_head_mask,
1445
+ use_cache=use_cache,
1446
+ output_attentions=output_attentions,
1447
+ output_hidden_states=output_hidden_states,
1448
+ return_dict=return_dict,
1449
+ )
1450
+
1451
+ sequence_output = decoder_outputs[0]
1452
+
1453
+ # Set device for model parallelism
1454
+ if self.model_parallel:
1455
+ torch.cuda.set_device(self.encoder.first_device)
1456
+ self.lm_head = self.lm_head.to(self.encoder.first_device)
1457
+ sequence_output = sequence_output.to(self.lm_head.weight.device)
1458
+
1459
+ if self.config.tie_word_embeddings:
1460
+ # Rescale output before projecting on vocab
1461
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
1462
+ sequence_output = sequence_output * (self.model_dim**-0.5)
1463
+
1464
+ sequence_output = self.conv_block(sequence_output)
1465
+ lm_logits = self.lm_head(sequence_output)
1466
+
1467
+ loss = None
1468
+ if labels is not None:
1469
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
1470
+ # move labels to correct device to enable PP
1471
+ labels = labels.to(lm_logits.device)
1472
+ if len(labels.shape) == 2:
1473
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
1474
+ else:
1475
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.reshape(-1))
1476
+ # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
1477
+
1478
+ if not return_dict:
1479
+ output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
1480
+ return ((loss,) + output) if loss is not None else output
1481
+
1482
+ return Seq2SeqLMOutput(
1483
+ loss=loss,
1484
+ logits=lm_logits,
1485
+ past_key_values=decoder_outputs.past_key_values,
1486
+ decoder_hidden_states=decoder_outputs.hidden_states,
1487
+ decoder_attentions=decoder_outputs.attentions,
1488
+ cross_attentions=decoder_outputs.cross_attentions,
1489
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1490
+ encoder_hidden_states=encoder_outputs.hidden_states,
1491
+ encoder_attentions=encoder_outputs.attentions,
1492
+ )
1493
+
1494
+ def prepare_inputs_for_generation(
1495
+ self,
1496
+ input_ids,
1497
+ past_key_values=None,
1498
+ attention_mask=None,
1499
+ head_mask=None,
1500
+ decoder_head_mask=None,
1501
+ decoder_attention_mask=None,
1502
+ cross_attn_head_mask=None,
1503
+ use_cache=None,
1504
+ encoder_outputs=None,
1505
+ **kwargs,
1506
+ ):
1507
+ # cut decoder_input_ids if past_key_values is used
1508
+ if past_key_values is not None:
1509
+ past_length = past_key_values[0][0].shape[2]
1510
+
1511
+ # Some generation methods already pass only the last input ID
1512
+ if input_ids.shape[1] > past_length:
1513
+ remove_prefix_length = past_length
1514
+ else:
1515
+ # Default to old behavior: keep only final ID
1516
+ remove_prefix_length = input_ids.shape[1] - 1
1517
+
1518
+ input_ids = input_ids[:, remove_prefix_length:]
1519
+
1520
+ return {
1521
+ "decoder_input_ids": input_ids,
1522
+ "past_key_values": past_key_values,
1523
+ "encoder_outputs": encoder_outputs,
1524
+ "attention_mask": attention_mask,
1525
+ "head_mask": head_mask,
1526
+ "decoder_head_mask": decoder_head_mask,
1527
+ "decoder_attention_mask": decoder_attention_mask,
1528
+ "cross_attn_head_mask": cross_attn_head_mask,
1529
+ "use_cache": use_cache,
1530
+ }
1531
+
1532
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
1533
+ return self._shift_right(labels)
1534
+
1535
+ def _reorder_cache(self, past_key_values, beam_idx):
1536
+ # if decoder past is not included in output
1537
+ # speedy decoding is disabled and no need to reorder
1538
+ if past_key_values is None:
1539
+ logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
1540
+ return past_key_values
1541
+
1542
+ reordered_decoder_past = ()
1543
+ for layer_past_states in past_key_values:
1544
+ # get the correct batch idx from layer past batch dim
1545
+ # batch dim of `past` is at 2nd position
1546
+ reordered_layer_past_states = ()
1547
+ for layer_past_state in layer_past_states:
1548
+ # need to set correct `past` for each of the four key / value states
1549
+ reordered_layer_past_states = reordered_layer_past_states + (
1550
+ layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
1551
+ )
1552
+
1553
+ if reordered_layer_past_states[0].shape != layer_past_states[0].shape:
1554
+ raise ValueError(
1555
+ f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched"
1556
+ )
1557
+ if len(reordered_layer_past_states) != len(layer_past_states):
1558
+ raise ValueError(
1559
+ f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched"
1560
+ )
1561
+
1562
+ reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
1563
+ return reordered_decoder_past
1564
+
1565
+
1566
+
1567
+ class T5MIMOEncoderModel(T5PreTrainedModel):
1568
+ _tied_weights_keys = ["encoder.embed_tokens.weight"]
1569
+ _keys_to_ignore_on_load_unexpected = [r"decoder"]
1570
+
1571
+ def __init__(self, config: T5MIMOConfig):
1572
+ super().__init__(config)
1573
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
1574
+
1575
+ encoder_config = copy.deepcopy(config)
1576
+ encoder_config.use_cache = False
1577
+ encoder_config.is_encoder_decoder = False
1578
+ self.encoder = T5Stack(encoder_config, self.shared)
1579
+
1580
+ # Initialize weights and apply final processing
1581
+ self.post_init()
1582
+
1583
+ # Model parallel
1584
+ self.model_parallel = False
1585
+ self.device_map = None
1586
+
1587
+ def parallelize(self, device_map=None):
1588
+ warnings.warn(
1589
+ "`T5EncoderModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
1590
+ " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
1591
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0,"
1592
+ " 'block.1': 1, ...}",
1593
+ FutureWarning,
1594
+ )
1595
+ self.device_map = (
1596
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1597
+ if device_map is None
1598
+ else device_map
1599
+ )
1600
+ assert_device_map(self.device_map, len(self.encoder.block))
1601
+ self.encoder.parallelize(self.device_map)
1602
+ self.model_parallel = True
1603
+
1604
+ def deparallelize(self):
1605
+ warnings.warn(
1606
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1607
+ FutureWarning,
1608
+ )
1609
+ self.encoder.deparallelize()
1610
+ self.encoder = self.encoder.to("cpu")
1611
+ self.model_parallel = False
1612
+ self.device_map = None
1613
+ torch.cuda.empty_cache()
1614
+
1615
+ def get_input_embeddings(self):
1616
+ return self.shared
1617
+
1618
+ def set_input_embeddings(self, new_embeddings):
1619
+ self.shared = new_embeddings
1620
+ self.encoder.set_input_embeddings(new_embeddings)
1621
+
1622
+ def _tie_weights(self):
1623
+ if self.config.tie_word_embeddings:
1624
+ self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
1625
+
1626
+ def get_encoder(self):
1627
+ return self.encoder
1628
+
1629
+ def _prune_heads(self, heads_to_prune):
1630
+ """
1631
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1632
+ class PreTrainedModel
1633
+ """
1634
+ for layer, heads in heads_to_prune.items():
1635
+ self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads)
1636
+
1637
+ def forward(
1638
+ self,
1639
+ input_ids: Optional[torch.LongTensor] = None,
1640
+ attention_mask: Optional[torch.FloatTensor] = None,
1641
+ head_mask: Optional[torch.FloatTensor] = None,
1642
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1643
+ output_attentions: Optional[bool] = None,
1644
+ output_hidden_states: Optional[bool] = None,
1645
+ return_dict: Optional[bool] = None,
1646
+ ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
1647
+ r"""
1648
+ Returns:
1649
+
1650
+ Example:
1651
+
1652
+ ```python
1653
+ >>> from transformers import AutoTokenizer, T5EncoderModel
1654
+
1655
+ >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
1656
+ >>> model = T5EncoderModel.from_pretrained("google-t5/t5-small")
1657
+ >>> input_ids = tokenizer(
1658
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
1659
+ ... ).input_ids # Batch size 1
1660
+ >>> outputs = model(input_ids=input_ids)
1661
+ >>> last_hidden_states = outputs.last_hidden_state
1662
+ ```"""
1663
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1664
+
1665
+ encoder_outputs = self.encoder(
1666
+ input_ids=input_ids,
1667
+ attention_mask=attention_mask,
1668
+ inputs_embeds=inputs_embeds,
1669
+ head_mask=head_mask,
1670
+ output_attentions=output_attentions,
1671
+ output_hidden_states=output_hidden_states,
1672
+ return_dict=return_dict,
1673
+ )
1674
+
1675
+ return encoder_outputs
1676
+
1677
+
1678
+
1679
+
1680
+ class MultivariateConvBlock(nn.Module):
1681
+ def __init__(self, num_seqs, model_dim, kernel_size=3, num_filters=64, stride=1, padding=1):
1682
+ """
1683
+ Multivariate convolutional block to capture cross-sequence interactions and temporal patterns.
1684
+
1685
+ Args:
1686
+ num_seqs (int): Number of sequences (multivariate time series).
1687
+ model_dim (int): Dimension of each feature vector (typically 256).
1688
+ kernel_size (int): Size of the convolutional kernel. Default is 3.
1689
+ num_filters (int): Number of convolutional filters (output channels). Default is 64.
1690
+ stride (int): Stride of the convolutional kernel. Default is 1.
1691
+ padding (int): Padding for the convolutional kernel. Default is 1 (to preserve sequence length).
1692
+ """
1693
+ super(MultivariateConvBlock, self).__init__()
1694
+
1695
+
1696
+ # 2D Convolution across sequences and time
1697
+ self.conv1 = nn.Conv2d(
1698
+ in_channels=num_seqs,
1699
+ out_channels=num_filters,
1700
+ kernel_size=kernel_size, # Kernel spans across time and all features
1701
+ stride=1, # Stride across time, no stride across features
1702
+ padding=1 # Padding to preserve sequence length, no padding across features
1703
+ )
1704
+
1705
+ # Batch normalization for stabilization and faster convergence
1706
+ self.bn1 = nn.BatchNorm2d(num_filters)
1707
+
1708
+ # Second convolution layer to further model interactions and temporal patterns
1709
+ self.conv2 = nn.Conv2d(
1710
+ in_channels=num_filters,
1711
+ out_channels=num_filters,
1712
+ kernel_size=(kernel_size, 1), # Focus only on temporal patterns
1713
+ stride=(stride, 1),
1714
+ padding=(padding, 0)
1715
+ )
1716
+
1717
+ # Batch normalization after second convolution
1718
+ self.bn2 = nn.BatchNorm2d(num_filters)
1719
+
1720
+ # 1x1 Convolution to reduce the channel dimension back to num_seqs
1721
+ self.conv3 = nn.Conv2d(
1722
+ in_channels=num_filters,
1723
+ out_channels=num_seqs, # Back to the original number of sequences (channels)
1724
+ kernel_size=(1, 1)
1725
+ )
1726
+
1727
+ def forward(self, x):
1728
+ """
1729
+ Forward pass of the multivariate convolutional block.
1730
+
1731
+ Args:
1732
+ x (torch.Tensor): Input tensor of shape [batch_size, num_seqs, seq_len, model_dim].
1733
+
1734
+ Returns:
1735
+ torch.Tensor: Output tensor of shape [batch_size, num_seqs, seq_len, model_dim].
1736
+ """
1737
+ # Permute to [batch_size, num_seqs, seq_len, model_dim] -> [batch_size, num_seqs, model_dim, seq_len]
1738
+ x = x.permute(0, 1, 3, 2)
1739
+
1740
+ # Apply first convolution and activation
1741
+ x = nn.functional.relu(self.bn1(self.conv1(x)))
1742
+ # Apply second convolution and activation
1743
+ x = nn.functional.relu(self.bn2(self.conv2(x)))
1744
+
1745
+ # Reduce channel dimension back to num_seqs
1746
+ x = self.conv3(x)
1747
+
1748
+ # Permute back to original shape [batch_size, num_seqs, seq_len, model_dim]
1749
+ x = x.permute(0, 1, 3, 2)
1750
+
1751
+ return x