marinone94
commited on
Commit
•
bb38ae2
1
Parent(s):
0713f7f
reset to set prefix tokens
Browse files
run_speech_recognition_seq2seq_streaming.py
CHANGED
@@ -330,13 +330,13 @@ class DataCollatorSpeechSeq2SeqWithPadding:
|
|
330 |
# labels[:, 1] = torch.full_like(labels[:, 1], -100)
|
331 |
|
332 |
# remove start of sentence token from labels
|
333 |
-
if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
|
334 |
-
|
335 |
|
336 |
-
# add start of sentence token to labels + language + task
|
337 |
-
labels = torch.cat((torch.full_like(labels[:, 0], self.task_id).unsqueeze(0).T, labels), dim=-1)
|
338 |
-
labels = torch.cat((torch.full_like(labels[:, 0], self.language_id).unsqueeze(0).T, labels), dim=-1)
|
339 |
-
labels = torch.cat((torch.full_like(labels[:, 0], self.decoder_start_token_id).unsqueeze(0).T, labels), dim=-1)
|
340 |
|
341 |
batch["labels"] = labels
|
342 |
|
@@ -640,14 +640,16 @@ def main():
|
|
640 |
|
641 |
if model_args.freeze_encoder:
|
642 |
model.freeze_encoder()
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
-
|
|
|
|
|
651 |
|
652 |
# 6. Resample speech dataset if necessary
|
653 |
# logger.info("*** Resample dataset ***")
|
|
|
330 |
# labels[:, 1] = torch.full_like(labels[:, 1], -100)
|
331 |
|
332 |
# remove start of sentence token from labels
|
333 |
+
# if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
|
334 |
+
# labels = labels[:, 1:]
|
335 |
|
336 |
+
# # add start of sentence token to labels + language + task
|
337 |
+
# labels = torch.cat((torch.full_like(labels[:, 0], self.task_id).unsqueeze(0).T, labels), dim=-1)
|
338 |
+
# labels = torch.cat((torch.full_like(labels[:, 0], self.language_id).unsqueeze(0).T, labels), dim=-1)
|
339 |
+
# labels = torch.cat((torch.full_like(labels[:, 0], self.decoder_start_token_id).unsqueeze(0).T, labels), dim=-1)
|
340 |
|
341 |
batch["labels"] = labels
|
342 |
|
|
|
640 |
|
641 |
if model_args.freeze_encoder:
|
642 |
model.freeze_encoder()
|
643 |
+
|
644 |
+
tokenizer.set_prefix_tokens(language=data_args.language_train, task=data_args.task)
|
645 |
+
|
646 |
+
# if data_args.language_train is not None and len(data_args.language_train.split(",")) == 1:
|
647 |
+
# # We only need to set the task id when the language is specified (i.e. in a multilingual setting)
|
648 |
+
# # If more than a langugae is specified, it will be specified in the data collator
|
649 |
+
# tokenizer.set_prefix_tokens(language=data_args.language_train, task=data_args.task)
|
650 |
+
# elif data_args.language_train is not None and len(data_args.language_train.split(",")) > 1:
|
651 |
+
# # make sure language and task are not stored in the model config
|
652 |
+
# model.config.forced_decoder_ids = None
|
653 |
|
654 |
# 6. Resample speech dataset if necessary
|
655 |
# logger.info("*** Resample dataset ***")
|