marinone94
commited on
Commit
•
dbe5e43
1
Parent(s):
b7db389
WIP: mix datasets
Browse files- run_speech_recognition_seq2seq_streaming.py +79 -45
- test_run_nordic.sh +2 -2
- test_run_nordic_cv.sh +41 -0
run_speech_recognition_seq2seq_streaming.py
CHANGED
@@ -87,16 +87,17 @@ if hf_token is not None:
|
|
87 |
with open("/root/.huggingface/token", "w") as f:
|
88 |
f.write(hf_token)
|
89 |
logger.info("Huggingface API key set")
|
90 |
-
except PermissionError:
|
91 |
logger.warning("Huggingface API key not set, relying on ~/.huggingface/token")
|
92 |
else:
|
93 |
logger.warning("Huggingface API key not set, relying on ~/.huggingface/token")
|
94 |
|
95 |
-
wandb.login(key=wandb_token, relogin=True, timeout=5)
|
96 |
-
wandb.init(project="whisper", entity="pn-aa")
|
97 |
|
98 |
logger.info("Wandb API key set, logging to wandb")
|
99 |
|
|
|
100 |
@dataclass
|
101 |
class ModelArguments:
|
102 |
"""
|
@@ -300,7 +301,7 @@ class DataCollatorSpeechSeq2SeqWithPadding:
|
|
300 |
model_input_name = self.processor.model_input_names[0]
|
301 |
input_features = [{model_input_name: feature[model_input_name]} for feature in features]
|
302 |
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
303 |
-
lang_features = [f"<|{TO_LANGUAGE_CODE[feature['language']]}|>" for feature in features]
|
304 |
|
305 |
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
|
306 |
|
@@ -313,15 +314,19 @@ class DataCollatorSpeechSeq2SeqWithPadding:
|
|
313 |
# cut bos token here as it's append later anyways
|
314 |
if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
|
315 |
labels = labels[:, 1:]
|
316 |
-
lang_token_ids = self.processor.tokenizer(lang_features).input_ids
|
317 |
-
# Replace language and task if they are in the beginning, otherwise add them
|
318 |
-
if (labels[:, 1] == self.task_id).all().cpu().item():
|
319 |
-
|
320 |
-
|
321 |
-
else:
|
322 |
-
|
323 |
-
|
324 |
-
|
|
|
|
|
|
|
|
|
325 |
|
326 |
batch["labels"] = labels
|
327 |
|
@@ -358,30 +363,54 @@ def notify_me(recipient, message=None):
|
|
358 |
smtp_obj.quit()
|
359 |
|
360 |
|
361 |
-
def load_maybe_streaming_dataset(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
362 |
"""
|
363 |
Utility function to load a dataset in streaming mode. For datasets with multiple splits,
|
364 |
each split is loaded individually and then splits combined by taking alternating examples from
|
365 |
each (interleaving).
|
366 |
"""
|
367 |
-
|
368 |
-
if "
|
369 |
-
|
|
|
370 |
|
371 |
if "," in dataset_names or "+" in split:
|
372 |
# load multiple splits separated by the `+` symbol with streaming mode
|
373 |
dataset_splits = []
|
374 |
-
for dataset_name, dataset_config_name, split_names
|
375 |
-
dataset_names.split(","), dataset_config_names.split(","), split.split(",")
|
376 |
):
|
377 |
for split_name in split_names.split("+"):
|
378 |
-
|
|
|
|
|
|
|
379 |
raw_datasets_features = list(dataset.features.keys())
|
380 |
-
if
|
381 |
-
if len(
|
382 |
raise ValueError("Column name not found in dataset.")
|
383 |
-
|
384 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
385 |
dataset_splits.append(dataset)
|
386 |
|
387 |
# interleave multiple splits to form one dataset
|
@@ -460,6 +489,14 @@ def main():
|
|
460 |
# Set seed before initializing model.
|
461 |
set_seed(training_args.seed)
|
462 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
463 |
# 4. Load dataset
|
464 |
logger.info("*** Load dataset ***")
|
465 |
raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
|
@@ -471,8 +508,10 @@ def main():
|
|
471 |
split=data_args.train_split_name,
|
472 |
use_auth_token=hf_token if model_args.use_auth_token else None,
|
473 |
streaming=data_args.streaming,
|
474 |
-
|
475 |
-
|
|
|
|
|
476 |
)
|
477 |
|
478 |
if training_args.do_eval:
|
@@ -482,7 +521,10 @@ def main():
|
|
482 |
split=data_args.eval_split_name,
|
483 |
use_auth_token=hf_token if model_args.use_auth_token else None,
|
484 |
streaming=data_args.streaming,
|
485 |
-
|
|
|
|
|
|
|
486 |
)
|
487 |
|
488 |
raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
|
@@ -518,12 +560,6 @@ def main():
|
|
518 |
if training_args.gradient_checkpointing:
|
519 |
config.update({"use_cache": False})
|
520 |
|
521 |
-
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
522 |
-
model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
|
523 |
-
cache_dir=model_args.cache_dir,
|
524 |
-
revision=model_args.model_revision,
|
525 |
-
use_auth_token=hf_token if model_args.use_auth_token else None,
|
526 |
-
)
|
527 |
tokenizer = AutoTokenizer.from_pretrained(
|
528 |
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
529 |
cache_dir=model_args.cache_dir,
|
@@ -548,21 +584,19 @@ def main():
|
|
548 |
if model_args.freeze_encoder:
|
549 |
model.freeze_encoder()
|
550 |
|
551 |
-
if data_args.
|
552 |
# We only need to set the task id when the language is specified (i.e. in a multilingual setting)
|
553 |
# If more than a langugae is specified, it will be specified in the data collator
|
554 |
-
tokenizer.set_prefix_tokens(language=data_args.
|
555 |
-
elif data_args.
|
556 |
# make sure language and task are not stored in the model config
|
557 |
model.config.forced_decoder_ids = None
|
558 |
|
559 |
# 6. Resample speech dataset if necessary
|
560 |
-
logger.info("*** Resample dataset ***")
|
561 |
-
dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
|
562 |
-
if dataset_sampling_rate != feature_extractor.sampling_rate:
|
563 |
-
|
564 |
-
data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
|
565 |
-
)
|
566 |
|
567 |
# 7. Preprocessing the datasets.
|
568 |
# We need to read the audio files as arrays and tokenize the targets.
|
@@ -606,7 +640,7 @@ def main():
|
|
606 |
return batch
|
607 |
|
608 |
with training_args.main_process_first(desc="dataset map pre-processing"):
|
609 |
-
raw_datasets_features.remove("language")
|
610 |
vectorized_datasets = raw_datasets.map(
|
611 |
prepare_dataset,
|
612 |
remove_columns=raw_datasets_features,
|
@@ -765,8 +799,8 @@ def main():
|
|
765 |
kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
|
766 |
else:
|
767 |
kwargs["dataset"] = data_args.dataset_name
|
768 |
-
if "common_voice" in data_args.dataset_name:
|
769 |
-
|
770 |
if model_args.model_index_name is not None:
|
771 |
kwargs["model_name"] = model_args.model_index_name
|
772 |
|
|
|
87 |
with open("/root/.huggingface/token", "w") as f:
|
88 |
f.write(hf_token)
|
89 |
logger.info("Huggingface API key set")
|
90 |
+
except (PermissionError, OSError):
|
91 |
logger.warning("Huggingface API key not set, relying on ~/.huggingface/token")
|
92 |
else:
|
93 |
logger.warning("Huggingface API key not set, relying on ~/.huggingface/token")
|
94 |
|
95 |
+
# wandb.login(key=wandb_token, relogin=True, timeout=5)
|
96 |
+
# wandb.init(project="whisper", entity="pn-aa")
|
97 |
|
98 |
logger.info("Wandb API key set, logging to wandb")
|
99 |
|
100 |
+
|
101 |
@dataclass
|
102 |
class ModelArguments:
|
103 |
"""
|
|
|
301 |
model_input_name = self.processor.model_input_names[0]
|
302 |
input_features = [{model_input_name: feature[model_input_name]} for feature in features]
|
303 |
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
304 |
+
# lang_features = [f"<|{TO_LANGUAGE_CODE[feature['language']]}|>" for feature in features]
|
305 |
|
306 |
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
|
307 |
|
|
|
314 |
# cut bos token here as it's append later anyways
|
315 |
if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
|
316 |
labels = labels[:, 1:]
|
317 |
+
# lang_token_ids = self.processor.tokenizer(lang_features).input_ids
|
318 |
+
# # Replace language and task if they are in the beginning, otherwise add them
|
319 |
+
# if (labels[:, 1] == self.task_id).all().cpu().item():
|
320 |
+
# labels[:, 0] = lang_token_ids
|
321 |
+
# labels[:, 1] = torch.full_like(labels[:, 1], self.task_id)
|
322 |
+
# else:
|
323 |
+
# # convert task id to tensor of labels dim to concatenate
|
324 |
+
# task_id = torch.full_like(labels[:, 0], self.task_id)
|
325 |
+
# labels = torch.cat((lang_token_ids, task_id, labels), dim=1)
|
326 |
+
|
327 |
+
# Set language and task to pad token
|
328 |
+
labels[:, 0] = torch.full_like(labels[:, 0], -100)
|
329 |
+
labels[:, 1] = torch.full_like(labels[:, 1], -100)
|
330 |
|
331 |
batch["labels"] = labels
|
332 |
|
|
|
363 |
smtp_obj.quit()
|
364 |
|
365 |
|
366 |
+
def load_maybe_streaming_dataset(
|
367 |
+
dataset_names,
|
368 |
+
dataset_config_names,
|
369 |
+
split="train",
|
370 |
+
streaming=True,
|
371 |
+
audio_column_name=None,
|
372 |
+
sampling_rate=None,
|
373 |
+
**kwargs
|
374 |
+
):
|
375 |
"""
|
376 |
Utility function to load a dataset in streaming mode. For datasets with multiple splits,
|
377 |
each split is loaded individually and then splits combined by taking alternating examples from
|
378 |
each (interleaving).
|
379 |
"""
|
380 |
+
text_column_names = None
|
381 |
+
if "text_column_name" in kwargs:
|
382 |
+
text_column_names = kwargs.pop("text_column_name").split(",")
|
383 |
+
text_col_name_ref = text_column_names[0]
|
384 |
|
385 |
if "," in dataset_names or "+" in split:
|
386 |
# load multiple splits separated by the `+` symbol with streaming mode
|
387 |
dataset_splits = []
|
388 |
+
for dataset_name, dataset_config_name, split_names in zip(
|
389 |
+
dataset_names.split(","), dataset_config_names.split(","), split.split(",")
|
390 |
):
|
391 |
for split_name in split_names.split("+"):
|
392 |
+
if dataset_config_name:
|
393 |
+
dataset = load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=streaming, **kwargs)
|
394 |
+
else:
|
395 |
+
dataset = load_dataset(dataset_name, split=split_name, streaming=streaming, **kwargs)
|
396 |
raw_datasets_features = list(dataset.features.keys())
|
397 |
+
if text_col_name_ref not in raw_datasets_features:
|
398 |
+
if len(text_column_names) == 1:
|
399 |
raise ValueError("Column name not found in dataset.")
|
400 |
+
flag = False
|
401 |
+
for text_column_name in text_column_names:
|
402 |
+
if text_column_name in raw_datasets_features:
|
403 |
+
dataset = dataset.rename_column(text_column_name, text_col_name_ref)
|
404 |
+
flag = True
|
405 |
+
break
|
406 |
+
if flag is False:
|
407 |
+
raise ValueError("None of the text column names provided found in dataset."
|
408 |
+
f"Text columns: {text_column_names}"
|
409 |
+
f"Dataset columns: {raw_datasets_features}")
|
410 |
+
if audio_column_name is not None and sampling_rate is not None:
|
411 |
+
dataset = dataset.cast_column(
|
412 |
+
audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate)
|
413 |
+
)
|
414 |
dataset_splits.append(dataset)
|
415 |
|
416 |
# interleave multiple splits to form one dataset
|
|
|
489 |
# Set seed before initializing model.
|
490 |
set_seed(training_args.seed)
|
491 |
|
492 |
+
# Load feature extractor
|
493 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
494 |
+
model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
|
495 |
+
cache_dir=model_args.cache_dir,
|
496 |
+
revision=model_args.model_revision,
|
497 |
+
use_auth_token=hf_token if model_args.use_auth_token else None,
|
498 |
+
)
|
499 |
+
|
500 |
# 4. Load dataset
|
501 |
logger.info("*** Load dataset ***")
|
502 |
raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
|
|
|
508 |
split=data_args.train_split_name,
|
509 |
use_auth_token=hf_token if model_args.use_auth_token else None,
|
510 |
streaming=data_args.streaming,
|
511 |
+
text_column_name=data_args.text_column_name,
|
512 |
+
audio_column_name=data_args.audio_column_name,
|
513 |
+
sampling_rate=feature_extractor.sampling_rate,
|
514 |
+
# language=data_args.language_train
|
515 |
)
|
516 |
|
517 |
if training_args.do_eval:
|
|
|
521 |
split=data_args.eval_split_name,
|
522 |
use_auth_token=hf_token if model_args.use_auth_token else None,
|
523 |
streaming=data_args.streaming,
|
524 |
+
text_column_name=data_args.text_column_name,
|
525 |
+
audio_column_name=data_args.audio_column_name,
|
526 |
+
sampling_rate=feature_extractor.sampling_rate,
|
527 |
+
# language=data_args.language_eval
|
528 |
)
|
529 |
|
530 |
raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
|
|
|
560 |
if training_args.gradient_checkpointing:
|
561 |
config.update({"use_cache": False})
|
562 |
|
|
|
|
|
|
|
|
|
|
|
|
|
563 |
tokenizer = AutoTokenizer.from_pretrained(
|
564 |
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
565 |
cache_dir=model_args.cache_dir,
|
|
|
584 |
if model_args.freeze_encoder:
|
585 |
model.freeze_encoder()
|
586 |
|
587 |
+
if data_args.language_train is not None and len(data_args.language_train.split(",")) == 1:
|
588 |
# We only need to set the task id when the language is specified (i.e. in a multilingual setting)
|
589 |
# If more than a langugae is specified, it will be specified in the data collator
|
590 |
+
tokenizer.set_prefix_tokens(language=data_args.language_train, task=data_args.task)
|
591 |
+
elif data_args.language_train is not None and len(data_args.language_train.split(",")) > 1:
|
592 |
# make sure language and task are not stored in the model config
|
593 |
model.config.forced_decoder_ids = None
|
594 |
|
595 |
# 6. Resample speech dataset if necessary
|
596 |
+
# logger.info("*** Resample dataset ***")
|
597 |
+
# dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
|
598 |
+
# if dataset_sampling_rate != feature_extractor.sampling_rate:
|
599 |
+
|
|
|
|
|
600 |
|
601 |
# 7. Preprocessing the datasets.
|
602 |
# We need to read the audio files as arrays and tokenize the targets.
|
|
|
640 |
return batch
|
641 |
|
642 |
with training_args.main_process_first(desc="dataset map pre-processing"):
|
643 |
+
# raw_datasets_features.remove("language")
|
644 |
vectorized_datasets = raw_datasets.map(
|
645 |
prepare_dataset,
|
646 |
remove_columns=raw_datasets_features,
|
|
|
799 |
kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
|
800 |
else:
|
801 |
kwargs["dataset"] = data_args.dataset_name
|
802 |
+
# if "common_voice" in data_args.dataset_name:
|
803 |
+
# kwargs["language"] = data_args.dataset_config_name[:2]
|
804 |
if model_args.model_index_name is not None:
|
805 |
kwargs["model_name"] = model_args.model_index_name
|
806 |
|
test_run_nordic.sh
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
python $1run_speech_recognition_seq2seq_streaming.py \
|
2 |
--model_name_or_path="openai/whisper-tiny" \
|
3 |
--dataset_train_name="mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,babelbox/babelbox_voice,NbAiLab/NST,arpelarpe/nota,NbAiLab/NPSC,google/fleurs,google/fleurs,google/fleurs" \
|
4 |
-
--dataset_train_config_name="sv-SE,da,nn-NO
|
5 |
--language_train="swedish,danish,norwegian,swedish,norwegian,danish,norwegian,swedish,danish,norwegian" \
|
6 |
--train_split_name="train+validation,train+validation,train+validation,train,train+test,train,train+validation,train+validation,train+validation,train+validation" \
|
7 |
--dataset_eval_name="mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0" \
|
@@ -25,7 +25,7 @@ python $1run_speech_recognition_seq2seq_streaming.py \
|
|
25 |
--generation_max_length="225" \
|
26 |
--length_column_name="input_length" \
|
27 |
--max_duration_in_seconds="30" \
|
28 |
-
--text_column_name="sentence,text" \
|
29 |
--freeze_feature_encoder="False" \
|
30 |
--report_to="wandb" \
|
31 |
--metric_for_best_model="wer" \
|
|
|
1 |
python $1run_speech_recognition_seq2seq_streaming.py \
|
2 |
--model_name_or_path="openai/whisper-tiny" \
|
3 |
--dataset_train_name="mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,babelbox/babelbox_voice,NbAiLab/NST,arpelarpe/nota,NbAiLab/NPSC,google/fleurs,google/fleurs,google/fleurs" \
|
4 |
+
--dataset_train_config_name="sv-SE,da,nn-NO,nst,no-distant,,16K_mp3_nynorsk,sv_se,da_dk,nb_no" \
|
5 |
--language_train="swedish,danish,norwegian,swedish,norwegian,danish,norwegian,swedish,danish,norwegian" \
|
6 |
--train_split_name="train+validation,train+validation,train+validation,train,train+test,train,train+validation,train+validation,train+validation,train+validation" \
|
7 |
--dataset_eval_name="mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0" \
|
|
|
25 |
--generation_max_length="225" \
|
26 |
--length_column_name="input_length" \
|
27 |
--max_duration_in_seconds="30" \
|
28 |
+
--text_column_name="sentence,text,raw_transcription" \
|
29 |
--freeze_feature_encoder="False" \
|
30 |
--report_to="wandb" \
|
31 |
--metric_for_best_model="wer" \
|
test_run_nordic_cv.sh
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python $1run_speech_recognition_seq2seq_streaming.py \
|
2 |
+
--model_name_or_path="openai/whisper-tiny" \
|
3 |
+
--dataset_train_name="mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0" \
|
4 |
+
--dataset_train_config_name="sv-SE,da,nn-NO" \
|
5 |
+
--language_train="swedish,danish,norwegian" \
|
6 |
+
--train_split_name="train+validation,train+validation,train+validation" \
|
7 |
+
--dataset_eval_name="mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0" \
|
8 |
+
--dataset_eval_config_name="sv-SE,da,nn-NO" \
|
9 |
+
--language_eval="swedish,danish,norwegian" \
|
10 |
+
--eval_split_name="test" \
|
11 |
+
--model_index_name="Whisper Tiny Swedish" \
|
12 |
+
--max_train_samples="64" \
|
13 |
+
--max_eval_samples="32" \
|
14 |
+
--max_steps="500" \
|
15 |
+
--output_dir="./" \
|
16 |
+
--per_device_train_batch_size="8" \
|
17 |
+
--per_device_eval_batch_size="4" \
|
18 |
+
--logging_steps="25" \
|
19 |
+
--learning_rate="1e-5" \
|
20 |
+
--warmup_steps="500" \
|
21 |
+
--evaluation_strategy="steps" \
|
22 |
+
--eval_steps="1000" \
|
23 |
+
--save_strategy="steps" \
|
24 |
+
--save_steps="1000" \
|
25 |
+
--generation_max_length="225" \
|
26 |
+
--length_column_name="input_length" \
|
27 |
+
--max_duration_in_seconds="30" \
|
28 |
+
--text_column_name="sentence,text" \
|
29 |
+
--freeze_feature_encoder="False" \
|
30 |
+
--metric_for_best_model="wer" \
|
31 |
+
--greater_is_better="False" \
|
32 |
+
--load_best_model_at_end \
|
33 |
+
--gradient_checkpointing \
|
34 |
+
--overwrite_output_dir \
|
35 |
+
--do_train \
|
36 |
+
--do_eval \
|
37 |
+
--predict_with_generate \
|
38 |
+
--do_normalize_eval \
|
39 |
+
--streaming \
|
40 |
+
--use_auth_token \
|
41 |
+
--push_to_hub
|