boris commited on
Commit
708a42c
1 Parent(s): 272552a

fix(seq2seq): memory issue

Browse files
Files changed (1) hide show
  1. dev/seq2seq/run_seq2seq_flax.py +2 -18
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -100,12 +100,6 @@ class ModelArguments:
100
  "help": "Pretrained config name or path if not the same as model_name"
101
  },
102
  )
103
- tokenizer_name: Optional[str] = field(
104
- default=None,
105
- metadata={
106
- "help": "Pretrained tokenizer name or path if not the same as model_name"
107
- },
108
- )
109
  cache_dir: Optional[str] = field(
110
  default=None,
111
  metadata={
@@ -539,11 +533,6 @@ def main():
539
  )
540
 
541
  else:
542
- base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
543
- model_args.model_name_or_path,
544
- seed=training_args.seed,
545
- dtype=getattr(jnp, model_args.dtype),
546
- )
547
  # Set up our new model config
548
  config = BartConfig.from_pretrained(model_args.model_name_or_path)
549
  config.tie_word_embeddings = False
@@ -568,11 +557,6 @@ def main():
568
  config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
569
  )
570
 
571
- # Use pre-trained weights for encoder
572
- model.params["model"]["encoder"] = base_model.params["model"]["encoder"]
573
- model.params["model"]["shared"] = base_model.params["model"]["shared"]
574
- del base_model
575
-
576
  # Load tokenizer if it has not been set
577
  if tokenizer is None:
578
  tokenizer = AutoTokenizer.from_pretrained(
@@ -960,12 +944,12 @@ def main():
960
  artifact.add_file(
961
  str(Path(training_args.output_dir) / "training_state.json")
962
  )
963
- wandb.run.log_artifact(artifact)
964
-
965
  # save some space
966
  c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
967
  c.cleanup(wandb.util.from_human_size("5GB"))
968
 
 
 
969
  # save to the hub
970
  if training_args.push_to_hub:
971
  model.save_pretrained(
 
100
  "help": "Pretrained config name or path if not the same as model_name"
101
  },
102
  )
 
 
 
 
 
 
103
  cache_dir: Optional[str] = field(
104
  default=None,
105
  metadata={
 
533
  )
534
 
535
  else:
 
 
 
 
 
536
  # Set up our new model config
537
  config = BartConfig.from_pretrained(model_args.model_name_or_path)
538
  config.tie_word_embeddings = False
 
557
  config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
558
  )
559
 
 
 
 
 
 
560
  # Load tokenizer if it has not been set
561
  if tokenizer is None:
562
  tokenizer = AutoTokenizer.from_pretrained(
 
944
  artifact.add_file(
945
  str(Path(training_args.output_dir) / "training_state.json")
946
  )
 
 
947
  # save some space
948
  c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
949
  c.cleanup(wandb.util.from_human_size("5GB"))
950
 
951
+ wandb.run.log_artifact(artifact)
952
+
953
  # save to the hub
954
  if training_args.push_to_hub:
955
  model.save_pretrained(