boris commited on
Commit
80b41d1
1 Parent(s): 0a77f72

feat: avoid OOM

Browse files
Files changed (1) hide show
  1. dev/seq2seq/run_seq2seq_flax.py +5 -0
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -475,6 +475,8 @@ def main():
475
 
476
  # load model
477
  model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
 
 
478
 
479
  # load tokenizer
480
  tokenizer = AutoTokenizer.from_pretrained(
@@ -529,7 +531,10 @@ def main():
529
  config=config,
530
  seed=training_args.seed_model,
531
  dtype=getattr(jnp, model_args.dtype),
 
532
  )
 
 
533
  else:
534
  model = CustomFlaxBartForConditionalGeneration(
535
  config,
 
475
 
476
  # load model
477
  model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
478
+ # avoid OOM on TPU: see https://github.com/google/flax/issues/1658
479
+ print(model.params)
480
 
481
  # load tokenizer
482
  tokenizer = AutoTokenizer.from_pretrained(
 
531
  config=config,
532
  seed=training_args.seed_model,
533
  dtype=getattr(jnp, model_args.dtype),
534
+ ignore_mismatched_sizes=True,
535
  )
536
+ # avoid OOM on TPU: see https://github.com/google/flax/issues/1658
537
+ print(model.params)
538
  else:
539
  model = CustomFlaxBartForConditionalGeneration(
540
  config,