marinone94 commited on
Commit
cbbd9a3
1 Parent(s): dbe5e43

use multiple datasets

Browse files
run_speech_recognition_seq2seq_streaming.py CHANGED
@@ -394,12 +394,16 @@ def load_maybe_streaming_dataset(
394
  else:
395
  dataset = load_dataset(dataset_name, split=split_name, streaming=streaming, **kwargs)
396
  raw_datasets_features = list(dataset.features.keys())
 
397
  if text_col_name_ref not in raw_datasets_features:
398
  if len(text_column_names) == 1:
399
- raise ValueError("Column name not found in dataset.")
 
 
400
  flag = False
401
  for text_column_name in text_column_names:
402
- if text_column_name in raw_datasets_features:
 
403
  dataset = dataset.rename_column(text_column_name, text_col_name_ref)
404
  flag = True
405
  break
@@ -408,9 +412,15 @@ def load_maybe_streaming_dataset(
408
  f"Text columns: {text_column_names}"
409
  f"Dataset columns: {raw_datasets_features}")
410
  if audio_column_name is not None and sampling_rate is not None:
411
- dataset = dataset.cast_column(
412
- audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate)
413
- )
 
 
 
 
 
 
414
  dataset_splits.append(dataset)
415
 
416
  # interleave multiple splits to form one dataset
@@ -422,6 +432,36 @@ def load_maybe_streaming_dataset(
422
  return dataset
423
 
424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
  def main():
426
  # 1. Parse input arguments
427
  # See all possible arguments in src/transformers/training_args.py
@@ -510,7 +550,7 @@ def main():
510
  streaming=data_args.streaming,
511
  text_column_name=data_args.text_column_name,
512
  audio_column_name=data_args.audio_column_name,
513
- sampling_rate=feature_extractor.sampling_rate,
514
  # language=data_args.language_train
515
  )
516
 
@@ -523,7 +563,7 @@ def main():
523
  streaming=data_args.streaming,
524
  text_column_name=data_args.text_column_name,
525
  audio_column_name=data_args.audio_column_name,
526
- sampling_rate=feature_extractor.sampling_rate,
527
  # language=data_args.language_eval
528
  )
529
 
 
394
  else:
395
  dataset = load_dataset(dataset_name, split=split_name, streaming=streaming, **kwargs)
396
  raw_datasets_features = list(dataset.features.keys())
397
+ logger.info(f"Dataset {dataset_name} - Features: {raw_datasets_features}")
398
  if text_col_name_ref not in raw_datasets_features:
399
  if len(text_column_names) == 1:
400
+ raise ValueError("None of the text column names provided found in dataset."
401
+ f"Text columns: {text_column_names}"
402
+ f"Dataset columns: {raw_datasets_features}")
403
  flag = False
404
  for text_column_name in text_column_names:
405
+ if text_column_name in raw_datasets_features:
406
+ logger.info(f"Renaming text column {text_column_name} to {text_col_name_ref}")
407
  dataset = dataset.rename_column(text_column_name, text_col_name_ref)
408
  flag = True
409
  break
 
412
  f"Text columns: {text_column_names}"
413
  f"Dataset columns: {raw_datasets_features}")
414
  if audio_column_name is not None and sampling_rate is not None:
415
+ ds_sr = int(dataset.features[audio_column_name].sampling_rate)
416
+ if ds_sr != sampling_rate:
417
+ dataset = dataset.cast_column(
418
+ audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate)
419
+ )
420
+ raw_datasets_features = list(dataset.features.keys())
421
+ raw_datasets_features.remove(audio_column_name)
422
+ raw_datasets_features.remove(text_col_name_ref)
423
+ dataset = dataset.remove_columns(column_names=raw_datasets_features)
424
  dataset_splits.append(dataset)
425
 
426
  # interleave multiple splits to form one dataset
 
432
  return dataset
433
 
434
 
435
+ def load_common_voice_like_dataset(
436
+ dataset_name,
437
+ config,
438
+ split,
439
+ audio_column_name=None,
440
+ sampling_rate=None,
441
+ streaming=True,
442
+ use_auth_token=False
443
+ ):
444
+
445
+ """
446
+ Utility function to load the Common Voice dataset.
447
+ """
448
+ dataset = load_dataset(
449
+ dataset_name,
450
+ config,
451
+ split=split,
452
+ streaming=streaming,
453
+ use_auth_token=use_auth_token,
454
+ )
455
+ if audio_column_name is not None and sampling_rate is not None:
456
+ dataset = dataset.cast_column(
457
+ audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate)
458
+ )
459
+ return dataset
460
+
461
+
462
+ # def load_nst_nbailab(config, split, )
463
+
464
+
465
  def main():
466
  # 1. Parse input arguments
467
  # See all possible arguments in src/transformers/training_args.py
 
550
  streaming=data_args.streaming,
551
  text_column_name=data_args.text_column_name,
552
  audio_column_name=data_args.audio_column_name,
553
+ sampling_rate=int(feature_extractor.sampling_rate),
554
  # language=data_args.language_train
555
  )
556
 
 
563
  streaming=data_args.streaming,
564
  text_column_name=data_args.text_column_name,
565
  audio_column_name=data_args.audio_column_name,
566
+ sampling_rate=int(feature_extractor.sampling_rate),
567
  # language=data_args.language_eval
568
  )
569
 
test_run_nordic.sh CHANGED
@@ -1,9 +1,9 @@
1
  python $1run_speech_recognition_seq2seq_streaming.py \
2
  --model_name_or_path="openai/whisper-tiny" \
3
- --dataset_train_name="mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,babelbox/babelbox_voice,NbAiLab/NST,arpelarpe/nota,NbAiLab/NPSC,google/fleurs,google/fleurs,google/fleurs" \
4
- --dataset_train_config_name="sv-SE,da,nn-NO,nst,no-distant,,16K_mp3_nynorsk,sv_se,da_dk,nb_no" \
5
- --language_train="swedish,danish,norwegian,swedish,norwegian,danish,norwegian,swedish,danish,norwegian" \
6
- --train_split_name="train+validation,train+validation,train+validation,train,train+test,train,train+validation,train+validation,train+validation,train+validation" \
7
  --dataset_eval_name="mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0" \
8
  --dataset_eval_config_name="sv-SE,da,nn-NO" \
9
  --language_eval="swedish,danish,norwegian" \
 
1
  python $1run_speech_recognition_seq2seq_streaming.py \
2
  --model_name_or_path="openai/whisper-tiny" \
3
+ --dataset_train_name="mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,babelbox/babelbox_voice,NbAiLab/NST,NbAiLab/NPSC,google/fleurs,google/fleurs,google/fleurs" \
4
+ --dataset_train_config_name="sv-SE,da,nn-NO,nst,no-distant,16K_mp3_nynorsk,sv_se,da_dk,nb_no" \
5
+ --language_train="swedish,danish,norwegian,swedish,norwegian,norwegian,swedish,danish,norwegian" \
6
+ --train_split_name="train+validation,train+validation,train+validation,train,train+test,train+validation,train+validation,train+validation,train+validation" \
7
  --dataset_eval_name="mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0" \
8
  --dataset_eval_config_name="sv-SE,da,nn-NO" \
9
  --language_eval="swedish,danish,norwegian" \