marinone94
commited on
Commit
•
98dfb11
1
Parent(s):
a9f9b4a
avoid pushing checkpoints
Browse files
run_speech_recognition_seq2seq_streaming.py
CHANGED
@@ -797,6 +797,9 @@ def main():
|
|
797 |
)
|
798 |
logger.info("*** Trainer initialized ***")
|
799 |
|
|
|
|
|
|
|
800 |
# 12. Training
|
801 |
if training_args.do_train:
|
802 |
logger.info("*** Train ***")
|
@@ -812,10 +815,7 @@ def main():
|
|
812 |
# We don't want to push the model to the hub now
|
813 |
# so we temporarily set to false the push_to_hub attribute
|
814 |
# and then reset it to the original value
|
815 |
-
orig_push_to_hub = trainer.args.push_to_hub
|
816 |
-
trainer.args.push_to_hub = False
|
817 |
trainer.save_model() # Saves the feature extractor too for easy upload
|
818 |
-
trainer.args.push_to_hub = orig_push_to_hub
|
819 |
logger.info("*** Model saved ***")
|
820 |
metrics = train_result.metrics
|
821 |
if data_args.max_train_samples:
|
@@ -909,7 +909,7 @@ def main():
|
|
909 |
notify_me(recipient=RECIPIENT_ADDRESS,
|
910 |
message=f"Training complete! {train_results = } {eval_results = }")
|
911 |
|
912 |
-
|
913 |
if training_args.push_to_hub:
|
914 |
logger.info("*** Pushing to hub ***")
|
915 |
trainer.push_to_hub(**kwargs)
|
|
|
797 |
)
|
798 |
logger.info("*** Trainer initialized ***")
|
799 |
|
800 |
+
orig_push_to_hub = trainer.args.push_to_hub
|
801 |
+
trainer.args.push_to_hub = False
|
802 |
+
|
803 |
# 12. Training
|
804 |
if training_args.do_train:
|
805 |
logger.info("*** Train ***")
|
|
|
815 |
# We don't want to push the model to the hub now
|
816 |
# so we temporarily set to false the push_to_hub attribute
|
817 |
# and then reset it to the original value
|
|
|
|
|
818 |
trainer.save_model() # Saves the feature extractor too for easy upload
|
|
|
819 |
logger.info("*** Model saved ***")
|
820 |
metrics = train_result.metrics
|
821 |
if data_args.max_train_samples:
|
|
|
909 |
notify_me(recipient=RECIPIENT_ADDRESS,
|
910 |
message=f"Training complete! {train_results = } {eval_results = }")
|
911 |
|
912 |
+
trainer.args.push_to_hub = orig_push_to_hub
|
913 |
if training_args.push_to_hub:
|
914 |
logger.info("*** Pushing to hub ***")
|
915 |
trainer.push_to_hub(**kwargs)
|