Fine-tuning bge-reranker-v2-gemma resulted in CUDA torch.cuda.OutOfMemoryError even with 4 GPUs
#2
by
jackkwok
- opened
I am fine-tuning bge-reranker-v2-gemma with my custom training dataset. I am using 4x NVIDIA A10G with 24GB memory each so it's quite a lot of memory. But, I still get CUDA OOM shortly into the training. Any idea?
My command:
torchrun --nproc_per_node 4 \
-m FlagEmbedding.llm_reranker.finetune_for_instruction.run \
--output_dir model_artifacts \
--token <secret redacted> \
--model_name_or_path google/gemma-2b \
--train_data ./jsonl/train.jsonl \
--learning_rate 2e-4 \
--num_train_epochs 1 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 16 \
--dataloader_drop_last True \
--query_max_len 512 \
--passage_max_len 512 \
--train_group_size 16 \
--logging_steps 1 \
--save_steps 2000 \
--save_total_limit 50 \
--ddp_find_unused_parameters False \
--gradient_checkpointing \
--deepspeed stage1.json \
--warmup_ratio 0.1 \
--bf16 \
--use_lora True \
--lora_rank 32 \
--lora_alpha 64 \
--use_flash_attn True \
--target_modules q_proj k_proj v_proj o_proj
I still get CUDA OOM shortly into the training:
{'loss': 2.9226, 'grad_norm': 5.388334274291992, 'learning_rate': 0.0, 'epoch': 0.0}
{'loss': 2.8772, 'grad_norm': 4.197434902191162, 'learning_rate': 3.562071871080222e-05, 'epoch': 0.0}
{'loss': 2.8868, 'grad_norm': 4.446421146392822, 'learning_rate': 5.645750340535797e-05, 'epoch': 0.01}
{'loss': 2.6807, 'grad_norm': 3.2873618602752686, 'learning_rate': 7.124143742160444e-05, 'epoch': 0.01}
{'loss': 2.6077, 'grad_norm': 3.4100160598754883, 'learning_rate': 8.270874753469163e-05, 'epoch': 0.01}
{'loss': 2.5243, 'grad_norm': 3.642030715942383, 'learning_rate': 9.207822211616019e-05, 'epoch': 0.01}
{'loss': 2.6225, 'grad_norm': 3.3111226558685303, 'learning_rate': 0.0001, 'epoch': 0.01}
{'loss': 2.7275, 'grad_norm': 4.486476898193359, 'learning_rate': 0.00010686215613240667, 'epoch': 0.02}
{'loss': 2.7934, 'grad_norm': 3.5005621910095215, 'learning_rate': 0.00011291500681071594, 'epoch': 0.02}
{'loss': 2.6622, 'grad_norm': 3.005181312561035, 'learning_rate': 0.00011832946624549386, 'epoch': 0.02}
{'loss': 2.6258, 'grad_norm': 3.1711008548736572, 'learning_rate': 0.0001232274405867344, 'epoch': 0.02}
{'loss': 2.489, 'grad_norm': 3.336585283279419, 'learning_rate': 0.0001276989408269624, 'epoch': 0.02}
{'loss': 2.4983, 'grad_norm': 3.2431273460388184, 'learning_rate': 0.0001318123223061841, 'epoch': 0.03}
{'loss': 2.4421, 'grad_norm': 3.682769298553467, 'learning_rate': 0.00013562071871080222, 'epoch': 0.03}
{'loss': 2.6211, 'grad_norm': 3.925990343093872, 'learning_rate': 0.0001391662509400496, 'epoch': 0.03}
{'loss': 2.8245, 'grad_norm': 4.797396659851074, 'learning_rate': 0.00014248287484320887, 'epoch': 0.03}
{'loss': 2.5547, 'grad_norm': 4.069711685180664, 'learning_rate': 0.0001455983641090348, 'epoch': 0.03}
{'loss': 2.8645, 'grad_norm': 5.024136543273926, 'learning_rate': 0.00014853572552151815, 'epoch': 0.04}
{'loss': 2.3737, 'grad_norm': 3.9875905513763428, 'learning_rate': 0.00015131423106025147, 'epoch': 0.04}
{'loss': 2.6015, 'grad_norm': 3.5503010749816895, 'learning_rate': 0.00015395018495629608, 'epoch': 0.04}
{'loss': 2.4038, 'grad_norm': 4.3279194831848145, 'learning_rate': 0.000156457503405358, 'epoch': 0.04}
{'loss': 2.374, 'grad_norm': 3.7719438076019287, 'learning_rate': 0.00015884815929753662, 'epoch': 0.04}
{'loss': 2.2142, 'grad_norm': 3.907940626144409, 'learning_rate': 0.00016113252800759313, 'epoch': 0.05}
{'loss': 2.526, 'grad_norm': 3.962578296661377, 'learning_rate': 0.00016331965953776464, 'epoch': 0.05}
{'loss': 2.4629, 'grad_norm': 4.234306812286377, 'learning_rate': 0.00016541749506938325, 'epoch': 0.05}
{'loss': 2.0923, 'grad_norm': 4.046939373016357, 'learning_rate': 0.00016743304101698634, 'epoch': 0.05}
{'loss': 2.4197, 'grad_norm': 5.140893459320068, 'learning_rate': 0.0001693725102160739, 'epoch': 0.06}
{'loss': 2.266, 'grad_norm': 4.84731912612915, 'learning_rate': 0.00017124143742160445, 'epoch': 0.06}
{'loss': 2.3097, 'grad_norm': 4.363345623016357, 'learning_rate': 0.00017304477452986233, 'epoch': 0.06}
{'loss': 2.2013, 'grad_norm': 6.07499885559082, 'learning_rate': 0.00017478696965085182, 'epoch': 0.06}
{'loss': 2.3062, 'grad_norm': 4.905301570892334, 'learning_rate': 0.0001764720332103851, 'epoch': 0.06}
{'loss': 2.1395, 'grad_norm': 11.986255645751953, 'learning_rate': 0.00017810359355401113, 'epoch': 0.07}
{'loss': 2.3626, 'grad_norm': 7.714587211608887, 'learning_rate': 0.00017968494399209236, 'epoch': 0.07}
{'loss': 2.3273, 'grad_norm': 10.299479484558105, 'learning_rate': 0.00018121908281983702, 'epoch': 0.07}
{'loss': 2.1057, 'grad_norm': 6.595864295959473, 'learning_rate': 0.00018270874753469163, 'epoch': 0.07}
{'loss': 2.3477, 'grad_norm': 9.525544166564941, 'learning_rate': 0.00018415644423232038, 'epoch': 0.07}
{'loss': 2.027, 'grad_norm': 13.262056350708008, 'learning_rate': 0.00018556447297411074, 'epoch': 0.08}
{'loss': 2.087, 'grad_norm': 8.810412406921387, 'learning_rate': 0.0001869349497710537, 'epoch': 0.08}
{'loss': 2.1892, 'grad_norm': 6.087996006011963, 'learning_rate': 0.00018826982571154205, 'epoch': 0.08}
{'loss': 1.5995, 'grad_norm': 7.068698406219482, 'learning_rate': 0.00018957090366709828, 'epoch': 0.08}
8%|ββββββββββββ | 40/490 [33:43<6:10:42, 49.43s/it]Traceback (most recent call last):
File "/home/ec2-user/SageMaker/custom-miniconda/miniconda/envs/faiss/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/ec2-user/SageMaker/custom-miniconda/miniconda/envs/faiss/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/ec2-user/SageMaker/custom-miniconda/miniconda/envs/faiss/lib/python3.8/site-packages/FlagEmbedding/llm_reranker/finetune_for_instruction/run.py", line 131, in <module>
main()
File "/home/ec2-user/SageMaker/custom-miniconda/miniconda/envs/faiss/lib/python3.8/site-packages/FlagEmbedding/llm_reranker/finetune_for_instruction/run.py", line 118, in main
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
File "/home/ec2-user/SageMaker/custom-miniconda/miniconda/envs/faiss/lib/python3.8/site-packages/transformers/trainer.py", line 1885, in train
return inner_training_loop(
File "/home/ec2-user/SageMaker/custom-miniconda/miniconda/envs/faiss/lib/python3.8/site-packages/transformers/trainer.py", line 2216, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/home/ec2-user/SageMaker/custom-miniconda/miniconda/envs/faiss/lib/python3.8/site-packages/transformers/trainer.py", line 3250, in training_step
self.accelerator.backward(loss)
File "/home/ec2-user/SageMaker/custom-miniconda/miniconda/envs/faiss/lib/python3.8/site-packages/accelerate/accelerator.py", line 2117, in backward
self.deepspeed_engine_wrapped.backward(loss, **kwargs)
File "/home/ec2-user/SageMaker/custom-miniconda/miniconda/envs/faiss/lib/python3.8/site-packages/accelerate/utils/deepspeed.py", line 166, in backward
self.engine.backward(loss, **kwargs)
File "/home/ec2-user/SageMaker/custom-miniconda/miniconda/envs/faiss/lib/python3.8/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/home/ec2-user/SageMaker/custom-miniconda/miniconda/envs/faiss/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1976, in backward
self.optimizer.backward(loss, retain_graph=retain_graph)
File "/home/ec2-user/SageMaker/custom-miniconda/miniconda/envs/faiss/lib/python3.8/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 2056, in backward
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
File "/home/ec2-user/SageMaker/custom-miniconda/miniconda/envs/faiss/lib/python3.8/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
scaled_loss.backward(retain_graph=retain_graph)
File "/home/ec2-user/SageMaker/custom-miniconda/miniconda/envs/faiss/lib/python3.8/site-packages/torch/_tensor.py", line 492, in backward
torch.autograd.backward(
File "/home/ec2-user/SageMaker/custom-miniconda/miniconda/envs/faiss/lib/python3.8/site-packages/torch/autograd/__init__.py", line 251, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 7.32 GiB. GPU 0 has a total capacty of 22.19 GiB of which 7.13 GiB is free. Including non-PyTorch memory, this process has 15.06 GiB memory in use. Of the allocated memory 14.13 GiB is allocated by PyTorch, and 535.78 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
My dependencies:
sentence-transformers~=2.7.0
FlagEmbedding~=1.2.10
peft~=0.11.1
deepspeed~=0.14.2
flash-attn~=2.5.9.post1
Issue fixed by decreasing both of these parameters:
--query_max_len
--passage_max_len