Chat-UniVi
commited on
Commit
โข
d643284
1
Parent(s):
94af976
Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,73 @@
|
|
1 |
-
---
|
2 |
-
license: apache-2.0
|
3 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
---
|
4 |
+
# MoH: Multi-Head Attention as Mixture-of-Head Attention
|
5 |
+
|
6 |
+
**Paper or resources for more information:**
|
7 |
+
[[Paper]()] [[Code](https://github.com/SkyworkAI/MoE-plus-plus)]
|
8 |
+
|
9 |
+
## โก Overview
|
10 |
+
We propose Mixture-of-Head attention (MoH), a new architecture that treats attention heads as experts in the Mixture-of-Experts (MoE) mechanism. MoH has two significant advantages:
|
11 |
+
* First, MoH enables each token to select the appropriate attention heads, enhancing inference efficiency without compromising accuracy or increasing the number of parameters.
|
12 |
+
* Second, MoH replaces the standard summation in multi-head attention with a weighted summation, introducing flexibility to the attention mechanism and unlocking extra performance potential.
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
## ๐ฎ Highlights
|
17 |
+
### ๐ก General Framework
|
18 |
+
We evaluate our proposed MoH across various popular model frameworks, including Vision Transformers (ViT) for image classification, Diffusion models with Transformers (DiT) for class-conditional image generation, and Large Language Models (LLMs) for language tasks.
|
19 |
+
|
20 |
+
<div align=center>
|
21 |
+
|
22 |
+
| Code | HuggingFace Model |
|
23 |
+
|:-----------------------------------------:|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|
|
24 |
+
| **[MoH-ViT](https://github.com/SkyworkAI/MoH/tree/main/MoH-ViT)** | ๐ค [MoH-ViT-B-75](https://huggingface.co/Chat-UniVi/MoH-ViT-B-75), [MoH-ViT-B-50](https://huggingface.co/Chat-UniVi/MoH-ViT-B-50), [MoH-ViT-S-80](https://huggingface.co/Chat-UniVi/MoH-ViT-S-80), [MoH-ViT-S-75](https://huggingface.co/Chat-UniVi/MoH-ViT-S-75) |
|
25 |
+
| **[MoH-DiT](https://github.com/SkyworkAI/MoH/tree/main/MoH-DiT)** | ๐ [MoH-DiT-90](https://huggingface.co/Chat-UniVi/MoH-DiT-XL-90) |
|
26 |
+
| **[MoH-LLaMA3-8B](https://github.com/SkyworkAI/MoH/tree/main/MoH-LLaMA3)** | ๐ [MoH-LLaMA3-8B](https://huggingface.co/Chat-UniVi/MoH-LLaMA3-8B) |
|
27 |
+
|
28 |
+
</div>
|
29 |
+
|
30 |
+
### ๐ฅ High Performance
|
31 |
+
Extensive experiments on ViT, DiT, and LLMs demonstrate that MoH outperforms multi-head attention by using only **50%~90%** of the attention heads.
|
32 |
+
|
33 |
+
### ๐ค Support Continue-Tuning Starting from the Multi-Head Attention Models
|
34 |
+
we demonstrate that pre-trained multi-head attention models, such as LLaMA3-8B, can be further continue-tuned into our MoH models. Notably, MoH-LLaMA3-8B achieves an average accuracy of 64.0% across 14 benchmarks, outperforming LLaMA3-8B by 2.4% by utilizing only 75% of the attention heads.
|
35 |
+
|
36 |
+
|
37 |
+
The MoH model quickly recovers to over **95%** of the performance of the original model within a training budget of 10B tokens. Then, the performance gradually improves with the increase of the training tokens.
|
38 |
+
|
39 |
+
## ๐ค API for Model Inference
|
40 |
+
If you want to load the model from the model hub on Hugging Face or on local, you can use the following code snippets.
|
41 |
+
|
42 |
+
### Base Model Inference
|
43 |
+
```python
|
44 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
45 |
+
|
46 |
+
question = "Hello!"
|
47 |
+
|
48 |
+
model = AutoModelForCausalLM.from_pretrained("Chat-UniVi/MoH-LLaMA3-8B", trust_remote_code=True, device_map='auto')
|
49 |
+
tokenizer = AutoTokenizer.from_pretrained("Chat-UniVi/MoH-LLaMA3-8B", trust_remote_code=True)
|
50 |
+
|
51 |
+
inputs = tokenizer(question, return_tensors='pt').to(model.device)
|
52 |
+
response = model.generate(inputs.input_ids, max_length=128)
|
53 |
+
print(tokenizer.decode(response.cpu()[0], skip_special_tokens=True))
|
54 |
+
```
|
55 |
+
|
56 |
+
### Chat Model Inference
|
57 |
+
Coming soon...
|
58 |
+
|
59 |
+
|
60 |
+
## ๐๏ธ Training & Validating
|
61 |
+
* The training code is built on [Skywork-MoE](https://github.com/SkyworkAI/Skywork-MoE). Unless Skywork-MoE is open source, we can't open source MoH-LLaMA3 alone. We will release the training code after the approval is completed.
|
62 |
+
* The evaluation is performed on multiple key benchmarks using the [Eleuther AI Language Model Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness).
|
63 |
+
|
64 |
+
```python
|
65 |
+
# For example, test MoH-LLaMA3-8B on winogrande
|
66 |
+
|
67 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch \
|
68 |
+
--main_process_port 2004 -m lm_eval --model hf \
|
69 |
+
--model_args pretrained=Chat-UniVi/MoH-LLaMA3-8B \
|
70 |
+
--tasks winogrande \
|
71 |
+
--batch_size 1 \
|
72 |
+
--output_path Results/winogrande
|
73 |
+
```
|