marinone94
commited on
Commit
•
eeecd97
1
Parent(s):
5e0ceba
add logs
Browse files
run_speech_recognition_seq2seq_streaming.py
CHANGED
@@ -460,6 +460,7 @@ def main():
|
|
460 |
revision=model_args.model_revision,
|
461 |
use_auth_token=True if model_args.use_auth_token else None,
|
462 |
)
|
|
|
463 |
|
464 |
if model.config.decoder_start_token_id is None:
|
465 |
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
@@ -498,6 +499,7 @@ def main():
|
|
498 |
if data_args.streaming
|
499 |
else raw_datasets["train"].select(range(data_args.max_train_samples))
|
500 |
)
|
|
|
501 |
|
502 |
if data_args.max_eval_samples is not None:
|
503 |
raw_datasets["eval"] = (
|
@@ -505,6 +507,7 @@ def main():
|
|
505 |
if data_args.streaming
|
506 |
else raw_datasets["eval"].select(range(data_args.max_eval_samples))
|
507 |
)
|
|
|
508 |
|
509 |
def prepare_dataset(batch):
|
510 |
# process audio
|
@@ -526,6 +529,7 @@ def main():
|
|
526 |
prepare_dataset,
|
527 |
remove_columns=raw_datasets_features,
|
528 |
).with_format("torch")
|
|
|
529 |
|
530 |
if training_args.do_train and data_args.streaming:
|
531 |
# manually shuffle if streaming (done by the trainer for non-streaming)
|
@@ -533,6 +537,7 @@ def main():
|
|
533 |
buffer_size=data_args.shuffle_buffer_size,
|
534 |
seed=training_args.seed,
|
535 |
)
|
|
|
536 |
|
537 |
# filter training data that is shorter than min_input_length or longer than
|
538 |
# max_input_length
|
@@ -544,10 +549,12 @@ def main():
|
|
544 |
is_audio_in_length_range,
|
545 |
input_columns=["input_length"],
|
546 |
)
|
|
|
547 |
|
548 |
# 8. Load Metric
|
549 |
metric = evaluate.load("wer")
|
550 |
do_normalize_eval = data_args.do_normalize_eval
|
|
|
551 |
|
552 |
def compute_metrics(pred):
|
553 |
pred_ids = pred.predictions
|
@@ -577,12 +584,13 @@ def main():
|
|
577 |
config.save_pretrained(training_args.output_dir)
|
578 |
|
579 |
processor = AutoProcessor.from_pretrained(training_args.output_dir)
|
580 |
-
|
581 |
# 10. Define data collator
|
582 |
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
|
583 |
processor=processor,
|
584 |
decoder_start_token_id=model.config.decoder_start_token_id,
|
585 |
)
|
|
|
586 |
|
587 |
# 11. Configure Trainer
|
588 |
# Trainer callback to reinitialise and reshuffle the streamable datasets at the beginning of each epoch
|
@@ -594,6 +602,9 @@ def main():
|
|
594 |
elif isinstance(train_dataloader.dataset, IterableDataset):
|
595 |
train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)
|
596 |
|
|
|
|
|
|
|
597 |
# Initialize Trainer
|
598 |
trainer = Seq2SeqTrainer(
|
599 |
model=model,
|
@@ -605,9 +616,7 @@ def main():
|
|
605 |
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
606 |
callbacks=[ShuffleCallback()] if data_args.streaming else None,
|
607 |
)
|
608 |
-
|
609 |
-
push_to_hub = training_args.push_to_hub
|
610 |
-
training_args.push_to_hub = False
|
611 |
|
612 |
# 12. Training
|
613 |
if training_args.do_train:
|
@@ -643,6 +652,7 @@ def main():
|
|
643 |
trainer.save_metrics("eval", metrics)
|
644 |
|
645 |
# 14. Write Training Stats
|
|
|
646 |
kwargs = {
|
647 |
"finetuned_from": model_args.model_name_or_path,
|
648 |
"tasks": "automatic-speech-recognition",
|
@@ -659,11 +669,13 @@ def main():
|
|
659 |
if model_args.model_index_name is not None:
|
660 |
kwargs["model_name"] = model_args.model_index_name
|
661 |
|
662 |
-
|
|
|
663 |
if training_args.push_to_hub:
|
664 |
trainer.push_to_hub(**kwargs)
|
665 |
else:
|
666 |
trainer.create_model_card(**kwargs)
|
|
|
667 |
|
668 |
return results
|
669 |
|
|
|
460 |
revision=model_args.model_revision,
|
461 |
use_auth_token=True if model_args.use_auth_token else None,
|
462 |
)
|
463 |
+
logger.info("Loaded pretrained model, tokenizer, and feature extractor")
|
464 |
|
465 |
if model.config.decoder_start_token_id is None:
|
466 |
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
|
|
499 |
if data_args.streaming
|
500 |
else raw_datasets["train"].select(range(data_args.max_train_samples))
|
501 |
)
|
502 |
+
logger.info("Using %d train samples", data_args.max_train_samples)
|
503 |
|
504 |
if data_args.max_eval_samples is not None:
|
505 |
raw_datasets["eval"] = (
|
|
|
507 |
if data_args.streaming
|
508 |
else raw_datasets["eval"].select(range(data_args.max_eval_samples))
|
509 |
)
|
510 |
+
logger.info("Using %d eval samples", data_args.max_eval_samples)
|
511 |
|
512 |
def prepare_dataset(batch):
|
513 |
# process audio
|
|
|
529 |
prepare_dataset,
|
530 |
remove_columns=raw_datasets_features,
|
531 |
).with_format("torch")
|
532 |
+
logger.info("Dataset map pre-processing done")
|
533 |
|
534 |
if training_args.do_train and data_args.streaming:
|
535 |
# manually shuffle if streaming (done by the trainer for non-streaming)
|
|
|
537 |
buffer_size=data_args.shuffle_buffer_size,
|
538 |
seed=training_args.seed,
|
539 |
)
|
540 |
+
logger.info("Shuffled dataset")
|
541 |
|
542 |
# filter training data that is shorter than min_input_length or longer than
|
543 |
# max_input_length
|
|
|
549 |
is_audio_in_length_range,
|
550 |
input_columns=["input_length"],
|
551 |
)
|
552 |
+
logger.info("Filtered training dataset")
|
553 |
|
554 |
# 8. Load Metric
|
555 |
metric = evaluate.load("wer")
|
556 |
do_normalize_eval = data_args.do_normalize_eval
|
557 |
+
logger.info("Loaded metric")
|
558 |
|
559 |
def compute_metrics(pred):
|
560 |
pred_ids = pred.predictions
|
|
|
584 |
config.save_pretrained(training_args.output_dir)
|
585 |
|
586 |
processor = AutoProcessor.from_pretrained(training_args.output_dir)
|
587 |
+
|
588 |
# 10. Define data collator
|
589 |
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
|
590 |
processor=processor,
|
591 |
decoder_start_token_id=model.config.decoder_start_token_id,
|
592 |
)
|
593 |
+
logger.info("Defined data collator")
|
594 |
|
595 |
# 11. Configure Trainer
|
596 |
# Trainer callback to reinitialise and reshuffle the streamable datasets at the beginning of each epoch
|
|
|
602 |
elif isinstance(train_dataloader.dataset, IterableDataset):
|
603 |
train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)
|
604 |
|
605 |
+
push_to_hub = training_args.push_to_hub
|
606 |
+
training_args.push_to_hub = False
|
607 |
+
|
608 |
# Initialize Trainer
|
609 |
trainer = Seq2SeqTrainer(
|
610 |
model=model,
|
|
|
616 |
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
617 |
callbacks=[ShuffleCallback()] if data_args.streaming else None,
|
618 |
)
|
619 |
+
logger.info("Initialized Trainer")
|
|
|
|
|
620 |
|
621 |
# 12. Training
|
622 |
if training_args.do_train:
|
|
|
652 |
trainer.save_metrics("eval", metrics)
|
653 |
|
654 |
# 14. Write Training Stats
|
655 |
+
logger.info("Training completed. Writing training stats")
|
656 |
kwargs = {
|
657 |
"finetuned_from": model_args.model_name_or_path,
|
658 |
"tasks": "automatic-speech-recognition",
|
|
|
669 |
if model_args.model_index_name is not None:
|
670 |
kwargs["model_name"] = model_args.model_index_name
|
671 |
|
672 |
+
logger.info("Pushing model to the hub") if push_to_hub else logger.info("Not pushing model to the hub - creating model card only")
|
673 |
+
trainer.args.push_to_hub = push_to_hub
|
674 |
if training_args.push_to_hub:
|
675 |
trainer.push_to_hub(**kwargs)
|
676 |
else:
|
677 |
trainer.create_model_card(**kwargs)
|
678 |
+
logger.info("*** DONE! ***")
|
679 |
|
680 |
return results
|
681 |
|