Spaces:
Running
Running
fix(seq2seq): memory issue
Browse files
dev/seq2seq/run_seq2seq_flax.py
CHANGED
@@ -100,12 +100,6 @@ class ModelArguments:
|
|
100 |
"help": "Pretrained config name or path if not the same as model_name"
|
101 |
},
|
102 |
)
|
103 |
-
tokenizer_name: Optional[str] = field(
|
104 |
-
default=None,
|
105 |
-
metadata={
|
106 |
-
"help": "Pretrained tokenizer name or path if not the same as model_name"
|
107 |
-
},
|
108 |
-
)
|
109 |
cache_dir: Optional[str] = field(
|
110 |
default=None,
|
111 |
metadata={
|
@@ -539,11 +533,6 @@ def main():
|
|
539 |
)
|
540 |
|
541 |
else:
|
542 |
-
base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
543 |
-
model_args.model_name_or_path,
|
544 |
-
seed=training_args.seed,
|
545 |
-
dtype=getattr(jnp, model_args.dtype),
|
546 |
-
)
|
547 |
# Set up our new model config
|
548 |
config = BartConfig.from_pretrained(model_args.model_name_or_path)
|
549 |
config.tie_word_embeddings = False
|
@@ -568,11 +557,6 @@ def main():
|
|
568 |
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
569 |
)
|
570 |
|
571 |
-
# Use pre-trained weights for encoder
|
572 |
-
model.params["model"]["encoder"] = base_model.params["model"]["encoder"]
|
573 |
-
model.params["model"]["shared"] = base_model.params["model"]["shared"]
|
574 |
-
del base_model
|
575 |
-
|
576 |
# Load tokenizer if it has not been set
|
577 |
if tokenizer is None:
|
578 |
tokenizer = AutoTokenizer.from_pretrained(
|
@@ -960,12 +944,12 @@ def main():
|
|
960 |
artifact.add_file(
|
961 |
str(Path(training_args.output_dir) / "training_state.json")
|
962 |
)
|
963 |
-
wandb.run.log_artifact(artifact)
|
964 |
-
|
965 |
# save some space
|
966 |
c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
|
967 |
c.cleanup(wandb.util.from_human_size("5GB"))
|
968 |
|
|
|
|
|
969 |
# save to the hub
|
970 |
if training_args.push_to_hub:
|
971 |
model.save_pretrained(
|
|
|
100 |
"help": "Pretrained config name or path if not the same as model_name"
|
101 |
},
|
102 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
cache_dir: Optional[str] = field(
|
104 |
default=None,
|
105 |
metadata={
|
|
|
533 |
)
|
534 |
|
535 |
else:
|
|
|
|
|
|
|
|
|
|
|
536 |
# Set up our new model config
|
537 |
config = BartConfig.from_pretrained(model_args.model_name_or_path)
|
538 |
config.tie_word_embeddings = False
|
|
|
557 |
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
558 |
)
|
559 |
|
|
|
|
|
|
|
|
|
|
|
560 |
# Load tokenizer if it has not been set
|
561 |
if tokenizer is None:
|
562 |
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
944 |
artifact.add_file(
|
945 |
str(Path(training_args.output_dir) / "training_state.json")
|
946 |
)
|
|
|
|
|
947 |
# save some space
|
948 |
c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
|
949 |
c.cleanup(wandb.util.from_human_size("5GB"))
|
950 |
|
951 |
+
wandb.run.log_artifact(artifact)
|
952 |
+
|
953 |
# save to the hub
|
954 |
if training_args.push_to_hub:
|
955 |
model.save_pretrained(
|