Size mismatch error during train
Hey @taqwa92
The issue is with your target label sequences. Some of the label sequences have a length that exceeds the model’s maximum generation length. These must be very long sequences, as the maximum generation length is 448. This is the longest sequence the model is configured to handle (model.config.max_length
).
We've got two options here:
- Filter any label sequences longer than max length
- Increase the models' max length
What we can do is compute the labels length of each target sequence:
def prepare_dataset(batch):
# load and resample audio data from 48 to 16kHz
audio = batch["audio"]
# compute input length
batch["input_length"] = len(batch["audio"])
# compute log-Mel input features from input audio array
batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
# encode target text to label ids
batch["labels"] = tokenizer(batch["sentence"]).input_ids
# compute labels length
batch["labels_length"] = len(batch["labels"])
return batch
And then filter those that exceed the models maximum length:
MAX_DURATION_IN_SECONDS = 30.0
max_input_length = MAX_DURATION_IN_SECONDS * 16000
def filter_inputs(input_length):
"""Filter inputs with zero input length or longer than 30s"""
return 0 < input_length < max_input_length
max_label_length = model.config.max_length
def filter_labels(labels_length):
"""Filter label sequences longer than max length (448)"""
return labels_length < max_label_length
You can then apply the prepare_dataset
function and the two new filter functions to your dataset common_voice
as follows:
# pre-process
common_voice = common_voice.map(prepare_dataset, remove_columns= my_dataset.column_names["train"])
# filter by audio length
common_voice = common_voice.filter(filter_inputs, input_columns=["input_length"], remove_columns=["input_length"]
# filter by label length
common_voice = common_voice.filter(filter_labels, input_columns=["labels_length"], remove_columns=["labels_length"])
That should pre-process the dataset and remove any label sequences that are too long for the model.
Alternatively, we can change the model’s max length to any value we want:
model.config.max_length = 500
This will update the max length to 500 tokens. Make sure to do this before you filter for it to take effect:
max_label_length = model.config.max_length = 500
def filter_labels(labels_length):
"""Filter label sequences longer than the new max length (500)"""
return labels_length < max_label_length
Hope that helps!
alot of thanks for you prof @sanchit-gandhi , it really helps me