marinone94 commited on
Commit
eeecd97
1 Parent(s): 5e0ceba
run_speech_recognition_seq2seq_streaming.py CHANGED
@@ -460,6 +460,7 @@ def main():
460
  revision=model_args.model_revision,
461
  use_auth_token=True if model_args.use_auth_token else None,
462
  )
 
463
 
464
  if model.config.decoder_start_token_id is None:
465
  raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
@@ -498,6 +499,7 @@ def main():
498
  if data_args.streaming
499
  else raw_datasets["train"].select(range(data_args.max_train_samples))
500
  )
 
501
 
502
  if data_args.max_eval_samples is not None:
503
  raw_datasets["eval"] = (
@@ -505,6 +507,7 @@ def main():
505
  if data_args.streaming
506
  else raw_datasets["eval"].select(range(data_args.max_eval_samples))
507
  )
 
508
 
509
  def prepare_dataset(batch):
510
  # process audio
@@ -526,6 +529,7 @@ def main():
526
  prepare_dataset,
527
  remove_columns=raw_datasets_features,
528
  ).with_format("torch")
 
529
 
530
  if training_args.do_train and data_args.streaming:
531
  # manually shuffle if streaming (done by the trainer for non-streaming)
@@ -533,6 +537,7 @@ def main():
533
  buffer_size=data_args.shuffle_buffer_size,
534
  seed=training_args.seed,
535
  )
 
536
 
537
  # filter training data that is shorter than min_input_length or longer than
538
  # max_input_length
@@ -544,10 +549,12 @@ def main():
544
  is_audio_in_length_range,
545
  input_columns=["input_length"],
546
  )
 
547
 
548
  # 8. Load Metric
549
  metric = evaluate.load("wer")
550
  do_normalize_eval = data_args.do_normalize_eval
 
551
 
552
  def compute_metrics(pred):
553
  pred_ids = pred.predictions
@@ -577,12 +584,13 @@ def main():
577
  config.save_pretrained(training_args.output_dir)
578
 
579
  processor = AutoProcessor.from_pretrained(training_args.output_dir)
580
-
581
  # 10. Define data collator
582
  data_collator = DataCollatorSpeechSeq2SeqWithPadding(
583
  processor=processor,
584
  decoder_start_token_id=model.config.decoder_start_token_id,
585
  )
 
586
 
587
  # 11. Configure Trainer
588
  # Trainer callback to reinitialise and reshuffle the streamable datasets at the beginning of each epoch
@@ -594,6 +602,9 @@ def main():
594
  elif isinstance(train_dataloader.dataset, IterableDataset):
595
  train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)
596
 
 
 
 
597
  # Initialize Trainer
598
  trainer = Seq2SeqTrainer(
599
  model=model,
@@ -605,9 +616,7 @@ def main():
605
  compute_metrics=compute_metrics if training_args.predict_with_generate else None,
606
  callbacks=[ShuffleCallback()] if data_args.streaming else None,
607
  )
608
-
609
- push_to_hub = training_args.push_to_hub
610
- training_args.push_to_hub = False
611
 
612
  # 12. Training
613
  if training_args.do_train:
@@ -643,6 +652,7 @@ def main():
643
  trainer.save_metrics("eval", metrics)
644
 
645
  # 14. Write Training Stats
 
646
  kwargs = {
647
  "finetuned_from": model_args.model_name_or_path,
648
  "tasks": "automatic-speech-recognition",
@@ -659,11 +669,13 @@ def main():
659
  if model_args.model_index_name is not None:
660
  kwargs["model_name"] = model_args.model_index_name
661
 
662
- training_args.push_to_hub = push_to_hub
 
663
  if training_args.push_to_hub:
664
  trainer.push_to_hub(**kwargs)
665
  else:
666
  trainer.create_model_card(**kwargs)
 
667
 
668
  return results
669
 
 
460
  revision=model_args.model_revision,
461
  use_auth_token=True if model_args.use_auth_token else None,
462
  )
463
+ logger.info("Loaded pretrained model, tokenizer, and feature extractor")
464
 
465
  if model.config.decoder_start_token_id is None:
466
  raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
 
499
  if data_args.streaming
500
  else raw_datasets["train"].select(range(data_args.max_train_samples))
501
  )
502
+ logger.info("Using %d train samples", data_args.max_train_samples)
503
 
504
  if data_args.max_eval_samples is not None:
505
  raw_datasets["eval"] = (
 
507
  if data_args.streaming
508
  else raw_datasets["eval"].select(range(data_args.max_eval_samples))
509
  )
510
+ logger.info("Using %d eval samples", data_args.max_eval_samples)
511
 
512
  def prepare_dataset(batch):
513
  # process audio
 
529
  prepare_dataset,
530
  remove_columns=raw_datasets_features,
531
  ).with_format("torch")
532
+ logger.info("Dataset map pre-processing done")
533
 
534
  if training_args.do_train and data_args.streaming:
535
  # manually shuffle if streaming (done by the trainer for non-streaming)
 
537
  buffer_size=data_args.shuffle_buffer_size,
538
  seed=training_args.seed,
539
  )
540
+ logger.info("Shuffled dataset")
541
 
542
  # filter training data that is shorter than min_input_length or longer than
543
  # max_input_length
 
549
  is_audio_in_length_range,
550
  input_columns=["input_length"],
551
  )
552
+ logger.info("Filtered training dataset")
553
 
554
  # 8. Load Metric
555
  metric = evaluate.load("wer")
556
  do_normalize_eval = data_args.do_normalize_eval
557
+ logger.info("Loaded metric")
558
 
559
  def compute_metrics(pred):
560
  pred_ids = pred.predictions
 
584
  config.save_pretrained(training_args.output_dir)
585
 
586
  processor = AutoProcessor.from_pretrained(training_args.output_dir)
587
+
588
  # 10. Define data collator
589
  data_collator = DataCollatorSpeechSeq2SeqWithPadding(
590
  processor=processor,
591
  decoder_start_token_id=model.config.decoder_start_token_id,
592
  )
593
+ logger.info("Defined data collator")
594
 
595
  # 11. Configure Trainer
596
  # Trainer callback to reinitialise and reshuffle the streamable datasets at the beginning of each epoch
 
602
  elif isinstance(train_dataloader.dataset, IterableDataset):
603
  train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)
604
 
605
+ push_to_hub = training_args.push_to_hub
606
+ training_args.push_to_hub = False
607
+
608
  # Initialize Trainer
609
  trainer = Seq2SeqTrainer(
610
  model=model,
 
616
  compute_metrics=compute_metrics if training_args.predict_with_generate else None,
617
  callbacks=[ShuffleCallback()] if data_args.streaming else None,
618
  )
619
+ logger.info("Initialized Trainer")
 
 
620
 
621
  # 12. Training
622
  if training_args.do_train:
 
652
  trainer.save_metrics("eval", metrics)
653
 
654
  # 14. Write Training Stats
655
+ logger.info("Training completed. Writing training stats")
656
  kwargs = {
657
  "finetuned_from": model_args.model_name_or_path,
658
  "tasks": "automatic-speech-recognition",
 
669
  if model_args.model_index_name is not None:
670
  kwargs["model_name"] = model_args.model_index_name
671
 
672
+ logger.info("Pushing model to the hub") if push_to_hub else logger.info("Not pushing model to the hub - creating model card only")
673
+ trainer.args.push_to_hub = push_to_hub
674
  if training_args.push_to_hub:
675
  trainer.push_to_hub(**kwargs)
676
  else:
677
  trainer.create_model_card(**kwargs)
678
+ logger.info("*** DONE! ***")
679
 
680
  return results
681