Transformers
PyTorch
wav2vec2
pretraining
mms
Inference Endpoints

Pre-training MMS-300M with a new language

#2
by Tirthankar - opened

I can successfully pre-train patrickvonplaten/wav2vec2-base-v2 with a new language unlabeled wav corpus by following the steps outlined in https://github.com/huggingface/transformers/tree/main/examples/pytorch/speech-pretraining on a single GPU A100.

However, when I want to use it with patrickvonplaten/mms-300 or facebook/wav2vec2-xls-r-300m, it gives error at accelerator.backward(loss).

RuntimeError: handle_0 INTERNAL ASSERT FAILED at "../c10/cuda/driver_api.cpp":15, please report a bug to PyTorch.

Any pointer to the resolution of this would be great.

Hey @Tirthankar - could you please provide a full stack trace and reproducible code snippet? And also your current environment details, which can be obtained by copying the output of the command:

transformers-cli env

Although judging by the runtime error, it looks like it's a bug with either the accelerate library and/or PyTorch, so IMO definitely worth opening issues there already to get some direct help!

Hi Sanchit - Thanks for your reply. Here is the Jupyter Notebook cell-by-cell execution of the original wav2vec2_no_trainer.py modified as per the current requirements. Also the last part of the command: transformers-cli env is pasted below. This is on a A100 single GPU with 40GB memory. There was no problem for pre-training wav2vec2 but not so for mms-300m.

  • transformers version: 4.30.2
  • Platform: Linux-4.15.0-189-generic-x86_64-with-glibc2.27
  • Python version: 3.8.16
  • Huggingface_hub version: 0.16.4
  • Safetensors version: 0.3.1
  • PyTorch version (GPU?): 2.1.0.dev20230523+cu117 (True)
  • Tensorflow version (GPU?): 2.12.0 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

The notebook:
The error was at accelerator.backward(loss) [see cell 21]
There is an assertion error to interrupt the execution just before the above command in cell 20 just after the loss calculation.

https://github.com/Tirthankar-iiitb/mms_pretrain/blob/main/wav2vec2_no_trainer_kas.ipynb

Hope these info helps to advise further.

Thanks/Tirthankar.

Sign up or log in to comment