marinone94 commited on
Commit
5e0ceba
1 Parent(s): 5e05341

check if ds is load

Browse files
run_speech_recognition_seq2seq_streaming.py CHANGED
@@ -387,6 +387,7 @@ def main():
387
  raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
388
 
389
  if training_args.do_train:
 
390
  raw_datasets["train"] = load_maybe_streaming_dataset(
391
  data_args.dataset_name,
392
  data_args.dataset_config_name,
@@ -394,8 +395,10 @@ def main():
394
  use_auth_token=True if model_args.use_auth_token else None,
395
  streaming=data_args.streaming,
396
  )
 
397
 
398
  if training_args.do_eval:
 
399
  raw_datasets["eval"] = load_maybe_streaming_dataset(
400
  data_args.dataset_name,
401
  data_args.dataset_config_name,
@@ -403,6 +406,7 @@ def main():
403
  use_auth_token=True if model_args.use_auth_token else None,
404
  streaming=data_args.streaming,
405
  )
 
406
 
407
  raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
408
 
@@ -602,8 +606,12 @@ def main():
602
  callbacks=[ShuffleCallback()] if data_args.streaming else None,
603
  )
604
 
 
 
 
605
  # 12. Training
606
  if training_args.do_train:
 
607
  checkpoint = None
608
  if training_args.resume_from_checkpoint is not None:
609
  checkpoint = training_args.resume_from_checkpoint
@@ -651,6 +659,7 @@ def main():
651
  if model_args.model_index_name is not None:
652
  kwargs["model_name"] = model_args.model_index_name
653
 
 
654
  if training_args.push_to_hub:
655
  trainer.push_to_hub(**kwargs)
656
  else:
 
387
  raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
388
 
389
  if training_args.do_train:
390
+ logger.info("Loading training dataset")
391
  raw_datasets["train"] = load_maybe_streaming_dataset(
392
  data_args.dataset_name,
393
  data_args.dataset_config_name,
 
395
  use_auth_token=True if model_args.use_auth_token else None,
396
  streaming=data_args.streaming,
397
  )
398
+ logger.info("Loaded training dataset")
399
 
400
  if training_args.do_eval:
401
+ logger.info("Loading evaluation dataset")
402
  raw_datasets["eval"] = load_maybe_streaming_dataset(
403
  data_args.dataset_name,
404
  data_args.dataset_config_name,
 
406
  use_auth_token=True if model_args.use_auth_token else None,
407
  streaming=data_args.streaming,
408
  )
409
+ logger.info("Loaded evaluation dataset")
410
 
411
  raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
412
 
 
606
  callbacks=[ShuffleCallback()] if data_args.streaming else None,
607
  )
608
 
609
+ push_to_hub = training_args.push_to_hub
610
+ training_args.push_to_hub = False
611
+
612
  # 12. Training
613
  if training_args.do_train:
614
+ logger.info("*** Train ***")
615
  checkpoint = None
616
  if training_args.resume_from_checkpoint is not None:
617
  checkpoint = training_args.resume_from_checkpoint
 
659
  if model_args.model_index_name is not None:
660
  kwargs["model_name"] = model_args.model_index_name
661
 
662
+ training_args.push_to_hub = push_to_hub
663
  if training_args.push_to_hub:
664
  trainer.push_to_hub(**kwargs)
665
  else: