Spaces:
Running
Running
feat: log num_parameters early
Browse files- tools/train/train.py +32 -31
tools/train/train.py
CHANGED
@@ -558,6 +558,35 @@ def main():
|
|
558 |
)
|
559 |
num_params = model.num_params
|
560 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
561 |
# Create learning rate schedule
|
562 |
def create_learning_rate_fn() -> Callable[[int], jnp.array]:
|
563 |
"""Create the learning rate function."""
|
@@ -915,42 +944,14 @@ def main():
|
|
915 |
out_axis_resources=None,
|
916 |
)
|
917 |
|
918 |
-
logger.info("***** Running training *****")
|
919 |
-
logger.info(f" Num examples = {len_train_dataset}")
|
920 |
-
logger.info(f" Num Epochs = {num_epochs}")
|
921 |
-
logger.info(
|
922 |
-
f" Batch size per device = {training_args.per_device_train_batch_size}"
|
923 |
-
)
|
924 |
-
logger.info(f" Number of devices = {jax.device_count()}")
|
925 |
-
logger.info(
|
926 |
-
f" Gradient accumulation steps = {training_args.gradient_accumulation_steps}"
|
927 |
-
)
|
928 |
-
logger.info(f" Batch size per update = {batch_size_per_step}")
|
929 |
-
logger.info(f" Model parameters = {num_params:,}")
|
930 |
-
epochs = tqdm(
|
931 |
-
range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
|
932 |
-
)
|
933 |
-
|
934 |
# init variables
|
935 |
last_time = time.perf_counter()
|
936 |
train_metrics = None
|
937 |
step = int(state.step)
|
938 |
metrics_logger = MetricsLogger(step)
|
939 |
-
|
940 |
-
|
941 |
-
|
942 |
-
wandb.define_metric("*", step_metric="train/step")
|
943 |
-
|
944 |
-
# add interesting config parameters
|
945 |
-
wandb.config.update(
|
946 |
-
{
|
947 |
-
"len_train_dataset": len_train_dataset,
|
948 |
-
"len_eval_dataset": len_eval_dataset,
|
949 |
-
"batch_size_per_step": batch_size_per_step,
|
950 |
-
"num_params": num_params,
|
951 |
-
"num_devices": jax.device_count(),
|
952 |
-
}
|
953 |
-
)
|
954 |
|
955 |
def run_evaluation():
|
956 |
# ======================== Evaluating ==============================
|
|
|
558 |
)
|
559 |
num_params = model.num_params
|
560 |
|
561 |
+
logger.info("***** Running training *****")
|
562 |
+
logger.info(f" Num examples = {len_train_dataset}")
|
563 |
+
logger.info(f" Num Epochs = {num_epochs}")
|
564 |
+
logger.info(
|
565 |
+
f" Batch size per device = {training_args.per_device_train_batch_size}"
|
566 |
+
)
|
567 |
+
logger.info(f" Number of devices = {jax.device_count()}")
|
568 |
+
logger.info(
|
569 |
+
f" Gradient accumulation steps = {training_args.gradient_accumulation_steps}"
|
570 |
+
)
|
571 |
+
logger.info(f" Batch size per update = {batch_size_per_step}")
|
572 |
+
logger.info(f" Model parameters = {num_params:,}")
|
573 |
+
|
574 |
+
# create wandb run
|
575 |
+
if jax.process_index() == 0:
|
576 |
+
# set default x-axis as 'train/step'
|
577 |
+
wandb.define_metric("*", step_metric="train/step")
|
578 |
+
|
579 |
+
# add interesting config parameters
|
580 |
+
wandb.config.update(
|
581 |
+
{
|
582 |
+
"len_train_dataset": len_train_dataset,
|
583 |
+
"len_eval_dataset": len_eval_dataset,
|
584 |
+
"batch_size_per_step": batch_size_per_step,
|
585 |
+
"num_params": num_params,
|
586 |
+
"num_devices": jax.device_count(),
|
587 |
+
}
|
588 |
+
)
|
589 |
+
|
590 |
# Create learning rate schedule
|
591 |
def create_learning_rate_fn() -> Callable[[int], jnp.array]:
|
592 |
"""Create the learning rate function."""
|
|
|
944 |
out_axis_resources=None,
|
945 |
)
|
946 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
947 |
# init variables
|
948 |
last_time = time.perf_counter()
|
949 |
train_metrics = None
|
950 |
step = int(state.step)
|
951 |
metrics_logger = MetricsLogger(step)
|
952 |
+
epochs = tqdm(
|
953 |
+
range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
|
954 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
955 |
|
956 |
def run_evaluation():
|
957 |
# ======================== Evaluating ==============================
|