marinone94 commited on
Commit
9e467b3
1 Parent(s): e023bb4

use tokenizer to batch_decode

Browse files
run_speech_recognition_seq2seq_streaming.py CHANGED
@@ -458,7 +458,7 @@ def load_maybe_streaming_dataset(
458
  return dataset
459
 
460
 
461
- def print_data_samples(dataset, processor, max_samples=5):
462
  shown_samples = 0
463
  for batch in dataset:
464
  print("Target: ", tokenizer.batch_decode(batch["labels"]))
@@ -786,7 +786,7 @@ def main():
786
  # 12. Training
787
  if training_args.do_train:
788
  logger.info("*** Train ***")
789
- print_data_samples(vectorized_datasets["train"], processor)
790
  checkpoint = None
791
  if training_args.resume_from_checkpoint is not None:
792
  checkpoint = training_args.resume_from_checkpoint
@@ -824,15 +824,15 @@ def main():
824
  num_beams=training_args.generation_num_beams,
825
  )
826
  logger.info("*** Test prediction done ***")
827
- preds = processor.batch_decode(predictions.predictions)
828
- labels = processor.batch_decode(predictions.label_ids)
829
  pred_labels = [f"Prediction: {pred}\nLabel: {label}\n" for pred, label in zip(preds, labels)]
830
  logger.info("Before setting language and task")
831
  logger.info(f"{pred_labels}")
832
  trainer.model.config.forced_decoder_ids = \
833
- processor.get_decoder_prompt_ids(language=data_args.language_eval, task=data_args.task, no_timestamps=True)
834
- preds = processor.batch_decode(predictions.predictions)
835
- labels = processor.batch_decode(predictions.label_ids)
836
  pred_labels = [f"Prediction: {pred}\nLabel: {label}\n" for pred, label in zip(preds, labels)]
837
  logger.info("After setting language and task")
838
  logger.info(f"{pred_labels}")
@@ -841,7 +841,7 @@ def main():
841
  results = {}
842
  if training_args.do_eval:
843
  logger.info("*** Evaluate ***")
844
- print_data_samples(vectorized_datasets["eval"], processor)
845
  metrics = trainer.evaluate(
846
  metric_key_prefix="eval",
847
  max_length=training_args.generation_max_length,
 
458
  return dataset
459
 
460
 
461
+ def print_data_samples(dataset, tokenizer, max_samples=5):
462
  shown_samples = 0
463
  for batch in dataset:
464
  print("Target: ", tokenizer.batch_decode(batch["labels"]))
 
786
  # 12. Training
787
  if training_args.do_train:
788
  logger.info("*** Train ***")
789
+ print_data_samples(vectorized_datasets["train"], tokenizer)
790
  checkpoint = None
791
  if training_args.resume_from_checkpoint is not None:
792
  checkpoint = training_args.resume_from_checkpoint
 
824
  num_beams=training_args.generation_num_beams,
825
  )
826
  logger.info("*** Test prediction done ***")
827
+ preds = tokenizer.batch_decode(predictions.predictions)
828
+ labels = tokenizer.batch_decode(predictions.label_ids)
829
  pred_labels = [f"Prediction: {pred}\nLabel: {label}\n" for pred, label in zip(preds, labels)]
830
  logger.info("Before setting language and task")
831
  logger.info(f"{pred_labels}")
832
  trainer.model.config.forced_decoder_ids = \
833
+ tokenizer.get_decoder_prompt_ids(language=data_args.language_eval, task=data_args.task, no_timestamps=True)
834
+ preds = tokenizer.batch_decode(predictions.predictions)
835
+ labels = tokenizer.batch_decode(predictions.label_ids)
836
  pred_labels = [f"Prediction: {pred}\nLabel: {label}\n" for pred, label in zip(preds, labels)]
837
  logger.info("After setting language and task")
838
  logger.info(f"{pred_labels}")
 
841
  results = {}
842
  if training_args.do_eval:
843
  logger.info("*** Evaluate ***")
844
+ print_data_samples(vectorized_datasets["eval"], tokenizer)
845
  metrics = trainer.evaluate(
846
  metric_key_prefix="eval",
847
  max_length=training_args.generation_max_length,