MInference 1.0: 10x Faster Million Context Inference with a Single GPU
This article introduces MInference 1.0: a pre-filling acceleration method based on dynamic sparse attention. The method searches for the optimal sparse pattern for each attention head, dynamically constructs sparse indices based on the input prompt, and finally performs sparse computation for long-context attention.
On a single A100 GPU, MInference 1.0 achieves up to 10x acceleration in pre-filling, while maintaining the same or even superior accuracy as full attention on numerous tasks.
Inference Bottleneck for Long-context LLMs
Due to the quadratic complexity of multi-head Attention operations, it is an extremely expensive token-mixing method.
In the era of BERT and GPT-2, the latency caused by attention was generally acceptable due to the relatively short input sequences. However, in the current context where LLMs generally pursue long text capabilities, this latency is becoming one of the main bottlenecks.
Test hardware: single A100-80G GPU
As shown in Figure (a) above, when the input prompt length exceeds 500K, the pre-filling stage of inference (a.k.a. time-to-first-token, TTFT) takes nearly 10 minutes, with the majority of time spent on attention. When the prompt length continues to grow to 1M, this waiting time will reach 30 minutes.
At the same time, we also find that attention operations are extreme inefficient in long-context scenarios. As shown in Figure (b), when the prompt length is 128K, if only the Top-4K columns are used for computation in attention, more than 96% of the global attention scores can be recalled.
This indicates that the attention matrix is extremely sparse in long texts. In other words, the quadratic complexity attention operation spends a lot of unnecessary computing power on near-zero elements, resulting in low efficiency of the entire pre-filling stage, and this problem is particularly prominent in long-context scenarios.
The sparsity of the attention matrix has been analyzed in detail in many works. As shown in the figure below, this property has also lead to many efficient-related works.
However, these sparse attention methods often rely on fixed patterns - for example, sliding windows always focus on local tokens, while StreamingLLM-like patterns always focus on local + initial tokens. This fixed pattern contradicts another property of attention operations - dynamicity.
Looking at this figure again, (b) shows that calculating only the top-4k columns can cover most of the attention scores; but if we use the top-4K index from this process on another prompt, we get the result in (c) - its ability to cover attention scores drops significantly to 83%, and even lower in some layers.
This indicates that the sparse distribution of attention is highly input-dependent, and its sparse distribution varies greatly according to different inputs. This makes it difficult for fixed patterns to achieve good generalization in different scenarios and inputs.
In summary, an ideal efficient long-context attention needs to take into account both the sparsity and dynamicity of attention: dynamically estimate an input-dependent sparse mask based on the input of attention, and then complete the sparse computation of attention.
MInference 1.0
MInference 1.0 aims to construct dynamic sparse attention that can be computed at high speed on GPUs.
Overall, MInference analyzed the sparse distribution of attention in long-context LLMs and proposed three sparse attention patterns that are easy to accelerate on GPUs, and implemented efficient GPU kernels for each pattern.
A-shape sparse, vertical-slash (v-s) sparse, and block sparse
We found that in the multi-head attention matrices of long-context LLMs, there is obvious spatial clustering, and most of them belong to one of the following three patterns.
For example, some heads show an A-shaped sparse pattern: attention is mostly concentrated on local and initial tokens (Figure 1 above, which is the sparse pattern of StreamingLLM). In addition, there is a vertical pattern that emphasizes special tokens, paired with a diagonal pattern that emphasizes relative positions (vertical-slash, Figure 2 above). Finally, there is a block sparse pattern that clusters in blocks (Figure 3 above).
Proposing these three patterns has two main benefits:
- The combination of the three can cover almost all sparse attention distributions.
- The characteristics of these patterns allow them to be efficiently computed on GPUs using flash-attention and PIT methods.
Comparing the three sparse patterns of MInference against the Top-K sparse attntion
Question: Why can these patterns be computed efficiently, but not Top-K sparse attention?
Because MInference's three patterns show clear spatial clustering, GPU kernels can efficiently complete sparse operations using 64Γ64 (A-shape and block-sparse heads) or 64Γ1 (v-s heads) blocks.
On the other hand, Top-K sparse attention (i.e., calculating only the top-k K vectors for each Q, shown on the right in the figure) requires a long time to build sparse indices due to its overly fine-grained sparse distribution, and produces a large number of ineffective operations when using blocks for computation on GPUs.
After MInference specifies the three sparse patterns, it will complete the following three steps:
- Find the optimal pattern for each attention head under a given FLOPs budget.
- Dynamically calculate the optimal sparse distribution for each input (e.g., the position of vertical-slash lines, or the index of blocks).
- Perform sparse computation of attention based on the sparse index obtained in step 2.
Searching for the Optimal Pattern for Each Head
search space
For each attention head, MInference will search for the optimal sparse pattern, that is, find which pattern can recall as many attention scores as possible on that attention head under the constraint of total FLOPs.
To this end, MInference proposes Kernel-Aware Sparse Pattern Search. This algorithm searches for the optimal pattern under a given FLOPs budget based on the actual computational volume of GPU kernels (kernel-aware), instead of the sparsity rate of attention, specifically including the following two steps:
- Decide which of A-shape, vertical-slash, and block sparse the head belongs to?
- Decide the optimal sparsity rate of the head (i.e., the number of vertical and diagonal lines in the vertical-slash pattern, the number of blocks in the block sparse pattern, etc.).
According to our observations, the sparse pattern type of attention heads is input-independent, that is, attention heads showing a vertical-slash pattern will always show a vertical-slash pattern under different inputs (note: however, the specific distribution position of vertical and diagonal lines is still input-dependent).
So MInference actually adopts an offline method when completing Kernel-Aware Sparse Pattern Search.
Dynamic Sparse Index Building
vs pattern: using last_q as proxy; block pattern: using pooling for dimensionality reduction
After confirming the type and sparsity rate of the sparse pattern, MInference will dynamically construct sparse indices based on the input: for the vertical-slash pattern, this step is to determine the position of the most important vertical and diagonal lines in the attention matrix; for the block sparse pattern, this step is to find the most relevant blocks from the entire attention matrix.
However, since the construction of sparse indices needs to introduce additional overhead, we need very efficient algorithms in this step to avoid the latency reduction brought by sparse attention being wasted on the index building step.
We use two efficient estimation methods to estimate the attention matrix for vertical-slash and block sparse patterns respectively, to obtain the sparse index needed for sparse computation:
- For the vertical-slash pattern, we found that using some q at the end position of the Q matrix (i.e., last_q in algorithm 2) can accurately obtain the global vertical and diagonal line positions.
- For the block sparse pattern, we apply a pooling layer of size 64 on the seq_len dimension to reduce the dimensionality of Q and K matrices, then calculate πΜ πΎΜ π after dimensionality reduction and find the blocks with the largest magnitude.
In terms of complexity, the index building (algorithm 2) of the vertical-slash pattern has a complexity of O(Nd). As for algorithm 3, due to the use of a pooling layer of size 64 for dimensionality reduction, its complexity is only 1/64*64 of the complexity of full attention. This demonstrates that both algorithms are very efficient.
Sparse Attention Computing
MInference implemented corresponding GPU kernels for the three proposed sparse attention patterns: for A-shape pattern and block sparse pattern, the kernel uses blocks of size 64Γ64 for computation, while for vertical-slash pattern, the kernel uses blocks of size 64Γ1 for computation.
Results
MInference 1.0 can process input prompts of 500K length 6.8 times faster on a single A100, and when the prompt length is 1M, MInference can achieve 10 times pre-filling acceleration.
All green Needle Results + 10x acceleration
At the same time, compared with full attention, MInference's dynamic sparse attention has achieved the same or even surpassed full attention accuracy on numerous tasks.
RULER
Better accuracy than full attention, and effective context length surpassing full attention
"Needle in a Haystack"
MInference achieves the same or better performance as full attention on GLM-4, Yi, Phi-3, Qwen2 and other LLMs
Using MInference with vLLM and HF models
Using MInference 1.0 on HuggingFace models and vLLM models only requires three lines of code:
# For HuggingFace models
from transformers import AutoModel
+from minference import MInferenceConfig, apply_minference
model = AutoModel.from_pretrained("meta-llama/Llama-2-7b-hf")
+config = MInferenceConfig(sparsity_ratio=0.9)
+model = apply_minference(model, config)
# For vLLM models
from vllm import LLM
+from minference import MInferenceConfig, apply_minference
llm = LLM(model="meta-llama/Llama-2-7b-hf")
+config = MInferenceConfig(sparsity_ratio=0.9)
+llm.model = apply_minference(llm.model, config)