Spaces:
Running
Running
feat(train): use pjit (#125)
Browse files- src/dalle_mini/data.py +0 -3
- src/dalle_mini/model/__init__.py +1 -0
- src/dalle_mini/model/modeling.py +10 -2
- tools/train/train.py +179 -137
src/dalle_mini/data.py
CHANGED
@@ -6,7 +6,6 @@ import jax.numpy as jnp
|
|
6 |
import numpy as np
|
7 |
from braceexpand import braceexpand
|
8 |
from datasets import Dataset, load_dataset
|
9 |
-
from flax.training.common_utils import shard
|
10 |
|
11 |
from .text import TextNormalizer
|
12 |
|
@@ -191,7 +190,6 @@ class Dataset:
|
|
191 |
lambda x: x.reshape((-1, per_device_batch_size) + x.shape[1:]),
|
192 |
batch,
|
193 |
)
|
194 |
-
batch = shard(batch)
|
195 |
yield batch
|
196 |
|
197 |
def _dataloader_datasets_streaming(
|
@@ -232,7 +230,6 @@ class Dataset:
|
|
232 |
),
|
233 |
batch,
|
234 |
)
|
235 |
-
batch = shard(batch)
|
236 |
yield batch
|
237 |
batch = {k: [] for k in keys}
|
238 |
first_loop = False
|
|
|
6 |
import numpy as np
|
7 |
from braceexpand import braceexpand
|
8 |
from datasets import Dataset, load_dataset
|
|
|
9 |
|
10 |
from .text import TextNormalizer
|
11 |
|
|
|
190 |
lambda x: x.reshape((-1, per_device_batch_size) + x.shape[1:]),
|
191 |
batch,
|
192 |
)
|
|
|
193 |
yield batch
|
194 |
|
195 |
def _dataloader_datasets_streaming(
|
|
|
230 |
),
|
231 |
batch,
|
232 |
)
|
|
|
233 |
yield batch
|
234 |
batch = {k: [] for k in keys}
|
235 |
first_loop = False
|
src/dalle_mini/model/__init__.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
from .configuration import DalleBartConfig
|
2 |
from .modeling import DalleBart
|
|
|
3 |
from .tokenizer import DalleBartTokenizer
|
|
|
1 |
from .configuration import DalleBartConfig
|
2 |
from .modeling import DalleBart
|
3 |
+
from .partitions import set_partitions
|
4 |
from .tokenizer import DalleBartTokenizer
|
src/dalle_mini/model/modeling.py
CHANGED
@@ -300,6 +300,7 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
|
|
300 |
- added num_params property
|
301 |
- config_class replaced to DalleBartConfig
|
302 |
- __init__ accepts abstract_init which does uses parameter shape to initialize the model
|
|
|
303 |
"""
|
304 |
|
305 |
config_class = DalleBartConfig
|
@@ -311,6 +312,7 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
|
|
311 |
seed: int = 0,
|
312 |
dtype: jnp.dtype = jnp.float32,
|
313 |
abstract_init: bool = False,
|
|
|
314 |
**kwargs,
|
315 |
):
|
316 |
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
@@ -330,15 +332,21 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
|
|
330 |
self.key = PRNGKey(seed)
|
331 |
self.dtype = dtype
|
332 |
|
|
|
|
|
|
|
|
|
|
|
|
|
333 |
# randomly initialized parameters
|
334 |
if abstract_init:
|
335 |
# init the model weights only abstractly, eval_shape will return a pytree
|
336 |
# with the structure as weights but without any actual values, this will just contain
|
337 |
# the shape information. Weights need to be loaded later.
|
338 |
-
init_fn = partial(
|
339 |
random_params = jax.eval_shape(init_fn, self.key)
|
340 |
else:
|
341 |
-
random_params =
|
342 |
|
343 |
# save required_params as set
|
344 |
self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
|
|
|
300 |
- added num_params property
|
301 |
- config_class replaced to DalleBartConfig
|
302 |
- __init__ accepts abstract_init which does uses parameter shape to initialize the model
|
303 |
+
- init weights on CPU
|
304 |
"""
|
305 |
|
306 |
config_class = DalleBartConfig
|
|
|
312 |
seed: int = 0,
|
313 |
dtype: jnp.dtype = jnp.float32,
|
314 |
abstract_init: bool = False,
|
315 |
+
load_on_cpu: bool = True,
|
316 |
**kwargs,
|
317 |
):
|
318 |
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
|
|
332 |
self.key = PRNGKey(seed)
|
333 |
self.dtype = dtype
|
334 |
|
335 |
+
# init weights on CPU
|
336 |
+
if load_on_cpu:
|
337 |
+
init_fn = jax.jit(self.init_weights, static_argnums=(1,), backend="cpu")
|
338 |
+
else:
|
339 |
+
init_fn = self.init_weights
|
340 |
+
|
341 |
# randomly initialized parameters
|
342 |
if abstract_init:
|
343 |
# init the model weights only abstractly, eval_shape will return a pytree
|
344 |
# with the structure as weights but without any actual values, this will just contain
|
345 |
# the shape information. Weights need to be loaded later.
|
346 |
+
init_fn = partial(init_fn, input_shape=input_shape)
|
347 |
random_params = jax.eval_shape(init_fn, self.key)
|
348 |
else:
|
349 |
+
random_params = init_fn(self.key, input_shape)
|
350 |
|
351 |
# save required_params as set
|
352 |
self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
|
tools/train/train.py
CHANGED
@@ -30,21 +30,28 @@ from typing import Callable, Optional
|
|
30 |
import datasets
|
31 |
import jax
|
32 |
import jax.numpy as jnp
|
|
|
33 |
import optax
|
34 |
import transformers
|
35 |
import wandb
|
36 |
from datasets import Dataset
|
37 |
from distributed_shampoo import GraftingType, distributed_shampoo
|
38 |
-
from flax import
|
39 |
-
from flax.jax_utils import unreplicate
|
40 |
from flax.serialization import from_bytes, to_bytes
|
41 |
from flax.training import train_state
|
42 |
-
from flax.training.common_utils import
|
|
|
|
|
43 |
from tqdm import tqdm
|
44 |
-
from transformers import
|
45 |
|
46 |
from dalle_mini.data import Dataset
|
47 |
-
from dalle_mini.model import
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
logger = logging.getLogger(__name__)
|
50 |
|
@@ -223,7 +230,6 @@ class TrainingArguments:
|
|
223 |
"help": 'The optimizer to use. Can be "distributed_shampoo" (default), "adam" or "adafactor"'
|
224 |
},
|
225 |
)
|
226 |
-
weight_decay: float = field(default=None, metadata={"help": "Weight decay."})
|
227 |
beta1: float = field(
|
228 |
default=0.9,
|
229 |
metadata={"help": "Beta1 for Adam & Distributed Shampoo."},
|
@@ -332,6 +338,13 @@ class TrainingArguments:
|
|
332 |
metadata={"help": "Verify that TPU is not in use."},
|
333 |
)
|
334 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
335 |
def __post_init__(self):
|
336 |
assert self.optim in [
|
337 |
"distributed_shampoo",
|
@@ -340,9 +353,6 @@ class TrainingArguments:
|
|
340 |
], f"Selected optimizer not supported: {self.optim}"
|
341 |
if self.per_device_eval_batch_size is None:
|
342 |
self.per_device_eval_batch_size = self.per_device_train_batch_size
|
343 |
-
if self.weight_decay is None:
|
344 |
-
if self.optim in ["distributed_shampoo", "adam"]:
|
345 |
-
self.weight_decay = 0.0
|
346 |
if (
|
347 |
os.path.exists(self.output_dir)
|
348 |
and os.listdir(self.output_dir)
|
@@ -353,6 +363,10 @@ class TrainingArguments:
|
|
353 |
f"Output directory ({self.output_dir}) already exists and is not empty."
|
354 |
"Use --overwrite_output_dir to overcome."
|
355 |
)
|
|
|
|
|
|
|
|
|
356 |
|
357 |
|
358 |
class TrainState(train_state.TrainState):
|
@@ -361,28 +375,6 @@ class TrainState(train_state.TrainState):
|
|
361 |
train_time: float = 0.0 # total time the model trained
|
362 |
train_samples: int = 0 # number of samples seen
|
363 |
|
364 |
-
def replicate(self):
|
365 |
-
return jax_utils.replicate(self).replace(
|
366 |
-
dropout_rng=shard_prng_key(self.dropout_rng)
|
367 |
-
)
|
368 |
-
|
369 |
-
def restore_state(self, artifact_dir):
|
370 |
-
# restore optimizer state
|
371 |
-
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
|
372 |
-
new_opt_state = from_bytes(self.opt_state, f.read())
|
373 |
-
|
374 |
-
# restore other parameters
|
375 |
-
with (Path(artifact_dir) / "training_state.json").open("r") as f:
|
376 |
-
training_state = json.load(f)
|
377 |
-
|
378 |
-
# replace state
|
379 |
-
return self.replace(
|
380 |
-
opt_state=new_opt_state,
|
381 |
-
step=training_state["step"],
|
382 |
-
train_time=training_state["train_time"],
|
383 |
-
train_samples=training_state["train_samples"],
|
384 |
-
)
|
385 |
-
|
386 |
|
387 |
class MetricsLogger:
|
388 |
def __init__(self, state):
|
@@ -391,14 +383,14 @@ class MetricsLogger:
|
|
391 |
|
392 |
def get_all_train_metrics(self, train_metrics, state):
|
393 |
"""Make a dict of training metrics to be logged"""
|
394 |
-
metrics =
|
395 |
# get state parameters
|
396 |
state_dict = {
|
397 |
-
k.split("_")[-1]:
|
398 |
for k in ["epoch", "train_time", "train_samples"]
|
399 |
}
|
400 |
# timing metrics
|
401 |
-
new_step = int(
|
402 |
new_time = time.perf_counter()
|
403 |
if new_step > self.step:
|
404 |
time_per_step = (new_time - self.time) / (new_step - self.step)
|
@@ -487,8 +479,6 @@ def main():
|
|
487 |
dtype=getattr(jnp, model_args.dtype),
|
488 |
abstract_init=True,
|
489 |
)
|
490 |
-
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
491 |
-
print(model.params)
|
492 |
|
493 |
# load tokenizer
|
494 |
tokenizer = DalleBartTokenizer.from_pretrained(
|
@@ -512,8 +502,6 @@ def main():
|
|
512 |
dtype=getattr(jnp, model_args.dtype),
|
513 |
abstract_init=True,
|
514 |
)
|
515 |
-
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
516 |
-
print(model.params)
|
517 |
else:
|
518 |
model = DalleBart(
|
519 |
config,
|
@@ -523,7 +511,7 @@ def main():
|
|
523 |
|
524 |
# Load tokenizer
|
525 |
if model_args.tokenizer_name is not None:
|
526 |
-
tokenizer =
|
527 |
model_args.tokenizer_name, use_fast=True
|
528 |
)
|
529 |
else:
|
@@ -601,32 +589,9 @@ def main():
|
|
601 |
|
602 |
learning_rate_fn = create_learning_rate_fn()
|
603 |
|
604 |
-
# We use Optax's "masking" functionality to not apply weight decay
|
605 |
-
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
606 |
-
# mask boolean with the same structure as the parameters.
|
607 |
-
# The mask is True for parameters that should be decayed.
|
608 |
-
# Note that this mask is specifically adapted for FlaxBart.
|
609 |
-
def decay_mask_fn(params):
|
610 |
-
flat_params = traverse_util.flatten_dict(params)
|
611 |
-
layer_norm_params = [
|
612 |
-
(name, "scale")
|
613 |
-
for name in [
|
614 |
-
"self_attn_layer_norm",
|
615 |
-
"layernorm_embedding",
|
616 |
-
"final_layer_norm",
|
617 |
-
]
|
618 |
-
]
|
619 |
-
flat_mask = {
|
620 |
-
path: (path[-1] != "bias" and path[-2:] not in layer_norm_params)
|
621 |
-
for path in flat_params
|
622 |
-
}
|
623 |
-
return traverse_util.unflatten_dict(flat_mask)
|
624 |
-
|
625 |
# create adam optimizer
|
626 |
if training_args.optim == "distributed_shampoo":
|
627 |
# parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
|
628 |
-
# Notes:
|
629 |
-
# - mask for weight decay is not implemented
|
630 |
optimizer = distributed_shampoo(
|
631 |
learning_rate_fn,
|
632 |
block_size=training_args.block_size,
|
@@ -634,7 +599,6 @@ def main():
|
|
634 |
beta2=training_args.beta2,
|
635 |
diagonal_epsilon=1e-10,
|
636 |
matrix_epsilon=1e-8,
|
637 |
-
weight_decay=training_args.weight_decay,
|
638 |
start_preconditioning_step=training_args.warmup_steps,
|
639 |
preconditioning_compute_steps=training_args.preconditioning_compute_steps,
|
640 |
statistics_compute_steps=1,
|
@@ -657,30 +621,104 @@ def main():
|
|
657 |
b1=training_args.beta1,
|
658 |
b2=training_args.beta2,
|
659 |
eps=training_args.adam_epsilon,
|
660 |
-
weight_decay=training_args.weight_decay,
|
661 |
-
mask=decay_mask_fn,
|
662 |
)
|
663 |
elif training_args.optim == "adafactor":
|
664 |
# We use the default parameters here to initialize adafactor,
|
665 |
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
666 |
optimizer = optax.adafactor(
|
667 |
learning_rate=learning_rate_fn,
|
668 |
-
weight_decay_rate=training_args.weight_decay,
|
669 |
-
weight_decay_mask=decay_mask_fn,
|
670 |
clipping_threshold=training_args.max_grad_norm,
|
671 |
)
|
672 |
|
673 |
-
#
|
674 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
675 |
apply_fn=model.__call__,
|
676 |
-
params=model.params,
|
677 |
tx=optimizer,
|
678 |
-
dropout_rng=dropout_rng,
|
679 |
)
|
|
|
|
|
680 |
if training_args.resume_from_checkpoint is not None:
|
681 |
-
# restore
|
682 |
-
|
683 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
684 |
|
685 |
# label smoothed cross entropy
|
686 |
def loss_fn(logits, labels):
|
@@ -691,6 +729,8 @@ def main():
|
|
691 |
# Define gradient update step fn
|
692 |
def train_step(state, batch, delta_time):
|
693 |
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
|
|
|
|
694 |
|
695 |
def compute_loss(params, minibatch):
|
696 |
labels = minibatch.pop("labels")
|
@@ -728,7 +768,6 @@ def main():
|
|
728 |
),
|
729 |
)
|
730 |
|
731 |
-
grads = jax.lax.pmean(grads, "batch")
|
732 |
state = state.apply_gradients(
|
733 |
grads=grads,
|
734 |
dropout_rng=new_dropout_rng,
|
@@ -740,7 +779,6 @@ def main():
|
|
740 |
"loss": loss,
|
741 |
"learning_rate": learning_rate_fn(state.step),
|
742 |
}
|
743 |
-
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
744 |
|
745 |
return state, metrics
|
746 |
|
@@ -752,12 +790,20 @@ def main():
|
|
752 |
|
753 |
# summarize metrics
|
754 |
metrics = {"loss": loss}
|
755 |
-
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
756 |
return metrics
|
757 |
|
758 |
# Create parallel version of the train and eval step
|
759 |
-
p_train_step =
|
760 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
761 |
|
762 |
logger.info("***** Running training *****")
|
763 |
logger.info(f" Num examples = {len_train_dataset}")
|
@@ -792,9 +838,6 @@ def main():
|
|
792 |
}
|
793 |
)
|
794 |
|
795 |
-
# replicate state on each device
|
796 |
-
state = state.replicate()
|
797 |
-
|
798 |
def run_evaluation():
|
799 |
# ======================== Evaluating ==============================
|
800 |
eval_metrics = []
|
@@ -819,13 +862,11 @@ def main():
|
|
819 |
eval_metrics.append(metrics)
|
820 |
|
821 |
# normalize eval metrics
|
822 |
-
eval_metrics =
|
823 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
824 |
|
825 |
# log metrics
|
826 |
-
metrics_logger.log(
|
827 |
-
eval_metrics, step=unreplicate(state.step), prefix="eval"
|
828 |
-
)
|
829 |
|
830 |
# Print metrics and update progress bar
|
831 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
@@ -836,7 +877,7 @@ def main():
|
|
836 |
|
837 |
def run_save_model(state, eval_metrics=None):
|
838 |
if jax.process_index() == 0:
|
839 |
-
params = jax.device_get(
|
840 |
# save model locally
|
841 |
model.save_pretrained(
|
842 |
training_args.output_dir,
|
@@ -847,11 +888,11 @@ def main():
|
|
847 |
tokenizer.save_pretrained(training_args.output_dir)
|
848 |
|
849 |
# save state
|
850 |
-
opt_state =
|
851 |
with (Path(training_args.output_dir) / "opt_state.msgpack").open("wb") as f:
|
852 |
f.write(to_bytes(opt_state))
|
853 |
state_dict = {
|
854 |
-
k: jax.device_get(
|
855 |
for k in ["step", "epoch", "train_time", "train_samples"]
|
856 |
}
|
857 |
with (Path(training_args.output_dir) / "training_state.json").open(
|
@@ -912,63 +953,64 @@ def main():
|
|
912 |
last_time = time.perf_counter()
|
913 |
train_metrics = None
|
914 |
|
915 |
-
|
916 |
-
|
917 |
-
|
918 |
-
|
919 |
-
|
920 |
-
|
921 |
-
|
922 |
-
|
923 |
-
|
924 |
-
|
925 |
-
|
926 |
-
|
927 |
-
|
928 |
-
|
929 |
-
|
930 |
-
|
931 |
-
|
932 |
-
|
933 |
-
|
934 |
-
|
|
|
935 |
|
936 |
-
|
937 |
-
|
938 |
-
|
939 |
-
|
940 |
|
941 |
-
|
942 |
-
|
943 |
-
|
944 |
-
)
|
945 |
-
step = unreplicate(state.step)
|
946 |
|
947 |
-
|
948 |
-
|
949 |
-
|
|
|
|
|
950 |
|
951 |
-
|
952 |
-
|
953 |
-
|
954 |
|
955 |
-
|
956 |
-
|
957 |
|
958 |
-
|
959 |
-
|
960 |
-
|
961 |
-
|
962 |
|
963 |
-
|
964 |
-
|
965 |
-
|
966 |
|
967 |
-
|
968 |
-
|
969 |
|
970 |
-
|
971 |
-
|
972 |
|
973 |
|
974 |
if __name__ == "__main__":
|
|
|
30 |
import datasets
|
31 |
import jax
|
32 |
import jax.numpy as jnp
|
33 |
+
import numpy as np
|
34 |
import optax
|
35 |
import transformers
|
36 |
import wandb
|
37 |
from datasets import Dataset
|
38 |
from distributed_shampoo import GraftingType, distributed_shampoo
|
39 |
+
from flax.core.frozen_dict import freeze
|
|
|
40 |
from flax.serialization import from_bytes, to_bytes
|
41 |
from flax.training import train_state
|
42 |
+
from flax.training.common_utils import onehot, stack_forest
|
43 |
+
from jax.experimental import PartitionSpec, maps
|
44 |
+
from jax.experimental.pjit import pjit
|
45 |
from tqdm import tqdm
|
46 |
+
from transformers import HfArgumentParser
|
47 |
|
48 |
from dalle_mini.data import Dataset
|
49 |
+
from dalle_mini.model import (
|
50 |
+
DalleBart,
|
51 |
+
DalleBartConfig,
|
52 |
+
DalleBartTokenizer,
|
53 |
+
set_partitions,
|
54 |
+
)
|
55 |
|
56 |
logger = logging.getLogger(__name__)
|
57 |
|
|
|
230 |
"help": 'The optimizer to use. Can be "distributed_shampoo" (default), "adam" or "adafactor"'
|
231 |
},
|
232 |
)
|
|
|
233 |
beta1: float = field(
|
234 |
default=0.9,
|
235 |
metadata={"help": "Beta1 for Adam & Distributed Shampoo."},
|
|
|
338 |
metadata={"help": "Verify that TPU is not in use."},
|
339 |
)
|
340 |
|
341 |
+
mp_devices: Optional[int] = field(
|
342 |
+
default=1,
|
343 |
+
metadata={
|
344 |
+
"help": "Number of devices required for model parallelism. The other dimension of available devices is used for data parallelism."
|
345 |
+
},
|
346 |
+
)
|
347 |
+
|
348 |
def __post_init__(self):
|
349 |
assert self.optim in [
|
350 |
"distributed_shampoo",
|
|
|
353 |
], f"Selected optimizer not supported: {self.optim}"
|
354 |
if self.per_device_eval_batch_size is None:
|
355 |
self.per_device_eval_batch_size = self.per_device_train_batch_size
|
|
|
|
|
|
|
356 |
if (
|
357 |
os.path.exists(self.output_dir)
|
358 |
and os.listdir(self.output_dir)
|
|
|
363 |
f"Output directory ({self.output_dir}) already exists and is not empty."
|
364 |
"Use --overwrite_output_dir to overcome."
|
365 |
)
|
366 |
+
assert (
|
367 |
+
jax.device_count() % self.mp_devices == 0
|
368 |
+
), f"Number of available devices ({jax.device_count()} must be divisible by number of devices used for model parallelism ({self.mp_devices})."
|
369 |
+
self.dp_devices = jax.device_count() // self.mp_devices
|
370 |
|
371 |
|
372 |
class TrainState(train_state.TrainState):
|
|
|
375 |
train_time: float = 0.0 # total time the model trained
|
376 |
train_samples: int = 0 # number of samples seen
|
377 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
378 |
|
379 |
class MetricsLogger:
|
380 |
def __init__(self, state):
|
|
|
383 |
|
384 |
def get_all_train_metrics(self, train_metrics, state):
|
385 |
"""Make a dict of training metrics to be logged"""
|
386 |
+
metrics = train_metrics
|
387 |
# get state parameters
|
388 |
state_dict = {
|
389 |
+
k.split("_")[-1]: getattr(state, k)
|
390 |
for k in ["epoch", "train_time", "train_samples"]
|
391 |
}
|
392 |
# timing metrics
|
393 |
+
new_step = int(state.step)
|
394 |
new_time = time.perf_counter()
|
395 |
if new_step > self.step:
|
396 |
time_per_step = (new_time - self.time) / (new_step - self.step)
|
|
|
479 |
dtype=getattr(jnp, model_args.dtype),
|
480 |
abstract_init=True,
|
481 |
)
|
|
|
|
|
482 |
|
483 |
# load tokenizer
|
484 |
tokenizer = DalleBartTokenizer.from_pretrained(
|
|
|
502 |
dtype=getattr(jnp, model_args.dtype),
|
503 |
abstract_init=True,
|
504 |
)
|
|
|
|
|
505 |
else:
|
506 |
model = DalleBart(
|
507 |
config,
|
|
|
511 |
|
512 |
# Load tokenizer
|
513 |
if model_args.tokenizer_name is not None:
|
514 |
+
tokenizer = DalleBartTokenizer.from_pretrained(
|
515 |
model_args.tokenizer_name, use_fast=True
|
516 |
)
|
517 |
else:
|
|
|
589 |
|
590 |
learning_rate_fn = create_learning_rate_fn()
|
591 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
592 |
# create adam optimizer
|
593 |
if training_args.optim == "distributed_shampoo":
|
594 |
# parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
|
|
|
|
|
595 |
optimizer = distributed_shampoo(
|
596 |
learning_rate_fn,
|
597 |
block_size=training_args.block_size,
|
|
|
599 |
beta2=training_args.beta2,
|
600 |
diagonal_epsilon=1e-10,
|
601 |
matrix_epsilon=1e-8,
|
|
|
602 |
start_preconditioning_step=training_args.warmup_steps,
|
603 |
preconditioning_compute_steps=training_args.preconditioning_compute_steps,
|
604 |
statistics_compute_steps=1,
|
|
|
621 |
b1=training_args.beta1,
|
622 |
b2=training_args.beta2,
|
623 |
eps=training_args.adam_epsilon,
|
|
|
|
|
624 |
)
|
625 |
elif training_args.optim == "adafactor":
|
626 |
# We use the default parameters here to initialize adafactor,
|
627 |
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
628 |
optimizer = optax.adafactor(
|
629 |
learning_rate=learning_rate_fn,
|
|
|
|
|
630 |
clipping_threshold=training_args.max_grad_norm,
|
631 |
)
|
632 |
|
633 |
+
# get opt_state shape without actual init
|
634 |
+
opt_state_shape = jax.eval_shape(lambda x: optimizer.init(x), model.params)
|
635 |
+
|
636 |
+
# get PartitionSpec for model params
|
637 |
+
param_spec = set_partitions(model.params)
|
638 |
+
|
639 |
+
# create PartitionSpec for opt_state
|
640 |
+
def opt_state_spec_per_leaf(x):
|
641 |
+
if training_args.optim in ["adam", "adafactor"]:
|
642 |
+
if isinstance(x, dict):
|
643 |
+
# variables with same structure as params
|
644 |
+
return param_spec
|
645 |
+
else:
|
646 |
+
# other variables such as count
|
647 |
+
return None
|
648 |
+
else:
|
649 |
+
# TODO: create spec for Distributed Shampoo
|
650 |
+
raise NotImplementedError
|
651 |
+
|
652 |
+
opt_state_spec = jax.tree_map(
|
653 |
+
opt_state_spec_per_leaf,
|
654 |
+
opt_state_shape,
|
655 |
+
# return None spec for empty elements
|
656 |
+
is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)),
|
657 |
+
)
|
658 |
+
|
659 |
+
# create a mesh
|
660 |
+
mesh_shape = (training_args.dp_devices, training_args.mp_devices)
|
661 |
+
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
|
662 |
+
mesh = maps.Mesh(devices, ("batch", "mp"))
|
663 |
+
|
664 |
+
# Create state spec
|
665 |
+
state_spec = TrainState(
|
666 |
+
params=param_spec,
|
667 |
+
opt_state=opt_state_spec,
|
668 |
+
dropout_rng=None,
|
669 |
+
step=None,
|
670 |
+
epoch=None,
|
671 |
+
train_time=None,
|
672 |
+
train_samples=None,
|
673 |
apply_fn=model.__call__,
|
|
|
674 |
tx=optimizer,
|
|
|
675 |
)
|
676 |
+
|
677 |
+
opt_state, attr_state = None, None
|
678 |
if training_args.resume_from_checkpoint is not None:
|
679 |
+
# restore opt_state
|
680 |
+
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
|
681 |
+
opt_state = from_bytes(opt_state_shape, f.read())
|
682 |
+
# need to freeze dict for pjit
|
683 |
+
opt_state = jax.tree_map(
|
684 |
+
lambda x: freeze(x) if isinstance(x, dict) else x,
|
685 |
+
opt_state,
|
686 |
+
is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)),
|
687 |
+
)
|
688 |
+
# restore other attributes
|
689 |
+
with (Path(artifact_dir) / "training_state.json").open("r") as f:
|
690 |
+
attr_state = json.load(f)
|
691 |
+
|
692 |
+
# create training state
|
693 |
+
def init_state(params, opt_state):
|
694 |
+
if training_args.resume_from_checkpoint is None:
|
695 |
+
state = TrainState.create(
|
696 |
+
apply_fn=model.__call__,
|
697 |
+
tx=optimizer,
|
698 |
+
params=freeze(params),
|
699 |
+
dropout_rng=dropout_rng,
|
700 |
+
)
|
701 |
+
else:
|
702 |
+
state = TrainState(
|
703 |
+
apply_fn=model.__call__,
|
704 |
+
tx=optimizer,
|
705 |
+
params=freeze(params),
|
706 |
+
opt_state=opt_state,
|
707 |
+
dropout_rng=dropout_rng,
|
708 |
+
**attr_state,
|
709 |
+
)
|
710 |
+
return state
|
711 |
+
|
712 |
+
with maps.mesh(mesh.devices, mesh.axis_names):
|
713 |
+
state = pjit(
|
714 |
+
init_state,
|
715 |
+
in_axis_resources=(param_spec, opt_state_spec),
|
716 |
+
out_axis_resources=state_spec,
|
717 |
+
donate_argnums=(0, 1),
|
718 |
+
)(freeze(model.params), opt_state)
|
719 |
+
|
720 |
+
# free memory from large parameters
|
721 |
+
del model._params, opt_state
|
722 |
|
723 |
# label smoothed cross entropy
|
724 |
def loss_fn(logits, labels):
|
|
|
729 |
# Define gradient update step fn
|
730 |
def train_step(state, batch, delta_time):
|
731 |
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
732 |
+
# use a different rng per node
|
733 |
+
dropout_rng = jax.random.fold_in(dropout_rng, jax.process_index())
|
734 |
|
735 |
def compute_loss(params, minibatch):
|
736 |
labels = minibatch.pop("labels")
|
|
|
768 |
),
|
769 |
)
|
770 |
|
|
|
771 |
state = state.apply_gradients(
|
772 |
grads=grads,
|
773 |
dropout_rng=new_dropout_rng,
|
|
|
779 |
"loss": loss,
|
780 |
"learning_rate": learning_rate_fn(state.step),
|
781 |
}
|
|
|
782 |
|
783 |
return state, metrics
|
784 |
|
|
|
790 |
|
791 |
# summarize metrics
|
792 |
metrics = {"loss": loss}
|
|
|
793 |
return metrics
|
794 |
|
795 |
# Create parallel version of the train and eval step
|
796 |
+
p_train_step = pjit(
|
797 |
+
train_step,
|
798 |
+
in_axis_resources=(state_spec, PartitionSpec("batch", None), None),
|
799 |
+
out_axis_resources=(state_spec, None),
|
800 |
+
donate_argnums=(0,),
|
801 |
+
)
|
802 |
+
p_eval_step = pjit(
|
803 |
+
eval_step,
|
804 |
+
in_axis_resources=(param_spec, PartitionSpec("batch", None)),
|
805 |
+
out_axis_resources=None,
|
806 |
+
)
|
807 |
|
808 |
logger.info("***** Running training *****")
|
809 |
logger.info(f" Num examples = {len_train_dataset}")
|
|
|
838 |
}
|
839 |
)
|
840 |
|
|
|
|
|
|
|
841 |
def run_evaluation():
|
842 |
# ======================== Evaluating ==============================
|
843 |
eval_metrics = []
|
|
|
862 |
eval_metrics.append(metrics)
|
863 |
|
864 |
# normalize eval metrics
|
865 |
+
eval_metrics = stack_forest(eval_metrics)
|
866 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
867 |
|
868 |
# log metrics
|
869 |
+
metrics_logger.log(eval_metrics, step=state.step, prefix="eval")
|
|
|
|
|
870 |
|
871 |
# Print metrics and update progress bar
|
872 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
|
|
877 |
|
878 |
def run_save_model(state, eval_metrics=None):
|
879 |
if jax.process_index() == 0:
|
880 |
+
params = jax.device_get(state.params)
|
881 |
# save model locally
|
882 |
model.save_pretrained(
|
883 |
training_args.output_dir,
|
|
|
888 |
tokenizer.save_pretrained(training_args.output_dir)
|
889 |
|
890 |
# save state
|
891 |
+
opt_state = jax.device_get(state.opt_state)
|
892 |
with (Path(training_args.output_dir) / "opt_state.msgpack").open("wb") as f:
|
893 |
f.write(to_bytes(opt_state))
|
894 |
state_dict = {
|
895 |
+
k: jax.device_get(getattr(state, k)).item()
|
896 |
for k in ["step", "epoch", "train_time", "train_samples"]
|
897 |
}
|
898 |
with (Path(training_args.output_dir) / "training_state.json").open(
|
|
|
953 |
last_time = time.perf_counter()
|
954 |
train_metrics = None
|
955 |
|
956 |
+
with maps.mesh(mesh.devices, mesh.axis_names):
|
957 |
+
for epoch in epochs:
|
958 |
+
state.replace(epoch=epoch)
|
959 |
+
# ======================== Training ================================
|
960 |
+
metrics_logger.log({"train/epoch": epoch}, step=state.step)
|
961 |
+
|
962 |
+
# Generate an epoch by shuffling sampling indices from the train dataset
|
963 |
+
train_loader = dataset.dataloader(
|
964 |
+
"train",
|
965 |
+
training_args.per_device_train_batch_size,
|
966 |
+
training_args.gradient_accumulation_steps,
|
967 |
+
epoch,
|
968 |
+
)
|
969 |
+
# train
|
970 |
+
for batch in tqdm(
|
971 |
+
train_loader,
|
972 |
+
desc="Training...",
|
973 |
+
position=1,
|
974 |
+
leave=False,
|
975 |
+
total=steps_per_epoch,
|
976 |
+
):
|
977 |
|
978 |
+
# calculate delta time (we have a lag of one step but it's ok)
|
979 |
+
new_time = time.perf_counter()
|
980 |
+
delta_time = new_time - last_time
|
981 |
+
last_time = new_time
|
982 |
|
983 |
+
# train step
|
984 |
+
state, train_metrics = p_train_step(state, batch, delta_time)
|
985 |
+
step = state.step
|
|
|
|
|
986 |
|
987 |
+
if step % training_args.logging_steps == 0 and jax.process_index() == 0:
|
988 |
+
all_metrics = metrics_logger.get_all_train_metrics(
|
989 |
+
train_metrics, state
|
990 |
+
)
|
991 |
+
metrics_logger.log(all_metrics, step=step, prefix="train")
|
992 |
|
993 |
+
eval_metrics = None
|
994 |
+
if training_args.eval_steps and step % training_args.eval_steps == 0:
|
995 |
+
eval_metrics = run_evaluation()
|
996 |
|
997 |
+
if step % training_args.save_steps == 0:
|
998 |
+
run_save_model(state, eval_metrics)
|
999 |
|
1000 |
+
# log final train metrics
|
1001 |
+
if train_metrics is not None:
|
1002 |
+
all_metrics = metrics_logger.get_all_train_metrics(train_metrics, state)
|
1003 |
+
metrics_logger.log(all_metrics, step=step, prefix="train")
|
1004 |
|
1005 |
+
epochs.write(
|
1006 |
+
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"
|
1007 |
+
)
|
1008 |
|
1009 |
+
# Final evaluation
|
1010 |
+
eval_metrics = run_evaluation()
|
1011 |
|
1012 |
+
# save checkpoint after each epoch
|
1013 |
+
run_save_model(state, eval_metrics)
|
1014 |
|
1015 |
|
1016 |
if __name__ == "__main__":
|