Spaces:
Running
Running
feat: no gradient checkpointing for params init
Browse files- tools/train/train.py +7 -6
tools/train/train.py
CHANGED
@@ -531,8 +531,6 @@ def main():
|
|
531 |
# Set up our new model config
|
532 |
if model_args.config_name:
|
533 |
config = DalleBartConfig.from_pretrained(model_args.config_name)
|
534 |
-
# initializing params with gradient checkpointing create issues
|
535 |
-
config.gradient_checkpointing = False
|
536 |
else:
|
537 |
config = None
|
538 |
|
@@ -545,6 +543,9 @@ def main():
|
|
545 |
dtype=getattr(jnp, model_args.dtype),
|
546 |
abstract_init=True,
|
547 |
load_on_cpu=True,
|
|
|
|
|
|
|
548 |
)
|
549 |
else:
|
550 |
model = DalleBart(
|
@@ -552,6 +553,7 @@ def main():
|
|
552 |
seed=training_args.seed_model,
|
553 |
dtype=getattr(jnp, model_args.dtype),
|
554 |
load_on_cpu=True,
|
|
|
555 |
)
|
556 |
|
557 |
# update model config per training args
|
@@ -559,11 +561,10 @@ def main():
|
|
559 |
# This is still considered correctly during training as function is pjitted
|
560 |
model.config.gradient_checkpointing = training_args.gradient_checkpointing
|
561 |
|
562 |
-
# eval model cannot use remat
|
563 |
-
eval_config = copy.deepcopy(model.config)
|
564 |
-
eval_config.gradient_checkpointing = False
|
565 |
-
|
566 |
if training_args.gradient_checkpointing:
|
|
|
|
|
|
|
567 |
eval_model = DalleBart(
|
568 |
eval_config,
|
569 |
seed=training_args.seed_model,
|
|
|
531 |
# Set up our new model config
|
532 |
if model_args.config_name:
|
533 |
config = DalleBartConfig.from_pretrained(model_args.config_name)
|
|
|
|
|
534 |
else:
|
535 |
config = None
|
536 |
|
|
|
543 |
dtype=getattr(jnp, model_args.dtype),
|
544 |
abstract_init=True,
|
545 |
load_on_cpu=True,
|
546 |
+
# initializing params with gradient checkpointing creates issues
|
547 |
+
# we correctly set it later per training_args
|
548 |
+
gradient_checkpointing=False,
|
549 |
)
|
550 |
else:
|
551 |
model = DalleBart(
|
|
|
553 |
seed=training_args.seed_model,
|
554 |
dtype=getattr(jnp, model_args.dtype),
|
555 |
load_on_cpu=True,
|
556 |
+
gradient_checkpointing=False,
|
557 |
)
|
558 |
|
559 |
# update model config per training args
|
|
|
561 |
# This is still considered correctly during training as function is pjitted
|
562 |
model.config.gradient_checkpointing = training_args.gradient_checkpointing
|
563 |
|
|
|
|
|
|
|
|
|
564 |
if training_args.gradient_checkpointing:
|
565 |
+
# eval model cannot use remat
|
566 |
+
eval_config = copy.deepcopy(model.config)
|
567 |
+
eval_config.gradient_checkpointing = False
|
568 |
eval_model = DalleBart(
|
569 |
eval_config,
|
570 |
seed=training_args.seed_model,
|