Spaces:
Running
Running
feat: get rid of global_step + log more metrics
Browse files- dev/seq2seq/run_seq2seq_flax.py +29 -20
dev/seq2seq/run_seq2seq_flax.py
CHANGED
@@ -419,11 +419,10 @@ def create_learning_rate_fn(
|
|
419 |
def wandb_log(metrics, step=None, prefix=None):
|
420 |
if jax.process_index() == 0:
|
421 |
log_metrics = {
|
422 |
-
f"{prefix}/{k}" if prefix is not None else k:
|
423 |
-
for k, v in metrics.items()
|
424 |
}
|
425 |
if step is not None:
|
426 |
-
log_metrics["train/step"] = step
|
427 |
wandb.log(log_metrics)
|
428 |
|
429 |
|
@@ -512,10 +511,6 @@ def main():
|
|
512 |
save_code=True,
|
513 |
)
|
514 |
|
515 |
-
# set default x-axis as 'train/step'
|
516 |
-
wandb.define_metric("train/step")
|
517 |
-
wandb.define_metric("*", step_metric="train/step")
|
518 |
-
|
519 |
if model_args.from_checkpoint is not None:
|
520 |
artifact = wandb.run.use_artifact(model_args.from_checkpoint)
|
521 |
artifact_dir = artifact.download()
|
@@ -867,13 +862,27 @@ def main():
|
|
867 |
f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
|
868 |
)
|
869 |
logger.info(
|
870 |
-
f" Total train batch size (w. parallel & distributed) = {
|
871 |
)
|
872 |
logger.info(f" Total global steps = {total_steps}")
|
873 |
logger.info(f" Total optimization steps = {total_optimization_steps}")
|
874 |
|
875 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
876 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
877 |
|
878 |
def run_evaluation():
|
879 |
# ======================== Evaluating ==============================
|
@@ -900,7 +909,7 @@ def main():
|
|
900 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
901 |
|
902 |
# log metrics
|
903 |
-
wandb_log(eval_metrics, step=
|
904 |
|
905 |
# Print metrics and update progress bar
|
906 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
@@ -923,6 +932,7 @@ def main():
|
|
923 |
tokenizer.save_pretrained(training_args.output_dir)
|
924 |
|
925 |
# save state
|
|
|
926 |
state = unreplicate(state)
|
927 |
with (Path(training_args.output_dir) / "opt_state.msgpack").open("wb") as f:
|
928 |
f.write(to_bytes(state.opt_state))
|
@@ -978,7 +988,7 @@ def main():
|
|
978 |
|
979 |
for epoch in epochs:
|
980 |
# ======================== Training ================================
|
981 |
-
wandb_log({"train/epoch": epoch}, step=
|
982 |
|
983 |
# Create sampling rng
|
984 |
rng, input_rng = jax.random.split(rng)
|
@@ -999,21 +1009,20 @@ def main():
|
|
999 |
leave=False,
|
1000 |
total=steps_per_epoch,
|
1001 |
):
|
1002 |
-
global_step += 1
|
1003 |
state, train_metric = p_train_step(state, batch)
|
1004 |
|
1005 |
-
if
|
1006 |
# log metrics
|
1007 |
-
wandb_log(unreplicate(train_metric), step=
|
1008 |
|
1009 |
-
if training_args.eval_steps and
|
1010 |
run_evaluation()
|
1011 |
|
1012 |
-
if
|
1013 |
-
run_save_model(state,
|
1014 |
|
1015 |
# log final train metrics
|
1016 |
-
wandb_log(unreplicate(train_metric), step=
|
1017 |
|
1018 |
train_metric = unreplicate(train_metric)
|
1019 |
epochs.write(
|
@@ -1023,8 +1032,8 @@ def main():
|
|
1023 |
# Final evaluation
|
1024 |
eval_metrics = run_evaluation()
|
1025 |
|
1026 |
-
# save checkpoint after each epoch
|
1027 |
-
run_save_model(state,
|
1028 |
|
1029 |
|
1030 |
if __name__ == "__main__":
|
|
|
419 |
def wandb_log(metrics, step=None, prefix=None):
|
420 |
if jax.process_index() == 0:
|
421 |
log_metrics = {
|
422 |
+
f"{prefix}/{k}" if prefix is not None else k: v for k, v in metrics.items()
|
|
|
423 |
}
|
424 |
if step is not None:
|
425 |
+
log_metrics["train/step"] = unreplicate(step)
|
426 |
wandb.log(log_metrics)
|
427 |
|
428 |
|
|
|
511 |
save_code=True,
|
512 |
)
|
513 |
|
|
|
|
|
|
|
|
|
514 |
if model_args.from_checkpoint is not None:
|
515 |
artifact = wandb.run.use_artifact(model_args.from_checkpoint)
|
516 |
artifact_dir = artifact.download()
|
|
|
862 |
f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
|
863 |
)
|
864 |
logger.info(
|
865 |
+
f" Total train batch size (w. parallel & distributed) = {batch_size_per_update}"
|
866 |
)
|
867 |
logger.info(f" Total global steps = {total_steps}")
|
868 |
logger.info(f" Total optimization steps = {total_optimization_steps}")
|
869 |
|
870 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
871 |
+
|
872 |
+
# set default x-axis as 'train/step'
|
873 |
+
wandb_log({}, step=state.step)
|
874 |
+
wandb.define_metric("*", step_metric="train/step")
|
875 |
+
|
876 |
+
# add interesting config parameters
|
877 |
+
wandb.config.update(
|
878 |
+
{
|
879 |
+
"len_train": len_train_dataset,
|
880 |
+
"len_eval": len_eval_dataset,
|
881 |
+
"batch_size_per_update": batch_size_per_update,
|
882 |
+
"total_steps": total_steps,
|
883 |
+
"total_optimization_steps": total_optimization_steps,
|
884 |
+
}
|
885 |
+
)
|
886 |
|
887 |
def run_evaluation():
|
888 |
# ======================== Evaluating ==============================
|
|
|
909 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
910 |
|
911 |
# log metrics
|
912 |
+
wandb_log(eval_metrics, step=state.step, prefix="eval")
|
913 |
|
914 |
# Print metrics and update progress bar
|
915 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
|
|
932 |
tokenizer.save_pretrained(training_args.output_dir)
|
933 |
|
934 |
# save state
|
935 |
+
# TODO: maybe we should just save the full state object without params
|
936 |
state = unreplicate(state)
|
937 |
with (Path(training_args.output_dir) / "opt_state.msgpack").open("wb") as f:
|
938 |
f.write(to_bytes(state.opt_state))
|
|
|
988 |
|
989 |
for epoch in epochs:
|
990 |
# ======================== Training ================================
|
991 |
+
wandb_log({"train/epoch": epoch}, step=state.step)
|
992 |
|
993 |
# Create sampling rng
|
994 |
rng, input_rng = jax.random.split(rng)
|
|
|
1009 |
leave=False,
|
1010 |
total=steps_per_epoch,
|
1011 |
):
|
|
|
1012 |
state, train_metric = p_train_step(state, batch)
|
1013 |
|
1014 |
+
if state.step % data_args.log_interval == 0 and jax.process_index() == 0:
|
1015 |
# log metrics
|
1016 |
+
wandb_log(unreplicate(train_metric), step=state.step, prefix="train")
|
1017 |
|
1018 |
+
if training_args.eval_steps and state.step % training_args.eval_steps == 0:
|
1019 |
run_evaluation()
|
1020 |
|
1021 |
+
if state.step % data_args.save_model_steps == 0:
|
1022 |
+
run_save_model(state, state.step, epoch)
|
1023 |
|
1024 |
# log final train metrics
|
1025 |
+
wandb_log(unreplicate(train_metric), step=state.step, prefix="train")
|
1026 |
|
1027 |
train_metric = unreplicate(train_metric)
|
1028 |
epochs.write(
|
|
|
1032 |
# Final evaluation
|
1033 |
eval_metrics = run_evaluation()
|
1034 |
|
1035 |
+
# save checkpoint after each epoch
|
1036 |
+
run_save_model(state, state.step, epoch, eval_metrics)
|
1037 |
|
1038 |
|
1039 |
if __name__ == "__main__":
|