ydshieh commited on
Commit
a244e91
1 Parent(s): 8364b8b
.gitattributes CHANGED
@@ -16,3 +16,5 @@
16
  *.pth filter=lfs diff=lfs merge=lfs -text
17
  *tfevents* filter=lfs diff=lfs merge=lfs -text
18
  wit_data_dir/train/train.tsv filter=lfs diff=lfs merge=lfs -text
 
 
 
16
  *.pth filter=lfs diff=lfs merge=lfs -text
17
  *tfevents* filter=lfs diff=lfs merge=lfs -text
18
  wit_data_dir/train/train.tsv filter=lfs diff=lfs merge=lfs -text
19
+ wit_data_dir/dev/dev.tsv filter=lfs diff=lfs merge=lfs -text
20
+ wit_data_dir/test/test.tsv filter=lfs diff=lfs merge=lfs -text
generate.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+
3
+ current_path = os.path.dirname(os.path.abspath(__file__))
4
+ sys.path.append(current_path)
5
+
6
+
7
+
8
+ # Vit - as encoder
9
+ from transformers import ViTFeatureExtractor
10
+ from PIL import Image
11
+ import requests
12
+ import numpy as np
13
+
14
+ url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
15
+ image = Image.open(requests.get(url, stream=True).raw)
16
+
17
+ feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
18
+ encoder_inputs = feature_extractor(images=image, return_tensors="jax")
19
+ pixel_values = encoder_inputs.pixel_values
20
+
21
+ # GPT2 / GPT2LM - as decoder
22
+ from transformers import ViTFeatureExtractor, GPT2Tokenizer
23
+
24
+ name = 'asi/gpt-fr-cased-small'
25
+ tokenizer = GPT2Tokenizer.from_pretrained(name)
26
+ decoder_inputs = tokenizer("mon chien est mignon", return_tensors="jax", )
27
+ print(decoder_inputs)
28
+
29
+ # Setup the tokenizer for targets
30
+ with tokenizer.as_target_tokenizer():
31
+ labels = tokenizer(
32
+ ['un chien super beau' + ' ' + tokenizer.eos_token, 'un chat' + ' ' + tokenizer.eos_token], max_length=5, padding="max_length", truncation=True, return_tensors="np"
33
+ )
34
+ print(labels)
35
+ exit(0)
36
+
37
+ inputs = dict(decoder_inputs)
38
+ inputs['pixel_values'] = pixel_values
39
+ #print(inputs)
40
+
41
+
42
+ # With the LM head in GPT2LM
43
+ from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
44
+ flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_pretrained('./outputs-small-ds/ckpt_3',)
45
+
46
+ logits = flax_vit_gpt2_lm(**inputs)[0]
47
+ preds = np.argmax(logits, axis=-1)
48
+ print('=' * 60)
49
+ print('Flax: Vit + modified GPT2LM')
50
+ #print(preds)
51
+
52
+ max_length = 32
53
+ num_beams = 16
54
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
55
+ batch = {'pixel_values': pixel_values}
56
+ generation = flax_vit_gpt2_lm.generate(batch['pixel_values'], **gen_kwargs)
57
+ print(generation)
58
+
59
+ token_ids = np.array(generation.sequences)[0]
60
+ generation = tokenizer.decode(token_ids)
61
+ print(generation)
62
+
63
+ del flax_vit_gpt2_lm
run_summarization.py ADDED
@@ -0,0 +1,832 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for summarization.
18
+ """
19
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
+
21
+ import sys, os
22
+
23
+ current_path = os.path.dirname(os.path.abspath(__file__))
24
+ sys.path.append(current_path)
25
+
26
+ import logging
27
+ import os
28
+ import sys
29
+ import time
30
+ from dataclasses import dataclass, field
31
+ from functools import partial
32
+ from pathlib import Path
33
+ from typing import Callable, Optional
34
+
35
+ import datasets
36
+ import nltk # Here to have a nice missing dependency error message early on
37
+ import numpy as np
38
+ from datasets import Dataset, load_dataset, load_metric
39
+ from tqdm import tqdm
40
+
41
+ import jax
42
+ import jax.numpy as jnp
43
+ import optax
44
+ import transformers
45
+ from filelock import FileLock
46
+ from flax import jax_utils, traverse_util
47
+ from flax.jax_utils import unreplicate
48
+ from flax.training import train_state
49
+ from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
50
+ from transformers import (
51
+ CONFIG_MAPPING,
52
+ FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
53
+ AutoConfig,
54
+ AutoTokenizer,
55
+ FlaxAutoModelForSeq2SeqLM,
56
+ HfArgumentParser,
57
+ TrainingArguments,
58
+ is_tensorboard_available,
59
+ )
60
+ from transformers.file_utils import is_offline_mode
61
+
62
+ from transformers import ViTFeatureExtractor, GPT2Tokenizer, GPT2Config
63
+ from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
64
+
65
+ logger = logging.getLogger(__name__)
66
+
67
+ try:
68
+ nltk.data.find("tokenizers/punkt")
69
+ except (LookupError, OSError):
70
+ if is_offline_mode():
71
+ raise LookupError(
72
+ "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
73
+ )
74
+ with FileLock(".lock") as lock:
75
+ nltk.download("punkt", quiet=True)
76
+
77
+
78
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
79
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
80
+
81
+
82
+ @dataclass
83
+ class ModelArguments:
84
+ """
85
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
86
+ """
87
+
88
+ model_name_or_path: Optional[str] = field(
89
+ default=None,
90
+ metadata={
91
+ "help": "The model checkpoint for weights initialization."
92
+ "Don't set if you want to train a model from scratch."
93
+ },
94
+ )
95
+ model_type: Optional[str] = field(
96
+ default=None,
97
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
98
+ )
99
+ config_name: Optional[str] = field(
100
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
101
+ )
102
+ tokenizer_name: Optional[str] = field(
103
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
104
+ )
105
+ cache_dir: Optional[str] = field(
106
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
107
+ )
108
+ use_fast_tokenizer: bool = field(
109
+ default=True,
110
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
111
+ )
112
+ dtype: Optional[str] = field(
113
+ default="float32",
114
+ metadata={
115
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
116
+ },
117
+ )
118
+
119
+
120
+ @dataclass
121
+ class DataTrainingArguments:
122
+ """
123
+ Arguments pertaining to what data we are going to input our model for training and eval.
124
+ """
125
+
126
+ dataset_name: Optional[str] = field(
127
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
128
+ )
129
+ dataset_config_name: Optional[str] = field(
130
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
131
+ )
132
+ text_column: Optional[str] = field(
133
+ default=None,
134
+ metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
135
+ )
136
+ summary_column: Optional[str] = field(
137
+ default=None,
138
+ metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
139
+ )
140
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
141
+ validation_file: Optional[str] = field(
142
+ default=None,
143
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
144
+ )
145
+ max_source_length: Optional[int] = field(
146
+ default=1024,
147
+ metadata={
148
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
149
+ "than this will be truncated, sequences shorter will be padded."
150
+ },
151
+ )
152
+ max_target_length: Optional[int] = field(
153
+ default=128,
154
+ metadata={
155
+ "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
156
+ "than this will be truncated, sequences shorter will be padded."
157
+ },
158
+ )
159
+ val_max_target_length: Optional[int] = field(
160
+ default=None,
161
+ metadata={
162
+ "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
163
+ "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
164
+ "This argument is also used to override the `max_length` param of `model.generate`, which is used "
165
+ "during evaluation."
166
+ },
167
+ )
168
+ max_train_samples: Optional[int] = field(
169
+ default=None,
170
+ metadata={
171
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
172
+ "value if set."
173
+ },
174
+ )
175
+ max_eval_samples: Optional[int] = field(
176
+ default=None,
177
+ metadata={
178
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
179
+ "value if set."
180
+ },
181
+ )
182
+ max_predict_samples: Optional[int] = field(
183
+ default=None,
184
+ metadata={
185
+ "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
186
+ "value if set."
187
+ },
188
+ )
189
+ preprocessing_num_workers: Optional[int] = field(
190
+ default=None,
191
+ metadata={"help": "The number of processes to use for the preprocessing."},
192
+ )
193
+ source_prefix: Optional[str] = field(
194
+ default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
195
+ )
196
+ predict_with_generate: bool = field(
197
+ default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
198
+ )
199
+ num_beams: Optional[int] = field(
200
+ default=None,
201
+ metadata={
202
+ "help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
203
+ "which is used during evaluation."
204
+ },
205
+ )
206
+ overwrite_cache: bool = field(
207
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
208
+ )
209
+
210
+ def __post_init__(self):
211
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
212
+ raise ValueError("Need either a dataset name or a training/validation file.")
213
+ else:
214
+ if self.train_file is not None:
215
+ extension = self.train_file.split(".")[-1]
216
+ assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
217
+ if self.validation_file is not None:
218
+ extension = self.validation_file.split(".")[-1]
219
+ assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
220
+ if self.val_max_target_length is None:
221
+ self.val_max_target_length = self.max_target_length
222
+
223
+
224
+ summarization_name_mapping = {
225
+ "amazon_reviews_multi": ("review_body", "review_title"),
226
+ "big_patent": ("description", "abstract"),
227
+ "cnn_dailymail": ("article", "highlights"),
228
+ "orange_sum": ("text", "summary"),
229
+ "pn_summary": ("article", "summary"),
230
+ "psc": ("extract_text", "summary_text"),
231
+ "samsum": ("dialogue", "summary"),
232
+ "thaisum": ("body", "summary"),
233
+ "xglue": ("news_body", "news_title"),
234
+ "xsum": ("document", "summary"),
235
+ "wiki_summary": ("article", "highlights"),
236
+ }
237
+
238
+
239
+ class TrainState(train_state.TrainState):
240
+ dropout_rng: jnp.ndarray
241
+
242
+ def replicate(self):
243
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
244
+
245
+
246
+ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
247
+ """
248
+ Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
249
+ Shuffle batches if `shuffle` is `True`.
250
+ """
251
+ steps_per_epoch = len(dataset) // batch_size
252
+
253
+ if shuffle:
254
+ batch_idx = jax.random.permutation(rng, len(dataset))
255
+ else:
256
+ batch_idx = jnp.arange(len(dataset))
257
+
258
+ batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
259
+ batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
260
+
261
+ for idx in batch_idx:
262
+ batch = dataset[idx]
263
+ batch = {k: jnp.array(v) for k, v in batch.items()}
264
+
265
+ batch = shard(batch)
266
+
267
+ yield batch
268
+
269
+
270
+ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
271
+ summary_writer.scalar("train_time", train_time, step)
272
+
273
+ train_metrics = get_metrics(train_metrics)
274
+ for key, vals in train_metrics.items():
275
+ tag = f"train_{key}"
276
+ for i, val in enumerate(vals):
277
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
278
+
279
+ for metric_name, value in eval_metrics.items():
280
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
281
+
282
+
283
+ def create_learning_rate_fn(
284
+ train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
285
+ ) -> Callable[[int], jnp.array]:
286
+ """Returns a linear warmup, linear_decay learning rate function."""
287
+ steps_per_epoch = train_ds_size // train_batch_size
288
+ num_train_steps = steps_per_epoch * num_train_epochs
289
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
290
+ decay_fn = optax.linear_schedule(
291
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
292
+ )
293
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
294
+ return schedule_fn
295
+
296
+
297
+ def main():
298
+ # See all possible arguments in src/transformers/training_args.py
299
+ # or by passing the --help flag to this script.
300
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
301
+
302
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
303
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
304
+ # If we pass only one argument to the script and it's the path to a json file,
305
+ # let's parse it to get our arguments.
306
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
307
+ else:
308
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
309
+
310
+ if (
311
+ os.path.exists(training_args.output_dir)
312
+ and os.listdir(training_args.output_dir)
313
+ and training_args.do_train
314
+ and not training_args.overwrite_output_dir
315
+ ):
316
+ raise ValueError(
317
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
318
+ "Use --overwrite_output_dir to overcome."
319
+ )
320
+
321
+ # Make one log on every process with the configuration for debugging.
322
+ logging.basicConfig(
323
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
324
+ datefmt="%m/%d/%Y %H:%M:%S",
325
+ level=logging.INFO,
326
+ )
327
+ # Setup logging, we only want one process per machine to log things on the screen.
328
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
329
+ if jax.process_index() == 0:
330
+ datasets.utils.logging.set_verbosity_warning()
331
+ transformers.utils.logging.set_verbosity_info()
332
+ else:
333
+ datasets.utils.logging.set_verbosity_error()
334
+ transformers.utils.logging.set_verbosity_error()
335
+
336
+ # Set the verbosity to info of the Transformers logger (on main process only):
337
+ logger.info(f"Training/evaluation parameters {training_args}")
338
+
339
+ # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
340
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
341
+ # (the dataset will be downloaded automatically from the datasets Hub).
342
+ #
343
+ # For CSV/JSON files this script will use the first column for the full texts and the second column for the
344
+ # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments).
345
+ #
346
+ if data_args.dataset_name is not None:
347
+ # Downloading and loading a dataset from the hub.
348
+ dataset = load_dataset(
349
+ data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False, data_dir='./wit_data_dir/'
350
+ )
351
+ else:
352
+ data_files = {}
353
+ if data_args.train_file is not None:
354
+ data_files["train"] = data_args.train_file
355
+ extension = data_args.train_file.split(".")[-1]
356
+ if data_args.validation_file is not None:
357
+ data_files["validation"] = data_args.validation_file
358
+ extension = data_args.validation_file.split(".")[-1]
359
+ if data_args.test_file is not None:
360
+ data_files["test"] = data_args.test_file
361
+ extension = data_args.test_file.split(".")[-1]
362
+ dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
363
+
364
+ vit_name_path = 'google/vit-base-patch16-224-in21k'
365
+ gpt2_name_path = 'asi/gpt-fr-cased-small'
366
+
367
+ gpt2_config = GPT2Config.from_pretrained(gpt2_name_path)
368
+ gpt2_config.add_cross_attention = True
369
+
370
+
371
+ vit_gpt2_name_path = ''
372
+
373
+ feature_extractor = ViTFeatureExtractor.from_pretrained(vit_name_path)
374
+
375
+ tokenizer = GPT2Tokenizer.from_pretrained(gpt2_name_path)
376
+
377
+ if not vit_gpt2_name_path:
378
+ assert vit_name_path
379
+ assert gpt2_name_path
380
+ vit_gpt2_model = FlaxViTGPT2LMForConditionalGeneration.from_vit_gpt2_pretrained(
381
+ vit_name_path, gpt2_name_path
382
+ )
383
+ else:
384
+ vit_gpt2_model = FlaxViTGPT2LMForConditionalGeneration.from_pretrained(
385
+ vit_gpt2_name_path
386
+ )
387
+
388
+ model = vit_gpt2_model
389
+ model.config.is_encoder_decoder = True
390
+ model.config.decoder_start_token_id = gpt2_config.bos_token_id
391
+ model.config.bos_token_id = gpt2_config.bos_token_id
392
+ model.config.eos_token_id = gpt2_config.eos_token_id
393
+ model.config.pad_token_id = gpt2_config.pad_token_id
394
+
395
+ # Preprocessing the datasets.
396
+ # We need to tokenize inputs and targets.
397
+ if training_args.do_train:
398
+ column_names = dataset["train"].column_names
399
+ elif training_args.do_eval:
400
+ column_names = dataset["validation"].column_names
401
+ elif training_args.do_predict:
402
+ column_names = dataset["test"].column_names
403
+ else:
404
+ logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
405
+ return
406
+
407
+ image_file_column = 'image_file'
408
+ caption_column = 'caption'
409
+ pixels_file_column = 'pixels_file'
410
+
411
+ # Temporarily set max_target_length for training.
412
+ max_target_length = data_args.max_target_length
413
+
414
+ # In Flax, for seq2seq models we need to pass `decoder_input_ids`
415
+ # as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
416
+ # for that dynamically import the `shift_tokens_right` function from the model file
417
+ model_module = __import__(vit_gpt2_model.__module__, fromlist=["shift_tokens_right"])
418
+ shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
419
+
420
+ # Setting padding="max_length" as we need fixed length inputs for jitted functions
421
+ def preprocess_function(examples):
422
+
423
+ pixels_file = examples[pixels_file_column]
424
+ if not pixels_file:
425
+ assert examples[image_file_column]
426
+ _pixel_values = []
427
+ for y in examples[image_file_column]:
428
+ with Image.open(y) as image:
429
+ encoder_inputs = feature_extractor(images=image, return_tensors="np")
430
+ x = encoder_inputs.pixel_values
431
+ _pixel_values.append(x)
432
+ pixel_values = np.concatenate(_pixel_values)
433
+ else:
434
+ pixel_values = np.concatenate([np.load(x) for x in pixels_file])
435
+
436
+ targets = examples[caption_column]
437
+
438
+ # Add eos_token!!
439
+ targets = [x + ' ' + tokenizer.eos_token for x in targets]
440
+
441
+ model_inputs = {}
442
+ model_inputs['pixel_values'] = pixel_values
443
+
444
+ # Setup the tokenizer for targets
445
+ with tokenizer.as_target_tokenizer():
446
+ labels = tokenizer(
447
+ targets, max_length=max_target_length, padding="max_length", truncation=True, return_tensors="np"
448
+ )
449
+
450
+ model_inputs["labels"] = labels["input_ids"]
451
+
452
+ #print(labels["input_ids"])
453
+ #print(gpt2_config.pad_token_id)
454
+ #rint(gpt2_config.bos_token_id)
455
+
456
+ decoder_input_ids = shift_tokens_right_fn(
457
+ jnp.array(labels["input_ids"]), gpt2_config.pad_token_id, gpt2_config.bos_token_id
458
+ )
459
+ model_inputs["input_ids"] = np.asarray(decoder_input_ids)
460
+
461
+ # We need decoder_attention_mask so we can ignore pad tokens from loss
462
+ model_inputs["attention_mask"] = labels["attention_mask"]
463
+
464
+ return model_inputs
465
+
466
+ if training_args.do_train:
467
+ if "train" not in dataset:
468
+ raise ValueError("--do_train requires a train dataset")
469
+ train_dataset = dataset["train"]
470
+ if data_args.max_train_samples is not None:
471
+ train_dataset = train_dataset.select(range(data_args.max_train_samples))
472
+
473
+ train_dataset = train_dataset.map(
474
+ preprocess_function,
475
+ batched=True,
476
+ num_proc=data_args.preprocessing_num_workers,
477
+ remove_columns=column_names,
478
+ load_from_cache_file=not data_args.overwrite_cache,
479
+ desc="Running tokenizer on train dataset",
480
+ )
481
+
482
+ if training_args.do_eval:
483
+ max_target_length = data_args.val_max_target_length
484
+ if "validation" not in dataset:
485
+ raise ValueError("--do_eval requires a validation dataset")
486
+ eval_dataset = dataset["validation"]
487
+ if data_args.max_eval_samples is not None:
488
+ eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
489
+ eval_dataset = eval_dataset.map(
490
+ preprocess_function,
491
+ batched=True,
492
+ num_proc=data_args.preprocessing_num_workers,
493
+ remove_columns=column_names,
494
+ load_from_cache_file=not data_args.overwrite_cache,
495
+ desc="Running tokenizer on validation dataset",
496
+ )
497
+
498
+ if training_args.do_predict:
499
+ max_target_length = data_args.val_max_target_length
500
+ if "test" not in dataset:
501
+ raise ValueError("--do_predict requires a test dataset")
502
+ predict_dataset = dataset["test"]
503
+ if data_args.max_predict_samples is not None:
504
+ predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
505
+ predict_dataset = predict_dataset.map(
506
+ preprocess_function,
507
+ batched=True,
508
+ num_proc=data_args.preprocessing_num_workers,
509
+ remove_columns=column_names,
510
+ load_from_cache_file=not data_args.overwrite_cache,
511
+ desc="Running tokenizer on prediction dataset",
512
+ )
513
+
514
+ # Metric
515
+ metric = load_metric("rouge")
516
+
517
+ def postprocess_text(preds, labels):
518
+ preds = [pred.strip() for pred in preds]
519
+ labels = [label.strip() for label in labels]
520
+
521
+ # rougeLSum expects newline after each sentence
522
+ preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
523
+ labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
524
+
525
+ return preds, labels
526
+
527
+ def compute_metrics(preds, labels):
528
+ decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
529
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
530
+
531
+ # Some simple post-processing
532
+ decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
533
+
534
+ result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
535
+ # Extract a few results from ROUGE
536
+ result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
537
+
538
+ prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
539
+ result["gen_len"] = np.mean(prediction_lens)
540
+ result = {k: round(v, 4) for k, v in result.items()}
541
+ return result
542
+
543
+ # Enable tensorboard only on the master node
544
+ has_tensorboard = is_tensorboard_available()
545
+ if has_tensorboard and jax.process_index() == 0:
546
+ try:
547
+ from flax.metrics.tensorboard import SummaryWriter
548
+
549
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
550
+ except ImportError as ie:
551
+ has_tensorboard = False
552
+ logger.warning(
553
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
554
+ )
555
+ else:
556
+ logger.warning(
557
+ "Unable to display metrics through TensorBoard because the package is not installed: "
558
+ "Please run pip install tensorboard to enable."
559
+ )
560
+
561
+ # Initialize our training
562
+ rng = jax.random.PRNGKey(training_args.seed)
563
+ rng, dropout_rng = jax.random.split(rng)
564
+
565
+ # Store some constant
566
+ num_epochs = int(training_args.num_train_epochs)
567
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
568
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
569
+ steps_per_epoch = len(train_dataset) // train_batch_size
570
+ total_train_steps = steps_per_epoch * num_epochs
571
+
572
+ # Create learning rate schedule
573
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
574
+ len(train_dataset),
575
+ train_batch_size,
576
+ training_args.num_train_epochs,
577
+ training_args.warmup_steps,
578
+ training_args.learning_rate,
579
+ )
580
+
581
+ # We use Optax's "masking" functionality to not apply weight decay
582
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
583
+ # mask boolean with the same structure as the parameters.
584
+ # The mask is True for parameters that should be decayed.
585
+ # Note that this mask is specifically adapted for FlaxBart.
586
+ # For FlaxT5, one should correct the layer norm parameter naming
587
+ # accordingly - see `run_t5_mlm_flax.py` e.g.
588
+ def decay_mask_fn(params):
589
+ flat_params = traverse_util.flatten_dict(params)
590
+ layer_norm_params = [
591
+ (name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
592
+ ]
593
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
594
+ return traverse_util.unflatten_dict(flat_mask)
595
+
596
+ # create adam optimizer
597
+ adamw = optax.adamw(
598
+ learning_rate=linear_decay_lr_schedule_fn,
599
+ b1=training_args.adam_beta1,
600
+ b2=training_args.adam_beta2,
601
+ eps=training_args.adam_epsilon,
602
+ weight_decay=training_args.weight_decay,
603
+ mask=decay_mask_fn,
604
+ )
605
+
606
+ # Setup train state
607
+ state = TrainState.create(apply_fn=vit_gpt2_model.__call__, params=vit_gpt2_model.params, tx=adamw, dropout_rng=dropout_rng)
608
+
609
+ # label smoothed cross entropy
610
+ def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
611
+ """
612
+ The label smoothing implementation is adapted from Flax's official example:
613
+ https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
614
+ """
615
+ vocab_size = logits.shape[-1]
616
+ confidence = 1.0 - label_smoothing_factor
617
+ low_confidence = (1.0 - confidence) / (vocab_size - 1)
618
+ normalizing_constant = -(
619
+ confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
620
+ )
621
+ soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
622
+
623
+ loss = optax.softmax_cross_entropy(logits, soft_labels)
624
+ loss = loss - normalizing_constant
625
+
626
+ # ignore padded tokens from loss
627
+ loss = loss * padding_mask
628
+ loss = loss.sum() / padding_mask.sum()
629
+ return loss
630
+
631
+ # Define gradient update step fn
632
+ def train_step(state, batch, label_smoothing_factor=0.0):
633
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
634
+
635
+ def compute_loss(params):
636
+ labels = batch.pop("labels")
637
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
638
+ loss = loss_fn(logits, labels, batch["attention_mask"], label_smoothing_factor)
639
+ return loss
640
+
641
+ grad_fn = jax.value_and_grad(compute_loss)
642
+ loss, grad = grad_fn(state.params)
643
+ grad = jax.lax.pmean(grad, "batch")
644
+
645
+ new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
646
+
647
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
648
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
649
+
650
+ return new_state, metrics
651
+
652
+ # Define eval fn
653
+ def eval_step(params, batch, label_smoothing_factor=0.0):
654
+ labels = batch.pop("labels")
655
+ logits = model(**batch, params=params, train=False)[0]
656
+ loss = loss_fn(logits, labels, batch["attention_mask"], label_smoothing_factor)
657
+
658
+ # summarize metrics
659
+ metrics = {"loss": loss}
660
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
661
+ return metrics
662
+
663
+ # Define generation function
664
+ max_length = (
665
+ data_args.val_max_target_length if data_args.val_max_target_length is not None else model.config.max_length
666
+ )
667
+ num_beams = data_args.num_beams if data_args.num_beams is not None else model.config.num_beams
668
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
669
+
670
+ def generate_step(params, batch):
671
+ model.params = params
672
+ # output_ids = model.generate(batch["pixel_values"], **gen_kwargs)
673
+
674
+ #encoder_outputs = model.encode(pixel_values=batch['pixel_values'])
675
+ #output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], encoder_outputs=encoder_outputs, **gen_kwargs)
676
+
677
+ # encoder_outputs = model.encode(pixel_values=batch['pixel_values'], params=params, train=False)
678
+ output_ids = model.generate(batch['pixel_values'], **gen_kwargs)
679
+
680
+
681
+ return output_ids.sequences
682
+
683
+ # Create parallel version of the train and eval step
684
+ p_train_step = jax.pmap(
685
+ partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0,)
686
+ )
687
+ p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch")
688
+ p_generate_step = jax.pmap(generate_step, "batch")
689
+
690
+ # Replicate the train state on each device
691
+ state = state.replicate()
692
+
693
+ logger.info("***** Running training *****")
694
+ logger.info(f" Num examples = {len(train_dataset)}")
695
+ logger.info(f" Num Epochs = {num_epochs}")
696
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
697
+ logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
698
+ logger.info(f" Total optimization steps = {total_train_steps}")
699
+
700
+ train_time = 0
701
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
702
+ for epoch in epochs:
703
+ # ======================== Training ================================
704
+ train_start = time.time()
705
+
706
+ # Create sampling rng
707
+ rng, input_rng = jax.random.split(rng)
708
+ train_metrics = []
709
+
710
+ # Generate an epoch by shuffling sampling indices from the train dataset
711
+ train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
712
+ steps_per_epoch = len(train_dataset) // train_batch_size
713
+ # train
714
+ for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
715
+ batch = next(train_loader)
716
+ state, train_metric = p_train_step(state, batch)
717
+ train_metrics.append(train_metric)
718
+
719
+ train_time += time.time() - train_start
720
+
721
+ train_metric = unreplicate(train_metric)
722
+
723
+ desc = f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
724
+ epochs.write(desc)
725
+ epochs.desc = desc
726
+ logger.info(desc)
727
+ with open(os.path.join(training_args.output_dir, f'report.txt'), 'a', encoding='UTF-8') as fp:
728
+ fp.write(desc + '\n')
729
+
730
+
731
+ # ======================== Evaluating ==============================
732
+ eval_metrics = []
733
+ eval_preds = []
734
+ eval_labels = []
735
+
736
+ eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
737
+ eval_steps = len(eval_dataset) // eval_batch_size
738
+ for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
739
+ # Model forward
740
+ batch = next(eval_loader)
741
+ labels = batch["labels"]
742
+
743
+ metrics = p_eval_step(state.params, batch)
744
+ eval_metrics.append(metrics)
745
+
746
+ # generation
747
+ if data_args.predict_with_generate:
748
+ generated_ids = p_generate_step(state.params, batch)
749
+ eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
750
+ eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
751
+
752
+ # normalize eval metrics
753
+ eval_metrics = get_metrics(eval_metrics)
754
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
755
+
756
+ # compute ROUGE metrics
757
+ rouge_desc = ""
758
+ if data_args.predict_with_generate:
759
+ rouge_metrics = compute_metrics(eval_preds, eval_labels)
760
+ eval_metrics.update(rouge_metrics)
761
+ rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
762
+
763
+ # Print metrics and update progress bar
764
+ desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
765
+ epochs.write(desc)
766
+ epochs.desc = desc
767
+ logger.info(desc)
768
+ with open(os.path.join(training_args.output_dir, f'report.txt'), 'a', encoding='UTF-8') as fp:
769
+ fp.write(desc + '\n')
770
+
771
+
772
+ # Save metrics
773
+ if has_tensorboard and jax.process_index() == 0:
774
+ cur_step = epoch * (len(train_dataset) // train_batch_size)
775
+ write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
776
+
777
+ # ======================== Prediction loop ==============================
778
+ if training_args.do_predict:
779
+ logger.info("*** Predict ***")
780
+
781
+ pred_metrics = []
782
+ pred_generations = []
783
+ pred_labels = []
784
+
785
+ pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size)
786
+ pred_steps = len(predict_dataset) // eval_batch_size
787
+ for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False):
788
+ # Model forward
789
+ batch = next(pred_loader)
790
+ labels = batch["labels"]
791
+
792
+ metrics = p_eval_step(state.params, batch)
793
+ pred_metrics.append(metrics)
794
+
795
+ # generation
796
+ if data_args.predict_with_generate:
797
+ generated_ids = p_generate_step(state.params, batch)
798
+ pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
799
+ pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
800
+
801
+ # normalize prediction metrics
802
+ pred_metrics = get_metrics(pred_metrics)
803
+ pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
804
+
805
+ # compute ROUGE metrics
806
+ rouge_desc = ""
807
+ if data_args.predict_with_generate:
808
+ rouge_metrics = compute_metrics(pred_generations, pred_labels)
809
+ pred_metrics.update(rouge_metrics)
810
+ rouge_desc = " ".join([f"Predict {key}: {value} |" for key, value in rouge_metrics.items()])
811
+
812
+ # Print metrics
813
+ desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
814
+ epochs.write(desc)
815
+ epochs.desc = desc
816
+ logger.info(desc)
817
+ with open(os.path.join(training_args.output_dir, f'report.txt'), 'a', encoding='UTF-8') as fp:
818
+ fp.write(desc + '\n')
819
+
820
+
821
+ # save checkpoint after each epoch and push checkpoint to the hub
822
+ if jax.process_index() == 0:
823
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
824
+ model.save_pretrained(
825
+ os.path.join(training_args.output_dir, f'ckpt_{epoch+1}'),
826
+ params=params,
827
+ push_to_hub=training_args.push_to_hub,
828
+ commit_message=f"Saving weights and logs of epoch {epoch+1}",
829
+ )
830
+
831
+ if __name__ == "__main__":
832
+ main()
test_vit_gpt2.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+
3
+ current_path = os.path.dirname(os.path.abspath(__file__))
4
+ sys.path.append(current_path)
5
+
6
+ # Vit - as encoder
7
+ from transformers import ViTFeatureExtractor
8
+ from PIL import Image
9
+ import requests
10
+ import numpy as np
11
+
12
+ url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
13
+ image = Image.open(requests.get(url, stream=True).raw)
14
+
15
+ feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
16
+ encoder_inputs = feature_extractor(images=image, return_tensors="jax")
17
+ pixel_values = encoder_inputs.pixel_values
18
+
19
+ # GPT2 / GPT2LM - as decoder
20
+ from transformers import ViTFeatureExtractor, GPT2Tokenizer
21
+
22
+ name = 'asi/gpt-fr-cased-small'
23
+ tokenizer = GPT2Tokenizer.from_pretrained(name)
24
+ decoder_inputs = tokenizer("mon chien est mignon", return_tensors="jax")
25
+
26
+ inputs = dict(decoder_inputs)
27
+ inputs['pixel_values'] = pixel_values
28
+ print(inputs)
29
+
30
+ # With new added LM head
31
+ from vit_gpt2.modeling_flax_vit_gpt2 import FlaxViTGPT2ForConditionalGeneration
32
+ flax_vit_gpt2 = FlaxViTGPT2ForConditionalGeneration.from_vit_gpt2_pretrained(
33
+ 'google/vit-base-patch16-224-in21k', 'asi/gpt-fr-cased-small'
34
+ )
35
+ logits = flax_vit_gpt2(**inputs)[0]
36
+ preds = np.argmax(logits, axis=-1)
37
+ print('=' * 60)
38
+ print('Flax: Vit + modified GPT2 + LM')
39
+ print(preds)
40
+
41
+ del flax_vit_gpt2
42
+
43
+ # With the LM head in GPT2LM
44
+ from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
45
+ flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_vit_gpt2_pretrained(
46
+ 'google/vit-base-patch16-224-in21k', 'asi/gpt-fr-cased-small'
47
+ )
48
+
49
+ logits = flax_vit_gpt2_lm(**inputs)[0]
50
+ preds = np.argmax(logits, axis=-1)
51
+ print('=' * 60)
52
+ print('Flax: Vit + modified GPT2LM')
53
+ print(preds)
54
+
55
+ del flax_vit_gpt2_lm
56
+
57
+ # With PyTorch [Vit + unmodified GPT2LMHeadModel]
58
+ import torch
59
+ from transformers import ViTModel, GPT2Config, GPT2LMHeadModel
60
+
61
+ vit_model_pt = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
62
+ encoder_inputs = feature_extractor(images=image, return_tensors="pt")
63
+ vit_outputs = vit_model_pt(**encoder_inputs)
64
+ vit_last_hidden_states = vit_outputs.last_hidden_state
65
+
66
+ del vit_model_pt
67
+
68
+ inputs_pt = tokenizer("mon chien est mignon", return_tensors="pt")
69
+ inputs_pt = dict(inputs_pt)
70
+ inputs_pt['encoder_hidden_states'] = vit_last_hidden_states
71
+
72
+ config = GPT2Config.from_pretrained('asi/gpt-fr-cased-small')
73
+ config.add_cross_attention = True
74
+ gpt2_model_pt = GPT2LMHeadModel.from_pretrained('asi/gpt-fr-cased-small', config=config)
75
+
76
+ gp2lm_outputs = gpt2_model_pt(**inputs_pt)
77
+ logits_pt = gp2lm_outputs.logits
78
+ preds_pt = torch.argmax(logits_pt, dim=-1).cpu().detach().numpy()
79
+ print('=' * 60)
80
+ print('Pytorch: Vit + unmodified GPT2LM')
81
+ print(preds_pt)
82
+
83
+ del gpt2_model_pt
test_wit_dataset_script.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import json
3
+ import os
4
+
5
+ import datasets
6
+ import pandas as pd
7
+ import numpy as np
8
+
9
+ ds = datasets.load_dataset('./wit_dataset_script.py', data_dir='./wit_data_dir/')
10
+ test_ds = ds['test']
11
+
12
+
13
+ def transform(example):
14
+
15
+ example['pixel_values'] = np.load(example['pixels_file'])
16
+ return example
17
+
18
+
19
+ test_ds = test_ds.map(transform)
20
+
21
+ for x in test_ds:
22
+ print(x)
23
+ break
tests_load.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+
3
+ current_path = os.path.dirname(os.path.abspath(__file__))
4
+ sys.path.append(current_path)
5
+
6
+ # Vit - as encoder
7
+ from transformers import ViTFeatureExtractor
8
+ from PIL import Image
9
+ import requests
10
+ import numpy as np
11
+
12
+ url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
13
+ image = Image.open(requests.get(url, stream=True).raw)
14
+
15
+ feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
16
+ encoder_inputs = feature_extractor(images=image, return_tensors="jax")
17
+ pixel_values = encoder_inputs.pixel_values
18
+
19
+ # GPT2 / GPT2LM - as decoder
20
+ from transformers import ViTFeatureExtractor, GPT2Tokenizer
21
+
22
+ name = 'asi/gpt-fr-cased-small'
23
+ tokenizer = GPT2Tokenizer.from_pretrained(name)
24
+ decoder_inputs = tokenizer("mon chien est mignon", return_tensors="jax")
25
+
26
+ inputs = dict(decoder_inputs)
27
+ inputs['pixel_values'] = pixel_values
28
+ print(inputs)
29
+
30
+
31
+
32
+
33
+
34
+ # With the LM head in GPT2LM
35
+ from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
36
+ flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_pretrained(
37
+ '.',
38
+ )
39
+
40
+ logits = flax_vit_gpt2_lm(**inputs)[0]
41
+ preds = np.argmax(logits, axis=-1)
42
+ print('=' * 60)
43
+ print('Flax: Vit + modified GPT2LM')
44
+ print(preds)
45
+
46
+ # flax_vit_gpt2_lm.save_pretrained('.')
47
+
48
+ del flax_vit_gpt2_lm
tests_save.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+
3
+ current_path = os.path.dirname(os.path.abspath(__file__))
4
+ sys.path.append(current_path)
5
+
6
+ # Vit - as encoder
7
+ from transformers import ViTFeatureExtractor
8
+ from PIL import Image
9
+ import requests
10
+ import numpy as np
11
+
12
+ url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
13
+ image = Image.open(requests.get(url, stream=True).raw)
14
+
15
+ feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
16
+ encoder_inputs = feature_extractor(images=image, return_tensors="jax")
17
+ pixel_values = encoder_inputs.pixel_values
18
+
19
+ # GPT2 / GPT2LM - as decoder
20
+ from transformers import ViTFeatureExtractor, GPT2Tokenizer
21
+
22
+ name = 'asi/gpt-fr-cased-small'
23
+ tokenizer = GPT2Tokenizer.from_pretrained(name)
24
+ decoder_inputs = tokenizer("mon chien est mignon", return_tensors="jax")
25
+
26
+ inputs = dict(decoder_inputs)
27
+ inputs['pixel_values'] = pixel_values
28
+ print(inputs)
29
+
30
+
31
+
32
+
33
+
34
+ # With the LM head in GPT2LM
35
+ from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
36
+ flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_vit_gpt2_pretrained(
37
+ 'google/vit-base-patch16-224-in21k', 'asi/gpt-fr-cased-small'
38
+ )
39
+
40
+ logits = flax_vit_gpt2_lm(**inputs)[0]
41
+ preds = np.argmax(logits, axis=-1)
42
+ print('=' * 60)
43
+ print('Flax: Vit + modified GPT2LM')
44
+ print(preds)
45
+
46
+ flax_vit_gpt2_lm.save_pretrained('.')
47
+
48
+ del flax_vit_gpt2_lm
vit_gpt2/__init__.py ADDED
File without changes
vit_gpt2/configuration_vit_gpt2.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ from transformers import GPT2Config, ViTConfig
4
+ from transformers.configuration_utils import PretrainedConfig
5
+ from transformers.utils import logging
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+
10
+ class ViTGPT2Config(PretrainedConfig):
11
+
12
+ model_type = "vit-gpt2"
13
+ is_composition = True
14
+
15
+ def __init__(self, **kwargs):
16
+ super().__init__(**kwargs)
17
+
18
+ if "vit_config" not in kwargs:
19
+ raise ValueError("`vit_config` can not be `None`.")
20
+
21
+ if "gpt2_config" not in kwargs:
22
+ raise ValueError("`gpt2_config` can not be `None`.")
23
+
24
+ vit_config = kwargs.pop("vit_config")
25
+ gpt2_config = kwargs.pop("gpt2_config")
26
+
27
+ self.vit_config = ViTConfig(**vit_config)
28
+ self.gpt2_config = GPT2Config(**gpt2_config)
29
+
30
+ @classmethod
31
+ def from_vit_gpt2_configs(
32
+ cls, vit_config: PretrainedConfig, gpt2_config: PretrainedConfig, **kwargs
33
+ ):
34
+ return cls(
35
+ vit_config=vit_config.to_dict(),
36
+ gpt2_config=gpt2_config.to_dict(),
37
+ **kwargs
38
+ )
39
+
40
+ def to_dict(self):
41
+ output = copy.deepcopy(self.__dict__)
42
+ output["vit_config"] = self.vit_config.to_dict()
43
+ output["gpt2_config"] = self.gpt2_config.to_dict()
44
+ output["model_type"] = self.__class__.model_type
45
+ return output
vit_gpt2/modeling_flax_gpt2.py ADDED
@@ -0,0 +1,752 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, Optional, Tuple
17
+
18
+ import flax.linen as nn
19
+ import jax
20
+ import jax.numpy as jnp
21
+ from flax.core.frozen_dict import FrozenDict, unfreeze
22
+ from flax.linen import combine_masks, make_causal_mask
23
+ from flax.linen.attention import dot_product_attention_weights
24
+ from jax import lax
25
+
26
+ from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
27
+ from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPast, FlaxCausalLMOutput, FlaxBaseModelOutputWithPastAndCrossAttentions, FlaxSeq2SeqLMOutput
28
+ from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
29
+ from transformers.utils import logging
30
+ from transformers.models.gpt2.configuration_gpt2 import GPT2Config
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+ _CHECKPOINT_FOR_DOC = "gpt2"
36
+ _CONFIG_FOR_DOC = "GPT2Config"
37
+ _TOKENIZER_FOR_DOC = "GPT2Tokenizer"
38
+
39
+
40
+ GPT2_START_DOCSTRING = r"""
41
+
42
+ This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
43
+ generic methods the library implements for all its model (such as downloading or saving, resizing the input
44
+ embeddings, pruning heads etc.)
45
+
46
+ This model is also a Flax Linen `flax.nn.Module
47
+ <https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax
48
+ Module and refer to the Flax documentation for all matter related to general usage and behavior.
49
+
50
+ Finally, this model supports inherent JAX features such as:
51
+
52
+ - `Just-In-Time (JIT) compilation <https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit>`__
53
+ - `Automatic Differentiation <https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation>`__
54
+ - `Vectorization <https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap>`__
55
+ - `Parallelization <https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap>`__
56
+
57
+ Parameters:
58
+ config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model.
59
+ Initializing with a config file does not load the weights associated with the model, only the
60
+ configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
61
+ model weights.
62
+ """
63
+
64
+ GPT2_INPUTS_DOCSTRING = r"""
65
+ Args:
66
+ input_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, input_ids_length)`):
67
+ :obj:`input_ids_length` = ``sequence_length``. Indices of input sequence tokens in the vocabulary.
68
+
69
+ Indices can be obtained using :class:`~transformers.GPT2Tokenizer`. See
70
+ :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
71
+ details.
72
+
73
+ `What are input IDs? <../glossary.html#input-ids>`__
74
+ attention_mask (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
75
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
76
+
77
+ - 1 for tokens that are **not masked**,
78
+ - 0 for tokens that are **masked**.
79
+
80
+ `What are attention masks? <../glossary.html#attention-mask>`__
81
+ position_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
82
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
83
+ config.max_position_embeddings - 1]``.
84
+ past_key_values (:obj:`Dict[str, np.ndarray]`, `optional`, returned by ``init_cache`` or when passing previous ``past_key_values``):
85
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
86
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape `[batch_size, max_length]`.
87
+ output_attentions (:obj:`bool`, `optional`):
88
+ Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
89
+ tensors for more detail.
90
+ output_hidden_states (:obj:`bool`, `optional`):
91
+ Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
92
+ more detail.
93
+ return_dict (:obj:`bool`, `optional`):
94
+ Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
95
+ """
96
+
97
+
98
+ class FlaxConv1D(nn.Module):
99
+ features: int
100
+ use_bias: bool = True
101
+ dtype: Any = jnp.float32
102
+ precision: Any = None
103
+
104
+ @nn.compact
105
+ def __call__(self, inputs):
106
+ inputs = jnp.asarray(inputs, self.dtype)
107
+ kernel = self.param("kernel", jax.nn.initializers.normal(stddev=0.02), (self.features, inputs.shape[-1]))
108
+ kernel = jnp.asarray(kernel.transpose(), self.dtype)
109
+ y = lax.dot_general(inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())), precision=self.precision)
110
+ if self.use_bias:
111
+ bias = self.param("bias", jax.nn.initializers.zeros, (self.features,))
112
+ bias = jnp.asarray(bias, self.dtype)
113
+ y = y + bias
114
+ return y
115
+
116
+
117
+ class FlaxGPT2Attention(nn.Module):
118
+ config: GPT2Config
119
+ dtype: jnp.dtype = jnp.float32
120
+ causal: bool = True
121
+
122
+ def setup(self):
123
+ config = self.config
124
+ self.embed_dim = config.hidden_size
125
+ self.num_heads = config.num_attention_heads
126
+ self.head_dim = self.embed_dim // self.num_heads
127
+
128
+ self.c_attn = FlaxConv1D(features=3 * self.embed_dim, dtype=self.dtype)
129
+ self.c_proj = FlaxConv1D(self.embed_dim, dtype=self.dtype)
130
+
131
+ self.c_attn_for_k_v = FlaxConv1D(features=2 * self.embed_dim, dtype=self.dtype)
132
+
133
+ self.resid_dropout = nn.Dropout(rate=config.resid_pdrop)
134
+
135
+ if self.causal:
136
+ self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")
137
+
138
+ def _split_heads(self, hidden_states):
139
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
140
+
141
+ def _merge_heads(self, hidden_states):
142
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
143
+
144
+ @nn.compact
145
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
146
+ """
147
+ This function takes projected key, value states from a single input token and concatenates the states to cached
148
+ states from previous steps. This function is slighly adapted from the official Flax repository:
149
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
150
+ """
151
+ # detect if we're initializing by absence of existing cache data.
152
+ is_initialized = self.has_variable("cache", "cached_key")
153
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
154
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
155
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
156
+
157
+ if is_initialized:
158
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
159
+ # update key, value caches with our new 1d spatial slices
160
+ cur_index = cache_index.value
161
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
162
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
163
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
164
+ cached_key.value = key
165
+ cached_value.value = value
166
+ num_updated_cache_vectors = query.shape[1]
167
+ cache_index.value = cache_index.value + num_updated_cache_vectors
168
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
169
+ pad_mask = jnp.broadcast_to(
170
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
171
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
172
+ )
173
+ attention_mask = combine_masks(pad_mask, attention_mask)
174
+ return key, value, attention_mask
175
+
176
+ def __call__(
177
+ self,
178
+ hidden_states,
179
+ key_value_states: Optional[jnp.ndarray] = None,
180
+ attention_mask=None,
181
+ deterministic: bool = True,
182
+ init_cache: bool = False,
183
+ output_attentions: bool = False,
184
+ ):
185
+
186
+ # if key_value_states are provided this layer is used as a cross-attention layer
187
+ # for the decoder
188
+ is_cross_attention = key_value_states is not None
189
+
190
+ qkv_out = self.c_attn(hidden_states)
191
+ query, key, value = jnp.split(qkv_out, 3, axis=2)
192
+
193
+ if is_cross_attention:
194
+ _qkv_out = self.c_attn_for_k_v(key_value_states)
195
+ key, value = jnp.split(_qkv_out, 2, axis=2)
196
+
197
+ query = self._split_heads(query)
198
+ key = self._split_heads(key)
199
+ value = self._split_heads(value)
200
+
201
+ query_length, key_length = query.shape[1], key.shape[1]
202
+
203
+ if self.causal:
204
+ if self.has_variable("cache", "cached_key"):
205
+ mask_shift = self.variables["cache"]["cache_index"]
206
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
207
+ causal_mask = lax.dynamic_slice(
208
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
209
+ )
210
+ else:
211
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
212
+
213
+ batch_size = hidden_states.shape[0]
214
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
215
+
216
+ # combine masks if needed
217
+ if attention_mask is not None and self.causal:
218
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
219
+ attention_mask = combine_masks(attention_mask, causal_mask)
220
+ elif self.causal:
221
+ attention_mask = causal_mask
222
+ elif attention_mask is not None:
223
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
224
+
225
+ dropout_rng = None
226
+ if not deterministic and self.config.attn_pdrop > 0.0:
227
+ dropout_rng = self.make_rng("dropout")
228
+
229
+ # During fast autoregressive decoding, we feed one position at a time,
230
+ # and cache the keys and values step by step.
231
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
232
+ key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)
233
+
234
+ # transform boolean mask into float mask
235
+ if attention_mask is not None:
236
+ attention_bias = lax.select(
237
+ attention_mask > 0,
238
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
239
+ jnp.full(attention_mask.shape, -1e4).astype(self.dtype),
240
+ )
241
+ else:
242
+ attention_bias = None
243
+
244
+ # usual dot product attention
245
+ attn_weights = dot_product_attention_weights(
246
+ query,
247
+ key,
248
+ bias=attention_bias,
249
+ dropout_rng=dropout_rng,
250
+ dropout_rate=self.config.attn_pdrop,
251
+ deterministic=deterministic,
252
+ dtype=self.dtype,
253
+ precision=None,
254
+ )
255
+
256
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
257
+ attn_output = self._merge_heads(attn_output)
258
+ attn_output = self.c_proj(attn_output)
259
+ attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
260
+
261
+ outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
262
+ return outputs
263
+
264
+
265
+ class FlaxGPT2MLP(nn.Module):
266
+ config: GPT2Config
267
+ intermediate_size: int
268
+ dtype: jnp.dtype = jnp.float32
269
+
270
+ def setup(self):
271
+ embed_dim = self.config.hidden_size
272
+ self.c_fc = FlaxConv1D(self.intermediate_size, dtype=self.dtype)
273
+ self.c_proj = FlaxConv1D(embed_dim, dtype=self.dtype)
274
+ self.act = ACT2FN[self.config.activation_function]
275
+ self.dropout = nn.Dropout(rate=self.config.resid_pdrop)
276
+
277
+ def __call__(self, hidden_states, deterministic: bool = True):
278
+ hidden_states = self.c_fc(hidden_states)
279
+ hidden_states = self.act(hidden_states)
280
+ hidden_states = self.c_proj(hidden_states)
281
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
282
+ return hidden_states
283
+
284
+
285
+ class FlaxGPT2Block(nn.Module):
286
+ config: GPT2Config
287
+ dtype: jnp.dtype = jnp.float32
288
+
289
+ def setup(self):
290
+ hidden_size = self.config.hidden_size
291
+ inner_dim = self.config.n_inner if self.config.n_inner is not None else 4 * hidden_size
292
+
293
+ self.ln_1 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
294
+ self.attn = FlaxGPT2Attention(self.config, dtype=self.dtype)
295
+ self.ln_3 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
296
+ self.encoder_attn = FlaxGPT2Attention(config=self.config, dtype=self.dtype)
297
+ self.ln_2 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
298
+ self.mlp = FlaxGPT2MLP(self.config, inner_dim, dtype=self.dtype)
299
+
300
+ def __call__(
301
+ self,
302
+ hidden_states,
303
+ attention_mask=None,
304
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
305
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
306
+ deterministic: bool = True,
307
+ init_cache: bool = False,
308
+ output_attentions: bool = False,
309
+ ):
310
+ residual = hidden_states
311
+ hidden_states = self.ln_1(hidden_states)
312
+ outputs = self.attn(
313
+ hidden_states,
314
+ attention_mask=attention_mask,
315
+ deterministic=deterministic,
316
+ init_cache=init_cache,
317
+ output_attentions=output_attentions,
318
+ )
319
+ # residual connection
320
+ attn_output = outputs[0]
321
+ hidden_states = attn_output + residual
322
+
323
+ # Cross-Attention Block
324
+ if encoder_hidden_states is not None:
325
+
326
+ residual = hidden_states
327
+ hidden_states = self.ln_3(hidden_states)
328
+
329
+ cross_attn_outputs = self.encoder_attn(
330
+ hidden_states=hidden_states,
331
+ key_value_states=encoder_hidden_states,
332
+ attention_mask=encoder_attention_mask,
333
+ deterministic=deterministic,
334
+ output_attentions=output_attentions,
335
+ )
336
+
337
+ # residual connection
338
+ cross_attn_output = cross_attn_outputs[0]
339
+ hidden_states = cross_attn_output + residual
340
+
341
+ residual = hidden_states
342
+ hidden_states = self.ln_2(hidden_states)
343
+ feed_forward_hidden_states = self.mlp(hidden_states, deterministic=deterministic)
344
+ # residual connection
345
+ hidden_states = residual + feed_forward_hidden_states
346
+
347
+ output = (hidden_states,) + outputs[1:]
348
+ if encoder_hidden_states is not None:
349
+ output = output + cross_attn_outputs[1:]
350
+
351
+ return output
352
+
353
+
354
+ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
355
+ """
356
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
357
+ models.
358
+ """
359
+
360
+ config_class = GPT2Config
361
+ base_model_prefix = "transformer"
362
+ module_class: nn.Module = None
363
+
364
+ def __init__(
365
+ self,
366
+ config: GPT2Config,
367
+ input_shape: Tuple = (1, 1),
368
+ seed: int = 0,
369
+ dtype: jnp.dtype = jnp.float32,
370
+ **kwargs,
371
+ ):
372
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
373
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
374
+
375
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
376
+ # init input tensors
377
+ input_ids = jnp.zeros(input_shape, dtype="i4")
378
+ attention_mask = jnp.ones_like(input_ids)
379
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
380
+ params_rng, dropout_rng = jax.random.split(rng)
381
+ rngs = {"params": params_rng, "dropout": dropout_rng}
382
+
383
+ if self.config.add_cross_attention:
384
+ encoder_hidden_states = jnp.zeros(input_shape + (self.config.n_embd,))
385
+ encoder_attention_mask = attention_mask
386
+ module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, encoder_hidden_states, encoder_attention_mask, return_dict=False)
387
+ else:
388
+ module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
389
+
390
+ return module_init_outputs["params"]
391
+
392
+ @classmethod
393
+ def _from_config(cls, config, **kwargs):
394
+ return super()._from_config(config, **kwargs)
395
+
396
+ def init_cache(self, batch_size, max_length):
397
+ r"""
398
+ Args:
399
+ batch_size (:obj:`int`):
400
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
401
+ max_length (:obj:`int`):
402
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
403
+ cache.
404
+ """
405
+ # init input variables to retrieve cache
406
+ input_ids = jnp.ones((batch_size, max_length))
407
+ attention_mask = jnp.ones_like(input_ids)
408
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
409
+
410
+ init_variables = self.module.init(
411
+ jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
412
+ )
413
+ return init_variables["cache"]
414
+
415
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
416
+ def __call__(
417
+ self,
418
+ input_ids,
419
+ attention_mask=None,
420
+ position_ids=None,
421
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
422
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
423
+ params: dict = None,
424
+ past_key_values: dict = None,
425
+ dropout_rng: jax.random.PRNGKey = None,
426
+ train: bool = False,
427
+ output_attentions: Optional[bool] = None,
428
+ output_hidden_states: Optional[bool] = None,
429
+ return_dict: Optional[bool] = None,
430
+ ):
431
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
432
+ output_hidden_states = (
433
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
434
+ )
435
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
436
+
437
+ if encoder_hidden_states is not None and encoder_attention_mask is None:
438
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
439
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
440
+
441
+ batch_size, sequence_length = input_ids.shape
442
+
443
+ if position_ids is None:
444
+ if past_key_values is not None:
445
+ raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
446
+
447
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
448
+
449
+ if attention_mask is None:
450
+ attention_mask = jnp.ones((batch_size, sequence_length))
451
+
452
+ # Handle any PRNG if needed
453
+ rngs = {}
454
+ if dropout_rng is not None:
455
+ rngs["dropout"] = dropout_rng
456
+
457
+ inputs = {"params": params or self.params}
458
+
459
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPT2Attention module
460
+ if past_key_values:
461
+ inputs["cache"] = past_key_values
462
+ mutable = ["cache"]
463
+ else:
464
+ mutable = False
465
+
466
+ outputs = self.module.apply(
467
+ inputs,
468
+ jnp.array(input_ids, dtype="i4"),
469
+ jnp.array(attention_mask, dtype="i4"),
470
+ jnp.array(position_ids, dtype="i4"),
471
+ encoder_hidden_states,
472
+ encoder_attention_mask,
473
+ not train,
474
+ False,
475
+ output_attentions,
476
+ output_hidden_states,
477
+ return_dict,
478
+ rngs=rngs,
479
+ mutable=mutable,
480
+ )
481
+
482
+ # add updated cache to model output
483
+ if past_key_values is not None and return_dict:
484
+ outputs, past_key_values = outputs
485
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
486
+ return outputs
487
+ elif past_key_values is not None and not return_dict:
488
+ outputs, past_key_values = outputs
489
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
490
+
491
+ return outputs
492
+
493
+
494
+ class FlaxGPT2BlockCollection(nn.Module):
495
+ config: GPT2Config
496
+ dtype: jnp.dtype = jnp.float32
497
+
498
+ def setup(self):
499
+ self.blocks = [
500
+ FlaxGPT2Block(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
501
+ ]
502
+
503
+ def __call__(
504
+ self,
505
+ hidden_states,
506
+ attention_mask=None,
507
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
508
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
509
+ deterministic: bool = True,
510
+ init_cache: bool = False,
511
+ output_attentions: bool = False,
512
+ output_hidden_states: bool = False,
513
+ return_dict: bool = True,
514
+ ):
515
+ all_attentions = () if output_attentions else None
516
+ all_hidden_states = () if output_hidden_states else None
517
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
518
+
519
+ for block in self.blocks:
520
+ if output_hidden_states:
521
+ all_hidden_states += (hidden_states,)
522
+
523
+ layer_outputs = block(
524
+ hidden_states,
525
+ attention_mask,
526
+ encoder_hidden_states=encoder_hidden_states,
527
+ encoder_attention_mask=encoder_attention_mask,
528
+ deterministic=deterministic,
529
+ init_cache=init_cache,
530
+ output_attentions=output_attentions,
531
+ )
532
+ hidden_states = layer_outputs[0]
533
+
534
+ if output_attentions:
535
+ all_attentions += (layer_outputs[1],)
536
+ if encoder_hidden_states is not None:
537
+ all_cross_attentions += (layer_outputs[2],)
538
+
539
+ if output_hidden_states:
540
+ all_hidden_states += (hidden_states,)
541
+
542
+ outputs = [hidden_states, all_hidden_states, all_attentions, all_cross_attentions]
543
+
544
+ if not return_dict:
545
+ return tuple(v for v in outputs if v is not None)
546
+
547
+ if encoder_hidden_states is None:
548
+ return FlaxBaseModelOutputWithPast(
549
+ last_hidden_state=hidden_states,
550
+ past_key_values=None,
551
+ hidden_states=all_hidden_states,
552
+ attentions=all_attentions,
553
+ )
554
+ else:
555
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
556
+ last_hidden_state=hidden_states,
557
+ past_key_values=None,
558
+ hidden_states=all_hidden_states,
559
+ attentions=all_attentions,
560
+ cross_attentions=all_cross_attentions,
561
+ )
562
+
563
+ class FlaxGPT2Module(nn.Module):
564
+ config: GPT2Config
565
+ dtype: jnp.dtype = jnp.float32
566
+
567
+ def setup(self):
568
+ self.embed_dim = self.config.hidden_size
569
+
570
+ self.wte = nn.Embed(
571
+ self.config.vocab_size,
572
+ self.embed_dim,
573
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
574
+ dtype=self.dtype,
575
+ )
576
+ self.wpe = nn.Embed(
577
+ self.config.max_position_embeddings,
578
+ self.embed_dim,
579
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
580
+ dtype=self.dtype,
581
+ )
582
+ self.dropout = nn.Dropout(rate=self.config.embd_pdrop)
583
+ self.h = FlaxGPT2BlockCollection(self.config, dtype=self.dtype)
584
+ self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
585
+
586
+ def __call__(
587
+ self,
588
+ input_ids,
589
+ attention_mask,
590
+ position_ids,
591
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
592
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
593
+ deterministic=True,
594
+ init_cache: bool = False,
595
+ output_attentions: bool = False,
596
+ output_hidden_states: bool = False,
597
+ return_dict: bool = True,
598
+ ):
599
+ input_embeds = self.wte(input_ids.astype("i4"))
600
+ position_embeds = self.wpe(position_ids.astype("i4"))
601
+
602
+ hidden_states = input_embeds + position_embeds
603
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
604
+
605
+ outputs = self.h(
606
+ hidden_states,
607
+ attention_mask,
608
+ encoder_hidden_states,
609
+ encoder_attention_mask,
610
+ deterministic=deterministic,
611
+ init_cache=init_cache,
612
+ output_attentions=output_attentions,
613
+ output_hidden_states=output_hidden_states,
614
+ return_dict=return_dict,
615
+ )
616
+
617
+ hidden_states = outputs[0]
618
+ hidden_states = self.ln_f(hidden_states)
619
+
620
+ if not return_dict:
621
+ return (hidden_states,) + outputs[1:]
622
+
623
+ if encoder_hidden_states is None:
624
+ return FlaxBaseModelOutput(
625
+ last_hidden_state=hidden_states,
626
+ hidden_states=outputs.hidden_states,
627
+ attentions=outputs.attentions,
628
+ )
629
+ else:
630
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
631
+ last_hidden_state=hidden_states,
632
+ hidden_states=outputs.hidden_states,
633
+ attentions=outputs.attentions,
634
+ cross_attentions=outputs.cross_attentions,
635
+ )
636
+
637
+ @add_start_docstrings(
638
+ "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
639
+ GPT2_START_DOCSTRING,
640
+ )
641
+ class FlaxGPT2Model(FlaxGPT2PreTrainedModel):
642
+ module_class = FlaxGPT2Module
643
+
644
+
645
+ append_call_sample_docstring(
646
+ FlaxGPT2Model, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC
647
+ )
648
+
649
+
650
+ class FlaxGPT2LMHeadModule(nn.Module):
651
+ config: GPT2Config
652
+ dtype: jnp.dtype = jnp.float32
653
+
654
+ def setup(self):
655
+ self.transformer = FlaxGPT2Module(self.config, dtype=self.dtype)
656
+ self.lm_head = nn.Dense(
657
+ self.config.vocab_size,
658
+ use_bias=False,
659
+ dtype=self.dtype,
660
+ kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range, dtype=self.dtype),
661
+ )
662
+
663
+ def __call__(
664
+ self,
665
+ input_ids,
666
+ attention_mask,
667
+ position_ids,
668
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
669
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
670
+ deterministic: bool = True,
671
+ init_cache: bool = False,
672
+ output_attentions: bool = False,
673
+ output_hidden_states: bool = False,
674
+ return_dict: bool = True,
675
+ ):
676
+ outputs = self.transformer(
677
+ input_ids,
678
+ attention_mask,
679
+ position_ids,
680
+ encoder_hidden_states,
681
+ encoder_attention_mask,
682
+ deterministic=deterministic,
683
+ init_cache=init_cache,
684
+ output_attentions=output_attentions,
685
+ output_hidden_states=output_hidden_states,
686
+ return_dict=return_dict,
687
+ )
688
+
689
+ hidden_states = outputs[0]
690
+
691
+ if self.config.tie_word_embeddings:
692
+ shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T
693
+ lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
694
+ else:
695
+ lm_logits = self.lm_head(hidden_states)
696
+
697
+ if not return_dict:
698
+ return (lm_logits,) + outputs[1:]
699
+
700
+ if encoder_hidden_states is None:
701
+ return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
702
+ else:
703
+ return FlaxSeq2SeqLMOutput(
704
+ logits=lm_logits,
705
+ decoder_hidden_states=outputs.hidden_states,
706
+ decoder_attentions=outputs.attentions,
707
+ cross_attentions=outputs.cross_attentions,
708
+ encoder_last_hidden_state=encoder_hidden_states,
709
+ encoder_hidden_states=None,
710
+ encoder_attentions=None,
711
+ )
712
+
713
+ @add_start_docstrings(
714
+ """
715
+ The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
716
+ embeddings).
717
+ """,
718
+ GPT2_START_DOCSTRING,
719
+ )
720
+ class FlaxGPT2LMHeadModel(FlaxGPT2PreTrainedModel):
721
+ module_class = FlaxGPT2LMHeadModule
722
+
723
+ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
724
+ # initializing the cache
725
+ batch_size, seq_length = input_ids.shape
726
+
727
+ past_key_values = self.init_cache(batch_size, max_length)
728
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
729
+ # But since GPT2 uses a causal mask, those positions are masked anyways.
730
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
731
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
732
+ if attention_mask is not None:
733
+ position_ids = attention_mask.cumsum(axis=-1) - 1
734
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
735
+ else:
736
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
737
+
738
+ return {
739
+ "past_key_values": past_key_values,
740
+ "attention_mask": extended_attention_mask,
741
+ "position_ids": position_ids,
742
+ }
743
+
744
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
745
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
746
+ model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
747
+ return model_kwargs
748
+
749
+
750
+ append_call_sample_docstring(
751
+ FlaxGPT2LMHeadModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutput, _CONFIG_FOR_DOC
752
+ )
vit_gpt2/modeling_flax_vit_gpt2.py ADDED
@@ -0,0 +1,704 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Tuple
2
+
3
+ import flax.linen as nn
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from flax.core.frozen_dict import FrozenDict, unfreeze
7
+ from jax import lax
8
+ from jax.random import PRNGKey
9
+ from transformers import GPT2Config, FlaxViTModel, ViTConfig
10
+ from transformers.modeling_flax_outputs import (
11
+ FlaxCausalLMOutputWithCrossAttentions,
12
+ FlaxSeq2SeqLMOutput,
13
+ FlaxSeq2SeqModelOutput,
14
+ )
15
+ from transformers.models.bart.modeling_flax_bart import (
16
+ shift_tokens_right,
17
+ )
18
+ from .modeling_flax_gpt2 import (
19
+ FlaxGPT2Module,
20
+ FlaxGPT2Model,
21
+ FlaxPreTrainedModel
22
+ )
23
+ from transformers.models.vit.modeling_flax_vit import FlaxViTModule
24
+
25
+ from .configuration_vit_gpt2 import ViTGPT2Config
26
+
27
+
28
+ class FlaxViTGPT2Module(nn.Module):
29
+ config: ViTGPT2Config
30
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
31
+
32
+ def setup(self):
33
+
34
+ self.encoder = FlaxViTModule(self.config.vit_config, dtype=self.dtype)
35
+ self.decoder = FlaxGPT2Module(self.config.gpt2_config, dtype=self.dtype)
36
+
37
+ def _get_encoder_module(self):
38
+ return self.encoder
39
+
40
+ def _get_decoder_module(self):
41
+ return self.decoder
42
+
43
+ def __call__(
44
+ self,
45
+ pixel_values,
46
+ input_ids,
47
+ attention_mask,
48
+ position_ids,
49
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
50
+ output_attentions: bool = False,
51
+ output_hidden_states: bool = False,
52
+ return_dict: bool = True,
53
+ deterministic: bool = True,
54
+ ):
55
+ encoder_outputs = self.encoder(
56
+ pixel_values=pixel_values,
57
+ deterministic=deterministic,
58
+ output_attentions=output_attentions,
59
+ output_hidden_states=output_hidden_states,
60
+ return_dict=return_dict,
61
+ )
62
+
63
+ decoder_outputs = self.decoder(
64
+ input_ids=input_ids,
65
+ attention_mask=attention_mask,
66
+ position_ids=position_ids,
67
+ encoder_hidden_states=encoder_outputs[0],
68
+ encoder_attention_mask=encoder_attention_mask,
69
+ deterministic=deterministic,
70
+ output_attentions=output_attentions,
71
+ output_hidden_states=output_hidden_states,
72
+ return_dict=return_dict
73
+ )
74
+
75
+ return FlaxSeq2SeqModelOutput(
76
+ last_hidden_state=decoder_outputs.last_hidden_state,
77
+ decoder_hidden_states=decoder_outputs.hidden_states,
78
+ decoder_attentions=decoder_outputs.attentions,
79
+ cross_attentions=decoder_outputs.cross_attentions,
80
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
81
+ encoder_hidden_states=encoder_outputs.hidden_states,
82
+ encoder_attentions=encoder_outputs.attentions,
83
+ )
84
+
85
+
86
+ class FlaxViTGPT2ForConditionalGenerationModule(nn.Module):
87
+ config: ViTGPT2Config
88
+ dtype: jnp.dtype = jnp.float32
89
+ bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros
90
+
91
+ def setup(self):
92
+ self.model = FlaxViTGPT2Module(config=self.config, dtype=self.dtype)
93
+ self.lm_head = nn.Dense(
94
+ self.model.decoder.embed_dim,
95
+ use_bias=False,
96
+ dtype=self.dtype,
97
+ kernel_init=jax.nn.initializers.normal(
98
+ self.config.gpt2_config.initializer_range, self.dtype
99
+ ),
100
+ )
101
+ self.final_logits_bias = self.param(
102
+ "final_logits_bias", self.bias_init, (1, self.model.decoder.embed_dim)
103
+ )
104
+
105
+ def _get_encoder_module(self):
106
+ return self.model.encoder
107
+
108
+ def _get_decoder_module(self):
109
+ return self.model.decoder
110
+
111
+ def __call__(
112
+ self,
113
+ pixel_values,
114
+ input_ids,
115
+ attention_mask,
116
+ position_ids,
117
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
118
+ output_attentions: bool = False,
119
+ output_hidden_states: bool = False,
120
+ return_dict: bool = True,
121
+ deterministic: bool = True,
122
+ ):
123
+ outputs = self.model(
124
+ pixel_values=pixel_values,
125
+ input_ids=input_ids,
126
+ attention_mask=attention_mask,
127
+ position_ids=position_ids,
128
+ encoder_attention_mask=encoder_attention_mask,
129
+ output_attentions=output_attentions,
130
+ output_hidden_states=output_hidden_states,
131
+ return_dict=return_dict,
132
+ deterministic=deterministic,
133
+ )
134
+
135
+ hidden_states = outputs[0]
136
+ lm_logits = self.lm_head(hidden_states)
137
+ lm_logits += self.final_logits_bias
138
+
139
+ if not return_dict:
140
+ output = (lm_logits,) + outputs[1:]
141
+ return output
142
+
143
+ return FlaxSeq2SeqLMOutput(
144
+ logits=lm_logits,
145
+ decoder_hidden_states=outputs.decoder_hidden_states,
146
+ decoder_attentions=outputs.decoder_attentions,
147
+ cross_attentions=outputs.cross_attentions,
148
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
149
+ encoder_hidden_states=outputs.encoder_hidden_states,
150
+ encoder_attentions=outputs.encoder_attentions,
151
+ )
152
+
153
+ class FlaxViTGPT2PreTrainedModel(FlaxPreTrainedModel):
154
+ config_class = ViTGPT2Config
155
+ base_model_prefix: str = "model"
156
+ module_class: nn.Module = None
157
+
158
+ def __init__(
159
+ self,
160
+ config: ViTGPT2Config,
161
+ input_shape: Tuple = None,
162
+ seed: int = 0,
163
+ dtype: jnp.dtype = jnp.float32,
164
+ **kwargs,
165
+ ):
166
+ if input_shape is None:
167
+ input_shape = (
168
+ (1, config.vit_config.image_size, config.vit_config.image_size, 3),
169
+ (1, 1),
170
+ )
171
+
172
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
173
+ super().__init__(
174
+ config, module, input_shape=input_shape, seed=seed, dtype=dtype
175
+ )
176
+
177
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
178
+ # init input tensors
179
+ pixel_values = jax.random.normal(rng, input_shape[0])
180
+ # # make sure initialization pass will work for FlaxBartForSequenceClassificationModule
181
+ # input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id)
182
+
183
+ input_ids = jnp.zeros(input_shape[1], dtype="i4")
184
+ attention_mask = jnp.ones_like(input_ids)
185
+
186
+ batch_size, sequence_length = input_ids.shape
187
+ position_ids = jnp.broadcast_to(
188
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
189
+ )
190
+
191
+ params_rng, dropout_rng = jax.random.split(rng)
192
+ rngs = {"params": params_rng, "dropout": dropout_rng}
193
+
194
+ return self.module.init(
195
+ rngs,
196
+ pixel_values,
197
+ input_ids,
198
+ attention_mask,
199
+ position_ids,
200
+ )["params"]
201
+
202
+ def init_cache(self, batch_size, max_length, encoder_outputs):
203
+
204
+ input_ids = jnp.ones((batch_size, max_length), dtype="i4")
205
+ attention_mask = jnp.ones_like(input_ids)
206
+ position_ids = jnp.broadcast_to(
207
+ jnp.arange(jnp.atleast_2d(input_ids).shape[-1]),
208
+ input_ids.shape,
209
+ )
210
+
211
+ def _decoder_forward(
212
+ module,
213
+ input_ids,
214
+ attention_mask,
215
+ position_ids,
216
+ **kwargs,
217
+ ):
218
+ decoder_module = module._get_decoder_module()
219
+ return decoder_module(
220
+ input_ids,
221
+ attention_mask,
222
+ position_ids,
223
+ **kwargs,
224
+ )
225
+
226
+ init_variables = self.module.init(
227
+ jax.random.PRNGKey(0),
228
+ input_ids=input_ids,
229
+ attention_mask=attention_mask,
230
+ position_ids=position_ids,
231
+ encoder_hidden_states=encoder_outputs[0],
232
+ init_cache=True,
233
+ method=_decoder_forward, # we only need to call the decoder to init the cache
234
+ )
235
+ return unfreeze(init_variables["cache"])
236
+
237
+ def encode(
238
+ self,
239
+ pixel_values: jnp.ndarray,
240
+ output_attentions: Optional[bool] = None,
241
+ output_hidden_states: Optional[bool] = None,
242
+ return_dict: Optional[bool] = None,
243
+ train: bool = False,
244
+ params: dict = None,
245
+ dropout_rng: PRNGKey = None,
246
+ ):
247
+ output_attentions = (
248
+ output_attentions
249
+ if output_attentions is not None
250
+ else self.config.output_attentions
251
+ )
252
+ output_hidden_states = (
253
+ output_hidden_states
254
+ if output_hidden_states is not None
255
+ else self.config.output_hidden_states
256
+ )
257
+ return_dict = (
258
+ return_dict if return_dict is not None else self.config.return_dict
259
+ )
260
+
261
+ # Handle any PRNG if needed
262
+ rngs = {}
263
+ if dropout_rng is not None:
264
+ rngs["dropout"] = dropout_rng
265
+
266
+ def _encoder_forward(module, pixel_values, **kwargs):
267
+ encode_module = module._get_encoder_module()
268
+ return encode_module(pixel_values, **kwargs)
269
+
270
+ return self.module.apply(
271
+ {"params": params or self.params},
272
+ pixel_values=jnp.array(pixel_values, dtype="i4"),
273
+ output_attentions=output_attentions,
274
+ output_hidden_states=output_hidden_states,
275
+ return_dict=return_dict,
276
+ deterministic=not train,
277
+ rngs=rngs,
278
+ method=_encoder_forward,
279
+ )
280
+
281
+ def decode(
282
+ self,
283
+ input_ids,
284
+ encoder_outputs,
285
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
286
+ attention_mask: Optional[jnp.ndarray] = None,
287
+ position_ids: Optional[jnp.ndarray] = None,
288
+ past_key_values: dict = None,
289
+ output_attentions: Optional[bool] = None,
290
+ output_hidden_states: Optional[bool] = None,
291
+ return_dict: Optional[bool] = None,
292
+ train: bool = False,
293
+ params: dict = None,
294
+ dropout_rng: PRNGKey = None,
295
+ ):
296
+
297
+ output_attentions = (
298
+ output_attentions
299
+ if output_attentions is not None
300
+ else self.config.output_attentions
301
+ )
302
+ output_hidden_states = (
303
+ output_hidden_states
304
+ if output_hidden_states is not None
305
+ else self.config.output_hidden_states
306
+ )
307
+ return_dict = (
308
+ return_dict if return_dict is not None else self.config.return_dict
309
+ )
310
+
311
+ encoder_hidden_states = encoder_outputs[0]
312
+ if encoder_attention_mask is None:
313
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
314
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
315
+
316
+ batch_size, sequence_length = input_ids.shape
317
+ if attention_mask is None:
318
+ attention_mask = jnp.ones((batch_size, sequence_length))
319
+
320
+ if position_ids is None:
321
+ if past_key_values is not None:
322
+ raise ValueError(
323
+ "Make sure to provide `position_ids` when passing `past_key_values`."
324
+ )
325
+
326
+ position_ids = jnp.broadcast_to(
327
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
328
+ )
329
+
330
+ # Handle any PRNG if needed
331
+ rngs = {}
332
+ if dropout_rng is not None:
333
+ rngs["dropout"] = dropout_rng
334
+
335
+ inputs = {"params": params or self.params}
336
+
337
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
338
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
339
+ # it can be changed by FlaxGPT2Attention module
340
+ if past_key_values:
341
+ inputs["cache"] = past_key_values
342
+ mutable = ["cache"]
343
+ else:
344
+ mutable = False
345
+
346
+ def _decoder_forward(
347
+ module,
348
+ input_ids,
349
+ attention_mask,
350
+ position_ids,
351
+ **kwargs,
352
+ ):
353
+ decoder_module = module._get_decoder_module()
354
+ return decoder_module(
355
+ input_ids,
356
+ attention_mask,
357
+ position_ids,
358
+ **kwargs,
359
+ )
360
+
361
+ outputs = self.module.apply(
362
+ inputs,
363
+ input_ids=jnp.array(input_ids, dtype="i4"),
364
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
365
+ position_ids=jnp.array(position_ids, dtype="i4"),
366
+ encoder_hidden_states=encoder_hidden_states,
367
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
368
+ output_attentions=output_attentions,
369
+ output_hidden_states=output_hidden_states,
370
+ return_dict=return_dict,
371
+ deterministic=not train,
372
+ rngs=rngs,
373
+ mutable=mutable,
374
+ method=_decoder_forward,
375
+ )
376
+
377
+ # add updated cache to model output
378
+ if past_key_values is not None and return_dict:
379
+ outputs, past = outputs
380
+ outputs["past_key_values"] = unfreeze(past["cache"])
381
+ return outputs
382
+ elif past_key_values is not None and not return_dict:
383
+ outputs, past = outputs
384
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
385
+
386
+ return outputs
387
+
388
+ def __call__(
389
+ self,
390
+ pixel_values: jnp.ndarray,
391
+ input_ids: Optional[jnp.ndarray] = None,
392
+ attention_mask: Optional[jnp.ndarray] = None,
393
+ position_ids: Optional[jnp.ndarray] = None,
394
+ output_attentions: Optional[bool] = None,
395
+ output_hidden_states: Optional[bool] = None,
396
+ return_dict: Optional[bool] = None,
397
+ train: bool = False,
398
+ params: dict = None,
399
+ dropout_rng: PRNGKey = None,
400
+ ):
401
+ output_attentions = (
402
+ output_attentions
403
+ if output_attentions is not None
404
+ else self.config.output_attentions
405
+ )
406
+ output_hidden_states = (
407
+ output_hidden_states
408
+ if output_hidden_states is not None
409
+ else self.config.output_hidden_states
410
+ )
411
+ return_dict = (
412
+ return_dict if return_dict is not None else self.config.return_dict
413
+ )
414
+
415
+ pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
416
+
417
+ # # prepare encoder inputs
418
+ # if encoder_attention_mask is None:
419
+ # encoder_attention_mask = jnp.ones_like(input_ids)
420
+
421
+ # if position_ids is None:
422
+ # batch_size, sequence_length = input_ids.shape
423
+ # position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
424
+
425
+ # prepare decoder inputs
426
+ # if decoder_input_ids is None:
427
+ # decoder_input_ids = shift_tokens_right(
428
+ # input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id
429
+ # ) # TODO: Check how to use this
430
+
431
+ if attention_mask is None:
432
+ attention_mask = jnp.ones_like(input_ids)
433
+ if position_ids is None:
434
+ batch_size, sequence_length = input_ids.shape
435
+ position_ids = jnp.broadcast_to(
436
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
437
+ )
438
+
439
+ # Handle any PRNG if needed
440
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
441
+
442
+ return self.module.apply(
443
+ {"params": params or self.params},
444
+ pixel_values=jnp.array(pixel_values, dtype=jnp.float32),
445
+ input_ids=jnp.array(input_ids, dtype="i4"),
446
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
447
+ position_ids=jnp.array(position_ids, dtype="i4"),
448
+ output_attentions=output_attentions,
449
+ output_hidden_states=output_hidden_states,
450
+ return_dict=return_dict,
451
+ deterministic=not train,
452
+ rngs=rngs,
453
+ )
454
+
455
+
456
+ class FlaxViTGPT2ForConditionalGeneration(FlaxViTGPT2PreTrainedModel):
457
+ module_class = FlaxViTGPT2ForConditionalGenerationModule
458
+ dtype: jnp.dtype = jnp.float32
459
+
460
+ def decode(
461
+ self,
462
+ input_ids,
463
+ encoder_outputs,
464
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
465
+ attention_mask: Optional[jnp.ndarray] = None,
466
+ position_ids: Optional[jnp.ndarray] = None,
467
+ past_key_values: dict = None,
468
+ output_attentions: Optional[bool] = None,
469
+ output_hidden_states: Optional[bool] = None,
470
+ return_dict: Optional[bool] = None,
471
+ deterministic: bool = True,
472
+ params: dict = None,
473
+ dropout_rng: PRNGKey = None,
474
+ ):
475
+ output_attentions = (
476
+ output_attentions
477
+ if output_attentions is not None
478
+ else self.config.output_attentions
479
+ )
480
+ output_hidden_states = (
481
+ output_hidden_states
482
+ if output_hidden_states is not None
483
+ else self.config.output_hidden_states
484
+ )
485
+ return_dict = (
486
+ return_dict if return_dict is not None else self.config.return_dict
487
+ )
488
+
489
+ encoder_hidden_states = encoder_outputs[0]
490
+ if encoder_attention_mask is None:
491
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
492
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
493
+
494
+ batch_size, sequence_length = input_ids.shape
495
+ if attention_mask is None:
496
+ attention_mask = jnp.ones((batch_size, sequence_length))
497
+
498
+ if position_ids is None:
499
+ if past_key_values is not None:
500
+ raise ValueError(
501
+ "Make sure to provide `position_ids` when passing `past_key_values`."
502
+ )
503
+
504
+ position_ids = jnp.broadcast_to(
505
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
506
+ )
507
+
508
+ # Handle any PRNG if needed
509
+ rngs = {}
510
+ if dropout_rng is not None:
511
+ rngs["dropout"] = dropout_rng
512
+
513
+ inputs = {"params": params or self.params}
514
+
515
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
516
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
517
+ # it can be changed by FlaxGPT2Attention module
518
+ if past_key_values:
519
+ inputs["cache"] = past_key_values
520
+ mutable = ["cache"]
521
+ else:
522
+ mutable = False
523
+
524
+ def _decoder_forward(
525
+ module,
526
+ input_ids,
527
+ attention_mask,
528
+ position_ids,
529
+ **kwargs,
530
+ ):
531
+ decoder_module = module._get_decoder_module()
532
+ outputs = decoder_module(
533
+ input_ids,
534
+ attention_mask,
535
+ position_ids,
536
+ **kwargs,
537
+ )
538
+ hidden_states = outputs[0]
539
+
540
+ if self.config.tie_word_embeddings:
541
+ shared_embedding = module.model.variables["params"]["shared"][
542
+ "embedding"
543
+ ]
544
+ lm_logits = module.lm_head.apply(
545
+ {"params": {"kernel": shared_embedding.T}}, hidden_states
546
+ )
547
+ else:
548
+ lm_logits = module.lm_head(hidden_states)
549
+
550
+ lm_logits += module.final_logits_bias
551
+ return lm_logits, outputs
552
+
553
+ outputs = self.module.apply(
554
+ inputs,
555
+ input_ids=jnp.array(input_ids, dtype="i4"),
556
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
557
+ position_ids=jnp.array(position_ids, dtype="i4"),
558
+ encoder_hidden_states=encoder_hidden_states,
559
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
560
+ output_attentions=output_attentions,
561
+ output_hidden_states=output_hidden_states,
562
+ return_dict=return_dict,
563
+ deterministic=deterministic,
564
+ rngs=rngs,
565
+ mutable=mutable,
566
+ method=_decoder_forward,
567
+ )
568
+
569
+ if past_key_values is None:
570
+ lm_logits, outputs = outputs
571
+ else:
572
+ (lm_logits, outputs), past = outputs
573
+
574
+ if return_dict:
575
+ outputs = FlaxCausalLMOutputWithCrossAttentions(
576
+ logits=lm_logits,
577
+ hidden_states=outputs.hidden_states,
578
+ attentions=outputs.attentions,
579
+ cross_attentions=outputs.cross_attentions,
580
+ )
581
+ else:
582
+ outputs = (lm_logits,) + outputs[1:]
583
+
584
+ # add updated cache to model output
585
+ if past_key_values is not None and return_dict:
586
+ outputs["past_key_values"] = unfreeze(past["cache"])
587
+ return outputs
588
+ elif past_key_values is not None and not return_dict:
589
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
590
+
591
+ return outputs
592
+
593
+ def prepare_inputs_for_generation(
594
+ self,
595
+ input_ids,
596
+ max_length,
597
+ encoder_attention_mask: Optional[jnp.DeviceArray] = None,
598
+ attention_mask: Optional[jnp.DeviceArray] = None,
599
+ encoder_outputs=None,
600
+ **kwargs,
601
+ ):
602
+ # initializing the cache
603
+ batch_size, seq_length = input_ids.shape
604
+
605
+ past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
606
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
607
+ # But since the decoder uses a causal mask, those positions are masked anyways.
608
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
609
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
610
+ if attention_mask is not None:
611
+ position_ids = attention_mask.cumsum(axis=-1) - 1
612
+ extended_attention_mask = lax.dynamic_update_slice(
613
+ extended_attention_mask, attention_mask, (0, 0)
614
+ )
615
+ else:
616
+ position_ids = jnp.broadcast_to(
617
+ jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
618
+ )
619
+
620
+ return {
621
+ "past_key_values": past_key_values,
622
+ "encoder_outputs": encoder_outputs,
623
+ "encoder_attention_mask": encoder_attention_mask,
624
+ "attention_mask": extended_attention_mask,
625
+ "position_ids": position_ids,
626
+ }
627
+
628
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
629
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
630
+ model_kwargs["position_ids"] = (
631
+ model_kwargs["position_ids"][:, -1:] + 1
632
+ )
633
+ return model_kwargs
634
+
635
+ @classmethod
636
+ def from_vit_gpt2_pretrained(
637
+ cls,
638
+ vit_model_name_or_path: str = None,
639
+ gpt2_model_name_or_path: str = None,
640
+ *model_args,
641
+ **kwargs,
642
+ ) -> FlaxViTGPT2PreTrainedModel:
643
+
644
+ kwargs_gpt2 = {
645
+ argument[len("gpt2_") :]: value
646
+ for argument, value in kwargs.items()
647
+ if argument.startswith("gpt2_")
648
+ }
649
+
650
+ kwargs_vit = {
651
+ argument[len("vit_") :]: value
652
+ for argument, value in kwargs.items()
653
+ if argument.startswith("vit_")
654
+ }
655
+
656
+ # remove gpt2, vit kwargs from kwargs
657
+ for key in kwargs_gpt2.keys():
658
+ del kwargs["gpt2_" + key]
659
+ for key in kwargs_vit.keys():
660
+ del kwargs["vit_" + key]
661
+
662
+ # Load and initialize the gpt2 and vit model
663
+ gpt2_model = kwargs_gpt2.pop("model", None)
664
+ if gpt2_model is None:
665
+ assert (
666
+ gpt2_model_name_or_path is not None
667
+ ), "If `model` is not defined as an argument, a `gpt2_model_name_or_path` has to be defined"
668
+
669
+ if "config" not in kwargs_gpt2:
670
+ gpt2_config = GPT2Config.from_pretrained(gpt2_model_name_or_path)
671
+ kwargs_gpt2["config"] = gpt2_config
672
+
673
+ kwargs_gpt2["config"].add_cross_attention = True
674
+ gpt2_model = FlaxGPT2Model.from_pretrained(
675
+ gpt2_model_name_or_path, *model_args, **kwargs_gpt2
676
+ )
677
+
678
+ vit_model = kwargs_vit.pop("model", None)
679
+ if vit_model is None:
680
+ assert (
681
+ vit_model_name_or_path is not None
682
+ ), "If `model` is not defined as an argument, a `vit_model_name_or_path` has to be defined"
683
+
684
+ if "config" not in kwargs_vit:
685
+ vit_config = ViTConfig.from_pretrained(vit_model_name_or_path)
686
+ kwargs_vit["config"] = vit_config
687
+
688
+ vit_model = FlaxViTModel.from_pretrained(
689
+ vit_model_name_or_path, *model_args, **kwargs_vit
690
+ )
691
+
692
+ # instantiate config with corresponding kwargs
693
+ dtype = kwargs.pop("dtype", jnp.float32)
694
+ config = ViTGPT2Config.from_vit_gpt2_configs(
695
+ vit_model.config, gpt2_model.config, **kwargs
696
+ )
697
+
698
+ # init model
699
+ model = cls(config, *model_args, dtype=dtype, **kwargs)
700
+ model.params["model"]["encoder"] = vit_model.params
701
+ model.params["model"]["decoder"] = gpt2_model.params
702
+
703
+ return model
704
+
vit_gpt2/modeling_flax_vit_gpt2_lm.py ADDED
@@ -0,0 +1,684 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Tuple
2
+
3
+ import flax.linen as nn
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from flax.core.frozen_dict import FrozenDict, unfreeze
7
+ from jax import lax
8
+ from jax.random import PRNGKey
9
+ from transformers import GPT2Config, FlaxViTModel, ViTConfig
10
+ from transformers.modeling_flax_outputs import (
11
+ FlaxCausalLMOutputWithCrossAttentions,
12
+ FlaxSeq2SeqLMOutput,
13
+ FlaxSeq2SeqModelOutput,
14
+ )
15
+ from transformers.models.bart.modeling_flax_bart import (
16
+ shift_tokens_right,
17
+ )
18
+ from .modeling_flax_gpt2 import (
19
+ FlaxGPT2Module,
20
+ FlaxGPT2Model,
21
+ FlaxGPT2LMHeadModule,
22
+ FlaxGPT2LMHeadModel,
23
+ FlaxPreTrainedModel
24
+ )
25
+ from transformers.models.vit.modeling_flax_vit import FlaxViTModule
26
+
27
+ from .configuration_vit_gpt2 import ViTGPT2Config
28
+
29
+
30
+ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
31
+ """
32
+ Shift input ids one token to the right.
33
+ """
34
+ shifted_input_ids = jnp.roll(input_ids, 1, axis=-1)
35
+ shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id)
36
+ # replace possible -100 values in labels by `pad_token_id`
37
+ shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
38
+
39
+ return shifted_input_ids
40
+
41
+ class FlaxViTGPT2LMModule(nn.Module):
42
+ config: ViTGPT2Config
43
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
44
+
45
+ def setup(self):
46
+
47
+ self.encoder = FlaxViTModule(self.config.vit_config, dtype=self.dtype)
48
+ self.decoder = FlaxGPT2LMHeadModule(self.config.gpt2_config, dtype=self.dtype)
49
+
50
+ def _get_encoder_module(self):
51
+ return self.encoder
52
+
53
+ def _get_decoder_module(self):
54
+ return self.decoder
55
+
56
+ def __call__(
57
+ self,
58
+ pixel_values,
59
+ input_ids,
60
+ attention_mask,
61
+ position_ids,
62
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
63
+ output_attentions: bool = False,
64
+ output_hidden_states: bool = False,
65
+ return_dict: bool = True,
66
+ deterministic: bool = True,
67
+ ):
68
+ encoder_outputs = self.encoder(
69
+ pixel_values=pixel_values,
70
+ deterministic=deterministic,
71
+ output_attentions=output_attentions,
72
+ output_hidden_states=output_hidden_states,
73
+ return_dict=return_dict,
74
+ )
75
+
76
+ decoder_outputs = self.decoder(
77
+ input_ids=input_ids,
78
+ attention_mask=attention_mask,
79
+ position_ids=position_ids,
80
+ encoder_hidden_states=encoder_outputs[0],
81
+ encoder_attention_mask=encoder_attention_mask,
82
+ deterministic=deterministic,
83
+ output_attentions=output_attentions,
84
+ output_hidden_states=output_hidden_states,
85
+ return_dict=return_dict
86
+ )
87
+
88
+ if not return_dict:
89
+ return decoder_outputs + encoder_outputs
90
+
91
+ return FlaxSeq2SeqLMOutput(
92
+ logits=decoder_outputs.logits,
93
+ decoder_hidden_states=decoder_outputs.decoder_hidden_states,
94
+ decoder_attentions=decoder_outputs.decoder_attentions,
95
+ cross_attentions=decoder_outputs.cross_attentions,
96
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
97
+ encoder_hidden_states=encoder_outputs.hidden_states,
98
+ encoder_attentions=encoder_outputs.attentions,
99
+ )
100
+
101
+ class FlaxViTGPT2LMForConditionalGenerationModule(nn.Module):
102
+ config: ViTGPT2Config
103
+ dtype: jnp.dtype = jnp.float32
104
+ bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros
105
+
106
+ def setup(self):
107
+ self.model = FlaxViTGPT2LMModule(config=self.config, dtype=self.dtype)
108
+
109
+ def _get_encoder_module(self):
110
+ return self.model.encoder
111
+
112
+ def _get_decoder_module(self):
113
+ return self.model.decoder
114
+
115
+ def __call__(
116
+ self,
117
+ pixel_values,
118
+ input_ids,
119
+ attention_mask,
120
+ position_ids,
121
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
122
+ output_attentions: bool = False,
123
+ output_hidden_states: bool = False,
124
+ return_dict: bool = True,
125
+ deterministic: bool = True,
126
+ ):
127
+ outputs = self.model(
128
+ pixel_values=pixel_values,
129
+ input_ids=input_ids,
130
+ attention_mask=attention_mask,
131
+ position_ids=position_ids,
132
+ encoder_attention_mask=encoder_attention_mask,
133
+ output_attentions=output_attentions,
134
+ output_hidden_states=output_hidden_states,
135
+ return_dict=return_dict,
136
+ deterministic=deterministic,
137
+ )
138
+
139
+ return outputs
140
+
141
+
142
+ class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel):
143
+ config_class = ViTGPT2Config
144
+ base_model_prefix: str = "model"
145
+ module_class: nn.Module = None
146
+
147
+ def __init__(
148
+ self,
149
+ config: ViTGPT2Config,
150
+ input_shape: Tuple = None,
151
+ seed: int = 0,
152
+ dtype: jnp.dtype = jnp.float32,
153
+ **kwargs,
154
+ ):
155
+ if input_shape is None:
156
+ input_shape = (
157
+ (1, config.vit_config.image_size, config.vit_config.image_size, 3),
158
+ (1, 1),
159
+ )
160
+
161
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
162
+ super().__init__(
163
+ config, module, input_shape=input_shape, seed=seed, dtype=dtype
164
+ )
165
+
166
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
167
+ # init input tensors
168
+ pixel_values = jax.random.normal(rng, input_shape[0])
169
+ # # make sure initialization pass will work for FlaxBartForSequenceClassificationModule
170
+ # input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id)
171
+
172
+ input_ids = jnp.zeros(input_shape[1], dtype="i4")
173
+ attention_mask = jnp.ones_like(input_ids)
174
+
175
+ batch_size, sequence_length = input_ids.shape
176
+ position_ids = jnp.broadcast_to(
177
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
178
+ )
179
+
180
+ params_rng, dropout_rng = jax.random.split(rng)
181
+ rngs = {"params": params_rng, "dropout": dropout_rng}
182
+
183
+ return self.module.init(
184
+ rngs,
185
+ pixel_values,
186
+ input_ids,
187
+ attention_mask,
188
+ position_ids,
189
+ )["params"]
190
+
191
+ def init_cache(self, batch_size, max_length, encoder_outputs):
192
+
193
+ input_ids = jnp.ones((batch_size, max_length), dtype="i4")
194
+ attention_mask = jnp.ones_like(input_ids)
195
+ position_ids = jnp.broadcast_to(
196
+ jnp.arange(jnp.atleast_2d(input_ids).shape[-1]),
197
+ input_ids.shape,
198
+ )
199
+
200
+ def _decoder_forward(
201
+ module,
202
+ input_ids,
203
+ attention_mask,
204
+ position_ids,
205
+ **kwargs,
206
+ ):
207
+ decoder_module = module._get_decoder_module()
208
+ return decoder_module(
209
+ input_ids,
210
+ attention_mask,
211
+ position_ids,
212
+ **kwargs,
213
+ )
214
+
215
+ init_variables = self.module.init(
216
+ jax.random.PRNGKey(0),
217
+ input_ids=input_ids,
218
+ attention_mask=attention_mask,
219
+ position_ids=position_ids,
220
+ encoder_hidden_states=encoder_outputs[0],
221
+ init_cache=True,
222
+ method=_decoder_forward, # we only need to call the decoder to init the cache
223
+ )
224
+ return unfreeze(init_variables["cache"])
225
+
226
+ def encode(
227
+ self,
228
+ pixel_values: jnp.ndarray,
229
+ attention_mask: Optional[jnp.ndarray] = None,
230
+ output_attentions: Optional[bool] = None,
231
+ output_hidden_states: Optional[bool] = None,
232
+ return_dict: Optional[bool] = None,
233
+ train: bool = False,
234
+ params: dict = None,
235
+ dropout_rng: PRNGKey = None,
236
+ ):
237
+ output_attentions = (
238
+ output_attentions
239
+ if output_attentions is not None
240
+ else self.config.output_attentions
241
+ )
242
+ output_hidden_states = (
243
+ output_hidden_states
244
+ if output_hidden_states is not None
245
+ else self.config.output_hidden_states
246
+ )
247
+ return_dict = (
248
+ return_dict if return_dict is not None else self.config.return_dict
249
+ )
250
+
251
+ pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
252
+
253
+ # Handle any PRNG if needed
254
+ rngs = {}
255
+ if dropout_rng is not None:
256
+ rngs["dropout"] = dropout_rng
257
+
258
+ def _encoder_forward(module, pixel_values, **kwargs):
259
+ encode_module = module._get_encoder_module()
260
+ return encode_module(pixel_values, **kwargs)
261
+
262
+ return self.module.apply(
263
+ {"params": params or self.params},
264
+ pixel_values=jnp.array(pixel_values, dtype="i4"),
265
+ output_attentions=output_attentions,
266
+ output_hidden_states=output_hidden_states,
267
+ return_dict=return_dict,
268
+ deterministic=not train,
269
+ rngs=rngs,
270
+ method=_encoder_forward,
271
+ )
272
+
273
+ def decode(
274
+ self,
275
+ input_ids,
276
+ encoder_outputs,
277
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
278
+ attention_mask: Optional[jnp.ndarray] = None,
279
+ position_ids: Optional[jnp.ndarray] = None,
280
+ past_key_values: dict = None,
281
+ output_attentions: Optional[bool] = None,
282
+ output_hidden_states: Optional[bool] = None,
283
+ return_dict: Optional[bool] = None,
284
+ train: bool = False,
285
+ params: dict = None,
286
+ dropout_rng: PRNGKey = None,
287
+ ):
288
+
289
+ output_attentions = (
290
+ output_attentions
291
+ if output_attentions is not None
292
+ else self.config.output_attentions
293
+ )
294
+ output_hidden_states = (
295
+ output_hidden_states
296
+ if output_hidden_states is not None
297
+ else self.config.output_hidden_states
298
+ )
299
+ return_dict = (
300
+ return_dict if return_dict is not None else self.config.return_dict
301
+ )
302
+
303
+ encoder_hidden_states = encoder_outputs[0]
304
+ if encoder_attention_mask is None:
305
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
306
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
307
+
308
+ batch_size, sequence_length = input_ids.shape
309
+ if attention_mask is None:
310
+ attention_mask = jnp.ones((batch_size, sequence_length))
311
+
312
+ if position_ids is None:
313
+ if past_key_values is not None:
314
+ raise ValueError(
315
+ "Make sure to provide `position_ids` when passing `past_key_values`."
316
+ )
317
+
318
+ position_ids = jnp.broadcast_to(
319
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
320
+ )
321
+
322
+ # Handle any PRNG if needed
323
+ rngs = {}
324
+ if dropout_rng is not None:
325
+ rngs["dropout"] = dropout_rng
326
+
327
+ inputs = {"params": params or self.params}
328
+
329
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
330
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
331
+ # it can be changed by FlaxGPT2Attention module
332
+ if past_key_values:
333
+ inputs["cache"] = past_key_values
334
+ mutable = ["cache"]
335
+ else:
336
+ mutable = False
337
+
338
+ def _decoder_forward(
339
+ module,
340
+ input_ids,
341
+ attention_mask,
342
+ position_ids,
343
+ **kwargs,
344
+ ):
345
+ decoder_module = module._get_decoder_module()
346
+ return decoder_module(
347
+ input_ids,
348
+ attention_mask,
349
+ position_ids,
350
+ **kwargs,
351
+ )
352
+
353
+ outputs = self.module.apply(
354
+ inputs,
355
+ input_ids=jnp.array(input_ids, dtype="i4"),
356
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
357
+ position_ids=jnp.array(position_ids, dtype="i4"),
358
+ encoder_hidden_states=encoder_hidden_states,
359
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
360
+ output_attentions=output_attentions,
361
+ output_hidden_states=output_hidden_states,
362
+ return_dict=return_dict,
363
+ deterministic=not train,
364
+ rngs=rngs,
365
+ mutable=mutable,
366
+ method=_decoder_forward,
367
+ )
368
+
369
+ # add updated cache to model output
370
+ if past_key_values is not None and return_dict:
371
+ outputs, past = outputs
372
+ outputs["past_key_values"] = unfreeze(past["cache"])
373
+ return outputs
374
+ elif past_key_values is not None and not return_dict:
375
+ outputs, past = outputs
376
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
377
+
378
+ return outputs
379
+
380
+ def __call__(
381
+ self,
382
+ pixel_values: jnp.ndarray,
383
+ input_ids: Optional[jnp.ndarray] = None,
384
+ attention_mask: Optional[jnp.ndarray] = None,
385
+ position_ids: Optional[jnp.ndarray] = None,
386
+ output_attentions: Optional[bool] = None,
387
+ output_hidden_states: Optional[bool] = None,
388
+ return_dict: Optional[bool] = None,
389
+ train: bool = False,
390
+ params: dict = None,
391
+ dropout_rng: PRNGKey = None,
392
+ ):
393
+ output_attentions = (
394
+ output_attentions
395
+ if output_attentions is not None
396
+ else self.config.output_attentions
397
+ )
398
+ output_hidden_states = (
399
+ output_hidden_states
400
+ if output_hidden_states is not None
401
+ else self.config.output_hidden_states
402
+ )
403
+ return_dict = (
404
+ return_dict if return_dict is not None else self.config.return_dict
405
+ )
406
+
407
+ pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
408
+
409
+ # # prepare encoder inputs
410
+ # if encoder_attention_mask is None:
411
+ # encoder_attention_mask = jnp.ones_like(input_ids)
412
+
413
+ # if position_ids is None:
414
+ # batch_size, sequence_length = input_ids.shape
415
+ # position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
416
+
417
+ # prepare decoder inputs
418
+ # if decoder_input_ids is None:
419
+ # decoder_input_ids = shift_tokens_right(
420
+ # input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id
421
+ # ) # TODO: Check how to use this
422
+
423
+ if attention_mask is None:
424
+ attention_mask = jnp.ones_like(input_ids)
425
+ if position_ids is None:
426
+ batch_size, sequence_length = input_ids.shape
427
+ position_ids = jnp.broadcast_to(
428
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
429
+ )
430
+
431
+ # Handle any PRNG if needed
432
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
433
+
434
+ return self.module.apply(
435
+ {"params": params or self.params},
436
+ pixel_values=jnp.array(pixel_values, dtype=jnp.float32),
437
+ input_ids=jnp.array(input_ids, dtype="i4"),
438
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
439
+ position_ids=jnp.array(position_ids, dtype="i4"),
440
+ output_attentions=output_attentions,
441
+ output_hidden_states=output_hidden_states,
442
+ return_dict=return_dict,
443
+ deterministic=not train,
444
+ rngs=rngs,
445
+ )
446
+
447
+
448
+ class FlaxViTGPT2LMForConditionalGeneration(FlaxViTGPT2LMPreTrainedModel):
449
+ module_class = FlaxViTGPT2LMForConditionalGenerationModule
450
+ dtype: jnp.dtype = jnp.float32
451
+
452
+ def decode(
453
+ self,
454
+ input_ids,
455
+ encoder_outputs,
456
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
457
+ attention_mask: Optional[jnp.ndarray] = None,
458
+ position_ids: Optional[jnp.ndarray] = None,
459
+ past_key_values: dict = None,
460
+ output_attentions: Optional[bool] = None,
461
+ output_hidden_states: Optional[bool] = None,
462
+ return_dict: Optional[bool] = None,
463
+ deterministic: bool = True,
464
+ params: dict = None,
465
+ dropout_rng: PRNGKey = None,
466
+ ):
467
+ output_attentions = (
468
+ output_attentions
469
+ if output_attentions is not None
470
+ else self.config.output_attentions
471
+ )
472
+ output_hidden_states = (
473
+ output_hidden_states
474
+ if output_hidden_states is not None
475
+ else self.config.output_hidden_states
476
+ )
477
+ return_dict = (
478
+ return_dict if return_dict is not None else self.config.return_dict
479
+ )
480
+
481
+ encoder_hidden_states = encoder_outputs[0]
482
+ if encoder_attention_mask is None:
483
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
484
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
485
+
486
+ batch_size, sequence_length = input_ids.shape
487
+ if attention_mask is None:
488
+ attention_mask = jnp.ones((batch_size, sequence_length))
489
+
490
+ if position_ids is None:
491
+ if past_key_values is not None:
492
+ raise ValueError(
493
+ "Make sure to provide `position_ids` when passing `past_key_values`."
494
+ )
495
+
496
+ position_ids = jnp.broadcast_to(
497
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
498
+ )
499
+
500
+ # Handle any PRNG if needed
501
+ rngs = {}
502
+ if dropout_rng is not None:
503
+ rngs["dropout"] = dropout_rng
504
+
505
+ inputs = {"params": params or self.params}
506
+
507
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
508
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
509
+ # it can be changed by FlaxGPT2Attention module
510
+ if past_key_values:
511
+ inputs["cache"] = past_key_values
512
+ mutable = ["cache"]
513
+ else:
514
+ mutable = False
515
+
516
+ def _decoder_forward(
517
+ module,
518
+ input_ids,
519
+ attention_mask,
520
+ position_ids,
521
+ **kwargs,
522
+ ):
523
+ decoder_module = module._get_decoder_module()
524
+ outputs = decoder_module(
525
+ input_ids,
526
+ attention_mask,
527
+ position_ids,
528
+ **kwargs,
529
+ )
530
+ lm_logits = outputs[0]
531
+
532
+ return lm_logits, outputs
533
+
534
+ outputs = self.module.apply(
535
+ inputs,
536
+ input_ids=jnp.array(input_ids, dtype="i4"),
537
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
538
+ position_ids=jnp.array(position_ids, dtype="i4"),
539
+ encoder_hidden_states=encoder_hidden_states,
540
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
541
+ output_attentions=output_attentions,
542
+ output_hidden_states=output_hidden_states,
543
+ return_dict=return_dict,
544
+ deterministic=deterministic,
545
+ rngs=rngs,
546
+ mutable=mutable,
547
+ method=_decoder_forward,
548
+ )
549
+
550
+ if past_key_values is None:
551
+ lm_logits, outputs = outputs
552
+ else:
553
+ (lm_logits, outputs), past = outputs
554
+
555
+ if return_dict:
556
+ outputs = FlaxCausalLMOutputWithCrossAttentions(
557
+ logits=lm_logits,
558
+ hidden_states=outputs.decoder_hidden_states,
559
+ attentions=outputs.decoder_attentions,
560
+ cross_attentions=outputs.cross_attentions,
561
+ )
562
+ else:
563
+ outputs = (lm_logits,) + outputs[1:]
564
+
565
+ # add updated cache to model output
566
+ if past_key_values is not None and return_dict:
567
+ outputs["past_key_values"] = unfreeze(past["cache"])
568
+ return outputs
569
+ elif past_key_values is not None and not return_dict:
570
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
571
+
572
+ return outputs
573
+
574
+ def prepare_inputs_for_generation(
575
+ self,
576
+ input_ids,
577
+ max_length,
578
+ encoder_attention_mask: Optional[jnp.DeviceArray] = None,
579
+ attention_mask: Optional[jnp.DeviceArray] = None,
580
+ encoder_outputs=None,
581
+ **kwargs,
582
+ ):
583
+ # initializing the cache
584
+ batch_size, seq_length = input_ids.shape
585
+
586
+ past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
587
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
588
+ # But since the decoder uses a causal mask, those positions are masked anyways.
589
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
590
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
591
+ if attention_mask is not None:
592
+ position_ids = attention_mask.cumsum(axis=-1) - 1
593
+ extended_attention_mask = lax.dynamic_update_slice(
594
+ extended_attention_mask, attention_mask, (0, 0)
595
+ )
596
+ else:
597
+ position_ids = jnp.broadcast_to(
598
+ jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
599
+ )
600
+
601
+ return {
602
+ "past_key_values": past_key_values,
603
+ "encoder_outputs": encoder_outputs,
604
+ "encoder_attention_mask": encoder_attention_mask,
605
+ "attention_mask": extended_attention_mask,
606
+ "position_ids": position_ids,
607
+ }
608
+
609
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
610
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
611
+ model_kwargs["position_ids"] = (
612
+ model_kwargs["position_ids"][:, -1:] + 1
613
+ )
614
+ return model_kwargs
615
+
616
+ @classmethod
617
+ def from_vit_gpt2_pretrained(
618
+ cls,
619
+ vit_model_name_or_path: str = None,
620
+ gpt2_model_name_or_path: str = None,
621
+ *model_args,
622
+ **kwargs,
623
+ ) -> FlaxViTGPT2LMPreTrainedModel:
624
+
625
+ kwargs_gpt2 = {
626
+ argument[len("gpt2_") :]: value
627
+ for argument, value in kwargs.items()
628
+ if argument.startswith("gpt2_")
629
+ }
630
+
631
+ kwargs_vit = {
632
+ argument[len("vit_") :]: value
633
+ for argument, value in kwargs.items()
634
+ if argument.startswith("vit_")
635
+ }
636
+
637
+ # remove gpt2, vit kwargs from kwargs
638
+ for key in kwargs_gpt2.keys():
639
+ del kwargs["gpt2_" + key]
640
+ for key in kwargs_vit.keys():
641
+ del kwargs["vit_" + key]
642
+
643
+ # Load and initialize the gpt2 and vit model
644
+ gpt2_model = kwargs_gpt2.pop("model", None)
645
+ if gpt2_model is None:
646
+ assert (
647
+ gpt2_model_name_or_path is not None
648
+ ), "If `model` is not defined as an argument, a `gpt2_model_name_or_path` has to be defined"
649
+
650
+ if "config" not in kwargs_gpt2:
651
+ gpt2_config = GPT2Config.from_pretrained(gpt2_model_name_or_path)
652
+ kwargs_gpt2["config"] = gpt2_config
653
+
654
+ kwargs_gpt2["config"].add_cross_attention = True
655
+ gpt2_model = FlaxGPT2LMHeadModel.from_pretrained(
656
+ gpt2_model_name_or_path, *model_args, **kwargs_gpt2
657
+ )
658
+
659
+ vit_model = kwargs_vit.pop("model", None)
660
+ if vit_model is None:
661
+ assert (
662
+ vit_model_name_or_path is not None
663
+ ), "If `model` is not defined as an argument, a `vit_model_name_or_path` has to be defined"
664
+
665
+ if "config" not in kwargs_vit:
666
+ vit_config = ViTConfig.from_pretrained(vit_model_name_or_path)
667
+ kwargs_vit["config"] = vit_config
668
+
669
+ vit_model = FlaxViTModel.from_pretrained(
670
+ vit_model_name_or_path, *model_args, **kwargs_vit
671
+ )
672
+
673
+ # instantiate config with corresponding kwargs
674
+ dtype = kwargs.pop("dtype", jnp.float32)
675
+ config = ViTGPT2Config.from_vit_gpt2_configs(
676
+ vit_model.config, gpt2_model.config, **kwargs
677
+ )
678
+
679
+ # init model
680
+ model = cls(config, *model_args, dtype=dtype, **kwargs)
681
+ model.params["model"]["encoder"] = vit_model.params
682
+ model.params["model"]["decoder"] = gpt2_model.params
683
+
684
+ return model
wit_data_dir/dev/dev.tsv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef1ecdcd132885a8f29c8707fad649431c6ff3d9bbd295d56b8520e7046c0eb7
3
+ size 1418232
wit_data_dir/test/test.tsv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f0517292749005808b1d1d75343c76b8b16c3ed74fde030f7af8b611ad7b4d5d
3
+ size 1406997
wit_dataset_script.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import json
3
+ import os
4
+
5
+ import datasets
6
+ import pandas as pd
7
+ import numpy as np
8
+
9
+
10
+ # TODO: Add BibTeX citation
11
+ # Find for instance the citation on arxiv or on the dataset repo/website
12
+ _CITATION = """\
13
+ @InProceedings{huggingface:dataset,
14
+ title = {A great new dataset},
15
+ author={huggingface, Inc.
16
+ },
17
+ year={2020}
18
+ }
19
+ """
20
+
21
+ # TODO: Add description of the dataset here
22
+ # You can copy an official description
23
+ _DESCRIPTION = """\
24
+ This new dataset is designed to solve this great NLP task and is crafted with a lot of care.
25
+ """
26
+
27
+ # TODO: Add a link to an official homepage for the dataset here
28
+ _HOMEPAGE = ""
29
+
30
+ # TODO: Add the licence for the dataset here if you can find it
31
+ _LICENSE = ""
32
+
33
+ # TODO: Add link to the official dataset URLs here
34
+ # The HuggingFace dataset library don't host the datasets but only point to the original files
35
+ # This can be an arbitrary nested dict/list of URLs (see below in `_split_generators` method)
36
+ _URLs = {
37
+ }
38
+
39
+
40
+ # TODO: Name of the dataset usually match the script name with CamelCase instead of snake_case
41
+ class WITDataset(datasets.GeneratorBasedBuilder):
42
+ """TODO: Short description of my dataset."""
43
+
44
+ VERSION = datasets.Version("1.1.0")
45
+
46
+ DEFAULT_CONFIG_NAME = "en"
47
+
48
+ def _info(self):
49
+ # TODO: This method specifies the datasets.DatasetInfo object which contains informations and typings for the dataset
50
+
51
+ features = datasets.Features(
52
+ {
53
+ "id": datasets.Value("int64"),
54
+ "lang": datasets.Value("string"),
55
+ "caption": datasets.Value("string"),
56
+ "context": datasets.Value("string"),
57
+ "image_url": datasets.Value("string"),
58
+ "page_url": datasets.Value("string"),
59
+ "image_file": datasets.Value("string"),
60
+ "pixels_file": datasets.Value("string")
61
+ # These are the features of your dataset like images, labels ...
62
+ }
63
+ )
64
+
65
+ return datasets.DatasetInfo(
66
+ # This is the description that will appear on the datasets page.
67
+ description=_DESCRIPTION,
68
+ # This defines the different columns of the dataset and their types
69
+ features=features, # Here we define them above because they are different between the two configurations
70
+ # If there's a common (input, target) tuple from the features,
71
+ # specify them here. They'll be used if as_supervised=True in
72
+ # builder.as_dataset.
73
+ supervised_keys=None,
74
+ # Homepage of the dataset for documentation
75
+ homepage=_HOMEPAGE,
76
+ # License for the dataset if available
77
+ license=_LICENSE,
78
+ # Citation for the dataset
79
+ citation=_CITATION,
80
+ )
81
+
82
+ def _split_generators(self, dl_manager):
83
+ """Returns SplitGenerators."""
84
+ # TODO: This method is tasked with downloading/extracting the data and defining the splits depending on the configuration
85
+ # If several configurations are possible (listed in BUILDER_CONFIGS), the configuration selected by the user is in self.config.name
86
+
87
+ data_dir = self.config.data_dir
88
+
89
+ return [
90
+ datasets.SplitGenerator(
91
+ name=datasets.Split.TRAIN,
92
+ # These kwargs will be passed to _generate_examples
93
+ gen_kwargs={
94
+ "data_dir": os.path.join(data_dir, "train"),
95
+ "split": "train",
96
+ },
97
+ ),
98
+ datasets.SplitGenerator(
99
+ name=datasets.Split.TEST,
100
+ # These kwargs will be passed to _generate_examples
101
+ gen_kwargs={
102
+ "data_dir": os.path.join(data_dir, "test"),
103
+ "split": "test"
104
+ },
105
+ ),
106
+ datasets.SplitGenerator(
107
+ name=datasets.Split.VALIDATION,
108
+ # These kwargs will be passed to _generate_examples
109
+ gen_kwargs={
110
+ "data_dir": os.path.join(data_dir, "dev"),
111
+ "split": "dev",
112
+ },
113
+ ),
114
+ ]
115
+
116
+ def _generate_examples(
117
+ self, data_dir, split # method parameters are unpacked from `gen_kwargs` as given in `_split_generators`
118
+ ):
119
+ """ Yields examples as (key, example) tuples. """
120
+ # This method handles input defined in _split_generators to yield (key, example) tuples from the dataset.
121
+ # The `key` is here for legacy reason (tfds) and is not important in itself.
122
+
123
+ df = pd.read_csv(os.path.join(data_dir, f'{split}.tsv'), sep='\t')
124
+
125
+ for id_, row in df.iterrows():
126
+
127
+ _id = row[0]
128
+
129
+ # null caption and context
130
+ if type(row[4]) != str or type(row[5]) != str:
131
+ continue
132
+
133
+ image_file = os.path.join(data_dir, 'images', f'{_id}.jpg')
134
+ pixels_file = os.path.join(data_dir, 'numpy', f'{_id}.npy')
135
+
136
+ yield id_, {
137
+ "id": row[0],
138
+ "lang": row[1],
139
+ "caption": row[4],
140
+ "context": row[5],
141
+ "image_url": row[2],
142
+ "page_url": row[3],
143
+ "image_file": image_file,
144
+ "pixels_file": pixels_file
145
+ }