ammarnasr commited on
Commit
e9fcc04
1 Parent(s): a81a2e3

Upload T5MIMOForConditionalGeneration

Browse files
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "T5MIMOForConditionalGeneration"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_t5mimo.T5MIMOConfig",
7
+ "AutoModel": "modeling_t5mimo.T5MIMOModel",
8
+ "AutoModelForSeq2SeqLM": "modeling_t5mimo.T5MIMOForConditionalGeneration"
9
+ },
10
+ "classifier_dropout": 0.0,
11
+ "d_ff": 1024,
12
+ "d_kv": 64,
13
+ "d_model": 256,
14
+ "decoder_start_token_id": 0,
15
+ "dense_act_fn": "relu",
16
+ "dropout_rate": 0.1,
17
+ "eos_token_id": 1,
18
+ "feed_forward_proj": "relu",
19
+ "initializer_factor": 0.05,
20
+ "is_encoder_decoder": true,
21
+ "is_gated_act": false,
22
+ "is_mimo": true,
23
+ "layer_norm_epsilon": 1e-06,
24
+ "model_type": "t5mimo",
25
+ "num_decoder_layers": 4,
26
+ "num_filters": 64,
27
+ "num_heads": 4,
28
+ "num_layers": 4,
29
+ "num_seqs": 3,
30
+ "pad_token_id": 0,
31
+ "relative_attention_max_distance": 128,
32
+ "relative_attention_num_buckets": 32,
33
+ "torch_dtype": "float32",
34
+ "transformers_version": "4.41.1",
35
+ "use_cache": true,
36
+ "vocab_size": 4096
37
+ }
configuration_t5mimo.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Mapping
2
+ from transformers.configuration_utils import PretrainedConfig
3
+ from transformers.onnx import OnnxSeq2SeqConfigWithPast
4
+ from transformers.utils import logging
5
+
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+
10
+ class T5MIMOConfig(PretrainedConfig):
11
+ r"""
12
+ This is the configuration class to store the configuration of a [`T5Model`] or a [`TFT5Model`]. It is used to
13
+ instantiate a T5 model according to the specified arguments, defining the model architecture. Instantiating a
14
+ configuration with the defaults will yield a similar configuration to that of the T5
15
+ [google-t5/t5-small](https://huggingface.co/google-t5/t5-small) architecture.
16
+
17
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
18
+ documentation from [`PretrainedConfig`] for more information.
19
+
20
+ Arguments:
21
+ vocab_size (`int`, *optional*, defaults to 32128):
22
+ Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the
23
+ `inputs_ids` passed when calling [`T5Model`] or [`TFT5Model`].
24
+ d_model (`int`, *optional*, defaults to 512):
25
+ Size of the encoder layers and the pooler layer.
26
+ d_kv (`int`, *optional*, defaults to 64):
27
+ Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will
28
+ be defined as `num_heads * d_kv`.
29
+ d_ff (`int`, *optional*, defaults to 2048):
30
+ Size of the intermediate feed forward layer in each `T5Block`.
31
+ num_layers (`int`, *optional*, defaults to 6):
32
+ Number of hidden layers in the Transformer encoder.
33
+ num_decoder_layers (`int`, *optional*):
34
+ Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set.
35
+ num_heads (`int`, *optional*, defaults to 8):
36
+ Number of attention heads for each attention layer in the Transformer encoder.
37
+ relative_attention_num_buckets (`int`, *optional*, defaults to 32):
38
+ The number of buckets to use for each attention layer.
39
+ relative_attention_max_distance (`int`, *optional*, defaults to 128):
40
+ The maximum distance of the longer sequences for the bucket separation.
41
+ dropout_rate (`float`, *optional*, defaults to 0.1):
42
+ The ratio for all dropout layers.
43
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
44
+ The dropout ratio for classifier.
45
+ layer_norm_eps (`float`, *optional*, defaults to 1e-6):
46
+ The epsilon used by the layer normalization layers.
47
+ initializer_factor (`float`, *optional*, defaults to 1):
48
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
49
+ testing).
50
+ feed_forward_proj (`string`, *optional*, defaults to `"relu"`):
51
+ Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. T5v1.1 uses the
52
+ `"gated-gelu"` feed forward projection. Original T5 uses `"relu"`.
53
+ use_cache (`bool`, *optional*, defaults to `True`):
54
+ Whether or not the model should return the last key/values attentions (not used by all models).
55
+ """
56
+
57
+ model_type = "t5mimo"
58
+ keys_to_ignore_at_inference = ["past_key_values"]
59
+ attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}
60
+
61
+ def __init__(
62
+ self,
63
+ vocab_size=32128,
64
+ d_model=512,
65
+ d_kv=64,
66
+ d_ff=2048,
67
+ num_layers=6,
68
+ num_decoder_layers=None,
69
+ num_heads=8,
70
+ relative_attention_num_buckets=32,
71
+ relative_attention_max_distance=128,
72
+ dropout_rate=0.1,
73
+ layer_norm_epsilon=1e-6,
74
+ initializer_factor=1.0,
75
+ feed_forward_proj="relu",
76
+ is_encoder_decoder=True,
77
+ use_cache=True,
78
+ pad_token_id=0,
79
+ eos_token_id=1,
80
+ decoder_start_token_id = 0,
81
+ classifier_dropout=0.0,
82
+ num_seqs=3,
83
+ num_filters=64,
84
+ is_mimo=True,
85
+ **kwargs,
86
+ ):
87
+ self.vocab_size = vocab_size
88
+ self.d_model = d_model
89
+ self.d_kv = d_kv
90
+ self.d_ff = d_ff
91
+ self.num_layers = num_layers
92
+ self.num_decoder_layers = (
93
+ num_decoder_layers if num_decoder_layers is not None else self.num_layers
94
+ ) # default = symmetry
95
+ self.num_heads = num_heads
96
+ self.relative_attention_num_buckets = relative_attention_num_buckets
97
+ self.relative_attention_max_distance = relative_attention_max_distance
98
+ self.dropout_rate = dropout_rate
99
+ self.classifier_dropout = classifier_dropout
100
+ self.layer_norm_epsilon = layer_norm_epsilon
101
+ self.initializer_factor = initializer_factor
102
+ self.feed_forward_proj = feed_forward_proj
103
+ self.use_cache = use_cache
104
+ self.num_seqs = num_seqs
105
+ self.num_filters = num_filters
106
+ self.is_mimo = is_mimo
107
+
108
+ act_info = self.feed_forward_proj.split("-")
109
+ self.dense_act_fn = act_info[-1]
110
+ self.is_gated_act = act_info[0] == "gated"
111
+
112
+ if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
113
+ raise ValueError(
114
+ f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer. "
115
+ "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
116
+ "'gated-gelu' or 'relu'"
117
+ )
118
+
119
+ # for backwards compatibility
120
+ if feed_forward_proj == "gated-gelu":
121
+ self.dense_act_fn = "gelu_new"
122
+
123
+ super().__init__(
124
+ pad_token_id=pad_token_id,
125
+ eos_token_id=eos_token_id,
126
+ decoder_start_token_id=decoder_start_token_id,
127
+ is_encoder_decoder=is_encoder_decoder,
128
+ **kwargs,
129
+ )
130
+
131
+
132
+ class T5MIMOOnnxConfig(OnnxSeq2SeqConfigWithPast):
133
+ @property
134
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
135
+ common_inputs = {
136
+ "input_ids": {0: "batch", 1: "encoder_sequence"},
137
+ "attention_mask": {0: "batch", 1: "encoder_sequence"},
138
+ }
139
+ if self.use_past:
140
+ common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence"
141
+ common_inputs["decoder_input_ids"] = {0: "batch"}
142
+ common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
143
+ else:
144
+ common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
145
+ common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
146
+
147
+ if self.use_past:
148
+ self.fill_with_past_key_values_(common_inputs, direction="inputs")
149
+
150
+ return common_inputs
151
+
152
+ @property
153
+ def default_onnx_opset(self) -> int:
154
+ return 13
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "decoder_start_token_id": 0,
4
+ "eos_token_id": 1,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.41.1"
7
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9d97da94a794f0b0aad1566b9e13205267ea4b3b70ae8c6cd147e6fe6e651cb
3
+ size 33588312
modeling_t5mimo.py ADDED
@@ -0,0 +1,1753 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import warnings
4
+ from typing import Optional, Tuple, Union
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import CrossEntropyLoss
8
+ from transformers.activations import ACT2FN
9
+ from transformers.modeling_outputs import (
10
+ BaseModelOutput,
11
+ BaseModelOutputWithPastAndCrossAttentions,
12
+ Seq2SeqLMOutput,
13
+ Seq2SeqModelOutput,
14
+ )
15
+ from transformers.modeling_utils import PreTrainedModel
16
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer
17
+ from transformers.utils import (
18
+ DUMMY_INPUTS,
19
+ DUMMY_MASK,
20
+ is_torch_fx_proxy,
21
+ logging,
22
+ )
23
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
24
+ from .configuration_t5mimo import T5MIMOConfig
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+
31
+ class T5LayerNorm(nn.Module):
32
+ def __init__(self, hidden_size, eps=1e-6):
33
+ """
34
+ Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
35
+ """
36
+ super().__init__()
37
+ self.weight = nn.Parameter(torch.ones(hidden_size))
38
+ self.variance_epsilon = eps
39
+
40
+ def forward(self, hidden_states):
41
+ # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
42
+ # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
43
+ # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
44
+ # half-precision inputs is done in fp32
45
+
46
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
47
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
48
+
49
+ # convert into half-precision if necessary
50
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
51
+ hidden_states = hidden_states.to(self.weight.dtype)
52
+
53
+ return self.weight * hidden_states
54
+
55
+
56
+ ALL_LAYERNORM_LAYERS.append(T5LayerNorm)
57
+
58
+
59
+ class T5DenseActDense(nn.Module):
60
+ def __init__(self, config: T5MIMOConfig):
61
+ super().__init__()
62
+ self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
63
+ self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
64
+ self.dropout = nn.Dropout(config.dropout_rate)
65
+ self.act = ACT2FN[config.dense_act_fn]
66
+
67
+ def forward(self, hidden_states):
68
+ hidden_states = self.wi(hidden_states)
69
+ hidden_states = self.act(hidden_states)
70
+ hidden_states = self.dropout(hidden_states)
71
+ if (
72
+ isinstance(self.wo.weight, torch.Tensor)
73
+ and hidden_states.dtype != self.wo.weight.dtype
74
+ and self.wo.weight.dtype != torch.int8
75
+ ):
76
+ hidden_states = hidden_states.to(self.wo.weight.dtype)
77
+ hidden_states = self.wo(hidden_states)
78
+ return hidden_states
79
+
80
+
81
+ class T5DenseGatedActDense(nn.Module):
82
+ def __init__(self, config: T5MIMOConfig):
83
+ super().__init__()
84
+ self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
85
+ self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
86
+ self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
87
+ self.dropout = nn.Dropout(config.dropout_rate)
88
+ self.act = ACT2FN[config.dense_act_fn]
89
+
90
+ def forward(self, hidden_states):
91
+ hidden_gelu = self.act(self.wi_0(hidden_states))
92
+ hidden_linear = self.wi_1(hidden_states)
93
+ hidden_states = hidden_gelu * hidden_linear
94
+ hidden_states = self.dropout(hidden_states)
95
+
96
+ # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
97
+ # See https://github.com/huggingface/transformers/issues/20287
98
+ # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
99
+ if (
100
+ isinstance(self.wo.weight, torch.Tensor)
101
+ and hidden_states.dtype != self.wo.weight.dtype
102
+ and self.wo.weight.dtype != torch.int8
103
+ ):
104
+ hidden_states = hidden_states.to(self.wo.weight.dtype)
105
+
106
+ hidden_states = self.wo(hidden_states)
107
+ return hidden_states
108
+
109
+
110
+ class T5LayerFF(nn.Module):
111
+ def __init__(self, config: T5MIMOConfig):
112
+ super().__init__()
113
+ if config.is_gated_act:
114
+ self.DenseReluDense = T5DenseGatedActDense(config)
115
+ else:
116
+ self.DenseReluDense = T5DenseActDense(config)
117
+
118
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
119
+ self.dropout = nn.Dropout(config.dropout_rate)
120
+
121
+ def forward(self, hidden_states):
122
+ forwarded_states = self.layer_norm(hidden_states)
123
+ forwarded_states = self.DenseReluDense(forwarded_states)
124
+ hidden_states = hidden_states + self.dropout(forwarded_states)
125
+ return hidden_states
126
+
127
+
128
+
129
+ class MultivariateConvBlock(nn.Module):
130
+ def __init__(self, config: T5MIMOConfig, kernel_size=3, stride=1, padding=1):
131
+ super().__init__()
132
+ # 2D Convolution across sequences and time
133
+ self.conv1 = nn.Conv2d(
134
+ in_channels=config.num_seqs,
135
+ out_channels=config.num_filters,
136
+ kernel_size=kernel_size, # Kernel spans across time and all features
137
+ stride=1, # Stride across time, no stride across features
138
+ padding=1 # Padding to preserve sequence length, no padding across features
139
+ )
140
+
141
+ # Batch normalization for stabilization and faster convergence
142
+ self.bn1 = nn.BatchNorm2d(config.num_filters)
143
+
144
+ # Second convolution layer to further model interactions and temporal patterns
145
+ self.conv2 = nn.Conv2d(
146
+ in_channels=config.num_filters,
147
+ out_channels=config.num_filters,
148
+ kernel_size=(kernel_size, 1), # Focus only on temporal patterns
149
+ stride=(stride, 1),
150
+ padding=(padding, 0)
151
+ )
152
+
153
+ # Batch normalization after second convolution
154
+ self.bn2 = nn.BatchNorm2d(config.num_filters)
155
+
156
+ # 1x1 Convolution to reduce the channel dimension back to num_seqs
157
+ self.conv3 = nn.Conv2d(
158
+ in_channels=config.num_filters,
159
+ out_channels=config.num_seqs, # Back to the original number of sequences (channels)
160
+ kernel_size=(1, 1)
161
+ )
162
+
163
+ def forward(self, x):
164
+ """
165
+ Forward pass of the multivariate convolutional block.
166
+
167
+ Args:
168
+ x (torch.Tensor): Input tensor of shape [batch_size, num_seqs, seq_len, model_dim].
169
+
170
+ Returns:
171
+ torch.Tensor: Output tensor of shape [batch_size, num_seqs, seq_len, model_dim].
172
+ """
173
+ # Permute to [batch_size, num_seqs, seq_len, model_dim] -> [batch_size, num_seqs, model_dim, seq_len]
174
+ x = x.permute(0, 1, 3, 2)
175
+
176
+ # Apply first convolution and activation
177
+ x = nn.functional.relu(self.bn1(self.conv1(x)))
178
+ # Apply second convolution and activation
179
+ x = nn.functional.relu(self.bn2(self.conv2(x)))
180
+
181
+ # Reduce channel dimension back to num_seqs
182
+ x = self.conv3(x)
183
+
184
+ # Permute back to original shape [batch_size, num_seqs, seq_len, model_dim]
185
+ x = x.permute(0, 1, 3, 2)
186
+
187
+ return x
188
+
189
+
190
+
191
+ class T5Attention(nn.Module):
192
+ def __init__(self, config: T5MIMOConfig, has_relative_attention_bias=False):
193
+ super().__init__()
194
+ self.is_decoder = config.is_decoder
195
+ self.has_relative_attention_bias = has_relative_attention_bias
196
+ self.relative_attention_num_buckets = config.relative_attention_num_buckets
197
+ self.relative_attention_max_distance = config.relative_attention_max_distance
198
+ self.d_model = config.d_model
199
+ self.key_value_proj_dim = config.d_kv
200
+ self.n_heads = config.num_heads
201
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
202
+ self.dropout = config.dropout_rate
203
+ self.config = config
204
+
205
+ # Mesh TensorFlow initialization to avoid scaling before softmax
206
+ self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
207
+ self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
208
+ self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
209
+ self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
210
+
211
+ if self.has_relative_attention_bias:
212
+ self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
213
+ self.pruned_heads = set()
214
+ self.gradient_checkpointing = False
215
+
216
+ def prune_heads(self, heads):
217
+ if len(heads) == 0:
218
+ return
219
+ heads, index = find_pruneable_heads_and_indices(
220
+ heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
221
+ )
222
+ # Prune linear layers
223
+ self.q = prune_linear_layer(self.q, index)
224
+ self.k = prune_linear_layer(self.k, index)
225
+ self.v = prune_linear_layer(self.v, index)
226
+ self.o = prune_linear_layer(self.o, index, dim=1)
227
+ # Update hyper params
228
+ self.n_heads = self.n_heads - len(heads)
229
+ self.inner_dim = self.key_value_proj_dim * self.n_heads
230
+ self.pruned_heads = self.pruned_heads.union(heads)
231
+
232
+ @staticmethod
233
+ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
234
+ """
235
+ Adapted from Mesh Tensorflow:
236
+ https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
237
+
238
+ Translate relative position to a bucket number for relative attention. The relative position is defined as
239
+ memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
240
+ position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
241
+ small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
242
+ positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
243
+ This should allow for more graceful generalization to longer sequences than the model has been trained on
244
+
245
+ Args:
246
+ relative_position: an int32 Tensor
247
+ bidirectional: a boolean - whether the attention is bidirectional
248
+ num_buckets: an integer
249
+ max_distance: an integer
250
+
251
+ Returns:
252
+ a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
253
+ """
254
+ relative_buckets = 0
255
+ if bidirectional:
256
+ num_buckets //= 2
257
+ relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
258
+ relative_position = torch.abs(relative_position)
259
+ else:
260
+ relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
261
+ # now relative_position is in the range [0, inf)
262
+
263
+ # half of the buckets are for exact increments in positions
264
+ max_exact = num_buckets // 2
265
+ is_small = relative_position < max_exact
266
+
267
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
268
+ relative_position_if_large = max_exact + (
269
+ torch.log(relative_position.float() / max_exact)
270
+ / math.log(max_distance / max_exact)
271
+ * (num_buckets - max_exact)
272
+ ).to(torch.long)
273
+ relative_position_if_large = torch.min(
274
+ relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
275
+ )
276
+
277
+ relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
278
+ return relative_buckets
279
+
280
+ def compute_bias(self, query_length, key_length, device=None, multivar_dim=None):
281
+ """Compute binned relative position bias"""
282
+ if device is None:
283
+ device = self.relative_attention_bias.weight.device
284
+ context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
285
+ memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
286
+ relative_position = memory_position - context_position # shape (query_length, key_length)
287
+ relative_position_bucket = self._relative_position_bucket(
288
+ relative_position, # shape (query_length, key_length)
289
+ bidirectional=(not self.is_decoder),
290
+ num_buckets=self.relative_attention_num_buckets,
291
+ max_distance=self.relative_attention_max_distance,
292
+ )
293
+ values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
294
+ values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
295
+ if self.config.is_mimo:
296
+ if multivar_dim == None:
297
+ raise ValueError(f"multivar_dim can not be None when config.is_mimo=True")
298
+ values = values.unsqueeze(0)# shape (1, 1, num_heads, query_length, key_length)
299
+ values = values.repeat(1, multivar_dim, 1, 1, 1) # shape (1, multivar_dim, num_heads, query_length, key_length)
300
+ return values
301
+
302
+ def forward(
303
+ self,
304
+ hidden_states,
305
+ mask=None,
306
+ key_value_states=None,
307
+ position_bias=None,
308
+ past_key_value=None,
309
+ layer_head_mask=None,
310
+ query_length=None,
311
+ use_cache=False,
312
+ output_attentions=False,
313
+ ):
314
+ """
315
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
316
+ """
317
+ # Input is (batch_size, seq_length, dim)
318
+ # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
319
+ # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
320
+
321
+ if self.config.is_mimo:
322
+ batch_size, multivar_dim, seq_length = hidden_states.shape[:3]
323
+ else:
324
+ batch_size, seq_length = hidden_states.shape[:2]
325
+ multivar_dim=None
326
+ real_seq_length = seq_length
327
+
328
+ if past_key_value is not None:
329
+ if len(past_key_value) != 2:
330
+ raise ValueError(f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states")
331
+ if self.config.is_mimo:
332
+ real_seq_length += past_key_value[0].shape[3] if query_length is None else query_length
333
+ else:
334
+ real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
335
+
336
+ if self.config.is_mimo:
337
+ key_length = real_seq_length if key_value_states is None else key_value_states.shape[2]
338
+ else:
339
+ key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
340
+
341
+
342
+
343
+ def shape(states):
344
+ """projection"""
345
+ if self.config.is_mimo:
346
+ return states.view(batch_size, multivar_dim, -1, self.n_heads, self.key_value_proj_dim).transpose(2, 3)
347
+ else:
348
+ return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
349
+
350
+
351
+ def unshape(states):
352
+ """reshape"""
353
+ if self.config.is_mimo:
354
+ return states.transpose(2, 3).contiguous().view(batch_size, multivar_dim, -1, self.inner_dim)
355
+ else:
356
+ return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
357
+
358
+ def project(hidden_states, proj_layer, key_value_states, past_key_value):
359
+ """projects hidden states correctly to key/query states"""
360
+ if key_value_states is None:
361
+ # self-attn
362
+ # (batch_size, n_heads, seq_length, dim_per_head)
363
+ hidden_states = shape(proj_layer(hidden_states))
364
+ elif past_key_value is None:
365
+ # cross-attn
366
+ # (batch_size, n_heads, seq_length, dim_per_head)
367
+ hidden_states = shape(proj_layer(key_value_states))
368
+ if past_key_value is not None:
369
+ if key_value_states is None:
370
+ # self-attn
371
+ # (batch_size, n_heads, key_length, dim_per_head)
372
+ if self.config.is_mimo:
373
+ hidden_states = torch.cat([past_key_value, hidden_states], dim=3)
374
+ else:
375
+ hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
376
+ elif past_key_value.shape[2] != key_value_states.shape[1]:
377
+ # checking that the `sequence_length` of the `past_key_value` is the same as
378
+ # the provided `key_value_states` to support prefix tuning
379
+ # cross-attn
380
+ # (batch_size, n_heads, seq_length, dim_per_head)
381
+ hidden_states = shape(proj_layer(key_value_states))
382
+ else:
383
+ # cross-attn
384
+ hidden_states = past_key_value
385
+ return hidden_states
386
+
387
+ # get query states
388
+ query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
389
+
390
+
391
+ # get key/value states
392
+ key_states = project(
393
+ hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
394
+ )
395
+ value_states = project(
396
+ hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
397
+ )
398
+
399
+
400
+
401
+ # compute scores
402
+ if self.config.is_mimo:
403
+ scores = torch.matmul(query_states, key_states.transpose(4, 3))
404
+ else:
405
+ scores = torch.matmul(query_states, key_states.transpose(3, 2)) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
406
+
407
+
408
+
409
+
410
+ if position_bias is None:
411
+ if not self.has_relative_attention_bias:
412
+ if self.config.is_mimo:
413
+ position_bias = torch.zeros((1,multivar_dim, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype)
414
+ else:
415
+ position_bias = torch.zeros((1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype)
416
+ if self.gradient_checkpointing and self.training:
417
+ position_bias.requires_grad = True
418
+ else:
419
+ position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device, multivar_dim=multivar_dim)
420
+
421
+
422
+ # if key and values are already calculated
423
+ # we want only the last query position bias
424
+ if past_key_value is not None:
425
+ if self.config.is_mimo:
426
+ position_bias = position_bias[:, :, :, -hidden_states.size(2) :, :]
427
+ else:
428
+ position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
429
+
430
+ if mask is not None:
431
+ position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
432
+
433
+
434
+
435
+
436
+ if self.pruned_heads:
437
+ mask = torch.ones(position_bias.shape[1])
438
+ mask[list(self.pruned_heads)] = 0
439
+ position_bias_masked = position_bias[:, mask.bool()]
440
+ else:
441
+ position_bias_masked = position_bias
442
+
443
+
444
+ scores += position_bias_masked
445
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) # (batch_size, n_heads, seq_length, key_length)
446
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # (batch_size, n_heads, seq_length, key_length)
447
+
448
+ # Mask heads if we want to
449
+ if layer_head_mask is not None:
450
+ attn_weights = attn_weights * layer_head_mask
451
+
452
+
453
+ attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
454
+ attn_output = self.o(attn_output)
455
+
456
+
457
+ present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
458
+ outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
459
+
460
+
461
+ if output_attentions:
462
+ outputs = outputs + (attn_weights,)
463
+
464
+ return outputs
465
+
466
+
467
+ class T5LayerSelfAttention(nn.Module):
468
+ def __init__(self, config, has_relative_attention_bias=False):
469
+ super().__init__()
470
+ self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
471
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
472
+ self.dropout = nn.Dropout(config.dropout_rate)
473
+
474
+ def forward(
475
+ self,
476
+ hidden_states,
477
+ attention_mask=None,
478
+ position_bias=None,
479
+ layer_head_mask=None,
480
+ past_key_value=None,
481
+ use_cache=False,
482
+ output_attentions=False,
483
+ ):
484
+ normed_hidden_states = self.layer_norm(hidden_states)
485
+ attention_output = self.SelfAttention(
486
+ normed_hidden_states,
487
+ mask=attention_mask,
488
+ position_bias=position_bias,
489
+ layer_head_mask=layer_head_mask,
490
+ past_key_value=past_key_value,
491
+ use_cache=use_cache,
492
+ output_attentions=output_attentions,
493
+ )
494
+
495
+ hidden_states = hidden_states + self.dropout(attention_output[0])
496
+ outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
497
+ return outputs
498
+
499
+
500
+ class T5LayerCrossAttention(nn.Module):
501
+ def __init__(self, config):
502
+ super().__init__()
503
+ self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False)
504
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
505
+ self.dropout = nn.Dropout(config.dropout_rate)
506
+
507
+ def forward(
508
+ self,
509
+ hidden_states,
510
+ key_value_states,
511
+ attention_mask=None,
512
+ position_bias=None,
513
+ layer_head_mask=None,
514
+ past_key_value=None,
515
+ use_cache=False,
516
+ query_length=None,
517
+ output_attentions=False,
518
+ ):
519
+ normed_hidden_states = self.layer_norm(hidden_states)
520
+ attention_output = self.EncDecAttention(
521
+ normed_hidden_states,
522
+ mask=attention_mask,
523
+ key_value_states=key_value_states,
524
+ position_bias=position_bias,
525
+ layer_head_mask=layer_head_mask,
526
+ past_key_value=past_key_value,
527
+ use_cache=use_cache,
528
+ query_length=query_length,
529
+ output_attentions=output_attentions,
530
+ )
531
+ layer_output = hidden_states + self.dropout(attention_output[0])
532
+ outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
533
+ return outputs
534
+
535
+
536
+ class T5Block(nn.Module):
537
+ def __init__(self, config, has_relative_attention_bias=False):
538
+ super().__init__()
539
+ self.is_decoder = config.is_decoder
540
+ self.layer = nn.ModuleList()
541
+ self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
542
+ if self.is_decoder:
543
+ self.layer.append(T5LayerCrossAttention(config))
544
+
545
+ self.layer.append(T5LayerFF(config))
546
+
547
+ self.config = config
548
+
549
+ def forward(
550
+ self,
551
+ hidden_states,
552
+ attention_mask=None,
553
+ position_bias=None,
554
+ encoder_hidden_states=None,
555
+ encoder_attention_mask=None,
556
+ encoder_decoder_position_bias=None,
557
+ layer_head_mask=None,
558
+ cross_attn_layer_head_mask=None,
559
+ past_key_value=None,
560
+ use_cache=False,
561
+ output_attentions=False,
562
+ return_dict=True,
563
+ ):
564
+ if past_key_value is not None:
565
+ if not self.is_decoder:
566
+ logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.")
567
+ expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
568
+
569
+ if len(past_key_value) != expected_num_past_key_values:
570
+ raise ValueError(
571
+ f"There should be {expected_num_past_key_values} past states. "
572
+ f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
573
+ f"Got {len(past_key_value)} past key / value states"
574
+ )
575
+
576
+ self_attn_past_key_value = past_key_value[:2]
577
+ cross_attn_past_key_value = past_key_value[2:]
578
+ else:
579
+ self_attn_past_key_value, cross_attn_past_key_value = None, None
580
+
581
+ self_attention_outputs = self.layer[0](
582
+ hidden_states,
583
+ attention_mask=attention_mask,
584
+ position_bias=position_bias,
585
+ layer_head_mask=layer_head_mask,
586
+ past_key_value=self_attn_past_key_value,
587
+ use_cache=use_cache,
588
+ output_attentions=output_attentions,
589
+ )
590
+ hidden_states, present_key_value_state = self_attention_outputs[:2]
591
+ attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
592
+
593
+ # clamp inf values to enable fp16 training
594
+ if hidden_states.dtype == torch.float16:
595
+ clamp_value = torch.where(
596
+ torch.isinf(hidden_states).any(),
597
+ torch.finfo(hidden_states.dtype).max - 1000,
598
+ torch.finfo(hidden_states.dtype).max,
599
+ )
600
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
601
+
602
+ do_cross_attention = self.is_decoder and encoder_hidden_states is not None
603
+ if do_cross_attention:
604
+ # the actual query length is unknown for cross attention
605
+ # if using past key value states. Need to inject it here
606
+ if present_key_value_state is not None:
607
+ if self.config.is_mimo:
608
+ query_length = present_key_value_state[0].shape[3]
609
+ else:
610
+ query_length = present_key_value_state[0].shape[2]
611
+ else:
612
+ query_length = None
613
+
614
+ cross_attention_outputs = self.layer[1](
615
+ hidden_states,
616
+ key_value_states=encoder_hidden_states,
617
+ attention_mask=encoder_attention_mask,
618
+ position_bias=encoder_decoder_position_bias,
619
+ layer_head_mask=cross_attn_layer_head_mask,
620
+ past_key_value=cross_attn_past_key_value,
621
+ query_length=query_length,
622
+ use_cache=use_cache,
623
+ output_attentions=output_attentions,
624
+ )
625
+ hidden_states = cross_attention_outputs[0]
626
+
627
+ # clamp inf values to enable fp16 training
628
+ if hidden_states.dtype == torch.float16:
629
+ clamp_value = torch.where(
630
+ torch.isinf(hidden_states).any(),
631
+ torch.finfo(hidden_states.dtype).max - 1000,
632
+ torch.finfo(hidden_states.dtype).max,
633
+ )
634
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
635
+
636
+ # Combine self attn and cross attn key value states
637
+ if present_key_value_state is not None:
638
+ present_key_value_state = present_key_value_state + cross_attention_outputs[1]
639
+
640
+ # Keep cross-attention outputs and relative position weights
641
+ attention_outputs = attention_outputs + cross_attention_outputs[2:]
642
+
643
+ # Apply Feed Forward layer
644
+ hidden_states = self.layer[-1](hidden_states)
645
+
646
+ # clamp inf values to enable fp16 training
647
+ if hidden_states.dtype == torch.float16:
648
+ clamp_value = torch.where(
649
+ torch.isinf(hidden_states).any(),
650
+ torch.finfo(hidden_states.dtype).max - 1000,
651
+ torch.finfo(hidden_states.dtype).max,
652
+ )
653
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
654
+
655
+ outputs = (hidden_states,)
656
+
657
+ if use_cache:
658
+ outputs = outputs + (present_key_value_state,) + attention_outputs
659
+ else:
660
+ outputs = outputs + attention_outputs
661
+
662
+ return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
663
+
664
+
665
+ class T5ClassificationHead(nn.Module):
666
+ """Head for sentence-level classification tasks."""
667
+
668
+ def __init__(self, config: T5MIMOConfig):
669
+ super().__init__()
670
+ self.dense = nn.Linear(config.d_model, config.d_model)
671
+ self.dropout = nn.Dropout(p=config.classifier_dropout)
672
+ self.out_proj = nn.Linear(config.d_model, config.num_labels)
673
+
674
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
675
+ hidden_states = self.dropout(hidden_states)
676
+ hidden_states = self.dense(hidden_states)
677
+ hidden_states = torch.tanh(hidden_states)
678
+ hidden_states = self.dropout(hidden_states)
679
+ hidden_states = self.out_proj(hidden_states)
680
+ return hidden_states
681
+
682
+
683
+ class T5PreTrainedModel(PreTrainedModel):
684
+ """
685
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
686
+ models.
687
+ """
688
+
689
+ config_class = T5MIMOConfig
690
+ base_model_prefix = "transformer"
691
+ is_parallelizable = True
692
+ supports_gradient_checkpointing = True
693
+ _no_split_modules = ["T5Block"]
694
+ _keep_in_fp32_modules = ["wo"]
695
+
696
+ @property
697
+ def dummy_inputs(self):
698
+ input_ids = torch.tensor(DUMMY_INPUTS)
699
+ input_mask = torch.tensor(DUMMY_MASK)
700
+ dummy_inputs = {
701
+ "decoder_input_ids": input_ids,
702
+ "input_ids": input_ids,
703
+ "decoder_attention_mask": input_mask,
704
+ }
705
+ return dummy_inputs
706
+
707
+ def _init_weights(self, module):
708
+ """Initialize the weights"""
709
+ factor = self.config.initializer_factor # Used for testing weights initialization
710
+ if isinstance(module, T5LayerNorm):
711
+ module.weight.data.fill_(factor * 1.0)
712
+ elif isinstance(
713
+ module,
714
+ (T5MIMOModel, T5MIMOForConditionalGeneration, T5MIMOEncoderModel),
715
+ ):
716
+ # Mesh TensorFlow embeddings initialization
717
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
718
+ module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
719
+ if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
720
+ module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
721
+ if hasattr(module, "qa_outputs"):
722
+ module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
723
+ module.qa_outputs.bias.data.zero_()
724
+ elif isinstance(module, T5ClassificationHead):
725
+ module.dense.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
726
+ if hasattr(module.dense, "bias") and module.dense.bias is not None:
727
+ module.dense.bias.data.zero_()
728
+ module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
729
+ if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None:
730
+ module.out_proj.bias.data.zero_()
731
+ elif isinstance(module, T5DenseActDense):
732
+ # Mesh TensorFlow FF initialization
733
+ # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
734
+ # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
735
+ module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
736
+ if hasattr(module.wi, "bias") and module.wi.bias is not None:
737
+ module.wi.bias.data.zero_()
738
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
739
+ if hasattr(module.wo, "bias") and module.wo.bias is not None:
740
+ module.wo.bias.data.zero_()
741
+ elif isinstance(module, T5DenseGatedActDense):
742
+ module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
743
+ if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
744
+ module.wi_0.bias.data.zero_()
745
+ module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
746
+ if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
747
+ module.wi_1.bias.data.zero_()
748
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
749
+ if hasattr(module.wo, "bias") and module.wo.bias is not None:
750
+ module.wo.bias.data.zero_()
751
+ elif isinstance(module, T5Attention):
752
+ # Mesh TensorFlow attention initialization to avoid scaling before softmax
753
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
754
+ d_model = self.config.d_model
755
+ key_value_proj_dim = self.config.d_kv
756
+ n_heads = self.config.num_heads
757
+ module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
758
+ module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
759
+ module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
760
+ module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
761
+ if module.has_relative_attention_bias:
762
+ module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
763
+
764
+ def _shift_right(self, input_ids):
765
+ decoder_start_token_id = self.config.decoder_start_token_id
766
+ pad_token_id = self.config.pad_token_id
767
+
768
+ if decoder_start_token_id is None:
769
+ raise ValueError(
770
+ "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. "
771
+ "See T5 docs for more information."
772
+ )
773
+
774
+ # shift inputs to the right
775
+ if is_torch_fx_proxy(input_ids):
776
+ # Item assignment is not supported natively for proxies.
777
+ shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
778
+ shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
779
+ else:
780
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
781
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
782
+ shifted_input_ids[..., 0] = decoder_start_token_id
783
+
784
+ if pad_token_id is None:
785
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
786
+ # replace possible -100 values in labels by `pad_token_id`
787
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
788
+
789
+ return shifted_input_ids
790
+
791
+
792
+ class T5Stack(T5PreTrainedModel):
793
+ def __init__(self, config, embed_tokens=None):
794
+ super().__init__(config)
795
+
796
+ self.embed_tokens = embed_tokens
797
+ self.is_decoder = config.is_decoder
798
+
799
+ self.block = nn.ModuleList(
800
+ [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
801
+ )
802
+ self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
803
+ self.dropout = nn.Dropout(config.dropout_rate)
804
+
805
+ # Initialize weights and apply final processing
806
+ self.post_init()
807
+ # Model parallel
808
+ self.model_parallel = False
809
+ self.device_map = None
810
+ self.gradient_checkpointing = False
811
+
812
+ def parallelize(self, device_map=None):
813
+ warnings.warn(
814
+ "`T5Stack.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model"
815
+ " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
816
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0,"
817
+ " 'block.1': 1, ...}",
818
+ FutureWarning,
819
+ )
820
+ # Check validity of device_map
821
+ self.device_map = (
822
+ get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map
823
+ )
824
+ assert_device_map(self.device_map, len(self.block))
825
+ self.model_parallel = True
826
+ self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
827
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
828
+ # Load onto devices
829
+ for k, v in self.device_map.items():
830
+ for layer in v:
831
+ cuda_device = "cuda:" + str(k)
832
+ self.block[layer] = self.block[layer].to(cuda_device)
833
+
834
+ # Set embed_tokens to first layer
835
+ self.embed_tokens = self.embed_tokens.to(self.first_device)
836
+ # Set final layer norm to last device
837
+ self.final_layer_norm = self.final_layer_norm.to(self.last_device)
838
+
839
+
840
+ def deparallelize(self):
841
+ warnings.warn(
842
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
843
+ FutureWarning,
844
+ )
845
+ self.model_parallel = False
846
+ self.device_map = None
847
+ self.first_device = "cpu"
848
+ self.last_device = "cpu"
849
+ for i in range(len(self.block)):
850
+ self.block[i] = self.block[i].to("cpu")
851
+ self.embed_tokens = self.embed_tokens.to("cpu")
852
+ self.final_layer_norm = self.final_layer_norm.to("cpu")
853
+ torch.cuda.empty_cache()
854
+
855
+ def get_input_embeddings(self):
856
+ return self.embed_tokens
857
+
858
+ def set_input_embeddings(self, new_embeddings):
859
+ self.embed_tokens = new_embeddings
860
+
861
+ def forward(
862
+ self,
863
+ input_ids=None,
864
+ attention_mask=None,
865
+ encoder_hidden_states=None,
866
+ encoder_attention_mask=None,
867
+ inputs_embeds=None,
868
+ head_mask=None,
869
+ cross_attn_head_mask=None,
870
+ past_key_values=None,
871
+ use_cache=None,
872
+ output_attentions=None,
873
+ output_hidden_states=None,
874
+ return_dict=None,
875
+ ):
876
+ # Model parallel
877
+ if self.model_parallel:
878
+ torch.cuda.set_device(self.first_device)
879
+ self.embed_tokens = self.embed_tokens.to(self.first_device)
880
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
881
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
882
+ output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
883
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
884
+
885
+ if input_ids is not None and inputs_embeds is not None:
886
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
887
+ raise ValueError(f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time")
888
+ elif input_ids is not None:
889
+ input_shape = input_ids.size()
890
+ elif inputs_embeds is not None:
891
+ input_shape = inputs_embeds.size()[:-1]
892
+ else:
893
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
894
+ raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
895
+
896
+ if inputs_embeds is None:
897
+ if self.embed_tokens is None:
898
+ raise ValueError("You have to initialize the model with valid token embeddings")
899
+ inputs_embeds = self.embed_tokens(input_ids)
900
+
901
+ if self.config.is_mimo:
902
+ batch_size, multivar_seqs ,seq_length = input_shape
903
+ else:
904
+ batch_size, seq_length = input_shape
905
+
906
+ # required mask seq length can be calculated via length of past
907
+ if self.config.is_mimo:
908
+ mask_seq_length = past_key_values[0][0].shape[3] + seq_length if past_key_values is not None else seq_length
909
+ else:
910
+ mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
911
+
912
+ if use_cache is True:
913
+ if not self.is_decoder:
914
+ raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
915
+
916
+ # initialize past_key_values with `None` if past does not exist
917
+ if past_key_values is None:
918
+ past_key_values = [None] * len(self.block)
919
+ if attention_mask is None:
920
+ if self.config.is_mimo:
921
+ attention_mask = torch.ones(batch_size,multivar_seqs, mask_seq_length, device=inputs_embeds.device)
922
+ else:
923
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
924
+
925
+
926
+
927
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
928
+ # ourselves in which case we just need to make it broadcastable to all heads.
929
+
930
+ if self.config.is_mimo:
931
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask[:,0,:], (input_shape[0], input_shape[2]))
932
+ extended_attention_mask = extended_attention_mask.unsqueeze(1)
933
+ extended_attention_mask = extended_attention_mask.repeat(1, input_shape[1], 1, 1, 1)
934
+ else:
935
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
936
+
937
+
938
+ # If a 2D or 3D attention mask is provided for the cross-attention
939
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
940
+ if self.is_decoder and encoder_hidden_states is not None:
941
+ if self.config.is_mimo:
942
+ encoder_batch_size, multivar_dem, encoder_sequence_length, _ = encoder_hidden_states.size()
943
+ else:
944
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
945
+
946
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
947
+ if encoder_attention_mask is None:
948
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long)
949
+
950
+ if self.config.is_mimo:
951
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
952
+ encoder_extended_attention_mask = encoder_extended_attention_mask.unsqueeze(0)
953
+ encoder_extended_attention_mask = encoder_extended_attention_mask.repeat(1, input_shape[1], 1, 1, 1)
954
+ else:
955
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
956
+ else:
957
+ if self.config.is_mimo:
958
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
959
+ encoder_extended_attention_mask = encoder_extended_attention_mask.permute(0, 2, 1, 3)
960
+ encoder_extended_attention_mask = encoder_extended_attention_mask.unsqueeze(3)
961
+ else:
962
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
963
+
964
+ else:
965
+ encoder_extended_attention_mask = None
966
+
967
+
968
+
969
+ if self.gradient_checkpointing and self.training:
970
+ if use_cache:
971
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
972
+ use_cache = False
973
+
974
+ # Prepare head mask if needed
975
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
976
+ cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
977
+ present_key_value_states = () if use_cache else None
978
+ all_hidden_states = () if output_hidden_states else None
979
+ all_attentions = () if output_attentions else None
980
+ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
981
+ position_bias = None
982
+ encoder_decoder_position_bias = None
983
+
984
+ hidden_states = self.dropout(inputs_embeds)
985
+
986
+ for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
987
+ layer_head_mask = head_mask[i]
988
+ cross_attn_layer_head_mask = cross_attn_head_mask[i]
989
+ # Model parallel
990
+ if self.model_parallel:
991
+ torch.cuda.set_device(hidden_states.device)
992
+ # Ensure that attention_mask is always on the same device as hidden_states
993
+ if attention_mask is not None:
994
+ attention_mask = attention_mask.to(hidden_states.device)
995
+ if position_bias is not None:
996
+ position_bias = position_bias.to(hidden_states.device)
997
+ if encoder_hidden_states is not None:
998
+ encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
999
+ if encoder_extended_attention_mask is not None:
1000
+ encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
1001
+ if encoder_decoder_position_bias is not None:
1002
+ encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
1003
+ if layer_head_mask is not None:
1004
+ layer_head_mask = layer_head_mask.to(hidden_states.device)
1005
+ if cross_attn_layer_head_mask is not None:
1006
+ cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
1007
+ if output_hidden_states:
1008
+ all_hidden_states = all_hidden_states + (hidden_states,)
1009
+
1010
+ if self.gradient_checkpointing and self.training:
1011
+ layer_outputs = self._gradient_checkpointing_func(
1012
+ layer_module.forward,
1013
+ hidden_states,
1014
+ extended_attention_mask,
1015
+ position_bias,
1016
+ encoder_hidden_states,
1017
+ encoder_extended_attention_mask,
1018
+ encoder_decoder_position_bias,
1019
+ layer_head_mask,
1020
+ cross_attn_layer_head_mask,
1021
+ None, # past_key_value is always None with gradient checkpointing
1022
+ use_cache,
1023
+ output_attentions,
1024
+ )
1025
+ else:
1026
+ layer_outputs = layer_module(
1027
+ hidden_states,
1028
+ attention_mask=extended_attention_mask,
1029
+ position_bias=position_bias,
1030
+ encoder_hidden_states=encoder_hidden_states,
1031
+ encoder_attention_mask=encoder_extended_attention_mask,
1032
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
1033
+ layer_head_mask=layer_head_mask,
1034
+ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
1035
+ past_key_value=past_key_value,
1036
+ use_cache=use_cache,
1037
+ output_attentions=output_attentions,
1038
+ )
1039
+
1040
+ # layer_outputs is a tuple with:
1041
+ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
1042
+ if use_cache is False:
1043
+ layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
1044
+
1045
+ hidden_states, present_key_value_state = layer_outputs[:2]
1046
+
1047
+ # We share the position biases between the layers - the first layer store them
1048
+ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
1049
+ # (cross-attention position bias), (cross-attention weights)
1050
+ position_bias = layer_outputs[2]
1051
+ if self.is_decoder and encoder_hidden_states is not None:
1052
+ encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
1053
+ # append next layer key value states
1054
+ if use_cache:
1055
+ present_key_value_states = present_key_value_states + (present_key_value_state,)
1056
+
1057
+ if output_attentions:
1058
+ all_attentions = all_attentions + (layer_outputs[3],)
1059
+ if self.is_decoder:
1060
+ all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
1061
+
1062
+ # Model Parallel: If it's the last layer for that device, put things on the next device
1063
+ if self.model_parallel:
1064
+ for k, v in self.device_map.items():
1065
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
1066
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
1067
+
1068
+ hidden_states = self.final_layer_norm(hidden_states)
1069
+ hidden_states = self.dropout(hidden_states)
1070
+
1071
+ # Add last layer
1072
+ if output_hidden_states:
1073
+ all_hidden_states = all_hidden_states + (hidden_states,)
1074
+
1075
+ if not return_dict:
1076
+ return tuple(
1077
+ v
1078
+ for v in [
1079
+ hidden_states,
1080
+ present_key_value_states,
1081
+ all_hidden_states,
1082
+ all_attentions,
1083
+ all_cross_attentions,
1084
+ ]
1085
+ if v is not None
1086
+ )
1087
+ return BaseModelOutputWithPastAndCrossAttentions(
1088
+ last_hidden_state=hidden_states,
1089
+ past_key_values=present_key_value_states,
1090
+ hidden_states=all_hidden_states,
1091
+ attentions=all_attentions,
1092
+ cross_attentions=all_cross_attentions,
1093
+ )
1094
+
1095
+
1096
+
1097
+ class T5MIMOModel(T5PreTrainedModel):
1098
+ config_class = T5MIMOConfig
1099
+
1100
+ _keys_to_ignore_on_load_unexpected = [
1101
+ "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
1102
+ ]
1103
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
1104
+
1105
+ def __init__(self, config: T5MIMOConfig):
1106
+ super().__init__(config)
1107
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
1108
+
1109
+ encoder_config = copy.deepcopy(config)
1110
+ encoder_config.is_decoder = False
1111
+ encoder_config.use_cache = False
1112
+ encoder_config.is_encoder_decoder = False
1113
+ self.encoder = T5Stack(encoder_config, self.shared)
1114
+
1115
+ decoder_config = copy.deepcopy(config)
1116
+ decoder_config.is_decoder = True
1117
+ decoder_config.is_encoder_decoder = False
1118
+ decoder_config.num_layers = config.num_decoder_layers
1119
+ self.decoder = T5Stack(decoder_config, self.shared)
1120
+
1121
+ # Initialize weights and apply final processing
1122
+ self.post_init()
1123
+
1124
+ # Model parallel
1125
+ self.model_parallel = False
1126
+ self.device_map = None
1127
+
1128
+
1129
+ def parallelize(self, device_map=None):
1130
+ warnings.warn(
1131
+ "`T5Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model"
1132
+ " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
1133
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'encoder.block.0':"
1134
+ " 0, 'encoder.block.1': 1, ...}",
1135
+ FutureWarning,
1136
+ )
1137
+ self.device_map = (
1138
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1139
+ if device_map is None
1140
+ else device_map
1141
+ )
1142
+ assert_device_map(self.device_map, len(self.encoder.block))
1143
+ self.encoder.parallelize(self.device_map)
1144
+ self.decoder.parallelize(self.device_map)
1145
+ self.model_parallel = True
1146
+
1147
+
1148
+ def deparallelize(self):
1149
+ warnings.warn(
1150
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1151
+ FutureWarning,
1152
+ )
1153
+ self.encoder.deparallelize()
1154
+ self.decoder.deparallelize()
1155
+ self.encoder = self.encoder.to("cpu")
1156
+ self.decoder = self.decoder.to("cpu")
1157
+ self.model_parallel = False
1158
+ self.device_map = None
1159
+ torch.cuda.empty_cache()
1160
+
1161
+ def get_input_embeddings(self):
1162
+ return self.shared
1163
+
1164
+ def set_input_embeddings(self, new_embeddings):
1165
+ self.shared = new_embeddings
1166
+ self.encoder.set_input_embeddings(new_embeddings)
1167
+ self.decoder.set_input_embeddings(new_embeddings)
1168
+
1169
+ def _tie_weights(self):
1170
+ if self.config.tie_word_embeddings:
1171
+ self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
1172
+ self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
1173
+
1174
+ def get_encoder(self):
1175
+ return self.encoder
1176
+
1177
+ def get_decoder(self):
1178
+ return self.decoder
1179
+
1180
+ def _prune_heads(self, heads_to_prune):
1181
+ """
1182
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1183
+ class PreTrainedModel
1184
+ """
1185
+ for layer, heads in heads_to_prune.items():
1186
+ self.encoder.layer[layer].attention.prune_heads(heads)
1187
+
1188
+ def forward(
1189
+ self,
1190
+ input_ids: Optional[torch.LongTensor] = None,
1191
+ attention_mask: Optional[torch.FloatTensor] = None,
1192
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1193
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
1194
+ head_mask: Optional[torch.FloatTensor] = None,
1195
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
1196
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1197
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1198
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1199
+ inputs_embeds: Optional[torch.Tensor] = None,
1200
+ decoder_inputs_embeds: Optional[torch.Tensor] = None,
1201
+ use_cache: Optional[bool] = None,
1202
+ output_attentions: Optional[bool] = None,
1203
+ output_hidden_states: Optional[bool] = None,
1204
+ return_dict: Optional[bool] = None,
1205
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
1206
+ r"""
1207
+ Returns:
1208
+
1209
+ Example:
1210
+
1211
+ ```python
1212
+ >>> from transformers import AutoTokenizer, T5Model
1213
+
1214
+ >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
1215
+ >>> model = T5Model.from_pretrained("google-t5/t5-small")
1216
+
1217
+ >>> input_ids = tokenizer(
1218
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
1219
+ ... ).input_ids # Batch size 1
1220
+ >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
1221
+
1222
+ >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model.
1223
+ >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg.
1224
+ >>> decoder_input_ids = model._shift_right(decoder_input_ids)
1225
+
1226
+ >>> # forward pass
1227
+ >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
1228
+ >>> last_hidden_states = outputs.last_hidden_state
1229
+ ```"""
1230
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1231
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1232
+
1233
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1234
+ if head_mask is not None and decoder_head_mask is None:
1235
+ if self.config.num_layers == self.config.num_decoder_layers:
1236
+ decoder_head_mask = head_mask
1237
+
1238
+ # Encode if needed (training, first prediction pass)
1239
+ if encoder_outputs is None:
1240
+ encoder_outputs = self.encoder(
1241
+ input_ids=input_ids,
1242
+ attention_mask=attention_mask,
1243
+ inputs_embeds=inputs_embeds,
1244
+ head_mask=head_mask,
1245
+ output_attentions=output_attentions,
1246
+ output_hidden_states=output_hidden_states,
1247
+ return_dict=return_dict,
1248
+ )
1249
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1250
+ encoder_outputs = BaseModelOutput(
1251
+ last_hidden_state=encoder_outputs[0],
1252
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1253
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1254
+ )
1255
+
1256
+ hidden_states = encoder_outputs[0]
1257
+
1258
+ # Set device for model parallelism
1259
+ if self.model_parallel:
1260
+ torch.cuda.set_device(self.decoder.first_device)
1261
+ hidden_states = hidden_states.to(self.decoder.first_device)
1262
+ if decoder_input_ids is not None:
1263
+ decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
1264
+ if attention_mask is not None:
1265
+ attention_mask = attention_mask.to(self.decoder.first_device)
1266
+ if decoder_attention_mask is not None:
1267
+ decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
1268
+
1269
+ # Decode
1270
+ decoder_outputs = self.decoder(
1271
+ input_ids=decoder_input_ids,
1272
+ attention_mask=decoder_attention_mask,
1273
+ inputs_embeds=decoder_inputs_embeds,
1274
+ past_key_values=past_key_values,
1275
+ encoder_hidden_states=hidden_states,
1276
+ encoder_attention_mask=attention_mask,
1277
+ head_mask=decoder_head_mask,
1278
+ cross_attn_head_mask=cross_attn_head_mask,
1279
+ use_cache=use_cache,
1280
+ output_attentions=output_attentions,
1281
+ output_hidden_states=output_hidden_states,
1282
+ return_dict=return_dict,
1283
+ )
1284
+
1285
+ if not return_dict:
1286
+ return decoder_outputs + encoder_outputs
1287
+
1288
+ return Seq2SeqModelOutput(
1289
+ last_hidden_state=decoder_outputs.last_hidden_state,
1290
+ past_key_values=decoder_outputs.past_key_values,
1291
+ decoder_hidden_states=decoder_outputs.hidden_states,
1292
+ decoder_attentions=decoder_outputs.attentions,
1293
+ cross_attentions=decoder_outputs.cross_attentions,
1294
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1295
+ encoder_hidden_states=encoder_outputs.hidden_states,
1296
+ encoder_attentions=encoder_outputs.attentions,
1297
+ )
1298
+
1299
+
1300
+
1301
+ class T5MIMOForConditionalGeneration(T5PreTrainedModel):
1302
+ config_class = T5MIMOConfig
1303
+
1304
+ _keys_to_ignore_on_load_unexpected = [
1305
+ "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
1306
+ ]
1307
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
1308
+
1309
+ def __init__(self, config: T5MIMOConfig):
1310
+ super().__init__(config)
1311
+ self.model_dim = config.d_model
1312
+
1313
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
1314
+
1315
+ encoder_config = copy.deepcopy(config)
1316
+ encoder_config.is_decoder = False
1317
+ encoder_config.use_cache = False
1318
+ encoder_config.is_encoder_decoder = False
1319
+ self.encoder = T5Stack(encoder_config, self.shared)
1320
+
1321
+ decoder_config = copy.deepcopy(config)
1322
+ decoder_config.is_decoder = True
1323
+ decoder_config.is_encoder_decoder = False
1324
+ decoder_config.num_layers = config.num_decoder_layers
1325
+ self.decoder = T5Stack(decoder_config, self.shared)
1326
+
1327
+
1328
+ # self.conv_block = MultivariateConvBlock(config)
1329
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
1330
+
1331
+ # Initialize weights and apply final processing
1332
+ self.post_init()
1333
+
1334
+ # Model parallel
1335
+ self.model_parallel = False
1336
+ self.device_map = None
1337
+
1338
+
1339
+ def parallelize(self, device_map=None):
1340
+ warnings.warn(
1341
+ "`T5ForConditionalGeneration.parallelize` is deprecated and will be removed in v5 of Transformers, you"
1342
+ " should load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also"
1343
+ " provide your own `device_map` but it needs to be a dictionary module_name to device, so for instance"
1344
+ " {'encoder.block.0': 0, 'encoder.block.1': 1, ...}",
1345
+ FutureWarning,
1346
+ )
1347
+ self.device_map = (
1348
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1349
+ if device_map is None
1350
+ else device_map
1351
+ )
1352
+ assert_device_map(self.device_map, len(self.encoder.block))
1353
+ self.encoder.parallelize(self.device_map)
1354
+ self.decoder.parallelize(self.device_map)
1355
+ self.lm_head = self.lm_head.to(self.decoder.first_device)
1356
+ self.model_parallel = True
1357
+
1358
+
1359
+ def deparallelize(self):
1360
+ warnings.warn(
1361
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1362
+ FutureWarning,
1363
+ )
1364
+ self.encoder.deparallelize()
1365
+ self.decoder.deparallelize()
1366
+ self.encoder = self.encoder.to("cpu")
1367
+ self.decoder = self.decoder.to("cpu")
1368
+ self.lm_head = self.lm_head.to("cpu")
1369
+ self.model_parallel = False
1370
+ self.device_map = None
1371
+ torch.cuda.empty_cache()
1372
+
1373
+ def get_input_embeddings(self):
1374
+ return self.shared
1375
+
1376
+ def set_input_embeddings(self, new_embeddings):
1377
+ self.shared = new_embeddings
1378
+ self.encoder.set_input_embeddings(new_embeddings)
1379
+ self.decoder.set_input_embeddings(new_embeddings)
1380
+
1381
+ def _tie_weights(self):
1382
+ if self.config.tie_word_embeddings:
1383
+ self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
1384
+ self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
1385
+
1386
+ def set_output_embeddings(self, new_embeddings):
1387
+ self.lm_head = new_embeddings
1388
+
1389
+ def get_output_embeddings(self):
1390
+ return self.lm_head
1391
+
1392
+ def get_encoder(self):
1393
+ return self.encoder
1394
+
1395
+ def get_decoder(self):
1396
+ return self.decoder
1397
+
1398
+ def forward(
1399
+ self,
1400
+ input_ids: Optional[torch.LongTensor] = None,
1401
+ attention_mask: Optional[torch.FloatTensor] = None,
1402
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1403
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
1404
+ head_mask: Optional[torch.FloatTensor] = None,
1405
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
1406
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1407
+ encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1408
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1409
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1410
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1411
+ labels: Optional[torch.LongTensor] = None,
1412
+ use_cache: Optional[bool] = None,
1413
+ output_attentions: Optional[bool] = None,
1414
+ output_hidden_states: Optional[bool] = None,
1415
+ return_dict: Optional[bool] = None,
1416
+ use_conv: Optional[bool] = True,
1417
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
1418
+ r"""
1419
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1420
+ Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
1421
+ config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
1422
+ labels in `[0, ..., config.vocab_size]`
1423
+
1424
+ Returns:
1425
+
1426
+ Examples:
1427
+
1428
+ ```python
1429
+ >>> from transformers import AutoTokenizer, T5ForConditionalGeneration
1430
+
1431
+ >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
1432
+ >>> model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
1433
+
1434
+ >>> # training
1435
+ >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
1436
+ >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
1437
+ >>> outputs = model(input_ids=input_ids, labels=labels)
1438
+ >>> loss = outputs.loss
1439
+ >>> logits = outputs.logits
1440
+
1441
+ >>> # inference
1442
+ >>> input_ids = tokenizer(
1443
+ ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
1444
+ ... ).input_ids # Batch size 1
1445
+ >>> outputs = model.generate(input_ids)
1446
+ >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
1447
+ >>> # studies have shown that owning a dog is good for you.
1448
+ ```"""
1449
+
1450
+
1451
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1452
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1453
+
1454
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1455
+ if head_mask is not None and decoder_head_mask is None:
1456
+ if self.config.num_layers == self.config.num_decoder_layers:
1457
+ decoder_head_mask = head_mask
1458
+
1459
+
1460
+
1461
+ # Encode if needed (training, first prediction pass)
1462
+ if encoder_outputs is None:
1463
+ # Convert encoder inputs in embeddings if needed
1464
+ encoder_outputs = self.encoder(
1465
+ input_ids=input_ids,
1466
+ attention_mask=attention_mask,
1467
+ inputs_embeds=inputs_embeds,
1468
+ head_mask=head_mask,
1469
+ output_attentions=output_attentions,
1470
+ output_hidden_states=output_hidden_states,
1471
+ return_dict=return_dict,
1472
+ )
1473
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1474
+ encoder_outputs = BaseModelOutput(
1475
+ last_hidden_state=encoder_outputs[0],
1476
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1477
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1478
+ )
1479
+
1480
+ hidden_states = encoder_outputs[0]
1481
+
1482
+ if self.model_parallel:
1483
+ torch.cuda.set_device(self.decoder.first_device)
1484
+
1485
+ if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
1486
+ # get decoder inputs from shifting lm labels to the right
1487
+ decoder_input_ids = self._shift_right(labels)
1488
+
1489
+ # Set device for model parallelism
1490
+ if self.model_parallel:
1491
+ torch.cuda.set_device(self.decoder.first_device)
1492
+ hidden_states = hidden_states.to(self.decoder.first_device)
1493
+ if decoder_input_ids is not None:
1494
+ decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
1495
+ if attention_mask is not None:
1496
+ attention_mask = attention_mask.to(self.decoder.first_device)
1497
+ if decoder_attention_mask is not None:
1498
+ decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
1499
+
1500
+
1501
+
1502
+ # Decode
1503
+ decoder_outputs = self.decoder(
1504
+ input_ids=decoder_input_ids,
1505
+ attention_mask=decoder_attention_mask,
1506
+ inputs_embeds=decoder_inputs_embeds,
1507
+ past_key_values=past_key_values,
1508
+ encoder_hidden_states=hidden_states,
1509
+ encoder_attention_mask=attention_mask,
1510
+ head_mask=decoder_head_mask,
1511
+ cross_attn_head_mask=cross_attn_head_mask,
1512
+ use_cache=use_cache,
1513
+ output_attentions=output_attentions,
1514
+ output_hidden_states=output_hidden_states,
1515
+ return_dict=return_dict,
1516
+ )
1517
+
1518
+ sequence_output = decoder_outputs[0]
1519
+
1520
+
1521
+ # if use_conv:
1522
+ # sequence_output = self.conv_block(sequence_output)
1523
+
1524
+ # Set device for model parallelism
1525
+ if self.model_parallel:
1526
+ torch.cuda.set_device(self.encoder.first_device)
1527
+ self.lm_head = self.lm_head.to(self.encoder.first_device)
1528
+ sequence_output = sequence_output.to(self.lm_head.weight.device)
1529
+
1530
+ if self.config.tie_word_embeddings:
1531
+ # Rescale output before projecting on vocab
1532
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
1533
+ sequence_output = sequence_output * (self.model_dim**-0.5)
1534
+
1535
+ lm_logits = self.lm_head(sequence_output)
1536
+
1537
+ loss = None
1538
+ if labels is not None:
1539
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
1540
+ # move labels to correct device to enable PP
1541
+ labels = labels.to(lm_logits.device)
1542
+ if len(labels.shape) == 2:
1543
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
1544
+ else:
1545
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.reshape(-1))
1546
+ # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
1547
+
1548
+ if not return_dict:
1549
+ output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
1550
+ return ((loss,) + output) if loss is not None else output
1551
+
1552
+
1553
+
1554
+
1555
+ seq2seqlmoutput = Seq2SeqLMOutput(
1556
+ loss=loss,
1557
+ logits=lm_logits,
1558
+ past_key_values=decoder_outputs.past_key_values,
1559
+ decoder_hidden_states=decoder_outputs.hidden_states,
1560
+ decoder_attentions=decoder_outputs.attentions,
1561
+ cross_attentions=decoder_outputs.cross_attentions,
1562
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1563
+ encoder_hidden_states=encoder_outputs.hidden_states,
1564
+ encoder_attentions=encoder_outputs.attentions,
1565
+ )
1566
+ return seq2seqlmoutput
1567
+
1568
+ def prepare_inputs_for_generation(
1569
+ self,
1570
+ input_ids,
1571
+ past_key_values=None,
1572
+ attention_mask=None,
1573
+ head_mask=None,
1574
+ decoder_head_mask=None,
1575
+ decoder_attention_mask=None,
1576
+ cross_attn_head_mask=None,
1577
+ use_cache=None,
1578
+ encoder_outputs=None,
1579
+ **kwargs,
1580
+ ):
1581
+ # cut decoder_input_ids if past_key_values is used
1582
+ if past_key_values is not None:
1583
+ past_length = past_key_values[0][0].shape[2]
1584
+
1585
+ # Some generation methods already pass only the last input ID
1586
+ if input_ids.shape[1] > past_length:
1587
+ remove_prefix_length = past_length
1588
+ else:
1589
+ # Default to old behavior: keep only final ID
1590
+ remove_prefix_length = input_ids.shape[1] - 1
1591
+
1592
+ input_ids = input_ids[:, remove_prefix_length:]
1593
+
1594
+ return {
1595
+ "decoder_input_ids": input_ids,
1596
+ "past_key_values": past_key_values,
1597
+ "encoder_outputs": encoder_outputs,
1598
+ "attention_mask": attention_mask,
1599
+ "head_mask": head_mask,
1600
+ "decoder_head_mask": decoder_head_mask,
1601
+ "decoder_attention_mask": decoder_attention_mask,
1602
+ "cross_attn_head_mask": cross_attn_head_mask,
1603
+ "use_cache": use_cache,
1604
+ }
1605
+
1606
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
1607
+ return self._shift_right(labels)
1608
+
1609
+ def _reorder_cache(self, past_key_values, beam_idx):
1610
+ # if decoder past is not included in output
1611
+ # speedy decoding is disabled and no need to reorder
1612
+ if past_key_values is None:
1613
+ logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
1614
+ return past_key_values
1615
+
1616
+ reordered_decoder_past = ()
1617
+ for layer_past_states in past_key_values:
1618
+ # get the correct batch idx from layer past batch dim
1619
+ # batch dim of `past` is at 2nd position
1620
+ reordered_layer_past_states = ()
1621
+ for layer_past_state in layer_past_states:
1622
+ # need to set correct `past` for each of the four key / value states
1623
+ reordered_layer_past_states = reordered_layer_past_states + (
1624
+ layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
1625
+ )
1626
+
1627
+ if reordered_layer_past_states[0].shape != layer_past_states[0].shape:
1628
+ raise ValueError(
1629
+ f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched"
1630
+ )
1631
+ if len(reordered_layer_past_states) != len(layer_past_states):
1632
+ raise ValueError(
1633
+ f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched"
1634
+ )
1635
+
1636
+ reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
1637
+ return reordered_decoder_past
1638
+
1639
+
1640
+
1641
+ class T5MIMOEncoderModel(T5PreTrainedModel):
1642
+ _tied_weights_keys = ["encoder.embed_tokens.weight"]
1643
+ _keys_to_ignore_on_load_unexpected = [r"decoder"]
1644
+
1645
+ def __init__(self, config: T5MIMOConfig):
1646
+ super().__init__(config)
1647
+
1648
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
1649
+
1650
+ encoder_config = copy.deepcopy(config)
1651
+ encoder_config.use_cache = False
1652
+ encoder_config.is_encoder_decoder = False
1653
+ self.encoder = T5Stack(encoder_config, self.shared)
1654
+
1655
+ # Initialize weights and apply final processing
1656
+ self.post_init()
1657
+
1658
+ # Model parallel
1659
+ self.model_parallel = False
1660
+ self.device_map = None
1661
+
1662
+ def parallelize(self, device_map=None):
1663
+ warnings.warn(
1664
+ "`T5EncoderModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
1665
+ " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
1666
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0,"
1667
+ " 'block.1': 1, ...}",
1668
+ FutureWarning,
1669
+ )
1670
+ self.device_map = (
1671
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1672
+ if device_map is None
1673
+ else device_map
1674
+ )
1675
+ assert_device_map(self.device_map, len(self.encoder.block))
1676
+ self.encoder.parallelize(self.device_map)
1677
+ self.model_parallel = True
1678
+
1679
+ def deparallelize(self):
1680
+ warnings.warn(
1681
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1682
+ FutureWarning,
1683
+ )
1684
+ self.encoder.deparallelize()
1685
+ self.encoder = self.encoder.to("cpu")
1686
+ self.model_parallel = False
1687
+ self.device_map = None
1688
+ torch.cuda.empty_cache()
1689
+
1690
+ def get_input_embeddings(self):
1691
+ return self.shared
1692
+
1693
+ def set_input_embeddings(self, new_embeddings):
1694
+ self.shared = new_embeddings
1695
+ self.encoder.set_input_embeddings(new_embeddings)
1696
+
1697
+ def _tie_weights(self):
1698
+ if self.config.tie_word_embeddings:
1699
+ self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
1700
+
1701
+ def get_encoder(self):
1702
+ return self.encoder
1703
+
1704
+ def _prune_heads(self, heads_to_prune):
1705
+ """
1706
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1707
+ class PreTrainedModel
1708
+ """
1709
+ for layer, heads in heads_to_prune.items():
1710
+ self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads)
1711
+
1712
+ def forward(
1713
+ self,
1714
+ input_ids: Optional[torch.LongTensor] = None,
1715
+ attention_mask: Optional[torch.FloatTensor] = None,
1716
+ head_mask: Optional[torch.FloatTensor] = None,
1717
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1718
+ output_attentions: Optional[bool] = None,
1719
+ output_hidden_states: Optional[bool] = None,
1720
+ return_dict: Optional[bool] = None,
1721
+ ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
1722
+ r"""
1723
+ Returns:
1724
+
1725
+ Example:
1726
+
1727
+ ```python
1728
+ >>> from transformers import AutoTokenizer, T5EncoderModel
1729
+
1730
+ >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
1731
+ >>> model = T5EncoderModel.from_pretrained("google-t5/t5-small")
1732
+ >>> input_ids = tokenizer(
1733
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
1734
+ ... ).input_ids # Batch size 1
1735
+ >>> outputs = model(input_ids=input_ids)
1736
+ >>> last_hidden_states = outputs.last_hidden_state
1737
+ ```"""
1738
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1739
+
1740
+ encoder_outputs = self.encoder(
1741
+ input_ids=input_ids,
1742
+ attention_mask=attention_mask,
1743
+ inputs_embeds=inputs_embeds,
1744
+ head_mask=head_mask,
1745
+ output_attentions=output_attentions,
1746
+ output_hidden_states=output_hidden_states,
1747
+ return_dict=return_dict,
1748
+ )
1749
+
1750
+ return encoder_outputs
1751
+
1752
+
1753
+