Spaces:
Running
Running
feat(train): restore opt_state efficiently
Browse files- tools/train/train.py +40 -35
tools/train/train.py
CHANGED
@@ -42,7 +42,7 @@ from flax.training.common_utils import onehot, stack_forest
|
|
42 |
from jax.experimental import PartitionSpec, maps
|
43 |
from jax.experimental.pjit import pjit
|
44 |
from tqdm import tqdm
|
45 |
-
from transformers import
|
46 |
|
47 |
import wandb
|
48 |
from dalle_mini.data import Dataset
|
@@ -375,23 +375,6 @@ class TrainState(train_state.TrainState):
|
|
375 |
train_time: float = 0.0 # total time the model trained
|
376 |
train_samples: int = 0 # number of samples seen
|
377 |
|
378 |
-
def restore_state(self, artifact_dir):
|
379 |
-
# restore optimizer state
|
380 |
-
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
|
381 |
-
new_opt_state = from_bytes(self.opt_state, f.read())
|
382 |
-
|
383 |
-
# restore other parameters
|
384 |
-
with (Path(artifact_dir) / "training_state.json").open("r") as f:
|
385 |
-
training_state = json.load(f)
|
386 |
-
|
387 |
-
# replace state
|
388 |
-
return self.replace(
|
389 |
-
opt_state=new_opt_state,
|
390 |
-
step=training_state["step"],
|
391 |
-
train_time=training_state["train_time"],
|
392 |
-
train_samples=training_state["train_samples"],
|
393 |
-
)
|
394 |
-
|
395 |
|
396 |
class MetricsLogger:
|
397 |
def __init__(self, state):
|
@@ -528,7 +511,7 @@ def main():
|
|
528 |
|
529 |
# Load tokenizer
|
530 |
if model_args.tokenizer_name is not None:
|
531 |
-
tokenizer =
|
532 |
model_args.tokenizer_name, use_fast=True
|
533 |
)
|
534 |
else:
|
@@ -648,8 +631,7 @@ def main():
|
|
648 |
)
|
649 |
|
650 |
# get opt_state shape without actual init
|
651 |
-
|
652 |
-
opt_state_shape = jax.eval_shape(lambda x: optimizer.init(x), param_shape)
|
653 |
|
654 |
# get PartitionSpec for model params
|
655 |
param_spec = set_partitions(model.params)
|
@@ -692,28 +674,51 @@ def main():
|
|
692 |
tx=optimizer,
|
693 |
)
|
694 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
695 |
# create training state
|
696 |
-
def init_state(params):
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
703 |
return state
|
704 |
|
705 |
with maps.mesh(mesh.devices, mesh.axis_names):
|
706 |
state = pjit(
|
707 |
init_state,
|
708 |
-
in_axis_resources=
|
709 |
out_axis_resources=state_spec,
|
710 |
-
donate_argnums=(0,),
|
711 |
-
)(freeze(model.params))
|
712 |
|
713 |
-
|
714 |
-
|
715 |
-
# we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
|
716 |
-
state = state.restore_state(artifact_dir)
|
717 |
|
718 |
# label smoothed cross entropy
|
719 |
def loss_fn(logits, labels):
|
|
|
42 |
from jax.experimental import PartitionSpec, maps
|
43 |
from jax.experimental.pjit import pjit
|
44 |
from tqdm import tqdm
|
45 |
+
from transformers import HfArgumentParser
|
46 |
|
47 |
import wandb
|
48 |
from dalle_mini.data import Dataset
|
|
|
375 |
train_time: float = 0.0 # total time the model trained
|
376 |
train_samples: int = 0 # number of samples seen
|
377 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
378 |
|
379 |
class MetricsLogger:
|
380 |
def __init__(self, state):
|
|
|
511 |
|
512 |
# Load tokenizer
|
513 |
if model_args.tokenizer_name is not None:
|
514 |
+
tokenizer = DalleBartTokenizer.from_pretrained(
|
515 |
model_args.tokenizer_name, use_fast=True
|
516 |
)
|
517 |
else:
|
|
|
631 |
)
|
632 |
|
633 |
# get opt_state shape without actual init
|
634 |
+
opt_state_shape = jax.eval_shape(lambda x: optimizer.init(x), model.params)
|
|
|
635 |
|
636 |
# get PartitionSpec for model params
|
637 |
param_spec = set_partitions(model.params)
|
|
|
674 |
tx=optimizer,
|
675 |
)
|
676 |
|
677 |
+
opt_state, attr_state = None, None
|
678 |
+
if training_args.resume_from_checkpoint is not None:
|
679 |
+
# restore opt_state
|
680 |
+
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
|
681 |
+
opt_state = from_bytes(opt_state_shape, f.read())
|
682 |
+
# need to freeze dict for pjit
|
683 |
+
opt_state = jax.tree_map(
|
684 |
+
lambda x: freeze(x) if isinstance(x, dict) else x,
|
685 |
+
opt_state,
|
686 |
+
is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)),
|
687 |
+
)
|
688 |
+
# restore other attributes
|
689 |
+
with (Path(artifact_dir) / "training_state.json").open("r") as f:
|
690 |
+
attr_state = json.load(f)
|
691 |
+
|
692 |
# create training state
|
693 |
+
def init_state(params, opt_state):
|
694 |
+
if training_args.resume_from_checkpoint is None:
|
695 |
+
state = TrainState.create(
|
696 |
+
apply_fn=model.__call__,
|
697 |
+
tx=optimizer,
|
698 |
+
params=freeze(params),
|
699 |
+
dropout_rng=dropout_rng,
|
700 |
+
)
|
701 |
+
else:
|
702 |
+
state = TrainState(
|
703 |
+
apply_fn=model.__call__,
|
704 |
+
tx=optimizer,
|
705 |
+
params=freeze(params),
|
706 |
+
opt_state=opt_state,
|
707 |
+
dropout_rng=dropout_rng,
|
708 |
+
**attr_state,
|
709 |
+
)
|
710 |
return state
|
711 |
|
712 |
with maps.mesh(mesh.devices, mesh.axis_names):
|
713 |
state = pjit(
|
714 |
init_state,
|
715 |
+
in_axis_resources=(param_spec, opt_state_spec),
|
716 |
out_axis_resources=state_spec,
|
717 |
+
donate_argnums=(0, 1),
|
718 |
+
)(freeze(model.params), opt_state)
|
719 |
|
720 |
+
# free memory from large parameters
|
721 |
+
del model._params, opt_state
|
|
|
|
|
722 |
|
723 |
# label smoothed cross entropy
|
724 |
def loss_fn(logits, labels):
|