FlashAttention support for Mistral HF Implementation
Hi,
First of all, thank you all for releasing such an amazing model! I'm trying to further train Mistral-7B-v0.1 on some custom data.
I noticed that the official implementation (https://github.com/mistralai/mistral-src/blob/main/mistral/model.py) has Flash Attention built in.
However, the HuggingFace version (https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py) doesn't seem to have Flash Attention integrated.
Is it possible if you can provide a script so that we can replace standard attention with Flash Attention after we've loaded the model via
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
Thank you for your time and effort :)
its still WIP, but I used this seems to work fine for FA2 https://github.com/huggingface/transformers/pull/26464