marinone94
commited on
Commit
•
5e0ceba
1
Parent(s):
5e05341
check if ds is load
Browse files
run_speech_recognition_seq2seq_streaming.py
CHANGED
@@ -387,6 +387,7 @@ def main():
|
|
387 |
raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
|
388 |
|
389 |
if training_args.do_train:
|
|
|
390 |
raw_datasets["train"] = load_maybe_streaming_dataset(
|
391 |
data_args.dataset_name,
|
392 |
data_args.dataset_config_name,
|
@@ -394,8 +395,10 @@ def main():
|
|
394 |
use_auth_token=True if model_args.use_auth_token else None,
|
395 |
streaming=data_args.streaming,
|
396 |
)
|
|
|
397 |
|
398 |
if training_args.do_eval:
|
|
|
399 |
raw_datasets["eval"] = load_maybe_streaming_dataset(
|
400 |
data_args.dataset_name,
|
401 |
data_args.dataset_config_name,
|
@@ -403,6 +406,7 @@ def main():
|
|
403 |
use_auth_token=True if model_args.use_auth_token else None,
|
404 |
streaming=data_args.streaming,
|
405 |
)
|
|
|
406 |
|
407 |
raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
|
408 |
|
@@ -602,8 +606,12 @@ def main():
|
|
602 |
callbacks=[ShuffleCallback()] if data_args.streaming else None,
|
603 |
)
|
604 |
|
|
|
|
|
|
|
605 |
# 12. Training
|
606 |
if training_args.do_train:
|
|
|
607 |
checkpoint = None
|
608 |
if training_args.resume_from_checkpoint is not None:
|
609 |
checkpoint = training_args.resume_from_checkpoint
|
@@ -651,6 +659,7 @@ def main():
|
|
651 |
if model_args.model_index_name is not None:
|
652 |
kwargs["model_name"] = model_args.model_index_name
|
653 |
|
|
|
654 |
if training_args.push_to_hub:
|
655 |
trainer.push_to_hub(**kwargs)
|
656 |
else:
|
|
|
387 |
raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
|
388 |
|
389 |
if training_args.do_train:
|
390 |
+
logger.info("Loading training dataset")
|
391 |
raw_datasets["train"] = load_maybe_streaming_dataset(
|
392 |
data_args.dataset_name,
|
393 |
data_args.dataset_config_name,
|
|
|
395 |
use_auth_token=True if model_args.use_auth_token else None,
|
396 |
streaming=data_args.streaming,
|
397 |
)
|
398 |
+
logger.info("Loaded training dataset")
|
399 |
|
400 |
if training_args.do_eval:
|
401 |
+
logger.info("Loading evaluation dataset")
|
402 |
raw_datasets["eval"] = load_maybe_streaming_dataset(
|
403 |
data_args.dataset_name,
|
404 |
data_args.dataset_config_name,
|
|
|
406 |
use_auth_token=True if model_args.use_auth_token else None,
|
407 |
streaming=data_args.streaming,
|
408 |
)
|
409 |
+
logger.info("Loaded evaluation dataset")
|
410 |
|
411 |
raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
|
412 |
|
|
|
606 |
callbacks=[ShuffleCallback()] if data_args.streaming else None,
|
607 |
)
|
608 |
|
609 |
+
push_to_hub = training_args.push_to_hub
|
610 |
+
training_args.push_to_hub = False
|
611 |
+
|
612 |
# 12. Training
|
613 |
if training_args.do_train:
|
614 |
+
logger.info("*** Train ***")
|
615 |
checkpoint = None
|
616 |
if training_args.resume_from_checkpoint is not None:
|
617 |
checkpoint = training_args.resume_from_checkpoint
|
|
|
659 |
if model_args.model_index_name is not None:
|
660 |
kwargs["model_name"] = model_args.model_index_name
|
661 |
|
662 |
+
training_args.push_to_hub = push_to_hub
|
663 |
if training_args.push_to_hub:
|
664 |
trainer.push_to_hub(**kwargs)
|
665 |
else:
|