Spaces:
Running
Running
feat: avoid OOM
Browse files
dev/seq2seq/run_seq2seq_flax.py
CHANGED
@@ -475,6 +475,8 @@ def main():
|
|
475 |
|
476 |
# load model
|
477 |
model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
|
|
|
|
|
478 |
|
479 |
# load tokenizer
|
480 |
tokenizer = AutoTokenizer.from_pretrained(
|
@@ -529,7 +531,10 @@ def main():
|
|
529 |
config=config,
|
530 |
seed=training_args.seed_model,
|
531 |
dtype=getattr(jnp, model_args.dtype),
|
|
|
532 |
)
|
|
|
|
|
533 |
else:
|
534 |
model = CustomFlaxBartForConditionalGeneration(
|
535 |
config,
|
|
|
475 |
|
476 |
# load model
|
477 |
model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
|
478 |
+
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
479 |
+
print(model.params)
|
480 |
|
481 |
# load tokenizer
|
482 |
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
531 |
config=config,
|
532 |
seed=training_args.seed_model,
|
533 |
dtype=getattr(jnp, model_args.dtype),
|
534 |
+
ignore_mismatched_sizes=True,
|
535 |
)
|
536 |
+
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
537 |
+
print(model.params)
|
538 |
else:
|
539 |
model = CustomFlaxBartForConditionalGeneration(
|
540 |
config,
|