marinone94
commited on
Commit
•
aac3fbb
1
Parent(s):
9e467b3
use decode to inspect ds
Browse files
run_speech_recognition_seq2seq_streaming.py
CHANGED
@@ -294,6 +294,8 @@ class DataCollatorSpeechSeq2SeqWithPadding:
|
|
294 |
processor: Any
|
295 |
decoder_start_token_id: int
|
296 |
task_id: int
|
|
|
|
|
297 |
|
298 |
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
299 |
# split inputs and labels since they have to be of different lengths and need
|
@@ -312,8 +314,7 @@ class DataCollatorSpeechSeq2SeqWithPadding:
|
|
312 |
|
313 |
# if bos token is appended in previous tokenization step,
|
314 |
# cut bos token here as it's append later anyways
|
315 |
-
|
316 |
-
# labels = labels[:, 1:]
|
317 |
# lang_token_ids = self.processor.tokenizer(lang_features).input_ids
|
318 |
# # Replace language and task if they are in the beginning, otherwise add them
|
319 |
# if (labels[:, 1] == self.task_id).all().cpu().item():
|
@@ -328,6 +329,15 @@ class DataCollatorSpeechSeq2SeqWithPadding:
|
|
328 |
# labels[:, 0] = torch.full_like(labels[:, 0], -100)
|
329 |
# labels[:, 1] = torch.full_like(labels[:, 1], -100)
|
330 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
331 |
batch["labels"] = labels
|
332 |
|
333 |
return batch
|
@@ -461,7 +471,7 @@ def load_maybe_streaming_dataset(
|
|
461 |
def print_data_samples(dataset, tokenizer, max_samples=5):
|
462 |
shown_samples = 0
|
463 |
for batch in dataset:
|
464 |
-
print("Target: ", tokenizer.
|
465 |
shown_samples += len(batch)
|
466 |
if shown_samples >= max_samples:
|
467 |
break
|
|
|
294 |
processor: Any
|
295 |
decoder_start_token_id: int
|
296 |
task_id: int
|
297 |
+
# TODO: remove - infer language from dataset
|
298 |
+
language_id: int = -100
|
299 |
|
300 |
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
301 |
# split inputs and labels since they have to be of different lengths and need
|
|
|
314 |
|
315 |
# if bos token is appended in previous tokenization step,
|
316 |
# cut bos token here as it's append later anyways
|
317 |
+
|
|
|
318 |
# lang_token_ids = self.processor.tokenizer(lang_features).input_ids
|
319 |
# # Replace language and task if they are in the beginning, otherwise add them
|
320 |
# if (labels[:, 1] == self.task_id).all().cpu().item():
|
|
|
329 |
# labels[:, 0] = torch.full_like(labels[:, 0], -100)
|
330 |
# labels[:, 1] = torch.full_like(labels[:, 1], -100)
|
331 |
|
332 |
+
# remove start of sentence token from labels
|
333 |
+
if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
|
334 |
+
labels = labels[:, 1:]
|
335 |
+
|
336 |
+
# add start of sentence token to labels + language + task
|
337 |
+
labels = torch.cat((torch.full_like(labels[:, 0], self.task_id), labels), dim=1)
|
338 |
+
labels = torch.cat((torch.full_like(labels[:, 0], self.language_id), labels), dim=1)
|
339 |
+
labels = torch.cat((torch.full_like(labels[:, 0], self.decoder_start_token_id), labels), dim=1)
|
340 |
+
|
341 |
batch["labels"] = labels
|
342 |
|
343 |
return batch
|
|
|
471 |
def print_data_samples(dataset, tokenizer, max_samples=5):
|
472 |
shown_samples = 0
|
473 |
for batch in dataset:
|
474 |
+
print("Target: ", tokenizer.decode(batch["labels"]))
|
475 |
shown_samples += len(batch)
|
476 |
if shown_samples >= max_samples:
|
477 |
break
|