FSDP Finetuning
Hi I'm trying to finetune Mixtral using FSDP framework and I have this error during the first backward pass:
Exception: Attention mask should be of size (1, 1, 4096, 8192), but is torch.Size([1, 1, 4096, 4096])
I'm using the same logic and the same data I used to finetune Mistral 7B...
Getting this error as well.
Thanks, could you open an issue on https://huggingface/transformers
with a full reproducer?
Is there any corresponding issue?
I believe it has been recently fixed by: https://github.com/huggingface/transformers/pull/28061
You can use the main branch of transformers, pip install -U git+https://github.com/huggingface/transformers.git
@rganti
Can you please share your FSDP config ?
I am trying a full fine tuning(not LoRA) usingauto_wrap_policy={MixtralDecoderLayer}, activation_checkpointing_policy={MixtralDecoderLayer}
according to https://lightning.ai/docs/fabric/stable/advanced/model_parallel/fsdp.html
It is giving me recomputed tensor size mismatch error. A detailed bug report is here
FYI: I tried the latest transformer and lightning library installed from git+https
{
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"fsdp_backward_prefetch_policy": "BACKWARD_PRE",
"fsdp_cpu_ram_efficient_loading": "False",
"fsdp_forward_prefetch": "True",
"fsdp_offload_params": "False",
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
"fsdp_sync_module_states": "False",
"fsdp_transformer_layer_cls_to_wrap": "MixtralDecoderLayer",
"fsdp_use_orig_params": "True",
"activation_checkpointing": "True"
}
I am using SFTTrainer
btw
@hrushikesh1
-- some other model (GPTBigCode
) is giving me this trouble (size/shape mismatch), it used to work well in the past for me :)
@hrushikesh1 To update, it seems to be flaky and dependent on the PyTorch and HF versions that are installed. I am still trying to figure out the "right" combination, but perhaps @ybelkada or someone from HF/PT teams can comment?
specifically, using torch version 2.2.0.dev20231121+cu118
and transformers is 4.37.0.dev0
and python is 3.11
Thanks for the info
@rganti
!
I was able to solve it by explicitly callingmodel.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': True})
after model load AutoModel.from_pretrained()
.
The issue re-appears if I set use_reentrant:False
in the above call. Lightning library might be defaulting to use_reentrant:False.
There is lot of notes and warning from pytorch on the renentrant behavior here
As of torch 2.1 it defaults to True, but they plan to move to use_reentrant=False as a default in future, that might be causing the flakiness you observe across versions
@hrushikesh1 I was able to Lora tune mixtral on the latest PT nightlies and latest HF main after adding the above line, thanks!
Wanted to share a note for a future data scientist in trouble:
I was trying LORA
fine tuning of Mistral-7B using FSDP
strategy and pytorch lighting
trainer. It used to get stuck at Step-1.
Turned out, since there are some frozen parameters without gradients, I can not use gradient_clipping.
{ "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", "fsdp_backward_prefetch_policy": "BACKWARD_PRE", "fsdp_cpu_ram_efficient_loading": "False", "fsdp_forward_prefetch": "True", "fsdp_offload_params": "False", "fsdp_state_dict_type": "SHARDED_STATE_DICT", "fsdp_sync_module_states": "False", "fsdp_transformer_layer_cls_to_wrap": "MixtralDecoderLayer", "fsdp_use_orig_params": "True", "activation_checkpointing": "True" }
I am using
SFTTrainer
Hi, Did you fine-tune Mixtral 8x7b with any adapter ? Or just regular fine-tuning with FSDP. Can you provide your GPU computing resource info ?