ydshieh commited on
Commit
6d4331f
1 Parent(s): 43c04d5

save tokenizer to each ckpt folder

Browse files
Files changed (1) hide show
  1. run_image_caption_flax.py +1 -1
run_image_caption_flax.py CHANGED
@@ -999,7 +999,7 @@ def main():
999
  if jax.process_index() == 0:
1000
  params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
1001
  model.save_pretrained(os.path.join(training_args.output_dir, f'ckpt_{epoch+1}'), params=params)
1002
- tokenizer.save_pretrained(training_args.output_dir)
1003
  if training_args.push_to_hub:
1004
  repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
1005
 
 
999
  if jax.process_index() == 0:
1000
  params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
1001
  model.save_pretrained(os.path.join(training_args.output_dir, f'ckpt_{epoch+1}'), params=params)
1002
+ tokenizer.save_pretrained(os.path.join(training_args.output_dir, f'ckpt_{epoch+1}'))
1003
  if training_args.push_to_hub:
1004
  repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
1005