File size: 29,176 Bytes
46cb01f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d209547
803c7df
d209547
46cb01f
6523a6d
d209547
46cb01f
 
 
 
 
 
 
 
d209547
 
46cb01f
 
d209547
46cb01f
85c1b8e
d209547
 
46cb01f
85c1b8e
a6252c9
46cb01f
803c7df
3f0364c
 
46cb01f
 
 
 
 
 
 
803c7df
46cb01f
 
 
 
 
0a77f72
 
 
 
 
 
803c7df
 
a96f44d
803c7df
a96f44d
46cb01f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a96f44d
 
 
 
46cb01f
3f0364c
a96f44d
 
 
 
 
85c1b8e
a96f44d
 
 
 
85c1b8e
 
46cb01f
 
 
85c1b8e
 
0a77f72
a96f44d
0a77f72
a96f44d
46cb01f
eac6890
 
 
 
 
 
46cb01f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87fac28
46cb01f
 
 
a96f44d
87fac28
 
 
46cb01f
85c1b8e
 
 
 
 
 
 
85748ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498559f
1c44a7d
a96f44d
85748ef
1c44a7d
46cb01f
eac6890
85748ef
eac6890
 
 
 
85748ef
 
 
 
 
 
 
 
0a77f72
85748ef
0a77f72
85748ef
46cb01f
 
 
69cf636
6523a6d
 
 
46cb01f
 
a96f44d
 
 
46cb01f
3cd6d41
 
 
e2400cc
3cd6d41
6523a6d
3cd6d41
 
 
 
6523a6d
 
 
 
 
 
3cd6d41
46cb01f
 
a96f44d
 
87fac28
6523a6d
46cb01f
 
6523a6d
 
 
 
a96f44d
 
 
87fac28
5a3211f
46cb01f
a96f44d
 
 
 
 
 
46cb01f
 
 
 
19070ab
 
a96f44d
4a4820f
a96f44d
19070ab
47e006f
19070ab
 
 
46cb01f
85748ef
a96f44d
 
 
46cb01f
 
 
a96f44d
 
 
46cb01f
 
 
 
 
 
 
 
 
 
 
 
 
a96f44d
46cb01f
803c7df
46cb01f
 
803c7df
46cb01f
 
803c7df
46cb01f
 
 
 
 
 
 
53dade7
 
 
46cb01f
 
 
9bf9397
85c1b8e
0fe3e72
 
 
a96f44d
46cb01f
074c5e1
5b533b5
 
 
 
 
 
 
074c5e1
0a77f72
5b533b5
 
 
 
3d61350
803c7df
 
61c93f2
80b41d1
 
46cb01f
3cd6d41
 
 
803c7df
3cd6d41
4aced93
3d61350
 
0a77f72
a96f4dc
0a77f72
a96f4dc
3d61350
0a77f72
 
b257ca8
0a77f72
 
 
 
 
80b41d1
 
0a77f72
b257ca8
0a77f72
 
 
 
3d61350
3cd6d41
803c7df
 
 
 
 
 
 
 
 
4aced93
46cb01f
85c1b8e
46cb01f
0fe3e72
85c1b8e
 
 
a96f4dc
85c1b8e
46cb01f
 
eac6890
46cb01f
 
 
 
5b533b5
a96f44d
5b533b5
 
 
 
 
 
 
 
 
a96f44d
85c1b8e
6523a6d
5b533b5
 
 
6523a6d
0df810d
 
 
53dade7
46cb01f
 
69cf636
46cb01f
 
85748ef
0df810d
46cb01f
 
 
 
 
 
 
 
 
 
a96f44d
 
 
 
 
 
46cb01f
a96f44d
 
 
 
46cb01f
 
 
600ad79
 
 
 
69cf636
bab75aa
 
85748ef
600ad79
 
 
69cf636
600ad79
 
 
 
 
 
46cb01f
69cf636
 
 
 
 
 
46cb01f
c9e9575
 
 
5960e87
c9e9575
 
0a77f72
6523a6d
 
3cd6d41
46cb01f
 
9db361a
 
d61405b
46cb01f
 
 
6523a6d
46cb01f
 
69cf636
46cb01f
a96f44d
 
 
9db361a
46cb01f
 
 
69cf636
 
6523a6d
 
 
 
5b533b5
6523a6d
46cb01f
a96f44d
 
69cf636
a96f44d
46cb01f
6523a6d
 
46cb01f
 
9db361a
46cb01f
 
9db361a
46cb01f
 
 
 
 
 
 
a96f44d
9db361a
46cb01f
 
a96f44d
46cb01f
a96f44d
 
 
5b533b5
c9e9575
47e006f
c9e9575
53dade7
6523a6d
 
 
4a4820f
5b533b5
 
 
 
 
 
 
 
 
 
 
 
 
 
46cb01f
6523a6d
 
 
566d5f2
46cb01f
 
32dc2d8
85c1b8e
0df810d
 
 
 
 
a96f44d
 
 
 
 
 
 
32dc2d8
 
 
 
 
 
 
 
19070ab
0df810d
566d5f2
32dc2d8
0d94b71
32dc2d8
 
19070ab
566d5f2
 
6523a6d
d449092
6523a6d
6e89e9e
 
 
 
 
 
aecf3a7
 
 
a30dbd3
6523a6d
a96f44d
6523a6d
0df810d
85748ef
0df810d
 
a96f44d
 
 
6523a6d
0df810d
6523a6d
 
aecf3a7
5b533b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a96f44d
9bf9397
6523a6d
baa52db
9bf9397
566d5f2
6523a6d
566d5f2
0df810d
566d5f2
 
85c1b8e
566d5f2
a96f44d
 
 
 
 
 
 
6523a6d
 
 
 
 
 
 
baa52db
0df810d
 
 
566d5f2
85748ef
19070ab
baa52db
0df810d
 
 
 
 
baa52db
566d5f2
6523a6d
47e006f
6523a6d
a96f44d
6523a6d
 
a96f44d
3fef9c1
baa52db
 
 
566d5f2
9bf9397
baa52db
9bf9397
566d5f2
19070ab
566d5f2
46cb01f
4a4820f
6523a6d
754f876
1c44a7d
46cb01f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for seq2seq, text to image.
Script adapted from run_summarization_flax.py
"""

import json
import logging
import os
import sys
import time
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Callable, Optional

import datasets
import jax
import jax.numpy as jnp
import optax
import transformers
import wandb
from datasets import Dataset
from flax import jax_utils, traverse_util
from flax.jax_utils import unreplicate
from flax.serialization import from_bytes, to_bytes
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard_prng_key
from tqdm import tqdm
from transformers import AutoTokenizer, HfArgumentParser

from dalle_mini.data import Dataset
from dalle_mini.model import DalleBart, DalleBartConfig

logger = logging.getLogger(__name__)


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
    """

    model_name_or_path: Optional[str] = field(
        default=None,
        metadata={
            "help": "The model checkpoint for weights initialization."
            "Don't set if you want to train a model from scratch."
        },
    )
    config_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "Pretrained config name or path if not the same as model_name"
        },
    )
    tokenizer_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "Pretrained tokenizer name or path if not the same as model_name_or_path"
        },
    )
    dtype: Optional[str] = field(
        default="float32",
        metadata={
            "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
        },
    )


@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """

    text_column: Optional[str] = field(
        default="caption",
        metadata={
            "help": "The name of the column in the datasets containing the full texts (for summarization)."
        },
    )
    encoding_column: Optional[str] = field(
        default="encoding",
        metadata={
            "help": "The name of the column in the datasets containing the image encodings."
        },
    )
    dataset_repo_or_path: str = field(
        default=None,
        metadata={"help": "The dataset repository containing encoded files."},
    )
    train_file: Optional[str] = field(
        default=None,
        metadata={"help": "The input training data file (glob acceptable)."},
    )
    validation_file: Optional[str] = field(
        default=None,
        metadata={"help": "An optional input evaluation data file (glob acceptable)."},
    )
    # data loading should not be a bottleneck so we use "streaming" mode by default
    streaming: bool = field(
        default=True,
        metadata={"help": "Whether to stream the dataset."},
    )
    use_auth_token: bool = field(
        default=False,
        metadata={
            "help": "Whether to use the authentication token for private datasets."
        },
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
            "value if set."
        },
    )
    max_eval_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
            "value if set."
        },
    )
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={
            "help": "The number of processes to use for the preprocessing. Not used in streaming mode."
        },
    )
    overwrite_cache: bool = field(
        default=False,
        metadata={
            "help": "Overwrite the cached training and evaluation sets. Not used in streaming mode."
        },
    )
    # default seed of None ensures we don't repeat the same items if script was interrupted during an epoch
    seed_dataset: int = field(
        default=None,
        metadata={
            "help": "Random seed for the dataset that will be set at the beginning of training."
        },
    )

    def __post_init__(self):
        if self.dataset_repo_or_path is None:
            raise ValueError("Need a dataset repository or path.")


@dataclass
class TrainingArguments:
    """
    Arguments pertaining to training parameters.
    """

    output_dir: str = field(
        metadata={
            "help": "The output directory where the model predictions and checkpoints will be written."
        },
    )
    overwrite_output_dir: bool = field(
        default=False,
        metadata={
            "help": (
                "Overwrite the content of the output directory. "
                "Use this to continue training if output_dir points to a checkpoint directory."
            )
        },
    )

    do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
    do_eval: bool = field(
        default=False, metadata={"help": "Whether to run eval on the dev set."}
    )

    per_device_train_batch_size: int = field(
        default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
    )
    per_device_eval_batch_size: int = field(
        default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
    )

    gradient_accumulation_steps: int = field(
        default=1,
        metadata={
            "help": "Number of updates steps to accumulate before performing a backward/update pass."
        },
    )

    learning_rate: float = field(
        default=5e-5, metadata={"help": "The initial learning rate."}
    )
    adafactor: bool = field(
        default=False,
        metadata={"help": "Whether or not to replace AdamW by Adafactor."},
    )
    weight_decay: float = field(
        default=None, metadata={"help": "Weight decay if we apply some."}
    )
    adam_beta1: float = field(
        default=0.9, metadata={"help": "Beta1 for AdamW optimizer"}
    )
    adam_beta2: float = field(
        default=0.999, metadata={"help": "Beta2 for AdamW optimizer"}
    )
    adam_epsilon: float = field(
        default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}
    )
    max_grad_norm: float = field(
        default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
    )
    use_decay: bool = field(
        default=False,
        metadata={"help": "Whether to use decay in the learning rate scheduler."},
    )

    num_train_epochs: float = field(
        default=3.0, metadata={"help": "Total number of training epochs to perform."}
    )
    warmup_steps: int = field(
        default=0, metadata={"help": "Linear warmup over warmup_steps."}
    )

    logging_steps: int = field(
        default=40, metadata={"help": "Log every X updates steps."}
    )
    eval_steps: int = field(
        default=400, metadata={"help": "Run an evaluation every X steps."}
    )
    save_steps: int = field(
        default=4000, metadata={"help": "Save checkpoint every X updates steps."}
    )
    log_model: bool = field(
        default=False,
        metadata={"help": "Log model to wandb at `save_steps` frequency."},
    )

    seed_model: int = field(
        default=42,
        metadata={
            "help": "Random seed for the model that will be set at the beginning of training."
        },
    )

    push_to_hub: bool = field(
        default=False,
        metadata={
            "help": "Whether or not to upload the trained model to the model hub after training."
        },
    )

    resume_from_checkpoint: Optional[str] = field(
        default=None,
        metadata={"help": "Reference to a wandb artifact for resuming training."},
    )


class TrainState(train_state.TrainState):
    dropout_rng: jnp.ndarray = None
    epoch: int = 0
    train_time: float = 0.0  # total time the model trained
    train_samples: int = 0  # number of samples seen

    def replicate(self):
        return jax_utils.replicate(self).replace(
            dropout_rng=shard_prng_key(self.dropout_rng)
        )

    def restore_state(self, artifact_dir):
        # restore optimizer state
        with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
            new_opt_state = from_bytes(self.opt_state, f.read())

        # restore other parameters
        with (Path(artifact_dir) / "training_state.json").open("r") as f:
            training_state = json.load(f)

        # replace state
        return self.replace(
            opt_state=new_opt_state,
            step=training_state["step"],
            train_time=training_state["train_time"],
            train_samples=training_state["train_samples"],
        )


def create_learning_rate_fn(
    num_warmup_steps: int,
    learning_rate: float,
    use_decay: bool,
    num_train_steps: int = None,  # used only with `use_decay`, typically train_size // batch_size * num_epochs
) -> Callable[[int], jnp.array]:
    """Returns a linear warmup, linear_decay learning rate function."""
    if use_decay:
        assert (
            num_train_steps is not None
        ), "Learning rate with decay requires number of training steps"
    warmup_fn = optax.linear_schedule(
        init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
    )
    if not use_decay:
        return warmup_fn
    decay_fn = optax.linear_schedule(
        init_value=learning_rate,
        end_value=0,
        transition_steps=num_train_steps - num_warmup_steps,
    )
    schedule_fn = optax.join_schedules(
        schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]
    )
    return schedule_fn


def wandb_log(metrics, step=None, prefix=None):
    if jax.process_index() == 0:
        log_metrics = {
            f"{prefix}/{k}" if prefix is not None else k: v for k, v in metrics.items()
        }
        if step is not None:
            log_metrics["train/step"] = step
        wandb.log(log_metrics)


def main():
    # See all possible arguments by passing the --help flag to this script.
    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments)
    )
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(
            json_file=os.path.abspath(sys.argv[1])
        )
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    if (
        os.path.exists(training_args.output_dir)
        and os.listdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
    ):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty."
            "Use --overwrite_output_dir to overcome."
        )

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    # Setup logging, we only want one process per machine to log things on the screen.
    logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
    if jax.process_index() == 0:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    logger.info(f"TPUs: {jax.device_count()}")
    assert jax.device_count() == 8, "TPUs in use, please check running processes"

    # Set the verbosity to info of the Transformers logger (on main process only):
    logger.info(f"Training/evaluation parameters {training_args}")

    # Load dataset
    dataset = Dataset(
        **asdict(data_args),
        do_train=training_args.do_train,
        do_eval=training_args.do_eval,
    )

    # Set up wandb run
    if jax.process_index() == 0:
        wandb.init(
            entity="dalle-mini",
            project="dalle-mini",
            job_type="Seq2Seq",
            config=parser.parse_args(),
        )

    if training_args.resume_from_checkpoint is not None:
        if jax.process_index() == 0:
            artifact = wandb.run.use_artifact(training_args.resume_from_checkpoint)
        else:
            artifact = wandb.Api().artifact(training_args.resume_from_checkpoint)
        artifact_dir = artifact.download()

        # load model
        model = DalleBart.from_pretrained(artifact_dir)
        # avoid OOM on TPU: see https://github.com/google/flax/issues/1658
        print(model.params)

        # load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(
            artifact_dir,
            use_fast=True,
        )

    else:
        # Set up our new model config
        if model_args.config_name:
            config = DalleBartConfig.from_pretrained(model_args.config_name)
        else:
            config = DalleBartConfig.from_pretrained(model_args.model_name_or_path)

        # Load or create new model
        if model_args.model_name_or_path:
            model = DalleBart.from_pretrained(
                model_args.model_name_or_path,
                config=config,
                seed=training_args.seed_model,
                dtype=getattr(jnp, model_args.dtype),
            )
            # avoid OOM on TPU: see https://github.com/google/flax/issues/1658
            print(model.params)
        else:
            model = DalleBart(
                config,
                seed=training_args.seed_model,
                dtype=getattr(jnp, model_args.dtype),
            )

        # Load tokenizer
        if model_args.tokenizer_name is not None:
            tokenizer = AutoTokenizer.from_pretrained(
                model_args.tokenizer_name, use_fast=True
            )
        else:
            tokenizer = AutoTokenizer.from_pretrained(
                model_args.model_name_or_path,
                use_fast=True,
            )

    # Preprocessing the datasets.
    # We need to normalize and tokenize inputs and targets.

    dataset.preprocess(
        tokenizer=tokenizer,
        decoder_start_token_id=model.config.decoder_start_token_id,
        normalize_text=model.config.normalize_text,
        max_length=model.config.max_text_length,
    )

    # Initialize our training
    rng = jax.random.PRNGKey(training_args.seed_model)
    rng, dropout_rng = jax.random.split(rng)

    # Store some constant
    num_epochs = int(training_args.num_train_epochs)
    # batch size per node
    train_batch_size = (
        int(training_args.per_device_train_batch_size) * jax.local_device_count()
    )
    batch_size_per_update = (
        train_batch_size
        * training_args.gradient_accumulation_steps
        * jax.process_count()
    )
    eval_batch_size = (
        int(training_args.per_device_eval_batch_size) * jax.local_device_count()
    )
    len_train_dataset, len_eval_dataset = dataset.length
    steps_per_epoch = (
        len_train_dataset // (train_batch_size * jax.process_count())
        if len_train_dataset is not None
        else None
    )
    num_train_steps = (
        steps_per_epoch * num_epochs if steps_per_epoch is not None else None
    )
    num_params = model.num_params

    # Create learning rate schedule
    learning_rate_fn = create_learning_rate_fn(
        training_args.warmup_steps,
        training_args.learning_rate,
        training_args.use_decay,
        num_train_steps,
    )

    # We use Optax's "masking" functionality to not apply weight decay
    # to bias and LayerNorm scale parameters. decay_mask_fn returns a
    # mask boolean with the same structure as the parameters.
    # The mask is True for parameters that should be decayed.
    # Note that this mask is specifically adapted for FlaxBart.
    def decay_mask_fn(params):
        flat_params = traverse_util.flatten_dict(params)
        layer_norm_params = [
            (name, "scale")
            for name in [
                "self_attn_layer_norm",
                "layernorm_embedding",
                "final_layer_norm",
            ]
        ]
        flat_mask = {
            path: (path[-1] != "bias" and path[-2:] not in layer_norm_params)
            for path in flat_params
        }
        return traverse_util.unflatten_dict(flat_mask)

    # create adam optimizer
    if training_args.adafactor:
        # We use the default parameters here to initialize adafactor,
        # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
        optimizer = optax.adafactor(
            learning_rate=learning_rate_fn,
            weight_decay_rate=training_args.weight_decay,
            weight_decay_mask=decay_mask_fn,
            clipping_threshold=training_args.max_grad_norm,
        )
    else:
        optimizer = optax.adamw(
            learning_rate=learning_rate_fn,
            b1=training_args.adam_beta1,
            b2=training_args.adam_beta2,
            eps=training_args.adam_epsilon,
            weight_decay=training_args.weight_decay,
            mask=decay_mask_fn,
        )

    # add gradient accumulation
    if training_args.gradient_accumulation_steps > 1:
        optimizer = optax.chain(
            optax.apply_every(training_args.gradient_accumulation_steps), optimizer
        )

    # Setup train state
    state = TrainState.create(
        apply_fn=model.__call__,
        params=model.params,
        tx=optimizer,
        dropout_rng=dropout_rng,
    )
    if training_args.resume_from_checkpoint is not None:
        # restore optimizer state and other parameters
        # we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
        state = state.restore_state(artifact_dir)

    # label smoothed cross entropy
    def loss_fn(logits, labels):
        loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
        loss = loss.mean()
        return loss

    # Define gradient update step fn
    def train_step(state, batch, delta_time):
        dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)

        def compute_loss(params, batch):
            labels = batch.pop("labels")
            logits = state.apply_fn(
                **batch, params=params, dropout_rng=dropout_rng, train=True
            )[0]
            loss = loss_fn(logits, labels)
            return loss

        grad_fn = jax.value_and_grad(compute_loss)
        loss, grads = grad_fn(state.params, batch)
        grads = jax.lax.pmean(grads, "batch")
        state = state.apply_gradients(
            grads=grads,
            dropout_rng=new_dropout_rng,
            train_time=state.train_time + delta_time,
            train_samples=state.train_samples + train_batch_size * jax.process_count(),
        )

        metrics = {
            "loss": loss,
            "learning_rate": learning_rate_fn(state.step),
        }
        metrics = jax.lax.pmean(metrics, axis_name="batch")

        return state, metrics

    # Define eval fn
    def eval_step(params, batch):
        labels = batch.pop("labels")
        logits = model(**batch, params=params, train=False)[0]
        loss = loss_fn(logits, labels)

        # summarize metrics
        metrics = {"loss": loss}
        metrics = jax.lax.pmean(metrics, axis_name="batch")
        return metrics

    # Create parallel version of the train and eval step
    p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
    p_eval_step = jax.pmap(eval_step, "batch")

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len_train_dataset}")
    logger.info(f"  Num Epochs = {num_epochs}")
    logger.info(
        f"  Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
    )
    logger.info(f"  Number of devices = {jax.device_count()}")
    logger.info(
        f"  Total train batch size (w. parallel, distributed & gradient accumulation) = {batch_size_per_update}"
    )
    logger.info(f"  Model parameters = {num_params:,}")
    epochs = tqdm(
        range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
    )

    if jax.process_index() == 0:
        # set default x-axis as 'train/step'
        wandb_log({}, step=state.step)
        wandb.define_metric("*", step_metric="train/step")

        # add interesting config parameters
        wandb.config.update(
            {
                "len_train_dataset": len_train_dataset,
                "len_eval_dataset": len_eval_dataset,
                "batch_size_per_update": batch_size_per_update,
                "num_params": num_params,
            }
        )

    # replicate state on each device
    state = state.replicate()

    def run_evaluation():
        # ======================== Evaluating ==============================
        eval_metrics = []
        if training_args.do_eval:
            eval_loader = dataset.dataloader("eval", eval_batch_size)
            eval_steps = (
                len_eval_dataset // eval_batch_size
                if len_eval_dataset is not None
                else None
            )
            for batch in tqdm(
                eval_loader,
                desc="Evaluating...",
                position=2,
                leave=False,
                total=eval_steps,
            ):
                # Model forward
                metrics = p_eval_step(state.params, batch)
                eval_metrics.append(metrics)

            # normalize eval metrics
            eval_metrics = get_metrics(eval_metrics)
            eval_metrics = jax.tree_map(jnp.mean, eval_metrics)

            # log metrics
            wandb_log(eval_metrics, step=unreplicate(state.step), prefix="eval")

            # Print metrics and update progress bar
            desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
            epochs.write(desc)
            epochs.desc = desc

            return eval_metrics

    def run_save_model(state, eval_metrics=None):
        if jax.process_index() == 0:
            params = jax.device_get(unreplicate(state.params))
            # save model locally
            model.save_pretrained(
                training_args.output_dir,
                params=params,
            )

            # save tokenizer
            tokenizer.save_pretrained(training_args.output_dir)

            # save state
            opt_state = unreplicate(state.opt_state)
            with (Path(training_args.output_dir) / "opt_state.msgpack").open("wb") as f:
                f.write(to_bytes(opt_state))
            state_dict = {
                k: jax.device_get(unreplicate(getattr(state, k))).item()
                for k in ["step", "epoch", "train_time", "train_samples"]
            }
            with (Path(training_args.output_dir) / "training_state.json").open(
                "w"
            ) as f:
                json.dump(
                    state_dict,
                    f,
                )

            if jax.process_index() == 0:
                # save to W&B
                if training_args.log_model:
                    # save some space
                    c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
                    c.cleanup(wandb.util.from_human_size("10GB"))

                    metadata = dict(state_dict)
                    metadata["num_params"] = num_params
                    if eval_metrics is not None:
                        metadata["eval"] = eval_metrics
                    artifact = wandb.Artifact(
                        name=f"model-{wandb.run.id}",
                        type="bart_model",
                        metadata=metadata,
                    )
                    artifact.add_file(
                        str(Path(training_args.output_dir) / "flax_model.msgpack")
                    )
                    artifact.add_file(
                        str(Path(training_args.output_dir) / "config.json")
                    )
                    artifact.add_file(
                        str(Path(training_args.output_dir) / "tokenizer.json")
                    )
                    artifact.add_file(
                        str(Path(training_args.output_dir) / "tokenizer_config.json")
                    )
                    artifact.add_file(
                        str(Path(training_args.output_dir) / "vocab.json")
                    )
                    artifact.add_file(
                        str(Path(training_args.output_dir) / "merges.txt")
                    )
                    artifact.add_file(
                        str(Path(training_args.output_dir) / "special_tokens_map.json")
                    )
                    artifact.add_file(
                        str(Path(training_args.output_dir) / "opt_state.msgpack")
                    )
                    artifact.add_file(
                        str(Path(training_args.output_dir) / "training_state.json")
                    )

                    wandb.run.log_artifact(artifact)

                # save to the hub
                if training_args.push_to_hub:
                    model.save_pretrained(
                        training_args.output_dir,
                        params=params,
                        push_to_hub=training_args.push_to_hub,
                        commit_message=f"Saving weights and logs at step {unreplicate(state.step)+1}",
                        temp_dir=True,  # avoid issues with being in a repository
                    )

    # init variables
    last_time = time.perf_counter()
    train_metrics = None

    for epoch in epochs:
        state.replace(epoch=jax_utils.replicate(epoch))
        # ======================== Training ================================
        wandb_log({"train/epoch": epoch}, step=unreplicate(state.step))

        # Generate an epoch by shuffling sampling indices from the train dataset
        train_loader = dataset.dataloader("train", train_batch_size)
        # train
        for batch in tqdm(
            train_loader,
            desc="Training...",
            position=1,
            leave=False,
            total=steps_per_epoch,
        ):

            # calculate delta time (we have a lag of one step but it's ok)
            new_time = time.perf_counter()
            delta_time = new_time - last_time
            last_time = new_time

            # train step
            state, train_metrics = p_train_step(
                state, batch, jax_utils.replicate(delta_time)
            )
            step = unreplicate(state.step)

            if step % training_args.logging_steps == 0 and jax.process_index() == 0:
                # log metrics
                metrics = unreplicate(train_metrics)
                # log state parameters
                state_dict = {
                    k.split("_")[-1]: unreplicate(getattr(state, k))
                    for k in ["epoch", "train_time", "train_samples"]
                }
                wandb_log({**metrics, **state_dict}, step=step, prefix="train")

            eval_metrics = None
            if training_args.eval_steps and step % training_args.eval_steps == 0:
                eval_metrics = run_evaluation()

            if step % training_args.save_steps == 0:
                run_save_model(state, eval_metrics)

        # log final train metrics
        if train_metrics is not None:
            train_metrics = unreplicate(train_metrics)
            wandb_log(train_metrics, step=step, prefix="train")

            epochs.write(
                f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"
            )

        # Final evaluation
        eval_metrics = run_evaluation()

        # save checkpoint after each epoch
        run_save_model(state, eval_metrics)


if __name__ == "__main__":
    main()