Spaces:
Running
Running
add tokenizer save to wandb:
Browse filesFormer-commit-id: 36b4af0d456410a4c2996d1476525e91205d3d1c
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -811,13 +811,16 @@ def main():
|
|
811 |
params=params,
|
812 |
)
|
813 |
|
|
|
|
|
|
|
814 |
# save state
|
815 |
state = unreplicate(state)
|
816 |
with (Path(training_args.output_dir) / 'opt_state.msgpack').open('wb') as f:
|
817 |
f.write(to_bytes(state.opt_state))
|
818 |
with (Path(training_args.output_dir) / 'training_state.json').open('w') as f:
|
819 |
json.dump({'step': state.step.item()}, f)
|
820 |
-
|
821 |
# save to W&B
|
822 |
if data_args.log_model:
|
823 |
metadata = {'step': step, 'epoch': epoch}
|
@@ -827,6 +830,11 @@ def main():
|
|
827 |
name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
|
828 |
)
|
829 |
artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
|
|
|
|
|
|
|
|
|
|
|
830 |
artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
|
831 |
artifact.add_file(str(Path(training_args.output_dir) / 'opt_state.msgpack'))
|
832 |
artifact.add_file(str(Path(training_args.output_dir) / 'training_state.json'))
|
|
|
811 |
params=params,
|
812 |
)
|
813 |
|
814 |
+
# save tokenizer
|
815 |
+
tokenizer.save_pretrained(training_args.output_dir)
|
816 |
+
|
817 |
# save state
|
818 |
state = unreplicate(state)
|
819 |
with (Path(training_args.output_dir) / 'opt_state.msgpack').open('wb') as f:
|
820 |
f.write(to_bytes(state.opt_state))
|
821 |
with (Path(training_args.output_dir) / 'training_state.json').open('w') as f:
|
822 |
json.dump({'step': state.step.item()}, f)
|
823 |
+
|
824 |
# save to W&B
|
825 |
if data_args.log_model:
|
826 |
metadata = {'step': step, 'epoch': epoch}
|
|
|
830 |
name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
|
831 |
)
|
832 |
artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
|
833 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'tokenizer_config.json'))
|
834 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'special_tokens_map.json'))
|
835 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'vocab.json'))
|
836 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'added_tokens.json'))
|
837 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'merges.txt'))
|
838 |
artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
|
839 |
artifact.add_file(str(Path(training_args.output_dir) / 'opt_state.msgpack'))
|
840 |
artifact.add_file(str(Path(training_args.output_dir) / 'training_state.json'))
|