boris commited on
Commit
a96f44d
1 Parent(s): 39caefb

feat: handle streaming

Browse files
dev/seq2seq/do_big_run.sh CHANGED
@@ -1,7 +1,11 @@
1
  python run_seq2seq_flax.py \
2
  --max_source_length 128 \
3
- --train_file /data/CC12M/encoded-small-train.tsv \
4
- --validation_file /data/CC12M/encoded-small-valid.tsv \
 
 
 
 
5
  --output_dir output \
6
  --per_device_train_batch_size 56 \
7
  --per_device_eval_batch_size 56 \
 
1
  python run_seq2seq_flax.py \
2
  --max_source_length 128 \
3
+ --dataset_repo_or_path dalle-mini/encoded \
4
+ --train_file **/train/*/*.jsonl \
5
+ --validation_file **/valid/*/*.jsonl \
6
+ --streaming \
7
+ --len_train 1000000 \
8
+ --len_eval 100 \
9
  --output_dir output \
10
  --per_device_train_batch_size 56 \
11
  --per_device_eval_batch_size 56 \
dev/seq2seq/do_small_run.sh CHANGED
@@ -1,7 +1,10 @@
1
  python run_seq2seq_flax.py \
2
- --max_source_length 128 \
3
- --train_file /data/CC12M/encoded-small-train.tsv \
4
- --validation_file /data/CC12M/encoded-small-valid.tsv \
 
 
 
5
  --output_dir output \
6
  --per_device_train_batch_size 56 \
7
  --per_device_eval_batch_size 56 \
 
1
  python run_seq2seq_flax.py \
2
+ --dataset_repo_or_path dalle-mini/encoded \
3
+ --train_file **/train/*/*.jsonl \
4
+ --validation_file **/valid/*/*.jsonl \
5
+ --streaming \
6
+ --len_train 1000000 \
7
+ --len_eval 1000 \
8
  --output_dir output \
9
  --per_device_train_batch_size 56 \
10
  --per_device_eval_batch_size 56 \
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -20,9 +20,8 @@ Script adapted from run_summarization_flax.py
20
  # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
21
 
22
  import os
23
- import logging as pylogging # To avoid collision with transformers.utils.logging
24
  import sys
25
- import time
26
  from dataclasses import dataclass, field
27
  from functools import partial
28
  from pathlib import Path
@@ -30,7 +29,6 @@ from typing import Callable, Optional
30
  import json
31
 
32
  import datasets
33
- import nltk # Here to have a nice missing dependency error message early on
34
  import numpy as np
35
  from datasets import Dataset, load_dataset, load_metric
36
  from tqdm import tqdm
@@ -47,9 +45,7 @@ from flax.jax_utils import unreplicate
47
  from flax.training import train_state
48
  from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
49
  from transformers import (
50
- CONFIG_MAPPING,
51
  FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
52
- AutoConfig,
53
  AutoTokenizer,
54
  FlaxAutoModelForSeq2SeqLM,
55
  FlaxBartForConditionalGeneration,
@@ -61,17 +57,9 @@ from transformers.file_utils import is_offline_mode
61
 
62
  import wandb
63
 
64
- logger = pylogging.getLogger(__name__)
65
 
66
- try:
67
- nltk.data.find("tokenizers/punkt")
68
- except (LookupError, OSError):
69
- if is_offline_mode():
70
- raise LookupError(
71
- "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
72
- )
73
- with FileLock(".lock") as lock:
74
- nltk.download("punkt", quiet=True)
75
 
76
 
77
  MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
@@ -83,7 +71,7 @@ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
83
  OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
84
  OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
85
  BOS_TOKEN_ID = 16384
86
- BASE_MODEL = 'facebook/bart-large-cnn' # we currently have issues with bart-large
87
 
88
 
89
  @dataclass
@@ -101,20 +89,34 @@ class ModelArguments:
101
  )
102
  model_type: Optional[str] = field(
103
  default=None,
104
- metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
 
 
 
105
  )
106
  config_name: Optional[str] = field(
107
- default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
 
 
 
108
  )
109
  tokenizer_name: Optional[str] = field(
110
- default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
 
 
 
111
  )
112
  cache_dir: Optional[str] = field(
113
- default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
 
 
 
114
  )
115
  use_fast_tokenizer: bool = field(
116
  default=True,
117
- metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
 
 
118
  )
119
  dtype: Optional[str] = field(
120
  default="float32",
@@ -137,27 +139,51 @@ class DataTrainingArguments:
137
  """
138
 
139
  dataset_name: Optional[str] = field(
140
- default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
 
141
  )
142
  dataset_config_name: Optional[str] = field(
143
- default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
 
 
 
144
  )
145
  text_column: Optional[str] = field(
146
- default='caption',
147
- metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
 
 
148
  )
149
  encoding_column: Optional[str] = field(
150
- default='encoding',
151
- metadata={"help": "The name of the column in the datasets containing the image encodings."},
 
 
 
 
 
 
 
 
 
152
  )
153
- train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
154
  validation_file: Optional[str] = field(
155
  default=None,
156
- metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
 
 
 
 
 
 
157
  )
158
- test_file: Optional[str] = field(
159
  default=None,
160
- metadata={"help": "An optional input predict data file to do prediction on (a text file)."},
 
 
 
 
161
  )
162
  max_source_length: Optional[int] = field(
163
  default=128,
@@ -167,7 +193,8 @@ class DataTrainingArguments:
167
  },
168
  )
169
  no_decay: bool = field(
170
- default=False, metadata={"help": "Whether to use decay in the learning rate scheduler."}
 
171
  )
172
  max_target_length: Optional[int] = field(
173
  default=OUTPUT_LENGTH,
@@ -199,60 +226,65 @@ class DataTrainingArguments:
199
  "value if set."
200
  },
201
  )
202
- max_predict_samples: Optional[int] = field(
203
- default=None,
204
- metadata={
205
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
206
- "value if set."
207
- },
208
  )
209
  preprocessing_num_workers: Optional[int] = field(
210
- default=80, # ensure we have the same datasets cached data and avoid using too much space
211
  metadata={"help": "The number of processes to use for the preprocessing."},
212
  )
213
  source_prefix: Optional[str] = field(
214
- default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
215
- )
216
- predict_with_generate: bool = field(
217
- default=False, metadata={"help": "Whether to use generate to calculate generative metrics."}
218
- )
219
- num_beams: Optional[int] = field(
220
  default=None,
221
  metadata={
222
- "help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
223
- "which is used during evaluation."
224
  },
225
  )
226
  overwrite_cache: bool = field(
227
- default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
 
228
  )
229
  log_interval: Optional[int] = field(
230
  default=40,
231
- metadata={
232
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
233
- "value if set."
234
- },
235
  )
236
  log_model: bool = field(
237
- default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
 
238
  )
239
  save_model_steps: Optional[int] = field(
240
- default=3000, # about once every hour in our experiments
241
  metadata={
242
  "help": "For logging the model more frequently. Used only when `log_model` is set."
243
  },
244
  )
245
 
246
  def __post_init__(self):
247
- if self.dataset_name is None and self.train_file is None and self.validation_file is None:
248
- raise ValueError("Need either a dataset name or a training/validation file.")
 
 
 
 
 
 
249
  else:
250
  if self.train_file is not None:
251
  extension = self.train_file.split(".")[-1]
252
- assert extension in ["tsv", "csv", "json"], "`train_file` should be a tsv, csv or json file."
 
 
 
 
 
253
  if self.validation_file is not None:
254
  extension = self.validation_file.split(".")[-1]
255
- assert extension in ["tsv", "csv", "json"], "`validation_file` should be a tsv, csv or json file."
 
 
 
 
 
256
  if self.val_max_target_length is None:
257
  self.val_max_target_length = self.max_target_length
258
 
@@ -263,14 +295,20 @@ class TrainState(train_state.TrainState):
263
  optimizer_step: int
264
 
265
  def replicate(self):
266
- return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
 
 
267
 
268
 
269
  class CustomFlaxBartModule(FlaxBartModule):
270
  def setup(self):
271
  # check config is valid, otherwise set default values
272
- self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE)
273
- self.config.max_position_embeddings_decoder = getattr(self.config, 'max_position_embeddings_decoder', OUTPUT_LENGTH)
 
 
 
 
274
 
275
  # we keep shared to easily load pre-trained weights
276
  self.shared = nn.Embed(
@@ -286,18 +324,29 @@ class CustomFlaxBartModule(FlaxBartModule):
286
  embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
287
  dtype=self.dtype,
288
  )
289
- self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
 
 
290
 
291
  # the decoder has a different config
292
  decoder_config = BartConfig(self.config.to_dict())
293
- decoder_config.max_position_embeddings = self.config.max_position_embeddings_decoder
 
 
294
  decoder_config.vocab_size = self.config.vocab_size_output
295
- self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
 
 
 
296
 
297
- class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
 
 
298
  def setup(self):
299
  # check config is valid, otherwise set default values
300
- self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE)
 
 
301
 
302
  self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
303
  self.lm_head = nn.Dense(
@@ -306,13 +355,18 @@ class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerat
306
  dtype=self.dtype,
307
  kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
308
  )
309
- self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.config.vocab_size_output))
 
 
 
310
 
311
  class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
312
  module_class = CustomFlaxBartForConditionalGenerationModule
313
-
314
 
315
- def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
 
 
 
316
  """
317
  Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
318
  Shuffle batches if `shuffle` is `True`.
@@ -330,33 +384,58 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
330
  for idx in batch_idx:
331
  batch = dataset[idx]
332
  batch = {k: jnp.array(v) for k, v in batch.items()}
333
-
334
  batch = shard(batch)
335
-
336
  yield batch
337
 
338
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
  def create_learning_rate_fn(
340
- train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float, no_decay: bool
 
 
 
 
 
341
  ) -> Callable[[int], jnp.array]:
342
  """Returns a linear warmup, linear_decay learning rate function."""
343
  steps_per_epoch = train_ds_size // train_batch_size
344
  num_train_steps = steps_per_epoch * num_train_epochs
345
- warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
 
 
346
  if no_decay:
347
  return warmup_fn
348
  decay_fn = optax.linear_schedule(
349
- init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
 
 
 
 
 
350
  )
351
- schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
352
  return schedule_fn
353
 
354
 
355
  def wandb_log(metrics, step=None, prefix=None):
356
  if jax.process_index() == 0:
357
- log_metrics = {f'{prefix}/{k}' if prefix is not None else k: jax.device_get(v) for k,v in metrics.items()}
 
 
 
358
  if step is not None:
359
- log_metrics['train/step'] = step
360
  wandb.log(log_metrics)
361
 
362
 
@@ -365,11 +444,15 @@ def main():
365
  # or by passing the --help flag to this script.
366
  # We now keep distinct sets of args, for a cleaner separation of concerns.
367
 
368
- parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
 
 
369
  if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
370
  # If we pass only one argument to the script and it's the path to a json file,
371
  # let's parse it to get our arguments.
372
- model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
 
 
373
  else:
374
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
375
 
@@ -383,18 +466,18 @@ def main():
383
  f"Output directory ({training_args.output_dir}) already exists and is not empty."
384
  "Use --overwrite_output_dir to overcome."
385
  )
386
-
387
  # Set up wandb run
388
  wandb.init(
389
- entity='wandb',
390
- project='hf-flax-dalle-mini',
391
- job_type='Seq2SeqVQGAN',
392
- config=parser.parse_args()
393
  )
394
 
395
  # set default x-axis as 'train/step'
396
- wandb.define_metric('train/step')
397
- wandb.define_metric('*', step_metric='train/step')
398
 
399
  # Make one log on every process with the configuration for debugging.
400
  pylogging.basicConfig(
@@ -418,16 +501,13 @@ def main():
418
  # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
419
  # (the dataset will be downloaded automatically from the datasets Hub).
420
  #
421
- data_files = {}
422
- if data_args.train_file is not None:
423
- data_files["train"] = data_args.train_file
424
- if data_args.validation_file is not None:
425
- data_files["validation"] = data_args.validation_file
426
- if data_args.test_file is not None:
427
- data_files["test"] = data_args.test_file
428
- dataset = load_dataset("csv", data_files=data_files, cache_dir=model_args.cache_dir, delimiter="\t")
429
- # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
430
- # https://huggingface.co/docs/datasets/loading_datasets.html.
431
 
432
  # Set up items to load or create
433
  tokenizer = None
@@ -435,17 +515,17 @@ def main():
435
 
436
  def restore_state(state, artifact_dir):
437
  # restore optimizer state
438
- with (Path(artifact_dir) / 'opt_state.msgpack').open('rb') as f:
439
  opt_state = from_bytes(state.opt_state, f.read())
440
-
441
  # restore steps
442
- with (Path(artifact_dir) / 'training_state.json').open('r') as f:
443
  training_state = json.load(f)
444
- step = training_state['step']
445
  optimizer_step = step // training_args.gradient_accumulation_steps
446
 
447
  return step, optimizer_step, opt_state
448
-
449
  if model_args.from_checkpoint is not None:
450
  artifact = wandb.run.use_artifact(model_args.from_checkpoint)
451
  artifact_dir = artifact.download()
@@ -461,40 +541,54 @@ def main():
461
  config = model.config
462
 
463
  # load tokenizer if present
464
- if (Path(artifact_dir) / 'tokenizer_config.json').exists():
465
  tokenizer = AutoTokenizer.from_pretrained(
466
- model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
467
- )
 
 
468
 
469
  else:
470
  base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
471
- model_args.model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
 
 
472
  )
473
  # Set up our new model config
474
  config = BartConfig.from_pretrained(model_args.model_name_or_path)
475
  config.tie_word_embeddings = False
476
  config.decoder_start_token_id = BOS_TOKEN_ID # for first token
477
- config.bos_token_id = BOS_TOKEN_ID # should not be used (due to forced_bos_token_id)
478
- config.pos_token_id = BOS_TOKEN_ID # should not be needed (as we generate until max_length)
 
 
 
 
479
  config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
480
  config.forced_bos_token_id = None # we don't need this token
481
  config.forced_eos_token_id = None # we don't need this token
482
- config.force_bos_token_to_be_generated = False # otherwise it sets bos_token_id at loading
 
 
483
  config.min_length = data_args.max_target_length
484
  config.max_length = data_args.max_target_length
485
 
486
  # Create a custom model and initialize it randomly
487
- model = CustomFlaxBartForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
 
 
488
 
489
  # Use pre-trained weights for encoder
490
- model.params['model']['encoder'] = base_model.params['model']['encoder']
491
- model.params['model']['shared'] = base_model.params['model']['shared']
492
  del base_model
493
 
494
  # Load tokenizer if it has not been set
495
  if tokenizer is None:
496
  tokenizer = AutoTokenizer.from_pretrained(
497
- model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
 
 
498
  )
499
 
500
  print(f"TPUs: {jax.device_count()}")
@@ -504,23 +598,11 @@ def main():
504
 
505
  # Preprocessing the datasets.
506
  # We need to tokenize inputs and targets.
507
- if training_args.do_train:
508
- column_names = dataset["train"].column_names
509
- elif training_args.do_eval:
510
- column_names = dataset["validation"].column_names
511
- elif training_args.do_predict:
512
- column_names = dataset["test"].column_names
513
- else:
514
- logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
515
- return
516
 
517
  # Get the column names for input/target.
518
  text_column = data_args.text_column
519
  encoding_column = data_args.encoding_column
520
 
521
- # Temporarily set max_target_length for training.
522
- max_target_length = data_args.max_target_length
523
-
524
  def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
525
  """
526
  Shift input ids one token to the right.
@@ -530,18 +612,28 @@ def main():
530
  shifted_input_ids[:, 0] = decoder_start_token_id
531
  return shifted_input_ids
532
 
 
 
 
 
 
 
533
  def preprocess_function(examples):
534
  inputs = examples[text_column]
535
- inputs = [prefix + inp for inp in inputs]
536
- # Setting padding="max_length" as we need fixed length inputs for jitted functions
537
  model_inputs = tokenizer(
538
- inputs, max_length=data_args.max_source_length, padding="max_length", truncation=True, return_tensors="np"
 
 
 
 
539
  )
540
 
541
  # set up targets
542
  # Note: labels correspond to our target indices
543
  # decoder input ids are the same but shifted to the right with bos at the beginning (and without last token)
544
- labels = [eval(indices) for indices in examples['encoding']]
545
  labels = np.asarray(labels)
546
 
547
  # We need the labels, in addition to the decoder_input_ids, for the compute_loss function
@@ -558,46 +650,75 @@ def main():
558
  raise ValueError("--do_train requires a train dataset")
559
  train_dataset = dataset["train"]
560
  if data_args.max_train_samples is not None:
561
- train_dataset = train_dataset.select(range(data_args.max_train_samples))
562
- train_dataset = train_dataset.map(
563
- preprocess_function,
564
- batched=True,
565
- num_proc=data_args.preprocessing_num_workers,
566
- remove_columns=column_names,
567
- load_from_cache_file=not data_args.overwrite_cache,
568
- desc="Running tokenizer on train dataset",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
569
  )
570
 
571
  if training_args.do_eval:
572
- max_target_length = data_args.val_max_target_length
573
  if "validation" not in dataset:
574
  raise ValueError("--do_eval requires a validation dataset")
575
  eval_dataset = dataset["validation"]
576
  if data_args.max_eval_samples is not None:
577
- eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
578
- eval_dataset = eval_dataset.map(
579
- preprocess_function,
580
- batched=True,
581
- num_proc=data_args.preprocessing_num_workers,
582
- remove_columns=column_names,
583
- load_from_cache_file=not data_args.overwrite_cache,
584
- desc="Running tokenizer on validation dataset",
585
- )
586
-
587
- if training_args.do_predict:
588
- max_target_length = data_args.val_max_target_length
589
- if "test" not in dataset:
590
- raise ValueError("--do_predict requires a test dataset")
591
- predict_dataset = dataset["test"]
592
- if data_args.max_predict_samples is not None:
593
- predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
594
- predict_dataset = predict_dataset.map(
595
- preprocess_function,
596
- batched=True,
597
- num_proc=data_args.preprocessing_num_workers,
598
- remove_columns=column_names,
599
- load_from_cache_file=not data_args.overwrite_cache,
600
- desc="Running tokenizer on prediction dataset",
 
 
 
 
 
 
601
  )
602
 
603
  # Initialize our training
@@ -606,21 +727,40 @@ def main():
606
 
607
  # Store some constant
608
  num_epochs = int(training_args.num_train_epochs)
609
- train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
 
 
610
  total_batch_size = int(train_batch_size) * training_args.gradient_accumulation_steps
611
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
612
- steps_per_epoch = len(train_dataset) // train_batch_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
613
  total_steps = steps_per_epoch * num_epochs
614
- total_optimization_steps = (len(train_dataset) // total_batch_size) * num_epochs
615
 
616
  # Create learning rate schedule
617
  linear_decay_lr_schedule_fn = create_learning_rate_fn(
618
- len(train_dataset),
619
  total_batch_size,
620
  training_args.num_train_epochs,
621
  training_args.warmup_steps,
622
  training_args.learning_rate,
623
- data_args.no_decay
624
  )
625
 
626
  # We use Optax's "masking" functionality to not apply weight decay
@@ -633,9 +773,17 @@ def main():
633
  def decay_mask_fn(params):
634
  flat_params = traverse_util.flatten_dict(params)
635
  layer_norm_params = [
636
- (name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
 
 
 
 
 
637
  ]
638
- flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
 
 
 
639
  return traverse_util.unflatten_dict(flat_mask)
640
 
641
  # create adam optimizer
@@ -667,7 +815,9 @@ def main():
667
  if model_args.from_checkpoint is not None:
668
  # restore optimizer state, step and optimizer_step
669
  step, optimizer_step, opt_state = restore_state(state, artifact_dir)
670
- state = state.replace(step=step, optimizer_step=optimizer_step, opt_state=opt_state)
 
 
671
 
672
  # label smoothed cross entropy
673
  def loss_fn(logits, labels):
@@ -681,7 +831,9 @@ def main():
681
 
682
  def compute_loss(params):
683
  labels = batch.pop("labels")
684
- logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
 
 
685
  loss = loss_fn(logits, labels)
686
  return loss
687
 
@@ -690,10 +842,14 @@ def main():
690
  grad_accum = jax.tree_multimap(lambda x, y: x + y, grads, state.grad_accum)
691
 
692
  def update_fn():
693
- grads = jax.tree_map(lambda x: x / training_args.gradient_accumulation_steps, grad_accum)
 
 
694
  grads = jax.lax.pmean(grads, "batch")
695
  new_state = state.apply_gradients(
696
- grads=grads, grad_accum=jax.tree_map(jnp.zeros_like, grads), optimizer_step=state.optimizer_step + 1
 
 
697
  )
698
  return new_state
699
 
@@ -704,7 +860,10 @@ def main():
704
  None,
705
  )
706
 
707
- metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.optimizer_step)}
 
 
 
708
  metrics = jax.lax.pmean(metrics, axis_name="batch")
709
 
710
  return new_state.replace(dropout_rng=new_dropout_rng), metrics
@@ -720,39 +879,25 @@ def main():
720
  metrics = jax.lax.pmean(metrics, axis_name="batch")
721
  return metrics
722
 
723
- # Define generation function
724
- max_length = (
725
- data_args.val_max_target_length if data_args.val_max_target_length is not None else model.config.max_length
726
- )
727
- num_beams = data_args.num_beams if data_args.num_beams is not None else model.config.num_beams
728
- gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
729
-
730
- def generate_step(params, batch):
731
- model.params = params
732
- output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
733
- return output_ids.sequences
734
-
735
  # Create parallel version of the train and eval step
736
- p_train_step = jax.pmap(
737
- train_step, "batch", donate_argnums=(0,)
738
- )
739
  p_eval_step = jax.pmap(eval_step, "batch")
740
- p_generate_step = jax.pmap(generate_step, "batch")
741
 
742
  # Replicate the train state on each device
743
  state = state.replicate()
744
 
745
  logger.info("***** Running training *****")
746
- logger.info(f" Num examples = {len(train_dataset)}")
747
  logger.info(f" Num Epochs = {num_epochs}")
748
- logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
 
 
749
  logger.info(
750
  f" Total train batch size (w. parallel & distributed) = {train_batch_size * training_args.gradient_accumulation_steps}"
751
  )
752
  logger.info(f" Total global steps = {total_steps}")
753
  logger.info(f" Total optimization steps = {total_optimization_steps}")
754
 
755
- train_time = 0
756
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
757
  global_step = 0
758
 
@@ -760,31 +905,31 @@ def main():
760
  # ======================== Evaluating ==============================
761
  eval_metrics = []
762
  if training_args.do_eval:
763
- eval_preds = []
764
- eval_labels = []
765
-
766
- eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
767
- eval_steps = len(eval_dataset) // eval_batch_size
768
- for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
 
 
 
 
 
 
769
  # Model forward
770
- batch = next(eval_loader)
771
- labels = batch["labels"]
772
-
773
  metrics = p_eval_step(state.params, batch)
774
  eval_metrics.append(metrics)
775
 
776
- # generation
777
- if data_args.predict_with_generate:
778
- generated_ids = p_generate_step(state.params, batch)
779
- eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
780
- eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
781
-
782
  # normalize eval metrics
 
783
  eval_metrics = get_metrics(eval_metrics)
 
784
  eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
 
785
 
786
  # log metrics
787
- wandb_log(eval_metrics, step=global_step, prefix='eval')
788
 
789
  # Print metrics and update progress bar
790
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
@@ -808,28 +953,42 @@ def main():
808
 
809
  # save state
810
  state = unreplicate(state)
811
- with (Path(training_args.output_dir) / 'opt_state.msgpack').open('wb') as f:
812
  f.write(to_bytes(state.opt_state))
813
- with (Path(training_args.output_dir) / 'training_state.json').open('w') as f:
814
- json.dump({'step': state.step.item()}, f)
 
 
815
 
816
  # save to W&B
817
  if data_args.log_model:
818
- metadata = {'step': step, 'epoch': epoch}
819
  if eval_metrics is not None:
820
- metadata['eval/loss'] = eval_metrics['loss']
821
  artifact = wandb.Artifact(
822
  name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
823
  )
824
- artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
825
- artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
826
- artifact.add_file(str(Path(training_args.output_dir) / 'tokenizer.json'))
827
- artifact.add_file(str(Path(training_args.output_dir) / 'tokenizer_config.json'))
828
- artifact.add_file(str(Path(training_args.output_dir) / 'vocab.json'))
829
- artifact.add_file(str(Path(training_args.output_dir) / 'merges.txt'))
830
- artifact.add_file(str(Path(training_args.output_dir) / 'special_tokens_map.json'))
831
- artifact.add_file(str(Path(training_args.output_dir) / 'opt_state.msgpack'))
832
- artifact.add_file(str(Path(training_args.output_dir) / 'training_state.json'))
 
 
 
 
 
 
 
 
 
 
 
 
833
  wandb.run.log_artifact(artifact)
834
 
835
  # save some space
@@ -843,39 +1002,47 @@ def main():
843
  params=params,
844
  push_to_hub=training_args.push_to_hub,
845
  commit_message=f"Saving weights and logs of epoch {epoch+1}",
846
- temp_dir=True # avoid issues with being in a repository
847
  )
848
-
849
  for epoch in epochs:
850
  # ======================== Training ================================
851
- train_start = time.time()
852
 
853
  # Create sampling rng
854
  rng, input_rng = jax.random.split(rng)
855
 
856
  # Generate an epoch by shuffling sampling indices from the train dataset
857
- train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
858
- steps_per_epoch = len(train_dataset) // train_batch_size
 
 
 
 
 
859
  # train
860
- for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
861
- global_step +=1
862
- batch = next(train_loader)
 
 
 
 
 
863
  state, train_metric = p_train_step(state, batch)
864
 
865
  if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
866
  # log metrics
867
- wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
868
 
869
  if training_args.eval_steps and global_step % training_args.eval_steps == 0:
870
  run_evaluation()
871
-
872
  if global_step % data_args.save_model_steps == 0:
873
  run_save_model(state, global_step, epoch)
874
-
875
  # log final train metrics
876
- wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
877
 
878
- train_time += time.time() - train_start
879
  train_metric = unreplicate(train_metric)
880
  epochs.write(
881
  f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
@@ -888,38 +1055,5 @@ def main():
888
  run_save_model(state, global_step, epoch, eval_metrics)
889
 
890
 
891
- # ======================== Prediction loop ==============================
892
- if training_args.do_predict:
893
- logger.info("*** Predict ***")
894
-
895
- pred_metrics = []
896
- pred_generations = []
897
- pred_labels = []
898
-
899
- pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size)
900
- pred_steps = len(predict_dataset) // eval_batch_size
901
- for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False):
902
- # Model forward
903
- batch = next(pred_loader)
904
- labels = batch["labels"]
905
-
906
- metrics = p_eval_step(state.params, batch)
907
- pred_metrics.append(metrics)
908
-
909
- # generation
910
- if data_args.predict_with_generate:
911
- generated_ids = p_generate_step(state.params, batch)
912
- pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
913
- pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
914
-
915
- # normalize prediction metrics
916
- pred_metrics = get_metrics(pred_metrics)
917
- pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
918
-
919
- # Print metrics
920
- desc = f"Predict Loss: {pred_metrics['loss']})"
921
- logger.info(desc)
922
-
923
-
924
  if __name__ == "__main__":
925
  main()
 
20
  # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
21
 
22
  import os
23
+ import logging as pylogging # To avoid collision with transformers.utils.logging
24
  import sys
 
25
  from dataclasses import dataclass, field
26
  from functools import partial
27
  from pathlib import Path
 
29
  import json
30
 
31
  import datasets
 
32
  import numpy as np
33
  from datasets import Dataset, load_dataset, load_metric
34
  from tqdm import tqdm
 
45
  from flax.training import train_state
46
  from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
47
  from transformers import (
 
48
  FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
 
49
  AutoTokenizer,
50
  FlaxAutoModelForSeq2SeqLM,
51
  FlaxBartForConditionalGeneration,
 
57
 
58
  import wandb
59
 
60
+ from dalle_mini.text import TextNormalizer
61
 
62
+ logger = pylogging.getLogger(__name__)
 
 
 
 
 
 
 
 
63
 
64
 
65
  MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
 
71
  OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
72
  OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
73
  BOS_TOKEN_ID = 16384
74
+ BASE_MODEL = "facebook/bart-large-cnn" # we currently have issues with bart-large
75
 
76
 
77
  @dataclass
 
89
  )
90
  model_type: Optional[str] = field(
91
  default=None,
92
+ metadata={
93
+ "help": "If training from scratch, pass a model type from the list: "
94
+ + ", ".join(MODEL_TYPES)
95
+ },
96
  )
97
  config_name: Optional[str] = field(
98
+ default=None,
99
+ metadata={
100
+ "help": "Pretrained config name or path if not the same as model_name"
101
+ },
102
  )
103
  tokenizer_name: Optional[str] = field(
104
+ default=None,
105
+ metadata={
106
+ "help": "Pretrained tokenizer name or path if not the same as model_name"
107
+ },
108
  )
109
  cache_dir: Optional[str] = field(
110
+ default=None,
111
+ metadata={
112
+ "help": "Where do you want to store the pretrained models downloaded from s3"
113
+ },
114
  )
115
  use_fast_tokenizer: bool = field(
116
  default=True,
117
+ metadata={
118
+ "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
119
+ },
120
  )
121
  dtype: Optional[str] = field(
122
  default="float32",
 
139
  """
140
 
141
  dataset_name: Optional[str] = field(
142
+ default=None,
143
+ metadata={"help": "The name of the dataset to use (via the datasets library)."},
144
  )
145
  dataset_config_name: Optional[str] = field(
146
+ default=None,
147
+ metadata={
148
+ "help": "The configuration name of the dataset to use (via the datasets library)."
149
+ },
150
  )
151
  text_column: Optional[str] = field(
152
+ default="caption",
153
+ metadata={
154
+ "help": "The name of the column in the datasets containing the full texts (for summarization)."
155
+ },
156
  )
157
  encoding_column: Optional[str] = field(
158
+ default="encoding",
159
+ metadata={
160
+ "help": "The name of the column in the datasets containing the image encodings."
161
+ },
162
+ )
163
+ dataset_repo_or_path: Optional[str] = field(
164
+ default=None,
165
+ metadata={"help": "The dataset repository containing encoded files."},
166
+ )
167
+ train_file: Optional[str] = field(
168
+ default=None, metadata={"help": "The input training data file (a text file)."}
169
  )
 
170
  validation_file: Optional[str] = field(
171
  default=None,
172
+ metadata={
173
+ "help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."
174
+ },
175
+ )
176
+ streaming: bool = field(
177
+ default=False,
178
+ metadata={"help": "Whether to stream the dataset."},
179
  )
180
+ len_train: Optional[int] = field(
181
  default=None,
182
+ metadata={"help": "Length of training dataset, required for streaming"},
183
+ )
184
+ len_eval: Optional[int] = field(
185
+ default=None,
186
+ metadata={"help": "Length of validation dataset, required for streaming"},
187
  )
188
  max_source_length: Optional[int] = field(
189
  default=128,
 
193
  },
194
  )
195
  no_decay: bool = field(
196
+ default=False,
197
+ metadata={"help": "Whether to use decay in the learning rate scheduler."},
198
  )
199
  max_target_length: Optional[int] = field(
200
  default=OUTPUT_LENGTH,
 
226
  "value if set."
227
  },
228
  )
229
+ normalize_text: bool = field(
230
+ default=False,
231
+ metadata={"help": "Normalize/Simplify text"},
 
 
 
232
  )
233
  preprocessing_num_workers: Optional[int] = field(
234
+ default=80, # ensure we have the same datasets cached data and avoid using too much space
235
  metadata={"help": "The number of processes to use for the preprocessing."},
236
  )
237
  source_prefix: Optional[str] = field(
 
 
 
 
 
 
238
  default=None,
239
  metadata={
240
+ "help": "A prefix to add before every source text (useful for T5 models)."
 
241
  },
242
  )
243
  overwrite_cache: bool = field(
244
+ default=False,
245
+ metadata={"help": "Overwrite the cached training and evaluation sets"},
246
  )
247
  log_interval: Optional[int] = field(
248
  default=40,
249
+ metadata={"help": "Log frequency for metrics"},
 
 
 
250
  )
251
  log_model: bool = field(
252
+ default=False,
253
+ metadata={"help": "Overwrite the cached training and evaluation sets"},
254
  )
255
  save_model_steps: Optional[int] = field(
256
+ default=3000, # about once every hour in our experiments
257
  metadata={
258
  "help": "For logging the model more frequently. Used only when `log_model` is set."
259
  },
260
  )
261
 
262
  def __post_init__(self):
263
+ if (
264
+ self.dataset_name is None
265
+ and self.train_file is None
266
+ and self.validation_file is None
267
+ ):
268
+ raise ValueError(
269
+ "Need either a dataset name or a training/validation file."
270
+ )
271
  else:
272
  if self.train_file is not None:
273
  extension = self.train_file.split(".")[-1]
274
+ assert extension in [
275
+ "tsv",
276
+ "csv",
277
+ "json",
278
+ "jsonl",
279
+ ], "`train_file` should be a tsv, csv or json file."
280
  if self.validation_file is not None:
281
  extension = self.validation_file.split(".")[-1]
282
+ assert extension in [
283
+ "tsv",
284
+ "csv",
285
+ "json",
286
+ "jsonl",
287
+ ], "`validation_file` should be a tsv, csv or json file."
288
  if self.val_max_target_length is None:
289
  self.val_max_target_length = self.max_target_length
290
 
 
295
  optimizer_step: int
296
 
297
  def replicate(self):
298
+ return jax_utils.replicate(self).replace(
299
+ dropout_rng=shard_prng_key(self.dropout_rng)
300
+ )
301
 
302
 
303
  class CustomFlaxBartModule(FlaxBartModule):
304
  def setup(self):
305
  # check config is valid, otherwise set default values
306
+ self.config.vocab_size_output = getattr(
307
+ self.config, "vocab_size_output", OUTPUT_VOCAB_SIZE
308
+ )
309
+ self.config.max_position_embeddings_decoder = getattr(
310
+ self.config, "max_position_embeddings_decoder", OUTPUT_LENGTH
311
+ )
312
 
313
  # we keep shared to easily load pre-trained weights
314
  self.shared = nn.Embed(
 
324
  embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
325
  dtype=self.dtype,
326
  )
327
+ self.encoder = FlaxBartEncoder(
328
+ self.config, dtype=self.dtype, embed_tokens=self.shared
329
+ )
330
 
331
  # the decoder has a different config
332
  decoder_config = BartConfig(self.config.to_dict())
333
+ decoder_config.max_position_embeddings = (
334
+ self.config.max_position_embeddings_decoder
335
+ )
336
  decoder_config.vocab_size = self.config.vocab_size_output
337
+ self.decoder = FlaxBartDecoder(
338
+ decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed
339
+ )
340
+
341
 
342
+ class CustomFlaxBartForConditionalGenerationModule(
343
+ FlaxBartForConditionalGenerationModule
344
+ ):
345
  def setup(self):
346
  # check config is valid, otherwise set default values
347
+ self.config.vocab_size_output = getattr(
348
+ self.config, "vocab_size_output", OUTPUT_VOCAB_SIZE
349
+ )
350
 
351
  self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
352
  self.lm_head = nn.Dense(
 
355
  dtype=self.dtype,
356
  kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
357
  )
358
+ self.final_logits_bias = self.param(
359
+ "final_logits_bias", self.bias_init, (1, self.config.vocab_size_output)
360
+ )
361
+
362
 
363
  class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
364
  module_class = CustomFlaxBartForConditionalGenerationModule
 
365
 
366
+
367
+ def data_loader(
368
+ rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False
369
+ ):
370
  """
371
  Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
372
  Shuffle batches if `shuffle` is `True`.
 
384
  for idx in batch_idx:
385
  batch = dataset[idx]
386
  batch = {k: jnp.array(v) for k, v in batch.items()}
 
387
  batch = shard(batch)
 
388
  yield batch
389
 
390
 
391
+ def data_loader_streaming(dataset: Dataset, batch_size: int):
392
+ keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
393
+ batch = {k: [] for k in keys}
394
+ for item in dataset:
395
+ for k, v in item.items():
396
+ batch[k].append(v)
397
+ if len(batch[keys[0]]) == batch_size:
398
+ batch = {k: jnp.array(v) for k, v in batch.items()}
399
+ batch = shard(batch)
400
+ yield batch
401
+ batch = {k: [] for k in keys}
402
+
403
+
404
  def create_learning_rate_fn(
405
+ train_ds_size: int,
406
+ train_batch_size: int,
407
+ num_train_epochs: int,
408
+ num_warmup_steps: int,
409
+ learning_rate: float,
410
+ no_decay: bool,
411
  ) -> Callable[[int], jnp.array]:
412
  """Returns a linear warmup, linear_decay learning rate function."""
413
  steps_per_epoch = train_ds_size // train_batch_size
414
  num_train_steps = steps_per_epoch * num_train_epochs
415
+ warmup_fn = optax.linear_schedule(
416
+ init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
417
+ )
418
  if no_decay:
419
  return warmup_fn
420
  decay_fn = optax.linear_schedule(
421
+ init_value=learning_rate,
422
+ end_value=0,
423
+ transition_steps=num_train_steps - num_warmup_steps,
424
+ )
425
+ schedule_fn = optax.join_schedules(
426
+ schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]
427
  )
 
428
  return schedule_fn
429
 
430
 
431
  def wandb_log(metrics, step=None, prefix=None):
432
  if jax.process_index() == 0:
433
+ log_metrics = {
434
+ f"{prefix}/{k}" if prefix is not None else k: jax.device_get(v)
435
+ for k, v in metrics.items()
436
+ }
437
  if step is not None:
438
+ log_metrics["train/step"] = step
439
  wandb.log(log_metrics)
440
 
441
 
 
444
  # or by passing the --help flag to this script.
445
  # We now keep distinct sets of args, for a cleaner separation of concerns.
446
 
447
+ parser = HfArgumentParser(
448
+ (ModelArguments, DataTrainingArguments, TrainingArguments)
449
+ )
450
  if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
451
  # If we pass only one argument to the script and it's the path to a json file,
452
  # let's parse it to get our arguments.
453
+ model_args, data_args, training_args = parser.parse_json_file(
454
+ json_file=os.path.abspath(sys.argv[1])
455
+ )
456
  else:
457
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
458
 
 
466
  f"Output directory ({training_args.output_dir}) already exists and is not empty."
467
  "Use --overwrite_output_dir to overcome."
468
  )
469
+
470
  # Set up wandb run
471
  wandb.init(
472
+ entity="dalle-mini",
473
+ project="dalle-mini",
474
+ job_type="Seq2Seq",
475
+ config=parser.parse_args(),
476
  )
477
 
478
  # set default x-axis as 'train/step'
479
+ wandb.define_metric("train/step")
480
+ wandb.define_metric("*", step_metric="train/step")
481
 
482
  # Make one log on every process with the configuration for debugging.
483
  pylogging.basicConfig(
 
501
  # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
502
  # (the dataset will be downloaded automatically from the datasets Hub).
503
  #
504
+ data_files = {
505
+ "train": data_args.train_file,
506
+ "validation": data_args.validation_file,
507
+ }
508
+ dataset = load_dataset(
509
+ data_args.dataset_repo_or_path, data_files=data_files, streaming=True
510
+ )
 
 
 
511
 
512
  # Set up items to load or create
513
  tokenizer = None
 
515
 
516
  def restore_state(state, artifact_dir):
517
  # restore optimizer state
518
+ with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
519
  opt_state = from_bytes(state.opt_state, f.read())
520
+
521
  # restore steps
522
+ with (Path(artifact_dir) / "training_state.json").open("r") as f:
523
  training_state = json.load(f)
524
+ step = training_state["step"]
525
  optimizer_step = step // training_args.gradient_accumulation_steps
526
 
527
  return step, optimizer_step, opt_state
528
+
529
  if model_args.from_checkpoint is not None:
530
  artifact = wandb.run.use_artifact(model_args.from_checkpoint)
531
  artifact_dir = artifact.download()
 
541
  config = model.config
542
 
543
  # load tokenizer if present
544
+ if (Path(artifact_dir) / "tokenizer_config.json").exists():
545
  tokenizer = AutoTokenizer.from_pretrained(
546
+ model_args.model_name_or_path,
547
+ cache_dir=model_args.cache_dir,
548
+ use_fast=model_args.use_fast_tokenizer,
549
+ )
550
 
551
  else:
552
  base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
553
+ model_args.model_name_or_path,
554
+ seed=training_args.seed,
555
+ dtype=getattr(jnp, model_args.dtype),
556
  )
557
  # Set up our new model config
558
  config = BartConfig.from_pretrained(model_args.model_name_or_path)
559
  config.tie_word_embeddings = False
560
  config.decoder_start_token_id = BOS_TOKEN_ID # for first token
561
+ config.bos_token_id = (
562
+ BOS_TOKEN_ID # should not be used (due to forced_bos_token_id)
563
+ )
564
+ config.pos_token_id = (
565
+ BOS_TOKEN_ID # should not be needed (as we generate until max_length)
566
+ )
567
  config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
568
  config.forced_bos_token_id = None # we don't need this token
569
  config.forced_eos_token_id = None # we don't need this token
570
+ config.force_bos_token_to_be_generated = (
571
+ False # otherwise it sets bos_token_id at loading
572
+ )
573
  config.min_length = data_args.max_target_length
574
  config.max_length = data_args.max_target_length
575
 
576
  # Create a custom model and initialize it randomly
577
+ model = CustomFlaxBartForConditionalGeneration(
578
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
579
+ )
580
 
581
  # Use pre-trained weights for encoder
582
+ model.params["model"]["encoder"] = base_model.params["model"]["encoder"]
583
+ model.params["model"]["shared"] = base_model.params["model"]["shared"]
584
  del base_model
585
 
586
  # Load tokenizer if it has not been set
587
  if tokenizer is None:
588
  tokenizer = AutoTokenizer.from_pretrained(
589
+ model_args.model_name_or_path,
590
+ cache_dir=model_args.cache_dir,
591
+ use_fast=model_args.use_fast_tokenizer,
592
  )
593
 
594
  print(f"TPUs: {jax.device_count()}")
 
598
 
599
  # Preprocessing the datasets.
600
  # We need to tokenize inputs and targets.
 
 
 
 
 
 
 
 
 
601
 
602
  # Get the column names for input/target.
603
  text_column = data_args.text_column
604
  encoding_column = data_args.encoding_column
605
 
 
 
 
606
  def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
607
  """
608
  Shift input ids one token to the right.
 
612
  shifted_input_ids[:, 0] = decoder_start_token_id
613
  return shifted_input_ids
614
 
615
+ text_normalizer = TextNormalizer() if data_args.normalize_text else None
616
+
617
+ def normalize_text(example):
618
+ example[text_column] = text_normalizer(example[text_column])
619
+ return example
620
+
621
  def preprocess_function(examples):
622
  inputs = examples[text_column]
623
+ inputs = [prefix + inp for inp in inputs] if prefix else inputs
624
+ # Setting padding="max_length" as we need fixed length inputs for jitted functions
625
  model_inputs = tokenizer(
626
+ inputs,
627
+ max_length=data_args.max_source_length,
628
+ padding="max_length",
629
+ truncation=True,
630
+ return_tensors="np",
631
  )
632
 
633
  # set up targets
634
  # Note: labels correspond to our target indices
635
  # decoder input ids are the same but shifted to the right with bos at the beginning (and without last token)
636
+ labels = examples[encoding_column]
637
  labels = np.asarray(labels)
638
 
639
  # We need the labels, in addition to the decoder_input_ids, for the compute_loss function
 
650
  raise ValueError("--do_train requires a train dataset")
651
  train_dataset = dataset["train"]
652
  if data_args.max_train_samples is not None:
653
+ train_dataset = (
654
+ train_dataset.take(data_args.max_train_samples)
655
+ if data_args.streaming
656
+ else train_dataset.select(range(data_args.max_train_samples))
657
+ )
658
+ if data_args.streaming:
659
+ train_dataset = train_dataset.shuffle(1000, training_args.seed)
660
+ if data_args.normalize_text:
661
+ train_dataset = (
662
+ train_dataset.map(text_normalizer)
663
+ if data_args.streaming
664
+ else train_dataset.map(
665
+ normalize_text,
666
+ num_proc=data_args.preprocessing_num_workers,
667
+ load_from_cache_file=not data_args.overwrite_cache,
668
+ desc="Normalizing the validation dataset",
669
+ )
670
+ )
671
+ train_dataset = (
672
+ train_dataset.map(
673
+ preprocess_function,
674
+ batched=True,
675
+ )
676
+ if data_args.streaming
677
+ else train_dataset.map(
678
+ preprocess_function,
679
+ batched=True,
680
+ num_proc=data_args.preprocessing_num_workers,
681
+ remove_columns=train_dataset.column_names,
682
+ load_from_cache_file=not data_args.overwrite_cache,
683
+ desc="Running tokenizer on validation dataset",
684
+ )
685
  )
686
 
687
  if training_args.do_eval:
 
688
  if "validation" not in dataset:
689
  raise ValueError("--do_eval requires a validation dataset")
690
  eval_dataset = dataset["validation"]
691
  if data_args.max_eval_samples is not None:
692
+ eval_dataset = (
693
+ eval_dataset.take(data_args.max_train_samples)
694
+ if data_args.streaming
695
+ else eval_dataset.select(range(data_args.max_train_samples))
696
+ )
697
+ if data_args.normalize_text:
698
+ eval_dataset = (
699
+ eval_dataset.map(text_normalizer)
700
+ if data_args.streaming
701
+ else eval_dataset.map(
702
+ normalize_text,
703
+ num_proc=data_args.preprocessing_num_workers,
704
+ load_from_cache_file=not data_args.overwrite_cache,
705
+ desc="Normalizing the validation dataset",
706
+ )
707
+ )
708
+ eval_dataset = (
709
+ eval_dataset.map(
710
+ preprocess_function,
711
+ batched=True,
712
+ )
713
+ if data_args.streaming
714
+ else eval_dataset.map(
715
+ preprocess_function,
716
+ batched=True,
717
+ num_proc=data_args.preprocessing_num_workers,
718
+ remove_columns=eval_dataset.column_names,
719
+ load_from_cache_file=not data_args.overwrite_cache,
720
+ desc="Running tokenizer on validation dataset",
721
+ )
722
  )
723
 
724
  # Initialize our training
 
727
 
728
  # Store some constant
729
  num_epochs = int(training_args.num_train_epochs)
730
+ train_batch_size = (
731
+ int(training_args.per_device_train_batch_size) * jax.device_count()
732
+ )
733
  total_batch_size = int(train_batch_size) * training_args.gradient_accumulation_steps
734
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
735
+ if data_args.streaming:
736
+ len_train_dataset = data_args.len_train
737
+ if (
738
+ data_args.max_train_samples is not None
739
+ and data_args.max_train_samples < len_train_dataset
740
+ ):
741
+ len_train_dataset = data_args.max_train_samples
742
+
743
+ len_eval_dataset = data_args.len_eval
744
+ if (
745
+ data_args.max_eval_samples is not None
746
+ and data_args.max_eval_samples < len_eval_dataset
747
+ ):
748
+ len_eval_dataset = data_args.max_eval_samples
749
+ else:
750
+ len_train_dataset = len(train_dataset)
751
+ len_eval_dataset = len(eval_dataset)
752
+ steps_per_epoch = len_train_dataset // train_batch_size
753
  total_steps = steps_per_epoch * num_epochs
754
+ total_optimization_steps = (len_train_dataset // total_batch_size) * num_epochs
755
 
756
  # Create learning rate schedule
757
  linear_decay_lr_schedule_fn = create_learning_rate_fn(
758
+ len_train_dataset,
759
  total_batch_size,
760
  training_args.num_train_epochs,
761
  training_args.warmup_steps,
762
  training_args.learning_rate,
763
+ data_args.no_decay,
764
  )
765
 
766
  # We use Optax's "masking" functionality to not apply weight decay
 
773
  def decay_mask_fn(params):
774
  flat_params = traverse_util.flatten_dict(params)
775
  layer_norm_params = [
776
+ (name, "scale")
777
+ for name in [
778
+ "self_attn_layer_norm",
779
+ "layernorm_embedding",
780
+ "final_layer_norm",
781
+ ]
782
  ]
783
+ flat_mask = {
784
+ path: (path[-1] != "bias" and path[-2:] not in layer_norm_params)
785
+ for path in flat_params
786
+ }
787
  return traverse_util.unflatten_dict(flat_mask)
788
 
789
  # create adam optimizer
 
815
  if model_args.from_checkpoint is not None:
816
  # restore optimizer state, step and optimizer_step
817
  step, optimizer_step, opt_state = restore_state(state, artifact_dir)
818
+ state = state.replace(
819
+ step=step, optimizer_step=optimizer_step, opt_state=opt_state
820
+ )
821
 
822
  # label smoothed cross entropy
823
  def loss_fn(logits, labels):
 
831
 
832
  def compute_loss(params):
833
  labels = batch.pop("labels")
834
+ logits = state.apply_fn(
835
+ **batch, params=params, dropout_rng=dropout_rng, train=True
836
+ )[0]
837
  loss = loss_fn(logits, labels)
838
  return loss
839
 
 
842
  grad_accum = jax.tree_multimap(lambda x, y: x + y, grads, state.grad_accum)
843
 
844
  def update_fn():
845
+ grads = jax.tree_map(
846
+ lambda x: x / training_args.gradient_accumulation_steps, grad_accum
847
+ )
848
  grads = jax.lax.pmean(grads, "batch")
849
  new_state = state.apply_gradients(
850
+ grads=grads,
851
+ grad_accum=jax.tree_map(jnp.zeros_like, grads),
852
+ optimizer_step=state.optimizer_step + 1,
853
  )
854
  return new_state
855
 
 
860
  None,
861
  )
862
 
863
+ metrics = {
864
+ "loss": loss,
865
+ "learning_rate": linear_decay_lr_schedule_fn(state.optimizer_step),
866
+ }
867
  metrics = jax.lax.pmean(metrics, axis_name="batch")
868
 
869
  return new_state.replace(dropout_rng=new_dropout_rng), metrics
 
879
  metrics = jax.lax.pmean(metrics, axis_name="batch")
880
  return metrics
881
 
 
 
 
 
 
 
 
 
 
 
 
 
882
  # Create parallel version of the train and eval step
883
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
 
 
884
  p_eval_step = jax.pmap(eval_step, "batch")
 
885
 
886
  # Replicate the train state on each device
887
  state = state.replicate()
888
 
889
  logger.info("***** Running training *****")
890
+ logger.info(f" Num examples = {len_train_dataset}")
891
  logger.info(f" Num Epochs = {num_epochs}")
892
+ logger.info(
893
+ f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
894
+ )
895
  logger.info(
896
  f" Total train batch size (w. parallel & distributed) = {train_batch_size * training_args.gradient_accumulation_steps}"
897
  )
898
  logger.info(f" Total global steps = {total_steps}")
899
  logger.info(f" Total optimization steps = {total_optimization_steps}")
900
 
 
901
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
902
  global_step = 0
903
 
 
905
  # ======================== Evaluating ==============================
906
  eval_metrics = []
907
  if training_args.do_eval:
908
+ if data_args.streaming:
909
+ eval_loader = data_loader_streaming(eval_dataset, eval_batch_size)
910
+ else:
911
+ eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
912
+ eval_steps = len_eval_dataset // eval_batch_size
913
+ for batch in tqdm(
914
+ eval_loader,
915
+ desc="Evaluating...",
916
+ position=2,
917
+ leave=False,
918
+ total=eval_steps,
919
+ ):
920
  # Model forward
 
 
 
921
  metrics = p_eval_step(state.params, batch)
922
  eval_metrics.append(metrics)
923
 
 
 
 
 
 
 
924
  # normalize eval metrics
925
+ breakpoint()
926
  eval_metrics = get_metrics(eval_metrics)
927
+ breakpoint()
928
  eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
929
+ breakpoint()
930
 
931
  # log metrics
932
+ wandb_log(eval_metrics, step=global_step, prefix="eval")
933
 
934
  # Print metrics and update progress bar
935
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
 
953
 
954
  # save state
955
  state = unreplicate(state)
956
+ with (Path(training_args.output_dir) / "opt_state.msgpack").open("wb") as f:
957
  f.write(to_bytes(state.opt_state))
958
+ with (Path(training_args.output_dir) / "training_state.json").open(
959
+ "w"
960
+ ) as f:
961
+ json.dump({"step": state.step.item()}, f)
962
 
963
  # save to W&B
964
  if data_args.log_model:
965
+ metadata = {"step": step, "epoch": epoch}
966
  if eval_metrics is not None:
967
+ metadata["eval/loss"] = eval_metrics["loss"]
968
  artifact = wandb.Artifact(
969
  name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
970
  )
971
+ artifact.add_file(
972
+ str(Path(training_args.output_dir) / "flax_model.msgpack")
973
+ )
974
+ artifact.add_file(str(Path(training_args.output_dir) / "config.json"))
975
+ artifact.add_file(
976
+ str(Path(training_args.output_dir) / "tokenizer.json")
977
+ )
978
+ artifact.add_file(
979
+ str(Path(training_args.output_dir) / "tokenizer_config.json")
980
+ )
981
+ artifact.add_file(str(Path(training_args.output_dir) / "vocab.json"))
982
+ artifact.add_file(str(Path(training_args.output_dir) / "merges.txt"))
983
+ artifact.add_file(
984
+ str(Path(training_args.output_dir) / "special_tokens_map.json")
985
+ )
986
+ artifact.add_file(
987
+ str(Path(training_args.output_dir) / "opt_state.msgpack")
988
+ )
989
+ artifact.add_file(
990
+ str(Path(training_args.output_dir) / "training_state.json")
991
+ )
992
  wandb.run.log_artifact(artifact)
993
 
994
  # save some space
 
1002
  params=params,
1003
  push_to_hub=training_args.push_to_hub,
1004
  commit_message=f"Saving weights and logs of epoch {epoch+1}",
1005
+ temp_dir=True, # avoid issues with being in a repository
1006
  )
1007
+
1008
  for epoch in epochs:
1009
  # ======================== Training ================================
 
1010
 
1011
  # Create sampling rng
1012
  rng, input_rng = jax.random.split(rng)
1013
 
1014
  # Generate an epoch by shuffling sampling indices from the train dataset
1015
+ if data_args.streaming:
1016
+ train_dataset.set_epoch(epoch)
1017
+ train_loader = data_loader_streaming(train_dataset, train_batch_size)
1018
+ else:
1019
+ train_loader = data_loader(
1020
+ input_rng, train_dataset, train_batch_size, shuffle=True
1021
+ )
1022
  # train
1023
+ for batch in tqdm(
1024
+ train_loader,
1025
+ desc="Training...",
1026
+ position=1,
1027
+ leave=False,
1028
+ total=steps_per_epoch,
1029
+ ):
1030
+ global_step += 1
1031
  state, train_metric = p_train_step(state, batch)
1032
 
1033
  if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
1034
  # log metrics
1035
+ wandb_log(unreplicate(train_metric), step=global_step, prefix="train")
1036
 
1037
  if training_args.eval_steps and global_step % training_args.eval_steps == 0:
1038
  run_evaluation()
1039
+
1040
  if global_step % data_args.save_model_steps == 0:
1041
  run_save_model(state, global_step, epoch)
1042
+
1043
  # log final train metrics
1044
+ wandb_log(unreplicate(train_metric), step=global_step, prefix="train")
1045
 
 
1046
  train_metric = unreplicate(train_metric)
1047
  epochs.write(
1048
  f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
 
1055
  run_save_model(state, global_step, epoch, eval_metrics)
1056
 
1057
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1058
  if __name__ == "__main__":
1059
  main()