marinone94
commited on
Commit
•
9e467b3
1
Parent(s):
e023bb4
use tokenizer to batch_decode
Browse files
run_speech_recognition_seq2seq_streaming.py
CHANGED
@@ -458,7 +458,7 @@ def load_maybe_streaming_dataset(
|
|
458 |
return dataset
|
459 |
|
460 |
|
461 |
-
def print_data_samples(dataset,
|
462 |
shown_samples = 0
|
463 |
for batch in dataset:
|
464 |
print("Target: ", tokenizer.batch_decode(batch["labels"]))
|
@@ -786,7 +786,7 @@ def main():
|
|
786 |
# 12. Training
|
787 |
if training_args.do_train:
|
788 |
logger.info("*** Train ***")
|
789 |
-
print_data_samples(vectorized_datasets["train"],
|
790 |
checkpoint = None
|
791 |
if training_args.resume_from_checkpoint is not None:
|
792 |
checkpoint = training_args.resume_from_checkpoint
|
@@ -824,15 +824,15 @@ def main():
|
|
824 |
num_beams=training_args.generation_num_beams,
|
825 |
)
|
826 |
logger.info("*** Test prediction done ***")
|
827 |
-
preds =
|
828 |
-
labels =
|
829 |
pred_labels = [f"Prediction: {pred}\nLabel: {label}\n" for pred, label in zip(preds, labels)]
|
830 |
logger.info("Before setting language and task")
|
831 |
logger.info(f"{pred_labels}")
|
832 |
trainer.model.config.forced_decoder_ids = \
|
833 |
-
|
834 |
-
preds =
|
835 |
-
labels =
|
836 |
pred_labels = [f"Prediction: {pred}\nLabel: {label}\n" for pred, label in zip(preds, labels)]
|
837 |
logger.info("After setting language and task")
|
838 |
logger.info(f"{pred_labels}")
|
@@ -841,7 +841,7 @@ def main():
|
|
841 |
results = {}
|
842 |
if training_args.do_eval:
|
843 |
logger.info("*** Evaluate ***")
|
844 |
-
print_data_samples(vectorized_datasets["eval"],
|
845 |
metrics = trainer.evaluate(
|
846 |
metric_key_prefix="eval",
|
847 |
max_length=training_args.generation_max_length,
|
|
|
458 |
return dataset
|
459 |
|
460 |
|
461 |
+
def print_data_samples(dataset, tokenizer, max_samples=5):
|
462 |
shown_samples = 0
|
463 |
for batch in dataset:
|
464 |
print("Target: ", tokenizer.batch_decode(batch["labels"]))
|
|
|
786 |
# 12. Training
|
787 |
if training_args.do_train:
|
788 |
logger.info("*** Train ***")
|
789 |
+
print_data_samples(vectorized_datasets["train"], tokenizer)
|
790 |
checkpoint = None
|
791 |
if training_args.resume_from_checkpoint is not None:
|
792 |
checkpoint = training_args.resume_from_checkpoint
|
|
|
824 |
num_beams=training_args.generation_num_beams,
|
825 |
)
|
826 |
logger.info("*** Test prediction done ***")
|
827 |
+
preds = tokenizer.batch_decode(predictions.predictions)
|
828 |
+
labels = tokenizer.batch_decode(predictions.label_ids)
|
829 |
pred_labels = [f"Prediction: {pred}\nLabel: {label}\n" for pred, label in zip(preds, labels)]
|
830 |
logger.info("Before setting language and task")
|
831 |
logger.info(f"{pred_labels}")
|
832 |
trainer.model.config.forced_decoder_ids = \
|
833 |
+
tokenizer.get_decoder_prompt_ids(language=data_args.language_eval, task=data_args.task, no_timestamps=True)
|
834 |
+
preds = tokenizer.batch_decode(predictions.predictions)
|
835 |
+
labels = tokenizer.batch_decode(predictions.label_ids)
|
836 |
pred_labels = [f"Prediction: {pred}\nLabel: {label}\n" for pred, label in zip(preds, labels)]
|
837 |
logger.info("After setting language and task")
|
838 |
logger.info(f"{pred_labels}")
|
|
|
841 |
results = {}
|
842 |
if training_args.do_eval:
|
843 |
logger.info("*** Evaluate ***")
|
844 |
+
print_data_samples(vectorized_datasets["eval"], tokenizer)
|
845 |
metrics = trainer.evaluate(
|
846 |
metric_key_prefix="eval",
|
847 |
max_length=training_args.generation_max_length,
|