marinone94
commited on
Commit
•
cbbd9a3
1
Parent(s):
dbe5e43
use multiple datasets
Browse files
run_speech_recognition_seq2seq_streaming.py
CHANGED
@@ -394,12 +394,16 @@ def load_maybe_streaming_dataset(
|
|
394 |
else:
|
395 |
dataset = load_dataset(dataset_name, split=split_name, streaming=streaming, **kwargs)
|
396 |
raw_datasets_features = list(dataset.features.keys())
|
|
|
397 |
if text_col_name_ref not in raw_datasets_features:
|
398 |
if len(text_column_names) == 1:
|
399 |
-
raise ValueError("
|
|
|
|
|
400 |
flag = False
|
401 |
for text_column_name in text_column_names:
|
402 |
-
if text_column_name in raw_datasets_features:
|
|
|
403 |
dataset = dataset.rename_column(text_column_name, text_col_name_ref)
|
404 |
flag = True
|
405 |
break
|
@@ -408,9 +412,15 @@ def load_maybe_streaming_dataset(
|
|
408 |
f"Text columns: {text_column_names}"
|
409 |
f"Dataset columns: {raw_datasets_features}")
|
410 |
if audio_column_name is not None and sampling_rate is not None:
|
411 |
-
|
412 |
-
|
413 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
414 |
dataset_splits.append(dataset)
|
415 |
|
416 |
# interleave multiple splits to form one dataset
|
@@ -422,6 +432,36 @@ def load_maybe_streaming_dataset(
|
|
422 |
return dataset
|
423 |
|
424 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
425 |
def main():
|
426 |
# 1. Parse input arguments
|
427 |
# See all possible arguments in src/transformers/training_args.py
|
@@ -510,7 +550,7 @@ def main():
|
|
510 |
streaming=data_args.streaming,
|
511 |
text_column_name=data_args.text_column_name,
|
512 |
audio_column_name=data_args.audio_column_name,
|
513 |
-
sampling_rate=feature_extractor.sampling_rate,
|
514 |
# language=data_args.language_train
|
515 |
)
|
516 |
|
@@ -523,7 +563,7 @@ def main():
|
|
523 |
streaming=data_args.streaming,
|
524 |
text_column_name=data_args.text_column_name,
|
525 |
audio_column_name=data_args.audio_column_name,
|
526 |
-
sampling_rate=feature_extractor.sampling_rate,
|
527 |
# language=data_args.language_eval
|
528 |
)
|
529 |
|
|
|
394 |
else:
|
395 |
dataset = load_dataset(dataset_name, split=split_name, streaming=streaming, **kwargs)
|
396 |
raw_datasets_features = list(dataset.features.keys())
|
397 |
+
logger.info(f"Dataset {dataset_name} - Features: {raw_datasets_features}")
|
398 |
if text_col_name_ref not in raw_datasets_features:
|
399 |
if len(text_column_names) == 1:
|
400 |
+
raise ValueError("None of the text column names provided found in dataset."
|
401 |
+
f"Text columns: {text_column_names}"
|
402 |
+
f"Dataset columns: {raw_datasets_features}")
|
403 |
flag = False
|
404 |
for text_column_name in text_column_names:
|
405 |
+
if text_column_name in raw_datasets_features:
|
406 |
+
logger.info(f"Renaming text column {text_column_name} to {text_col_name_ref}")
|
407 |
dataset = dataset.rename_column(text_column_name, text_col_name_ref)
|
408 |
flag = True
|
409 |
break
|
|
|
412 |
f"Text columns: {text_column_names}"
|
413 |
f"Dataset columns: {raw_datasets_features}")
|
414 |
if audio_column_name is not None and sampling_rate is not None:
|
415 |
+
ds_sr = int(dataset.features[audio_column_name].sampling_rate)
|
416 |
+
if ds_sr != sampling_rate:
|
417 |
+
dataset = dataset.cast_column(
|
418 |
+
audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate)
|
419 |
+
)
|
420 |
+
raw_datasets_features = list(dataset.features.keys())
|
421 |
+
raw_datasets_features.remove(audio_column_name)
|
422 |
+
raw_datasets_features.remove(text_col_name_ref)
|
423 |
+
dataset = dataset.remove_columns(column_names=raw_datasets_features)
|
424 |
dataset_splits.append(dataset)
|
425 |
|
426 |
# interleave multiple splits to form one dataset
|
|
|
432 |
return dataset
|
433 |
|
434 |
|
435 |
+
def load_common_voice_like_dataset(
|
436 |
+
dataset_name,
|
437 |
+
config,
|
438 |
+
split,
|
439 |
+
audio_column_name=None,
|
440 |
+
sampling_rate=None,
|
441 |
+
streaming=True,
|
442 |
+
use_auth_token=False
|
443 |
+
):
|
444 |
+
|
445 |
+
"""
|
446 |
+
Utility function to load the Common Voice dataset.
|
447 |
+
"""
|
448 |
+
dataset = load_dataset(
|
449 |
+
dataset_name,
|
450 |
+
config,
|
451 |
+
split=split,
|
452 |
+
streaming=streaming,
|
453 |
+
use_auth_token=use_auth_token,
|
454 |
+
)
|
455 |
+
if audio_column_name is not None and sampling_rate is not None:
|
456 |
+
dataset = dataset.cast_column(
|
457 |
+
audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate)
|
458 |
+
)
|
459 |
+
return dataset
|
460 |
+
|
461 |
+
|
462 |
+
# def load_nst_nbailab(config, split, )
|
463 |
+
|
464 |
+
|
465 |
def main():
|
466 |
# 1. Parse input arguments
|
467 |
# See all possible arguments in src/transformers/training_args.py
|
|
|
550 |
streaming=data_args.streaming,
|
551 |
text_column_name=data_args.text_column_name,
|
552 |
audio_column_name=data_args.audio_column_name,
|
553 |
+
sampling_rate=int(feature_extractor.sampling_rate),
|
554 |
# language=data_args.language_train
|
555 |
)
|
556 |
|
|
|
563 |
streaming=data_args.streaming,
|
564 |
text_column_name=data_args.text_column_name,
|
565 |
audio_column_name=data_args.audio_column_name,
|
566 |
+
sampling_rate=int(feature_extractor.sampling_rate),
|
567 |
# language=data_args.language_eval
|
568 |
)
|
569 |
|
test_run_nordic.sh
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
python $1run_speech_recognition_seq2seq_streaming.py \
|
2 |
--model_name_or_path="openai/whisper-tiny" \
|
3 |
-
--dataset_train_name="mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,babelbox/babelbox_voice,NbAiLab/NST,
|
4 |
-
--dataset_train_config_name="sv-SE,da,nn-NO,nst,no-distant
|
5 |
-
--language_train="swedish,danish,norwegian,swedish,norwegian,
|
6 |
-
--train_split_name="train+validation,train+validation,train+validation,train,train+test,train
|
7 |
--dataset_eval_name="mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0" \
|
8 |
--dataset_eval_config_name="sv-SE,da,nn-NO" \
|
9 |
--language_eval="swedish,danish,norwegian" \
|
|
|
1 |
python $1run_speech_recognition_seq2seq_streaming.py \
|
2 |
--model_name_or_path="openai/whisper-tiny" \
|
3 |
+
--dataset_train_name="mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,babelbox/babelbox_voice,NbAiLab/NST,NbAiLab/NPSC,google/fleurs,google/fleurs,google/fleurs" \
|
4 |
+
--dataset_train_config_name="sv-SE,da,nn-NO,nst,no-distant,16K_mp3_nynorsk,sv_se,da_dk,nb_no" \
|
5 |
+
--language_train="swedish,danish,norwegian,swedish,norwegian,norwegian,swedish,danish,norwegian" \
|
6 |
+
--train_split_name="train+validation,train+validation,train+validation,train,train+test,train+validation,train+validation,train+validation,train+validation" \
|
7 |
--dataset_eval_name="mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0" \
|
8 |
--dataset_eval_config_name="sv-SE,da,nn-NO" \
|
9 |
--language_eval="swedish,danish,norwegian" \
|