Spaces:
Running
Running
Merge pull request #90 from borisdayma/feat-new
Browse files- dev/seq2seq/run_seq2seq_flax.py +18 -31
dev/seq2seq/run_seq2seq_flax.py
CHANGED
@@ -100,12 +100,6 @@ class ModelArguments:
|
|
100 |
"help": "Pretrained config name or path if not the same as model_name"
|
101 |
},
|
102 |
)
|
103 |
-
tokenizer_name: Optional[str] = field(
|
104 |
-
default=None,
|
105 |
-
metadata={
|
106 |
-
"help": "Pretrained tokenizer name or path if not the same as model_name"
|
107 |
-
},
|
108 |
-
)
|
109 |
cache_dir: Optional[str] = field(
|
110 |
default=None,
|
111 |
metadata={
|
@@ -422,7 +416,7 @@ def wandb_log(metrics, step=None, prefix=None):
|
|
422 |
f"{prefix}/{k}" if prefix is not None else k: v for k, v in metrics.items()
|
423 |
}
|
424 |
if step is not None:
|
425 |
-
log_metrics["train/step"] =
|
426 |
wandb.log(log_metrics)
|
427 |
|
428 |
|
@@ -534,11 +528,6 @@ def main():
|
|
534 |
)
|
535 |
|
536 |
else:
|
537 |
-
base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
538 |
-
model_args.model_name_or_path,
|
539 |
-
seed=training_args.seed,
|
540 |
-
dtype=getattr(jnp, model_args.dtype),
|
541 |
-
)
|
542 |
# Set up our new model config
|
543 |
config = BartConfig.from_pretrained(model_args.model_name_or_path)
|
544 |
config.tie_word_embeddings = False
|
@@ -563,11 +552,6 @@ def main():
|
|
563 |
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
564 |
)
|
565 |
|
566 |
-
# Use pre-trained weights for encoder
|
567 |
-
model.params["model"]["encoder"] = base_model.params["model"]["encoder"]
|
568 |
-
model.params["model"]["shared"] = base_model.params["model"]["shared"]
|
569 |
-
del base_model
|
570 |
-
|
571 |
# Load tokenizer if it has not been set
|
572 |
if tokenizer is None:
|
573 |
tokenizer = AutoTokenizer.from_pretrained(
|
@@ -862,7 +846,7 @@ def main():
|
|
862 |
f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
|
863 |
)
|
864 |
logger.info(
|
865 |
-
f" Total train batch size (w. parallel &
|
866 |
)
|
867 |
logger.info(f" Total global steps = {total_steps}")
|
868 |
logger.info(f" Total optimization steps = {total_optimization_steps}")
|
@@ -870,7 +854,7 @@ def main():
|
|
870 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
871 |
|
872 |
# set default x-axis as 'train/step'
|
873 |
-
wandb_log({}, step=state.step)
|
874 |
wandb.define_metric("*", step_metric="train/step")
|
875 |
|
876 |
# add interesting config parameters
|
@@ -909,7 +893,7 @@ def main():
|
|
909 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
910 |
|
911 |
# log metrics
|
912 |
-
wandb_log(eval_metrics, step=state.step, prefix="eval")
|
913 |
|
914 |
# Print metrics and update progress bar
|
915 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
@@ -943,6 +927,10 @@ def main():
|
|
943 |
|
944 |
# save to W&B
|
945 |
if data_args.log_model:
|
|
|
|
|
|
|
|
|
946 |
metadata = {"step": step, "epoch": epoch}
|
947 |
if eval_metrics is not None:
|
948 |
metadata["eval/loss"] = eval_metrics["loss"]
|
@@ -970,11 +958,8 @@ def main():
|
|
970 |
artifact.add_file(
|
971 |
str(Path(training_args.output_dir) / "training_state.json")
|
972 |
)
|
973 |
-
wandb.run.log_artifact(artifact)
|
974 |
|
975 |
-
|
976 |
-
c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
|
977 |
-
c.cleanup(wandb.util.from_human_size("5GB"))
|
978 |
|
979 |
# save to the hub
|
980 |
if training_args.push_to_hub:
|
@@ -988,7 +973,8 @@ def main():
|
|
988 |
|
989 |
for epoch in epochs:
|
990 |
# ======================== Training ================================
|
991 |
-
|
|
|
992 |
|
993 |
# Create sampling rng
|
994 |
rng, input_rng = jax.random.split(rng)
|
@@ -1010,19 +996,20 @@ def main():
|
|
1010 |
total=steps_per_epoch,
|
1011 |
):
|
1012 |
state, train_metric = p_train_step(state, batch)
|
|
|
1013 |
|
1014 |
-
if
|
1015 |
# log metrics
|
1016 |
-
wandb_log(unreplicate(train_metric), step=
|
1017 |
|
1018 |
-
if training_args.eval_steps and
|
1019 |
run_evaluation()
|
1020 |
|
1021 |
-
if
|
1022 |
-
run_save_model(state,
|
1023 |
|
1024 |
# log final train metrics
|
1025 |
-
wandb_log(unreplicate(train_metric), step=
|
1026 |
|
1027 |
train_metric = unreplicate(train_metric)
|
1028 |
epochs.write(
|
|
|
100 |
"help": "Pretrained config name or path if not the same as model_name"
|
101 |
},
|
102 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
cache_dir: Optional[str] = field(
|
104 |
default=None,
|
105 |
metadata={
|
|
|
416 |
f"{prefix}/{k}" if prefix is not None else k: v for k, v in metrics.items()
|
417 |
}
|
418 |
if step is not None:
|
419 |
+
log_metrics["train/step"] = step
|
420 |
wandb.log(log_metrics)
|
421 |
|
422 |
|
|
|
528 |
)
|
529 |
|
530 |
else:
|
|
|
|
|
|
|
|
|
|
|
531 |
# Set up our new model config
|
532 |
config = BartConfig.from_pretrained(model_args.model_name_or_path)
|
533 |
config.tie_word_embeddings = False
|
|
|
552 |
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
553 |
)
|
554 |
|
|
|
|
|
|
|
|
|
|
|
555 |
# Load tokenizer if it has not been set
|
556 |
if tokenizer is None:
|
557 |
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
846 |
f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
|
847 |
)
|
848 |
logger.info(
|
849 |
+
f" Total train batch size (w. parallel, distributed & gradient accumulation) = {batch_size_per_update}"
|
850 |
)
|
851 |
logger.info(f" Total global steps = {total_steps}")
|
852 |
logger.info(f" Total optimization steps = {total_optimization_steps}")
|
|
|
854 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
855 |
|
856 |
# set default x-axis as 'train/step'
|
857 |
+
wandb_log({}, step=unreplicate(state.step))
|
858 |
wandb.define_metric("*", step_metric="train/step")
|
859 |
|
860 |
# add interesting config parameters
|
|
|
893 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
894 |
|
895 |
# log metrics
|
896 |
+
wandb_log(eval_metrics, step=unreplicate(state.step), prefix="eval")
|
897 |
|
898 |
# Print metrics and update progress bar
|
899 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
|
|
927 |
|
928 |
# save to W&B
|
929 |
if data_args.log_model:
|
930 |
+
# save some space
|
931 |
+
c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
|
932 |
+
c.cleanup(wandb.util.from_human_size("5GB"))
|
933 |
+
|
934 |
metadata = {"step": step, "epoch": epoch}
|
935 |
if eval_metrics is not None:
|
936 |
metadata["eval/loss"] = eval_metrics["loss"]
|
|
|
958 |
artifact.add_file(
|
959 |
str(Path(training_args.output_dir) / "training_state.json")
|
960 |
)
|
|
|
961 |
|
962 |
+
wandb.run.log_artifact(artifact)
|
|
|
|
|
963 |
|
964 |
# save to the hub
|
965 |
if training_args.push_to_hub:
|
|
|
973 |
|
974 |
for epoch in epochs:
|
975 |
# ======================== Training ================================
|
976 |
+
step = unreplicate(state.step)
|
977 |
+
wandb_log({"train/epoch": epoch}, step=step)
|
978 |
|
979 |
# Create sampling rng
|
980 |
rng, input_rng = jax.random.split(rng)
|
|
|
996 |
total=steps_per_epoch,
|
997 |
):
|
998 |
state, train_metric = p_train_step(state, batch)
|
999 |
+
step = unreplicate(state.step)
|
1000 |
|
1001 |
+
if step % data_args.log_interval == 0 and jax.process_index() == 0:
|
1002 |
# log metrics
|
1003 |
+
wandb_log(unreplicate(train_metric), step=step, prefix="train")
|
1004 |
|
1005 |
+
if training_args.eval_steps and step % training_args.eval_steps == 0:
|
1006 |
run_evaluation()
|
1007 |
|
1008 |
+
if step % data_args.save_model_steps == 0:
|
1009 |
+
run_save_model(state, step, epoch)
|
1010 |
|
1011 |
# log final train metrics
|
1012 |
+
wandb_log(unreplicate(train_metric), step=step, prefix="train")
|
1013 |
|
1014 |
train_metric = unreplicate(train_metric)
|
1015 |
epochs.write(
|