Spaces:
Running
Running
feat(train): distributed_shampoo with pjit
Browse files- src/dalle_mini/model/modeling.py +1 -1
- tools/train/train.py +96 -62
src/dalle_mini/model/modeling.py
CHANGED
@@ -312,7 +312,7 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
|
|
312 |
seed: int = 0,
|
313 |
dtype: jnp.dtype = jnp.float32,
|
314 |
abstract_init: bool = False,
|
315 |
-
load_on_cpu: bool =
|
316 |
**kwargs,
|
317 |
):
|
318 |
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
|
|
312 |
seed: int = 0,
|
313 |
dtype: jnp.dtype = jnp.float32,
|
314 |
abstract_init: bool = False,
|
315 |
+
load_on_cpu: bool = False,
|
316 |
**kwargs,
|
317 |
):
|
318 |
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
tools/train/train.py
CHANGED
@@ -36,7 +36,7 @@ import transformers
|
|
36 |
import wandb
|
37 |
from datasets import Dataset
|
38 |
from distributed_shampoo import GraftingType, distributed_shampoo
|
39 |
-
from flax.core.frozen_dict import freeze
|
40 |
from flax.serialization import from_bytes, to_bytes
|
41 |
from flax.training import train_state
|
42 |
from flax.training.common_utils import onehot, stack_forest
|
@@ -478,6 +478,7 @@ def main():
|
|
478 |
artifact_dir,
|
479 |
dtype=getattr(jnp, model_args.dtype),
|
480 |
abstract_init=True,
|
|
|
481 |
)
|
482 |
|
483 |
# load tokenizer
|
@@ -501,12 +502,14 @@ def main():
|
|
501 |
seed=training_args.seed_model,
|
502 |
dtype=getattr(jnp, model_args.dtype),
|
503 |
abstract_init=True,
|
|
|
504 |
)
|
505 |
else:
|
506 |
model = DalleBart(
|
507 |
config,
|
508 |
seed=training_args.seed_model,
|
509 |
dtype=getattr(jnp, model_args.dtype),
|
|
|
510 |
)
|
511 |
|
512 |
# Load tokenizer
|
@@ -606,7 +609,10 @@ def main():
|
|
606 |
graft_type=GraftingType.RMSPROP_NORMALIZED,
|
607 |
nesterov=False,
|
608 |
exponent_override=0,
|
609 |
-
|
|
|
|
|
|
|
610 |
inverse_failure_threshold=0.1,
|
611 |
moving_average_for_momentum=True,
|
612 |
skip_preconditioning_dim_size_gt=training_args.skip_preconditioning_dim_size_gt,
|
@@ -630,31 +636,48 @@ def main():
|
|
630 |
clipping_threshold=training_args.max_grad_norm,
|
631 |
)
|
632 |
|
633 |
-
# get opt_state shape without actual init
|
634 |
-
opt_state_shape = jax.eval_shape(lambda x: optimizer.init(x), model.params)
|
635 |
-
|
636 |
# get PartitionSpec for model params
|
637 |
param_spec = set_partitions(model.params)
|
638 |
|
639 |
-
#
|
640 |
-
def
|
641 |
-
if training_args.optim
|
642 |
-
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
648 |
else:
|
649 |
-
# TODO: create spec for Distributed Shampoo
|
650 |
raise NotImplementedError
|
|
|
651 |
|
652 |
-
opt_state_spec =
|
653 |
-
opt_state_spec_per_leaf,
|
654 |
-
opt_state_shape,
|
655 |
-
# return None spec for empty elements
|
656 |
-
is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)),
|
657 |
-
)
|
658 |
|
659 |
# create a mesh
|
660 |
mesh_shape = (training_args.dp_devices, training_args.mp_devices)
|
@@ -674,51 +697,62 @@ def main():
|
|
674 |
tx=optimizer,
|
675 |
)
|
676 |
|
677 |
-
opt_state, attr_state = None, None
|
678 |
-
if training_args.resume_from_checkpoint is not None:
|
679 |
-
# restore opt_state
|
680 |
-
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
|
681 |
-
opt_state = from_bytes(opt_state_shape, f.read())
|
682 |
-
# need to freeze dict for pjit
|
683 |
-
opt_state = jax.tree_map(
|
684 |
-
lambda x: freeze(x) if isinstance(x, dict) else x,
|
685 |
-
opt_state,
|
686 |
-
is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)),
|
687 |
-
)
|
688 |
-
# restore other attributes
|
689 |
-
with (Path(artifact_dir) / "training_state.json").open("r") as f:
|
690 |
-
attr_state = json.load(f)
|
691 |
-
|
692 |
# create training state
|
693 |
-
|
694 |
if training_args.resume_from_checkpoint is None:
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
701 |
else:
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
opt_state=
|
707 |
-
|
708 |
-
|
709 |
-
|
710 |
-
|
711 |
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
-
|
716 |
-
|
717 |
-
|
718 |
-
|
719 |
-
|
720 |
-
|
721 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
722 |
|
723 |
# label smoothed cross entropy
|
724 |
def loss_fn(logits, labels):
|
|
|
36 |
import wandb
|
37 |
from datasets import Dataset
|
38 |
from distributed_shampoo import GraftingType, distributed_shampoo
|
39 |
+
from flax.core.frozen_dict import freeze, unfreeze
|
40 |
from flax.serialization import from_bytes, to_bytes
|
41 |
from flax.training import train_state
|
42 |
from flax.training.common_utils import onehot, stack_forest
|
|
|
478 |
artifact_dir,
|
479 |
dtype=getattr(jnp, model_args.dtype),
|
480 |
abstract_init=True,
|
481 |
+
load_on_cpu=True,
|
482 |
)
|
483 |
|
484 |
# load tokenizer
|
|
|
502 |
seed=training_args.seed_model,
|
503 |
dtype=getattr(jnp, model_args.dtype),
|
504 |
abstract_init=True,
|
505 |
+
load_on_cpu=True,
|
506 |
)
|
507 |
else:
|
508 |
model = DalleBart(
|
509 |
config,
|
510 |
seed=training_args.seed_model,
|
511 |
dtype=getattr(jnp, model_args.dtype),
|
512 |
+
load_on_cpu=True,
|
513 |
)
|
514 |
|
515 |
# Load tokenizer
|
|
|
609 |
graft_type=GraftingType.RMSPROP_NORMALIZED,
|
610 |
nesterov=False,
|
611 |
exponent_override=0,
|
612 |
+
statistics_partition_spec=PartitionSpec(None, "batch", None),
|
613 |
+
preconditioner_partition_spec=PartitionSpec("batch", None, None),
|
614 |
+
num_devices_for_pjit=training_args.dp_devices,
|
615 |
+
shard_optimizer_states=True,
|
616 |
inverse_failure_threshold=0.1,
|
617 |
moving_average_for_momentum=True,
|
618 |
skip_preconditioning_dim_size_gt=training_args.skip_preconditioning_dim_size_gt,
|
|
|
636 |
clipping_threshold=training_args.max_grad_norm,
|
637 |
)
|
638 |
|
|
|
|
|
|
|
639 |
# get PartitionSpec for model params
|
640 |
param_spec = set_partitions(model.params)
|
641 |
|
642 |
+
# get PartitionSpec for optimizer state
|
643 |
+
def get_opt_state_spec_and_shape(param_spec):
|
644 |
+
if training_args.optim == "adam":
|
645 |
+
# get opt_state shape without actual init
|
646 |
+
opt_state_shape = jax.eval_shape(optimizer.init, model.params)
|
647 |
+
|
648 |
+
def _opt_state_spec_per_leaf(x):
|
649 |
+
if isinstance(x, dict):
|
650 |
+
# variables with same structure as params
|
651 |
+
return param_spec
|
652 |
+
else:
|
653 |
+
# other variables such as count
|
654 |
+
return None
|
655 |
+
|
656 |
+
opt_state_spec = jax.tree_map(
|
657 |
+
_opt_state_spec_per_leaf,
|
658 |
+
opt_state_shape,
|
659 |
+
# return None spec for empty elements
|
660 |
+
is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)),
|
661 |
+
)
|
662 |
+
|
663 |
+
elif training_args.optim == "adafactor":
|
664 |
+
# factorized state must be replicated (rank different than params)
|
665 |
+
opt_state_spec = None
|
666 |
+
|
667 |
+
elif training_args.optim == "distributed_shampoo":
|
668 |
+
# memory efficient in distributed_shampoo, fake init
|
669 |
+
_opt_state = optimizer.init(model.params)
|
670 |
+
opt_state_spec = _opt_state.pspec_fn(
|
671 |
+
params=model.params,
|
672 |
+
params_partition_spec=unfreeze(param_spec),
|
673 |
+
partition_spec_for_statistics=PartitionSpec(None, "batch", None),
|
674 |
+
)
|
675 |
+
opt_state_shape = _opt_state.shape_and_dtype_fn(model.params)
|
676 |
else:
|
|
|
677 |
raise NotImplementedError
|
678 |
+
return opt_state_spec, opt_state_shape
|
679 |
|
680 |
+
opt_state_spec, opt_state_shape = get_opt_state_spec_and_shape(param_spec)
|
|
|
|
|
|
|
|
|
|
|
681 |
|
682 |
# create a mesh
|
683 |
mesh_shape = (training_args.dp_devices, training_args.mp_devices)
|
|
|
697 |
tx=optimizer,
|
698 |
)
|
699 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
700 |
# create training state
|
701 |
+
with maps.mesh(mesh.devices, mesh.axis_names):
|
702 |
if training_args.resume_from_checkpoint is None:
|
703 |
+
|
704 |
+
def init_state(params):
|
705 |
+
return TrainState.create(
|
706 |
+
apply_fn=model.__call__,
|
707 |
+
tx=optimizer,
|
708 |
+
params=params,
|
709 |
+
dropout_rng=dropout_rng,
|
710 |
+
)
|
711 |
+
|
712 |
+
state = pjit(
|
713 |
+
init_state,
|
714 |
+
in_axis_resources=(param_spec,),
|
715 |
+
out_axis_resources=state_spec,
|
716 |
+
donate_argnums=(0,),
|
717 |
+
)(freeze(model.params))
|
718 |
+
|
719 |
else:
|
720 |
+
# restore opt_state
|
721 |
+
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
|
722 |
+
opt_state = from_bytes(opt_state_shape, f.read())
|
723 |
+
# need to freeze dict for pjit
|
724 |
+
opt_state = jax.tree_map(
|
725 |
+
lambda x: freeze(x) if isinstance(x, dict) else x,
|
726 |
+
opt_state,
|
727 |
+
is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)),
|
728 |
+
)
|
729 |
|
730 |
+
# restore other attributes
|
731 |
+
with (Path(artifact_dir) / "training_state.json").open("r") as f:
|
732 |
+
attr_state = json.load(f)
|
733 |
+
|
734 |
+
def restore_state(params, opt_state):
|
735 |
+
return TrainState(
|
736 |
+
apply_fn=model.__call__,
|
737 |
+
tx=optimizer,
|
738 |
+
params=params,
|
739 |
+
opt_state=opt_state,
|
740 |
+
dropout_rng=dropout_rng,
|
741 |
+
**attr_state,
|
742 |
+
)
|
743 |
+
|
744 |
+
state = pjit(
|
745 |
+
restore_state,
|
746 |
+
in_axis_resources=(param_spec, opt_state_spec),
|
747 |
+
out_axis_resources=state_spec,
|
748 |
+
donate_argnums=(0, 1),
|
749 |
+
)(freeze(model.params), opt_state)
|
750 |
+
|
751 |
+
# remove opt_state from CPU
|
752 |
+
del opt_state
|
753 |
+
|
754 |
+
# free memory
|
755 |
+
del model._params
|
756 |
|
757 |
# label smoothed cross entropy
|
758 |
def loss_fn(logits, labels):
|