marinone94 commited on
Commit
0713f7f
1 Parent(s): c4280a8

fix tensor dim

Browse files
run_speech_recognition_seq2seq_streaming.py CHANGED
@@ -334,9 +334,9 @@ class DataCollatorSpeechSeq2SeqWithPadding:
334
  labels = labels[:, 1:]
335
 
336
  # add start of sentence token to labels + language + task
337
- labels = torch.cat((torch.full_like(labels[:, 0], self.task_id).unsqueeze(0).T, labels), dim=0)
338
- labels = torch.cat((torch.full_like(labels[:, 0], self.language_id).unsqueeze(0).T, labels), dim=0)
339
- labels = torch.cat((torch.full_like(labels[:, 0], self.decoder_start_token_id).unsqueeze(0).T, labels), dim=0)
340
 
341
  batch["labels"] = labels
342
 
 
334
  labels = labels[:, 1:]
335
 
336
  # add start of sentence token to labels + language + task
337
+ labels = torch.cat((torch.full_like(labels[:, 0], self.task_id).unsqueeze(0).T, labels), dim=-1)
338
+ labels = torch.cat((torch.full_like(labels[:, 0], self.language_id).unsqueeze(0).T, labels), dim=-1)
339
+ labels = torch.cat((torch.full_like(labels[:, 0], self.decoder_start_token_id).unsqueeze(0).T, labels), dim=-1)
340
 
341
  batch["labels"] = labels
342