Spaces:
Running
Running
feat(train): overhead from 70% to 1% 🥳
Browse files- tools/train/train.py +21 -5
tools/train/train.py
CHANGED
@@ -777,9 +777,10 @@ def main():
|
|
777 |
def train_step(state, batch, delta_time):
|
778 |
# batch is (gradient_accumulation_steps, minibatch_size, ...)
|
779 |
# check correct batch shape during compilation
|
780 |
-
assert batch["labels"].shape[0:
|
781 |
training_args.gradient_accumulation_steps,
|
782 |
-
|
|
|
783 |
), f"Expected label batch of shape dp_devices x gradient_acculumation x batch_per_device and got {batch['labels'].shape}"
|
784 |
|
785 |
# get a minibatch (one gradient accumulation slice)
|
@@ -801,13 +802,27 @@ def main():
|
|
801 |
grad_fn = jax.value_and_grad(compute_loss)
|
802 |
|
803 |
def loss_and_grad(grad_idx, dropout_rng):
|
|
|
804 |
minibatch = get_minibatch(batch, grad_idx)
|
805 |
# ensure batch is sharded over devices
|
806 |
minibatch = jax.tree_map(
|
807 |
lambda x: with_sharding_constraint(x, PartitionSpec("batch")), minibatch
|
808 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
809 |
# return loss and grads
|
810 |
-
return
|
811 |
|
812 |
# create a new rng
|
813 |
dropout_rng, _ = jax.random.split(state.dropout_rng)
|
@@ -1061,12 +1076,13 @@ def main():
|
|
1061 |
delta_time = new_time - last_time
|
1062 |
last_time = new_time
|
1063 |
|
1064 |
-
# reshape data into (gradient_accumulation_steps,
|
1065 |
batch = jax.tree_map(
|
1066 |
lambda x: x.reshape(
|
1067 |
(
|
1068 |
training_args.gradient_accumulation_steps,
|
1069 |
-
|
|
|
1070 |
)
|
1071 |
+ x.shape[1:]
|
1072 |
),
|
|
|
777 |
def train_step(state, batch, delta_time):
|
778 |
# batch is (gradient_accumulation_steps, minibatch_size, ...)
|
779 |
# check correct batch shape during compilation
|
780 |
+
assert batch["labels"].shape[0:3] == (
|
781 |
training_args.gradient_accumulation_steps,
|
782 |
+
training_args.dp_devices,
|
783 |
+
training_args.per_device_train_batch_size,
|
784 |
), f"Expected label batch of shape dp_devices x gradient_acculumation x batch_per_device and got {batch['labels'].shape}"
|
785 |
|
786 |
# get a minibatch (one gradient accumulation slice)
|
|
|
802 |
grad_fn = jax.value_and_grad(compute_loss)
|
803 |
|
804 |
def loss_and_grad(grad_idx, dropout_rng):
|
805 |
+
# minibatch at grad_idx, shape (dp_devices, per_device_train_batch_size, ...)
|
806 |
minibatch = get_minibatch(batch, grad_idx)
|
807 |
# ensure batch is sharded over devices
|
808 |
minibatch = jax.tree_map(
|
809 |
lambda x: with_sharding_constraint(x, PartitionSpec("batch")), minibatch
|
810 |
)
|
811 |
+
# calculate loss and grads independently per dp_device
|
812 |
+
loss_grads = jax.vmap(grad_fn, in_axes=(None, 0, None), out_axes=(0, 0))(
|
813 |
+
state.params, minibatch, dropout_rng
|
814 |
+
)
|
815 |
+
# ensure they are sharded over devices
|
816 |
+
loss_grads = jax.tree_map(
|
817 |
+
lambda x: with_sharding_constraint(x, PartitionSpec("batch")),
|
818 |
+
loss_grads,
|
819 |
+
)
|
820 |
+
|
821 |
+
# average across all devices
|
822 |
+
loss_grads = jax.tree_map(lambda x: jnp.mean(x, axis=0), loss_grads)
|
823 |
+
|
824 |
# return loss and grads
|
825 |
+
return loss_grads
|
826 |
|
827 |
# create a new rng
|
828 |
dropout_rng, _ = jax.random.split(state.dropout_rng)
|
|
|
1076 |
delta_time = new_time - last_time
|
1077 |
last_time = new_time
|
1078 |
|
1079 |
+
# reshape data into (gradient_accumulation_steps, dp_devices, batch_per_dp, ...)
|
1080 |
batch = jax.tree_map(
|
1081 |
lambda x: x.reshape(
|
1082 |
(
|
1083 |
training_args.gradient_accumulation_steps,
|
1084 |
+
training_args.dp_devices,
|
1085 |
+
training_args.per_device_train_batch_size,
|
1086 |
)
|
1087 |
+ x.shape[1:]
|
1088 |
),
|