license: apache-2.0
MoH: Multi-Head Attention as Mixture-of-Head Attention
Paper or resources for more information: [Paper] [Code]
โก Overview
We propose Mixture-of-Head attention (MoH), a new architecture that treats attention heads as experts in the Mixture-of-Experts (MoE) mechanism. MoH has two significant advantages:
- First, MoH enables each token to select the appropriate attention heads, enhancing inference efficiency without compromising accuracy or increasing the number of parameters.
- Second, MoH replaces the standard summation in multi-head attention with a weighted summation, introducing flexibility to the attention mechanism and unlocking extra performance potential.
๐ฎ Highlights
๐ก General Framework
We evaluate our proposed MoH across various popular model frameworks, including Vision Transformers (ViT) for image classification, Diffusion models with Transformers (DiT) for class-conditional image generation, and Large Language Models (LLMs) for language tasks.
Code | HuggingFace Model |
---|---|
MoH-ViT | ๐ค MoH-ViT-B-75, MoH-ViT-B-50, MoH-ViT-S-80, MoH-ViT-S-75 |
MoH-DiT | ๐ MoH-DiT-90 |
MoH-LLaMA3-8B | ๐ MoH-LLaMA3-8B |
๐ฅ High Performance
Extensive experiments on ViT, DiT, and LLMs demonstrate that MoH outperforms multi-head attention by using only 50%~90% of the attention heads.
๐ค Support Continue-Tuning Starting from the Multi-Head Attention Models
we demonstrate that pre-trained multi-head attention models, such as LLaMA3-8B, can be further continue-tuned into our MoH models. Notably, MoH-LLaMA3-8B achieves an average accuracy of 64.0% across 14 benchmarks, outperforming LLaMA3-8B by 2.4% by utilizing only 75% of the attention heads.
The MoH model quickly recovers to over 95% of the performance of the original model within a training budget of 10B tokens. Then, the performance gradually improves with the increase of the training tokens.
๐ค API for Model Inference
If you want to load the model from the model hub on Hugging Face or on local, you can use the following code snippets.
Base Model Inference
from transformers import AutoModelForCausalLM, AutoTokenizer
question = "Hello!"
model = AutoModelForCausalLM.from_pretrained("Chat-UniVi/MoH-LLaMA3-8B", trust_remote_code=True, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained("Chat-UniVi/MoH-LLaMA3-8B", trust_remote_code=True)
inputs = tokenizer(question, return_tensors='pt').to(model.device)
response = model.generate(inputs.input_ids, max_length=128)
print(tokenizer.decode(response.cpu()[0], skip_special_tokens=True))
Chat Model Inference
Coming soon...
๐๏ธ Training & Validating
- The training code is built on Skywork-MoE. Unless Skywork-MoE is open source, we can't open source MoH-LLaMA3 alone. We will release the training code after the approval is completed.
- The evaluation is performed on multiple key benchmarks using the Eleuther AI Language Model Evaluation Harness.
# For example, test MoH-LLaMA3-8B on winogrande
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch \
--main_process_port 2004 -m lm_eval --model hf \
--model_args pretrained=Chat-UniVi/MoH-LLaMA3-8B \
--tasks winogrande \
--batch_size 1 \
--output_path Results/winogrande
โ๏ธ Citation
If you find this paper useful, please consider staring ๐ this repo and citing ๐ our paper:
@article{jin2024moh,
title={MoH: Multi-Head Attention as Mixture-of-Head Attention},
author={Peng Jin and Bo Zhu and Li Yuan and Shuicheng Yan},
year={2024}
}