wissamantoun commited on
Commit
226f8e0
1 Parent(s): 26801de

Update README.md and config.json

Browse files
Files changed (4) hide show
  1. README.md +37 -14
  2. config.json +10 -4
  3. configuration_aragpt2.py +275 -0
  4. modeling_aragpt2.py +1975 -0
README.md CHANGED
@@ -1,11 +1,14 @@
1
  ---
2
  language: ar
 
 
 
3
  datasets:
4
  - wikipedia
5
  - Osian
6
- - 1.5B-Arabic-Corpus
7
- - oscar-arabic-unshuffled
8
- - Assafir(private)
9
  inference: false
10
  widget:
11
  - text: "يحكى أن مزارعا مخادعا قام ببيع بئر الماء الموجود في أرضه لجاره مقابل مبلغ كبير من المال"
@@ -29,17 +32,17 @@ Both models are trained using the `adafactor` optimizer, since the `adam` and `l
29
 
30
  AraGPT2 is trained on the same large Arabic Dataset as AraBERTv2.
31
 
32
- # Usage
 
 
33
 
34
  ## Testing the model using `transformers`:
35
 
 
 
 
36
  ```python
37
- from transformers import GPT2TokenizerFast, pipeline
38
- #for base and medium
39
- from transformers import GPT2LMHeadModel
40
- #for large and mega
41
- # pip install arabert
42
- from arabert.aragpt2.grover.modeling_gpt2 import GPT2LMHeadModel
43
 
44
  from arabert.preprocess import ArabertPreprocessor
45
 
@@ -49,13 +52,15 @@ arabert_prep = ArabertPreprocessor(model_name=MODEL_NAME)
49
  text=""
50
  text_clean = arabert_prep.preprocess(text)
51
 
52
- model = GPT2LMHeadModel.from_pretrained(MODEL_NAME)
53
  tokenizer = GPT2TokenizerFast.from_pretrained(MODEL_NAME)
54
- generation_pipeline = pipeline("text-generation",model=model,tokenizer=tokenizer)
 
 
55
 
56
  #feel free to try different decoding settings
57
  generation_pipeline(text,
58
- pad_token_id=tokenizer.eos_token_id,
59
  num_beams=10,
60
  max_length=200,
61
  top_p=0.9,
@@ -79,7 +84,25 @@ python create_pretraining_data.py
79
 
80
  Finetuning:
81
  ```bash
82
- python3 run_pretraining.py \\\r\n --input_file="gs://<GS_BUCKET>/pretraining_data/*" \\\r\n --output_dir="gs://<GS_BUCKET>/pretraining_model/" \\\r\n --config_file="config/small_hparams.json" \\\r\n --batch_size=128 \\\r\n --eval_batch_size=8 \\\r\n --num_train_steps= \\\r\n --num_warmup_steps= \\\r\n --learning_rate= \\\r\n --save_checkpoints_steps= \\\r\n --max_seq_length=1024 \\\r\n --max_eval_steps= \\\r\n --optimizer="lamb" \\\r\n --iterations_per_loop=5000 \\\r\n --keep_checkpoint_max=10 \\\r\n --use_tpu=True \\\r\n --tpu_name=<TPU NAME> \\\r\n --do_train=True \\\r\n --do_eval=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  ```
84
  # Model Sizes
85
 
 
1
  ---
2
  language: ar
3
+ license: other
4
+ license_name: custom
5
+ license_link: https://github.com/aub-mind/arabert/blob/master/aragpt2/LICENSE
6
  datasets:
7
  - wikipedia
8
  - Osian
9
+ - arabic-billion-words
10
+ - oscar
11
+ - Assafir-private
12
  inference: false
13
  widget:
14
  - text: "يحكى أن مزارعا مخادعا قام ببيع بئر الماء الموجود في أرضه لجاره مقابل مبلغ كبير من المال"
 
32
 
33
  AraGPT2 is trained on the same large Arabic Dataset as AraBERTv2.
34
 
35
+
36
+ # NOTE: The model expects the input to be preprocessed using the `arabert` library.
37
+ if not the model won't be able to generate the correct output.
38
 
39
  ## Testing the model using `transformers`:
40
 
41
+ The model code is now hosted on HuggingFace so you need to use the `trust_remote_code` flag, and can be used as follows:
42
+
43
+
44
  ```python
45
+ from transformers import AutoModelForCausalLM, pipeline
 
 
 
 
 
46
 
47
  from arabert.preprocess import ArabertPreprocessor
48
 
 
52
  text=""
53
  text_clean = arabert_prep.preprocess(text)
54
 
55
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True)
56
  tokenizer = GPT2TokenizerFast.from_pretrained(MODEL_NAME)
57
+ generation_pipeline = pipeline(
58
+ "text-generation", model=MODEL_NAME, trust_remote_code=True
59
+ )
60
 
61
  #feel free to try different decoding settings
62
  generation_pipeline(text,
63
+ pad_token_id=pipeline.tokenizer.eos_token_id,
64
  num_beams=10,
65
  max_length=200,
66
  top_p=0.9,
 
84
 
85
  Finetuning:
86
  ```bash
87
+ python3 run_pretraining.py \
88
+ --input_file="gs://<GS_BUCKET>/pretraining_data/*" \
89
+ --output_dir="gs://<GS_BUCKET>/pretraining_model/" \
90
+ --config_file="config/small_hparams.json" \
91
+ --batch_size=128 \
92
+ --eval_batch_size=8 \
93
+ --num_train_steps= \
94
+ --num_warmup_steps= \
95
+ --learning_rate= \
96
+ --save_checkpoints_steps= \
97
+ --max_seq_length=1024 \
98
+ --max_eval_steps= \
99
+ --optimizer="lamb" \
100
+ --iterations_per_loop=5000 \
101
+ --keep_checkpoint_max=10 \
102
+ --use_tpu=True \
103
+ --tpu_name=<TPU NAME> \
104
+ --do_train=True \
105
+ --do_eval=False
106
  ```
107
  # Model Sizes
108
 
config.json CHANGED
@@ -1,8 +1,13 @@
1
  {
2
  "activation_function": "gelu_new",
3
  "architectures": [
4
- "GPT2LMHeadModel"
5
  ],
 
 
 
 
 
6
  "attention_probs_dropout_prob": 0.1,
7
  "attn_pdrop": 0.1,
8
  "bos_token_id": 0,
@@ -32,9 +37,10 @@
32
  "max_length": 50,
33
  "num_beams": 5,
34
  "top_p": 0.95,
35
- "repetition_penalty": 3.0,
36
  "no_repeat_ngram_size": 3
37
  }
38
  },
39
- "vocab_size": 64000
40
- }
 
 
1
  {
2
  "activation_function": "gelu_new",
3
  "architectures": [
4
+ "AraGPT2LMHeadModel"
5
  ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_aragpt2.AraGPT2Config",
8
+ "AutoModelForCausalLM": "modeling_aragpt2.AraGPT2LMHeadModel",
9
+ "AutoModel": "modeling_aragpt2.AraGPT2Model"
10
+ },
11
  "attention_probs_dropout_prob": 0.1,
12
  "attn_pdrop": 0.1,
13
  "bos_token_id": 0,
 
37
  "max_length": 50,
38
  "num_beams": 5,
39
  "top_p": 0.95,
40
+ "repetition_penalty": 3.0,
41
  "no_repeat_ngram_size": 3
42
  }
43
  },
44
+ "vocab_size": 64000,
45
+ "tokenizer_class": "GPT2Tokenizer"
46
+ }
configuration_aragpt2.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ """ AraGPT2 configuration"""
3
+ from collections import OrderedDict
4
+ from typing import Any, List, Mapping, Optional
5
+
6
+ from transformers import PreTrainedTokenizer, TensorType, is_torch_available
7
+ from transformers.configuration_utils import PretrainedConfig
8
+ from transformers.onnx import OnnxConfigWithPast, PatchingSpec
9
+ from transformers.utils import logging
10
+
11
+
12
+ logger = logging.get_logger(__name__)
13
+
14
+ AraGPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
15
+ "aubmindlab/aragpt2-mega": "https://huggingface.co/aubmindlab/aragpt2-mega/resolve/main/config.json",
16
+ }
17
+
18
+
19
+ class AraGPT2Config(PretrainedConfig):
20
+ """
21
+ This is the configuration class to store the configuration of a [`AraGPT2Model`] or a [`TFAraGPT2Model`]. It is used to
22
+ instantiate a AraGPT2 model according to the specified arguments, defining the model architecture. Instantiating a
23
+ configuration with the defaults will yield a similar configuration to that of the AraGPT2
24
+ [aubmindlab/aragpt2-mega](https://huggingface.co/aubmindlab/aragpt2-mega) architecture.
25
+
26
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
27
+ documentation from [`PretrainedConfig`] for more information.
28
+
29
+
30
+ Args:
31
+ vocab_size (`int`, *optional*, defaults to 64000):
32
+ Vocabulary size of the AraGPT2 model. Defines the number of different tokens that can be represented by the
33
+ `inputs_ids` passed when calling [`AraGPT2Model`] or [`TFAraGPT2Model`].
34
+ n_positions (`int`, *optional*, defaults to 1024):
35
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
36
+ just in case (e.g., 512 or 1024 or 2048).
37
+ n_embd (`int`, *optional*, defaults to 768):
38
+ Dimensionality of the embeddings and hidden states.
39
+ n_layer (`int`, *optional*, defaults to 12):
40
+ Number of hidden layers in the Transformer encoder.
41
+ n_head (`int`, *optional*, defaults to 12):
42
+ Number of attention heads for each attention layer in the Transformer encoder.
43
+ n_inner (`int`, *optional*):
44
+ Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
45
+ activation_function (`str`, *optional*, defaults to `"gelu_new"`):
46
+ Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`.
47
+ resid_pdrop (`float`, *optional*, defaults to 0.1):
48
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
49
+ embd_pdrop (`float`, *optional*, defaults to 0.1):
50
+ The dropout ratio for the embeddings.
51
+ attn_pdrop (`float`, *optional*, defaults to 0.1):
52
+ The dropout ratio for the attention.
53
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
54
+ The epsilon to use in the layer normalization layers.
55
+ initializer_range (`float`, *optional*, defaults to 0.02):
56
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
57
+ summary_type (`string`, *optional*, defaults to `"cls_index"`):
58
+ Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
59
+ [`TFGPT2DoubleHeadsModel`].
60
+
61
+ Has to be one of the following options:
62
+
63
+ - `"last"`: Take the last token hidden state (like XLNet).
64
+ - `"first"`: Take the first token hidden state (like BERT).
65
+ - `"mean"`: Take the mean of all tokens hidden states.
66
+ - `"cls_index"`: Supply a Tensor of classification token position (like GPT/AraGPT2).
67
+ - `"attn"`: Not implemented now, use multi-head attention.
68
+ summary_use_proj (`bool`, *optional*, defaults to `True`):
69
+ Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
70
+ [`TFGPT2DoubleHeadsModel`].
71
+
72
+ Whether or not to add a projection after the vector extraction.
73
+ summary_activation (`str`, *optional*):
74
+ Argument used when doing sequence summary. Used in for the multiple choice head in
75
+ [`GPT2DoubleHeadsModel`].
76
+
77
+ Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation.
78
+ summary_proj_to_labels (`bool`, *optional*, defaults to `True`):
79
+ Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
80
+ [`TFGPT2DoubleHeadsModel`].
81
+
82
+ Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes.
83
+ summary_first_dropout (`float`, *optional*, defaults to 0.1):
84
+ Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
85
+ [`TFGPT2DoubleHeadsModel`].
86
+
87
+ The dropout ratio to be used after the projection and activation.
88
+ scale_attn_weights (`bool`, *optional*, defaults to `True`):
89
+ Scale attention weights by dividing by sqrt(hidden_size)..
90
+ use_cache (`bool`, *optional*, defaults to `True`):
91
+ Whether or not the model should return the last key/values attentions (not used by all models).
92
+ bos_token_id (`int`, *optional*, defaults to 50256):
93
+ Id of the beginning of sentence token in the vocabulary.
94
+ eos_token_id (`int`, *optional*, defaults to 50256):
95
+ Id of the end of sentence token in the vocabulary.
96
+ scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`):
97
+ Whether to additionally scale attention weights by `1 / layer_idx + 1`.
98
+ reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`):
99
+ Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention
100
+ dot-product/softmax to float() when training with mixed precision.
101
+
102
+ Example:
103
+
104
+ ```python
105
+ >>> from transformers import AraGPT2Config, AraGPT2Model
106
+
107
+ >>> # Initializing a AraGPT2 configuration
108
+ >>> configuration = AraGPT2Config()
109
+
110
+ >>> # Initializing a model (with random weights) from the configuration
111
+ >>> model = AraGPT2Model(configuration)
112
+
113
+ >>> # Accessing the model configuration
114
+ >>> configuration = model.config
115
+ ```"""
116
+
117
+ model_type = "aragpt2"
118
+ keys_to_ignore_at_inference = ["past_key_values"]
119
+ attribute_map = {
120
+ "hidden_size": "n_embd",
121
+ "max_position_embeddings": "n_positions",
122
+ "num_attention_heads": "n_head",
123
+ "num_hidden_layers": "n_layer",
124
+ }
125
+
126
+ def __init__(
127
+ self,
128
+ vocab_size=64000,
129
+ n_positions=1024,
130
+ n_embd=768,
131
+ n_layer=12,
132
+ n_head=12,
133
+ n_inner=None,
134
+ activation_function="gelu_new",
135
+ resid_pdrop=0.1,
136
+ embd_pdrop=0.1,
137
+ attn_pdrop=0.1,
138
+ layer_norm_epsilon=1e-5,
139
+ initializer_range=0.02,
140
+ summary_type="cls_index",
141
+ summary_use_proj=True,
142
+ summary_activation=None,
143
+ summary_proj_to_labels=True,
144
+ summary_first_dropout=0.1,
145
+ scale_attn_weights=True,
146
+ use_cache=True,
147
+ bos_token_id=0,
148
+ eos_token_id=0,
149
+ scale_attn_by_inverse_layer_idx=False,
150
+ reorder_and_upcast_attn=False,
151
+ **kwargs,
152
+ ):
153
+ self.vocab_size = vocab_size
154
+ self.n_positions = n_positions
155
+ self.n_embd = n_embd
156
+ self.n_layer = n_layer
157
+ self.n_head = n_head
158
+ self.n_inner = n_inner
159
+ self.activation_function = activation_function
160
+ self.resid_pdrop = resid_pdrop
161
+ self.embd_pdrop = embd_pdrop
162
+ self.attn_pdrop = attn_pdrop
163
+ self.layer_norm_epsilon = layer_norm_epsilon
164
+ self.initializer_range = initializer_range
165
+ self.summary_type = summary_type
166
+ self.summary_use_proj = summary_use_proj
167
+ self.summary_activation = summary_activation
168
+ self.summary_first_dropout = summary_first_dropout
169
+ self.summary_proj_to_labels = summary_proj_to_labels
170
+ self.scale_attn_weights = scale_attn_weights
171
+ self.use_cache = use_cache
172
+ self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
173
+ self.reorder_and_upcast_attn = reorder_and_upcast_attn
174
+
175
+ self.bos_token_id = bos_token_id
176
+ self.eos_token_id = eos_token_id
177
+
178
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
179
+
180
+
181
+ class AraGPT2OnnxConfig(OnnxConfigWithPast):
182
+ def __init__(
183
+ self,
184
+ config: PretrainedConfig,
185
+ task: str = "default",
186
+ patching_specs: List[PatchingSpec] = None,
187
+ use_past: bool = False,
188
+ ):
189
+ super().__init__(
190
+ config, task=task, patching_specs=patching_specs, use_past=use_past
191
+ )
192
+ if not getattr(self._config, "pad_token_id", None):
193
+ # TODO: how to do that better?
194
+ self._config.pad_token_id = 0
195
+
196
+ @property
197
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
198
+ common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
199
+ if self.use_past:
200
+ self.fill_with_past_key_values_(common_inputs, direction="inputs")
201
+ common_inputs["attention_mask"] = {
202
+ 0: "batch",
203
+ 1: "past_sequence + sequence",
204
+ }
205
+ else:
206
+ common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
207
+
208
+ return common_inputs
209
+
210
+ @property
211
+ def num_layers(self) -> int:
212
+ return self._config.n_layer
213
+
214
+ @property
215
+ def num_attention_heads(self) -> int:
216
+ return self._config.n_head
217
+
218
+ def generate_dummy_inputs(
219
+ self,
220
+ tokenizer: PreTrainedTokenizer,
221
+ batch_size: int = -1,
222
+ seq_length: int = -1,
223
+ is_pair: bool = False,
224
+ framework: Optional[TensorType] = None,
225
+ ) -> Mapping[str, Any]:
226
+ common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
227
+ tokenizer,
228
+ batch_size=batch_size,
229
+ seq_length=seq_length,
230
+ is_pair=is_pair,
231
+ framework=framework,
232
+ )
233
+
234
+ # We need to order the input in the way they appears in the forward()
235
+ ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
236
+
237
+ # Need to add the past_keys
238
+ if self.use_past:
239
+ if not is_torch_available():
240
+ raise ValueError(
241
+ "Cannot generate dummy past_keys inputs without PyTorch installed."
242
+ )
243
+ else:
244
+ import torch
245
+
246
+ batch, seqlen = common_inputs["input_ids"].shape
247
+ # Not using the same length for past_key_values
248
+ past_key_values_length = seqlen + 2
249
+ past_shape = (
250
+ batch,
251
+ self.num_attention_heads,
252
+ past_key_values_length,
253
+ self._config.hidden_size // self.num_attention_heads,
254
+ )
255
+ ordered_inputs["past_key_values"] = [
256
+ (torch.zeros(past_shape), torch.zeros(past_shape))
257
+ for _ in range(self.num_layers)
258
+ ]
259
+
260
+ ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
261
+ if self.use_past:
262
+ mask_dtype = ordered_inputs["attention_mask"].dtype
263
+ ordered_inputs["attention_mask"] = torch.cat(
264
+ [
265
+ ordered_inputs["attention_mask"],
266
+ torch.ones(batch, past_key_values_length, dtype=mask_dtype),
267
+ ],
268
+ dim=1,
269
+ )
270
+
271
+ return ordered_inputs
272
+
273
+ @property
274
+ def default_onnx_opset(self) -> int:
275
+ return 13
modeling_aragpt2.py ADDED
@@ -0,0 +1,1975 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ """PyTorch AraGPT2 model."""
3
+
4
+ import math
5
+ import os
6
+ import warnings
7
+ from dataclasses import dataclass
8
+ from typing import Optional, Tuple, Union
9
+
10
+ import torch
11
+ import torch.utils.checkpoint
12
+ from torch import nn
13
+ from torch.cuda.amp import autocast
14
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
15
+
16
+ from transformers.activations import ACT2FN
17
+ from transformers.modeling_outputs import (
18
+ BaseModelOutputWithPastAndCrossAttentions,
19
+ CausalLMOutputWithCrossAttentions,
20
+ QuestionAnsweringModelOutput,
21
+ SequenceClassifierOutputWithPast,
22
+ TokenClassifierOutput,
23
+ )
24
+ from transformers.modeling_utils import PreTrainedModel, SequenceSummary
25
+ from transformers.pytorch_utils import (
26
+ Conv1D,
27
+ find_pruneable_heads_and_indices,
28
+ prune_conv1d_layer,
29
+ )
30
+ from transformers.utils import (
31
+ ModelOutput,
32
+ add_code_sample_docstrings,
33
+ add_start_docstrings,
34
+ add_start_docstrings_to_model_forward,
35
+ logging,
36
+ replace_return_docstrings,
37
+ )
38
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
39
+ from .configuration_aragpt2 import AraGPT2Config
40
+
41
+
42
+ logger = logging.get_logger(__name__)
43
+
44
+ _CHECKPOINT_FOR_DOC = "aubmindlab/aragpt2-mega"
45
+ _CONFIG_FOR_DOC = "AraGPT2Config"
46
+ _TOKENIZER_FOR_DOC = "GPT2Tokenizer"
47
+
48
+ ARAGPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
49
+ "aubmindlab/aragpt2-large",
50
+ "aubmindlab/aragpt2-mega",
51
+ # See all AraGPT2 models at https://huggingface.co/models?filter=aragpt2
52
+ ]
53
+
54
+ _ARAGPT2_ML_TF_TO_TORCH = {
55
+ "LayerNorm_embed_norm": "emb_norm",
56
+ "pos_embed": "wpe.weight",
57
+ "word_embed": "wte.weight",
58
+ "layer": "h",
59
+ # Most importently This two layer norm must be put on the same position as gpt2-ml
60
+ # or generated data is bad, just repeat the last token
61
+ "LayerNorm_mlp_ln0": "ln_1",
62
+ "LayerNorm_mlp_ln1": "ln_2",
63
+ "intermediate": "mlp.c_fc",
64
+ "output": "mlp.c_proj",
65
+ "query_layer": "attn.c_attn",
66
+ "key_layer": "attn.c_attn",
67
+ "value_layer": "attn.c_attn",
68
+ "context_projection_layer": "attn.c_proj",
69
+ "gamma": "weight",
70
+ "kernel": "weight",
71
+ "beta": "bias",
72
+ "bias": "bias",
73
+ }
74
+
75
+ WEIGHTS_NAME = "pytorch_model.bin"
76
+ CONFIG_NAME = "config.json"
77
+
78
+
79
+ def convert_gpt2_checkpoint_to_pytorch(
80
+ aragpt2_checkpoint_path, aragpt2_config_file, pytorch_dump_folder_path
81
+ ):
82
+ # Construct model
83
+ if aragpt2_config_file == "":
84
+ config = AraGPT2Config()
85
+ else:
86
+ config = AraGPT2Config.from_json_file(aragpt2_config_file)
87
+ model = AraGPT2Model(config)
88
+
89
+ # Load weights from numpy
90
+ load_tf_weights_in_aragpt2(model, config, aragpt2_checkpoint_path)
91
+
92
+ # Save pytorch-model
93
+ pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
94
+ pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
95
+ print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
96
+ torch.save(model.state_dict(), pytorch_weights_dump_path)
97
+ print("Save configuration file to {}".format(pytorch_config_dump_path))
98
+ with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
99
+ f.write(config.to_json_string())
100
+
101
+
102
+ # XXX: MUST do like: convert_gpt2_checkpoint_to_pytorch('./model.ckpt-100000', './mega.json', './')
103
+ # https://github.com/tensorflow/models/issues/2675#issuecomment-516595597
104
+ def load_tf_weights_in_aragpt2(model, config, aragpt2_checkpoint_path):
105
+ """Load tf checkpoints in a pytorch model"""
106
+ try:
107
+ import re
108
+ import tensorflow as tf
109
+ except ImportError:
110
+ logger.error(
111
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
112
+ "https://www.tensorflow.org/install/ for installation instructions."
113
+ )
114
+ raise
115
+ tf_path = os.path.abspath(aragpt2_checkpoint_path)
116
+ logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
117
+ # Load weights from TF model
118
+ init_vars = tf.train.list_variables(tf_path)
119
+ names = []
120
+ arrays = []
121
+ for name, shape in init_vars:
122
+ logger.info("Loading TF weight {} with shape {}".format(name, shape))
123
+ array = tf.train.load_variable(tf_path, name)
124
+ names.append(name)
125
+ arrays.append(array.squeeze())
126
+
127
+ import copy
128
+
129
+ orig_model = copy.deepcopy(model)
130
+
131
+ for name, array in zip(names, arrays):
132
+ name = name[6:] # skip "model/"
133
+ name = name.split("/")
134
+ pointer = model
135
+
136
+ attn_layer = ""
137
+ for m_name in name:
138
+ if re.fullmatch(r"[A-Za-z]+\d+", m_name):
139
+ scope_names = re.split(r"(\d+)", m_name)
140
+ else:
141
+ scope_names = [m_name]
142
+ sname = scope_names[0]
143
+
144
+ if sname == "" or sname == "embeddings":
145
+ continue
146
+ elif sname not in _ARAGPT2_ML_TF_TO_TORCH:
147
+ print("=========================================================")
148
+ logger.info("Skip var name {}".format(scope_names))
149
+ pointer = None
150
+ break
151
+ else:
152
+ tname = _ARAGPT2_ML_TF_TO_TORCH[sname]
153
+ if "." in tname:
154
+ parent, child = tname.split(".")
155
+ pointer = getattr(pointer, parent)
156
+ pointer = getattr(pointer, child)
157
+ else:
158
+ pointer = getattr(pointer, tname)
159
+
160
+ if tname == "attn.c_attn":
161
+ attn_layer = sname
162
+
163
+ if len(scope_names) >= 2:
164
+ num = int(scope_names[1])
165
+ pointer = pointer[num]
166
+
167
+ if pointer is None:
168
+ continue
169
+ if attn_layer == "":
170
+ try:
171
+ assert pointer.shape == array.shape
172
+ except AssertionError as e:
173
+ e.args += (pointer.shape, array.shape)
174
+ raise
175
+ logger.info(
176
+ "Initialize PyTorch weight {}, {}, {}".format(
177
+ name, array.mean(), pointer.mean()
178
+ )
179
+ )
180
+ if attn_layer == "":
181
+ pointer.data = torch.from_numpy(array)
182
+ else:
183
+ shape = pointer.shape
184
+ d = torch.from_numpy(array)
185
+ is_bias = len(shape) == 1
186
+ end = int(shape[0 if is_bias else 1] / 3)
187
+ m = dict(
188
+ query_layer=0,
189
+ key_layer=end,
190
+ value_layer=end * 2,
191
+ )
192
+ start = m[attn_layer]
193
+ end = start + end
194
+ if is_bias:
195
+ pointer.data[start:end] = d
196
+ else:
197
+ pointer.data[:, start:end] = d
198
+ logger.info(
199
+ "Initialize PyTorch weight {}, {}, {}".format(
200
+ name, array.mean(), pointer.mean()
201
+ )
202
+ )
203
+
204
+ for name, params in orig_model.named_parameters():
205
+ for n, p in model.named_parameters():
206
+ if name == n:
207
+ if params.equal(p):
208
+ print("--------------------------")
209
+ print(" %s not changed!" % n)
210
+ return model
211
+
212
+
213
+ class AraGPT2Attention(nn.Module):
214
+ def __init__(self, config, is_cross_attention=False, layer_idx=None):
215
+ super().__init__()
216
+
217
+ max_positions = config.max_position_embeddings
218
+ self.register_buffer(
219
+ "bias",
220
+ torch.tril(
221
+ torch.ones((max_positions, max_positions), dtype=torch.bool)
222
+ ).view(1, 1, max_positions, max_positions),
223
+ persistent=False,
224
+ )
225
+ self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
226
+
227
+ self.embed_dim = config.hidden_size
228
+ self.num_heads = config.num_attention_heads
229
+ self.head_dim = self.embed_dim // self.num_heads
230
+ self.split_size = self.embed_dim
231
+ if self.head_dim * self.num_heads != self.embed_dim:
232
+ raise ValueError(
233
+ f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
234
+ f" {self.num_heads})."
235
+ )
236
+
237
+ self.scale_attn_weights = config.scale_attn_weights
238
+ self.is_cross_attention = is_cross_attention
239
+
240
+ # Layer-wise attention scaling, reordering, and upcasting
241
+ self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
242
+ self.layer_idx = layer_idx
243
+ self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
244
+
245
+ if self.is_cross_attention:
246
+ self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
247
+ self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
248
+ else:
249
+ self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
250
+ self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
251
+
252
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
253
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
254
+
255
+ self.pruned_heads = set()
256
+
257
+ def prune_heads(self, heads):
258
+ if len(heads) == 0:
259
+ return
260
+ heads, index = find_pruneable_heads_and_indices(
261
+ heads, self.num_heads, self.head_dim, self.pruned_heads
262
+ )
263
+ index_attn = torch.cat(
264
+ [index, index + self.split_size, index + (2 * self.split_size)]
265
+ )
266
+
267
+ # Prune conv1d layers
268
+ self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
269
+ self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
270
+
271
+ # Update hyper params
272
+ self.split_size = (self.split_size // self.num_heads) * (
273
+ self.num_heads - len(heads)
274
+ )
275
+ self.num_heads = self.num_heads - len(heads)
276
+ self.pruned_heads = self.pruned_heads.union(heads)
277
+
278
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
279
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
280
+
281
+ if self.scale_attn_weights:
282
+ attn_weights = attn_weights / torch.full(
283
+ [],
284
+ value.size(-1) ** 0.5,
285
+ dtype=attn_weights.dtype,
286
+ device=attn_weights.device,
287
+ )
288
+
289
+ # Layer-wise attention scaling
290
+ if self.scale_attn_by_inverse_layer_idx:
291
+ attn_weights = attn_weights / float(self.layer_idx + 1)
292
+
293
+ if not self.is_cross_attention:
294
+ # if only "normal" attention layer implements causal mask
295
+ query_length, key_length = query.size(-2), key.size(-2)
296
+ causal_mask = self.bias[
297
+ :, :, key_length - query_length : key_length, :key_length
298
+ ]
299
+ mask_value = torch.finfo(attn_weights.dtype).min
300
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
301
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
302
+ mask_value = torch.full(
303
+ [], mask_value, dtype=attn_weights.dtype, device=attn_weights.device
304
+ )
305
+ attn_weights = torch.where(
306
+ causal_mask, attn_weights.to(attn_weights.dtype), mask_value
307
+ )
308
+
309
+ if attention_mask is not None:
310
+ # Apply the attention mask
311
+ attn_weights = attn_weights + attention_mask
312
+
313
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
314
+
315
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
316
+ attn_weights = attn_weights.type(value.dtype)
317
+ attn_weights = self.attn_dropout(attn_weights)
318
+
319
+ # Mask heads if we want to
320
+ if head_mask is not None:
321
+ attn_weights = attn_weights * head_mask
322
+
323
+ attn_output = torch.matmul(attn_weights, value)
324
+
325
+ return attn_output, attn_weights
326
+
327
+ def _upcast_and_reordered_attn(
328
+ self, query, key, value, attention_mask=None, head_mask=None
329
+ ):
330
+ # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
331
+ bsz, num_heads, q_seq_len, dk = query.size()
332
+ _, _, k_seq_len, _ = key.size()
333
+
334
+ # Preallocate attn_weights for `baddbmm`
335
+ attn_weights = torch.empty(
336
+ bsz * num_heads,
337
+ q_seq_len,
338
+ k_seq_len,
339
+ dtype=torch.float32,
340
+ device=query.device,
341
+ )
342
+
343
+ # Compute Scale Factor
344
+ scale_factor = 1.0
345
+ if self.scale_attn_weights:
346
+ scale_factor /= float(value.size(-1)) ** 0.5
347
+
348
+ if self.scale_attn_by_inverse_layer_idx:
349
+ scale_factor /= float(self.layer_idx + 1)
350
+
351
+ # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
352
+ with autocast(enabled=False):
353
+ q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(
354
+ -1, dk, k_seq_len
355
+ )
356
+ attn_weights = torch.baddbmm(
357
+ attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor
358
+ )
359
+ attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
360
+
361
+ if not self.is_cross_attention:
362
+ # if only "normal" attention layer implements causal mask
363
+ query_length, key_length = query.size(-2), key.size(-2)
364
+ causal_mask = self.bias[
365
+ :, :, key_length - query_length : key_length, :key_length
366
+ ]
367
+ mask_value = torch.finfo(attn_weights.dtype).min
368
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
369
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
370
+ mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(
371
+ attn_weights.device
372
+ )
373
+ attn_weights = torch.where(causal_mask, attn_weights, mask_value)
374
+
375
+ if attention_mask is not None:
376
+ # Apply the attention mask
377
+ attn_weights = attn_weights + attention_mask
378
+
379
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
380
+
381
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
382
+ if attn_weights.dtype != torch.float32:
383
+ raise RuntimeError(
384
+ "Error with upcasting, attn_weights does not have dtype torch.float32"
385
+ )
386
+ attn_weights = attn_weights.type(value.dtype)
387
+ attn_weights = self.attn_dropout(attn_weights)
388
+
389
+ # Mask heads if we want to
390
+ if head_mask is not None:
391
+ attn_weights = attn_weights * head_mask
392
+
393
+ attn_output = torch.matmul(attn_weights, value)
394
+
395
+ return attn_output, attn_weights
396
+
397
+ def _split_heads(self, tensor, num_heads, attn_head_size):
398
+ """
399
+ Splits hidden_size dim into attn_head_size and num_heads
400
+ """
401
+ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
402
+ tensor = tensor.view(new_shape)
403
+ return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
404
+
405
+ def _merge_heads(self, tensor, num_heads, attn_head_size):
406
+ """
407
+ Merges attn_head_size dim and num_attn_heads dim into hidden_size
408
+ """
409
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
410
+ new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
411
+ return tensor.view(new_shape)
412
+
413
+ def forward(
414
+ self,
415
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
416
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
417
+ attention_mask: Optional[torch.FloatTensor] = None,
418
+ head_mask: Optional[torch.FloatTensor] = None,
419
+ encoder_hidden_states: Optional[torch.Tensor] = None,
420
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
421
+ use_cache: Optional[bool] = False,
422
+ output_attentions: Optional[bool] = False,
423
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
424
+ if encoder_hidden_states is not None:
425
+ if not hasattr(self, "q_attn"):
426
+ raise ValueError(
427
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
428
+ "Please make sure to instantiate class with `AraGPT2Attention(..., is_cross_attention=True)`."
429
+ )
430
+
431
+ query = self.q_attn(hidden_states)
432
+ key, value = self.c_attn(encoder_hidden_states).split(
433
+ self.split_size, dim=2
434
+ )
435
+ attention_mask = encoder_attention_mask
436
+ else:
437
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
438
+
439
+ query = self._split_heads(query, self.num_heads, self.head_dim)
440
+ key = self._split_heads(key, self.num_heads, self.head_dim)
441
+ value = self._split_heads(value, self.num_heads, self.head_dim)
442
+
443
+ if layer_past is not None:
444
+ past_key, past_value = layer_past
445
+ key = torch.cat((past_key, key), dim=-2)
446
+ value = torch.cat((past_value, value), dim=-2)
447
+
448
+ if use_cache is True:
449
+ present = (key, value)
450
+ else:
451
+ present = None
452
+
453
+ if self.reorder_and_upcast_attn:
454
+ attn_output, attn_weights = self._upcast_and_reordered_attn(
455
+ query, key, value, attention_mask, head_mask
456
+ )
457
+ else:
458
+ attn_output, attn_weights = self._attn(
459
+ query, key, value, attention_mask, head_mask
460
+ )
461
+
462
+ attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
463
+ attn_output = self.c_proj(attn_output)
464
+ attn_output = self.resid_dropout(attn_output)
465
+
466
+ outputs = (attn_output, present)
467
+ if output_attentions:
468
+ outputs += (attn_weights,)
469
+
470
+ return outputs # a, present, (attentions)
471
+
472
+
473
+ class AraGPT2MLP(nn.Module):
474
+ def __init__(self, intermediate_size, config):
475
+ super().__init__()
476
+ embed_dim = config.hidden_size
477
+ self.c_fc = Conv1D(intermediate_size, embed_dim)
478
+ self.c_proj = Conv1D(embed_dim, intermediate_size)
479
+ self.act = ACT2FN[config.activation_function]
480
+ self.dropout = nn.Dropout(config.resid_pdrop)
481
+
482
+ def forward(
483
+ self, hidden_states: Optional[Tuple[torch.FloatTensor]]
484
+ ) -> torch.FloatTensor:
485
+ hidden_states = self.c_fc(hidden_states)
486
+ hidden_states = self.act(hidden_states)
487
+ hidden_states = self.c_proj(hidden_states)
488
+ hidden_states = self.dropout(hidden_states)
489
+ return hidden_states
490
+
491
+
492
+ class AraGPT2Block(nn.Module):
493
+ def __init__(self, config, layer_idx=None):
494
+ super().__init__()
495
+ hidden_size = config.hidden_size
496
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
497
+
498
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
499
+ self.attn = AraGPT2Attention(config, layer_idx=layer_idx)
500
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
501
+
502
+ if config.add_cross_attention:
503
+ self.crossattention = AraGPT2Attention(
504
+ config, is_cross_attention=True, layer_idx=layer_idx
505
+ )
506
+ self.ln_cross_attn = nn.LayerNorm(
507
+ hidden_size, eps=config.layer_norm_epsilon
508
+ )
509
+
510
+ self.mlp = AraGPT2MLP(inner_dim, config)
511
+
512
+ def forward(
513
+ self,
514
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
515
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
516
+ attention_mask: Optional[torch.FloatTensor] = None,
517
+ head_mask: Optional[torch.FloatTensor] = None,
518
+ encoder_hidden_states: Optional[torch.Tensor] = None,
519
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
520
+ use_cache: Optional[bool] = False,
521
+ output_attentions: Optional[bool] = False,
522
+ ) -> Union[
523
+ Tuple[torch.Tensor],
524
+ Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]],
525
+ ]:
526
+
527
+ # removed in GROVER
528
+ # residual = hidden_states
529
+ # hidden_states = self.ln_1(hidden_states)
530
+ attn_outputs = self.attn(
531
+ hidden_states,
532
+ layer_past=layer_past,
533
+ attention_mask=attention_mask,
534
+ head_mask=head_mask,
535
+ use_cache=use_cache,
536
+ output_attentions=output_attentions,
537
+ )
538
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
539
+ outputs = attn_outputs[1:]
540
+ # residual connection
541
+ hidden_states = attn_output + hidden_states
542
+
543
+ if encoder_hidden_states is not None:
544
+ # add one self-attention block for cross-attention
545
+ if not hasattr(self, "crossattention"):
546
+ raise ValueError(
547
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
548
+ "cross-attention layers by setting `config.add_cross_attention=True`"
549
+ )
550
+ # removed in GROVER
551
+ # residual = hidden_states
552
+ # hidden_states = self.ln_cross_attn(hidden_states)
553
+ cross_attn_outputs = self.crossattention(
554
+ hidden_states,
555
+ attention_mask=attention_mask,
556
+ head_mask=head_mask,
557
+ encoder_hidden_states=encoder_hidden_states,
558
+ encoder_attention_mask=encoder_attention_mask,
559
+ output_attentions=output_attentions,
560
+ )
561
+ attn_output = cross_attn_outputs[0]
562
+ # residual connection
563
+ hidden_states = attn_output + hidden_states
564
+ outputs = (
565
+ outputs + cross_attn_outputs[2:]
566
+ ) # add cross attentions if we output attention weights
567
+
568
+ residual = hidden_states
569
+ hidden_states = self.ln_1(hidden_states)
570
+ feed_forward_hidden_states = self.mlp(hidden_states)
571
+ # residual connection
572
+ hidden_states = residual + feed_forward_hidden_states
573
+
574
+ hidden_states = self.ln_2(hidden_states) # Added in GROVER
575
+
576
+ if use_cache:
577
+ outputs = (hidden_states,) + outputs
578
+ else:
579
+ outputs = (hidden_states,) + outputs[1:]
580
+
581
+ return outputs # hidden_states, present, (attentions, cross_attentions)
582
+
583
+
584
+ class AraGPT2PreTrainedModel(PreTrainedModel):
585
+ """
586
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
587
+ models.
588
+ """
589
+
590
+ config_class = AraGPT2Config
591
+ load_tf_weights = load_tf_weights_in_aragpt2
592
+ base_model_prefix = "transformer"
593
+ is_parallelizable = True
594
+ supports_gradient_checkpointing = True
595
+ _no_split_modules = ["AraGPT2Block"]
596
+ _skip_keys_device_placement = "past_key_values"
597
+
598
+ def __init__(self, *inputs, **kwargs):
599
+ super().__init__(*inputs, **kwargs)
600
+
601
+ def _init_weights(self, module):
602
+ """Initialize the weights."""
603
+ if isinstance(module, (nn.Linear, Conv1D)):
604
+ # Slightly different from the TF version which uses truncated_normal for initialization
605
+ # cf https://github.com/pytorch/pytorch/pull/5617
606
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
607
+ if module.bias is not None:
608
+ module.bias.data.zero_()
609
+ elif isinstance(module, nn.Embedding):
610
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
611
+ if module.padding_idx is not None:
612
+ module.weight.data[module.padding_idx].zero_()
613
+ elif isinstance(module, nn.LayerNorm):
614
+ module.bias.data.zero_()
615
+ module.weight.data.fill_(1.0)
616
+
617
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
618
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
619
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
620
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
621
+ #
622
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
623
+ for name, p in module.named_parameters():
624
+ if "c_proj" in name and "weight" in name:
625
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
626
+ p.data.normal_(
627
+ mean=0.0,
628
+ std=(
629
+ self.config.initializer_range
630
+ / math.sqrt(2 * self.config.n_layer)
631
+ ),
632
+ )
633
+
634
+
635
+ @dataclass
636
+ class AraGPT2DoubleHeadsModelOutput(ModelOutput):
637
+ """
638
+ Base class for outputs of models predicting if two sentences are consecutive or not.
639
+
640
+ Args:
641
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
642
+ Language modeling loss.
643
+ mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided):
644
+ Multiple choice classification loss.
645
+ logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`):
646
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
647
+ mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
648
+ Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
649
+ past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
650
+ Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads,
651
+ sequence_length, embed_size_per_head)`).
652
+
653
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
654
+ `past_key_values` input) to speed up sequential decoding.
655
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
656
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
657
+ shape `(batch_size, sequence_length, hidden_size)`.
658
+
659
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
660
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
661
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
662
+ sequence_length)`.
663
+
664
+ GPT2Attentions weights after the attention softmax, used to compute the weighted average in the
665
+ self-attention heads.
666
+ """
667
+
668
+ loss: Optional[torch.FloatTensor] = None
669
+ mc_loss: Optional[torch.FloatTensor] = None
670
+ logits: torch.FloatTensor = None
671
+ mc_logits: torch.FloatTensor = None
672
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
673
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
674
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
675
+
676
+
677
+ AraGPT2_START_DOCSTRING = r"""
678
+
679
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
680
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
681
+ etc.)
682
+
683
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
684
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
685
+ and behavior.
686
+
687
+ Parameters:
688
+ config ([`AraGPT2Config`]): Model configuration class with all the parameters of the model.
689
+ Initializing with a config file does not load the weights associated with the model, only the
690
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
691
+ """
692
+
693
+ GPT2_INPUTS_DOCSTRING = r"""
694
+ Args:
695
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
696
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
697
+ `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
698
+ sequence tokens in the vocabulary.
699
+
700
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
701
+ `input_ids`.
702
+
703
+ Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
704
+ [`PreTrainedTokenizer.__call__`] for details.
705
+
706
+ [What are input IDs?](../glossary#input-ids)
707
+ past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
708
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
709
+ `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
710
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
711
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
712
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
713
+
714
+ - 1 for tokens that are **not masked**,
715
+ - 0 for tokens that are **masked**.
716
+
717
+ If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for
718
+ `past_key_values`. In other words, the `attention_mask` always has to have the length:
719
+ `len(past_key_values) + len(input_ids)`
720
+
721
+ [What are attention masks?](../glossary#attention-mask)
722
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
723
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
724
+ 1]`:
725
+
726
+ - 0 corresponds to a *sentence A* token,
727
+ - 1 corresponds to a *sentence B* token.
728
+
729
+ [What are token type IDs?](../glossary#token-type-ids)
730
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
731
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
732
+ config.max_position_embeddings - 1]`.
733
+
734
+ [What are position IDs?](../glossary#position-ids)
735
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
736
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
737
+
738
+ - 1 indicates the head is **not masked**,
739
+ - 0 indicates the head is **masked**.
740
+
741
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
742
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
743
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
744
+ model's internal embedding lookup matrix.
745
+
746
+ If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
747
+ `past_key_values`).
748
+ use_cache (`bool`, *optional*):
749
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
750
+ `past_key_values`).
751
+ output_attentions (`bool`, *optional*):
752
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
753
+ tensors for more detail.
754
+ output_hidden_states (`bool`, *optional*):
755
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
756
+ more detail.
757
+ return_dict (`bool`, *optional*):
758
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
759
+ """
760
+ PARALLELIZE_DOCSTRING = r"""
761
+ This is an experimental feature and is a subject to change at a moment's notice.
762
+
763
+ Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
764
+ it will evenly distribute blocks across all devices.
765
+
766
+ Args:
767
+ device_map (`Dict[int, list]`, optional, defaults to None):
768
+ A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
769
+ automatically mapped to the first device (for esoteric reasons). That means that the first device should
770
+ have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the
771
+ following number of attention modules:
772
+
773
+ - aubmindlab/aragpt2-mega: 48
774
+
775
+ Example:
776
+
777
+ ```python
778
+ # Here is an example of a device map on a machine with 4 GPUs using aubmindlab/aragpt2-mega, which has a total of 48 attention modules:
779
+ model = AraGPT2LMHeadModel.from_pretrained("aubmindlab/aragpt2-mega")
780
+ device_map = {
781
+ 0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
782
+ 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
783
+ 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
784
+ 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
785
+ }
786
+ model.parallelize(device_map)
787
+ ```
788
+ """
789
+ DEPARALLELIZE_DOCSTRING = r"""
790
+ Moves the model to cpu from a model parallel state.
791
+
792
+ Example:
793
+
794
+ ```python
795
+ # On a 4 GPU machine with aubmindlab/aragpt2-mega:
796
+ model = AraGPT2LMHeadModel.from_pretrained("aubmindlab/aragpt2-mega")
797
+ device_map = {
798
+ 0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
799
+ 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
800
+ 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
801
+ 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
802
+ }
803
+ model.parallelize(device_map) # Splits the model across several devices
804
+ model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
805
+ ```
806
+ """
807
+
808
+
809
+ @add_start_docstrings(
810
+ "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
811
+ AraGPT2_START_DOCSTRING,
812
+ )
813
+ class AraGPT2Model(AraGPT2PreTrainedModel):
814
+ _keys_to_ignore_on_load_unexpected = ["attn.masked_bias"]
815
+ _keys_to_ignore_on_load_missing = ["attn.masked_bias"]
816
+
817
+ def __init__(self, config: AraGPT2Config):
818
+ super().__init__(config)
819
+
820
+ self.embed_dim = config.hidden_size
821
+
822
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
823
+ self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
824
+ self.emb_norm = nn.LayerNorm(
825
+ config.n_embd, eps=config.layer_norm_epsilon
826
+ ) # Added in GROVER
827
+ self.drop = nn.Dropout(config.embd_pdrop)
828
+ self.h = nn.ModuleList(
829
+ [AraGPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
830
+ )
831
+ # Removed in GROVER
832
+ # self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
833
+
834
+ # Model parallel
835
+ self.model_parallel = False
836
+ self.device_map = None
837
+ self.gradient_checkpointing = False
838
+
839
+ # Initialize weights and apply final processing
840
+ self.post_init()
841
+
842
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
843
+ def parallelize(self, device_map=None):
844
+ # Check validity of device_map
845
+ warnings.warn(
846
+ "`AraGPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your"
847
+ " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
848
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1,"
849
+ " ...}",
850
+ FutureWarning,
851
+ )
852
+ self.device_map = (
853
+ get_device_map(len(self.h), range(torch.cuda.device_count()))
854
+ if device_map is None
855
+ else device_map
856
+ )
857
+ assert_device_map(self.device_map, len(self.h))
858
+ self.model_parallel = True
859
+ self.first_device = (
860
+ "cpu"
861
+ if "cpu" in self.device_map.keys()
862
+ else "cuda:" + str(min(self.device_map.keys()))
863
+ )
864
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
865
+ self.wte = self.wte.to(self.first_device)
866
+ self.wpe = self.wpe.to(self.first_device)
867
+
868
+ # Added in GROVER
869
+ # Wissam: not sure if it is fine being on cpu or Better on GPU
870
+ self.emb_norm = self.emb_norm.to(
871
+ "cuda:" + str(min(self.device_map.keys()))
872
+ ) # GPU
873
+ # self.emb_norm = self.emb_norm.to(self.first_device) # CPU
874
+
875
+ # Load onto devices
876
+ for k, v in self.device_map.items():
877
+ for block in v:
878
+ cuda_device = "cuda:" + str(k)
879
+ self.h[block] = self.h[block].to(cuda_device)
880
+ # ln_f to last
881
+ # Removed in GROVER
882
+ # self.ln_f = self.ln_f.to(self.last_device)
883
+
884
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
885
+ def deparallelize(self):
886
+ warnings.warn(
887
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
888
+ FutureWarning,
889
+ )
890
+ self.model_parallel = False
891
+ self.device_map = None
892
+ self.first_device = "cpu"
893
+ self.last_device = "cpu"
894
+ self.wte = self.wte.to("cpu")
895
+ self.wpe = self.wpe.to("cpu")
896
+ # Added in GROVER
897
+ self.emb_norm = self.emb_norm.to("cpu")
898
+ for index in range(len(self.h)):
899
+ self.h[index] = self.h[index].to("cpu")
900
+ # Removed in GROVER
901
+ # self.ln_f = self.ln_f.to("cpu")
902
+ torch.cuda.empty_cache()
903
+
904
+ def get_input_embeddings(self):
905
+ return self.wte
906
+
907
+ def set_input_embeddings(self, new_embeddings):
908
+ self.wte = new_embeddings
909
+
910
+ def _prune_heads(self, heads_to_prune):
911
+ """
912
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
913
+ """
914
+ for layer, heads in heads_to_prune.items():
915
+ self.h[layer].attn.prune_heads(heads)
916
+
917
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
918
+ @add_code_sample_docstrings(
919
+ processor_class=_TOKENIZER_FOR_DOC,
920
+ checkpoint=_CHECKPOINT_FOR_DOC,
921
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
922
+ config_class=_CONFIG_FOR_DOC,
923
+ )
924
+ def forward(
925
+ self,
926
+ input_ids: Optional[torch.LongTensor] = None,
927
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
928
+ attention_mask: Optional[torch.FloatTensor] = None,
929
+ token_type_ids: Optional[torch.LongTensor] = None,
930
+ position_ids: Optional[torch.LongTensor] = None,
931
+ head_mask: Optional[torch.FloatTensor] = None,
932
+ inputs_embeds: Optional[torch.FloatTensor] = None,
933
+ encoder_hidden_states: Optional[torch.Tensor] = None,
934
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
935
+ use_cache: Optional[bool] = None,
936
+ output_attentions: Optional[bool] = None,
937
+ output_hidden_states: Optional[bool] = None,
938
+ return_dict: Optional[bool] = None,
939
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
940
+ output_attentions = (
941
+ output_attentions
942
+ if output_attentions is not None
943
+ else self.config.output_attentions
944
+ )
945
+ output_hidden_states = (
946
+ output_hidden_states
947
+ if output_hidden_states is not None
948
+ else self.config.output_hidden_states
949
+ )
950
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
951
+ return_dict = (
952
+ return_dict if return_dict is not None else self.config.use_return_dict
953
+ )
954
+
955
+ if input_ids is not None and inputs_embeds is not None:
956
+ raise ValueError(
957
+ "You cannot specify both input_ids and inputs_embeds at the same time"
958
+ )
959
+ elif input_ids is not None:
960
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
961
+ input_shape = input_ids.size()
962
+ input_ids = input_ids.view(-1, input_shape[-1])
963
+ batch_size = input_ids.shape[0]
964
+ elif inputs_embeds is not None:
965
+ input_shape = inputs_embeds.size()[:-1]
966
+ batch_size = inputs_embeds.shape[0]
967
+ else:
968
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
969
+
970
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
971
+
972
+ if token_type_ids is not None:
973
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
974
+
975
+ if past_key_values is None:
976
+ past_length = 0
977
+ past_key_values = tuple([None] * len(self.h))
978
+ else:
979
+ past_length = past_key_values[0][0].size(-2)
980
+ if position_ids is None:
981
+ position_ids = torch.arange(
982
+ past_length,
983
+ input_shape[-1] + past_length,
984
+ dtype=torch.long,
985
+ device=device,
986
+ )
987
+ position_ids = position_ids.unsqueeze(0)
988
+
989
+ # AraGPT2Attention mask.
990
+ if attention_mask is not None:
991
+ if batch_size <= 0:
992
+ raise ValueError("batch_size has to be defined and > 0")
993
+ attention_mask = attention_mask.view(batch_size, -1)
994
+ # We create a 3D attention mask from a 2D tensor mask.
995
+ # Sizes are [batch_size, 1, 1, to_seq_length]
996
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
997
+ # this attention mask is more simple than the triangular masking of causal attention
998
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
999
+ attention_mask = attention_mask[:, None, None, :]
1000
+
1001
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
1002
+ # masked positions, this operation will create a tensor which is 0.0 for
1003
+ # positions we want to attend and the dtype's smallest value for masked positions.
1004
+ # Since we are adding it to the raw scores before the softmax, this is
1005
+ # effectively the same as removing these entirely.
1006
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
1007
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
1008
+
1009
+ # If a 2D or 3D attention mask is provided for the cross-attention
1010
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1011
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
1012
+ encoder_batch_size, encoder_sequence_length, _ = (
1013
+ encoder_hidden_states.size()
1014
+ )
1015
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1016
+ if encoder_attention_mask is None:
1017
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
1018
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
1019
+ else:
1020
+ encoder_attention_mask = None
1021
+
1022
+ # Prepare head mask if needed
1023
+ # 1.0 in head_mask indicate we keep the head
1024
+ # attention_probs has shape bsz x n_heads x N x N
1025
+ # head_mask has shape n_layer x batch x n_heads x N x N
1026
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
1027
+
1028
+ if inputs_embeds is None:
1029
+ inputs_embeds = self.wte(input_ids)
1030
+ position_embeds = self.wpe(position_ids)
1031
+ hidden_states = inputs_embeds + position_embeds
1032
+
1033
+ if token_type_ids is not None:
1034
+ token_type_embeds = self.wte(token_type_ids)
1035
+ hidden_states = hidden_states + token_type_embeds
1036
+
1037
+ hidden_states = self.drop(hidden_states)
1038
+ # Added in Grover
1039
+ hidden_states = self.emb_norm(hidden_states)
1040
+
1041
+ output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
1042
+
1043
+ if self.gradient_checkpointing and self.training:
1044
+ if use_cache:
1045
+ logger.warning_once(
1046
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1047
+ )
1048
+ use_cache = False
1049
+
1050
+ presents = () if use_cache else None
1051
+ all_self_attentions = () if output_attentions else None
1052
+ all_cross_attentions = (
1053
+ () if output_attentions and self.config.add_cross_attention else None
1054
+ )
1055
+ all_hidden_states = () if output_hidden_states else None
1056
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
1057
+ # Model parallel
1058
+ if self.model_parallel:
1059
+ torch.cuda.set_device(hidden_states.device)
1060
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
1061
+ if layer_past is not None:
1062
+ layer_past = tuple(
1063
+ past_state.to(hidden_states.device) for past_state in layer_past
1064
+ )
1065
+ # Ensure that attention_mask is always on the same device as hidden_states
1066
+ if attention_mask is not None:
1067
+ attention_mask = attention_mask.to(hidden_states.device)
1068
+ if isinstance(head_mask, torch.Tensor):
1069
+ head_mask = head_mask.to(hidden_states.device)
1070
+ if output_hidden_states:
1071
+ all_hidden_states = all_hidden_states + (hidden_states,)
1072
+
1073
+ if self.gradient_checkpointing and self.training:
1074
+ outputs = self._gradient_checkpointing_func(
1075
+ block.__call__,
1076
+ hidden_states,
1077
+ None,
1078
+ attention_mask,
1079
+ head_mask[i],
1080
+ encoder_hidden_states,
1081
+ encoder_attention_mask,
1082
+ use_cache,
1083
+ output_attentions,
1084
+ )
1085
+ else:
1086
+ outputs = block(
1087
+ hidden_states,
1088
+ layer_past=layer_past,
1089
+ attention_mask=attention_mask,
1090
+ head_mask=head_mask[i],
1091
+ encoder_hidden_states=encoder_hidden_states,
1092
+ encoder_attention_mask=encoder_attention_mask,
1093
+ use_cache=use_cache,
1094
+ output_attentions=output_attentions,
1095
+ )
1096
+
1097
+ hidden_states = outputs[0]
1098
+ if use_cache is True:
1099
+ presents = presents + (outputs[1],)
1100
+
1101
+ if output_attentions:
1102
+ all_self_attentions = all_self_attentions + (
1103
+ outputs[2 if use_cache else 1],
1104
+ )
1105
+ if self.config.add_cross_attention:
1106
+ all_cross_attentions = all_cross_attentions + (
1107
+ outputs[3 if use_cache else 2],
1108
+ )
1109
+
1110
+ # Model Parallel: If it's the last layer for that device, put things on the next device
1111
+ if self.model_parallel:
1112
+ for k, v in self.device_map.items():
1113
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
1114
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
1115
+
1116
+ # Removed in Grover
1117
+ # hidden_states = self.ln_f(hidden_states)
1118
+
1119
+ hidden_states = hidden_states.view(output_shape)
1120
+ # Add last hidden state
1121
+ if output_hidden_states:
1122
+ all_hidden_states = all_hidden_states + (hidden_states,)
1123
+
1124
+ if not return_dict:
1125
+ return tuple(
1126
+ v
1127
+ for v in [
1128
+ hidden_states,
1129
+ presents,
1130
+ all_hidden_states,
1131
+ all_self_attentions,
1132
+ all_cross_attentions,
1133
+ ]
1134
+ if v is not None
1135
+ )
1136
+
1137
+ return BaseModelOutputWithPastAndCrossAttentions(
1138
+ last_hidden_state=hidden_states,
1139
+ past_key_values=presents,
1140
+ hidden_states=all_hidden_states,
1141
+ attentions=all_self_attentions,
1142
+ cross_attentions=all_cross_attentions,
1143
+ )
1144
+
1145
+
1146
+ @add_start_docstrings(
1147
+ """
1148
+ The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
1149
+ embeddings).
1150
+ """,
1151
+ AraGPT2_START_DOCSTRING,
1152
+ )
1153
+ class AraGPT2LMHeadModel(AraGPT2PreTrainedModel):
1154
+ _keys_to_ignore_on_load_unexpected = [
1155
+ r"attn.masked_bias",
1156
+ r"attn.bias",
1157
+ r"lm_head.weight",
1158
+ ]
1159
+ _keys_to_ignore_on_load_missing = [
1160
+ r"attn.masked_bias",
1161
+ r"attn.bias",
1162
+ r"lm_head.weight",
1163
+ ]
1164
+ _tied_weights_keys = ["lm_head.weight"]
1165
+
1166
+ def __init__(self, config: AraGPT2Config):
1167
+ super().__init__(config)
1168
+ self.transformer = AraGPT2Model(config)
1169
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1170
+
1171
+ # Model parallel
1172
+ self.model_parallel = False
1173
+ self.device_map = None
1174
+
1175
+ # Initialize weights and apply final processing
1176
+ self.post_init()
1177
+
1178
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1179
+ def parallelize(self, device_map=None):
1180
+ warnings.warn(
1181
+ "`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
1182
+ " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
1183
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':"
1184
+ " 0, 'transformer.h.1': 1, ...}",
1185
+ FutureWarning,
1186
+ )
1187
+ self.device_map = (
1188
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
1189
+ if device_map is None
1190
+ else device_map
1191
+ )
1192
+ assert_device_map(self.device_map, len(self.transformer.h))
1193
+ self.transformer.parallelize(self.device_map)
1194
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
1195
+ self.model_parallel = True
1196
+
1197
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1198
+ def deparallelize(self):
1199
+ warnings.warn(
1200
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1201
+ FutureWarning,
1202
+ )
1203
+ self.transformer.deparallelize()
1204
+ self.transformer = self.transformer.to("cpu")
1205
+ self.lm_head = self.lm_head.to("cpu")
1206
+ self.model_parallel = False
1207
+ torch.cuda.empty_cache()
1208
+
1209
+ def get_output_embeddings(self):
1210
+ return self.lm_head
1211
+
1212
+ def set_output_embeddings(self, new_embeddings):
1213
+ self.lm_head = new_embeddings
1214
+
1215
+ def prepare_inputs_for_generation(
1216
+ self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
1217
+ ):
1218
+ token_type_ids = kwargs.get("token_type_ids", None)
1219
+ # Omit tokens covered by past_key_values
1220
+ if past_key_values:
1221
+ past_length = past_key_values[0][0].shape[2]
1222
+
1223
+ # Some generation methods already pass only the last input ID
1224
+ if input_ids.shape[1] > past_length:
1225
+ remove_prefix_length = past_length
1226
+ else:
1227
+ # Default to old behavior: keep only final ID
1228
+ remove_prefix_length = input_ids.shape[1] - 1
1229
+
1230
+ input_ids = input_ids[:, remove_prefix_length:]
1231
+ if token_type_ids is not None:
1232
+ token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
1233
+
1234
+ attention_mask = kwargs.get("attention_mask", None)
1235
+ position_ids = kwargs.get("position_ids", None)
1236
+
1237
+ if attention_mask is not None and position_ids is None:
1238
+ # create position_ids on the fly for batch generation
1239
+ position_ids = attention_mask.long().cumsum(-1) - 1
1240
+ position_ids.masked_fill_(attention_mask == 0, 1)
1241
+ if past_key_values:
1242
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1243
+ else:
1244
+ position_ids = None
1245
+
1246
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1247
+ if inputs_embeds is not None and past_key_values is None:
1248
+ model_inputs = {"inputs_embeds": inputs_embeds}
1249
+ else:
1250
+ model_inputs = {"input_ids": input_ids}
1251
+
1252
+ model_inputs.update(
1253
+ {
1254
+ "past_key_values": past_key_values,
1255
+ "use_cache": kwargs.get("use_cache"),
1256
+ "position_ids": position_ids,
1257
+ "attention_mask": attention_mask,
1258
+ "token_type_ids": token_type_ids,
1259
+ }
1260
+ )
1261
+
1262
+ return model_inputs
1263
+
1264
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1265
+ @add_code_sample_docstrings(
1266
+ processor_class=_TOKENIZER_FOR_DOC,
1267
+ checkpoint=_CHECKPOINT_FOR_DOC,
1268
+ output_type=CausalLMOutputWithCrossAttentions,
1269
+ config_class=_CONFIG_FOR_DOC,
1270
+ )
1271
+ def forward(
1272
+ self,
1273
+ input_ids: Optional[torch.LongTensor] = None,
1274
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1275
+ attention_mask: Optional[torch.FloatTensor] = None,
1276
+ token_type_ids: Optional[torch.LongTensor] = None,
1277
+ position_ids: Optional[torch.LongTensor] = None,
1278
+ head_mask: Optional[torch.FloatTensor] = None,
1279
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1280
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1281
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1282
+ labels: Optional[torch.LongTensor] = None,
1283
+ use_cache: Optional[bool] = None,
1284
+ output_attentions: Optional[bool] = None,
1285
+ output_hidden_states: Optional[bool] = None,
1286
+ return_dict: Optional[bool] = None,
1287
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
1288
+ r"""
1289
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1290
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1291
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1292
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1293
+ """
1294
+ return_dict = (
1295
+ return_dict if return_dict is not None else self.config.use_return_dict
1296
+ )
1297
+
1298
+ transformer_outputs = self.transformer(
1299
+ input_ids,
1300
+ past_key_values=past_key_values,
1301
+ attention_mask=attention_mask,
1302
+ token_type_ids=token_type_ids,
1303
+ position_ids=position_ids,
1304
+ head_mask=head_mask,
1305
+ inputs_embeds=inputs_embeds,
1306
+ encoder_hidden_states=encoder_hidden_states,
1307
+ encoder_attention_mask=encoder_attention_mask,
1308
+ use_cache=use_cache,
1309
+ output_attentions=output_attentions,
1310
+ output_hidden_states=output_hidden_states,
1311
+ return_dict=return_dict,
1312
+ )
1313
+ hidden_states = transformer_outputs[0]
1314
+
1315
+ # Set device for model parallelism
1316
+ if self.model_parallel:
1317
+ torch.cuda.set_device(self.transformer.first_device)
1318
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
1319
+
1320
+ lm_logits = self.lm_head(hidden_states)
1321
+
1322
+ loss = None
1323
+ if labels is not None:
1324
+ # move labels to correct device to enable model parallelism
1325
+ labels = labels.to(lm_logits.device)
1326
+ # Shift so that tokens < n predict n
1327
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1328
+ shift_labels = labels[..., 1:].contiguous()
1329
+ # Flatten the tokens
1330
+ loss_fct = CrossEntropyLoss()
1331
+ loss = loss_fct(
1332
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
1333
+ )
1334
+
1335
+ if not return_dict:
1336
+ output = (lm_logits,) + transformer_outputs[1:]
1337
+ return ((loss,) + output) if loss is not None else output
1338
+
1339
+ return CausalLMOutputWithCrossAttentions(
1340
+ loss=loss,
1341
+ logits=lm_logits,
1342
+ past_key_values=transformer_outputs.past_key_values,
1343
+ hidden_states=transformer_outputs.hidden_states,
1344
+ attentions=transformer_outputs.attentions,
1345
+ cross_attentions=transformer_outputs.cross_attentions,
1346
+ )
1347
+
1348
+ @staticmethod
1349
+ def _reorder_cache(
1350
+ past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
1351
+ ) -> Tuple[Tuple[torch.Tensor]]:
1352
+ """
1353
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1354
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1355
+ beam_idx at every generation step.
1356
+ """
1357
+ return tuple(
1358
+ tuple(
1359
+ past_state.index_select(0, beam_idx.to(past_state.device))
1360
+ for past_state in layer_past
1361
+ )
1362
+ for layer_past in past_key_values
1363
+ )
1364
+
1365
+
1366
+ @add_start_docstrings(
1367
+ """
1368
+ The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
1369
+ RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the
1370
+ input embeddings, the classification head takes as input the input of a specified classification token index in the
1371
+ input sequence).
1372
+ """,
1373
+ AraGPT2_START_DOCSTRING,
1374
+ )
1375
+ class AraGPT2DoubleHeadsModel(AraGPT2PreTrainedModel):
1376
+ _keys_to_ignore_on_load_unexpected = [
1377
+ r"attn.masked_bias",
1378
+ r"attn.bias",
1379
+ r"lm_head.weight",
1380
+ ]
1381
+ _keys_to_ignore_on_load_missing = [
1382
+ r"attn.masked_bias",
1383
+ r"attn.bias",
1384
+ r"lm_head.weight",
1385
+ ]
1386
+ _tied_weights_keys = ["lm_head.weight"]
1387
+
1388
+ def __init__(self, config: AraGPT2Config):
1389
+ super().__init__(config)
1390
+ config.num_labels = 1
1391
+ self.transformer = AraGPT2Model(config)
1392
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1393
+ self.multiple_choice_head = SequenceSummary(config)
1394
+
1395
+ # Model parallel
1396
+ self.model_parallel = False
1397
+ self.device_map = None
1398
+
1399
+ # Initialize weights and apply final processing
1400
+ self.post_init()
1401
+
1402
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1403
+ def parallelize(self, device_map=None):
1404
+ warnings.warn(
1405
+ "`GPT2DoubleHeadsModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should"
1406
+ " load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your"
1407
+ " own `device_map` but it needs to be a dictionary module_name to device, so for instance"
1408
+ " {'transformer.h.0': 0, 'transformer.h.1': 1, ...}",
1409
+ FutureWarning,
1410
+ )
1411
+ self.device_map = (
1412
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
1413
+ if device_map is None
1414
+ else device_map
1415
+ )
1416
+ assert_device_map(self.device_map, len(self.transformer.h))
1417
+ self.transformer.parallelize(self.device_map)
1418
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
1419
+ self.multiple_choice_head = self.multiple_choice_head.to(
1420
+ self.transformer.first_device
1421
+ )
1422
+ self.model_parallel = True
1423
+
1424
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1425
+ def deparallelize(self):
1426
+ warnings.warn(
1427
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1428
+ FutureWarning,
1429
+ )
1430
+ self.transformer.deparallelize()
1431
+ self.transformer = self.transformer.to("cpu")
1432
+ self.lm_head = self.lm_head.to("cpu")
1433
+ self.multiple_choice_head = self.multiple_choice_head.to("cpu")
1434
+ self.model_parallel = False
1435
+ torch.cuda.empty_cache()
1436
+
1437
+ def get_output_embeddings(self):
1438
+ return self.lm_head
1439
+
1440
+ def set_output_embeddings(self, new_embeddings):
1441
+ self.lm_head = new_embeddings
1442
+
1443
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
1444
+ token_type_ids = kwargs.get("token_type_ids", None)
1445
+ # Omit tokens covered by past_key_values
1446
+ if past_key_values:
1447
+ past_length = past_key_values[0][0].shape[2]
1448
+
1449
+ # Some generation methods already pass only the last input ID
1450
+ if input_ids.shape[1] > past_length:
1451
+ remove_prefix_length = past_length
1452
+ else:
1453
+ # Default to old behavior: keep only final ID
1454
+ remove_prefix_length = input_ids.shape[1] - 1
1455
+
1456
+ input_ids = input_ids[:, remove_prefix_length:]
1457
+ if token_type_ids is not None:
1458
+ token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
1459
+
1460
+ attention_mask = kwargs.get("attention_mask", None)
1461
+ position_ids = kwargs.get("position_ids", None)
1462
+
1463
+ if attention_mask is not None and position_ids is None:
1464
+ # create position_ids on the fly for batch generation
1465
+ position_ids = attention_mask.long().cumsum(-1) - 1
1466
+ position_ids.masked_fill_(attention_mask == 0, 1)
1467
+ if past_key_values:
1468
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1469
+ else:
1470
+ position_ids = None
1471
+
1472
+ return {
1473
+ "input_ids": input_ids,
1474
+ "past_key_values": past_key_values,
1475
+ "use_cache": kwargs.get("use_cache"),
1476
+ "position_ids": position_ids,
1477
+ "attention_mask": attention_mask,
1478
+ "token_type_ids": token_type_ids,
1479
+ }
1480
+
1481
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1482
+ @replace_return_docstrings(
1483
+ output_type=AraGPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC
1484
+ )
1485
+ def forward(
1486
+ self,
1487
+ input_ids: Optional[torch.LongTensor] = None,
1488
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1489
+ attention_mask: Optional[torch.FloatTensor] = None,
1490
+ token_type_ids: Optional[torch.LongTensor] = None,
1491
+ position_ids: Optional[torch.LongTensor] = None,
1492
+ head_mask: Optional[torch.FloatTensor] = None,
1493
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1494
+ mc_token_ids: Optional[torch.LongTensor] = None,
1495
+ labels: Optional[torch.LongTensor] = None,
1496
+ mc_labels: Optional[torch.LongTensor] = None,
1497
+ use_cache: Optional[bool] = None,
1498
+ output_attentions: Optional[bool] = None,
1499
+ output_hidden_states: Optional[bool] = None,
1500
+ return_dict: Optional[bool] = None,
1501
+ **kwargs,
1502
+ ) -> Union[Tuple, AraGPT2DoubleHeadsModelOutput]:
1503
+ r"""
1504
+ mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
1505
+ Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
1506
+ 1]`.
1507
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1508
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1509
+ `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to
1510
+ `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]`
1511
+ mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*):
1512
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
1513
+ where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above)
1514
+
1515
+ Return:
1516
+
1517
+ Example:
1518
+
1519
+ ```python
1520
+ >>> import torch
1521
+ >>> from transformers import GPT2Tokenizer, GPT2DoubleHeadsModel
1522
+
1523
+ >>> tokenizer = GPT2Tokenizer.from_pretrained("aubmindlab/aragpt2-mega")
1524
+ >>> model = GPT2DoubleHeadsModel.from_pretrained("aubmindlab/aragpt2-mega")
1525
+
1526
+ >>> # Add a [CLS] to the vocabulary (we should train it also!)
1527
+ >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"})
1528
+ >>> # Update the model embeddings with the new vocabulary size
1529
+ >>> embedding_layer = model.resize_token_embeddings(len(tokenizer))
1530
+
1531
+ >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
1532
+ >>> encoded_choices = [tokenizer.encode(s) for s in choices]
1533
+ >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]
1534
+
1535
+ >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2
1536
+ >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1
1537
+
1538
+ >>> outputs = model(input_ids, mc_token_ids=mc_token_ids)
1539
+ >>> lm_logits = outputs.logits
1540
+ >>> mc_logits = outputs.mc_logits
1541
+ ```"""
1542
+ return_dict = (
1543
+ return_dict if return_dict is not None else self.config.use_return_dict
1544
+ )
1545
+
1546
+ transformer_outputs = self.transformer(
1547
+ input_ids,
1548
+ past_key_values=past_key_values,
1549
+ attention_mask=attention_mask,
1550
+ token_type_ids=token_type_ids,
1551
+ position_ids=position_ids,
1552
+ head_mask=head_mask,
1553
+ inputs_embeds=inputs_embeds,
1554
+ use_cache=use_cache,
1555
+ output_attentions=output_attentions,
1556
+ output_hidden_states=output_hidden_states,
1557
+ return_dict=return_dict,
1558
+ )
1559
+
1560
+ hidden_states = transformer_outputs[0]
1561
+
1562
+ # Set device for model parallelism
1563
+ if self.model_parallel:
1564
+ torch.cuda.set_device(self.transformer.first_device)
1565
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
1566
+
1567
+ lm_logits = self.lm_head(hidden_states)
1568
+ mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
1569
+
1570
+ mc_loss = None
1571
+ if mc_labels is not None:
1572
+ loss_fct = CrossEntropyLoss()
1573
+ mc_loss = loss_fct(
1574
+ mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)
1575
+ )
1576
+ lm_loss = None
1577
+ if labels is not None:
1578
+ labels = labels.to(lm_logits.device)
1579
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1580
+ shift_labels = labels[..., 1:].contiguous()
1581
+ loss_fct = CrossEntropyLoss()
1582
+ lm_loss = loss_fct(
1583
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
1584
+ )
1585
+
1586
+ if not return_dict:
1587
+ output = (lm_logits, mc_logits) + transformer_outputs[1:]
1588
+ if mc_loss is not None:
1589
+ output = (mc_loss,) + output
1590
+ return ((lm_loss,) + output) if lm_loss is not None else output
1591
+
1592
+ return AraGPT2DoubleHeadsModelOutput(
1593
+ loss=lm_loss,
1594
+ mc_loss=mc_loss,
1595
+ logits=lm_logits,
1596
+ mc_logits=mc_logits,
1597
+ past_key_values=transformer_outputs.past_key_values,
1598
+ hidden_states=transformer_outputs.hidden_states,
1599
+ attentions=transformer_outputs.attentions,
1600
+ )
1601
+
1602
+ @staticmethod
1603
+ def _reorder_cache(
1604
+ past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
1605
+ ) -> Tuple[Tuple[torch.Tensor]]:
1606
+ """
1607
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1608
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1609
+ beam_idx at every generation step.
1610
+ """
1611
+ return tuple(
1612
+ tuple(
1613
+ past_state.index_select(0, beam_idx.to(past_state.device))
1614
+ for past_state in layer_past
1615
+ )
1616
+ for layer_past in past_key_values
1617
+ )
1618
+
1619
+
1620
+ @add_start_docstrings(
1621
+ """
1622
+ The GPT2 Model transformer with a sequence classification head on top (linear layer).
1623
+
1624
+ [`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1625
+ (e.g. GPT-1) do.
1626
+
1627
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1628
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1629
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1630
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1631
+ each row of the batch).
1632
+ """,
1633
+ AraGPT2_START_DOCSTRING,
1634
+ )
1635
+ class AraGPT2ForSequenceClassification(AraGPT2PreTrainedModel):
1636
+ _keys_to_ignore_on_load_unexpected = [
1637
+ r"h\.\d+\.attn\.masked_bias",
1638
+ r"lm_head.weight",
1639
+ ]
1640
+ _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head.weight"]
1641
+
1642
+ def __init__(self, config: AraGPT2Config):
1643
+ super().__init__(config)
1644
+ self.num_labels = config.num_labels
1645
+ self.transformer = AraGPT2Model(config)
1646
+ self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
1647
+
1648
+ # Model parallel
1649
+ self.model_parallel = False
1650
+ self.device_map = None
1651
+
1652
+ # Initialize weights and apply final processing
1653
+ self.post_init()
1654
+
1655
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1656
+ @add_code_sample_docstrings(
1657
+ processor_class=_TOKENIZER_FOR_DOC,
1658
+ output_type=SequenceClassifierOutputWithPast,
1659
+ config_class=_CONFIG_FOR_DOC,
1660
+ )
1661
+ def forward(
1662
+ self,
1663
+ input_ids: Optional[torch.LongTensor] = None,
1664
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1665
+ attention_mask: Optional[torch.FloatTensor] = None,
1666
+ token_type_ids: Optional[torch.LongTensor] = None,
1667
+ position_ids: Optional[torch.LongTensor] = None,
1668
+ head_mask: Optional[torch.FloatTensor] = None,
1669
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1670
+ labels: Optional[torch.LongTensor] = None,
1671
+ use_cache: Optional[bool] = None,
1672
+ output_attentions: Optional[bool] = None,
1673
+ output_hidden_states: Optional[bool] = None,
1674
+ return_dict: Optional[bool] = None,
1675
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1676
+ r"""
1677
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1678
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1679
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1680
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1681
+ """
1682
+ return_dict = (
1683
+ return_dict if return_dict is not None else self.config.use_return_dict
1684
+ )
1685
+
1686
+ transformer_outputs = self.transformer(
1687
+ input_ids,
1688
+ past_key_values=past_key_values,
1689
+ attention_mask=attention_mask,
1690
+ token_type_ids=token_type_ids,
1691
+ position_ids=position_ids,
1692
+ head_mask=head_mask,
1693
+ inputs_embeds=inputs_embeds,
1694
+ use_cache=use_cache,
1695
+ output_attentions=output_attentions,
1696
+ output_hidden_states=output_hidden_states,
1697
+ return_dict=return_dict,
1698
+ )
1699
+ hidden_states = transformer_outputs[0]
1700
+ logits = self.score(hidden_states)
1701
+
1702
+ if input_ids is not None:
1703
+ batch_size, sequence_length = input_ids.shape[:2]
1704
+ else:
1705
+ batch_size, sequence_length = inputs_embeds.shape[:2]
1706
+
1707
+ assert (
1708
+ self.config.pad_token_id is not None or batch_size == 1
1709
+ ), "Cannot handle batch sizes > 1 if no padding token is defined."
1710
+ if self.config.pad_token_id is None:
1711
+ sequence_lengths = -1
1712
+ else:
1713
+ if input_ids is not None:
1714
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1715
+ sequence_lengths = (
1716
+ torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1717
+ )
1718
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1719
+ sequence_lengths = sequence_lengths.to(logits.device)
1720
+ else:
1721
+ sequence_lengths = -1
1722
+ logger.warning(
1723
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1724
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1725
+ )
1726
+
1727
+ pooled_logits = logits[
1728
+ torch.arange(batch_size, device=logits.device), sequence_lengths
1729
+ ]
1730
+
1731
+ loss = None
1732
+ if labels is not None:
1733
+ if self.config.problem_type is None:
1734
+ if self.num_labels == 1:
1735
+ self.config.problem_type = "regression"
1736
+ elif self.num_labels > 1 and (
1737
+ labels.dtype == torch.long or labels.dtype == torch.int
1738
+ ):
1739
+ self.config.problem_type = "single_label_classification"
1740
+ else:
1741
+ self.config.problem_type = "multi_label_classification"
1742
+
1743
+ if self.config.problem_type == "regression":
1744
+ loss_fct = MSELoss()
1745
+ if self.num_labels == 1:
1746
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1747
+ else:
1748
+ loss = loss_fct(pooled_logits, labels)
1749
+ elif self.config.problem_type == "single_label_classification":
1750
+ loss_fct = CrossEntropyLoss()
1751
+ loss = loss_fct(
1752
+ pooled_logits.view(-1, self.num_labels), labels.view(-1)
1753
+ )
1754
+ elif self.config.problem_type == "multi_label_classification":
1755
+ loss_fct = BCEWithLogitsLoss()
1756
+ loss = loss_fct(pooled_logits, labels)
1757
+ if not return_dict:
1758
+ output = (pooled_logits,) + transformer_outputs[1:]
1759
+ return ((loss,) + output) if loss is not None else output
1760
+
1761
+ return SequenceClassifierOutputWithPast(
1762
+ loss=loss,
1763
+ logits=pooled_logits,
1764
+ past_key_values=transformer_outputs.past_key_values,
1765
+ hidden_states=transformer_outputs.hidden_states,
1766
+ attentions=transformer_outputs.attentions,
1767
+ )
1768
+
1769
+
1770
+ @add_start_docstrings(
1771
+ """
1772
+ GPT2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1773
+ Named-Entity-Recognition (NER) tasks.
1774
+ """,
1775
+ AraGPT2_START_DOCSTRING,
1776
+ )
1777
+ class AraGPT2ForTokenClassification(AraGPT2PreTrainedModel):
1778
+ def __init__(self, config: AraGPT2Config):
1779
+ super().__init__(config)
1780
+ self.num_labels = config.num_labels
1781
+
1782
+ self.transformer = AraGPT2Model(config)
1783
+ if (
1784
+ hasattr(config, "classifier_dropout")
1785
+ and config.classifier_dropout is not None
1786
+ ):
1787
+ classifier_dropout = config.classifier_dropout
1788
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1789
+ classifier_dropout = config.hidden_dropout
1790
+ else:
1791
+ classifier_dropout = 0.1
1792
+ self.dropout = nn.Dropout(classifier_dropout)
1793
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1794
+
1795
+ # Model parallel
1796
+ self.model_parallel = False
1797
+ self.device_map = None
1798
+
1799
+ # Initialize weights and apply final processing
1800
+ self.post_init()
1801
+
1802
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1803
+ # fmt: off
1804
+ @add_code_sample_docstrings(
1805
+ processor_class=_TOKENIZER_FOR_DOC,
1806
+ output_type=TokenClassifierOutput,
1807
+ config_class=_CONFIG_FOR_DOC,
1808
+ )
1809
+ # fmt: on
1810
+ def forward(
1811
+ self,
1812
+ input_ids: Optional[torch.LongTensor] = None,
1813
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1814
+ attention_mask: Optional[torch.FloatTensor] = None,
1815
+ token_type_ids: Optional[torch.LongTensor] = None,
1816
+ position_ids: Optional[torch.LongTensor] = None,
1817
+ head_mask: Optional[torch.FloatTensor] = None,
1818
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1819
+ labels: Optional[torch.LongTensor] = None,
1820
+ use_cache: Optional[bool] = None,
1821
+ output_attentions: Optional[bool] = None,
1822
+ output_hidden_states: Optional[bool] = None,
1823
+ return_dict: Optional[bool] = None,
1824
+ ) -> Union[Tuple, TokenClassifierOutput]:
1825
+ r"""
1826
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1827
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1828
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1829
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1830
+ """
1831
+ return_dict = (
1832
+ return_dict if return_dict is not None else self.config.use_return_dict
1833
+ )
1834
+
1835
+ transformer_outputs = self.transformer(
1836
+ input_ids,
1837
+ past_key_values=past_key_values,
1838
+ attention_mask=attention_mask,
1839
+ token_type_ids=token_type_ids,
1840
+ position_ids=position_ids,
1841
+ head_mask=head_mask,
1842
+ inputs_embeds=inputs_embeds,
1843
+ use_cache=use_cache,
1844
+ output_attentions=output_attentions,
1845
+ output_hidden_states=output_hidden_states,
1846
+ return_dict=return_dict,
1847
+ )
1848
+
1849
+ hidden_states = transformer_outputs[0]
1850
+ hidden_states = self.dropout(hidden_states)
1851
+ logits = self.classifier(hidden_states)
1852
+
1853
+ loss = None
1854
+ if labels is not None:
1855
+ labels = labels.to(logits.device)
1856
+ loss_fct = CrossEntropyLoss()
1857
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1858
+
1859
+ if not return_dict:
1860
+ output = (logits,) + transformer_outputs[2:]
1861
+ return ((loss,) + output) if loss is not None else output
1862
+
1863
+ return TokenClassifierOutput(
1864
+ loss=loss,
1865
+ logits=logits,
1866
+ hidden_states=transformer_outputs.hidden_states,
1867
+ attentions=transformer_outputs.attentions,
1868
+ )
1869
+
1870
+
1871
+ @add_start_docstrings(
1872
+ """
1873
+ The AraGPT2 Model transformer with a span classification head on top for extractive question-answering tasks like
1874
+ SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1875
+ """,
1876
+ AraGPT2_START_DOCSTRING,
1877
+ )
1878
+ class AraGPT2ForQuestionAnswering(AraGPT2PreTrainedModel):
1879
+ def __init__(self, config: AraGPT2Config):
1880
+ super().__init__(config)
1881
+ self.num_labels = config.num_labels
1882
+ self.transformer = AraGPT2Model(config)
1883
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
1884
+
1885
+ # Model parallel
1886
+ self.model_parallel = False
1887
+ self.device_map = None
1888
+
1889
+ # Initialize weights and apply final processing
1890
+ self.post_init()
1891
+
1892
+ @add_start_docstrings_to_model_forward(
1893
+ GPT2_INPUTS_DOCSTRING.format("batch_size, sequence_length")
1894
+ )
1895
+ @add_code_sample_docstrings(
1896
+ checkpoint=_CHECKPOINT_FOR_DOC,
1897
+ output_type=QuestionAnsweringModelOutput,
1898
+ config_class=_CONFIG_FOR_DOC,
1899
+ real_checkpoint=_CHECKPOINT_FOR_DOC,
1900
+ )
1901
+ def forward(
1902
+ self,
1903
+ input_ids: Optional[torch.LongTensor] = None,
1904
+ attention_mask: Optional[torch.FloatTensor] = None,
1905
+ token_type_ids: Optional[torch.LongTensor] = None,
1906
+ position_ids: Optional[torch.LongTensor] = None,
1907
+ head_mask: Optional[torch.FloatTensor] = None,
1908
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1909
+ start_positions: Optional[torch.LongTensor] = None,
1910
+ end_positions: Optional[torch.LongTensor] = None,
1911
+ output_attentions: Optional[bool] = None,
1912
+ output_hidden_states: Optional[bool] = None,
1913
+ return_dict: Optional[bool] = None,
1914
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1915
+ r"""
1916
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1917
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1918
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1919
+ are not taken into account for computing the loss.
1920
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1921
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1922
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1923
+ are not taken into account for computing the loss.
1924
+ """
1925
+ return_dict = (
1926
+ return_dict if return_dict is not None else self.config.use_return_dict
1927
+ )
1928
+
1929
+ outputs = self.transformer(
1930
+ input_ids,
1931
+ attention_mask=attention_mask,
1932
+ token_type_ids=token_type_ids,
1933
+ position_ids=position_ids,
1934
+ head_mask=head_mask,
1935
+ inputs_embeds=inputs_embeds,
1936
+ output_attentions=output_attentions,
1937
+ output_hidden_states=output_hidden_states,
1938
+ return_dict=return_dict,
1939
+ )
1940
+
1941
+ sequence_output = outputs[0]
1942
+
1943
+ logits = self.qa_outputs(sequence_output)
1944
+ start_logits, end_logits = logits.split(1, dim=-1)
1945
+ start_logits = start_logits.squeeze(-1).contiguous()
1946
+ end_logits = end_logits.squeeze(-1).contiguous()
1947
+
1948
+ total_loss = None
1949
+ if start_positions is not None and end_positions is not None:
1950
+ # If we are on multi-GPU, split add a dimension
1951
+ if len(start_positions.size()) > 1:
1952
+ start_positions = start_positions.squeeze(-1).to(start_logits.device)
1953
+ if len(end_positions.size()) > 1:
1954
+ end_positions = end_positions.squeeze(-1).to(end_logits.device)
1955
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1956
+ ignored_index = start_logits.size(1)
1957
+ start_positions = start_positions.clamp(0, ignored_index)
1958
+ end_positions = end_positions.clamp(0, ignored_index)
1959
+
1960
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1961
+ start_loss = loss_fct(start_logits, start_positions)
1962
+ end_loss = loss_fct(end_logits, end_positions)
1963
+ total_loss = (start_loss + end_loss) / 2
1964
+
1965
+ if not return_dict:
1966
+ output = (start_logits, end_logits) + outputs[2:]
1967
+ return ((total_loss,) + output) if total_loss is not None else output
1968
+
1969
+ return QuestionAnsweringModelOutput(
1970
+ loss=total_loss,
1971
+ start_logits=start_logits,
1972
+ end_logits=end_logits,
1973
+ hidden_states=outputs.hidden_states,
1974
+ attentions=outputs.attentions,
1975
+ )