ydshieh
commited on
Commit
•
6d4331f
1
Parent(s):
43c04d5
save tokenizer to each ckpt folder
Browse files
run_image_caption_flax.py
CHANGED
@@ -999,7 +999,7 @@ def main():
|
|
999 |
if jax.process_index() == 0:
|
1000 |
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
1001 |
model.save_pretrained(os.path.join(training_args.output_dir, f'ckpt_{epoch+1}'), params=params)
|
1002 |
-
tokenizer.save_pretrained(training_args.output_dir)
|
1003 |
if training_args.push_to_hub:
|
1004 |
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
|
1005 |
|
|
|
999 |
if jax.process_index() == 0:
|
1000 |
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
1001 |
model.save_pretrained(os.path.join(training_args.output_dir, f'ckpt_{epoch+1}'), params=params)
|
1002 |
+
tokenizer.save_pretrained(os.path.join(training_args.output_dir, f'ckpt_{epoch+1}'))
|
1003 |
if training_args.push_to_hub:
|
1004 |
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
|
1005 |
|