Spaces:
Running
Running
feat(train): merge logged dict
Browse files- tools/train/train.py +8 -8
tools/train/train.py
CHANGED
@@ -797,7 +797,7 @@ def main():
|
|
797 |
|
798 |
# init variables
|
799 |
last_time = time.perf_counter()
|
800 |
-
|
801 |
|
802 |
for epoch in epochs:
|
803 |
state.replace(epoch=jax_utils.replicate(epoch))
|
@@ -821,20 +821,20 @@ def main():
|
|
821 |
last_time = new_time
|
822 |
|
823 |
# train step
|
824 |
-
state,
|
825 |
state, batch, jax_utils.replicate(delta_time)
|
826 |
)
|
827 |
step = unreplicate(state.step)
|
828 |
|
829 |
if step % training_args.logging_steps == 0 and jax.process_index() == 0:
|
830 |
# log metrics
|
831 |
-
|
832 |
# log state parameters
|
833 |
state_dict = {
|
834 |
k.split("_")[-1]: unreplicate(getattr(state, k))
|
835 |
for k in ["epoch", "train_time", "train_samples"]
|
836 |
}
|
837 |
-
wandb_log(state_dict, step=step, prefix="train")
|
838 |
|
839 |
eval_metrics = None
|
840 |
if training_args.eval_steps and step % training_args.eval_steps == 0:
|
@@ -844,12 +844,12 @@ def main():
|
|
844 |
run_save_model(state, eval_metrics)
|
845 |
|
846 |
# log final train metrics
|
847 |
-
if
|
848 |
-
|
849 |
-
wandb_log(
|
850 |
|
851 |
epochs.write(
|
852 |
-
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {
|
853 |
)
|
854 |
|
855 |
# Final evaluation
|
|
|
797 |
|
798 |
# init variables
|
799 |
last_time = time.perf_counter()
|
800 |
+
train_metrics = None
|
801 |
|
802 |
for epoch in epochs:
|
803 |
state.replace(epoch=jax_utils.replicate(epoch))
|
|
|
821 |
last_time = new_time
|
822 |
|
823 |
# train step
|
824 |
+
state, train_metrics = p_train_step(
|
825 |
state, batch, jax_utils.replicate(delta_time)
|
826 |
)
|
827 |
step = unreplicate(state.step)
|
828 |
|
829 |
if step % training_args.logging_steps == 0 and jax.process_index() == 0:
|
830 |
# log metrics
|
831 |
+
metrics = unreplicate(train_metrics)
|
832 |
# log state parameters
|
833 |
state_dict = {
|
834 |
k.split("_")[-1]: unreplicate(getattr(state, k))
|
835 |
for k in ["epoch", "train_time", "train_samples"]
|
836 |
}
|
837 |
+
wandb_log({**metrics, **state_dict}, step=step, prefix="train")
|
838 |
|
839 |
eval_metrics = None
|
840 |
if training_args.eval_steps and step % training_args.eval_steps == 0:
|
|
|
844 |
run_save_model(state, eval_metrics)
|
845 |
|
846 |
# log final train metrics
|
847 |
+
if train_metrics is not None:
|
848 |
+
train_metrics = unreplicate(train_metrics)
|
849 |
+
wandb_log(train_metrics, step=step, prefix="train")
|
850 |
|
851 |
epochs.write(
|
852 |
+
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"
|
853 |
)
|
854 |
|
855 |
# Final evaluation
|