ydshieh commited on
Commit
3bffad7
1 Parent(s): a1a9885

Use new FlaxVisionEncoderDecoderModel class

Browse files
Files changed (2) hide show
  1. run_image_caption.py +255 -101
  2. run_summarization_flax.py +265 -100
run_image_caption.py CHANGED
@@ -18,11 +18,6 @@ Fine-tuning the library models for summarization.
18
  """
19
  # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
 
21
- import sys, os
22
-
23
- current_path = os.path.dirname(os.path.abspath(__file__))
24
- sys.path.append(current_path)
25
-
26
  import logging
27
  import os
28
  import sys
@@ -48,20 +43,21 @@ from flax import jax_utils, traverse_util
48
  from flax.jax_utils import unreplicate
49
  from flax.training import train_state
50
  from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
 
51
  from transformers import (
52
  CONFIG_MAPPING,
53
- FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
54
  AutoConfig,
 
55
  AutoTokenizer,
56
  FlaxAutoModelForSeq2SeqLM,
57
  HfArgumentParser,
58
  TrainingArguments,
59
  is_tensorboard_available,
 
60
  )
61
- from transformers.file_utils import is_offline_mode
62
 
63
- from transformers import ViTFeatureExtractor, GPT2Tokenizer, GPT2Config
64
- from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
65
 
66
  logger = logging.getLogger(__name__)
67
 
@@ -76,10 +72,23 @@ except (LookupError, OSError):
76
  nltk.download("punkt", quiet=True)
77
 
78
 
79
- MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
80
  MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
81
 
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  @dataclass
84
  class ModelArguments:
85
  """
@@ -93,15 +102,46 @@ class ModelArguments:
93
  "Don't set if you want to train a model from scratch."
94
  },
95
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  model_type: Optional[str] = field(
97
  default=None,
98
  metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
99
  )
 
 
 
 
 
 
 
 
100
  config_name: Optional[str] = field(
101
  default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
102
  )
 
 
 
 
 
 
 
 
 
103
  tokenizer_name: Optional[str] = field(
104
- default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
105
  )
106
  cache_dir: Optional[str] = field(
107
  default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
@@ -130,19 +170,26 @@ class DataTrainingArguments:
130
  dataset_config_name: Optional[str] = field(
131
  default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
132
  )
133
- text_column: Optional[str] = field(
 
 
 
134
  default=None,
135
- metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
136
  )
137
- summary_column: Optional[str] = field(
138
  default=None,
139
- metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
140
  )
141
  train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
142
  validation_file: Optional[str] = field(
143
  default=None,
144
  metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
145
  )
 
 
 
 
146
  max_source_length: Optional[int] = field(
147
  default=1024,
148
  metadata={
@@ -191,9 +238,6 @@ class DataTrainingArguments:
191
  default=None,
192
  metadata={"help": "The number of processes to use for the preprocessing."},
193
  )
194
- source_prefix: Optional[str] = field(
195
- default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
196
- )
197
  predict_with_generate: bool = field(
198
  default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
199
  )
@@ -222,18 +266,8 @@ class DataTrainingArguments:
222
  self.val_max_target_length = self.max_target_length
223
 
224
 
225
- summarization_name_mapping = {
226
- "amazon_reviews_multi": ("review_body", "review_title"),
227
- "big_patent": ("description", "abstract"),
228
- "cnn_dailymail": ("article", "highlights"),
229
- "orange_sum": ("text", "summary"),
230
- "pn_summary": ("article", "summary"),
231
- "psc": ("extract_text", "summary_text"),
232
- "samsum": ("dialogue", "summary"),
233
- "thaisum": ("body", "summary"),
234
- "xglue": ("news_body", "news_title"),
235
- "xsum": ("document", "summary"),
236
- "wiki_summary": ("article", "highlights"),
237
  }
238
 
239
 
@@ -337,6 +371,16 @@ def main():
337
  # Set the verbosity to info of the Transformers logger (on main process only):
338
  logger.info(f"Training/evaluation parameters {training_args}")
339
 
 
 
 
 
 
 
 
 
 
 
340
  # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
341
  # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
342
  # (the dataset will be downloaded automatically from the datasets Hub).
@@ -347,7 +391,7 @@ def main():
347
  if data_args.dataset_name is not None:
348
  # Downloading and loading a dataset from the hub.
349
  dataset = load_dataset(
350
- data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False, data_dir='/home/33611/caption/'
351
  )
352
  else:
353
  data_files = {}
@@ -360,38 +404,152 @@ def main():
360
  if data_args.test_file is not None:
361
  data_files["test"] = data_args.test_file
362
  extension = data_args.test_file.split(".")[-1]
363
- dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
 
 
 
364
 
365
- vit_name_path = 'google/vit-base-patch16-224-in21k'
366
- gpt2_name_path = 'asi/gpt-fr-cased-small'
367
-
368
- gpt2_config = GPT2Config.from_pretrained(gpt2_name_path)
369
- gpt2_config.add_cross_attention = True
370
-
 
 
 
 
 
 
 
 
371
 
372
- vit_gpt2_name_path = ''
 
 
 
 
 
 
373
 
374
- feature_extractor = ViTFeatureExtractor.from_pretrained(vit_name_path)
 
 
 
 
 
 
375
 
376
- tokenizer = GPT2Tokenizer.from_pretrained(gpt2_name_path)
 
 
377
 
378
- if not vit_gpt2_name_path:
379
- assert vit_name_path
380
- assert gpt2_name_path
381
- vit_gpt2_model = FlaxViTGPT2LMForConditionalGeneration.from_vit_gpt2_pretrained(
382
- vit_name_path, gpt2_name_path
383
- )
384
  else:
385
- vit_gpt2_model = FlaxViTGPT2LMForConditionalGeneration.from_pretrained(
386
- vit_gpt2_name_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
 
389
- model = vit_gpt2_model
390
- model.config.is_encoder_decoder = True
391
- model.config.decoder_start_token_id = gpt2_config.bos_token_id
392
- model.config.bos_token_id = gpt2_config.bos_token_id
393
- model.config.eos_token_id = gpt2_config.eos_token_id
394
- model.config.pad_token_id = gpt2_config.pad_token_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
 
396
  # Preprocessing the datasets.
397
  # We need to tokenize inputs and targets.
@@ -405,8 +563,26 @@ def main():
405
  logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
406
  return
407
 
408
- image_file_column = 'image_file'
409
- caption_column = 'fr'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
 
411
  # Temporarily set max_target_length for training.
412
  max_target_length = data_args.max_target_length
@@ -414,29 +590,25 @@ def main():
414
  # In Flax, for seq2seq models we need to pass `decoder_input_ids`
415
  # as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
416
  # for that dynamically import the `shift_tokens_right` function from the model file
417
- model_module = __import__(vit_gpt2_model.__module__, fromlist=["shift_tokens_right"])
418
- shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
419
 
420
  # Setting padding="max_length" as we need fixed length inputs for jitted functions
421
  def preprocess_function(examples):
422
-
423
- _pixel_values = []
424
- _captions = []
425
- for y, z in zip(examples[image_file_column], examples[caption_column]):
426
- with Image.open(y) as image:
427
  try:
428
  encoder_inputs = feature_extractor(images=image, return_tensors="np")
429
  except:
430
  continue
431
- x = encoder_inputs.pixel_values
432
- _pixel_values.append(x)
433
- _captions.append(z + ' ' + tokenizer.eos_token)
434
- pixel_values = np.concatenate(_pixel_values)
435
 
436
- targets = _captions
437
-
438
- # Add eos_token!!
439
- #targets = [x + ' ' + tokenizer.eos_token for x in targets]
440
 
441
  model_inputs = {}
442
  model_inputs['pixel_values'] = pixel_values
@@ -448,18 +620,13 @@ def main():
448
  )
449
 
450
  model_inputs["labels"] = labels["input_ids"]
451
-
452
- #print(labels["input_ids"])
453
- #print(gpt2_config.pad_token_id)
454
- #rint(gpt2_config.bos_token_id)
455
-
456
  decoder_input_ids = shift_tokens_right_fn(
457
- jnp.array(labels["input_ids"]), gpt2_config.pad_token_id, gpt2_config.bos_token_id
458
  )
459
- model_inputs["input_ids"] = np.asarray(decoder_input_ids)
460
 
461
  # We need decoder_attention_mask so we can ignore pad tokens from loss
462
- model_inputs["attention_mask"] = labels["attention_mask"]
463
 
464
  return model_inputs
465
 
@@ -469,7 +636,6 @@ def main():
469
  train_dataset = dataset["train"]
470
  if data_args.max_train_samples is not None:
471
  train_dataset = train_dataset.select(range(data_args.max_train_samples))
472
-
473
  train_dataset = train_dataset.map(
474
  preprocess_function,
475
  batched=True,
@@ -604,7 +770,7 @@ def main():
604
  )
605
 
606
  # Setup train state
607
- state = TrainState.create(apply_fn=vit_gpt2_model.__call__, params=vit_gpt2_model.params, tx=adamw, dropout_rng=dropout_rng)
608
 
609
  # label smoothed cross entropy
610
  def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
@@ -635,7 +801,7 @@ def main():
635
  def compute_loss(params):
636
  labels = batch.pop("labels")
637
  logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
638
- loss = loss_fn(logits, labels, batch["attention_mask"], label_smoothing_factor)
639
  return loss
640
 
641
  grad_fn = jax.value_and_grad(compute_loss)
@@ -653,7 +819,7 @@ def main():
653
  def eval_step(params, batch, label_smoothing_factor=0.0):
654
  labels = batch.pop("labels")
655
  logits = model(**batch, params=params, train=False)[0]
656
- loss = loss_fn(logits, labels, batch["attention_mask"], label_smoothing_factor)
657
 
658
  # summarize metrics
659
  metrics = {"loss": loss}
@@ -669,15 +835,7 @@ def main():
669
 
670
  def generate_step(params, batch):
671
  model.params = params
672
- # output_ids = model.generate(batch["pixel_values"], **gen_kwargs)
673
-
674
- #encoder_outputs = model.encode(pixel_values=batch['pixel_values'])
675
- #output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], encoder_outputs=encoder_outputs, **gen_kwargs)
676
-
677
- # encoder_outputs = model.encode(pixel_values=batch['pixel_values'], params=params, train=False)
678
  output_ids = model.generate(batch['pixel_values'], **gen_kwargs)
679
-
680
-
681
  return output_ids.sequences
682
 
683
  # Create parallel version of the train and eval step
@@ -727,7 +885,6 @@ def main():
727
  with open(os.path.join(training_args.output_dir, f'report.txt'), 'a', encoding='UTF-8') as fp:
728
  fp.write(desc + '\n')
729
 
730
-
731
  # ======================== Evaluating ==============================
732
  eval_metrics = []
733
  eval_preds = []
@@ -768,7 +925,6 @@ def main():
768
  with open(os.path.join(training_args.output_dir, f'report.txt'), 'a', encoding='UTF-8') as fp:
769
  fp.write(desc + '\n')
770
 
771
-
772
  # Save metrics
773
  if has_tensorboard and jax.process_index() == 0:
774
  cur_step = epoch * (len(train_dataset) // train_batch_size)
@@ -816,17 +972,15 @@ def main():
816
  logger.info(desc)
817
  with open(os.path.join(training_args.output_dir, f'report.txt'), 'a', encoding='UTF-8') as fp:
818
  fp.write(desc + '\n')
819
-
820
 
821
  # save checkpoint after each epoch and push checkpoint to the hub
822
  if jax.process_index() == 0:
823
  params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
824
- model.save_pretrained(
825
- os.path.join(training_args.output_dir, f'ckpt_{epoch+1}'),
826
- params=params,
827
- push_to_hub=training_args.push_to_hub,
828
- commit_message=f"Saving weights and logs of epoch {epoch+1}",
829
- )
830
 
831
  if __name__ == "__main__":
832
  main()
 
18
  """
19
  # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
 
 
 
 
 
 
21
  import logging
22
  import os
23
  import sys
 
43
  from flax.jax_utils import unreplicate
44
  from flax.training import train_state
45
  from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
46
+ from huggingface_hub import Repository
47
  from transformers import (
48
  CONFIG_MAPPING,
49
+ FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
50
  AutoConfig,
51
+ AutoFeatureExtractor,
52
  AutoTokenizer,
53
  FlaxAutoModelForSeq2SeqLM,
54
  HfArgumentParser,
55
  TrainingArguments,
56
  is_tensorboard_available,
57
+ FlaxAutoModelForVision2Seq,
58
  )
59
+ from transformers.file_utils import get_full_repo_name, is_offline_mode
60
 
 
 
61
 
62
  logger = logging.getLogger(__name__)
63
 
 
72
  nltk.download("punkt", quiet=True)
73
 
74
 
75
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING.keys())
76
  MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
77
 
78
 
79
+ # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
80
+ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
81
+ """
82
+ Shift input ids one token to the right.
83
+ """
84
+ shifted_input_ids = np.zeros_like(input_ids)
85
+ shifted_input_ids[:, 1:] = input_ids[:, :-1]
86
+ shifted_input_ids[:, 0] = decoder_start_token_id
87
+
88
+ shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
89
+ return shifted_input_ids
90
+
91
+
92
  @dataclass
93
  class ModelArguments:
94
  """
 
102
  "Don't set if you want to train a model from scratch."
103
  },
104
  )
105
+ encoder_model_name_or_path: Optional[str] = field(
106
+ default=None,
107
+ metadata={
108
+ "help": "The encoder model checkpoint for weights initialization."
109
+ "Don't set if you want to train a model from scratch."
110
+ },
111
+ )
112
+ decoder_model_name_or_path: Optional[str] = field(
113
+ default=None,
114
+ metadata={
115
+ "help": "The decoder model checkpoint for weights initialization."
116
+ "Don't set if you want to train a model from scratch."
117
+ },
118
+ )
119
  model_type: Optional[str] = field(
120
  default=None,
121
  metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
122
  )
123
+ encoder_model_type: Optional[str] = field(
124
+ default=None,
125
+ metadata={"help": "If training from scratch, pass a encoder model type from the list: " + ", ".join(MODEL_TYPES)},
126
+ )
127
+ decoder_model_type: Optional[str] = field(
128
+ default=None,
129
+ metadata={"help": "If training from scratch, pass a decoder model type from the list: " + ", ".join(MODEL_TYPES)},
130
+ )
131
  config_name: Optional[str] = field(
132
  default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
133
  )
134
+ encoder_config_name: Optional[str] = field(
135
+ default=None, metadata={"help": "Pretrained config name or path if not the same as encoder_model_name"}
136
+ )
137
+ decoder_config_name: Optional[str] = field(
138
+ default=None, metadata={"help": "Pretrained config name or path if not the same as decoder_model_name"}
139
+ )
140
+ feature_extractor_name: Optional[str] = field(
141
+ default=None, metadata={"help": "Pretrained feature extractor_name name or path if not the same as encoder_model_name"}
142
+ )
143
  tokenizer_name: Optional[str] = field(
144
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as decoder_model_name"}
145
  )
146
  cache_dir: Optional[str] = field(
147
  default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
 
170
  dataset_config_name: Optional[str] = field(
171
  default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
172
  )
173
+ data_dir: Optional[str] = field(
174
+ default=None, metadata={"help": "The data directory of the dataset to use (via the datasets library)."}
175
+ )
176
+ image_column: Optional[str] = field(
177
  default=None,
178
+ metadata={"help": "The name of the column in the datasets containing the full image file paths (for image captioning)."},
179
  )
180
+ caption_column: Optional[str] = field(
181
  default=None,
182
+ metadata={"help": "The name of the column in the datasets containing the image captions (for image captioning)."},
183
  )
184
  train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
185
  validation_file: Optional[str] = field(
186
  default=None,
187
  metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
188
  )
189
+ test_file: Optional[str] = field(
190
+ default=None,
191
+ metadata={"help": "An optional input predict data file to do prediction on (a text file)."},
192
+ )
193
  max_source_length: Optional[int] = field(
194
  default=1024,
195
  metadata={
 
238
  default=None,
239
  metadata={"help": "The number of processes to use for the preprocessing."},
240
  )
 
 
 
241
  predict_with_generate: bool = field(
242
  default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
243
  )
 
266
  self.val_max_target_length = self.max_target_length
267
 
268
 
269
+ image_captioning_name_mapping = {
270
+ "image_caption_dataset.py": ("image_file", "caption"),
 
 
 
 
 
 
 
 
 
 
271
  }
272
 
273
 
 
371
  # Set the verbosity to info of the Transformers logger (on main process only):
372
  logger.info(f"Training/evaluation parameters {training_args}")
373
 
374
+ # Handle the repository creation
375
+ if training_args.push_to_hub:
376
+ if training_args.hub_model_id is None:
377
+ repo_name = get_full_repo_name(
378
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
379
+ )
380
+ else:
381
+ repo_name = training_args.hub_model_id
382
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
383
+
384
  # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
385
  # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
386
  # (the dataset will be downloaded automatically from the datasets Hub).
 
391
  if data_args.dataset_name is not None:
392
  # Downloading and loading a dataset from the hub.
393
  dataset = load_dataset(
394
+ data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False, data_dir=data_args.data_dir
395
  )
396
  else:
397
  data_files = {}
 
404
  if data_args.test_file is not None:
405
  data_files["test"] = data_args.test_file
406
  extension = data_args.test_file.split(".")[-1]
407
+ # TODO: Check
408
+ dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir, data_dir=data_args.data_dir)
409
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
410
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
411
 
412
+ # Load pretrained model and tokenizer
413
+
414
+ encoder_cache_dir, decoder_cache_dir = None, None
415
+ if model_args.cache_dir:
416
+ encoder_cache_dir = os.path.join(model_args.cache_dir, "encoder")
417
+ decoder_cache_dir = os.path.join(model_args.cache_dir, "decoder")
418
+
419
+ if model_args.config_name:
420
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
421
+ elif model_args.model_name_or_path:
422
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
423
+ elif getattr(CONFIG_MAPPING[model_args.model_type], "from_encoder_decoder_configs", None):
424
+
425
+ config_class = CONFIG_MAPPING[model_args.model_type]
426
 
427
+ if model_args.encoder_config_name:
428
+ encoder_config = AutoConfig.from_pretrained(model_args.encoder_config_name, cache_dir=encoder_cache_dir)
429
+ elif model_args.encoder_model_name_or_path:
430
+ encoder_config = AutoConfig.from_pretrained(model_args.encoder_model_name_or_path, cache_dir=encoder_cache_dir)
431
+ else:
432
+ encoder_config = CONFIG_MAPPING[model_args.encoder_model_type]()
433
+ logger.warning("You are instantiating a new config instance from scratch for the encoder.")
434
 
435
+ if model_args.decoder_config_name:
436
+ decoder_config = AutoConfig.from_pretrained(model_args.decoder_config_name, cache_dir=decoder_cache_dir)
437
+ elif model_args.decoder_model_name_or_path:
438
+ decoder_config = AutoConfig.from_pretrained(model_args.decoder_model_name_or_path, cache_dir=decoder_cache_dir)
439
+ else:
440
+ decoder_config = CONFIG_MAPPING[model_args.decoder_model_type]()
441
+ logger.warning("You are instantiating a new config instance from scratch for the decoder.")
442
 
443
+ logger.info("Setting `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
444
+ decoder_config.is_decoder = True
445
+ decoder_config.add_cross_attention = True
446
 
447
+ config = config_class.from_encoder_decoder_configs(encoder_config, decoder_config)
 
 
 
 
 
448
  else:
449
+ config = CONFIG_MAPPING[model_args.model_type]()
450
+ logger.warning("You are instantiating a new config instance from scratch.")
451
+
452
+ decoder_start_token_id = getattr(config, "decoder_start_token_id", None)
453
+ if not decoder_start_token_id and getattr(config, "decoder", None):
454
+ decoder_start_token_id = getattr(config.decoder, "decoder_start_token_id", None)
455
+ bos_token_id = getattr(config, "bos_token_id", None)
456
+ if not bos_token_id and getattr(config, "decoder", None):
457
+ bos_token_id = getattr(config.decoder, "bos_token_id", None)
458
+ eos_token_id = getattr(config, "eos_token_id", None)
459
+ if not eos_token_id and getattr(config, "decoder", None):
460
+ eos_token_id = getattr(config.decoder, "eos_token_id", None)
461
+ pad_token_id = getattr(config, "pad_token_id", None)
462
+ if not pad_token_id and getattr(config, "decoder", None):
463
+ pad_token_id = getattr(config.decoder, "pad_token_id", None)
464
+
465
+ if decoder_start_token_id is None:
466
+ decoder_start_token_id = bos_token_id
467
+ if pad_token_id is None:
468
+ pad_token_id = eos_token_id
469
+
470
+ config.decoder_start_token_id = decoder_start_token_id
471
+ config.bos_token_id = bos_token_id
472
+ config.eos_token_id = eos_token_id
473
+ config.pad_token_id = pad_token_id
474
+
475
+ if getattr(config, "decoder", None):
476
+ config.decoder.decoder_start_token_id = decoder_start_token_id
477
+ config.decoder.bos_token_id = bos_token_id
478
+ config.decoder.eos_token_id = eos_token_id
479
+ config.decoder.pad_token_id = pad_token_id
480
+
481
+ feature_extractor = None
482
+ if model_args.feature_extractor_name:
483
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
484
+ model_args.feature_extractor_name, cache_dir=model_args.cache_dir,
485
  )
486
+ elif model_args.model_name_or_path:
487
+ try:
488
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
489
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir
490
+ )
491
+ except ValueError as e:
492
+ logger.warning(e)
493
+ if not feature_extractor:
494
+ if model_args.encoder_model_name_or_path:
495
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
496
+ model_args.encoder_model_name_or_path, cache_dir=model_args.cache_dir
497
+ )
498
+ else:
499
+ raise ValueError(
500
+ "You are instantiating a new feature extractor from scratch. This is not supported by this script."
501
+ "You can do it from another script, save it, and load it from here, using --feature_extractor_name."
502
+ )
503
 
504
+ tokenizer = None
505
+ if model_args.tokenizer_name:
506
+ tokenizer = AutoTokenizer.from_pretrained(
507
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
508
+ )
509
+ elif model_args.model_name_or_path:
510
+ try:
511
+ tokenizer = AutoTokenizer.from_pretrained(
512
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
513
+ )
514
+ except ValueError as e:
515
+ logger.warning(e)
516
+ if not tokenizer:
517
+ if model_args.decoder_model_name_or_path:
518
+ tokenizer = AutoTokenizer.from_pretrained(
519
+ model_args.decoder_model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
520
+ )
521
+ else:
522
+ raise ValueError(
523
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
524
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
525
+ )
526
+ tokenizer.pad_token = tokenizer.convert_ids_to_tokens(config.pad_token_id)
527
+
528
+ if model_args.model_name_or_path:
529
+ model = FlaxAutoModelForVision2Seq.from_pretrained(
530
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
531
+ )
532
+ elif model_args.encoder_model_name_or_path and model_args.decoder_model_name_or_path:
533
+ model_class = FlaxAutoModelForVision2Seq.from_config(config).__class__
534
+ model = model_class.from_encoder_decoder_pretrained(
535
+ model_args.encoder_model_name_or_path,
536
+ model_args.decoder_model_name_or_path,
537
+ encoder_config=config.encoder,
538
+ decoder_config=config.decoder,
539
+ encoder_seed=training_args.seed,
540
+ decoder_seed=training_args.seed,
541
+ encoder_dtype=getattr(jnp, model_args.dtype),
542
+ decoder_dtype=getattr(jnp, model_args.dtype),
543
+ )
544
+ # Set `encoder-decoder` (top-level) specific config
545
+ model.config.decoder_start_token_id = decoder_start_token_id
546
+ model.config.bos_token_id = bos_token_id
547
+ model.config.eos_token_id = eos_token_id
548
+ model.config.pad_token_id = pad_token_id
549
+ else:
550
+ model = FlaxAutoModelForVision2Seq.from_config(
551
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
552
+ )
553
 
554
  # Preprocessing the datasets.
555
  # We need to tokenize inputs and targets.
 
563
  logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
564
  return
565
 
566
+ # Get the column names for input/target.
567
+ dataset_columns = image_captioning_name_mapping.get(data_args.dataset_name, None)
568
+ if data_args.image_column is None:
569
+ assert dataset_columns is not None
570
+ image_column = dataset_columns[0]
571
+ else:
572
+ image_column = data_args.image_column
573
+ if image_column not in column_names:
574
+ raise ValueError(
575
+ f"--image_column' value '{data_args.image_column}' needs to be one of: {', '.join(column_names)}"
576
+ )
577
+ if data_args.caption_column is None:
578
+ assert dataset_columns is not None
579
+ caption_column = dataset_columns[1]
580
+ else:
581
+ caption_column = data_args.caption_column
582
+ if caption_column not in column_names:
583
+ raise ValueError(
584
+ f"--caption_column' value '{data_args.caption_column}' needs to be one of: {', '.join(column_names)}"
585
+ )
586
 
587
  # Temporarily set max_target_length for training.
588
  max_target_length = data_args.max_target_length
 
590
  # In Flax, for seq2seq models we need to pass `decoder_input_ids`
591
  # as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
592
  # for that dynamically import the `shift_tokens_right` function from the model file
593
+ model_module = __import__(model.__module__, fromlist=["shift_tokens_right"])
594
+ shift_tokens_right_fn = getattr(model_module, "shift_tokens_right", shift_tokens_right)
595
 
596
  # Setting padding="max_length" as we need fixed length inputs for jitted functions
597
  def preprocess_function(examples):
598
+
599
+ pixel_values = []
600
+ captions = []
601
+ for image_file, caption in zip(examples[image_column], examples[caption_column]):
602
+ with Image.open(image_file) as image:
603
  try:
604
  encoder_inputs = feature_extractor(images=image, return_tensors="np")
605
  except:
606
  continue
607
+ pixel_values.append(encoder_inputs.pixel_values)
608
+ captions.append(caption + ' ' + tokenizer.eos_token)
 
 
609
 
610
+ pixel_values = np.concatenate(pixel_values)
611
+ targets = captions
 
 
612
 
613
  model_inputs = {}
614
  model_inputs['pixel_values'] = pixel_values
 
620
  )
621
 
622
  model_inputs["labels"] = labels["input_ids"]
 
 
 
 
 
623
  decoder_input_ids = shift_tokens_right_fn(
624
+ jnp.array(labels["input_ids"]), config.pad_token_id, config.decoder_start_token_id
625
  )
626
+ model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
627
 
628
  # We need decoder_attention_mask so we can ignore pad tokens from loss
629
+ model_inputs["decoder_attention_mask"] = labels["attention_mask"]
630
 
631
  return model_inputs
632
 
 
636
  train_dataset = dataset["train"]
637
  if data_args.max_train_samples is not None:
638
  train_dataset = train_dataset.select(range(data_args.max_train_samples))
 
639
  train_dataset = train_dataset.map(
640
  preprocess_function,
641
  batched=True,
 
770
  )
771
 
772
  # Setup train state
773
+ state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
774
 
775
  # label smoothed cross entropy
776
  def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
 
801
  def compute_loss(params):
802
  labels = batch.pop("labels")
803
  logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
804
+ loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
805
  return loss
806
 
807
  grad_fn = jax.value_and_grad(compute_loss)
 
819
  def eval_step(params, batch, label_smoothing_factor=0.0):
820
  labels = batch.pop("labels")
821
  logits = model(**batch, params=params, train=False)[0]
822
+ loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
823
 
824
  # summarize metrics
825
  metrics = {"loss": loss}
 
835
 
836
  def generate_step(params, batch):
837
  model.params = params
 
 
 
 
 
 
838
  output_ids = model.generate(batch['pixel_values'], **gen_kwargs)
 
 
839
  return output_ids.sequences
840
 
841
  # Create parallel version of the train and eval step
 
885
  with open(os.path.join(training_args.output_dir, f'report.txt'), 'a', encoding='UTF-8') as fp:
886
  fp.write(desc + '\n')
887
 
 
888
  # ======================== Evaluating ==============================
889
  eval_metrics = []
890
  eval_preds = []
 
925
  with open(os.path.join(training_args.output_dir, f'report.txt'), 'a', encoding='UTF-8') as fp:
926
  fp.write(desc + '\n')
927
 
 
928
  # Save metrics
929
  if has_tensorboard and jax.process_index() == 0:
930
  cur_step = epoch * (len(train_dataset) // train_batch_size)
 
972
  logger.info(desc)
973
  with open(os.path.join(training_args.output_dir, f'report.txt'), 'a', encoding='UTF-8') as fp:
974
  fp.write(desc + '\n')
 
975
 
976
  # save checkpoint after each epoch and push checkpoint to the hub
977
  if jax.process_index() == 0:
978
  params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
979
+ model.save_pretrained(os.path.join(training_args.output_dir, f'ckpt_{epoch+1}'), params=params)
980
+ tokenizer.save_pretrained(training_args.output_dir)
981
+ if training_args.push_to_hub:
982
+ repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
983
+
 
984
 
985
  if __name__ == "__main__":
986
  main()
run_summarization_flax.py CHANGED
@@ -32,6 +32,7 @@ import nltk # Here to have a nice missing dependency error message early on
32
  import numpy as np
33
  from datasets import Dataset, load_dataset, load_metric
34
  from tqdm import tqdm
 
35
 
36
  import jax
37
  import jax.numpy as jnp
@@ -45,13 +46,15 @@ from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_ke
45
  from huggingface_hub import Repository
46
  from transformers import (
47
  CONFIG_MAPPING,
48
- FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
49
  AutoConfig,
 
50
  AutoTokenizer,
51
  FlaxAutoModelForSeq2SeqLM,
52
  HfArgumentParser,
53
  TrainingArguments,
54
  is_tensorboard_available,
 
55
  )
56
  from transformers.file_utils import get_full_repo_name, is_offline_mode
57
 
@@ -69,10 +72,23 @@ except (LookupError, OSError):
69
  nltk.download("punkt", quiet=True)
70
 
71
 
72
- MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
73
  MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
74
 
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  @dataclass
77
  class ModelArguments:
78
  """
@@ -86,15 +102,46 @@ class ModelArguments:
86
  "Don't set if you want to train a model from scratch."
87
  },
88
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  model_type: Optional[str] = field(
90
  default=None,
91
  metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
92
  )
 
 
 
 
 
 
 
 
93
  config_name: Optional[str] = field(
94
  default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
95
  )
 
 
 
 
 
 
 
 
 
96
  tokenizer_name: Optional[str] = field(
97
- default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
98
  )
99
  cache_dir: Optional[str] = field(
100
  default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
@@ -123,13 +170,16 @@ class DataTrainingArguments:
123
  dataset_config_name: Optional[str] = field(
124
  default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
125
  )
126
- text_column: Optional[str] = field(
 
 
 
127
  default=None,
128
- metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
129
  )
130
- summary_column: Optional[str] = field(
131
  default=None,
132
- metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
133
  )
134
  train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
135
  validation_file: Optional[str] = field(
@@ -188,9 +238,6 @@ class DataTrainingArguments:
188
  default=None,
189
  metadata={"help": "The number of processes to use for the preprocessing."},
190
  )
191
- source_prefix: Optional[str] = field(
192
- default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
193
- )
194
  predict_with_generate: bool = field(
195
  default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
196
  )
@@ -219,18 +266,8 @@ class DataTrainingArguments:
219
  self.val_max_target_length = self.max_target_length
220
 
221
 
222
- summarization_name_mapping = {
223
- "amazon_reviews_multi": ("review_body", "review_title"),
224
- "big_patent": ("description", "abstract"),
225
- "cnn_dailymail": ("article", "highlights"),
226
- "orange_sum": ("text", "summary"),
227
- "pn_summary": ("article", "summary"),
228
- "psc": ("extract_text", "summary_text"),
229
- "samsum": ("dialogue", "summary"),
230
- "thaisum": ("body", "summary"),
231
- "xglue": ("news_body", "news_title"),
232
- "xsum": ("document", "summary"),
233
- "wiki_summary": ("article", "highlights"),
234
  }
235
 
236
 
@@ -354,7 +391,7 @@ def main():
354
  if data_args.dataset_name is not None:
355
  # Downloading and loading a dataset from the hub.
356
  dataset = load_dataset(
357
- data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False
358
  )
359
  else:
360
  data_files = {}
@@ -367,48 +404,153 @@ def main():
367
  if data_args.test_file is not None:
368
  data_files["test"] = data_args.test_file
369
  extension = data_args.test_file.split(".")[-1]
370
- dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
 
371
  # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
372
  # https://huggingface.co/docs/datasets/loading_datasets.html.
373
 
374
  # Load pretrained model and tokenizer
375
 
 
 
 
 
 
376
  if model_args.config_name:
377
  config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
378
  elif model_args.model_name_or_path:
379
  config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
  else:
381
  config = CONFIG_MAPPING[model_args.model_type]()
382
  logger.warning("You are instantiating a new config instance from scratch.")
383
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  if model_args.tokenizer_name:
385
  tokenizer = AutoTokenizer.from_pretrained(
386
  model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
387
  )
388
  elif model_args.model_name_or_path:
389
- tokenizer = AutoTokenizer.from_pretrained(
390
- model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
391
- )
392
- else:
393
- raise ValueError(
394
- "You are instantiating a new tokenizer from scratch. This is not supported by this script."
395
- "You can do it from another script, save it, and load it from here, using --tokenizer_name."
396
- )
 
 
 
 
 
 
 
 
 
397
 
398
  if model_args.model_name_or_path:
399
- model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
400
  model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
401
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
  else:
403
- model = FlaxAutoModelForSeq2SeqLM.from_config(
404
  config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
405
  )
406
 
407
- if model.config.decoder_start_token_id is None:
408
- raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
409
-
410
- prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
411
-
412
  # Preprocessing the datasets.
413
  # We need to tokenize inputs and targets.
414
  if training_args.do_train:
@@ -422,22 +564,24 @@ def main():
422
  return
423
 
424
  # Get the column names for input/target.
425
- dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
426
- if data_args.text_column is None:
427
- text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
 
428
  else:
429
- text_column = data_args.text_column
430
- if text_column not in column_names:
431
  raise ValueError(
432
- f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}"
433
  )
434
- if data_args.summary_column is None:
435
- summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
 
436
  else:
437
- summary_column = data_args.summary_column
438
- if summary_column not in column_names:
439
  raise ValueError(
440
- f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}"
441
  )
442
 
443
  # Temporarily set max_target_length for training.
@@ -446,17 +590,28 @@ def main():
446
  # In Flax, for seq2seq models we need to pass `decoder_input_ids`
447
  # as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
448
  # for that dynamically import the `shift_tokens_right` function from the model file
449
- model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
450
- shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
451
 
452
  # Setting padding="max_length" as we need fixed length inputs for jitted functions
453
  def preprocess_function(examples):
454
- inputs = examples[text_column]
455
- targets = examples[summary_column]
456
- inputs = [prefix + inp for inp in inputs]
457
- model_inputs = tokenizer(
458
- inputs, max_length=data_args.max_source_length, padding="max_length", truncation=True, return_tensors="np"
459
- )
 
 
 
 
 
 
 
 
 
 
 
460
 
461
  # Setup the tokenizer for targets
462
  with tokenizer.as_target_tokenizer():
@@ -680,7 +835,7 @@ def main():
680
 
681
  def generate_step(params, batch):
682
  model.params = params
683
- output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
684
  return output_ids.sequences
685
 
686
  # Create parallel version of the train and eval step
@@ -723,9 +878,12 @@ def main():
723
 
724
  train_metric = unreplicate(train_metric)
725
 
726
- epochs.write(
727
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
728
- )
 
 
 
729
 
730
  # ======================== Evaluating ==============================
731
  eval_metrics = []
@@ -763,55 +921,62 @@ def main():
763
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
764
  epochs.write(desc)
765
  epochs.desc = desc
 
 
 
766
 
767
  # Save metrics
768
  if has_tensorboard and jax.process_index() == 0:
769
  cur_step = epoch * (len(train_dataset) // train_batch_size)
770
  write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
771
 
772
- # ======================== Prediction loop ==============================
773
- if training_args.do_predict:
774
- logger.info("*** Predict ***")
775
-
776
- pred_metrics = []
777
- pred_generations = []
778
- pred_labels = []
779
-
780
- pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size)
781
- pred_steps = len(predict_dataset) // eval_batch_size
782
- for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False):
783
- # Model forward
784
- batch = next(pred_loader)
785
- labels = batch["labels"]
786
-
787
- metrics = p_eval_step(state.params, batch)
788
- pred_metrics.append(metrics)
789
-
790
- # generation
 
 
 
 
 
 
 
 
 
 
 
791
  if data_args.predict_with_generate:
792
- generated_ids = p_generate_step(state.params, batch)
793
- pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
794
- pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
795
-
796
- # normalize prediction metrics
797
- pred_metrics = get_metrics(pred_metrics)
798
- pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
799
-
800
- # compute ROUGE metrics
801
- rouge_desc = ""
802
- if data_args.predict_with_generate:
803
- rouge_metrics = compute_metrics(pred_generations, pred_labels)
804
- pred_metrics.update(rouge_metrics)
805
- rouge_desc = " ".join([f"Predict {key}: {value} |" for key, value in rouge_metrics.items()])
806
-
807
- # Print metrics
808
- desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
809
- logger.info(desc)
810
 
811
  # save checkpoint after each epoch and push checkpoint to the hub
812
  if jax.process_index() == 0:
813
  params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
814
- model.save_pretrained(training_args.output_dir, params=params)
815
  tokenizer.save_pretrained(training_args.output_dir)
816
  if training_args.push_to_hub:
817
  repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
 
32
  import numpy as np
33
  from datasets import Dataset, load_dataset, load_metric
34
  from tqdm import tqdm
35
+ from PIL import Image
36
 
37
  import jax
38
  import jax.numpy as jnp
 
46
  from huggingface_hub import Repository
47
  from transformers import (
48
  CONFIG_MAPPING,
49
+ FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
50
  AutoConfig,
51
+ AutoFeatureExtractor,
52
  AutoTokenizer,
53
  FlaxAutoModelForSeq2SeqLM,
54
  HfArgumentParser,
55
  TrainingArguments,
56
  is_tensorboard_available,
57
+ FlaxAutoModelForVision2Seq,
58
  )
59
  from transformers.file_utils import get_full_repo_name, is_offline_mode
60
 
 
72
  nltk.download("punkt", quiet=True)
73
 
74
 
75
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING.keys())
76
  MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
77
 
78
 
79
+ # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
80
+ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
81
+ """
82
+ Shift input ids one token to the right.
83
+ """
84
+ shifted_input_ids = np.zeros_like(input_ids)
85
+ shifted_input_ids[:, 1:] = input_ids[:, :-1]
86
+ shifted_input_ids[:, 0] = decoder_start_token_id
87
+
88
+ shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
89
+ return shifted_input_ids
90
+
91
+
92
  @dataclass
93
  class ModelArguments:
94
  """
 
102
  "Don't set if you want to train a model from scratch."
103
  },
104
  )
105
+ encoder_model_name_or_path: Optional[str] = field(
106
+ default=None,
107
+ metadata={
108
+ "help": "The encoder model checkpoint for weights initialization."
109
+ "Don't set if you want to train a model from scratch."
110
+ },
111
+ )
112
+ decoder_model_name_or_path: Optional[str] = field(
113
+ default=None,
114
+ metadata={
115
+ "help": "The decoder model checkpoint for weights initialization."
116
+ "Don't set if you want to train a model from scratch."
117
+ },
118
+ )
119
  model_type: Optional[str] = field(
120
  default=None,
121
  metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
122
  )
123
+ encoder_model_type: Optional[str] = field(
124
+ default=None,
125
+ metadata={"help": "If training from scratch, pass a encoder model type from the list: " + ", ".join(MODEL_TYPES)},
126
+ )
127
+ decoder_model_type: Optional[str] = field(
128
+ default=None,
129
+ metadata={"help": "If training from scratch, pass a decoder model type from the list: " + ", ".join(MODEL_TYPES)},
130
+ )
131
  config_name: Optional[str] = field(
132
  default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
133
  )
134
+ encoder_config_name: Optional[str] = field(
135
+ default=None, metadata={"help": "Pretrained config name or path if not the same as encoder_model_name"}
136
+ )
137
+ decoder_config_name: Optional[str] = field(
138
+ default=None, metadata={"help": "Pretrained config name or path if not the same as decoder_model_name"}
139
+ )
140
+ feature_extractor_name: Optional[str] = field(
141
+ default=None, metadata={"help": "Pretrained feature extractor_name name or path if not the same as encoder_model_name"}
142
+ )
143
  tokenizer_name: Optional[str] = field(
144
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as decoder_model_name"}
145
  )
146
  cache_dir: Optional[str] = field(
147
  default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
 
170
  dataset_config_name: Optional[str] = field(
171
  default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
172
  )
173
+ data_dir: Optional[str] = field(
174
+ default=None, metadata={"help": "The data directory of the dataset to use (via the datasets library)."}
175
+ )
176
+ image_column: Optional[str] = field(
177
  default=None,
178
+ metadata={"help": "The name of the column in the datasets containing the full image file paths (for image captioning)."},
179
  )
180
+ caption_column: Optional[str] = field(
181
  default=None,
182
+ metadata={"help": "The name of the column in the datasets containing the image captions (for image captioning)."},
183
  )
184
  train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
185
  validation_file: Optional[str] = field(
 
238
  default=None,
239
  metadata={"help": "The number of processes to use for the preprocessing."},
240
  )
 
 
 
241
  predict_with_generate: bool = field(
242
  default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
243
  )
 
266
  self.val_max_target_length = self.max_target_length
267
 
268
 
269
+ image_captioning_name_mapping = {
270
+ "image_caption_dataset.py": ("image_file", "caption"),
 
 
 
 
 
 
 
 
 
 
271
  }
272
 
273
 
 
391
  if data_args.dataset_name is not None:
392
  # Downloading and loading a dataset from the hub.
393
  dataset = load_dataset(
394
+ data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False, data_dir=data_args.data_dir
395
  )
396
  else:
397
  data_files = {}
 
404
  if data_args.test_file is not None:
405
  data_files["test"] = data_args.test_file
406
  extension = data_args.test_file.split(".")[-1]
407
+ # TODO: Check
408
+ dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir, data_dir=data_args.data_dir)
409
  # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
410
  # https://huggingface.co/docs/datasets/loading_datasets.html.
411
 
412
  # Load pretrained model and tokenizer
413
 
414
+ encoder_cache_dir, decoder_cache_dir = None, None
415
+ if model_args.cache_dir:
416
+ encoder_cache_dir = os.path.join(model_args.cache_dir, "encoder")
417
+ decoder_cache_dir = os.path.join(model_args.cache_dir, "decoder")
418
+
419
  if model_args.config_name:
420
  config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
421
  elif model_args.model_name_or_path:
422
  config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
423
+ elif getattr(CONFIG_MAPPING[model_args.model_type], "from_encoder_decoder_configs", None):
424
+
425
+ config_class = CONFIG_MAPPING[model_args.model_type]
426
+
427
+ if model_args.encoder_config_name:
428
+ encoder_config = AutoConfig.from_pretrained(model_args.encoder_config_name, cache_dir=encoder_cache_dir)
429
+ elif model_args.encoder_model_name_or_path:
430
+ encoder_config = AutoConfig.from_pretrained(model_args.encoder_model_name_or_path, cache_dir=encoder_cache_dir)
431
+ else:
432
+ encoder_config = CONFIG_MAPPING[model_args.encoder_model_type]()
433
+ logger.warning("You are instantiating a new config instance from scratch for the encoder.")
434
+
435
+ if model_args.decoder_config_name:
436
+ decoder_config = AutoConfig.from_pretrained(model_args.decoder_config_name, cache_dir=decoder_cache_dir)
437
+ elif model_args.decoder_model_name_or_path:
438
+ decoder_config = AutoConfig.from_pretrained(model_args.decoder_model_name_or_path, cache_dir=decoder_cache_dir)
439
+ else:
440
+ decoder_config = CONFIG_MAPPING[model_args.decoder_model_type]()
441
+ logger.warning("You are instantiating a new config instance from scratch for the decoder.")
442
+
443
+ logger.info("Setting `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
444
+ decoder_config.is_decoder = True
445
+ decoder_config.add_cross_attention = True
446
+
447
+ config = config_class.from_encoder_decoder_configs(encoder_config, decoder_config)
448
  else:
449
  config = CONFIG_MAPPING[model_args.model_type]()
450
  logger.warning("You are instantiating a new config instance from scratch.")
451
 
452
+ decoder_start_token_id = getattr(config, "decoder_start_token_id", None)
453
+ if not decoder_start_token_id and getattr(config, "decoder", None):
454
+ decoder_start_token_id = getattr(config.decoder, "decoder_start_token_id", None)
455
+ bos_token_id = getattr(config, "bos_token_id", None)
456
+ if not bos_token_id and getattr(config, "decoder", None):
457
+ bos_token_id = getattr(config.decoder, "bos_token_id", None)
458
+ eos_token_id = getattr(config, "eos_token_id", None)
459
+ if not eos_token_id and getattr(config, "decoder", None):
460
+ eos_token_id = getattr(config.decoder, "eos_token_id", None)
461
+ pad_token_id = getattr(config, "pad_token_id", None)
462
+ if not pad_token_id and getattr(config, "decoder", None):
463
+ pad_token_id = getattr(config.decoder, "pad_token_id", None)
464
+
465
+ if decoder_start_token_id is None:
466
+ decoder_start_token_id = bos_token_id
467
+ if pad_token_id is None:
468
+ pad_token_id = eos_token_id
469
+
470
+ config.decoder_start_token_id = decoder_start_token_id
471
+ config.bos_token_id = bos_token_id
472
+ config.eos_token_id = eos_token_id
473
+ config.pad_token_id = pad_token_id
474
+
475
+ if getattr(config, "decoder", None):
476
+ config.decoder.decoder_start_token_id = decoder_start_token_id
477
+ config.decoder.bos_token_id = bos_token_id
478
+ config.decoder.eos_token_id = eos_token_id
479
+ config.decoder.pad_token_id = pad_token_id
480
+
481
+ feature_extractor = None
482
+ if model_args.feature_extractor_name:
483
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
484
+ model_args.feature_extractor_name, cache_dir=model_args.cache_dir,
485
+ )
486
+ elif model_args.model_name_or_path:
487
+ try:
488
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
489
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir
490
+ )
491
+ except ValueError as e:
492
+ logger.warning(e)
493
+ if not feature_extractor:
494
+ if model_args.encoder_model_name_or_path:
495
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
496
+ model_args.encoder_model_name_or_path, cache_dir=model_args.cache_dir
497
+ )
498
+ else:
499
+ raise ValueError(
500
+ "You are instantiating a new feature extractor from scratch. This is not supported by this script."
501
+ "You can do it from another script, save it, and load it from here, using --feature_extractor_name."
502
+ )
503
+
504
+ tokenizer = None
505
  if model_args.tokenizer_name:
506
  tokenizer = AutoTokenizer.from_pretrained(
507
  model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
508
  )
509
  elif model_args.model_name_or_path:
510
+ try:
511
+ tokenizer = AutoTokenizer.from_pretrained(
512
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
513
+ )
514
+ except ValueError as e:
515
+ logger.warning(e)
516
+ if not tokenizer:
517
+ if model_args.decoder_model_name_or_path:
518
+ tokenizer = AutoTokenizer.from_pretrained(
519
+ model_args.decoder_model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
520
+ )
521
+ else:
522
+ raise ValueError(
523
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
524
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
525
+ )
526
+ tokenizer.pad_token = tokenizer.convert_ids_to_tokens(config.pad_token_id)
527
 
528
  if model_args.model_name_or_path:
529
+ model = FlaxAutoModelForVision2Seq.from_pretrained(
530
  model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
531
  )
532
+ elif model_args.encoder_model_name_or_path and model_args.decoder_model_name_or_path:
533
+ model_class = FlaxAutoModelForVision2Seq.from_config(config).__class__
534
+ model = model_class.from_encoder_decoder_pretrained(
535
+ model_args.encoder_model_name_or_path,
536
+ model_args.decoder_model_name_or_path,
537
+ encoder_config=config.encoder,
538
+ decoder_config=config.decoder,
539
+ encoder_seed=training_args.seed,
540
+ decoder_seed=training_args.seed,
541
+ encoder_dtype=getattr(jnp, model_args.dtype),
542
+ decoder_dtype=getattr(jnp, model_args.dtype),
543
+ )
544
+ # Set `encoder-decoder` (top-level) specific config
545
+ model.config.decoder_start_token_id = decoder_start_token_id
546
+ model.config.bos_token_id = bos_token_id
547
+ model.config.eos_token_id = eos_token_id
548
+ model.config.pad_token_id = pad_token_id
549
  else:
550
+ model = FlaxAutoModelForVision2Seq.from_config(
551
  config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
552
  )
553
 
 
 
 
 
 
554
  # Preprocessing the datasets.
555
  # We need to tokenize inputs and targets.
556
  if training_args.do_train:
 
564
  return
565
 
566
  # Get the column names for input/target.
567
+ dataset_columns = image_captioning_name_mapping.get(data_args.dataset_name, None)
568
+ if data_args.image_column is None:
569
+ assert dataset_columns is not None
570
+ image_column = dataset_columns[0]
571
  else:
572
+ image_column = data_args.image_column
573
+ if image_column not in column_names:
574
  raise ValueError(
575
+ f"--image_column' value '{data_args.image_column}' needs to be one of: {', '.join(column_names)}"
576
  )
577
+ if data_args.caption_column is None:
578
+ assert dataset_columns is not None
579
+ caption_column = dataset_columns[1]
580
  else:
581
+ caption_column = data_args.caption_column
582
+ if caption_column not in column_names:
583
  raise ValueError(
584
+ f"--caption_column' value '{data_args.caption_column}' needs to be one of: {', '.join(column_names)}"
585
  )
586
 
587
  # Temporarily set max_target_length for training.
 
590
  # In Flax, for seq2seq models we need to pass `decoder_input_ids`
591
  # as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
592
  # for that dynamically import the `shift_tokens_right` function from the model file
593
+ model_module = __import__(model.__module__, fromlist=["shift_tokens_right"])
594
+ shift_tokens_right_fn = getattr(model_module, "shift_tokens_right", shift_tokens_right)
595
 
596
  # Setting padding="max_length" as we need fixed length inputs for jitted functions
597
  def preprocess_function(examples):
598
+
599
+ pixel_values = []
600
+ captions = []
601
+ for image_file, caption in zip(examples[image_column], examples[caption_column]):
602
+ with Image.open(image_file) as image:
603
+ try:
604
+ encoder_inputs = feature_extractor(images=image, return_tensors="np")
605
+ except:
606
+ continue
607
+ pixel_values.append(encoder_inputs.pixel_values)
608
+ captions.append(caption + ' ' + tokenizer.eos_token)
609
+
610
+ pixel_values = np.concatenate(pixel_values)
611
+ targets = captions
612
+
613
+ model_inputs = {}
614
+ model_inputs['pixel_values'] = pixel_values
615
 
616
  # Setup the tokenizer for targets
617
  with tokenizer.as_target_tokenizer():
 
835
 
836
  def generate_step(params, batch):
837
  model.params = params
838
+ output_ids = model.generate(batch['pixel_values'], **gen_kwargs)
839
  return output_ids.sequences
840
 
841
  # Create parallel version of the train and eval step
 
878
 
879
  train_metric = unreplicate(train_metric)
880
 
881
+ desc = f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
882
+ epochs.write(desc)
883
+ epochs.desc = desc
884
+ logger.info(desc)
885
+ with open(os.path.join(training_args.output_dir, f'report.txt'), 'a', encoding='UTF-8') as fp:
886
+ fp.write(desc + '\n')
887
 
888
  # ======================== Evaluating ==============================
889
  eval_metrics = []
 
921
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
922
  epochs.write(desc)
923
  epochs.desc = desc
924
+ logger.info(desc)
925
+ with open(os.path.join(training_args.output_dir, f'report.txt'), 'a', encoding='UTF-8') as fp:
926
+ fp.write(desc + '\n')
927
 
928
  # Save metrics
929
  if has_tensorboard and jax.process_index() == 0:
930
  cur_step = epoch * (len(train_dataset) // train_batch_size)
931
  write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
932
 
933
+ # ======================== Prediction loop ==============================
934
+ if training_args.do_predict:
935
+ logger.info("*** Predict ***")
936
+
937
+ pred_metrics = []
938
+ pred_generations = []
939
+ pred_labels = []
940
+
941
+ pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size)
942
+ pred_steps = len(predict_dataset) // eval_batch_size
943
+ for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False):
944
+ # Model forward
945
+ batch = next(pred_loader)
946
+ labels = batch["labels"]
947
+
948
+ metrics = p_eval_step(state.params, batch)
949
+ pred_metrics.append(metrics)
950
+
951
+ # generation
952
+ if data_args.predict_with_generate:
953
+ generated_ids = p_generate_step(state.params, batch)
954
+ pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
955
+ pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
956
+
957
+ # normalize prediction metrics
958
+ pred_metrics = get_metrics(pred_metrics)
959
+ pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
960
+
961
+ # compute ROUGE metrics
962
+ rouge_desc = ""
963
  if data_args.predict_with_generate:
964
+ rouge_metrics = compute_metrics(pred_generations, pred_labels)
965
+ pred_metrics.update(rouge_metrics)
966
+ rouge_desc = " ".join([f"Predict {key}: {value} |" for key, value in rouge_metrics.items()])
967
+
968
+ # Print metrics
969
+ desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
970
+ epochs.write(desc)
971
+ epochs.desc = desc
972
+ logger.info(desc)
973
+ with open(os.path.join(training_args.output_dir, f'report.txt'), 'a', encoding='UTF-8') as fp:
974
+ fp.write(desc + '\n')
 
 
 
 
 
 
 
975
 
976
  # save checkpoint after each epoch and push checkpoint to the hub
977
  if jax.process_index() == 0:
978
  params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
979
+ model.save_pretrained(os.path.join(training_args.output_dir, f'ckpt_{epoch+1}'), params=params)
980
  tokenizer.save_pretrained(training_args.output_dir)
981
  if training_args.push_to_hub:
982
  repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)