Model training
Hello Sanchit,
I'm here again :) Now trying to add BART into the model.
While now my model basically lingers around WER~14%, and I find that the most common errors are from long sentences where the model tends to skip long segments when generating the transcript.
May I ask if you have any idea about that?
Thanks!
Hey @mutiann !
In my experience, running a wandb sweep over the generation hyper-parameters (num_beams
, length_penalty
) works well! This blog explains generation very succinctly: https://huggingface.co/blog/how-to-generate
Sweeping over length_penalty
: https://wandb.ai/sanchit-gandhi/tedlium/sweeps/186lxl40?workspace=user-sanchit-gandhi
You can find the sweep config here: https://wandb.ai/sanchit-gandhi/tedlium/sweeps/186lxl40/overview?workspace=user-sanchit-gandhi
You can play around with the sweep config to sweep over num_beams
as well, or fix this to a value > 5 for adequate results
Another example sweep! One that I'm currently running for LibriSpeech: https://wandb.ai/sanchit-gandhi/librispeech-960h-dev/sweeps/po1rn2fe?workspace=user-sanchit-gandhi
Thank you so much! Unfortunately, the optimal generation hyperparameter only gives ~0.1% performance increase :(
I tried to mimic your training recipe and could only reach 10.1% WER at 50k steps. Continuing training will help it to reach 9.5% at 100k, and by normalizing the loss not across utterances but across tokens (i.e. longer utterances getting higher impacts) it will reach 10.0% at 50k and 8.5% at 100k, but I'm certain that you can reach better results if you do these on your models as well, so this is not the point. Also your WER curve is quite surprising (around 18% from 10k to 40k steps, and then suddenly improved to 9.0% at 50k steps), while my models are not (they are like 12% at 40k steps). So I guess that there must be some special wit :) Do you have any idea?
Thanks in advance!
Hey @mutiann , replying in-line to your comments!
Unfortunately, the optimal generation hyperparameter only gives ~0.1% performance increase :(
That suggests the model is indeed under-trained!
by normalizing the loss not across utterances but across tokens (i.e. longer utterances getting higher impacts) it will reach 10.0% at 50k and 8.5% at 100k
In the training loop, we normalise the CE loss by the number of non-padded tokens: https://github.com/sanchit-gandhi/seq2seq-speech/blob/cd25022f24677835f8bd26f528e0e444b80702c6/run_flax_speech_recognition_seq2seq.py#L1265-L1267
Therefore, the loss is weighted equally on a token-by-token basis. Did you weight longer utterances more heavily?
around 18% from 10k to 40k steps, and then suddenly improved to 9.0% at 50k steps
With regards to the sharp drop-off, this is because for the first 40k steps I evaluate with greedy, and then for the last evaluation/prediction steps I evaluate with beam search!
As you correctly pointed out on https://github.com/sanchit-gandhi/seq2seq-speech/issues/77, we should not filter the eval/test samples by our audio/text length criterion! I re-ran training without this filtering step, and evaluated using beam search at each evaluation step (so no sharp drop-off). The results are here: https://wandb.ai/sanchit-gandhi/tedlium/runs/2pefefil?workspace=user-sanchit-gandhi
I achieved an eval WER of 12.1% after 50k train steps, and a test WER of 7.0%. Optimising the generation HPs, I get an eval WER of 11.8% and a test WER of 7.0% https://wandb.ai/sanchit-gandhi/tedlium-dev/sweeps/mqmy3sfo?workspace=user-sanchit-gandhi
Are your quoted training results with or without the eval sample filtering?
Thank you for your reply!
I've been evaluating models without filtering. An eval WER of around 12% is roughly equal to my results. So actually there isn't any gap and we are on the same ground :) The only mystery (or challenge) is the gap between BART and CTC now.
As for normalization, previously I normalize them by utterances, so longer utterances will actually be underweighed. This won't make any difference at 50k steps, but the curve becomes different after that and leads to 1.5% WER gap later, compared to normalizing by tokens as both of us do now. Maybe weight longer utterances more heavily will even help it more, since most errors are from skipping in longer utterances.
The only mystery (or challenge) is the gap between BART and CTC now.
The Seq2Seq model (Wav2Vec2-2-BART) is able to leverage the knowledge of the target text domain gained by the BART model during pre-training (massive amounts of text data!). This greatly helps it with phonetic error corrections (foebe -> Phoebe). CTC is still very much susceptible to these errors, basing its predictions almost entirely on the audio inputs.
You can boot-strap your CTC results by incorporating a language model (LM) and performing LM boosted beam-search. In doing so, you can correct for phonetic errors and receive a 25-30% reduction in WER as a result, bringing it close (if not better) to Seq2Seq performance! Refer to this (fantastic) blog for a tutorial on n-grams with Transformers: https://huggingface.co/blog/wav2vec2-with-ngram
You can train an n-gram LM using this PR: https://github.com/sanchit-gandhi/seq2seq-speech/pull/83/files
You can evaluate with an n-gram using this script: https://github.com/sanchit-gandhi/seq2seq-speech/blob/main/run_flax_speech_recognition_ctc_ngram.py
All the best!
Sanchit
Yes indeed, the LM definitely helps, so we could expect that BART will usually improve the performance. The mystery I mentioned is why in both of our experiments the Wav2Vec2 + BART models have higher WER than Wav2Vec2 + CTC. BART seems to be very unstable in generating long transcripts and often skips part of the transcript.
Thanks a lot!
Ah I see! Yes that's an interesting one - removing the filtering constraint on the eval dataset reduced my Seq2Seq eval WER by several percentage points. Interesting to hear that you've found Seq2Seq to fare poorly on long transcripts, this aligns with the numbers I got from evaluation! The weighting of the BART decoder must play quite heavily and try to 'correct' these long transcriptions. Curious to hear if you've looked into remedying this problem! I've tried to tackle it with the generation HPs: https://wandb.ai/sanchit-gandhi/tedlium-dev/sweeps/p5e3l8rw?workspace=user-sanchit-gandhi
Still a way off CTC on the eval set!
I've actually found Seq2Seq to outperform vanilla CTC on the test set - 7.0% vs 8.2%. CTC + n-gram will almost certainly close the gap.
I don't have much idea to tackle the issue at the moment...Of course there are many tricks that can be tried on but I doubt if they are generally applicable. Anyway I'm trying to investigate if there is any better approach.