Spaces:
Running
Running
feat(train): progress on pjit
Browse files- src/dalle_mini/data.py +0 -2
- tools/train/train.py +34 -31
src/dalle_mini/data.py
CHANGED
@@ -191,7 +191,6 @@ class Dataset:
|
|
191 |
lambda x: x.reshape((-1, per_device_batch_size) + x.shape[1:]),
|
192 |
batch,
|
193 |
)
|
194 |
-
batch = shard(batch)
|
195 |
yield batch
|
196 |
|
197 |
def _dataloader_datasets_streaming(
|
@@ -232,7 +231,6 @@ class Dataset:
|
|
232 |
),
|
233 |
batch,
|
234 |
)
|
235 |
-
batch = shard(batch)
|
236 |
yield batch
|
237 |
batch = {k: [] for k in keys}
|
238 |
first_loop = False
|
|
|
191 |
lambda x: x.reshape((-1, per_device_batch_size) + x.shape[1:]),
|
192 |
batch,
|
193 |
)
|
|
|
194 |
yield batch
|
195 |
|
196 |
def _dataloader_datasets_streaming(
|
|
|
231 |
),
|
232 |
batch,
|
233 |
)
|
|
|
234 |
yield batch
|
235 |
batch = {k: [] for k in keys}
|
236 |
first_loop = False
|
tools/train/train.py
CHANGED
@@ -34,13 +34,11 @@ import numpy as np
|
|
34 |
import optax
|
35 |
import transformers
|
36 |
from datasets import Dataset
|
37 |
-
from distributed_shampoo import GraftingType, distributed_shampoo
|
38 |
-
from flax import
|
39 |
-
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
40 |
-
from flax.jax_utils import unreplicate
|
41 |
from flax.serialization import from_bytes, to_bytes
|
42 |
from flax.training import train_state
|
43 |
-
from flax.training.common_utils import get_metrics, onehot
|
44 |
from jax.experimental import PartitionSpec, maps
|
45 |
from jax.experimental.pjit import pjit
|
46 |
from tqdm import tqdm
|
@@ -402,14 +400,14 @@ class MetricsLogger:
|
|
402 |
|
403 |
def get_all_train_metrics(self, train_metrics, state):
|
404 |
"""Make a dict of training metrics to be logged"""
|
405 |
-
metrics =
|
406 |
# get state parameters
|
407 |
state_dict = {
|
408 |
-
k.split("_")[-1]:
|
409 |
for k in ["epoch", "train_time", "train_samples"]
|
410 |
}
|
411 |
# timing metrics
|
412 |
-
new_step = int(
|
413 |
new_time = time.perf_counter()
|
414 |
if new_step > self.step:
|
415 |
time_per_step = (new_time - self.time) / (new_step - self.step)
|
@@ -551,7 +549,7 @@ def main():
|
|
551 |
|
552 |
# Initialize our training
|
553 |
rng = jax.random.PRNGKey(training_args.seed_model)
|
554 |
-
rng,
|
555 |
|
556 |
# Store some constant
|
557 |
num_epochs = training_args.num_train_epochs
|
@@ -681,34 +679,39 @@ def main():
|
|
681 |
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
|
682 |
mesh = maps.Mesh(devices, ("batch", "mp"))
|
683 |
|
684 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
685 |
with maps.mesh(mesh.devices, mesh.axis_names):
|
|
|
686 |
params, opt_state = pjit(
|
687 |
lambda x: (x, optimizer.init(x)),
|
688 |
in_axis_resources=None,
|
689 |
out_axis_resources=(param_spec, opt_state_spec),
|
690 |
)(freeze(model.params))
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
opt_state
|
697 |
-
tx=optimizer,
|
698 |
-
dropout_rng=dropout_rng,
|
699 |
-
step=0,
|
700 |
-
)
|
701 |
-
|
702 |
-
# create PartitionSpec for state
|
703 |
-
state_spec = {
|
704 |
-
"params": param_spec,
|
705 |
-
"opt_state": opt_state_spec,
|
706 |
-
"dropout_rng": PartitionSpec("batch", None),
|
707 |
-
"epoch": None,
|
708 |
-
"step": None,
|
709 |
-
"train_samples": None,
|
710 |
-
"train_time": None,
|
711 |
-
}
|
712 |
|
713 |
if training_args.resume_from_checkpoint is not None:
|
714 |
# restore optimizer state and other parameters
|
|
|
34 |
import optax
|
35 |
import transformers
|
36 |
from datasets import Dataset
|
37 |
+
from distributed_shampoo import GraftingType, distributed_shampoo
|
38 |
+
from flax.core.frozen_dict import freeze
|
|
|
|
|
39 |
from flax.serialization import from_bytes, to_bytes
|
40 |
from flax.training import train_state
|
41 |
+
from flax.training.common_utils import get_metrics, onehot
|
42 |
from jax.experimental import PartitionSpec, maps
|
43 |
from jax.experimental.pjit import pjit
|
44 |
from tqdm import tqdm
|
|
|
400 |
|
401 |
def get_all_train_metrics(self, train_metrics, state):
|
402 |
"""Make a dict of training metrics to be logged"""
|
403 |
+
metrics = train_metrics
|
404 |
# get state parameters
|
405 |
state_dict = {
|
406 |
+
k.split("_")[-1]: getattr(state, k)
|
407 |
for k in ["epoch", "train_time", "train_samples"]
|
408 |
}
|
409 |
# timing metrics
|
410 |
+
new_step = int(state.step)
|
411 |
new_time = time.perf_counter()
|
412 |
if new_step > self.step:
|
413 |
time_per_step = (new_time - self.time) / (new_step - self.step)
|
|
|
549 |
|
550 |
# Initialize our training
|
551 |
rng = jax.random.PRNGKey(training_args.seed_model)
|
552 |
+
rng, dropout_rng = jax.random.split(rng)
|
553 |
|
554 |
# Store some constant
|
555 |
num_epochs = training_args.num_train_epochs
|
|
|
679 |
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
|
680 |
mesh = maps.Mesh(devices, ("batch", "mp"))
|
681 |
|
682 |
+
# Setup train state
|
683 |
+
def init_state(params, opt_state):
|
684 |
+
return TrainState(
|
685 |
+
apply_fn=model.__call__,
|
686 |
+
tx=optimizer,
|
687 |
+
params=params,
|
688 |
+
opt_state=opt_state,
|
689 |
+
dropout_rng=dropout_rng,
|
690 |
+
step=0,
|
691 |
+
)
|
692 |
+
|
693 |
+
state_spec = init_state(param_spec, opt_state_spec)
|
694 |
+
state_spec = state_spec.replace(
|
695 |
+
dropout_rng=None,
|
696 |
+
step=None,
|
697 |
+
epoch=None,
|
698 |
+
train_time=None,
|
699 |
+
train_samples=None,
|
700 |
+
)
|
701 |
+
|
702 |
with maps.mesh(mesh.devices, mesh.axis_names):
|
703 |
+
# move params & init opt_state over specified devices
|
704 |
params, opt_state = pjit(
|
705 |
lambda x: (x, optimizer.init(x)),
|
706 |
in_axis_resources=None,
|
707 |
out_axis_resources=(param_spec, opt_state_spec),
|
708 |
)(freeze(model.params))
|
709 |
+
# create training state
|
710 |
+
state = pjit(
|
711 |
+
init_state,
|
712 |
+
in_axis_resources=(param_spec, opt_state_spec),
|
713 |
+
out_axis_resources=state_spec,
|
714 |
+
)(params, opt_state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
715 |
|
716 |
if training_args.resume_from_checkpoint is not None:
|
717 |
# restore optimizer state and other parameters
|