marinone94
commited on
Commit
•
6ea4d4a
1
Parent(s):
ee5b1b2
allow multiple languages and datasets
Browse files
run_speech_recognition_seq2seq_streaming.py
CHANGED
@@ -49,6 +49,7 @@ from transformers import (
|
|
49 |
set_seed,
|
50 |
)
|
51 |
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
|
|
|
52 |
from transformers.trainer_pt_utils import IterableDatasetShard
|
53 |
from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
54 |
from transformers.utils import check_min_version, send_example_telemetry
|
@@ -61,6 +62,9 @@ require_version("datasets>=1.18.2", "To fix: pip install -r examples/pytorch/spe
|
|
61 |
|
62 |
logger = logging.getLogger(__name__)
|
63 |
|
|
|
|
|
|
|
64 |
wandb_token = os.environ.get("WANDB_TOKEN", "None")
|
65 |
hf_token = os.environ.get("HF_TOKEN", None)
|
66 |
if (hf_token is None or wandb_token == "None") and os.path.exists("./creds.txt"):
|
@@ -160,10 +164,16 @@ class DataTrainingArguments:
|
|
160 |
Arguments pertaining to what data we are going to input our model for training and eval.
|
161 |
"""
|
162 |
|
163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
165 |
)
|
166 |
-
|
167 |
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
168 |
)
|
169 |
text_column: Optional[str] = field(
|
@@ -232,7 +242,16 @@ class DataTrainingArguments:
|
|
232 |
default=True,
|
233 |
metadata={"help": "Whether to normalise the references and predictions in the eval WER calculation."},
|
234 |
)
|
235 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
default=None,
|
237 |
metadata={
|
238 |
"help": (
|
@@ -273,6 +292,7 @@ class DataCollatorSpeechSeq2SeqWithPadding:
|
|
273 |
|
274 |
processor: Any
|
275 |
decoder_start_token_id: int
|
|
|
276 |
|
277 |
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
278 |
# split inputs and labels since they have to be of different lengths and need
|
@@ -280,6 +300,7 @@ class DataCollatorSpeechSeq2SeqWithPadding:
|
|
280 |
model_input_name = self.processor.model_input_names[0]
|
281 |
input_features = [{model_input_name: feature[model_input_name]} for feature in features]
|
282 |
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
|
|
283 |
|
284 |
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
|
285 |
|
@@ -292,6 +313,15 @@ class DataCollatorSpeechSeq2SeqWithPadding:
|
|
292 |
# cut bos token here as it's append later anyways
|
293 |
if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
|
294 |
labels = labels[:, 1:]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
|
296 |
batch["labels"] = labels
|
297 |
|
@@ -316,7 +346,7 @@ def notify_me(recipient, message=None):
|
|
316 |
from email.mime.text import MIMEText
|
317 |
|
318 |
msg = MIMEText(message)
|
319 |
-
msg["Subject"] = "Training
|
320 |
msg["From"] = "[email protected]"
|
321 |
msg["To"] = recipient
|
322 |
|
@@ -334,16 +364,26 @@ def load_maybe_streaming_dataset(dataset_names, dataset_config_names, split="tra
|
|
334 |
each split is loaded individually and then splits combined by taking alternating examples from
|
335 |
each (interleaving).
|
336 |
"""
|
|
|
|
|
|
|
|
|
337 |
if "," in dataset_names or "+" in split:
|
338 |
# load multiple splits separated by the `+` symbol with streaming mode
|
339 |
dataset_splits = []
|
340 |
-
for dataset_name, dataset_config_name, split_names in zip(
|
341 |
-
dataset_names.split(","), dataset_config_names.split(","), split.split(",")
|
342 |
):
|
343 |
for split_name in split_names.split("+"):
|
344 |
-
dataset = load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=streaming, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
345 |
dataset_splits.append(dataset)
|
346 |
-
|
347 |
# interleave multiple splits to form one dataset
|
348 |
interleaved_dataset = interleave_datasets(dataset_splits)
|
349 |
return interleaved_dataset
|
@@ -426,20 +466,23 @@ def main():
|
|
426 |
|
427 |
if training_args.do_train:
|
428 |
raw_datasets["train"] = load_maybe_streaming_dataset(
|
429 |
-
data_args.
|
430 |
-
data_args.
|
431 |
split=data_args.train_split_name,
|
432 |
use_auth_token=hf_token if model_args.use_auth_token else None,
|
433 |
streaming=data_args.streaming,
|
|
|
|
|
434 |
)
|
435 |
|
436 |
if training_args.do_eval:
|
437 |
raw_datasets["eval"] = load_maybe_streaming_dataset(
|
438 |
-
data_args.
|
439 |
-
data_args.
|
440 |
split=data_args.eval_split_name,
|
441 |
use_auth_token=hf_token if model_args.use_auth_token else None,
|
442 |
streaming=data_args.streaming,
|
|
|
443 |
)
|
444 |
|
445 |
raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
|
@@ -451,6 +494,7 @@ def main():
|
|
451 |
f"{', '.join(raw_datasets_features)}."
|
452 |
)
|
453 |
|
|
|
454 |
if data_args.text_column_name not in raw_datasets_features:
|
455 |
raise ValueError(
|
456 |
f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
|
@@ -504,9 +548,13 @@ def main():
|
|
504 |
if model_args.freeze_encoder:
|
505 |
model.freeze_encoder()
|
506 |
|
507 |
-
if data_args.language is not None:
|
508 |
# We only need to set the task id when the language is specified (i.e. in a multilingual setting)
|
|
|
509 |
tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
|
|
|
|
|
|
|
510 |
|
511 |
# 6. Resample speech dataset if necessary
|
512 |
logger.info("*** Resample dataset ***")
|
@@ -558,6 +606,7 @@ def main():
|
|
558 |
return batch
|
559 |
|
560 |
with training_args.main_process_first(desc="dataset map pre-processing"):
|
|
|
561 |
vectorized_datasets = raw_datasets.map(
|
562 |
prepare_dataset,
|
563 |
remove_columns=raw_datasets_features,
|
@@ -617,9 +666,14 @@ def main():
|
|
617 |
processor = AutoProcessor.from_pretrained(training_args.output_dir)
|
618 |
|
619 |
# 10. Define data collator
|
|
|
|
|
|
|
|
|
620 |
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
|
621 |
processor=processor,
|
622 |
decoder_start_token_id=model.config.decoder_start_token_id,
|
|
|
623 |
)
|
624 |
|
625 |
# 11. Configure Trainer
|
@@ -716,20 +770,24 @@ def main():
|
|
716 |
if model_args.model_index_name is not None:
|
717 |
kwargs["model_name"] = model_args.model_index_name
|
718 |
|
|
|
|
|
|
|
|
|
|
|
|
|
719 |
if training_args.push_to_hub:
|
720 |
logger.info("*** Pushing to hub ***")
|
721 |
trainer.push_to_hub(**kwargs)
|
722 |
logger.info("*** Pushed to hub ***")
|
|
|
|
|
723 |
else:
|
724 |
logger.info("*** Creating model card ***")
|
725 |
trainer.create_model_card(**kwargs)
|
726 |
logger.info("*** Model card created ***")
|
727 |
-
|
728 |
-
|
729 |
-
logger.info("*** Sending notification ***")
|
730 |
-
notify_me(recipient="[email protected]", message=json.dumps(kwargs, indent=4))
|
731 |
-
|
732 |
-
logger.info("*** Training complete!!! ***")
|
733 |
|
734 |
return results
|
735 |
|
|
|
49 |
set_seed,
|
50 |
)
|
51 |
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
|
52 |
+
from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE
|
53 |
from transformers.trainer_pt_utils import IterableDatasetShard
|
54 |
from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
55 |
from transformers.utils import check_min_version, send_example_telemetry
|
|
|
62 |
|
63 |
logger = logging.getLogger(__name__)
|
64 |
|
65 |
+
SENDING_NOTIFICATION = "*** Sending notification to email ***"
|
66 |
+
RECIPIENT_ADDRESS = "[email protected]"
|
67 |
+
|
68 |
wandb_token = os.environ.get("WANDB_TOKEN", "None")
|
69 |
hf_token = os.environ.get("HF_TOKEN", None)
|
70 |
if (hf_token is None or wandb_token == "None") and os.path.exists("./creds.txt"):
|
|
|
164 |
Arguments pertaining to what data we are going to input our model for training and eval.
|
165 |
"""
|
166 |
|
167 |
+
dataset_train_name: str = field(
|
168 |
+
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
169 |
+
)
|
170 |
+
dataset_train_config_name: Optional[str] = field(
|
171 |
+
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
172 |
+
)
|
173 |
+
dataset_eval_name: str = field(
|
174 |
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
175 |
)
|
176 |
+
dataset_eval_config_name: Optional[str] = field(
|
177 |
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
178 |
)
|
179 |
text_column: Optional[str] = field(
|
|
|
242 |
default=True,
|
243 |
metadata={"help": "Whether to normalise the references and predictions in the eval WER calculation."},
|
244 |
)
|
245 |
+
language_train: str = field(
|
246 |
+
default=None,
|
247 |
+
metadata={
|
248 |
+
"help": (
|
249 |
+
"Language for multilingual fine-tuning. This argument should be set for multilingual fine-tuning "
|
250 |
+
"only. For English speech recognition, it should be set to `None`."
|
251 |
+
)
|
252 |
+
},
|
253 |
+
)
|
254 |
+
language_eval: str = field(
|
255 |
default=None,
|
256 |
metadata={
|
257 |
"help": (
|
|
|
292 |
|
293 |
processor: Any
|
294 |
decoder_start_token_id: int
|
295 |
+
task_id: int
|
296 |
|
297 |
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
298 |
# split inputs and labels since they have to be of different lengths and need
|
|
|
300 |
model_input_name = self.processor.model_input_names[0]
|
301 |
input_features = [{model_input_name: feature[model_input_name]} for feature in features]
|
302 |
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
303 |
+
lang_features = [f"<|{TO_LANGUAGE_CODE[feature['language']]}|>" for feature in features]
|
304 |
|
305 |
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
|
306 |
|
|
|
313 |
# cut bos token here as it's append later anyways
|
314 |
if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
|
315 |
labels = labels[:, 1:]
|
316 |
+
lang_token_ids = self.processor.tokenizer(lang_features).input_ids
|
317 |
+
# Replace language and task if they are in the beginning, otherwise add them
|
318 |
+
if (labels[:, 1] == self.task_id).all().cpu().item():
|
319 |
+
labels[:, 0] = lang_token_ids
|
320 |
+
labels[:, 1] = torch.full_like(labels[:, 1], self.task_id)
|
321 |
+
else:
|
322 |
+
# convert task id to tensor of labels dim to concatenate
|
323 |
+
task_id = torch.full_like(labels[:, 0], self.task_id)
|
324 |
+
labels = torch.cat((lang_token_ids, task_id, labels), dim=1)
|
325 |
|
326 |
batch["labels"] = labels
|
327 |
|
|
|
346 |
from email.mime.text import MIMEText
|
347 |
|
348 |
msg = MIMEText(message)
|
349 |
+
msg["Subject"] = "Training updates..."
|
350 |
msg["From"] = "[email protected]"
|
351 |
msg["To"] = recipient
|
352 |
|
|
|
364 |
each split is loaded individually and then splits combined by taking alternating examples from
|
365 |
each (interleaving).
|
366 |
"""
|
367 |
+
column_names = None
|
368 |
+
if "column_names" in kwargs:
|
369 |
+
column_names = kwargs.pop("column_names").split(",")
|
370 |
+
|
371 |
if "," in dataset_names or "+" in split:
|
372 |
# load multiple splits separated by the `+` symbol with streaming mode
|
373 |
dataset_splits = []
|
374 |
+
for dataset_name, dataset_config_name, split_names, lang in zip(
|
375 |
+
dataset_names.split(","), dataset_config_names.split(","), split.split(","), kwargs.pop("language").split(",")
|
376 |
):
|
377 |
for split_name in split_names.split("+"):
|
378 |
+
dataset = load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=streaming, **kwargs)
|
379 |
+
raw_datasets_features = list(next(iter(dataset.values())).features.keys())
|
380 |
+
if column_names[0] not in raw_datasets_features:
|
381 |
+
if len(column_names) == 1 or column_names[1] not in raw_datasets_features:
|
382 |
+
raise ValueError("Column name not found in dataset.")
|
383 |
+
dataset = dataset.rename_columns(column_names[1], column_names[0])
|
384 |
+
dataset["language"] = lang
|
385 |
dataset_splits.append(dataset)
|
386 |
+
|
387 |
# interleave multiple splits to form one dataset
|
388 |
interleaved_dataset = interleave_datasets(dataset_splits)
|
389 |
return interleaved_dataset
|
|
|
466 |
|
467 |
if training_args.do_train:
|
468 |
raw_datasets["train"] = load_maybe_streaming_dataset(
|
469 |
+
data_args.dataset_train_name,
|
470 |
+
data_args.dataset_train_config_name,
|
471 |
split=data_args.train_split_name,
|
472 |
use_auth_token=hf_token if model_args.use_auth_token else None,
|
473 |
streaming=data_args.streaming,
|
474 |
+
column_names=data_args.text_column_name,
|
475 |
+
language=data_args.language_train
|
476 |
)
|
477 |
|
478 |
if training_args.do_eval:
|
479 |
raw_datasets["eval"] = load_maybe_streaming_dataset(
|
480 |
+
data_args.dataset_eval_name,
|
481 |
+
data_args.dataset_eval_config_name,
|
482 |
split=data_args.eval_split_name,
|
483 |
use_auth_token=hf_token if model_args.use_auth_token else None,
|
484 |
streaming=data_args.streaming,
|
485 |
+
language=data_args.language_eval
|
486 |
)
|
487 |
|
488 |
raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
|
|
|
494 |
f"{', '.join(raw_datasets_features)}."
|
495 |
)
|
496 |
|
497 |
+
data_args.text_column_name = data_args.text_column_name.split(",")[0]
|
498 |
if data_args.text_column_name not in raw_datasets_features:
|
499 |
raise ValueError(
|
500 |
f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
|
|
|
548 |
if model_args.freeze_encoder:
|
549 |
model.freeze_encoder()
|
550 |
|
551 |
+
if data_args.language is not None and len(data_args.language.split(",")) == 1:
|
552 |
# We only need to set the task id when the language is specified (i.e. in a multilingual setting)
|
553 |
+
# If more than a langugae is specified, it will be specified in the data collator
|
554 |
tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
|
555 |
+
elif data_args.language is not None and len(data_args.language.split(",")) > 1:
|
556 |
+
# make sure language and task are not stored in the model config
|
557 |
+
model.config.forced_decoder_ids = None
|
558 |
|
559 |
# 6. Resample speech dataset if necessary
|
560 |
logger.info("*** Resample dataset ***")
|
|
|
606 |
return batch
|
607 |
|
608 |
with training_args.main_process_first(desc="dataset map pre-processing"):
|
609 |
+
raw_datasets_features.remove("language")
|
610 |
vectorized_datasets = raw_datasets.map(
|
611 |
prepare_dataset,
|
612 |
remove_columns=raw_datasets_features,
|
|
|
666 |
processor = AutoProcessor.from_pretrained(training_args.output_dir)
|
667 |
|
668 |
# 10. Define data collator
|
669 |
+
task_token = data_args.task
|
670 |
+
if not task_token.startswith('<|'):
|
671 |
+
task_token = f'<{task_token}>'
|
672 |
+
task_id = tokenizer(task_token).input_ids[0]
|
673 |
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
|
674 |
processor=processor,
|
675 |
decoder_start_token_id=model.config.decoder_start_token_id,
|
676 |
+
task_id=task_id
|
677 |
)
|
678 |
|
679 |
# 11. Configure Trainer
|
|
|
770 |
if model_args.model_index_name is not None:
|
771 |
kwargs["model_name"] = model_args.model_index_name
|
772 |
|
773 |
+
# Training complete notification
|
774 |
+
logger.info(SENDING_NOTIFICATION)
|
775 |
+
notify_me(recipient=RECIPIENT_ADDRESS, message=json.dumps(kwargs, indent=4))
|
776 |
+
logger.info("*** Training complete!!! ***")
|
777 |
+
|
778 |
+
|
779 |
if training_args.push_to_hub:
|
780 |
logger.info("*** Pushing to hub ***")
|
781 |
trainer.push_to_hub(**kwargs)
|
782 |
logger.info("*** Pushed to hub ***")
|
783 |
+
logger.info(SENDING_NOTIFICATION)
|
784 |
+
notify_me(recipient=RECIPIENT_ADDRESS, message="Model pushed to hub")
|
785 |
else:
|
786 |
logger.info("*** Creating model card ***")
|
787 |
trainer.create_model_card(**kwargs)
|
788 |
logger.info("*** Model card created ***")
|
789 |
+
logger.info(SENDING_NOTIFICATION)
|
790 |
+
notify_me(recipient=RECIPIENT_ADDRESS, message="Model card created")
|
|
|
|
|
|
|
|
|
791 |
|
792 |
return results
|
793 |
|
test_run_nordic.sh
CHANGED
@@ -1,9 +1,12 @@
|
|
1 |
python $1run_speech_recognition_seq2seq_streaming.py \
|
2 |
--model_name_or_path="openai/whisper-tiny" \
|
3 |
-
--
|
4 |
-
--
|
5 |
-
--
|
6 |
-
--train_split_name="train+validation,train+validation,train+validation,train,train+test,
|
|
|
|
|
|
|
7 |
--eval_split_name="test" \
|
8 |
--model_index_name="Whisper Tiny Swedish" \
|
9 |
--max_train_samples="64" \
|
@@ -22,7 +25,7 @@ python $1run_speech_recognition_seq2seq_streaming.py \
|
|
22 |
--generation_max_length="225" \
|
23 |
--length_column_name="input_length" \
|
24 |
--max_duration_in_seconds="30" \
|
25 |
-
--text_column_name="sentence" \
|
26 |
--freeze_feature_encoder="False" \
|
27 |
--report_to="wandb" \
|
28 |
--metric_for_best_model="wer" \
|
|
|
1 |
python $1run_speech_recognition_seq2seq_streaming.py \
|
2 |
--model_name_or_path="openai/whisper-tiny" \
|
3 |
+
--dataset_train_name="mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,babelbox/babelbox_voice,NbAiLab/NST,arpelarpe/nota,NbAiLab/NPSC,google/fleurs,google/fleurs,google/fleurs" \
|
4 |
+
--dataset_train_config_name="sv-SE,da,nn-NO,,no-distant,,16k_mp3_nynorsk,sv_se,da_dk,nb_no" \
|
5 |
+
--language_train="swedish,danish,norwegian,swedish,norwegian,danish,norwegian,swedish,danish,norwegian" \
|
6 |
+
--train_split_name="train+validation,train+validation,train+validation,train,train+test,train,train+validation,train+validation,train+validation,train+validation" \
|
7 |
+
--dataset_eval_name="mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0" \
|
8 |
+
--dataset_eval_config_name="sv-SE,da,nn-NO" \
|
9 |
+
--language_eval="swedish,danish,norwegian" \
|
10 |
--eval_split_name="test" \
|
11 |
--model_index_name="Whisper Tiny Swedish" \
|
12 |
--max_train_samples="64" \
|
|
|
25 |
--generation_max_length="225" \
|
26 |
--length_column_name="input_length" \
|
27 |
--max_duration_in_seconds="30" \
|
28 |
+
--text_column_name="sentence,text" \
|
29 |
--freeze_feature_encoder="False" \
|
30 |
--report_to="wandb" \
|
31 |
--metric_for_best_model="wer" \
|