marinone94
commited on
Commit
•
fa4a655
1
Parent(s):
bb38ae2
hard code language
Browse files
run_speech_recognition_seq2seq_streaming.py
CHANGED
@@ -325,7 +325,9 @@ class DataCollatorSpeechSeq2SeqWithPadding:
|
|
325 |
# task_id = torch.full_like(labels[:, 0], self.task_id)
|
326 |
# labels = torch.cat((lang_token_ids, task_id, labels), dim=1)
|
327 |
|
328 |
-
# Set language
|
|
|
|
|
329 |
# labels[:, 0] = torch.full_like(labels[:, 0], -100)
|
330 |
# labels[:, 1] = torch.full_like(labels[:, 1], -100)
|
331 |
|
@@ -641,7 +643,7 @@ def main():
|
|
641 |
if model_args.freeze_encoder:
|
642 |
model.freeze_encoder()
|
643 |
|
644 |
-
tokenizer.set_prefix_tokens(language=
|
645 |
|
646 |
# if data_args.language_train is not None and len(data_args.language_train.split(",")) == 1:
|
647 |
# # We only need to set the task id when the language is specified (i.e. in a multilingual setting)
|
|
|
325 |
# task_id = torch.full_like(labels[:, 0], self.task_id)
|
326 |
# labels = torch.cat((lang_token_ids, task_id, labels), dim=1)
|
327 |
|
328 |
+
# Set language to pad token
|
329 |
+
if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
|
330 |
+
labels[:, 1] = torch.full_like(labels[:, 1], -100)
|
331 |
# labels[:, 0] = torch.full_like(labels[:, 0], -100)
|
332 |
# labels[:, 1] = torch.full_like(labels[:, 1], -100)
|
333 |
|
|
|
643 |
if model_args.freeze_encoder:
|
644 |
model.freeze_encoder()
|
645 |
|
646 |
+
tokenizer.set_prefix_tokens(language="swedish", task=data_args.task)
|
647 |
|
648 |
# if data_args.language_train is not None and len(data_args.language_train.split(",")) == 1:
|
649 |
# # We only need to set the task id when the language is specified (i.e. in a multilingual setting)
|